Add condition eval and basic if example - workin

This commit is contained in:
Pragyansh Chaturvedi
2025-09-10 04:05:07 +05:30
parent 357ad7cb99
commit aeb9a45175
3 changed files with 95 additions and 43 deletions

View File

@ -1,6 +1,6 @@
compile: compile:
chmod +x ./tools/compile.py chmod +x ./tools/compile.py
./tools/compile.py ./examples/execve2.py ./tools/compile.py ./examples/execve3.py
install: install:
pip install -e . pip install -e .

View File

@ -1,9 +1,10 @@
from pythonbpf import bpf, map, section, bpfglobal, compile from pythonbpf import bpf, map, section, bpfglobal, compile
from pythonbpf.helpers import bpf_ktime_get_ns from pythonbpf.helpers import ktime
from pythonbpf.maps import HashMap from pythonbpf.maps import HashMap
from ctypes import c_void_p, c_int64, c_int32, c_uint64 from ctypes import c_void_p, c_int64, c_int32, c_uint64
@bpf @bpf
@map @map
def last() -> HashMap: def last() -> HashMap:
@ -24,13 +25,16 @@ def hello_again(ctx: c_void_p) -> c_int64:
print("exited") print("exited")
key = 0 key = 0
tsp = last().lookup(key) tsp = last().lookup(key)
if tsp: # if tsp:
delta = (bpf_ktime_get_ns() - tsp.value) # delta = (bpf_ktime_get_ns() - tsp.value)
if delta < 1000000000: # if delta < 1000000000:
print("execve called within last second") # print("execve called within last second")
last().delete(key) # last().delete(key)
ts = bpf_ktime_get_ns() if True:
last().update(key, ts) print("we prevailed")
# ts = ktime()
ktime()
# last().update(key, ts)
return c_int64(0) return c_int64(0)
@ -39,4 +43,5 @@ def hello_again(ctx: c_void_p) -> c_int64:
def LICENSE() -> str: def LICENSE() -> str:
return "GPL" return "GPL"
compile() compile()

View File

