From e9f3aa25d27c475ba4e93062671df438e2ce4fc0 Mon Sep 17 00:00:00 2001 From: Pragyansh Chaturvedi Date: Sun, 5 Oct 2025 23:19:06 +0530 Subject: [PATCH] Make handle_return (crude for now) --- pythonbpf/functions/functions_pass.py | 124 +++++++++++++++----------- 1 file changed, 70 insertions(+), 54 deletions(-) diff --git a/pythonbpf/functions/functions_pass.py b/pythonbpf/functions/functions_pass.py index c6a15a9..5143a93 100644 --- a/pythonbpf/functions/functions_pass.py +++ b/pythonbpf/functions/functions_pass.py @@ -9,6 +9,7 @@ from pythonbpf.type_deducer import ctypes_to_ir from pythonbpf.binary_ops import handle_binary_op from pythonbpf.expr_pass import eval_expr, handle_expr + logger = logging.getLogger(__name__) @@ -350,6 +351,65 @@ def handle_if( builder.position_at_end(merge_block) +def handle_return( + func, module, builder, stmt, map_sym_tab, local_sym_tab, struct_sym_tab, ret_type +): + if stmt.value is None: + builder.ret(ir.Constant(ir.IntType(64), 0)) + return True + elif ( + isinstance(stmt.value, ast.Call) + and isinstance(stmt.value.func, ast.Name) + and len(stmt.value.args) == 1 + ): + if isinstance(stmt.value.args[0], ast.Constant) and isinstance( + stmt.value.args[0].value, int + ): + call_type = stmt.value.func.id + if ctypes_to_ir(call_type) != ret_type: + raise ValueError( + "Return type mismatch: expected" + f"{ctypes_to_ir(call_type)}, got {call_type}" + ) + else: + builder.ret(ir.Constant(ret_type, stmt.value.args[0].value)) + return True + elif isinstance(stmt.value.args[0], ast.BinOp): + # TODO: Should be routed through eval_expr + val = handle_binary_op(stmt.value.args[0], builder, None, local_sym_tab) + if val is None: + raise ValueError("Failed to evaluate return expression") + if val[1] != ret_type: + raise ValueError( + f"Return type mismatch: expected {ret_type}, got {val[1]}" + ) + builder.ret(val[0]) + return True + elif isinstance(stmt.value.args[0], ast.Name): + if stmt.value.args[0].id in local_sym_tab: + var = local_sym_tab[stmt.value.args[0].id].var + val = builder.load(var) + if val.type != ret_type: + raise ValueError( + f"Return type mismatch: expected {ret_type}, got {val.type}" + ) + builder.ret(val) + return True + else: + raise ValueError("Failed to evaluate return expression") + elif isinstance(stmt.value, ast.Name): + if stmt.value.id == "XDP_PASS": + builder.ret(ir.Constant(ret_type, 2)) + return True + elif stmt.value.id == "XDP_DROP": + builder.ret(ir.Constant(ret_type, 1)) + return True + else: + raise ValueError("Failed to evaluate return expression") + else: + raise ValueError("Unsupported return value") + + def process_stmt( func, module, @@ -383,60 +443,16 @@ def process_stmt( func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab ) elif isinstance(stmt, ast.Return): - if stmt.value is None: - builder.ret(ir.Constant(ir.IntType(64), 0)) - did_return = True - elif ( - isinstance(stmt.value, ast.Call) - and isinstance(stmt.value.func, ast.Name) - and len(stmt.value.args) == 1 - ): - if isinstance(stmt.value.args[0], ast.Constant) and isinstance( - stmt.value.args[0].value, int - ): - call_type = stmt.value.func.id - if ctypes_to_ir(call_type) != ret_type: - raise ValueError( - "Return type mismatch: expected" - f"{ctypes_to_ir(call_type)}, got {call_type}" - ) - else: - builder.ret(ir.Constant(ret_type, stmt.value.args[0].value)) - did_return = True - elif isinstance(stmt.value.args[0], ast.BinOp): - # TODO: Should be routed through eval_expr - val = handle_binary_op(stmt.value.args[0], builder, None, local_sym_tab) - if val is None: - raise ValueError("Failed to evaluate return expression") - if val[1] != ret_type: - raise ValueError( - f"Return type mismatch: expected {ret_type}, got {val[1]}" - ) - builder.ret(val[0]) - did_return = True - elif isinstance(stmt.value.args[0], ast.Name): - if stmt.value.args[0].id in local_sym_tab: - var = local_sym_tab[stmt.value.args[0].id].var - val = builder.load(var) - if val.type != ret_type: - raise ValueError( - f"Return type mismatch: expected {ret_type}, got {val.type}" - ) - builder.ret(val) - did_return = True - else: - raise ValueError("Failed to evaluate return expression") - elif isinstance(stmt.value, ast.Name): - if stmt.value.id == "XDP_PASS": - builder.ret(ir.Constant(ret_type, 2)) - did_return = True - elif stmt.value.id == "XDP_DROP": - builder.ret(ir.Constant(ret_type, 1)) - did_return = True - else: - raise ValueError("Failed to evaluate return expression") - else: - raise ValueError("Unsupported return value") + did_return = handle_return( + func, + module, + builder, + stmt, + map_sym_tab, + local_sym_tab, + structs_sym_tab, + ret_type, + ) return did_return