Fix typecheck errors and improve WebSocket transport implementation

- Fix INotifee interface compliance in WebSocket demo
- Fix handler function signatures to be async (THandler compatibility)
- Fix is_closed method usage with proper type checking
- Fix pytest.raises multiple exception type issue
- Fix line length violations (E501) across multiple files
- Add debugging logging to Noise security module for troubleshooting
- Update WebSocket transport examples and tests
- Improve transport registry error handling
This commit is contained in:
acul71
2025-08-11 01:25:49 +02:00
parent 64107b4648
commit fe4c17e8d1
16 changed files with 845 additions and 488 deletions

View File

@ -11,13 +11,14 @@ This script demonstrates:
import asyncio import asyncio
import logging import logging
import sys
from pathlib import Path from pathlib import Path
import sys
# Add the libp2p directory to the path so we can import it # Add the libp2p directory to the path so we can import it
sys.path.insert(0, str(Path(__file__).parent.parent)) sys.path.insert(0, str(Path(__file__).parent.parent))
import multiaddr import multiaddr
from libp2p.transport import ( from libp2p.transport import (
create_transport, create_transport,
create_transport_for_multiaddr, create_transport_for_multiaddr,
@ -25,9 +26,8 @@ from libp2p.transport import (
get_transport_registry, get_transport_registry,
register_transport, register_transport,
) )
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.tcp.tcp import TCP from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.websocket.transport import WebsocketTransport from libp2p.transport.upgrader import TransportUpgrader
# Set up logging # Set up logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -50,7 +50,8 @@ def demo_transport_registry():
print("\nRegistered transports:") print("\nRegistered transports:")
for protocol in supported: for protocol in supported:
transport_class = registry.get_transport(protocol) transport_class = registry.get_transport(protocol)
print(f" {protocol}: {transport_class.__name__}") class_name = transport_class.__name__ if transport_class else "None"
print(f" {protocol}: {class_name}")
print() print()
@ -114,15 +115,13 @@ def demo_custom_transport_registration():
print("🔧 Custom Transport Registration Demo") print("🔧 Custom Transport Registration Demo")
print("=" * 50) print("=" * 50)
# Create a dummy upgrader
upgrader = TransportUpgrader({}, {})
# Show current supported protocols # Show current supported protocols
print(f"Before registration: {get_supported_transport_protocols()}") print(f"Before registration: {get_supported_transport_protocols()}")
# Register a custom transport (using TCP as an example) # Register a custom transport (using TCP as an example)
class CustomTCPTransport(TCP): class CustomTCPTransport(TCP):
"""Custom TCP transport for demonstration.""" """Custom TCP transport for demonstration."""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.custom_flag = True self.custom_flag = True
@ -137,7 +136,12 @@ def demo_custom_transport_registration():
try: try:
custom_transport = create_transport("custom_tcp") custom_transport = create_transport("custom_tcp")
print(f"✅ Created custom transport: {type(custom_transport).__name__}") print(f"✅ Created custom transport: {type(custom_transport).__name__}")
print(f" Custom flag: {custom_transport.custom_flag}") # Check if it has the custom flag (type-safe way)
if hasattr(custom_transport, "custom_flag"):
flag_value = getattr(custom_transport, "custom_flag", "Not found")
print(f" Custom flag: {flag_value}")
else:
print(" Custom flag: Not found")
except Exception as e: except Exception as e:
print(f"❌ Error creating custom transport: {e}") print(f"❌ Error creating custom transport: {e}")
@ -202,4 +206,5 @@ if __name__ == "__main__":
except Exception as e: except Exception as e:
print(f"\n❌ Demo failed with error: {e}") print(f"\n❌ Demo failed with error: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()

View File

@ -5,7 +5,6 @@ Simple TCP echo demo to verify basic libp2p functionality.
import argparse import argparse
import logging import logging
import sys
import traceback import traceback
import multiaddr import multiaddr
@ -18,10 +17,10 @@ from libp2p.network.swarm import Swarm
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.peer.peerinfo import info_from_p2p_addr from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.peer.peerstore import PeerStore from libp2p.peer.peerstore import PeerStore
from libp2p.security.insecure.transport import InsecureTransport, PLAINTEXT_PROTOCOL_ID from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
from libp2p.stream_muxer.yamux.yamux import Yamux from libp2p.stream_muxer.yamux.yamux import Yamux
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.tcp.tcp import TCP from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.upgrader import TransportUpgrader
# Enable debug logging # Enable debug logging
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
@ -31,12 +30,13 @@ logger = logging.getLogger("libp2p.tcp-example")
# Simple echo protocol # Simple echo protocol
ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0") ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0")
async def echo_handler(stream): async def echo_handler(stream):
"""Simple echo handler that echoes back any data received.""" """Simple echo handler that echoes back any data received."""
try: try:
data = await stream.read(1024) data = await stream.read(1024)
if data: if data:
message = data.decode('utf-8', errors='replace') message = data.decode("utf-8", errors="replace")
print(f"📥 Received: {message}") print(f"📥 Received: {message}")
print(f"📤 Echoing back: {message}") print(f"📤 Echoing back: {message}")
await stream.write(data) await stream.write(data)
@ -45,6 +45,7 @@ async def echo_handler(stream):
logger.error(f"Echo handler error: {e}") logger.error(f"Echo handler error: {e}")
await stream.close() await stream.close()
def create_tcp_host(): def create_tcp_host():
"""Create a host with TCP transport.""" """Create a host with TCP transport."""
# Create key pair and peer store # Create key pair and peer store
@ -70,6 +71,7 @@ def create_tcp_host():
return host return host
async def run(port: int, destination: str) -> None: async def run(port: int, destination: str) -> None:
localhost_ip = "0.0.0.0" localhost_ip = "0.0.0.0"
@ -84,7 +86,10 @@ async def run(port: int, destination: str) -> None:
# Set up echo handler # Set up echo handler
host.set_stream_handler(ECHO_PROTOCOL_ID, echo_handler) host.set_stream_handler(ECHO_PROTOCOL_ID, echo_handler)
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery: async with (
host.run(listen_addrs=[listen_addr]),
trio.open_nursery() as (nursery),
):
# Start the peer-store cleanup task # Start the peer-store cleanup task
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
@ -102,8 +107,8 @@ async def run(port: int, destination: str) -> None:
print("🌐 TCP Server Started Successfully!") print("🌐 TCP Server Started Successfully!")
print("=" * 50) print("=" * 50)
print(f"📍 Server Address: {client_addr}") print(f"📍 Server Address: {client_addr}")
print(f"🔧 Protocol: /echo/1.0.0") print("🔧 Protocol: /echo/1.0.0")
print(f"🚀 Transport: TCP") print("🚀 Transport: TCP")
print() print()
print("📋 To test the connection, run this in another terminal:") print("📋 To test the connection, run this in another terminal:")
print(f" python test_tcp_echo.py -d {client_addr}") print(f" python test_tcp_echo.py -d {client_addr}")
@ -127,7 +132,10 @@ async def run(port: int, destination: str) -> None:
host = create_tcp_host() host = create_tcp_host()
# Start the host for client operations # Start the host for client operations
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery: async with (
host.run(listen_addrs=[listen_addr]),
trio.open_nursery() as (nursery),
):
# Start the peer-store cleanup task # Start the peer-store cleanup task
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
maddr = multiaddr.Multiaddr(destination) maddr = multiaddr.Multiaddr(destination)
@ -144,7 +152,7 @@ async def run(port: int, destination: str) -> None:
print("✅ Successfully connected to TCP server!") print("✅ Successfully connected to TCP server!")
except Exception as e: except Exception as e:
error_msg = str(e) error_msg = str(e)
print(f"\n❌ Connection Failed!") print("\n❌ Connection Failed!")
print(f" Peer ID: {info.peer_id}") print(f" Peer ID: {info.peer_id}")
print(f" Address: {destination}") print(f" Address: {destination}")
print(f" Error: {error_msg}") print(f" Error: {error_msg}")
@ -191,11 +199,14 @@ async def run(port: int, destination: str) -> None:
traceback.print_exc() traceback.print_exc()
return return
def main() -> None: def main() -> None:
description = "Simple TCP echo demo for libp2p" description = "Simple TCP echo demo for libp2p"
parser = argparse.ArgumentParser(description=description) parser = argparse.ArgumentParser(description=description)
parser.add_argument("-p", "--port", default=0, type=int, help="source port number") parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
parser.add_argument("-d", "--destination", type=str, help="destination multiaddr string") parser.add_argument(
"-d", "--destination", type=str, help="destination multiaddr string"
)
args = parser.parse_args() args = parser.parse_args()
@ -204,5 +215,6 @@ def main() -> None:
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -5,16 +5,16 @@ Simple test script to verify WebSocket transport functionality.
import asyncio import asyncio
import logging import logging
import sys
from pathlib import Path from pathlib import Path
import sys
# Add the libp2p directory to the path so we can import it # Add the libp2p directory to the path so we can import it
sys.path.insert(0, str(Path(__file__).parent)) sys.path.insert(0, str(Path(__file__).parent))
import multiaddr import multiaddr
from libp2p.transport import create_transport, create_transport_for_multiaddr from libp2p.transport import create_transport, create_transport_for_multiaddr
from libp2p.transport.upgrader import TransportUpgrader from libp2p.transport.upgrader import TransportUpgrader
from libp2p.network.connection.raw_connection import RawConnection
# Set up logging # Set up logging
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
@ -37,7 +37,10 @@ async def test_websocket_transport():
# Test creating transport from multiaddr # Test creating transport from multiaddr
ws_maddr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") ws_maddr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
ws_transport_from_maddr = create_transport_for_multiaddr(ws_maddr, upgrader) ws_transport_from_maddr = create_transport_for_multiaddr(ws_maddr, upgrader)
print(f"✅ WebSocket transport from multiaddr: {type(ws_transport_from_maddr).__name__}") print(
f"✅ WebSocket transport from multiaddr: "
f"{type(ws_transport_from_maddr).__name__}"
)
# Test creating listener # Test creating listener
handler_called = False handler_called = False
@ -52,8 +55,13 @@ async def test_websocket_transport():
print(f"✅ WebSocket listener created: {type(listener).__name__}") print(f"✅ WebSocket listener created: {type(listener).__name__}")
# Test that the transport can be used # Test that the transport can be used
print(f"✅ WebSocket transport supports dialing: {hasattr(ws_transport, 'dial')}") print(
print(f"✅ WebSocket transport supports listening: {hasattr(ws_transport, 'create_listener')}") f"✅ WebSocket transport supports dialing: {hasattr(ws_transport, 'dial')}"
)
print(
f"✅ WebSocket transport supports listening: "
f"{hasattr(ws_transport, 'create_listener')}"
)
print("\n🎯 WebSocket Transport Test Results:") print("\n🎯 WebSocket Transport Test Results:")
print("✅ Transport creation: PASS") print("✅ Transport creation: PASS")
@ -64,6 +72,7 @@ async def test_websocket_transport():
except Exception as e: except Exception as e:
print(f"❌ WebSocket transport test failed: {e}") print(f"❌ WebSocket transport test failed: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return False return False
@ -75,7 +84,10 @@ async def test_transport_registry():
print("\n🔧 Testing Transport Registry") print("\n🔧 Testing Transport Registry")
print("=" * 30) print("=" * 30)
from libp2p.transport import get_transport_registry, get_supported_transport_protocols from libp2p.transport import (
get_supported_transport_protocols,
get_transport_registry,
)
registry = get_transport_registry() registry = get_transport_registry()
supported = get_supported_transport_protocols() supported = get_supported_transport_protocols()
@ -85,7 +97,8 @@ async def test_transport_registry():
# Test getting transports # Test getting transports
for protocol in supported: for protocol in supported:
transport_class = registry.get_transport(protocol) transport_class = registry.get_transport(protocol)
print(f" {protocol}: {transport_class.__name__}") class_name = transport_class.__name__ if transport_class else "None"
print(f" {protocol}: {class_name}")
# Test creating transports through registry # Test creating transports through registry
upgrader = TransportUpgrader({}, {}) upgrader = TransportUpgrader({}, {})
@ -128,4 +141,5 @@ if __name__ == "__main__":
except Exception as e: except Exception as e:
print(f"\n❌ Test failed with error: {e}") print(f"\n❌ Test failed with error: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()

View File

@ -1,21 +1,26 @@
import argparse import argparse
import logging import logging
import signal
import sys import sys
import traceback import traceback
import multiaddr import multiaddr
import trio import trio
from libp2p.abc import INotifee
from libp2p.crypto.ed25519 import create_new_key_pair as create_ed25519_key_pair
from libp2p.crypto.secp256k1 import create_new_key_pair from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.custom_types import TProtocol from libp2p.custom_types import TProtocol
from libp2p.host.basic_host import BasicHost from libp2p.host.basic_host import BasicHost
from libp2p.network.swarm import Swarm from libp2p.network.swarm import Swarm
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.peer.peerstore import PeerStore from libp2p.peer.peerstore import PeerStore
from libp2p.security.insecure.transport import InsecureTransport, PLAINTEXT_PROTOCOL_ID from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
from libp2p.security.noise.transport import Transport as NoiseTransport from libp2p.security.noise.transport import (
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID PROTOCOL_ID as NOISE_PROTOCOL_ID,
Transport as NoiseTransport,
)
from libp2p.stream_muxer.yamux.yamux import Yamux from libp2p.stream_muxer.yamux.yamux import Yamux
from libp2p.transport.upgrader import TransportUpgrader from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.websocket.transport import WebsocketTransport from libp2p.transport.websocket.transport import WebsocketTransport
@ -25,6 +30,15 @@ logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("libp2p.websocket-example") logger = logging.getLogger("libp2p.websocket-example")
# Suppress KeyboardInterrupt by handling SIGINT directly
def signal_handler(signum, frame):
print("✅ Clean exit completed.")
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
# Simple echo protocol # Simple echo protocol
ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0") ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0")
@ -34,7 +48,7 @@ async def echo_handler(stream):
try: try:
data = await stream.read(1024) data = await stream.read(1024)
if data: if data:
message = data.decode('utf-8', errors='replace') message = data.decode("utf-8", errors="replace")
print(f"📥 Received: {message}") print(f"📥 Received: {message}")
print(f"📤 Echoing back: {message}") print(f"📤 Echoing back: {message}")
await stream.write(data) await stream.write(data)
@ -44,7 +58,7 @@ async def echo_handler(stream):
await stream.close() await stream.close()
def create_websocket_host(listen_addrs=None, use_noise=False): def create_websocket_host(listen_addrs=None, use_plaintext=False):
"""Create a host with WebSocket transport.""" """Create a host with WebSocket transport."""
# Create key pair and peer store # Create key pair and peer store
key_pair = create_new_key_pair() key_pair = create_new_key_pair()
@ -52,11 +66,22 @@ def create_websocket_host(listen_addrs=None, use_noise=False):
peer_store = PeerStore() peer_store = PeerStore()
peer_store.add_key_pair(peer_id, key_pair) peer_store.add_key_pair(peer_id, key_pair)
if use_noise: if use_plaintext:
# Create transport upgrader with plaintext security
upgrader = TransportUpgrader(
secure_transports_by_protocol={
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
},
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
else:
# Create separate Ed25519 key for Noise protocol
noise_key_pair = create_ed25519_key_pair()
# Create Noise transport # Create Noise transport
noise_transport = NoiseTransport( noise_transport = NoiseTransport(
libp2p_keypair=key_pair, libp2p_keypair=key_pair,
noise_privkey=key_pair.private_key, noise_privkey=noise_key_pair.private_key,
early_data=None, early_data=None,
with_noise_pipes=False, with_noise_pipes=False,
) )
@ -68,14 +93,6 @@ def create_websocket_host(listen_addrs=None, use_noise=False):
}, },
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux}, muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
) )
else:
# Create transport upgrader with plaintext security
upgrader = TransportUpgrader(
secure_transports_by_protocol={
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
},
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
# Create WebSocket transport # Create WebSocket transport
transport = WebsocketTransport(upgrader) transport = WebsocketTransport(upgrader)
@ -87,7 +104,7 @@ def create_websocket_host(listen_addrs=None, use_noise=False):
return host return host
async def run(port: int, destination: str, use_noise: bool = False) -> None: async def run(port: int, destination: str, use_plaintext: bool = False) -> None:
localhost_ip = "0.0.0.0" localhost_ip = "0.0.0.0"
if not destination: if not destination:
@ -95,16 +112,66 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}/ws") listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}/ws")
try: try:
host = create_websocket_host(use_noise=use_noise) host = create_websocket_host(use_plaintext=use_plaintext)
logger.debug(f"Created host with use_noise={use_noise}") logger.debug(f"Created host with use_plaintext={use_plaintext}")
# Set up echo handler # Set up echo handler
host.set_stream_handler(ECHO_PROTOCOL_ID, echo_handler) host.set_stream_handler(ECHO_PROTOCOL_ID, echo_handler)
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery: # Add connection event handlers for debugging
class DebugNotifee(INotifee):
async def opened_stream(self, network, stream):
pass
async def closed_stream(self, network, stream):
pass
async def connected(self, network, conn):
print(
f"🔗 New libp2p connection established: "
f"{conn.muxed_conn.peer_id}"
)
if hasattr(conn.muxed_conn, "get_security_protocol"):
security = conn.muxed_conn.get_security_protocol()
else:
security = "Unknown"
print(f" Security: {security}")
async def disconnected(self, network, conn):
print(f"🔌 libp2p connection closed: {conn.muxed_conn.peer_id}")
async def listen(self, network, multiaddr):
pass
async def listen_close(self, network, multiaddr):
pass
host.get_network().register_notifee(DebugNotifee())
# Create a cancellation token for clean shutdown
cancel_scope = trio.CancelScope()
async def signal_handler():
with trio.open_signal_receiver(signal.SIGINT, signal.SIGTERM) as (
signal_receiver
):
async for sig in signal_receiver:
print(f"\n🛑 Received signal {sig}")
print("✅ Shutting down WebSocket server...")
cancel_scope.cancel()
return
async with (
host.run(listen_addrs=[listen_addr]),
trio.open_nursery() as (nursery),
):
# Start the peer-store cleanup task # Start the peer-store cleanup task
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
# Start the signal handler
nursery.start_soon(signal_handler)
# Get the actual address and replace 0.0.0.0 with 127.0.0.1 for client # Get the actual address and replace 0.0.0.0 with 127.0.0.1 for client
# connections # connections
addrs = host.get_addrs() addrs = host.get_addrs()
@ -120,11 +187,12 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
print("🌐 WebSocket Server Started Successfully!") print("🌐 WebSocket Server Started Successfully!")
print("=" * 50) print("=" * 50)
print(f"📍 Server Address: {client_addr}") print(f"📍 Server Address: {client_addr}")
print(f"🔧 Protocol: /echo/1.0.0") print("🔧 Protocol: /echo/1.0.0")
print(f"🚀 Transport: WebSocket (/ws)") print("🚀 Transport: WebSocket (/ws)")
print() print()
print("📋 To test the connection, run this in another terminal:") print("📋 To test the connection, run this in another terminal:")
print(f" python websocket_demo.py -d {client_addr}") plaintext_flag = " --plaintext" if use_plaintext else ""
print(f" python websocket_demo.py -d {client_addr}{plaintext_flag}")
print() print()
print("⏳ Waiting for incoming WebSocket connections...") print("⏳ Waiting for incoming WebSocket connections...")
print("" * 50) print("" * 50)
@ -132,9 +200,9 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
# Add a custom handler to show connection events # Add a custom handler to show connection events
async def custom_echo_handler(stream): async def custom_echo_handler(stream):
peer_id = stream.muxed_conn.peer_id peer_id = stream.muxed_conn.peer_id
print(f"\n🔗 New WebSocket Connection!") print("\n🔗 New WebSocket Connection!")
print(f" Peer ID: {peer_id}") print(f" Peer ID: {peer_id}")
print(f" Protocol: /echo/1.0.0") print(" Protocol: /echo/1.0.0")
# Show remote address in multiaddr format # Show remote address in multiaddr format
try: try:
@ -142,20 +210,22 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
if remote_address: if remote_address:
print(f" Remote: {remote_address}") print(f" Remote: {remote_address}")
except Exception: except Exception:
print(f" Remote: Unknown") print(" Remote: Unknown")
print(f"" * 40) print("" * 40)
# Call the original handler # Call the original handler
await echo_handler(stream) await echo_handler(stream)
print(f"" * 40) print("" * 40)
print(f"✅ Echo request completed for peer: {peer_id}") print(f"✅ Echo request completed for peer: {peer_id}")
print() print()
# Replace the handler with our custom one # Replace the handler with our custom one
host.set_stream_handler(ECHO_PROTOCOL_ID, custom_echo_handler) host.set_stream_handler(ECHO_PROTOCOL_ID, custom_echo_handler)
# Wait indefinitely or until cancelled
with cancel_scope:
await trio.sleep_forever() await trio.sleep_forever()
except Exception as e: except Exception as e:
@ -169,12 +239,44 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
try: try:
# Create a single host for client operations # Create a single host for client operations
host = create_websocket_host(use_noise=use_noise) host = create_websocket_host(use_plaintext=use_plaintext)
# Start the host for client operations # Start the host for client operations
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery: async with (
host.run(listen_addrs=[listen_addr]),
trio.open_nursery() as (nursery),
):
# Start the peer-store cleanup task # Start the peer-store cleanup task
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60) nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
# Add connection event handlers for debugging
class ClientDebugNotifee(INotifee):
async def opened_stream(self, network, stream):
pass
async def closed_stream(self, network, stream):
pass
async def connected(self, network, conn):
print(
f"🔗 Client: libp2p connection established: "
f"{conn.muxed_conn.peer_id}"
)
async def disconnected(self, network, conn):
print(
f"🔌 Client: libp2p connection closed: "
f"{conn.muxed_conn.peer_id}"
)
async def listen(self, network, multiaddr):
pass
async def listen_close(self, network, multiaddr):
pass
host.get_network().register_notifee(ClientDebugNotifee())
maddr = multiaddr.Multiaddr(destination) maddr = multiaddr.Multiaddr(destination)
info = info_from_p2p_addr(maddr) info = info_from_p2p_addr(maddr)
print("🔌 WebSocket Client Starting...") print("🔌 WebSocket Client Starting...")
@ -185,20 +287,33 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
try: try:
print("🔗 Connecting to WebSocket server...") print("🔗 Connecting to WebSocket server...")
print(f" Security: {'Plaintext' if use_plaintext else 'Noise'}")
await host.connect(info) await host.connect(info)
print("✅ Successfully connected to WebSocket server!") print("✅ Successfully connected to WebSocket server!")
except Exception as e: except Exception as e:
error_msg = str(e) error_msg = str(e)
if "unable to connect" in error_msg or "SwarmException" in error_msg: print("\n❌ Connection Failed!")
print(f"\n❌ Connection Failed!")
print(f" Peer ID: {info.peer_id}") print(f" Peer ID: {info.peer_id}")
print(f" Address: {destination}") print(f" Address: {destination}")
print(f" Security: {'Plaintext' if use_plaintext else 'Noise'}")
print(f" Error: {error_msg}") print(f" Error: {error_msg}")
print(f" Error type: {type(e).__name__}")
# Add more detailed error information for debugging
if hasattr(e, "__cause__") and e.__cause__:
print(f" Root cause: {e.__cause__}")
print(f" Root cause type: {type(e.__cause__).__name__}")
print() print()
print("💡 Troubleshooting:") print("💡 Troubleshooting:")
print(" • Make sure the WebSocket server is running") print(" • Make sure the WebSocket server is running")
print(" • Check that the server address is correct") print(" • Check that the server address is correct")
print(" • Verify the server is listening on the right port") print(" • Verify the server is listening on the right port")
print(
" • Ensure both client and server use the same sec protocol"
)
if not use_plaintext:
print(" • Noise over WebSocket may have compatibility issues")
return return
# Create a stream and send test data # Create a stream and send test data
@ -242,7 +357,17 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
finally: finally:
# Ensure stream is closed # Ensure stream is closed
try: try:
if stream and not await stream.is_closed(): if stream:
# Check if stream has is_closed method and use it
has_is_closed = hasattr(stream, "is_closed") and callable(
getattr(stream, "is_closed")
)
if has_is_closed:
# type: ignore[attr-defined]
if not await stream.is_closed():
await stream.close()
else:
# Fallback: just try to close the stream
await stream.close() await stream.close()
except Exception: except Exception:
pass pass
@ -257,6 +382,9 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
print() print()
print("🚀 Your WebSocket transport is ready for production use!") print("🚀 Your WebSocket transport is ready for production use!")
# Add a small delay to ensure all cleanup is complete
await trio.sleep(0.1)
except Exception as e: except Exception as e:
print(f"❌ Error creating WebSocket client: {e}") print(f"❌ Error creating WebSocket client: {e}")
traceback.print_exc() traceback.print_exc()
@ -266,12 +394,15 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
def main() -> None: def main() -> None:
description = """ description = """
This program demonstrates the libp2p WebSocket transport. This program demonstrates the libp2p WebSocket transport.
First run 'python websocket_demo.py -p <PORT> [--noise]' to start a WebSocket server. First run
Then run 'python websocket_demo.py <ANOTHER_PORT> -d <DESTINATION> [--noise]' 'python websocket_demo.py -p <PORT> [--plaintext]' to start a WebSocket server.
Then run
'python websocket_demo.py <ANOTHER_PORT> -d <DESTINATION> [--plaintext]'
where <DESTINATION> is the multiaddress shown by the server. where <DESTINATION> is the multiaddress shown by the server.
By default, this example uses plaintext security for communication. By default, this example uses Noise encryption for secure communication.
Use --noise for testing with Noise encryption (experimental). Use --plaintext for testing with unencrypted communication
(not recommended for production).
""" """
example_maddr = ( example_maddr = (
@ -287,20 +418,30 @@ def main() -> None:
help=f"destination multiaddr string, e.g. {example_maddr}", help=f"destination multiaddr string, e.g. {example_maddr}",
) )
parser.add_argument( parser.add_argument(
"--noise", "--plaintext",
action="store_true", action="store_true",
help="use Noise encryption instead of plaintext security", help=(
"use plaintext security instead of Noise encryption "
"(not recommended for production)"
),
) )
args = parser.parse_args() args = parser.parse_args()
# Determine security mode: use plaintext by default, Noise if --noise is specified # Determine security mode: use Noise by default,
use_noise = args.noise # plaintext if --plaintext is specified
use_plaintext = args.plaintext
try: try:
trio.run(run, args.port, args.destination, use_noise) trio.run(run, args.port, args.destination, use_plaintext)
except KeyboardInterrupt: except KeyboardInterrupt:
pass # This is expected when Ctrl+C is pressed
# The signal handler already printed the shutdown message
print("✅ Clean exit completed.")
return
except Exception as e:
print(f"❌ Unexpected error: {e}")
return
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -19,6 +19,7 @@ from libp2p.abc import (
IPeerRouting, IPeerRouting,
IPeerStore, IPeerStore,
ISecureTransport, ISecureTransport,
ITransport,
) )
from libp2p.crypto.keys import ( from libp2p.crypto.keys import (
KeyPair, KeyPair,
@ -231,14 +232,15 @@ def new_swarm(
) )
# Create transport based on listen_addrs or default to TCP # Create transport based on listen_addrs or default to TCP
transport: ITransport
if listen_addrs is None: if listen_addrs is None:
transport = TCP() transport = TCP()
else: else:
# Use the first address to determine transport type # Use the first address to determine transport type
addr = listen_addrs[0] addr = listen_addrs[0]
transport = create_transport_for_multiaddr(addr, upgrader) transport_maybe = create_transport_for_multiaddr(addr, upgrader)
if transport is None: if transport_maybe is None:
# Fallback to TCP if no specific transport found # Fallback to TCP if no specific transport found
if addr.__contains__("tcp"): if addr.__contains__("tcp"):
transport = TCP() transport = TCP()
@ -250,20 +252,8 @@ def new_swarm(
f"Unknown transport in listen_addrs: {listen_addrs}. " f"Unknown transport in listen_addrs: {listen_addrs}. "
f"Supported protocols: {supported_protocols}" f"Supported protocols: {supported_protocols}"
) )
else:
# Generate X25519 keypair for Noise transport = transport_maybe
noise_key_pair = create_new_x25519_key_pair()
# Default security transports (using Noise as primary)
secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport] = sec_opt or {
NOISE_PROTOCOL_ID: NoiseTransport(
key_pair, noise_privkey=noise_key_pair.private_key
),
TProtocol(secio.ID): secio.Transport(key_pair),
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(
key_pair, peerstore=peerstore_opt
),
}
# Use given muxer preference if provided, otherwise use global default # Use given muxer preference if provided, otherwise use global default
if muxer_preference is not None: if muxer_preference is not None:

