PythonBPF: Add Compilation Context to allow parallel compilation of multiple bpf programs

This commit is contained in:
Pragyansh Chaturvedi
2026-02-21 18:59:33 +05:30
parent 45d85c416f
commit ec4a6852ec
14 changed files with 455 additions and 497 deletions

View File

@ -4,7 +4,6 @@ import logging
from pythonbpf.helper import (
HelperHandlerRegistry,
reset_scratch_pool,
)
from pythonbpf.type_deducer import ctypes_to_ir
from pythonbpf.expr import (
@ -76,36 +75,30 @@ def count_temps_in_call(call_node, local_sym_tab):
def handle_if_allocation(
module, builder, stmt, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
compilation_context, builder, stmt, func, ret_type, local_sym_tab
):
"""Recursively handle allocations in if/else branches."""
if stmt.body:
allocate_mem(
module,
compilation_context,
builder,
stmt.body,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
if stmt.orelse:
allocate_mem(
module,
compilation_context,
builder,
stmt.orelse,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
def allocate_mem(
module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
):
def allocate_mem(compilation_context, builder, body, func, ret_type, local_sym_tab):
max_temps_needed = {}
def merge_type_counts(count_dict):
@ -137,19 +130,15 @@ def allocate_mem(
# Handle allocations
if isinstance(stmt, ast.If):
handle_if_allocation(
module,
compilation_context,
builder,
stmt,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
elif isinstance(stmt, ast.Assign):
handle_assign_allocation(
builder, stmt, local_sym_tab, map_sym_tab, structs_sym_tab
)
handle_assign_allocation(compilation_context, builder, stmt, local_sym_tab)
allocate_temp_pool(builder, max_temps_needed, local_sym_tab)
@ -161,9 +150,7 @@ def allocate_mem(
# ============================================================================
def handle_assign(
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab
):
def handle_assign(func, compilation_context, builder, stmt, local_sym_tab):
"""Handle assignment statements in the function body."""
# NOTE: Support multi-target assignments (e.g.: a, b = 1, 2)
@ -175,13 +162,11 @@ def handle_assign(
var_name = target.id
result = handle_variable_assignment(
func,
module,
compilation_context,
builder,
var_name,
rval,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
if not result:
logger.error(f"Failed to handle assignment to {var_name}")
@ -191,13 +176,11 @@ def handle_assign(
# NOTE: Struct field assignment case: pkt.field = value
handle_struct_field_assignment(
func,
module,
compilation_context,
builder,
target,
rval,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
continue
@ -205,18 +188,12 @@ def handle_assign(
logger.error(f"Unsupported assignment target: {ast.dump(target)}")
def handle_cond(
func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab=None
):
val = eval_expr(
func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab
)[0]
def handle_cond(func, compilation_context, builder, cond, local_sym_tab):
val = eval_expr(func, compilation_context, builder, cond, local_sym_tab)[0]
return convert_to_bool(builder, val)
def handle_if(
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab=None
):
def handle_if(func, compilation_context, builder, stmt, local_sym_tab):
"""Handle if statements in the function body."""
logger.info("Handling if statement")
# start = builder.block.parent
@ -227,9 +204,7 @@ def handle_if(
else:
else_block = None
cond = handle_cond(
func, module, builder, stmt.test, local_sym_tab, map_sym_tab, structs_sym_tab
)
cond = handle_cond(func, compilation_context, builder, stmt.test, local_sym_tab)
if else_block:
builder.cbranch(cond, then_block, else_block)
else:
@ -237,9 +212,7 @@ def handle_if(
builder.position_at_end(then_block)
for s in stmt.body:
process_stmt(
func, module, builder, s, local_sym_tab, map_sym_tab, structs_sym_tab, False
)
process_stmt(func, compilation_context, builder, s, local_sym_tab, False)
if not builder.block.is_terminated:
builder.branch(merge_block)
@ -248,12 +221,10 @@ def handle_if(
for s in stmt.orelse:
process_stmt(
func,
module,
compilation_context,
builder,
s,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
False,
)
if not builder.block.is_terminated:
@ -262,21 +233,25 @@ def handle_if(
builder.position_at_end(merge_block)
def handle_return(builder, stmt, local_sym_tab, ret_type):
def handle_return(builder, stmt, local_sym_tab, ret_type, compilation_context=None):
logger.info(f"Handling return statement: {ast.dump(stmt)}")
if stmt.value is None:
return handle_none_return(builder)
elif isinstance(stmt.value, ast.Name) and is_xdp_name(stmt.value.id):
return handle_xdp_return(stmt, builder, ret_type)
else:
# Fallback for now if ctx not passed, but caller should pass it
if compilation_context is None:
raise RuntimeError(
"CompilationContext required for return statement evaluation"
)
val = eval_expr(
func=None,
module=None,
compilation_context=compilation_context,
builder=builder,
expr=stmt.value,
local_sym_tab=local_sym_tab,
map_sym_tab={},
structs_sym_tab={},
)
logger.info(f"Evaluated return expression to {val}")
builder.ret(val[0])
@ -285,43 +260,34 @@ def handle_return(builder, stmt, local_sym_tab, ret_type):
def process_stmt(
func,
module,
compilation_context,
builder,
stmt,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
did_return,
ret_type=ir.IntType(64),
):
logger.info(f"Processing statement: {ast.dump(stmt)}")
reset_scratch_pool()
# Use context scratch pool
compilation_context.scratch_pool.reset()
if isinstance(stmt, ast.Expr):
handle_expr(
func,
module,
compilation_context,
builder,
stmt,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
elif isinstance(stmt, ast.Assign):
handle_assign(
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab
)
handle_assign(func, compilation_context, builder, stmt, local_sym_tab)
elif isinstance(stmt, ast.AugAssign):
raise SyntaxError("Augmented assignment not supported")
elif isinstance(stmt, ast.If):
handle_if(
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab
)
handle_if(func, compilation_context, builder, stmt, local_sym_tab)
elif isinstance(stmt, ast.Return):
did_return = handle_return(
builder,
stmt,
local_sym_tab,
ret_type,
builder, stmt, local_sym_tab, ret_type, compilation_context
)
return did_return
@ -332,13 +298,11 @@ def process_stmt(
def process_func_body(
module,
compilation_context,
builder,
func_node,
func,
ret_type,
map_sym_tab,
structs_sym_tab,
):
"""Process the body of a bpf function"""
# TODO: A lot. We just have print -> bpf_trace_printk for now
@ -360,6 +324,9 @@ def process_func_body(
raise TypeError(
f"Unsupported annotation type: {ast.dump(context_arg.annotation)}"
)
# Use context's handler if available, else usage of VmlinuxHandlerRegistry
# For now relying on VmlinuxHandlerRegistry which relies on codegen setting it
if VmlinuxHandlerRegistry.is_vmlinux_struct(context_type_name):
resolved_type = VmlinuxHandlerRegistry.get_struct_type(
context_type_name
@ -370,14 +337,12 @@ def process_func_body(
# pre-allocate dynamic variables
local_sym_tab = allocate_mem(
module,
compilation_context,
builder,
func_node.body,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
logger.info(f"Local symbol table: {local_sym_tab.keys()}")
@ -385,12 +350,10 @@ def process_func_body(
for stmt in func_node.body:
did_return = process_stmt(
func,
module,
compilation_context,
builder,
stmt,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
did_return,
ret_type,
)
@ -399,9 +362,12 @@ def process_func_body(
builder.ret(ir.Constant(ir.IntType(64), 0))
def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_tab):
def process_bpf_chunk(func_node, compilation_context, return_type):
"""Process a single BPF chunk (function) and emit corresponding LLVM IR."""
# Set current function in context (optional but good for future)
compilation_context.current_func = func_node
func_name = func_node.name
ret_type = return_type
@ -413,7 +379,7 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t
param_types.append(ir.PointerType())
func_ty = ir.FunctionType(ret_type, param_types)
func = ir.Function(module, func_ty, func_name)
func = ir.Function(compilation_context.module, func_ty, func_name)
func.linkage = "dso_local"
func.attributes.add("nounwind")
@ -433,13 +399,11 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t
builder = ir.IRBuilder(block)
process_func_body(
module,
compilation_context,
builder,
func_node,
func,
ret_type,
map_sym_tab,
structs_sym_tab,
)
return func
@ -449,23 +413,32 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t
# ============================================================================
def func_proc(tree, module, chunks, map_sym_tab, structs_sym_tab):
def func_proc(tree, compilation_context, chunks):
"""Process all functions decorated with @bpf and @bpfglobal"""
for func_node in chunks:
# Ignore structs and maps
# Check against the lists
if (
func_node.name in compilation_context.structs_sym_tab
or func_node.name in compilation_context.map_sym_tab
):
continue
# Also check decorators to be sure
decorators = [d.id for d in func_node.decorator_list if isinstance(d, ast.Name)]
if "struct" in decorators or "map" in decorators:
continue
if is_global_function(func_node):
continue
func_type = get_probe_string(func_node)
logger.info(f"Found probe_string of {func_node.name}: {func_type}")
func = process_bpf_chunk(
func_node,
module,
ctypes_to_ir(infer_return_type(func_node)),
map_sym_tab,
structs_sym_tab,
)
return_type = ctypes_to_ir(infer_return_type(func_node))
func = process_bpf_chunk(func_node, compilation_context, return_type)
logger.info(f"Generating Debug Info for Function {func_node.name}")
generate_function_debug_info(func_node, module, func)
generate_function_debug_info(func_node, compilation_context.module, func)
# TODO: WIP, for string assignment to fixed-size arrays