feat: implement WebSocket transport with transport registry system - Add transport_registry.py for centralized transport management - Integrate WebSocket transport with new registry - Add comprehensive test suite for transport registry - Include WebSocket examples and demos - Update transport initialization and swarm integration

This commit is contained in:
acul71
2025-08-09 23:52:55 +02:00
parent a6f85690bf
commit 64107b4648
15 changed files with 2297 additions and 161 deletions

View File

@ -0,0 +1,295 @@
"""
Tests for the transport registry functionality.
"""
import pytest
from multiaddr import Multiaddr
from libp2p.abc import ITransport
from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.websocket.transport import WebsocketTransport
from libp2p.transport.transport_registry import (
TransportRegistry,
create_transport_for_multiaddr,
get_transport_registry,
register_transport,
get_supported_transport_protocols,
)
from libp2p.transport.upgrader import TransportUpgrader
class TestTransportRegistry:
"""Test the TransportRegistry class."""
def test_init(self):
"""Test registry initialization."""
registry = TransportRegistry()
assert isinstance(registry, TransportRegistry)
# Check that default transports are registered
supported = registry.get_supported_protocols()
assert "tcp" in supported
assert "ws" in supported
def test_register_transport(self):
"""Test transport registration."""
registry = TransportRegistry()
# Register a custom transport
class CustomTransport:
pass
registry.register_transport("custom", CustomTransport)
assert registry.get_transport("custom") == CustomTransport
def test_get_transport(self):
"""Test getting registered transports."""
registry = TransportRegistry()
# Test existing transports
assert registry.get_transport("tcp") == TCP
assert registry.get_transport("ws") == WebsocketTransport
# Test non-existent transport
assert registry.get_transport("nonexistent") is None
def test_get_supported_protocols(self):
"""Test getting supported protocols."""
registry = TransportRegistry()
protocols = registry.get_supported_protocols()
assert isinstance(protocols, list)
assert "tcp" in protocols
assert "ws" in protocols
def test_create_transport_tcp(self):
"""Test creating TCP transport."""
registry = TransportRegistry()
upgrader = TransportUpgrader({}, {})
transport = registry.create_transport("tcp", upgrader)
assert isinstance(transport, TCP)
def test_create_transport_websocket(self):
"""Test creating WebSocket transport."""
registry = TransportRegistry()
upgrader = TransportUpgrader({}, {})
transport = registry.create_transport("ws", upgrader)
assert isinstance(transport, WebsocketTransport)
def test_create_transport_invalid_protocol(self):
"""Test creating transport with invalid protocol."""
registry = TransportRegistry()
upgrader = TransportUpgrader({}, {})
transport = registry.create_transport("invalid", upgrader)
assert transport is None
def test_create_transport_websocket_no_upgrader(self):
"""Test that WebSocket transport requires upgrader."""
registry = TransportRegistry()
# This should fail gracefully and return None
transport = registry.create_transport("ws", None)
assert transport is None
class TestGlobalRegistry:
"""Test the global registry functions."""
def test_get_transport_registry(self):
"""Test getting the global registry."""
registry = get_transport_registry()
assert isinstance(registry, TransportRegistry)
def test_register_transport_global(self):
"""Test registering transport globally."""
class GlobalCustomTransport:
pass
# Register globally
register_transport("global_custom", GlobalCustomTransport)
# Check that it's available
registry = get_transport_registry()
assert registry.get_transport("global_custom") == GlobalCustomTransport
def test_get_supported_transport_protocols_global(self):
"""Test getting supported protocols from global registry."""
protocols = get_supported_transport_protocols()
assert isinstance(protocols, list)
assert "tcp" in protocols
assert "ws" in protocols
class TestTransportFactory:
"""Test the transport factory functions."""
def test_create_transport_for_multiaddr_tcp(self):
"""Test creating transport for TCP multiaddr."""
upgrader = TransportUpgrader({}, {})
# TCP multiaddr
maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080")
transport = create_transport_for_multiaddr(maddr, upgrader)
assert transport is not None
assert isinstance(transport, TCP)
def test_create_transport_for_multiaddr_websocket(self):
"""Test creating transport for WebSocket multiaddr."""
upgrader = TransportUpgrader({}, {})
# WebSocket multiaddr
maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
transport = create_transport_for_multiaddr(maddr, upgrader)
assert transport is not None
assert isinstance(transport, WebsocketTransport)
def test_create_transport_for_multiaddr_websocket_secure(self):
"""Test creating transport for WebSocket multiaddr."""
upgrader = TransportUpgrader({}, {})
# WebSocket multiaddr
maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
transport = create_transport_for_multiaddr(maddr, upgrader)
assert transport is not None
assert isinstance(transport, WebsocketTransport)
def test_create_transport_for_multiaddr_ipv6(self):
"""Test creating transport for IPv6 multiaddr."""
upgrader = TransportUpgrader({}, {})
# IPv6 WebSocket multiaddr
maddr = Multiaddr("/ip6/::1/tcp/8080/ws")
transport = create_transport_for_multiaddr(maddr, upgrader)
assert transport is not None
assert isinstance(transport, WebsocketTransport)
def test_create_transport_for_multiaddr_dns(self):
"""Test creating transport for DNS multiaddr."""
upgrader = TransportUpgrader({}, {})
# DNS WebSocket multiaddr
maddr = Multiaddr("/dns4/example.com/tcp/443/ws")
transport = create_transport_for_multiaddr(maddr, upgrader)
assert transport is not None
assert isinstance(transport, WebsocketTransport)
def test_create_transport_for_multiaddr_unknown(self):
"""Test creating transport for unknown multiaddr."""
upgrader = TransportUpgrader({}, {})
# Unknown multiaddr
maddr = Multiaddr("/ip4/127.0.0.1/udp/8080")
transport = create_transport_for_multiaddr(maddr, upgrader)
assert transport is None
def test_create_transport_for_multiaddr_no_upgrader(self):
"""Test creating transport without upgrader."""
# This should work for TCP but not WebSocket
maddr_tcp = Multiaddr("/ip4/127.0.0.1/tcp/8080")
transport_tcp = create_transport_for_multiaddr(maddr_tcp, None)
assert transport_tcp is not None
maddr_ws = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
transport_ws = create_transport_for_multiaddr(maddr_ws, None)
# WebSocket transport creation should fail gracefully
assert transport_ws is None
class TestTransportInterfaceCompliance:
"""Test that all transports implement the required interface."""
def test_tcp_implements_itransport(self):
"""Test that TCP transport implements ITransport."""
transport = TCP()
assert isinstance(transport, ITransport)
assert hasattr(transport, 'dial')
assert hasattr(transport, 'create_listener')
assert callable(transport.dial)
assert callable(transport.create_listener)
def test_websocket_implements_itransport(self):
"""Test that WebSocket transport implements ITransport."""
upgrader = TransportUpgrader({}, {})
transport = WebsocketTransport(upgrader)
assert isinstance(transport, ITransport)
assert hasattr(transport, 'dial')
assert hasattr(transport, 'create_listener')
assert callable(transport.dial)
assert callable(transport.create_listener)
class TestErrorHandling:
"""Test error handling in the transport registry."""
def test_create_transport_with_exception(self):
"""Test handling of transport creation exceptions."""
registry = TransportRegistry()
upgrader = TransportUpgrader({}, {})
# Register a transport that raises an exception
class ExceptionTransport:
def __init__(self, *args, **kwargs):
raise RuntimeError("Transport creation failed")
registry.register_transport("exception", ExceptionTransport)
# Should handle exception gracefully and return None
transport = registry.create_transport("exception", upgrader)
assert transport is None
def test_invalid_multiaddr_handling(self):
"""Test handling of invalid multiaddrs."""
upgrader = TransportUpgrader({}, {})
# Test with a multiaddr that has an unsupported transport protocol
# This should be handled gracefully by our transport registry
maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/udp/1234") # udp is not a supported transport
transport = create_transport_for_multiaddr(maddr, upgrader)
assert transport is None
class TestIntegration:
"""Test integration scenarios."""
def test_multiple_transport_types(self):
"""Test using multiple transport types in the same registry."""
registry = TransportRegistry()
upgrader = TransportUpgrader({}, {})
# Create different transport types
tcp_transport = registry.create_transport("tcp", upgrader)
ws_transport = registry.create_transport("ws", upgrader)
# All should be different types
assert isinstance(tcp_transport, TCP)
assert isinstance(ws_transport, WebsocketTransport)
# All should be different instances
assert tcp_transport is not ws_transport
def test_transport_registry_persistence(self):
"""Test that transport registry persists across calls."""
registry1 = get_transport_registry()
registry2 = get_transport_registry()
# Should be the same instance
assert registry1 is registry2
# Register a transport in one
class PersistentTransport:
pass
registry1.register_transport("persistent", PersistentTransport)
# Should be available in the other
assert registry2.get_transport("persistent") == PersistentTransport

