Compare commits

..

8 Commits

Author SHA1 Message Date
a2b1a8baff Clarify compilation target in package documentation
Updated documentation to clarify that the compilation targets LLVM IR instead of eBPF bytecode directly.
2025-10-09 10:10:34 +05:30
22289821f9 format chore 2025-10-09 10:06:17 +05:30
d86dd683f4 ignore vmlinux and extend example 2025-10-09 10:04:35 +05:30
6881d2e960 Complete documentation coverage: add final module docstrings
Co-authored-by: varun-r-mallya <100590632+varun-r-mallya@users.noreply.github.com>
2025-10-08 17:30:03 +00:00
d9dfb61000 Add remaining docstrings to complete documentation coverage
Co-authored-by: varun-r-mallya <100590632+varun-r-mallya@users.noreply.github.com>
2025-10-08 17:25:29 +00:00
cdf4f3e885 Add module-level docstrings and helper utility docstrings
Co-authored-by: varun-r-mallya <100590632+varun-r-mallya@users.noreply.github.com>
2025-10-08 17:20:45 +00:00
5b20b08d9f Add docstrings to core modules and helper functions
Co-authored-by: varun-r-mallya <100590632+varun-r-mallya@users.noreply.github.com>
2025-10-08 17:16:05 +00:00
9f103c34a0 Initial plan 2025-10-08 17:04:52 +00:00
89 changed files with 250446 additions and 5549 deletions

2
.gitignore vendored
View File

@ -8,5 +8,3 @@ __pycache__/
*.o
.ipynb_checkpoints/
vmlinux.py
~*
vmlinux.h

View File

@ -12,7 +12,7 @@
#
# See https://github.com/pre-commit/pre-commit
exclude: 'vmlinux.py'
exclude: 'vmlinux.*\.py$'
ci:
autoupdate_commit_msg: "chore: update pre-commit hooks"
@ -41,7 +41,7 @@ repos:
- id: ruff
args: ["--fix", "--show-fixes"]
- id: ruff-format
# exclude: ^(docs)|^(tests)|^(examples)
exclude: ^(docs)|^(tests)|^(examples)
# Checking static types
- repo: https://github.com/pre-commit/mirrors-mypy

View File

@ -1,34 +0,0 @@
from pythonbpf import bpf, section, bpfglobal, BPF, trace_fields
from ctypes import c_void_p, c_int64
@bpf
@section("tracepoint/syscalls/sys_enter_clone")
def hello_world(ctx: c_void_p) -> c_int64:
print("Hello, World!")
return 0 # type: ignore [return-value]
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
# Compile and load
b = BPF()
b.load()
b.attach_all()
# header
print(f"{'TIME(s)':<18} {'COMM':<16} {'PID':<6} {'MESSAGE'}")
# format output
while True:
try:
(task, pid, cpu, flags, ts, msg) = trace_fields()
except ValueError:
continue
except KeyboardInterrupt:
exit()
print(f"{ts:<18} {task:<16} {pid:<6} {msg}")

View File

@ -1,61 +0,0 @@
from pythonbpf import bpf, map, struct, section, bpfglobal, BPF
from pythonbpf.helper import ktime, pid, comm
from pythonbpf.maps import PerfEventArray
from ctypes import c_void_p, c_int64
@bpf
@struct
class data_t:
pid: c_int64
ts: c_int64
comm: str(16) # type: ignore [valid-type]
@bpf
@map
def events() -> PerfEventArray:
return PerfEventArray(key_size=c_int64, value_size=c_int64)
@bpf
@section("tracepoint/syscalls/sys_enter_clone")
def hello(ctx: c_void_p) -> c_int64:
dataobj = data_t()
dataobj.pid, dataobj.ts = pid(), ktime()
comm(dataobj.comm)
events.output(dataobj)
return 0 # type: ignore [return-value]
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
# Compile and load
b = BPF()
b.load()
b.attach_all()
start = 0
def callback(cpu, event):
global start
if start == 0:
start = event.ts
ts = (event.ts - start) / 1e9
print(f"[CPU {cpu}] PID: {event.pid}, TS: {ts}, COMM: {event.comm.decode()}")
perf = b["events"].open_perf_buffer(callback, struct_name="data_t")
print("Starting to poll... (Ctrl+C to stop)")
print("Try running: fork() or clone() system calls to trigger events")
try:
while True:
b["events"].poll(1000)
except KeyboardInterrupt:
print("Stopping...")

View File

@ -1,23 +0,0 @@
from pythonbpf import bpf, section, bpfglobal, BPF, trace_pipe
from ctypes import c_void_p, c_int64
@bpf
@section("tracepoint/syscalls/sys_enter_clone")
def hello_world(ctx: c_void_p) -> c_int64:
print("Hello, World!")
return 0 # type: ignore [return-value]
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
# Compile and load
b = BPF()
b.load()
b.attach_all()
trace_pipe()

View File

@ -1,58 +0,0 @@
from pythonbpf import bpf, map, section, bpfglobal, BPF, trace_fields
from pythonbpf.helper import ktime
from pythonbpf.maps import HashMap
from ctypes import c_void_p, c_int64
@bpf
@map
def last() -> HashMap:
return HashMap(key=c_int64, value=c_int64, max_entries=2)
@bpf
@section("tracepoint/syscalls/sys_enter_sync")
def do_trace(ctx: c_void_p) -> c_int64:
ts_key, cnt_key = 0, 1
tsp, cntp = last.lookup(ts_key), last.lookup(cnt_key)
if not cntp:
last.update(cnt_key, 0)
cntp = last.lookup(cnt_key)
if tsp:
delta = ktime() - tsp
if delta < 1000000000:
time_ms = delta // 1000000
print(f"{time_ms} {cntp}")
last.delete(ts_key)
else:
last.update(ts_key, ktime())
last.update(cnt_key, cntp + 1)
return 0 # type: ignore [return-value]
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
# Compile and load
b = BPF()
b.load()
b.attach_all()
print("Tracing for quick sync's... Ctrl-C to end")
# format output
start = 0
while True:
try:
task, pid, cpu, flags, ts, msg = trace_fields()
if start == 0:
start = ts
ts -= start
ms, cnt = msg.split()
print(f"At time {ts} s: Multiple syncs detected, last {ms} ms ago. Count {cnt}")
except KeyboardInterrupt:
exit()

View File

@ -1,78 +0,0 @@
from pythonbpf import bpf, map, struct, section, bpfglobal, BPF
from pythonbpf.helper import ktime
from pythonbpf.maps import HashMap
from pythonbpf.maps import PerfEventArray
from ctypes import c_void_p, c_int64
@bpf
@struct
class data_t:
ts: c_int64
ms: c_int64
@bpf
@map
def events() -> PerfEventArray:
return PerfEventArray(key_size=c_int64, value_size=c_int64)
@bpf
@map
def last() -> HashMap:
return HashMap(key=c_int64, value=c_int64, max_entries=1)
@bpf
@section("tracepoint/syscalls/sys_enter_sync")
def do_trace(ctx: c_void_p) -> c_int64:
dat, dat.ts, key = data_t(), ktime(), 0
tsp = last.lookup(key)
if tsp:
delta = ktime() - tsp
if delta < 1000000000:
dat.ms = delta // 1000000
events.output(dat)
last.delete(key)
else:
last.update(key, ktime())
return 0 # type: ignore [return-value]
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
# Compile and load
b = BPF()
b.load()
b.attach_all()
print("Tracing for quick sync's... Ctrl-C to end")
# format output
start = 0
def callback(cpu, event):
global start
if start == 0:
start = event.ts
event.ts -= start
print(
f"At time {event.ts / 1e9} s: Multiple sync detected, Last sync: {event.ms} ms ago"
)
perf = b["events"].open_perf_buffer(callback, struct_name="data_t")
print("Starting to poll... (Ctrl+C to stop)")
print("Try running: fork() or clone() system calls to trigger events")
try:
while True:
b["events"].poll(1000)
except KeyboardInterrupt:
print("Stopping...")

View File

@ -1,53 +0,0 @@
from pythonbpf import bpf, map, section, bpfglobal, BPF, trace_fields
from pythonbpf.helper import ktime
from pythonbpf.maps import HashMap
from ctypes import c_void_p, c_int64
@bpf
@map
def last() -> HashMap:
return HashMap(key=c_int64, value=c_int64, max_entries=1)
@bpf
@section("tracepoint/syscalls/sys_enter_sync")
def do_trace(ctx: c_void_p) -> c_int64:
key = 0
tsp = last.lookup(key)
if tsp:
delta = ktime() - tsp
if delta < 1000000000:
time_ms = delta // 1000000
print(f"{time_ms}")
last.delete(key)
else:
last.update(key, ktime())
return 0 # type: ignore [return-value]
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
# Compile and load
b = BPF()
b.load()
b.attach_all()
print("Tracing for quick sync's... Ctrl-C to end")
# format output
start = 0
while True:
try:
task, pid, cpu, flags, ts, ms = trace_fields()
if start == 0:
start = ts
ts -= start
print(f"At time {ts} s: Multiple syncs detected, last {ms} ms ago")
except KeyboardInterrupt:
exit()

View File

@ -1,23 +0,0 @@
from pythonbpf import bpf, section, bpfglobal, BPF, trace_pipe
from ctypes import c_void_p, c_int64
@bpf
@section("tracepoint/syscalls/sys_enter_sync")
def hello_world(ctx: c_void_p) -> c_int64:
print("sys_sync() called")
return c_int64(0)
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
# Compile and load
b = BPF()
b.load()
b.attach_all()
print("Tracing sys_sync()... Ctrl-C to end.")
trace_pipe()

View File

@ -1,127 +0,0 @@
from pythonbpf import bpf, map, struct, section, bpfglobal, BPF
from pythonbpf.helper import ktime, pid
from pythonbpf.maps import HashMap, PerfEventArray
from ctypes import c_void_p, c_uint64
import matplotlib.pyplot as plt
import numpy as np
@bpf
@struct
class latency_event:
pid: c_uint64
delta_us: c_uint64 # Latency in microseconds
@bpf
@map
def start() -> HashMap:
return HashMap(key=c_uint64, value=c_uint64, max_entries=10240)
@bpf
@map
def events() -> PerfEventArray:
return PerfEventArray(key_size=c_uint64, value_size=c_uint64)
@bpf
@section("kprobe/vfs_read")
def do_entry(ctx: c_void_p) -> c_uint64:
p, ts = pid(), ktime()
start.update(p, ts)
return 0 # type: ignore [return-value]
@bpf
@section("kretprobe/vfs_read")
def do_return(ctx: c_void_p) -> c_uint64:
p = pid()
tsp = start.lookup(p)
if tsp:
delta_ns = ktime() - tsp
# Only track if latency > 1 microsecond
if delta_ns > 1000:
evt = latency_event()
evt.pid, evt.delta_us = p, delta_ns // 1000
events.output(evt)
start.delete(p)
return 0 # type: ignore [return-value]
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
# Load BPF
print("Loading BPF program...")
b = BPF()
b.load()
b.attach_all()
# Collect latencies
latencies = []
def callback(cpu, event):
latencies.append(event.delta_us)
b["events"].open_perf_buffer(callback, struct_name="latency_event")
print("Tracing vfs_read latency... Hit Ctrl-C to end.")
try:
while True:
b["events"].poll(1000)
if len(latencies) > 0 and len(latencies) % 1000 == 0:
print(f"Collected {len(latencies)} samples...")
except KeyboardInterrupt:
print(f"Collected {len(latencies)} samples. Generating histogram...")
# Create histogram with matplotlib
if latencies:
# Use log scale for better visualization
log_latencies = np.log2(latencies)
plt.figure(figsize=(12, 6))
# Plot 1: Linear histogram
plt.subplot(1, 2, 1)
plt.hist(latencies, bins=50, edgecolor="black", alpha=0.7)
plt.xlabel("Latency (microseconds)")
plt.ylabel("Count")
plt.title("VFS Read Latency Distribution (Linear)")
plt.grid(True, alpha=0.3)
# Plot 2: Log2 histogram (like BCC)
plt.subplot(1, 2, 2)
plt.hist(log_latencies, bins=50, edgecolor="black", alpha=0.7, color="orange")
plt.xlabel("log2(Latency in µs)")
plt.ylabel("Count")
plt.title("VFS Read Latency Distribution (Log2)")
plt.grid(True, alpha=0.3)
# Add statistics
print("Statistics:")
print(f" Count: {len(latencies)}")
print(f" Min: {min(latencies)} µs")
print(f" Max: {max(latencies)} µs")
print(f" Mean: {np.mean(latencies):.2f} µs")
print(f" Median: {np.median(latencies):.2f} µs")
print(f" P95: {np.percentile(latencies, 95):.2f} µs")
print(f" P99: {np.percentile(latencies, 99):.2f} µs")
plt.tight_layout()
plt.savefig("vfs_read_latency.png", dpi=150)
print("Histogram saved to vfs_read_latency.png")
plt.show()
else:
print("No samples collected!")

View File

@ -1,101 +0,0 @@
"""BPF program for tracing VFS read latency."""
from pythonbpf import bpf, map, struct, section, bpfglobal, BPF
from pythonbpf.helper import ktime, pid
from pythonbpf.maps import HashMap, PerfEventArray
from ctypes import c_void_p, c_uint64
import argparse
from data_collector import LatencyCollector
from dashboard import LatencyDashboard
@bpf
@struct
class latency_event:
pid: c_uint64
delta_us: c_uint64
@bpf
@map
def start() -> HashMap:
"""Map to store start timestamps by PID."""
return HashMap(key=c_uint64, value=c_uint64, max_entries=10240)
@bpf
@map
def events() -> PerfEventArray:
"""Perf event array for sending latency events to userspace."""
return PerfEventArray(key_size=c_uint64, value_size=c_uint64)
@bpf
@section("kprobe/vfs_read")
def do_entry(ctx: c_void_p) -> c_uint64:
"""Record start time when vfs_read is called."""
p, ts = pid(), ktime()
start.update(p, ts)
return 0 # type: ignore [return-value]
@bpf
@section("kretprobe/vfs_read")
def do_return(ctx: c_void_p) -> c_uint64:
"""Calculate and record latency when vfs_read returns."""
p = pid()
tsp = start.lookup(p)
if tsp:
delta_ns = ktime() - tsp
# Only track latencies > 1 microsecond
if delta_ns > 1000:
evt = latency_event()
evt.pid, evt.delta_us = p, delta_ns // 1000
events.output(evt)
start.delete(p)
return 0 # type: ignore [return-value]
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Monitor VFS read latency with live dashboard"
)
parser.add_argument(
"--host", default="0.0.0.0", help="Dashboard host (default: 0.0.0.0)"
)
parser.add_argument(
"--port", type=int, default=8050, help="Dashboard port (default: 8050)"
)
parser.add_argument(
"--buffer", type=int, default=10000, help="Recent data buffer size"
)
return parser.parse_args()
args = parse_args()
# Load BPF program
print("Loading BPF program...")
b = BPF()
b.load()
b.attach_all()
print("✅ BPF program loaded and attached")
# Setup data collector
collector = LatencyCollector(b, buffer_size=args.buffer)
collector.start()
# Create and run dashboard
dashboard = LatencyDashboard(collector)
dashboard.run(host=args.host, port=args.port)

View File

@ -1,282 +0,0 @@
"""Plotly Dash dashboard for visualizing latency data."""
import dash
from dash import dcc, html
from dash.dependencies import Input, Output
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
class LatencyDashboard:
"""Interactive dashboard for latency visualization."""
def __init__(self, collector, title: str = "VFS Read Latency Monitor"):
self.collector = collector
self.app = dash.Dash(__name__)
self.app.title = title
self._setup_layout()
self._setup_callbacks()
def _setup_layout(self):
"""Create dashboard layout."""
self.app.layout = html.Div(
[
html.H1(
"🔥 VFS Read Latency Dashboard",
style={
"textAlign": "center",
"color": "#2c3e50",
"marginBottom": 20,
},
),
# Stats cards
html.Div(
[
self._create_stat_card(
"total-samples", "📊 Total Samples", "#3498db"
),
self._create_stat_card(
"mean-latency", "⚡ Mean Latency", "#e74c3c"
),
self._create_stat_card(
"p99-latency", "🔥 P99 Latency", "#f39c12"
),
],
style={
"display": "flex",
"justifyContent": "space-around",
"marginBottom": 30,
},
),
# Graphs - ✅ Make sure these IDs match the callback outputs
dcc.Graph(id="dual-histogram", style={"height": "450px"}),
dcc.Graph(id="log2-buckets", style={"height": "350px"}),
dcc.Graph(id="timeseries-graph", style={"height": "300px"}),
# Auto-update
dcc.Interval(id="interval-component", interval=1000, n_intervals=0),
],
style={"padding": 20, "fontFamily": "Arial, sans-serif"},
)
def _create_stat_card(self, id_name: str, title: str, color: str):
"""Create a statistics card."""
return html.Div(
[
html.H3(title, style={"color": color}),
html.H2(id=id_name, style={"fontSize": 48, "color": "#2c3e50"}),
],
className="stat-box",
style={
"background": "white",
"padding": 20,
"borderRadius": 10,
"boxShadow": "0 4px 6px rgba(0,0,0,0.1)",
"textAlign": "center",
"flex": 1,
"margin": "0 10px",
},
)
def _setup_callbacks(self):
"""Setup dashboard callbacks."""
@self.app.callback(
[
Output("total-samples", "children"),
Output("mean-latency", "children"),
Output("p99-latency", "children"),
Output("dual-histogram", "figure"), # ✅ Match layout IDs
Output("log2-buckets", "figure"), # ✅ Match layout IDs
Output("timeseries-graph", "figure"), # ✅ Match layout IDs
],
[Input("interval-component", "n_intervals")],
)
def update_dashboard(n):
stats = self.collector.get_stats()
if stats.total == 0:
return self._empty_state()
return (
f"{stats.total:,}",
f"{stats.mean:.1f} µs",
f"{stats.p99:.1f} µs",
self._create_dual_histogram(),
self._create_log2_buckets(),
self._create_timeseries(),
)
def _empty_state(self):
"""Return empty state for dashboard."""
empty_fig = go.Figure()
empty_fig.update_layout(
title="Waiting for data... Generate some disk I/O!", template="plotly_white"
)
# ✅ Return 6 values (3 stats + 3 figures)
return "0", "0 µs", "0 µs", empty_fig, empty_fig, empty_fig
def _create_dual_histogram(self) -> go.Figure:
"""Create side-by-side linear and log2 histograms."""
latencies = self.collector.get_all_latencies()
# Create subplots
fig = make_subplots(
rows=1,
cols=2,
subplot_titles=("Linear Scale", "Log2 Scale"),
horizontal_spacing=0.12,
)
# Linear histogram
fig.add_trace(
go.Histogram(
x=latencies,
nbinsx=50,
marker_color="rgb(55, 83, 109)",
opacity=0.75,
name="Linear",
),
row=1,
col=1,
)
# Log2 histogram
log2_latencies = np.log2(latencies + 1) # +1 to avoid log2(0)
fig.add_trace(
go.Histogram(
x=log2_latencies,
nbinsx=30,
marker_color="rgb(243, 156, 18)",
opacity=0.75,
name="Log2",
),
row=1,
col=2,
)
# Update axes
fig.update_xaxes(title_text="Latency (µs)", row=1, col=1)
fig.update_xaxes(title_text="log2(Latency in µs)", row=1, col=2)
fig.update_yaxes(title_text="Count", row=1, col=1)
fig.update_yaxes(title_text="Count", row=1, col=2)
fig.update_layout(
title_text="📊 Latency Distribution (Linear vs Log2)",
template="plotly_white",
showlegend=False,
height=450,
)
return fig
def _create_log2_buckets(self) -> go.Figure:
"""Create bar chart of log2 buckets (like BCC histogram)."""
buckets = self.collector.get_histogram_buckets()
if not buckets:
fig = go.Figure()
fig.update_layout(
title="🔥 Log2 Histogram - Waiting for data...", template="plotly_white"
)
return fig
# Sort buckets
sorted_buckets = sorted(buckets.keys())
counts = [buckets[b] for b in sorted_buckets]
# Create labels (e.g., "8-16µs", "16-32µs")
labels = []
hover_text = []
for bucket in sorted_buckets:
lower = 2**bucket
upper = 2 ** (bucket + 1)
labels.append(f"{lower}-{upper}")
# Calculate percentage
total = sum(counts)
pct = (buckets[bucket] / total) * 100 if total > 0 else 0
hover_text.append(
f"Range: {lower}-{upper} µs<br>"
f"Count: {buckets[bucket]:,}<br>"
f"Percentage: {pct:.2f}%"
)
# Create bar chart
fig = go.Figure()
fig.add_trace(
go.Bar(
x=labels,
y=counts,
marker=dict(
color=counts,
colorscale="YlOrRd",
showscale=True,
colorbar=dict(title="Count"),
),
text=counts,
textposition="outside",
hovertext=hover_text,
hoverinfo="text",
)
)
fig.update_layout(
title="🔥 Log2 Histogram (BCC-style buckets)",
xaxis_title="Latency Range (µs)",
yaxis_title="Count",
template="plotly_white",
height=350,
xaxis=dict(tickangle=-45),
)
return fig
def _create_timeseries(self) -> go.Figure:
"""Create time series figure."""
recent = self.collector.get_recent_latencies()
if not recent:
fig = go.Figure()
fig.update_layout(
title="⏱️ Real-time Latency - Waiting for data...",
template="plotly_white",
)
return fig
times = [d["time"] for d in recent]
lats = [d["latency"] for d in recent]
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=times,
y=lats,
mode="lines",
line=dict(color="rgb(231, 76, 60)", width=2),
fill="tozeroy",
fillcolor="rgba(231, 76, 60, 0.2)",
)
)
fig.update_layout(
title="⏱️ Real-time Latency (Last 10,000 samples)",
xaxis_title="Time (seconds)",
yaxis_title="Latency (µs)",
template="plotly_white",
height=300,
)
return fig
def run(self, host: str = "0.0.0.0", port: int = 8050, debug: bool = False):
"""Run the dashboard server."""
print(f"\n{'=' * 60}")
print(f"🚀 Dashboard running at: http://{host}:{port}")
print(" Access from your browser to see live graphs")
print(
" Generate disk I/O to see data: dd if=/dev/zero of=/tmp/test bs=1M count=100"
)
print(f"{'=' * 60}\n")
self.app.run(debug=debug, host=host, port=port)

View File

@ -1,96 +0,0 @@
"""Data collection and management."""
import threading
import time
import numpy as np
from collections import deque
from dataclasses import dataclass
from typing import List, Dict
@dataclass
class LatencyStats:
"""Statistics computed from latency data."""
total: int = 0
mean: float = 0.0
median: float = 0.0
min: float = 0.0
max: float = 0.0
p95: float = 0.0
p99: float = 0.0
@classmethod
def from_array(cls, data: np.ndarray) -> "LatencyStats":
"""Compute stats from numpy array."""
if len(data) == 0:
return cls()
return cls(
total=len(data),
mean=float(np.mean(data)),
median=float(np.median(data)),
min=float(np.min(data)),
max=float(np.max(data)),
p95=float(np.percentile(data, 95)),
p99=float(np.percentile(data, 99)),
)
class LatencyCollector:
"""Collects and manages latency data from BPF."""
def __init__(self, bpf_object, buffer_size: int = 10000):
self.bpf = bpf_object
self.all_latencies: List[float] = []
self.recent_latencies = deque(maxlen=buffer_size) # type: ignore [var-annotated]
self.start_time = time.time()
self._lock = threading.Lock()
self._poll_thread = None
def callback(self, cpu: int, event):
"""Callback for BPF events."""
with self._lock:
self.all_latencies.append(event.delta_us)
self.recent_latencies.append(
{"time": time.time() - self.start_time, "latency": event.delta_us}
)
def start(self):
"""Start collecting data."""
self.bpf["events"].open_perf_buffer(self.callback, struct_name="latency_event")
def poll_loop():
while True:
self.bpf["events"].poll(100)
self._poll_thread = threading.Thread(target=poll_loop, daemon=True)
self._poll_thread.start()
print("✅ Data collection started")
def get_all_latencies(self) -> np.ndarray:
"""Get all latencies as numpy array."""
with self._lock:
return np.array(self.all_latencies) if self.all_latencies else np.array([])
def get_recent_latencies(self) -> List[Dict]:
"""Get recent latencies with timestamps."""
with self._lock:
return list(self.recent_latencies)
def get_stats(self) -> LatencyStats:
"""Compute current statistics."""
return LatencyStats.from_array(self.get_all_latencies())
def get_histogram_buckets(self) -> Dict[int, int]:
"""Get log2 histogram buckets."""
latencies = self.get_all_latencies()
if len(latencies) == 0:
return {}
log_buckets = np.floor(np.log2(latencies + 1)).astype(int)
buckets = {} # type: ignore [var-annotated]
for bucket in log_buckets:
buckets[bucket] = buckets.get(bucket, 0) + 1
return buckets

View File

@ -1,178 +0,0 @@
from pythonbpf import bpf, map, struct, section, bpfglobal, BPF
from pythonbpf.helper import ktime, pid
from pythonbpf.maps import HashMap, PerfEventArray
from ctypes import c_void_p, c_uint64
from rich.console import Console
from rich.live import Live
from rich.table import Table
from rich.panel import Panel
from rich.layout import Layout
import numpy as np
import threading
import time
from collections import Counter
# ==================== BPF Setup ====================
@bpf
@struct
class latency_event:
pid: c_uint64
delta_us: c_uint64
@bpf
@map
def start() -> HashMap:
return HashMap(key=c_uint64, value=c_uint64, max_entries=10240)
@bpf
@map
def events() -> PerfEventArray:
return PerfEventArray(key_size=c_uint64, value_size=c_uint64)
@bpf
@section("kprobe/vfs_read")
def do_entry(ctx: c_void_p) -> c_uint64:
p, ts = pid(), ktime()
start.update(p, ts)
return 0 # type: ignore [return-value]
@bpf
@section("kretprobe/vfs_read")
def do_return(ctx: c_void_p) -> c_uint64:
p = pid()
tsp = start.lookup(p)
if tsp:
delta_ns = ktime() - tsp
if delta_ns > 1000:
evt = latency_event()
evt.pid, evt.delta_us = p, delta_ns // 1000
events.output(evt)
start.delete(p)
return 0 # type: ignore [return-value]
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
console = Console()
console.print("[bold green]Loading BPF program...[/]")
b = BPF()
b.load()
b.attach_all()
# ==================== Data Collection ====================
all_latencies = []
histogram_buckets = Counter() # type: ignore [var-annotated]
def callback(cpu, event):
all_latencies.append(event.delta_us)
# Create log2 bucket
bucket = int(np.floor(np.log2(event.delta_us + 1)))
histogram_buckets[bucket] += 1
b["events"].open_perf_buffer(callback, struct_name="latency_event")
def poll_events():
while True:
b["events"].poll(100)
poll_thread = threading.Thread(target=poll_events, daemon=True)
poll_thread.start()
# ==================== Live Display ====================
def generate_display():
layout = Layout()
layout.split_column(
Layout(name="header", size=3),
Layout(name="stats", size=8),
Layout(name="histogram", size=20),
)
# Header
layout["header"].update(
Panel("[bold cyan]🔥 VFS Read Latency Monitor[/]", style="bold white on blue")
)
# Stats
if len(all_latencies) > 0:
lats = np.array(all_latencies)
stats_table = Table(show_header=False, box=None, padding=(0, 2))
stats_table.add_column(style="bold cyan")
stats_table.add_column(style="bold yellow")
stats_table.add_row("📊 Total Samples:", f"{len(lats):,}")
stats_table.add_row("⚡ Mean Latency:", f"{np.mean(lats):.2f} µs")
stats_table.add_row("📉 Min Latency:", f"{np.min(lats):.2f} µs")
stats_table.add_row("📈 Max Latency:", f"{np.max(lats):.2f} µs")
stats_table.add_row("🎯 P95 Latency:", f"{np.percentile(lats, 95):.2f} µs")
stats_table.add_row("🔥 P99 Latency:", f"{np.percentile(lats, 99):.2f} µs")
layout["stats"].update(
Panel(stats_table, title="Statistics", border_style="green")
)
else:
layout["stats"].update(
Panel("[yellow]Waiting for data...[/]", border_style="yellow")
)
# Histogram
if histogram_buckets:
hist_table = Table(title="Latency Distribution", box=None)
hist_table.add_column("Range", style="cyan", no_wrap=True)
hist_table.add_column("Count", justify="right", style="yellow")
hist_table.add_column("Distribution", style="green")
max_count = max(histogram_buckets.values())
for bucket in sorted(histogram_buckets.keys()):
count = histogram_buckets[bucket]
lower = 2**bucket
upper = 2 ** (bucket + 1)
# Create bar
bar_width = int((count / max_count) * 40)
bar = "" * bar_width
hist_table.add_row(
f"{lower:5d}-{upper:5d} µs",
f"{count:6d}",
f"[green]{bar}[/] {count / len(all_latencies) * 100:.1f}%",
)
layout["histogram"].update(Panel(hist_table, border_style="green"))
return layout
try:
with Live(generate_display(), refresh_per_second=2, console=console) as live:
while True:
time.sleep(0.5)
live.update(generate_display())
except KeyboardInterrupt:
console.print("\n[bold red]Stopping...[/]")
if all_latencies:
console.print(f"\n[bold green]✅ Collected {len(all_latencies):,} samples[/]")

View File

@ -40,12 +40,6 @@ Python-BPF is an LLVM IR generator for eBPF programs written in Python. It uses
---
## Try It Out!
Run
```bash
curl -s https://raw.githubusercontent.com/pythonbpf/Python-BPF/refs/heads/master/tools/setup.sh | sudo bash
```
## Installation
Dependencies:
@ -89,14 +83,14 @@ def hist() -> HashMap:
def hello(ctx: c_void_p) -> c_int64:
process_id = pid()
one = 1
prev = hist.lookup(process_id)
prev = hist().lookup(process_id)
if prev:
previous_value = prev + 1
print(f"count: {previous_value} with {process_id}")
hist.update(process_id, previous_value)
hist().update(process_id, previous_value)
return c_int64(0)
else:
hist.update(process_id, one)
hist().update(process_id, one)
return c_int64(0)

View File

@ -308,7 +308,6 @@
"def hist() -> HashMap:\n",
" return HashMap(key=c_int32, value=c_uint64, max_entries=4096)\n",
"\n",
"\n",
"@bpf\n",
"@section(\"tracepoint/syscalls/sys_enter_clone\")\n",
"def hello(ctx: c_void_p) -> c_int64:\n",
@ -330,7 +329,6 @@
"def LICENSE() -> str:\n",
" return \"GPL\"\n",
"\n",
"\n",
"b = BPF()"
]
},
@ -359,6 +357,7 @@
}
],
"source": [
"\n",
"b.load_and_attach()\n",
"hist = BpfMap(b, hist)\n",
"print(\"Recording\")\n",

View File

@ -8,14 +8,12 @@ def hello_world(ctx: c_void_p) -> c_int64:
print("Hello, World!")
return c_int64(0)
@bpf
@section("kprobe/do_unlinkat")
def hello_world2(ctx: c_void_p) -> c_int64:
print("Hello, World!")
return c_int64(0)
@bpf
@bpfglobal
def LICENSE() -> str:

View File

@ -27,7 +27,7 @@ def hello(ctx: c_void_p) -> c_int32:
dataobj.pid = pid()
dataobj.ts = ktime()
# dataobj.comm = strobj
print(f"clone called at {dataobj.ts} by pid{dataobj.pid}, comm {strobj}")
print(f"clone called at {dataobj.ts} by pid" f"{dataobj.pid}, comm {strobj}")
events.output(dataobj)
return c_int32(0)

248446
examples/vmlinux.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,8 @@
from pythonbpf import bpf, map, section, bpfglobal, compile, compile_to_ir
from pythonbpf import bpf, map, section, bpfglobal, compile
from pythonbpf.helper import XDP_PASS
from pythonbpf.maps import HashMap
from ctypes import c_int64, c_void_p
from ctypes import c_void_p, c_int64
# Instructions to how to run this program
# 1. Install PythonBPF: pip install pythonbpf
@ -41,5 +41,4 @@ def LICENSE() -> str:
return "GPL"
compile_to_ir("xdp_pass.py", "xdp_pass.ll")
compile()

View File

@ -4,26 +4,12 @@ build-backend = "setuptools.build_meta"
[project]
name = "pythonbpf"
version = "0.1.6"
version = "0.1.4"
description = "Reduced Python frontend for eBPF"
authors = [
{ name = "r41k0u", email="pragyanshchaturvedi18@gmail.com" },
{ name = "varun-r-mallya", email="varunrmallya@gmail.com" }
]
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Operating System :: POSIX :: Linux",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: System :: Operating System Kernels :: Linux",
]
readme = "README.md"
license = {text = "Apache-2.0"}
requires-python = ">=3.8"

View File

@ -1,6 +1,12 @@
"""
PythonBPF - A Python frontend for eBPF programs.
This package provides decorators and compilation tools to write BPF programs
in Python syntax and compile them to LLVM IR that can be compiled to eBPF bytecode.
"""
from .decorators import bpf, map, section, bpfglobal, struct
from .codegen import compile_to_ir, compile, BPF
from .utils import trace_pipe, trace_fields
__all__ = [
"bpf",
@ -11,6 +17,4 @@ __all__ = [
"compile_to_ir",
"compile",
"BPF",
"trace_pipe",
"trace_fields",
]

View File

@ -1,297 +0,0 @@
import ast
import logging
from llvmlite import ir
from dataclasses import dataclass
from typing import Any
from pythonbpf.helper import HelperHandlerRegistry
from .expr import VmlinuxHandlerRegistry
from pythonbpf.type_deducer import ctypes_to_ir
logger = logging.getLogger(__name__)
@dataclass
class LocalSymbol:
var: ir.AllocaInstr
ir_type: ir.Type
metadata: Any = None
def __iter__(self):
yield self.var
yield self.ir_type
yield self.metadata
def create_targets_and_rvals(stmt):
"""Create lists of targets and right-hand values from an assignment statement."""
if isinstance(stmt.targets[0], ast.Tuple):
if not isinstance(stmt.value, ast.Tuple):
logger.warning("Mismatched multi-target assignment, skipping allocation")
return [], []
targets, rvals = stmt.targets[0].elts, stmt.value.elts
if len(targets) != len(rvals):
logger.warning("length of LHS != length of RHS, skipping allocation")
return [], []
return targets, rvals
return stmt.targets, [stmt.value]
def handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab):
"""Handle memory allocation for assignment statements."""
logger.info(f"Handling assignment for allocation: {ast.dump(stmt)}")
# NOTE: Support multi-target assignments (e.g.: a, b = 1, 2)
targets, rvals = create_targets_and_rvals(stmt)
for target, rval in zip(targets, rvals):
# Skip non-name targets (e.g., struct field assignments)
if isinstance(target, ast.Attribute):
logger.debug(
f"Struct field assignment to {target.attr}, no allocation needed"
)
continue
if not isinstance(target, ast.Name):
logger.warning(
f"Unsupported assignment target type: {type(target).__name__}"
)
continue
var_name = target.id
# Skip if already allocated
if var_name in local_sym_tab:
logger.debug(f"Variable {var_name} already allocated, skipping")
continue
# When allocating a variable, check if it's a vmlinux struct type
if isinstance(
stmt.value, ast.Name
) and VmlinuxHandlerRegistry.is_vmlinux_struct(stmt.value.id):
# Handle vmlinux struct allocation
# This requires more implementation
print(stmt.value)
pass
# Determine type and allocate based on rval
if isinstance(rval, ast.Call):
_allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab)
elif isinstance(rval, ast.Constant):
_allocate_for_constant(builder, var_name, rval, local_sym_tab)
elif isinstance(rval, ast.BinOp):
_allocate_for_binop(builder, var_name, local_sym_tab)
elif isinstance(rval, ast.Name):
# Variable-to-variable assignment (b = a)
_allocate_for_name(builder, var_name, rval, local_sym_tab)
elif isinstance(rval, ast.Attribute):
# Struct field-to-variable assignment (a = dat.fld)
_allocate_for_attribute(
builder, var_name, rval, local_sym_tab, structs_sym_tab
)
else:
logger.warning(
f"Unsupported assignment value type for {var_name}: {type(rval).__name__}"
)
def _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab):
"""Allocate memory for variable assigned from a call."""
if isinstance(rval.func, ast.Name):
call_type = rval.func.id
# C type constructors
if call_type in ("c_int32", "c_int64", "c_uint32", "c_uint64"):
ir_type = ctypes_to_ir(call_type)
var = builder.alloca(ir_type, name=var_name)
var.align = ir_type.width // 8
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
logger.info(f"Pre-allocated {var_name} as {call_type}")
# Helper functions
elif HelperHandlerRegistry.has_handler(call_type):
ir_type = ir.IntType(64) # Assume i64 return type
var = builder.alloca(ir_type, name=var_name)
var.align = 8
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
logger.info(f"Pre-allocated {var_name} for helper {call_type}")
# Deref function
elif call_type == "deref":
ir_type = ir.IntType(64) # Assume i64 return type
var = builder.alloca(ir_type, name=var_name)
var.align = 8
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
logger.info(f"Pre-allocated {var_name} for deref")
# Struct constructors
elif call_type in structs_sym_tab:
struct_info = structs_sym_tab[call_type]
var = builder.alloca(struct_info.ir_type, name=var_name)
local_sym_tab[var_name] = LocalSymbol(var, struct_info.ir_type, call_type)
logger.info(f"Pre-allocated {var_name} for struct {call_type}")
else:
logger.warning(f"Unknown call type for allocation: {call_type}")
elif isinstance(rval.func, ast.Attribute):
# Map method calls - need double allocation for ptr handling
_allocate_for_map_method(builder, var_name, local_sym_tab)
else:
logger.warning(f"Unsupported call function type for {var_name}")
def _allocate_for_map_method(builder, var_name, local_sym_tab):
"""Allocate memory for variable assigned from map method (double alloc)."""
# Main variable (pointer to pointer)
ir_type = ir.PointerType(ir.IntType(64))
var = builder.alloca(ir_type, name=var_name)
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
# Temporary variable for computed values
tmp_ir_type = ir.IntType(64)
var_tmp = builder.alloca(tmp_ir_type, name=f"{var_name}_tmp")
local_sym_tab[f"{var_name}_tmp"] = LocalSymbol(var_tmp, tmp_ir_type)
logger.info(f"Pre-allocated {var_name} and {var_name}_tmp for map method")
def _allocate_for_constant(builder, var_name, rval, local_sym_tab):
"""Allocate memory for variable assigned from a constant."""
if isinstance(rval.value, bool):
ir_type = ir.IntType(1)
var = builder.alloca(ir_type, name=var_name)
var.align = 1
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
logger.info(f"Pre-allocated {var_name} as bool")
elif isinstance(rval.value, int):
ir_type = ir.IntType(64)
var = builder.alloca(ir_type, name=var_name)
var.align = 8
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
logger.info(f"Pre-allocated {var_name} as i64")
elif isinstance(rval.value, str):
ir_type = ir.PointerType(ir.IntType(8))
var = builder.alloca(ir_type, name=var_name)
var.align = 8
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
logger.info(f"Pre-allocated {var_name} as string")
else:
logger.warning(
f"Unsupported constant type for {var_name}: {type(rval.value).__name__}"
)
def _allocate_for_binop(builder, var_name, local_sym_tab):
"""Allocate memory for variable assigned from a binary operation."""
ir_type = ir.IntType(64) # Assume i64 result
var = builder.alloca(ir_type, name=var_name)
var.align = 8
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
logger.info(f"Pre-allocated {var_name} for binop result")
def allocate_temp_pool(builder, max_temps, local_sym_tab):
"""Allocate the temporary scratch space pool for helper arguments."""
if max_temps == 0:
return
logger.info(f"Allocating temp pool of {max_temps} variables")
for i in range(max_temps):
temp_name = f"__helper_temp_{i}"
temp_var = builder.alloca(ir.IntType(64), name=temp_name)
temp_var.align = 8
local_sym_tab[temp_name] = LocalSymbol(temp_var, ir.IntType(64))
def _allocate_for_name(builder, var_name, rval, local_sym_tab):
"""Allocate memory for variable-to-variable assignment (b = a)."""
source_var = rval.id
if source_var not in local_sym_tab:
logger.error(f"Source variable '{source_var}' not found in symbol table")
return
# Get type and metadata from source variable
source_symbol = local_sym_tab[source_var]
# Allocate with same type and alignment
var = _allocate_with_type(builder, var_name, source_symbol.ir_type)
local_sym_tab[var_name] = LocalSymbol(
var, source_symbol.ir_type, source_symbol.metadata
)
logger.info(
f"Pre-allocated {var_name} from {source_var} with type {source_symbol.ir_type}"
)
def _allocate_for_attribute(builder, var_name, rval, local_sym_tab, structs_sym_tab):
"""Allocate memory for struct field-to-variable assignment (a = dat.fld)."""
if not isinstance(rval.value, ast.Name):
logger.warning(f"Complex attribute access not supported for {var_name}")
return
struct_var = rval.value.id
field_name = rval.attr
# Validate struct and field
if struct_var not in local_sym_tab:
logger.error(f"Struct variable '{struct_var}' not found")
return
struct_type = local_sym_tab[struct_var].metadata
if not struct_type or struct_type not in structs_sym_tab:
logger.error(f"Struct type '{struct_type}' not found")
return
struct_info = structs_sym_tab[struct_type]
if field_name not in struct_info.fields:
logger.error(f"Field '{field_name}' not found in struct '{struct_type}'")
return
# Get field type
field_type = struct_info.field_type(field_name)
# Special case: char array -> allocate as i8* pointer instead
if (
isinstance(field_type, ir.ArrayType)
and isinstance(field_type.element, ir.IntType)
and field_type.element.width == 8
):
alloc_type = ir.PointerType(ir.IntType(8))
logger.info(f"Allocating {var_name} as i8* (pointer to char array)")
else:
alloc_type = field_type
var = _allocate_with_type(builder, var_name, alloc_type)
local_sym_tab[var_name] = LocalSymbol(var, alloc_type)
logger.info(
f"Pre-allocated {var_name} from {struct_var}.{field_name} with type {alloc_type}"
)
def _allocate_with_type(builder, var_name, ir_type):
"""Allocate variable with appropriate alignment for type."""
var = builder.alloca(ir_type, name=var_name)
var.align = _get_alignment(ir_type)
return var
def _get_alignment(ir_type):
"""Get appropriate alignment for IR type."""
if isinstance(ir_type, ir.IntType):
return ir_type.width // 8
elif isinstance(ir_type, ir.ArrayType) and isinstance(ir_type.element, ir.IntType):
return ir_type.element.width // 8
else:
return 8 # Default: pointer size

View File

@ -1,224 +0,0 @@
import ast
import logging
from llvmlite import ir
from pythonbpf.expr import eval_expr
from pythonbpf.helper import emit_probe_read_kernel_str_call
logger = logging.getLogger(__name__)
def handle_struct_field_assignment(
func, module, builder, target, rval, local_sym_tab, map_sym_tab, structs_sym_tab
):
"""Handle struct field assignment (obj.field = value)."""
var_name = target.value.id
field_name = target.attr
if var_name not in local_sym_tab:
logger.error(f"Variable '{var_name}' not found in symbol table")
return
struct_type = local_sym_tab[var_name].metadata
struct_info = structs_sym_tab[struct_type]
if field_name not in struct_info.fields:
logger.error(f"Field '{field_name}' not found in struct '{struct_type}'")
return
# Get field pointer and evaluate value
field_ptr = struct_info.gep(builder, local_sym_tab[var_name].var, field_name)
field_type = struct_info.field_type(field_name)
val_result = eval_expr(
func, module, builder, rval, local_sym_tab, map_sym_tab, structs_sym_tab
)
if val_result is None:
logger.error(f"Failed to evaluate value for {var_name}.{field_name}")
return
val, val_type = val_result
# Special case: i8* string to [N x i8] char array
if _is_char_array(field_type) and _is_i8_ptr(val_type):
_copy_string_to_char_array(
func,
module,
builder,
val,
field_ptr,
field_type,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
logger.info(f"Copied string to char array {var_name}.{field_name}")
return
# Regular assignment
builder.store(val, field_ptr)
logger.info(f"Assigned to struct field {var_name}.{field_name}")
def _copy_string_to_char_array(
func,
module,
builder,
src_ptr,
dst_ptr,
array_type,
local_sym_tab,
map_sym_tab,
struct_sym_tab,
):
"""Copy string (i8*) to char array ([N x i8]) using bpf_probe_read_kernel_str"""
array_size = array_type.count
# Get pointer to first element: [N x i8]* -> i8*
dst_i8_ptr = builder.gep(
dst_ptr,
[ir.Constant(ir.IntType(32), 0), ir.Constant(ir.IntType(32), 0)],
inbounds=True,
)
# Use the shared emitter function
emit_probe_read_kernel_str_call(builder, dst_i8_ptr, array_size, src_ptr)
def _is_char_array(ir_type):
"""Check if type is [N x i8]."""
return (
isinstance(ir_type, ir.ArrayType)
and isinstance(ir_type.element, ir.IntType)
and ir_type.element.width == 8
)
def _is_i8_ptr(ir_type):
"""Check if type is i8*."""
return (
isinstance(ir_type, ir.PointerType)
and isinstance(ir_type.pointee, ir.IntType)
and ir_type.pointee.width == 8
)
def handle_variable_assignment(
func, module, builder, var_name, rval, local_sym_tab, map_sym_tab, structs_sym_tab
):
"""Handle single named variable assignment."""
if var_name not in local_sym_tab:
logger.error(f"Variable {var_name} not declared.")
return False
var_ptr = local_sym_tab[var_name].var
var_type = local_sym_tab[var_name].ir_type
# NOTE: Special case for struct initialization
if isinstance(rval, ast.Call) and isinstance(rval.func, ast.Name):
struct_name = rval.func.id
if struct_name in structs_sym_tab and len(rval.args) == 0:
struct_info = structs_sym_tab[struct_name]
ir_struct = struct_info.ir_type
builder.store(ir.Constant(ir_struct, None), var_ptr)
logger.info(f"Initialized struct {struct_name} for variable {var_name}")
return True
# Special case: struct field char array -> pointer
# Handle this before eval_expr to get the pointer, not the value
if isinstance(rval, ast.Attribute) and isinstance(rval.value, ast.Name):
converted_val = _try_convert_char_array_to_ptr(
rval, var_type, builder, local_sym_tab, structs_sym_tab
)
if converted_val is not None:
builder.store(converted_val, var_ptr)
logger.info(f"Assigned char array pointer to {var_name}")
return True
val_result = eval_expr(
func, module, builder, rval, local_sym_tab, map_sym_tab, structs_sym_tab
)
if val_result is None:
logger.error(f"Failed to evaluate value for {var_name}")
return False
val, val_type = val_result
logger.info(f"Evaluated value for {var_name}: {val} of type {val_type}, {var_type}")
if val_type != var_type:
if isinstance(val_type, ir.IntType) and isinstance(var_type, ir.IntType):
# Allow implicit int widening
if val_type.width < var_type.width:
val = builder.sext(val, var_type)
logger.info(f"Implicitly widened int for variable {var_name}")
elif val_type.width > var_type.width:
val = builder.trunc(val, var_type)
logger.info(f"Implicitly truncated int for variable {var_name}")
elif isinstance(val_type, ir.IntType) and isinstance(var_type, ir.PointerType):
# NOTE: This is assignment to a PTR_TO_MAP_VALUE_OR_NULL
logger.info(
f"Creating temporary variable for pointer assignment to {var_name}"
)
var_ptr_tmp = local_sym_tab[f"{var_name}_tmp"].var
builder.store(val, var_ptr_tmp)
val = var_ptr_tmp
else:
logger.error(
f"Type mismatch for variable {var_name}: {val_type} vs {var_type}"
)
return False
builder.store(val, var_ptr)
logger.info(f"Assigned value to variable {var_name}")
return True
def _try_convert_char_array_to_ptr(
rval, var_type, builder, local_sym_tab, structs_sym_tab
):
"""Try to convert char array field to i8* pointer"""
# Only convert if target is i8*
if not (
isinstance(var_type, ir.PointerType)
and isinstance(var_type.pointee, ir.IntType)
and var_type.pointee.width == 8
):
return None
struct_var = rval.value.id
field_name = rval.attr
# Validate struct
if struct_var not in local_sym_tab:
return None
struct_type = local_sym_tab[struct_var].metadata
if not struct_type or struct_type not in structs_sym_tab:
return None
struct_info = structs_sym_tab[struct_type]
if field_name not in struct_info.fields:
return None
field_type = struct_info.field_type(field_name)
# Check if it's a char array
if not (
isinstance(field_type, ir.ArrayType)
and isinstance(field_type.element, ir.IntType)
and field_type.element.width == 8
):
return None
# Get pointer to struct field
struct_ptr = local_sym_tab[struct_var].var
field_ptr = struct_info.gep(builder, struct_ptr, field_name)
# GEP to first element: [N x i8]* -> i8*
return builder.gep(
field_ptr,
[ir.Constant(ir.IntType(32), 0), ir.Constant(ir.IntType(32), 0)],
inbounds=True,
)

102
pythonbpf/binary_ops.py Normal file
View File

@ -0,0 +1,102 @@
"""
Binary operations handling for BPF programs.
This module provides functions to handle binary operations (add, subtract,
multiply, etc.) and emit the corresponding LLVM IR instructions.
"""
import ast
from llvmlite import ir
from logging import Logger
import logging
logger: Logger = logging.getLogger(__name__)
def recursive_dereferencer(var, builder):
"""dereference until primitive type comes out"""
# TODO: Not worrying about stack overflow for now
logger.info(f"Dereferencing {var}, type is {var.type}")
if isinstance(var.type, ir.PointerType):
a = builder.load(var)
return recursive_dereferencer(a, builder)
elif isinstance(var.type, ir.IntType):
return var
else:
raise TypeError(f"Unsupported type for dereferencing: {var.type}")
def get_operand_value(operand, builder, local_sym_tab):
"""Extract the value from an operand, handling variables and constants."""
if isinstance(operand, ast.Name):
if operand.id in local_sym_tab:
return recursive_dereferencer(local_sym_tab[operand.id].var, builder)
raise ValueError(f"Undefined variable: {operand.id}")
elif isinstance(operand, ast.Constant):
if isinstance(operand.value, int):
return ir.Constant(ir.IntType(64), operand.value)
raise TypeError(f"Unsupported constant type: {type(operand.value)}")
elif isinstance(operand, ast.BinOp):
return handle_binary_op_impl(operand, builder, local_sym_tab)
raise TypeError(f"Unsupported operand type: {type(operand)}")
def handle_binary_op_impl(rval, builder, local_sym_tab):
"""
Handle binary operations and emit corresponding LLVM IR instructions.
Args:
rval: The AST BinOp node representing the binary operation
builder: LLVM IR builder for emitting instructions
local_sym_tab: Symbol table mapping variable names to their IR representations
Returns:
The LLVM IR value representing the result of the binary operation
"""
op = rval.op
left = get_operand_value(rval.left, builder, local_sym_tab)
right = get_operand_value(rval.right, builder, local_sym_tab)
logger.info(f"left is {left}, right is {right}, op is {op}")
# Map AST operation nodes to LLVM IR builder methods
op_map = {
ast.Add: builder.add,
ast.Sub: builder.sub,
ast.Mult: builder.mul,
ast.Div: builder.sdiv,
ast.Mod: builder.srem,
ast.LShift: builder.shl,
ast.RShift: builder.lshr,
ast.BitOr: builder.or_,
ast.BitXor: builder.xor,
ast.BitAnd: builder.and_,
ast.FloorDiv: builder.udiv,
}
if type(op) in op_map:
result = op_map[type(op)](left, right)
return result
else:
raise SyntaxError("Unsupported binary operation")
def handle_binary_op(rval, builder, var_name, local_sym_tab):
"""
Handle binary operations and optionally store the result to a variable.
Args:
rval: The AST BinOp node representing the binary operation
builder: LLVM IR builder for emitting instructions
var_name: Optional variable name to store the result
local_sym_tab: Symbol table mapping variable names to their IR representations
Returns:
A tuple of (result_value, result_type)
"""
result = handle_binary_op_impl(rval, builder, local_sym_tab)
if var_name and var_name in local_sym_tab:
logger.info(
f"Storing result {result} into variable {local_sym_tab[var_name].var}"
)
builder.store(result, local_sym_tab[var_name].var)
return result, result.type

View File

@ -1,12 +1,17 @@
"""
Code generation module for PythonBPF.
This module handles the conversion of Python BPF programs to LLVM IR and
object files. It provides the main compilation pipeline from Python AST
to BPF bytecode.
"""
import ast
from llvmlite import ir
from .license_pass import license_processing
from .functions import func_proc
from .maps import maps_proc
from .structs import structs_proc
from .vmlinux_parser import vmlinux_proc
from pythonbpf.vmlinux_parser.vmlinux_exports_handler import VmlinuxHandler
from .expr import VmlinuxHandlerRegistry
from .globals_pass import (
globals_list_creation,
globals_processing,
@ -17,24 +22,14 @@ import os
import subprocess
import inspect
from pathlib import Path
from pylibbpf import BpfObject
from pylibbpf import BpfProgram
import tempfile
from logging import Logger
import logging
import re
logger: Logger = logging.getLogger(__name__)
VERSION = "v0.1.6"
def finalize_module(original_str):
"""After all IR generation is complete, we monkey patch btf_ama attribute"""
# Create a string with applied transformation of btf_ama attribute addition to BTF struct field accesses.
pattern = r'(@"llvm\.[^"]+:[^"]*" = external global i64, !llvm\.preserve\.access\.index ![0-9]+)'
replacement = r'\1 "btf_ama"'
return re.sub(pattern, replacement, original_str)
VERSION = "v0.1.4"
def find_bpf_chunks(tree):
@ -50,6 +45,14 @@ def find_bpf_chunks(tree):
def processor(source_code, filename, module):
"""
Process Python source code and convert BPF-decorated functions to LLVM IR.
Args:
source_code: The Python source code to process
filename: The name of the source file
module: The LLVM IR module to populate
"""
tree = ast.parse(source_code, filename)
logger.debug(ast.dump(tree, indent=4))
@ -57,23 +60,29 @@ def processor(source_code, filename, module):
for func_node in bpf_chunks:
logger.info(f"Found BPF function/struct: {func_node.name}")
vmlinux_symtab = vmlinux_proc(tree, module)
if vmlinux_symtab:
handler = VmlinuxHandler.initialize(vmlinux_symtab)
VmlinuxHandlerRegistry.set_handler(handler)
populate_global_symbol_table(tree, module)
license_processing(tree, module)
globals_processing(tree, module)
structs_sym_tab = structs_proc(tree, module, bpf_chunks)
map_sym_tab = maps_proc(tree, module, bpf_chunks)
func_proc(tree, module, bpf_chunks, map_sym_tab, structs_sym_tab)
globals_list_creation(tree, module)
return structs_sym_tab, map_sym_tab
def compile_to_ir(filename: str, output: str, loglevel=logging.INFO):
"""
Compile a Python BPF program to LLVM IR.
Args:
filename: Path to the Python source file containing BPF programs
output: Path where the LLVM IR (.ll) file will be written
loglevel: Logging level for compilation messages
Returns:
Path to the generated LLVM IR file
"""
logging.basicConfig(
level=loglevel, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
@ -96,7 +105,7 @@ def compile_to_ir(filename: str, output: str, loglevel=logging.INFO):
True,
)
structs_sym_tab, maps_sym_tab = processor(source, filename, module)
processor(source, filename, module)
wchar_size = module.add_metadata(
[
@ -137,45 +146,28 @@ def compile_to_ir(filename: str, output: str, loglevel=logging.INFO):
module.add_named_metadata("llvm.ident", [f"PythonBPF {VERSION}"])
module_string = finalize_module(str(module))
logger.info(f"IR written to {output}")
with open(output, "w") as f:
f.write(f'source_filename = "{filename}"\n')
f.write(module_string)
f.write(str(module))
f.write("\n")
return output, structs_sym_tab, maps_sym_tab
return output
def _run_llc(ll_file, obj_file):
"""Compile LLVM IR to BPF object file using llc."""
def compile(loglevel=logging.INFO) -> bool:
"""
Compile the calling Python BPF program to an object file.
logger.info(f"Compiling IR to object: {ll_file} -> {obj_file}")
result = subprocess.run(
[
"llc",
"-march=bpf",
"-filetype=obj",
"-O2",
str(ll_file),
"-o",
str(obj_file),
],
check=True,
capture_output=True,
text=True,
)
This function should be called from a Python file containing BPF programs.
It will compile the calling file to LLVM IR and then to a BPF object file.
if result.returncode == 0:
logger.info(f"Object file written to {obj_file}")
return True
else:
logger.error(f"llc compilation failed: {result.stderr}")
return False
Args:
loglevel: Logging level for compilation messages
def compile(loglevel=logging.WARNING) -> bool:
Returns:
True if compilation succeeded, False otherwise
"""
# Look one level up the stack to the caller of this function
caller_frame = inspect.stack()[1]
caller_file = Path(caller_frame.filename).resolve()
@ -183,19 +175,44 @@ def compile(loglevel=logging.WARNING) -> bool:
ll_file = Path("/tmp") / caller_file.with_suffix(".ll").name
o_file = caller_file.with_suffix(".o")
_, structs_sym_tab, maps_sym_tab = compile_to_ir(
str(caller_file), str(ll_file), loglevel=loglevel
success = True
success = (
compile_to_ir(str(caller_file), str(ll_file), loglevel=loglevel) and success
)
if not _run_llc(ll_file, o_file):
logger.error("Compilation to object file failed.")
return False
success = bool(
subprocess.run(
[
"llc",
"-march=bpf",
"-filetype=obj",
"-O2",
str(ll_file),
"-o",
str(o_file),
],
check=True,
)
and success
)
logger.info(f"Object written to {o_file}")
return True
return success
def BPF(loglevel=logging.WARNING) -> BpfObject:
def BPF(loglevel=logging.INFO) -> BpfProgram:
"""
Compile the calling Python BPF program and return a BpfProgram object.
This function compiles the calling file's BPF programs to an object file
and loads it into a BpfProgram object for immediate use.
Args:
loglevel: Logging level for compilation messages
Returns:
A BpfProgram object that can be used to load and attach BPF programs
"""
caller_frame = inspect.stack()[1]
src = inspect.getsource(caller_frame.frame)
with tempfile.NamedTemporaryFile(
@ -208,9 +225,18 @@ def BPF(loglevel=logging.WARNING) -> BpfObject:
f.write(src)
f.flush()
source = f.name
_, structs_sym_tab, maps_sym_tab = compile_to_ir(
source, str(inter.name), loglevel=loglevel
compile_to_ir(source, str(inter.name), loglevel=loglevel)
subprocess.run(
[
"llc",
"-march=bpf",
"-filetype=obj",
"-O2",
str(inter.name),
"-o",
str(obj_file.name),
],
check=True,
)
_run_llc(str(inter.name), str(obj_file.name))
return BpfObject(str(obj_file.name), structs=structs_sym_tab)
return BpfProgram(str(obj_file.name))

View File

@ -1,3 +1,5 @@
"""Debug information generation for BPF programs (DWARF/BTF)."""
from .dwarf_constants import * # noqa: F403
from .dtypes import * # noqa: F403
from .debug_info_generator import DebugInfoGenerator

View File

@ -8,11 +8,31 @@ from typing import Any, List
class DebugInfoGenerator:
"""
Generator for DWARF/BTF debug information in LLVM IR modules.
This class provides methods to create debug metadata for BPF programs,
including types, structs, globals, and compilation units.
"""
def __init__(self, module):
"""
Initialize the debug info generator.
Args:
module: LLVM IR module to attach debug info to
"""
self.module = module
self._type_cache = {} # Cache for common debug types
def generate_file_metadata(self, filename, dirname):
"""
Generate file metadata for debug info.
Args:
filename: Name of the source file
dirname: Directory containing the source file
"""
self.module._file_metadata = self.module.add_debug_info(
"DIFile",
{ # type: ignore
@ -24,6 +44,15 @@ class DebugInfoGenerator:
def generate_debug_cu(
self, language, producer: str, is_optimized: bool, is_distinct: bool
):
"""
Generate debug compile unit metadata.
Args:
language: DWARF language code (e.g., DW_LANG_C11)
producer: Compiler/producer string
is_optimized: Whether the code is optimized
is_distinct: Whether the compile unit should be distinct
"""
self.module._debug_compile_unit = self.module.add_debug_info(
"DICompileUnit",
{ # type: ignore
@ -81,22 +110,18 @@ class DebugInfoGenerator:
},
)
def create_array_type_vmlinux(self, type_info: Any, count: int) -> Any:
"""Create an array type of the given base type with specified count"""
base_type, type_sizing = type_info
subrange = self.module.add_debug_info("DISubrange", {"count": count})
return self.module.add_debug_info(
"DICompositeType",
{
"tag": dc.DW_TAG_array_type,
"baseType": base_type,
"size": type_sizing,
"elements": [subrange],
},
)
@staticmethod
def _compute_array_size(base_type: Any, count: int) -> int:
"""
Compute the size of an array in bits.
Args:
base_type: The base type of the array
count: Number of elements in the array
Returns:
Total size in bits
"""
# Extract size from base_type if possible
# For simplicity, assuming base_type has a size attribute
return getattr(base_type, "size", 32) * count
@ -115,23 +140,6 @@ class DebugInfoGenerator:
},
)
def create_struct_member_vmlinux(
self, name: str, base_type_with_size: Any, offset: int
) -> Any:
"""Create a struct member with the given name, type, and offset"""
base_type, type_size = base_type_with_size
return self.module.add_debug_info(
"DIDerivedType",
{
"tag": dc.DW_TAG_member,
"name": name,
"file": self.module._file_metadata,
"baseType": base_type,
"size": type_size,
"offset": offset,
},
)
def create_struct_type(
self, members: List[Any], size: int, is_distinct: bool
) -> Any:
@ -147,22 +155,6 @@ class DebugInfoGenerator:
is_distinct=is_distinct,
)
def create_struct_type_with_name(
self, name: str, members: List[Any], size: int, is_distinct: bool
) -> Any:
"""Create a struct type with the given members and size"""
return self.module.add_debug_info(
"DICompositeType",
{
"name": name,
"tag": dc.DW_TAG_structure_type,
"file": self.module._file_metadata,
"size": size,
"elements": members,
},
is_distinct=is_distinct,
)
def create_global_var_debug_info(
self, name: str, var_type: Any, is_local: bool = False
) -> Any:

View File

@ -1,7 +1,11 @@
"""Debug information types and constants."""
import llvmlite.ir as ir
class DwarfBehaviorEnum:
"""DWARF module flag behavior constants for LLVM."""
ERROR_IF_MISMATCH = ir.Constant(ir.IntType(32), 1)
WARNING_IF_MISMATCH = ir.Constant(ir.IntType(32), 2)
OVERRIDE_USE_LARGEST = ir.Constant(ir.IntType(32), 7)

View File

@ -1,3 +1,9 @@
"""
DWARF debugging format constants.
Generated constants from dwarf.h for use in debug information generation.
"""
# generated constants from dwarf.h
DW_UT_compile = 0x01

View File

@ -1,3 +1,11 @@
"""
Decorators for marking BPF functions, maps, structs, and globals.
This module provides the core decorators used to annotate Python code
for BPF compilation.
"""
def bpf(func):
"""Decorator to mark a function for BPF compilation."""
func._is_bpf = True
@ -23,7 +31,18 @@ def struct(cls):
def section(name: str):
"""
Decorator to specify the ELF section name for a BPF program.
Args:
name: The section name (e.g., 'xdp', 'tracepoint/syscalls/sys_enter_execve')
Returns:
A decorator function that marks the function with the section name
"""
def wrapper(fn):
"""Decorator that sets the section name on the function."""
fn._section = name
return fn

View File

@ -1,16 +1,6 @@
from .expr_pass import eval_expr, handle_expr, get_operand_value
from .type_normalization import convert_to_bool, get_base_type_and_depth
from .ir_ops import deref_to_depth
from .call_registry import CallHandlerRegistry
from .vmlinux_registry import VmlinuxHandlerRegistry
"""Expression evaluation and processing for BPF programs."""
__all__ = [
"eval_expr",
"handle_expr",
"convert_to_bool",
"get_base_type_and_depth",
"deref_to_depth",
"get_operand_value",
"CallHandlerRegistry",
"VmlinuxHandlerRegistry",
]
from .expr_pass import eval_expr, handle_expr
from .type_normalization import convert_to_bool
__all__ = ["eval_expr", "handle_expr", "convert_to_bool"]

View File

@ -1,20 +0,0 @@
class CallHandlerRegistry:
"""Registry for handling different types of calls (helpers, etc.)"""
_handler = None
@classmethod
def set_handler(cls, handler):
"""Set the handler for unknown calls"""
cls._handler = handler
@classmethod
def handle_call(
cls, call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
):
"""Handle a call using the registered handler"""
if cls._handler is None:
return None
return cls._handler(
call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
)

View File

@ -1,3 +1,11 @@
"""
Expression evaluation and LLVM IR generation.
This module handles the evaluation of Python expressions in BPF programs,
including variables, constants, function calls, comparisons, boolean
operations, and more.
"""
import ast
from llvmlite import ir
from logging import Logger
@ -5,21 +13,10 @@ import logging
from typing import Dict
from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes
from .call_registry import CallHandlerRegistry
from .type_normalization import (
convert_to_bool,
handle_comparator,
get_base_type_and_depth,
deref_to_depth,
)
from .vmlinux_registry import VmlinuxHandlerRegistry
from .type_normalization import convert_to_bool, handle_comparator
logger: Logger = logging.getLogger(__name__)
# ============================================================================
# Leaf Handlers (No Recursive eval_expr calls)
# ============================================================================
def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder):
"""Handle ast.Name expressions."""
@ -28,34 +25,16 @@ def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder
val = builder.load(var)
return val, local_sym_tab[expr.id].ir_type
else:
# Check if it's a vmlinux enum/constant
vmlinux_result = VmlinuxHandlerRegistry.handle_name(expr.id)
if vmlinux_result is not None:
return vmlinux_result
raise SyntaxError(f"Undefined variable {expr.id}")
logger.info(f"Undefined variable {expr.id}")
return None
def _handle_constant_expr(module, builder, expr: ast.Constant):
def _handle_constant_expr(expr: ast.Constant):
"""Handle ast.Constant expressions."""
if isinstance(expr.value, int) or isinstance(expr.value, bool):
return ir.Constant(ir.IntType(64), int(expr.value)), ir.IntType(64)
elif isinstance(expr.value, str):
str_name = f".str.{id(expr)}"
str_bytes = expr.value.encode("utf-8") + b"\x00"
str_type = ir.ArrayType(ir.IntType(8), len(str_bytes))
str_constant = ir.Constant(str_type, bytearray(str_bytes))
# Create global variable
global_str = ir.GlobalVariable(module, str_type, name=str_name)
global_str.linkage = "internal"
global_str.global_constant = True
global_str.initializer = str_constant
str_ptr = builder.bitcast(global_str, ir.PointerType(ir.IntType(8)))
return str_ptr, ir.PointerType(ir.IntType(8))
else:
logger.error(f"Unsupported constant type {ast.dump(expr)}")
logger.error("Unsupported constant type")
return None
@ -79,13 +58,6 @@ def _handle_attribute_expr(
val = builder.load(gep)
field_type = metadata.field_type(attr_name)
return val, field_type
# Try vmlinux handler as fallback
vmlinux_result = VmlinuxHandlerRegistry.handle_attribute(
expr, local_sym_tab, None, builder
)
if vmlinux_result is not None:
return vmlinux_result
return None
@ -124,123 +96,6 @@ def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilde
return val, local_sym_tab[arg.id].ir_type
# ============================================================================
# Binary Operations
# ============================================================================
def get_operand_value(
func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None
):
"""Extract the value from an operand, handling variables and constants."""
logger.info(f"Getting operand value for: {ast.dump(operand)}")
if isinstance(operand, ast.Name):
if operand.id in local_sym_tab:
var = local_sym_tab[operand.id].var
var_type = var.type
base_type, depth = get_base_type_and_depth(var_type)
logger.info(f"var is {var}, base_type is {base_type}, depth is {depth}")
val = deref_to_depth(func, builder, var, depth)
return val
else:
# Check if it's a vmlinux enum/constant
vmlinux_result = VmlinuxHandlerRegistry.handle_name(operand.id)
if vmlinux_result is not None:
val, _ = vmlinux_result
return val
elif isinstance(operand, ast.Constant):
if isinstance(operand.value, int):
cst = ir.Constant(ir.IntType(64), int(operand.value))
return cst
raise TypeError(f"Unsupported constant type: {type(operand.value)}")
elif isinstance(operand, ast.BinOp):
res = _handle_binary_op_impl(
func, module, operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab
)
return res
else:
res = eval_expr(
func, module, builder, operand, local_sym_tab, map_sym_tab, structs_sym_tab
)
if res is None:
raise ValueError(f"Failed to evaluate call expression: {operand}")
val, _ = res
logger.info(f"Evaluated expr to {val} of type {val.type}")
base_type, depth = get_base_type_and_depth(val.type)
if depth > 0:
val = deref_to_depth(func, builder, val, depth)
return val
raise TypeError(f"Unsupported operand type: {type(operand)}")
def _handle_binary_op_impl(
func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab=None
):
op = rval.op
left = get_operand_value(
func, module, rval.left, builder, local_sym_tab, map_sym_tab, structs_sym_tab
)
right = get_operand_value(
func, module, rval.right, builder, local_sym_tab, map_sym_tab, structs_sym_tab
)
logger.info(f"left is {left}, right is {right}, op is {op}")
# NOTE: Before doing the operation, if the operands are integers
# we always extend them to i64. The assignment to LHS will take
# care of truncation if needed.
if isinstance(left.type, ir.IntType) and left.type.width < 64:
left = builder.sext(left, ir.IntType(64))
if isinstance(right.type, ir.IntType) and right.type.width < 64:
right = builder.sext(right, ir.IntType(64))
# Map AST operation nodes to LLVM IR builder methods
op_map = {
ast.Add: builder.add,
ast.Sub: builder.sub,
ast.Mult: builder.mul,
ast.Div: builder.sdiv,
ast.Mod: builder.srem,
ast.LShift: builder.shl,
ast.RShift: builder.lshr,
ast.BitOr: builder.or_,
ast.BitXor: builder.xor,
ast.BitAnd: builder.and_,
ast.FloorDiv: builder.udiv,
}
if type(op) in op_map:
result = op_map[type(op)](left, right)
return result
else:
raise SyntaxError("Unsupported binary operation")
def _handle_binary_op(
func,
module,
rval,
builder,
var_name,
local_sym_tab,
map_sym_tab,
structs_sym_tab=None,
):
result = _handle_binary_op_impl(
func, module, rval, builder, local_sym_tab, map_sym_tab, structs_sym_tab
)
if var_name and var_name in local_sym_tab:
logger.info(
f"Storing result {result} into variable {local_sym_tab[var_name].var}"
)
builder.store(result, local_sym_tab[var_name].var)
return result, result.type
# ============================================================================
# Comparison and Unary Operations
# ============================================================================
def _handle_ctypes_call(
func,
module,
@ -329,32 +184,21 @@ def _handle_unary_op(
structs_sym_tab=None,
):
"""Handle ast.UnaryOp expressions."""
if not isinstance(expr.op, ast.Not) and not isinstance(expr.op, ast.USub):
logger.error("Only 'not' and '-' unary operators are supported")
if not isinstance(expr.op, ast.Not):
logger.error("Only 'not' unary operator is supported")
return None
operand = get_operand_value(
func, module, expr.operand, builder, local_sym_tab, map_sym_tab, structs_sym_tab
operand = eval_expr(
func, module, builder, expr.operand, local_sym_tab, map_sym_tab, structs_sym_tab
)
if operand is None:
logger.error("Failed to evaluate operand for unary operation")
return None
if isinstance(expr.op, ast.Not):
true_const = ir.Constant(ir.IntType(1), 1)
result = builder.xor(convert_to_bool(builder, operand), true_const)
return result, ir.IntType(1)
elif isinstance(expr.op, ast.USub):
# Multiply by -1
neg_one = ir.Constant(ir.IntType(64), -1)
result = builder.mul(operand, neg_one)
return result, ir.IntType(64)
return None
# ============================================================================
# Boolean Operations
# ============================================================================
operand_val, operand_type = operand
true_const = ir.Constant(ir.IntType(1), 1)
result = builder.xor(convert_to_bool(builder, operand_val), true_const)
return result, ir.IntType(1)
def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab):
@ -487,11 +331,6 @@ def _handle_boolean_op(
return None
# ============================================================================
# Expression Dispatcher
# ============================================================================
def eval_expr(
func,
module,
@ -501,11 +340,26 @@ def eval_expr(
map_sym_tab,
structs_sym_tab=None,
):
"""
Evaluate an expression and return its LLVM IR value and type.
Args:
func: The LLVM IR function being built
module: The LLVM IR module
builder: LLVM IR builder
expr: The AST expression node to evaluate
local_sym_tab: Local symbol table
map_sym_tab: Map symbol table
structs_sym_tab: Struct symbol table
Returns:
A tuple of (value, type) or None if evaluation fails
"""
logger.info(f"Evaluating expression: {ast.dump(expr)}")
if isinstance(expr, ast.Name):
return _handle_name_expr(expr, local_sym_tab, builder)
elif isinstance(expr, ast.Constant):
return _handle_constant_expr(module, builder, expr)
return _handle_constant_expr(expr)
elif isinstance(expr, ast.Call):
if isinstance(expr.func, ast.Name) and expr.func.id == "deref":
return _handle_deref_call(expr, local_sym_tab, builder)
@ -521,27 +375,57 @@ def eval_expr(
structs_sym_tab,
)
result = CallHandlerRegistry.handle_call(
expr, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
)
if result is not None:
return result
# delayed import to avoid circular dependency
from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call
logger.warning(f"Unknown call: {ast.dump(expr)}")
return None
if isinstance(expr.func, ast.Name) and HelperHandlerRegistry.has_handler(
expr.func.id
):
return handle_helper_call(
expr,
module,
builder,
func,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
elif isinstance(expr.func, ast.Attribute):
logger.info(f"Handling method call: {ast.dump(expr.func)}")
if isinstance(expr.func.value, ast.Call) and isinstance(
expr.func.value.func, ast.Name
):
method_name = expr.func.attr
if HelperHandlerRegistry.has_handler(method_name):
return handle_helper_call(
expr,
module,
builder,
func,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
elif isinstance(expr.func.value, ast.Name):
obj_name = expr.func.value.id
method_name = expr.func.attr
if obj_name in map_sym_tab:
if HelperHandlerRegistry.has_handler(method_name):
return handle_helper_call(
expr,
module,
builder,
func,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
elif isinstance(expr, ast.Attribute):
return _handle_attribute_expr(expr, local_sym_tab, structs_sym_tab, builder)
elif isinstance(expr, ast.BinOp):
return _handle_binary_op(
func,
module,
expr,
builder,
None,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
from pythonbpf.binary_ops import handle_binary_op
return handle_binary_op(expr, builder, None, local_sym_tab)
elif isinstance(expr, ast.Compare):
return _handle_compare(
func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab

View File

@ -1,50 +0,0 @@
import logging
from llvmlite import ir
logger = logging.getLogger(__name__)
def deref_to_depth(func, builder, val, target_depth):
"""Dereference a pointer to a certain depth."""
cur_val = val
cur_type = val.type
for depth in range(target_depth):
if not isinstance(val.type, ir.PointerType):
logger.error("Cannot dereference further, non-pointer type")
return None
# dereference with null check
pointee_type = cur_type.pointee
null_check_block = builder.block
not_null_block = func.append_basic_block(name=f"deref_not_null_{depth}")
merge_block = func.append_basic_block(name=f"deref_merge_{depth}")
null_ptr = ir.Constant(cur_type, None)
is_not_null = builder.icmp_signed("!=", cur_val, null_ptr)
logger.debug(f"Inserted null check for pointer at depth {depth}")
builder.cbranch(is_not_null, not_null_block, merge_block)
builder.position_at_end(not_null_block)
dereferenced_val = builder.load(cur_val)
logger.debug(f"Dereferenced to depth {depth - 1}, type: {pointee_type}")
builder.branch(merge_block)
builder.position_at_end(merge_block)
phi = builder.phi(pointee_type, name=f"deref_result_{depth}")
zero_value = (
ir.Constant(pointee_type, 0)
if isinstance(pointee_type, ir.IntType)
else ir.Constant(pointee_type, None)
)
phi.add_incoming(zero_value, null_check_block)
phi.add_incoming(dereferenced_val, not_null_block)
# Continue with phi result
cur_val = phi
cur_type = pointee_type
return cur_val

View File

@ -1,7 +1,13 @@
"""
Type normalization and comparison operations for expressions.
This module provides utilities for normalizing types between expressions,
handling pointer dereferencing, and generating comparison operations.
"""
from llvmlite import ir
import logging
import ast
from llvmlite import ir
from .ir_ops import deref_to_depth
logger = logging.getLogger(__name__)
@ -17,8 +23,16 @@ COMPARISON_OPS = {
}
def get_base_type_and_depth(ir_type):
"""Get the base type for pointer types."""
def _get_base_type_and_depth(ir_type):
"""
Get the base type and pointer depth for an LLVM IR type.
Args:
ir_type: The LLVM IR type to analyze
Returns:
A tuple of (base_type, depth) where depth is the number of pointer levels
"""
cur_type = ir_type
depth = 0
while isinstance(cur_type, ir.PointerType):
@ -27,8 +41,76 @@ def get_base_type_and_depth(ir_type):
return cur_type, depth
def _deref_to_depth(func, builder, val, target_depth):
"""
Dereference a pointer to a certain depth with null checks.
Args:
func: The LLVM IR function being built
builder: LLVM IR builder
val: The pointer value to dereference
target_depth: Number of levels to dereference
Returns:
The dereferenced value, or None if dereferencing fails
"""
cur_val = val
cur_type = val.type
for depth in range(target_depth):
if not isinstance(val.type, ir.PointerType):
logger.error("Cannot dereference further, non-pointer type")
return None
# dereference with null check
pointee_type = cur_type.pointee
null_check_block = builder.block
not_null_block = func.append_basic_block(name=f"deref_not_null_{depth}")
merge_block = func.append_basic_block(name=f"deref_merge_{depth}")
null_ptr = ir.Constant(cur_type, None)
is_not_null = builder.icmp_signed("!=", cur_val, null_ptr)
logger.debug(f"Inserted null check for pointer at depth {depth}")
builder.cbranch(is_not_null, not_null_block, merge_block)
builder.position_at_end(not_null_block)
dereferenced_val = builder.load(cur_val)
logger.debug(f"Dereferenced to depth {depth - 1}, type: {pointee_type}")
builder.branch(merge_block)
builder.position_at_end(merge_block)
phi = builder.phi(pointee_type, name=f"deref_result_{depth}")
zero_value = (
ir.Constant(pointee_type, 0)
if isinstance(pointee_type, ir.IntType)
else ir.Constant(pointee_type, None)
)
phi.add_incoming(zero_value, null_check_block)
phi.add_incoming(dereferenced_val, not_null_block)
# Continue with phi result
cur_val = phi
cur_type = pointee_type
return cur_val
def _normalize_types(func, builder, lhs, rhs):
"""Normalize types for comparison."""
"""
Normalize types for comparison by casting or dereferencing as needed.
Args:
func: The LLVM IR function being built
builder: LLVM IR builder
lhs: Left-hand side value
rhs: Right-hand side value
Returns:
A tuple of (normalized_lhs, normalized_rhs) or (None, None) on error
"""
logger.info(f"Normalizing types: {lhs.type} vs {rhs.type}")
if isinstance(lhs.type, ir.IntType) and isinstance(rhs.type, ir.IntType):
@ -43,18 +125,27 @@ def _normalize_types(func, builder, lhs, rhs):
logger.error(f"Type mismatch: {lhs.type} vs {rhs.type}")
return None, None
else:
lhs_base, lhs_depth = get_base_type_and_depth(lhs.type)
rhs_base, rhs_depth = get_base_type_and_depth(rhs.type)
lhs_base, lhs_depth = _get_base_type_and_depth(lhs.type)
rhs_base, rhs_depth = _get_base_type_and_depth(rhs.type)
if lhs_base == rhs_base:
if lhs_depth < rhs_depth:
rhs = deref_to_depth(func, builder, rhs, rhs_depth - lhs_depth)
rhs = _deref_to_depth(func, builder, rhs, rhs_depth - lhs_depth)
elif rhs_depth < lhs_depth:
lhs = deref_to_depth(func, builder, lhs, lhs_depth - rhs_depth)
lhs = _deref_to_depth(func, builder, lhs, lhs_depth - rhs_depth)
return _normalize_types(func, builder, lhs, rhs)
def convert_to_bool(builder, val):
"""Convert a value to boolean."""
"""
Convert an LLVM IR value to a boolean (i1) type.
Args:
builder: LLVM IR builder
val: The value to convert
Returns:
An i1 boolean value
"""
if val.type == ir.IntType(1):
return val
if isinstance(val.type, ir.PointerType):
@ -65,7 +156,19 @@ def convert_to_bool(builder, val):
def handle_comparator(func, builder, op, lhs, rhs):
"""Handle comparison operations."""
"""
Handle comparison operations between two values.
Args:
func: The LLVM IR function being built
builder: LLVM IR builder
op: The AST comparison operator node
lhs: Left-hand side value
rhs: Right-hand side value
Returns:
A tuple of (result, ir.IntType(1)) or None on error
"""
if lhs.type != rhs.type:
lhs, rhs = _normalize_types(func, builder, lhs, rhs)

View File

@ -1,45 +0,0 @@
import ast
class VmlinuxHandlerRegistry:
"""Registry for vmlinux handler operations"""
_handler = None
@classmethod
def set_handler(cls, handler):
"""Set the vmlinux handler"""
cls._handler = handler
@classmethod
def get_handler(cls):
"""Get the vmlinux handler"""
return cls._handler
@classmethod
def handle_name(cls, name):
"""Try to handle a name as vmlinux enum/constant"""
if cls._handler is None:
return None
return cls._handler.handle_vmlinux_enum(name)
@classmethod
def handle_attribute(cls, expr, local_sym_tab, module, builder):
"""Try to handle an attribute access as vmlinux struct field"""
if cls._handler is None:
return None
if isinstance(expr.value, ast.Name):
var_name = expr.value.id
field_name = expr.attr
return cls._handler.handle_vmlinux_struct_field(
var_name, field_name, module, builder, local_sym_tab
)
return None
@classmethod
def is_vmlinux_struct(cls, name):
"""Check if a name refers to a vmlinux struct"""
if cls._handler is None:
return False
return cls._handler.is_vmlinux_struct(name)

View File

@ -1,3 +1,5 @@
"""BPF function processing and LLVM IR generation."""
from .functions_pass import func_proc
__all__ = ["func_proc"]

View File

@ -0,0 +1,25 @@
"""Registry for statement handler functions."""
from typing import Dict
class StatementHandlerRegistry:
"""Registry for statement handlers."""
_handlers: Dict = {}
@classmethod
def register(cls, stmt_type):
"""Register a handler for a specific statement type."""
def decorator(handler):
"""Decorator that registers the handler."""
cls._handlers[stmt_type] = handler
return handler
return decorator
@classmethod
def __getitem__(cls, stmt_type):
"""Get the handler for a specific statement type."""
return cls._handlers.get(stmt_type, None)

View File

@ -1,88 +0,0 @@
import ast
def get_probe_string(func_node):
"""Extract the probe string from the decorator of the function node"""
# TODO: right now we have the whole string in the section decorator
# But later we can implement typed tuples for tracepoints and kprobes
# For helper functions, we return "helper"
for decorator in func_node.decorator_list:
if isinstance(decorator, ast.Name) and decorator.id == "bpfglobal":
return None
if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Name):
if decorator.func.id == "section" and len(decorator.args) == 1:
arg = decorator.args[0]
if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
return arg.value
return "helper"
def is_global_function(func_node):
"""Check if the function is a global"""
for decorator in func_node.decorator_list:
if isinstance(decorator, ast.Name) and decorator.id in (
"map",
"bpfglobal",
"struct",
):
return True
return False
def infer_return_type(func_node: ast.FunctionDef):
if not isinstance(func_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
raise TypeError("Expected ast.FunctionDef")
if func_node.returns is not None:
try:
return ast.unparse(func_node.returns)
except Exception:
node = func_node.returns
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Attribute):
return getattr(node, "attr", type(node).__name__)
try:
return str(node)
except Exception:
return type(node).__name__
found_type = None
def _expr_type(e):
if e is None:
return "None"
if isinstance(e, ast.Constant):
return type(e.value).__name__
if isinstance(e, ast.Name):
return e.id
if isinstance(e, ast.Call):
f = e.func
if isinstance(f, ast.Name):
return f.id
if isinstance(f, ast.Attribute):
try:
return ast.unparse(f)
except Exception:
return getattr(f, "attr", type(f).__name__)
try:
return ast.unparse(f)
except Exception:
return type(f).__name__
if isinstance(e, ast.Attribute):
try:
return ast.unparse(e)
except Exception:
return getattr(e, "attr", type(e).__name__)
try:
return ast.unparse(e)
except Exception:
return type(e).__name__
for walked_node in ast.walk(func_node):
if isinstance(walked_node, ast.Return):
t = _expr_type(walked_node.value)
if found_type is None:
found_type = t
elif found_type != t:
raise ValueError(f"Conflicting return types: {found_type} vs {t}")
return found_type or "None"

View File

@ -1,191 +1,281 @@
"""
BPF function processing and LLVM IR generation.
This module handles the core function processing, converting Python function
definitions into LLVM IR for BPF programs. It manages local variables,
control flow, and statement processing.
"""
from llvmlite import ir
import ast
import logging
from typing import Any
from dataclasses import dataclass
from pythonbpf.helper import (
HelperHandlerRegistry,
reset_scratch_pool,
)
from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call
from pythonbpf.type_deducer import ctypes_to_ir
from pythonbpf.binary_ops import handle_binary_op
from pythonbpf.expr import eval_expr, handle_expr, convert_to_bool
from pythonbpf.assign_pass import (
handle_variable_assignment,
handle_struct_field_assignment,
)
from pythonbpf.allocation_pass import (
handle_assign_allocation,
allocate_temp_pool,
create_targets_and_rvals,
)
from .return_utils import handle_none_return, handle_xdp_return, is_xdp_name
from .function_metadata import get_probe_string, is_global_function, infer_return_type
from .return_utils import _handle_none_return, _handle_xdp_return, _is_xdp_name
logger = logging.getLogger(__name__)
# ============================================================================
# SECTION 1: Memory Allocation
# ============================================================================
@dataclass
class LocalSymbol:
"""
Represents a local variable in a BPF function.
Attributes:
var: LLVM IR alloca instruction for the variable
ir_type: LLVM IR type of the variable
metadata: Optional metadata (e.g., struct type name)
"""
var: ir.AllocaInstr
ir_type: ir.Type
metadata: Any = None
def __iter__(self):
"""Support tuple unpacking of LocalSymbol."""
yield self.var
yield self.ir_type
yield self.metadata
def count_temps_in_call(call_node, local_sym_tab):
"""Count the number of temporary variables needed for a function call."""
def get_probe_string(func_node):
"""Extract the probe string from the decorator of the function node."""
# TODO: right now we have the whole string in the section decorator
# But later we can implement typed tuples for tracepoints and kprobes
# For helper functions, we return "helper"
count = 0
is_helper = False
# NOTE: We exclude print calls for now
if isinstance(call_node.func, ast.Name):
if (
HelperHandlerRegistry.has_handler(call_node.func.id)
and call_node.func.id != "print"
):
is_helper = True
elif isinstance(call_node.func, ast.Attribute):
if HelperHandlerRegistry.has_handler(call_node.func.attr):
is_helper = True
if not is_helper:
return 0
for arg in call_node.args:
# NOTE: Count all non-name arguments
# For struct fields, if it is being passed as an argument,
# The struct object should already exist in the local_sym_tab
if not isinstance(arg, ast.Name) and not (
isinstance(arg, ast.Attribute) and arg.value.id in local_sym_tab
):
count += 1
return count
def handle_if_allocation(
module, builder, stmt, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
):
"""Recursively handle allocations in if/else branches."""
if stmt.body:
allocate_mem(
module,
builder,
stmt.body,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
if stmt.orelse:
allocate_mem(
module,
builder,
stmt.orelse,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
def allocate_mem(
module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
):
max_temps_needed = 0
def update_max_temps_for_stmt(stmt):
nonlocal max_temps_needed
temps_needed = 0
if isinstance(stmt, ast.If):
for s in stmt.body:
update_max_temps_for_stmt(s)
for s in stmt.orelse:
update_max_temps_for_stmt(s)
return
for node in ast.walk(stmt):
if isinstance(node, ast.Call):
temps_needed += count_temps_in_call(node, local_sym_tab)
max_temps_needed = max(max_temps_needed, temps_needed)
for stmt in body:
update_max_temps_for_stmt(stmt)
# Handle allocations
if isinstance(stmt, ast.If):
handle_if_allocation(
module,
builder,
stmt,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
elif isinstance(stmt, ast.Assign):
handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab)
allocate_temp_pool(builder, max_temps_needed, local_sym_tab)
return local_sym_tab
# ============================================================================
# SECTION 2: Statement Handlers
# ============================================================================
for decorator in func_node.decorator_list:
if isinstance(decorator, ast.Name) and decorator.id == "bpfglobal":
return None
if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Name):
if decorator.func.id == "section" and len(decorator.args) == 1:
arg = decorator.args[0]
if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
return arg.value
return "helper"
def handle_assign(
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab
):
"""Handle assignment statements in the function body."""
if len(stmt.targets) != 1:
logger.info("Unsupported multiassignment")
return
# NOTE: Support multi-target assignments (e.g.: a, b = 1, 2)
targets, rvals = create_targets_and_rvals(stmt)
num_types = ("c_int32", "c_int64", "c_uint32", "c_uint64")
for target, rval in zip(targets, rvals):
if isinstance(target, ast.Name):
# NOTE: Simple variable assignment case: x = 5
var_name = target.id
result = handle_variable_assignment(
func,
module,
builder,
var_name,
rval,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
target = stmt.targets[0]
logger.info(f"Handling assignment to {ast.dump(target)}")
if not isinstance(target, ast.Name) and not isinstance(target, ast.Attribute):
logger.info("Unsupported assignment target")
return
var_name = target.id if isinstance(target, ast.Name) else target.value.id
rval = stmt.value
if isinstance(target, ast.Attribute):
# struct field assignment
field_name = target.attr
if var_name in local_sym_tab:
struct_type = local_sym_tab[var_name].metadata
struct_info = structs_sym_tab[struct_type]
if field_name in struct_info.fields:
field_ptr = struct_info.gep(
builder, local_sym_tab[var_name].var, field_name
)
val = eval_expr(
func,
module,
builder,
rval,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
if isinstance(struct_info.field_type(field_name), ir.ArrayType) and val[
1
] == ir.PointerType(ir.IntType(8)):
# TODO: Figure it out, not a priority rn
# Special case for string assignment to char array
# str_len = struct_info["field_types"][field_idx].count
# assign_string_to_array(builder, field_ptr, val[0], str_len)
# print(f"Assigned to struct field {var_name}.{field_name}")
pass
if val is None:
logger.info("Failed to evaluate struct field assignment")
return
logger.info(field_ptr)
builder.store(val[0], field_ptr)
logger.info(f"Assigned to struct field {var_name}.{field_name}")
return
elif isinstance(rval, ast.Constant):
if isinstance(rval.value, bool):
if rval.value:
builder.store(
ir.Constant(ir.IntType(1), 1), local_sym_tab[var_name].var
)
else:
builder.store(
ir.Constant(ir.IntType(1), 0), local_sym_tab[var_name].var
)
logger.info(f"Assigned constant {rval.value} to {var_name}")
elif isinstance(rval.value, int):
# Assume c_int64 for now
# var = builder.alloca(ir.IntType(64), name=var_name)
# var.align = 8
builder.store(
ir.Constant(ir.IntType(64), rval.value), local_sym_tab[var_name].var
)
if not result:
logger.error(f"Failed to handle assignment to {var_name}")
continue
if isinstance(target, ast.Attribute):
# NOTE: Struct field assignment case: pkt.field = value
handle_struct_field_assignment(
func,
module,
builder,
target,
rval,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
logger.info(f"Assigned constant {rval.value} to {var_name}")
elif isinstance(rval.value, str):
str_val = rval.value.encode("utf-8") + b"\x00"
str_const = ir.Constant(
ir.ArrayType(ir.IntType(8), len(str_val)), bytearray(str_val)
)
continue
# Unsupported target type
logger.error(f"Unsupported assignment target: {ast.dump(target)}")
global_str = ir.GlobalVariable(
module, str_const.type, name=f"{var_name}_str"
)
global_str.linkage = "internal"
global_str.global_constant = True
global_str.initializer = str_const
str_ptr = builder.bitcast(global_str, ir.PointerType(ir.IntType(8)))
builder.store(str_ptr, local_sym_tab[var_name].var)
logger.info(f"Assigned string constant '{rval.value}' to {var_name}")
else:
logger.info("Unsupported constant type")
elif isinstance(rval, ast.Call):
if isinstance(rval.func, ast.Name):
call_type = rval.func.id
logger.info(f"Assignment call type: {call_type}")
if (
call_type in num_types
and len(rval.args) == 1
and isinstance(rval.args[0], ast.Constant)
and isinstance(rval.args[0].value, int)
):
ir_type = ctypes_to_ir(call_type)
# var = builder.alloca(ir_type, name=var_name)
# var.align = ir_type.width // 8
builder.store(
ir.Constant(ir_type, rval.args[0].value),
local_sym_tab[var_name].var,
)
logger.info(
f"Assigned {call_type} constant {rval.args[0].value} to {var_name}"
)
elif HelperHandlerRegistry.has_handler(call_type):
# var = builder.alloca(ir.IntType(64), name=var_name)
# var.align = 8
val = handle_helper_call(
rval,
module,
builder,
func,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
builder.store(val[0], local_sym_tab[var_name].var)
logger.info(f"Assigned constant {rval.func.id} to {var_name}")
elif call_type == "deref" and len(rval.args) == 1:
logger.info(f"Handling deref assignment {ast.dump(rval)}")
val = eval_expr(
func,
module,
builder,
rval,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
if val is None:
logger.info("Failed to evaluate deref argument")
return
logger.info(f"Dereferenced value: {val}, storing in {var_name}")
builder.store(val[0], local_sym_tab[var_name].var)
logger.info(f"Dereferenced and assigned to {var_name}")
elif call_type in structs_sym_tab and len(rval.args) == 0:
struct_info = structs_sym_tab[call_type]
ir_type = struct_info.ir_type
# var = builder.alloca(ir_type, name=var_name)
# Null init
builder.store(ir.Constant(ir_type, None), local_sym_tab[var_name].var)
logger.info(f"Assigned struct {call_type} to {var_name}")
else:
logger.info(f"Unsupported assignment call type: {call_type}")
elif isinstance(rval.func, ast.Attribute):
logger.info(f"Assignment call attribute: {ast.dump(rval.func)}")
if isinstance(rval.func.value, ast.Name):
if rval.func.value.id in map_sym_tab:
map_name = rval.func.value.id
method_name = rval.func.attr
if HelperHandlerRegistry.has_handler(method_name):
val = handle_helper_call(
rval,
module,
builder,
func,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
builder.store(val[0], local_sym_tab[var_name].var)
else:
# TODO: probably a struct access
logger.info(f"TODO STRUCT ACCESS {ast.dump(rval)}")
elif isinstance(rval.func.value, ast.Call) and isinstance(
rval.func.value.func, ast.Name
):
map_name = rval.func.value.func.id
method_name = rval.func.attr
if map_name in map_sym_tab:
if HelperHandlerRegistry.has_handler(method_name):
val = handle_helper_call(
rval,
module,
builder,
func,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
# var = builder.alloca(ir.IntType(64), name=var_name)
# var.align = 8
builder.store(val[0], local_sym_tab[var_name].var)
else:
logger.info("Unsupported assignment call structure")
else:
logger.info("Unsupported assignment call function type")
elif isinstance(rval, ast.BinOp):
handle_binary_op(rval, builder, var_name, local_sym_tab)
else:
logger.info("Unsupported assignment value type")
def handle_cond(
func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab=None
):
"""
Evaluate a condition expression and convert it to a boolean value.
Args:
func: The LLVM IR function being built
module: The LLVM IR module
builder: LLVM IR builder
cond: The AST condition node to evaluate
local_sym_tab: Local symbol table
map_sym_tab: Map symbol table
structs_sym_tab: Struct symbol table
Returns:
LLVM IR boolean value representing the condition result
"""
val = eval_expr(
func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab
)[0]
@ -241,11 +331,23 @@ def handle_if(
def handle_return(builder, stmt, local_sym_tab, ret_type):
"""
Handle return statements in BPF functions.
Args:
builder: LLVM IR builder
stmt: The AST Return node
local_sym_tab: Local symbol table
ret_type: Expected return type
Returns:
True if a return was emitted, False otherwise
"""
logger.info(f"Handling return statement: {ast.dump(stmt)}")
if stmt.value is None:
return handle_none_return(builder)
elif isinstance(stmt.value, ast.Name) and is_xdp_name(stmt.value.id):
return handle_xdp_return(stmt, builder, ret_type)
return _handle_none_return(builder)
elif isinstance(stmt.value, ast.Name) and _is_xdp_name(stmt.value.id):
return _handle_xdp_return(stmt, builder, ret_type)
else:
val = eval_expr(
func=None,
@ -272,8 +374,24 @@ def process_stmt(
did_return,
ret_type=ir.IntType(64),
):
"""
Process a single statement in a BPF function.
Args:
func: The LLVM IR function being built
module: The LLVM IR module
builder: LLVM IR builder
stmt: The AST statement node to process
local_sym_tab: Local symbol table
map_sym_tab: Map symbol table
structs_sym_tab: Struct symbol table
did_return: Whether a return has been emitted
ret_type: Expected return type
Returns:
True if a return was emitted, False otherwise
"""
logger.info(f"Processing statement: {ast.dump(stmt)}")
reset_scratch_pool()
if isinstance(stmt, ast.Expr):
handle_expr(
func,
@ -304,19 +422,143 @@ def process_stmt(
return did_return
# ============================================================================
# SECTION 3: Function Body Processing
# ============================================================================
def allocate_mem(
module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
):
"""
Pre-allocate stack memory for local variables in a BPF function.
This function scans the function body and creates alloca instructions
for all local variables before processing the function statements.
Args:
module: The LLVM IR module
builder: LLVM IR builder
body: List of AST statements in the function body
func: The LLVM IR function being built
ret_type: Expected return type
map_sym_tab: Map symbol table
local_sym_tab: Local symbol table to populate
structs_sym_tab: Struct symbol table
Returns:
Updated local symbol table
"""
for stmt in body:
has_metadata = False
if isinstance(stmt, ast.If):
if stmt.body:
local_sym_tab = allocate_mem(
module,
builder,
stmt.body,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
if stmt.orelse:
local_sym_tab = allocate_mem(
module,
builder,
stmt.orelse,
func,
ret_type,
map_sym_tab,
local_sym_tab,
structs_sym_tab,
)
elif isinstance(stmt, ast.Assign):
if len(stmt.targets) != 1:
logger.info("Unsupported multiassignment")
continue
target = stmt.targets[0]
if not isinstance(target, ast.Name):
logger.info("Unsupported assignment target")
continue
var_name = target.id
rval = stmt.value
if var_name in local_sym_tab:
logger.info(f"Variable {var_name} already allocated")
continue
if isinstance(rval, ast.Call):
if isinstance(rval.func, ast.Name):
call_type = rval.func.id
if call_type in ("c_int32", "c_int64", "c_uint32", "c_uint64"):
ir_type = ctypes_to_ir(call_type)
var = builder.alloca(ir_type, name=var_name)
var.align = ir_type.width // 8
logger.info(
f"Pre-allocated variable {var_name} of type {call_type}"
)
elif HelperHandlerRegistry.has_handler(call_type):
# Assume return type is int64 for now
ir_type = ir.IntType(64)
var = builder.alloca(ir_type, name=var_name)
var.align = ir_type.width // 8
logger.info(f"Pre-allocated variable {var_name} for helper")
elif call_type == "deref" and len(rval.args) == 1:
# Assume return type is int64 for now
ir_type = ir.IntType(64)
var = builder.alloca(ir_type, name=var_name)
var.align = ir_type.width // 8
logger.info(f"Pre-allocated variable {var_name} for deref")
elif call_type in structs_sym_tab:
struct_info = structs_sym_tab[call_type]
ir_type = struct_info.ir_type
var = builder.alloca(ir_type, name=var_name)
has_metadata = True
logger.info(
f"Pre-allocated variable {var_name} for struct {call_type}"
)
elif isinstance(rval.func, ast.Attribute):
ir_type = ir.PointerType(ir.IntType(64))
var = builder.alloca(ir_type, name=var_name)
# var.align = ir_type.width // 8
logger.info(f"Pre-allocated variable {var_name} for map")
else:
logger.info("Unsupported assignment call function type")
continue
elif isinstance(rval, ast.Constant):
if isinstance(rval.value, bool):
ir_type = ir.IntType(1)
var = builder.alloca(ir_type, name=var_name)
var.align = 1
logger.info(f"Pre-allocated variable {var_name} of type c_bool")
elif isinstance(rval.value, int):
# Assume c_int64 for now
ir_type = ir.IntType(64)
var = builder.alloca(ir_type, name=var_name)
var.align = ir_type.width // 8
logger.info(f"Pre-allocated variable {var_name} of type c_int64")
elif isinstance(rval.value, str):
ir_type = ir.PointerType(ir.IntType(8))
var = builder.alloca(ir_type, name=var_name)
var.align = 8
logger.info(f"Pre-allocated variable {var_name} of type string")
else:
logger.info("Unsupported constant type")
continue
elif isinstance(rval, ast.BinOp):
# Assume c_int64 for now
ir_type = ir.IntType(64)
var = builder.alloca(ir_type, name=var_name)
var.align = ir_type.width // 8
logger.info(f"Pre-allocated variable {var_name} of type c_int64")
else:
logger.info("Unsupported assignment value type")
continue
if has_metadata:
local_sym_tab[var_name] = LocalSymbol(var, ir_type, call_type)
else:
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
return local_sym_tab
def process_func_body(
module,
builder,
func_node,
func,
ret_type,
map_sym_tab,
structs_sym_tab,
module, builder, func_node, func, ret_type, map_sym_tab, structs_sym_tab
):
"""Process the body of a bpf function"""
# TODO: A lot. We just have print -> bpf_trace_printk for now
@ -389,25 +631,33 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t
builder = ir.IRBuilder(block)
process_func_body(
module,
builder,
func_node,
func,
ret_type,
map_sym_tab,
structs_sym_tab,
module, builder, func_node, func, ret_type, map_sym_tab, structs_sym_tab
)
return func
# ============================================================================
# SECTION 4: Top-Level Function Processor
# ============================================================================
def func_proc(tree, module, chunks, map_sym_tab, structs_sym_tab):
"""
Process all BPF function chunks and generate LLVM IR.
Args:
tree: The Python AST (not used in current implementation)
module: The LLVM IR module to add functions to
chunks: List of AST function nodes decorated with @bpf
map_sym_tab: Map symbol table
structs_sym_tab: Struct symbol table
"""
for func_node in chunks:
if is_global_function(func_node):
is_global = False
for decorator in func_node.decorator_list:
if isinstance(decorator, ast.Name) and decorator.id in (
"map",
"bpfglobal",
"struct",
):
is_global = True
break
if is_global:
continue
func_type = get_probe_string(func_node)
logger.info(f"Found probe_string of {func_node.name}: {func_type}")
@ -421,7 +671,80 @@ def func_proc(tree, module, chunks, map_sym_tab, structs_sym_tab):
)
# TODO: WIP, for string assignment to fixed-size arrays
def infer_return_type(func_node: ast.FunctionDef):
"""
Infer the return type of a BPF function from annotations or return statements.
Args:
func_node: The AST function node
Returns:
String representation of the return type (e.g., 'c_int64')
Raises:
TypeError: If func_node is not a FunctionDef
"""
if not isinstance(func_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
raise TypeError("Expected ast.FunctionDef")
if func_node.returns is not None:
try:
return ast.unparse(func_node.returns)
except Exception:
node = func_node.returns
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Attribute):
return getattr(node, "attr", type(node).__name__)
try:
return str(node)
except Exception:
return type(node).__name__
found_type = None
def _expr_type(e):
"""Helper function to extract type from an expression."""
if e is None:
return "None"
if isinstance(e, ast.Constant):
return type(e.value).__name__
if isinstance(e, ast.Name):
return e.id
if isinstance(e, ast.Call):
f = e.func
if isinstance(f, ast.Name):
return f.id
if isinstance(f, ast.Attribute):
try:
return ast.unparse(f)
except Exception:
return getattr(f, "attr", type(f).__name__)
try:
return ast.unparse(f)
except Exception:
return type(f).__name__
if isinstance(e, ast.Attribute):
try:
return ast.unparse(e)
except Exception:
return getattr(e, "attr", type(e).__name__)
try:
return ast.unparse(e)
except Exception:
return type(e).__name__
for walked_node in ast.walk(func_node):
if isinstance(walked_node, ast.Return):
t = _expr_type(walked_node.value)
if found_type is None:
found_type = t
elif found_type != t:
raise ValueError(f"Conflicting return types: {found_type} vs {t}")
return found_type or "None"
# For string assignment to fixed-size arrays
def assign_string_to_array(builder, target_array_ptr, source_string_ptr, array_length):
"""
Copy a string (i8*) to a fixed-size array ([N x i8]*)

View File

@ -1,3 +1,10 @@
"""
Utility functions for handling return statements in BPF functions.
Provides handlers for different types of returns including XDP actions,
None returns, and standard returns.
"""
import logging
import ast
@ -14,19 +21,19 @@ XDP_ACTIONS = {
}
def handle_none_return(builder) -> bool:
def _handle_none_return(builder) -> bool:
"""Handle return or return None -> returns 0."""
builder.ret(ir.Constant(ir.IntType(64), 0))
logger.debug("Generated default return: 0")
return True
def is_xdp_name(name: str) -> bool:
def _is_xdp_name(name: str) -> bool:
"""Check if a name is an XDP action"""
return name in XDP_ACTIONS
def handle_xdp_return(stmt: ast.Return, builder, ret_type) -> bool:
def _handle_xdp_return(stmt: ast.Return, builder, ret_type) -> bool:
"""Handle XDP returns"""
if not isinstance(stmt.value, ast.Name):
return False
@ -37,6 +44,7 @@ def handle_xdp_return(stmt: ast.Return, builder, ret_type) -> bool:
raise ValueError(
f"Unknown XDP action: {action_name}. Available: {XDP_ACTIONS.keys()}"
)
return False
value = XDP_ACTIONS[action_name]
builder.ret(ir.Constant(ret_type, value))

View File

@ -1,3 +1,10 @@
"""
Global variables and compiler metadata processing.
This module handles BPF global variables and emits the @llvm.compiler.used
metadata to prevent LLVM from optimizing away important symbols.
"""
from llvmlite import ir
import ast
@ -12,6 +19,16 @@ global_sym_tab = []
def populate_global_symbol_table(tree, module: ir.Module):
"""
Populate the global symbol table with BPF functions, maps, and globals.
Args:
tree: The Python AST to scan for global symbols
module: The LLVM IR module (not used in current implementation)
Returns:
False (legacy return value)
"""
for node in tree.body:
if isinstance(node, ast.FunctionDef):
for dec in node.decorator_list:
@ -33,6 +50,17 @@ def populate_global_symbol_table(tree, module: ir.Module):
def emit_global(module: ir.Module, node, name):
"""
Emit a BPF global variable into the LLVM IR module.
Args:
module: The LLVM IR module to add the global variable to
node: The AST function node containing the global definition
name: The name of the global variable
Returns:
The created global variable
"""
logger.info(f"global identifier {name} processing")
# deduce LLVM type from the annotated return
if not isinstance(node.returns, ast.Name):
@ -117,7 +145,11 @@ def globals_processing(tree, module):
def emit_llvm_compiler_used(module: ir.Module, names: list[str]):
"""
Emit the @llvm.compiler.used global given a list of function/global names.
Emit the @llvm.compiler.used global to prevent LLVM from optimizing away symbols.
Args:
module: The LLVM IR module to add the compiler.used metadata to
names: List of function/global names that must be preserved
"""
ptr_ty = ir.PointerType()
used_array_ty = ir.ArrayType(ptr_ty, len(names))
@ -138,6 +170,13 @@ def emit_llvm_compiler_used(module: ir.Module, names: list[str]):
def globals_list_creation(tree, module: ir.Module):
"""
Collect all BPF symbols and emit @llvm.compiler.used metadata.
Args:
tree: The Python AST to scan for symbols
module: The LLVM IR module to add metadata to
"""
collected = ["LICENSE"]
for node in tree.body:

View File

@ -1,70 +1,15 @@
from .helper_registry import HelperHandlerRegistry
from .helper_utils import reset_scratch_pool
from .bpf_helper_handler import handle_helper_call, emit_probe_read_kernel_str_call
from .helpers import ktime, pid, deref, comm, probe_read_str, XDP_DROP, XDP_PASS
"""BPF helper functions and handlers."""
# Register the helper handler with expr module
def _register_helper_handler():
"""Register helper call handler with the expression evaluator"""
from pythonbpf.expr.expr_pass import CallHandlerRegistry
def helper_call_handler(
call, module, builder, func, local_sym_tab, map_sym_tab, structs_sym_tab
):
"""Check if call is a helper and handle it"""
import ast
# Check for direct helper calls (e.g., ktime(), print())
if isinstance(call.func, ast.Name):
if HelperHandlerRegistry.has_handler(call.func.id):
return handle_helper_call(
call,
module,
builder,
func,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
# Check for method calls (e.g., map.lookup())
elif isinstance(call.func, ast.Attribute):
method_name = call.func.attr
# Handle: my_map.lookup(key)
if isinstance(call.func.value, ast.Name):
obj_name = call.func.value.id
if map_sym_tab and obj_name in map_sym_tab:
if HelperHandlerRegistry.has_handler(method_name):
return handle_helper_call(
call,
module,
builder,
func,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
return None
CallHandlerRegistry.set_handler(helper_call_handler)
# Register on module import
_register_helper_handler()
from .helper_utils import HelperHandlerRegistry
from .bpf_helper_handler import handle_helper_call
from .helpers import ktime, pid, deref, XDP_DROP, XDP_PASS
__all__ = [
"HelperHandlerRegistry",
"reset_scratch_pool",
"handle_helper_call",
"emit_probe_read_kernel_str_call",
"ktime",
"pid",
"deref",
"comm",
"probe_read_str",
"XDP_DROP",
"XDP_PASS",
]

View File

@ -1,18 +1,22 @@
"""
BPF helper function handlers for LLVM IR emission.
This module provides handlers for various BPF helper functions, emitting
the appropriate LLVM IR to call kernel BPF helpers like map operations,
printing, time functions, etc.
"""
import ast
from llvmlite import ir
from enum import Enum
from .helper_registry import HelperHandlerRegistry
from .helper_utils import (
HelperHandlerRegistry,
get_or_create_ptr_from_arg,
get_flags_val,
handle_fstring_print,
simple_string_print,
get_data_ptr_and_size,
get_buffer_ptr_and_size,
get_char_array_ptr_and_size,
get_ptr_from_arg,
)
from .printk_formatter import simple_string_print, handle_fstring_print
from logging import Logger
import logging
@ -20,15 +24,15 @@ logger: Logger = logging.getLogger(__name__)
class BPFHelperID(Enum):
"""Enumeration of BPF helper function IDs."""
BPF_MAP_LOOKUP_ELEM = 1
BPF_MAP_UPDATE_ELEM = 2
BPF_MAP_DELETE_ELEM = 3
BPF_KTIME_GET_NS = 5
BPF_PRINTK = 6
BPF_GET_CURRENT_PID_TGID = 14
BPF_GET_CURRENT_COMM = 16
BPF_PERF_EVENT_OUTPUT = 25
BPF_PROBE_READ_KERNEL_STR = 115
@HelperHandlerRegistry.register("ktime")
@ -40,7 +44,6 @@ def bpf_ktime_get_ns_emitter(
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_ktime_get_ns helper function call.
@ -63,7 +66,6 @@ def bpf_map_lookup_elem_emitter(
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_map_lookup_elem helper function call.
@ -72,17 +74,11 @@ def bpf_map_lookup_elem_emitter(
raise ValueError(
f"Map lookup expects exactly one argument (key), got {len(call.args)}"
)
key_ptr = get_or_create_ptr_from_arg(
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab
)
key_ptr = get_or_create_ptr_from_arg(call.args[0], builder, local_sym_tab)
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
# TODO: I have changed the return type to i64*, as we are
# allocating space for that type in allocate_mem. This is
# temporary, and we will honour other widths later. But this
# allows us to have cool binary ops on the returned value.
fn_type = ir.FunctionType(
ir.PointerType(ir.IntType(64)), # Return type: void*
ir.PointerType(), # Return type: void*
[ir.PointerType(), ir.PointerType()], # Args: (void*, void*)
var_arg=False,
)
@ -105,7 +101,6 @@ def bpf_printk_emitter(
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""Emit LLVM IR for bpf_printk helper function call."""
if not hasattr(func, "_fmt_counter"):
@ -141,7 +136,7 @@ def bpf_printk_emitter(
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
builder.call(fn_ptr, args, tail=True)
return True
return None
@HelperHandlerRegistry.register("update")
@ -153,7 +148,6 @@ def bpf_map_update_elem_emitter(
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_map_update_elem helper function call.
@ -168,12 +162,8 @@ def bpf_map_update_elem_emitter(
value_arg = call.args[1]
flags_arg = call.args[2] if len(call.args) > 2 else None
key_ptr = get_or_create_ptr_from_arg(
func, module, key_arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab
)
value_ptr = get_or_create_ptr_from_arg(
func, module, value_arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab
)
key_ptr = get_or_create_ptr_from_arg(key_arg, builder, local_sym_tab)
value_ptr = get_or_create_ptr_from_arg(value_arg, builder, local_sym_tab)
flags_val = get_flags_val(flags_arg, builder, local_sym_tab)
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
@ -208,7 +198,6 @@ def bpf_map_delete_elem_emitter(
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_map_delete_elem helper function call.
@ -218,9 +207,7 @@ def bpf_map_delete_elem_emitter(
raise ValueError(
f"Map delete expects exactly one argument (key), got {len(call.args)}"
)
key_ptr = get_or_create_ptr_from_arg(
func, module, call.args[0], builder, local_sym_tab, map_sym_tab, struct_sym_tab
)
key_ptr = get_or_create_ptr_from_arg(call.args[0], builder, local_sym_tab)
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
# Define function type for bpf_map_delete_elem
@ -239,63 +226,6 @@ def bpf_map_delete_elem_emitter(
return result, None
@HelperHandlerRegistry.register("comm")
def bpf_get_current_comm_emitter(
call,
map_ptr,
module,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_get_current_comm helper function call.
Accepts: comm(dataobj.field) or comm(my_buffer)
"""
if not call.args or len(call.args) != 1:
raise ValueError(
f"comm expects exactly one argument (buffer), got {len(call.args)}"
)
buf_arg = call.args[0]
# Extract buffer pointer and size
buf_ptr, buf_size = get_buffer_ptr_and_size(
buf_arg, builder, local_sym_tab, struct_sym_tab
)
# Validate it's a char array
if not isinstance(
buf_ptr.type.pointee, ir.ArrayType
) or buf_ptr.type.pointee.element != ir.IntType(8):
raise ValueError(
f"comm expects a char array buffer, got {buf_ptr.type.pointee}"
)
# Cast to void* and call helper
buf_void_ptr = builder.bitcast(buf_ptr, ir.PointerType())
fn_type = ir.FunctionType(
ir.IntType(64),
[ir.PointerType(), ir.IntType(32)],
var_arg=False,
)
fn_ptr = builder.inttoptr(
ir.Constant(ir.IntType(64), BPFHelperID.BPF_GET_CURRENT_COMM.value),
ir.PointerType(fn_type),
)
result = builder.call(
fn_ptr, [buf_void_ptr, ir.Constant(ir.IntType(32), buf_size)], tail=False
)
logger.info(f"Emitted bpf_get_current_comm with {buf_size} byte buffer")
return result, None
@HelperHandlerRegistry.register("pid")
def bpf_get_current_pid_tgid_emitter(
call,
@ -305,7 +235,6 @@ def bpf_get_current_pid_tgid_emitter(
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_get_current_pid_tgid helper function call.
@ -332,8 +261,12 @@ def bpf_perf_event_output_handler(
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""
Emit LLVM IR for bpf_perf_event_output helper function call.
This allows sending data to userspace via a perf event array.
"""
if len(call.args) != 1:
raise ValueError(
f"Perf event output expects exactly one argument, got {len(call.args)}"
@ -371,68 +304,6 @@ def bpf_perf_event_output_handler(
return result, None
def emit_probe_read_kernel_str_call(builder, dst_ptr, dst_size, src_ptr):
"""Emit LLVM IR call to bpf_probe_read_kernel_str"""
fn_type = ir.FunctionType(
ir.IntType(64),
[ir.PointerType(), ir.IntType(32), ir.PointerType()],
var_arg=False,
)
fn_ptr = builder.inttoptr(
ir.Constant(ir.IntType(64), BPFHelperID.BPF_PROBE_READ_KERNEL_STR.value),
ir.PointerType(fn_type),
)
result = builder.call(
fn_ptr,
[
builder.bitcast(dst_ptr, ir.PointerType()),
ir.Constant(ir.IntType(32), dst_size),
builder.bitcast(src_ptr, ir.PointerType()),
],
tail=False,
)
logger.info(f"Emitted bpf_probe_read_kernel_str (size={dst_size})")
return result
@HelperHandlerRegistry.register("probe_read_str")
def bpf_probe_read_kernel_str_emitter(
call,
map_ptr,
module,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
map_sym_tab=None,
):
"""Emit LLVM IR for bpf_probe_read_kernel_str helper."""
if len(call.args) != 2:
raise ValueError(
f"probe_read_str expects 2 args (dst, src), got {len(call.args)}"
)
# Get destination buffer (char array -> i8*)
dst_ptr, dst_size = get_char_array_ptr_and_size(
call.args[0], builder, local_sym_tab, struct_sym_tab
)
# Get source pointer (evaluate expression)
src_ptr, src_type = get_ptr_from_arg(
call.args[1], func, module, builder, local_sym_tab, map_sym_tab, struct_sym_tab
)
# Emit the helper call
result = emit_probe_read_kernel_str_call(builder, dst_ptr, dst_size, src_ptr)
logger.info(f"Emitted bpf_probe_read_kernel_str (size={dst_size})")
return result, ir.IntType(64)
def handle_helper_call(
call,
module,
@ -446,6 +317,7 @@ def handle_helper_call(
# Helper function to get map pointer and invoke handler
def invoke_helper(method_name, map_ptr=None):
"""Helper function to look up and invoke a registered handler."""
handler = HelperHandlerRegistry.get_handler(method_name)
if not handler:
raise NotImplementedError(
@ -459,7 +331,6 @@ def handle_helper_call(
func,
local_sym_tab,
struct_sym_tab,
map_sym_tab,
)
# Handle direct function calls (e.g., print(), ktime())

View File

@ -1,27 +0,0 @@
from typing import Callable
class HelperHandlerRegistry:
"""Registry for BPF helpers"""
_handlers: dict[str, Callable] = {}
@classmethod
def register(cls, helper_name):
"""Decorator to register a handler function for a helper"""
def decorator(func):
cls._handlers[helper_name] = func
return func
return decorator
@classmethod
def get_handler(cls, helper_name):
"""Get the handler function for a helper"""
return cls._handlers.get(helper_name)
@classmethod
def has_handler(cls, helper_name):
"""Check if a handler function is registered for a helper"""
return helper_name in cls._handlers

View File

@ -1,102 +1,130 @@
"""
Utility functions for BPF helper function handling.
This module provides utility functions for processing BPF helper function
calls, including argument handling, string formatting for bpf_printk,
and a registry for helper function handlers.
"""
import ast
import logging
from collections.abc import Callable
from llvmlite import ir
from pythonbpf.expr import (
get_operand_value,
eval_expr,
)
from pythonbpf.expr import eval_expr
logger = logging.getLogger(__name__)
class ScratchPoolManager:
"""Manage the temporary helper variables in local_sym_tab"""
class HelperHandlerRegistry:
"""Registry for BPF helpers"""
def __init__(self):
self._counter = 0
_handlers: dict[str, Callable] = {}
@property
def counter(self):
return self._counter
@classmethod
def register(cls, helper_name):
"""Decorator to register a handler function for a helper"""
def reset(self):
self._counter = 0
logger.debug("Scratch pool counter reset to 0")
def decorator(func):
"""Decorator that registers the handler function."""
cls._handlers[helper_name] = func
return func
def get_next_temp(self, local_sym_tab):
temp_name = f"__helper_temp_{self._counter}"
self._counter += 1
return decorator
if temp_name not in local_sym_tab:
raise ValueError(
f"Scratch pool exhausted or inadequate: {temp_name}. "
f"Current counter: {self._counter}"
)
@classmethod
def get_handler(cls, helper_name):
"""Get the handler function for a helper"""
return cls._handlers.get(helper_name)
return local_sym_tab[temp_name].var, temp_name
_temp_pool_manager = ScratchPoolManager() # Singleton instance
def reset_scratch_pool():
"""Reset the scratch pool counter"""
_temp_pool_manager.reset()
# ============================================================================
# Argument Preparation
# ============================================================================
@classmethod
def has_handler(cls, helper_name):
"""Check if a handler function is registered for a helper"""
return helper_name in cls._handlers
def get_var_ptr_from_name(var_name, local_sym_tab):
"""Get a pointer to a variable from the symbol table."""
"""
Get a pointer to a variable from the symbol table.
Args:
var_name: Name of the variable to look up
local_sym_tab: Local symbol table
Returns:
Pointer to the variable
Raises:
ValueError: If the variable is not found
"""
if local_sym_tab and var_name in local_sym_tab:
return local_sym_tab[var_name].var
raise ValueError(f"Variable '{var_name}' not found in local symbol table")
def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64):
"""Create a pointer to an integer constant."""
def create_int_constant_ptr(value, builder, int_width=64):
"""
Create a pointer to an integer constant.
Args:
value: The integer value
builder: LLVM IR builder
int_width: Width of the integer in bits (default: 64)
Returns:
Pointer to the allocated integer constant
"""
# Default to 64-bit integer
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab)
logger.info(f"Using temp variable '{temp_name}' for int constant {value}")
const_val = ir.Constant(ir.IntType(int_width), value)
builder.store(const_val, ptr)
int_type = ir.IntType(int_width)
ptr = builder.alloca(int_type)
ptr.align = int_type.width // 8
builder.store(ir.Constant(int_type, value), ptr)
return ptr
def get_or_create_ptr_from_arg(
func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab=None
):
"""Extract or create pointer from the call arguments."""
def get_or_create_ptr_from_arg(arg, builder, local_sym_tab):
"""
Extract or create pointer from call arguments.
Args:
arg: The AST argument node
builder: LLVM IR builder
local_sym_tab: Local symbol table
Returns:
Pointer to the argument value
Raises:
NotImplementedError: If the argument type is not supported
"""
if isinstance(arg, ast.Name):
ptr = get_var_ptr_from_name(arg.id, local_sym_tab)
elif isinstance(arg, ast.Constant) and isinstance(arg.value, int):
ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab)
ptr = create_int_constant_ptr(arg.value, builder)
else:
# Evaluate the expression and store the result in a temp variable
val = get_operand_value(
func, module, arg, builder, local_sym_tab, map_sym_tab, struct_sym_tab
raise NotImplementedError(
"Only simple variable names are supported as args in map helpers."
)
if val is None:
raise ValueError("Failed to evaluate expression for helper arg.")
# NOTE: We assume the result is an int64 for now
# if isinstance(arg, ast.Attribute):
# return val
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab)
logger.info(f"Using temp variable '{temp_name}' for expression result")
builder.store(val, ptr)
return ptr
def get_flags_val(arg, builder, local_sym_tab):
"""Extract or create flags value from the call arguments."""
"""
Extract or create flags value from call arguments.
Args:
arg: The AST argument node for flags
builder: LLVM IR builder
local_sym_tab: Local symbol table
Returns:
Integer flags value or LLVM IR value
Raises:
ValueError: If a variable is not found in symbol table
NotImplementedError: If the argument type is not supported
"""
if not arg:
return 0
@ -114,6 +142,231 @@ def get_flags_val(arg, builder, local_sym_tab):
)
def simple_string_print(string_value, module, builder, func):
"""
Prepare arguments for bpf_printk from a simple string value.
Args:
string_value: The string to print
module: LLVM IR module
builder: LLVM IR builder
func: The LLVM IR function being built
Returns:
List of arguments for bpf_printk
"""
fmt_str = string_value + "\n\0"
fmt_ptr = _create_format_string_global(fmt_str, func, module, builder)
args = [fmt_ptr, ir.Constant(ir.IntType(32), len(fmt_str))]
return args
def handle_fstring_print(
joined_str,
module,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
):
"""
Handle f-string formatting for bpf_printk emitter.
Args:
joined_str: AST JoinedStr node representing the f-string
module: LLVM IR module
builder: LLVM IR builder
func: The LLVM IR function being built
local_sym_tab: Local symbol table
struct_sym_tab: Struct symbol table
Returns:
List of arguments for bpf_printk
Raises:
NotImplementedError: If f-string contains unsupported value types
"""
fmt_parts = []
exprs = []
for value in joined_str.values:
logger.debug(f"Processing f-string value: {ast.dump(value)}")
if isinstance(value, ast.Constant):
_process_constant_in_fstring(value, fmt_parts, exprs)
elif isinstance(value, ast.FormattedValue):
_process_fval(
value,
fmt_parts,
exprs,
local_sym_tab,
struct_sym_tab,
)
else:
raise NotImplementedError(f"Unsupported f-string value type: {type(value)}")
fmt_str = "".join(fmt_parts)
args = simple_string_print(fmt_str, module, builder, func)
# NOTE: Process expressions (limited to 3 due to BPF constraints)
if len(exprs) > 3:
logger.warning("bpf_printk supports up to 3 args, extra args will be ignored.")
for expr in exprs[:3]:
arg_value = _prepare_expr_args(
expr,
func,
module,
builder,
local_sym_tab,
struct_sym_tab,
)
args.append(arg_value)
return args
def _process_constant_in_fstring(cst, fmt_parts, exprs):
"""Process constant values in f-string."""
if isinstance(cst.value, str):
fmt_parts.append(cst.value)
elif isinstance(cst.value, int):
fmt_parts.append("%lld")
exprs.append(ir.Constant(ir.IntType(64), cst.value))
else:
raise NotImplementedError(
f"Unsupported constant type in f-string: {type(cst.value)}"
)
def _process_fval(fval, fmt_parts, exprs, local_sym_tab, struct_sym_tab):
"""Process formatted values in f-string."""
logger.debug(f"Processing formatted value: {ast.dump(fval)}")
if isinstance(fval.value, ast.Name):
_process_name_in_fval(fval.value, fmt_parts, exprs, local_sym_tab)
elif isinstance(fval.value, ast.Attribute):
_process_attr_in_fval(
fval.value,
fmt_parts,
exprs,
local_sym_tab,
struct_sym_tab,
)
else:
raise NotImplementedError(
f"Unsupported formatted value in f-string: {type(fval.value)}"
)
def _process_name_in_fval(name_node, fmt_parts, exprs, local_sym_tab):
"""Process name nodes in formatted values."""
if local_sym_tab and name_node.id in local_sym_tab:
_, var_type, tmp = local_sym_tab[name_node.id]
_populate_fval(var_type, name_node, fmt_parts, exprs)
def _process_attr_in_fval(attr_node, fmt_parts, exprs, local_sym_tab, struct_sym_tab):
"""Process attribute nodes in formatted values."""
if (
isinstance(attr_node.value, ast.Name)
and local_sym_tab
and attr_node.value.id in local_sym_tab
):
var_name = attr_node.value.id
field_name = attr_node.attr
var_type = local_sym_tab[var_name].metadata
if var_type not in struct_sym_tab:
raise ValueError(
f"Struct '{var_type}' for '{var_name}' not in symbol table"
)
struct_info = struct_sym_tab[var_type]
if field_name not in struct_info.fields:
raise ValueError(f"Field '{field_name}' not found in struct '{var_type}'")
field_type = struct_info.field_type(field_name)
_populate_fval(field_type, attr_node, fmt_parts, exprs)
else:
raise NotImplementedError(
"Only simple attribute on local vars is supported in f-strings."
)
def _populate_fval(ftype, node, fmt_parts, exprs):
"""Populate format parts and expressions based on field type."""
if isinstance(ftype, ir.IntType):
# TODO: We print as signed integers only for now
if ftype.width == 64:
fmt_parts.append("%lld")
exprs.append(node)
elif ftype.width == 32:
fmt_parts.append("%d")
exprs.append(node)
else:
raise NotImplementedError(
f"Unsupported integer width in f-string: {ftype.width}"
)
elif ftype == ir.PointerType(ir.IntType(8)):
# NOTE: We assume i8* is a string
fmt_parts.append("%s")
exprs.append(node)
else:
raise NotImplementedError(f"Unsupported field type in f-string: {ftype}")
def _create_format_string_global(fmt_str, func, module, builder):
"""Create a global variable for the format string."""
fmt_name = f"{func.name}____fmt{func._fmt_counter}"
func._fmt_counter += 1
fmt_gvar = ir.GlobalVariable(
module, ir.ArrayType(ir.IntType(8), len(fmt_str)), name=fmt_name
)
fmt_gvar.global_constant = True
fmt_gvar.initializer = ir.Constant(
ir.ArrayType(ir.IntType(8), len(fmt_str)), bytearray(fmt_str.encode("utf8"))
)
fmt_gvar.linkage = "internal"
fmt_gvar.align = 1
return builder.bitcast(fmt_gvar, ir.PointerType())
def _prepare_expr_args(expr, func, module, builder, local_sym_tab, struct_sym_tab):
"""Evaluate and prepare an expression to use as an arg for bpf_printk."""
val, _ = eval_expr(
func,
module,
builder,
expr,
local_sym_tab,
None,
struct_sym_tab,
)
if val:
if isinstance(val.type, ir.PointerType):
val = builder.ptrtoint(val, ir.IntType(64))
elif isinstance(val.type, ir.IntType):
if val.type.width < 64:
val = builder.sext(val, ir.IntType(64))
else:
logger.warning(
"Only int and ptr supported in bpf_printk args. Others default to 0."
)
val = ir.Constant(ir.IntType(64), 0)
return val
else:
logger.warning(
"Failed to evaluate expression for bpf_printk argument. "
"It will be converted to 0."
)
return ir.Constant(ir.IntType(64), 0)
def get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab):
"""Extract data pointer and size information for perf event output."""
if isinstance(data_arg, ast.Name):
@ -137,140 +390,3 @@ def get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab):
raise NotImplementedError(
"Only simple object names are supported as data in perf event output."
)
def get_buffer_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab):
"""Extract buffer pointer and size from either a struct field or variable."""
# Case 1: Struct field (obj.field)
if isinstance(buf_arg, ast.Attribute):
if not isinstance(buf_arg.value, ast.Name):
raise ValueError(
"Only simple struct field access supported (e.g., obj.field)"
)
struct_name = buf_arg.value.id
field_name = buf_arg.attr
# Lookup struct
if not local_sym_tab or struct_name not in local_sym_tab:
raise ValueError(f"Struct '{struct_name}' not found")
struct_type = local_sym_tab[struct_name].metadata
if not struct_sym_tab or struct_type not in struct_sym_tab:
raise ValueError(f"Struct type '{struct_type}' not found")
struct_info = struct_sym_tab[struct_type]
# Get field pointer and type
struct_ptr = local_sym_tab[struct_name].var
field_ptr = struct_info.gep(builder, struct_ptr, field_name)
field_type = struct_info.field_type(field_name)
if not isinstance(field_type, ir.ArrayType):
raise ValueError(f"Field '{field_name}' must be an array type")
return field_ptr, field_type.count
# Case 2: Variable name
elif isinstance(buf_arg, ast.Name):
var_name = buf_arg.id
if not local_sym_tab or var_name not in local_sym_tab:
raise ValueError(f"Variable '{var_name}' not found")
var_ptr = local_sym_tab[var_name].var
var_type = local_sym_tab[var_name].ir_type
if not isinstance(var_type, ir.ArrayType):
raise ValueError(f"Variable '{var_name}' must be an array type")
return var_ptr, var_type.count
else:
raise ValueError(
"comm expects either a struct field (obj.field) or variable name"
)
def get_char_array_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab):
"""Get pointer to char array and its size."""
# Struct field: obj.field
if isinstance(buf_arg, ast.Attribute) and isinstance(buf_arg.value, ast.Name):
var_name = buf_arg.value.id
field_name = buf_arg.attr
if not (local_sym_tab and var_name in local_sym_tab):
raise ValueError(f"Variable '{var_name}' not found")
struct_type = local_sym_tab[var_name].metadata
if not (struct_sym_tab and struct_type in struct_sym_tab):
raise ValueError(f"Struct type '{struct_type}' not found")
struct_info = struct_sym_tab[struct_type]
if field_name not in struct_info.fields:
raise ValueError(f"Field '{field_name}' not found")
field_type = struct_info.field_type(field_name)
if not _is_char_array(field_type):
raise ValueError("Expected char array field")
struct_ptr = local_sym_tab[var_name].var
field_ptr = struct_info.gep(builder, struct_ptr, field_name)
# GEP to first element: [N x i8]* -> i8*
buf_ptr = builder.gep(
field_ptr,
[ir.Constant(ir.IntType(32), 0), ir.Constant(ir.IntType(32), 0)],
inbounds=True,
)
return buf_ptr, field_type.count
elif isinstance(buf_arg, ast.Name):
# NOTE: We shouldn't be doing this as we can't get size info
var_name = buf_arg.id
if not (local_sym_tab and var_name in local_sym_tab):
raise ValueError(f"Variable '{var_name}' not found")
var_ptr = local_sym_tab[var_name].var
var_type = local_sym_tab[var_name].ir_type
if not isinstance(var_type, ir.PointerType) or not isinstance(
var_type.pointee, ir.IntType(8)
):
raise ValueError("Expected str ptr variable")
return var_ptr, 256 # Size unknown for str ptr, using 256 as default
else:
raise ValueError("Expected struct field or variable name")
def _is_char_array(ir_type):
"""Check if IR type is [N x i8]."""
return (
isinstance(ir_type, ir.ArrayType)
and isinstance(ir_type.element, ir.IntType)
and ir_type.element.width == 8
)
def get_ptr_from_arg(
arg, func, module, builder, local_sym_tab, map_sym_tab, struct_sym_tab
):
"""Evaluate argument and return pointer value"""
result = eval_expr(
func, module, builder, arg, local_sym_tab, map_sym_tab, struct_sym_tab
)
if not result:
raise ValueError("Failed to evaluate argument")
val, val_type = result
if not isinstance(val_type, ir.PointerType):
raise ValueError(f"Expected pointer type, got {val_type}")
return val, val_type

View File

@ -1,34 +1,39 @@
"""
BPF helper function stubs for Python type hints.
This module provides Python stub functions that represent BPF helper functions.
These stubs are used for type checking and will be replaced with actual BPF
helper calls during compilation.
"""
import ctypes
def ktime():
"""get current ktime"""
"""
Get the current kernel time in nanoseconds.
Returns:
A c_int64 stub value (actual implementation is in BPF runtime)
"""
return ctypes.c_int64(0)
def pid():
"""get current process id"""
"""
Get the current process ID (PID).
Returns:
A c_int32 stub value (actual implementation is in BPF runtime)
"""
return ctypes.c_int32(0)
def deref(ptr):
"""dereference a pointer"""
"dereference a pointer"
result = ctypes.cast(ptr, ctypes.POINTER(ctypes.c_void_p)).contents.value
return result if result is not None else 0
def comm(buf):
"""get current process command name"""
return ctypes.c_int64(0)
def probe_read_str(dst, src):
"""Safely read a null-terminated string from kernel memory"""
return ctypes.c_int64(0)
XDP_ABORTED = ctypes.c_int64(0)
XDP_DROP = ctypes.c_int64(1)
XDP_PASS = ctypes.c_int64(2)
XDP_TX = ctypes.c_int64(3)
XDP_REDIRECT = ctypes.c_int64(4)

View File

@ -1,316 +0,0 @@
import ast
import logging
from llvmlite import ir
from pythonbpf.expr import eval_expr, get_base_type_and_depth, deref_to_depth
from pythonbpf.expr.vmlinux_registry import VmlinuxHandlerRegistry
logger = logging.getLogger(__name__)
def simple_string_print(string_value, module, builder, func):
"""Prepare arguments for bpf_printk from a simple string value"""
fmt_str = string_value + "\n\0"
fmt_ptr = _create_format_string_global(fmt_str, func, module, builder)
args = [fmt_ptr, ir.Constant(ir.IntType(32), len(fmt_str))]
return args
def handle_fstring_print(
joined_str,
module,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
):
"""Handle f-string formatting for bpf_printk emitter."""
fmt_parts = []
exprs = []
for value in joined_str.values:
logger.debug(f"Processing f-string value: {ast.dump(value)}")
if isinstance(value, ast.Constant):
_process_constant_in_fstring(value, fmt_parts, exprs)
elif isinstance(value, ast.FormattedValue):
_process_fval(
value,
fmt_parts,
exprs,
local_sym_tab,
struct_sym_tab,
)
else:
raise NotImplementedError(f"Unsupported f-string value type: {type(value)}")
fmt_str = "".join(fmt_parts)
args = simple_string_print(fmt_str, module, builder, func)
# NOTE: Process expressions (limited to 3 due to BPF constraints)
if len(exprs) > 3:
logger.warning("bpf_printk supports up to 3 args, extra args will be ignored.")
for expr in exprs[:3]:
arg_value = _prepare_expr_args(
expr,
func,
module,
builder,
local_sym_tab,
struct_sym_tab,
)
args.append(arg_value)
return args
# ============================================================================
# Internal Helpers
# ============================================================================
def _process_constant_in_fstring(cst, fmt_parts, exprs):
"""Process constant values in f-string."""
if isinstance(cst.value, str):
fmt_parts.append(cst.value)
elif isinstance(cst.value, int):
fmt_parts.append("%lld")
exprs.append(ir.Constant(ir.IntType(64), cst.value))
else:
raise NotImplementedError(
f"Unsupported constant type in f-string: {type(cst.value)}"
)
def _process_fval(fval, fmt_parts, exprs, local_sym_tab, struct_sym_tab):
"""Process formatted values in f-string."""
logger.debug(f"Processing formatted value: {ast.dump(fval)}")
if isinstance(fval.value, ast.Name):
_process_name_in_fval(fval.value, fmt_parts, exprs, local_sym_tab)
elif isinstance(fval.value, ast.Attribute):
_process_attr_in_fval(
fval.value,
fmt_parts,
exprs,
local_sym_tab,
struct_sym_tab,
)
else:
raise NotImplementedError(
f"Unsupported formatted value in f-string: {type(fval.value)}"
)
def _process_name_in_fval(name_node, fmt_parts, exprs, local_sym_tab):
"""Process name nodes in formatted values."""
if local_sym_tab and name_node.id in local_sym_tab:
_, var_type, tmp = local_sym_tab[name_node.id]
_populate_fval(var_type, name_node, fmt_parts, exprs)
else:
# Try to resolve through vmlinux registry if not in local symbol table
result = VmlinuxHandlerRegistry.handle_name(name_node.id)
if result:
val, var_type = result
_populate_fval(var_type, name_node, fmt_parts, exprs)
else:
raise ValueError(
f"Variable '{name_node.id}' not found in symbol table or vmlinux"
)
def _process_attr_in_fval(attr_node, fmt_parts, exprs, local_sym_tab, struct_sym_tab):
"""Process attribute nodes in formatted values."""
if (
isinstance(attr_node.value, ast.Name)
and local_sym_tab
and attr_node.value.id in local_sym_tab
):
var_name = attr_node.value.id
field_name = attr_node.attr
var_type = local_sym_tab[var_name].metadata
if var_type not in struct_sym_tab:
raise ValueError(
f"Struct '{var_type}' for '{var_name}' not in symbol table"
)
struct_info = struct_sym_tab[var_type]
if field_name not in struct_info.fields:
raise ValueError(f"Field '{field_name}' not found in struct '{var_type}'")
field_type = struct_info.field_type(field_name)
_populate_fval(field_type, attr_node, fmt_parts, exprs)
else:
raise NotImplementedError(
"Only simple attribute on local vars is supported in f-strings."
)
def _populate_fval(ftype, node, fmt_parts, exprs):
"""Populate format parts and expressions based on field type."""
if isinstance(ftype, ir.IntType):
# TODO: We print as signed integers only for now
if ftype.width == 64:
fmt_parts.append("%lld")
exprs.append(node)
elif ftype.width == 32:
fmt_parts.append("%d")
exprs.append(node)
else:
raise NotImplementedError(
f"Unsupported integer width in f-string: {ftype.width}"
)
elif isinstance(ftype, ir.PointerType):
target, depth = get_base_type_and_depth(ftype)
if isinstance(target, ir.IntType):
if target.width == 64:
fmt_parts.append("%lld")
exprs.append(node)
elif target.width == 32:
fmt_parts.append("%d")
exprs.append(node)
elif target.width == 8 and depth == 1:
# NOTE: Assume i8* is a string
fmt_parts.append("%s")
exprs.append(node)
else:
raise NotImplementedError(
f"Unsupported pointer target type in f-string: {target}"
)
else:
raise NotImplementedError(
f"Unsupported pointer target type in f-string: {target}"
)
elif isinstance(ftype, ir.ArrayType):
if isinstance(ftype.element, ir.IntType) and ftype.element.width == 8:
# Char array
fmt_parts.append("%s")
exprs.append(node)
else:
raise NotImplementedError(
f"Unsupported array element type in f-string: {ftype.element}"
)
else:
raise NotImplementedError(f"Unsupported field type in f-string: {ftype}")
def _create_format_string_global(fmt_str, func, module, builder):
"""Create a global variable for the format string."""
fmt_name = f"{func.name}____fmt{func._fmt_counter}"
func._fmt_counter += 1
fmt_gvar = ir.GlobalVariable(
module, ir.ArrayType(ir.IntType(8), len(fmt_str)), name=fmt_name
)
fmt_gvar.global_constant = True
fmt_gvar.initializer = ir.Constant(
ir.ArrayType(ir.IntType(8), len(fmt_str)), bytearray(fmt_str.encode("utf8"))
)
fmt_gvar.linkage = "internal"
fmt_gvar.align = 1
return builder.bitcast(fmt_gvar, ir.PointerType())
def _prepare_expr_args(expr, func, module, builder, local_sym_tab, struct_sym_tab):
"""Evaluate and prepare an expression to use as an arg for bpf_printk."""
# Special case: struct field char array needs pointer to first element
char_array_ptr = _get_struct_char_array_ptr(
expr, builder, local_sym_tab, struct_sym_tab
)
if char_array_ptr:
return char_array_ptr
# Regular expression evaluation
val, _ = eval_expr(func, module, builder, expr, local_sym_tab, None, struct_sym_tab)
if not val:
logger.warning("Failed to evaluate expression for bpf_printk, defaulting to 0")
return ir.Constant(ir.IntType(64), 0)
# Convert value to bpf_printk compatible type
if isinstance(val.type, ir.PointerType):
return _handle_pointer_arg(val, func, builder)
elif isinstance(val.type, ir.IntType):
return _handle_int_arg(val, builder)
else:
logger.warning(f"Unsupported type {val.type} in bpf_printk, defaulting to 0")
return ir.Constant(ir.IntType(64), 0)
def _get_struct_char_array_ptr(expr, builder, local_sym_tab, struct_sym_tab):
"""Get pointer to first element of char array in struct field, or None."""
if not (isinstance(expr, ast.Attribute) and isinstance(expr.value, ast.Name)):
return None
var_name = expr.value.id
field_name = expr.attr
# Check if it's a valid struct field
if not (
local_sym_tab
and var_name in local_sym_tab
and struct_sym_tab
and local_sym_tab[var_name].metadata in struct_sym_tab
):
return None
struct_type = local_sym_tab[var_name].metadata
struct_info = struct_sym_tab[struct_type]
if field_name not in struct_info.fields:
return None
field_type = struct_info.field_type(field_name)
# Check if it's a char array
is_char_array = (
isinstance(field_type, ir.ArrayType)
and isinstance(field_type.element, ir.IntType)
and field_type.element.width == 8
)
if not is_char_array:
return None
# Get field pointer and GEP to first element: [N x i8]* -> i8*
struct_ptr = local_sym_tab[var_name].var
field_ptr = struct_info.gep(builder, struct_ptr, field_name)
return builder.gep(
field_ptr,
[ir.Constant(ir.IntType(32), 0), ir.Constant(ir.IntType(32), 0)],
inbounds=True,
)
def _handle_pointer_arg(val, func, builder):
"""Convert pointer type for bpf_printk."""
target, depth = get_base_type_and_depth(val.type)
if not isinstance(target, ir.IntType):
logger.warning("Only int pointers supported in bpf_printk, defaulting to 0")
return ir.Constant(ir.IntType(64), 0)
# i8* is string - use as-is
if target.width == 8 and depth == 1:
return val
# Integer pointers: dereference and sign-extend to i64
if target.width >= 32:
val = deref_to_depth(func, builder, val, depth)
return builder.sext(val, ir.IntType(64))
logger.warning("Unsupported pointer width in bpf_printk, defaulting to 0")
return ir.Constant(ir.IntType(64), 0)
def _handle_int_arg(val, builder):
"""Convert integer type for bpf_printk (sign-extend to i64)."""
if val.type.width < 64:
return builder.sext(val, ir.IntType(64))
return val

View File

@ -1,3 +1,10 @@
"""
LICENSE global variable processing for BPF programs.
This module handles the processing of the LICENSE function which is required
for BPF programs to declare their license (typically "GPL").
"""
from llvmlite import ir
import ast
from logging import Logger
@ -7,6 +14,16 @@ logger: Logger = logging.getLogger(__name__)
def emit_license(module: ir.Module, license_str: str):
"""
Emit a LICENSE global variable into the LLVM IR module.
Args:
module: The LLVM IR module to add the LICENSE variable to
license_str: The license string (e.g., 'GPL')
Returns:
The created global variable
"""
license_bytes = license_str.encode("utf8") + b"\x00"
elems = [ir.Constant(ir.IntType(8), b) for b in license_bytes]
ty = ir.ArrayType(ir.IntType(8), len(elems))

View File

@ -1,3 +1,5 @@
"""BPF map types and processing."""
from .maps import HashMap, PerfEventArray, RingBuf
from .maps_pass import maps_proc

View File

@ -1,93 +0,0 @@
from pythonbpf.debuginfo import DebugInfoGenerator
from .map_types import BPFMapType
def create_map_debug_info(module, map_global, map_name, map_params):
"""Generate debug info metadata for BPF maps HASH and PERF_EVENT_ARRAY"""
generator = DebugInfoGenerator(module)
uint_type = generator.get_uint32_type()
ulong_type = generator.get_uint64_type()
array_type = generator.create_array_type(
uint_type, map_params.get("type", BPFMapType.UNSPEC).value
)
type_ptr = generator.create_pointer_type(array_type, 64)
key_ptr = generator.create_pointer_type(
array_type if "key_size" in map_params else ulong_type, 64
)
value_ptr = generator.create_pointer_type(
array_type if "value_size" in map_params else ulong_type, 64
)
elements_arr = []
# Create struct members
# scope field does not appear for some reason
cnt = 0
for elem in map_params:
if elem == "max_entries":
continue
if elem == "type":
ptr = type_ptr
elif "key" in elem:
ptr = key_ptr
else:
ptr = value_ptr
# TODO: the best way to do this is not 64, but get the size each time. this will not work for structs.
member = generator.create_struct_member(elem, ptr, cnt * 64)
elements_arr.append(member)
cnt += 1
if "max_entries" in map_params:
max_entries_array = generator.create_array_type(
uint_type, map_params["max_entries"]
)
max_entries_ptr = generator.create_pointer_type(max_entries_array, 64)
max_entries_member = generator.create_struct_member(
"max_entries", max_entries_ptr, cnt * 64
)
elements_arr.append(max_entries_member)
# Create the struct type
struct_type = generator.create_struct_type(
elements_arr, 64 * len(elements_arr), is_distinct=True
)
# Create global variable debug info
global_var = generator.create_global_var_debug_info(
map_name, struct_type, is_local=False
)
# Attach debug info to the global variable
map_global.set_metadata("dbg", global_var)
return global_var
def create_ringbuf_debug_info(module, map_global, map_name, map_params):
"""Generate debug information metadata for BPF RINGBUF map"""
generator = DebugInfoGenerator(module)
int_type = generator.get_int32_type()
type_array = generator.create_array_type(
int_type, map_params.get("type", BPFMapType.RINGBUF).value
)
type_ptr = generator.create_pointer_type(type_array, 64)
type_member = generator.create_struct_member("type", type_ptr, 0)
max_entries_array = generator.create_array_type(int_type, map_params["max_entries"])
max_entries_ptr = generator.create_pointer_type(max_entries_array, 64)
max_entries_member = generator.create_struct_member(
"max_entries", max_entries_ptr, 64
)
elements_arr = [type_member, max_entries_member]
struct_type = generator.create_struct_type(elements_arr, 128, is_distinct=True)
global_var = generator.create_global_var_debug_info(
map_name, struct_type, is_local=False
)
map_global.set_metadata("dbg", global_var)
return global_var

View File

@ -1,39 +0,0 @@
from enum import Enum
class BPFMapType(Enum):
UNSPEC = 0
HASH = 1
ARRAY = 2
PROG_ARRAY = 3
PERF_EVENT_ARRAY = 4
PERCPU_HASH = 5
PERCPU_ARRAY = 6
STACK_TRACE = 7
CGROUP_ARRAY = 8
LRU_HASH = 9
LRU_PERCPU_HASH = 10
LPM_TRIE = 11
ARRAY_OF_MAPS = 12
HASH_OF_MAPS = 13
DEVMAP = 14
SOCKMAP = 15
CPUMAP = 16
XSKMAP = 17
SOCKHASH = 18
CGROUP_STORAGE_DEPRECATED = 19
CGROUP_STORAGE = 19
REUSEPORT_SOCKARRAY = 20
PERCPU_CGROUP_STORAGE_DEPRECATED = 21
PERCPU_CGROUP_STORAGE = 21
QUEUE = 22
STACK = 23
SK_STORAGE = 24
DEVMAP_HASH = 25
STRUCT_OPS = 26
RINGBUF = 27
INODE_STORAGE = 28
TASK_STORAGE = 29
BLOOM_FILTER = 30
USER_RINGBUF = 31
CGRP_STORAGE = 32

View File

@ -1,18 +1,60 @@
"""
BPF map type definitions for Python type hints.
This module provides Python classes that represent BPF map types.
These are used for type checking and map definition; the actual BPF maps
are generated as LLVM IR during compilation.
"""
# This file provides type and function hints only and does not actually give any functionality.
class HashMap:
"""
A BPF hash map for storing key-value pairs.
This is a type hint class used during compilation. The actual BPF map
implementation is generated as LLVM IR.
"""
def __init__(self, key, value, max_entries):
"""
Initialize a HashMap definition.
Args:
key: The ctypes type for keys (e.g., c_int64)
value: The ctypes type for values (e.g., c_int64)
max_entries: Maximum number of entries the map can hold
"""
self.key = key
self.value = value
self.max_entries = max_entries
self.entries = {}
def lookup(self, key):
"""
Look up a value by key in the map.
Args:
key: The key to look up
Returns:
The value if found, None otherwise
"""
if key in self.entries:
return self.entries[key]
else:
return None
def delete(self, key):
"""
Delete an entry from the map by key.
Args:
key: The key to delete
Raises:
KeyError: If the key is not found in the map
"""
if key in self.entries:
del self.entries[key]
else:
@ -20,6 +62,17 @@ class HashMap:
# TODO: define the flags that can be added
def update(self, key, value, flags=None):
"""
Update or insert a key-value pair in the map.
Args:
key: The key to update
value: The new value
flags: Optional flags for update behavior
Raises:
KeyError: If the key is not found in the map
"""
if key in self.entries:
self.entries[key] = value
else:
@ -27,25 +80,76 @@ class HashMap:
class PerfEventArray:
"""
A BPF perf event array for sending data to userspace.
This is a type hint class used during compilation.
"""
def __init__(self, key_size, value_size):
"""
Initialize a PerfEventArray definition.
Args:
key_size: The size/type for keys
value_size: The size/type for values
"""
self.key_type = key_size
self.value_type = value_size
self.entries = {}
def output(self, data):
"""
Output data to the perf event array.
Args:
data: The data to output
"""
pass # Placeholder for output method
class RingBuf:
"""
A BPF ring buffer for efficient data transfer to userspace.
This is a type hint class used during compilation.
"""
def __init__(self, max_entries):
"""
Initialize a RingBuf definition.
Args:
max_entries: Maximum number of entries the ring buffer can hold
"""
self.max_entries = max_entries
def reserve(self, size: int, flags=0):
"""
Reserve space in the ring buffer.
Args:
size: Size in bytes to reserve
flags: Optional reservation flags
Returns:
0 as a placeholder (actual implementation is in BPF runtime)
Raises:
ValueError: If size exceeds max_entries
"""
if size > self.max_entries:
raise ValueError("size cannot be greater than set maximum entries")
return 0
def submit(self, data, flags=0):
"""
Submit data to the ring buffer.
Args:
data: The data to submit
flags: Optional submission flags
"""
pass
# add discard, output and also give names to flags and stuff

View File

@ -1,13 +1,17 @@
"""
BPF map processing and LLVM IR generation.
This module handles the processing of BPF map definitions decorated with @map,
converting them to appropriate LLVM IR global variables with BTF debug info.
"""
import ast
import logging
from logging import Logger
from llvmlite import ir
from enum import Enum
from .maps_utils import MapProcessorRegistry
from .map_types import BPFMapType
from .map_debug_info import create_map_debug_info, create_ringbuf_debug_info
from pythonbpf.expr.vmlinux_registry import VmlinuxHandlerRegistry
from pythonbpf.debuginfo import DebugInfoGenerator
import logging
logger: Logger = logging.getLogger(__name__)
@ -23,14 +27,73 @@ def maps_proc(tree, module, chunks):
def is_map(func_node):
"""
Check if a function node is decorated with @map.
Args:
func_node: The AST function node to check
Returns:
True if the function is decorated with @map, False otherwise
"""
return any(
isinstance(decorator, ast.Name) and decorator.id == "map"
for decorator in func_node.decorator_list
)
class BPFMapType(Enum):
"""Enumeration of BPF map types."""
UNSPEC = 0
HASH = 1
ARRAY = 2
PROG_ARRAY = 3
PERF_EVENT_ARRAY = 4
PERCPU_HASH = 5
PERCPU_ARRAY = 6
STACK_TRACE = 7
CGROUP_ARRAY = 8
LRU_HASH = 9
LRU_PERCPU_HASH = 10
LPM_TRIE = 11
ARRAY_OF_MAPS = 12
HASH_OF_MAPS = 13
DEVMAP = 14
SOCKMAP = 15
CPUMAP = 16
XSKMAP = 17
SOCKHASH = 18
CGROUP_STORAGE_DEPRECATED = 19
CGROUP_STORAGE = 19
REUSEPORT_SOCKARRAY = 20
PERCPU_CGROUP_STORAGE_DEPRECATED = 21
PERCPU_CGROUP_STORAGE = 21
QUEUE = 22
STACK = 23
SK_STORAGE = 24
DEVMAP_HASH = 25
STRUCT_OPS = 26
RINGBUF = 27
INODE_STORAGE = 28
TASK_STORAGE = 29
BLOOM_FILTER = 30
USER_RINGBUF = 31
CGRP_STORAGE = 32
def create_bpf_map(module, map_name, map_params):
"""Create a BPF map in the module with given parameters and debug info"""
"""
Create a BPF map in the module with given parameters and debug info.
Args:
module: The LLVM IR module to add the map to
map_name: The name of the BPF map
map_params: Dictionary of map parameters (type, key_size, value_size, max_entries)
Returns:
The created global variable representing the map
"""
# Create the anonymous struct type for BPF map
map_struct_type = ir.LiteralStructType(
@ -49,42 +112,114 @@ def create_bpf_map(module, map_name, map_params):
return map_global
def _parse_map_params(rval, expected_args=None):
"""Parse map parameters from call arguments and keywords."""
def create_map_debug_info(module, map_global, map_name, map_params):
"""Generate debug info metadata for BPF maps HASH and PERF_EVENT_ARRAY"""
generator = DebugInfoGenerator(module)
params = {}
handler = VmlinuxHandlerRegistry.get_handler()
# Parse positional arguments
if expected_args:
for i, arg_name in enumerate(expected_args):
if i < len(rval.args):
arg = rval.args[i]
if isinstance(arg, ast.Name):
params[arg_name] = arg.id
elif isinstance(arg, ast.Constant):
params[arg_name] = arg.value
uint_type = generator.get_uint32_type()
ulong_type = generator.get_uint64_type()
array_type = generator.create_array_type(
uint_type, map_params.get("type", BPFMapType.UNSPEC).value
)
type_ptr = generator.create_pointer_type(array_type, 64)
key_ptr = generator.create_pointer_type(
array_type if "key_size" in map_params else ulong_type, 64
)
value_ptr = generator.create_pointer_type(
array_type if "value_size" in map_params else ulong_type, 64
)
# Parse keyword arguments (override positional)
for keyword in rval.keywords:
if isinstance(keyword.value, ast.Name):
name = keyword.value.id
if handler and handler.is_vmlinux_enum(name):
result = handler.get_vmlinux_enum_value(name)
params[keyword.arg] = result if result is not None else name
else:
params[keyword.arg] = name
elif isinstance(keyword.value, ast.Constant):
params[keyword.arg] = keyword.value.value
elements_arr = []
return params
# Create struct members
# scope field does not appear for some reason
cnt = 0
for elem in map_params:
if elem == "max_entries":
continue
if elem == "type":
ptr = type_ptr
elif "key" in elem:
ptr = key_ptr
else:
ptr = value_ptr
# TODO: the best way to do this is not 64, but get the size each time. this will not work for structs.
member = generator.create_struct_member(elem, ptr, cnt * 64)
elements_arr.append(member)
cnt += 1
if "max_entries" in map_params:
max_entries_array = generator.create_array_type(
uint_type, map_params["max_entries"]
)
max_entries_ptr = generator.create_pointer_type(max_entries_array, 64)
max_entries_member = generator.create_struct_member(
"max_entries", max_entries_ptr, cnt * 64
)
elements_arr.append(max_entries_member)
# Create the struct type
struct_type = generator.create_struct_type(
elements_arr, 64 * len(elements_arr), is_distinct=True
)
# Create global variable debug info
global_var = generator.create_global_var_debug_info(
map_name, struct_type, is_local=False
)
# Attach debug info to the global variable
map_global.set_metadata("dbg", global_var)
return global_var
def create_ringbuf_debug_info(module, map_global, map_name, map_params):
"""Generate debug information metadata for BPF RINGBUF map"""
generator = DebugInfoGenerator(module)
int_type = generator.get_int32_type()
type_array = generator.create_array_type(
int_type, map_params.get("type", BPFMapType.RINGBUF).value
)
type_ptr = generator.create_pointer_type(type_array, 64)
type_member = generator.create_struct_member("type", type_ptr, 0)
max_entries_array = generator.create_array_type(int_type, map_params["max_entries"])
max_entries_ptr = generator.create_pointer_type(max_entries_array, 64)
max_entries_member = generator.create_struct_member(
"max_entries", max_entries_ptr, 64
)
elements_arr = [type_member, max_entries_member]
struct_type = generator.create_struct_type(elements_arr, 128, is_distinct=True)
global_var = generator.create_global_var_debug_info(
map_name, struct_type, is_local=False
)
map_global.set_metadata("dbg", global_var)
return global_var
@MapProcessorRegistry.register("RingBuf")
def process_ringbuf_map(map_name, rval, module):
"""Process a BPF_RINGBUF map declaration"""
logger.info(f"Processing Ringbuf: {map_name}")
map_params = _parse_map_params(rval, expected_args=["max_entries"])
map_params["type"] = BPFMapType.RINGBUF
map_params = {"type": BPFMapType.RINGBUF}
# Parse max_entries if present
if len(rval.args) >= 1 and isinstance(rval.args[0], ast.Constant):
const_val = rval.args[0].value
if isinstance(const_val, int):
map_params["max_entries"] = const_val
for keyword in rval.keywords:
if keyword.arg == "max_entries" and isinstance(keyword.value, ast.Constant):
const_val = keyword.value.value
if isinstance(const_val, int):
map_params["max_entries"] = const_val
logger.info(f"Ringbuf map parameters: {map_params}")
@ -97,8 +232,27 @@ def process_ringbuf_map(map_name, rval, module):
def process_hash_map(map_name, rval, module):
"""Process a BPF_HASH map declaration"""
logger.info(f"Processing HashMap: {map_name}")
map_params = _parse_map_params(rval, expected_args=["key", "value", "max_entries"])
map_params["type"] = BPFMapType.HASH
map_params = {"type": BPFMapType.HASH}
# Assuming order: key_type, value_type, max_entries
if len(rval.args) >= 1 and isinstance(rval.args[0], ast.Name):
map_params["key"] = rval.args[0].id
if len(rval.args) >= 2 and isinstance(rval.args[1], ast.Name):
map_params["value"] = rval.args[1].id
if len(rval.args) >= 3 and isinstance(rval.args[2], ast.Constant):
const_val = rval.args[2].value
if isinstance(const_val, (int, str)): # safe check
map_params["max_entries"] = const_val
for keyword in rval.keywords:
if keyword.arg == "key" and isinstance(keyword.value, ast.Name):
map_params["key"] = keyword.value.id
elif keyword.arg == "value" and isinstance(keyword.value, ast.Name):
map_params["value"] = keyword.value.id
elif keyword.arg == "max_entries" and isinstance(keyword.value, ast.Constant):
const_val = keyword.value.value
if isinstance(const_val, (int, str)):
map_params["max_entries"] = const_val
logger.info(f"Map parameters: {map_params}")
map_global = create_bpf_map(module, map_name, map_params)
@ -111,8 +265,18 @@ def process_hash_map(map_name, rval, module):
def process_perf_event_map(map_name, rval, module):
"""Process a BPF_PERF_EVENT_ARRAY map declaration"""
logger.info(f"Processing PerfEventArray: {map_name}")
map_params = _parse_map_params(rval, expected_args=["key_size", "value_size"])
map_params["type"] = BPFMapType.PERF_EVENT_ARRAY
map_params = {"type": BPFMapType.PERF_EVENT_ARRAY}
if len(rval.args) >= 1 and isinstance(rval.args[0], ast.Name):
map_params["key_size"] = rval.args[0].id
if len(rval.args) >= 2 and isinstance(rval.args[1], ast.Name):
map_params["value_size"] = rval.args[1].id
for keyword in rval.keywords:
if keyword.arg == "key_size" and isinstance(keyword.value, ast.Name):
map_params["key_size"] = keyword.value.id
elif keyword.arg == "value_size" and isinstance(keyword.value, ast.Name):
map_params["value_size"] = keyword.value.id
logger.info(f"Map parameters: {map_params}")
map_global = create_bpf_map(module, map_name, map_params)

View File

@ -1,3 +1,5 @@
"""Registry for BPF map processor functions."""
from collections.abc import Callable
from typing import Any
@ -12,6 +14,7 @@ class MapProcessorRegistry:
"""Decorator to register a processor function for a map type"""
def decorator(func):
"""Decorator that registers the processor function."""
cls._processors[map_type_name] = func
return func

View File

@ -1,3 +1,5 @@
"""Struct processing for BPF programs."""
from .structs_pass import structs_proc
__all__ = ["structs_proc"]

View File

@ -1,19 +1,72 @@
"""
Struct type wrapper for BPF structs.
This module provides a wrapper class for LLVM IR struct types with
helper methods for field access and manipulation.
"""
from llvmlite import ir
class StructType:
"""
Wrapper class for LLVM IR struct types with field access helpers.
Attributes:
ir_type: The LLVM IR struct type
fields: Dictionary mapping field names to their types
size: Total size of the struct in bytes
"""
def __init__(self, ir_type, fields, size):
"""
Initialize a StructType.
Args:
ir_type: The LLVM IR struct type
fields: Dictionary mapping field names to their types
size: Total size of the struct in bytes
"""
self.ir_type = ir_type
self.fields = fields
self.size = size
def field_idx(self, field_name):
"""
Get the index of a field in the struct.
Args:
field_name: The name of the field
Returns:
The zero-based index of the field
"""
return list(self.fields.keys()).index(field_name)
def field_type(self, field_name):
"""
Get the LLVM IR type of a field.
Args:
field_name: The name of the field
Returns:
The LLVM IR type of the field
"""
return self.fields[field_name]
def gep(self, builder, ptr, field_name):
"""
Generate a GEP (GetElementPtr) instruction to access a struct field.
Args:
builder: LLVM IR builder
ptr: Pointer to the struct
field_name: Name of the field to access
Returns:
A pointer to the field
"""
idx = self.field_idx(field_name)
return builder.gep(
ptr,
@ -22,6 +75,18 @@ class StructType:
)
def field_size(self, field_name):
"""
Calculate the size of a field in bytes.
Args:
field_name: The name of the field
Returns:
The size of the field in bytes
Raises:
TypeError: If the field type is not supported
"""
fld = self.fields[field_name]
if isinstance(fld, ir.ArrayType):
return fld.count * (fld.element.width // 8)

View File

@ -1,3 +1,10 @@
"""
BPF struct processing and LLVM IR type generation.
This module handles the processing of Python classes decorated with @struct,
converting them to LLVM IR struct types for use in BPF programs.
"""
import ast
import logging
from llvmlite import ir
@ -26,6 +33,15 @@ def structs_proc(tree, module, chunks):
def is_bpf_struct(cls_node):
"""
Check if a class node is decorated with @struct.
Args:
cls_node: The AST class node to check
Returns:
True if the class is decorated with @struct, False otherwise
"""
return any(
isinstance(decorator, ast.Name) and decorator.id == "struct"
for decorator in cls_node.decorator_list
@ -33,7 +49,16 @@ def is_bpf_struct(cls_node):
def process_bpf_struct(cls_node, module):
"""Process a single BPF struct definition"""
"""
Process a single BPF struct definition and create its LLVM IR representation.
Args:
cls_node: The AST class node representing the struct
module: The LLVM IR module (not used in current implementation)
Returns:
A StructType object containing the struct's type information
"""
fields = parse_struct_fields(cls_node)
field_types = list(fields.values())
@ -44,7 +69,18 @@ def process_bpf_struct(cls_node, module):
def parse_struct_fields(cls_node):
"""Parse fields of a struct class node"""
"""
Parse fields of a struct class node.
Args:
cls_node: The AST class node representing the struct
Returns:
A dictionary mapping field names to their LLVM IR types
Raises:
TypeError: If a field has an unsupported type annotation
"""
fields = {}
for item in cls_node.body:
@ -57,7 +93,18 @@ def parse_struct_fields(cls_node):
def get_type_from_ann(annotation):
"""Convert an AST annotation node to an LLVM IR type for struct fields"""
"""
Convert an AST annotation node to an LLVM IR type for struct fields.
Args:
annotation: The AST annotation node (e.g., c_int64, str(32))
Returns:
The corresponding LLVM IR type
Raises:
TypeError: If the annotation type is not supported
"""
if isinstance(annotation, ast.Call) and isinstance(annotation.func, ast.Name):
if annotation.func.id == "str":
# Char array
@ -72,7 +119,15 @@ def get_type_from_ann(annotation):
def calc_struct_size(field_types):
"""Calculate total size of the struct with alignment and padding"""
"""
Calculate total size of the struct with alignment and padding.
Args:
field_types: List of LLVM IR types for each field
Returns:
The total size of the struct in bytes
"""
curr_offset = 0
for ftype in field_types:
if isinstance(ftype, ir.IntType):

View File

@ -1,3 +1,10 @@
"""
Type mapping from Python ctypes to LLVM IR types.
This module provides utilities to convert Python ctypes type names
to their corresponding LLVM IR representations.
"""
from llvmlite import ir
# TODO: THIS IS NOT SUPPOSED TO MATCH STRINGS :skull:
@ -19,10 +26,31 @@ mapping = {
def ctypes_to_ir(ctype: str):
"""
Convert a ctypes type name to its corresponding LLVM IR type.
Args:
ctype: String name of the ctypes type (e.g., 'c_int64', 'c_void_p')
Returns:
The corresponding LLVM IR type
Raises:
NotImplementedError: If the ctype is not supported
"""
if ctype in mapping:
return mapping[ctype]
raise NotImplementedError(f"No mapping for {ctype}")
def is_ctypes(ctype: str) -> bool:
"""
Check if a given type name is a supported ctypes type.
Args:
ctype: String name of the type to check
Returns:
True if the type is a supported ctypes type, False otherwise
"""
return ctype in mapping

View File

@ -1,58 +0,0 @@
import subprocess
def trace_pipe():
"""Util to read from the trace pipe."""
try:
subprocess.run(["cat", "/sys/kernel/tracing/trace_pipe"])
except KeyboardInterrupt:
print("Tracing stopped.")
except (FileNotFoundError, PermissionError) as e:
print(f"Error accessing trace_pipe: {e}. Try running as root.")
def trace_fields():
"""Parse one line from trace_pipe into fields."""
with open("/sys/kernel/tracing/trace_pipe", "rb", buffering=0) as f:
while True:
line = f.readline().rstrip()
if not line:
continue
# Skip lost event lines
if line.startswith(b"CPU:"):
continue
# Parse BCC-style: first 16 bytes = task
task = line[:16].lstrip().decode("utf-8")
line = line[17:] # Skip past task field and space
# Find the colon that ends "pid cpu flags timestamp"
ts_end = line.find(b":")
if ts_end == -1:
raise ValueError("Cannot parse trace line")
# Split "pid [cpu] flags timestamp"
try:
parts = line[:ts_end].split()
if len(parts) < 4:
raise ValueError("Not enough fields")
pid = int(parts[0])
cpu = parts[1][1:-1] # Remove brackets from [cpu]
cpu = int(cpu)
flags = parts[2]
ts = float(parts[3])
except (ValueError, IndexError):
raise ValueError("Cannot parse trace line")
# Get message: skip ": symbol:" part
line = line[ts_end + 1 :] # Skip first ":"
sym_end = line.find(b":")
if sym_end != -1:
msg = line[sym_end + 2 :].decode("utf-8") # Skip ": " after symbol
else:
msg = line.lstrip().decode("utf-8")
return (task, pid, cpu, flags, ts, msg)

View File

@ -1,3 +0,0 @@
from .import_detector import vmlinux_proc
__all__ = ["vmlinux_proc"]

View File

@ -1,36 +0,0 @@
from enum import Enum, auto
from typing import Any, Dict, List, Optional, TypedDict
from dataclasses import dataclass
import llvmlite.ir as ir
from pythonbpf.vmlinux_parser.dependency_node import Field
@dataclass
class AssignmentType(Enum):
CONSTANT = auto()
STRUCT = auto()
ARRAY = auto() # probably won't be used
FUNCTION_POINTER = auto()
POINTER = auto() # again, probably won't be used
@dataclass
class FunctionSignature(TypedDict):
return_type: str
param_types: List[str]
varargs: bool
# Thew name of the assignment will be in the dict that uses this class
@dataclass
class AssignmentInfo(TypedDict):
value_type: AssignmentType
python_type: type
value: Optional[Any]
pointer_level: Optional[int]
signature: Optional[FunctionSignature] # For function pointers
# The key of the dict is the name of the field.
# Value is a tuple that contains the global variable representing that field
# along with all the information about that field as a Field type.
members: Optional[Dict[str, tuple[ir.GlobalVariable, Field]]] # For structs.

View File

@ -1,255 +0,0 @@
import logging
from functools import lru_cache
import importlib
from .dependency_handler import DependencyHandler
from .dependency_node import DependencyNode
import ctypes
from typing import Optional, Any, Dict
logger = logging.getLogger(__name__)
@lru_cache(maxsize=1)
def get_module_symbols(module_name: str):
imported_module = importlib.import_module(module_name)
return [name for name in dir(imported_module)], imported_module
def process_vmlinux_class(
node,
llvm_module,
handler: DependencyHandler,
):
symbols_in_module, imported_module = get_module_symbols("vmlinux")
if node.name in symbols_in_module:
vmlinux_type = getattr(imported_module, node.name)
process_vmlinux_post_ast(vmlinux_type, llvm_module, handler)
else:
raise ImportError(f"{node.name} not in vmlinux")
def process_vmlinux_post_ast(
elem_type_class,
llvm_handler,
handler: DependencyHandler,
processing_stack=None,
):
# Initialize processing stack on first call
if processing_stack is None:
processing_stack = set()
symbols_in_module, imported_module = get_module_symbols("vmlinux")
current_symbol_name = elem_type_class.__name__
logger.info(f"Begin {current_symbol_name} Processing")
field_table: Dict[str, list] = {}
is_complex_type = False
containing_type: Optional[Any] = None
ctype_complex_type: Optional[Any] = None
type_length: Optional[int] = None
module_name = getattr(elem_type_class, "__module__", None)
# Check if already processed
if handler.has_node(current_symbol_name):
logger.debug(f"Node {current_symbol_name} already processed and ready")
return True
# XXX:Check its use. It's probably not being used.
if current_symbol_name in processing_stack:
logger.debug(
f"Dependency already in processing stack for {current_symbol_name}, skipping"
)
return True
processing_stack.add(current_symbol_name)
if module_name == "vmlinux":
if hasattr(elem_type_class, "_type_"):
pass
else:
new_dep_node = DependencyNode(name=current_symbol_name)
# elem_type_class is the actual vmlinux struct/class
new_dep_node.set_ctype_struct(elem_type_class)
handler.add_node(new_dep_node)
class_obj = getattr(imported_module, current_symbol_name)
# Inspect the class fields
if hasattr(class_obj, "_fields_"):
for field_elem in class_obj._fields_:
field_name: str = ""
field_type: Optional[Any] = None
bitfield_size: Optional[int] = None
if len(field_elem) == 2:
field_name, field_type = field_elem
elif len(field_elem) == 3:
field_name, field_type, bitfield_size = field_elem
field_table[field_name] = [field_type, bitfield_size]
elif hasattr(class_obj, "__annotations__"):
for field_elem in class_obj.__annotations__.items():
if len(field_elem) == 2:
field_name, field_type = field_elem
bitfield_size = None
elif len(field_elem) == 3:
field_name, field_type, bitfield_size = field_elem
else:
raise ValueError(
"Number of fields in items() of class object unexpected"
)
field_table[field_name] = [field_type, bitfield_size]
else:
raise TypeError("Could not get required class and definition")
logger.debug(f"Extracted fields for {current_symbol_name}: {field_table}")
for elem in field_table.items():
elem_name, elem_temp_list = elem
[elem_type, elem_bitfield_size] = elem_temp_list
local_module_name = getattr(elem_type, "__module__", None)
new_dep_node.add_field(elem_name, elem_type, ready=False)
if local_module_name == ctypes.__name__:
# TODO: need to process pointer to ctype and also CFUNCTYPES here recursively. Current processing is a single dereference
new_dep_node.set_field_bitfield_size(elem_name, elem_bitfield_size)
# Process pointer to ctype
if isinstance(elem_type, type) and issubclass(
elem_type, ctypes._Pointer
):
# Get the pointed-to type
pointed_type = elem_type._type_
logger.debug(f"Found pointer to type: {pointed_type}")
new_dep_node.set_field_containing_type(elem_name, pointed_type)
new_dep_node.set_field_ctype_complex_type(
elem_name, ctypes._Pointer
)
new_dep_node.set_field_ready(elem_name, is_ready=True)
# Process function pointers (CFUNCTYPE)
elif hasattr(elem_type, "_restype_") and hasattr(
elem_type, "_argtypes_"
):
# This is a CFUNCTYPE or similar
logger.info(
f"Function pointer detected for {elem_name} with return type {elem_type._restype_} and arguments {elem_type._argtypes_}"
)
# Set the field as ready but mark it with special handling
new_dep_node.set_field_ctype_complex_type(
elem_name, ctypes.CFUNCTYPE
)
new_dep_node.set_field_ready(elem_name, is_ready=True)
logger.warning(
"Blindly processing CFUNCTYPE ctypes to ensure compilation. Unsupported"
)
else:
# Regular ctype
new_dep_node.set_field_ready(elem_name, is_ready=True)
logger.debug(
f"Field {elem_name} is direct ctypes type: {elem_type}"
)
elif local_module_name == "vmlinux":
new_dep_node.set_field_bitfield_size(elem_name, elem_bitfield_size)
logger.debug(
f"Processing vmlinux field: {elem_name}, type: {elem_type}"
)
if hasattr(elem_type, "_type_"):
is_complex_type = True
containing_type = elem_type._type_
if hasattr(elem_type, "_length_") and is_complex_type:
type_length = elem_type._length_
if containing_type.__module__ == "vmlinux":
new_dep_node.add_dependent(
elem_type._type_.__name__
if hasattr(elem_type._type_, "__name__")
else str(elem_type._type_)
)
elif containing_type.__module__ == ctypes.__name__:
if isinstance(elem_type, type):
if issubclass(elem_type, ctypes.Array):
ctype_complex_type = ctypes.Array
elif issubclass(elem_type, ctypes._Pointer):
ctype_complex_type = ctypes._Pointer
else:
raise ImportError(
"Non Array and Pointer type ctype imports not supported in current version"
)
else:
raise TypeError("Unsupported ctypes subclass")
else:
raise ImportError(
f"Unsupported module of {containing_type}"
)
logger.debug(
f"{containing_type} containing type of parent {elem_name} with {elem_type} and ctype {ctype_complex_type} and length {type_length}"
)
new_dep_node.set_field_containing_type(
elem_name, containing_type
)
new_dep_node.set_field_type_size(elem_name, type_length)
new_dep_node.set_field_ctype_complex_type(
elem_name, ctype_complex_type
)
new_dep_node.set_field_type(elem_name, elem_type)
if containing_type.__module__ == "vmlinux":
containing_type_name = (
containing_type.__name__
if hasattr(containing_type, "__name__")
else str(containing_type)
)
# Check for self-reference or already processed
if containing_type_name == current_symbol_name:
# Self-referential pointer
logger.debug(
f"Self-referential pointer in {current_symbol_name}.{elem_name}"
)
new_dep_node.set_field_ready(elem_name, True)
elif handler.has_node(containing_type_name):
# Already processed
logger.debug(
f"Reusing already processed {containing_type_name}"
)
new_dep_node.set_field_ready(elem_name, True)
else:
# Process recursively - THIS WAS MISSING
new_dep_node.add_dependent(containing_type_name)
process_vmlinux_post_ast(
containing_type,
llvm_handler,
handler,
processing_stack,
)
new_dep_node.set_field_ready(elem_name, True)
elif containing_type.__module__ == ctypes.__name__:
logger.debug(f"Processing ctype internal{containing_type}")
new_dep_node.set_field_ready(elem_name, True)
else:
raise TypeError(
"Module not supported in recursive resolution"
)
else:
new_dep_node.add_dependent(
elem_type.__name__
if hasattr(elem_type, "__name__")
else str(elem_type)
)
process_vmlinux_post_ast(
elem_type,
llvm_handler,
handler,
processing_stack,
)
new_dep_node.set_field_ready(elem_name, True)
else:
raise ValueError(
f"{elem_name} with type {elem_type} from module {module_name} not supported in recursive resolver"
)
else:
raise ImportError("UNSUPPORTED Module")
logger.info(
f"{current_symbol_name} processed and handler readiness {handler.is_ready}"
)
return True

View File

@ -1,173 +0,0 @@
from typing import Optional, Dict, List, Iterator
from .dependency_node import DependencyNode
class DependencyHandler:
"""
Manages a collection of DependencyNode objects with no duplicates.
Ensures that no two nodes with the same name can be added and provides
methods to check readiness and retrieve specific nodes.
Example usage:
# Create a handler
handler = DependencyHandler()
# Create some dependency nodes
node1 = DependencyNode(name="node1")
node1.add_field("field1", str)
node1.set_field_value("field1", "value1")
node2 = DependencyNode(name="node2")
node2.add_field("field1", int)
# Add nodes to the handler
handler.add_node(node1)
handler.add_node(node2)
# Check if a specific node exists
print(handler.has_node("node1")) # True
# Get a reference to a node and modify it
node = handler.get_node("node2")
node.set_field_value("field1", 42)
# Check if all nodes are ready
print(handler.is_ready) # False (node2 is ready, but node1 isn't)
"""
def __init__(self):
# Using a dictionary with node names as keys ensures name uniqueness
# and provides efficient lookups
self._nodes: Dict[str, DependencyNode] = {}
def add_node(self, node: DependencyNode) -> bool:
"""
Add a dependency node to the handler.
Args:
node: The DependencyNode to add
Returns:
bool: True if the node was added, False if a node with the same name already exists
Raises:
TypeError: If the provided object is not a DependencyNode
"""
if not isinstance(node, DependencyNode):
raise TypeError(f"Expected DependencyNode, got {type(node).__name__}")
# Check if a node with this name already exists
if node.name in self._nodes:
return False
self._nodes[node.name] = node
return True
@property
def is_ready(self) -> bool:
"""
Check if all nodes are ready.
Returns:
bool: True if all nodes are ready (or if there are no nodes), False otherwise
"""
if not self._nodes:
return True
return all(node.is_ready for node in self._nodes.values())
def has_node(self, name: str) -> bool:
"""
Check if a node with the given name exists.
Args:
name: The name to check
Returns:
bool: True if a node with the given name exists, False otherwise
"""
return name in self._nodes
def get_node(self, name: str) -> Optional[DependencyNode]:
"""
Get a node by name for manipulation.
Args:
name: The name of the node to retrieve
Returns:
Optional[DependencyNode]: The node with the given name, or None if not found
"""
return self._nodes.get(name)
def remove_node(self, node_or_name) -> bool:
"""
Remove a node by name or reference.
Args:
node_or_name: The node to remove or its name
Returns:
bool: True if the node was removed, False if not found
"""
if isinstance(node_or_name, DependencyNode):
name = node_or_name.name
else:
name = node_or_name
if name in self._nodes:
del self._nodes[name]
return True
return False
def get_all_nodes(self) -> List[DependencyNode]:
"""
Get all nodes stored in the handler.
Returns:
List[DependencyNode]: List of all nodes
"""
return list(self._nodes.values())
def __iter__(self) -> Iterator[DependencyNode]:
"""
Iterate over all nodes.
Returns:
Iterator[DependencyNode]: Iterator over all nodes
"""
return iter(self._nodes.values())
def __len__(self) -> int:
"""
Get the number of nodes in the handler.
Returns:
int: The number of nodes
"""
return len(self._nodes)
def __getitem__(self, name: str) -> DependencyNode:
"""
Get a node by name using dictionary-style access.
Args:
name: The name of the node to retrieve
Returns:
DependencyNode: The node with the given name
Raises:
KeyError: If no node with the given name exists
Example:
node = handler["some-dep_node_name"]
"""
if name not in self._nodes:
raise KeyError(f"No node with name '{name}' found")
return self._nodes[name]
@property
def nodes(self):
return self._nodes

View File

@ -1,388 +0,0 @@
from dataclasses import dataclass, field
from typing import Dict, Any, Optional
import ctypes
# TODO: FIX THE FUCKING TYPE NAME CONVENTION.
@dataclass
class Field:
"""Represents a field in a dependency node with its type and readiness state."""
name: str
type: type
ctype_complex_type: Optional[Any]
containing_type: Optional[Any]
type_size: Optional[int]
bitfield_size: Optional[int]
offset: int
value: Any = None
ready: bool = False
def __hash__(self):
"""
Create a hash based on the immutable attributes that define this field's identity.
This allows Field objects to be used as dictionary keys.
"""
# Use a tuple of the fields that uniquely identify this field
identity = (
self.name,
id(self.type), # Use id for non-hashable types
id(self.ctype_complex_type) if self.ctype_complex_type else None,
id(self.containing_type) if self.containing_type else None,
self.type_size,
self.bitfield_size,
self.offset,
self.value if self.value else None,
)
return hash(identity)
def __eq__(self, other):
"""
Define equality consistent with the hash function.
Two fields are equal if they have they are the same
"""
return self is other
def set_ready(self, is_ready: bool = True) -> None:
"""Set the readiness state of this field."""
self.ready = is_ready
def set_value(self, value: Any, mark_ready: bool = False) -> None:
"""Set the value of this field and optionally mark it as ready."""
self.value = value
if mark_ready:
self.ready = True
def set_type(self, given_type, mark_ready: bool = False) -> None:
"""Set value of the type field and mark as ready"""
self.type = given_type
if mark_ready:
self.ready = True
def set_containing_type(
self, containing_type: Optional[Any], mark_ready: bool = False
) -> None:
"""Set the containing_type of this field and optionally mark it as ready."""
self.containing_type = containing_type
if mark_ready:
self.ready = True
def set_type_size(self, type_size: Any, mark_ready: bool = False) -> None:
"""Set the type_size of this field and optionally mark it as ready."""
self.type_size = type_size
if mark_ready:
self.ready = True
def set_ctype_complex_type(
self, ctype_complex_type: Any, mark_ready: bool = False
) -> None:
"""Set the ctype_complex_type of this field and optionally mark it as ready."""
self.ctype_complex_type = ctype_complex_type
if mark_ready:
self.ready = True
def set_bitfield_size(self, bitfield_size: Any, mark_ready: bool = False) -> None:
"""Set the bitfield_size of this field and optionally mark it as ready."""
self.bitfield_size = bitfield_size
if mark_ready:
self.ready = True
def set_offset(self, offset: int) -> None:
"""Set the offset of this field"""
self.offset = offset
@dataclass
class DependencyNode:
"""
A node with typed fields and readiness tracking.
Example usage:
# Create a dependency node for a Person
somestruct = DependencyNode(name="struct_1")
# Add fields with their types
somestruct.add_field("field_1", str)
somestruct.add_field("field_2", int)
somestruct.add_field("field_3", str)
# Check if the node is ready (should be False initially)
print(f"Is node ready? {somestruct.is_ready}") # False
# Set some field values
somestruct.set_field_value("field_1", "someproperty")
somestruct.set_field_value("field_2", 30)
# Check if the node is ready (still False because email is not ready)
print(f"Is node ready? {somestruct.is_ready}") # False
# Set the last field and make the node ready
somestruct.set_field_value("field_3", "anotherproperty")
# Now the node should be ready
print(f"Is node ready? {somestruct.is_ready}") # True
# You can also mark a field as not ready
somestruct.set_field_ready("field_3", False)
# Now the node is not ready again
print(f"Is node ready? {somestruct.is_ready}") # False
# Get all field values
print(somestruct.get_field_values()) # {'field_1': 'someproperty', 'field_2': 30, 'field_3': 'anotherproperty'}
# Get only ready fields
ready_fields = somestruct.get_ready_fields()
print(f"Ready fields: {[field.name for field in ready_fields.values()]}") # ['field_1', 'field_2']
"""
name: str
depends_on: Optional[list[str]] = None
fields: Dict[str, Field] = field(default_factory=dict)
_ready_cache: Optional[bool] = field(default=None, repr=False)
current_offset: int = 0
ctype_struct: Optional[Any] = field(default=None, repr=False)
def add_field(
self,
name: str,
field_type: type,
initial_value: Any = None,
containing_type: Optional[Any] = None,
type_size: Optional[int] = None,
ctype_complex_type: Optional[int] = None,
bitfield_size: Optional[int] = None,
ready: bool = False,
offset: int = 0,
) -> None:
"""Add a field to the node with an optional initial value and readiness state."""
if self.depends_on is None:
self.depends_on = []
self.fields[name] = Field(
name=name,
type=field_type,
value=initial_value,
ready=ready,
containing_type=containing_type,
type_size=type_size,
ctype_complex_type=ctype_complex_type,
bitfield_size=bitfield_size,
offset=offset,
)
# Invalidate readiness cache
self._ready_cache = None
def set_ctype_struct(self, ctype_struct: Any) -> None:
"""Set the ctypes structure for automatic offset calculation."""
self.ctype_struct = ctype_struct
def __sizeof__(self):
# If we have a ctype_struct, use its size
if self.ctype_struct is not None:
return ctypes.sizeof(self.ctype_struct)
return self.current_offset
def get_field(self, name: str) -> Field:
"""Get a field by name."""
return self.fields[name]
def set_field_value(self, name: str, value: Any, mark_ready: bool = False) -> None:
"""Set a field's value and optionally mark it as ready."""
if name not in self.fields:
raise KeyError(f"Field '{name}' does not exist in node '{self.name}'")
self.fields[name].set_value(value, mark_ready)
# Invalidate readiness cache
self._ready_cache = None
def set_field_type(self, name: str, type: Any, mark_ready: bool = False) -> None:
"""Set a field's type and optionally mark it as ready."""
if name not in self.fields:
raise KeyError(f"Field '{name}' does not exist in node '{self.name}'")
self.fields[name].set_type(type, mark_ready)
# Invalidate readiness cache
self._ready_cache = None
def set_field_containing_type(
self, name: str, containing_type: Any, mark_ready: bool = False
) -> None:
"""Set a field's containing_type and optionally mark it as ready."""
if name not in self.fields:
raise KeyError(f"Field '{name}' does not exist in node '{self.name}'")
self.fields[name].set_containing_type(containing_type, mark_ready)
# Invalidate readiness cache
self._ready_cache = None
def set_field_type_size(
self, name: str, type_size: Any, mark_ready: bool = False
) -> None:
"""Set a field's type_size and optionally mark it as ready."""
if name not in self.fields:
raise KeyError(f"Field '{name}' does not exist in node '{self.name}'")
self.fields[name].set_type_size(type_size, mark_ready)
# Invalidate readiness cache
self._ready_cache = None
def set_field_ctype_complex_type(
self, name: str, ctype_complex_type: Any, mark_ready: bool = False
) -> None:
"""Set a field's ctype_complex_type and optionally mark it as ready."""
if name not in self.fields:
raise KeyError(f"Field '{name}' does not exist in node '{self.name}'")
self.fields[name].set_ctype_complex_type(ctype_complex_type, mark_ready)
# Invalidate readiness cache
self._ready_cache = None
def set_field_bitfield_size(
self, name: str, bitfield_size: Any, mark_ready: bool = False
) -> None:
"""Set a field's bitfield_size and optionally mark it as ready."""
if name not in self.fields:
raise KeyError(f"Field '{name}' does not exist in node '{self.name}'")
self.fields[name].set_bitfield_size(bitfield_size, mark_ready)
# Invalidate readiness cache
self._ready_cache = None
def set_field_ready(
self,
name: str,
is_ready: bool = False,
size_of_containing_type: Optional[int] = None,
) -> None:
"""Mark a field as ready or not ready."""
if name not in self.fields:
raise KeyError(f"Field '{name}' does not exist in node '{self.name}'")
self.fields[name].set_ready(is_ready)
# Use ctypes built-in offset if available
if self.ctype_struct is not None:
try:
self.fields[name].set_offset(getattr(self.ctype_struct, name).offset)
except AttributeError:
# Fallback to manual calculation if field not found in ctype_struct
self.fields[name].set_offset(self.current_offset)
self.current_offset += self._calculate_size(
name, size_of_containing_type
)
else:
# Manual offset calculation when no ctype_struct is available
self.fields[name].set_offset(self.current_offset)
self.current_offset += self._calculate_size(name, size_of_containing_type)
# Invalidate readiness cache
self._ready_cache = None
def _calculate_size(
self, name: str, size_of_containing_type: Optional[int] = None
) -> int:
processing_field = self.fields[name]
# size_of_field will be in bytes
if processing_field.type.__module__ == ctypes.__name__:
size_of_field = ctypes.sizeof(processing_field.type)
return size_of_field
elif processing_field.type.__module__ == "vmlinux":
if processing_field.ctype_complex_type is not None:
if issubclass(processing_field.ctype_complex_type, ctypes.Array):
if processing_field.containing_type.__module__ == ctypes.__name__:
if (
processing_field.containing_type is not None
and processing_field.type_size is not None
):
size_of_field = (
ctypes.sizeof(processing_field.containing_type)
* processing_field.type_size
)
else:
raise RuntimeError(
f"{processing_field} has no containing_type or type_size"
)
return size_of_field
elif processing_field.containing_type.__module__ == "vmlinux":
if (
size_of_containing_type is not None
and processing_field.type_size is not None
):
size_of_field = (
size_of_containing_type * processing_field.type_size
)
else:
raise RuntimeError(
f"{processing_field} has no containing_type or type_size"
)
return size_of_field
elif issubclass(processing_field.ctype_complex_type, ctypes._Pointer):
return ctypes.sizeof(ctypes.c_void_p)
else:
raise NotImplementedError(
"This subclass of ctype not supported yet"
)
elif processing_field.type_size is not None:
# Handle vmlinux types with type_size but no ctype_complex_type
# This means it's a direct vmlinux struct field (not array/pointer wrapped)
# The type_size should already contain the full size of the struct
# But if there's a containing_type from vmlinux, we need that size
if processing_field.containing_type is not None:
if processing_field.containing_type.__module__ == "vmlinux":
# For vmlinux containing types, we need the pre-calculated size
if size_of_containing_type is not None:
return size_of_containing_type * processing_field.type_size
else:
raise RuntimeError(
f"Field {name}: vmlinux containing_type requires size_of_containing_type"
)
else:
raise ModuleNotFoundError(
f"Containing type module {processing_field.containing_type.__module__} not supported"
)
else:
raise RuntimeError("Wrong type found with no containing type")
else:
# No ctype_complex_type and no type_size, must rely on size_of_containing_type
if size_of_containing_type is None:
raise RuntimeError(
f"Size of containing type {size_of_containing_type} is None"
)
return size_of_containing_type
else:
raise ModuleNotFoundError("Module is not supported for the operation")
raise RuntimeError("control should not reach here")
@property
def is_ready(self) -> bool:
"""Check if the node is ready (all fields are ready)."""
# Use cached value if available
if self._ready_cache is not None:
return self._ready_cache
# Calculate readiness only when needed
if not self.fields:
self._ready_cache = True
return True
self._ready_cache = all(elem.ready for elem in self.fields.values())
return self._ready_cache
def get_field_values(self) -> Dict[str, Any]:
"""Get a dictionary of field names to their values."""
return {name: elem.value for name, elem in self.fields.items()}
def get_ready_fields(self) -> Dict[str, Field]:
"""Get all fields that are marked as ready."""
return {name: elem for name, elem in self.fields.items() if elem.ready}
def get_not_ready_fields(self) -> Dict[str, Field]:
"""Get all fields that are marked as not ready."""
return {name: elem for name, elem in self.fields.items() if not elem.ready}
def add_dependent(self, dep_type):
if dep_type in self.depends_on:
return
else:
self.depends_on.append(dep_type)

View File

@ -1,162 +0,0 @@
import ast
import logging
import importlib
import inspect
from .assignment_info import AssignmentInfo, AssignmentType
from .dependency_handler import DependencyHandler
from .ir_gen import IRGenerator
from .class_handler import process_vmlinux_class
logger = logging.getLogger(__name__)
def detect_import_statement(tree: ast.AST) -> list[tuple[str, ast.ImportFrom]]:
"""
Parse AST and detect import statements from vmlinux.
Returns a list of tuples (module_name, imported_item) for vmlinux imports.
Raises SyntaxError for invalid import patterns.
Args:
tree: The AST to parse
Returns:
List of tuples containing (module_name, imported_item) for each vmlinux import
Raises:
SyntaxError: If multiple imports from vmlinux are attempted or import * is used
"""
vmlinux_imports = []
for node in ast.walk(tree):
# Handle "from vmlinux import ..." statements
if isinstance(node, ast.ImportFrom):
if node.module == "vmlinux":
# Check for wildcard import: from vmlinux import *
if any(alias.name == "*" for alias in node.names):
raise SyntaxError(
"Wildcard imports from vmlinux are not supported. "
"Please import specific types explicitly."
)
# Check for multiple imports: from vmlinux import A, B, C
if len(node.names) > 1:
imported_names = [alias.name for alias in node.names]
raise SyntaxError(
f"Multiple imports from vmlinux are not supported. "
f"Found: {', '.join(imported_names)}. "
f"Please use separate import statements for each type."
)
# Check if no specific import is specified (should not happen with valid Python)
if len(node.names) == 0:
raise SyntaxError(
"Import from vmlinux must specify at least one type."
)
# Valid single import
for alias in node.names:
import_name = alias.name
# Use alias if provided, otherwise use the original name (commented)
# as_name = alias.asname if alias.asname else alias.name
vmlinux_imports.append(("vmlinux", node))
logger.info(f"Found vmlinux import: {import_name}")
# Handle "import vmlinux" statements (not typical but should be rejected)
elif isinstance(node, ast.Import):
for alias in node.names:
if alias.name == "vmlinux" or alias.name.startswith("vmlinux."):
raise SyntaxError(
"Direct import of vmlinux module is not supported. "
"Use 'from vmlinux import <type>' instead."
)
logger.info(f"Total vmlinux imports detected: {len(vmlinux_imports)}")
return vmlinux_imports
def vmlinux_proc(tree: ast.AST, module):
import_statements = detect_import_statement(tree)
# initialise dependency handler
handler = DependencyHandler()
# initialise assignment dictionary of name to type
assignments: dict[str, AssignmentInfo] = {}
if not import_statements:
logger.info("No vmlinux imports found")
return
# Import vmlinux module directly
try:
vmlinux_mod = importlib.import_module("vmlinux")
except ImportError:
logger.warning("Could not import vmlinux module")
return
source_file = inspect.getsourcefile(vmlinux_mod)
if source_file is None:
logger.warning("Cannot find source for vmlinux module")
return
with open(source_file, "r") as f:
mod_ast = ast.parse(f.read(), filename=source_file)
for import_mod, import_node in import_statements:
for alias in import_node.names:
imported_name = alias.name
found = False
for mod_node in mod_ast.body:
if (
isinstance(mod_node, ast.ClassDef)
and mod_node.name == imported_name
):
process_vmlinux_class(mod_node, module, handler)
found = True
break
if isinstance(mod_node, ast.Assign):
for target in mod_node.targets:
if isinstance(target, ast.Name) and target.id == imported_name:
process_vmlinux_assign(mod_node, module, assignments)
found = True
break
if found:
break
if not found:
logger.info(
f"{imported_name} not found as ClassDef or Assign in vmlinux"
)
IRGenerator(module, handler, assignments)
return assignments
def process_vmlinux_assign(node, module, assignments: dict[str, AssignmentInfo]):
"""Process assignments from vmlinux module."""
# Only handle single-target assignments
if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name):
target_name = node.targets[0].id
# Handle constant value assignments
if isinstance(node.value, ast.Constant):
# Fixed: using proper TypedDict creation syntax with named arguments
assignments[target_name] = AssignmentInfo(
value_type=AssignmentType.CONSTANT,
python_type=type(node.value.value),
value=node.value.value,
pointer_level=None,
signature=None,
members=None,
)
logger.info(
f"Added assignment: {target_name} = {node.value.value!r} of type {type(node.value.value)}"
)
# Handle other assignment types that we may need to support
else:
logger.warning(
f"Unsupported assignment type for {target_name}: {ast.dump(node.value)}"
)
else:
raise ValueError("Not a simple assignment")

View File

@ -1,3 +0,0 @@
from .ir_generation import IRGenerator
__all__ = ["IRGenerator"]

View File

@ -1,161 +0,0 @@
from pythonbpf.debuginfo import DebugInfoGenerator, dwarf_constants as dc
from ..dependency_node import DependencyNode
import ctypes
import logging
from typing import List, Any, Tuple
logger = logging.getLogger(__name__)
def debug_info_generation(
struct: DependencyNode,
llvm_module,
generated_debug_info: List[Tuple[DependencyNode, Any]],
) -> Any:
"""
Generate DWARF debug information for a struct defined in a DependencyNode.
Args:
struct: The dependency node containing struct information
llvm_module: The LLVM module to add debug info to
generated_debug_info: List of tuples (struct, debug_info) to track generated debug info
Returns:
The generated global variable debug info
"""
# Set up debug info generator
generator = DebugInfoGenerator(llvm_module)
# Check if debug info for this struct has already been generated
for existing_struct, debug_info in generated_debug_info:
if existing_struct.name == struct.name:
return debug_info
# Process all fields and create members for the struct
members = []
for field_name, field in struct.fields.items():
# Get appropriate debug type for this field
field_type = _get_field_debug_type(
field_name, field, generator, struct, generated_debug_info
)
# Create struct member with proper offset
member = generator.create_struct_member_vmlinux(
field_name, field_type, field.offset * 8
)
members.append(member)
if struct.name.startswith("struct_"):
struct_name = struct.name.removeprefix("struct_")
else:
raise ValueError("Unions are not supported in the current version")
# Create struct type with all members
struct_type = generator.create_struct_type_with_name(
struct_name, members, struct.__sizeof__() * 8, is_distinct=True
)
return struct_type
def _get_field_debug_type(
field_name: str,
field,
generator: DebugInfoGenerator,
parent_struct: DependencyNode,
generated_debug_info: List[Tuple[DependencyNode, Any]],
) -> tuple[Any, int]:
"""
Determine the appropriate debug type for a field based on its Python/ctypes type.
Args:
field_name: Name of the field
field: Field object containing type information
generator: DebugInfoGenerator instance
parent_struct: The parent struct containing this field
generated_debug_info: List of already generated debug info
Returns:
The debug info type for this field
"""
# Handle complex types (arrays, pointers)
if field.ctype_complex_type is not None:
if issubclass(field.ctype_complex_type, ctypes.Array):
# Handle array types
element_type, base_type_size = _get_basic_debug_type(
field.containing_type, generator
)
return generator.create_array_type_vmlinux(
(element_type, base_type_size * field.type_size), field.type_size
), field.type_size * base_type_size
elif issubclass(field.ctype_complex_type, ctypes._Pointer):
# Handle pointer types
pointee_type, _ = _get_basic_debug_type(field.containing_type, generator)
return generator.create_pointer_type(pointee_type), 64
# Handle other vmlinux types (nested structs)
if field.type.__module__ == "vmlinux":
# If it's a struct from vmlinux, check if we've already generated debug info for it
struct_name = field.type.__name__
# Look for existing debug info in the list
for existing_struct, debug_info in generated_debug_info:
if existing_struct.name == struct_name:
# Use existing debug info
return debug_info, existing_struct.__sizeof__()
# If not found, create a forward declaration
# This will be completed when the actual struct is processed
logger.warning("Forward declaration in struct created")
forward_type = generator.create_struct_type([], 0, is_distinct=True)
return forward_type, 0
# Handle basic C types
return _get_basic_debug_type(field.type, generator)
def _get_basic_debug_type(ctype, generator: DebugInfoGenerator) -> Any:
"""
Map a ctypes type to a DWARF debug type.
Args:
ctype: A ctypes type or Python type
generator: DebugInfoGenerator instance
Returns:
The corresponding debug type
"""
# Map ctypes to debug info types
if ctype == ctypes.c_char or ctype == ctypes.c_byte:
return generator.get_basic_type("char", 8, dc.DW_ATE_signed_char), 8
elif ctype == ctypes.c_ubyte or ctype == ctypes.c_uint8:
return generator.get_basic_type("unsigned char", 8, dc.DW_ATE_unsigned_char), 8
elif ctype == ctypes.c_short or ctype == ctypes.c_int16:
return generator.get_basic_type("short", 16, dc.DW_ATE_signed), 16
elif ctype == ctypes.c_ushort or ctype == ctypes.c_uint16:
return generator.get_basic_type("unsigned short", 16, dc.DW_ATE_unsigned), 16
elif ctype == ctypes.c_int or ctype == ctypes.c_int32:
return generator.get_basic_type("int", 32, dc.DW_ATE_signed), 32
elif ctype == ctypes.c_uint or ctype == ctypes.c_uint32:
return generator.get_basic_type("unsigned int", 32, dc.DW_ATE_unsigned), 32
elif ctype == ctypes.c_long:
return generator.get_basic_type("long", 64, dc.DW_ATE_signed), 64
elif ctype == ctypes.c_ulong:
return generator.get_basic_type("unsigned long", 64, dc.DW_ATE_unsigned), 64
elif ctype == ctypes.c_longlong or ctype == ctypes.c_int64:
return generator.get_basic_type("long long", 64, dc.DW_ATE_signed), 64
elif ctype == ctypes.c_ulonglong or ctype == ctypes.c_uint64:
return generator.get_basic_type(
"unsigned long long", 64, dc.DW_ATE_unsigned
), 64
elif ctype == ctypes.c_float:
return generator.get_basic_type("float", 32, dc.DW_ATE_float), 32
elif ctype == ctypes.c_double:
return generator.get_basic_type("double", 64, dc.DW_ATE_float), 64
elif ctype == ctypes.c_bool:
return generator.get_basic_type("bool", 8, dc.DW_ATE_boolean), 8
elif ctype == ctypes.c_char_p:
char_type = generator.get_basic_type("char", 8, dc.DW_ATE_signed_char), 8
return generator.create_pointer_type(char_type)
elif ctype == ctypes.c_void_p:
return generator.create_pointer_type(None), 64
else:
return generator.get_uint64_type(), 64

View File

@ -1,225 +0,0 @@
import ctypes
import logging
from ..assignment_info import AssignmentInfo, AssignmentType
from ..dependency_handler import DependencyHandler
from .debug_info_gen import debug_info_generation
from ..dependency_node import DependencyNode
import llvmlite.ir as ir
logger = logging.getLogger(__name__)
class IRGenerator:
# get the assignments dict and add this stuff to it.
def __init__(self, llvm_module, handler: DependencyHandler, assignments):
self.llvm_module = llvm_module
self.handler: DependencyHandler = handler
self.generated: list[str] = []
self.generated_debug_info: list = []
# Use struct_name and field_name as key instead of Field object
self.generated_field_names: dict[str, dict[str, ir.GlobalVariable]] = {}
self.assignments: dict[str, AssignmentInfo] = assignments
if not handler.is_ready:
raise ImportError(
"Semantic analysis of vmlinux imports failed. Cannot generate IR"
)
for struct in handler:
self.struct_processor(struct)
def struct_processor(self, struct, processing_stack=None):
# Initialize processing stack on first call
if processing_stack is None:
processing_stack = set()
# If already generated, skip
if struct.name in self.generated:
return
# Detect circular dependency
if struct.name in processing_stack:
logger.info(
f"Circular dependency detected for {struct.name}, skipping recursive processing"
)
# For circular dependencies, we can either:
# 1. Use forward declarations (opaque pointers)
# 2. Mark as incomplete and process later
# 3. Generate a placeholder type
# Here we'll just skip and let it be processed in its own call
return
logger.info(f"IR generating for {struct.name}")
# Add to processing stack before processing dependencies
processing_stack.add(struct.name)
try:
# Process all dependencies first
if struct.depends_on is None:
pass
else:
for dependency in struct.depends_on:
if dependency not in self.generated:
# Check if dependency exists in handler
if dependency in self.handler.nodes:
dep_node_from_dependency = self.handler[dependency]
# Pass the processing_stack down to track circular refs
self.struct_processor(
dep_node_from_dependency, processing_stack
)
else:
raise RuntimeError(
f"Warning: Dependency {dependency} not found in handler"
)
# Generate IR first to populate field names
self.generated_debug_info.append(
(struct, self.gen_ir(struct, self.generated_debug_info))
)
# Fill the assignments dictionary with struct information
if struct.name not in self.assignments:
# Create a members dictionary for AssignmentInfo
members_dict = {}
for field_name, field in struct.fields.items():
# Get the generated field name from our dictionary, or use field_name if not found
if (
struct.name in self.generated_field_names
and field_name in self.generated_field_names[struct.name]
):
field_global_variable = self.generated_field_names[struct.name][
field_name
]
members_dict[field_name] = (field_global_variable, field)
else:
raise ValueError(
f"llvm global name not found for struct field {field_name}"
)
# members_dict[field_name] = (field_name, field)
# Add struct to assignments dictionary
self.assignments[struct.name] = AssignmentInfo(
value_type=AssignmentType.STRUCT,
python_type=struct.ctype_struct,
value=None,
pointer_level=None,
signature=None,
members=members_dict,
)
logger.info(f"Added struct assignment info for {struct.name}")
self.generated.append(struct.name)
finally:
# Remove from processing stack after we're done
processing_stack.discard(struct.name)
def gen_ir(self, struct, generated_debug_info):
# TODO: we add the btf_ama attribute by monkey patching in the end of compilation, but once llvmlite
# accepts our issue, we will resort to normal accessed attribute based attribute addition
# currently we generate all possible field accesses for CO-RE and put into the assignment table
debug_info = debug_info_generation(
struct, self.llvm_module, generated_debug_info
)
field_index = 0
# Make sure the struct has an entry in our field names dictionary
if struct.name not in self.generated_field_names:
self.generated_field_names[struct.name] = {}
for field_name, field in struct.fields.items():
# does not take arrays and similar types into consideration yet.
if field.ctype_complex_type is not None and issubclass(
field.ctype_complex_type, ctypes.Array
):
array_size = field.type_size
containing_type = field.containing_type
if containing_type.__module__ == ctypes.__name__:
containing_type_size = ctypes.sizeof(containing_type)
if array_size == 0:
field_co_re_name = self._struct_name_generator(
struct, field, field_index, True, 0, containing_type_size
)
globvar = ir.GlobalVariable(
self.llvm_module, ir.IntType(64), name=field_co_re_name
)
globvar.linkage = "external"
globvar.set_metadata("llvm.preserve.access.index", debug_info)
self.generated_field_names[struct.name][field_name] = globvar
field_index += 1
continue
for i in range(0, array_size):
field_co_re_name = self._struct_name_generator(
struct, field, field_index, True, i, containing_type_size
)
globvar = ir.GlobalVariable(
self.llvm_module, ir.IntType(64), name=field_co_re_name
)
globvar.linkage = "external"
globvar.set_metadata("llvm.preserve.access.index", debug_info)
self.generated_field_names[struct.name][field_name] = globvar
field_index += 1
elif field.type_size is not None:
array_size = field.type_size
containing_type = field.containing_type
if containing_type.__module__ == "vmlinux":
containing_type_size = self.handler[
containing_type.__name__
].current_offset
for i in range(0, array_size):
field_co_re_name = self._struct_name_generator(
struct, field, field_index, True, i, containing_type_size
)
globvar = ir.GlobalVariable(
self.llvm_module, ir.IntType(64), name=field_co_re_name
)
globvar.linkage = "external"
globvar.set_metadata("llvm.preserve.access.index", debug_info)
self.generated_field_names[struct.name][field_name] = globvar
field_index += 1
else:
field_co_re_name = self._struct_name_generator(
struct, field, field_index
)
field_index += 1
globvar = ir.GlobalVariable(
self.llvm_module, ir.IntType(64), name=field_co_re_name
)
globvar.linkage = "external"
globvar.set_metadata("llvm.preserve.access.index", debug_info)
self.generated_field_names[struct.name][field_name] = globvar
return debug_info
def _struct_name_generator(
self,
struct: DependencyNode,
field,
field_index: int,
is_indexed: bool = False,
index: int = 0,
containing_type_size: int = 0,
) -> str:
# TODO: Does not support Unions as well as recursive pointer and array type naming
if is_indexed:
name = (
"llvm."
+ struct.name.removeprefix("struct_")
+ f":0:{field.offset + index * containing_type_size}"
+ "$"
+ f"0:{field_index}:{index}"
)
return name
elif struct.name.startswith("struct_"):
name = (
"llvm."
+ struct.name.removeprefix("struct_")
+ f":0:{field.offset}"
+ "$"
+ f"0:{field_index}"
)
return name
else:
print(self.handler[struct.name])
raise TypeError(
"Name generation cannot occur due to type name not starting with struct"
)

View File

@ -1,90 +0,0 @@
import logging
from llvmlite import ir
from pythonbpf.vmlinux_parser.assignment_info import AssignmentType
logger = logging.getLogger(__name__)
class VmlinuxHandler:
"""Handler for vmlinux-related operations"""
_instance = None
@classmethod
def get_instance(cls):
"""Get the singleton instance"""
if cls._instance is None:
logger.warning("VmlinuxHandler used before initialization")
return None
return cls._instance
@classmethod
def initialize(cls, vmlinux_symtab):
"""Initialize the handler with vmlinux symbol table"""
cls._instance = cls(vmlinux_symtab)
return cls._instance
def __init__(self, vmlinux_symtab):
"""Initialize with vmlinux symbol table"""
self.vmlinux_symtab = vmlinux_symtab
logger.info(
f"VmlinuxHandler initialized with {len(vmlinux_symtab) if vmlinux_symtab else 0} symbols"
)
def is_vmlinux_enum(self, name):
"""Check if name is a vmlinux enum constant"""
return (
name in self.vmlinux_symtab
and self.vmlinux_symtab[name]["value_type"] == AssignmentType.CONSTANT
)
def is_vmlinux_struct(self, name):
"""Check if name is a vmlinux struct"""
return (
name in self.vmlinux_symtab
and self.vmlinux_symtab[name]["value_type"] == AssignmentType.STRUCT
)
def handle_vmlinux_enum(self, name):
"""Handle vmlinux enum constants by returning LLVM IR constants"""
if self.is_vmlinux_enum(name):
value = self.vmlinux_symtab[name]["value"]
logger.info(f"Resolving vmlinux enum {name} = {value}")
return ir.Constant(ir.IntType(64), value), ir.IntType(64)
return None
def get_vmlinux_enum_value(self, name):
"""Handle vmlinux enum constants by returning LLVM IR constants"""
if self.is_vmlinux_enum(name):
value = self.vmlinux_symtab[name]["value"]
logger.info(f"The value of vmlinux enum {name} = {value}")
return value
return None
def handle_vmlinux_struct(self, struct_name, module, builder):
"""Handle vmlinux struct initializations"""
if self.is_vmlinux_struct(struct_name):
# TODO: Implement core-specific struct handling
# This will be more complex and depends on the BTF information
logger.info(f"Handling vmlinux struct {struct_name}")
# Return struct type and allocated pointer
# This is a stub, actual implementation will be more complex
return None
return None
def handle_vmlinux_struct_field(
self, struct_var_name, field_name, module, builder, local_sym_tab
):
"""Handle access to vmlinux struct fields"""
# Check if it's a variable of vmlinux struct type
if struct_var_name in local_sym_tab:
var_info = local_sym_tab[struct_var_name] # noqa: F841
# Need to check if this variable is a vmlinux struct
# This will depend on how you track vmlinux struct types in your symbol table
logger.info(
f"Attempting to access field {field_name} of possible vmlinux struct {struct_var_name}"
)
# Return pointer to field and field type
return None
return None

View File

@ -1,10 +1,11 @@
#include "vmlinux.h"
#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
#include <bpf/bpf_endian.h>
#define u64 unsigned long long
#define u32 unsigned int
SEC("xdp")
int hello(struct xdp_md *ctx) {
bpf_printk("Hello, World! %ud \n", ctx->data);
bpf_printk("Hello, World!\n");
return XDP_PASS;
}

View File

@ -1,9 +1,23 @@
// SPDX-License-Identifier: GPL-2.0
#include "vmlinux.h"
#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
#include <bpf/bpf_tracing.h>
struct trace_entry {
short unsigned int type;
unsigned char flags;
unsigned char preempt_count;
int pid;
};
struct trace_event_raw_sys_enter {
struct trace_entry ent;
long int id;
long unsigned int args[6];
char __data[0];
};
struct event {
__u32 pid;
__u32 uid;
@ -19,7 +33,7 @@ struct {
SEC("tp/syscalls/sys_enter_setuid")
int handle_setuid_entry(struct trace_event_raw_sys_enter *ctx) {
struct event data = {};
struct blk_integrity_iter it = {};
// Extract UID from the syscall arguments
data.uid = (unsigned int)ctx->args[0];
data.ts = bpf_ktime_get_ns();

View File

@ -1,40 +0,0 @@
from pythonbpf import bpf, map, section, bpfglobal, compile
from ctypes import c_void_p, c_int64, c_uint64
from pythonbpf.maps import HashMap
# NOTE: This example tries to reinterpret the variable `x` to a different type.
# We do not allow this for now, as stack allocations are typed and have to be
# done in the first basic block. Allowing re-interpretation would require
# re-allocation of stack space (possibly in a new basic block), which is not
# supported in eBPF yet.
# We can allow bitcasts in cases where the width of the types is the same in
# the future. But for now, we do not allow any re-interpretation of variables.
@bpf
@map
def last() -> HashMap:
return HashMap(key=c_uint64, value=c_uint64, max_entries=3)
@bpf
@section("tracepoint/syscalls/sys_enter_execve")
def hello_world(ctx: c_void_p) -> c_int64:
last.update(0, 1)
x = last.lookup(0)
x = 20
if x == 2:
print("Hello, World!")
else:
print("Goodbye, World!")
return
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile()

View File

@ -3,19 +3,16 @@ import logging
from pythonbpf import compile, bpf, section, bpfglobal, compile_to_ir
from ctypes import c_void_p, c_int64, c_int32
@bpf
@bpfglobal
def somevalue() -> c_int32:
return c_int32(42)
@bpf
@bpfglobal
def somevalue2() -> c_int64:
return c_int64(69)
@bpf
@bpfglobal
def somevalue1() -> c_int32:
@ -24,14 +21,12 @@ def somevalue1() -> c_int32:
# --- Passing examples ---
# Simple constant return
@bpf
@bpfglobal
def g1() -> c_int64:
return c_int64(42)
# Constructor with one constant argument
@bpf
@bpfglobal
@ -67,17 +62,15 @@ def g2() -> c_int64:
# def g6() -> c_int64:
# return c_int64(CONST)
# Constructor with multiple args
# TODO: this is not working. should it work ?
#TODO: this is not working. should it work ?
@bpf
@bpfglobal
def g7() -> c_int64:
return c_int64(1)
# Dataclass call
# TODO: fails with dataclass
#TODO: fails with dataclass
# @dataclass
# class Point:
# x: c_int64
@ -98,7 +91,6 @@ def sometag(ctx: c_void_p) -> c_int64:
print(f"{somevalue}")
return c_int64(1)
@bpf
@bpfglobal
def LICENSE() -> str:

View File

@ -11,7 +11,6 @@ from ctypes import c_void_p, c_int64
# We cannot allocate space for the intermediate type now.
# We probably need to track the ref/deref chain for each variable.
@bpf
@map
def count() -> HashMap:

View File

@ -3,7 +3,6 @@ import logging
from pythonbpf import compile, bpf, section, bpfglobal, compile_to_ir
from ctypes import c_void_p, c_int64
# This should not pass as somevalue is not declared at all.
@bpf
@section("tracepoint/syscalls/sys_enter_execve")
@ -12,7 +11,6 @@ def sometag(ctx: c_void_p) -> c_int64:
print(f"{somevalue}") # noqa: F821
return c_int64(1)
@bpf
@bpfglobal
def LICENSE() -> str:

View File

@ -1,54 +0,0 @@
from pythonbpf import bpf, map, section, bpfglobal, compile_to_ir
from pythonbpf.maps import HashMap
from pythonbpf.helper import XDP_PASS
from vmlinux import TASK_COMM_LEN # noqa: F401
from vmlinux import struct_qspinlock # noqa: F401
# from vmlinux import struct_trace_event_raw_sys_enter # noqa: F401
# from vmlinux import struct_posix_cputimers # noqa: F401
from vmlinux import struct_xdp_md
# from vmlinux import struct_trace_event_raw_sys_enter # noqa: F401
# from vmlinux import struct_ring_buffer_per_cpu # noqa: F401
# from vmlinux import struct_request # noqa: F401
from ctypes import c_int64
# Instructions to how to run this program
# 1. Install PythonBPF: pip install pythonbpf
# 2. Run the program: python examples/xdp_pass.py
# 3. Run the program with sudo: sudo tools/check.sh run examples/xdp_pass.o
# 4. Attach object file to any network device with something like ./check.sh xdp examples/xdp_pass.o tailscale0
# 5. send traffic through the device and observe effects
@bpf
@map
def count() -> HashMap:
return HashMap(key=c_int64, value=c_int64, max_entries=1)
@bpf
@section("xdp")
def hello_world(ctx: struct_xdp_md) -> c_int64:
key = 0
one = 1
prev = count().lookup(key)
if prev:
prevval = prev + 1
print(f"count: {prevval}")
count().update(key, prevval)
return XDP_PASS
else:
count().update(key, one)
return XDP_PASS
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile_to_ir("xdp_pass.py", "xdp_pass.ll")

View File

@ -1,74 +0,0 @@
from pythonbpf import bpf, map, section, bpfglobal, compile, struct
from ctypes import c_void_p, c_int64, c_int32, c_uint64
from pythonbpf.maps import HashMap
from pythonbpf.helper import ktime
# NOTE: This is a comprehensive test combining struct, helper, and map features
# Please note that at line 50, though we have used an absurd expression to test
# the compiler, it is recommended to use named variables to reduce the amount of
# scratch space that needs to be allocated.
@bpf
@struct
class data_t:
pid: c_uint64
ts: c_uint64
@bpf
@map
def last() -> HashMap:
return HashMap(key=c_uint64, value=c_uint64, max_entries=3)
@bpf
@section("tracepoint/syscalls/sys_enter_execve")
def hello_world(ctx: c_void_p) -> c_int64:
dat = data_t()
dat.pid = 123
dat.pid = dat.pid + 1
print(f"pid is {dat.pid}")
tu = 9
last.update(0, tu)
last.update(1, -last.lookup(0))
x = last.lookup(0)
print(f"Map value at index 0: {x}")
x = x + c_int32(1)
print(f"x after adding 32-bit 1 is {x}")
x = ktime() - 121
print(f"ktime - 121 is {x}")
x = last.lookup(0)
x = x + 1
print(f"x is {x}")
if x == 10:
jat = data_t()
jat.ts = 456
print(f"Hello, World!, ts is {jat.ts}")
a = last.lookup(0)
print(f"a is {a}")
last.update(9, 9)
last.update(
0,
last.lookup(last.lookup(0))
+ last.lookup(last.lookup(0))
+ last.lookup(last.lookup(0)),
)
z = last.lookup(0)
print(f"new map val at index 0 is {z}")
else:
a = last.lookup(0)
print("Goodbye, World!")
c = last.lookup(1 - 1)
print(f"c is {c}")
return
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile()

View File

@ -1,27 +0,0 @@
from pythonbpf import bpf, section, bpfglobal, compile
from ctypes import c_void_p, c_int64
@bpf
@section("tracepoint/syscalls/sys_enter_execve")
def hello_world(ctx: c_void_p) -> c_int64:
x = 1
print(f"Initial x: {x}")
a = 20
x = a
print(f"Updated x with a: {x}")
x = (x + x) * 3
if x == 2:
print("Hello, World!")
else:
print(f"Goodbye, World! {x}")
return
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile()

View File

@ -1,34 +0,0 @@
from pythonbpf import bpf, map, section, bpfglobal, compile
from ctypes import c_void_p, c_int64, c_uint64
from pythonbpf.maps import HashMap
# NOTE: An example of i64** assignment with binops on the RHS
@bpf
@map
def last() -> HashMap:
return HashMap(key=c_uint64, value=c_uint64, max_entries=3)
@bpf
@section("tracepoint/syscalls/sys_enter_execve")
def hello_world(ctx: c_void_p) -> c_int64:
last.update(0, 1)
x = last.lookup(0)
print(f"{x}")
x = x + 1
if x == 2:
print("Hello, World!")
else:
print("Goodbye, World!")
return
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile()

View File

@ -1,28 +0,0 @@
from pythonbpf import bpf, struct, section, bpfglobal
from pythonbpf.helper import comm
from ctypes import c_void_p, c_int64
@bpf
@struct
class data_t:
comm: str(16) # type: ignore [valid-type]
copp: str(16) # type: ignore [valid-type]
@bpf
@section("tracepoint/syscalls/sys_enter_clone")
def hello(ctx: c_void_p) -> c_int64:
dataobj = data_t()
comm(dataobj.comm)
strobj = dataobj.comm
dataobj.copp = strobj
print(f"clone called by comm {dataobj.copp}")
return 0
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"

View File

@ -1,40 +0,0 @@
from pythonbpf import bpf, section, bpfglobal, compile, struct
from ctypes import c_void_p, c_int64, c_uint64
from pythonbpf.helper import ktime
@bpf
@struct
class data_t:
pid: c_uint64
ts: c_uint64
@bpf
@section("tracepoint/syscalls/sys_enter_execve")
def hello_world(ctx: c_void_p) -> c_int64:
dat = data_t()
dat.pid = 123
dat.pid = dat.pid + 1
print(f"pid is {dat.pid}")
x = ktime() - 121
print(f"ktime is {x}")
x = 1
x = x + 1
print(f"x is {x}")
if x == 2:
jat = data_t()
jat.ts = 456
print(f"Hello, World!, ts is {jat.ts}")
else:
print("Goodbye, World!")
return
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile()

View File

@ -1,26 +0,0 @@
from pythonbpf import bpf, struct, section, bpfglobal
from pythonbpf.helper import comm
from ctypes import c_void_p, c_int64
@bpf
@struct
class data_t:
comm: str(16) # type: ignore [valid-type]
@bpf
@section("tracepoint/syscalls/sys_enter_clone")
def hello(ctx: c_void_p) -> c_int64:
dataobj = data_t()
comm(dataobj.comm)
strobj = dataobj.comm
print(f"clone called by comm {strobj}")
return 0
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"

View File

@ -6,8 +6,8 @@ from ctypes import c_void_p, c_int32
@section("tracepoint/syscalls/sys_enter_execve")
def hello_world(ctx: c_void_p) -> c_int32:
print("Hello, World!")
a = 1 # int64
return c_int32(a) # typecast to int32
a = 1 # int64
return c_int32(a) # typecast to int32
@bpf

View File

@ -33,3 +33,5 @@ compile_to_ir("ringbuf.py", "ringbuf.ll")
compile()
b = BPF()
b.load_and_attach()
while True:
print("running")

View File

@ -1,47 +0,0 @@
import logging
from pythonbpf import bpf, section, bpfglobal, compile_to_ir, map
from pythonbpf import compile # noqa: F401
from vmlinux import TASK_COMM_LEN # noqa: F401
from vmlinux import struct_trace_event_raw_sys_enter # noqa: F401
from ctypes import c_uint64, c_int32, c_int64
from pythonbpf.maps import HashMap
# from vmlinux import struct_uinput_device
# from vmlinux import struct_blk_integrity_iter
@bpf
@map
def mymap() -> HashMap:
return HashMap(key=c_int32, value=c_uint64, max_entries=TASK_COMM_LEN)
@bpf
@map
def mymap2() -> HashMap:
return HashMap(key=c_int32, value=c_uint64, max_entries=18)
# Instructions to how to run this program
# 1. Install PythonBPF: pip install pythonbpf
# 2. Run the program: python examples/simple_struct_test.py
# 3. Run the program with sudo: sudo tools/check.sh run examples/simple_struct_test.o
# 4. Attach object file to any network device with something like ./check.sh run examples/simple_struct_test.o tailscale0
# 5. send traffic through the device and observe effects
@bpf
@section("tracepoint/syscalls/sys_enter_execve")
def hello_world(ctx: struct_trace_event_raw_sys_enter) -> c_int64:
a = 2 + TASK_COMM_LEN + TASK_COMM_LEN
print(f"Hello, World{TASK_COMM_LEN} and {a}")
return c_int64(TASK_COMM_LEN + 2)
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile_to_ir("simple_struct_test.py", "simple_struct_test.ll", loglevel=logging.DEBUG)
# compile()

View File

@ -1,199 +0,0 @@
#!/bin/bash
print_warning() {
echo -e "\033[1;33m$1\033[0m"
}
print_info() {
echo -e "\033[1;32m$1\033[0m"
}
if [ "$EUID" -ne 0 ]; then
echo "Please run this script with sudo."
exit 1
fi
print_warning "===================================================================="
print_warning " WARNING "
print_warning " This script will run kernel-level BPF programs. "
print_warning " BPF programs run with kernel privileges and could potentially "
print_warning " affect system stability if not used properly. "
print_warning " "
print_warning " This is a non-interactive version for curl piping. "
print_warning " The script will proceed automatically with installation. "
print_warning "===================================================================="
echo
print_info "This script will:"
echo "1. Check and install required dependencies (libelf, clang, python, bpftool)"
echo "2. Download example programs from the Python-BPF GitHub repository"
echo "3. Create a Python virtual environment with necessary packages"
echo "4. Set up a Jupyter notebook server"
echo "Starting in 5 seconds. Press Ctrl+C to cancel..."
sleep 5
WORK_DIR="/tmp/python_bpf_setup"
REAL_USER=$(logname || echo "$SUDO_USER")
echo "Creating temporary directory: $WORK_DIR"
mkdir -p "$WORK_DIR"
cd "$WORK_DIR" || exit 1
if [ -f /etc/os-release ]; then
. /etc/os-release
DISTRO=$ID
else
echo "Cannot determine Linux distribution. Exiting."
exit 1
fi
install_dependencies() {
case $DISTRO in
ubuntu|debian|pop|mint|elementary|zorin)
echo "Detected Ubuntu/Debian-based system"
apt update
# Check and install libelf
if ! dpkg -l libelf-dev >/dev/null 2>&1; then
echo "Installing libelf-dev..."
apt install -y libelf-dev
else
echo "libelf-dev is already installed."
fi
# Check and install clang
if ! command -v clang >/dev/null 2>&1; then
echo "Installing clang..."
apt install -y clang
else
echo "clang is already installed."
fi
# Check and install python
if ! command -v python3 >/dev/null 2>&1; then
echo "Installing python3..."
apt install -y python3 python3-pip python3-venv
else
echo "python3 is already installed."
fi
# Check and install bpftool
if ! command -v bpftool >/dev/null 2>&1; then
echo "Installing bpftool..."
apt install -y linux-tools-common linux-tools-generic
# If bpftool still not found, try installing linux-tools-$(uname -r)
if ! command -v bpftool >/dev/null 2>&1; then
KERNEL_VERSION=$(uname -r)
apt install -y linux-tools-$KERNEL_VERSION
fi
else
echo "bpftool is already installed."
fi
;;
arch|manjaro|endeavouros)
echo "Detected Arch-based Linux system"
# Check and install libelf
if ! pacman -Q libelf >/dev/null 2>&1; then
echo "Installing libelf..."
pacman -S --noconfirm libelf
else
echo "libelf is already installed."
fi
# Check and install clang
if ! command -v clang >/dev/null 2>&1; then
echo "Installing clang..."
pacman -S --noconfirm clang
else
echo "clang is already installed."
fi
# Check and install python
if ! command -v python3 >/dev/null 2>&1; then
echo "Installing python3..."
pacman -S --noconfirm python python-pip
else
echo "python3 is already installed."
fi
# Check and install bpftool
if ! command -v bpftool >/dev/null 2>&1; then
echo "Installing bpftool..."
pacman -S --noconfirm bpf linux-headers
else
echo "bpftool is already installed."
fi
;;
*)
echo "Unsupported distribution: $DISTRO"
echo "This script only supports Ubuntu/Debian and Arch Linux derivatives."
exit 1
;;
esac
}
echo "Checking and installing dependencies..."
install_dependencies
# Download example programs
echo "Downloading example programs from Python-BPF GitHub repository..."
mkdir -p examples
cd examples || exit 1
echo "Fetching example files list..."
FILES=$(curl -s "https://api.github.com/repos/pythonbpf/Python-BPF/contents/examples" | grep -o '"path": "examples/[^"]*"' | awk -F'"' '{print $4}')
if [ -z "$FILES" ]; then
echo "Failed to fetch file list from repository. Using fallback method..."
# Fallback to downloading common example files
EXAMPLES=(
"binops_demo.py"
"blk_request.py"
"clone-matplotlib.ipynb"
"clone_plot.py"
"hello_world.py"
"kprobes.py"
"struct_and_perf.py"
"sys_sync.py"
"xdp_pass.py"
)
for example in "${EXAMPLES[@]}"; do
echo "Downloading: $example"
curl -s -O "https://raw.githubusercontent.com/pythonbpf/Python-BPF/master/examples/$example"
done
else
for file in $FILES; do
filename=$(basename "$file")
echo "Downloading: $filename"
curl -s -o "$filename" "https://raw.githubusercontent.com/pythonbpf/Python-BPF/master/$file"
done
fi
cd "$WORK_DIR" || exit 1
chown -R "$REAL_USER:$(id -gn "$REAL_USER")" .
echo "Creating Python virtual environment..."
su - "$REAL_USER" -c "cd \"$WORK_DIR\" && python3 -m venv venv"
echo "Installing Python packages..."
su - "$REAL_USER" -c "cd \"$WORK_DIR\" && source venv/bin/activate && pip install --upgrade pip && pip install jupyter pythonbpf pylibbpf matplotlib"
cat > "$WORK_DIR/start_jupyter.sh" << EOF
#!/bin/bash
cd "$WORK_DIR"
source venv/bin/activate
cd examples
sudo ../venv/bin/python -m notebook --ip=0.0.0.0 --allow-root
EOF
chmod +x "$WORK_DIR/start_jupyter.sh"
chown "$REAL_USER:$(id -gn "$REAL_USER")" "$WORK_DIR/start_jupyter.sh"
print_info "========================================================"
print_info "Setup complete! To start Jupyter Notebook, run:"
print_info "$ sudo $WORK_DIR/start_jupyter.sh"
print_info "========================================================"

View File

@ -26,13 +26,8 @@ import tempfile
class BTFConverter:
def __init__(
self,
btf_source="/sys/kernel/btf/vmlinux",
output_file="vmlinux.py",
keep_intermediate=False,
verbose=False,
):
def __init__(self, btf_source="/sys/kernel/btf/vmlinux", output_file="vmlinux.py",
keep_intermediate=False, verbose=False):
self.btf_source = btf_source
self.output_file = output_file
self.keep_intermediate = keep_intermediate
@ -49,7 +44,11 @@ class BTFConverter:
self.log(f"{description}...")
try:
result = subprocess.run(
cmd, shell=True, check=True, capture_output=True, text=True
cmd,
shell=True,
check=True,
capture_output=True,
text=True
)
if self.verbose and result.stdout:
print(result.stdout)
@ -70,55 +69,51 @@ class BTFConverter:
"""Step 1.5: Preprocess enum definitions."""
self.log("Preprocessing enum definitions...")
with open(input_file, "r") as f:
with open(input_file, 'r') as f:
original_code = f.read()
# Extract anonymous enums
enums = re.findall(
r"(?<!typedef\s)(enum\s*\{[^}]*\})\s*(\w+)\s*(?::\s*\d+)?\s*;",
original_code,
r'(?<!typedef\s)(enum\s*\{[^}]*\})\s*(\w+)\s*(?::\s*\d+)?\s*;',
original_code
)
enum_defs = [enum_block + ";" for enum_block, _ in enums]
enum_defs = [enum_block + ';' for enum_block, _ in enums]
# Replace anonymous enums with int declarations
processed_code = re.sub(
r"(?<!typedef\s)enum\s*\{[^}]*\}\s*(\w+)\s*(?::\s*\d+)?\s*;",
r"int \1;",
original_code,
r'(?<!typedef\s)enum\s*\{[^}]*\}\s*(\w+)\s*(?::\s*\d+)?\s*;',
r'int \1;',
original_code
)
# Prepend enum definitions
if enum_defs:
enum_text = "\n".join(enum_defs) + "\n\n"
enum_text = '\n'.join(enum_defs) + '\n\n'
processed_code = enum_text + processed_code
output_file = os.path.join(self.temp_dir, "vmlinux_processed.h")
with open(output_file, "w") as f:
with open(output_file, 'w') as f:
f.write(processed_code)
return output_file
def step2_5_process_kioctx(self, input_file):
# TODO: this is a very bad bug and design decision. A single struct has an issue mostly.
#TODO: this is a very bad bug and design decision. A single struct has an issue mostly.
"""Step 2.5: Process struct kioctx to extract nested anonymous structs."""
self.log("Processing struct kioctx nested structs...")
with open(input_file, "r") as f:
with open(input_file, 'r') as f:
content = f.read()
# Pattern to match struct kioctx with its full body (handles multiple nesting levels)
kioctx_pattern = (
r"struct\s+kioctx\s*\{(?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*\}\s*;"
)
kioctx_pattern = r'struct\s+kioctx\s*\{(?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*\}\s*;'
def process_kioctx_replacement(match):
full_struct = match.group(0)
self.log(f"Found struct kioctx, length: {len(full_struct)} chars")
# Extract the struct body (everything between outermost { and })
body_match = re.search(
r"struct\s+kioctx\s*\{(.*)\}\s*;", full_struct, re.DOTALL
)
body_match = re.search(r'struct\s+kioctx\s*\{(.*)\}\s*;', full_struct, re.DOTALL)
if not body_match:
return full_struct
@ -126,7 +121,7 @@ class BTFConverter:
# Find all anonymous structs within the body
# Pattern: struct { ... } followed by ; (not a member name)
# anon_struct_pattern = r"struct\s*\{[^}]*\}"
anon_struct_pattern = r'struct\s*\{[^}]*\}'
anon_structs = []
anon_counter = 4 # Start from 4, counting down to 1
@ -136,9 +131,7 @@ class BTFConverter:
anon_struct_content = m.group(0)
# Extract the body of the anonymous struct
anon_body_match = re.search(
r"struct\s*\{(.*)\}", anon_struct_content, re.DOTALL
)
anon_body_match = re.search(r'struct\s*\{(.*)\}', anon_struct_content, re.DOTALL)
if not anon_body_match:
return anon_struct_content
@ -161,7 +154,7 @@ class BTFConverter:
processed_body = body
# Find all occurrences and process them
pattern_with_semicolon = r"struct\s*\{([^}]*)\}\s*;"
pattern_with_semicolon = r'struct\s*\{([^}]*)\}\s*;'
matches = list(re.finditer(pattern_with_semicolon, body, re.DOTALL))
if not matches:
@ -185,16 +178,14 @@ class BTFConverter:
# Replace in the body
replacement = f"struct {anon_name} {member_name};"
processed_body = (
processed_body[:start_pos] + replacement + processed_body[end_pos:]
)
processed_body = processed_body[:start_pos] + replacement + processed_body[end_pos:]
anon_counter -= 1
# Rebuild the complete definition
if anon_structs:
# Prepend the anonymous struct definitions
anon_definitions = "\n".join(anon_structs) + "\n\n"
anon_definitions = '\n'.join(anon_structs) + '\n\n'
new_struct = f"struct kioctx {{{processed_body}}};"
return anon_definitions + new_struct
else:
@ -202,11 +193,14 @@ class BTFConverter:
# Apply the transformation
processed_content = re.sub(
kioctx_pattern, process_kioctx_replacement, content, flags=re.DOTALL
kioctx_pattern,
process_kioctx_replacement,
content,
flags=re.DOTALL
)
output_file = os.path.join(self.temp_dir, "vmlinux_kioctx_processed.h")
with open(output_file, "w") as f:
with open(output_file, 'w') as f:
f.write(processed_content)
self.log(f"Saved kioctx-processed output to {output_file}")
@ -224,7 +218,7 @@ class BTFConverter:
output_file = os.path.join(self.temp_dir, "vmlinux_raw.py")
cmd = (
f"clang2py {input_file} -o {output_file} "
f'--clang-args="-fno-ms-extensions -I/usr/include -I/usr/include/linux"'
f"--clang-args=\"-fno-ms-extensions -I/usr/include -I/usr/include/linux\""
)
self.run_command(cmd, "Converting to Python ctypes")
return output_file
@ -240,21 +234,25 @@ class BTFConverter:
data = re.sub(r"\('_[0-9]+',\s*ctypes\.[a-zA-Z0-9_]+,\s*0\),?\s*\n?", "", data)
# Replace ('_20', ctypes.c_uint64, 64) → ('_20', ctypes.c_uint64)
data = re.sub(
r"\('(_[0-9]+)',\s*(ctypes\.[a-zA-Z0-9_]+),\s*[0-9]+\)", r"('\1', \2)", data
)
data = re.sub(r"\('(_[0-9]+)',\s*(ctypes\.[a-zA-Z0-9_]+),\s*[0-9]+\)", r"('\1', \2)", data)
# Replace ('_20', ctypes.c_char, 8) with ('_20', ctypes.c_uint8, 8)
data = re.sub(r"(ctypes\.c_char)(\s*,\s*\d+\))", r"ctypes.c_uint8\2", data)
data = re.sub(
r"(ctypes\.c_char)(\s*,\s*\d+\))",
r"ctypes.c_uint8\2",
data
)
# below to replace those c_bool with bitfield greater than 8
def repl(m):
name, bits = m.groups()
return (
f"('{name}', ctypes.c_uint32, {bits})" if int(bits) > 8 else m.group(0)
)
return f"('{name}', ctypes.c_uint32, {bits})" if int(bits) > 8 else m.group(0)
data = re.sub(r"\('([^']+)',\s*ctypes\.c_bool,\s*(\d+)\)", repl, data)
data = re.sub(
r"\('([^']+)',\s*ctypes\.c_bool,\s*(\d+)\)",
repl,
data
)
# Remove ctypes. prefix from invalid entries
invalid_ctypes = ["bpf_iter_state", "_cache_type", "fs_context_purpose"]
@ -271,7 +269,6 @@ class BTFConverter:
if not self.keep_intermediate and self.temp_dir != ".":
self.log(f"Cleaning up temporary directory: {self.temp_dir}")
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
def convert(self):
@ -295,7 +292,6 @@ class BTFConverter:
except Exception as e:
print(f"\n✗ Error during conversion: {e}", file=sys.stderr)
import traceback
traceback.print_exc()
sys.exit(1)
finally:
@ -308,13 +304,18 @@ class BTFConverter:
dependencies = {
"bpftool": "bpftool --version",
"clang": "clang --version",
"clang2py": "clang2py --version",
"clang2py": "clang2py --version"
}
missing = []
for tool, cmd in dependencies.items():
try:
subprocess.run(cmd, shell=True, check=True, capture_output=True)
subprocess.run(
cmd,
shell=True,
check=True,
capture_output=True
)
except subprocess.CalledProcessError:
missing.append(tool)
@ -336,31 +337,31 @@ Examples:
%(prog)s
%(prog)s -o kernel_types.py
%(prog)s --btf-source /sys/kernel/btf/custom_module -k -v
""",
"""
)
parser.add_argument(
"--btf-source",
default="/sys/kernel/btf/vmlinux",
help="Path to BTF source (default: /sys/kernel/btf/vmlinux)",
help="Path to BTF source (default: /sys/kernel/btf/vmlinux)"
)
parser.add_argument(
"-o",
"--output",
"-o", "--output",
default="vmlinux.py",
help="Output Python file (default: vmlinux.py)",
help="Output Python file (default: vmlinux.py)"
)
parser.add_argument(
"-k",
"--keep-intermediate",
"-k", "--keep-intermediate",
action="store_true",
help="Keep intermediate files (vmlinux.h, vmlinux_processed.h, etc.)",
help="Keep intermediate files (vmlinux.h, vmlinux_processed.h, etc.)"
)
parser.add_argument(
"-v", "--verbose", action="store_true", help="Enable verbose output"
"-v", "--verbose",
action="store_true",
help="Enable verbose output"
)
args = parser.parse_args()
@ -369,7 +370,7 @@ Examples:
btf_source=args.btf_source,
output_file=args.output,
keep_intermediate=args.keep_intermediate,
verbose=args.verbose,
verbose=args.verbose
)
converter.convert()