diff --git a/pythonbpf/binary_ops.py b/pythonbpf/binary_ops.py index eaf2e42..5594856 100644 --- a/pythonbpf/binary_ops.py +++ b/pythonbpf/binary_ops.py @@ -8,42 +8,35 @@ logger: Logger = logging.getLogger(__name__) def recursive_dereferencer(var, builder): """dereference until primitive type comes out""" + # TODO: Not worrying about stack overflow for now if isinstance(var.type, ir.PointerType): a = builder.load(var) return recursive_dereferencer(a, builder) - elif var.type == ir.IntType(64): + elif isinstance(var.type, ir.IntType): return var else: raise TypeError(f"Unsupported type for dereferencing: {var.type}") +def get_operand_value(operand, builder, local_sym_tab): + """Extract the value from an operand, handling variables and constants.""" + if isinstance(operand, ast.Name): + if operand.id in local_sym_tab: + return recursive_dereferencer(local_sym_tab[operand.id].var, builder) + raise ValueError(f"Undefined variable: {operand.id}") + elif isinstance(operand, ast.Constant): + if isinstance(operand.value, int): + return ir.Constant(ir.IntType(64), operand.value) + raise TypeError(f"Unsupported constant type: {type(operand.value)}") + raise TypeError(f"Unsupported operand type: {type(operand)}") + + def handle_binary_op(rval, module, builder, var_name, local_sym_tab, map_sym_tab, func): logger.info(f"module {module}") - left = rval.left - right = rval.right op = rval.op - # Handle left operand - if isinstance(left, ast.Name): - if left.id in local_sym_tab: - left = recursive_dereferencer(local_sym_tab[left.id].var, builder) - else: - raise SyntaxError(f"Undefined variable: {left.id}") - elif isinstance(left, ast.Constant): - left = ir.Constant(ir.IntType(64), left.value) - else: - raise SyntaxError("Unsupported left operand type") - - if isinstance(right, ast.Name): - if right.id in local_sym_tab: - right = recursive_dereferencer(local_sym_tab[right.id].var, builder) - else: - raise SyntaxError(f"Undefined variable: {right.id}") - elif isinstance(right, ast.Constant): - right = ir.Constant(ir.IntType(64), right.value) - else: - raise SyntaxError("Unsupported right operand type") - + left = get_operand_value(rval.left, builder, local_sym_tab) + right = get_operand_value(rval.right, builder, local_sym_tab) logger.info(f"left is {left}, right is {right}, op is {op}") if isinstance(op, ast.Add):