Add LocalSymbol dataclass

This commit is contained in:
Pragyansh Chaturvedi
2025-10-02 04:13:24 +05:30
parent 1a66887f48
commit 2fd2a46838
2 changed files with 39 additions and 15 deletions

View File

@ -1,6 +1,8 @@
from llvmlite import ir from llvmlite import ir
import ast import ast
import logging
from typing import Any from typing import Any
from dataclasses import dataclass
from .helper import HelperHandlerRegistry, handle_helper_call from .helper import HelperHandlerRegistry, handle_helper_call
from .type_deducer import ctypes_to_ir from .type_deducer import ctypes_to_ir
@ -8,6 +10,14 @@ from .binary_ops import handle_binary_op
from .expr_pass import eval_expr, handle_expr from .expr_pass import eval_expr, handle_expr
local_var_metadata: dict[str | Any, Any] = {} local_var_metadata: dict[str | Any, Any] = {}
logger = logging.getLogger(__name__)
@dataclass
class LocalSymbol:
var: ir.AllocaInstr
ir_type: ir.Type
metadata: Any = None
def get_probe_string(func_node): def get_probe_string(func_node):
@ -83,16 +93,19 @@ def handle_assign(
elif isinstance(rval, ast.Constant): elif isinstance(rval, ast.Constant):
if isinstance(rval.value, bool): if isinstance(rval.value, bool):
if rval.value: if rval.value:
builder.store(ir.Constant(ir.IntType(1), 1), local_sym_tab[var_name][0]) builder.store(ir.Constant(ir.IntType(1), 1),
local_sym_tab[var_name][0])
else: else:
builder.store(ir.Constant(ir.IntType(1), 0), local_sym_tab[var_name][0]) builder.store(ir.Constant(ir.IntType(1), 0),
local_sym_tab[var_name][0])
print(f"Assigned constant {rval.value} to {var_name}") print(f"Assigned constant {rval.value} to {var_name}")
elif isinstance(rval.value, int): elif isinstance(rval.value, int):
# Assume c_int64 for now # Assume c_int64 for now
# var = builder.alloca(ir.IntType(64), name=var_name) # var = builder.alloca(ir.IntType(64), name=var_name)
# var.align = 8 # var.align = 8
builder.store( builder.store(
ir.Constant(ir.IntType(64), rval.value), local_sym_tab[var_name][0] ir.Constant(ir.IntType(64),
rval.value), local_sym_tab[var_name][0]
) )
# local_sym_tab[var_name] = var # local_sym_tab[var_name] = var
print(f"Assigned constant {rval.value} to {var_name}") print(f"Assigned constant {rval.value} to {var_name}")
@ -107,7 +120,8 @@ def handle_assign(
global_str.linkage = "internal" global_str.linkage = "internal"
global_str.global_constant = True global_str.global_constant = True
global_str.initializer = str_const global_str.initializer = str_const
str_ptr = builder.bitcast(global_str, ir.PointerType(ir.IntType(8))) str_ptr = builder.bitcast(
global_str, ir.PointerType(ir.IntType(8)))
builder.store(str_ptr, local_sym_tab[var_name][0]) builder.store(str_ptr, local_sym_tab[var_name][0])
print(f"Assigned string constant '{rval.value}' to {var_name}") print(f"Assigned string constant '{rval.value}' to {var_name}")
else: else:
@ -126,7 +140,8 @@ def handle_assign(
# var = builder.alloca(ir_type, name=var_name) # var = builder.alloca(ir_type, name=var_name)
# var.align = ir_type.width // 8 # var.align = ir_type.width // 8
builder.store( builder.store(
ir.Constant(ir_type, rval.args[0].value), local_sym_tab[var_name][0] ir.Constant(
ir_type, rval.args[0].value), local_sym_tab[var_name][0]
) )
print( print(
f"Assigned {call_type} constant " f"Assigned {call_type} constant "
@ -172,7 +187,8 @@ def handle_assign(
ir_type = struct_info.ir_type ir_type = struct_info.ir_type
# var = builder.alloca(ir_type, name=var_name) # var = builder.alloca(ir_type, name=var_name)
# Null init # Null init
builder.store(ir.Constant(ir_type, None), local_sym_tab[var_name][0]) builder.store(ir.Constant(ir_type, None),
local_sym_tab[var_name][0])
local_var_metadata[var_name] = call_type local_var_metadata[var_name] = call_type
print(f"Assigned struct {call_type} to {var_name}") print(f"Assigned struct {call_type} to {var_name}")
# local_sym_tab[var_name] = var # local_sym_tab[var_name] = var
@ -243,7 +259,8 @@ def handle_cond(func, module, builder, cond, local_sym_tab, map_sym_tab):
print(f"Undefined variable {cond.id} in condition") print(f"Undefined variable {cond.id} in condition")
return None return None
elif isinstance(cond, ast.Compare): elif isinstance(cond, ast.Compare):
lhs = eval_expr(func, module, builder, cond.left, local_sym_tab, map_sym_tab)[0] lhs = eval_expr(func, module, builder, cond.left,
local_sym_tab, map_sym_tab)[0]
if len(cond.ops) != 1 or len(cond.comparators) != 1: if len(cond.ops) != 1 or len(cond.comparators) != 1:
print("Unsupported complex comparison") print("Unsupported complex comparison")
return None return None
@ -296,7 +313,8 @@ def handle_if(
else: else:
else_block = None else_block = None
cond = handle_cond(func, module, builder, stmt.test, local_sym_tab, map_sym_tab) cond = handle_cond(func, module, builder, stmt.test,
local_sym_tab, map_sym_tab)
if else_block: if else_block:
builder.cbranch(cond, then_block, else_block) builder.cbranch(cond, then_block, else_block)
else: else:
@ -441,7 +459,8 @@ def allocate_mem(
ir_type = ctypes_to_ir(call_type) ir_type = ctypes_to_ir(call_type)
var = builder.alloca(ir_type, name=var_name) var = builder.alloca(ir_type, name=var_name)
var.align = ir_type.width // 8 var.align = ir_type.width // 8
print(f"Pre-allocated variable {var_name} of type {call_type}") print(
f"Pre-allocated variable {var_name} of type {call_type}")
elif HelperHandlerRegistry.has_handler(call_type): elif HelperHandlerRegistry.has_handler(call_type):
# Assume return type is int64 for now # Assume return type is int64 for now
ir_type = ir.IntType(64) ir_type = ir.IntType(64)
@ -662,7 +681,8 @@ def infer_return_type(func_node: ast.FunctionDef):
if found_type is None: if found_type is None:
found_type = t found_type = t
elif found_type != t: elif found_type != t:
raise ValueError("Conflicting return types:" f"{found_type} vs {t}") raise ValueError("Conflicting return types:" f"{
found_type} vs {t}")
return found_type or "None" return found_type or "None"
@ -699,7 +719,8 @@ def assign_string_to_array(builder, target_array_ptr, source_string_ptr, array_l
char = builder.load(src_ptr) char = builder.load(src_ptr)
# Store character in target # Store character in target
dst_ptr = builder.gep(target_array_ptr, [ir.Constant(ir.IntType(32), 0), idx]) dst_ptr = builder.gep(
target_array_ptr, [ir.Constant(ir.IntType(32), 0), idx])
builder.store(char, dst_ptr) builder.store(char, dst_ptr)
# Increment counter # Increment counter
@ -710,5 +731,6 @@ def assign_string_to_array(builder, target_array_ptr, source_string_ptr, array_l
# Ensure null termination # Ensure null termination
last_idx = ir.Constant(ir.IntType(32), array_length - 1) last_idx = ir.Constant(ir.IntType(32), array_length - 1)
null_ptr = builder.gep(target_array_ptr, [ir.Constant(ir.IntType(32), 0), last_idx]) null_ptr = builder.gep(
target_array_ptr, [ir.Constant(ir.IntType(32), 0), last_idx])
builder.store(ir.Constant(ir.IntType(8), 0), null_ptr) builder.store(ir.Constant(ir.IntType(8), 0), null_ptr)

View File

@ -85,7 +85,7 @@ def create_bpf_map(module, map_name, map_params):
def create_map_debug_info(module, map_global, map_name, map_params): def create_map_debug_info(module, map_global, map_name, map_params):
"""Generate debug information metadata for BPF maps HASH and PERF_EVENT_ARRAY""" """Generate debug info metadata for BPF maps HASH and PERF_EVENT_ARRAY"""
generator = DebugInfoGenerator(module) generator = DebugInfoGenerator(module)
uint_type = generator.get_uint32_type() uint_type = generator.get_uint32_type()
@ -158,7 +158,8 @@ def create_ringbuf_debug_info(module, map_global, map_name, map_params):
type_ptr = generator.create_pointer_type(type_array, 64) type_ptr = generator.create_pointer_type(type_array, 64)
type_member = generator.create_struct_member("type", type_ptr, 0) type_member = generator.create_struct_member("type", type_ptr, 0)
max_entries_array = generator.create_array_type(int_type, map_params["max_entries"]) max_entries_array = generator.create_array_type(
int_type, map_params["max_entries"])
max_entries_ptr = generator.create_pointer_type(max_entries_array, 64) max_entries_ptr = generator.create_pointer_type(max_entries_array, 64)
max_entries_member = generator.create_struct_member( max_entries_member = generator.create_struct_member(
"max_entries", max_entries_ptr, 64 "max_entries", max_entries_ptr, 64
@ -166,7 +167,8 @@ def create_ringbuf_debug_info(module, map_global, map_name, map_params):
elements_arr = [type_member, max_entries_member] elements_arr = [type_member, max_entries_member]
struct_type = generator.create_struct_type(elements_arr, 128, is_distinct=True) struct_type = generator.create_struct_type(
elements_arr, 128, is_distinct=True)
global_var = generator.create_global_var_debug_info( global_var = generator.create_global_var_debug_info(
map_name, struct_type, is_local=False map_name, struct_type, is_local=False