From 0f6971bcc229abbe87ee2b3effa91f8219f0e068 Mon Sep 17 00:00:00 2001 From: Pragyansh Chaturvedi Date: Sun, 12 Oct 2025 11:34:40 +0530 Subject: [PATCH] Refactor allocate_mem --- pythonbpf/functions/functions_pass.py | 342 ++++++++++++++++---------- 1 file changed, 211 insertions(+), 131 deletions(-) diff --git a/pythonbpf/functions/functions_pass.py b/pythonbpf/functions/functions_pass.py index 68314f1..a024ca5 100644 --- a/pythonbpf/functions/functions_pass.py +++ b/pythonbpf/functions/functions_pass.py @@ -220,6 +220,203 @@ def process_stmt( return did_return +def _is_helper_call(call_node): + """Check if a call node is a BPF helper function call.""" + if isinstance(call_node.func, ast.Name): + # Exclude print from requiring temps (handles f-strings differently) + func_name = call_node.func.id + return HelperHandlerRegistry.has_handler(func_name) and func_name != "print" + + elif isinstance(call_node.func, ast.Attribute): + return HelperHandlerRegistry.has_handler(call_node.func.attr) + + return False + + +def _handle_if_allocation( + module, builder, stmt, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab +): + """Recursively handle allocations in if/else branches.""" + if stmt.body: + allocate_mem( + module, + builder, + stmt.body, + func, + ret_type, + map_sym_tab, + local_sym_tab, + structs_sym_tab, + ) + if stmt.orelse: + allocate_mem( + module, + builder, + stmt.orelse, + func, + ret_type, + map_sym_tab, + local_sym_tab, + structs_sym_tab, + ) + + +def _handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab): + """Handle memory allocation for assignment statements.""" + + # Validate assignment + if len(stmt.targets) != 1: + logger.warning("Multi-target assignment not supported, skipping allocation") + return + + target = stmt.targets[0] + + # Skip non-name targets (e.g., struct field assignments) + if isinstance(target, ast.Attribute): + logger.debug(f"Struct field assignment to {target.attr}, no allocation needed") + return + + if not isinstance(target, ast.Name): + logger.warning(f"Unsupported assignment target type: {type(target).__name__}") + return + + var_name = target.id + rval = stmt.value + + # Skip if already allocated + if var_name in local_sym_tab: + logger.debug(f"Variable {var_name} already allocated, skipping") + return + + # Determine type and allocate based on rval + if isinstance(rval, ast.Call): + _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab) + elif isinstance(rval, ast.Constant): + _allocate_for_constant(builder, var_name, rval, local_sym_tab) + elif isinstance(rval, ast.BinOp): + _allocate_for_binop(builder, var_name, local_sym_tab) + else: + logger.warning( + f"Unsupported assignment value type for {var_name}: {type(rval).__name__}" + ) + + +def _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab): + """Allocate memory for variable assigned from a call.""" + + if isinstance(rval.func, ast.Name): + call_type = rval.func.id + + # C type constructors + if call_type in ("c_int32", "c_int64", "c_uint32", "c_uint64"): + ir_type = ctypes_to_ir(call_type) + var = builder.alloca(ir_type, name=var_name) + var.align = ir_type.width // 8 + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + logger.info(f"Pre-allocated {var_name} as {call_type}") + + # Helper functions + elif HelperHandlerRegistry.has_handler(call_type): + ir_type = ir.IntType(64) # Assume i64 return type + var = builder.alloca(ir_type, name=var_name) + var.align = 8 + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + logger.info(f"Pre-allocated {var_name} for helper {call_type}") + + # Deref function + elif call_type == "deref": + ir_type = ir.IntType(64) # Assume i64 return type + var = builder.alloca(ir_type, name=var_name) + var.align = 8 + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + logger.info(f"Pre-allocated {var_name} for deref") + + # Struct constructors + elif call_type in structs_sym_tab: + struct_info = structs_sym_tab[call_type] + var = builder.alloca(struct_info.ir_type, name=var_name) + local_sym_tab[var_name] = LocalSymbol(var, struct_info.ir_type, call_type) + logger.info(f"Pre-allocated {var_name} for struct {call_type}") + + else: + logger.warning(f"Unknown call type for allocation: {call_type}") + + elif isinstance(rval.func, ast.Attribute): + # Map method calls - need double allocation for ptr handling + _allocate_for_map_method(builder, var_name, local_sym_tab) + + else: + logger.warning(f"Unsupported call function type for {var_name}") + + +def _allocate_for_map_method(builder, var_name, local_sym_tab): + """Allocate memory for variable assigned from map method (double alloc).""" + + # Main variable (pointer to pointer) + ir_type = ir.PointerType(ir.IntType(64)) + var = builder.alloca(ir_type, name=var_name) + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + + # Temporary variable for computed values + tmp_ir_type = ir.IntType(64) + var_tmp = builder.alloca(tmp_ir_type, name=f"{var_name}_tmp") + local_sym_tab[f"{var_name}_tmp"] = LocalSymbol(var_tmp, tmp_ir_type) + + logger.info(f"Pre-allocated {var_name} and {var_name}_tmp for map method") + + +def _allocate_for_constant(builder, var_name, rval, local_sym_tab): + """Allocate memory for variable assigned from a constant.""" + + if isinstance(rval.value, bool): + ir_type = ir.IntType(1) + var = builder.alloca(ir_type, name=var_name) + var.align = 1 + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + logger.info(f"Pre-allocated {var_name} as bool") + + elif isinstance(rval.value, int): + ir_type = ir.IntType(64) + var = builder.alloca(ir_type, name=var_name) + var.align = 8 + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + logger.info(f"Pre-allocated {var_name} as i64") + + elif isinstance(rval.value, str): + ir_type = ir.PointerType(ir.IntType(8)) + var = builder.alloca(ir_type, name=var_name) + var.align = 8 + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + logger.info(f"Pre-allocated {var_name} as string") + + else: + logger.warning( + f"Unsupported constant type for {var_name}: {type(rval.value).__name__}" + ) + + +def _allocate_for_binop(builder, var_name, local_sym_tab): + """Allocate memory for variable assigned from a binary operation.""" + ir_type = ir.IntType(64) # Assume i64 result + var = builder.alloca(ir_type, name=var_name) + var.align = 8 + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + logger.info(f"Pre-allocated {var_name} for binop result") + + +def _allocate_temp_pool(builder, max_temps, local_sym_tab): + """Allocate the temporary scratch space pool for helper arguments.""" + if max_temps == 0: + return + + logger.info(f"Allocating temp pool of {max_temps} variables") + for i in range(max_temps): + temp_name = f"__helper_temp_{i}" + temp_var = builder.alloca(ir.IntType(64), name=temp_name) + temp_var.align = 8 + local_sym_tab[temp_name] = LocalSymbol(temp_var, ir.IntType(64)) + + def count_temps_in_call(call_node, local_sym_tab): """Count the number of temporary variables needed for a function call.""" @@ -255,7 +452,6 @@ def count_temps_in_call(call_node, local_sym_tab): def allocate_mem( module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab ): - double_alloc = False max_temps_needed = 0 def update_max_temps_for_stmt(stmt): @@ -276,139 +472,23 @@ def allocate_mem( for stmt in body: update_max_temps_for_stmt(stmt) - has_metadata = False + + # Handle allocations if isinstance(stmt, ast.If): - if stmt.body: - local_sym_tab = allocate_mem( - module, - builder, - stmt.body, - func, - ret_type, - map_sym_tab, - local_sym_tab, - structs_sym_tab, - ) - if stmt.orelse: - local_sym_tab = allocate_mem( - module, - builder, - stmt.orelse, - func, - ret_type, - map_sym_tab, - local_sym_tab, - structs_sym_tab, - ) + _handle_if_allocation( + module, + builder, + stmt, + func, + ret_type, + map_sym_tab, + local_sym_tab, + structs_sym_tab, + ) elif isinstance(stmt, ast.Assign): - if len(stmt.targets) != 1: - logger.info("Unsupported multiassignment") - continue - target = stmt.targets[0] - if not isinstance(target, ast.Name) and not isinstance( - target, ast.Attribute - ): - logger.info("Unsupported assignment target") - continue - if isinstance(target, ast.Attribute): - logger.info( - f"Struct field {target.attr} assignment, will be handled later" - ) - continue - var_name = target.id - rval = stmt.value - if var_name in local_sym_tab: - logger.info(f"Variable {var_name} already allocated") - continue - if isinstance(rval, ast.Call): - if isinstance(rval.func, ast.Name): - call_type = rval.func.id - if call_type in ("c_int32", "c_int64", "c_uint32", "c_uint64"): - ir_type = ctypes_to_ir(call_type) - var = builder.alloca(ir_type, name=var_name) - var.align = ir_type.width // 8 - logger.info( - f"Pre-allocated variable {var_name} of type {call_type}" - ) - elif HelperHandlerRegistry.has_handler(call_type): - # Assume return type is int64 for now - ir_type = ir.IntType(64) - var = builder.alloca(ir_type, name=var_name) - var.align = ir_type.width // 8 - logger.info(f"Pre-allocated variable {var_name} for helper") - elif call_type == "deref" and len(rval.args) == 1: - # Assume return type is int64 for now - ir_type = ir.IntType(64) - var = builder.alloca(ir_type, name=var_name) - var.align = ir_type.width // 8 - logger.info(f"Pre-allocated variable {var_name} for deref") - elif call_type in structs_sym_tab: - struct_info = structs_sym_tab[call_type] - ir_type = struct_info.ir_type - var = builder.alloca(ir_type, name=var_name) - has_metadata = True - logger.info( - f"Pre-allocated variable {var_name} for struct {call_type}" - ) - elif isinstance(rval.func, ast.Attribute): - # Map method call - ir_type = ir.PointerType(ir.IntType(64)) - var = builder.alloca(ir_type, name=var_name) + _handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab) - # declare an intermediate ptr type for map lookup - tmp_ir_type = ir.IntType(64) - var_tmp = builder.alloca(tmp_ir_type, name=f"{var_name}_tmp") - double_alloc = True - # var.align = ir_type.width // 8 - logger.info( - f"Pre-allocated variable {var_name} and {var_name}_tmp for map" - ) - else: - logger.info("Unsupported assignment call function type") - continue - elif isinstance(rval, ast.Constant): - if isinstance(rval.value, bool): - ir_type = ir.IntType(1) - var = builder.alloca(ir_type, name=var_name) - var.align = 1 - logger.info(f"Pre-allocated variable {var_name} of type c_bool") - elif isinstance(rval.value, int): - # Assume c_int64 for now - ir_type = ir.IntType(64) - var = builder.alloca(ir_type, name=var_name) - var.align = ir_type.width // 8 - logger.info(f"Pre-allocated variable {var_name} of type c_int64") - elif isinstance(rval.value, str): - ir_type = ir.PointerType(ir.IntType(8)) - var = builder.alloca(ir_type, name=var_name) - var.align = 8 - logger.info(f"Pre-allocated variable {var_name} of type string") - else: - logger.info("Unsupported constant type") - continue - elif isinstance(rval, ast.BinOp): - # Assume c_int64 for now - ir_type = ir.IntType(64) - var = builder.alloca(ir_type, name=var_name) - var.align = ir_type.width // 8 - logger.info(f"Pre-allocated variable {var_name} of type c_int64") - else: - logger.info("Unsupported assignment value type") - continue - - if has_metadata: - local_sym_tab[var_name] = LocalSymbol(var, ir_type, call_type) - else: - local_sym_tab[var_name] = LocalSymbol(var, ir_type) - - if double_alloc: - local_sym_tab[f"{var_name}_tmp"] = LocalSymbol(var_tmp, tmp_ir_type) - - logger.info(f"Temporary scratch space needed for calls: {max_temps_needed}") - for i in range(max_temps_needed): - temp_var = builder.alloca(ir.IntType(64), name=f"__helper_temp_{i}") - temp_var.align = 8 - local_sym_tab[f"__helper_temp_{i}"] = LocalSymbol(temp_var, ir.IntType(64)) + _allocate_temp_pool(builder, max_temps_needed, local_sym_tab) return local_sym_tab