Move handle_comparator to type_normalization

This commit is contained in:
Pragyansh Chaturvedi
2025-10-08 07:20:04 +05:30
parent 0a6571726a
commit d38d73d5c6
2 changed files with 35 additions and 35 deletions

View File

@ -5,7 +5,7 @@ import logging
from typing import Dict
from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes
from .type_normalization import normalize_types, convert_to_bool
from .type_normalization import convert_to_bool, handle_comparator
logger: Logger = logging.getLogger(__name__)
@ -130,37 +130,6 @@ def _handle_ctypes_call(
return val
def _handle_comparator(func, builder, op, lhs, rhs):
"""Handle comparison operations."""
# NOTE: For now assume same types
if lhs.type != rhs.type:
lhs, rhs = normalize_types(func, builder, lhs, rhs)
if lhs is None or rhs is None:
return None
comparison_ops = {
ast.Eq: "==",
ast.NotEq: "!=",
ast.Lt: "<",
ast.LtE: "<=",
ast.Gt: ">",
ast.GtE: ">=",
ast.Is: "==",
ast.IsNot: "!=",
}
if type(op) not in comparison_ops:
logger.error(f"Unsupported comparison operator: {type(op)}")
return None
predicate = comparison_ops[type(op)]
result = builder.icmp_signed(predicate, lhs, rhs)
logger.debug(f"Comparison result: {result}")
return result, ir.IntType(1)
def _handle_compare(
func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab=None
):
@ -194,7 +163,7 @@ def _handle_compare(
lhs, _ = lhs
rhs, _ = rhs
return _handle_comparator(func, builder, cond.ops[0], lhs, rhs)
return handle_comparator(func, builder, cond.ops[0], lhs, rhs)
def _handle_unary_op(

View File

@ -1,8 +1,20 @@
from llvmlite import ir
import logging
import ast
logger = logging.getLogger(__name__)
COMPARISON_OPS = {
ast.Eq: "==",
ast.NotEq: "!=",
ast.Lt: "<",
ast.LtE: "<=",
ast.Gt: ">",
ast.GtE: ">=",
ast.Is: "==",
ast.IsNot: "!=",
}
def _get_base_type_and_depth(ir_type):
"""Get the base type for pointer types."""
@ -60,7 +72,7 @@ def _deref_to_depth(func, builder, val, target_depth):
return cur_val
def normalize_types(func, builder, lhs, rhs):
def _normalize_types(func, builder, lhs, rhs):
"""Normalize types for comparison."""
logger.info(f"Normalizing types: {lhs.type} vs {rhs.type}")
@ -83,7 +95,7 @@ def normalize_types(func, builder, lhs, rhs):
rhs = _deref_to_depth(func, builder, rhs, rhs_depth - lhs_depth)
elif rhs_depth < lhs_depth:
lhs = _deref_to_depth(func, builder, lhs, lhs_depth - rhs_depth)
return normalize_types(func, builder, lhs, rhs)
return _normalize_types(func, builder, lhs, rhs)
def convert_to_bool(builder, val):
@ -95,3 +107,22 @@ def convert_to_bool(builder, val):
else:
zero = ir.Constant(val.type, 0)
return builder.icmp_signed("!=", val, zero)
def handle_comparator(func, builder, op, lhs, rhs):
"""Handle comparison operations."""
if lhs.type != rhs.type:
lhs, rhs = _normalize_types(func, builder, lhs, rhs)
if lhs is None or rhs is None:
return None
if type(op) not in COMPARISON_OPS:
logger.error(f"Unsupported comparison operator: {type(op)}")
return None
predicate = COMPARISON_OPS[type(op)]
result = builder.icmp_signed(predicate, lhs, rhs)
logger.debug(f"Comparison result: {result}")
return result, ir.IntType(1)