From f53ca3bd5b091910ca08f5832d7cf7c4e6ceb016 Mon Sep 17 00:00:00 2001 From: Pragyansh Chaturvedi Date: Mon, 6 Oct 2025 04:43:04 +0530 Subject: [PATCH] Add ctypes in eval_expr --- pythonbpf/expr_pass.py | 57 +++++++++++++++++++++++++++ pythonbpf/functions/functions_pass.py | 5 +++ pythonbpf/type_deducer.py | 34 +++++++++------- 3 files changed, 81 insertions(+), 15 deletions(-) diff --git a/pythonbpf/expr_pass.py b/pythonbpf/expr_pass.py index 9ceb77f..bcdb018 100644 --- a/pythonbpf/expr_pass.py +++ b/pythonbpf/expr_pass.py @@ -4,6 +4,8 @@ from logging import Logger import logging from typing import Dict +from .type_deducer import ctypes_to_ir, is_ctypes + logger: Logger = logging.getLogger(__name__) @@ -88,6 +90,50 @@ def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilde return val, local_sym_tab[arg.id].ir_type +def _handle_ctypes_call( + func, + module, + builder, + expr, + local_sym_tab, + map_sym_tab, + structs_sym_tab=None, +): + """Handle ctypes type constructor calls.""" + if len(expr.args) != 1: + logger.info("ctypes constructor takes exactly one argument") + return None + + arg = expr.args[0] + val = eval_expr( + func, + module, + builder, + arg, + local_sym_tab, + map_sym_tab, + structs_sym_tab, + ) + if val is None: + logger.info("Failed to evaluate argument to ctypes constructor") + return None + call_type = expr.func.id + expected_type = ctypes_to_ir(call_type) + if expected_type is None: + logger.info(f"Unsupported ctypes type: {call_type}") + return None + if val[1] != expected_type: + # NOTE: We are only considering casting to and from int types for now + if isinstance(val[1], ir.IntType) and isinstance(expected_type, ir.IntType): + if val[1].width < expected_type.width: + val = (builder.sext(val[0], expected_type), expected_type) + else: + val = (builder.trunc(val[0], expected_type), expected_type) + else: + raise ValueError(f"Type mismatch: expected {expected_type}, got {val[1]}") + return val + + def eval_expr( func, module, @@ -106,6 +152,17 @@ def eval_expr( if isinstance(expr.func, ast.Name) and expr.func.id == "deref": return _handle_deref_call(expr, local_sym_tab, builder) + if isinstance(expr.func, ast.Name) and is_ctypes(expr.func.id): + return _handle_ctypes_call( + func, + module, + builder, + expr, + local_sym_tab, + map_sym_tab, + structs_sym_tab, + ) + # delayed import to avoid circular dependency from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call diff --git a/pythonbpf/functions/functions_pass.py b/pythonbpf/functions/functions_pass.py index 16554f5..8d4a559 100644 --- a/pythonbpf/functions/functions_pass.py +++ b/pythonbpf/functions/functions_pass.py @@ -359,6 +359,11 @@ def handle_return(builder, stmt, local_sym_tab, ret_type): return _handle_none_return(builder) elif isinstance(stmt.value, ast.Name) and _is_xdp_name(stmt.value.id): return _handle_xdp_return(stmt, builder, ret_type) + elif True: + val = eval_expr(None, None, builder, stmt.value, local_sym_tab, {}, {}) + logger.info(f"Evaluated return expression to {val}") + builder.ret(val[0]) + return True elif ( isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name) diff --git a/pythonbpf/type_deducer.py b/pythonbpf/type_deducer.py index 909d33c..9867cc6 100644 --- a/pythonbpf/type_deducer.py +++ b/pythonbpf/type_deducer.py @@ -1,24 +1,28 @@ from llvmlite import ir # TODO: THIS IS NOT SUPPOSED TO MATCH STRINGS :skull: +mapping = { + "c_int8": ir.IntType(8), + "c_uint8": ir.IntType(8), + "c_int16": ir.IntType(16), + "c_uint16": ir.IntType(16), + "c_int32": ir.IntType(32), + "c_uint32": ir.IntType(32), + "c_int64": ir.IntType(64), + "c_uint64": ir.IntType(64), + "c_float": ir.FloatType(), + "c_double": ir.DoubleType(), + "c_void_p": ir.IntType(64), + # Not so sure about this one + "str": ir.PointerType(ir.IntType(8)), +} def ctypes_to_ir(ctype: str): - mapping = { - "c_int8": ir.IntType(8), - "c_uint8": ir.IntType(8), - "c_int16": ir.IntType(16), - "c_uint16": ir.IntType(16), - "c_int32": ir.IntType(32), - "c_uint32": ir.IntType(32), - "c_int64": ir.IntType(64), - "c_uint64": ir.IntType(64), - "c_float": ir.FloatType(), - "c_double": ir.DoubleType(), - "c_void_p": ir.IntType(64), - # Not so sure about this one - "str": ir.PointerType(ir.IntType(8)), - } if ctype in mapping: return mapping[ctype] raise NotImplementedError(f"No mapping for {ctype}") + + +def is_ctypes(ctype: str) -> bool: + return ctype in mapping