5 Commits

5 changed files with 90 additions and 31 deletions

View File

@ -1,6 +1,7 @@
from enum import Enum, auto from enum import Enum, auto
from typing import Any, Dict, List, Optional, TypedDict from typing import Any, Dict, List, Optional, TypedDict
from dataclasses import dataclass from dataclasses import dataclass
import llvmlite.ir as ir
from pythonbpf.vmlinux_parser.dependency_node import Field from pythonbpf.vmlinux_parser.dependency_node import Field
@ -32,4 +33,4 @@ class AssignmentInfo(TypedDict):
# The key of the dict is the name of the field. # The key of the dict is the name of the field.
# Value is a tuple that contains the global variable representing that field # Value is a tuple that contains the global variable representing that field
# along with all the information about that field as a Field type. # along with all the information about that field as a Field type.
members: Optional[Dict[str, tuple[str, Field]]] # For structs. members: Optional[Dict[str, tuple[ir.GlobalVariable, Field]]] # For structs.

View File

@ -2,9 +2,8 @@ import logging
from functools import lru_cache from functools import lru_cache
import importlib import importlib
from .assignment_info import AssignmentInfo, AssignmentType
from .dependency_handler import DependencyHandler from .dependency_handler import DependencyHandler
from .dependency_node import DependencyNode, Field from .dependency_node import DependencyNode
import ctypes import ctypes
from typing import Optional, Any, Dict from typing import Optional, Any, Dict
@ -21,12 +20,11 @@ def process_vmlinux_class(
node, node,
llvm_module, llvm_module,
handler: DependencyHandler, handler: DependencyHandler,
assignments: dict[str, AssignmentInfo],
): ):
symbols_in_module, imported_module = get_module_symbols("vmlinux") symbols_in_module, imported_module = get_module_symbols("vmlinux")
if node.name in symbols_in_module: if node.name in symbols_in_module:
vmlinux_type = getattr(imported_module, node.name) vmlinux_type = getattr(imported_module, node.name)
process_vmlinux_post_ast(vmlinux_type, llvm_module, handler, assignments) process_vmlinux_post_ast(vmlinux_type, llvm_module, handler)
else: else:
raise ImportError(f"{node.name} not in vmlinux") raise ImportError(f"{node.name} not in vmlinux")
@ -35,7 +33,6 @@ def process_vmlinux_post_ast(
elem_type_class, elem_type_class,
llvm_handler, llvm_handler,
handler: DependencyHandler, handler: DependencyHandler,
assignments: dict[str, AssignmentInfo],
processing_stack=None, processing_stack=None,
): ):
# Initialize processing stack on first call # Initialize processing stack on first call
@ -103,9 +100,6 @@ def process_vmlinux_post_ast(
else: else:
raise TypeError("Could not get required class and definition") raise TypeError("Could not get required class and definition")
# Create a members dictionary for AssignmentInfo
members_dict: Dict[str, tuple[str, Field]] = {}
logger.debug(f"Extracted fields for {current_symbol_name}: {field_table}") logger.debug(f"Extracted fields for {current_symbol_name}: {field_table}")
for elem in field_table.items(): for elem in field_table.items():
elem_name, elem_temp_list = elem elem_name, elem_temp_list = elem
@ -113,11 +107,6 @@ def process_vmlinux_post_ast(
local_module_name = getattr(elem_type, "__module__", None) local_module_name = getattr(elem_type, "__module__", None)
new_dep_node.add_field(elem_name, elem_type, ready=False) new_dep_node.add_field(elem_name, elem_type, ready=False)
# Store field reference for struct assignment info
field_ref = new_dep_node.get_field(elem_name)
if field_ref:
members_dict[elem_name] = (elem_name, field_ref)
if local_module_name == ctypes.__name__: if local_module_name == ctypes.__name__:
# TODO: need to process pointer to ctype and also CFUNCTYPES here recursively. Current processing is a single dereference # TODO: need to process pointer to ctype and also CFUNCTYPES here recursively. Current processing is a single dereference
new_dep_node.set_field_bitfield_size(elem_name, elem_bitfield_size) new_dep_node.set_field_bitfield_size(elem_name, elem_bitfield_size)
@ -229,7 +218,6 @@ def process_vmlinux_post_ast(
containing_type, containing_type,
llvm_handler, llvm_handler,
handler, handler,
assignments, # Pass assignments to recursive call
processing_stack, processing_stack,
) )
new_dep_node.set_field_ready(elem_name, True) new_dep_node.set_field_ready(elem_name, True)
@ -250,7 +238,6 @@ def process_vmlinux_post_ast(
elem_type, elem_type,
llvm_handler, llvm_handler,
handler, handler,
assignments,
processing_stack, processing_stack,
) )
new_dep_node.set_field_ready(elem_name, True) new_dep_node.set_field_ready(elem_name, True)
@ -259,17 +246,6 @@ def process_vmlinux_post_ast(
f"{elem_name} with type {elem_type} from module {module_name} not supported in recursive resolver" f"{elem_name} with type {elem_type} from module {module_name} not supported in recursive resolver"
) )
# Add struct to assignments dictionary
assignments[current_symbol_name] = AssignmentInfo(
value_type=AssignmentType.STRUCT,
python_type=elem_type_class,
value=None,
pointer_level=None,
signature=None,
members=members_dict,
)
logger.info(f"Added struct assignment info for {current_symbol_name}")
else: else:
raise ImportError("UNSUPPORTED Module") raise ImportError("UNSUPPORTED Module")

