fix broken IR generation logic for globals

This commit is contained in:
2025-10-03 22:55:40 +05:30
parent c3a512d5cf
commit ab1c4223d5
2 changed files with 65 additions and 28 deletions

View File

@ -1,8 +1,6 @@
from llvmlite import ir from llvmlite import ir
import ast import ast
from llvmlite import ir
import ast
from logging import Logger from logging import Logger
import logging import logging
from .type_deducer import ctypes_to_ir from .type_deducer import ctypes_to_ir
@ -11,11 +9,41 @@ logger: Logger = logging.getLogger(__name__)
def emit_global(module: ir.Module, node, name): def emit_global(module: ir.Module, node, name):
print("global", node.returns.id) logger.info(f"global identifier {name} processing")
# TODO: below part is LLM generated check logic.
# deduce LLVM type from the annotated return
if not isinstance(node.returns, ast.Name):
raise ValueError(f"Unsupported return annotation {ast.dump(node.returns)}")
ty = ctypes_to_ir(node.returns.id) ty = ctypes_to_ir(node.returns.id)
# extract the return expression
ret_stmt = node.body[0]
if not isinstance(ret_stmt, ast.Return) or ret_stmt.value is None:
raise ValueError(f"Global '{name}' has no valid return")
init_val = ret_stmt.value
# simple constant like "return 0"
if isinstance(init_val, ast.Constant):
llvm_init = ir.Constant(ty, init_val.value)
# variable reference like "return SOME_CONST"
elif isinstance(init_val, ast.Name):
# you may need symbol resolution here, stub as 0 for now
raise ValueError(f"Name reference {init_val.id} not yet supported")
# constructor call like "return c_int64(0)" or dataclass(...)
elif isinstance(init_val, ast.Call):
if len(init_val.args) == 1 and isinstance(init_val.args[0], ast.Constant):
llvm_init = ir.Constant(ty, init_val.args[0].value)
else:
raise ValueError(f"Complex constructor not supported: {ast.dump(init_val)}")
else:
raise ValueError(f"Unsupported return expr {ast.dump(init_val)}")
gvar = ir.GlobalVariable(module, ty, name=name) gvar = ir.GlobalVariable(module, ty, name=name)
gvar.initializer = ir.Constant(ty, initial_value) gvar.initializer = llvm_init
gvar.align = 8 gvar.align = 8
gvar.linkage = "dso_local" gvar.linkage = "dso_local"
gvar.global_constant = False gvar.global_constant = False
@ -24,11 +52,11 @@ def emit_global(module: ir.Module, node, name):
def globals_processing(tree, module): def globals_processing(tree, module):
"""Process stuff decorated with @bpf and @bpfglobal except license and return the section name""" """Process stuff decorated with @bpf and @bpfglobal except license and return the section name"""
global_sym_tab = [] globals_sym_tab = []
for node in tree.body: for node in tree.body:
# Skip non-assignment and non-function nodes # Skip non-assignment and non-function nodes
if not (isinstance(node, (ast.FunctionDef, ast.AnnAssign, ast.Assign))): if not (isinstance(node, ast.FunctionDef)):
continue continue
# Get the name based on node type # Get the name based on node type
@ -38,33 +66,31 @@ def globals_processing(tree, module):
continue continue
# Check for duplicate names # Check for duplicate names
if name in global_sym_tab: if name in globals_sym_tab:
raise SyntaxError(f"ERROR: Global name '{name}' previously defined") raise SyntaxError(f"ERROR: Global name '{name}' previously defined")
else: else:
global_sym_tab.append(name) globals_sym_tab.append(name)
# Process decorated functions
if isinstance(node, ast.FunctionDef) and node.name != "LICENSE": if isinstance(node, ast.FunctionDef) and node.name != "LICENSE":
# Check decorators
decorators = [ decorators = [
dec.id for dec in node.decorator_list if isinstance(dec, ast.Name) dec.id for dec in node.decorator_list if isinstance(dec, ast.Name)
] ]
if "bpf" in decorators and "bpfglobal" in decorators: if "bpf" in decorators and "bpfglobal" in decorators:
if ( if (
len(node.body) == 1 len(node.body) == 1
and isinstance(node.body[0], ast.Return) and isinstance(node.body[0], ast.Return)
and node.body[0].value is not None and node.body[0].value is not None
and isinstance(node.body[0].value, (ast.Constant, ast.Name)) and isinstance(
node.body[0].value, (ast.Constant, ast.Name, ast.Call)
)
): ):
emit_global(module, node, name) emit_global(module, node, name)
return node.name
else: else:
logger.info(f"Invalid global expression for '{node.name}'") raise SyntaxError(f"ERROR: Invalid syntax for {name} global")
return None
return None return None
def emit_llvm_compiler_used(module: ir.Module, names: list[str]): def emit_llvm_compiler_used(module: ir.Module, names: list[str]):
""" """
Emit the @llvm.compiler.used global given a list of function/global names. Emit the @llvm.compiler.used global given a list of function/global names.
@ -94,12 +120,12 @@ def globals_list_creation(tree, module: ir.Module):
if isinstance(node, ast.FunctionDef): if isinstance(node, ast.FunctionDef):
for dec in node.decorator_list: for dec in node.decorator_list:
if ( if (
isinstance(dec, ast.Call) isinstance(dec, ast.Call)
and isinstance(dec.func, ast.Name) and isinstance(dec.func, ast.Name)
and dec.func.id == "section" and dec.func.id == "section"
and len(dec.args) == 1 and len(dec.args) == 1
and isinstance(dec.args[0], ast.Constant) and isinstance(dec.args[0], ast.Constant)
and isinstance(dec.args[0].value, str) and isinstance(dec.args[0].value, str)
): ):
collected.append(node.name) collected.append(node.name)

View File

@ -1,17 +1,28 @@
import logging import logging
from pythonbpf import compile, bpf, section, bpfglobal, compile_to_ir from pythonbpf import compile, bpf, section, bpfglobal, compile_to_ir
from ctypes import c_void_p, c_int64 from ctypes import c_void_p, c_int64, c_int32
@bpf @bpf
@bpfglobal @bpfglobal
def somevalue() -> c_int64: def somevalue() -> c_int32:
return c_int32(0)
@bpf
@bpfglobal
def somevalue2() -> c_int64:
return c_int64(0) return c_int64(0)
@bpf @bpf
@section("sometag1") @bpfglobal
def somevalue1() -> c_int32:
return c_int32(0)
@bpf
@section("tracepoint/syscalls/sys_enter_execve")
def sometag(ctx: c_void_p) -> c_int64: def sometag(ctx: c_void_p) -> c_int64:
return c_int64(0) print("test")
return c_int64(1)
@bpf @bpf
@bpfglobal @bpfglobal