diff --git a/pythonbpf/helper/bpf_helper_handler.py b/pythonbpf/helper/bpf_helper_handler.py index 104078c..d6301a3 100644 --- a/pythonbpf/helper/bpf_helper_handler.py +++ b/pythonbpf/helper/bpf_helper_handler.py @@ -2,7 +2,7 @@ import ast from llvmlite import ir from pythonbpf.expr_pass import eval_expr from enum import Enum -from .helper_utils import HelperHandlerRegistry +from .helper_utils import HelperHandlerRegistry, get_key_ptr class BPFHelperID(Enum): @@ -38,31 +38,7 @@ def bpf_map_lookup_elem_emitter(call, map_ptr, module, builder, func, """ Emit LLVM IR for bpf_map_lookup_elem helper function call. """ - if call.args and len(call.args) != 1: - raise ValueError("Map lookup expects exactly one argument, got " - f"{len(call.args)}") - key_arg = call.args[0] - if isinstance(key_arg, ast.Name): - key_name = key_arg.id - if local_sym_tab and key_name in local_sym_tab: - key_ptr = local_sym_tab[key_name][0] - else: - raise ValueError( - f"Key variable {key_name} not found in local symbol table.") - elif isinstance(key_arg, ast.Constant) and isinstance(key_arg.value, int): - # handle constant integer keys - key_val = key_arg.value - key_type = ir.IntType(64) - key_ptr = builder.alloca(key_type) - key_ptr.align = key_type // 8 - builder.store(ir.Constant(key_type, key_val), key_ptr) - else: - raise NotImplementedError( - "Only simple variable names are supported as keys in map lookup.") - - if key_ptr is None: - raise ValueError("Key pointer is None.") - + key_ptr = get_key_ptr(call, builder, local_sym_tab) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) fn_type = ir.FunctionType( @@ -72,7 +48,6 @@ def bpf_map_lookup_elem_emitter(call, map_ptr, module, builder, func, ) fn_ptr_type = ir.PointerType(fn_type) - # Helper ID 1 is bpf_map_lookup_elem fn_addr = ir.Constant(ir.IntType( 64), BPFHelperID.BPF_MAP_LOOKUP_ELEM.value) fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) diff --git a/pythonbpf/helper/helper_utils.py b/pythonbpf/helper/helper_utils.py index 8ff7276..3594c3c 100644 --- a/pythonbpf/helper/helper_utils.py +++ b/pythonbpf/helper/helper_utils.py @@ -1,3 +1,7 @@ +import ast +from llvmlite import ir + + class HelperHandlerRegistry: """Registry for BPF helpers""" _handlers = {} @@ -14,3 +18,38 @@ class HelperHandlerRegistry: def get_handler(cls, helper_name): """Get the handler function for a helper""" return cls._handlers.get(helper_name) + + +def get_var_ptr_from_name(var_name, local_sym_tab): + """Get a pointer to a variable from the symbol table.""" + if local_sym_tab and var_name in local_sym_tab: + return local_sym_tab[var_name][0] + raise ValueError(f"Variable '{var_name}' not found in local symbol table") + + +def create_int_constant_ptr(value, builder, int_width=64): + """Create a pointer to an integer constant.""" + # Default to 64-bit integer + int_type = ir.IntType(int_width) + ptr = builder.alloca(int_type) + ptr.align = int_type.width // 8 + builder.store(ir.Constant(int_type, value), ptr) + return ptr + + +def get_key_ptr(call, builder, local_sym_tab): + """Extract key pointer from the call arguments.""" + if not call.args or len(call.args) != 1: + raise ValueError("Map lookup expects exactly one argument, got " + f"{len(call.args)}") + + key_arg = call.args[0] + + if isinstance(key_arg, ast.Name): + key_ptr = get_var_ptr_from_name(key_arg.id, local_sym_tab) + elif isinstance(key_arg, ast.Constant) and isinstance(key_arg.value, int): + key_ptr = create_int_constant_ptr(key_arg.value, builder) + else: + raise NotImplementedError( + "Only simple variable names are supported as keys in map lookup.") + return key_ptr