mirror of
https://github.com/varun-r-mallya/Python-BPF.git
synced 2025-12-31 21:06:25 +00:00
Refactor allocate_mem
This commit is contained in:
@ -220,6 +220,203 @@ def process_stmt(
|
|||||||
return did_return
|
return did_return
|
||||||
|
|
||||||
|
|
||||||
|
def _is_helper_call(call_node):
|
||||||
|
"""Check if a call node is a BPF helper function call."""
|
||||||
|
if isinstance(call_node.func, ast.Name):
|
||||||
|
# Exclude print from requiring temps (handles f-strings differently)
|
||||||
|
func_name = call_node.func.id
|
||||||
|
return HelperHandlerRegistry.has_handler(func_name) and func_name != "print"
|
||||||
|
|
||||||
|
elif isinstance(call_node.func, ast.Attribute):
|
||||||
|
return HelperHandlerRegistry.has_handler(call_node.func.attr)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_if_allocation(
|
||||||
|
module, builder, stmt, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
|
||||||
|
):
|
||||||
|
"""Recursively handle allocations in if/else branches."""
|
||||||
|
if stmt.body:
|
||||||
|
allocate_mem(
|
||||||
|
module,
|
||||||
|
builder,
|
||||||
|
stmt.body,
|
||||||
|
func,
|
||||||
|
ret_type,
|
||||||
|
map_sym_tab,
|
||||||
|
local_sym_tab,
|
||||||
|
structs_sym_tab,
|
||||||
|
)
|
||||||
|
if stmt.orelse:
|
||||||
|
allocate_mem(
|
||||||
|
module,
|
||||||
|
builder,
|
||||||
|
stmt.orelse,
|
||||||
|
func,
|
||||||
|
ret_type,
|
||||||
|
map_sym_tab,
|
||||||
|
local_sym_tab,
|
||||||
|
structs_sym_tab,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab):
|
||||||
|
"""Handle memory allocation for assignment statements."""
|
||||||
|
|
||||||
|
# Validate assignment
|
||||||
|
if len(stmt.targets) != 1:
|
||||||
|
logger.warning("Multi-target assignment not supported, skipping allocation")
|
||||||
|
return
|
||||||
|
|
||||||
|
target = stmt.targets[0]
|
||||||
|
|
||||||
|
# Skip non-name targets (e.g., struct field assignments)
|
||||||
|
if isinstance(target, ast.Attribute):
|
||||||
|
logger.debug(f"Struct field assignment to {target.attr}, no allocation needed")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not isinstance(target, ast.Name):
|
||||||
|
logger.warning(f"Unsupported assignment target type: {type(target).__name__}")
|
||||||
|
return
|
||||||
|
|
||||||
|
var_name = target.id
|
||||||
|
rval = stmt.value
|
||||||
|
|
||||||
|
# Skip if already allocated
|
||||||
|
if var_name in local_sym_tab:
|
||||||
|
logger.debug(f"Variable {var_name} already allocated, skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Determine type and allocate based on rval
|
||||||
|
if isinstance(rval, ast.Call):
|
||||||
|
_allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab)
|
||||||
|
elif isinstance(rval, ast.Constant):
|
||||||
|
_allocate_for_constant(builder, var_name, rval, local_sym_tab)
|
||||||
|
elif isinstance(rval, ast.BinOp):
|
||||||
|
_allocate_for_binop(builder, var_name, local_sym_tab)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Unsupported assignment value type for {var_name}: {type(rval).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab):
|
||||||
|
"""Allocate memory for variable assigned from a call."""
|
||||||
|
|
||||||
|
if isinstance(rval.func, ast.Name):
|
||||||
|
call_type = rval.func.id
|
||||||
|
|
||||||
|
# C type constructors
|
||||||
|
if call_type in ("c_int32", "c_int64", "c_uint32", "c_uint64"):
|
||||||
|
ir_type = ctypes_to_ir(call_type)
|
||||||
|
var = builder.alloca(ir_type, name=var_name)
|
||||||
|
var.align = ir_type.width // 8
|
||||||
|
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
|
||||||
|
logger.info(f"Pre-allocated {var_name} as {call_type}")
|
||||||
|
|
||||||
|
# Helper functions
|
||||||
|
elif HelperHandlerRegistry.has_handler(call_type):
|
||||||
|
ir_type = ir.IntType(64) # Assume i64 return type
|
||||||
|
var = builder.alloca(ir_type, name=var_name)
|
||||||
|
var.align = 8
|
||||||
|
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
|
||||||
|
logger.info(f"Pre-allocated {var_name} for helper {call_type}")
|
||||||
|
|
||||||
|
# Deref function
|
||||||
|
elif call_type == "deref":
|
||||||
|
ir_type = ir.IntType(64) # Assume i64 return type
|
||||||
|
var = builder.alloca(ir_type, name=var_name)
|
||||||
|
var.align = 8
|
||||||
|
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
|
||||||
|
logger.info(f"Pre-allocated {var_name} for deref")
|
||||||
|
|
||||||
|
# Struct constructors
|
||||||
|
elif call_type in structs_sym_tab:
|
||||||
|
struct_info = structs_sym_tab[call_type]
|
||||||
|
var = builder.alloca(struct_info.ir_type, name=var_name)
|
||||||
|
local_sym_tab[var_name] = LocalSymbol(var, struct_info.ir_type, call_type)
|
||||||
|
logger.info(f"Pre-allocated {var_name} for struct {call_type}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unknown call type for allocation: {call_type}")
|
||||||
|
|
||||||
|
elif isinstance(rval.func, ast.Attribute):
|
||||||
|
# Map method calls - need double allocation for ptr handling
|
||||||
|
_allocate_for_map_method(builder, var_name, local_sym_tab)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unsupported call function type for {var_name}")
|
||||||
|
|
||||||
|
|
||||||
|
def _allocate_for_map_method(builder, var_name, local_sym_tab):
|
||||||
|
"""Allocate memory for variable assigned from map method (double alloc)."""
|
||||||
|
|
||||||
|
# Main variable (pointer to pointer)
|
||||||
|
ir_type = ir.PointerType(ir.IntType(64))
|
||||||
|
var = builder.alloca(ir_type, name=var_name)
|
||||||
|
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
|
||||||
|
|
||||||
|
# Temporary variable for computed values
|
||||||
|
tmp_ir_type = ir.IntType(64)
|
||||||
|
var_tmp = builder.alloca(tmp_ir_type, name=f"{var_name}_tmp")
|
||||||
|
local_sym_tab[f"{var_name}_tmp"] = LocalSymbol(var_tmp, tmp_ir_type)
|
||||||
|
|
||||||
|
logger.info(f"Pre-allocated {var_name} and {var_name}_tmp for map method")
|
||||||
|
|
||||||
|
|
||||||
|
def _allocate_for_constant(builder, var_name, rval, local_sym_tab):
|
||||||
|
"""Allocate memory for variable assigned from a constant."""
|
||||||
|
|
||||||
|
if isinstance(rval.value, bool):
|
||||||
|
ir_type = ir.IntType(1)
|
||||||
|
var = builder.alloca(ir_type, name=var_name)
|
||||||
|
var.align = 1
|
||||||
|
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
|
||||||
|
logger.info(f"Pre-allocated {var_name} as bool")
|
||||||
|
|
||||||
|
elif isinstance(rval.value, int):
|
||||||
|
ir_type = ir.IntType(64)
|
||||||
|
var = builder.alloca(ir_type, name=var_name)
|
||||||
|
var.align = 8
|
||||||
|
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
|
||||||
|
logger.info(f"Pre-allocated {var_name} as i64")
|
||||||
|
|
||||||
|
elif isinstance(rval.value, str):
|
||||||
|
ir_type = ir.PointerType(ir.IntType(8))
|
||||||
|
var = builder.alloca(ir_type, name=var_name)
|
||||||
|
var.align = 8
|
||||||
|
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
|
||||||
|
logger.info(f"Pre-allocated {var_name} as string")
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Unsupported constant type for {var_name}: {type(rval.value).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _allocate_for_binop(builder, var_name, local_sym_tab):
|
||||||
|
"""Allocate memory for variable assigned from a binary operation."""
|
||||||
|
ir_type = ir.IntType(64) # Assume i64 result
|
||||||
|
var = builder.alloca(ir_type, name=var_name)
|
||||||
|
var.align = 8
|
||||||
|
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
|
||||||
|
logger.info(f"Pre-allocated {var_name} for binop result")
|
||||||
|
|
||||||
|
|
||||||
|
def _allocate_temp_pool(builder, max_temps, local_sym_tab):
|
||||||
|
"""Allocate the temporary scratch space pool for helper arguments."""
|
||||||
|
if max_temps == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Allocating temp pool of {max_temps} variables")
|
||||||
|
for i in range(max_temps):
|
||||||
|
temp_name = f"__helper_temp_{i}"
|
||||||
|
temp_var = builder.alloca(ir.IntType(64), name=temp_name)
|
||||||
|
temp_var.align = 8
|
||||||
|
local_sym_tab[temp_name] = LocalSymbol(temp_var, ir.IntType(64))
|
||||||
|
|
||||||
|
|
||||||
def count_temps_in_call(call_node, local_sym_tab):
|
def count_temps_in_call(call_node, local_sym_tab):
|
||||||
"""Count the number of temporary variables needed for a function call."""
|
"""Count the number of temporary variables needed for a function call."""
|
||||||
|
|
||||||
@ -255,7 +452,6 @@ def count_temps_in_call(call_node, local_sym_tab):
|
|||||||
def allocate_mem(
|
def allocate_mem(
|
||||||
module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
|
module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
|
||||||
):
|
):
|
||||||
double_alloc = False
|
|
||||||
max_temps_needed = 0
|
max_temps_needed = 0
|
||||||
|
|
||||||
def update_max_temps_for_stmt(stmt):
|
def update_max_temps_for_stmt(stmt):
|
||||||
@ -276,139 +472,23 @@ def allocate_mem(
|
|||||||
|
|
||||||
for stmt in body:
|
for stmt in body:
|
||||||
update_max_temps_for_stmt(stmt)
|
update_max_temps_for_stmt(stmt)
|
||||||
has_metadata = False
|
|
||||||
|
# Handle allocations
|
||||||
if isinstance(stmt, ast.If):
|
if isinstance(stmt, ast.If):
|
||||||
if stmt.body:
|
_handle_if_allocation(
|
||||||
local_sym_tab = allocate_mem(
|
module,
|
||||||
module,
|
builder,
|
||||||
builder,
|
stmt,
|
||||||
stmt.body,
|
func,
|
||||||
func,
|
ret_type,
|
||||||
ret_type,
|
map_sym_tab,
|
||||||
map_sym_tab,
|
local_sym_tab,
|
||||||
local_sym_tab,
|
structs_sym_tab,
|
||||||
structs_sym_tab,
|
)
|
||||||
)
|
|
||||||
if stmt.orelse:
|
|
||||||
local_sym_tab = allocate_mem(
|
|
||||||
module,
|
|
||||||
builder,
|
|
||||||
stmt.orelse,
|
|
||||||
func,
|
|
||||||
ret_type,
|
|
||||||
map_sym_tab,
|
|
||||||
local_sym_tab,
|
|
||||||
structs_sym_tab,
|
|
||||||
)
|
|
||||||
elif isinstance(stmt, ast.Assign):
|
elif isinstance(stmt, ast.Assign):
|
||||||
if len(stmt.targets) != 1:
|
_handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab)
|
||||||
logger.info("Unsupported multiassignment")
|
|
||||||
continue
|
|
||||||
target = stmt.targets[0]
|
|
||||||
if not isinstance(target, ast.Name) and not isinstance(
|
|
||||||
target, ast.Attribute
|
|
||||||
):
|
|
||||||
logger.info("Unsupported assignment target")
|
|
||||||
continue
|
|
||||||
if isinstance(target, ast.Attribute):
|
|
||||||
logger.info(
|
|
||||||
f"Struct field {target.attr} assignment, will be handled later"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
var_name = target.id
|
|
||||||
rval = stmt.value
|
|
||||||
if var_name in local_sym_tab:
|
|
||||||
logger.info(f"Variable {var_name} already allocated")
|
|
||||||
continue
|
|
||||||
if isinstance(rval, ast.Call):
|
|
||||||
if isinstance(rval.func, ast.Name):
|
|
||||||
call_type = rval.func.id
|
|
||||||
if call_type in ("c_int32", "c_int64", "c_uint32", "c_uint64"):
|
|
||||||
ir_type = ctypes_to_ir(call_type)
|
|
||||||
var = builder.alloca(ir_type, name=var_name)
|
|
||||||
var.align = ir_type.width // 8
|
|
||||||
logger.info(
|
|
||||||
f"Pre-allocated variable {var_name} of type {call_type}"
|
|
||||||
)
|
|
||||||
elif HelperHandlerRegistry.has_handler(call_type):
|
|
||||||
# Assume return type is int64 for now
|
|
||||||
ir_type = ir.IntType(64)
|
|
||||||
var = builder.alloca(ir_type, name=var_name)
|
|
||||||
var.align = ir_type.width // 8
|
|
||||||
logger.info(f"Pre-allocated variable {var_name} for helper")
|
|
||||||
elif call_type == "deref" and len(rval.args) == 1:
|
|
||||||
# Assume return type is int64 for now
|
|
||||||
ir_type = ir.IntType(64)
|
|
||||||
var = builder.alloca(ir_type, name=var_name)
|
|
||||||
var.align = ir_type.width // 8
|
|
||||||
logger.info(f"Pre-allocated variable {var_name} for deref")
|
|
||||||
elif call_type in structs_sym_tab:
|
|
||||||
struct_info = structs_sym_tab[call_type]
|
|
||||||
ir_type = struct_info.ir_type
|
|
||||||
var = builder.alloca(ir_type, name=var_name)
|
|
||||||
has_metadata = True
|
|
||||||
logger.info(
|
|
||||||
f"Pre-allocated variable {var_name} for struct {call_type}"
|
|
||||||
)
|
|
||||||
elif isinstance(rval.func, ast.Attribute):
|
|
||||||
# Map method call
|
|
||||||
ir_type = ir.PointerType(ir.IntType(64))
|
|
||||||
var = builder.alloca(ir_type, name=var_name)
|
|
||||||
|
|
||||||
# declare an intermediate ptr type for map lookup
|
_allocate_temp_pool(builder, max_temps_needed, local_sym_tab)
|
||||||
tmp_ir_type = ir.IntType(64)
|
|
||||||
var_tmp = builder.alloca(tmp_ir_type, name=f"{var_name}_tmp")
|
|
||||||
double_alloc = True
|
|
||||||
# var.align = ir_type.width // 8
|
|
||||||
logger.info(
|
|
||||||
f"Pre-allocated variable {var_name} and {var_name}_tmp for map"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info("Unsupported assignment call function type")
|
|
||||||
continue
|
|
||||||
elif isinstance(rval, ast.Constant):
|
|
||||||
if isinstance(rval.value, bool):
|
|
||||||
ir_type = ir.IntType(1)
|
|
||||||
var = builder.alloca(ir_type, name=var_name)
|
|
||||||
var.align = 1
|
|
||||||
logger.info(f"Pre-allocated variable {var_name} of type c_bool")
|
|
||||||
elif isinstance(rval.value, int):
|
|
||||||
# Assume c_int64 for now
|
|
||||||
ir_type = ir.IntType(64)
|
|
||||||
var = builder.alloca(ir_type, name=var_name)
|
|
||||||
var.align = ir_type.width // 8
|
|
||||||
logger.info(f"Pre-allocated variable {var_name} of type c_int64")
|
|
||||||
elif isinstance(rval.value, str):
|
|
||||||
ir_type = ir.PointerType(ir.IntType(8))
|
|
||||||
var = builder.alloca(ir_type, name=var_name)
|
|
||||||
var.align = 8
|
|
||||||
logger.info(f"Pre-allocated variable {var_name} of type string")
|
|
||||||
else:
|
|
||||||
logger.info("Unsupported constant type")
|
|
||||||
continue
|
|
||||||
elif isinstance(rval, ast.BinOp):
|
|
||||||
# Assume c_int64 for now
|
|
||||||
ir_type = ir.IntType(64)
|
|
||||||
var = builder.alloca(ir_type, name=var_name)
|
|
||||||
var.align = ir_type.width // 8
|
|
||||||
logger.info(f"Pre-allocated variable {var_name} of type c_int64")
|
|
||||||
else:
|
|
||||||
logger.info("Unsupported assignment value type")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if has_metadata:
|
|
||||||
local_sym_tab[var_name] = LocalSymbol(var, ir_type, call_type)
|
|
||||||
else:
|
|
||||||
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
|
|
||||||
|
|
||||||
if double_alloc:
|
|
||||||
local_sym_tab[f"{var_name}_tmp"] = LocalSymbol(var_tmp, tmp_ir_type)
|
|
||||||
|
|
||||||
logger.info(f"Temporary scratch space needed for calls: {max_temps_needed}")
|
|
||||||
for i in range(max_temps_needed):
|
|
||||||
temp_var = builder.alloca(ir.IntType(64), name=f"__helper_temp_{i}")
|
|
||||||
temp_var.align = 8
|
|
||||||
local_sym_tab[f"__helper_temp_{i}"] = LocalSymbol(temp_var, ir.IntType(64))
|
|
||||||
|
|
||||||
return local_sym_tab
|
return local_sym_tab
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user