Use HelperHandleRegitry

This commit is contained in:
Pragyansh Chaturvedi
2025-10-01 03:53:11 +05:30
parent 6cd07498fe
commit 61f6743f0a
5 changed files with 36 additions and 30 deletions

View File

@ -0,0 +1,2 @@
from .helper_utils import HelperHandlerRegistry
from .bpf_helper_handler import handle_helper_call

View File

@ -2,6 +2,7 @@ import ast
from llvmlite import ir
from pythonbpf.expr_pass import eval_expr
from enum import Enum
from .helper_utils import HelperHandlerRegistry
class BPFHelperID(Enum):
@ -14,6 +15,7 @@ class BPFHelperID(Enum):
BPF_PERF_EVENT_OUTPUT = 25
@HelperHandlerRegistry.register("ktime")
def bpf_ktime_get_ns_emitter(call, map_ptr, module, builder, func, local_sym_tab=None, struct_sym_tab=None, local_var_metadata=None):
"""
Emit LLVM IR for bpf_ktime_get_ns helper function call.
@ -27,6 +29,7 @@ def bpf_ktime_get_ns_emitter(call, map_ptr, module, builder, func, local_sym_tab
return result, ir.IntType(64)
@HelperHandlerRegistry.register("lookup")
def bpf_map_lookup_elem_emitter(call, map_ptr, module, builder, func, local_sym_tab=None, struct_sym_tab=None, local_var_metadata=None):
"""
Emit LLVM IR for bpf_map_lookup_elem helper function call.
@ -66,7 +69,8 @@ def bpf_map_lookup_elem_emitter(call, map_ptr, module, builder, func, local_sym_
fn_ptr_type = ir.PointerType(fn_type)
# Helper ID 1 is bpf_map_lookup_elem
fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_MAP_LOOKUP_ELEM.value)
fn_addr = ir.Constant(ir.IntType(
64), BPFHelperID.BPF_MAP_LOOKUP_ELEM.value)
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
result = builder.call(fn_ptr, [map_void_ptr, key_ptr], tail=False)
@ -74,6 +78,7 @@ def bpf_map_lookup_elem_emitter(call, map_ptr, module, builder, func, local_sym_
return result, ir.PointerType()
@HelperHandlerRegistry.register("print")
def bpf_printk_emitter(call, map_ptr, module, builder, func, local_sym_tab=None, struct_sym_tab=None, local_var_metadata=None):
if not hasattr(func, "_fmt_counter"):
func._fmt_counter = 0
@ -233,6 +238,7 @@ def bpf_printk_emitter(call, map_ptr, module, builder, func, local_sym_tab=None,
return None
@HelperHandlerRegistry.register("update")
def bpf_map_update_elem_emitter(call, map_ptr, module, builder, func, local_sym_tab=None, struct_sym_tab=None, local_var_metadata=None):
"""
Emit LLVM IR for bpf_map_update_elem helper function call.
@ -315,7 +321,8 @@ def bpf_map_update_elem_emitter(call, map_ptr, module, builder, func, local_sym_
fn_ptr_type = ir.PointerType(fn_type)
# helper id
fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_MAP_UPDATE_ELEM.value)
fn_addr = ir.Constant(ir.IntType(
64), BPFHelperID.BPF_MAP_UPDATE_ELEM.value)
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
if isinstance(flags_val, int):
@ -329,6 +336,7 @@ def bpf_map_update_elem_emitter(call, map_ptr, module, builder, func, local_sym_
return result, None
@HelperHandlerRegistry.register("delete")
def bpf_map_delete_elem_emitter(call, map_ptr, module, builder, func, local_sym_tab=None, struct_sym_tab=None, local_var_metadata=None):
"""
Emit LLVM IR for bpf_map_delete_elem helper function call.
@ -375,7 +383,8 @@ def bpf_map_delete_elem_emitter(call, map_ptr, module, builder, func, local_sym_
fn_ptr_type = ir.PointerType(fn_type)
# Helper ID 3 is bpf_map_delete_elem
fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_MAP_DELETE_ELEM.value)
fn_addr = ir.Constant(ir.IntType(
64), BPFHelperID.BPF_MAP_DELETE_ELEM.value)
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
# Call the helper function
@ -384,12 +393,14 @@ def bpf_map_delete_elem_emitter(call, map_ptr, module, builder, func, local_sym_
return result, None
@HelperHandlerRegistry.register("pid")
def bpf_get_current_pid_tgid_emitter(call, map_ptr, module, builder, func, local_sym_tab=None, struct_sym_tab=None, local_var_metadata=None):
"""
Emit LLVM IR for bpf_get_current_pid_tgid helper function call.
"""
# func is an arg to just have a uniform signature with other emitters
helper_id = ir.Constant(ir.IntType(64), BPFHelperID.BPF_GET_CURRENT_PID_TGID.value)
helper_id = ir.Constant(ir.IntType(
64), BPFHelperID.BPF_GET_CURRENT_PID_TGID.value)
fn_type = ir.FunctionType(ir.IntType(64), [], var_arg=False)
fn_ptr_type = ir.PointerType(fn_type)
fn_ptr = builder.inttoptr(helper_id, fn_ptr_type)
@ -442,7 +453,8 @@ def bpf_perf_event_output_handler(call, map_ptr, module, builder, func, local_sy
fn_ptr_type = ir.PointerType(fn_type)
# helper id
fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_PERF_EVENT_OUTPUT.value)
fn_addr = ir.Constant(ir.IntType(
64), BPFHelperID.BPF_PERF_EVENT_OUTPUT.value)
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
result = builder.call(
@ -453,24 +465,14 @@ def bpf_perf_event_output_handler(call, map_ptr, module, builder, func, local_sy
"Only simple object names are supported as data in perf event output.")
helper_func_list = {
"lookup": bpf_map_lookup_elem_emitter,
"print": bpf_printk_emitter,
"ktime": bpf_ktime_get_ns_emitter,
"update": bpf_map_update_elem_emitter,
"delete": bpf_map_delete_elem_emitter,
"pid": bpf_get_current_pid_tgid_emitter,
"output": bpf_perf_event_output_handler,
}
def handle_helper_call(call, module, builder, func, local_sym_tab=None, map_sym_tab=None, struct_sym_tab=None, local_var_metadata=None):
print(local_var_metadata)
if isinstance(call.func, ast.Name):
func_name = call.func.id
if func_name in helper_func_list:
hdl_func = HelperHandlerRegistry.get_handler(func_name)
if hdl_func:
# it is not a map method call
return helper_func_list[func_name](call, None, module, builder, func, local_sym_tab, struct_sym_tab, local_var_metadata)
return hdl_func(call, None, module, builder, func, local_sym_tab, struct_sym_tab, local_var_metadata)
else:
raise NotImplementedError(
f"Function {func_name} is not implemented as a helper function.")
@ -481,9 +483,10 @@ def handle_helper_call(call, module, builder, func, local_sym_tab=None, map_sym_
method_name = call.func.attr
if map_sym_tab and map_name in map_sym_tab:
map_ptr = map_sym_tab[map_name]
if method_name in helper_func_list:
hdl_func = HelperHandlerRegistry.get_handler(method_name)
if hdl_func:
print(local_var_metadata)
return helper_func_list[method_name](
return hdl_func(
call, map_ptr, module, builder, func, local_sym_tab, struct_sym_tab, local_var_metadata)
else:
raise NotImplementedError(
@ -496,8 +499,9 @@ def handle_helper_call(call, module, builder, func, local_sym_tab=None, map_sym_
method_name = call.func.attr
if map_sym_tab and obj_name in map_sym_tab:
map_ptr = map_sym_tab[obj_name]
if method_name in helper_func_list:
return helper_func_list[method_name](
hdl_func = HelperHandlerRegistry.get_handler(method_name)
if hdl_func:
return hdl_func(
call, map_ptr, module, builder, func, local_sym_tab, struct_sym_tab, local_var_metadata)
else:
raise NotImplementedError(

View File

@ -1,15 +0,0 @@
import ctypes
def ktime():
return ctypes.c_int64(0)
def pid():
return ctypes.c_int32(0)
def deref(ptr):
"dereference a pointer"
result = ctypes.cast(ptr, ctypes.POINTER(ctypes.c_void_p)).contents.value
return result if result is not None else 0
XDP_DROP = ctypes.c_int64(1)
XDP_PASS = ctypes.c_int64(2)