diff --git a/pythonbpf/expr/__init__.py b/pythonbpf/expr/__init__.py new file mode 100644 index 0000000..d58c543 --- /dev/null +++ b/pythonbpf/expr/__init__.py @@ -0,0 +1,4 @@ +from .expr_pass import eval_expr, handle_expr +from .type_normalization import convert_to_bool + +__all__ = ["eval_expr", "handle_expr", "convert_to_bool"] diff --git a/pythonbpf/expr_pass.py b/pythonbpf/expr/expr_pass.py similarity index 51% rename from pythonbpf/expr_pass.py rename to pythonbpf/expr/expr_pass.py index 56d047e..21be196 100644 --- a/pythonbpf/expr_pass.py +++ b/pythonbpf/expr/expr_pass.py @@ -4,7 +4,8 @@ from logging import Logger import logging from typing import Dict -from .type_deducer import ctypes_to_ir, is_ctypes +from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes +from .type_normalization import convert_to_bool, handle_comparator logger: Logger = logging.getLogger(__name__) @@ -22,12 +23,10 @@ def _handle_name_expr(expr: ast.Name, local_sym_tab: Dict, builder: ir.IRBuilder def _handle_constant_expr(expr: ast.Constant): """Handle ast.Constant expressions.""" - if isinstance(expr.value, int): - return ir.Constant(ir.IntType(64), expr.value), ir.IntType(64) - elif isinstance(expr.value, bool): - return ir.Constant(ir.IntType(1), int(expr.value)), ir.IntType(1) + if isinstance(expr.value, int) or isinstance(expr.value, bool): + return ir.Constant(ir.IntType(64), int(expr.value)), ir.IntType(64) else: - logger.info("Unsupported constant type") + logger.error("Unsupported constant type") return None @@ -45,7 +44,6 @@ def _handle_attribute_expr( var_ptr, var_type, var_metadata = local_sym_tab[var_name] logger.info(f"Loading attribute {attr_name} from variable {var_name}") logger.info(f"Variable type: {var_type}, Variable ptr: {var_ptr}") - metadata = structs_sym_tab[var_metadata] if attr_name in metadata.fields: gep = metadata.gep(builder, var_ptr, attr_name) @@ -132,6 +130,199 @@ def _handle_ctypes_call( return val +def _handle_compare( + func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab=None +): + """Handle ast.Compare expressions.""" + + if len(cond.ops) != 1 or len(cond.comparators) != 1: + logger.error("Only single comparisons are supported") + return None + lhs = eval_expr( + func, + module, + builder, + cond.left, + local_sym_tab, + map_sym_tab, + structs_sym_tab, + ) + rhs = eval_expr( + func, + module, + builder, + cond.comparators[0], + local_sym_tab, + map_sym_tab, + structs_sym_tab, + ) + + if lhs is None or rhs is None: + logger.error("Failed to evaluate comparison operands") + return None + + lhs, _ = lhs + rhs, _ = rhs + return handle_comparator(func, builder, cond.ops[0], lhs, rhs) + + +def _handle_unary_op( + func, + module, + builder, + expr: ast.UnaryOp, + local_sym_tab, + map_sym_tab, + structs_sym_tab=None, +): + """Handle ast.UnaryOp expressions.""" + if not isinstance(expr.op, ast.Not): + logger.error("Only 'not' unary operator is supported") + return None + + operand = eval_expr( + func, module, builder, expr.operand, local_sym_tab, map_sym_tab, structs_sym_tab + ) + if operand is None: + logger.error("Failed to evaluate operand for unary operation") + return None + + operand_val, operand_type = operand + true_const = ir.Constant(ir.IntType(1), 1) + result = builder.xor(convert_to_bool(builder, operand_val), true_const) + return result, ir.IntType(1) + + +def _handle_and_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab): + """Handle `and` boolean operations.""" + + logger.debug(f"Handling 'and' operator with {len(expr.values)} operands") + + merge_block = func.append_basic_block(name="and.merge") + false_block = func.append_basic_block(name="and.false") + + incoming_values = [] + + for i, value in enumerate(expr.values): + is_last = i == len(expr.values) - 1 + + # Evaluate current operand + operand_result = eval_expr( + func, None, builder, value, local_sym_tab, map_sym_tab, structs_sym_tab + ) + if operand_result is None: + logger.error(f"Failed to evaluate operand {i} in 'and' expression") + return None + + operand_val, operand_type = operand_result + + # Convert to boolean if needed + operand_bool = convert_to_bool(builder, operand_val) + current_block = builder.block + + if is_last: + # Last operand: result is this value + builder.branch(merge_block) + incoming_values.append((operand_bool, current_block)) + else: + # Not last: check if true, continue or short-circuit + next_check = func.append_basic_block(name=f"and.check_{i + 1}") + builder.cbranch(operand_bool, next_check, false_block) + builder.position_at_end(next_check) + + # False block: short-circuit with false + builder.position_at_end(false_block) + builder.branch(merge_block) + false_value = ir.Constant(ir.IntType(1), 0) + incoming_values.append((false_value, false_block)) + + # Merge block: phi node + builder.position_at_end(merge_block) + phi = builder.phi(ir.IntType(1), name="and.result") + for val, block in incoming_values: + phi.add_incoming(val, block) + + logger.debug(f"Generated 'and' with {len(incoming_values)} incoming values") + return phi, ir.IntType(1) + + +def _handle_or_op(func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab): + """Handle `or` boolean operations.""" + + logger.debug(f"Handling 'or' operator with {len(expr.values)} operands") + + merge_block = func.append_basic_block(name="or.merge") + true_block = func.append_basic_block(name="or.true") + + incoming_values = [] + + for i, value in enumerate(expr.values): + is_last = i == len(expr.values) - 1 + + # Evaluate current operand + operand_result = eval_expr( + func, None, builder, value, local_sym_tab, map_sym_tab, structs_sym_tab + ) + if operand_result is None: + logger.error(f"Failed to evaluate operand {i} in 'or' expression") + return None + + operand_val, operand_type = operand_result + + # Convert to boolean if needed + operand_bool = convert_to_bool(builder, operand_val) + current_block = builder.block + + if is_last: + # Last operand: result is this value + builder.branch(merge_block) + incoming_values.append((operand_bool, current_block)) + else: + # Not last: check if false, continue or short-circuit + next_check = func.append_basic_block(name=f"or.check_{i + 1}") + builder.cbranch(operand_bool, true_block, next_check) + builder.position_at_end(next_check) + + # True block: short-circuit with true + builder.position_at_end(true_block) + builder.branch(merge_block) + true_value = ir.Constant(ir.IntType(1), 1) + incoming_values.append((true_value, true_block)) + + # Merge block: phi node + builder.position_at_end(merge_block) + phi = builder.phi(ir.IntType(1), name="or.result") + for val, block in incoming_values: + phi.add_incoming(val, block) + + logger.debug(f"Generated 'or' with {len(incoming_values)} incoming values") + return phi, ir.IntType(1) + + +def _handle_boolean_op( + func, + module, + builder, + expr: ast.BoolOp, + local_sym_tab, + map_sym_tab, + structs_sym_tab=None, +): + """Handle `and` and `or` boolean operations.""" + + if isinstance(expr.op, ast.And): + return _handle_and_op( + func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab + ) + elif isinstance(expr.op, ast.Or): + return _handle_or_op( + func, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab + ) + else: + logger.error(f"Unsupported boolean operator: {type(expr.op).__name__}") + return None + + def eval_expr( func, module, @@ -212,6 +403,18 @@ def eval_expr( from pythonbpf.binary_ops import handle_binary_op return handle_binary_op(expr, builder, None, local_sym_tab) + elif isinstance(expr, ast.Compare): + return _handle_compare( + func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab + ) + elif isinstance(expr, ast.UnaryOp): + return _handle_unary_op( + func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab + ) + elif isinstance(expr, ast.BoolOp): + return _handle_boolean_op( + func, module, builder, expr, local_sym_tab, map_sym_tab, structs_sym_tab + ) logger.info("Unsupported expression evaluation") return None diff --git a/pythonbpf/expr/type_normalization.py b/pythonbpf/expr/type_normalization.py new file mode 100644 index 0000000..7a2fb57 --- /dev/null +++ b/pythonbpf/expr/type_normalization.py @@ -0,0 +1,128 @@ +from llvmlite import ir +import logging +import ast + +logger = logging.getLogger(__name__) + +COMPARISON_OPS = { + ast.Eq: "==", + ast.NotEq: "!=", + ast.Lt: "<", + ast.LtE: "<=", + ast.Gt: ">", + ast.GtE: ">=", + ast.Is: "==", + ast.IsNot: "!=", +} + + +def _get_base_type_and_depth(ir_type): + """Get the base type for pointer types.""" + cur_type = ir_type + depth = 0 + while isinstance(cur_type, ir.PointerType): + depth += 1 + cur_type = cur_type.pointee + return cur_type, depth + + +def _deref_to_depth(func, builder, val, target_depth): + """Dereference a pointer to a certain depth.""" + + cur_val = val + cur_type = val.type + + for depth in range(target_depth): + if not isinstance(val.type, ir.PointerType): + logger.error("Cannot dereference further, non-pointer type") + return None + + # dereference with null check + pointee_type = cur_type.pointee + null_check_block = builder.block + not_null_block = func.append_basic_block(name=f"deref_not_null_{depth}") + merge_block = func.append_basic_block(name=f"deref_merge_{depth}") + + null_ptr = ir.Constant(cur_type, None) + is_not_null = builder.icmp_signed("!=", cur_val, null_ptr) + logger.debug(f"Inserted null check for pointer at depth {depth}") + + builder.cbranch(is_not_null, not_null_block, merge_block) + + builder.position_at_end(not_null_block) + dereferenced_val = builder.load(cur_val) + logger.debug(f"Dereferenced to depth {depth - 1}, type: {pointee_type}") + builder.branch(merge_block) + + builder.position_at_end(merge_block) + phi = builder.phi(pointee_type, name=f"deref_result_{depth}") + + zero_value = ( + ir.Constant(pointee_type, 0) + if isinstance(pointee_type, ir.IntType) + else ir.Constant(pointee_type, None) + ) + phi.add_incoming(zero_value, null_check_block) + + phi.add_incoming(dereferenced_val, not_null_block) + + # Continue with phi result + cur_val = phi + cur_type = pointee_type + return cur_val + + +def _normalize_types(func, builder, lhs, rhs): + """Normalize types for comparison.""" + + logger.info(f"Normalizing types: {lhs.type} vs {rhs.type}") + if isinstance(lhs.type, ir.IntType) and isinstance(rhs.type, ir.IntType): + if lhs.type.width < rhs.type.width: + lhs = builder.sext(lhs, rhs.type) + else: + rhs = builder.sext(rhs, lhs.type) + return lhs, rhs + elif not isinstance(lhs.type, ir.PointerType) and not isinstance( + rhs.type, ir.PointerType + ): + logger.error(f"Type mismatch: {lhs.type} vs {rhs.type}") + return None, None + else: + lhs_base, lhs_depth = _get_base_type_and_depth(lhs.type) + rhs_base, rhs_depth = _get_base_type_and_depth(rhs.type) + if lhs_base == rhs_base: + if lhs_depth < rhs_depth: + rhs = _deref_to_depth(func, builder, rhs, rhs_depth - lhs_depth) + elif rhs_depth < lhs_depth: + lhs = _deref_to_depth(func, builder, lhs, lhs_depth - rhs_depth) + return _normalize_types(func, builder, lhs, rhs) + + +def convert_to_bool(builder, val): + """Convert a value to boolean.""" + if val.type == ir.IntType(1): + return val + if isinstance(val.type, ir.PointerType): + zero = ir.Constant(val.type, None) + else: + zero = ir.Constant(val.type, 0) + return builder.icmp_signed("!=", val, zero) + + +def handle_comparator(func, builder, op, lhs, rhs): + """Handle comparison operations.""" + + if lhs.type != rhs.type: + lhs, rhs = _normalize_types(func, builder, lhs, rhs) + + if lhs is None or rhs is None: + return None + + if type(op) not in COMPARISON_OPS: + logger.error(f"Unsupported comparison operator: {type(op)}") + return None + + predicate = COMPARISON_OPS[type(op)] + result = builder.icmp_signed(predicate, lhs, rhs) + logger.debug(f"Comparison result: {result}") + return result, ir.IntType(1) diff --git a/pythonbpf/functions/functions_pass.py b/pythonbpf/functions/functions_pass.py index 18904ec..7fc3feb 100644 --- a/pythonbpf/functions/functions_pass.py +++ b/pythonbpf/functions/functions_pass.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from pythonbpf.helper import HelperHandlerRegistry, handle_helper_call from pythonbpf.type_deducer import ctypes_to_ir from pythonbpf.binary_ops import handle_binary_op -from pythonbpf.expr_pass import eval_expr, handle_expr +from pythonbpf.expr import eval_expr, handle_expr, convert_to_bool from .return_utils import _handle_none_return, _handle_xdp_return, _is_xdp_name @@ -240,71 +240,13 @@ def handle_assign( logger.info("Unsupported assignment value type") -def handle_cond(func, module, builder, cond, local_sym_tab, map_sym_tab): - if isinstance(cond, ast.Constant): - if isinstance(cond.value, bool): - return ir.Constant(ir.IntType(1), int(cond.value)) - elif isinstance(cond.value, int): - return ir.Constant(ir.IntType(1), int(bool(cond.value))) - else: - logger.info("Unsupported constant type in condition") - return None - elif isinstance(cond, ast.Name): - if cond.id in local_sym_tab: - var = local_sym_tab[cond.id].var - val = builder.load(var) - if val.type != ir.IntType(1): - # Convert nonzero values to true, zero to false - if isinstance(val.type, ir.PointerType): - # For pointer types, compare with null pointer - zero = ir.Constant(val.type, None) - else: - # For integer types, compare with zero - zero = ir.Constant(val.type, 0) - val = builder.icmp_signed("!=", val, zero) - return val - else: - logger.info(f"Undefined variable {cond.id} in condition") - return None - elif isinstance(cond, ast.Compare): - lhs = eval_expr(func, module, builder, cond.left, local_sym_tab, map_sym_tab)[0] - if len(cond.ops) != 1 or len(cond.comparators) != 1: - logger.info("Unsupported complex comparison") - return None - rhs = eval_expr( - func, module, builder, cond.comparators[0], local_sym_tab, map_sym_tab - )[0] - op = cond.ops[0] - - if lhs.type != rhs.type: - if isinstance(lhs.type, ir.IntType) and isinstance(rhs.type, ir.IntType): - # Extend the smaller type to the larger type - if lhs.type.width < rhs.type.width: - lhs = builder.sext(lhs, rhs.type) - elif lhs.type.width > rhs.type.width: - rhs = builder.sext(rhs, lhs.type) - else: - logger.info("Type mismatch in comparison") - return None - - if isinstance(op, ast.Eq): - return builder.icmp_signed("==", lhs, rhs) - elif isinstance(op, ast.NotEq): - return builder.icmp_signed("!=", lhs, rhs) - elif isinstance(op, ast.Lt): - return builder.icmp_signed("<", lhs, rhs) - elif isinstance(op, ast.LtE): - return builder.icmp_signed("<=", lhs, rhs) - elif isinstance(op, ast.Gt): - return builder.icmp_signed(">", lhs, rhs) - elif isinstance(op, ast.GtE): - return builder.icmp_signed(">=", lhs, rhs) - else: - logger.info("Unsupported comparison operator") - return None - else: - logger.info("Unsupported condition expression") - return None +def handle_cond( + func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab=None +): + val = eval_expr( + func, module, builder, cond, local_sym_tab, map_sym_tab, structs_sym_tab + )[0] + return convert_to_bool(builder, val) def handle_if( @@ -320,7 +262,9 @@ def handle_if( else: else_block = None - cond = handle_cond(func, module, builder, stmt.test, local_sym_tab, map_sym_tab) + cond = handle_cond( + func, module, builder, stmt.test, local_sym_tab, map_sym_tab, structs_sym_tab + ) if else_block: builder.cbranch(cond, then_block, else_block) else: diff --git a/pythonbpf/helper/helper_utils.py b/pythonbpf/helper/helper_utils.py index 0da1e5e..68ab52c 100644 --- a/pythonbpf/helper/helper_utils.py +++ b/pythonbpf/helper/helper_utils.py @@ -3,7 +3,7 @@ import logging from collections.abc import Callable from llvmlite import ir -from pythonbpf.expr_pass import eval_expr +from pythonbpf.expr import eval_expr logger = logging.getLogger(__name__) diff --git a/pythonbpf/maps/maps_pass.py b/pythonbpf/maps/maps_pass.py index cc8dfa6..95748a8 100644 --- a/pythonbpf/maps/maps_pass.py +++ b/pythonbpf/maps/maps_pass.py @@ -3,7 +3,7 @@ from logging import Logger from llvmlite import ir from enum import Enum from .maps_utils import MapProcessorRegistry -from ..debuginfo import DebugInfoGenerator +from pythonbpf.debuginfo import DebugInfoGenerator import logging logger: Logger = logging.getLogger(__name__) diff --git a/tests/failing_tests/conditionals/helper_cond.py b/tests/failing_tests/conditionals/helper_cond.py new file mode 100644 index 0000000..8cf5bdb --- /dev/null +++ b/tests/failing_tests/conditionals/helper_cond.py @@ -0,0 +1,34 @@ +from pythonbpf import bpf, map, section, bpfglobal, compile +from ctypes import c_void_p, c_int64, c_uint64 +from pythonbpf.maps import HashMap + +# NOTE: Decided against fixing this +# as a workaround is assigning the result of lookup to a variable +# and then using that variable in the if statement. +# Might fix in future. + + +@bpf +@map +def last() -> HashMap: + return HashMap(key=c_uint64, value=c_uint64, max_entries=3) + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + last.update(0, 1) + if last.lookup(0) > 0: + print("Hello, World!") + else: + print("Goodbye, World!") + return + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/failing_tests/conditionals/struct_ptr.py b/tests/failing_tests/conditionals/struct_ptr.py new file mode 100644 index 0000000..7085f81 --- /dev/null +++ b/tests/failing_tests/conditionals/struct_ptr.py @@ -0,0 +1,34 @@ +from pythonbpf import bpf, struct, section, bpfglobal, compile +from ctypes import c_void_p, c_int64, c_uint64 + +# NOTE: Decided against fixing this +# as one workaround is to just check any field of the struct +# in the if statement. Ugly but works. +# Might fix in future. + + +@bpf +@struct +class data_t: + pid: c_uint64 + ts: c_uint64 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + dat = data_t() + if dat: + print("Hello, World!") + else: + print("Goodbye, World!") + return + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/conditionals/and.py b/tests/passing_tests/conditionals/and.py new file mode 100644 index 0000000..5cb4824 --- /dev/null +++ b/tests/passing_tests/conditionals/and.py @@ -0,0 +1,32 @@ +from pythonbpf import bpf, map, section, bpfglobal, compile +from ctypes import c_void_p, c_int64, c_uint64 +from pythonbpf.maps import HashMap + + +@bpf +@map +def last() -> HashMap: + return HashMap(key=c_uint64, value=c_uint64, max_entries=3) + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + last.update(0, 1) + last.update(1, 2) + x = last.lookup(0) + y = last.lookup(1) + if x and y: + print("Hello, World!") + else: + print("Goodbye, World!") + return + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/conditionals/bool.py b/tests/passing_tests/conditionals/bool.py new file mode 100644 index 0000000..341fa46 --- /dev/null +++ b/tests/passing_tests/conditionals/bool.py @@ -0,0 +1,21 @@ +from pythonbpf import bpf, section, bpfglobal, compile +from ctypes import c_void_p, c_int64 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + if True: + print("Hello, World!") + else: + print("Goodbye, World!") + return + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/conditionals/const_binop.py b/tests/passing_tests/conditionals/const_binop.py new file mode 100644 index 0000000..9dffd30 --- /dev/null +++ b/tests/passing_tests/conditionals/const_binop.py @@ -0,0 +1,21 @@ +from pythonbpf import bpf, section, bpfglobal, compile +from ctypes import c_void_p, c_int64 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + if (0 + 1) * 0: + print("Hello, World!") + else: + print("Goodbye, World!") + return + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/conditionals/const_int.py b/tests/passing_tests/conditionals/const_int.py new file mode 100644 index 0000000..47589c8 --- /dev/null +++ b/tests/passing_tests/conditionals/const_int.py @@ -0,0 +1,21 @@ +from pythonbpf import bpf, section, bpfglobal, compile +from ctypes import c_void_p, c_int64 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + if 0: + print("Hello, World!") + else: + print("Goodbye, World!") + return + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/conditionals/map.py b/tests/passing_tests/conditionals/map.py new file mode 100644 index 0000000..fa490a7 --- /dev/null +++ b/tests/passing_tests/conditionals/map.py @@ -0,0 +1,30 @@ +from pythonbpf import bpf, map, section, bpfglobal, compile +from ctypes import c_void_p, c_int64, c_uint64 +from pythonbpf.maps import HashMap + + +@bpf +@map +def last() -> HashMap: + return HashMap(key=c_uint64, value=c_uint64, max_entries=3) + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + # last.update(0, 1) + tsp = last.lookup(0) + if tsp: + print("Hello, World!") + else: + print("Goodbye, World!") + return + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/conditionals/map_comp.py b/tests/passing_tests/conditionals/map_comp.py new file mode 100644 index 0000000..2350f06 --- /dev/null +++ b/tests/passing_tests/conditionals/map_comp.py @@ -0,0 +1,30 @@ +from pythonbpf import bpf, map, section, bpfglobal, compile +from ctypes import c_void_p, c_int64, c_uint64 +from pythonbpf.maps import HashMap + + +@bpf +@map +def last() -> HashMap: + return HashMap(key=c_uint64, value=c_uint64, max_entries=3) + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + last.update(0, 1) + tsp = last.lookup(0) + if tsp > 0: + print("Hello, World!") + else: + print("Goodbye, World!") + return + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/conditionals/not.py b/tests/passing_tests/conditionals/not.py new file mode 100644 index 0000000..773291e --- /dev/null +++ b/tests/passing_tests/conditionals/not.py @@ -0,0 +1,30 @@ +from pythonbpf import bpf, map, section, bpfglobal, compile +from ctypes import c_void_p, c_int64, c_uint64 +from pythonbpf.maps import HashMap + + +@bpf +@map +def last() -> HashMap: + return HashMap(key=c_uint64, value=c_uint64, max_entries=3) + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + # last.update(0, 1) + tsp = last.lookup(0) + if not tsp: + print("Hello, World!") + else: + print("Goodbye, World!") + return + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/conditionals/or.py b/tests/passing_tests/conditionals/or.py new file mode 100644 index 0000000..5626179 --- /dev/null +++ b/tests/passing_tests/conditionals/or.py @@ -0,0 +1,32 @@ +from pythonbpf import bpf, map, section, bpfglobal, compile +from ctypes import c_void_p, c_int64, c_uint64 +from pythonbpf.maps import HashMap + + +@bpf +@map +def last() -> HashMap: + return HashMap(key=c_uint64, value=c_uint64, max_entries=3) + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + last.update(0, 1) + # last.update(1, 2) + x = last.lookup(0) + y = last.lookup(1) + if x or y: + print("Hello, World!") + else: + print("Goodbye, World!") + return + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/conditionals/struct_access.py b/tests/passing_tests/conditionals/struct_access.py new file mode 100644 index 0000000..5267290 --- /dev/null +++ b/tests/passing_tests/conditionals/struct_access.py @@ -0,0 +1,29 @@ +from pythonbpf import bpf, struct, section, bpfglobal, compile +from ctypes import c_void_p, c_int64, c_uint64 + + +@bpf +@struct +class data_t: + pid: c_uint64 + ts: c_uint64 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + dat = data_t() + if dat.ts: + print("Hello, World!") + else: + print("Goodbye, World!") + return + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/conditionals/type_mismatch.py b/tests/passing_tests/conditionals/type_mismatch.py new file mode 100644 index 0000000..1efc5e2 --- /dev/null +++ b/tests/passing_tests/conditionals/type_mismatch.py @@ -0,0 +1,23 @@ +from pythonbpf import bpf, section, bpfglobal, compile +from ctypes import c_void_p, c_int64, c_int32 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + x = 0 + y = c_int32(0) + if x == y: + print("Hello, World!") + else: + print("Goodbye, World!") + return + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/conditionals/var.py b/tests/passing_tests/conditionals/var.py new file mode 100644 index 0000000..449501e --- /dev/null +++ b/tests/passing_tests/conditionals/var.py @@ -0,0 +1,22 @@ +from pythonbpf import bpf, section, bpfglobal, compile +from ctypes import c_void_p, c_int64 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + x = 0 + if x: + print("Hello, World!") + else: + print("Goodbye, World!") + return + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/conditionals/var_binop.py b/tests/passing_tests/conditionals/var_binop.py new file mode 100644 index 0000000..75a2262 --- /dev/null +++ b/tests/passing_tests/conditionals/var_binop.py @@ -0,0 +1,22 @@ +from pythonbpf import bpf, section, bpfglobal, compile +from ctypes import c_void_p, c_int64 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + x = 0 + if x * 1: + print("Hello, World!") + else: + print("Goodbye, World!") + return + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/conditionals/var_comp.py b/tests/passing_tests/conditionals/var_comp.py new file mode 100644 index 0000000..4f12f15 --- /dev/null +++ b/tests/passing_tests/conditionals/var_comp.py @@ -0,0 +1,22 @@ +from pythonbpf import bpf, section, bpfglobal, compile +from ctypes import c_void_p, c_int64 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + x = 2 + if x > 3: + print("Hello, World!") + else: + print("Goodbye, World!") + return + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile() diff --git a/tests/passing_tests/return/bool.py b/tests/passing_tests/return/bool.py new file mode 100644 index 0000000..b5627a6 --- /dev/null +++ b/tests/passing_tests/return/bool.py @@ -0,0 +1,18 @@ +from pythonbpf import bpf, section, bpfglobal, compile +from ctypes import c_void_p, c_int64 + + +@bpf +@section("tracepoint/syscalls/sys_enter_execve") +def hello_world(ctx: c_void_p) -> c_int64: + print("Hello, World!") + return True + + +@bpf +@bpfglobal +def LICENSE() -> str: + return "GPL" + + +compile()