diff --git a/pythonbpf/allocation_pass.py b/pythonbpf/allocation_pass.py index b96a9cf..4955f3f 100644 --- a/pythonbpf/allocation_pass.py +++ b/pythonbpf/allocation_pass.py @@ -1,6 +1,6 @@ import ast import logging - +import ctypes from llvmlite import ir from .local_symbol import LocalSymbol from pythonbpf.helper import HelperHandlerRegistry @@ -249,7 +249,46 @@ def _allocate_for_attribute(builder, var_name, rval, local_sym_tab, structs_sym_ ].var = base_ptr # This is repurposing of var to store the pointer of the base type local_sym_tab[struct_var].ir_type = field_ir - actual_ir_type = ir.IntType(64) + # Determine the actual IR type based on the field's type + actual_ir_type = None + + # Check if it's a ctypes primitive + if field.type.__module__ == ctypes.__name__: + try: + field_size_bytes = ctypes.sizeof(field.type) + field_size_bits = field_size_bytes * 8 + + if field_size_bits in [8, 16, 32, 64]: + actual_ir_type = ir.IntType(field_size_bits) + else: + logger.warning( + f"Unusual field size {field_size_bits} bits for {field_name}" + ) + actual_ir_type = ir.IntType(64) + except Exception as e: + logger.warning( + f"Could not determine size for ctypes field {field_name}: {e}" + ) + actual_ir_type = ir.IntType(64) + + # Check if it's a nested vmlinux struct or complex type + elif field.type.__module__ == "vmlinux": + # For pointers to structs, use pointer type (64-bit) + if field.ctype_complex_type is not None and issubclass( + field.ctype_complex_type, ctypes._Pointer + ): + actual_ir_type = ir.IntType(64) # Pointer is always 64-bit + # For embedded structs, this is more complex - might need different handling + else: + logger.warning( + f"Field {field_name} is a nested vmlinux struct, using i64 for now" + ) + actual_ir_type = ir.IntType(64) + else: + logger.warning( + f"Unknown field type module {field.type.__module__} for {field_name}" + ) + actual_ir_type = ir.IntType(64) # Allocate with the actual IR type, not the GlobalVariable var = _allocate_with_type(builder, var_name, actual_ir_type) diff --git a/pythonbpf/functions/function_debug_info.py b/pythonbpf/functions/function_debug_info.py index f924ebc..e22e9ad 100644 --- a/pythonbpf/functions/function_debug_info.py +++ b/pythonbpf/functions/function_debug_info.py @@ -59,7 +59,6 @@ def generate_function_debug_info( leading_argument_name, 1, pointer_to_context_debug_info ) retained_nodes = [context_local_variable] - print("function name", func_node.name) subprogram_debug_info = generator.create_subprogram( func_node.name, subroutine_type, retained_nodes ) diff --git a/pythonbpf/type_deducer.py b/pythonbpf/type_deducer.py index a6834a9..fd589ae 100644 --- a/pythonbpf/type_deducer.py +++ b/pythonbpf/type_deducer.py @@ -16,6 +16,8 @@ mapping = { "c_long": ir.IntType(64), "c_ulong": ir.IntType(64), "c_longlong": ir.IntType(64), + "c_uint": ir.IntType(32), + "c_int": ir.IntType(32), # Not so sure about this one "str": ir.PointerType(ir.IntType(8)), } diff --git a/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py b/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py index 62c0327..6ec5988 100644 --- a/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py +++ b/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py @@ -1,6 +1,6 @@ import logging from typing import Any - +import ctypes from llvmlite import ir from pythonbpf.local_symbol import LocalSymbol @@ -98,18 +98,16 @@ class VmlinuxHandler: python_type.__name__, field_name ) builder.function.args[0].type = ir.PointerType(ir.IntType(8)) - print(builder.function.args[0]) field_ptr = self.load_ctx_field( - builder, builder.function.args[0], globvar_ir + builder, builder.function.args[0], globvar_ir, field_data ) - print(field_ptr) # Return pointer to field and field type return field_ptr, field_data else: raise RuntimeError("Variable accessed not found in symbol table") @staticmethod - def load_ctx_field(builder, ctx_arg, offset_global): + def load_ctx_field(builder, ctx_arg, offset_global, field_data): """ Generate LLVM IR to load a field from BPF context using offset. @@ -117,7 +115,7 @@ class VmlinuxHandler: builder: llvmlite IRBuilder instance ctx_arg: The context pointer argument (ptr/i8*) offset_global: Global variable containing the field offset (i64) - + field_data: contains data about the field Returns: The loaded value (i64 register) """ @@ -164,9 +162,43 @@ class VmlinuxHandler: passthrough_fn, [ir.Constant(ir.IntType(32), 0), field_ptr], tail=True ) - # Bitcast to i64* (assuming field is 64-bit, adjust if needed) - i64_ptr_type = ir.PointerType(ir.IntType(64)) - typed_ptr = builder.bitcast(verified_ptr, i64_ptr_type) + # Determine the appropriate IR type based on field information + int_width = 64 # Default to 64-bit + + if field_data is not None: + # Try to determine the size from field metadata + if field_data.type.__module__ == ctypes.__name__: + try: + field_size_bytes = ctypes.sizeof(field_data.type) + field_size_bits = field_size_bytes * 8 + + if field_size_bits in [8, 16, 32, 64]: + int_width = field_size_bits + logger.info(f"Determined field size: {int_width} bits") + else: + logger.warning( + f"Unusual field size {field_size_bits} bits, using default 64" + ) + except Exception as e: + logger.warning( + f"Could not determine field size: {e}, using default 64" + ) + + elif field_data.type.__module__ == "vmlinux": + # For pointers to structs or complex vmlinux types + if field_data.ctype_complex_type is not None and issubclass( + field_data.ctype_complex_type, ctypes._Pointer + ): + int_width = 64 # Pointers are always 64-bit + logger.info("Field is a pointer type, using 64 bits") + # TODO: Add handling for other complex types (arrays, embedded structs, etc.) + else: + logger.warning("Complex vmlinux field type, using default 64 bits") + + # Bitcast to appropriate pointer type based on determined width + ptr_type = ir.PointerType(ir.IntType(int_width)) + + typed_ptr = builder.bitcast(verified_ptr, ptr_type) # Load and return the value value = builder.load(typed_ptr) diff --git a/tests/c-form/i32test.bpf.c b/tests/c-form/i32test.bpf.c new file mode 100644 index 0000000..71b479f --- /dev/null +++ b/tests/c-form/i32test.bpf.c @@ -0,0 +1,15 @@ +#include +#include + +SEC("xdp") +int print_xdp_data(struct xdp_md *ctx) +{ + // 'data' is a pointer to the start of packet data + void *data = (void *)(long)ctx->data; + + bpf_printk("ctx->data = %p\n", data); + + return XDP_PASS; +} + +char LICENSE[] SEC("license") = "GPL"; diff --git a/tests/failing_tests/vmlinux/args_test.py b/tests/failing_tests/vmlinux/args_test.py new file mode 100644 index 0000000..7acca25 --- /dev/null +++ b/tests/failing_tests/vmlinux/args_test.py @@ -0,0 +1,30 @@ +import logging + +from pythonbpf import bpf, section, bpfglobal, compile_to_ir +from pythonbpf import compile # noqa: F401 +from vmlinux import TASK_COMM_LEN # noqa: F401 +from vmlinux import struct_trace_event_raw_sys_enter # noqa: F401 +from ctypes import c_int64, c_int32, c_void_p # noqa: F401 + + +# from vmlinux import struct_uinput_device +# from vmlinux import struct_blk_integrity_iter + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: struct_trace_event_raw_sys_enter) -> c_int64: + b = ctx.args + c = b[0] + print(f"This is context args field {c}") + return c_int64(0) + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile_to_ir("args_test.py", "args_test.ll", loglevel=logging.INFO) +compile() diff --git a/tests/failing_tests/vmlinux/i32_test.py b/tests/failing_tests/vmlinux/i32_test.py new file mode 100644 index 0000000..86ce2b7 --- /dev/null +++ b/tests/failing_tests/vmlinux/i32_test.py @@ -0,0 +1,23 @@ +from ctypes import c_int64, c_int32 +from pythonbpf import bpf, section, bpfglobal, compile_to_ir, compile +from vmlinux import struct_xdp_md +from vmlinux import XDP_PASS + + +@bpf +@section("xdp") +def print_xdp_data(ctx: struct_xdp_md) -> c_int64: + data = ctx.data # 32-bit field: packet start pointer + something = c_int32(2 + data) + print(f"ctx->data = {something}") + return c_int64(XDP_PASS) + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile_to_ir("i32_test.py", "i32_test.ll") +compile() diff --git a/tests/passing_tests/vmlinux/simple_struct_test.py b/tests/passing_tests/vmlinux/simple_struct_test.py index 97ab54a..2f34ba4 100644 --- a/tests/passing_tests/vmlinux/simple_struct_test.py +++ b/tests/passing_tests/vmlinux/simple_struct_test.py @@ -44,4 +44,4 @@ def LICENSE() -> str: compile_to_ir("simple_struct_test.py", "simple_struct_test.ll", loglevel=logging.DEBUG) -# compile() +compile()