""" Expression evaluation and LLVM IR generation. This module handles the evaluation of Python expressions in BPF programs, including variables, constants, function calls, comparisons, boolean operations, and more. """ import ast from llvmlite import ir from logging import Logger import logging from typing import Dict from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes from .type_normalization import convert_to_bool, handle_comparator logger: Logger = logging.getLogger(__name__) def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder): """Handle ast.Name expressions.""" if expr.id in local_sym_tab: var = local_sym_tab[expr.id].var val = builder.load(var) return val, local_sym_tab[expr.id].ir_type else: logger.info(f"Undefined variable {expr.id}") return None def _handle_constant_expr(expr: ast.Constant): """Handle ast.Constant expressions.""" if isinstance(expr.value, int) or isinstance(expr.value, bool): return ir.Constant(ir.IntType(64), int(expr.value)), ir.IntType(64) else: logger.error("Unsupported constant type") return None def _handle_attribute_expr( expr: ast.Attribute, local_sym_tab: Dict, structs_sym_tab: Dict, builder: ir.IRBuilder, ): """Handle ast.Attribute expressions for struct field access.""" if isinstance(expr.value, ast.Name): var_name = expr.value.id attr_name = expr.attr if var_name in local_sym_tab: var_ptr, var_type, var_metadata = local_sym_tab[var_name] logger.info(f"Loading attribute {attr_name} from variable {var_name}") logger.info(f"Variable type: {var_type}, Variable ptr: {var_ptr}") metadata = structs_sym_tab[var_metadata] if attr_name in metadata.fields: gep = metadata.gep(builder, var_ptr, attr_name) val = builder.load(gep) field_type = metadata.field_type(attr_name) return val, field_type return None def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilder): """Handle deref function calls.""" logger.info(f"Handling deref {ast.dump(expr)}") if len(expr.args) != 1: logger.info("deref takes exactly one argument") return None arg = expr.args[0] if ( isinstance(arg, ast.Call) and isinstance(arg.func, ast.Name) and arg.func.id == "deref" ): logger.info("Multiple deref not supported") return None if isinstance(arg, ast.Name): if arg.id in local_sym_tab: arg_ptr = local_sym_tab[arg.id].var else: logger.info(f"Undefined variable {arg.id}") return None else: logger.info("Unsupported argument type for deref") return None if arg_ptr is None: logger.info("Failed to evaluate deref argument") return None # Load the value from pointer val = builder.load(arg_ptr) return val, local_sym_tab[arg.id].ir_type def _handle_ctypes_call( func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab=None, ): """Handle ctypes type constructor calls.""" if len(expr.args) != 1: logger.info("ctypes constructor takes exactly one argument") return None arg = expr.args[0] val = eval_expr( func, module, builder, arg, local_sym_tab, map_sym_tab, structs_sym_tab, ) if val is None: logger.info("Failed to evaluate argument to ctypes constructor") return None call_type = expr.func.id expected_type = ctypes_to_ir(call_type) if val[1] != expected_type: # NOTE: We are only considering casting to and from int types for now if isinstance(val[1], ir.IntType) and isinstance(expected_type, ir.IntType): if val[1].width < expected_type.width: val = (builder.sext(val[0], expected_type), expected_type) else: val = (builder.trunc(val[0], expected_type), expected_type) else: raise ValueError(f"Type mismatch: expected {expected_type}, got {val[1]}") return val def _handle_compare( func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab=None ): """Handle ast.Compare expressions.""" if len(cond.ops) != 1 or len(cond.comparators) != 1: logger.error("Only single comparisons are supported") return None lhs = eval_expr( func, module, builder, cond.left, local_sym_tab, map_sym_tab, structs_sym_tab, ) rhs = eval_expr( func, module, builder, cond.comparators[0], local_sym_tab, map_sym_tab, structs_sym_tab, ) if lhs is None or rhs is None: logger.error("Failed to evaluate comparison operands") return None lhs, _ = lhs rhs, _ = rhs return handle_comparator(func, builder, cond.ops[0], lhs, rhs) def _handle_unary_op( func, module, builder, expr: ast.UnaryOp, local_sym_tab, map_sym_tab, structs_sym_tab=None, ): """Handle ast.UnaryOp expressions.""" if not isinstance(expr.op, ast.Not): logger.error("Only 'not' unary operator is supported") return None operand = eval_expr( func, module, builder, expr.operand, local_sym_tab, map_sym_tab, structs_sym_tab ) if operand is None: logger.error("Failed to evaluate operand for unary operation") return None operand_val, operand_type = operand true_const = ir.Constant(ir.IntType(1), 1) result = builder.xor(convert_to_bool(builder, operand_val), true_const) return result, ir.IntType(1) def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab): """Handle `and` boolean operations.""" logger.debug(f"Handling 'and' operator with {len(expr.values)} operands") merge_block = func.append_basic_block(name="and.merge") false_block = func.append_basic_block(name="and.false") incoming_values = [] for i, value in enumerate(expr.values): is_last = i == len(expr.values) - 1 # Evaluate current operand operand_result = eval_expr( func, None, builder, value, local_sym_tab, map_sym_tab, structs_sym_tab ) if operand_result is None: logger.error(f"Failed to evaluate operand {i} in 'and' expression") return None operand_val, operand_type = operand_result # Convert to boolean if needed operand_bool = convert_to_bool(builder, operand_val) current_block = builder.block if is_last: # Last operand: result is this value builder.branch(merge_block) incoming_values.append((operand_bool, current_block)) else: # Not last: check if true, continue or short-circuit next_check = func.append_basic_block(name=f"and.check_{i + 1}") builder.cbranch(operand_bool, next_check, false_block) builder.position_at_end(next_check) # False block: short-circuit with false builder.position_at_end(false_block) builder.branch(merge_block) false_value = ir.Constant(ir.IntType(1), 0) incoming_values.append((false_value, false_block)) # Merge block: phi node builder.position_at_end(merge_block) phi = builder.phi(ir.IntType(1), name="and.result") for val, block in incoming_values: phi.add_incoming(val, block) logger.debug(f"Generated 'and' with {len(incoming_values)} incoming values") return phi, ir.IntType(1) def _handle_or_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab): """Handle `or` boolean operations.""" logger.debug(f"Handling 'or' operator with {len(expr.values)} operands") merge_block = func.append_basic_block(name="or.merge") true_block = func.append_basic_block(name="or.true") incoming_values = [] for i, value in enumerate(expr.values): is_last = i == len(expr.values) - 1 # Evaluate current operand operand_result = eval_expr( func, None, builder, value, local_sym_tab, map_sym_tab, structs_sym_tab ) if operand_result is None: logger.error(f"Failed to evaluate operand {i} in 'or' expression") return None operand_val, operand_type = operand_result # Convert to boolean if needed operand_bool = convert_to_bool(builder, operand_val) current_block = builder.block if is_last: # Last operand: result is this value builder.branch(merge_block) incoming_values.append((operand_bool, current_block)) else: # Not last: check if false, continue or short-circuit next_check = func.append_basic_block(name=f"or.check_{i + 1}") builder.cbranch(operand_bool, true_block, next_check) builder.position_at_end(next_check) # True block: short-circuit with true builder.position_at_end(true_block) builder.branch(merge_block) true_value = ir.Constant(ir.IntType(1), 1) incoming_values.append((true_value, true_block)) # Merge block: phi node builder.position_at_end(merge_block) phi = builder.phi(ir.IntType(1), name="or.result") for val, block in incoming_values: phi.add_incoming(val, block) logger.debug(f"Generated 'or' with {len(incoming_values)} incoming values") return phi, ir.IntType(1) def _handle_boolean_op( func, module, builder, expr: ast.BoolOp, local_sym_tab, map_sym_tab, structs_sym_tab=None, ): """Handle `and` and `or` boolean operations.""" if isinstance(expr.op, ast.And): return _handle_and_op( func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab ) elif isinstance(expr.op, ast.Or): return _handle_or_op( func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab ) else: logger.error(f"Unsupported boolean operator: {type(expr.op).__name__}") return None def eval_expr( func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab=None, ): """ Evaluate an expression and return its LLVM IR value and type. Args: func: The LLVM IR function being built module: The LLVM IR module builder: LLVM IR builder expr: The AST expression node to evaluate local_sym_tab: Local symbol table map_sym_tab: Map symbol table structs_sym_tab: Struct symbol table Returns: A tuple of (value, type) or None if evaluation fails """ logger.info(f"Evaluating expression: {ast.dump(expr)}") if isinstance(expr, ast.Name): return _handle_name_expr(expr, local_sym_tab, builder) elif isinstance(expr, ast.Constant): return _handle_constant_expr(expr) elif isinstance(expr, ast.Call): if isinstance(expr.func, ast.Name) and expr.func.id == "deref": return _handle_deref_call(expr, local_sym_tab, builder) if isinstance(expr.func, ast.Name) and is_ctypes(expr.func.id): return _handle_ctypes_call( func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab, ) # delayed import to avoid circular dependency from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call if isinstance(expr.func, ast.Name) and HelperHandlerRegistry.has_handler( expr.func.id ): return handle_helper_call( expr, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab, ) elif isinstance(expr.func, ast.Attribute): logger.info(f"Handling method call: {ast.dump(expr.func)}") if isinstance(expr.func.value, ast.Call) and isinstance( expr.func.value.func, ast.Name ): method_name = expr.func.attr if HelperHandlerRegistry.has_handler(method_name): return handle_helper_call( expr, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab, ) elif isinstance(expr.func.value, ast.Name): obj_name = expr.func.value.id method_name = expr.func.attr if obj_name in map_sym_tab: if HelperHandlerRegistry.has_handler(method_name): return handle_helper_call( expr, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab, ) elif isinstance(expr, ast.Attribute): return _handle_attribute_expr(expr, local_sym_tab, structs_sym_tab, builder) elif isinstance(expr, ast.BinOp): from pythonbpf.binary_ops import handle_binary_op return handle_binary_op(expr, builder, None, local_sym_tab) elif isinstance(expr, ast.Compare): return _handle_compare( func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab ) elif isinstance(expr, ast.UnaryOp): return _handle_unary_op( func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab ) elif isinstance(expr, ast.BoolOp): return _handle_boolean_op( func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab ) logger.info("Unsupported expression evaluation") return None def handle_expr( func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab, ): """Handle expression statements in the function body.""" logger.info(f"Handling expression: {ast.dump(expr)}") call = expr.value if isinstance(call, ast.Call): eval_expr( func, module, builder, call, local_sym_tab, map_sym_tab, structs_sym_tab, ) else: logger.info("Unsupported expression type")