diff --git a/anomaly-detection/lib/__init__.py b/anomaly-detection/lib/__init__.py new file mode 100644 index 0000000..f02a48a --- /dev/null +++ b/anomaly-detection/lib/__init__.py @@ -0,0 +1,22 @@ +""" +Process Anomaly Detection - Constants and Utilities +""" + +import logging + +logger = logging.getLogger(__name__) +MAX_SYSCALLS = 548 + + +def comm_for_pid(pid: int) -> bytes | None: + """Get process name from /proc.""" + try: + with open(f"/proc/{pid}/comm", "rb") as f: + return f.read().strip() + except FileNotFoundError: + logger.warning(f"Process with PID {pid} not found.") + except PermissionError: + logger.warning(f"Permission denied when accessing /proc/{pid}/comm.") + except Exception as e: + logger.warning(f"Error reading /proc/{pid}/comm: {e}") + return None diff --git a/anomaly-detection/lib/ml.py b/anomaly-detection/lib/ml.py new file mode 100644 index 0000000..9b0c395 --- /dev/null +++ b/anomaly-detection/lib/ml.py @@ -0,0 +1,173 @@ +""" +Autoencoder for Process Behavior Anomaly Detection + +Uses Keras/TensorFlow to train an autoencoder on syscall patterns. +Anomalies are detected when reconstruction error exceeds threshold. +""" + +import logging +import os + +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split +from tensorflow import keras + +from lib import MAX_SYSCALLS + +logger = logging.getLogger(__name__) + + +def create_autoencoder(n_inputs: int = MAX_SYSCALLS) -> keras.Model: + """ + Create the autoencoder architecture. + + Architecture: input → encoder → bottleneck → decoder → output + """ + inp = keras.Input(shape=(n_inputs,)) + + # Encoder + encoder = keras.layers.Dense(n_inputs)(inp) + encoder = keras.layers.ReLU()(encoder) + + # Bottleneck (compressed representation) + bottleneck = keras.layers.Dense(n_inputs // 2)(encoder) + + # Decoder + decoder = keras.layers.Dense(n_inputs)(bottleneck) + decoder = keras.layers.ReLU()(decoder) + output = keras.layers.Dense(n_inputs, activation="linear")(decoder) + + model = keras.Model(inp, output) + model.compile(optimizer="adam", loss="mse") + + return model + + +class AutoEncoder: + """ + Autoencoder for syscall pattern anomaly detection. + + Usage: + # Training + ae = AutoEncoder('model.keras') + model, threshold = ae.train('data.csv', epochs=200) + + # Inference + ae = AutoEncoder('model.keras', load=True) + _, errors, total_error = ae.predict([features]) + """ + + def __init__(self, filename: str, load: bool = False): + self.filename = filename + self.model = None + + if load: + self._load_model() + + def _load_model(self) -> None: + """Load a trained model from disk.""" + if not os.path.exists(self.filename): + raise FileNotFoundError(f"Model file not found: {self.filename}") + + logger.info(f"Loading model from {self.filename}") + self.model = keras.models.load_model(self.filename) + + def train( + self, + datafile: str, + epochs: int, + batch_size: int, + test_size: float = 0.1, + ) -> tuple[keras.Model, float]: + """ + Train the autoencoder on collected data. + + Args: + datafile: Path to CSV file with training data + epochs: Number of training epochs + batch_size: Training batch size + test_size: Fraction of data to use for validation + + Returns: + Tuple of (trained model, error threshold) + """ + if not os.path.exists(datafile): + raise FileNotFoundError(f"Data file not found: {datafile}") + + logger.info(f"Loading training data from {datafile}") + + # Load and prepare data + df = pd.read_csv(datafile) + features = df.drop(["sample_time"], axis=1).values + + logger.info(f"Loaded {len(features)} samples with {features.shape[1]} features") + + # Split train/test + train_data, test_data = train_test_split( + features, + test_size=test_size, + random_state=42, + ) + + logger.info(f"Training set: {len(train_data)} samples") + logger.info(f"Test set: {len(test_data)} samples") + + # Create and train model + self.model = create_autoencoder() + + if self.model is None: + raise RuntimeError("Failed to create the autoencoder model.") + + logger.info("Training autoencoder...") + self.model.fit( + train_data, + train_data, + validation_data=(test_data, test_data), + epochs=epochs, + batch_size=batch_size, + verbose=1, + ) + + # Save model (use .keras format for Keras 3.x compatibility) + self.model.save(self.filename) + logger.info(f"Model saved to {self.filename}") + + # Calculate error threshold from test data + threshold = self._calculate_threshold(test_data) + + return self.model, threshold + + def _calculate_threshold(self, test_data: np.ndarray) -> float: + """Calculate error threshold from test data.""" + logger.info(f"Calculating error threshold from {len(test_data)} test samples") + + if self.model is None: + raise RuntimeError("Model not loaded. Use load=True or train first.") + + predictions = self.model.predict(test_data, verbose=0) + errors = np.abs(test_data - predictions).sum(axis=1) + + return float(errors.max()) + + def predict(self, X: list | np.ndarray) -> tuple[np.ndarray, np.ndarray, float]: + """ + Run prediction and return reconstruction error. + + Args: + X: Input data (list of feature vectors) + + Returns: + Tuple of (reconstructed, per_feature_errors, total_error) + """ + if self.model is None: + raise RuntimeError("Model not loaded. Use load=True or train first.") + + X = np.asarray(X, dtype=np.float32) + y = self.model.predict(X, verbose=0) + + # Per-feature reconstruction error + errors = np.abs(X[0] - y[0]) + total_error = float(errors.sum()) + + return y, errors, total_error diff --git a/anomaly-detection/lib/platform.py b/anomaly-detection/lib/platform.py new file mode 100644 index 0000000..5e996ee --- /dev/null +++ b/anomaly-detection/lib/platform.py @@ -0,0 +1,448 @@ +# Copyright 2017 Sasha Goldshtein +# Copyright 2018 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +syscall.py contains functions useful for mapping between syscall names and numbers +""" + +# Syscall table for Linux x86_64, not very recent. Automatically generated from +# https://git.kernel.org/pub/scm/linux/kernel/git/stable/linux.git/tree/arch/x86/entry/syscalls/syscall_64.tbl?h=linux-6.17.y +# using the following command: +# +# cat arch/x86/entry/syscalls/syscall_64.tbl \ +# | awk 'BEGIN { print "syscalls = {" } +# /^[0-9]/ { print " "$1": b\""$3"\"," } +# END { print "}" }' + +SYSCALLS = { + 0: b"read", + 1: b"write", + 2: b"open", + 3: b"close", + 4: b"stat", + 5: b"fstat", + 6: b"lstat", + 7: b"poll", + 8: b"lseek", + 9: b"mmap", + 10: b"mprotect", + 11: b"munmap", + 12: b"brk", + 13: b"rt_sigaction", + 14: b"rt_sigprocmask", + 15: b"rt_sigreturn", + 16: b"ioctl", + 17: b"pread64", + 18: b"pwrite64", + 19: b"readv", + 20: b"writev", + 21: b"access", + 22: b"pipe", + 23: b"select", + 24: b"sched_yield", + 25: b"mremap", + 26: b"msync", + 27: b"mincore", + 28: b"madvise", + 29: b"shmget", + 30: b"shmat", + 31: b"shmctl", + 32: b"dup", + 33: b"dup2", + 34: b"pause", + 35: b"nanosleep", + 36: b"getitimer", + 37: b"alarm", + 38: b"setitimer", + 39: b"getpid", + 40: b"sendfile", + 41: b"socket", + 42: b"connect", + 43: b"accept", + 44: b"sendto", + 45: b"recvfrom", + 46: b"sendmsg", + 47: b"recvmsg", + 48: b"shutdown", + 49: b"bind", + 50: b"listen", + 51: b"getsockname", + 52: b"getpeername", + 53: b"socketpair", + 54: b"setsockopt", + 55: b"getsockopt", + 56: b"clone", + 57: b"fork", + 58: b"vfork", + 59: b"execve", + 60: b"exit", + 61: b"wait4", + 62: b"kill", + 63: b"uname", + 64: b"semget", + 65: b"semop", + 66: b"semctl", + 67: b"shmdt", + 68: b"msgget", + 69: b"msgsnd", + 70: b"msgrcv", + 71: b"msgctl", + 72: b"fcntl", + 73: b"flock", + 74: b"fsync", + 75: b"fdatasync", + 76: b"truncate", + 77: b"ftruncate", + 78: b"getdents", + 79: b"getcwd", + 80: b"chdir", + 81: b"fchdir", + 82: b"rename", + 83: b"mkdir", + 84: b"rmdir", + 85: b"creat", + 86: b"link", + 87: b"unlink", + 88: b"symlink", + 89: b"readlink", + 90: b"chmod", + 91: b"fchmod", + 92: b"chown", + 93: b"fchown", + 94: b"lchown", + 95: b"umask", + 96: b"gettimeofday", + 97: b"getrlimit", + 98: b"getrusage", + 99: b"sysinfo", + 100: b"times", + 101: b"ptrace", + 102: b"getuid", + 103: b"syslog", + 104: b"getgid", + 105: b"setuid", + 106: b"setgid", + 107: b"geteuid", + 108: b"getegid", + 109: b"setpgid", + 110: b"getppid", + 111: b"getpgrp", + 112: b"setsid", + 113: b"setreuid", + 114: b"setregid", + 115: b"getgroups", + 116: b"setgroups", + 117: b"setresuid", + 118: b"getresuid", + 119: b"setresgid", + 120: b"getresgid", + 121: b"getpgid", + 122: b"setfsuid", + 123: b"setfsgid", + 124: b"getsid", + 125: b"capget", + 126: b"capset", + 127: b"rt_sigpending", + 128: b"rt_sigtimedwait", + 129: b"rt_sigqueueinfo", + 130: b"rt_sigsuspend", + 131: b"sigaltstack", + 132: b"utime", + 133: b"mknod", + 134: b"uselib", + 135: b"personality", + 136: b"ustat", + 137: b"statfs", + 138: b"fstatfs", + 139: b"sysfs", + 140: b"getpriority", + 141: b"setpriority", + 142: b"sched_setparam", + 143: b"sched_getparam", + 144: b"sched_setscheduler", + 145: b"sched_getscheduler", + 146: b"sched_get_priority_max", + 147: b"sched_get_priority_min", + 148: b"sched_rr_get_interval", + 149: b"mlock", + 150: b"munlock", + 151: b"mlockall", + 152: b"munlockall", + 153: b"vhangup", + 154: b"modify_ldt", + 155: b"pivot_root", + 156: b"_sysctl", + 157: b"prctl", + 158: b"arch_prctl", + 159: b"adjtimex", + 160: b"setrlimit", + 161: b"chroot", + 162: b"sync", + 163: b"acct", + 164: b"settimeofday", + 165: b"mount", + 166: b"umount2", + 167: b"swapon", + 168: b"swapoff", + 169: b"reboot", + 170: b"sethostname", + 171: b"setdomainname", + 172: b"iopl", + 173: b"ioperm", + 174: b"create_module", + 175: b"init_module", + 176: b"delete_module", + 177: b"get_kernel_syms", + 178: b"query_module", + 179: b"quotactl", + 180: b"nfsservctl", + 181: b"getpmsg", + 182: b"putpmsg", + 183: b"afs_syscall", + 184: b"tuxcall", + 185: b"security", + 186: b"gettid", + 187: b"readahead", + 188: b"setxattr", + 189: b"lsetxattr", + 190: b"fsetxattr", + 191: b"getxattr", + 192: b"lgetxattr", + 193: b"fgetxattr", + 194: b"listxattr", + 195: b"llistxattr", + 196: b"flistxattr", + 197: b"removexattr", + 198: b"lremovexattr", + 199: b"fremovexattr", + 200: b"tkill", + 201: b"time", + 202: b"futex", + 203: b"sched_setaffinity", + 204: b"sched_getaffinity", + 205: b"set_thread_area", + 206: b"io_setup", + 207: b"io_destroy", + 208: b"io_getevents", + 209: b"io_submit", + 210: b"io_cancel", + 211: b"get_thread_area", + 212: b"lookup_dcookie", + 213: b"epoll_create", + 214: b"epoll_ctl_old", + 215: b"epoll_wait_old", + 216: b"remap_file_pages", + 217: b"getdents64", + 218: b"set_tid_address", + 219: b"restart_syscall", + 220: b"semtimedop", + 221: b"fadvise64", + 222: b"timer_create", + 223: b"timer_settime", + 224: b"timer_gettime", + 225: b"timer_getoverrun", + 226: b"timer_delete", + 227: b"clock_settime", + 228: b"clock_gettime", + 229: b"clock_getres", + 230: b"clock_nanosleep", + 231: b"exit_group", + 232: b"epoll_wait", + 233: b"epoll_ctl", + 234: b"tgkill", + 235: b"utimes", + 236: b"vserver", + 237: b"mbind", + 238: b"set_mempolicy", + 239: b"get_mempolicy", + 240: b"mq_open", + 241: b"mq_unlink", + 242: b"mq_timedsend", + 243: b"mq_timedreceive", + 244: b"mq_notify", + 245: b"mq_getsetattr", + 246: b"kexec_load", + 247: b"waitid", + 248: b"add_key", + 249: b"request_key", + 250: b"keyctl", + 251: b"ioprio_set", + 252: b"ioprio_get", + 253: b"inotify_init", + 254: b"inotify_add_watch", + 255: b"inotify_rm_watch", + 256: b"migrate_pages", + 257: b"openat", + 258: b"mkdirat", + 259: b"mknodat", + 260: b"fchownat", + 261: b"futimesat", + 262: b"newfstatat", + 263: b"unlinkat", + 264: b"renameat", + 265: b"linkat", + 266: b"symlinkat", + 267: b"readlinkat", + 268: b"fchmodat", + 269: b"faccessat", + 270: b"pselect6", + 271: b"ppoll", + 272: b"unshare", + 273: b"set_robust_list", + 274: b"get_robust_list", + 275: b"splice", + 276: b"tee", + 277: b"sync_file_range", + 278: b"vmsplice", + 279: b"move_pages", + 280: b"utimensat", + 281: b"epoll_pwait", + 282: b"signalfd", + 283: b"timerfd_create", + 284: b"eventfd", + 285: b"fallocate", + 286: b"timerfd_settime", + 287: b"timerfd_gettime", + 288: b"accept4", + 289: b"signalfd4", + 290: b"eventfd2", + 291: b"epoll_create1", + 292: b"dup3", + 293: b"pipe2", + 294: b"inotify_init1", + 295: b"preadv", + 296: b"pwritev", + 297: b"rt_tgsigqueueinfo", + 298: b"perf_event_open", + 299: b"recvmmsg", + 300: b"fanotify_init", + 301: b"fanotify_mark", + 302: b"prlimit64", + 303: b"name_to_handle_at", + 304: b"open_by_handle_at", + 305: b"clock_adjtime", + 306: b"syncfs", + 307: b"sendmmsg", + 308: b"setns", + 309: b"getcpu", + 310: b"process_vm_readv", + 311: b"process_vm_writev", + 312: b"kcmp", + 313: b"finit_module", + 314: b"sched_setattr", + 315: b"sched_getattr", + 316: b"renameat2", + 317: b"seccomp", + 318: b"getrandom", + 319: b"memfd_create", + 320: b"kexec_file_load", + 321: b"bpf", + 322: b"execveat", + 323: b"userfaultfd", + 324: b"membarrier", + 325: b"mlock2", + 326: b"copy_file_range", + 327: b"preadv2", + 328: b"pwritev2", + 329: b"pkey_mprotect", + 330: b"pkey_alloc", + 331: b"pkey_free", + 332: b"statx", + 333: b"io_pgetevents", + 334: b"rseq", + 335: b"uretprobe", + 424: b"pidfd_send_signal", + 425: b"io_uring_setup", + 426: b"io_uring_enter", + 427: b"io_uring_register", + 428: b"open_tree", + 429: b"move_mount", + 430: b"fsopen", + 431: b"fsconfig", + 432: b"fsmount", + 433: b"fspick", + 434: b"pidfd_open", + 435: b"clone3", + 436: b"close_range", + 437: b"openat2", + 438: b"pidfd_getfd", + 439: b"faccessat2", + 440: b"process_madvise", + 441: b"epoll_pwait2", + 442: b"mount_setattr", + 443: b"quotactl_fd", + 444: b"landlock_create_ruleset", + 445: b"landlock_add_rule", + 446: b"landlock_restrict_self", + 447: b"memfd_secret", + 448: b"process_mrelease", + 449: b"futex_waitv", + 450: b"set_mempolicy_home_node", + 451: b"cachestat", + 452: b"fchmodat2", + 453: b"map_shadow_stack", + 454: b"futex_wake", + 455: b"futex_wait", + 456: b"futex_requeue", + 457: b"statmount", + 458: b"listmount", + 459: b"lsm_get_self_attr", + 460: b"lsm_set_self_attr", + 461: b"lsm_list_modules", + 462: b"mseal", + 463: b"setxattrat", + 464: b"getxattrat", + 465: b"listxattrat", + 466: b"removexattrat", + 467: b"open_tree_attr", + 468: b"file_getattr", + 469: b"file_setattr", + 512: b"rt_sigaction", + 513: b"rt_sigreturn", + 514: b"ioctl", + 515: b"readv", + 516: b"writev", + 517: b"recvfrom", + 518: b"sendmsg", + 519: b"recvmsg", + 520: b"execve", + 521: b"ptrace", + 522: b"rt_sigpending", + 523: b"rt_sigtimedwait", + 524: b"rt_sigqueueinfo", + 525: b"sigaltstack", + 526: b"timer_create", + 527: b"mq_notify", + 528: b"kexec_load", + 529: b"waitid", + 530: b"set_robust_list", + 531: b"get_robust_list", + 532: b"vmsplice", + 533: b"move_pages", + 534: b"preadv", + 535: b"pwritev", + 536: b"rt_tgsigqueueinfo", + 537: b"recvmmsg", + 538: b"sendmmsg", + 539: b"process_vm_readv", + 540: b"process_vm_writev", + 541: b"setsockopt", + 542: b"getsockopt", + 543: b"io_setup", + 544: b"io_submit", + 545: b"execveat", + 546: b"preadv2", + 547: b"pwritev2", +} diff --git a/anomaly-detection/lib/probe.py b/anomaly-detection/lib/probe.py new file mode 100644 index 0000000..84353d9 --- /dev/null +++ b/anomaly-detection/lib/probe.py @@ -0,0 +1,117 @@ +""" +PythonBPF eBPF Probe for Syscall Histogram Collection +""" + +from vmlinux import struct_trace_event_raw_sys_enter +from pythonbpf import bpf, map, section, bpfglobal, BPF +from pythonbpf.helper import pid +from pythonbpf.maps import HashMap +from ctypes import c_int64 +from lib import MAX_SYSCALLS, comm_for_pid + + +@bpf +@map +def histogram() -> HashMap: + return HashMap(key=c_int64, value=c_int64, max_entries=1024) + + +@bpf +@map +def target_pid_map() -> HashMap: + return HashMap(key=c_int64, value=c_int64, max_entries=1) + + +@bpf +@section("tracepoint/raw_syscalls/sys_enter") +def trace_syscall(ctx: struct_trace_event_raw_sys_enter) -> c_int64: + syscall_id = ctx.id + current_pid = pid() + target = target_pid_map.lookup(0) + if target: + if current_pid != target: + return 0 # type: ignore + if syscall_id < 0 or syscall_id >= 548: + return 0 # type: ignore + count = histogram.lookup(syscall_id) + if count: + histogram.update(syscall_id, count + 1) + else: + histogram.update(syscall_id, 1) + return 0 # type: ignore + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +ebpf_prog = BPF() + + +class Probe: + """ + Syscall histogram probe for a target process. + + Usage: + probe = Probe(target_pid=1234) + probe.start() + histogram = probe.get_histogram() + """ + + def __init__(self, target_pid: int, max_syscalls: int = MAX_SYSCALLS): + self.target_pid = target_pid + self.max_syscalls = max_syscalls + self.comm = comm_for_pid(target_pid) + + if self.comm is None: + raise ValueError(f"Cannot find process with PID {target_pid}") + + self._bpf = None + self._histogram_map = None + self._target_map = None + + def start(self): + """Compile, load, and attach the BPF probe.""" + # Compile and load + self._bpf = ebpf_prog + self._bpf.load() + self._bpf.attach_all() + + # Get map references + self._histogram_map = self._bpf["histogram"] + self._target_map = self._bpf["target_pid_map"] + + # Set target PID in the map + self._target_map.update(0, self.target_pid) + + return self + + def get_histogram(self) -> list: + """Read current histogram values as a list.""" + if self._histogram_map is None: + raise RuntimeError("Probe not started. Call start() first.") + + result = [0] * self.max_syscalls + + for syscall_id in range(self.max_syscalls): + try: + count = self._histogram_map.lookup(syscall_id) + if count is not None: + result[syscall_id] = int(count) + except Exception: + pass + + return result + + def __getitem__(self, syscall_id: int) -> int: + """Allow indexing: probe[syscall_id]""" + if self._histogram_map is None: + raise RuntimeError("Probe not started") + + try: + count = self._histogram_map.lookup(syscall_id) + return int(count) if count is not None else 0 + except Exception: + return 0 diff --git a/anomaly-detection/main.py b/anomaly-detection/main.py new file mode 100644 index 0000000..fbcb2ae --- /dev/null +++ b/anomaly-detection/main.py @@ -0,0 +1,335 @@ +#!/usr/bin/env python3 +""" +Process Behavior Anomaly Detection using PythonBPF and Autoencoders + +Ported from evilsocket's BCC implementation to PythonBPF. +https://github.com/evilsocket/ebpf-process-anomaly-detection + +Usage: + # 1.Learn normal behavior from a process + sudo python main.py --learn --pid 1234 --data normal.csv + + # 2.Train the autoencoder (no sudo needed) + python main.py --train --data normal.csv --model model.h5 + + # 3.Monitor for anomalies + sudo python main.py --run --pid 1234 --model model.h5 +""" + +import argparse +import logging +import os +import sys +import time +from collections import Counter + +from lib import MAX_SYSCALLS +from lib.ml import AutoEncoder +from lib.platform import SYSCALLS +from lib.probe import Probe + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +logger = logging.getLogger(__name__) + + +def learn(pid: int, data_path: str, poll_interval_ms: int) -> None: + """ + Capture syscall patterns from target process. + + Args: + pid: Target process ID + data_path: Path to save CSV data + poll_interval_ms: Polling interval in milliseconds + """ + if os.path.exists(data_path): + logger.error( + f"{data_path} already exists.Delete it or use a different filename." + ) + sys.exit(1) + + try: + probe = Probe(pid) + except ValueError as e: + logger.error(str(e)) + sys.exit(1) + + probe_comm = probe.comm.decode() if probe.comm else "unknown" + + print(f"šŸ“Š Learning from process {pid} ({probe_comm})") + print(f"šŸ“ Saving data to {data_path}") + print(f"ā±ļø Polling interval: {poll_interval_ms}ms") + print("Press Ctrl+C to stop...\n") + + probe.start() + + prev_histogram = [0.0] * MAX_SYSCALLS + prev_report_time = time.time() + sample_count = 0 + poll_interval_sec = poll_interval_ms / 1000.0 + + header = "sample_time," + ",".join(f"sys_{i}" for i in range(MAX_SYSCALLS)) + + with open(data_path, "w") as fp: + fp.write(header + "\n") + + try: + while True: + histogram = [float(x) for x in probe.get_histogram()] + + if histogram != prev_histogram: + deltas = _compute_deltas(prev_histogram, histogram) + prev_histogram = histogram.copy() + + row = f"{time.time()},{','.join(map(str, deltas))}" + fp.write(row + "\n") + fp.flush() + sample_count += 1 + + now = time.time() + if now - prev_report_time >= 1.0: + print(f" {sample_count} samples saved...") + prev_report_time = now + + time.sleep(poll_interval_sec) + + except KeyboardInterrupt: + print(f"\nāœ… Stopped. Saved {sample_count} samples to {data_path}") + + +def train(data_path: str, model_path: str, epochs: int, batch_size: int) -> None: + """ + Train autoencoder on captured data. + + Args: + data_path: Path to training CSV data + model_path: Path to save trained model + epochs: Number of training epochs + batch_size: Training batch size + """ + if not os.path.exists(data_path): + logger.error(f"Data file {data_path} not found.Run --learn first.") + sys.exit(1) + + print(f"🧠 Training autoencoder on {data_path}") + print(f" Epochs: {epochs}") + print(f" Batch size: {batch_size}") + print() + + ae = AutoEncoder(model_path) + _, threshold = ae.train(data_path, epochs, batch_size) + + print() + print("=" * 50) + print("āœ… Training complete!") + print(f" Model saved to: {model_path}") + print(f" Error threshold: {threshold:.6f}") + print() + print(f"šŸ’” Use --max-error {threshold:.4f} when running detection") + print("=" * 50) + + +def run(pid: int, model_path: str, max_error: float, poll_interval_ms: int) -> None: + """ + Monitor process and detect anomalies. + + Args: + pid: Target process ID + model_path: Path to trained model + max_error: Anomaly detection threshold + poll_interval_ms: Polling interval in milliseconds + """ + if not os.path.exists(model_path): + logger.error(f"Model file {model_path} not found. Run --train first.") + sys.exit(1) + + try: + probe = Probe(pid) + except ValueError as e: + logger.error(str(e)) + sys.exit(1) + + ae = AutoEncoder(model_path, load=True) + probe_comm = probe.comm.decode() if probe.comm else "unknown" + + print(f"šŸ” Monitoring process {pid} ({probe_comm}) for anomalies") + print(f" Error threshold: {max_error}") + print(f" Polling interval: {poll_interval_ms}ms") + print("Press Ctrl+C to stop...\n") + + probe.start() + + prev_histogram = [0.0] * MAX_SYSCALLS + anomaly_count = 0 + check_count = 0 + poll_interval_sec = poll_interval_ms / 1000.0 + + try: + while True: + histogram = [float(x) for x in probe.get_histogram()] + + if histogram != prev_histogram: + deltas = _compute_deltas(prev_histogram, histogram) + prev_histogram = histogram.copy() + check_count += 1 + + _, feat_errors, total_error = ae.predict([deltas]) + + if total_error > max_error: + anomaly_count += 1 + _report_anomaly(anomaly_count, total_error, max_error, feat_errors) + + time.sleep(poll_interval_sec) + + except KeyboardInterrupt: + print("\nāœ… Stopped.") + print(f" Checks performed: {check_count}") + print(f" Anomalies detected: {anomaly_count}") + + +def _compute_deltas(prev: list[float], current: list[float]) -> list[float]: + """Compute rate of change between two histograms.""" + deltas = [] + for p, c in zip(prev, current): + if c != 0.0: + delta = 1.0 - (p / c) + else: + delta = 0.0 + deltas.append(delta) + return deltas + + +def _report_anomaly( + count: int, + total_error: float, + threshold: float, + feat_errors: list[float], +) -> None: + """Print anomaly report with top offending syscalls.""" + print(f"🚨 ANOMALY #{count} detected!") + print(f" Total error: {total_error:.4f} (threshold: {threshold})") + + errors_by_syscall = {idx: err for idx, err in enumerate(feat_errors)} + top3 = Counter(errors_by_syscall).most_common(3) + + print(" Top anomalous syscalls:") + for idx, err in top3: + name = SYSCALLS.get(idx, f"syscall_{idx}") + print(f" • {name!r}: {err:.4f}") + print() + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Process anomaly detection with PythonBPF and Autoencoders", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Learn from a process (e.g., Firefox) for a few minutes + sudo python main.py --learn --pid $(pgrep -o firefox) --data firefox.csv + + # Train the model (no sudo needed) + python main.py --train --data firefox.csv --model firefox.h5 + + # Monitor the same process for anomalies + sudo python main.py --run --pid $(pgrep -o firefox) --model firefox.h5 + + # Full workflow for nginx: + sudo python main.py --learn --pid $(pgrep -o nginx) --data nginx_normal.csv + python main.py --train --data nginx_normal.csv --model nginx.h5 --epochs 100 + sudo python main.py --run --pid $(pgrep -o nginx) --model nginx.h5 --max-error 0.05 + """, + ) + + actions = parser.add_mutually_exclusive_group() + actions.add_argument( + "--learn", + action="store_true", + help="Capture syscall patterns from a process", + ) + actions.add_argument( + "--train", + action="store_true", + help="Train autoencoder on captured data", + ) + actions.add_argument( + "--run", + action="store_true", + help="Monitor process for anomalies", + ) + + parser.add_argument( + "--pid", + type=int, + default=0, + help="Target process ID", + ) + parser.add_argument( + "--data", + default="data.csv", + help="CSV file for training data (default: data.csv)", + ) + parser.add_argument( + "--model", + default="model.keras", + help="Model file path (default: model.h5)", + ) + parser.add_argument( + "--time", + type=int, + default=100, + help="Polling interval in milliseconds (default: 100)", + ) + parser.add_argument( + "--epochs", + type=int, + default=200, + help="Training epochs (default: 200)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=16, + help="Training batch size (default: 16)", + ) + parser.add_argument( + "--max-error", + type=float, + default=0.09, + help="Anomaly detection threshold (default: 0.09)", + ) + + return parser.parse_args() + + +def main() -> None: + """Main entry point.""" + args = parse_args() + + if not any([args.learn, args.train, args.run]): + print("No action specified.Use --learn, --train, or --run.") + print("Run with --help for usage information.") + sys.exit(0) + + if args.learn: + if args.pid == 0: + logger.error("--pid required for --learn") + sys.exit(1) + learn(args.pid, args.data, args.time) + + elif args.train: + train(args.data, args.model, args.epochs, args.batch_size) + + elif args.run: + if args.pid == 0: + logger.error("--pid required for --run") + sys.exit(1) + run(args.pid, args.model, args.max_error, args.time) + + +if __name__ == "__main__": + main()