mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge branch 'main' into async-validators
This commit is contained in:
@ -84,6 +84,8 @@ DEFAULT_MUXER = "YAMUX"
|
|||||||
# Multiplexer options
|
# Multiplexer options
|
||||||
MUXER_YAMUX = "YAMUX"
|
MUXER_YAMUX = "YAMUX"
|
||||||
MUXER_MPLEX = "MPLEX"
|
MUXER_MPLEX = "MPLEX"
|
||||||
|
DEFAULT_NEGOTIATE_TIMEOUT = 5
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None:
|
def set_default_muxer(muxer_name: Literal["YAMUX", "MPLEX"]) -> None:
|
||||||
@ -249,6 +251,7 @@ def new_host(
|
|||||||
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
|
muxer_preference: Literal["YAMUX", "MPLEX"] | None = None,
|
||||||
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
|
listen_addrs: Sequence[multiaddr.Multiaddr] | None = None,
|
||||||
enable_mDNS: bool = False,
|
enable_mDNS: bool = False,
|
||||||
|
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||||
) -> IHost:
|
) -> IHost:
|
||||||
"""
|
"""
|
||||||
Create a new libp2p host based on the given parameters.
|
Create a new libp2p host based on the given parameters.
|
||||||
@ -274,6 +277,6 @@ def new_host(
|
|||||||
|
|
||||||
if disc_opt is not None:
|
if disc_opt is not None:
|
||||||
return RoutedHost(swarm, disc_opt, enable_mDNS)
|
return RoutedHost(swarm, disc_opt, enable_mDNS)
|
||||||
return BasicHost(swarm, enable_mDNS)
|
return BasicHost(network=swarm,enable_mDNS=enable_mDNS , negotitate_timeout=negotiate_timeout)
|
||||||
|
|
||||||
__version__ = __version("libp2p")
|
__version__ = __version("libp2p")
|
||||||
|
|||||||
@ -71,6 +71,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("libp2p.network.basic_host")
|
logger = logging.getLogger("libp2p.network.basic_host")
|
||||||
|
DEFAULT_NEGOTIATE_TIMEOUT = 5
|
||||||
|
|
||||||
|
|
||||||
class BasicHost(IHost):
|
class BasicHost(IHost):
|
||||||
@ -92,10 +93,12 @@ class BasicHost(IHost):
|
|||||||
network: INetworkService,
|
network: INetworkService,
|
||||||
enable_mDNS: bool = False,
|
enable_mDNS: bool = False,
|
||||||
default_protocols: Optional["OrderedDict[TProtocol, StreamHandlerFn]"] = None,
|
default_protocols: Optional["OrderedDict[TProtocol, StreamHandlerFn]"] = None,
|
||||||
|
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._network = network
|
self._network = network
|
||||||
self._network.set_stream_handler(self._swarm_stream_handler)
|
self._network.set_stream_handler(self._swarm_stream_handler)
|
||||||
self.peerstore = self._network.peerstore
|
self.peerstore = self._network.peerstore
|
||||||
|
self.negotiate_timeout = negotitate_timeout
|
||||||
# Protocol muxing
|
# Protocol muxing
|
||||||
default_protocols = default_protocols or get_default_protocols(self)
|
default_protocols = default_protocols or get_default_protocols(self)
|
||||||
self.multiselect = Multiselect(dict(default_protocols.items()))
|
self.multiselect = Multiselect(dict(default_protocols.items()))
|
||||||
@ -189,7 +192,10 @@ class BasicHost(IHost):
|
|||||||
self.multiselect.add_handler(protocol_id, stream_handler)
|
self.multiselect.add_handler(protocol_id, stream_handler)
|
||||||
|
|
||||||
async def new_stream(
|
async def new_stream(
|
||||||
self, peer_id: ID, protocol_ids: Sequence[TProtocol]
|
self,
|
||||||
|
peer_id: ID,
|
||||||
|
protocol_ids: Sequence[TProtocol],
|
||||||
|
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||||
) -> INetStream:
|
) -> INetStream:
|
||||||
"""
|
"""
|
||||||
:param peer_id: peer_id that host is connecting
|
:param peer_id: peer_id that host is connecting
|
||||||
@ -201,7 +207,9 @@ class BasicHost(IHost):
|
|||||||
# Perform protocol muxing to determine protocol to use
|
# Perform protocol muxing to determine protocol to use
|
||||||
try:
|
try:
|
||||||
selected_protocol = await self.multiselect_client.select_one_of(
|
selected_protocol = await self.multiselect_client.select_one_of(
|
||||||
list(protocol_ids), MultiselectCommunicator(net_stream)
|
list(protocol_ids),
|
||||||
|
MultiselectCommunicator(net_stream),
|
||||||
|
negotitate_timeout,
|
||||||
)
|
)
|
||||||
except MultiselectClientError as error:
|
except MultiselectClientError as error:
|
||||||
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
|
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
|
||||||
@ -211,7 +219,12 @@ class BasicHost(IHost):
|
|||||||
net_stream.set_protocol(selected_protocol)
|
net_stream.set_protocol(selected_protocol)
|
||||||
return net_stream
|
return net_stream
|
||||||
|
|
||||||
async def send_command(self, peer_id: ID, command: str) -> list[str]:
|
async def send_command(
|
||||||
|
self,
|
||||||
|
peer_id: ID,
|
||||||
|
command: str,
|
||||||
|
response_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||||
|
) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Send a multistream-select command to the specified peer and return
|
Send a multistream-select command to the specified peer and return
|
||||||
the response.
|
the response.
|
||||||
@ -225,7 +238,7 @@ class BasicHost(IHost):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self.multiselect_client.query_multistream_command(
|
response = await self.multiselect_client.query_multistream_command(
|
||||||
MultiselectCommunicator(new_stream), command
|
MultiselectCommunicator(new_stream), command, response_timeout
|
||||||
)
|
)
|
||||||
except MultiselectClientError as error:
|
except MultiselectClientError as error:
|
||||||
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
|
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
|
||||||
@ -264,7 +277,7 @@ class BasicHost(IHost):
|
|||||||
# Perform protocol muxing to determine protocol to use
|
# Perform protocol muxing to determine protocol to use
|
||||||
try:
|
try:
|
||||||
protocol, handler = await self.multiselect.negotiate(
|
protocol, handler = await self.multiselect.negotiate(
|
||||||
MultiselectCommunicator(net_stream)
|
MultiselectCommunicator(net_stream), self.negotiate_timeout
|
||||||
)
|
)
|
||||||
except MultiselectError as error:
|
except MultiselectError as error:
|
||||||
peer_id = net_stream.muxed_conn.peer_id
|
peer_id = net_stream.muxed_conn.peer_id
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
import trio
|
||||||
|
|
||||||
from libp2p.abc import (
|
from libp2p.abc import (
|
||||||
IMultiselectCommunicator,
|
IMultiselectCommunicator,
|
||||||
IMultiselectMuxer,
|
IMultiselectMuxer,
|
||||||
@ -14,6 +16,7 @@ from .exceptions import (
|
|||||||
|
|
||||||
MULTISELECT_PROTOCOL_ID = "/multistream/1.0.0"
|
MULTISELECT_PROTOCOL_ID = "/multistream/1.0.0"
|
||||||
PROTOCOL_NOT_FOUND_MSG = "na"
|
PROTOCOL_NOT_FOUND_MSG = "na"
|
||||||
|
DEFAULT_NEGOTIATE_TIMEOUT = 5
|
||||||
|
|
||||||
|
|
||||||
class Multiselect(IMultiselectMuxer):
|
class Multiselect(IMultiselectMuxer):
|
||||||
@ -47,47 +50,56 @@ class Multiselect(IMultiselectMuxer):
|
|||||||
|
|
||||||
# FIXME: Make TProtocol Optional[TProtocol] to keep types consistent
|
# FIXME: Make TProtocol Optional[TProtocol] to keep types consistent
|
||||||
async def negotiate(
|
async def negotiate(
|
||||||
self, communicator: IMultiselectCommunicator
|
self,
|
||||||
|
communicator: IMultiselectCommunicator,
|
||||||
|
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||||
) -> tuple[TProtocol, StreamHandlerFn | None]:
|
) -> tuple[TProtocol, StreamHandlerFn | None]:
|
||||||
"""
|
"""
|
||||||
Negotiate performs protocol selection.
|
Negotiate performs protocol selection.
|
||||||
|
|
||||||
:param stream: stream to negotiate on
|
:param stream: stream to negotiate on
|
||||||
|
:param negotiate_timeout: timeout for negotiation
|
||||||
:return: selected protocol name, handler function
|
:return: selected protocol name, handler function
|
||||||
:raise MultiselectError: raised when negotiation failed
|
:raise MultiselectError: raised when negotiation failed
|
||||||
"""
|
"""
|
||||||
await self.handshake(communicator)
|
try:
|
||||||
|
with trio.fail_after(negotiate_timeout):
|
||||||
|
await self.handshake(communicator)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
|
||||||
command = await communicator.read()
|
|
||||||
except MultiselectCommunicatorError as error:
|
|
||||||
raise MultiselectError() from error
|
|
||||||
|
|
||||||
if command == "ls":
|
|
||||||
supported_protocols = [p for p in self.handlers.keys() if p is not None]
|
|
||||||
response = "\n".join(supported_protocols) + "\n"
|
|
||||||
|
|
||||||
try:
|
|
||||||
await communicator.write(response)
|
|
||||||
except MultiselectCommunicatorError as error:
|
|
||||||
raise MultiselectError() from error
|
|
||||||
|
|
||||||
else:
|
|
||||||
protocol = TProtocol(command)
|
|
||||||
if protocol in self.handlers:
|
|
||||||
try:
|
try:
|
||||||
await communicator.write(protocol)
|
command = await communicator.read()
|
||||||
except MultiselectCommunicatorError as error:
|
except MultiselectCommunicatorError as error:
|
||||||
raise MultiselectError() from error
|
raise MultiselectError() from error
|
||||||
|
|
||||||
return protocol, self.handlers[protocol]
|
if command == "ls":
|
||||||
try:
|
supported_protocols = [
|
||||||
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
|
p for p in self.handlers.keys() if p is not None
|
||||||
except MultiselectCommunicatorError as error:
|
]
|
||||||
raise MultiselectError() from error
|
response = "\n".join(supported_protocols) + "\n"
|
||||||
|
|
||||||
raise MultiselectError("Negotiation failed: no matching protocol")
|
try:
|
||||||
|
await communicator.write(response)
|
||||||
|
except MultiselectCommunicatorError as error:
|
||||||
|
raise MultiselectError() from error
|
||||||
|
|
||||||
|
else:
|
||||||
|
protocol = TProtocol(command)
|
||||||
|
if protocol in self.handlers:
|
||||||
|
try:
|
||||||
|
await communicator.write(protocol)
|
||||||
|
except MultiselectCommunicatorError as error:
|
||||||
|
raise MultiselectError() from error
|
||||||
|
|
||||||
|
return protocol, self.handlers[protocol]
|
||||||
|
try:
|
||||||
|
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
|
||||||
|
except MultiselectCommunicatorError as error:
|
||||||
|
raise MultiselectError() from error
|
||||||
|
|
||||||
|
raise MultiselectError("Negotiation failed: no matching protocol")
|
||||||
|
except trio.TooSlowError:
|
||||||
|
raise MultiselectError("handshake read timeout")
|
||||||
|
|
||||||
async def handshake(self, communicator: IMultiselectCommunicator) -> None:
|
async def handshake(self, communicator: IMultiselectCommunicator) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -2,6 +2,8 @@ from collections.abc import (
|
|||||||
Sequence,
|
Sequence,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import trio
|
||||||
|
|
||||||
from libp2p.abc import (
|
from libp2p.abc import (
|
||||||
IMultiselectClient,
|
IMultiselectClient,
|
||||||
IMultiselectCommunicator,
|
IMultiselectCommunicator,
|
||||||
@ -17,6 +19,7 @@ from .exceptions import (
|
|||||||
|
|
||||||
MULTISELECT_PROTOCOL_ID = "/multistream/1.0.0"
|
MULTISELECT_PROTOCOL_ID = "/multistream/1.0.0"
|
||||||
PROTOCOL_NOT_FOUND_MSG = "na"
|
PROTOCOL_NOT_FOUND_MSG = "na"
|
||||||
|
DEFAULT_NEGOTIATE_TIMEOUT = 5
|
||||||
|
|
||||||
|
|
||||||
class MultiselectClient(IMultiselectClient):
|
class MultiselectClient(IMultiselectClient):
|
||||||
@ -40,6 +43,7 @@ class MultiselectClient(IMultiselectClient):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
handshake_contents = await communicator.read()
|
handshake_contents = await communicator.read()
|
||||||
|
|
||||||
except MultiselectCommunicatorError as error:
|
except MultiselectCommunicatorError as error:
|
||||||
raise MultiselectClientError() from error
|
raise MultiselectClientError() from error
|
||||||
|
|
||||||
@ -47,7 +51,10 @@ class MultiselectClient(IMultiselectClient):
|
|||||||
raise MultiselectClientError("multiselect protocol ID mismatch")
|
raise MultiselectClientError("multiselect protocol ID mismatch")
|
||||||
|
|
||||||
async def select_one_of(
|
async def select_one_of(
|
||||||
self, protocols: Sequence[TProtocol], communicator: IMultiselectCommunicator
|
self,
|
||||||
|
protocols: Sequence[TProtocol],
|
||||||
|
communicator: IMultiselectCommunicator,
|
||||||
|
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||||
) -> TProtocol:
|
) -> TProtocol:
|
||||||
"""
|
"""
|
||||||
For each protocol, send message to multiselect selecting protocol and
|
For each protocol, send message to multiselect selecting protocol and
|
||||||
@ -56,22 +63,32 @@ class MultiselectClient(IMultiselectClient):
|
|||||||
|
|
||||||
:param protocol: protocol to select
|
:param protocol: protocol to select
|
||||||
:param communicator: communicator to use to communicate with counterparty
|
:param communicator: communicator to use to communicate with counterparty
|
||||||
|
:param negotiate_timeout: timeout for negotiation
|
||||||
:return: selected protocol
|
:return: selected protocol
|
||||||
:raise MultiselectClientError: raised when protocol negotiation failed
|
:raise MultiselectClientError: raised when protocol negotiation failed
|
||||||
"""
|
"""
|
||||||
await self.handshake(communicator)
|
try:
|
||||||
|
with trio.fail_after(negotitate_timeout):
|
||||||
|
await self.handshake(communicator)
|
||||||
|
|
||||||
for protocol in protocols:
|
for protocol in protocols:
|
||||||
try:
|
try:
|
||||||
selected_protocol = await self.try_select(communicator, protocol)
|
selected_protocol = await self.try_select(
|
||||||
return selected_protocol
|
communicator, protocol
|
||||||
except MultiselectClientError:
|
)
|
||||||
pass
|
return selected_protocol
|
||||||
|
except MultiselectClientError:
|
||||||
|
pass
|
||||||
|
|
||||||
raise MultiselectClientError("protocols not supported")
|
raise MultiselectClientError("protocols not supported")
|
||||||
|
except trio.TooSlowError:
|
||||||
|
raise MultiselectClientError("response timed out")
|
||||||
|
|
||||||
async def query_multistream_command(
|
async def query_multistream_command(
|
||||||
self, communicator: IMultiselectCommunicator, command: str
|
self,
|
||||||
|
communicator: IMultiselectCommunicator,
|
||||||
|
command: str,
|
||||||
|
response_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Send a multistream-select command over the given communicator and return
|
Send a multistream-select command over the given communicator and return
|
||||||
@ -79,26 +96,32 @@ class MultiselectClient(IMultiselectClient):
|
|||||||
|
|
||||||
:param communicator: communicator to use to communicate with counterparty
|
:param communicator: communicator to use to communicate with counterparty
|
||||||
:param command: supported multistream-select command(e.g., ls)
|
:param command: supported multistream-select command(e.g., ls)
|
||||||
|
:param negotiate_timeout: timeout for negotiation
|
||||||
:raise MultiselectClientError: If the communicator fails to process data.
|
:raise MultiselectClientError: If the communicator fails to process data.
|
||||||
:return: list of strings representing the response from peer.
|
:return: list of strings representing the response from peer.
|
||||||
"""
|
"""
|
||||||
await self.handshake(communicator)
|
|
||||||
|
|
||||||
if command == "ls":
|
|
||||||
try:
|
|
||||||
await communicator.write("ls")
|
|
||||||
except MultiselectCommunicatorError as error:
|
|
||||||
raise MultiselectClientError() from error
|
|
||||||
else:
|
|
||||||
raise ValueError("Command not supported")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await communicator.read()
|
with trio.fail_after(response_timeout):
|
||||||
response_list = response.strip().splitlines()
|
await self.handshake(communicator)
|
||||||
except MultiselectCommunicatorError as error:
|
|
||||||
raise MultiselectClientError() from error
|
|
||||||
|
|
||||||
return response_list
|
if command == "ls":
|
||||||
|
try:
|
||||||
|
await communicator.write("ls")
|
||||||
|
except MultiselectCommunicatorError as error:
|
||||||
|
raise MultiselectClientError() from error
|
||||||
|
else:
|
||||||
|
raise ValueError("Command not supported")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await communicator.read()
|
||||||
|
response_list = response.strip().splitlines()
|
||||||
|
|
||||||
|
except MultiselectCommunicatorError as error:
|
||||||
|
raise MultiselectClientError() from error
|
||||||
|
|
||||||
|
return response_list
|
||||||
|
except trio.TooSlowError:
|
||||||
|
raise MultiselectClientError("command response timed out")
|
||||||
|
|
||||||
async def try_select(
|
async def try_select(
|
||||||
self, communicator: IMultiselectCommunicator, protocol: TProtocol
|
self, communicator: IMultiselectCommunicator, protocol: TProtocol
|
||||||
@ -118,6 +141,7 @@ class MultiselectClient(IMultiselectClient):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
response = await communicator.read()
|
response = await communicator.read()
|
||||||
|
|
||||||
except MultiselectCommunicatorError as error:
|
except MultiselectCommunicatorError as error:
|
||||||
raise MultiselectClientError() from error
|
raise MultiselectClientError() from error
|
||||||
|
|
||||||
|
|||||||
@ -31,9 +31,6 @@ from libp2p.stream_muxer.yamux.yamux import (
|
|||||||
Yamux,
|
Yamux,
|
||||||
)
|
)
|
||||||
|
|
||||||
# FIXME: add negotiate timeout to `MuxerMultistream`
|
|
||||||
DEFAULT_NEGOTIATE_TIMEOUT = 60
|
|
||||||
|
|
||||||
|
|
||||||
class MuxerMultistream:
|
class MuxerMultistream:
|
||||||
"""
|
"""
|
||||||
|
|||||||
4
newsfragments/696.bugfix.rst
Normal file
4
newsfragments/696.bugfix.rst
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
Add timeout wrappers in:
|
||||||
|
1. multiselect.py: `negotiate` function
|
||||||
|
2. multiselect_client.py: `select_one_of` , `query_multistream_command` functions
|
||||||
|
to prevent indefinite hangs when a remote peer does not respond.
|
||||||
59
tests/core/protocol_muxer/test_negotiate_timeout.py
Normal file
59
tests/core/protocol_muxer/test_negotiate_timeout.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import pytest
|
||||||
|
import trio
|
||||||
|
|
||||||
|
from libp2p.abc import (
|
||||||
|
IMultiselectCommunicator,
|
||||||
|
)
|
||||||
|
from libp2p.custom_types import TProtocol
|
||||||
|
from libp2p.protocol_muxer.exceptions import (
|
||||||
|
MultiselectClientError,
|
||||||
|
MultiselectError,
|
||||||
|
)
|
||||||
|
from libp2p.protocol_muxer.multiselect import Multiselect
|
||||||
|
from libp2p.protocol_muxer.multiselect_client import MultiselectClient
|
||||||
|
|
||||||
|
|
||||||
|
class DummyMultiselectCommunicator(IMultiselectCommunicator):
|
||||||
|
"""
|
||||||
|
Dummy MultiSelectCommunicator to test out negotiate timmeout.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
async def write(self, msg_str: str) -> None:
|
||||||
|
"""Goes into infinite loop when .write is called"""
|
||||||
|
await trio.sleep_forever()
|
||||||
|
|
||||||
|
async def read(self) -> str:
|
||||||
|
"""Returns a dummy read"""
|
||||||
|
return "dummy_read"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_select_one_of_timeout():
|
||||||
|
ECHO = TProtocol("/echo/1.0.0")
|
||||||
|
communicator = DummyMultiselectCommunicator()
|
||||||
|
|
||||||
|
client = MultiselectClient()
|
||||||
|
|
||||||
|
with pytest.raises(MultiselectClientError, match="response timed out"):
|
||||||
|
await client.select_one_of([ECHO], communicator, 2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_query_multistream_command_timeout():
|
||||||
|
communicator = DummyMultiselectCommunicator()
|
||||||
|
client = MultiselectClient()
|
||||||
|
|
||||||
|
with pytest.raises(MultiselectClientError, match="response timed out"):
|
||||||
|
await client.query_multistream_command(communicator, "ls", 2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_negotiate_timeout():
|
||||||
|
communicator = DummyMultiselectCommunicator()
|
||||||
|
server = Multiselect()
|
||||||
|
|
||||||
|
with pytest.raises(MultiselectError, match="handshake read timeout"):
|
||||||
|
await server.negotiate(communicator, 2)
|
||||||
Reference in New Issue
Block a user