Files
python-bpf/BCC-Examples/container-monitor/net_stats.bpf.py

193 lines
5.4 KiB
Python

import logging
import time
import os
from pathlib import Path
from pythonbpf import bpf, map, section, bpfglobal, struct, compile, BPF
from pythonbpf.maps import HashMap
from pythonbpf.helper import get_current_cgroup_id
from ctypes import c_int32, c_uint64, c_void_p, c_uint32
from vmlinux import struct_sk_buff, struct_pt_regs
@bpf
@struct
class net_stats:
rx_packets: c_uint64
tx_packets: c_uint64
rx_bytes: c_uint64
tx_bytes: c_uint64
@bpf
@map
def net_stats_map() -> HashMap:
return HashMap(key=c_uint64, value=net_stats, max_entries=1024)
@bpf
@section("kprobe/__netif_receive_skb")
def trace_netif_rx(ctx: struct_pt_regs) -> c_int32:
cgroup_id = get_current_cgroup_id()
# Read skb pointer from first argument (PT_REGS_PARM1)
skb = struct_sk_buff(ctx.di) # x86_64: first arg in rdi
# Read skb->len
pkt_len = c_uint64(skb.len)
stats_ptr = net_stats_map.lookup(cgroup_id)
if stats_ptr:
stats = net_stats()
stats.rx_packets = stats_ptr.rx_packets + 1
stats.tx_packets = stats_ptr.tx_packets
stats.rx_bytes = stats_ptr.rx_bytes + pkt_len
stats.tx_bytes = stats_ptr.tx_bytes
net_stats_map.update(cgroup_id, stats)
else:
stats = net_stats()
stats.rx_packets = c_uint64(1)
stats.tx_packets = c_uint64(0)
stats.rx_bytes = pkt_len
stats.tx_bytes = c_uint64(0)
net_stats_map.update(cgroup_id, stats)
return c_int32(0)
@bpf
@section("kprobe/__dev_queue_xmit")
def trace_dev_xmit(ctx1: struct_pt_regs) -> c_int32:
cgroup_id = get_current_cgroup_id()
# Read skb pointer from first argument
skb = struct_sk_buff(ctx1.di)
pkt_len = c_uint64(skb.len)
stats_ptr = net_stats_map.lookup(cgroup_id)
if stats_ptr:
stats = net_stats()
stats.rx_packets = stats_ptr.rx_packets
stats.tx_packets = stats_ptr.tx_packets + 1
stats.rx_bytes = stats_ptr.rx_bytes
stats.tx_bytes = stats_ptr.tx_bytes + pkt_len
net_stats_map.update(cgroup_id, stats)
else:
stats = net_stats()
stats.rx_packets = c_uint64(0)
stats.tx_packets = c_uint64(1)
stats.rx_bytes = c_uint64(0)
stats.tx_bytes = pkt_len
net_stats_map.update(cgroup_id, stats)
return c_int32(0)
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
# Load and attach BPF program
b = BPF()
b.load()
b.attach_all()
# Get map reference and enable struct deserialization
net_stats_map_ref = b["net_stats_map"]
net_stats_map_ref.set_value_struct("net_stats")
def get_cgroup_ids():
"""Get all cgroup IDs from the system"""
cgroup_ids = set()
# Get cgroup IDs from running processes
for proc_dir in Path("/proc").glob("[0-9]*"):
try:
cgroup_file = proc_dir / "cgroup"
if cgroup_file.exists():
with open(cgroup_file) as f:
for line in f:
# Parse cgroup path and get inode
parts = line.strip().split(":")
if len(parts) >= 3:
cgroup_path = parts[2]
# Try to get the cgroup inode which is used as ID
cgroup_mount = f"/sys/fs/cgroup{cgroup_path}"
if os.path.exists(cgroup_mount):
stat_info = os.stat(cgroup_mount)
cgroup_ids.add(stat_info.st_ino)
except (PermissionError, FileNotFoundError, OSError):
continue
return cgroup_ids
# Display function
def display_stats():
"""Read and display network I/O statistics from BPF maps"""
print("\n" + "=" * 100)
print(f"{'CGROUP ID':<20} {'RX PACKETS':<15} {'RX BYTES':<20} {'TX PACKETS':<15} {'TX BYTES':<20}")
print("=" * 100)
# Get cgroup IDs from the system
cgroup_ids = get_cgroup_ids()
if not cgroup_ids:
print("No cgroups found...")
print("=" * 100)
return
# Initialize totals
total_rx_packets = 0
total_rx_bytes = 0
total_tx_packets = 0
total_tx_bytes = 0
# Track which cgroups have data
cgroups_with_data = []
# Display stats for each cgroup
for cgroup_id in sorted(cgroup_ids):
# Get network stats using lookup
rx_packets = 0
rx_bytes = 0
tx_packets = 0
tx_bytes = 0
net_stat = net_stats_map_ref.lookup(cgroup_id)
if net_stat:
rx_packets = int(net_stat.rx_packets)
rx_bytes = int(net_stat.rx_bytes)
tx_packets = int(net_stat.tx_packets)
tx_bytes = int(net_stat.tx_bytes)
total_rx_packets += rx_packets
total_rx_bytes += rx_bytes
total_tx_packets += tx_packets
total_tx_bytes += tx_bytes
print(f"{cgroup_id:<20} {rx_packets:<15} {rx_bytes:<20} {tx_packets:<15} {tx_bytes:<20}")
cgroups_with_data.append(cgroup_id)
if not cgroups_with_data:
print("No data collected yet...")
print("=" * 100)
print(f"{'TOTAL':<20} {total_rx_packets:<15} {total_rx_bytes:<20} {total_tx_packets:<15} {total_tx_bytes:<20}")
print()
# Main loop
if __name__ == "__main__":
print("Tracing network I/O operations... Press Ctrl+C to exit\n")
try:
while True:
time.sleep(5) # Update every 5 seconds
display_stats()
except KeyboardInterrupt:
print("\nStopped")