diff --git a/pythonbpf/functions/func_registry_handlers.py b/pythonbpf/functions/func_registry_handlers.py deleted file mode 100644 index afe54f6..0000000 --- a/pythonbpf/functions/func_registry_handlers.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Dict - - -class StatementHandlerRegistry: - """Registry for statement handlers.""" - - _handlers: Dict = {} - - @classmethod - def register(cls, stmt_type): - """Register a handler for a specific statement type.""" - - def decorator(handler): - cls._handlers[stmt_type] = handler - return handler - - return decorator - - @classmethod - def __getitem__(cls, stmt_type): - """Get the handler for a specific statement type.""" - return cls._handlers.get(stmt_type, None) diff --git a/pythonbpf/functions/function_metadata.py b/pythonbpf/functions/function_metadata.py new file mode 100644 index 0000000..4597581 --- /dev/null +++ b/pythonbpf/functions/function_metadata.py @@ -0,0 +1,88 @@ +import ast + + +def get_probe_string(func_node): + """Extract the probe string from the decorator of the function node""" + # TODO: right now we have the whole string in the section decorator + # But later we can implement typed tuples for tracepoints and kprobes + # For helper functions, we return "helper" + + for decorator in func_node.decorator_list: + if isinstance(decorator, ast.Name) and decorator.id == "bpfglobal": + return None + if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Name): + if decorator.func.id == "section" and len(decorator.args) == 1: + arg = decorator.args[0] + if isinstance(arg, ast.Constant) and isinstance(arg.value, str): + return arg.value + return "helper" + + +def is_global_function(func_node): + """Check if the function is a global""" + for decorator in func_node.decorator_list: + if isinstance(decorator, ast.Name) and decorator.id in ( + "map", + "bpfglobal", + "struct", + ): + return True + return False + + +def infer_return_type(func_node: ast.FunctionDef): + if not isinstance(func_node, (ast.FunctionDef, ast.AsyncFunctionDef)): + raise TypeError("Expected ast.FunctionDef") + if func_node.returns is not None: + try: + return ast.unparse(func_node.returns) + except Exception: + node = func_node.returns + if isinstance(node, ast.Name): + return node.id + if isinstance(node, ast.Attribute): + return getattr(node, "attr", type(node).__name__) + try: + return str(node) + except Exception: + return type(node).__name__ + found_type = None + + def _expr_type(e): + if e is None: + return "None" + if isinstance(e, ast.Constant): + return type(e.value).__name__ + if isinstance(e, ast.Name): + return e.id + if isinstance(e, ast.Call): + f = e.func + if isinstance(f, ast.Name): + return f.id + if isinstance(f, ast.Attribute): + try: + return ast.unparse(f) + except Exception: + return getattr(f, "attr", type(f).__name__) + try: + return ast.unparse(f) + except Exception: + return type(f).__name__ + if isinstance(e, ast.Attribute): + try: + return ast.unparse(e) + except Exception: + return getattr(e, "attr", type(e).__name__) + try: + return ast.unparse(e) + except Exception: + return type(e).__name__ + + for walked_node in ast.walk(func_node): + if isinstance(walked_node, ast.Return): + t = _expr_type(walked_node.value) + if found_type is None: + found_type = t + elif found_type != t: + raise ValueError(f"Conflicting return types: {found_type} vs {t}") + return found_type or "None" diff --git a/pythonbpf/functions/functions_pass.py b/pythonbpf/functions/functions_pass.py index f706a84..8d0bce1 100644 --- a/pythonbpf/functions/functions_pass.py +++ b/pythonbpf/functions/functions_pass.py @@ -15,26 +15,124 @@ from pythonbpf.assign_pass import ( from pythonbpf.allocation_pass import handle_assign_allocation, allocate_temp_pool from .return_utils import handle_none_return, handle_xdp_return, is_xdp_name +from .function_metadata import get_probe_string, is_global_function, infer_return_type logger = logging.getLogger(__name__) -def get_probe_string(func_node): - """Extract the probe string from the decorator of the function node.""" - # TODO: right now we have the whole string in the section decorator - # But later we can implement typed tuples for tracepoints and kprobes - # For helper functions, we return "helper" +# ============================================================================ +# SECTION 1: Memory Allocation +# ============================================================================ - for decorator in func_node.decorator_list: - if isinstance(decorator, ast.Name) and decorator.id == "bpfglobal": - return None - if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Name): - if decorator.func.id == "section" and len(decorator.args) == 1: - arg = decorator.args[0] - if isinstance(arg, ast.Constant) and isinstance(arg.value, str): - return arg.value - return "helper" + +def count_temps_in_call(call_node, local_sym_tab): + """Count the number of temporary variables needed for a function call.""" + + count = 0 + is_helper = False + + # NOTE: We exclude print calls for now + if isinstance(call_node.func, ast.Name): + if ( + HelperHandlerRegistry.has_handler(call_node.func.id) + and call_node.func.id != "print" + ): + is_helper = True + elif isinstance(call_node.func, ast.Attribute): + if HelperHandlerRegistry.has_handler(call_node.func.attr): + is_helper = True + + if not is_helper: + return 0 + + for arg in call_node.args: + # NOTE: Count all non-name arguments + # For struct fields, if it is being passed as an argument, + # The struct object should already exist in the local_sym_tab + if not isinstance(arg, ast.Name) and not ( + isinstance(arg, ast.Attribute) and arg.value.id in local_sym_tab + ): + count += 1 + + return count + + +def handle_if_allocation( + module, builder, stmt, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab +): + """Recursively handle allocations in if/else branches.""" + if stmt.body: + allocate_mem( + module, + builder, + stmt.body, + func, + ret_type, + map_sym_tab, + local_sym_tab, + structs_sym_tab, + ) + if stmt.orelse: + allocate_mem( + module, + builder, + stmt.orelse, + func, + ret_type, + map_sym_tab, + local_sym_tab, + structs_sym_tab, + ) + + +def allocate_mem( + module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab +): + max_temps_needed = 0 + + def update_max_temps_for_stmt(stmt): + nonlocal max_temps_needed + temps_needed = 0 + + if isinstance(stmt, ast.If): + for s in stmt.body: + update_max_temps_for_stmt(s) + for s in stmt.orelse: + update_max_temps_for_stmt(s) + return + + for node in ast.walk(stmt): + if isinstance(node, ast.Call): + temps_needed += count_temps_in_call(node, local_sym_tab) + max_temps_needed = max(max_temps_needed, temps_needed) + + for stmt in body: + update_max_temps_for_stmt(stmt) + + # Handle allocations + if isinstance(stmt, ast.If): + handle_if_allocation( + module, + builder, + stmt, + func, + ret_type, + map_sym_tab, + local_sym_tab, + structs_sym_tab, + ) + elif isinstance(stmt, ast.Assign): + handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab) + + allocate_temp_pool(builder, max_temps_needed, local_sym_tab) + + return local_sym_tab + + +# ============================================================================ +# SECTION 2: Statement Handlers +# ============================================================================ def handle_assign( @@ -207,108 +305,9 @@ def process_stmt( return did_return -def handle_if_allocation( - module, builder, stmt, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab -): - """Recursively handle allocations in if/else branches.""" - if stmt.body: - allocate_mem( - module, - builder, - stmt.body, - func, - ret_type, - map_sym_tab, - local_sym_tab, - structs_sym_tab, - ) - if stmt.orelse: - allocate_mem( - module, - builder, - stmt.orelse, - func, - ret_type, - map_sym_tab, - local_sym_tab, - structs_sym_tab, - ) - - -def count_temps_in_call(call_node, local_sym_tab): - """Count the number of temporary variables needed for a function call.""" - - count = 0 - is_helper = False - - # NOTE: We exclude print calls for now - if isinstance(call_node.func, ast.Name): - if ( - HelperHandlerRegistry.has_handler(call_node.func.id) - and call_node.func.id != "print" - ): - is_helper = True - elif isinstance(call_node.func, ast.Attribute): - if HelperHandlerRegistry.has_handler(call_node.func.attr): - is_helper = True - - if not is_helper: - return 0 - - for arg in call_node.args: - # NOTE: Count all non-name arguments - # For struct fields, if it is being passed as an argument, - # The struct object should already exist in the local_sym_tab - if not isinstance(arg, ast.Name) and not ( - isinstance(arg, ast.Attribute) and arg.value.id in local_sym_tab - ): - count += 1 - - return count - - -def allocate_mem( - module, builder, body, func, ret_type, map_sym_tab, local_sym_tab, structs_sym_tab -): - max_temps_needed = 0 - - def update_max_temps_for_stmt(stmt): - nonlocal max_temps_needed - temps_needed = 0 - - if isinstance(stmt, ast.If): - for s in stmt.body: - update_max_temps_for_stmt(s) - for s in stmt.orelse: - update_max_temps_for_stmt(s) - return - - for node in ast.walk(stmt): - if isinstance(node, ast.Call): - temps_needed += count_temps_in_call(node, local_sym_tab) - max_temps_needed = max(max_temps_needed, temps_needed) - - for stmt in body: - update_max_temps_for_stmt(stmt) - - # Handle allocations - if isinstance(stmt, ast.If): - handle_if_allocation( - module, - builder, - stmt, - func, - ret_type, - map_sym_tab, - local_sym_tab, - structs_sym_tab, - ) - elif isinstance(stmt, ast.Assign): - handle_assign_allocation(builder, stmt, local_sym_tab, structs_sym_tab) - - allocate_temp_pool(builder, max_temps_needed, local_sym_tab) - - return local_sym_tab +# ============================================================================ +# SECTION 3: Function Body Processing +# ============================================================================ def process_func_body( @@ -390,18 +389,14 @@ def process_bpf_chunk(func_node, module, return_type, map_sym_tab, structs_sym_t return func +# ============================================================================ +# SECTION 4: Top-Level Function Processor +# ============================================================================ + + def func_proc(tree, module, chunks, map_sym_tab, structs_sym_tab): for func_node in chunks: - is_global = False - for decorator in func_node.decorator_list: - if isinstance(decorator, ast.Name) and decorator.id in ( - "map", - "bpfglobal", - "struct", - ): - is_global = True - break - if is_global: + if is_global_function(func_node): continue func_type = get_probe_string(func_node) logger.info(f"Found probe_string of {func_node.name}: {func_type}") @@ -415,67 +410,7 @@ def func_proc(tree, module, chunks, map_sym_tab, structs_sym_tab): ) -def infer_return_type(func_node: ast.FunctionDef): - if not isinstance(func_node, (ast.FunctionDef, ast.AsyncFunctionDef)): - raise TypeError("Expected ast.FunctionDef") - if func_node.returns is not None: - try: - return ast.unparse(func_node.returns) - except Exception: - node = func_node.returns - if isinstance(node, ast.Name): - return node.id - if isinstance(node, ast.Attribute): - return getattr(node, "attr", type(node).__name__) - try: - return str(node) - except Exception: - return type(node).__name__ - found_type = None - - def _expr_type(e): - if e is None: - return "None" - if isinstance(e, ast.Constant): - return type(e.value).__name__ - if isinstance(e, ast.Name): - return e.id - if isinstance(e, ast.Call): - f = e.func - if isinstance(f, ast.Name): - return f.id - if isinstance(f, ast.Attribute): - try: - return ast.unparse(f) - except Exception: - return getattr(f, "attr", type(f).__name__) - try: - return ast.unparse(f) - except Exception: - return type(f).__name__ - if isinstance(e, ast.Attribute): - try: - return ast.unparse(e) - except Exception: - return getattr(e, "attr", type(e).__name__) - try: - return ast.unparse(e) - except Exception: - return type(e).__name__ - - for walked_node in ast.walk(func_node): - if isinstance(walked_node, ast.Return): - t = _expr_type(walked_node.value) - if found_type is None: - found_type = t - elif found_type != t: - raise ValueError(f"Conflicting return types: {found_type} vs {t}") - return found_type or "None" - - -# For string assignment to fixed-size arrays - - +# TODO: WIP, for string assignment to fixed-size arrays def assign_string_to_array(builder, target_array_ptr, source_string_ptr, array_length): """ Copy a string (i8*) to a fixed-size array ([N x i8]*)