18 Commits

Author SHA1 Message Date
cf99b3bb9a Fix call to get_or_create_ptr_from_arg for probe_read_str 2025-11-07 19:16:48 +05:30
6c85b248ce Init sz in get_or_create_ptr_from_arg 2025-11-07 19:03:21 +05:30
b5a3494cc6 Fix typo in get_or_create_ptr_from_arg 2025-11-07 19:01:40 +05:30
be62972974 Fix ScratchPoolManager::counter 2025-11-07 19:00:57 +05:30
2f4a7d2f90 Remove get_struct_char_array_ptr in favour of get_char_array_ptr_and_size, wrap it in get_or_crate_ptr_from_arg to use in bpf_helper_handler 2025-11-07 18:54:59 +05:30
3ccd3f767e Add expected types for pointer creation of args in probe_read handler 2025-11-06 19:59:04 +05:30
2e37726922 Add signature relection for all helper handlers except print 2025-11-06 19:47:57 +05:30
5b36726b7d Make bpf_skb_store_bytes work 2025-11-05 20:02:39 +05:30
3e6cea2b67 Move get_struct_char_array_ptr from helper/printk_formatter to helper/helper_utils, enable array to ptr conversion in skb_store_bytes 2025-11-05 19:10:58 +05:30
338d4994d8 Fix count_temps_in_call to only look for Pointer args of a helper_sig 2025-11-05 17:36:37 +05:30
3078d4224d Add typed scratch space support to the bpf_skb_store_bytes helper 2025-11-04 16:09:11 +05:30
7d29790f00 Make use of new get_next_temp in helpers 2025-11-04 16:02:56 +05:30
963e2a8171 Change ScratchPoolManager to use typed scratch space 2025-11-04 14:16:44 +05:30
123a92af1d Change allocation pass to generate typed temp variables 2025-11-04 06:20:39 +05:30
752f564d3f Change count_temps_in_call to return hashmap of types 2025-11-04 05:40:22 +05:30
d8cddb9799 Add signature extraction to HelperHandlerRegistry 2025-11-04 05:19:22 +05:30
33e18f6d6d Introduce HelperSignature in HelperHandlerRegistry 2025-11-03 21:21:13 +05:30
5e371787eb Fix the number of args for skb_store_bytes by making the first arg implicit 2025-11-03 21:11:16 +05:30
7 changed files with 303 additions and 121 deletions

View File

@ -199,17 +199,33 @@ def _allocate_for_binop(builder, var_name, local_sym_tab):
logger.info(f"Pre-allocated {var_name} for binop result") logger.info(f"Pre-allocated {var_name} for binop result")
def _get_type_name(ir_type):
"""Get a string representation of an IR type."""
if isinstance(ir_type, ir.IntType):
return f"i{ir_type.width}"
elif isinstance(ir_type, ir.PointerType):
return "ptr"
elif isinstance(ir_type, ir.ArrayType):
return f"[{ir_type.count}x{_get_type_name(ir_type.element)}]"
else:
return str(ir_type).replace(" ", "")
def allocate_temp_pool(builder, max_temps, local_sym_tab): def allocate_temp_pool(builder, max_temps, local_sym_tab):
"""Allocate the temporary scratch space pool for helper arguments.""" """Allocate the temporary scratch space pool for helper arguments."""
if max_temps == 0: if not max_temps:
logger.info("No temp pool allocation needed")
return return
logger.info(f"Allocating temp pool of {max_temps} variables") for tmp_type, cnt in max_temps.items():
for i in range(max_temps): type_name = _get_type_name(tmp_type)
temp_name = f"__helper_temp_{i}" logger.info(f"Allocating temp pool of {cnt} variables of type {type_name}")
temp_var = builder.alloca(ir.IntType(64), name=temp_name) for i in range(cnt):
temp_var.align = 8 temp_name = f"__helper_temp_{type_name}_{i}"
local_sym_tab[temp_name] = LocalSymbol(temp_var, ir.IntType(64)) temp_var = builder.alloca(tmp_type, name=temp_name)
temp_var.align = _get_alignment(tmp_type)
local_sym_tab[temp_name] = LocalSymbol(temp_var, tmp_type)
logger.debug(f"Allocated temp variable: {temp_name}")
def _allocate_for_name(builder, var_name, rval, local_sym_tab): def _allocate_for_name(builder, var_name, rval, local_sym_tab):

View File

@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
def count_temps_in_call(call_node, local_sym_tab): def count_temps_in_call(call_node, local_sym_tab):
"""Count the number of temporary variables needed for a function call.""" """Count the number of temporary variables needed for a function call."""
count = 0 count = {}
is_helper = False is_helper = False
# NOTE: We exclude print calls for now # NOTE: We exclude print calls for now
@ -43,21 +43,28 @@ def count_temps_in_call(call_node, local_sym_tab):
and call_node.func.id != "print" and call_node.func.id != "print"
): ):
is_helper = True is_helper = True
func_name = call_node.func.id
elif isinstance(call_node.func, ast.Attribute): elif isinstance(call_node.func, ast.Attribute):
if HelperHandlerRegistry.has_handler(call_node.func.attr): if HelperHandlerRegistry.has_handler(call_node.func.attr):
is_helper = True is_helper = True
func_name = call_node.func.attr
if not is_helper: if not is_helper:
return 0 return {} # No temps needed
for arg in call_node.args: for arg_idx in range(len(call_node.args)):
# NOTE: Count all non-name arguments # NOTE: Count all non-name arguments
# For struct fields, if it is being passed as an argument, # For struct fields, if it is being passed as an argument,
# The struct object should already exist in the local_sym_tab # The struct object should already exist in the local_sym_tab
if not isinstance(arg, ast.Name) and not ( arg = call_node.args[arg_idx]
if isinstance(arg, ast.Name) or (
isinstance(arg, ast.Attribute) and arg.value.id in local_sym_tab isinstance(arg, ast.Attribute) and arg.value.id in local_sym_tab
): ):
count += 1 continue
param_type = HelperHandlerRegistry.get_param_type(func_name, arg_idx)
if isinstance(param_type, ir.PointerType):
pointee_type = param_type.pointee
count[pointee_type] = count.get(pointee_type, 0) + 1
return count return count
@ -93,11 +100,15 @@ def handle_if_allocation(
def allocate_mem( def allocate_mem(
module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
): ):
max_temps_needed = 0 max_temps_needed = {}
def merge_type_counts(count_dict):
nonlocal max_temps_needed
for typ, cnt in count_dict.items():
max_temps_needed[typ] = max(max_temps_needed.get(typ, 0), cnt)
def update_max_temps_for_stmt(stmt): def update_max_temps_for_stmt(stmt):
nonlocal max_temps_needed nonlocal max_temps_needed
temps_needed = 0
if isinstance(stmt, ast.If): if isinstance(stmt, ast.If):
for s in stmt.body: for s in stmt.body:
@ -106,10 +117,13 @@ def allocate_mem(
update_max_temps_for_stmt(s) update_max_temps_for_stmt(s)
return return
stmt_temps = {}
for node in ast.walk(stmt): for node in ast.walk(stmt):
if isinstance(node, ast.Call): if isinstance(node, ast.Call):
temps_needed += count_temps_in_call(node, local_sym_tab) call_temps = count_temps_in_call(node, local_sym_tab)
max_temps_needed = max(max_temps_needed, temps_needed) for typ, cnt in call_temps.items():
stmt_temps[typ] = stmt_temps.get(typ, 0) + cnt
merge_type_counts(stmt_temps)
for stmt in body: for stmt in body:
update_max_temps_for_stmt(stmt) update_max_temps_for_stmt(stmt)

View File

@ -8,7 +8,6 @@ from .helper_utils import (
get_flags_val, get_flags_val,
get_data_ptr_and_size, get_data_ptr_and_size,
get_buffer_ptr_and_size, get_buffer_ptr_and_size,
get_char_array_ptr_and_size,
get_ptr_from_arg, get_ptr_from_arg,
get_int_value_from_arg, get_int_value_from_arg,
) )
@ -37,7 +36,11 @@ class BPFHelperID(Enum):
BPF_PROBE_READ_KERNEL_STR = 115 BPF_PROBE_READ_KERNEL_STR = 115
@HelperHandlerRegistry.register("ktime") @HelperHandlerRegistry.register(
"ktime",
param_types=[],
return_type=ir.IntType(64),
)
def bpf_ktime_get_ns_emitter( def bpf_ktime_get_ns_emitter(
call, call,
map_ptr, map_ptr,
@ -60,7 +63,11 @@ def bpf_ktime_get_ns_emitter(
return result, ir.IntType(64) return result, ir.IntType(64)
@HelperHandlerRegistry.register("lookup") @HelperHandlerRegistry.register(
"lookup",
param_types=[ir.PointerType(ir.IntType(64))],
return_type=ir.PointerType(ir.IntType(64)),
)
def bpf_map_lookup_elem_emitter( def bpf_map_lookup_elem_emitter(
call, call,
map_ptr, map_ptr,
@ -102,6 +109,7 @@ def bpf_map_lookup_elem_emitter(
return result, ir.PointerType() return result, ir.PointerType()
# NOTE: This has special handling so we won't reflect the signature here.
@HelperHandlerRegistry.register("print") @HelperHandlerRegistry.register("print")
def bpf_printk_emitter( def bpf_printk_emitter(
call, call,
@ -150,7 +158,15 @@ def bpf_printk_emitter(
return True return True
@HelperHandlerRegistry.register("update") @HelperHandlerRegistry.register(
"update",
param_types=[
ir.PointerType(ir.IntType(64)),
ir.PointerType(ir.IntType(64)),
ir.IntType(64),
],
return_type=ir.PointerType(ir.IntType(64)),
)
def bpf_map_update_elem_emitter( def bpf_map_update_elem_emitter(
call, call,
map_ptr, map_ptr,
@ -205,7 +221,11 @@ def bpf_map_update_elem_emitter(
return result, None return result, None
@HelperHandlerRegistry.register("delete") @HelperHandlerRegistry.register(
"delete",
param_types=[ir.PointerType(ir.IntType(64))],
return_type=ir.PointerType(ir.IntType(64)),
)
def bpf_map_delete_elem_emitter( def bpf_map_delete_elem_emitter(
call, call,
map_ptr, map_ptr,
@ -245,7 +265,11 @@ def bpf_map_delete_elem_emitter(
return result, None return result, None
@HelperHandlerRegistry.register("comm") @HelperHandlerRegistry.register(
"comm",
param_types=[ir.PointerType(ir.IntType(8))],
return_type=ir.IntType(64),
)
def bpf_get_current_comm_emitter( def bpf_get_current_comm_emitter(
call, call,
map_ptr, map_ptr,
@ -302,7 +326,11 @@ def bpf_get_current_comm_emitter(
return result, None return result, None
@HelperHandlerRegistry.register("pid") @HelperHandlerRegistry.register(
"pid",
param_types=[],
return_type=ir.IntType(64),
)
def bpf_get_current_pid_tgid_emitter( def bpf_get_current_pid_tgid_emitter(
call, call,
map_ptr, map_ptr,
@ -330,7 +358,11 @@ def bpf_get_current_pid_tgid_emitter(
return pid, ir.IntType(64) return pid, ir.IntType(64)
@HelperHandlerRegistry.register("output") @HelperHandlerRegistry.register(
"output",
param_types=[ir.PointerType(ir.IntType(8))],
return_type=ir.IntType(64),
)
def bpf_perf_event_output_handler( def bpf_perf_event_output_handler(
call, call,
map_ptr, map_ptr,
@ -405,7 +437,14 @@ def emit_probe_read_kernel_str_call(builder, dst_ptr, dst_size, src_ptr):
return result return result
@HelperHandlerRegistry.register("probe_read_str") @HelperHandlerRegistry.register(
"probe_read_str",
param_types=[
ir.PointerType(ir.IntType(8)),
ir.PointerType(ir.IntType(8)),
],
return_type=ir.IntType(64),
)
def bpf_probe_read_kernel_str_emitter( def bpf_probe_read_kernel_str_emitter(
call, call,
map_ptr, map_ptr,
@ -424,8 +463,8 @@ def bpf_probe_read_kernel_str_emitter(
) )
# Get destination buffer (char array -> i8*) # Get destination buffer (char array -> i8*)
dst_ptr, dst_size = get_char_array_ptr_and_size( dst_ptr, dst_size = get_or_create_ptr_from_arg(
call.args[0], builder, local_sym_tab, struct_sym_tab func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab
) )
# Get source pointer (evaluate expression) # Get source pointer (evaluate expression)
@ -440,7 +479,11 @@ def bpf_probe_read_kernel_str_emitter(
return result, ir.IntType(64) return result, ir.IntType(64)
@HelperHandlerRegistry.register("random") @HelperHandlerRegistry.register(
"random",
param_types=[],
return_type=ir.IntType(32),
)
def bpf_get_prandom_u32_emitter( def bpf_get_prandom_u32_emitter(
call, call,
map_ptr, map_ptr,
@ -462,7 +505,15 @@ def bpf_get_prandom_u32_emitter(
return result, ir.IntType(32) return result, ir.IntType(32)
@HelperHandlerRegistry.register("probe_read") @HelperHandlerRegistry.register(
"probe_read",
param_types=[
ir.PointerType(ir.IntType(8)),
ir.IntType(32),
ir.PointerType(ir.IntType(8)),
],
return_type=ir.IntType(64),
)
def bpf_probe_read_emitter( def bpf_probe_read_emitter(
call, call,
map_ptr, map_ptr,
@ -481,7 +532,14 @@ def bpf_probe_read_emitter(
logger.warn("Expected 3 args for probe_read helper") logger.warn("Expected 3 args for probe_read helper")
return return
dst_ptr = get_or_create_ptr_from_arg( dst_ptr = get_or_create_ptr_from_arg(
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab func,
module,
call.args[0],
builder,
local_sym_tab,
map_sym_tab,
struct_sym_tab,
ir.IntType(8),
) )
size_val = get_int_value_from_arg( size_val = get_int_value_from_arg(
call.args[1], call.args[1],
@ -493,7 +551,14 @@ def bpf_probe_read_emitter(
struct_sym_tab, struct_sym_tab,
) )
src_ptr = get_or_create_ptr_from_arg( src_ptr = get_or_create_ptr_from_arg(
func, module, call.args[2], builder, local_sym_tab, map_sym_tab, struct_sym_tab func,
module,
call.args[2],
builder,
local_sym_tab,
map_sym_tab,
struct_sym_tab,
ir.IntType(8),
) )
fn_type = ir.FunctionType( fn_type = ir.FunctionType(
ir.IntType(64), ir.IntType(64),
@ -517,7 +582,11 @@ def bpf_probe_read_emitter(
return result, ir.IntType(64) return result, ir.IntType(64)
@HelperHandlerRegistry.register("smp_processor_id") @HelperHandlerRegistry.register(
"smp_processor_id",
param_types=[],
return_type=ir.IntType(32),
)
def bpf_get_smp_processor_id_emitter( def bpf_get_smp_processor_id_emitter(
call, call,
map_ptr, map_ptr,
@ -540,7 +609,11 @@ def bpf_get_smp_processor_id_emitter(
return result, ir.IntType(32) return result, ir.IntType(32)
@HelperHandlerRegistry.register("uid") @HelperHandlerRegistry.register(
"uid",
param_types=[],
return_type=ir.IntType(64),
)
def bpf_get_current_uid_gid_emitter( def bpf_get_current_uid_gid_emitter(
call, call,
map_ptr, map_ptr,
@ -567,7 +640,16 @@ def bpf_get_current_uid_gid_emitter(
return pid, ir.IntType(64) return pid, ir.IntType(64)
@HelperHandlerRegistry.register("skb_store_bytes") @HelperHandlerRegistry.register(
"skb_store_bytes",
param_types=[
ir.IntType(32),
ir.PointerType(ir.IntType(8)),
ir.IntType(32),
ir.IntType(64),
],
return_type=ir.IntType(64),
)
def bpf_skb_store_bytes_emitter( def bpf_skb_store_bytes_emitter(
call, call,
map_ptr, map_ptr,
@ -583,16 +665,22 @@ def bpf_skb_store_bytes_emitter(
Expected call signature: skb_store_bytes(skb, offset, from, len, flags) Expected call signature: skb_store_bytes(skb, offset, from, len, flags)
""" """
if len(call.args) not in (4, 5): args_signature = [
ir.PointerType(), # skb pointer
ir.IntType(32), # offset
ir.PointerType(), # from
ir.IntType(32), # len
ir.IntType(64), # flags
]
if len(call.args) not in (3, 4):
raise ValueError( raise ValueError(
f"skb_store_bytes expects 4 or 5 args (skb, offset, from, len, flags), got {len(call.args)}" f"skb_store_bytes expects 3 or 4 args (offset, from, len, flags), got {len(call.args)}"
) )
skb_ptr = get_or_create_ptr_from_arg( skb_ptr = func.args[0] # First argument to the function is skb
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab
)
offset_val = get_int_value_from_arg( offset_val = get_int_value_from_arg(
call.args[1], call.args[0],
func, func,
module, module,
builder, builder,
@ -601,10 +689,17 @@ def bpf_skb_store_bytes_emitter(
struct_sym_tab, struct_sym_tab,
) )
from_ptr = get_or_create_ptr_from_arg( from_ptr = get_or_create_ptr_from_arg(
func, module, call.args[2], builder, local_sym_tab, map_sym_tab, struct_sym_tab func,
module,
call.args[1],
builder,
local_sym_tab,
map_sym_tab,
struct_sym_tab,
args_signature[2],
) )
len_val = get_int_value_from_arg( len_val = get_int_value_from_arg(
call.args[3], call.args[2],
func, func,
module, module,
builder, builder,
@ -612,19 +707,14 @@ def bpf_skb_store_bytes_emitter(
map_sym_tab, map_sym_tab,
struct_sym_tab, struct_sym_tab,
) )
if len(call.args) == 5: if len(call.args) == 4:
flags_val = get_flags_val(call.args[4], builder, local_sym_tab) flags_val = get_flags_val(call.args[3], builder, local_sym_tab)
else: else:
flags_val = ir.Constant(ir.IntType(64), 0) flags_val = 0
flags = ir.Constant(ir.IntType(64), flags_val)
fn_type = ir.FunctionType( fn_type = ir.FunctionType(
ir.IntType(64), ir.IntType(64),
[ args_signature,
ir.PointerType(), # skb
ir.IntType(32), # offset
ir.PointerType(), # from
ir.IntType(32), # len
ir.IntType(64), # flags
],
var_arg=False, var_arg=False,
) )
fn_ptr = builder.inttoptr( fn_ptr = builder.inttoptr(
@ -638,7 +728,7 @@ def bpf_skb_store_bytes_emitter(
builder.trunc(offset_val, ir.IntType(32)), builder.trunc(offset_val, ir.IntType(32)),
builder.bitcast(from_ptr, ir.PointerType()), builder.bitcast(from_ptr, ir.PointerType()),
builder.trunc(len_val, ir.IntType(32)), builder.trunc(len_val, ir.IntType(32)),
flags_val, flags,
], ],
tail=False, tail=False,
) )

View File

@ -1,17 +1,31 @@
from dataclasses import dataclass
from llvmlite import ir
from typing import Callable from typing import Callable
@dataclass
class HelperSignature:
"""Signature of a BPF helper function"""
arg_types: list[ir.Type]
return_type: ir.Type
func: Callable
class HelperHandlerRegistry: class HelperHandlerRegistry:
"""Registry for BPF helpers""" """Registry for BPF helpers"""
_handlers: dict[str, Callable] = {} _handlers: dict[str, HelperSignature] = {}
@classmethod @classmethod
def register(cls, helper_name): def register(cls, helper_name, param_types=None, return_type=None):
"""Decorator to register a handler function for a helper""" """Decorator to register a handler function for a helper"""
def decorator(func): def decorator(func):
cls._handlers[helper_name] = func helper_sig = HelperSignature(
arg_types=param_types, return_type=return_type, func=func
)
cls._handlers[helper_name] = helper_sig
return func return func
return decorator return decorator
@ -19,9 +33,29 @@ class HelperHandlerRegistry:
@classmethod @classmethod
def get_handler(cls, helper_name): def get_handler(cls, helper_name):
"""Get the handler function for a helper""" """Get the handler function for a helper"""
return cls._handlers.get(helper_name) handler = cls._handlers.get(helper_name)
return handler.func if handler else None
@classmethod @classmethod
def has_handler(cls, helper_name): def has_handler(cls, helper_name):
"""Check if a handler function is registered for a helper""" """Check if a handler function is registered for a helper"""
return helper_name in cls._handlers return helper_name in cls._handlers
@classmethod
def get_signature(cls, helper_name):
"""Get the signature of a helper function"""
return cls._handlers.get(helper_name)
@classmethod
def get_param_type(cls, helper_name, index):
"""Get the type of a parameter of a helper function by the index"""
signature = cls.get_signature(helper_name)
if signature and signature.arg_types and 0 <= index < len(signature.arg_types):
return signature.arg_types[index]
return None
@classmethod
def get_return_type(cls, helper_name):
"""Get the return type of a helper function"""
signature = cls.get_signature(helper_name)
return signature.return_type if signature else None

View File

@ -14,26 +14,43 @@ class ScratchPoolManager:
"""Manage the temporary helper variables in local_sym_tab""" """Manage the temporary helper variables in local_sym_tab"""
def __init__(self): def __init__(self):
self._counter = 0 self._counters = {}
@property @property
def counter(self): def counter(self):
return self._counter return sum(self._counters.values())
def reset(self): def reset(self):
self._counter = 0 self._counters.clear()
logger.debug("Scratch pool counter reset to 0") logger.debug("Scratch pool counter reset to 0")
def get_next_temp(self, local_sym_tab): def _get_type_name(self, ir_type):
temp_name = f"__helper_temp_{self._counter}" if isinstance(ir_type, ir.PointerType):
self._counter += 1 return "ptr"
elif isinstance(ir_type, ir.IntType):
return f"i{ir_type.width}"
elif isinstance(ir_type, ir.ArrayType):
return f"[{ir_type.count}x{self._get_type_name(ir_type.element)}]"
else:
return str(ir_type).replace(" ", "")
def get_next_temp(self, local_sym_tab, expected_type=None):
# Default to i64 if no expected type provided
type_name = self._get_type_name(expected_type) if expected_type else "i64"
if type_name not in self._counters:
self._counters[type_name] = 0
counter = self._counters[type_name]
temp_name = f"__helper_temp_{type_name}_{counter}"
self._counters[type_name] += 1
if temp_name not in local_sym_tab: if temp_name not in local_sym_tab:
raise ValueError( raise ValueError(
f"Scratch pool exhausted or inadequate: {temp_name}. " f"Scratch pool exhausted or inadequate: {temp_name}. "
f"Current counter: {self._counter}" f"Type: {type_name} Counter: {counter}"
) )
logger.debug(f"Using {temp_name} for type {type_name}")
return local_sym_tab[temp_name].var, temp_name return local_sym_tab[temp_name].var, temp_name
@ -60,24 +77,73 @@ def get_var_ptr_from_name(var_name, local_sym_tab):
def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64): def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64):
"""Create a pointer to an integer constant.""" """Create a pointer to an integer constant."""
# Default to 64-bit integer int_type = ir.IntType(int_width)
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab) ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab, int_type)
logger.info(f"Using temp variable '{temp_name}' for int constant {value}") logger.info(f"Using temp variable '{temp_name}' for int constant {value}")
const_val = ir.Constant(ir.IntType(int_width), value) const_val = ir.Constant(int_type, value)
builder.store(const_val, ptr) builder.store(const_val, ptr)
return ptr return ptr
def get_or_create_ptr_from_arg( def get_or_create_ptr_from_arg(
func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab=None func,
module,
arg,
builder,
local_sym_tab,
map_sym_tab,
struct_sym_tab=None,
expected_type=None,
): ):
"""Extract or create pointer from the call arguments.""" """Extract or create pointer from the call arguments."""
logger.info(f"Getting pointer from arg: {ast.dump(arg)}")
sz = None
if isinstance(arg, ast.Name): if isinstance(arg, ast.Name):
# Stack space is already allocated
ptr = get_var_ptr_from_name(arg.id, local_sym_tab) ptr = get_var_ptr_from_name(arg.id, local_sym_tab)
elif isinstance(arg, ast.Constant) and isinstance(arg.value, int): elif isinstance(arg, ast.Constant) and isinstance(arg.value, int):
ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab) int_width = 64 # Default to i64
if expected_type and isinstance(expected_type, ir.IntType):
int_width = expected_type.width
ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab, int_width)
elif isinstance(arg, ast.Attribute):
# A struct field
struct_name = arg.value.id
field_name = arg.attr
if not local_sym_tab or struct_name not in local_sym_tab:
raise ValueError(f"Struct '{struct_name}' not found")
struct_type = local_sym_tab[struct_name].metadata
if not struct_sym_tab or struct_type not in struct_sym_tab:
raise ValueError(f"Struct type '{struct_type}' not found")
struct_info = struct_sym_tab[struct_type]
if field_name not in struct_info.fields:
raise ValueError(
f"Field '{field_name}' not found in struct '{struct_name}'"
)
field_type = struct_info.field_type(field_name)
struct_ptr = local_sym_tab[struct_name].var
# Special handling for char arrays
if (
isinstance(field_type, ir.ArrayType)
and isinstance(field_type.element, ir.IntType)
and field_type.element.width == 8
):
ptr, sz = get_char_array_ptr_and_size(
arg, builder, local_sym_tab, struct_sym_tab
)
if not ptr:
raise ValueError("Failed to get char array pointer from struct field")
else:
ptr = struct_info.gep(builder, struct_ptr, field_name)
else: else:
# NOTE: For any integer expression reaching this branch, it is probably a struct field or a binop
# Evaluate the expression and store the result in a temp variable # Evaluate the expression and store the result in a temp variable
val = get_operand_value( val = get_operand_value(
func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab
@ -85,13 +151,20 @@ def get_or_create_ptr_from_arg(
if val is None: if val is None:
raise ValueError("Failed to evaluate expression for helper arg.") raise ValueError("Failed to evaluate expression for helper arg.")
# NOTE: We assume the result is an int64 for now ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab, expected_type)
# if isinstance(arg, ast.Attribute):
# return val
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab)
logger.info(f"Using temp variable '{temp_name}' for expression result") logger.info(f"Using temp variable '{temp_name}' for expression result")
if (
isinstance(val.type, ir.IntType)
and expected_type
and val.type.width > expected_type.width
):
val = builder.trunc(val, expected_type)
builder.store(val, ptr) builder.store(val, ptr)
# NOTE: For char arrays, also return size
if sz:
return ptr, sz
return ptr return ptr

View File

@ -47,7 +47,7 @@ def uid():
return ctypes.c_int32(0) return ctypes.c_int32(0)
def skb_store_bytes(skb, offset, from_buf, size, flags=0): def skb_store_bytes(offset, from_buf, size, flags=0):
"""store bytes into a socket buffer""" """store bytes into a socket buffer"""
return ctypes.c_int64(0) return ctypes.c_int64(0)

View File

@ -4,6 +4,7 @@ import logging
from llvmlite import ir from llvmlite import ir
from pythonbpf.expr import eval_expr, get_base_type_and_depth, deref_to_depth from pythonbpf.expr import eval_expr, get_base_type_and_depth, deref_to_depth
from pythonbpf.expr.vmlinux_registry import VmlinuxHandlerRegistry from pythonbpf.expr.vmlinux_registry import VmlinuxHandlerRegistry
from pythonbpf.helper.helper_utils import get_char_array_ptr_and_size
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -219,7 +220,7 @@ def _prepare_expr_args(expr, func, module, builder, local_sym_tab, struct_sym_ta
"""Evaluate and prepare an expression to use as an arg for bpf_printk.""" """Evaluate and prepare an expression to use as an arg for bpf_printk."""
# Special case: struct field char array needs pointer to first element # Special case: struct field char array needs pointer to first element
char_array_ptr = _get_struct_char_array_ptr( char_array_ptr, _ = get_char_array_ptr_and_size(
expr, builder, local_sym_tab, struct_sym_tab expr, builder, local_sym_tab, struct_sym_tab
) )
if char_array_ptr: if char_array_ptr:
@ -242,52 +243,6 @@ def _prepare_expr_args(expr, func, module, builder, local_sym_tab, struct_sym_ta
return ir.Constant(ir.IntType(64), 0) return ir.Constant(ir.IntType(64), 0)
def _get_struct_char_array_ptr(expr, builder, local_sym_tab, struct_sym_tab):
"""Get pointer to first element of char array in struct field, or None."""
if not (isinstance(expr, ast.Attribute) and isinstance(expr.value, ast.Name)):
return None
var_name = expr.value.id
field_name = expr.attr
# Check if it's a valid struct field
if not (
local_sym_tab
and var_name in local_sym_tab
and struct_sym_tab
and local_sym_tab[var_name].metadata in struct_sym_tab
):
return None
struct_type = local_sym_tab[var_name].metadata
struct_info = struct_sym_tab[struct_type]
if field_name not in struct_info.fields:
return None
field_type = struct_info.field_type(field_name)
# Check if it's a char array
is_char_array = (
isinstance(field_type, ir.ArrayType)
and isinstance(field_type.element, ir.IntType)
and field_type.element.width == 8
)
if not is_char_array:
return None
# Get field pointer and GEP to first element: [N x i8]* -> i8*
struct_ptr = local_sym_tab[var_name].var
field_ptr = struct_info.gep(builder, struct_ptr, field_name)
return builder.gep(
field_ptr,
[ir.Constant(ir.IntType(32), 0), ir.Constant(ir.IntType(32), 0)],
inbounds=True,
)
def _handle_pointer_arg(val, func, builder): def _handle_pointer_arg(val, func, builder):
"""Convert pointer type for bpf_printk.""" """Convert pointer type for bpf_printk."""
target, depth = get_base_type_and_depth(val.type) target, depth = get_base_type_and_depth(val.type)