PythonBPF: Add Compilation Context to allow parallel compilation of multiple bpf programs

This commit is contained in:
Pragyansh Chaturvedi
2026-02-21 18:59:33 +05:30
parent 45d85c416f
commit ec4a6852ec
14 changed files with 455 additions and 497 deletions

View File

@ -37,7 +37,7 @@ def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder
raise SyntaxError(f"Undefined variable {expr.id}")
def _handle_constant_expr(module, builder, expr: ast.Constant):
def _handle_constant_expr(compilation_context, builder, expr: ast.Constant):
"""Handle ast.Constant expressions."""
if isinstance(expr.value, int) or isinstance(expr.value, bool):
return ir.Constant(ir.IntType(64), int(expr.value)), ir.IntType(64)
@ -48,7 +48,9 @@ def _handle_constant_expr(module, builder, expr: ast.Constant):
str_constant = ir.Constant(str_type, bytearray(str_bytes))
# Create global variable
global_str = ir.GlobalVariable(module, str_type, name=str_name)
global_str = ir.GlobalVariable(
compilation_context.module, str_type, name=str_name
)
global_str.linkage = "internal"
global_str.global_constant = True
global_str.initializer = str_constant
@ -64,10 +66,11 @@ def _handle_attribute_expr(
func,
expr: ast.Attribute,
local_sym_tab: Dict,
structs_sym_tab: Dict,
compilation_context,
builder: ir.IRBuilder,
):
"""Handle ast.Attribute expressions for struct field access."""
structs_sym_tab = compilation_context.structs_sym_tab
if isinstance(expr.value, ast.Name):
var_name = expr.value.id
attr_name = expr.attr
@ -157,9 +160,7 @@ def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilde
# ============================================================================
def get_operand_value(
func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None
):
def get_operand_value(func, compilation_context, operand, builder, local_sym_tab):
"""Extract the value from an operand, handling variables and constants."""
logger.info(f"Getting operand value for: {ast.dump(operand)}")
if isinstance(operand, ast.Name):
@ -187,13 +188,11 @@ def get_operand_value(
raise TypeError(f"Unsupported constant type: {type(operand.value)}")
elif isinstance(operand, ast.BinOp):
res = _handle_binary_op_impl(
func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab
func, compilation_context, operand, builder, local_sym_tab
)
return res
else:
res = eval_expr(
func, module, builder, operand, local_sym_tab, map_sym_tab, structs_sym_tab
)
res = eval_expr(func, compilation_context, builder, operand, local_sym_tab)
if res is None:
raise ValueError(f"Failed to evaluate call expression: {operand}")
val, _ = res
@ -205,15 +204,13 @@ def get_operand_value(
raise TypeError(f"Unsupported operand type: {type(operand)}")
def _handle_binary_op_impl(
func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None
):
def _handle_binary_op_impl(func, compilation_context, rval, builder, local_sym_tab):
op = rval.op
left = get_operand_value(
func, module, rval.left, builder, local_sym_tab, map_sym_tab, structs_sym_tab
func, compilation_context, rval.left, builder, local_sym_tab
)
right = get_operand_value(
func, module, rval.right, builder, local_sym_tab, map_sym_tab, structs_sym_tab
func, compilation_context, rval.right, builder, local_sym_tab
)
logger.info(f"left is {left}, right is {right}, op is {op}")
@ -249,16 +246,14 @@ def _handle_binary_op_impl(
def _handle_binary_op(
func,
module,
compilation_context,
rval,
builder,
var_name,
local_sym_tab,
map_sym_tab,
structs_sym_tab=None,
):
result = _handle_binary_op_impl(
func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab
func, compilation_context, rval, builder, local_sym_tab
)
if var_name and var_name in local_sym_tab:
logger.info(
@ -275,12 +270,10 @@ def _handle_binary_op(
def _handle_ctypes_call(
func,
module,
compilation_context,
builder,
expr,
local_sym_tab,
map_sym_tab,
structs_sym_tab=None,
):
"""Handle ctypes type constructor calls."""
if len(expr.args) != 1:
@ -290,12 +283,10 @@ def _handle_ctypes_call(
arg = expr.args[0]
val = eval_expr(
func,
module,
compilation_context,
builder,
arg,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
if val is None:
logger.info("Failed to evaluate argument to ctypes constructor")
@ -344,9 +335,7 @@ def _handle_ctypes_call(
return value, expected_type
def _handle_compare(
func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab=None
):
def _handle_compare(func, compilation_context, builder, cond, local_sym_tab):
"""Handle ast.Compare expressions."""
if len(cond.ops) != 1 or len(cond.comparators) != 1:
@ -354,21 +343,17 @@ def _handle_compare(
return None
lhs = eval_expr(
func,
module,
compilation_context,
builder,
cond.left,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
rhs = eval_expr(
func,
module,
compilation_context,
builder,
cond.comparators[0],
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
if lhs is None or rhs is None:
@ -382,12 +367,10 @@ def _handle_compare(
def _handle_unary_op(
func,
module,
compilation_context,
builder,
expr: ast.UnaryOp,
local_sym_tab,
map_sym_tab,
structs_sym_tab=None,
):
"""Handle ast.UnaryOp expressions."""
if not isinstance(expr.op, ast.Not) and not isinstance(expr.op, ast.USub):
@ -395,7 +378,7 @@ def _handle_unary_op(
return None
operand = get_operand_value(
func, module, expr.operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab
func, compilation_context, expr.operand, builder, local_sym_tab
)
if operand is None:
logger.error("Failed to evaluate operand for unary operation")
@ -418,7 +401,7 @@ def _handle_unary_op(
# ============================================================================
def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab):
def _handle_and_op(func, builder, expr, local_sym_tab, compilation_context):
"""Handle `and` boolean operations."""
logger.debug(f"Handling 'and' operator with {len(expr.values)} operands")
@ -433,7 +416,7 @@ def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_
# Evaluate current operand
operand_result = eval_expr(
func, None, builder, value, local_sym_tab, map_sym_tab, structs_sym_tab
func, compilation_context, builder, value, local_sym_tab
)
if operand_result is None:
logger.error(f"Failed to evaluate operand {i} in 'and' expression")
@ -471,7 +454,7 @@ def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_
return phi, ir.IntType(1)
def _handle_or_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab):
def _handle_or_op(func, builder, expr, local_sym_tab, compilation_context):
"""Handle `or` boolean operations."""
logger.debug(f"Handling 'or' operator with {len(expr.values)} operands")
@ -486,7 +469,7 @@ def _handle_or_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_t
# Evaluate current operand
operand_result = eval_expr(
func, None, builder, value, local_sym_tab, map_sym_tab, structs_sym_tab
func, compilation_context, builder, value, local_sym_tab
)
if operand_result is None:
logger.error(f"Failed to evaluate operand {i} in 'or' expression")
@ -526,23 +509,17 @@ def _handle_or_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_t
def _handle_boolean_op(
func,
module,
compilation_context,
builder,
expr: ast.BoolOp,
local_sym_tab,
map_sym_tab,
structs_sym_tab=None,
):
"""Handle `and` and `or` boolean operations."""
if isinstance(expr.op, ast.And):
return _handle_and_op(
func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab
)
return _handle_and_op(func, builder, expr, local_sym_tab, compilation_context)
elif isinstance(expr.op, ast.Or):
return _handle_or_op(
func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab
)
return _handle_or_op(func, builder, expr, local_sym_tab, compilation_context)
else:
logger.error(f"Unsupported boolean operator: {type(expr.op).__name__}")
return None
@ -555,12 +532,10 @@ def _handle_boolean_op(
def _handle_vmlinux_cast(
func,
module,
compilation_context,
builder,
expr,
local_sym_tab,
map_sym_tab,
structs_sym_tab=None,
):
# handle expressions such as struct_request(ctx.di) where struct_request is a vmlinux
# struct and ctx.di is a pointer to a struct but is actually represented as a c_uint64
@ -576,12 +551,10 @@ def _handle_vmlinux_cast(
# Evaluate the argument (e.g., ctx.di which is a c_uint64)
arg_result = eval_expr(
func,
module,
compilation_context,
builder,
expr.args[0],
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
if arg_result is None:
@ -614,18 +587,17 @@ def _handle_vmlinux_cast(
def _handle_user_defined_struct_cast(
func,
module,
compilation_context,
builder,
expr,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
):
"""Handle user-defined struct cast expressions like iphdr(nh).
This casts a pointer/integer value to a pointer to the user-defined struct,
similar to how vmlinux struct casts work but for user-defined @struct types.
"""
structs_sym_tab = compilation_context.structs_sym_tab
if len(expr.args) != 1:
logger.info("User-defined struct cast takes exactly one argument")
return None
@ -643,12 +615,10 @@ def _handle_user_defined_struct_cast(
# an address/pointer value)
arg_result = eval_expr(
func,
module,
compilation_context,
builder,
expr.args[0],
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
if arg_result is None:
@ -683,30 +653,28 @@ def _handle_user_defined_struct_cast(
def eval_expr(
func,
module,
compilation_context,
builder,
expr,
local_sym_tab,
map_sym_tab,
structs_sym_tab=None,
):
structs_sym_tab = compilation_context.structs_sym_tab
logger.info(f"Evaluating expression: {ast.dump(expr)}")
if isinstance(expr, ast.Name):
return _handle_name_expr(expr, local_sym_tab, builder)
elif isinstance(expr, ast.Constant):
return _handle_constant_expr(module, builder, expr)
return _handle_constant_expr(compilation_context, builder, expr)
elif isinstance(expr, ast.Call):
if isinstance(expr.func, ast.Name) and VmlinuxHandlerRegistry.is_vmlinux_struct(
expr.func.id
):
return _handle_vmlinux_cast(
func,
module,
compilation_context,
builder,
expr,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
if isinstance(expr.func, ast.Name) and expr.func.id == "deref":
return _handle_deref_call(expr, local_sym_tab, builder)
@ -714,26 +682,23 @@ def eval_expr(
if isinstance(expr.func, ast.Name) and is_ctypes(expr.func.id):
return _handle_ctypes_call(
func,
module,
compilation_context,
builder,
expr,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
if isinstance(expr.func, ast.Name) and (expr.func.id in structs_sym_tab):
return _handle_user_defined_struct_cast(
func,
module,
compilation_context,
builder,
expr,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
# NOTE: Updated handle_call signature
result = CallHandlerRegistry.handle_call(
expr, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
expr, compilation_context, builder, func, local_sym_tab
)
if result is not None:
return result
@ -742,30 +707,24 @@ def eval_expr(
return None
elif isinstance(expr, ast.Attribute):
return _handle_attribute_expr(
func, expr, local_sym_tab, structs_sym_tab, builder
func, expr, local_sym_tab, compilation_context, builder
)
elif isinstance(expr, ast.BinOp):
return _handle_binary_op(
func,
module,
compilation_context,
expr,
builder,
None,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
elif isinstance(expr, ast.Compare):
return _handle_compare(
func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab
)
return _handle_compare(func, compilation_context, builder, expr, local_sym_tab)
elif isinstance(expr, ast.UnaryOp):
return _handle_unary_op(
func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab
)
return _handle_unary_op(func, compilation_context, builder, expr, local_sym_tab)
elif isinstance(expr, ast.BoolOp):
return _handle_boolean_op(
func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab
func, compilation_context, builder, expr, local_sym_tab
)
logger.info("Unsupported expression evaluation")
return None
@ -773,12 +732,10 @@ def eval_expr(
def handle_expr(
func,
module,
compilation_context,
builder,
expr,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
):
"""Handle expression statements in the function body."""
logger.info(f"Handling expression: {ast.dump(expr)}")
@ -786,12 +743,10 @@ def handle_expr(
if isinstance(call, ast.Call):
eval_expr(
func,
module,
compilation_context,
builder,
call,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
else:
logger.info("Unsupported expression type")