43 Commits

Author SHA1 Message Date
e0ad1bfb0f Move bulk of allocation logic to allocation_pass 2025-10-12 12:14:46 +05:30
69bee5fee9 Seperate LocalSymbol from functions 2025-10-12 12:10:09 +05:30
2f1aaa4834 Fix typos 2025-10-12 11:41:01 +05:30
0f6971bcc2 Refactor allocate_mem 2025-10-12 11:34:40 +05:30
08c0ccf0ac Pass map_sym_tab to handle_struct_field_assign 2025-10-12 10:37:20 +05:30
64e44d0d58 Use handle_struct_field_assignment in handle_assign 2025-10-12 10:30:46 +05:30
3ad1b73c5a Add handle_struct_field_assignment to assign_pass 2025-10-12 10:19:52 +05:30
105c5a7bd0 Cleanup handle_assign 2025-10-12 10:12:45 +05:30
933d2a5c77 Fix comprehensive assignment test 2025-10-12 09:47:57 +05:30
b93f704eb8 Tweak the comprehensive assignment test 2025-10-12 09:46:16 +05:30
fa82dc7ebd Add comprehensive passing test for assignment 2025-10-12 09:39:33 +05:30
e8026a13bf Allow helpers to be called within themselves 2025-10-12 09:30:37 +05:30
a3b4d09652 Fix errorstring in _handle_unary_op 2025-10-12 09:13:04 +05:30
4e33fd4a32 Add negation UnaryOp 2025-10-12 09:11:56 +05:30
2cf68f6473 Allow map-based helpers to be used as helper args / within binops which are helper args 2025-10-12 07:57:55 +05:30
d66e6a6aff Allow struct members as helper args 2025-10-12 06:00:50 +05:30
cd74e896cf Allow binops as args to helpers accepting int* 2025-10-12 04:20:46 +05:30
207f714027 Use scratch space to store consts passed to helpers 2025-10-12 04:17:37 +05:30
5dcf670f49 Add ScratchPoolManager and it's singleton 2025-10-12 01:47:11 +05:30
6bce29b90f Allocate scratch space for temp vars at the end of allocate_mem 2025-10-12 00:37:57 +05:30
321415fa28 Add update_max_temps_for_stmt in allocate_mem 2025-10-12 00:33:07 +05:30
8776d7607f Add count_temps_in_call to call scratch space needed in a helper call 2025-10-12 00:17:10 +05:30
8b7b1c08a5 Add struct_and_helper_binops passing test for assignments 2025-10-11 22:03:32 +05:30
c9bbe1ffd8 Call eval_expr properly within get_operand_value 2025-10-11 03:21:09 +05:30
91a3fe140d Remove unnecessary return artifacts from get_operand_value 2025-10-11 03:06:24 +05:30
c2c17741e5 Remove store_through_chain 2025-10-11 03:04:26 +05:30
cac88d1560 Allow different int widths in binops 2025-10-11 02:44:08 +05:30
317575644f Interpret bools as ints in binops 2025-10-11 00:18:11 +05:30
a756f5e4b7 Add passing helper test for assignment 2025-10-10 23:55:12 +05:30
7529820c0b Allow int** pointers to store binops of type int** op int 2025-10-10 20:36:37 +05:30
9febadffd3 Add pointer handling to helper_utils, finish pointer assignment 2025-10-10 15:01:15 +05:30
99aacca94b WIP: allow pointer assignments to var 2025-10-10 13:48:40 +05:30
1d517d4e09 Add double_alloc in alloc_mem 2025-10-10 12:28:45 +05:30
047f361ea9 Allocate twice for map lookups 2025-10-10 06:09:46 +05:30
489244a015 Add store_through_chain 2025-10-10 02:56:11 +05:30
8bab07ed72 Remove recursive_dereferencer 2025-10-10 00:13:35 +05:30
1253f51ff3 Use deref_to_val instead of recursive_dereferencer in get_operand value 2025-10-09 23:11:06 +05:30
23afb0bd33 Add deref_to_val to deref into final value and return the chain as well in binops 2025-10-09 21:47:28 +05:30
c596213b2a Add cst_var_binop.py as passing assign test 2025-10-09 03:42:25 +05:30
054a834464 Add failing assign test retype.py, with explanation 2025-10-09 03:28:07 +05:30
d7bfe86524 Add handle_variable_assignment to assign_pass 2025-10-09 03:09:10 +05:30
84ed27f222 Add handle_variable_assignment stub and boilerplate in handle_assign 2025-10-08 22:55:03 +05:30
6008d9841f Change loglevel of multi-assignment warning in handle_assign 2025-10-08 22:45:09 +05:30
40 changed files with 249405 additions and 1182 deletions

2
.gitignore vendored
View File

@ -7,5 +7,3 @@ __pycache__/
*.ll
*.o
.ipynb_checkpoints/
vmlinux.py
~*

View File

@ -12,7 +12,7 @@
#
# See https://github.com/pre-commit/pre-commit
exclude: 'vmlinux.py'
exclude: 'vmlinux.*\.py$'
ci:
autoupdate_commit_msg: "chore: update pre-commit hooks"
@ -41,7 +41,7 @@ repos:
- id: ruff
args: ["--fix", "--show-fixes"]
- id: ruff-format
# exclude: ^(docs)|^(tests)|^(examples)
exclude: ^(docs)|^(tests)|^(examples)
# Checking static types
- repo: https://github.com/pre-commit/mirrors-mypy

View File

@ -83,14 +83,14 @@ def hist() -> HashMap:
def hello(ctx: c_void_p) -> c_int64:
process_id = pid()
one = 1
prev = hist.lookup(process_id)
prev = hist().lookup(process_id)
if prev:
previous_value = prev + 1
print(f"count: {previous_value} with {process_id}")
hist.update(process_id, previous_value)
hist().update(process_id, previous_value)
return c_int64(0)
else:
hist.update(process_id, one)
hist().update(process_id, one)
return c_int64(0)

13
TODO.md Normal file
View File

@ -0,0 +1,13 @@
## Short term
- Implement enough functionality to port the BCC tutorial examples in PythonBPF
- Add all maps
- XDP support in pylibbpf
- ringbuf support
- Add oneline IfExpr conditionals (wishlist)
## Long term
- Refactor the codebase to be better than a hackathon project
- Port to C++ and use actual LLVM?
- Fix struct_kioctx issue in the vmlinux transpiler

View File

@ -308,7 +308,6 @@
"def hist() -> HashMap:\n",
" return HashMap(key=c_int32, value=c_uint64, max_entries=4096)\n",
"\n",
"\n",
"@bpf\n",
"@section(\"tracepoint/syscalls/sys_enter_clone\")\n",
"def hello(ctx: c_void_p) -> c_int64:\n",
@ -330,7 +329,6 @@
"def LICENSE() -> str:\n",
" return \"GPL\"\n",
"\n",
"\n",
"b = BPF()"
]
},
@ -359,6 +357,7 @@
}
],
"source": [
"\n",
"b.load_and_attach()\n",
"hist = BpfMap(b, hist)\n",
"print(\"Recording\")\n",

View File

@ -8,14 +8,12 @@ def hello_world(ctx: c_void_p) -> c_int64:
print("Hello, World!")
return c_int64(0)
@bpf
@section("kprobe/do_unlinkat")
def hello_world2(ctx: c_void_p) -> c_int64:
print("Hello, World!")
return c_int64(0)
@bpf
@bpfglobal
def LICENSE() -> str:

View File

@ -27,7 +27,7 @@ def hello(ctx: c_void_p) -> c_int32:
dataobj.pid = pid()
dataobj.ts = ktime()
# dataobj.comm = strobj
print(f"clone called at {dataobj.ts} by pid{dataobj.pid}, comm {strobj}")
print(f"clone called at {dataobj.ts} by pid" f"{dataobj.pid}, comm {strobj}")
events.output(dataobj)
return c_int32(0)

248446
examples/vmlinux.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,8 @@
from pythonbpf import bpf, map, section, bpfglobal, compile, compile_to_ir
from pythonbpf import bpf, map, section, bpfglobal, compile
from pythonbpf.helper import XDP_PASS
from pythonbpf.maps import HashMap
from ctypes import c_int64, c_void_p
from ctypes import c_void_p, c_int64
# Instructions to how to run this program
# 1. Install PythonBPF: pip install pythonbpf
@ -41,5 +41,4 @@ def LICENSE() -> str:
return "GPL"
compile_to_ir("xdp_pass.py", "xdp_pass.ll")
compile()

View File

@ -0,0 +1,191 @@
import ast
import logging
from llvmlite import ir
from dataclasses import dataclass
from typing import Any
from pythonbpf.helper import HelperHandlerRegistry
from pythonbpf.type_deducer import ctypes_to_ir
logger = logging.getLogger(__name__)
@dataclass
class LocalSymbol:
var: ir.AllocaInstr
ir_type: ir.Type
metadata: Any = None
def __iter__(self):
yield self.var
yield self.ir_type
yield self.metadata
def _is_helper_call(call_node):
"""Check if a call node is a BPF helper function call."""
if isinstance(call_node.func, ast.Name):
# Exclude print from requiring temps (handles f-strings differently)
func_name = call_node.func.id
return HelperHandlerRegistry.has_handler(func_name) and func_name != "print"
elif isinstance(call_node.func, ast.Attribute):
return HelperHandlerRegistry.has_handler(call_node.func.attr)
return False
def handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab):
"""Handle memory allocation for assignment statements."""
# Validate assignment
if len(stmt.targets) != 1:
logger.warning("Multi-target assignment not supported, skipping allocation")
return
target = stmt.targets[0]
# Skip non-name targets (e.g., struct field assignments)
if isinstance(target, ast.Attribute):
logger.debug(f"Struct field assignment to {target.attr}, no allocation needed")
return
if not isinstance(target, ast.Name):
logger.warning(f"Unsupported assignment target type: {type(target).__name__}")
return
var_name = target.id
rval = stmt.value
# Skip if already allocated
if var_name in local_sym_tab:
logger.debug(f"Variable {var_name} already allocated, skipping")
return
# Determine type and allocate based on rval
if isinstance(rval, ast.Call):
_allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab)
elif isinstance(rval, ast.Constant):
_allocate_for_constant(builder, var_name, rval, local_sym_tab)
elif isinstance(rval, ast.BinOp):
_allocate_for_binop(builder, var_name, local_sym_tab)
else:
logger.warning(
f"Unsupported assignment value type for {var_name}: {type(rval).__name__}"
)
def _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab):
"""Allocate memory for variable assigned from a call."""
if isinstance(rval.func, ast.Name):
call_type = rval.func.id
# C type constructors
if call_type in ("c_int32", "c_int64", "c_uint32", "c_uint64"):
ir_type = ctypes_to_ir(call_type)
var = builder.alloca(ir_type, name=var_name)
var.align = ir_type.width // 8
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
logger.info(f"Pre-allocated {var_name} as {call_type}")
# Helper functions
elif HelperHandlerRegistry.has_handler(call_type):
ir_type = ir.IntType(64) # Assume i64 return type
var = builder.alloca(ir_type, name=var_name)
var.align = 8
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
logger.info(f"Pre-allocated {var_name} for helper {call_type}")
# Deref function
elif call_type == "deref":
ir_type = ir.IntType(64) # Assume i64 return type
var = builder.alloca(ir_type, name=var_name)
var.align = 8
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
logger.info(f"Pre-allocated {var_name} for deref")
# Struct constructors
elif call_type in structs_sym_tab:
struct_info = structs_sym_tab[call_type]
var = builder.alloca(struct_info.ir_type, name=var_name)
local_sym_tab[var_name] = LocalSymbol(var, struct_info.ir_type, call_type)
logger.info(f"Pre-allocated {var_name} for struct {call_type}")
else:
logger.warning(f"Unknown call type for allocation: {call_type}")
elif isinstance(rval.func, ast.Attribute):
# Map method calls - need double allocation for ptr handling
_allocate_for_map_method(builder, var_name, local_sym_tab)
else:
logger.warning(f"Unsupported call function type for {var_name}")
def _allocate_for_map_method(builder, var_name, local_sym_tab):
"""Allocate memory for variable assigned from map method (double alloc)."""
# Main variable (pointer to pointer)
ir_type = ir.PointerType(ir.IntType(64))
var = builder.alloca(ir_type, name=var_name)
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
# Temporary variable for computed values
tmp_ir_type = ir.IntType(64)
var_tmp = builder.alloca(tmp_ir_type, name=f"{var_name}_tmp")
local_sym_tab[f"{var_name}_tmp"] = LocalSymbol(var_tmp, tmp_ir_type)
logger.info(f"Pre-allocated {var_name} and {var_name}_tmp for map method")
def _allocate_for_constant(builder, var_name, rval, local_sym_tab):
"""Allocate memory for variable assigned from a constant."""
if isinstance(rval.value, bool):
ir_type = ir.IntType(1)
var = builder.alloca(ir_type, name=var_name)
var.align = 1
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
logger.info(f"Pre-allocated {var_name} as bool")
elif isinstance(rval.value, int):
ir_type = ir.IntType(64)
var = builder.alloca(ir_type, name=var_name)
var.align = 8
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
logger.info(f"Pre-allocated {var_name} as i64")
elif isinstance(rval.value, str):
ir_type = ir.PointerType(ir.IntType(8))
var = builder.alloca(ir_type, name=var_name)
var.align = 8
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
logger.info(f"Pre-allocated {var_name} as string")
else:
logger.warning(
f"Unsupported constant type for {var_name}: {type(rval.value).__name__}"
)
def _allocate_for_binop(builder, var_name, local_sym_tab):
"""Allocate memory for variable assigned from a binary operation."""
ir_type = ir.IntType(64) # Assume i64 result
var = builder.alloca(ir_type, name=var_name)
var.align = 8
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
logger.info(f"Pre-allocated {var_name} for binop result")
def allocate_temp_pool(builder, max_temps, local_sym_tab):
"""Allocate the temporary scratch space pool for helper arguments."""
if max_temps == 0:
return
logger.info(f"Allocating temp pool of {max_temps} variables")
for i in range(max_temps):
temp_name = f"__helper_temp_{i}"
temp_var = builder.alloca(ir.IntType(64), name=temp_name)
temp_var.align = 8
local_sym_tab[temp_name] = LocalSymbol(temp_var, ir.IntType(64))

108
pythonbpf/assign_pass.py Normal file
View File

@ -0,0 +1,108 @@
import ast
import logging
from llvmlite import ir
from pythonbpf.expr import eval_expr
logger = logging.getLogger(__name__)
def handle_struct_field_assignment(
func, module, builder, target, rval, local_sym_tab, map_sym_tab, structs_sym_tab
):
"""Handle struct field assignment (obj.field = value)."""
var_name = target.value.id
field_name = target.attr
if var_name not in local_sym_tab:
logger.error(f"Variable '{var_name}' not found in symbol table")
return
struct_type = local_sym_tab[var_name].metadata
struct_info = structs_sym_tab[struct_type]
if field_name not in struct_info.fields:
logger.error(f"Field '{field_name}' not found in struct '{struct_type}'")
return
# Get field pointer and evaluate value
field_ptr = struct_info.gep(builder, local_sym_tab[var_name].var, field_name)
val = eval_expr(
func, module, builder, rval, local_sym_tab, map_sym_tab, structs_sym_tab
)
if val is None:
logger.error(f"Failed to evaluate value for {var_name}.{field_name}")
return
# TODO: Handle string assignment to char array (not a priority)
field_type = struct_info.field_type(field_name)
if isinstance(field_type, ir.ArrayType) and val[1] == ir.PointerType(ir.IntType(8)):
logger.warning(
f"String to char array assignment not implemented for {var_name}.{field_name}"
)
return
# Store the value
builder.store(val[0], field_ptr)
logger.info(f"Assigned to struct field {var_name}.{field_name}")
def handle_variable_assignment(
func, module, builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab
):
"""Handle single named variable assignment."""
if var_name not in local_sym_tab:
logger.error(f"Variable {var_name} not declared.")
return False
var_ptr = local_sym_tab[var_name].var
var_type = local_sym_tab[var_name].ir_type
# NOTE: Special case for struct initialization
if isinstance(rval, ast.Call) and isinstance(rval.func, ast.Name):
struct_name = rval.func.id
if struct_name in structs_sym_tab and len(rval.args) == 0:
struct_info = structs_sym_tab[struct_name]
ir_struct = struct_info.ir_type
builder.store(ir.Constant(ir_struct, None), var_ptr)
logger.info(f"Initialized struct {struct_name} for variable {var_name}")
return True
val_result = eval_expr(
func, module, builder, rval, local_sym_tab, map_sym_tab, structs_sym_tab
)
if val_result is None:
logger.error(f"Failed to evaluate value for {var_name}")
return False
val, val_type = val_result
logger.info(f"Evaluated value for {var_name}: {val} of type {val_type}, {var_type}")
if val_type != var_type:
if isinstance(val_type, ir.IntType) and isinstance(var_type, ir.IntType):
# Allow implicit int widening
if val_type.width < var_type.width:
val = builder.sext(val, var_type)
logger.info(f"Implicitly widened int for variable {var_name}")
elif val_type.width > var_type.width:
val = builder.trunc(val, var_type)
logger.info(f"Implicitly truncated int for variable {var_name}")
elif isinstance(val_type, ir.IntType) and isinstance(var_type, ir.PointerType):
# NOTE: This is assignment to a PTR_TO_MAP_VALUE_OR_NULL
logger.info(
f"Creating temporary variable for pointer assignment to {var_name}"
)
var_ptr_tmp = local_sym_tab[f"{var_name}_tmp"].var
builder.store(val, var_ptr_tmp)
val = var_ptr_tmp
else:
logger.error(
f"Type mismatch for variable {var_name}: {val_type} vs {var_type}"
)
return False
builder.store(val, var_ptr)
logger.info(f"Assigned value to variable {var_name}")
return True

View File

@ -3,43 +3,70 @@ from llvmlite import ir
from logging import Logger
import logging
from pythonbpf.expr import get_base_type_and_depth, deref_to_depth, eval_expr
logger: Logger = logging.getLogger(__name__)
def recursive_dereferencer(var, builder):
"""dereference until primitive type comes out"""
# TODO: Not worrying about stack overflow for now
logger.info(f"Dereferencing {var}, type is {var.type}")
if isinstance(var.type, ir.PointerType):
a = builder.load(var)
return recursive_dereferencer(a, builder)
elif isinstance(var.type, ir.IntType):
return var
else:
raise TypeError(f"Unsupported type for dereferencing: {var.type}")
def get_operand_value(operand, builder, local_sym_tab):
def get_operand_value(
func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None
):
"""Extract the value from an operand, handling variables and constants."""
logger.info(f"Getting operand value for: {ast.dump(operand)}")
if isinstance(operand, ast.Name):
if operand.id in local_sym_tab:
return recursive_dereferencer(local_sym_tab[operand.id].var, builder)
var = local_sym_tab[operand.id].var
var_type = var.type
base_type, depth = get_base_type_and_depth(var_type)
logger.info(f"var is {var}, base_type is {base_type}, depth is {depth}")
val = deref_to_depth(func, builder, var, depth)
return val
raise ValueError(f"Undefined variable: {operand.id}")
elif isinstance(operand, ast.Constant):
if isinstance(operand.value, int):
return ir.Constant(ir.IntType(64), operand.value)
cst = ir.Constant(ir.IntType(64), int(operand.value))
return cst
raise TypeError(f"Unsupported constant type: {type(operand.value)}")
elif isinstance(operand, ast.BinOp):
return handle_binary_op_impl(operand, builder, local_sym_tab)
res = handle_binary_op_impl(
func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab
)
return res
else:
res = eval_expr(
func, module, builder, operand, local_sym_tab, map_sym_tab, structs_sym_tab
)
if res is None:
raise ValueError(f"Failed to evaluate call expression: {operand}")
val, _ = res
logger.info(f"Evaluated expr to {val} of type {val.type}")
base_type, depth = get_base_type_and_depth(val.type)
if depth > 0:
val = deref_to_depth(func, builder, val, depth)
return val
raise TypeError(f"Unsupported operand type: {type(operand)}")
def handle_binary_op_impl(rval, builder, local_sym_tab):
def handle_binary_op_impl(
func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None
):
op = rval.op
left = get_operand_value(rval.left, builder, local_sym_tab)
right = get_operand_value(rval.right, builder, local_sym_tab)
left = get_operand_value(
func, module, rval.left, builder, local_sym_tab, map_sym_tab, structs_sym_tab
)
right = get_operand_value(
func, module, rval.right, builder, local_sym_tab, map_sym_tab, structs_sym_tab
)
logger.info(f"left is {left}, right is {right}, op is {op}")
# NOTE: Before doing the operation, if the operands are integers
# we always extend them to i64. The assignment to LHS will take
# care of truncation if needed.
if isinstance(left.type, ir.IntType) and left.type.width < 64:
left = builder.sext(left, ir.IntType(64))
if isinstance(right.type, ir.IntType) and right.type.width < 64:
right = builder.sext(right, ir.IntType(64))
# Map AST operation nodes to LLVM IR builder methods
op_map = {
ast.Add: builder.add,
@ -62,8 +89,19 @@ def handle_binary_op_impl(rval, builder, local_sym_tab):
raise SyntaxError("Unsupported binary operation")
def handle_binary_op(rval, builder, var_name, local_sym_tab):
result = handle_binary_op_impl(rval, builder, local_sym_tab)
def handle_binary_op(
func,
module,
rval,
builder,
var_name,
local_sym_tab,
map_sym_tab,
structs_sym_tab=None,
):
result = handle_binary_op_impl(
func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab
)
if var_name and var_name in local_sym_tab:
logger.info(
f"Storing result {result} into variable {local_sym_tab[var_name].var}"

View File

@ -4,7 +4,6 @@ from .license_pass import license_processing
from .functions import func_proc
from .maps import maps_proc
from .structs import structs_proc
from .vmlinux_parser import vmlinux_proc
from .globals_pass import (
globals_list_creation,
globals_processing,
@ -45,7 +44,6 @@ def processor(source_code, filename, module):
for func_node in bpf_chunks:
logger.info(f"Found BPF function/struct: {func_node.name}")
vmlinux_proc(tree, module)
populate_global_symbol_table(tree, module)
license_processing(tree, module)
globals_processing(tree, module)

View File

@ -1,4 +1,10 @@
from .expr_pass import eval_expr, handle_expr
from .type_normalization import convert_to_bool
from .type_normalization import convert_to_bool, get_base_type_and_depth, deref_to_depth
__all__ = ["eval_expr", "handle_expr", "convert_to_bool"]
__all__ = [
"eval_expr",
"handle_expr",
"convert_to_bool",
"get_base_type_and_depth",
"deref_to_depth",
]

View File

@ -26,7 +26,7 @@ def _handle_constant_expr(expr: ast.Constant):
if isinstance(expr.value, int) or isinstance(expr.value, bool):
return ir.Constant(ir.IntType(64), int(expr.value)), ir.IntType(64)
else:
logger.error("Unsupported constant type")
logger.error(f"Unsupported constant type {ast.dump(expr)}")
return None
@ -176,21 +176,28 @@ def _handle_unary_op(
structs_sym_tab=None,
):
"""Handle ast.UnaryOp expressions."""
if not isinstance(expr.op, ast.Not):
logger.error("Only 'not' unary operator is supported")
if not isinstance(expr.op, ast.Not) and not isinstance(expr.op, ast.USub):
logger.error("Only 'not' and '-' unary operators are supported")
return None
operand = eval_expr(
func, module, builder, expr.operand, local_sym_tab, map_sym_tab, structs_sym_tab
from pythonbpf.binary_ops import get_operand_value
operand = get_operand_value(
func, module, expr.operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab
)
if operand is None:
logger.error("Failed to evaluate operand for unary operation")
return None
operand_val, operand_type = operand
true_const = ir.Constant(ir.IntType(1), 1)
result = builder.xor(convert_to_bool(builder, operand_val), true_const)
return result, ir.IntType(1)
if isinstance(expr.op, ast.Not):
true_const = ir.Constant(ir.IntType(1), 1)
result = builder.xor(convert_to_bool(builder, operand), true_const)
return result, ir.IntType(1)
elif isinstance(expr.op, ast.USub):
# Multiply by -1
neg_one = ir.Constant(ir.IntType(64), -1)
result = builder.mul(operand, neg_one)
return result, ir.IntType(64)
def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab):
@ -402,7 +409,16 @@ def eval_expr(
elif isinstance(expr, ast.BinOp):
from pythonbpf.binary_ops import handle_binary_op
return handle_binary_op(expr, builder, None, local_sym_tab)
return handle_binary_op(
func,
module,
expr,
builder,
None,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
elif isinstance(expr, ast.Compare):
return _handle_compare(
func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab

View File

@ -16,7 +16,7 @@ COMPARISON_OPS = {
}
def _get_base_type_and_depth(ir_type):
def get_base_type_and_depth(ir_type):
"""Get the base type for pointer types."""
cur_type = ir_type
depth = 0
@ -26,7 +26,7 @@ def _get_base_type_and_depth(ir_type):
return cur_type, depth
def _deref_to_depth(func, builder, val, target_depth):
def deref_to_depth(func, builder, val, target_depth):
"""Dereference a pointer to a certain depth."""
cur_val = val
@ -88,13 +88,13 @@ def _normalize_types(func, builder, lhs, rhs):
logger.error(f"Type mismatch: {lhs.type} vs {rhs.type}")
return None, None
else:
lhs_base, lhs_depth = _get_base_type_and_depth(lhs.type)
rhs_base, rhs_depth = _get_base_type_and_depth(rhs.type)
lhs_base, lhs_depth = get_base_type_and_depth(lhs.type)
rhs_base, rhs_depth = get_base_type_and_depth(rhs.type)
if lhs_base == rhs_base:
if lhs_depth < rhs_depth:
rhs = _deref_to_depth(func, builder, rhs, rhs_depth - lhs_depth)
rhs = deref_to_depth(func, builder, rhs, rhs_depth - lhs_depth)
elif rhs_depth < lhs_depth:
lhs = _deref_to_depth(func, builder, lhs, lhs_depth - rhs_depth)
lhs = deref_to_depth(func, builder, lhs, lhs_depth - rhs_depth)
return _normalize_types(func, builder, lhs, rhs)

View File

@ -1,13 +1,18 @@
from llvmlite import ir
import ast
import logging
from typing import Any
from dataclasses import dataclass
from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call
from pythonbpf.helper import (
HelperHandlerRegistry,
reset_scratch_pool,
)
from pythonbpf.type_deducer import ctypes_to_ir
from pythonbpf.binary_ops import handle_binary_op
from pythonbpf.expr import eval_expr, handle_expr, convert_to_bool
from pythonbpf.assign_pass import (
handle_variable_assignment,
handle_struct_field_assignment,
)
from pythonbpf.allocation_pass import handle_assign_allocation, allocate_temp_pool
from .return_utils import _handle_none_return, _handle_xdp_return, _is_xdp_name
@ -15,18 +20,6 @@ from .return_utils import _handle_none_return, _handle_xdp_return, _is_xdp_name
logger = logging.getLogger(__name__)
@dataclass
class LocalSymbol:
var: ir.AllocaInstr
ir_type: ir.Type
metadata: Any = None
def __iter__(self):
yield self.var
yield self.ir_type
yield self.metadata
def get_probe_string(func_node):
"""Extract the probe string from the decorator of the function node."""
# TODO: right now we have the whole string in the section decorator
@ -48,196 +41,49 @@ def handle_assign(
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab
):
"""Handle assignment statements in the function body."""
if len(stmt.targets) != 1:
logger.info("Unsupported multiassignment")
return
num_types = ("c_int32", "c_int64", "c_uint32", "c_uint64")
# TODO: Support this later
# GH #37
if len(stmt.targets) != 1:
logger.error("Multi-target assignment is not supported for now")
return
target = stmt.targets[0]
logger.info(f"Handling assignment to {ast.dump(target)}")
if not isinstance(target, ast.Name) and not isinstance(target, ast.Attribute):
logger.info("Unsupported assignment target")
return
var_name = target.id if isinstance(target, ast.Name) else target.value.id
rval = stmt.value
if isinstance(target, ast.Name):
# NOTE: Simple variable assignment case: x = 5
var_name = target.id
result = handle_variable_assignment(
func,
module,
builder,
var_name,
rval,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
if not result:
logger.error(f"Failed to handle assignment to {var_name}")
return
if isinstance(target, ast.Attribute):
# struct field assignment
field_name = target.attr
if var_name in local_sym_tab:
struct_type = local_sym_tab[var_name].metadata
struct_info = structs_sym_tab[struct_type]
if field_name in struct_info.fields:
field_ptr = struct_info.gep(
builder, local_sym_tab[var_name].var, field_name
)
val = eval_expr(
func,
module,
builder,
rval,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
if isinstance(struct_info.field_type(field_name), ir.ArrayType) and val[
1
] == ir.PointerType(ir.IntType(8)):
# TODO: Figure it out, not a priority rn
# Special case for string assignment to char array
# str_len = struct_info["field_types"][field_idx].count
# assign_string_to_array(builder, field_ptr, val[0], str_len)
# print(f"Assigned to struct field {var_name}.{field_name}")
pass
if val is None:
logger.info("Failed to evaluate struct field assignment")
return
logger.info(field_ptr)
builder.store(val[0], field_ptr)
logger.info(f"Assigned to struct field {var_name}.{field_name}")
return
elif isinstance(rval, ast.Constant):
if isinstance(rval.value, bool):
if rval.value:
builder.store(
ir.Constant(ir.IntType(1), 1), local_sym_tab[var_name].var
)
else:
builder.store(
ir.Constant(ir.IntType(1), 0), local_sym_tab[var_name].var
)
logger.info(f"Assigned constant {rval.value} to {var_name}")
elif isinstance(rval.value, int):
# Assume c_int64 for now
# var = builder.alloca(ir.IntType(64), name=var_name)
# var.align = 8
builder.store(
ir.Constant(ir.IntType(64), rval.value), local_sym_tab[var_name].var
)
logger.info(f"Assigned constant {rval.value} to {var_name}")
elif isinstance(rval.value, str):
str_val = rval.value.encode("utf-8") + b"\x00"
str_const = ir.Constant(
ir.ArrayType(ir.IntType(8), len(str_val)), bytearray(str_val)
)
global_str = ir.GlobalVariable(
module, str_const.type, name=f"{var_name}_str"
)
global_str.linkage = "internal"
global_str.global_constant = True
global_str.initializer = str_const
str_ptr = builder.bitcast(global_str, ir.PointerType(ir.IntType(8)))
builder.store(str_ptr, local_sym_tab[var_name].var)
logger.info(f"Assigned string constant '{rval.value}' to {var_name}")
else:
logger.info("Unsupported constant type")
elif isinstance(rval, ast.Call):
if isinstance(rval.func, ast.Name):
call_type = rval.func.id
logger.info(f"Assignment call type: {call_type}")
if (
call_type in num_types
and len(rval.args) == 1
and isinstance(rval.args[0], ast.Constant)
and isinstance(rval.args[0].value, int)
):
ir_type = ctypes_to_ir(call_type)
# var = builder.alloca(ir_type, name=var_name)
# var.align = ir_type.width // 8
builder.store(
ir.Constant(ir_type, rval.args[0].value),
local_sym_tab[var_name].var,
)
logger.info(
f"Assigned {call_type} constant {rval.args[0].value} to {var_name}"
)
elif HelperHandlerRegistry.has_handler(call_type):
# var = builder.alloca(ir.IntType(64), name=var_name)
# var.align = 8
val = handle_helper_call(
rval,
module,
builder,
func,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
builder.store(val[0], local_sym_tab[var_name].var)
logger.info(f"Assigned constant {rval.func.id} to {var_name}")
elif call_type == "deref" and len(rval.args) == 1:
logger.info(f"Handling deref assignment {ast.dump(rval)}")
val = eval_expr(
func,
module,
builder,
rval,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
if val is None:
logger.info("Failed to evaluate deref argument")
return
logger.info(f"Dereferenced value: {val}, storing in {var_name}")
builder.store(val[0], local_sym_tab[var_name].var)
logger.info(f"Dereferenced and assigned to {var_name}")
elif call_type in structs_sym_tab and len(rval.args) == 0:
struct_info = structs_sym_tab[call_type]
ir_type = struct_info.ir_type
# var = builder.alloca(ir_type, name=var_name)
# Null init
builder.store(ir.Constant(ir_type, None), local_sym_tab[var_name].var)
logger.info(f"Assigned struct {call_type} to {var_name}")
else:
logger.info(f"Unsupported assignment call type: {call_type}")
elif isinstance(rval.func, ast.Attribute):
logger.info(f"Assignment call attribute: {ast.dump(rval.func)}")
if isinstance(rval.func.value, ast.Name):
if rval.func.value.id in map_sym_tab:
map_name = rval.func.value.id
method_name = rval.func.attr
if HelperHandlerRegistry.has_handler(method_name):
val = handle_helper_call(
rval,
module,
builder,
func,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
builder.store(val[0], local_sym_tab[var_name].var)
else:
# TODO: probably a struct access
logger.info(f"TODO STRUCT ACCESS {ast.dump(rval)}")
elif isinstance(rval.func.value, ast.Call) and isinstance(
rval.func.value.func, ast.Name
):
map_name = rval.func.value.func.id
method_name = rval.func.attr
if map_name in map_sym_tab:
if HelperHandlerRegistry.has_handler(method_name):
val = handle_helper_call(
rval,
module,
builder,
func,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
# var = builder.alloca(ir.IntType(64), name=var_name)
# var.align = 8
builder.store(val[0], local_sym_tab[var_name].var)
else:
logger.info("Unsupported assignment call structure")
else:
logger.info("Unsupported assignment call function type")
elif isinstance(rval, ast.BinOp):
handle_binary_op(rval, builder, var_name, local_sym_tab)
else:
logger.info("Unsupported assignment value type")
# NOTE: Struct field assignment case: pkt.field = value
handle_struct_field_assignment(
func,
module,
builder,
target,
rval,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
return
# Unsupported target type
logger.error(f"Unsupported assignment target: {ast.dump(target)}")
def handle_cond(
@ -330,6 +176,7 @@ def process_stmt(
ret_type=ir.IntType(64),
):
logger.info(f"Processing statement: {ast.dump(stmt)}")
reset_scratch_pool()
if isinstance(stmt, ast.Expr):
handle_expr(
func,
@ -360,119 +207,107 @@ def process_stmt(
return did_return
def handle_if_allocation(
module, builder, stmt, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
):
"""Recursively handle allocations in if/else branches."""
if stmt.body:
allocate_mem(
module,
builder,
stmt.body,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
if stmt.orelse:
allocate_mem(
module,
builder,
stmt.orelse,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
def count_temps_in_call(call_node, local_sym_tab):
"""Count the number of temporary variables needed for a function call."""
count = 0
is_helper = False
# NOTE: We exclude print calls for now
if isinstance(call_node.func, ast.Name):
if (
HelperHandlerRegistry.has_handler(call_node.func.id)
and call_node.func.id != "print"
):
is_helper = True
elif isinstance(call_node.func, ast.Attribute):
if HelperHandlerRegistry.has_handler(call_node.func.attr):
is_helper = True
if not is_helper:
return 0
for arg in call_node.args:
# NOTE: Count all non-name arguments
# For struct fields, if it is being passed as an argument,
# The struct object should already exist in the local_sym_tab
if not isinstance(arg, ast.Name) and not (
isinstance(arg, ast.Attribute) and arg.value.id in local_sym_tab
):
count += 1
return count
def allocate_mem(
module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
):
for stmt in body:
has_metadata = False
if isinstance(stmt, ast.If):
if stmt.body:
local_sym_tab = allocate_mem(
module,
builder,
stmt.body,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
if stmt.orelse:
local_sym_tab = allocate_mem(
module,
builder,
stmt.orelse,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
elif isinstance(stmt, ast.Assign):
if len(stmt.targets) != 1:
logger.info("Unsupported multiassignment")
continue
target = stmt.targets[0]
if not isinstance(target, ast.Name):
logger.info("Unsupported assignment target")
continue
var_name = target.id
rval = stmt.value
if var_name in local_sym_tab:
logger.info(f"Variable {var_name} already allocated")
continue
if isinstance(rval, ast.Call):
if isinstance(rval.func, ast.Name):
call_type = rval.func.id
if call_type in ("c_int32", "c_int64", "c_uint32", "c_uint64"):
ir_type = ctypes_to_ir(call_type)
var = builder.alloca(ir_type, name=var_name)
var.align = ir_type.width // 8
logger.info(
f"Pre-allocated variable {var_name} of type {call_type}"
)
elif HelperHandlerRegistry.has_handler(call_type):
# Assume return type is int64 for now
ir_type = ir.IntType(64)
var = builder.alloca(ir_type, name=var_name)
var.align = ir_type.width // 8
logger.info(f"Pre-allocated variable {var_name} for helper")
elif call_type == "deref" and len(rval.args) == 1:
# Assume return type is int64 for now
ir_type = ir.IntType(64)
var = builder.alloca(ir_type, name=var_name)
var.align = ir_type.width // 8
logger.info(f"Pre-allocated variable {var_name} for deref")
elif call_type in structs_sym_tab:
struct_info = structs_sym_tab[call_type]
ir_type = struct_info.ir_type
var = builder.alloca(ir_type, name=var_name)
has_metadata = True
logger.info(
f"Pre-allocated variable {var_name} for struct {call_type}"
)
elif isinstance(rval.func, ast.Attribute):
ir_type = ir.PointerType(ir.IntType(64))
var = builder.alloca(ir_type, name=var_name)
# var.align = ir_type.width // 8
logger.info(f"Pre-allocated variable {var_name} for map")
else:
logger.info("Unsupported assignment call function type")
continue
elif isinstance(rval, ast.Constant):
if isinstance(rval.value, bool):
ir_type = ir.IntType(1)
var = builder.alloca(ir_type, name=var_name)
var.align = 1
logger.info(f"Pre-allocated variable {var_name} of type c_bool")
elif isinstance(rval.value, int):
# Assume c_int64 for now
ir_type = ir.IntType(64)
var = builder.alloca(ir_type, name=var_name)
var.align = ir_type.width // 8
logger.info(f"Pre-allocated variable {var_name} of type c_int64")
elif isinstance(rval.value, str):
ir_type = ir.PointerType(ir.IntType(8))
var = builder.alloca(ir_type, name=var_name)
var.align = 8
logger.info(f"Pre-allocated variable {var_name} of type string")
else:
logger.info("Unsupported constant type")
continue
elif isinstance(rval, ast.BinOp):
# Assume c_int64 for now
ir_type = ir.IntType(64)
var = builder.alloca(ir_type, name=var_name)
var.align = ir_type.width // 8
logger.info(f"Pre-allocated variable {var_name} of type c_int64")
else:
logger.info("Unsupported assignment value type")
continue
max_temps_needed = 0
def update_max_temps_for_stmt(stmt):
nonlocal max_temps_needed
temps_needed = 0
if isinstance(stmt, ast.If):
for s in stmt.body:
update_max_temps_for_stmt(s)
for s in stmt.orelse:
update_max_temps_for_stmt(s)
return
for node in ast.walk(stmt):
if isinstance(node, ast.Call):
temps_needed += count_temps_in_call(node, local_sym_tab)
max_temps_needed = max(max_temps_needed, temps_needed)
for stmt in body:
update_max_temps_for_stmt(stmt)
# Handle allocations
if isinstance(stmt, ast.If):
handle_if_allocation(
module,
builder,
stmt,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
elif isinstance(stmt, ast.Assign):
handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab)
allocate_temp_pool(builder, max_temps_needed, local_sym_tab)
if has_metadata:
local_sym_tab[var_name] = LocalSymbol(var, ir_type, call_type)
else:
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
return local_sym_tab

View File

@ -1,9 +1,10 @@
from .helper_utils import HelperHandlerRegistry
from .helper_utils import HelperHandlerRegistry, reset_scratch_pool
from .bpf_helper_handler import handle_helper_call
from .helpers import ktime, pid, deref, XDP_DROP, XDP_PASS
__all__ = [
"HelperHandlerRegistry",
"reset_scratch_pool",
"handle_helper_call",
"ktime",
"pid",

View File

@ -34,6 +34,7 @@ def bpf_ktime_get_ns_emitter(
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_ktime_get_ns helper function call.
@ -56,6 +57,7 @@ def bpf_map_lookup_elem_emitter(
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_map_lookup_elem helper function call.
@ -64,11 +66,17 @@ def bpf_map_lookup_elem_emitter(
raise ValueError(
f"Map lookup expects exactly one argument (key), got {len(call.args)}"
)
key_ptr = get_or_create_ptr_from_arg(call.args[0], builder, local_sym_tab)
key_ptr = get_or_create_ptr_from_arg(
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab
)
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
# TODO: I have changed the return type to i64*, as we are
# allocating space for that type in allocate_mem. This is
# temporary, and we will honour other widths later. But this
# allows us to have cool binary ops on the returned value.
fn_type = ir.FunctionType(
ir.PointerType(), # Return type: void*
ir.PointerType(ir.IntType(64)), # Return type: void*
[ir.PointerType(), ir.PointerType()], # Args: (void*, void*)
var_arg=False,
)
@ -91,6 +99,7 @@ def bpf_printk_emitter(
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""Emit LLVM IR for bpf_printk helper function call."""
if not hasattr(func, "_fmt_counter"):
@ -138,6 +147,7 @@ def bpf_map_update_elem_emitter(
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_map_update_elem helper function call.
@ -152,8 +162,12 @@ def bpf_map_update_elem_emitter(
value_arg = call.args[1]
flags_arg = call.args[2] if len(call.args) > 2 else None
key_ptr = get_or_create_ptr_from_arg(key_arg, builder, local_sym_tab)
value_ptr = get_or_create_ptr_from_arg(value_arg, builder, local_sym_tab)
key_ptr = get_or_create_ptr_from_arg(
func, module, key_arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab
)
value_ptr = get_or_create_ptr_from_arg(
func, module, value_arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab
)
flags_val = get_flags_val(flags_arg, builder, local_sym_tab)
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
@ -188,6 +202,7 @@ def bpf_map_delete_elem_emitter(
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_map_delete_elem helper function call.
@ -197,7 +212,9 @@ def bpf_map_delete_elem_emitter(
raise ValueError(
f"Map delete expects exactly one argument (key), got {len(call.args)}"
)
key_ptr = get_or_create_ptr_from_arg(call.args[0], builder, local_sym_tab)
key_ptr = get_or_create_ptr_from_arg(
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab
)
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
# Define function type for bpf_map_delete_elem
@ -225,6 +242,7 @@ def bpf_get_current_pid_tgid_emitter(
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_get_current_pid_tgid helper function call.
@ -251,6 +269,7 @@ def bpf_perf_event_output_handler(
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
if len(call.args) != 1:
raise ValueError(
@ -315,6 +334,7 @@ def handle_helper_call(
func,
local_sym_tab,
struct_sym_tab,
map_sym_tab,
)
# Handle direct function calls (e.g., print(), ktime())

View File

@ -3,7 +3,8 @@ import logging
from collections.abc import Callable
from llvmlite import ir
from pythonbpf.expr import eval_expr
from pythonbpf.expr import eval_expr, get_base_type_and_depth, deref_to_depth
from pythonbpf.binary_ops import get_operand_value
logger = logging.getLogger(__name__)
@ -34,6 +35,41 @@ class HelperHandlerRegistry:
return helper_name in cls._handlers
class ScratchPoolManager:
"""Manage the temporary helper variables in local_sym_tab"""
def __init__(self):
self._counter = 0
@property
def counter(self):
return self._counter
def reset(self):
self._counter = 0
logger.debug("Scratch pool counter reset to 0")
def get_next_temp(self, local_sym_tab):
temp_name = f"__helper_temp_{self._counter}"
self._counter += 1
if temp_name not in local_sym_tab:
raise ValueError(
f"Scratch pool exhausted or inadequate: {temp_name}. "
f"Current counter: {self._counter}"
)
return local_sym_tab[temp_name].var, temp_name
_temp_pool_manager = ScratchPoolManager() # Singleton instance
def reset_scratch_pool():
"""Reset the scratch pool counter"""
_temp_pool_manager.reset()
def get_var_ptr_from_name(var_name, local_sym_tab):
"""Get a pointer to a variable from the symbol table."""
if local_sym_tab and var_name in local_sym_tab:
@ -41,27 +77,41 @@ def get_var_ptr_from_name(var_name, local_sym_tab):
raise ValueError(f"Variable '{var_name}' not found in local symbol table")
def create_int_constant_ptr(value, builder, int_width=64):
def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64):
"""Create a pointer to an integer constant."""
# Default to 64-bit integer
int_type = ir.IntType(int_width)
ptr = builder.alloca(int_type)
ptr.align = int_type.width // 8
builder.store(ir.Constant(int_type, value), ptr)
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab)
logger.info(f"Using temp variable '{temp_name}' for int constant {value}")
const_val = ir.Constant(ir.IntType(int_width), value)
builder.store(const_val, ptr)
return ptr
def get_or_create_ptr_from_arg(arg, builder, local_sym_tab):
def get_or_create_ptr_from_arg(
func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab=None
):
"""Extract or create pointer from the call arguments."""
if isinstance(arg, ast.Name):
ptr = get_var_ptr_from_name(arg.id, local_sym_tab)
elif isinstance(arg, ast.Constant) and isinstance(arg.value, int):
ptr = create_int_constant_ptr(arg.value, builder)
ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab)
else:
raise NotImplementedError(
"Only simple variable names are supported as args in map helpers."
# Evaluate the expression and store the result in a temp variable
val = get_operand_value(
func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab
)
if val is None:
raise ValueError("Failed to evaluate expression for helper arg.")
# NOTE: We assume the result is an int64 for now
# if isinstance(arg, ast.Attribute):
# return val
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab)
logger.info(f"Using temp variable '{temp_name}' for expression result")
builder.store(val, ptr)
return ptr
@ -224,10 +274,27 @@ def _populate_fval(ftype, node, fmt_parts, exprs):
raise NotImplementedError(
f"Unsupported integer width in f-string: {ftype.width}"
)
elif ftype == ir.PointerType(ir.IntType(8)):
# NOTE: We assume i8* is a string
fmt_parts.append("%s")
exprs.append(node)
elif isinstance(ftype, ir.PointerType):
target, depth = get_base_type_and_depth(ftype)
if isinstance(target, ir.IntType):
if target.width == 64:
fmt_parts.append("%lld")
exprs.append(node)
elif target.width == 32:
fmt_parts.append("%d")
exprs.append(node)
elif target.width == 8 and depth == 1:
# NOTE: Assume i8* is a string
fmt_parts.append("%s")
exprs.append(node)
else:
raise NotImplementedError(
f"Unsupported pointer target type in f-string: {target}"
)
else:
raise NotImplementedError(
f"Unsupported pointer target type in f-string: {target}"
)
else:
raise NotImplementedError(f"Unsupported field type in f-string: {ftype}")
@ -264,7 +331,20 @@ def _prepare_expr_args(expr, func, module, builder, local_sym_tab, struct_sym_ta
if val:
if isinstance(val.type, ir.PointerType):
val = builder.ptrtoint(val, ir.IntType(64))
target, depth = get_base_type_and_depth(val.type)
if isinstance(target, ir.IntType):
if target.width >= 32:
val = deref_to_depth(func, builder, val, depth)
val = builder.sext(val, ir.IntType(64))
elif target.width == 8 and depth == 1:
# NOTE: i8* is string, no need to deref
pass
else:
logger.warning(
"Only int and ptr supported in bpf_printk args. Others default to 0."
)
val = ir.Constant(ir.IntType(64), 0)
elif isinstance(val.type, ir.IntType):
if val.type.width < 64:
val = builder.sext(val, ir.IntType(64))

View File

@ -15,8 +15,5 @@ def deref(ptr):
return result if result is not None else 0
XDP_ABORTED = ctypes.c_int64(0)
XDP_DROP = ctypes.c_int64(1)
XDP_PASS = ctypes.c_int64(2)
XDP_TX = ctypes.c_int64(3)
XDP_REDIRECT = ctypes.c_int64(4)

View File

@ -1,3 +0,0 @@
from .import_detector import vmlinux_proc
__all__ = ["vmlinux_proc"]

View File

@ -1,177 +0,0 @@
import logging
from functools import lru_cache
import importlib
from .dependency_handler import DependencyHandler
from .dependency_node import DependencyNode
import ctypes
logger = logging.getLogger(__name__)
@lru_cache(maxsize=1)
def get_module_symbols(module_name: str):
imported_module = importlib.import_module(module_name)
return [name for name in dir(imported_module)], imported_module
def process_vmlinux_class(node, llvm_module, handler: DependencyHandler):
symbols_in_module, imported_module = get_module_symbols("vmlinux")
if node.name in symbols_in_module:
vmlinux_type = getattr(imported_module, node.name)
process_vmlinux_post_ast(vmlinux_type, llvm_module, handler)
else:
raise ImportError(f"{node.name} not in vmlinux")
# Recursive function that gets all the dependent classes and adds them to handler
def process_vmlinux_post_ast(node, llvm_module, handler: DependencyHandler, processing_stack=None):
"""
Recursively process vmlinux classes and their dependencies.
Args:
node: The class/type to process
llvm_module: The LLVM module context
handler: DependencyHandler to track all nodes
processing_stack: Set of currently processing nodes to detect cycles
"""
# Initialize processing stack on first call
if processing_stack is None:
processing_stack = set()
symbols_in_module, imported_module = get_module_symbols("vmlinux")
# Handle both node objects and type objects
if hasattr(node, "name"):
current_symbol_name = node.name
elif hasattr(node, "__name__"):
current_symbol_name = node.__name__
else:
current_symbol_name = str(node)
if current_symbol_name not in symbols_in_module:
raise ImportError(f"{current_symbol_name} not present in module vmlinux")
# Check if we're already processing this node (circular dependency)
if current_symbol_name in processing_stack:
logger.debug(f"Circular dependency detected for {current_symbol_name}, skipping")
return True
# Check if already processed
if handler.has_node(current_symbol_name):
existing_node = handler.get_node(current_symbol_name)
# If the node exists and is ready, we're done
if existing_node and existing_node.is_ready:
logger.info(f"Node {current_symbol_name} already processed and ready")
return True
logger.info(f"Resolving vmlinux class {current_symbol_name}")
logger.debug(
f"Current handler state: {handler.is_ready} readiness and {handler.get_all_nodes()} all nodes"
)
# Add to processing stack to detect cycles
processing_stack.add(current_symbol_name)
try:
field_table = {} # should contain the field and it's type.
# Get the class object from the module
class_obj = getattr(imported_module, current_symbol_name)
# Inspect the class fields
if hasattr(class_obj, "_fields_"):
for field_name, field_type in class_obj._fields_:
field_table[field_name] = field_type
elif hasattr(class_obj, "__annotations__"):
for field_name, field_type in class_obj.__annotations__.items():
field_table[field_name] = field_type
else:
raise TypeError("Could not get required class and definition")
logger.debug(f"Extracted fields for {current_symbol_name}: {field_table}")
# Create or get the node
if handler.has_node(current_symbol_name):
new_dep_node = handler.get_node(current_symbol_name)
else:
new_dep_node = DependencyNode(name=current_symbol_name)
handler.add_node(new_dep_node)
# Process each field
for elem_name, elem_type in field_table.items():
module_name = getattr(elem_type, "__module__", None)
if module_name == ctypes.__name__:
# Simple ctypes - mark as ready immediately
new_dep_node.add_field(elem_name, elem_type, ready=True)
elif module_name == "vmlinux":
# Complex vmlinux type - needs recursive processing
new_dep_node.add_field(elem_name, elem_type, ready=False)
logger.debug(f"Processing vmlinux field: {elem_name}, type: {elem_type}")
identify_ctypes_type(elem_name, elem_type, new_dep_node)
# Determine the actual symbol to process
symbol_name = (
elem_type.__name__
if hasattr(elem_type, "__name__")
else str(elem_type)
)
vmlinux_symbol = None
# Handle pointers/arrays to other types
if hasattr(elem_type, "_type_"):
containing_module_name = getattr(
(elem_type._type_), "__module__", None
)
if containing_module_name == ctypes.__name__:
# Pointer/Array to ctypes - mark as ready
new_dep_node.set_field_ready(elem_name, True)
continue
elif containing_module_name == "vmlinux":
# Pointer/Array to vmlinux type
symbol_name = (
(elem_type._type_).__name__
if hasattr((elem_type._type_), "__name__")
else str(elem_type._type_)
)
# Self-referential check
if symbol_name == current_symbol_name:
logger.debug(f"Self-referential field {elem_name} in {current_symbol_name}")
# For pointers to self, we can mark as ready since the type is being defined
new_dep_node.set_field_ready(elem_name, True)
continue
vmlinux_symbol = getattr(imported_module, symbol_name)
else:
# Direct vmlinux type (not pointer/array)
vmlinux_symbol = getattr(imported_module, symbol_name)
# Recursively process the dependency
if vmlinux_symbol is not None:
if process_vmlinux_post_ast(vmlinux_symbol, llvm_module, handler, processing_stack):
new_dep_node.set_field_ready(elem_name, True)
else:
raise ValueError(
f"{elem_name} with type {elem_type} not supported in recursive resolver"
)
logger.info(f"Successfully processed node: {current_symbol_name}")
return True
finally:
# Remove from processing stack when done
processing_stack.discard(current_symbol_name)
def identify_ctypes_type(elem_name, elem_type, new_dep_node: DependencyNode):
if isinstance(elem_type, type):
if issubclass(elem_type, ctypes.Array):
new_dep_node.set_field_type(elem_name, ctypes.Array)
new_dep_node.set_field_containing_type(elem_name, elem_type._type_)
new_dep_node.set_field_type_size(elem_name, elem_type._length_)
elif issubclass(elem_type, ctypes._Pointer):
new_dep_node.set_field_type(elem_name, ctypes._Pointer)
new_dep_node.set_field_containing_type(elem_name, elem_type._type_)
else:
raise TypeError("Instance sent instead of Class")

View File

@ -1,149 +0,0 @@
from typing import Optional, Dict, List, Iterator
from .dependency_node import DependencyNode
class DependencyHandler:
"""
Manages a collection of DependencyNode objects with no duplicates.
Ensures that no two nodes with the same name can be added and provides
methods to check readiness and retrieve specific nodes.
Example usage:
# Create a handler
handler = DependencyHandler()
# Create some dependency nodes
node1 = DependencyNode(name="node1")
node1.add_field("field1", str)
node1.set_field_value("field1", "value1")
node2 = DependencyNode(name="node2")
node2.add_field("field1", int)
# Add nodes to the handler
handler.add_node(node1)
handler.add_node(node2)
# Check if a specific node exists
print(handler.has_node("node1")) # True
# Get a reference to a node and modify it
node = handler.get_node("node2")
node.set_field_value("field1", 42)
# Check if all nodes are ready
print(handler.is_ready) # False (node2 is ready, but node1 isn't)
"""
def __init__(self):
# Using a dictionary with node names as keys ensures name uniqueness
# and provides efficient lookups
self._nodes: Dict[str, DependencyNode] = {}
def add_node(self, node: DependencyNode) -> bool:
"""
Add a dependency node to the handler.
Args:
node: The DependencyNode to add
Returns:
bool: True if the node was added, False if a node with the same name already exists
Raises:
TypeError: If the provided object is not a DependencyNode
"""
if not isinstance(node, DependencyNode):
raise TypeError(f"Expected DependencyNode, got {type(node).__name__}")
# Check if a node with this name already exists
if node.name in self._nodes:
return False
self._nodes[node.name] = node
return True
@property
def is_ready(self) -> bool:
"""
Check if all nodes are ready.
Returns:
bool: True if all nodes are ready (or if there are no nodes), False otherwise
"""
if not self._nodes:
return True
return all(node.is_ready for node in self._nodes.values())
def has_node(self, name: str) -> bool:
"""
Check if a node with the given name exists.
Args:
name: The name to check
Returns:
bool: True if a node with the given name exists, False otherwise
"""
return name in self._nodes
def get_node(self, name: str) -> Optional[DependencyNode]:
"""
Get a node by name for manipulation.
Args:
name: The name of the node to retrieve
Returns:
Optional[DependencyNode]: The node with the given name, or None if not found
"""
return self._nodes.get(name)
def remove_node(self, node_or_name) -> bool:
"""
Remove a node by name or reference.
Args:
node_or_name: The node to remove or its name
Returns:
bool: True if the node was removed, False if not found
"""
if isinstance(node_or_name, DependencyNode):
name = node_or_name.name
else:
name = node_or_name
if name in self._nodes:
del self._nodes[name]
return True
return False
def get_all_nodes(self) -> List[DependencyNode]:
"""
Get all nodes stored in the handler.
Returns:
List[DependencyNode]: List of all nodes
"""
return list(self._nodes.values())
def __iter__(self) -> Iterator[DependencyNode]:
"""
Iterate over all nodes.
Returns:
Iterator[DependencyNode]: Iterator over all nodes
"""
return iter(self._nodes.values())
def __len__(self) -> int:
"""
Get the number of nodes in the handler.
Returns:
int: The number of nodes
"""
return len(self._nodes)

View File

@ -1,190 +0,0 @@
from dataclasses import dataclass, field
from typing import Dict, Any, Optional
#TODO: FIX THE FUCKING TYPE NAME CONVENTION.
@dataclass
class Field:
"""Represents a field in a dependency node with its type and readiness state."""
name: str
type: type
containing_type: Optional[Any]
type_size: Optional[int]
value: Any = None
ready: bool = False
def set_ready(self, is_ready: bool = True) -> None:
"""Set the readiness state of this field."""
self.ready = is_ready
def set_value(self, value: Any, mark_ready: bool = True) -> None:
"""Set the value of this field and optionally mark it as ready."""
self.value = value
if mark_ready:
self.ready = True
def set_type(self, given_type, mark_ready: bool = True) -> None:
"""Set value of the type field and mark as ready"""
self.type = given_type
if mark_ready:
self.ready = True
def set_containing_type(
self, containing_type: Optional[Any], mark_ready: bool = True
) -> None:
"""Set the containing_type of this field and optionally mark it as ready."""
self.containing_type = containing_type
if mark_ready:
self.ready = True
def set_type_size(self, type_size: Any, mark_ready: bool = True) -> None:
"""Set the type_size of this field and optionally mark it as ready."""
self.type_size = type_size
if mark_ready:
self.ready = True
@dataclass
class DependencyNode:
"""
A node with typed fields and readiness tracking.
Example usage:
# Create a dependency node for a Person
somestruct = DependencyNode(name="struct_1")
# Add fields with their types
somestruct.add_field("field_1", str)
somestruct.add_field("field_2", int)
somestruct.add_field("field_3", str)
# Check if the node is ready (should be False initially)
print(f"Is node ready? {somestruct.is_ready}") # False
# Set some field values
somestruct.set_field_value("field_1", "someproperty")
somestruct.set_field_value("field_2", 30)
# Check if the node is ready (still False because email is not ready)
print(f"Is node ready? {somestruct.is_ready}") # False
# Set the last field and make the node ready
somestruct.set_field_value("field_3", "anotherproperty")
# Now the node should be ready
print(f"Is node ready? {somestruct.is_ready}") # True
# You can also mark a field as not ready
somestruct.set_field_ready("field_3", False)
# Now the node is not ready again
print(f"Is node ready? {somestruct.is_ready}") # False
# Get all field values
print(somestruct.get_field_values()) # {'field_1': 'someproperty', 'field_2': 30, 'field_3': 'anotherproperty'}
# Get only ready fields
ready_fields = somestruct.get_ready_fields()
print(f"Ready fields: {[field.name for field in ready_fields.values()]}") # ['field_1', 'field_2']
"""
name: str
fields: Dict[str, Field] = field(default_factory=dict)
_ready_cache: Optional[bool] = field(default=None, repr=False)
def add_field(
self,
name: str,
field_type: type,
initial_value: Any = None,
containing_type: Optional[Any] = None,
type_size: Optional[int] = None,
ready: bool = False,
) -> None:
"""Add a field to the node with an optional initial value and readiness state."""
self.fields[name] = Field(
name=name,
type=field_type,
value=initial_value,
ready=ready,
containing_type=containing_type,
type_size=type_size,
)
# Invalidate readiness cache
self._ready_cache = None
def get_field(self, name: str) -> Field:
"""Get a field by name."""
return self.fields[name]
def set_field_value(self, name: str, value: Any, mark_ready: bool = True) -> None:
"""Set a field's value and optionally mark it as ready."""
if name not in self.fields:
raise KeyError(f"Field '{name}' does not exist in node '{self.name}'")
self.fields[name].set_value(value, mark_ready)
# Invalidate readiness cache
self._ready_cache = None
def set_field_type(self, name: str, type: Any, mark_ready: bool = True) -> None:
"""Set a field's type and optionally mark it as ready."""
if name not in self.fields:
raise KeyError(f"Field '{name}' does not exist in node '{self.name}'")
self.fields[name].set_type(type, mark_ready)
# Invalidate readiness cache
self._ready_cache = None
def set_field_containing_type(
self, name: str, containing_type: Any, mark_ready: bool = True
) -> None:
"""Set a field's containing_type and optionally mark it as ready."""
if name not in self.fields:
raise KeyError(f"Field '{name}' does not exist in node '{self.name}'")
self.fields[name].set_containing_type(containing_type, mark_ready)
# Invalidate readiness cache
self._ready_cache = None
def set_field_type_size(
self, name: str, type_size: Any, mark_ready: bool = True
) -> None:
"""Set a field's type_size and optionally mark it as ready."""
if name not in self.fields:
raise KeyError(f"Field '{name}' does not exist in node '{self.name}'")
self.fields[name].set_type_size(type_size, mark_ready)
# Invalidate readiness cache
self._ready_cache = None
def set_field_ready(self, name: str, is_ready: bool = True) -> None:
"""Mark a field as ready or not ready."""
if name not in self.fields:
raise KeyError(f"Field '{name}' does not exist in node '{self.name}'")
self.fields[name].set_ready(is_ready)
# Invalidate readiness cache
self._ready_cache = None
@property
def is_ready(self) -> bool:
"""Check if the node is ready (all fields are ready)."""
# Use cached value if available
if self._ready_cache is not None:
return self._ready_cache
# Calculate readiness only when needed
if not self.fields:
self._ready_cache = False
return False
self._ready_cache = all(elem.ready for elem in self.fields.values())
return self._ready_cache
def get_field_values(self) -> Dict[str, Any]:
"""Get a dictionary of field names to their values."""
return {name: elem.value for name, elem in self.fields.items()}
def get_ready_fields(self) -> Dict[str, Field]:
"""Get all fields that are marked as ready."""
return {name: elem for name, elem in self.fields.items() if elem.ready}

View File

@ -1,135 +0,0 @@
import ast
import logging
from typing import List, Tuple, Dict
import importlib
import inspect
from .dependency_handler import DependencyHandler
from .ir_generation import IRGenerator
from .class_handler import process_vmlinux_class
logger = logging.getLogger(__name__)
def detect_import_statement(tree: ast.AST) -> List[Tuple[str, ast.ImportFrom]]:
"""
Parse AST and detect import statements from vmlinux.
Returns a list of tuples (module_name, imported_item) for vmlinux imports.
Raises SyntaxError for invalid import patterns.
Args:
tree: The AST to parse
Returns:
List of tuples containing (module_name, imported_item) for each vmlinux import
Raises:
SyntaxError: If multiple imports from vmlinux are attempted or import * is used
"""
vmlinux_imports = []
for node in ast.walk(tree):
# Handle "from vmlinux import ..." statements
if isinstance(node, ast.ImportFrom):
if node.module == "vmlinux":
# Check for wildcard import: from vmlinux import *
if any(alias.name == "*" for alias in node.names):
raise SyntaxError(
"Wildcard imports from vmlinux are not supported. "
"Please import specific types explicitly."
)
# Check for multiple imports: from vmlinux import A, B, C
if len(node.names) > 1:
imported_names = [alias.name for alias in node.names]
raise SyntaxError(
f"Multiple imports from vmlinux are not supported. "
f"Found: {', '.join(imported_names)}. "
f"Please use separate import statements for each type."
)
# Check if no specific import is specified (should not happen with valid Python)
if len(node.names) == 0:
raise SyntaxError(
"Import from vmlinux must specify at least one type."
)
# Valid single import
for alias in node.names:
import_name = alias.name
# Use alias if provided, otherwise use the original name (commented)
# as_name = alias.asname if alias.asname else alias.name
vmlinux_imports.append(("vmlinux", node))
logger.info(f"Found vmlinux import: {import_name}")
# Handle "import vmlinux" statements (not typical but should be rejected)
elif isinstance(node, ast.Import):
for alias in node.names:
if alias.name == "vmlinux" or alias.name.startswith("vmlinux."):
raise SyntaxError(
"Direct import of vmlinux module is not supported. "
"Use 'from vmlinux import <type>' instead."
)
logger.info(f"Total vmlinux imports detected: {len(vmlinux_imports)}")
return vmlinux_imports
def vmlinux_proc(tree: ast.AST, module):
import_statements = detect_import_statement(tree)
# initialise dependency handler
handler = DependencyHandler()
# initialise assignment dictionary of name to type
assignments: Dict[str, type] = {}
if not import_statements:
logger.info("No vmlinux imports found")
return
# Import vmlinux module directly
try:
vmlinux_mod = importlib.import_module("vmlinux")
except ImportError:
logger.warning("Could not import vmlinux module")
return
source_file = inspect.getsourcefile(vmlinux_mod)
if source_file is None:
logger.warning("Cannot find source for vmlinux module")
return
with open(source_file, "r") as f:
mod_ast = ast.parse(f.read(), filename=source_file)
for import_mod, import_node in import_statements:
for alias in import_node.names:
imported_name = alias.name
found = False
for mod_node in mod_ast.body:
if (
isinstance(mod_node, ast.ClassDef)
and mod_node.name == imported_name
):
process_vmlinux_class(mod_node, module, handler)
found = True
break
if isinstance(mod_node, ast.Assign):
for target in mod_node.targets:
if isinstance(target, ast.Name) and target.id == imported_name:
process_vmlinux_assign(mod_node, module, assignments)
found = True
break
if found:
break
if not found:
logger.info(
f"{imported_name} not found as ClassDef or Assign in vmlinux"
)
IRGenerator(module, handler)
def process_vmlinux_assign(node, module, assignments: Dict[str, type]):
raise NotImplementedError("Assignment handling has not been implemented yet")

View File

@ -1,8 +0,0 @@
# here, we will iterate through the dependencies and generate IR once dependencies are resolved fully
from .dependency_handler import DependencyHandler
class IRGenerator:
def __init__(self, module, handler):
self.module = module
self.handler: DependencyHandler = handler

View File

@ -1,10 +1,11 @@
#include "vmlinux.h"
#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
#include <bpf/bpf_endian.h>
#define u64 unsigned long long
#define u32 unsigned int
SEC("xdp")
int hello(struct xdp_md *ctx) {
bpf_printk("Hello, World! %ud \n", ctx->data);
bpf_printk("Hello, World!\n");
return XDP_PASS;
}

View File

@ -0,0 +1,39 @@
from pythonbpf import bpf, map, section, bpfglobal, compile
from ctypes import c_void_p, c_int64, c_uint64
from pythonbpf.maps import HashMap
# NOTE: This example tries to reinterpret the variable `x` to a different type.
# We do not allow this for now, as stack allocations are typed and have to be
# done in the first basic block. Allowing re-interpretation would require
# re-allocation of stack space (possibly in a new basic block), which is not
# supported in eBPF yet.
# We can allow bitcasts in cases where the width of the types is the same in
# the future. But for now, we do not allow any re-interpretation of variables.
@bpf
@map
def last() -> HashMap:
return HashMap(key=c_uint64, value=c_uint64, max_entries=3)
@bpf
@section("tracepoint/syscalls/sys_enter_execve")
def hello_world(ctx: c_void_p) -> c_int64:
last.update(0, 1)
x = last.lookup(0)
x = 20
if x == 2:
print("Hello, World!")
else:
print("Goodbye, World!")
return
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile()

View File

@ -3,19 +3,16 @@ import logging
from pythonbpf import compile, bpf, section, bpfglobal, compile_to_ir
from ctypes import c_void_p, c_int64, c_int32
@bpf
@bpfglobal
def somevalue() -> c_int32:
return c_int32(42)
@bpf
@bpfglobal
def somevalue2() -> c_int64:
return c_int64(69)
@bpf
@bpfglobal
def somevalue1() -> c_int32:
@ -24,14 +21,12 @@ def somevalue1() -> c_int32:
# --- Passing examples ---
# Simple constant return
@bpf
@bpfglobal
def g1() -> c_int64:
return c_int64(42)
# Constructor with one constant argument
@bpf
@bpfglobal
@ -67,17 +62,15 @@ def g2() -> c_int64:
# def g6() -> c_int64:
# return c_int64(CONST)
# Constructor with multiple args
# TODO: this is not working. should it work ?
#TODO: this is not working. should it work ?
@bpf
@bpfglobal
def g7() -> c_int64:
return c_int64(1)
# Dataclass call
# TODO: fails with dataclass
#TODO: fails with dataclass
# @dataclass
# class Point:
# x: c_int64
@ -98,7 +91,6 @@ def sometag(ctx: c_void_p) -> c_int64:
print(f"{somevalue}")
return c_int64(1)
@bpf
@bpfglobal
def LICENSE() -> str:

View File

@ -11,7 +11,6 @@ from ctypes import c_void_p, c_int64
# We cannot allocate space for the intermediate type now.
# We probably need to track the ref/deref chain for each variable.
@bpf
@map
def count() -> HashMap:

View File

@ -3,7 +3,6 @@ import logging
from pythonbpf import compile, bpf, section, bpfglobal, compile_to_ir
from ctypes import c_void_p, c_int64
# This should not pass as somevalue is not declared at all.
@bpf
@section("tracepoint/syscalls/sys_enter_execve")
@ -12,7 +11,6 @@ def sometag(ctx: c_void_p) -> c_int64:
print(f"{somevalue}") # noqa: F821
return c_int64(1)
@bpf
@bpfglobal
def LICENSE() -> str:

View File

@ -1,47 +0,0 @@
from pythonbpf import bpf, map, section, bpfglobal, compile, compile_to_ir
from pythonbpf.maps import HashMap
from pythonbpf.helper import XDP_PASS
from vmlinux import struct_ring_buffer_per_cpu # noqa: F401
from vmlinux import struct_xdp_buff # noqa: F401
from vmlinux import struct_xdp_md
from ctypes import c_int64
# Instructions to how to run this program
# 1. Install PythonBPF: pip install pythonbpf
# 2. Run the program: python examples/xdp_pass.py
# 3. Run the program with sudo: sudo tools/check.sh run examples/xdp_pass.o
# 4. Attach object file to any network device with something like ./check.sh xdp examples/xdp_pass.o tailscale0
# 5. send traffic through the device and observe effects
@bpf
@map
def count() -> HashMap:
return HashMap(key=c_int64, value=c_int64, max_entries=1)
@bpf
@section("xdp")
def hello_world(ctx: struct_xdp_md) -> c_int64:
key = 0
one = 1
prev = count().lookup(key)
if prev:
prevval = prev + 1
print(f"count: {prevval}")
count().update(key, prevval)
return XDP_PASS
else:
count().update(key, one)
return XDP_PASS
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile_to_ir("xdp_pass.py", "xdp_pass.ll")
compile()

View File

@ -0,0 +1,69 @@
from pythonbpf import bpf, map, section, bpfglobal, compile, struct
from ctypes import c_void_p, c_int64, c_int32, c_uint64
from pythonbpf.maps import HashMap
from pythonbpf.helper import ktime
# NOTE: This is a comprehensive test combining struct, helper, and map features
# Please note that at line 50, though we have used an absurd expression to test
# the compiler, it is recommended to use named variables to reduce the amount of
# scratch space that needs to be allocated.
@bpf
@struct
class data_t:
pid: c_uint64
ts: c_uint64
@bpf
@map
def last() -> HashMap:
return HashMap(key=c_uint64, value=c_uint64, max_entries=3)
@bpf
@section("tracepoint/syscalls/sys_enter_execve")
def hello_world(ctx: c_void_p) -> c_int64:
dat = data_t()
dat.pid = 123
dat.pid = dat.pid + 1
print(f"pid is {dat.pid}")
tu = 9
last.update(0, tu)
last.update(1, -last.lookup(0))
x = last.lookup(0)
print(f"Map value at index 0: {x}")
x = x + c_int32(1)
print(f"x after adding 32-bit 1 is {x}")
x = ktime() - 121
print(f"ktime - 121 is {x}")
x = last.lookup(0)
x = x + 1
print(f"x is {x}")
if x == 10:
jat = data_t()
jat.ts = 456
print(f"Hello, World!, ts is {jat.ts}")
a = last.lookup(0)
print(f"a is {a}")
last.update(9, 9)
last.update(0, last.lookup(last.lookup(0)) +
last.lookup(last.lookup(0)) + last.lookup(last.lookup(0)))
z = last.lookup(0)
print(f"new map val at index 0 is {z}")
else:
a = last.lookup(0)
print("Goodbye, World!")
c = last.lookup(1 - 1)
print(f"c is {c}")
return
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile()

View File

@ -0,0 +1,27 @@
from pythonbpf import bpf, section, bpfglobal, compile
from ctypes import c_void_p, c_int64
@bpf
@section("tracepoint/syscalls/sys_enter_execve")
def hello_world(ctx: c_void_p) -> c_int64:
x = 1
print(f"Initial x: {x}")
a = 20
x = a
print(f"Updated x with a: {x}")
x = (x + x) * 3
if x == 2:
print("Hello, World!")
else:
print(f"Goodbye, World! {x}")
return
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile()

View File

@ -0,0 +1,34 @@
from pythonbpf import bpf, map, section, bpfglobal, compile
from ctypes import c_void_p, c_int64, c_uint64
from pythonbpf.maps import HashMap
# NOTE: An example of i64** assignment with binops on the RHS
@bpf
@map
def last() -> HashMap:
return HashMap(key=c_uint64, value=c_uint64, max_entries=3)
@bpf
@section("tracepoint/syscalls/sys_enter_execve")
def hello_world(ctx: c_void_p) -> c_int64:
last.update(0, 1)
x = last.lookup(0)
print(f"{x}")
x = x + 1
if x == 2:
print("Hello, World!")
else:
print("Goodbye, World!")
return
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile()

View File

@ -0,0 +1,40 @@
from pythonbpf import bpf, section, bpfglobal, compile, struct
from ctypes import c_void_p, c_int64, c_uint64
from pythonbpf.helper import ktime
@bpf
@struct
class data_t:
pid: c_uint64
ts: c_uint64
@bpf
@section("tracepoint/syscalls/sys_enter_execve")
def hello_world(ctx: c_void_p) -> c_int64:
dat = data_t()
dat.pid = 123
dat.pid = dat.pid + 1
print(f"pid is {dat.pid}")
x = ktime() - 121
print(f"ktime is {x}")
x = 1
x = x + 1
print(f"x is {x}")
if x == 2:
jat = data_t()
jat.ts = 456
print(f"Hello, World!, ts is {jat.ts}")
else:
print("Goodbye, World!")
return
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile()

View File

@ -6,8 +6,8 @@ from ctypes import c_void_p, c_int32
@section("tracepoint/syscalls/sys_enter_execve")
def hello_world(ctx: c_void_p) -> c_int32:
print("Hello, World!")
a = 1 # int64
return c_int32(a) # typecast to int32
a = 1 # int64
return c_int32(a) # typecast to int32
@bpf

View File

@ -26,13 +26,8 @@ import tempfile
class BTFConverter:
def __init__(
self,
btf_source="/sys/kernel/btf/vmlinux",
output_file="vmlinux.py",
keep_intermediate=False,
verbose=False,
):
def __init__(self, btf_source="/sys/kernel/btf/vmlinux", output_file="vmlinux.py",
keep_intermediate=False, verbose=False):
self.btf_source = btf_source
self.output_file = output_file
self.keep_intermediate = keep_intermediate
@ -49,7 +44,11 @@ class BTFConverter:
self.log(f"{description}...")
try:
result = subprocess.run(
cmd, shell=True, check=True, capture_output=True, text=True
cmd,
shell=True,
check=True,
capture_output=True,
text=True
)
if self.verbose and result.stdout:
print(result.stdout)
@ -70,55 +69,51 @@ class BTFConverter:
"""Step 1.5: Preprocess enum definitions."""
self.log("Preprocessing enum definitions...")
with open(input_file, "r") as f:
with open(input_file, 'r') as f:
original_code = f.read()
# Extract anonymous enums
enums = re.findall(
r"(?<!typedef\s)(enum\s*\{[^}]*\})\s*(\w+)\s*(?::\s*\d+)?\s*;",
original_code,
r'(?<!typedef\s)(enum\s*\{[^}]*\})\s*(\w+)\s*(?::\s*\d+)?\s*;',
original_code
)
enum_defs = [enum_block + ";" for enum_block, _ in enums]
enum_defs = [enum_block + ';' for enum_block, _ in enums]
# Replace anonymous enums with int declarations
processed_code = re.sub(
r"(?<!typedef\s)enum\s*\{[^}]*\}\s*(\w+)\s*(?::\s*\d+)?\s*;",
r"int \1;",
original_code,
r'(?<!typedef\s)enum\s*\{[^}]*\}\s*(\w+)\s*(?::\s*\d+)?\s*;',
r'int \1;',
original_code
)
# Prepend enum definitions
if enum_defs:
enum_text = "\n".join(enum_defs) + "\n\n"
enum_text = '\n'.join(enum_defs) + '\n\n'
processed_code = enum_text + processed_code
output_file = os.path.join(self.temp_dir, "vmlinux_processed.h")
with open(output_file, "w") as f:
with open(output_file, 'w') as f:
f.write(processed_code)
return output_file
def step2_5_process_kioctx(self, input_file):
# TODO: this is a very bad bug and design decision. A single struct has an issue mostly.
#TODO: this is a very bad bug and design decision. A single struct has an issue mostly.
"""Step 2.5: Process struct kioctx to extract nested anonymous structs."""
self.log("Processing struct kioctx nested structs...")
with open(input_file, "r") as f:
with open(input_file, 'r') as f:
content = f.read()
# Pattern to match struct kioctx with its full body (handles multiple nesting levels)
kioctx_pattern = (
r"struct\s+kioctx\s*\{(?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*\}\s*;"
)
kioctx_pattern = r'struct\s+kioctx\s*\{(?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*\}\s*;'
def process_kioctx_replacement(match):
full_struct = match.group(0)
self.log(f"Found struct kioctx, length: {len(full_struct)} chars")
# Extract the struct body (everything between outermost { and })
body_match = re.search(
r"struct\s+kioctx\s*\{(.*)\}\s*;", full_struct, re.DOTALL
)
body_match = re.search(r'struct\s+kioctx\s*\{(.*)\}\s*;', full_struct, re.DOTALL)
if not body_match:
return full_struct
@ -126,7 +121,7 @@ class BTFConverter:
# Find all anonymous structs within the body
# Pattern: struct { ... } followed by ; (not a member name)
# anon_struct_pattern = r"struct\s*\{[^}]*\}"
anon_struct_pattern = r'struct\s*\{[^}]*\}'
anon_structs = []
anon_counter = 4 # Start from 4, counting down to 1
@ -136,9 +131,7 @@ class BTFConverter:
anon_struct_content = m.group(0)
# Extract the body of the anonymous struct
anon_body_match = re.search(
r"struct\s*\{(.*)\}", anon_struct_content, re.DOTALL
)
anon_body_match = re.search(r'struct\s*\{(.*)\}', anon_struct_content, re.DOTALL)
if not anon_body_match:
return anon_struct_content
@ -161,7 +154,7 @@ class BTFConverter:
processed_body = body
# Find all occurrences and process them
pattern_with_semicolon = r"struct\s*\{([^}]*)\}\s*;"
pattern_with_semicolon = r'struct\s*\{([^}]*)\}\s*;'
matches = list(re.finditer(pattern_with_semicolon, body, re.DOTALL))
if not matches:
@ -185,16 +178,14 @@ class BTFConverter:
# Replace in the body
replacement = f"struct {anon_name} {member_name};"
processed_body = (
processed_body[:start_pos] + replacement + processed_body[end_pos:]
)
processed_body = processed_body[:start_pos] + replacement + processed_body[end_pos:]
anon_counter -= 1
# Rebuild the complete definition
if anon_structs:
# Prepend the anonymous struct definitions
anon_definitions = "\n".join(anon_structs) + "\n\n"
anon_definitions = '\n'.join(anon_structs) + '\n\n'
new_struct = f"struct kioctx {{{processed_body}}};"
return anon_definitions + new_struct
else:
@ -202,11 +193,14 @@ class BTFConverter:
# Apply the transformation
processed_content = re.sub(
kioctx_pattern, process_kioctx_replacement, content, flags=re.DOTALL
kioctx_pattern,
process_kioctx_replacement,
content,
flags=re.DOTALL
)
output_file = os.path.join(self.temp_dir, "vmlinux_kioctx_processed.h")
with open(output_file, "w") as f:
with open(output_file, 'w') as f:
f.write(processed_content)
self.log(f"Saved kioctx-processed output to {output_file}")
@ -224,7 +218,7 @@ class BTFConverter:
output_file = os.path.join(self.temp_dir, "vmlinux_raw.py")
cmd = (
f"clang2py {input_file} -o {output_file} "
f'--clang-args="-fno-ms-extensions -I/usr/include -I/usr/include/linux"'
f"--clang-args=\"-fno-ms-extensions -I/usr/include -I/usr/include/linux\""
)
self.run_command(cmd, "Converting to Python ctypes")
return output_file
@ -240,21 +234,14 @@ class BTFConverter:
data = re.sub(r"\('_[0-9]+',\s*ctypes\.[a-zA-Z0-9_]+,\s*0\),?\s*\n?", "", data)
# Replace ('_20', ctypes.c_uint64, 64) → ('_20', ctypes.c_uint64)
data = re.sub(
r"\('(_[0-9]+)',\s*(ctypes\.[a-zA-Z0-9_]+),\s*[0-9]+\)", r"('\1', \2)", data
)
data = re.sub(r"\('(_[0-9]+)',\s*(ctypes\.[a-zA-Z0-9_]+),\s*[0-9]+\)", r"('\1', \2)", data)
# Replace ('_20', ctypes.c_char, 8) with ('_20', ctypes.c_uint8, 8)
data = re.sub(r"(ctypes\.c_char)(\s*,\s*\d+\))", r"ctypes.c_uint8\2", data)
# below to replace those c_bool with bitfield greater than 8
def repl(m):
name, bits = m.groups()
return (
f"('{name}', ctypes.c_uint32, {bits})" if int(bits) > 8 else m.group(0)
)
data = re.sub(r"\('([^']+)',\s*ctypes\.c_bool,\s*(\d+)\)", repl, data)
data = re.sub(
r"(ctypes\.c_char)(\s*,\s*\d+\))",
r"ctypes.c_uint8\2",
data
)
# Remove ctypes. prefix from invalid entries
invalid_ctypes = ["bpf_iter_state", "_cache_type", "fs_context_purpose"]
@ -271,7 +258,6 @@ class BTFConverter:
if not self.keep_intermediate and self.temp_dir != ".":
self.log(f"Cleaning up temporary directory: {self.temp_dir}")
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
def convert(self):
@ -295,7 +281,6 @@ class BTFConverter:
except Exception as e:
print(f"\n✗ Error during conversion: {e}", file=sys.stderr)
import traceback
traceback.print_exc()
sys.exit(1)
finally:
@ -308,13 +293,18 @@ class BTFConverter:
dependencies = {
"bpftool": "bpftool --version",
"clang": "clang --version",
"clang2py": "clang2py --version",
"clang2py": "clang2py --version"
}
missing = []
for tool, cmd in dependencies.items():
try:
subprocess.run(cmd, shell=True, check=True, capture_output=True)
subprocess.run(
cmd,
shell=True,
check=True,
capture_output=True
)
except subprocess.CalledProcessError:
missing.append(tool)
@ -336,31 +326,31 @@ Examples:
%(prog)s
%(prog)s -o kernel_types.py
%(prog)s --btf-source /sys/kernel/btf/custom_module -k -v
""",
"""
)
parser.add_argument(
"--btf-source",
default="/sys/kernel/btf/vmlinux",
help="Path to BTF source (default: /sys/kernel/btf/vmlinux)",
help="Path to BTF source (default: /sys/kernel/btf/vmlinux)"
)
parser.add_argument(
"-o",
"--output",
"-o", "--output",
default="vmlinux.py",
help="Output Python file (default: vmlinux.py)",
help="Output Python file (default: vmlinux.py)"
)
parser.add_argument(
"-k",
"--keep-intermediate",
"-k", "--keep-intermediate",
action="store_true",
help="Keep intermediate files (vmlinux.h, vmlinux_processed.h, etc.)",
help="Keep intermediate files (vmlinux.h, vmlinux_processed.h, etc.)"
)
parser.add_argument(
"-v", "--verbose", action="store_true", help="Enable verbose output"
"-v", "--verbose",
action="store_true",
help="Enable verbose output"
)
args = parser.parse_args()
@ -369,7 +359,7 @@ Examples:
btf_source=args.btf_source,
output_file=args.output,
keep_intermediate=args.keep_intermediate,
verbose=args.verbose,
verbose=args.verbose
)
converter.convert()