Merge branch 'main' into async-validators

This commit is contained in:
Manu Sheel Gupta
2025-07-02 10:19:13 -07:00
committed by GitHub
7 changed files with 173 additions and 61 deletions

View File

@ -84,6 +84,8 @@ DEFAULT_MUXER = "YAMUX"
# Multiplexer options # Multiplexer options
MUXER_YAMUX = "YAMUX" MUXER_YAMUX = "YAMUX"
MUXER_MPLEX = "MPLEX" MUXER_MPLEX = "MPLEX"
DEFAULT_NEGOTIATE_TIMEOUT = 5
def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None: def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None:
@ -249,6 +251,7 @@ def new_host(
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None, muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None, listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
enable_mDNS: bool = False, enable_mDNS: bool = False,
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
) -> IHost: ) -> IHost:
""" """
Create a new libp2p host based on the given parameters. Create a new libp2p host based on the given parameters.
@ -274,6 +277,6 @@ def new_host(
if disc_opt is not None: if disc_opt is not None:
return RoutedHost(swarm, disc_opt, enable_mDNS) return RoutedHost(swarm, disc_opt, enable_mDNS)
return BasicHost(swarm, enable_mDNS) return BasicHost(network=swarm,enable_mDNS=enable_mDNS , negotitate_timeout=negotiate_timeout)
__version__ = __version("libp2p") __version__ = __version("libp2p")

View File

@ -71,6 +71,7 @@ if TYPE_CHECKING:
logger = logging.getLogger("libp2p.network.basic_host") logger = logging.getLogger("libp2p.network.basic_host")
DEFAULT_NEGOTIATE_TIMEOUT = 5
class BasicHost(IHost): class BasicHost(IHost):
@ -92,10 +93,12 @@ class BasicHost(IHost):
network: INetworkService, network: INetworkService,
enable_mDNS: bool = False, enable_mDNS: bool = False,
default_protocols: Optional["OrderedDict[TProtocol, StreamHandlerFn]"] = None, default_protocols: Optional["OrderedDict[TProtocol, StreamHandlerFn]"] = None,
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
) -> None: ) -> None:
self._network = network self._network = network
self._network.set_stream_handler(self._swarm_stream_handler) self._network.set_stream_handler(self._swarm_stream_handler)
self.peerstore = self._network.peerstore self.peerstore = self._network.peerstore
self.negotiate_timeout = negotitate_timeout
# Protocol muxing # Protocol muxing
default_protocols = default_protocols or get_default_protocols(self) default_protocols = default_protocols or get_default_protocols(self)
self.multiselect = Multiselect(dict(default_protocols.items())) self.multiselect = Multiselect(dict(default_protocols.items()))
@ -189,7 +192,10 @@ class BasicHost(IHost):
self.multiselect.add_handler(protocol_id, stream_handler) self.multiselect.add_handler(protocol_id, stream_handler)
async def new_stream( async def new_stream(
self, peer_id: ID, protocol_ids: Sequence[TProtocol] self,
peer_id: ID,
protocol_ids: Sequence[TProtocol],
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
) -> INetStream: ) -> INetStream:
""" """
:param peer_id: peer_id that host is connecting :param peer_id: peer_id that host is connecting
@ -201,7 +207,9 @@ class BasicHost(IHost):
# Perform protocol muxing to determine protocol to use # Perform protocol muxing to determine protocol to use
try: try:
selected_protocol = await self.multiselect_client.select_one_of( selected_protocol = await self.multiselect_client.select_one_of(
list(protocol_ids), MultiselectCommunicator(net_stream) list(protocol_ids),
MultiselectCommunicator(net_stream),
negotitate_timeout,
) )
except MultiselectClientError as error: except MultiselectClientError as error:
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error) logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
@ -211,7 +219,12 @@ class BasicHost(IHost):
net_stream.set_protocol(selected_protocol) net_stream.set_protocol(selected_protocol)
return net_stream return net_stream
async def send_command(self, peer_id: ID, command: str) -> list[str]: async def send_command(
self,
peer_id: ID,
command: str,
response_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
) -> list[str]:
""" """
Send a multistream-select command to the specified peer and return Send a multistream-select command to the specified peer and return
the response. the response.
@ -225,7 +238,7 @@ class BasicHost(IHost):
try: try:
response = await self.multiselect_client.query_multistream_command( response = await self.multiselect_client.query_multistream_command(
MultiselectCommunicator(new_stream), command MultiselectCommunicator(new_stream), command, response_timeout
) )
except MultiselectClientError as error: except MultiselectClientError as error:
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error) logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
@ -264,7 +277,7 @@ class BasicHost(IHost):
# Perform protocol muxing to determine protocol to use # Perform protocol muxing to determine protocol to use
try: try:
protocol, handler = await self.multiselect.negotiate( protocol, handler = await self.multiselect.negotiate(
MultiselectCommunicator(net_stream) MultiselectCommunicator(net_stream), self.negotiate_timeout
) )
except MultiselectError as error: except MultiselectError as error:
peer_id = net_stream.muxed_conn.peer_id peer_id = net_stream.muxed_conn.peer_id

