diff --git a/pythonbpf/expr/expr_pass.py b/pythonbpf/expr/expr_pass.py index a72a4bb..21be196 100644 --- a/pythonbpf/expr/expr_pass.py +++ b/pythonbpf/expr/expr_pass.py @@ -5,7 +5,7 @@ import logging from typing import Dict from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes -from .type_normalization import normalize_types, convert_to_bool +from .type_normalization import convert_to_bool, handle_comparator logger: Logger = logging.getLogger(__name__) @@ -130,37 +130,6 @@ def _handle_ctypes_call( return val -def _handle_comparator(func, builder, op, lhs, rhs): - """Handle comparison operations.""" - - # NOTE: For now assume same types - if lhs.type != rhs.type: - lhs, rhs = normalize_types(func, builder, lhs, rhs) - - if lhs is None or rhs is None: - return None - - comparison_ops = { - ast.Eq: "==", - ast.NotEq: "!=", - ast.Lt: "<", - ast.LtE: "<=", - ast.Gt: ">", - ast.GtE: ">=", - ast.Is: "==", - ast.IsNot: "!=", - } - - if type(op) not in comparison_ops: - logger.error(f"Unsupported comparison operator: {type(op)}") - return None - - predicate = comparison_ops[type(op)] - result = builder.icmp_signed(predicate, lhs, rhs) - logger.debug(f"Comparison result: {result}") - return result, ir.IntType(1) - - def _handle_compare( func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab=None ): @@ -194,7 +163,7 @@ def _handle_compare( lhs, _ = lhs rhs, _ = rhs - return _handle_comparator(func, builder, cond.ops[0], lhs, rhs) + return handle_comparator(func, builder, cond.ops[0], lhs, rhs) def _handle_unary_op( diff --git a/pythonbpf/expr/type_normalization.py b/pythonbpf/expr/type_normalization.py index 2715e1f..7a2fb57 100644 --- a/pythonbpf/expr/type_normalization.py +++ b/pythonbpf/expr/type_normalization.py @@ -1,8 +1,20 @@ from llvmlite import ir import logging +import ast logger = logging.getLogger(__name__) +COMPARISON_OPS = { + ast.Eq: "==", + ast.NotEq: "!=", + ast.Lt: "<", + ast.LtE: "<=", + ast.Gt: ">", + ast.GtE: ">=", + ast.Is: "==", + ast.IsNot: "!=", +} + def _get_base_type_and_depth(ir_type): """Get the base type for pointer types.""" @@ -60,7 +72,7 @@ def _deref_to_depth(func, builder, val, target_depth): return cur_val -def normalize_types(func, builder, lhs, rhs): +def _normalize_types(func, builder, lhs, rhs): """Normalize types for comparison.""" logger.info(f"Normalizing types: {lhs.type} vs {rhs.type}") @@ -83,7 +95,7 @@ def normalize_types(func, builder, lhs, rhs): rhs = _deref_to_depth(func, builder, rhs, rhs_depth - lhs_depth) elif rhs_depth < lhs_depth: lhs = _deref_to_depth(func, builder, lhs, lhs_depth - rhs_depth) - return normalize_types(func, builder, lhs, rhs) + return _normalize_types(func, builder, lhs, rhs) def convert_to_bool(builder, val): @@ -95,3 +107,22 @@ def convert_to_bool(builder, val): else: zero = ir.Constant(val.type, 0) return builder.icmp_signed("!=", val, zero) + + +def handle_comparator(func, builder, op, lhs, rhs): + """Handle comparison operations.""" + + if lhs.type != rhs.type: + lhs, rhs = _normalize_types(func, builder, lhs, rhs) + + if lhs is None or rhs is None: + return None + + if type(op) not in COMPARISON_OPS: + logger.error(f"Unsupported comparison operator: {type(op)}") + return None + + predicate = COMPARISON_OPS[type(op)] + result = builder.icmp_signed(predicate, lhs, rhs) + logger.debug(f"Comparison result: {result}") + return result, ir.IntType(1)