View File

@ -18,6 +18,31 @@ class Field:
value: Any = None value: Any = None
ready: bool = False ready: bool = False
def __hash__(self):
"""
Create a hash based on the immutable attributes that define this field's identity.
This allows Field objects to be used as dictionary keys.
"""
# Use a tuple of the fields that uniquely identify this field
identity = (
self.name,
id(self.type), # Use id for non-hashable types
id(self.ctype_complex_type) if self.ctype_complex_type else None,
id(self.containing_type) if self.containing_type else None,
self.type_size,
self.bitfield_size,
self.offset,
self.value if self.value else None,
)
return hash(identity)
def __eq__(self, other):
"""
Define equality consistent with the hash function.
Two fields are equal if they have they are the same
"""
return self is other
def set_ready(self, is_ready: bool = True) -> None: def set_ready(self, is_ready: bool = True) -> None:
"""Set the readiness state of this field.""" """Set the readiness state of this field."""
self.ready = is_ready self.ready = is_ready

View File

@ -112,7 +112,7 @@ def vmlinux_proc(tree: ast.AST, module):
isinstance(mod_node, ast.ClassDef) isinstance(mod_node, ast.ClassDef)
and mod_node.name == imported_name and mod_node.name == imported_name
): ):
process_vmlinux_class(mod_node, module, handler, assignments) process_vmlinux_class(mod_node, module, handler)
found = True found = True
break break
if isinstance(mod_node, ast.Assign): if isinstance(mod_node, ast.Assign):
@ -128,7 +128,7 @@ def vmlinux_proc(tree: ast.AST, module):
f"{imported_name} not found as ClassDef or Assign in vmlinux" f"{imported_name} not found as ClassDef or Assign in vmlinux"
) )
IRGenerator(module, handler) IRGenerator(module, handler, assignments)
return assignments return assignments

View File

