mirror of
https://github.com/varun-r-mallya/Python-BPF.git
synced 2026-02-12 16:10:59 +00:00
Compare commits
3 Commits
123a92af1d
...
3078d4224d
| Author | SHA1 | Date | |
|---|---|---|---|
| 3078d4224d | |||
| 7d29790f00 | |||
| 963e2a8171 |
@ -50,7 +50,7 @@ def count_temps_in_call(call_node, local_sym_tab):
|
|||||||
func_name = call_node.func.attr
|
func_name = call_node.func.attr
|
||||||
|
|
||||||
if not is_helper:
|
if not is_helper:
|
||||||
return 0
|
return {} # No temps needed
|
||||||
|
|
||||||
for arg_idx in range(len(call_node.args)):
|
for arg_idx in range(len(call_node.args)):
|
||||||
# NOTE: Count all non-name arguments
|
# NOTE: Count all non-name arguments
|
||||||
|
|||||||
@ -567,7 +567,11 @@ def bpf_get_current_uid_gid_emitter(
|
|||||||
return pid, ir.IntType(64)
|
return pid, ir.IntType(64)
|
||||||
|
|
||||||
|
|
||||||
@HelperHandlerRegistry.register("skb_store_bytes")
|
@HelperHandlerRegistry.register(
|
||||||
|
"skb_store_bytes",
|
||||||
|
param_types=[ir.IntType(32), ir.PointerType(), ir.IntType(32), ir.IntType(64)],
|
||||||
|
return_type=ir.IntType(64),
|
||||||
|
)
|
||||||
def bpf_skb_store_bytes_emitter(
|
def bpf_skb_store_bytes_emitter(
|
||||||
call,
|
call,
|
||||||
map_ptr,
|
map_ptr,
|
||||||
@ -583,6 +587,14 @@ def bpf_skb_store_bytes_emitter(
|
|||||||
Expected call signature: skb_store_bytes(skb, offset, from, len, flags)
|
Expected call signature: skb_store_bytes(skb, offset, from, len, flags)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
args_signature = [
|
||||||
|
ir.PointerType(), # skb pointer
|
||||||
|
ir.IntType(32), # offset
|
||||||
|
ir.PointerType(), # from
|
||||||
|
ir.IntType(32), # len
|
||||||
|
ir.IntType(64), # flags
|
||||||
|
]
|
||||||
|
|
||||||
if len(call.args) not in (3, 4):
|
if len(call.args) not in (3, 4):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"skb_store_bytes expects 3 or 4 args (offset, from, len, flags), got {len(call.args)}"
|
f"skb_store_bytes expects 3 or 4 args (offset, from, len, flags), got {len(call.args)}"
|
||||||
@ -596,10 +608,18 @@ def bpf_skb_store_bytes_emitter(
|
|||||||
builder,
|
builder,
|
||||||
local_sym_tab,
|
local_sym_tab,
|
||||||
map_sym_tab,
|
map_sym_tab,
|
||||||
|
args_signature[1],
|
||||||
struct_sym_tab,
|
struct_sym_tab,
|
||||||
)
|
)
|
||||||
from_ptr = get_or_create_ptr_from_arg(
|
from_ptr = get_or_create_ptr_from_arg(
|
||||||
func, module, call.args[1], builder, local_sym_tab, map_sym_tab, struct_sym_tab
|
func,
|
||||||
|
module,
|
||||||
|
call.args[1],
|
||||||
|
builder,
|
||||||
|
local_sym_tab,
|
||||||
|
map_sym_tab,
|
||||||
|
args_signature[2],
|
||||||
|
struct_sym_tab,
|
||||||
)
|
)
|
||||||
len_val = get_int_value_from_arg(
|
len_val = get_int_value_from_arg(
|
||||||
call.args[2],
|
call.args[2],
|
||||||
@ -608,6 +628,7 @@ def bpf_skb_store_bytes_emitter(
|
|||||||
builder,
|
builder,
|
||||||
local_sym_tab,
|
local_sym_tab,
|
||||||
map_sym_tab,
|
map_sym_tab,
|
||||||
|
args_signature[3],
|
||||||
struct_sym_tab,
|
struct_sym_tab,
|
||||||
)
|
)
|
||||||
if len(call.args) == 4:
|
if len(call.args) == 4:
|
||||||
@ -617,13 +638,7 @@ def bpf_skb_store_bytes_emitter(
|
|||||||
flags = ir.Constant(ir.IntType(64), flags_val)
|
flags = ir.Constant(ir.IntType(64), flags_val)
|
||||||
fn_type = ir.FunctionType(
|
fn_type = ir.FunctionType(
|
||||||
ir.IntType(64),
|
ir.IntType(64),
|
||||||
[
|
args_signature,
|
||||||
ir.PointerType(), # skb
|
|
||||||
ir.IntType(32), # offset
|
|
||||||
ir.PointerType(), # from
|
|
||||||
ir.IntType(32), # len
|
|
||||||
ir.IntType(64), # flags
|
|
||||||
],
|
|
||||||
var_arg=False,
|
var_arg=False,
|
||||||
)
|
)
|
||||||
fn_ptr = builder.inttoptr(
|
fn_ptr = builder.inttoptr(
|
||||||
|
|||||||
@ -50,7 +50,7 @@ class HelperHandlerRegistry:
|
|||||||
def get_param_type(cls, helper_name, index):
|
def get_param_type(cls, helper_name, index):
|
||||||
"""Get the type of a parameter of a helper function by the index"""
|
"""Get the type of a parameter of a helper function by the index"""
|
||||||
signature = cls.get_signature(helper_name)
|
signature = cls.get_signature(helper_name)
|
||||||
if signature and 0 <= index < len(signature.arg_types):
|
if signature and signature.arg_types and 0 <= index < len(signature.arg_types):
|
||||||
return signature.arg_types[index]
|
return signature.arg_types[index]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -14,26 +14,43 @@ class ScratchPoolManager:
|
|||||||
"""Manage the temporary helper variables in local_sym_tab"""
|
"""Manage the temporary helper variables in local_sym_tab"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._counter = 0
|
self._counters = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def counter(self):
|
def counter(self):
|
||||||
return self._counter
|
return sum(self._counter.values())
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self._counter = 0
|
self._counters.clear()
|
||||||
logger.debug("Scratch pool counter reset to 0")
|
logger.debug("Scratch pool counter reset to 0")
|
||||||
|
|
||||||
def get_next_temp(self, local_sym_tab):
|
def _get_type_name(self, ir_type):
|
||||||
temp_name = f"__helper_temp_{self._counter}"
|
if isinstance(ir_type, ir.PointerType):
|
||||||
self._counter += 1
|
return "ptr"
|
||||||
|
elif isinstance(ir_type, ir.IntType):
|
||||||
|
return f"i{ir_type.width}"
|
||||||
|
elif isinstance(ir_type, ir.ArrayType):
|
||||||
|
return f"[{ir_type.count}x{self._get_type_name(ir_type.element)}]"
|
||||||
|
else:
|
||||||
|
return str(ir_type).replace(" ", "")
|
||||||
|
|
||||||
|
def get_next_temp(self, local_sym_tab, expected_type=None):
|
||||||
|
# Default to i64 if no expected type provided
|
||||||
|
type_name = self._get_type_name(expected_type) if expected_type else "i64"
|
||||||
|
if type_name not in self._counters:
|
||||||
|
self._counters[type_name] = 0
|
||||||
|
|
||||||
|
counter = self._counters[type_name]
|
||||||
|
temp_name = f"__helper_temp_{type_name}_{counter}"
|
||||||
|
self._counters[type_name] += 1
|
||||||
|
|
||||||
if temp_name not in local_sym_tab:
|
if temp_name not in local_sym_tab:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Scratch pool exhausted or inadequate: {temp_name}. "
|
f"Scratch pool exhausted or inadequate: {temp_name}. "
|
||||||
f"Current counter: {self._counter}"
|
f"Type: {type_name} Counter: {counter}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Using {temp_name} for type {type_name}")
|
||||||
return local_sym_tab[temp_name].var, temp_name
|
return local_sym_tab[temp_name].var, temp_name
|
||||||
|
|
||||||
|
|
||||||
@ -60,24 +77,35 @@ def get_var_ptr_from_name(var_name, local_sym_tab):
|
|||||||
def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64):
|
def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64):
|
||||||
"""Create a pointer to an integer constant."""
|
"""Create a pointer to an integer constant."""
|
||||||
|
|
||||||
# Default to 64-bit integer
|
int_type = ir.IntType(int_width)
|
||||||
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab)
|
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab, int_type)
|
||||||
logger.info(f"Using temp variable '{temp_name}' for int constant {value}")
|
logger.info(f"Using temp variable '{temp_name}' for int constant {value}")
|
||||||
const_val = ir.Constant(ir.IntType(int_width), value)
|
const_val = ir.Constant(int_type, value)
|
||||||
builder.store(const_val, ptr)
|
builder.store(const_val, ptr)
|
||||||
return ptr
|
return ptr
|
||||||
|
|
||||||
|
|
||||||
def get_or_create_ptr_from_arg(
|
def get_or_create_ptr_from_arg(
|
||||||
func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab=None
|
func,
|
||||||
|
module,
|
||||||
|
arg,
|
||||||
|
builder,
|
||||||
|
local_sym_tab,
|
||||||
|
map_sym_tab,
|
||||||
|
expected_type=None,
|
||||||
|
struct_sym_tab=None,
|
||||||
):
|
):
|
||||||
"""Extract or create pointer from the call arguments."""
|
"""Extract or create pointer from the call arguments."""
|
||||||
|
|
||||||
if isinstance(arg, ast.Name):
|
if isinstance(arg, ast.Name):
|
||||||
|
# Stack space is already allocated
|
||||||
ptr = get_var_ptr_from_name(arg.id, local_sym_tab)
|
ptr = get_var_ptr_from_name(arg.id, local_sym_tab)
|
||||||
elif isinstance(arg, ast.Constant) and isinstance(arg.value, int):
|
elif isinstance(arg, ast.Constant) and isinstance(arg.value, int):
|
||||||
ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab)
|
if expected_type and isinstance(expected_type, ir.IntType):
|
||||||
|
int_width = expected_type.width
|
||||||
|
ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab, int_width)
|
||||||
else:
|
else:
|
||||||
|
# NOTE: For any integer expression reaching this branch, it is probably a struct field or a binop
|
||||||
# Evaluate the expression and store the result in a temp variable
|
# Evaluate the expression and store the result in a temp variable
|
||||||
val = get_operand_value(
|
val = get_operand_value(
|
||||||
func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab
|
func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||||
@ -85,11 +113,14 @@ def get_or_create_ptr_from_arg(
|
|||||||
if val is None:
|
if val is None:
|
||||||
raise ValueError("Failed to evaluate expression for helper arg.")
|
raise ValueError("Failed to evaluate expression for helper arg.")
|
||||||
|
|
||||||
# NOTE: We assume the result is an int64 for now
|
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab, expected_type)
|
||||||
# if isinstance(arg, ast.Attribute):
|
|
||||||
# return val
|
|
||||||
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab)
|
|
||||||
logger.info(f"Using temp variable '{temp_name}' for expression result")
|
logger.info(f"Using temp variable '{temp_name}' for expression result")
|
||||||
|
if (
|
||||||
|
isinstance(val.type, ir.IntType)
|
||||||
|
and expected_type
|
||||||
|
and val.type.width > expected_type.width
|
||||||
|
):
|
||||||
|
val = builder.trunc(val, expected_type)
|
||||||
builder.store(val, ptr)
|
builder.store(val, ptr)
|
||||||
|
|
||||||
return ptr
|
return ptr
|
||||||
|
|||||||
Reference in New Issue
Block a user