4 Commits

View File

@ -6,38 +6,57 @@ import logging
logger: Logger = logging.getLogger(__name__) logger: Logger = logging.getLogger(__name__)
def recursive_dereferencer(var, builder): def deref_to_val(var, builder):
"""dereference until primitive type comes out""" """Dereference a variable to get its value and pointer chain."""
# TODO: Not worrying about stack overflow for now
logger.info(f"Dereferencing {var}, type is {var.type}") logger.info(f"Dereferencing {var}, type is {var.type}")
if isinstance(var.type, ir.PointerType):
a = builder.load(var) chain = [var]
return recursive_dereferencer(a, builder) cur = var
elif isinstance(var.type, ir.IntType):
return var while isinstance(cur.type, ir.PointerType):
cur = builder.load(cur)
chain.append(cur)
if isinstance(cur.type, ir.IntType):
logger.info(f"dereference chain: {chain}")
return cur, chain
else: else:
raise TypeError(f"Unsupported type for dereferencing: {var.type}") raise TypeError(f"Unsupported type for dereferencing: {cur.type}")
def get_operand_value(operand, builder, local_sym_tab): def get_operand_value(operand, builder, local_sym_tab):
"""Extract the value from an operand, handling variables and constants.""" """Extract the value from an operand, handling variables and constants."""
if isinstance(operand, ast.Name): if isinstance(operand, ast.Name):
if operand.id in local_sym_tab: if operand.id in local_sym_tab:
return recursive_dereferencer(local_sym_tab[operand.id].var, builder) var = local_sym_tab[operand.id].var
val, chain = deref_to_val(var, builder)
return val, chain, var
raise ValueError(f"Undefined variable: {operand.id}") raise ValueError(f"Undefined variable: {operand.id}")
elif isinstance(operand, ast.Constant): elif isinstance(operand, ast.Constant):
if isinstance(operand.value, int): if isinstance(operand.value, int):
return ir.Constant(ir.IntType(64), operand.value) cst = ir.Constant(ir.IntType(64), operand.value)
return cst, [cst], None
raise TypeError(f"Unsupported constant type: {type(operand.value)}") raise TypeError(f"Unsupported constant type: {type(operand.value)}")
elif isinstance(operand, ast.BinOp): elif isinstance(operand, ast.BinOp):
return handle_binary_op_impl(operand, builder, local_sym_tab) res = handle_binary_op_impl(operand, builder, local_sym_tab)
return res, [res], None
raise TypeError(f"Unsupported operand type: {type(operand)}") raise TypeError(f"Unsupported operand type: {type(operand)}")
def store_through_chain(value, chain, builder):
"""Store a value through a pointer chain."""
if not chain or len(chain) < 2:
raise ValueError("Pointer chain must have at least two elements")
for ptr in reversed(chain[1:]):
builder.store(value, ptr)
value = ptr
def handle_binary_op_impl(rval, builder, local_sym_tab): def handle_binary_op_impl(rval, builder, local_sym_tab):
op = rval.op op = rval.op
left = get_operand_value(rval.left, builder, local_sym_tab) left, _, _ = get_operand_value(rval.left, builder, local_sym_tab)
right = get_operand_value(rval.right, builder, local_sym_tab) right, _, _ = get_operand_value(rval.right, builder, local_sym_tab)
logger.info(f"left is {left}, right is {right}, op is {op}") logger.info(f"left is {left}, right is {right}, op is {op}")
# Map AST operation nodes to LLVM IR builder methods # Map AST operation nodes to LLVM IR builder methods