Fix typecheck errors and improve WebSocket transport implementation

- Fix INotifee interface compliance in WebSocket demo
- Fix handler function signatures to be async (THandler compatibility)
- Fix is_closed method usage with proper type checking
- Fix pytest.raises multiple exception type issue
- Fix line length violations (E501) across multiple files
- Add debugging logging to Noise security module for troubleshooting
- Update WebSocket transport examples and tests
- Improve transport registry error handling
This commit is contained in:
acul71
2025-08-11 01:25:49 +02:00
parent 64107b4648
commit fe4c17e8d1
16 changed files with 845 additions and 488 deletions

View File

@ -11,13 +11,14 @@ This script demonstrates:
import asyncio
import logging
import sys
from pathlib import Path
import sys
# Add the libp2p directory to the path so we can import it
sys.path.insert(0, str(Path(__file__).parent.parent))
import multiaddr
from libp2p.transport import (
create_transport,
create_transport_for_multiaddr,
@ -25,9 +26,8 @@ from libp2p.transport import (
get_transport_registry,
register_transport,
)
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.websocket.transport import WebsocketTransport
from libp2p.transport.upgrader import TransportUpgrader
# Set up logging
logging.basicConfig(level=logging.INFO)
@ -38,20 +38,21 @@ def demo_transport_registry():
"""Demonstrate the transport registry functionality."""
print("🔧 Transport Registry Demo")
print("=" * 50)
# Get the global registry
registry = get_transport_registry()
# Show supported protocols
supported = get_supported_transport_protocols()
print(f"Supported transport protocols: {supported}")
# Show registered transports
print("\nRegistered transports:")
for protocol in supported:
transport_class = registry.get_transport(protocol)
print(f" {protocol}: {transport_class.__name__}")
class_name = transport_class.__name__ if transport_class else "None"
print(f" {protocol}: {class_name}")
print()
@ -59,21 +60,21 @@ def demo_transport_factory():
"""Demonstrate the transport factory functions."""
print("🏭 Transport Factory Demo")
print("=" * 50)
# Create a dummy upgrader for WebSocket transport
upgrader = TransportUpgrader({}, {})
# Create transports using the factory function
try:
tcp_transport = create_transport("tcp")
print(f"✅ Created TCP transport: {type(tcp_transport).__name__}")
ws_transport = create_transport("ws", upgrader)
print(f"✅ Created WebSocket transport: {type(ws_transport).__name__}")
except Exception as e:
print(f"❌ Error creating transport: {e}")
print()
@ -81,10 +82,10 @@ def demo_multiaddr_transport_selection():
"""Demonstrate automatic transport selection based on multiaddrs."""
print("🎯 Multiaddr Transport Selection Demo")
print("=" * 50)
# Create a dummy upgrader
upgrader = TransportUpgrader({}, {})
# Test different multiaddr types
test_addrs = [
"/ip4/127.0.0.1/tcp/8080",
@ -92,20 +93,20 @@ def demo_multiaddr_transport_selection():
"/ip6/::1/tcp/8080/ws",
"/dns4/example.com/tcp/443/ws",
]
for addr_str in test_addrs:
try:
maddr = multiaddr.Multiaddr(addr_str)
transport = create_transport_for_multiaddr(maddr, upgrader)
if transport:
print(f"{addr_str} -> {type(transport).__name__}")
else:
print(f"{addr_str} -> No transport found")
except Exception as e:
print(f"{addr_str} -> Error: {e}")
print()
@ -113,34 +114,37 @@ def demo_custom_transport_registration():
"""Demonstrate how to register custom transports."""
print("🔧 Custom Transport Registration Demo")
print("=" * 50)
# Create a dummy upgrader
upgrader = TransportUpgrader({}, {})
# Show current supported protocols
print(f"Before registration: {get_supported_transport_protocols()}")
# Register a custom transport (using TCP as an example)
class CustomTCPTransport(TCP):
"""Custom TCP transport for demonstration."""
def __init__(self):
super().__init__()
self.custom_flag = True
# Register the custom transport
register_transport("custom_tcp", CustomTCPTransport)
# Show updated supported protocols
print(f"After registration: {get_supported_transport_protocols()}")
# Test creating the custom transport
try:
custom_transport = create_transport("custom_tcp")
print(f"✅ Created custom transport: {type(custom_transport).__name__}")
print(f" Custom flag: {custom_transport.custom_flag}")
# Check if it has the custom flag (type-safe way)
if hasattr(custom_transport, "custom_flag"):
flag_value = getattr(custom_transport, "custom_flag", "Not found")
print(f" Custom flag: {flag_value}")
else:
print(" Custom flag: Not found")
except Exception as e:
print(f"❌ Error creating custom transport: {e}")
print()
@ -148,7 +152,7 @@ def demo_integration_with_libp2p():
"""Demonstrate how the new system integrates with libp2p."""
print("🚀 Libp2p Integration Demo")
print("=" * 50)
print("The new transport system integrates seamlessly with libp2p:")
print()
print("1. ✅ Automatic transport selection based on multiaddr")
@ -157,7 +161,7 @@ def demo_integration_with_libp2p():
print("4. ✅ Easy registration of new transport protocols")
print("5. ✅ No changes needed to existing libp2p code")
print()
print("Example usage in libp2p:")
print(" # This will automatically use WebSocket transport")
print(" host = new_host(listen_addrs=['/ip4/127.0.0.1/tcp/8080/ws'])")
@ -165,7 +169,7 @@ def demo_integration_with_libp2p():
print(" # This will automatically use TCP transport")
print(" host = new_host(listen_addrs=['/ip4/127.0.0.1/tcp/8080'])")
print()
print()
@ -174,14 +178,14 @@ async def main():
print("🎉 Py-libp2p Transport Integration Demo")
print("=" * 60)
print()
# Run all demos
demo_transport_registry()
demo_transport_factory()
demo_multiaddr_transport_selection()
demo_custom_transport_registration()
demo_integration_with_libp2p()
print("🎯 Summary of New Features:")
print("=" * 40)
print("✅ Transport Registry: Central registry for all transport implementations")
@ -202,4 +206,5 @@ if __name__ == "__main__":
except Exception as e:
print(f"\n❌ Demo failed with error: {e}")
import traceback
traceback.print_exc()

View File

@ -5,7 +5,6 @@ Simple TCP echo demo to verify basic libp2p functionality.
import argparse
import logging
import sys
import traceback
import multiaddr
@ -18,10 +17,10 @@ from libp2p.network.swarm import Swarm
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.peer.peerstore import PeerStore
from libp2p.security.insecure.transport import InsecureTransport, PLAINTEXT_PROTOCOL_ID
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
from libp2p.stream_muxer.yamux.yamux import Yamux
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.upgrader import TransportUpgrader
# Enable debug logging
logging.basicConfig(level=logging.DEBUG)
@ -31,12 +30,13 @@ logger = logging.getLogger("libp2p.tcp-example")
# Simple echo protocol
ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0")
async def echo_handler(stream):
"""Simple echo handler that echoes back any data received."""
try:
data = await stream.read(1024)
if data:
message = data.decode('utf-8', errors='replace')
message = data.decode("utf-8", errors="replace")
print(f"📥 Received: {message}")
print(f"📤 Echoing back: {message}")
await stream.write(data)
@ -45,6 +45,7 @@ async def echo_handler(stream):
logger.error(f"Echo handler error: {e}")
await stream.close()
def create_tcp_host():
"""Create a host with TCP transport."""
# Create key pair and peer store
@ -60,31 +61,35 @@ def create_tcp_host():
},
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
# Create TCP transport
transport = TCP()
# Create swarm and host
swarm = Swarm(peer_id, peer_store, upgrader, transport)
host = BasicHost(swarm)
return host
async def run(port: int, destination: str) -> None:
localhost_ip = "0.0.0.0"
if not destination:
# Create first host (listener) with TCP transport
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}")
try:
host = create_tcp_host()
logger.debug("Created TCP host")
# Set up echo handler
host.set_stream_handler(ECHO_PROTOCOL_ID, echo_handler)
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
async with (
host.run(listen_addrs=[listen_addr]),
trio.open_nursery() as (nursery),
):
# Start the peer-store cleanup task
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
@ -95,15 +100,15 @@ async def run(port: int, destination: str) -> None:
if not addrs:
print("❌ Error: No addresses found for the host")
return
server_addr = str(addrs[0])
client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/")
print("🌐 TCP Server Started Successfully!")
print("=" * 50)
print(f"📍 Server Address: {client_addr}")
print(f"🔧 Protocol: /echo/1.0.0")
print(f"🚀 Transport: TCP")
print("🔧 Protocol: /echo/1.0.0")
print("🚀 Transport: TCP")
print()
print("📋 To test the connection, run this in another terminal:")
print(f" python test_tcp_echo.py -d {client_addr}")
@ -112,7 +117,7 @@ async def run(port: int, destination: str) -> None:
print("" * 50)
await trio.sleep_forever()
except Exception as e:
print(f"❌ Error creating TCP server: {e}")
traceback.print_exc()
@ -121,13 +126,16 @@ async def run(port: int, destination: str) -> None:
else:
# Create second host (dialer) with TCP transport
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}")
try:
# Create a single host for client operations
host = create_tcp_host()
# Start the host for client operations
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
async with (
host.run(listen_addrs=[listen_addr]),
trio.open_nursery() as (nursery),
):
# Start the peer-store cleanup task
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
maddr = multiaddr.Multiaddr(destination)
@ -144,7 +152,7 @@ async def run(port: int, destination: str) -> None:
print("✅ Successfully connected to TCP server!")
except Exception as e:
error_msg = str(e)
print(f"\n❌ Connection Failed!")
print("\n❌ Connection Failed!")
print(f" Peer ID: {info.peer_id}")
print(f" Address: {destination}")
print(f" Error: {error_msg}")
@ -185,24 +193,28 @@ async def run(port: int, destination: str) -> None:
traceback.print_exc()
print("✅ TCP demo completed successfully!")
except Exception as e:
print(f"❌ Error creating TCP client: {e}")
traceback.print_exc()
return
def main() -> None:
description = "Simple TCP echo demo for libp2p"
parser = argparse.ArgumentParser(description=description)
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
parser.add_argument("-d", "--destination", type=str, help="destination multiaddr string")
parser.add_argument(
"-d", "--destination", type=str, help="destination multiaddr string"
)
args = parser.parse_args()
try:
trio.run(run, args.port, args.destination)
except KeyboardInterrupt:
pass
if __name__ == "__main__":
main()

View File

@ -5,16 +5,16 @@ Simple test script to verify WebSocket transport functionality.
import asyncio
import logging
import sys
from pathlib import Path
import sys
# Add the libp2p directory to the path so we can import it
sys.path.insert(0, str(Path(__file__).parent))
import multiaddr
from libp2p.transport import create_transport, create_transport_for_multiaddr
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.network.connection.raw_connection import RawConnection
# Set up logging
logging.basicConfig(level=logging.DEBUG)
@ -25,48 +25,57 @@ async def test_websocket_transport():
"""Test basic WebSocket transport functionality."""
print("🧪 Testing WebSocket Transport Functionality")
print("=" * 50)
# Create a dummy upgrader
upgrader = TransportUpgrader({}, {})
# Test creating WebSocket transport
try:
ws_transport = create_transport("ws", upgrader)
print(f"✅ WebSocket transport created: {type(ws_transport).__name__}")
# Test creating transport from multiaddr
ws_maddr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
ws_transport_from_maddr = create_transport_for_multiaddr(ws_maddr, upgrader)
print(f"✅ WebSocket transport from multiaddr: {type(ws_transport_from_maddr).__name__}")
print(
f"✅ WebSocket transport from multiaddr: "
f"{type(ws_transport_from_maddr).__name__}"
)
# Test creating listener
handler_called = False
async def test_handler(conn):
nonlocal handler_called
handler_called = True
print(f"✅ Connection handler called with: {type(conn).__name__}")
await conn.close()
listener = ws_transport.create_listener(test_handler)
print(f"✅ WebSocket listener created: {type(listener).__name__}")
# Test that the transport can be used
print(f"✅ WebSocket transport supports dialing: {hasattr(ws_transport, 'dial')}")
print(f"✅ WebSocket transport supports listening: {hasattr(ws_transport, 'create_listener')}")
print(
f"✅ WebSocket transport supports dialing: {hasattr(ws_transport, 'dial')}"
)
print(
f"✅ WebSocket transport supports listening: "
f"{hasattr(ws_transport, 'create_listener')}"
)
print("\n🎯 WebSocket Transport Test Results:")
print("✅ Transport creation: PASS")
print("✅ Multiaddr parsing: PASS")
print("✅ Listener creation: PASS")
print("✅ Interface compliance: PASS")
except Exception as e:
print(f"❌ WebSocket transport test failed: {e}")
import traceback
traceback.print_exc()
return False
return True
@ -74,22 +83,26 @@ async def test_transport_registry():
"""Test the transport registry functionality."""
print("\n🔧 Testing Transport Registry")
print("=" * 30)
from libp2p.transport import get_transport_registry, get_supported_transport_protocols
from libp2p.transport import (
get_supported_transport_protocols,
get_transport_registry,
)
registry = get_transport_registry()
supported = get_supported_transport_protocols()
print(f"Supported protocols: {supported}")
# Test getting transports
for protocol in supported:
transport_class = registry.get_transport(protocol)
print(f" {protocol}: {transport_class.__name__}")
class_name = transport_class.__name__ if transport_class else "None"
print(f" {protocol}: {class_name}")
# Test creating transports through registry
upgrader = TransportUpgrader({}, {})
for protocol in supported:
try:
transport = registry.create_transport(protocol, upgrader)
@ -106,17 +119,17 @@ async def main():
print("🚀 WebSocket Transport Integration Test Suite")
print("=" * 60)
print()
# Run tests
success = await test_websocket_transport()
await test_transport_registry()
print("\n" + "=" * 60)
if success:
print("🎉 All tests passed! WebSocket transport is working correctly.")
else:
print("❌ Some tests failed. Check the output above for details.")
print("\n🚀 WebSocket transport is ready for use in py-libp2p!")
@ -128,4 +141,5 @@ if __name__ == "__main__":
except Exception as e:
print(f"\n❌ Test failed with error: {e}")
import traceback
traceback.print_exc()

View File

@ -1,21 +1,26 @@
import argparse
import logging
import signal
import sys
import traceback
import multiaddr
import trio
from libp2p.abc import INotifee
from libp2p.crypto.ed25519 import create_new_key_pair as create_ed25519_key_pair
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.custom_types import TProtocol
from libp2p.host.basic_host import BasicHost
from libp2p.network.swarm import Swarm
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.peer.peerstore import PeerStore
from libp2p.security.insecure.transport import InsecureTransport, PLAINTEXT_PROTOCOL_ID
from libp2p.security.noise.transport import Transport as NoiseTransport
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
from libp2p.security.noise.transport import (
PROTOCOL_ID as NOISE_PROTOCOL_ID,
Transport as NoiseTransport,
)
from libp2p.stream_muxer.yamux.yamux import Yamux
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.websocket.transport import WebsocketTransport
@ -25,6 +30,15 @@ logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("libp2p.websocket-example")
# Suppress KeyboardInterrupt by handling SIGINT directly
def signal_handler(signum, frame):
print("✅ Clean exit completed.")
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
# Simple echo protocol
ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0")
@ -34,7 +48,7 @@ async def echo_handler(stream):
try:
data = await stream.read(1024)
if data:
message = data.decode('utf-8', errors='replace')
message = data.decode("utf-8", errors="replace")
print(f"📥 Received: {message}")
print(f"📤 Echoing back: {message}")
await stream.write(data)
@ -44,7 +58,7 @@ async def echo_handler(stream):
await stream.close()
def create_websocket_host(listen_addrs=None, use_noise=False):
def create_websocket_host(listen_addrs=None, use_plaintext=False):
"""Create a host with WebSocket transport."""
# Create key pair and peer store
key_pair = create_new_key_pair()
@ -52,11 +66,22 @@ def create_websocket_host(listen_addrs=None, use_noise=False):
peer_store = PeerStore()
peer_store.add_key_pair(peer_id, key_pair)
if use_noise:
if use_plaintext:
# Create transport upgrader with plaintext security
upgrader = TransportUpgrader(
secure_transports_by_protocol={
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
},
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
else:
# Create separate Ed25519 key for Noise protocol
noise_key_pair = create_ed25519_key_pair()
# Create Noise transport
noise_transport = NoiseTransport(
libp2p_keypair=key_pair,
noise_privkey=key_pair.private_key,
noise_privkey=noise_key_pair.private_key,
early_data=None,
with_noise_pipes=False,
)
@ -68,43 +93,85 @@ def create_websocket_host(listen_addrs=None, use_noise=False):
},
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
else:
# Create transport upgrader with plaintext security
upgrader = TransportUpgrader(
secure_transports_by_protocol={
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
},
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
# Create WebSocket transport
transport = WebsocketTransport(upgrader)
# Create swarm and host
swarm = Swarm(peer_id, peer_store, upgrader, transport)
host = BasicHost(swarm)
return host
async def run(port: int, destination: str, use_noise: bool = False) -> None:
async def run(port: int, destination: str, use_plaintext: bool = False) -> None:
localhost_ip = "0.0.0.0"
if not destination:
# Create first host (listener) with WebSocket transport
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}/ws")
try:
host = create_websocket_host(use_noise=use_noise)
logger.debug(f"Created host with use_noise={use_noise}")
host = create_websocket_host(use_plaintext=use_plaintext)
logger.debug(f"Created host with use_plaintext={use_plaintext}")
# Set up echo handler
host.set_stream_handler(ECHO_PROTOCOL_ID, echo_handler)
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
# Add connection event handlers for debugging
class DebugNotifee(INotifee):
async def opened_stream(self, network, stream):
pass
async def closed_stream(self, network, stream):
pass
async def connected(self, network, conn):
print(
f"🔗 New libp2p connection established: "
f"{conn.muxed_conn.peer_id}"
)
if hasattr(conn.muxed_conn, "get_security_protocol"):
security = conn.muxed_conn.get_security_protocol()
else:
security = "Unknown"
print(f" Security: {security}")
async def disconnected(self, network, conn):
print(f"🔌 libp2p connection closed: {conn.muxed_conn.peer_id}")
async def listen(self, network, multiaddr):
pass
async def listen_close(self, network, multiaddr):
pass
host.get_network().register_notifee(DebugNotifee())
# Create a cancellation token for clean shutdown
cancel_scope = trio.CancelScope()
async def signal_handler():
with trio.open_signal_receiver(signal.SIGINT, signal.SIGTERM) as (
signal_receiver
):
async for sig in signal_receiver:
print(f"\n🛑 Received signal {sig}")
print("✅ Shutting down WebSocket server...")
cancel_scope.cancel()
return
async with (
host.run(listen_addrs=[listen_addr]),
trio.open_nursery() as (nursery),
):
# Start the peer-store cleanup task
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
# Start the signal handler
nursery.start_soon(signal_handler)
# Get the actual address and replace 0.0.0.0 with 127.0.0.1 for client
# connections
addrs = host.get_addrs()
@ -113,18 +180,19 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
print("❌ Error: No addresses found for the host")
print("Debug: host.get_addrs() returned empty list")
return
server_addr = str(addrs[0])
client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/")
print("🌐 WebSocket Server Started Successfully!")
print("=" * 50)
print(f"📍 Server Address: {client_addr}")
print(f"🔧 Protocol: /echo/1.0.0")
print(f"🚀 Transport: WebSocket (/ws)")
print("🔧 Protocol: /echo/1.0.0")
print("🚀 Transport: WebSocket (/ws)")
print()
print("📋 To test the connection, run this in another terminal:")
print(f" python websocket_demo.py -d {client_addr}")
plaintext_flag = " --plaintext" if use_plaintext else ""
print(f" python websocket_demo.py -d {client_addr}{plaintext_flag}")
print()
print("⏳ Waiting for incoming WebSocket connections...")
print("" * 50)
@ -132,32 +200,34 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
# Add a custom handler to show connection events
async def custom_echo_handler(stream):
peer_id = stream.muxed_conn.peer_id
print(f"\n🔗 New WebSocket Connection!")
print("\n🔗 New WebSocket Connection!")
print(f" Peer ID: {peer_id}")
print(f" Protocol: /echo/1.0.0")
print(" Protocol: /echo/1.0.0")
# Show remote address in multiaddr format
try:
remote_address = stream.get_remote_address()
if remote_address:
print(f" Remote: {remote_address}")
except Exception:
print(f" Remote: Unknown")
print(f"" * 40)
print(" Remote: Unknown")
print("" * 40)
# Call the original handler
await echo_handler(stream)
print(f"" * 40)
print("" * 40)
print(f"✅ Echo request completed for peer: {peer_id}")
print()
# Replace the handler with our custom one
host.set_stream_handler(ECHO_PROTOCOL_ID, custom_echo_handler)
await trio.sleep_forever()
# Wait indefinitely or until cancelled
with cancel_scope:
await trio.sleep_forever()
except Exception as e:
print(f"❌ Error creating WebSocket server: {e}")
traceback.print_exc()
@ -166,15 +236,47 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
else:
# Create second host (dialer) with WebSocket transport
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}/ws")
try:
# Create a single host for client operations
host = create_websocket_host(use_noise=use_noise)
host = create_websocket_host(use_plaintext=use_plaintext)
# Start the host for client operations
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
async with (
host.run(listen_addrs=[listen_addr]),
trio.open_nursery() as (nursery),
):
# Start the peer-store cleanup task
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
# Add connection event handlers for debugging
class ClientDebugNotifee(INotifee):
async def opened_stream(self, network, stream):
pass
async def closed_stream(self, network, stream):
pass
async def connected(self, network, conn):
print(
f"🔗 Client: libp2p connection established: "
f"{conn.muxed_conn.peer_id}"
)
async def disconnected(self, network, conn):
print(
f"🔌 Client: libp2p connection closed: "
f"{conn.muxed_conn.peer_id}"
)
async def listen(self, network, multiaddr):
pass
async def listen_close(self, network, multiaddr):
pass
host.get_network().register_notifee(ClientDebugNotifee())
maddr = multiaddr.Multiaddr(destination)
info = info_from_p2p_addr(maddr)
print("🔌 WebSocket Client Starting...")
@ -185,21 +287,34 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
try:
print("🔗 Connecting to WebSocket server...")
print(f" Security: {'Plaintext' if use_plaintext else 'Noise'}")
await host.connect(info)
print("✅ Successfully connected to WebSocket server!")
except Exception as e:
error_msg = str(e)
if "unable to connect" in error_msg or "SwarmException" in error_msg:
print(f"\n❌ Connection Failed!")
print(f" Peer ID: {info.peer_id}")
print(f" Address: {destination}")
print(f" Error: {error_msg}")
print()
print("💡 Troubleshooting:")
print(" • Make sure the WebSocket server is running")
print(" • Check that the server address is correct")
print(" • Verify the server is listening on the right port")
return
print("\n❌ Connection Failed!")
print(f" Peer ID: {info.peer_id}")
print(f" Address: {destination}")
print(f" Security: {'Plaintext' if use_plaintext else 'Noise'}")
print(f" Error: {error_msg}")
print(f" Error type: {type(e).__name__}")
# Add more detailed error information for debugging
if hasattr(e, "__cause__") and e.__cause__:
print(f" Root cause: {e.__cause__}")
print(f" Root cause type: {type(e.__cause__).__name__}")
print()
print("💡 Troubleshooting:")
print(" • Make sure the WebSocket server is running")
print(" • Check that the server address is correct")
print(" • Verify the server is listening on the right port")
print(
" • Ensure both client and server use the same sec protocol"
)
if not use_plaintext:
print(" • Noise over WebSocket may have compatibility issues")
return
# Create a stream and send test data
try:
@ -242,8 +357,18 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
finally:
# Ensure stream is closed
try:
if stream and not await stream.is_closed():
await stream.close()
if stream:
# Check if stream has is_closed method and use it
has_is_closed = hasattr(stream, "is_closed") and callable(
getattr(stream, "is_closed")
)
if has_is_closed:
# type: ignore[attr-defined]
if not await stream.is_closed():
await stream.close()
else:
# Fallback: just try to close the stream
await stream.close()
except Exception:
pass
@ -256,7 +381,10 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
print("✅ libp2p integration verified!")
print()
print("🚀 Your WebSocket transport is ready for production use!")
# Add a small delay to ensure all cleanup is complete
await trio.sleep(0.1)
except Exception as e:
print(f"❌ Error creating WebSocket client: {e}")
traceback.print_exc()
@ -266,12 +394,15 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
def main() -> None:
description = """
This program demonstrates the libp2p WebSocket transport.
First run 'python websocket_demo.py -p <PORT> [--noise]' to start a WebSocket server.
Then run 'python websocket_demo.py <ANOTHER_PORT> -d <DESTINATION> [--noise]'
First run
'python websocket_demo.py -p <PORT> [--plaintext]' to start a WebSocket server.
Then run
'python websocket_demo.py <ANOTHER_PORT> -d <DESTINATION> [--plaintext]'
where <DESTINATION> is the multiaddress shown by the server.
By default, this example uses plaintext security for communication.
Use --noise for testing with Noise encryption (experimental).
By default, this example uses Noise encryption for secure communication.
Use --plaintext for testing with unencrypted communication
(not recommended for production).
"""
example_maddr = (
@ -287,20 +418,30 @@ def main() -> None:
help=f"destination multiaddr string, e.g. {example_maddr}",
)
parser.add_argument(
"--noise",
"--plaintext",
action="store_true",
help="use Noise encryption instead of plaintext security",
help=(
"use plaintext security instead of Noise encryption "
"(not recommended for production)"
),
)
args = parser.parse_args()
# Determine security mode: use plaintext by default, Noise if --noise is specified
use_noise = args.noise
# Determine security mode: use Noise by default,
# plaintext if --plaintext is specified
use_plaintext = args.plaintext
try:
trio.run(run, args.port, args.destination, use_noise)
trio.run(run, args.port, args.destination, use_plaintext)
except KeyboardInterrupt:
pass
# This is expected when Ctrl+C is pressed
# The signal handler already printed the shutdown message
print("✅ Clean exit completed.")
return
except Exception as e:
print(f"❌ Unexpected error: {e}")
return
if __name__ == "__main__":

