mirror of
https://github.com/varun-r-mallya/Python-BPF.git
synced 2025-12-31 21:06:25 +00:00
Compare commits
14 Commits
refactor_a
...
5d0a888542
| Author | SHA1 | Date | |
|---|---|---|---|
| 5d0a888542 | |||
| 0042280ff1 | |||
| 7a67041ea3 | |||
| 45e6ce5e5c | |||
| c5f0a2806f | |||
| b0ea93a786 | |||
| fc058c4341 | |||
| 158cc42e1e | |||
| 2a1eabc10d | |||
| 31645f0316 | |||
| 21ce041353 | |||
| 6402cf7be5 | |||
| 9a96e1247b | |||
| 989134f4be |
@ -83,14 +83,14 @@ def hist() -> HashMap:
|
|||||||
def hello(ctx: c_void_p) -> c_int64:
|
def hello(ctx: c_void_p) -> c_int64:
|
||||||
process_id = pid()
|
process_id = pid()
|
||||||
one = 1
|
one = 1
|
||||||
prev = hist().lookup(process_id)
|
prev = hist.lookup(process_id)
|
||||||
if prev:
|
if prev:
|
||||||
previous_value = prev + 1
|
previous_value = prev + 1
|
||||||
print(f"count: {previous_value} with {process_id}")
|
print(f"count: {previous_value} with {process_id}")
|
||||||
hist().update(process_id, previous_value)
|
hist.update(process_id, previous_value)
|
||||||
return c_int64(0)
|
return c_int64(0)
|
||||||
else:
|
else:
|
||||||
hist().update(process_id, one)
|
hist.update(process_id, one)
|
||||||
return c_int64(0)
|
return c_int64(0)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
13
TODO.md
13
TODO.md
@ -1,13 +0,0 @@
|
|||||||
## Short term
|
|
||||||
|
|
||||||
- Implement enough functionality to port the BCC tutorial examples in PythonBPF
|
|
||||||
- Add all maps
|
|
||||||
- XDP support in pylibbpf
|
|
||||||
- ringbuf support
|
|
||||||
- Add oneline IfExpr conditionals (wishlist)
|
|
||||||
|
|
||||||
## Long term
|
|
||||||
|
|
||||||
- Refactor the codebase to be better than a hackathon project
|
|
||||||
- Port to C++ and use actual LLVM?
|
|
||||||
- Fix struct_kioctx issue in the vmlinux transpiler
|
|
||||||
@ -1,110 +0,0 @@
|
|||||||
import ast
|
|
||||||
from llvmlite import ir
|
|
||||||
from logging import Logger
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from pythonbpf.expr import get_base_type_and_depth, deref_to_depth, eval_expr
|
|
||||||
|
|
||||||
logger: Logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def get_operand_value(
|
|
||||||
func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None
|
|
||||||
):
|
|
||||||
"""Extract the value from an operand, handling variables and constants."""
|
|
||||||
logger.info(f"Getting operand value for: {ast.dump(operand)}")
|
|
||||||
if isinstance(operand, ast.Name):
|
|
||||||
if operand.id in local_sym_tab:
|
|
||||||
var = local_sym_tab[operand.id].var
|
|
||||||
var_type = var.type
|
|
||||||
base_type, depth = get_base_type_and_depth(var_type)
|
|
||||||
logger.info(f"var is {var}, base_type is {base_type}, depth is {depth}")
|
|
||||||
val = deref_to_depth(func, builder, var, depth)
|
|
||||||
return val
|
|
||||||
raise ValueError(f"Undefined variable: {operand.id}")
|
|
||||||
elif isinstance(operand, ast.Constant):
|
|
||||||
if isinstance(operand.value, int):
|
|
||||||
cst = ir.Constant(ir.IntType(64), int(operand.value))
|
|
||||||
return cst
|
|
||||||
raise TypeError(f"Unsupported constant type: {type(operand.value)}")
|
|
||||||
elif isinstance(operand, ast.BinOp):
|
|
||||||
res = handle_binary_op_impl(
|
|
||||||
func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab
|
|
||||||
)
|
|
||||||
return res
|
|
||||||
else:
|
|
||||||
res = eval_expr(
|
|
||||||
func, module, builder, operand, local_sym_tab, map_sym_tab, structs_sym_tab
|
|
||||||
)
|
|
||||||
if res is None:
|
|
||||||
raise ValueError(f"Failed to evaluate call expression: {operand}")
|
|
||||||
val, _ = res
|
|
||||||
logger.info(f"Evaluated expr to {val} of type {val.type}")
|
|
||||||
base_type, depth = get_base_type_and_depth(val.type)
|
|
||||||
if depth > 0:
|
|
||||||
val = deref_to_depth(func, builder, val, depth)
|
|
||||||
return val
|
|
||||||
raise TypeError(f"Unsupported operand type: {type(operand)}")
|
|
||||||
|
|
||||||
|
|
||||||
def handle_binary_op_impl(
|
|
||||||
func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None
|
|
||||||
):
|
|
||||||
op = rval.op
|
|
||||||
left = get_operand_value(
|
|
||||||
func, module, rval.left, builder, local_sym_tab, map_sym_tab, structs_sym_tab
|
|
||||||
)
|
|
||||||
right = get_operand_value(
|
|
||||||
func, module, rval.right, builder, local_sym_tab, map_sym_tab, structs_sym_tab
|
|
||||||
)
|
|
||||||
logger.info(f"left is {left}, right is {right}, op is {op}")
|
|
||||||
|
|
||||||
# NOTE: Before doing the operation, if the operands are integers
|
|
||||||
# we always extend them to i64. The assignment to LHS will take
|
|
||||||
# care of truncation if needed.
|
|
||||||
if isinstance(left.type, ir.IntType) and left.type.width < 64:
|
|
||||||
left = builder.sext(left, ir.IntType(64))
|
|
||||||
if isinstance(right.type, ir.IntType) and right.type.width < 64:
|
|
||||||
right = builder.sext(right, ir.IntType(64))
|
|
||||||
|
|
||||||
# Map AST operation nodes to LLVM IR builder methods
|
|
||||||
op_map = {
|
|
||||||
ast.Add: builder.add,
|
|
||||||
ast.Sub: builder.sub,
|
|
||||||
ast.Mult: builder.mul,
|
|
||||||
ast.Div: builder.sdiv,
|
|
||||||
ast.Mod: builder.srem,
|
|
||||||
ast.LShift: builder.shl,
|
|
||||||
ast.RShift: builder.lshr,
|
|
||||||
ast.BitOr: builder.or_,
|
|
||||||
ast.BitXor: builder.xor,
|
|
||||||
ast.BitAnd: builder.and_,
|
|
||||||
ast.FloorDiv: builder.udiv,
|
|
||||||
}
|
|
||||||
|
|
||||||
if type(op) in op_map:
|
|
||||||
result = op_map[type(op)](left, right)
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
raise SyntaxError("Unsupported binary operation")
|
|
||||||
|
|
||||||
|
|
||||||
def handle_binary_op(
|
|
||||||
func,
|
|
||||||
module,
|
|
||||||
rval,
|
|
||||||
builder,
|
|
||||||
var_name,
|
|
||||||
local_sym_tab,
|
|
||||||
map_sym_tab,
|
|
||||||
structs_sym_tab=None,
|
|
||||||
):
|
|
||||||
result = handle_binary_op_impl(
|
|
||||||
func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab
|
|
||||||
)
|
|
||||||
if var_name and var_name in local_sym_tab:
|
|
||||||
logger.info(
|
|
||||||
f"Storing result {result} into variable {local_sym_tab[var_name].var}"
|
|
||||||
)
|
|
||||||
builder.store(result, local_sym_tab[var_name].var)
|
|
||||||
return result, result.type
|
|
||||||
@ -1,5 +1,7 @@
|
|||||||
from .expr_pass import eval_expr, handle_expr
|
from .expr_pass import eval_expr, handle_expr, get_operand_value
|
||||||
from .type_normalization import convert_to_bool, get_base_type_and_depth, deref_to_depth
|
from .type_normalization import convert_to_bool, get_base_type_and_depth
|
||||||
|
from .ir_ops import deref_to_depth
|
||||||
|
from .call_registry import CallHandlerRegistry
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"eval_expr",
|
"eval_expr",
|
||||||
@ -7,4 +9,6 @@ __all__ = [
|
|||||||
"convert_to_bool",
|
"convert_to_bool",
|
||||||
"get_base_type_and_depth",
|
"get_base_type_and_depth",
|
||||||
"deref_to_depth",
|
"deref_to_depth",
|
||||||
|
"get_operand_value",
|
||||||
|
"CallHandlerRegistry",
|
||||||
]
|
]
|
||||||
|
|||||||
20
pythonbpf/expr/call_registry.py
Normal file
20
pythonbpf/expr/call_registry.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
class CallHandlerRegistry:
|
||||||
|
"""Registry for handling different types of calls (helpers, etc.)"""
|
||||||
|
|
||||||
|
_handler = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_handler(cls, handler):
|
||||||
|
"""Set the handler for unknown calls"""
|
||||||
|
cls._handler = handler
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def handle_call(
|
||||||
|
cls, call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||||
|
):
|
||||||
|
"""Handle a call using the registered handler"""
|
||||||
|
if cls._handler is None:
|
||||||
|
return None
|
||||||
|
return cls._handler(
|
||||||
|
call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||||
|
)
|
||||||
@ -5,10 +5,20 @@ import logging
|
|||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes
|
from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes
|
||||||
from .type_normalization import convert_to_bool, handle_comparator
|
from .call_registry import CallHandlerRegistry
|
||||||
|
from .type_normalization import (
|
||||||
|
convert_to_bool,
|
||||||
|
handle_comparator,
|
||||||
|
get_base_type_and_depth,
|
||||||
|
deref_to_depth,
|
||||||
|
)
|
||||||
|
|
||||||
logger: Logger = logging.getLogger(__name__)
|
logger: Logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Leaf Handlers (No Recursive eval_expr calls)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder):
|
def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder):
|
||||||
"""Handle ast.Name expressions."""
|
"""Handle ast.Name expressions."""
|
||||||
@ -21,10 +31,24 @@ def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _handle_constant_expr(expr: ast.Constant):
|
def _handle_constant_expr(module, builder, expr: ast.Constant):
|
||||||
"""Handle ast.Constant expressions."""
|
"""Handle ast.Constant expressions."""
|
||||||
if isinstance(expr.value, int) or isinstance(expr.value, bool):
|
if isinstance(expr.value, int) or isinstance(expr.value, bool):
|
||||||
return ir.Constant(ir.IntType(64), int(expr.value)), ir.IntType(64)
|
return ir.Constant(ir.IntType(64), int(expr.value)), ir.IntType(64)
|
||||||
|
elif isinstance(expr.value, str):
|
||||||
|
str_name = f".str.{id(expr)}"
|
||||||
|
str_bytes = expr.value.encode("utf-8") + b"\x00"
|
||||||
|
str_type = ir.ArrayType(ir.IntType(8), len(str_bytes))
|
||||||
|
str_constant = ir.Constant(str_type, bytearray(str_bytes))
|
||||||
|
|
||||||
|
# Create global variable
|
||||||
|
global_str = ir.GlobalVariable(module, str_type, name=str_name)
|
||||||
|
global_str.linkage = "internal"
|
||||||
|
global_str.global_constant = True
|
||||||
|
global_str.initializer = str_constant
|
||||||
|
|
||||||
|
str_ptr = builder.bitcast(global_str, ir.PointerType(ir.IntType(8)))
|
||||||
|
return str_ptr, ir.PointerType(ir.IntType(8))
|
||||||
else:
|
else:
|
||||||
logger.error(f"Unsupported constant type {ast.dump(expr)}")
|
logger.error(f"Unsupported constant type {ast.dump(expr)}")
|
||||||
return None
|
return None
|
||||||
@ -88,6 +112,118 @@ def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilde
|
|||||||
return val, local_sym_tab[arg.id].ir_type
|
return val, local_sym_tab[arg.id].ir_type
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Binary Operations
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def get_operand_value(
|
||||||
|
func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None
|
||||||
|
):
|
||||||
|
"""Extract the value from an operand, handling variables and constants."""
|
||||||
|
logger.info(f"Getting operand value for: {ast.dump(operand)}")
|
||||||
|
if isinstance(operand, ast.Name):
|
||||||
|
if operand.id in local_sym_tab:
|
||||||
|
var = local_sym_tab[operand.id].var
|
||||||
|
var_type = var.type
|
||||||
|
base_type, depth = get_base_type_and_depth(var_type)
|
||||||
|
logger.info(f"var is {var}, base_type is {base_type}, depth is {depth}")
|
||||||
|
val = deref_to_depth(func, builder, var, depth)
|
||||||
|
return val
|
||||||
|
raise ValueError(f"Undefined variable: {operand.id}")
|
||||||
|
elif isinstance(operand, ast.Constant):
|
||||||
|
if isinstance(operand.value, int):
|
||||||
|
cst = ir.Constant(ir.IntType(64), int(operand.value))
|
||||||
|
return cst
|
||||||
|
raise TypeError(f"Unsupported constant type: {type(operand.value)}")
|
||||||
|
elif isinstance(operand, ast.BinOp):
|
||||||
|
res = _handle_binary_op_impl(
|
||||||
|
func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||||
|
)
|
||||||
|
return res
|
||||||
|
else:
|
||||||
|
res = eval_expr(
|
||||||
|
func, module, builder, operand, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||||
|
)
|
||||||
|
if res is None:
|
||||||
|
raise ValueError(f"Failed to evaluate call expression: {operand}")
|
||||||
|
val, _ = res
|
||||||
|
logger.info(f"Evaluated expr to {val} of type {val.type}")
|
||||||
|
base_type, depth = get_base_type_and_depth(val.type)
|
||||||
|
if depth > 0:
|
||||||
|
val = deref_to_depth(func, builder, val, depth)
|
||||||
|
return val
|
||||||
|
raise TypeError(f"Unsupported operand type: {type(operand)}")
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_binary_op_impl(
|
||||||
|
func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None
|
||||||
|
):
|
||||||
|
op = rval.op
|
||||||
|
left = get_operand_value(
|
||||||
|
func, module, rval.left, builder, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||||
|
)
|
||||||
|
right = get_operand_value(
|
||||||
|
func, module, rval.right, builder, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||||
|
)
|
||||||
|
logger.info(f"left is {left}, right is {right}, op is {op}")
|
||||||
|
|
||||||
|
# NOTE: Before doing the operation, if the operands are integers
|
||||||
|
# we always extend them to i64. The assignment to LHS will take
|
||||||
|
# care of truncation if needed.
|
||||||
|
if isinstance(left.type, ir.IntType) and left.type.width < 64:
|
||||||
|
left = builder.sext(left, ir.IntType(64))
|
||||||
|
if isinstance(right.type, ir.IntType) and right.type.width < 64:
|
||||||
|
right = builder.sext(right, ir.IntType(64))
|
||||||
|
|
||||||
|
# Map AST operation nodes to LLVM IR builder methods
|
||||||
|
op_map = {
|
||||||
|
ast.Add: builder.add,
|
||||||
|
ast.Sub: builder.sub,
|
||||||
|
ast.Mult: builder.mul,
|
||||||
|
ast.Div: builder.sdiv,
|
||||||
|
ast.Mod: builder.srem,
|
||||||
|
ast.LShift: builder.shl,
|
||||||
|
ast.RShift: builder.lshr,
|
||||||
|
ast.BitOr: builder.or_,
|
||||||
|
ast.BitXor: builder.xor,
|
||||||
|
ast.BitAnd: builder.and_,
|
||||||
|
ast.FloorDiv: builder.udiv,
|
||||||
|
}
|
||||||
|
|
||||||
|
if type(op) in op_map:
|
||||||
|
result = op_map[type(op)](left, right)
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
raise SyntaxError("Unsupported binary operation")
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_binary_op(
|
||||||
|
func,
|
||||||
|
module,
|
||||||
|
rval,
|
||||||
|
builder,
|
||||||
|
var_name,
|
||||||
|
local_sym_tab,
|
||||||
|
map_sym_tab,
|
||||||
|
structs_sym_tab=None,
|
||||||
|
):
|
||||||
|
result = _handle_binary_op_impl(
|
||||||
|
func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||||
|
)
|
||||||
|
if var_name and var_name in local_sym_tab:
|
||||||
|
logger.info(
|
||||||
|
f"Storing result {result} into variable {local_sym_tab[var_name].var}"
|
||||||
|
)
|
||||||
|
builder.store(result, local_sym_tab[var_name].var)
|
||||||
|
return result, result.type
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Comparison and Unary Operations
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
def _handle_ctypes_call(
|
def _handle_ctypes_call(
|
||||||
func,
|
func,
|
||||||
module,
|
module,
|
||||||
@ -180,8 +316,6 @@ def _handle_unary_op(
|
|||||||
logger.error("Only 'not' and '-' unary operators are supported")
|
logger.error("Only 'not' and '-' unary operators are supported")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
from pythonbpf.binary_ops import get_operand_value
|
|
||||||
|
|
||||||
operand = get_operand_value(
|
operand = get_operand_value(
|
||||||
func, module, expr.operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab
|
func, module, expr.operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||||
)
|
)
|
||||||
@ -200,6 +334,11 @@ def _handle_unary_op(
|
|||||||
return result, ir.IntType(64)
|
return result, ir.IntType(64)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Boolean Operations
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab):
|
def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab):
|
||||||
"""Handle `and` boolean operations."""
|
"""Handle `and` boolean operations."""
|
||||||
|
|
||||||
@ -330,6 +469,11 @@ def _handle_boolean_op(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Expression Dispatcher
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
def eval_expr(
|
def eval_expr(
|
||||||
func,
|
func,
|
||||||
module,
|
module,
|
||||||
@ -343,7 +487,7 @@ def eval_expr(
|
|||||||
if isinstance(expr, ast.Name):
|
if isinstance(expr, ast.Name):
|
||||||
return _handle_name_expr(expr, local_sym_tab, builder)
|
return _handle_name_expr(expr, local_sym_tab, builder)
|
||||||
elif isinstance(expr, ast.Constant):
|
elif isinstance(expr, ast.Constant):
|
||||||
return _handle_constant_expr(expr)
|
return _handle_constant_expr(module, builder, expr)
|
||||||
elif isinstance(expr, ast.Call):
|
elif isinstance(expr, ast.Call):
|
||||||
if isinstance(expr.func, ast.Name) and expr.func.id == "deref":
|
if isinstance(expr.func, ast.Name) and expr.func.id == "deref":
|
||||||
return _handle_deref_call(expr, local_sym_tab, builder)
|
return _handle_deref_call(expr, local_sym_tab, builder)
|
||||||
@ -359,57 +503,18 @@ def eval_expr(
|
|||||||
structs_sym_tab,
|
structs_sym_tab,
|
||||||
)
|
)
|
||||||
|
|
||||||
# delayed import to avoid circular dependency
|
result = CallHandlerRegistry.handle_call(
|
||||||
from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call
|
expr, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||||
|
)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
|
||||||
if isinstance(expr.func, ast.Name) and HelperHandlerRegistry.has_handler(
|
logger.warning(f"Unknown call: {ast.dump(expr)}")
|
||||||
expr.func.id
|
return None
|
||||||
):
|
|
||||||
return handle_helper_call(
|
|
||||||
expr,
|
|
||||||
module,
|
|
||||||
builder,
|
|
||||||
func,
|
|
||||||
local_sym_tab,
|
|
||||||
map_sym_tab,
|
|
||||||
structs_sym_tab,
|
|
||||||
)
|
|
||||||
elif isinstance(expr.func, ast.Attribute):
|
|
||||||
logger.info(f"Handling method call: {ast.dump(expr.func)}")
|
|
||||||
if isinstance(expr.func.value, ast.Call) and isinstance(
|
|
||||||
expr.func.value.func, ast.Name
|
|
||||||
):
|
|
||||||
method_name = expr.func.attr
|
|
||||||
if HelperHandlerRegistry.has_handler(method_name):
|
|
||||||
return handle_helper_call(
|
|
||||||
expr,
|
|
||||||
module,
|
|
||||||
builder,
|
|
||||||
func,
|
|
||||||
local_sym_tab,
|
|
||||||
map_sym_tab,
|
|
||||||
structs_sym_tab,
|
|
||||||
)
|
|
||||||
elif isinstance(expr.func.value, ast.Name):
|
|
||||||
obj_name = expr.func.value.id
|
|
||||||
method_name = expr.func.attr
|
|
||||||
if obj_name in map_sym_tab:
|
|
||||||
if HelperHandlerRegistry.has_handler(method_name):
|
|
||||||
return handle_helper_call(
|
|
||||||
expr,
|
|
||||||
module,
|
|
||||||
builder,
|
|
||||||
func,
|
|
||||||
local_sym_tab,
|
|
||||||
map_sym_tab,
|
|
||||||
structs_sym_tab,
|
|
||||||
)
|
|
||||||
elif isinstance(expr, ast.Attribute):
|
elif isinstance(expr, ast.Attribute):
|
||||||
return _handle_attribute_expr(expr, local_sym_tab, structs_sym_tab, builder)
|
return _handle_attribute_expr(expr, local_sym_tab, structs_sym_tab, builder)
|
||||||
elif isinstance(expr, ast.BinOp):
|
elif isinstance(expr, ast.BinOp):
|
||||||
from pythonbpf.binary_ops import handle_binary_op
|
return _handle_binary_op(
|
||||||
|
|
||||||
return handle_binary_op(
|
|
||||||
func,
|
func,
|
||||||
module,
|
module,
|
||||||
expr,
|
expr,
|
||||||
|
|||||||
50
pythonbpf/expr/ir_ops.py
Normal file
50
pythonbpf/expr/ir_ops.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import logging
|
||||||
|
from llvmlite import ir
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def deref_to_depth(func, builder, val, target_depth):
|
||||||
|
"""Dereference a pointer to a certain depth."""
|
||||||
|
|
||||||
|
cur_val = val
|
||||||
|
cur_type = val.type
|
||||||
|
|
||||||
|
for depth in range(target_depth):
|
||||||
|
if not isinstance(val.type, ir.PointerType):
|
||||||
|
logger.error("Cannot dereference further, non-pointer type")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# dereference with null check
|
||||||
|
pointee_type = cur_type.pointee
|
||||||
|
null_check_block = builder.block
|
||||||
|
not_null_block = func.append_basic_block(name=f"deref_not_null_{depth}")
|
||||||
|
merge_block = func.append_basic_block(name=f"deref_merge_{depth}")
|
||||||
|
|
||||||
|
null_ptr = ir.Constant(cur_type, None)
|
||||||
|
is_not_null = builder.icmp_signed("!=", cur_val, null_ptr)
|
||||||
|
logger.debug(f"Inserted null check for pointer at depth {depth}")
|
||||||
|
|
||||||
|
builder.cbranch(is_not_null, not_null_block, merge_block)
|
||||||
|
|
||||||
|
builder.position_at_end(not_null_block)
|
||||||
|
dereferenced_val = builder.load(cur_val)
|
||||||
|
logger.debug(f"Dereferenced to depth {depth - 1}, type: {pointee_type}")
|
||||||
|
builder.branch(merge_block)
|
||||||
|
|
||||||
|
builder.position_at_end(merge_block)
|
||||||
|
phi = builder.phi(pointee_type, name=f"deref_result_{depth}")
|
||||||
|
|
||||||
|
zero_value = (
|
||||||
|
ir.Constant(pointee_type, 0)
|
||||||
|
if isinstance(pointee_type, ir.IntType)
|
||||||
|
else ir.Constant(pointee_type, None)
|
||||||
|
)
|
||||||
|
phi.add_incoming(zero_value, null_check_block)
|
||||||
|
|
||||||
|
phi.add_incoming(dereferenced_val, not_null_block)
|
||||||
|
|
||||||
|
# Continue with phi result
|
||||||
|
cur_val = phi
|
||||||
|
cur_type = pointee_type
|
||||||
|
return cur_val
|
||||||
@ -1,6 +1,7 @@
|
|||||||
from llvmlite import ir
|
|
||||||
import logging
|
import logging
|
||||||
import ast
|
import ast
|
||||||
|
from llvmlite import ir
|
||||||
|
from .ir_ops import deref_to_depth
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -26,52 +27,6 @@ def get_base_type_and_depth(ir_type):
|
|||||||
return cur_type, depth
|
return cur_type, depth
|
||||||
|
|
||||||
|
|
||||||
def deref_to_depth(func, builder, val, target_depth):
|
|
||||||
"""Dereference a pointer to a certain depth."""
|
|
||||||
|
|
||||||
cur_val = val
|
|
||||||
cur_type = val.type
|
|
||||||
|
|
||||||
for depth in range(target_depth):
|
|
||||||
if not isinstance(val.type, ir.PointerType):
|
|
||||||
logger.error("Cannot dereference further, non-pointer type")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# dereference with null check
|
|
||||||
pointee_type = cur_type.pointee
|
|
||||||
null_check_block = builder.block
|
|
||||||
not_null_block = func.append_basic_block(name=f"deref_not_null_{depth}")
|
|
||||||
merge_block = func.append_basic_block(name=f"deref_merge_{depth}")
|
|
||||||
|
|
||||||
null_ptr = ir.Constant(cur_type, None)
|
|
||||||
is_not_null = builder.icmp_signed("!=", cur_val, null_ptr)
|
|
||||||
logger.debug(f"Inserted null check for pointer at depth {depth}")
|
|
||||||
|
|
||||||
builder.cbranch(is_not_null, not_null_block, merge_block)
|
|
||||||
|
|
||||||
builder.position_at_end(not_null_block)
|
|
||||||
dereferenced_val = builder.load(cur_val)
|
|
||||||
logger.debug(f"Dereferenced to depth {depth - 1}, type: {pointee_type}")
|
|
||||||
builder.branch(merge_block)
|
|
||||||
|
|
||||||
builder.position_at_end(merge_block)
|
|
||||||
phi = builder.phi(pointee_type, name=f"deref_result_{depth}")
|
|
||||||
|
|
||||||
zero_value = (
|
|
||||||
ir.Constant(pointee_type, 0)
|
|
||||||
if isinstance(pointee_type, ir.IntType)
|
|
||||||
else ir.Constant(pointee_type, None)
|
|
||||||
)
|
|
||||||
phi.add_incoming(zero_value, null_check_block)
|
|
||||||
|
|
||||||
phi.add_incoming(dereferenced_val, not_null_block)
|
|
||||||
|
|
||||||
# Continue with phi result
|
|
||||||
cur_val = phi
|
|
||||||
cur_type = pointee_type
|
|
||||||
return cur_val
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_types(func, builder, lhs, rhs):
|
def _normalize_types(func, builder, lhs, rhs):
|
||||||
"""Normalize types for comparison."""
|
"""Normalize types for comparison."""
|
||||||
|
|
||||||
|
|||||||
@ -1,22 +0,0 @@
|
|||||||
from typing import Dict
|
|
||||||
|
|
||||||
|
|
||||||
class StatementHandlerRegistry:
|
|
||||||
"""Registry for statement handlers."""
|
|
||||||
|
|
||||||
_handlers: Dict = {}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register(cls, stmt_type):
|
|
||||||
"""Register a handler for a specific statement type."""
|
|
||||||
|
|
||||||
def decorator(handler):
|
|
||||||
cls._handlers[stmt_type] = handler
|
|
||||||
return handler
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __getitem__(cls, stmt_type):
|
|
||||||
"""Get the handler for a specific statement type."""
|
|
||||||
return cls._handlers.get(stmt_type, None)
|
|
||||||
88
pythonbpf/functions/function_metadata.py
Normal file
88
pythonbpf/functions/function_metadata.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
import ast
|
||||||
|
|
||||||
|
|
||||||
|
def get_probe_string(func_node):
|
||||||
|
"""Extract the probe string from the decorator of the function node"""
|
||||||
|
# TODO: right now we have the whole string in the section decorator
|
||||||
|
# But later we can implement typed tuples for tracepoints and kprobes
|
||||||
|
# For helper functions, we return "helper"
|
||||||
|
|
||||||
|
for decorator in func_node.decorator_list:
|
||||||
|
if isinstance(decorator, ast.Name) and decorator.id == "bpfglobal":
|
||||||
|
return None
|
||||||
|
if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Name):
|
||||||
|
if decorator.func.id == "section" and len(decorator.args) == 1:
|
||||||
|
arg = decorator.args[0]
|
||||||
|
if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
|
||||||
|
return arg.value
|
||||||
|
return "helper"
|
||||||
|
|
||||||
|
|
||||||
|
def is_global_function(func_node):
|
||||||
|
"""Check if the function is a global"""
|
||||||
|
for decorator in func_node.decorator_list:
|
||||||
|
if isinstance(decorator, ast.Name) and decorator.id in (
|
||||||
|
"map",
|
||||||
|
"bpfglobal",
|
||||||
|
"struct",
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def infer_return_type(func_node: ast.FunctionDef):
|
||||||
|
if not isinstance(func_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||||
|
raise TypeError("Expected ast.FunctionDef")
|
||||||
|
if func_node.returns is not None:
|
||||||
|
try:
|
||||||
|
return ast.unparse(func_node.returns)
|
||||||
|
except Exception:
|
||||||
|
node = func_node.returns
|
||||||
|
if isinstance(node, ast.Name):
|
||||||
|
return node.id
|
||||||
|
if isinstance(node, ast.Attribute):
|
||||||
|
return getattr(node, "attr", type(node).__name__)
|
||||||
|
try:
|
||||||
|
return str(node)
|
||||||
|
except Exception:
|
||||||
|
return type(node).__name__
|
||||||
|
found_type = None
|
||||||
|
|
||||||
|
def _expr_type(e):
|
||||||
|
if e is None:
|
||||||
|
return "None"
|
||||||
|
if isinstance(e, ast.Constant):
|
||||||
|
return type(e.value).__name__
|
||||||
|
if isinstance(e, ast.Name):
|
||||||
|
return e.id
|
||||||
|
if isinstance(e, ast.Call):
|
||||||
|
f = e.func
|
||||||
|
if isinstance(f, ast.Name):
|
||||||
|
return f.id
|
||||||
|
if isinstance(f, ast.Attribute):
|
||||||
|
try:
|
||||||
|
return ast.unparse(f)
|
||||||
|
except Exception:
|
||||||
|
return getattr(f, "attr", type(f).__name__)
|
||||||
|
try:
|
||||||
|
return ast.unparse(f)
|
||||||
|
except Exception:
|
||||||
|
return type(f).__name__
|
||||||
|
if isinstance(e, ast.Attribute):
|
||||||
|
try:
|
||||||
|
return ast.unparse(e)
|
||||||
|
except Exception:
|
||||||
|
return getattr(e, "attr", type(e).__name__)
|
||||||
|
try:
|
||||||
|
return ast.unparse(e)
|
||||||
|
except Exception:
|
||||||
|
return type(e).__name__
|
||||||
|
|
||||||
|
for walked_node in ast.walk(func_node):
|
||||||
|
if isinstance(walked_node, ast.Return):
|
||||||
|
t = _expr_type(walked_node.value)
|
||||||
|
if found_type is None:
|
||||||
|
found_type = t
|
||||||
|
elif found_type != t:
|
||||||
|
raise ValueError(f"Conflicting return types: {found_type} vs {t}")
|
||||||
|
return found_type or "None"
|
||||||
@ -14,27 +14,125 @@ from pythonbpf.assign_pass import (
|
|||||||
)
|
)
|
||||||
from pythonbpf.allocation_pass import handle_assign_allocation, allocate_temp_pool
|
from pythonbpf.allocation_pass import handle_assign_allocation, allocate_temp_pool
|
||||||
|
|
||||||
from .return_utils import _handle_none_return, _handle_xdp_return, _is_xdp_name
|
from .return_utils import handle_none_return, handle_xdp_return, is_xdp_name
|
||||||
|
from .function_metadata import get_probe_string, is_global_function, infer_return_type
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_probe_string(func_node):
|
# ============================================================================
|
||||||
"""Extract the probe string from the decorator of the function node."""
|
# SECTION 1: Memory Allocation
|
||||||
# TODO: right now we have the whole string in the section decorator
|
# ============================================================================
|
||||||
# But later we can implement typed tuples for tracepoints and kprobes
|
|
||||||
# For helper functions, we return "helper"
|
|
||||||
|
|
||||||
for decorator in func_node.decorator_list:
|
|
||||||
if isinstance(decorator, ast.Name) and decorator.id == "bpfglobal":
|
def count_temps_in_call(call_node, local_sym_tab):
|
||||||
return None
|
"""Count the number of temporary variables needed for a function call."""
|
||||||
if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Name):
|
|
||||||
if decorator.func.id == "section" and len(decorator.args) == 1:
|
count = 0
|
||||||
arg = decorator.args[0]
|
is_helper = False
|
||||||
if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
|
|
||||||
return arg.value
|
# NOTE: We exclude print calls for now
|
||||||
return "helper"
|
if isinstance(call_node.func, ast.Name):
|
||||||
|
if (
|
||||||
|
HelperHandlerRegistry.has_handler(call_node.func.id)
|
||||||
|
and call_node.func.id != "print"
|
||||||
|
):
|
||||||
|
is_helper = True
|
||||||
|
elif isinstance(call_node.func, ast.Attribute):
|
||||||
|
if HelperHandlerRegistry.has_handler(call_node.func.attr):
|
||||||
|
is_helper = True
|
||||||
|
|
||||||
|
if not is_helper:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
for arg in call_node.args:
|
||||||
|
# NOTE: Count all non-name arguments
|
||||||
|
# For struct fields, if it is being passed as an argument,
|
||||||
|
# The struct object should already exist in the local_sym_tab
|
||||||
|
if not isinstance(arg, ast.Name) and not (
|
||||||
|
isinstance(arg, ast.Attribute) and arg.value.id in local_sym_tab
|
||||||
|
):
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
def handle_if_allocation(
|
||||||
|
module, builder, stmt, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
|
||||||
|
):
|
||||||
|
"""Recursively handle allocations in if/else branches."""
|
||||||
|
if stmt.body:
|
||||||
|
allocate_mem(
|
||||||
|
module,
|
||||||
|
builder,
|
||||||
|
stmt.body,
|
||||||
|
func,
|
||||||
|
ret_type,
|
||||||
|
map_sym_tab,
|
||||||
|
local_sym_tab,
|
||||||
|
structs_sym_tab,
|
||||||
|
)
|
||||||
|
if stmt.orelse:
|
||||||
|
allocate_mem(
|
||||||
|
module,
|
||||||
|
builder,
|
||||||
|
stmt.orelse,
|
||||||
|
func,
|
||||||
|
ret_type,
|
||||||
|
map_sym_tab,
|
||||||
|
local_sym_tab,
|
||||||
|
structs_sym_tab,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def allocate_mem(
|
||||||
|
module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
|
||||||
|
):
|
||||||
|
max_temps_needed = 0
|
||||||
|
|
||||||
|
def update_max_temps_for_stmt(stmt):
|
||||||
|
nonlocal max_temps_needed
|
||||||
|
temps_needed = 0
|
||||||
|
|
||||||
|
if isinstance(stmt, ast.If):
|
||||||
|
for s in stmt.body:
|
||||||
|
update_max_temps_for_stmt(s)
|
||||||
|
for s in stmt.orelse:
|
||||||
|
update_max_temps_for_stmt(s)
|
||||||
|
return
|
||||||
|
|
||||||
|
for node in ast.walk(stmt):
|
||||||
|
if isinstance(node, ast.Call):
|
||||||
|
temps_needed += count_temps_in_call(node, local_sym_tab)
|
||||||
|
max_temps_needed = max(max_temps_needed, temps_needed)
|
||||||
|
|
||||||
|
for stmt in body:
|
||||||
|
update_max_temps_for_stmt(stmt)
|
||||||
|
|
||||||
|
# Handle allocations
|
||||||
|
if isinstance(stmt, ast.If):
|
||||||
|
handle_if_allocation(
|
||||||
|
module,
|
||||||
|
builder,
|
||||||
|
stmt,
|
||||||
|
func,
|
||||||
|
ret_type,
|
||||||
|
map_sym_tab,
|
||||||
|
local_sym_tab,
|
||||||
|
structs_sym_tab,
|
||||||
|
)
|
||||||
|
elif isinstance(stmt, ast.Assign):
|
||||||
|
handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab)
|
||||||
|
|
||||||
|
allocate_temp_pool(builder, max_temps_needed, local_sym_tab)
|
||||||
|
|
||||||
|
return local_sym_tab
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# SECTION 2: Statement Handlers
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
def handle_assign(
|
def handle_assign(
|
||||||
@ -146,9 +244,9 @@ def handle_if(
|
|||||||
def handle_return(builder, stmt, local_sym_tab, ret_type):
|
def handle_return(builder, stmt, local_sym_tab, ret_type):
|
||||||
logger.info(f"Handling return statement: {ast.dump(stmt)}")
|
logger.info(f"Handling return statement: {ast.dump(stmt)}")
|
||||||
if stmt.value is None:
|
if stmt.value is None:
|
||||||
return _handle_none_return(builder)
|
return handle_none_return(builder)
|
||||||
elif isinstance(stmt.value, ast.Name) and _is_xdp_name(stmt.value.id):
|
elif isinstance(stmt.value, ast.Name) and is_xdp_name(stmt.value.id):
|
||||||
return _handle_xdp_return(stmt, builder, ret_type)
|
return handle_xdp_return(stmt, builder, ret_type)
|
||||||
else:
|
else:
|
||||||
val = eval_expr(
|
val = eval_expr(
|
||||||
func=None,
|
func=None,
|
||||||
@ -207,108 +305,9 @@ def process_stmt(
|
|||||||
return did_return
|
return did_return
|
||||||
|
|
||||||
|
|
||||||
def handle_if_allocation(
|
# ============================================================================
|
||||||
module, builder, stmt, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
|
# SECTION 3: Function Body Processing
|
||||||
):
|
# ============================================================================
|
||||||
"""Recursively handle allocations in if/else branches."""
|
|
||||||
if stmt.body:
|
|
||||||
allocate_mem(
|
|
||||||
module,
|
|
||||||
builder,
|
|
||||||
stmt.body,
|
|
||||||
func,
|
|
||||||
ret_type,
|
|
||||||
map_sym_tab,
|
|
||||||
local_sym_tab,
|
|
||||||
structs_sym_tab,
|
|
||||||
)
|
|
||||||
if stmt.orelse:
|
|
||||||
allocate_mem(
|
|
||||||
module,
|
|
||||||
builder,
|
|
||||||
stmt.orelse,
|
|
||||||
func,
|
|
||||||
ret_type,
|
|
||||||
map_sym_tab,
|
|
||||||
local_sym_tab,
|
|
||||||
structs_sym_tab,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def count_temps_in_call(call_node, local_sym_tab):
|
|
||||||
"""Count the number of temporary variables needed for a function call."""
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
is_helper = False
|
|
||||||
|
|
||||||
# NOTE: We exclude print calls for now
|
|
||||||
if isinstance(call_node.func, ast.Name):
|
|
||||||
if (
|
|
||||||
HelperHandlerRegistry.has_handler(call_node.func.id)
|
|
||||||
and call_node.func.id != "print"
|
|
||||||
):
|
|
||||||
is_helper = True
|
|
||||||
elif isinstance(call_node.func, ast.Attribute):
|
|
||||||
if HelperHandlerRegistry.has_handler(call_node.func.attr):
|
|
||||||
is_helper = True
|
|
||||||
|
|
||||||
if not is_helper:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
for arg in call_node.args:
|
|
||||||
# NOTE: Count all non-name arguments
|
|
||||||
# For struct fields, if it is being passed as an argument,
|
|
||||||
# The struct object should already exist in the local_sym_tab
|
|
||||||
if not isinstance(arg, ast.Name) and not (
|
|
||||||
isinstance(arg, ast.Attribute) and arg.value.id in local_sym_tab
|
|
||||||
):
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
return count
|
|
||||||
|
|
||||||
|
|
||||||
def allocate_mem(
|
|
||||||
module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
|
|
||||||
):
|
|
||||||
max_temps_needed = 0
|
|
||||||
|
|
||||||
def update_max_temps_for_stmt(stmt):
|
|
||||||
nonlocal max_temps_needed
|
|
||||||
temps_needed = 0
|
|
||||||
|
|
||||||
if isinstance(stmt, ast.If):
|
|
||||||
for s in stmt.body:
|
|
||||||
update_max_temps_for_stmt(s)
|
|
||||||
for s in stmt.orelse:
|
|
||||||
update_max_temps_for_stmt(s)
|
|
||||||
return
|
|
||||||
|
|
||||||
for node in ast.walk(stmt):
|
|
||||||
if isinstance(node, ast.Call):
|
|
||||||
temps_needed += count_temps_in_call(node, local_sym_tab)
|
|
||||||
max_temps_needed = max(max_temps_needed, temps_needed)
|
|
||||||
|
|
||||||
for stmt in body:
|
|
||||||
update_max_temps_for_stmt(stmt)
|
|
||||||
|
|
||||||
# Handle allocations
|
|
||||||
if isinstance(stmt, ast.If):
|
|
||||||
handle_if_allocation(
|
|
||||||
module,
|
|
||||||
builder,
|
|
||||||
stmt,
|
|
||||||
func,
|
|
||||||
ret_type,
|
|
||||||
map_sym_tab,
|
|
||||||
local_sym_tab,
|
|
||||||
structs_sym_tab,
|
|
||||||
)
|
|
||||||
elif isinstance(stmt, ast.Assign):
|
|
||||||
handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab)
|
|
||||||
|
|
||||||
allocate_temp_pool(builder, max_temps_needed, local_sym_tab)
|
|
||||||
|
|
||||||
return local_sym_tab
|
|
||||||
|
|
||||||
|
|
||||||
def process_func_body(
|
def process_func_body(
|
||||||
@ -390,18 +389,14 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t
|
|||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# SECTION 4: Top-Level Function Processor
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
def func_proc(tree, module, chunks, map_sym_tab, structs_sym_tab):
|
def func_proc(tree, module, chunks, map_sym_tab, structs_sym_tab):
|
||||||
for func_node in chunks:
|
for func_node in chunks:
|
||||||
is_global = False
|
if is_global_function(func_node):
|
||||||
for decorator in func_node.decorator_list:
|
|
||||||
if isinstance(decorator, ast.Name) and decorator.id in (
|
|
||||||
"map",
|
|
||||||
"bpfglobal",
|
|
||||||
"struct",
|
|
||||||
):
|
|
||||||
is_global = True
|
|
||||||
break
|
|
||||||
if is_global:
|
|
||||||
continue
|
continue
|
||||||
func_type = get_probe_string(func_node)
|
func_type = get_probe_string(func_node)
|
||||||
logger.info(f"Found probe_string of {func_node.name}: {func_type}")
|
logger.info(f"Found probe_string of {func_node.name}: {func_type}")
|
||||||
@ -415,67 +410,7 @@ def func_proc(tree, module, chunks, map_sym_tab, structs_sym_tab):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def infer_return_type(func_node: ast.FunctionDef):
|
# TODO: WIP, for string assignment to fixed-size arrays
|
||||||
if not isinstance(func_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
||||||
raise TypeError("Expected ast.FunctionDef")
|
|
||||||
if func_node.returns is not None:
|
|
||||||
try:
|
|
||||||
return ast.unparse(func_node.returns)
|
|
||||||
except Exception:
|
|
||||||
node = func_node.returns
|
|
||||||
if isinstance(node, ast.Name):
|
|
||||||
return node.id
|
|
||||||
if isinstance(node, ast.Attribute):
|
|
||||||
return getattr(node, "attr", type(node).__name__)
|
|
||||||
try:
|
|
||||||
return str(node)
|
|
||||||
except Exception:
|
|
||||||
return type(node).__name__
|
|
||||||
found_type = None
|
|
||||||
|
|
||||||
def _expr_type(e):
|
|
||||||
if e is None:
|
|
||||||
return "None"
|
|
||||||
if isinstance(e, ast.Constant):
|
|
||||||
return type(e.value).__name__
|
|
||||||
if isinstance(e, ast.Name):
|
|
||||||
return e.id
|
|
||||||
if isinstance(e, ast.Call):
|
|
||||||
f = e.func
|
|
||||||
if isinstance(f, ast.Name):
|
|
||||||
return f.id
|
|
||||||
if isinstance(f, ast.Attribute):
|
|
||||||
try:
|
|
||||||
return ast.unparse(f)
|
|
||||||
except Exception:
|
|
||||||
return getattr(f, "attr", type(f).__name__)
|
|
||||||
try:
|
|
||||||
return ast.unparse(f)
|
|
||||||
except Exception:
|
|
||||||
return type(f).__name__
|
|
||||||
if isinstance(e, ast.Attribute):
|
|
||||||
try:
|
|
||||||
return ast.unparse(e)
|
|
||||||
except Exception:
|
|
||||||
return getattr(e, "attr", type(e).__name__)
|
|
||||||
try:
|
|
||||||
return ast.unparse(e)
|
|
||||||
except Exception:
|
|
||||||
return type(e).__name__
|
|
||||||
|
|
||||||
for walked_node in ast.walk(func_node):
|
|
||||||
if isinstance(walked_node, ast.Return):
|
|
||||||
t = _expr_type(walked_node.value)
|
|
||||||
if found_type is None:
|
|
||||||
found_type = t
|
|
||||||
elif found_type != t:
|
|
||||||
raise ValueError(f"Conflicting return types: {found_type} vs {t}")
|
|
||||||
return found_type or "None"
|
|
||||||
|
|
||||||
|
|
||||||
# For string assignment to fixed-size arrays
|
|
||||||
|
|
||||||
|
|
||||||
def assign_string_to_array(builder, target_array_ptr, source_string_ptr, array_length):
|
def assign_string_to_array(builder, target_array_ptr, source_string_ptr, array_length):
|
||||||
"""
|
"""
|
||||||
Copy a string (i8*) to a fixed-size array ([N x i8]*)
|
Copy a string (i8*) to a fixed-size array ([N x i8]*)
|
||||||
|
|||||||
@ -14,19 +14,19 @@ XDP_ACTIONS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _handle_none_return(builder) -> bool:
|
def handle_none_return(builder) -> bool:
|
||||||
"""Handle return or return None -> returns 0."""
|
"""Handle return or return None -> returns 0."""
|
||||||
builder.ret(ir.Constant(ir.IntType(64), 0))
|
builder.ret(ir.Constant(ir.IntType(64), 0))
|
||||||
logger.debug("Generated default return: 0")
|
logger.debug("Generated default return: 0")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _is_xdp_name(name: str) -> bool:
|
def is_xdp_name(name: str) -> bool:
|
||||||
"""Check if a name is an XDP action"""
|
"""Check if a name is an XDP action"""
|
||||||
return name in XDP_ACTIONS
|
return name in XDP_ACTIONS
|
||||||
|
|
||||||
|
|
||||||
def _handle_xdp_return(stmt: ast.Return, builder, ret_type) -> bool:
|
def handle_xdp_return(stmt: ast.Return, builder, ret_type) -> bool:
|
||||||
"""Handle XDP returns"""
|
"""Handle XDP returns"""
|
||||||
if not isinstance(stmt.value, ast.Name):
|
if not isinstance(stmt.value, ast.Name):
|
||||||
return False
|
return False
|
||||||
@ -37,7 +37,6 @@ def _handle_xdp_return(stmt: ast.Return, builder, ret_type) -> bool:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown XDP action: {action_name}. Available: {XDP_ACTIONS.keys()}"
|
f"Unknown XDP action: {action_name}. Available: {XDP_ACTIONS.keys()}"
|
||||||
)
|
)
|
||||||
return False
|
|
||||||
|
|
||||||
value = XDP_ACTIONS[action_name]
|
value = XDP_ACTIONS[action_name]
|
||||||
builder.ret(ir.Constant(ret_type, value))
|
builder.ret(ir.Constant(ret_type, value))
|
||||||
|
|||||||
@ -2,6 +2,58 @@ from .helper_utils import HelperHandlerRegistry, reset_scratch_pool
|
|||||||
from .bpf_helper_handler import handle_helper_call
|
from .bpf_helper_handler import handle_helper_call
|
||||||
from .helpers import ktime, pid, deref, XDP_DROP, XDP_PASS
|
from .helpers import ktime, pid, deref, XDP_DROP, XDP_PASS
|
||||||
|
|
||||||
|
|
||||||
|
# Register the helper handler with expr module
|
||||||
|
def _register_helper_handler():
|
||||||
|
"""Register helper call handler with the expression evaluator"""
|
||||||
|
from pythonbpf.expr.expr_pass import CallHandlerRegistry
|
||||||
|
|
||||||
|
def helper_call_handler(
|
||||||
|
call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||||
|
):
|
||||||
|
"""Check if call is a helper and handle it"""
|
||||||
|
import ast
|
||||||
|
|
||||||
|
# Check for direct helper calls (e.g., ktime(), print())
|
||||||
|
if isinstance(call.func, ast.Name):
|
||||||
|
if HelperHandlerRegistry.has_handler(call.func.id):
|
||||||
|
return handle_helper_call(
|
||||||
|
call,
|
||||||
|
module,
|
||||||
|
builder,
|
||||||
|
func,
|
||||||
|
local_sym_tab,
|
||||||
|
map_sym_tab,
|
||||||
|
structs_sym_tab,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for method calls (e.g., map.lookup())
|
||||||
|
elif isinstance(call.func, ast.Attribute):
|
||||||
|
method_name = call.func.attr
|
||||||
|
|
||||||
|
# Handle: my_map.lookup(key)
|
||||||
|
if isinstance(call.func.value, ast.Name):
|
||||||
|
obj_name = call.func.value.id
|
||||||
|
if map_sym_tab and obj_name in map_sym_tab:
|
||||||
|
if HelperHandlerRegistry.has_handler(method_name):
|
||||||
|
return handle_helper_call(
|
||||||
|
call,
|
||||||
|
module,
|
||||||
|
builder,
|
||||||
|
func,
|
||||||
|
local_sym_tab,
|
||||||
|
map_sym_tab,
|
||||||
|
structs_sym_tab,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
CallHandlerRegistry.set_handler(helper_call_handler)
|
||||||
|
|
||||||
|
|
||||||
|
# Register on module import
|
||||||
|
_register_helper_handler()
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"HelperHandlerRegistry",
|
"HelperHandlerRegistry",
|
||||||
"reset_scratch_pool",
|
"reset_scratch_pool",
|
||||||
|
|||||||
@ -135,7 +135,7 @@ def bpf_printk_emitter(
|
|||||||
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
|
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
|
||||||
|
|
||||||
builder.call(fn_ptr, args, tail=True)
|
builder.call(fn_ptr, args, tail=True)
|
||||||
return None
|
return True
|
||||||
|
|
||||||
|
|
||||||
@HelperHandlerRegistry.register("update")
|
@HelperHandlerRegistry.register("update")
|
||||||
|
|||||||
@ -3,8 +3,12 @@ import logging
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
|
||||||
from llvmlite import ir
|
from llvmlite import ir
|
||||||
from pythonbpf.expr import eval_expr, get_base_type_and_depth, deref_to_depth
|
from pythonbpf.expr import (
|
||||||
from pythonbpf.binary_ops import get_operand_value
|
eval_expr,
|
||||||
|
get_base_type_and_depth,
|
||||||
|
deref_to_depth,
|
||||||
|
get_operand_value,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -243,6 +243,17 @@ class BTFConverter:
|
|||||||
data
|
data
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# below to replace those c_bool with bitfield greater than 8
|
||||||
|
def repl(m):
|
||||||
|
name, bits = m.groups()
|
||||||
|
return f"('{name}', ctypes.c_uint32, {bits})" if int(bits) > 8 else m.group(0)
|
||||||
|
|
||||||
|
data = re.sub(
|
||||||
|
r"\('([^']+)',\s*ctypes\.c_bool,\s*(\d+)\)",
|
||||||
|
repl,
|
||||||
|
data
|
||||||
|
)
|
||||||
|
|
||||||
# Remove ctypes. prefix from invalid entries
|
# Remove ctypes. prefix from invalid entries
|
||||||
invalid_ctypes = ["bpf_iter_state", "_cache_type", "fs_context_purpose"]
|
invalid_ctypes = ["bpf_iter_state", "_cache_type", "fs_context_purpose"]
|
||||||
for name in invalid_ctypes:
|
for name in invalid_ctypes:
|
||||||
|
|||||||
Reference in New Issue
Block a user