1 Commits

Author SHA1 Message Date
3f66c4f53f Initial plan 2025-11-22 07:59:21 +00:00
13 changed files with 53 additions and 156 deletions

View File

@ -12,7 +12,7 @@ jobs:
name: Format name: Format
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v6 - uses: actions/checkout@v5
- uses: actions/setup-python@v6 - uses: actions/setup-python@v6
with: with:
python-version: "3.x" python-version: "3.x"

View File

@ -20,7 +20,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v6 - uses: actions/checkout@v5
- uses: actions/setup-python@v6 - uses: actions/setup-python@v6
with: with:

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "pythonbpf" name = "pythonbpf"
version = "0.1.7" version = "0.1.6"
description = "Reduced Python frontend for eBPF" description = "Reduced Python frontend for eBPF"
authors = [ authors = [
{ name = "r41k0u", email="pragyanshchaturvedi18@gmail.com" }, { name = "r41k0u", email="pragyanshchaturvedi18@gmail.com" },
@ -29,7 +29,7 @@ license = {text = "Apache-2.0"}
requires-python = ">=3.10" requires-python = ">=3.10"
dependencies = [ dependencies = [
"llvmlite>=0.45", "llvmlite",
"astpretty", "astpretty",
"pylibbpf" "pylibbpf"
] ]

View File

@ -190,7 +190,7 @@ def _allocate_for_map_method(
# Main variable (pointer to pointer) # Main variable (pointer to pointer)
ir_type = ir.PointerType(ir.IntType(64)) ir_type = ir.PointerType(ir.IntType(64))
var = builder.alloca(ir_type, name=var_name) var = builder.alloca(ir_type, name=var_name)
local_sym_tab[var_name] = LocalSymbol(var, ir_type, value_type) local_sym_tab[var_name] = LocalSymbol(var, ir_type)
# Temporary variable for computed values # Temporary variable for computed values
tmp_ir_type = value_ir_type tmp_ir_type = value_ir_type
var_tmp = builder.alloca(tmp_ir_type, name=f"{var_name}_tmp") var_tmp = builder.alloca(tmp_ir_type, name=f"{var_name}_tmp")

View File

@ -25,7 +25,7 @@ import re
logger: Logger = logging.getLogger(__name__) logger: Logger = logging.getLogger(__name__)
VERSION = "v0.1.7" VERSION = "v0.1.6"
def finalize_module(original_str): def finalize_module(original_str):

View File

@ -1,6 +1,6 @@
from .expr_pass import eval_expr, handle_expr, get_operand_value from .expr_pass import eval_expr, handle_expr, get_operand_value
from .type_normalization import convert_to_bool, get_base_type_and_depth from .type_normalization import convert_to_bool, get_base_type_and_depth
from .ir_ops import deref_to_depth, access_struct_field from .ir_ops import deref_to_depth
from .call_registry import CallHandlerRegistry from .call_registry import CallHandlerRegistry
from .vmlinux_registry import VmlinuxHandlerRegistry from .vmlinux_registry import VmlinuxHandlerRegistry
@ -10,7 +10,6 @@ __all__ = [
"convert_to_bool", "convert_to_bool",
"get_base_type_and_depth", "get_base_type_and_depth",
"deref_to_depth", "deref_to_depth",
"access_struct_field",
"get_operand_value", "get_operand_value",
"CallHandlerRegistry", "CallHandlerRegistry",
"VmlinuxHandlerRegistry", "VmlinuxHandlerRegistry",

View File

@ -6,11 +6,11 @@ from typing import Dict
from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes
from .call_registry import CallHandlerRegistry from .call_registry import CallHandlerRegistry
from .ir_ops import deref_to_depth, access_struct_field
from .type_normalization import ( from .type_normalization import (
convert_to_bool, convert_to_bool,
handle_comparator, handle_comparator,
get_base_type_and_depth, get_base_type_and_depth,
deref_to_depth,
) )
from .vmlinux_registry import VmlinuxHandlerRegistry from .vmlinux_registry import VmlinuxHandlerRegistry
from ..vmlinux_parser.dependency_node import Field from ..vmlinux_parser.dependency_node import Field
@ -61,7 +61,6 @@ def _handle_constant_expr(module, builder, expr: ast.Constant):
def _handle_attribute_expr( def _handle_attribute_expr(
func,
expr: ast.Attribute, expr: ast.Attribute,
local_sym_tab: Dict, local_sym_tab: Dict,
structs_sym_tab: Dict, structs_sym_tab: Dict,
@ -97,23 +96,13 @@ def _handle_attribute_expr(
) )
return None return None
if var_metadata in structs_sym_tab: # Regular user-defined struct
return access_struct_field( metadata = structs_sym_tab.get(var_metadata)
builder, if metadata and attr_name in metadata.fields:
var_ptr, gep = metadata.gep(builder, var_ptr, attr_name)
var_type, val = builder.load(gep)
var_metadata, field_type = metadata.field_type(attr_name)
expr.attr, return val, field_type
structs_sym_tab,
func,
)
else:
logger.error(f"Struct metadata for '{var_name}' not found")
else:
logger.error(f"Undefined variable '{var_name}' for attribute access")
else:
logger.error("Unsupported attribute base expression type")
return None return None
@ -659,9 +648,7 @@ def eval_expr(
logger.warning(f"Unknown call: {ast.dump(expr)}") logger.warning(f"Unknown call: {ast.dump(expr)}")
return None return None
elif isinstance(expr, ast.Attribute): elif isinstance(expr, ast.Attribute):
return _handle_attribute_expr( return _handle_attribute_expr(expr, local_sym_tab, structs_sym_tab, builder)
func, expr, local_sym_tab, structs_sym_tab, builder
)
elif isinstance(expr, ast.BinOp): elif isinstance(expr, ast.BinOp):
return _handle_binary_op( return _handle_binary_op(
func, func,

View File

@ -17,100 +17,34 @@ def deref_to_depth(func, builder, val, target_depth):
# dereference with null check # dereference with null check
pointee_type = cur_type.pointee pointee_type = cur_type.pointee
null_check_block = builder.block
not_null_block = func.append_basic_block(name=f"deref_not_null_{depth}")
merge_block = func.append_basic_block(name=f"deref_merge_{depth}")
def load_op(builder, ptr): null_ptr = ir.Constant(cur_type, None)
return builder.load(ptr) is_not_null = builder.icmp_signed("!=", cur_val, null_ptr)
logger.debug(f"Inserted null check for pointer at depth {depth}")
cur_val = _null_checked_operation( builder.cbranch(is_not_null, not_null_block, merge_block)
func, builder, cur_val, load_op, pointee_type, f"deref_{depth}"
builder.position_at_end(not_null_block)
dereferenced_val = builder.load(cur_val)
logger.debug(f"Dereferenced to depth {depth - 1}, type: {pointee_type}")
builder.branch(merge_block)
builder.position_at_end(merge_block)
phi = builder.phi(pointee_type, name=f"deref_result_{depth}")
zero_value = (
ir.Constant(pointee_type, 0)
if isinstance(pointee_type, ir.IntType)
else ir.Constant(pointee_type, None)
) )
phi.add_incoming(zero_value, null_check_block)
phi.add_incoming(dereferenced_val, not_null_block)
# Continue with phi result
cur_val = phi
cur_type = pointee_type cur_type = pointee_type
logger.debug(f"Dereferenced to depth {depth}, type: {pointee_type}")
return cur_val return cur_val
def _null_checked_operation(func, builder, ptr, operation, result_type, name_prefix):
"""
Generic null-checked operation on a pointer.
"""
curr_block = builder.block
not_null_block = func.append_basic_block(name=f"{name_prefix}_not_null")
merge_block = func.append_basic_block(name=f"{name_prefix}_merge")
null_ptr = ir.Constant(ptr.type, None)
is_not_null = builder.icmp_signed("!=", ptr, null_ptr)
builder.cbranch(is_not_null, not_null_block, merge_block)
builder.position_at_end(not_null_block)
result = operation(builder, ptr)
not_null_after = builder.block
builder.branch(merge_block)
builder.position_at_end(merge_block)
phi = builder.phi(result_type, name=f"{name_prefix}_result")
if isinstance(result_type, ir.IntType):
null_val = ir.Constant(result_type, 0)
elif isinstance(result_type, ir.PointerType):
null_val = ir.Constant(result_type, None)
else:
null_val = ir.Constant(result_type, ir.Undefined)
phi.add_incoming(null_val, curr_block)
phi.add_incoming(result, not_null_after)
return phi
def access_struct_field(
builder, var_ptr, var_type, var_metadata, field_name, structs_sym_tab, func=None
):
"""
Access a struct field - automatically returns value or pointer based on field type.
"""
metadata = (
structs_sym_tab.get(var_metadata)
if isinstance(var_metadata, str)
else var_metadata
)
if not metadata or field_name not in metadata.fields:
raise ValueError(f"Field '{field_name}' not found in struct")
field_type = metadata.field_type(field_name)
is_ptr_to_struct = isinstance(var_type, ir.PointerType) and isinstance(
var_metadata, str
)
# Get struct pointer
struct_ptr = builder.load(var_ptr) if is_ptr_to_struct else var_ptr
should_load = not isinstance(field_type, ir.ArrayType)
def field_access_op(builder, ptr):
typed_ptr = builder.bitcast(ptr, metadata.ir_type.as_pointer())
field_ptr = metadata.gep(builder, typed_ptr, field_name)
return builder.load(field_ptr) if should_load else field_ptr
# Handle null check for pointer-to-struct
if is_ptr_to_struct:
if func is None:
raise ValueError("func required for null-safe struct pointer access")
if should_load:
result_type = field_type
else:
result_type = field_type.as_pointer()
result = _null_checked_operation(
func,
builder,
struct_ptr,
field_access_op,
result_type,
f"field_{field_name}",
)
return result, field_type
field_ptr = metadata.gep(builder, struct_ptr, field_name)
result = builder.load(field_ptr) if should_load else field_ptr
return result, field_type

View File

@ -5,7 +5,6 @@ from llvmlite import ir
from pythonbpf.expr import ( from pythonbpf.expr import (
get_operand_value, get_operand_value,
eval_expr, eval_expr,
access_struct_field,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -136,7 +135,7 @@ def get_or_create_ptr_from_arg(
and field_type.element.width == 8 and field_type.element.width == 8
): ):
ptr, sz = get_char_array_ptr_and_size( ptr, sz = get_char_array_ptr_and_size(
arg, builder, local_sym_tab, struct_sym_tab, func arg, builder, local_sym_tab, struct_sym_tab
) )
if not ptr: if not ptr:
raise ValueError("Failed to get char array pointer from struct field") raise ValueError("Failed to get char array pointer from struct field")
@ -267,9 +266,7 @@ def get_buffer_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab):
) )
def get_char_array_ptr_and_size( def get_char_array_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab):
buf_arg, builder, local_sym_tab, struct_sym_tab, func=None
):
"""Get pointer to char array and its size.""" """Get pointer to char array and its size."""
# Struct field: obj.field # Struct field: obj.field
@ -280,11 +277,11 @@ def get_char_array_ptr_and_size(
if not (local_sym_tab and var_name in local_sym_tab): if not (local_sym_tab and var_name in local_sym_tab):
raise ValueError(f"Variable '{var_name}' not found") raise ValueError(f"Variable '{var_name}' not found")
struct_ptr, struct_type, struct_metadata = local_sym_tab[var_name] struct_type = local_sym_tab[var_name].metadata
if not (struct_sym_tab and struct_metadata in struct_sym_tab): if not (struct_sym_tab and struct_type in struct_sym_tab):
raise ValueError(f"Struct type '{struct_metadata}' not found") raise ValueError(f"Struct type '{struct_type}' not found")
struct_info = struct_sym_tab[struct_metadata] struct_info = struct_sym_tab[struct_type]
if field_name not in struct_info.fields: if field_name not in struct_info.fields:
raise ValueError(f"Field '{field_name}' not found") raise ValueError(f"Field '{field_name}' not found")
@ -295,24 +292,8 @@ def get_char_array_ptr_and_size(
) )
return None, 0 return None, 0
# Check if char array struct_ptr = local_sym_tab[var_name].var
if not ( field_ptr = struct_info.gep(builder, struct_ptr, field_name)
isinstance(field_type, ir.ArrayType)
and isinstance(field_type.element, ir.IntType)
and field_type.element.width == 8
):
logger.warning("Field is not a char array")
return None, 0
field_ptr, _ = access_struct_field(
builder,
struct_ptr,
struct_type,
struct_metadata,
field_name,
struct_sym_tab,
func,
)
# GEP to first element: [N x i8]* -> i8* # GEP to first element: [N x i8]* -> i8*
buf_ptr = builder.gep( buf_ptr = builder.gep(

View File

@ -222,7 +222,7 @@ def _prepare_expr_args(expr, func, module, builder, local_sym_tab, struct_sym_ta
# Special case: struct field char array needs pointer to first element # Special case: struct field char array needs pointer to first element
if isinstance(expr, ast.Attribute): if isinstance(expr, ast.Attribute):
char_array_ptr, _ = get_char_array_ptr_and_size( char_array_ptr, _ = get_char_array_ptr_and_size(
expr, builder, local_sym_tab, struct_sym_tab, func expr, builder, local_sym_tab, struct_sym_tab
) )
if char_array_ptr: if char_array_ptr:
return char_array_ptr return char_array_ptr

View File

@ -117,7 +117,6 @@ def _get_key_val_dbg_type(name, generator, structs_sym_tab):
type_obj = structs_sym_tab.get(name) type_obj = structs_sym_tab.get(name)
if type_obj: if type_obj:
logger.info(f"Found struct named {name}, generating debug type")
return _get_struct_debug_type(type_obj, generator, structs_sym_tab) return _get_struct_debug_type(type_obj, generator, structs_sym_tab)
# Fallback to basic types # Fallback to basic types
@ -166,6 +165,6 @@ def _get_struct_debug_type(struct_obj, generator, structs_sym_tab):
) )
elements_arr.append(member) elements_arr.append(member)
struct_type = generator.create_struct_type( struct_type = generator.create_struct_type(
elements_arr, struct_obj.size * 8, is_distinct=True elements_arr, struct_obj.size, is_distinct=True
) )
return struct_type return struct_type

View File

@ -17,7 +17,6 @@ mapping = {
"c_ulong": ir.IntType(64), "c_ulong": ir.IntType(64),
"c_longlong": ir.IntType(64), "c_longlong": ir.IntType(64),
"c_uint": ir.IntType(32), "c_uint": ir.IntType(32),
"c_int": ir.IntType(32),
# Not so sure about this one # Not so sure about this one
"str": ir.PointerType(ir.IntType(8)), "str": ir.PointerType(ir.IntType(8)),
} }

View File

@ -1,6 +1,6 @@
from pythonbpf import bpf, section, struct, bpfglobal, compile, map from pythonbpf import bpf, section, struct, bpfglobal, compile, map
from pythonbpf.maps import HashMap from pythonbpf.maps import HashMap
from pythonbpf.helper import pid, comm from pythonbpf.helper import pid
from ctypes import c_void_p, c_int64 from ctypes import c_void_p, c_int64
@ -9,7 +9,6 @@ from ctypes import c_void_p, c_int64
class val_type: class val_type:
counter: c_int64 counter: c_int64
shizzle: c_int64 shizzle: c_int64
comm: str(16)
@bpf @bpf
@ -23,7 +22,6 @@ def last() -> HashMap:
def hello_world(ctx: c_void_p) -> c_int64: def hello_world(ctx: c_void_p) -> c_int64:
obj = val_type() obj = val_type()
obj.counter, obj.shizzle = 42, 96 obj.counter, obj.shizzle = 42, 96
comm(obj.comm)
t = last.lookup(obj) t = last.lookup(obj)
if t: if t:
print(f"Found existing entry: counter={obj.counter}, pid={t}") print(f"Found existing entry: counter={obj.counter}, pid={t}")