overhaul handle_helper_calls

This commit is contained in:
Pragyansh Chaturvedi
2025-09-21 16:10:29 +05:30
parent 3c976b88d3
commit a1371697cc
3 changed files with 65 additions and 17 deletions

View File

@ -16,7 +16,7 @@ def bpf_ktime_get_ns_emitter(call, map_ptr, module, builder, func, local_sym_tab
return result
def bpf_map_lookup_elem_emitter(call, map_ptr, module, builder, local_sym_tab=None, struct_sym_tab=None):
def bpf_map_lookup_elem_emitter(call, map_ptr, module, builder, func, local_sym_tab=None, struct_sym_tab=None):
"""
Emit LLVM IR for bpf_map_lookup_elem helper function call.
"""
@ -172,7 +172,7 @@ def bpf_printk_emitter(call, map_ptr, module, builder, func, local_sym_tab=None)
ir.IntType(32), len(fmt_str))], tail=True)
def bpf_map_update_elem_emitter(call, map_ptr, module, builder, local_sym_tab=None, struct_sym_tab=None):
def bpf_map_update_elem_emitter(call, map_ptr, module, builder, func, local_sym_tab=None, struct_sym_tab=None):
"""
Emit LLVM IR for bpf_map_update_elem helper function call.
Expected call signature: map.update(key, value, flags=0)
@ -268,7 +268,7 @@ def bpf_map_update_elem_emitter(call, map_ptr, module, builder, local_sym_tab=No
return result
def bpf_map_delete_elem_emitter(call, map_ptr, module, builder, local_sym_tab=None, struct_sym_tab=None):
def bpf_map_delete_elem_emitter(call, map_ptr, module, builder, func, local_sym_tab=None, struct_sym_tab=None):
"""
Emit LLVM IR for bpf_map_delete_elem helper function call.
Expected call signature: map.delete(key)
@ -340,8 +340,31 @@ def bpf_get_current_pid_tgid_emitter(call, map_ptr, module, builder, func, local
return pid
def bpf_perf_event_output_handler(call, map_ptr, module, builder, local_sym_tab=None, struct_sym_tab=None):
pass
def bpf_perf_event_output_handler(call, map_ptr, module, builder, func, local_sym_tab=None, struct_sym_tab=None, local_var_metadata=None):
if len(call.args) != 1:
raise ValueError("Perf event output expects exactly one argument (data), got "
f"{len(call.args)}")
data_arg = call.args[0]
ctx_ptr = func.args[0] # First argument to the function is ctx
if isinstance(data_arg, ast.Name):
data_name = data_arg.id
if local_sym_tab and data_name in local_sym_tab:
data_ptr = local_sym_tab[data_name]
else:
raise ValueError(
f"Data variable {data_name} not found in local symbol table.")
# Check is data_name is a struct
if local_var_metadata and data_name in local_var_metadata:
data_type = local_var_metadata[data_name]
if data_type in struct_sym_tab:
struct_info = struct_sym_tab[data_type]
data_size = 0
for field_type in struct_info["type"].elements:
if isinstance(field_type, ir.IntType):
data_size += field_type.width // 8
elif isinstance(field_type, ir.PointerType):
data_size += 8
helper_func_list = {
@ -355,7 +378,7 @@ helper_func_list = {
}
def handle_helper_call(call, module, builder, func, local_sym_tab=None, map_sym_tab=None, struct_sym_tab=None):
def handle_helper_call(call, module, builder, func, local_sym_tab=None, map_sym_tab=None, struct_sym_tab=None, local_var_metadata=None):
if isinstance(call.func, ast.Name):
func_name = call.func.id
if func_name in helper_func_list:
@ -372,14 +395,29 @@ def handle_helper_call(call, module, builder, func, local_sym_tab=None, map_sym_
if map_sym_tab and map_name in map_sym_tab:
map_ptr = map_sym_tab[map_name]
if method_name in helper_func_list:
print(local_var_metadata)
return helper_func_list[method_name](
call, map_ptr, module, builder, local_sym_tab, struct_sym_tab)
call, map_ptr, module, builder, func, local_sym_tab, struct_sym_tab, local_var_metadata)
else:
raise NotImplementedError(
f"Map method {method_name} is not implemented as a helper function.")
else:
raise ValueError(
f"Map variable {map_name} not found in symbol tables.")
elif isinstance(call.func.value, ast.Name):
obj_name = call.func.value.id
method_name = call.func.attr
if map_sym_tab and obj_name in map_sym_tab:
map_ptr = map_sym_tab[obj_name]
if method_name in helper_func_list:
return helper_func_list[method_name](
call, map_ptr, module, builder, func, local_sym_tab, struct_sym_tab, local_var_metadata)
else:
raise NotImplementedError(
f"Map method {method_name} is not implemented as a helper function.")
else:
raise ValueError(
f"Map variable {obj_name} not found in symbol tables.")
else:
raise NotImplementedError(
"Attribute not supported for map method calls.")

View File

@ -2,7 +2,7 @@ import ast
from llvmlite import ir
def eval_expr(func, module, builder, expr, local_sym_tab, map_sym_tab):
def eval_expr(func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab=None, local_var_metadata=None):
print(f"Evaluating expression: {expr}")
if isinstance(expr, ast.Name):
if expr.id in local_sym_tab:
@ -50,22 +50,31 @@ def eval_expr(func, module, builder, expr, local_sym_tab, map_sym_tab):
# check for helpers
if expr.func.id in helper_func_list:
return handle_helper_call(
expr, module, builder, func, local_sym_tab, map_sym_tab)
expr, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab, local_var_metadata)
elif isinstance(expr.func, ast.Attribute):
print(f"Handling method call: {ast.dump(expr.func)}")
if isinstance(expr.func.value, ast.Call) and isinstance(expr.func.value.func, ast.Name):
method_name = expr.func.attr
if method_name in helper_func_list:
return handle_helper_call(
expr, module, builder, func, local_sym_tab, map_sym_tab)
expr, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab, local_var_metadata)
elif isinstance(expr.func.value, ast.Name):
obj_name = expr.func.value.id
method_name = expr.func.attr
if obj_name in map_sym_tab:
if method_name in helper_func_list:
return handle_helper_call(
expr, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab, local_var_metadata)
print("Unsupported expression evaluation")
return None
def handle_expr(func, module, builder, expr, local_sym_tab, map_sym_tab):
def handle_expr(func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab, local_var_metadata):
"""Handle expression statements in the function body."""
print(f"Handling expression: {ast.dump(expr)}")
call = expr.value
if isinstance(call, ast.Call):
eval_expr(func, module, builder, call, local_sym_tab, map_sym_tab)
eval_expr(func, module, builder, call, local_sym_tab,
map_sym_tab, structs_sym_tab, local_var_metadata)
else:
print("Unsupported expression type")

