feat:user defined struct casting

This commit is contained in:
2025-11-27 12:41:57 +05:30
parent 4905649700
commit 1593b7bcfe
6 changed files with 176 additions and 63 deletions

View File

@ -114,9 +114,18 @@ def _allocate_for_call(
# Struct constructors
elif call_type in structs_sym_tab:
struct_info = structs_sym_tab[call_type]
var = builder.alloca(struct_info.ir_type, name=var_name)
local_sym_tab[var_name] = LocalSymbol(var, struct_info.ir_type, call_type)
logger.info(f"Pre-allocated {var_name} for struct {call_type}")
if len(rval.args) == 0:
# Zero-arg constructor: allocate the struct itself
var = builder.alloca(struct_info.ir_type, name=var_name)
local_sym_tab[var_name] = LocalSymbol(var, struct_info.ir_type, call_type)
logger.info(f"Pre-allocated {var_name} for struct {call_type}")
else:
# Pointer cast: allocate as pointer to struct
ptr_type = ir.PointerType(struct_info.ir_type)
var = builder.alloca(ptr_type, name=var_name)
var.align = 8
local_sym_tab[var_name] = LocalSymbol(var, ptr_type, call_type)
logger.info(f"Pre-allocated {var_name} for struct pointer cast to {call_type}")
elif VmlinuxHandlerRegistry.is_vmlinux_struct(call_type):
# When calling struct_name(pointer), we're doing a cast, not construction

View File

@ -174,6 +174,23 @@ def handle_variable_assignment(
f"Type mismatch: vmlinux struct pointer requires i64, got {var_type}"
)
return False
# Handle user-defined struct pointer casts
# val_type is a string (struct name), var_type is a pointer to the struct
if isinstance(val_type, str) and val_type in structs_sym_tab:
struct_info = structs_sym_tab[val_type]
expected_ptr_type = ir.PointerType(struct_info.ir_type)
# Check if var_type matches the expected pointer type
if isinstance(var_type, ir.PointerType):
# val is already the correct pointer type from inttoptr/bitcast
builder.store(val, var_ptr)
logger.info(f"Assigned user-defined struct pointer cast to {var_name}")
return True
else:
logger.error(
f"Type mismatch: user-defined struct pointer cast requires pointer type, got {var_type}"
)
return False
if isinstance(val_type, Field):
logger.info("Handling assignment to struct field")
# Special handling for struct_xdp_md i32 fields that are zero-extended to i64

View File

@ -618,7 +618,7 @@ def _handle_boolean_op(
# ============================================================================
# VMLinux casting
# Struct casting (including vmlinux struct casting)
# ============================================================================
@ -667,7 +667,7 @@ def _handle_vmlinux_cast(
# If arg_val is an integer type, we need to inttoptr it
ptr_type = ir.PointerType()
# TODO: add a field value type check here
print(arg_type)
# print(arg_type)
if isinstance(arg_type, Field):
if ctypes_to_ir(arg_type.type.__name__):
# Cast integer to pointer
@ -681,6 +681,69 @@ def _handle_vmlinux_cast(
return casted_ptr, vmlinux_struct_type
def _handle_user_defined_struct_cast(
func,
module,
builder,
expr,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
):
"""Handle user-defined struct cast expressions like iphdr(nh).
This casts a pointer/integer value to a pointer to the user-defined struct,
similar to how vmlinux struct casts work but for user-defined @struct types.
"""
if len(expr.args) != 1:
logger.info("User-defined struct cast takes exactly one argument")
return None
# Get the struct name
struct_name = expr.func.id
if struct_name not in structs_sym_tab:
logger.error(f"Struct {struct_name} not found in structs_sym_tab")
return None
struct_info = structs_sym_tab[struct_name]
# Evaluate the argument (e.g.,
# an address/pointer value)
arg_result = eval_expr(
func,
module,
builder,
expr.args[0],
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
if arg_result is None:
logger.info("Failed to evaluate argument to user-defined struct cast")
return None
arg_val, arg_type = arg_result
# Cast the integer/pointer value to a pointer to the struct type
# The struct pointer type is a pointer to the struct's IR type
struct_ptr_type = ir.PointerType(struct_info.ir_type)
# If arg_val is an integer type (like i64), convert to pointer using inttoptr
if isinstance(arg_val.type, ir.IntType):
casted_ptr = builder.inttoptr(arg_val, struct_ptr_type)
logger.info(f"Cast integer to pointer for struct {struct_name}")
elif isinstance(arg_val.type, ir.PointerType):
# If already a pointer, bitcast to the struct pointer type
casted_ptr = builder.bitcast(arg_val, struct_ptr_type)
logger.info(f"Bitcast pointer to struct pointer for {struct_name}")
else:
logger.error(f"Unsupported type for user-defined struct cast: {arg_val.type}")
return None
return casted_ptr, struct_name
# ============================================================================
# Expression Dispatcher
# ============================================================================
@ -726,6 +789,16 @@ def eval_expr(
map_sym_tab,
structs_sym_tab,
)
if isinstance(expr.func, ast.Name) and (expr.func.id in structs_sym_tab):
return _handle_user_defined_struct_cast(
func,
module,
builder,
expr,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
result = CallHandlerRegistry.handle_call(
expr, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab

View File

@ -20,6 +20,8 @@ mapping = {
"c_int": ir.IntType(32),
"c_ushort": ir.IntType(16),
"c_short": ir.IntType(16),
"c_ubyte": ir.IntType(8),
"c_byte": ir.IntType(8),
# Not so sure about this one
"str": ir.PointerType(ir.IntType(8)),
}

View File

@ -1,72 +1,38 @@
// xdp_ip_map.c
#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
#include <bpf/bpf_endian.h>
#include <linux/if_ether.h>
#include <linux/ip.h>
#include <bpf/bpf_helpers.h>
struct ip_key {
__u8 family; // 4 = IPv4
__u8 pad[3]; // padding for alignment
__u8 addr[16]; // IPv4 uses first 4 bytes
struct fake_iphdr {
unsigned short useless;
unsigned short tot_len;
unsigned short id;
unsigned short frag_off;
unsigned char ttl;
unsigned char protocol;
unsigned short check;
unsigned int saddr;
unsigned int daddr;
};
// key → packet count
struct {
__uint(type, BPF_MAP_TYPE_HASH);
__uint(max_entries, 16384);
__type(key, struct ip_key);
__type(value, __u64);
} ip_count_map SEC(".maps");
SEC("xdp")
int xdp_ip_map(struct xdp_md *ctx)
int xdp_prog(struct xdp_md *ctx)
{
void *data_end = (void *)(long)ctx->data_end;
void *data = (void *)(long)ctx->data;
struct ethhdr *eth = data;
void *data_end = (void *)(long)ctx->data_end;
void *data = (void *)(long)ctx->data;
if (eth + 1 > (struct ethhdr *)data_end)
return XDP_PASS;
struct ethhdr *eth = data;
if ((void *)(eth + 1) > data_end)
return XDP_ABORTED;
if (eth->h_proto != __constant_htons(ETH_P_IP))
return XDP_PASS;
__u16 h_proto = eth->h_proto;
void *nh = data + sizeof(*eth);
struct fake_iphdr *iph = (struct fake_iphdr *)(eth + 1);
if ((void *)(iph + 1) > data_end)
return XDP_ABORTED;
bpf_printk("%d", iph->saddr);
// VLAN handling: single tag
if (h_proto == bpf_htons(ETH_P_8021Q) ||
h_proto == bpf_htons(ETH_P_8021AD)) {
if (nh + 4 > data_end)
return XDP_PASS;
h_proto = *(__u16 *)(nh + 2);
nh += 4;
}
struct ip_key key = {};
// IPv4
if (h_proto == bpf_htons(ETH_P_IP)) {
struct iphdr *iph = nh;
if (iph + 1 > (struct iphdr *)data_end)
return XDP_PASS;
key.family = 4;
// Copy 4 bytes of IPv4 address
__builtin_memcpy(key.addr, &iph->saddr, 4);
__u64 *val = bpf_map_lookup_elem(&ip_count_map, &key);
if (val)
(*val)++;
else {
__u64 init = 1;
bpf_map_update_elem(&ip_count_map, &key, &init, BPF_ANY);
}
return XDP_PASS;
}
return XDP_PASS;
return XDP_PASS;
}
char _license[] SEC("license") = "GPL";

View File

@ -0,0 +1,46 @@
from vmlinux import XDP_PASS, XDP_DROP
from vmlinux import (
struct_xdp_md,
struct_ethhdr,
)
from pythonbpf import bpf, section, bpfglobal, compile, compile_to_ir, struct
from ctypes import c_int64, c_ubyte, c_ushort, c_uint32
@bpf
@struct
class iphdr:
useless: c_ushort
tot_len: c_ushort
id: c_ushort
frag_off: c_ushort
ttl: c_ubyte
protocol: c_ubyte
check: c_ushort
saddr: c_uint32
daddr: c_uint32
@bpf
@section("xdp")
def ip_detector(ctx: struct_xdp_md) -> c_int64:
data = ctx.data
data_end = ctx.data_end
eth = struct_ethhdr(ctx.data)
nh = ctx.data + 14
if nh + 20 > data_end:
return c_int64(XDP_DROP)
iph = iphdr(nh)
h_proto = eth.h_proto
h_proto_ext = c_int64(h_proto)
ipv4 = iph.saddr
print(f"ipaddress: {ipv4}")
return c_int64(XDP_PASS)
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile_to_ir("xdp_test_1.py", "xdp_test_1.ll")
compile()