15 Commits

Author SHA1 Message Date
83d9f4b34f add failing test 2025-10-02 15:47:35 +05:30
e83215391a add ringbuf submit function. commit does not verify on input, but the mirror C code does not as well. 2025-10-02 06:31:35 +05:30
2a93a325ce add ringbuf reserve function 2025-10-02 06:07:17 +05:30
1a66887f48 move helper annotations to helpers module 2025-10-02 01:55:32 +05:30
23f3cbcea7 add type annotations 2025-10-02 01:43:05 +05:30
429f51437f Merge pull request #15 from pythonbpf/static-type-checks
Static type checks
2025-10-02 01:38:46 +05:30
c92272dd35 workflow update 2025-10-02 01:37:36 +05:30
8792740eb0 workflow update 2025-10-02 01:36:14 +05:30
cf5faaad7f remove pointless type annotation
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
2025-10-02 01:27:03 +05:30
59b3d6514b fix ruff errors 2025-10-02 01:23:55 +05:30
3c956e671a add static type checking
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
2025-10-02 01:11:54 +05:30
8650297866 make type checks viable 2025-10-02 00:51:23 +05:30
6831f11179 Fix fstrings in examples, add alternate map attr access 2025-10-02 00:22:59 +05:30
d4e8e1bf73 Fix unterminated fstrings 2025-10-02 00:14:51 +05:30
08f2b283c9 Merge pull request #10 from pythonbpf/helper-refactor
bpf_helper_handler refactor
2025-10-02 00:08:59 +05:30
23 changed files with 486 additions and 223 deletions

View File

@ -5,10 +5,7 @@ name: Format
on: on:
workflow_dispatch: workflow_dispatch:
pull_request:
push: push:
branches:
- master
jobs: jobs:
pre-commit: pre-commit:

View File

@ -41,16 +41,15 @@ repos:
- id: ruff - id: ruff
args: ["--fix", "--show-fixes"] args: ["--fix", "--show-fixes"]
- id: ruff-format - id: ruff-format
exclude: ^(docs) exclude: ^(tests/|examples/|docs/)
## Checking static types # Checking static types
#- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
# rev: "v1.10.0" rev: "v1.10.0"
# hooks: hooks:
# - id: mypy - id: mypy
# files: "setup.py" exclude: ^(tests/|examples/)
# args: [] additional_dependencies: [types-setuptools]
# additional_dependencies: [types-setuptools]
# Changes tabs to spaces # Changes tabs to spaces
- repo: https://github.com/Lucas-C/pre-commit-hooks - repo: https://github.com/Lucas-C/pre-commit-hooks

View File

@ -60,12 +60,13 @@ pip install pythonbpf pylibbpf
```python ```python
import time import time
from pythonbpf import bpf, map, section, bpfglobal, BPF from pythonbpf import bpf, map, section, bpfglobal, BPF
from pythonbpf.helpers import pid from pythonbpf.helper import pid
from pythonbpf.maps import HashMap from pythonbpf.maps import HashMap
from pylibbpf import * from pylibbpf import *
from ctypes import c_void_p, c_int64, c_uint64, c_int32 from ctypes import c_void_p, c_int64, c_uint64, c_int32
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
# This program attaches an eBPF tracepoint to sys_enter_clone, # This program attaches an eBPF tracepoint to sys_enter_clone,
# counts per-PID clone syscalls, stores them in a hash map, # counts per-PID clone syscalls, stores them in a hash map,
# and then plots the distribution as a histogram using matplotlib. # and then plots the distribution as a histogram using matplotlib.
@ -76,6 +77,7 @@ import matplotlib.pyplot as plt
def hist() -> HashMap: def hist() -> HashMap:
return HashMap(key=c_int32, value=c_uint64, max_entries=4096) return HashMap(key=c_int32, value=c_uint64, max_entries=4096)
@bpf @bpf
@section("tracepoint/syscalls/sys_enter_clone") @section("tracepoint/syscalls/sys_enter_clone")
def hello(ctx: c_void_p) -> c_int64: def hello(ctx: c_void_p) -> c_int64:

View File

@ -1,5 +1,5 @@
from pythonbpf import bpf, map, section, bpfglobal, compile from pythonbpf import bpf, map, section, bpfglobal, compile
from pythonbpf.helpers import ktime from pythonbpf.helper import ktime
from pythonbpf.maps import HashMap from pythonbpf.maps import HashMap
from ctypes import c_void_p, c_int64, c_uint64 from ctypes import c_void_p, c_int64, c_uint64

View File

@ -1,5 +1,5 @@
from pythonbpf import bpf, map, section, bpfglobal, compile from pythonbpf import bpf, map, section, bpfglobal, compile
from pythonbpf.helpers import ktime from pythonbpf.helper import ktime
from pythonbpf.maps import HashMap from pythonbpf.maps import HashMap
from ctypes import c_void_p, c_int32, c_uint64 from ctypes import c_void_p, c_int32, c_uint64

View File

@ -10,7 +10,7 @@
"import time\n", "import time\n",
"\n", "\n",
"from pythonbpf import bpf, map, section, bpfglobal, BPF\n", "from pythonbpf import bpf, map, section, bpfglobal, BPF\n",
"from pythonbpf.helpers import pid\n", "from pythonbpf.helper import pid\n",
"from pythonbpf.maps import HashMap\n", "from pythonbpf.maps import HashMap\n",
"from pylibbpf import *\n", "from pylibbpf import *\n",
"from ctypes import c_void_p, c_int64, c_uint64, c_int32\n", "from ctypes import c_void_p, c_int64, c_uint64, c_int32\n",

View File

@ -1,7 +1,7 @@
import time import time
from pythonbpf import bpf, map, section, bpfglobal, BPF from pythonbpf import bpf, map, section, bpfglobal, BPF
from pythonbpf.helpers import pid from pythonbpf.helper import pid
from pythonbpf.maps import HashMap from pythonbpf.maps import HashMap
from pylibbpf import BpfMap from pylibbpf import BpfMap
from ctypes import c_void_p, c_int64, c_uint64, c_int32 from ctypes import c_void_p, c_int64, c_uint64, c_int32

View File

@ -22,9 +22,5 @@ def LICENSE() -> str:
b = BPF() b = BPF()
b.load_and_attach() b.load_and_attach()
if b.is_loaded() and b.is_attached():
print("Successfully loaded and attached")
else:
print("Could not load successfully")
# Now cat /sys/kernel/debug/tracing/trace_pipe to see results of the execve syscall. # Now cat /sys/kernel/debug/tracing/trace_pipe to see results of the execve syscall.

View File

@ -1,5 +1,5 @@
from pythonbpf import bpf, map, struct, section, bpfglobal, compile from pythonbpf import bpf, map, struct, section, bpfglobal, compile
from pythonbpf.helpers import ktime, pid from pythonbpf.helper import ktime, pid
from pythonbpf.maps import PerfEventArray from pythonbpf.maps import PerfEventArray
from ctypes import c_void_p, c_int32, c_uint64 from ctypes import c_void_p, c_int32, c_uint64
@ -27,10 +27,7 @@ def hello(ctx: c_void_p) -> c_int32:
dataobj.pid = pid() dataobj.pid = pid()
dataobj.ts = ktime() dataobj.ts = ktime()
# dataobj.comm = strobj # dataobj.comm = strobj
print( print(f"clone called at {dataobj.ts} by pid" f"{dataobj.pid}, comm {strobj}")
f"clone called at {dataobj.ts} by pid {
dataobj.pid}, comm {strobj} at time {ts}"
)
events.output(dataobj) events.output(dataobj)
return c_int32(0) return c_int32(0)

View File

@ -1,5 +1,5 @@
from pythonbpf import bpf, map, section, bpfglobal, compile from pythonbpf import bpf, map, section, bpfglobal, compile
from pythonbpf.helpers import ktime from pythonbpf.helper import ktime
from pythonbpf.maps import HashMap from pythonbpf.maps import HashMap
from ctypes import c_void_p, c_int64, c_uint64 from ctypes import c_void_p, c_int64, c_uint64

View File

@ -1,5 +1,5 @@
from pythonbpf import bpf, map, section, bpfglobal, compile from pythonbpf import bpf, map, section, bpfglobal, compile
from pythonbpf.helpers import XDP_PASS from pythonbpf.helper import XDP_PASS
from pythonbpf.maps import HashMap from pythonbpf.maps import HashMap
from ctypes import c_void_p, c_int64 from ctypes import c_void_p, c_int64

View File

@ -141,7 +141,7 @@ def compile() -> bool:
success = True success = True
success = compile_to_ir(str(caller_file), str(ll_file)) and success success = compile_to_ir(str(caller_file), str(ll_file)) and success
success = ( success = bool(
subprocess.run( subprocess.run(
[ [
"llc", "llc",

View File

@ -1,13 +1,13 @@
from llvmlite import ir from llvmlite import ir
import ast import ast
from typing import Any
from .helper import HelperHandlerRegistry, handle_helper_call from .helper import HelperHandlerRegistry, handle_helper_call
from .type_deducer import ctypes_to_ir from .type_deducer import ctypes_to_ir
from .binary_ops import handle_binary_op from .binary_ops import handle_binary_op
from .expr_pass import eval_expr, handle_expr from .expr_pass import eval_expr, handle_expr
local_var_metadata = {} local_var_metadata: dict[str | Any, Any] = {}
def get_probe_string(func_node): def get_probe_string(func_node):
@ -83,19 +83,16 @@ def handle_assign(
elif isinstance(rval, ast.Constant): elif isinstance(rval, ast.Constant):
if isinstance(rval.value, bool): if isinstance(rval.value, bool):
if rval.value: if rval.value:
builder.store(ir.Constant(ir.IntType(1), 1), builder.store(ir.Constant(ir.IntType(1), 1), local_sym_tab[var_name][0])
local_sym_tab[var_name][0])
else: else:
builder.store(ir.Constant(ir.IntType(1), 0), builder.store(ir.Constant(ir.IntType(1), 0), local_sym_tab[var_name][0])
local_sym_tab[var_name][0])
print(f"Assigned constant {rval.value} to {var_name}") print(f"Assigned constant {rval.value} to {var_name}")
elif isinstance(rval.value, int): elif isinstance(rval.value, int):
# Assume c_int64 for now # Assume c_int64 for now
# var = builder.alloca(ir.IntType(64), name=var_name) # var = builder.alloca(ir.IntType(64), name=var_name)
# var.align = 8 # var.align = 8
builder.store( builder.store(
ir.Constant(ir.IntType(64), ir.Constant(ir.IntType(64), rval.value), local_sym_tab[var_name][0]
rval.value), local_sym_tab[var_name][0]
) )
# local_sym_tab[var_name] = var # local_sym_tab[var_name] = var
print(f"Assigned constant {rval.value} to {var_name}") print(f"Assigned constant {rval.value} to {var_name}")
@ -110,8 +107,7 @@ def handle_assign(
global_str.linkage = "internal" global_str.linkage = "internal"
global_str.global_constant = True global_str.global_constant = True
global_str.initializer = str_const global_str.initializer = str_const
str_ptr = builder.bitcast( str_ptr = builder.bitcast(global_str, ir.PointerType(ir.IntType(8)))
global_str, ir.PointerType(ir.IntType(8)))
builder.store(str_ptr, local_sym_tab[var_name][0]) builder.store(str_ptr, local_sym_tab[var_name][0])
print(f"Assigned string constant '{rval.value}' to {var_name}") print(f"Assigned string constant '{rval.value}' to {var_name}")
else: else:
@ -130,8 +126,7 @@ def handle_assign(
# var = builder.alloca(ir_type, name=var_name) # var = builder.alloca(ir_type, name=var_name)
# var.align = ir_type.width // 8 # var.align = ir_type.width // 8
builder.store( builder.store(
ir.Constant( ir.Constant(ir_type, rval.args[0].value), local_sym_tab[var_name][0]
ir_type, rval.args[0].value), local_sym_tab[var_name][0]
) )
print( print(
f"Assigned {call_type} constant " f"Assigned {call_type} constant "
@ -177,8 +172,7 @@ def handle_assign(
ir_type = struct_info.ir_type ir_type = struct_info.ir_type
# var = builder.alloca(ir_type, name=var_name) # var = builder.alloca(ir_type, name=var_name)
# Null init # Null init
builder.store(ir.Constant(ir_type, None), builder.store(ir.Constant(ir_type, None), local_sym_tab[var_name][0])
local_sym_tab[var_name][0])
local_var_metadata[var_name] = call_type local_var_metadata[var_name] = call_type
print(f"Assigned struct {call_type} to {var_name}") print(f"Assigned struct {call_type} to {var_name}")
# local_sym_tab[var_name] = var # local_sym_tab[var_name] = var
@ -249,8 +243,7 @@ def handle_cond(func, module, builder, cond, local_sym_tab, map_sym_tab):
print(f"Undefined variable {cond.id} in condition") print(f"Undefined variable {cond.id} in condition")
return None return None
elif isinstance(cond, ast.Compare): elif isinstance(cond, ast.Compare):
lhs = eval_expr(func, module, builder, cond.left, lhs = eval_expr(func, module, builder, cond.left, local_sym_tab, map_sym_tab)[0]
local_sym_tab, map_sym_tab)[0]
if len(cond.ops) != 1 or len(cond.comparators) != 1: if len(cond.ops) != 1 or len(cond.comparators) != 1:
print("Unsupported complex comparison") print("Unsupported complex comparison")
return None return None
@ -303,8 +296,7 @@ def handle_if(
else: else:
else_block = None else_block = None
cond = handle_cond(func, module, builder, stmt.test, cond = handle_cond(func, module, builder, stmt.test, local_sym_tab, map_sym_tab)
local_sym_tab, map_sym_tab)
if else_block: if else_block:
builder.cbranch(cond, then_block, else_block) builder.cbranch(cond, then_block, else_block)
else: else:
@ -449,8 +441,7 @@ def allocate_mem(
ir_type = ctypes_to_ir(call_type) ir_type = ctypes_to_ir(call_type)
var = builder.alloca(ir_type, name=var_name) var = builder.alloca(ir_type, name=var_name)
var.align = ir_type.width // 8 var.align = ir_type.width // 8
print( print(f"Pre-allocated variable {var_name} of type {call_type}")
f"Pre-allocated variable {var_name} of type {call_type}")
elif HelperHandlerRegistry.has_handler(call_type): elif HelperHandlerRegistry.has_handler(call_type):
# Assume return type is int64 for now # Assume return type is int64 for now
ir_type = ir.IntType(64) ir_type = ir.IntType(64)
@ -469,8 +460,8 @@ def allocate_mem(
var = builder.alloca(ir_type, name=var_name) var = builder.alloca(ir_type, name=var_name)
local_var_metadata[var_name] = call_type local_var_metadata[var_name] = call_type
print( print(
f"Pre-allocated variable { f"Pre-allocated variable {var_name} "
var_name} for struct {call_type}" f"for struct {call_type}"
) )
elif isinstance(rval.func, ast.Attribute): elif isinstance(rval.func, ast.Attribute):
ir_type = ir.PointerType(ir.IntType(64)) ir_type = ir.PointerType(ir.IntType(64))
@ -665,14 +656,13 @@ def infer_return_type(func_node: ast.FunctionDef):
except Exception: except Exception:
return type(e).__name__ return type(e).__name__
for node in ast.walk(func_node): for walked_node in ast.walk(func_node):
if isinstance(node, ast.Return): if isinstance(walked_node, ast.Return):
t = _expr_type(node.value) t = _expr_type(walked_node.value)
if found_type is None: if found_type is None:
found_type = t found_type = t
elif found_type != t: elif found_type != t:
raise ValueError("Conflicting return types:" f"{ raise ValueError("Conflicting return types:" f"{found_type} vs {t}")
found_type} vs {t}")
return found_type or "None" return found_type or "None"
@ -709,8 +699,7 @@ def assign_string_to_array(builder, target_array_ptr, source_string_ptr, array_l
char = builder.load(src_ptr) char = builder.load(src_ptr)
# Store character in target # Store character in target
dst_ptr = builder.gep( dst_ptr = builder.gep(target_array_ptr, [ir.Constant(ir.IntType(32), 0), idx])
target_array_ptr, [ir.Constant(ir.IntType(32), 0), idx])
builder.store(char, dst_ptr) builder.store(char, dst_ptr)
# Increment counter # Increment counter
@ -721,6 +710,5 @@ def assign_string_to_array(builder, target_array_ptr, source_string_ptr, array_l
# Ensure null termination # Ensure null termination
last_idx = ir.Constant(ir.IntType(32), array_length - 1) last_idx = ir.Constant(ir.IntType(32), array_length - 1)
null_ptr = builder.gep( null_ptr = builder.gep(target_array_ptr, [ir.Constant(ir.IntType(32), 0), last_idx])
target_array_ptr, [ir.Constant(ir.IntType(32), 0), last_idx])
builder.store(ir.Constant(ir.IntType(8), 0), null_ptr) builder.store(ir.Constant(ir.IntType(8), 0), null_ptr)

View File

@ -1,2 +1,13 @@
from .helper_utils import HelperHandlerRegistry from .helper_utils import HelperHandlerRegistry
from .bpf_helper_handler import handle_helper_call from .bpf_helper_handler import handle_helper_call
from .helpers import ktime, pid, deref, XDP_DROP, XDP_PASS
__all__ = [
"HelperHandlerRegistry",
"handle_helper_call",
"ktime",
"pid",
"deref",
"XDP_DROP",
"XDP_PASS",
]

View File

@ -1,10 +1,14 @@
import ast import ast
from llvmlite import ir from llvmlite import ir
from enum import Enum from enum import Enum
from .helper_utils import (HelperHandlerRegistry, from .helper_utils import (
get_or_create_ptr_from_arg, get_flags_val, HelperHandlerRegistry,
handle_fstring_print, simple_string_print, get_or_create_ptr_from_arg,
get_data_ptr_and_size) get_flags_val,
handle_fstring_print,
simple_string_print,
get_data_ptr_and_size,
)
class BPFHelperID(Enum): class BPFHelperID(Enum):
@ -15,12 +19,21 @@ class BPFHelperID(Enum):
BPF_PRINTK = 6 BPF_PRINTK = 6
BPF_GET_CURRENT_PID_TGID = 14 BPF_GET_CURRENT_PID_TGID = 14
BPF_PERF_EVENT_OUTPUT = 25 BPF_PERF_EVENT_OUTPUT = 25
BPF_RINGBUF_RESERVE = 131
BPF_RINGBUF_SUBMIT = 132
@HelperHandlerRegistry.register("ktime") @HelperHandlerRegistry.register("ktime")
def bpf_ktime_get_ns_emitter(call, map_ptr, module, builder, func, def bpf_ktime_get_ns_emitter(
local_sym_tab=None, struct_sym_tab=None, call,
local_var_metadata=None): map_ptr,
module,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
local_var_metadata=None,
):
""" """
Emit LLVM IR for bpf_ktime_get_ns helper function call. Emit LLVM IR for bpf_ktime_get_ns helper function call.
""" """
@ -34,27 +47,34 @@ def bpf_ktime_get_ns_emitter(call, map_ptr, module, builder, func,
@HelperHandlerRegistry.register("lookup") @HelperHandlerRegistry.register("lookup")
def bpf_map_lookup_elem_emitter(call, map_ptr, module, builder, func, def bpf_map_lookup_elem_emitter(
local_sym_tab=None, struct_sym_tab=None, call,
local_var_metadata=None): map_ptr,
module,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
local_var_metadata=None,
):
""" """
Emit LLVM IR for bpf_map_lookup_elem helper function call. Emit LLVM IR for bpf_map_lookup_elem helper function call.
""" """
if not call.args or len(call.args) != 1: if not call.args or len(call.args) != 1:
raise ValueError("Map lookup expects exactly one argument (key), got " raise ValueError(
f"{len(call.args)}") "Map lookup expects exactly one argument (key), got " f"{len(call.args)}"
)
key_ptr = get_or_create_ptr_from_arg(call.args[0], builder, local_sym_tab) key_ptr = get_or_create_ptr_from_arg(call.args[0], builder, local_sym_tab)
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
fn_type = ir.FunctionType( fn_type = ir.FunctionType(
ir.PointerType(), # Return type: void* ir.PointerType(), # Return type: void*
[ir.PointerType(), ir.PointerType()], # Args: (void*, void*) [ir.PointerType(), ir.PointerType()], # Args: (void*, void*)
var_arg=False var_arg=False,
) )
fn_ptr_type = ir.PointerType(fn_type) fn_ptr_type = ir.PointerType(fn_type)
fn_addr = ir.Constant(ir.IntType( fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_MAP_LOOKUP_ELEM.value)
64), BPFHelperID.BPF_MAP_LOOKUP_ELEM.value)
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
result = builder.call(fn_ptr, [map_void_ptr, key_ptr], tail=False) result = builder.call(fn_ptr, [map_void_ptr, key_ptr], tail=False)
@ -63,33 +83,46 @@ def bpf_map_lookup_elem_emitter(call, map_ptr, module, builder, func,
@HelperHandlerRegistry.register("print") @HelperHandlerRegistry.register("print")
def bpf_printk_emitter(call, map_ptr, module, builder, func, def bpf_printk_emitter(
local_sym_tab=None, struct_sym_tab=None, call,
local_var_metadata=None): map_ptr,
module,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
local_var_metadata=None,
):
"""Emit LLVM IR for bpf_printk helper function call.""" """Emit LLVM IR for bpf_printk helper function call."""
if not hasattr(func, "_fmt_counter"): if not hasattr(func, "_fmt_counter"):
func._fmt_counter = 0 func._fmt_counter = 0
if not call.args: if not call.args:
raise ValueError( raise ValueError("bpf_printk expects at least one argument (format string)")
"bpf_printk expects at least one argument (format string)")
args = [] args = []
if isinstance(call.args[0], ast.JoinedStr): if isinstance(call.args[0], ast.JoinedStr):
args = handle_fstring_print(call.args[0], module, builder, func, args = handle_fstring_print(
local_sym_tab, struct_sym_tab, call.args[0],
local_var_metadata) module,
elif (isinstance(call.args[0], ast.Constant) and builder,
isinstance(call.args[0].value, str)): func,
local_sym_tab,
struct_sym_tab,
local_var_metadata,
)
elif isinstance(call.args[0], ast.Constant) and isinstance(call.args[0].value, str):
# TODO: We are only supporting single arguments for now. # TODO: We are only supporting single arguments for now.
# In case of multiple args, the first one will be taken. # In case of multiple args, the first one will be taken.
args = simple_string_print(call.args[0].value, module, builder, func) args = simple_string_print(call.args[0].value, module, builder, func)
else: else:
raise NotImplementedError( raise NotImplementedError(
"Only simple strings or f-strings are supported in bpf_printk.") "Only simple strings or f-strings are supported in bpf_printk."
)
fn_type = ir.FunctionType( fn_type = ir.FunctionType(
ir.IntType(64), [ir.PointerType(), ir.IntType(32)], var_arg=True) ir.IntType(64), [ir.PointerType(), ir.IntType(32)], var_arg=True
)
fn_ptr_type = ir.PointerType(fn_type) fn_ptr_type = ir.PointerType(fn_type)
fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_PRINTK.value) fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_PRINTK.value)
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
@ -99,18 +132,25 @@ def bpf_printk_emitter(call, map_ptr, module, builder, func,
@HelperHandlerRegistry.register("update") @HelperHandlerRegistry.register("update")
def bpf_map_update_elem_emitter(call, map_ptr, module, builder, func, def bpf_map_update_elem_emitter(
local_sym_tab=None, struct_sym_tab=None, call,
local_var_metadata=None): map_ptr,
module,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
local_var_metadata=None,
):
""" """
Emit LLVM IR for bpf_map_update_elem helper function call. Emit LLVM IR for bpf_map_update_elem helper function call.
Expected call signature: map.update(key, value, flags=0) Expected call signature: map.update(key, value, flags=0)
""" """
if (not call.args or if not call.args or len(call.args) < 2 or len(call.args) > 3:
len(call.args) < 2 or raise ValueError(
len(call.args) > 3): "Map update expects 2 or 3 args (key, value, flags), "
raise ValueError("Map update expects 2 or 3 args (key, value, flags), " f"got {len(call.args)}"
f"got {len(call.args)}") )
key_arg = call.args[0] key_arg = call.args[0]
value_arg = call.args[1] value_arg = call.args[1]
@ -124,12 +164,11 @@ def bpf_map_update_elem_emitter(call, map_ptr, module, builder, func,
fn_type = ir.FunctionType( fn_type = ir.FunctionType(
ir.IntType(64), ir.IntType(64),
[ir.PointerType(), ir.PointerType(), ir.PointerType(), ir.IntType(64)], [ir.PointerType(), ir.PointerType(), ir.PointerType(), ir.IntType(64)],
var_arg=False var_arg=False,
) )
fn_ptr_type = ir.PointerType(fn_type) fn_ptr_type = ir.PointerType(fn_type)
fn_addr = ir.Constant(ir.IntType( fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_MAP_UPDATE_ELEM.value)
64), BPFHelperID.BPF_MAP_UPDATE_ELEM.value)
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
if isinstance(flags_val, int): if isinstance(flags_val, int):
@ -138,22 +177,139 @@ def bpf_map_update_elem_emitter(call, map_ptr, module, builder, func,
flags_const = flags_val flags_const = flags_val
result = builder.call( result = builder.call(
fn_ptr, [map_void_ptr, key_ptr, value_ptr, flags_const], tail=False) fn_ptr, [map_void_ptr, key_ptr, value_ptr, flags_const], tail=False
)
return result, None return result, None
@HelperHandlerRegistry.register("submit")
def bpf_ringbuf_submit_emitter(
call,
map_ptr,
module,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
local_var_metadata=None,
):
"""
Emit LLVM IR for bpf_ringbuf_submit helper function call.
Expected call signature: ringbuf.submit(data, flags=0)
"""
if not call.args or len(call.args) < 1 or len(call.args) > 2:
raise ValueError(
"Ringbuf submit expects 1 or 2 args (data, flags), "
f"got {len(call.args)}"
)
data_arg = call.args[0]
data_ptr = get_or_create_ptr_from_arg(data_arg, builder, local_sym_tab)
# Get flags argument (default to 0)
flags_arg = call.args[1] if len(call.args) > 1 else None
flags_val = get_flags_val(flags_arg, builder, local_sym_tab)
# Returns: void
# Args: (void* data, u64 flags)
fn_type = ir.FunctionType(
ir.VoidType(),
[ir.PointerType(), ir.IntType(64)],
var_arg=False,
)
fn_ptr_type = ir.PointerType(fn_type)
fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_RINGBUF_SUBMIT.value)
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
if isinstance(flags_val, int):
flags_const = ir.Constant(ir.IntType(64), flags_val)
else:
flags_const = flags_val
builder.call(fn_ptr, [data_ptr, flags_const], tail=True)
return None
@HelperHandlerRegistry.register("reserve")
def bpf_ringbuf_reserve_emitter(
call,
map_ptr,
module,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
local_var_metadata=None,
):
"""
Emit LLVM IR for bpf_ringbuf_reserve helper function call.
Expected call signature: ringbuf.reserve(size, flags=0)
"""
if not call.args or len(call.args) < 1 or len(call.args) > 2:
raise ValueError(
"Ringbuf reserve expects 1 or 2 args (size, flags), "
f"got {len(call.args)}"
)
# TODO: here, getting length of stuff does not actually work. need to fix this.
size_arg = call.args[0]
if isinstance(size_arg, ast.Constant):
size_val = ir.Constant(ir.IntType(64), size_arg.value)
elif isinstance(size_arg, ast.Name):
if size_arg.id not in local_sym_tab:
raise ValueError(
f"Variable '{size_arg.id}' not found in local symbol table"
)
size_val = builder.load(local_sym_tab[size_arg.id])
else:
raise NotImplementedError(f"Unsupported size argument type: {type(size_arg)}")
flags_arg = call.args[1] if len(call.args) > 1 else None
flags_val = get_flags_val(flags_arg, builder, local_sym_tab)
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
# Args: (void* ringbuf, u64 size, u64 flags)
fn_type = ir.FunctionType(
ir.PointerType(),
[ir.PointerType(), ir.IntType(64), ir.IntType(64)],
var_arg=False,
)
fn_ptr_type = ir.PointerType(fn_type)
fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_RINGBUF_RESERVE.value)
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
if isinstance(flags_val, int):
flags_const = ir.Constant(ir.IntType(64), flags_val)
else:
flags_const = flags_val
result = builder.call(fn_ptr, [map_void_ptr, size_val, flags_const], tail=True)
return result, ir.PointerType()
@HelperHandlerRegistry.register("delete") @HelperHandlerRegistry.register("delete")
def bpf_map_delete_elem_emitter(call, map_ptr, module, builder, func, def bpf_map_delete_elem_emitter(
local_sym_tab=None, struct_sym_tab=None, call,
local_var_metadata=None): map_ptr,
module,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
local_var_metadata=None,
):
""" """
Emit LLVM IR for bpf_map_delete_elem helper function call. Emit LLVM IR for bpf_map_delete_elem helper function call.
Expected call signature: map.delete(key) Expected call signature: map.delete(key)
""" """
if not call.args or len(call.args) != 1: if not call.args or len(call.args) != 1:
raise ValueError("Map delete expects exactly one argument (key), got " raise ValueError(
f"{len(call.args)}") "Map delete expects exactly one argument (key), got " f"{len(call.args)}"
)
key_ptr = get_or_create_ptr_from_arg(call.args[0], builder, local_sym_tab) key_ptr = get_or_create_ptr_from_arg(call.args[0], builder, local_sym_tab)
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
@ -161,12 +317,11 @@ def bpf_map_delete_elem_emitter(call, map_ptr, module, builder, func,
fn_type = ir.FunctionType( fn_type = ir.FunctionType(
ir.IntType(64), # Return type: int64 (status code) ir.IntType(64), # Return type: int64 (status code)
[ir.PointerType(), ir.PointerType()], # Args: (void*, void*) [ir.PointerType(), ir.PointerType()], # Args: (void*, void*)
var_arg=False var_arg=False,
) )
fn_ptr_type = ir.PointerType(fn_type) fn_ptr_type = ir.PointerType(fn_type)
fn_addr = ir.Constant(ir.IntType( fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_MAP_DELETE_ELEM.value)
64), BPFHelperID.BPF_MAP_DELETE_ELEM.value)
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
result = builder.call(fn_ptr, [map_void_ptr, key_ptr], tail=False) result = builder.call(fn_ptr, [map_void_ptr, key_ptr], tail=False)
@ -175,15 +330,21 @@ def bpf_map_delete_elem_emitter(call, map_ptr, module, builder, func,
@HelperHandlerRegistry.register("pid") @HelperHandlerRegistry.register("pid")
def bpf_get_current_pid_tgid_emitter(call, map_ptr, module, builder, func, def bpf_get_current_pid_tgid_emitter(
local_sym_tab=None, struct_sym_tab=None, call,
local_var_metadata=None): map_ptr,
module,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
local_var_metadata=None,
):
""" """
Emit LLVM IR for bpf_get_current_pid_tgid helper function call. Emit LLVM IR for bpf_get_current_pid_tgid helper function call.
""" """
# func is an arg to just have a uniform signature with other emitters # func is an arg to just have a uniform signature with other emitters
helper_id = ir.Constant(ir.IntType( helper_id = ir.Constant(ir.IntType(64), BPFHelperID.BPF_GET_CURRENT_PID_TGID.value)
64), BPFHelperID.BPF_GET_CURRENT_PID_TGID.value)
fn_type = ir.FunctionType(ir.IntType(64), [], var_arg=False) fn_type = ir.FunctionType(ir.IntType(64), [], var_arg=False)
fn_ptr_type = ir.PointerType(fn_type) fn_ptr_type = ir.PointerType(fn_type)
fn_ptr = builder.inttoptr(helper_id, fn_ptr_type) fn_ptr = builder.inttoptr(helper_id, fn_ptr_type)
@ -196,18 +357,26 @@ def bpf_get_current_pid_tgid_emitter(call, map_ptr, module, builder, func,
@HelperHandlerRegistry.register("output") @HelperHandlerRegistry.register("output")
def bpf_perf_event_output_handler(call, map_ptr, module, builder, func, def bpf_perf_event_output_handler(
local_sym_tab=None, struct_sym_tab=None, call,
local_var_metadata=None): map_ptr,
module,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
local_var_metadata=None,
):
if len(call.args) != 1: if len(call.args) != 1:
raise ValueError("Perf event output expects exactly one argument, " raise ValueError(
f"got {len(call.args)}") "Perf event output expects exactly one argument, " f"got {len(call.args)}"
)
data_arg = call.args[0] data_arg = call.args[0]
ctx_ptr = func.args[0] # First argument to the function is ctx ctx_ptr = func.args[0] # First argument to the function is ctx
data_ptr, size_val = get_data_ptr_and_size(data_arg, local_sym_tab, data_ptr, size_val = get_data_ptr_and_size(
struct_sym_tab, data_arg, local_sym_tab, struct_sym_tab, local_var_metadata
local_var_metadata) )
# BPF_F_CURRENT_CPU is -1 in 32 bit # BPF_F_CURRENT_CPU is -1 in 32 bit
flags_val = ir.Constant(ir.IntType(64), 0xFFFFFFFF) flags_val = ir.Constant(ir.IntType(64), 0xFFFFFFFF)
@ -216,36 +385,56 @@ def bpf_perf_event_output_handler(call, map_ptr, module, builder, func,
data_void_ptr = builder.bitcast(data_ptr, ir.PointerType()) data_void_ptr = builder.bitcast(data_ptr, ir.PointerType())
fn_type = ir.FunctionType( fn_type = ir.FunctionType(
ir.IntType(64), ir.IntType(64),
[ir.PointerType(ir.IntType(8)), ir.PointerType(), ir.IntType(64), [
ir.PointerType(), ir.IntType(64)], ir.PointerType(ir.IntType(8)),
var_arg=False ir.PointerType(),
ir.IntType(64),
ir.PointerType(),
ir.IntType(64),
],
var_arg=False,
) )
fn_ptr_type = ir.PointerType(fn_type) fn_ptr_type = ir.PointerType(fn_type)
# helper id # helper id
fn_addr = ir.Constant(ir.IntType(64), fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_PERF_EVENT_OUTPUT.value)
BPFHelperID.BPF_PERF_EVENT_OUTPUT.value)
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
result = builder.call( result = builder.call(
fn_ptr, fn_ptr, [ctx_ptr, map_void_ptr, flags_val, data_void_ptr, size_val], tail=False
[ctx_ptr, map_void_ptr, flags_val, data_void_ptr, size_val], )
tail=False)
return result, None return result, None
def handle_helper_call(call, module, builder, func, def handle_helper_call(
local_sym_tab=None, map_sym_tab=None, call,
struct_sym_tab=None, local_var_metadata=None): module,
builder,
func,
local_sym_tab=None,
map_sym_tab=None,
struct_sym_tab=None,
local_var_metadata=None,
):
"""Process a BPF helper function call and emit the appropriate LLVM IR.""" """Process a BPF helper function call and emit the appropriate LLVM IR."""
# Helper function to get map pointer and invoke handler # Helper function to get map pointer and invoke handler
def invoke_helper(method_name, map_ptr=None): def invoke_helper(method_name, map_ptr=None):
handler = HelperHandlerRegistry.get_handler(method_name) handler = HelperHandlerRegistry.get_handler(method_name)
if not handler: if not handler:
raise NotImplementedError( raise NotImplementedError(
f"Helper function '{method_name}' is not implemented.") f"Helper function '{method_name}' is not implemented."
return handler(call, map_ptr, module, builder, func, )
local_sym_tab, struct_sym_tab, local_var_metadata) return handler(
call,
map_ptr,
module,
builder,
func,
local_sym_tab,
struct_sym_tab,
local_var_metadata,
)
# Handle direct function calls (e.g., print(), ktime()) # Handle direct function calls (e.g., print(), ktime())
if isinstance(call.func, ast.Name): if isinstance(call.func, ast.Name):
@ -255,14 +444,18 @@ def handle_helper_call(call, module, builder, func,
elif isinstance(call.func, ast.Attribute): elif isinstance(call.func, ast.Attribute):
method_name = call.func.attr method_name = call.func.attr
value = call.func.value value = call.func.value
print(f"Handling method call: {ast.dump(call.func)}")
# Get map pointer from different styles of map access # Get map pointer from different styles of map access
if isinstance(value, ast.Call) and isinstance(value.func, ast.Name): if isinstance(value, ast.Call) and isinstance(value.func, ast.Name):
# Variable style: my_map.lookup(key) # Func style: my_map().lookup(key)
map_name = value.func.id map_name = value.func.id
elif isinstance(value, ast.Name):
# Direct style: my_map.lookup(key)
map_name = value.id
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Unsupported map access pattern: {ast.dump(value)}") f"Unsupported map access pattern: {ast.dump(value)}"
)
# Verify map exists and get pointer # Verify map exists and get pointer
if not map_sym_tab or map_name not in map_sym_tab: if not map_sym_tab or map_name not in map_sym_tab:

View File

@ -1,5 +1,7 @@
import ast import ast
import logging import logging
from collections.abc import Callable
from llvmlite import ir from llvmlite import ir
from pythonbpf.expr_pass import eval_expr from pythonbpf.expr_pass import eval_expr
@ -8,14 +10,17 @@ logger = logging.getLogger(__name__)
class HelperHandlerRegistry: class HelperHandlerRegistry:
"""Registry for BPF helpers""" """Registry for BPF helpers"""
_handlers = {}
_handlers: dict[str, Callable] = {}
@classmethod @classmethod
def register(cls, helper_name): def register(cls, helper_name):
"""Decorator to register a handler function for a helper""" """Decorator to register a handler function for a helper"""
def decorator(func): def decorator(func):
cls._handlers[helper_name] = func cls._handlers[helper_name] = func
return func return func
return decorator return decorator
@classmethod @classmethod
@ -55,7 +60,8 @@ def get_or_create_ptr_from_arg(arg, builder, local_sym_tab):
ptr = create_int_constant_ptr(arg.value, builder) ptr = create_int_constant_ptr(arg.value, builder)
else: else:
raise NotImplementedError( raise NotImplementedError(
"Only simple variable names are supported as args in map helpers.") "Only simple variable names are supported as args in map helpers."
)
return ptr return ptr
@ -69,13 +75,13 @@ def get_flags_val(arg, builder, local_sym_tab):
flags_ptr = local_sym_tab[arg.id][0] flags_ptr = local_sym_tab[arg.id][0]
return builder.load(flags_ptr) return builder.load(flags_ptr)
else: else:
raise ValueError( raise ValueError(f"Variable '{arg.id}' not found in local symbol table")
f"Variable '{arg.id}' not found in local symbol table")
elif isinstance(arg, ast.Constant) and isinstance(arg.value, int): elif isinstance(arg, ast.Constant) and isinstance(arg.value, int):
return arg.value return arg.value
raise NotImplementedError( raise NotImplementedError(
"Only var names or int consts are supported as map helpers flags.") "Only var names or int consts are supported as map helpers flags."
)
def simple_string_print(string_value, module, builder, func): def simple_string_print(string_value, module, builder, func):
@ -87,9 +93,15 @@ def simple_string_print(string_value, module, builder, func):
return args return args
def handle_fstring_print(joined_str, module, builder, func, def handle_fstring_print(
local_sym_tab=None, struct_sym_tab=None, joined_str,
local_var_metadata=None): module,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
local_var_metadata=None,
):
"""Handle f-string formatting for bpf_printk emitter.""" """Handle f-string formatting for bpf_printk emitter."""
fmt_parts = [] fmt_parts = []
exprs = [] exprs = []
@ -100,25 +112,34 @@ def handle_fstring_print(joined_str, module, builder, func,
if isinstance(value, ast.Constant): if isinstance(value, ast.Constant):
_process_constant_in_fstring(value, fmt_parts, exprs) _process_constant_in_fstring(value, fmt_parts, exprs)
elif isinstance(value, ast.FormattedValue): elif isinstance(value, ast.FormattedValue):
_process_fval(value, fmt_parts, exprs, _process_fval(
local_sym_tab, struct_sym_tab, value,
local_var_metadata) fmt_parts,
exprs,
local_sym_tab,
struct_sym_tab,
local_var_metadata,
)
else: else:
raise NotImplementedError( raise NotImplementedError(f"Unsupported f-string value type: {type(value)}")
f"Unsupported f-string value type: {type(value)}")
fmt_str = "".join(fmt_parts) fmt_str = "".join(fmt_parts)
args = simple_string_print(fmt_str, module, builder, func) args = simple_string_print(fmt_str, module, builder, func)
# NOTE: Process expressions (limited to 3 due to BPF constraints) # NOTE: Process expressions (limited to 3 due to BPF constraints)
if len(exprs) > 3: if len(exprs) > 3:
logger.warning( logger.warning("bpf_printk supports up to 3 args, extra args will be ignored.")
"bpf_printk supports up to 3 args, extra args will be ignored.")
for expr in exprs[:3]: for expr in exprs[:3]:
arg_value = _prepare_expr_args(expr, func, module, builder, arg_value = _prepare_expr_args(
local_sym_tab, struct_sym_tab, expr,
local_var_metadata) func,
module,
builder,
local_sym_tab,
struct_sym_tab,
local_var_metadata,
)
args.append(arg_value) args.append(arg_value)
return args return args
@ -133,24 +154,31 @@ def _process_constant_in_fstring(cst, fmt_parts, exprs):
exprs.append(ir.Constant(ir.IntType(64), cst.value)) exprs.append(ir.Constant(ir.IntType(64), cst.value))
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Unsupported constant type in f-string: {type(cst.value)}") f"Unsupported constant type in f-string: {type(cst.value)}"
)
def _process_fval(fval, fmt_parts, exprs, def _process_fval(
local_sym_tab, struct_sym_tab, fval, fmt_parts, exprs, local_sym_tab, struct_sym_tab, local_var_metadata
local_var_metadata): ):
"""Process formatted values in f-string.""" """Process formatted values in f-string."""
logger.debug(f"Processing formatted value: {ast.dump(fval)}") logger.debug(f"Processing formatted value: {ast.dump(fval)}")
if isinstance(fval.value, ast.Name): if isinstance(fval.value, ast.Name):
_process_name_in_fval(fval.value, fmt_parts, exprs, local_sym_tab) _process_name_in_fval(fval.value, fmt_parts, exprs, local_sym_tab)
elif isinstance(fval.value, ast.Attribute): elif isinstance(fval.value, ast.Attribute):
_process_attr_in_fval(fval.value, fmt_parts, exprs, _process_attr_in_fval(
local_sym_tab, struct_sym_tab, fval.value,
local_var_metadata) fmt_parts,
exprs,
local_sym_tab,
struct_sym_tab,
local_var_metadata,
)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Unsupported formatted value in f-string: {type(fval.value)}") f"Unsupported formatted value in f-string: {type(fval.value)}"
)
def _process_name_in_fval(name_node, fmt_parts, exprs, local_sym_tab): def _process_name_in_fval(name_node, fmt_parts, exprs, local_sym_tab):
@ -160,34 +188,39 @@ def _process_name_in_fval(name_node, fmt_parts, exprs, local_sym_tab):
_populate_fval(var_type, name_node, fmt_parts, exprs) _populate_fval(var_type, name_node, fmt_parts, exprs)
def _process_attr_in_fval(attr_node, fmt_parts, exprs, def _process_attr_in_fval(
local_sym_tab, struct_sym_tab, attr_node, fmt_parts, exprs, local_sym_tab, struct_sym_tab, local_var_metadata
local_var_metadata): ):
"""Process attribute nodes in formatted values.""" """Process attribute nodes in formatted values."""
if (isinstance(attr_node.value, ast.Name) and if (
local_sym_tab and attr_node.value.id in local_sym_tab): isinstance(attr_node.value, ast.Name)
and local_sym_tab
and attr_node.value.id in local_sym_tab
):
var_name = attr_node.value.id var_name = attr_node.value.id
field_name = attr_node.attr field_name = attr_node.attr
if not local_var_metadata or var_name not in local_var_metadata: if not local_var_metadata or var_name not in local_var_metadata:
raise ValueError( raise ValueError(
f"Metadata for '{var_name}' not found in local var metadata") f"Metadata for '{var_name}' not found in local var metadata"
)
var_type = local_var_metadata[var_name] var_type = local_var_metadata[var_name]
if var_type not in struct_sym_tab: if var_type not in struct_sym_tab:
raise ValueError( raise ValueError(
f"Struct '{var_type}' for '{var_name}' not in symbol table") f"Struct '{var_type}' for '{var_name}' not in symbol table"
)
struct_info = struct_sym_tab[var_type] struct_info = struct_sym_tab[var_type]
if field_name not in struct_info.fields: if field_name not in struct_info.fields:
raise ValueError( raise ValueError(f"Field '{field_name}' not found in struct '{var_type}'")
f"Field '{field_name}' not found in struct '{var_type}'")
field_type = struct_info.field_type(field_name) field_type = struct_info.field_type(field_name)
_populate_fval(field_type, attr_node, fmt_parts, exprs) _populate_fval(field_type, attr_node, fmt_parts, exprs)
else: else:
raise NotImplementedError( raise NotImplementedError(
"Only simple attribute on local vars is supported in f-strings.") "Only simple attribute on local vars is supported in f-strings."
)
def _populate_fval(ftype, node, fmt_parts, exprs): def _populate_fval(ftype, node, fmt_parts, exprs):
@ -202,14 +235,14 @@ def _populate_fval(ftype, node, fmt_parts, exprs):
exprs.append(node) exprs.append(node)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Unsupported integer width in f-string: {ftype.width}") f"Unsupported integer width in f-string: {ftype.width}"
)
elif ftype == ir.PointerType(ir.IntType(8)): elif ftype == ir.PointerType(ir.IntType(8)):
# NOTE: We assume i8* is a string # NOTE: We assume i8* is a string
fmt_parts.append("%s") fmt_parts.append("%s")
exprs.append(node) exprs.append(node)
else: else:
raise NotImplementedError( raise NotImplementedError(f"Unsupported field type in f-string: {ftype}")
f"Unsupported field type in f-string: {ftype}")
def _create_format_string_global(fmt_str, func, module, builder): def _create_format_string_global(fmt_str, func, module, builder):
@ -218,11 +251,11 @@ def _create_format_string_global(fmt_str, func, module, builder):
func._fmt_counter += 1 func._fmt_counter += 1
fmt_gvar = ir.GlobalVariable( fmt_gvar = ir.GlobalVariable(
module, ir.ArrayType(ir.IntType(8), len(fmt_str)), name=fmt_name) module, ir.ArrayType(ir.IntType(8), len(fmt_str)), name=fmt_name
)
fmt_gvar.global_constant = True fmt_gvar.global_constant = True
fmt_gvar.initializer = ir.Constant( fmt_gvar.initializer = ir.Constant(
ir.ArrayType(ir.IntType(8), len(fmt_str)), ir.ArrayType(ir.IntType(8), len(fmt_str)), bytearray(fmt_str.encode("utf8"))
bytearray(fmt_str.encode("utf8"))
) )
fmt_gvar.linkage = "internal" fmt_gvar.linkage = "internal"
fmt_gvar.align = 1 fmt_gvar.align = 1
@ -230,13 +263,20 @@ def _create_format_string_global(fmt_str, func, module, builder):
return builder.bitcast(fmt_gvar, ir.PointerType()) return builder.bitcast(fmt_gvar, ir.PointerType())
def _prepare_expr_args(expr, func, module, builder, def _prepare_expr_args(
local_sym_tab, struct_sym_tab, expr, func, module, builder, local_sym_tab, struct_sym_tab, local_var_metadata
local_var_metadata): ):
"""Evaluate and prepare an expression to use as an arg for bpf_printk.""" """Evaluate and prepare an expression to use as an arg for bpf_printk."""
val, _ = eval_expr(func, module, builder, expr, val, _ = eval_expr(
local_sym_tab, None, struct_sym_tab, func,
local_var_metadata) module,
builder,
expr,
local_sym_tab,
None,
struct_sym_tab,
local_var_metadata,
)
if val: if val:
if isinstance(val.type, ir.PointerType): if isinstance(val.type, ir.PointerType):
@ -246,19 +286,19 @@ def _prepare_expr_args(expr, func, module, builder,
val = builder.sext(val, ir.IntType(64)) val = builder.sext(val, ir.IntType(64))
else: else:
logger.warning( logger.warning(
"Only int and ptr supported in bpf_printk args. " "Only int and ptr supported in bpf_printk args. " "Others default to 0."
"Others default to 0.") )
val = ir.Constant(ir.IntType(64), 0) val = ir.Constant(ir.IntType(64), 0)
return val return val
else: else:
logger.warning( logger.warning(
"Failed to evaluate expression for bpf_printk argument. " "Failed to evaluate expression for bpf_printk argument. "
"It will be converted to 0.") "It will be converted to 0."
)
return ir.Constant(ir.IntType(64), 0) return ir.Constant(ir.IntType(64), 0)
def get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab, def get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab, local_var_metadata):
local_var_metadata):
"""Extract data pointer and size information for perf event output.""" """Extract data pointer and size information for perf event output."""
if isinstance(data_arg, ast.Name): if isinstance(data_arg, ast.Name):
data_name = data_arg.id data_name = data_arg.id
@ -266,7 +306,8 @@ def get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab,
data_ptr = local_sym_tab[data_name][0] data_ptr = local_sym_tab[data_name][0]
else: else:
raise ValueError( raise ValueError(
f"Data variable {data_name} not found in local symbol table.") f"Data variable {data_name} not found in local symbol table."
)
# Check if data_name is a struct # Check if data_name is a struct
if local_var_metadata and data_name in local_var_metadata: if local_var_metadata and data_name in local_var_metadata:
@ -277,12 +318,14 @@ def get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab,
return data_ptr, size_val return data_ptr, size_val
else: else:
raise ValueError( raise ValueError(
f"Struct {data_type} for {data_name} not in symbol table.") f"Struct {data_type} for {data_name} not in symbol table."
)
else: else:
raise ValueError( raise ValueError(
f"Metadata for variable {data_name} " f"Metadata for variable {data_name} "
"not found in local variable metadata.") "not found in local variable metadata."
)
else: else:
raise NotImplementedError( raise NotImplementedError(
"Only simple object names are supported " "Only simple object names are supported " "as data in perf event output."
"as data in perf event output.") )

View File

@ -1,7 +1,11 @@
from collections.abc import Callable
from typing import Any
class MapProcessorRegistry: class MapProcessorRegistry:
"""Registry for map processor functions""" """Registry for map processor functions"""
_processors = {} _processors: dict[str, Callable[..., Any]] = {}
@classmethod @classmethod
def register(cls, map_type_name): def register(cls, map_type_name):

View File

@ -22,27 +22,29 @@ struct {
SEC("tracepoint/syscalls/sys_enter_execve") SEC("tracepoint/syscalls/sys_enter_execve")
int trace_execve(void *ctx) int trace_execve(void *ctx)
{ {
struct event *e; // struct event *e;
__u64 pid_tgid; // __u64 pid_tgid;
__u64 uid_gid; // __u64 uid_gid;
__u32 *e;
// Reserve space in the ringbuffer // Reserve space in the ringbuffer
e = bpf_ringbuf_reserve(&events, sizeof(*e), 0); e = bpf_ringbuf_reserve(&events, sizeof(*e), 0);
if (!e) if (!e)
return 0; return 0;
//
// // Fill the struct with data
// pid_tgid = bpf_get_current_pid_tgid();
// e->pid = pid_tgid >> 32;
//
// uid_gid = bpf_get_current_uid_gid();
// e->uid = uid_gid & 0xFFFFFFFF;
//
// e->timestamp = bpf_ktime_get_ns();
// Fill the struct with data // bpf_get_current_comm(&e->comm, sizeof(e->comm));
pid_tgid = bpf_get_current_pid_tgid(); //
e->pid = pid_tgid >> 32; // // Submit the event to ringbuffer
__u32 temp = 32;
uid_gid = bpf_get_current_uid_gid(); e = &temp;
e->uid = uid_gid & 0xFFFFFFFF;
e->timestamp = bpf_ktime_get_ns();
bpf_get_current_comm(&e->comm, sizeof(e->comm));
// Submit the event to ringbuffer
bpf_ringbuf_submit(e, 0); bpf_ringbuf_submit(e, 0);
return 0; return 0;

View File

@ -0,0 +1,33 @@
from pythonbpf import bpf, map, bpfglobal, section, compile, compile_to_ir, BPF
from pythonbpf.maps import RingBuf
from ctypes import c_int32, c_void_p
# Define a map
@bpf
@map
def mymap() -> RingBuf:
return RingBuf(max_entries=(1024))
@bpf
@section("tracepoint/syscalls/sys_enter_clone")
def random_section(ctx: c_void_p) -> c_int32:
e: c_int32 = mymap().reserve(64)
if e == 0: # here is the issue i think
return c_int32(0)
mymap().submit(e)
return c_int32(0)
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile_to_ir("ringbuf.py", "ringbuf.ll")
compile()
b = BPF()
b.load_and_attach()
while True:
print("running")

View File

@ -1,5 +1,5 @@
from pythonbpf import bpf, map, section, bpfglobal, compile from pythonbpf import bpf, map, section, bpfglobal, compile
from pythonbpf.helpers import XDP_PASS from pythonbpf.helper import XDP_PASS
from pythonbpf.maps import HashMap from pythonbpf.maps import HashMap
from ctypes import c_void_p, c_int64 from ctypes import c_void_p, c_int64

View File

@ -1,5 +1,5 @@
from pythonbpf import bpf, map, struct, section, bpfglobal, compile, compile_to_ir, BPF from pythonbpf import bpf, map, struct, section, bpfglobal, compile, compile_to_ir, BPF
from pythonbpf.helpers import ktime, pid from pythonbpf.helper import ktime, pid
from pythonbpf.maps import PerfEventArray from pythonbpf.maps import PerfEventArray
from ctypes import c_void_p, c_int32, c_uint64 from ctypes import c_void_p, c_int32, c_uint64

View File

@ -1,5 +1,5 @@
from pythonbpf import bpf, BPF, map, bpfglobal, section, compile, compile_to_ir from pythonbpf import bpf, map, bpfglobal, section, compile, compile_to_ir, BPF
from pythonbpf.maps import RingBuf, HashMap from pythonbpf.maps import RingBuf
from ctypes import c_int32, c_void_p from ctypes import c_int32, c_void_p
@ -9,17 +9,13 @@ from ctypes import c_int32, c_void_p
def mymap() -> RingBuf: def mymap() -> RingBuf:
return RingBuf(max_entries=(1024)) return RingBuf(max_entries=(1024))
@bpf
@map
def mymap2() -> HashMap:
return HashMap(key=c_int32, value=c_int32, max_entries=1024)
@bpf @bpf
@section("tracepoint/syscalls/sys_enter_clone") @section("tracepoint/syscalls/sys_enter_clone")
def random_section(ctx: c_void_p) -> c_int32: def random_section(ctx: c_void_p) -> c_int32:
print("Hello") print("Hello")
e = mymap().reserve(6)
if e:
mymap().submit(e)
return c_int32(0) return c_int32(0)
@ -33,3 +29,5 @@ compile_to_ir("ringbuf.py", "ringbuf.ll")
compile() compile()
b = BPF() b = BPF()
b.load_and_attach() b.load_and_attach()
while True:
print("running")