add map update function support

This commit is contained in:
2025-09-10 23:44:29 +05:30
parent f830fbe8ba
commit 7de3a381b0
4 changed files with 110 additions and 5 deletions

View File

@ -37,7 +37,7 @@ def hello_again(ctx: c_void_p) -> c_int64:
if x:
print("we did not prevail")
ts = ktime()
# last().update(key, ts)
last().update(key, ts, 0)
return c_int64(0)
@ -46,5 +46,4 @@ def hello_again(ctx: c_void_p) -> c_int64:
def LICENSE() -> str:
return "GPL"
compile()

View File

@ -19,7 +19,6 @@ def bpf_map_lookup_elem_emitter(call, map_ptr, module, builder, local_sym_tab=No
"""
Emit LLVM IR for bpf_map_lookup_elem helper function call.
"""
if call.args and len(call.args) != 1:
raise ValueError("Map lookup expects exactly one argument, got "
f"{len(call.args)}")
@ -94,11 +93,105 @@ def bpf_printk_emitter(call, module, builder, func):
builder.call(fn_ptr, [fmt_ptr, ir.Constant(
ir.IntType(32), len(fmt_str))], tail=True)
def bpf_map_update_elem_emitter(call, map_ptr, module, builder, local_sym_tab=None):
"""
Emit LLVM IR for bpf_map_update_elem helper function call.
Expected call signature: map.update(key, value, flags=0)
"""
if not call.args or len(call.args) < 2 or len(call.args) > 3:
raise ValueError("Map update expects 2 or 3 arguments (key, value, flags), got "
f"{len(call.args)}")
key_arg = call.args[0]
value_arg = call.args[1]
flags_arg = call.args[2] if len(call.args) > 2 else None
# Handle key
if isinstance(key_arg, ast.Name):
key_name = key_arg.id
if local_sym_tab and key_name in local_sym_tab:
key_ptr = local_sym_tab[key_name]
else:
raise ValueError(
f"Key variable {key_name} not found in local symbol table.")
elif isinstance(key_arg, ast.Constant) and isinstance(key_arg.value, int):
# Handle constant integer keys
key_val = key_arg.value
key_type = ir.IntType(64)
key_ptr = builder.alloca(key_type)
key_ptr.align = key_type.width // 8
builder.store(ir.Constant(key_type, key_val), key_ptr)
else:
raise NotImplementedError(
"Only simple variable names and integer constants are supported as keys in map update.")
# Handle value
if isinstance(value_arg, ast.Name):
value_name = value_arg.id
if local_sym_tab and value_name in local_sym_tab:
value_ptr = local_sym_tab[value_name]
else:
raise ValueError(
f"Value variable {value_name} not found in local symbol table.")
elif isinstance(value_arg, ast.Constant) and isinstance(value_arg.value, int):
# Handle constant integers
value_val = value_arg.value
value_type = ir.IntType(64)
value_ptr = builder.alloca(value_type)
value_ptr.align = value_type.width // 8
builder.store(ir.Constant(value_type, value_val), value_ptr)
else:
raise NotImplementedError(
"Only simple variable names and integer constants are supported as values in map update.")
# Handle flags argument (defaults to 0)
if flags_arg is not None:
if isinstance(flags_arg, ast.Constant) and isinstance(flags_arg.value, int):
flags_val = flags_arg.value
elif isinstance(flags_arg, ast.Name):
flags_name = flags_arg.id
if local_sym_tab and flags_name in local_sym_tab:
# Assume it's a stored integer value, load it
flags_ptr = local_sym_tab[flags_name]
flags_val = builder.load(flags_ptr)
else:
raise ValueError(
f"Flags variable {flags_name} not found in local symbol table.")
else:
raise NotImplementedError(
"Only integer constants and simple variable names are supported as flags in map update.")
else:
flags_val = 0
if key_ptr is None or value_ptr is None:
raise ValueError("Key pointer or value pointer is None.")
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
fn_type = ir.FunctionType(
ir.IntType(64),
[ir.PointerType(), ir.PointerType(), ir.PointerType(), ir.IntType(64)],
var_arg=False
)
fn_ptr_type = ir.PointerType(fn_type)
# helper id
fn_addr = ir.Constant(ir.IntType(64), 2)
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
if isinstance(flags_val, int):
flags_const = ir.Constant(ir.IntType(64), flags_val)
else:
flags_const = flags_val
result = builder.call(fn_ptr, [map_void_ptr, key_ptr, value_ptr, flags_const], tail=False)
return result
helper_func_list = {
"lookup": bpf_map_lookup_elem_emitter,
"print": bpf_printk_emitter,
"ktime": bpf_ktime_get_ns_emitter,
"update": bpf_map_update_elem_emitter,
}

View File

@ -106,6 +106,18 @@ def handle_expr(func, module, builder, expr, local_sym_tab, map_sym_tab):
handle_helper_call(
call, module, builder, func, local_sym_tab, map_sym_tab)
return
elif isinstance(call.func, ast.Attribute):
if isinstance(call.func.value, ast.Call) and isinstance(call.func.value.func, ast.Name):
method_name = call.func.attr
if method_name in helper_func_list:
handle_helper_call(
call, module, builder, func, local_sym_tab, map_sym_tab)
return
# I VIBED THIS WITHOUT UNDERSTANDING THIS PART>>>> TODO: check this later
if call.func.id in helper_func_list:
handle_helper_call(
call, module, builder, func, local_sym_tab, map_sym_tab)
return
elif isinstance(call, ast.Name):
if call.id in local_sym_tab:
var = local_sym_tab[call.id]

View File

@ -16,8 +16,9 @@ class HashMap:
del self.entries[key]
else:
raise KeyError(f"Key {key} not found in map")
def update(self, key, value):
# TODO: define the flags that can be added
def update(self, key, value, flags=None):
if key in self.entries:
self.entries[key] = value
else: