Merge pull request #59 from pythonbpf/vmlinux-handler

LGTM
This commit is contained in:
varunrmallya
2025-10-21 05:01:14 +05:30
committed by GitHub
10 changed files with 231 additions and 13 deletions

View File

@ -5,6 +5,7 @@ from llvmlite import ir
from dataclasses import dataclass
from typing import Any
from pythonbpf.helper import HelperHandlerRegistry
from .expr import VmlinuxHandlerRegistry
from pythonbpf.type_deducer import ctypes_to_ir
logger = logging.getLogger(__name__)
@ -49,6 +50,15 @@ def handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab):
logger.debug(f"Variable {var_name} already allocated, skipping")
return
# When allocating a variable, check if it's a vmlinux struct type
if isinstance(stmt.value, ast.Name) and VmlinuxHandlerRegistry.is_vmlinux_struct(
stmt.value.id
):
# Handle vmlinux struct allocation
# This requires more implementation
print(stmt.value)
pass
# Determine type and allocate based on rval
if isinstance(rval, ast.Call):
_allocate_for_call(builder, var_name, rval, local_sym_tab, structs_sym_tab)

View File

@ -5,6 +5,8 @@ from .functions import func_proc
from .maps import maps_proc
from .structs import structs_proc
from .vmlinux_parser import vmlinux_proc
from pythonbpf.vmlinux_parser.vmlinux_exports_handler import VmlinuxHandler
from .expr import VmlinuxHandlerRegistry
from .globals_pass import (
globals_list_creation,
globals_processing,
@ -56,10 +58,13 @@ def processor(source_code, filename, module):
logger.info(f"Found BPF function/struct: {func_node.name}")
vmlinux_symtab = vmlinux_proc(tree, module)
if vmlinux_symtab:
handler = VmlinuxHandler.initialize(vmlinux_symtab)
VmlinuxHandlerRegistry.set_handler(handler)
populate_global_symbol_table(tree, module)
license_processing(tree, module)
globals_processing(tree, module)
print("DEBUG:", vmlinux_symtab)
structs_sym_tab = structs_proc(tree, module, bpf_chunks)
map_sym_tab = maps_proc(tree, module, bpf_chunks)
func_proc(tree, module, bpf_chunks, map_sym_tab, structs_sym_tab)

View File

@ -2,6 +2,7 @@ from .expr_pass import eval_expr, handle_expr, get_operand_value
from .type_normalization import convert_to_bool, get_base_type_and_depth
from .ir_ops import deref_to_depth
from .call_registry import CallHandlerRegistry
from .vmlinux_registry import VmlinuxHandlerRegistry
__all__ = [
"eval_expr",
@ -11,4 +12,5 @@ __all__ = [
"deref_to_depth",
"get_operand_value",
"CallHandlerRegistry",
"VmlinuxHandlerRegistry",
]

View File

@ -12,6 +12,7 @@ from .type_normalization import (
get_base_type_and_depth,
deref_to_depth,
)
from .vmlinux_registry import VmlinuxHandlerRegistry
logger: Logger = logging.getLogger(__name__)
@ -27,8 +28,12 @@ def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder
val = builder.load(var)
return val, local_sym_tab[expr.id].ir_type
else:
logger.info(f"Undefined variable {expr.id}")
return None
# Check if it's a vmlinux enum/constant
vmlinux_result = VmlinuxHandlerRegistry.handle_name(expr.id)
if vmlinux_result is not None:
return vmlinux_result
raise SyntaxError(f"Undefined variable {expr.id}")
def _handle_constant_expr(module, builder, expr: ast.Constant):
@ -74,6 +79,13 @@ def _handle_attribute_expr(
val = builder.load(gep)
field_type = metadata.field_type(attr_name)
return val, field_type
# Try vmlinux handler as fallback
vmlinux_result = VmlinuxHandlerRegistry.handle_attribute(
expr, local_sym_tab, None, builder
)
if vmlinux_result is not None:
return vmlinux_result
return None
@ -130,7 +142,12 @@ def get_operand_value(
logger.info(f"var is {var}, base_type is {base_type}, depth is {depth}")
val = deref_to_depth(func, builder, var, depth)
return val
raise ValueError(f"Undefined variable: {operand.id}")
else:
# Check if it's a vmlinux enum/constant
vmlinux_result = VmlinuxHandlerRegistry.handle_name(operand.id)
if vmlinux_result is not None:
val, _ = vmlinux_result
return val
elif isinstance(operand, ast.Constant):
if isinstance(operand.value, int):
cst = ir.Constant(ir.IntType(64), int(operand.value))
@ -332,6 +349,7 @@ def _handle_unary_op(
neg_one = ir.Constant(ir.IntType(64), -1)
result = builder.mul(operand, neg_one)
return result, ir.IntType(64)
return None
# ============================================================================

View File

@ -0,0 +1,45 @@
import ast
class VmlinuxHandlerRegistry:
"""Registry for vmlinux handler operations"""
_handler = None
@classmethod
def set_handler(cls, handler):
"""Set the vmlinux handler"""
cls._handler = handler
@classmethod
def get_handler(cls):
"""Get the vmlinux handler"""
return cls._handler
@classmethod
def handle_name(cls, name):
"""Try to handle a name as vmlinux enum/constant"""
if cls._handler is None:
return None
return cls._handler.handle_vmlinux_enum(name)
@classmethod
def handle_attribute(cls, expr, local_sym_tab, module, builder):
"""Try to handle an attribute access as vmlinux struct field"""
if cls._handler is None:
return None
if isinstance(expr.value, ast.Name):
var_name = expr.value.id
field_name = expr.attr
return cls._handler.handle_vmlinux_struct_field(
var_name, field_name, module, builder, local_sym_tab
)
return None
@classmethod
def is_vmlinux_struct(cls, name):
"""Check if a name refers to a vmlinux struct"""
if cls._handler is None:
return False
return cls._handler.is_vmlinux_struct(name)

View File

@ -311,7 +311,13 @@ def process_stmt(
def process_func_body(
module, builder, func_node, func, ret_type, map_sym_tab, structs_sym_tab
module,
builder,
func_node,
func,
ret_type,
map_sym_tab,
structs_sym_tab,
):
"""Process the body of a bpf function"""
# TODO: A lot. We just have print -> bpf_trace_printk for now
@ -384,7 +390,13 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t
builder = ir.IRBuilder(block)
process_func_body(
module, builder, func_node, func, ret_type, map_sym_tab, structs_sym_tab
module,
builder,
func_node,
func,
ret_type,
map_sym_tab,
structs_sym_tab,
)
return func

View File

@ -3,6 +3,7 @@ import logging
from llvmlite import ir
from pythonbpf.expr import eval_expr, get_base_type_and_depth, deref_to_depth
from pythonbpf.expr.vmlinux_registry import VmlinuxHandlerRegistry
logger = logging.getLogger(__name__)
@ -108,6 +109,16 @@ def _process_name_in_fval(name_node, fmt_parts, exprs, local_sym_tab):
if local_sym_tab and name_node.id in local_sym_tab:
_, var_type, tmp = local_sym_tab[name_node.id]
_populate_fval(var_type, name_node, fmt_parts, exprs)
else:
# Try to resolve through vmlinux registry if not in local symbol table
result = VmlinuxHandlerRegistry.handle_name(name_node.id)
if result:
val, var_type = result
_populate_fval(var_type, name_node, fmt_parts, exprs)
else:
raise ValueError(
f"Variable '{name_node.id}' not found in symbol table or vmlinux"
)
def _process_attr_in_fval(attr_node, fmt_parts, exprs, local_sym_tab, struct_sym_tab):

View File

@ -6,6 +6,8 @@ from llvmlite import ir
from .maps_utils import MapProcessorRegistry
from .map_types import BPFMapType
from .map_debug_info import create_map_debug_info, create_ringbuf_debug_info
from pythonbpf.expr.vmlinux_registry import VmlinuxHandlerRegistry
logger: Logger = logging.getLogger(__name__)
@ -51,7 +53,7 @@ def _parse_map_params(rval, expected_args=None):
"""Parse map parameters from call arguments and keywords."""
params = {}
handler = VmlinuxHandlerRegistry.get_handler()
# Parse positional arguments
if expected_args:
for i, arg_name in enumerate(expected_args):
@ -65,7 +67,12 @@ def _parse_map_params(rval, expected_args=None):
# Parse keyword arguments (override positional)
for keyword in rval.keywords:
if isinstance(keyword.value, ast.Name):
params[keyword.arg] = keyword.value.id
name = keyword.value.id
if handler and handler.is_vmlinux_enum(name):
result = handler.get_vmlinux_enum_value(name)
params[keyword.arg] = result if result is not None else name
else:
params[keyword.arg] = name
elif isinstance(keyword.value, ast.Constant):
params[keyword.arg] = keyword.value.value

View File

@ -0,0 +1,90 @@
import logging
from llvmlite import ir
from pythonbpf.vmlinux_parser.assignment_info import AssignmentType
logger = logging.getLogger(__name__)
class VmlinuxHandler:
"""Handler for vmlinux-related operations"""
_instance = None
@classmethod
def get_instance(cls):
"""Get the singleton instance"""
if cls._instance is None:
logger.warning("VmlinuxHandler used before initialization")
return None
return cls._instance
@classmethod
def initialize(cls, vmlinux_symtab):
"""Initialize the handler with vmlinux symbol table"""
cls._instance = cls(vmlinux_symtab)
return cls._instance
def __init__(self, vmlinux_symtab):
"""Initialize with vmlinux symbol table"""
self.vmlinux_symtab = vmlinux_symtab
logger.info(
f"VmlinuxHandler initialized with {len(vmlinux_symtab) if vmlinux_symtab else 0} symbols"
)
def is_vmlinux_enum(self, name):
"""Check if name is a vmlinux enum constant"""
return (
name in self.vmlinux_symtab
and self.vmlinux_symtab[name]["value_type"] == AssignmentType.CONSTANT
)
def is_vmlinux_struct(self, name):
"""Check if name is a vmlinux struct"""
return (
name in self.vmlinux_symtab
and self.vmlinux_symtab[name]["value_type"] == AssignmentType.STRUCT
)
def handle_vmlinux_enum(self, name):
"""Handle vmlinux enum constants by returning LLVM IR constants"""
if self.is_vmlinux_enum(name):
value = self.vmlinux_symtab[name]["value"]
logger.info(f"Resolving vmlinux enum {name} = {value}")
return ir.Constant(ir.IntType(64), value), ir.IntType(64)
return None
def get_vmlinux_enum_value(self, name):
"""Handle vmlinux enum constants by returning LLVM IR constants"""
if self.is_vmlinux_enum(name):
value = self.vmlinux_symtab[name]["value"]
logger.info(f"The value of vmlinux enum {name} = {value}")
return value
return None
def handle_vmlinux_struct(self, struct_name, module, builder):
"""Handle vmlinux struct initializations"""
if self.is_vmlinux_struct(struct_name):
# TODO: Implement core-specific struct handling
# This will be more complex and depends on the BTF information
logger.info(f"Handling vmlinux struct {struct_name}")
# Return struct type and allocated pointer
# This is a stub, actual implementation will be more complex
return None
return None
def handle_vmlinux_struct_field(
self, struct_var_name, field_name, module, builder, local_sym_tab
):
"""Handle access to vmlinux struct fields"""
# Check if it's a variable of vmlinux struct type
if struct_var_name in local_sym_tab:
var_info = local_sym_tab[struct_var_name] # noqa: F841
# Need to check if this variable is a vmlinux struct
# This will depend on how you track vmlinux struct types in your symbol table
logger.info(
f"Attempting to access field {field_name} of possible vmlinux struct {struct_var_name}"
)
# Return pointer to field and field type
return None
return None

View File

@ -1,10 +1,26 @@
from pythonbpf import bpf, section, bpfglobal, compile_to_ir
import logging
from pythonbpf import bpf, section, bpfglobal, compile_to_ir, map
from pythonbpf import compile # noqa: F401
from vmlinux import TASK_COMM_LEN # noqa: F401
from vmlinux import struct_trace_event_raw_sys_enter # noqa: F401
from ctypes import c_uint64, c_int32, c_int64
from pythonbpf.maps import HashMap
# from vmlinux import struct_uinput_device
# from vmlinux import struct_blk_integrity_iter
from ctypes import c_int64
@bpf
@map
def mymap() -> HashMap:
return HashMap(key=c_int32, value=c_uint64, max_entries=TASK_COMM_LEN)
@bpf
@map
def mymap2() -> HashMap:
return HashMap(key=c_int32, value=c_uint64, max_entries=18)
# Instructions to how to run this program
@ -16,8 +32,9 @@ from ctypes import c_int64
@bpf
@section("tracepoint/syscalls/sys_enter_execve")
def hello_world(ctx: struct_trace_event_raw_sys_enter) -> c_int64:
print("Hello, World!")
return c_int64(0)
a = 2 + TASK_COMM_LEN + TASK_COMM_LEN
print(f"Hello, World{TASK_COMM_LEN} and {a}")
return c_int64(TASK_COMM_LEN + 2)
@bpf
@ -26,4 +43,5 @@ def LICENSE() -> str:
return "GPL"
compile_to_ir("simple_struct_test.py", "simple_struct_test.ll")
compile_to_ir("simple_struct_test.py", "simple_struct_test.ll", loglevel=logging.DEBUG)
# compile()