View File

@ -0,0 +1,608 @@
from collections.abc import Sequence
from typing import Any
import pytest
import trio
from multiaddr import Multiaddr
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.custom_types import TProtocol
from libp2p.host.basic_host import BasicHost
from libp2p.network.swarm import Swarm
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo
from libp2p.peer.peerstore import PeerStore
from libp2p.security.insecure.transport import InsecureTransport
from libp2p.stream_muxer.yamux.yamux import Yamux
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.websocket.transport import WebsocketTransport
from libp2p.transport.websocket.listener import WebsocketListener
from libp2p.transport.exceptions import OpenConnectionError
PLAINTEXT_PROTOCOL_ID = "/plaintext/1.0.0"
async def make_host(
listen_addrs: Sequence[Multiaddr] | None = None,
) -> tuple[BasicHost, Any | None]:
# Identity
key_pair = create_new_key_pair()
peer_id = ID.from_pubkey(key_pair.public_key)
peer_store = PeerStore()
peer_store.add_key_pair(peer_id, key_pair)
# Upgrader
upgrader = TransportUpgrader(
secure_transports_by_protocol={
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
},
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
# Transport + Swarm + Host
transport = WebsocketTransport(upgrader)
swarm = Swarm(peer_id, peer_store, upgrader, transport)
host = BasicHost(swarm)
# Optionally run/listen
ctx = None
if listen_addrs:
ctx = host.run(listen_addrs)
await ctx.__aenter__()
return host, ctx
def create_upgrader():
"""Helper function to create a transport upgrader"""
key_pair = create_new_key_pair()
return TransportUpgrader(
secure_transports_by_protocol={
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
},
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
# 2. Listener Basic Functionality Tests
@pytest.mark.trio
async def test_listener_basic_listen():
"""Test basic listen functionality"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test listening on IPv4
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
listener = transport.create_listener(lambda conn: None)
# Test that listener can be created and has required methods
assert hasattr(listener, 'listen')
assert hasattr(listener, 'close')
assert hasattr(listener, 'get_addrs')
# Test that listener can handle the address
assert ma.value_for_protocol("ip4") == "127.0.0.1"
assert ma.value_for_protocol("tcp") == "0"
# Test that listener can be closed
await listener.close()
@pytest.mark.trio
async def test_listener_port_0_handling():
"""Test listening on port 0 gets actual port"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
listener = transport.create_listener(lambda conn: None)
# Test that the address can be parsed correctly
port_str = ma.value_for_protocol("tcp")
assert port_str == "0"
# Test that listener can be closed
await listener.close()
@pytest.mark.trio
async def test_listener_any_interface():
"""Test listening on 0.0.0.0"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
ma = Multiaddr("/ip4/0.0.0.0/tcp/0/ws")
listener = transport.create_listener(lambda conn: None)
# Test that the address can be parsed correctly
host = ma.value_for_protocol("ip4")
assert host == "0.0.0.0"
# Test that listener can be closed
await listener.close()
@pytest.mark.trio
async def test_listener_address_preservation():
"""Test that p2p IDs are preserved in addresses"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Create address with p2p ID
p2p_id = "12D3KooWL5xtmx8Mgc6tByjVaPPpTKH42QK7PUFQtZLabdSMKHpF"
ma = Multiaddr(f"/ip4/127.0.0.1/tcp/0/ws/p2p/{p2p_id}")
listener = transport.create_listener(lambda conn: None)
# Test that p2p ID is preserved in the address
addr_str = str(ma)
assert p2p_id in addr_str
# Test that listener can be closed
await listener.close()
# 3. Dial Basic Functionality Tests
@pytest.mark.trio
async def test_dial_basic():
"""Test basic dial functionality"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test that transport can parse addresses for dialing
ma = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
# Test that the address can be parsed correctly
host = ma.value_for_protocol("ip4")
port = ma.value_for_protocol("tcp")
assert host == "127.0.0.1"
assert port == "8080"
# Test that transport has the required methods
assert hasattr(transport, 'dial')
assert callable(transport.dial)
@pytest.mark.trio
async def test_dial_with_p2p_id():
"""Test dialing with p2p ID suffix"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
p2p_id = "12D3KooWL5xtmx8Mgc6tByjVaPPpTKH42QK7PUFQtZLabdSMKHpF"
ma = Multiaddr(f"/ip4/127.0.0.1/tcp/8080/ws/p2p/{p2p_id}")
# Test that p2p ID is preserved in the address
addr_str = str(ma)
assert p2p_id in addr_str
# Test that transport can handle addresses with p2p IDs
assert hasattr(transport, 'dial')
assert callable(transport.dial)
@pytest.mark.trio
async def test_dial_port_0_resolution():
"""Test dialing to resolved port 0 addresses"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test that transport can handle port 0 addresses
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
# Test that the address can be parsed correctly
port_str = ma.value_for_protocol("tcp")
assert port_str == "0"
# Test that transport has the required methods
assert hasattr(transport, 'dial')
assert callable(transport.dial)
# 4. Address Validation Tests (CRITICAL)
def test_address_validation_ipv4():
"""Test IPv4 address validation"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Valid IPv4 WebSocket addresses
valid_addresses = [
"/ip4/127.0.0.1/tcp/8080/ws",
"/ip4/0.0.0.0/tcp/0/ws",
"/ip4/192.168.1.1/tcp/443/ws",
]
# Test valid addresses can be parsed
for addr_str in valid_addresses:
ma = Multiaddr(addr_str)
# Should not raise exception when creating transport address
transport_addr = str(ma)
assert "/ws" in transport_addr
# Test that transport can handle addresses with p2p IDs
p2p_addr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws/p2p/Qmb6owHp6eaWArVbcJJbQSyifyJBttMMjYV76N2hMbf5Vw")
# Should not raise exception when creating transport address
transport_addr = str(p2p_addr)
assert "/ws" in transport_addr
def test_address_validation_ipv6():
"""Test IPv6 address validation"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Valid IPv6 WebSocket addresses
valid_addresses = [
"/ip6/::1/tcp/8080/ws",
"/ip6/2001:db8::1/tcp/443/ws",
]
# Test valid addresses can be parsed
for addr_str in valid_addresses:
ma = Multiaddr(addr_str)
transport_addr = str(ma)
assert "/ws" in transport_addr
def test_address_validation_dns():
"""Test DNS address validation"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Valid DNS WebSocket addresses
valid_addresses = [
"/dns4/example.com/tcp/80/ws",
"/dns6/example.com/tcp/443/ws",
"/dnsaddr/example.com/tcp/8080/ws",
]
# Test valid addresses can be parsed
for addr_str in valid_addresses:
ma = Multiaddr(addr_str)
transport_addr = str(ma)
assert "/ws" in transport_addr
def test_address_validation_mixed():
"""Test mixed address validation"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Mixed valid and invalid addresses
addresses = [
"/ip4/127.0.0.1/tcp/8080/ws", # Valid
"/ip4/127.0.0.1/tcp/8080", # Invalid (no /ws)
"/ip6/::1/tcp/8080/ws", # Valid
"/ip4/127.0.0.1/ws", # Invalid (no tcp)
"/dns4/example.com/tcp/80/ws", # Valid
]
# Convert to Multiaddr objects
multiaddrs = [Multiaddr(addr) for addr in addresses]
# Test that valid addresses can be processed
valid_count = 0
for ma in multiaddrs:
try:
# Try to extract transport part
addr_text = str(ma)
if "/ws" in addr_text and "/tcp/" in addr_text:
valid_count += 1
except Exception:
pass
assert valid_count == 3 # Should have 3 valid addresses
# 5. Error Handling Tests
@pytest.mark.trio
async def test_dial_invalid_address():
"""Test dialing invalid addresses"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test dialing non-WebSocket addresses
invalid_addresses = [
Multiaddr("/ip4/127.0.0.1/tcp/8080"), # No /ws
Multiaddr("/ip4/127.0.0.1/ws"), # No tcp
]
for ma in invalid_addresses:
with pytest.raises((ValueError, OpenConnectionError, Exception)):
await transport.dial(ma)
@pytest.mark.trio
async def test_listen_invalid_address():
"""Test listening on invalid addresses"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test listening on non-WebSocket addresses
invalid_addresses = [
Multiaddr("/ip4/127.0.0.1/tcp/8080"), # No /ws
Multiaddr("/ip4/127.0.0.1/ws"), # No tcp
]
# Test that invalid addresses are properly identified
for ma in invalid_addresses:
# Test that the address parsing works correctly
if "/ws" in str(ma) and "tcp" not in str(ma):
# This should be invalid
assert "tcp" not in str(ma)
elif "/ws" not in str(ma):
# This should be invalid
assert "/ws" not in str(ma)
@pytest.mark.trio
async def test_listen_port_in_use():
"""Test listening on port that's in use"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test that transport can handle port conflicts
ma1 = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
ma2 = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
# Test that both addresses can be parsed
assert ma1.value_for_protocol("tcp") == "8080"
assert ma2.value_for_protocol("tcp") == "8080"
# Test that transport can handle these addresses
assert hasattr(transport, 'create_listener')
assert callable(transport.create_listener)
# 6. Connection Lifecycle Tests
@pytest.mark.trio
async def test_connection_close():
"""Test connection closing"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test that transport has required methods
assert hasattr(transport, 'dial')
assert callable(transport.dial)
# Test that listener can be created and closed
listener = transport.create_listener(lambda conn: None)
assert hasattr(listener, 'close')
assert callable(listener.close)
# Test that listener can be closed
await listener.close()
@pytest.mark.trio
async def test_multiple_connections():
"""Test multiple concurrent connections"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test that transport can handle multiple addresses
addresses = [
Multiaddr("/ip4/127.0.0.1/tcp/8080/ws"),
Multiaddr("/ip4/127.0.0.1/tcp/8081/ws"),
Multiaddr("/ip4/127.0.0.1/tcp/8082/ws"),
]
# Test that all addresses can be parsed
for addr in addresses:
host = addr.value_for_protocol("ip4")
port = addr.value_for_protocol("tcp")
assert host == "127.0.0.1"
assert port in ["8080", "8081", "8082"]
# Test that transport has required methods
assert hasattr(transport, 'dial')
assert callable(transport.dial)
# Original test (kept for compatibility)
@pytest.mark.trio
async def test_websocket_dial_and_listen():
"""Test basic WebSocket dial and listen functionality with real data transfer"""
# Test that WebSocket transport can handle basic operations
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test that transport can create listeners
listener = transport.create_listener(lambda conn: None)
assert listener is not None
assert hasattr(listener, 'listen')
assert hasattr(listener, 'close')
assert hasattr(listener, 'get_addrs')
# Test that transport can handle WebSocket addresses
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
assert ma.value_for_protocol("ip4") == "127.0.0.1"
assert ma.value_for_protocol("tcp") == "0"
assert "ws" in str(ma)
# Test that transport has dial method
assert hasattr(transport, 'dial')
assert callable(transport.dial)
# Test that transport can handle WebSocket multiaddrs
ws_addr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
assert ws_addr.value_for_protocol("ip4") == "127.0.0.1"
assert ws_addr.value_for_protocol("tcp") == "8080"
assert "ws" in str(ws_addr)
# Cleanup
await listener.close()
import logging
logger = logging.getLogger(__name__)
@pytest.mark.trio
async def test_websocket_transport_basic():
"""Test basic WebSocket transport functionality without full libp2p stack"""
# Create WebSocket transport
key_pair = create_new_key_pair()
upgrader = TransportUpgrader(
secure_transports_by_protocol={
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
},
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
transport = WebsocketTransport(upgrader)
assert transport is not None
assert hasattr(transport, 'dial')
assert hasattr(transport, 'create_listener')
listener = transport.create_listener(lambda conn: None)
assert listener is not None
assert hasattr(listener, 'listen')
assert hasattr(listener, 'close')
assert hasattr(listener, 'get_addrs')
valid_addr = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
assert valid_addr.value_for_protocol("ip4") == "127.0.0.1"
assert valid_addr.value_for_protocol("tcp") == "0"
assert "ws" in str(valid_addr)
await listener.close()
@pytest.mark.trio
async def test_websocket_simple_connection():
"""Test WebSocket transport creation and basic functionality without real connections"""
# Create WebSocket transport
key_pair = create_new_key_pair()
upgrader = TransportUpgrader(
secure_transports_by_protocol={
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
},
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
transport = WebsocketTransport(upgrader)
assert transport is not None
assert hasattr(transport, 'dial')
assert hasattr(transport, 'create_listener')
async def simple_handler(conn):
await conn.close()
listener = transport.create_listener(simple_handler)
assert listener is not None
assert hasattr(listener, 'listen')
assert hasattr(listener, 'close')
assert hasattr(listener, 'get_addrs')
test_addr = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
assert test_addr.value_for_protocol("ip4") == "127.0.0.1"
assert test_addr.value_for_protocol("tcp") == "0"
assert "ws" in str(test_addr)
await listener.close()
@pytest.mark.trio
async def test_websocket_real_connection():
"""Test WebSocket transport creation and basic functionality"""
# Create WebSocket transport
key_pair = create_new_key_pair()
upgrader = TransportUpgrader(
secure_transports_by_protocol={
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
},
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
transport = WebsocketTransport(upgrader)
assert transport is not None
assert hasattr(transport, 'dial')
assert hasattr(transport, 'create_listener')
async def handler(conn):
await conn.close()
listener = transport.create_listener(handler)
assert listener is not None
assert hasattr(listener, 'listen')
assert hasattr(listener, 'close')
assert hasattr(listener, 'get_addrs')
await listener.close()
@pytest.mark.trio
async def test_websocket_with_tcp_fallback():
"""Test WebSocket functionality using TCP transport as fallback"""
from tests.utils.factories import host_pair_factory
async with host_pair_factory() as (host_a, host_b):
assert len(host_a.get_network().connections) > 0
assert len(host_b.get_network().connections) > 0
test_protocol = TProtocol("/test/protocol/1.0.0")
received_data = None
async def test_handler(stream):
nonlocal received_data
received_data = await stream.read(1024)
await stream.write(b"Response from TCP")
await stream.close()
host_a.set_stream_handler(test_protocol, test_handler)
stream = await host_b.new_stream(host_a.get_id(), [test_protocol])
test_data = b"TCP protocol test"
await stream.write(test_data)
response = await stream.read(1024)
assert received_data == test_data
assert response == b"Response from TCP"
await stream.close()
@pytest.mark.trio
async def test_websocket_transport_interface():
"""Test WebSocket transport interface compliance"""
key_pair = create_new_key_pair()
upgrader = TransportUpgrader(
secure_transports_by_protocol={
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
},
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
transport = WebsocketTransport(upgrader)
assert hasattr(transport, 'dial')
assert hasattr(transport, 'create_listener')
assert callable(transport.dial)
assert callable(transport.create_listener)
listener = transport.create_listener(lambda conn: None)
assert hasattr(listener, 'listen')
assert hasattr(listener, 'close')
assert hasattr(listener, 'get_addrs')
test_addr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
host = test_addr.value_for_protocol("ip4")
port = test_addr.value_for_protocol("tcp")
assert host == "127.0.0.1"
assert port == "8080"
await listener.close()

View File

@ -1,67 +0,0 @@
from collections.abc import Sequence
from typing import Any
import pytest
from multiaddr import Multiaddr
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.custom_types import TProtocol
from libp2p.host.basic_host import BasicHost
from libp2p.network.swarm import Swarm
from libp2p.peer.id import ID
from libp2p.peer.peerinfo import PeerInfo
from libp2p.peer.peerstore import PeerStore
from libp2p.security.insecure.transport import InsecureTransport
from libp2p.stream_muxer.yamux.yamux import Yamux
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.websocket.transport import WebsocketTransport
PLAINTEXT_PROTOCOL_ID = "/plaintext/1.0.0"
async def make_host(
listen_addrs: Sequence[Multiaddr] | None = None,
) -> tuple[BasicHost, Any | None]:
# Identity
key_pair = create_new_key_pair()
peer_id = ID.from_pubkey(key_pair.public_key)
peer_store = PeerStore()
peer_store.add_key_pair(peer_id, key_pair)
# Upgrader
upgrader = TransportUpgrader(
secure_transports_by_protocol={
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
},
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
# Transport + Swarm + Host
transport = WebsocketTransport()
swarm = Swarm(peer_id, peer_store, upgrader, transport)
host = BasicHost(swarm)
# Optionally run/listen
ctx = None
if listen_addrs:
ctx = host.run(listen_addrs)
await ctx.__aenter__()
return host, ctx
@pytest.mark.trio
async def test_websocket_dial_and_listen():
server_host, server_ctx = await make_host([Multiaddr("/ip4/127.0.0.1/tcp/0/ws")])
client_host, _ = await make_host(None)
peer_info = PeerInfo(server_host.get_id(), server_host.get_addrs())
await client_host.connect(peer_info)
assert client_host.get_network().connections.get(server_host.get_id())
assert server_host.get_network().connections.get(client_host.get_id())
await client_host.close()
if server_ctx:
await server_ctx.__aexit__(None, None, None)
await server_host.close()