diff --git a/pythonbpf/expr/__init__.py b/pythonbpf/expr/__init__.py index ac3a975..dfd2128 100644 --- a/pythonbpf/expr/__init__.py +++ b/pythonbpf/expr/__init__.py @@ -1,6 +1,6 @@ from .expr_pass import eval_expr, handle_expr, get_operand_value from .type_normalization import convert_to_bool, get_base_type_and_depth -from .ir_ops import deref_to_depth +from .ir_ops import deref_to_depth, access_struct_field from .call_registry import CallHandlerRegistry from .vmlinux_registry import VmlinuxHandlerRegistry @@ -10,6 +10,7 @@ __all__ = [ "convert_to_bool", "get_base_type_and_depth", "deref_to_depth", + "access_struct_field", "get_operand_value", "CallHandlerRegistry", "VmlinuxHandlerRegistry", diff --git a/pythonbpf/expr/expr_pass.py b/pythonbpf/expr/expr_pass.py index 9f3bfa4..d34dff5 100644 --- a/pythonbpf/expr/expr_pass.py +++ b/pythonbpf/expr/expr_pass.py @@ -6,11 +6,11 @@ from typing import Dict from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes from .call_registry import CallHandlerRegistry +from .ir_ops import deref_to_depth, access_struct_field from .type_normalization import ( convert_to_bool, handle_comparator, get_base_type_and_depth, - deref_to_depth, ) from .vmlinux_registry import VmlinuxHandlerRegistry from ..vmlinux_parser.dependency_node import Field @@ -77,89 +77,6 @@ def _handle_attribute_expr( logger.info( f"Variable type: {var_type}, Variable ptr: {var_ptr}, Variable Metadata: {var_metadata}" ) - # Check if this is a pointer to a struct (from map lookup) - if ( - isinstance(var_type, ir.PointerType) - and var_metadata - and isinstance(var_metadata, str) - ): - if var_metadata in structs_sym_tab: - logger.info( - f"Handling pointer to struct {var_metadata} from map lookup" - ) - - if func is None: - raise ValueError( - f"func parameter required for null-safe pointer access to {var_name}.{attr_name}" - ) - - # Load the pointer value (ptr) - struct_ptr = builder.load(var_ptr) - - # Create blocks for null check - null_check_block = builder.block - not_null_block = func.append_basic_block( - name=f"{var_name}_not_null" - ) - merge_block = func.append_basic_block(name=f"{var_name}_merge") - - # Check if pointer is null - null_ptr = ir.Constant(struct_ptr.type, None) - is_not_null = builder.icmp_signed("!=", struct_ptr, null_ptr) - logger.info(f"Inserted null check for pointer {var_name}") - - builder.cbranch(is_not_null, not_null_block, merge_block) - - # Not-null block: Access the field - builder.position_at_end(not_null_block) - - # Get struct metadata - metadata = structs_sym_tab[var_metadata] - struct_ptr = builder.bitcast( - struct_ptr, metadata.ir_type.as_pointer() - ) - - if attr_name not in metadata.fields: - raise ValueError( - f"Field '{attr_name}' not found in struct '{var_metadata}'" - ) - - # GEP to field - field_gep = metadata.gep(builder, struct_ptr, attr_name) - - # Load field value - field_val = builder.load(field_gep) - field_type = metadata.field_type(attr_name) - - logger.info( - f"Loaded field {attr_name} from struct pointer, type: {field_type}" - ) - - # Branch to merge - not_null_after_load = builder.block - builder.branch(merge_block) - - # Merge block: PHI node for the result - builder.position_at_end(merge_block) - phi = builder.phi(field_type, name=f"{var_name}_{attr_name}") - - # If null, return zero/default value - if isinstance(field_type, ir.IntType): - zero_value = ir.Constant(field_type, 0) - elif isinstance(field_type, ir.PointerType): - zero_value = ir.Constant(field_type, None) - elif isinstance(field_type, ir.ArrayType): - # For arrays, we can't easily create a zero constant - # This case is tricky - for now, just use undef - zero_value = ir.Constant(field_type, ir.Undefined) - else: - zero_value = ir.Constant(field_type, ir.Undefined) - - phi.add_incoming(zero_value, null_check_block) - phi.add_incoming(field_val, not_null_after_load) - - logger.info(f"Created PHI node for {var_name}.{attr_name}") - return phi, field_type if ( hasattr(var_metadata, "__module__") and var_metadata.__module__ == "vmlinux" @@ -180,13 +97,23 @@ def _handle_attribute_expr( ) return None - # Regular user-defined struct - metadata = structs_sym_tab.get(var_metadata) - if metadata and attr_name in metadata.fields: - gep = metadata.gep(builder, var_ptr, attr_name) - val = builder.load(gep) - field_type = metadata.field_type(attr_name) - return val, field_type + if var_metadata in structs_sym_tab: + return access_struct_field( + builder, + var_ptr, + var_type, + var_metadata, + expr.attr, + structs_sym_tab, + func, + ) + else: + logger.error(f"Struct metadata for '{var_name}' not found") + else: + logger.error(f"Undefined variable '{var_name}' for attribute access") + else: + logger.error("Unsupported attribute base expression type") + return None diff --git a/pythonbpf/expr/ir_ops.py b/pythonbpf/expr/ir_ops.py index f6835e2..b4961f9 100644 --- a/pythonbpf/expr/ir_ops.py +++ b/pythonbpf/expr/ir_ops.py @@ -17,41 +17,108 @@ def deref_to_depth(func, builder, val, target_depth): # dereference with null check pointee_type = cur_type.pointee - null_check_block = builder.block - not_null_block = func.append_basic_block(name=f"deref_not_null_{depth}") - merge_block = func.append_basic_block(name=f"deref_merge_{depth}") - null_ptr = ir.Constant(cur_type, None) - is_not_null = builder.icmp_signed("!=", cur_val, null_ptr) - logger.debug(f"Inserted null check for pointer at depth {depth}") + def load_op(builder, ptr): + return builder.load(ptr) - builder.cbranch(is_not_null, not_null_block, merge_block) - - builder.position_at_end(not_null_block) - dereferenced_val = builder.load(cur_val) - logger.debug(f"Dereferenced to depth {depth - 1}, type: {pointee_type}") - builder.branch(merge_block) - - builder.position_at_end(merge_block) - phi = builder.phi(pointee_type, name=f"deref_result_{depth}") - - zero_value = ( - ir.Constant(pointee_type, 0) - if isinstance(pointee_type, ir.IntType) - else ir.Constant(pointee_type, None) + cur_val = _null_checked_operation( + func, builder, cur_val, load_op, pointee_type, f"deref_{depth}" ) - phi.add_incoming(zero_value, null_check_block) - - phi.add_incoming(dereferenced_val, not_null_block) - - # Continue with phi result - cur_val = phi cur_type = pointee_type + logger.debug(f"Dereferenced to depth {depth}, type: {pointee_type}") return cur_val -def deref_struct_ptr( - func, builder, struct_ptr, struct_metadata, field_name, structs_sym_tab +def _null_checked_operation(func, builder, ptr, operation, result_type, name_prefix): + """ + Generic null-checked operation on a pointer. + """ + curr_block = builder.block + not_null_block = func.append_basic_block(name=f"{name_prefix}_not_null") + merge_block = func.append_basic_block(name=f"{name_prefix}_merge") + + # Null check + null_ptr = ir.Constant(ptr.type, None) + is_not_null = builder.icmp_signed("!=", ptr, null_ptr) + builder.cbranch(is_not_null, not_null_block, merge_block) + + # Not-null path: execute operation + builder.position_at_end(not_null_block) + result = operation(builder, ptr) + not_null_after = builder.block + builder.branch(merge_block) + + # Merge with PHI + builder.position_at_end(merge_block) + phi = builder.phi(result_type, name=f"{name_prefix}_result") + + # Null fallback value + if isinstance(result_type, ir.IntType): + null_val = ir.Constant(result_type, 0) + elif isinstance(result_type, ir.PointerType): + null_val = ir.Constant(result_type, None) + else: + null_val = ir.Constant(result_type, ir.Undefined) + + phi.add_incoming(null_val, curr_block) + phi.add_incoming(result, not_null_after) + + return phi + + +def access_struct_field( + builder, var_ptr, var_type, var_metadata, field_name, structs_sym_tab, func=None ): - """Dereference a pointer to a struct type.""" - return deref_to_depth(func, builder, struct_ptr, 1) + """ + Access a struct field - automatically returns value or pointer based on field type. + """ + # Get struct metadata + metadata = ( + structs_sym_tab.get(var_metadata) + if isinstance(var_metadata, str) + else var_metadata + ) + if not metadata or field_name not in metadata.fields: + raise ValueError(f"Field '{field_name}' not found in struct") + + field_type = metadata.field_type(field_name) + is_ptr_to_struct = isinstance(var_type, ir.PointerType) and isinstance( + var_metadata, str + ) + + # Get struct pointer + struct_ptr = builder.load(var_ptr) if is_ptr_to_struct else var_ptr + + # Decide: load value or return pointer? + should_load = not isinstance(field_type, ir.ArrayType) + + # Define the field access operation + def field_access_op(builder, ptr): + typed_ptr = builder.bitcast(ptr, metadata.ir_type.as_pointer()) + field_ptr = metadata.gep(builder, typed_ptr, field_name) + return builder.load(field_ptr) if should_load else field_ptr + + # Handle null check for pointer-to-struct + if is_ptr_to_struct: + if func is None: + raise ValueError("func required for null-safe struct pointer access") + + if should_load: + result_type = field_type + else: + result_type = field_type.as_pointer() + + result = _null_checked_operation( + func, + builder, + struct_ptr, + field_access_op, + result_type, + f"field_{field_name}", + ) + return result, field_type + + # No null check needed + field_ptr = metadata.gep(builder, struct_ptr, field_name) + result = builder.load(field_ptr) if should_load else field_ptr + return result, field_type diff --git a/pythonbpf/helper/helper_utils.py b/pythonbpf/helper/helper_utils.py index aecb5e9..c6ec3f6 100644 --- a/pythonbpf/helper/helper_utils.py +++ b/pythonbpf/helper/helper_utils.py @@ -5,6 +5,7 @@ from llvmlite import ir from pythonbpf.expr import ( get_operand_value, eval_expr, + access_struct_field, ) logger = logging.getLogger(__name__) @@ -135,7 +136,7 @@ def get_or_create_ptr_from_arg( and field_type.element.width == 8 ): ptr, sz = get_char_array_ptr_and_size( - arg, builder, local_sym_tab, struct_sym_tab + arg, builder, local_sym_tab, struct_sym_tab, func ) if not ptr: raise ValueError("Failed to get char array pointer from struct field") @@ -266,7 +267,9 @@ def get_buffer_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab): ) -def get_char_array_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab): +def get_char_array_ptr_and_size( + buf_arg, builder, local_sym_tab, struct_sym_tab, func=None +): """Get pointer to char array and its size.""" # Struct field: obj.field @@ -277,11 +280,11 @@ def get_char_array_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab) if not (local_sym_tab and var_name in local_sym_tab): raise ValueError(f"Variable '{var_name}' not found") - struct_type = local_sym_tab[var_name].metadata - if not (struct_sym_tab and struct_type in struct_sym_tab): - raise ValueError(f"Struct type '{struct_type}' not found") + struct_ptr, struct_type, struct_metadata = local_sym_tab[var_name] + if not (struct_sym_tab and struct_metadata in struct_sym_tab): + raise ValueError(f"Struct type '{struct_metadata}' not found") - struct_info = struct_sym_tab[struct_type] + struct_info = struct_sym_tab[struct_metadata] if field_name not in struct_info.fields: raise ValueError(f"Field '{field_name}' not found") @@ -292,8 +295,25 @@ def get_char_array_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab) ) return None, 0 - struct_ptr = local_sym_tab[var_name].var - field_ptr = struct_info.gep(builder, struct_ptr, field_name) + # Check if char array + if not ( + isinstance(field_type, ir.ArrayType) + and isinstance(field_type.element, ir.IntType) + and field_type.element.width == 8 + ): + logger.warning("Field is not a char array") + return None, 0 + + # Get field pointer (automatically handles null checks!) + field_ptr, _ = access_struct_field( + builder, + struct_ptr, + struct_type, + struct_metadata, + field_name, + struct_sym_tab, + func, + ) # GEP to first element: [N x i8]* -> i8* buf_ptr = builder.gep( diff --git a/pythonbpf/helper/printk_formatter.py b/pythonbpf/helper/printk_formatter.py index 721213e..4364166 100644 --- a/pythonbpf/helper/printk_formatter.py +++ b/pythonbpf/helper/printk_formatter.py @@ -222,7 +222,7 @@ def _prepare_expr_args(expr, func, module, builder, local_sym_tab, struct_sym_ta # Special case: struct field char array needs pointer to first element if isinstance(expr, ast.Attribute): char_array_ptr, _ = get_char_array_ptr_and_size( - expr, builder, local_sym_tab, struct_sym_tab + expr, builder, local_sym_tab, struct_sym_tab, func ) if char_array_ptr: return char_array_ptr