format chore

This commit is contained in:
2025-10-11 22:00:25 +05:30
parent abbf17748d
commit 75d3ad4fe2
6 changed files with 104 additions and 78 deletions

View File

@ -1 +1,3 @@
from .import_detector import vmlinux_proc
__all__ = ["vmlinux_proc"]

View File

@ -1,4 +1,3 @@
import ast
import logging
from functools import lru_cache
import importlib
@ -20,9 +19,9 @@ def process_vmlinux_class(node, llvm_module, handler: DependencyHandler):
symbols_in_module, imported_module = get_module_symbols("vmlinux")
# Handle both node objects and type objects
if hasattr(node, 'name'):
if hasattr(node, "name"):
current_symbol_name = node.name
elif hasattr(node, '__name__'):
elif hasattr(node, "__name__"):
current_symbol_name = node.__name__
else:
current_symbol_name = str(node)
@ -30,7 +29,9 @@ def process_vmlinux_class(node, llvm_module, handler: DependencyHandler):
if current_symbol_name not in symbols_in_module:
raise ImportError(f"{current_symbol_name} not present in module vmlinux")
logger.info(f"Resolving vmlinux class {current_symbol_name}")
logger.debug(f"Current handler state: {handler.is_ready} readiness and {handler.get_all_nodes()} all nodes")
logger.debug(
f"Current handler state: {handler.is_ready} readiness and {handler.get_all_nodes()} all nodes"
)
field_table = {} # should contain the field and it's type.
# Get the class object from the module
@ -42,12 +43,12 @@ def process_vmlinux_class(node, llvm_module, handler: DependencyHandler):
# Inspect the class fields
# Assuming class_obj has fields stored in some standard way
# If it's a ctypes-like structure with _fields_
if hasattr(class_obj, '_fields_'):
if hasattr(class_obj, "_fields_"):
for field_name, field_type in class_obj._fields_:
field_table[field_name] = field_type
# If it's using __annotations__
elif hasattr(class_obj, '__annotations__'):
elif hasattr(class_obj, "__annotations__"):
for field_name, field_type in class_obj.__annotations__.items():
field_table[field_name] = field_type
@ -69,17 +70,24 @@ def process_vmlinux_class(node, llvm_module, handler: DependencyHandler):
print("elem_name:", elem_name, "elem_type:", elem_type)
# currently fails when a non-normal type appears which is basically everytime
identify_ctypes_type(elem_type)
symbol_name = elem_type.__name__ if hasattr(elem_type, '__name__') else str(elem_type)
symbol_name = (
elem_type.__name__
if hasattr(elem_type, "__name__")
else str(elem_type)
)
vmlinux_symbol = getattr(imported_module, symbol_name)
if process_vmlinux_class(vmlinux_symbol, llvm_module, handler):
new_dep_node.set_field_ready(elem_name, True)
else:
raise ValueError(f"{elem_name} with type {elem_type} not supported in recursive resolver")
raise ValueError(
f"{elem_name} with type {elem_type} not supported in recursive resolver"
)
handler.add_node(new_dep_node)
logger.info(f"added node: {current_symbol_name}")
return True
def identify_ctypes_type(t):
if isinstance(t, type): # t is a type/class
if issubclass(t, ctypes.Array):

View File

@ -5,6 +5,7 @@ from typing import Dict, Any, Optional
@dataclass
class Field:
"""Represents a field in a dependency node with its type and readiness state."""
name: str
type: type
value: Any = None
@ -64,13 +65,22 @@ class DependencyNode:
ready_fields = somestruct.get_ready_fields()
print(f"Ready fields: {[field.name for field in ready_fields.values()]}") # ['field_1', 'field_2']
"""
name: str
fields: Dict[str, Field] = field(default_factory=dict)
_ready_cache: Optional[bool] = field(default=None, repr=False)
def add_field(self, name: str, field_type: type, initial_value: Any = None, ready: bool = False) -> None:
def add_field(
self,
name: str,
field_type: type,
initial_value: Any = None,
ready: bool = False,
) -> None:
"""Add a field to the node with an optional initial value and readiness state."""
self.fields[name] = Field(name=name, type=field_type, value=initial_value, ready=ready)
self.fields[name] = Field(
name=name, type=field_type, value=initial_value, ready=ready
)
# Invalidate readiness cache
self._ready_cache = None

View File

@ -6,7 +6,7 @@ import inspect
from .dependency_handler import DependencyHandler
from .ir_generation import IRGenerator
from .vmlinux_class_handler import process_vmlinux_class
from .class_handler import process_vmlinux_class
logger = logging.getLogger(__name__)
@ -58,8 +58,8 @@ def detect_import_statement(tree: ast.AST) -> List[Tuple[str, ast.ImportFrom]]:
# Valid single import
for alias in node.names:
import_name = alias.name
# Use alias if provided, otherwise use the original name
as_name = alias.asname if alias.asname else alias.name
# Use alias if provided, otherwise use the original name (commented)
# as_name = alias.asname if alias.asname else alias.name
vmlinux_imports.append(("vmlinux", node))
logger.info(f"Found vmlinux import: {import_name}")
@ -68,13 +68,14 @@ def detect_import_statement(tree: ast.AST) -> List[Tuple[str, ast.ImportFrom]]:
for alias in node.names:
if alias.name == "vmlinux" or alias.name.startswith("vmlinux."):
raise SyntaxError(
f"Direct import of vmlinux module is not supported. "
f"Use 'from vmlinux import <type>' instead."
"Direct import of vmlinux module is not supported. "
"Use 'from vmlinux import <type>' instead."
)
logger.info(f"Total vmlinux imports detected: {len(vmlinux_imports)}")
return vmlinux_imports
def vmlinux_proc(tree: ast.AST, module):
import_statements = detect_import_statement(tree)
@ -107,7 +108,10 @@ def vmlinux_proc(tree: ast.AST, module):
imported_name = alias.name
found = False
for mod_node in mod_ast.body:
if isinstance(mod_node, ast.ClassDef) and mod_node.name == imported_name:
if (
isinstance(mod_node, ast.ClassDef)
and mod_node.name == imported_name
):
process_vmlinux_class(mod_node, module, handler)
found = True
break
@ -120,9 +124,12 @@ def vmlinux_proc(tree: ast.AST, module):
if found:
break
if not found:
logger.info(f"{imported_name} not found as ClassDef or Assign in vmlinux")
logger.info(
f"{imported_name} not found as ClassDef or Assign in vmlinux"
)
IRGenerator(module, handler)
def process_vmlinux_assign(node, module, assignments: Dict[str, type]):
raise NotImplementedError("Assignment handling has not been implemented yet")