From 2cf68f64735e48d0c90faa109842956f3da9f221 Mon Sep 17 00:00:00 2001 From: Pragyansh Chaturvedi Date: Sun, 12 Oct 2025 07:57:55 +0530 Subject: [PATCH] Allow map-based helpers to be used as helper args / within binops which are helper args --- pythonbpf/binary_ops.py | 4 +++ pythonbpf/functions/functions_pass.py | 38 ++++++++++++++++++++------ pythonbpf/helper/bpf_helper_handler.py | 22 +++++++++++---- pythonbpf/helper/helper_utils.py | 10 ++++--- 4 files changed, 56 insertions(+), 18 deletions(-) diff --git a/pythonbpf/binary_ops.py b/pythonbpf/binary_ops.py index a5b8dbe..6ea534b 100644 --- a/pythonbpf/binary_ops.py +++ b/pythonbpf/binary_ops.py @@ -39,6 +39,10 @@ def get_operand_value( if res is None: raise ValueError(f"Failed to evaluate call expression: {operand}") val, _ = res + logger.info(f"Evaluated expr to {val} of type {val.type}") + base_type, depth = get_base_type_and_depth(val.type) + if depth > 0: + val = deref_to_depth(func, builder, val, depth) return val raise TypeError(f"Unsupported operand type: {type(operand)}") diff --git a/pythonbpf/functions/functions_pass.py b/pythonbpf/functions/functions_pass.py index 64acad4..ae3f94b 100644 --- a/pythonbpf/functions/functions_pass.py +++ b/pythonbpf/functions/functions_pass.py @@ -388,14 +388,18 @@ def process_stmt( return did_return -def count_temps_in_call(call_node): +def count_temps_in_call(call_node, local_sym_tab): """Count the number of temporary variables needed for a function call.""" count = 0 is_helper = False + # NOTE: We exclude print calls for now if isinstance(call_node.func, ast.Name): - if HelperHandlerRegistry.has_handler(call_node.func.id): + if ( + HelperHandlerRegistry.has_handler(call_node.func.id) + and call_node.func.id != "print" + ): is_helper = True elif isinstance(call_node.func, ast.Attribute): if HelperHandlerRegistry.has_handler(call_node.func.attr): @@ -405,10 +409,11 @@ def count_temps_in_call(call_node): return 0 for arg in call_node.args: - if ( - isinstance(arg, ast.BinOp) - or isinstance(arg, ast.Constant) - or isinstance(arg, ast.UnaryOp) + # NOTE: Count all non-name arguments + # For struct fields, if it is being passed as an argument, + # The struct object should already exist in the local_sym_tab + if not isinstance(arg, ast.Name) and not ( + isinstance(arg, ast.Attribute) and arg.value.id in local_sym_tab ): count += 1 @@ -423,11 +428,19 @@ def allocate_mem( def update_max_temps_for_stmt(stmt): nonlocal max_temps_needed + temps_needed = 0 + + if isinstance(stmt, ast.If): + for s in stmt.body: + update_max_temps_for_stmt(s) + for s in stmt.orelse: + update_max_temps_for_stmt(s) + return for node in ast.walk(stmt): if isinstance(node, ast.Call): - temps_needed = count_temps_in_call(node) - max_temps_needed = max(max_temps_needed, temps_needed) + temps_needed += count_temps_in_call(node, local_sym_tab) + max_temps_needed = max(max_temps_needed, temps_needed) for stmt in body: update_max_temps_for_stmt(stmt) @@ -460,9 +473,16 @@ def allocate_mem( logger.info("Unsupported multiassignment") continue target = stmt.targets[0] - if not isinstance(target, ast.Name): + if not isinstance(target, ast.Name) and not isinstance( + target, ast.Attribute + ): logger.info("Unsupported assignment target") continue + if isinstance(target, ast.Attribute): + logger.info( + f"Struct field {target.attr} assignment, will be handled later" + ) + continue var_name = target.id rval = stmt.value if var_name in local_sym_tab: diff --git a/pythonbpf/helper/bpf_helper_handler.py b/pythonbpf/helper/bpf_helper_handler.py index 0dd2ba4..44731d7 100644 --- a/pythonbpf/helper/bpf_helper_handler.py +++ b/pythonbpf/helper/bpf_helper_handler.py @@ -34,6 +34,7 @@ def bpf_ktime_get_ns_emitter( func, local_sym_tab=None, struct_sym_tab=None, + map_sym_tab=None, ): """ Emit LLVM IR for bpf_ktime_get_ns helper function call. @@ -56,6 +57,7 @@ def bpf_map_lookup_elem_emitter( func, local_sym_tab=None, struct_sym_tab=None, + map_sym_tab=None, ): """ Emit LLVM IR for bpf_map_lookup_elem helper function call. @@ -65,12 +67,16 @@ def bpf_map_lookup_elem_emitter( f"Map lookup expects exactly one argument (key), got {len(call.args)}" ) key_ptr = get_or_create_ptr_from_arg( - func, module, call.args[0], builder, local_sym_tab, struct_sym_tab + func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab ) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) + # TODO: I have changed the return typr to i64*, as we are + # allocating space for that type in allocate_mem. This is + # temporary, and we will honour other widths later. But this + # allows us to have cool binary ops on the returned value. fn_type = ir.FunctionType( - ir.PointerType(), # Return type: void* + ir.PointerType(ir.IntType(64)), # Return type: void* [ir.PointerType(), ir.PointerType()], # Args: (void*, void*) var_arg=False, ) @@ -93,6 +99,7 @@ def bpf_printk_emitter( func, local_sym_tab=None, struct_sym_tab=None, + map_sym_tab=None, ): """Emit LLVM IR for bpf_printk helper function call.""" if not hasattr(func, "_fmt_counter"): @@ -140,6 +147,7 @@ def bpf_map_update_elem_emitter( func, local_sym_tab=None, struct_sym_tab=None, + map_sym_tab=None, ): """ Emit LLVM IR for bpf_map_update_elem helper function call. @@ -155,10 +163,10 @@ def bpf_map_update_elem_emitter( flags_arg = call.args[2] if len(call.args) > 2 else None key_ptr = get_or_create_ptr_from_arg( - func, module, key_arg, builder, local_sym_tab, struct_sym_tab + func, module, key_arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab ) value_ptr = get_or_create_ptr_from_arg( - func, module, value_arg, builder, local_sym_tab, struct_sym_tab + func, module, value_arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab ) flags_val = get_flags_val(flags_arg, builder, local_sym_tab) @@ -194,6 +202,7 @@ def bpf_map_delete_elem_emitter( func, local_sym_tab=None, struct_sym_tab=None, + map_sym_tab=None, ): """ Emit LLVM IR for bpf_map_delete_elem helper function call. @@ -204,7 +213,7 @@ def bpf_map_delete_elem_emitter( f"Map delete expects exactly one argument (key), got {len(call.args)}" ) key_ptr = get_or_create_ptr_from_arg( - func, module, call.args[0], builder, local_sym_tab, struct_sym_tab + func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab ) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) @@ -233,6 +242,7 @@ def bpf_get_current_pid_tgid_emitter( func, local_sym_tab=None, struct_sym_tab=None, + map_sym_tab=None, ): """ Emit LLVM IR for bpf_get_current_pid_tgid helper function call. @@ -259,6 +269,7 @@ def bpf_perf_event_output_handler( func, local_sym_tab=None, struct_sym_tab=None, + map_sym_tab=None, ): if len(call.args) != 1: raise ValueError( @@ -323,6 +334,7 @@ def handle_helper_call( func, local_sym_tab, struct_sym_tab, + map_sym_tab, ) # Handle direct function calls (e.g., print(), ktime()) diff --git a/pythonbpf/helper/helper_utils.py b/pythonbpf/helper/helper_utils.py index c12e56b..5960c9a 100644 --- a/pythonbpf/helper/helper_utils.py +++ b/pythonbpf/helper/helper_utils.py @@ -81,14 +81,14 @@ def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64): # Default to 64-bit integer ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab) - logger.debug(f"Using temp variable '{temp_name}' for int constant {value}") + logger.info(f"Using temp variable '{temp_name}' for int constant {value}") const_val = ir.Constant(ir.IntType(int_width), value) builder.store(const_val, ptr) return ptr def get_or_create_ptr_from_arg( - func, module, arg, builder, local_sym_tab, struct_sym_tab=None + func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab=None ): """Extract or create pointer from the call arguments.""" @@ -104,15 +104,17 @@ def get_or_create_ptr_from_arg( builder, arg, local_sym_tab, - None, + map_sym_tab, struct_sym_tab, ) if val is None: raise ValueError("Failed to evaluate expression for helper arg.") # NOTE: We assume the result is an int64 for now + # if isinstance(arg, ast.Attribute): + # return val ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab) - logger.debug(f"Using temp variable '{temp_name}' for expression result") + logger.info(f"Using temp variable '{temp_name}' for expression result") builder.store(val, ptr) return ptr