diff --git a/pythonbpf/helper/helper_utils.py b/pythonbpf/helper/helper_utils.py index 4b5534e..20d152a 100644 --- a/pythonbpf/helper/helper_utils.py +++ b/pythonbpf/helper/helper_utils.py @@ -1,6 +1,7 @@ import ast import logging from llvmlite import ir +from pythonbpf.expr_pass import eval_expr logger = logging.getLogger(__name__) @@ -88,6 +89,28 @@ def _handle_fstring_print(joined_str, module, builder, func, _process_fval(value, fmt_parts, exprs, local_sym_tab, struct_sym_tab, local_var_metadata) + else: + raise NotImplementedError( + f"Unsupported f-string value type: {type(value)}") + + fmt_str = "".join(fmt_parts) + "\n\0" + fmt_ptr = _create_format_string_global(fmt_str, func, module, builder) + + args = [fmt_ptr, ir.Constant(ir.IntType(32), len(fmt_str))] + + # NOTE: Process expressions (limited to 3 due to BPF constraints) + if len(exprs) > 3: + logger.warn( + "bpf_printk supports up to 3 arguments, extra arguments will be ignored.") + + for expr in exprs[:3]: + arg_value = _prepare_expr_args(expr, func, module, builder, + local_sym_tab, struct_sym_tab, + local_var_metadata) + args.append(arg_value) + + # Call the BPF_PRINTK helper + return _call_bpf_printk_helper(args, builder) def _process_constant_in_fstring(cst, fmt_parts, exprs): @@ -176,3 +199,58 @@ def _populate_fval(ftype, node, fmt_parts, exprs): else: raise NotImplementedError( f"Unsupported field type in f-string: {ftype}") + + +def _create_format_string_global(fmt_str, func, module, builder): + """Create a global variable for the format string.""" + fmt_name = f"{func.name}____fmt{func._fmt_counter}" + func._fmt_counter += 1 + + fmt_gvar = ir.GlobalVariable( + module, ir.ArrayType(ir.IntType(8), len(fmt_str)), name=fmt_name) + fmt_gvar.global_constant = True + fmt_gvar.initializer = ir.Constant( + ir.ArrayType(ir.IntType(8), len(fmt_str)), + bytearray(fmt_str.encode("utf8")) + ) + fmt_gvar.linkage = "internal" + fmt_gvar.align = 1 + + return builder.bitcast(fmt_gvar, ir.PointerType()) + + +def _prepare_expr_args(expr, func, module, builder, + local_sym_tab, struct_sym_tab, + local_var_metadata): + """Evaluate and prepare an expression to be used as an argument for bpf_printk.""" + print(f"{ast.dump(expr)}") + val, _ = eval_expr(func, module, builder, expr, + local_sym_tab, None, struct_sym_tab, + local_var_metadata) + + if val: + if isinstance(val.type, ir.PointerType): + val = builder.ptrtoint(val, ir.IntType(64)) + elif isinstance(val.type, ir.IntType): + if val.type.width < 64: + val = builder.sext(val, ir.IntType(64)) + else: + logger.warn( + "Only int and ptr supported in bpf_printk arguments. Others default to 0.") + val = ir.Constant(ir.IntType(64), 0) + return val + else: + logger.warn( + "Failed to evaluate expression for bpf_printk argument. It will be converted to 0.") + return ir.Constant(ir.IntType(64), 0) + + +def _call_bpf_printk_helper(args, builder): + """Call the BPF_PRINTK helper function with the provided arguments.""" + fn_type = ir.FunctionType( + ir.IntType(64), [ir.PointerType(), ir.IntType(32)], var_arg=True) + fn_ptr_type = ir.PointerType(fn_type) + fn_addr = ir.Constant(ir.IntType(64), BPFHelperID.BPF_PRINTK.value) + fn_ptr = builder.inttoptr(fn_addr, fn_ptr_type) + + return builder.call(fn_ptr, args, tail=True)