diff --git a/pythonbpf/bpf_helper_handler.py b/pythonbpf/bpf_helper_handler.py index a34c5bd..3c45098 100644 --- a/pythonbpf/bpf_helper_handler.py +++ b/pythonbpf/bpf_helper_handler.py @@ -6,6 +6,7 @@ def bpf_ktime_get_ns_emitter(call, module, builder, func): """ Emit LLVM IR for bpf_ktime_get_ns helper function call. """ + # func is an arg to just have a uniform signature with other emitters helper_id = ir.Constant(ir.IntType(64), 5) fn_type = ir.FunctionType(ir.IntType(64), [], var_arg=False) fn_ptr_type = ir.PointerType(fn_type) @@ -14,15 +15,38 @@ def bpf_ktime_get_ns_emitter(call, module, builder, func): return result -def bpf_map_lookup_elem_emitter(map_ptr, key_ptr, module, builder): +def bpf_map_lookup_elem_emitter(call, map_ptr, module, builder, local_sym_tab=None): """ Emit LLVM IR for bpf_map_lookup_elem helper function call. """ - # Cast pointers to void* + + if call.args and len(call.args) != 1: + raise ValueError("Map lookup expects exactly one argument, got " + f"{len(call.args)}") + key_arg = call.args[0] + if isinstance(key_arg, ast.Name): + key_name = key_arg.id + if local_sym_tab and key_name in local_sym_tab: + key_ptr = local_sym_tab[key_name] + else: + raise ValueError( + f"Key variable {key_name} not found in local symbol table.") + elif isinstance(key_arg, ast.Constant) and isinstance(key_arg.value, int): + # handle constant integer keys + key_val = key_arg.value + key_type = ir.IntType(64) + key_ptr = builder.alloca(key_type) + key_ptr.align = key_type // 8 + builder.store(ir.Constant(key_type, key_val), key_ptr) + else: + raise NotImplementedError( + "Only simple variable names are supported as keys in map lookup.") + + if key_ptr is None: + raise ValueError("Key pointer is None.") + map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) - # Define function type for bpf_map_lookup_elem - # The function takes two void* arguments and returns void* fn_type = ir.FunctionType( ir.PointerType(), # Return type: void* [ir.PointerType(), ir.PointerType()], # Args: (void*, void*) @@ -34,7 +58,6 @@ def bpf_map_lookup_elem_emitter(map_ptr, key_ptr, module, builder): fn_addr = ir.Constant(ir.IntType(64), 1) fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) - # Call the helper function result = builder.call(fn_ptr, [map_void_ptr, key_ptr], tail=False) return result @@ -79,5 +102,31 @@ helper_func_list = { } -def handle_helper_call(call, module, builder, func): - return None +def handle_helper_call(call, module, builder, func, local_sym_tab=None, map_sym_tab=None): + if isinstance(call.func, ast.Name): + func_name = call.func.id + if func_name in helper_func_list: + # it is not a map method call + helper_func_list[func_name](call, module, builder, func) + else: + raise NotImplementedError( + f"Function {func_name} is not implemented as a helper function.") + elif isinstance(call.func, ast.Attribute): + # likely a map method call + if isinstance(call.func.value, ast.Call) and isinstance(call.func.value.func, ast.Name): + map_name = call.func.value.func.id + method_name = call.func.attr + if map_sym_tab and map_name in map_sym_tab: + map_ptr = map_sym_tab[map_name] + if method_name in helper_func_list: + helper_func_list[method_name]( + call, map_ptr, module, builder, local_sym_tab) + else: + raise NotImplementedError( + f"Map method {method_name} is not implemented as a helper function.") + else: + raise ValueError( + f"Map variable {map_name} not found in symbol tables.") + else: + raise NotImplementedError( + "Attribute not supported for map method calls.") diff --git a/pythonbpf/functions_pass.py b/pythonbpf/functions_pass.py index d58d36d..10fbb69 100644 --- a/pythonbpf/functions_pass.py +++ b/pythonbpf/functions_pass.py @@ -1,7 +1,7 @@ from llvmlite import ir import ast -from .bpf_helper_handler import bpf_printk_emitter, bpf_ktime_get_ns_emitter, bpf_map_lookup_elem_emitter +from .bpf_helper_handler import helper_func_list, handle_helper_call from .type_deducer import ctypes_to_ir @@ -60,55 +60,16 @@ def handle_assign(module, builder, stmt, map_sym_tab, local_sym_tab): else: print(f"Unsupported assignment call type: {call_type}") elif isinstance(rval.func, ast.Attribute): - if isinstance(rval.func.attr, str) and rval.func.attr == "lookup": - # Get map name and check symtab - # maps are called as funcs - if isinstance(rval.func.value, ast.Call) and isinstance(rval.func.value.func, ast.Name): - map_name = rval.func.value.func.id - if map_name in map_sym_tab: - map_global = map_sym_tab[map_name] - print(f"Found map {map_name} in symtab for lookup") - if len(rval.args) != 1: - print("Unsupported lookup with != 1 arg") - return - key_arg = rval.args[0] - print(f"Lookup key arg type: {type(key_arg)}") - # TODO: implement a parse_arg ffs as this can be a fucking expr - if isinstance(key_arg, ast.Constant) and isinstance(key_arg.value, int): - key_val = key_arg.value - key_type = ir.IntType(64) - print(f"Key type: {key_type}") - print(f"Key val: {key_val}") - key_var = builder.alloca(key_type) - key_var.align = key_type // 8 - builder.store(ir.Constant( - key_type, key_val), key_var) - elif isinstance(key_arg, ast.Name): - # Check in local symtab first - if key_arg.id in local_sym_tab: - key_var = local_sym_tab[key_arg.id] - key_type = key_var.type.pointee - elif key_arg.id in map_sym_tab: - key_var = map_sym_tab[key_arg.id] - key_type = key_var.type.pointee - else: - print("Key variable " - f"{key_arg.id} not found in symtabs") - return - print(f"Found key variable {key_arg.id} in symtab") - print(f"Key type: {key_type}") - else: - print("Unsupported lookup key arg") - return - - # TODO: generate call to bpf_map_lookup_elem - result_ptr = bpf_map_lookup_elem_emitter( - map_global, key_var, module, builder) - - else: - print(f"Map {map_name} not found in symbol table") + if isinstance(rval.func.value, ast.Call) and isinstance(rval.func.value.func, ast.Name): + map_name = rval.func.value.func.id + method_name = rval.func.attr + if map_name in map_sym_tab: + map_ptr = map_sym_tab[map_name] + if method_name in helper_func_list: + handle_helper_call( + rval, module, builder, None, local_sym_tab, map_sym_tab) else: - print("Unsupported assignment from method call") + print("Unsupported assignment call structure") def process_func_body(module, builder, func_node, func, ret_type, map_sym_tab): @@ -121,10 +82,11 @@ def process_func_body(module, builder, func_node, func, ret_type, map_sym_tab): for stmt in func_node.body: if isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Call): call = stmt.value - if isinstance(call.func, ast.Name) and call.func.id == "print": - bpf_printk_emitter(call, module, builder, func) - if isinstance(call.func, ast.Name) and call.func.id == "ktime": - bpf_ktime_get_ns_emitter(call, module, builder, func) + if isinstance(call.func, ast.Name): + # check for helpers first + if call.func.id in helper_func_list: + handle_helper_call( + call, module, builder, func, local_sym_tab, map_sym_tab) elif isinstance(stmt, ast.Assign): handle_assign(module, builder, stmt, map_sym_tab, local_sym_tab) elif isinstance(stmt, ast.Return):