mirror of
https://github.com/varun-r-mallya/Python-BPF.git
synced 2026-04-02 11:51:27 +00:00
Compare commits
21 Commits
copilot/cr
...
testing-fr
| Author | SHA1 | Date | |
|---|---|---|---|
| 1d555ddd47 | |||
| ee444447b9 | |||
| da57911122 | |||
| 0498885f71 | |||
| 3f4f95115f | |||
| f2b9767098 | |||
| 0e087b9ea5 | |||
| ccbdfee9de | |||
| 61bca6bad9 | |||
| 305a8ba9e3 | |||
| bdcfe47601 | |||
| 3396d84e26 | |||
| c22911daaf | |||
| c04e32bd24 | |||
| b7f917c3c2 | |||
| b025ae7158 | |||
| ec4a6852ec | |||
| 45d85c416f | |||
| cb18ab67d9 | |||
| d1872bc868 | |||
| 2c2ed473d8 |
4
.github/workflows/python-publish.yml
vendored
4
.github/workflows/python-publish.yml
vendored
@ -33,7 +33,7 @@ jobs:
|
||||
python -m build
|
||||
|
||||
- name: Upload distributions
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v7
|
||||
with:
|
||||
name: release-dists
|
||||
path: dist/
|
||||
@ -59,7 +59,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Retrieve release distributions
|
||||
uses: actions/download-artifact@v7
|
||||
uses: actions/download-artifact@v8
|
||||
with:
|
||||
name: release-dists
|
||||
path: dist/
|
||||
|
||||
21
.readthedocs.yaml
Normal file
21
.readthedocs.yaml
Normal file
@ -0,0 +1,21 @@
|
||||
# Read the Docs configuration file for Sphinx
|
||||
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
|
||||
|
||||
version: 2
|
||||
|
||||
# Set the OS, Python version, and other tools you might need
|
||||
build:
|
||||
os: ubuntu-24.04
|
||||
tools:
|
||||
python: "3.12"
|
||||
|
||||
# Build documentation in the "docs/" directory with Sphinx
|
||||
sphinx:
|
||||
configuration: docs/conf.py
|
||||
|
||||
# Optionally, but recommended, declare the Python requirements
|
||||
python:
|
||||
install:
|
||||
- requirements: docs/requirements.txt
|
||||
- method: pip
|
||||
path: .
|
||||
@ -27,7 +27,7 @@ def trace_completion(ctx: struct_pt_regs) -> c_int64:
|
||||
print(f"{data_len} {cmd_flags:x} {delta_us}\n")
|
||||
start.delete(req_ptr)
|
||||
|
||||
return c_int64(0)
|
||||
return 0 # type: ignore [return-value]
|
||||
|
||||
|
||||
@bpf
|
||||
@ -36,7 +36,7 @@ def trace_start(ctx1: struct_pt_regs) -> c_int32:
|
||||
req = ctx1.di
|
||||
ts = ktime()
|
||||
start.update(req, ts)
|
||||
return c_int32(0)
|
||||
return 0 # type: ignore [return-value]
|
||||
|
||||
|
||||
@bpf
|
||||
|
||||
16
Makefile
16
Makefile
@ -1,10 +1,22 @@
|
||||
install:
|
||||
pip install -e .
|
||||
uv pip install -e ".[test]"
|
||||
|
||||
clean:
|
||||
rm -rf build dist *.egg-info
|
||||
rm -rf examples/*.ll examples/*.o
|
||||
rm -rf htmlcov .coverage
|
||||
|
||||
test:
|
||||
pytest tests/ -v --tb=short -m "not verifier"
|
||||
|
||||
test-cov:
|
||||
pytest tests/ -v --tb=short -m "not verifier" \
|
||||
--cov=pythonbpf --cov-report=term-missing --cov-report=html
|
||||
|
||||
test-verifier:
|
||||
@echo "NOTE: verifier tests require sudo and bpftool. Uses sudo .venv/bin/python3."
|
||||
pytest tests/test_verifier.py -v --tb=short -m verifier
|
||||
|
||||
all: clean install
|
||||
|
||||
.PHONY: all clean
|
||||
.PHONY: all clean install test test-cov test-verifier
|
||||
|
||||
@ -19,6 +19,8 @@
|
||||
<a href="https://pepy.tech/project/pythonbpf"><img src="https://pepy.tech/badge/pythonbpf" alt="Downloads"></a>
|
||||
<!-- Build & CI -->
|
||||
<a href="https://github.com/pythonbpf/python-bpf/actions"><img src="https://github.com/pythonbpf/python-bpf/actions/workflows/python-publish.yml/badge.svg" alt="Build Status"></a>
|
||||
<!-- Documentation -->
|
||||
<a href="https://python-bpf.readthedocs.io/en/latest/?badge=latest"><img src="https://readthedocs.org/projects/python-bpf/badge/?version=latest" alt="Documentation Status"></a>
|
||||
<!-- Meta -->
|
||||
<a href="https://github.com/pythonbpf/python-bpf/blob/main/LICENSE"><img src="https://img.shields.io/github/license/pythonbpf/python-bpf" alt="License"></a>
|
||||
</p>
|
||||
|
||||
@ -41,7 +41,30 @@ docs = [
|
||||
"sphinx-rtd-theme>=2.0",
|
||||
"sphinx-copybutton",
|
||||
]
|
||||
test = [
|
||||
"pytest>=8.0",
|
||||
"pytest-cov>=5.0",
|
||||
"tomli>=2.0; python_version < '3.11'",
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["pythonbpf*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
pythonpath = ["."]
|
||||
python_files = ["test_*.py"]
|
||||
markers = [
|
||||
"verifier: requires sudo/root for kernel verifier tests (not run by default)",
|
||||
"vmlinux: requires vmlinux.py for current kernel",
|
||||
]
|
||||
log_cli = false
|
||||
|
||||
[tool.coverage.run]
|
||||
source = ["pythonbpf"]
|
||||
omit = ["*/vmlinux*", "*/__pycache__/*"]
|
||||
|
||||
[tool.coverage.report]
|
||||
show_missing = true
|
||||
skip_covered = false
|
||||
|
||||
@ -26,9 +26,7 @@ def create_targets_and_rvals(stmt):
|
||||
return stmt.targets, [stmt.value]
|
||||
|
||||
|
||||
def handle_assign_allocation(
|
||||
builder, stmt, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
):
|
||||
def handle_assign_allocation(compilation_context, builder, stmt, local_sym_tab):
|
||||
"""Handle memory allocation for assignment statements."""
|
||||
|
||||
logger.info(f"Handling assignment for allocation: {ast.dump(stmt)}")
|
||||
@ -59,7 +57,7 @@ def handle_assign_allocation(
|
||||
# Determine type and allocate based on rval
|
||||
if isinstance(rval, ast.Call):
|
||||
_allocate_for_call(
|
||||
builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
builder, var_name, rval, local_sym_tab, compilation_context
|
||||
)
|
||||
elif isinstance(rval, ast.Constant):
|
||||
_allocate_for_constant(builder, var_name, rval, local_sym_tab)
|
||||
@ -71,7 +69,7 @@ def handle_assign_allocation(
|
||||
elif isinstance(rval, ast.Attribute):
|
||||
# Struct field-to-variable assignment (a = dat.fld)
|
||||
_allocate_for_attribute(
|
||||
builder, var_name, rval, local_sym_tab, structs_sym_tab
|
||||
builder, var_name, rval, local_sym_tab, compilation_context
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
@ -79,10 +77,9 @@ def handle_assign_allocation(
|
||||
)
|
||||
|
||||
|
||||
def _allocate_for_call(
|
||||
builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
):
|
||||
def _allocate_for_call(builder, var_name, rval, local_sym_tab, compilation_context):
|
||||
"""Allocate memory for variable assigned from a call."""
|
||||
structs_sym_tab = compilation_context.structs_sym_tab
|
||||
|
||||
if isinstance(rval.func, ast.Name):
|
||||
call_type = rval.func.id
|
||||
@ -149,7 +146,7 @@ def _allocate_for_call(
|
||||
elif isinstance(rval.func, ast.Attribute):
|
||||
# Map method calls - need double allocation for ptr handling
|
||||
_allocate_for_map_method(
|
||||
builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
builder, var_name, rval, local_sym_tab, compilation_context
|
||||
)
|
||||
|
||||
else:
|
||||
@ -157,9 +154,11 @@ def _allocate_for_call(
|
||||
|
||||
|
||||
def _allocate_for_map_method(
|
||||
builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
builder, var_name, rval, local_sym_tab, compilation_context
|
||||
):
|
||||
"""Allocate memory for variable assigned from map method (double alloc)."""
|
||||
map_sym_tab = compilation_context.map_sym_tab
|
||||
structs_sym_tab = compilation_context.structs_sym_tab
|
||||
|
||||
map_name = rval.func.value.id
|
||||
method_name = rval.func.attr
|
||||
@ -321,8 +320,12 @@ def _allocate_for_name(builder, var_name, rval, local_sym_tab):
|
||||
)
|
||||
|
||||
|
||||
def _allocate_for_attribute(builder, var_name, rval, local_sym_tab, structs_sym_tab):
|
||||
def _allocate_for_attribute(
|
||||
builder, var_name, rval, local_sym_tab, compilation_context
|
||||
):
|
||||
"""Allocate memory for struct field-to-variable assignment (a = dat.fld)."""
|
||||
structs_sym_tab = compilation_context.structs_sym_tab
|
||||
|
||||
if not isinstance(rval.value, ast.Name):
|
||||
logger.warning(f"Complex attribute access not supported for {var_name}")
|
||||
return
|
||||
|
||||
@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def handle_struct_field_assignment(
|
||||
func, module, builder, target, rval, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
func, compilation_context, builder, target, rval, local_sym_tab
|
||||
):
|
||||
"""Handle struct field assignment (obj.field = value)."""
|
||||
|
||||
@ -24,7 +24,7 @@ def handle_struct_field_assignment(
|
||||
return
|
||||
|
||||
struct_type = local_sym_tab[var_name].metadata
|
||||
struct_info = structs_sym_tab[struct_type]
|
||||
struct_info = compilation_context.structs_sym_tab[struct_type]
|
||||
|
||||
if field_name not in struct_info.fields:
|
||||
logger.error(f"Field '{field_name}' not found in struct '{struct_type}'")
|
||||
@ -33,9 +33,7 @@ def handle_struct_field_assignment(
|
||||
# Get field pointer and evaluate value
|
||||
field_ptr = struct_info.gep(builder, local_sym_tab[var_name].var, field_name)
|
||||
field_type = struct_info.field_type(field_name)
|
||||
val_result = eval_expr(
|
||||
func, module, builder, rval, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
)
|
||||
val_result = eval_expr(func, compilation_context, builder, rval, local_sym_tab)
|
||||
|
||||
if val_result is None:
|
||||
logger.error(f"Failed to evaluate value for {var_name}.{field_name}")
|
||||
@ -47,14 +45,11 @@ def handle_struct_field_assignment(
|
||||
if _is_char_array(field_type) and _is_i8_ptr(val_type):
|
||||
_copy_string_to_char_array(
|
||||
func,
|
||||
module,
|
||||
builder,
|
||||
val,
|
||||
field_ptr,
|
||||
field_type,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
logger.info(f"Copied string to char array {var_name}.{field_name}")
|
||||
return
|
||||
@ -66,14 +61,11 @@ def handle_struct_field_assignment(
|
||||
|
||||
def _copy_string_to_char_array(
|
||||
func,
|
||||
module,
|
||||
builder,
|
||||
src_ptr,
|
||||
dst_ptr,
|
||||
array_type,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
struct_sym_tab,
|
||||
):
|
||||
"""Copy string (i8*) to char array ([N x i8]) using bpf_probe_read_kernel_str"""
|
||||
|
||||
@ -109,7 +101,7 @@ def _is_i8_ptr(ir_type):
|
||||
|
||||
|
||||
def handle_variable_assignment(
|
||||
func, module, builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
func, compilation_context, builder, var_name, rval, local_sym_tab
|
||||
):
|
||||
"""Handle single named variable assignment."""
|
||||
|
||||
@ -120,6 +112,8 @@ def handle_variable_assignment(
|
||||
var_ptr = local_sym_tab[var_name].var
|
||||
var_type = local_sym_tab[var_name].ir_type
|
||||
|
||||
structs_sym_tab = compilation_context.structs_sym_tab
|
||||
|
||||
# NOTE: Special case for struct initialization
|
||||
if isinstance(rval, ast.Call) and isinstance(rval.func, ast.Name):
|
||||
struct_name = rval.func.id
|
||||
@ -142,9 +136,7 @@ def handle_variable_assignment(
|
||||
logger.info(f"Assigned char array pointer to {var_name}")
|
||||
return True
|
||||
|
||||
val_result = eval_expr(
|
||||
func, module, builder, rval, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
)
|
||||
val_result = eval_expr(func, compilation_context, builder, rval, local_sym_tab)
|
||||
if val_result is None:
|
||||
logger.error(f"Failed to evaluate value for {var_name}")
|
||||
return False
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import ast
|
||||
from llvmlite import ir
|
||||
from .context import CompilationContext
|
||||
from .license_pass import license_processing
|
||||
from .functions import func_proc
|
||||
from .maps import maps_proc
|
||||
@ -67,9 +68,10 @@ def find_bpf_chunks(tree):
|
||||
return bpf_functions
|
||||
|
||||
|
||||
def processor(source_code, filename, module):
|
||||
def processor(source_code, filename, compilation_context):
|
||||
tree = ast.parse(source_code, filename)
|
||||
logger.debug(ast.dump(tree, indent=4))
|
||||
module = compilation_context.module
|
||||
|
||||
bpf_chunks = find_bpf_chunks(tree)
|
||||
for func_node in bpf_chunks:
|
||||
@ -81,15 +83,18 @@ def processor(source_code, filename, module):
|
||||
if vmlinux_symtab:
|
||||
handler = VmlinuxHandler.initialize(vmlinux_symtab)
|
||||
VmlinuxHandlerRegistry.set_handler(handler)
|
||||
compilation_context.vmlinux_handler = handler
|
||||
|
||||
populate_global_symbol_table(tree, module)
|
||||
license_processing(tree, module)
|
||||
globals_processing(tree, module)
|
||||
structs_sym_tab = structs_proc(tree, module, bpf_chunks)
|
||||
map_sym_tab = maps_proc(tree, module, bpf_chunks, structs_sym_tab)
|
||||
func_proc(tree, module, bpf_chunks, map_sym_tab, structs_sym_tab)
|
||||
populate_global_symbol_table(tree, compilation_context)
|
||||
license_processing(tree, compilation_context)
|
||||
globals_processing(tree, compilation_context)
|
||||
structs_sym_tab = structs_proc(tree, compilation_context, bpf_chunks)
|
||||
|
||||
globals_list_creation(tree, module)
|
||||
map_sym_tab = maps_proc(tree, compilation_context, bpf_chunks)
|
||||
|
||||
func_proc(tree, compilation_context, bpf_chunks)
|
||||
|
||||
globals_list_creation(tree, compilation_context)
|
||||
return structs_sym_tab, map_sym_tab
|
||||
|
||||
|
||||
@ -104,6 +109,8 @@ def compile_to_ir(filename: str, output: str, loglevel=logging.INFO):
|
||||
module.data_layout = "e-m:e-p:64:64-i64:64-i128:128-n32:64-S128"
|
||||
module.triple = "bpf"
|
||||
|
||||
compilation_context = CompilationContext(module)
|
||||
|
||||
if not hasattr(module, "_debug_compile_unit"):
|
||||
debug_generator = DebugInfoGenerator(module)
|
||||
debug_generator.generate_file_metadata(filename, os.path.dirname(filename))
|
||||
@ -116,7 +123,7 @@ def compile_to_ir(filename: str, output: str, loglevel=logging.INFO):
|
||||
True,
|
||||
)
|
||||
|
||||
structs_sym_tab, maps_sym_tab = processor(source, filename, module)
|
||||
structs_sym_tab, maps_sym_tab = processor(source, filename, compilation_context)
|
||||
|
||||
wchar_size = module.add_metadata(
|
||||
[
|
||||
|
||||
82
pythonbpf/context.py
Normal file
82
pythonbpf/context.py
Normal file
@ -0,0 +1,82 @@
|
||||
from llvmlite import ir
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pythonbpf.structs.struct_type import StructType
|
||||
from pythonbpf.maps.maps_utils import MapSymbol
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ScratchPoolManager:
|
||||
"""Manage the temporary helper variables in local_sym_tab"""
|
||||
|
||||
def __init__(self):
|
||||
self._counters = {}
|
||||
|
||||
@property
|
||||
def counter(self):
|
||||
return sum(self._counters.values())
|
||||
|
||||
def reset(self):
|
||||
self._counters.clear()
|
||||
logger.debug("Scratch pool counter reset to 0")
|
||||
|
||||
def _get_type_name(self, ir_type):
|
||||
if isinstance(ir_type, ir.PointerType):
|
||||
return "ptr"
|
||||
elif isinstance(ir_type, ir.IntType):
|
||||
return f"i{ir_type.width}"
|
||||
elif isinstance(ir_type, ir.ArrayType):
|
||||
return f"[{ir_type.count}x{self._get_type_name(ir_type.element)}]"
|
||||
else:
|
||||
return str(ir_type).replace(" ", "")
|
||||
|
||||
def get_next_temp(self, local_sym_tab, expected_type=None):
|
||||
# Default to i64 if no expected type provided
|
||||
type_name = self._get_type_name(expected_type) if expected_type else "i64"
|
||||
if type_name not in self._counters:
|
||||
self._counters[type_name] = 0
|
||||
|
||||
counter = self._counters[type_name]
|
||||
temp_name = f"__helper_temp_{type_name}_{counter}"
|
||||
self._counters[type_name] += 1
|
||||
|
||||
if temp_name not in local_sym_tab:
|
||||
raise ValueError(
|
||||
f"Scratch pool exhausted or inadequate: {temp_name}. "
|
||||
f"Type: {type_name} Counter: {counter}"
|
||||
)
|
||||
|
||||
logger.debug(f"Using {temp_name} for type {type_name}")
|
||||
return local_sym_tab[temp_name].var, temp_name
|
||||
|
||||
|
||||
class CompilationContext:
|
||||
"""
|
||||
Holds the state for a single compilation run.
|
||||
This replaces global mutable state modules.
|
||||
"""
|
||||
|
||||
def __init__(self, module: ir.Module):
|
||||
self.module = module
|
||||
|
||||
# Symbol tables
|
||||
self.global_sym_tab: list[ir.GlobalVariable] = []
|
||||
self.structs_sym_tab: dict[str, "StructType"] = {}
|
||||
self.map_sym_tab: dict[str, "MapSymbol"] = {}
|
||||
|
||||
# Helper management
|
||||
self.scratch_pool = ScratchPoolManager()
|
||||
|
||||
# Vmlinux handling (optional, specialized)
|
||||
self.vmlinux_handler = None # Can be VmlinuxHandler instance
|
||||
|
||||
# Current function context (optional, if needed globally during function processing)
|
||||
self.current_func = None
|
||||
|
||||
def reset(self):
|
||||
"""Reset state between functions if necessary, though new context per compile is preferred."""
|
||||
self.scratch_pool.reset()
|
||||
self.current_func = None
|
||||
@ -9,12 +9,8 @@ class CallHandlerRegistry:
|
||||
cls._handler = handler
|
||||
|
||||
@classmethod
|
||||
def handle_call(
|
||||
cls, call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
):
|
||||
def handle_call(cls, call, compilation_context, builder, func, local_sym_tab):
|
||||
"""Handle a call using the registered handler"""
|
||||
if cls._handler is None:
|
||||
return None
|
||||
return cls._handler(
|
||||
call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
)
|
||||
return cls._handler(call, compilation_context, builder, func, local_sym_tab)
|
||||
|
||||
@ -37,7 +37,7 @@ def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder
|
||||
raise SyntaxError(f"Undefined variable {expr.id}")
|
||||
|
||||
|
||||
def _handle_constant_expr(module, builder, expr: ast.Constant):
|
||||
def _handle_constant_expr(compilation_context, builder, expr: ast.Constant):
|
||||
"""Handle ast.Constant expressions."""
|
||||
if isinstance(expr.value, int) or isinstance(expr.value, bool):
|
||||
return ir.Constant(ir.IntType(64), int(expr.value)), ir.IntType(64)
|
||||
@ -48,7 +48,9 @@ def _handle_constant_expr(module, builder, expr: ast.Constant):
|
||||
str_constant = ir.Constant(str_type, bytearray(str_bytes))
|
||||
|
||||
# Create global variable
|
||||
global_str = ir.GlobalVariable(module, str_type, name=str_name)
|
||||
global_str = ir.GlobalVariable(
|
||||
compilation_context.module, str_type, name=str_name
|
||||
)
|
||||
global_str.linkage = "internal"
|
||||
global_str.global_constant = True
|
||||
global_str.initializer = str_constant
|
||||
@ -64,10 +66,11 @@ def _handle_attribute_expr(
|
||||
func,
|
||||
expr: ast.Attribute,
|
||||
local_sym_tab: Dict,
|
||||
structs_sym_tab: Dict,
|
||||
compilation_context,
|
||||
builder: ir.IRBuilder,
|
||||
):
|
||||
"""Handle ast.Attribute expressions for struct field access."""
|
||||
structs_sym_tab = compilation_context.structs_sym_tab
|
||||
if isinstance(expr.value, ast.Name):
|
||||
var_name = expr.value.id
|
||||
attr_name = expr.attr
|
||||
@ -157,9 +160,7 @@ def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilde
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def get_operand_value(
|
||||
func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None
|
||||
):
|
||||
def get_operand_value(func, compilation_context, operand, builder, local_sym_tab):
|
||||
"""Extract the value from an operand, handling variables and constants."""
|
||||
logger.info(f"Getting operand value for: {ast.dump(operand)}")
|
||||
if isinstance(operand, ast.Name):
|
||||
@ -187,13 +188,11 @@ def get_operand_value(
|
||||
raise TypeError(f"Unsupported constant type: {type(operand.value)}")
|
||||
elif isinstance(operand, ast.BinOp):
|
||||
res = _handle_binary_op_impl(
|
||||
func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
func, compilation_context, operand, builder, local_sym_tab
|
||||
)
|
||||
return res
|
||||
else:
|
||||
res = eval_expr(
|
||||
func, module, builder, operand, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
)
|
||||
res = eval_expr(func, compilation_context, builder, operand, local_sym_tab)
|
||||
if res is None:
|
||||
raise ValueError(f"Failed to evaluate call expression: {operand}")
|
||||
val, _ = res
|
||||
@ -205,15 +204,13 @@ def get_operand_value(
|
||||
raise TypeError(f"Unsupported operand type: {type(operand)}")
|
||||
|
||||
|
||||
def _handle_binary_op_impl(
|
||||
func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None
|
||||
):
|
||||
def _handle_binary_op_impl(func, compilation_context, rval, builder, local_sym_tab):
|
||||
op = rval.op
|
||||
left = get_operand_value(
|
||||
func, module, rval.left, builder, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
func, compilation_context, rval.left, builder, local_sym_tab
|
||||
)
|
||||
right = get_operand_value(
|
||||
func, module, rval.right, builder, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
func, compilation_context, rval.right, builder, local_sym_tab
|
||||
)
|
||||
logger.info(f"left is {left}, right is {right}, op is {op}")
|
||||
|
||||
@ -249,16 +246,14 @@ def _handle_binary_op_impl(
|
||||
|
||||
def _handle_binary_op(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
rval,
|
||||
builder,
|
||||
var_name,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab=None,
|
||||
):
|
||||
result = _handle_binary_op_impl(
|
||||
func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
func, compilation_context, rval, builder, local_sym_tab
|
||||
)
|
||||
if var_name and var_name in local_sym_tab:
|
||||
logger.info(
|
||||
@ -275,12 +270,10 @@ def _handle_binary_op(
|
||||
|
||||
def _handle_ctypes_call(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
expr,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab=None,
|
||||
):
|
||||
"""Handle ctypes type constructor calls."""
|
||||
if len(expr.args) != 1:
|
||||
@ -290,12 +283,10 @@ def _handle_ctypes_call(
|
||||
arg = expr.args[0]
|
||||
val = eval_expr(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
arg,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
if val is None:
|
||||
logger.info("Failed to evaluate argument to ctypes constructor")
|
||||
@ -344,9 +335,7 @@ def _handle_ctypes_call(
|
||||
return value, expected_type
|
||||
|
||||
|
||||
def _handle_compare(
|
||||
func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab=None
|
||||
):
|
||||
def _handle_compare(func, compilation_context, builder, cond, local_sym_tab):
|
||||
"""Handle ast.Compare expressions."""
|
||||
|
||||
if len(cond.ops) != 1 or len(cond.comparators) != 1:
|
||||
@ -354,21 +343,17 @@ def _handle_compare(
|
||||
return None
|
||||
lhs = eval_expr(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
cond.left,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
rhs = eval_expr(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
cond.comparators[0],
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
|
||||
if lhs is None or rhs is None:
|
||||
@ -382,12 +367,10 @@ def _handle_compare(
|
||||
|
||||
def _handle_unary_op(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
expr: ast.UnaryOp,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab=None,
|
||||
):
|
||||
"""Handle ast.UnaryOp expressions."""
|
||||
if not isinstance(expr.op, ast.Not) and not isinstance(expr.op, ast.USub):
|
||||
@ -395,7 +378,7 @@ def _handle_unary_op(
|
||||
return None
|
||||
|
||||
operand = get_operand_value(
|
||||
func, module, expr.operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
func, compilation_context, expr.operand, builder, local_sym_tab
|
||||
)
|
||||
if operand is None:
|
||||
logger.error("Failed to evaluate operand for unary operation")
|
||||
@ -418,7 +401,7 @@ def _handle_unary_op(
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab):
|
||||
def _handle_and_op(func, builder, expr, local_sym_tab, compilation_context):
|
||||
"""Handle `and` boolean operations."""
|
||||
|
||||
logger.debug(f"Handling 'and' operator with {len(expr.values)} operands")
|
||||
@ -433,7 +416,7 @@ def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_
|
||||
|
||||
# Evaluate current operand
|
||||
operand_result = eval_expr(
|
||||
func, None, builder, value, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
func, compilation_context, builder, value, local_sym_tab
|
||||
)
|
||||
if operand_result is None:
|
||||
logger.error(f"Failed to evaluate operand {i} in 'and' expression")
|
||||
@ -471,7 +454,7 @@ def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_
|
||||
return phi, ir.IntType(1)
|
||||
|
||||
|
||||
def _handle_or_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab):
|
||||
def _handle_or_op(func, builder, expr, local_sym_tab, compilation_context):
|
||||
"""Handle `or` boolean operations."""
|
||||
|
||||
logger.debug(f"Handling 'or' operator with {len(expr.values)} operands")
|
||||
@ -486,7 +469,7 @@ def _handle_or_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_t
|
||||
|
||||
# Evaluate current operand
|
||||
operand_result = eval_expr(
|
||||
func, None, builder, value, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
func, compilation_context, builder, value, local_sym_tab
|
||||
)
|
||||
if operand_result is None:
|
||||
logger.error(f"Failed to evaluate operand {i} in 'or' expression")
|
||||
@ -526,23 +509,17 @@ def _handle_or_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_t
|
||||
|
||||
def _handle_boolean_op(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
expr: ast.BoolOp,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab=None,
|
||||
):
|
||||
"""Handle `and` and `or` boolean operations."""
|
||||
|
||||
if isinstance(expr.op, ast.And):
|
||||
return _handle_and_op(
|
||||
func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
)
|
||||
return _handle_and_op(func, builder, expr, local_sym_tab, compilation_context)
|
||||
elif isinstance(expr.op, ast.Or):
|
||||
return _handle_or_op(
|
||||
func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
)
|
||||
return _handle_or_op(func, builder, expr, local_sym_tab, compilation_context)
|
||||
else:
|
||||
logger.error(f"Unsupported boolean operator: {type(expr.op).__name__}")
|
||||
return None
|
||||
@ -555,12 +532,10 @@ def _handle_boolean_op(
|
||||
|
||||
def _handle_vmlinux_cast(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
expr,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab=None,
|
||||
):
|
||||
# handle expressions such as struct_request(ctx.di) where struct_request is a vmlinux
|
||||
# struct and ctx.di is a pointer to a struct but is actually represented as a c_uint64
|
||||
@ -576,12 +551,10 @@ def _handle_vmlinux_cast(
|
||||
# Evaluate the argument (e.g., ctx.di which is a c_uint64)
|
||||
arg_result = eval_expr(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
expr.args[0],
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
|
||||
if arg_result is None:
|
||||
@ -614,18 +587,17 @@ def _handle_vmlinux_cast(
|
||||
|
||||
def _handle_user_defined_struct_cast(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
expr,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
):
|
||||
"""Handle user-defined struct cast expressions like iphdr(nh).
|
||||
|
||||
This casts a pointer/integer value to a pointer to the user-defined struct,
|
||||
similar to how vmlinux struct casts work but for user-defined @struct types.
|
||||
"""
|
||||
structs_sym_tab = compilation_context.structs_sym_tab
|
||||
if len(expr.args) != 1:
|
||||
logger.info("User-defined struct cast takes exactly one argument")
|
||||
return None
|
||||
@ -643,12 +615,10 @@ def _handle_user_defined_struct_cast(
|
||||
# an address/pointer value)
|
||||
arg_result = eval_expr(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
expr.args[0],
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
|
||||
if arg_result is None:
|
||||
@ -683,30 +653,28 @@ def _handle_user_defined_struct_cast(
|
||||
|
||||
def eval_expr(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
expr,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab=None,
|
||||
):
|
||||
structs_sym_tab = compilation_context.structs_sym_tab
|
||||
|
||||
logger.info(f"Evaluating expression: {ast.dump(expr)}")
|
||||
if isinstance(expr, ast.Name):
|
||||
return _handle_name_expr(expr, local_sym_tab, builder)
|
||||
elif isinstance(expr, ast.Constant):
|
||||
return _handle_constant_expr(module, builder, expr)
|
||||
return _handle_constant_expr(compilation_context, builder, expr)
|
||||
elif isinstance(expr, ast.Call):
|
||||
if isinstance(expr.func, ast.Name) and VmlinuxHandlerRegistry.is_vmlinux_struct(
|
||||
expr.func.id
|
||||
):
|
||||
return _handle_vmlinux_cast(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
expr,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
if isinstance(expr.func, ast.Name) and expr.func.id == "deref":
|
||||
return _handle_deref_call(expr, local_sym_tab, builder)
|
||||
@ -714,26 +682,23 @@ def eval_expr(
|
||||
if isinstance(expr.func, ast.Name) and is_ctypes(expr.func.id):
|
||||
return _handle_ctypes_call(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
expr,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
if isinstance(expr.func, ast.Name) and (expr.func.id in structs_sym_tab):
|
||||
return _handle_user_defined_struct_cast(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
expr,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
|
||||
# NOTE: Updated handle_call signature
|
||||
result = CallHandlerRegistry.handle_call(
|
||||
expr, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
expr, compilation_context, builder, func, local_sym_tab
|
||||
)
|
||||
if result is not None:
|
||||
return result
|
||||
@ -742,30 +707,24 @@ def eval_expr(
|
||||
return None
|
||||
elif isinstance(expr, ast.Attribute):
|
||||
return _handle_attribute_expr(
|
||||
func, expr, local_sym_tab, structs_sym_tab, builder
|
||||
func, expr, local_sym_tab, compilation_context, builder
|
||||
)
|
||||
elif isinstance(expr, ast.BinOp):
|
||||
return _handle_binary_op(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
expr,
|
||||
builder,
|
||||
None,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
elif isinstance(expr, ast.Compare):
|
||||
return _handle_compare(
|
||||
func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
)
|
||||
return _handle_compare(func, compilation_context, builder, expr, local_sym_tab)
|
||||
elif isinstance(expr, ast.UnaryOp):
|
||||
return _handle_unary_op(
|
||||
func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
)
|
||||
return _handle_unary_op(func, compilation_context, builder, expr, local_sym_tab)
|
||||
elif isinstance(expr, ast.BoolOp):
|
||||
return _handle_boolean_op(
|
||||
func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
func, compilation_context, builder, expr, local_sym_tab
|
||||
)
|
||||
logger.info("Unsupported expression evaluation")
|
||||
return None
|
||||
@ -773,12 +732,10 @@ def eval_expr(
|
||||
|
||||
def handle_expr(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
expr,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
):
|
||||
"""Handle expression statements in the function body."""
|
||||
logger.info(f"Handling expression: {ast.dump(expr)}")
|
||||
@ -786,12 +743,10 @@ def handle_expr(
|
||||
if isinstance(call, ast.Call):
|
||||
eval_expr(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
call,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
else:
|
||||
logger.info("Unsupported expression type")
|
||||
|
||||
@ -4,7 +4,6 @@ import logging
|
||||
|
||||
from pythonbpf.helper import (
|
||||
HelperHandlerRegistry,
|
||||
reset_scratch_pool,
|
||||
)
|
||||
from pythonbpf.type_deducer import ctypes_to_ir
|
||||
from pythonbpf.expr import (
|
||||
@ -76,36 +75,30 @@ def count_temps_in_call(call_node, local_sym_tab):
|
||||
|
||||
|
||||
def handle_if_allocation(
|
||||
module, builder, stmt, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
|
||||
compilation_context, builder, stmt, func, ret_type, local_sym_tab
|
||||
):
|
||||
"""Recursively handle allocations in if/else branches."""
|
||||
if stmt.body:
|
||||
allocate_mem(
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
stmt.body,
|
||||
func,
|
||||
ret_type,
|
||||
map_sym_tab,
|
||||
local_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
if stmt.orelse:
|
||||
allocate_mem(
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
stmt.orelse,
|
||||
func,
|
||||
ret_type,
|
||||
map_sym_tab,
|
||||
local_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
|
||||
|
||||
def allocate_mem(
|
||||
module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
|
||||
):
|
||||
def allocate_mem(compilation_context, builder, body, func, ret_type, local_sym_tab):
|
||||
max_temps_needed = {}
|
||||
|
||||
def merge_type_counts(count_dict):
|
||||
@ -137,19 +130,15 @@ def allocate_mem(
|
||||
# Handle allocations
|
||||
if isinstance(stmt, ast.If):
|
||||
handle_if_allocation(
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
stmt,
|
||||
func,
|
||||
ret_type,
|
||||
map_sym_tab,
|
||||
local_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
elif isinstance(stmt, ast.Assign):
|
||||
handle_assign_allocation(
|
||||
builder, stmt, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
)
|
||||
handle_assign_allocation(compilation_context, builder, stmt, local_sym_tab)
|
||||
|
||||
allocate_temp_pool(builder, max_temps_needed, local_sym_tab)
|
||||
|
||||
@ -161,9 +150,7 @@ def allocate_mem(
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def handle_assign(
|
||||
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab
|
||||
):
|
||||
def handle_assign(func, compilation_context, builder, stmt, local_sym_tab):
|
||||
"""Handle assignment statements in the function body."""
|
||||
|
||||
# NOTE: Support multi-target assignments (e.g.: a, b = 1, 2)
|
||||
@ -175,13 +162,11 @@ def handle_assign(
|
||||
var_name = target.id
|
||||
result = handle_variable_assignment(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
var_name,
|
||||
rval,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
if not result:
|
||||
logger.error(f"Failed to handle assignment to {var_name}")
|
||||
@ -191,13 +176,11 @@ def handle_assign(
|
||||
# NOTE: Struct field assignment case: pkt.field = value
|
||||
handle_struct_field_assignment(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
target,
|
||||
rval,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
continue
|
||||
|
||||
@ -205,18 +188,12 @@ def handle_assign(
|
||||
logger.error(f"Unsupported assignment target: {ast.dump(target)}")
|
||||
|
||||
|
||||
def handle_cond(
|
||||
func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab=None
|
||||
):
|
||||
val = eval_expr(
|
||||
func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
)[0]
|
||||
def handle_cond(func, compilation_context, builder, cond, local_sym_tab):
|
||||
val = eval_expr(func, compilation_context, builder, cond, local_sym_tab)[0]
|
||||
return convert_to_bool(builder, val)
|
||||
|
||||
|
||||
def handle_if(
|
||||
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab=None
|
||||
):
|
||||
def handle_if(func, compilation_context, builder, stmt, local_sym_tab):
|
||||
"""Handle if statements in the function body."""
|
||||
logger.info("Handling if statement")
|
||||
# start = builder.block.parent
|
||||
@ -227,9 +204,7 @@ def handle_if(
|
||||
else:
|
||||
else_block = None
|
||||
|
||||
cond = handle_cond(
|
||||
func, module, builder, stmt.test, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
)
|
||||
cond = handle_cond(func, compilation_context, builder, stmt.test, local_sym_tab)
|
||||
if else_block:
|
||||
builder.cbranch(cond, then_block, else_block)
|
||||
else:
|
||||
@ -237,9 +212,7 @@ def handle_if(
|
||||
|
||||
builder.position_at_end(then_block)
|
||||
for s in stmt.body:
|
||||
process_stmt(
|
||||
func, module, builder, s, local_sym_tab, map_sym_tab, structs_sym_tab, False
|
||||
)
|
||||
process_stmt(func, compilation_context, builder, s, local_sym_tab, False)
|
||||
if not builder.block.is_terminated:
|
||||
builder.branch(merge_block)
|
||||
|
||||
@ -248,12 +221,10 @@ def handle_if(
|
||||
for s in stmt.orelse:
|
||||
process_stmt(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
s,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
False,
|
||||
)
|
||||
if not builder.block.is_terminated:
|
||||
@ -262,21 +233,25 @@ def handle_if(
|
||||
builder.position_at_end(merge_block)
|
||||
|
||||
|
||||
def handle_return(builder, stmt, local_sym_tab, ret_type):
|
||||
def handle_return(builder, stmt, local_sym_tab, ret_type, compilation_context=None):
|
||||
logger.info(f"Handling return statement: {ast.dump(stmt)}")
|
||||
if stmt.value is None:
|
||||
return handle_none_return(builder)
|
||||
elif isinstance(stmt.value, ast.Name) and is_xdp_name(stmt.value.id):
|
||||
return handle_xdp_return(stmt, builder, ret_type)
|
||||
else:
|
||||
# Fallback for now if ctx not passed, but caller should pass it
|
||||
if compilation_context is None:
|
||||
raise RuntimeError(
|
||||
"CompilationContext required for return statement evaluation"
|
||||
)
|
||||
|
||||
val = eval_expr(
|
||||
func=None,
|
||||
module=None,
|
||||
compilation_context=compilation_context,
|
||||
builder=builder,
|
||||
expr=stmt.value,
|
||||
local_sym_tab=local_sym_tab,
|
||||
map_sym_tab={},
|
||||
structs_sym_tab={},
|
||||
)
|
||||
logger.info(f"Evaluated return expression to {val}")
|
||||
builder.ret(val[0])
|
||||
@ -285,43 +260,34 @@ def handle_return(builder, stmt, local_sym_tab, ret_type):
|
||||
|
||||
def process_stmt(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
stmt,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
did_return,
|
||||
ret_type=ir.IntType(64),
|
||||
):
|
||||
logger.info(f"Processing statement: {ast.dump(stmt)}")
|
||||
reset_scratch_pool()
|
||||
# Use context scratch pool
|
||||
compilation_context.scratch_pool.reset()
|
||||
|
||||
if isinstance(stmt, ast.Expr):
|
||||
handle_expr(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
stmt,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
elif isinstance(stmt, ast.Assign):
|
||||
handle_assign(
|
||||
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab
|
||||
)
|
||||
handle_assign(func, compilation_context, builder, stmt, local_sym_tab)
|
||||
elif isinstance(stmt, ast.AugAssign):
|
||||
raise SyntaxError("Augmented assignment not supported")
|
||||
elif isinstance(stmt, ast.If):
|
||||
handle_if(
|
||||
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab
|
||||
)
|
||||
handle_if(func, compilation_context, builder, stmt, local_sym_tab)
|
||||
elif isinstance(stmt, ast.Return):
|
||||
did_return = handle_return(
|
||||
builder,
|
||||
stmt,
|
||||
local_sym_tab,
|
||||
ret_type,
|
||||
builder, stmt, local_sym_tab, ret_type, compilation_context
|
||||
)
|
||||
return did_return
|
||||
|
||||
@ -332,13 +298,11 @@ def process_stmt(
|
||||
|
||||
|
||||
def process_func_body(
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func_node,
|
||||
func,
|
||||
ret_type,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
):
|
||||
"""Process the body of a bpf function"""
|
||||
# TODO: A lot. We just have print -> bpf_trace_printk for now
|
||||
@ -360,6 +324,7 @@ def process_func_body(
|
||||
raise TypeError(
|
||||
f"Unsupported annotation type: {ast.dump(context_arg.annotation)}"
|
||||
)
|
||||
|
||||
if VmlinuxHandlerRegistry.is_vmlinux_struct(context_type_name):
|
||||
resolved_type = VmlinuxHandlerRegistry.get_struct_type(
|
||||
context_type_name
|
||||
@ -370,14 +335,12 @@ def process_func_body(
|
||||
|
||||
# pre-allocate dynamic variables
|
||||
local_sym_tab = allocate_mem(
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func_node.body,
|
||||
func,
|
||||
ret_type,
|
||||
map_sym_tab,
|
||||
local_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
|
||||
logger.info(f"Local symbol table: {local_sym_tab.keys()}")
|
||||
@ -385,12 +348,10 @@ def process_func_body(
|
||||
for stmt in func_node.body:
|
||||
did_return = process_stmt(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
stmt,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
did_return,
|
||||
ret_type,
|
||||
)
|
||||
@ -399,9 +360,12 @@ def process_func_body(
|
||||
builder.ret(ir.Constant(ir.IntType(64), 0))
|
||||
|
||||
|
||||
def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_tab):
|
||||
def process_bpf_chunk(func_node, compilation_context, return_type):
|
||||
"""Process a single BPF chunk (function) and emit corresponding LLVM IR."""
|
||||
|
||||
# Set current function in context (optional but good for future)
|
||||
compilation_context.current_func = func_node
|
||||
|
||||
func_name = func_node.name
|
||||
|
||||
ret_type = return_type
|
||||
@ -413,7 +377,7 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t
|
||||
param_types.append(ir.PointerType())
|
||||
|
||||
func_ty = ir.FunctionType(ret_type, param_types)
|
||||
func = ir.Function(module, func_ty, func_name)
|
||||
func = ir.Function(compilation_context.module, func_ty, func_name)
|
||||
|
||||
func.linkage = "dso_local"
|
||||
func.attributes.add("nounwind")
|
||||
@ -433,13 +397,11 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t
|
||||
builder = ir.IRBuilder(block)
|
||||
|
||||
process_func_body(
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func_node,
|
||||
func,
|
||||
ret_type,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
return func
|
||||
|
||||
@ -449,23 +411,32 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def func_proc(tree, module, chunks, map_sym_tab, structs_sym_tab):
|
||||
def func_proc(tree, compilation_context, chunks):
|
||||
"""Process all functions decorated with @bpf and @bpfglobal"""
|
||||
for func_node in chunks:
|
||||
# Ignore structs and maps
|
||||
# Check against the lists
|
||||
if (
|
||||
func_node.name in compilation_context.structs_sym_tab
|
||||
or func_node.name in compilation_context.map_sym_tab
|
||||
):
|
||||
continue
|
||||
|
||||
# Also check decorators to be sure
|
||||
decorators = [d.id for d in func_node.decorator_list if isinstance(d, ast.Name)]
|
||||
if "struct" in decorators or "map" in decorators:
|
||||
continue
|
||||
|
||||
if is_global_function(func_node):
|
||||
continue
|
||||
func_type = get_probe_string(func_node)
|
||||
logger.info(f"Found probe_string of {func_node.name}: {func_type}")
|
||||
|
||||
func = process_bpf_chunk(
|
||||
func_node,
|
||||
module,
|
||||
ctypes_to_ir(infer_return_type(func_node)),
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
return_type = ctypes_to_ir(infer_return_type(func_node))
|
||||
func = process_bpf_chunk(func_node, compilation_context, return_type)
|
||||
|
||||
logger.info(f"Generating Debug Info for Function {func_node.name}")
|
||||
generate_function_debug_info(func_node, module, func)
|
||||
generate_function_debug_info(func_node, compilation_context.module, func)
|
||||
|
||||
|
||||
# TODO: WIP, for string assignment to fixed-size arrays
|
||||
|
||||
@ -7,11 +7,11 @@ from .type_deducer import ctypes_to_ir
|
||||
|
||||
logger: Logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO: this is going to be a huge fuck of a headache in the future.
|
||||
global_sym_tab = []
|
||||
|
||||
|
||||
def populate_global_symbol_table(tree, module: ir.Module):
|
||||
def populate_global_symbol_table(tree, compilation_context):
|
||||
"""
|
||||
compilation_context: CompilationContext
|
||||
"""
|
||||
for node in tree.body:
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
for dec in node.decorator_list:
|
||||
@ -23,16 +23,16 @@ def populate_global_symbol_table(tree, module: ir.Module):
|
||||
and isinstance(dec.args[0], ast.Constant)
|
||||
and isinstance(dec.args[0].value, str)
|
||||
):
|
||||
global_sym_tab.append(node)
|
||||
compilation_context.global_sym_tab.append(node)
|
||||
elif isinstance(dec, ast.Name) and dec.id == "bpfglobal":
|
||||
global_sym_tab.append(node)
|
||||
compilation_context.global_sym_tab.append(node)
|
||||
|
||||
elif isinstance(dec, ast.Name) and dec.id == "map":
|
||||
global_sym_tab.append(node)
|
||||
compilation_context.global_sym_tab.append(node)
|
||||
return False
|
||||
|
||||
|
||||
def emit_global(module: ir.Module, node, name):
|
||||
def _emit_global(module: ir.Module, node, name):
|
||||
logger.info(f"global identifier {name} processing")
|
||||
# deduce LLVM type from the annotated return
|
||||
if not isinstance(node.returns, ast.Name):
|
||||
@ -74,9 +74,12 @@ def emit_global(module: ir.Module, node, name):
|
||||
return gvar
|
||||
|
||||
|
||||
def globals_processing(tree, module):
|
||||
def globals_processing(tree, compilation_context):
|
||||
"""Process stuff decorated with @bpf and @bpfglobal except license and return the section name"""
|
||||
globals_sym_tab = []
|
||||
# Local tracking for duplicate checking if needed, or we can iterate context
|
||||
# But for now, we process specific nodes
|
||||
|
||||
current_globals = []
|
||||
|
||||
for node in tree.body:
|
||||
# Skip non-assignment and non-function nodes
|
||||
@ -90,10 +93,10 @@ def globals_processing(tree, module):
|
||||
continue
|
||||
|
||||
# Check for duplicate names
|
||||
if name in globals_sym_tab:
|
||||
if name in current_globals:
|
||||
raise SyntaxError(f"ERROR: Global name '{name}' previously defined")
|
||||
else:
|
||||
globals_sym_tab.append(name)
|
||||
current_globals.append(name)
|
||||
|
||||
if isinstance(node, ast.FunctionDef) and node.name != "LICENSE":
|
||||
decorators = [
|
||||
@ -108,14 +111,14 @@ def globals_processing(tree, module):
|
||||
node.body[0].value, (ast.Constant, ast.Name, ast.Call)
|
||||
)
|
||||
):
|
||||
emit_global(module, node, name)
|
||||
_emit_global(compilation_context.module, node, name)
|
||||
else:
|
||||
raise SyntaxError(f"ERROR: Invalid syntax for {name} global")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def emit_llvm_compiler_used(module: ir.Module, names: list[str]):
|
||||
def _emit_llvm_compiler_used(module: ir.Module, names: list[str]):
|
||||
"""
|
||||
Emit the @llvm.compiler.used global given a list of function/global names.
|
||||
"""
|
||||
@ -137,8 +140,9 @@ def emit_llvm_compiler_used(module: ir.Module, names: list[str]):
|
||||
gv.section = "llvm.metadata"
|
||||
|
||||
|
||||
def globals_list_creation(tree, module: ir.Module):
|
||||
def globals_list_creation(tree, compilation_context):
|
||||
collected = ["LICENSE"]
|
||||
module = compilation_context.module
|
||||
|
||||
for node in tree.body:
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
@ -160,4 +164,4 @@ def globals_list_creation(tree, module: ir.Module):
|
||||
elif isinstance(dec, ast.Name) and dec.id == "map":
|
||||
collected.append(node.name)
|
||||
|
||||
emit_llvm_compiler_used(module, collected)
|
||||
_emit_llvm_compiler_used(module, collected)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from .helper_registry import HelperHandlerRegistry
|
||||
from .helper_utils import reset_scratch_pool
|
||||
|
||||
from .bpf_helper_handler import (
|
||||
handle_helper_call,
|
||||
emit_probe_read_kernel_str_call,
|
||||
@ -28,9 +28,7 @@ def _register_helper_handler():
|
||||
"""Register helper call handler with the expression evaluator"""
|
||||
from pythonbpf.expr.expr_pass import CallHandlerRegistry
|
||||
|
||||
def helper_call_handler(
|
||||
call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
|
||||
):
|
||||
def helper_call_handler(call, compilation_context, builder, func, local_sym_tab):
|
||||
"""Check if call is a helper and handle it"""
|
||||
import ast
|
||||
|
||||
@ -39,17 +37,16 @@ def _register_helper_handler():
|
||||
if HelperHandlerRegistry.has_handler(call.func.id):
|
||||
return handle_helper_call(
|
||||
call,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
|
||||
# Check for method calls (e.g., map.lookup())
|
||||
elif isinstance(call.func, ast.Attribute):
|
||||
method_name = call.func.attr
|
||||
map_sym_tab = compilation_context.map_sym_tab
|
||||
|
||||
# Handle: my_map.lookup(key)
|
||||
if isinstance(call.func.value, ast.Name):
|
||||
@ -58,12 +55,10 @@ def _register_helper_handler():
|
||||
if HelperHandlerRegistry.has_handler(method_name):
|
||||
return handle_helper_call(
|
||||
call,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
structs_sym_tab,
|
||||
)
|
||||
|
||||
return None
|
||||
@ -76,7 +71,6 @@ _register_helper_handler()
|
||||
|
||||
__all__ = [
|
||||
"HelperHandlerRegistry",
|
||||
"reset_scratch_pool",
|
||||
"handle_helper_call",
|
||||
"emit_probe_read_kernel_str_call",
|
||||
"emit_probe_read_kernel_call",
|
||||
|
||||
@ -50,12 +50,10 @@ class BPFHelperID(Enum):
|
||||
def bpf_ktime_get_ns_emitter(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
map_sym_tab=None,
|
||||
):
|
||||
"""
|
||||
Emit LLVM IR for bpf_ktime_get_ns helper function call.
|
||||
@ -77,12 +75,10 @@ def bpf_ktime_get_ns_emitter(
|
||||
def bpf_get_current_cgroup_id(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
map_sym_tab=None,
|
||||
):
|
||||
"""
|
||||
Emit LLVM IR for bpf_get_current_cgroup_id helper function call.
|
||||
@ -104,12 +100,10 @@ def bpf_get_current_cgroup_id(
|
||||
def bpf_map_lookup_elem_emitter(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
map_sym_tab=None,
|
||||
):
|
||||
"""
|
||||
Emit LLVM IR for bpf_map_lookup_elem helper function call.
|
||||
@ -119,7 +113,7 @@ def bpf_map_lookup_elem_emitter(
|
||||
f"Map lookup expects exactly one argument (key), got {len(call.args)}"
|
||||
)
|
||||
key_ptr = get_or_create_ptr_from_arg(
|
||||
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||
func, compilation_context, call.args[0], builder, local_sym_tab
|
||||
)
|
||||
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
|
||||
|
||||
@ -147,12 +141,10 @@ def bpf_map_lookup_elem_emitter(
|
||||
def bpf_printk_emitter(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
map_sym_tab=None,
|
||||
):
|
||||
"""Emit LLVM IR for bpf_printk helper function call."""
|
||||
if not hasattr(func, "_fmt_counter"):
|
||||
@ -165,16 +157,17 @@ def bpf_printk_emitter(
|
||||
if isinstance(call.args[0], ast.JoinedStr):
|
||||
args = handle_fstring_print(
|
||||
call.args[0],
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab,
|
||||
struct_sym_tab,
|
||||
)
|
||||
elif isinstance(call.args[0], ast.Constant) and isinstance(call.args[0].value, str):
|
||||
# TODO: We are only supporting single arguments for now.
|
||||
# In case of multiple args, the first one will be taken.
|
||||
args = simple_string_print(call.args[0].value, module, builder, func)
|
||||
args = simple_string_print(
|
||||
call.args[0].value, compilation_context, builder, func
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Only simple strings or f-strings are supported in bpf_printk."
|
||||
@ -203,12 +196,10 @@ def bpf_printk_emitter(
|
||||
def bpf_map_update_elem_emitter(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
map_sym_tab=None,
|
||||
):
|
||||
"""
|
||||
Emit LLVM IR for bpf_map_update_elem helper function call.
|
||||
@ -224,10 +215,10 @@ def bpf_map_update_elem_emitter(
|
||||
flags_arg = call.args[2] if len(call.args) > 2 else None
|
||||
|
||||
key_ptr = get_or_create_ptr_from_arg(
|
||||
func, module, key_arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||
func, compilation_context, key_arg, builder, local_sym_tab
|
||||
)
|
||||
value_ptr = get_or_create_ptr_from_arg(
|
||||
func, module, value_arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||
func, compilation_context, value_arg, builder, local_sym_tab
|
||||
)
|
||||
flags_val = get_flags_val(flags_arg, builder, local_sym_tab)
|
||||
|
||||
@ -262,12 +253,10 @@ def bpf_map_update_elem_emitter(
|
||||
def bpf_map_delete_elem_emitter(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
map_sym_tab=None,
|
||||
):
|
||||
"""
|
||||
Emit LLVM IR for bpf_map_delete_elem helper function call.
|
||||
@ -278,7 +267,7 @@ def bpf_map_delete_elem_emitter(
|
||||
f"Map delete expects exactly one argument (key), got {len(call.args)}"
|
||||
)
|
||||
key_ptr = get_or_create_ptr_from_arg(
|
||||
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||
func, compilation_context, call.args[0], builder, local_sym_tab
|
||||
)
|
||||
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
|
||||
|
||||
@ -306,12 +295,10 @@ def bpf_map_delete_elem_emitter(
|
||||
def bpf_get_current_comm_emitter(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
map_sym_tab=None,
|
||||
):
|
||||
"""
|
||||
Emit LLVM IR for bpf_get_current_comm helper function call.
|
||||
@ -327,7 +314,7 @@ def bpf_get_current_comm_emitter(
|
||||
|
||||
# Extract buffer pointer and size
|
||||
buf_ptr, buf_size = get_buffer_ptr_and_size(
|
||||
buf_arg, builder, local_sym_tab, struct_sym_tab
|
||||
buf_arg, builder, local_sym_tab, compilation_context
|
||||
)
|
||||
|
||||
# Validate it's a char array
|
||||
@ -367,12 +354,10 @@ def bpf_get_current_comm_emitter(
|
||||
def bpf_get_current_pid_tgid_emitter(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
map_sym_tab=None,
|
||||
):
|
||||
"""
|
||||
Emit LLVM IR for bpf_get_current_pid_tgid helper function call.
|
||||
@ -394,12 +379,10 @@ def bpf_get_current_pid_tgid_emitter(
|
||||
def bpf_perf_event_output_handler(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
map_sym_tab=None,
|
||||
):
|
||||
"""
|
||||
Emit LLVM IR for bpf_perf_event_output helper function call.
|
||||
@ -412,7 +395,9 @@ def bpf_perf_event_output_handler(
|
||||
data_arg = call.args[0]
|
||||
ctx_ptr = func.args[0] # First argument to the function is ctx
|
||||
|
||||
data_ptr, size_val = get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab)
|
||||
data_ptr, size_val = get_data_ptr_and_size(
|
||||
data_arg, local_sym_tab, compilation_context
|
||||
)
|
||||
|
||||
# BPF_F_CURRENT_CPU is -1 in 32 bit
|
||||
flags_val = ir.Constant(ir.IntType(64), 0xFFFFFFFF)
|
||||
@ -445,12 +430,10 @@ def bpf_perf_event_output_handler(
|
||||
def bpf_ringbuf_output_emitter(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
map_sym_tab=None,
|
||||
):
|
||||
"""
|
||||
Emit LLVM IR for bpf_ringbuf_output helper function call.
|
||||
@ -461,7 +444,9 @@ def bpf_ringbuf_output_emitter(
|
||||
f"Ringbuf output expects exactly one argument, got {len(call.args)}"
|
||||
)
|
||||
data_arg = call.args[0]
|
||||
data_ptr, size_val = get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab)
|
||||
data_ptr, size_val = get_data_ptr_and_size(
|
||||
data_arg, local_sym_tab, compilation_context
|
||||
)
|
||||
flags_val = ir.Constant(ir.IntType(64), 0)
|
||||
|
||||
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
|
||||
@ -496,38 +481,32 @@ def bpf_ringbuf_output_emitter(
|
||||
def handle_output_helper(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
map_sym_tab=None,
|
||||
):
|
||||
"""
|
||||
Route output helper to the appropriate emitter based on map type.
|
||||
"""
|
||||
match map_sym_tab[map_ptr.name].type:
|
||||
match compilation_context.map_sym_tab[map_ptr.name].type:
|
||||
case BPFMapType.PERF_EVENT_ARRAY:
|
||||
return bpf_perf_event_output_handler(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab,
|
||||
struct_sym_tab,
|
||||
map_sym_tab,
|
||||
)
|
||||
case BPFMapType.RINGBUF:
|
||||
return bpf_ringbuf_output_emitter(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab,
|
||||
struct_sym_tab,
|
||||
map_sym_tab,
|
||||
)
|
||||
case _:
|
||||
logger.error("Unsupported map type for output helper.")
|
||||
@ -572,12 +551,10 @@ def emit_probe_read_kernel_str_call(builder, dst_ptr, dst_size, src_ptr):
|
||||
def bpf_probe_read_kernel_str_emitter(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
map_sym_tab=None,
|
||||
):
|
||||
"""Emit LLVM IR for bpf_probe_read_kernel_str helper."""
|
||||
|
||||
@ -588,12 +565,12 @@ def bpf_probe_read_kernel_str_emitter(
|
||||
|
||||
# Get destination buffer (char array -> i8*)
|
||||
dst_ptr, dst_size = get_or_create_ptr_from_arg(
|
||||
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||
func, compilation_context, call.args[0], builder, local_sym_tab
|
||||
)
|
||||
|
||||
# Get source pointer (evaluate expression)
|
||||
src_ptr, src_type = get_ptr_from_arg(
|
||||
call.args[1], func, module, builder, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||
call.args[1], func, compilation_context, builder, local_sym_tab
|
||||
)
|
||||
|
||||
# Emit the helper call
|
||||
@ -641,12 +618,10 @@ def emit_probe_read_kernel_call(builder, dst_ptr, dst_size, src_ptr):
|
||||
def bpf_probe_read_kernel_emitter(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
map_sym_tab=None,
|
||||
):
|
||||
"""Emit LLVM IR for bpf_probe_read_kernel helper."""
|
||||
|
||||
@ -657,12 +632,12 @@ def bpf_probe_read_kernel_emitter(
|
||||
|
||||
# Get destination buffer (char array -> i8*)
|
||||
dst_ptr, dst_size = get_or_create_ptr_from_arg(
|
||||
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||
func, compilation_context, call.args[0], builder, local_sym_tab
|
||||
)
|
||||
|
||||
# Get source pointer (evaluate expression)
|
||||
src_ptr, src_type = get_ptr_from_arg(
|
||||
call.args[1], func, module, builder, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||
call.args[1], func, compilation_context, builder, local_sym_tab
|
||||
)
|
||||
|
||||
# Emit the helper call
|
||||
@ -680,12 +655,10 @@ def bpf_probe_read_kernel_emitter(
|
||||
def bpf_get_prandom_u32_emitter(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
map_sym_tab=None,
|
||||
):
|
||||
"""
|
||||
Emit LLVM IR for bpf_get_prandom_u32 helper function call.
|
||||
@ -710,12 +683,10 @@ def bpf_get_prandom_u32_emitter(
|
||||
def bpf_probe_read_emitter(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
map_sym_tab=None,
|
||||
):
|
||||
"""
|
||||
Emit LLVM IR for bpf_probe_read helper function
|
||||
@ -726,31 +697,25 @@ def bpf_probe_read_emitter(
|
||||
return
|
||||
dst_ptr = get_or_create_ptr_from_arg(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
call.args[0],
|
||||
builder,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
struct_sym_tab,
|
||||
ir.IntType(8),
|
||||
)
|
||||
size_val = get_int_value_from_arg(
|
||||
call.args[1],
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
struct_sym_tab,
|
||||
)
|
||||
src_ptr = get_or_create_ptr_from_arg(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
call.args[2],
|
||||
builder,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
struct_sym_tab,
|
||||
ir.IntType(8),
|
||||
)
|
||||
fn_type = ir.FunctionType(
|
||||
@ -783,12 +748,10 @@ def bpf_probe_read_emitter(
|
||||
def bpf_get_smp_processor_id_emitter(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
map_sym_tab=None,
|
||||
):
|
||||
"""
|
||||
Emit LLVM IR for bpf_get_smp_processor_id helper function call.
|
||||
@ -810,12 +773,10 @@ def bpf_get_smp_processor_id_emitter(
|
||||
def bpf_get_current_uid_gid_emitter(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
map_sym_tab=None,
|
||||
):
|
||||
"""
|
||||
Emit LLVM IR for bpf_get_current_uid_gid helper function call.
|
||||
@ -846,12 +807,10 @@ def bpf_get_current_uid_gid_emitter(
|
||||
def bpf_skb_store_bytes_emitter(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
map_sym_tab=None,
|
||||
):
|
||||
"""
|
||||
Emit LLVM IR for bpf_skb_store_bytes helper function call.
|
||||
@ -875,30 +834,24 @@ def bpf_skb_store_bytes_emitter(
|
||||
offset_val = get_int_value_from_arg(
|
||||
call.args[0],
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
struct_sym_tab,
|
||||
)
|
||||
from_ptr = get_or_create_ptr_from_arg(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
call.args[1],
|
||||
builder,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
struct_sym_tab,
|
||||
args_signature[2],
|
||||
)
|
||||
len_val = get_int_value_from_arg(
|
||||
call.args[2],
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
struct_sym_tab,
|
||||
)
|
||||
if len(call.args) == 4:
|
||||
flags_val = get_flags_val(call.args[3], builder, local_sym_tab)
|
||||
@ -940,12 +893,10 @@ def bpf_skb_store_bytes_emitter(
|
||||
def bpf_ringbuf_reserve_emitter(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
map_sym_tab=None,
|
||||
):
|
||||
"""
|
||||
Emit LLVM IR for bpf_ringbuf_reserve helper function call.
|
||||
@ -960,11 +911,9 @@ def bpf_ringbuf_reserve_emitter(
|
||||
size_val = get_int_value_from_arg(
|
||||
call.args[0],
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
struct_sym_tab,
|
||||
)
|
||||
|
||||
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
|
||||
@ -991,12 +940,10 @@ def bpf_ringbuf_reserve_emitter(
|
||||
def bpf_ringbuf_submit_emitter(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
map_sym_tab=None,
|
||||
):
|
||||
"""
|
||||
Emit LLVM IR for bpf_ringbuf_submit helper function call.
|
||||
@ -1013,12 +960,10 @@ def bpf_ringbuf_submit_emitter(
|
||||
|
||||
data_ptr = get_or_create_ptr_from_arg(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
data_arg,
|
||||
builder,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
struct_sym_tab,
|
||||
ir.PointerType(ir.IntType(8)),
|
||||
)
|
||||
|
||||
@ -1050,12 +995,10 @@ def bpf_ringbuf_submit_emitter(
|
||||
def bpf_get_stack_emitter(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
map_sym_tab=None,
|
||||
):
|
||||
"""
|
||||
Emit LLVM IR for bpf_get_stack helper function call.
|
||||
@ -1068,7 +1011,7 @@ def bpf_get_stack_emitter(
|
||||
buf_arg = call.args[0]
|
||||
flags_arg = call.args[1] if len(call.args) == 2 else None
|
||||
buf_ptr, buf_size = get_buffer_ptr_and_size(
|
||||
buf_arg, builder, local_sym_tab, struct_sym_tab
|
||||
buf_arg, builder, local_sym_tab, compilation_context
|
||||
)
|
||||
flags_val = get_flags_val(flags_arg, builder, local_sym_tab)
|
||||
if isinstance(flags_val, int):
|
||||
@ -1098,12 +1041,10 @@ def bpf_get_stack_emitter(
|
||||
|
||||
def handle_helper_call(
|
||||
call,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
map_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
):
|
||||
"""Process a BPF helper function call and emit the appropriate LLVM IR."""
|
||||
|
||||
@ -1117,14 +1058,14 @@ def handle_helper_call(
|
||||
return handler(
|
||||
call,
|
||||
map_ptr,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab,
|
||||
struct_sym_tab,
|
||||
map_sym_tab,
|
||||
)
|
||||
|
||||
map_sym_tab = compilation_context.map_sym_tab
|
||||
|
||||
# Handle direct function calls (e.g., print(), ktime())
|
||||
if isinstance(call.func, ast.Name):
|
||||
return invoke_helper(call.func.id)
|
||||
|
||||
@ -3,7 +3,6 @@ import logging
|
||||
|
||||
from llvmlite import ir
|
||||
from pythonbpf.expr import (
|
||||
get_operand_value,
|
||||
eval_expr,
|
||||
access_struct_field,
|
||||
)
|
||||
@ -11,58 +10,6 @@ from pythonbpf.expr import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ScratchPoolManager:
|
||||
"""Manage the temporary helper variables in local_sym_tab"""
|
||||
|
||||
def __init__(self):
|
||||
self._counters = {}
|
||||
|
||||
@property
|
||||
def counter(self):
|
||||
return sum(self._counters.values())
|
||||
|
||||
def reset(self):
|
||||
self._counters.clear()
|
||||
logger.debug("Scratch pool counter reset to 0")
|
||||
|
||||
def _get_type_name(self, ir_type):
|
||||
if isinstance(ir_type, ir.PointerType):
|
||||
return "ptr"
|
||||
elif isinstance(ir_type, ir.IntType):
|
||||
return f"i{ir_type.width}"
|
||||
elif isinstance(ir_type, ir.ArrayType):
|
||||
return f"[{ir_type.count}x{self._get_type_name(ir_type.element)}]"
|
||||
else:
|
||||
return str(ir_type).replace(" ", "")
|
||||
|
||||
def get_next_temp(self, local_sym_tab, expected_type=None):
|
||||
# Default to i64 if no expected type provided
|
||||
type_name = self._get_type_name(expected_type) if expected_type else "i64"
|
||||
if type_name not in self._counters:
|
||||
self._counters[type_name] = 0
|
||||
|
||||
counter = self._counters[type_name]
|
||||
temp_name = f"__helper_temp_{type_name}_{counter}"
|
||||
self._counters[type_name] += 1
|
||||
|
||||
if temp_name not in local_sym_tab:
|
||||
raise ValueError(
|
||||
f"Scratch pool exhausted or inadequate: {temp_name}. "
|
||||
f"Type: {type_name} Counter: {counter}"
|
||||
)
|
||||
|
||||
logger.debug(f"Using {temp_name} for type {type_name}")
|
||||
return local_sym_tab[temp_name].var, temp_name
|
||||
|
||||
|
||||
_temp_pool_manager = ScratchPoolManager() # Singleton instance
|
||||
|
||||
|
||||
def reset_scratch_pool():
|
||||
"""Reset the scratch pool counter"""
|
||||
_temp_pool_manager.reset()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Argument Preparation
|
||||
# ============================================================================
|
||||
@ -75,11 +22,15 @@ def get_var_ptr_from_name(var_name, local_sym_tab):
|
||||
raise ValueError(f"Variable '{var_name}' not found in local symbol table")
|
||||
|
||||
|
||||
def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64):
|
||||
def create_int_constant_ptr(
|
||||
value, builder, compilation_context, local_sym_tab, int_width=64
|
||||
):
|
||||
"""Create a pointer to an integer constant."""
|
||||
|
||||
int_type = ir.IntType(int_width)
|
||||
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab, int_type)
|
||||
ptr, temp_name = compilation_context.scratch_pool.get_next_temp(
|
||||
local_sym_tab, int_type
|
||||
)
|
||||
logger.info(f"Using temp variable '{temp_name}' for int constant {value}")
|
||||
const_val = ir.Constant(int_type, value)
|
||||
builder.store(const_val, ptr)
|
||||
@ -88,12 +39,10 @@ def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64):
|
||||
|
||||
def get_or_create_ptr_from_arg(
|
||||
func,
|
||||
module,
|
||||
compilation_context,
|
||||
arg,
|
||||
builder,
|
||||
local_sym_tab,
|
||||
map_sym_tab,
|
||||
struct_sym_tab=None,
|
||||
expected_type=None,
|
||||
):
|
||||
"""Extract or create pointer from the call arguments."""
|
||||
@ -107,11 +56,14 @@ def get_or_create_ptr_from_arg(
|
||||
int_width = 64 # Default to i64
|
||||
if expected_type and isinstance(expected_type, ir.IntType):
|
||||
int_width = expected_type.width
|
||||
ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab, int_width)
|
||||
ptr = create_int_constant_ptr(
|
||||
arg.value, builder, compilation_context, local_sym_tab, int_width
|
||||
)
|
||||
elif isinstance(arg, ast.Attribute):
|
||||
# A struct field
|
||||
struct_name = arg.value.id
|
||||
field_name = arg.attr
|
||||
struct_sym_tab = compilation_context.structs_sym_tab
|
||||
|
||||
if not local_sym_tab or struct_name not in local_sym_tab:
|
||||
raise ValueError(f"Struct '{struct_name}' not found")
|
||||
@ -136,7 +88,7 @@ def get_or_create_ptr_from_arg(
|
||||
and field_type.element.width == 8
|
||||
):
|
||||
ptr, sz = get_char_array_ptr_and_size(
|
||||
arg, builder, local_sym_tab, struct_sym_tab, func
|
||||
arg, builder, local_sym_tab, compilation_context, func
|
||||
)
|
||||
if not ptr:
|
||||
raise ValueError("Failed to get char array pointer from struct field")
|
||||
@ -146,13 +98,15 @@ def get_or_create_ptr_from_arg(
|
||||
else:
|
||||
# NOTE: For any integer expression reaching this branch, it is probably a struct field or a binop
|
||||
# Evaluate the expression and store the result in a temp variable
|
||||
val = get_operand_value(
|
||||
func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||
)
|
||||
val = eval_expr(func, compilation_context, builder, arg, local_sym_tab)
|
||||
if val:
|
||||
val = val[0]
|
||||
if val is None:
|
||||
raise ValueError("Failed to evaluate expression for helper arg.")
|
||||
|
||||
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab, expected_type)
|
||||
ptr, temp_name = compilation_context.scratch_pool.get_next_temp(
|
||||
local_sym_tab, expected_type
|
||||
)
|
||||
logger.info(f"Using temp variable '{temp_name}' for expression result")
|
||||
if (
|
||||
isinstance(val.type, ir.IntType)
|
||||
@ -188,8 +142,9 @@ def get_flags_val(arg, builder, local_sym_tab):
|
||||
)
|
||||
|
||||
|
||||
def get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab):
|
||||
def get_data_ptr_and_size(data_arg, local_sym_tab, compilation_context):
|
||||
"""Extract data pointer and size information for perf event output."""
|
||||
struct_sym_tab = compilation_context.structs_sym_tab
|
||||
if isinstance(data_arg, ast.Name):
|
||||
data_name = data_arg.id
|
||||
if local_sym_tab and data_name in local_sym_tab:
|
||||
@ -213,8 +168,9 @@ def get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab):
|
||||
)
|
||||
|
||||
|
||||
def get_buffer_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab):
|
||||
def get_buffer_ptr_and_size(buf_arg, builder, local_sym_tab, compilation_context):
|
||||
"""Extract buffer pointer and size from either a struct field or variable."""
|
||||
struct_sym_tab = compilation_context.structs_sym_tab
|
||||
|
||||
# Case 1: Struct field (obj.field)
|
||||
if isinstance(buf_arg, ast.Attribute):
|
||||
@ -268,9 +224,10 @@ def get_buffer_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab):
|
||||
|
||||
|
||||
def get_char_array_ptr_and_size(
|
||||
buf_arg, builder, local_sym_tab, struct_sym_tab, func=None
|
||||
buf_arg, builder, local_sym_tab, compilation_context, func=None
|
||||
):
|
||||
"""Get pointer to char array and its size."""
|
||||
struct_sym_tab = compilation_context.structs_sym_tab
|
||||
|
||||
# Struct field: obj.field
|
||||
if isinstance(buf_arg, ast.Attribute) and isinstance(buf_arg.value, ast.Name):
|
||||
@ -351,14 +308,10 @@ def _is_char_array(ir_type):
|
||||
)
|
||||
|
||||
|
||||
def get_ptr_from_arg(
|
||||
arg, func, module, builder, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||
):
|
||||
def get_ptr_from_arg(arg, func, compilation_context, builder, local_sym_tab):
|
||||
"""Evaluate argument and return pointer value"""
|
||||
|
||||
result = eval_expr(
|
||||
func, module, builder, arg, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||
)
|
||||
result = eval_expr(func, compilation_context, builder, arg, local_sym_tab)
|
||||
|
||||
if not result:
|
||||
raise ValueError("Failed to evaluate argument")
|
||||
@ -371,14 +324,10 @@ def get_ptr_from_arg(
|
||||
return val, val_type
|
||||
|
||||
|
||||
def get_int_value_from_arg(
|
||||
arg, func, module, builder, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||
):
|
||||
def get_int_value_from_arg(arg, func, compilation_context, builder, local_sym_tab):
|
||||
"""Evaluate argument and return integer value"""
|
||||
|
||||
result = eval_expr(
|
||||
func, module, builder, arg, local_sym_tab, map_sym_tab, struct_sym_tab
|
||||
)
|
||||
result = eval_expr(func, compilation_context, builder, arg, local_sym_tab)
|
||||
|
||||
if not result:
|
||||
raise ValueError("Failed to evaluate argument")
|
||||
|
||||
@ -9,8 +9,9 @@ from pythonbpf.helper.helper_utils import get_char_array_ptr_and_size
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def simple_string_print(string_value, module, builder, func):
|
||||
def simple_string_print(string_value, compilation_context, builder, func):
|
||||
"""Prepare arguments for bpf_printk from a simple string value"""
|
||||
module = compilation_context.module
|
||||
fmt_str = string_value + "\n\0"
|
||||
fmt_ptr = _create_format_string_global(fmt_str, func, module, builder)
|
||||
|
||||
@ -20,11 +21,10 @@ def simple_string_print(string_value, module, builder, func):
|
||||
|
||||
def handle_fstring_print(
|
||||
joined_str,
|
||||
module,
|
||||
compilation_context,
|
||||
builder,
|
||||
func,
|
||||
local_sym_tab=None,
|
||||
struct_sym_tab=None,
|
||||
):
|
||||
"""Handle f-string formatting for bpf_printk emitter."""
|
||||
fmt_parts = []
|
||||
@ -41,13 +41,13 @@ def handle_fstring_print(
|
||||
fmt_parts,
|
||||
exprs,
|
||||
local_sym_tab,
|
||||
struct_sym_tab,
|
||||
compilation_context.structs_sym_tab,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported f-string value type: {type(value)}")
|
||||
|
||||
fmt_str = "".join(fmt_parts)
|
||||
args = simple_string_print(fmt_str, module, builder, func)
|
||||
args = simple_string_print(fmt_str, compilation_context, builder, func)
|
||||
|
||||
# NOTE: Process expressions (limited to 3 due to BPF constraints)
|
||||
if len(exprs) > 3:
|
||||
@ -55,12 +55,7 @@ def handle_fstring_print(
|
||||
|
||||
for expr in exprs[:3]:
|
||||
arg_value = _prepare_expr_args(
|
||||
expr,
|
||||
func,
|
||||
module,
|
||||
builder,
|
||||
local_sym_tab,
|
||||
struct_sym_tab,
|
||||
expr, func, compilation_context, builder, local_sym_tab
|
||||
)
|
||||
args.append(arg_value)
|
||||
|
||||
@ -216,19 +211,19 @@ def _create_format_string_global(fmt_str, func, module, builder):
|
||||
return builder.bitcast(fmt_gvar, ir.PointerType())
|
||||
|
||||
|
||||
def _prepare_expr_args(expr, func, module, builder, local_sym_tab, struct_sym_tab):
|
||||
def _prepare_expr_args(expr, func, compilation_context, builder, local_sym_tab):
|
||||
"""Evaluate and prepare an expression to use as an arg for bpf_printk."""
|
||||
|
||||
# Special case: struct field char array needs pointer to first element
|
||||
if isinstance(expr, ast.Attribute):
|
||||
char_array_ptr, _ = get_char_array_ptr_and_size(
|
||||
expr, builder, local_sym_tab, struct_sym_tab, func
|
||||
expr, builder, local_sym_tab, compilation_context, func
|
||||
)
|
||||
if char_array_ptr:
|
||||
return char_array_ptr
|
||||
|
||||
# Regular expression evaluation
|
||||
val, _ = eval_expr(func, module, builder, expr, local_sym_tab, None, struct_sym_tab)
|
||||
val, _ = eval_expr(func, compilation_context, builder, expr, local_sym_tab)
|
||||
|
||||
if not val:
|
||||
logger.warning("Failed to evaluate expression for bpf_printk, defaulting to 0")
|
||||
|
||||
@ -23,7 +23,7 @@ def emit_license(module: ir.Module, license_str: str):
|
||||
return gvar
|
||||
|
||||
|
||||
def license_processing(tree, module):
|
||||
def license_processing(tree, compilation_context):
|
||||
"""Process the LICENSE function decorated with @bpf and @bpfglobal and return the section name"""
|
||||
count = 0
|
||||
for node in tree.body:
|
||||
@ -42,12 +42,14 @@ def license_processing(tree, module):
|
||||
and isinstance(node.body[0].value, ast.Constant)
|
||||
and isinstance(node.body[0].value.value, str)
|
||||
):
|
||||
emit_license(module, node.body[0].value.value)
|
||||
emit_license(
|
||||
compilation_context.module, node.body[0].value.value
|
||||
)
|
||||
return "LICENSE"
|
||||
else:
|
||||
logger.info("ERROR: LICENSE() must return a string literal")
|
||||
return None
|
||||
raise SyntaxError(
|
||||
"ERROR: LICENSE() must return a string literal"
|
||||
)
|
||||
else:
|
||||
logger.info("ERROR: LICENSE already defined")
|
||||
return None
|
||||
raise SyntaxError("ERROR: Multiple LICENSE globals defined")
|
||||
return None
|
||||
|
||||
@ -6,9 +6,10 @@ from .map_types import BPFMapType
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_map_debug_info(module, map_global, map_name, map_params, structs_sym_tab):
|
||||
def create_map_debug_info(compilation_context, map_global, map_name, map_params):
|
||||
"""Generate debug info metadata for BPF maps HASH and PERF_EVENT_ARRAY"""
|
||||
generator = DebugInfoGenerator(module)
|
||||
generator = DebugInfoGenerator(compilation_context.module)
|
||||
structs_sym_tab = compilation_context.structs_sym_tab
|
||||
logger.info(f"Creating debug info for map {map_name} with params {map_params}")
|
||||
uint_type = generator.get_uint32_type()
|
||||
array_type = generator.create_array_type(
|
||||
@ -77,11 +78,9 @@ def create_map_debug_info(module, map_global, map_name, map_params, structs_sym_
|
||||
# Ideally we should expose a single create_map_debug_info function that handles all map types.
|
||||
# We can probably use a registry pattern to register different map types and their debug info generators.
|
||||
# map_params["type"] will be used to determine which generator to use.
|
||||
def create_ringbuf_debug_info(
|
||||
module, map_global, map_name, map_params, structs_sym_tab
|
||||
):
|
||||
def create_ringbuf_debug_info(compilation_context, map_global, map_name, map_params):
|
||||
"""Generate debug information metadata for BPF RINGBUF map"""
|
||||
generator = DebugInfoGenerator(module)
|
||||
generator = DebugInfoGenerator(compilation_context.module)
|
||||
|
||||
int_type = generator.get_int32_type()
|
||||
|
||||
|
||||
@ -12,14 +12,14 @@ from pythonbpf.expr.vmlinux_registry import VmlinuxHandlerRegistry
|
||||
logger: Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def maps_proc(tree, module, chunks, structs_sym_tab):
|
||||
def maps_proc(tree, compilation_context, chunks):
|
||||
"""Process all functions decorated with @map to find BPF maps"""
|
||||
map_sym_tab = {}
|
||||
map_sym_tab = compilation_context.map_sym_tab
|
||||
for func_node in chunks:
|
||||
if is_map(func_node):
|
||||
logger.info(f"Found BPF map: {func_node.name}")
|
||||
map_sym_tab[func_node.name] = process_bpf_map(
|
||||
func_node, module, structs_sym_tab
|
||||
func_node, compilation_context
|
||||
)
|
||||
return map_sym_tab
|
||||
|
||||
@ -31,7 +31,7 @@ def is_map(func_node):
|
||||
)
|
||||
|
||||
|
||||
def create_bpf_map(module, map_name, map_params):
|
||||
def create_bpf_map(compilation_context, map_name, map_params):
|
||||
"""Create a BPF map in the module with given parameters and debug info"""
|
||||
|
||||
# Create the anonymous struct type for BPF map
|
||||
@ -40,7 +40,9 @@ def create_bpf_map(module, map_name, map_params):
|
||||
)
|
||||
|
||||
# Create the global variable
|
||||
map_global = ir.GlobalVariable(module, map_struct_type, name=map_name)
|
||||
map_global = ir.GlobalVariable(
|
||||
compilation_context.module, map_struct_type, name=map_name
|
||||
)
|
||||
map_global.linkage = "dso_local"
|
||||
map_global.global_constant = False
|
||||
map_global.initializer = ir.Constant(map_struct_type, None)
|
||||
@ -55,6 +57,8 @@ def _parse_map_params(rval, expected_args=None):
|
||||
"""Parse map parameters from call arguments and keywords."""
|
||||
|
||||
params = {}
|
||||
|
||||
# TODO: Replace it with compilation_context.vmlinux_handler someday?
|
||||
handler = VmlinuxHandlerRegistry.get_handler()
|
||||
# Parse positional arguments
|
||||
if expected_args:
|
||||
@ -82,10 +86,11 @@ def _parse_map_params(rval, expected_args=None):
|
||||
def _get_vmlinux_enum(handler, name):
|
||||
if handler and handler.is_vmlinux_enum(name):
|
||||
return handler.get_vmlinux_enum_value(name)
|
||||
return None
|
||||
|
||||
|
||||
@MapProcessorRegistry.register("RingBuffer")
|
||||
def process_ringbuf_map(map_name, rval, module, structs_sym_tab):
|
||||
def process_ringbuf_map(map_name, rval, compilation_context):
|
||||
"""Process a BPF_RINGBUF map declaration"""
|
||||
logger.info(f"Processing Ringbuf: {map_name}")
|
||||
map_params = _parse_map_params(rval, expected_args=["max_entries"])
|
||||
@ -104,42 +109,55 @@ def process_ringbuf_map(map_name, rval, module, structs_sym_tab):
|
||||
|
||||
logger.info(f"Ringbuf map parameters: {map_params}")
|
||||
|
||||
map_global = create_bpf_map(module, map_name, map_params)
|
||||
map_global = create_bpf_map(compilation_context, map_name, map_params)
|
||||
create_ringbuf_debug_info(
|
||||
module, map_global.sym, map_name, map_params, structs_sym_tab
|
||||
compilation_context,
|
||||
map_global.sym,
|
||||
map_name,
|
||||
map_params,
|
||||
)
|
||||
return map_global
|
||||
|
||||
|
||||
@MapProcessorRegistry.register("HashMap")
|
||||
def process_hash_map(map_name, rval, module, structs_sym_tab):
|
||||
def process_hash_map(map_name, rval, compilation_context):
|
||||
"""Process a BPF_HASH map declaration"""
|
||||
logger.info(f"Processing HashMap: {map_name}")
|
||||
map_params = _parse_map_params(rval, expected_args=["key", "value", "max_entries"])
|
||||
map_params["type"] = BPFMapType.HASH
|
||||
|
||||
logger.info(f"Map parameters: {map_params}")
|
||||
map_global = create_bpf_map(module, map_name, map_params)
|
||||
map_global = create_bpf_map(compilation_context, map_name, map_params)
|
||||
# Generate debug info for BTF
|
||||
create_map_debug_info(module, map_global.sym, map_name, map_params, structs_sym_tab)
|
||||
create_map_debug_info(
|
||||
compilation_context,
|
||||
map_global.sym,
|
||||
map_name,
|
||||
map_params,
|
||||
)
|
||||
return map_global
|
||||
|
||||
|
||||
@MapProcessorRegistry.register("PerfEventArray")
|
||||
def process_perf_event_map(map_name, rval, module, structs_sym_tab):
|
||||
def process_perf_event_map(map_name, rval, compilation_context):
|
||||
"""Process a BPF_PERF_EVENT_ARRAY map declaration"""
|
||||
logger.info(f"Processing PerfEventArray: {map_name}")
|
||||
map_params = _parse_map_params(rval, expected_args=["key_size", "value_size"])
|
||||
map_params["type"] = BPFMapType.PERF_EVENT_ARRAY
|
||||
|
||||
logger.info(f"Map parameters: {map_params}")
|
||||
map_global = create_bpf_map(module, map_name, map_params)
|
||||
map_global = create_bpf_map(compilation_context, map_name, map_params)
|
||||
# Generate debug info for BTF
|
||||
create_map_debug_info(module, map_global.sym, map_name, map_params, structs_sym_tab)
|
||||
create_map_debug_info(
|
||||
compilation_context,
|
||||
map_global.sym,
|
||||
map_name,
|
||||
map_params,
|
||||
)
|
||||
return map_global
|
||||
|
||||
|
||||
def process_bpf_map(func_node, module, structs_sym_tab):
|
||||
def process_bpf_map(func_node, compilation_context):
|
||||
"""Process a BPF map (a function decorated with @map)"""
|
||||
map_name = func_node.name
|
||||
logger.info(f"Processing BPF map: {map_name}")
|
||||
@ -158,9 +176,9 @@ def process_bpf_map(func_node, module, structs_sym_tab):
|
||||
if isinstance(rval, ast.Call) and isinstance(rval.func, ast.Name):
|
||||
handler = MapProcessorRegistry.get_processor(rval.func.id)
|
||||
if handler:
|
||||
return handler(map_name, rval, module, structs_sym_tab)
|
||||
return handler(map_name, rval, compilation_context)
|
||||
else:
|
||||
logger.warning(f"Unknown map type {rval.func.id}, defaulting to HashMap")
|
||||
return process_hash_map(map_name, rval, module)
|
||||
return process_hash_map(map_name, rval, compilation_context)
|
||||
else:
|
||||
raise ValueError("Function under @map must return a map")
|
||||
|
||||
@ -14,14 +14,17 @@ logger = logging.getLogger(__name__)
|
||||
# Shall we just int64, int32 and uint32 similarly?
|
||||
|
||||
|
||||
def structs_proc(tree, module, chunks):
|
||||
def structs_proc(tree, compilation_context, chunks):
|
||||
"""Process all class definitions to find BPF structs"""
|
||||
structs_sym_tab = {}
|
||||
# Use the context's symbol table
|
||||
structs_sym_tab = compilation_context.structs_sym_tab
|
||||
|
||||
for cls_node in chunks:
|
||||
if is_bpf_struct(cls_node):
|
||||
logger.info(f"Found BPF struct: {cls_node.name}")
|
||||
struct_info = process_bpf_struct(cls_node, module)
|
||||
struct_info = process_bpf_struct(cls_node)
|
||||
structs_sym_tab[cls_node.name] = struct_info
|
||||
|
||||
return structs_sym_tab
|
||||
|
||||
|
||||
@ -32,7 +35,7 @@ def is_bpf_struct(cls_node):
|
||||
)
|
||||
|
||||
|
||||
def process_bpf_struct(cls_node, module):
|
||||
def process_bpf_struct(cls_node):
|
||||
"""Process a single BPF struct definition"""
|
||||
|
||||
fields = parse_struct_fields(cls_node)
|
||||
|
||||
116
tests/README.md
Normal file
116
tests/README.md
Normal file
@ -0,0 +1,116 @@
|
||||
# PythonBPF Test Suite
|
||||
|
||||
## Quick start
|
||||
|
||||
```bash
|
||||
# Activate the venv and install test deps (once)
|
||||
source .venv/bin/activate
|
||||
uv pip install -e ".[test]"
|
||||
|
||||
# Run the full suite (IR + LLC levels, no sudo required)
|
||||
make test
|
||||
|
||||
# Run with coverage report
|
||||
make test-cov
|
||||
```
|
||||
|
||||
## Test levels
|
||||
|
||||
Tests are split into three levels, each in a separate file:
|
||||
|
||||
| Level | File | What it checks | Needs sudo? |
|
||||
|---|---|---|---|
|
||||
| 1 — IR generation | `test_ir_generation.py` | `compile_to_ir()` completes without exception or `logging.ERROR` | No |
|
||||
| 2 — LLC compilation | `test_llc_compilation.py` | Level 1 + `llc` produces a non-empty `.o` file | No |
|
||||
| 3 — Kernel verifier | `test_verifier.py` | `bpftool prog load -d` exits 0 | Yes |
|
||||
|
||||
Levels 1 and 2 run together with `make test`. Level 3 is opt-in:
|
||||
|
||||
```bash
|
||||
make test-verifier # requires bpftool and sudo
|
||||
```
|
||||
|
||||
## Running a single test
|
||||
|
||||
Tests are parametrized by file path. Use `-k` to filter:
|
||||
|
||||
```bash
|
||||
# By file name
|
||||
pytest tests/ -v -k "and.py" -m "not verifier"
|
||||
|
||||
# By category
|
||||
pytest tests/ -v -k "conditionals" -m "not verifier"
|
||||
|
||||
# One specific level only
|
||||
pytest tests/test_ir_generation.py -v -k "hash_map.py"
|
||||
```
|
||||
|
||||
## Coverage report
|
||||
|
||||
```bash
|
||||
make test-cov
|
||||
```
|
||||
|
||||
- **Terminal**: shows per-file coverage with missing lines after the test run.
|
||||
- **HTML**: written to `htmlcov/index.html` — open in a browser for line-by-line detail.
|
||||
|
||||
```bash
|
||||
xdg-open htmlcov/index.html
|
||||
```
|
||||
|
||||
`htmlcov/` and `.coverage` are excluded from git (listed in `.gitignore` if not already).
|
||||
|
||||
## Expected failures (`test_config.toml`)
|
||||
|
||||
Known-broken tests are declared in `tests/test_config.toml`:
|
||||
|
||||
```toml
|
||||
[xfail]
|
||||
"failing_tests/my_test.py" = {reason = "...", level = "ir"}
|
||||
```
|
||||
|
||||
- `level = "ir"` — fails during IR generation; both IR and LLC tests are marked xfail.
|
||||
- `level = "llc"` — IR generates fine but `llc` rejects it; only the LLC test is marked xfail.
|
||||
|
||||
All xfails use `strict = True`: if a test starts **passing** it shows up as **XPASS** and is treated as a test failure. This is intentional — it means the bug was fixed and the test should be promoted to `passing_tests/`.
|
||||
|
||||
## Adding a new test
|
||||
|
||||
1. Create a `.py` file in `tests/passing_tests/<category>/` with the usual `@bpf` decorators and a `compile()` call at the bottom.
|
||||
2. Run `make test` — the file is discovered and tested automatically at all levels.
|
||||
3. If the test is expected to fail, add it to `tests/test_config.toml` instead of `passing_tests/`.
|
||||
|
||||
## Directory structure
|
||||
|
||||
```
|
||||
tests/
|
||||
├── README.md ← you are here
|
||||
├── conftest.py ← pytest config: discovery, xfail/skip injection, fixtures
|
||||
├── test_config.toml ← expected-failure list
|
||||
├── test_ir_generation.py ← Level 1
|
||||
├── test_llc_compilation.py ← Level 2
|
||||
├── test_verifier.py ← Level 3 (opt-in, sudo)
|
||||
├── framework/
|
||||
│ ├── bpf_test_case.py ← BpfTestCase dataclass
|
||||
│ ├── collector.py ← discovers test files, reads test_config.toml
|
||||
│ ├── compiler.py ← wrappers around compile_to_ir() + _run_llc()
|
||||
│ └── verifier.py ← bpftool subprocess wrapper
|
||||
├── passing_tests/ ← programs that should compile and verify cleanly
|
||||
└── failing_tests/ ← programs with known issues (declared in test_config.toml)
|
||||
```
|
||||
|
||||
## Known regressions (as of compilation-context PR merge)
|
||||
|
||||
Three tests in `passing_tests/` currently fail — these are bugs to fix, not xfails:
|
||||
|
||||
| Test | Error |
|
||||
|---|---|
|
||||
| `passing_tests/assign/comprehensive.py` | `TypeError: cannot store i64* to i64*` |
|
||||
| `passing_tests/helpers/bpf_probe_read.py` | `ValueError: 'ctx' not in local symbol table` |
|
||||
| `passing_tests/vmlinux/register_state_dump.py` | `KeyError: 'cs'` |
|
||||
|
||||
Nine tests in `failing_tests/` were fixed by the compilation-context PR (they show as XPASS). They can be moved to `passing_tests/` when convenient:
|
||||
|
||||
`assign/retype.py`, `conditionals/helper_cond.py`, `conditionals/oneline.py`,
|
||||
`direct_assign.py`, `globals.py`, `if.py`, `license.py` (IR only), `named_arg.py`,
|
||||
`xdp/xdp_test_1.py`
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
103
tests/conftest.py
Normal file
103
tests/conftest.py
Normal file
@ -0,0 +1,103 @@
|
||||
"""
|
||||
pytest configuration for the PythonBPF test suite.
|
||||
|
||||
Test discovery:
|
||||
All .py files under tests/passing_tests/ and tests/failing_tests/ are
|
||||
collected as parametrized BPF test cases.
|
||||
|
||||
Markers applied automatically from test_config.toml:
|
||||
- xfail (strict=True): failing_tests/ entries that are expected to fail
|
||||
- skip: vmlinux tests when vmlinux.py is not importable
|
||||
|
||||
Run the suite:
|
||||
pytest tests/ -v -m "not verifier" # IR + LLC only (no sudo)
|
||||
pytest tests/ -v --cov=pythonbpf # with coverage
|
||||
pytest tests/test_verifier.py -m verifier # kernel verifier (sudo required)
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.framework.collector import collect_all_test_files
|
||||
|
||||
# ── vmlinux availability ────────────────────────────────────────────────────
|
||||
|
||||
try:
|
||||
import vmlinux # noqa: F401
|
||||
|
||||
VMLINUX_AVAILABLE = True
|
||||
except ImportError:
|
||||
VMLINUX_AVAILABLE = False
|
||||
|
||||
|
||||
# ── shared fixture: collected test cases ───────────────────────────────────
|
||||
|
||||
|
||||
def _all_cases():
|
||||
return collect_all_test_files()
|
||||
|
||||
|
||||
# ── pytest_generate_tests: parametrize on bpf_test_file ───────────────────
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
if "bpf_test_file" in metafunc.fixturenames:
|
||||
cases = _all_cases()
|
||||
metafunc.parametrize(
|
||||
"bpf_test_file",
|
||||
[c.path for c in cases],
|
||||
ids=[c.rel_path for c in cases],
|
||||
)
|
||||
|
||||
|
||||
# ── pytest_collection_modifyitems: apply xfail / skip markers ─────────────
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(items):
|
||||
case_map = {c.rel_path: c for c in _all_cases()}
|
||||
|
||||
for item in items:
|
||||
# Resolve the test case from the parametrize ID embedded in the node id.
|
||||
# Node id format: tests/test_foo.py::test_bar[passing_tests/helpers/pid.py]
|
||||
case = None
|
||||
for bracket in (item.callspec.id,) if hasattr(item, "callspec") else ():
|
||||
case = case_map.get(bracket)
|
||||
break
|
||||
|
||||
if case is None:
|
||||
continue
|
||||
|
||||
# vmlinux skip
|
||||
if case.needs_vmlinux and not VMLINUX_AVAILABLE:
|
||||
item.add_marker(
|
||||
pytest.mark.skip(reason="vmlinux.py not available for current kernel")
|
||||
)
|
||||
continue
|
||||
|
||||
# xfail (strict: XPASS counts as a test failure, alerting us to fixed bugs)
|
||||
if case.is_expected_fail:
|
||||
# Level "ir" → fails at IR generation: xfail both IR and LLC tests
|
||||
# Level "llc" → IR succeeds but LLC fails: only xfail the LLC test
|
||||
is_llc_test = item.nodeid.startswith("tests/test_llc_compilation.py")
|
||||
|
||||
apply_xfail = (case.xfail_level == "ir") or (
|
||||
case.xfail_level == "llc" and is_llc_test
|
||||
)
|
||||
if apply_xfail:
|
||||
item.add_marker(
|
||||
pytest.mark.xfail(
|
||||
reason=case.xfail_reason,
|
||||
strict=True,
|
||||
raises=Exception,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ── caplog level fixture: capture ERROR+ from pythonbpf ───────────────────
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_log_level(caplog):
|
||||
with caplog.at_level(logging.ERROR, logger="pythonbpf"):
|
||||
yield
|
||||
0
tests/framework/__init__.py
Normal file
0
tests/framework/__init__.py
Normal file
17
tests/framework/bpf_test_case.py
Normal file
17
tests/framework/bpf_test_case.py
Normal file
@ -0,0 +1,17 @@
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class BpfTestCase:
|
||||
path: Path
|
||||
rel_path: str
|
||||
is_expected_fail: bool = False
|
||||
xfail_reason: str = ""
|
||||
xfail_level: str = "ir" # "ir" or "llc"
|
||||
needs_vmlinux: bool = False
|
||||
skip_reason: str = ""
|
||||
|
||||
@property
|
||||
def test_id(self) -> str:
|
||||
return self.rel_path.replace("/", "::")
|
||||
60
tests/framework/collector.py
Normal file
60
tests/framework/collector.py
Normal file
@ -0,0 +1,60 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib
|
||||
else:
|
||||
import tomli as tomllib
|
||||
|
||||
from .bpf_test_case import BpfTestCase
|
||||
|
||||
TESTS_DIR = Path(__file__).parent.parent
|
||||
CONFIG_FILE = TESTS_DIR / "test_config.toml"
|
||||
|
||||
VMLINUX_TEST_DIRS = {"passing_tests/vmlinux"}
|
||||
VMLINUX_TEST_PREFIXES = {
|
||||
"failing_tests/vmlinux",
|
||||
"failing_tests/xdp",
|
||||
}
|
||||
|
||||
|
||||
def _is_vmlinux_test(rel_path: str) -> bool:
|
||||
for prefix in VMLINUX_TEST_DIRS | VMLINUX_TEST_PREFIXES:
|
||||
if rel_path.startswith(prefix):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
if not CONFIG_FILE.exists():
|
||||
return {}
|
||||
with open(CONFIG_FILE, "rb") as f:
|
||||
return tomllib.load(f)
|
||||
|
||||
|
||||
def collect_all_test_files() -> list[BpfTestCase]:
|
||||
config = _load_config()
|
||||
xfail_map: dict = config.get("xfail", {})
|
||||
|
||||
cases = []
|
||||
for subdir in ("passing_tests", "failing_tests"):
|
||||
for py_file in sorted((TESTS_DIR / subdir).rglob("*.py")):
|
||||
rel = str(py_file.relative_to(TESTS_DIR))
|
||||
needs_vmlinux = _is_vmlinux_test(rel)
|
||||
|
||||
xfail_entry = xfail_map.get(rel)
|
||||
is_expected_fail = xfail_entry is not None
|
||||
xfail_reason = xfail_entry.get("reason", "") if xfail_entry else ""
|
||||
xfail_level = xfail_entry.get("level", "ir") if xfail_entry else "ir"
|
||||
|
||||
cases.append(
|
||||
BpfTestCase(
|
||||
path=py_file,
|
||||
rel_path=rel,
|
||||
is_expected_fail=is_expected_fail,
|
||||
xfail_reason=xfail_reason,
|
||||
xfail_level=xfail_level,
|
||||
needs_vmlinux=needs_vmlinux,
|
||||
)
|
||||
)
|
||||
return cases
|
||||
23
tests/framework/compiler.py
Normal file
23
tests/framework/compiler.py
Normal file
@ -0,0 +1,23 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from pythonbpf.codegen import compile_to_ir, _run_llc
|
||||
|
||||
|
||||
def run_ir_generation(test_path: Path, output_ll: Path):
|
||||
"""Run compile_to_ir on a BPF test file.
|
||||
|
||||
Returns the (output, structs_sym_tab, maps_sym_tab) tuple from compile_to_ir.
|
||||
Raises on exception. Any logging.ERROR records captured by pytest caplog
|
||||
indicate a compile failure even when no exception is raised.
|
||||
"""
|
||||
return compile_to_ir(str(test_path), str(output_ll), loglevel=logging.WARNING)
|
||||
|
||||
|
||||
def run_llc(ll_path: Path, obj_path: Path) -> bool:
|
||||
"""Compile a .ll file to a BPF .o using llc.
|
||||
|
||||
Raises subprocess.CalledProcessError on failure (llc uses check=True).
|
||||
Returns True on success.
|
||||
"""
|
||||
return _run_llc(str(ll_path), str(obj_path))
|
||||
25
tests/framework/verifier.py
Normal file
25
tests/framework/verifier.py
Normal file
@ -0,0 +1,25 @@
|
||||
import subprocess
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def verify_object(obj_path: Path) -> tuple[bool, str]:
|
||||
"""Run bpftool prog load -d to verify a BPF object file against the kernel verifier.
|
||||
|
||||
Pins the program temporarily at /sys/fs/bpf/bpf_prog_test_<uuid>, then removes it.
|
||||
Returns (success, combined_output). Requires sudo / root.
|
||||
"""
|
||||
pin_path = f"/sys/fs/bpf/bpf_prog_test_{uuid.uuid4().hex[:8]}"
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["sudo", "bpftool", "prog", "load", "-d", str(obj_path), pin_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
output = result.stdout + result.stderr
|
||||
return result.returncode == 0, output
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, "bpftool timed out after 30s"
|
||||
finally:
|
||||
subprocess.run(["sudo", "rm", "-f", pin_path], check=False, capture_output=True)
|
||||
33
tests/test_config.toml
Normal file
33
tests/test_config.toml
Normal file
@ -0,0 +1,33 @@
|
||||
# test_config.toml
|
||||
#
|
||||
# [xfail] — tests expected to fail.
|
||||
# key = path relative to tests/
|
||||
# value = {reason = "...", level = "ir" | "llc"}
|
||||
# level "ir" = fails during pythonbpf IR generation (exception or ERROR log)
|
||||
# level "llc" = IR generates but llc rejects it
|
||||
#
|
||||
# Tests removed from this list because they now pass (fixed by compilation-context PR):
|
||||
# failing_tests/assign/retype.py
|
||||
# failing_tests/conditionals/helper_cond.py
|
||||
# failing_tests/conditionals/oneline.py
|
||||
# failing_tests/direct_assign.py
|
||||
# failing_tests/globals.py
|
||||
# failing_tests/if.py
|
||||
# failing_tests/license.py
|
||||
# failing_tests/named_arg.py
|
||||
# failing_tests/xdp/xdp_test_1.py
|
||||
# These files can be moved to passing_tests/ when convenient.
|
||||
|
||||
[xfail]
|
||||
|
||||
"failing_tests/conditionals/struct_ptr.py" = {reason = "Struct pointer used directly as boolean condition not supported", level = "ir"}
|
||||
|
||||
"failing_tests/license.py" = {reason = "Missing LICENSE global produces IR that llc rejects — should be caught earlier with a clear error message", level = "llc"}
|
||||
|
||||
"failing_tests/undeclared_values.py" = {reason = "Undeclared variable used in f-string — should raise SyntaxError (correct behaviour, test documents it)", level = "ir"}
|
||||
|
||||
"failing_tests/vmlinux/args_test.py" = {reason = "struct_trace_event_raw_sys_enter args field access not supported", level = "ir"}
|
||||
|
||||
"failing_tests/vmlinux/assignment_handling.py" = {reason = "Assigning vmlinux enum value (XDP_PASS) to a local variable not yet supported", level = "ir"}
|
||||
|
||||
"failing_tests/xdp_pass.py" = {reason = "XDP program using vmlinux structs (struct_xdp_md) and complex map/struct interaction not yet supported", level = "ir"}
|
||||
29
tests/test_ir_generation.py
Normal file
29
tests/test_ir_generation.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""
|
||||
Level 1 — IR Generation tests.
|
||||
|
||||
For every BPF test file, calls compile_to_ir() and asserts:
|
||||
1. No exception is raised by the pythonbpf compiler.
|
||||
2. No logging.ERROR records are emitted during compilation.
|
||||
3. A .ll file is produced.
|
||||
|
||||
Tests in failing_tests/ are marked xfail (strict=True) by conftest.py —
|
||||
they must raise an exception or produce an ERROR log to pass the suite.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
from tests.framework.compiler import run_ir_generation
|
||||
|
||||
|
||||
def test_ir_generation(bpf_test_file: Path, tmp_path, caplog):
|
||||
ll_path = tmp_path / "output.ll"
|
||||
|
||||
run_ir_generation(bpf_test_file, ll_path)
|
||||
|
||||
error_records = [r for r in caplog.records if r.levelno >= logging.ERROR]
|
||||
assert not error_records, "IR generation produced ERROR log(s):\n" + "\n".join(
|
||||
f" [{r.name}] {r.getMessage()}" for r in error_records
|
||||
)
|
||||
assert ll_path.exists(), "compile_to_ir() returned without writing a .ll file"
|
||||
32
tests/test_llc_compilation.py
Normal file
32
tests/test_llc_compilation.py
Normal file
@ -0,0 +1,32 @@
|
||||
"""
|
||||
Level 2 — LLC compilation tests.
|
||||
|
||||
For every BPF test file, runs the full compile_to_ir() + _run_llc() pipeline
|
||||
and asserts a non-empty .o file is produced.
|
||||
|
||||
Tests in failing_tests/ are marked xfail (strict=True) by conftest.py.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
from tests.framework.compiler import run_ir_generation, run_llc
|
||||
|
||||
|
||||
def test_llc_compilation(bpf_test_file: Path, tmp_path, caplog):
|
||||
ll_path = tmp_path / "output.ll"
|
||||
obj_path = tmp_path / "output.o"
|
||||
|
||||
run_ir_generation(bpf_test_file, ll_path)
|
||||
|
||||
error_records = [r for r in caplog.records if r.levelno >= logging.ERROR]
|
||||
assert not error_records, "IR generation produced ERROR log(s):\n" + "\n".join(
|
||||
f" [{r.name}] {r.getMessage()}" for r in error_records
|
||||
)
|
||||
|
||||
run_llc(ll_path, obj_path)
|
||||
|
||||
assert obj_path.exists() and obj_path.stat().st_size > 0, (
|
||||
"llc did not produce a non-empty .o file"
|
||||
)
|
||||
65
tests/test_verifier.py
Normal file
65
tests/test_verifier.py
Normal file
@ -0,0 +1,65 @@
|
||||
"""
|
||||
Level 3 — Kernel verifier tests.
|
||||
|
||||
For every passing BPF test file, compiles to a .o and runs:
|
||||
sudo bpftool prog load -d <file.o> /sys/fs/bpf/bpf_prog_test_<id>
|
||||
|
||||
These tests are opt-in: they require sudo and kernel access, and are gated
|
||||
behind the `verifier` pytest mark. Run with:
|
||||
|
||||
pytest tests/test_verifier.py -m verifier -v
|
||||
|
||||
Note: uses the venv Python binary for any in-process calls, but bpftool
|
||||
itself is invoked via subprocess with sudo. Ensure bpftool is installed
|
||||
and the user can sudo.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.framework.collector import collect_all_test_files
|
||||
from tests.framework.compiler import run_ir_generation, run_llc
|
||||
from tests.framework.verifier import verify_object
|
||||
|
||||
|
||||
def _passing_test_files():
|
||||
return [
|
||||
c.path
|
||||
for c in collect_all_test_files()
|
||||
if not c.is_expected_fail and not c.needs_vmlinux
|
||||
]
|
||||
|
||||
|
||||
def _passing_test_ids():
|
||||
return [
|
||||
c.rel_path
|
||||
for c in collect_all_test_files()
|
||||
if not c.is_expected_fail and not c.needs_vmlinux
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.verifier
|
||||
@pytest.mark.parametrize(
|
||||
"verifier_test_file",
|
||||
_passing_test_files(),
|
||||
ids=_passing_test_ids(),
|
||||
)
|
||||
def test_kernel_verifier(verifier_test_file: Path, tmp_path, caplog):
|
||||
"""Compile the BPF test and verify it passes the kernel verifier."""
|
||||
ll_path = tmp_path / "output.ll"
|
||||
obj_path = tmp_path / "output.o"
|
||||
|
||||
run_ir_generation(verifier_test_file, ll_path)
|
||||
|
||||
error_records = [r for r in caplog.records if r.levelno >= logging.ERROR]
|
||||
assert not error_records, "IR generation produced ERROR log(s):\n" + "\n".join(
|
||||
f" [{r.name}] {r.getMessage()}" for r in error_records
|
||||
)
|
||||
|
||||
run_llc(ll_path, obj_path)
|
||||
assert obj_path.exists() and obj_path.stat().st_size > 0
|
||||
|
||||
ok, output = verify_object(obj_path)
|
||||
assert ok, f"Kernel verifier rejected {verifier_test_file.name}:\n{output}"
|
||||
Reference in New Issue
Block a user