PythonBPF: Add Compilation Context to allow parallel compilation of multiple bpf programs

This commit is contained in:
Pragyansh Chaturvedi
2026-02-21 18:59:33 +05:30
parent 45d85c416f
commit ec4a6852ec
14 changed files with 455 additions and 497 deletions

View File

@ -50,12 +50,10 @@ class BPFHelperID(Enum):
def bpf_ktime_get_ns_emitter(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_ktime_get_ns helper function call.
@ -77,12 +75,10 @@ def bpf_ktime_get_ns_emitter(
def bpf_get_current_cgroup_id(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_get_current_cgroup_id helper function call.
@ -104,12 +100,10 @@ def bpf_get_current_cgroup_id(
def bpf_map_lookup_elem_emitter(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_map_lookup_elem helper function call.
@ -119,7 +113,7 @@ def bpf_map_lookup_elem_emitter(
f"Map lookup expects exactly one argument (key), got {len(call.args)}"
)
key_ptr = get_or_create_ptr_from_arg(
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab
func, compilation_context, call.args[0], builder, local_sym_tab
)
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
@ -147,12 +141,10 @@ def bpf_map_lookup_elem_emitter(
def bpf_printk_emitter(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""Emit LLVM IR for bpf_printk helper function call."""
if not hasattr(func, "_fmt_counter"):
@ -165,16 +157,18 @@ def bpf_printk_emitter(
if isinstance(call.args[0], ast.JoinedStr):
args = handle_fstring_print(
call.args[0],
module,
compilation_context.module,
builder,
func,
local_sym_tab,
struct_sym_tab,
compilation_context.structs_sym_tab,
)
elif isinstance(call.args[0], ast.Constant) and isinstance(call.args[0].value, str):
# TODO: We are only supporting single arguments for now.
# In case of multiple args, the first one will be taken.
args = simple_string_print(call.args[0].value, module, builder, func)
args = simple_string_print(
call.args[0].value, compilation_context.module, builder, func
)
else:
raise NotImplementedError(
"Only simple strings or f-strings are supported in bpf_printk."
@ -203,12 +197,10 @@ def bpf_printk_emitter(
def bpf_map_update_elem_emitter(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_map_update_elem helper function call.
@ -224,10 +216,10 @@ def bpf_map_update_elem_emitter(
flags_arg = call.args[2] if len(call.args) > 2 else None
key_ptr = get_or_create_ptr_from_arg(
func, module, key_arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab
func, compilation_context, key_arg, builder, local_sym_tab
)
value_ptr = get_or_create_ptr_from_arg(
func, module, value_arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab
func, compilation_context, value_arg, builder, local_sym_tab
)
flags_val = get_flags_val(flags_arg, builder, local_sym_tab)
@ -262,12 +254,10 @@ def bpf_map_update_elem_emitter(
def bpf_map_delete_elem_emitter(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_map_delete_elem helper function call.
@ -278,7 +268,7 @@ def bpf_map_delete_elem_emitter(
f"Map delete expects exactly one argument (key), got {len(call.args)}"
)
key_ptr = get_or_create_ptr_from_arg(
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab
func, compilation_context, call.args[0], builder, local_sym_tab
)
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
@ -306,12 +296,10 @@ def bpf_map_delete_elem_emitter(
def bpf_get_current_comm_emitter(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_get_current_comm helper function call.
@ -327,7 +315,7 @@ def bpf_get_current_comm_emitter(
# Extract buffer pointer and size
buf_ptr, buf_size = get_buffer_ptr_and_size(
buf_arg, builder, local_sym_tab, struct_sym_tab
buf_arg, builder, local_sym_tab, compilation_context
)
# Validate it's a char array
@ -367,12 +355,10 @@ def bpf_get_current_comm_emitter(
def bpf_get_current_pid_tgid_emitter(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_get_current_pid_tgid helper function call.
@ -394,12 +380,10 @@ def bpf_get_current_pid_tgid_emitter(
def bpf_perf_event_output_handler(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_perf_event_output helper function call.
@ -412,7 +396,9 @@ def bpf_perf_event_output_handler(
data_arg = call.args[0]
ctx_ptr = func.args[0] # First argument to the function is ctx
data_ptr, size_val = get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab)
data_ptr, size_val = get_data_ptr_and_size(
data_arg, local_sym_tab, compilation_context.structs_sym_tab
)
# BPF_F_CURRENT_CPU is -1 in 32 bit
flags_val = ir.Constant(ir.IntType(64), 0xFFFFFFFF)
@ -445,12 +431,10 @@ def bpf_perf_event_output_handler(
def bpf_ringbuf_output_emitter(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_ringbuf_output helper function call.
@ -461,7 +445,9 @@ def bpf_ringbuf_output_emitter(
f"Ringbuf output expects exactly one argument, got {len(call.args)}"
)
data_arg = call.args[0]
data_ptr, size_val = get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab)
data_ptr, size_val = get_data_ptr_and_size(
data_arg, local_sym_tab, compilation_context.structs_sym_tab
)
flags_val = ir.Constant(ir.IntType(64), 0)
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
@ -496,38 +482,32 @@ def bpf_ringbuf_output_emitter(
def handle_output_helper(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Route output helper to the appropriate emitter based on map type.
"""
match map_sym_tab[map_ptr.name].type:
match compilation_context.map_sym_tab[map_ptr.name].type:
case BPFMapType.PERF_EVENT_ARRAY:
return bpf_perf_event_output_handler(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab,
struct_sym_tab,
map_sym_tab,
)
case BPFMapType.RINGBUF:
return bpf_ringbuf_output_emitter(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab,
struct_sym_tab,
map_sym_tab,
)
case _:
logger.error("Unsupported map type for output helper.")
@ -572,12 +552,10 @@ def emit_probe_read_kernel_str_call(builder, dst_ptr, dst_size, src_ptr):
def bpf_probe_read_kernel_str_emitter(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""Emit LLVM IR for bpf_probe_read_kernel_str helper."""
@ -588,12 +566,12 @@ def bpf_probe_read_kernel_str_emitter(
# Get destination buffer (char array -> i8*)
dst_ptr, dst_size = get_or_create_ptr_from_arg(
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab
func, compilation_context, call.args[0], builder, local_sym_tab
)
# Get source pointer (evaluate expression)
src_ptr, src_type = get_ptr_from_arg(
call.args[1], func, module, builder, local_sym_tab, map_sym_tab, struct_sym_tab
call.args[1], func, compilation_context, builder, local_sym_tab
)
# Emit the helper call
@ -641,12 +619,10 @@ def emit_probe_read_kernel_call(builder, dst_ptr, dst_size, src_ptr):
def bpf_probe_read_kernel_emitter(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""Emit LLVM IR for bpf_probe_read_kernel helper."""
@ -657,12 +633,12 @@ def bpf_probe_read_kernel_emitter(
# Get destination buffer (char array -> i8*)
dst_ptr, dst_size = get_or_create_ptr_from_arg(
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab
func, compilation_context, call.args[0], builder, local_sym_tab
)
# Get source pointer (evaluate expression)
src_ptr, src_type = get_ptr_from_arg(
call.args[1], func, module, builder, local_sym_tab, map_sym_tab, struct_sym_tab
call.args[1], func, compilation_context, builder, local_sym_tab
)
# Emit the helper call
@ -680,12 +656,10 @@ def bpf_probe_read_kernel_emitter(
def bpf_get_prandom_u32_emitter(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_get_prandom_u32 helper function call.
@ -710,12 +684,10 @@ def bpf_get_prandom_u32_emitter(
def bpf_probe_read_emitter(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_probe_read helper function
@ -726,31 +698,25 @@ def bpf_probe_read_emitter(
return
dst_ptr = get_or_create_ptr_from_arg(
func,
module,
compilation_context,
call.args[0],
builder,
local_sym_tab,
map_sym_tab,
struct_sym_tab,
ir.IntType(8),
)
size_val = get_int_value_from_arg(
call.args[1],
func,
module,
compilation_context,
builder,
local_sym_tab,
map_sym_tab,
struct_sym_tab,
)
src_ptr = get_or_create_ptr_from_arg(
func,
module,
compilation_context,
call.args[2],
builder,
local_sym_tab,
map_sym_tab,
struct_sym_tab,
ir.IntType(8),
)
fn_type = ir.FunctionType(
@ -783,12 +749,10 @@ def bpf_probe_read_emitter(
def bpf_get_smp_processor_id_emitter(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_get_smp_processor_id helper function call.
@ -810,12 +774,10 @@ def bpf_get_smp_processor_id_emitter(
def bpf_get_current_uid_gid_emitter(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_get_current_uid_gid helper function call.
@ -846,12 +808,10 @@ def bpf_get_current_uid_gid_emitter(
def bpf_skb_store_bytes_emitter(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_skb_store_bytes helper function call.
@ -875,30 +835,24 @@ def bpf_skb_store_bytes_emitter(
offset_val = get_int_value_from_arg(
call.args[0],
func,
module,
compilation_context,
builder,
local_sym_tab,
map_sym_tab,
struct_sym_tab,
)
from_ptr = get_or_create_ptr_from_arg(
func,
module,
compilation_context,
call.args[1],
builder,
local_sym_tab,
map_sym_tab,
struct_sym_tab,
args_signature[2],
)
len_val = get_int_value_from_arg(
call.args[2],
func,
module,
compilation_context,
builder,
local_sym_tab,
map_sym_tab,
struct_sym_tab,
)
if len(call.args) == 4:
flags_val = get_flags_val(call.args[3], builder, local_sym_tab)
@ -940,12 +894,10 @@ def bpf_skb_store_bytes_emitter(
def bpf_ringbuf_reserve_emitter(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_ringbuf_reserve helper function call.
@ -960,11 +912,9 @@ def bpf_ringbuf_reserve_emitter(
size_val = get_int_value_from_arg(
call.args[0],
func,
module,
compilation_context,
builder,
local_sym_tab,
map_sym_tab,
struct_sym_tab,
)
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
@ -991,12 +941,10 @@ def bpf_ringbuf_reserve_emitter(
def bpf_ringbuf_submit_emitter(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_ringbuf_submit helper function call.
@ -1013,12 +961,10 @@ def bpf_ringbuf_submit_emitter(
data_ptr = get_or_create_ptr_from_arg(
func,
module,
compilation_context,
data_arg,
builder,
local_sym_tab,
map_sym_tab,
struct_sym_tab,
ir.PointerType(ir.IntType(8)),
)
@ -1050,12 +996,10 @@ def bpf_ringbuf_submit_emitter(
def bpf_get_stack_emitter(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_get_stack helper function call.
@ -1068,7 +1012,7 @@ def bpf_get_stack_emitter(
buf_arg = call.args[0]
flags_arg = call.args[1] if len(call.args) == 2 else None
buf_ptr, buf_size = get_buffer_ptr_and_size(
buf_arg, builder, local_sym_tab, struct_sym_tab
buf_arg, builder, local_sym_tab, compilation_context
)
flags_val = get_flags_val(flags_arg, builder, local_sym_tab)
if isinstance(flags_val, int):
@ -1098,12 +1042,10 @@ def bpf_get_stack_emitter(
def handle_helper_call(
call,
module,
compilation_context,
builder,
func,
local_sym_tab=None,
map_sym_tab=None,
struct_sym_tab=None,
):
"""Process a BPF helper function call and emit the appropriate LLVM IR."""
@ -1117,14 +1059,14 @@ def handle_helper_call(
return handler(
call,
map_ptr,
module,
compilation_context,
builder,
func,
local_sym_tab,
struct_sym_tab,
map_sym_tab,
)
map_sym_tab = compilation_context.map_sym_tab
# Handle direct function calls (e.g., print(), ktime())
if isinstance(call.func, ast.Name):
return invoke_helper(call.func.id)