Files
python-bpf/pythonbpf/functions/functions_pass.py

515 lines
15 KiB
Python

from llvmlite import ir
import ast
import logging
from pythonbpf.helper import (
HelperHandlerRegistry,
reset_scratch_pool,
)
from pythonbpf.type_deducer import ctypes_to_ir
from pythonbpf.expr import (
eval_expr,
handle_expr,
convert_to_bool,
VmlinuxHandlerRegistry,
)
from pythonbpf.assign_pass import (
handle_variable_assignment,
handle_struct_field_assignment,
)
from pythonbpf.allocation_pass import (
handle_assign_allocation,
allocate_temp_pool,
create_targets_and_rvals,
LocalSymbol,
)
from .function_debug_info import generate_function_debug_info
from .return_utils import handle_none_return, handle_xdp_return, is_xdp_name
from .function_metadata import get_probe_string, is_global_function, infer_return_type
logger = logging.getLogger(__name__)
# ============================================================================
# SECTION 1: Memory Allocation
# ============================================================================
def count_temps_in_call(call_node, local_sym_tab):
"""Count the number of temporary variables needed for a function call."""
count = {}
is_helper = False
# NOTE: We exclude print calls for now
if isinstance(call_node.func, ast.Name):
if (
HelperHandlerRegistry.has_handler(call_node.func.id)
and call_node.func.id != "print"
):
is_helper = True
func_name = call_node.func.id
elif isinstance(call_node.func, ast.Attribute):
if HelperHandlerRegistry.has_handler(call_node.func.attr):
is_helper = True
func_name = call_node.func.attr
if not is_helper:
return {} # No temps needed
for arg_idx in range(len(call_node.args)):
# NOTE: Count all non-name arguments
# For struct fields, if it is being passed as an argument,
# The struct object should already exist in the local_sym_tab
arg = call_node.args[arg_idx]
if isinstance(arg, ast.Name) or (
isinstance(arg, ast.Attribute) and arg.value.id in local_sym_tab
):
continue
param_type = HelperHandlerRegistry.get_param_type(func_name, arg_idx)
if isinstance(param_type, ir.PointerType):
pointee_type = param_type.pointee
count[pointee_type] = count.get(pointee_type, 0) + 1
return count
def handle_if_allocation(
module, builder, stmt, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
):
"""Recursively handle allocations in if/else branches."""
if stmt.body:
allocate_mem(
module,
builder,
stmt.body,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
if stmt.orelse:
allocate_mem(
module,
builder,
stmt.orelse,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
def allocate_mem(
module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
):
max_temps_needed = {}
def merge_type_counts(count_dict):
nonlocal max_temps_needed
for typ, cnt in count_dict.items():
max_temps_needed[typ] = max(max_temps_needed.get(typ, 0), cnt)
def update_max_temps_for_stmt(stmt):
nonlocal max_temps_needed
if isinstance(stmt, ast.If):
for s in stmt.body:
update_max_temps_for_stmt(s)
for s in stmt.orelse:
update_max_temps_for_stmt(s)
return
stmt_temps = {}
for node in ast.walk(stmt):
if isinstance(node, ast.Call):
call_temps = count_temps_in_call(node, local_sym_tab)
for typ, cnt in call_temps.items():
stmt_temps[typ] = stmt_temps.get(typ, 0) + cnt
merge_type_counts(stmt_temps)
for stmt in body:
update_max_temps_for_stmt(stmt)
# Handle allocations
if isinstance(stmt, ast.If):
handle_if_allocation(
module,
builder,
stmt,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
elif isinstance(stmt, ast.Assign):
handle_assign_allocation(
builder, stmt, local_sym_tab, map_sym_tab, structs_sym_tab
)
allocate_temp_pool(builder, max_temps_needed, local_sym_tab)
return local_sym_tab
# ============================================================================
# SECTION 2: Statement Handlers
# ============================================================================
def handle_assign(
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab
):
"""Handle assignment statements in the function body."""
# NOTE: Support multi-target assignments (e.g.: a, b = 1, 2)
targets, rvals = create_targets_and_rvals(stmt)
for target, rval in zip(targets, rvals):
if isinstance(target, ast.Name):
# NOTE: Simple variable assignment case: x = 5
var_name = target.id
result = handle_variable_assignment(
func,
module,
builder,
var_name,
rval,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
if not result:
logger.error(f"Failed to handle assignment to {var_name}")
continue
if isinstance(target, ast.Attribute):
# NOTE: Struct field assignment case: pkt.field = value
handle_struct_field_assignment(
func,
module,
builder,
target,
rval,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
continue
# Unsupported target type
logger.error(f"Unsupported assignment target: {ast.dump(target)}")
def handle_cond(
func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab=None
):
val = eval_expr(
func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab
)[0]
return convert_to_bool(builder, val)
def handle_if(
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab=None
):
"""Handle if statements in the function body."""
logger.info("Handling if statement")
# start = builder.block.parent
then_block = func.append_basic_block(name="if.then")
merge_block = func.append_basic_block(name="if.end")
if stmt.orelse:
else_block = func.append_basic_block(name="if.else")
else:
else_block = None
cond = handle_cond(
func, module, builder, stmt.test, local_sym_tab, map_sym_tab, structs_sym_tab
)
if else_block:
builder.cbranch(cond, then_block, else_block)
else:
builder.cbranch(cond, then_block, merge_block)
builder.position_at_end(then_block)
for s in stmt.body:
process_stmt(
func, module, builder, s, local_sym_tab, map_sym_tab, structs_sym_tab, False
)
if not builder.block.is_terminated:
builder.branch(merge_block)
if else_block:
builder.position_at_end(else_block)
for s in stmt.orelse:
process_stmt(
func,
module,
builder,
s,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
False,
)
if not builder.block.is_terminated:
builder.branch(merge_block)
builder.position_at_end(merge_block)
def handle_return(builder, stmt, local_sym_tab, ret_type):
logger.info(f"Handling return statement: {ast.dump(stmt)}")
if stmt.value is None:
return handle_none_return(builder)
elif isinstance(stmt.value, ast.Name) and is_xdp_name(stmt.value.id):
return handle_xdp_return(stmt, builder, ret_type)
else:
val = eval_expr(
func=None,
module=None,
builder=builder,
expr=stmt.value,
local_sym_tab=local_sym_tab,
map_sym_tab={},
structs_sym_tab={},
)
logger.info(f"Evaluated return expression to {val}")
builder.ret(val[0])
return True
def process_stmt(
func,
module,
builder,
stmt,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
did_return,
ret_type=ir.IntType(64),
):
logger.info(f"Processing statement: {ast.dump(stmt)}")
reset_scratch_pool()
if isinstance(stmt, ast.Expr):
handle_expr(
func,
module,
builder,
stmt,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
elif isinstance(stmt, ast.Assign):
handle_assign(
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab
)
elif isinstance(stmt, ast.AugAssign):
raise SyntaxError("Augmented assignment not supported")
elif isinstance(stmt, ast.If):
handle_if(
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab
)
elif isinstance(stmt, ast.Return):
did_return = handle_return(
builder,
stmt,
local_sym_tab,
ret_type,
)
return did_return
# ============================================================================
# SECTION 3: Function Body Processing
# ============================================================================
def process_func_body(
module,
builder,
func_node,
func,
ret_type,
map_sym_tab,
structs_sym_tab,
):
"""Process the body of a bpf function"""
# TODO: A lot. We just have print -> bpf_trace_printk for now
did_return = False
local_sym_tab = {}
# Add the context parameter (first function argument) to the local symbol table
if func_node.args.args and len(func_node.args.args) > 0:
context_arg = func_node.args.args[0]
context_name = context_arg.arg
if hasattr(context_arg, "annotation") and context_arg.annotation:
if isinstance(context_arg.annotation, ast.Name):
context_type_name = context_arg.annotation.id
elif isinstance(context_arg.annotation, ast.Attribute):
context_type_name = context_arg.annotation.attr
else:
raise TypeError(
f"Unsupported annotation type: {ast.dump(context_arg.annotation)}"
)
if VmlinuxHandlerRegistry.is_vmlinux_struct(context_type_name):
resolved_type = VmlinuxHandlerRegistry.get_struct_type(
context_type_name
)
context_type = LocalSymbol(None, None, resolved_type)
local_sym_tab[context_name] = context_type
logger.info(f"Added argument '{context_name}' to local symbol table")
# pre-allocate dynamic variables
local_sym_tab = allocate_mem(
module,
builder,
func_node.body,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
logger.info(f"Local symbol table: {local_sym_tab.keys()}")
for stmt in func_node.body:
did_return = process_stmt(
func,
module,
builder,
stmt,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
did_return,
ret_type,
)
if not did_return:
builder.ret(ir.Constant(ir.IntType(64), 0))
def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_tab):
"""Process a single BPF chunk (function) and emit corresponding LLVM IR."""
func_name = func_node.name
ret_type = return_type
# TODO: parse parameters
param_types = []
if func_node.args.args:
# Assume first arg to be ctx
param_types.append(ir.PointerType())
func_ty = ir.FunctionType(ret_type, param_types)
func = ir.Function(module, func_ty, func_name)
func.linkage = "dso_local"
func.attributes.add("nounwind")
func.attributes.add("noinline")
# func.attributes.add("optnone")
if func_node.args.args:
# Only look at the first argument for now
param = func.args[0]
param.add_attribute("nocapture")
probe_string = get_probe_string(func_node)
if probe_string is not None:
func.section = probe_string
block = func.append_basic_block(name="entry")
builder = ir.IRBuilder(block)
process_func_body(
module,
builder,
func_node,
func,
ret_type,
map_sym_tab,
structs_sym_tab,
)
return func
# ============================================================================
# SECTION 4: Top-Level Function Processor
# ============================================================================
def func_proc(tree, module, chunks, map_sym_tab, structs_sym_tab):
for func_node in chunks:
if is_global_function(func_node):
continue
func_type = get_probe_string(func_node)
logger.info(f"Found probe_string of {func_node.name}: {func_type}")
func = process_bpf_chunk(
func_node,
module,
ctypes_to_ir(infer_return_type(func_node)),
map_sym_tab,
structs_sym_tab,
)
logger.info(f"Generating Debug Info for Function {func_node.name}")
generate_function_debug_info(func_node, module, func)
# TODO: WIP, for string assignment to fixed-size arrays
def assign_string_to_array(builder, target_array_ptr, source_string_ptr, array_length):
"""
Copy a string (i8*) to a fixed-size array ([N x i8]*)
"""
# Create a loop to copy characters one by one
# entry_block = builder.block
copy_block = builder.append_basic_block("copy_char")
end_block = builder.append_basic_block("copy_end")
# Create loop counter
i = builder.alloca(ir.IntType(32))
builder.store(ir.Constant(ir.IntType(32), 0), i)
# Start the loop
builder.branch(copy_block)
# Copy loop
builder.position_at_end(copy_block)
idx = builder.load(i)
in_bounds = builder.icmp_unsigned(
"<", idx, ir.Constant(ir.IntType(32), array_length)
)
builder.cbranch(in_bounds, copy_block, end_block)
with builder.if_then(in_bounds):
# Load character from source
src_ptr = builder.gep(source_string_ptr, [idx])
char = builder.load(src_ptr)
# Store character in target
dst_ptr = builder.gep(target_array_ptr, [ir.Constant(ir.IntType(32), 0), idx])
builder.store(char, dst_ptr)
# Increment counter
next_idx = builder.add(idx, ir.Constant(ir.IntType(32), 1))
builder.store(next_idx, i)
builder.position_at_end(end_block)
# Ensure null termination
last_idx = ir.Constant(ir.IntType(32), array_length - 1)
null_ptr = builder.gep(target_array_ptr, [ir.Constant(ir.IntType(32), 0), last_idx])
builder.store(ir.Constant(ir.IntType(8), 0), null_ptr)