3 Commits

4 changed files with 73 additions and 27 deletions

View File

@ -50,7 +50,7 @@ def count_temps_in_call(call_node, local_sym_tab):
func_name = call_node.func.attr
if not is_helper:
return 0
return {} # No temps needed
for arg_idx in range(len(call_node.args)):
# NOTE: Count all non-name arguments

View File

@ -567,7 +567,11 @@ def bpf_get_current_uid_gid_emitter(
return pid, ir.IntType(64)
@HelperHandlerRegistry.register("skb_store_bytes")
@HelperHandlerRegistry.register(
"skb_store_bytes",
param_types=[ir.IntType(32), ir.PointerType(), ir.IntType(32), ir.IntType(64)],
return_type=ir.IntType(64),
)
def bpf_skb_store_bytes_emitter(
call,
map_ptr,
@ -583,6 +587,14 @@ def bpf_skb_store_bytes_emitter(
Expected call signature: skb_store_bytes(skb, offset, from, len, flags)
"""
args_signature = [
ir.PointerType(), # skb pointer
ir.IntType(32), # offset
ir.PointerType(), # from
ir.IntType(32), # len
ir.IntType(64), # flags
]
if len(call.args) not in (3, 4):
raise ValueError(
f"skb_store_bytes expects 3 or 4 args (offset, from, len, flags), got {len(call.args)}"
@ -596,10 +608,18 @@ def bpf_skb_store_bytes_emitter(
builder,
local_sym_tab,
map_sym_tab,
args_signature[1],
struct_sym_tab,
)
from_ptr = get_or_create_ptr_from_arg(
func, module, call.args[1], builder, local_sym_tab, map_sym_tab, struct_sym_tab
func,
module,
call.args[1],
builder,
local_sym_tab,
map_sym_tab,
args_signature[2],
struct_sym_tab,
)
len_val = get_int_value_from_arg(
call.args[2],
@ -608,6 +628,7 @@ def bpf_skb_store_bytes_emitter(
builder,
local_sym_tab,
map_sym_tab,
args_signature[3],
struct_sym_tab,
)
if len(call.args) == 4:
@ -617,13 +638,7 @@ def bpf_skb_store_bytes_emitter(
flags = ir.Constant(ir.IntType(64), flags_val)
fn_type = ir.FunctionType(
ir.IntType(64),
[
ir.PointerType(), # skb
ir.IntType(32), # offset
ir.PointerType(), # from
ir.IntType(32), # len
ir.IntType(64), # flags
],
args_signature,
var_arg=False,
)
fn_ptr = builder.inttoptr(

View File

@ -50,7 +50,7 @@ class HelperHandlerRegistry:
def get_param_type(cls, helper_name, index):
"""Get the type of a parameter of a helper function by the index"""
signature = cls.get_signature(helper_name)
if signature and 0 <= index < len(signature.arg_types):
if signature and signature.arg_types and 0 <= index < len(signature.arg_types):
return signature.arg_types[index]
return None

View File

@ -14,26 +14,43 @@ class ScratchPoolManager:
"""Manage the temporary helper variables in local_sym_tab"""
def __init__(self):
self._counter = 0
self._counters = {}
@property
def counter(self):
return self._counter
return sum(self._counter.values())
def reset(self):
self._counter = 0
self._counters.clear()
logger.debug("Scratch pool counter reset to 0")
def get_next_temp(self, local_sym_tab):
temp_name = f"__helper_temp_{self._counter}"
self._counter += 1
def _get_type_name(self, ir_type):
if isinstance(ir_type, ir.PointerType):
return "ptr"
elif isinstance(ir_type, ir.IntType):
return f"i{ir_type.width}"
elif isinstance(ir_type, ir.ArrayType):
return f"[{ir_type.count}x{self._get_type_name(ir_type.element)}]"
else:
return str(ir_type).replace(" ", "")
def get_next_temp(self, local_sym_tab, expected_type=None):
# Default to i64 if no expected type provided
type_name = self._get_type_name(expected_type) if expected_type else "i64"
if type_name not in self._counters:
self._counters[type_name] = 0
counter = self._counters[type_name]
temp_name = f"__helper_temp_{type_name}_{counter}"
self._counters[type_name] += 1
if temp_name not in local_sym_tab:
raise ValueError(
f"Scratch pool exhausted or inadequate: {temp_name}. "
f"Current counter: {self._counter}"
f"Type: {type_name} Counter: {counter}"
)
logger.debug(f"Using {temp_name} for type {type_name}")
return local_sym_tab[temp_name].var, temp_name
@ -60,24 +77,35 @@ def get_var_ptr_from_name(var_name, local_sym_tab):
def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64):
"""Create a pointer to an integer constant."""
# Default to 64-bit integer
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab)
int_type = ir.IntType(int_width)
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab, int_type)
logger.info(f"Using temp variable '{temp_name}' for int constant {value}")
const_val = ir.Constant(ir.IntType(int_width), value)
const_val = ir.Constant(int_type, value)
builder.store(const_val, ptr)
return ptr
def get_or_create_ptr_from_arg(
func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab=None
func,
module,
arg,
builder,
local_sym_tab,
map_sym_tab,
expected_type=None,
struct_sym_tab=None,
):
"""Extract or create pointer from the call arguments."""
if isinstance(arg, ast.Name):
# Stack space is already allocated
ptr = get_var_ptr_from_name(arg.id, local_sym_tab)
elif isinstance(arg, ast.Constant) and isinstance(arg.value, int):
ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab)
if expected_type and isinstance(expected_type, ir.IntType):
int_width = expected_type.width
ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab, int_width)
else:
# NOTE: For any integer expression reaching this branch, it is probably a struct field or a binop
# Evaluate the expression and store the result in a temp variable
val = get_operand_value(
func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab
@ -85,11 +113,14 @@ def get_or_create_ptr_from_arg(
if val is None:
raise ValueError("Failed to evaluate expression for helper arg.")
# NOTE: We assume the result is an int64 for now
# if isinstance(arg, ast.Attribute):
# return val
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab)
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab, expected_type)
logger.info(f"Using temp variable '{temp_name}' for expression result")
if (
isinstance(val.type, ir.IntType)
and expected_type
and val.type.width > expected_type.width
):
val = builder.trunc(val, expected_type)
builder.store(val, ptr)
return ptr