Add comparison ops

This commit is contained in:
Pragyansh Chaturvedi
2025-09-11 01:52:30 +05:30
parent 3dd3784ec4
commit 4f726a7a1a
2 changed files with 58 additions and 4 deletions

View File

@ -101,8 +101,26 @@ def handle_assign(func, module, builder, stmt, map_sym_tab, local_sym_tab):
def handle_expr(func, module, builder, expr, local_sym_tab, map_sym_tab):
"""Handle expression statements in the function body."""
print(f"Handling expression: {ast.dump(expr)}")
if isinstance(expr, ast.Name):
if expr.id in local_sym_tab:
var = local_sym_tab[expr.id]
val = builder.load(var)
return val
else:
print(f"Undefined variable {expr.id}")
return None
elif isinstance(expr, ast.Constant):
if isinstance(expr.value, int):
return ir.Constant(ir.IntType(64), expr.value)
elif isinstance(expr.value, bool):
return ir.Constant(ir.IntType(1), int(expr.value))
else:
print("Unsupported constant type")
return None
call = expr.value
print(f"Handling expression: {ast.dump(call)}")
if isinstance(call, ast.Call):
if isinstance(call.func, ast.Name):
# check for helpers first
@ -154,6 +172,42 @@ def handle_cond(func, module, builder, cond, local_sym_tab, map_sym_tab):
else:
print(f"Undefined variable {cond.id} in condition")
return None
elif isinstance(cond, ast.Compare):
lhs = handle_expr(func, module, builder, cond.left,
local_sym_tab, map_sym_tab)
if len(cond.ops) != 1 or len(cond.comparators) != 1:
print("Unsupported complex comparison")
return None
rhs = handle_expr(func, module, builder,
cond.comparators[0], local_sym_tab, map_sym_tab)
op = cond.ops[0]
if lhs.type != rhs.type:
if isinstance(lhs.type, ir.IntType) and isinstance(rhs.type, ir.IntType):
# Extend the smaller type to the larger type
if lhs.type.width < rhs.type.width:
lhs = builder.sext(lhs, rhs.type)
elif lhs.type.width > rhs.type.width:
rhs = builder.sext(rhs, lhs.type)
else:
print("Type mismatch in comparison")
return None
if isinstance(op, ast.Eq):
return builder.icmp_signed("==", lhs, rhs)
elif isinstance(op, ast.NotEq):
return builder.icmp_signed("!=", lhs, rhs)
elif isinstance(op, ast.Lt):
return builder.icmp_signed("<", lhs, rhs)
elif isinstance(op, ast.LtE):
return builder.icmp_signed("<=", lhs, rhs)
elif isinstance(op, ast.Gt):
return builder.icmp_signed(">", lhs, rhs)
elif isinstance(op, ast.GtE):
return builder.icmp_signed(">=", lhs, rhs)
else:
print("Unsupported comparison operator")
return None
else:
print("Unsupported condition expression")
return None