From fabf2cefc4e88a85d6f3943c0c41c60a572ec69b Mon Sep 17 00:00:00 2001 From: acul71 Date: Thu, 13 Mar 2025 14:08:13 +0100 Subject: [PATCH] feat: implement get_remote_address via delegation pattern --- libp2p/io/abc.py | 18 +++++++++++++++++- libp2p/io/trio.py | 11 +++++++++++ libp2p/network/connection/raw_connection.py | 8 ++++++++ libp2p/network/stream/net_stream.py | 4 ++++ libp2p/security/insecure/transport.py | 16 ++++++++++++++++ libp2p/security/secure_session.py | 7 +++++++ libp2p/stream_muxer/mplex/mplex.py | 4 ++++ libp2p/stream_muxer/mplex/mplex_stream.py | 5 +++++ 8 files changed, 72 insertions(+), 1 deletion(-) diff --git a/libp2p/io/abc.py b/libp2p/io/abc.py index d87da7af..75125fd8 100644 --- a/libp2p/io/abc.py +++ b/libp2p/io/abc.py @@ -2,6 +2,9 @@ from abc import ( ABC, abstractmethod, ) +from typing import ( + Optional, +) class Closer(ABC): @@ -35,7 +38,14 @@ class ReadWriter(Reader, Writer): class ReadWriteCloser(Reader, Writer, Closer): - pass + @abstractmethod + def get_remote_address(self) -> Optional[tuple[str, int]]: + """ + Return the remote address of the connected peer. + + :return: A tuple of (host, port) or None if not available + """ + ... class MsgReader(ABC): @@ -66,3 +76,9 @@ class Encrypter(ABC): class EncryptedMsgReadWriter(MsgReadWriteCloser, Encrypter): """Read/write message with encryption/decryption.""" + + def get_remote_address(self) -> Optional[tuple[str, int]]: + """Get remote address if supported by the underlying connection.""" + if hasattr(self, "conn") and hasattr(self.conn, "get_remote_address"): + return self.conn.get_remote_address() + return None diff --git a/libp2p/io/trio.py b/libp2p/io/trio.py index 3998dbef..f0301b90 100644 --- a/libp2p/io/trio.py +++ b/libp2p/io/trio.py @@ -1,4 +1,7 @@ import logging +from typing import ( + Optional, +) import trio @@ -42,3 +45,11 @@ class TrioTCPStream(ReadWriteCloser): async def close(self) -> None: await self.stream.aclose() + + def get_remote_address(self) -> Optional[tuple[str, int]]: + """Return the remote address as (host, port) tuple.""" + try: + return self.stream.socket.getpeername() + except (AttributeError, OSError) as e: + logger.error("Error getting remote address: %s", e) + return None diff --git a/libp2p/network/connection/raw_connection.py b/libp2p/network/connection/raw_connection.py index fc2ea61b..2c6dd5d7 100644 --- a/libp2p/network/connection/raw_connection.py +++ b/libp2p/network/connection/raw_connection.py @@ -1,3 +1,7 @@ +from typing import ( + Optional, +) + from libp2p.abc import ( IRawConnection, ) @@ -42,3 +46,7 @@ class RawConnection(IRawConnection): async def close(self) -> None: await self.stream.close() + + def get_remote_address(self) -> Optional[tuple[str, int]]: + """Delegate to the underlying stream's get_remote_address method.""" + return self.stream.get_remote_address() diff --git a/libp2p/network/stream/net_stream.py b/libp2p/network/stream/net_stream.py index 5dc053c4..694b302b 100644 --- a/libp2p/network/stream/net_stream.py +++ b/libp2p/network/stream/net_stream.py @@ -78,6 +78,10 @@ class NetStream(INetStream): async def reset(self) -> None: await self.muxed_stream.reset() + def get_remote_address(self) -> Optional[tuple[str, int]]: + """Delegate to the underlying muxed stream.""" + return self.muxed_stream.get_remote_address() + # TODO: `remove`: Called by close and write when the stream is in specific states. # It notifies `ClosedStream` after `SwarmConn.remove_stream` is called. # Reference: https://github.com/libp2p/go-libp2p-swarm/blob/99831444e78c8f23c9335c17d8f7c700ba25ca14/swarm_stream.go # noqa: E501 diff --git a/libp2p/security/insecure/transport.py b/libp2p/security/insecure/transport.py index 8a0e3939..4666cc78 100644 --- a/libp2p/security/insecure/transport.py +++ b/libp2p/security/insecure/transport.py @@ -1,3 +1,7 @@ +from typing import ( + Optional, +) + from libp2p.abc import ( IRawConnection, ISecureConn, @@ -73,6 +77,12 @@ class InsecureSession(BaseSession): is_initiator=is_initiator, ) self.conn = conn + # Cache the remote address to avoid repeated lookups + # through the delegation chain + try: + self.remote_peer_addr = conn.get_remote_address() + except AttributeError: + self.remote_peer_addr = None async def write(self, data: bytes) -> None: await self.conn.write(data) @@ -83,6 +93,12 @@ class InsecureSession(BaseSession): async def close(self) -> None: await self.conn.close() + def get_remote_address(self) -> Optional[tuple[str, int]]: + """ + Delegate to the underlying connection's get_remote_address method. + """ + return self.conn.get_remote_address() + async def run_handshake( local_peer: ID, diff --git a/libp2p/security/secure_session.py b/libp2p/security/secure_session.py index 7c727619..7551bfee 100644 --- a/libp2p/security/secure_session.py +++ b/libp2p/security/secure_session.py @@ -1,4 +1,7 @@ import io +from typing import ( + Optional, +) from libp2p.crypto.keys import ( PrivateKey, @@ -41,6 +44,10 @@ class SecureSession(BaseSession): self._reset_internal_buffer() + def get_remote_address(self) -> Optional[tuple[str, int]]: + """Delegate to the underlying connection's get_remote_address method.""" + return self.conn.get_remote_address() + def _reset_internal_buffer(self) -> None: self.buf = io.BytesIO() self.low_watermark = 0 diff --git a/libp2p/stream_muxer/mplex/mplex.py b/libp2p/stream_muxer/mplex/mplex.py index 332a84ae..e21e0768 100644 --- a/libp2p/stream_muxer/mplex/mplex.py +++ b/libp2p/stream_muxer/mplex/mplex.py @@ -365,3 +365,7 @@ class Mplex(IMuxedConn): await send_channel.aclose() self.event_closed.set() await self.new_stream_send_channel.aclose() + + def get_remote_address(self) -> Optional[tuple[str, int]]: + """Delegate to the underlying Mplex connection's secured_conn.""" + return self.secured_conn.get_remote_address() diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 3026b824..9b876a55 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -1,5 +1,6 @@ from typing import ( TYPE_CHECKING, + Optional, ) import trio @@ -252,3 +253,7 @@ class MplexStream(IMuxedStream): """ self.write_deadline = ttl return True + + def get_remote_address(self) -> Optional[tuple[str, int]]: + """Delegate to the parent Mplex connection.""" + return self.muxed_conn.get_remote_address()