4 Commits

6 changed files with 223 additions and 93 deletions

View File

@ -16,10 +16,37 @@ def get_module_symbols(module_name: str):
return [name for name in dir(imported_module)], imported_module return [name for name in dir(imported_module)], imported_module
def unwrap_pointer_type(type_obj: Any) -> Any:
"""
Recursively unwrap all pointer layers to get the base type.
This handles multiply nested pointers like LP_LP_struct_attribute_group
and returns the base type (struct_attribute_group).
Stops unwrapping when reaching a non-pointer type (one without _type_ attribute).
Args:
type_obj: The type object to unwrap
Returns:
The base type after unwrapping all pointer layers
"""
current_type = type_obj
# Keep unwrapping while it's a pointer/array type (has _type_)
# But stop if _type_ is just a string or basic type marker
while hasattr(current_type, "_type_"):
next_type = current_type._type_
# Stop if _type_ is a string (like 'c' for c_char)
if isinstance(next_type, str):
break
current_type = next_type
return current_type
def process_vmlinux_class( def process_vmlinux_class(
node, node,
llvm_module, llvm_module,
handler: DependencyHandler, handler: DependencyHandler,
): ):
symbols_in_module, imported_module = get_module_symbols("vmlinux") symbols_in_module, imported_module = get_module_symbols("vmlinux")
if node.name in symbols_in_module: if node.name in symbols_in_module:
@ -30,10 +57,10 @@ def process_vmlinux_class(
def process_vmlinux_post_ast( def process_vmlinux_post_ast(
elem_type_class, elem_type_class,
llvm_handler, llvm_handler,
handler: DependencyHandler, handler: DependencyHandler,
processing_stack=None, processing_stack=None,
): ):
# Initialize processing stack on first call # Initialize processing stack on first call
if processing_stack is None: if processing_stack is None:
@ -113,7 +140,7 @@ def process_vmlinux_post_ast(
# Process pointer to ctype # Process pointer to ctype
if isinstance(elem_type, type) and issubclass( if isinstance(elem_type, type) and issubclass(
elem_type, ctypes._Pointer elem_type, ctypes._Pointer
): ):
# Get the pointed-to type # Get the pointed-to type
pointed_type = elem_type._type_ pointed_type = elem_type._type_
@ -126,7 +153,7 @@ def process_vmlinux_post_ast(
# Process function pointers (CFUNCTYPE) # Process function pointers (CFUNCTYPE)
elif hasattr(elem_type, "_restype_") and hasattr( elif hasattr(elem_type, "_restype_") and hasattr(
elem_type, "_argtypes_" elem_type, "_argtypes_"
): ):
# This is a CFUNCTYPE or similar # This is a CFUNCTYPE or similar
logger.info( logger.info(
@ -158,13 +185,19 @@ def process_vmlinux_post_ast(
if hasattr(elem_type, "_length_") and is_complex_type: if hasattr(elem_type, "_length_") and is_complex_type:
type_length = elem_type._length_ type_length = elem_type._length_
if containing_type.__module__ == "vmlinux": # Unwrap all pointer layers to get the base type for dependency tracking
new_dep_node.add_dependent( base_type = unwrap_pointer_type(elem_type)
elem_type._type_.__name__ base_type_module = getattr(base_type, "__module__", None)
if hasattr(elem_type._type_, "__name__")
else str(elem_type._type_) if base_type_module == "vmlinux":
base_type_name = (
base_type.__name__
if hasattr(base_type, "__name__")
else str(base_type)
) )
elif containing_type.__module__ == ctypes.__name__: new_dep_node.add_dependent(base_type_name)
elif base_type_module == ctypes.__name__ or base_type_module is None:
# Handle ctypes or types with no module (like some internal ctypes types)
if isinstance(elem_type, type): if isinstance(elem_type, type):
if issubclass(elem_type, ctypes.Array): if issubclass(elem_type, ctypes.Array):
ctype_complex_type = ctypes.Array ctype_complex_type = ctypes.Array
@ -178,7 +211,7 @@ def process_vmlinux_post_ast(
raise TypeError("Unsupported ctypes subclass") raise TypeError("Unsupported ctypes subclass")
else: else:
raise ImportError( raise ImportError(
f"Unsupported module of {containing_type}" f"Unsupported module of {base_type}: {base_type_module}"
) )
logger.debug( logger.debug(
f"{containing_type} containing type of parent {elem_name} with {elem_type} and ctype {ctype_complex_type} and length {type_length}" f"{containing_type} containing type of parent {elem_name} with {elem_type} and ctype {ctype_complex_type} and length {type_length}"
@ -191,11 +224,16 @@ def process_vmlinux_post_ast(
elem_name, ctype_complex_type elem_name, ctype_complex_type
) )
new_dep_node.set_field_type(elem_name, elem_type) new_dep_node.set_field_type(elem_name, elem_type)
if containing_type.__module__ == "vmlinux":
# Check the containing_type module to decide whether to recurse
containing_type_module = getattr(containing_type, "__module__", None)
if containing_type_module == "vmlinux":
# Also unwrap containing_type to get base type name
base_containing_type = unwrap_pointer_type(containing_type)
containing_type_name = ( containing_type_name = (
containing_type.__name__ base_containing_type.__name__
if hasattr(containing_type, "__name__") if hasattr(base_containing_type, "__name__")
else str(containing_type) else str(base_containing_type)
) )
# Check for self-reference or already processed # Check for self-reference or already processed
@ -212,21 +250,21 @@ def process_vmlinux_post_ast(
) )
new_dep_node.set_field_ready(elem_name, True) new_dep_node.set_field_ready(elem_name, True)
else: else:
# Process recursively - THIS WAS MISSING # Process recursively - use base containing type, not the pointer wrapper
new_dep_node.add_dependent(containing_type_name) new_dep_node.add_dependent(containing_type_name)
process_vmlinux_post_ast( process_vmlinux_post_ast(
containing_type, base_containing_type,
llvm_handler, llvm_handler,
handler, handler,
processing_stack, processing_stack,
) )
new_dep_node.set_field_ready(elem_name, True) new_dep_node.set_field_ready(elem_name, True)
elif containing_type.__module__ == ctypes.__name__: elif containing_type_module == ctypes.__name__ or containing_type_module is None:
logger.debug(f"Processing ctype internal{containing_type}") logger.debug(f"Processing ctype internal{containing_type}")
new_dep_node.set_field_ready(elem_name, True) new_dep_node.set_field_ready(elem_name, True)
else: else:
raise TypeError( raise TypeError(
"Module not supported in recursive resolution" f"Module not supported in recursive resolution: {containing_type_module}"
) )
else: else:
new_dep_node.add_dependent( new_dep_node.add_dependent(
@ -245,9 +283,12 @@ def process_vmlinux_post_ast(
raise ValueError( raise ValueError(
f"{elem_name} with type {elem_type} from module {module_name} not supported in recursive resolver" f"{elem_name} with type {elem_type} from module {module_name} not supported in recursive resolver"
) )
elif module_name == ctypes.__name__ or module_name is None:
# Handle ctypes types - these don't need processing, just return
logger.debug(f"Skipping ctypes type {current_symbol_name}")
return True
else: else:
raise ImportError("UNSUPPORTED Module") raise ImportError(f"UNSUPPORTED Module {module_name}")
logger.info( logger.info(
f"{current_symbol_name} processed and handler readiness {handler.is_ready}" f"{current_symbol_name} processed and handler readiness {handler.is_ready}"

View File

@ -11,7 +11,9 @@ from .class_handler import process_vmlinux_class
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def detect_import_statement(tree: ast.AST) -> list[tuple[str, ast.ImportFrom]]: def detect_import_statement(
tree: ast.AST,
) -> list[tuple[str, ast.ImportFrom, str, str]]:
""" """
Parse AST and detect import statements from vmlinux. Parse AST and detect import statements from vmlinux.
@ -25,7 +27,7 @@ def detect_import_statement(tree: ast.AST) -> list[tuple[str, ast.ImportFrom]]:
List of tuples containing (module_name, imported_item) for each vmlinux import List of tuples containing (module_name, imported_item) for each vmlinux import
Raises: Raises:
SyntaxError: If multiple imports from vmlinux are attempted or import * is used SyntaxError: If import * is used
""" """
vmlinux_imports = [] vmlinux_imports = []
@ -40,28 +42,19 @@ def detect_import_statement(tree: ast.AST) -> list[tuple[str, ast.ImportFrom]]:
"Please import specific types explicitly." "Please import specific types explicitly."
) )
# Check for multiple imports: from vmlinux import A, B, C
if len(node.names) > 1:
imported_names = [alias.name for alias in node.names]
raise SyntaxError(
f"Multiple imports from vmlinux are not supported. "
f"Found: {', '.join(imported_names)}. "
f"Please use separate import statements for each type."
)
# Check if no specific import is specified (should not happen with valid Python) # Check if no specific import is specified (should not happen with valid Python)
if len(node.names) == 0: if len(node.names) == 0:
raise SyntaxError( raise SyntaxError(
"Import from vmlinux must specify at least one type." "Import from vmlinux must specify at least one type."
) )
# Valid single import # Support multiple imports: from vmlinux import A, B, C
for alias in node.names: for alias in node.names:
import_name = alias.name import_name = alias.name
# Use alias if provided, otherwise use the original name (commented) # Use alias if provided, otherwise use the original name
# as_name = alias.asname if alias.asname else alias.name as_name = alias.asname if alias.asname else alias.name
vmlinux_imports.append(("vmlinux", node)) vmlinux_imports.append(("vmlinux", node, import_name, as_name))
logger.info(f"Found vmlinux import: {import_name}") logger.info(f"Found vmlinux import: {import_name} as {as_name}")
# Handle "import vmlinux" statements (not typical but should be rejected) # Handle "import vmlinux" statements (not typical but should be rejected)
elif isinstance(node, ast.Import): elif isinstance(node, ast.Import):
@ -73,6 +66,7 @@ def detect_import_statement(tree: ast.AST) -> list[tuple[str, ast.ImportFrom]]:
) )
logger.info(f"Total vmlinux imports detected: {len(vmlinux_imports)}") logger.info(f"Total vmlinux imports detected: {len(vmlinux_imports)}")
# print(f"\n**************\n{vmlinux_imports}\n**************\n")
return vmlinux_imports return vmlinux_imports
@ -103,40 +97,37 @@ def vmlinux_proc(tree: ast.AST, module):
with open(source_file, "r") as f: with open(source_file, "r") as f:
mod_ast = ast.parse(f.read(), filename=source_file) mod_ast = ast.parse(f.read(), filename=source_file)
for import_mod, import_node in import_statements: for import_mod, import_node, imported_name, as_name in import_statements:
for alias in import_node.names: found = False
imported_name = alias.name for mod_node in mod_ast.body:
found = False if isinstance(mod_node, ast.ClassDef) and mod_node.name == imported_name:
for mod_node in mod_ast.body: process_vmlinux_class(mod_node, module, handler)
if ( found = True
isinstance(mod_node, ast.ClassDef) break
and mod_node.name == imported_name if isinstance(mod_node, ast.Assign):
): for target in mod_node.targets:
process_vmlinux_class(mod_node, module, handler) if isinstance(target, ast.Name) and target.id == imported_name:
found = True process_vmlinux_assign(mod_node, module, assignments, as_name)
break found = True
if isinstance(mod_node, ast.Assign): break
for target in mod_node.targets: if found:
if isinstance(target, ast.Name) and target.id == imported_name: break
process_vmlinux_assign(mod_node, module, assignments) if not found:
found = True logger.info(f"{imported_name} not found as ClassDef or Assign in vmlinux")
break
if found:
break
if not found:
logger.info(
f"{imported_name} not found as ClassDef or Assign in vmlinux"
)
IRGenerator(module, handler, assignments) IRGenerator(module, handler, assignments)
return assignments return assignments
def process_vmlinux_assign(node, module, assignments: dict[str, AssignmentInfo]): def process_vmlinux_assign(
node, module, assignments: dict[str, AssignmentInfo], target_name=None
):
"""Process assignments from vmlinux module.""" """Process assignments from vmlinux module."""
# Only handle single-target assignments # Only handle single-target assignments
if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
target_name = node.targets[0].id # Use provided target_name (for aliased imports) or fall back to original name
if target_name is None:
target_name = node.targets[0].id
# Handle constant value assignments # Handle constant value assignments
if isinstance(node.value, ast.Constant): if isinstance(node.value, ast.Constant):

View File

@ -46,13 +46,14 @@ def debug_info_generation(
if struct.name.startswith("struct_"): if struct.name.startswith("struct_"):
struct_name = struct.name.removeprefix("struct_") struct_name = struct.name.removeprefix("struct_")
# Create struct type with all members
struct_type = generator.create_struct_type_with_name(
struct_name, members, struct.__sizeof__() * 8, is_distinct=True
)
else: else:
raise ValueError("Unions are not supported in the current version") logger.warning("Blindly handling Unions present in vmlinux dependencies")
# Create struct type with all members struct_type = None
struct_type = generator.create_struct_type_with_name( # raise ValueError("Unions are not supported in the current version")
struct_name, members, struct.__sizeof__() * 8, is_distinct=True
)
return struct_type return struct_type
@ -62,7 +63,7 @@ def _get_field_debug_type(
generator: DebugInfoGenerator, generator: DebugInfoGenerator,
parent_struct: DependencyNode, parent_struct: DependencyNode,
generated_debug_info: List[Tuple[DependencyNode, Any]], generated_debug_info: List[Tuple[DependencyNode, Any]],
) -> tuple[Any, int]: ) -> tuple[Any, int] | None:
""" """
Determine the appropriate debug type for a field based on its Python/ctypes type. Determine the appropriate debug type for a field based on its Python/ctypes type.
@ -78,7 +79,11 @@ def _get_field_debug_type(
""" """
# Handle complex types (arrays, pointers) # Handle complex types (arrays, pointers)
if field.ctype_complex_type is not None: if field.ctype_complex_type is not None:
if issubclass(field.ctype_complex_type, ctypes.Array): #TODO: Check if this is a CFUNCTYPE (function pointer), but sadly it just checks callable for now
if callable(field.ctype_complex_type):
# Handle function pointer types, create a void pointer as a placeholder
return generator.create_pointer_type(None), 64
elif issubclass(field.ctype_complex_type, ctypes.Array):
# Handle array types # Handle array types
element_type, base_type_size = _get_basic_debug_type( element_type, base_type_size = _get_basic_debug_type(
field.containing_type, generator field.containing_type, generator

View File

@ -11,6 +11,9 @@ logger = logging.getLogger(__name__)
class IRGenerator: class IRGenerator:
# This field keeps track of the non_struct names to avoid duplicate name errors.
type_number = 0
unprocessed_store = []
# get the assignments dict and add this stuff to it. # get the assignments dict and add this stuff to it.
def __init__(self, llvm_module, handler: DependencyHandler, assignments): def __init__(self, llvm_module, handler: DependencyHandler, assignments):
self.llvm_module = llvm_module self.llvm_module = llvm_module
@ -68,6 +71,7 @@ class IRGenerator:
dep_node_from_dependency, processing_stack dep_node_from_dependency, processing_stack
) )
else: else:
print(struct)
raise RuntimeError( raise RuntimeError(
f"Warning: Dependency {dependency} not found in handler" f"Warning: Dependency {dependency} not found in handler"
) )
@ -82,6 +86,7 @@ class IRGenerator:
members_dict = {} members_dict = {}
for field_name, field in struct.fields.items(): for field_name, field in struct.fields.items():
# Get the generated field name from our dictionary, or use field_name if not found # Get the generated field name from our dictionary, or use field_name if not found
print(f"DEBUG: {struct.name}, {field_name}")
if ( if (
struct.name in self.generated_field_names struct.name in self.generated_field_names
and field_name in self.generated_field_names[struct.name] and field_name in self.generated_field_names[struct.name]
@ -129,7 +134,20 @@ class IRGenerator:
for field_name, field in struct.fields.items(): for field_name, field in struct.fields.items():
# does not take arrays and similar types into consideration yet. # does not take arrays and similar types into consideration yet.
if field.ctype_complex_type is not None and issubclass( if callable(field.ctype_complex_type):
# Function pointer case - generate a simple field accessor
field_co_re_name, returned = self._struct_name_generator(
struct, field, field_index
)
print(field_co_re_name)
field_index += 1
globvar = ir.GlobalVariable(
self.llvm_module, ir.IntType(64), name=field_co_re_name
)
globvar.linkage = "external"
globvar.set_metadata("llvm.preserve.access.index", debug_info)
self.generated_field_names[struct.name][field_name] = globvar
elif field.ctype_complex_type is not None and issubclass(
field.ctype_complex_type, ctypes.Array field.ctype_complex_type, ctypes.Array
): ):
array_size = field.type_size array_size = field.type_size
@ -137,7 +155,7 @@ class IRGenerator:
if containing_type.__module__ == ctypes.__name__: if containing_type.__module__ == ctypes.__name__:
containing_type_size = ctypes.sizeof(containing_type) containing_type_size = ctypes.sizeof(containing_type)
if array_size == 0: if array_size == 0:
field_co_re_name = self._struct_name_generator( field_co_re_name, returned = self._struct_name_generator(
struct, field, field_index, True, 0, containing_type_size struct, field, field_index, True, 0, containing_type_size
) )
globvar = ir.GlobalVariable( globvar = ir.GlobalVariable(
@ -149,7 +167,7 @@ class IRGenerator:
field_index += 1 field_index += 1
continue continue
for i in range(0, array_size): for i in range(0, array_size):
field_co_re_name = self._struct_name_generator( field_co_re_name, returned = self._struct_name_generator(
struct, field, field_index, True, i, containing_type_size struct, field, field_index, True, i, containing_type_size
) )
globvar = ir.GlobalVariable( globvar = ir.GlobalVariable(
@ -163,12 +181,26 @@ class IRGenerator:
array_size = field.type_size array_size = field.type_size
containing_type = field.containing_type containing_type = field.containing_type
if containing_type.__module__ == "vmlinux": if containing_type.__module__ == "vmlinux":
containing_type_size = self.handler[ print(struct)
containing_type.__name__ # Unwrap all pointer layers to get the base struct type
].current_offset base_containing_type = containing_type
for i in range(0, array_size): while hasattr(base_containing_type, "_type_"):
field_co_re_name = self._struct_name_generator( next_type = base_containing_type._type_
struct, field, field_index, True, i, containing_type_size # Stop if _type_ is a string (like 'c' for c_char)
#TODO: stacked pointers not handl;ing ctypes check here as well
if isinstance(next_type, str):
break
base_containing_type = next_type
# Get the base struct name
base_struct_name = base_containing_type.__name__ if hasattr(base_containing_type, "__name__") else str(base_containing_type)
# Look up the size using the base struct name
containing_type_size = self.handler[base_struct_name].current_offset
print(f"GAY: {array_size}, {struct.name}, {field_name}")
if array_size == 0:
field_co_re_name, returned = self._struct_name_generator(
struct, field, field_index, True, 0, containing_type_size
) )
globvar = ir.GlobalVariable( globvar = ir.GlobalVariable(
self.llvm_module, ir.IntType(64), name=field_co_re_name self.llvm_module, ir.IntType(64), name=field_co_re_name
@ -176,9 +208,21 @@ class IRGenerator:
globvar.linkage = "external" globvar.linkage = "external"
globvar.set_metadata("llvm.preserve.access.index", debug_info) globvar.set_metadata("llvm.preserve.access.index", debug_info)
self.generated_field_names[struct.name][field_name] = globvar self.generated_field_names[struct.name][field_name] = globvar
field_index += 1 field_index += 1
else:
for i in range(0, array_size):
field_co_re_name, returned = self._struct_name_generator(
struct, field, field_index, True, i, containing_type_size
)
globvar = ir.GlobalVariable(
self.llvm_module, ir.IntType(64), name=field_co_re_name
)
globvar.linkage = "external"
globvar.set_metadata("llvm.preserve.access.index", debug_info)
self.generated_field_names[struct.name][field_name] = globvar
field_index += 1
else: else:
field_co_re_name = self._struct_name_generator( field_co_re_name, returned = self._struct_name_generator(
struct, field, field_index struct, field, field_index
) )
field_index += 1 field_index += 1
@ -198,7 +242,7 @@ class IRGenerator:
is_indexed: bool = False, is_indexed: bool = False,
index: int = 0, index: int = 0,
containing_type_size: int = 0, containing_type_size: int = 0,
) -> str: ) -> tuple[str, bool]:
# TODO: Does not support Unions as well as recursive pointer and array type naming # TODO: Does not support Unions as well as recursive pointer and array type naming
if is_indexed: if is_indexed:
name = ( name = (
@ -208,7 +252,7 @@ class IRGenerator:
+ "$" + "$"
+ f"0:{field_index}:{index}" + f"0:{field_index}:{index}"
) )
return name return name, True
elif struct.name.startswith("struct_"): elif struct.name.startswith("struct_"):
name = ( name = (
"llvm." "llvm."
@ -217,9 +261,18 @@ class IRGenerator:
+ "$" + "$"
+ f"0:{field_index}" + f"0:{field_index}"
) )
return name return name, True
else: else:
print(self.handler[struct.name]) logger.warning(
raise TypeError( "Blindly handling non-struct type to avoid type errors in vmlinux IR generation. Possibly a union."
"Name generation cannot occur due to type name not starting with struct"
) )
self.type_number += 1
unprocessed_type = "unprocessed_type_" + str(self.handler[struct.name].name)
if self.unprocessed_store.__contains__(unprocessed_type):
return unprocessed_type + "_" + str(self.type_number), False
else:
self.unprocessed_store.append(unprocessed_type)
return unprocessed_type, False
# raise TypeError(
# "Name generation cannot occur due to type name not starting with struct"
# )

View File

@ -0,0 +1,21 @@
from vmlinux import struct_request, struct_pt_regs, XDP_PASS
from pythonbpf import bpf, section, bpfglobal, compile_to_ir
import logging
@bpf
@section("kprobe/blk_mq_start_request")
def example(ctx: struct_pt_regs):
req = struct_request(ctx.di)
c = req.__data_len
d = XDP_PASS
print(f"data length {c} and test {d}")
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile_to_ir("requests.py", "requests.ll", loglevel=logging.INFO)

View File

@ -0,0 +1,19 @@
from vmlinux import struct_kobj_type
from pythonbpf import bpf, section, bpfglobal, compile_to_ir
import logging
from ctypes import c_void_p
@bpf
@section("kprobe/blk_mq_start_request")
def example(ctx: c_void_p):
print(f"data lengt")
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile_to_ir("requests.py", "requests.ll", loglevel=logging.INFO)