From 9355f33da80374957638909d8425d365dc039adf Mon Sep 17 00:00:00 2001 From: Alex Stokes Date: Sat, 24 Aug 2019 19:58:56 +0200 Subject: [PATCH] Add basic test for `secio` Two peers in-memory can create a secure, bidirectional channel --- tests/security/test_secio.py | 120 +++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 tests/security/test_secio.py diff --git a/tests/security/test_secio.py b/tests/security/test_secio.py new file mode 100644 index 00000000..27a00c5e --- /dev/null +++ b/tests/security/test_secio.py @@ -0,0 +1,120 @@ +import asyncio + +import pytest + +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.network.connection.raw_connection_interface import IRawConnection +from libp2p.peer.id import ID +from libp2p.security.secio.transport import NONCE_SIZE, create_secure_session + + +class InMemoryConnection(IRawConnection): + def __init__(self, peer, initiator=False): + self.peer = peer + self.recv_queue = asyncio.Queue() + self.send_queue = asyncio.Queue() + self.initiator = initiator + + self.current_msg = None + self.current_position = 0 + + self.closed = False + self.stream_counter = 0 + + @property + def writer(self): + return self + + @property + def reader(self): + return self + + async def write(self, data: bytes) -> None: + if self.closed: + raise Exception("InMemoryConnection is closed for writing") + + await self.send_queue.put(data) + + async def drain(self): + return + + async def readexactly(self, n): + return await self.read(n) + + async def read(self, n: int = -1) -> bytes: + """ + NOTE: have to buffer the current message and juggle packets + off the recv queue to satisfy the semantics of this function. + """ + if self.closed: + raise Exception("InMemoryConnection is closed for reading") + + if not self.current_msg: + self.current_msg = await self.recv_queue.get() + self.current_position = 0 + + if n < 0: + msg = self.current_msg + self.current_msg = None + return msg + + next_msg = self.current_msg[self.current_position : self.current_position + n] + self.current_position += n + if self.current_position == len(self.current_msg): + self.current_msg = None + return next_msg + + def close(self) -> None: + self.closed = True + + def next_stream_id(self) -> int: + self.stream_counter += 1 + return self.stream_counter + + +async def create_pipe(local_conn, remote_conn): + try: + while True: + next_msg = await local_conn.send_queue.get() + await remote_conn.recv_queue.put(next_msg) + except asyncio.CancelledError: + return + + +@pytest.mark.asyncio +async def test_create_secure_session(): + local_nonce = b"\x01" * NONCE_SIZE + local_key_pair = create_new_key_pair(b"a") + local_peer = ID.from_pubkey(local_key_pair.public_key) + + remote_nonce = b"\x02" * NONCE_SIZE + remote_key_pair = create_new_key_pair(b"b") + remote_peer = ID.from_pubkey(remote_key_pair.public_key) + + local_conn = InMemoryConnection(local_peer, initiator=True) + remote_conn = InMemoryConnection(remote_peer) + + local_pipe_task = asyncio.create_task(create_pipe(local_conn, remote_conn)) + remote_pipe_task = asyncio.create_task(create_pipe(remote_conn, local_conn)) + + local_session_builder = create_secure_session( + local_nonce, local_peer, local_key_pair.private_key, local_conn, remote_peer + ) + remote_session_builder = create_secure_session( + remote_nonce, remote_peer, remote_key_pair.private_key, remote_conn + ) + local_secure_conn, remote_secure_conn = await asyncio.gather( + local_session_builder, remote_session_builder + ) + + local_pipe_task.cancel() + remote_pipe_task.cancel() + await local_pipe_task + await remote_pipe_task + + assert local_secure_conn + assert remote_secure_conn + + +if __name__ == "__main__": + asyncio.run(test_create_secure_session())