mirror of
https://github.com/varun-r-mallya/Python-BPF.git
synced 2025-12-31 21:06:25 +00:00
add map update function support
This commit is contained in:
@ -37,7 +37,7 @@ def hello_again(ctx: c_void_p) -> c_int64:
|
|||||||
if x:
|
if x:
|
||||||
print("we did not prevail")
|
print("we did not prevail")
|
||||||
ts = ktime()
|
ts = ktime()
|
||||||
# last().update(key, ts)
|
last().update(key, ts, 0)
|
||||||
return c_int64(0)
|
return c_int64(0)
|
||||||
|
|
||||||
|
|
||||||
@ -46,5 +46,4 @@ def hello_again(ctx: c_void_p) -> c_int64:
|
|||||||
def LICENSE() -> str:
|
def LICENSE() -> str:
|
||||||
return "GPL"
|
return "GPL"
|
||||||
|
|
||||||
|
|
||||||
compile()
|
compile()
|
||||||
|
|||||||
@ -19,7 +19,6 @@ def bpf_map_lookup_elem_emitter(call, map_ptr, module, builder, local_sym_tab=No
|
|||||||
"""
|
"""
|
||||||
Emit LLVM IR for bpf_map_lookup_elem helper function call.
|
Emit LLVM IR for bpf_map_lookup_elem helper function call.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if call.args and len(call.args) != 1:
|
if call.args and len(call.args) != 1:
|
||||||
raise ValueError("Map lookup expects exactly one argument, got "
|
raise ValueError("Map lookup expects exactly one argument, got "
|
||||||
f"{len(call.args)}")
|
f"{len(call.args)}")
|
||||||
@ -94,11 +93,105 @@ def bpf_printk_emitter(call, module, builder, func):
|
|||||||
builder.call(fn_ptr, [fmt_ptr, ir.Constant(
|
builder.call(fn_ptr, [fmt_ptr, ir.Constant(
|
||||||
ir.IntType(32), len(fmt_str))], tail=True)
|
ir.IntType(32), len(fmt_str))], tail=True)
|
||||||
|
|
||||||
|
def bpf_map_update_elem_emitter(call, map_ptr, module, builder, local_sym_tab=None):
|
||||||
|
"""
|
||||||
|
Emit LLVM IR for bpf_map_update_elem helper function call.
|
||||||
|
Expected call signature: map.update(key, value, flags=0)
|
||||||
|
"""
|
||||||
|
if not call.args or len(call.args) < 2 or len(call.args) > 3:
|
||||||
|
raise ValueError("Map update expects 2 or 3 arguments (key, value, flags), got "
|
||||||
|
f"{len(call.args)}")
|
||||||
|
|
||||||
|
key_arg = call.args[0]
|
||||||
|
value_arg = call.args[1]
|
||||||
|
flags_arg = call.args[2] if len(call.args) > 2 else None
|
||||||
|
|
||||||
|
# Handle key
|
||||||
|
if isinstance(key_arg, ast.Name):
|
||||||
|
key_name = key_arg.id
|
||||||
|
if local_sym_tab and key_name in local_sym_tab:
|
||||||
|
key_ptr = local_sym_tab[key_name]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Key variable {key_name} not found in local symbol table.")
|
||||||
|
elif isinstance(key_arg, ast.Constant) and isinstance(key_arg.value, int):
|
||||||
|
# Handle constant integer keys
|
||||||
|
key_val = key_arg.value
|
||||||
|
key_type = ir.IntType(64)
|
||||||
|
key_ptr = builder.alloca(key_type)
|
||||||
|
key_ptr.align = key_type.width // 8
|
||||||
|
builder.store(ir.Constant(key_type, key_val), key_ptr)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Only simple variable names and integer constants are supported as keys in map update.")
|
||||||
|
|
||||||
|
# Handle value
|
||||||
|
if isinstance(value_arg, ast.Name):
|
||||||
|
value_name = value_arg.id
|
||||||
|
if local_sym_tab and value_name in local_sym_tab:
|
||||||
|
value_ptr = local_sym_tab[value_name]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Value variable {value_name} not found in local symbol table.")
|
||||||
|
elif isinstance(value_arg, ast.Constant) and isinstance(value_arg.value, int):
|
||||||
|
# Handle constant integers
|
||||||
|
value_val = value_arg.value
|
||||||
|
value_type = ir.IntType(64)
|
||||||
|
value_ptr = builder.alloca(value_type)
|
||||||
|
value_ptr.align = value_type.width // 8
|
||||||
|
builder.store(ir.Constant(value_type, value_val), value_ptr)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Only simple variable names and integer constants are supported as values in map update.")
|
||||||
|
|
||||||
|
# Handle flags argument (defaults to 0)
|
||||||
|
if flags_arg is not None:
|
||||||
|
if isinstance(flags_arg, ast.Constant) and isinstance(flags_arg.value, int):
|
||||||
|
flags_val = flags_arg.value
|
||||||
|
elif isinstance(flags_arg, ast.Name):
|
||||||
|
flags_name = flags_arg.id
|
||||||
|
if local_sym_tab and flags_name in local_sym_tab:
|
||||||
|
# Assume it's a stored integer value, load it
|
||||||
|
flags_ptr = local_sym_tab[flags_name]
|
||||||
|
flags_val = builder.load(flags_ptr)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Flags variable {flags_name} not found in local symbol table.")
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Only integer constants and simple variable names are supported as flags in map update.")
|
||||||
|
else:
|
||||||
|
flags_val = 0
|
||||||
|
|
||||||
|
if key_ptr is None or value_ptr is None:
|
||||||
|
raise ValueError("Key pointer or value pointer is None.")
|
||||||
|
|
||||||
|
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
|
||||||
|
fn_type = ir.FunctionType(
|
||||||
|
ir.IntType(64),
|
||||||
|
[ir.PointerType(), ir.PointerType(), ir.PointerType(), ir.IntType(64)],
|
||||||
|
var_arg=False
|
||||||
|
)
|
||||||
|
fn_ptr_type = ir.PointerType(fn_type)
|
||||||
|
|
||||||
|
# helper id
|
||||||
|
fn_addr = ir.Constant(ir.IntType(64), 2)
|
||||||
|
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
|
||||||
|
|
||||||
|
if isinstance(flags_val, int):
|
||||||
|
flags_const = ir.Constant(ir.IntType(64), flags_val)
|
||||||
|
else:
|
||||||
|
flags_const = flags_val
|
||||||
|
|
||||||
|
result = builder.call(fn_ptr, [map_void_ptr, key_ptr, value_ptr, flags_const], tail=False)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
helper_func_list = {
|
helper_func_list = {
|
||||||
"lookup": bpf_map_lookup_elem_emitter,
|
"lookup": bpf_map_lookup_elem_emitter,
|
||||||
"print": bpf_printk_emitter,
|
"print": bpf_printk_emitter,
|
||||||
"ktime": bpf_ktime_get_ns_emitter,
|
"ktime": bpf_ktime_get_ns_emitter,
|
||||||
|
"update": bpf_map_update_elem_emitter,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -106,6 +106,18 @@ def handle_expr(func, module, builder, expr, local_sym_tab, map_sym_tab):
|
|||||||
handle_helper_call(
|
handle_helper_call(
|
||||||
call, module, builder, func, local_sym_tab, map_sym_tab)
|
call, module, builder, func, local_sym_tab, map_sym_tab)
|
||||||
return
|
return
|
||||||
|
elif isinstance(call.func, ast.Attribute):
|
||||||
|
if isinstance(call.func.value, ast.Call) and isinstance(call.func.value.func, ast.Name):
|
||||||
|
method_name = call.func.attr
|
||||||
|
if method_name in helper_func_list:
|
||||||
|
handle_helper_call(
|
||||||
|
call, module, builder, func, local_sym_tab, map_sym_tab)
|
||||||
|
return
|
||||||
|
# I VIBED THIS WITHOUT UNDERSTANDING THIS PART>>>> TODO: check this later
|
||||||
|
if call.func.id in helper_func_list:
|
||||||
|
handle_helper_call(
|
||||||
|
call, module, builder, func, local_sym_tab, map_sym_tab)
|
||||||
|
return
|
||||||
elif isinstance(call, ast.Name):
|
elif isinstance(call, ast.Name):
|
||||||
if call.id in local_sym_tab:
|
if call.id in local_sym_tab:
|
||||||
var = local_sym_tab[call.id]
|
var = local_sym_tab[call.id]
|
||||||
|
|||||||
@ -17,7 +17,8 @@ class HashMap:
|
|||||||
else:
|
else:
|
||||||
raise KeyError(f"Key {key} not found in map")
|
raise KeyError(f"Key {key} not found in map")
|
||||||
|
|
||||||
def update(self, key, value):
|
# TODO: define the flags that can be added
|
||||||
|
def update(self, key, value, flags=None):
|
||||||
if key in self.entries:
|
if key in self.entries:
|
||||||
self.entries[key] = value
|
self.entries[key] = value
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user