diff --git a/examples/execve3.py b/examples/execve3.py index 4f3d753..cfa4dc3 100644 --- a/examples/execve3.py +++ b/examples/execve3.py @@ -37,7 +37,7 @@ def hello_again(ctx: c_void_p) -> c_int64: if x: print("we did not prevail") ts = ktime() -# last().update(key, ts) + last().update(key, ts, 0) return c_int64(0) @@ -46,5 +46,4 @@ def hello_again(ctx: c_void_p) -> c_int64: def LICENSE() -> str: return "GPL" - compile() diff --git a/pythonbpf/bpf_helper_handler.py b/pythonbpf/bpf_helper_handler.py index abefd3f..165c8a6 100644 --- a/pythonbpf/bpf_helper_handler.py +++ b/pythonbpf/bpf_helper_handler.py @@ -19,7 +19,6 @@ def bpf_map_lookup_elem_emitter(call, map_ptr, module, builder, local_sym_tab=No """ Emit LLVM IR for bpf_map_lookup_elem helper function call. """ - if call.args and len(call.args) != 1: raise ValueError("Map lookup expects exactly one argument, got " f"{len(call.args)}") @@ -94,11 +93,105 @@ def bpf_printk_emitter(call, module, builder, func): builder.call(fn_ptr, [fmt_ptr, ir.Constant( ir.IntType(32), len(fmt_str))], tail=True) +def bpf_map_update_elem_emitter(call, map_ptr, module, builder, local_sym_tab=None): + """ + Emit LLVM IR for bpf_map_update_elem helper function call. + Expected call signature: map.update(key, value, flags=0) + """ + if not call.args or len(call.args) < 2 or len(call.args) > 3: + raise ValueError("Map update expects 2 or 3 arguments (key, value, flags), got " + f"{len(call.args)}") + + key_arg = call.args[0] + value_arg = call.args[1] + flags_arg = call.args[2] if len(call.args) > 2 else None + + # Handle key + if isinstance(key_arg, ast.Name): + key_name = key_arg.id + if local_sym_tab and key_name in local_sym_tab: + key_ptr = local_sym_tab[key_name] + else: + raise ValueError( + f"Key variable {key_name} not found in local symbol table.") + elif isinstance(key_arg, ast.Constant) and isinstance(key_arg.value, int): + # Handle constant integer keys + key_val = key_arg.value + key_type = ir.IntType(64) + key_ptr = builder.alloca(key_type) + key_ptr.align = key_type.width // 8 + builder.store(ir.Constant(key_type, key_val), key_ptr) + else: + raise NotImplementedError( + "Only simple variable names and integer constants are supported as keys in map update.") + + # Handle value + if isinstance(value_arg, ast.Name): + value_name = value_arg.id + if local_sym_tab and value_name in local_sym_tab: + value_ptr = local_sym_tab[value_name] + else: + raise ValueError( + f"Value variable {value_name} not found in local symbol table.") + elif isinstance(value_arg, ast.Constant) and isinstance(value_arg.value, int): + # Handle constant integers + value_val = value_arg.value + value_type = ir.IntType(64) + value_ptr = builder.alloca(value_type) + value_ptr.align = value_type.width // 8 + builder.store(ir.Constant(value_type, value_val), value_ptr) + else: + raise NotImplementedError( + "Only simple variable names and integer constants are supported as values in map update.") + + # Handle flags argument (defaults to 0) + if flags_arg is not None: + if isinstance(flags_arg, ast.Constant) and isinstance(flags_arg.value, int): + flags_val = flags_arg.value + elif isinstance(flags_arg, ast.Name): + flags_name = flags_arg.id + if local_sym_tab and flags_name in local_sym_tab: + # Assume it's a stored integer value, load it + flags_ptr = local_sym_tab[flags_name] + flags_val = builder.load(flags_ptr) + else: + raise ValueError( + f"Flags variable {flags_name} not found in local symbol table.") + else: + raise NotImplementedError( + "Only integer constants and simple variable names are supported as flags in map update.") + else: + flags_val = 0 + + if key_ptr is None or value_ptr is None: + raise ValueError("Key pointer or value pointer is None.") + + map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) + fn_type = ir.FunctionType( + ir.IntType(64), + [ir.PointerType(), ir.PointerType(), ir.PointerType(), ir.IntType(64)], + var_arg=False + ) + fn_ptr_type = ir.PointerType(fn_type) + + # helper id + fn_addr = ir.Constant(ir.IntType(64), 2) + fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) + + if isinstance(flags_val, int): + flags_const = ir.Constant(ir.IntType(64), flags_val) + else: + flags_const = flags_val + + result = builder.call(fn_ptr, [map_void_ptr, key_ptr, value_ptr, flags_const], tail=False) + + return result helper_func_list = { "lookup": bpf_map_lookup_elem_emitter, "print": bpf_printk_emitter, "ktime": bpf_ktime_get_ns_emitter, + "update": bpf_map_update_elem_emitter, } diff --git a/pythonbpf/functions_pass.py b/pythonbpf/functions_pass.py index 2713028..57aee45 100644 --- a/pythonbpf/functions_pass.py +++ b/pythonbpf/functions_pass.py @@ -106,6 +106,18 @@ def handle_expr(func, module, builder, expr, local_sym_tab, map_sym_tab): handle_helper_call( call, module, builder, func, local_sym_tab, map_sym_tab) return + elif isinstance(call.func, ast.Attribute): + if isinstance(call.func.value, ast.Call) and isinstance(call.func.value.func, ast.Name): + method_name = call.func.attr + if method_name in helper_func_list: + handle_helper_call( + call, module, builder, func, local_sym_tab, map_sym_tab) + return + # I VIBED THIS WITHOUT UNDERSTANDING THIS PART>>>> TODO: check this later + if call.func.id in helper_func_list: + handle_helper_call( + call, module, builder, func, local_sym_tab, map_sym_tab) + return elif isinstance(call, ast.Name): if call.id in local_sym_tab: var = local_sym_tab[call.id] diff --git a/pythonbpf/maps.py b/pythonbpf/maps.py index ac245d2..02a8439 100644 --- a/pythonbpf/maps.py +++ b/pythonbpf/maps.py @@ -16,8 +16,9 @@ class HashMap: del self.entries[key] else: raise KeyError(f"Key {key} not found in map") - - def update(self, key, value): + + # TODO: define the flags that can be added + def update(self, key, value, flags=None): if key in self.entries: self.entries[key] = value else: