14 Commits

Author SHA1 Message Date
5d0a888542 Remove deadcode and seperate modules for pythonbpf.functions 2025-10-13 04:41:46 +05:30
0042280ff1 Rename public API and remove deadcode in return_utils 2025-10-13 04:23:58 +05:30
7a67041ea3 Move CallHandlerRegistry to expr/call_registry.py, annotate eval_expr 2025-10-13 04:16:22 +05:30
45e6ce5e5c Move deref_to_depth to expr/ir_ops.py 2025-10-13 04:01:27 +05:30
c5f0a2806f Make printk_emiiter return True to prevent bogus logger warnings in eval_expr 2025-10-13 02:40:34 +05:30
b0ea93a786 Merge pull request #50 from pythonbpf/dep_inv
Use dependency inversion to remove handler delayed import in eval_expr
2025-10-13 02:32:18 +05:30
fc058c4341 Use dependency inversion to remove handler delayed import in eval_expr 2025-10-13 02:28:00 +05:30
158cc42e1e Move binop handling logic to expr_pass, remove delayed imports of get_operand_value 2025-10-13 00:36:42 +05:30
2a1eabc10d Fix regression in struct_perf_output 2025-10-13 00:00:43 +05:30
31645f0316 Merge pull request #40 from pythonbpf/refactor_assign
Refactor assignment statement handling and the typing mechanism around it
2025-10-12 12:19:51 +05:30
21ce041353 Refactor hist() calls to use dot notation 2025-10-10 20:45:07 +05:30
6402cf7be5 remove todos and move to projects on github. 2025-10-08 22:27:51 +05:30
9a96e1247b Merge pull request #29 from pythonbpf/smol_pp
add patch for Kernel 6.14 BTF in transpiler
2025-10-08 21:47:49 +05:30
989134f4be add patch for Kernel 6.14 BTF
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
2025-10-08 21:47:02 +05:30
16 changed files with 525 additions and 447 deletions

View File

@ -83,14 +83,14 @@ def hist() -> HashMap:
def hello(ctx: c_void_p) -> c_int64: def hello(ctx: c_void_p) -> c_int64:
process_id = pid() process_id = pid()
one = 1 one = 1
prev = hist().lookup(process_id) prev = hist.lookup(process_id)
if prev: if prev:
previous_value = prev + 1 previous_value = prev + 1
print(f"count: {previous_value} with {process_id}") print(f"count: {previous_value} with {process_id}")
hist().update(process_id, previous_value) hist.update(process_id, previous_value)
return c_int64(0) return c_int64(0)
else: else:
hist().update(process_id, one) hist.update(process_id, one)
return c_int64(0) return c_int64(0)

13
TODO.md
View File

@ -1,13 +0,0 @@
## Short term
- Implement enough functionality to port the BCC tutorial examples in PythonBPF
- Add all maps
- XDP support in pylibbpf
- ringbuf support
- Add oneline IfExpr conditionals (wishlist)
## Long term
- Refactor the codebase to be better than a hackathon project
- Port to C++ and use actual LLVM?
- Fix struct_kioctx issue in the vmlinux transpiler

View File

@ -1,110 +0,0 @@
import ast
from llvmlite import ir
from logging import Logger
import logging
from pythonbpf.expr import get_base_type_and_depth, deref_to_depth, eval_expr
logger: Logger = logging.getLogger(__name__)
def get_operand_value(
func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None
):
"""Extract the value from an operand, handling variables and constants."""
logger.info(f"Getting operand value for: {ast.dump(operand)}")
if isinstance(operand, ast.Name):
if operand.id in local_sym_tab:
var = local_sym_tab[operand.id].var
var_type = var.type
base_type, depth = get_base_type_and_depth(var_type)
logger.info(f"var is {var}, base_type is {base_type}, depth is {depth}")
val = deref_to_depth(func, builder, var, depth)
return val
raise ValueError(f"Undefined variable: {operand.id}")
elif isinstance(operand, ast.Constant):
if isinstance(operand.value, int):
cst = ir.Constant(ir.IntType(64), int(operand.value))
return cst
raise TypeError(f"Unsupported constant type: {type(operand.value)}")
elif isinstance(operand, ast.BinOp):
res = handle_binary_op_impl(
func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab
)
return res
else:
res = eval_expr(
func, module, builder, operand, local_sym_tab, map_sym_tab, structs_sym_tab
)
if res is None:
raise ValueError(f"Failed to evaluate call expression: {operand}")
val, _ = res
logger.info(f"Evaluated expr to {val} of type {val.type}")
base_type, depth = get_base_type_and_depth(val.type)
if depth > 0:
val = deref_to_depth(func, builder, val, depth)
return val
raise TypeError(f"Unsupported operand type: {type(operand)}")
def handle_binary_op_impl(
func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None
):
op = rval.op
left = get_operand_value(
func, module, rval.left, builder, local_sym_tab, map_sym_tab, structs_sym_tab
)
right = get_operand_value(
func, module, rval.right, builder, local_sym_tab, map_sym_tab, structs_sym_tab
)
logger.info(f"left is {left}, right is {right}, op is {op}")
# NOTE: Before doing the operation, if the operands are integers
# we always extend them to i64. The assignment to LHS will take
# care of truncation if needed.
if isinstance(left.type, ir.IntType) and left.type.width < 64:
left = builder.sext(left, ir.IntType(64))
if isinstance(right.type, ir.IntType) and right.type.width < 64:
right = builder.sext(right, ir.IntType(64))
# Map AST operation nodes to LLVM IR builder methods
op_map = {
ast.Add: builder.add,
ast.Sub: builder.sub,
ast.Mult: builder.mul,
ast.Div: builder.sdiv,
ast.Mod: builder.srem,
ast.LShift: builder.shl,
ast.RShift: builder.lshr,
ast.BitOr: builder.or_,
ast.BitXor: builder.xor,
ast.BitAnd: builder.and_,
ast.FloorDiv: builder.udiv,
}
if type(op) in op_map:
result = op_map[type(op)](left, right)
return result
else:
raise SyntaxError("Unsupported binary operation")
def handle_binary_op(
func,
module,
rval,
builder,
var_name,
local_sym_tab,
map_sym_tab,
structs_sym_tab=None,
):
result = handle_binary_op_impl(
func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab
)
if var_name and var_name in local_sym_tab:
logger.info(
f"Storing result {result} into variable {local_sym_tab[var_name].var}"
)
builder.store(result, local_sym_tab[var_name].var)
return result, result.type

View File

@ -1,5 +1,7 @@
from .expr_pass import eval_expr, handle_expr from .expr_pass import eval_expr, handle_expr, get_operand_value
from .type_normalization import convert_to_bool, get_base_type_and_depth, deref_to_depth from .type_normalization import convert_to_bool, get_base_type_and_depth
from .ir_ops import deref_to_depth
from .call_registry import CallHandlerRegistry
__all__ = [ __all__ = [
"eval_expr", "eval_expr",
@ -7,4 +9,6 @@ __all__ = [
"convert_to_bool", "convert_to_bool",
"get_base_type_and_depth", "get_base_type_and_depth",
"deref_to_depth", "deref_to_depth",
"get_operand_value",
"CallHandlerRegistry",
] ]

View File

@ -0,0 +1,20 @@
class CallHandlerRegistry:
"""Registry for handling different types of calls (helpers, etc.)"""
_handler = None
@classmethod
def set_handler(cls, handler):
"""Set the handler for unknown calls"""
cls._handler = handler
@classmethod
def handle_call(
cls, call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
):
"""Handle a call using the registered handler"""
if cls._handler is None:
return None
return cls._handler(
call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
)

View File

@ -5,10 +5,20 @@ import logging
from typing import Dict from typing import Dict
from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes
from .type_normalization import convert_to_bool, handle_comparator from .call_registry import CallHandlerRegistry
from .type_normalization import (
convert_to_bool,
handle_comparator,
get_base_type_and_depth,
deref_to_depth,
)
logger: Logger = logging.getLogger(__name__) logger: Logger = logging.getLogger(__name__)
# ============================================================================
# Leaf Handlers (No Recursive eval_expr calls)
# ============================================================================
def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder): def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder):
"""Handle ast.Name expressions.""" """Handle ast.Name expressions."""
@ -21,10 +31,24 @@ def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder
return None return None
def _handle_constant_expr(expr: ast.Constant): def _handle_constant_expr(module, builder, expr: ast.Constant):
"""Handle ast.Constant expressions.""" """Handle ast.Constant expressions."""
if isinstance(expr.value, int) or isinstance(expr.value, bool): if isinstance(expr.value, int) or isinstance(expr.value, bool):
return ir.Constant(ir.IntType(64), int(expr.value)), ir.IntType(64) return ir.Constant(ir.IntType(64), int(expr.value)), ir.IntType(64)
elif isinstance(expr.value, str):
str_name = f".str.{id(expr)}"
str_bytes = expr.value.encode("utf-8") + b"\x00"
str_type = ir.ArrayType(ir.IntType(8), len(str_bytes))
str_constant = ir.Constant(str_type, bytearray(str_bytes))
# Create global variable
global_str = ir.GlobalVariable(module, str_type, name=str_name)
global_str.linkage = "internal"
global_str.global_constant = True
global_str.initializer = str_constant
str_ptr = builder.bitcast(global_str, ir.PointerType(ir.IntType(8)))
return str_ptr, ir.PointerType(ir.IntType(8))
else: else:
logger.error(f"Unsupported constant type {ast.dump(expr)}") logger.error(f"Unsupported constant type {ast.dump(expr)}")
return None return None
@ -88,6 +112,118 @@ def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilde
return val, local_sym_tab[arg.id].ir_type return val, local_sym_tab[arg.id].ir_type
# ============================================================================
# Binary Operations
# ============================================================================
def get_operand_value(
func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None
):
"""Extract the value from an operand, handling variables and constants."""
logger.info(f"Getting operand value for: {ast.dump(operand)}")
if isinstance(operand, ast.Name):
if operand.id in local_sym_tab:
var = local_sym_tab[operand.id].var
var_type = var.type
base_type, depth = get_base_type_and_depth(var_type)
logger.info(f"var is {var}, base_type is {base_type}, depth is {depth}")
val = deref_to_depth(func, builder, var, depth)
return val
raise ValueError(f"Undefined variable: {operand.id}")
elif isinstance(operand, ast.Constant):
if isinstance(operand.value, int):
cst = ir.Constant(ir.IntType(64), int(operand.value))
return cst
raise TypeError(f"Unsupported constant type: {type(operand.value)}")
elif isinstance(operand, ast.BinOp):
res = _handle_binary_op_impl(
func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab
)
return res
else:
res = eval_expr(
func, module, builder, operand, local_sym_tab, map_sym_tab, structs_sym_tab
)
if res is None:
raise ValueError(f"Failed to evaluate call expression: {operand}")
val, _ = res
logger.info(f"Evaluated expr to {val} of type {val.type}")
base_type, depth = get_base_type_and_depth(val.type)
if depth > 0:
val = deref_to_depth(func, builder, val, depth)
return val
raise TypeError(f"Unsupported operand type: {type(operand)}")
def _handle_binary_op_impl(
func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None
):
op = rval.op
left = get_operand_value(
func, module, rval.left, builder, local_sym_tab, map_sym_tab, structs_sym_tab
)
right = get_operand_value(
func, module, rval.right, builder, local_sym_tab, map_sym_tab, structs_sym_tab
)
logger.info(f"left is {left}, right is {right}, op is {op}")
# NOTE: Before doing the operation, if the operands are integers
# we always extend them to i64. The assignment to LHS will take
# care of truncation if needed.
if isinstance(left.type, ir.IntType) and left.type.width < 64:
left = builder.sext(left, ir.IntType(64))
if isinstance(right.type, ir.IntType) and right.type.width < 64:
right = builder.sext(right, ir.IntType(64))
# Map AST operation nodes to LLVM IR builder methods
op_map = {
ast.Add: builder.add,
ast.Sub: builder.sub,
ast.Mult: builder.mul,
ast.Div: builder.sdiv,
ast.Mod: builder.srem,
ast.LShift: builder.shl,
ast.RShift: builder.lshr,
ast.BitOr: builder.or_,
ast.BitXor: builder.xor,
ast.BitAnd: builder.and_,
ast.FloorDiv: builder.udiv,
}
if type(op) in op_map:
result = op_map[type(op)](left, right)
return result
else:
raise SyntaxError("Unsupported binary operation")
def _handle_binary_op(
func,
module,
rval,
builder,
var_name,
local_sym_tab,
map_sym_tab,
structs_sym_tab=None,
):
result = _handle_binary_op_impl(
func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab
)
if var_name and var_name in local_sym_tab:
logger.info(
f"Storing result {result} into variable {local_sym_tab[var_name].var}"
)
builder.store(result, local_sym_tab[var_name].var)
return result, result.type
# ============================================================================
# Comparison and Unary Operations
# ============================================================================
def _handle_ctypes_call( def _handle_ctypes_call(
func, func,
module, module,
@ -180,8 +316,6 @@ def _handle_unary_op(
logger.error("Only 'not' and '-' unary operators are supported") logger.error("Only 'not' and '-' unary operators are supported")
return None return None
from pythonbpf.binary_ops import get_operand_value
operand = get_operand_value( operand = get_operand_value(
func, module, expr.operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab func, module, expr.operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab
) )
@ -200,6 +334,11 @@ def _handle_unary_op(
return result, ir.IntType(64) return result, ir.IntType(64)
# ============================================================================
# Boolean Operations
# ============================================================================
def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab): def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab):
"""Handle `and` boolean operations.""" """Handle `and` boolean operations."""
@ -330,6 +469,11 @@ def _handle_boolean_op(
return None return None
# ============================================================================
# Expression Dispatcher
# ============================================================================
def eval_expr( def eval_expr(
func, func,
module, module,
@ -343,7 +487,7 @@ def eval_expr(
if isinstance(expr, ast.Name): if isinstance(expr, ast.Name):
return _handle_name_expr(expr, local_sym_tab, builder) return _handle_name_expr(expr, local_sym_tab, builder)
elif isinstance(expr, ast.Constant): elif isinstance(expr, ast.Constant):
return _handle_constant_expr(expr) return _handle_constant_expr(module, builder, expr)
elif isinstance(expr, ast.Call): elif isinstance(expr, ast.Call):
if isinstance(expr.func, ast.Name) and expr.func.id == "deref": if isinstance(expr.func, ast.Name) and expr.func.id == "deref":
return _handle_deref_call(expr, local_sym_tab, builder) return _handle_deref_call(expr, local_sym_tab, builder)
@ -359,57 +503,18 @@ def eval_expr(
structs_sym_tab, structs_sym_tab,
) )
# delayed import to avoid circular dependency result = CallHandlerRegistry.handle_call(
from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call expr, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
)
if result is not None:
return result
if isinstance(expr.func, ast.Name) and HelperHandlerRegistry.has_handler( logger.warning(f"Unknown call: {ast.dump(expr)}")
expr.func.id return None
):
return handle_helper_call(
expr,
module,
builder,
func,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
elif isinstance(expr.func, ast.Attribute):
logger.info(f"Handling method call: {ast.dump(expr.func)}")
if isinstance(expr.func.value, ast.Call) and isinstance(
expr.func.value.func, ast.Name
):
method_name = expr.func.attr
if HelperHandlerRegistry.has_handler(method_name):
return handle_helper_call(
expr,
module,
builder,
func,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
elif isinstance(expr.func.value, ast.Name):
obj_name = expr.func.value.id
method_name = expr.func.attr
if obj_name in map_sym_tab:
if HelperHandlerRegistry.has_handler(method_name):
return handle_helper_call(
expr,
module,
builder,
func,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
elif isinstance(expr, ast.Attribute): elif isinstance(expr, ast.Attribute):
return _handle_attribute_expr(expr, local_sym_tab, structs_sym_tab, builder) return _handle_attribute_expr(expr, local_sym_tab, structs_sym_tab, builder)
elif isinstance(expr, ast.BinOp): elif isinstance(expr, ast.BinOp):
from pythonbpf.binary_ops import handle_binary_op return _handle_binary_op(
return handle_binary_op(
func, func,
module, module,
expr, expr,

50
pythonbpf/expr/ir_ops.py Normal file
View File

@ -0,0 +1,50 @@
import logging
from llvmlite import ir
logger = logging.getLogger(__name__)
def deref_to_depth(func, builder, val, target_depth):
"""Dereference a pointer to a certain depth."""
cur_val = val
cur_type = val.type
for depth in range(target_depth):
if not isinstance(val.type, ir.PointerType):
logger.error("Cannot dereference further, non-pointer type")
return None
# dereference with null check
pointee_type = cur_type.pointee
null_check_block = builder.block
not_null_block = func.append_basic_block(name=f"deref_not_null_{depth}")
merge_block = func.append_basic_block(name=f"deref_merge_{depth}")
null_ptr = ir.Constant(cur_type, None)
is_not_null = builder.icmp_signed("!=", cur_val, null_ptr)
logger.debug(f"Inserted null check for pointer at depth {depth}")
builder.cbranch(is_not_null, not_null_block, merge_block)
builder.position_at_end(not_null_block)
dereferenced_val = builder.load(cur_val)
logger.debug(f"Dereferenced to depth {depth - 1}, type: {pointee_type}")
builder.branch(merge_block)
builder.position_at_end(merge_block)
phi = builder.phi(pointee_type, name=f"deref_result_{depth}")
zero_value = (
ir.Constant(pointee_type, 0)
if isinstance(pointee_type, ir.IntType)
else ir.Constant(pointee_type, None)
)
phi.add_incoming(zero_value, null_check_block)
phi.add_incoming(dereferenced_val, not_null_block)
# Continue with phi result
cur_val = phi
cur_type = pointee_type
return cur_val

View File

@ -1,6 +1,7 @@
from llvmlite import ir
import logging import logging
import ast import ast
from llvmlite import ir
from .ir_ops import deref_to_depth
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -26,52 +27,6 @@ def get_base_type_and_depth(ir_type):
return cur_type, depth return cur_type, depth
def deref_to_depth(func, builder, val, target_depth):
"""Dereference a pointer to a certain depth."""
cur_val = val
cur_type = val.type
for depth in range(target_depth):
if not isinstance(val.type, ir.PointerType):
logger.error("Cannot dereference further, non-pointer type")
return None
# dereference with null check
pointee_type = cur_type.pointee
null_check_block = builder.block
not_null_block = func.append_basic_block(name=f"deref_not_null_{depth}")
merge_block = func.append_basic_block(name=f"deref_merge_{depth}")
null_ptr = ir.Constant(cur_type, None)
is_not_null = builder.icmp_signed("!=", cur_val, null_ptr)
logger.debug(f"Inserted null check for pointer at depth {depth}")
builder.cbranch(is_not_null, not_null_block, merge_block)
builder.position_at_end(not_null_block)
dereferenced_val = builder.load(cur_val)
logger.debug(f"Dereferenced to depth {depth - 1}, type: {pointee_type}")
builder.branch(merge_block)
builder.position_at_end(merge_block)
phi = builder.phi(pointee_type, name=f"deref_result_{depth}")
zero_value = (
ir.Constant(pointee_type, 0)
if isinstance(pointee_type, ir.IntType)
else ir.Constant(pointee_type, None)
)
phi.add_incoming(zero_value, null_check_block)
phi.add_incoming(dereferenced_val, not_null_block)
# Continue with phi result
cur_val = phi
cur_type = pointee_type
return cur_val
def _normalize_types(func, builder, lhs, rhs): def _normalize_types(func, builder, lhs, rhs):
"""Normalize types for comparison.""" """Normalize types for comparison."""

View File

@ -1,22 +0,0 @@
from typing import Dict
class StatementHandlerRegistry:
"""Registry for statement handlers."""
_handlers: Dict = {}
@classmethod
def register(cls, stmt_type):
"""Register a handler for a specific statement type."""
def decorator(handler):
cls._handlers[stmt_type] = handler
return handler
return decorator
@classmethod
def __getitem__(cls, stmt_type):
"""Get the handler for a specific statement type."""
return cls._handlers.get(stmt_type, None)

View File

@ -0,0 +1,88 @@
import ast
def get_probe_string(func_node):
"""Extract the probe string from the decorator of the function node"""
# TODO: right now we have the whole string in the section decorator
# But later we can implement typed tuples for tracepoints and kprobes
# For helper functions, we return "helper"
for decorator in func_node.decorator_list:
if isinstance(decorator, ast.Name) and decorator.id == "bpfglobal":
return None
if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Name):
if decorator.func.id == "section" and len(decorator.args) == 1:
arg = decorator.args[0]
if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
return arg.value
return "helper"
def is_global_function(func_node):
"""Check if the function is a global"""
for decorator in func_node.decorator_list:
if isinstance(decorator, ast.Name) and decorator.id in (
"map",
"bpfglobal",
"struct",
):
return True
return False
def infer_return_type(func_node: ast.FunctionDef):
if not isinstance(func_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
raise TypeError("Expected ast.FunctionDef")
if func_node.returns is not None:
try:
return ast.unparse(func_node.returns)
except Exception:
node = func_node.returns
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Attribute):
return getattr(node, "attr", type(node).__name__)
try:
return str(node)
except Exception:
return type(node).__name__
found_type = None
def _expr_type(e):
if e is None:
return "None"
if isinstance(e, ast.Constant):
return type(e.value).__name__
if isinstance(e, ast.Name):
return e.id
if isinstance(e, ast.Call):
f = e.func
if isinstance(f, ast.Name):
return f.id
if isinstance(f, ast.Attribute):
try:
return ast.unparse(f)
except Exception:
return getattr(f, "attr", type(f).__name__)
try:
return ast.unparse(f)
except Exception:
return type(f).__name__
if isinstance(e, ast.Attribute):
try:
return ast.unparse(e)
except Exception:
return getattr(e, "attr", type(e).__name__)
try:
return ast.unparse(e)
except Exception:
return type(e).__name__
for walked_node in ast.walk(func_node):
if isinstance(walked_node, ast.Return):
t = _expr_type(walked_node.value)
if found_type is None:
found_type = t
elif found_type != t:
raise ValueError(f"Conflicting return types: {found_type} vs {t}")
return found_type or "None"

View File

@ -14,27 +14,125 @@ from pythonbpf.assign_pass import (
) )
from pythonbpf.allocation_pass import handle_assign_allocation, allocate_temp_pool from pythonbpf.allocation_pass import handle_assign_allocation, allocate_temp_pool
from .return_utils import _handle_none_return, _handle_xdp_return, _is_xdp_name from .return_utils import handle_none_return, handle_xdp_return, is_xdp_name
from .function_metadata import get_probe_string, is_global_function, infer_return_type
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_probe_string(func_node): # ============================================================================
"""Extract the probe string from the decorator of the function node.""" # SECTION 1: Memory Allocation
# TODO: right now we have the whole string in the section decorator # ============================================================================
# But later we can implement typed tuples for tracepoints and kprobes
# For helper functions, we return "helper"
for decorator in func_node.decorator_list:
if isinstance(decorator, ast.Name) and decorator.id == "bpfglobal": def count_temps_in_call(call_node, local_sym_tab):
return None """Count the number of temporary variables needed for a function call."""
if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Name):
if decorator.func.id == "section" and len(decorator.args) == 1: count = 0
arg = decorator.args[0] is_helper = False
if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
return arg.value # NOTE: We exclude print calls for now
return "helper" if isinstance(call_node.func, ast.Name):
if (
HelperHandlerRegistry.has_handler(call_node.func.id)
and call_node.func.id != "print"
):
is_helper = True
elif isinstance(call_node.func, ast.Attribute):
if HelperHandlerRegistry.has_handler(call_node.func.attr):
is_helper = True
if not is_helper:
return 0
for arg in call_node.args:
# NOTE: Count all non-name arguments
# For struct fields, if it is being passed as an argument,
# The struct object should already exist in the local_sym_tab
if not isinstance(arg, ast.Name) and not (
isinstance(arg, ast.Attribute) and arg.value.id in local_sym_tab
):
count += 1
return count
def handle_if_allocation(
module, builder, stmt, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
):
"""Recursively handle allocations in if/else branches."""
if stmt.body:
allocate_mem(
module,
builder,
stmt.body,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
if stmt.orelse:
allocate_mem(
module,
builder,
stmt.orelse,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
def allocate_mem(
module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
):
max_temps_needed = 0
def update_max_temps_for_stmt(stmt):
nonlocal max_temps_needed
temps_needed = 0
if isinstance(stmt, ast.If):
for s in stmt.body:
update_max_temps_for_stmt(s)
for s in stmt.orelse:
update_max_temps_for_stmt(s)
return
for node in ast.walk(stmt):
if isinstance(node, ast.Call):
temps_needed += count_temps_in_call(node, local_sym_tab)
max_temps_needed = max(max_temps_needed, temps_needed)
for stmt in body:
update_max_temps_for_stmt(stmt)
# Handle allocations
if isinstance(stmt, ast.If):
handle_if_allocation(
module,
builder,
stmt,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
elif isinstance(stmt, ast.Assign):
handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab)
allocate_temp_pool(builder, max_temps_needed, local_sym_tab)
return local_sym_tab
# ============================================================================
# SECTION 2: Statement Handlers
# ============================================================================
def handle_assign( def handle_assign(
@ -146,9 +244,9 @@ def handle_if(
def handle_return(builder, stmt, local_sym_tab, ret_type): def handle_return(builder, stmt, local_sym_tab, ret_type):
logger.info(f"Handling return statement: {ast.dump(stmt)}") logger.info(f"Handling return statement: {ast.dump(stmt)}")
if stmt.value is None: if stmt.value is None:
return _handle_none_return(builder) return handle_none_return(builder)
elif isinstance(stmt.value, ast.Name) and _is_xdp_name(stmt.value.id): elif isinstance(stmt.value, ast.Name) and is_xdp_name(stmt.value.id):
return _handle_xdp_return(stmt, builder, ret_type) return handle_xdp_return(stmt, builder, ret_type)
else: else:
val = eval_expr( val = eval_expr(
func=None, func=None,
@ -207,108 +305,9 @@ def process_stmt(
return did_return return did_return
def handle_if_allocation( # ============================================================================
module, builder, stmt, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab # SECTION 3: Function Body Processing
): # ============================================================================
"""Recursively handle allocations in if/else branches."""
if stmt.body:
allocate_mem(
module,
builder,
stmt.body,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
if stmt.orelse:
allocate_mem(
module,
builder,
stmt.orelse,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
def count_temps_in_call(call_node, local_sym_tab):
"""Count the number of temporary variables needed for a function call."""
count = 0
is_helper = False
# NOTE: We exclude print calls for now
if isinstance(call_node.func, ast.Name):
if (
HelperHandlerRegistry.has_handler(call_node.func.id)
and call_node.func.id != "print"
):
is_helper = True
elif isinstance(call_node.func, ast.Attribute):
if HelperHandlerRegistry.has_handler(call_node.func.attr):
is_helper = True
if not is_helper:
return 0
for arg in call_node.args:
# NOTE: Count all non-name arguments
# For struct fields, if it is being passed as an argument,
# The struct object should already exist in the local_sym_tab
if not isinstance(arg, ast.Name) and not (
isinstance(arg, ast.Attribute) and arg.value.id in local_sym_tab
):
count += 1
return count
def allocate_mem(
module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
):
max_temps_needed = 0
def update_max_temps_for_stmt(stmt):
nonlocal max_temps_needed
temps_needed = 0
if isinstance(stmt, ast.If):
for s in stmt.body:
update_max_temps_for_stmt(s)
for s in stmt.orelse:
update_max_temps_for_stmt(s)
return
for node in ast.walk(stmt):
if isinstance(node, ast.Call):
temps_needed += count_temps_in_call(node, local_sym_tab)
max_temps_needed = max(max_temps_needed, temps_needed)
for stmt in body:
update_max_temps_for_stmt(stmt)
# Handle allocations
if isinstance(stmt, ast.If):
handle_if_allocation(
module,
builder,
stmt,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
elif isinstance(stmt, ast.Assign):
handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab)
allocate_temp_pool(builder, max_temps_needed, local_sym_tab)
return local_sym_tab
def process_func_body( def process_func_body(
@ -390,18 +389,14 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t
return func return func
# ============================================================================
# SECTION 4: Top-Level Function Processor
# ============================================================================
def func_proc(tree, module, chunks, map_sym_tab, structs_sym_tab): def func_proc(tree, module, chunks, map_sym_tab, structs_sym_tab):
for func_node in chunks: for func_node in chunks:
is_global = False if is_global_function(func_node):
for decorator in func_node.decorator_list:
if isinstance(decorator, ast.Name) and decorator.id in (
"map",
"bpfglobal",
"struct",
):
is_global = True
break
if is_global:
continue continue
func_type = get_probe_string(func_node) func_type = get_probe_string(func_node)
logger.info(f"Found probe_string of {func_node.name}: {func_type}") logger.info(f"Found probe_string of {func_node.name}: {func_type}")
@ -415,67 +410,7 @@ def func_proc(tree, module, chunks, map_sym_tab, structs_sym_tab):
) )
def infer_return_type(func_node: ast.FunctionDef): # TODO: WIP, for string assignment to fixed-size arrays
if not isinstance(func_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
raise TypeError("Expected ast.FunctionDef")
if func_node.returns is not None:
try:
return ast.unparse(func_node.returns)
except Exception:
node = func_node.returns
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Attribute):
return getattr(node, "attr", type(node).__name__)
try:
return str(node)
except Exception:
return type(node).__name__
found_type = None
def _expr_type(e):
if e is None:
return "None"
if isinstance(e, ast.Constant):
return type(e.value).__name__
if isinstance(e, ast.Name):
return e.id
if isinstance(e, ast.Call):
f = e.func
if isinstance(f, ast.Name):
return f.id
if isinstance(f, ast.Attribute):
try:
return ast.unparse(f)
except Exception:
return getattr(f, "attr", type(f).__name__)
try:
return ast.unparse(f)
except Exception:
return type(f).__name__
if isinstance(e, ast.Attribute):
try:
return ast.unparse(e)
except Exception:
return getattr(e, "attr", type(e).__name__)
try:
return ast.unparse(e)
except Exception:
return type(e).__name__
for walked_node in ast.walk(func_node):
if isinstance(walked_node, ast.Return):
t = _expr_type(walked_node.value)
if found_type is None:
found_type = t
elif found_type != t:
raise ValueError(f"Conflicting return types: {found_type} vs {t}")
return found_type or "None"
# For string assignment to fixed-size arrays
def assign_string_to_array(builder, target_array_ptr, source_string_ptr, array_length): def assign_string_to_array(builder, target_array_ptr, source_string_ptr, array_length):
""" """
Copy a string (i8*) to a fixed-size array ([N x i8]*) Copy a string (i8*) to a fixed-size array ([N x i8]*)

View File

@ -14,19 +14,19 @@ XDP_ACTIONS = {
} }
def _handle_none_return(builder) -> bool: def handle_none_return(builder) -> bool:
"""Handle return or return None -> returns 0.""" """Handle return or return None -> returns 0."""
builder.ret(ir.Constant(ir.IntType(64), 0)) builder.ret(ir.Constant(ir.IntType(64), 0))
logger.debug("Generated default return: 0") logger.debug("Generated default return: 0")
return True return True
def _is_xdp_name(name: str) -> bool: def is_xdp_name(name: str) -> bool:
"""Check if a name is an XDP action""" """Check if a name is an XDP action"""
return name in XDP_ACTIONS return name in XDP_ACTIONS
def _handle_xdp_return(stmt: ast.Return, builder, ret_type) -> bool: def handle_xdp_return(stmt: ast.Return, builder, ret_type) -> bool:
"""Handle XDP returns""" """Handle XDP returns"""
if not isinstance(stmt.value, ast.Name): if not isinstance(stmt.value, ast.Name):
return False return False
@ -37,7 +37,6 @@ def _handle_xdp_return(stmt: ast.Return, builder, ret_type) -> bool:
raise ValueError( raise ValueError(
f"Unknown XDP action: {action_name}. Available: {XDP_ACTIONS.keys()}" f"Unknown XDP action: {action_name}. Available: {XDP_ACTIONS.keys()}"
) )
return False
value = XDP_ACTIONS[action_name] value = XDP_ACTIONS[action_name]
builder.ret(ir.Constant(ret_type, value)) builder.ret(ir.Constant(ret_type, value))

View File

@ -2,6 +2,58 @@ from .helper_utils import HelperHandlerRegistry, reset_scratch_pool
from .bpf_helper_handler import handle_helper_call from .bpf_helper_handler import handle_helper_call
from .helpers import ktime, pid, deref, XDP_DROP, XDP_PASS from .helpers import ktime, pid, deref, XDP_DROP, XDP_PASS
# Register the helper handler with expr module
def _register_helper_handler():
"""Register helper call handler with the expression evaluator"""
from pythonbpf.expr.expr_pass import CallHandlerRegistry
def helper_call_handler(
call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
):
"""Check if call is a helper and handle it"""
import ast
# Check for direct helper calls (e.g., ktime(), print())
if isinstance(call.func, ast.Name):
if HelperHandlerRegistry.has_handler(call.func.id):
return handle_helper_call(
call,
module,
builder,
func,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
# Check for method calls (e.g., map.lookup())
elif isinstance(call.func, ast.Attribute):
method_name = call.func.attr
# Handle: my_map.lookup(key)
if isinstance(call.func.value, ast.Name):
obj_name = call.func.value.id
if map_sym_tab and obj_name in map_sym_tab:
if HelperHandlerRegistry.has_handler(method_name):
return handle_helper_call(
call,
module,
builder,
func,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
return None
CallHandlerRegistry.set_handler(helper_call_handler)
# Register on module import
_register_helper_handler()
__all__ = [ __all__ = [
"HelperHandlerRegistry", "HelperHandlerRegistry",
"reset_scratch_pool", "reset_scratch_pool",

View File

@ -135,7 +135,7 @@ def bpf_printk_emitter(
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
builder.call(fn_ptr, args, tail=True) builder.call(fn_ptr, args, tail=True)
return None return True
@HelperHandlerRegistry.register("update") @HelperHandlerRegistry.register("update")

View File

@ -3,8 +3,12 @@ import logging
from collections.abc import Callable from collections.abc import Callable
from llvmlite import ir from llvmlite import ir
from pythonbpf.expr import eval_expr, get_base_type_and_depth, deref_to_depth from pythonbpf.expr import (
from pythonbpf.binary_ops import get_operand_value eval_expr,
get_base_type_and_depth,
deref_to_depth,
get_operand_value,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -243,6 +243,17 @@ class BTFConverter:
data data
) )
# below to replace those c_bool with bitfield greater than 8
def repl(m):
name, bits = m.groups()
return f"('{name}', ctypes.c_uint32, {bits})" if int(bits) > 8 else m.group(0)
data = re.sub(
r"\('([^']+)',\s*ctypes\.c_bool,\s*(\d+)\)",
repl,
data
)
# Remove ctypes. prefix from invalid entries # Remove ctypes. prefix from invalid entries
invalid_ctypes = ["bpf_iter_state", "_cache_type", "fs_context_purpose"] invalid_ctypes = ["bpf_iter_state", "_cache_type", "fs_context_purpose"]
for name in invalid_ctypes: for name in invalid_ctypes: