mirror of
https://github.com/varun-r-mallya/Python-BPF.git
synced 2025-12-31 21:06:25 +00:00
add symbol resolution to import detection
This commit is contained in:
@ -1,6 +1,8 @@
|
|||||||
import ast
|
import ast
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
from .vmlinux_class_handler import process_vmlinux_class
|
from .vmlinux_class_handler import process_vmlinux_class
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -55,7 +57,7 @@ def detect_import_statement(tree: ast.AST) -> List[Tuple[str, str]]:
|
|||||||
import_name = alias.name
|
import_name = alias.name
|
||||||
# Use alias if provided, otherwise use the original name
|
# Use alias if provided, otherwise use the original name
|
||||||
as_name = alias.asname if alias.asname else alias.name
|
as_name = alias.asname if alias.asname else alias.name
|
||||||
vmlinux_imports.append(("vmlinux", import_name))
|
vmlinux_imports.append(("vmlinux", node))
|
||||||
logger.info(f"Found vmlinux import: {import_name}")
|
logger.info(f"Found vmlinux import: {import_name}")
|
||||||
|
|
||||||
# Handle "import vmlinux" statements (not typical but should be rejected)
|
# Handle "import vmlinux" statements (not typical but should be rejected)
|
||||||
@ -77,20 +79,40 @@ def vmlinux_proc(tree: ast.AST, module):
|
|||||||
logger.info("No vmlinux imports found")
|
logger.info("No vmlinux imports found")
|
||||||
return
|
return
|
||||||
|
|
||||||
vmlinux_types = set()
|
# Import vmlinux module directly
|
||||||
for module_name, imported_item in import_statements:
|
try:
|
||||||
vmlinux_types.add(imported_item)
|
vmlinux_mod = importlib.import_module("vmlinux")
|
||||||
logger.info(f"Registered vmlinux type: {imported_item}")
|
except ImportError:
|
||||||
|
logger.warning("Could not import vmlinux module")
|
||||||
|
return
|
||||||
|
|
||||||
for node in ast.walk(tree):
|
source_file = inspect.getsourcefile(vmlinux_mod)
|
||||||
if isinstance(node, ast.ClassDef):
|
if source_file is None:
|
||||||
# Check if this class uses vmlinux types
|
logger.warning("Cannot find source for vmlinux module")
|
||||||
logger.info(f"Processing ClassDef with vmlinux types: {node.name}")
|
return
|
||||||
process_vmlinux_class(node, module, vmlinux_types)
|
|
||||||
|
|
||||||
elif isinstance(node, ast.Assign):
|
with open(source_file, "r") as f:
|
||||||
logger.info(f"Processing Assign with vmlinux types")
|
mod_ast = ast.parse(f.read(), filename=source_file)
|
||||||
process_vmlinux_assign(node, module, vmlinux_types)
|
|
||||||
|
|
||||||
def process_vmlinux_assign(node, module, vmlinux_types):
|
for import_mod, import_node in import_statements:
|
||||||
|
for alias in import_node.names:
|
||||||
|
imported_name = alias.name
|
||||||
|
found = False
|
||||||
|
for mod_node in mod_ast.body:
|
||||||
|
if isinstance(mod_node, ast.ClassDef) and mod_node.name == imported_name:
|
||||||
|
process_vmlinux_class(mod_node, module)
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
if isinstance(mod_node, ast.Assign):
|
||||||
|
for target in mod_node.targets:
|
||||||
|
if isinstance(target, ast.Name) and target.id == imported_name:
|
||||||
|
process_vmlinux_assign(mod_node, module)
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
if found:
|
||||||
|
break
|
||||||
|
if not found:
|
||||||
|
logger.info(f"{imported_name} not found as ClassDef or Assign in vmlinux")
|
||||||
|
|
||||||
|
def process_vmlinux_assign(node, module):
|
||||||
raise NotImplementedError("Assignment handling has not been implemented yet")
|
raise NotImplementedError("Assignment handling has not been implemented yet")
|
||||||
|
|||||||
@ -1,8 +1,16 @@
|
|||||||
import ast
|
import ast
|
||||||
import logging
|
import logging
|
||||||
|
import importlib
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def process_vmlinux_class(node, module, vmlinux_types):
|
|
||||||
|
def get_module_symbols(module_name: str):
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
return [name for name in dir(module)]
|
||||||
|
|
||||||
|
def process_vmlinux_class(node, module):
|
||||||
# Process ClassDef nodes that use vmlinux imports
|
# Process ClassDef nodes that use vmlinux imports
|
||||||
|
symbols = get_module_symbols("vmlinux")
|
||||||
|
# print(symbols)
|
||||||
pass
|
pass
|
||||||
Reference in New Issue
Block a user