View File

@ -57,7 +57,7 @@ def handle_assign(func, module, builder, stmt, map_sym_tab, local_sym_tab, struc
ir.Constant(ir.IntType(32), field_idx)],
inbounds=True)
val = eval_expr(func, module, builder, rval,
local_sym_tab, map_sym_tab)
local_sym_tab, map_sym_tab, structs_sym_tab)
if val is None:
print("Failed to evaluate struct field assignment")
return
@ -100,14 +100,14 @@ def handle_assign(func, module, builder, stmt, map_sym_tab, local_sym_tab, struc
# var = builder.alloca(ir.IntType(64), name=var_name)
# var.align = 8
val = handle_helper_call(
rval, module, builder, None, local_sym_tab, map_sym_tab, structs_sym_tab)
rval, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab, local_var_metadata)
builder.store(val, local_sym_tab[var_name])
# local_sym_tab[var_name] = var
print(f"Assigned constant {rval.func.id} to {var_name}")
elif call_type == "deref" and len(rval.args) == 1:
print(f"Handling deref assignment {ast.dump(rval)}")
val = eval_expr(func, module, builder, rval,
local_sym_tab, map_sym_tab)
local_sym_tab, map_sym_tab, structs_sym_tab)
if val is None:
print("Failed to evaluate deref argument")
return
@ -139,7 +139,7 @@ def handle_assign(func, module, builder, stmt, map_sym_tab, local_sym_tab, struc
map_ptr = map_sym_tab[map_name]
if method_name in helper_func_list:
val = handle_helper_call(
rval, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab)
rval, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab, local_var_metadata)
# var = builder.alloca(ir.IntType(64), name=var_name)
# var.align = 8
builder.store(val, local_sym_tab[var_name])
@ -261,7 +261,8 @@ def handle_if(func, module, builder, stmt, map_sym_tab, local_sym_tab):
def process_stmt(func, module, builder, stmt, local_sym_tab, map_sym_tab, structs_sym_tab, did_return, ret_type=ir.IntType(64)):
print(f"Processing statement: {ast.dump(stmt)}")
if isinstance(stmt, ast.Expr):
handle_expr(func, module, builder, stmt, local_sym_tab, map_sym_tab)
handle_expr(func, module, builder, stmt, local_sym_tab,
map_sym_tab, structs_sym_tab, local_var_metadata)
elif isinstance(stmt, ast.Assign):
handle_assign(func, module, builder, stmt, map_sym_tab,
local_sym_tab, structs_sym_tab)