Merge pull request #20 from pythonbpf/fix-failing-tests

Fix failing tests in tests/
This commit is contained in:
Pragyansh Chaturvedi
2025-10-05 14:04:14 +05:30
committed by GitHub
7 changed files with 133 additions and 27 deletions

View File

@ -9,6 +9,7 @@ logger: Logger = logging.getLogger(__name__)
def recursive_dereferencer(var, builder): def recursive_dereferencer(var, builder):
"""dereference until primitive type comes out""" """dereference until primitive type comes out"""
# TODO: Not worrying about stack overflow for now # TODO: Not worrying about stack overflow for now
logger.info(f"Dereferencing {var}, type is {var.type}")
if isinstance(var.type, ir.PointerType): if isinstance(var.type, ir.PointerType):
a = builder.load(var) a = builder.load(var)
return recursive_dereferencer(a, builder) return recursive_dereferencer(a, builder)
@ -18,7 +19,7 @@ def recursive_dereferencer(var, builder):
raise TypeError(f"Unsupported type for dereferencing: {var.type}") raise TypeError(f"Unsupported type for dereferencing: {var.type}")
def get_operand_value(operand, module, builder, local_sym_tab): def get_operand_value(operand, builder, local_sym_tab):
"""Extract the value from an operand, handling variables and constants.""" """Extract the value from an operand, handling variables and constants."""
if isinstance(operand, ast.Name): if isinstance(operand, ast.Name):
if operand.id in local_sym_tab: if operand.id in local_sym_tab:
@ -29,14 +30,14 @@ def get_operand_value(operand, module, builder, local_sym_tab):
return ir.Constant(ir.IntType(64), operand.value) return ir.Constant(ir.IntType(64), operand.value)
raise TypeError(f"Unsupported constant type: {type(operand.value)}") raise TypeError(f"Unsupported constant type: {type(operand.value)}")
elif isinstance(operand, ast.BinOp): elif isinstance(operand, ast.BinOp):
return handle_binary_op_impl(operand, module, builder, local_sym_tab) return handle_binary_op_impl(operand, builder, local_sym_tab)
raise TypeError(f"Unsupported operand type: {type(operand)}") raise TypeError(f"Unsupported operand type: {type(operand)}")
def handle_binary_op_impl(rval, module, builder, local_sym_tab): def handle_binary_op_impl(rval, builder, local_sym_tab):
op = rval.op op = rval.op
left = get_operand_value(rval.left, module, builder, local_sym_tab) left = get_operand_value(rval.left, builder, local_sym_tab)
right = get_operand_value(rval.right, module, builder, local_sym_tab) right = get_operand_value(rval.right, builder, local_sym_tab)
logger.info(f"left is {left}, right is {right}, op is {op}") logger.info(f"left is {left}, right is {right}, op is {op}")
# Map AST operation nodes to LLVM IR builder methods # Map AST operation nodes to LLVM IR builder methods
@ -61,6 +62,11 @@ def handle_binary_op_impl(rval, module, builder, local_sym_tab):
raise SyntaxError("Unsupported binary operation") raise SyntaxError("Unsupported binary operation")
def handle_binary_op(rval, module, builder, var_name, local_sym_tab): def handle_binary_op(rval, builder, var_name, local_sym_tab):
result = handle_binary_op_impl(rval, module, builder, local_sym_tab) result = handle_binary_op_impl(rval, builder, local_sym_tab)
builder.store(result, local_sym_tab[var_name].var) if var_name and var_name in local_sym_tab:
logger.info(
f"Storing result {result} into variable {local_sym_tab[var_name].var}"
)
builder.store(result, local_sym_tab[var_name].var)
return result, result.type

View File

@ -48,7 +48,7 @@ def processor(source_code, filename, module):
globals_processing(tree, module) globals_processing(tree, module)
def compile_to_ir(filename: str, output: str, loglevel=logging.WARNING): def compile_to_ir(filename: str, output: str, loglevel=logging.INFO):
logging.basicConfig( logging.basicConfig(
level=loglevel, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" level=loglevel, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
) )
@ -121,7 +121,7 @@ def compile_to_ir(filename: str, output: str, loglevel=logging.WARNING):
return output return output
def compile(loglevel=logging.WARNING) -> bool: def compile(loglevel=logging.INFO) -> bool:
# Look one level up the stack to the caller of this function # Look one level up the stack to the caller of this function
caller_frame = inspect.stack()[1] caller_frame = inspect.stack()[1]
caller_file = Path(caller_frame.filename).resolve() caller_file = Path(caller_frame.filename).resolve()
@ -154,7 +154,7 @@ def compile(loglevel=logging.WARNING) -> bool:
return success return success
def BPF(loglevel=logging.WARNING) -> BpfProgram: def BPF(loglevel=logging.INFO) -> BpfProgram:
caller_frame = inspect.stack()[1] caller_frame = inspect.stack()[1]
src = inspect.getsource(caller_frame.frame) src = inspect.getsource(caller_frame.frame)
with tempfile.NamedTemporaryFile( with tempfile.NamedTemporaryFile(

View File

@ -232,7 +232,7 @@ def handle_assign(
else: else:
logger.info("Unsupported assignment call function type") logger.info("Unsupported assignment call function type")
elif isinstance(rval, ast.BinOp): elif isinstance(rval, ast.BinOp):
handle_binary_op(rval, module, builder, var_name, local_sym_tab) handle_binary_op(rval, builder, var_name, local_sym_tab)
else: else:
logger.info("Unsupported assignment value type") logger.info("Unsupported assignment value type")
@ -384,24 +384,49 @@ def process_stmt(
) )
elif isinstance(stmt, ast.Return): elif isinstance(stmt, ast.Return):
if stmt.value is None: if stmt.value is None:
builder.ret(ir.Constant(ir.IntType(32), 0)) builder.ret(ir.Constant(ir.IntType(64), 0))
did_return = True did_return = True
elif ( elif (
isinstance(stmt.value, ast.Call) isinstance(stmt.value, ast.Call)
and isinstance(stmt.value.func, ast.Name) and isinstance(stmt.value.func, ast.Name)
and len(stmt.value.args) == 1 and len(stmt.value.args) == 1
and isinstance(stmt.value.args[0], ast.Constant)
and isinstance(stmt.value.args[0].value, int)
): ):
call_type = stmt.value.func.id if isinstance(stmt.value.args[0], ast.Constant) and isinstance(
if ctypes_to_ir(call_type) != ret_type: stmt.value.args[0].value, int
raise ValueError( ):
"Return type mismatch: expected" call_type = stmt.value.func.id
f"{ctypes_to_ir(call_type)}, got {call_type}" if ctypes_to_ir(call_type) != ret_type:
) raise ValueError(
else: "Return type mismatch: expected"
builder.ret(ir.Constant(ret_type, stmt.value.args[0].value)) f"{ctypes_to_ir(call_type)}, got {call_type}"
)
else:
builder.ret(ir.Constant(ret_type, stmt.value.args[0].value))
did_return = True
elif isinstance(stmt.value.args[0], ast.BinOp):
# TODO: Should be routed through eval_expr
val = handle_binary_op(stmt.value.args[0], builder, None, local_sym_tab)
if val is None:
raise ValueError("Failed to evaluate return expression")
if val[1] != ret_type:
raise ValueError(
"Return type mismatch: expected " f"{ret_type}, got {val[1]}"
)
builder.ret(val[0])
did_return = True did_return = True
elif isinstance(stmt.value.args[0], ast.Name):
if stmt.value.args[0].id in local_sym_tab:
var = local_sym_tab[stmt.value.args[0].id].var
val = builder.load(var)
if val.type != ret_type:
raise ValueError(
"Return type mismatch: expected"
f"{ret_type}, got {val.type}"
)
builder.ret(val)
did_return = True
else:
raise ValueError("Failed to evaluate return expression")
elif isinstance(stmt.value, ast.Name): elif isinstance(stmt.value, ast.Name):
if stmt.value.id == "XDP_PASS": if stmt.value.id == "XDP_PASS":
builder.ret(ir.Constant(ret_type, 2)) builder.ret(ir.Constant(ret_type, 2))
@ -454,6 +479,9 @@ def allocate_mem(
continue continue
var_name = target.id var_name = target.id
rval = stmt.value rval = stmt.value
if var_name in local_sym_tab:
logger.info(f"Variable {var_name} already allocated")
continue
if isinstance(rval, ast.Call): if isinstance(rval, ast.Call):
if isinstance(rval.func, ast.Name): if isinstance(rval.func, ast.Name):
call_type = rval.func.id call_type = rval.func.id
@ -566,7 +594,7 @@ def process_func_body(
) )
if not did_return: if not did_return:
builder.ret(ir.Constant(ir.IntType(32), 0)) builder.ret(ir.Constant(ir.IntType(64), 0))
def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_tab): def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_tab):

View File

@ -4,6 +4,18 @@ from pythonbpf.maps import HashMap
from ctypes import c_void_p, c_int64 from ctypes import c_void_p, c_int64
# NOTE: I have decided to not fix this example for now.
# The issue is in line 31, where we are passing an expression.
# The update helper expects a pointer type. But the problem is
# that we must allocate the space for said pointer in the first
# basic block. As that usage is in a different basic block, we
# are unable to cast the expression to a pointer type. (as we never
# allocated space for it).
# Shall we change our space allocation logic? That allows users to
# spam the same helper with the same args, and still run out of
# stack space. So we consider this usage invalid for now.
# Might fix it later.
@bpf @bpf
@map @map
@ -14,12 +26,12 @@ def count() -> HashMap:
@bpf @bpf
@section("xdp") @section("xdp")
def hello_world(ctx: c_void_p) -> c_int64: def hello_world(ctx: c_void_p) -> c_int64:
prev = count().lookup(0) prev = count.lookup(0)
if prev: if prev:
count().update(0, prev + 1) count.update(0, prev + 1)
return XDP_PASS return XDP_PASS
else: else:
count().update(0, 1) count.update(0, 1)
return XDP_PASS return XDP_PASS

View File

@ -0,0 +1,40 @@
from pythonbpf import bpf, map, section, bpfglobal, compile
from pythonbpf.helper import XDP_PASS
from pythonbpf.maps import HashMap
from ctypes import c_void_p, c_int64
# NOTE: This example exposes the problems with our typing system.
# We can't do steps on line 25 and 27.
# prev is of type i64**. For prev + 1, we deref it down to i64
# To assign it back to prev, we need to go back to i64**.
# We cannot allocate space for the intermediate type now.
# We probably need to track the ref/deref chain for each variable.
@bpf
@map
def count() -> HashMap:
return HashMap(key=c_int64, value=c_int64, max_entries=1)
@bpf
@section("xdp")
def hello_world(ctx: c_void_p) -> c_int64:
prev = count.lookup(0)
if prev:
prev = prev + 1
count.update(0, prev)
return XDP_PASS
else:
count.update(0, 1)
return XDP_PASS
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile()

View File

@ -0,0 +1,20 @@
import logging
from pythonbpf import compile, bpf, section, bpfglobal
from ctypes import c_void_p, c_int64
@bpf
@section("sometag1")
def sometag(ctx: c_void_p) -> c_int64:
a = 1 - 1
return c_int64(a)
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile(loglevel=logging.INFO)