Add LocalSymbol dataclass

This commit is contained in:
Pragyansh Chaturvedi
2025-10-02 04:13:24 +05:30
parent 1a66887f48
commit 2fd2a46838
2 changed files with 39 additions and 15 deletions

View File

@ -1,6 +1,8 @@
from llvmlite import ir
import ast
import logging
from typing import Any
from dataclasses import dataclass
from .helper import HelperHandlerRegistry, handle_helper_call
from .type_deducer import ctypes_to_ir
@ -8,6 +10,14 @@ from .binary_ops import handle_binary_op
from .expr_pass import eval_expr, handle_expr
local_var_metadata: dict[str | Any, Any] = {}
logger = logging.getLogger(__name__)
@dataclass
class LocalSymbol:
var: ir.AllocaInstr
ir_type: ir.Type
metadata: Any = None
def get_probe_string(func_node):
@ -83,16 +93,19 @@ def handle_assign(
elif isinstance(rval, ast.Constant):
if isinstance(rval.value, bool):
if rval.value:
builder.store(ir.Constant(ir.IntType(1), 1), local_sym_tab[var_name][0])
builder.store(ir.Constant(ir.IntType(1), 1),
local_sym_tab[var_name][0])
else:
builder.store(ir.Constant(ir.IntType(1), 0), local_sym_tab[var_name][0])
builder.store(ir.Constant(ir.IntType(1), 0),
local_sym_tab[var_name][0])
print(f"Assigned constant {rval.value} to {var_name}")
elif isinstance(rval.value, int):
# Assume c_int64 for now
# var = builder.alloca(ir.IntType(64), name=var_name)
# var.align = 8
builder.store(
ir.Constant(ir.IntType(64), rval.value), local_sym_tab[var_name][0]
ir.Constant(ir.IntType(64),
rval.value), local_sym_tab[var_name][0]
)
# local_sym_tab[var_name] = var
print(f"Assigned constant {rval.value} to {var_name}")
@ -107,7 +120,8 @@ def handle_assign(
global_str.linkage = "internal"
global_str.global_constant = True
global_str.initializer = str_const
str_ptr = builder.bitcast(global_str, ir.PointerType(ir.IntType(8)))
str_ptr = builder.bitcast(
global_str, ir.PointerType(ir.IntType(8)))
builder.store(str_ptr, local_sym_tab[var_name][0])
print(f"Assigned string constant '{rval.value}' to {var_name}")
else:
@ -126,7 +140,8 @@ def handle_assign(
# var = builder.alloca(ir_type, name=var_name)
# var.align = ir_type.width // 8
builder.store(
ir.Constant(ir_type, rval.args[0].value), local_sym_tab[var_name][0]
ir.Constant(
ir_type, rval.args[0].value), local_sym_tab[var_name][0]
)
print(
f"Assigned {call_type} constant "
@ -172,7 +187,8 @@ def handle_assign(
ir_type = struct_info.ir_type
# var = builder.alloca(ir_type, name=var_name)
# Null init
builder.store(ir.Constant(ir_type, None), local_sym_tab[var_name][0])
builder.store(ir.Constant(ir_type, None),
local_sym_tab[var_name][0])
local_var_metadata[var_name] = call_type
print(f"Assigned struct {call_type} to {var_name}")
# local_sym_tab[var_name] = var
@ -243,7 +259,8 @@ def handle_cond(func, module, builder, cond, local_sym_tab, map_sym_tab):
print(f"Undefined variable {cond.id} in condition")
return None
elif isinstance(cond, ast.Compare):
lhs = eval_expr(func, module, builder, cond.left, local_sym_tab, map_sym_tab)[0]
lhs = eval_expr(func, module, builder, cond.left,
local_sym_tab, map_sym_tab)[0]
if len(cond.ops) != 1 or len(cond.comparators) != 1:
print("Unsupported complex comparison")
return None
@ -296,7 +313,8 @@ def handle_if(
else:
else_block = None
cond = handle_cond(func, module, builder, stmt.test, local_sym_tab, map_sym_tab)
cond = handle_cond(func, module, builder, stmt.test,
local_sym_tab, map_sym_tab)
if else_block:
builder.cbranch(cond, then_block, else_block)
else:
@ -441,7 +459,8 @@ def allocate_mem(
ir_type = ctypes_to_ir(call_type)
var = builder.alloca(ir_type, name=var_name)
var.align = ir_type.width // 8
print(f"Pre-allocated variable {var_name} of type {call_type}")
print(
f"Pre-allocated variable {var_name} of type {call_type}")
elif HelperHandlerRegistry.has_handler(call_type):
# Assume return type is int64 for now
ir_type = ir.IntType(64)
@ -662,7 +681,8 @@ def infer_return_type(func_node: ast.FunctionDef):
if found_type is None:
found_type = t
elif found_type != t:
raise ValueError("Conflicting return types:" f"{found_type} vs {t}")
raise ValueError("Conflicting return types:" f"{
found_type} vs {t}")
return found_type or "None"
@ -699,7 +719,8 @@ def assign_string_to_array(builder, target_array_ptr, source_string_ptr, array_l
char = builder.load(src_ptr)
# Store character in target
dst_ptr = builder.gep(target_array_ptr, [ir.Constant(ir.IntType(32), 0), idx])
dst_ptr = builder.gep(
target_array_ptr, [ir.Constant(ir.IntType(32), 0), idx])
builder.store(char, dst_ptr)
# Increment counter
@ -710,5 +731,6 @@ def assign_string_to_array(builder, target_array_ptr, source_string_ptr, array_l
# Ensure null termination
last_idx = ir.Constant(ir.IntType(32), array_length - 1)
null_ptr = builder.gep(target_array_ptr, [ir.Constant(ir.IntType(32), 0), last_idx])
null_ptr = builder.gep(
target_array_ptr, [ir.Constant(ir.IntType(32), 0), last_idx])
builder.store(ir.Constant(ir.IntType(8), 0), null_ptr)

View File

@ -85,7 +85,7 @@ def create_bpf_map(module, map_name, map_params):
def create_map_debug_info(module, map_global, map_name, map_params):
"""Generate debug information metadata for BPF maps HASH and PERF_EVENT_ARRAY"""
"""Generate debug info metadata for BPF maps HASH and PERF_EVENT_ARRAY"""
generator = DebugInfoGenerator(module)
uint_type = generator.get_uint32_type()
@ -158,7 +158,8 @@ def create_ringbuf_debug_info(module, map_global, map_name, map_params):
type_ptr = generator.create_pointer_type(type_array, 64)
type_member = generator.create_struct_member("type", type_ptr, 0)
max_entries_array = generator.create_array_type(int_type, map_params["max_entries"])
max_entries_array = generator.create_array_type(
int_type, map_params["max_entries"])
max_entries_ptr = generator.create_pointer_type(max_entries_array, 64)
max_entries_member = generator.create_struct_member(
"max_entries", max_entries_ptr, 64
@ -166,7 +167,8 @@ def create_ringbuf_debug_info(module, map_global, map_name, map_params):
elements_arr = [type_member, max_entries_member]
struct_type = generator.create_struct_type(elements_arr, 128, is_distinct=True)
struct_type = generator.create_struct_type(
elements_arr, 128, is_distinct=True)
global_var = generator.create_global_var_debug_info(
map_name, struct_type, is_local=False