Make handle_return (crude for now)

This commit is contained in:
Pragyansh Chaturvedi
2025-10-05 23:19:06 +05:30
parent d0a8e96b70
commit e9f3aa25d2

View File

@ -9,6 +9,7 @@ from pythonbpf.type_deducer import ctypes_to_ir
from pythonbpf.binary_ops import handle_binary_op
from pythonbpf.expr_pass import eval_expr, handle_expr
logger = logging.getLogger(__name__)
@ -350,6 +351,65 @@ def handle_if(
builder.position_at_end(merge_block)
def handle_return(
func, module, builder, stmt, map_sym_tab, local_sym_tab, struct_sym_tab, ret_type
):
if stmt.value is None:
builder.ret(ir.Constant(ir.IntType(64), 0))
return True
elif (
isinstance(stmt.value, ast.Call)
and isinstance(stmt.value.func, ast.Name)
and len(stmt.value.args) == 1
):
if isinstance(stmt.value.args[0], ast.Constant) and isinstance(
stmt.value.args[0].value, int
):
call_type = stmt.value.func.id
if ctypes_to_ir(call_type) != ret_type:
raise ValueError(
"Return type mismatch: expected"
f"{ctypes_to_ir(call_type)}, got {call_type}"
)
else:
builder.ret(ir.Constant(ret_type, stmt.value.args[0].value))
return True
elif isinstance(stmt.value.args[0], ast.BinOp):
# TODO: Should be routed through eval_expr
val = handle_binary_op(stmt.value.args[0], builder, None, local_sym_tab)
if val is None:
raise ValueError("Failed to evaluate return expression")
if val[1] != ret_type:
raise ValueError(
f"Return type mismatch: expected {ret_type}, got {val[1]}"
)
builder.ret(val[0])
return True
elif isinstance(stmt.value.args[0], ast.Name):
if stmt.value.args[0].id in local_sym_tab:
var = local_sym_tab[stmt.value.args[0].id].var
val = builder.load(var)
if val.type != ret_type:
raise ValueError(
f"Return type mismatch: expected {ret_type}, got {val.type}"
)
builder.ret(val)
return True
else:
raise ValueError("Failed to evaluate return expression")
elif isinstance(stmt.value, ast.Name):
if stmt.value.id == "XDP_PASS":
builder.ret(ir.Constant(ret_type, 2))
return True
elif stmt.value.id == "XDP_DROP":
builder.ret(ir.Constant(ret_type, 1))
return True
else:
raise ValueError("Failed to evaluate return expression")
else:
raise ValueError("Unsupported return value")
def process_stmt(
func,
module,
@ -383,60 +443,16 @@ def process_stmt(
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab
)
elif isinstance(stmt, ast.Return):
if stmt.value is None:
builder.ret(ir.Constant(ir.IntType(64), 0))
did_return = True
elif (
isinstance(stmt.value, ast.Call)
and isinstance(stmt.value.func, ast.Name)
and len(stmt.value.args) == 1
):
if isinstance(stmt.value.args[0], ast.Constant) and isinstance(
stmt.value.args[0].value, int
):
call_type = stmt.value.func.id
if ctypes_to_ir(call_type) != ret_type:
raise ValueError(
"Return type mismatch: expected"
f"{ctypes_to_ir(call_type)}, got {call_type}"
)
else:
builder.ret(ir.Constant(ret_type, stmt.value.args[0].value))
did_return = True
elif isinstance(stmt.value.args[0], ast.BinOp):
# TODO: Should be routed through eval_expr
val = handle_binary_op(stmt.value.args[0], builder, None, local_sym_tab)
if val is None:
raise ValueError("Failed to evaluate return expression")
if val[1] != ret_type:
raise ValueError(
f"Return type mismatch: expected {ret_type}, got {val[1]}"
)
builder.ret(val[0])
did_return = True
elif isinstance(stmt.value.args[0], ast.Name):
if stmt.value.args[0].id in local_sym_tab:
var = local_sym_tab[stmt.value.args[0].id].var
val = builder.load(var)
if val.type != ret_type:
raise ValueError(
f"Return type mismatch: expected {ret_type}, got {val.type}"
)
builder.ret(val)
did_return = True
else:
raise ValueError("Failed to evaluate return expression")
elif isinstance(stmt.value, ast.Name):
if stmt.value.id == "XDP_PASS":
builder.ret(ir.Constant(ret_type, 2))
did_return = True
elif stmt.value.id == "XDP_DROP":
builder.ret(ir.Constant(ret_type, 1))
did_return = True
else:
raise ValueError("Failed to evaluate return expression")
else:
raise ValueError("Unsupported return value")
did_return = handle_return(
func,
module,
builder,
stmt,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
ret_type,
)
return did_return