mirror of
https://github.com/varun-r-mallya/Python-BPF.git
synced 2026-03-17 18:51:28 +00:00
PythonBPF: Add Compilation Context to allow parallel compilation of multiple bpf programs
This commit is contained in:
@ -4,7 +4,6 @@ import logging
|
||||
|
||||
from pythonbpf.helper import (
|
||||
HelperHandlerRegistry,
|
||||
reset_scratch_pool,
|
||||
)
|
||||
from pythonbpf.type_deducer import ctypes_to_ir
|
||||
from pythonbpf.expr import (
|
||||
@ -76,36 +75,30 @@ def count_temps_in_call(call_node, local_sym_tab):
|
||||
|
||||
|
||||
def handle_if_allocation(
|
||||
module, builder, stmt, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
|
||||
compilation_context, builder, stmt, func, ret_type, local_sym_tab
|
||||
):
|
||||
"""Recursively handle allocations in if/else branches."""
|
||||
if stmt.body:
|
||||
allocate_mem(
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
stmt.body,
|
||||
func,
|
||||
ret_type,
|
||||
map_sym_tab,
|
||||
local_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
if stmt.orelse:
|
||||
allocate_mem(
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
stmt.orelse,
|
||||
func,
|
||||
ret_type,
|
||||
map_sym_tab,
|
||||
local_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
|
||||
|
||||
def allocate_mem(
|
||||
module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
|
||||
):
|
||||
def allocate_mem(compilation_context, builder, body, func, ret_type, local_sym_tab):
|
||||
max_temps_needed = {}
|
||||
|
||||
def merge_type_counts(count_dict):
|
||||
@ -137,19 +130,15 @@ def allocate_mem(
|
||||
# Handle allocations
|
||||
if isinstance(stmt, ast.If):
|
||||
handle_if_allocation(
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
stmt,
|
||||
func,
|
||||
ret_type,
|
||||
map_sym_tab,
|
||||
local_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
elif isinstance(stmt, ast.Assign):
|
||||
handle_assign_allocation(
|
||||
builder, stmt, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
)
|
||||
handle_assign_allocation(compilation_context, builder, stmt, local_sym_tab)
|
||||
|
||||
allocate_temp_pool(builder, max_temps_needed, local_sym_tab)
|
||||
|
||||
@ -161,9 +150,7 @@ def allocate_mem(
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def handle_assign(
|
||||
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab
|
||||
):
|
||||
def handle_assign(func, compilation_context, builder, stmt, local_sym_tab):
|
||||
"""Handle assignment statements in the function body."""
|
||||
|
||||
# NOTE: Support multi-target assignments (e.g.: a, b = 1, 2)
|
||||
@ -175,13 +162,11 @@ def handle_assign(
|
||||
var_name = target.id
|
||||
result = handle_variable_assignment(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
var_name,
|
||||
rval,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
if not result:
|
||||
logger.error(f"Failed to handle assignment to {var_name}")
|
||||
@ -191,13 +176,11 @@ def handle_assign(
|
||||
# NOTE: Struct field assignment case: pkt.field = value
|
||||
handle_struct_field_assignment(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
target,
|
||||
rval,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
continue
|
||||
|
||||
@ -205,18 +188,12 @@ def handle_assign(
|
||||
logger.error(f"Unsupported assignment target: {ast.dump(target)}")
|
||||
|
||||
|
||||
def handle_cond(
|
||||
func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab=None
|
||||
):
|
||||
val = eval_expr(
|
||||
func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
)[0]
|
||||
def handle_cond(func, compilation_context, builder, cond, local_sym_tab):
|
||||
val = eval_expr(func, compilation_context, builder, cond, local_sym_tab)[0]
|
||||
return convert_to_bool(builder, val)
|
||||
|
||||
|
||||
def handle_if(
|
||||
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab=None
|
||||
):
|
||||
def handle_if(func, compilation_context, builder, stmt, local_sym_tab):
|
||||
"""Handle if statements in the function body."""
|
||||
logger.info("Handling if statement")
|
||||
# start = builder.block.parent
|
||||
@ -227,9 +204,7 @@ def handle_if(
|
||||
else:
|
||||
else_block = None
|
||||
|
||||
cond = handle_cond(
|
||||
func, module, builder, stmt.test, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
)
|
||||
cond = handle_cond(func, compilation_context, builder, stmt.test, local_sym_tab)
|
||||
if else_block:
|
||||
builder.cbranch(cond, then_block, else_block)
|
||||
else:
|
||||
@ -237,9 +212,7 @@ def handle_if(
|
||||
|
||||
builder.position_at_end(then_block)
|
||||
for s in stmt.body:
|
||||
process_stmt(
|
||||
func, module, builder, s, local_sym_tab, map_sym_tab, structs_sym_tab, False
|
||||
)
|
||||
process_stmt(func, compilation_context, builder, s, local_sym_tab, False)
|
||||
if not builder.block.is_terminated:
|
||||
builder.branch(merge_block)
|
||||
|
||||
@ -248,12 +221,10 @@ def handle_if(
|
||||
for s in stmt.orelse:
|
||||
process_stmt(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
s,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
False,
|
||||
)
|
||||
if not builder.block.is_terminated:
|
||||
@ -262,21 +233,25 @@ def handle_if(
|
||||
builder.position_at_end(merge_block)
|
||||
|
||||
|
||||
def handle_return(builder, stmt, local_sym_tab, ret_type):
|
||||
def handle_return(builder, stmt, local_sym_tab, ret_type, compilation_context=None):
|
||||
logger.info(f"Handling return statement: {ast.dump(stmt)}")
|
||||
if stmt.value is None:
|
||||
return handle_none_return(builder)
|
||||
elif isinstance(stmt.value, ast.Name) and is_xdp_name(stmt.value.id):
|
||||
return handle_xdp_return(stmt, builder, ret_type)
|
||||
else:
|
||||
# Fallback for now if ctx not passed, but caller should pass it
|
||||
if compilation_context is None:
|
||||
raise RuntimeError(
|
||||
"CompilationContext required for return statement evaluation"
|
||||
)
|
||||
|
||||
val = eval_expr(
|
||||
func=None,
|
||||
module=None,
|
||||
compilation_context=compilation_context,
|
||||
builder=builder,
|
||||
expr=stmt.value,
|
||||
local_sym_tab=local_sym_tab,
|
||||
map_sym_tab={},
|
||||
structs_sym_tab={},
|
||||
)
|
||||
logger.info(f"Evaluated return expression to {val}")
|
||||
builder.ret(val[0])
|
||||
@ -285,43 +260,34 @@ def handle_return(builder, stmt, local_sym_tab, ret_type):
|
||||
|
||||
def process_stmt(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
stmt,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
did_return,
|
||||
ret_type=ir.IntType(64),
|
||||
):
|
||||
logger.info(f"Processing statement: {ast.dump(stmt)}")
|
||||
reset_scratch_pool()
|
||||
# Use context scratch pool
|
||||
compilation_context.scratch_pool.reset()
|
||||
|
||||
if isinstance(stmt, ast.Expr):
|
||||
handle_expr(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
stmt,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
elif isinstance(stmt, ast.Assign):
|
||||
handle_assign(
|
||||
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab
|
||||
)
|
||||
handle_assign(func, compilation_context, builder, stmt, local_sym_tab)
|
||||
elif isinstance(stmt, ast.AugAssign):
|
||||
raise SyntaxError("Augmented assignment not supported")
|
||||
elif isinstance(stmt, ast.If):
|
||||
handle_if(
|
||||
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab
|
||||
)
|
||||
handle_if(func, compilation_context, builder, stmt, local_sym_tab)
|
||||
elif isinstance(stmt, ast.Return):
|
||||
did_return = handle_return(
|
||||
builder,
|
||||
stmt,
|
||||
local_sym_tab,
|
||||
ret_type,
|
||||
builder, stmt, local_sym_tab, ret_type, compilation_context
|
||||
)
|
||||
return did_return
|
||||
|
||||
@ -332,13 +298,11 @@ def process_stmt(
|
||||
|
||||
|
||||
def process_func_body(
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func_node,
|
||||
func,
|
||||
ret_type,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
):
|
||||
"""Process the body of a bpf function"""
|
||||
# TODO: A lot. We just have print -> bpf_trace_printk for now
|
||||
@ -360,6 +324,9 @@ def process_func_body(
|
||||
raise TypeError(
|
||||
f"Unsupported annotation type: {ast.dump(context_arg.annotation)}"
|
||||
)
|
||||
|
||||
# Use context's handler if available, else usage of VmlinuxHandlerRegistry
|
||||
# For now relying on VmlinuxHandlerRegistry which relies on codegen setting it
|
||||
if VmlinuxHandlerRegistry.is_vmlinux_struct(context_type_name):
|
||||
resolved_type = VmlinuxHandlerRegistry.get_struct_type(
|
||||
context_type_name
|
||||
@ -370,14 +337,12 @@ def process_func_body(
|
||||
|
||||
# pre-allocate dynamic variables
|
||||
local_sym_tab = allocate_mem(
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func_node.body,
|
||||
func,
|
||||
ret_type,
|
||||
map_sym_tab,
|
||||
local_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
|
||||
logger.info(f"Local symbol table: {local_sym_tab.keys()}")
|
||||
@ -385,12 +350,10 @@ def process_func_body(
|
||||
for stmt in func_node.body:
|
||||
did_return = process_stmt(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
stmt,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
did_return,
|
||||
ret_type,
|
||||
)
|
||||
@ -399,9 +362,12 @@ def process_func_body(
|
||||
builder.ret(ir.Constant(ir.IntType(64), 0))
|
||||
|
||||
|
||||
def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_tab):
|
||||
def process_bpf_chunk(func_node, compilation_context, return_type):
|
||||
"""Process a single BPF chunk (function) and emit corresponding LLVM IR."""
|
||||
|
||||
# Set current function in context (optional but good for future)
|
||||
compilation_context.current_func = func_node
|
||||
|
||||
func_name = func_node.name
|
||||
|
||||
ret_type = return_type
|
||||
@ -413,7 +379,7 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t
|
||||
param_types.append(ir.PointerType())
|
||||
|
||||
func_ty = ir.FunctionType(ret_type, param_types)
|
||||
func = ir.Function(module, func_ty, func_name)
|
||||
func = ir.Function(compilation_context.module, func_ty, func_name)
|
||||
|
||||
func.linkage = "dso_local"
|
||||
func.attributes.add("nounwind")
|
||||
@ -433,13 +399,11 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t
|
||||
builder = ir.IRBuilder(block)
|
||||
|
||||
process_func_body(
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func_node,
|
||||
func,
|
||||
ret_type,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
return func
|
||||
|
||||
@ -449,23 +413,32 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def func_proc(tree, module, chunks, map_sym_tab, structs_sym_tab):
|
||||
def func_proc(tree, compilation_context, chunks):
|
||||
"""Process all functions decorated with @bpf and @bpfglobal"""
|
||||
for func_node in chunks:
|
||||
# Ignore structs and maps
|
||||
# Check against the lists
|
||||
if (
|
||||
func_node.name in compilation_context.structs_sym_tab
|
||||
or func_node.name in compilation_context.map_sym_tab
|
||||
):
|
||||
continue
|
||||
|
||||
# Also check decorators to be sure
|
||||
decorators = [d.id for d in func_node.decorator_list if isinstance(d, ast.Name)]
|
||||
if "struct" in decorators or "map" in decorators:
|
||||
continue
|
||||
|
||||
if is_global_function(func_node):
|
||||
continue
|
||||
func_type = get_probe_string(func_node)
|
||||
logger.info(f"Found probe_string of {func_node.name}: {func_type}")
|
||||
|
||||
func = process_bpf_chunk(
|
||||
func_node,
|
||||
module,
|
||||
ctypes_to_ir(infer_return_type(func_node)),
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
return_type = ctypes_to_ir(infer_return_type(func_node))
|
||||
func = process_bpf_chunk(func_node, compilation_context, return_type)
|
||||
|
||||
logger.info(f"Generating Debug Info for Function {func_node.name}")
|
||||
generate_function_debug_info(func_node, module, func)
|
||||
generate_function_debug_info(func_node, compilation_context.module, func)
|
||||
|
||||
|
||||
# TODO: WIP, for string assignment to fixed-size arrays
|
||||
|
||||
Reference in New Issue
Block a user