2 Commits

3 changed files with 31 additions and 17 deletions

View File

@ -583,16 +583,14 @@ def bpf_skb_store_bytes_emitter(
Expected call signature: skb_store_bytes(skb, offset, from, len, flags)
"""
if len(call.args) not in (4, 5):
if len(call.args) not in (3, 4):
raise ValueError(
f"skb_store_bytes expects 4 or 5 args (skb, offset, from, len, flags), got {len(call.args)}"
f"skb_store_bytes expects 3 or 4 args (offset, from, len, flags), got {len(call.args)}"
)
skb_ptr = get_or_create_ptr_from_arg(
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab
)
skb_ptr = func.args[0] # First argument to the function is skb
offset_val = get_int_value_from_arg(
call.args[1],
call.args[0],
func,
module,
builder,
@ -601,10 +599,10 @@ def bpf_skb_store_bytes_emitter(
struct_sym_tab,
)
from_ptr = get_or_create_ptr_from_arg(
func, module, call.args[2], builder, local_sym_tab, map_sym_tab, struct_sym_tab
func, module, call.args[1], builder, local_sym_tab, map_sym_tab, struct_sym_tab
)
len_val = get_int_value_from_arg(
call.args[3],
call.args[2],
func,
module,
builder,
@ -612,10 +610,11 @@ def bpf_skb_store_bytes_emitter(
map_sym_tab,
struct_sym_tab,
)
if len(call.args) == 5:
flags_val = get_flags_val(call.args[4], builder, local_sym_tab)
if len(call.args) == 4:
flags_val = get_flags_val(call.args[3], builder, local_sym_tab)
else:
flags_val = ir.Constant(ir.IntType(64), 0)
flags_val = 0
flags = ir.Constant(ir.IntType(64), flags_val)
fn_type = ir.FunctionType(
ir.IntType(64),
[
@ -638,7 +637,7 @@ def bpf_skb_store_bytes_emitter(
builder.trunc(offset_val, ir.IntType(32)),
builder.bitcast(from_ptr, ir.PointerType()),
builder.trunc(len_val, ir.IntType(32)),
flags_val,
flags,
],
tail=False,
)

View File

@ -1,17 +1,31 @@
from dataclasses import dataclass
from llvmlite import ir
from typing import Callable
@dataclass
class HelperSignature:
"""Signature of a BPF helper function"""
arg_types: list[ir.Type]
return_type: ir.Type
func: Callable
class HelperHandlerRegistry:
"""Registry for BPF helpers"""
_handlers: dict[str, Callable] = {}
_handlers: dict[str, HelperSignature] = {}
@classmethod
def register(cls, helper_name):
def register(cls, helper_name, param_types=None, return_type=None):
"""Decorator to register a handler function for a helper"""
def decorator(func):
cls._handlers[helper_name] = func
helper_sig = HelperSignature(
arg_types=param_types, return_type=return_type, func=func
)
cls._handlers[helper_name] = helper_sig
return func
return decorator
@ -19,7 +33,8 @@ class HelperHandlerRegistry:
@classmethod
def get_handler(cls, helper_name):
"""Get the handler function for a helper"""
return cls._handlers.get(helper_name)
handler = cls._handlers.get(helper_name)
return handler.func if handler else None
@classmethod
def has_handler(cls, helper_name):

View File

@ -47,7 +47,7 @@ def uid():
return ctypes.c_int32(0)
def skb_store_bytes(skb, offset, from_buf, size, flags=0):
def skb_store_bytes(offset, from_buf, size, flags=0):
"""store bytes into a socket buffer"""
return ctypes.c_int64(0)