Add null checks for pointer derefs to avoid map_value_or_null verifier errors

This commit is contained in:
Pragyansh Chaturvedi
2025-10-08 05:53:12 +05:30
parent 3f9604a370
commit 6b59980874

View File

@ -139,38 +139,84 @@ def _get_base_type_and_depth(ir_type):
return cur_type, depth
def _deref_to_depth(builder, val, target_depth):
def _deref_to_depth(func, builder, val, target_depth):
"""Dereference a pointer to a certain depth."""
cur_val = val
for _ in range(target_depth):
cur_type = val.type
for depth in range(target_depth):
if not isinstance(val.type, ir.PointerType):
logger.error("Cannot dereference further, non-pointer type")
return None
cur_val = builder.load(cur_val)
# dereference with null check
pointee_type = cur_type.pointee
null_check_block = builder.block
not_null_block = func.append_basic_block(name=f"deref_not_null_{depth}")
merge_block = func.append_basic_block(name=f"deref_merge_{depth}")
null_ptr = ir.Constant(cur_type, None)
is_not_null = builder.icmp_signed("!=", cur_val, null_ptr)
logger.debug(f"Inserted null check for pointer at depth {depth}")
builder.cbranch(is_not_null, not_null_block, merge_block)
builder.position_at_end(not_null_block)
dereferenced_val = builder.load(cur_val)
logger.debug(f"Dereferenced to depth {depth - 1}, type: {pointee_type}")
builder.branch(merge_block)
builder.position_at_end(merge_block)
phi = builder.phi(pointee_type, name=f"deref_result_{depth}")
zero_value = (
ir.Constant(pointee_type, 0)
if isinstance(pointee_type, ir.IntType)
else ir.Constant(pointee_type, None)
)
phi.add_incoming(zero_value, null_check_block)
phi.add_incoming(dereferenced_val, not_null_block)
# Continue with phi result
cur_val = phi
cur_type = pointee_type
return cur_val
def _normalize_types(builder, lhs, rhs):
def _normalize_types(func, builder, lhs, rhs):
"""Normalize types for comparison."""
logger.info(f"Normalizing types: {lhs.type} vs {rhs.type}")
if isinstance(lhs.type, ir.IntType) and isinstance(rhs.type, ir.IntType):
if lhs.type.width < rhs.type.width:
lhs = builder.sext(lhs, rhs.type)
else:
rhs = builder.sext(rhs, lhs.type)
return lhs, rhs
logger.error(f"Type mismatch: {lhs.type} vs {rhs.type}")
return None, None
elif not isinstance(lhs.type, ir.PointerType) and not isinstance(
rhs.type, ir.PointerType
):
logger.error(f"Type mismatch: {lhs.type} vs {rhs.type}")
return None, None
else:
lhs_base, lhs_depth = _get_base_type_and_depth(lhs.type)
rhs_base, rhs_depth = _get_base_type_and_depth(rhs.type)
if lhs_base == rhs_base:
if lhs_depth < rhs_depth:
rhs = _deref_to_depth(func, builder, rhs, rhs_depth - lhs_depth)
elif rhs_depth < lhs_depth:
lhs = _deref_to_depth(func, builder, lhs, lhs_depth - rhs_depth)
return _normalize_types(func, builder, lhs, rhs)
def _handle_comparator(builder, op, lhs, rhs):
def _handle_comparator(func, builder, op, lhs, rhs):
"""Handle comparison operations."""
# NOTE: For now assume same types
if lhs.type != rhs.type:
lhs, rhs = _normalize_types(builder, lhs, rhs)
lhs, rhs = _normalize_types(func, builder, lhs, rhs)
if lhs is None or rhs is None:
return None
@ -227,7 +273,7 @@ def _handle_compare(
lhs, _ = lhs
rhs, _ = rhs
return _handle_comparator(builder, cond.ops[0], lhs, rhs)
return _handle_comparator(func, builder, cond.ops[0], lhs, rhs)
def convert_to_bool(builder, val):