57 Commits

Author SHA1 Message Date
8e3942d38c format chore 2025-10-08 14:31:37 +05:30
3abe07c5b2 add global symbol table populate function 2025-10-05 14:05:10 +05:30
01bd7604ed add global symbol table populate function 2025-10-05 14:04:25 +05:30
7ae84a0d5a add failing test 2025-10-05 00:55:38 +05:30
df3f00261a changer order of passes 2025-10-04 08:17:16 +05:30
ab610147a5 update globals test and todos. 2025-10-04 06:36:51 +05:30
7720fe9f9f format chore 2025-10-04 06:33:09 +05:30
7aeac86bd3 fix broken IR generation logic for globals 2025-10-04 06:32:25 +05:30
ab1c4223d5 fix broken IR generation logic for globals 2025-10-03 22:55:40 +05:30
c3a512d5cf add global support with broken generation function 2025-10-03 22:20:04 +05:30
4a60c42cd0 add global failing test
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
2025-10-03 21:25:58 +05:30
b35134625b Merge pull request #19 from pythonbpf/fix-expr
Refactor expr_pass
2025-10-03 17:36:31 +05:30
c3db609a90 Revert to using Warning loglevel as default 2025-10-03 17:35:57 +05:30
cc626c38f7 Move binops1 to tests/passing 2025-10-03 17:13:02 +05:30
a8b3f4f86c Fix recursive binops, move failing binops to passing 2025-10-03 17:08:41 +05:30
d593969408 Refactor ugly if-elif chain in handle_binary_op 2025-10-03 14:04:38 +05:30
6d5895ebc2 More fixes to recursive dereferencer, add get_operand value 2025-10-03 13:46:52 +05:30
c9ee6e4f17 Fix recursive_dereferencer in binops 2025-10-03 13:35:15 +05:30
a622c53e0f Add deref 2025-10-03 02:00:01 +05:30
a4f1363aed Add _handle_attribute_expr 2025-10-03 01:50:59 +05:30
3a819dcaee Add _handle_constant_expr 2025-10-02 22:54:38 +05:30
729270b34b Use _handle_name_expr in eval_expr 2025-10-02 22:50:21 +05:30
44cbcccb6c Create _handle_name_expr 2025-10-02 22:43:54 +05:30
253944afd2 Merge pull request #18 from pythonbpf/fix-maps
Fix map calling convention
2025-10-02 22:12:01 +05:30
54993ce5c2 Merge branch 'master' into fix-maps 2025-10-02 22:11:38 +05:30
05083bd513 janitorial nitpicks 2025-10-02 22:10:28 +05:30
6e4c340780 Allow non-call convention for maps 2025-10-02 22:07:28 +05:30
9dbca410c2 Remove calls from map in sys_sync 2025-10-02 21:24:15 +05:30
62ca3b5ffe format errors 2025-10-02 19:07:49 +05:30
f263c35156 move debug cu generation to debug module 2025-10-02 19:05:58 +05:30
0678d70309 bump version 2025-10-02 18:02:36 +05:30
96fa5687f8 Merge pull request #17 from pythonbpf/logging
add logging
2025-10-02 17:59:18 +05:30
4d0dd68d56 fix formatting 2025-10-02 17:58:24 +05:30
89b0a07419 add logging level control 2025-10-02 17:57:37 +05:30
469ca43eaa replace prints with logger.info 2025-10-02 17:46:27 +05:30
dc2b611cbc format errors
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
2025-10-02 05:17:02 +05:30
0c1acf1420 Fix local_sym_tab usage in binary_ops 2025-10-02 05:08:05 +05:30
71b97e3e20 Add iter to LocalSymbol 2025-10-02 04:56:34 +05:30
12ba3605e9 Fix local_sym_tab usage in helpers 2025-10-02 04:53:04 +05:30
d7427f306f Fix usage of local_sym_tab in expr_pass 2025-10-02 04:50:31 +05:30
0142381ce2 Remove local_var_metadata from expr_pass 2025-10-02 04:44:14 +05:30
9223d7b5c5 Remove local_var_metadata from helpers 2025-10-02 04:40:44 +05:30
3b74ade455 Remove occurences of local_var_metadata from functions_pass, use LocalSymbol.var 2025-10-02 04:35:10 +05:30
dadcb69f1c Store LocalSymbol in allocate_mem 2025-10-02 04:27:10 +05:30
2fd2a46838 Add LocalSymbol dataclass 2025-10-02 04:13:24 +05:30
1a66887f48 move helper annotations to helpers module 2025-10-02 01:55:32 +05:30
23f3cbcea7 add type annotations 2025-10-02 01:43:05 +05:30
429f51437f Merge pull request #15 from pythonbpf/static-type-checks
Static type checks
2025-10-02 01:38:46 +05:30
c92272dd35 workflow update 2025-10-02 01:37:36 +05:30
8792740eb0 workflow update 2025-10-02 01:36:14 +05:30
cf5faaad7f remove pointless type annotation
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
2025-10-02 01:27:03 +05:30
59b3d6514b fix ruff errors 2025-10-02 01:23:55 +05:30
3c956e671a add static type checking
Signed-off-by: varun-r-mallya <varunrmallya@gmail.com>
2025-10-02 01:11:54 +05:30
8650297866 make type checks viable 2025-10-02 00:51:23 +05:30
6831f11179 Fix fstrings in examples, add alternate map attr access 2025-10-02 00:22:59 +05:30
d4e8e1bf73 Fix unterminated fstrings 2025-10-02 00:14:51 +05:30
08f2b283c9 Merge pull request #10 from pythonbpf/helper-refactor
bpf_helper_handler refactor
2025-10-02 00:08:59 +05:30
32 changed files with 915 additions and 477 deletions

View File

@ -5,10 +5,7 @@ name: Format
on: on:
workflow_dispatch: workflow_dispatch:
pull_request:
push: push:
branches:
- master
jobs: jobs:
pre-commit: pre-commit:

View File

@ -41,16 +41,15 @@ repos:
- id: ruff - id: ruff
args: ["--fix", "--show-fixes"] args: ["--fix", "--show-fixes"]
- id: ruff-format - id: ruff-format
exclude: ^(docs) exclude: ^(docs)|^(tests)|^(examples)
## Checking static types # Checking static types
#- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
# rev: "v1.10.0" rev: "v1.10.0"
# hooks: hooks:
# - id: mypy - id: mypy
# files: "setup.py" exclude: ^(tests)|^(examples)
# args: [] additional_dependencies: [types-setuptools]
# additional_dependencies: [types-setuptools]
# Changes tabs to spaces # Changes tabs to spaces
- repo: https://github.com/Lucas-C/pre-commit-hooks - repo: https://github.com/Lucas-C/pre-commit-hooks

View File

@ -60,12 +60,13 @@ pip install pythonbpf pylibbpf
```python ```python
import time import time
from pythonbpf import bpf, map, section, bpfglobal, BPF from pythonbpf import bpf, map, section, bpfglobal, BPF
from pythonbpf.helpers import pid from pythonbpf.helper import pid
from pythonbpf.maps import HashMap from pythonbpf.maps import HashMap
from pylibbpf import * from pylibbpf import *
from ctypes import c_void_p, c_int64, c_uint64, c_int32 from ctypes import c_void_p, c_int64, c_uint64, c_int32
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
# This program attaches an eBPF tracepoint to sys_enter_clone, # This program attaches an eBPF tracepoint to sys_enter_clone,
# counts per-PID clone syscalls, stores them in a hash map, # counts per-PID clone syscalls, stores them in a hash map,
# and then plots the distribution as a histogram using matplotlib. # and then plots the distribution as a histogram using matplotlib.
@ -76,6 +77,7 @@ import matplotlib.pyplot as plt
def hist() -> HashMap: def hist() -> HashMap:
return HashMap(key=c_int32, value=c_uint64, max_entries=4096) return HashMap(key=c_int32, value=c_uint64, max_entries=4096)
@bpf @bpf
@section("tracepoint/syscalls/sys_enter_clone") @section("tracepoint/syscalls/sys_enter_clone")
def hello(ctx: c_void_p) -> c_int64: def hello(ctx: c_void_p) -> c_int64:

View File

@ -1,5 +1,5 @@
from pythonbpf import bpf, map, section, bpfglobal, compile from pythonbpf import bpf, map, section, bpfglobal, compile
from pythonbpf.helpers import ktime from pythonbpf.helper import ktime
from pythonbpf.maps import HashMap from pythonbpf.maps import HashMap
from ctypes import c_void_p, c_int64, c_uint64 from ctypes import c_void_p, c_int64, c_uint64

View File

@ -1,5 +1,5 @@
from pythonbpf import bpf, map, section, bpfglobal, compile from pythonbpf import bpf, map, section, bpfglobal, compile
from pythonbpf.helpers import ktime from pythonbpf.helper import ktime
from pythonbpf.maps import HashMap from pythonbpf.maps import HashMap
from ctypes import c_void_p, c_int32, c_uint64 from ctypes import c_void_p, c_int32, c_uint64

View File

@ -10,7 +10,7 @@
"import time\n", "import time\n",
"\n", "\n",
"from pythonbpf import bpf, map, section, bpfglobal, BPF\n", "from pythonbpf import bpf, map, section, bpfglobal, BPF\n",
"from pythonbpf.helpers import pid\n", "from pythonbpf.helper import pid\n",
"from pythonbpf.maps import HashMap\n", "from pythonbpf.maps import HashMap\n",
"from pylibbpf import *\n", "from pylibbpf import *\n",
"from ctypes import c_void_p, c_int64, c_uint64, c_int32\n", "from ctypes import c_void_p, c_int64, c_uint64, c_int32\n",

View File

@ -1,7 +1,7 @@
import time import time
from pythonbpf import bpf, map, section, bpfglobal, BPF from pythonbpf import bpf, map, section, bpfglobal, BPF
from pythonbpf.helpers import pid from pythonbpf.helper import pid
from pythonbpf.maps import HashMap from pythonbpf.maps import HashMap
from pylibbpf import BpfMap from pylibbpf import BpfMap
from ctypes import c_void_p, c_int64, c_uint64, c_int32 from ctypes import c_void_p, c_int64, c_uint64, c_int32

View File

@ -1,5 +1,5 @@
from pythonbpf import bpf, map, struct, section, bpfglobal, compile from pythonbpf import bpf, map, struct, section, bpfglobal, compile
from pythonbpf.helpers import ktime, pid from pythonbpf.helper import ktime, pid
from pythonbpf.maps import PerfEventArray from pythonbpf.maps import PerfEventArray
from ctypes import c_void_p, c_int32, c_uint64 from ctypes import c_void_p, c_int32, c_uint64
@ -27,10 +27,7 @@ def hello(ctx: c_void_p) -> c_int32:
dataobj.pid = pid() dataobj.pid = pid()
dataobj.ts = ktime() dataobj.ts = ktime()
# dataobj.comm = strobj # dataobj.comm = strobj
print( print(f"clone called at {dataobj.ts} by pid" f"{dataobj.pid}, comm {strobj}")
f"clone called at {dataobj.ts} by pid {
dataobj.pid}, comm {strobj} at time {ts}"
)
events.output(dataobj) events.output(dataobj)
return c_int32(0) return c_int32(0)

View File

@ -1,5 +1,5 @@
from pythonbpf import bpf, map, section, bpfglobal, compile from pythonbpf import bpf, map, section, bpfglobal, compile
from pythonbpf.helpers import ktime from pythonbpf.helper import ktime
from pythonbpf.maps import HashMap from pythonbpf.maps import HashMap
from ctypes import c_void_p, c_int64, c_uint64 from ctypes import c_void_p, c_int64, c_uint64
@ -21,17 +21,17 @@ def last() -> HashMap:
@section("tracepoint/syscalls/sys_enter_sync") @section("tracepoint/syscalls/sys_enter_sync")
def do_trace(ctx: c_void_p) -> c_int64: def do_trace(ctx: c_void_p) -> c_int64:
key = 0 key = 0
tsp = last().lookup(key) tsp = last.lookup(key)
if tsp: if tsp:
kt = ktime() kt = ktime()
delta = kt - tsp delta = kt - tsp
if delta < 1000000000: if delta < 1000000000:
time_ms = delta // 1000000 time_ms = delta // 1000000
print(f"sync called within last second, last {time_ms} ms ago") print(f"sync called within last second, last {time_ms} ms ago")
last().delete(key) last.delete(key)
else: else:
kt = ktime() kt = ktime()
last().update(key, kt) last.update(key, kt)
return c_int64(0) return c_int64(0)

View File

@ -1,5 +1,5 @@
from pythonbpf import bpf, map, section, bpfglobal, compile from pythonbpf import bpf, map, section, bpfglobal, compile
from pythonbpf.helpers import XDP_PASS from pythonbpf.helper import XDP_PASS
from pythonbpf.maps import HashMap from pythonbpf.maps import HashMap
from ctypes import c_void_p, c_int64 from ctypes import c_void_p, c_int64

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "pythonbpf" name = "pythonbpf"
version = "0.1.3" version = "0.1.4"
description = "Reduced Python frontend for eBPF" description = "Reduced Python frontend for eBPF"
authors = [ authors = [
{ name = "r41k0u", email="pragyanshchaturvedi18@gmail.com" }, { name = "r41k0u", email="pragyanshchaturvedi18@gmail.com" },

View File

@ -1,71 +1,66 @@
import ast import ast
from llvmlite import ir from llvmlite import ir
from logging import Logger
import logging
logger: Logger = logging.getLogger(__name__)
def recursive_dereferencer(var, builder): def recursive_dereferencer(var, builder):
"""dereference until primitive type comes out""" """dereference until primitive type comes out"""
if var.type == ir.PointerType(ir.PointerType(ir.IntType(64))): # TODO: Not worrying about stack overflow for now
if isinstance(var.type, ir.PointerType):
a = builder.load(var) a = builder.load(var)
return recursive_dereferencer(a, builder) return recursive_dereferencer(a, builder)
elif var.type == ir.PointerType(ir.IntType(64)): elif isinstance(var.type, ir.IntType):
a = builder.load(var)
return recursive_dereferencer(a, builder)
elif var.type == ir.IntType(64):
return var return var
else: else:
raise TypeError(f"Unsupported type for dereferencing: {var.type}") raise TypeError(f"Unsupported type for dereferencing: {var.type}")
def handle_binary_op(rval, module, builder, var_name, local_sym_tab, map_sym_tab, func): def get_operand_value(operand, module, builder, local_sym_tab):
print(module) """Extract the value from an operand, handling variables and constants."""
left = rval.left if isinstance(operand, ast.Name):
right = rval.right if operand.id in local_sym_tab:
return recursive_dereferencer(local_sym_tab[operand.id].var, builder)
raise ValueError(f"Undefined variable: {operand.id}")
elif isinstance(operand, ast.Constant):
if isinstance(operand.value, int):
return ir.Constant(ir.IntType(64), operand.value)
raise TypeError(f"Unsupported constant type: {type(operand.value)}")
elif isinstance(operand, ast.BinOp):
return handle_binary_op_impl(operand, module, builder, local_sym_tab)
raise TypeError(f"Unsupported operand type: {type(operand)}")
def handle_binary_op_impl(rval, module, builder, local_sym_tab):
op = rval.op op = rval.op
left = get_operand_value(rval.left, module, builder, local_sym_tab)
right = get_operand_value(rval.right, module, builder, local_sym_tab)
logger.info(f"left is {left}, right is {right}, op is {op}")
# Handle left operand # Map AST operation nodes to LLVM IR builder methods
if isinstance(left, ast.Name): op_map = {
if left.id in local_sym_tab: ast.Add: builder.add,
left = recursive_dereferencer(local_sym_tab[left.id][0], builder) ast.Sub: builder.sub,
else: ast.Mult: builder.mul,
raise SyntaxError(f"Undefined variable: {left.id}") ast.Div: builder.sdiv,
elif isinstance(left, ast.Constant): ast.Mod: builder.srem,
left = ir.Constant(ir.IntType(64), left.value) ast.LShift: builder.shl,
else: ast.RShift: builder.lshr,
raise SyntaxError("Unsupported left operand type") ast.BitOr: builder.or_,
ast.BitXor: builder.xor,
ast.BitAnd: builder.and_,
ast.FloorDiv: builder.udiv,
}
if isinstance(right, ast.Name): if type(op) in op_map:
if right.id in local_sym_tab: result = op_map[type(op)](left, right)
right = recursive_dereferencer(local_sym_tab[right.id][0], builder) return result
else:
raise SyntaxError(f"Undefined variable: {right.id}")
elif isinstance(right, ast.Constant):
right = ir.Constant(ir.IntType(64), right.value)
else:
raise SyntaxError("Unsupported right operand type")
print(f"left is {left}, right is {right}, op is {op}")
if isinstance(op, ast.Add):
builder.store(builder.add(left, right), local_sym_tab[var_name][0])
elif isinstance(op, ast.Sub):
builder.store(builder.sub(left, right), local_sym_tab[var_name][0])
elif isinstance(op, ast.Mult):
builder.store(builder.mul(left, right), local_sym_tab[var_name][0])
elif isinstance(op, ast.Div):
builder.store(builder.sdiv(left, right), local_sym_tab[var_name][0])
elif isinstance(op, ast.Mod):
builder.store(builder.srem(left, right), local_sym_tab[var_name][0])
elif isinstance(op, ast.LShift):
builder.store(builder.shl(left, right), local_sym_tab[var_name][0])
elif isinstance(op, ast.RShift):
builder.store(builder.lshr(left, right), local_sym_tab[var_name][0])
elif isinstance(op, ast.BitOr):
builder.store(builder.or_(left, right), local_sym_tab[var_name][0])
elif isinstance(op, ast.BitXor):
builder.store(builder.xor(left, right), local_sym_tab[var_name][0])
elif isinstance(op, ast.BitAnd):
builder.store(builder.and_(left, right), local_sym_tab[var_name][0])
elif isinstance(op, ast.FloorDiv):
builder.store(builder.udiv(left, right), local_sym_tab[var_name][0])
else: else:
raise SyntaxError("Unsupported binary operation") raise SyntaxError("Unsupported binary operation")
def handle_binary_op(rval, module, builder, var_name, local_sym_tab):
result = handle_binary_op_impl(rval, module, builder, local_sym_tab)
builder.store(result, local_sym_tab[var_name].var)

View File

@ -4,16 +4,24 @@ from .license_pass import license_processing
from .functions_pass import func_proc from .functions_pass import func_proc
from .maps import maps_proc from .maps import maps_proc
from .structs import structs_proc from .structs import structs_proc
from .globals_pass import globals_processing from .globals_pass import (
from .debuginfo import DW_LANG_C11, DwarfBehaviorEnum globals_list_creation,
globals_processing,
populate_global_symbol_table,
)
from .debuginfo import DW_LANG_C11, DwarfBehaviorEnum, DebugInfoGenerator
import os import os
import subprocess import subprocess
import inspect import inspect
from pathlib import Path from pathlib import Path
from pylibbpf import BpfProgram from pylibbpf import BpfProgram
import tempfile import tempfile
from logging import Logger
import logging
VERSION = "v0.1.3" logger: Logger = logging.getLogger(__name__)
VERSION = "v0.1.4"
def find_bpf_chunks(tree): def find_bpf_chunks(tree):
@ -30,21 +38,27 @@ def find_bpf_chunks(tree):
def processor(source_code, filename, module): def processor(source_code, filename, module):
tree = ast.parse(source_code, filename) tree = ast.parse(source_code, filename)
print(ast.dump(tree, indent=4)) logger.debug(ast.dump(tree, indent=4))
bpf_chunks = find_bpf_chunks(tree) bpf_chunks = find_bpf_chunks(tree)
for func_node in bpf_chunks: for func_node in bpf_chunks:
print(f"Found BPF function/struct: {func_node.name}") logger.info(f"Found BPF function/struct: {func_node.name}")
populate_global_symbol_table(tree, module)
license_processing(tree, module)
globals_processing(tree, module)
structs_sym_tab = structs_proc(tree, module, bpf_chunks) structs_sym_tab = structs_proc(tree, module, bpf_chunks)
map_sym_tab = maps_proc(tree, module, bpf_chunks) map_sym_tab = maps_proc(tree, module, bpf_chunks)
func_proc(tree, module, bpf_chunks, map_sym_tab, structs_sym_tab) func_proc(tree, module, bpf_chunks, map_sym_tab, structs_sym_tab)
license_processing(tree, module) globals_list_creation(tree, module)
globals_processing(tree, module)
def compile_to_ir(filename: str, output: str): def compile_to_ir(filename: str, output: str, loglevel=logging.WARNING):
logging.basicConfig(
level=loglevel, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
with open(filename) as f: with open(filename) as f:
source = f.read() source = f.read()
@ -53,33 +67,17 @@ def compile_to_ir(filename: str, output: str):
module.triple = "bpf" module.triple = "bpf"
if not hasattr(module, "_debug_compile_unit"): if not hasattr(module, "_debug_compile_unit"):
module._file_metadata = module.add_debug_info( debug_generator = DebugInfoGenerator(module)
"DIFile", debug_generator.generate_file_metadata(filename, os.path.dirname(filename))
{ # type: ignore debug_generator.generate_debug_cu(
"filename": filename, DW_LANG_C11,
"directory": os.path.dirname(filename), f"PythonBPF {VERSION}",
}, True, # TODO: This is probably not true
# TODO: add a global field here that keeps track of all the globals. Works without it, but I think it might
# be required for kprobes.
True,
) )
module._debug_compile_unit = module.add_debug_info(
"DICompileUnit",
{ # type: ignore
"language": DW_LANG_C11,
"file": module._file_metadata, # type: ignore
"producer": f"PythonBPF {VERSION}",
"isOptimized": True, # TODO: This is probably not true
# TODO: add a global field here that keeps track of all the globals. Works without it, but I think it might
# be required for kprobes.
"runtimeVersion": 0,
"emissionKind": 1,
"splitDebugInlining": False,
"nameTableKind": 0,
},
is_distinct=True,
)
module.add_named_metadata("llvm.dbg.cu", module._debug_compile_unit) # type: ignore
processor(source, filename, module) processor(source, filename, module)
wchar_size = module.add_metadata( wchar_size = module.add_metadata(
@ -121,7 +119,7 @@ def compile_to_ir(filename: str, output: str):
module.add_named_metadata("llvm.ident", [f"PythonBPF {VERSION}"]) module.add_named_metadata("llvm.ident", [f"PythonBPF {VERSION}"])
print(f"IR written to {output}") logger.info(f"IR written to {output}")
with open(output, "w") as f: with open(output, "w") as f:
f.write(f'source_filename = "{filename}"\n') f.write(f'source_filename = "{filename}"\n')
f.write(str(module)) f.write(str(module))
@ -130,7 +128,7 @@ def compile_to_ir(filename: str, output: str):
return output return output
def compile() -> bool: def compile(loglevel=logging.WARNING) -> bool:
# Look one level up the stack to the caller of this function # Look one level up the stack to the caller of this function
caller_frame = inspect.stack()[1] caller_frame = inspect.stack()[1]
caller_file = Path(caller_frame.filename).resolve() caller_file = Path(caller_frame.filename).resolve()
@ -139,9 +137,11 @@ def compile() -> bool:
o_file = caller_file.with_suffix(".o") o_file = caller_file.with_suffix(".o")
success = True success = True
success = compile_to_ir(str(caller_file), str(ll_file)) and success
success = ( success = (
compile_to_ir(str(caller_file), str(ll_file), loglevel=loglevel) and success
)
success = bool(
subprocess.run( subprocess.run(
[ [
"llc", "llc",
@ -157,11 +157,11 @@ def compile() -> bool:
and success and success
) )
print(f"Object written to {o_file}") logger.info(f"Object written to {o_file}")
return success return success
def BPF() -> BpfProgram: def BPF(loglevel=logging.WARNING) -> BpfProgram:
caller_frame = inspect.stack()[1] caller_frame = inspect.stack()[1]
src = inspect.getsource(caller_frame.frame) src = inspect.getsource(caller_frame.frame)
with tempfile.NamedTemporaryFile( with tempfile.NamedTemporaryFile(
@ -174,7 +174,7 @@ def BPF() -> BpfProgram:
f.write(src) f.write(src)
f.flush() f.flush()
source = f.name source = f.name
compile_to_ir(source, str(inter.name)) compile_to_ir(source, str(inter.name), loglevel=loglevel)
subprocess.run( subprocess.run(
[ [
"llc", "llc",

View File

@ -12,6 +12,34 @@ class DebugInfoGenerator:
self.module = module self.module = module
self._type_cache = {} # Cache for common debug types self._type_cache = {} # Cache for common debug types
def generate_file_metadata(self, filename, dirname):
self.module._file_metadata = self.module.add_debug_info(
"DIFile",
{ # type: ignore
"filename": filename,
"directory": dirname,
},
)
def generate_debug_cu(
self, language, producer: str, is_optimized: bool, is_distinct: bool
):
self.module._debug_compile_unit = self.module.add_debug_info(
"DICompileUnit",
{ # type: ignore
"language": language,
"file": self.module._file_metadata, # type: ignore
"producer": producer,
"isOptimized": is_optimized,
"runtimeVersion": 0,
"emissionKind": 1,
"splitDebugInlining": False,
"nameTableKind": 0,
},
is_distinct=is_distinct,
)
self.module.add_named_metadata("llvm.dbg.cu", self.module._debug_compile_unit) # type: ignore
def get_basic_type(self, name: str, size: int, encoding: int) -> Any: def get_basic_type(self, name: str, size: int, encoding: int) -> Any:
"""Get or create a basic type with caching""" """Get or create a basic type with caching"""
key = (name, size, encoding) key = (name, size, encoding)

View File

@ -1,5 +1,91 @@
import ast import ast
from llvmlite import ir from llvmlite import ir
from logging import Logger
import logging
from typing import Dict
logger: Logger = logging.getLogger(__name__)
def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder):
"""Handle ast.Name expressions."""
if expr.id in local_sym_tab:
var = local_sym_tab[expr.id].var
val = builder.load(var)
return val, local_sym_tab[expr.id].ir_type
else:
logger.info(f"Undefined variable {expr.id}")
return None
def _handle_constant_expr(expr: ast.Constant):
"""Handle ast.Constant expressions."""
if isinstance(expr.value, int):
return ir.Constant(ir.IntType(64), expr.value), ir.IntType(64)
elif isinstance(expr.value, bool):
return ir.Constant(ir.IntType(1), int(expr.value)), ir.IntType(1)
else:
logger.info("Unsupported constant type")
return None
def _handle_attribute_expr(
expr: ast.Attribute,
local_sym_tab: Dict,
structs_sym_tab: Dict,
builder: ir.IRBuilder,
):
"""Handle ast.Attribute expressions for struct field access."""
if isinstance(expr.value, ast.Name):
var_name = expr.value.id
attr_name = expr.attr
if var_name in local_sym_tab:
var_ptr, var_type, var_metadata = local_sym_tab[var_name]
logger.info(f"Loading attribute {attr_name} from variable {var_name}")
logger.info(f"Variable type: {var_type}, Variable ptr: {var_ptr}")
metadata = structs_sym_tab[var_metadata]
if attr_name in metadata.fields:
gep = metadata.gep(builder, var_ptr, attr_name)
val = builder.load(gep)
field_type = metadata.field_type(attr_name)
return val, field_type
return None
def _handle_deref_call(expr: ast.Call, local_sym_tab: Dict, builder: ir.IRBuilder):
"""Handle deref function calls."""
logger.info(f"Handling deref {ast.dump(expr)}")
if len(expr.args) != 1:
logger.info("deref takes exactly one argument")
return None
arg = expr.args[0]
if (
isinstance(arg, ast.Call)
and isinstance(arg.func, ast.Name)
and arg.func.id == "deref"
):
logger.info("Multiple deref not supported")
return None
if isinstance(arg, ast.Name):
if arg.id in local_sym_tab:
arg_ptr = local_sym_tab[arg.id].var
else:
logger.info(f"Undefined variable {arg.id}")
return None
else:
logger.info("Unsupported argument type for deref")
return None
if arg_ptr is None:
logger.info("Failed to evaluate deref argument")
return None
# Load the value from pointer
val = builder.load(arg_ptr)
return val, local_sym_tab[arg.id].ir_type
def eval_expr( def eval_expr(
@ -10,72 +96,33 @@ def eval_expr(
local_sym_tab, local_sym_tab,
map_sym_tab, map_sym_tab,
structs_sym_tab=None, structs_sym_tab=None,
local_var_metadata=None,
): ):
print(f"Evaluating expression: {ast.dump(expr)}") logger.info(f"Evaluating expression: {ast.dump(expr)}")
print(local_var_metadata)
if isinstance(expr, ast.Name): if isinstance(expr, ast.Name):
if expr.id in local_sym_tab: return _handle_name_expr(expr, local_sym_tab, builder)
var = local_sym_tab[expr.id][0]
val = builder.load(var)
return val, local_sym_tab[expr.id][1] # return value and type
else:
print(f"Undefined variable {expr.id}")
return None
elif isinstance(expr, ast.Constant): elif isinstance(expr, ast.Constant):
if isinstance(expr.value, int): return _handle_constant_expr(expr)
return ir.Constant(ir.IntType(64), expr.value), ir.IntType(64)
elif isinstance(expr.value, bool):
return ir.Constant(ir.IntType(1), int(expr.value)), ir.IntType(1)
else:
print("Unsupported constant type")
return None
elif isinstance(expr, ast.Call): elif isinstance(expr, ast.Call):
if isinstance(expr.func, ast.Name) and expr.func.id == "deref":
return _handle_deref_call(expr, local_sym_tab, builder)
# delayed import to avoid circular dependency # delayed import to avoid circular dependency
from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call
if isinstance(expr.func, ast.Name): if isinstance(expr.func, ast.Name) and HelperHandlerRegistry.has_handler(
# check deref expr.func.id
if expr.func.id == "deref": ):
print(f"Handling deref {ast.dump(expr)}") return handle_helper_call(
if len(expr.args) != 1: expr,
print("deref takes exactly one argument") module,
return None builder,
arg = expr.args[0] func,
if ( local_sym_tab,
isinstance(arg, ast.Call) map_sym_tab,
and isinstance(arg.func, ast.Name) structs_sym_tab,
and arg.func.id == "deref" )
):
print("Multiple deref not supported")
return None
if isinstance(arg, ast.Name):
if arg.id in local_sym_tab:
arg = local_sym_tab[arg.id][0]
else:
print(f"Undefined variable {arg.id}")
return None
if arg is None:
print("Failed to evaluate deref argument")
return None
# Since we are handling only name case, directly take type from sym tab
val = builder.load(arg)
return val, local_sym_tab[expr.args[0].id][1]
# check for helpers
if HelperHandlerRegistry.has_handler(expr.func.id):
return handle_helper_call(
expr,
module,
builder,
func,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
local_var_metadata,
)
elif isinstance(expr.func, ast.Attribute): elif isinstance(expr.func, ast.Attribute):
print(f"Handling method call: {ast.dump(expr.func)}") logger.info(f"Handling method call: {ast.dump(expr.func)}")
if isinstance(expr.func.value, ast.Call) and isinstance( if isinstance(expr.func.value, ast.Call) and isinstance(
expr.func.value.func, ast.Name expr.func.value.func, ast.Name
): ):
@ -89,7 +136,6 @@ def eval_expr(
local_sym_tab, local_sym_tab,
map_sym_tab, map_sym_tab,
structs_sym_tab, structs_sym_tab,
local_var_metadata,
) )
elif isinstance(expr.func.value, ast.Name): elif isinstance(expr.func.value, ast.Name):
obj_name = expr.func.value.id obj_name = expr.func.value.id
@ -104,25 +150,10 @@ def eval_expr(
local_sym_tab, local_sym_tab,
map_sym_tab, map_sym_tab,
structs_sym_tab, structs_sym_tab,
local_var_metadata,
) )
elif isinstance(expr, ast.Attribute): elif isinstance(expr, ast.Attribute):
if isinstance(expr.value, ast.Name): return _handle_attribute_expr(expr, local_sym_tab, structs_sym_tab, builder)
var_name = expr.value.id logger.info("Unsupported expression evaluation")
attr_name = expr.attr
if var_name in local_sym_tab:
var_ptr, var_type = local_sym_tab[var_name]
print(f"Loading attribute " f"{attr_name} from variable {var_name}")
print(f"Variable type: {var_type}, Variable ptr: {var_ptr}")
print(local_var_metadata)
if local_var_metadata and var_name in local_var_metadata:
metadata = structs_sym_tab[local_var_metadata[var_name]]
if attr_name in metadata.fields:
gep = metadata.gep(builder, var_ptr, attr_name)
val = builder.load(gep)
field_type = metadata.field_type(attr_name)
return val, field_type
print("Unsupported expression evaluation")
return None return None
@ -134,11 +165,9 @@ def handle_expr(
local_sym_tab, local_sym_tab,
map_sym_tab, map_sym_tab,
structs_sym_tab, structs_sym_tab,
local_var_metadata,
): ):
"""Handle expression statements in the function body.""" """Handle expression statements in the function body."""
print(f"Handling expression: {ast.dump(expr)}") logger.info(f"Handling expression: {ast.dump(expr)}")
print(local_var_metadata)
call = expr.value call = expr.value
if isinstance(call, ast.Call): if isinstance(call, ast.Call):
eval_expr( eval_expr(
@ -149,7 +178,6 @@ def handle_expr(
local_sym_tab, local_sym_tab,
map_sym_tab, map_sym_tab,
structs_sym_tab, structs_sym_tab,
local_var_metadata,
) )
else: else:
print("Unsupported expression type") logger.info("Unsupported expression type")

View File

@ -1,13 +1,27 @@
from llvmlite import ir from llvmlite import ir
import ast import ast
import logging
from typing import Any
from dataclasses import dataclass
from .helper import HelperHandlerRegistry, handle_helper_call from .helper import HelperHandlerRegistry, handle_helper_call
from .type_deducer import ctypes_to_ir from .type_deducer import ctypes_to_ir
from .binary_ops import handle_binary_op from .binary_ops import handle_binary_op
from .expr_pass import eval_expr, handle_expr from .expr_pass import eval_expr, handle_expr
local_var_metadata = {} logger = logging.getLogger(__name__)
@dataclass
class LocalSymbol:
var: ir.AllocaInstr
ir_type: ir.Type
metadata: Any = None
def __iter__(self):
yield self.var
yield self.ir_type
yield self.metadata
def get_probe_string(func_node): def get_probe_string(func_node):
@ -32,28 +46,27 @@ def handle_assign(
): ):
"""Handle assignment statements in the function body.""" """Handle assignment statements in the function body."""
if len(stmt.targets) != 1: if len(stmt.targets) != 1:
print("Unsupported multiassignment") logger.info("Unsupported multiassignment")
return return
num_types = ("c_int32", "c_int64", "c_uint32", "c_uint64") num_types = ("c_int32", "c_int64", "c_uint32", "c_uint64")
target = stmt.targets[0] target = stmt.targets[0]
print(f"Handling assignment to {ast.dump(target)}") logger.info(f"Handling assignment to {ast.dump(target)}")
if not isinstance(target, ast.Name) and not isinstance(target, ast.Attribute): if not isinstance(target, ast.Name) and not isinstance(target, ast.Attribute):
print("Unsupported assignment target") logger.info("Unsupported assignment target")
return return
var_name = target.id if isinstance(target, ast.Name) else target.value.id var_name = target.id if isinstance(target, ast.Name) else target.value.id
rval = stmt.value rval = stmt.value
if isinstance(target, ast.Attribute): if isinstance(target, ast.Attribute):
# struct field assignment # struct field assignment
field_name = target.attr field_name = target.attr
if var_name in local_sym_tab and var_name in local_var_metadata: if var_name in local_sym_tab:
struct_type = local_var_metadata[var_name] struct_type = local_sym_tab[var_name].metadata
struct_info = structs_sym_tab[struct_type] struct_info = structs_sym_tab[struct_type]
if field_name in struct_info.fields: if field_name in struct_info.fields:
field_ptr = struct_info.gep( field_ptr = struct_info.gep(
builder, local_sym_tab[var_name][0], field_name builder, local_sym_tab[var_name].var, field_name
) )
val = eval_expr( val = eval_expr(
func, func,
@ -74,31 +87,31 @@ def handle_assign(
# print(f"Assigned to struct field {var_name}.{field_name}") # print(f"Assigned to struct field {var_name}.{field_name}")
pass pass
if val is None: if val is None:
print("Failed to evaluate struct field assignment") logger.info("Failed to evaluate struct field assignment")
return return
print(field_ptr) logger.info(field_ptr)
builder.store(val[0], field_ptr) builder.store(val[0], field_ptr)
print(f"Assigned to struct field {var_name}.{field_name}") logger.info(f"Assigned to struct field {var_name}.{field_name}")
return return
elif isinstance(rval, ast.Constant): elif isinstance(rval, ast.Constant):
if isinstance(rval.value, bool): if isinstance(rval.value, bool):
if rval.value: if rval.value:
builder.store(ir.Constant(ir.IntType(1), 1), builder.store(
local_sym_tab[var_name][0]) ir.Constant(ir.IntType(1), 1), local_sym_tab[var_name].var
)
else: else:
builder.store(ir.Constant(ir.IntType(1), 0), builder.store(
local_sym_tab[var_name][0]) ir.Constant(ir.IntType(1), 0), local_sym_tab[var_name].var
print(f"Assigned constant {rval.value} to {var_name}") )
logger.info(f"Assigned constant {rval.value} to {var_name}")
elif isinstance(rval.value, int): elif isinstance(rval.value, int):
# Assume c_int64 for now # Assume c_int64 for now
# var = builder.alloca(ir.IntType(64), name=var_name) # var = builder.alloca(ir.IntType(64), name=var_name)
# var.align = 8 # var.align = 8
builder.store( builder.store(
ir.Constant(ir.IntType(64), ir.Constant(ir.IntType(64), rval.value), local_sym_tab[var_name].var
rval.value), local_sym_tab[var_name][0]
) )
# local_sym_tab[var_name] = var logger.info(f"Assigned constant {rval.value} to {var_name}")
print(f"Assigned constant {rval.value} to {var_name}")
elif isinstance(rval.value, str): elif isinstance(rval.value, str):
str_val = rval.value.encode("utf-8") + b"\x00" str_val = rval.value.encode("utf-8") + b"\x00"
str_const = ir.Constant( str_const = ir.Constant(
@ -110,16 +123,15 @@ def handle_assign(
global_str.linkage = "internal" global_str.linkage = "internal"
global_str.global_constant = True global_str.global_constant = True
global_str.initializer = str_const global_str.initializer = str_const
str_ptr = builder.bitcast( str_ptr = builder.bitcast(global_str, ir.PointerType(ir.IntType(8)))
global_str, ir.PointerType(ir.IntType(8))) builder.store(str_ptr, local_sym_tab[var_name].var)
builder.store(str_ptr, local_sym_tab[var_name][0]) logger.info(f"Assigned string constant '{rval.value}' to {var_name}")
print(f"Assigned string constant '{rval.value}' to {var_name}")
else: else:
print("Unsupported constant type") logger.info("Unsupported constant type")
elif isinstance(rval, ast.Call): elif isinstance(rval, ast.Call):
if isinstance(rval.func, ast.Name): if isinstance(rval.func, ast.Name):
call_type = rval.func.id call_type = rval.func.id
print(f"Assignment call type: {call_type}") logger.info(f"Assignment call type: {call_type}")
if ( if (
call_type in num_types call_type in num_types
and len(rval.args) == 1 and len(rval.args) == 1
@ -130,14 +142,13 @@ def handle_assign(
# var = builder.alloca(ir_type, name=var_name) # var = builder.alloca(ir_type, name=var_name)
# var.align = ir_type.width // 8 # var.align = ir_type.width // 8
builder.store( builder.store(
ir.Constant( ir.Constant(ir_type, rval.args[0].value),
ir_type, rval.args[0].value), local_sym_tab[var_name][0] local_sym_tab[var_name].var,
) )
print( logger.info(
f"Assigned {call_type} constant " f"Assigned {call_type} constant "
f"{rval.args[0].value} to {var_name}" f"{rval.args[0].value} to {var_name}"
) )
# local_sym_tab[var_name] = var
elif HelperHandlerRegistry.has_handler(call_type): elif HelperHandlerRegistry.has_handler(call_type):
# var = builder.alloca(ir.IntType(64), name=var_name) # var = builder.alloca(ir.IntType(64), name=var_name)
# var.align = 8 # var.align = 8
@ -149,13 +160,11 @@ def handle_assign(
local_sym_tab, local_sym_tab,
map_sym_tab, map_sym_tab,
structs_sym_tab, structs_sym_tab,
local_var_metadata,
) )
builder.store(val[0], local_sym_tab[var_name][0]) builder.store(val[0], local_sym_tab[var_name].var)
# local_sym_tab[var_name] = var logger.info(f"Assigned constant {rval.func.id} to {var_name}")
print(f"Assigned constant {rval.func.id} to {var_name}")
elif call_type == "deref" and len(rval.args) == 1: elif call_type == "deref" and len(rval.args) == 1:
print(f"Handling deref assignment {ast.dump(rval)}") logger.info(f"Handling deref assignment {ast.dump(rval)}")
val = eval_expr( val = eval_expr(
func, func,
module, module,
@ -166,29 +175,40 @@ def handle_assign(
structs_sym_tab, structs_sym_tab,
) )
if val is None: if val is None:
print("Failed to evaluate deref argument") logger.info("Failed to evaluate deref argument")
return return
print(f"Dereferenced value: {val}, storing in {var_name}") logger.info(f"Dereferenced value: {val}, storing in {var_name}")
builder.store(val[0], local_sym_tab[var_name][0]) builder.store(val[0], local_sym_tab[var_name].var)
# local_sym_tab[var_name] = var logger.info(f"Dereferenced and assigned to {var_name}")
print(f"Dereferenced and assigned to {var_name}")
elif call_type in structs_sym_tab and len(rval.args) == 0: elif call_type in structs_sym_tab and len(rval.args) == 0:
struct_info = structs_sym_tab[call_type] struct_info = structs_sym_tab[call_type]
ir_type = struct_info.ir_type ir_type = struct_info.ir_type
# var = builder.alloca(ir_type, name=var_name) # var = builder.alloca(ir_type, name=var_name)
# Null init # Null init
builder.store(ir.Constant(ir_type, None), builder.store(ir.Constant(ir_type, None), local_sym_tab[var_name].var)
local_sym_tab[var_name][0]) logger.info(f"Assigned struct {call_type} to {var_name}")
local_var_metadata[var_name] = call_type
print(f"Assigned struct {call_type} to {var_name}")
# local_sym_tab[var_name] = var
else: else:
print(f"Unsupported assignment call type: {call_type}") logger.info(f"Unsupported assignment call type: {call_type}")
elif isinstance(rval.func, ast.Attribute): elif isinstance(rval.func, ast.Attribute):
print(f"Assignment call attribute: {ast.dump(rval.func)}") logger.info(f"Assignment call attribute: {ast.dump(rval.func)}")
if isinstance(rval.func.value, ast.Name): if isinstance(rval.func.value, ast.Name):
# TODO: probably a struct access if rval.func.value.id in map_sym_tab:
print(f"TODO STRUCT ACCESS {ast.dump(rval)}") map_name = rval.func.value.id
method_name = rval.func.attr
if HelperHandlerRegistry.has_handler(method_name):
val = handle_helper_call(
rval,
module,
builder,
func,
local_sym_tab,
map_sym_tab,
structs_sym_tab,
)
builder.store(val[0], local_sym_tab[var_name].var)
else:
# TODO: probably a struct access
logger.info(f"TODO STRUCT ACCESS {ast.dump(rval)}")
elif isinstance(rval.func.value, ast.Call) and isinstance( elif isinstance(rval.func.value, ast.Call) and isinstance(
rval.func.value.func, ast.Name rval.func.value.func, ast.Name
): ):
@ -204,22 +224,18 @@ def handle_assign(
local_sym_tab, local_sym_tab,
map_sym_tab, map_sym_tab,
structs_sym_tab, structs_sym_tab,
local_var_metadata,
) )
# var = builder.alloca(ir.IntType(64), name=var_name) # var = builder.alloca(ir.IntType(64), name=var_name)
# var.align = 8 # var.align = 8
builder.store(val[0], local_sym_tab[var_name][0]) builder.store(val[0], local_sym_tab[var_name].var)
# local_sym_tab[var_name] = var
else: else:
print("Unsupported assignment call structure") logger.info("Unsupported assignment call structure")
else: else:
print("Unsupported assignment call function type") logger.info("Unsupported assignment call function type")
elif isinstance(rval, ast.BinOp): elif isinstance(rval, ast.BinOp):
handle_binary_op( handle_binary_op(rval, module, builder, var_name, local_sym_tab)
rval, module, builder, var_name, local_sym_tab, map_sym_tab, func
)
else: else:
print("Unsupported assignment value type") logger.info("Unsupported assignment value type")
def handle_cond(func, module, builder, cond, local_sym_tab, map_sym_tab): def handle_cond(func, module, builder, cond, local_sym_tab, map_sym_tab):
@ -229,11 +245,11 @@ def handle_cond(func, module, builder, cond, local_sym_tab, map_sym_tab):
elif isinstance(cond.value, int): elif isinstance(cond.value, int):
return ir.Constant(ir.IntType(1), int(bool(cond.value))) return ir.Constant(ir.IntType(1), int(bool(cond.value)))
else: else:
print("Unsupported constant type in condition") logger.info("Unsupported constant type in condition")
return None return None
elif isinstance(cond, ast.Name): elif isinstance(cond, ast.Name):
if cond.id in local_sym_tab: if cond.id in local_sym_tab:
var = local_sym_tab[cond.id][0] var = local_sym_tab[cond.id].var
val = builder.load(var) val = builder.load(var)
if val.type != ir.IntType(1): if val.type != ir.IntType(1):
# Convert nonzero values to true, zero to false # Convert nonzero values to true, zero to false
@ -246,13 +262,12 @@ def handle_cond(func, module, builder, cond, local_sym_tab, map_sym_tab):
val = builder.icmp_signed("!=", val, zero) val = builder.icmp_signed("!=", val, zero)
return val return val
else: else:
print(f"Undefined variable {cond.id} in condition") logger.info(f"Undefined variable {cond.id} in condition")
return None return None
elif isinstance(cond, ast.Compare): elif isinstance(cond, ast.Compare):
lhs = eval_expr(func, module, builder, cond.left, lhs = eval_expr(func, module, builder, cond.left, local_sym_tab, map_sym_tab)[0]
local_sym_tab, map_sym_tab)[0]
if len(cond.ops) != 1 or len(cond.comparators) != 1: if len(cond.ops) != 1 or len(cond.comparators) != 1:
print("Unsupported complex comparison") logger.info("Unsupported complex comparison")
return None return None
rhs = eval_expr( rhs = eval_expr(
func, module, builder, cond.comparators[0], local_sym_tab, map_sym_tab func, module, builder, cond.comparators[0], local_sym_tab, map_sym_tab
@ -267,7 +282,7 @@ def handle_cond(func, module, builder, cond, local_sym_tab, map_sym_tab):
elif lhs.type.width > rhs.type.width: elif lhs.type.width > rhs.type.width:
rhs = builder.sext(rhs, lhs.type) rhs = builder.sext(rhs, lhs.type)
else: else:
print("Type mismatch in comparison") logger.info("Type mismatch in comparison")
return None return None
if isinstance(op, ast.Eq): if isinstance(op, ast.Eq):
@ -283,10 +298,10 @@ def handle_cond(func, module, builder, cond, local_sym_tab, map_sym_tab):
elif isinstance(op, ast.GtE): elif isinstance(op, ast.GtE):
return builder.icmp_signed(">=", lhs, rhs) return builder.icmp_signed(">=", lhs, rhs)
else: else:
print("Unsupported comparison operator") logger.info("Unsupported comparison operator")
return None return None
else: else:
print("Unsupported condition expression") logger.info("Unsupported condition expression")
return None return None
@ -294,7 +309,7 @@ def handle_if(
func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab=None func, module, builder, stmt, map_sym_tab, local_sym_tab, structs_sym_tab=None
): ):
"""Handle if statements in the function body.""" """Handle if statements in the function body."""
print("Handling if statement") logger.info("Handling if statement")
# start = builder.block.parent # start = builder.block.parent
then_block = func.append_basic_block(name="if.then") then_block = func.append_basic_block(name="if.then")
merge_block = func.append_basic_block(name="if.end") merge_block = func.append_basic_block(name="if.end")
@ -303,8 +318,7 @@ def handle_if(
else: else:
else_block = None else_block = None
cond = handle_cond(func, module, builder, stmt.test, cond = handle_cond(func, module, builder, stmt.test, local_sym_tab, map_sym_tab)
local_sym_tab, map_sym_tab)
if else_block: if else_block:
builder.cbranch(cond, then_block, else_block) builder.cbranch(cond, then_block, else_block)
else: else:
@ -348,9 +362,8 @@ def process_stmt(
did_return, did_return,
ret_type=ir.IntType(64), ret_type=ir.IntType(64),
): ):
print(f"Processing statement: {ast.dump(stmt)}") logger.info(f"Processing statement: {ast.dump(stmt)}")
if isinstance(stmt, ast.Expr): if isinstance(stmt, ast.Expr):
print(local_var_metadata)
handle_expr( handle_expr(
func, func,
module, module,
@ -359,7 +372,6 @@ def process_stmt(
local_sym_tab, local_sym_tab,
map_sym_tab, map_sym_tab,
structs_sym_tab, structs_sym_tab,
local_var_metadata,
) )
elif isinstance(stmt, ast.Assign): elif isinstance(stmt, ast.Assign):
handle_assign( handle_assign(
@ -409,6 +421,7 @@ def allocate_mem(
module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab
): ):
for stmt in body: for stmt in body:
has_metadata = False
if isinstance(stmt, ast.If): if isinstance(stmt, ast.If):
if stmt.body: if stmt.body:
local_sym_tab = allocate_mem( local_sym_tab = allocate_mem(
@ -434,11 +447,11 @@ def allocate_mem(
) )
elif isinstance(stmt, ast.Assign): elif isinstance(stmt, ast.Assign):
if len(stmt.targets) != 1: if len(stmt.targets) != 1:
print("Unsupported multiassignment") logger.info("Unsupported multiassignment")
continue continue
target = stmt.targets[0] target = stmt.targets[0]
if not isinstance(target, ast.Name): if not isinstance(target, ast.Name):
print("Unsupported assignment target") logger.info("Unsupported assignment target")
continue continue
var_name = target.id var_name = target.id
rval = stmt.value rval = stmt.value
@ -449,67 +462,72 @@ def allocate_mem(
ir_type = ctypes_to_ir(call_type) ir_type = ctypes_to_ir(call_type)
var = builder.alloca(ir_type, name=var_name) var = builder.alloca(ir_type, name=var_name)
var.align = ir_type.width // 8 var.align = ir_type.width // 8
print( logger.info(
f"Pre-allocated variable {var_name} of type {call_type}") f"Pre-allocated variable {var_name} of type {call_type}"
)
elif HelperHandlerRegistry.has_handler(call_type): elif HelperHandlerRegistry.has_handler(call_type):
# Assume return type is int64 for now # Assume return type is int64 for now
ir_type = ir.IntType(64) ir_type = ir.IntType(64)
var = builder.alloca(ir_type, name=var_name) var = builder.alloca(ir_type, name=var_name)
var.align = ir_type.width // 8 var.align = ir_type.width // 8
print(f"Pre-allocated variable {var_name} for helper") logger.info(f"Pre-allocated variable {var_name} for helper")
elif call_type == "deref" and len(rval.args) == 1: elif call_type == "deref" and len(rval.args) == 1:
# Assume return type is int64 for now # Assume return type is int64 for now
ir_type = ir.IntType(64) ir_type = ir.IntType(64)
var = builder.alloca(ir_type, name=var_name) var = builder.alloca(ir_type, name=var_name)
var.align = ir_type.width // 8 var.align = ir_type.width // 8
print(f"Pre-allocated variable {var_name} for deref") logger.info(f"Pre-allocated variable {var_name} for deref")
elif call_type in structs_sym_tab: elif call_type in structs_sym_tab:
struct_info = structs_sym_tab[call_type] struct_info = structs_sym_tab[call_type]
ir_type = struct_info.ir_type ir_type = struct_info.ir_type
var = builder.alloca(ir_type, name=var_name) var = builder.alloca(ir_type, name=var_name)
local_var_metadata[var_name] = call_type has_metadata = True
print( logger.info(
f"Pre-allocated variable { f"Pre-allocated variable {var_name} "
var_name} for struct {call_type}" f"for struct {call_type}"
) )
elif isinstance(rval.func, ast.Attribute): elif isinstance(rval.func, ast.Attribute):
ir_type = ir.PointerType(ir.IntType(64)) ir_type = ir.PointerType(ir.IntType(64))
var = builder.alloca(ir_type, name=var_name) var = builder.alloca(ir_type, name=var_name)
# var.align = ir_type.width // 8 # var.align = ir_type.width // 8
print(f"Pre-allocated variable {var_name} for map") logger.info(f"Pre-allocated variable {var_name} for map")
else: else:
print("Unsupported assignment call function type") logger.info("Unsupported assignment call function type")
continue continue
elif isinstance(rval, ast.Constant): elif isinstance(rval, ast.Constant):
if isinstance(rval.value, bool): if isinstance(rval.value, bool):
ir_type = ir.IntType(1) ir_type = ir.IntType(1)
var = builder.alloca(ir_type, name=var_name) var = builder.alloca(ir_type, name=var_name)
var.align = 1 var.align = 1
print(f"Pre-allocated variable {var_name} of type c_bool") logger.info(f"Pre-allocated variable {var_name} of type c_bool")
elif isinstance(rval.value, int): elif isinstance(rval.value, int):
# Assume c_int64 for now # Assume c_int64 for now
ir_type = ir.IntType(64) ir_type = ir.IntType(64)
var = builder.alloca(ir_type, name=var_name) var = builder.alloca(ir_type, name=var_name)
var.align = ir_type.width // 8 var.align = ir_type.width // 8
print(f"Pre-allocated variable {var_name} of type c_int64") logger.info(f"Pre-allocated variable {var_name} of type c_int64")
elif isinstance(rval.value, str): elif isinstance(rval.value, str):
ir_type = ir.PointerType(ir.IntType(8)) ir_type = ir.PointerType(ir.IntType(8))
var = builder.alloca(ir_type, name=var_name) var = builder.alloca(ir_type, name=var_name)
var.align = 8 var.align = 8
print(f"Pre-allocated variable {var_name} of type string") logger.info(f"Pre-allocated variable {var_name} of type string")
else: else:
print("Unsupported constant type") logger.info("Unsupported constant type")
continue continue
elif isinstance(rval, ast.BinOp): elif isinstance(rval, ast.BinOp):
# Assume c_int64 for now # Assume c_int64 for now
ir_type = ir.IntType(64) ir_type = ir.IntType(64)
var = builder.alloca(ir_type, name=var_name) var = builder.alloca(ir_type, name=var_name)
var.align = ir_type.width // 8 var.align = ir_type.width // 8
print(f"Pre-allocated variable {var_name} of type c_int64") logger.info(f"Pre-allocated variable {var_name} of type c_int64")
else: else:
print("Unsupported assignment value type") logger.info("Unsupported assignment value type")
continue continue
local_sym_tab[var_name] = (var, ir_type)
if has_metadata:
local_sym_tab[var_name] = LocalSymbol(var, ir_type, call_type)
else:
local_sym_tab[var_name] = LocalSymbol(var, ir_type)
return local_sym_tab return local_sym_tab
@ -534,7 +552,7 @@ def process_func_body(
structs_sym_tab, structs_sym_tab,
) )
print(f"Local symbol table: {local_sym_tab.keys()}") logger.info(f"Local symbol table: {local_sym_tab.keys()}")
for stmt in func_node.body: for stmt in func_node.body:
did_return = process_stmt( did_return = process_stmt(
@ -606,7 +624,7 @@ def func_proc(tree, module, chunks, map_sym_tab, structs_sym_tab):
if is_global: if is_global:
continue continue
func_type = get_probe_string(func_node) func_type = get_probe_string(func_node)
print(f"Found probe_string of {func_node.name}: {func_type}") logger.info(f"Found probe_string of {func_node.name}: {func_type}")
process_bpf_chunk( process_bpf_chunk(
func_node, func_node,
@ -665,14 +683,13 @@ def infer_return_type(func_node: ast.FunctionDef):
except Exception: except Exception:
return type(e).__name__ return type(e).__name__
for node in ast.walk(func_node): for walked_node in ast.walk(func_node):
if isinstance(node, ast.Return): if isinstance(walked_node, ast.Return):
t = _expr_type(node.value) t = _expr_type(walked_node.value)
if found_type is None: if found_type is None:
found_type = t found_type = t
elif found_type != t: elif found_type != t:
raise ValueError("Conflicting return types:" f"{ raise ValueError(f"Conflicting return types: {found_type} vs {t}")
found_type} vs {t}")
return found_type or "None" return found_type or "None"
@ -709,8 +726,7 @@ def assign_string_to_array(builder, target_array_ptr, source_string_ptr, array_l
char = builder.load(src_ptr) char = builder.load(src_ptr)
# Store character in target # Store character in target
dst_ptr = builder.gep( dst_ptr = builder.gep(target_array_ptr, [ir.Constant(ir.IntType(32), 0), idx])
target_array_ptr, [ir.Constant(ir.IntType(32), 0), idx])
builder.store(char, dst_ptr) builder.store(char, dst_ptr)
# Increment counter # Increment counter
@ -721,6 +737,5 @@ def assign_string_to_array(builder, target_array_ptr, source_string_ptr, array_l
# Ensure null termination # Ensure null termination
last_idx = ir.Constant(ir.IntType(32), array_length - 1) last_idx = ir.Constant(ir.IntType(32), array_length - 1)
null_ptr = builder.gep( null_ptr = builder.gep(target_array_ptr, [ir.Constant(ir.IntType(32), 0), last_idx])
target_array_ptr, [ir.Constant(ir.IntType(32), 0), last_idx])
builder.store(ir.Constant(ir.IntType(8), 0), null_ptr) builder.store(ir.Constant(ir.IntType(8), 0), null_ptr)

View File

@ -1,8 +1,121 @@
from llvmlite import ir from llvmlite import ir
import ast import ast
from logging import Logger
import logging
from .type_deducer import ctypes_to_ir
def emit_globals(module: ir.Module, names: list[str]): logger: Logger = logging.getLogger(__name__)
# TODO: this is going to be a huge fuck of a headache in the future.
global_sym_tab = []
def populate_global_symbol_table(tree, module: ir.Module):
for node in tree.body:
if isinstance(node, ast.FunctionDef):
for dec in node.decorator_list:
if (
isinstance(dec, ast.Call)
and isinstance(dec.func, ast.Name)
and dec.func.id == "section"
and len(dec.args) == 1
and isinstance(dec.args[0], ast.Constant)
and isinstance(dec.args[0].value, str)
):
global_sym_tab.append(node)
elif isinstance(dec, ast.Name) and dec.id == "bpfglobal":
global_sym_tab.append(node)
elif isinstance(dec, ast.Name) and dec.id == "map":
global_sym_tab.append(node)
return False
def emit_global(module: ir.Module, node, name):
logger.info(f"global identifier {name} processing")
# deduce LLVM type from the annotated return
if not isinstance(node.returns, ast.Name):
raise ValueError(f"Unsupported return annotation {ast.dump(node.returns)}")
ty = ctypes_to_ir(node.returns.id)
# extract the return expression
# TODO: turn this return extractor into a generic function I can use everywhere.
ret_stmt = node.body[0]
if not isinstance(ret_stmt, ast.Return) or ret_stmt.value is None:
raise ValueError(f"Global '{name}' has no valid return")
init_val = ret_stmt.value
# simple constant like "return 0"
if isinstance(init_val, ast.Constant):
llvm_init = ir.Constant(ty, init_val.value)
# variable reference like "return SOME_CONST"
elif isinstance(init_val, ast.Name):
# need symbol resolution here, stub as 0 for now
raise ValueError(f"Name reference {init_val.id} not yet supported")
# constructor call like "return c_int64(0)" or dataclass(...)
elif isinstance(init_val, ast.Call):
if len(init_val.args) >= 1 and isinstance(init_val.args[0], ast.Constant):
llvm_init = ir.Constant(ty, init_val.args[0].value)
else:
logger.info("Defaulting to zero as no constant argument found")
llvm_init = ir.Constant(ty, 0)
else:
raise ValueError(f"Unsupported return expr {ast.dump(init_val)}")
gvar = ir.GlobalVariable(module, ty, name=name)
gvar.initializer = llvm_init
gvar.align = 8
gvar.linkage = "dso_local"
gvar.global_constant = False
return gvar
def globals_processing(tree, module):
"""Process stuff decorated with @bpf and @bpfglobal except license and return the section name"""
globals_sym_tab = []
for node in tree.body:
# Skip non-assignment and non-function nodes
if not (isinstance(node, ast.FunctionDef)):
continue
# Get the name based on node type
if isinstance(node, ast.FunctionDef):
name = node.name
else:
continue
# Check for duplicate names
if name in globals_sym_tab:
raise SyntaxError(f"ERROR: Global name '{name}' previously defined")
else:
globals_sym_tab.append(name)
if isinstance(node, ast.FunctionDef) and node.name != "LICENSE":
decorators = [
dec.id for dec in node.decorator_list if isinstance(dec, ast.Name)
]
if "bpf" in decorators and "bpfglobal" in decorators:
if (
len(node.body) == 1
and isinstance(node.body[0], ast.Return)
and node.body[0].value is not None
and isinstance(
node.body[0].value, (ast.Constant, ast.Name, ast.Call)
)
):
emit_global(module, node, name)
else:
raise SyntaxError(f"ERROR: Invalid syntax for {name} global")
return None
def emit_llvm_compiler_used(module: ir.Module, names: list[str]):
""" """
Emit the @llvm.compiler.used global given a list of function/global names. Emit the @llvm.compiler.used global given a list of function/global names.
""" """
@ -24,7 +137,7 @@ def emit_globals(module: ir.Module, names: list[str]):
gv.section = "llvm.metadata" gv.section = "llvm.metadata"
def globals_processing(tree, module: ir.Module): def globals_list_creation(tree, module: ir.Module):
collected = ["LICENSE"] collected = ["LICENSE"]
for node in tree.body: for node in tree.body:
@ -40,10 +153,11 @@ def globals_processing(tree, module: ir.Module):
): ):
collected.append(node.name) collected.append(node.name)
elif isinstance(dec, ast.Name) and dec.id == "bpfglobal": # NOTE: all globals other than
collected.append(node.name) # elif isinstance(dec, ast.Name) and dec.id == "bpfglobal":
# collected.append(node.name)
elif isinstance(dec, ast.Name) and dec.id == "map": elif isinstance(dec, ast.Name) and dec.id == "map":
collected.append(node.name) collected.append(node.name)
emit_globals(module, collected) emit_llvm_compiler_used(module, collected)

View File

@ -1,2 +1,13 @@
from .helper_utils import HelperHandlerRegistry from .helper_utils import HelperHandlerRegistry
from .bpf_helper_handler import handle_helper_call from .bpf_helper_handler import handle_helper_call
from .helpers import ktime, pid, deref, XDP_DROP, XDP_PASS
__all__ = [
"HelperHandlerRegistry",
"handle_helper_call",
"ktime",
"pid",
"deref",
"XDP_DROP",
"XDP_PASS",
]

View File

@ -1,10 +1,18 @@
import ast import ast
from llvmlite import ir from llvmlite import ir
from enum import Enum from enum import Enum
from .helper_utils import (HelperHandlerRegistry, from .helper_utils import (
get_or_create_ptr_from_arg, get_flags_val, HelperHandlerRegistry,
handle_fstring_print, simple_string_print, get_or_create_ptr_from_arg,
get_data_ptr_and_size) get_flags_val,
handle_fstring_print,
simple_string_print,
get_data_ptr_and_size,
)
from logging import Logger
import logging
logger: Logger = logging.getLogger(__name__)
class BPFHelperID(Enum): class BPFHelperID(Enum):
@ -18,9 +26,15 @@ class BPFHelperID(Enum):
@HelperHandlerRegistry.register("ktime") @HelperHandlerRegistry.register("ktime")
def bpf_ktime_get_ns_emitter(call, map_ptr, module, builder, func, def bpf_ktime_get_ns_emitter(
local_sym_tab=None, struct_sym_tab=None, call,
local_var_metadata=None): map_ptr,
module,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
):
""" """
Emit LLVM IR for bpf_ktime_get_ns helper function call. Emit LLVM IR for bpf_ktime_get_ns helper function call.
""" """
@ -34,27 +48,33 @@ def bpf_ktime_get_ns_emitter(call, map_ptr, module, builder, func,
@HelperHandlerRegistry.register("lookup") @HelperHandlerRegistry.register("lookup")
def bpf_map_lookup_elem_emitter(call, map_ptr, module, builder, func, def bpf_map_lookup_elem_emitter(
local_sym_tab=None, struct_sym_tab=None, call,
local_var_metadata=None): map_ptr,
module,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
):
""" """
Emit LLVM IR for bpf_map_lookup_elem helper function call. Emit LLVM IR for bpf_map_lookup_elem helper function call.
""" """
if not call.args or len(call.args) != 1: if not call.args or len(call.args) != 1:
raise ValueError("Map lookup expects exactly one argument (key), got " raise ValueError(
f"{len(call.args)}") "Map lookup expects exactly one argument (key), got " f"{len(call.args)}"
)
key_ptr = get_or_create_ptr_from_arg(call.args[0], builder, local_sym_tab) key_ptr = get_or_create_ptr_from_arg(call.args[0], builder, local_sym_tab)
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
fn_type = ir.FunctionType( fn_type = ir.FunctionType(
ir.PointerType(), # Return type: void* ir.PointerType(), # Return type: void*
[ir.PointerType(), ir.PointerType()], # Args: (void*, void*) [ir.PointerType(), ir.PointerType()], # Args: (void*, void*)
var_arg=False var_arg=False,
) )
fn_ptr_type = ir.PointerType(fn_type) fn_ptr_type = ir.PointerType(fn_type)
fn_addr = ir.Constant(ir.IntType( fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_MAP_LOOKUP_ELEM.value)
64), BPFHelperID.BPF_MAP_LOOKUP_ELEM.value)
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
result = builder.call(fn_ptr, [map_void_ptr, key_ptr], tail=False) result = builder.call(fn_ptr, [map_void_ptr, key_ptr], tail=False)
@ -63,33 +83,44 @@ def bpf_map_lookup_elem_emitter(call, map_ptr, module, builder, func,
@HelperHandlerRegistry.register("print") @HelperHandlerRegistry.register("print")
def bpf_printk_emitter(call, map_ptr, module, builder, func, def bpf_printk_emitter(
local_sym_tab=None, struct_sym_tab=None, call,
local_var_metadata=None): map_ptr,
module,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
):
"""Emit LLVM IR for bpf_printk helper function call.""" """Emit LLVM IR for bpf_printk helper function call."""
if not hasattr(func, "_fmt_counter"): if not hasattr(func, "_fmt_counter"):
func._fmt_counter = 0 func._fmt_counter = 0
if not call.args: if not call.args:
raise ValueError( raise ValueError("bpf_printk expects at least one argument (format string)")
"bpf_printk expects at least one argument (format string)")
args = [] args = []
if isinstance(call.args[0], ast.JoinedStr): if isinstance(call.args[0], ast.JoinedStr):
args = handle_fstring_print(call.args[0], module, builder, func, args = handle_fstring_print(
local_sym_tab, struct_sym_tab, call.args[0],
local_var_metadata) module,
elif (isinstance(call.args[0], ast.Constant) and builder,
isinstance(call.args[0].value, str)): func,
local_sym_tab,
struct_sym_tab,
)
elif isinstance(call.args[0], ast.Constant) and isinstance(call.args[0].value, str):
# TODO: We are only supporting single arguments for now. # TODO: We are only supporting single arguments for now.
# In case of multiple args, the first one will be taken. # In case of multiple args, the first one will be taken.
args = simple_string_print(call.args[0].value, module, builder, func) args = simple_string_print(call.args[0].value, module, builder, func)
else: else:
raise NotImplementedError( raise NotImplementedError(
"Only simple strings or f-strings are supported in bpf_printk.") "Only simple strings or f-strings are supported in bpf_printk."
)
fn_type = ir.FunctionType( fn_type = ir.FunctionType(
ir.IntType(64), [ir.PointerType(), ir.IntType(32)], var_arg=True) ir.IntType(64), [ir.PointerType(), ir.IntType(32)], var_arg=True
)
fn_ptr_type = ir.PointerType(fn_type) fn_ptr_type = ir.PointerType(fn_type)
fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_PRINTK.value) fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_PRINTK.value)
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
@ -99,18 +130,24 @@ def bpf_printk_emitter(call, map_ptr, module, builder, func,
@HelperHandlerRegistry.register("update") @HelperHandlerRegistry.register("update")
def bpf_map_update_elem_emitter(call, map_ptr, module, builder, func, def bpf_map_update_elem_emitter(
local_sym_tab=None, struct_sym_tab=None, call,
local_var_metadata=None): map_ptr,
module,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
):
""" """
Emit LLVM IR for bpf_map_update_elem helper function call. Emit LLVM IR for bpf_map_update_elem helper function call.
Expected call signature: map.update(key, value, flags=0) Expected call signature: map.update(key, value, flags=0)
""" """
if (not call.args or if not call.args or len(call.args) < 2 or len(call.args) > 3:
len(call.args) < 2 or raise ValueError(
len(call.args) > 3): "Map update expects 2 or 3 args (key, value, flags), "
raise ValueError("Map update expects 2 or 3 args (key, value, flags), " f"got {len(call.args)}"
f"got {len(call.args)}") )
key_arg = call.args[0] key_arg = call.args[0]
value_arg = call.args[1] value_arg = call.args[1]
@ -124,12 +161,11 @@ def bpf_map_update_elem_emitter(call, map_ptr, module, builder, func,
fn_type = ir.FunctionType( fn_type = ir.FunctionType(
ir.IntType(64), ir.IntType(64),
[ir.PointerType(), ir.PointerType(), ir.PointerType(), ir.IntType(64)], [ir.PointerType(), ir.PointerType(), ir.PointerType(), ir.IntType(64)],
var_arg=False var_arg=False,
) )
fn_ptr_type = ir.PointerType(fn_type) fn_ptr_type = ir.PointerType(fn_type)
fn_addr = ir.Constant(ir.IntType( fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_MAP_UPDATE_ELEM.value)
64), BPFHelperID.BPF_MAP_UPDATE_ELEM.value)
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
if isinstance(flags_val, int): if isinstance(flags_val, int):
@ -138,22 +174,30 @@ def bpf_map_update_elem_emitter(call, map_ptr, module, builder, func,
flags_const = flags_val flags_const = flags_val
result = builder.call( result = builder.call(
fn_ptr, [map_void_ptr, key_ptr, value_ptr, flags_const], tail=False) fn_ptr, [map_void_ptr, key_ptr, value_ptr, flags_const], tail=False
)
return result, None return result, None
@HelperHandlerRegistry.register("delete") @HelperHandlerRegistry.register("delete")
def bpf_map_delete_elem_emitter(call, map_ptr, module, builder, func, def bpf_map_delete_elem_emitter(
local_sym_tab=None, struct_sym_tab=None, call,
local_var_metadata=None): map_ptr,
module,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
):
""" """
Emit LLVM IR for bpf_map_delete_elem helper function call. Emit LLVM IR for bpf_map_delete_elem helper function call.
Expected call signature: map.delete(key) Expected call signature: map.delete(key)
""" """
if not call.args or len(call.args) != 1: if not call.args or len(call.args) != 1:
raise ValueError("Map delete expects exactly one argument (key), got " raise ValueError(
f"{len(call.args)}") "Map delete expects exactly one argument (key), got " f"{len(call.args)}"
)
key_ptr = get_or_create_ptr_from_arg(call.args[0], builder, local_sym_tab) key_ptr = get_or_create_ptr_from_arg(call.args[0], builder, local_sym_tab)
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType()) map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
@ -161,12 +205,11 @@ def bpf_map_delete_elem_emitter(call, map_ptr, module, builder, func,
fn_type = ir.FunctionType( fn_type = ir.FunctionType(
ir.IntType(64), # Return type: int64 (status code) ir.IntType(64), # Return type: int64 (status code)
[ir.PointerType(), ir.PointerType()], # Args: (void*, void*) [ir.PointerType(), ir.PointerType()], # Args: (void*, void*)
var_arg=False var_arg=False,
) )
fn_ptr_type = ir.PointerType(fn_type) fn_ptr_type = ir.PointerType(fn_type)
fn_addr = ir.Constant(ir.IntType( fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_MAP_DELETE_ELEM.value)
64), BPFHelperID.BPF_MAP_DELETE_ELEM.value)
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
result = builder.call(fn_ptr, [map_void_ptr, key_ptr], tail=False) result = builder.call(fn_ptr, [map_void_ptr, key_ptr], tail=False)
@ -175,15 +218,20 @@ def bpf_map_delete_elem_emitter(call, map_ptr, module, builder, func,
@HelperHandlerRegistry.register("pid") @HelperHandlerRegistry.register("pid")
def bpf_get_current_pid_tgid_emitter(call, map_ptr, module, builder, func, def bpf_get_current_pid_tgid_emitter(
local_sym_tab=None, struct_sym_tab=None, call,
local_var_metadata=None): map_ptr,
module,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
):
""" """
Emit LLVM IR for bpf_get_current_pid_tgid helper function call. Emit LLVM IR for bpf_get_current_pid_tgid helper function call.
""" """
# func is an arg to just have a uniform signature with other emitters # func is an arg to just have a uniform signature with other emitters
helper_id = ir.Constant(ir.IntType( helper_id = ir.Constant(ir.IntType(64), BPFHelperID.BPF_GET_CURRENT_PID_TGID.value)
64), BPFHelperID.BPF_GET_CURRENT_PID_TGID.value)
fn_type = ir.FunctionType(ir.IntType(64), [], var_arg=False) fn_type = ir.FunctionType(ir.IntType(64), [], var_arg=False)
fn_ptr_type = ir.PointerType(fn_type) fn_ptr_type = ir.PointerType(fn_type)
fn_ptr = builder.inttoptr(helper_id, fn_ptr_type) fn_ptr = builder.inttoptr(helper_id, fn_ptr_type)
@ -196,18 +244,23 @@ def bpf_get_current_pid_tgid_emitter(call, map_ptr, module, builder, func,
@HelperHandlerRegistry.register("output") @HelperHandlerRegistry.register("output")
def bpf_perf_event_output_handler(call, map_ptr, module, builder, func, def bpf_perf_event_output_handler(
local_sym_tab=None, struct_sym_tab=None, call,
local_var_metadata=None): map_ptr,
module,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
):
if len(call.args) != 1: if len(call.args) != 1:
raise ValueError("Perf event output expects exactly one argument, " raise ValueError(
f"got {len(call.args)}") "Perf event output expects exactly one argument, " f"got {len(call.args)}"
)
data_arg = call.args[0] data_arg = call.args[0]
ctx_ptr = func.args[0] # First argument to the function is ctx ctx_ptr = func.args[0] # First argument to the function is ctx
data_ptr, size_val = get_data_ptr_and_size(data_arg, local_sym_tab, data_ptr, size_val = get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab)
struct_sym_tab,
local_var_metadata)
# BPF_F_CURRENT_CPU is -1 in 32 bit # BPF_F_CURRENT_CPU is -1 in 32 bit
flags_val = ir.Constant(ir.IntType(64), 0xFFFFFFFF) flags_val = ir.Constant(ir.IntType(64), 0xFFFFFFFF)
@ -216,36 +269,54 @@ def bpf_perf_event_output_handler(call, map_ptr, module, builder, func,
data_void_ptr = builder.bitcast(data_ptr, ir.PointerType()) data_void_ptr = builder.bitcast(data_ptr, ir.PointerType())
fn_type = ir.FunctionType( fn_type = ir.FunctionType(
ir.IntType(64), ir.IntType(64),
[ir.PointerType(ir.IntType(8)), ir.PointerType(), ir.IntType(64), [
ir.PointerType(), ir.IntType(64)], ir.PointerType(ir.IntType(8)),
var_arg=False ir.PointerType(),
ir.IntType(64),
ir.PointerType(),
ir.IntType(64),
],
var_arg=False,
) )
fn_ptr_type = ir.PointerType(fn_type) fn_ptr_type = ir.PointerType(fn_type)
# helper id # helper id
fn_addr = ir.Constant(ir.IntType(64), fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_PERF_EVENT_OUTPUT.value)
BPFHelperID.BPF_PERF_EVENT_OUTPUT.value)
fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type)
result = builder.call( result = builder.call(
fn_ptr, fn_ptr, [ctx_ptr, map_void_ptr, flags_val, data_void_ptr, size_val], tail=False
[ctx_ptr, map_void_ptr, flags_val, data_void_ptr, size_val], )
tail=False)
return result, None return result, None
def handle_helper_call(call, module, builder, func, def handle_helper_call(
local_sym_tab=None, map_sym_tab=None, call,
struct_sym_tab=None, local_var_metadata=None): module,
builder,
func,
local_sym_tab=None,
map_sym_tab=None,
struct_sym_tab=None,
):
"""Process a BPF helper function call and emit the appropriate LLVM IR.""" """Process a BPF helper function call and emit the appropriate LLVM IR."""
# Helper function to get map pointer and invoke handler # Helper function to get map pointer and invoke handler
def invoke_helper(method_name, map_ptr=None): def invoke_helper(method_name, map_ptr=None):
handler = HelperHandlerRegistry.get_handler(method_name) handler = HelperHandlerRegistry.get_handler(method_name)
if not handler: if not handler:
raise NotImplementedError( raise NotImplementedError(
f"Helper function '{method_name}' is not implemented.") f"Helper function '{method_name}' is not implemented."
return handler(call, map_ptr, module, builder, func, )
local_sym_tab, struct_sym_tab, local_var_metadata) return handler(
call,
map_ptr,
module,
builder,
func,
local_sym_tab,
struct_sym_tab,
)
# Handle direct function calls (e.g., print(), ktime()) # Handle direct function calls (e.g., print(), ktime())
if isinstance(call.func, ast.Name): if isinstance(call.func, ast.Name):
@ -255,14 +326,18 @@ def handle_helper_call(call, module, builder, func,
elif isinstance(call.func, ast.Attribute): elif isinstance(call.func, ast.Attribute):
method_name = call.func.attr method_name = call.func.attr
value = call.func.value value = call.func.value
logger.info(f"Handling method call: {ast.dump(call.func)}")
# Get map pointer from different styles of map access # Get map pointer from different styles of map access
if isinstance(value, ast.Call) and isinstance(value.func, ast.Name): if isinstance(value, ast.Call) and isinstance(value.func, ast.Name):
# Variable style: my_map.lookup(key) # Func style: my_map().lookup(key)
map_name = value.func.id map_name = value.func.id
elif isinstance(value, ast.Name):
# Direct style: my_map.lookup(key)
map_name = value.id
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Unsupported map access pattern: {ast.dump(value)}") f"Unsupported map access pattern: {ast.dump(value)}"
)
# Verify map exists and get pointer # Verify map exists and get pointer
if not map_sym_tab or map_name not in map_sym_tab: if not map_sym_tab or map_name not in map_sym_tab:

View File

@ -1,5 +1,7 @@
import ast import ast
import logging import logging
from collections.abc import Callable
from llvmlite import ir from llvmlite import ir
from pythonbpf.expr_pass import eval_expr from pythonbpf.expr_pass import eval_expr
@ -8,14 +10,17 @@ logger = logging.getLogger(__name__)
class HelperHandlerRegistry: class HelperHandlerRegistry:
"""Registry for BPF helpers""" """Registry for BPF helpers"""
_handlers = {}
_handlers: dict[str, Callable] = {}
@classmethod @classmethod
def register(cls, helper_name): def register(cls, helper_name):
"""Decorator to register a handler function for a helper""" """Decorator to register a handler function for a helper"""
def decorator(func): def decorator(func):
cls._handlers[helper_name] = func cls._handlers[helper_name] = func
return func return func
return decorator return decorator
@classmethod @classmethod
@ -32,7 +37,7 @@ class HelperHandlerRegistry:
def get_var_ptr_from_name(var_name, local_sym_tab): def get_var_ptr_from_name(var_name, local_sym_tab):
"""Get a pointer to a variable from the symbol table.""" """Get a pointer to a variable from the symbol table."""
if local_sym_tab and var_name in local_sym_tab: if local_sym_tab and var_name in local_sym_tab:
return local_sym_tab[var_name][0] return local_sym_tab[var_name].var
raise ValueError(f"Variable '{var_name}' not found in local symbol table") raise ValueError(f"Variable '{var_name}' not found in local symbol table")
@ -55,7 +60,8 @@ def get_or_create_ptr_from_arg(arg, builder, local_sym_tab):
ptr = create_int_constant_ptr(arg.value, builder) ptr = create_int_constant_ptr(arg.value, builder)
else: else:
raise NotImplementedError( raise NotImplementedError(
"Only simple variable names are supported as args in map helpers.") "Only simple variable names are supported as args in map helpers."
)
return ptr return ptr
@ -66,16 +72,16 @@ def get_flags_val(arg, builder, local_sym_tab):
if isinstance(arg, ast.Name): if isinstance(arg, ast.Name):
if local_sym_tab and arg.id in local_sym_tab: if local_sym_tab and arg.id in local_sym_tab:
flags_ptr = local_sym_tab[arg.id][0] flags_ptr = local_sym_tab[arg.id].var
return builder.load(flags_ptr) return builder.load(flags_ptr)
else: else:
raise ValueError( raise ValueError(f"Variable '{arg.id}' not found in local symbol table")
f"Variable '{arg.id}' not found in local symbol table")
elif isinstance(arg, ast.Constant) and isinstance(arg.value, int): elif isinstance(arg, ast.Constant) and isinstance(arg.value, int):
return arg.value return arg.value
raise NotImplementedError( raise NotImplementedError(
"Only var names or int consts are supported as map helpers flags.") "Only var names or int consts are supported as map helpers flags."
)
def simple_string_print(string_value, module, builder, func): def simple_string_print(string_value, module, builder, func):
@ -87,9 +93,14 @@ def simple_string_print(string_value, module, builder, func):
return args return args
def handle_fstring_print(joined_str, module, builder, func, def handle_fstring_print(
local_sym_tab=None, struct_sym_tab=None, joined_str,
local_var_metadata=None): module,
builder,
func,
local_sym_tab=None,
struct_sym_tab=None,
):
"""Handle f-string formatting for bpf_printk emitter.""" """Handle f-string formatting for bpf_printk emitter."""
fmt_parts = [] fmt_parts = []
exprs = [] exprs = []
@ -100,25 +111,32 @@ def handle_fstring_print(joined_str, module, builder, func,
if isinstance(value, ast.Constant): if isinstance(value, ast.Constant):
_process_constant_in_fstring(value, fmt_parts, exprs) _process_constant_in_fstring(value, fmt_parts, exprs)
elif isinstance(value, ast.FormattedValue): elif isinstance(value, ast.FormattedValue):
_process_fval(value, fmt_parts, exprs, _process_fval(
local_sym_tab, struct_sym_tab, value,
local_var_metadata) fmt_parts,
exprs,
local_sym_tab,
struct_sym_tab,
)
else: else:
raise NotImplementedError( raise NotImplementedError(f"Unsupported f-string value type: {type(value)}")
f"Unsupported f-string value type: {type(value)}")
fmt_str = "".join(fmt_parts) fmt_str = "".join(fmt_parts)
args = simple_string_print(fmt_str, module, builder, func) args = simple_string_print(fmt_str, module, builder, func)
# NOTE: Process expressions (limited to 3 due to BPF constraints) # NOTE: Process expressions (limited to 3 due to BPF constraints)
if len(exprs) > 3: if len(exprs) > 3:
logger.warning( logger.warning("bpf_printk supports up to 3 args, extra args will be ignored.")
"bpf_printk supports up to 3 args, extra args will be ignored.")
for expr in exprs[:3]: for expr in exprs[:3]:
arg_value = _prepare_expr_args(expr, func, module, builder, arg_value = _prepare_expr_args(
local_sym_tab, struct_sym_tab, expr,
local_var_metadata) func,
module,
builder,
local_sym_tab,
struct_sym_tab,
)
args.append(arg_value) args.append(arg_value)
return args return args
@ -133,61 +151,63 @@ def _process_constant_in_fstring(cst, fmt_parts, exprs):
exprs.append(ir.Constant(ir.IntType(64), cst.value)) exprs.append(ir.Constant(ir.IntType(64), cst.value))
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Unsupported constant type in f-string: {type(cst.value)}") f"Unsupported constant type in f-string: {type(cst.value)}"
)
def _process_fval(fval, fmt_parts, exprs, def _process_fval(fval, fmt_parts, exprs, local_sym_tab, struct_sym_tab):
local_sym_tab, struct_sym_tab,
local_var_metadata):
"""Process formatted values in f-string.""" """Process formatted values in f-string."""
logger.debug(f"Processing formatted value: {ast.dump(fval)}") logger.debug(f"Processing formatted value: {ast.dump(fval)}")
if isinstance(fval.value, ast.Name): if isinstance(fval.value, ast.Name):
_process_name_in_fval(fval.value, fmt_parts, exprs, local_sym_tab) _process_name_in_fval(fval.value, fmt_parts, exprs, local_sym_tab)
elif isinstance(fval.value, ast.Attribute): elif isinstance(fval.value, ast.Attribute):
_process_attr_in_fval(fval.value, fmt_parts, exprs, _process_attr_in_fval(
local_sym_tab, struct_sym_tab, fval.value,
local_var_metadata) fmt_parts,
exprs,
local_sym_tab,
struct_sym_tab,
)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Unsupported formatted value in f-string: {type(fval.value)}") f"Unsupported formatted value in f-string: {type(fval.value)}"
)
def _process_name_in_fval(name_node, fmt_parts, exprs, local_sym_tab): def _process_name_in_fval(name_node, fmt_parts, exprs, local_sym_tab):
"""Process name nodes in formatted values.""" """Process name nodes in formatted values."""
if local_sym_tab and name_node.id in local_sym_tab: if local_sym_tab and name_node.id in local_sym_tab:
_, var_type = local_sym_tab[name_node.id] _, var_type, tmp = local_sym_tab[name_node.id]
_populate_fval(var_type, name_node, fmt_parts, exprs) _populate_fval(var_type, name_node, fmt_parts, exprs)
def _process_attr_in_fval(attr_node, fmt_parts, exprs, def _process_attr_in_fval(attr_node, fmt_parts, exprs, local_sym_tab, struct_sym_tab):
local_sym_tab, struct_sym_tab,
local_var_metadata):
"""Process attribute nodes in formatted values.""" """Process attribute nodes in formatted values."""
if (isinstance(attr_node.value, ast.Name) and if (
local_sym_tab and attr_node.value.id in local_sym_tab): isinstance(attr_node.value, ast.Name)
and local_sym_tab
and attr_node.value.id in local_sym_tab
):
var_name = attr_node.value.id var_name = attr_node.value.id
field_name = attr_node.attr field_name = attr_node.attr
if not local_var_metadata or var_name not in local_var_metadata: var_type = local_sym_tab[var_name].metadata
raise ValueError(
f"Metadata for '{var_name}' not found in local var metadata")
var_type = local_var_metadata[var_name]
if var_type not in struct_sym_tab: if var_type not in struct_sym_tab:
raise ValueError( raise ValueError(
f"Struct '{var_type}' for '{var_name}' not in symbol table") f"Struct '{var_type}' for '{var_name}' not in symbol table"
)
struct_info = struct_sym_tab[var_type] struct_info = struct_sym_tab[var_type]
if field_name not in struct_info.fields: if field_name not in struct_info.fields:
raise ValueError( raise ValueError(f"Field '{field_name}' not found in struct '{var_type}'")
f"Field '{field_name}' not found in struct '{var_type}'")
field_type = struct_info.field_type(field_name) field_type = struct_info.field_type(field_name)
_populate_fval(field_type, attr_node, fmt_parts, exprs) _populate_fval(field_type, attr_node, fmt_parts, exprs)
else: else:
raise NotImplementedError( raise NotImplementedError(
"Only simple attribute on local vars is supported in f-strings.") "Only simple attribute on local vars is supported in f-strings."
)
def _populate_fval(ftype, node, fmt_parts, exprs): def _populate_fval(ftype, node, fmt_parts, exprs):
@ -202,14 +222,14 @@ def _populate_fval(ftype, node, fmt_parts, exprs):
exprs.append(node) exprs.append(node)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Unsupported integer width in f-string: {ftype.width}") f"Unsupported integer width in f-string: {ftype.width}"
)
elif ftype == ir.PointerType(ir.IntType(8)): elif ftype == ir.PointerType(ir.IntType(8)):
# NOTE: We assume i8* is a string # NOTE: We assume i8* is a string
fmt_parts.append("%s") fmt_parts.append("%s")
exprs.append(node) exprs.append(node)
else: else:
raise NotImplementedError( raise NotImplementedError(f"Unsupported field type in f-string: {ftype}")
f"Unsupported field type in f-string: {ftype}")
def _create_format_string_global(fmt_str, func, module, builder): def _create_format_string_global(fmt_str, func, module, builder):
@ -218,11 +238,11 @@ def _create_format_string_global(fmt_str, func, module, builder):
func._fmt_counter += 1 func._fmt_counter += 1
fmt_gvar = ir.GlobalVariable( fmt_gvar = ir.GlobalVariable(
module, ir.ArrayType(ir.IntType(8), len(fmt_str)), name=fmt_name) module, ir.ArrayType(ir.IntType(8), len(fmt_str)), name=fmt_name
)
fmt_gvar.global_constant = True fmt_gvar.global_constant = True
fmt_gvar.initializer = ir.Constant( fmt_gvar.initializer = ir.Constant(
ir.ArrayType(ir.IntType(8), len(fmt_str)), ir.ArrayType(ir.IntType(8), len(fmt_str)), bytearray(fmt_str.encode("utf8"))
bytearray(fmt_str.encode("utf8"))
) )
fmt_gvar.linkage = "internal" fmt_gvar.linkage = "internal"
fmt_gvar.align = 1 fmt_gvar.align = 1
@ -230,13 +250,17 @@ def _create_format_string_global(fmt_str, func, module, builder):
return builder.bitcast(fmt_gvar, ir.PointerType()) return builder.bitcast(fmt_gvar, ir.PointerType())
def _prepare_expr_args(expr, func, module, builder, def _prepare_expr_args(expr, func, module, builder, local_sym_tab, struct_sym_tab):
local_sym_tab, struct_sym_tab,
local_var_metadata):
"""Evaluate and prepare an expression to use as an arg for bpf_printk.""" """Evaluate and prepare an expression to use as an arg for bpf_printk."""
val, _ = eval_expr(func, module, builder, expr, val, _ = eval_expr(
local_sym_tab, None, struct_sym_tab, func,
local_var_metadata) module,
builder,
expr,
local_sym_tab,
None,
struct_sym_tab,
)
if val: if val:
if isinstance(val.type, ir.PointerType): if isinstance(val.type, ir.PointerType):
@ -246,43 +270,38 @@ def _prepare_expr_args(expr, func, module, builder,
val = builder.sext(val, ir.IntType(64)) val = builder.sext(val, ir.IntType(64))
else: else:
logger.warning( logger.warning(
"Only int and ptr supported in bpf_printk args. " "Only int and ptr supported in bpf_printk args. " "Others default to 0."
"Others default to 0.") )
val = ir.Constant(ir.IntType(64), 0) val = ir.Constant(ir.IntType(64), 0)
return val return val
else: else:
logger.warning( logger.warning(
"Failed to evaluate expression for bpf_printk argument. " "Failed to evaluate expression for bpf_printk argument. "
"It will be converted to 0.") "It will be converted to 0."
)
return ir.Constant(ir.IntType(64), 0) return ir.Constant(ir.IntType(64), 0)
def get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab, def get_data_ptr_and_size(data_arg, local_sym_tab, struct_sym_tab):
local_var_metadata):
"""Extract data pointer and size information for perf event output.""" """Extract data pointer and size information for perf event output."""
if isinstance(data_arg, ast.Name): if isinstance(data_arg, ast.Name):
data_name = data_arg.id data_name = data_arg.id
if local_sym_tab and data_name in local_sym_tab: if local_sym_tab and data_name in local_sym_tab:
data_ptr = local_sym_tab[data_name][0] data_ptr = local_sym_tab[data_name].var
else: else:
raise ValueError( raise ValueError(
f"Data variable {data_name} not found in local symbol table.") f"Data variable {data_name} not found in local symbol table."
)
# Check if data_name is a struct # Check if data_name is a struct
if local_var_metadata and data_name in local_var_metadata: data_type = local_sym_tab[data_name].metadata
data_type = local_var_metadata[data_name] if data_type in struct_sym_tab:
if data_type in struct_sym_tab: struct_info = struct_sym_tab[data_type]
struct_info = struct_sym_tab[data_type] size_val = ir.Constant(ir.IntType(64), struct_info.size)
size_val = ir.Constant(ir.IntType(64), struct_info.size) return data_ptr, size_val
return data_ptr, size_val
else:
raise ValueError(
f"Struct {data_type} for {data_name} not in symbol table.")
else: else:
raise ValueError( raise ValueError(f"Struct {data_type} for {data_name} not in symbol table.")
f"Metadata for variable {data_name} "
"not found in local variable metadata.")
else: else:
raise NotImplementedError( raise NotImplementedError(
"Only simple object names are supported " "Only simple object names are supported as data in perf event output."
"as data in perf event output.") )

View File

@ -1,5 +1,9 @@
from llvmlite import ir from llvmlite import ir
import ast import ast
from logging import Logger
import logging
logger: Logger = logging.getLogger(__name__)
def emit_license(module: ir.Module, license_str: str): def emit_license(module: ir.Module, license_str: str):
@ -41,9 +45,9 @@ def license_processing(tree, module):
emit_license(module, node.body[0].value.value) emit_license(module, node.body[0].value.value)
return "LICENSE" return "LICENSE"
else: else:
print("ERROR: LICENSE() must return a string literal") logger.info("ERROR: LICENSE() must return a string literal")
return None return None
else: else:
print("ERROR: LICENSE already defined") logger.info("ERROR: LICENSE already defined")
return None return None
return None return None

View File

@ -85,7 +85,7 @@ def create_bpf_map(module, map_name, map_params):
def create_map_debug_info(module, map_global, map_name, map_params): def create_map_debug_info(module, map_global, map_name, map_params):
"""Generate debug information metadata for BPF maps HASH and PERF_EVENT_ARRAY""" """Generate debug info metadata for BPF maps HASH and PERF_EVENT_ARRAY"""
generator = DebugInfoGenerator(module) generator = DebugInfoGenerator(module)
uint_type = generator.get_uint32_type() uint_type = generator.get_uint32_type()

View File

@ -1,7 +1,11 @@
from collections.abc import Callable
from typing import Any
class MapProcessorRegistry: class MapProcessorRegistry:
"""Registry for map processor functions""" """Registry for map processor functions"""
_processors = {} _processors: dict[str, Callable[..., Any]] = {}
@classmethod @classmethod
def register(cls, map_type_name): def register(cls, map_type_name):

View File

@ -19,7 +19,7 @@ def structs_proc(tree, module, chunks):
structs_sym_tab = {} structs_sym_tab = {}
for cls_node in chunks: for cls_node in chunks:
if is_bpf_struct(cls_node): if is_bpf_struct(cls_node):
print(f"Found BPF struct: {cls_node.name}") logger.info(f"Found BPF struct: {cls_node.name}")
struct_info = process_bpf_struct(cls_node, module) struct_info = process_bpf_struct(cls_node, module)
structs_sym_tab[cls_node.name] = struct_info structs_sym_tab[cls_node.name] = struct_info
return structs_sym_tab return structs_sym_tab

View File

@ -0,0 +1,27 @@
// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
#include <bpf/bpf_tracing.h>
#include <linux/types.h>
struct test_struct {
__u64 a;
__u64 b;
};
struct test_struct w = {};
volatile __u64 prev_time = 0;
SEC("tracepoint/syscalls/sys_enter_execve")
int trace_execve(void *ctx)
{
bpf_printk("previous %ul now %ul", w.b, w.a);
__u64 ts = bpf_ktime_get_ns();
bpf_printk("prev %ul now %ul", prev_time, ts);
w.a = ts;
w.b = prev_time;
prev_time = ts;
return 0;
}
char LICENSE[] SEC("license") = "GPL";

View File

@ -1,5 +1,5 @@
from pythonbpf import bpf, map, section, bpfglobal, compile from pythonbpf import bpf, map, section, bpfglobal, compile
from pythonbpf.helpers import XDP_PASS from pythonbpf.helper import XDP_PASS
from pythonbpf.maps import HashMap from pythonbpf.maps import HashMap
from ctypes import c_void_p, c_int64 from ctypes import c_void_p, c_int64

View File

@ -0,0 +1,101 @@
import logging
from pythonbpf import compile, bpf, section, bpfglobal, compile_to_ir
from ctypes import c_void_p, c_int64, c_int32
@bpf
@bpfglobal
def somevalue() -> c_int32:
return c_int32(42)
@bpf
@bpfglobal
def somevalue2() -> c_int64:
return c_int64(69)
@bpf
@bpfglobal
def somevalue1() -> c_int32:
return c_int32(42)
# --- Passing examples ---
# Simple constant return
@bpf
@bpfglobal
def g1() -> c_int64:
return c_int64(42)
# Constructor with one constant argument
@bpf
@bpfglobal
def g2() -> c_int64:
return c_int64(69)
# --- Failing examples ---
# No return annotation
# @bpf
# @bpfglobal
# def g3():
# return 42
# Return annotation is complex
# @bpf
# @bpfglobal
# def g4() -> List[int]:
# return []
# # Return is missing
# @bpf
# @bpfglobal
# def g5() -> c_int64:
# pass
# # Return is a variable reference
# #TODO: maybe fix this sometime later. It defaults to 0
# CONST = 5
# @bpf
# @bpfglobal
# def g6() -> c_int64:
# return c_int64(CONST)
# Constructor with multiple args
#TODO: this is not working. should it work ?
@bpf
@bpfglobal
def g7() -> c_int64:
return c_int64(1)
# Dataclass call
#TODO: fails with dataclass
# @dataclass
# class Point:
# x: c_int64
# y: c_int64
# @bpf
# @bpfglobal
# def g8() -> Point:
# return Point(1, 2)
@bpf
@section("tracepoint/syscalls/sys_enter_execve")
def sometag(ctx: c_void_p) -> c_int64:
print("test")
global somevalue
somevalue = 2
print(f"{somevalue}")
return c_int64(1)
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile_to_ir("globals.py", "globals.ll", loglevel=logging.INFO)
compile()

View File

@ -0,0 +1,21 @@
import logging
from pythonbpf import compile, bpf, section, bpfglobal, compile_to_ir
from ctypes import c_void_p, c_int64
# This should not pass as somevalue is not declared at all.
@bpf
@section("tracepoint/syscalls/sys_enter_execve")
def sometag(ctx: c_void_p) -> c_int64:
print("test")
print(f"{somevalue}") # noqa: F821
return c_int64(1)
@bpf
@bpfglobal
def LICENSE() -> str:
return "GPL"
compile_to_ir("globals.py", "globals.ll", loglevel=logging.INFO)
compile()

View File

@ -3,9 +3,9 @@ from ctypes import c_void_p, c_int64
@bpf @bpf
@section("sometag1") @section("tracepoint/syscalls/sys_enter_sync")
def sometag(ctx: c_void_p) -> c_int64: def sometag(ctx: c_void_p) -> c_int64:
a = 1 + 2 + 1 a = 1 + 2 + 1 + 12 + 13
print(f"{a}") print(f"{a}")
return c_int64(0) return c_int64(0)

View File

@ -3,11 +3,12 @@ from ctypes import c_void_p, c_int64
@bpf @bpf
@section("sometag1") @section("tracepoint/syscalls/sys_enter_sync")
def sometag(ctx: c_void_p) -> c_int64: def sometag(ctx: c_void_p) -> c_int64:
b = 1 + 2 b = 1 + 2
a = 1 + b a = 1 + b
return c_int64(a) print(f"{a}")
return c_int64(0)
@bpf @bpf

View File

@ -1,7 +1,7 @@
from pythonbpf import bpf, map, struct, section, bpfglobal, compile, compile_to_ir, BPF from pythonbpf import bpf, map, struct, section, bpfglobal, compile, compile_to_ir, BPF
from pythonbpf.helpers import ktime, pid from pythonbpf.helper import ktime, pid
from pythonbpf.maps import PerfEventArray from pythonbpf.maps import PerfEventArray
import logging
from ctypes import c_void_p, c_int32, c_uint64 from ctypes import c_void_p, c_int32, c_uint64
@ -42,8 +42,8 @@ def LICENSE() -> str:
return "GPL" return "GPL"
compile()
compile_to_ir("perf_buffer_map.py", "perf_buffer_map.ll") compile_to_ir("perf_buffer_map.py", "perf_buffer_map.ll")
compile(loglevel=logging.INFO)
b = BPF() b = BPF()
b.load_and_attach() b.load_and_attach()