Unify struct and pointer to struct handling, abstract null check in ir_ops

This commit is contained in:
Pragyansh Chaturvedi
2025-11-23 06:26:09 +05:30
parent fed6af1ed6
commit cbe365d760
5 changed files with 146 additions and 131 deletions

View File

@ -1,6 +1,6 @@
from .expr_pass import eval_expr, handle_expr, get_operand_value from .expr_pass import eval_expr, handle_expr, get_operand_value
from .type_normalization import convert_to_bool, get_base_type_and_depth from .type_normalization import convert_to_bool, get_base_type_and_depth
from .ir_ops import deref_to_depth from .ir_ops import deref_to_depth, access_struct_field
from .call_registry import CallHandlerRegistry from .call_registry import CallHandlerRegistry
from .vmlinux_registry import VmlinuxHandlerRegistry from .vmlinux_registry import VmlinuxHandlerRegistry
@ -10,6 +10,7 @@ __all__ = [
"convert_to_bool", "convert_to_bool",
"get_base_type_and_depth", "get_base_type_and_depth",
"deref_to_depth", "deref_to_depth",
"access_struct_field",
"get_operand_value", "get_operand_value",
"CallHandlerRegistry", "CallHandlerRegistry",
"VmlinuxHandlerRegistry", "VmlinuxHandlerRegistry",

View File

@ -6,11 +6,11 @@ from typing import Dict
from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes
from .call_registry import CallHandlerRegistry from .call_registry import CallHandlerRegistry
from .ir_ops import deref_to_depth, access_struct_field
from .type_normalization import ( from .type_normalization import (
convert_to_bool, convert_to_bool,
handle_comparator, handle_comparator,
get_base_type_and_depth, get_base_type_and_depth,
deref_to_depth,
) )
from .vmlinux_registry import VmlinuxHandlerRegistry from .vmlinux_registry import VmlinuxHandlerRegistry
from ..vmlinux_parser.dependency_node import Field from ..vmlinux_parser.dependency_node import Field
@ -77,89 +77,6 @@ def _handle_attribute_expr(
logger.info( logger.info(
f"Variable type: {var_type}, Variable ptr: {var_ptr}, Variable Metadata: {var_metadata}" f"Variable type: {var_type}, Variable ptr: {var_ptr}, Variable Metadata: {var_metadata}"
) )
# Check if this is a pointer to a struct (from map lookup)
if (
isinstance(var_type, ir.PointerType)
and var_metadata
and isinstance(var_metadata, str)
):
if var_metadata in structs_sym_tab:
logger.info(
f"Handling pointer to struct {var_metadata} from map lookup"
)
if func is None:
raise ValueError(
f"func parameter required for null-safe pointer access to {var_name}.{attr_name}"
)
# Load the pointer value (ptr<struct>)
struct_ptr = builder.load(var_ptr)
# Create blocks for null check
null_check_block = builder.block
not_null_block = func.append_basic_block(
name=f"{var_name}_not_null"
)
merge_block = func.append_basic_block(name=f"{var_name}_merge")
# Check if pointer is null
null_ptr = ir.Constant(struct_ptr.type, None)
is_not_null = builder.icmp_signed("!=", struct_ptr, null_ptr)
logger.info(f"Inserted null check for pointer {var_name}")
builder.cbranch(is_not_null, not_null_block, merge_block)
# Not-null block: Access the field
builder.position_at_end(not_null_block)
# Get struct metadata
metadata = structs_sym_tab[var_metadata]
struct_ptr = builder.bitcast(
struct_ptr, metadata.ir_type.as_pointer()
)
if attr_name not in metadata.fields:
raise ValueError(
f"Field '{attr_name}' not found in struct '{var_metadata}'"
)
# GEP to field
field_gep = metadata.gep(builder, struct_ptr, attr_name)
# Load field value
field_val = builder.load(field_gep)
field_type = metadata.field_type(attr_name)
logger.info(
f"Loaded field {attr_name} from struct pointer, type: {field_type}"
)
# Branch to merge
not_null_after_load = builder.block
builder.branch(merge_block)
# Merge block: PHI node for the result
builder.position_at_end(merge_block)
phi = builder.phi(field_type, name=f"{var_name}_{attr_name}")
# If null, return zero/default value
if isinstance(field_type, ir.IntType):
zero_value = ir.Constant(field_type, 0)
elif isinstance(field_type, ir.PointerType):
zero_value = ir.Constant(field_type, None)
elif isinstance(field_type, ir.ArrayType):
# For arrays, we can't easily create a zero constant
# This case is tricky - for now, just use undef
zero_value = ir.Constant(field_type, ir.Undefined)
else:
zero_value = ir.Constant(field_type, ir.Undefined)
phi.add_incoming(zero_value, null_check_block)
phi.add_incoming(field_val, not_null_after_load)
logger.info(f"Created PHI node for {var_name}.{attr_name}")
return phi, field_type
if ( if (
hasattr(var_metadata, "__module__") hasattr(var_metadata, "__module__")
and var_metadata.__module__ == "vmlinux" and var_metadata.__module__ == "vmlinux"
@ -180,13 +97,23 @@ def _handle_attribute_expr(
) )
return None return None
# Regular user-defined struct if var_metadata in structs_sym_tab:
metadata = structs_sym_tab.get(var_metadata) return access_struct_field(
if metadata and attr_name in metadata.fields: builder,
gep = metadata.gep(builder, var_ptr, attr_name) var_ptr,
val = builder.load(gep) var_type,
field_type = metadata.field_type(attr_name) var_metadata,
return val, field_type expr.attr,
structs_sym_tab,
func,
)
else:
logger.error(f"Struct metadata for '{var_name}' not found")
else:
logger.error(f"Undefined variable '{var_name}' for attribute access")
else:
logger.error("Unsupported attribute base expression type")
return None return None

View File

@ -17,41 +17,108 @@ def deref_to_depth(func, builder, val, target_depth):
# dereference with null check # dereference with null check
pointee_type = cur_type.pointee pointee_type = cur_type.pointee
null_check_block = builder.block
not_null_block = func.append_basic_block(name=f"deref_not_null_{depth}")
merge_block = func.append_basic_block(name=f"deref_merge_{depth}")
null_ptr = ir.Constant(cur_type, None) def load_op(builder, ptr):
is_not_null = builder.icmp_signed("!=", cur_val, null_ptr) return builder.load(ptr)
logger.debug(f"Inserted null check for pointer at depth {depth}")
builder.cbranch(is_not_null, not_null_block, merge_block) cur_val = _null_checked_operation(
func, builder, cur_val, load_op, pointee_type, f"deref_{depth}"
builder.position_at_end(not_null_block)
dereferenced_val = builder.load(cur_val)
logger.debug(f"Dereferenced to depth {depth - 1}, type: {pointee_type}")
builder.branch(merge_block)
builder.position_at_end(merge_block)
phi = builder.phi(pointee_type, name=f"deref_result_{depth}")
zero_value = (
ir.Constant(pointee_type, 0)
if isinstance(pointee_type, ir.IntType)
else ir.Constant(pointee_type, None)
) )
phi.add_incoming(zero_value, null_check_block)
phi.add_incoming(dereferenced_val, not_null_block)
# Continue with phi result
cur_val = phi
cur_type = pointee_type cur_type = pointee_type
logger.debug(f"Dereferenced to depth {depth}, type: {pointee_type}")
return cur_val return cur_val
def deref_struct_ptr( def _null_checked_operation(func, builder, ptr, operation, result_type, name_prefix):
func, builder, struct_ptr, struct_metadata, field_name, structs_sym_tab """
Generic null-checked operation on a pointer.
"""
curr_block = builder.block
not_null_block = func.append_basic_block(name=f"{name_prefix}_not_null")
merge_block = func.append_basic_block(name=f"{name_prefix}_merge")
# Null check
null_ptr = ir.Constant(ptr.type, None)
is_not_null = builder.icmp_signed("!=", ptr, null_ptr)
builder.cbranch(is_not_null, not_null_block, merge_block)
# Not-null path: execute operation
builder.position_at_end(not_null_block)
result = operation(builder, ptr)
not_null_after = builder.block
builder.branch(merge_block)
# Merge with PHI
builder.position_at_end(merge_block)
phi = builder.phi(result_type, name=f"{name_prefix}_result")
# Null fallback value
if isinstance(result_type, ir.IntType):
null_val = ir.Constant(result_type, 0)
elif isinstance(result_type, ir.PointerType):
null_val = ir.Constant(result_type, None)
else:
null_val = ir.Constant(result_type, ir.Undefined)
phi.add_incoming(null_val, curr_block)
phi.add_incoming(result, not_null_after)
return phi
def access_struct_field(
builder, var_ptr, var_type, var_metadata, field_name, structs_sym_tab, func=None
): ):
"""Dereference a pointer to a struct type.""" """
return deref_to_depth(func, builder, struct_ptr, 1) Access a struct field - automatically returns value or pointer based on field type.
"""
# Get struct metadata
metadata = (
structs_sym_tab.get(var_metadata)
if isinstance(var_metadata, str)
else var_metadata
)
if not metadata or field_name not in metadata.fields:
raise ValueError(f"Field '{field_name}' not found in struct")
field_type = metadata.field_type(field_name)
is_ptr_to_struct = isinstance(var_type, ir.PointerType) and isinstance(
var_metadata, str
)
# Get struct pointer
struct_ptr = builder.load(var_ptr) if is_ptr_to_struct else var_ptr
# Decide: load value or return pointer?
should_load = not isinstance(field_type, ir.ArrayType)
# Define the field access operation
def field_access_op(builder, ptr):
typed_ptr = builder.bitcast(ptr, metadata.ir_type.as_pointer())
field_ptr = metadata.gep(builder, typed_ptr, field_name)
return builder.load(field_ptr) if should_load else field_ptr
# Handle null check for pointer-to-struct
if is_ptr_to_struct:
if func is None:
raise ValueError("func required for null-safe struct pointer access")
if should_load:
result_type = field_type
else:
result_type = field_type.as_pointer()
result = _null_checked_operation(
func,
builder,
struct_ptr,
field_access_op,
result_type,
f"field_{field_name}",
)
return result, field_type
# No null check needed
field_ptr = metadata.gep(builder, struct_ptr, field_name)
result = builder.load(field_ptr) if should_load else field_ptr
return result, field_type

View File

@ -5,6 +5,7 @@ from llvmlite import ir
from pythonbpf.expr import ( from pythonbpf.expr import (
get_operand_value, get_operand_value,
eval_expr, eval_expr,
access_struct_field,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -135,7 +136,7 @@ def get_or_create_ptr_from_arg(
and field_type.element.width == 8 and field_type.element.width == 8
): ):
ptr, sz = get_char_array_ptr_and_size( ptr, sz = get_char_array_ptr_and_size(
arg, builder, local_sym_tab, struct_sym_tab arg, builder, local_sym_tab, struct_sym_tab, func
) )
if not ptr: if not ptr:
raise ValueError("Failed to get char array pointer from struct field") raise ValueError("Failed to get char array pointer from struct field")
@ -266,7 +267,9 @@ def get_buffer_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab):
) )
def get_char_array_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab): def get_char_array_ptr_and_size(
buf_arg, builder, local_sym_tab, struct_sym_tab, func=None
):
"""Get pointer to char array and its size.""" """Get pointer to char array and its size."""
# Struct field: obj.field # Struct field: obj.field
@ -277,11 +280,11 @@ def get_char_array_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab)
if not (local_sym_tab and var_name in local_sym_tab): if not (local_sym_tab and var_name in local_sym_tab):
raise ValueError(f"Variable '{var_name}' not found") raise ValueError(f"Variable '{var_name}' not found")
struct_type = local_sym_tab[var_name].metadata struct_ptr, struct_type, struct_metadata = local_sym_tab[var_name]
if not (struct_sym_tab and struct_type in struct_sym_tab): if not (struct_sym_tab and struct_metadata in struct_sym_tab):
raise ValueError(f"Struct type '{struct_type}' not found") raise ValueError(f"Struct type '{struct_metadata}' not found")
struct_info = struct_sym_tab[struct_type] struct_info = struct_sym_tab[struct_metadata]
if field_name not in struct_info.fields: if field_name not in struct_info.fields:
raise ValueError(f"Field '{field_name}' not found") raise ValueError(f"Field '{field_name}' not found")
@ -292,8 +295,25 @@ def get_char_array_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab)
) )
return None, 0 return None, 0
struct_ptr = local_sym_tab[var_name].var # Check if char array
field_ptr = struct_info.gep(builder, struct_ptr, field_name) if not (
isinstance(field_type, ir.ArrayType)
and isinstance(field_type.element, ir.IntType)
and field_type.element.width == 8
):
logger.warning("Field is not a char array")
return None, 0
# Get field pointer (automatically handles null checks!)
field_ptr, _ = access_struct_field(
builder,
struct_ptr,
struct_type,
struct_metadata,
field_name,
struct_sym_tab,
func,
)
# GEP to first element: [N x i8]* -> i8* # GEP to first element: [N x i8]* -> i8*
buf_ptr = builder.gep( buf_ptr = builder.gep(

View File

@ -222,7 +222,7 @@ def _prepare_expr_args(expr, func, module, builder, local_sym_tab, struct_sym_ta
# Special case: struct field char array needs pointer to first element # Special case: struct field char array needs pointer to first element
if isinstance(expr, ast.Attribute): if isinstance(expr, ast.Attribute):
char_array_ptr, _ = get_char_array_ptr_and_size( char_array_ptr, _ = get_char_array_ptr_and_size(
expr, builder, local_sym_tab, struct_sym_tab expr, builder, local_sym_tab, struct_sym_tab, func
) )
if char_array_ptr: if char_array_ptr:
return char_array_ptr return char_array_ptr