3 Commits

4 changed files with 73 additions and 27 deletions

View File

@ -50,7 +50,7 @@ def count_temps_in_call(call_node, local_sym_tab):
func_name = call_node.func.attr func_name = call_node.func.attr
if not is_helper: if not is_helper:
return 0 return {} # No temps needed
for arg_idx in range(len(call_node.args)): for arg_idx in range(len(call_node.args)):
# NOTE: Count all non-name arguments # NOTE: Count all non-name arguments

View File

@ -567,7 +567,11 @@ def bpf_get_current_uid_gid_emitter(
return pid, ir.IntType(64) return pid, ir.IntType(64)
@HelperHandlerRegistry.register("skb_store_bytes") @HelperHandlerRegistry.register(
"skb_store_bytes",
param_types=[ir.IntType(32), ir.PointerType(), ir.IntType(32), ir.IntType(64)],
return_type=ir.IntType(64),
)
def bpf_skb_store_bytes_emitter( def bpf_skb_store_bytes_emitter(
call, call,
map_ptr, map_ptr,
@ -583,6 +587,14 @@ def bpf_skb_store_bytes_emitter(
Expected call signature: skb_store_bytes(skb, offset, from, len, flags) Expected call signature: skb_store_bytes(skb, offset, from, len, flags)
""" """
args_signature = [
ir.PointerType(), # skb pointer
ir.IntType(32), # offset
ir.PointerType(), # from
ir.IntType(32), # len
ir.IntType(64), # flags
]
if len(call.args) not in (3, 4): if len(call.args) not in (3, 4):
raise ValueError( raise ValueError(
f"skb_store_bytes expects 3 or 4 args (offset, from, len, flags), got {len(call.args)}" f"skb_store_bytes expects 3 or 4 args (offset, from, len, flags), got {len(call.args)}"
@ -596,10 +608,18 @@ def bpf_skb_store_bytes_emitter(
builder, builder,
local_sym_tab, local_sym_tab,
map_sym_tab, map_sym_tab,
args_signature[1],
struct_sym_tab, struct_sym_tab,
) )
from_ptr = get_or_create_ptr_from_arg( from_ptr = get_or_create_ptr_from_arg(
func, module, call.args[1], builder, local_sym_tab, map_sym_tab, struct_sym_tab func,
module,
call.args[1],
builder,
local_sym_tab,
map_sym_tab,
args_signature[2],
struct_sym_tab,
) )
len_val = get_int_value_from_arg( len_val = get_int_value_from_arg(
call.args[2], call.args[2],
@ -608,6 +628,7 @@ def bpf_skb_store_bytes_emitter(
builder, builder,
local_sym_tab, local_sym_tab,
map_sym_tab, map_sym_tab,
args_signature[3],
struct_sym_tab, struct_sym_tab,
) )
if len(call.args) == 4: if len(call.args) == 4:
@ -617,13 +638,7 @@ def bpf_skb_store_bytes_emitter(
flags = ir.Constant(ir.IntType(64), flags_val) flags = ir.Constant(ir.IntType(64), flags_val)
fn_type = ir.FunctionType( fn_type = ir.FunctionType(
ir.IntType(64), ir.IntType(64),
[ args_signature,
ir.PointerType(), # skb
ir.IntType(32), # offset
ir.PointerType(), # from
ir.IntType(32), # len
ir.IntType(64), # flags
],
var_arg=False, var_arg=False,
) )
fn_ptr = builder.inttoptr( fn_ptr = builder.inttoptr(

View File

@ -50,7 +50,7 @@ class HelperHandlerRegistry:
def get_param_type(cls, helper_name, index): def get_param_type(cls, helper_name, index):
"""Get the type of a parameter of a helper function by the index""" """Get the type of a parameter of a helper function by the index"""
signature = cls.get_signature(helper_name) signature = cls.get_signature(helper_name)
if signature and 0 <= index < len(signature.arg_types): if signature and signature.arg_types and 0 <= index < len(signature.arg_types):
return signature.arg_types[index] return signature.arg_types[index]
return None return None

View File

@ -14,26 +14,43 @@ class ScratchPoolManager:
"""Manage the temporary helper variables in local_sym_tab""" """Manage the temporary helper variables in local_sym_tab"""
def __init__(self): def __init__(self):
self._counter = 0 self._counters = {}
@property @property
def counter(self): def counter(self):
return self._counter return sum(self._counter.values())
def reset(self): def reset(self):
self._counter = 0 self._counters.clear()
logger.debug("Scratch pool counter reset to 0") logger.debug("Scratch pool counter reset to 0")
def get_next_temp(self, local_sym_tab): def _get_type_name(self, ir_type):
temp_name = f"__helper_temp_{self._counter}" if isinstance(ir_type, ir.PointerType):
self._counter += 1 return "ptr"
elif isinstance(ir_type, ir.IntType):
return f"i{ir_type.width}"
elif isinstance(ir_type, ir.ArrayType):
return f"[{ir_type.count}x{self._get_type_name(ir_type.element)}]"
else:
return str(ir_type).replace(" ", "")
def get_next_temp(self, local_sym_tab, expected_type=None):
# Default to i64 if no expected type provided
type_name = self._get_type_name(expected_type) if expected_type else "i64"
if type_name not in self._counters:
self._counters[type_name] = 0
counter = self._counters[type_name]
temp_name = f"__helper_temp_{type_name}_{counter}"
self._counters[type_name] += 1
if temp_name not in local_sym_tab: if temp_name not in local_sym_tab:
raise ValueError( raise ValueError(
f"Scratch pool exhausted or inadequate: {temp_name}. " f"Scratch pool exhausted or inadequate: {temp_name}. "
f"Current counter: {self._counter}" f"Type: {type_name} Counter: {counter}"
) )
logger.debug(f"Using {temp_name} for type {type_name}")
return local_sym_tab[temp_name].var, temp_name return local_sym_tab[temp_name].var, temp_name
@ -60,24 +77,35 @@ def get_var_ptr_from_name(var_name, local_sym_tab):
def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64): def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64):
"""Create a pointer to an integer constant.""" """Create a pointer to an integer constant."""
# Default to 64-bit integer int_type = ir.IntType(int_width)
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab) ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab, int_type)
logger.info(f"Using temp variable '{temp_name}' for int constant {value}") logger.info(f"Using temp variable '{temp_name}' for int constant {value}")
const_val = ir.Constant(ir.IntType(int_width), value) const_val = ir.Constant(int_type, value)
builder.store(const_val, ptr) builder.store(const_val, ptr)
return ptr return ptr
def get_or_create_ptr_from_arg( def get_or_create_ptr_from_arg(
func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab=None func,
module,
arg,
builder,
local_sym_tab,
map_sym_tab,
expected_type=None,
struct_sym_tab=None,
): ):
"""Extract or create pointer from the call arguments.""" """Extract or create pointer from the call arguments."""
if isinstance(arg, ast.Name): if isinstance(arg, ast.Name):
# Stack space is already allocated
ptr = get_var_ptr_from_name(arg.id, local_sym_tab) ptr = get_var_ptr_from_name(arg.id, local_sym_tab)
elif isinstance(arg, ast.Constant) and isinstance(arg.value, int): elif isinstance(arg, ast.Constant) and isinstance(arg.value, int):
ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab) if expected_type and isinstance(expected_type, ir.IntType):
int_width = expected_type.width
ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab, int_width)
else: else:
# NOTE: For any integer expression reaching this branch, it is probably a struct field or a binop
# Evaluate the expression and store the result in a temp variable # Evaluate the expression and store the result in a temp variable
val = get_operand_value( val = get_operand_value(
func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab
@ -85,11 +113,14 @@ def get_or_create_ptr_from_arg(
if val is None: if val is None:
raise ValueError("Failed to evaluate expression for helper arg.") raise ValueError("Failed to evaluate expression for helper arg.")
# NOTE: We assume the result is an int64 for now ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab, expected_type)
# if isinstance(arg, ast.Attribute):
# return val
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab)
logger.info(f"Using temp variable '{temp_name}' for expression result") logger.info(f"Using temp variable '{temp_name}' for expression result")
if (
isinstance(val.type, ir.IntType)
and expected_type
and val.type.width > expected_type.width
):
val = builder.trunc(val, expected_type)
builder.store(val, ptr) builder.store(val, ptr)
return ptr return ptr