Change ScratchPoolManager to use typed scratch space

This commit is contained in:
Pragyansh Chaturvedi
2025-11-04 14:16:44 +05:30
parent 123a92af1d
commit 963e2a8171
3 changed files with 26 additions and 9 deletions

View File

@ -50,7 +50,7 @@ def count_temps_in_call(call_node, local_sym_tab):
func_name = call_node.func.attr func_name = call_node.func.attr
if not is_helper: if not is_helper:
return 0 return {} # No temps needed
for arg_idx in range(len(call_node.args)): for arg_idx in range(len(call_node.args)):
# NOTE: Count all non-name arguments # NOTE: Count all non-name arguments

View File

@ -50,7 +50,7 @@ class HelperHandlerRegistry:
def get_param_type(cls, helper_name, index): def get_param_type(cls, helper_name, index):
"""Get the type of a parameter of a helper function by the index""" """Get the type of a parameter of a helper function by the index"""
signature = cls.get_signature(helper_name) signature = cls.get_signature(helper_name)
if signature and 0 <= index < len(signature.arg_types): if signature and signature.arg_types and 0 <= index < len(signature.arg_types):
return signature.arg_types[index] return signature.arg_types[index]
return None return None

View File

@ -14,26 +14,43 @@ class ScratchPoolManager:
"""Manage the temporary helper variables in local_sym_tab""" """Manage the temporary helper variables in local_sym_tab"""
def __init__(self): def __init__(self):
self._counter = 0 self._counters = {}
@property @property
def counter(self): def counter(self):
return self._counter return sum(self._counter.values())
def reset(self): def reset(self):
self._counter = 0 self._counters.clear()
logger.debug("Scratch pool counter reset to 0") logger.debug("Scratch pool counter reset to 0")
def get_next_temp(self, local_sym_tab): def _get_type_name(self, ir_type):
temp_name = f"__helper_temp_{self._counter}" if isinstance(ir_type, ir.PointerType):
self._counter += 1 return "ptr"
elif isinstance(ir_type, ir.IntType):
return f"i{ir_type.width}"
elif isinstance(ir_type, ir.ArrayType):
return f"[{ir_type.count}x{self._get_type_name(ir_type.element)}]"
else:
return str(ir_type).replace(" ", "")
def get_next_temp(self, local_sym_tab, expected_type=None):
# Default to i64 if no expected type provided
type_name = self._get_type_name(expected_type) if expected_type else "i64"
if type_name not in self._counters:
self._counters[type_name] = 0
counter = self._counters[type_name]
temp_name = f"__helper_temp_{type_name}_{counter}"
self._counters[type_name] += 1
if temp_name not in local_sym_tab: if temp_name not in local_sym_tab:
raise ValueError( raise ValueError(
f"Scratch pool exhausted or inadequate: {temp_name}. " f"Scratch pool exhausted or inadequate: {temp_name}. "
f"Current counter: {self._counter}" f"Type: {type_name} Counter: {counter}"
) )
logger.debug(f"Using {temp_name} for type {type_name}")
return local_sym_tab[temp_name].var, temp_name return local_sym_tab[temp_name].var, temp_name