diff --git a/pythonbpf/binary_ops.py b/pythonbpf/binary_ops.py deleted file mode 100644 index 6ea534b..0000000 --- a/pythonbpf/binary_ops.py +++ /dev/null @@ -1,110 +0,0 @@ -import ast -from llvmlite import ir -from logging import Logger -import logging - -from pythonbpf.expr import get_base_type_and_depth, deref_to_depth, eval_expr - -logger: Logger = logging.getLogger(__name__) - - -def get_operand_value( - func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None -): - """Extract the value from an operand, handling variables and constants.""" - logger.info(f"Getting operand value for: {ast.dump(operand)}") - if isinstance(operand, ast.Name): - if operand.id in local_sym_tab: - var = local_sym_tab[operand.id].var - var_type = var.type - base_type, depth = get_base_type_and_depth(var_type) - logger.info(f"var is {var}, base_type is {base_type}, depth is {depth}") - val = deref_to_depth(func, builder, var, depth) - return val - raise ValueError(f"Undefined variable: {operand.id}") - elif isinstance(operand, ast.Constant): - if isinstance(operand.value, int): - cst = ir.Constant(ir.IntType(64), int(operand.value)) - return cst - raise TypeError(f"Unsupported constant type: {type(operand.value)}") - elif isinstance(operand, ast.BinOp): - res = handle_binary_op_impl( - func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab - ) - return res - else: - res = eval_expr( - func, module, builder, operand, local_sym_tab, map_sym_tab, structs_sym_tab - ) - if res is None: - raise ValueError(f"Failed to evaluate call expression: {operand}") - val, _ = res - logger.info(f"Evaluated expr to {val} of type {val.type}") - base_type, depth = get_base_type_and_depth(val.type) - if depth > 0: - val = deref_to_depth(func, builder, val, depth) - return val - raise TypeError(f"Unsupported operand type: {type(operand)}") - - -def handle_binary_op_impl( - func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None -): - op = rval.op - left = get_operand_value( - func, module, rval.left, builder, local_sym_tab, map_sym_tab, structs_sym_tab - ) - right = get_operand_value( - func, module, rval.right, builder, local_sym_tab, map_sym_tab, structs_sym_tab - ) - logger.info(f"left is {left}, right is {right}, op is {op}") - - # NOTE: Before doing the operation, if the operands are integers - # we always extend them to i64. The assignment to LHS will take - # care of truncation if needed. - if isinstance(left.type, ir.IntType) and left.type.width < 64: - left = builder.sext(left, ir.IntType(64)) - if isinstance(right.type, ir.IntType) and right.type.width < 64: - right = builder.sext(right, ir.IntType(64)) - - # Map AST operation nodes to LLVM IR builder methods - op_map = { - ast.Add: builder.add, - ast.Sub: builder.sub, - ast.Mult: builder.mul, - ast.Div: builder.sdiv, - ast.Mod: builder.srem, - ast.LShift: builder.shl, - ast.RShift: builder.lshr, - ast.BitOr: builder.or_, - ast.BitXor: builder.xor, - ast.BitAnd: builder.and_, - ast.FloorDiv: builder.udiv, - } - - if type(op) in op_map: - result = op_map[type(op)](left, right) - return result - else: - raise SyntaxError("Unsupported binary operation") - - -def handle_binary_op( - func, - module, - rval, - builder, - var_name, - local_sym_tab, - map_sym_tab, - structs_sym_tab=None, -): - result = handle_binary_op_impl( - func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab - ) - if var_name and var_name in local_sym_tab: - logger.info( - f"Storing result {result} into variable {local_sym_tab[var_name].var}" - ) - builder.store(result, local_sym_tab[var_name].var) - return result, result.type diff --git a/pythonbpf/expr/__init__.py b/pythonbpf/expr/__init__.py index dd5b480..94cf330 100644 --- a/pythonbpf/expr/__init__.py +++ b/pythonbpf/expr/__init__.py @@ -1,4 +1,4 @@ -from .expr_pass import eval_expr, handle_expr +from .expr_pass import eval_expr, handle_expr, get_operand_value from .type_normalization import convert_to_bool, get_base_type_and_depth, deref_to_depth __all__ = [ @@ -7,4 +7,5 @@ __all__ = [ "convert_to_bool", "get_base_type_and_depth", "deref_to_depth", + "get_operand_value", ] diff --git a/pythonbpf/expr/expr_pass.py b/pythonbpf/expr/expr_pass.py index 85645b0..f16fd46 100644 --- a/pythonbpf/expr/expr_pass.py +++ b/pythonbpf/expr/expr_pass.py @@ -5,11 +5,118 @@ import logging from typing import Dict from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes -from .type_normalization import convert_to_bool, handle_comparator +from .type_normalization import ( + convert_to_bool, + handle_comparator, + get_base_type_and_depth, + deref_to_depth, +) logger: Logger = logging.getLogger(__name__) +def get_operand_value( + func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None +): + """Extract the value from an operand, handling variables and constants.""" + logger.info(f"Getting operand value for: {ast.dump(operand)}") + if isinstance(operand, ast.Name): + if operand.id in local_sym_tab: + var = local_sym_tab[operand.id].var + var_type = var.type + base_type, depth = get_base_type_and_depth(var_type) + logger.info(f"var is {var}, base_type is {base_type}, depth is {depth}") + val = deref_to_depth(func, builder, var, depth) + return val + raise ValueError(f"Undefined variable: {operand.id}") + elif isinstance(operand, ast.Constant): + if isinstance(operand.value, int): + cst = ir.Constant(ir.IntType(64), int(operand.value)) + return cst + raise TypeError(f"Unsupported constant type: {type(operand.value)}") + elif isinstance(operand, ast.BinOp): + res = _handle_binary_op_impl( + func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab + ) + return res + else: + res = eval_expr( + func, module, builder, operand, local_sym_tab, map_sym_tab, structs_sym_tab + ) + if res is None: + raise ValueError(f"Failed to evaluate call expression: {operand}") + val, _ = res + logger.info(f"Evaluated expr to {val} of type {val.type}") + base_type, depth = get_base_type_and_depth(val.type) + if depth > 0: + val = deref_to_depth(func, builder, val, depth) + return val + raise TypeError(f"Unsupported operand type: {type(operand)}") + + +def _handle_binary_op_impl( + func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None +): + op = rval.op + left = get_operand_value( + func, module, rval.left, builder, local_sym_tab, map_sym_tab, structs_sym_tab + ) + right = get_operand_value( + func, module, rval.right, builder, local_sym_tab, map_sym_tab, structs_sym_tab + ) + logger.info(f"left is {left}, right is {right}, op is {op}") + + # NOTE: Before doing the operation, if the operands are integers + # we always extend them to i64. The assignment to LHS will take + # care of truncation if needed. + if isinstance(left.type, ir.IntType) and left.type.width < 64: + left = builder.sext(left, ir.IntType(64)) + if isinstance(right.type, ir.IntType) and right.type.width < 64: + right = builder.sext(right, ir.IntType(64)) + + # Map AST operation nodes to LLVM IR builder methods + op_map = { + ast.Add: builder.add, + ast.Sub: builder.sub, + ast.Mult: builder.mul, + ast.Div: builder.sdiv, + ast.Mod: builder.srem, + ast.LShift: builder.shl, + ast.RShift: builder.lshr, + ast.BitOr: builder.or_, + ast.BitXor: builder.xor, + ast.BitAnd: builder.and_, + ast.FloorDiv: builder.udiv, + } + + if type(op) in op_map: + result = op_map[type(op)](left, right) + return result + else: + raise SyntaxError("Unsupported binary operation") + + +def _handle_binary_op( + func, + module, + rval, + builder, + var_name, + local_sym_tab, + map_sym_tab, + structs_sym_tab=None, +): + result = _handle_binary_op_impl( + func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab + ) + if var_name and var_name in local_sym_tab: + logger.info( + f"Storing result {result} into variable {local_sym_tab[var_name].var}" + ) + builder.store(result, local_sym_tab[var_name].var) + return result, result.type + + def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder): """Handle ast.Name expressions.""" if expr.id in local_sym_tab: @@ -194,8 +301,6 @@ def _handle_unary_op( logger.error("Only 'not' and '-' unary operators are supported") return None - from pythonbpf.binary_ops import get_operand_value - operand = get_operand_value( func, module, expr.operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab ) @@ -421,9 +526,7 @@ def eval_expr( elif isinstance(expr, ast.Attribute): return _handle_attribute_expr(expr, local_sym_tab, structs_sym_tab, builder) elif isinstance(expr, ast.BinOp): - from pythonbpf.binary_ops import handle_binary_op - - return handle_binary_op( + return _handle_binary_op( func, module, expr, diff --git a/pythonbpf/helper/helper_utils.py b/pythonbpf/helper/helper_utils.py index 284aa68..b67058b 100644 --- a/pythonbpf/helper/helper_utils.py +++ b/pythonbpf/helper/helper_utils.py @@ -3,8 +3,12 @@ import logging from collections.abc import Callable from llvmlite import ir -from pythonbpf.expr import eval_expr, get_base_type_and_depth, deref_to_depth -from pythonbpf.binary_ops import get_operand_value +from pythonbpf.expr import ( + eval_expr, + get_base_type_and_depth, + deref_to_depth, + get_operand_value, +) logger = logging.getLogger(__name__)