diff --git a/examples/execve3.py b/examples/execve3.py index 7ec4039..faf524a 100644 --- a/examples/execve3.py +++ b/examples/execve3.py @@ -30,10 +30,10 @@ def hello_again(ctx: c_void_p) -> c_int64: # if delta < 1000000000: # print("execve called within last second") # last().delete(key) - x = True + x = 1 y = False - if x: - if y: + if x > 0: + if x < 2: print("we prevailed") else: print("we did not prevail") diff --git a/pythonbpf/functions_pass.py b/pythonbpf/functions_pass.py index 19a5161..aa91de5 100644 --- a/pythonbpf/functions_pass.py +++ b/pythonbpf/functions_pass.py @@ -101,8 +101,26 @@ def handle_assign(func, module, builder, stmt, map_sym_tab, local_sym_tab): def handle_expr(func, module, builder, expr, local_sym_tab, map_sym_tab): """Handle expression statements in the function body.""" + print(f"Handling expression: {ast.dump(expr)}") + + if isinstance(expr, ast.Name): + if expr.id in local_sym_tab: + var = local_sym_tab[expr.id] + val = builder.load(var) + return val + else: + print(f"Undefined variable {expr.id}") + return None + elif isinstance(expr, ast.Constant): + if isinstance(expr.value, int): + return ir.Constant(ir.IntType(64), expr.value) + elif isinstance(expr.value, bool): + return ir.Constant(ir.IntType(1), int(expr.value)) + else: + print("Unsupported constant type") + return None + call = expr.value - print(f"Handling expression: {ast.dump(call)}") if isinstance(call, ast.Call): if isinstance(call.func, ast.Name): # check for helpers first @@ -154,6 +172,42 @@ def handle_cond(func, module, builder, cond, local_sym_tab, map_sym_tab): else: print(f"Undefined variable {cond.id} in condition") return None + elif isinstance(cond, ast.Compare): + lhs = handle_expr(func, module, builder, cond.left, + local_sym_tab, map_sym_tab) + if len(cond.ops) != 1 or len(cond.comparators) != 1: + print("Unsupported complex comparison") + return None + rhs = handle_expr(func, module, builder, + cond.comparators[0], local_sym_tab, map_sym_tab) + op = cond.ops[0] + + if lhs.type != rhs.type: + if isinstance(lhs.type, ir.IntType) and isinstance(rhs.type, ir.IntType): + # Extend the smaller type to the larger type + if lhs.type.width < rhs.type.width: + lhs = builder.sext(lhs, rhs.type) + elif lhs.type.width > rhs.type.width: + rhs = builder.sext(rhs, lhs.type) + else: + print("Type mismatch in comparison") + return None + + if isinstance(op, ast.Eq): + return builder.icmp_signed("==", lhs, rhs) + elif isinstance(op, ast.NotEq): + return builder.icmp_signed("!=", lhs, rhs) + elif isinstance(op, ast.Lt): + return builder.icmp_signed("<", lhs, rhs) + elif isinstance(op, ast.LtE): + return builder.icmp_signed("<=", lhs, rhs) + elif isinstance(op, ast.Gt): + return builder.icmp_signed(">", lhs, rhs) + elif isinstance(op, ast.GtE): + return builder.icmp_signed(">=", lhs, rhs) + else: + print("Unsupported comparison operator") + return None else: print("Unsupported condition expression") return None