mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge remote changes with local WebSocket improvements
- Combined yashksaini-coder's flow control improvements with luca's WSS features - Preserved comprehensive WSS support, TLS configuration, and handshake timeout - Added production-ready buffer management and connection limits - Maintained backward compatibility with existing WebSocket functionality - Integrated both approaches for optimal WebSocket transport implementation
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@ -178,8 +178,6 @@ env.bak/
|
|||||||
#lockfiles
|
#lockfiles
|
||||||
uv.lock
|
uv.lock
|
||||||
poetry.lock
|
poetry.lock
|
||||||
|
|
||||||
# JavaScript interop test files
|
|
||||||
tests/interop/js_libp2p/js_node/node_modules/
|
tests/interop/js_libp2p/js_node/node_modules/
|
||||||
tests/interop/js_libp2p/js_node/package-lock.json
|
tests/interop/js_libp2p/js_node/package-lock.json
|
||||||
tests/interop/js_libp2p/js_node/src/node_modules/
|
tests/interop/js_libp2p/js_node/src/node_modules/
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
"""Libp2p Python implementation."""
|
"""Libp2p Python implementation."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import ssl
|
|
||||||
|
|
||||||
from libp2p.transport.quic.utils import is_quic_multiaddr
|
from libp2p.transport.quic.utils import is_quic_multiaddr
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -180,8 +179,6 @@ def new_swarm(
|
|||||||
enable_quic: bool = False,
|
enable_quic: bool = False,
|
||||||
retry_config: Optional["RetryConfig"] = None,
|
retry_config: Optional["RetryConfig"] = None,
|
||||||
connection_config: ConnectionConfig | QUICTransportConfig | None = None,
|
connection_config: ConnectionConfig | QUICTransportConfig | None = None,
|
||||||
tls_client_config: ssl.SSLContext | None = None,
|
|
||||||
tls_server_config: ssl.SSLContext | None = None,
|
|
||||||
) -> INetworkService:
|
) -> INetworkService:
|
||||||
"""
|
"""
|
||||||
Create a swarm instance based on the parameters.
|
Create a swarm instance based on the parameters.
|
||||||
@ -193,9 +190,7 @@ def new_swarm(
|
|||||||
:param muxer_preference: optional explicit muxer preference
|
:param muxer_preference: optional explicit muxer preference
|
||||||
:param listen_addrs: optional list of multiaddrs to listen on
|
:param listen_addrs: optional list of multiaddrs to listen on
|
||||||
:param enable_quic: enable quic for transport
|
:param enable_quic: enable quic for transport
|
||||||
:param connection_config: options for transport configuration
|
:param quic_transport_opt: options for transport
|
||||||
:param tls_client_config: optional TLS configuration for WebSocket client connections (WSS)
|
|
||||||
:param tls_server_config: optional TLS configuration for WebSocket server connections (WSS)
|
|
||||||
:return: return a default swarm instance
|
:return: return a default swarm instance
|
||||||
|
|
||||||
Note: Yamux (/yamux/1.0.0) is the preferred stream multiplexer
|
Note: Yamux (/yamux/1.0.0) is the preferred stream multiplexer
|
||||||
@ -208,6 +203,24 @@ def new_swarm(
|
|||||||
|
|
||||||
id_opt = generate_peer_id_from(key_pair)
|
id_opt = generate_peer_id_from(key_pair)
|
||||||
|
|
||||||
|
transport: TCP | QUICTransport | ITransport
|
||||||
|
quic_transport_opt = connection_config if isinstance(connection_config, QUICTransportConfig) else None
|
||||||
|
|
||||||
|
if listen_addrs is None:
|
||||||
|
if enable_quic:
|
||||||
|
transport = QUICTransport(key_pair.private_key, config=quic_transport_opt)
|
||||||
|
else:
|
||||||
|
transport = TCP()
|
||||||
|
else:
|
||||||
|
addr = listen_addrs[0]
|
||||||
|
is_quic = is_quic_multiaddr(addr)
|
||||||
|
if addr.__contains__("tcp"):
|
||||||
|
transport = TCP()
|
||||||
|
elif is_quic:
|
||||||
|
transport = QUICTransport(key_pair.private_key, config=quic_transport_opt)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown transport in listen_addrs: {listen_addrs}")
|
||||||
|
|
||||||
# Generate X25519 keypair for Noise
|
# Generate X25519 keypair for Noise
|
||||||
noise_key_pair = create_new_x25519_key_pair()
|
noise_key_pair = create_new_x25519_key_pair()
|
||||||
|
|
||||||
@ -248,24 +261,19 @@ def new_swarm(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create transport based on listen_addrs or default to TCP
|
# Create transport based on listen_addrs or default to TCP
|
||||||
transport: ITransport
|
|
||||||
if listen_addrs is None:
|
if listen_addrs is None:
|
||||||
transport = TCP()
|
transport = TCP()
|
||||||
else:
|
else:
|
||||||
# Use the first address to determine transport type
|
# Use the first address to determine transport type
|
||||||
addr = listen_addrs[0]
|
addr = listen_addrs[0]
|
||||||
transport_maybe = create_transport_for_multiaddr(
|
transport_maybe = create_transport_for_multiaddr(addr, upgrader)
|
||||||
addr,
|
|
||||||
upgrader,
|
|
||||||
private_key=key_pair.private_key,
|
|
||||||
tls_client_config=tls_client_config,
|
|
||||||
tls_server_config=tls_server_config
|
|
||||||
)
|
|
||||||
|
|
||||||
if transport_maybe is None:
|
if transport_maybe is None:
|
||||||
# Fallback to TCP if no specific transport found
|
# Fallback to TCP if no specific transport found
|
||||||
if addr.__contains__("tcp"):
|
if addr.__contains__("tcp"):
|
||||||
transport = TCP()
|
transport = TCP()
|
||||||
|
elif addr.__contains__("quic"):
|
||||||
|
transport = QUICTransport(key_pair.private_key, config=quic_transport_opt)
|
||||||
else:
|
else:
|
||||||
supported_protocols = get_supported_transport_protocols()
|
supported_protocols = get_supported_transport_protocols()
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -275,6 +283,31 @@ def new_swarm(
|
|||||||
else:
|
else:
|
||||||
transport = transport_maybe
|
transport = transport_maybe
|
||||||
|
|
||||||
|
# Use given muxer preference if provided, otherwise use global default
|
||||||
|
if muxer_preference is not None:
|
||||||
|
temp_pref = muxer_preference.upper()
|
||||||
|
if temp_pref not in [MUXER_YAMUX, MUXER_MPLEX]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown muxer: {muxer_preference}. Use 'YAMUX' or 'MPLEX'."
|
||||||
|
)
|
||||||
|
active_preference = temp_pref
|
||||||
|
else:
|
||||||
|
active_preference = DEFAULT_MUXER
|
||||||
|
|
||||||
|
# Use provided muxer options if given, otherwise create based on preference
|
||||||
|
if muxer_opt is not None:
|
||||||
|
muxer_transports_by_protocol = muxer_opt
|
||||||
|
else:
|
||||||
|
if active_preference == MUXER_MPLEX:
|
||||||
|
muxer_transports_by_protocol = create_mplex_muxer_option()
|
||||||
|
else: # YAMUX is default
|
||||||
|
muxer_transports_by_protocol = create_yamux_muxer_option()
|
||||||
|
|
||||||
|
upgrader = TransportUpgrader(
|
||||||
|
secure_transports_by_protocol=secure_transports_by_protocol,
|
||||||
|
muxer_transports_by_protocol=muxer_transports_by_protocol,
|
||||||
|
)
|
||||||
|
|
||||||
peerstore = peerstore_opt or PeerStore()
|
peerstore = peerstore_opt or PeerStore()
|
||||||
# Store our key pair in peerstore
|
# Store our key pair in peerstore
|
||||||
peerstore.add_key_pair(id_opt, key_pair)
|
peerstore.add_key_pair(id_opt, key_pair)
|
||||||
@ -302,8 +335,6 @@ def new_host(
|
|||||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||||
enable_quic: bool = False,
|
enable_quic: bool = False,
|
||||||
quic_transport_opt: QUICTransportConfig | None = None,
|
quic_transport_opt: QUICTransportConfig | None = None,
|
||||||
tls_client_config: ssl.SSLContext | None = None,
|
|
||||||
tls_server_config: ssl.SSLContext | None = None,
|
|
||||||
) -> IHost:
|
) -> IHost:
|
||||||
"""
|
"""
|
||||||
Create a new libp2p host based on the given parameters.
|
Create a new libp2p host based on the given parameters.
|
||||||
@ -318,9 +349,7 @@ def new_host(
|
|||||||
:param enable_mDNS: whether to enable mDNS discovery
|
:param enable_mDNS: whether to enable mDNS discovery
|
||||||
:param bootstrap: optional list of bootstrap peer addresses as strings
|
:param bootstrap: optional list of bootstrap peer addresses as strings
|
||||||
:param enable_quic: optinal choice to use QUIC for transport
|
:param enable_quic: optinal choice to use QUIC for transport
|
||||||
:param quic_transport_opt: optional configuration for quic transport
|
:param transport_opt: optional configuration for quic transport
|
||||||
:param tls_client_config: optional TLS configuration for WebSocket client connections (WSS)
|
|
||||||
:param tls_server_config: optional TLS configuration for WebSocket server connections (WSS)
|
|
||||||
:return: return a host instance
|
:return: return a host instance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -335,9 +364,7 @@ def new_host(
|
|||||||
peerstore_opt=peerstore_opt,
|
peerstore_opt=peerstore_opt,
|
||||||
muxer_preference=muxer_preference,
|
muxer_preference=muxer_preference,
|
||||||
listen_addrs=listen_addrs,
|
listen_addrs=listen_addrs,
|
||||||
connection_config=quic_transport_opt if enable_quic else None,
|
connection_config=quic_transport_opt if enable_quic else None
|
||||||
tls_client_config=tls_client_config,
|
|
||||||
tls_server_config=tls_server_config
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if disc_opt is not None:
|
if disc_opt is not None:
|
||||||
|
|||||||
@ -213,7 +213,6 @@ class BasicHost(IHost):
|
|||||||
self,
|
self,
|
||||||
peer_id: ID,
|
peer_id: ID,
|
||||||
protocol_ids: Sequence[TProtocol],
|
protocol_ids: Sequence[TProtocol],
|
||||||
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
|
||||||
) -> INetStream:
|
) -> INetStream:
|
||||||
"""
|
"""
|
||||||
:param peer_id: peer_id that host is connecting
|
:param peer_id: peer_id that host is connecting
|
||||||
@ -227,7 +226,7 @@ class BasicHost(IHost):
|
|||||||
selected_protocol = await self.multiselect_client.select_one_of(
|
selected_protocol = await self.multiselect_client.select_one_of(
|
||||||
list(protocol_ids),
|
list(protocol_ids),
|
||||||
MultiselectCommunicator(net_stream),
|
MultiselectCommunicator(net_stream),
|
||||||
negotitate_timeout,
|
self.negotiate_timeout,
|
||||||
)
|
)
|
||||||
except MultiselectClientError as error:
|
except MultiselectClientError as error:
|
||||||
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
|
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
|
||||||
|
|||||||
@ -490,6 +490,7 @@ class Swarm(Service, INetworkService):
|
|||||||
for maddr in multiaddrs:
|
for maddr in multiaddrs:
|
||||||
logger.debug(f"Swarm.listen processing multiaddr: {maddr}")
|
logger.debug(f"Swarm.listen processing multiaddr: {maddr}")
|
||||||
if str(maddr) in self.listeners:
|
if str(maddr) in self.listeners:
|
||||||
|
logger.debug(f"Swarm.listen: listener already exists for {maddr}")
|
||||||
success_count += 1
|
success_count += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -555,6 +556,7 @@ class Swarm(Service, INetworkService):
|
|||||||
# I/O agnostic, we should change the API.
|
# I/O agnostic, we should change the API.
|
||||||
if self.listener_nursery is None:
|
if self.listener_nursery is None:
|
||||||
raise SwarmException("swarm instance hasn't been run")
|
raise SwarmException("swarm instance hasn't been run")
|
||||||
|
assert self.listener_nursery is not None # For type checker
|
||||||
logger.debug(f"Swarm.listen: calling listener.listen for {maddr}")
|
logger.debug(f"Swarm.listen: calling listener.listen for {maddr}")
|
||||||
await listener.listen(maddr, self.listener_nursery)
|
await listener.listen(maddr, self.listener_nursery)
|
||||||
logger.debug(f"Swarm.listen: listener.listen completed for {maddr}")
|
logger.debug(f"Swarm.listen: listener.listen completed for {maddr}")
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from libp2p.protocol_muxer.exceptions import (
|
|||||||
MultiselectError,
|
MultiselectError,
|
||||||
)
|
)
|
||||||
from libp2p.protocol_muxer.multiselect import (
|
from libp2p.protocol_muxer.multiselect import (
|
||||||
|
DEFAULT_NEGOTIATE_TIMEOUT,
|
||||||
Multiselect,
|
Multiselect,
|
||||||
)
|
)
|
||||||
from libp2p.protocol_muxer.multiselect_client import (
|
from libp2p.protocol_muxer.multiselect_client import (
|
||||||
@ -46,11 +47,17 @@ class MuxerMultistream:
|
|||||||
transports: "OrderedDict[TProtocol, TMuxerClass]"
|
transports: "OrderedDict[TProtocol, TMuxerClass]"
|
||||||
multiselect: Multiselect
|
multiselect: Multiselect
|
||||||
multiselect_client: MultiselectClient
|
multiselect_client: MultiselectClient
|
||||||
|
negotiate_timeout: int
|
||||||
|
|
||||||
def __init__(self, muxer_transports_by_protocol: TMuxerOptions) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
muxer_transports_by_protocol: TMuxerOptions,
|
||||||
|
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||||
|
) -> None:
|
||||||
self.transports = OrderedDict()
|
self.transports = OrderedDict()
|
||||||
self.multiselect = Multiselect()
|
self.multiselect = Multiselect()
|
||||||
self.multistream_client = MultiselectClient()
|
self.multistream_client = MultiselectClient()
|
||||||
|
self.negotiate_timeout = negotiate_timeout
|
||||||
for protocol, transport in muxer_transports_by_protocol.items():
|
for protocol, transport in muxer_transports_by_protocol.items():
|
||||||
self.add_transport(protocol, transport)
|
self.add_transport(protocol, transport)
|
||||||
|
|
||||||
@ -80,10 +87,12 @@ class MuxerMultistream:
|
|||||||
communicator = MultiselectCommunicator(conn)
|
communicator = MultiselectCommunicator(conn)
|
||||||
if conn.is_initiator:
|
if conn.is_initiator:
|
||||||
protocol = await self.multiselect_client.select_one_of(
|
protocol = await self.multiselect_client.select_one_of(
|
||||||
tuple(self.transports.keys()), communicator
|
tuple(self.transports.keys()), communicator, self.negotiate_timeout
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
protocol, _ = await self.multiselect.negotiate(communicator)
|
protocol, _ = await self.multiselect.negotiate(
|
||||||
|
communicator, self.negotiate_timeout
|
||||||
|
)
|
||||||
if protocol is None:
|
if protocol is None:
|
||||||
raise MultiselectError(
|
raise MultiselectError(
|
||||||
"Fail to negotiate a stream muxer protocol: no protocol selected"
|
"Fail to negotiate a stream muxer protocol: no protocol selected"
|
||||||
@ -93,7 +102,7 @@ class MuxerMultistream:
|
|||||||
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
|
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
|
||||||
communicator = MultiselectCommunicator(conn)
|
communicator = MultiselectCommunicator(conn)
|
||||||
protocol = await self.multistream_client.select_one_of(
|
protocol = await self.multistream_client.select_one_of(
|
||||||
tuple(self.transports.keys()), communicator
|
tuple(self.transports.keys()), communicator, self.negotiate_timeout
|
||||||
)
|
)
|
||||||
transport_class = self.transports[protocol]
|
transport_class = self.transports[protocol]
|
||||||
if protocol == PROTOCOL_ID:
|
if protocol == PROTOCOL_ID:
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable
|
|||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from aioquic.quic import events
|
from aioquic.quic import events
|
||||||
from aioquic.quic.connection import QuicConnection
|
from aioquic.quic.connection import QuicConnection
|
||||||
@ -871,9 +871,11 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
|||||||
# Process events by type
|
# Process events by type
|
||||||
for event_type, event_list in events_by_type.items():
|
for event_type, event_list in events_by_type.items():
|
||||||
if event_type == type(events.StreamDataReceived).__name__:
|
if event_type == type(events.StreamDataReceived).__name__:
|
||||||
await self._handle_stream_data_batch(
|
# Filter to only StreamDataReceived events
|
||||||
cast(list[events.StreamDataReceived], event_list)
|
stream_data_events = [
|
||||||
)
|
e for e in event_list if isinstance(e, events.StreamDataReceived)
|
||||||
|
]
|
||||||
|
await self._handle_stream_data_batch(stream_data_events)
|
||||||
else:
|
else:
|
||||||
# Process other events individually
|
# Process other events individually
|
||||||
for event in event_list:
|
for event in event_list:
|
||||||
|
|||||||
@ -14,6 +14,9 @@ from libp2p.protocol_muxer.exceptions import (
|
|||||||
MultiselectClientError,
|
MultiselectClientError,
|
||||||
MultiselectError,
|
MultiselectError,
|
||||||
)
|
)
|
||||||
|
from libp2p.protocol_muxer.multiselect import (
|
||||||
|
DEFAULT_NEGOTIATE_TIMEOUT,
|
||||||
|
)
|
||||||
from libp2p.security.exceptions import (
|
from libp2p.security.exceptions import (
|
||||||
HandshakeFailure,
|
HandshakeFailure,
|
||||||
)
|
)
|
||||||
@ -37,9 +40,12 @@ class TransportUpgrader:
|
|||||||
self,
|
self,
|
||||||
secure_transports_by_protocol: TSecurityOptions,
|
secure_transports_by_protocol: TSecurityOptions,
|
||||||
muxer_transports_by_protocol: TMuxerOptions,
|
muxer_transports_by_protocol: TMuxerOptions,
|
||||||
|
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||||
):
|
):
|
||||||
self.security_multistream = SecurityMultistream(secure_transports_by_protocol)
|
self.security_multistream = SecurityMultistream(secure_transports_by_protocol)
|
||||||
self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol)
|
self.muxer_multistream = MuxerMultistream(
|
||||||
|
muxer_transports_by_protocol, negotiate_timeout
|
||||||
|
)
|
||||||
|
|
||||||
async def upgrade_security(
|
async def upgrade_security(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -14,10 +14,17 @@ class P2PWebSocketConnection(ReadWriteCloser):
|
|||||||
"""
|
"""
|
||||||
Wraps a WebSocketConnection to provide the raw stream interface
|
Wraps a WebSocketConnection to provide the raw stream interface
|
||||||
that libp2p protocols expect.
|
that libp2p protocols expect.
|
||||||
|
|
||||||
|
Implements production-ready buffer management and flow control
|
||||||
|
as recommended in the libp2p WebSocket specification.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, ws_connection: Any, ws_context: Any = None, is_secure: bool = False
|
self,
|
||||||
|
ws_connection: Any,
|
||||||
|
ws_context: Any = None,
|
||||||
|
is_secure: bool = False,
|
||||||
|
max_buffered_amount: int = 4 * 1024 * 1024,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._ws_connection = ws_connection
|
self._ws_connection = ws_connection
|
||||||
self._ws_context = ws_context
|
self._ws_context = ws_context
|
||||||
@ -29,18 +36,36 @@ class P2PWebSocketConnection(ReadWriteCloser):
|
|||||||
self._bytes_written = 0
|
self._bytes_written = 0
|
||||||
self._closed = False
|
self._closed = False
|
||||||
self._close_lock = trio.Lock()
|
self._close_lock = trio.Lock()
|
||||||
|
self._max_buffered_amount = max_buffered_amount
|
||||||
|
self._write_lock = trio.Lock()
|
||||||
|
|
||||||
async def write(self, data: bytes) -> None:
|
async def write(self, data: bytes) -> None:
|
||||||
|
"""Write data with flow control and buffer management"""
|
||||||
if self._closed:
|
if self._closed:
|
||||||
raise IOException("Connection is closed")
|
raise IOException("Connection is closed")
|
||||||
|
|
||||||
try:
|
async with self._write_lock:
|
||||||
# Send as a binary WebSocket message
|
try:
|
||||||
await self._ws_connection.send_message(data)
|
logger.debug(f"WebSocket writing {len(data)} bytes")
|
||||||
self._bytes_written += len(data)
|
|
||||||
except Exception as e:
|
# Check buffer amount for flow control
|
||||||
logger.error(f"WebSocket write failed: {e}")
|
if hasattr(self._ws_connection, "bufferedAmount"):
|
||||||
raise IOException from e
|
buffered = self._ws_connection.bufferedAmount
|
||||||
|
if buffered > self._max_buffered_amount:
|
||||||
|
logger.warning(f"WebSocket buffer full: {buffered} bytes")
|
||||||
|
# In production, you might want to
|
||||||
|
# wait or implement backpressure
|
||||||
|
# For now, we'll continue but log the warning
|
||||||
|
|
||||||
|
# Send as a binary WebSocket message
|
||||||
|
await self._ws_connection.send_message(data)
|
||||||
|
self._bytes_written += len(data)
|
||||||
|
logger.debug(f"WebSocket wrote {len(data)} bytes successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket write failed: {e}")
|
||||||
|
self._closed = True
|
||||||
|
raise IOException from e
|
||||||
|
|
||||||
async def read(self, n: int | None = None) -> bytes:
|
async def read(self, n: int | None = None) -> bytes:
|
||||||
"""
|
"""
|
||||||
@ -122,18 +147,25 @@ class P2PWebSocketConnection(ReadWriteCloser):
|
|||||||
return # Already closed
|
return # Already closed
|
||||||
|
|
||||||
logger.debug("WebSocket connection closing")
|
logger.debug("WebSocket connection closing")
|
||||||
|
self._closed = True
|
||||||
try:
|
try:
|
||||||
# Always close the connection directly, avoid context manager issues
|
# Always close the connection directly, avoid context manager issues
|
||||||
# The context manager may be causing cancel scope corruption
|
# The context manager may be causing cancel scope corruption
|
||||||
logger.debug("WebSocket closing connection directly")
|
logger.debug("WebSocket closing connection directly")
|
||||||
await self._ws_connection.aclose()
|
await self._ws_connection.aclose()
|
||||||
|
# Exit the context manager if we have one
|
||||||
|
if self._ws_context is not None:
|
||||||
|
await self._ws_context.__aexit__(None, None, None)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"WebSocket close error: {e}")
|
logger.error(f"WebSocket close error: {e}")
|
||||||
# Don't raise here, as close() should be idempotent
|
# Don't raise here, as close() should be idempotent
|
||||||
finally:
|
finally:
|
||||||
self._closed = True
|
|
||||||
logger.debug("WebSocket connection closed")
|
logger.debug("WebSocket connection closed")
|
||||||
|
|
||||||
|
def is_closed(self) -> bool:
|
||||||
|
"""Check if the connection is closed"""
|
||||||
|
return self._closed
|
||||||
|
|
||||||
def conn_state(self) -> dict[str, Any]:
|
def conn_state(self) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Return connection state information similar to Go's ConnState() method.
|
Return connection state information similar to Go's ConnState() method.
|
||||||
|
|||||||
@ -19,6 +19,13 @@ logger = logging.getLogger(__name__)
|
|||||||
class WebsocketTransport(ITransport):
|
class WebsocketTransport(ITransport):
|
||||||
"""
|
"""
|
||||||
Libp2p WebSocket transport: dial and listen on /ip4/.../tcp/.../ws and /wss
|
Libp2p WebSocket transport: dial and listen on /ip4/.../tcp/.../ws and /wss
|
||||||
|
|
||||||
|
Implements production-ready WebSocket transport with:
|
||||||
|
- Flow control and buffer management
|
||||||
|
- Connection limits and rate limiting
|
||||||
|
- Proper error handling and cleanup
|
||||||
|
- Support for both WS and WSS protocols
|
||||||
|
- TLS configuration and handshake timeout
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -27,11 +34,15 @@ class WebsocketTransport(ITransport):
|
|||||||
tls_client_config: ssl.SSLContext | None = None,
|
tls_client_config: ssl.SSLContext | None = None,
|
||||||
tls_server_config: ssl.SSLContext | None = None,
|
tls_server_config: ssl.SSLContext | None = None,
|
||||||
handshake_timeout: float = 15.0,
|
handshake_timeout: float = 15.0,
|
||||||
|
max_buffered_amount: int = 4 * 1024 * 1024,
|
||||||
):
|
):
|
||||||
self._upgrader = upgrader
|
self._upgrader = upgrader
|
||||||
self._tls_client_config = tls_client_config
|
self._tls_client_config = tls_client_config
|
||||||
self._tls_server_config = tls_server_config
|
self._tls_server_config = tls_server_config
|
||||||
self._handshake_timeout = handshake_timeout
|
self._handshake_timeout = handshake_timeout
|
||||||
|
self._max_buffered_amount = max_buffered_amount
|
||||||
|
self._connection_count = 0
|
||||||
|
self._max_connections = 1000 # Production limit
|
||||||
|
|
||||||
async def dial(self, maddr: Multiaddr) -> RawConnection:
|
async def dial(self, maddr: Multiaddr) -> RawConnection:
|
||||||
"""Dial a WebSocket connection to the given multiaddr."""
|
"""Dial a WebSocket connection to the given multiaddr."""
|
||||||
@ -67,6 +78,12 @@ class WebsocketTransport(ITransport):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Check connection limits
|
||||||
|
if self._connection_count >= self._max_connections:
|
||||||
|
raise OpenConnectionError(
|
||||||
|
f"Maximum connections reached: {self._max_connections}"
|
||||||
|
)
|
||||||
|
|
||||||
# Prepare SSL context for WSS connections
|
# Prepare SSL context for WSS connections
|
||||||
ssl_context = None
|
ssl_context = None
|
||||||
if parsed.is_wss:
|
if parsed.is_wss:
|
||||||
@ -100,10 +117,6 @@ class WebsocketTransport(ITransport):
|
|||||||
f"port={ws_port}, resource={ws_resource}"
|
f"port={ws_port}, resource={ws_resource}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Instead of fighting trio-websocket's lifecycle, let's try using
|
|
||||||
# a persistent task that will keep the WebSocket alive
|
|
||||||
# This mimics what trio-websocket does internally but with our control
|
|
||||||
|
|
||||||
# Create a background task manager for this connection
|
# Create a background task manager for this connection
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
@ -127,11 +140,18 @@ class WebsocketTransport(ITransport):
|
|||||||
)
|
)
|
||||||
logger.debug("WebsocketTransport.dial WebSocket connection established")
|
logger.debug("WebsocketTransport.dial WebSocket connection established")
|
||||||
|
|
||||||
# Create our connection wrapper
|
# Create our connection wrapper with both WSS support and flow control
|
||||||
# Pass None for nursery since we're using the parent nursery
|
conn = P2PWebSocketConnection(
|
||||||
conn = P2PWebSocketConnection(ws, None, is_secure=parsed.is_wss)
|
ws,
|
||||||
|
None,
|
||||||
|
is_secure=parsed.is_wss,
|
||||||
|
max_buffered_amount=self._max_buffered_amount
|
||||||
|
)
|
||||||
logger.debug("WebsocketTransport.dial created P2PWebSocketConnection")
|
logger.debug("WebsocketTransport.dial created P2PWebSocketConnection")
|
||||||
|
|
||||||
|
self._connection_count += 1
|
||||||
|
logger.debug(f"Total connections: {self._connection_count}")
|
||||||
|
|
||||||
return RawConnection(conn, initiator=True)
|
return RawConnection(conn, initiator=True)
|
||||||
except trio.TooSlowError as e:
|
except trio.TooSlowError as e:
|
||||||
raise OpenConnectionError(
|
raise OpenConnectionError(
|
||||||
@ -139,6 +159,7 @@ class WebsocketTransport(ITransport):
|
|||||||
f"for {maddr}"
|
f"for {maddr}"
|
||||||
) from e
|
) from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to dial WebSocket {maddr}: {e}")
|
||||||
raise OpenConnectionError(f"Failed to dial WebSocket {maddr}: {e}") from e
|
raise OpenConnectionError(f"Failed to dial WebSocket {maddr}: {e}") from e
|
||||||
|
|
||||||
def create_listener(self, handler: THandler) -> IListener: # type: ignore[override]
|
def create_listener(self, handler: THandler) -> IListener: # type: ignore[override]
|
||||||
|
|||||||
1
newsfragments/896.bugfix.rst
Normal file
1
newsfragments/896.bugfix.rst
Normal file
@ -0,0 +1 @@
|
|||||||
|
Exposed timeout method in muxer multistream and updated all the usage. Added testcases to verify that timeout value is passed correctly
|
||||||
1
newsfragments/927.bugfix.rst
Normal file
1
newsfragments/927.bugfix.rst
Normal file
@ -0,0 +1 @@
|
|||||||
|
Fix multiaddr dependency to use the last py-multiaddr commit hash to resolve installation issues
|
||||||
@ -24,7 +24,7 @@ dependencies = [
|
|||||||
"grpcio>=1.41.0",
|
"grpcio>=1.41.0",
|
||||||
"lru-dict>=1.1.6",
|
"lru-dict>=1.1.6",
|
||||||
# "multiaddr (>=0.0.9,<0.0.10)",
|
# "multiaddr (>=0.0.9,<0.0.10)",
|
||||||
"multiaddr @ git+https://github.com/multiformats/py-multiaddr.git@3ea7f866fda9268ee92506edf9d8e975274bf941",
|
"multiaddr @ git+https://github.com/multiformats/py-multiaddr.git@b186e2ccadc22545dec4069ff313787bf29265e0",
|
||||||
"mypy-protobuf>=3.0.0",
|
"mypy-protobuf>=3.0.0",
|
||||||
"noiseprotocol>=0.3.0",
|
"noiseprotocol>=0.3.0",
|
||||||
"protobuf>=4.25.0,<5.0.0",
|
"protobuf>=4.25.0,<5.0.0",
|
||||||
|
|||||||
108
tests/core/stream_muxer/test_muxer_multistream.py
Normal file
108
tests/core/stream_muxer/test_muxer_multistream.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
from unittest.mock import (
|
||||||
|
AsyncMock,
|
||||||
|
MagicMock,
|
||||||
|
)
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from libp2p.custom_types import (
|
||||||
|
TMuxerClass,
|
||||||
|
TProtocol,
|
||||||
|
)
|
||||||
|
from libp2p.peer.id import (
|
||||||
|
ID,
|
||||||
|
)
|
||||||
|
from libp2p.protocol_muxer.exceptions import (
|
||||||
|
MultiselectError,
|
||||||
|
)
|
||||||
|
from libp2p.stream_muxer.muxer_multistream import (
|
||||||
|
MuxerMultistream,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_muxer_timeout_configuration():
|
||||||
|
"""Test that muxer respects timeout configuration."""
|
||||||
|
muxer = MuxerMultistream({}, negotiate_timeout=1)
|
||||||
|
assert muxer.negotiate_timeout == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_select_transport_passes_timeout_to_multiselect():
|
||||||
|
"""Test that timeout is passed to multiselect client in select_transport."""
|
||||||
|
# Mock dependencies
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_conn.is_initiator = False
|
||||||
|
|
||||||
|
# Mock MultiselectClient
|
||||||
|
muxer = MuxerMultistream({}, negotiate_timeout=10)
|
||||||
|
muxer.multiselect.negotiate = AsyncMock(return_value=("mock_protocol", None))
|
||||||
|
muxer.transports[TProtocol("mock_protocol")] = MagicMock(return_value=MagicMock())
|
||||||
|
|
||||||
|
# Call select_transport
|
||||||
|
await muxer.select_transport(mock_conn)
|
||||||
|
|
||||||
|
# Verify that select_one_of was called with the correct timeout
|
||||||
|
args, _ = muxer.multiselect.negotiate.call_args
|
||||||
|
assert args[1] == 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_new_conn_passes_timeout_to_multistream_client():
|
||||||
|
"""Test that timeout is passed to multistream client in new_conn."""
|
||||||
|
# Mock dependencies
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_conn.is_initiator = True
|
||||||
|
mock_peer_id = ID(b"test_peer")
|
||||||
|
mock_communicator = MagicMock()
|
||||||
|
|
||||||
|
# Mock MultistreamClient and transports
|
||||||
|
muxer = MuxerMultistream({}, negotiate_timeout=30)
|
||||||
|
muxer.multistream_client.select_one_of = AsyncMock(return_value="mock_protocol")
|
||||||
|
muxer.transports[TProtocol("mock_protocol")] = MagicMock(return_value=MagicMock())
|
||||||
|
|
||||||
|
# Call new_conn
|
||||||
|
await muxer.new_conn(mock_conn, mock_peer_id)
|
||||||
|
|
||||||
|
# Verify that select_one_of was called with the correct timeout
|
||||||
|
muxer.multistream_client.select_one_of(
|
||||||
|
tuple(muxer.transports.keys()), mock_communicator, 30
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_select_transport_no_protocol_selected():
|
||||||
|
"""
|
||||||
|
Test that select_transport raises MultiselectError when no protocol is selected.
|
||||||
|
"""
|
||||||
|
# Mock dependencies
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_conn.is_initiator = False
|
||||||
|
|
||||||
|
# Mock Multiselect to return None
|
||||||
|
muxer = MuxerMultistream({}, negotiate_timeout=30)
|
||||||
|
muxer.multiselect.negotiate = AsyncMock(return_value=(None, None))
|
||||||
|
|
||||||
|
# Expect MultiselectError to be raised
|
||||||
|
with pytest.raises(MultiselectError, match="no protocol selected"):
|
||||||
|
await muxer.select_transport(mock_conn)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_add_transport_updates_precedence():
|
||||||
|
"""Test that adding a transport updates protocol precedence."""
|
||||||
|
# Mock transport classes
|
||||||
|
mock_transport1 = MagicMock(spec=TMuxerClass)
|
||||||
|
mock_transport2 = MagicMock(spec=TMuxerClass)
|
||||||
|
|
||||||
|
# Initialize muxer and add transports
|
||||||
|
muxer = MuxerMultistream({}, negotiate_timeout=30)
|
||||||
|
muxer.add_transport(TProtocol("proto1"), mock_transport1)
|
||||||
|
muxer.add_transport(TProtocol("proto2"), mock_transport2)
|
||||||
|
|
||||||
|
# Verify transport order
|
||||||
|
assert list(muxer.transports.keys()) == ["proto1", "proto2"]
|
||||||
|
|
||||||
|
# Re-add proto1 to check if it moves to the end
|
||||||
|
muxer.add_transport(TProtocol("proto1"), mock_transport1)
|
||||||
|
assert list(muxer.transports.keys()) == ["proto2", "proto1"]
|
||||||
27
tests/core/transport/test_upgrader.py
Normal file
27
tests/core/transport/test_upgrader.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from libp2p.custom_types import (
|
||||||
|
TMuxerOptions,
|
||||||
|
TSecurityOptions,
|
||||||
|
)
|
||||||
|
from libp2p.transport.upgrader import (
|
||||||
|
TransportUpgrader,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_transport_upgrader_security_and_muxer_initialization():
|
||||||
|
"""Test TransportUpgrader initializes security and muxer multistreams correctly."""
|
||||||
|
secure_transports: TSecurityOptions = {}
|
||||||
|
muxer_transports: TMuxerOptions = {}
|
||||||
|
negotiate_timeout = 15
|
||||||
|
|
||||||
|
upgrader = TransportUpgrader(
|
||||||
|
secure_transports, muxer_transports, negotiate_timeout=negotiate_timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify security multistream initialization
|
||||||
|
assert upgrader.security_multistream.transports == secure_transports
|
||||||
|
# Verify muxer multistream initialization and timeout
|
||||||
|
assert upgrader.muxer_multistream.transports == muxer_transports
|
||||||
|
assert upgrader.muxer_multistream.negotiate_timeout == negotiate_timeout
|
||||||
@ -10,12 +10,11 @@
|
|||||||
"license": "ISC",
|
"license": "ISC",
|
||||||
"description": "",
|
"description": "",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@libp2p/ping": "^2.0.36",
|
"@chainsafe/libp2p-noise": "^9.0.0",
|
||||||
"@libp2p/websockets": "^9.2.18",
|
|
||||||
"@chainsafe/libp2p-yamux": "^5.0.1",
|
"@chainsafe/libp2p-yamux": "^5.0.1",
|
||||||
"@chainsafe/libp2p-noise": "^16.0.1",
|
"@libp2p/ping": "^2.0.36",
|
||||||
"@libp2p/plaintext": "^2.0.7",
|
"@libp2p/plaintext": "^2.0.29",
|
||||||
"@libp2p/identify": "^3.0.39",
|
"@libp2p/websockets": "^9.2.18",
|
||||||
"libp2p": "^2.9.0",
|
"libp2p": "^2.9.0",
|
||||||
"multiaddr": "^10.0.1"
|
"multiaddr": "^10.0.1"
|
||||||
}
|
}
|
||||||
|
|||||||
@ -9,8 +9,16 @@ from trio.lowlevel import open_process
|
|||||||
|
|
||||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||||
from libp2p.custom_types import TProtocol
|
from libp2p.custom_types import TProtocol
|
||||||
|
from libp2p.host.basic_host import BasicHost
|
||||||
from libp2p.network.exceptions import SwarmException
|
from libp2p.network.exceptions import SwarmException
|
||||||
|
from libp2p.network.swarm import Swarm
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
|
from libp2p.peer.peerinfo import PeerInfo
|
||||||
|
from libp2p.peer.peerstore import PeerStore
|
||||||
|
from libp2p.security.insecure.transport import InsecureTransport
|
||||||
|
from libp2p.stream_muxer.yamux.yamux import Yamux
|
||||||
|
from libp2p.transport.upgrader import TransportUpgrader
|
||||||
|
from libp2p.transport.websocket.transport import WebsocketTransport
|
||||||
|
|
||||||
PLAINTEXT_PROTOCOL_ID = "/plaintext/2.0.0"
|
PLAINTEXT_PROTOCOL_ID = "/plaintext/2.0.0"
|
||||||
|
|
||||||
@ -20,254 +28,98 @@ async def test_ping_with_js_node():
|
|||||||
js_node_dir = os.path.join(os.path.dirname(__file__), "js_libp2p", "js_node", "src")
|
js_node_dir = os.path.join(os.path.dirname(__file__), "js_libp2p", "js_node", "src")
|
||||||
script_name = "./ws_ping_node.mjs"
|
script_name = "./ws_ping_node.mjs"
|
||||||
|
|
||||||
# Debug: Check if JS node directory exists
|
|
||||||
print(f"JS Node Directory: {js_node_dir}")
|
|
||||||
print(f"JS Node Directory exists: {os.path.exists(js_node_dir)}")
|
|
||||||
|
|
||||||
if os.path.exists(js_node_dir):
|
|
||||||
print(f"JS Node Directory contents: {os.listdir(js_node_dir)}")
|
|
||||||
script_path = os.path.join(js_node_dir, script_name)
|
|
||||||
print(f"Script path: {script_path}")
|
|
||||||
print(f"Script exists: {os.path.exists(script_path)}")
|
|
||||||
|
|
||||||
if os.path.exists(script_path):
|
|
||||||
with open(script_path) as f:
|
|
||||||
script_content = f.read()
|
|
||||||
print(f"Script content (first 500 chars): {script_content[:500]}...")
|
|
||||||
|
|
||||||
# Debug: Check if npm is available
|
|
||||||
try:
|
try:
|
||||||
npm_version = subprocess.run(
|
subprocess.run(
|
||||||
["npm", "--version"],
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
check=True,
|
|
||||||
)
|
|
||||||
print(f"NPM version: {npm_version.stdout.strip()}")
|
|
||||||
except (subprocess.CalledProcessError, FileNotFoundError) as e:
|
|
||||||
print(f"NPM not available: {e}")
|
|
||||||
|
|
||||||
# Debug: Check if node is available
|
|
||||||
try:
|
|
||||||
node_version = subprocess.run(
|
|
||||||
["node", "--version"],
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
check=True,
|
|
||||||
)
|
|
||||||
print(f"Node version: {node_version.stdout.strip()}")
|
|
||||||
except (subprocess.CalledProcessError, FileNotFoundError) as e:
|
|
||||||
print(f"Node not available: {e}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
print(f"Running npm install in {js_node_dir}...")
|
|
||||||
npm_install_result = subprocess.run(
|
|
||||||
["npm", "install"],
|
["npm", "install"],
|
||||||
cwd=js_node_dir,
|
cwd=js_node_dir,
|
||||||
check=True,
|
check=True,
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
text=True,
|
text=True,
|
||||||
)
|
)
|
||||||
print(f"NPM install stdout: {npm_install_result.stdout}")
|
|
||||||
print(f"NPM install stderr: {npm_install_result.stderr}")
|
|
||||||
except (subprocess.CalledProcessError, FileNotFoundError) as e:
|
except (subprocess.CalledProcessError, FileNotFoundError) as e:
|
||||||
print(f"NPM install failed: {e}")
|
|
||||||
pytest.fail(f"Failed to run 'npm install': {e}")
|
pytest.fail(f"Failed to run 'npm install': {e}")
|
||||||
|
|
||||||
# Launch the JS libp2p node (long-running)
|
# Launch the JS libp2p node (long-running)
|
||||||
print(f"Launching JS node: node {script_name} in {js_node_dir}")
|
|
||||||
proc = await open_process(
|
proc = await open_process(
|
||||||
["node", script_name],
|
["node", script_name],
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.PIPE,
|
stderr=subprocess.PIPE,
|
||||||
cwd=js_node_dir,
|
cwd=js_node_dir,
|
||||||
)
|
)
|
||||||
print(f"JS node process started with PID: {proc.pid}")
|
|
||||||
assert proc.stdout is not None, "stdout pipe missing"
|
assert proc.stdout is not None, "stdout pipe missing"
|
||||||
assert proc.stderr is not None, "stderr pipe missing"
|
assert proc.stderr is not None, "stderr pipe missing"
|
||||||
stdout = proc.stdout
|
stdout = proc.stdout
|
||||||
stderr = proc.stderr
|
stderr = proc.stderr
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Read JS node output until we get peer ID and multiaddrs
|
# Read first two lines (PeerID and multiaddr)
|
||||||
print("Waiting for JS node to output PeerID and multiaddrs...")
|
|
||||||
buffer = b""
|
buffer = b""
|
||||||
peer_id_found: str | bool = False
|
|
||||||
multiaddrs_found = []
|
|
||||||
|
|
||||||
with trio.fail_after(30):
|
with trio.fail_after(30):
|
||||||
while True:
|
while buffer.count(b"\n") < 2:
|
||||||
chunk = await stdout.receive_some(1024)
|
chunk = await stdout.receive_some(1024)
|
||||||
if not chunk:
|
if not chunk:
|
||||||
print("No more data from JS node stdout")
|
|
||||||
break
|
break
|
||||||
buffer += chunk
|
buffer += chunk
|
||||||
print(f"Received chunk: {chunk}")
|
|
||||||
|
|
||||||
# Parse lines as we receive them
|
lines = [line for line in buffer.decode().splitlines() if line.strip()]
|
||||||
lines = buffer.decode().splitlines()
|
if len(lines) < 2:
|
||||||
for line in lines:
|
|
||||||
line = line.strip()
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Look for peer ID (starts with "12D3Koo")
|
|
||||||
if line.startswith("12D3Koo") and not peer_id_found:
|
|
||||||
peer_id_found = line
|
|
||||||
print(f"Found peer ID: {peer_id_found}")
|
|
||||||
|
|
||||||
# Look for multiaddrs (start with "/ip4/" or "/ip6/")
|
|
||||||
elif line.startswith("/ip4/") or line.startswith("/ip6/"):
|
|
||||||
if line not in multiaddrs_found:
|
|
||||||
multiaddrs_found.append(line)
|
|
||||||
print(f"Found multiaddr: {line}")
|
|
||||||
|
|
||||||
# Stop when we have peer ID and at least one multiaddr
|
|
||||||
if peer_id_found and multiaddrs_found:
|
|
||||||
print(f"✅ Collected: Peer ID + {len(multiaddrs_found)} multiaddrs")
|
|
||||||
break
|
|
||||||
|
|
||||||
print(f"Total buffer received: {buffer}")
|
|
||||||
all_lines = [line for line in buffer.decode().splitlines() if line.strip()]
|
|
||||||
print(f"All JS Node lines: {all_lines}")
|
|
||||||
|
|
||||||
if not peer_id_found or not multiaddrs_found:
|
|
||||||
print("Missing peer ID or multiaddrs from JS node, checking stderr...")
|
|
||||||
stderr_output = await stderr.receive_some(2048)
|
stderr_output = await stderr.receive_some(2048)
|
||||||
stderr_output = stderr_output.decode()
|
stderr_output = stderr_output.decode()
|
||||||
print(f"JS node stderr: {stderr_output}")
|
|
||||||
pytest.fail(
|
pytest.fail(
|
||||||
"JS node did not produce expected PeerID and multiaddr.\n"
|
"JS node did not produce expected PeerID and multiaddr.\n"
|
||||||
f"Found peer ID: {peer_id_found}\n"
|
|
||||||
f"Found multiaddrs: {multiaddrs_found}\n"
|
|
||||||
f"Stdout: {buffer.decode()!r}\n"
|
f"Stdout: {buffer.decode()!r}\n"
|
||||||
f"Stderr: {stderr_output!r}"
|
f"Stderr: {stderr_output!r}"
|
||||||
)
|
)
|
||||||
|
peer_id_line, addr_line = lines[0], lines[1]
|
||||||
# peer_id = ID.from_base58(peer_id_found) # Not used currently
|
peer_id = ID.from_base58(peer_id_line)
|
||||||
# Use the first localhost multiaddr preferentially, or fallback to first
|
maddr = Multiaddr(addr_line)
|
||||||
# available
|
|
||||||
maddr = None
|
|
||||||
for addr_str in multiaddrs_found:
|
|
||||||
if "127.0.0.1" in addr_str:
|
|
||||||
maddr = Multiaddr(addr_str)
|
|
||||||
break
|
|
||||||
if not maddr:
|
|
||||||
maddr = Multiaddr(multiaddrs_found[0])
|
|
||||||
|
|
||||||
# Debug: Print what we're trying to connect to
|
# Debug: Print what we're trying to connect to
|
||||||
print(f"JS Node Peer ID: {peer_id_found}")
|
print(f"JS Node Peer ID: {peer_id_line}")
|
||||||
print(f"JS Node Address: {maddr}")
|
print(f"JS Node Address: {addr_line}")
|
||||||
print(f"All found multiaddrs: {multiaddrs_found}")
|
print(f"All JS Node lines: {lines}")
|
||||||
print(f"Selected multiaddr: {maddr}")
|
|
||||||
|
|
||||||
# Set up Python host using new_host API with Noise security
|
|
||||||
print("Setting up Python host...")
|
|
||||||
from libp2p import create_yamux_muxer_option, new_host
|
|
||||||
|
|
||||||
|
# Set up Python host
|
||||||
key_pair = create_new_key_pair()
|
key_pair = create_new_key_pair()
|
||||||
# noise_key_pair = create_new_x25519_key_pair() # Not used currently
|
py_peer_id = ID.from_pubkey(key_pair.public_key)
|
||||||
print(f"Python Peer ID: {ID.from_pubkey(key_pair.public_key)}")
|
peer_store = PeerStore()
|
||||||
|
peer_store.add_key_pair(py_peer_id, key_pair)
|
||||||
|
|
||||||
# Use default security options (includes Noise, SecIO, and plaintext)
|
upgrader = TransportUpgrader(
|
||||||
# This will allow protocol negotiation to choose the best match
|
secure_transports_by_protocol={
|
||||||
host = new_host(
|
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
|
||||||
key_pair=key_pair,
|
},
|
||||||
muxer_opt=create_yamux_muxer_option(),
|
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
||||||
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")],
|
|
||||||
)
|
)
|
||||||
print(f"Python host created: {host}")
|
transport = WebsocketTransport(upgrader)
|
||||||
|
swarm = Swarm(py_peer_id, peer_store, upgrader, transport)
|
||||||
|
host = BasicHost(swarm)
|
||||||
|
|
||||||
# Connect to JS node using modern peer info
|
# Connect to JS node
|
||||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
peer_info = PeerInfo(peer_id, [maddr])
|
||||||
|
|
||||||
peer_info = info_from_p2p_addr(maddr)
|
|
||||||
print(f"Python trying to connect to: {peer_info}")
|
print(f"Python trying to connect to: {peer_info}")
|
||||||
print(f"Peer info addresses: {peer_info.addrs}")
|
|
||||||
|
|
||||||
# Test WebSocket multiaddr validation
|
# Use the host as a context manager
|
||||||
from libp2p.transport.websocket.multiaddr_utils import (
|
|
||||||
is_valid_websocket_multiaddr,
|
|
||||||
parse_websocket_multiaddr,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Is valid WebSocket multiaddr: {is_valid_websocket_multiaddr(maddr)}")
|
|
||||||
try:
|
|
||||||
parsed = parse_websocket_multiaddr(maddr)
|
|
||||||
print(
|
|
||||||
f"Parsed WebSocket multiaddr: is_wss={parsed.is_wss}, "
|
|
||||||
f"sni={parsed.sni}, rest_multiaddr={parsed.rest_multiaddr}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to parse WebSocket multiaddr: {e}")
|
|
||||||
|
|
||||||
# Use proper host.run() context manager
|
|
||||||
async with host.run(listen_addrs=[]):
|
async with host.run(listen_addrs=[]):
|
||||||
await trio.sleep(1)
|
await trio.sleep(1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print("Attempting to connect to JS node...")
|
|
||||||
await host.connect(peer_info)
|
await host.connect(peer_info)
|
||||||
print("Successfully connected to JS node!")
|
|
||||||
except SwarmException as e:
|
except SwarmException as e:
|
||||||
underlying_error = e.__cause__
|
underlying_error = e.__cause__
|
||||||
print(f"Connection failed with SwarmException: {e}")
|
|
||||||
print(f"Underlying error: {underlying_error}")
|
|
||||||
pytest.fail(
|
pytest.fail(
|
||||||
"Connection failed with SwarmException.\n"
|
"Connection failed with SwarmException.\n"
|
||||||
f"THE REAL ERROR IS: {underlying_error!r}\n"
|
f"THE REAL ERROR IS: {underlying_error!r}\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify connection was established
|
assert host.get_network().connections.get(peer_id) is not None
|
||||||
assert host.get_network().connections.get(peer_info.peer_id) is not None
|
|
||||||
|
|
||||||
# Try to ping the JS node
|
# Ping protocol
|
||||||
ping_protocol = TProtocol("/ipfs/ping/1.0.0")
|
stream = await host.new_stream(peer_id, [TProtocol("/ipfs/ping/1.0.0")])
|
||||||
try:
|
await stream.write(b"ping")
|
||||||
print("Opening ping stream...")
|
data = await stream.read(4)
|
||||||
stream = await host.new_stream(peer_info.peer_id, [ping_protocol])
|
assert data == b"pong"
|
||||||
print("Ping stream opened successfully!")
|
|
||||||
|
|
||||||
# Send ping data (32 bytes as per libp2p ping protocol)
|
|
||||||
ping_data = b"\x00" * 32
|
|
||||||
await stream.write(ping_data)
|
|
||||||
print(f"Sent ping: {len(ping_data)} bytes")
|
|
||||||
|
|
||||||
# Wait for pong response
|
|
||||||
pong_data = await stream.read(32)
|
|
||||||
print(f"Received pong: {len(pong_data)} bytes")
|
|
||||||
|
|
||||||
# Verify the pong matches the ping
|
|
||||||
assert pong_data == ping_data, (
|
|
||||||
f"Ping/pong mismatch: {ping_data!r} != {pong_data!r}"
|
|
||||||
)
|
|
||||||
print("✅ Ping/pong successful!")
|
|
||||||
|
|
||||||
await stream.close()
|
|
||||||
print("Stream closed successfully!")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Ping failed: {e}")
|
|
||||||
pytest.fail(f"Ping failed: {e}")
|
|
||||||
|
|
||||||
print("🎉 JavaScript WebSocket interop test completed successfully!")
|
|
||||||
finally:
|
finally:
|
||||||
print(f"Terminating JS node process (PID: {proc.pid})...")
|
proc.send_signal(signal.SIGTERM)
|
||||||
try:
|
|
||||||
proc.send_signal(signal.SIGTERM)
|
|
||||||
print("SIGTERM sent to JS node process")
|
|
||||||
await trio.sleep(1) # Give it time to terminate gracefully
|
|
||||||
if proc.poll() is None:
|
|
||||||
print("JS node process still running, sending SIGKILL...")
|
|
||||||
proc.send_signal(signal.SIGKILL)
|
|
||||||
await trio.sleep(0.5)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error terminating JS node process: {e}")
|
|
||||||
|
|
||||||
# Check if process is still running
|
|
||||||
if proc.poll() is None:
|
|
||||||
print("WARNING: JS node process is still running!")
|
|
||||||
else:
|
|
||||||
print(f"JS node process terminated with exit code: {proc.poll()}")
|
|
||||||
|
|
||||||
await trio.sleep(0)
|
await trio.sleep(0)
|
||||||
|
|||||||
Reference in New Issue
Block a user