mirror of
https://github.com/varun-r-mallya/Python-BPF.git
synced 2025-12-31 21:06:25 +00:00
Remove redundant functions from struct_pass
This commit is contained in:
@ -5,6 +5,13 @@ from .type_deducer import ctypes_to_ir
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO: Shall we allow the following syntax:
|
||||
# struct MyStruct:
|
||||
# field1: int
|
||||
# field2: str(32)
|
||||
# Where int is mapped to c_uint64?
|
||||
# Shall we just int64, int32 and uint32 similarly?
|
||||
|
||||
|
||||
def structs_proc(tree, module, chunks):
|
||||
""" Process all class definitions to find BPF structs """
|
||||
@ -27,62 +34,16 @@ def is_bpf_struct(cls_node):
|
||||
def process_bpf_struct(cls_node, module):
|
||||
""" Process a single BPF struct definition """
|
||||
|
||||
field_names = []
|
||||
field_types = []
|
||||
|
||||
for item in cls_node.body:
|
||||
#
|
||||
# field syntax:
|
||||
# class struct_example:
|
||||
# num: c_uint64
|
||||
#
|
||||
if isinstance(item, ast.AnnAssign):
|
||||
if isinstance(item.target, ast.Name):
|
||||
print(f"Field: {item.target.id}, Type: "
|
||||
f"{ast.dump(item.annotation)}")
|
||||
field_names.append(item.target.id)
|
||||
if isinstance(item.annotation, ast.Call):
|
||||
if isinstance(item.annotation.func, ast.Name):
|
||||
if item.annotation.func.id == "str":
|
||||
# This is a char array with fixed length
|
||||
# TODO: For now assume str is always with constant
|
||||
field_types.append(ir.ArrayType(
|
||||
ir.IntType(8), item.annotation.args[0].value))
|
||||
else:
|
||||
field_types.append(
|
||||
ctypes_to_ir(item.annotation.id))
|
||||
else:
|
||||
print(f"Unsupported struct field: {ast.dump(item)}")
|
||||
return
|
||||
|
||||
curr_offset = 0
|
||||
for ftype in field_types:
|
||||
if isinstance(ftype, ir.IntType):
|
||||
fsize = ftype.width // 8
|
||||
alignment = fsize
|
||||
elif isinstance(ftype, ir.ArrayType):
|
||||
fsize = ftype.count * (ftype.element.width // 8)
|
||||
alignment = ftype.element.width // 8
|
||||
elif isinstance(ftype, ir.PointerType):
|
||||
fsize = 8
|
||||
alignment = 8
|
||||
else:
|
||||
print(f"Unsupported field type in struct {cls_node.name}")
|
||||
return
|
||||
padding = (alignment - (curr_offset % alignment)) % alignment
|
||||
curr_offset += padding
|
||||
curr_offset += fsize
|
||||
final_padding = (8 - (curr_offset % 8)) % 8
|
||||
total_size = curr_offset + final_padding
|
||||
|
||||
field_names, field_types = parse_struct_fields(cls_node)
|
||||
total_size = calc_struct_size(field_types)
|
||||
struct_type = ir.LiteralStructType(field_types)
|
||||
structs_sym_tab[cls_node.name] = {
|
||||
logger.info(f"Created struct {cls_node.name} with fields {field_names}")
|
||||
return {
|
||||
"type": struct_type,
|
||||
"fields": {name: idx for idx, name in enumerate(field_names)},
|
||||
"size": total_size,
|
||||
"field_types": field_types,
|
||||
}
|
||||
print(f"Created struct {cls_node.name} with fields {field_names}")
|
||||
|
||||
|
||||
def parse_struct_fields(cls_node):
|
||||
@ -91,7 +52,8 @@ def parse_struct_fields(cls_node):
|
||||
field_types = []
|
||||
|
||||
for item in cls_node.body:
|
||||
if isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name):
|
||||
if isinstance(item, ast.AnnAssign) and \
|
||||
isinstance(item.target, ast.Name):
|
||||
field_names.append(item.target.id)
|
||||
field_types.append(get_type_from_ann(item.annotation))
|
||||
else:
|
||||
@ -112,3 +74,26 @@ def get_type_from_ann(annotation):
|
||||
return ctypes_to_ir(annotation.id)
|
||||
|
||||
raise TypeError(f"Unsupported annotation type: {ast.dump(annotation)}")
|
||||
|
||||
|
||||
def calc_struct_size(field_types):
|
||||
""" Calculate total size of the struct with alignment and padding """
|
||||
curr_offset = 0
|
||||
for ftype in field_types:
|
||||
if isinstance(ftype, ir.IntType):
|
||||
fsize = ftype.width // 8
|
||||
alignment = fsize
|
||||
elif isinstance(ftype, ir.ArrayType):
|
||||
fsize = ftype.count * (ftype.element.width // 8)
|
||||
alignment = ftype.element.width // 8
|
||||
elif isinstance(ftype, ir.PointerType):
|
||||
fsize = 8
|
||||
alignment = 8
|
||||
else:
|
||||
raise TypeError(f"Unsupported field type: {ftype}")
|
||||
|
||||
padding = (alignment - (curr_offset % alignment)) % alignment
|
||||
curr_offset += padding + fsize
|
||||
|
||||
final_padding = (8 - (curr_offset % 8)) % 8
|
||||
return curr_offset + final_padding
|
||||
|
||||
Reference in New Issue
Block a user