From 6ccbab402f3f4320581ab1bce18f6755d9f50bd7 Mon Sep 17 00:00:00 2001 From: Pragyansh Chaturvedi Date: Wed, 1 Oct 2025 22:12:30 +0530 Subject: [PATCH] Complete printk refactor --- pythonbpf/helper/bpf_helper_handler.py | 171 +++---------------------- pythonbpf/helper/helper_utils.py | 18 ++- 2 files changed, 33 insertions(+), 156 deletions(-) diff --git a/pythonbpf/helper/bpf_helper_handler.py b/pythonbpf/helper/bpf_helper_handler.py index 2c763b5..6a0cf14 100644 --- a/pythonbpf/helper/bpf_helper_handler.py +++ b/pythonbpf/helper/bpf_helper_handler.py @@ -2,7 +2,7 @@ import ast from llvmlite import ir from pythonbpf.expr_pass import eval_expr from enum import Enum -from .helper_utils import HelperHandlerRegistry, get_or_create_ptr_from_arg, get_flags_val +from .helper_utils import HelperHandlerRegistry, get_or_create_ptr_from_arg, get_flags_val, _handle_fstring_print, _simple_string_print class BPFHelperID(Enum): @@ -64,163 +64,34 @@ def bpf_map_lookup_elem_emitter(call, map_ptr, module, builder, func, def bpf_printk_emitter(call, map_ptr, module, builder, func, local_sym_tab=None, struct_sym_tab=None, local_var_metadata=None): + """Emit LLVM IR for bpf_printk helper function call.""" if not hasattr(func, "_fmt_counter"): func._fmt_counter = 0 if not call.args: - raise ValueError("print expects at least one argument") + raise ValueError( + "bpf_printk expects at least one argument (format string)") + args = [] if isinstance(call.args[0], ast.JoinedStr): - fmt_parts = [] - exprs = [] + args = _handle_fstring_print(call.args[0], module, builder, func, + local_sym_tab, struct_sym_tab, + local_var_metadata) + elif isinstance(call.args[0], ast.Constant) and isinstance(call.args[0].value, str): + # TODO: We are onbly supporting single arguments for now. + # In case of multiple args, the first one will be taken. + args = _simple_string_print(call.args[0], module, builder, func) + else: + raise NotImplementedError( + "Only simple string literals or f-strings are supported in bpf_printk.") - for value in call.args[0].values: - print("Value in f-string:", ast.dump(value)) - if isinstance(value, ast.Constant): - if isinstance(value.value, str): - fmt_parts.append(value.value) - elif isinstance(value.value, int): - fmt_parts.append("%lld") - exprs.append(ir.Constant(ir.IntType(64), value.value)) - else: - raise NotImplementedError( - "Only string and integer constants are supported in f-string.") - elif isinstance(value, ast.FormattedValue): - print("Formatted value:", ast.dump(value)) - # TODO: Dirty handling here, only checks for int or str - if isinstance(value.value, ast.Name): - if local_sym_tab and value.value.id in local_sym_tab: - var_ptr, var_type = local_sym_tab[value.value.id] - if isinstance(var_type, ir.IntType): - fmt_parts.append("%lld") - exprs.append(value.value) - elif var_type == ir.PointerType(ir.IntType(8)): - # Case with string - fmt_parts.append("%s") - exprs.append(value.value) - else: - raise NotImplementedError( - "Only integer and pointer types are supported in formatted values.") - else: - raise ValueError( - f"Variable {value.value.id} not found in local symbol table.") - elif isinstance(value.value, ast.Attribute): - # object field access from struct - if (isinstance(value.value.value, ast.Name) and - local_sym_tab and - value.value.value.id in local_sym_tab): - var_name = value.value.value.id - field_name = value.value.attr - if local_var_metadata and var_name in local_var_metadata: - var_type = local_var_metadata[var_name] - if var_type in struct_sym_tab: - struct_info = struct_sym_tab[var_type] - if field_name in struct_info.fields: - field_type = struct_info.field_type( - field_name) - if isinstance(field_type, ir.IntType): - fmt_parts.append("%lld") - exprs.append(value.value) - elif field_type == ir.PointerType(ir.IntType(8)): - fmt_parts.append("%s") - exprs.append(value.value) - else: - raise NotImplementedError( - "Only integer and pointer types are supported in formatted values.") - else: - raise ValueError( - f"Field {field_name} not found in struct {var_type}.") - else: - raise ValueError( - f"Struct type {var_type} for variable {var_name} not found in struct symbol table.") - else: - raise ValueError( - f"Metadata for variable {var_name} not found in local variable metadata.") - else: - raise ValueError( - f"Variable {value.value.value.id} not found in local symbol table.") - else: - raise NotImplementedError( - "Only simple variable names are supported in formatted values.") - else: - raise NotImplementedError( - "Unsupported value type in f-string.") + fn_type = ir.FunctionType( + ir.IntType(64), [ir.PointerType(), ir.IntType(32)], var_arg=True) + fn_ptr_type = ir.PointerType(fn_type) + fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_PRINTK.value) + fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) - fmt_str = "".join(fmt_parts) + "\n" + "\0" - fmt_name = f"{func.name}____fmt{func._fmt_counter}" - func._fmt_counter += 1 - - fmt_gvar = ir.GlobalVariable( - module, ir.ArrayType(ir.IntType(8), len(fmt_str)), name=fmt_name) - fmt_gvar.global_constant = True - fmt_gvar.initializer = ir.Constant( # type: ignore - ir.ArrayType(ir.IntType(8), len(fmt_str)), - bytearray(fmt_str.encode("utf8")) - ) - fmt_gvar.linkage = "internal" - fmt_gvar.align = 1 # type: ignore - - fmt_ptr = builder.bitcast(fmt_gvar, ir.PointerType()) - - args = [fmt_ptr, ir.Constant(ir.IntType(32), len(fmt_str))] - - # Only 3 args supported in bpf_printk - if len(exprs) > 3: - print( - "Warning: bpf_printk supports up to 3 arguments, extra arguments will be ignored.") - - for expr in exprs[:3]: - print(f"{ast.dump(expr)}") - val, _ = eval_expr(func, module, builder, - expr, local_sym_tab, None, struct_sym_tab, local_var_metadata) - if val: - if isinstance(val.type, ir.PointerType): - val = builder.ptrtoint(val, ir.IntType(64)) - elif isinstance(val.type, ir.IntType): - if val.type.width < 64: - val = builder.sext(val, ir.IntType(64)) - else: - print( - "Warning: Only integer and pointer types are supported in bpf_printk arguments. Others will be converted to 0.") - val = ir.Constant(ir.IntType(64), 0) - args.append(val) - else: - print( - "Warning: Failed to evaluate expression for bpf_printk argument. It will be converted to 0.") - args.append(ir.Constant(ir.IntType(64), 0)) - fn_type = ir.FunctionType(ir.IntType( - 64), [ir.PointerType(), ir.IntType(32)], var_arg=True) - fn_ptr_type = ir.PointerType(fn_type) - fn_addr = ir.Constant(ir.IntType(64), 6) - fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) - return builder.call(fn_ptr, args, tail=True) - - for arg in call.args: - if isinstance(arg, ast.Constant) and isinstance(arg.value, str): - fmt_str = arg.value + "\n" + "\0" - fmt_name = f"{func.name}____fmt{func._fmt_counter}" - func._fmt_counter += 1 - - fmt_gvar = ir.GlobalVariable( - module, ir.ArrayType(ir.IntType(8), len(fmt_str)), name=fmt_name) - fmt_gvar.global_constant = True - fmt_gvar.initializer = ir.Constant( # type: ignore - ir.ArrayType(ir.IntType(8), len(fmt_str)), - bytearray(fmt_str.encode("utf8")) - ) - fmt_gvar.linkage = "internal" - fmt_gvar.align = 1 # type: ignore - - fmt_ptr = builder.bitcast(fmt_gvar, ir.PointerType()) - - fn_type = ir.FunctionType(ir.IntType( - 64), [ir.PointerType(), ir.IntType(32)], var_arg=True) - fn_ptr_type = ir.PointerType(fn_type) - fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_PRINTK.value) - fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) - - builder.call(fn_ptr, [fmt_ptr, ir.Constant( - ir.IntType(32), len(fmt_str))], tail=True) + builder.call(fn_ptr, args, tail=True) return None diff --git a/pythonbpf/helper/helper_utils.py b/pythonbpf/helper/helper_utils.py index 20d152a..1d0fb73 100644 --- a/pythonbpf/helper/helper_utils.py +++ b/pythonbpf/helper/helper_utils.py @@ -73,6 +73,15 @@ def get_flags_val(arg, builder, local_sym_tab): "Only simple variable names or integer constants are supported as flags in map helpers.") +def _simple_string_print(string_value, module, builder, func): + """Emit code for a simple string print statement.""" + fmt_str = string_value + "\n\0" + fmt_ptr = _create_format_string_global(fmt_str, func, module, builder) + + args = [fmt_ptr, ir.Constant(ir.IntType(32), len(fmt_str))] + return args + + def _handle_fstring_print(joined_str, module, builder, func, local_sym_tab=None, struct_sym_tab=None, local_var_metadata=None): @@ -93,10 +102,8 @@ def _handle_fstring_print(joined_str, module, builder, func, raise NotImplementedError( f"Unsupported f-string value type: {type(value)}") - fmt_str = "".join(fmt_parts) + "\n\0" - fmt_ptr = _create_format_string_global(fmt_str, func, module, builder) - - args = [fmt_ptr, ir.Constant(ir.IntType(32), len(fmt_str))] + fmt_str = "".join(fmt_parts) + args = _simple_string_print(fmt_str, module, builder, func) # NOTE: Process expressions (limited to 3 due to BPF constraints) if len(exprs) > 3: @@ -109,8 +116,7 @@ def _handle_fstring_print(joined_str, module, builder, func, local_var_metadata) args.append(arg_value) - # Call the BPF_PRINTK helper - return _call_bpf_printk_helper(args, builder) + return args def _process_constant_in_fstring(cst, fmt_parts, exprs):