From fe4c17e8d12579a92580a6895c0ca278e8cc76bf Mon Sep 17 00:00:00 2001 From: acul71 Date: Mon, 11 Aug 2025 01:25:49 +0200 Subject: [PATCH] Fix typecheck errors and improve WebSocket transport implementation - Fix INotifee interface compliance in WebSocket demo - Fix handler function signatures to be async (THandler compatibility) - Fix is_closed method usage with proper type checking - Fix pytest.raises multiple exception type issue - Fix line length violations (E501) across multiple files - Add debugging logging to Noise security module for troubleshooting - Update WebSocket transport examples and tests - Improve transport registry error handling --- examples/transport_integration_demo.py | 73 ++-- examples/websocket/test_tcp_echo.py | 54 +-- .../websocket/test_websocket_transport.py | 66 ++-- examples/websocket/websocket_demo.py | 275 +++++++++++---- libp2p/__init__.py | 24 +- libp2p/security/noise/io.py | 14 +- libp2p/security/noise/messages.py | 30 +- libp2p/security/noise/patterns.py | 35 ++ libp2p/transport/__init__.py | 13 +- libp2p/transport/transport_registry.py | 109 +++--- libp2p/transport/websocket/connection.py | 83 ++++- libp2p/transport/websocket/listener.py | 71 ++-- libp2p/transport/websocket/transport.py | 7 +- .../core/transport/test_transport_registry.py | 149 ++++---- tests/core/transport/test_websocket.py | 319 +++++++++--------- tests/interop/test_js_ws_ping.py | 11 +- 16 files changed, 845 insertions(+), 488 deletions(-) rename test_websocket_transport.py => examples/websocket/test_websocket_transport.py (85%) diff --git a/examples/transport_integration_demo.py b/examples/transport_integration_demo.py index a7138e55..424979e9 100644 --- a/examples/transport_integration_demo.py +++ b/examples/transport_integration_demo.py @@ -11,13 +11,14 @@ This script demonstrates: import asyncio import logging -import sys from pathlib import Path +import sys # Add the libp2p directory to the path so we can import it sys.path.insert(0, str(Path(__file__).parent.parent)) import multiaddr + from libp2p.transport import ( create_transport, create_transport_for_multiaddr, @@ -25,9 +26,8 @@ from libp2p.transport import ( get_transport_registry, register_transport, ) -from libp2p.transport.upgrader import TransportUpgrader from libp2p.transport.tcp.tcp import TCP -from libp2p.transport.websocket.transport import WebsocketTransport +from libp2p.transport.upgrader import TransportUpgrader # Set up logging logging.basicConfig(level=logging.INFO) @@ -38,20 +38,21 @@ def demo_transport_registry(): """Demonstrate the transport registry functionality.""" print("๐Ÿ”ง Transport Registry Demo") print("=" * 50) - + # Get the global registry registry = get_transport_registry() - + # Show supported protocols supported = get_supported_transport_protocols() print(f"Supported transport protocols: {supported}") - + # Show registered transports print("\nRegistered transports:") for protocol in supported: transport_class = registry.get_transport(protocol) - print(f" {protocol}: {transport_class.__name__}") - + class_name = transport_class.__name__ if transport_class else "None" + print(f" {protocol}: {class_name}") + print() @@ -59,21 +60,21 @@ def demo_transport_factory(): """Demonstrate the transport factory functions.""" print("๐Ÿญ Transport Factory Demo") print("=" * 50) - + # Create a dummy upgrader for WebSocket transport upgrader = TransportUpgrader({}, {}) - + # Create transports using the factory function try: tcp_transport = create_transport("tcp") print(f"โœ… Created TCP transport: {type(tcp_transport).__name__}") - + ws_transport = create_transport("ws", upgrader) print(f"โœ… Created WebSocket transport: {type(ws_transport).__name__}") - + except Exception as e: print(f"โŒ Error creating transport: {e}") - + print() @@ -81,10 +82,10 @@ def demo_multiaddr_transport_selection(): """Demonstrate automatic transport selection based on multiaddrs.""" print("๐ŸŽฏ Multiaddr Transport Selection Demo") print("=" * 50) - + # Create a dummy upgrader upgrader = TransportUpgrader({}, {}) - + # Test different multiaddr types test_addrs = [ "/ip4/127.0.0.1/tcp/8080", @@ -92,20 +93,20 @@ def demo_multiaddr_transport_selection(): "/ip6/::1/tcp/8080/ws", "/dns4/example.com/tcp/443/ws", ] - + for addr_str in test_addrs: try: maddr = multiaddr.Multiaddr(addr_str) transport = create_transport_for_multiaddr(maddr, upgrader) - + if transport: print(f"โœ… {addr_str} -> {type(transport).__name__}") else: print(f"โŒ {addr_str} -> No transport found") - + except Exception as e: print(f"โŒ {addr_str} -> Error: {e}") - + print() @@ -113,34 +114,37 @@ def demo_custom_transport_registration(): """Demonstrate how to register custom transports.""" print("๐Ÿ”ง Custom Transport Registration Demo") print("=" * 50) - - # Create a dummy upgrader - upgrader = TransportUpgrader({}, {}) - + # Show current supported protocols print(f"Before registration: {get_supported_transport_protocols()}") - + # Register a custom transport (using TCP as an example) class CustomTCPTransport(TCP): """Custom TCP transport for demonstration.""" + def __init__(self): super().__init__() self.custom_flag = True - + # Register the custom transport register_transport("custom_tcp", CustomTCPTransport) - + # Show updated supported protocols print(f"After registration: {get_supported_transport_protocols()}") - + # Test creating the custom transport try: custom_transport = create_transport("custom_tcp") print(f"โœ… Created custom transport: {type(custom_transport).__name__}") - print(f" Custom flag: {custom_transport.custom_flag}") + # Check if it has the custom flag (type-safe way) + if hasattr(custom_transport, "custom_flag"): + flag_value = getattr(custom_transport, "custom_flag", "Not found") + print(f" Custom flag: {flag_value}") + else: + print(" Custom flag: Not found") except Exception as e: print(f"โŒ Error creating custom transport: {e}") - + print() @@ -148,7 +152,7 @@ def demo_integration_with_libp2p(): """Demonstrate how the new system integrates with libp2p.""" print("๐Ÿš€ Libp2p Integration Demo") print("=" * 50) - + print("The new transport system integrates seamlessly with libp2p:") print() print("1. โœ… Automatic transport selection based on multiaddr") @@ -157,7 +161,7 @@ def demo_integration_with_libp2p(): print("4. โœ… Easy registration of new transport protocols") print("5. โœ… No changes needed to existing libp2p code") print() - + print("Example usage in libp2p:") print(" # This will automatically use WebSocket transport") print(" host = new_host(listen_addrs=['/ip4/127.0.0.1/tcp/8080/ws'])") @@ -165,7 +169,7 @@ def demo_integration_with_libp2p(): print(" # This will automatically use TCP transport") print(" host = new_host(listen_addrs=['/ip4/127.0.0.1/tcp/8080'])") print() - + print() @@ -174,14 +178,14 @@ async def main(): print("๐ŸŽ‰ Py-libp2p Transport Integration Demo") print("=" * 60) print() - + # Run all demos demo_transport_registry() demo_transport_factory() demo_multiaddr_transport_selection() demo_custom_transport_registration() demo_integration_with_libp2p() - + print("๐ŸŽฏ Summary of New Features:") print("=" * 40) print("โœ… Transport Registry: Central registry for all transport implementations") @@ -202,4 +206,5 @@ if __name__ == "__main__": except Exception as e: print(f"\nโŒ Demo failed with error: {e}") import traceback + traceback.print_exc() diff --git a/examples/websocket/test_tcp_echo.py b/examples/websocket/test_tcp_echo.py index b9d4ef09..20728bf6 100644 --- a/examples/websocket/test_tcp_echo.py +++ b/examples/websocket/test_tcp_echo.py @@ -5,7 +5,6 @@ Simple TCP echo demo to verify basic libp2p functionality. import argparse import logging -import sys import traceback import multiaddr @@ -18,10 +17,10 @@ from libp2p.network.swarm import Swarm from libp2p.peer.id import ID from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.peer.peerstore import PeerStore -from libp2p.security.insecure.transport import InsecureTransport, PLAINTEXT_PROTOCOL_ID +from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport from libp2p.stream_muxer.yamux.yamux import Yamux -from libp2p.transport.upgrader import TransportUpgrader from libp2p.transport.tcp.tcp import TCP +from libp2p.transport.upgrader import TransportUpgrader # Enable debug logging logging.basicConfig(level=logging.DEBUG) @@ -31,12 +30,13 @@ logger = logging.getLogger("libp2p.tcp-example") # Simple echo protocol ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0") + async def echo_handler(stream): """Simple echo handler that echoes back any data received.""" try: data = await stream.read(1024) if data: - message = data.decode('utf-8', errors='replace') + message = data.decode("utf-8", errors="replace") print(f"๐Ÿ“ฅ Received: {message}") print(f"๐Ÿ“ค Echoing back: {message}") await stream.write(data) @@ -45,6 +45,7 @@ async def echo_handler(stream): logger.error(f"Echo handler error: {e}") await stream.close() + def create_tcp_host(): """Create a host with TCP transport.""" # Create key pair and peer store @@ -60,31 +61,35 @@ def create_tcp_host(): }, muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux}, ) - + # Create TCP transport transport = TCP() - + # Create swarm and host swarm = Swarm(peer_id, peer_store, upgrader, transport) host = BasicHost(swarm) - + return host + async def run(port: int, destination: str) -> None: localhost_ip = "0.0.0.0" if not destination: # Create first host (listener) with TCP transport listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}") - + try: host = create_tcp_host() logger.debug("Created TCP host") - + # Set up echo handler host.set_stream_handler(ECHO_PROTOCOL_ID, echo_handler) - async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery: + async with ( + host.run(listen_addrs=[listen_addr]), + trio.open_nursery() as (nursery), + ): # Start the peer-store cleanup task nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) @@ -95,15 +100,15 @@ async def run(port: int, destination: str) -> None: if not addrs: print("โŒ Error: No addresses found for the host") return - + server_addr = str(addrs[0]) client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/") print("๐ŸŒ TCP Server Started Successfully!") print("=" * 50) print(f"๐Ÿ“ Server Address: {client_addr}") - print(f"๐Ÿ”ง Protocol: /echo/1.0.0") - print(f"๐Ÿš€ Transport: TCP") + print("๐Ÿ”ง Protocol: /echo/1.0.0") + print("๐Ÿš€ Transport: TCP") print() print("๐Ÿ“‹ To test the connection, run this in another terminal:") print(f" python test_tcp_echo.py -d {client_addr}") @@ -112,7 +117,7 @@ async def run(port: int, destination: str) -> None: print("โ”€" * 50) await trio.sleep_forever() - + except Exception as e: print(f"โŒ Error creating TCP server: {e}") traceback.print_exc() @@ -121,13 +126,16 @@ async def run(port: int, destination: str) -> None: else: # Create second host (dialer) with TCP transport listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}") - + try: # Create a single host for client operations host = create_tcp_host() - + # Start the host for client operations - async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery: + async with ( + host.run(listen_addrs=[listen_addr]), + trio.open_nursery() as (nursery), + ): # Start the peer-store cleanup task nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) maddr = multiaddr.Multiaddr(destination) @@ -144,7 +152,7 @@ async def run(port: int, destination: str) -> None: print("โœ… Successfully connected to TCP server!") except Exception as e: error_msg = str(e) - print(f"\nโŒ Connection Failed!") + print("\nโŒ Connection Failed!") print(f" Peer ID: {info.peer_id}") print(f" Address: {destination}") print(f" Error: {error_msg}") @@ -185,24 +193,28 @@ async def run(port: int, destination: str) -> None: traceback.print_exc() print("โœ… TCP demo completed successfully!") - + except Exception as e: print(f"โŒ Error creating TCP client: {e}") traceback.print_exc() return + def main() -> None: description = "Simple TCP echo demo for libp2p" parser = argparse.ArgumentParser(description=description) parser.add_argument("-p", "--port", default=0, type=int, help="source port number") - parser.add_argument("-d", "--destination", type=str, help="destination multiaddr string") + parser.add_argument( + "-d", "--destination", type=str, help="destination multiaddr string" + ) args = parser.parse_args() - + try: trio.run(run, args.port, args.destination) except KeyboardInterrupt: pass + if __name__ == "__main__": main() diff --git a/test_websocket_transport.py b/examples/websocket/test_websocket_transport.py similarity index 85% rename from test_websocket_transport.py rename to examples/websocket/test_websocket_transport.py index b0bca17e..86353ef9 100644 --- a/test_websocket_transport.py +++ b/examples/websocket/test_websocket_transport.py @@ -5,16 +5,16 @@ Simple test script to verify WebSocket transport functionality. import asyncio import logging -import sys from pathlib import Path +import sys # Add the libp2p directory to the path so we can import it sys.path.insert(0, str(Path(__file__).parent)) import multiaddr + from libp2p.transport import create_transport, create_transport_for_multiaddr from libp2p.transport.upgrader import TransportUpgrader -from libp2p.network.connection.raw_connection import RawConnection # Set up logging logging.basicConfig(level=logging.DEBUG) @@ -25,48 +25,57 @@ async def test_websocket_transport(): """Test basic WebSocket transport functionality.""" print("๐Ÿงช Testing WebSocket Transport Functionality") print("=" * 50) - + # Create a dummy upgrader upgrader = TransportUpgrader({}, {}) - + # Test creating WebSocket transport try: ws_transport = create_transport("ws", upgrader) print(f"โœ… WebSocket transport created: {type(ws_transport).__name__}") - + # Test creating transport from multiaddr ws_maddr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") ws_transport_from_maddr = create_transport_for_multiaddr(ws_maddr, upgrader) - print(f"โœ… WebSocket transport from multiaddr: {type(ws_transport_from_maddr).__name__}") - + print( + f"โœ… WebSocket transport from multiaddr: " + f"{type(ws_transport_from_maddr).__name__}" + ) + # Test creating listener handler_called = False - + async def test_handler(conn): nonlocal handler_called handler_called = True print(f"โœ… Connection handler called with: {type(conn).__name__}") await conn.close() - + listener = ws_transport.create_listener(test_handler) print(f"โœ… WebSocket listener created: {type(listener).__name__}") - + # Test that the transport can be used - print(f"โœ… WebSocket transport supports dialing: {hasattr(ws_transport, 'dial')}") - print(f"โœ… WebSocket transport supports listening: {hasattr(ws_transport, 'create_listener')}") - + print( + f"โœ… WebSocket transport supports dialing: {hasattr(ws_transport, 'dial')}" + ) + print( + f"โœ… WebSocket transport supports listening: " + f"{hasattr(ws_transport, 'create_listener')}" + ) + print("\n๐ŸŽฏ WebSocket Transport Test Results:") print("โœ… Transport creation: PASS") print("โœ… Multiaddr parsing: PASS") print("โœ… Listener creation: PASS") print("โœ… Interface compliance: PASS") - + except Exception as e: print(f"โŒ WebSocket transport test failed: {e}") import traceback + traceback.print_exc() return False - + return True @@ -74,22 +83,26 @@ async def test_transport_registry(): """Test the transport registry functionality.""" print("\n๐Ÿ”ง Testing Transport Registry") print("=" * 30) - - from libp2p.transport import get_transport_registry, get_supported_transport_protocols - + + from libp2p.transport import ( + get_supported_transport_protocols, + get_transport_registry, + ) + registry = get_transport_registry() supported = get_supported_transport_protocols() - + print(f"Supported protocols: {supported}") - + # Test getting transports for protocol in supported: transport_class = registry.get_transport(protocol) - print(f" {protocol}: {transport_class.__name__}") - + class_name = transport_class.__name__ if transport_class else "None" + print(f" {protocol}: {class_name}") + # Test creating transports through registry upgrader = TransportUpgrader({}, {}) - + for protocol in supported: try: transport = registry.create_transport(protocol, upgrader) @@ -106,17 +119,17 @@ async def main(): print("๐Ÿš€ WebSocket Transport Integration Test Suite") print("=" * 60) print() - + # Run tests success = await test_websocket_transport() await test_transport_registry() - + print("\n" + "=" * 60) if success: print("๐ŸŽ‰ All tests passed! WebSocket transport is working correctly.") else: print("โŒ Some tests failed. Check the output above for details.") - + print("\n๐Ÿš€ WebSocket transport is ready for use in py-libp2p!") @@ -128,4 +141,5 @@ if __name__ == "__main__": except Exception as e: print(f"\nโŒ Test failed with error: {e}") import traceback + traceback.print_exc() diff --git a/examples/websocket/websocket_demo.py b/examples/websocket/websocket_demo.py index 2e2e0477..bd13a881 100644 --- a/examples/websocket/websocket_demo.py +++ b/examples/websocket/websocket_demo.py @@ -1,21 +1,26 @@ import argparse import logging +import signal import sys import traceback import multiaddr import trio +from libp2p.abc import INotifee +from libp2p.crypto.ed25519 import create_new_key_pair as create_ed25519_key_pair from libp2p.crypto.secp256k1 import create_new_key_pair from libp2p.custom_types import TProtocol from libp2p.host.basic_host import BasicHost from libp2p.network.swarm import Swarm from libp2p.peer.id import ID -from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr +from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.peer.peerstore import PeerStore -from libp2p.security.insecure.transport import InsecureTransport, PLAINTEXT_PROTOCOL_ID -from libp2p.security.noise.transport import Transport as NoiseTransport -from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID +from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport +from libp2p.security.noise.transport import ( + PROTOCOL_ID as NOISE_PROTOCOL_ID, + Transport as NoiseTransport, +) from libp2p.stream_muxer.yamux.yamux import Yamux from libp2p.transport.upgrader import TransportUpgrader from libp2p.transport.websocket.transport import WebsocketTransport @@ -25,6 +30,15 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger("libp2p.websocket-example") + +# Suppress KeyboardInterrupt by handling SIGINT directly +def signal_handler(signum, frame): + print("โœ… Clean exit completed.") + sys.exit(0) + + +signal.signal(signal.SIGINT, signal_handler) + # Simple echo protocol ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0") @@ -34,7 +48,7 @@ async def echo_handler(stream): try: data = await stream.read(1024) if data: - message = data.decode('utf-8', errors='replace') + message = data.decode("utf-8", errors="replace") print(f"๐Ÿ“ฅ Received: {message}") print(f"๐Ÿ“ค Echoing back: {message}") await stream.write(data) @@ -44,7 +58,7 @@ async def echo_handler(stream): await stream.close() -def create_websocket_host(listen_addrs=None, use_noise=False): +def create_websocket_host(listen_addrs=None, use_plaintext=False): """Create a host with WebSocket transport.""" # Create key pair and peer store key_pair = create_new_key_pair() @@ -52,11 +66,22 @@ def create_websocket_host(listen_addrs=None, use_noise=False): peer_store = PeerStore() peer_store.add_key_pair(peer_id, key_pair) - if use_noise: + if use_plaintext: + # Create transport upgrader with plaintext security + upgrader = TransportUpgrader( + secure_transports_by_protocol={ + TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair) + }, + muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux}, + ) + else: + # Create separate Ed25519 key for Noise protocol + noise_key_pair = create_ed25519_key_pair() + # Create Noise transport noise_transport = NoiseTransport( libp2p_keypair=key_pair, - noise_privkey=key_pair.private_key, + noise_privkey=noise_key_pair.private_key, early_data=None, with_noise_pipes=False, ) @@ -68,43 +93,85 @@ def create_websocket_host(listen_addrs=None, use_noise=False): }, muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux}, ) - else: - # Create transport upgrader with plaintext security - upgrader = TransportUpgrader( - secure_transports_by_protocol={ - TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair) - }, - muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux}, - ) - + # Create WebSocket transport transport = WebsocketTransport(upgrader) - + # Create swarm and host swarm = Swarm(peer_id, peer_store, upgrader, transport) host = BasicHost(swarm) - + return host -async def run(port: int, destination: str, use_noise: bool = False) -> None: +async def run(port: int, destination: str, use_plaintext: bool = False) -> None: localhost_ip = "0.0.0.0" if not destination: # Create first host (listener) with WebSocket transport listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}/ws") - + try: - host = create_websocket_host(use_noise=use_noise) - logger.debug(f"Created host with use_noise={use_noise}") - + host = create_websocket_host(use_plaintext=use_plaintext) + logger.debug(f"Created host with use_plaintext={use_plaintext}") + # Set up echo handler host.set_stream_handler(ECHO_PROTOCOL_ID, echo_handler) - async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery: + # Add connection event handlers for debugging + class DebugNotifee(INotifee): + async def opened_stream(self, network, stream): + pass + + async def closed_stream(self, network, stream): + pass + + async def connected(self, network, conn): + print( + f"๐Ÿ”— New libp2p connection established: " + f"{conn.muxed_conn.peer_id}" + ) + if hasattr(conn.muxed_conn, "get_security_protocol"): + security = conn.muxed_conn.get_security_protocol() + else: + security = "Unknown" + + print(f" Security: {security}") + + async def disconnected(self, network, conn): + print(f"๐Ÿ”Œ libp2p connection closed: {conn.muxed_conn.peer_id}") + + async def listen(self, network, multiaddr): + pass + + async def listen_close(self, network, multiaddr): + pass + + host.get_network().register_notifee(DebugNotifee()) + + # Create a cancellation token for clean shutdown + cancel_scope = trio.CancelScope() + + async def signal_handler(): + with trio.open_signal_receiver(signal.SIGINT, signal.SIGTERM) as ( + signal_receiver + ): + async for sig in signal_receiver: + print(f"\n๐Ÿ›‘ Received signal {sig}") + print("โœ… Shutting down WebSocket server...") + cancel_scope.cancel() + return + + async with ( + host.run(listen_addrs=[listen_addr]), + trio.open_nursery() as (nursery), + ): # Start the peer-store cleanup task nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) + # Start the signal handler + nursery.start_soon(signal_handler) + # Get the actual address and replace 0.0.0.0 with 127.0.0.1 for client # connections addrs = host.get_addrs() @@ -113,18 +180,19 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None: print("โŒ Error: No addresses found for the host") print("Debug: host.get_addrs() returned empty list") return - + server_addr = str(addrs[0]) client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/") print("๐ŸŒ WebSocket Server Started Successfully!") print("=" * 50) print(f"๐Ÿ“ Server Address: {client_addr}") - print(f"๐Ÿ”ง Protocol: /echo/1.0.0") - print(f"๐Ÿš€ Transport: WebSocket (/ws)") + print("๐Ÿ”ง Protocol: /echo/1.0.0") + print("๐Ÿš€ Transport: WebSocket (/ws)") print() print("๐Ÿ“‹ To test the connection, run this in another terminal:") - print(f" python websocket_demo.py -d {client_addr}") + plaintext_flag = " --plaintext" if use_plaintext else "" + print(f" python websocket_demo.py -d {client_addr}{plaintext_flag}") print() print("โณ Waiting for incoming WebSocket connections...") print("โ”€" * 50) @@ -132,32 +200,34 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None: # Add a custom handler to show connection events async def custom_echo_handler(stream): peer_id = stream.muxed_conn.peer_id - print(f"\n๐Ÿ”— New WebSocket Connection!") + print("\n๐Ÿ”— New WebSocket Connection!") print(f" Peer ID: {peer_id}") - print(f" Protocol: /echo/1.0.0") - + print(" Protocol: /echo/1.0.0") + # Show remote address in multiaddr format try: remote_address = stream.get_remote_address() if remote_address: print(f" Remote: {remote_address}") except Exception: - print(f" Remote: Unknown") - - print(f" โ”€" * 40) + print(" Remote: Unknown") + + print(" โ”€" * 40) # Call the original handler await echo_handler(stream) - print(f" โ”€" * 40) + print(" โ”€" * 40) print(f"โœ… Echo request completed for peer: {peer_id}") print() # Replace the handler with our custom one host.set_stream_handler(ECHO_PROTOCOL_ID, custom_echo_handler) - await trio.sleep_forever() - + # Wait indefinitely or until cancelled + with cancel_scope: + await trio.sleep_forever() + except Exception as e: print(f"โŒ Error creating WebSocket server: {e}") traceback.print_exc() @@ -166,15 +236,47 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None: else: # Create second host (dialer) with WebSocket transport listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}/ws") - + try: # Create a single host for client operations - host = create_websocket_host(use_noise=use_noise) - + host = create_websocket_host(use_plaintext=use_plaintext) + # Start the host for client operations - async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery: + async with ( + host.run(listen_addrs=[listen_addr]), + trio.open_nursery() as (nursery), + ): # Start the peer-store cleanup task nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) + + # Add connection event handlers for debugging + class ClientDebugNotifee(INotifee): + async def opened_stream(self, network, stream): + pass + + async def closed_stream(self, network, stream): + pass + + async def connected(self, network, conn): + print( + f"๐Ÿ”— Client: libp2p connection established: " + f"{conn.muxed_conn.peer_id}" + ) + + async def disconnected(self, network, conn): + print( + f"๐Ÿ”Œ Client: libp2p connection closed: " + f"{conn.muxed_conn.peer_id}" + ) + + async def listen(self, network, multiaddr): + pass + + async def listen_close(self, network, multiaddr): + pass + + host.get_network().register_notifee(ClientDebugNotifee()) + maddr = multiaddr.Multiaddr(destination) info = info_from_p2p_addr(maddr) print("๐Ÿ”Œ WebSocket Client Starting...") @@ -185,21 +287,34 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None: try: print("๐Ÿ”— Connecting to WebSocket server...") + print(f" Security: {'Plaintext' if use_plaintext else 'Noise'}") await host.connect(info) print("โœ… Successfully connected to WebSocket server!") except Exception as e: error_msg = str(e) - if "unable to connect" in error_msg or "SwarmException" in error_msg: - print(f"\nโŒ Connection Failed!") - print(f" Peer ID: {info.peer_id}") - print(f" Address: {destination}") - print(f" Error: {error_msg}") - print() - print("๐Ÿ’ก Troubleshooting:") - print(" โ€ข Make sure the WebSocket server is running") - print(" โ€ข Check that the server address is correct") - print(" โ€ข Verify the server is listening on the right port") - return + print("\nโŒ Connection Failed!") + print(f" Peer ID: {info.peer_id}") + print(f" Address: {destination}") + print(f" Security: {'Plaintext' if use_plaintext else 'Noise'}") + print(f" Error: {error_msg}") + print(f" Error type: {type(e).__name__}") + + # Add more detailed error information for debugging + if hasattr(e, "__cause__") and e.__cause__: + print(f" Root cause: {e.__cause__}") + print(f" Root cause type: {type(e.__cause__).__name__}") + + print() + print("๐Ÿ’ก Troubleshooting:") + print(" โ€ข Make sure the WebSocket server is running") + print(" โ€ข Check that the server address is correct") + print(" โ€ข Verify the server is listening on the right port") + print( + " โ€ข Ensure both client and server use the same sec protocol" + ) + if not use_plaintext: + print(" โ€ข Noise over WebSocket may have compatibility issues") + return # Create a stream and send test data try: @@ -242,8 +357,18 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None: finally: # Ensure stream is closed try: - if stream and not await stream.is_closed(): - await stream.close() + if stream: + # Check if stream has is_closed method and use it + has_is_closed = hasattr(stream, "is_closed") and callable( + getattr(stream, "is_closed") + ) + if has_is_closed: + # type: ignore[attr-defined] + if not await stream.is_closed(): + await stream.close() + else: + # Fallback: just try to close the stream + await stream.close() except Exception: pass @@ -256,7 +381,10 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None: print("โœ… libp2p integration verified!") print() print("๐Ÿš€ Your WebSocket transport is ready for production use!") - + + # Add a small delay to ensure all cleanup is complete + await trio.sleep(0.1) + except Exception as e: print(f"โŒ Error creating WebSocket client: {e}") traceback.print_exc() @@ -266,12 +394,15 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None: def main() -> None: description = """ This program demonstrates the libp2p WebSocket transport. - First run 'python websocket_demo.py -p [--noise]' to start a WebSocket server. - Then run 'python websocket_demo.py -d [--noise]' + First run + 'python websocket_demo.py -p [--plaintext]' to start a WebSocket server. + Then run + 'python websocket_demo.py -d [--plaintext]' where is the multiaddress shown by the server. - By default, this example uses plaintext security for communication. - Use --noise for testing with Noise encryption (experimental). + By default, this example uses Noise encryption for secure communication. + Use --plaintext for testing with unencrypted communication + (not recommended for production). """ example_maddr = ( @@ -287,20 +418,30 @@ def main() -> None: help=f"destination multiaddr string, e.g. {example_maddr}", ) parser.add_argument( - "--noise", + "--plaintext", action="store_true", - help="use Noise encryption instead of plaintext security", + help=( + "use plaintext security instead of Noise encryption " + "(not recommended for production)" + ), ) args = parser.parse_args() - # Determine security mode: use plaintext by default, Noise if --noise is specified - use_noise = args.noise - + # Determine security mode: use Noise by default, + # plaintext if --plaintext is specified + use_plaintext = args.plaintext + try: - trio.run(run, args.port, args.destination, use_noise) + trio.run(run, args.port, args.destination, use_plaintext) except KeyboardInterrupt: - pass + # This is expected when Ctrl+C is pressed + # The signal handler already printed the shutdown message + print("โœ… Clean exit completed.") + return + except Exception as e: + print(f"โŒ Unexpected error: {e}") + return if __name__ == "__main__": diff --git a/libp2p/__init__.py b/libp2p/__init__.py index d9c24960..91d60ae5 100644 --- a/libp2p/__init__.py +++ b/libp2p/__init__.py @@ -19,6 +19,7 @@ from libp2p.abc import ( IPeerRouting, IPeerStore, ISecureTransport, + ITransport, ) from libp2p.crypto.keys import ( KeyPair, @@ -231,14 +232,15 @@ def new_swarm( ) # Create transport based on listen_addrs or default to TCP + transport: ITransport if listen_addrs is None: transport = TCP() else: # Use the first address to determine transport type addr = listen_addrs[0] - transport = create_transport_for_multiaddr(addr, upgrader) - - if transport is None: + transport_maybe = create_transport_for_multiaddr(addr, upgrader) + + if transport_maybe is None: # Fallback to TCP if no specific transport found if addr.__contains__("tcp"): transport = TCP() @@ -250,20 +252,8 @@ def new_swarm( f"Unknown transport in listen_addrs: {listen_addrs}. " f"Supported protocols: {supported_protocols}" ) - - # Generate X25519 keypair for Noise - noise_key_pair = create_new_x25519_key_pair() - - # Default security transports (using Noise as primary) - secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport] = sec_opt or { - NOISE_PROTOCOL_ID: NoiseTransport( - key_pair, noise_privkey=noise_key_pair.private_key - ), - TProtocol(secio.ID): secio.Transport(key_pair), - TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport( - key_pair, peerstore=peerstore_opt - ), - } + else: + transport = transport_maybe # Use given muxer preference if provided, otherwise use global default if muxer_preference is not None: diff --git a/libp2p/security/noise/io.py b/libp2p/security/noise/io.py index a24b6c74..18fbbcd5 100644 --- a/libp2p/security/noise/io.py +++ b/libp2p/security/noise/io.py @@ -1,3 +1,4 @@ +import logging from typing import ( cast, ) @@ -15,6 +16,8 @@ from libp2p.io.msgio import ( FixedSizeLenMsgReadWriter, ) +logger = logging.getLogger(__name__) + SIZE_NOISE_MESSAGE_LEN = 2 MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1 SIZE_NOISE_MESSAGE_BODY_LEN = 2 @@ -50,18 +53,25 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter): self.noise_state = noise_state async def write_msg(self, msg: bytes, prefix_encoded: bool = False) -> None: + logger.debug(f"Noise write_msg: encrypting {len(msg)} bytes") data_encrypted = self.encrypt(msg) if prefix_encoded: # Manually add the prefix if needed data_encrypted = self.prefix + data_encrypted + logger.debug(f"Noise write_msg: writing {len(data_encrypted)} encrypted bytes") await self.read_writer.write_msg(data_encrypted) + logger.debug("Noise write_msg: write completed successfully") async def read_msg(self, prefix_encoded: bool = False) -> bytes: + logger.debug("Noise read_msg: reading encrypted message") noise_msg_encrypted = await self.read_writer.read_msg() + logger.debug(f"Noise read_msg: read {len(noise_msg_encrypted)} encrypted bytes") if prefix_encoded: - return self.decrypt(noise_msg_encrypted[len(self.prefix) :]) + result = self.decrypt(noise_msg_encrypted[len(self.prefix) :]) else: - return self.decrypt(noise_msg_encrypted) + result = self.decrypt(noise_msg_encrypted) + logger.debug(f"Noise read_msg: decrypted to {len(result)} bytes") + return result async def close(self) -> None: await self.read_writer.close() diff --git a/libp2p/security/noise/messages.py b/libp2p/security/noise/messages.py index 309b24b0..f7e2dceb 100644 --- a/libp2p/security/noise/messages.py +++ b/libp2p/security/noise/messages.py @@ -1,6 +1,7 @@ from dataclasses import ( dataclass, ) +import logging from libp2p.crypto.keys import ( PrivateKey, @@ -12,6 +13,8 @@ from libp2p.crypto.serialization import ( from .pb import noise_pb2 as noise_pb +logger = logging.getLogger(__name__) + SIGNED_DATA_PREFIX = "noise-libp2p-static-key:" @@ -48,6 +51,8 @@ def make_handshake_payload_sig( id_privkey: PrivateKey, noise_static_pubkey: PublicKey ) -> bytes: data = make_data_to_be_signed(noise_static_pubkey) + logger.debug(f"make_handshake_payload_sig: signing data length: {len(data)}") + logger.debug(f"make_handshake_payload_sig: signing data hex: {data.hex()}") return id_privkey.sign(data) @@ -60,4 +65,27 @@ def verify_handshake_payload_sig( 2. signed by the private key corresponding to `id_pubkey` """ expected_data = make_data_to_be_signed(noise_static_pubkey) - return payload.id_pubkey.verify(expected_data, payload.id_sig) + logger.debug( + f"verify_handshake_payload_sig: payload.id_pubkey type: " + f"{type(payload.id_pubkey)}" + ) + logger.debug( + f"verify_handshake_payload_sig: noise_static_pubkey type: " + f"{type(noise_static_pubkey)}" + ) + logger.debug( + f"verify_handshake_payload_sig: expected_data length: {len(expected_data)}" + ) + logger.debug( + f"verify_handshake_payload_sig: expected_data hex: {expected_data.hex()}" + ) + logger.debug( + f"verify_handshake_payload_sig: payload.id_sig length: {len(payload.id_sig)}" + ) + try: + result = payload.id_pubkey.verify(expected_data, payload.id_sig) + logger.debug(f"verify_handshake_payload_sig: verification result: {result}") + return result + except Exception as e: + logger.error(f"verify_handshake_payload_sig: verification exception: {e}") + return False diff --git a/libp2p/security/noise/patterns.py b/libp2p/security/noise/patterns.py index 00f51d06..d51332a4 100644 --- a/libp2p/security/noise/patterns.py +++ b/libp2p/security/noise/patterns.py @@ -2,6 +2,7 @@ from abc import ( ABC, abstractmethod, ) +import logging from cryptography.hazmat.primitives import ( serialization, @@ -46,6 +47,8 @@ from .messages import ( verify_handshake_payload_sig, ) +logger = logging.getLogger(__name__) + class IPattern(ABC): @abstractmethod @@ -95,6 +98,7 @@ class PatternXX(BasePattern): self.early_data = early_data async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn: + logger.debug(f"Noise XX handshake_inbound started for peer {self.local_peer}") noise_state = self.create_noise_state() noise_state.set_as_responder() noise_state.start_handshake() @@ -107,15 +111,22 @@ class PatternXX(BasePattern): read_writer = NoiseHandshakeReadWriter(conn, noise_state) # Consume msg#1. + logger.debug("Noise XX handshake_inbound: reading msg#1") await read_writer.read_msg() + logger.debug("Noise XX handshake_inbound: read msg#1 successfully") # Send msg#2, which should include our handshake payload. + logger.debug("Noise XX handshake_inbound: preparing msg#2") our_payload = self.make_handshake_payload() msg_2 = our_payload.serialize() + logger.debug(f"Noise XX handshake_inbound: sending msg#2 ({len(msg_2)} bytes)") await read_writer.write_msg(msg_2) + logger.debug("Noise XX handshake_inbound: sent msg#2 successfully") # Receive and consume msg#3. + logger.debug("Noise XX handshake_inbound: reading msg#3") msg_3 = await read_writer.read_msg() + logger.debug(f"Noise XX handshake_inbound: read msg#3 ({len(msg_3)} bytes)") peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3) if handshake_state.rs is None: @@ -147,6 +158,7 @@ class PatternXX(BasePattern): async def handshake_outbound( self, conn: IRawConnection, remote_peer: ID ) -> ISecureConn: + logger.debug(f"Noise XX handshake_outbound started to peer {remote_peer}") noise_state = self.create_noise_state() read_writer = NoiseHandshakeReadWriter(conn, noise_state) @@ -159,11 +171,15 @@ class PatternXX(BasePattern): raise NoiseStateError("Handshake state is not initialized") # Send msg#1, which is *not* encrypted. + logger.debug("Noise XX handshake_outbound: sending msg#1") msg_1 = b"" await read_writer.write_msg(msg_1) + logger.debug("Noise XX handshake_outbound: sent msg#1 successfully") # Read msg#2 from the remote, which contains the public key of the peer. + logger.debug("Noise XX handshake_outbound: reading msg#2") msg_2 = await read_writer.read_msg() + logger.debug(f"Noise XX handshake_outbound: read msg#2 ({len(msg_2)} bytes)") peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2) if handshake_state.rs is None: @@ -174,8 +190,27 @@ class PatternXX(BasePattern): ) remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs) + logger.debug( + f"Noise XX handshake_outbound: verifying signature for peer {remote_peer}" + ) + logger.debug( + f"Noise XX handshake_outbound: remote_pubkey type: {type(remote_pubkey)}" + ) + id_pubkey_repr = peer_handshake_payload.id_pubkey.to_bytes().hex() + logger.debug( + f"Noise XX handshake_outbound: peer_handshake_payload.id_pubkey: " + f"{id_pubkey_repr}" + ) if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey): + logger.error( + f"Noise XX handshake_outbound: signature verification failed for peer " + f"{remote_peer}" + ) raise InvalidSignature + logger.debug( + f"Noise XX handshake_outbound: signature verification successful for peer " + f"{remote_peer}" + ) remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey) if remote_peer_id_from_pubkey != remote_peer: raise PeerIDMismatchesPubkey( diff --git a/libp2p/transport/__init__.py b/libp2p/transport/__init__.py index aa58d051..67ea6a74 100644 --- a/libp2p/transport/__init__.py +++ b/libp2p/transport/__init__.py @@ -1,17 +1,19 @@ from .tcp.tcp import TCP from .websocket.transport import WebsocketTransport from .transport_registry import ( - TransportRegistry, + TransportRegistry, create_transport_for_multiaddr, get_transport_registry, register_transport, get_supported_transport_protocols, ) +from .upgrader import TransportUpgrader +from libp2p.abc import ITransport -def create_transport(protocol: str, upgrader=None): +def create_transport(protocol: str, upgrader: TransportUpgrader | None = None) -> ITransport: """ Convenience function to create a transport instance. - + :param protocol: The transport protocol ("tcp", "ws", or custom) :param upgrader: Optional transport upgrader (required for WebSocket) :return: Transport instance @@ -28,7 +30,10 @@ def create_transport(protocol: str, upgrader=None): registry = get_transport_registry() transport_class = registry.get_transport(protocol) if transport_class: - return registry.create_transport(protocol, upgrader) + transport = registry.create_transport(protocol, upgrader) + if transport is None: + raise ValueError(f"Failed to create transport for protocol: {protocol}") + return transport else: raise ValueError(f"Unsupported transport protocol: {protocol}") diff --git a/libp2p/transport/transport_registry.py b/libp2p/transport/transport_registry.py index ffa2a8fa..a6228d4e 100644 --- a/libp2p/transport/transport_registry.py +++ b/libp2p/transport/transport_registry.py @@ -3,13 +3,15 @@ Transport registry for dynamic transport selection based on multiaddr protocols. """ import logging -from typing import Dict, Type, Optional +from typing import Any + from multiaddr import Multiaddr +from multiaddr.protocols import Protocol from libp2p.abc import ITransport from libp2p.transport.tcp.tcp import TCP -from libp2p.transport.websocket.transport import WebsocketTransport from libp2p.transport.upgrader import TransportUpgrader +from libp2p.transport.websocket.transport import WebsocketTransport logger = logging.getLogger("libp2p.transport.registry") @@ -17,28 +19,29 @@ logger = logging.getLogger("libp2p.transport.registry") def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool: """ Validate that a multiaddr has a valid TCP structure. - + :param maddr: The multiaddr to validate :return: True if valid TCP structure, False otherwise """ try: # TCP multiaddr should have structure like /ip4/127.0.0.1/tcp/8080 # or /ip6/::1/tcp/8080 - protocols = maddr.protocols() - + protocols: list[Protocol] = list(maddr.protocols()) + # Must have at least 2 protocols: network (ip4/ip6) + tcp if len(protocols) < 2: return False - + # First protocol should be a network protocol (ip4, ip6, dns4, dns6) if protocols[0].name not in ["ip4", "ip6", "dns4", "dns6"]: return False - + # Second protocol should be tcp if protocols[1].name != "tcp": return False - - # Should not have any protocols after tcp (unless it's a valid continuation like p2p) + + # Should not have any protocols after tcp (unless it's a valid + # continuation like p2p) # For now, we'll be strict and only allow network + tcp if len(protocols) > 2: # Check if the additional protocols are valid continuations @@ -46,9 +49,9 @@ def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool: for i in range(2, len(protocols)): if protocols[i].name not in valid_continuations: return False - + return True - + except Exception: return False @@ -56,31 +59,31 @@ def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool: def _is_valid_websocket_multiaddr(maddr: Multiaddr) -> bool: """ Validate that a multiaddr has a valid WebSocket structure. - + :param maddr: The multiaddr to validate :return: True if valid WebSocket structure, False otherwise """ try: # WebSocket multiaddr should have structure like /ip4/127.0.0.1/tcp/8080/ws # or /ip6/::1/tcp/8080/ws - protocols = maddr.protocols() - + protocols: list[Protocol] = list(maddr.protocols()) + # Must have at least 3 protocols: network (ip4/ip6/dns4/dns6) + tcp + ws if len(protocols) < 3: return False - + # First protocol should be a network protocol (ip4, ip6, dns4, dns6) if protocols[0].name not in ["ip4", "ip6", "dns4", "dns6"]: return False - + # Second protocol should be tcp if protocols[1].name != "tcp": return False - + # Last protocol should be ws if protocols[-1].name != "ws": return False - + # Should not have any protocols between tcp and ws if len(protocols) > 3: # Check if the additional protocols are valid continuations @@ -88,9 +91,9 @@ def _is_valid_websocket_multiaddr(maddr: Multiaddr) -> bool: for i in range(2, len(protocols) - 1): if protocols[i].name not in valid_continuations: return False - + return True - + except Exception: return False @@ -99,46 +102,52 @@ class TransportRegistry: """ Registry for mapping multiaddr protocols to transport implementations. """ - - def __init__(self): - self._transports: Dict[str, Type[ITransport]] = {} + + def __init__(self) -> None: + self._transports: dict[str, type[ITransport]] = {} self._register_default_transports() - + def _register_default_transports(self) -> None: """Register the default transport implementations.""" # Register TCP transport for /tcp protocol self.register_transport("tcp", TCP) - + # Register WebSocket transport for /ws protocol self.register_transport("ws", WebsocketTransport) - - def register_transport(self, protocol: str, transport_class: Type[ITransport]) -> None: + + def register_transport( + self, protocol: str, transport_class: type[ITransport] + ) -> None: """ Register a transport class for a specific protocol. - + :param protocol: The protocol identifier (e.g., "tcp", "ws") :param transport_class: The transport class to register """ self._transports[protocol] = transport_class - logger.debug(f"Registered transport {transport_class.__name__} for protocol {protocol}") - - def get_transport(self, protocol: str) -> Optional[Type[ITransport]]: + logger.debug( + f"Registered transport {transport_class.__name__} for protocol {protocol}" + ) + + def get_transport(self, protocol: str) -> type[ITransport] | None: """ Get the transport class for a specific protocol. - + :param protocol: The protocol identifier :return: The transport class or None if not found """ return self._transports.get(protocol) - + def get_supported_protocols(self) -> list[str]: """Get list of supported transport protocols.""" return list(self._transports.keys()) - - def create_transport(self, protocol: str, upgrader: Optional[TransportUpgrader] = None, **kwargs) -> Optional[ITransport]: + + def create_transport( + self, protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any + ) -> ITransport | None: """ Create a transport instance for a specific protocol. - + :param protocol: The protocol identifier :param upgrader: The transport upgrader instance (required for WebSocket) :param kwargs: Additional arguments for transport construction @@ -147,14 +156,17 @@ class TransportRegistry: transport_class = self.get_transport(protocol) if transport_class is None: return None - + try: if protocol == "ws": # WebSocket transport requires upgrader if upgrader is None: - logger.warning(f"WebSocket transport '{protocol}' requires upgrader") + logger.warning( + f"WebSocket transport '{protocol}' requires upgrader" + ) return None - return transport_class(upgrader) + # Use explicit WebsocketTransport to avoid type issues + return WebsocketTransport(upgrader) else: # TCP transport doesn't require upgrader return transport_class() @@ -172,15 +184,17 @@ def get_transport_registry() -> TransportRegistry: return _global_registry -def register_transport(protocol: str, transport_class: Type[ITransport]) -> None: +def register_transport(protocol: str, transport_class: type[ITransport]) -> None: """Register a transport class in the global registry.""" _global_registry.register_transport(protocol, transport_class) -def create_transport_for_multiaddr(maddr: Multiaddr, upgrader: TransportUpgrader) -> Optional[ITransport]: +def create_transport_for_multiaddr( + maddr: Multiaddr, upgrader: TransportUpgrader +) -> ITransport | None: """ Create the appropriate transport for a given multiaddr. - + :param maddr: The multiaddr to create transport for :param upgrader: The transport upgrader instance :return: Transport instance or None if no suitable transport found @@ -188,7 +202,7 @@ def create_transport_for_multiaddr(maddr: Multiaddr, upgrader: TransportUpgrader try: # Get all protocols in the multiaddr protocols = [proto.name for proto in maddr.protocols()] - + # Check for supported transport protocols in order of preference # We need to validate that the multiaddr structure is valid for our transports if "ws" in protocols: @@ -201,11 +215,14 @@ def create_transport_for_multiaddr(maddr: Multiaddr, upgrader: TransportUpgrader # Check if the multiaddr has proper TCP structure if _is_valid_tcp_multiaddr(maddr): return _global_registry.create_transport("tcp", upgrader) - + # If no supported transport protocol found or structure is invalid, return None - logger.warning(f"No supported transport protocol found or invalid structure in multiaddr: {maddr}") + logger.warning( + f"No supported transport protocol found or invalid structure in " + f"multiaddr: {maddr}" + ) return None - + except Exception as e: # Handle any errors gracefully (e.g., invalid multiaddr) logger.warning(f"Error processing multiaddr {maddr}: {e}") diff --git a/libp2p/transport/websocket/connection.py b/libp2p/transport/websocket/connection.py index 7188ae8c..3051339d 100644 --- a/libp2p/transport/websocket/connection.py +++ b/libp2p/transport/websocket/connection.py @@ -1,9 +1,13 @@ -from trio.abc import Stream +import logging +from typing import Any + import trio from libp2p.io.abc import ReadWriteCloser from libp2p.io.exceptions import IOException +logger = logging.getLogger(__name__) + class P2PWebSocketConnection(ReadWriteCloser): """ @@ -11,7 +15,7 @@ class P2PWebSocketConnection(ReadWriteCloser): that libp2p protocols expect. """ - def __init__(self, ws_connection, ws_context=None): + def __init__(self, ws_connection: Any, ws_context: Any = None) -> None: self._ws_connection = ws_connection self._ws_context = ws_context self._read_buffer = b"" @@ -19,57 +23,102 @@ class P2PWebSocketConnection(ReadWriteCloser): async def write(self, data: bytes) -> None: try: + logger.debug(f"WebSocket writing {len(data)} bytes") # Send as a binary WebSocket message await self._ws_connection.send_message(data) + logger.debug(f"WebSocket wrote {len(data)} bytes successfully") except Exception as e: + logger.error(f"WebSocket write failed: {e}") raise IOException from e async def read(self, n: int | None = None) -> bytes: """ Read up to n bytes (if n is given), else read up to 64KiB. + This implementation provides byte-level access to WebSocket messages, + which is required for Noise protocol handshake. """ async with self._read_lock: try: + logger.debug( + f"WebSocket read requested: n={n}, " + f"buffer_size={len(self._read_buffer)}" + ) + # If we have buffered data, return it if self._read_buffer: if n is None: result = self._read_buffer self._read_buffer = b"" + logger.debug( + f"WebSocket read returning all buffered data: " + f"{len(result)} bytes" + ) return result else: if len(self._read_buffer) >= n: result = self._read_buffer[:n] self._read_buffer = self._read_buffer[n:] + logger.debug( + f"WebSocket read returning {len(result)} bytes " + f"from buffer" + ) return result else: - result = self._read_buffer - self._read_buffer = b"" - return result + # We need more data, but we have some buffered + # Keep the buffered data and get more + logger.debug( + f"WebSocket read needs more data: have " + f"{len(self._read_buffer)}, need {n}" + ) + pass + + # If we need exactly n bytes but don't have enough, get more data + while n is not None and ( + not self._read_buffer or len(self._read_buffer) < n + ): + logger.debug( + f"WebSocket read getting more data: " + f"buffer_size={len(self._read_buffer)}, need={n}" + ) + # Get the next WebSocket message and treat it as a byte stream + # This mimics the Go implementation's NextReader() approach + message = await self._ws_connection.get_message() + if isinstance(message, str): + message = message.encode("utf-8") + + logger.debug( + f"WebSocket read received message: {len(message)} bytes" + ) + # Add to buffer + self._read_buffer += message - # Get the next WebSocket message - message = await self._ws_connection.get_message() - if isinstance(message, str): - message = message.encode('utf-8') - - # Add to buffer - self._read_buffer = message - # Return requested amount if n is None: result = self._read_buffer self._read_buffer = b"" + logger.debug( + f"WebSocket read returning all data: {len(result)} bytes" + ) return result else: if len(self._read_buffer) >= n: result = self._read_buffer[:n] self._read_buffer = self._read_buffer[n:] + logger.debug( + f"WebSocket read returning exact {len(result)} bytes" + ) return result else: + # This should never happen due to the while loop above result = self._read_buffer self._read_buffer = b"" + logger.debug( + f"WebSocket read returning remaining {len(result)} bytes" + ) return result - + except Exception as e: + logger.error(f"WebSocket read failed: {e}") raise IOException from e async def close(self) -> None: @@ -83,12 +132,12 @@ class P2PWebSocketConnection(ReadWriteCloser): # Try to get remote address from the WebSocket connection try: remote = self._ws_connection.remote - if hasattr(remote, 'address') and hasattr(remote, 'port'): + if hasattr(remote, "address") and hasattr(remote, "port"): return str(remote.address), int(remote.port) elif isinstance(remote, str): # Parse address:port format - if ':' in remote: - host, port = remote.rsplit(':', 1) + if ":" in remote: + host, port = remote.rsplit(":", 1) return host, int(port) except Exception: pass diff --git a/libp2p/transport/websocket/listener.py b/libp2p/transport/websocket/listener.py index 33194e3f..b8dffc93 100644 --- a/libp2p/transport/websocket/listener.py +++ b/libp2p/transport/websocket/listener.py @@ -1,6 +1,6 @@ +from collections.abc import Awaitable, Callable import logging -import socket -from typing import Any, Callable +from typing import Any from multiaddr import Multiaddr import trio @@ -9,7 +9,6 @@ from trio_websocket import serve_websocket from libp2p.abc import IListener from libp2p.custom_types import THandler -from libp2p.network.connection.raw_connection import RawConnection from libp2p.transport.upgrader import TransportUpgrader from .connection import P2PWebSocketConnection @@ -27,7 +26,8 @@ class WebsocketListener(IListener): self._upgrader = upgrader self._server = None self._shutdown_event = trio.Event() - self._nursery = None + self._nursery: trio.Nursery | None = None + self._listeners: Any = None async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool: logger.debug(f"WebsocketListener.listen called with {maddr}") @@ -47,56 +47,60 @@ class WebsocketListener(IListener): if port_str is None: raise ValueError(f"No TCP port found in multiaddr: {maddr}") port = int(port_str) - + logger.debug(f"WebsocketListener: host={host}, port={port}") async def serve_websocket_tcp( - handler: Callable, + handler: Callable[[Any], Awaitable[None]], port: int, host: str, - task_status: trio.TaskStatus[list], + task_status: TaskStatus[Any], ) -> None: """Start TCP server and handle WebSocket connections manually""" logger.debug("serve_websocket_tcp %s %s", host, port) - - async def websocket_handler(request): + + async def websocket_handler(request: Any) -> None: """Handle WebSocket requests""" logger.debug("WebSocket request received") try: # Accept the WebSocket connection ws_connection = await request.accept() logger.debug("WebSocket handshake successful") - + # Create the WebSocket connection wrapper - conn = P2PWebSocketConnection(ws_connection) - + conn = P2PWebSocketConnection(ws_connection) # type: ignore[no-untyped-call] + # Call the handler function that was passed to create_listener # This handler will handle the security and muxing upgrades logger.debug("Calling connection handler") await self._handler(conn) - + # Don't keep the connection alive indefinitely # Let the handler manage the connection lifecycle - logger.debug("Handler completed, connection will be managed by handler") - + logger.debug( + "Handler completed, connection will be managed by handler" + ) + except Exception as e: logger.debug(f"WebSocket connection error: {e}") logger.debug(f"Error type: {type(e)}") import traceback + logger.debug(f"Traceback: {traceback.format_exc()}") # Reject the connection try: await request.reject(400) - except: + except Exception: pass - + # Use trio_websocket.serve_websocket for proper WebSocket handling - from trio_websocket import serve_websocket - await serve_websocket(websocket_handler, host, port, None, task_status=task_status) + await serve_websocket( + websocket_handler, host, port, None, task_status=task_status + ) # Store the nursery for shutdown self._nursery = nursery - + # Start the server using nursery.start() like TCP does logger.debug("Calling nursery.start()...") started_listeners = await nursery.start( @@ -111,18 +115,21 @@ class WebsocketListener(IListener): logger.error(f"Failed to start WebSocket listener for {maddr}") return False - # Store the listeners for get_addrs() and close() - these are real SocketListener objects + # Store the listeners for get_addrs() and close() - these are real + # SocketListener objects self._listeners = started_listeners - logger.debug(f"WebsocketListener.listen returning True with WebSocketServer object") + logger.debug( + "WebsocketListener.listen returning True with WebSocketServer object" + ) return True - + def get_addrs(self) -> tuple[Multiaddr, ...]: - if not hasattr(self, '_listeners') or not self._listeners: + if not hasattr(self, "_listeners") or not self._listeners: logger.debug("No listeners available for get_addrs()") return () - + # Handle WebSocketServer objects - if hasattr(self._listeners, 'port'): + if hasattr(self._listeners, "port"): # This is a WebSocketServer object port = self._listeners.port # Create a multiaddr from the port @@ -138,12 +145,12 @@ class WebsocketListener(IListener): async def close(self) -> None: """Close the WebSocket listener and stop accepting new connections""" logger.debug("WebsocketListener.close called") - if hasattr(self, '_listeners') and self._listeners: + if hasattr(self, "_listeners") and self._listeners: # Signal shutdown self._shutdown_event.set() - + # Close the WebSocket server - if hasattr(self._listeners, 'aclose'): + if hasattr(self._listeners, "aclose"): # This is a WebSocketServer object logger.debug("Closing WebSocket server") await self._listeners.aclose() @@ -152,15 +159,15 @@ class WebsocketListener(IListener): # This is a list of listeners (like TCP) logger.debug("Closing TCP listeners") for listener in self._listeners: - listener.close() + await listener.aclose() logger.debug("TCP listeners closed") else: # Unknown type, try to close it directly logger.debug("Closing unknown listener type") - if hasattr(self._listeners, 'close'): + if hasattr(self._listeners, "close"): self._listeners.close() logger.debug("Unknown listener closed") - + # Clear the listeners reference self._listeners = None logger.debug("WebsocketListener.close completed") diff --git a/libp2p/transport/websocket/transport.py b/libp2p/transport/websocket/transport.py index adf04504..98c983d0 100644 --- a/libp2p/transport/websocket/transport.py +++ b/libp2p/transport/websocket/transport.py @@ -1,6 +1,6 @@ import logging + from multiaddr import Multiaddr -from trio_websocket import open_websocket_url from libp2p.abc import IListener, ITransport from libp2p.custom_types import THandler @@ -11,7 +11,7 @@ from libp2p.transport.upgrader import TransportUpgrader from .connection import P2PWebSocketConnection from .listener import WebsocketListener -logger = logging.getLogger("libp2p.transport.websocket") +logger = logging.getLogger(__name__) class WebsocketTransport(ITransport): @@ -25,7 +25,7 @@ class WebsocketTransport(ITransport): async def dial(self, maddr: Multiaddr) -> RawConnection: """Dial a WebSocket connection to the given multiaddr.""" logger.debug(f"WebsocketTransport.dial called with {maddr}") - + # Extract host and port from multiaddr host = ( maddr.value_for_protocol("ip4") @@ -45,6 +45,7 @@ class WebsocketTransport(ITransport): try: from trio_websocket import open_websocket_url + # Use the context manager but don't exit it immediately # The connection will be closed when the RawConnection is closed ws_context = open_websocket_url(ws_url) diff --git a/tests/core/transport/test_transport_registry.py b/tests/core/transport/test_transport_registry.py index b357ebe2..ff2fb234 100644 --- a/tests/core/transport/test_transport_registry.py +++ b/tests/core/transport/test_transport_registry.py @@ -2,20 +2,20 @@ Tests for the transport registry functionality. """ -import pytest from multiaddr import Multiaddr -from libp2p.abc import ITransport +from libp2p.abc import IListener, IRawConnection, ITransport +from libp2p.custom_types import THandler from libp2p.transport.tcp.tcp import TCP -from libp2p.transport.websocket.transport import WebsocketTransport from libp2p.transport.transport_registry import ( TransportRegistry, create_transport_for_multiaddr, + get_supported_transport_protocols, get_transport_registry, register_transport, - get_supported_transport_protocols, ) from libp2p.transport.upgrader import TransportUpgrader +from libp2p.transport.websocket.transport import WebsocketTransport class TestTransportRegistry: @@ -25,7 +25,7 @@ class TestTransportRegistry: """Test registry initialization.""" registry = TransportRegistry() assert isinstance(registry, TransportRegistry) - + # Check that default transports are registered supported = registry.get_supported_protocols() assert "tcp" in supported @@ -34,22 +34,28 @@ class TestTransportRegistry: def test_register_transport(self): """Test transport registration.""" registry = TransportRegistry() - + # Register a custom transport - class CustomTransport: - pass - + class CustomTransport(ITransport): + async def dial(self, maddr: Multiaddr) -> IRawConnection: + raise NotImplementedError("CustomTransport dial not implemented") + + def create_listener(self, handler_function: THandler) -> IListener: + raise NotImplementedError( + "CustomTransport create_listener not implemented" + ) + registry.register_transport("custom", CustomTransport) assert registry.get_transport("custom") == CustomTransport def test_get_transport(self): """Test getting registered transports.""" registry = TransportRegistry() - + # Test existing transports assert registry.get_transport("tcp") == TCP assert registry.get_transport("ws") == WebsocketTransport - + # Test non-existent transport assert registry.get_transport("nonexistent") is None @@ -57,7 +63,7 @@ class TestTransportRegistry: """Test getting supported protocols.""" registry = TransportRegistry() protocols = registry.get_supported_protocols() - + assert isinstance(protocols, list) assert "tcp" in protocols assert "ws" in protocols @@ -66,7 +72,7 @@ class TestTransportRegistry: """Test creating TCP transport.""" registry = TransportRegistry() upgrader = TransportUpgrader({}, {}) - + transport = registry.create_transport("tcp", upgrader) assert isinstance(transport, TCP) @@ -74,7 +80,7 @@ class TestTransportRegistry: """Test creating WebSocket transport.""" registry = TransportRegistry() upgrader = TransportUpgrader({}, {}) - + transport = registry.create_transport("ws", upgrader) assert isinstance(transport, WebsocketTransport) @@ -82,14 +88,14 @@ class TestTransportRegistry: """Test creating transport with invalid protocol.""" registry = TransportRegistry() upgrader = TransportUpgrader({}, {}) - + transport = registry.create_transport("invalid", upgrader) assert transport is None def test_create_transport_websocket_no_upgrader(self): """Test that WebSocket transport requires upgrader.""" registry = TransportRegistry() - + # This should fail gracefully and return None transport = registry.create_transport("ws", None) assert transport is None @@ -105,12 +111,19 @@ class TestGlobalRegistry: def test_register_transport_global(self): """Test registering transport globally.""" - class GlobalCustomTransport: - pass - + + class GlobalCustomTransport(ITransport): + async def dial(self, maddr: Multiaddr) -> IRawConnection: + raise NotImplementedError("GlobalCustomTransport dial not implemented") + + def create_listener(self, handler_function: THandler) -> IListener: + raise NotImplementedError( + "GlobalCustomTransport create_listener not implemented" + ) + # Register globally register_transport("global_custom", GlobalCustomTransport) - + # Check that it's available registry = get_transport_registry() assert registry.get_transport("global_custom") == GlobalCustomTransport @@ -129,79 +142,80 @@ class TestTransportFactory: def test_create_transport_for_multiaddr_tcp(self): """Test creating transport for TCP multiaddr.""" upgrader = TransportUpgrader({}, {}) - + # TCP multiaddr maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080") transport = create_transport_for_multiaddr(maddr, upgrader) - + assert transport is not None assert isinstance(transport, TCP) def test_create_transport_for_multiaddr_websocket(self): """Test creating transport for WebSocket multiaddr.""" upgrader = TransportUpgrader({}, {}) - + # WebSocket multiaddr maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") transport = create_transport_for_multiaddr(maddr, upgrader) - + assert transport is not None assert isinstance(transport, WebsocketTransport) def test_create_transport_for_multiaddr_websocket_secure(self): """Test creating transport for WebSocket multiaddr.""" upgrader = TransportUpgrader({}, {}) - + # WebSocket multiaddr maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") transport = create_transport_for_multiaddr(maddr, upgrader) - + assert transport is not None assert isinstance(transport, WebsocketTransport) def test_create_transport_for_multiaddr_ipv6(self): """Test creating transport for IPv6 multiaddr.""" upgrader = TransportUpgrader({}, {}) - + # IPv6 WebSocket multiaddr maddr = Multiaddr("/ip6/::1/tcp/8080/ws") transport = create_transport_for_multiaddr(maddr, upgrader) - + assert transport is not None assert isinstance(transport, WebsocketTransport) def test_create_transport_for_multiaddr_dns(self): """Test creating transport for DNS multiaddr.""" upgrader = TransportUpgrader({}, {}) - + # DNS WebSocket multiaddr maddr = Multiaddr("/dns4/example.com/tcp/443/ws") transport = create_transport_for_multiaddr(maddr, upgrader) - + assert transport is not None assert isinstance(transport, WebsocketTransport) def test_create_transport_for_multiaddr_unknown(self): """Test creating transport for unknown multiaddr.""" upgrader = TransportUpgrader({}, {}) - + # Unknown multiaddr maddr = Multiaddr("/ip4/127.0.0.1/udp/8080") transport = create_transport_for_multiaddr(maddr, upgrader) - + assert transport is None - def test_create_transport_for_multiaddr_no_upgrader(self): - """Test creating transport without upgrader.""" - # This should work for TCP but not WebSocket + def test_create_transport_for_multiaddr_with_upgrader(self): + """Test creating transport with upgrader.""" + upgrader = TransportUpgrader({}, {}) + + # This should work for both TCP and WebSocket with upgrader maddr_tcp = Multiaddr("/ip4/127.0.0.1/tcp/8080") - transport_tcp = create_transport_for_multiaddr(maddr_tcp, None) + transport_tcp = create_transport_for_multiaddr(maddr_tcp, upgrader) assert transport_tcp is not None - + maddr_ws = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") - transport_ws = create_transport_for_multiaddr(maddr_ws, None) - # WebSocket transport creation should fail gracefully - assert transport_ws is None + transport_ws = create_transport_for_multiaddr(maddr_ws, upgrader) + assert transport_ws is not None class TestTransportInterfaceCompliance: @@ -211,8 +225,8 @@ class TestTransportInterfaceCompliance: """Test that TCP transport implements ITransport.""" transport = TCP() assert isinstance(transport, ITransport) - assert hasattr(transport, 'dial') - assert hasattr(transport, 'create_listener') + assert hasattr(transport, "dial") + assert hasattr(transport, "create_listener") assert callable(transport.dial) assert callable(transport.create_listener) @@ -221,8 +235,8 @@ class TestTransportInterfaceCompliance: upgrader = TransportUpgrader({}, {}) transport = WebsocketTransport(upgrader) assert isinstance(transport, ITransport) - assert hasattr(transport, 'dial') - assert hasattr(transport, 'create_listener') + assert hasattr(transport, "dial") + assert hasattr(transport, "create_listener") assert callable(transport.dial) assert callable(transport.create_listener) @@ -234,14 +248,22 @@ class TestErrorHandling: """Test handling of transport creation exceptions.""" registry = TransportRegistry() upgrader = TransportUpgrader({}, {}) - + # Register a transport that raises an exception - class ExceptionTransport: + class ExceptionTransport(ITransport): def __init__(self, *args, **kwargs): raise RuntimeError("Transport creation failed") - + + async def dial(self, maddr: Multiaddr) -> IRawConnection: + raise NotImplementedError("ExceptionTransport dial not implemented") + + def create_listener(self, handler_function: THandler) -> IListener: + raise NotImplementedError( + "ExceptionTransport create_listener not implemented" + ) + registry.register_transport("exception", ExceptionTransport) - + # Should handle exception gracefully and return None transport = registry.create_transport("exception", upgrader) assert transport is None @@ -249,12 +271,13 @@ class TestErrorHandling: def test_invalid_multiaddr_handling(self): """Test handling of invalid multiaddrs.""" upgrader = TransportUpgrader({}, {}) - + # Test with a multiaddr that has an unsupported transport protocol # This should be handled gracefully by our transport registry - maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/udp/1234") # udp is not a supported transport + # udp is not a supported transport + maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/udp/1234") transport = create_transport_for_multiaddr(maddr, upgrader) - + assert transport is None @@ -265,15 +288,15 @@ class TestIntegration: """Test using multiple transport types in the same registry.""" registry = TransportRegistry() upgrader = TransportUpgrader({}, {}) - + # Create different transport types tcp_transport = registry.create_transport("tcp", upgrader) ws_transport = registry.create_transport("ws", upgrader) - + # All should be different types assert isinstance(tcp_transport, TCP) assert isinstance(ws_transport, WebsocketTransport) - + # All should be different instances assert tcp_transport is not ws_transport @@ -281,15 +304,21 @@ class TestIntegration: """Test that transport registry persists across calls.""" registry1 = get_transport_registry() registry2 = get_transport_registry() - + # Should be the same instance assert registry1 is registry2 - + # Register a transport in one - class PersistentTransport: - pass - + class PersistentTransport(ITransport): + async def dial(self, maddr: Multiaddr) -> IRawConnection: + raise NotImplementedError("PersistentTransport dial not implemented") + + def create_listener(self, handler_function: THandler) -> IListener: + raise NotImplementedError( + "PersistentTransport create_listener not implemented" + ) + registry1.register_transport("persistent", PersistentTransport) - + # Should be available in the other assert registry2.get_transport("persistent") == PersistentTransport diff --git a/tests/core/transport/test_websocket.py b/tests/core/transport/test_websocket.py index 1df85256..56051a15 100644 --- a/tests/core/transport/test_websocket.py +++ b/tests/core/transport/test_websocket.py @@ -1,23 +1,23 @@ from collections.abc import Sequence +import logging from typing import Any import pytest -import trio from multiaddr import Multiaddr +import trio from libp2p.crypto.secp256k1 import create_new_key_pair from libp2p.custom_types import TProtocol from libp2p.host.basic_host import BasicHost from libp2p.network.swarm import Swarm from libp2p.peer.id import ID -from libp2p.peer.peerinfo import PeerInfo from libp2p.peer.peerstore import PeerStore from libp2p.security.insecure.transport import InsecureTransport from libp2p.stream_muxer.yamux.yamux import Yamux from libp2p.transport.upgrader import TransportUpgrader from libp2p.transport.websocket.transport import WebsocketTransport -from libp2p.transport.websocket.listener import WebsocketListener -from libp2p.transport.exceptions import OpenConnectionError + +logger = logging.getLogger(__name__) PLAINTEXT_PROTOCOL_ID = "/plaintext/1.0.0" @@ -64,29 +64,30 @@ def create_upgrader(): ) - - - # 2. Listener Basic Functionality Tests @pytest.mark.trio async def test_listener_basic_listen(): """Test basic listen functionality""" upgrader = create_upgrader() transport = WebsocketTransport(upgrader) - + # Test listening on IPv4 ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws") - listener = transport.create_listener(lambda conn: None) - + + async def dummy_handler(conn): + await trio.sleep(0) + + listener = transport.create_listener(dummy_handler) + # Test that listener can be created and has required methods - assert hasattr(listener, 'listen') - assert hasattr(listener, 'close') - assert hasattr(listener, 'get_addrs') - + assert hasattr(listener, "listen") + assert hasattr(listener, "close") + assert hasattr(listener, "get_addrs") + # Test that listener can handle the address assert ma.value_for_protocol("ip4") == "127.0.0.1" assert ma.value_for_protocol("tcp") == "0" - + # Test that listener can be closed await listener.close() @@ -96,14 +97,18 @@ async def test_listener_port_0_handling(): """Test listening on port 0 gets actual port""" upgrader = create_upgrader() transport = WebsocketTransport(upgrader) - + ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws") - listener = transport.create_listener(lambda conn: None) - + + async def dummy_handler(conn): + await trio.sleep(0) + + listener = transport.create_listener(dummy_handler) + # Test that the address can be parsed correctly port_str = ma.value_for_protocol("tcp") assert port_str == "0" - + # Test that listener can be closed await listener.close() @@ -113,14 +118,18 @@ async def test_listener_any_interface(): """Test listening on 0.0.0.0""" upgrader = create_upgrader() transport = WebsocketTransport(upgrader) - + ma = Multiaddr("/ip4/0.0.0.0/tcp/0/ws") - listener = transport.create_listener(lambda conn: None) - + + async def dummy_handler(conn): + await trio.sleep(0) + + listener = transport.create_listener(dummy_handler) + # Test that the address can be parsed correctly host = ma.value_for_protocol("ip4") assert host == "0.0.0.0" - + # Test that listener can be closed await listener.close() @@ -130,16 +139,20 @@ async def test_listener_address_preservation(): """Test that p2p IDs are preserved in addresses""" upgrader = create_upgrader() transport = WebsocketTransport(upgrader) - + # Create address with p2p ID p2p_id = "12D3KooWL5xtmx8Mgc6tByjVaPPpTKH42QK7PUFQtZLabdSMKHpF" ma = Multiaddr(f"/ip4/127.0.0.1/tcp/0/ws/p2p/{p2p_id}") - listener = transport.create_listener(lambda conn: None) - + + async def dummy_handler(conn): + await trio.sleep(0) + + listener = transport.create_listener(dummy_handler) + # Test that p2p ID is preserved in the address addr_str = str(ma) assert p2p_id in addr_str - + # Test that listener can be closed await listener.close() @@ -150,18 +163,18 @@ async def test_dial_basic(): """Test basic dial functionality""" upgrader = create_upgrader() transport = WebsocketTransport(upgrader) - + # Test that transport can parse addresses for dialing ma = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") - + # Test that the address can be parsed correctly host = ma.value_for_protocol("ip4") port = ma.value_for_protocol("tcp") assert host == "127.0.0.1" assert port == "8080" - + # Test that transport has the required methods - assert hasattr(transport, 'dial') + assert hasattr(transport, "dial") assert callable(transport.dial) @@ -170,16 +183,16 @@ async def test_dial_with_p2p_id(): """Test dialing with p2p ID suffix""" upgrader = create_upgrader() transport = WebsocketTransport(upgrader) - + p2p_id = "12D3KooWL5xtmx8Mgc6tByjVaPPpTKH42QK7PUFQtZLabdSMKHpF" ma = Multiaddr(f"/ip4/127.0.0.1/tcp/8080/ws/p2p/{p2p_id}") - + # Test that p2p ID is preserved in the address addr_str = str(ma) assert p2p_id in addr_str - + # Test that transport can handle addresses with p2p IDs - assert hasattr(transport, 'dial') + assert hasattr(transport, "dial") assert callable(transport.dial) @@ -188,41 +201,42 @@ async def test_dial_port_0_resolution(): """Test dialing to resolved port 0 addresses""" upgrader = create_upgrader() transport = WebsocketTransport(upgrader) - + # Test that transport can handle port 0 addresses ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws") - + # Test that the address can be parsed correctly port_str = ma.value_for_protocol("tcp") assert port_str == "0" - + # Test that transport has the required methods - assert hasattr(transport, 'dial') + assert hasattr(transport, "dial") assert callable(transport.dial) # 4. Address Validation Tests (CRITICAL) def test_address_validation_ipv4(): """Test IPv4 address validation""" - upgrader = create_upgrader() - transport = WebsocketTransport(upgrader) - + # upgrader = create_upgrader() # Not used in this test + # Valid IPv4 WebSocket addresses valid_addresses = [ "/ip4/127.0.0.1/tcp/8080/ws", "/ip4/0.0.0.0/tcp/0/ws", "/ip4/192.168.1.1/tcp/443/ws", ] - + # Test valid addresses can be parsed for addr_str in valid_addresses: ma = Multiaddr(addr_str) # Should not raise exception when creating transport address transport_addr = str(ma) assert "/ws" in transport_addr - + # Test that transport can handle addresses with p2p IDs - p2p_addr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws/p2p/Qmb6owHp6eaWArVbcJJbQSyifyJBttMMjYV76N2hMbf5Vw") + p2p_addr = Multiaddr( + "/ip4/127.0.0.1/tcp/8080/ws/p2p/Qmb6owHp6eaWArVbcJJbQSyifyJBttMMjYV76N2hMbf5Vw" + ) # Should not raise exception when creating transport address transport_addr = str(p2p_addr) assert "/ws" in transport_addr @@ -230,15 +244,14 @@ def test_address_validation_ipv4(): def test_address_validation_ipv6(): """Test IPv6 address validation""" - upgrader = create_upgrader() - transport = WebsocketTransport(upgrader) - + # upgrader = create_upgrader() # Not used in this test + # Valid IPv6 WebSocket addresses valid_addresses = [ "/ip6/::1/tcp/8080/ws", "/ip6/2001:db8::1/tcp/443/ws", ] - + # Test valid addresses can be parsed for addr_str in valid_addresses: ma = Multiaddr(addr_str) @@ -248,16 +261,15 @@ def test_address_validation_ipv6(): def test_address_validation_dns(): """Test DNS address validation""" - upgrader = create_upgrader() - transport = WebsocketTransport(upgrader) - + # upgrader = create_upgrader() # Not used in this test + # Valid DNS WebSocket addresses valid_addresses = [ "/dns4/example.com/tcp/80/ws", "/dns6/example.com/tcp/443/ws", "/dnsaddr/example.com/tcp/8080/ws", ] - + # Test valid addresses can be parsed for addr_str in valid_addresses: ma = Multiaddr(addr_str) @@ -267,21 +279,20 @@ def test_address_validation_dns(): def test_address_validation_mixed(): """Test mixed address validation""" - upgrader = create_upgrader() - transport = WebsocketTransport(upgrader) - + # upgrader = create_upgrader() # Not used in this test + # Mixed valid and invalid addresses addresses = [ "/ip4/127.0.0.1/tcp/8080/ws", # Valid - "/ip4/127.0.0.1/tcp/8080", # Invalid (no /ws) - "/ip6/::1/tcp/8080/ws", # Valid - "/ip4/127.0.0.1/ws", # Invalid (no tcp) + "/ip4/127.0.0.1/tcp/8080", # Invalid (no /ws) + "/ip6/::1/tcp/8080/ws", # Valid + "/ip4/127.0.0.1/ws", # Invalid (no tcp) "/dns4/example.com/tcp/80/ws", # Valid ] - + # Convert to Multiaddr objects multiaddrs = [Multiaddr(addr) for addr in addresses] - + # Test that valid addresses can be processed valid_count = 0 for ma in multiaddrs: @@ -292,7 +303,7 @@ def test_address_validation_mixed(): valid_count += 1 except Exception: pass - + assert valid_count == 3 # Should have 3 valid addresses @@ -302,30 +313,29 @@ async def test_dial_invalid_address(): """Test dialing invalid addresses""" upgrader = create_upgrader() transport = WebsocketTransport(upgrader) - + # Test dialing non-WebSocket addresses invalid_addresses = [ Multiaddr("/ip4/127.0.0.1/tcp/8080"), # No /ws Multiaddr("/ip4/127.0.0.1/ws"), # No tcp ] - + for ma in invalid_addresses: - with pytest.raises((ValueError, OpenConnectionError, Exception)): + with pytest.raises(Exception): await transport.dial(ma) @pytest.mark.trio async def test_listen_invalid_address(): """Test listening on invalid addresses""" - upgrader = create_upgrader() - transport = WebsocketTransport(upgrader) - + # upgrader = create_upgrader() # Not used in this test + # Test listening on non-WebSocket addresses invalid_addresses = [ Multiaddr("/ip4/127.0.0.1/tcp/8080"), # No /ws Multiaddr("/ip4/127.0.0.1/ws"), # No tcp ] - + # Test that invalid addresses are properly identified for ma in invalid_addresses: # Test that the address parsing works correctly @@ -342,17 +352,17 @@ async def test_listen_port_in_use(): """Test listening on port that's in use""" upgrader = create_upgrader() transport = WebsocketTransport(upgrader) - + # Test that transport can handle port conflicts ma1 = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") ma2 = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") - + # Test that both addresses can be parsed assert ma1.value_for_protocol("tcp") == "8080" assert ma2.value_for_protocol("tcp") == "8080" - + # Test that transport can handle these addresses - assert hasattr(transport, 'create_listener') + assert hasattr(transport, "create_listener") assert callable(transport.create_listener) @@ -362,16 +372,19 @@ async def test_connection_close(): """Test connection closing""" upgrader = create_upgrader() transport = WebsocketTransport(upgrader) - + # Test that transport has required methods - assert hasattr(transport, 'dial') + assert hasattr(transport, "dial") assert callable(transport.dial) - + # Test that listener can be created and closed - listener = transport.create_listener(lambda conn: None) - assert hasattr(listener, 'close') + async def dummy_handler(conn): + await trio.sleep(0) + + listener = transport.create_listener(dummy_handler) + assert hasattr(listener, "close") assert callable(listener.close) - + # Test that listener can be closed await listener.close() @@ -381,32 +394,26 @@ async def test_multiple_connections(): """Test multiple concurrent connections""" upgrader = create_upgrader() transport = WebsocketTransport(upgrader) - + # Test that transport can handle multiple addresses addresses = [ Multiaddr("/ip4/127.0.0.1/tcp/8080/ws"), Multiaddr("/ip4/127.0.0.1/tcp/8081/ws"), Multiaddr("/ip4/127.0.0.1/tcp/8082/ws"), ] - + # Test that all addresses can be parsed for addr in addresses: host = addr.value_for_protocol("ip4") port = addr.value_for_protocol("tcp") assert host == "127.0.0.1" assert port in ["8080", "8081", "8082"] - + # Test that transport has required methods - assert hasattr(transport, 'dial') + assert hasattr(transport, "dial") assert callable(transport.dial) - - - - - - # Original test (kept for compatibility) @pytest.mark.trio async def test_websocket_dial_and_listen(): @@ -414,42 +421,40 @@ async def test_websocket_dial_and_listen(): # Test that WebSocket transport can handle basic operations upgrader = create_upgrader() transport = WebsocketTransport(upgrader) - + # Test that transport can create listeners - listener = transport.create_listener(lambda conn: None) + async def dummy_handler(conn): + await trio.sleep(0) + + listener = transport.create_listener(dummy_handler) assert listener is not None - assert hasattr(listener, 'listen') - assert hasattr(listener, 'close') - assert hasattr(listener, 'get_addrs') - + assert hasattr(listener, "listen") + assert hasattr(listener, "close") + assert hasattr(listener, "get_addrs") + # Test that transport can handle WebSocket addresses ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws") assert ma.value_for_protocol("ip4") == "127.0.0.1" assert ma.value_for_protocol("tcp") == "0" assert "ws" in str(ma) - + # Test that transport has dial method - assert hasattr(transport, 'dial') + assert hasattr(transport, "dial") assert callable(transport.dial) - + # Test that transport can handle WebSocket multiaddrs ws_addr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") assert ws_addr.value_for_protocol("ip4") == "127.0.0.1" assert ws_addr.value_for_protocol("tcp") == "8080" assert "ws" in str(ws_addr) - + # Cleanup await listener.close() -import logging -logger = logging.getLogger(__name__) - - @pytest.mark.trio async def test_websocket_transport_basic(): """Test basic WebSocket transport functionality without full libp2p stack""" - # Create WebSocket transport key_pair = create_new_key_pair() upgrader = TransportUpgrader( @@ -459,29 +464,31 @@ async def test_websocket_transport_basic(): muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux}, ) transport = WebsocketTransport(upgrader) - + assert transport is not None - assert hasattr(transport, 'dial') - assert hasattr(transport, 'create_listener') - - listener = transport.create_listener(lambda conn: None) + assert hasattr(transport, "dial") + assert hasattr(transport, "create_listener") + + async def dummy_handler(conn): + await trio.sleep(0) + + listener = transport.create_listener(dummy_handler) assert listener is not None - assert hasattr(listener, 'listen') - assert hasattr(listener, 'close') - assert hasattr(listener, 'get_addrs') - + assert hasattr(listener, "listen") + assert hasattr(listener, "close") + assert hasattr(listener, "get_addrs") + valid_addr = Multiaddr("/ip4/127.0.0.1/tcp/0/ws") assert valid_addr.value_for_protocol("ip4") == "127.0.0.1" assert valid_addr.value_for_protocol("tcp") == "0" assert "ws" in str(valid_addr) - + await listener.close() @pytest.mark.trio async def test_websocket_simple_connection(): - """Test WebSocket transport creation and basic functionality without real connections""" - + """Test WebSocket transport creation and basic functionality without real conn""" # Create WebSocket transport key_pair = create_new_key_pair() upgrader = TransportUpgrader( @@ -491,32 +498,31 @@ async def test_websocket_simple_connection(): muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux}, ) transport = WebsocketTransport(upgrader) - + assert transport is not None - assert hasattr(transport, 'dial') - assert hasattr(transport, 'create_listener') - + assert hasattr(transport, "dial") + assert hasattr(transport, "create_listener") + async def simple_handler(conn): await conn.close() - + listener = transport.create_listener(simple_handler) assert listener is not None - assert hasattr(listener, 'listen') - assert hasattr(listener, 'close') - assert hasattr(listener, 'get_addrs') - + assert hasattr(listener, "listen") + assert hasattr(listener, "close") + assert hasattr(listener, "get_addrs") + test_addr = Multiaddr("/ip4/127.0.0.1/tcp/0/ws") assert test_addr.value_for_protocol("ip4") == "127.0.0.1" assert test_addr.value_for_protocol("tcp") == "0" assert "ws" in str(test_addr) - + await listener.close() @pytest.mark.trio async def test_websocket_real_connection(): """Test WebSocket transport creation and basic functionality""" - # Create WebSocket transport key_pair = create_new_key_pair() upgrader = TransportUpgrader( @@ -526,59 +532,57 @@ async def test_websocket_real_connection(): muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux}, ) transport = WebsocketTransport(upgrader) - + assert transport is not None - assert hasattr(transport, 'dial') - assert hasattr(transport, 'create_listener') - + assert hasattr(transport, "dial") + assert hasattr(transport, "create_listener") + async def handler(conn): await conn.close() - + listener = transport.create_listener(handler) assert listener is not None - assert hasattr(listener, 'listen') - assert hasattr(listener, 'close') - assert hasattr(listener, 'get_addrs') - + assert hasattr(listener, "listen") + assert hasattr(listener, "close") + assert hasattr(listener, "get_addrs") + await listener.close() @pytest.mark.trio async def test_websocket_with_tcp_fallback(): """Test WebSocket functionality using TCP transport as fallback""" - from tests.utils.factories import host_pair_factory - + async with host_pair_factory() as (host_a, host_b): assert len(host_a.get_network().connections) > 0 assert len(host_b.get_network().connections) > 0 - + test_protocol = TProtocol("/test/protocol/1.0.0") received_data = None - + async def test_handler(stream): nonlocal received_data received_data = await stream.read(1024) await stream.write(b"Response from TCP") await stream.close() - + host_a.set_stream_handler(test_protocol, test_handler) stream = await host_b.new_stream(host_a.get_id(), [test_protocol]) - + test_data = b"TCP protocol test" await stream.write(test_data) response = await stream.read(1024) - + assert received_data == test_data assert response == b"Response from TCP" - + await stream.close() @pytest.mark.trio async def test_websocket_transport_interface(): """Test WebSocket transport interface compliance""" - key_pair = create_new_key_pair() upgrader = TransportUpgrader( secure_transports_by_protocol={ @@ -586,23 +590,26 @@ async def test_websocket_transport_interface(): }, muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux}, ) - + transport = WebsocketTransport(upgrader) - - assert hasattr(transport, 'dial') - assert hasattr(transport, 'create_listener') + + assert hasattr(transport, "dial") + assert hasattr(transport, "create_listener") assert callable(transport.dial) assert callable(transport.create_listener) - - listener = transport.create_listener(lambda conn: None) - assert hasattr(listener, 'listen') - assert hasattr(listener, 'close') - assert hasattr(listener, 'get_addrs') - + + async def dummy_handler(conn): + await trio.sleep(0) + + listener = transport.create_listener(dummy_handler) + assert hasattr(listener, "listen") + assert hasattr(listener, "close") + assert hasattr(listener, "get_addrs") + test_addr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") host = test_addr.value_for_protocol("ip4") port = test_addr.value_for_protocol("tcp") assert host == "127.0.0.1" assert port == "8080" - + await listener.close() diff --git a/tests/interop/test_js_ws_ping.py b/tests/interop/test_js_ws_ping.py index b2cf248d..b0e73a36 100644 --- a/tests/interop/test_js_ws_ping.py +++ b/tests/interop/test_js_ws_ping.py @@ -20,7 +20,7 @@ from libp2p.stream_muxer.yamux.yamux import Yamux from libp2p.transport.upgrader import TransportUpgrader from libp2p.transport.websocket.transport import WebsocketTransport -PLAINTEXT_PROTOCOL_ID = "/plaintext/1.0.0" +PLAINTEXT_PROTOCOL_ID = "/plaintext/2.0.0" @pytest.mark.trio @@ -74,6 +74,11 @@ async def test_ping_with_js_node(): peer_id = ID.from_base58(peer_id_line) maddr = Multiaddr(addr_line) + # Debug: Print what we're trying to connect to + print(f"JS Node Peer ID: {peer_id_line}") + print(f"JS Node Address: {addr_line}") + print(f"All JS Node lines: {lines}") + # Set up Python host key_pair = create_new_key_pair() py_peer_id = ID.from_pubkey(key_pair.public_key) @@ -86,13 +91,15 @@ async def test_ping_with_js_node(): }, muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux}, ) - transport = WebsocketTransport() + transport = WebsocketTransport(upgrader) swarm = Swarm(py_peer_id, peer_store, upgrader, transport) host = BasicHost(swarm) # Connect to JS node peer_info = PeerInfo(peer_id, [maddr]) + print(f"Python trying to connect to: {peer_info}") + await trio.sleep(1) try: