Allow map-based helpers to be used as helper args / within binops which are helper args

This commit is contained in:
Pragyansh Chaturvedi
2025-10-12 07:57:55 +05:30
parent d66e6a6aff
commit 2cf68f6473
4 changed files with 56 additions and 18 deletions

View File

@ -388,14 +388,18 @@ def process_stmt(
return did_return
def count_temps_in_call(call_node):
def count_temps_in_call(call_node, local_sym_tab):
"""Count the number of temporary variables needed for a function call."""
count = 0
is_helper = False
# NOTE: We exclude print calls for now
if isinstance(call_node.func, ast.Name):
if HelperHandlerRegistry.has_handler(call_node.func.id):
if (
HelperHandlerRegistry.has_handler(call_node.func.id)
and call_node.func.id != "print"
):
is_helper = True
elif isinstance(call_node.func, ast.Attribute):
if HelperHandlerRegistry.has_handler(call_node.func.attr):
@ -405,10 +409,11 @@ def count_temps_in_call(call_node):
return 0
for arg in call_node.args:
if (
isinstance(arg, ast.BinOp)
or isinstance(arg, ast.Constant)
or isinstance(arg, ast.UnaryOp)
# NOTE: Count all non-name arguments
# For struct fields, if it is being passed as an argument,
# The struct object should already exist in the local_sym_tab
if not isinstance(arg, ast.Name) and not (
isinstance(arg, ast.Attribute) and arg.value.id in local_sym_tab
):
count += 1
@ -423,11 +428,19 @@ def allocate_mem(
def update_max_temps_for_stmt(stmt):
nonlocal max_temps_needed
temps_needed = 0
if isinstance(stmt, ast.If):
for s in stmt.body:
update_max_temps_for_stmt(s)
for s in stmt.orelse:
update_max_temps_for_stmt(s)
return
for node in ast.walk(stmt):
if isinstance(node, ast.Call):
temps_needed = count_temps_in_call(node)
max_temps_needed = max(max_temps_needed, temps_needed)
temps_needed += count_temps_in_call(node, local_sym_tab)
max_temps_needed = max(max_temps_needed, temps_needed)
for stmt in body:
update_max_temps_for_stmt(stmt)
@ -460,9 +473,16 @@ def allocate_mem(
logger.info("Unsupported multiassignment")
continue
target = stmt.targets[0]
if not isinstance(target, ast.Name):
if not isinstance(target, ast.Name) and not isinstance(
target, ast.Attribute
):
logger.info("Unsupported assignment target")
continue
if isinstance(target, ast.Attribute):
logger.info(
f"Struct field {target.attr} assignment, will be handled later"
)
continue
var_name = target.id
rval = stmt.value
if var_name in local_sym_tab: