mirror of
https://github.com/varun-r-mallya/Python-BPF.git
synced 2025-12-31 21:06:25 +00:00
Add null checks for pointer derefs to avoid map_value_or_null verifier errors
This commit is contained in:
@ -139,38 +139,84 @@ def _get_base_type_and_depth(ir_type):
|
|||||||
return cur_type, depth
|
return cur_type, depth
|
||||||
|
|
||||||
|
|
||||||
def _deref_to_depth(builder, val, target_depth):
|
def _deref_to_depth(func, builder, val, target_depth):
|
||||||
"""Dereference a pointer to a certain depth."""
|
"""Dereference a pointer to a certain depth."""
|
||||||
|
|
||||||
cur_val = val
|
cur_val = val
|
||||||
for _ in range(target_depth):
|
cur_type = val.type
|
||||||
|
|
||||||
|
for depth in range(target_depth):
|
||||||
if not isinstance(val.type, ir.PointerType):
|
if not isinstance(val.type, ir.PointerType):
|
||||||
logger.error("Cannot dereference further, non-pointer type")
|
logger.error("Cannot dereference further, non-pointer type")
|
||||||
return None
|
return None
|
||||||
cur_val = builder.load(cur_val)
|
|
||||||
|
# dereference with null check
|
||||||
|
pointee_type = cur_type.pointee
|
||||||
|
null_check_block = builder.block
|
||||||
|
not_null_block = func.append_basic_block(name=f"deref_not_null_{depth}")
|
||||||
|
merge_block = func.append_basic_block(name=f"deref_merge_{depth}")
|
||||||
|
|
||||||
|
null_ptr = ir.Constant(cur_type, None)
|
||||||
|
is_not_null = builder.icmp_signed("!=", cur_val, null_ptr)
|
||||||
|
logger.debug(f"Inserted null check for pointer at depth {depth}")
|
||||||
|
|
||||||
|
builder.cbranch(is_not_null, not_null_block, merge_block)
|
||||||
|
|
||||||
|
builder.position_at_end(not_null_block)
|
||||||
|
dereferenced_val = builder.load(cur_val)
|
||||||
|
logger.debug(f"Dereferenced to depth {depth - 1}, type: {pointee_type}")
|
||||||
|
builder.branch(merge_block)
|
||||||
|
|
||||||
|
builder.position_at_end(merge_block)
|
||||||
|
phi = builder.phi(pointee_type, name=f"deref_result_{depth}")
|
||||||
|
|
||||||
|
zero_value = (
|
||||||
|
ir.Constant(pointee_type, 0)
|
||||||
|
if isinstance(pointee_type, ir.IntType)
|
||||||
|
else ir.Constant(pointee_type, None)
|
||||||
|
)
|
||||||
|
phi.add_incoming(zero_value, null_check_block)
|
||||||
|
|
||||||
|
phi.add_incoming(dereferenced_val, not_null_block)
|
||||||
|
|
||||||
|
# Continue with phi result
|
||||||
|
cur_val = phi
|
||||||
|
cur_type = pointee_type
|
||||||
return cur_val
|
return cur_val
|
||||||
|
|
||||||
|
|
||||||
def _normalize_types(builder, lhs, rhs):
|
def _normalize_types(func, builder, lhs, rhs):
|
||||||
"""Normalize types for comparison."""
|
"""Normalize types for comparison."""
|
||||||
|
|
||||||
|
logger.info(f"Normalizing types: {lhs.type} vs {rhs.type}")
|
||||||
if isinstance(lhs.type, ir.IntType) and isinstance(rhs.type, ir.IntType):
|
if isinstance(lhs.type, ir.IntType) and isinstance(rhs.type, ir.IntType):
|
||||||
if lhs.type.width < rhs.type.width:
|
if lhs.type.width < rhs.type.width:
|
||||||
lhs = builder.sext(lhs, rhs.type)
|
lhs = builder.sext(lhs, rhs.type)
|
||||||
else:
|
else:
|
||||||
rhs = builder.sext(rhs, lhs.type)
|
rhs = builder.sext(rhs, lhs.type)
|
||||||
return lhs, rhs
|
return lhs, rhs
|
||||||
|
elif not isinstance(lhs.type, ir.PointerType) and not isinstance(
|
||||||
logger.error(f"Type mismatch: {lhs.type} vs {rhs.type}")
|
rhs.type, ir.PointerType
|
||||||
return None, None
|
):
|
||||||
|
logger.error(f"Type mismatch: {lhs.type} vs {rhs.type}")
|
||||||
|
return None, None
|
||||||
|
else:
|
||||||
|
lhs_base, lhs_depth = _get_base_type_and_depth(lhs.type)
|
||||||
|
rhs_base, rhs_depth = _get_base_type_and_depth(rhs.type)
|
||||||
|
if lhs_base == rhs_base:
|
||||||
|
if lhs_depth < rhs_depth:
|
||||||
|
rhs = _deref_to_depth(func, builder, rhs, rhs_depth - lhs_depth)
|
||||||
|
elif rhs_depth < lhs_depth:
|
||||||
|
lhs = _deref_to_depth(func, builder, lhs, lhs_depth - rhs_depth)
|
||||||
|
return _normalize_types(func, builder, lhs, rhs)
|
||||||
|
|
||||||
|
|
||||||
def _handle_comparator(builder, op, lhs, rhs):
|
def _handle_comparator(func, builder, op, lhs, rhs):
|
||||||
"""Handle comparison operations."""
|
"""Handle comparison operations."""
|
||||||
|
|
||||||
# NOTE: For now assume same types
|
# NOTE: For now assume same types
|
||||||
if lhs.type != rhs.type:
|
if lhs.type != rhs.type:
|
||||||
lhs, rhs = _normalize_types(builder, lhs, rhs)
|
lhs, rhs = _normalize_types(func, builder, lhs, rhs)
|
||||||
|
|
||||||
if lhs is None or rhs is None:
|
if lhs is None or rhs is None:
|
||||||
return None
|
return None
|
||||||
@ -227,7 +273,7 @@ def _handle_compare(
|
|||||||
|
|
||||||
lhs, _ = lhs
|
lhs, _ = lhs
|
||||||
rhs, _ = rhs
|
rhs, _ = rhs
|
||||||
return _handle_comparator(builder, cond.ops[0], lhs, rhs)
|
return _handle_comparator(func, builder, cond.ops[0], lhs, rhs)
|
||||||
|
|
||||||
|
|
||||||
def convert_to_bool(builder, val):
|
def convert_to_bool(builder, val):
|
||||||
|
|||||||
Reference in New Issue
Block a user