seperate expr handling logic to a different file to prevent circular import, add format strings

This commit is contained in:
Pragyansh Chaturvedi
2025-09-11 02:51:24 +05:30
parent cfdc14137c
commit 1936ded032
4 changed files with 149 additions and 58 deletions

View File

@ -34,9 +34,9 @@ def hello_again(ctx: c_void_p) -> c_int64:
y = False y = False
if x > 0: if x > 0:
if x < 2: if x < 2:
print("we prevailed") print(f"we prevailed {x}")
else: else:
print("we did not prevail") print(f"we did not prevail {x}")
ts = ktime() ts = ktime()
last().update(key, ts) last().update(key, ts)

View File

@ -1,8 +1,9 @@
import ast import ast
from llvmlite import ir from llvmlite import ir
from .expr_pass import eval_expr
def bpf_ktime_get_ns_emitter(call, module, builder, func): def bpf_ktime_get_ns_emitter(call, map_ptr, module, builder, func, local_sym_tab=None):
""" """
Emit LLVM IR for bpf_ktime_get_ns helper function call. Emit LLVM IR for bpf_ktime_get_ns helper function call.
""" """
@ -62,10 +63,87 @@ def bpf_map_lookup_elem_emitter(call, map_ptr, module, builder, local_sym_tab=No
return result return result
def bpf_printk_emitter(call, module, builder, func): def bpf_printk_emitter(call, map_ptr, module, builder, func, local_sym_tab=None):
if not hasattr(func, "_fmt_counter"): if not hasattr(func, "_fmt_counter"):
func._fmt_counter = 0 func._fmt_counter = 0
if not call.args:
raise ValueError("print expects at least one argument")
if isinstance(call.args[0], ast.JoinedStr):
fmt_parts = []
exprs = []
for value in call.args[0].values:
if isinstance(value, ast.Constant):
if isinstance(value.value, str):
fmt_parts.append(value.value)
elif isinstance(value.value, int):
fmt_parts.append("%lld")
exprs.append(ir.Constant(ir.IntType(64), value.value))
else:
raise NotImplementedError(
"Only string and integer constants are supported in f-string.")
elif isinstance(value, ast.FormattedValue):
# Assume int for now
fmt_parts.append("%d")
if isinstance(value.value, ast.Name):
exprs.append(value.value)
else:
raise NotImplementedError(
"Only simple variable names are supported in formatted values.")
else:
raise NotImplementedError(
"Unsupported value type in f-string.")
fmt_str = "".join(fmt_parts) + "\n" + "\0"
fmt_name = f"{func.name}____fmt{func._fmt_counter}"
func._fmt_counter += 1
fmt_gvar = ir.GlobalVariable(
module, ir.ArrayType(ir.IntType(8), len(fmt_str)), name=fmt_name)
fmt_gvar.global_constant = True
fmt_gvar.initializer = ir.Constant(
ir.ArrayType(ir.IntType(8), len(fmt_str)),
bytearray(fmt_str.encode("utf8"))
)
fmt_gvar.linkage = "internal"
fmt_gvar.align = 1
fmt_ptr = builder.bitcast(fmt_gvar, ir.PointerType())
args = [fmt_ptr, ir.Constant(ir.IntType(32), len(fmt_str))]
# Only 3 args supported in bpf_printk
if len(exprs) > 3:
print(
"Warning: bpf_printk supports up to 3 arguments, extra arguments will be ignored.")
for expr in exprs[:3]:
val = eval_expr(func, module, builder, expr, local_sym_tab, None)
if val:
if isinstance(val.type, ir.PointerType):
val = builder.ptrtoint(val, ir.IntType(64))
elif isinstance(val.type, ir.IntType):
if val.type.width < 64:
val = builder.sext(val, ir.IntType(64))
else:
print(
"Warning: Only integer and pointer types are supported in bpf_printk arguments. Others will be converted to 0.")
val = ir.Constant(ir.IntType(64), 0)
args.append(val)
else:
print(
"Warning: Failed to evaluate expression for bpf_printk argument. It will be converted to 0.")
args.append(ir.Constant(ir.IntType(64), 0))
fn_type = ir.FunctionType(ir.IntType(
64), [ir.PointerType(), ir.IntType(32)], var_arg=True)
fn_ptr_type = ir.PointerType(fn_type)
fn_addr = ir.Constant(ir.IntType(64), 6)
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
return builder.call(fn_ptr, args, tail=True)
for arg in call.args: for arg in call.args:
if isinstance(arg, ast.Constant) and isinstance(arg.value, str): if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
fmt_str = arg.value + "\n" + "\0" fmt_str = arg.value + "\n" + "\0"
@ -93,6 +171,7 @@ def bpf_printk_emitter(call, module, builder, func):
builder.call(fn_ptr, [fmt_ptr, ir.Constant( builder.call(fn_ptr, [fmt_ptr, ir.Constant(
ir.IntType(32), len(fmt_str))], tail=True) ir.IntType(32), len(fmt_str))], tail=True)
def bpf_map_update_elem_emitter(call, map_ptr, module, builder, local_sym_tab=None): def bpf_map_update_elem_emitter(call, map_ptr, module, builder, local_sym_tab=None):
""" """
Emit LLVM IR for bpf_map_update_elem helper function call. Emit LLVM IR for bpf_map_update_elem helper function call.
@ -101,11 +180,11 @@ def bpf_map_update_elem_emitter(call, map_ptr, module, builder, local_sym_tab=No
if not call.args or len(call.args) < 2 or len(call.args) > 3: if not call.args or len(call.args) < 2 or len(call.args) > 3:
raise ValueError("Map update expects 2 or 3 arguments (key, value, flags), got " raise ValueError("Map update expects 2 or 3 arguments (key, value, flags), got "
f"{len(call.args)}") f"{len(call.args)}")
key_arg = call.args[0] key_arg = call.args[0]
value_arg = call.args[1] value_arg = call.args[1]
flags_arg = call.args[2] if len(call.args) > 2 else None flags_arg = call.args[2] if len(call.args) > 2 else None
# Handle key # Handle key
if isinstance(key_arg, ast.Name): if isinstance(key_arg, ast.Name):
key_name = key_arg.id key_name = key_arg.id
@ -124,7 +203,7 @@ def bpf_map_update_elem_emitter(call, map_ptr, module, builder, local_sym_tab=No
else: else:
raise NotImplementedError( raise NotImplementedError(
"Only simple variable names and integer constants are supported as keys in map update.") "Only simple variable names and integer constants are supported as keys in map update.")
# Handle value # Handle value
if isinstance(value_arg, ast.Name): if isinstance(value_arg, ast.Name):
value_name = value_arg.id value_name = value_arg.id
@ -143,7 +222,7 @@ def bpf_map_update_elem_emitter(call, map_ptr, module, builder, local_sym_tab=No
else: else:
raise NotImplementedError( raise NotImplementedError(
"Only simple variable names and integer constants are supported as values in map update.") "Only simple variable names and integer constants are supported as values in map update.")
# Handle flags argument (defaults to 0) # Handle flags argument (defaults to 0)
if flags_arg is not None: if flags_arg is not None:
if isinstance(flags_arg, ast.Constant) and isinstance(flags_arg.value, int): if isinstance(flags_arg, ast.Constant) and isinstance(flags_arg.value, int):
@ -162,7 +241,7 @@ def bpf_map_update_elem_emitter(call, map_ptr, module, builder, local_sym_tab=No
"Only integer constants and simple variable names are supported as flags in map update.") "Only integer constants and simple variable names are supported as flags in map update.")
else: else:
flags_val = 0 flags_val = 0
if key_ptr is None or value_ptr is None: if key_ptr is None or value_ptr is None:
raise ValueError("Key pointer or value pointer is None.") raise ValueError("Key pointer or value pointer is None.")
@ -173,20 +252,22 @@ def bpf_map_update_elem_emitter(call, map_ptr, module, builder, local_sym_tab=No
var_arg=False var_arg=False
) )
fn_ptr_type = ir.PointerType(fn_type) fn_ptr_type = ir.PointerType(fn_type)
# helper id # helper id
fn_addr = ir.Constant(ir.IntType(64), 2) fn_addr = ir.Constant(ir.IntType(64), 2)
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
if isinstance(flags_val, int): if isinstance(flags_val, int):
flags_const = ir.Constant(ir.IntType(64), flags_val) flags_const = ir.Constant(ir.IntType(64), flags_val)
else: else:
flags_const = flags_val flags_const = flags_val
result = builder.call(fn_ptr, [map_void_ptr, key_ptr, value_ptr, flags_const], tail=False) result = builder.call(
fn_ptr, [map_void_ptr, key_ptr, value_ptr, flags_const], tail=False)
return result return result
helper_func_list = { helper_func_list = {
"lookup": bpf_map_lookup_elem_emitter, "lookup": bpf_map_lookup_elem_emitter,
"print": bpf_printk_emitter, "print": bpf_printk_emitter,
@ -200,7 +281,7 @@ def handle_helper_call(call, module, builder, func, local_sym_tab=None, map_sym_
func_name = call.func.id func_name = call.func.id
if func_name in helper_func_list: if func_name in helper_func_list:
# it is not a map method call # it is not a map method call
return helper_func_list[func_name](call, module, builder, func) return helper_func_list[func_name](call, None, module, builder, func, local_sym_tab)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Function {func_name} is not implemented as a helper function.") f"Function {func_name} is not implemented as a helper function.")

49
pythonbpf/expr_pass.py Normal file
View File

@ -0,0 +1,49 @@
import ast
from llvmlite import ir
def eval_expr(func, module, builder, expr, local_sym_tab, map_sym_tab):
print(f"Evaluating expression: {expr}")
if isinstance(expr, ast.Name):
if expr.id in local_sym_tab:
var = local_sym_tab[expr.id]
val = builder.load(var)
return val
else:
print(f"Undefined variable {expr.id}")
return None
elif isinstance(expr, ast.Constant):
if isinstance(expr.value, int):
return ir.Constant(ir.IntType(64), expr.value)
elif isinstance(expr.value, bool):
return ir.Constant(ir.IntType(1), int(expr.value))
else:
print("Unsupported constant type")
return None
elif isinstance(expr, ast.Call):
# delayed import to avoid circular dependency
from .bpf_helper_handler import helper_func_list, handle_helper_call
if isinstance(expr.func, ast.Name):
# check for helpers first
if expr.func.id in helper_func_list:
return handle_helper_call(
expr, module, builder, func, local_sym_tab, map_sym_tab)
elif isinstance(expr.func, ast.Attribute):
if isinstance(expr.func.value, ast.Call) and isinstance(expr.func.value.func, ast.Name):
method_name = expr.func.attr
if method_name in helper_func_list:
return handle_helper_call(
expr, module, builder, func, local_sym_tab, map_sym_tab)
print("Unsupported expression evaluation")
return None
def handle_expr(func, module, builder, expr, local_sym_tab, map_sym_tab):
"""Handle expression statements in the function body."""
print(f"Handling expression: {ast.dump(expr)}")
call = expr.value
if isinstance(call, ast.Call):
eval_expr(func, module, builder, call, local_sym_tab, map_sym_tab)
else:
print("Unsupported expression type")

View File

@ -1,9 +1,11 @@
from llvmlite import ir from llvmlite import ir
import ast import ast
from .bpf_helper_handler import helper_func_list, handle_helper_call from .bpf_helper_handler import helper_func_list, handle_helper_call
from .type_deducer import ctypes_to_ir from .type_deducer import ctypes_to_ir
from .binary_ops import handle_binary_op from .binary_ops import handle_binary_op
from .expr_pass import eval_expr, handle_expr
def get_probe_string(func_node): def get_probe_string(func_node):
@ -22,6 +24,7 @@ def get_probe_string(func_node):
return arg.value return arg.value
return "helper" return "helper"
def handle_assign(func, module, builder, stmt, map_sym_tab, local_sym_tab): def handle_assign(func, module, builder, stmt, map_sym_tab, local_sym_tab):
"""Handle assignment statements in the function body.""" """Handle assignment statements in the function body."""
if len(stmt.targets) != 1: if len(stmt.targets) != 1:
@ -96,54 +99,12 @@ def handle_assign(func, module, builder, stmt, map_sym_tab, local_sym_tab):
else: else:
print("Unsupported assignment call function type") print("Unsupported assignment call function type")
elif isinstance(rval, ast.BinOp): elif isinstance(rval, ast.BinOp):
handle_binary_op(rval, module, builder, func, local_sym_tab, map_sym_tab) handle_binary_op(rval, module, builder, func,
local_sym_tab, map_sym_tab)
else: else:
print("Unsupported assignment value type") print("Unsupported assignment value type")
def eval_expr(func, module, builder, expr, local_sym_tab, map_sym_tab):
if isinstance(expr, ast.Name):
if expr.id in local_sym_tab:
var = local_sym_tab[expr.id]
val = builder.load(var)
return val
else:
print(f"Undefined variable {expr.id}")
return None
elif isinstance(expr, ast.Constant):
if isinstance(expr.value, int):
return ir.Constant(ir.IntType(64), expr.value)
elif isinstance(expr.value, bool):
return ir.Constant(ir.IntType(1), int(expr.value))
else:
print("Unsupported constant type")
return None
elif isinstance(expr, ast.Call):
if isinstance(expr.func, ast.Name):
# check for helpers first
if expr.func.id in helper_func_list:
return handle_helper_call(
expr, module, builder, func, local_sym_tab, map_sym_tab)
elif isinstance(expr.func, ast.Attribute):
if isinstance(expr.func.value, ast.Call) and isinstance(expr.func.value.func, ast.Name):
method_name = expr.func.attr
if method_name in helper_func_list:
return handle_helper_call(
expr, module, builder, func, local_sym_tab, map_sym_tab)
print("Unsupported expression evaluation")
return None
def handle_expr(func, module, builder, expr, local_sym_tab, map_sym_tab):
"""Handle expression statements in the function body."""
print(f"Handling expression: {ast.dump(expr)}")
call = expr.value
if isinstance(call, ast.Call):
eval_expr(func, module, builder, call, local_sym_tab, map_sym_tab)
else:
print("Unsupported expression type")
def handle_cond(func, module, builder, cond, local_sym_tab, map_sym_tab): def handle_cond(func, module, builder, cond, local_sym_tab, map_sym_tab):
if isinstance(cond, ast.Constant): if isinstance(cond, ast.Constant):
if isinstance(cond.value, bool): if isinstance(cond.value, bool):