Compare commits

...

7 Commits

4 changed files with 154 additions and 8 deletions

View File

@ -4,7 +4,11 @@ import logging
from typing import Any
from dataclasses import dataclass
from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call
from pythonbpf.helper import (
HelperHandlerRegistry,
handle_helper_call,
reset_scratch_pool,
)
from pythonbpf.type_deducer import ctypes_to_ir
from pythonbpf.binary_ops import handle_binary_op
from pythonbpf.expr import eval_expr, handle_expr, convert_to_bool
@ -353,6 +357,7 @@ def process_stmt(
ret_type=ir.IntType(64),
):
logger.info(f"Processing statement: {ast.dump(stmt)}")
reset_scratch_pool()
if isinstance(stmt, ast.Expr):
handle_expr(
func,
@ -383,11 +388,49 @@ def process_stmt(
return did_return
def count_temps_in_call(call_node):
"""Count the number of temporary variables needed for a function call."""
count = 0
is_helper = False
if isinstance(call_node.func, ast.Name):
if HelperHandlerRegistry.has_handler(call_node.func.id):
is_helper = True
elif isinstance(call_node.func, ast.Attribute):
if HelperHandlerRegistry.has_handler(call_node.func.attr):
is_helper = True
if not is_helper:
return 0
for arg in call_node.args:
if (
isinstance(arg, ast.BinOp)
or isinstance(arg, ast.Constant)
or isinstance(arg, ast.UnaryOp)
):
count += 1
return count
def allocate_mem(
module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
):
double_alloc = False
max_temps_needed = 0
def update_max_temps_for_stmt(stmt):
nonlocal max_temps_needed
for node in ast.walk(stmt):
if isinstance(node, ast.Call):
temps_needed = count_temps_in_call(node)
max_temps_needed = max(max_temps_needed, temps_needed)
for stmt in body:
update_max_temps_for_stmt(stmt)
has_metadata = False
if isinstance(stmt, ast.If):
if stmt.body:
@ -508,6 +551,13 @@ def allocate_mem(
if double_alloc:
local_sym_tab[f"{var_name}_tmp"] = LocalSymbol(var_tmp, tmp_ir_type)
logger.info(f"Temporary scratch space needed for calls: {max_temps_needed}")
for i in range(max_temps_needed):
temp_var = builder.alloca(ir.IntType(64), name=f"__helper_temp_{i}")
temp_var.align = 8
local_sym_tab[f"__helper_temp_{i}"] = LocalSymbol(temp_var, ir.IntType(64))
return local_sym_tab

View File

@ -1,9 +1,10 @@
from .helper_utils import HelperHandlerRegistry
from .helper_utils import HelperHandlerRegistry, reset_scratch_pool
from .bpf_helper_handler import handle_helper_call
from .helpers import ktime, pid, deref, XDP_DROP, XDP_PASS
__all__ = [
"HelperHandlerRegistry",
"reset_scratch_pool",
"handle_helper_call",
"ktime",
"pid",

View File

@ -34,6 +34,41 @@ class HelperHandlerRegistry:
return helper_name in cls._handlers
class ScratchPoolManager:
"""Manage the temporary helper variables in local_sym_tab"""
def __init__(self):
self._counter = 0
@property
def counter(self):
return self._counter
def reset(self):
self._counter = 0
logger.debug("Scratch pool counter reset to 0")
def get_next_temp(self, local_sym_tab):
temp_name = f"__helper_temp_{self._counter}"
self._counter += 1
if temp_name not in local_sym_tab:
raise ValueError(
f"Scratch pool exhausted or inadequate: {temp_name}. "
f"Current counter: {self._counter}"
)
return local_sym_tab[temp_name].var, temp_name
_temp_pool_manager = ScratchPoolManager() # Singleton instance
def reset_scratch_pool():
"""Reset the scratch pool counter"""
_temp_pool_manager.reset()
def get_var_ptr_from_name(var_name, local_sym_tab):
"""Get a pointer to a variable from the symbol table."""
if local_sym_tab and var_name in local_sym_tab:
@ -41,13 +76,14 @@ def get_var_ptr_from_name(var_name, local_sym_tab):
raise ValueError(f"Variable '{var_name}' not found in local symbol table")
def create_int_constant_ptr(value, builder, int_width=64):
def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64):
"""Create a pointer to an integer constant."""
# Default to 64-bit integer
int_type = ir.IntType(int_width)
ptr = builder.alloca(int_type)
ptr.align = int_type.width // 8
builder.store(ir.Constant(int_type, value), ptr)
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab)
logger.debug(f"Using temp variable '{temp_name}' for int constant {value}")
const_val = ir.Constant(ir.IntType(int_width), value)
builder.store(const_val, ptr)
return ptr
@ -57,7 +93,26 @@ def get_or_create_ptr_from_arg(arg, builder, local_sym_tab):
if isinstance(arg, ast.Name):
ptr = get_var_ptr_from_name(arg.id, local_sym_tab)
elif isinstance(arg, ast.Constant) and isinstance(arg.value, int):
ptr = create_int_constant_ptr(arg.value, builder)
ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab)
elif isinstance(arg, ast.BinOp):
# Evaluate the expression and store the result in a temp variable
val, _ = eval_expr(
None,
None,
builder,
arg,
local_sym_tab,
None,
None,
)
if val is None:
raise ValueError("Failed to evaluate expression for helper arg.")
# NOTE: We assume the result is an int64 for now
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab)
logger.debug(f"Using temp variable '{temp_name}' for expression result")
builder.store(val, ptr)
else:
raise NotImplementedError(
"Only simple variable names are supported as args in map helpers."

View File

@ -0,0 +1,40 @@
from pythonbpf import bpf, section, bpfglobal, compile, struct
from ctypes import c_void_p, c_int64, c_uint64
from pythonbpf.helper import ktime
@bpf
@struct
class data_t:
pid: c_uint64
ts: c_uint64
@bpf
@section("tracepoint/syscalls/sys_enter_execve")
def hello_world(ctx: c_void_p) -> c_int64:
dat = data_t()
dat.pid = 123
dat.pid = dat.pid + 1
print(f"pid is {dat.pid}")
x = ktime() - 121
print(f"ktime is {x}")
x = 1
x = x + 1
print(f"x is {x}")
if x == 2:
jat = data_t()
jat.ts = 456
print(f"Hello, World!, ts is {jat.ts}")
else:
print("Goodbye, World!")
return
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile()