From 963e2a81718ca65571e6077763118c9ceab6dae2 Mon Sep 17 00:00:00 2001 From: Pragyansh Chaturvedi Date: Tue, 4 Nov 2025 14:16:44 +0530 Subject: [PATCH] Change ScratchPoolManager to use typed scratch space --- pythonbpf/functions/functions_pass.py | 2 +- pythonbpf/helper/helper_registry.py | 2 +- pythonbpf/helper/helper_utils.py | 31 +++++++++++++++++++++------ 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/pythonbpf/functions/functions_pass.py b/pythonbpf/functions/functions_pass.py index c652d3e..33342b7 100644 --- a/pythonbpf/functions/functions_pass.py +++ b/pythonbpf/functions/functions_pass.py @@ -50,7 +50,7 @@ def count_temps_in_call(call_node, local_sym_tab): func_name = call_node.func.attr if not is_helper: - return 0 + return {} # No temps needed for arg_idx in range(len(call_node.args)): # NOTE: Count all non-name arguments diff --git a/pythonbpf/helper/helper_registry.py b/pythonbpf/helper/helper_registry.py index dccb8c2..0e09d70 100644 --- a/pythonbpf/helper/helper_registry.py +++ b/pythonbpf/helper/helper_registry.py @@ -50,7 +50,7 @@ class HelperHandlerRegistry: def get_param_type(cls, helper_name, index): """Get the type of a parameter of a helper function by the index""" signature = cls.get_signature(helper_name) - if signature and 0 <= index < len(signature.arg_types): + if signature and signature.arg_types and 0 <= index < len(signature.arg_types): return signature.arg_types[index] return None diff --git a/pythonbpf/helper/helper_utils.py b/pythonbpf/helper/helper_utils.py index 841698c..4ec901d 100644 --- a/pythonbpf/helper/helper_utils.py +++ b/pythonbpf/helper/helper_utils.py @@ -14,26 +14,43 @@ class ScratchPoolManager: """Manage the temporary helper variables in local_sym_tab""" def __init__(self): - self._counter = 0 + self._counters = {} @property def counter(self): - return self._counter + return sum(self._counter.values()) def reset(self): - self._counter = 0 + self._counters.clear() logger.debug("Scratch pool counter reset to 0") - def get_next_temp(self, local_sym_tab): - temp_name = f"__helper_temp_{self._counter}" - self._counter += 1 + def _get_type_name(self, ir_type): + if isinstance(ir_type, ir.PointerType): + return "ptr" + elif isinstance(ir_type, ir.IntType): + return f"i{ir_type.width}" + elif isinstance(ir_type, ir.ArrayType): + return f"[{ir_type.count}x{self._get_type_name(ir_type.element)}]" + else: + return str(ir_type).replace(" ", "") + + def get_next_temp(self, local_sym_tab, expected_type=None): + # Default to i64 if no expected type provided + type_name = self._get_type_name(expected_type) if expected_type else "i64" + if type_name not in self._counters: + self._counters[type_name] = 0 + + counter = self._counters[type_name] + temp_name = f"__helper_temp_{type_name}_{counter}" + self._counters[type_name] += 1 if temp_name not in local_sym_tab: raise ValueError( f"Scratch pool exhausted or inadequate: {temp_name}. " - f"Current counter: {self._counter}" + f"Type: {type_name} Counter: {counter}" ) + logger.debug(f"Using {temp_name} for type {type_name}") return local_sym_tab[temp_name].var, temp_name