diff --git a/pythonbpf/vmlinux_parser/import_detector.py b/pythonbpf/vmlinux_parser/import_detector.py index d90c478..48fe403 100644 --- a/pythonbpf/vmlinux_parser/import_detector.py +++ b/pythonbpf/vmlinux_parser/import_detector.py @@ -25,7 +25,7 @@ def detect_import_statement(tree: ast.AST) -> list[tuple[str, ast.ImportFrom]]: List of tuples containing (module_name, imported_item) for each vmlinux import Raises: - SyntaxError: If multiple imports from vmlinux are attempted or import * is used + SyntaxError: If import * is used """ vmlinux_imports = [] @@ -40,28 +40,19 @@ def detect_import_statement(tree: ast.AST) -> list[tuple[str, ast.ImportFrom]]: "Please import specific types explicitly." ) - # Check for multiple imports: from vmlinux import A, B, C - if len(node.names) > 1: - imported_names = [alias.name for alias in node.names] - raise SyntaxError( - f"Multiple imports from vmlinux are not supported. " - f"Found: {', '.join(imported_names)}. " - f"Please use separate import statements for each type." - ) - # Check if no specific import is specified (should not happen with valid Python) if len(node.names) == 0: raise SyntaxError( "Import from vmlinux must specify at least one type." ) - # Valid single import + # Support multiple imports: from vmlinux import A, B, C for alias in node.names: import_name = alias.name - # Use alias if provided, otherwise use the original name (commented) - # as_name = alias.asname if alias.asname else alias.name - vmlinux_imports.append(("vmlinux", node)) - logger.info(f"Found vmlinux import: {import_name}") + # Use alias if provided, otherwise use the original name + as_name = alias.asname if alias.asname else alias.name + vmlinux_imports.append(("vmlinux", node, import_name, as_name)) + logger.info(f"Found vmlinux import: {import_name} as {as_name}") # Handle "import vmlinux" statements (not typical but should be rejected) elif isinstance(node, ast.Import): @@ -103,40 +94,37 @@ def vmlinux_proc(tree: ast.AST, module): with open(source_file, "r") as f: mod_ast = ast.parse(f.read(), filename=source_file) - for import_mod, import_node in import_statements: - for alias in import_node.names: - imported_name = alias.name - found = False - for mod_node in mod_ast.body: - if ( - isinstance(mod_node, ast.ClassDef) - and mod_node.name == imported_name - ): - process_vmlinux_class(mod_node, module, handler) - found = True - break - if isinstance(mod_node, ast.Assign): - for target in mod_node.targets: - if isinstance(target, ast.Name) and target.id == imported_name: - process_vmlinux_assign(mod_node, module, assignments) - found = True - break - if found: - break - if not found: - logger.info( - f"{imported_name} not found as ClassDef or Assign in vmlinux" - ) + for import_mod, import_node, imported_name, as_name in import_statements: + found = False + for mod_node in mod_ast.body: + if isinstance(mod_node, ast.ClassDef) and mod_node.name == imported_name: + process_vmlinux_class(mod_node, module, handler) + found = True + break + if isinstance(mod_node, ast.Assign): + for target in mod_node.targets: + if isinstance(target, ast.Name) and target.id == imported_name: + process_vmlinux_assign(mod_node, module, assignments, as_name) + found = True + break + if found: + break + if not found: + logger.info(f"{imported_name} not found as ClassDef or Assign in vmlinux") IRGenerator(module, handler, assignments) return assignments -def process_vmlinux_assign(node, module, assignments: dict[str, AssignmentInfo]): +def process_vmlinux_assign( + node, module, assignments: dict[str, AssignmentInfo], target_name=None +): """Process assignments from vmlinux module.""" # Only handle single-target assignments if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): - target_name = node.targets[0].id + # Use provided target_name (for aliased imports) or fall back to original name + if target_name is None: + target_name = node.targets[0].id # Handle constant value assignments if isinstance(node.value, ast.Constant): diff --git a/tests/failing_tests/vmlinux/requests.py b/tests/failing_tests/vmlinux/requests.py new file mode 100644 index 0000000..f19256b --- /dev/null +++ b/tests/failing_tests/vmlinux/requests.py @@ -0,0 +1,21 @@ +from vmlinux import struct_request, struct_pt_regs, XDP_PASS +from pythonbpf import bpf, section, bpfglobal, compile_to_ir +import logging + + +@bpf +@section("kprobe/blk_mq_start_request") +def example(ctx: struct_pt_regs): + req = struct_request(ctx.di) + c = req.__data_len + d = XDP_PASS + print(f"data length {c} and test {d}") + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile_to_ir("requests.py", "requests.ll", loglevel=logging.INFO)