PythonBPF: Add Compilation Context to allow parallel compilation of multiple bpf programs

This commit is contained in:
Pragyansh Chaturvedi
2026-02-21 18:59:33 +05:30
parent 45d85c416f
commit ec4a6852ec
14 changed files with 455 additions and 497 deletions

View File

@ -26,9 +26,7 @@ def create_targets_and_rvals(stmt):
return stmt.targets, [stmt.value] return stmt.targets, [stmt.value]
def handle_assign_allocation( def handle_assign_allocation(compilation_context, builder, stmt, local_sym_tab):
builder, stmt, local_sym_tab, map_sym_tab, structs_sym_tab
):
"""Handle memory allocation for assignment statements.""" """Handle memory allocation for assignment statements."""
logger.info(f"Handling assignment for allocation: {ast.dump(stmt)}") logger.info(f"Handling assignment for allocation: {ast.dump(stmt)}")
@ -59,7 +57,7 @@ def handle_assign_allocation(
# Determine type and allocate based on rval # Determine type and allocate based on rval
if isinstance(rval, ast.Call): if isinstance(rval, ast.Call):
_allocate_for_call( _allocate_for_call(
builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab builder, var_name, rval, local_sym_tab, compilation_context
) )
elif isinstance(rval, ast.Constant): elif isinstance(rval, ast.Constant):
_allocate_for_constant(builder, var_name, rval, local_sym_tab) _allocate_for_constant(builder, var_name, rval, local_sym_tab)
@ -71,7 +69,7 @@ def handle_assign_allocation(
elif isinstance(rval, ast.Attribute): elif isinstance(rval, ast.Attribute):
# Struct field-to-variable assignment (a = dat.fld) # Struct field-to-variable assignment (a = dat.fld)
_allocate_for_attribute( _allocate_for_attribute(
builder, var_name, rval, local_sym_tab, structs_sym_tab builder, var_name, rval, local_sym_tab, compilation_context
) )
else: else:
logger.warning( logger.warning(
@ -79,10 +77,9 @@ def handle_assign_allocation(
) )
def _allocate_for_call( def _allocate_for_call(builder, var_name, rval, local_sym_tab, compilation_context):
builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab
):
"""Allocate memory for variable assigned from a call.""" """Allocate memory for variable assigned from a call."""
structs_sym_tab = compilation_context.structs_sym_tab
if isinstance(rval.func, ast.Name): if isinstance(rval.func, ast.Name):
call_type = rval.func.id call_type = rval.func.id
@ -149,7 +146,7 @@ def _allocate_for_call(
elif isinstance(rval.func, ast.Attribute): elif isinstance(rval.func, ast.Attribute):
# Map method calls - need double allocation for ptr handling # Map method calls - need double allocation for ptr handling
_allocate_for_map_method( _allocate_for_map_method(
builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab builder, var_name, rval, local_sym_tab, compilation_context
) )
else: else:
@ -157,9 +154,11 @@ def _allocate_for_call(
def _allocate_for_map_method( def _allocate_for_map_method(
builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab builder, var_name, rval, local_sym_tab, compilation_context
): ):
"""Allocate memory for variable assigned from map method (double alloc).""" """Allocate memory for variable assigned from map method (double alloc)."""
map_sym_tab = compilation_context.map_sym_tab
structs_sym_tab = compilation_context.structs_sym_tab
map_name = rval.func.value.id map_name = rval.func.value.id
method_name = rval.func.attr method_name = rval.func.attr
@ -299,6 +298,15 @@ def allocate_temp_pool(builder, max_temps, local_sym_tab):
logger.debug(f"Allocated temp variable: {temp_name}") logger.debug(f"Allocated temp variable: {temp_name}")
def _get_alignment(tmp_type):
"""Return alignment for a given type."""
if isinstance(tmp_type, ir.PointerType):
return 8
elif isinstance(tmp_type, ir.IntType):
return tmp_type.width // 8
return 8
def _allocate_for_name(builder, var_name, rval, local_sym_tab): def _allocate_for_name(builder, var_name, rval, local_sym_tab):
"""Allocate memory for variable-to-variable assignment (b = a).""" """Allocate memory for variable-to-variable assignment (b = a)."""
source_var = rval.id source_var = rval.id
@ -321,8 +329,22 @@ def _allocate_for_name(builder, var_name, rval, local_sym_tab):
) )
def _allocate_for_attribute(builder, var_name, rval, local_sym_tab, structs_sym_tab): def _allocate_with_type(builder, var_name, ir_type):
"""Allocate memory for a variable with a specific type."""
var = builder.alloca(ir_type, name=var_name)
if isinstance(ir_type, ir.IntType):
var.align = ir_type.width // 8
elif isinstance(ir_type, ir.PointerType):
var.align = 8
return var
def _allocate_for_attribute(
builder, var_name, rval, local_sym_tab, compilation_context
):
"""Allocate memory for struct field-to-variable assignment (a = dat.fld).""" """Allocate memory for struct field-to-variable assignment (a = dat.fld)."""
structs_sym_tab = compilation_context.structs_sym_tab
if not isinstance(rval.value, ast.Name): if not isinstance(rval.value, ast.Name):
logger.warning(f"Complex attribute access not supported for {var_name}") logger.warning(f"Complex attribute access not supported for {var_name}")
return return
@ -455,20 +477,3 @@ def _allocate_for_attribute(builder, var_name, rval, local_sym_tab, structs_sym_
logger.info( logger.info(
f"Pre-allocated {var_name} from {struct_var}.{field_name} with type {alloc_type}" f"Pre-allocated {var_name} from {struct_var}.{field_name} with type {alloc_type}"
) )
def _allocate_with_type(builder, var_name, ir_type):
"""Allocate variable with appropriate alignment for type."""
var = builder.alloca(ir_type, name=var_name)
var.align = _get_alignment(ir_type)
return var
def _get_alignment(ir_type):
"""Get appropriate alignment for IR type."""
if isinstance(ir_type, ir.IntType):
return ir_type.width // 8
elif isinstance(ir_type, ir.ArrayType) and isinstance(ir_type.element, ir.IntType):
return ir_type.element.width // 8
else:
return 8 # Default: pointer size

View File

@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
def handle_struct_field_assignment( def handle_struct_field_assignment(
func, module, builder, target, rval, local_sym_tab, map_sym_tab, structs_sym_tab func, compilation_context, builder, target, rval, local_sym_tab
): ):
"""Handle struct field assignment (obj.field = value).""" """Handle struct field assignment (obj.field = value)."""
@ -24,7 +24,7 @@ def handle_struct_field_assignment(
return return
struct_type = local_sym_tab[var_name].metadata struct_type = local_sym_tab[var_name].metadata
struct_info = structs_sym_tab[struct_type] struct_info = compilation_context.structs_sym_tab[struct_type]
if field_name not in struct_info.fields: if field_name not in struct_info.fields:
logger.error(f"Field '{field_name}' not found in struct '{struct_type}'") logger.error(f"Field '{field_name}' not found in struct '{struct_type}'")
@ -33,9 +33,7 @@ def handle_struct_field_assignment(
# Get field pointer and evaluate value # Get field pointer and evaluate value
field_ptr = struct_info.gep(builder, local_sym_tab[var_name].var, field_name) field_ptr = struct_info.gep(builder, local_sym_tab[var_name].var, field_name)
field_type = struct_info.field_type(field_name) field_type = struct_info.field_type(field_name)
val_result = eval_expr( val_result = eval_expr(func, compilation_context, builder, rval, local_sym_tab)
func, module, builder, rval, local_sym_tab, map_sym_tab, structs_sym_tab
)
if val_result is None: if val_result is None:
logger.error(f"Failed to evaluate value for {var_name}.{field_name}") logger.error(f"Failed to evaluate value for {var_name}.{field_name}")
@ -47,14 +45,12 @@ def handle_struct_field_assignment(
if _is_char_array(field_type) and _is_i8_ptr(val_type): if _is_char_array(field_type) and _is_i8_ptr(val_type):
_copy_string_to_char_array( _copy_string_to_char_array(
func, func,
module, compilation_context,
builder, builder,
val, val,
field_ptr, field_ptr,
field_type, field_type,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab,
) )
logger.info(f"Copied string to char array {var_name}.{field_name}") logger.info(f"Copied string to char array {var_name}.{field_name}")
return return
@ -66,14 +62,12 @@ def handle_struct_field_assignment(
def _copy_string_to_char_array( def _copy_string_to_char_array(
func, func,
module, compilation_context,
builder, builder,
src_ptr, src_ptr,
dst_ptr, dst_ptr,
array_type, array_type,
local_sym_tab, local_sym_tab,
map_sym_tab,
struct_sym_tab,
): ):
"""Copy string (i8*) to char array ([N x i8]) using bpf_probe_read_kernel_str""" """Copy string (i8*) to char array ([N x i8]) using bpf_probe_read_kernel_str"""
@ -109,7 +103,7 @@ def _is_i8_ptr(ir_type):
def handle_variable_assignment( def handle_variable_assignment(
func, module, builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab func, compilation_context, builder, var_name, rval, local_sym_tab
): ):
"""Handle single named variable assignment.""" """Handle single named variable assignment."""
@ -120,6 +114,8 @@ def handle_variable_assignment(
var_ptr = local_sym_tab[var_name].var var_ptr = local_sym_tab[var_name].var
var_type = local_sym_tab[var_name].ir_type var_type = local_sym_tab[var_name].ir_type
structs_sym_tab = compilation_context.structs_sym_tab
# NOTE: Special case for struct initialization # NOTE: Special case for struct initialization
if isinstance(rval, ast.Call) and isinstance(rval.func, ast.Name): if isinstance(rval, ast.Call) and isinstance(rval.func, ast.Name):
struct_name = rval.func.id struct_name = rval.func.id
@ -142,9 +138,7 @@ def handle_variable_assignment(
logger.info(f"Assigned char array pointer to {var_name}") logger.info(f"Assigned char array pointer to {var_name}")
return True return True
val_result = eval_expr( val_result = eval_expr(func, compilation_context, builder, rval, local_sym_tab)
func, module, builder, rval, local_sym_tab, map_sym_tab, structs_sym_tab
)
if val_result is None: if val_result is None:
logger.error(f"Failed to evaluate value for {var_name}") logger.error(f"Failed to evaluate value for {var_name}")
return False return False

View File

@ -1,5 +1,6 @@
import ast import ast
from llvmlite import ir from llvmlite import ir
from .context import CompilationContext
from .license_pass import license_processing from .license_pass import license_processing
from .functions import func_proc from .functions import func_proc
from .maps import maps_proc from .maps import maps_proc
@ -67,9 +68,10 @@ def find_bpf_chunks(tree):
return bpf_functions return bpf_functions
def processor(source_code, filename, module): def processor(source_code, filename, compilation_context):
tree = ast.parse(source_code, filename) tree = ast.parse(source_code, filename)
logger.debug(ast.dump(tree, indent=4)) logger.debug(ast.dump(tree, indent=4))
module = compilation_context.module
bpf_chunks = find_bpf_chunks(tree) bpf_chunks = find_bpf_chunks(tree)
for func_node in bpf_chunks: for func_node in bpf_chunks:
@ -81,15 +83,18 @@ def processor(source_code, filename, module):
if vmlinux_symtab: if vmlinux_symtab:
handler = VmlinuxHandler.initialize(vmlinux_symtab) handler = VmlinuxHandler.initialize(vmlinux_symtab)
VmlinuxHandlerRegistry.set_handler(handler) VmlinuxHandlerRegistry.set_handler(handler)
compilation_context.vmlinux_handler = handler
populate_global_symbol_table(tree, module) populate_global_symbol_table(tree, compilation_context)
license_processing(tree, module) license_processing(tree, compilation_context)
globals_processing(tree, module) globals_processing(tree, compilation_context)
structs_sym_tab = structs_proc(tree, module, bpf_chunks) structs_sym_tab = structs_proc(tree, compilation_context, bpf_chunks)
map_sym_tab = maps_proc(tree, module, bpf_chunks, structs_sym_tab)
func_proc(tree, module, bpf_chunks, map_sym_tab, structs_sym_tab)
globals_list_creation(tree, module) map_sym_tab = maps_proc(tree, compilation_context, bpf_chunks)
func_proc(tree, compilation_context, bpf_chunks)
globals_list_creation(tree, compilation_context)
return structs_sym_tab, map_sym_tab return structs_sym_tab, map_sym_tab
@ -104,6 +109,8 @@ def compile_to_ir(filename: str, output: str, loglevel=logging.INFO):
module.data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128" module.data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128"
module.triple = "bpf" module.triple = "bpf"
compilation_context = CompilationContext(module)
if not hasattr(module, "_debug_compile_unit"): if not hasattr(module, "_debug_compile_unit"):
debug_generator = DebugInfoGenerator(module) debug_generator = DebugInfoGenerator(module)
debug_generator.generate_file_metadata(filename, os.path.dirname(filename)) debug_generator.generate_file_metadata(filename, os.path.dirname(filename))
@ -116,7 +123,7 @@ def compile_to_ir(filename: str, output: str, loglevel=logging.INFO):
True, True,
) )
structs_sym_tab, maps_sym_tab = processor(source, filename, module) structs_sym_tab, maps_sym_tab = processor(source, filename, compilation_context)
wchar_size = module.add_metadata( wchar_size = module.add_metadata(
[ [

82
pythonbpf/context.py Normal file
View File

@ -0,0 +1,82 @@
from llvmlite import ir
import logging
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from pythonbpf.structs.struct_type import StructType
from pythonbpf.maps.maps_utils import MapSymbol
logger = logging.getLogger(__name__)
class ScratchPoolManager:
"""Manage the temporary helper variables in local_sym_tab"""
def __init__(self):
self._counters = {}
@property
def counter(self):
return sum(self._counters.values())
def reset(self):
self._counters.clear()
logger.debug("Scratch pool counter reset to 0")
def _get_type_name(self, ir_type):
if isinstance(ir_type, ir.PointerType):
return "ptr"
elif isinstance(ir_type, ir.IntType):
return f"i{ir_type.width}"
elif isinstance(ir_type, ir.ArrayType):
return f"[{ir_type.count}x{self._get_type_name(ir_type.element)}]"
else:
return str(ir_type).replace(" ", "")
def get_next_temp(self, local_sym_tab, expected_type=None):
# Default to i64 if no expected type provided
type_name = self._get_type_name(expected_type) if expected_type else "i64"
if type_name not in self._counters:
self._counters[type_name] = 0
counter = self._counters[type_name]
temp_name = f"__helper_temp_{type_name}_{counter}"
self._counters[type_name] += 1
if temp_name not in local_sym_tab:
raise ValueError(
f"Scratch pool exhausted or inadequate: {temp_name}. "
f"Type: {type_name} Counter: {counter}"
)
logger.debug(f"Using {temp_name} for type {type_name}")
return local_sym_tab[temp_name].var, temp_name
class CompilationContext:
"""
Holds the state for a single compilation run.
This replaces global mutable state modules.
"""
def __init__(self, module: ir.Module):
self.module = module
# Symbol tables
self.global_sym_tab: list[ir.GlobalVariable] = []
self.structs_sym_tab: dict[str, "StructType"] = {}
self.map_sym_tab: dict[str, "MapSymbol"] = {}
# Helper management
self.scratch_pool = ScratchPoolManager()
# Vmlinux handling (optional, specialized)
self.vmlinux_handler = None # Can be VmlinuxHandler instance
# Current function context (optional, if needed globally during function processing)
self.current_func = None
def reset(self):
"""Reset state between functions if necessary, though new context per compile is preferred."""
self.scratch_pool.reset()
self.current_func = None

View File

@ -9,12 +9,8 @@ class CallHandlerRegistry:
cls._handler = handler cls._handler = handler
@classmethod @classmethod
def handle_call( def handle_call(cls, call, compilation_context, builder, func, local_sym_tab):
cls, call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
):
"""Handle a call using the registered handler""" """Handle a call using the registered handler"""
if cls._handler is None: if cls._handler is None:
return None return None
return cls._handler( return cls._handler(call, compilation_context, builder, func, local_sym_tab)
call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
)

View File

@ -37,7 +37,7 @@ def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder
raise SyntaxError(f"Undefined variable {expr.id}") raise SyntaxError(f"Undefined variable {expr.id}")
def _handle_constant_expr(module, builder, expr: ast.Constant): def _handle_constant_expr(compilation_context, builder, expr: ast.Constant):
"""Handle ast.Constant expressions.""" """Handle ast.Constant expressions."""
if isinstance(expr.value, int) or isinstance(expr.value, bool): if isinstance(expr.value, int) or isinstance(expr.value, bool):
return ir.Constant(ir.IntType(64), int(expr.value)), ir.IntType(64) return ir.Constant(ir.IntType(64), int(expr.value)), ir.IntType(64)
@ -48,7 +48,9 @@ def _handle_constant_expr(module, builder, expr: ast.Constant):
str_constant = ir.Constant(str_type, bytearray(str_bytes)) str_constant = ir.Constant(str_type, bytearray(str_bytes))
# Create global variable # Create global variable
global_str = ir.GlobalVariable(module, str_type, name=str_name) global_str = ir.GlobalVariable(
compilation_context.module, str_type, name=str_name
)
global_str.linkage = "internal" global_str.linkage = "internal"
global_str.global_constant = True global_str.global_constant = True
global_str.initializer = str_constant global_str.initializer = str_constant
@ -64,10 +66,11 @@ def _handle_attribute_expr(
func, func,
expr: ast.Attribute, expr: ast.Attribute,
local_sym_tab: Dict, local_sym_tab: Dict,
structs_sym_tab: Dict, compilation_context,
builder: ir.IRBuilder, builder: ir.IRBuilder,
): ):
"""Handle ast.Attribute expressions for struct field access.""" """Handle ast.Attribute expressions for struct field access."""
structs_sym_tab = compilation_context.structs_sym_tab
if isinstance(expr.value, ast.Name): if isinstance(expr.value, ast.Name):
var_name = expr.value.id var_name = expr.value.id
attr_name = expr.attr attr_name = expr.attr
@ -157,9 +160,7 @@ def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilde
# ============================================================================ # ============================================================================
def get_operand_value( def get_operand_value(func, compilation_context, operand, builder, local_sym_tab):
func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None
):
"""Extract the value from an operand, handling variables and constants.""" """Extract the value from an operand, handling variables and constants."""
logger.info(f"Getting operand value for: {ast.dump(operand)}") logger.info(f"Getting operand value for: {ast.dump(operand)}")
if isinstance(operand, ast.Name): if isinstance(operand, ast.Name):
@ -187,13 +188,11 @@ def get_operand_value(
raise TypeError(f"Unsupported constant type: {type(operand.value)}") raise TypeError(f"Unsupported constant type: {type(operand.value)}")
elif isinstance(operand, ast.BinOp): elif isinstance(operand, ast.BinOp):
res = _handle_binary_op_impl( res = _handle_binary_op_impl(
func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab func, compilation_context, operand, builder, local_sym_tab
) )
return res return res
else: else:
res = eval_expr( res = eval_expr(func, compilation_context, builder, operand, local_sym_tab)
func, module, builder, operand, local_sym_tab, map_sym_tab, structs_sym_tab
)
if res is None: if res is None:
raise ValueError(f"Failed to evaluate call expression: {operand}") raise ValueError(f"Failed to evaluate call expression: {operand}")
val, _ = res val, _ = res
@ -205,15 +204,13 @@ def get_operand_value(
raise TypeError(f"Unsupported operand type: {type(operand)}") raise TypeError(f"Unsupported operand type: {type(operand)}")
def _handle_binary_op_impl( def _handle_binary_op_impl(func, compilation_context, rval, builder, local_sym_tab):
func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None
):
op = rval.op op = rval.op
left = get_operand_value( left = get_operand_value(
func, module, rval.left, builder, local_sym_tab, map_sym_tab, structs_sym_tab func, compilation_context, rval.left, builder, local_sym_tab
) )
right = get_operand_value( right = get_operand_value(
func, module, rval.right, builder, local_sym_tab, map_sym_tab, structs_sym_tab func, compilation_context, rval.right, builder, local_sym_tab
) )
logger.info(f"left is {left}, right is {right}, op is {op}") logger.info(f"left is {left}, right is {right}, op is {op}")
@ -249,16 +246,14 @@ def _handle_binary_op_impl(
def _handle_binary_op( def _handle_binary_op(
func, func,
module, compilation_context,
rval, rval,
builder, builder,
var_name, var_name,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab=None,
): ):
result = _handle_binary_op_impl( result = _handle_binary_op_impl(
func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab func, compilation_context, rval, builder, local_sym_tab
) )
if var_name and var_name in local_sym_tab: if var_name and var_name in local_sym_tab:
logger.info( logger.info(
@ -275,12 +270,10 @@ def _handle_binary_op(
def _handle_ctypes_call( def _handle_ctypes_call(
func, func,
module, compilation_context,
builder, builder,
expr, expr,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab=None,
): ):
"""Handle ctypes type constructor calls.""" """Handle ctypes type constructor calls."""
if len(expr.args) != 1: if len(expr.args) != 1:
@ -290,12 +283,10 @@ def _handle_ctypes_call(
arg = expr.args[0] arg = expr.args[0]
val = eval_expr( val = eval_expr(
func, func,
module, compilation_context,
builder, builder,
arg, arg,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab,
) )
if val is None: if val is None:
logger.info("Failed to evaluate argument to ctypes constructor") logger.info("Failed to evaluate argument to ctypes constructor")
@ -344,9 +335,7 @@ def _handle_ctypes_call(
return value, expected_type return value, expected_type
def _handle_compare( def _handle_compare(func, compilation_context, builder, cond, local_sym_tab):
func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab=None
):
"""Handle ast.Compare expressions.""" """Handle ast.Compare expressions."""
if len(cond.ops) != 1 or len(cond.comparators) != 1: if len(cond.ops) != 1 or len(cond.comparators) != 1:
@ -354,21 +343,17 @@ def _handle_compare(
return None return None
lhs = eval_expr( lhs = eval_expr(
func, func,
module, compilation_context,
builder, builder,
cond.left, cond.left,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab,
) )
rhs = eval_expr( rhs = eval_expr(
func, func,
module, compilation_context,
builder, builder,
cond.comparators[0], cond.comparators[0],
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab,
) )
if lhs is None or rhs is None: if lhs is None or rhs is None:
@ -382,12 +367,10 @@ def _handle_compare(
def _handle_unary_op( def _handle_unary_op(
func, func,
module, compilation_context,
builder, builder,
expr: ast.UnaryOp, expr: ast.UnaryOp,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab=None,
): ):
"""Handle ast.UnaryOp expressions.""" """Handle ast.UnaryOp expressions."""
if not isinstance(expr.op, ast.Not) and not isinstance(expr.op, ast.USub): if not isinstance(expr.op, ast.Not) and not isinstance(expr.op, ast.USub):
@ -395,7 +378,7 @@ def _handle_unary_op(
return None return None
operand = get_operand_value( operand = get_operand_value(
func, module, expr.operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab func, compilation_context, expr.operand, builder, local_sym_tab
) )
if operand is None: if operand is None:
logger.error("Failed to evaluate operand for unary operation") logger.error("Failed to evaluate operand for unary operation")
@ -418,7 +401,7 @@ def _handle_unary_op(
# ============================================================================ # ============================================================================
def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab): def _handle_and_op(func, builder, expr, local_sym_tab, compilation_context):
"""Handle `and` boolean operations.""" """Handle `and` boolean operations."""
logger.debug(f"Handling 'and' operator with {len(expr.values)} operands") logger.debug(f"Handling 'and' operator with {len(expr.values)} operands")
@ -433,7 +416,7 @@ def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_
# Evaluate current operand # Evaluate current operand
operand_result = eval_expr( operand_result = eval_expr(
func, None, builder, value, local_sym_tab, map_sym_tab, structs_sym_tab func, compilation_context, builder, value, local_sym_tab
) )
if operand_result is None: if operand_result is None:
logger.error(f"Failed to evaluate operand {i} in 'and' expression") logger.error(f"Failed to evaluate operand {i} in 'and' expression")
@ -471,7 +454,7 @@ def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_
return phi, ir.IntType(1) return phi, ir.IntType(1)
def _handle_or_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab): def _handle_or_op(func, builder, expr, local_sym_tab, compilation_context):
"""Handle `or` boolean operations.""" """Handle `or` boolean operations."""
logger.debug(f"Handling 'or' operator with {len(expr.values)} operands") logger.debug(f"Handling 'or' operator with {len(expr.values)} operands")
@ -486,7 +469,7 @@ def _handle_or_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_t
# Evaluate current operand # Evaluate current operand
operand_result = eval_expr( operand_result = eval_expr(
func, None, builder, value, local_sym_tab, map_sym_tab, structs_sym_tab func, compilation_context, builder, value, local_sym_tab
) )
if operand_result is None: if operand_result is None:
logger.error(f"Failed to evaluate operand {i} in 'or' expression") logger.error(f"Failed to evaluate operand {i} in 'or' expression")
@ -526,23 +509,17 @@ def _handle_or_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_t
def _handle_boolean_op( def _handle_boolean_op(
func, func,
module, compilation_context,
builder, builder,
expr: ast.BoolOp, expr: ast.BoolOp,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab=None,
): ):
"""Handle `and` and `or` boolean operations.""" """Handle `and` and `or` boolean operations."""
if isinstance(expr.op, ast.And): if isinstance(expr.op, ast.And):
return _handle_and_op( return _handle_and_op(func, builder, expr, local_sym_tab, compilation_context)
func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab
)
elif isinstance(expr.op, ast.Or): elif isinstance(expr.op, ast.Or):
return _handle_or_op( return _handle_or_op(func, builder, expr, local_sym_tab, compilation_context)
func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab
)
else: else:
logger.error(f"Unsupported boolean operator: {type(expr.op).__name__}") logger.error(f"Unsupported boolean operator: {type(expr.op).__name__}")
return None return None
@ -555,12 +532,10 @@ def _handle_boolean_op(
def _handle_vmlinux_cast( def _handle_vmlinux_cast(
func, func,
module, compilation_context,
builder, builder,
expr, expr,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab=None,
): ):
# handle expressions such as struct_request(ctx.di) where struct_request is a vmlinux # handle expressions such as struct_request(ctx.di) where struct_request is a vmlinux
# struct and ctx.di is a pointer to a struct but is actually represented as a c_uint64 # struct and ctx.di is a pointer to a struct but is actually represented as a c_uint64
@ -576,12 +551,10 @@ def _handle_vmlinux_cast(
# Evaluate the argument (e.g., ctx.di which is a c_uint64) # Evaluate the argument (e.g., ctx.di which is a c_uint64)
arg_result = eval_expr( arg_result = eval_expr(
func, func,
module, compilation_context,
builder, builder,
expr.args[0], expr.args[0],
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab,
) )
if arg_result is None: if arg_result is None:
@ -614,18 +587,17 @@ def _handle_vmlinux_cast(
def _handle_user_defined_struct_cast( def _handle_user_defined_struct_cast(
func, func,
module, compilation_context,
builder, builder,
expr, expr,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab,
): ):
"""Handle user-defined struct cast expressions like iphdr(nh). """Handle user-defined struct cast expressions like iphdr(nh).
This casts a pointer/integer value to a pointer to the user-defined struct, This casts a pointer/integer value to a pointer to the user-defined struct,
similar to how vmlinux struct casts work but for user-defined @struct types. similar to how vmlinux struct casts work but for user-defined @struct types.
""" """
structs_sym_tab = compilation_context.structs_sym_tab
if len(expr.args) != 1: if len(expr.args) != 1:
logger.info("User-defined struct cast takes exactly one argument") logger.info("User-defined struct cast takes exactly one argument")
return None return None
@ -643,12 +615,10 @@ def _handle_user_defined_struct_cast(
# an address/pointer value) # an address/pointer value)
arg_result = eval_expr( arg_result = eval_expr(
func, func,
module, compilation_context,
builder, builder,
expr.args[0], expr.args[0],
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab,
) )
if arg_result is None: if arg_result is None:
@ -683,30 +653,28 @@ def _handle_user_defined_struct_cast(
def eval_expr( def eval_expr(
func, func,
module, compilation_context,
builder, builder,
expr, expr,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab=None,
): ):
structs_sym_tab = compilation_context.structs_sym_tab
logger.info(f"Evaluating expression: {ast.dump(expr)}") logger.info(f"Evaluating expression: {ast.dump(expr)}")
if isinstance(expr, ast.Name): if isinstance(expr, ast.Name):
return _handle_name_expr(expr, local_sym_tab, builder) return _handle_name_expr(expr, local_sym_tab, builder)
elif isinstance(expr, ast.Constant): elif isinstance(expr, ast.Constant):
return _handle_constant_expr(module, builder, expr) return _handle_constant_expr(compilation_context, builder, expr)
elif isinstance(expr, ast.Call): elif isinstance(expr, ast.Call):
if isinstance(expr.func, ast.Name) and VmlinuxHandlerRegistry.is_vmlinux_struct( if isinstance(expr.func, ast.Name) and VmlinuxHandlerRegistry.is_vmlinux_struct(
expr.func.id expr.func.id
): ):
return _handle_vmlinux_cast( return _handle_vmlinux_cast(
func, func,
module, compilation_context,
builder, builder,
expr, expr,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab,
) )
if isinstance(expr.func, ast.Name) and expr.func.id == "deref": if isinstance(expr.func, ast.Name) and expr.func.id == "deref":
return _handle_deref_call(expr, local_sym_tab, builder) return _handle_deref_call(expr, local_sym_tab, builder)
@ -714,26 +682,23 @@ def eval_expr(
if isinstance(expr.func, ast.Name) and is_ctypes(expr.func.id): if isinstance(expr.func, ast.Name) and is_ctypes(expr.func.id):
return _handle_ctypes_call( return _handle_ctypes_call(
func, func,
module, compilation_context,
builder, builder,
expr, expr,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab,
) )
if isinstance(expr.func, ast.Name) and (expr.func.id in structs_sym_tab): if isinstance(expr.func, ast.Name) and (expr.func.id in structs_sym_tab):
return _handle_user_defined_struct_cast( return _handle_user_defined_struct_cast(
func, func,
module, compilation_context,
builder, builder,
expr, expr,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab,
) )
# NOTE: Updated handle_call signature
result = CallHandlerRegistry.handle_call( result = CallHandlerRegistry.handle_call(
expr, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab expr, compilation_context, builder, func, local_sym_tab
) )
if result is not None: if result is not None:
return result return result
@ -742,30 +707,24 @@ def eval_expr(
return None return None
elif isinstance(expr, ast.Attribute): elif isinstance(expr, ast.Attribute):
return _handle_attribute_expr( return _handle_attribute_expr(
func, expr, local_sym_tab, structs_sym_tab, builder func, expr, local_sym_tab, compilation_context, builder
) )
elif isinstance(expr, ast.BinOp): elif isinstance(expr, ast.BinOp):
return _handle_binary_op( return _handle_binary_op(
func, func,
module, compilation_context,
expr, expr,
builder, builder,
None, None,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab,
) )
elif isinstance(expr, ast.Compare): elif isinstance(expr, ast.Compare):
return _handle_compare( return _handle_compare(func, compilation_context, builder, expr, local_sym_tab)
func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab
)
elif isinstance(expr, ast.UnaryOp): elif isinstance(expr, ast.UnaryOp):
return _handle_unary_op( return _handle_unary_op(func, compilation_context, builder, expr, local_sym_tab)
func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab
)
elif isinstance(expr, ast.BoolOp): elif isinstance(expr, ast.BoolOp):
return _handle_boolean_op( return _handle_boolean_op(
func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab func, compilation_context, builder, expr, local_sym_tab
) )
logger.info("Unsupported expression evaluation") logger.info("Unsupported expression evaluation")
return None return None
@ -773,12 +732,10 @@ def eval_expr(
def handle_expr( def handle_expr(
func, func,
module, compilation_context,
builder, builder,
expr, expr,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab,
): ):
"""Handle expression statements in the function body.""" """Handle expression statements in the function body."""
logger.info(f"Handling expression: {ast.dump(expr)}") logger.info(f"Handling expression: {ast.dump(expr)}")
@ -786,12 +743,10 @@ def handle_expr(
if isinstance(call, ast.Call): if isinstance(call, ast.Call):
eval_expr( eval_expr(
func, func,
module, compilation_context,
builder, builder,
call, call,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab,
) )
else: else:
logger.info("Unsupported expression type") logger.info("Unsupported expression type")

View File

@ -4,7 +4,6 @@ import logging
from pythonbpf.helper import ( from pythonbpf.helper import (
HelperHandlerRegistry, HelperHandlerRegistry,
reset_scratch_pool,
) )
from pythonbpf.type_deducer import ctypes_to_ir from pythonbpf.type_deducer import ctypes_to_ir
from pythonbpf.expr import ( from pythonbpf.expr import (
@ -76,36 +75,30 @@ def count_temps_in_call(call_node, local_sym_tab):
def handle_if_allocation( def handle_if_allocation(
module, builder, stmt, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab compilation_context, builder, stmt, func, ret_type, local_sym_tab
): ):
"""Recursively handle allocations in if/else branches.""" """Recursively handle allocations in if/else branches."""
if stmt.body: if stmt.body:
allocate_mem( allocate_mem(
module, compilation_context,
builder, builder,
stmt.body, stmt.body,
func, func,
ret_type, ret_type,
map_sym_tab,
local_sym_tab, local_sym_tab,
structs_sym_tab,
) )
if stmt.orelse: if stmt.orelse:
allocate_mem( allocate_mem(
module, compilation_context,
builder, builder,
stmt.orelse, stmt.orelse,
func, func,
ret_type, ret_type,
map_sym_tab,
local_sym_tab, local_sym_tab,
structs_sym_tab,
) )
def allocate_mem( def allocate_mem(compilation_context, builder, body, func, ret_type, local_sym_tab):
module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
):
max_temps_needed = {} max_temps_needed = {}
def merge_type_counts(count_dict): def merge_type_counts(count_dict):
@ -137,19 +130,15 @@ def allocate_mem(
# Handle allocations # Handle allocations
if isinstance(stmt, ast.If): if isinstance(stmt, ast.If):
handle_if_allocation( handle_if_allocation(
module, compilation_context,
builder, builder,
stmt, stmt,
func, func,
ret_type, ret_type,
map_sym_tab,
local_sym_tab, local_sym_tab,
structs_sym_tab,
) )
elif isinstance(stmt, ast.Assign): elif isinstance(stmt, ast.Assign):
handle_assign_allocation( handle_assign_allocation(compilation_context, builder, stmt, local_sym_tab)
builder, stmt, local_sym_tab, map_sym_tab, structs_sym_tab
)
allocate_temp_pool(builder, max_temps_needed, local_sym_tab) allocate_temp_pool(builder, max_temps_needed, local_sym_tab)
@ -161,9 +150,7 @@ def allocate_mem(
# ============================================================================ # ============================================================================
def handle_assign( def handle_assign(func, compilation_context, builder, stmt, local_sym_tab):
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab
):
"""Handle assignment statements in the function body.""" """Handle assignment statements in the function body."""
# NOTE: Support multi-target assignments (e.g.: a, b = 1, 2) # NOTE: Support multi-target assignments (e.g.: a, b = 1, 2)
@ -175,13 +162,11 @@ def handle_assign(
var_name = target.id var_name = target.id
result = handle_variable_assignment( result = handle_variable_assignment(
func, func,
module, compilation_context,
builder, builder,
var_name, var_name,
rval, rval,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab,
) )
if not result: if not result:
logger.error(f"Failed to handle assignment to {var_name}") logger.error(f"Failed to handle assignment to {var_name}")
@ -191,13 +176,11 @@ def handle_assign(
# NOTE: Struct field assignment case: pkt.field = value # NOTE: Struct field assignment case: pkt.field = value
handle_struct_field_assignment( handle_struct_field_assignment(
func, func,
module, compilation_context,
builder, builder,
target, target,
rval, rval,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab,
) )
continue continue
@ -205,18 +188,12 @@ def handle_assign(
logger.error(f"Unsupported assignment target: {ast.dump(target)}") logger.error(f"Unsupported assignment target: {ast.dump(target)}")
def handle_cond( def handle_cond(func, compilation_context, builder, cond, local_sym_tab):
func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab=None val = eval_expr(func, compilation_context, builder, cond, local_sym_tab)[0]
):
val = eval_expr(
func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab
)[0]
return convert_to_bool(builder, val) return convert_to_bool(builder, val)
def handle_if( def handle_if(func, compilation_context, builder, stmt, local_sym_tab):
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab=None
):
"""Handle if statements in the function body.""" """Handle if statements in the function body."""
logger.info("Handling if statement") logger.info("Handling if statement")
# start = builder.block.parent # start = builder.block.parent
@ -227,9 +204,7 @@ def handle_if(
else: else:
else_block = None else_block = None
cond = handle_cond( cond = handle_cond(func, compilation_context, builder, stmt.test, local_sym_tab)
func, module, builder, stmt.test, local_sym_tab, map_sym_tab, structs_sym_tab
)
if else_block: if else_block:
builder.cbranch(cond, then_block, else_block) builder.cbranch(cond, then_block, else_block)
else: else:
@ -237,9 +212,7 @@ def handle_if(
builder.position_at_end(then_block) builder.position_at_end(then_block)
for s in stmt.body: for s in stmt.body:
process_stmt( process_stmt(func, compilation_context, builder, s, local_sym_tab, False)
func, module, builder, s, local_sym_tab, map_sym_tab, structs_sym_tab, False
)
if not builder.block.is_terminated: if not builder.block.is_terminated:
builder.branch(merge_block) builder.branch(merge_block)
@ -248,12 +221,10 @@ def handle_if(
for s in stmt.orelse: for s in stmt.orelse:
process_stmt( process_stmt(
func, func,
module, compilation_context,
builder, builder,
s, s,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab,
False, False,
) )
if not builder.block.is_terminated: if not builder.block.is_terminated:
@ -262,21 +233,25 @@ def handle_if(
builder.position_at_end(merge_block) builder.position_at_end(merge_block)
def handle_return(builder, stmt, local_sym_tab, ret_type): def handle_return(builder, stmt, local_sym_tab, ret_type, compilation_context=None):
logger.info(f"Handling return statement: {ast.dump(stmt)}") logger.info(f"Handling return statement: {ast.dump(stmt)}")
if stmt.value is None: if stmt.value is None:
return handle_none_return(builder) return handle_none_return(builder)
elif isinstance(stmt.value, ast.Name) and is_xdp_name(stmt.value.id): elif isinstance(stmt.value, ast.Name) and is_xdp_name(stmt.value.id):
return handle_xdp_return(stmt, builder, ret_type) return handle_xdp_return(stmt, builder, ret_type)
else: else:
# Fallback for now if ctx not passed, but caller should pass it
if compilation_context is None:
raise RuntimeError(
"CompilationContext required for return statement evaluation"
)
val = eval_expr( val = eval_expr(
func=None, func=None,
module=None, compilation_context=compilation_context,
builder=builder, builder=builder,
expr=stmt.value, expr=stmt.value,
local_sym_tab=local_sym_tab, local_sym_tab=local_sym_tab,
map_sym_tab={},
structs_sym_tab={},
) )
logger.info(f"Evaluated return expression to {val}") logger.info(f"Evaluated return expression to {val}")
builder.ret(val[0]) builder.ret(val[0])
@ -285,43 +260,34 @@ def handle_return(builder, stmt, local_sym_tab, ret_type):
def process_stmt( def process_stmt(
func, func,
module, compilation_context,
builder, builder,
stmt, stmt,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab,
did_return, did_return,
ret_type=ir.IntType(64), ret_type=ir.IntType(64),
): ):
logger.info(f"Processing statement: {ast.dump(stmt)}") logger.info(f"Processing statement: {ast.dump(stmt)}")
reset_scratch_pool() # Use context scratch pool
compilation_context.scratch_pool.reset()
if isinstance(stmt, ast.Expr): if isinstance(stmt, ast.Expr):
handle_expr( handle_expr(
func, func,
module, compilation_context,
builder, builder,
stmt, stmt,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab,
) )
elif isinstance(stmt, ast.Assign): elif isinstance(stmt, ast.Assign):
handle_assign( handle_assign(func, compilation_context, builder, stmt, local_sym_tab)
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab
)
elif isinstance(stmt, ast.AugAssign): elif isinstance(stmt, ast.AugAssign):
raise SyntaxError("Augmented assignment not supported") raise SyntaxError("Augmented assignment not supported")
elif isinstance(stmt, ast.If): elif isinstance(stmt, ast.If):
handle_if( handle_if(func, compilation_context, builder, stmt, local_sym_tab)
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab
)
elif isinstance(stmt, ast.Return): elif isinstance(stmt, ast.Return):
did_return = handle_return( did_return = handle_return(
builder, builder, stmt, local_sym_tab, ret_type, compilation_context
stmt,
local_sym_tab,
ret_type,
) )
return did_return return did_return
@ -332,13 +298,11 @@ def process_stmt(
def process_func_body( def process_func_body(
module, compilation_context,
builder, builder,
func_node, func_node,
func, func,
ret_type, ret_type,
map_sym_tab,
structs_sym_tab,
): ):
"""Process the body of a bpf function""" """Process the body of a bpf function"""
# TODO: A lot. We just have print -> bpf_trace_printk for now # TODO: A lot. We just have print -> bpf_trace_printk for now
@ -360,6 +324,9 @@ def process_func_body(
raise TypeError( raise TypeError(
f"Unsupported annotation type: {ast.dump(context_arg.annotation)}" f"Unsupported annotation type: {ast.dump(context_arg.annotation)}"
) )
# Use context's handler if available, else usage of VmlinuxHandlerRegistry
# For now relying on VmlinuxHandlerRegistry which relies on codegen setting it
if VmlinuxHandlerRegistry.is_vmlinux_struct(context_type_name): if VmlinuxHandlerRegistry.is_vmlinux_struct(context_type_name):
resolved_type = VmlinuxHandlerRegistry.get_struct_type( resolved_type = VmlinuxHandlerRegistry.get_struct_type(
context_type_name context_type_name
@ -370,14 +337,12 @@ def process_func_body(
# pre-allocate dynamic variables # pre-allocate dynamic variables
local_sym_tab = allocate_mem( local_sym_tab = allocate_mem(
module, compilation_context,
builder, builder,
func_node.body, func_node.body,
func, func,
ret_type, ret_type,
map_sym_tab,
local_sym_tab, local_sym_tab,
structs_sym_tab,
) )
logger.info(f"Local symbol table: {local_sym_tab.keys()}") logger.info(f"Local symbol table: {local_sym_tab.keys()}")
@ -385,12 +350,10 @@ def process_func_body(
for stmt in func_node.body: for stmt in func_node.body:
did_return = process_stmt( did_return = process_stmt(
func, func,
module, compilation_context,
builder, builder,
stmt, stmt,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab,
did_return, did_return,
ret_type, ret_type,
) )
@ -399,9 +362,12 @@ def process_func_body(
builder.ret(ir.Constant(ir.IntType(64), 0)) builder.ret(ir.Constant(ir.IntType(64), 0))
def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_tab): def process_bpf_chunk(func_node, compilation_context, return_type):
"""Process a single BPF chunk (function) and emit corresponding LLVM IR.""" """Process a single BPF chunk (function) and emit corresponding LLVM IR."""
# Set current function in context (optional but good for future)
compilation_context.current_func = func_node
func_name = func_node.name func_name = func_node.name
ret_type = return_type ret_type = return_type
@ -413,7 +379,7 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t
param_types.append(ir.PointerType()) param_types.append(ir.PointerType())
func_ty = ir.FunctionType(ret_type, param_types) func_ty = ir.FunctionType(ret_type, param_types)
func = ir.Function(module, func_ty, func_name) func = ir.Function(compilation_context.module, func_ty, func_name)
func.linkage = "dso_local" func.linkage = "dso_local"
func.attributes.add("nounwind") func.attributes.add("nounwind")
@ -433,13 +399,11 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t
builder = ir.IRBuilder(block) builder = ir.IRBuilder(block)
process_func_body( process_func_body(
module, compilation_context,
builder, builder,
func_node, func_node,
func, func,
ret_type, ret_type,
map_sym_tab,
structs_sym_tab,
) )
return func return func
@ -449,23 +413,32 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t
# ============================================================================ # ============================================================================
def func_proc(tree, module, chunks, map_sym_tab, structs_sym_tab): def func_proc(tree, compilation_context, chunks):
"""Process all functions decorated with @bpf and @bpfglobal"""
for func_node in chunks: for func_node in chunks:
# Ignore structs and maps
# Check against the lists
if (
func_node.name in compilation_context.structs_sym_tab
or func_node.name in compilation_context.map_sym_tab
):
continue
# Also check decorators to be sure
decorators = [d.id for d in func_node.decorator_list if isinstance(d, ast.Name)]
if "struct" in decorators or "map" in decorators:
continue
if is_global_function(func_node): if is_global_function(func_node):
continue continue
func_type = get_probe_string(func_node) func_type = get_probe_string(func_node)
logger.info(f"Found probe_string of {func_node.name}: {func_type}") logger.info(f"Found probe_string of {func_node.name}: {func_type}")
func = process_bpf_chunk( return_type = ctypes_to_ir(infer_return_type(func_node))
func_node, func = process_bpf_chunk(func_node, compilation_context, return_type)
module,
ctypes_to_ir(infer_return_type(func_node)),
map_sym_tab,
structs_sym_tab,
)
logger.info(f"Generating Debug Info for Function {func_node.name}") logger.info(f"Generating Debug Info for Function {func_node.name}")
generate_function_debug_info(func_node, module, func) generate_function_debug_info(func_node, compilation_context.module, func)
# TODO: WIP, for string assignment to fixed-size arrays # TODO: WIP, for string assignment to fixed-size arrays

View File

@ -7,11 +7,11 @@ from .type_deducer import ctypes_to_ir
logger: Logger = logging.getLogger(__name__) logger: Logger = logging.getLogger(__name__)
# TODO: this is going to be a huge fuck of a headache in the future.
global_sym_tab = []
def populate_global_symbol_table(tree, compilation_context):
def populate_global_symbol_table(tree, module: ir.Module): """
compilation_context: CompilationContext
"""
for node in tree.body: for node in tree.body:
if isinstance(node, ast.FunctionDef): if isinstance(node, ast.FunctionDef):
for dec in node.decorator_list: for dec in node.decorator_list:
@ -23,12 +23,12 @@ def populate_global_symbol_table(tree, module: ir.Module):
and isinstance(dec.args[0], ast.Constant) and isinstance(dec.args[0], ast.Constant)
and isinstance(dec.args[0].value, str) and isinstance(dec.args[0].value, str)
): ):
global_sym_tab.append(node) compilation_context.global_sym_tab.append(node)
elif isinstance(dec, ast.Name) and dec.id == "bpfglobal": elif isinstance(dec, ast.Name) and dec.id == "bpfglobal":
global_sym_tab.append(node) compilation_context.global_sym_tab.append(node)
elif isinstance(dec, ast.Name) and dec.id == "map": elif isinstance(dec, ast.Name) and dec.id == "map":
global_sym_tab.append(node) compilation_context.global_sym_tab.append(node)
return False return False
@ -74,9 +74,12 @@ def emit_global(module: ir.Module, node, name):
return gvar return gvar
def globals_processing(tree, module): def globals_processing(tree, compilation_context):
"""Process stuff decorated with @bpf and @bpfglobal except license and return the section name""" """Process stuff decorated with @bpf and @bpfglobal except license and return the section name"""
globals_sym_tab = [] # Local tracking for duplicate checking if needed, or we can iterate context
# But for now, we process specific nodes
current_globals = []
for node in tree.body: for node in tree.body:
# Skip non-assignment and non-function nodes # Skip non-assignment and non-function nodes
@ -90,10 +93,10 @@ def globals_processing(tree, module):
continue continue
# Check for duplicate names # Check for duplicate names
if name in globals_sym_tab: if name in current_globals:
raise SyntaxError(f"ERROR: Global name '{name}' previously defined") raise SyntaxError(f"ERROR: Global name '{name}' previously defined")
else: else:
globals_sym_tab.append(name) current_globals.append(name)
if isinstance(node, ast.FunctionDef) and node.name != "LICENSE": if isinstance(node, ast.FunctionDef) and node.name != "LICENSE":
decorators = [ decorators = [
@ -108,7 +111,7 @@ def globals_processing(tree, module):
node.body[0].value, (ast.Constant, ast.Name, ast.Call) node.body[0].value, (ast.Constant, ast.Name, ast.Call)
) )
): ):
emit_global(module, node, name) emit_global(compilation_context.module, node, name)
else: else:
raise SyntaxError(f"ERROR: Invalid syntax for {name} global") raise SyntaxError(f"ERROR: Invalid syntax for {name} global")
@ -137,8 +140,9 @@ def emit_llvm_compiler_used(module: ir.Module, names: list[str]):
gv.section = "llvm.metadata" gv.section = "llvm.metadata"
def globals_list_creation(tree, module: ir.Module): def globals_list_creation(tree, compilation_context):
collected = ["LICENSE"] collected = ["LICENSE"]
module = compilation_context.module
for node in tree.body: for node in tree.body:
if isinstance(node, ast.FunctionDef): if isinstance(node, ast.FunctionDef):

View File

@ -1,5 +1,5 @@
from .helper_registry import HelperHandlerRegistry from .helper_registry import HelperHandlerRegistry
from .helper_utils import reset_scratch_pool
from .bpf_helper_handler import ( from .bpf_helper_handler import (
handle_helper_call, handle_helper_call,
emit_probe_read_kernel_str_call, emit_probe_read_kernel_str_call,
@ -28,9 +28,7 @@ def _register_helper_handler():
"""Register helper call handler with the expression evaluator""" """Register helper call handler with the expression evaluator"""
from pythonbpf.expr.expr_pass import CallHandlerRegistry from pythonbpf.expr.expr_pass import CallHandlerRegistry
def helper_call_handler( def helper_call_handler(call, compilation_context, builder, func, local_sym_tab):
call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
):
"""Check if call is a helper and handle it""" """Check if call is a helper and handle it"""
import ast import ast
@ -39,17 +37,16 @@ def _register_helper_handler():
if HelperHandlerRegistry.has_handler(call.func.id): if HelperHandlerRegistry.has_handler(call.func.id):
return handle_helper_call( return handle_helper_call(
call, call,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab,
) )
# Check for method calls (e.g., map.lookup()) # Check for method calls (e.g., map.lookup())
elif isinstance(call.func, ast.Attribute): elif isinstance(call.func, ast.Attribute):
method_name = call.func.attr method_name = call.func.attr
map_sym_tab = compilation_context.map_sym_tab
# Handle: my_map.lookup(key) # Handle: my_map.lookup(key)
if isinstance(call.func.value, ast.Name): if isinstance(call.func.value, ast.Name):
@ -58,12 +55,10 @@ def _register_helper_handler():
if HelperHandlerRegistry.has_handler(method_name): if HelperHandlerRegistry.has_handler(method_name):
return handle_helper_call( return handle_helper_call(
call, call,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab, local_sym_tab,
map_sym_tab,
structs_sym_tab,
) )
return None return None
@ -76,7 +71,6 @@ _register_helper_handler()
__all__ = [ __all__ = [
"HelperHandlerRegistry", "HelperHandlerRegistry",
"reset_scratch_pool",
"handle_helper_call", "handle_helper_call",
"emit_probe_read_kernel_str_call", "emit_probe_read_kernel_str_call",
"emit_probe_read_kernel_call", "emit_probe_read_kernel_call",

View File

@ -50,12 +50,10 @@ class BPFHelperID(Enum):
def bpf_ktime_get_ns_emitter( def bpf_ktime_get_ns_emitter(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab=None, local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
): ):
""" """
Emit LLVM IR for bpf_ktime_get_ns helper function call. Emit LLVM IR for bpf_ktime_get_ns helper function call.
@ -77,12 +75,10 @@ def bpf_ktime_get_ns_emitter(
def bpf_get_current_cgroup_id( def bpf_get_current_cgroup_id(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab=None, local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
): ):
""" """
Emit LLVM IR for bpf_get_current_cgroup_id helper function call. Emit LLVM IR for bpf_get_current_cgroup_id helper function call.
@ -104,12 +100,10 @@ def bpf_get_current_cgroup_id(
def bpf_map_lookup_elem_emitter( def bpf_map_lookup_elem_emitter(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab=None, local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
): ):
""" """
Emit LLVM IR for bpf_map_lookup_elem helper function call. Emit LLVM IR for bpf_map_lookup_elem helper function call.
@ -119,7 +113,7 @@ def bpf_map_lookup_elem_emitter(
f"Map lookup expects exactly one argument (key), got {len(call.args)}" f"Map lookup expects exactly one argument (key), got {len(call.args)}"
) )
key_ptr = get_or_create_ptr_from_arg( key_ptr = get_or_create_ptr_from_arg(
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab func, compilation_context, call.args[0], builder, local_sym_tab
) )
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
@ -147,12 +141,10 @@ def bpf_map_lookup_elem_emitter(
def bpf_printk_emitter( def bpf_printk_emitter(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab=None, local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
): ):
"""Emit LLVM IR for bpf_printk helper function call.""" """Emit LLVM IR for bpf_printk helper function call."""
if not hasattr(func, "_fmt_counter"): if not hasattr(func, "_fmt_counter"):
@ -165,16 +157,18 @@ def bpf_printk_emitter(
if isinstance(call.args[0], ast.JoinedStr): if isinstance(call.args[0], ast.JoinedStr):
args = handle_fstring_print( args = handle_fstring_print(
call.args[0], call.args[0],
module, compilation_context.module,
builder, builder,
func, func,
local_sym_tab, local_sym_tab,
struct_sym_tab, compilation_context.structs_sym_tab,
) )
elif isinstance(call.args[0], ast.Constant) and isinstance(call.args[0].value, str): elif isinstance(call.args[0], ast.Constant) and isinstance(call.args[0].value, str):
# TODO: We are only supporting single arguments for now. # TODO: We are only supporting single arguments for now.
# In case of multiple args, the first one will be taken. # In case of multiple args, the first one will be taken.
args = simple_string_print(call.args[0].value, module, builder, func) args = simple_string_print(
call.args[0].value, compilation_context.module, builder, func
)
else: else:
raise NotImplementedError( raise NotImplementedError(
"Only simple strings or f-strings are supported in bpf_printk." "Only simple strings or f-strings are supported in bpf_printk."
@ -203,12 +197,10 @@ def bpf_printk_emitter(
def bpf_map_update_elem_emitter( def bpf_map_update_elem_emitter(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab=None, local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
): ):
""" """
Emit LLVM IR for bpf_map_update_elem helper function call. Emit LLVM IR for bpf_map_update_elem helper function call.
@ -224,10 +216,10 @@ def bpf_map_update_elem_emitter(
flags_arg = call.args[2] if len(call.args) > 2 else None flags_arg = call.args[2] if len(call.args) > 2 else None
key_ptr = get_or_create_ptr_from_arg( key_ptr = get_or_create_ptr_from_arg(
func, module, key_arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab func, compilation_context, key_arg, builder, local_sym_tab
) )
value_ptr = get_or_create_ptr_from_arg( value_ptr = get_or_create_ptr_from_arg(
func, module, value_arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab func, compilation_context, value_arg, builder, local_sym_tab
) )
flags_val = get_flags_val(flags_arg, builder, local_sym_tab) flags_val = get_flags_val(flags_arg, builder, local_sym_tab)
@ -262,12 +254,10 @@ def bpf_map_update_elem_emitter(
def bpf_map_delete_elem_emitter( def bpf_map_delete_elem_emitter(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab=None, local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
): ):
""" """
Emit LLVM IR for bpf_map_delete_elem helper function call. Emit LLVM IR for bpf_map_delete_elem helper function call.
@ -278,7 +268,7 @@ def bpf_map_delete_elem_emitter(
f"Map delete expects exactly one argument (key), got {len(call.args)}" f"Map delete expects exactly one argument (key), got {len(call.args)}"
) )
key_ptr = get_or_create_ptr_from_arg( key_ptr = get_or_create_ptr_from_arg(
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab func, compilation_context, call.args[0], builder, local_sym_tab
) )
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
@ -306,12 +296,10 @@ def bpf_map_delete_elem_emitter(
def bpf_get_current_comm_emitter( def bpf_get_current_comm_emitter(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab=None, local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
): ):
""" """
Emit LLVM IR for bpf_get_current_comm helper function call. Emit LLVM IR for bpf_get_current_comm helper function call.
@ -327,7 +315,7 @@ def bpf_get_current_comm_emitter(
# Extract buffer pointer and size # Extract buffer pointer and size
buf_ptr, buf_size = get_buffer_ptr_and_size( buf_ptr, buf_size = get_buffer_ptr_and_size(
buf_arg, builder, local_sym_tab, struct_sym_tab buf_arg, builder, local_sym_tab, compilation_context
) )
# Validate it's a char array # Validate it's a char array
@ -367,12 +355,10 @@ def bpf_get_current_comm_emitter(
def bpf_get_current_pid_tgid_emitter( def bpf_get_current_pid_tgid_emitter(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab=None, local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
): ):
""" """
Emit LLVM IR for bpf_get_current_pid_tgid helper function call. Emit LLVM IR for bpf_get_current_pid_tgid helper function call.
@ -394,12 +380,10 @@ def bpf_get_current_pid_tgid_emitter(
def bpf_perf_event_output_handler( def bpf_perf_event_output_handler(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab=None, local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
): ):
""" """
Emit LLVM IR for bpf_perf_event_output helper function call. Emit LLVM IR for bpf_perf_event_output helper function call.
@ -412,7 +396,9 @@ def bpf_perf_event_output_handler(
data_arg = call.args[0] data_arg = call.args[0]
ctx_ptr = func.args[0] # First argument to the function is ctx ctx_ptr = func.args[0] # First argument to the function is ctx
data_ptr, size_val = get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab) data_ptr, size_val = get_data_ptr_and_size(
data_arg, local_sym_tab, compilation_context.structs_sym_tab
)
# BPF_F_CURRENT_CPU is -1 in 32 bit # BPF_F_CURRENT_CPU is -1 in 32 bit
flags_val = ir.Constant(ir.IntType(64), 0xFFFFFFFF) flags_val = ir.Constant(ir.IntType(64), 0xFFFFFFFF)
@ -445,12 +431,10 @@ def bpf_perf_event_output_handler(
def bpf_ringbuf_output_emitter( def bpf_ringbuf_output_emitter(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab=None, local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
): ):
""" """
Emit LLVM IR for bpf_ringbuf_output helper function call. Emit LLVM IR for bpf_ringbuf_output helper function call.
@ -461,7 +445,9 @@ def bpf_ringbuf_output_emitter(
f"Ringbuf output expects exactly one argument, got {len(call.args)}" f"Ringbuf output expects exactly one argument, got {len(call.args)}"
) )
data_arg = call.args[0] data_arg = call.args[0]
data_ptr, size_val = get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab) data_ptr, size_val = get_data_ptr_and_size(
data_arg, local_sym_tab, compilation_context.structs_sym_tab
)
flags_val = ir.Constant(ir.IntType(64), 0) flags_val = ir.Constant(ir.IntType(64), 0)
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
@ -496,38 +482,32 @@ def bpf_ringbuf_output_emitter(
def handle_output_helper( def handle_output_helper(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab=None, local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
): ):
""" """
Route output helper to the appropriate emitter based on map type. Route output helper to the appropriate emitter based on map type.
""" """
match map_sym_tab[map_ptr.name].type: match compilation_context.map_sym_tab[map_ptr.name].type:
case BPFMapType.PERF_EVENT_ARRAY: case BPFMapType.PERF_EVENT_ARRAY:
return bpf_perf_event_output_handler( return bpf_perf_event_output_handler(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab, local_sym_tab,
struct_sym_tab,
map_sym_tab,
) )
case BPFMapType.RINGBUF: case BPFMapType.RINGBUF:
return bpf_ringbuf_output_emitter( return bpf_ringbuf_output_emitter(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab, local_sym_tab,
struct_sym_tab,
map_sym_tab,
) )
case _: case _:
logger.error("Unsupported map type for output helper.") logger.error("Unsupported map type for output helper.")
@ -572,12 +552,10 @@ def emit_probe_read_kernel_str_call(builder, dst_ptr, dst_size, src_ptr):
def bpf_probe_read_kernel_str_emitter( def bpf_probe_read_kernel_str_emitter(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab=None, local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
): ):
"""Emit LLVM IR for bpf_probe_read_kernel_str helper.""" """Emit LLVM IR for bpf_probe_read_kernel_str helper."""
@ -588,12 +566,12 @@ def bpf_probe_read_kernel_str_emitter(
# Get destination buffer (char array -> i8*) # Get destination buffer (char array -> i8*)
dst_ptr, dst_size = get_or_create_ptr_from_arg( dst_ptr, dst_size = get_or_create_ptr_from_arg(
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab func, compilation_context, call.args[0], builder, local_sym_tab
) )
# Get source pointer (evaluate expression) # Get source pointer (evaluate expression)
src_ptr, src_type = get_ptr_from_arg( src_ptr, src_type = get_ptr_from_arg(
call.args[1], func, module, builder, local_sym_tab, map_sym_tab, struct_sym_tab call.args[1], func, compilation_context, builder, local_sym_tab
) )
# Emit the helper call # Emit the helper call
@ -641,12 +619,10 @@ def emit_probe_read_kernel_call(builder, dst_ptr, dst_size, src_ptr):
def bpf_probe_read_kernel_emitter( def bpf_probe_read_kernel_emitter(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab=None, local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
): ):
"""Emit LLVM IR for bpf_probe_read_kernel helper.""" """Emit LLVM IR for bpf_probe_read_kernel helper."""
@ -657,12 +633,12 @@ def bpf_probe_read_kernel_emitter(
# Get destination buffer (char array -> i8*) # Get destination buffer (char array -> i8*)
dst_ptr, dst_size = get_or_create_ptr_from_arg( dst_ptr, dst_size = get_or_create_ptr_from_arg(
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab func, compilation_context, call.args[0], builder, local_sym_tab
) )
# Get source pointer (evaluate expression) # Get source pointer (evaluate expression)
src_ptr, src_type = get_ptr_from_arg( src_ptr, src_type = get_ptr_from_arg(
call.args[1], func, module, builder, local_sym_tab, map_sym_tab, struct_sym_tab call.args[1], func, compilation_context, builder, local_sym_tab
) )
# Emit the helper call # Emit the helper call
@ -680,12 +656,10 @@ def bpf_probe_read_kernel_emitter(
def bpf_get_prandom_u32_emitter( def bpf_get_prandom_u32_emitter(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab=None, local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
): ):
""" """
Emit LLVM IR for bpf_get_prandom_u32 helper function call. Emit LLVM IR for bpf_get_prandom_u32 helper function call.
@ -710,12 +684,10 @@ def bpf_get_prandom_u32_emitter(
def bpf_probe_read_emitter( def bpf_probe_read_emitter(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab=None, local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
): ):
""" """
Emit LLVM IR for bpf_probe_read helper function Emit LLVM IR for bpf_probe_read helper function
@ -726,31 +698,25 @@ def bpf_probe_read_emitter(
return return
dst_ptr = get_or_create_ptr_from_arg( dst_ptr = get_or_create_ptr_from_arg(
func, func,
module, compilation_context,
call.args[0], call.args[0],
builder, builder,
local_sym_tab, local_sym_tab,
map_sym_tab,
struct_sym_tab,
ir.IntType(8), ir.IntType(8),
) )
size_val = get_int_value_from_arg( size_val = get_int_value_from_arg(
call.args[1], call.args[1],
func, func,
module, compilation_context,
builder, builder,
local_sym_tab, local_sym_tab,
map_sym_tab,
struct_sym_tab,
) )
src_ptr = get_or_create_ptr_from_arg( src_ptr = get_or_create_ptr_from_arg(
func, func,
module, compilation_context,
call.args[2], call.args[2],
builder, builder,
local_sym_tab, local_sym_tab,
map_sym_tab,
struct_sym_tab,
ir.IntType(8), ir.IntType(8),
) )
fn_type = ir.FunctionType( fn_type = ir.FunctionType(
@ -783,12 +749,10 @@ def bpf_probe_read_emitter(
def bpf_get_smp_processor_id_emitter( def bpf_get_smp_processor_id_emitter(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab=None, local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
): ):
""" """
Emit LLVM IR for bpf_get_smp_processor_id helper function call. Emit LLVM IR for bpf_get_smp_processor_id helper function call.
@ -810,12 +774,10 @@ def bpf_get_smp_processor_id_emitter(
def bpf_get_current_uid_gid_emitter( def bpf_get_current_uid_gid_emitter(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab=None, local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
): ):
""" """
Emit LLVM IR for bpf_get_current_uid_gid helper function call. Emit LLVM IR for bpf_get_current_uid_gid helper function call.
@ -846,12 +808,10 @@ def bpf_get_current_uid_gid_emitter(
def bpf_skb_store_bytes_emitter( def bpf_skb_store_bytes_emitter(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab=None, local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
): ):
""" """
Emit LLVM IR for bpf_skb_store_bytes helper function call. Emit LLVM IR for bpf_skb_store_bytes helper function call.
@ -875,30 +835,24 @@ def bpf_skb_store_bytes_emitter(
offset_val = get_int_value_from_arg( offset_val = get_int_value_from_arg(
call.args[0], call.args[0],
func, func,
module, compilation_context,
builder, builder,
local_sym_tab, local_sym_tab,
map_sym_tab,
struct_sym_tab,
) )
from_ptr = get_or_create_ptr_from_arg( from_ptr = get_or_create_ptr_from_arg(
func, func,
module, compilation_context,
call.args[1], call.args[1],
builder, builder,
local_sym_tab, local_sym_tab,
map_sym_tab,
struct_sym_tab,
args_signature[2], args_signature[2],
) )
len_val = get_int_value_from_arg( len_val = get_int_value_from_arg(
call.args[2], call.args[2],
func, func,
module, compilation_context,
builder, builder,
local_sym_tab, local_sym_tab,
map_sym_tab,
struct_sym_tab,
) )
if len(call.args) == 4: if len(call.args) == 4:
flags_val = get_flags_val(call.args[3], builder, local_sym_tab) flags_val = get_flags_val(call.args[3], builder, local_sym_tab)
@ -940,12 +894,10 @@ def bpf_skb_store_bytes_emitter(
def bpf_ringbuf_reserve_emitter( def bpf_ringbuf_reserve_emitter(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab=None, local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
): ):
""" """
Emit LLVM IR for bpf_ringbuf_reserve helper function call. Emit LLVM IR for bpf_ringbuf_reserve helper function call.
@ -960,11 +912,9 @@ def bpf_ringbuf_reserve_emitter(
size_val = get_int_value_from_arg( size_val = get_int_value_from_arg(
call.args[0], call.args[0],
func, func,
module, compilation_context,
builder, builder,
local_sym_tab, local_sym_tab,
map_sym_tab,
struct_sym_tab,
) )
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
@ -991,12 +941,10 @@ def bpf_ringbuf_reserve_emitter(
def bpf_ringbuf_submit_emitter( def bpf_ringbuf_submit_emitter(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab=None, local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
): ):
""" """
Emit LLVM IR for bpf_ringbuf_submit helper function call. Emit LLVM IR for bpf_ringbuf_submit helper function call.
@ -1013,12 +961,10 @@ def bpf_ringbuf_submit_emitter(
data_ptr = get_or_create_ptr_from_arg( data_ptr = get_or_create_ptr_from_arg(
func, func,
module, compilation_context,
data_arg, data_arg,
builder, builder,
local_sym_tab, local_sym_tab,
map_sym_tab,
struct_sym_tab,
ir.PointerType(ir.IntType(8)), ir.PointerType(ir.IntType(8)),
) )
@ -1050,12 +996,10 @@ def bpf_ringbuf_submit_emitter(
def bpf_get_stack_emitter( def bpf_get_stack_emitter(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab=None, local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
): ):
""" """
Emit LLVM IR for bpf_get_stack helper function call. Emit LLVM IR for bpf_get_stack helper function call.
@ -1068,7 +1012,7 @@ def bpf_get_stack_emitter(
buf_arg = call.args[0] buf_arg = call.args[0]
flags_arg = call.args[1] if len(call.args) == 2 else None flags_arg = call.args[1] if len(call.args) == 2 else None
buf_ptr, buf_size = get_buffer_ptr_and_size( buf_ptr, buf_size = get_buffer_ptr_and_size(
buf_arg, builder, local_sym_tab, struct_sym_tab buf_arg, builder, local_sym_tab, compilation_context
) )
flags_val = get_flags_val(flags_arg, builder, local_sym_tab) flags_val = get_flags_val(flags_arg, builder, local_sym_tab)
if isinstance(flags_val, int): if isinstance(flags_val, int):
@ -1098,12 +1042,10 @@ def bpf_get_stack_emitter(
def handle_helper_call( def handle_helper_call(
call, call,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab=None, local_sym_tab=None,
map_sym_tab=None,
struct_sym_tab=None,
): ):
"""Process a BPF helper function call and emit the appropriate LLVM IR.""" """Process a BPF helper function call and emit the appropriate LLVM IR."""
@ -1117,14 +1059,14 @@ def handle_helper_call(
return handler( return handler(
call, call,
map_ptr, map_ptr,
module, compilation_context,
builder, builder,
func, func,
local_sym_tab, local_sym_tab,
struct_sym_tab,
map_sym_tab,
) )
map_sym_tab = compilation_context.map_sym_tab
# Handle direct function calls (e.g., print(), ktime()) # Handle direct function calls (e.g., print(), ktime())
if isinstance(call.func, ast.Name): if isinstance(call.func, ast.Name):
return invoke_helper(call.func.id) return invoke_helper(call.func.id)

View File

@ -3,7 +3,6 @@ import logging
from llvmlite import ir from llvmlite import ir
from pythonbpf.expr import ( from pythonbpf.expr import (
get_operand_value,
eval_expr, eval_expr,
access_struct_field, access_struct_field,
) )
@ -11,56 +10,38 @@ from pythonbpf.expr import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ScratchPoolManager: # NOTE: ScratchPoolManager is now in context.py
"""Manage the temporary helper variables in local_sym_tab"""
def __init__(self):
self._counters = {}
@property
def counter(self):
return sum(self._counters.values())
def reset(self):
self._counters.clear()
logger.debug("Scratch pool counter reset to 0")
def _get_type_name(self, ir_type):
if isinstance(ir_type, ir.PointerType):
return "ptr"
elif isinstance(ir_type, ir.IntType):
return f"i{ir_type.width}"
elif isinstance(ir_type, ir.ArrayType):
return f"[{ir_type.count}x{self._get_type_name(ir_type.element)}]"
else:
return str(ir_type).replace(" ", "")
def get_next_temp(self, local_sym_tab, expected_type=None):
# Default to i64 if no expected type provided
type_name = self._get_type_name(expected_type) if expected_type else "i64"
if type_name not in self._counters:
self._counters[type_name] = 0
counter = self._counters[type_name]
temp_name = f"__helper_temp_{type_name}_{counter}"
self._counters[type_name] += 1
if temp_name not in local_sym_tab:
raise ValueError(
f"Scratch pool exhausted or inadequate: {temp_name}. "
f"Type: {type_name} Counter: {counter}"
)
logger.debug(f"Using {temp_name} for type {type_name}")
return local_sym_tab[temp_name].var, temp_name
_temp_pool_manager = ScratchPoolManager() # Singleton instance def get_ptr_from_arg(arg, compilation_context, builder, local_sym_tab):
"""Helper to get a pointer value from an argument."""
# This is a bit duplicative of logic in eval_expr but simplified for helpers
# We might need to handle more cases here or defer to eval_expr
# Simple check for name
if isinstance(arg, ast.Name):
if arg.id in local_sym_tab:
sym = local_sym_tab[arg.id]
if isinstance(sym.ir_type, ir.PointerType):
return builder.load(sym.var)
# If it's an array/struct we might need GEP depending on how it was allocated
# For now assume load returns the pointer/value
return builder.load(sym.var)
def reset_scratch_pool(): # Use eval_expr for general case
"""Reset the scratch pool counter""" val = eval_expr(
_temp_pool_manager.reset() None,
compilation_context.module,
builder,
arg,
local_sym_tab,
compilation_context.map_sym_tab,
compilation_context.structs_sym_tab,
)
if val and isinstance(val[0].type, ir.PointerType):
return val[0]
return None
# ============================================================================ # ============================================================================
@ -75,11 +56,15 @@ def get_var_ptr_from_name(var_name, local_sym_tab):
raise ValueError(f"Variable '{var_name}' not found in local symbol table") raise ValueError(f"Variable '{var_name}' not found in local symbol table")
def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64): def create_int_constant_ptr(
value, builder, compilation_context, local_sym_tab, int_width=64
):
"""Create a pointer to an integer constant.""" """Create a pointer to an integer constant."""
int_type = ir.IntType(int_width) int_type = ir.IntType(int_width)
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab, int_type) ptr, temp_name = compilation_context.scratch_pool.get_next_temp(
local_sym_tab, int_type
)
logger.info(f"Using temp variable '{temp_name}' for int constant {value}") logger.info(f"Using temp variable '{temp_name}' for int constant {value}")
const_val = ir.Constant(int_type, value) const_val = ir.Constant(int_type, value)
builder.store(const_val, ptr) builder.store(const_val, ptr)
@ -88,12 +73,10 @@ def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64):
def get_or_create_ptr_from_arg( def get_or_create_ptr_from_arg(
func, func,
module, compilation_context,
arg, arg,
builder, builder,
local_sym_tab, local_sym_tab,
map_sym_tab,
struct_sym_tab=None,
expected_type=None, expected_type=None,
): ):
"""Extract or create pointer from the call arguments.""" """Extract or create pointer from the call arguments."""
@ -102,16 +85,22 @@ def get_or_create_ptr_from_arg(
sz = None sz = None
if isinstance(arg, ast.Name): if isinstance(arg, ast.Name):
# Stack space is already allocated # Stack space is already allocated
ptr = get_var_ptr_from_name(arg.id, local_sym_tab) if arg.id in local_sym_tab:
ptr = local_sym_tab[arg.id].var
else:
raise ValueError(f"Variable '{arg.id}' not found")
elif isinstance(arg, ast.Constant) and isinstance(arg.value, int): elif isinstance(arg, ast.Constant) and isinstance(arg.value, int):
int_width = 64 # Default to i64 int_width = 64 # Default to i64
if expected_type and isinstance(expected_type, ir.IntType): if expected_type and isinstance(expected_type, ir.IntType):
int_width = expected_type.width int_width = expected_type.width
ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab, int_width) ptr = create_int_constant_ptr(
arg.value, builder, compilation_context, local_sym_tab, int_width
)
elif isinstance(arg, ast.Attribute): elif isinstance(arg, ast.Attribute):
# A struct field # A struct field
struct_name = arg.value.id struct_name = arg.value.id
field_name = arg.attr field_name = arg.attr
struct_sym_tab = compilation_context.structs_sym_tab
if not local_sym_tab or struct_name not in local_sym_tab: if not local_sym_tab or struct_name not in local_sym_tab:
raise ValueError(f"Struct '{struct_name}' not found") raise ValueError(f"Struct '{struct_name}' not found")
@ -136,7 +125,7 @@ def get_or_create_ptr_from_arg(
and field_type.element.width == 8 and field_type.element.width == 8
): ):
ptr, sz = get_char_array_ptr_and_size( ptr, sz = get_char_array_ptr_and_size(
arg, builder, local_sym_tab, struct_sym_tab, func arg, builder, local_sym_tab, compilation_context, func
) )
if not ptr: if not ptr:
raise ValueError("Failed to get char array pointer from struct field") raise ValueError("Failed to get char array pointer from struct field")
@ -146,13 +135,15 @@ def get_or_create_ptr_from_arg(
else: else:
# NOTE: For any integer expression reaching this branch, it is probably a struct field or a binop # NOTE: For any integer expression reaching this branch, it is probably a struct field or a binop
# Evaluate the expression and store the result in a temp variable # Evaluate the expression and store the result in a temp variable
val = get_operand_value( val = eval_expr(func, compilation_context, builder, arg, local_sym_tab)
func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab if val:
) val = val[0]
if val is None: if val is None:
raise ValueError("Failed to evaluate expression for helper arg.") raise ValueError("Failed to evaluate expression for helper arg.")
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab, expected_type) ptr, temp_name = compilation_context.scratch_pool.get_next_temp(
local_sym_tab, expected_type
)
logger.info(f"Using temp variable '{temp_name}' for expression result") logger.info(f"Using temp variable '{temp_name}' for expression result")
if ( if (
isinstance(val.type, ir.IntType) isinstance(val.type, ir.IntType)
@ -188,8 +179,9 @@ def get_flags_val(arg, builder, local_sym_tab):
) )
def get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab): def get_data_ptr_and_size(data_arg, local_sym_tab, compilation_context):
"""Extract data pointer and size information for perf event output.""" """Extract data pointer and size information for perf event output."""
struct_sym_tab = compilation_context.structs_sym_tab
if isinstance(data_arg, ast.Name): if isinstance(data_arg, ast.Name):
data_name = data_arg.id data_name = data_arg.id
if local_sym_tab and data_name in local_sym_tab: if local_sym_tab and data_name in local_sym_tab:
@ -213,8 +205,9 @@ def get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab):
) )
def get_buffer_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab): def get_buffer_ptr_and_size(buf_arg, builder, local_sym_tab, compilation_context):
"""Extract buffer pointer and size from either a struct field or variable.""" """Extract buffer pointer and size from either a struct field or variable."""
struct_sym_tab = compilation_context.structs_sym_tab
# Case 1: Struct field (obj.field) # Case 1: Struct field (obj.field)
if isinstance(buf_arg, ast.Attribute): if isinstance(buf_arg, ast.Attribute):
@ -268,9 +261,10 @@ def get_buffer_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab):
def get_char_array_ptr_and_size( def get_char_array_ptr_and_size(
buf_arg, builder, local_sym_tab, struct_sym_tab, func=None buf_arg, builder, local_sym_tab, compilation_context, func=None
): ):
"""Get pointer to char array and its size.""" """Get pointer to char array and its size."""
struct_sym_tab = compilation_context.structs_sym_tab
# Struct field: obj.field # Struct field: obj.field
if isinstance(buf_arg, ast.Attribute) and isinstance(buf_arg.value, ast.Name): if isinstance(buf_arg, ast.Attribute) and isinstance(buf_arg.value, ast.Name):
@ -351,34 +345,10 @@ def _is_char_array(ir_type):
) )
def get_ptr_from_arg( def get_int_value_from_arg(arg, func, compilation_context, builder, local_sym_tab):
arg, func, module, builder, local_sym_tab, map_sym_tab, struct_sym_tab
):
"""Evaluate argument and return pointer value"""
result = eval_expr(
func, module, builder, arg, local_sym_tab, map_sym_tab, struct_sym_tab
)
if not result:
raise ValueError("Failed to evaluate argument")
val, val_type = result
if not isinstance(val_type, ir.PointerType):
raise ValueError(f"Expected pointer type, got {val_type}")
return val, val_type
def get_int_value_from_arg(
arg, func, module, builder, local_sym_tab, map_sym_tab, struct_sym_tab
):
"""Evaluate argument and return integer value""" """Evaluate argument and return integer value"""
result = eval_expr( result = eval_expr(func, compilation_context, builder, arg, local_sym_tab)
func, module, builder, arg, local_sym_tab, map_sym_tab, struct_sym_tab
)
if not result: if not result:
raise ValueError("Failed to evaluate argument") raise ValueError("Failed to evaluate argument")

View File

@ -23,7 +23,7 @@ def emit_license(module: ir.Module, license_str: str):
return gvar return gvar
def license_processing(tree, module): def license_processing(tree, compilation_context):
"""Process the LICENSE function decorated with @bpf and @bpfglobal and return the section name""" """Process the LICENSE function decorated with @bpf and @bpfglobal and return the section name"""
count = 0 count = 0
for node in tree.body: for node in tree.body:
@ -42,12 +42,14 @@ def license_processing(tree, module):
and isinstance(node.body[0].value, ast.Constant) and isinstance(node.body[0].value, ast.Constant)
and isinstance(node.body[0].value.value, str) and isinstance(node.body[0].value.value, str)
): ):
emit_license(module, node.body[0].value.value) emit_license(
compilation_context.module, node.body[0].value.value
)
return "LICENSE" return "LICENSE"
else: else:
logger.info("ERROR: LICENSE() must return a string literal") raise SyntaxError(
return None "ERROR: LICENSE() must return a string literal"
)
else: else:
logger.info("ERROR: LICENSE already defined") raise SyntaxError("ERROR: Multiple LICENSE globals defined")
return None
return None return None

View File

@ -12,14 +12,14 @@ from pythonbpf.expr.vmlinux_registry import VmlinuxHandlerRegistry
logger: Logger = logging.getLogger(__name__) logger: Logger = logging.getLogger(__name__)
def maps_proc(tree, module, chunks, structs_sym_tab): def maps_proc(tree, compilation_context, chunks):
"""Process all functions decorated with @map to find BPF maps""" """Process all functions decorated with @map to find BPF maps"""
map_sym_tab = {} map_sym_tab = compilation_context.map_sym_tab
for func_node in chunks: for func_node in chunks:
if is_map(func_node): if is_map(func_node):
logger.info(f"Found BPF map: {func_node.name}") logger.info(f"Found BPF map: {func_node.name}")
map_sym_tab[func_node.name] = process_bpf_map( map_sym_tab[func_node.name] = process_bpf_map(
func_node, module, structs_sym_tab func_node, compilation_context
) )
return map_sym_tab return map_sym_tab
@ -51,11 +51,11 @@ def create_bpf_map(module, map_name, map_params):
return MapSymbol(type=map_params["type"], sym=map_global, params=map_params) return MapSymbol(type=map_params["type"], sym=map_global, params=map_params)
def _parse_map_params(rval, expected_args=None): def _parse_map_params(rval, compilation_context, expected_args=None):
"""Parse map parameters from call arguments and keywords.""" """Parse map parameters from call arguments and keywords."""
params = {} params = {}
handler = VmlinuxHandlerRegistry.get_handler() handler = compilation_context.vmlinux_handler
# Parse positional arguments # Parse positional arguments
if expected_args: if expected_args:
for i, arg_name in enumerate(expected_args): for i, arg_name in enumerate(expected_args):
@ -83,12 +83,23 @@ def _get_vmlinux_enum(handler, name):
if handler and handler.is_vmlinux_enum(name): if handler and handler.is_vmlinux_enum(name):
return handler.get_vmlinux_enum_value(name) return handler.get_vmlinux_enum_value(name)
# Fallback to VmlinuxHandlerRegistry if handler invalid
# This is for backward compatibility or if refactoring isn't complete
if (
VmlinuxHandlerRegistry.get_handler()
and VmlinuxHandlerRegistry.get_handler().is_vmlinux_enum(name)
):
return VmlinuxHandlerRegistry.get_handler().get_vmlinux_enum_value(name)
return None
@MapProcessorRegistry.register("RingBuffer") @MapProcessorRegistry.register("RingBuffer")
def process_ringbuf_map(map_name, rval, module, structs_sym_tab): def process_ringbuf_map(map_name, rval, compilation_context):
"""Process a BPF_RINGBUF map declaration""" """Process a BPF_RINGBUF map declaration"""
logger.info(f"Processing Ringbuf: {map_name}") logger.info(f"Processing Ringbuf: {map_name}")
map_params = _parse_map_params(rval, expected_args=["max_entries"]) map_params = _parse_map_params(
rval, compilation_context, expected_args=["max_entries"]
)
map_params["type"] = BPFMapType.RINGBUF map_params["type"] = BPFMapType.RINGBUF
# NOTE: constraints borrowed from https://docs.ebpf.io/linux/map-type/BPF_MAP_TYPE_RINGBUF/ # NOTE: constraints borrowed from https://docs.ebpf.io/linux/map-type/BPF_MAP_TYPE_RINGBUF/
@ -104,42 +115,62 @@ def process_ringbuf_map(map_name, rval, module, structs_sym_tab):
logger.info(f"Ringbuf map parameters: {map_params}") logger.info(f"Ringbuf map parameters: {map_params}")
map_global = create_bpf_map(module, map_name, map_params) map_global = create_bpf_map(compilation_context.module, map_name, map_params)
create_ringbuf_debug_info( create_ringbuf_debug_info(
module, map_global.sym, map_name, map_params, structs_sym_tab compilation_context.module,
map_global.sym,
map_name,
map_params,
compilation_context.structs_sym_tab,
) )
return map_global return map_global
@MapProcessorRegistry.register("HashMap") @MapProcessorRegistry.register("HashMap")
def process_hash_map(map_name, rval, module, structs_sym_tab): def process_hash_map(map_name, rval, compilation_context):
"""Process a BPF_HASH map declaration""" """Process a BPF_HASH map declaration"""
logger.info(f"Processing HashMap: {map_name}") logger.info(f"Processing HashMap: {map_name}")
map_params = _parse_map_params(rval, expected_args=["key", "value", "max_entries"]) map_params = _parse_map_params(
rval, compilation_context, expected_args=["key", "value", "max_entries"]
)
map_params["type"] = BPFMapType.HASH map_params["type"] = BPFMapType.HASH
logger.info(f"Map parameters: {map_params}") logger.info(f"Map parameters: {map_params}")
map_global = create_bpf_map(module, map_name, map_params) map_global = create_bpf_map(compilation_context.module, map_name, map_params)
# Generate debug info for BTF # Generate debug info for BTF
create_map_debug_info(module, map_global.sym, map_name, map_params, structs_sym_tab) create_map_debug_info(
compilation_context.module,
map_global.sym,
map_name,
map_params,
compilation_context.structs_sym_tab,
)
return map_global return map_global
@MapProcessorRegistry.register("PerfEventArray") @MapProcessorRegistry.register("PerfEventArray")
def process_perf_event_map(map_name, rval, module, structs_sym_tab): def process_perf_event_map(map_name, rval, compilation_context):
"""Process a BPF_PERF_EVENT_ARRAY map declaration""" """Process a BPF_PERF_EVENT_ARRAY map declaration"""
logger.info(f"Processing PerfEventArray: {map_name}") logger.info(f"Processing PerfEventArray: {map_name}")
map_params = _parse_map_params(rval, expected_args=["key_size", "value_size"]) map_params = _parse_map_params(
rval, compilation_context, expected_args=["key_size", "value_size"]
)
map_params["type"] = BPFMapType.PERF_EVENT_ARRAY map_params["type"] = BPFMapType.PERF_EVENT_ARRAY
logger.info(f"Map parameters: {map_params}") logger.info(f"Map parameters: {map_params}")
map_global = create_bpf_map(module, map_name, map_params) map_global = create_bpf_map(compilation_context.module, map_name, map_params)
# Generate debug info for BTF # Generate debug info for BTF
create_map_debug_info(module, map_global.sym, map_name, map_params, structs_sym_tab) create_map_debug_info(
compilation_context.module,
map_global.sym,
map_name,
map_params,
compilation_context.structs_sym_tab,
)
return map_global return map_global
def process_bpf_map(func_node, module, structs_sym_tab): def process_bpf_map(func_node, compilation_context):
"""Process a BPF map (a function decorated with @map)""" """Process a BPF map (a function decorated with @map)"""
map_name = func_node.name map_name = func_node.name
logger.info(f"Processing BPF map: {map_name}") logger.info(f"Processing BPF map: {map_name}")
@ -158,9 +189,9 @@ def process_bpf_map(func_node, module, structs_sym_tab):
if isinstance(rval, ast.Call) and isinstance(rval.func, ast.Name): if isinstance(rval, ast.Call) and isinstance(rval.func, ast.Name):
handler = MapProcessorRegistry.get_processor(rval.func.id) handler = MapProcessorRegistry.get_processor(rval.func.id)
if handler: if handler:
return handler(map_name, rval, module, structs_sym_tab) return handler(map_name, rval, compilation_context)
else: else:
logger.warning(f"Unknown map type {rval.func.id}, defaulting to HashMap") logger.warning(f"Unknown map type {rval.func.id}, defaulting to HashMap")
return process_hash_map(map_name, rval, module) return process_hash_map(map_name, rval, compilation_context)
else: else:
raise ValueError("Function under @map must return a map") raise ValueError("Function under @map must return a map")

View File

@ -14,14 +14,17 @@ logger = logging.getLogger(__name__)
# Shall we just int64, int32 and uint32 similarly? # Shall we just int64, int32 and uint32 similarly?
def structs_proc(tree, module, chunks): def structs_proc(tree, compilation_context, chunks):
"""Process all class definitions to find BPF structs""" """Process all class definitions to find BPF structs"""
structs_sym_tab = {} # Use the context's symbol table
structs_sym_tab = compilation_context.structs_sym_tab
for cls_node in chunks: for cls_node in chunks:
if is_bpf_struct(cls_node): if is_bpf_struct(cls_node):
logger.info(f"Found BPF struct: {cls_node.name}") logger.info(f"Found BPF struct: {cls_node.name}")
struct_info = process_bpf_struct(cls_node, module) struct_info = process_bpf_struct(cls_node, compilation_context)
structs_sym_tab[cls_node.name] = struct_info structs_sym_tab[cls_node.name] = struct_info
return structs_sym_tab return structs_sym_tab
@ -32,7 +35,7 @@ def is_bpf_struct(cls_node):
) )
def process_bpf_struct(cls_node, module): def process_bpf_struct(cls_node, compilation_context):
"""Process a single BPF struct definition""" """Process a single BPF struct definition"""
fields = parse_struct_fields(cls_node) fields = parse_struct_fields(cls_node)