diff --git a/pythonbpf/debuginfo/debug_info_generator.py b/pythonbpf/debuginfo/debug_info_generator.py index 62f0cc3..5d10855 100644 --- a/pythonbpf/debuginfo/debug_info_generator.py +++ b/pythonbpf/debuginfo/debug_info_generator.py @@ -184,3 +184,20 @@ class DebugInfoGenerator: "DIGlobalVariableExpression", {"var": global_var, "expr": self.module.add_debug_info("DIExpression", {})}, ) + + def get_int64_type(self): + return self.get_basic_type("long", 64, dc.DW_ATE_signed) + + def create_subroutine_type(self, return_type, param_types): + """ + Create a DISubroutineType given return type and list of parameter types. + Equivalent to: !DISubroutineType(types: !{ret, args...}) + """ + type_array = [return_type] + if isinstance(param_types, (list, tuple)): + type_array.extend(param_types) + else: + type_array.append(param_types) + return self.module.add_debug_info("DISubroutineType", {"types": type_array}) + + diff --git a/pythonbpf/functions/function_debug_info.py b/pythonbpf/functions/function_debug_info.py index b508076..e7f49ef 100644 --- a/pythonbpf/functions/function_debug_info.py +++ b/pythonbpf/functions/function_debug_info.py @@ -1,9 +1,12 @@ import ast import llvmlite.ir as ir - +import logging from pythonbpf.debuginfo import DebugInfoGenerator from pythonbpf.expr import VmlinuxHandlerRegistry +import ctypes + +logger = logging.getLogger(__name__) def generate_function_debug_info( @@ -12,10 +15,42 @@ def generate_function_debug_info( generator = DebugInfoGenerator(module) leading_argument = func_node.args.args[0] leading_argument_name = leading_argument.arg - # TODO: add ctypes handling as well here - print(leading_argument.arg, leading_argument.annotation.id) - context_debug_info = VmlinuxHandlerRegistry.get_struct_debug_info( - name=leading_argument.annotation.id - ) - print(context_debug_info) - pass + annotation = leading_argument.annotation + if func_node.returns is None: + # TODO: should check if this logic is consistent with function return type handling elsewhere + return_type = ctypes.c_int64() + elif hasattr(func_node.returns, "id"): + return_type = func_node.returns.id + if return_type == "c_int32": + return_type = generator.get_int32_type() + elif return_type == "c_int64": + return_type = generator.get_int64_type() + elif return_type == "c_uint32": + return_type = generator.get_uint32_type() + elif return_type == "c_uint64": + return_type = generator.get_uint64_type() + else: + logger.warning( + "Return type should be int32, int64, uint32 or uint64 only. Falling back to int64" + ) + return_type = generator.get_int64_type() + else: + return_type = ctypes.c_int64() + # context processing + if annotation is None: + logger.warning("Type of context of function not found.") + return + if hasattr(annotation, "id"): + ctype_name = annotation.id + if ctype_name == "c_void_p": + return + elif ctype_name.startswith("ctypes"): + raise SyntaxError( + "The first argument should always be a pointer to a struct or a void pointer" + ) + context_debug_info = VmlinuxHandlerRegistry.get_struct_debug_info(annotation.id) + pointer_to_context_debug_info = generator.create_pointer_type(context_debug_info, 64) + subroutine_type = generator.create_subroutine_type(return_type, pointer_to_context_debug_info) + print(subroutine_type) + else: + logger.error(f"Invalid annotation type for argument '{leading_argument_name}'") diff --git a/tests/failing_tests/vmlinux/struct_field_access.py b/tests/failing_tests/vmlinux/struct_field_access.py index 0a6d68d..64405fa 100644 --- a/tests/failing_tests/vmlinux/struct_field_access.py +++ b/tests/failing_tests/vmlinux/struct_field_access.py @@ -4,7 +4,7 @@ from pythonbpf import bpf, section, bpfglobal, compile_to_ir from pythonbpf import compile # noqa: F401 from vmlinux import TASK_COMM_LEN # noqa: F401 from vmlinux import struct_trace_event_raw_sys_enter # noqa: F401 -from ctypes import c_int64, c_void_p # noqa: F401 +from ctypes import c_int64, c_int32, c_void_p # noqa: F401 # from vmlinux import struct_uinput_device @@ -13,10 +13,10 @@ from ctypes import c_int64, c_void_p # noqa: F401 @bpf @section("tracepoint/syscalls/sys_enter_execve") -def hello_world(ctx: struct_trace_event_raw_sys_enter) -> c_int64: +def hello_world(ctx: struct_trace_event_raw_sys_enter) -> c_int32: b = ctx.id print(f"This is context field {b}") - return c_int64(0) + return c_int32(0) @bpf