diff --git a/pythonbpf/allocation_pass.py b/pythonbpf/allocation_pass.py index 3149c75..a0559a2 100644 --- a/pythonbpf/allocation_pass.py +++ b/pythonbpf/allocation_pass.py @@ -51,7 +51,7 @@ def handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab): return # When allocating a variable, check if it's a vmlinux struct type - if isinstance(stmt.value, ast.Name) and VmlinuxHandlerRegistry.is_vmlinux_struct( + if isinstance(rval, ast.Name) and VmlinuxHandlerRegistry.is_vmlinux_struct( stmt.value.id ): # Handle vmlinux struct allocation diff --git a/pythonbpf/codegen.py b/pythonbpf/codegen.py index c3dd45f..0bf8c2d 100644 --- a/pythonbpf/codegen.py +++ b/pythonbpf/codegen.py @@ -36,6 +36,7 @@ def finalize_module(original_str): replacement = r'\1 "btf_ama"' return re.sub(pattern, replacement, original_str) + def bpf_passthrough_gen(module): i32_ty = ir.IntType(32) ptr_ty = ir.PointerType(ir.IntType(8)) diff --git a/pythonbpf/expr/vmlinux_registry.py b/pythonbpf/expr/vmlinux_registry.py index 9e9d52e..7c25095 100644 --- a/pythonbpf/expr/vmlinux_registry.py +++ b/pythonbpf/expr/vmlinux_registry.py @@ -1,5 +1,7 @@ import ast +from pythonbpf.vmlinux_parser.vmlinux_exports_handler import VmlinuxHandler + class VmlinuxHandlerRegistry: """Registry for vmlinux handler operations""" @@ -7,7 +9,7 @@ class VmlinuxHandlerRegistry: _handler = None @classmethod - def set_handler(cls, handler): + def set_handler(cls, handler: VmlinuxHandler): """Set the vmlinux handler""" cls._handler = handler @@ -43,3 +45,10 @@ class VmlinuxHandlerRegistry: if cls._handler is None: return False return cls._handler.is_vmlinux_struct(name) + + @classmethod + def get_struct_type(cls, name): + """Try to handle a struct name as vmlinux struct""" + if cls._handler is None: + return None + return cls._handler.get_vmlinux_struct_type(name) diff --git a/pythonbpf/functions/functions_pass.py b/pythonbpf/functions/functions_pass.py index 8243344..359efe1 100644 --- a/pythonbpf/functions/functions_pass.py +++ b/pythonbpf/functions/functions_pass.py @@ -7,7 +7,12 @@ from pythonbpf.helper import ( reset_scratch_pool, ) from pythonbpf.type_deducer import ctypes_to_ir -from pythonbpf.expr import eval_expr, handle_expr, convert_to_bool +from pythonbpf.expr import ( + eval_expr, + handle_expr, + convert_to_bool, + VmlinuxHandlerRegistry, +) from pythonbpf.assign_pass import ( handle_variable_assignment, handle_struct_field_assignment, @@ -337,6 +342,35 @@ def process_func_body( structs_sym_tab, ) + # Add the context parameter (first function argument) to the local symbol table + if func_node.args.args and len(func_node.args.args) > 0: + context_arg = func_node.args.args[0] + context_name = context_arg.arg + + if hasattr(context_arg, "annotation") and context_arg.annotation: + if isinstance(context_arg.annotation, ast.Name): + context_type_name = context_arg.annotation.id + elif isinstance(context_arg.annotation, ast.Attribute): + context_type_name = context_arg.annotation.attr + else: + raise TypeError( + f"Unsupported annotation type: {ast.dump(context_arg.annotation)}" + ) + if VmlinuxHandlerRegistry.is_vmlinux_struct(context_type_name): + resolved_type = VmlinuxHandlerRegistry.get_struct_type( + context_type_name + ) + context_type = {"type": ir.PointerType(resolved_type), "ptr": True} + else: + try: + resolved_type = ctypes_to_ir(context_type_name) + context_type = {"type": ir.PointerType(resolved_type), "ptr": True} + except Exception: + raise TypeError(f"Type '{context_type_name}' not declared") + + local_sym_tab[context_name] = context_type + logger.info(f"Added argument '{context_name}' to local symbol table") + logger.info(f"Local symbol table: {local_sym_tab.keys()}") for stmt in func_node.body: diff --git a/pythonbpf/helper/bpf_helper_handler.py b/pythonbpf/helper/bpf_helper_handler.py index 6e29614..5b06e5b 100644 --- a/pythonbpf/helper/bpf_helper_handler.py +++ b/pythonbpf/helper/bpf_helper_handler.py @@ -73,9 +73,9 @@ def bpf_map_lookup_elem_emitter( map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) # TODO: I have changed the return type to i64*, as we are - # allocating space for that type in allocate_mem. This is - # temporary, and we will honour other widths later. But this - # allows us to have cool binary ops on the returned value. + # allocating space for that type in allocate_mem. This is + # temporary, and we will honour other widths later. But this + # allows us to have cool binary ops on the returned value. fn_type = ir.FunctionType( ir.PointerType(ir.IntType(64)), # Return type: void* [ir.PointerType(), ir.PointerType()], # Args: (void*, void*) diff --git a/pythonbpf/vmlinux_parser/import_detector.py b/pythonbpf/vmlinux_parser/import_detector.py index 549e5f2..7cffae6 100644 --- a/pythonbpf/vmlinux_parser/import_detector.py +++ b/pythonbpf/vmlinux_parser/import_detector.py @@ -2,7 +2,6 @@ import ast import logging import importlib import inspect -import llvmlite.ir as ir from .assignment_info import AssignmentInfo, AssignmentType from .dependency_handler import DependencyHandler diff --git a/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py b/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py index 1986b44..e244cfc 100644 --- a/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py +++ b/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py @@ -39,6 +39,16 @@ class VmlinuxHandler: and self.vmlinux_symtab[name]["value_type"] == AssignmentType.CONSTANT ) + def get_vmlinux_struct_type(self, name): + """Check if name is a vmlinux struct type""" + if ( + name in self.vmlinux_symtab + and self.vmlinux_symtab[name]["value_type"] == AssignmentType.STRUCT + ): + return self.vmlinux_symtab[name]["python_type"] + else: + raise ValueError(f"{name} is not a vmlinux struct type") + def is_vmlinux_struct(self, name): """Check if name is a vmlinux struct""" return ( diff --git a/tests/failing_tests/vmlinux/struct_field_access.py b/tests/failing_tests/vmlinux/struct_field_access.py index f1f33cd..f3db2bf 100644 --- a/tests/failing_tests/vmlinux/struct_field_access.py +++ b/tests/failing_tests/vmlinux/struct_field_access.py @@ -4,7 +4,8 @@ from pythonbpf import bpf, section, bpfglobal, compile_to_ir from pythonbpf import compile # noqa: F401 from vmlinux import TASK_COMM_LEN # noqa: F401 from vmlinux import struct_trace_event_raw_sys_enter # noqa: F401 -from ctypes import c_int64 +from ctypes import c_int64, c_void_p # noqa: F401 + # from vmlinux import struct_uinput_device # from vmlinux import struct_blk_integrity_iter @@ -14,7 +15,9 @@ from ctypes import c_int64 @section("tracepoint/syscalls/sys_enter_execve") def hello_world(ctx: struct_trace_event_raw_sys_enter) -> c_int64: a = 2 + TASK_COMM_LEN + TASK_COMM_LEN + # b = ctx print(f"Hello, World{TASK_COMM_LEN} and {a}") + # print(f"This is context field {b}") return c_int64(TASK_COMM_LEN + 2)