diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index c8919c23..c1b42c58 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -23,7 +23,8 @@ if TYPE_CHECKING: """ -Reference: https://github.com/libp2p/go-libp2p-swarm/blob/04c86bbdafd390651cb2ee14e334f7caeedad722/swarm_conn.go +Reference: https://github.com/libp2p/go-libp2p-swarm/blob/ +04c86bbdafd390651cb2ee14e334f7caeedad722/swarm_conn.go """ @@ -43,6 +44,21 @@ class SwarmConn(INetConn): self.streams = set() self.event_closed = trio.Event() self.event_started = trio.Event() + # Provide back-references/hooks expected by NetStream + try: + setattr(self.muxed_conn, "swarm", self.swarm) + + # NetStream expects an awaitable remove_stream hook + async def _remove_stream_hook(stream: NetStream) -> None: + self.remove_stream(stream) + + setattr(self.muxed_conn, "remove_stream", _remove_stream_hook) + except Exception as e: + logging.warning( + f"Failed to set optional conveniences on muxed_conn " + f"for peer {muxed_conn.peer_id}: {e}" + ) + # optional conveniences if hasattr(muxed_conn, "on_close"): logging.debug(f"Setting on_close for peer {muxed_conn.peer_id}") setattr(muxed_conn, "on_close", self._on_muxed_conn_closed) diff --git a/libp2p/network/swarm.py b/libp2p/network/swarm.py index 706d649a..0a1ae1cd 100644 --- a/libp2p/network/swarm.py +++ b/libp2p/network/swarm.py @@ -1,3 +1,7 @@ +from collections.abc import ( + Awaitable, + Callable, +) import logging from multiaddr import ( @@ -411,7 +415,15 @@ class Swarm(Service, INetworkService): nursery.start_soon(notifee.listen, self, multiaddr) async def notify_closed_stream(self, stream: INetStream) -> None: - raise NotImplementedError + async with trio.open_nursery() as nursery: + for notifee in self.notifees: + nursery.start_soon(notifee.closed_stream, self, stream) async def notify_listen_close(self, multiaddr: Multiaddr) -> None: raise NotImplementedError + + # Generic notifier used by NetStream._notify_closed + async def notify_all(self, notifier: Callable[[INotifee], Awaitable[None]]) -> None: + async with trio.open_nursery() as nursery: + for notifee in self.notifees: + nursery.start_soon(notifier, notifee) diff --git a/newsfragments/826.feature.rst b/newsfragments/826.feature.rst new file mode 100644 index 00000000..face9786 --- /dev/null +++ b/newsfragments/826.feature.rst @@ -0,0 +1,6 @@ +Implement closed_stream notification in MyNotifee + +- Add notify_closed_stream method to swarm notification system for proper stream lifecycle management +- Integrate remove_stream hook in SwarmConn to enable stream closure notifications +- Add comprehensive tests for closed_stream functionality in test_notify.py +- Enable stream lifecycle integration for proper cleanup and resource management diff --git a/tests/core/network/test_notify.py b/tests/core/network/test_notify.py index 98caaf86..b19dd961 100644 --- a/tests/core/network/test_notify.py +++ b/tests/core/network/test_notify.py @@ -44,8 +44,11 @@ class MyNotifee(INotifee): self.events.append(Event.OpenedStream) async def closed_stream(self, network: INetwork, stream: INetStream) -> None: - # TODO: It is not implemented yet. - pass + if network is None: + raise ValueError("network parameter cannot be None") + if stream is None: + raise ValueError("stream parameter cannot be None") + self.events.append(Event.ClosedStream) async def connected(self, network: INetwork, conn: INetConn) -> None: self.events.append(Event.Connected) @@ -103,28 +106,20 @@ async def test_notify(security_protocol): # Wait for events assert await wait_for_event(events_0_0, Event.Connected, 1.0) assert await wait_for_event(events_0_0, Event.OpenedStream, 1.0) - # assert await wait_for_event( - # events_0_0, Event.ClosedStream, 1.0 - # ) # Not implemented + assert await wait_for_event(events_0_0, Event.ClosedStream, 1.0) assert await wait_for_event(events_0_0, Event.Disconnected, 1.0) assert await wait_for_event(events_0_1, Event.Connected, 1.0) assert await wait_for_event(events_0_1, Event.OpenedStream, 1.0) - # assert await wait_for_event( - # events_0_1, Event.ClosedStream, 1.0 - # ) # Not implemented + assert await wait_for_event(events_0_1, Event.ClosedStream, 1.0) assert await wait_for_event(events_0_1, Event.Disconnected, 1.0) assert await wait_for_event(events_1_0, Event.Connected, 1.0) assert await wait_for_event(events_1_0, Event.OpenedStream, 1.0) - # assert await wait_for_event( - # events_1_0, Event.ClosedStream, 1.0 - # ) # Not implemented + assert await wait_for_event(events_1_0, Event.ClosedStream, 1.0) assert await wait_for_event(events_1_0, Event.Disconnected, 1.0) assert await wait_for_event(events_1_1, Event.Connected, 1.0) assert await wait_for_event(events_1_1, Event.OpenedStream, 1.0) - # assert await wait_for_event( - # events_1_1, Event.ClosedStream, 1.0 - # ) # Not implemented + assert await wait_for_event(events_1_1, Event.ClosedStream, 1.0) assert await wait_for_event(events_1_1, Event.Disconnected, 1.0)