diff --git a/libp2p/security/security_multistream.py b/libp2p/security/security_multistream.py index 81db4cf2..94d29348 100644 --- a/libp2p/security/security_multistream.py +++ b/libp2p/security/security_multistream.py @@ -40,7 +40,7 @@ class SecurityMultistream(ABC): """ # Select a secure transport - transport = await self.select_transport(conn, True) + transport = await self.select_transport(conn, False) # Create secured connection secure_conn = await transport.secure_inbound(conn) @@ -81,7 +81,6 @@ class SecurityMultistream(ABC): protocol = await self.multiselect_client.select_one_of(list(self.transports.keys()), conn) else: # Select protocol if non-initiator - protocol = await self.multiselect.negotiate(conn) - + protocol, _ = await self.multiselect.negotiate(conn) # Return transport from protocol return self.transports[protocol] diff --git a/tests/security/test_security_multistream.py b/tests/security/test_security_multistream.py index aae8ce3a..1fcc8bca 100644 --- a/tests/security/test_security_multistream.py +++ b/tests/security/test_security_multistream.py @@ -32,8 +32,8 @@ async def perform_simple_test(assertion_func, transports_for_initiator, transpor # TODO: implement -- note we need to introduce the notion of communicating over a raw connection # for testing, we do NOT want to communicate over a stream so we can't just create two nodes # and use their conn because our mplex will internally relay messages to a stream - sec_opt1 = dict((str(i), transport) for i, transport in enumerate(transports_for_initiator)) - sec_opt2 = dict((str(i), transport) for i, transport in enumerate(transports_for_noninitiator)) + sec_opt1 = transports_for_initiator + sec_opt2 = transports_for_noninitiator node1 = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"], sec_opt=sec_opt1) node2 = await new_node(transport_opt=["/ip4/127.0.0.1/tcp/0"], sec_opt=sec_opt2) @@ -62,8 +62,8 @@ async def perform_simple_test(assertion_func, transports_for_initiator, transpor @pytest.mark.asyncio async def test_single_insecure_security_transport_succeeds(): - transports_for_initiator = [InsecureTransport("foo")] - transports_for_noninitiator = [InsecureTransport("foo")] + transports_for_initiator = {"foo": InsecureTransport("foo")} + transports_for_noninitiator = {"foo": InsecureTransport("foo")} def assertion_func(details): assert details["id"] == "foo" @@ -73,8 +73,8 @@ async def test_single_insecure_security_transport_succeeds(): @pytest.mark.asyncio async def test_single_simple_test_security_transport_succeeds(): - transports_for_initiator = [SimpleSecurityTransport("tacos")] - transports_for_noninitiator = [SimpleSecurityTransport("tacos")] + transports_for_initiator = {"tacos": SimpleSecurityTransport("tacos")} + transports_for_noninitiator = {"tacos": SimpleSecurityTransport("tacos")} def assertion_func(details): assert details["key_phrase"] == "tacos" @@ -82,3 +82,15 @@ async def test_single_simple_test_security_transport_succeeds(): await perform_simple_test(assertion_func, transports_for_initiator, transports_for_noninitiator) +@pytest.mark.asyncio +async def test_two_simple_test_security_transport_for_initiator_succeeds(): + transports_for_initiator = {"tacos": SimpleSecurityTransport("tacos"), + "shleep": SimpleSecurityTransport("shleep")} + transports_for_noninitiator = {"shleep": SimpleSecurityTransport("shleep")} + + def assertion_func(details): + assert details["key_phrase"] == "shleep" + + await perform_simple_test(assertion_func, + transports_for_initiator, transports_for_noninitiator) +