mirror of
https://github.com/varun-r-mallya/Python-BPF.git
synced 2025-12-31 21:06:25 +00:00
Use scratch space to store consts passed to helpers
This commit is contained in:
@ -4,7 +4,11 @@ import logging
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call
|
from pythonbpf.helper import (
|
||||||
|
HelperHandlerRegistry,
|
||||||
|
handle_helper_call,
|
||||||
|
reset_scratch_pool,
|
||||||
|
)
|
||||||
from pythonbpf.type_deducer import ctypes_to_ir
|
from pythonbpf.type_deducer import ctypes_to_ir
|
||||||
from pythonbpf.binary_ops import handle_binary_op
|
from pythonbpf.binary_ops import handle_binary_op
|
||||||
from pythonbpf.expr import eval_expr, handle_expr, convert_to_bool
|
from pythonbpf.expr import eval_expr, handle_expr, convert_to_bool
|
||||||
@ -353,6 +357,7 @@ def process_stmt(
|
|||||||
ret_type=ir.IntType(64),
|
ret_type=ir.IntType(64),
|
||||||
):
|
):
|
||||||
logger.info(f"Processing statement: {ast.dump(stmt)}")
|
logger.info(f"Processing statement: {ast.dump(stmt)}")
|
||||||
|
reset_scratch_pool()
|
||||||
if isinstance(stmt, ast.Expr):
|
if isinstance(stmt, ast.Expr):
|
||||||
handle_expr(
|
handle_expr(
|
||||||
func,
|
func,
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
from .helper_utils import HelperHandlerRegistry
|
from .helper_utils import HelperHandlerRegistry, reset_scratch_pool
|
||||||
from .bpf_helper_handler import handle_helper_call
|
from .bpf_helper_handler import handle_helper_call
|
||||||
from .helpers import ktime, pid, deref, XDP_DROP, XDP_PASS
|
from .helpers import ktime, pid, deref, XDP_DROP, XDP_PASS
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"HelperHandlerRegistry",
|
"HelperHandlerRegistry",
|
||||||
|
"reset_scratch_pool",
|
||||||
"handle_helper_call",
|
"handle_helper_call",
|
||||||
"ktime",
|
"ktime",
|
||||||
"pid",
|
"pid",
|
||||||
|
|||||||
@ -58,6 +58,8 @@ class ScratchPoolManager:
|
|||||||
f"Current counter: {self._counter}"
|
f"Current counter: {self._counter}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return local_sym_tab[temp_name].var, temp_name
|
||||||
|
|
||||||
|
|
||||||
_temp_pool_manager = ScratchPoolManager() # Singleton instance
|
_temp_pool_manager = ScratchPoolManager() # Singleton instance
|
||||||
|
|
||||||
@ -67,11 +69,6 @@ def reset_scratch_pool():
|
|||||||
_temp_pool_manager.reset()
|
_temp_pool_manager.reset()
|
||||||
|
|
||||||
|
|
||||||
def get_next_scratch_temp(local_sym_tab):
|
|
||||||
"""Get the next temporary variable name from the scratch pool"""
|
|
||||||
return _temp_pool_manager.get_next_temp(local_sym_tab)
|
|
||||||
|
|
||||||
|
|
||||||
def get_var_ptr_from_name(var_name, local_sym_tab):
|
def get_var_ptr_from_name(var_name, local_sym_tab):
|
||||||
"""Get a pointer to a variable from the symbol table."""
|
"""Get a pointer to a variable from the symbol table."""
|
||||||
if local_sym_tab and var_name in local_sym_tab:
|
if local_sym_tab and var_name in local_sym_tab:
|
||||||
@ -79,13 +76,14 @@ def get_var_ptr_from_name(var_name, local_sym_tab):
|
|||||||
raise ValueError(f"Variable '{var_name}' not found in local symbol table")
|
raise ValueError(f"Variable '{var_name}' not found in local symbol table")
|
||||||
|
|
||||||
|
|
||||||
def create_int_constant_ptr(value, builder, int_width=64):
|
def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64):
|
||||||
"""Create a pointer to an integer constant."""
|
"""Create a pointer to an integer constant."""
|
||||||
|
|
||||||
# Default to 64-bit integer
|
# Default to 64-bit integer
|
||||||
int_type = ir.IntType(int_width)
|
ptr, temp_name = _temp_pool_manager.get_next_temp(local_sym_tab)
|
||||||
ptr = builder.alloca(int_type)
|
logger.debug(f"Using temp variable '{temp_name}' for int constant {value}")
|
||||||
ptr.align = int_type.width // 8
|
const_val = ir.Constant(ir.IntType(int_width), value)
|
||||||
builder.store(ir.Constant(int_type, value), ptr)
|
builder.store(const_val, ptr)
|
||||||
return ptr
|
return ptr
|
||||||
|
|
||||||
|
|
||||||
@ -95,7 +93,7 @@ def get_or_create_ptr_from_arg(arg, builder, local_sym_tab):
|
|||||||
if isinstance(arg, ast.Name):
|
if isinstance(arg, ast.Name):
|
||||||
ptr = get_var_ptr_from_name(arg.id, local_sym_tab)
|
ptr = get_var_ptr_from_name(arg.id, local_sym_tab)
|
||||||
elif isinstance(arg, ast.Constant) and isinstance(arg.value, int):
|
elif isinstance(arg, ast.Constant) and isinstance(arg.value, int):
|
||||||
ptr = create_int_constant_ptr(arg.value, builder)
|
ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Only simple variable names are supported as args in map helpers."
|
"Only simple variable names are supported as args in map helpers."
|
||||||
|
|||||||
Reference in New Issue
Block a user