mirror of
https://github.com/varun-r-mallya/Python-BPF.git
synced 2026-04-23 22:21:28 +00:00
Compare commits
7 Commits
c9bbe1ffd8
...
cd74e896cf
| Author | SHA1 | Date | |
|---|---|---|---|
| cd74e896cf | |||
| 207f714027 | |||
| 5dcf670f49 | |||
| 6bce29b90f | |||
| 321415fa28 | |||
| 8776d7607f | |||
| 8b7b1c08a5 |
@ -4,7 +4,11 @@ import logging
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call
|
from pythonbpf.helper import (
|
||||||
|
HelperHandlerRegistry,
|
||||||
|
handle_helper_call,
|
||||||
|
reset_scratch_pool,
|
||||||
|
)
|
||||||
from pythonbpf.type_deducer import ctypes_to_ir
|
from pythonbpf.type_deducer import ctypes_to_ir
|
||||||
from pythonbpf.binary_ops import handle_binary_op
|
from pythonbpf.binary_ops import handle_binary_op
|
||||||
from pythonbpf.expr import eval_expr, handle_expr, convert_to_bool
|
from pythonbpf.expr import eval_expr, handle_expr, convert_to_bool
|
||||||
@ -353,6 +357,7 @@ def process_stmt(
|
|||||||
ret_type=ir.IntType(64),
|
ret_type=ir.IntType(64),
|
||||||
):
|
):
|
||||||
logger.info(f"Processing statement: {ast.dump(stmt)}")
|
logger.info(f"Processing statement: {ast.dump(stmt)}")
|
||||||
|
reset_scratch_pool()
|
||||||
if isinstance(stmt, ast.Expr):
|
if isinstance(stmt, ast.Expr):
|
||||||
handle_expr(
|
handle_expr(
|
||||||
func,
|
func,
|
||||||
@ -383,11 +388,49 @@ def process_stmt(
|
|||||||
return did_return
|
return did_return
|
||||||
|
|
||||||
|
|
||||||
|
def count_temps_in_call(call_node):
|
||||||
|
"""Count the number of temporary variables needed for a function call."""
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
is_helper = False
|
||||||
|
|
||||||
|
if isinstance(call_node.func, ast.Name):
|
||||||
|
if HelperHandlerRegistry.has_handler(call_node.func.id):
|
||||||
|
is_helper = True
|
||||||
|
elif isinstance(call_node.func, ast.Attribute):
|
||||||
|
if HelperHandlerRegistry.has_handler(call_node.func.attr):
|
||||||
|
is_helper = True
|
||||||
|
|
||||||
|
if not is_helper:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
for arg in call_node.args:
|
||||||
|
if (
|
||||||
|
isinstance(arg, ast.BinOp)
|
||||||
|
or isinstance(arg, ast.Constant)
|
||||||
|
or isinstance(arg, ast.UnaryOp)
|
||||||
|
):
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
def allocate_mem(
|
def allocate_mem(
|
||||||
module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
|
module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
|
||||||
):
|
):
|
||||||
double_alloc = False
|
double_alloc = False
|
||||||
|
max_temps_needed = 0
|
||||||
|
|
||||||
|
def update_max_temps_for_stmt(stmt):
|
||||||
|
nonlocal max_temps_needed
|
||||||
|
|
||||||
|
for node in ast.walk(stmt):
|
||||||
|
if isinstance(node, ast.Call):
|
||||||
|
temps_needed = count_temps_in_call(node)
|
||||||
|
max_temps_needed = max(max_temps_needed, temps_needed)
|
||||||
|
|
||||||
for stmt in body:
|
for stmt in body:
|
||||||
|
update_max_temps_for_stmt(stmt)
|
||||||
has_metadata = False
|
has_metadata = False
|
||||||
if isinstance(stmt, ast.If):
|
if isinstance(stmt, ast.If):
|
||||||
if stmt.body:
|
if stmt.body:
|
||||||
@ -508,6 +551,13 @@ def allocate_mem(
|
|||||||
|
|
||||||
if double_alloc:
|
if double_alloc:
|
||||||
local_sym_tab[f"{var_name}_tmp"] = LocalSymbol(var_tmp, tmp_ir_type)
|
local_sym_tab[f"{var_name}_tmp"] = LocalSymbol(var_tmp, tmp_ir_type)
|
||||||
|
|
||||||
|
logger.info(f"Temporary scratch space needed for calls: {max_temps_needed}")
|
||||||
|
for i in range(max_temps_needed):
|
||||||
|
temp_var = builder.alloca(ir.IntType(64), name=f"__helper_temp_{i}")
|
||||||
|
temp_var.align = 8
|
||||||
|
local_sym_tab[f"__helper_temp_{i}"] = LocalSymbol(temp_var, ir.IntType(64))
|
||||||
|
|
||||||
return local_sym_tab
|
return local_sym_tab
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
from .helper_utils import HelperHandlerRegistry
|
from .helper_utils import HelperHandlerRegistry, reset_scratch_pool
|
||||||
from .bpf_helper_handler import handle_helper_call
|
from .bpf_helper_handler import handle_helper_call
|
||||||
from .helpers import ktime, pid, deref, XDP_DROP, XDP_PASS
|
from .helpers import ktime, pid, deref, XDP_DROP, XDP_PASS
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"HelperHandlerRegistry",
|
"HelperHandlerRegistry",
|
||||||
|
"reset_scratch_pool",
|
||||||
"handle_helper_call",
|
"handle_helper_call",
|
||||||
"ktime",
|
"ktime",
|
||||||
"pid",
|
"pid",
|
||||||
|
|||||||
@ -34,6 +34,41 @@ class HelperHandlerRegistry:
|
|||||||
return helper_name in cls._handlers
|
return helper_name in cls._handlers
|
||||||
|
|
||||||
|
|
||||||
|
class ScratchPoolManager:
|
||||||
|
"""Manage the temporary helper variables in local_sym_tab"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._counter = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def counter(self):
|
||||||
|
return self._counter
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._counter = 0
|
||||||
|
logger.debug("Scratch pool counter reset to 0")
|
||||||
|
|
||||||
|
def get_next_temp(self, local_sym_tab):
|
||||||
|
temp_name = f"__helper_temp_{self._counter}"
|
||||||
|
self._counter += 1
|
||||||
|
|
||||||
|
if temp_name not in local_sym_tab:
|
||||||
|
raise ValueError(
|
||||||
|
f"Scratch pool exhausted or inadequate: {temp_name}. "
|
||||||
|
f"Current counter: {self._counter}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return local_sym_tab[temp_name].var, temp_name
|
||||||
|
|
||||||
|
|
||||||
|
_temp_pool_manager = ScratchPoolManager() # Singleton instance
|
||||||
|
|
||||||
|
|
||||||
|
def reset_scratch_pool():
|
||||||
|
"""Reset the scratch pool counter"""
|
||||||
|
_temp_pool_manager.reset()
|
||||||
|
|
||||||
|
|
||||||
def get_var_ptr_from_name(var_name, local_sym_tab):
|
def get_var_ptr_from_name(var_name, local_sym_tab):
|
||||||
"""Get a pointer to a variable from the symbol table."""
|
"""Get a pointer to a variable from the symbol table."""
|
||||||
if local_sym_tab and var_name in local_sym_tab:
|
if local_sym_tab and var_name in local_sym_tab:
|
||||||
@ -41,13 +76,14 @@ def get_var_ptr_from_name(var_name, local_sym_tab):
|
|||||||
raise ValueError(f"Variable '{var_name}' not found in local symbol table")
|
raise ValueError(f"Variable '{var_name}' not found in local symbol table")
|
||||||
|
|
||||||
|
|
||||||
def create_int_constant_ptr(value, builder, int_width=64):
|
def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64):
|
||||||
"""Create a pointer to an integer constant."""
|
"""Create a pointer to an integer constant."""
|
||||||
|
|
||||||
# Default to 64-bit integer
|
# Default to 64-bit integer
|
||||||
int_type = ir.IntType(int_width)
|
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab)
|
||||||
ptr = builder.alloca(int_type)
|
logger.debug(f"Using temp variable '{temp_name}' for int constant {value}")
|
||||||
ptr.align = int_type.width // 8
|
const_val = ir.Constant(ir.IntType(int_width), value)
|
||||||
builder.store(ir.Constant(int_type, value), ptr)
|
builder.store(const_val, ptr)
|
||||||
return ptr
|
return ptr
|
||||||
|
|
||||||
|
|
||||||
@ -57,7 +93,26 @@ def get_or_create_ptr_from_arg(arg, builder, local_sym_tab):
|
|||||||
if isinstance(arg, ast.Name):
|
if isinstance(arg, ast.Name):
|
||||||
ptr = get_var_ptr_from_name(arg.id, local_sym_tab)
|
ptr = get_var_ptr_from_name(arg.id, local_sym_tab)
|
||||||
elif isinstance(arg, ast.Constant) and isinstance(arg.value, int):
|
elif isinstance(arg, ast.Constant) and isinstance(arg.value, int):
|
||||||
ptr = create_int_constant_ptr(arg.value, builder)
|
ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab)
|
||||||
|
elif isinstance(arg, ast.BinOp):
|
||||||
|
# Evaluate the expression and store the result in a temp variable
|
||||||
|
val, _ = eval_expr(
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
builder,
|
||||||
|
arg,
|
||||||
|
local_sym_tab,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if val is None:
|
||||||
|
raise ValueError("Failed to evaluate expression for helper arg.")
|
||||||
|
|
||||||
|
# NOTE: We assume the result is an int64 for now
|
||||||
|
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab)
|
||||||
|
logger.debug(f"Using temp variable '{temp_name}' for expression result")
|
||||||
|
builder.store(val, ptr)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Only simple variable names are supported as args in map helpers."
|
"Only simple variable names are supported as args in map helpers."
|
||||||
|
|||||||
40
tests/passing_tests/assign/struct_and_helper_binops.py
Normal file
40
tests/passing_tests/assign/struct_and_helper_binops.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from pythonbpf import bpf, section, bpfglobal, compile, struct
|
||||||
|
from ctypes import c_void_p, c_int64, c_uint64
|
||||||
|
from pythonbpf.helper import ktime
|
||||||
|
|
||||||
|
|
||||||
|
@bpf
|
||||||
|
@struct
|
||||||
|
class data_t:
|
||||||
|
pid: c_uint64
|
||||||
|
ts: c_uint64
|
||||||
|
|
||||||
|
|
||||||
|
@bpf
|
||||||
|
@section("tracepoint/syscalls/sys_enter_execve")
|
||||||
|
def hello_world(ctx: c_void_p) -> c_int64:
|
||||||
|
dat = data_t()
|
||||||
|
dat.pid = 123
|
||||||
|
dat.pid = dat.pid + 1
|
||||||
|
print(f"pid is {dat.pid}")
|
||||||
|
x = ktime() - 121
|
||||||
|
print(f"ktime is {x}")
|
||||||
|
x = 1
|
||||||
|
x = x + 1
|
||||||
|
print(f"x is {x}")
|
||||||
|
if x == 2:
|
||||||
|
jat = data_t()
|
||||||
|
jat.ts = 456
|
||||||
|
print(f"Hello, World!, ts is {jat.ts}")
|
||||||
|
else:
|
||||||
|
print("Goodbye, World!")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@bpf
|
||||||
|
@bpfglobal
|
||||||
|
def LICENSE() -> str:
|
||||||
|
return "GPL"
|
||||||
|
|
||||||
|
|
||||||
|
compile()
|
||||||
Reference in New Issue
Block a user