mirror of
https://github.com/varun-r-mallya/Python-BPF.git
synced 2025-12-31 21:06:25 +00:00
Fix t/f/return.py, tweak handle_binary_ops
This commit is contained in:
@ -63,4 +63,6 @@ def handle_binary_op_impl(rval, module, builder, local_sym_tab):
|
|||||||
|
|
||||||
def handle_binary_op(rval, module, builder, var_name, local_sym_tab):
|
def handle_binary_op(rval, module, builder, var_name, local_sym_tab):
|
||||||
result = handle_binary_op_impl(rval, module, builder, local_sym_tab)
|
result = handle_binary_op_impl(rval, module, builder, local_sym_tab)
|
||||||
builder.store(result, local_sym_tab[var_name].var)
|
if var_name in local_sym_tab:
|
||||||
|
builder.store(result, local_sym_tab[var_name].var)
|
||||||
|
return result, result.type
|
||||||
|
|||||||
@ -48,7 +48,7 @@ def processor(source_code, filename, module):
|
|||||||
globals_processing(tree, module)
|
globals_processing(tree, module)
|
||||||
|
|
||||||
|
|
||||||
def compile_to_ir(filename: str, output: str, loglevel=logging.WARNING):
|
def compile_to_ir(filename: str, output: str, loglevel=logging.INFO):
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=loglevel, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
level=loglevel, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
||||||
)
|
)
|
||||||
@ -121,7 +121,7 @@ def compile_to_ir(filename: str, output: str, loglevel=logging.WARNING):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def compile(loglevel=logging.WARNING) -> bool:
|
def compile(loglevel=logging.INFO) -> bool:
|
||||||
# Look one level up the stack to the caller of this function
|
# Look one level up the stack to the caller of this function
|
||||||
caller_frame = inspect.stack()[1]
|
caller_frame = inspect.stack()[1]
|
||||||
caller_file = Path(caller_frame.filename).resolve()
|
caller_file = Path(caller_frame.filename).resolve()
|
||||||
@ -154,7 +154,7 @@ def compile(loglevel=logging.WARNING) -> bool:
|
|||||||
return success
|
return success
|
||||||
|
|
||||||
|
|
||||||
def BPF(loglevel=logging.WARNING) -> BpfProgram:
|
def BPF(loglevel=logging.INFO) -> BpfProgram:
|
||||||
caller_frame = inspect.stack()[1]
|
caller_frame = inspect.stack()[1]
|
||||||
src = inspect.getsource(caller_frame.frame)
|
src = inspect.getsource(caller_frame.frame)
|
||||||
with tempfile.NamedTemporaryFile(
|
with tempfile.NamedTemporaryFile(
|
||||||
|
|||||||
@ -391,17 +391,31 @@ def process_stmt(
|
|||||||
isinstance(stmt.value, ast.Call)
|
isinstance(stmt.value, ast.Call)
|
||||||
and isinstance(stmt.value.func, ast.Name)
|
and isinstance(stmt.value.func, ast.Name)
|
||||||
and len(stmt.value.args) == 1
|
and len(stmt.value.args) == 1
|
||||||
and isinstance(stmt.value.args[0], ast.Constant)
|
|
||||||
and isinstance(stmt.value.args[0].value, int)
|
|
||||||
):
|
):
|
||||||
call_type = stmt.value.func.id
|
if isinstance(stmt.value.args[0], ast.Constant) and isinstance(
|
||||||
if ctypes_to_ir(call_type) != ret_type:
|
stmt.value.args[0].value, int
|
||||||
raise ValueError(
|
):
|
||||||
"Return type mismatch: expected"
|
call_type = stmt.value.func.id
|
||||||
f"{ctypes_to_ir(call_type)}, got {call_type}"
|
if ctypes_to_ir(call_type) != ret_type:
|
||||||
|
raise ValueError(
|
||||||
|
"Return type mismatch: expected"
|
||||||
|
f"{ctypes_to_ir(call_type)}, got {call_type}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
builder.ret(ir.Constant(ret_type, stmt.value.args[0].value))
|
||||||
|
did_return = True
|
||||||
|
elif isinstance(stmt.value.args[0], ast.BinOp):
|
||||||
|
# TODO: Should be routed through eval_expr
|
||||||
|
val = handle_binary_op(
|
||||||
|
stmt.value.args[0], module, builder, None, local_sym_tab
|
||||||
)
|
)
|
||||||
else:
|
if val is None:
|
||||||
builder.ret(ir.Constant(ret_type, stmt.value.args[0].value))
|
raise ValueError("Failed to evaluate return expression")
|
||||||
|
if val[1] != ret_type:
|
||||||
|
raise ValueError(
|
||||||
|
"Return type mismatch: expected" f"{ret_type}, got {val[1]}"
|
||||||
|
)
|
||||||
|
builder.ret(val[0])
|
||||||
did_return = True
|
did_return = True
|
||||||
elif isinstance(stmt.value, ast.Name):
|
elif isinstance(stmt.value, ast.Name):
|
||||||
if stmt.value.id == "XDP_PASS":
|
if stmt.value.id == "XDP_PASS":
|
||||||
|
|||||||
0
tests/failing_tests/var_rval.py
Normal file
0
tests/failing_tests/var_rval.py
Normal file
Reference in New Issue
Block a user