View File

@ -1,3 +1,4 @@
import logging
from typing import ( from typing import (
cast, cast,
) )
@ -15,6 +16,8 @@ from libp2p.io.msgio import (
FixedSizeLenMsgReadWriter, FixedSizeLenMsgReadWriter,
) )
logger = logging.getLogger(__name__)
SIZE_NOISE_MESSAGE_LEN = 2 SIZE_NOISE_MESSAGE_LEN = 2
MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1 MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1
SIZE_NOISE_MESSAGE_BODY_LEN = 2 SIZE_NOISE_MESSAGE_BODY_LEN = 2
@ -50,18 +53,25 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
self.noise_state = noise_state self.noise_state = noise_state
async def write_msg(self, msg: bytes, prefix_encoded: bool = False) -> None: async def write_msg(self, msg: bytes, prefix_encoded: bool = False) -> None:
logger.debug(f"Noise write_msg: encrypting {len(msg)} bytes")
data_encrypted = self.encrypt(msg) data_encrypted = self.encrypt(msg)
if prefix_encoded: if prefix_encoded:
# Manually add the prefix if needed # Manually add the prefix if needed
data_encrypted = self.prefix + data_encrypted data_encrypted = self.prefix + data_encrypted
logger.debug(f"Noise write_msg: writing {len(data_encrypted)} encrypted bytes")
await self.read_writer.write_msg(data_encrypted) await self.read_writer.write_msg(data_encrypted)
logger.debug("Noise write_msg: write completed successfully")
async def read_msg(self, prefix_encoded: bool = False) -> bytes: async def read_msg(self, prefix_encoded: bool = False) -> bytes:
logger.debug("Noise read_msg: reading encrypted message")
noise_msg_encrypted = await self.read_writer.read_msg() noise_msg_encrypted = await self.read_writer.read_msg()
logger.debug(f"Noise read_msg: read {len(noise_msg_encrypted)} encrypted bytes")
if prefix_encoded: if prefix_encoded:
return self.decrypt(noise_msg_encrypted[len(self.prefix) :]) result = self.decrypt(noise_msg_encrypted[len(self.prefix) :])
else: else:
return self.decrypt(noise_msg_encrypted) result = self.decrypt(noise_msg_encrypted)
logger.debug(f"Noise read_msg: decrypted to {len(result)} bytes")
return result
async def close(self) -> None: async def close(self) -> None:
await self.read_writer.close() await self.read_writer.close()

View File

@ -1,6 +1,7 @@
from dataclasses import ( from dataclasses import (
dataclass, dataclass,
) )
import logging
from libp2p.crypto.keys import ( from libp2p.crypto.keys import (
PrivateKey, PrivateKey,
@ -12,6 +13,8 @@ from libp2p.crypto.serialization import (
from .pb import noise_pb2 as noise_pb from .pb import noise_pb2 as noise_pb
logger = logging.getLogger(__name__)
SIGNED_DATA_PREFIX = "noise-libp2p-static-key:" SIGNED_DATA_PREFIX = "noise-libp2p-static-key:"
@ -48,6 +51,8 @@ def make_handshake_payload_sig(
id_privkey: PrivateKey, noise_static_pubkey: PublicKey id_privkey: PrivateKey, noise_static_pubkey: PublicKey
) -> bytes: ) -> bytes:
data = make_data_to_be_signed(noise_static_pubkey) data = make_data_to_be_signed(noise_static_pubkey)
logger.debug(f"make_handshake_payload_sig: signing data length: {len(data)}")
logger.debug(f"make_handshake_payload_sig: signing data hex: {data.hex()}")
return id_privkey.sign(data) return id_privkey.sign(data)
@ -60,4 +65,27 @@ def verify_handshake_payload_sig(
2. signed by the private key corresponding to `id_pubkey` 2. signed by the private key corresponding to `id_pubkey`
""" """
expected_data = make_data_to_be_signed(noise_static_pubkey) expected_data = make_data_to_be_signed(noise_static_pubkey)
return payload.id_pubkey.verify(expected_data, payload.id_sig) logger.debug(
f"verify_handshake_payload_sig: payload.id_pubkey type: "
f"{type(payload.id_pubkey)}"
)
logger.debug(
f"verify_handshake_payload_sig: noise_static_pubkey type: "
f"{type(noise_static_pubkey)}"
)
logger.debug(
f"verify_handshake_payload_sig: expected_data length: {len(expected_data)}"
)
logger.debug(
f"verify_handshake_payload_sig: expected_data hex: {expected_data.hex()}"
)
logger.debug(
f"verify_handshake_payload_sig: payload.id_sig length: {len(payload.id_sig)}"
)
try:
result = payload.id_pubkey.verify(expected_data, payload.id_sig)
logger.debug(f"verify_handshake_payload_sig: verification result: {result}")
return result
except Exception as e:
logger.error(f"verify_handshake_payload_sig: verification exception: {e}")
return False

View File

@ -2,6 +2,7 @@ from abc import (
ABC, ABC,
abstractmethod, abstractmethod,
) )
import logging
from cryptography.hazmat.primitives import ( from cryptography.hazmat.primitives import (
serialization, serialization,
@ -46,6 +47,8 @@ from .messages import (
verify_handshake_payload_sig, verify_handshake_payload_sig,
) )
logger = logging.getLogger(__name__)
class IPattern(ABC): class IPattern(ABC):
@abstractmethod @abstractmethod
@ -95,6 +98,7 @@ class PatternXX(BasePattern):
self.early_data = early_data self.early_data = early_data
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn: async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
logger.debug(f"Noise XX handshake_inbound started for peer {self.local_peer}")
noise_state = self.create_noise_state() noise_state = self.create_noise_state()
noise_state.set_as_responder() noise_state.set_as_responder()
noise_state.start_handshake() noise_state.start_handshake()
@ -107,15 +111,22 @@ class PatternXX(BasePattern):
read_writer = NoiseHandshakeReadWriter(conn, noise_state) read_writer = NoiseHandshakeReadWriter(conn, noise_state)
# Consume msg#1. # Consume msg#1.
logger.debug("Noise XX handshake_inbound: reading msg#1")
await read_writer.read_msg() await read_writer.read_msg()
logger.debug("Noise XX handshake_inbound: read msg#1 successfully")
# Send msg#2, which should include our handshake payload. # Send msg#2, which should include our handshake payload.
logger.debug("Noise XX handshake_inbound: preparing msg#2")
our_payload = self.make_handshake_payload() our_payload = self.make_handshake_payload()
msg_2 = our_payload.serialize() msg_2 = our_payload.serialize()
logger.debug(f"Noise XX handshake_inbound: sending msg#2 ({len(msg_2)} bytes)")
await read_writer.write_msg(msg_2) await read_writer.write_msg(msg_2)
logger.debug("Noise XX handshake_inbound: sent msg#2 successfully")
# Receive and consume msg#3. # Receive and consume msg#3.
logger.debug("Noise XX handshake_inbound: reading msg#3")
msg_3 = await read_writer.read_msg() msg_3 = await read_writer.read_msg()
logger.debug(f"Noise XX handshake_inbound: read msg#3 ({len(msg_3)} bytes)")
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3) peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
if handshake_state.rs is None: if handshake_state.rs is None:
@ -147,6 +158,7 @@ class PatternXX(BasePattern):
async def handshake_outbound( async def handshake_outbound(
self, conn: IRawConnection, remote_peer: ID self, conn: IRawConnection, remote_peer: ID
) -> ISecureConn: ) -> ISecureConn:
logger.debug(f"Noise XX handshake_outbound started to peer {remote_peer}")
noise_state = self.create_noise_state() noise_state = self.create_noise_state()
read_writer = NoiseHandshakeReadWriter(conn, noise_state) read_writer = NoiseHandshakeReadWriter(conn, noise_state)
@ -159,11 +171,15 @@ class PatternXX(BasePattern):
raise NoiseStateError("Handshake state is not initialized") raise NoiseStateError("Handshake state is not initialized")
# Send msg#1, which is *not* encrypted. # Send msg#1, which is *not* encrypted.
logger.debug("Noise XX handshake_outbound: sending msg#1")
msg_1 = b"" msg_1 = b""
await read_writer.write_msg(msg_1) await read_writer.write_msg(msg_1)
logger.debug("Noise XX handshake_outbound: sent msg#1 successfully")
# Read msg#2 from the remote, which contains the public key of the peer. # Read msg#2 from the remote, which contains the public key of the peer.
logger.debug("Noise XX handshake_outbound: reading msg#2")
msg_2 = await read_writer.read_msg() msg_2 = await read_writer.read_msg()
logger.debug(f"Noise XX handshake_outbound: read msg#2 ({len(msg_2)} bytes)")
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2) peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
if handshake_state.rs is None: if handshake_state.rs is None:
@ -174,8 +190,27 @@ class PatternXX(BasePattern):
) )
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs) remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
logger.debug(
f"Noise XX handshake_outbound: verifying signature for peer {remote_peer}"
)
logger.debug(
f"Noise XX handshake_outbound: remote_pubkey type: {type(remote_pubkey)}"
)
id_pubkey_repr = peer_handshake_payload.id_pubkey.to_bytes().hex()
logger.debug(
f"Noise XX handshake_outbound: peer_handshake_payload.id_pubkey: "
f"{id_pubkey_repr}"
)
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey): if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
logger.error(
f"Noise XX handshake_outbound: signature verification failed for peer "
f"{remote_peer}"
)
raise InvalidSignature raise InvalidSignature
logger.debug(
f"Noise XX handshake_outbound: signature verification successful for peer "
f"{remote_peer}"
)
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey) remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
if remote_peer_id_from_pubkey != remote_peer: if remote_peer_id_from_pubkey != remote_peer:
raise PeerIDMismatchesPubkey( raise PeerIDMismatchesPubkey(

View File

@ -7,8 +7,10 @@ from .transport_registry import (
register_transport, register_transport,
get_supported_transport_protocols, get_supported_transport_protocols,
) )
from .upgrader import TransportUpgrader
from libp2p.abc import ITransport
def create_transport(protocol: str, upgrader=None): def create_transport(protocol: str, upgrader: TransportUpgrader | None = None) -> ITransport:
""" """
Convenience function to create a transport instance. Convenience function to create a transport instance.
@ -28,7 +30,10 @@ def create_transport(protocol: str, upgrader=None):
registry = get_transport_registry() registry = get_transport_registry()
transport_class = registry.get_transport(protocol) transport_class = registry.get_transport(protocol)
if transport_class: if transport_class:
return registry.create_transport(protocol, upgrader) transport = registry.create_transport(protocol, upgrader)
if transport is None:
raise ValueError(f"Failed to create transport for protocol: {protocol}")
return transport
else: else:
raise ValueError(f"Unsupported transport protocol: {protocol}") raise ValueError(f"Unsupported transport protocol: {protocol}")

View File

@ -3,13 +3,15 @@ Transport registry for dynamic transport selection based on multiaddr protocols.
""" """
import logging import logging
from typing import Dict, Type, Optional from typing import Any
from multiaddr import Multiaddr from multiaddr import Multiaddr
from multiaddr.protocols import Protocol
from libp2p.abc import ITransport from libp2p.abc import ITransport
from libp2p.transport.tcp.tcp import TCP from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.websocket.transport import WebsocketTransport
from libp2p.transport.upgrader import TransportUpgrader from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.websocket.transport import WebsocketTransport
logger = logging.getLogger("libp2p.transport.registry") logger = logging.getLogger("libp2p.transport.registry")
@ -24,7 +26,7 @@ def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool:
try: try:
# TCP multiaddr should have structure like /ip4/127.0.0.1/tcp/8080 # TCP multiaddr should have structure like /ip4/127.0.0.1/tcp/8080
# or /ip6/::1/tcp/8080 # or /ip6/::1/tcp/8080
protocols = maddr.protocols() protocols: list[Protocol] = list(maddr.protocols())
# Must have at least 2 protocols: network (ip4/ip6) + tcp # Must have at least 2 protocols: network (ip4/ip6) + tcp
if len(protocols) < 2: if len(protocols) < 2:
@ -38,7 +40,8 @@ def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool:
if protocols[1].name != "tcp": if protocols[1].name != "tcp":
return False return False
# Should not have any protocols after tcp (unless it's a valid continuation like p2p) # Should not have any protocols after tcp (unless it's a valid
# continuation like p2p)
# For now, we'll be strict and only allow network + tcp # For now, we'll be strict and only allow network + tcp
if len(protocols) > 2: if len(protocols) > 2:
# Check if the additional protocols are valid continuations # Check if the additional protocols are valid continuations
@ -63,7 +66,7 @@ def _is_valid_websocket_multiaddr(maddr: Multiaddr) -> bool:
try: try:
# WebSocket multiaddr should have structure like /ip4/127.0.0.1/tcp/8080/ws # WebSocket multiaddr should have structure like /ip4/127.0.0.1/tcp/8080/ws
# or /ip6/::1/tcp/8080/ws # or /ip6/::1/tcp/8080/ws
protocols = maddr.protocols() protocols: list[Protocol] = list(maddr.protocols())
# Must have at least 3 protocols: network (ip4/ip6/dns4/dns6) + tcp + ws # Must have at least 3 protocols: network (ip4/ip6/dns4/dns6) + tcp + ws
if len(protocols) < 3: if len(protocols) < 3:
@ -100,8 +103,8 @@ class TransportRegistry:
Registry for mapping multiaddr protocols to transport implementations. Registry for mapping multiaddr protocols to transport implementations.
""" """
def __init__(self): def __init__(self) -> None:
self._transports: Dict[str, Type[ITransport]] = {} self._transports: dict[str, type[ITransport]] = {}
self._register_default_transports() self._register_default_transports()
def _register_default_transports(self) -> None: def _register_default_transports(self) -> None:
@ -112,7 +115,9 @@ class TransportRegistry:
# Register WebSocket transport for /ws protocol # Register WebSocket transport for /ws protocol
self.register_transport("ws", WebsocketTransport) self.register_transport("ws", WebsocketTransport)
def register_transport(self, protocol: str, transport_class: Type[ITransport]) -> None: def register_transport(
self, protocol: str, transport_class: type[ITransport]
) -> None:
""" """
Register a transport class for a specific protocol. Register a transport class for a specific protocol.
@ -120,9 +125,11 @@ class TransportRegistry:
:param transport_class: The transport class to register :param transport_class: The transport class to register
""" """
self._transports[protocol] = transport_class self._transports[protocol] = transport_class
logger.debug(f"Registered transport {transport_class.__name__} for protocol {protocol}") logger.debug(
f"Registered transport {transport_class.__name__} for protocol {protocol}"
)
def get_transport(self, protocol: str) -> Optional[Type[ITransport]]: def get_transport(self, protocol: str) -> type[ITransport] | None:
""" """
Get the transport class for a specific protocol. Get the transport class for a specific protocol.
@ -135,7 +142,9 @@ class TransportRegistry:
"""Get list of supported transport protocols.""" """Get list of supported transport protocols."""
return list(self._transports.keys()) return list(self._transports.keys())
def create_transport(self, protocol: str, upgrader: Optional[TransportUpgrader] = None, **kwargs) -> Optional[ITransport]: def create_transport(
self, protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any
) -> ITransport | None:
""" """
Create a transport instance for a specific protocol. Create a transport instance for a specific protocol.
@ -152,9 +161,12 @@ class TransportRegistry:
if protocol == "ws": if protocol == "ws":
# WebSocket transport requires upgrader # WebSocket transport requires upgrader
if upgrader is None: if upgrader is None:
logger.warning(f"WebSocket transport '{protocol}' requires upgrader") logger.warning(
f"WebSocket transport '{protocol}' requires upgrader"
)
return None return None
return transport_class(upgrader) # Use explicit WebsocketTransport to avoid type issues
return WebsocketTransport(upgrader)
else: else:
# TCP transport doesn't require upgrader # TCP transport doesn't require upgrader
return transport_class() return transport_class()
@ -172,12 +184,14 @@ def get_transport_registry() -> TransportRegistry:
return _global_registry return _global_registry
def register_transport(protocol: str, transport_class: Type[ITransport]) -> None: def register_transport(protocol: str, transport_class: type[ITransport]) -> None:
"""Register a transport class in the global registry.""" """Register a transport class in the global registry."""
_global_registry.register_transport(protocol, transport_class) _global_registry.register_transport(protocol, transport_class)
def create_transport_for_multiaddr(maddr: Multiaddr, upgrader: TransportUpgrader) -> Optional[ITransport]: def create_transport_for_multiaddr(
maddr: Multiaddr, upgrader: TransportUpgrader
) -> ITransport | None:
""" """
Create the appropriate transport for a given multiaddr. Create the appropriate transport for a given multiaddr.
@ -203,7 +217,10 @@ def create_transport_for_multiaddr(maddr: Multiaddr, upgrader: TransportUpgrader
return _global_registry.create_transport("tcp", upgrader) return _global_registry.create_transport("tcp", upgrader)
# If no supported transport protocol found or structure is invalid, return None # If no supported transport protocol found or structure is invalid, return None
logger.warning(f"No supported transport protocol found or invalid structure in multiaddr: {maddr}") logger.warning(
f"No supported transport protocol found or invalid structure in "
f"multiaddr: {maddr}"
)
return None return None
except Exception as e: except Exception as e:

View File

@ -1,9 +1,13 @@
from trio.abc import Stream import logging
from typing import Any
import trio import trio
from libp2p.io.abc import ReadWriteCloser from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException from libp2p.io.exceptions import IOException
logger = logging.getLogger(__name__)
class P2PWebSocketConnection(ReadWriteCloser): class P2PWebSocketConnection(ReadWriteCloser):
""" """
@ -11,7 +15,7 @@ class P2PWebSocketConnection(ReadWriteCloser):
that libp2p protocols expect. that libp2p protocols expect.
""" """
def __init__(self, ws_connection, ws_context=None): def __init__(self, ws_connection: Any, ws_context: Any = None) -> None:
self._ws_connection = ws_connection self._ws_connection = ws_connection
self._ws_context = ws_context self._ws_context = ws_context
self._read_buffer = b"" self._read_buffer = b""
@ -19,57 +23,102 @@ class P2PWebSocketConnection(ReadWriteCloser):
async def write(self, data: bytes) -> None: async def write(self, data: bytes) -> None:
try: try:
logger.debug(f"WebSocket writing {len(data)} bytes")
# Send as a binary WebSocket message # Send as a binary WebSocket message
await self._ws_connection.send_message(data) await self._ws_connection.send_message(data)
logger.debug(f"WebSocket wrote {len(data)} bytes successfully")
except Exception as e: except Exception as e:
logger.error(f"WebSocket write failed: {e}")
raise IOException from e raise IOException from e
async def read(self, n: int | None = None) -> bytes: async def read(self, n: int | None = None) -> bytes:
""" """
Read up to n bytes (if n is given), else read up to 64KiB. Read up to n bytes (if n is given), else read up to 64KiB.
This implementation provides byte-level access to WebSocket messages,
which is required for Noise protocol handshake.
""" """
async with self._read_lock: async with self._read_lock:
try: try:
logger.debug(
f"WebSocket read requested: n={n}, "
f"buffer_size={len(self._read_buffer)}"
)
# If we have buffered data, return it # If we have buffered data, return it
if self._read_buffer: if self._read_buffer:
if n is None: if n is None:
result = self._read_buffer result = self._read_buffer
self._read_buffer = b"" self._read_buffer = b""
logger.debug(
f"WebSocket read returning all buffered data: "
f"{len(result)} bytes"
)
return result return result
else: else:
if len(self._read_buffer) >= n: if len(self._read_buffer) >= n:
result = self._read_buffer[:n] result = self._read_buffer[:n]
self._read_buffer = self._read_buffer[n:] self._read_buffer = self._read_buffer[n:]
logger.debug(
f"WebSocket read returning {len(result)} bytes "
f"from buffer"
)
return result return result
else: else:
result = self._read_buffer # We need more data, but we have some buffered
self._read_buffer = b"" # Keep the buffered data and get more
return result logger.debug(
f"WebSocket read needs more data: have "
f"{len(self._read_buffer)}, need {n}"
)
pass
# Get the next WebSocket message # If we need exactly n bytes but don't have enough, get more data
while n is not None and (
not self._read_buffer or len(self._read_buffer) < n
):
logger.debug(
f"WebSocket read getting more data: "
f"buffer_size={len(self._read_buffer)}, need={n}"
)
# Get the next WebSocket message and treat it as a byte stream
# This mimics the Go implementation's NextReader() approach
message = await self._ws_connection.get_message() message = await self._ws_connection.get_message()
if isinstance(message, str): if isinstance(message, str):
message = message.encode('utf-8') message = message.encode("utf-8")
logger.debug(
f"WebSocket read received message: {len(message)} bytes"
)
# Add to buffer # Add to buffer
self._read_buffer = message self._read_buffer += message
# Return requested amount # Return requested amount
if n is None: if n is None:
result = self._read_buffer result = self._read_buffer
self._read_buffer = b"" self._read_buffer = b""
logger.debug(
f"WebSocket read returning all data: {len(result)} bytes"
)
return result return result
else: else:
if len(self._read_buffer) >= n: if len(self._read_buffer) >= n:
result = self._read_buffer[:n] result = self._read_buffer[:n]
self._read_buffer = self._read_buffer[n:] self._read_buffer = self._read_buffer[n:]
logger.debug(
f"WebSocket read returning exact {len(result)} bytes"
)
return result return result
else: else:
# This should never happen due to the while loop above
result = self._read_buffer result = self._read_buffer
self._read_buffer = b"" self._read_buffer = b""
logger.debug(
f"WebSocket read returning remaining {len(result)} bytes"
)
return result return result
except Exception as e: except Exception as e:
logger.error(f"WebSocket read failed: {e}")
raise IOException from e raise IOException from e
async def close(self) -> None: async def close(self) -> None:
@ -83,12 +132,12 @@ class P2PWebSocketConnection(ReadWriteCloser):
# Try to get remote address from the WebSocket connection # Try to get remote address from the WebSocket connection
try: try:
remote = self._ws_connection.remote remote = self._ws_connection.remote
if hasattr(remote, 'address') and hasattr(remote, 'port'): if hasattr(remote, "address") and hasattr(remote, "port"):
return str(remote.address), int(remote.port) return str(remote.address), int(remote.port)
elif isinstance(remote, str): elif isinstance(remote, str):
# Parse address:port format # Parse address:port format
if ':' in remote: if ":" in remote:
host, port = remote.rsplit(':', 1) host, port = remote.rsplit(":", 1)
return host, int(port) return host, int(port)
except Exception: except Exception:
pass pass

View File

@ -1,6 +1,6 @@
from collections.abc import Awaitable, Callable
import logging import logging
import socket from typing import Any
from typing import Any, Callable
from multiaddr import Multiaddr from multiaddr import Multiaddr
import trio import trio
@ -9,7 +9,6 @@ from trio_websocket import serve_websocket
from libp2p.abc import IListener from libp2p.abc import IListener
from libp2p.custom_types import THandler from libp2p.custom_types import THandler
from libp2p.network.connection.raw_connection import RawConnection
from libp2p.transport.upgrader import TransportUpgrader from libp2p.transport.upgrader import TransportUpgrader
from .connection import P2PWebSocketConnection from .connection import P2PWebSocketConnection
@ -27,7 +26,8 @@ class WebsocketListener(IListener):
self._upgrader = upgrader self._upgrader = upgrader
self._server = None self._server = None
self._shutdown_event = trio.Event() self._shutdown_event = trio.Event()
self._nursery = None self._nursery: trio.Nursery | None = None
self._listeners: Any = None
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool: async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
logger.debug(f"WebsocketListener.listen called with {maddr}") logger.debug(f"WebsocketListener.listen called with {maddr}")
@ -51,15 +51,15 @@ class WebsocketListener(IListener):
logger.debug(f"WebsocketListener: host={host}, port={port}") logger.debug(f"WebsocketListener: host={host}, port={port}")
async def serve_websocket_tcp( async def serve_websocket_tcp(
handler: Callable, handler: Callable[[Any], Awaitable[None]],
port: int, port: int,
host: str, host: str,
task_status: trio.TaskStatus[list], task_status: TaskStatus[Any],
) -> None: ) -> None:
"""Start TCP server and handle WebSocket connections manually""" """Start TCP server and handle WebSocket connections manually"""
logger.debug("serve_websocket_tcp %s %s", host, port) logger.debug("serve_websocket_tcp %s %s", host, port)
async def websocket_handler(request): async def websocket_handler(request: Any) -> None:
"""Handle WebSocket requests""" """Handle WebSocket requests"""
logger.debug("WebSocket request received") logger.debug("WebSocket request received")
try: try:
@ -68,7 +68,7 @@ class WebsocketListener(IListener):
logger.debug("WebSocket handshake successful") logger.debug("WebSocket handshake successful")
# Create the WebSocket connection wrapper # Create the WebSocket connection wrapper
conn = P2PWebSocketConnection(ws_connection) conn = P2PWebSocketConnection(ws_connection) # type: ignore[no-untyped-call]
# Call the handler function that was passed to create_listener # Call the handler function that was passed to create_listener
# This handler will handle the security and muxing upgrades # This handler will handle the security and muxing upgrades
@ -77,22 +77,26 @@ class WebsocketListener(IListener):
# Don't keep the connection alive indefinitely # Don't keep the connection alive indefinitely
# Let the handler manage the connection lifecycle # Let the handler manage the connection lifecycle
logger.debug("Handler completed, connection will be managed by handler") logger.debug(
"Handler completed, connection will be managed by handler"
)
except Exception as e: except Exception as e:
logger.debug(f"WebSocket connection error: {e}") logger.debug(f"WebSocket connection error: {e}")
logger.debug(f"Error type: {type(e)}") logger.debug(f"Error type: {type(e)}")
import traceback import traceback
logger.debug(f"Traceback: {traceback.format_exc()}") logger.debug(f"Traceback: {traceback.format_exc()}")
# Reject the connection # Reject the connection
try: try:
await request.reject(400) await request.reject(400)
except: except Exception:
pass pass
# Use trio_websocket.serve_websocket for proper WebSocket handling # Use trio_websocket.serve_websocket for proper WebSocket handling
from trio_websocket import serve_websocket await serve_websocket(
await serve_websocket(websocket_handler, host, port, None, task_status=task_status) websocket_handler, host, port, None, task_status=task_status
)
# Store the nursery for shutdown # Store the nursery for shutdown
self._nursery = nursery self._nursery = nursery
@ -111,18 +115,21 @@ class WebsocketListener(IListener):
logger.error(f"Failed to start WebSocket listener for {maddr}") logger.error(f"Failed to start WebSocket listener for {maddr}")
return False return False
# Store the listeners for get_addrs() and close() - these are real SocketListener objects # Store the listeners for get_addrs() and close() - these are real
# SocketListener objects
self._listeners = started_listeners self._listeners = started_listeners
logger.debug(f"WebsocketListener.listen returning True with WebSocketServer object") logger.debug(
"WebsocketListener.listen returning True with WebSocketServer object"
)
return True return True
def get_addrs(self) -> tuple[Multiaddr, ...]: def get_addrs(self) -> tuple[Multiaddr, ...]:
if not hasattr(self, '_listeners') or not self._listeners: if not hasattr(self, "_listeners") or not self._listeners:
logger.debug("No listeners available for get_addrs()") logger.debug("No listeners available for get_addrs()")
return () return ()
# Handle WebSocketServer objects # Handle WebSocketServer objects
if hasattr(self._listeners, 'port'): if hasattr(self._listeners, "port"):
# This is a WebSocketServer object # This is a WebSocketServer object
port = self._listeners.port port = self._listeners.port
# Create a multiaddr from the port # Create a multiaddr from the port
@ -138,12 +145,12 @@ class WebsocketListener(IListener):
async def close(self) -> None: async def close(self) -> None:
"""Close the WebSocket listener and stop accepting new connections""" """Close the WebSocket listener and stop accepting new connections"""
logger.debug("WebsocketListener.close called") logger.debug("WebsocketListener.close called")
if hasattr(self, '_listeners') and self._listeners: if hasattr(self, "_listeners") and self._listeners:
# Signal shutdown # Signal shutdown
self._shutdown_event.set() self._shutdown_event.set()
# Close the WebSocket server # Close the WebSocket server
if hasattr(self._listeners, 'aclose'): if hasattr(self._listeners, "aclose"):
# This is a WebSocketServer object # This is a WebSocketServer object
logger.debug("Closing WebSocket server") logger.debug("Closing WebSocket server")
await self._listeners.aclose() await self._listeners.aclose()
@ -152,12 +159,12 @@ class WebsocketListener(IListener):
# This is a list of listeners (like TCP) # This is a list of listeners (like TCP)
logger.debug("Closing TCP listeners") logger.debug("Closing TCP listeners")
for listener in self._listeners: for listener in self._listeners:
listener.close() await listener.aclose()
logger.debug("TCP listeners closed") logger.debug("TCP listeners closed")
else: else:
# Unknown type, try to close it directly # Unknown type, try to close it directly
logger.debug("Closing unknown listener type") logger.debug("Closing unknown listener type")
if hasattr(self._listeners, 'close'): if hasattr(self._listeners, "close"):
self._listeners.close() self._listeners.close()
logger.debug("Unknown listener closed") logger.debug("Unknown listener closed")

View File

@ -1,6 +1,6 @@
import logging import logging
from multiaddr import Multiaddr from multiaddr import Multiaddr
from trio_websocket import open_websocket_url
from libp2p.abc import IListener, ITransport from libp2p.abc import IListener, ITransport
from libp2p.custom_types import THandler from libp2p.custom_types import THandler
@ -11,7 +11,7 @@ from libp2p.transport.upgrader import TransportUpgrader
from .connection import P2PWebSocketConnection from .connection import P2PWebSocketConnection
from .listener import WebsocketListener from .listener import WebsocketListener
logger = logging.getLogger("libp2p.transport.websocket") logger = logging.getLogger(__name__)
class WebsocketTransport(ITransport): class WebsocketTransport(ITransport):
@ -45,6 +45,7 @@ class WebsocketTransport(ITransport):
try: try:
from trio_websocket import open_websocket_url from trio_websocket import open_websocket_url
# Use the context manager but don't exit it immediately # Use the context manager but don't exit it immediately
# The connection will be closed when the RawConnection is closed # The connection will be closed when the RawConnection is closed
ws_context = open_websocket_url(ws_url) ws_context = open_websocket_url(ws_url)

View File

@ -2,20 +2,20 @@
Tests for the transport registry functionality. Tests for the transport registry functionality.
""" """
import pytest
from multiaddr import Multiaddr from multiaddr import Multiaddr
from libp2p.abc import ITransport from libp2p.abc import IListener, IRawConnection, ITransport
from libp2p.custom_types import THandler
from libp2p.transport.tcp.tcp import TCP from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.websocket.transport import WebsocketTransport
from libp2p.transport.transport_registry import ( from libp2p.transport.transport_registry import (
TransportRegistry, TransportRegistry,
create_transport_for_multiaddr, create_transport_for_multiaddr,
get_supported_transport_protocols,
get_transport_registry, get_transport_registry,
register_transport, register_transport,
get_supported_transport_protocols,
) )
from libp2p.transport.upgrader import TransportUpgrader from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.websocket.transport import WebsocketTransport
class TestTransportRegistry: class TestTransportRegistry:
@ -36,8 +36,14 @@ class TestTransportRegistry:
registry = TransportRegistry() registry = TransportRegistry()
# Register a custom transport # Register a custom transport
class CustomTransport: class CustomTransport(ITransport):
pass async def dial(self, maddr: Multiaddr) -> IRawConnection:
raise NotImplementedError("CustomTransport dial not implemented")
def create_listener(self, handler_function: THandler) -> IListener:
raise NotImplementedError(
"CustomTransport create_listener not implemented"
)
registry.register_transport("custom", CustomTransport) registry.register_transport("custom", CustomTransport)
assert registry.get_transport("custom") == CustomTransport assert registry.get_transport("custom") == CustomTransport
@ -105,8 +111,15 @@ class TestGlobalRegistry:
def test_register_transport_global(self): def test_register_transport_global(self):
"""Test registering transport globally.""" """Test registering transport globally."""
class GlobalCustomTransport:
pass class GlobalCustomTransport(ITransport):
async def dial(self, maddr: Multiaddr) -> IRawConnection:
raise NotImplementedError("GlobalCustomTransport dial not implemented")
def create_listener(self, handler_function: THandler) -> IListener:
raise NotImplementedError(
"GlobalCustomTransport create_listener not implemented"
)
# Register globally # Register globally
register_transport("global_custom", GlobalCustomTransport) register_transport("global_custom", GlobalCustomTransport)
@ -191,17 +204,18 @@ class TestTransportFactory:
assert transport is None assert transport is None
def test_create_transport_for_multiaddr_no_upgrader(self): def test_create_transport_for_multiaddr_with_upgrader(self):
"""Test creating transport without upgrader.""" """Test creating transport with upgrader."""
# This should work for TCP but not WebSocket upgrader = TransportUpgrader({}, {})
# This should work for both TCP and WebSocket with upgrader
maddr_tcp = Multiaddr("/ip4/127.0.0.1/tcp/8080") maddr_tcp = Multiaddr("/ip4/127.0.0.1/tcp/8080")
transport_tcp = create_transport_for_multiaddr(maddr_tcp, None) transport_tcp = create_transport_for_multiaddr(maddr_tcp, upgrader)
assert transport_tcp is not None assert transport_tcp is not None
maddr_ws = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") maddr_ws = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
transport_ws = create_transport_for_multiaddr(maddr_ws, None) transport_ws = create_transport_for_multiaddr(maddr_ws, upgrader)
# WebSocket transport creation should fail gracefully assert transport_ws is not None
assert transport_ws is None
class TestTransportInterfaceCompliance: class TestTransportInterfaceCompliance:
@ -211,8 +225,8 @@ class TestTransportInterfaceCompliance:
"""Test that TCP transport implements ITransport.""" """Test that TCP transport implements ITransport."""
transport = TCP() transport = TCP()
assert isinstance(transport, ITransport) assert isinstance(transport, ITransport)
assert hasattr(transport, 'dial') assert hasattr(transport, "dial")
assert hasattr(transport, 'create_listener') assert hasattr(transport, "create_listener")
assert callable(transport.dial) assert callable(transport.dial)
assert callable(transport.create_listener) assert callable(transport.create_listener)
@ -221,8 +235,8 @@ class TestTransportInterfaceCompliance:
upgrader = TransportUpgrader({}, {}) upgrader = TransportUpgrader({}, {})
transport = WebsocketTransport(upgrader) transport = WebsocketTransport(upgrader)
assert isinstance(transport, ITransport) assert isinstance(transport, ITransport)
assert hasattr(transport, 'dial') assert hasattr(transport, "dial")
assert hasattr(transport, 'create_listener') assert hasattr(transport, "create_listener")
assert callable(transport.dial) assert callable(transport.dial)
assert callable(transport.create_listener) assert callable(transport.create_listener)
@ -236,10 +250,18 @@ class TestErrorHandling:
upgrader = TransportUpgrader({}, {}) upgrader = TransportUpgrader({}, {})
# Register a transport that raises an exception # Register a transport that raises an exception
class ExceptionTransport: class ExceptionTransport(ITransport):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
raise RuntimeError("Transport creation failed") raise RuntimeError("Transport creation failed")
async def dial(self, maddr: Multiaddr) -> IRawConnection:
raise NotImplementedError("ExceptionTransport dial not implemented")
def create_listener(self, handler_function: THandler) -> IListener:
raise NotImplementedError(
"ExceptionTransport create_listener not implemented"
)
registry.register_transport("exception", ExceptionTransport) registry.register_transport("exception", ExceptionTransport)
# Should handle exception gracefully and return None # Should handle exception gracefully and return None
@ -252,7 +274,8 @@ class TestErrorHandling:
# Test with a multiaddr that has an unsupported transport protocol # Test with a multiaddr that has an unsupported transport protocol
# This should be handled gracefully by our transport registry # This should be handled gracefully by our transport registry
maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/udp/1234") # udp is not a supported transport # udp is not a supported transport
maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/udp/1234")
transport = create_transport_for_multiaddr(maddr, upgrader) transport = create_transport_for_multiaddr(maddr, upgrader)
assert transport is None assert transport is None
@ -286,8 +309,14 @@ class TestIntegration:
assert registry1 is registry2 assert registry1 is registry2
# Register a transport in one # Register a transport in one
class PersistentTransport: class PersistentTransport(ITransport):
pass async def dial(self, maddr: Multiaddr) -> IRawConnection:
raise NotImplementedError("PersistentTransport dial not implemented")
def create_listener(self, handler_function: THandler) -> IListener:
raise NotImplementedError(
"PersistentTransport create_listener not implemented"
)
registry1.register_transport("persistent", PersistentTransport) registry1.register_transport("persistent", PersistentTransport)

View File

@ -1,23 +1,23 @@
from collections.abc import Sequence from collections.abc import Sequence
import logging
from typing import Any from typing import Any
import pytest import pytest
import trio
from multiaddr import Multiaddr from multiaddr import Multiaddr
import trio
from libp2p.crypto.secp256k1 import create_new_key_pair from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.custom_types import TProtocol from libp2p.custom_types import TProtocol
from libp2p.host.basic_host import BasicHost from libp2p.host.basic_host import BasicHost
from libp2p.network.swarm import Swarm from libp2p.network.swarm import Swarm
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo
from libp2p.peer.peerstore import PeerStore from libp2p.peer.peerstore import PeerStore
from libp2p.security.insecure.transport import InsecureTransport from libp2p.security.insecure.transport import InsecureTransport
from libp2p.stream_muxer.yamux.yamux import Yamux from libp2p.stream_muxer.yamux.yamux import Yamux
from libp2p.transport.upgrader import TransportUpgrader from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.websocket.transport import WebsocketTransport from libp2p.transport.websocket.transport import WebsocketTransport
from libp2p.transport.websocket.listener import WebsocketListener
from libp2p.transport.exceptions import OpenConnectionError logger = logging.getLogger(__name__)
PLAINTEXT_PROTOCOL_ID = "/plaintext/1.0.0" PLAINTEXT_PROTOCOL_ID = "/plaintext/1.0.0"
@ -64,9 +64,6 @@ def create_upgrader():
) )
# 2. Listener Basic Functionality Tests # 2. Listener Basic Functionality Tests
@pytest.mark.trio @pytest.mark.trio
async def test_listener_basic_listen(): async def test_listener_basic_listen():
@ -76,12 +73,16 @@ async def test_listener_basic_listen():
# Test listening on IPv4 # Test listening on IPv4
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws") ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
listener = transport.create_listener(lambda conn: None)
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
# Test that listener can be created and has required methods # Test that listener can be created and has required methods
assert hasattr(listener, 'listen') assert hasattr(listener, "listen")
assert hasattr(listener, 'close') assert hasattr(listener, "close")
assert hasattr(listener, 'get_addrs') assert hasattr(listener, "get_addrs")
# Test that listener can handle the address # Test that listener can handle the address
assert ma.value_for_protocol("ip4") == "127.0.0.1" assert ma.value_for_protocol("ip4") == "127.0.0.1"
@ -98,7 +99,11 @@ async def test_listener_port_0_handling():
transport = WebsocketTransport(upgrader) transport = WebsocketTransport(upgrader)
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws") ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
listener = transport.create_listener(lambda conn: None)
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
# Test that the address can be parsed correctly # Test that the address can be parsed correctly
port_str = ma.value_for_protocol("tcp") port_str = ma.value_for_protocol("tcp")
@ -115,7 +120,11 @@ async def test_listener_any_interface():
transport = WebsocketTransport(upgrader) transport = WebsocketTransport(upgrader)
ma = Multiaddr("/ip4/0.0.0.0/tcp/0/ws") ma = Multiaddr("/ip4/0.0.0.0/tcp/0/ws")
listener = transport.create_listener(lambda conn: None)
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
# Test that the address can be parsed correctly # Test that the address can be parsed correctly
host = ma.value_for_protocol("ip4") host = ma.value_for_protocol("ip4")
@ -134,7 +143,11 @@ async def test_listener_address_preservation():
# Create address with p2p ID # Create address with p2p ID
p2p_id = "12D3KooWL5xtmx8Mgc6tByjVaPPpTKH42QK7PUFQtZLabdSMKHpF" p2p_id = "12D3KooWL5xtmx8Mgc6tByjVaPPpTKH42QK7PUFQtZLabdSMKHpF"
ma = Multiaddr(f"/ip4/127.0.0.1/tcp/0/ws/p2p/{p2p_id}") ma = Multiaddr(f"/ip4/127.0.0.1/tcp/0/ws/p2p/{p2p_id}")
listener = transport.create_listener(lambda conn: None)
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
# Test that p2p ID is preserved in the address # Test that p2p ID is preserved in the address
addr_str = str(ma) addr_str = str(ma)
@ -161,7 +174,7 @@ async def test_dial_basic():
assert port == "8080" assert port == "8080"
# Test that transport has the required methods # Test that transport has the required methods
assert hasattr(transport, 'dial') assert hasattr(transport, "dial")
assert callable(transport.dial) assert callable(transport.dial)
@ -179,7 +192,7 @@ async def test_dial_with_p2p_id():
assert p2p_id in addr_str assert p2p_id in addr_str
# Test that transport can handle addresses with p2p IDs # Test that transport can handle addresses with p2p IDs
assert hasattr(transport, 'dial') assert hasattr(transport, "dial")
assert callable(transport.dial) assert callable(transport.dial)
@ -197,15 +210,14 @@ async def test_dial_port_0_resolution():
assert port_str == "0" assert port_str == "0"
# Test that transport has the required methods # Test that transport has the required methods
assert hasattr(transport, 'dial') assert hasattr(transport, "dial")
assert callable(transport.dial) assert callable(transport.dial)
# 4. Address Validation Tests (CRITICAL) # 4. Address Validation Tests (CRITICAL)
def test_address_validation_ipv4(): def test_address_validation_ipv4():
"""Test IPv4 address validation""" """Test IPv4 address validation"""
upgrader = create_upgrader() # upgrader = create_upgrader() # Not used in this test
transport = WebsocketTransport(upgrader)
# Valid IPv4 WebSocket addresses # Valid IPv4 WebSocket addresses
valid_addresses = [ valid_addresses = [
@ -222,7 +234,9 @@ def test_address_validation_ipv4():
assert "/ws" in transport_addr assert "/ws" in transport_addr
# Test that transport can handle addresses with p2p IDs # Test that transport can handle addresses with p2p IDs
p2p_addr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws/p2p/Qmb6owHp6eaWArVbcJJbQSyifyJBttMMjYV76N2hMbf5Vw") p2p_addr = Multiaddr(
"/ip4/127.0.0.1/tcp/8080/ws/p2p/Qmb6owHp6eaWArVbcJJbQSyifyJBttMMjYV76N2hMbf5Vw"
)
# Should not raise exception when creating transport address # Should not raise exception when creating transport address
transport_addr = str(p2p_addr) transport_addr = str(p2p_addr)
assert "/ws" in transport_addr assert "/ws" in transport_addr
@ -230,8 +244,7 @@ def test_address_validation_ipv4():
def test_address_validation_ipv6(): def test_address_validation_ipv6():
"""Test IPv6 address validation""" """Test IPv6 address validation"""
upgrader = create_upgrader() # upgrader = create_upgrader() # Not used in this test
transport = WebsocketTransport(upgrader)
# Valid IPv6 WebSocket addresses # Valid IPv6 WebSocket addresses
valid_addresses = [ valid_addresses = [
@ -248,8 +261,7 @@ def test_address_validation_ipv6():
def test_address_validation_dns(): def test_address_validation_dns():
"""Test DNS address validation""" """Test DNS address validation"""
upgrader = create_upgrader() # upgrader = create_upgrader() # Not used in this test
transport = WebsocketTransport(upgrader)
# Valid DNS WebSocket addresses # Valid DNS WebSocket addresses
valid_addresses = [ valid_addresses = [
@ -267,8 +279,7 @@ def test_address_validation_dns():
def test_address_validation_mixed(): def test_address_validation_mixed():
"""Test mixed address validation""" """Test mixed address validation"""
upgrader = create_upgrader() # upgrader = create_upgrader() # Not used in this test
transport = WebsocketTransport(upgrader)
# Mixed valid and invalid addresses # Mixed valid and invalid addresses
addresses = [ addresses = [
@ -310,15 +321,14 @@ async def test_dial_invalid_address():
] ]
for ma in invalid_addresses: for ma in invalid_addresses:
with pytest.raises((ValueError, OpenConnectionError, Exception)): with pytest.raises(Exception):
await transport.dial(ma) await transport.dial(ma)
@pytest.mark.trio @pytest.mark.trio
async def test_listen_invalid_address(): async def test_listen_invalid_address():
"""Test listening on invalid addresses""" """Test listening on invalid addresses"""
upgrader = create_upgrader() # upgrader = create_upgrader() # Not used in this test
transport = WebsocketTransport(upgrader)
# Test listening on non-WebSocket addresses # Test listening on non-WebSocket addresses
invalid_addresses = [ invalid_addresses = [
@ -352,7 +362,7 @@ async def test_listen_port_in_use():
assert ma2.value_for_protocol("tcp") == "8080" assert ma2.value_for_protocol("tcp") == "8080"
# Test that transport can handle these addresses # Test that transport can handle these addresses
assert hasattr(transport, 'create_listener') assert hasattr(transport, "create_listener")
assert callable(transport.create_listener) assert callable(transport.create_listener)
@ -364,12 +374,15 @@ async def test_connection_close():
transport = WebsocketTransport(upgrader) transport = WebsocketTransport(upgrader)
# Test that transport has required methods # Test that transport has required methods
assert hasattr(transport, 'dial') assert hasattr(transport, "dial")
assert callable(transport.dial) assert callable(transport.dial)
# Test that listener can be created and closed # Test that listener can be created and closed
listener = transport.create_listener(lambda conn: None) async def dummy_handler(conn):
assert hasattr(listener, 'close') await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
assert hasattr(listener, "close")
assert callable(listener.close) assert callable(listener.close)
# Test that listener can be closed # Test that listener can be closed
@ -397,16 +410,10 @@ async def test_multiple_connections():
assert port in ["8080", "8081", "8082"] assert port in ["8080", "8081", "8082"]
# Test that transport has required methods # Test that transport has required methods
assert hasattr(transport, 'dial') assert hasattr(transport, "dial")
assert callable(transport.dial) assert callable(transport.dial)
# Original test (kept for compatibility) # Original test (kept for compatibility)
@pytest.mark.trio @pytest.mark.trio
async def test_websocket_dial_and_listen(): async def test_websocket_dial_and_listen():
@ -416,11 +423,14 @@ async def test_websocket_dial_and_listen():
transport = WebsocketTransport(upgrader) transport = WebsocketTransport(upgrader)
# Test that transport can create listeners # Test that transport can create listeners
listener = transport.create_listener(lambda conn: None) async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
assert listener is not None assert listener is not None
assert hasattr(listener, 'listen') assert hasattr(listener, "listen")
assert hasattr(listener, 'close') assert hasattr(listener, "close")
assert hasattr(listener, 'get_addrs') assert hasattr(listener, "get_addrs")
# Test that transport can handle WebSocket addresses # Test that transport can handle WebSocket addresses
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws") ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
@ -429,7 +439,7 @@ async def test_websocket_dial_and_listen():
assert "ws" in str(ma) assert "ws" in str(ma)
# Test that transport has dial method # Test that transport has dial method
assert hasattr(transport, 'dial') assert hasattr(transport, "dial")
assert callable(transport.dial) assert callable(transport.dial)
# Test that transport can handle WebSocket multiaddrs # Test that transport can handle WebSocket multiaddrs
@ -442,14 +452,9 @@ async def test_websocket_dial_and_listen():
await listener.close() await listener.close()
import logging
logger = logging.getLogger(__name__)
@pytest.mark.trio @pytest.mark.trio
async def test_websocket_transport_basic(): async def test_websocket_transport_basic():
"""Test basic WebSocket transport functionality without full libp2p stack""" """Test basic WebSocket transport functionality without full libp2p stack"""
# Create WebSocket transport # Create WebSocket transport
key_pair = create_new_key_pair() key_pair = create_new_key_pair()
upgrader = TransportUpgrader( upgrader = TransportUpgrader(
@ -461,14 +466,17 @@ async def test_websocket_transport_basic():
transport = WebsocketTransport(upgrader) transport = WebsocketTransport(upgrader)
assert transport is not None assert transport is not None
assert hasattr(transport, 'dial') assert hasattr(transport, "dial")
assert hasattr(transport, 'create_listener') assert hasattr(transport, "create_listener")
listener = transport.create_listener(lambda conn: None) async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
assert listener is not None assert listener is not None
assert hasattr(listener, 'listen') assert hasattr(listener, "listen")
assert hasattr(listener, 'close') assert hasattr(listener, "close")
assert hasattr(listener, 'get_addrs') assert hasattr(listener, "get_addrs")
valid_addr = Multiaddr("/ip4/127.0.0.1/tcp/0/ws") valid_addr = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
assert valid_addr.value_for_protocol("ip4") == "127.0.0.1" assert valid_addr.value_for_protocol("ip4") == "127.0.0.1"
@ -480,8 +488,7 @@ async def test_websocket_transport_basic():
@pytest.mark.trio @pytest.mark.trio
async def test_websocket_simple_connection(): async def test_websocket_simple_connection():
"""Test WebSocket transport creation and basic functionality without real connections""" """Test WebSocket transport creation and basic functionality without real conn"""
# Create WebSocket transport # Create WebSocket transport
key_pair = create_new_key_pair() key_pair = create_new_key_pair()
upgrader = TransportUpgrader( upgrader = TransportUpgrader(
@ -493,17 +500,17 @@ async def test_websocket_simple_connection():
transport = WebsocketTransport(upgrader) transport = WebsocketTransport(upgrader)
assert transport is not None assert transport is not None
assert hasattr(transport, 'dial') assert hasattr(transport, "dial")
assert hasattr(transport, 'create_listener') assert hasattr(transport, "create_listener")
async def simple_handler(conn): async def simple_handler(conn):
await conn.close() await conn.close()
listener = transport.create_listener(simple_handler) listener = transport.create_listener(simple_handler)
assert listener is not None assert listener is not None
assert hasattr(listener, 'listen') assert hasattr(listener, "listen")
assert hasattr(listener, 'close') assert hasattr(listener, "close")
assert hasattr(listener, 'get_addrs') assert hasattr(listener, "get_addrs")
test_addr = Multiaddr("/ip4/127.0.0.1/tcp/0/ws") test_addr = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
assert test_addr.value_for_protocol("ip4") == "127.0.0.1" assert test_addr.value_for_protocol("ip4") == "127.0.0.1"
@ -516,7 +523,6 @@ async def test_websocket_simple_connection():
@pytest.mark.trio @pytest.mark.trio
async def test_websocket_real_connection(): async def test_websocket_real_connection():
"""Test WebSocket transport creation and basic functionality""" """Test WebSocket transport creation and basic functionality"""
# Create WebSocket transport # Create WebSocket transport
key_pair = create_new_key_pair() key_pair = create_new_key_pair()
upgrader = TransportUpgrader( upgrader = TransportUpgrader(
@ -528,17 +534,17 @@ async def test_websocket_real_connection():
transport = WebsocketTransport(upgrader) transport = WebsocketTransport(upgrader)
assert transport is not None assert transport is not None
assert hasattr(transport, 'dial') assert hasattr(transport, "dial")
assert hasattr(transport, 'create_listener') assert hasattr(transport, "create_listener")
async def handler(conn): async def handler(conn):
await conn.close() await conn.close()
listener = transport.create_listener(handler) listener = transport.create_listener(handler)
assert listener is not None assert listener is not None
assert hasattr(listener, 'listen') assert hasattr(listener, "listen")
assert hasattr(listener, 'close') assert hasattr(listener, "close")
assert hasattr(listener, 'get_addrs') assert hasattr(listener, "get_addrs")
await listener.close() await listener.close()
@ -546,7 +552,6 @@ async def test_websocket_real_connection():
@pytest.mark.trio @pytest.mark.trio
async def test_websocket_with_tcp_fallback(): async def test_websocket_with_tcp_fallback():
"""Test WebSocket functionality using TCP transport as fallback""" """Test WebSocket functionality using TCP transport as fallback"""
from tests.utils.factories import host_pair_factory from tests.utils.factories import host_pair_factory
async with host_pair_factory() as (host_a, host_b): async with host_pair_factory() as (host_a, host_b):
@ -578,7 +583,6 @@ async def test_websocket_with_tcp_fallback():
@pytest.mark.trio @pytest.mark.trio
async def test_websocket_transport_interface(): async def test_websocket_transport_interface():
"""Test WebSocket transport interface compliance""" """Test WebSocket transport interface compliance"""
key_pair = create_new_key_pair() key_pair = create_new_key_pair()
upgrader = TransportUpgrader( upgrader = TransportUpgrader(
secure_transports_by_protocol={ secure_transports_by_protocol={
@ -589,15 +593,18 @@ async def test_websocket_transport_interface():
transport = WebsocketTransport(upgrader) transport = WebsocketTransport(upgrader)
assert hasattr(transport, 'dial') assert hasattr(transport, "dial")
assert hasattr(transport, 'create_listener') assert hasattr(transport, "create_listener")
assert callable(transport.dial) assert callable(transport.dial)
assert callable(transport.create_listener) assert callable(transport.create_listener)
listener = transport.create_listener(lambda conn: None) async def dummy_handler(conn):
assert hasattr(listener, 'listen') await trio.sleep(0)
assert hasattr(listener, 'close')
assert hasattr(listener, 'get_addrs') listener = transport.create_listener(dummy_handler)
assert hasattr(listener, "listen")
assert hasattr(listener, "close")
assert hasattr(listener, "get_addrs")
test_addr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws") test_addr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
host = test_addr.value_for_protocol("ip4") host = test_addr.value_for_protocol("ip4")

View File

@ -20,7 +20,7 @@ from libp2p.stream_muxer.yamux.yamux import Yamux
from libp2p.transport.upgrader import TransportUpgrader from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.websocket.transport import WebsocketTransport from libp2p.transport.websocket.transport import WebsocketTransport
PLAINTEXT_PROTOCOL_ID = "/plaintext/1.0.0" PLAINTEXT_PROTOCOL_ID = "/plaintext/2.0.0"
@pytest.mark.trio @pytest.mark.trio
@ -74,6 +74,11 @@ async def test_ping_with_js_node():
peer_id = ID.from_base58(peer_id_line) peer_id = ID.from_base58(peer_id_line)
maddr = Multiaddr(addr_line) maddr = Multiaddr(addr_line)
# Debug: Print what we're trying to connect to
print(f"JS Node Peer ID: {peer_id_line}")
print(f"JS Node Address: {addr_line}")
print(f"All JS Node lines: {lines}")
# Set up Python host # Set up Python host
key_pair = create_new_key_pair() key_pair = create_new_key_pair()
py_peer_id = ID.from_pubkey(key_pair.public_key) py_peer_id = ID.from_pubkey(key_pair.public_key)
@ -86,13 +91,15 @@ async def test_ping_with_js_node():
}, },
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux}, muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
) )
transport = WebsocketTransport() transport = WebsocketTransport(upgrader)
swarm = Swarm(py_peer_id, peer_store, upgrader, transport) swarm = Swarm(py_peer_id, peer_store, upgrader, transport)
host = BasicHost(swarm) host = BasicHost(swarm)
# Connect to JS node # Connect to JS node
peer_info = PeerInfo(peer_id, [maddr]) peer_info = PeerInfo(peer_id, [maddr])
print(f"Python trying to connect to: {peer_info}")
await trio.sleep(1) await trio.sleep(1)
try: try: