From 19b42b9a19ad111c9a93f23574b4823dd354adb1 Mon Sep 17 00:00:00 2001 From: Pragyansh Chaturvedi Date: Wed, 19 Nov 2025 04:09:51 +0530 Subject: [PATCH] Allocate hashmap lookup return vars based on the value type of said hashmap --- pythonbpf/allocation_pass.py | 80 +++++++++++++++++++++++++-- pythonbpf/functions/functions_pass.py | 4 +- pythonbpf/maps/maps_pass.py | 2 +- pythonbpf/maps/maps_utils.py | 1 + 4 files changed, 79 insertions(+), 8 deletions(-) diff --git a/pythonbpf/allocation_pass.py b/pythonbpf/allocation_pass.py index f16c36f..5e44eb4 100644 --- a/pythonbpf/allocation_pass.py +++ b/pythonbpf/allocation_pass.py @@ -7,6 +7,7 @@ from pythonbpf.helper import HelperHandlerRegistry from pythonbpf.vmlinux_parser.dependency_node import Field from .expr import VmlinuxHandlerRegistry from pythonbpf.type_deducer import ctypes_to_ir +from pythonbpf.maps import BPFMapType logger = logging.getLogger(__name__) @@ -25,7 +26,9 @@ def create_targets_and_rvals(stmt): return stmt.targets, [stmt.value] -def handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab): +def handle_assign_allocation( + builder, stmt, local_sym_tab, map_sym_tab, structs_sym_tab +): """Handle memory allocation for assignment statements.""" logger.info(f"Handling assignment for allocation: {ast.dump(stmt)}") @@ -55,7 +58,9 @@ def handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab): # Determine type and allocate based on rval if isinstance(rval, ast.Call): - _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab) + _allocate_for_call( + builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab + ) elif isinstance(rval, ast.Constant): _allocate_for_constant(builder, var_name, rval, local_sym_tab) elif isinstance(rval, ast.BinOp): @@ -74,7 +79,9 @@ def handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab): ) -def _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab): +def _allocate_for_call( + builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab +): """Allocate memory for variable assigned from a call.""" if isinstance(rval.func, ast.Name): @@ -116,15 +123,74 @@ def _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab): elif isinstance(rval.func, ast.Attribute): # Map method calls - need double allocation for ptr handling - _allocate_for_map_method(builder, var_name, local_sym_tab) + _allocate_for_map_method( + builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab + ) else: logger.warning(f"Unsupported call function type for {var_name}") -def _allocate_for_map_method(builder, var_name, local_sym_tab): +def _allocate_for_map_method( + builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab +): """Allocate memory for variable assigned from map method (double alloc).""" + map_name = rval.func.value.id + method_name = rval.func.attr + + # NOTE: We will have to special case HashMap.lookup which returns a pointer to value type + # The value type can be a struct as well, so we need to handle that properly + # This special casing is not ideal, as over time other map methods may need similar handling + # But for now, we will just handle lookup specifically + if map_name not in map_sym_tab: + logger.error(f"Map '{map_name}' not found for allocation") + return + + if method_name != "lookup": + # Fallback allocation for other map methods + _allocate_for_map_method_fallback(builder, var_name, local_sym_tab) + return + + map_params = map_sym_tab[map_name].params + if map_params["type"] != BPFMapType.HASH: + logger.warning( + "Map method lookup used on non-hash map, using fallback allocation" + ) + _allocate_for_map_method_fallback(builder, var_name, local_sym_tab) + return + + value_type = map_params["value"] + # Determine IR type for value + if isinstance(value_type, str) and value_type in structs_sym_tab: + struct_info = structs_sym_tab[value_type] + value_ir_type = struct_info.ir_type + else: + value_ir_type = ctypes_to_ir(value_type) + + if value_ir_type is None: + logger.warning( + f"Could not determine IR type for map value '{value_type}', using fallback allocation" + ) + _allocate_for_map_method_fallback(builder, var_name, local_sym_tab) + return + + # Main variable (pointer to pointer) + ir_type = ir.PointerType(value_ir_type) + var = builder.alloca(ir_type, name=var_name) + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + # Temporary variable for computed values + tmp_ir_type = value_ir_type + var_tmp = builder.alloca(tmp_ir_type, name=f"{var_name}_tmp") + local_sym_tab[f"{var_name}_tmp"] = LocalSymbol(var_tmp, tmp_ir_type) + logger.info( + f"Pre-allocated {var_name} and {var_name}_tmp for map method lookup of type {value_ir_type}" + ) + + +def _allocate_for_map_method_fallback(builder, var_name, local_sym_tab): + """Fallback allocation for map method variable (i64* and i64**).""" + # Main variable (pointer to pointer) ir_type = ir.PointerType(ir.IntType(64)) var = builder.alloca(ir_type, name=var_name) @@ -135,7 +201,9 @@ def _allocate_for_map_method(builder, var_name, local_sym_tab): var_tmp = builder.alloca(tmp_ir_type, name=f"{var_name}_tmp") local_sym_tab[f"{var_name}_tmp"] = LocalSymbol(var_tmp, tmp_ir_type) - logger.info(f"Pre-allocated {var_name} and {var_name}_tmp for map method") + logger.info( + f"Pre-allocated {var_name} and {var_name}_tmp for map method (fallback)" + ) def _allocate_for_constant(builder, var_name, rval, local_sym_tab): diff --git a/pythonbpf/functions/functions_pass.py b/pythonbpf/functions/functions_pass.py index 8e0a995..f78ed92 100644 --- a/pythonbpf/functions/functions_pass.py +++ b/pythonbpf/functions/functions_pass.py @@ -147,7 +147,9 @@ def allocate_mem( structs_sym_tab, ) elif isinstance(stmt, ast.Assign): - handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab) + handle_assign_allocation( + builder, stmt, local_sym_tab, map_sym_tab, structs_sym_tab + ) allocate_temp_pool(builder, max_temps_needed, local_sym_tab) diff --git a/pythonbpf/maps/maps_pass.py b/pythonbpf/maps/maps_pass.py index ed60958..ac498dc 100644 --- a/pythonbpf/maps/maps_pass.py +++ b/pythonbpf/maps/maps_pass.py @@ -48,7 +48,7 @@ def create_bpf_map(module, map_name, map_params): map_global.align = 8 logger.info(f"Created BPF map: {map_name} with params {map_params}") - return MapSymbol(type=map_params["type"], sym=map_global) + return MapSymbol(type=map_params["type"], sym=map_global, params=map_params) def _parse_map_params(rval, expected_args=None): diff --git a/pythonbpf/maps/maps_utils.py b/pythonbpf/maps/maps_utils.py index eaa43b2..194b408 100644 --- a/pythonbpf/maps/maps_utils.py +++ b/pythonbpf/maps/maps_utils.py @@ -11,6 +11,7 @@ class MapSymbol: type: BPFMapType sym: ir.GlobalVariable + params: dict[str, Any] | None = None class MapProcessorRegistry: