39 Commits

Author SHA1 Message Date
b09dc815fc Add StatementHandlerRegistry 2025-10-05 15:19:16 +05:30
ceaac78633 Janitorial: fix lint 2025-10-05 15:12:01 +05:30
dc7a127fa6 Restructure dir for functions 2025-10-05 15:09:39 +05:30
552cd352f2 Merge pull request #20 from pythonbpf/fix-failing-tests
Fix failing tests in tests/
2025-10-05 14:04:14 +05:30
c7f2955ee9 Fix typo in process_stmt 2025-10-05 14:03:19 +05:30
ef36ea1e03 Add nullcheck for var_name in handle_binary_ops 2025-10-05 14:02:08 +05:30
d341cb24c0 Update explanation for named_arg 2025-10-05 04:27:37 +05:30
2fabb67942 Add note for faling test named_arg 2025-10-05 03:15:17 +05:30
a0b0ad370e Merge pull request #23 from pythonbpf/formatter
update formatter and pre-commit
2025-10-05 01:15:01 +05:30
283b947fc5 Add named_arg failing test 2025-10-04 19:50:33 +05:30
bf78ac21fe Remove 'Static Typing' from short term tasks 2025-10-04 07:30:11 +05:30
ac49cd8b1c Fix hashmap access in direct_assign.py 2025-10-04 02:14:33 +05:30
af44bd063c Add explanation for direct_assign.py failing test 2025-10-04 02:13:46 +05:30
1239d1c35f Fix handle_binary_ops calls in functions_pass 2025-10-04 02:09:11 +05:30
f41a9ccf26 Remove unnecessary args from binary_ops 2025-10-04 02:07:31 +05:30
be05b5d102 Allow local symbols to be used within return 2025-10-03 19:50:56 +05:30
3f061750cf fix return value error 2025-10-03 19:11:11 +05:30
6d5d6345e2 Add var_rval failing test 2025-10-03 18:01:15 +05:30
6fea580693 Fix t/f/return.py, tweak handle_binary_ops 2025-10-03 17:56:21 +05:30
b35134625b Merge pull request #19 from pythonbpf/fix-expr
Refactor expr_pass
2025-10-03 17:36:31 +05:30
c3db609a90 Revert to using Warning loglevel as default 2025-10-03 17:35:57 +05:30
cc626c38f7 Move binops1 to tests/passing 2025-10-03 17:13:02 +05:30
a8b3f4f86c Fix recursive binops, move failing binops to passing 2025-10-03 17:08:41 +05:30
d593969408 Refactor ugly if-elif chain in handle_binary_op 2025-10-03 14:04:38 +05:30
6d5895ebc2 More fixes to recursive dereferencer, add get_operand value 2025-10-03 13:46:52 +05:30
c9ee6e4f17 Fix recursive_dereferencer in binops 2025-10-03 13:35:15 +05:30
a622c53e0f Add deref 2025-10-03 02:00:01 +05:30
a4f1363aed Add _handle_attribute_expr 2025-10-03 01:50:59 +05:30
3a819dcaee Add _handle_constant_expr 2025-10-02 22:54:38 +05:30
729270b34b Use _handle_name_expr in eval_expr 2025-10-02 22:50:21 +05:30
44cbcccb6c Create _handle_name_expr 2025-10-02 22:43:54 +05:30
86b9ec56d7 update formatter and pre-commit
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
2025-10-02 22:43:05 +05:30
253944afd2 Merge pull request #18 from pythonbpf/fix-maps
Fix map calling convention
2025-10-02 22:12:01 +05:30
54993ce5c2 Merge branch 'master' into fix-maps 2025-10-02 22:11:38 +05:30
05083bd513 janitorial nitpicks 2025-10-02 22:10:28 +05:30
6e4c340780 Allow non-call convention for maps 2025-10-02 22:07:28 +05:30
9dbca410c2 Remove calls from map in sys_sync 2025-10-02 21:24:15 +05:30
62ca3b5ffe format errors 2025-10-02 19:07:49 +05:30
f263c35156 move debug cu generation to debug module 2025-10-02 19:05:58 +05:30
20 changed files with 374 additions and 199 deletions

View File

@ -21,7 +21,7 @@ ci:
repos: repos:
# Standard hooks # Standard hooks
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0 rev: v6.0.0
hooks: hooks:
- id: check-added-large-files - id: check-added-large-files
- id: check-case-conflict - id: check-case-conflict
@ -36,7 +36,7 @@ repos:
- id: trailing-whitespace - id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.4.2" rev: "v0.13.2"
hooks: hooks:
- id: ruff - id: ruff
args: ["--fix", "--show-fixes"] args: ["--fix", "--show-fixes"]
@ -45,7 +45,7 @@ repos:
# Checking static types # Checking static types
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.10.0" rev: "v1.18.2"
hooks: hooks:
- id: mypy - id: mypy
exclude: ^(tests)|^(examples) exclude: ^(tests)|^(examples)

View File

@ -1,7 +1,6 @@
## Short term ## Short term
- Implement enough functionality to port the BCC tutorial examples in PythonBPF - Implement enough functionality to port the BCC tutorial examples in PythonBPF
- Static Typing
- Add all maps - Add all maps
- XDP support in pylibbpf - XDP support in pylibbpf
- ringbuf support - ringbuf support

View File

@ -12,7 +12,7 @@
"from pythonbpf import bpf, map, section, bpfglobal, BPF\n", "from pythonbpf import bpf, map, section, bpfglobal, BPF\n",
"from pythonbpf.helper import pid\n", "from pythonbpf.helper import pid\n",
"from pythonbpf.maps import HashMap\n", "from pythonbpf.maps import HashMap\n",
"from pylibbpf import *\n", "from pylibbpf import BpfMap\n",
"from ctypes import c_void_p, c_int64, c_uint64, c_int32\n", "from ctypes import c_void_p, c_int64, c_uint64, c_int32\n",
"import matplotlib.pyplot as plt" "import matplotlib.pyplot as plt"
] ]

View File

@ -21,17 +21,17 @@ def last() -> HashMap:
@section("tracepoint/syscalls/sys_enter_sync") @section("tracepoint/syscalls/sys_enter_sync")
def do_trace(ctx: c_void_p) -> c_int64: def do_trace(ctx: c_void_p) -> c_int64:
key = 0 key = 0
tsp = last().lookup(key) tsp = last.lookup(key)
if tsp: if tsp:
kt = ktime() kt = ktime()
delta = kt - tsp delta = kt - tsp
if delta < 1000000000: if delta < 1000000000:
time_ms = delta // 1000000 time_ms = delta // 1000000
print(f"sync called within last second, last {time_ms} ms ago") print(f"sync called within last second, last {time_ms} ms ago")
last().delete(key) last.delete(key)
else: else:
kt = ktime() kt = ktime()
last().update(key, kt) last.update(key, kt)
return c_int64(0) return c_int64(0)

View File

@ -8,68 +8,65 @@ logger: Logger = logging.getLogger(__name__)
def recursive_dereferencer(var, builder): def recursive_dereferencer(var, builder):
"""dereference until primitive type comes out""" """dereference until primitive type comes out"""
if var.type == ir.PointerType(ir.PointerType(ir.IntType(64))): # TODO: Not worrying about stack overflow for now
logger.info(f"Dereferencing {var}, type is {var.type}")
if isinstance(var.type, ir.PointerType):
a = builder.load(var) a = builder.load(var)
return recursive_dereferencer(a, builder) return recursive_dereferencer(a, builder)
elif var.type == ir.PointerType(ir.IntType(64)): elif isinstance(var.type, ir.IntType):
a = builder.load(var)
return recursive_dereferencer(a, builder)
elif var.type == ir.IntType(64):
return var return var
else: else:
raise TypeError(f"Unsupported type for dereferencing: {var.type}") raise TypeError(f"Unsupported type for dereferencing: {var.type}")
def handle_binary_op(rval, module, builder, var_name, local_sym_tab, map_sym_tab, func): def get_operand_value(operand, builder, local_sym_tab):
logger.info(f"module {module}") """Extract the value from an operand, handling variables and constants."""
left = rval.left if isinstance(operand, ast.Name):
right = rval.right if operand.id in local_sym_tab:
return recursive_dereferencer(local_sym_tab[operand.id].var, builder)
raise ValueError(f"Undefined variable: {operand.id}")
elif isinstance(operand, ast.Constant):
if isinstance(operand.value, int):
return ir.Constant(ir.IntType(64), operand.value)
raise TypeError(f"Unsupported constant type: {type(operand.value)}")
elif isinstance(operand, ast.BinOp):
return handle_binary_op_impl(operand, builder, local_sym_tab)
raise TypeError(f"Unsupported operand type: {type(operand)}")
def handle_binary_op_impl(rval, builder, local_sym_tab):
op = rval.op op = rval.op
left = get_operand_value(rval.left, builder, local_sym_tab)
# Handle left operand right = get_operand_value(rval.right, builder, local_sym_tab)
if isinstance(left, ast.Name):
if left.id in local_sym_tab:
left = recursive_dereferencer(local_sym_tab[left.id].var, builder)
else:
raise SyntaxError(f"Undefined variable: {left.id}")
elif isinstance(left, ast.Constant):
left = ir.Constant(ir.IntType(64), left.value)
else:
raise SyntaxError("Unsupported left operand type")
if isinstance(right, ast.Name):
if right.id in local_sym_tab:
right = recursive_dereferencer(local_sym_tab[right.id].var, builder)
else:
raise SyntaxError(f"Undefined variable: {right.id}")
elif isinstance(right, ast.Constant):
right = ir.Constant(ir.IntType(64), right.value)
else:
raise SyntaxError("Unsupported right operand type")
logger.info(f"left is {left}, right is {right}, op is {op}") logger.info(f"left is {left}, right is {right}, op is {op}")
if isinstance(op, ast.Add): # Map AST operation nodes to LLVM IR builder methods
builder.store(builder.add(left, right), local_sym_tab[var_name].var) op_map = {
elif isinstance(op, ast.Sub): ast.Add: builder.add,
builder.store(builder.sub(left, right), local_sym_tab[var_name].var) ast.Sub: builder.sub,
elif isinstance(op, ast.Mult): ast.Mult: builder.mul,
builder.store(builder.mul(left, right), local_sym_tab[var_name].var) ast.Div: builder.sdiv,
elif isinstance(op, ast.Div): ast.Mod: builder.srem,
builder.store(builder.sdiv(left, right), local_sym_tab[var_name].var) ast.LShift: builder.shl,
elif isinstance(op, ast.Mod): ast.RShift: builder.lshr,
builder.store(builder.srem(left, right), local_sym_tab[var_name].var) ast.BitOr: builder.or_,
elif isinstance(op, ast.LShift): ast.BitXor: builder.xor,
builder.store(builder.shl(left, right), local_sym_tab[var_name].var) ast.BitAnd: builder.and_,
elif isinstance(op, ast.RShift): ast.FloorDiv: builder.udiv,
builder.store(builder.lshr(left, right), local_sym_tab[var_name].var) }
elif isinstance(op, ast.BitOr):
builder.store(builder.or_(left, right), local_sym_tab[var_name].var) if type(op) in op_map:
elif isinstance(op, ast.BitXor): result = op_map[type(op)](left, right)
builder.store(builder.xor(left, right), local_sym_tab[var_name].var) return result
elif isinstance(op, ast.BitAnd):
builder.store(builder.and_(left, right), local_sym_tab[var_name].var)
elif isinstance(op, ast.FloorDiv):
builder.store(builder.udiv(left, right), local_sym_tab[var_name].var)
else: else:
raise SyntaxError("Unsupported binary operation") raise SyntaxError("Unsupported binary operation")
def handle_binary_op(rval, builder, var_name, local_sym_tab):
result = handle_binary_op_impl(rval, builder, local_sym_tab)
if var_name and var_name in local_sym_tab:
logger.info(
f"Storing result {result} into variable {local_sym_tab[var_name].var}"
)
builder.store(result, local_sym_tab[var_name].var)
return result, result.type

View File

@ -1,11 +1,11 @@
import ast import ast
from llvmlite import ir from llvmlite import ir
from .license_pass import license_processing from .license_pass import license_processing
from .functions_pass import func_proc from .functions import func_proc
from .maps import maps_proc from .maps import maps_proc
from .structs import structs_proc from .structs import structs_proc
from .globals_pass import globals_processing from .globals_pass import globals_processing
from .debuginfo import DW_LANG_C11, DwarfBehaviorEnum from .debuginfo import DW_LANG_C11, DwarfBehaviorEnum, DebugInfoGenerator
import os import os
import subprocess import subprocess
import inspect import inspect
@ -48,7 +48,7 @@ def processor(source_code, filename, module):
globals_processing(tree, module) globals_processing(tree, module)
def compile_to_ir(filename: str, output: str, loglevel=logging.WARNING): def compile_to_ir(filename: str, output: str, loglevel=logging.INFO):
logging.basicConfig( logging.basicConfig(
level=loglevel, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" level=loglevel, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
) )
@ -60,33 +60,17 @@ def compile_to_ir(filename: str, output: str, loglevel=logging.WARNING):
module.triple = "bpf" module.triple = "bpf"
if not hasattr(module, "_debug_compile_unit"): if not hasattr(module, "_debug_compile_unit"):
module._file_metadata = module.add_debug_info( debug_generator = DebugInfoGenerator(module)
"DIFile", debug_generator.generate_file_metadata(filename, os.path.dirname(filename))
{ # type: ignore debug_generator.generate_debug_cu(
"filename": filename, DW_LANG_C11,
"directory": os.path.dirname(filename), f"PythonBPF {VERSION}",
}, True, # TODO: This is probably not true
# TODO: add a global field here that keeps track of all the globals. Works without it, but I think it might
# be required for kprobes.
True,
) )
module._debug_compile_unit = module.add_debug_info(
"DICompileUnit",
{ # type: ignore
"language": DW_LANG_C11,
"file": module._file_metadata, # type: ignore
"producer": f"PythonBPF {VERSION}",
"isOptimized": True, # TODO: This is probably not true
# TODO: add a global field here that keeps track of all the globals. Works without it, but I think it might
# be required for kprobes.
"runtimeVersion": 0,
"emissionKind": 1,
"splitDebugInlining": False,
"nameTableKind": 0,
},
is_distinct=True,
)
module.add_named_metadata("llvm.dbg.cu", module._debug_compile_unit) # type: ignore
processor(source, filename, module) processor(source, filename, module)
wchar_size = module.add_metadata( wchar_size = module.add_metadata(
@ -137,7 +121,7 @@ def compile_to_ir(filename: str, output: str, loglevel=logging.WARNING):
return output return output
def compile(loglevel=logging.WARNING) -> bool: def compile(loglevel=logging.INFO) -> bool:
# Look one level up the stack to the caller of this function # Look one level up the stack to the caller of this function
caller_frame = inspect.stack()[1] caller_frame = inspect.stack()[1]
caller_file = Path(caller_frame.filename).resolve() caller_file = Path(caller_frame.filename).resolve()
@ -170,7 +154,7 @@ def compile(loglevel=logging.WARNING) -> bool:
return success return success
def BPF(loglevel=logging.WARNING) -> BpfProgram: def BPF(loglevel=logging.INFO) -> BpfProgram:
caller_frame = inspect.stack()[1] caller_frame = inspect.stack()[1]
src = inspect.getsource(caller_frame.frame) src = inspect.getsource(caller_frame.frame)
with tempfile.NamedTemporaryFile( with tempfile.NamedTemporaryFile(

View File

@ -12,6 +12,34 @@ class DebugInfoGenerator:
self.module = module self.module = module
self._type_cache = {} # Cache for common debug types self._type_cache = {} # Cache for common debug types
def generate_file_metadata(self, filename, dirname):
self.module._file_metadata = self.module.add_debug_info(
"DIFile",
{ # type: ignore
"filename": filename,
"directory": dirname,
},
)
def generate_debug_cu(
self, language, producer: str, is_optimized: bool, is_distinct: bool
):
self.module._debug_compile_unit = self.module.add_debug_info(
"DICompileUnit",
{ # type: ignore
"language": language,
"file": self.module._file_metadata, # type: ignore
"producer": producer,
"isOptimized": is_optimized,
"runtimeVersion": 0,
"emissionKind": 1,
"splitDebugInlining": False,
"nameTableKind": 0,
},
is_distinct=is_distinct,
)
self.module.add_named_metadata("llvm.dbg.cu", self.module._debug_compile_unit) # type: ignore
def get_basic_type(self, name: str, size: int, encoding: int) -> Any: def get_basic_type(self, name: str, size: int, encoding: int) -> Any:
"""Get or create a basic type with caching""" """Get or create a basic type with caching"""
key = (name, size, encoding) key = (name, size, encoding)

View File

@ -2,10 +2,92 @@ import ast
from llvmlite import ir from llvmlite import ir
from logging import Logger from logging import Logger
import logging import logging
from typing import Dict
logger: Logger = logging.getLogger(__name__) logger: Logger = logging.getLogger(__name__)
def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder):
"""Handle ast.Name expressions."""
if expr.id in local_sym_tab:
var = local_sym_tab[expr.id].var
val = builder.load(var)
return val, local_sym_tab[expr.id].ir_type
else:
logger.info(f"Undefined variable {expr.id}")
return None
def _handle_constant_expr(expr: ast.Constant):
"""Handle ast.Constant expressions."""
if isinstance(expr.value, int):
return ir.Constant(ir.IntType(64), expr.value), ir.IntType(64)
elif isinstance(expr.value, bool):
return ir.Constant(ir.IntType(1), int(expr.value)), ir.IntType(1)
else:
logger.info("Unsupported constant type")
return None
def _handle_attribute_expr(
expr: ast.Attribute,
local_sym_tab: Dict,
structs_sym_tab: Dict,
builder: ir.IRBuilder,
):
"""Handle ast.Attribute expressions for struct field access."""
if isinstance(expr.value, ast.Name):
var_name = expr.value.id
attr_name = expr.attr
if var_name in local_sym_tab:
var_ptr, var_type, var_metadata = local_sym_tab[var_name]
logger.info(f"Loading attribute {attr_name} from variable {var_name}")
logger.info(f"Variable type: {var_type}, Variable ptr: {var_ptr}")
metadata = structs_sym_tab[var_metadata]
if attr_name in metadata.fields:
gep = metadata.gep(builder, var_ptr, attr_name)
val = builder.load(gep)
field_type = metadata.field_type(attr_name)
return val, field_type
return None
def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilder):
"""Handle deref function calls."""
logger.info(f"Handling deref {ast.dump(expr)}")
if len(expr.args) != 1:
logger.info("deref takes exactly one argument")
return None
arg = expr.args[0]
if (
isinstance(arg, ast.Call)
and isinstance(arg.func, ast.Name)
and arg.func.id == "deref"
):
logger.info("Multiple deref not supported")
return None
if isinstance(arg, ast.Name):
if arg.id in local_sym_tab:
arg_ptr = local_sym_tab[arg.id].var
else:
logger.info(f"Undefined variable {arg.id}")
return None
else:
logger.info("Unsupported argument type for deref")
return None
if arg_ptr is None:
logger.info("Failed to evaluate deref argument")
return None
# Load the value from pointer
val = builder.load(arg_ptr)
return val, local_sym_tab[arg.id].ir_type
def eval_expr( def eval_expr(
func, func,
module, module,
@ -17,64 +99,28 @@ def eval_expr(
): ):
logger.info(f"Evaluating expression: {ast.dump(expr)}") logger.info(f"Evaluating expression: {ast.dump(expr)}")
if isinstance(expr, ast.Name): if isinstance(expr, ast.Name):
if expr.id in local_sym_tab: return _handle_name_expr(expr, local_sym_tab, builder)
var = local_sym_tab[expr.id].var
val = builder.load(var)
return val, local_sym_tab[expr.id].ir_type # return value and type
else:
logger.info(f"Undefined variable {expr.id}")
return None
elif isinstance(expr, ast.Constant): elif isinstance(expr, ast.Constant):
if isinstance(expr.value, int): return _handle_constant_expr(expr)
return ir.Constant(ir.IntType(64), expr.value), ir.IntType(64)
elif isinstance(expr.value, bool):
return ir.Constant(ir.IntType(1), int(expr.value)), ir.IntType(1)
else:
logger.info("Unsupported constant type")
return None
elif isinstance(expr, ast.Call): elif isinstance(expr, ast.Call):
if isinstance(expr.func, ast.Name) and expr.func.id == "deref":
return _handle_deref_call(expr, local_sym_tab, builder)
# delayed import to avoid circular dependency # delayed import to avoid circular dependency
from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call
if isinstance(expr.func, ast.Name): if isinstance(expr.func, ast.Name) and HelperHandlerRegistry.has_handler(
# check deref expr.func.id
if expr.func.id == "deref": ):
logger.info(f"Handling deref {ast.dump(expr)}") return handle_helper_call(
if len(expr.args) != 1: expr,
logger.info("deref takes exactly one argument") module,
return None builder,
arg = expr.args[0] func,
if ( local_sym_tab,
isinstance(arg, ast.Call) map_sym_tab,
and isinstance(arg.func, ast.Name) structs_sym_tab,
and arg.func.id == "deref" )
):
logger.info("Multiple deref not supported")
return None
if isinstance(arg, ast.Name):
if arg.id in local_sym_tab:
arg = local_sym_tab[arg.id].var
else:
logger.info(f"Undefined variable {arg.id}")
return None
if arg is None:
logger.info("Failed to evaluate deref argument")
return None
# Since we are handling only name case, directly take type from sym tab
val = builder.load(arg)
return val, local_sym_tab[expr.args[0].id].ir_type
# check for helpers
if HelperHandlerRegistry.has_handler(expr.func.id):
return handle_helper_call(
expr,
module,
builder,
func,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
elif isinstance(expr.func, ast.Attribute): elif isinstance(expr.func, ast.Attribute):
logger.info(f"Handling method call: {ast.dump(expr.func)}") logger.info(f"Handling method call: {ast.dump(expr.func)}")
if isinstance(expr.func.value, ast.Call) and isinstance( if isinstance(expr.func.value, ast.Call) and isinstance(
@ -106,19 +152,7 @@ def eval_expr(
structs_sym_tab, structs_sym_tab,
) )
elif isinstance(expr, ast.Attribute): elif isinstance(expr, ast.Attribute):
if isinstance(expr.value, ast.Name): return _handle_attribute_expr(expr, local_sym_tab, structs_sym_tab, builder)
var_name = expr.value.id
attr_name = expr.attr
if var_name in local_sym_tab:
var_ptr, var_type, var_metadata = local_sym_tab[var_name]
logger.info(f"Loading attribute {attr_name} from variable {var_name}")
logger.info(f"Variable type: {var_type}, Variable ptr: {var_ptr}")
metadata = structs_sym_tab[var_metadata]
if attr_name in metadata.fields:
gep = metadata.gep(builder, var_ptr, attr_name)
val = builder.load(gep)
field_type = metadata.field_type(attr_name)
return val, field_type
logger.info("Unsupported expression evaluation") logger.info("Unsupported expression evaluation")
return None return None

View File

@ -0,0 +1,3 @@
from .functions_pass import func_proc
__all__ = ["func_proc"]

View File

@ -0,0 +1,22 @@
from typing import Dict
class StatementHandlerRegistry:
"""Registry for statement handlers."""
_handlers: Dict = {}
@classmethod
def register(cls, stmt_type):
"""Register a handler for a specific statement type."""
def decorator(handler):
cls._handlers[stmt_type] = handler
return handler
return decorator
@classmethod
def get_handler(cls, stmt_type):
"""Get the handler for a specific statement type."""
return cls._handlers.get(stmt_type, None)

View File

@ -4,10 +4,10 @@ import logging
from typing import Any from typing import Any
from dataclasses import dataclass from dataclasses import dataclass
from .helper import HelperHandlerRegistry, handle_helper_call from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call
from .type_deducer import ctypes_to_ir from pythonbpf.type_deducer import ctypes_to_ir
from .binary_ops import handle_binary_op from pythonbpf.binary_ops import handle_binary_op
from .expr_pass import eval_expr, handle_expr from pythonbpf.expr_pass import eval_expr, handle_expr
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -146,8 +146,7 @@ def handle_assign(
local_sym_tab[var_name].var, local_sym_tab[var_name].var,
) )
logger.info( logger.info(
f"Assigned {call_type} constant " f"Assigned {call_type} constant {rval.args[0].value} to {var_name}"
f"{rval.args[0].value} to {var_name}"
) )
elif HelperHandlerRegistry.has_handler(call_type): elif HelperHandlerRegistry.has_handler(call_type):
# var = builder.alloca(ir.IntType(64), name=var_name) # var = builder.alloca(ir.IntType(64), name=var_name)
@ -192,8 +191,23 @@ def handle_assign(
elif isinstance(rval.func, ast.Attribute): elif isinstance(rval.func, ast.Attribute):
logger.info(f"Assignment call attribute: {ast.dump(rval.func)}") logger.info(f"Assignment call attribute: {ast.dump(rval.func)}")
if isinstance(rval.func.value, ast.Name): if isinstance(rval.func.value, ast.Name):
# TODO: probably a struct access if rval.func.value.id in map_sym_tab:
logger.info(f"TODO STRUCT ACCESS {ast.dump(rval)}") map_name = rval.func.value.id
method_name = rval.func.attr
if HelperHandlerRegistry.has_handler(method_name):
val = handle_helper_call(
rval,
module,
builder,
func,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
builder.store(val[0], local_sym_tab[var_name].var)
else:
# TODO: probably a struct access
logger.info(f"TODO STRUCT ACCESS {ast.dump(rval)}")
elif isinstance(rval.func.value, ast.Call) and isinstance( elif isinstance(rval.func.value, ast.Call) and isinstance(
rval.func.value.func, ast.Name rval.func.value.func, ast.Name
): ):
@ -218,9 +232,7 @@ def handle_assign(
else: else:
logger.info("Unsupported assignment call function type") logger.info("Unsupported assignment call function type")
elif isinstance(rval, ast.BinOp): elif isinstance(rval, ast.BinOp):
handle_binary_op( handle_binary_op(rval, builder, var_name, local_sym_tab)
rval, module, builder, var_name, local_sym_tab, map_sym_tab, func
)
else: else:
logger.info("Unsupported assignment value type") logger.info("Unsupported assignment value type")
@ -372,24 +384,48 @@ def process_stmt(
) )
elif isinstance(stmt, ast.Return): elif isinstance(stmt, ast.Return):
if stmt.value is None: if stmt.value is None:
builder.ret(ir.Constant(ir.IntType(32), 0)) builder.ret(ir.Constant(ir.IntType(64), 0))
did_return = True did_return = True
elif ( elif (
isinstance(stmt.value, ast.Call) isinstance(stmt.value, ast.Call)
and isinstance(stmt.value.func, ast.Name) and isinstance(stmt.value.func, ast.Name)
and len(stmt.value.args) == 1 and len(stmt.value.args) == 1
and isinstance(stmt.value.args[0], ast.Constant)
and isinstance(stmt.value.args[0].value, int)
): ):
call_type = stmt.value.func.id if isinstance(stmt.value.args[0], ast.Constant) and isinstance(
if ctypes_to_ir(call_type) != ret_type: stmt.value.args[0].value, int
raise ValueError( ):
"Return type mismatch: expected" call_type = stmt.value.func.id
f"{ctypes_to_ir(call_type)}, got {call_type}" if ctypes_to_ir(call_type) != ret_type:
) raise ValueError(
else: "Return type mismatch: expected"
builder.ret(ir.Constant(ret_type, stmt.value.args[0].value)) f"{ctypes_to_ir(call_type)}, got {call_type}"
)
else:
builder.ret(ir.Constant(ret_type, stmt.value.args[0].value))
did_return = True
elif isinstance(stmt.value.args[0], ast.BinOp):
# TODO: Should be routed through eval_expr
val = handle_binary_op(stmt.value.args[0], builder, None, local_sym_tab)
if val is None:
raise ValueError("Failed to evaluate return expression")
if val[1] != ret_type:
raise ValueError(
f"Return type mismatch: expected {ret_type}, got {val[1]}"
)
builder.ret(val[0])
did_return = True did_return = True
elif isinstance(stmt.value.args[0], ast.Name):
if stmt.value.args[0].id in local_sym_tab:
var = local_sym_tab[stmt.value.args[0].id].var
val = builder.load(var)
if val.type != ret_type:
raise ValueError(
f"Return type mismatch: expected {ret_type}, got {val.type}"
)
builder.ret(val)
did_return = True
else:
raise ValueError("Failed to evaluate return expression")
elif isinstance(stmt.value, ast.Name): elif isinstance(stmt.value, ast.Name):
if stmt.value.id == "XDP_PASS": if stmt.value.id == "XDP_PASS":
builder.ret(ir.Constant(ret_type, 2)) builder.ret(ir.Constant(ret_type, 2))
@ -442,6 +478,9 @@ def allocate_mem(
continue continue
var_name = target.id var_name = target.id
rval = stmt.value rval = stmt.value
if var_name in local_sym_tab:
logger.info(f"Variable {var_name} already allocated")
continue
if isinstance(rval, ast.Call): if isinstance(rval, ast.Call):
if isinstance(rval.func, ast.Name): if isinstance(rval.func, ast.Name):
call_type = rval.func.id call_type = rval.func.id
@ -470,8 +509,7 @@ def allocate_mem(
var = builder.alloca(ir_type, name=var_name) var = builder.alloca(ir_type, name=var_name)
has_metadata = True has_metadata = True
logger.info( logger.info(
f"Pre-allocated variable {var_name} " f"Pre-allocated variable {var_name} for struct {call_type}"
f"for struct {call_type}"
) )
elif isinstance(rval.func, ast.Attribute): elif isinstance(rval.func, ast.Attribute):
ir_type = ir.PointerType(ir.IntType(64)) ir_type = ir.PointerType(ir.IntType(64))
@ -555,7 +593,7 @@ def process_func_body(
) )
if not did_return: if not did_return:
builder.ret(ir.Constant(ir.IntType(32), 0)) builder.ret(ir.Constant(ir.IntType(64), 0))
def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_tab): def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_tab):

View File

@ -62,7 +62,7 @@ def bpf_map_lookup_elem_emitter(
""" """
if not call.args or len(call.args) != 1: if not call.args or len(call.args) != 1:
raise ValueError( raise ValueError(
"Map lookup expects exactly one argument (key), got " f"{len(call.args)}" f"Map lookup expects exactly one argument (key), got {len(call.args)}"
) )
key_ptr = get_or_create_ptr_from_arg(call.args[0], builder, local_sym_tab) key_ptr = get_or_create_ptr_from_arg(call.args[0], builder, local_sym_tab)
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
@ -145,8 +145,7 @@ def bpf_map_update_elem_emitter(
""" """
if not call.args or len(call.args) < 2 or len(call.args) > 3: if not call.args or len(call.args) < 2 or len(call.args) > 3:
raise ValueError( raise ValueError(
"Map update expects 2 or 3 args (key, value, flags), " f"Map update expects 2 or 3 args (key, value, flags), got {len(call.args)}"
f"got {len(call.args)}"
) )
key_arg = call.args[0] key_arg = call.args[0]
@ -196,7 +195,7 @@ def bpf_map_delete_elem_emitter(
""" """
if not call.args or len(call.args) != 1: if not call.args or len(call.args) != 1:
raise ValueError( raise ValueError(
"Map delete expects exactly one argument (key), got " f"{len(call.args)}" f"Map delete expects exactly one argument (key), got {len(call.args)}"
) )
key_ptr = get_or_create_ptr_from_arg(call.args[0], builder, local_sym_tab) key_ptr = get_or_create_ptr_from_arg(call.args[0], builder, local_sym_tab)
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
@ -255,7 +254,7 @@ def bpf_perf_event_output_handler(
): ):
if len(call.args) != 1: if len(call.args) != 1:
raise ValueError( raise ValueError(
"Perf event output expects exactly one argument, " f"got {len(call.args)}" f"Perf event output expects exactly one argument, got {len(call.args)}"
) )
data_arg = call.args[0] data_arg = call.args[0]
ctx_ptr = func.args[0] # First argument to the function is ctx ctx_ptr = func.args[0] # First argument to the function is ctx

View File

@ -270,7 +270,7 @@ def _prepare_expr_args(expr, func, module, builder, local_sym_tab, struct_sym_ta
val = builder.sext(val, ir.IntType(64)) val = builder.sext(val, ir.IntType(64))
else: else:
logger.warning( logger.warning(
"Only int and ptr supported in bpf_printk args. " "Others default to 0." "Only int and ptr supported in bpf_printk args. Others default to 0."
) )
val = ir.Constant(ir.IntType(64), 0) val = ir.Constant(ir.IntType(64), 0)
return val return val

View File

@ -278,9 +278,7 @@ def process_bpf_map(func_node, module):
if handler: if handler:
return handler(map_name, rval, module) return handler(map_name, rval, module)
else: else:
logger.warning( logger.warning(f"Unknown map type {rval.func.id}, defaulting to HashMap")
f"Unknown map type " f"{rval.func.id}, defaulting to HashMap"
)
return process_hash_map(map_name, rval, module) return process_hash_map(map_name, rval, module)
else: else:
raise ValueError("Function under @map must return a map") raise ValueError("Function under @map must return a map")

View File

@ -4,6 +4,18 @@ from pythonbpf.maps import HashMap
from ctypes import c_void_p, c_int64 from ctypes import c_void_p, c_int64
# NOTE: I have decided to not fix this example for now.
# The issue is in line 31, where we are passing an expression.
# The update helper expects a pointer type. But the problem is
# that we must allocate the space for said pointer in the first
# basic block. As that usage is in a different basic block, we
# are unable to cast the expression to a pointer type. (as we never
# allocated space for it).
# Shall we change our space allocation logic? That allows users to
# spam the same helper with the same args, and still run out of
# stack space. So we consider this usage invalid for now.
# Might fix it later.
@bpf @bpf
@map @map
@ -14,12 +26,12 @@ def count() -> HashMap:
@bpf @bpf
@section("xdp") @section("xdp")
def hello_world(ctx: c_void_p) -> c_int64: def hello_world(ctx: c_void_p) -> c_int64:
prev = count().lookup(0) prev = count.lookup(0)
if prev: if prev:
count().update(0, prev + 1) count.update(0, prev + 1)
return XDP_PASS return XDP_PASS
else: else:
count().update(0, 1) count.update(0, 1)
return XDP_PASS return XDP_PASS

View File

@ -0,0 +1,40 @@
from pythonbpf import bpf, map, section, bpfglobal, compile
from pythonbpf.helper import XDP_PASS
from pythonbpf.maps import HashMap
from ctypes import c_void_p, c_int64
# NOTE: This example exposes the problems with our typing system.
# We can't do steps on line 25 and 27.
# prev is of type i64**. For prev + 1, we deref it down to i64
# To assign it back to prev, we need to go back to i64**.
# We cannot allocate space for the intermediate type now.
# We probably need to track the ref/deref chain for each variable.
@bpf
@map
def count() -> HashMap:
return HashMap(key=c_int64, value=c_int64, max_entries=1)
@bpf
@section("xdp")
def hello_world(ctx: c_void_p) -> c_int64:
prev = count.lookup(0)
if prev:
prev = prev + 1
count.update(0, prev)
return XDP_PASS
else:
count.update(0, 1)
return XDP_PASS
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile()

View File

@ -3,9 +3,9 @@ from ctypes import c_void_p, c_int64
@bpf @bpf
@section("sometag1") @section("tracepoint/syscalls/sys_enter_sync")
def sometag(ctx: c_void_p) -> c_int64: def sometag(ctx: c_void_p) -> c_int64:
a = 1 + 2 + 1 a = 1 + 2 + 1 + 12 + 13
print(f"{a}") print(f"{a}")
return c_int64(0) return c_int64(0)

View File

@ -0,0 +1,20 @@
from pythonbpf import compile, bpf, section, bpfglobal
from ctypes import c_void_p, c_int64
@bpf
@section("tracepoint/syscalls/sys_enter_sync")
def sometag(ctx: c_void_p) -> c_int64:
b = 1 + 2
a = 1 + b
print(f"{a}")
return c_int64(0)
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile()

View File

@ -1,3 +1,5 @@
import logging
from pythonbpf import compile, bpf, section, bpfglobal from pythonbpf import compile, bpf, section, bpfglobal
from ctypes import c_void_p, c_int64 from ctypes import c_void_p, c_int64
@ -5,8 +7,7 @@ from ctypes import c_void_p, c_int64
@bpf @bpf
@section("sometag1") @section("sometag1")
def sometag(ctx: c_void_p) -> c_int64: def sometag(ctx: c_void_p) -> c_int64:
b = 1 + 2 a = 1 - 1
a = 1 + b
return c_int64(a) return c_int64(a)
@ -16,4 +17,4 @@ def LICENSE() -> str:
return "GPL" return "GPL"
compile() compile(loglevel=logging.INFO)