diff --git a/pythonbpf/vmlinux_parser/class_handler.py b/pythonbpf/vmlinux_parser/class_handler.py index 58168b3..a508ff7 100644 --- a/pythonbpf/vmlinux_parser/class_handler.py +++ b/pythonbpf/vmlinux_parser/class_handler.py @@ -2,9 +2,8 @@ import logging from functools import lru_cache import importlib -from .assignment_info import AssignmentInfo, AssignmentType from .dependency_handler import DependencyHandler -from .dependency_node import DependencyNode, Field +from .dependency_node import DependencyNode import ctypes from typing import Optional, Any, Dict @@ -21,12 +20,11 @@ def process_vmlinux_class( node, llvm_module, handler: DependencyHandler, - assignments: dict[str, AssignmentInfo], ): symbols_in_module, imported_module = get_module_symbols("vmlinux") if node.name in symbols_in_module: vmlinux_type = getattr(imported_module, node.name) - process_vmlinux_post_ast(vmlinux_type, llvm_module, handler, assignments) + process_vmlinux_post_ast(vmlinux_type, llvm_module, handler) else: raise ImportError(f"{node.name} not in vmlinux") @@ -35,7 +33,6 @@ def process_vmlinux_post_ast( elem_type_class, llvm_handler, handler: DependencyHandler, - assignments: dict[str, AssignmentInfo], processing_stack=None, ): # Initialize processing stack on first call @@ -103,9 +100,6 @@ def process_vmlinux_post_ast( else: raise TypeError("Could not get required class and definition") - # Create a members dictionary for AssignmentInfo - members_dict: Dict[str, tuple[str, Field]] = {} - logger.debug(f"Extracted fields for {current_symbol_name}: {field_table}") for elem in field_table.items(): elem_name, elem_temp_list = elem @@ -113,11 +107,6 @@ def process_vmlinux_post_ast( local_module_name = getattr(elem_type, "__module__", None) new_dep_node.add_field(elem_name, elem_type, ready=False) - # Store field reference for struct assignment info - field_ref = new_dep_node.get_field(elem_name) - if field_ref: - members_dict[elem_name] = (elem_name, field_ref) - if local_module_name == ctypes.__name__: # TODO: need to process pointer to ctype and also CFUNCTYPES here recursively. Current processing is a single dereference new_dep_node.set_field_bitfield_size(elem_name, elem_bitfield_size) @@ -229,7 +218,6 @@ def process_vmlinux_post_ast( containing_type, llvm_handler, handler, - assignments, # Pass assignments to recursive call processing_stack, ) new_dep_node.set_field_ready(elem_name, True) @@ -250,7 +238,6 @@ def process_vmlinux_post_ast( elem_type, llvm_handler, handler, - assignments, processing_stack, ) new_dep_node.set_field_ready(elem_name, True) @@ -259,17 +246,6 @@ def process_vmlinux_post_ast( f"{elem_name} with type {elem_type} from module {module_name} not supported in recursive resolver" ) - # Add struct to assignments dictionary - assignments[current_symbol_name] = AssignmentInfo( - value_type=AssignmentType.STRUCT, - python_type=elem_type_class, - value=None, - pointer_level=None, - signature=None, - members=members_dict, - ) - logger.info(f"Added struct assignment info for {current_symbol_name}") - else: raise ImportError("UNSUPPORTED Module") diff --git a/pythonbpf/vmlinux_parser/import_detector.py b/pythonbpf/vmlinux_parser/import_detector.py index d8bd78f..6df7a98 100644 --- a/pythonbpf/vmlinux_parser/import_detector.py +++ b/pythonbpf/vmlinux_parser/import_detector.py @@ -112,7 +112,7 @@ def vmlinux_proc(tree: ast.AST, module): isinstance(mod_node, ast.ClassDef) and mod_node.name == imported_name ): - process_vmlinux_class(mod_node, module, handler, assignments) + process_vmlinux_class(mod_node, module, handler) found = True break if isinstance(mod_node, ast.Assign): @@ -128,7 +128,7 @@ def vmlinux_proc(tree: ast.AST, module): f"{imported_name} not found as ClassDef or Assign in vmlinux" ) - IRGenerator(module, handler) + IRGenerator(module, handler, assignments) return assignments diff --git a/pythonbpf/vmlinux_parser/ir_gen/ir_generation.py b/pythonbpf/vmlinux_parser/ir_gen/ir_generation.py index cacd2e7..bd0adfa 100644 --- a/pythonbpf/vmlinux_parser/ir_gen/ir_generation.py +++ b/pythonbpf/vmlinux_parser/ir_gen/ir_generation.py @@ -1,5 +1,8 @@ import ctypes import logging + +from ..dependency_node import Field +from ..assignment_info import AssignmentInfo, AssignmentType from ..dependency_handler import DependencyHandler from .debug_info_gen import debug_info_generation from ..dependency_node import DependencyNode @@ -10,11 +13,13 @@ logger = logging.getLogger(__name__) class IRGenerator: # get the assignments dict and add this stuff to it. - def __init__(self, llvm_module, handler: DependencyHandler, assignment=None): + def __init__(self, llvm_module, handler: DependencyHandler, assignments): self.llvm_module = llvm_module self.handler: DependencyHandler = handler self.generated: list[str] = [] self.generated_debug_info: list = [] + self.generated_field_names: dict[Field, str] = {} + self.assignments: dict[str, AssignmentInfo] = assignments if not handler.is_ready: raise ImportError( "Semantic analysis of vmlinux imports failed. Cannot generate IR" @@ -67,6 +72,24 @@ class IRGenerator: f"Warning: Dependency {dependency} not found in handler" ) + # Fill the assignments dictionary with struct information + if struct.name not in self.assignments: + # Create a members dictionary for AssignmentInfo + members_dict = {} + for field_name, field in struct.fields.items(): + members_dict[field_name] = (self.generated_field_names[field], field) + + # Add struct to assignments dictionary + self.assignments[struct.name] = AssignmentInfo( + value_type=AssignmentType.STRUCT, + python_type=struct.ctype_struct, + value=None, + pointer_level=None, + signature=None, + members=members_dict, + ) + logger.info(f"Added struct assignment info for {struct.name}") + # Actual processor logic here after dependencies are resolved self.generated_debug_info.append( (struct, self.gen_ir(struct, self.generated_debug_info)) @@ -98,6 +121,7 @@ class IRGenerator: field_co_re_name = self._struct_name_generator( struct, field, field_index, True, i, containing_type_size ) + self.generated_field_names[field] = field_co_re_name globvar = ir.GlobalVariable( self.llvm_module, ir.IntType(64), name=field_co_re_name ) @@ -115,6 +139,7 @@ class IRGenerator: field_co_re_name = self._struct_name_generator( struct, field, field_index, True, i, containing_type_size ) + self.generated_field_names[field] = field_co_re_name globvar = ir.GlobalVariable( self.llvm_module, ir.IntType(64), name=field_co_re_name ) @@ -125,6 +150,7 @@ class IRGenerator: field_co_re_name = self._struct_name_generator( struct, field, field_index ) + self.generated_field_names[field] = field_co_re_name field_index += 1 globvar = ir.GlobalVariable( self.llvm_module, ir.IntType(64), name=field_co_re_name