From 1593b7bcfe03f1e38d1aaea55963979cf8074b84 Mon Sep 17 00:00:00 2001 From: varun-r-mallya Date: Thu, 27 Nov 2025 12:41:57 +0530 Subject: [PATCH] feat:user defined struct casting --- pythonbpf/allocation_pass.py | 15 ++++- pythonbpf/assign_pass.py | 17 ++++++ pythonbpf/expr/expr_pass.py | 77 ++++++++++++++++++++++++- pythonbpf/type_deducer.py | 2 + tests/c-form/xdp_test.bpf.c | 82 ++++++++------------------- tests/failing_tests/xdp/xdp_test_1.py | 46 +++++++++++++++ 6 files changed, 176 insertions(+), 63 deletions(-) create mode 100644 tests/failing_tests/xdp/xdp_test_1.py diff --git a/pythonbpf/allocation_pass.py b/pythonbpf/allocation_pass.py index fc6f21e..9f5cb20 100644 --- a/pythonbpf/allocation_pass.py +++ b/pythonbpf/allocation_pass.py @@ -114,9 +114,18 @@ def _allocate_for_call( # Struct constructors elif call_type in structs_sym_tab: struct_info = structs_sym_tab[call_type] - var = builder.alloca(struct_info.ir_type, name=var_name) - local_sym_tab[var_name] = LocalSymbol(var, struct_info.ir_type, call_type) - logger.info(f"Pre-allocated {var_name} for struct {call_type}") + if len(rval.args) == 0: + # Zero-arg constructor: allocate the struct itself + var = builder.alloca(struct_info.ir_type, name=var_name) + local_sym_tab[var_name] = LocalSymbol(var, struct_info.ir_type, call_type) + logger.info(f"Pre-allocated {var_name} for struct {call_type}") + else: + # Pointer cast: allocate as pointer to struct + ptr_type = ir.PointerType(struct_info.ir_type) + var = builder.alloca(ptr_type, name=var_name) + var.align = 8 + local_sym_tab[var_name] = LocalSymbol(var, ptr_type, call_type) + logger.info(f"Pre-allocated {var_name} for struct pointer cast to {call_type}") elif VmlinuxHandlerRegistry.is_vmlinux_struct(call_type): # When calling struct_name(pointer), we're doing a cast, not construction diff --git a/pythonbpf/assign_pass.py b/pythonbpf/assign_pass.py index 5d73cf3..87e5657 100644 --- a/pythonbpf/assign_pass.py +++ b/pythonbpf/assign_pass.py @@ -174,6 +174,23 @@ def handle_variable_assignment( f"Type mismatch: vmlinux struct pointer requires i64, got {var_type}" ) return False + # Handle user-defined struct pointer casts + # val_type is a string (struct name), var_type is a pointer to the struct + if isinstance(val_type, str) and val_type in structs_sym_tab: + struct_info = structs_sym_tab[val_type] + expected_ptr_type = ir.PointerType(struct_info.ir_type) + + # Check if var_type matches the expected pointer type + if isinstance(var_type, ir.PointerType): + # val is already the correct pointer type from inttoptr/bitcast + builder.store(val, var_ptr) + logger.info(f"Assigned user-defined struct pointer cast to {var_name}") + return True + else: + logger.error( + f"Type mismatch: user-defined struct pointer cast requires pointer type, got {var_type}" + ) + return False if isinstance(val_type, Field): logger.info("Handling assignment to struct field") # Special handling for struct_xdp_md i32 fields that are zero-extended to i64 diff --git a/pythonbpf/expr/expr_pass.py b/pythonbpf/expr/expr_pass.py index abafb86..f52e924 100644 --- a/pythonbpf/expr/expr_pass.py +++ b/pythonbpf/expr/expr_pass.py @@ -618,7 +618,7 @@ def _handle_boolean_op( # ============================================================================ -# VMLinux casting +# Struct casting (including vmlinux struct casting) # ============================================================================ @@ -667,7 +667,7 @@ def _handle_vmlinux_cast( # If arg_val is an integer type, we need to inttoptr it ptr_type = ir.PointerType() # TODO: add a field value type check here - print(arg_type) + # print(arg_type) if isinstance(arg_type, Field): if ctypes_to_ir(arg_type.type.__name__): # Cast integer to pointer @@ -681,6 +681,69 @@ def _handle_vmlinux_cast( return casted_ptr, vmlinux_struct_type +def _handle_user_defined_struct_cast( + func, + module, + builder, + expr, + local_sym_tab, + map_sym_tab, + structs_sym_tab, +): + """Handle user-defined struct cast expressions like iphdr(nh). + + This casts a pointer/integer value to a pointer to the user-defined struct, + similar to how vmlinux struct casts work but for user-defined @struct types. + """ + if len(expr.args) != 1: + logger.info("User-defined struct cast takes exactly one argument") + return None + + # Get the struct name + struct_name = expr.func.id + + if struct_name not in structs_sym_tab: + logger.error(f"Struct {struct_name} not found in structs_sym_tab") + return None + + struct_info = structs_sym_tab[struct_name] + + # Evaluate the argument (e.g., + # an address/pointer value) + arg_result = eval_expr( + func, + module, + builder, + expr.args[0], + local_sym_tab, + map_sym_tab, + structs_sym_tab, + ) + + if arg_result is None: + logger.info("Failed to evaluate argument to user-defined struct cast") + return None + + arg_val, arg_type = arg_result + + # Cast the integer/pointer value to a pointer to the struct type + # The struct pointer type is a pointer to the struct's IR type + struct_ptr_type = ir.PointerType(struct_info.ir_type) + + # If arg_val is an integer type (like i64), convert to pointer using inttoptr + if isinstance(arg_val.type, ir.IntType): + casted_ptr = builder.inttoptr(arg_val, struct_ptr_type) + logger.info(f"Cast integer to pointer for struct {struct_name}") + elif isinstance(arg_val.type, ir.PointerType): + # If already a pointer, bitcast to the struct pointer type + casted_ptr = builder.bitcast(arg_val, struct_ptr_type) + logger.info(f"Bitcast pointer to struct pointer for {struct_name}") + else: + logger.error(f"Unsupported type for user-defined struct cast: {arg_val.type}") + return None + + return casted_ptr, struct_name + # ============================================================================ # Expression Dispatcher # ============================================================================ @@ -726,6 +789,16 @@ def eval_expr( map_sym_tab, structs_sym_tab, ) + if isinstance(expr.func, ast.Name) and (expr.func.id in structs_sym_tab): + return _handle_user_defined_struct_cast( + func, + module, + builder, + expr, + local_sym_tab, + map_sym_tab, + structs_sym_tab, + ) result = CallHandlerRegistry.handle_call( expr, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab diff --git a/pythonbpf/type_deducer.py b/pythonbpf/type_deducer.py index 734b80b..2e4c77f 100644 --- a/pythonbpf/type_deducer.py +++ b/pythonbpf/type_deducer.py @@ -20,6 +20,8 @@ mapping = { "c_int": ir.IntType(32), "c_ushort": ir.IntType(16), "c_short": ir.IntType(16), + "c_ubyte": ir.IntType(8), + "c_byte": ir.IntType(8), # Not so sure about this one "str": ir.PointerType(ir.IntType(8)), } diff --git a/tests/c-form/xdp_test.bpf.c b/tests/c-form/xdp_test.bpf.c index 152b0ae..0e90ce1 100644 --- a/tests/c-form/xdp_test.bpf.c +++ b/tests/c-form/xdp_test.bpf.c @@ -1,72 +1,38 @@ -// xdp_ip_map.c #include -#include -#include #include #include +#include -struct ip_key { - __u8 family; // 4 = IPv4 - __u8 pad[3]; // padding for alignment - __u8 addr[16]; // IPv4 uses first 4 bytes +struct fake_iphdr { + unsigned short useless; + unsigned short tot_len; + unsigned short id; + unsigned short frag_off; + unsigned char ttl; + unsigned char protocol; + unsigned short check; + unsigned int saddr; + unsigned int daddr; }; -// key → packet count -struct { - __uint(type, BPF_MAP_TYPE_HASH); - __uint(max_entries, 16384); - __type(key, struct ip_key); - __type(value, __u64); -} ip_count_map SEC(".maps"); - SEC("xdp") -int xdp_ip_map(struct xdp_md *ctx) +int xdp_prog(struct xdp_md *ctx) { - void *data_end = (void *)(long)ctx->data_end; - void *data = (void *)(long)ctx->data; - struct ethhdr *eth = data; + void *data_end = (void *)(long)ctx->data_end; + void *data = (void *)(long)ctx->data; - if (eth + 1 > (struct ethhdr *)data_end) - return XDP_PASS; + struct ethhdr *eth = data; + if ((void *)(eth + 1) > data_end) + return XDP_ABORTED; + if (eth->h_proto != __constant_htons(ETH_P_IP)) + return XDP_PASS; - __u16 h_proto = eth->h_proto; - void *nh = data + sizeof(*eth); + struct fake_iphdr *iph = (struct fake_iphdr *)(eth + 1); + if ((void *)(iph + 1) > data_end) + return XDP_ABORTED; + bpf_printk("%d", iph->saddr); - // VLAN handling: single tag - if (h_proto == bpf_htons(ETH_P_8021Q) || - h_proto == bpf_htons(ETH_P_8021AD)) { - - if (nh + 4 > data_end) - return XDP_PASS; - - h_proto = *(__u16 *)(nh + 2); - nh += 4; - } - - struct ip_key key = {}; - - // IPv4 - if (h_proto == bpf_htons(ETH_P_IP)) { - struct iphdr *iph = nh; - if (iph + 1 > (struct iphdr *)data_end) - return XDP_PASS; - - key.family = 4; - // Copy 4 bytes of IPv4 address - __builtin_memcpy(key.addr, &iph->saddr, 4); - - __u64 *val = bpf_map_lookup_elem(&ip_count_map, &key); - if (val) - (*val)++; - else { - __u64 init = 1; - bpf_map_update_elem(&ip_count_map, &key, &init, BPF_ANY); - } - - return XDP_PASS; - } - - return XDP_PASS; + return XDP_PASS; } char _license[] SEC("license") = "GPL"; diff --git a/tests/failing_tests/xdp/xdp_test_1.py b/tests/failing_tests/xdp/xdp_test_1.py new file mode 100644 index 0000000..b15414e --- /dev/null +++ b/tests/failing_tests/xdp/xdp_test_1.py @@ -0,0 +1,46 @@ +from vmlinux import XDP_PASS, XDP_DROP +from vmlinux import ( + struct_xdp_md, + struct_ethhdr, +) +from pythonbpf import bpf, section, bpfglobal, compile, compile_to_ir, struct +from ctypes import c_int64, c_ubyte, c_ushort, c_uint32 + +@bpf +@struct +class iphdr: + useless: c_ushort + tot_len: c_ushort + id: c_ushort + frag_off: c_ushort + ttl: c_ubyte + protocol: c_ubyte + check: c_ushort + saddr: c_uint32 + daddr: c_uint32 + +@bpf +@section("xdp") +def ip_detector(ctx: struct_xdp_md) -> c_int64: + data = ctx.data + data_end = ctx.data_end + eth = struct_ethhdr(ctx.data) + nh = ctx.data + 14 + if nh + 20 > data_end: + return c_int64(XDP_DROP) + iph = iphdr(nh) + h_proto = eth.h_proto + h_proto_ext = c_int64(h_proto) + ipv4 = iph.saddr + print(f"ipaddress: {ipv4}") + return c_int64(XDP_PASS) + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile_to_ir("xdp_test_1.py", "xdp_test_1.ll") +compile()