diff --git a/debug_websocket_url.py b/debug_websocket_url.py deleted file mode 100644 index 328ddbd5..00000000 --- a/debug_websocket_url.py +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env python3 -""" -Debug script to test WebSocket URL construction and basic connection. -""" - -import logging - -from multiaddr import Multiaddr - -from libp2p.transport.websocket.multiaddr_utils import parse_websocket_multiaddr - -# Configure logging -logging.basicConfig(level=logging.DEBUG) -logger = logging.getLogger(__name__) - - -async def test_websocket_url(): - """Test WebSocket URL construction.""" - # Test multiaddr from your JS node - maddr_str = "/ip4/127.0.0.1/tcp/35391/ws/p2p/12D3KooWQh7p5xP2ppr3CrhUFsawmsKNe9jgDbacQdWCYpuGfMVN" - maddr = Multiaddr(maddr_str) - - logger.info(f"Testing multiaddr: {maddr}") - - # Parse WebSocket multiaddr - parsed = parse_websocket_multiaddr(maddr) - logger.info( - f"Parsed: is_wss={parsed.is_wss}, sni={parsed.sni}, rest_multiaddr={parsed.rest_multiaddr}" - ) - - # Construct WebSocket URL - if parsed.is_wss: - protocol = "wss" - else: - protocol = "ws" - - # Extract host and port from rest_multiaddr - host = parsed.rest_multiaddr.value_for_protocol("ip4") - port = parsed.rest_multiaddr.value_for_protocol("tcp") - - websocket_url = f"{protocol}://{host}:{port}/" - logger.info(f"WebSocket URL: {websocket_url}") - - # Test basic WebSocket connection - try: - from trio_websocket import open_websocket_url - - logger.info("Testing basic WebSocket connection...") - async with open_websocket_url(websocket_url) as ws: - logger.info("โœ… WebSocket connection successful!") - # Send a simple message - await ws.send_message(b"test") - logger.info("โœ… Message sent successfully!") - - except Exception as e: - logger.error(f"โŒ WebSocket connection failed: {e}") - import traceback - - logger.error(f"Traceback: {traceback.format_exc()}") - - -if __name__ == "__main__": - import trio - - trio.run(test_websocket_url) diff --git a/examples/test_tcp_data_transfer.py b/examples/test_tcp_data_transfer.py new file mode 100644 index 00000000..634386bd --- /dev/null +++ b/examples/test_tcp_data_transfer.py @@ -0,0 +1,446 @@ +#!/usr/bin/env python3 +""" +TCP P2P Data Transfer Test + +This test proves that TCP peer-to-peer data transfer works correctly in libp2p. +This serves as a baseline to compare with WebSocket tests. +""" + +import pytest +from multiaddr import Multiaddr +import trio + +from libp2p import create_yamux_muxer_option, new_host +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.custom_types import TProtocol +from libp2p.peer.peerinfo import info_from_p2p_addr +from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport + +# Test protocol for data exchange +TCP_DATA_PROTOCOL = TProtocol("/test/tcp-data-exchange/1.0.0") + + +async def create_tcp_host_pair(): + """Create a pair of hosts configured for TCP communication.""" + # Create key pairs + key_pair_a = create_new_key_pair() + key_pair_b = create_new_key_pair() + + # Create security options (using plaintext for simplicity) + def security_options(kp): + return { + PLAINTEXT_PROTOCOL_ID: InsecureTransport( + local_key_pair=kp, secure_bytes_provider=None, peerstore=None + ) + } + + # Host A (listener) - TCP transport (default) + host_a = new_host( + key_pair=key_pair_a, + sec_opt=security_options(key_pair_a), + muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")], + ) + + # Host B (dialer) - TCP transport (default) + host_b = new_host( + key_pair=key_pair_b, + sec_opt=security_options(key_pair_b), + muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")], + ) + + return host_a, host_b + + +@pytest.mark.trio +async def test_tcp_basic_connection(): + """Test basic TCP connection establishment.""" + host_a, host_b = await create_tcp_host_pair() + + connection_established = False + + async def connection_handler(stream): + nonlocal connection_established + connection_established = True + await stream.close() + + host_a.set_stream_handler(TCP_DATA_PROTOCOL, connection_handler) + + async with ( + host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]), + host_b.run(listen_addrs=[]), + ): + # Get host A's listen address + listen_addrs = host_a.get_addrs() + assert listen_addrs, "Host A should have listen addresses" + + # Extract TCP address + tcp_addr = None + for addr in listen_addrs: + if "/tcp/" in str(addr) and "/ws" not in str(addr): + tcp_addr = addr + break + + assert tcp_addr, f"No TCP address found in {listen_addrs}" + print(f"๐Ÿ”— Host A listening on: {tcp_addr}") + + # Create peer info for host A + peer_info = info_from_p2p_addr(tcp_addr) + + # Host B connects to host A + await host_b.connect(peer_info) + print("โœ… TCP connection established") + + # Open a stream to test the connection + stream = await host_b.new_stream(peer_info.peer_id, [TCP_DATA_PROTOCOL]) + await stream.close() + + # Wait a bit for the handler to be called + await trio.sleep(0.1) + + assert connection_established, "TCP connection handler should have been called" + print("โœ… TCP basic connection test successful!") + + +@pytest.mark.trio +async def test_tcp_data_transfer(): + """Test TCP peer-to-peer data transfer.""" + host_a, host_b = await create_tcp_host_pair() + + # Test data + test_data = b"Hello TCP P2P Data Transfer! This is a test message." + received_data = None + transfer_complete = trio.Event() + + async def data_handler(stream): + nonlocal received_data + try: + # Read the incoming data + received_data = await stream.read(len(test_data)) + # Echo it back to confirm successful transfer + await stream.write(received_data) + await stream.close() + transfer_complete.set() + except Exception as e: + print(f"Handler error: {e}") + transfer_complete.set() + + host_a.set_stream_handler(TCP_DATA_PROTOCOL, data_handler) + + async with ( + host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]), + host_b.run(listen_addrs=[]), + ): + # Get host A's listen address + listen_addrs = host_a.get_addrs() + assert listen_addrs, "Host A should have listen addresses" + + # Extract TCP address + tcp_addr = None + for addr in listen_addrs: + if "/tcp/" in str(addr) and "/ws" not in str(addr): + tcp_addr = addr + break + + assert tcp_addr, f"No TCP address found in {listen_addrs}" + print(f"๐Ÿ”— Host A listening on: {tcp_addr}") + + # Create peer info for host A + peer_info = info_from_p2p_addr(tcp_addr) + + # Host B connects to host A + await host_b.connect(peer_info) + print("โœ… TCP connection established") + + # Open a stream for data transfer + stream = await host_b.new_stream(peer_info.peer_id, [TCP_DATA_PROTOCOL]) + print("โœ… TCP stream opened") + + # Send test data + await stream.write(test_data) + print(f"๐Ÿ“ค Sent data: {test_data}") + + # Read echoed data back + echoed_data = await stream.read(len(test_data)) + print(f"๐Ÿ“ฅ Received echo: {echoed_data}") + + await stream.close() + + # Wait for transfer to complete + with trio.fail_after(5.0): # 5 second timeout + await transfer_complete.wait() + + # Verify data transfer + assert received_data == test_data, ( + f"Data mismatch: {received_data} != {test_data}" + ) + assert echoed_data == test_data, f"Echo mismatch: {echoed_data} != {test_data}" + + print("โœ… TCP P2P data transfer successful!") + print(f" Original: {test_data}") + print(f" Received: {received_data}") + print(f" Echoed: {echoed_data}") + + +@pytest.mark.trio +async def test_tcp_large_data_transfer(): + """Test TCP with larger data payloads.""" + host_a, host_b = await create_tcp_host_pair() + + # Large test data (10KB) + test_data = b"TCP Large Data Test! " * 500 # ~10KB + received_data = None + transfer_complete = trio.Event() + + async def large_data_handler(stream): + nonlocal received_data + try: + # Read data in chunks + chunks = [] + total_received = 0 + expected_size = len(test_data) + + while total_received < expected_size: + chunk = await stream.read(min(1024, expected_size - total_received)) + if not chunk: + break + chunks.append(chunk) + total_received += len(chunk) + + received_data = b"".join(chunks) + + # Send back confirmation + await stream.write(b"RECEIVED_OK") + await stream.close() + transfer_complete.set() + except Exception as e: + print(f"Large data handler error: {e}") + transfer_complete.set() + + host_a.set_stream_handler(TCP_DATA_PROTOCOL, large_data_handler) + + async with ( + host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]), + host_b.run(listen_addrs=[]), + ): + # Get host A's listen address + listen_addrs = host_a.get_addrs() + assert listen_addrs, "Host A should have listen addresses" + + # Extract TCP address + tcp_addr = None + for addr in listen_addrs: + if "/tcp/" in str(addr) and "/ws" not in str(addr): + tcp_addr = addr + break + + assert tcp_addr, f"No TCP address found in {listen_addrs}" + print(f"๐Ÿ”— Host A listening on: {tcp_addr}") + print(f"๐Ÿ“Š Test data size: {len(test_data)} bytes") + + # Create peer info for host A + peer_info = info_from_p2p_addr(tcp_addr) + + # Host B connects to host A + await host_b.connect(peer_info) + print("โœ… TCP connection established") + + # Open a stream for data transfer + stream = await host_b.new_stream(peer_info.peer_id, [TCP_DATA_PROTOCOL]) + print("โœ… TCP stream opened") + + # Send large test data in chunks + chunk_size = 1024 + sent_bytes = 0 + for i in range(0, len(test_data), chunk_size): + chunk = test_data[i : i + chunk_size] + await stream.write(chunk) + sent_bytes += len(chunk) + if sent_bytes % (chunk_size * 4) == 0: # Progress every 4KB + print(f"๐Ÿ“ค Sent {sent_bytes}/{len(test_data)} bytes") + + print(f"๐Ÿ“ค Sent all {len(test_data)} bytes") + + # Read confirmation + confirmation = await stream.read(1024) + print(f"๐Ÿ“ฅ Received confirmation: {confirmation}") + + await stream.close() + + # Wait for transfer to complete + with trio.fail_after(10.0): # 10 second timeout for large data + await transfer_complete.wait() + + # Verify data transfer + assert received_data is not None, "No data was received" + assert received_data == test_data, ( + "Large data transfer failed:" + + f" sizes {len(received_data)} != {len(test_data)}" + ) + assert confirmation == b"RECEIVED_OK", f"Confirmation failed: {confirmation}" + + print("โœ… TCP large data transfer successful!") + print(f" Data size: {len(test_data)} bytes") + print(f" Received: {len(received_data)} bytes") + print(f" Match: {received_data == test_data}") + + +@pytest.mark.trio +async def test_tcp_bidirectional_transfer(): + """Test bidirectional data transfer over TCP.""" + host_a, host_b = await create_tcp_host_pair() + + # Test data + data_a_to_b = b"Message from Host A to Host B via TCP" + data_b_to_a = b"Response from Host B to Host A via TCP" + + received_on_a = None + received_on_b = None + transfer_complete_a = trio.Event() + transfer_complete_b = trio.Event() + + async def handler_a(stream): + nonlocal received_on_a + try: + # Read data from B + received_on_a = await stream.read(len(data_b_to_a)) + print(f"๐Ÿ…ฐ๏ธ Host A received: {received_on_a}") + await stream.close() + transfer_complete_a.set() + except Exception as e: + print(f"Handler A error: {e}") + transfer_complete_a.set() + + async def handler_b(stream): + nonlocal received_on_b + try: + # Read data from A + received_on_b = await stream.read(len(data_a_to_b)) + print(f"๐Ÿ…ฑ๏ธ Host B received: {received_on_b}") + await stream.close() + transfer_complete_b.set() + except Exception as e: + print(f"Handler B error: {e}") + transfer_complete_b.set() + + # Set up handlers on both hosts + protocol_a_to_b = TProtocol("/test/tcp-a-to-b/1.0.0") + protocol_b_to_a = TProtocol("/test/tcp-b-to-a/1.0.0") + + host_a.set_stream_handler(protocol_b_to_a, handler_a) + host_b.set_stream_handler(protocol_a_to_b, handler_b) + + async with ( + host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]), + host_b.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]), + ): + # Get addresses + addrs_a = host_a.get_addrs() + addrs_b = host_b.get_addrs() + + assert addrs_a and addrs_b, "Both hosts should have addresses" + + # Extract TCP addresses + tcp_addr_a = next( + ( + addr + for addr in addrs_a + if "/tcp/" in str(addr) and "/ws" not in str(addr) + ), + None, + ) + tcp_addr_b = next( + ( + addr + for addr in addrs_b + if "/tcp/" in str(addr) and "/ws" not in str(addr) + ), + None, + ) + + assert tcp_addr_a and tcp_addr_b, ( + f"TCP addresses not found: A={addrs_a}, B={addrs_b}" + ) + print(f"๐Ÿ”— Host A listening on: {tcp_addr_a}") + print(f"๐Ÿ”— Host B listening on: {tcp_addr_b}") + + # Create peer infos + peer_info_a = info_from_p2p_addr(tcp_addr_a) + peer_info_b = info_from_p2p_addr(tcp_addr_b) + + # Establish connections + await host_b.connect(peer_info_a) + await host_a.connect(peer_info_b) + print("โœ… Bidirectional TCP connections established") + + # Send data A -> B + stream_a_to_b = await host_a.new_stream(peer_info_b.peer_id, [protocol_a_to_b]) + await stream_a_to_b.write(data_a_to_b) + print(f"๐Ÿ“ค A->B: {data_a_to_b}") + await stream_a_to_b.close() + + # Send data B -> A + stream_b_to_a = await host_b.new_stream(peer_info_a.peer_id, [protocol_b_to_a]) + await stream_b_to_a.write(data_b_to_a) + print(f"๐Ÿ“ค B->A: {data_b_to_a}") + await stream_b_to_a.close() + + # Wait for both transfers to complete + with trio.fail_after(5.0): + await transfer_complete_a.wait() + await transfer_complete_b.wait() + + # Verify bidirectional transfer + assert received_on_a == data_b_to_a, f"A received wrong data: {received_on_a}" + assert received_on_b == data_a_to_b, f"B received wrong data: {received_on_b}" + + print("โœ… TCP bidirectional data transfer successful!") + print(f" A->B: {data_a_to_b}") + print(f" B->A: {data_b_to_a}") + print(f" โœ“ A got: {received_on_a}") + print(f" โœ“ B got: {received_on_b}") + + +if __name__ == "__main__": + # Run tests directly + import logging + + logging.basicConfig(level=logging.INFO) + + print("๐Ÿงช Running TCP P2P Data Transfer Tests") + print("=" * 50) + + async def run_all_tcp_tests(): + try: + print("\n1. Testing basic TCP connection...") + await test_tcp_basic_connection() + except Exception as e: + print(f"โŒ Basic TCP connection test failed: {e}") + return + + try: + print("\n2. Testing TCP data transfer...") + await test_tcp_data_transfer() + except Exception as e: + print(f"โŒ TCP data transfer test failed: {e}") + return + + try: + print("\n3. Testing TCP large data transfer...") + await test_tcp_large_data_transfer() + except Exception as e: + print(f"โŒ TCP large data transfer test failed: {e}") + return + + try: + print("\n4. Testing TCP bidirectional transfer...") + await test_tcp_bidirectional_transfer() + except Exception as e: + print(f"โŒ TCP bidirectional transfer test failed: {e}") + return + + print("\n" + "=" * 50) + print("๐Ÿ TCP P2P Tests Complete - All Tests PASSED!") + + trio.run(run_all_tcp_tests) diff --git a/libp2p/__init__.py b/libp2p/__init__.py index 3679409f..73180915 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -1,6 +1,7 @@ """Libp2p Python implementation.""" import logging +import ssl from libp2p.transport.quic.utils import is_quic_multiaddr from typing import Any @@ -179,6 +180,8 @@ def new_swarm( enable_quic: bool = False, retry_config: Optional["RetryConfig"] = None, connection_config: ConnectionConfig | QUICTransportConfig | None = None, + tls_client_config: ssl.SSLContext | None = None, + tls_server_config: ssl.SSLContext | None = None, ) -> INetworkService: """ Create a swarm instance based on the parameters. @@ -190,7 +193,9 @@ def new_swarm( :param muxer_preference: optional explicit muxer preference :param listen_addrs: optional list of multiaddrs to listen on :param enable_quic: enable quic for transport - :param quic_transport_opt: options for transport + :param connection_config: options for transport configuration + :param tls_client_config: optional TLS configuration for WebSocket client connections (WSS) + :param tls_server_config: optional TLS configuration for WebSocket server connections (WSS) :return: return a default swarm instance Note: Yamux (/yamux/1.0.0) is the preferred stream multiplexer @@ -249,14 +254,18 @@ def new_swarm( else: # Use the first address to determine transport type addr = listen_addrs[0] - transport_maybe = create_transport_for_multiaddr(addr, upgrader) + transport_maybe = create_transport_for_multiaddr( + addr, + upgrader, + private_key=key_pair.private_key, + tls_client_config=tls_client_config, + tls_server_config=tls_server_config + ) if transport_maybe is None: # Fallback to TCP if no specific transport found if addr.__contains__("tcp"): transport = TCP() - elif addr.__contains__("quic"): - raise ValueError("QUIC not yet supported") else: supported_protocols = get_supported_transport_protocols() raise ValueError( @@ -293,6 +302,8 @@ def new_host( negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT, enable_quic: bool = False, quic_transport_opt: QUICTransportConfig | None = None, + tls_client_config: ssl.SSLContext | None = None, + tls_server_config: ssl.SSLContext | None = None, ) -> IHost: """ Create a new libp2p host based on the given parameters. @@ -307,7 +318,9 @@ def new_host( :param enable_mDNS: whether to enable mDNS discovery :param bootstrap: optional list of bootstrap peer addresses as strings :param enable_quic: optinal choice to use QUIC for transport - :param transport_opt: optional configuration for quic transport + :param quic_transport_opt: optional configuration for quic transport + :param tls_client_config: optional TLS configuration for WebSocket client connections (WSS) + :param tls_server_config: optional TLS configuration for WebSocket server connections (WSS) :return: return a host instance """ @@ -322,7 +335,9 @@ def new_host( peerstore_opt=peerstore_opt, muxer_preference=muxer_preference, listen_addrs=listen_addrs, - connection_config=quic_transport_opt if enable_quic else None + connection_config=quic_transport_opt if enable_quic else None, + tls_client_config=tls_client_config, + tls_server_config=tls_server_config ) if disc_opt is not None: diff --git a/libp2p/transport/__init__.py b/libp2p/transport/__init__.py index 29b3e63b..ebc587e5 100644 --- a/libp2p/transport/__init__.py +++ b/libp2p/transport/__init__.py @@ -1,3 +1,5 @@ +from typing import Any + from .tcp.tcp import TCP from .websocket.transport import WebsocketTransport from .transport_registry import ( @@ -10,7 +12,7 @@ from .transport_registry import ( from .upgrader import TransportUpgrader from libp2p.abc import ITransport -def create_transport(protocol: str, upgrader: TransportUpgrader | None = None, **kwargs) -> ITransport: +def create_transport(protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any) -> ITransport: """ Convenience function to create a transport instance. diff --git a/libp2p/transport/transport_registry.py b/libp2p/transport/transport_registry.py index db783395..eb965655 100644 --- a/libp2p/transport/transport_registry.py +++ b/libp2p/transport/transport_registry.py @@ -2,6 +2,7 @@ Transport registry for dynamic transport selection based on multiaddr protocols. """ +from collections.abc import Callable import logging from typing import Any @@ -16,8 +17,21 @@ from libp2p.transport.websocket.multiaddr_utils import ( ) +# Import QUIC utilities here to avoid circular imports +def _get_quic_transport() -> Any: + from libp2p.transport.quic.transport import QUICTransport + + return QUICTransport + + +def _get_quic_validation() -> Callable[[Multiaddr], bool]: + from libp2p.transport.quic.utils import is_quic_multiaddr + + return is_quic_multiaddr + + # Import WebsocketTransport here to avoid circular imports -def _get_websocket_transport(): +def _get_websocket_transport() -> Any: from libp2p.transport.websocket.transport import WebsocketTransport return WebsocketTransport @@ -85,6 +99,11 @@ class TransportRegistry: self.register_transport("ws", WebsocketTransport) self.register_transport("wss", WebsocketTransport) + # Register QUIC transport for /quic and /quic-v1 protocols + QUICTransport = _get_quic_transport() + self.register_transport("quic", QUICTransport) + self.register_transport("quic-v1", QUICTransport) + def register_transport( self, protocol: str, transport_class: type[ITransport] ) -> None: @@ -137,7 +156,22 @@ class TransportRegistry: return None # Use explicit WebsocketTransport to avoid type issues WebsocketTransport = _get_websocket_transport() - return WebsocketTransport(upgrader) + return WebsocketTransport( + upgrader, + tls_client_config=kwargs.get("tls_client_config"), + tls_server_config=kwargs.get("tls_server_config"), + handshake_timeout=kwargs.get("handshake_timeout", 15.0), + ) + elif protocol in ["quic", "quic-v1"]: + # QUIC transport requires private_key + private_key = kwargs.get("private_key") + if private_key is None: + logger.warning(f"QUIC transport '{protocol}' requires private_key") + return None + # Use explicit QUICTransport to avoid type issues + QUICTransport = _get_quic_transport() + config = kwargs.get("config") + return QUICTransport(private_key, config) else: # TCP transport doesn't require upgrader return transport_class() @@ -161,13 +195,15 @@ def register_transport(protocol: str, transport_class: type[ITransport]) -> None def create_transport_for_multiaddr( - maddr: Multiaddr, upgrader: TransportUpgrader + maddr: Multiaddr, upgrader: TransportUpgrader, **kwargs: Any ) -> ITransport | None: """ Create the appropriate transport for a given multiaddr. :param maddr: The multiaddr to create transport for :param upgrader: The transport upgrader instance + :param kwargs: Additional arguments for transport construction + (e.g., private_key for QUIC) :return: Transport instance or None if no suitable transport found """ try: @@ -176,7 +212,20 @@ def create_transport_for_multiaddr( # Check for supported transport protocols in order of preference # We need to validate that the multiaddr structure is valid for our transports - if "ws" in protocols or "wss" in protocols or "tls" in protocols: + if "quic" in protocols or "quic-v1" in protocols: + # For QUIC, we need a valid structure like: + # /ip4/127.0.0.1/udp/4001/quic + # /ip4/127.0.0.1/udp/4001/quic-v1 + is_quic_multiaddr = _get_quic_validation() + if is_quic_multiaddr(maddr): + # Determine QUIC version + if "quic-v1" in protocols: + return _global_registry.create_transport( + "quic-v1", upgrader, **kwargs + ) + else: + return _global_registry.create_transport("quic", upgrader, **kwargs) + elif "ws" in protocols or "wss" in protocols or "tls" in protocols: # For WebSocket, we need a valid structure like: # /ip4/127.0.0.1/tcp/8080/ws (insecure) # /ip4/127.0.0.1/tcp/8080/wss (secure) @@ -185,9 +234,9 @@ def create_transport_for_multiaddr( if is_valid_websocket_multiaddr(maddr): # Determine if this is a secure WebSocket connection if "wss" in protocols or "tls" in protocols: - return _global_registry.create_transport("wss", upgrader) + return _global_registry.create_transport("wss", upgrader, **kwargs) else: - return _global_registry.create_transport("ws", upgrader) + return _global_registry.create_transport("ws", upgrader, **kwargs) elif "tcp" in protocols: # For TCP, we need a valid structure like /ip4/127.0.0.1/tcp/8080 # Check if the multiaddr has proper TCP structure diff --git a/libp2p/transport/websocket/connection.py b/libp2p/transport/websocket/connection.py index f5a99b7e..68c1eb76 100644 --- a/libp2p/transport/websocket/connection.py +++ b/libp2p/transport/websocket/connection.py @@ -35,11 +35,9 @@ class P2PWebSocketConnection(ReadWriteCloser): raise IOException("Connection is closed") try: - logger.debug(f"WebSocket writing {len(data)} bytes") # Send as a binary WebSocket message await self._ws_connection.send_message(data) self._bytes_written += len(data) - logger.debug(f"WebSocket wrote {len(data)} bytes successfully") except Exception as e: logger.error(f"WebSocket write failed: {e}") raise IOException from e @@ -48,95 +46,70 @@ class P2PWebSocketConnection(ReadWriteCloser): """ Read up to n bytes (if n is given), else read up to 64KiB. This implementation provides byte-level access to WebSocket messages, - which is required for Noise protocol handshake. + which is required for libp2p protocol compatibility. + + For WebSocket compatibility with libp2p protocols, this method: + 1. Buffers incoming WebSocket messages + 2. Returns exactly the requested number of bytes when n is specified + 3. Accumulates multiple WebSocket messages if needed to satisfy the request + 4. Returns empty bytes (not raises) when connection is closed and no data + available """ if self._closed: raise IOException("Connection is closed") async with self._read_lock: try: - logger.debug( - f"WebSocket read requested: n={n}, " - f"buffer_size={len(self._read_buffer)}" - ) - - # If we have buffered data, return it - if self._read_buffer: - if n is None: - result = self._read_buffer - self._read_buffer = b"" - self._bytes_read += len(result) - logger.debug( - f"WebSocket read returning all buffered data: " - f"{len(result)} bytes" - ) - return result - else: - if len(self._read_buffer) >= n: - result = self._read_buffer[:n] - self._read_buffer = self._read_buffer[n:] - self._bytes_read += len(result) - logger.debug( - f"WebSocket read returning {len(result)} bytes " - f"from buffer" - ) - return result - else: - # We need more data, but we have some buffered - # Keep the buffered data and get more - logger.debug( - f"WebSocket read needs more data: have " - f"{len(self._read_buffer)}, need {n}" - ) - pass - - # If we need exactly n bytes but don't have enough, get more data - while n is not None and ( - not self._read_buffer or len(self._read_buffer) < n - ): - logger.debug( - f"WebSocket read getting more data: " - f"buffer_size={len(self._read_buffer)}, need={n}" - ) - # Get the next WebSocket message and treat it as a byte stream - # This mimics the Go implementation's NextReader() approach - message = await self._ws_connection.get_message() - if isinstance(message, str): - message = message.encode("utf-8") - - logger.debug( - f"WebSocket read received message: {len(message)} bytes" - ) - # Add to buffer - self._read_buffer += message - - # Return requested amount + # If n is None, read at least one message and return all buffered data if n is None: + if not self._read_buffer: + try: + # Use a short timeout to avoid blocking indefinitely + with trio.fail_after(1.0): # 1 second timeout + message = await self._ws_connection.get_message() + if isinstance(message, str): + message = message.encode("utf-8") + self._read_buffer = message + except trio.TooSlowError: + # No message available within timeout + return b"" + except Exception: + # Return empty bytes if no data available + # (connection closed) + return b"" + result = self._read_buffer self._read_buffer = b"" self._bytes_read += len(result) - logger.debug( - f"WebSocket read returning all data: {len(result)} bytes" - ) return result - else: - if len(self._read_buffer) >= n: - result = self._read_buffer[:n] - self._read_buffer = self._read_buffer[n:] - self._bytes_read += len(result) - logger.debug( - f"WebSocket read returning exact {len(result)} bytes" - ) - return result - else: - # This should never happen due to the while loop above - result = self._read_buffer - self._read_buffer = b"" - self._bytes_read += len(result) - logger.debug( - f"WebSocket read returning remaining {len(result)} bytes" - ) - return result + + # For specific byte count requests, return UP TO n bytes (not exactly n) + # This matches TCP semantics where read(1024) returns available data + # up to 1024 bytes + + # If we don't have any data buffered, try to get at least one message + if not self._read_buffer: + try: + # Use a short timeout to avoid blocking indefinitely + with trio.fail_after(1.0): # 1 second timeout + message = await self._ws_connection.get_message() + if isinstance(message, str): + message = message.encode("utf-8") + self._read_buffer = message + except trio.TooSlowError: + return b"" # No data available + except Exception: + return b"" + + # Now return up to n bytes from the buffer (TCP-like semantics) + if len(self._read_buffer) == 0: + return b"" + + # Return up to n bytes (like TCP read()) + result = self._read_buffer[:n] + self._read_buffer = self._read_buffer[len(result) :] + self._bytes_read += len(result) + return result except Exception as e: logger.error(f"WebSocket read failed: {e}") @@ -148,17 +121,18 @@ class P2PWebSocketConnection(ReadWriteCloser): if self._closed: return # Already closed + logger.debug("WebSocket connection closing") try: - # Close the WebSocket connection + # Always close the connection directly, avoid context manager issues + # The context manager may be causing cancel scope corruption + logger.debug("WebSocket closing connection directly") await self._ws_connection.aclose() - # Exit the context manager if we have one - if self._ws_context is not None: - await self._ws_context.__aexit__(None, None, None) except Exception as e: logger.error(f"WebSocket close error: {e}") # Don't raise here, as close() should be idempotent finally: self._closed = True + logger.debug("WebSocket connection closed") def conn_state(self) -> dict[str, Any]: """ diff --git a/libp2p/transport/websocket/listener.py b/libp2p/transport/websocket/listener.py index 5f5cf106..1ea3bc9b 100644 --- a/libp2p/transport/websocket/listener.py +++ b/libp2p/transport/websocket/listener.py @@ -38,6 +38,7 @@ class WebsocketListener(IListener): self._shutdown_event = trio.Event() self._nursery: trio.Nursery | None = None self._listeners: Any = None + self._is_wss = False # Track whether this is a WSS listener async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool: logger.debug(f"WebsocketListener.listen called with {maddr}") @@ -54,6 +55,9 @@ class WebsocketListener(IListener): f"Cannot listen on WSS address {maddr} without TLS configuration" ) + # Store whether this is a WSS listener + self._is_wss = parsed.is_wss + # Extract host and port from the base multiaddr host = ( parsed.rest_multiaddr.value_for_protocol("ip4") @@ -169,16 +173,16 @@ class WebsocketListener(IListener): if hasattr(self._listeners, "port"): # This is a WebSocketServer object port = self._listeners.port - # Create a multiaddr from the port - # Note: We don't know if this is WS or WSS from the server object - # For now, assume WS - this could be improved by storing the original multiaddr - return (Multiaddr(f"/ip4/127.0.0.1/tcp/{port}/ws"),) + # Create a multiaddr from the port with correct WSS/WS protocol + protocol = "wss" if self._is_wss else "ws" + return (Multiaddr(f"/ip4/127.0.0.1/tcp/{port}/{protocol}"),) else: # This is a list of listeners (like TCP) listeners = self._listeners # Get addresses from listeners like TCP does return tuple( - _multiaddr_from_socket(listener.socket) for listener in listeners + _multiaddr_from_socket(listener.socket, self._is_wss) + for listener in listeners ) async def close(self) -> None: @@ -212,7 +216,10 @@ class WebsocketListener(IListener): logger.debug("WebsocketListener.close completed") -def _multiaddr_from_socket(socket: trio.socket.SocketType) -> Multiaddr: +def _multiaddr_from_socket( + socket: trio.socket.SocketType, is_wss: bool = False +) -> Multiaddr: """Convert socket to multiaddr""" ip, port = socket.getsockname() - return Multiaddr(f"/ip4/{ip}/tcp/{port}/ws") + protocol = "wss" if is_wss else "ws" + return Multiaddr(f"/ip4/{ip}/tcp/{port}/{protocol}") diff --git a/libp2p/transport/websocket/multiaddr_utils.py b/libp2p/transport/websocket/multiaddr_utils.py index 57030c11..16a38073 100644 --- a/libp2p/transport/websocket/multiaddr_utils.py +++ b/libp2p/transport/websocket/multiaddr_utils.py @@ -125,7 +125,7 @@ def is_valid_websocket_multiaddr(maddr: Multiaddr) -> bool: # Find the WebSocket protocol ws_protocol_found = False tls_found = False - sni_found = False + # sni_found = False # Not used currently for i, protocol in enumerate(protocols[2:], start=2): if protocol.name in ws_protocols: @@ -134,7 +134,7 @@ def is_valid_websocket_multiaddr(maddr: Multiaddr) -> bool: elif protocol.name in tls_protocols: tls_found = True elif protocol.name in sni_protocols: - # sni_found = True # Not used in current implementation + pass # sni_found = True # Not used in current implementation if not ws_protocol_found: return False diff --git a/libp2p/transport/websocket/transport.py b/libp2p/transport/websocket/transport.py index fc8867a5..d9253c3f 100644 --- a/libp2p/transport/websocket/transport.py +++ b/libp2p/transport/websocket/transport.py @@ -2,7 +2,6 @@ import logging import ssl from multiaddr import Multiaddr -import trio from libp2p.abc import IListener, ITransport from libp2p.custom_types import THandler @@ -68,8 +67,6 @@ class WebsocketTransport(ITransport): ) try: - from trio_websocket import open_websocket_url - # Prepare SSL context for WSS connections ssl_context = None if parsed.is_wss: @@ -83,19 +80,63 @@ class WebsocketTransport(ITransport): ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE - # Use the context manager but don't exit it immediately - # The connection will be closed when the RawConnection is closed - ws_context = open_websocket_url(ws_url, ssl_context=ssl_context) + logger.debug(f"WebsocketTransport.dial opening connection to {ws_url}") - # Apply handshake timeout + # Use a different approach: start background nursery that will persist + logger.debug("WebsocketTransport.dial establishing connection") + + # Import trio-websocket functions + from trio_websocket import connect_websocket + from trio_websocket._impl import _url_to_host + + # Parse the WebSocket URL to get host, port, resource + # like trio-websocket does + ws_host, ws_port, ws_resource, ws_ssl_context = _url_to_host( + ws_url, ssl_context + ) + + logger.debug( + f"WebsocketTransport.dial parsed URL: host={ws_host}, " + f"port={ws_port}, resource={ws_resource}" + ) + + # Instead of fighting trio-websocket's lifecycle, let's try using + # a persistent task that will keep the WebSocket alive + # This mimics what trio-websocket does internally but with our control + + # Create a background task manager for this connection + import trio + + nursery_manager = trio.lowlevel.current_task().parent_nursery + if nursery_manager is None: + raise OpenConnectionError( + f"No parent nursery available for WebSocket connection to {maddr}" + ) + + # Apply timeout to the connection process with trio.fail_after(self._handshake_timeout): - ws = await ws_context.__aenter__() + logger.debug("WebsocketTransport.dial connecting WebSocket") + ws = await connect_websocket( + nursery_manager, # Use the existing nursery from libp2p + ws_host, + ws_port, + ws_resource, + use_ssl=ws_ssl_context, + message_queue_size=1024, # Reasonable defaults + max_message_size=16 * 1024 * 1024, # 16MB max message + ) + logger.debug("WebsocketTransport.dial WebSocket connection established") - conn = P2PWebSocketConnection(ws, ws_context, is_secure=parsed.is_wss) # type: ignore[attr-defined] - return RawConnection(conn, initiator=True) + # Create our connection wrapper + # Pass None for nursery since we're using the parent nursery + conn = P2PWebSocketConnection(ws, None, is_secure=parsed.is_wss) + logger.debug("WebsocketTransport.dial created P2PWebSocketConnection") + + return RawConnection(conn, initiator=True) except trio.TooSlowError as e: raise OpenConnectionError( - f"WebSocket handshake timeout after {self._handshake_timeout}s for {maddr}" + f"WebSocket handshake timeout after {self._handshake_timeout}s " + f"for {maddr}" ) from e except Exception as e: raise OpenConnectionError(f"Failed to dial WebSocket {maddr}: {e}") from e @@ -149,7 +190,8 @@ class WebsocketTransport(ITransport): return [maddr] # Create new multiaddr with SNI - # For /dns/example.com/tcp/8080/wss -> /dns/example.com/tcp/8080/tls/sni/example.com/ws + # For /dns/example.com/tcp/8080/wss -> + # /dns/example.com/tcp/8080/tls/sni/example.com/ws try: # Remove /wss and add /tls/sni/example.com/ws without_wss = maddr.decapsulate(Multiaddr("/wss")) diff --git a/test_websocket_client.py b/test_websocket_client.py deleted file mode 100755 index 984a93ef..00000000 --- a/test_websocket_client.py +++ /dev/null @@ -1,243 +0,0 @@ -#!/usr/bin/env python3 -""" -Standalone WebSocket client for testing py-libp2p WebSocket transport. -This script allows you to test the Python WebSocket client independently. -""" - -import argparse -import logging -import sys - -from multiaddr import Multiaddr -import trio - -from libp2p import create_yamux_muxer_option, new_host -from libp2p.crypto.secp256k1 import create_new_key_pair -from libp2p.crypto.x25519 import create_new_key_pair as create_new_x25519_key_pair -from libp2p.custom_types import TProtocol -from libp2p.network.exceptions import SwarmException -from libp2p.peer.id import ID -from libp2p.peer.peerinfo import info_from_p2p_addr -from libp2p.security.noise.transport import ( - PROTOCOL_ID as NOISE_PROTOCOL_ID, - Transport as NoiseTransport, -) -from libp2p.transport.websocket.multiaddr_utils import ( - is_valid_websocket_multiaddr, - parse_websocket_multiaddr, -) - -# Configure logging -logging.basicConfig( - level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) -logger = logging.getLogger(__name__) - -# Enable debug logging for WebSocket transport -logging.getLogger("libp2p.transport.websocket").setLevel(logging.DEBUG) -logging.getLogger("libp2p.network.swarm").setLevel(logging.DEBUG) - -PING_PROTOCOL_ID = TProtocol("/ipfs/ping/1.0.0") - - -async def test_websocket_connection(destination: str, timeout: int = 30) -> bool: - """ - Test WebSocket connection to a destination multiaddr. - - Args: - destination: Multiaddr string (e.g., /ip4/127.0.0.1/tcp/8080/ws/p2p/...) - timeout: Connection timeout in seconds - - Returns: - True if connection successful, False otherwise - - """ - try: - # Parse the destination multiaddr - maddr = Multiaddr(destination) - logger.info(f"Testing connection to: {maddr}") - - # Validate WebSocket multiaddr - if not is_valid_websocket_multiaddr(maddr): - logger.error(f"Invalid WebSocket multiaddr: {maddr}") - return False - - # Parse WebSocket multiaddr - try: - parsed = parse_websocket_multiaddr(maddr) - logger.info( - f"Parsed WebSocket multiaddr: is_wss={parsed.is_wss}, sni={parsed.sni}, rest_multiaddr={parsed.rest_multiaddr}" - ) - except Exception as e: - logger.error(f"Failed to parse WebSocket multiaddr: {e}") - return False - - # Extract peer ID from multiaddr - try: - peer_id = ID.from_base58(maddr.value_for_protocol("p2p")) - logger.info(f"Target peer ID: {peer_id}") - except Exception as e: - logger.error(f"Failed to extract peer ID from multiaddr: {e}") - return False - - # Create Python host using professional pattern - logger.info("Creating Python host...") - key_pair = create_new_key_pair() - py_peer_id = ID.from_pubkey(key_pair.public_key) - logger.info(f"Python Peer ID: {py_peer_id}") - - # Generate X25519 keypair for Noise - noise_key_pair = create_new_x25519_key_pair() - - # Create security options (following professional pattern) - security_options = { - NOISE_PROTOCOL_ID: NoiseTransport( - libp2p_keypair=key_pair, - noise_privkey=noise_key_pair.private_key, - early_data=None, - with_noise_pipes=False, - ) - } - - # Create muxer options - muxer_options = create_yamux_muxer_option() - - # Create host with proper configuration - host = new_host( - key_pair=key_pair, - sec_opt=security_options, - muxer_opt=muxer_options, - listen_addrs=[ - Multiaddr("/ip4/0.0.0.0/tcp/0/ws") - ], # WebSocket listen address - ) - logger.info(f"Python host created: {host}") - - # Create peer info using professional helper - peer_info = info_from_p2p_addr(maddr) - logger.info(f"Connecting to: {peer_info}") - - # Start the host - logger.info("Starting host...") - async with host.run(listen_addrs=[]): - # Wait a moment for host to be ready - await trio.sleep(1) - - # Attempt connection with timeout - logger.info("Attempting to connect...") - try: - with trio.fail_after(timeout): - await host.connect(peer_info) - logger.info("โœ… Successfully connected to peer!") - - # Test ping protocol (following professional pattern) - logger.info("Testing ping protocol...") - try: - stream = await host.new_stream( - peer_info.peer_id, [PING_PROTOCOL_ID] - ) - logger.info("โœ… Successfully created ping stream!") - - # Send ping (32 bytes as per libp2p ping protocol) - ping_data = b"\x01" * 32 - await stream.write(ping_data) - logger.info(f"โœ… Sent ping: {len(ping_data)} bytes") - - # Wait for pong (should be same 32 bytes) - pong_data = await stream.read(32) - logger.info(f"โœ… Received pong: {len(pong_data)} bytes") - - if pong_data == ping_data: - logger.info("โœ… Ping-pong test successful!") - return True - else: - logger.error( - f"โŒ Unexpected pong data: expected {len(ping_data)} bytes, got {len(pong_data)} bytes" - ) - return False - - except Exception as e: - logger.error(f"โŒ Ping protocol test failed: {e}") - return False - - except trio.TooSlowError: - logger.error(f"โŒ Connection timeout after {timeout} seconds") - return False - except SwarmException as e: - logger.error(f"โŒ Connection failed with SwarmException: {e}") - # Log the underlying error details - if hasattr(e, "__cause__") and e.__cause__: - logger.error(f"Underlying error: {e.__cause__}") - return False - except Exception as e: - logger.error(f"โŒ Connection failed with unexpected error: {e}") - import traceback - - logger.error(f"Full traceback: {traceback.format_exc()}") - return False - - except Exception as e: - logger.error(f"โŒ Test failed with error: {e}") - return False - - -async def main(): - """Main function to run the WebSocket client test.""" - parser = argparse.ArgumentParser( - description="Test py-libp2p WebSocket client connection", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Test connection to a WebSocket peer - python test_websocket_client.py /ip4/127.0.0.1/tcp/8080/ws/p2p/12D3KooW... - - # Test with custom timeout - python test_websocket_client.py /ip4/127.0.0.1/tcp/8080/ws/p2p/12D3KooW... --timeout 60 - - # Test WSS connection - python test_websocket_client.py /ip4/127.0.0.1/tcp/8080/wss/p2p/12D3KooW... - """, - ) - - parser.add_argument( - "destination", - help="Destination multiaddr (e.g., /ip4/127.0.0.1/tcp/8080/ws/p2p/12D3KooW...)", - ) - - parser.add_argument( - "--timeout", - type=int, - default=30, - help="Connection timeout in seconds (default: 30)", - ) - - parser.add_argument( - "--verbose", "-v", action="store_true", help="Enable verbose logging" - ) - - args = parser.parse_args() - - # Set logging level - if args.verbose: - logging.getLogger().setLevel(logging.DEBUG) - else: - logging.getLogger().setLevel(logging.INFO) - - logger.info("๐Ÿš€ Starting WebSocket client test...") - logger.info(f"Destination: {args.destination}") - logger.info(f"Timeout: {args.timeout}s") - - # Run the test - success = await test_websocket_connection(args.destination, args.timeout) - - if success: - logger.info("๐ŸŽ‰ WebSocket client test completed successfully!") - sys.exit(0) - else: - logger.error("๐Ÿ’ฅ WebSocket client test failed!") - sys.exit(1) - - -if __name__ == "__main__": - # Run with trio - trio.run(main) diff --git a/tests/core/transport/test_websocket.py b/tests/core/transport/test_websocket.py index cf2e2d5e..53f78aac 100644 --- a/tests/core/transport/test_websocket.py +++ b/tests/core/transport/test_websocket.py @@ -3,6 +3,7 @@ import logging from typing import Any import pytest +from exceptiongroup import ExceptionGroup from multiaddr import Multiaddr import trio @@ -623,6 +624,7 @@ async def test_websocket_data_exchange(): key_pair=key_pair_b, sec_opt=security_options_b, muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # WebSocket transport ) # Test data @@ -675,7 +677,10 @@ async def test_websocket_data_exchange(): @pytest.mark.trio async def test_websocket_host_pair_data_exchange(): - """Test WebSocket host pair with actual data exchange using host_pair_factory pattern""" + """ + Test WebSocket host pair with actual data exchange using host_pair_factory + pattern. + """ from libp2p import create_yamux_muxer_option, new_host from libp2p.crypto.secp256k1 import create_new_key_pair from libp2p.custom_types import TProtocol @@ -712,6 +717,7 @@ async def test_websocket_host_pair_data_exchange(): key_pair=key_pair_b, sec_opt=security_options_b, muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # WebSocket transport ) # Test data @@ -784,16 +790,102 @@ async def test_wss_host_pair_data_exchange(): InsecureTransport, ) - # Create TLS context for WSS - tls_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - tls_context.check_hostname = False - tls_context.verify_mode = ssl.CERT_NONE + # Create TLS contexts for WSS (separate for client and server) + # For testing, we need to create a self-signed certificate + try: + import datetime + import ipaddress + import os + import tempfile + + from cryptography import x509 + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.x509.oid import NameOID + + # Generate private key + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + + # Create certificate + subject = issuer = x509.Name( + [ + x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), # type: ignore + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Test"), # type: ignore + x509.NameAttribute(NameOID.LOCALITY_NAME, "Test"), # type: ignore + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Test"), # type: ignore + x509.NameAttribute(NameOID.COMMON_NAME, "localhost"), # type: ignore + ] + ) + + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.UTC)) + .not_valid_after( + datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=1) + ) + .add_extension( + x509.SubjectAlternativeName( + [ + x509.DNSName("localhost"), + x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")), + ] + ), + critical=False, + ) + .sign(private_key, hashes.SHA256()) + ) + + # Create temporary files for cert and key + cert_file = tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=".crt") + key_file = tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=".key") + + # Write certificate and key to files + cert_file.write(cert.public_bytes(serialization.Encoding.PEM)) + key_file.write( + private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + + cert_file.close() + key_file.close() + + # Server context for listener (Host A) + server_tls_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + server_tls_context.load_cert_chain(cert_file.name, key_file.name) + + # Client context for dialer (Host B) + client_tls_context = ssl.create_default_context() + client_tls_context.check_hostname = False + client_tls_context.verify_mode = ssl.CERT_NONE + + # Clean up temp files after use + def cleanup_certs(): + try: + os.unlink(cert_file.name) + os.unlink(key_file.name) + except Exception: + pass + + except ImportError: + pytest.skip("cryptography package required for WSS tests") + except Exception as e: + pytest.skip(f"Failed to create test certificates: {e}") # Create two hosts with WSS transport and plaintext security key_pair_a = create_new_key_pair() key_pair_b = create_new_key_pair() - # Host A (listener) - WSS transport + # Host A (listener) - WSS transport with server TLS config security_options_a = { PLAINTEXT_PROTOCOL_ID: InsecureTransport( local_key_pair=key_pair_a, secure_bytes_provider=None, peerstore=None @@ -804,9 +896,10 @@ async def test_wss_host_pair_data_exchange(): sec_opt=security_options_a, muxer_opt=create_yamux_muxer_option(), listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/wss")], + tls_server_config=server_tls_context, ) - # Host B (dialer) - WSS transport + # Host B (dialer) - WSS transport with client TLS config security_options_b = { PLAINTEXT_PROTOCOL_ID: InsecureTransport( local_key_pair=key_pair_b, secure_bytes_provider=None, peerstore=None @@ -816,6 +909,8 @@ async def test_wss_host_pair_data_exchange(): key_pair=key_pair_b, sec_opt=security_options_b, muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/wss")], # Ensure WSS transport + tls_client_config=client_tls_context, ) # Test data @@ -1028,7 +1123,7 @@ async def test_wss_transport_without_tls_config(): @pytest.mark.trio async def test_wss_dial_parsing(): """Test WSS dial functionality with multiaddr parsing.""" - upgrader = create_upgrader() + # upgrader = create_upgrader() # Not used in this test # transport = WebsocketTransport(upgrader) # Not used in this test # Test WSS multiaddr parsing in dial @@ -1085,10 +1180,15 @@ async def test_wss_listen_without_tls_config(): listener = transport.create_listener(dummy_handler) # This should raise an error when trying to listen on WSS without TLS config - with pytest.raises( - ValueError, match="Cannot listen on WSS address.*without TLS configuration" - ): - await listener.listen(wss_maddr, trio.open_nursery()) + with pytest.raises(ExceptionGroup) as exc_info: + async with trio.open_nursery() as nursery: + await listener.listen(wss_maddr, nursery) + + # Check that the ExceptionGroup contains the expected ValueError + assert len(exc_info.value.exceptions) == 1 + assert isinstance(exc_info.value.exceptions[0], ValueError) + assert "Cannot listen on WSS address" in str(exc_info.value.exceptions[0]) + assert "without TLS configuration" in str(exc_info.value.exceptions[0]) @pytest.mark.trio @@ -1213,7 +1313,7 @@ def test_wss_vs_ws_distinction(): @pytest.mark.trio async def test_wss_connection_handling(): """Test WSS connection handling with security flag.""" - upgrader = create_upgrader() + # upgrader = create_upgrader() # Not used in this test # transport = WebsocketTransport(upgrader) # Not used in this test # Test that WSS connections are marked as secure @@ -1263,7 +1363,9 @@ async def test_handshake_timeout(): await trio.sleep(0) listener = transport.create_listener(dummy_handler) - assert listener._handshake_timeout == 0.1 + # Type assertion to access private attribute for testing + assert hasattr(listener, "_handshake_timeout") + assert getattr(listener, "_handshake_timeout") == 0.1 @pytest.mark.trio @@ -1275,11 +1377,14 @@ async def test_handshake_timeout_creation(): from libp2p.transport import create_transport transport = create_transport("ws", upgrader, handshake_timeout=5.0) - assert transport._handshake_timeout == 5.0 + # Type assertion to access private attribute for testing + assert hasattr(transport, "_handshake_timeout") + assert getattr(transport, "_handshake_timeout") == 5.0 # Test default timeout transport_default = create_transport("ws", upgrader) - assert transport_default._handshake_timeout == 15.0 + assert hasattr(transport_default, "_handshake_timeout") + assert getattr(transport_default, "_handshake_timeout") == 15.0 @pytest.mark.trio @@ -1310,7 +1415,8 @@ async def test_connection_state_tracking(): assert state["total_bytes"] == 0 assert state["connection_duration"] >= 0 - # Test byte tracking (we can't actually read/write with mock, but we can test the method) + # Test byte tracking (we can't actually read/write with mock, but we can test + # the method) # The actual byte tracking will be tested in integration tests assert hasattr(conn, "_bytes_read") assert hasattr(conn, "_bytes_written") @@ -1396,7 +1502,7 @@ async def test_zero_byte_write_handling(): @pytest.mark.trio async def test_websocket_transport_protocols(): """Test that WebSocket transport reports correct protocols.""" - upgrader = create_upgrader() + # upgrader = create_upgrader() # Not used in this test # transport = WebsocketTransport(upgrader) # Not used in this test # Test that the transport can handle both WS and WSS protocols @@ -1427,7 +1533,9 @@ async def test_websocket_listener_addr_format(): await trio.sleep(0) listener_ws = transport_ws.create_listener(dummy_handler_ws) - assert listener_ws._handshake_timeout == 15.0 # Default timeout + # Type assertion to access private attribute for testing + assert hasattr(listener_ws, "_handshake_timeout") + assert getattr(listener_ws, "_handshake_timeout") == 15.0 # Default timeout # Test WSS listener with TLS config import ssl @@ -1439,13 +1547,19 @@ async def test_websocket_listener_addr_format(): await trio.sleep(0) listener_wss = transport_wss.create_listener(dummy_handler_wss) - assert listener_wss._tls_config is not None - assert listener_wss._handshake_timeout == 15.0 + # Type assertion to access private attributes for testing + assert hasattr(listener_wss, "_tls_config") + assert getattr(listener_wss, "_tls_config") is not None + assert hasattr(listener_wss, "_handshake_timeout") + assert getattr(listener_wss, "_handshake_timeout") == 15.0 @pytest.mark.trio async def test_sni_resolution_limitation(): - """Test SNI resolution limitation - Python multiaddr library doesn't support SNI protocol.""" + """ + Test SNI resolution limitation - Python multiaddr library doesn't support + SNI protocol. + """ upgrader = create_upgrader() transport = WebsocketTransport(upgrader) @@ -1471,7 +1585,7 @@ async def test_sni_resolution_limitation(): @pytest.mark.trio async def test_websocket_transport_can_dial(): """Test WebSocket transport CanDial functionality similar to Go implementation.""" - upgrader = create_upgrader() + # upgrader = create_upgrader() # Not used in this test # transport = WebsocketTransport(upgrader) # Not used in this test # Test valid WebSocket addresses that should be dialable diff --git a/tests/core/transport/test_websocket_p2p.py b/tests/core/transport/test_websocket_p2p.py index 35867ace..2744bb34 100644 --- a/tests/core/transport/test_websocket_p2p.py +++ b/tests/core/transport/test_websocket_p2p.py @@ -8,7 +8,6 @@ including both WS and WSS (WebSocket Secure) scenarios. import pytest from multiaddr import Multiaddr -import trio from libp2p import create_yamux_muxer_option, new_host from libp2p.crypto.secp256k1 import create_new_key_pair @@ -58,6 +57,8 @@ async def test_websocket_p2p_plaintext(): key_pair=key_pair_b, sec_opt=security_options_b, muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # Ensure WebSocket + # transport ) # Test data @@ -152,6 +153,8 @@ async def test_websocket_p2p_noise(): key_pair=key_pair_b, sec_opt=security_options_b, muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # Ensure WebSocket + # transport ) # Test data @@ -246,6 +249,8 @@ async def test_websocket_p2p_libp2p_ping(): key_pair=key_pair_b, sec_opt=security_options_b, muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # Ensure WebSocket + # transport ) # Set up ping handler on host A (standard libp2p ping protocol) @@ -301,7 +306,10 @@ async def test_websocket_p2p_libp2p_ping(): @pytest.mark.trio async def test_websocket_p2p_multiple_streams(): - """Test Python-to-Python WebSocket communication with multiple concurrent streams.""" + """ + Test Python-to-Python WebSocket communication with multiple concurrent + streams. + """ # Create two hosts with Noise security key_pair_a = create_new_key_pair() key_pair_b = create_new_key_pair() @@ -337,6 +345,8 @@ async def test_websocket_p2p_multiple_streams(): key_pair=key_pair_b, sec_opt=security_options_b, muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # Ensure WebSocket + # transport ) # Test protocol @@ -385,7 +395,9 @@ async def test_websocket_p2p_multiple_streams(): return response # Run all streams concurrently - tasks = [create_stream_and_test(i, test_data_list[i]) for i in range(num_streams)] + tasks = [ + create_stream_and_test(i, test_data_list[i]) for i in range(num_streams) + ] responses = [] for task in tasks: responses.append(await task) @@ -439,6 +451,8 @@ async def test_websocket_p2p_connection_state(): key_pair=key_pair_b, sec_opt=security_options_b, muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # Ensure WebSocket + # transport ) # Set up handler on host A @@ -488,21 +502,23 @@ async def test_websocket_p2p_connection_state(): # Get the connection to host A conn_to_a = None - for peer_id, conn in connections.items(): + for peer_id, conn_list in connections.items(): if peer_id == host_a.get_id(): - conn_to_a = conn + # connections maps peer_id to list of connections, get the first one + conn_to_a = conn_list[0] if conn_list else None break assert conn_to_a is not None, "Should have connection to host A" # Test that the connection has the expected properties assert hasattr(conn_to_a, "muxed_conn"), "Connection should have muxed_conn" - assert hasattr(conn_to_a.muxed_conn, "conn"), ( - "Muxed connection should have underlying conn" + assert hasattr(conn_to_a.muxed_conn, "secured_conn"), ( + "Muxed connection should have underlying secured_conn" ) # If the underlying connection is our WebSocket connection, test its state - underlying_conn = conn_to_a.muxed_conn.conn + # Type assertion to access private attribute for testing + underlying_conn = getattr(conn_to_a.muxed_conn, "secured_conn") if hasattr(underlying_conn, "conn_state"): state = underlying_conn.conn_state() assert "connection_start_time" in state, ( diff --git a/tests/interop/js_libp2p/js_node/src/package.json b/tests/interop/js_libp2p/js_node/src/package.json index e029c434..e5b1498f 100644 --- a/tests/interop/js_libp2p/js_node/src/package.json +++ b/tests/interop/js_libp2p/js_node/src/package.json @@ -13,7 +13,9 @@ "@libp2p/ping": "^2.0.36", "@libp2p/websockets": "^9.2.18", "@chainsafe/libp2p-yamux": "^5.0.1", + "@chainsafe/libp2p-noise": "^16.0.1", "@libp2p/plaintext": "^2.0.7", + "@libp2p/identify": "^3.0.39", "libp2p": "^2.9.0", "multiaddr": "^10.0.1" } diff --git a/tests/interop/js_libp2p/js_node/src/ws_ping_node.mjs b/tests/interop/js_libp2p/js_node/src/ws_ping_node.mjs index bff7b514..3951fc02 100644 --- a/tests/interop/js_libp2p/js_node/src/ws_ping_node.mjs +++ b/tests/interop/js_libp2p/js_node/src/ws_ping_node.mjs @@ -1,22 +1,76 @@ import { createLibp2p } from 'libp2p' import { webSockets } from '@libp2p/websockets' import { ping } from '@libp2p/ping' +import { noise } from '@chainsafe/libp2p-noise' import { plaintext } from '@libp2p/plaintext' import { yamux } from '@chainsafe/libp2p-yamux' +// import { identify } from '@libp2p/identify' // Commented out for compatibility + +// Configuration from environment (with defaults for compatibility) +const TRANSPORT = process.env.transport || 'ws' +const SECURITY = process.env.security || 'noise' +const MUXER = process.env.muxer || 'yamux' +const IP = process.env.ip || '0.0.0.0' async function main() { - const node = await createLibp2p({ - transports: [ webSockets() ], - connectionEncryption: [ plaintext() ], - streamMuxers: [ yamux() ], - services: { - // installs /ipfs/ping/1.0.0 handler - ping: ping() + console.log(`๐Ÿ”ง Configuration: transport=${TRANSPORT}, security=${SECURITY}, muxer=${MUXER}`) + + // Build options following the proven pattern from test-plans-fork + const options = { + start: true, + connectionGater: { + denyDialMultiaddr: async () => false }, - addresses: { - listen: ['/ip4/0.0.0.0/tcp/0/ws'] + connectionMonitor: { + enabled: false + }, + services: { + ping: ping() } - }) + } + + // Transport configuration (following get-libp2p.ts pattern) + switch (TRANSPORT) { + case 'ws': + options.transports = [webSockets()] + options.addresses = { + listen: [`/ip4/${IP}/tcp/0/ws`] + } + break + case 'wss': + process.env.NODE_TLS_REJECT_UNAUTHORIZED = '0' + options.transports = [webSockets()] + options.addresses = { + listen: [`/ip4/${IP}/tcp/0/wss`] + } + break + default: + throw new Error(`Unknown transport: ${TRANSPORT}`) + } + + // Security configuration + switch (SECURITY) { + case 'noise': + options.connectionEncryption = [noise()] + break + case 'plaintext': + options.connectionEncryption = [plaintext()] + break + default: + throw new Error(`Unknown security: ${SECURITY}`) + } + + // Muxer configuration + switch (MUXER) { + case 'yamux': + options.streamMuxers = [yamux()] + break + default: + throw new Error(`Unknown muxer: ${MUXER}`) + } + + console.log('๐Ÿ”ง Creating libp2p node with proven interop configuration...') + const node = await createLibp2p(options) await node.start() @@ -25,6 +79,39 @@ async function main() { console.log(addr.toString()) } + // Debug: Print supported protocols + console.log('DEBUG: Supported protocols:') + if (node.services && node.services.registrar) { + const protocols = node.services.registrar.getProtocols() + for (const protocol of protocols) { + console.log('DEBUG: Protocol:', protocol) + } + } + + // Debug: Print connection encryption protocols + console.log('DEBUG: Connection encryption protocols:') + try { + if (node.components && node.components.connectionEncryption) { + for (const encrypter of node.components.connectionEncryption) { + console.log('DEBUG: Encrypter:', encrypter.protocol) + } + } + } catch (e) { + console.log('DEBUG: Could not access connectionEncryption:', e.message) + } + + // Debug: Print stream muxer protocols + console.log('DEBUG: Stream muxer protocols:') + try { + if (node.components && node.components.streamMuxers) { + for (const muxer of node.components.streamMuxers) { + console.log('DEBUG: Muxer:', muxer.protocol) + } + } + } catch (e) { + console.log('DEBUG: Could not access streamMuxers:', e.message) + } + // Keep the process alive await new Promise(() => {}) } diff --git a/tests/interop/test_js_ws_ping.py b/tests/interop/test_js_ws_ping.py index 7f0f0660..700caed3 100644 --- a/tests/interop/test_js_ws_ping.py +++ b/tests/interop/test_js_ws_ping.py @@ -9,16 +9,8 @@ from trio.lowlevel import open_process from libp2p.crypto.secp256k1 import create_new_key_pair from libp2p.custom_types import TProtocol -from libp2p.host.basic_host import BasicHost from libp2p.network.exceptions import SwarmException -from libp2p.network.swarm import Swarm from libp2p.peer.id import ID -from libp2p.peer.peerinfo import PeerInfo -from libp2p.peer.peerstore import PeerStore -from libp2p.security.insecure.transport import InsecureTransport -from libp2p.stream_muxer.yamux.yamux import Yamux -from libp2p.transport.upgrader import TransportUpgrader -from libp2p.transport.websocket.transport import WebsocketTransport PLAINTEXT_PROTOCOL_ID = "/plaintext/2.0.0" @@ -97,11 +89,14 @@ async def test_ping_with_js_node(): stderr = proc.stderr try: - # Read first two lines (PeerID and multiaddr) - print("Waiting for JS node to output PeerID and multiaddr...") + # Read JS node output until we get peer ID and multiaddrs + print("Waiting for JS node to output PeerID and multiaddrs...") buffer = b"" + peer_id_found: str | bool = False + multiaddrs_found = [] + with trio.fail_after(30): - while buffer.count(b"\n") < 2: + while True: chunk = await stdout.receive_some(1024) if not chunk: print("No more data from JS node stdout") @@ -109,53 +104,84 @@ async def test_ping_with_js_node(): buffer += chunk print(f"Received chunk: {chunk}") - print(f"Total buffer received: {buffer}") - lines = [line for line in buffer.decode().splitlines() if line.strip()] - print(f"Parsed lines: {lines}") + # Parse lines as we receive them + lines = buffer.decode().splitlines() + for line in lines: + line = line.strip() + if not line: + continue - if len(lines) < 2: - print("Not enough lines from JS node, checking stderr...") + # Look for peer ID (starts with "12D3Koo") + if line.startswith("12D3Koo") and not peer_id_found: + peer_id_found = line + print(f"Found peer ID: {peer_id_found}") + + # Look for multiaddrs (start with "/ip4/" or "/ip6/") + elif line.startswith("/ip4/") or line.startswith("/ip6/"): + if line not in multiaddrs_found: + multiaddrs_found.append(line) + print(f"Found multiaddr: {line}") + + # Stop when we have peer ID and at least one multiaddr + if peer_id_found and multiaddrs_found: + print(f"โœ… Collected: Peer ID + {len(multiaddrs_found)} multiaddrs") + break + + print(f"Total buffer received: {buffer}") + all_lines = [line for line in buffer.decode().splitlines() if line.strip()] + print(f"All JS Node lines: {all_lines}") + + if not peer_id_found or not multiaddrs_found: + print("Missing peer ID or multiaddrs from JS node, checking stderr...") stderr_output = await stderr.receive_some(2048) stderr_output = stderr_output.decode() print(f"JS node stderr: {stderr_output}") pytest.fail( "JS node did not produce expected PeerID and multiaddr.\n" + f"Found peer ID: {peer_id_found}\n" + f"Found multiaddrs: {multiaddrs_found}\n" f"Stdout: {buffer.decode()!r}\n" f"Stderr: {stderr_output!r}" ) - peer_id_line, addr_line = lines[0], lines[1] - peer_id = ID.from_base58(peer_id_line) - maddr = Multiaddr(addr_line) + + # peer_id = ID.from_base58(peer_id_found) # Not used currently + # Use the first localhost multiaddr preferentially, or fallback to first + # available + maddr = None + for addr_str in multiaddrs_found: + if "127.0.0.1" in addr_str: + maddr = Multiaddr(addr_str) + break + if not maddr: + maddr = Multiaddr(multiaddrs_found[0]) # Debug: Print what we're trying to connect to - print(f"JS Node Peer ID: {peer_id_line}") - print(f"JS Node Address: {addr_line}") - print(f"All JS Node lines: {lines}") - print(f"Parsed multiaddr: {maddr}") + print(f"JS Node Peer ID: {peer_id_found}") + print(f"JS Node Address: {maddr}") + print(f"All found multiaddrs: {multiaddrs_found}") + print(f"Selected multiaddr: {maddr}") - # Set up Python host + # Set up Python host using new_host API with Noise security print("Setting up Python host...") - key_pair = create_new_key_pair() - py_peer_id = ID.from_pubkey(key_pair.public_key) - peer_store = PeerStore() - peer_store.add_key_pair(py_peer_id, key_pair) - print(f"Python Peer ID: {py_peer_id}") + from libp2p import create_yamux_muxer_option, new_host - # Use only plaintext security to match the JavaScript node - upgrader = TransportUpgrader( - secure_transports_by_protocol={ - TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair) - }, - muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux}, + key_pair = create_new_key_pair() + # noise_key_pair = create_new_x25519_key_pair() # Not used currently + print(f"Python Peer ID: {ID.from_pubkey(key_pair.public_key)}") + + # Use default security options (includes Noise, SecIO, and plaintext) + # This will allow protocol negotiation to choose the best match + host = new_host( + key_pair=key_pair, + muxer_opt=create_yamux_muxer_option(), + listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], ) - transport = WebsocketTransport(upgrader) - print(f"WebSocket transport created: {transport}") - swarm = Swarm(py_peer_id, peer_store, upgrader, transport) - host = BasicHost(swarm) print(f"Python host created: {host}") - # Connect to JS node - peer_info = PeerInfo(peer_id, [maddr]) + # Connect to JS node using modern peer info + from libp2p.peer.peerinfo import info_from_p2p_addr + + peer_info = info_from_p2p_addr(maddr) print(f"Python trying to connect to: {peer_info}") print(f"Peer info addresses: {peer_info.addrs}") @@ -169,37 +195,62 @@ async def test_ping_with_js_node(): try: parsed = parse_websocket_multiaddr(maddr) print( - f"Parsed WebSocket multiaddr: is_wss={parsed.is_wss}, sni={parsed.sni}, rest_multiaddr={parsed.rest_multiaddr}" + f"Parsed WebSocket multiaddr: is_wss={parsed.is_wss}, " + f"sni={parsed.sni}, rest_multiaddr={parsed.rest_multiaddr}" ) except Exception as e: print(f"Failed to parse WebSocket multiaddr: {e}") - await trio.sleep(1) + # Use proper host.run() context manager + async with host.run(listen_addrs=[]): + await trio.sleep(1) - try: - print("Attempting to connect to JS node...") - await host.connect(peer_info) - print("Successfully connected to JS node!") - except SwarmException as e: - underlying_error = e.__cause__ - print(f"Connection failed with SwarmException: {e}") - print(f"Underlying error: {underlying_error}") - pytest.fail( - "Connection failed with SwarmException.\n" - f"THE REAL ERROR IS: {underlying_error!r}\n" - ) + try: + print("Attempting to connect to JS node...") + await host.connect(peer_info) + print("Successfully connected to JS node!") + except SwarmException as e: + underlying_error = e.__cause__ + print(f"Connection failed with SwarmException: {e}") + print(f"Underlying error: {underlying_error}") + pytest.fail( + "Connection failed with SwarmException.\n" + f"THE REAL ERROR IS: {underlying_error!r}\n" + ) - assert host.get_network().connections.get(peer_id) is not None + # Verify connection was established + assert host.get_network().connections.get(peer_info.peer_id) is not None - # Ping protocol - stream = await host.new_stream(peer_id, [TProtocol("/ipfs/ping/1.0.0")]) - await stream.write(b"ping") - data = await stream.read(4) - assert data == b"pong" + # Try to ping the JS node + ping_protocol = TProtocol("/ipfs/ping/1.0.0") + try: + print("Opening ping stream...") + stream = await host.new_stream(peer_info.peer_id, [ping_protocol]) + print("Ping stream opened successfully!") - print("Closing Python host...") - await host.close() - print("Python host closed successfully") + # Send ping data (32 bytes as per libp2p ping protocol) + ping_data = b"\x00" * 32 + await stream.write(ping_data) + print(f"Sent ping: {len(ping_data)} bytes") + + # Wait for pong response + pong_data = await stream.read(32) + print(f"Received pong: {len(pong_data)} bytes") + + # Verify the pong matches the ping + assert pong_data == ping_data, ( + f"Ping/pong mismatch: {ping_data!r} != {pong_data!r}" + ) + print("โœ… Ping/pong successful!") + + await stream.close() + print("Stream closed successfully!") + + except Exception as e: + print(f"Ping failed: {e}") + pytest.fail(f"Ping failed: {e}") + + print("๐ŸŽ‰ JavaScript WebSocket interop test completed successfully!") finally: print(f"Terminating JS node process (PID: {proc.pid})...") try: