mirror of
https://github.com/varun-r-mallya/Python-BPF.git
synced 2026-03-19 11:41:28 +00:00
PythonBPF: Add Compilation Context to allow parallel compilation of multiple bpf programs
This commit is contained in:
@ -3,7 +3,6 @@ import logging
|
||||
|
||||
from llvmlite import ir
|
||||
from pythonbpf.expr import (
|
||||
get_operand_value,
|
||||
eval_expr,
|
||||
access_struct_field,
|
||||
)
|
||||
@ -11,56 +10,38 @@ from pythonbpf.expr import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ScratchPoolManager:
|
||||
"""Manage the temporary helper variables in local_sym_tab"""
|
||||
|
||||
def __init__(self):
|
||||
self._counters = {}
|
||||
|
||||
@property
|
||||
def counter(self):
|
||||
return sum(self._counters.values())
|
||||
|
||||
def reset(self):
|
||||
self._counters.clear()
|
||||
logger.debug("Scratch pool counter reset to 0")
|
||||
|
||||
def _get_type_name(self, ir_type):
|
||||
if isinstance(ir_type, ir.PointerType):
|
||||
return "ptr"
|
||||
elif isinstance(ir_type, ir.IntType):
|
||||
return f"i{ir_type.width}"
|
||||
elif isinstance(ir_type, ir.ArrayType):
|
||||
return f"[{ir_type.count}x{self._get_type_name(ir_type.element)}]"
|
||||
else:
|
||||
return str(ir_type).replace(" ", "")
|
||||
|
||||
def get_next_temp(self, local_sym_tab, expected_type=None):
|
||||
# Default to i64 if no expected type provided
|
||||
type_name = self._get_type_name(expected_type) if expected_type else "i64"
|
||||
if type_name not in self._counters:
|
||||
self._counters[type_name] = 0
|
||||
|
||||
counter = self._counters[type_name]
|
||||
temp_name = f"__helper_temp_{type_name}_{counter}"
|
||||
self._counters[type_name] += 1
|
||||
|
||||
if temp_name not in local_sym_tab:
|
||||
raise ValueError(
|
||||
f"Scratch pool exhausted or inadequate: {temp_name}. "
|
||||
f"Type: {type_name} Counter: {counter}"
|
||||
)
|
||||
|
||||
logger.debug(f"Using {temp_name} for type {type_name}")
|
||||
return local_sym_tab[temp_name].var, temp_name
|
||||
# NOTE: ScratchPoolManager is now in context.py
|
||||
|
||||
|
||||
_temp_pool_manager = ScratchPoolManager() # Singleton instance
|
||||
def get_ptr_from_arg(arg, compilation_context, builder, local_sym_tab):
|
||||
"""Helper to get a pointer value from an argument."""
|
||||
# This is a bit duplicative of logic in eval_expr but simplified for helpers
|
||||
# We might need to handle more cases here or defer to eval_expr
|
||||
|
||||
# Simple check for name
|
||||
if isinstance(arg, ast.Name):
|
||||
if arg.id in local_sym_tab:
|
||||
sym = local_sym_tab[arg.id]
|
||||
if isinstance(sym.ir_type, ir.PointerType):
|
||||
return builder.load(sym.var)
|
||||
# If it's an array/struct we might need GEP depending on how it was allocated
|
||||
# For now assume load returns the pointer/value
|
||||
return builder.load(sym.var)
|
||||
|
||||
def reset_scratch_pool():
|
||||
"""Reset the scratch pool counter"""
|
||||
_temp_pool_manager.reset()
|
||||
# Use eval_expr for general case
|
||||
val = eval_expr(
|
||||
None,
|
||||
compilation_context.module,
|
||||
builder,
|
||||
arg,
|
||||
local_sym_tab,
|
||||
compilation_context.map_sym_tab,
|
||||
compilation_context.structs_sym_tab,
|
||||
)
|
||||
if val and isinstance(val[0].type, ir.PointerType):
|
||||
return val[0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@ -75,11 +56,15 @@ def get_var_ptr_from_name(var_name, local_sym_tab):
|
||||
raise ValueError(f"Variable '{var_name}' not found in local symbol table")
|
||||
|
||||
|
||||
def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64):
|
||||
def create_int_constant_ptr(
|
||||
value, builder, compilation_context, local_sym_tab, int_width=64
|
||||
):
|
||||
"""Create a pointer to an integer constant."""
|
||||
|
||||
int_type = ir.IntType(int_width)
|
||||
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab, int_type)
|
||||
ptr, temp_name = compilation_context.scratch_pool.get_next_temp(
|
||||
local_sym_tab, int_type
|
||||
)
|
||||
logger.info(f"Using temp variable '{temp_name}' for int constant {value}")
|
||||
const_val = ir.Constant(int_type, value)
|
||||
builder.store(const_val, ptr)
|
||||
@ -88,12 +73,10 @@ def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64):
|
||||
|
||||
def get_or_create_ptr_from_arg(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
arg,
|
||||
builder,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
struct_sym_tab=None,
|
||||
expected_type=None,
|
||||
):
|
||||
"""Extract or create pointer from the call arguments."""
|
||||
@ -102,16 +85,22 @@ def get_or_create_ptr_from_arg(
|
||||
sz = None
|
||||
if isinstance(arg, ast.Name):
|
||||
# Stack space is already allocated
|
||||
ptr = get_var_ptr_from_name(arg.id, local_sym_tab)
|
||||
if arg.id in local_sym_tab:
|
||||
ptr = local_sym_tab[arg.id].var
|
||||
else:
|
||||
raise ValueError(f"Variable '{arg.id}' not found")
|
||||
elif isinstance(arg, ast.Constant) and isinstance(arg.value, int):
|
||||
int_width = 64 # Default to i64
|
||||
if expected_type and isinstance(expected_type, ir.IntType):
|
||||
int_width = expected_type.width
|
||||
ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab, int_width)
|
||||
ptr = create_int_constant_ptr(
|
||||
arg.value, builder, compilation_context, local_sym_tab, int_width
|
||||
)
|
||||
elif isinstance(arg, ast.Attribute):
|
||||
# A struct field
|
||||
struct_name = arg.value.id
|
||||
field_name = arg.attr
|
||||
struct_sym_tab = compilation_context.structs_sym_tab
|
||||
|
||||
if not local_sym_tab or struct_name not in local_sym_tab:
|
||||
raise ValueError(f"Struct '{struct_name}' not found")
|
||||
@ -136,7 +125,7 @@ def get_or_create_ptr_from_arg(
|
||||
and field_type.element.width == 8
|
||||
):
|
||||
ptr, sz = get_char_array_ptr_and_size(
|
||||
arg, builder, local_sym_tab, struct_sym_tab, func
|
||||
arg, builder, local_sym_tab, compilation_context, func
|
||||
)
|
||||
if not ptr:
|
||||
raise ValueError("Failed to get char array pointer from struct field")
|
||||
@ -146,13 +135,15 @@ def get_or_create_ptr_from_arg(
|
||||
else:
|
||||
# NOTE: For any integer expression reaching this branch, it is probably a struct field or a binop
|
||||
# Evaluate the expression and store the result in a temp variable
|
||||
val = get_operand_value(
|
||||
func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||
)
|
||||
val = eval_expr(func, compilation_context, builder, arg, local_sym_tab)
|
||||
if val:
|
||||
val = val[0]
|
||||
if val is None:
|
||||
raise ValueError("Failed to evaluate expression for helper arg.")
|
||||
|
||||
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab, expected_type)
|
||||
ptr, temp_name = compilation_context.scratch_pool.get_next_temp(
|
||||
local_sym_tab, expected_type
|
||||
)
|
||||
logger.info(f"Using temp variable '{temp_name}' for expression result")
|
||||
if (
|
||||
isinstance(val.type, ir.IntType)
|
||||
@ -188,8 +179,9 @@ def get_flags_val(arg, builder, local_sym_tab):
|
||||
)
|
||||
|
||||
|
||||
def get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab):
|
||||
def get_data_ptr_and_size(data_arg, local_sym_tab, compilation_context):
|
||||
"""Extract data pointer and size information for perf event output."""
|
||||
struct_sym_tab = compilation_context.structs_sym_tab
|
||||
if isinstance(data_arg, ast.Name):
|
||||
data_name = data_arg.id
|
||||
if local_sym_tab and data_name in local_sym_tab:
|
||||
@ -213,8 +205,9 @@ def get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab):
|
||||
)
|
||||
|
||||
|
||||
def get_buffer_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab):
|
||||
def get_buffer_ptr_and_size(buf_arg, builder, local_sym_tab, compilation_context):
|
||||
"""Extract buffer pointer and size from either a struct field or variable."""
|
||||
struct_sym_tab = compilation_context.structs_sym_tab
|
||||
|
||||
# Case 1: Struct field (obj.field)
|
||||
if isinstance(buf_arg, ast.Attribute):
|
||||
@ -268,9 +261,10 @@ def get_buffer_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab):
|
||||
|
||||
|
||||
def get_char_array_ptr_and_size(
|
||||
buf_arg, builder, local_sym_tab, struct_sym_tab, func=None
|
||||
buf_arg, builder, local_sym_tab, compilation_context, func=None
|
||||
):
|
||||
"""Get pointer to char array and its size."""
|
||||
struct_sym_tab = compilation_context.structs_sym_tab
|
||||
|
||||
# Struct field: obj.field
|
||||
if isinstance(buf_arg, ast.Attribute) and isinstance(buf_arg.value, ast.Name):
|
||||
@ -351,34 +345,10 @@ def _is_char_array(ir_type):
|
||||
)
|
||||
|
||||
|
||||
def get_ptr_from_arg(
|
||||
arg, func, module, builder, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||
):
|
||||
"""Evaluate argument and return pointer value"""
|
||||
|
||||
result = eval_expr(
|
||||
func, module, builder, arg, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise ValueError("Failed to evaluate argument")
|
||||
|
||||
val, val_type = result
|
||||
|
||||
if not isinstance(val_type, ir.PointerType):
|
||||
raise ValueError(f"Expected pointer type, got {val_type}")
|
||||
|
||||
return val, val_type
|
||||
|
||||
|
||||
def get_int_value_from_arg(
|
||||
arg, func, module, builder, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||
):
|
||||
def get_int_value_from_arg(arg, func, compilation_context, builder, local_sym_tab):
|
||||
"""Evaluate argument and return integer value"""
|
||||
|
||||
result = eval_expr(
|
||||
func, module, builder, arg, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||
)
|
||||
result = eval_expr(func, compilation_context, builder, arg, local_sym_tab)
|
||||
|
||||
if not result:
|
||||
raise ValueError("Failed to evaluate argument")
|
||||
|
||||
Reference in New Issue
Block a user