mirror of
https://github.com/varun-r-mallya/Python-BPF.git
synced 2025-12-31 21:06:25 +00:00
Move handle_comparator to type_normalization
This commit is contained in:
@ -5,7 +5,7 @@ import logging
|
|||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes
|
from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes
|
||||||
from .type_normalization import normalize_types, convert_to_bool
|
from .type_normalization import convert_to_bool, handle_comparator
|
||||||
|
|
||||||
logger: Logger = logging.getLogger(__name__)
|
logger: Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -130,37 +130,6 @@ def _handle_ctypes_call(
|
|||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
||||||
def _handle_comparator(func, builder, op, lhs, rhs):
|
|
||||||
"""Handle comparison operations."""
|
|
||||||
|
|
||||||
# NOTE: For now assume same types
|
|
||||||
if lhs.type != rhs.type:
|
|
||||||
lhs, rhs = normalize_types(func, builder, lhs, rhs)
|
|
||||||
|
|
||||||
if lhs is None or rhs is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
comparison_ops = {
|
|
||||||
ast.Eq: "==",
|
|
||||||
ast.NotEq: "!=",
|
|
||||||
ast.Lt: "<",
|
|
||||||
ast.LtE: "<=",
|
|
||||||
ast.Gt: ">",
|
|
||||||
ast.GtE: ">=",
|
|
||||||
ast.Is: "==",
|
|
||||||
ast.IsNot: "!=",
|
|
||||||
}
|
|
||||||
|
|
||||||
if type(op) not in comparison_ops:
|
|
||||||
logger.error(f"Unsupported comparison operator: {type(op)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
predicate = comparison_ops[type(op)]
|
|
||||||
result = builder.icmp_signed(predicate, lhs, rhs)
|
|
||||||
logger.debug(f"Comparison result: {result}")
|
|
||||||
return result, ir.IntType(1)
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_compare(
|
def _handle_compare(
|
||||||
func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab=None
|
func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab=None
|
||||||
):
|
):
|
||||||
@ -194,7 +163,7 @@ def _handle_compare(
|
|||||||
|
|
||||||
lhs, _ = lhs
|
lhs, _ = lhs
|
||||||
rhs, _ = rhs
|
rhs, _ = rhs
|
||||||
return _handle_comparator(func, builder, cond.ops[0], lhs, rhs)
|
return handle_comparator(func, builder, cond.ops[0], lhs, rhs)
|
||||||
|
|
||||||
|
|
||||||
def _handle_unary_op(
|
def _handle_unary_op(
|
||||||
|
|||||||
@ -1,8 +1,20 @@
|
|||||||
from llvmlite import ir
|
from llvmlite import ir
|
||||||
import logging
|
import logging
|
||||||
|
import ast
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
COMPARISON_OPS = {
|
||||||
|
ast.Eq: "==",
|
||||||
|
ast.NotEq: "!=",
|
||||||
|
ast.Lt: "<",
|
||||||
|
ast.LtE: "<=",
|
||||||
|
ast.Gt: ">",
|
||||||
|
ast.GtE: ">=",
|
||||||
|
ast.Is: "==",
|
||||||
|
ast.IsNot: "!=",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _get_base_type_and_depth(ir_type):
|
def _get_base_type_and_depth(ir_type):
|
||||||
"""Get the base type for pointer types."""
|
"""Get the base type for pointer types."""
|
||||||
@ -60,7 +72,7 @@ def _deref_to_depth(func, builder, val, target_depth):
|
|||||||
return cur_val
|
return cur_val
|
||||||
|
|
||||||
|
|
||||||
def normalize_types(func, builder, lhs, rhs):
|
def _normalize_types(func, builder, lhs, rhs):
|
||||||
"""Normalize types for comparison."""
|
"""Normalize types for comparison."""
|
||||||
|
|
||||||
logger.info(f"Normalizing types: {lhs.type} vs {rhs.type}")
|
logger.info(f"Normalizing types: {lhs.type} vs {rhs.type}")
|
||||||
@ -83,7 +95,7 @@ def normalize_types(func, builder, lhs, rhs):
|
|||||||
rhs = _deref_to_depth(func, builder, rhs, rhs_depth - lhs_depth)
|
rhs = _deref_to_depth(func, builder, rhs, rhs_depth - lhs_depth)
|
||||||
elif rhs_depth < lhs_depth:
|
elif rhs_depth < lhs_depth:
|
||||||
lhs = _deref_to_depth(func, builder, lhs, lhs_depth - rhs_depth)
|
lhs = _deref_to_depth(func, builder, lhs, lhs_depth - rhs_depth)
|
||||||
return normalize_types(func, builder, lhs, rhs)
|
return _normalize_types(func, builder, lhs, rhs)
|
||||||
|
|
||||||
|
|
||||||
def convert_to_bool(builder, val):
|
def convert_to_bool(builder, val):
|
||||||
@ -95,3 +107,22 @@ def convert_to_bool(builder, val):
|
|||||||
else:
|
else:
|
||||||
zero = ir.Constant(val.type, 0)
|
zero = ir.Constant(val.type, 0)
|
||||||
return builder.icmp_signed("!=", val, zero)
|
return builder.icmp_signed("!=", val, zero)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_comparator(func, builder, op, lhs, rhs):
|
||||||
|
"""Handle comparison operations."""
|
||||||
|
|
||||||
|
if lhs.type != rhs.type:
|
||||||
|
lhs, rhs = _normalize_types(func, builder, lhs, rhs)
|
||||||
|
|
||||||
|
if lhs is None or rhs is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if type(op) not in COMPARISON_OPS:
|
||||||
|
logger.error(f"Unsupported comparison operator: {type(op)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
predicate = COMPARISON_OPS[type(op)]
|
||||||
|
result = builder.icmp_signed(predicate, lhs, rhs)
|
||||||
|
logger.debug(f"Comparison result: {result}")
|
||||||
|
return result, ir.IntType(1)
|
||||||
|
|||||||
Reference in New Issue
Block a user