View File

@ -1,3 +1,5 @@
import trio
from libp2p.abc import ( from libp2p.abc import (
IMultiselectCommunicator, IMultiselectCommunicator,
IMultiselectMuxer, IMultiselectMuxer,
@ -14,6 +16,7 @@ from .exceptions import (
MULTISELECT_PROTOCOL_ID = "/multistream/1.0.0" MULTISELECT_PROTOCOL_ID = "/multistream/1.0.0"
PROTOCOL_NOT_FOUND_MSG = "na" PROTOCOL_NOT_FOUND_MSG = "na"
DEFAULT_NEGOTIATE_TIMEOUT = 5
class Multiselect(IMultiselectMuxer): class Multiselect(IMultiselectMuxer):
@ -47,47 +50,56 @@ class Multiselect(IMultiselectMuxer):
# FIXME: Make TProtocol Optional[TProtocol] to keep types consistent # FIXME: Make TProtocol Optional[TProtocol] to keep types consistent
async def negotiate( async def negotiate(
self, communicator: IMultiselectCommunicator self,
communicator: IMultiselectCommunicator,
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
) -> tuple[TProtocol, StreamHandlerFn | None]: ) -> tuple[TProtocol, StreamHandlerFn | None]:
""" """
Negotiate performs protocol selection. Negotiate performs protocol selection.
:param stream: stream to negotiate on :param stream: stream to negotiate on
:param negotiate_timeout: timeout for negotiation
:return: selected protocol name, handler function :return: selected protocol name, handler function
:raise MultiselectError: raised when negotiation failed :raise MultiselectError: raised when negotiation failed
""" """
await self.handshake(communicator) try:
with trio.fail_after(negotiate_timeout):
await self.handshake(communicator)
while True: while True:
try:
command = await communicator.read()
except MultiselectCommunicatorError as error:
raise MultiselectError() from error
if command == "ls":
supported_protocols = [p for p in self.handlers.keys() if p is not None]
response = "\n".join(supported_protocols) + "\n"
try:
await communicator.write(response)
except MultiselectCommunicatorError as error:
raise MultiselectError() from error
else:
protocol = TProtocol(command)
if protocol in self.handlers:
try: try:
await communicator.write(protocol) command = await communicator.read()
except MultiselectCommunicatorError as error: except MultiselectCommunicatorError as error:
raise MultiselectError() from error raise MultiselectError() from error
return protocol, self.handlers[protocol] if command == "ls":
try: supported_protocols = [
await communicator.write(PROTOCOL_NOT_FOUND_MSG) p for p in self.handlers.keys() if p is not None
except MultiselectCommunicatorError as error: ]
raise MultiselectError() from error response = "\n".join(supported_protocols) + "\n"
raise MultiselectError("Negotiation failed: no matching protocol") try:
await communicator.write(response)
except MultiselectCommunicatorError as error:
raise MultiselectError() from error
else:
protocol = TProtocol(command)
if protocol in self.handlers:
try:
await communicator.write(protocol)
except MultiselectCommunicatorError as error:
raise MultiselectError() from error
return protocol, self.handlers[protocol]
try:
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
except MultiselectCommunicatorError as error:
raise MultiselectError() from error
raise MultiselectError("Negotiation failed: no matching protocol")
except trio.TooSlowError:
raise MultiselectError("handshake read timeout")
async def handshake(self, communicator: IMultiselectCommunicator) -> None: async def handshake(self, communicator: IMultiselectCommunicator) -> None:
""" """

View File

@ -2,6 +2,8 @@ from collections.abc import (
Sequence, Sequence,
) )
import trio
from libp2p.abc import ( from libp2p.abc import (
IMultiselectClient, IMultiselectClient,
IMultiselectCommunicator, IMultiselectCommunicator,
@ -17,6 +19,7 @@ from .exceptions import (
MULTISELECT_PROTOCOL_ID = "/multistream/1.0.0" MULTISELECT_PROTOCOL_ID = "/multistream/1.0.0"
PROTOCOL_NOT_FOUND_MSG = "na" PROTOCOL_NOT_FOUND_MSG = "na"
DEFAULT_NEGOTIATE_TIMEOUT = 5
class MultiselectClient(IMultiselectClient): class MultiselectClient(IMultiselectClient):
@ -40,6 +43,7 @@ class MultiselectClient(IMultiselectClient):
try: try:
handshake_contents = await communicator.read() handshake_contents = await communicator.read()
except MultiselectCommunicatorError as error: except MultiselectCommunicatorError as error:
raise MultiselectClientError() from error raise MultiselectClientError() from error
@ -47,7 +51,10 @@ class MultiselectClient(IMultiselectClient):
raise MultiselectClientError("multiselect protocol ID mismatch") raise MultiselectClientError("multiselect protocol ID mismatch")
async def select_one_of( async def select_one_of(
self, protocols: Sequence[TProtocol], communicator: IMultiselectCommunicator self,
protocols: Sequence[TProtocol],
communicator: IMultiselectCommunicator,
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
) -> TProtocol: ) -> TProtocol:
""" """
For each protocol, send message to multiselect selecting protocol and For each protocol, send message to multiselect selecting protocol and
@ -56,22 +63,32 @@ class MultiselectClient(IMultiselectClient):
:param protocol: protocol to select :param protocol: protocol to select
:param communicator: communicator to use to communicate with counterparty :param communicator: communicator to use to communicate with counterparty
:param negotiate_timeout: timeout for negotiation
:return: selected protocol :return: selected protocol
:raise MultiselectClientError: raised when protocol negotiation failed :raise MultiselectClientError: raised when protocol negotiation failed
""" """
await self.handshake(communicator) try:
with trio.fail_after(negotitate_timeout):
await self.handshake(communicator)
for protocol in protocols: for protocol in protocols:
try: try:
selected_protocol = await self.try_select(communicator, protocol) selected_protocol = await self.try_select(
return selected_protocol communicator, protocol
except MultiselectClientError: )
pass return selected_protocol
except MultiselectClientError:
pass
raise MultiselectClientError("protocols not supported") raise MultiselectClientError("protocols not supported")
except trio.TooSlowError:
raise MultiselectClientError("response timed out")
async def query_multistream_command( async def query_multistream_command(
self, communicator: IMultiselectCommunicator, command: str self,
communicator: IMultiselectCommunicator,
command: str,
response_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
) -> list[str]: ) -> list[str]:
""" """
Send a multistream-select command over the given communicator and return Send a multistream-select command over the given communicator and return
@ -79,26 +96,32 @@ class MultiselectClient(IMultiselectClient):
:param communicator: communicator to use to communicate with counterparty :param communicator: communicator to use to communicate with counterparty
:param command: supported multistream-select command(e.g., ls) :param command: supported multistream-select command(e.g., ls)
:param negotiate_timeout: timeout for negotiation
:raise MultiselectClientError: If the communicator fails to process data. :raise MultiselectClientError: If the communicator fails to process data.
:return: list of strings representing the response from peer. :return: list of strings representing the response from peer.
""" """
await self.handshake(communicator)
if command == "ls":
try:
await communicator.write("ls")
except MultiselectCommunicatorError as error:
raise MultiselectClientError() from error
else:
raise ValueError("Command not supported")
try: try:
response = await communicator.read() with trio.fail_after(response_timeout):
response_list = response.strip().splitlines() await self.handshake(communicator)
except MultiselectCommunicatorError as error:
raise MultiselectClientError() from error
return response_list if command == "ls":
try:
await communicator.write("ls")
except MultiselectCommunicatorError as error:
raise MultiselectClientError() from error
else:
raise ValueError("Command not supported")
try:
response = await communicator.read()
response_list = response.strip().splitlines()
except MultiselectCommunicatorError as error:
raise MultiselectClientError() from error
return response_list
except trio.TooSlowError:
raise MultiselectClientError("command response timed out")
async def try_select( async def try_select(
self, communicator: IMultiselectCommunicator, protocol: TProtocol self, communicator: IMultiselectCommunicator, protocol: TProtocol
@ -118,6 +141,7 @@ class MultiselectClient(IMultiselectClient):
try: try:
response = await communicator.read() response = await communicator.read()
except MultiselectCommunicatorError as error: except MultiselectCommunicatorError as error:
raise MultiselectClientError() from error raise MultiselectClientError() from error

View File

@ -31,9 +31,6 @@ from libp2p.stream_muxer.yamux.yamux import (
Yamux, Yamux,
) )
# FIXME: add negotiate timeout to `MuxerMultistream`
DEFAULT_NEGOTIATE_TIMEOUT = 60
class MuxerMultistream: class MuxerMultistream:
""" """

View File

@ -0,0 +1,4 @@
Add timeout wrappers in:
1. multiselect.py: `negotiate` function
2. multiselect_client.py: `select_one_of` , `query_multistream_command` functions
to prevent indefinite hangs when a remote peer does not respond.

View File

@ -0,0 +1,59 @@
import pytest
import trio
from libp2p.abc import (
IMultiselectCommunicator,
)
from libp2p.custom_types import TProtocol
from libp2p.protocol_muxer.exceptions import (
MultiselectClientError,
MultiselectError,
)
from libp2p.protocol_muxer.multiselect import Multiselect
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
class DummyMultiselectCommunicator(IMultiselectCommunicator):
"""
Dummy MultiSelectCommunicator to test out negotiate timmeout.
"""
def __init__(self) -> None:
return
async def write(self, msg_str: str) -> None:
"""Goes into infinite loop when .write is called"""
await trio.sleep_forever()
async def read(self) -> str:
"""Returns a dummy read"""
return "dummy_read"
@pytest.mark.trio
async def test_select_one_of_timeout():
ECHO = TProtocol("/echo/1.0.0")
communicator = DummyMultiselectCommunicator()
client = MultiselectClient()
with pytest.raises(MultiselectClientError, match="response timed out"):
await client.select_one_of([ECHO], communicator, 2)
@pytest.mark.trio
async def test_query_multistream_command_timeout():
communicator = DummyMultiselectCommunicator()
client = MultiselectClient()
with pytest.raises(MultiselectClientError, match="response timed out"):
await client.query_multistream_command(communicator, "ls", 2)
@pytest.mark.trio
async def test_negotiate_timeout():
communicator = DummyMultiselectCommunicator()
server = Multiselect()
with pytest.raises(MultiselectError, match="handshake read timeout"):
await server.negotiate(communicator, 2)