From 7a67041ea31534ba8fafa2b410c11365e0f2065e Mon Sep 17 00:00:00 2001 From: Pragyansh Chaturvedi Date: Mon, 13 Oct 2025 04:16:22 +0530 Subject: [PATCH] Move CallHandlerRegistry to expr/call_registry.py, annotate eval_expr --- pythonbpf/expr/__init__.py | 3 +- pythonbpf/expr/call_registry.py | 20 +++ pythonbpf/expr/expr_pass.py | 213 ++++++++++++++++---------------- 3 files changed, 130 insertions(+), 106 deletions(-) create mode 100644 pythonbpf/expr/call_registry.py diff --git a/pythonbpf/expr/__init__.py b/pythonbpf/expr/__init__.py index 4e8f82d..3c403dd 100644 --- a/pythonbpf/expr/__init__.py +++ b/pythonbpf/expr/__init__.py @@ -1,6 +1,7 @@ -from .expr_pass import eval_expr, handle_expr, get_operand_value, CallHandlerRegistry +from .expr_pass import eval_expr, handle_expr, get_operand_value from .type_normalization import convert_to_bool, get_base_type_and_depth from .ir_ops import deref_to_depth +from .call_registry import CallHandlerRegistry __all__ = [ "eval_expr", diff --git a/pythonbpf/expr/call_registry.py b/pythonbpf/expr/call_registry.py new file mode 100644 index 0000000..858e23c --- /dev/null +++ b/pythonbpf/expr/call_registry.py @@ -0,0 +1,20 @@ +class CallHandlerRegistry: + """Registry for handling different types of calls (helpers, etc.)""" + + _handler = None + + @classmethod + def set_handler(cls, handler): + """Set the handler for unknown calls""" + cls._handler = handler + + @classmethod + def handle_call( + cls, call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab + ): + """Handle a call using the registered handler""" + if cls._handler is None: + return None + return cls._handler( + call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab + ) diff --git a/pythonbpf/expr/expr_pass.py b/pythonbpf/expr/expr_pass.py index e662984..8bbd524 100644 --- a/pythonbpf/expr/expr_pass.py +++ b/pythonbpf/expr/expr_pass.py @@ -5,6 +5,7 @@ import logging from typing import Dict from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes +from .call_registry import CallHandlerRegistry from .type_normalization import ( convert_to_bool, handle_comparator, @@ -14,27 +15,106 @@ from .type_normalization import ( logger: Logger = logging.getLogger(__name__) +# ============================================================================ +# Leaf Handlers (No Recursive eval_expr calls) +# ============================================================================ -class CallHandlerRegistry: - """Registry for handling different types of calls (helpers, etc.)""" - _handler = None +def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder): + """Handle ast.Name expressions.""" + if expr.id in local_sym_tab: + var = local_sym_tab[expr.id].var + val = builder.load(var) + return val, local_sym_tab[expr.id].ir_type + else: + logger.info(f"Undefined variable {expr.id}") + return None - @classmethod - def set_handler(cls, handler): - """Set the handler for unknown calls""" - cls._handler = handler - @classmethod - def handle_call( - cls, call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab +def _handle_constant_expr(module, builder, expr: ast.Constant): + """Handle ast.Constant expressions.""" + if isinstance(expr.value, int) or isinstance(expr.value, bool): + return ir.Constant(ir.IntType(64), int(expr.value)), ir.IntType(64) + elif isinstance(expr.value, str): + str_name = f".str.{id(expr)}" + str_bytes = expr.value.encode("utf-8") + b"\x00" + str_type = ir.ArrayType(ir.IntType(8), len(str_bytes)) + str_constant = ir.Constant(str_type, bytearray(str_bytes)) + + # Create global variable + global_str = ir.GlobalVariable(module, str_type, name=str_name) + global_str.linkage = "internal" + global_str.global_constant = True + global_str.initializer = str_constant + + str_ptr = builder.bitcast(global_str, ir.PointerType(ir.IntType(8))) + return str_ptr, ir.PointerType(ir.IntType(8)) + else: + logger.error(f"Unsupported constant type {ast.dump(expr)}") + return None + + +def _handle_attribute_expr( + expr: ast.Attribute, + local_sym_tab: Dict, + structs_sym_tab: Dict, + builder: ir.IRBuilder, +): + """Handle ast.Attribute expressions for struct field access.""" + if isinstance(expr.value, ast.Name): + var_name = expr.value.id + attr_name = expr.attr + if var_name in local_sym_tab: + var_ptr, var_type, var_metadata = local_sym_tab[var_name] + logger.info(f"Loading attribute {attr_name} from variable {var_name}") + logger.info(f"Variable type: {var_type}, Variable ptr: {var_ptr}") + metadata = structs_sym_tab[var_metadata] + if attr_name in metadata.fields: + gep = metadata.gep(builder, var_ptr, attr_name) + val = builder.load(gep) + field_type = metadata.field_type(attr_name) + return val, field_type + return None + + +def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilder): + """Handle deref function calls.""" + logger.info(f"Handling deref {ast.dump(expr)}") + if len(expr.args) != 1: + logger.info("deref takes exactly one argument") + return None + + arg = expr.args[0] + if ( + isinstance(arg, ast.Call) + and isinstance(arg.func, ast.Name) + and arg.func.id == "deref" ): - """Handle a call using the registered handler""" - if cls._handler is None: + logger.info("Multiple deref not supported") + return None + + if isinstance(arg, ast.Name): + if arg.id in local_sym_tab: + arg_ptr = local_sym_tab[arg.id].var + else: + logger.info(f"Undefined variable {arg.id}") return None - return cls._handler( - call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab - ) + else: + logger.info("Unsupported argument type for deref") + return None + + if arg_ptr is None: + logger.info("Failed to evaluate deref argument") + return None + + # Load the value from pointer + val = builder.load(arg_ptr) + return val, local_sym_tab[arg.id].ir_type + + +# ============================================================================ +# Binary Operations +# ============================================================================ def get_operand_value( @@ -139,96 +219,9 @@ def _handle_binary_op( return result, result.type -def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder): - """Handle ast.Name expressions.""" - if expr.id in local_sym_tab: - var = local_sym_tab[expr.id].var - val = builder.load(var) - return val, local_sym_tab[expr.id].ir_type - else: - logger.info(f"Undefined variable {expr.id}") - return None - - -def _handle_constant_expr(module, builder, expr: ast.Constant): - """Handle ast.Constant expressions.""" - if isinstance(expr.value, int) or isinstance(expr.value, bool): - return ir.Constant(ir.IntType(64), int(expr.value)), ir.IntType(64) - elif isinstance(expr.value, str): - str_name = f".str.{id(expr)}" - str_bytes = expr.value.encode("utf-8") + b"\x00" - str_type = ir.ArrayType(ir.IntType(8), len(str_bytes)) - str_constant = ir.Constant(str_type, bytearray(str_bytes)) - - # Create global variable - global_str = ir.GlobalVariable(module, str_type, name=str_name) - global_str.linkage = "internal" - global_str.global_constant = True - global_str.initializer = str_constant - - str_ptr = builder.bitcast(global_str, ir.PointerType(ir.IntType(8))) - return str_ptr, ir.PointerType(ir.IntType(8)) - else: - logger.error(f"Unsupported constant type {ast.dump(expr)}") - return None - - -def _handle_attribute_expr( - expr: ast.Attribute, - local_sym_tab: Dict, - structs_sym_tab: Dict, - builder: ir.IRBuilder, -): - """Handle ast.Attribute expressions for struct field access.""" - if isinstance(expr.value, ast.Name): - var_name = expr.value.id - attr_name = expr.attr - if var_name in local_sym_tab: - var_ptr, var_type, var_metadata = local_sym_tab[var_name] - logger.info(f"Loading attribute {attr_name} from variable {var_name}") - logger.info(f"Variable type: {var_type}, Variable ptr: {var_ptr}") - metadata = structs_sym_tab[var_metadata] - if attr_name in metadata.fields: - gep = metadata.gep(builder, var_ptr, attr_name) - val = builder.load(gep) - field_type = metadata.field_type(attr_name) - return val, field_type - return None - - -def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilder): - """Handle deref function calls.""" - logger.info(f"Handling deref {ast.dump(expr)}") - if len(expr.args) != 1: - logger.info("deref takes exactly one argument") - return None - - arg = expr.args[0] - if ( - isinstance(arg, ast.Call) - and isinstance(arg.func, ast.Name) - and arg.func.id == "deref" - ): - logger.info("Multiple deref not supported") - return None - - if isinstance(arg, ast.Name): - if arg.id in local_sym_tab: - arg_ptr = local_sym_tab[arg.id].var - else: - logger.info(f"Undefined variable {arg.id}") - return None - else: - logger.info("Unsupported argument type for deref") - return None - - if arg_ptr is None: - logger.info("Failed to evaluate deref argument") - return None - - # Load the value from pointer - val = builder.load(arg_ptr) - return val, local_sym_tab[arg.id].ir_type +# ============================================================================ +# Comparison and Unary Operations +# ============================================================================ def _handle_ctypes_call( @@ -341,6 +334,11 @@ def _handle_unary_op( return result, ir.IntType(64) +# ============================================================================ +# Boolean Operations +# ============================================================================ + + def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab): """Handle `and` boolean operations.""" @@ -471,6 +469,11 @@ def _handle_boolean_op( return None +# ============================================================================ +# Expression Dispatcher +# ============================================================================ + + def eval_expr( func, module,