mirror of
https://github.com/varun-r-mallya/Python-BPF.git
synced 2025-12-31 21:06:25 +00:00
feat:user defined struct casting
This commit is contained in:
@ -114,9 +114,18 @@ def _allocate_for_call(
|
||||
# Struct constructors
|
||||
elif call_type in structs_sym_tab:
|
||||
struct_info = structs_sym_tab[call_type]
|
||||
var = builder.alloca(struct_info.ir_type, name=var_name)
|
||||
local_sym_tab[var_name] = LocalSymbol(var, struct_info.ir_type, call_type)
|
||||
logger.info(f"Pre-allocated {var_name} for struct {call_type}")
|
||||
if len(rval.args) == 0:
|
||||
# Zero-arg constructor: allocate the struct itself
|
||||
var = builder.alloca(struct_info.ir_type, name=var_name)
|
||||
local_sym_tab[var_name] = LocalSymbol(var, struct_info.ir_type, call_type)
|
||||
logger.info(f"Pre-allocated {var_name} for struct {call_type}")
|
||||
else:
|
||||
# Pointer cast: allocate as pointer to struct
|
||||
ptr_type = ir.PointerType(struct_info.ir_type)
|
||||
var = builder.alloca(ptr_type, name=var_name)
|
||||
var.align = 8
|
||||
local_sym_tab[var_name] = LocalSymbol(var, ptr_type, call_type)
|
||||
logger.info(f"Pre-allocated {var_name} for struct pointer cast to {call_type}")
|
||||
|
||||
elif VmlinuxHandlerRegistry.is_vmlinux_struct(call_type):
|
||||
# When calling struct_name(pointer), we're doing a cast, not construction
|
||||
|
||||
@ -174,6 +174,23 @@ def handle_variable_assignment(
|
||||
f"Type mismatch: vmlinux struct pointer requires i64, got {var_type}"
|
||||
)
|
||||
return False
|
||||
# Handle user-defined struct pointer casts
|
||||
# val_type is a string (struct name), var_type is a pointer to the struct
|
||||
if isinstance(val_type, str) and val_type in structs_sym_tab:
|
||||
struct_info = structs_sym_tab[val_type]
|
||||
expected_ptr_type = ir.PointerType(struct_info.ir_type)
|
||||
|
||||
# Check if var_type matches the expected pointer type
|
||||
if isinstance(var_type, ir.PointerType):
|
||||
# val is already the correct pointer type from inttoptr/bitcast
|
||||
builder.store(val, var_ptr)
|
||||
logger.info(f"Assigned user-defined struct pointer cast to {var_name}")
|
||||
return True
|
||||
else:
|
||||
logger.error(
|
||||
f"Type mismatch: user-defined struct pointer cast requires pointer type, got {var_type}"
|
||||
)
|
||||
return False
|
||||
if isinstance(val_type, Field):
|
||||
logger.info("Handling assignment to struct field")
|
||||
# Special handling for struct_xdp_md i32 fields that are zero-extended to i64
|
||||
|
||||
@ -618,7 +618,7 @@ def _handle_boolean_op(
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# VMLinux casting
|
||||
# Struct casting (including vmlinux struct casting)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@ -667,7 +667,7 @@ def _handle_vmlinux_cast(
|
||||
# If arg_val is an integer type, we need to inttoptr it
|
||||
ptr_type = ir.PointerType()
|
||||
# TODO: add a field value type check here
|
||||
print(arg_type)
|
||||
# print(arg_type)
|
||||
if isinstance(arg_type, Field):
|
||||
if ctypes_to_ir(arg_type.type.__name__):
|
||||
# Cast integer to pointer
|
||||
@ -681,6 +681,69 @@ def _handle_vmlinux_cast(
|
||||
return casted_ptr, vmlinux_struct_type
|
||||
|
||||
|
||||
def _handle_user_defined_struct_cast(
|
||||
func,
|
||||
module,
|
||||
builder,
|
||||
expr,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
):
|
||||
"""Handle user-defined struct cast expressions like iphdr(nh).
|
||||
|
||||
This casts a pointer/integer value to a pointer to the user-defined struct,
|
||||
similar to how vmlinux struct casts work but for user-defined @struct types.
|
||||
"""
|
||||
if len(expr.args) != 1:
|
||||
logger.info("User-defined struct cast takes exactly one argument")
|
||||
return None
|
||||
|
||||
# Get the struct name
|
||||
struct_name = expr.func.id
|
||||
|
||||
if struct_name not in structs_sym_tab:
|
||||
logger.error(f"Struct {struct_name} not found in structs_sym_tab")
|
||||
return None
|
||||
|
||||
struct_info = structs_sym_tab[struct_name]
|
||||
|
||||
# Evaluate the argument (e.g.,
|
||||
# an address/pointer value)
|
||||
arg_result = eval_expr(
|
||||
func,
|
||||
module,
|
||||
builder,
|
||||
expr.args[0],
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
|
||||
if arg_result is None:
|
||||
logger.info("Failed to evaluate argument to user-defined struct cast")
|
||||
return None
|
||||
|
||||
arg_val, arg_type = arg_result
|
||||
|
||||
# Cast the integer/pointer value to a pointer to the struct type
|
||||
# The struct pointer type is a pointer to the struct's IR type
|
||||
struct_ptr_type = ir.PointerType(struct_info.ir_type)
|
||||
|
||||
# If arg_val is an integer type (like i64), convert to pointer using inttoptr
|
||||
if isinstance(arg_val.type, ir.IntType):
|
||||
casted_ptr = builder.inttoptr(arg_val, struct_ptr_type)
|
||||
logger.info(f"Cast integer to pointer for struct {struct_name}")
|
||||
elif isinstance(arg_val.type, ir.PointerType):
|
||||
# If already a pointer, bitcast to the struct pointer type
|
||||
casted_ptr = builder.bitcast(arg_val, struct_ptr_type)
|
||||
logger.info(f"Bitcast pointer to struct pointer for {struct_name}")
|
||||
else:
|
||||
logger.error(f"Unsupported type for user-defined struct cast: {arg_val.type}")
|
||||
return None
|
||||
|
||||
return casted_ptr, struct_name
|
||||
|
||||
# ============================================================================
|
||||
# Expression Dispatcher
|
||||
# ============================================================================
|
||||
@ -726,6 +789,16 @@ def eval_expr(
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
if isinstance(expr.func, ast.Name) and (expr.func.id in structs_sym_tab):
|
||||
return _handle_user_defined_struct_cast(
|
||||
func,
|
||||
module,
|
||||
builder,
|
||||
expr,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
|
||||
result = CallHandlerRegistry.handle_call(
|
||||
expr, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
|
||||
@ -20,6 +20,8 @@ mapping = {
|
||||
"c_int": ir.IntType(32),
|
||||
"c_ushort": ir.IntType(16),
|
||||
"c_short": ir.IntType(16),
|
||||
"c_ubyte": ir.IntType(8),
|
||||
"c_byte": ir.IntType(8),
|
||||
# Not so sure about this one
|
||||
"str": ir.PointerType(ir.IntType(8)),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user