From 559f419b4e01dec7f5748afb9d727b8d65851a77 Mon Sep 17 00:00:00 2001 From: NIC619 Date: Tue, 17 Sep 2019 15:42:18 +0800 Subject: [PATCH] Fix stream registration in `accept_stream` --- libp2p/stream_muxer/mplex/mplex.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index d85a7c21..7a5323dc 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -103,6 +103,12 @@ class Mplex(IMuxedConn): self.next_channel_id += 1 return next_id + async def _initialize_stream(self, stream_id: StreamID, name: str) -> MplexStream: + async with self.streams_lock: + stream = MplexStream(name, stream_id, self) + self.streams[stream_id] = stream + return stream + async def open_stream(self) -> IMuxedStream: """ creates a new muxed_stream @@ -112,29 +118,24 @@ class Mplex(IMuxedConn): stream_id = StreamID(channel_id=channel_id, is_initiator=True) # Default stream name is the `channel_id` name = str(channel_id) - async with self.streams_lock: - stream = MplexStream(name, stream_id, self) + stream = await self._initialize_stream(stream_id, name) await self.send_message(HeaderTags.NewStream, name.encode(), stream_id) - # TODO: is there a way to know if the peer accepted the stream? - # then we can safely register the stream - self.streams[stream_id] = stream return stream async def accept_stream(self, stream_id: StreamID, name: str) -> None: """ accepts a muxed stream opened by the other end """ - async with self.streams_lock: - stream = MplexStream(name, stream_id, self) + stream = await self._initialize_stream(stream_id, name) # Perform protocol negotiation for the stream. try: await self.generic_protocol_handler(stream) except MultiselectError: - # TODO: what to do when stream protocol negotiation fail? + # Un-register and reset the stream + del self.streams[stream_id] + await stream.reset() return - self.streams[stream_id] = stream - async def send_message( self, flag: HeaderTags, data: Optional[bytes], stream_id: StreamID ) -> int: