diff --git a/pythonbpf/structs_pass.py b/pythonbpf/structs_pass.py index 8714b01..4ba3869 100644 --- a/pythonbpf/structs_pass.py +++ b/pythonbpf/structs_pass.py @@ -5,6 +5,13 @@ from .type_deducer import ctypes_to_ir logger = logging.getLogger(__name__) +# TODO: Shall we allow the following syntax: +# struct MyStruct: +# field1: int +# field2: str(32) +# Where int is mapped to c_uint64? +# Shall we just int64, int32 and uint32 similarly? + def structs_proc(tree, module, chunks): """ Process all class definitions to find BPF structs """ @@ -27,62 +34,16 @@ def is_bpf_struct(cls_node): def process_bpf_struct(cls_node, module): """ Process a single BPF struct definition """ - field_names = [] - field_types = [] - - for item in cls_node.body: - # - # field syntax: - # class struct_example: - # num: c_uint64 - # - if isinstance(item, ast.AnnAssign): - if isinstance(item.target, ast.Name): - print(f"Field: {item.target.id}, Type: " - f"{ast.dump(item.annotation)}") - field_names.append(item.target.id) - if isinstance(item.annotation, ast.Call): - if isinstance(item.annotation.func, ast.Name): - if item.annotation.func.id == "str": - # This is a char array with fixed length - # TODO: For now assume str is always with constant - field_types.append(ir.ArrayType( - ir.IntType(8), item.annotation.args[0].value)) - else: - field_types.append( - ctypes_to_ir(item.annotation.id)) - else: - print(f"Unsupported struct field: {ast.dump(item)}") - return - - curr_offset = 0 - for ftype in field_types: - if isinstance(ftype, ir.IntType): - fsize = ftype.width // 8 - alignment = fsize - elif isinstance(ftype, ir.ArrayType): - fsize = ftype.count * (ftype.element.width // 8) - alignment = ftype.element.width // 8 - elif isinstance(ftype, ir.PointerType): - fsize = 8 - alignment = 8 - else: - print(f"Unsupported field type in struct {cls_node.name}") - return - padding = (alignment - (curr_offset % alignment)) % alignment - curr_offset += padding - curr_offset += fsize - final_padding = (8 - (curr_offset % 8)) % 8 - total_size = curr_offset + final_padding - + field_names, field_types = parse_struct_fields(cls_node) + total_size = calc_struct_size(field_types) struct_type = ir.LiteralStructType(field_types) - structs_sym_tab[cls_node.name] = { + logger.info(f"Created struct {cls_node.name} with fields {field_names}") + return { "type": struct_type, "fields": {name: idx for idx, name in enumerate(field_names)}, "size": total_size, "field_types": field_types, } - print(f"Created struct {cls_node.name} with fields {field_names}") def parse_struct_fields(cls_node): @@ -91,7 +52,8 @@ def parse_struct_fields(cls_node): field_types = [] for item in cls_node.body: - if isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name): + if isinstance(item, ast.AnnAssign) and \ + isinstance(item.target, ast.Name): field_names.append(item.target.id) field_types.append(get_type_from_ann(item.annotation)) else: @@ -112,3 +74,26 @@ def get_type_from_ann(annotation): return ctypes_to_ir(annotation.id) raise TypeError(f"Unsupported annotation type: {ast.dump(annotation)}") + + +def calc_struct_size(field_types): + """ Calculate total size of the struct with alignment and padding """ + curr_offset = 0 + for ftype in field_types: + if isinstance(ftype, ir.IntType): + fsize = ftype.width // 8 + alignment = fsize + elif isinstance(ftype, ir.ArrayType): + fsize = ftype.count * (ftype.element.width // 8) + alignment = ftype.element.width // 8 + elif isinstance(ftype, ir.PointerType): + fsize = 8 + alignment = 8 + else: + raise TypeError(f"Unsupported field type: {ftype}") + + padding = (alignment - (curr_offset % alignment)) % alignment + curr_offset += padding + fsize + + final_padding = (8 - (curr_offset % 8)) % 8 + return curr_offset + final_padding