Merge pull request #59 from pythonbpf/vmlinux-handler

LGTM
This commit is contained in:
varunrmallya
2025-10-21 05:01:14 +05:30
committed by GitHub
10 changed files with 231 additions and 13 deletions

View File

@ -5,6 +5,7 @@ from llvmlite import ir
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
from pythonbpf.helper import HelperHandlerRegistry from pythonbpf.helper import HelperHandlerRegistry
from .expr import VmlinuxHandlerRegistry
from pythonbpf.type_deducer import ctypes_to_ir from pythonbpf.type_deducer import ctypes_to_ir
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -49,6 +50,15 @@ def handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab):
logger.debug(f"Variable {var_name} already allocated, skipping") logger.debug(f"Variable {var_name} already allocated, skipping")
return return
# When allocating a variable, check if it's a vmlinux struct type
if isinstance(stmt.value, ast.Name) and VmlinuxHandlerRegistry.is_vmlinux_struct(
stmt.value.id
):
# Handle vmlinux struct allocation
# This requires more implementation
print(stmt.value)
pass
# Determine type and allocate based on rval # Determine type and allocate based on rval
if isinstance(rval, ast.Call): if isinstance(rval, ast.Call):
_allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab) _allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab)

View File

@ -5,6 +5,8 @@ from .functions import func_proc
from .maps import maps_proc from .maps import maps_proc
from .structs import structs_proc from .structs import structs_proc
from .vmlinux_parser import vmlinux_proc from .vmlinux_parser import vmlinux_proc
from pythonbpf.vmlinux_parser.vmlinux_exports_handler import VmlinuxHandler
from .expr import VmlinuxHandlerRegistry
from .globals_pass import ( from .globals_pass import (
globals_list_creation, globals_list_creation,
globals_processing, globals_processing,
@ -56,10 +58,13 @@ def processor(source_code, filename, module):
logger.info(f"Found BPF function/struct: {func_node.name}") logger.info(f"Found BPF function/struct: {func_node.name}")
vmlinux_symtab = vmlinux_proc(tree, module) vmlinux_symtab = vmlinux_proc(tree, module)
if vmlinux_symtab:
handler = VmlinuxHandler.initialize(vmlinux_symtab)
VmlinuxHandlerRegistry.set_handler(handler)
populate_global_symbol_table(tree, module) populate_global_symbol_table(tree, module)
license_processing(tree, module) license_processing(tree, module)
globals_processing(tree, module) globals_processing(tree, module)
print("DEBUG:", vmlinux_symtab)
structs_sym_tab = structs_proc(tree, module, bpf_chunks) structs_sym_tab = structs_proc(tree, module, bpf_chunks)
map_sym_tab = maps_proc(tree, module, bpf_chunks) map_sym_tab = maps_proc(tree, module, bpf_chunks)
func_proc(tree, module, bpf_chunks, map_sym_tab, structs_sym_tab) func_proc(tree, module, bpf_chunks, map_sym_tab, structs_sym_tab)

View File

@ -2,6 +2,7 @@ from .expr_pass import eval_expr, handle_expr, get_operand_value
from .type_normalization import convert_to_bool, get_base_type_and_depth from .type_normalization import convert_to_bool, get_base_type_and_depth
from .ir_ops import deref_to_depth from .ir_ops import deref_to_depth
from .call_registry import CallHandlerRegistry from .call_registry import CallHandlerRegistry
from .vmlinux_registry import VmlinuxHandlerRegistry
__all__ = [ __all__ = [
"eval_expr", "eval_expr",
@ -11,4 +12,5 @@ __all__ = [
"deref_to_depth", "deref_to_depth",
"get_operand_value", "get_operand_value",
"CallHandlerRegistry", "CallHandlerRegistry",
"VmlinuxHandlerRegistry",
] ]

View File

@ -12,6 +12,7 @@ from .type_normalization import (
get_base_type_and_depth, get_base_type_and_depth,
deref_to_depth, deref_to_depth,
) )
from .vmlinux_registry import VmlinuxHandlerRegistry
logger: Logger = logging.getLogger(__name__) logger: Logger = logging.getLogger(__name__)
@ -27,8 +28,12 @@ def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder
val = builder.load(var) val = builder.load(var)
return val, local_sym_tab[expr.id].ir_type return val, local_sym_tab[expr.id].ir_type
else: else:
logger.info(f"Undefined variable {expr.id}") # Check if it's a vmlinux enum/constant
return None vmlinux_result = VmlinuxHandlerRegistry.handle_name(expr.id)
if vmlinux_result is not None:
return vmlinux_result
raise SyntaxError(f"Undefined variable {expr.id}")
def _handle_constant_expr(module, builder, expr: ast.Constant): def _handle_constant_expr(module, builder, expr: ast.Constant):
@ -74,6 +79,13 @@ def _handle_attribute_expr(
val = builder.load(gep) val = builder.load(gep)
field_type = metadata.field_type(attr_name) field_type = metadata.field_type(attr_name)
return val, field_type return val, field_type
# Try vmlinux handler as fallback
vmlinux_result = VmlinuxHandlerRegistry.handle_attribute(
expr, local_sym_tab, None, builder
)
if vmlinux_result is not None:
return vmlinux_result
return None return None
@ -130,7 +142,12 @@ def get_operand_value(
logger.info(f"var is {var}, base_type is {base_type}, depth is {depth}") logger.info(f"var is {var}, base_type is {base_type}, depth is {depth}")
val = deref_to_depth(func, builder, var, depth) val = deref_to_depth(func, builder, var, depth)
return val return val
raise ValueError(f"Undefined variable: {operand.id}") else:
# Check if it's a vmlinux enum/constant
vmlinux_result = VmlinuxHandlerRegistry.handle_name(operand.id)
if vmlinux_result is not None:
val, _ = vmlinux_result
return val
elif isinstance(operand, ast.Constant): elif isinstance(operand, ast.Constant):
if isinstance(operand.value, int): if isinstance(operand.value, int):
cst = ir.Constant(ir.IntType(64), int(operand.value)) cst = ir.Constant(ir.IntType(64), int(operand.value))
@ -332,6 +349,7 @@ def _handle_unary_op(
neg_one = ir.Constant(ir.IntType(64), -1) neg_one = ir.Constant(ir.IntType(64), -1)
result = builder.mul(operand, neg_one) result = builder.mul(operand, neg_one)
return result, ir.IntType(64) return result, ir.IntType(64)
return None
# ============================================================================ # ============================================================================

View File

@ -0,0 +1,45 @@
import ast
class VmlinuxHandlerRegistry:
"""Registry for vmlinux handler operations"""
_handler = None
@classmethod
def set_handler(cls, handler):
"""Set the vmlinux handler"""
cls._handler = handler
@classmethod
def get_handler(cls):
"""Get the vmlinux handler"""
return cls._handler
@classmethod
def handle_name(cls, name):
"""Try to handle a name as vmlinux enum/constant"""
if cls._handler is None:
return None
return cls._handler.handle_vmlinux_enum(name)
@classmethod
def handle_attribute(cls, expr, local_sym_tab, module, builder):
"""Try to handle an attribute access as vmlinux struct field"""
if cls._handler is None:
return None
if isinstance(expr.value, ast.Name):
var_name = expr.value.id
field_name = expr.attr
return cls._handler.handle_vmlinux_struct_field(
var_name, field_name, module, builder, local_sym_tab
)
return None
@classmethod
def is_vmlinux_struct(cls, name):
"""Check if a name refers to a vmlinux struct"""
if cls._handler is None:
return False
return cls._handler.is_vmlinux_struct(name)

View File

@ -311,7 +311,13 @@ def process_stmt(
def process_func_body( def process_func_body(
module, builder, func_node, func, ret_type, map_sym_tab, structs_sym_tab module,
builder,
func_node,
func,
ret_type,
map_sym_tab,
structs_sym_tab,
): ):
"""Process the body of a bpf function""" """Process the body of a bpf function"""
# TODO: A lot. We just have print -> bpf_trace_printk for now # TODO: A lot. We just have print -> bpf_trace_printk for now
@ -384,7 +390,13 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t
builder = ir.IRBuilder(block) builder = ir.IRBuilder(block)
process_func_body( process_func_body(
module, builder, func_node, func, ret_type, map_sym_tab, structs_sym_tab module,
builder,
func_node,
func,
ret_type,
map_sym_tab,
structs_sym_tab,
) )
return func return func

View File

@ -3,6 +3,7 @@ import logging
from llvmlite import ir from llvmlite import ir
from pythonbpf.expr import eval_expr, get_base_type_and_depth, deref_to_depth from pythonbpf.expr import eval_expr, get_base_type_and_depth, deref_to_depth
from pythonbpf.expr.vmlinux_registry import VmlinuxHandlerRegistry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -108,6 +109,16 @@ def _process_name_in_fval(name_node, fmt_parts, exprs, local_sym_tab):
if local_sym_tab and name_node.id in local_sym_tab: if local_sym_tab and name_node.id in local_sym_tab:
_, var_type, tmp = local_sym_tab[name_node.id] _, var_type, tmp = local_sym_tab[name_node.id]
_populate_fval(var_type, name_node, fmt_parts, exprs) _populate_fval(var_type, name_node, fmt_parts, exprs)
else:
# Try to resolve through vmlinux registry if not in local symbol table
result = VmlinuxHandlerRegistry.handle_name(name_node.id)
if result:
val, var_type = result
_populate_fval(var_type, name_node, fmt_parts, exprs)
else:
raise ValueError(
f"Variable '{name_node.id}' not found in symbol table or vmlinux"
)
def _process_attr_in_fval(attr_node, fmt_parts, exprs, local_sym_tab, struct_sym_tab): def _process_attr_in_fval(attr_node, fmt_parts, exprs, local_sym_tab, struct_sym_tab):

View File

@ -6,6 +6,8 @@ from llvmlite import ir
from .maps_utils import MapProcessorRegistry from .maps_utils import MapProcessorRegistry
from .map_types import BPFMapType from .map_types import BPFMapType
from .map_debug_info import create_map_debug_info, create_ringbuf_debug_info from .map_debug_info import create_map_debug_info, create_ringbuf_debug_info
from pythonbpf.expr.vmlinux_registry import VmlinuxHandlerRegistry
logger: Logger = logging.getLogger(__name__) logger: Logger = logging.getLogger(__name__)
@ -51,7 +53,7 @@ def _parse_map_params(rval, expected_args=None):
"""Parse map parameters from call arguments and keywords.""" """Parse map parameters from call arguments and keywords."""
params = {} params = {}
handler = VmlinuxHandlerRegistry.get_handler()
# Parse positional arguments # Parse positional arguments
if expected_args: if expected_args:
for i, arg_name in enumerate(expected_args): for i, arg_name in enumerate(expected_args):
@ -65,7 +67,12 @@ def _parse_map_params(rval, expected_args=None):
# Parse keyword arguments (override positional) # Parse keyword arguments (override positional)
for keyword in rval.keywords: for keyword in rval.keywords:
if isinstance(keyword.value, ast.Name): if isinstance(keyword.value, ast.Name):
params[keyword.arg] = keyword.value.id name = keyword.value.id
if handler and handler.is_vmlinux_enum(name):
result = handler.get_vmlinux_enum_value(name)
params[keyword.arg] = result if result is not None else name
else:
params[keyword.arg] = name
elif isinstance(keyword.value, ast.Constant): elif isinstance(keyword.value, ast.Constant):
params[keyword.arg] = keyword.value.value params[keyword.arg] = keyword.value.value

View File

@ -0,0 +1,90 @@
import logging
from llvmlite import ir
from pythonbpf.vmlinux_parser.assignment_info import AssignmentType
logger = logging.getLogger(__name__)
class VmlinuxHandler:
"""Handler for vmlinux-related operations"""
_instance = None
@classmethod
def get_instance(cls):
"""Get the singleton instance"""
if cls._instance is None:
logger.warning("VmlinuxHandler used before initialization")
return None
return cls._instance
@classmethod
def initialize(cls, vmlinux_symtab):
"""Initialize the handler with vmlinux symbol table"""
cls._instance = cls(vmlinux_symtab)
return cls._instance
def __init__(self, vmlinux_symtab):
"""Initialize with vmlinux symbol table"""
self.vmlinux_symtab = vmlinux_symtab
logger.info(
f"VmlinuxHandler initialized with {len(vmlinux_symtab) if vmlinux_symtab else 0} symbols"
)
def is_vmlinux_enum(self, name):
"""Check if name is a vmlinux enum constant"""
return (
name in self.vmlinux_symtab
and self.vmlinux_symtab[name]["value_type"] == AssignmentType.CONSTANT
)
def is_vmlinux_struct(self, name):
"""Check if name is a vmlinux struct"""
return (
name in self.vmlinux_symtab
and self.vmlinux_symtab[name]["value_type"] == AssignmentType.STRUCT
)
def handle_vmlinux_enum(self, name):
"""Handle vmlinux enum constants by returning LLVM IR constants"""
if self.is_vmlinux_enum(name):
value = self.vmlinux_symtab[name]["value"]
logger.info(f"Resolving vmlinux enum {name} = {value}")
return ir.Constant(ir.IntType(64), value), ir.IntType(64)
return None
def get_vmlinux_enum_value(self, name):
"""Handle vmlinux enum constants by returning LLVM IR constants"""
if self.is_vmlinux_enum(name):
value = self.vmlinux_symtab[name]["value"]
logger.info(f"The value of vmlinux enum {name} = {value}")
return value
return None
def handle_vmlinux_struct(self, struct_name, module, builder):
"""Handle vmlinux struct initializations"""
if self.is_vmlinux_struct(struct_name):
# TODO: Implement core-specific struct handling
# This will be more complex and depends on the BTF information
logger.info(f"Handling vmlinux struct {struct_name}")
# Return struct type and allocated pointer
# This is a stub, actual implementation will be more complex
return None
return None
def handle_vmlinux_struct_field(
self, struct_var_name, field_name, module, builder, local_sym_tab
):
"""Handle access to vmlinux struct fields"""
# Check if it's a variable of vmlinux struct type
if struct_var_name in local_sym_tab:
var_info = local_sym_tab[struct_var_name] # noqa: F841
# Need to check if this variable is a vmlinux struct
# This will depend on how you track vmlinux struct types in your symbol table
logger.info(
f"Attempting to access field {field_name} of possible vmlinux struct {struct_var_name}"
)
# Return pointer to field and field type
return None
return None

View File

@ -1,10 +1,26 @@
from pythonbpf import bpf, section, bpfglobal, compile_to_ir import logging
from pythonbpf import bpf, section, bpfglobal, compile_to_ir, map
from pythonbpf import compile # noqa: F401
from vmlinux import TASK_COMM_LEN # noqa: F401 from vmlinux import TASK_COMM_LEN # noqa: F401
from vmlinux import struct_trace_event_raw_sys_enter # noqa: F401 from vmlinux import struct_trace_event_raw_sys_enter # noqa: F401
from ctypes import c_uint64, c_int32, c_int64
from pythonbpf.maps import HashMap
# from vmlinux import struct_uinput_device # from vmlinux import struct_uinput_device
# from vmlinux import struct_blk_integrity_iter # from vmlinux import struct_blk_integrity_iter
from ctypes import c_int64
@bpf
@map
def mymap() -> HashMap:
return HashMap(key=c_int32, value=c_uint64, max_entries=TASK_COMM_LEN)
@bpf
@map
def mymap2() -> HashMap:
return HashMap(key=c_int32, value=c_uint64, max_entries=18)
# Instructions to how to run this program # Instructions to how to run this program
@ -16,8 +32,9 @@ from ctypes import c_int64
@bpf @bpf
@section("tracepoint/syscalls/sys_enter_execve") @section("tracepoint/syscalls/sys_enter_execve")
def hello_world(ctx: struct_trace_event_raw_sys_enter) -> c_int64: def hello_world(ctx: struct_trace_event_raw_sys_enter) -> c_int64:
print("Hello, World!") a = 2 + TASK_COMM_LEN + TASK_COMM_LEN
return c_int64(0) print(f"Hello, World{TASK_COMM_LEN} and {a}")
return c_int64(TASK_COMM_LEN + 2)
@bpf @bpf
@ -26,4 +43,5 @@ def LICENSE() -> str:
return "GPL" return "GPL"
compile_to_ir("simple_struct_test.py", "simple_struct_test.ll") compile_to_ir("simple_struct_test.py", "simple_struct_test.ll", loglevel=logging.DEBUG)
# compile()