diff --git a/pythonbpf/functions/return_utils.py b/pythonbpf/functions/return_utils.py index 5637ff0..a3f08af 100644 --- a/pythonbpf/functions/return_utils.py +++ b/pythonbpf/functions/return_utils.py @@ -7,6 +7,11 @@ from pythonbpf.binary_ops import handle_binary_op logger: logging.Logger = logging.getLogger(__name__) +# TODO: Ideally there should be only 3 cases: +# - Return none +# - Return XDP +# - Return expr + XDP_ACTIONS = { "XDP_ABORTED": 0, "XDP_DROP": 1, @@ -26,7 +31,6 @@ def _handle_none_return(builder) -> bool: def _handle_typed_constant_return(call_type, return_value, builder, ret_type) -> bool: """Handle typed constant return like: return c_int64(42)""" - # call_type = stmt.value.func.id expected_type = ctypes_to_ir(call_type) if expected_type != ret_type: @@ -43,7 +47,6 @@ def _handle_typed_constant_return(call_type, return_value, builder, ret_type) -> def _handle_binop_return(arg, builder, ret_type, local_sym_tab) -> bool: """Handle return with binary operation: return c_int64(x + 1)""" - # result = handle_binary_op(stmt.value.args[0], builder, None, local_sym_tab) result = handle_binary_op(arg, builder, None, local_sym_tab) if result is None: @@ -62,8 +65,6 @@ def _handle_binop_return(arg, builder, ret_type, local_sym_tab) -> bool: def _handle_variable_return(var_name, builder, ret_type, local_sym_tab) -> bool: """Handle return of a variable: return c_int64(my_var)""" - # var_name = stmt.value.args[0].id - if var_name not in local_sym_tab: raise ValueError(f"Undefined variable in return: {var_name}") @@ -78,6 +79,38 @@ def _handle_variable_return(var_name, builder, ret_type, local_sym_tab) -> bool: return True +def _handle_wrapped_return(stmt: ast.Return, builder, ret_type, local_sym_tab) -> bool: + """Handle wrapped returns: return c_int64(42), return c_int64(x + 1), return c_int64(my_var)""" + + if not ( + isinstance(stmt.value, ast.Call) + and isinstance(stmt.value.func, ast.Name) + and len(stmt.value.args) == 1 + ): + return False + + arg = stmt.value.args[0] + + # Case 1: Constant value - return c_int64(42) + if isinstance(arg, ast.Constant) and isinstance(arg.value, int): + return _handle_typed_constant_return( + stmt.value.func.id, arg.value, builder, ret_type + ) + + # Case 2: Binary operation - return c_int64(x + 1) + elif isinstance(arg, ast.BinOp): + return _handle_binop_return(arg, builder, ret_type, local_sym_tab) + + # Case 3: Variable - return c_int64(my_var) + elif isinstance(arg, ast.Name): + if not arg.id: + raise ValueError("Variable return must have a type, e.g., c_int64") + return _handle_variable_return(arg.id, builder, ret_type, local_sym_tab) + + else: + raise ValueError(f"Unsupported return argument type: {type(arg).__name__}") + + def _handle_xdp_return(stmt: ast.Return, builder, ret_type) -> bool: """Handle XDP returns""" if not isinstance(stmt.value, ast.Name):