parse context from first function argument to local symbol table

This commit is contained in:
2025-10-22 11:38:52 +05:30
parent adf32560a0
commit b3921c424d
8 changed files with 64 additions and 8 deletions

View File

@ -51,7 +51,7 @@ def handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab):
return
# When allocating a variable, check if it's a vmlinux struct type
if isinstance(stmt.value, ast.Name) and VmlinuxHandlerRegistry.is_vmlinux_struct(
if isinstance(rval, ast.Name) and VmlinuxHandlerRegistry.is_vmlinux_struct(
stmt.value.id
):
# Handle vmlinux struct allocation

View File

@ -36,6 +36,7 @@ def finalize_module(original_str):
replacement = r'\1 "btf_ama"'
return re.sub(pattern, replacement, original_str)
def bpf_passthrough_gen(module):
i32_ty = ir.IntType(32)
ptr_ty = ir.PointerType(ir.IntType(8))

View File

@ -1,5 +1,7 @@
import ast
from pythonbpf.vmlinux_parser.vmlinux_exports_handler import VmlinuxHandler
class VmlinuxHandlerRegistry:
"""Registry for vmlinux handler operations"""
@ -7,7 +9,7 @@ class VmlinuxHandlerRegistry:
_handler = None
@classmethod
def set_handler(cls, handler):
def set_handler(cls, handler: VmlinuxHandler):
"""Set the vmlinux handler"""
cls._handler = handler
@ -43,3 +45,10 @@ class VmlinuxHandlerRegistry:
if cls._handler is None:
return False
return cls._handler.is_vmlinux_struct(name)
@classmethod
def get_struct_type(cls, name):
"""Try to handle a struct name as vmlinux struct"""
if cls._handler is None:
return None
return cls._handler.get_vmlinux_struct_type(name)

View File

@ -7,7 +7,12 @@ from pythonbpf.helper import (
reset_scratch_pool,
)
from pythonbpf.type_deducer import ctypes_to_ir
from pythonbpf.expr import eval_expr, handle_expr, convert_to_bool
from pythonbpf.expr import (
eval_expr,
handle_expr,
convert_to_bool,
VmlinuxHandlerRegistry,
)
from pythonbpf.assign_pass import (
handle_variable_assignment,
handle_struct_field_assignment,
@ -337,6 +342,35 @@ def process_func_body(
structs_sym_tab,
)
# Add the context parameter (first function argument) to the local symbol table
if func_node.args.args and len(func_node.args.args) > 0:
context_arg = func_node.args.args[0]
context_name = context_arg.arg
if hasattr(context_arg, "annotation") and context_arg.annotation:
if isinstance(context_arg.annotation, ast.Name):
context_type_name = context_arg.annotation.id
elif isinstance(context_arg.annotation, ast.Attribute):
context_type_name = context_arg.annotation.attr
else:
raise TypeError(
f"Unsupported annotation type: {ast.dump(context_arg.annotation)}"
)
if VmlinuxHandlerRegistry.is_vmlinux_struct(context_type_name):
resolved_type = VmlinuxHandlerRegistry.get_struct_type(
context_type_name
)
context_type = {"type": ir.PointerType(resolved_type), "ptr": True}
else:
try:
resolved_type = ctypes_to_ir(context_type_name)
context_type = {"type": ir.PointerType(resolved_type), "ptr": True}
except Exception:
raise TypeError(f"Type '{context_type_name}' not declared")
local_sym_tab[context_name] = context_type
logger.info(f"Added argument '{context_name}' to local symbol table")
logger.info(f"Local symbol table: {local_sym_tab.keys()}")
for stmt in func_node.body:

View File

@ -73,9 +73,9 @@ def bpf_map_lookup_elem_emitter(
map_void_ptr = builder.bitcast(map_ptr, ir.PointerType())
# TODO: I have changed the return type to i64*, as we are
# allocating space for that type in allocate_mem. This is
# temporary, and we will honour other widths later. But this
# allows us to have cool binary ops on the returned value.
# allocating space for that type in allocate_mem. This is
# temporary, and we will honour other widths later. But this
# allows us to have cool binary ops on the returned value.
fn_type = ir.FunctionType(
ir.PointerType(ir.IntType(64)), # Return type: void*
[ir.PointerType(), ir.PointerType()], # Args: (void*, void*)

View File

@ -2,7 +2,6 @@ import ast
import logging
import importlib
import inspect
import llvmlite.ir as ir
from .assignment_info import AssignmentInfo, AssignmentType
from .dependency_handler import DependencyHandler

View File

@ -39,6 +39,16 @@ class VmlinuxHandler:
and self.vmlinux_symtab[name]["value_type"] == AssignmentType.CONSTANT
)
def get_vmlinux_struct_type(self, name):
"""Check if name is a vmlinux struct type"""
if (
name in self.vmlinux_symtab
and self.vmlinux_symtab[name]["value_type"] == AssignmentType.STRUCT
):
return self.vmlinux_symtab[name]["python_type"]
else:
raise ValueError(f"{name} is not a vmlinux struct type")
def is_vmlinux_struct(self, name):
"""Check if name is a vmlinux struct"""
return (

View File

@ -4,7 +4,8 @@ from pythonbpf import bpf, section, bpfglobal, compile_to_ir
from pythonbpf import compile # noqa: F401
from vmlinux import TASK_COMM_LEN # noqa: F401
from vmlinux import struct_trace_event_raw_sys_enter # noqa: F401
from ctypes import c_int64
from ctypes import c_int64, c_void_p # noqa: F401
# from vmlinux import struct_uinput_device
# from vmlinux import struct_blk_integrity_iter
@ -14,7 +15,9 @@ from ctypes import c_int64
@section("tracepoint/syscalls/sys_enter_execve")
def hello_world(ctx: struct_trace_event_raw_sys_enter) -> c_int64:
a = 2 + TASK_COMM_LEN + TASK_COMM_LEN
# b = ctx
print(f"Hello, World{TASK_COMM_LEN} and {a}")
# print(f"This is context field {b}")
return c_int64(TASK_COMM_LEN + 2)