From 1a0e21eaa8cbcf34f521f15af57990e02e5fac11 Mon Sep 17 00:00:00 2001 From: varun-r-mallya Date: Tue, 21 Oct 2025 04:50:34 +0530 Subject: [PATCH] support vmlinux enum in map arguments --- pythonbpf/expr/expr_pass.py | 1 + pythonbpf/maps/maps_pass.py | 11 ++++++++-- .../vmlinux_parser/vmlinux_exports_handler.py | 8 +++++++ .../vmlinux/simple_struct_test.py | 21 +++++++++++++++---- 4 files changed, 35 insertions(+), 6 deletions(-) diff --git a/pythonbpf/expr/expr_pass.py b/pythonbpf/expr/expr_pass.py index 5e1163a..2a7cd5f 100644 --- a/pythonbpf/expr/expr_pass.py +++ b/pythonbpf/expr/expr_pass.py @@ -349,6 +349,7 @@ def _handle_unary_op( neg_one = ir.Constant(ir.IntType(64), -1) result = builder.mul(operand, neg_one) return result, ir.IntType(64) + return None # ============================================================================ diff --git a/pythonbpf/maps/maps_pass.py b/pythonbpf/maps/maps_pass.py index 8459848..85837d7 100644 --- a/pythonbpf/maps/maps_pass.py +++ b/pythonbpf/maps/maps_pass.py @@ -6,6 +6,8 @@ from llvmlite import ir from .maps_utils import MapProcessorRegistry from .map_types import BPFMapType from .map_debug_info import create_map_debug_info, create_ringbuf_debug_info +from pythonbpf.expr.vmlinux_registry import VmlinuxHandlerRegistry + logger: Logger = logging.getLogger(__name__) @@ -51,7 +53,7 @@ def _parse_map_params(rval, expected_args=None): """Parse map parameters from call arguments and keywords.""" params = {} - + handler = VmlinuxHandlerRegistry.get_handler() # Parse positional arguments if expected_args: for i, arg_name in enumerate(expected_args): @@ -65,7 +67,12 @@ def _parse_map_params(rval, expected_args=None): # Parse keyword arguments (override positional) for keyword in rval.keywords: if isinstance(keyword.value, ast.Name): - params[keyword.arg] = keyword.value.id + name = keyword.value.id + if handler and handler.is_vmlinux_enum(name): + result = handler.get_vmlinux_enum_value(name) + params[keyword.arg] = result if result is not None else name + else: + params[keyword.arg] = name elif isinstance(keyword.value, ast.Constant): params[keyword.arg] = keyword.value.value diff --git a/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py b/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py index f821520..1986b44 100644 --- a/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py +++ b/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py @@ -54,6 +54,14 @@ class VmlinuxHandler: return ir.Constant(ir.IntType(64), value), ir.IntType(64) return None + def get_vmlinux_enum_value(self, name): + """Handle vmlinux enum constants by returning LLVM IR constants""" + if self.is_vmlinux_enum(name): + value = self.vmlinux_symtab[name]["value"] + logger.info(f"The value of vmlinux enum {name} = {value}") + return value + return None + def handle_vmlinux_struct(self, struct_name, module, builder): """Handle vmlinux struct initializations""" if self.is_vmlinux_struct(struct_name): diff --git a/tests/passing_tests/vmlinux/simple_struct_test.py b/tests/passing_tests/vmlinux/simple_struct_test.py index 9c6d272..97ab54a 100644 --- a/tests/passing_tests/vmlinux/simple_struct_test.py +++ b/tests/passing_tests/vmlinux/simple_struct_test.py @@ -1,13 +1,26 @@ import logging -from pythonbpf import bpf, section, bpfglobal, compile_to_ir +from pythonbpf import bpf, section, bpfglobal, compile_to_ir, map from pythonbpf import compile # noqa: F401 from vmlinux import TASK_COMM_LEN # noqa: F401 from vmlinux import struct_trace_event_raw_sys_enter # noqa: F401 +from ctypes import c_uint64, c_int32, c_int64 +from pythonbpf.maps import HashMap # from vmlinux import struct_uinput_device # from vmlinux import struct_blk_integrity_iter -from ctypes import c_int64 + + +@bpf +@map +def mymap() -> HashMap: + return HashMap(key=c_int32, value=c_uint64, max_entries=TASK_COMM_LEN) + + +@bpf +@map +def mymap2() -> HashMap: + return HashMap(key=c_int32, value=c_uint64, max_entries=18) # Instructions to how to run this program @@ -21,7 +34,7 @@ from ctypes import c_int64 def hello_world(ctx: struct_trace_event_raw_sys_enter) -> c_int64: a = 2 + TASK_COMM_LEN + TASK_COMM_LEN print(f"Hello, World{TASK_COMM_LEN} and {a}") - return c_int64(TASK_COMM_LEN) + return c_int64(TASK_COMM_LEN + 2) @bpf @@ -31,4 +44,4 @@ def LICENSE() -> str: compile_to_ir("simple_struct_test.py", "simple_struct_test.ll", loglevel=logging.DEBUG) -compile() +# compile()