FIXME: Make TProtocol Optional[TProtocol] to keep types consistent (#770)

* FIXME: Make TProtocol Optional[TProtocol] to keep types consistent

* correct test case of test_protocol_muxer

* add newsfragment

* unit test added

---------

Co-authored-by: Manu Sheel Gupta <manusheel.edu@gmail.com>
This commit is contained in:
Jinesh Jain
2025-08-20 06:50:37 +05:30
committed by GitHub
parent e20a9a3814
commit dabb3a0962
9 changed files with 167 additions and 17 deletions

View File

@ -295,6 +295,13 @@ class BasicHost(IHost):
)
await net_stream.reset()
return
if protocol is None:
logger.debug(
"no protocol negotiated, closing stream from peer %s",
net_stream.muxed_conn.peer_id,
)
await net_stream.reset()
return
net_stream.set_protocol(protocol)
if handler is None:
logger.debug(

View File

@ -48,12 +48,11 @@ class Multiselect(IMultiselectMuxer):
"""
self.handlers[protocol] = handler
# FIXME: Make TProtocol Optional[TProtocol] to keep types consistent
async def negotiate(
self,
communicator: IMultiselectCommunicator,
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
) -> tuple[TProtocol, StreamHandlerFn | None]:
) -> tuple[TProtocol | None, StreamHandlerFn | None]:
"""
Negotiate performs protocol selection.
@ -84,14 +83,14 @@ class Multiselect(IMultiselectMuxer):
raise MultiselectError() from error
else:
protocol = TProtocol(command)
if protocol in self.handlers:
protocol_to_check = None if not command else TProtocol(command)
if protocol_to_check in self.handlers:
try:
await communicator.write(protocol)
await communicator.write(command)
except MultiselectCommunicatorError as error:
raise MultiselectError() from error
return protocol, self.handlers[protocol]
return protocol_to_check, self.handlers[protocol_to_check]
try:
await communicator.write(PROTOCOL_NOT_FOUND_MSG)
except MultiselectCommunicatorError as error:

View File

@ -134,8 +134,10 @@ class MultiselectClient(IMultiselectClient):
:raise MultiselectClientError: raised when protocol negotiation failed
:return: selected protocol
"""
# Represent `None` protocol as an empty string.
protocol_str = protocol if protocol is not None else ""
try:
await communicator.write(protocol)
await communicator.write(protocol_str)
except MultiselectCommunicatorError as error:
raise MultiselectClientError() from error
@ -145,7 +147,7 @@ class MultiselectClient(IMultiselectClient):
except MultiselectCommunicatorError as error:
raise MultiselectClientError() from error
if response == protocol:
if response == protocol_str:
return protocol
if response == PROTOCOL_NOT_FOUND_MSG:
raise MultiselectClientError("protocol not supported")

View File

@ -30,7 +30,10 @@ class MultiselectCommunicator(IMultiselectCommunicator):
"""
:raise MultiselectCommunicatorError: raised when failed to write to underlying reader
""" # noqa: E501
msg_bytes = encode_delim(msg_str.encode())
if msg_str is None:
msg_bytes = encode_delim(b"")
else:
msg_bytes = encode_delim(msg_str.encode())
try:
await self.read_writer.write(msg_bytes)
except IOException as error:

View File

@ -17,6 +17,9 @@ from libp2p.custom_types import (
from libp2p.peer.id import (
ID,
)
from libp2p.protocol_muxer.exceptions import (
MultiselectError,
)
from libp2p.protocol_muxer.multiselect import (
Multiselect,
)
@ -104,7 +107,7 @@ class SecurityMultistream(ABC):
:param is_initiator: true if we are the initiator, false otherwise
:return: selected secure transport
"""
protocol: TProtocol
protocol: TProtocol | None
communicator = MultiselectCommunicator(conn)
if is_initiator:
# Select protocol if initiator
@ -114,5 +117,7 @@ class SecurityMultistream(ABC):
else:
# Select protocol if non-initiator
protocol, _ = await self.multiselect.negotiate(communicator)
if protocol is None:
raise MultiselectError("fail to negotiate a security protocol")
# Return transport from protocol
return self.transports[protocol]

View File

@ -17,6 +17,9 @@ from libp2p.custom_types import (
from libp2p.peer.id import (
ID,
)
from libp2p.protocol_muxer.exceptions import (
MultiselectError,
)
from libp2p.protocol_muxer.multiselect import (
Multiselect,
)
@ -73,7 +76,7 @@ class MuxerMultistream:
:param conn: conn to choose a transport over
:return: selected muxer transport
"""
protocol: TProtocol
protocol: TProtocol | None
communicator = MultiselectCommunicator(conn)
if conn.is_initiator:
protocol = await self.multiselect_client.select_one_of(
@ -81,6 +84,8 @@ class MuxerMultistream:
)
else:
protocol, _ = await self.multiselect.negotiate(communicator)
if protocol is None:
raise MultiselectError("fail to negotiate a stream muxer protocol")
return self.transports[protocol]
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn: