diff --git a/pythonbpf/helper/bpf_helper_handler.py b/pythonbpf/helper/bpf_helper_handler.py index cf08327..b3bf663 100644 --- a/pythonbpf/helper/bpf_helper_handler.py +++ b/pythonbpf/helper/bpf_helper_handler.py @@ -8,7 +8,6 @@ from .helper_utils import ( get_flags_val, get_data_ptr_and_size, get_buffer_ptr_and_size, - get_char_array_ptr_and_size, get_ptr_from_arg, get_int_value_from_arg, ) @@ -464,7 +463,7 @@ def bpf_probe_read_kernel_str_emitter( ) # Get destination buffer (char array -> i8*) - dst_ptr, dst_size = get_char_array_ptr_and_size( + dst_ptr, dst_size = get_or_create_ptr_from_arg( call.args[0], builder, local_sym_tab, struct_sym_tab ) diff --git a/pythonbpf/helper/helper_utils.py b/pythonbpf/helper/helper_utils.py index 613e025..4d7de22 100644 --- a/pythonbpf/helper/helper_utils.py +++ b/pythonbpf/helper/helper_utils.py @@ -85,52 +85,6 @@ def create_int_constant_ptr(value, builder, local_sym_tab, int_width=64): return ptr -def get_struct_char_array_ptr(expr, builder, local_sym_tab, struct_sym_tab): - """Get pointer to first element of char array in struct field, or None.""" - if not (isinstance(expr, ast.Attribute) and isinstance(expr.value, ast.Name)): - return None - - var_name = expr.value.id - field_name = expr.attr - - # Check if it's a valid struct field - if not ( - local_sym_tab - and var_name in local_sym_tab - and struct_sym_tab - and local_sym_tab[var_name].metadata in struct_sym_tab - ): - return None - - struct_type = local_sym_tab[var_name].metadata - struct_info = struct_sym_tab[struct_type] - - if field_name not in struct_info.fields: - return None - - field_type = struct_info.field_type(field_name) - - # Check if it's a char array - is_char_array = ( - isinstance(field_type, ir.ArrayType) - and isinstance(field_type.element, ir.IntType) - and field_type.element.width == 8 - ) - - if not is_char_array: - return None - - # Get field pointer and GEP to first element: [N x i8]* -> i8* - struct_ptr = local_sym_tab[var_name].var - field_ptr = struct_info.gep(builder, struct_ptr, field_name) - - return builder.gep( - field_ptr, - [ir.Constant(ir.IntType(32), 0), ir.Constant(ir.IntType(32), 0)], - inbounds=True, - ) - - def get_or_create_ptr_from_arg( func, module, @@ -148,6 +102,7 @@ def get_or_create_ptr_from_arg( # Stack space is already allocated ptr = get_var_ptr_from_name(arg.id, local_sym_tab) elif isinstance(arg, ast.Constant) and isinstance(arg.value, int): + int_width = 64 # Deafult to i64 if expected_type and isinstance(expected_type, ir.IntType): int_width = expected_type.width ptr = create_int_constant_ptr(arg.value, builder, local_sym_tab, int_width) @@ -178,7 +133,9 @@ def get_or_create_ptr_from_arg( and isinstance(field_type.element, ir.IntType) and field_type.element.width == 8 ): - ptr = get_struct_char_array_ptr(arg, builder, local_sym_tab, struct_sym_tab) + ptr, sz = get_char_array_ptr_and_size( + arg, builder, local_sym_tab, struct_sym_tab + ) if not ptr: raise ValueError("Failed to get char array pointer from struct field") else: @@ -203,6 +160,10 @@ def get_or_create_ptr_from_arg( val = builder.trunc(val, expected_type) builder.store(val, ptr) + # NOTE: For char arrays, also return size + if sz: + return ptr, sz + return ptr diff --git a/pythonbpf/helper/printk_formatter.py b/pythonbpf/helper/printk_formatter.py index 221be10..58990c0 100644 --- a/pythonbpf/helper/printk_formatter.py +++ b/pythonbpf/helper/printk_formatter.py @@ -4,7 +4,7 @@ import logging from llvmlite import ir from pythonbpf.expr import eval_expr, get_base_type_and_depth, deref_to_depth from pythonbpf.expr.vmlinux_registry import VmlinuxHandlerRegistry -from pythonbpf.helper.helper_utils import get_struct_char_array_ptr +from pythonbpf.helper.helper_utils import get_char_array_ptr_and_size logger = logging.getLogger(__name__) @@ -220,7 +220,7 @@ def _prepare_expr_args(expr, func, module, builder, local_sym_tab, struct_sym_ta """Evaluate and prepare an expression to use as an arg for bpf_printk.""" # Special case: struct field char array needs pointer to first element - char_array_ptr = get_struct_char_array_ptr( + char_array_ptr, _ = get_char_array_ptr_and_size( expr, builder, local_sym_tab, struct_sym_tab ) if char_array_ptr: