mirror of
https://github.com/varun-r-mallya/Python-BPF.git
synced 2025-12-31 21:06:25 +00:00
Allow map-based helpers to be used as helper args / within binops which are helper args
This commit is contained in:
@ -39,6 +39,10 @@ def get_operand_value(
|
|||||||
if res is None:
|
if res is None:
|
||||||
raise ValueError(f"Failed to evaluate call expression: {operand}")
|
raise ValueError(f"Failed to evaluate call expression: {operand}")
|
||||||
val, _ = res
|
val, _ = res
|
||||||
|
logger.info(f"Evaluated expr to {val} of type {val.type}")
|
||||||
|
base_type, depth = get_base_type_and_depth(val.type)
|
||||||
|
if depth > 0:
|
||||||
|
val = deref_to_depth(func, builder, val, depth)
|
||||||
return val
|
return val
|
||||||
raise TypeError(f"Unsupported operand type: {type(operand)}")
|
raise TypeError(f"Unsupported operand type: {type(operand)}")
|
||||||
|
|
||||||
|
|||||||
@ -388,14 +388,18 @@ def process_stmt(
|
|||||||
return did_return
|
return did_return
|
||||||
|
|
||||||
|
|
||||||
def count_temps_in_call(call_node):
|
def count_temps_in_call(call_node, local_sym_tab):
|
||||||
"""Count the number of temporary variables needed for a function call."""
|
"""Count the number of temporary variables needed for a function call."""
|
||||||
|
|
||||||
count = 0
|
count = 0
|
||||||
is_helper = False
|
is_helper = False
|
||||||
|
|
||||||
|
# NOTE: We exclude print calls for now
|
||||||
if isinstance(call_node.func, ast.Name):
|
if isinstance(call_node.func, ast.Name):
|
||||||
if HelperHandlerRegistry.has_handler(call_node.func.id):
|
if (
|
||||||
|
HelperHandlerRegistry.has_handler(call_node.func.id)
|
||||||
|
and call_node.func.id != "print"
|
||||||
|
):
|
||||||
is_helper = True
|
is_helper = True
|
||||||
elif isinstance(call_node.func, ast.Attribute):
|
elif isinstance(call_node.func, ast.Attribute):
|
||||||
if HelperHandlerRegistry.has_handler(call_node.func.attr):
|
if HelperHandlerRegistry.has_handler(call_node.func.attr):
|
||||||
@ -405,10 +409,11 @@ def count_temps_in_call(call_node):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
for arg in call_node.args:
|
for arg in call_node.args:
|
||||||
if (
|
# NOTE: Count all non-name arguments
|
||||||
isinstance(arg, ast.BinOp)
|
# For struct fields, if it is being passed as an argument,
|
||||||
or isinstance(arg, ast.Constant)
|
# The struct object should already exist in the local_sym_tab
|
||||||
or isinstance(arg, ast.UnaryOp)
|
if not isinstance(arg, ast.Name) and not (
|
||||||
|
isinstance(arg, ast.Attribute) and arg.value.id in local_sym_tab
|
||||||
):
|
):
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
@ -423,11 +428,19 @@ def allocate_mem(
|
|||||||
|
|
||||||
def update_max_temps_for_stmt(stmt):
|
def update_max_temps_for_stmt(stmt):
|
||||||
nonlocal max_temps_needed
|
nonlocal max_temps_needed
|
||||||
|
temps_needed = 0
|
||||||
|
|
||||||
|
if isinstance(stmt, ast.If):
|
||||||
|
for s in stmt.body:
|
||||||
|
update_max_temps_for_stmt(s)
|
||||||
|
for s in stmt.orelse:
|
||||||
|
update_max_temps_for_stmt(s)
|
||||||
|
return
|
||||||
|
|
||||||
for node in ast.walk(stmt):
|
for node in ast.walk(stmt):
|
||||||
if isinstance(node, ast.Call):
|
if isinstance(node, ast.Call):
|
||||||
temps_needed = count_temps_in_call(node)
|
temps_needed += count_temps_in_call(node, local_sym_tab)
|
||||||
max_temps_needed = max(max_temps_needed, temps_needed)
|
max_temps_needed = max(max_temps_needed, temps_needed)
|
||||||
|
|
||||||
for stmt in body:
|
for stmt in body:
|
||||||
update_max_temps_for_stmt(stmt)
|
update_max_temps_for_stmt(stmt)
|
||||||
@ -460,9 +473,16 @@ def allocate_mem(
|
|||||||
logger.info("Unsupported multiassignment")
|
logger.info("Unsupported multiassignment")
|
||||||
continue
|
continue
|
||||||
target = stmt.targets[0]
|
target = stmt.targets[0]
|
||||||
if not isinstance(target, ast.Name):
|
if not isinstance(target, ast.Name) and not isinstance(
|
||||||
|
target, ast.Attribute
|
||||||
|
):
|
||||||
logger.info("Unsupported assignment target")
|
logger.info("Unsupported assignment target")
|
||||||
continue
|
continue
|
||||||
|
if isinstance(target, ast.Attribute):
|
||||||
|
logger.info(
|
||||||
|
f"Struct field {target.attr} assignment, will be handled later"
|
||||||
|
)
|
||||||
|
continue
|
||||||
var_name = target.id
|
var_name = target.id
|
||||||
rval = stmt.value
|
rval = stmt.value
|
||||||
if var_name in local_sym_tab:
|
if var_name in local_sym_tab:
|
||||||
|
|||||||
@ -34,6 +34,7 @@ def bpf_ktime_get_ns_emitter(
|
|||||||
func,
|
func,
|
||||||
local_sym_tab=None,
|
local_sym_tab=None,
|
||||||
struct_sym_tab=None,
|
struct_sym_tab=None,
|
||||||
|
map_sym_tab=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Emit LLVM IR for bpf_ktime_get_ns helper function call.
|
Emit LLVM IR for bpf_ktime_get_ns helper function call.
|
||||||
@ -56,6 +57,7 @@ def bpf_map_lookup_elem_emitter(
|
|||||||
func,
|
func,
|
||||||
local_sym_tab=None,
|
local_sym_tab=None,
|
||||||
struct_sym_tab=None,
|
struct_sym_tab=None,
|
||||||
|
map_sym_tab=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Emit LLVM IR for bpf_map_lookup_elem helper function call.
|
Emit LLVM IR for bpf_map_lookup_elem helper function call.
|
||||||
@ -65,12 +67,16 @@ def bpf_map_lookup_elem_emitter(
|
|||||||
f"Map lookup expects exactly one argument (key), got {len(call.args)}"
|
f"Map lookup expects exactly one argument (key), got {len(call.args)}"
|
||||||
)
|
)
|
||||||
key_ptr = get_or_create_ptr_from_arg(
|
key_ptr = get_or_create_ptr_from_arg(
|
||||||
func, module, call.args[0], builder, local_sym_tab, struct_sym_tab
|
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||||
)
|
)
|
||||||
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
|
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
|
||||||
|
|
||||||
|
# TODO: I have changed the return typr to i64*, as we are
|
||||||
|
# allocating space for that type in allocate_mem. This is
|
||||||
|
# temporary, and we will honour other widths later. But this
|
||||||
|
# allows us to have cool binary ops on the returned value.
|
||||||
fn_type = ir.FunctionType(
|
fn_type = ir.FunctionType(
|
||||||
ir.PointerType(), # Return type: void*
|
ir.PointerType(ir.IntType(64)), # Return type: void*
|
||||||
[ir.PointerType(), ir.PointerType()], # Args: (void*, void*)
|
[ir.PointerType(), ir.PointerType()], # Args: (void*, void*)
|
||||||
var_arg=False,
|
var_arg=False,
|
||||||
)
|
)
|
||||||
@ -93,6 +99,7 @@ def bpf_printk_emitter(
|
|||||||
func,
|
func,
|
||||||
local_sym_tab=None,
|
local_sym_tab=None,
|
||||||
struct_sym_tab=None,
|
struct_sym_tab=None,
|
||||||
|
map_sym_tab=None,
|
||||||
):
|
):
|
||||||
"""Emit LLVM IR for bpf_printk helper function call."""
|
"""Emit LLVM IR for bpf_printk helper function call."""
|
||||||
if not hasattr(func, "_fmt_counter"):
|
if not hasattr(func, "_fmt_counter"):
|
||||||
@ -140,6 +147,7 @@ def bpf_map_update_elem_emitter(
|
|||||||
func,
|
func,
|
||||||
local_sym_tab=None,
|
local_sym_tab=None,
|
||||||
struct_sym_tab=None,
|
struct_sym_tab=None,
|
||||||
|
map_sym_tab=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Emit LLVM IR for bpf_map_update_elem helper function call.
|
Emit LLVM IR for bpf_map_update_elem helper function call.
|
||||||
@ -155,10 +163,10 @@ def bpf_map_update_elem_emitter(
|
|||||||
flags_arg = call.args[2] if len(call.args) > 2 else None
|
flags_arg = call.args[2] if len(call.args) > 2 else None
|
||||||
|
|
||||||
key_ptr = get_or_create_ptr_from_arg(
|
key_ptr = get_or_create_ptr_from_arg(
|
||||||
func, module, key_arg, builder, local_sym_tab, struct_sym_tab
|
func, module, key_arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||||
)
|
)
|
||||||
value_ptr = get_or_create_ptr_from_arg(
|
value_ptr = get_or_create_ptr_from_arg(
|
||||||
func, module, value_arg, builder, local_sym_tab, struct_sym_tab
|
func, module, value_arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||||
)
|
)
|
||||||
flags_val = get_flags_val(flags_arg, builder, local_sym_tab)
|
flags_val = get_flags_val(flags_arg, builder, local_sym_tab)
|
||||||
|
|
||||||
@ -194,6 +202,7 @@ def bpf_map_delete_elem_emitter(
|
|||||||
func,
|
func,
|
||||||
local_sym_tab=None,
|
local_sym_tab=None,
|
||||||
struct_sym_tab=None,
|
struct_sym_tab=None,
|
||||||
|
map_sym_tab=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Emit LLVM IR for bpf_map_delete_elem helper function call.
|
Emit LLVM IR for bpf_map_delete_elem helper function call.
|
||||||
@ -204,7 +213,7 @@ def bpf_map_delete_elem_emitter(
|
|||||||
f"Map delete expects exactly one argument (key), got {len(call.args)}"
|
f"Map delete expects exactly one argument (key), got {len(call.args)}"
|
||||||
)
|
)
|
||||||
key_ptr = get_or_create_ptr_from_arg(
|
key_ptr = get_or_create_ptr_from_arg(
|
||||||
func, module, call.args[0], builder, local_sym_tab, struct_sym_tab
|
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||||
)
|
)
|
||||||
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
|
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
|
||||||
|
|
||||||
@ -233,6 +242,7 @@ def bpf_get_current_pid_tgid_emitter(
|
|||||||
func,
|
func,
|
||||||
local_sym_tab=None,
|
local_sym_tab=None,
|
||||||
struct_sym_tab=None,
|
struct_sym_tab=None,
|
||||||
|
map_sym_tab=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Emit LLVM IR for bpf_get_current_pid_tgid helper function call.
|
Emit LLVM IR for bpf_get_current_pid_tgid helper function call.
|
||||||
@ -259,6 +269,7 @@ def bpf_perf_event_output_handler(
|
|||||||
func,
|
func,
|
||||||
local_sym_tab=None,
|
local_sym_tab=None,
|
||||||
struct_sym_tab=None,
|
struct_sym_tab=None,
|
||||||
|
map_sym_tab=None,
|
||||||
):
|
):
|
||||||
if len(call.args) != 1:
|
if len(call.args) != 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -323,6 +334,7 @@ def handle_helper_call(
|
|||||||
func,
|
func,
|
||||||
local_sym_tab,
|
local_sym_tab,
|
||||||
struct_sym_tab,
|
struct_sym_tab,
|
||||||
|
map_sym_tab,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle direct function calls (e.g., print(), ktime())
|
# Handle direct function calls (e.g., print(), ktime())
|
||||||
|
|||||||
@ -81,14 +81,14 @@ def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64):
|
|||||||
|
|
||||||
# Default to 64-bit integer
|
# Default to 64-bit integer
|
||||||
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab)
|
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab)
|
||||||
logger.debug(f"Using temp variable '{temp_name}' for int constant {value}")
|
logger.info(f"Using temp variable '{temp_name}' for int constant {value}")
|
||||||
const_val = ir.Constant(ir.IntType(int_width), value)
|
const_val = ir.Constant(ir.IntType(int_width), value)
|
||||||
builder.store(const_val, ptr)
|
builder.store(const_val, ptr)
|
||||||
return ptr
|
return ptr
|
||||||
|
|
||||||
|
|
||||||
def get_or_create_ptr_from_arg(
|
def get_or_create_ptr_from_arg(
|
||||||
func, module, arg, builder, local_sym_tab, struct_sym_tab=None
|
func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab=None
|
||||||
):
|
):
|
||||||
"""Extract or create pointer from the call arguments."""
|
"""Extract or create pointer from the call arguments."""
|
||||||
|
|
||||||
@ -104,15 +104,17 @@ def get_or_create_ptr_from_arg(
|
|||||||
builder,
|
builder,
|
||||||
arg,
|
arg,
|
||||||
local_sym_tab,
|
local_sym_tab,
|
||||||
None,
|
map_sym_tab,
|
||||||
struct_sym_tab,
|
struct_sym_tab,
|
||||||
)
|
)
|
||||||
if val is None:
|
if val is None:
|
||||||
raise ValueError("Failed to evaluate expression for helper arg.")
|
raise ValueError("Failed to evaluate expression for helper arg.")
|
||||||
|
|
||||||
# NOTE: We assume the result is an int64 for now
|
# NOTE: We assume the result is an int64 for now
|
||||||
|
# if isinstance(arg, ast.Attribute):
|
||||||
|
# return val
|
||||||
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab)
|
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab)
|
||||||
logger.debug(f"Using temp variable '{temp_name}' for expression result")
|
logger.info(f"Using temp variable '{temp_name}' for expression result")
|
||||||
builder.store(val, ptr)
|
builder.store(val, ptr)
|
||||||
|
|
||||||
return ptr
|
return ptr
|
||||||
|
|||||||
Reference in New Issue
Block a user