PythonBPF: Add Compilation Context to allow parallel compilation of multiple bpf programs

This commit is contained in:
Pragyansh Chaturvedi
2026-02-21 18:59:33 +05:30
parent 45d85c416f
commit ec4a6852ec
14 changed files with 455 additions and 497 deletions

View File

@ -12,14 +12,14 @@ from pythonbpf.expr.vmlinux_registry import VmlinuxHandlerRegistry
logger: Logger = logging.getLogger(__name__)
def maps_proc(tree, module, chunks, structs_sym_tab):
def maps_proc(tree, compilation_context, chunks):
"""Process all functions decorated with @map to find BPF maps"""
map_sym_tab = {}
map_sym_tab = compilation_context.map_sym_tab
for func_node in chunks:
if is_map(func_node):
logger.info(f"Found BPF map: {func_node.name}")
map_sym_tab[func_node.name] = process_bpf_map(
func_node, module, structs_sym_tab
func_node, compilation_context
)
return map_sym_tab
@ -51,11 +51,11 @@ def create_bpf_map(module, map_name, map_params):
return MapSymbol(type=map_params["type"], sym=map_global, params=map_params)
def _parse_map_params(rval, expected_args=None):
def _parse_map_params(rval, compilation_context, expected_args=None):
"""Parse map parameters from call arguments and keywords."""
params = {}
handler = VmlinuxHandlerRegistry.get_handler()
handler = compilation_context.vmlinux_handler
# Parse positional arguments
if expected_args:
for i, arg_name in enumerate(expected_args):
@ -83,12 +83,23 @@ def _get_vmlinux_enum(handler, name):
if handler and handler.is_vmlinux_enum(name):
return handler.get_vmlinux_enum_value(name)
# Fallback to VmlinuxHandlerRegistry if handler invalid
# This is for backward compatibility or if refactoring isn't complete
if (
VmlinuxHandlerRegistry.get_handler()
and VmlinuxHandlerRegistry.get_handler().is_vmlinux_enum(name)
):
return VmlinuxHandlerRegistry.get_handler().get_vmlinux_enum_value(name)
return None
@MapProcessorRegistry.register("RingBuffer")
def process_ringbuf_map(map_name, rval, module, structs_sym_tab):
def process_ringbuf_map(map_name, rval, compilation_context):
"""Process a BPF_RINGBUF map declaration"""
logger.info(f"Processing Ringbuf: {map_name}")
map_params = _parse_map_params(rval, expected_args=["max_entries"])
map_params = _parse_map_params(
rval, compilation_context, expected_args=["max_entries"]
)
map_params["type"] = BPFMapType.RINGBUF
# NOTE: constraints borrowed from https://docs.ebpf.io/linux/map-type/BPF_MAP_TYPE_RINGBUF/
@ -104,42 +115,62 @@ def process_ringbuf_map(map_name, rval, module, structs_sym_tab):
logger.info(f"Ringbuf map parameters: {map_params}")
map_global = create_bpf_map(module, map_name, map_params)
map_global = create_bpf_map(compilation_context.module, map_name, map_params)
create_ringbuf_debug_info(
module, map_global.sym, map_name, map_params, structs_sym_tab
compilation_context.module,
map_global.sym,
map_name,
map_params,
compilation_context.structs_sym_tab,
)
return map_global
@MapProcessorRegistry.register("HashMap")
def process_hash_map(map_name, rval, module, structs_sym_tab):
def process_hash_map(map_name, rval, compilation_context):
"""Process a BPF_HASH map declaration"""
logger.info(f"Processing HashMap: {map_name}")
map_params = _parse_map_params(rval, expected_args=["key", "value", "max_entries"])
map_params = _parse_map_params(
rval, compilation_context, expected_args=["key", "value", "max_entries"]
)
map_params["type"] = BPFMapType.HASH
logger.info(f"Map parameters: {map_params}")
map_global = create_bpf_map(module, map_name, map_params)
map_global = create_bpf_map(compilation_context.module, map_name, map_params)
# Generate debug info for BTF
create_map_debug_info(module, map_global.sym, map_name, map_params, structs_sym_tab)
create_map_debug_info(
compilation_context.module,
map_global.sym,
map_name,
map_params,
compilation_context.structs_sym_tab,
)
return map_global
@MapProcessorRegistry.register("PerfEventArray")
def process_perf_event_map(map_name, rval, module, structs_sym_tab):
def process_perf_event_map(map_name, rval, compilation_context):
"""Process a BPF_PERF_EVENT_ARRAY map declaration"""
logger.info(f"Processing PerfEventArray: {map_name}")
map_params = _parse_map_params(rval, expected_args=["key_size", "value_size"])
map_params = _parse_map_params(
rval, compilation_context, expected_args=["key_size", "value_size"]
)
map_params["type"] = BPFMapType.PERF_EVENT_ARRAY
logger.info(f"Map parameters: {map_params}")
map_global = create_bpf_map(module, map_name, map_params)
map_global = create_bpf_map(compilation_context.module, map_name, map_params)
# Generate debug info for BTF
create_map_debug_info(module, map_global.sym, map_name, map_params, structs_sym_tab)
create_map_debug_info(
compilation_context.module,
map_global.sym,
map_name,
map_params,
compilation_context.structs_sym_tab,
)
return map_global
def process_bpf_map(func_node, module, structs_sym_tab):
def process_bpf_map(func_node, compilation_context):
"""Process a BPF map (a function decorated with @map)"""
map_name = func_node.name
logger.info(f"Processing BPF map: {map_name}")
@ -158,9 +189,9 @@ def process_bpf_map(func_node, module, structs_sym_tab):
if isinstance(rval, ast.Call) and isinstance(rval.func, ast.Name):
handler = MapProcessorRegistry.get_processor(rval.func.id)
if handler:
return handler(map_name, rval, module, structs_sym_tab)
return handler(map_name, rval, compilation_context)
else:
logger.warning(f"Unknown map type {rval.func.id}, defaulting to HashMap")
return process_hash_map(map_name, rval, module)
return process_hash_map(map_name, rval, compilation_context)
else:
raise ValueError("Function under @map must return a map")