@ -1,5 +1,7 @@
import ctypes import ctypes
import logging import logging
from ..assignment_info import AssignmentInfo, AssignmentType
from ..dependency_handler import DependencyHandler from ..dependency_handler import DependencyHandler
from .debug_info_gen import debug_info_generation from .debug_info_gen import debug_info_generation
from ..dependency_node import DependencyNode from ..dependency_node import DependencyNode
@ -10,11 +12,14 @@ logger = logging.getLogger(__name__)
class IRGenerator: class IRGenerator:
# get the assignments dict and add this stuff to it. # get the assignments dict and add this stuff to it.
def __init__(self, llvm_module, handler: DependencyHandler, assignment=None): def __init__(self, llvm_module, handler: DependencyHandler, assignments):
self.llvm_module = llvm_module self.llvm_module = llvm_module
self.handler: DependencyHandler = handler self.handler: DependencyHandler = handler
self.generated: list[str] = [] self.generated: list[str] = []
self.generated_debug_info: list = [] self.generated_debug_info: list = []
# Use struct_name and field_name as key instead of Field object
self.generated_field_names: dict[str, dict[str, ir.GlobalVariable]] = {}
self.assignments: dict[str, AssignmentInfo] = assignments
if not handler.is_ready: if not handler.is_ready:
raise ImportError( raise ImportError(
"Semantic analysis of vmlinux imports failed. Cannot generate IR" "Semantic analysis of vmlinux imports failed. Cannot generate IR"
@ -67,10 +72,42 @@ class IRGenerator:
f"Warning: Dependency {dependency} not found in handler" f"Warning: Dependency {dependency} not found in handler"
) )
# Actual processor logic here after dependencies are resolved # Generate IR first to populate field names
self.generated_debug_info.append( self.generated_debug_info.append(
(struct, self.gen_ir(struct, self.generated_debug_info)) (struct, self.gen_ir(struct, self.generated_debug_info))
) )
# Fill the assignments dictionary with struct information
if struct.name not in self.assignments:
# Create a members dictionary for AssignmentInfo
members_dict = {}
for field_name, field in struct.fields.items():
# Get the generated field name from our dictionary, or use field_name if not found
if (
struct.name in self.generated_field_names
and field_name in self.generated_field_names[struct.name]
):
field_global_variable = self.generated_field_names[struct.name][
field_name
]
members_dict[field_name] = (field_global_variable, field)
else:
raise ValueError(
f"llvm global name not found for struct field {field_name}"
)
# members_dict[field_name] = (field_name, field)
# Add struct to assignments dictionary
self.assignments[struct.name] = AssignmentInfo(
value_type=AssignmentType.STRUCT,
python_type=struct.ctype_struct,
value=None,
pointer_level=None,
signature=None,
members=members_dict,
)
logger.info(f"Added struct assignment info for {struct.name}")
self.generated.append(struct.name) self.generated.append(struct.name)
finally: finally:
@ -85,6 +122,11 @@ class IRGenerator:
struct, self.llvm_module, generated_debug_info struct, self.llvm_module, generated_debug_info
) )
field_index = 0 field_index = 0
# Make sure the struct has an entry in our field names dictionary
if struct.name not in self.generated_field_names:
self.generated_field_names[struct.name] = {}
for field_name, field in struct.fields.items(): for field_name, field in struct.fields.items():
# does not take arrays and similar types into consideration yet. # does not take arrays and similar types into consideration yet.
if field.ctype_complex_type is not None and issubclass( if field.ctype_complex_type is not None and issubclass(
@ -94,6 +136,18 @@ class IRGenerator:
containing_type = field.containing_type containing_type = field.containing_type
if containing_type.__module__ == ctypes.__name__: if containing_type.__module__ == ctypes.__name__:
containing_type_size = ctypes.sizeof(containing_type) containing_type_size = ctypes.sizeof(containing_type)
if array_size == 0:
field_co_re_name = self._struct_name_generator(
struct, field, field_index, True, 0, containing_type_size
)
globvar = ir.GlobalVariable(
self.llvm_module, ir.IntType(64), name=field_co_re_name
)
globvar.linkage = "external"
globvar.set_metadata("llvm.preserve.access.index", debug_info)
self.generated_field_names[struct.name][field_name] = globvar
field_index += 1
continue
for i in range(0, array_size): for i in range(0, array_size):
field_co_re_name = self._struct_name_generator( field_co_re_name = self._struct_name_generator(
struct, field, field_index, True, i, containing_type_size struct, field, field_index, True, i, containing_type_size
@ -103,6 +157,7 @@ class IRGenerator:
) )
globvar.linkage = "external" globvar.linkage = "external"
globvar.set_metadata("llvm.preserve.access.index", debug_info) globvar.set_metadata("llvm.preserve.access.index", debug_info)
self.generated_field_names[struct.name][field_name] = globvar
field_index += 1 field_index += 1
elif field.type_size is not None: elif field.type_size is not None:
array_size = field.type_size array_size = field.type_size
@ -120,6 +175,7 @@ class IRGenerator:
) )
globvar.linkage = "external" globvar.linkage = "external"
globvar.set_metadata("llvm.preserve.access.index", debug_info) globvar.set_metadata("llvm.preserve.access.index", debug_info)
self.generated_field_names[struct.name][field_name] = globvar
field_index += 1 field_index += 1
else: else:
field_co_re_name = self._struct_name_generator( field_co_re_name = self._struct_name_generator(
@ -131,6 +187,7 @@ class IRGenerator:
) )
globvar.linkage = "external" globvar.linkage = "external"
globvar.set_metadata("llvm.preserve.access.index", debug_info) globvar.set_metadata("llvm.preserve.access.index", debug_info)
self.generated_field_names[struct.name][field_name] = globvar
return debug_info return debug_info
def _struct_name_generator( def _struct_name_generator(