diff --git a/pythonbpf/vmlinux_parser/vmlinux_class_handler.py b/pythonbpf/vmlinux_parser/vmlinux_class_handler.py index a5bdf5e..8f9e904 100644 --- a/pythonbpf/vmlinux_parser/vmlinux_class_handler.py +++ b/pythonbpf/vmlinux_parser/vmlinux_class_handler.py @@ -4,6 +4,7 @@ from functools import lru_cache import importlib from .dependency_handler import DependencyHandler from .dependency_node import DependencyNode +import ctypes logger = logging.getLogger(__name__) @@ -12,7 +13,8 @@ def get_module_symbols(module_name: str): imported_module = importlib.import_module(module_name) return [name for name in dir(imported_module)], imported_module -def process_vmlinux_class(node, llvm_module, handler: DependencyHandler, parent=""): +# Recursive function that gets all the dependent classes and adds them to handler +def process_vmlinux_class(node, llvm_module, handler: DependencyHandler): symbols_in_module, imported_module = get_module_symbols("vmlinux") current_symbol_name = node.name if current_symbol_name not in symbols_in_module: @@ -28,7 +30,7 @@ def process_vmlinux_class(node, llvm_module, handler: DependencyHandler, parent= # Inspect the class fields # Assuming class_obj has fields stored in some standard way - #If it's a ctypes-like structure with _fields_ + # If it's a ctypes-like structure with _fields_ if hasattr(class_obj, '_fields_'): for field_name, field_type in class_obj._fields_: field_table[field_name] = field_type @@ -42,5 +44,21 @@ def process_vmlinux_class(node, llvm_module, handler: DependencyHandler, parent= raise TypeError("Could not get required class and definition") logger.info(f"Extracted fields for {current_symbol_name}: {field_table}") - return field_table + if handler.has_node(current_symbol_name): + logger.info("Extraction pruned due to already available field") + return True + else: + new_dep_node = DependencyNode(name=current_symbol_name) + for elem_name, elem_type in field_table.items(): + module_name = getattr(elem_type, "__module__", None) + if module_name == ctypes.__name__: + new_dep_node.add_field(elem_name, elem_type, ready=True) + elif module_name == "vmlinux": + new_dep_node.add_field(elem_name, elem_type, ready=False) + if process_vmlinux_class(elem_type, llvm_module, handler): + new_dep_node.set_field_ready(elem_name, True) + else: + print(f"[other] {elem_name} -> {elem_type}") + handler.add_node(new_dep_node) + return True diff --git a/tests/failing_tests/xdp_pass.py b/tests/failing_tests/xdp_pass.py index 2e3c644..dddce92 100644 --- a/tests/failing_tests/xdp_pass.py +++ b/tests/failing_tests/xdp_pass.py @@ -2,7 +2,6 @@ from pythonbpf import bpf, map, section, bpfglobal, compile, compile_to_ir from pythonbpf.maps import HashMap from pythonbpf.helper import XDP_PASS from vmlinux import struct_xdp_md -# from vmlinux import XDP_PASS from ctypes import c_int64 # Instructions to how to run this program