From 123a92af1d624e6256e693976585b592727f0688 Mon Sep 17 00:00:00 2001 From: Pragyansh Chaturvedi Date: Tue, 4 Nov 2025 06:20:39 +0530 Subject: [PATCH] Change allocation pass to generate typed temp variables --- pythonbpf/allocation_pass.py | 30 ++++++++++++++++++++------- pythonbpf/functions/functions_pass.py | 15 ++++++++++---- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/pythonbpf/allocation_pass.py b/pythonbpf/allocation_pass.py index 49c787f..ae748dd 100644 --- a/pythonbpf/allocation_pass.py +++ b/pythonbpf/allocation_pass.py @@ -199,17 +199,33 @@ def _allocate_for_binop(builder, var_name, local_sym_tab): logger.info(f"Pre-allocated {var_name} for binop result") +def _get_type_name(ir_type): + """Get a string representation of an IR type.""" + if isinstance(ir_type, ir.IntType): + return f"i{ir_type.width}" + elif isinstance(ir_type, ir.PointerType): + return "ptr" + elif isinstance(ir_type, ir.ArrayType): + return f"[{ir_type.count}x{_get_type_name(ir_type.element)}]" + else: + return str(ir_type).replace(" ", "") + + def allocate_temp_pool(builder, max_temps, local_sym_tab): """Allocate the temporary scratch space pool for helper arguments.""" - if max_temps == 0: + if not max_temps: + logger.info("No temp pool allocation needed") return - logger.info(f"Allocating temp pool of {max_temps} variables") - for i in range(max_temps): - temp_name = f"__helper_temp_{i}" - temp_var = builder.alloca(ir.IntType(64), name=temp_name) - temp_var.align = 8 - local_sym_tab[temp_name] = LocalSymbol(temp_var, ir.IntType(64)) + for tmp_type, cnt in max_temps.items(): + type_name = _get_type_name(tmp_type) + logger.info(f"Allocating temp pool of {cnt} variables of type {type_name}") + for i in range(cnt): + temp_name = f"__helper_temp_{type_name}_{i}" + temp_var = builder.alloca(tmp_type, name=temp_name) + temp_var.align = _get_alignment(tmp_type) + local_sym_tab[temp_name] = LocalSymbol(temp_var, tmp_type) + logger.debug(f"Allocated temp variable: {temp_name}") def _allocate_for_name(builder, var_name, rval, local_sym_tab): diff --git a/pythonbpf/functions/functions_pass.py b/pythonbpf/functions/functions_pass.py index 47d83b5..c652d3e 100644 --- a/pythonbpf/functions/functions_pass.py +++ b/pythonbpf/functions/functions_pass.py @@ -98,11 +98,15 @@ def handle_if_allocation( def allocate_mem( module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab ): - max_temps_needed = 0 + max_temps_needed = {} + + def merge_type_counts(count_dict): + nonlocal max_temps_needed + for typ, cnt in count_dict.items(): + max_temps_needed[typ] = max(max_temps_needed.get(typ, 0), cnt) def update_max_temps_for_stmt(stmt): nonlocal max_temps_needed - temps_needed = 0 if isinstance(stmt, ast.If): for s in stmt.body: @@ -111,10 +115,13 @@ def allocate_mem( update_max_temps_for_stmt(s) return + stmt_temps = {} for node in ast.walk(stmt): if isinstance(node, ast.Call): - temps_needed += count_temps_in_call(node, local_sym_tab) - max_temps_needed = max(max_temps_needed, temps_needed) + call_temps = count_temps_in_call(node, local_sym_tab) + for typ, cnt in call_temps.items(): + stmt_temps[typ] = stmt_temps.get(typ, 0) + cnt + merge_type_counts(stmt_temps) for stmt in body: update_max_temps_for_stmt(stmt)