8 Commits

20 changed files with 449 additions and 365 deletions

View File

@ -1,6 +1,6 @@
import ast import ast
import logging import logging
import ctypes
from llvmlite import ir from llvmlite import ir
from .local_symbol import LocalSymbol from .local_symbol import LocalSymbol
from pythonbpf.helper import HelperHandlerRegistry from pythonbpf.helper import HelperHandlerRegistry
@ -81,7 +81,7 @@ def _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab):
call_type = rval.func.id call_type = rval.func.id
# C type constructors # C type constructors
if call_type in ("c_int32", "c_int64", "c_uint32", "c_uint64"): if call_type in ("c_int32", "c_int64", "c_uint32", "c_uint64", "c_void_p"):
ir_type = ctypes_to_ir(call_type) ir_type = ctypes_to_ir(call_type)
var = builder.alloca(ir_type, name=var_name) var = builder.alloca(ir_type, name=var_name)
var.align = ir_type.width // 8 var.align = ir_type.width // 8
@ -249,7 +249,58 @@ def _allocate_for_attribute(builder, var_name, rval, local_sym_tab, structs_sym_
].var = base_ptr # This is repurposing of var to store the pointer of the base type ].var = base_ptr # This is repurposing of var to store the pointer of the base type
local_sym_tab[struct_var].ir_type = field_ir local_sym_tab[struct_var].ir_type = field_ir
actual_ir_type = ir.IntType(64) # Determine the actual IR type based on the field's type
actual_ir_type = None
# Check if it's a ctypes primitive
if field.type.__module__ == ctypes.__name__:
try:
field_size_bytes = ctypes.sizeof(field.type)
field_size_bits = field_size_bytes * 8
if field_size_bits in [8, 16, 32, 64]:
# Special case: struct_xdp_md i32 fields should allocate as i64
# because load_ctx_field will zero-extend them to i64
if (
vmlinux_struct_name == "struct_xdp_md"
and field_size_bits == 32
):
actual_ir_type = ir.IntType(64)
logger.info(
f"Allocating {var_name} as i64 for i32 field from struct_xdp_md.{field_name} "
"(will be zero-extended during load)"
)
else:
actual_ir_type = ir.IntType(field_size_bits)
else:
logger.warning(
f"Unusual field size {field_size_bits} bits for {field_name}"
)
actual_ir_type = ir.IntType(64)
except Exception as e:
logger.warning(
f"Could not determine size for ctypes field {field_name}: {e}"
)
actual_ir_type = ir.IntType(64)
# Check if it's a nested vmlinux struct or complex type
elif field.type.__module__ == "vmlinux":
# For pointers to structs, use pointer type (64-bit)
if field.ctype_complex_type is not None and issubclass(
field.ctype_complex_type, ctypes._Pointer
):
actual_ir_type = ir.IntType(64) # Pointer is always 64-bit
# For embedded structs, this is more complex - might need different handling
else:
logger.warning(
f"Field {field_name} is a nested vmlinux struct, using i64 for now"
)
actual_ir_type = ir.IntType(64)
else:
logger.warning(
f"Unknown field type module {field.type.__module__} for {field_name}"
)
actual_ir_type = ir.IntType(64)
# Allocate with the actual IR type, not the GlobalVariable # Allocate with the actual IR type, not the GlobalVariable
var = _allocate_with_type(builder, var_name, actual_ir_type) var = _allocate_with_type(builder, var_name, actual_ir_type)

View File

@ -152,15 +152,30 @@ def handle_variable_assignment(
if val_type != var_type: if val_type != var_type:
if isinstance(val_type, Field): if isinstance(val_type, Field):
logger.info("Handling assignment to struct field") logger.info("Handling assignment to struct field")
# Special handling for struct_xdp_md i32 fields that are zero-extended to i64
# The load_ctx_field already extended them, so val is i64 but val_type.type shows c_uint
if (
hasattr(val_type, "type")
and val_type.type.__name__ == "c_uint"
and isinstance(var_type, ir.IntType)
and var_type.width == 64
):
# This is the struct_xdp_md case - value is already i64
builder.store(val, var_ptr)
logger.info(
f"Assigned zero-extended struct_xdp_md i32 field to {var_name} (i64)"
)
return True
# TODO: handling only ctype struct fields for now. Handle other stuff too later. # TODO: handling only ctype struct fields for now. Handle other stuff too later.
if var_type == ctypes_to_ir(val_type.type.__name__): elif var_type == ctypes_to_ir(val_type.type.__name__):
builder.store(val, var_ptr) builder.store(val, var_ptr)
logger.info(f"Assigned ctype struct field to {var_name}") logger.info(f"Assigned ctype struct field to {var_name}")
return True return True
logger.error( else:
f"Failed to assign ctype struct field to {var_name}: {val_type} != {var_type}" logger.error(
) f"Failed to assign ctype struct field to {var_name}: {val_type} != {var_type}"
return False )
return False
elif isinstance(val_type, ir.IntType) and isinstance(var_type, ir.IntType): elif isinstance(val_type, ir.IntType) and isinstance(var_type, ir.IntType):
# Allow implicit int widening # Allow implicit int widening
if val_type.width < var_type.width: if val_type.width < var_type.width:

View File

@ -12,6 +12,7 @@ from .type_normalization import (
get_base_type_and_depth, get_base_type_and_depth,
deref_to_depth, deref_to_depth,
) )
from pythonbpf.vmlinux_parser.assignment_info import Field
from .vmlinux_registry import VmlinuxHandlerRegistry from .vmlinux_registry import VmlinuxHandlerRegistry
logger: Logger = logging.getLogger(__name__) logger: Logger = logging.getLogger(__name__)
@ -279,16 +280,45 @@ def _handle_ctypes_call(
call_type = expr.func.id call_type = expr.func.id
expected_type = ctypes_to_ir(call_type) expected_type = ctypes_to_ir(call_type)
if val[1] != expected_type: # Extract the actual IR value and type
# val could be (value, ir_type) or (value, Field)
value, val_type = val
# If val_type is a Field object (from vmlinux struct), get the actual IR type of the value
if isinstance(val_type, Field):
# The value is already the correct IR value (potentially zero-extended)
# Get the IR type from the value itself
actual_ir_type = value.type
logger.info(
f"Converting vmlinux field {val_type.name} (IR type: {actual_ir_type}) to {call_type}"
)
else:
actual_ir_type = val_type
if actual_ir_type != expected_type:
# NOTE: We are only considering casting to and from int types for now # NOTE: We are only considering casting to and from int types for now
if isinstance(val[1], ir.IntType) and isinstance(expected_type, ir.IntType): if isinstance(actual_ir_type, ir.IntType) and isinstance(
if val[1].width < expected_type.width: expected_type, ir.IntType
val = (builder.sext(val[0], expected_type), expected_type) ):
if actual_ir_type.width < expected_type.width:
value = builder.sext(value, expected_type)
logger.info(
f"Sign-extended from i{actual_ir_type.width} to i{expected_type.width}"
)
elif actual_ir_type.width > expected_type.width:
value = builder.trunc(value, expected_type)
logger.info(
f"Truncated from i{actual_ir_type.width} to i{expected_type.width}"
)
else: else:
val = (builder.trunc(val[0], expected_type), expected_type) # Same width, just use as-is (e.g., both i64)
pass
else: else:
raise ValueError(f"Type mismatch: expected {expected_type}, got {val[1]}") raise ValueError(
return val f"Type mismatch: expected {expected_type}, got {actual_ir_type} (original type: {val_type})"
)
return value, expected_type
def _handle_compare( def _handle_compare(

View File

@ -49,17 +49,27 @@ def generate_function_debug_info(
"The first argument should always be a pointer to a struct or a void pointer" "The first argument should always be a pointer to a struct or a void pointer"
) )
context_debug_info = VmlinuxHandlerRegistry.get_struct_debug_info(annotation.id) context_debug_info = VmlinuxHandlerRegistry.get_struct_debug_info(annotation.id)
# Create pointer to context this must be created fresh for each function
# to avoid circular reference issues when the same struct is used in multiple functions
pointer_to_context_debug_info = generator.create_pointer_type( pointer_to_context_debug_info = generator.create_pointer_type(
context_debug_info, 64 context_debug_info, 64
) )
# Create subroutine type - also fresh for each function
subroutine_type = generator.create_subroutine_type( subroutine_type = generator.create_subroutine_type(
return_type, pointer_to_context_debug_info return_type, pointer_to_context_debug_info
) )
# Create local variable - fresh for each function with unique name
context_local_variable = generator.create_local_variable_debug_info( context_local_variable = generator.create_local_variable_debug_info(
leading_argument_name, 1, pointer_to_context_debug_info leading_argument_name, 1, pointer_to_context_debug_info
) )
retained_nodes = [context_local_variable] retained_nodes = [context_local_variable]
print("function name", func_node.name) logger.info(f"Generating debug info for function {func_node.name}")
# Create subprogram with is_distinct=True to ensure each function gets unique debug info
subprogram_debug_info = generator.create_subprogram( subprogram_debug_info = generator.create_subprogram(
func_node.name, subroutine_type, retained_nodes func_node.name, subroutine_type, retained_nodes
) )

View File

@ -16,6 +16,8 @@ mapping = {
"c_long": ir.IntType(64), "c_long": ir.IntType(64),
"c_ulong": ir.IntType(64), "c_ulong": ir.IntType(64),
"c_longlong": ir.IntType(64), "c_longlong": ir.IntType(64),
"c_uint": ir.IntType(32),
"c_int": ir.IntType(32),
# Not so sure about this one # Not so sure about this one
"str": ir.PointerType(ir.IntType(8)), "str": ir.PointerType(ir.IntType(8)),
} }

View File

@ -16,37 +16,10 @@ def get_module_symbols(module_name: str):
return [name for name in dir(imported_module)], imported_module return [name for name in dir(imported_module)], imported_module
def unwrap_pointer_type(type_obj: Any) -> Any:
"""
Recursively unwrap all pointer layers to get the base type.
This handles multiply nested pointers like LP_LP_struct_attribute_group
and returns the base type (struct_attribute_group).
Stops unwrapping when reaching a non-pointer type (one without _type_ attribute).
Args:
type_obj: The type object to unwrap
Returns:
The base type after unwrapping all pointer layers
"""
current_type = type_obj
# Keep unwrapping while it's a pointer/array type (has _type_)
# But stop if _type_ is just a string or basic type marker
while hasattr(current_type, "_type_"):
next_type = current_type._type_
# Stop if _type_ is a string (like 'c' for c_char)
if isinstance(next_type, str):
break
current_type = next_type
return current_type
def process_vmlinux_class( def process_vmlinux_class(
node, node,
llvm_module, llvm_module,
handler: DependencyHandler, handler: DependencyHandler,
): ):
symbols_in_module, imported_module = get_module_symbols("vmlinux") symbols_in_module, imported_module = get_module_symbols("vmlinux")
if node.name in symbols_in_module: if node.name in symbols_in_module:
@ -57,10 +30,10 @@ def process_vmlinux_class(
def process_vmlinux_post_ast( def process_vmlinux_post_ast(
elem_type_class, elem_type_class,
llvm_handler, llvm_handler,
handler: DependencyHandler, handler: DependencyHandler,
processing_stack=None, processing_stack=None,
): ):
# Initialize processing stack on first call # Initialize processing stack on first call
if processing_stack is None: if processing_stack is None:
@ -140,7 +113,7 @@ def process_vmlinux_post_ast(
# Process pointer to ctype # Process pointer to ctype
if isinstance(elem_type, type) and issubclass( if isinstance(elem_type, type) and issubclass(
elem_type, ctypes._Pointer elem_type, ctypes._Pointer
): ):
# Get the pointed-to type # Get the pointed-to type
pointed_type = elem_type._type_ pointed_type = elem_type._type_
@ -153,7 +126,7 @@ def process_vmlinux_post_ast(
# Process function pointers (CFUNCTYPE) # Process function pointers (CFUNCTYPE)
elif hasattr(elem_type, "_restype_") and hasattr( elif hasattr(elem_type, "_restype_") and hasattr(
elem_type, "_argtypes_" elem_type, "_argtypes_"
): ):
# This is a CFUNCTYPE or similar # This is a CFUNCTYPE or similar
logger.info( logger.info(
@ -185,90 +158,13 @@ def process_vmlinux_post_ast(
if hasattr(elem_type, "_length_") and is_complex_type: if hasattr(elem_type, "_length_") and is_complex_type:
type_length = elem_type._length_ type_length = elem_type._length_
# Unwrap all pointer layers to get the base type for dependency tracking if containing_type.__module__ == "vmlinux":
base_type = unwrap_pointer_type(elem_type) new_dep_node.add_dependent(
base_type_module = getattr(base_type, "__module__", None) elem_type._type_.__name__
if hasattr(elem_type._type_, "__name__")
if base_type_module == "vmlinux": else str(elem_type._type_)
base_type_name = (
base_type.__name__
if hasattr(base_type, "__name__")
else str(base_type)
)
# ONLY add vmlinux types as dependencies
new_dep_node.add_dependent(base_type_name)
logger.debug(
f"{containing_type} containing type of parent {elem_name} with {elem_type} and ctype {ctype_complex_type} and length {type_length}"
)
new_dep_node.set_field_containing_type(
elem_name, containing_type
)
new_dep_node.set_field_type_size(elem_name, type_length)
new_dep_node.set_field_ctype_complex_type(
elem_name, ctype_complex_type
)
new_dep_node.set_field_type(elem_name, elem_type)
# Check the containing_type module to decide whether to recurse
containing_type_module = getattr(
containing_type, "__module__", None
)
if containing_type_module == "vmlinux":
# Also unwrap containing_type to get base type name
base_containing_type = unwrap_pointer_type(
containing_type
)
containing_type_name = (
base_containing_type.__name__
if hasattr(base_containing_type, "__name__")
else str(base_containing_type)
)
# Check for self-reference or already processed
if containing_type_name == current_symbol_name:
# Self-referential pointer
logger.debug(
f"Self-referential pointer in {current_symbol_name}.{elem_name}"
)
new_dep_node.set_field_ready(elem_name, True)
elif handler.has_node(containing_type_name):
# Already processed
logger.debug(
f"Reusing already processed {containing_type_name}"
)
new_dep_node.set_field_ready(elem_name, True)
else:
# Process recursively - use base containing type, not the pointer wrapper
new_dep_node.add_dependent(containing_type_name)
process_vmlinux_post_ast(
base_containing_type,
llvm_handler,
handler,
processing_stack,
)
new_dep_node.set_field_ready(elem_name, True)
elif (
containing_type_module == ctypes.__name__
or containing_type_module is None
):
logger.debug(
f"Processing ctype internal{containing_type}"
)
new_dep_node.set_field_ready(elem_name, True)
else:
raise TypeError(
f"Module not supported in recursive resolution: {containing_type_module}"
)
elif (
base_type_module == ctypes.__name__
or base_type_module is None
):
# Handle ctypes or types with no module (like some internal ctypes types)
# DO NOT add ctypes as dependencies - just set field metadata and mark ready
logger.debug(
f"Base type {base_type} is ctypes - NOT adding as dependency, just processing field"
) )
elif containing_type.__module__ == ctypes.__name__:
if isinstance(elem_type, type): if isinstance(elem_type, type):
if issubclass(elem_type, ctypes.Array): if issubclass(elem_type, ctypes.Array):
ctype_complex_type = ctypes.Array ctype_complex_type = ctypes.Array
@ -280,20 +176,57 @@ def process_vmlinux_post_ast(
) )
else: else:
raise TypeError("Unsupported ctypes subclass") raise TypeError("Unsupported ctypes subclass")
# Set field metadata but DO NOT add dependency or recurse
new_dep_node.set_field_containing_type(
elem_name, containing_type
)
new_dep_node.set_field_type_size(elem_name, type_length)
new_dep_node.set_field_ctype_complex_type(
elem_name, ctype_complex_type
)
new_dep_node.set_field_type(elem_name, elem_type)
new_dep_node.set_field_ready(elem_name, True)
else: else:
raise ImportError( raise ImportError(
f"Unsupported module of {base_type}: {base_type_module}" f"Unsupported module of {containing_type}"
)
logger.debug(
f"{containing_type} containing type of parent {elem_name} with {elem_type} and ctype {ctype_complex_type} and length {type_length}"
)
new_dep_node.set_field_containing_type(
elem_name, containing_type
)
new_dep_node.set_field_type_size(elem_name, type_length)
new_dep_node.set_field_ctype_complex_type(
elem_name, ctype_complex_type
)
new_dep_node.set_field_type(elem_name, elem_type)
if containing_type.__module__ == "vmlinux":
containing_type_name = (
containing_type.__name__
if hasattr(containing_type, "__name__")
else str(containing_type)
)
# Check for self-reference or already processed
if containing_type_name == current_symbol_name:
# Self-referential pointer
logger.debug(
f"Self-referential pointer in {current_symbol_name}.{elem_name}"
)
new_dep_node.set_field_ready(elem_name, True)
elif handler.has_node(containing_type_name):
# Already processed
logger.debug(
f"Reusing already processed {containing_type_name}"
)
new_dep_node.set_field_ready(elem_name, True)
else:
# Process recursively - THIS WAS MISSING
new_dep_node.add_dependent(containing_type_name)
process_vmlinux_post_ast(
containing_type,
llvm_handler,
handler,
processing_stack,
)
new_dep_node.set_field_ready(elem_name, True)
elif containing_type.__module__ == ctypes.__name__:
logger.debug(f"Processing ctype internal{containing_type}")
new_dep_node.set_field_ready(elem_name, True)
else:
raise TypeError(
"Module not supported in recursive resolution"
) )
else: else:
new_dep_node.add_dependent( new_dep_node.add_dependent(
@ -312,12 +245,9 @@ def process_vmlinux_post_ast(
raise ValueError( raise ValueError(
f"{elem_name} with type {elem_type} from module {module_name} not supported in recursive resolver" f"{elem_name} with type {elem_type} from module {module_name} not supported in recursive resolver"
) )
elif module_name == ctypes.__name__ or module_name is None:
# Handle ctypes types - these don't need processing, just return
logger.debug(f"Skipping ctypes type {current_symbol_name}")
return True
else: else:
raise ImportError(f"UNSUPPORTED Module {module_name}") raise ImportError("UNSUPPORTED Module")
logger.info( logger.info(
f"{current_symbol_name} processed and handler readiness {handler.is_ready}" f"{current_symbol_name} processed and handler readiness {handler.is_ready}"

View File

@ -11,9 +11,7 @@ from .class_handler import process_vmlinux_class
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def detect_import_statement( def detect_import_statement(tree: ast.AST) -> list[tuple[str, ast.ImportFrom]]:
tree: ast.AST,
) -> list[tuple[str, ast.ImportFrom, str, str]]:
""" """
Parse AST and detect import statements from vmlinux. Parse AST and detect import statements from vmlinux.
@ -27,7 +25,7 @@ def detect_import_statement(
List of tuples containing (module_name, imported_item) for each vmlinux import List of tuples containing (module_name, imported_item) for each vmlinux import
Raises: Raises:
SyntaxError: If import * is used SyntaxError: If multiple imports from vmlinux are attempted or import * is used
""" """
vmlinux_imports = [] vmlinux_imports = []
@ -42,19 +40,28 @@ def detect_import_statement(
"Please import specific types explicitly." "Please import specific types explicitly."
) )
# Check for multiple imports: from vmlinux import A, B, C
if len(node.names) > 1:
imported_names = [alias.name for alias in node.names]
raise SyntaxError(
f"Multiple imports from vmlinux are not supported. "
f"Found: {', '.join(imported_names)}. "
f"Please use separate import statements for each type."
)
# Check if no specific import is specified (should not happen with valid Python) # Check if no specific import is specified (should not happen with valid Python)
if len(node.names) == 0: if len(node.names) == 0:
raise SyntaxError( raise SyntaxError(
"Import from vmlinux must specify at least one type." "Import from vmlinux must specify at least one type."
) )
# Support multiple imports: from vmlinux import A, B, C # Valid single import
for alias in node.names: for alias in node.names:
import_name = alias.name import_name = alias.name
# Use alias if provided, otherwise use the original name # Use alias if provided, otherwise use the original name (commented)
as_name = alias.asname if alias.asname else alias.name # as_name = alias.asname if alias.asname else alias.name
vmlinux_imports.append(("vmlinux", node, import_name, as_name)) vmlinux_imports.append(("vmlinux", node))
logger.info(f"Found vmlinux import: {import_name} as {as_name}") logger.info(f"Found vmlinux import: {import_name}")
# Handle "import vmlinux" statements (not typical but should be rejected) # Handle "import vmlinux" statements (not typical but should be rejected)
elif isinstance(node, ast.Import): elif isinstance(node, ast.Import):
@ -66,7 +73,6 @@ def detect_import_statement(
) )
logger.info(f"Total vmlinux imports detected: {len(vmlinux_imports)}") logger.info(f"Total vmlinux imports detected: {len(vmlinux_imports)}")
# print(f"\n**************\n{vmlinux_imports}\n**************\n")
return vmlinux_imports return vmlinux_imports
@ -97,37 +103,40 @@ def vmlinux_proc(tree: ast.AST, module):
with open(source_file, "r") as f: with open(source_file, "r") as f:
mod_ast = ast.parse(f.read(), filename=source_file) mod_ast = ast.parse(f.read(), filename=source_file)
for import_mod, import_node, imported_name, as_name in import_statements: for import_mod, import_node in import_statements:
found = False for alias in import_node.names:
for mod_node in mod_ast.body: imported_name = alias.name
if isinstance(mod_node, ast.ClassDef) and mod_node.name == imported_name: found = False
process_vmlinux_class(mod_node, module, handler) for mod_node in mod_ast.body:
found = True if (
break isinstance(mod_node, ast.ClassDef)
if isinstance(mod_node, ast.Assign): and mod_node.name == imported_name
for target in mod_node.targets: ):
if isinstance(target, ast.Name) and target.id == imported_name: process_vmlinux_class(mod_node, module, handler)
process_vmlinux_assign(mod_node, module, assignments, as_name) found = True
found = True break
break if isinstance(mod_node, ast.Assign):
if found: for target in mod_node.targets:
break if isinstance(target, ast.Name) and target.id == imported_name:
if not found: process_vmlinux_assign(mod_node, module, assignments)
logger.info(f"{imported_name} not found as ClassDef or Assign in vmlinux") found = True
break
if found:
break
if not found:
logger.info(
f"{imported_name} not found as ClassDef or Assign in vmlinux"
)
IRGenerator(module, handler, assignments) IRGenerator(module, handler, assignments)
return assignments return assignments
def process_vmlinux_assign( def process_vmlinux_assign(node, module, assignments: dict[str, AssignmentInfo]):
node, module, assignments: dict[str, AssignmentInfo], target_name=None
):
"""Process assignments from vmlinux module.""" """Process assignments from vmlinux module."""
# Only handle single-target assignments # Only handle single-target assignments
if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
# Use provided target_name (for aliased imports) or fall back to original name target_name = node.targets[0].id
if target_name is None:
target_name = node.targets[0].id
# Handle constant value assignments # Handle constant value assignments
if isinstance(node.value, ast.Constant): if isinstance(node.value, ast.Constant):

View File

@ -46,14 +46,13 @@ def debug_info_generation(
if struct.name.startswith("struct_"): if struct.name.startswith("struct_"):
struct_name = struct.name.removeprefix("struct_") struct_name = struct.name.removeprefix("struct_")
# Create struct type with all members
struct_type = generator.create_struct_type_with_name(
struct_name, members, struct.__sizeof__() * 8, is_distinct=True
)
else: else:
logger.warning("Blindly handling Unions present in vmlinux dependencies") raise ValueError("Unions are not supported in the current version")
struct_type = None # Create struct type with all members
# raise ValueError("Unions are not supported in the current version") struct_type = generator.create_struct_type_with_name(
struct_name, members, struct.__sizeof__() * 8, is_distinct=True
)
return struct_type return struct_type
@ -63,7 +62,7 @@ def _get_field_debug_type(
generator: DebugInfoGenerator, generator: DebugInfoGenerator,
parent_struct: DependencyNode, parent_struct: DependencyNode,
generated_debug_info: List[Tuple[DependencyNode, Any]], generated_debug_info: List[Tuple[DependencyNode, Any]],
) -> tuple[Any, int] | None: ) -> tuple[Any, int]:
""" """
Determine the appropriate debug type for a field based on its Python/ctypes type. Determine the appropriate debug type for a field based on its Python/ctypes type.
@ -79,11 +78,7 @@ def _get_field_debug_type(
""" """
# Handle complex types (arrays, pointers) # Handle complex types (arrays, pointers)
if field.ctype_complex_type is not None: if field.ctype_complex_type is not None:
# TODO: Check if this is a CFUNCTYPE (function pointer), but sadly it just checks callable for now if issubclass(field.ctype_complex_type, ctypes.Array):
if callable(field.ctype_complex_type):
# Handle function pointer types, create a void pointer as a placeholder
return generator.create_pointer_type(None), 64
elif issubclass(field.ctype_complex_type, ctypes.Array):
# Handle array types # Handle array types
element_type, base_type_size = _get_basic_debug_type( element_type, base_type_size = _get_basic_debug_type(
field.containing_type, generator field.containing_type, generator

View File

@ -11,10 +11,6 @@ logger = logging.getLogger(__name__)
class IRGenerator: class IRGenerator:
# This field keeps track of the non_struct names to avoid duplicate name errors.
type_number = 0
unprocessed_store = []
# get the assignments dict and add this stuff to it. # get the assignments dict and add this stuff to it.
def __init__(self, llvm_module, handler: DependencyHandler, assignments): def __init__(self, llvm_module, handler: DependencyHandler, assignments):
self.llvm_module = llvm_module self.llvm_module = llvm_module
@ -72,7 +68,6 @@ class IRGenerator:
dep_node_from_dependency, processing_stack dep_node_from_dependency, processing_stack
) )
else: else:
print(struct)
raise RuntimeError( raise RuntimeError(
f"Warning: Dependency {dependency} not found in handler" f"Warning: Dependency {dependency} not found in handler"
) )
@ -87,7 +82,6 @@ class IRGenerator:
members_dict = {} members_dict = {}
for field_name, field in struct.fields.items(): for field_name, field in struct.fields.items():
# Get the generated field name from our dictionary, or use field_name if not found # Get the generated field name from our dictionary, or use field_name if not found
print(f"DEBUG: {struct.name}, {field_name}")
if ( if (
struct.name in self.generated_field_names struct.name in self.generated_field_names
and field_name in self.generated_field_names[struct.name] and field_name in self.generated_field_names[struct.name]
@ -135,20 +129,7 @@ class IRGenerator:
for field_name, field in struct.fields.items(): for field_name, field in struct.fields.items():
# does not take arrays and similar types into consideration yet. # does not take arrays and similar types into consideration yet.
if callable(field.ctype_complex_type): if field.ctype_complex_type is not None and issubclass(
# Function pointer case - generate a simple field accessor
field_co_re_name, returned = self._struct_name_generator(
struct, field, field_index
)
print(field_co_re_name)
field_index += 1
globvar = ir.GlobalVariable(
self.llvm_module, ir.IntType(64), name=field_co_re_name
)
globvar.linkage = "external"
globvar.set_metadata("llvm.preserve.access.index", debug_info)
self.generated_field_names[struct.name][field_name] = globvar
elif field.ctype_complex_type is not None and issubclass(
field.ctype_complex_type, ctypes.Array field.ctype_complex_type, ctypes.Array
): ):
array_size = field.type_size array_size = field.type_size
@ -156,7 +137,7 @@ class IRGenerator:
if containing_type.__module__ == ctypes.__name__: if containing_type.__module__ == ctypes.__name__:
containing_type_size = ctypes.sizeof(containing_type) containing_type_size = ctypes.sizeof(containing_type)
if array_size == 0: if array_size == 0:
field_co_re_name, returned = self._struct_name_generator( field_co_re_name = self._struct_name_generator(
struct, field, field_index, True, 0, containing_type_size struct, field, field_index, True, 0, containing_type_size
) )
globvar = ir.GlobalVariable( globvar = ir.GlobalVariable(
@ -168,7 +149,7 @@ class IRGenerator:
field_index += 1 field_index += 1
continue continue
for i in range(0, array_size): for i in range(0, array_size):
field_co_re_name, returned = self._struct_name_generator( field_co_re_name = self._struct_name_generator(
struct, field, field_index, True, i, containing_type_size struct, field, field_index, True, i, containing_type_size
) )
globvar = ir.GlobalVariable( globvar = ir.GlobalVariable(
@ -182,30 +163,12 @@ class IRGenerator:
array_size = field.type_size array_size = field.type_size
containing_type = field.containing_type containing_type = field.containing_type
if containing_type.__module__ == "vmlinux": if containing_type.__module__ == "vmlinux":
print(struct) containing_type_size = self.handler[
# Unwrap all pointer layers to get the base struct type containing_type.__name__
base_containing_type = containing_type ].current_offset
while hasattr(base_containing_type, "_type_"): for i in range(0, array_size):
next_type = base_containing_type._type_ field_co_re_name = self._struct_name_generator(
# Stop if _type_ is a string (like 'c' for c_char) struct, field, field_index, True, i, containing_type_size
# TODO: stacked pointers not handl;ing ctypes check here as well
if isinstance(next_type, str):
break
base_containing_type = next_type
# Get the base struct name
base_struct_name = (
base_containing_type.__name__
if hasattr(base_containing_type, "__name__")
else str(base_containing_type)
)
# Look up the size using the base struct name
containing_type_size = self.handler[base_struct_name].current_offset
print(f"GAY: {array_size}, {struct.name}, {field_name}")
if array_size == 0:
field_co_re_name, returned = self._struct_name_generator(
struct, field, field_index, True, 0, containing_type_size
) )
globvar = ir.GlobalVariable( globvar = ir.GlobalVariable(
self.llvm_module, ir.IntType(64), name=field_co_re_name self.llvm_module, ir.IntType(64), name=field_co_re_name
@ -213,30 +176,9 @@ class IRGenerator:
globvar.linkage = "external" globvar.linkage = "external"
globvar.set_metadata("llvm.preserve.access.index", debug_info) globvar.set_metadata("llvm.preserve.access.index", debug_info)
self.generated_field_names[struct.name][field_name] = globvar self.generated_field_names[struct.name][field_name] = globvar
field_index += 1 field_index += 1
else:
for i in range(0, array_size):
field_co_re_name, returned = self._struct_name_generator(
struct,
field,
field_index,
True,
i,
containing_type_size,
)
globvar = ir.GlobalVariable(
self.llvm_module, ir.IntType(64), name=field_co_re_name
)
globvar.linkage = "external"
globvar.set_metadata(
"llvm.preserve.access.index", debug_info
)
self.generated_field_names[struct.name][field_name] = (
globvar
)
field_index += 1
else: else:
field_co_re_name, returned = self._struct_name_generator( field_co_re_name = self._struct_name_generator(
struct, field, field_index struct, field, field_index
) )
field_index += 1 field_index += 1
@ -256,7 +198,7 @@ class IRGenerator:
is_indexed: bool = False, is_indexed: bool = False,
index: int = 0, index: int = 0,
containing_type_size: int = 0, containing_type_size: int = 0,
) -> tuple[str, bool]: ) -> str:
# TODO: Does not support Unions as well as recursive pointer and array type naming # TODO: Does not support Unions as well as recursive pointer and array type naming
if is_indexed: if is_indexed:
name = ( name = (
@ -266,7 +208,7 @@ class IRGenerator:
+ "$" + "$"
+ f"0:{field_index}:{index}" + f"0:{field_index}:{index}"
) )
return name, True return name
elif struct.name.startswith("struct_"): elif struct.name.startswith("struct_"):
name = ( name = (
"llvm." "llvm."
@ -275,18 +217,9 @@ class IRGenerator:
+ "$" + "$"
+ f"0:{field_index}" + f"0:{field_index}"
) )
return name, True return name
else: else:
logger.warning( print(self.handler[struct.name])
"Blindly handling non-struct type to avoid type errors in vmlinux IR generation. Possibly a union." raise TypeError(
"Name generation cannot occur due to type name not starting with struct"
) )
self.type_number += 1
unprocessed_type = "unprocessed_type_" + str(self.handler[struct.name].name)
if self.unprocessed_store.__contains__(unprocessed_type):
return unprocessed_type + "_" + str(self.type_number), False
else:
self.unprocessed_store.append(unprocessed_type)
return unprocessed_type, False
# raise TypeError(
# "Name generation cannot occur due to type name not starting with struct"
# )

View File

@ -1,6 +1,6 @@
import logging import logging
from typing import Any from typing import Any
import ctypes
from llvmlite import ir from llvmlite import ir
from pythonbpf.local_symbol import LocalSymbol from pythonbpf.local_symbol import LocalSymbol
@ -94,22 +94,19 @@ class VmlinuxHandler:
f"Attempting to access field {field_name} of possible vmlinux struct {struct_var_name}" f"Attempting to access field {field_name} of possible vmlinux struct {struct_var_name}"
) )
python_type: type = var_info.metadata python_type: type = var_info.metadata
globvar_ir, field_data = self.get_field_type( struct_name = python_type.__name__
python_type.__name__, field_name globvar_ir, field_data = self.get_field_type(struct_name, field_name)
)
builder.function.args[0].type = ir.PointerType(ir.IntType(8)) builder.function.args[0].type = ir.PointerType(ir.IntType(8))
print(builder.function.args[0])
field_ptr = self.load_ctx_field( field_ptr = self.load_ctx_field(
builder, builder.function.args[0], globvar_ir builder, builder.function.args[0], globvar_ir, field_data, struct_name
) )
print(field_ptr)
# Return pointer to field and field type # Return pointer to field and field type
return field_ptr, field_data return field_ptr, field_data
else: else:
raise RuntimeError("Variable accessed not found in symbol table") raise RuntimeError("Variable accessed not found in symbol table")
@staticmethod @staticmethod
def load_ctx_field(builder, ctx_arg, offset_global): def load_ctx_field(builder, ctx_arg, offset_global, field_data, struct_name=None):
""" """
Generate LLVM IR to load a field from BPF context using offset. Generate LLVM IR to load a field from BPF context using offset.
@ -117,9 +114,10 @@ class VmlinuxHandler:
builder: llvmlite IRBuilder instance builder: llvmlite IRBuilder instance
ctx_arg: The context pointer argument (ptr/i8*) ctx_arg: The context pointer argument (ptr/i8*)
offset_global: Global variable containing the field offset (i64) offset_global: Global variable containing the field offset (i64)
field_data: contains data about the field
struct_name: Name of the struct being accessed (optional)
Returns: Returns:
The loaded value (i64 register) The loaded value (i64 register or appropriately sized)
""" """
# Load the offset value # Load the offset value
@ -164,13 +162,61 @@ class VmlinuxHandler:
passthrough_fn, [ir.Constant(ir.IntType(32), 0), field_ptr], tail=True passthrough_fn, [ir.Constant(ir.IntType(32), 0), field_ptr], tail=True
) )
# Bitcast to i64* (assuming field is 64-bit, adjust if needed) # Determine the appropriate IR type based on field information
i64_ptr_type = ir.PointerType(ir.IntType(64)) int_width = 64 # Default to 64-bit
typed_ptr = builder.bitcast(verified_ptr, i64_ptr_type) needs_zext = False # Track if we need zero-extension for xdp_md
if field_data is not None:
# Try to determine the size from field metadata
if field_data.type.__module__ == ctypes.__name__:
try:
field_size_bytes = ctypes.sizeof(field_data.type)
field_size_bits = field_size_bytes * 8
if field_size_bits in [8, 16, 32, 64]:
int_width = field_size_bits
logger.info(f"Determined field size: {int_width} bits")
# Special handling for struct_xdp_md i32 fields
# Load as i32 but extend to i64 before storing
if struct_name == "struct_xdp_md" and int_width == 32:
needs_zext = True
logger.info(
"struct_xdp_md i32 field detected, will zero-extend to i64"
)
else:
logger.warning(
f"Unusual field size {field_size_bits} bits, using default 64"
)
except Exception as e:
logger.warning(
f"Could not determine field size: {e}, using default 64"
)
elif field_data.type.__module__ == "vmlinux":
# For pointers to structs or complex vmlinux types
if field_data.ctype_complex_type is not None and issubclass(
field_data.ctype_complex_type, ctypes._Pointer
):
int_width = 64 # Pointers are always 64-bit
logger.info("Field is a pointer type, using 64 bits")
# TODO: Add handling for other complex types (arrays, embedded structs, etc.)
else:
logger.warning("Complex vmlinux field type, using default 64 bits")
# Bitcast to appropriate pointer type based on determined width
ptr_type = ir.PointerType(ir.IntType(int_width))
typed_ptr = builder.bitcast(verified_ptr, ptr_type)
# Load and return the value # Load and return the value
value = builder.load(typed_ptr) value = builder.load(typed_ptr)
# Zero-extend i32 to i64 for struct_xdp_md fields
if needs_zext:
value = builder.zext(value, ir.IntType(64))
logger.info("Zero-extended i32 value to i64 for struct_xdp_md field")
return value return value
def has_field(self, struct_name, field_name): def has_field(self, struct_name, field_name):

View File

@ -1,19 +1,23 @@
BPF_CLANG := clang BPF_CLANG := clang
CFLAGS := -O0 -emit-llvm -target bpf -c CFLAGS := -emit-llvm -target bpf -c
SRC := $(wildcard *.bpf.c) SRC := $(wildcard *.bpf.c)
LL := $(SRC:.bpf.c=.bpf.ll) LL := $(SRC:.bpf.c=.bpf.ll)
LL2 := $(SRC:.bpf.c=.bpf.o2.ll)
OBJ := $(SRC:.bpf.c=.bpf.o) OBJ := $(SRC:.bpf.c=.bpf.o)
.PHONY: all clean .PHONY: all clean
all: $(LL) $(OBJ) all: $(LL) $(OBJ) $(LL2)
%.bpf.o: %.bpf.c %.bpf.o: %.bpf.c
$(BPF_CLANG) -O2 -g -target bpf -c $< -o $@ $(BPF_CLANG) -O2 -g -target bpf -c $< -o $@
%.bpf.ll: %.bpf.c %.bpf.ll: %.bpf.c
$(BPF_CLANG) $(CFLAGS) -g -S $< -o $@ $(BPF_CLANG) -O0 $(CFLAGS) -g -S $< -o $@
%.bpf.o2.ll: %.bpf.c
$(BPF_CLANG) -O2 $(CFLAGS) -g -S $< -o $@
clean: clean:
rm -f $(LL) $(OBJ) rm -f $(LL) $(OBJ) $(LL2)

View File

@ -0,0 +1,15 @@
#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
SEC("xdp")
int print_xdp_data(struct xdp_md *ctx)
{
// 'data' is a pointer to the start of packet data
long data = (long)ctx->data;
bpf_printk("ctx->data = %lld\n", data);
return XDP_PASS;
}
char LICENSE[] SEC("license") = "GPL";

View File

@ -0,0 +1,30 @@
import logging
from pythonbpf import bpf, section, bpfglobal, compile_to_ir
from pythonbpf import compile # noqa: F401
from vmlinux import TASK_COMM_LEN # noqa: F401
from vmlinux import struct_trace_event_raw_sys_enter # noqa: F401
from ctypes import c_int64, c_int32, c_void_p # noqa: F401
# from vmlinux import struct_uinput_device
# from vmlinux import struct_blk_integrity_iter
@bpf
@section("tracepoint/syscalls/sys_enter_execve")
def hello_world(ctx: struct_trace_event_raw_sys_enter) -> c_int64:
b = ctx.args
c = b[0]
print(f"This is context args field {c}")
return c_int64(0)
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile_to_ir("args_test.py", "args_test.ll", loglevel=logging.INFO)
compile()

View File

@ -1,22 +0,0 @@
from vmlinux import XDP_PASS
from pythonbpf import bpf, section, bpfglobal, compile_to_ir
import logging
from ctypes import c_int64, c_void_p
@bpf
@section("kprobe/blk_mq_start_request")
def example(ctx: c_void_p) -> c_int64:
d = XDP_PASS # This gives an error, but
e = XDP_PASS + 0 # this does not
print(f"test1 {e} test2 {d}")
return c_int64(0)
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile_to_ir("assignment_handling.py", "assignment_handling.ll", loglevel=logging.INFO)

View File

@ -1,22 +0,0 @@
from vmlinux import struct_request, struct_pt_regs
from pythonbpf import bpf, section, bpfglobal, compile_to_ir
import logging
from ctypes import c_int64
@bpf
@section("kprobe/blk_mq_start_request")
def example(ctx: struct_pt_regs) -> c_int64:
req = struct_request(ctx.di)
c = req.__data_len
print(f"data length {c}")
return c_int64(0)
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile_to_ir("requests.py", "requests.ll", loglevel=logging.INFO)

View File

@ -1,21 +0,0 @@
from vmlinux import struct_pt_regs, struct_request
from pythonbpf import bpf, section, bpfglobal, compile_to_ir
import logging
from ctypes import c_int64
@bpf
@section("kprobe/blk_mq_start_request")
def example(ctx: struct_pt_regs) -> c_int64:
req = ctx.di
print(f"data length {req}")
return c_int64(0)
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile_to_ir("requests2.py", "requests2.ll", loglevel=logging.INFO)

View File

@ -0,0 +1,31 @@
from ctypes import c_int64, c_void_p
from pythonbpf import bpf, section, bpfglobal, compile_to_ir, compile
from vmlinux import struct_xdp_md
from vmlinux import XDP_PASS
@bpf
@section("xdp")
def print_xdp_dat2a(ct2x: struct_xdp_md) -> c_int64:
data = ct2x.data # 32-bit field: packet start pointer
print(f"ct2x->data = {data}")
return c_int64(XDP_PASS)
@bpf
@section("xdp")
def print_xdp_data(ctx: struct_xdp_md) -> c_int64:
data = ctx.data # 32-bit field: packet start pointer
something = c_void_p(data)
print(f"ctx->data = {something}")
return c_int64(XDP_PASS)
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile_to_ir("i32_test.py", "i32_test.ll")
compile()

View File

@ -0,0 +1,24 @@
from ctypes import c_int64
from pythonbpf import bpf, section, bpfglobal, compile
from vmlinux import struct_xdp_md
from vmlinux import XDP_PASS
import logging
@bpf
@section("xdp")
def print_xdp_data(ctx: struct_xdp_md) -> c_int64:
data = 0
data = ctx.data # 32-bit field: packet start pointer
something = 2 + data
print(f"ctx->data = {something}")
return c_int64(XDP_PASS)
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile(logging.INFO)

View File

@ -0,0 +1,24 @@
from ctypes import c_int64
from pythonbpf import bpf, section, bpfglobal, compile, compile_to_ir
from vmlinux import struct_xdp_md
from vmlinux import XDP_PASS
import logging
@bpf
@section("xdp")
def print_xdp_data(ctx: struct_xdp_md) -> c_int64:
data = c_int64(ctx.data) # 32-bit field: packet start pointer
something = 2 + data
print(f"ctx->data = {something}")
return c_int64(XDP_PASS)
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile_to_ir("i32_test_fail_2.py", "i32_test_fail_2.ll")
compile(logging.INFO)

View File

@ -44,4 +44,4 @@ def LICENSE() -> str:
compile_to_ir("simple_struct_test.py", "simple_struct_test.ll", loglevel=logging.DEBUG) compile_to_ir("simple_struct_test.py", "simple_struct_test.ll", loglevel=logging.DEBUG)
# compile() compile()