diff --git a/pythonbpf/helper/bpf_helper_handler.py b/pythonbpf/helper/bpf_helper_handler.py index f5ae9a0..0dd2ba4 100644 --- a/pythonbpf/helper/bpf_helper_handler.py +++ b/pythonbpf/helper/bpf_helper_handler.py @@ -64,7 +64,9 @@ def bpf_map_lookup_elem_emitter( raise ValueError( f"Map lookup expects exactly one argument (key), got {len(call.args)}" ) - key_ptr = get_or_create_ptr_from_arg(call.args[0], builder, local_sym_tab) + key_ptr = get_or_create_ptr_from_arg( + func, module, call.args[0], builder, local_sym_tab, struct_sym_tab + ) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) fn_type = ir.FunctionType( @@ -152,8 +154,12 @@ def bpf_map_update_elem_emitter( value_arg = call.args[1] flags_arg = call.args[2] if len(call.args) > 2 else None - key_ptr = get_or_create_ptr_from_arg(key_arg, builder, local_sym_tab) - value_ptr = get_or_create_ptr_from_arg(value_arg, builder, local_sym_tab) + key_ptr = get_or_create_ptr_from_arg( + func, module, key_arg, builder, local_sym_tab, struct_sym_tab + ) + value_ptr = get_or_create_ptr_from_arg( + func, module, value_arg, builder, local_sym_tab, struct_sym_tab + ) flags_val = get_flags_val(flags_arg, builder, local_sym_tab) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) @@ -197,7 +203,9 @@ def bpf_map_delete_elem_emitter( raise ValueError( f"Map delete expects exactly one argument (key), got {len(call.args)}" ) - key_ptr = get_or_create_ptr_from_arg(call.args[0], builder, local_sym_tab) + key_ptr = get_or_create_ptr_from_arg( + func, module, call.args[0], builder, local_sym_tab, struct_sym_tab + ) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) # Define function type for bpf_map_delete_elem diff --git a/pythonbpf/helper/helper_utils.py b/pythonbpf/helper/helper_utils.py index 53198fd..c12e56b 100644 --- a/pythonbpf/helper/helper_utils.py +++ b/pythonbpf/helper/helper_utils.py @@ -87,23 +87,25 @@ def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64): return ptr -def get_or_create_ptr_from_arg(arg, builder, local_sym_tab): +def get_or_create_ptr_from_arg( + func, module, arg, builder, local_sym_tab, struct_sym_tab=None +): """Extract or create pointer from the call arguments.""" if isinstance(arg, ast.Name): ptr = get_var_ptr_from_name(arg.id, local_sym_tab) elif isinstance(arg, ast.Constant) and isinstance(arg.value, int): ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab) - elif isinstance(arg, ast.BinOp): + else: # Evaluate the expression and store the result in a temp variable val, _ = eval_expr( - None, - None, + func, + module, builder, arg, local_sym_tab, None, - None, + struct_sym_tab, ) if val is None: raise ValueError("Failed to evaluate expression for helper arg.") @@ -113,10 +115,6 @@ def get_or_create_ptr_from_arg(arg, builder, local_sym_tab): logger.debug(f"Using temp variable '{temp_name}' for expression result") builder.store(val, ptr) - else: - raise NotImplementedError( - "Only simple variable names are supported as args in map helpers." - ) return ptr