diff --git a/pythonbpf/allocation_pass.py b/pythonbpf/allocation_pass.py index a096739..56c039f 100644 --- a/pythonbpf/allocation_pass.py +++ b/pythonbpf/allocation_pass.py @@ -121,11 +121,10 @@ def _allocate_for_call( elif VmlinuxHandlerRegistry.is_vmlinux_struct(call_type): # When calling struct_name(pointer), we're doing a cast, not construction # So we allocate as a pointer (i64) not as the actual struct - ir_type = ir.IntType(64) # Pointer type - var = builder.alloca(ir_type, name=var_name) + var = builder.alloca(ir.PointerType(), name=var_name) var.align = 8 local_sym_tab[var_name] = LocalSymbol( - var, ir_type, VmlinuxHandlerRegistry.get_struct_type(call_type) + var, ir.PointerType(), VmlinuxHandlerRegistry.get_struct_type(call_type) ) logger.info( f"Pre-allocated {var_name} for vmlinux struct pointer cast to {call_type}" @@ -340,11 +339,11 @@ def _allocate_for_attribute(builder, var_name, rval, local_sym_tab, structs_sym_ field_ir, field = field_type # TODO: For now, we only support integer type allocations. # This always assumes first argument of function to be the context struct - base_ptr = builder.function.args[0] - local_sym_tab[ - struct_var - ].var = base_ptr # This is repurposing of var to store the pointer of the base type - local_sym_tab[struct_var].ir_type = field_ir + # base_ptr = builder.function.args[0] + # local_sym_tab[ + # struct_var + # ].var = base_ptr # This is repurposing of var to store the pointer of the base type + # local_sym_tab[struct_var].ir_type = field_ir # Determine the actual IR type based on the field's type actual_ir_type = None diff --git a/pythonbpf/assign_pass.py b/pythonbpf/assign_pass.py index 0bd48c6..fc84238 100644 --- a/pythonbpf/assign_pass.py +++ b/pythonbpf/assign_pass.py @@ -1,5 +1,7 @@ import ast import logging +from inspect import isclass + from llvmlite import ir from pythonbpf.expr import eval_expr from pythonbpf.helper import emit_probe_read_kernel_str_call @@ -150,6 +152,9 @@ def handle_variable_assignment( val, val_type = val_result logger.info(f"Evaluated value for {var_name}: {val} of type {val_type}, {var_type}") if val_type != var_type: + # if isclass(val_type) and (val_type.__module__ == "vmlinux"): + # logger.info("Handling typecast to vmlinux struct") + # print(val_type, var_type) if isinstance(val_type, Field): logger.info("Handling assignment to struct field") # Special handling for struct_xdp_md i32 fields that are zero-extended to i64 diff --git a/pythonbpf/expr/expr_pass.py b/pythonbpf/expr/expr_pass.py index a9eab98..335a764 100644 --- a/pythonbpf/expr/expr_pass.py +++ b/pythonbpf/expr/expr_pass.py @@ -524,6 +524,64 @@ def _handle_boolean_op( logger.error(f"Unsupported boolean operator: {type(expr.op).__name__}") return None +# ============================================================================ +# VMLinux casting +# ============================================================================ + +def _handle_vmlinux_cast( + func, + module, + builder, + expr, + local_sym_tab, + map_sym_tab, + structs_sym_tab=None, +): + # handle expressions such as struct_request(ctx.di) where struct_request is a vmlinux + # struct and ctx.di is a pointer to a struct but is actually represented as a c_uint64 + # which needs to be cast to a pointer. This is also a field of another vmlinux struct + """Handle vmlinux struct cast expressions like struct_request(ctx.di).""" + if len(expr.args) != 1: + logger.info("vmlinux struct cast takes exactly one argument") + return None + + # Get the struct name + struct_name = expr.func.id + + # Evaluate the argument (e.g., ctx.di which is a c_uint64) + arg_result = eval_expr( + func, + module, + builder, + expr.args[0], + local_sym_tab, + map_sym_tab, + structs_sym_tab, + ) + + if arg_result is None: + logger.info("Failed to evaluate argument to vmlinux struct cast") + return None + + arg_val, arg_type = arg_result + # Get the vmlinux struct type + vmlinux_struct_type = VmlinuxHandlerRegistry.get_struct_type(struct_name) + if vmlinux_struct_type is None: + logger.error(f"Failed to get vmlinux struct type for {struct_name}") + return None + # Cast the integer/value to a pointer to the struct + # If arg_val is an integer type, we need to inttoptr it + ptr_type = ir.PointerType() + #TODO: add a integer check here later + if ctypes_to_ir(arg_type.type.__name__): + # Cast integer to pointer + casted_ptr = builder.inttoptr(arg_val, ptr_type) + else: + logger.error(f"Unsupported type for vmlinux cast: {arg_type}") + return None + + return casted_ptr, vmlinux_struct_type + # ============================================================================ # Expression Dispatcher @@ -545,6 +603,16 @@ def eval_expr( elif isinstance(expr, ast.Constant): return _handle_constant_expr(module, builder, expr) elif isinstance(expr, ast.Call): + if isinstance(expr.func, ast.Name) and VmlinuxHandlerRegistry.is_vmlinux_struct(expr.func.id): + return _handle_vmlinux_cast( + func, + module, + builder, + expr, + local_sym_tab, + map_sym_tab, + structs_sym_tab, + ) if isinstance(expr.func, ast.Name) and expr.func.id == "deref": return _handle_deref_call(expr, local_sym_tab, builder) diff --git a/tests/c-form/requests.bpf.c b/tests/c-form/requests.bpf.c index 0e14e98..55b1239 100644 --- a/tests/c-form/requests.bpf.c +++ b/tests/c-form/requests.bpf.c @@ -1,15 +1,18 @@ #include "vmlinux.h" #include #include +#include char LICENSE[] SEC("license") = "GPL"; SEC("kprobe/blk_mq_start_request") int example(struct pt_regs *ctx) { + u64 a = ctx->r15; struct request *req = (struct request *)(ctx->di); - u32 data_len = req->__data_len; - bpf_printk("data length %u\n", data_len); + unsigned int something_ns = BPF_CORE_READ(req, timeout); + unsigned int data_len = BPF_CORE_READ(req, __data_len); + bpf_printk("data length %lld %ld %ld\n", data_len, something_ns, a); return 0; } diff --git a/tests/failing_tests/vmlinux/requests.py b/tests/failing_tests/vmlinux/requests.py index bab809f..a32636e 100644 --- a/tests/failing_tests/vmlinux/requests.py +++ b/tests/failing_tests/vmlinux/requests.py @@ -1,5 +1,5 @@ from vmlinux import struct_request, struct_pt_regs -from pythonbpf import bpf, section, bpfglobal, compile_to_ir +from pythonbpf import bpf, section, bpfglobal, compile_to_ir, compile import logging from ctypes import c_int64 @@ -7,9 +7,11 @@ from ctypes import c_int64 @bpf @section("kprobe/blk_mq_start_request") def example(ctx: struct_pt_regs) -> c_int64: + a = ctx.r15 req = struct_request(ctx.di) - c = req.__data_len - print(f"data length {c}") + d = req.__data_len + c = req.timeout + print(f"data length {d} and {c} and {a}") return c_int64(0) @@ -20,3 +22,4 @@ def LICENSE() -> str: compile_to_ir("requests.py", "requests.ll", loglevel=logging.INFO) +compile()