From 1118e4fcd62a179ca7060d13558a500bd58825b7 Mon Sep 17 00:00:00 2001 From: varun-r-mallya Date: Sun, 7 Sep 2025 20:12:39 +0530 Subject: [PATCH] add naive unpythonic return type inference to function parsing --- examples/execve2.py | 10 ++- pythonbpf/functions_pass.py | 156 ++++++++++++++++-------------------- pythonbpf/type_deducer.py | 44 +++++----- 3 files changed, 96 insertions(+), 114 deletions(-) diff --git a/examples/execve2.py b/examples/execve2.py index 165a191..39d1de2 100644 --- a/examples/execve2.py +++ b/examples/execve2.py @@ -1,11 +1,17 @@ from pythonbpf.decorators import bpf, section -from ctypes import c_void_p, c_int32 +from ctypes import c_void_p, c_int64, c_int32 @bpf @section("tracepoint/syscalls/sys_enter_execve") def hello(ctx: c_void_p) -> c_int32: - print("Hello, World!") + print("entered") return c_int32(0) +@bpf +@section("tracepoint/syscalls/sys_exit_execve") +def hello_again(ctx: c_void_p) -> c_int64: + print("exited") + return c_int64(0) + LICENSE = "GPL" diff --git a/pythonbpf/functions_pass.py b/pythonbpf/functions_pass.py index abe8466..c88be8c 100644 --- a/pythonbpf/functions_pass.py +++ b/pythonbpf/functions_pass.py @@ -1,47 +1,6 @@ from llvmlite import ir import ast - - -def emit_function(module: ir.Module, name: str): - ret_type = ir.IntType(32) - ptr_type = ir.PointerType() - func_ty = ir.FunctionType(ret_type, [ptr_type]) - - func = ir.Function(module, func_ty, name) - - param = func.args[0] - param.add_attribute("nocapture") - - func.attributes.add("nounwind") - # func.attributes.add("\"frame-pointer\"=\"all\"") - # func.attributes.add("no-trapping-math", "true") - # func.attributes.add("stack-protector-buffer-size", "8") - - block = func.append_basic_block(name="entry") - builder = ir.IRBuilder(block) - fmt_gvar = module.get_global("hello.____fmt") - - if fmt_gvar is None: - # If you haven't created the format string global yet - print("Warning: Format string global not found") - else: - # Cast integer 6 to function pointer type - fn_type = ir.FunctionType(ir.IntType( - 64), [ptr_type, ir.IntType(32)], var_arg=True) - fn_ptr_type = ir.PointerType(fn_type) - fn_addr = ir.Constant(ir.IntType(64), 6) - fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) - # Call the function - builder.call(fn_ptr, [fmt_gvar, ir.Constant(ir.IntType(32), 14)]) - - builder.ret(ir.Constant(ret_type, 0)) - - func.return_value.add_attribute("noundef") - func.linkage = "dso_local" - func.section = "kprobe/sys_clone" - print("function emitted:", name) - return func - +from .type_deducer import ctypes_to_ir def get_probe_string(func_node): """Extract the probe string from the decorator of the function node.""" @@ -58,7 +17,7 @@ def get_probe_string(func_node): return "helper" -def process_func_body(module, builder, func_node, func): +def process_func_body(module, builder, func_node, func, ret_type): """Process the body of a bpf function""" # TODO: A lot. We just have print -> bpf_trace_printk for now did_return = False @@ -100,23 +59,28 @@ def process_func_body(module, builder, func_node, func): if stmt.value is None: builder.ret(ir.Constant(ir.IntType(32), 0)) did_return = True - elif isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name) and stmt.value.func.id == "c_int32" and len(stmt.value.args) == 1 and isinstance(stmt.value.args[0], ast.Constant) and isinstance(stmt.value.args[0].value, int): - builder.ret(ir.Constant(ir.IntType( - 32), stmt.value.args[0].value)) - did_return = True + elif isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name) and len(stmt.value.args) == 1 and isinstance(stmt.value.args[0], ast.Constant) and isinstance(stmt.value.args[0].value, int): + call_type = stmt.value.func.id + if ctypes_to_ir(call_type) != ret_type: + raise ValueError(f"Return type mismatch: expected {ctypes_to_ir(call_type)}, got {call_type}") + else: + builder.ret(ir.Constant(ret_type, stmt.value.args[0].value)) + did_return = True else: print("Unsupported return value") if not did_return: builder.ret(ir.Constant(ir.IntType(32), 0)) -def process_bpf_chunk(func_node, module): +def process_bpf_chunk(func_node, module, return_type): """Process a single BPF chunk (function) and emit corresponding LLVM IR.""" func_name = func_node.name - # TODO: parse return type - ret_type = ir.IntType(32) + #TODO: The function actual arg retgurn type is parsed, + # but the actual output is not. It's still very wrong. Try uncommenting the + # code in execve2.py once + ret_type = return_type # TODO: parse parameters param_types = [] @@ -142,7 +106,7 @@ def process_bpf_chunk(func_node, module): block = func.append_basic_block(name="entry") builder = ir.IRBuilder(block) - process_func_body(module, builder, func_node, func) + process_func_body(module, builder, func_node, func, ret_type) print(func) print(module) @@ -154,39 +118,59 @@ def func_proc(tree, module, chunks): func_type = get_probe_string(func_node) print(f"Found probe_string of {func_node.name}: {func_type}") - process_bpf_chunk(func_node, module) + process_bpf_chunk(func_node, module, ctypes_to_ir(infer_return_type(func_node))) - -def functions_processing(tree, module): - bpf_functions = [] - helper_functions = [] - for node in tree.body: - section_name = "" - if isinstance(node, ast.FunctionDef): - if len(node.decorator_list) == 1: - bpf_functions.append(node) - node.end_lineno - else: - # IDK why this check is needed, but whatever - if 'helper_functions' not in locals(): - helper_functions.append(node) - - # TODO: implement helpers first - - for func in bpf_functions: - dec = func.decorator_list[0] - if ( - isinstance(dec, ast.Call) - and isinstance(dec.func, ast.Name) - and dec.func.id == "section" - and len(dec.args) == 1 - and isinstance(dec.args[0], ast.Constant) - and isinstance(dec.args[0].value, str) - ): - section_name = dec.args[0].value - else: - print(f"ERROR: Invalid decorator for function {func.name}") - continue - - # TODO: parse arguments and return type - emit_function(module, func.name + "func") +def infer_return_type(func_node: ast.FunctionDef): + if not isinstance(func_node, (ast.FunctionDef, ast.AsyncFunctionDef)): + raise TypeError("Expected ast.FunctionDef") + if func_node.returns is not None: + try: + return ast.unparse(func_node.returns) + except Exception: + node = func_node.returns + if isinstance(node, ast.Name): + return node.id + if isinstance(node, ast.Attribute): + return getattr(node, "attr", type(node).__name__) + try: + return str(node) + except Exception: + return type(node).__name__ + found_type = None + def _expr_type(e): + if e is None: + return "None" + if isinstance(e, ast.Constant): + return type(e.value).__name__ + if isinstance(e, ast.Name): + return e.id + if isinstance(e, ast.Call): + f = e.func + if isinstance(f, ast.Name): + return f.id + if isinstance(f, ast.Attribute): + try: + return ast.unparse(f) + except Exception: + return getattr(f, "attr", type(f).__name__) + try: + return ast.unparse(f) + except Exception: + return type(f).__name__ + if isinstance(e, ast.Attribute): + try: + return ast.unparse(e) + except Exception: + return getattr(e, "attr", type(e).__name__) + try: + return ast.unparse(e) + except Exception: + return type(e).__name__ + for node in ast.walk(func_node): + if isinstance(node, ast.Return): + t = _expr_type(node.value) + if found_type is None: + found_type = t + elif found_type != t: + raise ValueError(f"Conflicting return types: {found_type} vs {t}") + return found_type or "None" diff --git a/pythonbpf/type_deducer.py b/pythonbpf/type_deducer.py index 2f58a1d..30b6b5f 100644 --- a/pythonbpf/type_deducer.py +++ b/pythonbpf/type_deducer.py @@ -1,29 +1,21 @@ -import ctypes from llvmlite import ir -def ctypes_to_ir(ctype): - if ctype is ctypes.c_int32: - return ir.IntType(32) - if ctype is ctypes.c_int64: - return ir.IntType(64) - if ctype is ctypes.c_uint8: - return ir.IntType(8) - if ctype is ctypes.c_double: - return ir.DoubleType() - if ctype is ctypes.c_float: - return ir.FloatType() - - # pointers - if hasattr(ctype, "_type_") and hasattr(ctype, "_length_"): - # ctypes array - return ir.ArrayType(ctypes_to_ir(ctype._type_), ctype._length_) - - # if hasattr(ctype, "_type_") and issubclass(ctype, ctypes._Pointer): - # return ir.PointerType(ctypes_to_ir(ctype._type_)) - - # structs - if issubclass(ctype, ctypes.Structure): - fields = [ctypes_to_ir(f[1]) for f in ctype._fields_] - return ir.LiteralStructType(fields) - +#TODO: THIS IS NOT SUPPOSED TO MATCH STRINGS :skull: +def ctypes_to_ir(ctype: str): + print("CTYPE", ctype) + mapping = { + "c_int8": ir.IntType(8), + "c_uint8": ir.IntType(8), + "c_int16": ir.IntType(16), + "c_uint16": ir.IntType(16), + "c_int32": ir.IntType(32), + "c_uint32": ir.IntType(32), + "c_int64": ir.IntType(64), + "c_uint64": ir.IntType(64), + "c_float": ir.FloatType(), + "c_double": ir.DoubleType(), + "c_void_p": ir.IntType(64), + } + if ctype in mapping: + return mapping[ctype] raise NotImplementedError(f"No mapping for {ctype}")