@ -22,7 +22,7 @@ def get_probe_string(func_node):
return "helper" return "helper"
def handle_assign(module, builder, stmt, map_sym_tab, local_sym_tab): def handle_assign(func, module, builder, stmt, map_sym_tab, local_sym_tab):
"""Handle assignment statements in the function body.""" """Handle assignment statements in the function body."""
if len(stmt.targets) != 1: if len(stmt.targets) != 1:
print("Unsupported multiassignment") print("Unsupported multiassignment")
@ -74,48 +74,113 @@ def handle_assign(module, builder, stmt, map_sym_tab, local_sym_tab):
map_ptr = map_sym_tab[map_name] map_ptr = map_sym_tab[map_name]
if method_name in helper_func_list: if method_name in helper_func_list:
handle_helper_call( handle_helper_call(
rval, module, builder, None, local_sym_tab, map_sym_tab) rval, module, builder, func, local_sym_tab, map_sym_tab)
else: else:
print("Unsupported assignment call structure") print("Unsupported assignment call structure")
else: else:
print("Unsupported assignment call function type") print("Unsupported assignment call function type")
def handle_if_statement(module, builder, stmt, map_sym_tab, local_sym_tab): def handle_expr(func, module, builder, expr, local_sym_tab, map_sym_tab):
pass
def handle_expr(module, builder, expr, local_sym_tab, map_sym_tab):
"""Handle expression statements in the function body.""" """Handle expression statements in the function body."""
call = expr.value call = expr.value
print(f"Handling expression: {ast.dump(call)}")
if isinstance(call, ast.Call): if isinstance(call, ast.Call):
if isinstance(call.func, ast.Name): if isinstance(call.func, ast.Name):
# check for helpers first # check for helpers first
if call.func.id in helper_func_list: if call.func.id in helper_func_list:
handle_helper_call( handle_helper_call(
call, module, builder, None, local_sym_tab, map_sym_tab) call, module, builder, func, local_sym_tab, map_sym_tab)
return return
print("Unsupported expression statement") elif isinstance(call, ast.Name):
if call.id in local_sym_tab:
var = local_sym_tab[call.id]
val = builder.load(var)
return val
else:
print(f"Undefined variable {call.id}")
return None
elif isinstance(call, ast.Constant):
if isinstance(call.value, int):
return ir.Constant(ir.IntType(64), call.value)
elif isinstance(call.value, bool):
return ir.Constant(ir.IntType(1), int(call.value))
else:
print("Unsupported constant type")
return None
else:
print("Unsupported expression statement")
def handle_if(module, builder, stmt, map_sym_tab, local_sym_tab): def handle_cond(func, module, builder, cond, local_sym_tab, map_sym_tab):
if isinstance(cond, ast.Constant):
if isinstance(cond.value, bool):
return ir.Constant(ir.IntType(1), int(cond.value))
elif isinstance(cond.value, int):
return ir.Constant(ir.IntType(1), int(bool(cond.value)))
else:
print("Unsupported constant type in condition")
return None
elif isinstance(cond, ast.Name):
if cond.id in local_sym_tab:
var = local_sym_tab[cond.id]
val = builder.load(var)
return val
else:
print(f"Undefined variable {cond.id} in condition")
return None
else:
print("Unsupported condition expression")
return None
def handle_if(func, module, builder, stmt, map_sym_tab, local_sym_tab):
"""Handle if statements in the function body.""" """Handle if statements in the function body."""
func = builder.block.parent print("Handling if statement")
start = builder.block.parent
then_block = func.append_basic_block(name="if.then") then_block = func.append_basic_block(name="if.then")
merge_block = func.append_basic_block(name="if.end") merge_block = func.append_basic_block(name="if.end")
cond = stmt.test cond = handle_cond(func, module, builder, stmt.test,
local_sym_tab, map_sym_tab)
builder.cbranch(cond, then_block, merge_block) builder.cbranch(cond, then_block, merge_block)
builder.position_at_end(then_block) builder.position_at_end(then_block)
for s in stmt.body: for s in stmt.body:
pass process_stmt(func, module, builder, s,
local_sym_tab, map_sym_tab, False)
if not builder.block.is_terminated: if not builder.block.is_terminated:
builder.branch(merge_block) builder.branch(merge_block)
builder.position_at_end(merge_block) builder.position_at_end(merge_block)
def process_stmt(func, module, builder, stmt, local_sym_tab, map_sym_tab, did_return, ret_type=ir.IntType(64)):
print(f"Processing statement: {ast.dump(stmt)}")
if isinstance(stmt, ast.Expr):
handle_expr(func, module, builder, stmt, local_sym_tab, map_sym_tab)
elif isinstance(stmt, ast.Assign):
handle_assign(func, module, builder, stmt, map_sym_tab, local_sym_tab)
elif isinstance(stmt, ast.If):
handle_if(func, module, builder, stmt, map_sym_tab, local_sym_tab)
elif isinstance(stmt, ast.Return):
if stmt.value is None:
builder.ret(ir.Constant(ir.IntType(32), 0))
did_return = True
elif isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name) and len(stmt.value.args) == 1 and isinstance(stmt.value.args[0], ast.Constant) and isinstance(stmt.value.args[0].value, int):
call_type = stmt.value.func.id
if ctypes_to_ir(call_type) != ret_type:
raise ValueError("Return type mismatch: expected"
f"{ctypes_to_ir(call_type)}, got {call_type}")
else:
builder.ret(ir.Constant(
ret_type, stmt.value.args[0].value))
did_return = True
else:
print("Unsupported return value")
return did_return
def process_func_body(module, builder, func_node, func, ret_type, map_sym_tab): def process_func_body(module, builder, func_node, func, ret_type, map_sym_tab):
"""Process the body of a bpf function""" """Process the body of a bpf function"""
# TODO: A lot. We just have print -> bpf_trace_printk for now # TODO: A lot. We just have print -> bpf_trace_printk for now
@ -124,27 +189,9 @@ def process_func_body(module, builder, func_node, func, ret_type, map_sym_tab):
local_sym_tab = {} local_sym_tab = {}
for stmt in func_node.body: for stmt in func_node.body:
if isinstance(stmt, ast.Expr): did_return = process_stmt(func, module, builder, stmt, local_sym_tab,
handle_expr(module, builder, stmt, local_sym_tab, map_sym_tab) map_sym_tab, did_return, ret_type)
elif isinstance(stmt, ast.Assign):
handle_assign(module, builder, stmt, map_sym_tab, local_sym_tab)
elif isinstance(stmt, ast.If):
handle_if(module, builder, stmt, map_sym_tab, local_sym_tab)
elif isinstance(stmt, ast.Return):
if stmt.value is None:
builder.ret(ir.Constant(ir.IntType(32), 0))
did_return = True
elif isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name) and len(stmt.value.args) == 1 and isinstance(stmt.value.args[0], ast.Constant) and isinstance(stmt.value.args[0].value, int):
call_type = stmt.value.func.id
if ctypes_to_ir(call_type) != ret_type:
raise ValueError("Return type mismatch: expected"
f"{ctypes_to_ir(call_type)}, got {call_type}")
else:
builder.ret(ir.Constant(
ret_type, stmt.value.args[0].value))
did_return = True
else:
print("Unsupported return value")
if not did_return: if not did_return:
builder.ret(ir.Constant(ir.IntType(32), 0)) builder.ret(ir.Constant(ir.IntType(32), 0))