diff --git a/BCC-Examples/sync_perf_output.py b/BCC-Examples/sync_perf_output.py index 4b91b21..f778fd5 100644 --- a/BCC-Examples/sync_perf_output.py +++ b/BCC-Examples/sync_perf_output.py @@ -68,8 +68,6 @@ def callback(cpu, event): perf = b["events"].open_perf_buffer(callback, struct_name="data_t") print("Starting to poll... (Ctrl+C to stop)") -print("Try running: fork() or clone() system calls to trigger events") - try: while True: b["events"].poll(1000) diff --git a/pyproject.toml b/pyproject.toml index 3d0f11d..9ab259d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ ] readme = "README.md" license = {text = "Apache-2.0"} -requires-python = ">=3.8" +requires-python = ">=3.10" dependencies = [ "llvmlite", diff --git a/pythonbpf/allocation_pass.py b/pythonbpf/allocation_pass.py index b96a9cf..b5fa37c 100644 --- a/pythonbpf/allocation_pass.py +++ b/pythonbpf/allocation_pass.py @@ -1,12 +1,13 @@ import ast import logging - +import ctypes from llvmlite import ir from .local_symbol import LocalSymbol from pythonbpf.helper import HelperHandlerRegistry from pythonbpf.vmlinux_parser.dependency_node import Field from .expr import VmlinuxHandlerRegistry from pythonbpf.type_deducer import ctypes_to_ir +from pythonbpf.maps import BPFMapType logger = logging.getLogger(__name__) @@ -25,7 +26,9 @@ def create_targets_and_rvals(stmt): return stmt.targets, [stmt.value] -def handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab): +def handle_assign_allocation( + builder, stmt, local_sym_tab, map_sym_tab, structs_sym_tab +): """Handle memory allocation for assignment statements.""" logger.info(f"Handling assignment for allocation: {ast.dump(stmt)}") @@ -55,7 +58,9 @@ def handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab): # Determine type and allocate based on rval if isinstance(rval, ast.Call): - _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab) + _allocate_for_call( + builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab + ) elif isinstance(rval, ast.Constant): _allocate_for_constant(builder, var_name, rval, local_sym_tab) elif isinstance(rval, ast.BinOp): @@ -74,14 +79,16 @@ def handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab): ) -def _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab): +def _allocate_for_call( + builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab +): """Allocate memory for variable assigned from a call.""" if isinstance(rval.func, ast.Name): call_type = rval.func.id # C type constructors - if call_type in ("c_int32", "c_int64", "c_uint32", "c_uint64"): + if call_type in ("c_int32", "c_int64", "c_uint32", "c_uint64", "c_void_p"): ir_type = ctypes_to_ir(call_type) var = builder.alloca(ir_type, name=var_name) var.align = ir_type.width // 8 @@ -116,15 +123,74 @@ def _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab): elif isinstance(rval.func, ast.Attribute): # Map method calls - need double allocation for ptr handling - _allocate_for_map_method(builder, var_name, local_sym_tab) + _allocate_for_map_method( + builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab + ) else: logger.warning(f"Unsupported call function type for {var_name}") -def _allocate_for_map_method(builder, var_name, local_sym_tab): +def _allocate_for_map_method( + builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab +): """Allocate memory for variable assigned from map method (double alloc).""" + map_name = rval.func.value.id + method_name = rval.func.attr + + # NOTE: We will have to special case HashMap.lookup which returns a pointer to value type + # The value type can be a struct as well, so we need to handle that properly + # This special casing is not ideal, as over time other map methods may need similar handling + # But for now, we will just handle lookup specifically + if map_name not in map_sym_tab: + logger.error(f"Map '{map_name}' not found for allocation") + return + + if method_name != "lookup": + # Fallback allocation for other map methods + _allocate_for_map_method_fallback(builder, var_name, local_sym_tab) + return + + map_params = map_sym_tab[map_name].params + if map_params["type"] != BPFMapType.HASH: + logger.warning( + "Map method lookup used on non-hash map, using fallback allocation" + ) + _allocate_for_map_method_fallback(builder, var_name, local_sym_tab) + return + + value_type = map_params["value"] + # Determine IR type for value + if isinstance(value_type, str) and value_type in structs_sym_tab: + struct_info = structs_sym_tab[value_type] + value_ir_type = struct_info.ir_type + else: + value_ir_type = ctypes_to_ir(value_type) + + if value_ir_type is None: + logger.warning( + f"Could not determine IR type for map value '{value_type}', using fallback allocation" + ) + _allocate_for_map_method_fallback(builder, var_name, local_sym_tab) + return + + # Main variable (pointer to pointer) + ir_type = ir.PointerType(ir.IntType(64)) + var = builder.alloca(ir_type, name=var_name) + local_sym_tab[var_name] = LocalSymbol(var, ir_type) + # Temporary variable for computed values + tmp_ir_type = value_ir_type + var_tmp = builder.alloca(tmp_ir_type, name=f"{var_name}_tmp") + local_sym_tab[f"{var_name}_tmp"] = LocalSymbol(var_tmp, tmp_ir_type) + logger.info( + f"Pre-allocated {var_name} and {var_name}_tmp for map method lookup of type {value_ir_type}" + ) + + +def _allocate_for_map_method_fallback(builder, var_name, local_sym_tab): + """Fallback allocation for map method variable (i64* and i64**).""" + # Main variable (pointer to pointer) ir_type = ir.PointerType(ir.IntType(64)) var = builder.alloca(ir_type, name=var_name) @@ -135,7 +201,9 @@ def _allocate_for_map_method(builder, var_name, local_sym_tab): var_tmp = builder.alloca(tmp_ir_type, name=f"{var_name}_tmp") local_sym_tab[f"{var_name}_tmp"] = LocalSymbol(var_tmp, tmp_ir_type) - logger.info(f"Pre-allocated {var_name} and {var_name}_tmp for map method") + logger.info( + f"Pre-allocated {var_name} and {var_name}_tmp for map method (fallback)" + ) def _allocate_for_constant(builder, var_name, rval, local_sym_tab): @@ -177,17 +245,33 @@ def _allocate_for_binop(builder, var_name, local_sym_tab): logger.info(f"Pre-allocated {var_name} for binop result") +def _get_type_name(ir_type): + """Get a string representation of an IR type.""" + if isinstance(ir_type, ir.IntType): + return f"i{ir_type.width}" + elif isinstance(ir_type, ir.PointerType): + return "ptr" + elif isinstance(ir_type, ir.ArrayType): + return f"[{ir_type.count}x{_get_type_name(ir_type.element)}]" + else: + return str(ir_type).replace(" ", "") + + def allocate_temp_pool(builder, max_temps, local_sym_tab): """Allocate the temporary scratch space pool for helper arguments.""" - if max_temps == 0: + if not max_temps: + logger.info("No temp pool allocation needed") return - logger.info(f"Allocating temp pool of {max_temps} variables") - for i in range(max_temps): - temp_name = f"__helper_temp_{i}" - temp_var = builder.alloca(ir.IntType(64), name=temp_name) - temp_var.align = 8 - local_sym_tab[temp_name] = LocalSymbol(temp_var, ir.IntType(64)) + for tmp_type, cnt in max_temps.items(): + type_name = _get_type_name(tmp_type) + logger.info(f"Allocating temp pool of {cnt} variables of type {type_name}") + for i in range(cnt): + temp_name = f"__helper_temp_{type_name}_{i}" + temp_var = builder.alloca(tmp_type, name=temp_name) + temp_var.align = _get_alignment(tmp_type) + local_sym_tab[temp_name] = LocalSymbol(temp_var, tmp_type) + logger.debug(f"Allocated temp variable: {temp_name}") def _allocate_for_name(builder, var_name, rval, local_sym_tab): @@ -249,7 +333,58 @@ def _allocate_for_attribute(builder, var_name, rval, local_sym_tab, structs_sym_ ].var = base_ptr # This is repurposing of var to store the pointer of the base type local_sym_tab[struct_var].ir_type = field_ir - actual_ir_type = ir.IntType(64) + # Determine the actual IR type based on the field's type + actual_ir_type = None + + # Check if it's a ctypes primitive + if field.type.__module__ == ctypes.__name__: + try: + field_size_bytes = ctypes.sizeof(field.type) + field_size_bits = field_size_bytes * 8 + + if field_size_bits in [8, 16, 32, 64]: + # Special case: struct_xdp_md i32 fields should allocate as i64 + # because load_ctx_field will zero-extend them to i64 + if ( + vmlinux_struct_name == "struct_xdp_md" + and field_size_bits == 32 + ): + actual_ir_type = ir.IntType(64) + logger.info( + f"Allocating {var_name} as i64 for i32 field from struct_xdp_md.{field_name} " + "(will be zero-extended during load)" + ) + else: + actual_ir_type = ir.IntType(field_size_bits) + else: + logger.warning( + f"Unusual field size {field_size_bits} bits for {field_name}" + ) + actual_ir_type = ir.IntType(64) + except Exception as e: + logger.warning( + f"Could not determine size for ctypes field {field_name}: {e}" + ) + actual_ir_type = ir.IntType(64) + + # Check if it's a nested vmlinux struct or complex type + elif field.type.__module__ == "vmlinux": + # For pointers to structs, use pointer type (64-bit) + if field.ctype_complex_type is not None and issubclass( + field.ctype_complex_type, ctypes._Pointer + ): + actual_ir_type = ir.IntType(64) # Pointer is always 64-bit + # For embedded structs, this is more complex - might need different handling + else: + logger.warning( + f"Field {field_name} is a nested vmlinux struct, using i64 for now" + ) + actual_ir_type = ir.IntType(64) + else: + logger.warning( + f"Unknown field type module {field.type.__module__} for {field_name}" + ) + actual_ir_type = ir.IntType(64) # Allocate with the actual IR type, not the GlobalVariable var = _allocate_with_type(builder, var_name, actual_ir_type) diff --git a/pythonbpf/assign_pass.py b/pythonbpf/assign_pass.py index a1c2798..0bd48c6 100644 --- a/pythonbpf/assign_pass.py +++ b/pythonbpf/assign_pass.py @@ -152,15 +152,30 @@ def handle_variable_assignment( if val_type != var_type: if isinstance(val_type, Field): logger.info("Handling assignment to struct field") + # Special handling for struct_xdp_md i32 fields that are zero-extended to i64 + # The load_ctx_field already extended them, so val is i64 but val_type.type shows c_uint + if ( + hasattr(val_type, "type") + and val_type.type.__name__ == "c_uint" + and isinstance(var_type, ir.IntType) + and var_type.width == 64 + ): + # This is the struct_xdp_md case - value is already i64 + builder.store(val, var_ptr) + logger.info( + f"Assigned zero-extended struct_xdp_md i32 field to {var_name} (i64)" + ) + return True # TODO: handling only ctype struct fields for now. Handle other stuff too later. - if var_type == ctypes_to_ir(val_type.type.__name__): + elif var_type == ctypes_to_ir(val_type.type.__name__): builder.store(val, var_ptr) logger.info(f"Assigned ctype struct field to {var_name}") return True - logger.error( - f"Failed to assign ctype struct field to {var_name}: {val_type} != {var_type}" - ) - return False + else: + logger.error( + f"Failed to assign ctype struct field to {var_name}: {val_type} != {var_type}" + ) + return False elif isinstance(val_type, ir.IntType) and isinstance(var_type, ir.IntType): # Allow implicit int widening if val_type.width < var_type.width: diff --git a/pythonbpf/codegen.py b/pythonbpf/codegen.py index 60628e6..e22b7bd 100644 --- a/pythonbpf/codegen.py +++ b/pythonbpf/codegen.py @@ -86,7 +86,7 @@ def processor(source_code, filename, module): license_processing(tree, module) globals_processing(tree, module) structs_sym_tab = structs_proc(tree, module, bpf_chunks) - map_sym_tab = maps_proc(tree, module, bpf_chunks) + map_sym_tab = maps_proc(tree, module, bpf_chunks, structs_sym_tab) func_proc(tree, module, bpf_chunks, map_sym_tab, structs_sym_tab) globals_list_creation(tree, module) @@ -218,13 +218,11 @@ def compile(loglevel=logging.WARNING) -> bool: def BPF(loglevel=logging.WARNING) -> BpfObject: caller_frame = inspect.stack()[1] src = inspect.getsource(caller_frame.frame) - with tempfile.NamedTemporaryFile( - mode="w+", delete=True, suffix=".py" - ) as f, tempfile.NamedTemporaryFile( - mode="w+", delete=True, suffix=".ll" - ) as inter, tempfile.NamedTemporaryFile( - mode="w+", delete=False, suffix=".o" - ) as obj_file: + with ( + tempfile.NamedTemporaryFile(mode="w+", delete=True, suffix=".py") as f, + tempfile.NamedTemporaryFile(mode="w+", delete=True, suffix=".ll") as inter, + tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".o") as obj_file, + ): f.write(src) f.flush() source = f.name diff --git a/pythonbpf/debuginfo/debug_info_generator.py b/pythonbpf/debuginfo/debug_info_generator.py index cca467e..4b96d22 100644 --- a/pythonbpf/debuginfo/debug_info_generator.py +++ b/pythonbpf/debuginfo/debug_info_generator.py @@ -49,6 +49,10 @@ class DebugInfoGenerator: ) return self._type_cache[key] + def get_uint8_type(self) -> Any: + """Get debug info for signed 8-bit integer""" + return self.get_basic_type("char", 8, dc.DW_ATE_unsigned) + def get_int32_type(self) -> Any: """Get debug info for signed 32-bit integer""" return self.get_basic_type("int", 32, dc.DW_ATE_signed) diff --git a/pythonbpf/expr/expr_pass.py b/pythonbpf/expr/expr_pass.py index 1d10fcb..a9eab98 100644 --- a/pythonbpf/expr/expr_pass.py +++ b/pythonbpf/expr/expr_pass.py @@ -12,6 +12,7 @@ from .type_normalization import ( get_base_type_and_depth, deref_to_depth, ) +from pythonbpf.vmlinux_parser.assignment_info import Field from .vmlinux_registry import VmlinuxHandlerRegistry logger: Logger = logging.getLogger(__name__) @@ -279,16 +280,45 @@ def _handle_ctypes_call( call_type = expr.func.id expected_type = ctypes_to_ir(call_type) - if val[1] != expected_type: + # Extract the actual IR value and type + # val could be (value, ir_type) or (value, Field) + value, val_type = val + + # If val_type is a Field object (from vmlinux struct), get the actual IR type of the value + if isinstance(val_type, Field): + # The value is already the correct IR value (potentially zero-extended) + # Get the IR type from the value itself + actual_ir_type = value.type + logger.info( + f"Converting vmlinux field {val_type.name} (IR type: {actual_ir_type}) to {call_type}" + ) + else: + actual_ir_type = val_type + + if actual_ir_type != expected_type: # NOTE: We are only considering casting to and from int types for now - if isinstance(val[1], ir.IntType) and isinstance(expected_type, ir.IntType): - if val[1].width < expected_type.width: - val = (builder.sext(val[0], expected_type), expected_type) + if isinstance(actual_ir_type, ir.IntType) and isinstance( + expected_type, ir.IntType + ): + if actual_ir_type.width < expected_type.width: + value = builder.sext(value, expected_type) + logger.info( + f"Sign-extended from i{actual_ir_type.width} to i{expected_type.width}" + ) + elif actual_ir_type.width > expected_type.width: + value = builder.trunc(value, expected_type) + logger.info( + f"Truncated from i{actual_ir_type.width} to i{expected_type.width}" + ) else: - val = (builder.trunc(val[0], expected_type), expected_type) + # Same width, just use as-is (e.g., both i64) + pass else: - raise ValueError(f"Type mismatch: expected {expected_type}, got {val[1]}") - return val + raise ValueError( + f"Type mismatch: expected {expected_type}, got {actual_ir_type} (original type: {val_type})" + ) + + return value, expected_type def _handle_compare( diff --git a/pythonbpf/functions/function_debug_info.py b/pythonbpf/functions/function_debug_info.py index f924ebc..985eb92 100644 --- a/pythonbpf/functions/function_debug_info.py +++ b/pythonbpf/functions/function_debug_info.py @@ -49,17 +49,27 @@ def generate_function_debug_info( "The first argument should always be a pointer to a struct or a void pointer" ) context_debug_info = VmlinuxHandlerRegistry.get_struct_debug_info(annotation.id) + + # Create pointer to context this must be created fresh for each function + # to avoid circular reference issues when the same struct is used in multiple functions pointer_to_context_debug_info = generator.create_pointer_type( context_debug_info, 64 ) + + # Create subroutine type - also fresh for each function subroutine_type = generator.create_subroutine_type( return_type, pointer_to_context_debug_info ) + + # Create local variable - fresh for each function with unique name context_local_variable = generator.create_local_variable_debug_info( leading_argument_name, 1, pointer_to_context_debug_info ) + retained_nodes = [context_local_variable] - print("function name", func_node.name) + logger.info(f"Generating debug info for function {func_node.name}") + + # Create subprogram with is_distinct=True to ensure each function gets unique debug info subprogram_debug_info = generator.create_subprogram( func_node.name, subroutine_type, retained_nodes ) diff --git a/pythonbpf/functions/functions_pass.py b/pythonbpf/functions/functions_pass.py index f300e12..f78ed92 100644 --- a/pythonbpf/functions/functions_pass.py +++ b/pythonbpf/functions/functions_pass.py @@ -39,7 +39,7 @@ logger = logging.getLogger(__name__) def count_temps_in_call(call_node, local_sym_tab): """Count the number of temporary variables needed for a function call.""" - count = 0 + count = {} is_helper = False # NOTE: We exclude print calls for now @@ -49,21 +49,28 @@ def count_temps_in_call(call_node, local_sym_tab): and call_node.func.id != "print" ): is_helper = True + func_name = call_node.func.id elif isinstance(call_node.func, ast.Attribute): if HelperHandlerRegistry.has_handler(call_node.func.attr): is_helper = True + func_name = call_node.func.attr if not is_helper: - return 0 + return {} # No temps needed - for arg in call_node.args: + for arg_idx in range(len(call_node.args)): # NOTE: Count all non-name arguments # For struct fields, if it is being passed as an argument, # The struct object should already exist in the local_sym_tab - if not isinstance(arg, ast.Name) and not ( + arg = call_node.args[arg_idx] + if isinstance(arg, ast.Name) or ( isinstance(arg, ast.Attribute) and arg.value.id in local_sym_tab ): - count += 1 + continue + param_type = HelperHandlerRegistry.get_param_type(func_name, arg_idx) + if isinstance(param_type, ir.PointerType): + pointee_type = param_type.pointee + count[pointee_type] = count.get(pointee_type, 0) + 1 return count @@ -99,11 +106,15 @@ def handle_if_allocation( def allocate_mem( module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab ): - max_temps_needed = 0 + max_temps_needed = {} + + def merge_type_counts(count_dict): + nonlocal max_temps_needed + for typ, cnt in count_dict.items(): + max_temps_needed[typ] = max(max_temps_needed.get(typ, 0), cnt) def update_max_temps_for_stmt(stmt): nonlocal max_temps_needed - temps_needed = 0 if isinstance(stmt, ast.If): for s in stmt.body: @@ -112,10 +123,13 @@ def allocate_mem( update_max_temps_for_stmt(s) return + stmt_temps = {} for node in ast.walk(stmt): if isinstance(node, ast.Call): - temps_needed += count_temps_in_call(node, local_sym_tab) - max_temps_needed = max(max_temps_needed, temps_needed) + call_temps = count_temps_in_call(node, local_sym_tab) + for typ, cnt in call_temps.items(): + stmt_temps[typ] = stmt_temps.get(typ, 0) + cnt + merge_type_counts(stmt_temps) for stmt in body: update_max_temps_for_stmt(stmt) @@ -133,7 +147,9 @@ def allocate_mem( structs_sym_tab, ) elif isinstance(stmt, ast.Assign): - handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab) + handle_assign_allocation( + builder, stmt, local_sym_tab, map_sym_tab, structs_sym_tab + ) allocate_temp_pool(builder, max_temps_needed, local_sym_tab) diff --git a/pythonbpf/helper/__init__.py b/pythonbpf/helper/__init__.py index 2f9c347..6d38e79 100644 --- a/pythonbpf/helper/__init__.py +++ b/pythonbpf/helper/__init__.py @@ -1,7 +1,21 @@ from .helper_registry import HelperHandlerRegistry from .helper_utils import reset_scratch_pool from .bpf_helper_handler import handle_helper_call, emit_probe_read_kernel_str_call -from .helpers import ktime, pid, deref, comm, probe_read_str, XDP_DROP, XDP_PASS +from .helpers import ( + ktime, + pid, + deref, + comm, + probe_read_str, + random, + probe_read, + smp_processor_id, + uid, + skb_store_bytes, + get_stack, + XDP_DROP, + XDP_PASS, +) # Register the helper handler with expr module @@ -65,6 +79,12 @@ __all__ = [ "deref", "comm", "probe_read_str", + "random", + "probe_read", + "smp_processor_id", + "uid", + "skb_store_bytes", + "get_stack", "XDP_DROP", "XDP_PASS", ] diff --git a/pythonbpf/helper/bpf_helper_handler.py b/pythonbpf/helper/bpf_helper_handler.py index 4adaed2..ba35cc4 100644 --- a/pythonbpf/helper/bpf_helper_handler.py +++ b/pythonbpf/helper/bpf_helper_handler.py @@ -8,30 +8,43 @@ from .helper_utils import ( get_flags_val, get_data_ptr_and_size, get_buffer_ptr_and_size, - get_char_array_ptr_and_size, get_ptr_from_arg, + get_int_value_from_arg, ) from .printk_formatter import simple_string_print, handle_fstring_print - -from logging import Logger +from pythonbpf.maps import BPFMapType import logging -logger: Logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) class BPFHelperID(Enum): BPF_MAP_LOOKUP_ELEM = 1 BPF_MAP_UPDATE_ELEM = 2 BPF_MAP_DELETE_ELEM = 3 + BPF_PROBE_READ = 4 BPF_KTIME_GET_NS = 5 BPF_PRINTK = 6 + BPF_GET_PRANDOM_U32 = 7 + BPF_GET_SMP_PROCESSOR_ID = 8 + BPF_SKB_STORE_BYTES = 9 BPF_GET_CURRENT_PID_TGID = 14 + BPF_GET_CURRENT_UID_GID = 15 BPF_GET_CURRENT_COMM = 16 BPF_PERF_EVENT_OUTPUT = 25 + BPF_GET_STACK = 67 BPF_PROBE_READ_KERNEL_STR = 115 + BPF_RINGBUF_OUTPUT = 130 + BPF_RINGBUF_RESERVE = 131 + BPF_RINGBUF_SUBMIT = 132 + BPF_RINGBUF_DISCARD = 133 -@HelperHandlerRegistry.register("ktime") +@HelperHandlerRegistry.register( + "ktime", + param_types=[], + return_type=ir.IntType(64), +) def bpf_ktime_get_ns_emitter( call, map_ptr, @@ -54,7 +67,11 @@ def bpf_ktime_get_ns_emitter( return result, ir.IntType(64) -@HelperHandlerRegistry.register("lookup") +@HelperHandlerRegistry.register( + "lookup", + param_types=[ir.PointerType(ir.IntType(64))], + return_type=ir.PointerType(ir.IntType(64)), +) def bpf_map_lookup_elem_emitter( call, map_ptr, @@ -96,6 +113,7 @@ def bpf_map_lookup_elem_emitter( return result, ir.PointerType() +# NOTE: This has special handling so we won't reflect the signature here. @HelperHandlerRegistry.register("print") def bpf_printk_emitter( call, @@ -144,7 +162,15 @@ def bpf_printk_emitter( return True -@HelperHandlerRegistry.register("update") +@HelperHandlerRegistry.register( + "update", + param_types=[ + ir.PointerType(ir.IntType(64)), + ir.PointerType(ir.IntType(64)), + ir.IntType(64), + ], + return_type=ir.PointerType(ir.IntType(64)), +) def bpf_map_update_elem_emitter( call, map_ptr, @@ -199,7 +225,11 @@ def bpf_map_update_elem_emitter( return result, None -@HelperHandlerRegistry.register("delete") +@HelperHandlerRegistry.register( + "delete", + param_types=[ir.PointerType(ir.IntType(64))], + return_type=ir.PointerType(ir.IntType(64)), +) def bpf_map_delete_elem_emitter( call, map_ptr, @@ -239,7 +269,11 @@ def bpf_map_delete_elem_emitter( return result, None -@HelperHandlerRegistry.register("comm") +@HelperHandlerRegistry.register( + "comm", + param_types=[ir.PointerType(ir.IntType(8))], + return_type=ir.IntType(64), +) def bpf_get_current_comm_emitter( call, map_ptr, @@ -296,7 +330,11 @@ def bpf_get_current_comm_emitter( return result, None -@HelperHandlerRegistry.register("pid") +@HelperHandlerRegistry.register( + "pid", + param_types=[], + return_type=ir.IntType(64), +) def bpf_get_current_pid_tgid_emitter( call, map_ptr, @@ -318,12 +356,12 @@ def bpf_get_current_pid_tgid_emitter( result = builder.call(fn_ptr, [], tail=False) # Extract the lower 32 bits (PID) using bitwise AND with 0xFFFFFFFF + # TODO: return both PID and TGID if we end up needing TGID somewhere mask = ir.Constant(ir.IntType(64), 0xFFFFFFFF) pid = builder.and_(result, mask) return pid, ir.IntType(64) -@HelperHandlerRegistry.register("output") def bpf_perf_event_output_handler( call, map_ptr, @@ -334,6 +372,10 @@ def bpf_perf_event_output_handler( struct_sym_tab=None, map_sym_tab=None, ): + """ + Emit LLVM IR for bpf_perf_event_output helper function call. + """ + if len(call.args) != 1: raise ValueError( f"Perf event output expects exactly one argument, got {len(call.args)}" @@ -371,6 +413,98 @@ def bpf_perf_event_output_handler( return result, None +def bpf_ringbuf_output_emitter( + call, + map_ptr, + module, + builder, + func, + local_sym_tab=None, + struct_sym_tab=None, + map_sym_tab=None, +): + """ + Emit LLVM IR for bpf_ringbuf_output helper function call. + """ + + if len(call.args) != 1: + raise ValueError( + f"Ringbuf output expects exactly one argument, got {len(call.args)}" + ) + data_arg = call.args[0] + data_ptr, size_val = get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab) + flags_val = ir.Constant(ir.IntType(64), 0) + + map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) + data_void_ptr = builder.bitcast(data_ptr, ir.PointerType()) + fn_type = ir.FunctionType( + ir.IntType(64), + [ + ir.PointerType(), + ir.PointerType(), + ir.IntType(64), + ir.IntType(64), + ], + var_arg=False, + ) + fn_ptr_type = ir.PointerType(fn_type) + + # helper id + fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_RINGBUF_OUTPUT.value) + fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) + + result = builder.call( + fn_ptr, [map_void_ptr, data_void_ptr, size_val, flags_val], tail=False + ) + return result, None + + +@HelperHandlerRegistry.register( + "output", + param_types=[ir.PointerType(ir.IntType(8))], + return_type=ir.IntType(64), +) +def handle_output_helper( + call, + map_ptr, + module, + builder, + func, + local_sym_tab=None, + struct_sym_tab=None, + map_sym_tab=None, +): + """ + Route output helper to the appropriate emitter based on map type. + """ + match map_sym_tab[map_ptr.name].type: + case BPFMapType.PERF_EVENT_ARRAY: + return bpf_perf_event_output_handler( + call, + map_ptr, + module, + builder, + func, + local_sym_tab, + struct_sym_tab, + map_sym_tab, + ) + case BPFMapType.RINGBUF: + return bpf_ringbuf_output_emitter( + call, + map_ptr, + module, + builder, + func, + local_sym_tab, + struct_sym_tab, + map_sym_tab, + ) + case _: + logger.error("Unsupported map type for output helper.") + raise NotImplementedError("Output helper for this map type is not implemented.") + + def emit_probe_read_kernel_str_call(builder, dst_ptr, dst_size, src_ptr): """Emit LLVM IR call to bpf_probe_read_kernel_str""" @@ -398,7 +532,14 @@ def emit_probe_read_kernel_str_call(builder, dst_ptr, dst_size, src_ptr): return result -@HelperHandlerRegistry.register("probe_read_str") +@HelperHandlerRegistry.register( + "probe_read_str", + param_types=[ + ir.PointerType(ir.IntType(8)), + ir.PointerType(ir.IntType(8)), + ], + return_type=ir.IntType(64), +) def bpf_probe_read_kernel_str_emitter( call, map_ptr, @@ -417,8 +558,8 @@ def bpf_probe_read_kernel_str_emitter( ) # Get destination buffer (char array -> i8*) - dst_ptr, dst_size = get_char_array_ptr_and_size( - call.args[0], builder, local_sym_tab, struct_sym_tab + dst_ptr, dst_size = get_or_create_ptr_from_arg( + func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab ) # Get source pointer (evaluate expression) @@ -433,6 +574,430 @@ def bpf_probe_read_kernel_str_emitter( return result, ir.IntType(64) +@HelperHandlerRegistry.register( + "random", + param_types=[], + return_type=ir.IntType(32), +) +def bpf_get_prandom_u32_emitter( + call, + map_ptr, + module, + builder, + func, + local_sym_tab=None, + struct_sym_tab=None, + map_sym_tab=None, +): + """ + Emit LLVM IR for bpf_get_prandom_u32 helper function call. + """ + helper_id = ir.Constant(ir.IntType(64), BPFHelperID.BPF_GET_PRANDOM_U32.value) + fn_type = ir.FunctionType(ir.IntType(32), [], var_arg=False) + fn_ptr_type = ir.PointerType(fn_type) + fn_ptr = builder.inttoptr(helper_id, fn_ptr_type) + result = builder.call(fn_ptr, [], tail=False) + return result, ir.IntType(32) + + +@HelperHandlerRegistry.register( + "probe_read", + param_types=[ + ir.PointerType(ir.IntType(8)), + ir.IntType(32), + ir.PointerType(ir.IntType(8)), + ], + return_type=ir.IntType(64), +) +def bpf_probe_read_emitter( + call, + map_ptr, + module, + builder, + func, + local_sym_tab=None, + struct_sym_tab=None, + map_sym_tab=None, +): + """ + Emit LLVM IR for bpf_probe_read helper function + """ + + if len(call.args) != 3: + logger.warn("Expected 3 args for probe_read helper") + return + dst_ptr = get_or_create_ptr_from_arg( + func, + module, + call.args[0], + builder, + local_sym_tab, + map_sym_tab, + struct_sym_tab, + ir.IntType(8), + ) + size_val = get_int_value_from_arg( + call.args[1], + func, + module, + builder, + local_sym_tab, + map_sym_tab, + struct_sym_tab, + ) + src_ptr = get_or_create_ptr_from_arg( + func, + module, + call.args[2], + builder, + local_sym_tab, + map_sym_tab, + struct_sym_tab, + ir.IntType(8), + ) + fn_type = ir.FunctionType( + ir.IntType(64), + [ir.PointerType(), ir.IntType(32), ir.PointerType()], + var_arg=False, + ) + fn_ptr = builder.inttoptr( + ir.Constant(ir.IntType(64), BPFHelperID.BPF_PROBE_READ.value), + ir.PointerType(fn_type), + ) + result = builder.call( + fn_ptr, + [ + builder.bitcast(dst_ptr, ir.PointerType()), + builder.trunc(size_val, ir.IntType(32)), + builder.bitcast(src_ptr, ir.PointerType()), + ], + tail=False, + ) + logger.info(f"Emitted bpf_probe_read (size={size_val})") + return result, ir.IntType(64) + + +@HelperHandlerRegistry.register( + "smp_processor_id", + param_types=[], + return_type=ir.IntType(32), +) +def bpf_get_smp_processor_id_emitter( + call, + map_ptr, + module, + builder, + func, + local_sym_tab=None, + struct_sym_tab=None, + map_sym_tab=None, +): + """ + Emit LLVM IR for bpf_get_smp_processor_id helper function call. + """ + helper_id = ir.Constant(ir.IntType(64), BPFHelperID.BPF_GET_SMP_PROCESSOR_ID.value) + fn_type = ir.FunctionType(ir.IntType(32), [], var_arg=False) + fn_ptr_type = ir.PointerType(fn_type) + fn_ptr = builder.inttoptr(helper_id, fn_ptr_type) + result = builder.call(fn_ptr, [], tail=False) + logger.info("Emitted bpf_get_smp_processor_id call") + return result, ir.IntType(32) + + +@HelperHandlerRegistry.register( + "uid", + param_types=[], + return_type=ir.IntType(64), +) +def bpf_get_current_uid_gid_emitter( + call, + map_ptr, + module, + builder, + func, + local_sym_tab=None, + struct_sym_tab=None, + map_sym_tab=None, +): + """ + Emit LLVM IR for bpf_get_current_uid_gid helper function call. + """ + helper_id = ir.Constant(ir.IntType(64), BPFHelperID.BPF_GET_CURRENT_UID_GID.value) + fn_type = ir.FunctionType(ir.IntType(64), [], var_arg=False) + fn_ptr_type = ir.PointerType(fn_type) + fn_ptr = builder.inttoptr(helper_id, fn_ptr_type) + result = builder.call(fn_ptr, [], tail=False) + + # Extract the lower 32 bits (UID) using bitwise AND with 0xFFFFFFFF + # TODO: return both UID and GID if we end up needing GID somewhere + mask = ir.Constant(ir.IntType(64), 0xFFFFFFFF) + pid = builder.and_(result, mask) + return pid, ir.IntType(64) + + +@HelperHandlerRegistry.register( + "skb_store_bytes", + param_types=[ + ir.IntType(32), + ir.PointerType(ir.IntType(8)), + ir.IntType(32), + ir.IntType(64), + ], + return_type=ir.IntType(64), +) +def bpf_skb_store_bytes_emitter( + call, + map_ptr, + module, + builder, + func, + local_sym_tab=None, + struct_sym_tab=None, + map_sym_tab=None, +): + """ + Emit LLVM IR for bpf_skb_store_bytes helper function call. + Expected call signature: skb_store_bytes(skb, offset, from, len, flags) + """ + + args_signature = [ + ir.PointerType(), # skb pointer + ir.IntType(32), # offset + ir.PointerType(), # from + ir.IntType(32), # len + ir.IntType(64), # flags + ] + + if len(call.args) not in (3, 4): + raise ValueError( + f"skb_store_bytes expects 3 or 4 args (offset, from, len, flags), got {len(call.args)}" + ) + + skb_ptr = func.args[0] # First argument to the function is skb + offset_val = get_int_value_from_arg( + call.args[0], + func, + module, + builder, + local_sym_tab, + map_sym_tab, + struct_sym_tab, + ) + from_ptr = get_or_create_ptr_from_arg( + func, + module, + call.args[1], + builder, + local_sym_tab, + map_sym_tab, + struct_sym_tab, + args_signature[2], + ) + len_val = get_int_value_from_arg( + call.args[2], + func, + module, + builder, + local_sym_tab, + map_sym_tab, + struct_sym_tab, + ) + if len(call.args) == 4: + flags_val = get_flags_val(call.args[3], builder, local_sym_tab) + else: + flags_val = 0 + if isinstance(flags_val, int): + flags = ir.Constant(ir.IntType(64), flags_val) + else: + flags = flags_val + fn_type = ir.FunctionType( + ir.IntType(64), + args_signature, + var_arg=False, + ) + fn_ptr = builder.inttoptr( + ir.Constant(ir.IntType(64), BPFHelperID.BPF_SKB_STORE_BYTES.value), + ir.PointerType(fn_type), + ) + result = builder.call( + fn_ptr, + [ + builder.bitcast(skb_ptr, ir.PointerType()), + builder.trunc(offset_val, ir.IntType(32)), + builder.bitcast(from_ptr, ir.PointerType()), + builder.trunc(len_val, ir.IntType(32)), + flags, + ], + tail=False, + ) + logger.info("Emitted bpf_skb_store_bytes call") + return result, ir.IntType(64) + + +@HelperHandlerRegistry.register( + "reserve", + param_types=[ir.IntType(64)], + return_type=ir.PointerType(ir.IntType(8)), +) +def bpf_ringbuf_reserve_emitter( + call, + map_ptr, + module, + builder, + func, + local_sym_tab=None, + struct_sym_tab=None, + map_sym_tab=None, +): + """ + Emit LLVM IR for bpf_ringbuf_reserve helper function call. + Expected call signature: ringbuf.reserve(size) + """ + + if len(call.args) != 1: + raise ValueError( + f"ringbuf.reserve expects exactly one argument (size), got {len(call.args)}" + ) + + size_val = get_int_value_from_arg( + call.args[0], + func, + module, + builder, + local_sym_tab, + map_sym_tab, + struct_sym_tab, + ) + + map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) + fn_type = ir.FunctionType( + ir.PointerType(ir.IntType(8)), + [ir.PointerType(), ir.IntType(64)], + var_arg=False, + ) + fn_ptr_type = ir.PointerType(fn_type) + + fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_RINGBUF_RESERVE.value) + fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) + + result = builder.call(fn_ptr, [map_void_ptr, size_val], tail=False) + + return result, ir.PointerType(ir.IntType(8)) + + +@HelperHandlerRegistry.register( + "submit", + param_types=[ir.PointerType(ir.IntType(8)), ir.IntType(64)], + return_type=ir.VoidType(), +) +def bpf_ringbuf_submit_emitter( + call, + map_ptr, + module, + builder, + func, + local_sym_tab=None, + struct_sym_tab=None, + map_sym_tab=None, +): + """ + Emit LLVM IR for bpf_ringbuf_submit helper function call. + Expected call signature: ringbuf.submit(data, flags=0) + """ + + if len(call.args) not in (1, 2): + raise ValueError( + f"ringbuf.submit expects 1 or 2 args (data, flags), got {len(call.args)}" + ) + + data_arg = call.args[0] + flags_arg = call.args[1] if len(call.args) == 2 else None + + data_ptr = get_or_create_ptr_from_arg( + func, + module, + data_arg, + builder, + local_sym_tab, + map_sym_tab, + struct_sym_tab, + ir.PointerType(ir.IntType(8)), + ) + + flags_const = get_flags_val(flags_arg, builder, local_sym_tab) + if isinstance(flags_const, int): + flags_const = ir.Constant(ir.IntType(64), flags_const) + + map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) + fn_type = ir.FunctionType( + ir.VoidType(), + [ir.PointerType(), ir.PointerType(), ir.IntType(64)], + var_arg=False, + ) + fn_ptr_type = ir.PointerType(fn_type) + + fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_RINGBUF_SUBMIT.value) + fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) + + result = builder.call(fn_ptr, [map_void_ptr, data_ptr, flags_const], tail=False) + + return result, None + + +@HelperHandlerRegistry.register( + "get_stack", + param_types=[ir.PointerType(ir.IntType(8)), ir.IntType(64)], + return_type=ir.IntType(64), +) +def bpf_get_stack_emitter( + call, + map_ptr, + module, + builder, + func, + local_sym_tab=None, + struct_sym_tab=None, + map_sym_tab=None, +): + """ + Emit LLVM IR for bpf_get_stack helper function call. + """ + if len(call.args) not in (1, 2): + raise ValueError( + f"get_stack expects atmost two arguments (buf, flags), got {len(call.args)}" + ) + ctx_ptr = func.args[0] # First argument to the function is ctx + buf_arg = call.args[0] + flags_arg = call.args[1] if len(call.args) == 2 else None + buf_ptr, buf_size = get_buffer_ptr_and_size( + buf_arg, builder, local_sym_tab, struct_sym_tab + ) + flags_val = get_flags_val(flags_arg, builder, local_sym_tab) + if isinstance(flags_val, int): + flags_val = ir.Constant(ir.IntType(64), flags_val) + + buf_void_ptr = builder.bitcast(buf_ptr, ir.PointerType()) + fn_type = ir.FunctionType( + ir.IntType(64), + [ + ir.PointerType(ir.IntType(8)), + ir.PointerType(), + ir.IntType(64), + ir.IntType(64), + ], + var_arg=False, + ) + fn_ptr_type = ir.PointerType(fn_type) + fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_GET_STACK.value) + fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) + result = builder.call( + fn_ptr, + [ctx_ptr, buf_void_ptr, ir.Constant(ir.IntType(64), buf_size), flags_val], + tail=False, + ) + return result, ir.IntType(64) + + def handle_helper_call( call, module, @@ -487,6 +1052,6 @@ def handle_helper_call( if not map_sym_tab or map_name not in map_sym_tab: raise ValueError(f"Map '{map_name}' not found in symbol table") - return invoke_helper(method_name, map_sym_tab[map_name]) + return invoke_helper(method_name, map_sym_tab[map_name].sym) return None diff --git a/pythonbpf/helper/helper_registry.py b/pythonbpf/helper/helper_registry.py index 476e3b6..0e09d70 100644 --- a/pythonbpf/helper/helper_registry.py +++ b/pythonbpf/helper/helper_registry.py @@ -1,17 +1,31 @@ +from dataclasses import dataclass +from llvmlite import ir from typing import Callable +@dataclass +class HelperSignature: + """Signature of a BPF helper function""" + + arg_types: list[ir.Type] + return_type: ir.Type + func: Callable + + class HelperHandlerRegistry: """Registry for BPF helpers""" - _handlers: dict[str, Callable] = {} + _handlers: dict[str, HelperSignature] = {} @classmethod - def register(cls, helper_name): + def register(cls, helper_name, param_types=None, return_type=None): """Decorator to register a handler function for a helper""" def decorator(func): - cls._handlers[helper_name] = func + helper_sig = HelperSignature( + arg_types=param_types, return_type=return_type, func=func + ) + cls._handlers[helper_name] = helper_sig return func return decorator @@ -19,9 +33,29 @@ class HelperHandlerRegistry: @classmethod def get_handler(cls, helper_name): """Get the handler function for a helper""" - return cls._handlers.get(helper_name) + handler = cls._handlers.get(helper_name) + return handler.func if handler else None @classmethod def has_handler(cls, helper_name): """Check if a handler function is registered for a helper""" return helper_name in cls._handlers + + @classmethod + def get_signature(cls, helper_name): + """Get the signature of a helper function""" + return cls._handlers.get(helper_name) + + @classmethod + def get_param_type(cls, helper_name, index): + """Get the type of a parameter of a helper function by the index""" + signature = cls.get_signature(helper_name) + if signature and signature.arg_types and 0 <= index < len(signature.arg_types): + return signature.arg_types[index] + return None + + @classmethod + def get_return_type(cls, helper_name): + """Get the return type of a helper function""" + signature = cls.get_signature(helper_name) + return signature.return_type if signature else None diff --git a/pythonbpf/helper/helper_utils.py b/pythonbpf/helper/helper_utils.py index fdfd452..aecb5e9 100644 --- a/pythonbpf/helper/helper_utils.py +++ b/pythonbpf/helper/helper_utils.py @@ -14,26 +14,43 @@ class ScratchPoolManager: """Manage the temporary helper variables in local_sym_tab""" def __init__(self): - self._counter = 0 + self._counters = {} @property def counter(self): - return self._counter + return sum(self._counters.values()) def reset(self): - self._counter = 0 + self._counters.clear() logger.debug("Scratch pool counter reset to 0") - def get_next_temp(self, local_sym_tab): - temp_name = f"__helper_temp_{self._counter}" - self._counter += 1 + def _get_type_name(self, ir_type): + if isinstance(ir_type, ir.PointerType): + return "ptr" + elif isinstance(ir_type, ir.IntType): + return f"i{ir_type.width}" + elif isinstance(ir_type, ir.ArrayType): + return f"[{ir_type.count}x{self._get_type_name(ir_type.element)}]" + else: + return str(ir_type).replace(" ", "") + + def get_next_temp(self, local_sym_tab, expected_type=None): + # Default to i64 if no expected type provided + type_name = self._get_type_name(expected_type) if expected_type else "i64" + if type_name not in self._counters: + self._counters[type_name] = 0 + + counter = self._counters[type_name] + temp_name = f"__helper_temp_{type_name}_{counter}" + self._counters[type_name] += 1 if temp_name not in local_sym_tab: raise ValueError( f"Scratch pool exhausted or inadequate: {temp_name}. " - f"Current counter: {self._counter}" + f"Type: {type_name} Counter: {counter}" ) + logger.debug(f"Using {temp_name} for type {type_name}") return local_sym_tab[temp_name].var, temp_name @@ -60,24 +77,73 @@ def get_var_ptr_from_name(var_name, local_sym_tab): def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64): """Create a pointer to an integer constant.""" - # Default to 64-bit integer - ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab) + int_type = ir.IntType(int_width) + ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab, int_type) logger.info(f"Using temp variable '{temp_name}' for int constant {value}") - const_val = ir.Constant(ir.IntType(int_width), value) + const_val = ir.Constant(int_type, value) builder.store(const_val, ptr) return ptr def get_or_create_ptr_from_arg( - func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab=None + func, + module, + arg, + builder, + local_sym_tab, + map_sym_tab, + struct_sym_tab=None, + expected_type=None, ): """Extract or create pointer from the call arguments.""" + logger.info(f"Getting pointer from arg: {ast.dump(arg)}") + sz = None if isinstance(arg, ast.Name): + # Stack space is already allocated ptr = get_var_ptr_from_name(arg.id, local_sym_tab) elif isinstance(arg, ast.Constant) and isinstance(arg.value, int): - ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab) + int_width = 64 # Default to i64 + if expected_type and isinstance(expected_type, ir.IntType): + int_width = expected_type.width + ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab, int_width) + elif isinstance(arg, ast.Attribute): + # A struct field + struct_name = arg.value.id + field_name = arg.attr + + if not local_sym_tab or struct_name not in local_sym_tab: + raise ValueError(f"Struct '{struct_name}' not found") + + struct_type = local_sym_tab[struct_name].metadata + if not struct_sym_tab or struct_type not in struct_sym_tab: + raise ValueError(f"Struct type '{struct_type}' not found") + + struct_info = struct_sym_tab[struct_type] + if field_name not in struct_info.fields: + raise ValueError( + f"Field '{field_name}' not found in struct '{struct_name}'" + ) + + field_type = struct_info.field_type(field_name) + struct_ptr = local_sym_tab[struct_name].var + + # Special handling for char arrays + if ( + isinstance(field_type, ir.ArrayType) + and isinstance(field_type.element, ir.IntType) + and field_type.element.width == 8 + ): + ptr, sz = get_char_array_ptr_and_size( + arg, builder, local_sym_tab, struct_sym_tab + ) + if not ptr: + raise ValueError("Failed to get char array pointer from struct field") + else: + ptr = struct_info.gep(builder, struct_ptr, field_name) + else: + # NOTE: For any integer expression reaching this branch, it is probably a struct field or a binop # Evaluate the expression and store the result in a temp variable val = get_operand_value( func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab @@ -85,13 +151,20 @@ def get_or_create_ptr_from_arg( if val is None: raise ValueError("Failed to evaluate expression for helper arg.") - # NOTE: We assume the result is an int64 for now - # if isinstance(arg, ast.Attribute): - # return val - ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab) + ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab, expected_type) logger.info(f"Using temp variable '{temp_name}' for expression result") + if ( + isinstance(val.type, ir.IntType) + and expected_type + and val.type.width > expected_type.width + ): + val = builder.trunc(val, expected_type) builder.store(val, ptr) + # NOTE: For char arrays, also return size + if sz: + return ptr, sz + return ptr @@ -214,7 +287,10 @@ def get_char_array_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab) field_type = struct_info.field_type(field_name) if not _is_char_array(field_type): - raise ValueError("Expected char array field") + logger.info( + "Field is not a char array, falling back to int or ptr detection" + ) + return None, 0 struct_ptr = local_sym_tab[var_name].var field_ptr = struct_info.gep(builder, struct_ptr, field_name) @@ -274,3 +350,23 @@ def get_ptr_from_arg( raise ValueError(f"Expected pointer type, got {val_type}") return val, val_type + + +def get_int_value_from_arg( + arg, func, module, builder, local_sym_tab, map_sym_tab, struct_sym_tab +): + """Evaluate argument and return integer value""" + + result = eval_expr( + func, module, builder, arg, local_sym_tab, map_sym_tab, struct_sym_tab + ) + + if not result: + raise ValueError("Failed to evaluate argument") + + val, val_type = result + + if not isinstance(val_type, ir.IntType): + raise ValueError(f"Expected integer type, got {val_type}") + + return val diff --git a/pythonbpf/helper/helpers.py b/pythonbpf/helper/helpers.py index cb1a8e1..c80d57d 100644 --- a/pythonbpf/helper/helpers.py +++ b/pythonbpf/helper/helpers.py @@ -27,6 +27,36 @@ def probe_read_str(dst, src): return ctypes.c_int64(0) +def random(): + """get a pseudorandom u32 number""" + return ctypes.c_int32(0) + + +def probe_read(dst, size, src): + """Safely read data from kernel memory""" + return ctypes.c_int64(0) + + +def smp_processor_id(): + """get the current CPU id""" + return ctypes.c_int32(0) + + +def uid(): + """get current user id""" + return ctypes.c_int32(0) + + +def skb_store_bytes(offset, from_buf, size, flags=0): + """store bytes into a socket buffer""" + return ctypes.c_int64(0) + + +def get_stack(buf, flags=0): + """get the current stack trace""" + return ctypes.c_int64(0) + + XDP_ABORTED = ctypes.c_int64(0) XDP_DROP = ctypes.c_int64(1) XDP_PASS = ctypes.c_int64(2) diff --git a/pythonbpf/helper/printk_formatter.py b/pythonbpf/helper/printk_formatter.py index a18f135..721213e 100644 --- a/pythonbpf/helper/printk_formatter.py +++ b/pythonbpf/helper/printk_formatter.py @@ -4,6 +4,7 @@ import logging from llvmlite import ir from pythonbpf.expr import eval_expr, get_base_type_and_depth, deref_to_depth from pythonbpf.expr.vmlinux_registry import VmlinuxHandlerRegistry +from pythonbpf.helper.helper_utils import get_char_array_ptr_and_size logger = logging.getLogger(__name__) @@ -219,11 +220,12 @@ def _prepare_expr_args(expr, func, module, builder, local_sym_tab, struct_sym_ta """Evaluate and prepare an expression to use as an arg for bpf_printk.""" # Special case: struct field char array needs pointer to first element - char_array_ptr = _get_struct_char_array_ptr( - expr, builder, local_sym_tab, struct_sym_tab - ) - if char_array_ptr: - return char_array_ptr + if isinstance(expr, ast.Attribute): + char_array_ptr, _ = get_char_array_ptr_and_size( + expr, builder, local_sym_tab, struct_sym_tab + ) + if char_array_ptr: + return char_array_ptr # Regular expression evaluation val, _ = eval_expr(func, module, builder, expr, local_sym_tab, None, struct_sym_tab) @@ -242,52 +244,6 @@ def _prepare_expr_args(expr, func, module, builder, local_sym_tab, struct_sym_ta return ir.Constant(ir.IntType(64), 0) -def _get_struct_char_array_ptr(expr, builder, local_sym_tab, struct_sym_tab): - """Get pointer to first element of char array in struct field, or None.""" - if not (isinstance(expr, ast.Attribute) and isinstance(expr.value, ast.Name)): - return None - - var_name = expr.value.id - field_name = expr.attr - - # Check if it's a valid struct field - if not ( - local_sym_tab - and var_name in local_sym_tab - and struct_sym_tab - and local_sym_tab[var_name].metadata in struct_sym_tab - ): - return None - - struct_type = local_sym_tab[var_name].metadata - struct_info = struct_sym_tab[struct_type] - - if field_name not in struct_info.fields: - return None - - field_type = struct_info.field_type(field_name) - - # Check if it's a char array - is_char_array = ( - isinstance(field_type, ir.ArrayType) - and isinstance(field_type.element, ir.IntType) - and field_type.element.width == 8 - ) - - if not is_char_array: - return None - - # Get field pointer and GEP to first element: [N x i8]* -> i8* - struct_ptr = local_sym_tab[var_name].var - field_ptr = struct_info.gep(builder, struct_ptr, field_name) - - return builder.gep( - field_ptr, - [ir.Constant(ir.IntType(32), 0), ir.Constant(ir.IntType(32), 0)], - inbounds=True, - ) - - def _handle_pointer_arg(val, func, builder): """Convert pointer type for bpf_printk.""" target, depth = get_base_type_and_depth(val.type) diff --git a/pythonbpf/maps/__init__.py b/pythonbpf/maps/__init__.py index 48fc9ff..eb2007d 100644 --- a/pythonbpf/maps/__init__.py +++ b/pythonbpf/maps/__init__.py @@ -1,4 +1,5 @@ -from .maps import HashMap, PerfEventArray, RingBuf +from .maps import HashMap, PerfEventArray, RingBuffer from .maps_pass import maps_proc +from .map_types import BPFMapType -__all__ = ["HashMap", "PerfEventArray", "maps_proc", "RingBuf"] +__all__ = ["HashMap", "PerfEventArray", "maps_proc", "RingBuffer", "BPFMapType"] diff --git a/pythonbpf/maps/map_debug_info.py b/pythonbpf/maps/map_debug_info.py index b1a25d9..77a5888 100644 --- a/pythonbpf/maps/map_debug_info.py +++ b/pythonbpf/maps/map_debug_info.py @@ -1,22 +1,31 @@ +import logging +from llvmlite import ir from pythonbpf.debuginfo import DebugInfoGenerator from .map_types import BPFMapType +logger: logging.Logger = logging.getLogger(__name__) -def create_map_debug_info(module, map_global, map_name, map_params): + +def create_map_debug_info(module, map_global, map_name, map_params, structs_sym_tab): """Generate debug info metadata for BPF maps HASH and PERF_EVENT_ARRAY""" generator = DebugInfoGenerator(module) - + logger.info(f"Creating debug info for map {map_name} with params {map_params}") uint_type = generator.get_uint32_type() - ulong_type = generator.get_uint64_type() array_type = generator.create_array_type( uint_type, map_params.get("type", BPFMapType.UNSPEC).value ) type_ptr = generator.create_pointer_type(array_type, 64) key_ptr = generator.create_pointer_type( - array_type if "key_size" in map_params else ulong_type, 64 + array_type + if "key_size" in map_params + else _get_key_val_dbg_type(map_params.get("key"), generator, structs_sym_tab), + 64, ) value_ptr = generator.create_pointer_type( - array_type if "value_size" in map_params else ulong_type, 64 + array_type + if "value_size" in map_params + else _get_key_val_dbg_type(map_params.get("value"), generator, structs_sym_tab), + 64, ) elements_arr = [] @@ -64,7 +73,13 @@ def create_map_debug_info(module, map_global, map_name, map_params): return global_var -def create_ringbuf_debug_info(module, map_global, map_name, map_params): +# TODO: This should not be exposed outside of the module. +# Ideally we should expose a single create_map_debug_info function that handles all map types. +# We can probably use a registry pattern to register different map types and their debug info generators. +# map_params["type"] will be used to determine which generator to use. +def create_ringbuf_debug_info( + module, map_global, map_name, map_params, structs_sym_tab +): """Generate debug information metadata for BPF RINGBUF map""" generator = DebugInfoGenerator(module) @@ -91,3 +106,65 @@ def create_ringbuf_debug_info(module, map_global, map_name, map_params): ) map_global.set_metadata("dbg", global_var) return global_var + + +def _get_key_val_dbg_type(name, generator, structs_sym_tab): + """Get the debug type for key/value based on type object""" + + if not name: + logger.warn("No name provided for key/value type, defaulting to uint64") + return generator.get_uint64_type() + + type_obj = structs_sym_tab.get(name) + if type_obj: + return _get_struct_debug_type(type_obj, generator, structs_sym_tab) + + # Fallback to basic types + logger.info(f"No struct named {name}, falling back to basic type") + + # NOTE: Only handling int and long for now + if name in ["c_int32", "c_uint32"]: + return generator.get_uint32_type() + + # Default fallback for now + return generator.get_uint64_type() + + +def _get_struct_debug_type(struct_obj, generator, structs_sym_tab): + """Recursively create debug type for struct""" + elements_arr = [] + for fld in struct_obj.fields.keys(): + fld_type = struct_obj.field_type(fld) + if isinstance(fld_type, ir.IntType): + if fld_type.width == 32: + fld_dbg_type = generator.get_uint32_type() + else: + # NOTE: Assuming 64-bit for all other int types + fld_dbg_type = generator.get_uint64_type() + elif isinstance(fld_type, ir.ArrayType): + # NOTE: Array types have u8 elements only for now + # Debug info generation should fail for other types + elem_type = fld_type.element + if isinstance(elem_type, ir.IntType) and elem_type.width == 8: + char_type = generator.get_uint8_type() + fld_dbg_type = generator.create_array_type(char_type, fld_type.count) + else: + logger.warning( + f"Array element type {str(elem_type)} not supported for debug info, skipping" + ) + continue + else: + # NOTE: Only handling int and char arrays for now + logger.warning( + f"Field type {str(fld_type)} not supported for debug info, skipping" + ) + continue + + member = generator.create_struct_member( + fld, fld_dbg_type, struct_obj.field_size(fld) + ) + elements_arr.append(member) + struct_type = generator.create_struct_type( + elements_arr, struct_obj.size, is_distinct=True + ) + return struct_type diff --git a/pythonbpf/maps/maps.py b/pythonbpf/maps/maps.py index a2d7c21..583e957 100644 --- a/pythonbpf/maps/maps.py +++ b/pythonbpf/maps/maps.py @@ -36,11 +36,14 @@ class PerfEventArray: pass # Placeholder for output method -class RingBuf: +class RingBuffer: def __init__(self, max_entries): self.max_entries = max_entries - def reserve(self, size: int, flags=0): + def output(self, data, flags=0): + pass + + def reserve(self, size: int): if size > self.max_entries: raise ValueError("size cannot be greater than set maximum entries") return 0 @@ -48,4 +51,7 @@ class RingBuf: def submit(self, data, flags=0): pass + def discard(self, data, flags=0): + pass + # add discard, output and also give names to flags and stuff diff --git a/pythonbpf/maps/maps_pass.py b/pythonbpf/maps/maps_pass.py index 85837d7..ac498dc 100644 --- a/pythonbpf/maps/maps_pass.py +++ b/pythonbpf/maps/maps_pass.py @@ -3,7 +3,7 @@ import logging from logging import Logger from llvmlite import ir -from .maps_utils import MapProcessorRegistry +from .maps_utils import MapProcessorRegistry, MapSymbol from .map_types import BPFMapType from .map_debug_info import create_map_debug_info, create_ringbuf_debug_info from pythonbpf.expr.vmlinux_registry import VmlinuxHandlerRegistry @@ -12,13 +12,15 @@ from pythonbpf.expr.vmlinux_registry import VmlinuxHandlerRegistry logger: Logger = logging.getLogger(__name__) -def maps_proc(tree, module, chunks): +def maps_proc(tree, module, chunks, structs_sym_tab): """Process all functions decorated with @map to find BPF maps""" map_sym_tab = {} for func_node in chunks: if is_map(func_node): logger.info(f"Found BPF map: {func_node.name}") - map_sym_tab[func_node.name] = process_bpf_map(func_node, module) + map_sym_tab[func_node.name] = process_bpf_map( + func_node, module, structs_sym_tab + ) return map_sym_tab @@ -46,7 +48,7 @@ def create_bpf_map(module, map_name, map_params): map_global.align = 8 logger.info(f"Created BPF map: {map_name} with params {map_params}") - return map_global + return MapSymbol(type=map_params["type"], sym=map_global, params=map_params) def _parse_map_params(rval, expected_args=None): @@ -60,7 +62,8 @@ def _parse_map_params(rval, expected_args=None): if i < len(rval.args): arg = rval.args[i] if isinstance(arg, ast.Name): - params[arg_name] = arg.id + result = _get_vmlinux_enum(handler, arg.id) + params[arg_name] = result if result is not None else arg.id elif isinstance(arg, ast.Constant): params[arg_name] = arg.value @@ -68,33 +71,48 @@ def _parse_map_params(rval, expected_args=None): for keyword in rval.keywords: if isinstance(keyword.value, ast.Name): name = keyword.value.id - if handler and handler.is_vmlinux_enum(name): - result = handler.get_vmlinux_enum_value(name) - params[keyword.arg] = result if result is not None else name - else: - params[keyword.arg] = name + result = _get_vmlinux_enum(handler, name) + params[keyword.arg] = result if result is not None else name elif isinstance(keyword.value, ast.Constant): params[keyword.arg] = keyword.value.value return params -@MapProcessorRegistry.register("RingBuf") -def process_ringbuf_map(map_name, rval, module): +def _get_vmlinux_enum(handler, name): + if handler and handler.is_vmlinux_enum(name): + return handler.get_vmlinux_enum_value(name) + + +@MapProcessorRegistry.register("RingBuffer") +def process_ringbuf_map(map_name, rval, module, structs_sym_tab): """Process a BPF_RINGBUF map declaration""" logger.info(f"Processing Ringbuf: {map_name}") map_params = _parse_map_params(rval, expected_args=["max_entries"]) map_params["type"] = BPFMapType.RINGBUF + # NOTE: constraints borrowed from https://docs.ebpf.io/linux/map-type/BPF_MAP_TYPE_RINGBUF/ + max_entries = map_params.get("max_entries") + if ( + not isinstance(max_entries, int) + or max_entries < 4096 + or (max_entries & (max_entries - 1)) != 0 + ): + raise ValueError( + "Ringbuf max_entries must be a power of two greater than or equal to the page size (4096)" + ) + logger.info(f"Ringbuf map parameters: {map_params}") map_global = create_bpf_map(module, map_name, map_params) - create_ringbuf_debug_info(module, map_global, map_name, map_params) + create_ringbuf_debug_info( + module, map_global.sym, map_name, map_params, structs_sym_tab + ) return map_global @MapProcessorRegistry.register("HashMap") -def process_hash_map(map_name, rval, module): +def process_hash_map(map_name, rval, module, structs_sym_tab): """Process a BPF_HASH map declaration""" logger.info(f"Processing HashMap: {map_name}") map_params = _parse_map_params(rval, expected_args=["key", "value", "max_entries"]) @@ -103,12 +121,12 @@ def process_hash_map(map_name, rval, module): logger.info(f"Map parameters: {map_params}") map_global = create_bpf_map(module, map_name, map_params) # Generate debug info for BTF - create_map_debug_info(module, map_global, map_name, map_params) + create_map_debug_info(module, map_global.sym, map_name, map_params, structs_sym_tab) return map_global @MapProcessorRegistry.register("PerfEventArray") -def process_perf_event_map(map_name, rval, module): +def process_perf_event_map(map_name, rval, module, structs_sym_tab): """Process a BPF_PERF_EVENT_ARRAY map declaration""" logger.info(f"Processing PerfEventArray: {map_name}") map_params = _parse_map_params(rval, expected_args=["key_size", "value_size"]) @@ -117,11 +135,11 @@ def process_perf_event_map(map_name, rval, module): logger.info(f"Map parameters: {map_params}") map_global = create_bpf_map(module, map_name, map_params) # Generate debug info for BTF - create_map_debug_info(module, map_global, map_name, map_params) + create_map_debug_info(module, map_global.sym, map_name, map_params) return map_global -def process_bpf_map(func_node, module): +def process_bpf_map(func_node, module, structs_sym_tab): """Process a BPF map (a function decorated with @map)""" map_name = func_node.name logger.info(f"Processing BPF map: {map_name}") @@ -140,7 +158,7 @@ def process_bpf_map(func_node, module): if isinstance(rval, ast.Call) and isinstance(rval.func, ast.Name): handler = MapProcessorRegistry.get_processor(rval.func.id) if handler: - return handler(map_name, rval, module) + return handler(map_name, rval, module, structs_sym_tab) else: logger.warning(f"Unknown map type {rval.func.id}, defaulting to HashMap") return process_hash_map(map_name, rval, module) diff --git a/pythonbpf/maps/maps_utils.py b/pythonbpf/maps/maps_utils.py index ee3ad08..194b408 100644 --- a/pythonbpf/maps/maps_utils.py +++ b/pythonbpf/maps/maps_utils.py @@ -1,5 +1,17 @@ from collections.abc import Callable +from dataclasses import dataclass +from llvmlite import ir from typing import Any +from .map_types import BPFMapType + + +@dataclass +class MapSymbol: + """Class representing a symbol on the map""" + + type: BPFMapType + sym: ir.GlobalVariable + params: dict[str, Any] | None = None class MapProcessorRegistry: diff --git a/pythonbpf/type_deducer.py b/pythonbpf/type_deducer.py index a6834a9..fd589ae 100644 --- a/pythonbpf/type_deducer.py +++ b/pythonbpf/type_deducer.py @@ -16,6 +16,8 @@ mapping = { "c_long": ir.IntType(64), "c_ulong": ir.IntType(64), "c_longlong": ir.IntType(64), + "c_uint": ir.IntType(32), + "c_int": ir.IntType(32), # Not so sure about this one "str": ir.PointerType(ir.IntType(8)), } diff --git a/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py b/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py index f641e80..30f3058 100644 --- a/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py +++ b/pythonbpf/vmlinux_parser/vmlinux_exports_handler.py @@ -1,6 +1,6 @@ import logging from typing import Any - +import ctypes from llvmlite import ir from pythonbpf.local_symbol import LocalSymbol @@ -94,12 +94,11 @@ class VmlinuxHandler: f"Attempting to access field {field_name} of possible vmlinux struct {struct_var_name}" ) python_type: type = var_info.metadata - globvar_ir, field_data = self.get_field_type( - python_type.__name__, field_name - ) + struct_name = python_type.__name__ + globvar_ir, field_data = self.get_field_type(struct_name, field_name) builder.function.args[0].type = ir.PointerType(ir.IntType(8)) field_ptr = self.load_ctx_field( - builder, builder.function.args[0], globvar_ir + builder, builder.function.args[0], globvar_ir, field_data, struct_name ) # Return pointer to field and field type return field_ptr, field_data @@ -107,7 +106,7 @@ class VmlinuxHandler: raise RuntimeError("Variable accessed not found in symbol table") @staticmethod - def load_ctx_field(builder, ctx_arg, offset_global): + def load_ctx_field(builder, ctx_arg, offset_global, field_data, struct_name=None): """ Generate LLVM IR to load a field from BPF context using offset. @@ -115,9 +114,10 @@ class VmlinuxHandler: builder: llvmlite IRBuilder instance ctx_arg: The context pointer argument (ptr/i8*) offset_global: Global variable containing the field offset (i64) - + field_data: contains data about the field + struct_name: Name of the struct being accessed (optional) Returns: - The loaded value (i64 register) + The loaded value (i64 register or appropriately sized) """ # Load the offset value @@ -162,13 +162,61 @@ class VmlinuxHandler: passthrough_fn, [ir.Constant(ir.IntType(32), 0), field_ptr], tail=True ) - # Bitcast to i64* (assuming field is 64-bit, adjust if needed) - i64_ptr_type = ir.PointerType(ir.IntType(64)) - typed_ptr = builder.bitcast(verified_ptr, i64_ptr_type) + # Determine the appropriate IR type based on field information + int_width = 64 # Default to 64-bit + needs_zext = False # Track if we need zero-extension for xdp_md + + if field_data is not None: + # Try to determine the size from field metadata + if field_data.type.__module__ == ctypes.__name__: + try: + field_size_bytes = ctypes.sizeof(field_data.type) + field_size_bits = field_size_bytes * 8 + + if field_size_bits in [8, 16, 32, 64]: + int_width = field_size_bits + logger.info(f"Determined field size: {int_width} bits") + + # Special handling for struct_xdp_md i32 fields + # Load as i32 but extend to i64 before storing + if struct_name == "struct_xdp_md" and int_width == 32: + needs_zext = True + logger.info( + "struct_xdp_md i32 field detected, will zero-extend to i64" + ) + else: + logger.warning( + f"Unusual field size {field_size_bits} bits, using default 64" + ) + except Exception as e: + logger.warning( + f"Could not determine field size: {e}, using default 64" + ) + + elif field_data.type.__module__ == "vmlinux": + # For pointers to structs or complex vmlinux types + if field_data.ctype_complex_type is not None and issubclass( + field_data.ctype_complex_type, ctypes._Pointer + ): + int_width = 64 # Pointers are always 64-bit + logger.info("Field is a pointer type, using 64 bits") + # TODO: Add handling for other complex types (arrays, embedded structs, etc.) + else: + logger.warning("Complex vmlinux field type, using default 64 bits") + + # Bitcast to appropriate pointer type based on determined width + ptr_type = ir.PointerType(ir.IntType(int_width)) + + typed_ptr = builder.bitcast(verified_ptr, ptr_type) # Load and return the value value = builder.load(typed_ptr) + # Zero-extend i32 to i64 for struct_xdp_md fields + if needs_zext: + value = builder.zext(value, ir.IntType(64)) + logger.info("Zero-extended i32 value to i64 for struct_xdp_md field") + return value def has_field(self, struct_name, field_name): diff --git a/tests/c-form/i32test.bpf.c b/tests/c-form/i32test.bpf.c new file mode 100644 index 0000000..8457bab --- /dev/null +++ b/tests/c-form/i32test.bpf.c @@ -0,0 +1,15 @@ +#include +#include + +SEC("xdp") +int print_xdp_data(struct xdp_md *ctx) +{ + // 'data' is a pointer to the start of packet data + long data = (long)ctx->data; + + bpf_printk("ctx->data = %lld\n", data); + + return XDP_PASS; +} + +char LICENSE[] SEC("license") = "GPL"; diff --git a/tests/failing_tests/vmlinux/args_test.py b/tests/failing_tests/vmlinux/args_test.py new file mode 100644 index 0000000..7acca25 --- /dev/null +++ b/tests/failing_tests/vmlinux/args_test.py @@ -0,0 +1,30 @@ +import logging + +from pythonbpf import bpf, section, bpfglobal, compile_to_ir +from pythonbpf import compile # noqa: F401 +from vmlinux import TASK_COMM_LEN # noqa: F401 +from vmlinux import struct_trace_event_raw_sys_enter # noqa: F401 +from ctypes import c_int64, c_int32, c_void_p # noqa: F401 + + +# from vmlinux import struct_uinput_device +# from vmlinux import struct_blk_integrity_iter + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: struct_trace_event_raw_sys_enter) -> c_int64: + b = ctx.args + c = b[0] + print(f"This is context args field {c}") + return c_int64(0) + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile_to_ir("args_test.py", "args_test.ll", loglevel=logging.INFO) +compile() diff --git a/tests/passing_tests/hash_map_struct.py b/tests/passing_tests/hash_map_struct.py new file mode 100644 index 0000000..9f6cbac --- /dev/null +++ b/tests/passing_tests/hash_map_struct.py @@ -0,0 +1,42 @@ +from pythonbpf import bpf, section, struct, bpfglobal, compile, map +from pythonbpf.maps import HashMap +from pythonbpf.helper import pid +from ctypes import c_void_p, c_int64 + + +@bpf +@struct +class val_type: + counter: c_int64 + shizzle: c_int64 + + +@bpf +@map +def last() -> HashMap: + return HashMap(key=val_type, value=c_int64, max_entries=16) + + +@bpf +@section("tracepoint/syscalls/sys_enter_clone") +def hello_world(ctx: c_void_p) -> c_int64: + obj = val_type() + obj.counter, obj.shizzle = 42, 96 + t = last.lookup(obj) + if t: + print(f"Found existing entry: counter={obj.counter}, pid={t}") + last.delete(obj) + return 0 # type: ignore [return-value] + val = pid() + last.update(obj, val) + print(f"Map updated!, {obj.counter}, {obj.shizzle}, {val}") + return 0 # type: ignore [return-value] + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/helpers/bpf_probe_read.py b/tests/passing_tests/helpers/bpf_probe_read.py new file mode 100644 index 0000000..fcece4d --- /dev/null +++ b/tests/passing_tests/helpers/bpf_probe_read.py @@ -0,0 +1,29 @@ +from pythonbpf import bpf, section, bpfglobal, compile, struct +from ctypes import c_void_p, c_int64, c_uint64, c_uint32 +from pythonbpf.helper import probe_read + + +@bpf +@struct +class data_t: + pid: c_uint32 + value: c_uint64 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def test_probe_read(ctx: c_void_p) -> c_int64: + """Test bpf_probe_read helper function""" + data = data_t() + probe_read(data.value, 8, ctx) + probe_read(data.pid, 4, ctx) + return 0 + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/helpers/prandom.py b/tests/passing_tests/helpers/prandom.py new file mode 100644 index 0000000..396927b --- /dev/null +++ b/tests/passing_tests/helpers/prandom.py @@ -0,0 +1,25 @@ +from pythonbpf import bpf, bpfglobal, section, BPF, trace_pipe +from ctypes import c_void_p, c_int64 +from pythonbpf.helper import random + + +@bpf +@section("tracepoint/syscalls/sys_enter_clone") +def hello_world(ctx: c_void_p) -> c_int64: + r = random() + print(f"Hello, World!, {r}") + return 0 # type: ignore [return-value] + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +# Compile and load +b = BPF() +b.load() +b.attach_all() + +trace_pipe() diff --git a/tests/passing_tests/helpers/smp_processor_id.py b/tests/passing_tests/helpers/smp_processor_id.py new file mode 100644 index 0000000..8c17a75 --- /dev/null +++ b/tests/passing_tests/helpers/smp_processor_id.py @@ -0,0 +1,40 @@ +from pythonbpf import bpf, section, bpfglobal, compile, struct +from ctypes import c_void_p, c_int64, c_uint32, c_uint64 +from pythonbpf.helper import smp_processor_id, ktime + + +@bpf +@struct +class cpu_event_t: + cpu_id: c_uint32 + timestamp: c_uint64 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def trace_with_cpu(ctx: c_void_p) -> c_int64: + """Test bpf_get_smp_processor_id helper function""" + + # Get the current CPU ID + cpu = smp_processor_id() + + # Print it + print(f"Running on CPU {cpu}") + + # Use it in a struct + event = cpu_event_t() + event.cpu_id = smp_processor_id() + event.timestamp = ktime() + + print(f"Event on CPU {event.cpu_id} at time {event.timestamp}") + + return 0 + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/helpers/uid_gid.py b/tests/passing_tests/helpers/uid_gid.py new file mode 100644 index 0000000..e4e50b4 --- /dev/null +++ b/tests/passing_tests/helpers/uid_gid.py @@ -0,0 +1,31 @@ +from pythonbpf import bpf, section, bpfglobal, compile +from ctypes import c_void_p, c_int64 +from pythonbpf.helper import uid, pid + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def filter_by_user(ctx: c_void_p) -> c_int64: + """Filter events by specific user ID""" + + current_uid = uid() + + # Only trace root user (UID 0) + if current_uid == 0: + process_id = pid() + print(f"Root process {process_id} executed") + + # Or trace specific user (e.g., UID 1000) + if current_uid == 1002: + print("User 1002 executed something") + + return 0 + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/vmlinux/i32_test.py b/tests/passing_tests/vmlinux/i32_test.py new file mode 100644 index 0000000..4ba0969 --- /dev/null +++ b/tests/passing_tests/vmlinux/i32_test.py @@ -0,0 +1,31 @@ +from ctypes import c_int64, c_void_p +from pythonbpf import bpf, section, bpfglobal, compile_to_ir, compile +from vmlinux import struct_xdp_md +from vmlinux import XDP_PASS + + +@bpf +@section("xdp") +def print_xdp_dat2a(ct2x: struct_xdp_md) -> c_int64: + data = ct2x.data # 32-bit field: packet start pointer + print(f"ct2x->data = {data}") + return c_int64(XDP_PASS) + + +@bpf +@section("xdp") +def print_xdp_data(ctx: struct_xdp_md) -> c_int64: + data = ctx.data # 32-bit field: packet start pointer + something = c_void_p(data) + print(f"ctx->data = {something}") + return c_int64(XDP_PASS) + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile_to_ir("i32_test.py", "i32_test.ll") +compile() diff --git a/tests/passing_tests/vmlinux/i32_test_fail_1.py b/tests/passing_tests/vmlinux/i32_test_fail_1.py new file mode 100644 index 0000000..3f6d3c1 --- /dev/null +++ b/tests/passing_tests/vmlinux/i32_test_fail_1.py @@ -0,0 +1,24 @@ +from ctypes import c_int64 +from pythonbpf import bpf, section, bpfglobal, compile +from vmlinux import struct_xdp_md +from vmlinux import XDP_PASS +import logging + + +@bpf +@section("xdp") +def print_xdp_data(ctx: struct_xdp_md) -> c_int64: + data = 0 + data = ctx.data # 32-bit field: packet start pointer + something = 2 + data + print(f"ctx->data = {something}") + return c_int64(XDP_PASS) + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile(logging.INFO) diff --git a/tests/passing_tests/vmlinux/i32_test_fail_2.py b/tests/passing_tests/vmlinux/i32_test_fail_2.py new file mode 100644 index 0000000..4792fd6 --- /dev/null +++ b/tests/passing_tests/vmlinux/i32_test_fail_2.py @@ -0,0 +1,24 @@ +from ctypes import c_int64 +from pythonbpf import bpf, section, bpfglobal, compile, compile_to_ir +from vmlinux import struct_xdp_md +from vmlinux import XDP_PASS +import logging + + +@bpf +@section("xdp") +def print_xdp_data(ctx: struct_xdp_md) -> c_int64: + data = c_int64(ctx.data) # 32-bit field: packet start pointer + something = 2 + data + print(f"ctx->data = {something}") + return c_int64(XDP_PASS) + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile_to_ir("i32_test_fail_2.py", "i32_test_fail_2.ll") +compile(logging.INFO) diff --git a/tests/passing_tests/vmlinux/simple_struct_test.py b/tests/passing_tests/vmlinux/simple_struct_test.py index 97ab54a..2f34ba4 100644 --- a/tests/passing_tests/vmlinux/simple_struct_test.py +++ b/tests/passing_tests/vmlinux/simple_struct_test.py @@ -44,4 +44,4 @@ def LICENSE() -> str: compile_to_ir("simple_struct_test.py", "simple_struct_test.ll", loglevel=logging.DEBUG) -# compile() +compile()