diff --git a/pythonbpf/allocation_pass.py b/pythonbpf/allocation_pass.py new file mode 100644 index 0000000..5ec631a --- /dev/null +++ b/pythonbpf/allocation_pass.py @@ -0,0 +1,191 @@ +import ast +import logging + +from llvmlite import ir +from dataclasses import dataclass +from typing import Any +from pythonbpf.helper import HelperHandlerRegistry +from pythonbpf.type_deducer import ctypes_to_ir + +logger = logging.getLogger(__name__) + + +@dataclass +class LocalSymbol: + var: ir.AllocaInstr + ir_type: ir.Type + metadata: Any = None + + def __iter__(self): + yield self.var + yield self.ir_type + yield self.metadata + + +def _is_helper_call(call_node): + """Check if a call node is a BPF helper function call.""" + if isinstance(call_node.func, ast.Name): + # Exclude print from requiring temps (handles f-strings differently) + func_name = call_node.func.id + return HelperHandlerRegistry.has_handler(func_name) and func_name != "print" + + elif isinstance(call_node.func, ast.Attribute): + return HelperHandlerRegistry.has_handler(call_node.func.attr) + + return False + + +def handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab): + """Handle memory allocation for assignment statements.""" + + # Validate assignment + if len(stmt.targets) != 1: + logger.warning("Multi-target assignment not supported, skipping allocation") + return + + target = stmt.targets[0] + + # Skip non-name targets (e.g., struct field assignments) + if isinstance(target, ast.Attribute): + logger.debug(f"Struct field assignment to {target.attr}, no allocation needed") + return + + if not isinstance(target, ast.Name): + logger.warning(f"Unsupported assignment target type: {type(target).__name__}") + return + + var_name = target.id + rval = stmt.value + + # Skip if already allocated + if var_name in local_sym_tab: + logger.debug(f"Variable {var_name} already allocated, skipping") + return + + # Determine type and allocate based on rval + if isinstance(rval, ast.Call): + _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab) + elif isinstance(rval, ast.Constant): + _allocate_for_constant(builder, var_name, rval, local_sym_tab) + elif isinstance(rval, ast.BinOp): + _allocate_for_binop(builder, var_name, local_sym_tab) + else: + logger.warning( + f"Unsupported assignment value type for {var_name}: {type(rval).__name__}" + ) + + +def _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab): + """Allocate memory for variable assigned from a call.""" + + if isinstance(rval.func, ast.Name): + call_type = rval.func.id + + # C type constructors + if call_type in ("c_int32", "c_int64", "c_uint32", "c_uint64"): + ir_type = ctypes_to_ir(call_type) + var = builder.alloca(ir_type, name=var_name) + var.align = ir_type.width // 8 + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + logger.info(f"Pre-allocated {var_name} as {call_type}") + + # Helper functions + elif HelperHandlerRegistry.has_handler(call_type): + ir_type = ir.IntType(64) # Assume i64 return type + var = builder.alloca(ir_type, name=var_name) + var.align = 8 + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + logger.info(f"Pre-allocated {var_name} for helper {call_type}") + + # Deref function + elif call_type == "deref": + ir_type = ir.IntType(64) # Assume i64 return type + var = builder.alloca(ir_type, name=var_name) + var.align = 8 + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + logger.info(f"Pre-allocated {var_name} for deref") + + # Struct constructors + elif call_type in structs_sym_tab: + struct_info = structs_sym_tab[call_type] + var = builder.alloca(struct_info.ir_type, name=var_name) + local_sym_tab[var_name] = LocalSymbol(var, struct_info.ir_type, call_type) + logger.info(f"Pre-allocated {var_name} for struct {call_type}") + + else: + logger.warning(f"Unknown call type for allocation: {call_type}") + + elif isinstance(rval.func, ast.Attribute): + # Map method calls - need double allocation for ptr handling + _allocate_for_map_method(builder, var_name, local_sym_tab) + + else: + logger.warning(f"Unsupported call function type for {var_name}") + + +def _allocate_for_map_method(builder, var_name, local_sym_tab): + """Allocate memory for variable assigned from map method (double alloc).""" + + # Main variable (pointer to pointer) + ir_type = ir.PointerType(ir.IntType(64)) + var = builder.alloca(ir_type, name=var_name) + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + + # Temporary variable for computed values + tmp_ir_type = ir.IntType(64) + var_tmp = builder.alloca(tmp_ir_type, name=f"{var_name}_tmp") + local_sym_tab[f"{var_name}_tmp"] = LocalSymbol(var_tmp, tmp_ir_type) + + logger.info(f"Pre-allocated {var_name} and {var_name}_tmp for map method") + + +def _allocate_for_constant(builder, var_name, rval, local_sym_tab): + """Allocate memory for variable assigned from a constant.""" + + if isinstance(rval.value, bool): + ir_type = ir.IntType(1) + var = builder.alloca(ir_type, name=var_name) + var.align = 1 + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + logger.info(f"Pre-allocated {var_name} as bool") + + elif isinstance(rval.value, int): + ir_type = ir.IntType(64) + var = builder.alloca(ir_type, name=var_name) + var.align = 8 + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + logger.info(f"Pre-allocated {var_name} as i64") + + elif isinstance(rval.value, str): + ir_type = ir.PointerType(ir.IntType(8)) + var = builder.alloca(ir_type, name=var_name) + var.align = 8 + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + logger.info(f"Pre-allocated {var_name} as string") + + else: + logger.warning( + f"Unsupported constant type for {var_name}: {type(rval.value).__name__}" + ) + + +def _allocate_for_binop(builder, var_name, local_sym_tab): + """Allocate memory for variable assigned from a binary operation.""" + ir_type = ir.IntType(64) # Assume i64 result + var = builder.alloca(ir_type, name=var_name) + var.align = 8 + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + logger.info(f"Pre-allocated {var_name} for binop result") + + +def allocate_temp_pool(builder, max_temps, local_sym_tab): + """Allocate the temporary scratch space pool for helper arguments.""" + if max_temps == 0: + return + + logger.info(f"Allocating temp pool of {max_temps} variables") + for i in range(max_temps): + temp_name = f"__helper_temp_{i}" + temp_var = builder.alloca(ir.IntType(64), name=temp_name) + temp_var.align = 8 + local_sym_tab[temp_name] = LocalSymbol(temp_var, ir.IntType(64)) diff --git a/pythonbpf/functions/functions_pass.py b/pythonbpf/functions/functions_pass.py index a024ca5..45d7b0a 100644 --- a/pythonbpf/functions/functions_pass.py +++ b/pythonbpf/functions/functions_pass.py @@ -1,8 +1,6 @@ from llvmlite import ir import ast import logging -from typing import Any -from dataclasses import dataclass from pythonbpf.helper import ( HelperHandlerRegistry, @@ -14,6 +12,7 @@ from pythonbpf.assign_pass import ( handle_variable_assignment, handle_struct_field_assignment, ) +from pythonbpf.allocation_pass import handle_assign_allocation, allocate_temp_pool from .return_utils import _handle_none_return, _handle_xdp_return, _is_xdp_name @@ -21,18 +20,6 @@ from .return_utils import _handle_none_return, _handle_xdp_return, _is_xdp_name logger = logging.getLogger(__name__) -@dataclass -class LocalSymbol: - var: ir.AllocaInstr - ir_type: ir.Type - metadata: Any = None - - def __iter__(self): - yield self.var - yield self.ir_type - yield self.metadata - - def get_probe_string(func_node): """Extract the probe string from the decorator of the function node.""" # TODO: right now we have the whole string in the section decorator @@ -220,20 +207,7 @@ def process_stmt( return did_return -def _is_helper_call(call_node): - """Check if a call node is a BPF helper function call.""" - if isinstance(call_node.func, ast.Name): - # Exclude print from requiring temps (handles f-strings differently) - func_name = call_node.func.id - return HelperHandlerRegistry.has_handler(func_name) and func_name != "print" - - elif isinstance(call_node.func, ast.Attribute): - return HelperHandlerRegistry.has_handler(call_node.func.attr) - - return False - - -def _handle_if_allocation( +def handle_if_allocation( module, builder, stmt, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab ): """Recursively handle allocations in if/else branches.""" @@ -261,162 +235,6 @@ def _handle_if_allocation( ) -def _handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab): - """Handle memory allocation for assignment statements.""" - - # Validate assignment - if len(stmt.targets) != 1: - logger.warning("Multi-target assignment not supported, skipping allocation") - return - - target = stmt.targets[0] - - # Skip non-name targets (e.g., struct field assignments) - if isinstance(target, ast.Attribute): - logger.debug(f"Struct field assignment to {target.attr}, no allocation needed") - return - - if not isinstance(target, ast.Name): - logger.warning(f"Unsupported assignment target type: {type(target).__name__}") - return - - var_name = target.id - rval = stmt.value - - # Skip if already allocated - if var_name in local_sym_tab: - logger.debug(f"Variable {var_name} already allocated, skipping") - return - - # Determine type and allocate based on rval - if isinstance(rval, ast.Call): - _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab) - elif isinstance(rval, ast.Constant): - _allocate_for_constant(builder, var_name, rval, local_sym_tab) - elif isinstance(rval, ast.BinOp): - _allocate_for_binop(builder, var_name, local_sym_tab) - else: - logger.warning( - f"Unsupported assignment value type for {var_name}: {type(rval).__name__}" - ) - - -def _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab): - """Allocate memory for variable assigned from a call.""" - - if isinstance(rval.func, ast.Name): - call_type = rval.func.id - - # C type constructors - if call_type in ("c_int32", "c_int64", "c_uint32", "c_uint64"): - ir_type = ctypes_to_ir(call_type) - var = builder.alloca(ir_type, name=var_name) - var.align = ir_type.width // 8 - local_sym_tab[var_name] = LocalSymbol(var, ir_type) - logger.info(f"Pre-allocated {var_name} as {call_type}") - - # Helper functions - elif HelperHandlerRegistry.has_handler(call_type): - ir_type = ir.IntType(64) # Assume i64 return type - var = builder.alloca(ir_type, name=var_name) - var.align = 8 - local_sym_tab[var_name] = LocalSymbol(var, ir_type) - logger.info(f"Pre-allocated {var_name} for helper {call_type}") - - # Deref function - elif call_type == "deref": - ir_type = ir.IntType(64) # Assume i64 return type - var = builder.alloca(ir_type, name=var_name) - var.align = 8 - local_sym_tab[var_name] = LocalSymbol(var, ir_type) - logger.info(f"Pre-allocated {var_name} for deref") - - # Struct constructors - elif call_type in structs_sym_tab: - struct_info = structs_sym_tab[call_type] - var = builder.alloca(struct_info.ir_type, name=var_name) - local_sym_tab[var_name] = LocalSymbol(var, struct_info.ir_type, call_type) - logger.info(f"Pre-allocated {var_name} for struct {call_type}") - - else: - logger.warning(f"Unknown call type for allocation: {call_type}") - - elif isinstance(rval.func, ast.Attribute): - # Map method calls - need double allocation for ptr handling - _allocate_for_map_method(builder, var_name, local_sym_tab) - - else: - logger.warning(f"Unsupported call function type for {var_name}") - - -def _allocate_for_map_method(builder, var_name, local_sym_tab): - """Allocate memory for variable assigned from map method (double alloc).""" - - # Main variable (pointer to pointer) - ir_type = ir.PointerType(ir.IntType(64)) - var = builder.alloca(ir_type, name=var_name) - local_sym_tab[var_name] = LocalSymbol(var, ir_type) - - # Temporary variable for computed values - tmp_ir_type = ir.IntType(64) - var_tmp = builder.alloca(tmp_ir_type, name=f"{var_name}_tmp") - local_sym_tab[f"{var_name}_tmp"] = LocalSymbol(var_tmp, tmp_ir_type) - - logger.info(f"Pre-allocated {var_name} and {var_name}_tmp for map method") - - -def _allocate_for_constant(builder, var_name, rval, local_sym_tab): - """Allocate memory for variable assigned from a constant.""" - - if isinstance(rval.value, bool): - ir_type = ir.IntType(1) - var = builder.alloca(ir_type, name=var_name) - var.align = 1 - local_sym_tab[var_name] = LocalSymbol(var, ir_type) - logger.info(f"Pre-allocated {var_name} as bool") - - elif isinstance(rval.value, int): - ir_type = ir.IntType(64) - var = builder.alloca(ir_type, name=var_name) - var.align = 8 - local_sym_tab[var_name] = LocalSymbol(var, ir_type) - logger.info(f"Pre-allocated {var_name} as i64") - - elif isinstance(rval.value, str): - ir_type = ir.PointerType(ir.IntType(8)) - var = builder.alloca(ir_type, name=var_name) - var.align = 8 - local_sym_tab[var_name] = LocalSymbol(var, ir_type) - logger.info(f"Pre-allocated {var_name} as string") - - else: - logger.warning( - f"Unsupported constant type for {var_name}: {type(rval.value).__name__}" - ) - - -def _allocate_for_binop(builder, var_name, local_sym_tab): - """Allocate memory for variable assigned from a binary operation.""" - ir_type = ir.IntType(64) # Assume i64 result - var = builder.alloca(ir_type, name=var_name) - var.align = 8 - local_sym_tab[var_name] = LocalSymbol(var, ir_type) - logger.info(f"Pre-allocated {var_name} for binop result") - - -def _allocate_temp_pool(builder, max_temps, local_sym_tab): - """Allocate the temporary scratch space pool for helper arguments.""" - if max_temps == 0: - return - - logger.info(f"Allocating temp pool of {max_temps} variables") - for i in range(max_temps): - temp_name = f"__helper_temp_{i}" - temp_var = builder.alloca(ir.IntType(64), name=temp_name) - temp_var.align = 8 - local_sym_tab[temp_name] = LocalSymbol(temp_var, ir.IntType(64)) - - def count_temps_in_call(call_node, local_sym_tab): """Count the number of temporary variables needed for a function call.""" @@ -475,7 +293,7 @@ def allocate_mem( # Handle allocations if isinstance(stmt, ast.If): - _handle_if_allocation( + handle_if_allocation( module, builder, stmt, @@ -486,9 +304,9 @@ def allocate_mem( structs_sym_tab, ) elif isinstance(stmt, ast.Assign): - _handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab) + handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab) - _allocate_temp_pool(builder, max_temps_needed, local_sym_tab) + allocate_temp_pool(builder, max_temps_needed, local_sym_tab) return local_sym_tab diff --git a/pythonbpf/local_symbol.py b/pythonbpf/local_symbol.py deleted file mode 100644 index d0f76d5..0000000 --- a/pythonbpf/local_symbol.py +++ /dev/null @@ -1,15 +0,0 @@ -from llvm import ir -from dataclasses import dataclass -from typing import Any - - -@dataclass -class LocalSymbol: - var: ir.AllocaInstr - ir_type: ir.Type - metadata: Any = None - - def __iter__(self): - yield self.var - yield self.ir_type - yield self.metadata