View File

@ -19,6 +19,7 @@ from libp2p.abc import (
IPeerRouting,
IPeerStore,
ISecureTransport,
ITransport,
)
from libp2p.crypto.keys import (
KeyPair,
@ -231,14 +232,15 @@ def new_swarm(
)
# Create transport based on listen_addrs or default to TCP
transport: ITransport
if listen_addrs is None:
transport = TCP()
else:
# Use the first address to determine transport type
addr = listen_addrs[0]
transport = create_transport_for_multiaddr(addr, upgrader)
if transport is None:
transport_maybe = create_transport_for_multiaddr(addr, upgrader)
if transport_maybe is None:
# Fallback to TCP if no specific transport found
if addr.__contains__("tcp"):
transport = TCP()
@ -250,20 +252,8 @@ def new_swarm(
f"Unknown transport in listen_addrs: {listen_addrs}. "
f"Supported protocols: {supported_protocols}"
)
# Generate X25519 keypair for Noise
noise_key_pair = create_new_x25519_key_pair()
# Default security transports (using Noise as primary)
secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport] = sec_opt or {
NOISE_PROTOCOL_ID: NoiseTransport(
key_pair, noise_privkey=noise_key_pair.private_key
),
TProtocol(secio.ID): secio.Transport(key_pair),
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(
key_pair, peerstore=peerstore_opt
),
}
else:
transport = transport_maybe
# Use given muxer preference if provided, otherwise use global default
if muxer_preference is not None:

View File

@ -1,3 +1,4 @@
import logging
from typing import (
cast,
)
@ -15,6 +16,8 @@ from libp2p.io.msgio import (
FixedSizeLenMsgReadWriter,
)
logger = logging.getLogger(__name__)
SIZE_NOISE_MESSAGE_LEN = 2
MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1
SIZE_NOISE_MESSAGE_BODY_LEN = 2
@ -50,18 +53,25 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
self.noise_state = noise_state
async def write_msg(self, msg: bytes, prefix_encoded: bool = False) -> None:
logger.debug(f"Noise write_msg: encrypting {len(msg)} bytes")
data_encrypted = self.encrypt(msg)
if prefix_encoded:
# Manually add the prefix if needed
data_encrypted = self.prefix + data_encrypted
logger.debug(f"Noise write_msg: writing {len(data_encrypted)} encrypted bytes")
await self.read_writer.write_msg(data_encrypted)
logger.debug("Noise write_msg: write completed successfully")
async def read_msg(self, prefix_encoded: bool = False) -> bytes:
logger.debug("Noise read_msg: reading encrypted message")
noise_msg_encrypted = await self.read_writer.read_msg()
logger.debug(f"Noise read_msg: read {len(noise_msg_encrypted)} encrypted bytes")
if prefix_encoded:
return self.decrypt(noise_msg_encrypted[len(self.prefix) :])
result = self.decrypt(noise_msg_encrypted[len(self.prefix) :])
else:
return self.decrypt(noise_msg_encrypted)
result = self.decrypt(noise_msg_encrypted)
logger.debug(f"Noise read_msg: decrypted to {len(result)} bytes")
return result
async def close(self) -> None:
await self.read_writer.close()

View File

@ -1,6 +1,7 @@
from dataclasses import (
dataclass,
)
import logging
from libp2p.crypto.keys import (
PrivateKey,
@ -12,6 +13,8 @@ from libp2p.crypto.serialization import (
from .pb import noise_pb2 as noise_pb
logger = logging.getLogger(__name__)
SIGNED_DATA_PREFIX = "noise-libp2p-static-key:"
@ -48,6 +51,8 @@ def make_handshake_payload_sig(
id_privkey: PrivateKey, noise_static_pubkey: PublicKey
) -> bytes:
data = make_data_to_be_signed(noise_static_pubkey)
logger.debug(f"make_handshake_payload_sig: signing data length: {len(data)}")
logger.debug(f"make_handshake_payload_sig: signing data hex: {data.hex()}")
return id_privkey.sign(data)
@ -60,4 +65,27 @@ def verify_handshake_payload_sig(
2. signed by the private key corresponding to `id_pubkey`
"""
expected_data = make_data_to_be_signed(noise_static_pubkey)
return payload.id_pubkey.verify(expected_data, payload.id_sig)
logger.debug(
f"verify_handshake_payload_sig: payload.id_pubkey type: "
f"{type(payload.id_pubkey)}"
)
logger.debug(
f"verify_handshake_payload_sig: noise_static_pubkey type: "
f"{type(noise_static_pubkey)}"
)
logger.debug(
f"verify_handshake_payload_sig: expected_data length: {len(expected_data)}"
)
logger.debug(
f"verify_handshake_payload_sig: expected_data hex: {expected_data.hex()}"
)
logger.debug(
f"verify_handshake_payload_sig: payload.id_sig length: {len(payload.id_sig)}"
)
try:
result = payload.id_pubkey.verify(expected_data, payload.id_sig)
logger.debug(f"verify_handshake_payload_sig: verification result: {result}")
return result
except Exception as e:
logger.error(f"verify_handshake_payload_sig: verification exception: {e}")
return False

View File

@ -2,6 +2,7 @@ from abc import (
ABC,
abstractmethod,
)
import logging
from cryptography.hazmat.primitives import (
serialization,
@ -46,6 +47,8 @@ from .messages import (
verify_handshake_payload_sig,
)
logger = logging.getLogger(__name__)
class IPattern(ABC):
@abstractmethod
@ -95,6 +98,7 @@ class PatternXX(BasePattern):
self.early_data = early_data
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
logger.debug(f"Noise XX handshake_inbound started for peer {self.local_peer}")
noise_state = self.create_noise_state()
noise_state.set_as_responder()
noise_state.start_handshake()
@ -107,15 +111,22 @@ class PatternXX(BasePattern):
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
# Consume msg#1.
logger.debug("Noise XX handshake_inbound: reading msg#1")
await read_writer.read_msg()
logger.debug("Noise XX handshake_inbound: read msg#1 successfully")
# Send msg#2, which should include our handshake payload.
logger.debug("Noise XX handshake_inbound: preparing msg#2")
our_payload = self.make_handshake_payload()
msg_2 = our_payload.serialize()
logger.debug(f"Noise XX handshake_inbound: sending msg#2 ({len(msg_2)} bytes)")
await read_writer.write_msg(msg_2)
logger.debug("Noise XX handshake_inbound: sent msg#2 successfully")
# Receive and consume msg#3.
logger.debug("Noise XX handshake_inbound: reading msg#3")
msg_3 = await read_writer.read_msg()
logger.debug(f"Noise XX handshake_inbound: read msg#3 ({len(msg_3)} bytes)")
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
if handshake_state.rs is None:
@ -147,6 +158,7 @@ class PatternXX(BasePattern):
async def handshake_outbound(
self, conn: IRawConnection, remote_peer: ID
) -> ISecureConn:
logger.debug(f"Noise XX handshake_outbound started to peer {remote_peer}")
noise_state = self.create_noise_state()
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
@ -159,11 +171,15 @@ class PatternXX(BasePattern):
raise NoiseStateError("Handshake state is not initialized")
# Send msg#1, which is *not* encrypted.
logger.debug("Noise XX handshake_outbound: sending msg#1")
msg_1 = b""
await read_writer.write_msg(msg_1)
logger.debug("Noise XX handshake_outbound: sent msg#1 successfully")
# Read msg#2 from the remote, which contains the public key of the peer.
logger.debug("Noise XX handshake_outbound: reading msg#2")
msg_2 = await read_writer.read_msg()
logger.debug(f"Noise XX handshake_outbound: read msg#2 ({len(msg_2)} bytes)")
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
if handshake_state.rs is None:
@ -174,8 +190,27 @@ class PatternXX(BasePattern):
)
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
logger.debug(
f"Noise XX handshake_outbound: verifying signature for peer {remote_peer}"
)
logger.debug(
f"Noise XX handshake_outbound: remote_pubkey type: {type(remote_pubkey)}"
)
id_pubkey_repr = peer_handshake_payload.id_pubkey.to_bytes().hex()
logger.debug(
f"Noise XX handshake_outbound: peer_handshake_payload.id_pubkey: "
f"{id_pubkey_repr}"
)
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
logger.error(
f"Noise XX handshake_outbound: signature verification failed for peer "
f"{remote_peer}"
)
raise InvalidSignature
logger.debug(
f"Noise XX handshake_outbound: signature verification successful for peer "
f"{remote_peer}"
)
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
if remote_peer_id_from_pubkey != remote_peer:
raise PeerIDMismatchesPubkey(

View File

@ -1,17 +1,19 @@
from .tcp.tcp import TCP
from .websocket.transport import WebsocketTransport
from .transport_registry import (
TransportRegistry,
TransportRegistry,
create_transport_for_multiaddr,
get_transport_registry,
register_transport,
get_supported_transport_protocols,
)
from .upgrader import TransportUpgrader
from libp2p.abc import ITransport
def create_transport(protocol: str, upgrader=None):
def create_transport(protocol: str, upgrader: TransportUpgrader | None = None) -> ITransport:
"""
Convenience function to create a transport instance.
:param protocol: The transport protocol ("tcp", "ws", or custom)
:param upgrader: Optional transport upgrader (required for WebSocket)
:return: Transport instance
@ -28,7 +30,10 @@ def create_transport(protocol: str, upgrader=None):
registry = get_transport_registry()
transport_class = registry.get_transport(protocol)
if transport_class:
return registry.create_transport(protocol, upgrader)
transport = registry.create_transport(protocol, upgrader)
if transport is None:
raise ValueError(f"Failed to create transport for protocol: {protocol}")
return transport
else:
raise ValueError(f"Unsupported transport protocol: {protocol}")

View File

@ -3,13 +3,15 @@ Transport registry for dynamic transport selection based on multiaddr protocols.
"""
import logging
from typing import Dict, Type, Optional
from typing import Any
from multiaddr import Multiaddr
from multiaddr.protocols import Protocol
from libp2p.abc import ITransport
from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.websocket.transport import WebsocketTransport
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.websocket.transport import WebsocketTransport
logger = logging.getLogger("libp2p.transport.registry")
@ -17,28 +19,29 @@ logger = logging.getLogger("libp2p.transport.registry")
def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool:
"""
Validate that a multiaddr has a valid TCP structure.
:param maddr: The multiaddr to validate
:return: True if valid TCP structure, False otherwise
"""
try:
# TCP multiaddr should have structure like /ip4/127.0.0.1/tcp/8080
# or /ip6/::1/tcp/8080
protocols = maddr.protocols()
protocols: list[Protocol] = list(maddr.protocols())
# Must have at least 2 protocols: network (ip4/ip6) + tcp
if len(protocols) < 2:
return False
# First protocol should be a network protocol (ip4, ip6, dns4, dns6)
if protocols[0].name not in ["ip4", "ip6", "dns4", "dns6"]:
return False
# Second protocol should be tcp
if protocols[1].name != "tcp":
return False
# Should not have any protocols after tcp (unless it's a valid continuation like p2p)
# Should not have any protocols after tcp (unless it's a valid
# continuation like p2p)
# For now, we'll be strict and only allow network + tcp
if len(protocols) > 2:
# Check if the additional protocols are valid continuations
@ -46,9 +49,9 @@ def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool:
for i in range(2, len(protocols)):
if protocols[i].name not in valid_continuations:
return False
return True
except Exception:
return False
@ -56,31 +59,31 @@ def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool:
def _is_valid_websocket_multiaddr(maddr: Multiaddr) -> bool:
"""
Validate that a multiaddr has a valid WebSocket structure.
:param maddr: The multiaddr to validate
:return: True if valid WebSocket structure, False otherwise
"""
try:
# WebSocket multiaddr should have structure like /ip4/127.0.0.1/tcp/8080/ws
# or /ip6/::1/tcp/8080/ws
protocols = maddr.protocols()
protocols: list[Protocol] = list(maddr.protocols())
# Must have at least 3 protocols: network (ip4/ip6/dns4/dns6) + tcp + ws
if len(protocols) < 3:
return False
# First protocol should be a network protocol (ip4, ip6, dns4, dns6)
if protocols[0].name not in ["ip4", "ip6", "dns4", "dns6"]:
return False
# Second protocol should be tcp
if protocols[1].name != "tcp":
return False
# Last protocol should be ws
if protocols[-1].name != "ws":
return False
# Should not have any protocols between tcp and ws
if len(protocols) > 3:
# Check if the additional protocols are valid continuations
@ -88,9 +91,9 @@ def _is_valid_websocket_multiaddr(maddr: Multiaddr) -> bool:
for i in range(2, len(protocols) - 1):
if protocols[i].name not in valid_continuations:
return False
return True
except Exception:
return False
@ -99,46 +102,52 @@ class TransportRegistry:
"""
Registry for mapping multiaddr protocols to transport implementations.
"""
def __init__(self):
self._transports: Dict[str, Type[ITransport]] = {}
def __init__(self) -> None:
self._transports: dict[str, type[ITransport]] = {}
self._register_default_transports()
def _register_default_transports(self) -> None:
"""Register the default transport implementations."""
# Register TCP transport for /tcp protocol
self.register_transport("tcp", TCP)
# Register WebSocket transport for /ws protocol
self.register_transport("ws", WebsocketTransport)
def register_transport(self, protocol: str, transport_class: Type[ITransport]) -> None:
def register_transport(
self, protocol: str, transport_class: type[ITransport]
) -> None:
"""
Register a transport class for a specific protocol.
:param protocol: The protocol identifier (e.g., "tcp", "ws")
:param transport_class: The transport class to register
"""
self._transports[protocol] = transport_class
logger.debug(f"Registered transport {transport_class.__name__} for protocol {protocol}")
def get_transport(self, protocol: str) -> Optional[Type[ITransport]]:
logger.debug(
f"Registered transport {transport_class.__name__} for protocol {protocol}"
)
def get_transport(self, protocol: str) -> type[ITransport] | None:
"""
Get the transport class for a specific protocol.
:param protocol: The protocol identifier
:return: The transport class or None if not found
"""
return self._transports.get(protocol)
def get_supported_protocols(self) -> list[str]:
"""Get list of supported transport protocols."""
return list(self._transports.keys())
def create_transport(self, protocol: str, upgrader: Optional[TransportUpgrader] = None, **kwargs) -> Optional[ITransport]:
def create_transport(
self, protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any
) -> ITransport | None:
"""
Create a transport instance for a specific protocol.
:param protocol: The protocol identifier
:param upgrader: The transport upgrader instance (required for WebSocket)
:param kwargs: Additional arguments for transport construction
@ -147,14 +156,17 @@ class TransportRegistry:
transport_class = self.get_transport(protocol)
if transport_class is None:
return None
try:
if protocol == "ws":
# WebSocket transport requires upgrader
if upgrader is None:
logger.warning(f"WebSocket transport '{protocol}' requires upgrader")
logger.warning(
f"WebSocket transport '{protocol}' requires upgrader"
)
return None
return transport_class(upgrader)
# Use explicit WebsocketTransport to avoid type issues
return WebsocketTransport(upgrader)
else:
# TCP transport doesn't require upgrader
return transport_class()
@ -172,15 +184,17 @@ def get_transport_registry() -> TransportRegistry:
return _global_registry
def register_transport(protocol: str, transport_class: Type[ITransport]) -> None:
def register_transport(protocol: str, transport_class: type[ITransport]) -> None:
"""Register a transport class in the global registry."""
_global_registry.register_transport(protocol, transport_class)
def create_transport_for_multiaddr(maddr: Multiaddr, upgrader: TransportUpgrader) -> Optional[ITransport]:
def create_transport_for_multiaddr(
maddr: Multiaddr, upgrader: TransportUpgrader
) -> ITransport | None:
"""
Create the appropriate transport for a given multiaddr.
:param maddr: The multiaddr to create transport for
:param upgrader: The transport upgrader instance
:return: Transport instance or None if no suitable transport found
@ -188,7 +202,7 @@ def create_transport_for_multiaddr(maddr: Multiaddr, upgrader: TransportUpgrader
try:
# Get all protocols in the multiaddr
protocols = [proto.name for proto in maddr.protocols()]
# Check for supported transport protocols in order of preference
# We need to validate that the multiaddr structure is valid for our transports
if "ws" in protocols:
@ -201,11 +215,14 @@ def create_transport_for_multiaddr(maddr: Multiaddr, upgrader: TransportUpgrader
# Check if the multiaddr has proper TCP structure
if _is_valid_tcp_multiaddr(maddr):
return _global_registry.create_transport("tcp", upgrader)
# If no supported transport protocol found or structure is invalid, return None
logger.warning(f"No supported transport protocol found or invalid structure in multiaddr: {maddr}")
logger.warning(
f"No supported transport protocol found or invalid structure in "
f"multiaddr: {maddr}"
)
return None
except Exception as e:
# Handle any errors gracefully (e.g., invalid multiaddr)
logger.warning(f"Error processing multiaddr {maddr}: {e}")

View File

@ -1,9 +1,13 @@
from trio.abc import Stream
import logging
from typing import Any
import trio
from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException
logger = logging.getLogger(__name__)
class P2PWebSocketConnection(ReadWriteCloser):
"""
@ -11,7 +15,7 @@ class P2PWebSocketConnection(ReadWriteCloser):
that libp2p protocols expect.
"""
def __init__(self, ws_connection, ws_context=None):
def __init__(self, ws_connection: Any, ws_context: Any = None) -> None:
self._ws_connection = ws_connection
self._ws_context = ws_context
self._read_buffer = b""
@ -19,57 +23,102 @@ class P2PWebSocketConnection(ReadWriteCloser):
async def write(self, data: bytes) -> None:
try:
logger.debug(f"WebSocket writing {len(data)} bytes")
# Send as a binary WebSocket message
await self._ws_connection.send_message(data)
logger.debug(f"WebSocket wrote {len(data)} bytes successfully")
except Exception as e:
logger.error(f"WebSocket write failed: {e}")
raise IOException from e
async def read(self, n: int | None = None) -> bytes:
"""
Read up to n bytes (if n is given), else read up to 64KiB.
This implementation provides byte-level access to WebSocket messages,
which is required for Noise protocol handshake.
"""
async with self._read_lock:
try:
logger.debug(
f"WebSocket read requested: n={n}, "
f"buffer_size={len(self._read_buffer)}"
)
# If we have buffered data, return it
if self._read_buffer:
if n is None:
result = self._read_buffer
self._read_buffer = b""
logger.debug(
f"WebSocket read returning all buffered data: "
f"{len(result)} bytes"
)
return result
else:
if len(self._read_buffer) >= n:
result = self._read_buffer[:n]
self._read_buffer = self._read_buffer[n:]
logger.debug(
f"WebSocket read returning {len(result)} bytes "
f"from buffer"
)
return result
else:
result = self._read_buffer
self._read_buffer = b""
return result
# We need more data, but we have some buffered
# Keep the buffered data and get more
logger.debug(
f"WebSocket read needs more data: have "
f"{len(self._read_buffer)}, need {n}"
)
pass
# If we need exactly n bytes but don't have enough, get more data
while n is not None and (
not self._read_buffer or len(self._read_buffer) < n
):
logger.debug(
f"WebSocket read getting more data: "
f"buffer_size={len(self._read_buffer)}, need={n}"
)
# Get the next WebSocket message and treat it as a byte stream
# This mimics the Go implementation's NextReader() approach
message = await self._ws_connection.get_message()
if isinstance(message, str):
message = message.encode("utf-8")
logger.debug(
f"WebSocket read received message: {len(message)} bytes"
)
# Add to buffer
self._read_buffer += message
# Get the next WebSocket message
message = await self._ws_connection.get_message()
if isinstance(message, str):
message = message.encode('utf-8')
# Add to buffer
self._read_buffer = message
# Return requested amount
if n is None:
result = self._read_buffer
self._read_buffer = b""
logger.debug(
f"WebSocket read returning all data: {len(result)} bytes"
)
return result
else:
if len(self._read_buffer) >= n:
result = self._read_buffer[:n]
self._read_buffer = self._read_buffer[n:]
logger.debug(
f"WebSocket read returning exact {len(result)} bytes"
)
return result
else:
# This should never happen due to the while loop above
result = self._read_buffer
self._read_buffer = b""
logger.debug(
f"WebSocket read returning remaining {len(result)} bytes"
)
return result
except Exception as e:
logger.error(f"WebSocket read failed: {e}")
raise IOException from e
async def close(self) -> None:
@ -83,12 +132,12 @@ class P2PWebSocketConnection(ReadWriteCloser):
# Try to get remote address from the WebSocket connection
try:
remote = self._ws_connection.remote
if hasattr(remote, 'address') and hasattr(remote, 'port'):
if hasattr(remote, "address") and hasattr(remote, "port"):
return str(remote.address), int(remote.port)
elif isinstance(remote, str):
# Parse address:port format
if ':' in remote:
host, port = remote.rsplit(':', 1)
if ":" in remote:
host, port = remote.rsplit(":", 1)
return host, int(port)
except Exception:
pass

View File

@ -1,6 +1,6 @@
from collections.abc import Awaitable, Callable
import logging
import socket
from typing import Any, Callable
from typing import Any
from multiaddr import Multiaddr
import trio
@ -9,7 +9,6 @@ from trio_websocket import serve_websocket
from libp2p.abc import IListener
from libp2p.custom_types import THandler
from libp2p.network.connection.raw_connection import RawConnection
from libp2p.transport.upgrader import TransportUpgrader
from .connection import P2PWebSocketConnection
@ -27,7 +26,8 @@ class WebsocketListener(IListener):
self._upgrader = upgrader
self._server = None
self._shutdown_event = trio.Event()
self._nursery = None
self._nursery: trio.Nursery | None = None
self._listeners: Any = None
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
logger.debug(f"WebsocketListener.listen called with {maddr}")
@ -47,56 +47,60 @@ class WebsocketListener(IListener):
if port_str is None:
raise ValueError(f"No TCP port found in multiaddr: {maddr}")
port = int(port_str)
logger.debug(f"WebsocketListener: host={host}, port={port}")
async def serve_websocket_tcp(
handler: Callable,
handler: Callable[[Any], Awaitable[None]],
port: int,
host: str,
task_status: trio.TaskStatus[list],
task_status: TaskStatus[Any],
) -> None:
"""Start TCP server and handle WebSocket connections manually"""
logger.debug("serve_websocket_tcp %s %s", host, port)
async def websocket_handler(request):
async def websocket_handler(request: Any) -> None:
"""Handle WebSocket requests"""
logger.debug("WebSocket request received")
try:
# Accept the WebSocket connection
ws_connection = await request.accept()
logger.debug("WebSocket handshake successful")
# Create the WebSocket connection wrapper
conn = P2PWebSocketConnection(ws_connection)
conn = P2PWebSocketConnection(ws_connection) # type: ignore[no-untyped-call]
# Call the handler function that was passed to create_listener
# This handler will handle the security and muxing upgrades
logger.debug("Calling connection handler")
await self._handler(conn)
# Don't keep the connection alive indefinitely
# Let the handler manage the connection lifecycle
logger.debug("Handler completed, connection will be managed by handler")
logger.debug(
"Handler completed, connection will be managed by handler"
)
except Exception as e:
logger.debug(f"WebSocket connection error: {e}")
logger.debug(f"Error type: {type(e)}")
import traceback
logger.debug(f"Traceback: {traceback.format_exc()}")
# Reject the connection
try:
await request.reject(400)
except:
except Exception:
pass
# Use trio_websocket.serve_websocket for proper WebSocket handling
from trio_websocket import serve_websocket
await serve_websocket(websocket_handler, host, port, None, task_status=task_status)
await serve_websocket(
websocket_handler, host, port, None, task_status=task_status
)
# Store the nursery for shutdown
self._nursery = nursery
# Start the server using nursery.start() like TCP does
logger.debug("Calling nursery.start()...")
started_listeners = await nursery.start(
@ -111,18 +115,21 @@ class WebsocketListener(IListener):
logger.error(f"Failed to start WebSocket listener for {maddr}")
return False
# Store the listeners for get_addrs() and close() - these are real SocketListener objects
# Store the listeners for get_addrs() and close() - these are real
# SocketListener objects
self._listeners = started_listeners
logger.debug(f"WebsocketListener.listen returning True with WebSocketServer object")
logger.debug(
"WebsocketListener.listen returning True with WebSocketServer object"
)
return True
def get_addrs(self) -> tuple[Multiaddr, ...]:
if not hasattr(self, '_listeners') or not self._listeners:
if not hasattr(self, "_listeners") or not self._listeners:
logger.debug("No listeners available for get_addrs()")
return ()
# Handle WebSocketServer objects
if hasattr(self._listeners, 'port'):
if hasattr(self._listeners, "port"):
# This is a WebSocketServer object
port = self._listeners.port
# Create a multiaddr from the port
@ -138,12 +145,12 @@ class WebsocketListener(IListener):
async def close(self) -> None:
"""Close the WebSocket listener and stop accepting new connections"""
logger.debug("WebsocketListener.close called")
if hasattr(self, '_listeners') and self._listeners:
if hasattr(self, "_listeners") and self._listeners:
# Signal shutdown
self._shutdown_event.set()
# Close the WebSocket server
if hasattr(self._listeners, 'aclose'):
if hasattr(self._listeners, "aclose"):
# This is a WebSocketServer object
logger.debug("Closing WebSocket server")
await self._listeners.aclose()
@ -152,15 +159,15 @@ class WebsocketListener(IListener):
# This is a list of listeners (like TCP)
logger.debug("Closing TCP listeners")
for listener in self._listeners:
listener.close()
await listener.aclose()
logger.debug("TCP listeners closed")
else:
# Unknown type, try to close it directly
logger.debug("Closing unknown listener type")
if hasattr(self._listeners, 'close'):
if hasattr(self._listeners, "close"):
self._listeners.close()
logger.debug("Unknown listener closed")
# Clear the listeners reference
self._listeners = None
logger.debug("WebsocketListener.close completed")

View File

@ -1,6 +1,6 @@
import logging
from multiaddr import Multiaddr
from trio_websocket import open_websocket_url
from libp2p.abc import IListener, ITransport
from libp2p.custom_types import THandler
@ -11,7 +11,7 @@ from libp2p.transport.upgrader import TransportUpgrader
from .connection import P2PWebSocketConnection
from .listener import WebsocketListener
logger = logging.getLogger("libp2p.transport.websocket")
logger = logging.getLogger(__name__)
class WebsocketTransport(ITransport):
@ -25,7 +25,7 @@ class WebsocketTransport(ITransport):
async def dial(self, maddr: Multiaddr) -> RawConnection:
"""Dial a WebSocket connection to the given multiaddr."""
logger.debug(f"WebsocketTransport.dial called with {maddr}")
# Extract host and port from multiaddr
host = (
maddr.value_for_protocol("ip4")
@ -45,6 +45,7 @@ class WebsocketTransport(ITransport):
try:
from trio_websocket import open_websocket_url
# Use the context manager but don't exit it immediately
# The connection will be closed when the RawConnection is closed
ws_context = open_websocket_url(ws_url)

View File

@ -2,20 +2,20 @@
Tests for the transport registry functionality.
"""
import pytest
from multiaddr import Multiaddr
from libp2p.abc import ITransport
from libp2p.abc import IListener, IRawConnection, ITransport
from libp2p.custom_types import THandler
from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.websocket.transport import WebsocketTransport
from libp2p.transport.transport_registry import (
TransportRegistry,
create_transport_for_multiaddr,
get_supported_transport_protocols,
get_transport_registry,
register_transport,
get_supported_transport_protocols,
)
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.websocket.transport import WebsocketTransport
class TestTransportRegistry:
@ -25,7 +25,7 @@ class TestTransportRegistry:
"""Test registry initialization."""
registry = TransportRegistry()
assert isinstance(registry, TransportRegistry)
# Check that default transports are registered
supported = registry.get_supported_protocols()
assert "tcp" in supported
@ -34,22 +34,28 @@ class TestTransportRegistry:
def test_register_transport(self):
"""Test transport registration."""
registry = TransportRegistry()
# Register a custom transport
class CustomTransport:
pass
class CustomTransport(ITransport):
async def dial(self, maddr: Multiaddr) -> IRawConnection:
raise NotImplementedError("CustomTransport dial not implemented")
def create_listener(self, handler_function: THandler) -> IListener:
raise NotImplementedError(
"CustomTransport create_listener not implemented"
)
registry.register_transport("custom", CustomTransport)
assert registry.get_transport("custom") == CustomTransport
def test_get_transport(self):
"""Test getting registered transports."""
registry = TransportRegistry()
# Test existing transports
assert registry.get_transport("tcp") == TCP
assert registry.get_transport("ws") == WebsocketTransport
# Test non-existent transport
assert registry.get_transport("nonexistent") is None
@ -57,7 +63,7 @@ class TestTransportRegistry:
"""Test getting supported protocols."""
registry = TransportRegistry()
protocols = registry.get_supported_protocols()
assert isinstance(protocols, list)
assert "tcp" in protocols
assert "ws" in protocols
@ -66,7 +72,7 @@ class TestTransportRegistry:
"""Test creating TCP transport."""
registry = TransportRegistry()
upgrader = TransportUpgrader({}, {})
transport = registry.create_transport("tcp", upgrader)
assert isinstance(transport, TCP)
@ -74,7 +80,7 @@ class TestTransportRegistry:
"""Test creating WebSocket transport."""
registry = TransportRegistry()
upgrader = TransportUpgrader({}, {})
transport = registry.create_transport("ws", upgrader)
assert isinstance(transport, WebsocketTransport)
@ -82,14 +88,14 @@ class TestTransportRegistry:
"""Test creating transport with invalid protocol."""
registry = TransportRegistry()
upgrader = TransportUpgrader({}, {})
transport = registry.create_transport("invalid", upgrader)
assert transport is None
def test_create_transport_websocket_no_upgrader(self):
"""Test that WebSocket transport requires upgrader."""
registry = TransportRegistry()
# This should fail gracefully and return None
transport = registry.create_transport("ws", None)
assert transport is None
@ -105,12 +111,19 @@ class TestGlobalRegistry:
def test_register_transport_global(self):
"""Test registering transport globally."""
class GlobalCustomTransport:
pass
class GlobalCustomTransport(ITransport):
async def dial(self, maddr: Multiaddr) -> IRawConnection:
raise NotImplementedError("GlobalCustomTransport dial not implemented")
def create_listener(self, handler_function: THandler) -> IListener:
raise NotImplementedError(
"GlobalCustomTransport create_listener not implemented"
)
# Register globally
register_transport("global_custom", GlobalCustomTransport)
# Check that it's available
registry = get_transport_registry()
assert registry.get_transport("global_custom") == GlobalCustomTransport
@ -129,79 +142,80 @@ class TestTransportFactory:
def test_create_transport_for_multiaddr_tcp(self):
"""Test creating transport for TCP multiaddr."""
upgrader = TransportUpgrader({}, {})
# TCP multiaddr
maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080")
transport = create_transport_for_multiaddr(maddr, upgrader)
assert transport is not None
assert isinstance(transport, TCP)
def test_create_transport_for_multiaddr_websocket(self):
"""Test creating transport for WebSocket multiaddr."""
upgrader = TransportUpgrader({}, {})
# WebSocket multiaddr
maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
transport = create_transport_for_multiaddr(maddr, upgrader)
assert transport is not None
assert isinstance(transport, WebsocketTransport)
def test_create_transport_for_multiaddr_websocket_secure(self):
"""Test creating transport for WebSocket multiaddr."""
upgrader = TransportUpgrader({}, {})
# WebSocket multiaddr
maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
transport = create_transport_for_multiaddr(maddr, upgrader)
assert transport is not None
assert isinstance(transport, WebsocketTransport)
def test_create_transport_for_multiaddr_ipv6(self):
"""Test creating transport for IPv6 multiaddr."""
upgrader = TransportUpgrader({}, {})
# IPv6 WebSocket multiaddr
maddr = Multiaddr("/ip6/::1/tcp/8080/ws")
transport = create_transport_for_multiaddr(maddr, upgrader)
assert transport is not None
assert isinstance(transport, WebsocketTransport)
def test_create_transport_for_multiaddr_dns(self):
"""Test creating transport for DNS multiaddr."""
upgrader = TransportUpgrader({}, {})
# DNS WebSocket multiaddr
maddr = Multiaddr("/dns4/example.com/tcp/443/ws")
transport = create_transport_for_multiaddr(maddr, upgrader)
assert transport is not None
assert isinstance(transport, WebsocketTransport)
def test_create_transport_for_multiaddr_unknown(self):
"""Test creating transport for unknown multiaddr."""
upgrader = TransportUpgrader({}, {})
# Unknown multiaddr
maddr = Multiaddr("/ip4/127.0.0.1/udp/8080")
transport = create_transport_for_multiaddr(maddr, upgrader)
assert transport is None
def test_create_transport_for_multiaddr_no_upgrader(self):
"""Test creating transport without upgrader."""
# This should work for TCP but not WebSocket
def test_create_transport_for_multiaddr_with_upgrader(self):
"""Test creating transport with upgrader."""
upgrader = TransportUpgrader({}, {})
# This should work for both TCP and WebSocket with upgrader
maddr_tcp = Multiaddr("/ip4/127.0.0.1/tcp/8080")
transport_tcp = create_transport_for_multiaddr(maddr_tcp, None)
transport_tcp = create_transport_for_multiaddr(maddr_tcp, upgrader)
assert transport_tcp is not None
maddr_ws = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
transport_ws = create_transport_for_multiaddr(maddr_ws, None)
# WebSocket transport creation should fail gracefully
assert transport_ws is None
transport_ws = create_transport_for_multiaddr(maddr_ws, upgrader)
assert transport_ws is not None
class TestTransportInterfaceCompliance:
@ -211,8 +225,8 @@ class TestTransportInterfaceCompliance:
"""Test that TCP transport implements ITransport."""
transport = TCP()
assert isinstance(transport, ITransport)
assert hasattr(transport, 'dial')
assert hasattr(transport, 'create_listener')
assert hasattr(transport, "dial")
assert hasattr(transport, "create_listener")
assert callable(transport.dial)
assert callable(transport.create_listener)
@ -221,8 +235,8 @@ class TestTransportInterfaceCompliance:
upgrader = TransportUpgrader({}, {})
transport = WebsocketTransport(upgrader)
assert isinstance(transport, ITransport)
assert hasattr(transport, 'dial')
assert hasattr(transport, 'create_listener')
assert hasattr(transport, "dial")
assert hasattr(transport, "create_listener")
assert callable(transport.dial)
assert callable(transport.create_listener)
@ -234,14 +248,22 @@ class TestErrorHandling:
"""Test handling of transport creation exceptions."""
registry = TransportRegistry()
upgrader = TransportUpgrader({}, {})
# Register a transport that raises an exception
class ExceptionTransport:
class ExceptionTransport(ITransport):
def __init__(self, *args, **kwargs):
raise RuntimeError("Transport creation failed")
async def dial(self, maddr: Multiaddr) -> IRawConnection:
raise NotImplementedError("ExceptionTransport dial not implemented")
def create_listener(self, handler_function: THandler) -> IListener:
raise NotImplementedError(
"ExceptionTransport create_listener not implemented"
)
registry.register_transport("exception", ExceptionTransport)
# Should handle exception gracefully and return None
transport = registry.create_transport("exception", upgrader)
assert transport is None
@ -249,12 +271,13 @@ class TestErrorHandling:
def test_invalid_multiaddr_handling(self):
"""Test handling of invalid multiaddrs."""
upgrader = TransportUpgrader({}, {})
# Test with a multiaddr that has an unsupported transport protocol
# This should be handled gracefully by our transport registry
maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/udp/1234") # udp is not a supported transport
# udp is not a supported transport
maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/udp/1234")
transport = create_transport_for_multiaddr(maddr, upgrader)
assert transport is None
@ -265,15 +288,15 @@ class TestIntegration:
"""Test using multiple transport types in the same registry."""
registry = TransportRegistry()
upgrader = TransportUpgrader({}, {})
# Create different transport types
tcp_transport = registry.create_transport("tcp", upgrader)
ws_transport = registry.create_transport("ws", upgrader)
# All should be different types
assert isinstance(tcp_transport, TCP)
assert isinstance(ws_transport, WebsocketTransport)
# All should be different instances
assert tcp_transport is not ws_transport
@ -281,15 +304,21 @@ class TestIntegration:
"""Test that transport registry persists across calls."""
registry1 = get_transport_registry()
registry2 = get_transport_registry()
# Should be the same instance
assert registry1 is registry2
# Register a transport in one
class PersistentTransport:
pass
class PersistentTransport(ITransport):
async def dial(self, maddr: Multiaddr) -> IRawConnection:
raise NotImplementedError("PersistentTransport dial not implemented")
def create_listener(self, handler_function: THandler) -> IListener:
raise NotImplementedError(
"PersistentTransport create_listener not implemented"
)
registry1.register_transport("persistent", PersistentTransport)
# Should be available in the other
assert registry2.get_transport("persistent") == PersistentTransport

View File

@ -1,23 +1,23 @@
from collections.abc import Sequence
import logging
from typing import Any
import pytest
import trio
from multiaddr import Multiaddr
import trio
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.custom_types import TProtocol
from libp2p.host.basic_host import BasicHost
from libp2p.network.swarm import Swarm
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo
from libp2p.peer.peerstore import PeerStore
from libp2p.security.insecure.transport import InsecureTransport
from libp2p.stream_muxer.yamux.yamux import Yamux
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.websocket.transport import WebsocketTransport
from libp2p.transport.websocket.listener import WebsocketListener
from libp2p.transport.exceptions import OpenConnectionError
logger = logging.getLogger(__name__)
PLAINTEXT_PROTOCOL_ID = "/plaintext/1.0.0"
@ -64,29 +64,30 @@ def create_upgrader():
)
# 2. Listener Basic Functionality Tests
@pytest.mark.trio
async def test_listener_basic_listen():
"""Test basic listen functionality"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test listening on IPv4
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
listener = transport.create_listener(lambda conn: None)
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
# Test that listener can be created and has required methods
assert hasattr(listener, 'listen')
assert hasattr(listener, 'close')
assert hasattr(listener, 'get_addrs')
assert hasattr(listener, "listen")
assert hasattr(listener, "close")
assert hasattr(listener, "get_addrs")
# Test that listener can handle the address
assert ma.value_for_protocol("ip4") == "127.0.0.1"
assert ma.value_for_protocol("tcp") == "0"
# Test that listener can be closed
await listener.close()
@ -96,14 +97,18 @@ async def test_listener_port_0_handling():
"""Test listening on port 0 gets actual port"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
listener = transport.create_listener(lambda conn: None)
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
# Test that the address can be parsed correctly
port_str = ma.value_for_protocol("tcp")
assert port_str == "0"
# Test that listener can be closed
await listener.close()
@ -113,14 +118,18 @@ async def test_listener_any_interface():
"""Test listening on 0.0.0.0"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
ma = Multiaddr("/ip4/0.0.0.0/tcp/0/ws")
listener = transport.create_listener(lambda conn: None)
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
# Test that the address can be parsed correctly
host = ma.value_for_protocol("ip4")
assert host == "0.0.0.0"
# Test that listener can be closed
await listener.close()
@ -130,16 +139,20 @@ async def test_listener_address_preservation():
"""Test that p2p IDs are preserved in addresses"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Create address with p2p ID
p2p_id = "12D3KooWL5xtmx8Mgc6tByjVaPPpTKH42QK7PUFQtZLabdSMKHpF"
ma = Multiaddr(f"/ip4/127.0.0.1/tcp/0/ws/p2p/{p2p_id}")
listener = transport.create_listener(lambda conn: None)
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
# Test that p2p ID is preserved in the address
addr_str = str(ma)
assert p2p_id in addr_str
# Test that listener can be closed
await listener.close()
@ -150,18 +163,18 @@ async def test_dial_basic():
"""Test basic dial functionality"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test that transport can parse addresses for dialing
ma = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
# Test that the address can be parsed correctly
host = ma.value_for_protocol("ip4")
port = ma.value_for_protocol("tcp")
assert host == "127.0.0.1"
assert port == "8080"
# Test that transport has the required methods
assert hasattr(transport, 'dial')
assert hasattr(transport, "dial")
assert callable(transport.dial)
@ -170,16 +183,16 @@ async def test_dial_with_p2p_id():
"""Test dialing with p2p ID suffix"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
p2p_id = "12D3KooWL5xtmx8Mgc6tByjVaPPpTKH42QK7PUFQtZLabdSMKHpF"
ma = Multiaddr(f"/ip4/127.0.0.1/tcp/8080/ws/p2p/{p2p_id}")
# Test that p2p ID is preserved in the address
addr_str = str(ma)
assert p2p_id in addr_str
# Test that transport can handle addresses with p2p IDs
assert hasattr(transport, 'dial')
assert hasattr(transport, "dial")
assert callable(transport.dial)
@ -188,41 +201,42 @@ async def test_dial_port_0_resolution():
"""Test dialing to resolved port 0 addresses"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test that transport can handle port 0 addresses
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
# Test that the address can be parsed correctly
port_str = ma.value_for_protocol("tcp")
assert port_str == "0"
# Test that transport has the required methods
assert hasattr(transport, 'dial')
assert hasattr(transport, "dial")
assert callable(transport.dial)
# 4. Address Validation Tests (CRITICAL)
def test_address_validation_ipv4():
"""Test IPv4 address validation"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# upgrader = create_upgrader() # Not used in this test
# Valid IPv4 WebSocket addresses
valid_addresses = [
"/ip4/127.0.0.1/tcp/8080/ws",
"/ip4/0.0.0.0/tcp/0/ws",
"/ip4/192.168.1.1/tcp/443/ws",
]
# Test valid addresses can be parsed
for addr_str in valid_addresses:
ma = Multiaddr(addr_str)
# Should not raise exception when creating transport address
transport_addr = str(ma)
assert "/ws" in transport_addr
# Test that transport can handle addresses with p2p IDs
p2p_addr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws/p2p/Qmb6owHp6eaWArVbcJJbQSyifyJBttMMjYV76N2hMbf5Vw")
p2p_addr = Multiaddr(
"/ip4/127.0.0.1/tcp/8080/ws/p2p/Qmb6owHp6eaWArVbcJJbQSyifyJBttMMjYV76N2hMbf5Vw"
)
# Should not raise exception when creating transport address
transport_addr = str(p2p_addr)
assert "/ws" in transport_addr
@ -230,15 +244,14 @@ def test_address_validation_ipv4():
def test_address_validation_ipv6():
"""Test IPv6 address validation"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# upgrader = create_upgrader() # Not used in this test
# Valid IPv6 WebSocket addresses
valid_addresses = [
"/ip6/::1/tcp/8080/ws",
"/ip6/2001:db8::1/tcp/443/ws",
]
# Test valid addresses can be parsed
for addr_str in valid_addresses:
ma = Multiaddr(addr_str)
@ -248,16 +261,15 @@ def test_address_validation_ipv6():
def test_address_validation_dns():
"""Test DNS address validation"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# upgrader = create_upgrader() # Not used in this test
# Valid DNS WebSocket addresses
valid_addresses = [
"/dns4/example.com/tcp/80/ws",
"/dns6/example.com/tcp/443/ws",
"/dnsaddr/example.com/tcp/8080/ws",
]
# Test valid addresses can be parsed
for addr_str in valid_addresses:
ma = Multiaddr(addr_str)
@ -267,21 +279,20 @@ def test_address_validation_dns():
def test_address_validation_mixed():
"""Test mixed address validation"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# upgrader = create_upgrader() # Not used in this test
# Mixed valid and invalid addresses
addresses = [
"/ip4/127.0.0.1/tcp/8080/ws", # Valid
"/ip4/127.0.0.1/tcp/8080", # Invalid (no /ws)
"/ip6/::1/tcp/8080/ws", # Valid
"/ip4/127.0.0.1/ws", # Invalid (no tcp)
"/ip4/127.0.0.1/tcp/8080", # Invalid (no /ws)
"/ip6/::1/tcp/8080/ws", # Valid
"/ip4/127.0.0.1/ws", # Invalid (no tcp)
"/dns4/example.com/tcp/80/ws", # Valid
]
# Convert to Multiaddr objects
multiaddrs = [Multiaddr(addr) for addr in addresses]
# Test that valid addresses can be processed
valid_count = 0
for ma in multiaddrs:
@ -292,7 +303,7 @@ def test_address_validation_mixed():
valid_count += 1
except Exception:
pass
assert valid_count == 3 # Should have 3 valid addresses
@ -302,30 +313,29 @@ async def test_dial_invalid_address():
"""Test dialing invalid addresses"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test dialing non-WebSocket addresses
invalid_addresses = [
Multiaddr("/ip4/127.0.0.1/tcp/8080"), # No /ws
Multiaddr("/ip4/127.0.0.1/ws"), # No tcp
]
for ma in invalid_addresses:
with pytest.raises((ValueError, OpenConnectionError, Exception)):
with pytest.raises(Exception):
await transport.dial(ma)
@pytest.mark.trio
async def test_listen_invalid_address():
"""Test listening on invalid addresses"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# upgrader = create_upgrader() # Not used in this test
# Test listening on non-WebSocket addresses
invalid_addresses = [
Multiaddr("/ip4/127.0.0.1/tcp/8080"), # No /ws
Multiaddr("/ip4/127.0.0.1/ws"), # No tcp
]
# Test that invalid addresses are properly identified
for ma in invalid_addresses:
# Test that the address parsing works correctly
@ -342,17 +352,17 @@ async def test_listen_port_in_use():
"""Test listening on port that's in use"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test that transport can handle port conflicts
ma1 = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
ma2 = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
# Test that both addresses can be parsed
assert ma1.value_for_protocol("tcp") == "8080"
assert ma2.value_for_protocol("tcp") == "8080"
# Test that transport can handle these addresses
assert hasattr(transport, 'create_listener')
assert hasattr(transport, "create_listener")
assert callable(transport.create_listener)
@ -362,16 +372,19 @@ async def test_connection_close():
"""Test connection closing"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test that transport has required methods
assert hasattr(transport, 'dial')
assert hasattr(transport, "dial")
assert callable(transport.dial)
# Test that listener can be created and closed
listener = transport.create_listener(lambda conn: None)
assert hasattr(listener, 'close')
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
assert hasattr(listener, "close")
assert callable(listener.close)
# Test that listener can be closed
await listener.close()
@ -381,32 +394,26 @@ async def test_multiple_connections():
"""Test multiple concurrent connections"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test that transport can handle multiple addresses
addresses = [
Multiaddr("/ip4/127.0.0.1/tcp/8080/ws"),
Multiaddr("/ip4/127.0.0.1/tcp/8081/ws"),
Multiaddr("/ip4/127.0.0.1/tcp/8082/ws"),
]
# Test that all addresses can be parsed
for addr in addresses:
host = addr.value_for_protocol("ip4")
port = addr.value_for_protocol("tcp")
assert host == "127.0.0.1"
assert port in ["8080", "8081", "8082"]
# Test that transport has required methods
assert hasattr(transport, 'dial')
assert hasattr(transport, "dial")
assert callable(transport.dial)
# Original test (kept for compatibility)
@pytest.mark.trio
async def test_websocket_dial_and_listen():
@ -414,42 +421,40 @@ async def test_websocket_dial_and_listen():
# Test that WebSocket transport can handle basic operations
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test that transport can create listeners
listener = transport.create_listener(lambda conn: None)
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
assert listener is not None
assert hasattr(listener, 'listen')
assert hasattr(listener, 'close')
assert hasattr(listener, 'get_addrs')
assert hasattr(listener, "listen")
assert hasattr(listener, "close")
assert hasattr(listener, "get_addrs")
# Test that transport can handle WebSocket addresses
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
assert ma.value_for_protocol("ip4") == "127.0.0.1"
assert ma.value_for_protocol("tcp") == "0"
assert "ws" in str(ma)
# Test that transport has dial method
assert hasattr(transport, 'dial')
assert hasattr(transport, "dial")
assert callable(transport.dial)
# Test that transport can handle WebSocket multiaddrs
ws_addr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
assert ws_addr.value_for_protocol("ip4") == "127.0.0.1"
assert ws_addr.value_for_protocol("tcp") == "8080"
assert "ws" in str(ws_addr)
# Cleanup
await listener.close()
import logging
logger = logging.getLogger(__name__)
@pytest.mark.trio
async def test_websocket_transport_basic():
"""Test basic WebSocket transport functionality without full libp2p stack"""
# Create WebSocket transport
key_pair = create_new_key_pair()
upgrader = TransportUpgrader(
@ -459,29 +464,31 @@ async def test_websocket_transport_basic():
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
transport = WebsocketTransport(upgrader)
assert transport is not None
assert hasattr(transport, 'dial')
assert hasattr(transport, 'create_listener')
listener = transport.create_listener(lambda conn: None)
assert hasattr(transport, "dial")
assert hasattr(transport, "create_listener")
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
assert listener is not None
assert hasattr(listener, 'listen')
assert hasattr(listener, 'close')
assert hasattr(listener, 'get_addrs')
assert hasattr(listener, "listen")
assert hasattr(listener, "close")
assert hasattr(listener, "get_addrs")
valid_addr = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
assert valid_addr.value_for_protocol("ip4") == "127.0.0.1"
assert valid_addr.value_for_protocol("tcp") == "0"
assert "ws" in str(valid_addr)
await listener.close()
@pytest.mark.trio
async def test_websocket_simple_connection():
"""Test WebSocket transport creation and basic functionality without real connections"""
"""Test WebSocket transport creation and basic functionality without real conn"""
# Create WebSocket transport
key_pair = create_new_key_pair()
upgrader = TransportUpgrader(
@ -491,32 +498,31 @@ async def test_websocket_simple_connection():
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
transport = WebsocketTransport(upgrader)
assert transport is not None
assert hasattr(transport, 'dial')
assert hasattr(transport, 'create_listener')
assert hasattr(transport, "dial")
assert hasattr(transport, "create_listener")
async def simple_handler(conn):
await conn.close()
listener = transport.create_listener(simple_handler)
assert listener is not None
assert hasattr(listener, 'listen')
assert hasattr(listener, 'close')
assert hasattr(listener, 'get_addrs')
assert hasattr(listener, "listen")
assert hasattr(listener, "close")
assert hasattr(listener, "get_addrs")
test_addr = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
assert test_addr.value_for_protocol("ip4") == "127.0.0.1"
assert test_addr.value_for_protocol("tcp") == "0"
assert "ws" in str(test_addr)
await listener.close()
@pytest.mark.trio
async def test_websocket_real_connection():
"""Test WebSocket transport creation and basic functionality"""
# Create WebSocket transport
key_pair = create_new_key_pair()
upgrader = TransportUpgrader(
@ -526,59 +532,57 @@ async def test_websocket_real_connection():
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
transport = WebsocketTransport(upgrader)
assert transport is not None
assert hasattr(transport, 'dial')
assert hasattr(transport, 'create_listener')
assert hasattr(transport, "dial")
assert hasattr(transport, "create_listener")
async def handler(conn):
await conn.close()
listener = transport.create_listener(handler)
assert listener is not None
assert hasattr(listener, 'listen')
assert hasattr(listener, 'close')
assert hasattr(listener, 'get_addrs')
assert hasattr(listener, "listen")
assert hasattr(listener, "close")
assert hasattr(listener, "get_addrs")
await listener.close()
@pytest.mark.trio
async def test_websocket_with_tcp_fallback():
"""Test WebSocket functionality using TCP transport as fallback"""
from tests.utils.factories import host_pair_factory
async with host_pair_factory() as (host_a, host_b):
assert len(host_a.get_network().connections) > 0
assert len(host_b.get_network().connections) > 0
test_protocol = TProtocol("/test/protocol/1.0.0")
received_data = None
async def test_handler(stream):
nonlocal received_data
received_data = await stream.read(1024)
await stream.write(b"Response from TCP")
await stream.close()
host_a.set_stream_handler(test_protocol, test_handler)
stream = await host_b.new_stream(host_a.get_id(), [test_protocol])
test_data = b"TCP protocol test"
await stream.write(test_data)
response = await stream.read(1024)
assert received_data == test_data
assert response == b"Response from TCP"
await stream.close()
@pytest.mark.trio
async def test_websocket_transport_interface():
"""Test WebSocket transport interface compliance"""
key_pair = create_new_key_pair()
upgrader = TransportUpgrader(
secure_transports_by_protocol={
@ -586,23 +590,26 @@ async def test_websocket_transport_interface():
},
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
transport = WebsocketTransport(upgrader)
assert hasattr(transport, 'dial')
assert hasattr(transport, 'create_listener')
assert hasattr(transport, "dial")
assert hasattr(transport, "create_listener")
assert callable(transport.dial)
assert callable(transport.create_listener)
listener = transport.create_listener(lambda conn: None)
assert hasattr(listener, 'listen')
assert hasattr(listener, 'close')
assert hasattr(listener, 'get_addrs')
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
assert hasattr(listener, "listen")
assert hasattr(listener, "close")
assert hasattr(listener, "get_addrs")
test_addr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
host = test_addr.value_for_protocol("ip4")
port = test_addr.value_for_protocol("tcp")
assert host == "127.0.0.1"
assert port == "8080"
await listener.close()

View File

@ -20,7 +20,7 @@ from libp2p.stream_muxer.yamux.yamux import Yamux
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.websocket.transport import WebsocketTransport
PLAINTEXT_PROTOCOL_ID = "/plaintext/1.0.0"
PLAINTEXT_PROTOCOL_ID = "/plaintext/2.0.0"
@pytest.mark.trio
@ -74,6 +74,11 @@ async def test_ping_with_js_node():
peer_id = ID.from_base58(peer_id_line)
maddr = Multiaddr(addr_line)
# Debug: Print what we're trying to connect to
print(f"JS Node Peer ID: {peer_id_line}")
print(f"JS Node Address: {addr_line}")
print(f"All JS Node lines: {lines}")
# Set up Python host
key_pair = create_new_key_pair()
py_peer_id = ID.from_pubkey(key_pair.public_key)
@ -86,13 +91,15 @@ async def test_ping_with_js_node():
},
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
transport = WebsocketTransport()
transport = WebsocketTransport(upgrader)
swarm = Swarm(py_peer_id, peer_store, upgrader, transport)
host = BasicHost(swarm)
# Connect to JS node
peer_info = PeerInfo(peer_id, [maddr])
print(f"Python trying to connect to: {peer_info}")
await trio.sleep(1)
try: