Files
py-libp2p/tests/core/transport/test_websocket.py
acul71 f4d5a44521 Fix type errors and linting issues
- Fix type annotation errors in transport_registry.py and __init__.py
- Fix line length violations in test files (E501 errors)
- Fix missing return type annotations
- Fix cryptography NameAttribute type errors with type: ignore
- Fix ExceptionGroup import for cross-version compatibility
- Fix test failure in test_wss_listen_without_tls_config by handling ExceptionGroup
- Fix len() calls with None arguments in test_tcp_data_transfer.py
- Fix missing attribute access errors on interface types
- Fix boolean type expectation errors in test_js_ws_ping.py
- Fix nursery context manager type errors

All tests now pass and linting is clean.
2025-09-08 04:18:10 +02:00

1618 lines
52 KiB
Python

from collections.abc import Sequence
import logging
from typing import Any
import pytest
from exceptiongroup import ExceptionGroup
from multiaddr import Multiaddr
import trio
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.custom_types import TProtocol
from libp2p.host.basic_host import BasicHost
from libp2p.network.swarm import Swarm
from libp2p.peer.id import ID
from libp2p.peer.peerstore import PeerStore
from libp2p.security.insecure.transport import InsecureTransport
from libp2p.stream_muxer.yamux.yamux import Yamux
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.websocket.multiaddr_utils import (
is_valid_websocket_multiaddr,
parse_websocket_multiaddr,
)
from libp2p.transport.websocket.transport import WebsocketTransport
logger = logging.getLogger(__name__)
PLAINTEXT_PROTOCOL_ID = "/plaintext/1.0.0"
async def make_host(
listen_addrs: Sequence[Multiaddr] | None = None,
) -> tuple[BasicHost, Any | None]:
# Identity
key_pair = create_new_key_pair()
peer_id = ID.from_pubkey(key_pair.public_key)
peer_store = PeerStore()
peer_store.add_key_pair(peer_id, key_pair)
# Upgrader
upgrader = TransportUpgrader(
secure_transports_by_protocol={
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
},
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
# Transport + Swarm + Host
transport = WebsocketTransport(upgrader)
swarm = Swarm(peer_id, peer_store, upgrader, transport)
host = BasicHost(swarm)
# Optionally run/listen
ctx = None
if listen_addrs:
ctx = host.run(listen_addrs)
await ctx.__aenter__()
return host, ctx
def create_upgrader():
"""Helper function to create a transport upgrader"""
key_pair = create_new_key_pair()
return TransportUpgrader(
secure_transports_by_protocol={
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
},
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
# 2. Listener Basic Functionality Tests
@pytest.mark.trio
async def test_listener_basic_listen():
"""Test basic listen functionality"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test listening on IPv4
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
# Test that listener can be created and has required methods
assert hasattr(listener, "listen")
assert hasattr(listener, "close")
assert hasattr(listener, "get_addrs")
# Test that listener can handle the address
assert ma.value_for_protocol("ip4") == "127.0.0.1"
assert ma.value_for_protocol("tcp") == "0"
# Test that listener can be closed
await listener.close()
@pytest.mark.trio
async def test_listener_port_0_handling():
"""Test listening on port 0 gets actual port"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
# Test that the address can be parsed correctly
port_str = ma.value_for_protocol("tcp")
assert port_str == "0"
# Test that listener can be closed
await listener.close()
@pytest.mark.trio
async def test_listener_any_interface():
"""Test listening on 0.0.0.0"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
ma = Multiaddr("/ip4/0.0.0.0/tcp/0/ws")
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
# Test that the address can be parsed correctly
host = ma.value_for_protocol("ip4")
assert host == "0.0.0.0"
# Test that listener can be closed
await listener.close()
@pytest.mark.trio
async def test_listener_address_preservation():
"""Test that p2p IDs are preserved in addresses"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Create address with p2p ID
p2p_id = "12D3KooWL5xtmx8Mgc6tByjVaPPpTKH42QK7PUFQtZLabdSMKHpF"
ma = Multiaddr(f"/ip4/127.0.0.1/tcp/0/ws/p2p/{p2p_id}")
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
# Test that p2p ID is preserved in the address
addr_str = str(ma)
assert p2p_id in addr_str
# Test that listener can be closed
await listener.close()
# 3. Dial Basic Functionality Tests
@pytest.mark.trio
async def test_dial_basic():
"""Test basic dial functionality"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test that transport can parse addresses for dialing
ma = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
# Test that the address can be parsed correctly
host = ma.value_for_protocol("ip4")
port = ma.value_for_protocol("tcp")
assert host == "127.0.0.1"
assert port == "8080"
# Test that transport has the required methods
assert hasattr(transport, "dial")
assert callable(transport.dial)
@pytest.mark.trio
async def test_dial_with_p2p_id():
"""Test dialing with p2p ID suffix"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
p2p_id = "12D3KooWL5xtmx8Mgc6tByjVaPPpTKH42QK7PUFQtZLabdSMKHpF"
ma = Multiaddr(f"/ip4/127.0.0.1/tcp/8080/ws/p2p/{p2p_id}")
# Test that p2p ID is preserved in the address
addr_str = str(ma)
assert p2p_id in addr_str
# Test that transport can handle addresses with p2p IDs
assert hasattr(transport, "dial")
assert callable(transport.dial)
@pytest.mark.trio
async def test_dial_port_0_resolution():
"""Test dialing to resolved port 0 addresses"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test that transport can handle port 0 addresses
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
# Test that the address can be parsed correctly
port_str = ma.value_for_protocol("tcp")
assert port_str == "0"
# Test that transport has the required methods
assert hasattr(transport, "dial")
assert callable(transport.dial)
# 4. Address Validation Tests (CRITICAL)
def test_address_validation_ipv4():
"""Test IPv4 address validation"""
# upgrader = create_upgrader() # Not used in this test
# Valid IPv4 WebSocket addresses
valid_addresses = [
"/ip4/127.0.0.1/tcp/8080/ws",
"/ip4/0.0.0.0/tcp/0/ws",
"/ip4/192.168.1.1/tcp/443/ws",
]
# Test valid addresses can be parsed
for addr_str in valid_addresses:
ma = Multiaddr(addr_str)
# Should not raise exception when creating transport address
transport_addr = str(ma)
assert "/ws" in transport_addr
# Test that transport can handle addresses with p2p IDs
p2p_addr = Multiaddr(
"/ip4/127.0.0.1/tcp/8080/ws/p2p/Qmb6owHp6eaWArVbcJJbQSyifyJBttMMjYV76N2hMbf5Vw"
)
# Should not raise exception when creating transport address
transport_addr = str(p2p_addr)
assert "/ws" in transport_addr
def test_address_validation_ipv6():
"""Test IPv6 address validation"""
# upgrader = create_upgrader() # Not used in this test
# Valid IPv6 WebSocket addresses
valid_addresses = [
"/ip6/::1/tcp/8080/ws",
"/ip6/2001:db8::1/tcp/443/ws",
]
# Test valid addresses can be parsed
for addr_str in valid_addresses:
ma = Multiaddr(addr_str)
transport_addr = str(ma)
assert "/ws" in transport_addr
def test_address_validation_dns():
"""Test DNS address validation"""
# upgrader = create_upgrader() # Not used in this test
# Valid DNS WebSocket addresses
valid_addresses = [
"/dns4/example.com/tcp/80/ws",
"/dns6/example.com/tcp/443/ws",
"/dnsaddr/example.com/tcp/8080/ws",
]
# Test valid addresses can be parsed
for addr_str in valid_addresses:
ma = Multiaddr(addr_str)
transport_addr = str(ma)
assert "/ws" in transport_addr
def test_address_validation_mixed():
"""Test mixed address validation"""
# upgrader = create_upgrader() # Not used in this test
# Mixed valid and invalid addresses
addresses = [
"/ip4/127.0.0.1/tcp/8080/ws", # Valid
"/ip4/127.0.0.1/tcp/8080", # Invalid (no /ws)
"/ip6/::1/tcp/8080/ws", # Valid
"/ip4/127.0.0.1/ws", # Invalid (no tcp)
"/dns4/example.com/tcp/80/ws", # Valid
]
# Convert to Multiaddr objects
multiaddrs = [Multiaddr(addr) for addr in addresses]
# Test that valid addresses can be processed
valid_count = 0
for ma in multiaddrs:
try:
# Try to extract transport part
addr_text = str(ma)
if "/ws" in addr_text and "/tcp/" in addr_text:
valid_count += 1
except Exception:
pass
assert valid_count == 3 # Should have 3 valid addresses
# 5. Error Handling Tests
@pytest.mark.trio
async def test_dial_invalid_address():
"""Test dialing invalid addresses"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test dialing non-WebSocket addresses
invalid_addresses = [
Multiaddr("/ip4/127.0.0.1/tcp/8080"), # No /ws
Multiaddr("/ip4/127.0.0.1/ws"), # No tcp
]
for ma in invalid_addresses:
with pytest.raises(Exception):
await transport.dial(ma)
@pytest.mark.trio
async def test_listen_invalid_address():
"""Test listening on invalid addresses"""
# upgrader = create_upgrader() # Not used in this test
# Test listening on non-WebSocket addresses
invalid_addresses = [
Multiaddr("/ip4/127.0.0.1/tcp/8080"), # No /ws
Multiaddr("/ip4/127.0.0.1/ws"), # No tcp
]
# Test that invalid addresses are properly identified
for ma in invalid_addresses:
# Test that the address parsing works correctly
if "/ws" in str(ma) and "tcp" not in str(ma):
# This should be invalid
assert "tcp" not in str(ma)
elif "/ws" not in str(ma):
# This should be invalid
assert "/ws" not in str(ma)
@pytest.mark.trio
async def test_listen_port_in_use():
"""Test listening on port that's in use"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test that transport can handle port conflicts
ma1 = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
ma2 = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
# Test that both addresses can be parsed
assert ma1.value_for_protocol("tcp") == "8080"
assert ma2.value_for_protocol("tcp") == "8080"
# Test that transport can handle these addresses
assert hasattr(transport, "create_listener")
assert callable(transport.create_listener)
# 6. Connection Lifecycle Tests
@pytest.mark.trio
async def test_connection_close():
"""Test connection closing"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test that transport has required methods
assert hasattr(transport, "dial")
assert callable(transport.dial)
# Test that listener can be created and closed
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
assert hasattr(listener, "close")
assert callable(listener.close)
# Test that listener can be closed
await listener.close()
@pytest.mark.trio
async def test_multiple_connections():
"""Test multiple concurrent connections"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test that transport can handle multiple addresses
addresses = [
Multiaddr("/ip4/127.0.0.1/tcp/8080/ws"),
Multiaddr("/ip4/127.0.0.1/tcp/8081/ws"),
Multiaddr("/ip4/127.0.0.1/tcp/8082/ws"),
]
# Test that all addresses can be parsed
for addr in addresses:
host = addr.value_for_protocol("ip4")
port = addr.value_for_protocol("tcp")
assert host == "127.0.0.1"
assert port in ["8080", "8081", "8082"]
# Test that transport has required methods
assert hasattr(transport, "dial")
assert callable(transport.dial)
# Original test (kept for compatibility)
@pytest.mark.trio
async def test_websocket_dial_and_listen():
"""Test basic WebSocket dial and listen functionality with real data transfer"""
# Test that WebSocket transport can handle basic operations
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test that transport can create listeners
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
assert listener is not None
assert hasattr(listener, "listen")
assert hasattr(listener, "close")
assert hasattr(listener, "get_addrs")
# Test that transport can handle WebSocket addresses
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
assert ma.value_for_protocol("ip4") == "127.0.0.1"
assert ma.value_for_protocol("tcp") == "0"
assert "ws" in str(ma)
# Test that transport has dial method
assert hasattr(transport, "dial")
assert callable(transport.dial)
# Test that transport can handle WebSocket multiaddrs
ws_addr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
assert ws_addr.value_for_protocol("ip4") == "127.0.0.1"
assert ws_addr.value_for_protocol("tcp") == "8080"
assert "ws" in str(ws_addr)
# Cleanup
await listener.close()
@pytest.mark.trio
async def test_websocket_transport_basic():
"""Test basic WebSocket transport functionality without full libp2p stack"""
# Create WebSocket transport
key_pair = create_new_key_pair()
upgrader = TransportUpgrader(
secure_transports_by_protocol={
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
},
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
transport = WebsocketTransport(upgrader)
assert transport is not None
assert hasattr(transport, "dial")
assert hasattr(transport, "create_listener")
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
assert listener is not None
assert hasattr(listener, "listen")
assert hasattr(listener, "close")
assert hasattr(listener, "get_addrs")
valid_addr = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
assert valid_addr.value_for_protocol("ip4") == "127.0.0.1"
assert valid_addr.value_for_protocol("tcp") == "0"
assert "ws" in str(valid_addr)
await listener.close()
@pytest.mark.trio
async def test_websocket_simple_connection():
"""Test WebSocket transport creation and basic functionality without real conn"""
# Create WebSocket transport
key_pair = create_new_key_pair()
upgrader = TransportUpgrader(
secure_transports_by_protocol={
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
},
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
transport = WebsocketTransport(upgrader)
assert transport is not None
assert hasattr(transport, "dial")
assert hasattr(transport, "create_listener")
async def simple_handler(conn):
await conn.close()
listener = transport.create_listener(simple_handler)
assert listener is not None
assert hasattr(listener, "listen")
assert hasattr(listener, "close")
assert hasattr(listener, "get_addrs")
test_addr = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
assert test_addr.value_for_protocol("ip4") == "127.0.0.1"
assert test_addr.value_for_protocol("tcp") == "0"
assert "ws" in str(test_addr)
await listener.close()
@pytest.mark.trio
async def test_websocket_real_connection():
"""Test WebSocket transport creation and basic functionality"""
# Create WebSocket transport
key_pair = create_new_key_pair()
upgrader = TransportUpgrader(
secure_transports_by_protocol={
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
},
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
transport = WebsocketTransport(upgrader)
assert transport is not None
assert hasattr(transport, "dial")
assert hasattr(transport, "create_listener")
async def handler(conn):
await conn.close()
listener = transport.create_listener(handler)
assert listener is not None
assert hasattr(listener, "listen")
assert hasattr(listener, "close")
assert hasattr(listener, "get_addrs")
await listener.close()
@pytest.mark.trio
async def test_websocket_with_tcp_fallback():
"""Test WebSocket functionality using TCP transport as fallback"""
from tests.utils.factories import host_pair_factory
async with host_pair_factory() as (host_a, host_b):
assert len(host_a.get_network().connections) > 0
assert len(host_b.get_network().connections) > 0
test_protocol = TProtocol("/test/protocol/1.0.0")
received_data = None
async def test_handler(stream):
nonlocal received_data
received_data = await stream.read(1024)
await stream.write(b"Response from TCP")
await stream.close()
host_a.set_stream_handler(test_protocol, test_handler)
stream = await host_b.new_stream(host_a.get_id(), [test_protocol])
test_data = b"TCP protocol test"
await stream.write(test_data)
response = await stream.read(1024)
assert received_data == test_data
assert response == b"Response from TCP"
await stream.close()
@pytest.mark.trio
async def test_websocket_data_exchange():
"""Test WebSocket transport with actual data exchange between two hosts"""
from libp2p import create_yamux_muxer_option, new_host
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.custom_types import TProtocol
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.security.insecure.transport import (
PLAINTEXT_PROTOCOL_ID,
InsecureTransport,
)
# Create two hosts with plaintext security
key_pair_a = create_new_key_pair()
key_pair_b = create_new_key_pair()
# Host A (listener)
security_options_a = {
PLAINTEXT_PROTOCOL_ID: InsecureTransport(
local_key_pair=key_pair_a, secure_bytes_provider=None, peerstore=None
)
}
host_a = new_host(
key_pair=key_pair_a,
sec_opt=security_options_a,
muxer_opt=create_yamux_muxer_option(),
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")],
)
# Host B (dialer)
security_options_b = {
PLAINTEXT_PROTOCOL_ID: InsecureTransport(
local_key_pair=key_pair_b, secure_bytes_provider=None, peerstore=None
)
}
host_b = new_host(
key_pair=key_pair_b,
sec_opt=security_options_b,
muxer_opt=create_yamux_muxer_option(),
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # WebSocket transport
)
# Test data
test_data = b"Hello WebSocket Data Exchange!"
received_data = None
# Set up handler on host A
test_protocol = TProtocol("/test/websocket/data/1.0.0")
async def data_handler(stream):
nonlocal received_data
received_data = await stream.read(len(test_data))
await stream.write(received_data) # Echo back
await stream.close()
host_a.set_stream_handler(test_protocol, data_handler)
# Start both hosts
async with (
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")]),
host_b.run(listen_addrs=[]),
):
# Get host A's listen address
listen_addrs = host_a.get_addrs()
assert len(listen_addrs) > 0
# Find the WebSocket address
ws_addr = None
for addr in listen_addrs:
if "/ws" in str(addr):
ws_addr = addr
break
assert ws_addr is not None, "No WebSocket listen address found"
# Connect host B to host A
peer_info = info_from_p2p_addr(ws_addr)
await host_b.connect(peer_info)
# Create stream and test data exchange
stream = await host_b.new_stream(host_a.get_id(), [test_protocol])
await stream.write(test_data)
response = await stream.read(len(test_data))
await stream.close()
# Verify data exchange
assert received_data == test_data, f"Expected {test_data}, got {received_data}"
assert response == test_data, f"Expected echo {test_data}, got {response}"
@pytest.mark.trio
async def test_websocket_host_pair_data_exchange():
"""
Test WebSocket host pair with actual data exchange using host_pair_factory
pattern.
"""
from libp2p import create_yamux_muxer_option, new_host
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.custom_types import TProtocol
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.security.insecure.transport import (
PLAINTEXT_PROTOCOL_ID,
InsecureTransport,
)
# Create two hosts with WebSocket transport and plaintext security
key_pair_a = create_new_key_pair()
key_pair_b = create_new_key_pair()
# Host A (listener) - WebSocket transport
security_options_a = {
PLAINTEXT_PROTOCOL_ID: InsecureTransport(
local_key_pair=key_pair_a, secure_bytes_provider=None, peerstore=None
)
}
host_a = new_host(
key_pair=key_pair_a,
sec_opt=security_options_a,
muxer_opt=create_yamux_muxer_option(),
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")],
)
# Host B (dialer) - WebSocket transport
security_options_b = {
PLAINTEXT_PROTOCOL_ID: InsecureTransport(
local_key_pair=key_pair_b, secure_bytes_provider=None, peerstore=None
)
}
host_b = new_host(
key_pair=key_pair_b,
sec_opt=security_options_b,
muxer_opt=create_yamux_muxer_option(),
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # WebSocket transport
)
# Test data
test_data = b"Hello WebSocket Host Pair Data Exchange!"
received_data = None
# Set up handler on host A
test_protocol = TProtocol("/test/websocket/hostpair/1.0.0")
async def data_handler(stream):
nonlocal received_data
received_data = await stream.read(len(test_data))
await stream.write(received_data) # Echo back
await stream.close()
host_a.set_stream_handler(test_protocol, data_handler)
# Start both hosts and connect them (following host_pair_factory pattern)
async with (
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")]),
host_b.run(listen_addrs=[]),
):
# Connect the hosts using the same pattern as host_pair_factory
# Get host A's listen address and create peer info
listen_addrs = host_a.get_addrs()
assert len(listen_addrs) > 0
# Find the WebSocket address
ws_addr = None
for addr in listen_addrs:
if "/ws" in str(addr):
ws_addr = addr
break
assert ws_addr is not None, "No WebSocket listen address found"
# Connect host B to host A
peer_info = info_from_p2p_addr(ws_addr)
await host_b.connect(peer_info)
# Allow time for connection to establish (following host_pair_factory pattern)
await trio.sleep(0.1)
# Verify connection is established
assert len(host_a.get_network().connections) > 0
assert len(host_b.get_network().connections) > 0
# Test data exchange
stream = await host_b.new_stream(host_a.get_id(), [test_protocol])
await stream.write(test_data)
response = await stream.read(len(test_data))
await stream.close()
# Verify data exchange
assert received_data == test_data, f"Expected {test_data}, got {received_data}"
assert response == test_data, f"Expected echo {test_data}, got {response}"
@pytest.mark.trio
async def test_wss_host_pair_data_exchange():
"""Test WSS host pair with actual data exchange using host_pair_factory pattern"""
import ssl
from libp2p import create_yamux_muxer_option, new_host
from libp2p.crypto.secp256k1 import create_new_key_pair
from libp2p.custom_types import TProtocol
from libp2p.peer.peerinfo import info_from_p2p_addr
from libp2p.security.insecure.transport import (
PLAINTEXT_PROTOCOL_ID,
InsecureTransport,
)
# Create TLS contexts for WSS (separate for client and server)
# For testing, we need to create a self-signed certificate
try:
import datetime
import ipaddress
import os
import tempfile
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID
# Generate private key
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
)
# Create certificate
subject = issuer = x509.Name(
[
x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), # type: ignore
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Test"), # type: ignore
x509.NameAttribute(NameOID.LOCALITY_NAME, "Test"), # type: ignore
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Test"), # type: ignore
x509.NameAttribute(NameOID.COMMON_NAME, "localhost"), # type: ignore
]
)
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(private_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.now(datetime.UTC))
.not_valid_after(
datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=1)
)
.add_extension(
x509.SubjectAlternativeName(
[
x509.DNSName("localhost"),
x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")),
]
),
critical=False,
)
.sign(private_key, hashes.SHA256())
)
# Create temporary files for cert and key
cert_file = tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=".crt")
key_file = tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=".key")
# Write certificate and key to files
cert_file.write(cert.public_bytes(serialization.Encoding.PEM))
key_file.write(
private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
)
cert_file.close()
key_file.close()
# Server context for listener (Host A)
server_tls_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
server_tls_context.load_cert_chain(cert_file.name, key_file.name)
# Client context for dialer (Host B)
client_tls_context = ssl.create_default_context()
client_tls_context.check_hostname = False
client_tls_context.verify_mode = ssl.CERT_NONE
# Clean up temp files after use
def cleanup_certs():
try:
os.unlink(cert_file.name)
os.unlink(key_file.name)
except Exception:
pass
except ImportError:
pytest.skip("cryptography package required for WSS tests")
except Exception as e:
pytest.skip(f"Failed to create test certificates: {e}")
# Create two hosts with WSS transport and plaintext security
key_pair_a = create_new_key_pair()
key_pair_b = create_new_key_pair()
# Host A (listener) - WSS transport with server TLS config
security_options_a = {
PLAINTEXT_PROTOCOL_ID: InsecureTransport(
local_key_pair=key_pair_a, secure_bytes_provider=None, peerstore=None
)
}
host_a = new_host(
key_pair=key_pair_a,
sec_opt=security_options_a,
muxer_opt=create_yamux_muxer_option(),
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/wss")],
tls_server_config=server_tls_context,
)
# Host B (dialer) - WSS transport with client TLS config
security_options_b = {
PLAINTEXT_PROTOCOL_ID: InsecureTransport(
local_key_pair=key_pair_b, secure_bytes_provider=None, peerstore=None
)
}
host_b = new_host(
key_pair=key_pair_b,
sec_opt=security_options_b,
muxer_opt=create_yamux_muxer_option(),
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/wss")], # Ensure WSS transport
tls_client_config=client_tls_context,
)
# Test data
test_data = b"Hello WSS Host Pair Data Exchange!"
received_data = None
# Set up handler on host A
test_protocol = TProtocol("/test/wss/hostpair/1.0.0")
async def data_handler(stream):
nonlocal received_data
received_data = await stream.read(len(test_data))
await stream.write(received_data) # Echo back
await stream.close()
host_a.set_stream_handler(test_protocol, data_handler)
# Start both hosts and connect them (following host_pair_factory pattern)
async with (
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/wss")]),
host_b.run(listen_addrs=[]),
):
# Connect the hosts using the same pattern as host_pair_factory
# Get host A's listen address and create peer info
listen_addrs = host_a.get_addrs()
assert len(listen_addrs) > 0
# Find the WSS address
wss_addr = None
for addr in listen_addrs:
if "/wss" in str(addr):
wss_addr = addr
break
assert wss_addr is not None, "No WSS listen address found"
# Connect host B to host A
peer_info = info_from_p2p_addr(wss_addr)
await host_b.connect(peer_info)
# Allow time for connection to establish (following host_pair_factory pattern)
await trio.sleep(0.1)
# Verify connection is established
assert len(host_a.get_network().connections) > 0
assert len(host_b.get_network().connections) > 0
# Test data exchange
stream = await host_b.new_stream(host_a.get_id(), [test_protocol])
await stream.write(test_data)
response = await stream.read(len(test_data))
await stream.close()
# Verify data exchange
assert received_data == test_data, f"Expected {test_data}, got {received_data}"
assert response == test_data, f"Expected echo {test_data}, got {response}"
@pytest.mark.trio
async def test_websocket_transport_interface():
"""Test WebSocket transport interface compliance"""
key_pair = create_new_key_pair()
upgrader = TransportUpgrader(
secure_transports_by_protocol={
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
},
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
)
transport = WebsocketTransport(upgrader)
assert hasattr(transport, "dial")
assert hasattr(transport, "create_listener")
assert callable(transport.dial)
assert callable(transport.create_listener)
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
assert hasattr(listener, "listen")
assert hasattr(listener, "close")
assert hasattr(listener, "get_addrs")
test_addr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
host = test_addr.value_for_protocol("ip4")
port = test_addr.value_for_protocol("tcp")
assert host == "127.0.0.1"
assert port == "8080"
await listener.close()
# ============================================================================
# WSS (WebSocket Secure) Tests
# ============================================================================
def test_wss_multiaddr_validation():
"""Test WSS multiaddr validation and parsing."""
# Valid WSS multiaddrs
valid_wss_addresses = [
"/ip4/127.0.0.1/tcp/8080/wss",
"/ip6/::1/tcp/8080/wss",
"/dns/localhost/tcp/8080/wss",
"/ip4/127.0.0.1/tcp/8080/tls/ws",
"/ip6/::1/tcp/8080/tls/ws",
]
# Invalid WSS multiaddrs
invalid_wss_addresses = [
"/ip4/127.0.0.1/tcp/8080/ws", # Regular WS, not WSS
"/ip4/127.0.0.1/tcp/8080", # No WebSocket protocol
"/ip4/127.0.0.1/wss", # No TCP
]
# Test valid WSS addresses
for addr_str in valid_wss_addresses:
ma = Multiaddr(addr_str)
assert is_valid_websocket_multiaddr(ma), f"Address {addr_str} should be valid"
# Test parsing
parsed = parse_websocket_multiaddr(ma)
assert parsed.is_wss, f"Address {addr_str} should be parsed as WSS"
# Test invalid addresses
for addr_str in invalid_wss_addresses:
ma = Multiaddr(addr_str)
if "/ws" in addr_str and "/wss" not in addr_str and "/tls" not in addr_str:
# Regular WS should be valid but not WSS
assert is_valid_websocket_multiaddr(ma), (
f"Address {addr_str} should be valid"
)
parsed = parse_websocket_multiaddr(ma)
assert not parsed.is_wss, f"Address {addr_str} should not be parsed as WSS"
else:
# Invalid addresses should fail validation
assert not is_valid_websocket_multiaddr(ma), (
f"Address {addr_str} should be invalid"
)
def test_wss_multiaddr_parsing():
"""Test WSS multiaddr parsing functionality."""
# Test /wss format
wss_ma = Multiaddr("/ip4/127.0.0.1/tcp/8080/wss")
parsed = parse_websocket_multiaddr(wss_ma)
assert parsed.is_wss
assert parsed.sni is None
assert parsed.rest_multiaddr.value_for_protocol("ip4") == "127.0.0.1"
assert parsed.rest_multiaddr.value_for_protocol("tcp") == "8080"
# Test /tls/ws format
tls_ws_ma = Multiaddr("/ip4/127.0.0.1/tcp/8080/tls/ws")
parsed = parse_websocket_multiaddr(tls_ws_ma)
assert parsed.is_wss
assert parsed.sni is None
assert parsed.rest_multiaddr.value_for_protocol("ip4") == "127.0.0.1"
assert parsed.rest_multiaddr.value_for_protocol("tcp") == "8080"
# Test regular /ws format
ws_ma = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
parsed = parse_websocket_multiaddr(ws_ma)
assert not parsed.is_wss
assert parsed.sni is None
@pytest.mark.trio
async def test_wss_transport_creation():
"""Test WSS transport creation with TLS configuration."""
import ssl
# Create TLS contexts
client_ssl_context = ssl.create_default_context()
server_ssl_context = ssl.create_default_context()
server_ssl_context.check_hostname = False
server_ssl_context.verify_mode = ssl.CERT_NONE
upgrader = create_upgrader()
# Test creating WSS transport with TLS configs
wss_transport = WebsocketTransport(
upgrader,
tls_client_config=client_ssl_context,
tls_server_config=server_ssl_context,
)
assert wss_transport is not None
assert hasattr(wss_transport, "dial")
assert hasattr(wss_transport, "create_listener")
assert wss_transport._tls_client_config is not None
assert wss_transport._tls_server_config is not None
@pytest.mark.trio
async def test_wss_transport_without_tls_config():
"""Test WSS transport creation without TLS configuration."""
upgrader = create_upgrader()
# Test creating WSS transport without TLS configs (should still work)
wss_transport = WebsocketTransport(upgrader)
assert wss_transport is not None
assert hasattr(wss_transport, "dial")
assert hasattr(wss_transport, "create_listener")
assert wss_transport._tls_client_config is None
assert wss_transport._tls_server_config is None
@pytest.mark.trio
async def test_wss_dial_parsing():
"""Test WSS dial functionality with multiaddr parsing."""
# upgrader = create_upgrader() # Not used in this test
# transport = WebsocketTransport(upgrader) # Not used in this test
# Test WSS multiaddr parsing in dial
wss_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/wss")
# Test that the transport can parse WSS addresses
# (We can't actually dial without a server, but we can test parsing)
try:
parsed = parse_websocket_multiaddr(wss_maddr)
assert parsed.is_wss
assert parsed.rest_multiaddr.value_for_protocol("ip4") == "127.0.0.1"
assert parsed.rest_multiaddr.value_for_protocol("tcp") == "8080"
except Exception as e:
pytest.fail(f"WSS multiaddr parsing failed: {e}")
@pytest.mark.trio
async def test_wss_listen_parsing():
"""Test WSS listen functionality with multiaddr parsing."""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test WSS multiaddr parsing in listen
wss_maddr = Multiaddr("/ip4/127.0.0.1/tcp/0/wss")
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
# Test that the transport can parse WSS addresses
try:
parsed = parse_websocket_multiaddr(wss_maddr)
assert parsed.is_wss
assert parsed.rest_multiaddr.value_for_protocol("ip4") == "127.0.0.1"
assert parsed.rest_multiaddr.value_for_protocol("tcp") == "0"
except Exception as e:
pytest.fail(f"WSS multiaddr parsing failed: {e}")
await listener.close()
@pytest.mark.trio
async def test_wss_listen_without_tls_config():
"""Test WSS listen without TLS configuration should fail."""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader) # No TLS config
wss_maddr = Multiaddr("/ip4/127.0.0.1/tcp/0/wss")
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
# This should raise an error when trying to listen on WSS without TLS config
with pytest.raises(ExceptionGroup) as exc_info:
async with trio.open_nursery() as nursery:
await listener.listen(wss_maddr, nursery)
# Check that the ExceptionGroup contains the expected ValueError
assert len(exc_info.value.exceptions) == 1
assert isinstance(exc_info.value.exceptions[0], ValueError)
assert "Cannot listen on WSS address" in str(exc_info.value.exceptions[0])
assert "without TLS configuration" in str(exc_info.value.exceptions[0])
@pytest.mark.trio
async def test_wss_listen_with_tls_config():
"""Test WSS listen with TLS configuration."""
import ssl
# Create server TLS context
server_ssl_context = ssl.create_default_context()
server_ssl_context.check_hostname = False
server_ssl_context.verify_mode = ssl.CERT_NONE
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader, tls_server_config=server_ssl_context)
wss_maddr = Multiaddr("/ip4/127.0.0.1/tcp/0/wss")
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
# This should not raise an error when TLS config is provided
# Note: We can't actually start listening without proper certificates,
# but we can test that the validation passes
try:
parsed = parse_websocket_multiaddr(wss_maddr)
assert parsed.is_wss
assert transport._tls_server_config is not None
except Exception as e:
pytest.fail(f"WSS listen with TLS config failed: {e}")
await listener.close()
def test_wss_transport_registry():
"""Test WSS support in transport registry."""
from libp2p.transport.transport_registry import (
create_transport_for_multiaddr,
get_supported_transport_protocols,
)
# Test that WSS is supported
supported = get_supported_transport_protocols()
assert "ws" in supported
assert "wss" in supported
# Test transport creation for WSS multiaddrs
upgrader = create_upgrader()
# Test WS multiaddr
ws_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
ws_transport = create_transport_for_multiaddr(ws_maddr, upgrader)
assert ws_transport is not None
assert isinstance(ws_transport, WebsocketTransport)
# Test WSS multiaddr
wss_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/wss")
wss_transport = create_transport_for_multiaddr(wss_maddr, upgrader)
assert wss_transport is not None
assert isinstance(wss_transport, WebsocketTransport)
# Test TLS/WS multiaddr
tls_ws_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/tls/ws")
tls_ws_transport = create_transport_for_multiaddr(tls_ws_maddr, upgrader)
assert tls_ws_transport is not None
assert isinstance(tls_ws_transport, WebsocketTransport)
def test_wss_multiaddr_formats():
"""Test different WSS multiaddr formats."""
# Test various WSS formats
wss_formats = [
"/ip4/127.0.0.1/tcp/8080/wss",
"/ip6/::1/tcp/8080/wss",
"/dns/localhost/tcp/8080/wss",
"/ip4/127.0.0.1/tcp/8080/tls/ws",
"/ip6/::1/tcp/8080/tls/ws",
"/dns/example.com/tcp/443/tls/ws",
]
for addr_str in wss_formats:
ma = Multiaddr(addr_str)
# Should be valid WebSocket multiaddr
assert is_valid_websocket_multiaddr(ma), f"Address {addr_str} should be valid"
# Should parse as WSS
parsed = parse_websocket_multiaddr(ma)
assert parsed.is_wss, f"Address {addr_str} should be parsed as WSS"
# Should have correct base multiaddr
assert parsed.rest_multiaddr.value_for_protocol("tcp") is not None
def test_wss_vs_ws_distinction():
"""Test that WSS and WS are properly distinguished."""
# WS addresses should not be WSS
ws_addresses = [
"/ip4/127.0.0.1/tcp/8080/ws",
"/ip6/::1/tcp/8080/ws",
"/dns/localhost/tcp/8080/ws",
]
for addr_str in ws_addresses:
ma = Multiaddr(addr_str)
parsed = parse_websocket_multiaddr(ma)
assert not parsed.is_wss, f"Address {addr_str} should not be WSS"
# WSS addresses should be WSS
wss_addresses = [
"/ip4/127.0.0.1/tcp/8080/wss",
"/ip4/127.0.0.1/tcp/8080/tls/ws",
]
for addr_str in wss_addresses:
ma = Multiaddr(addr_str)
parsed = parse_websocket_multiaddr(ma)
assert parsed.is_wss, f"Address {addr_str} should be WSS"
@pytest.mark.trio
async def test_wss_connection_handling():
"""Test WSS connection handling with security flag."""
# upgrader = create_upgrader() # Not used in this test
# transport = WebsocketTransport(upgrader) # Not used in this test
# Test that WSS connections are marked as secure
wss_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/wss")
parsed = parse_websocket_multiaddr(wss_maddr)
assert parsed.is_wss
# Test that WS connections are not marked as secure
ws_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
parsed = parse_websocket_multiaddr(ws_maddr)
assert not parsed.is_wss
def test_wss_error_handling():
"""Test WSS error handling for invalid configurations."""
# upgrader = create_upgrader() # Not used in this test
# Test invalid multiaddr formats
invalid_addresses = [
"/ip4/127.0.0.1/tcp/8080", # No WebSocket protocol
"/ip4/127.0.0.1/wss", # No TCP
"/tcp/8080/wss", # No network protocol
]
for addr_str in invalid_addresses:
ma = Multiaddr(addr_str)
assert not is_valid_websocket_multiaddr(ma), (
f"Address {addr_str} should be invalid"
)
# Should raise ValueError when parsing invalid addresses
with pytest.raises(ValueError):
parse_websocket_multiaddr(ma)
@pytest.mark.trio
async def test_handshake_timeout():
"""Test WebSocket handshake timeout functionality."""
upgrader = create_upgrader()
# Test creating transport with custom handshake timeout
transport = WebsocketTransport(upgrader, handshake_timeout=0.1) # 100ms timeout
assert transport._handshake_timeout == 0.1
# Test that the timeout is passed to the listener
async def dummy_handler(conn):
await trio.sleep(0)
listener = transport.create_listener(dummy_handler)
# Type assertion to access private attribute for testing
assert hasattr(listener, "_handshake_timeout")
assert getattr(listener, "_handshake_timeout") == 0.1
@pytest.mark.trio
async def test_handshake_timeout_creation():
"""Test handshake timeout in transport creation."""
upgrader = create_upgrader()
# Test creating transport with handshake timeout via create_transport
from libp2p.transport import create_transport
transport = create_transport("ws", upgrader, handshake_timeout=5.0)
# Type assertion to access private attribute for testing
assert hasattr(transport, "_handshake_timeout")
assert getattr(transport, "_handshake_timeout") == 5.0
# Test default timeout
transport_default = create_transport("ws", upgrader)
assert hasattr(transport_default, "_handshake_timeout")
assert getattr(transport_default, "_handshake_timeout") == 15.0
@pytest.mark.trio
async def test_connection_state_tracking():
"""Test WebSocket connection state tracking."""
from libp2p.transport.websocket.connection import P2PWebSocketConnection
# Create a mock WebSocket connection
class MockWebSocketConnection:
async def send_message(self, data: bytes) -> None:
pass
async def get_message(self) -> bytes:
return b"test message"
async def aclose(self) -> None:
pass
mock_ws = MockWebSocketConnection()
conn = P2PWebSocketConnection(mock_ws, is_secure=True)
# Test initial state
state = conn.conn_state()
assert state["transport"] == "websocket"
assert state["secure"] is True
assert state["bytes_read"] == 0
assert state["bytes_written"] == 0
assert state["total_bytes"] == 0
assert state["connection_duration"] >= 0
# Test byte tracking (we can't actually read/write with mock, but we can test
# the method)
# The actual byte tracking will be tested in integration tests
assert hasattr(conn, "_bytes_read")
assert hasattr(conn, "_bytes_written")
assert hasattr(conn, "_connection_start_time")
@pytest.mark.trio
async def test_concurrent_close_handling():
"""Test concurrent close handling similar to Go implementation."""
from libp2p.transport.websocket.connection import P2PWebSocketConnection
# Create a mock WebSocket connection that tracks close calls
class MockWebSocketConnection:
def __init__(self):
self.close_calls = 0
self.closed = False
async def send_message(self, data: bytes) -> None:
if self.closed:
raise Exception("Connection closed")
pass
async def get_message(self) -> bytes:
if self.closed:
raise Exception("Connection closed")
return b"test message"
async def aclose(self) -> None:
self.close_calls += 1
self.closed = True
mock_ws = MockWebSocketConnection()
conn = P2PWebSocketConnection(mock_ws, is_secure=False)
# Test that multiple close calls are handled gracefully
await conn.close()
await conn.close() # Second close should not raise an error
# The mock should only be closed once
assert mock_ws.close_calls == 1
assert mock_ws.closed is True
@pytest.mark.trio
async def test_zero_byte_write_handling():
"""Test zero-byte write handling similar to Go implementation."""
from libp2p.transport.websocket.connection import P2PWebSocketConnection
# Create a mock WebSocket connection that tracks write calls
class MockWebSocketConnection:
def __init__(self):
self.write_calls = []
async def send_message(self, data: bytes) -> None:
self.write_calls.append(len(data))
async def get_message(self) -> bytes:
return b"test message"
async def aclose(self) -> None:
pass
mock_ws = MockWebSocketConnection()
conn = P2PWebSocketConnection(mock_ws, is_secure=False)
# Test zero-byte write
await conn.write(b"")
assert 0 in mock_ws.write_calls
# Test normal write
await conn.write(b"hello")
assert 5 in mock_ws.write_calls
# Test multiple zero-byte writes
for _ in range(10):
await conn.write(b"")
# Should have 11 zero-byte writes total (1 initial + 10 in loop)
zero_byte_writes = [call for call in mock_ws.write_calls if call == 0]
assert len(zero_byte_writes) == 11
@pytest.mark.trio
async def test_websocket_transport_protocols():
"""Test that WebSocket transport reports correct protocols."""
# upgrader = create_upgrader() # Not used in this test
# transport = WebsocketTransport(upgrader) # Not used in this test
# Test that the transport can handle both WS and WSS protocols
ws_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
wss_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/wss")
# Both should be valid WebSocket multiaddrs
assert is_valid_websocket_multiaddr(ws_maddr)
assert is_valid_websocket_multiaddr(wss_maddr)
# Both should be parseable
ws_parsed = parse_websocket_multiaddr(ws_maddr)
wss_parsed = parse_websocket_multiaddr(wss_maddr)
assert not ws_parsed.is_wss
assert wss_parsed.is_wss
@pytest.mark.trio
async def test_websocket_listener_addr_format():
"""Test WebSocket listener address format similar to Go implementation."""
upgrader = create_upgrader()
# Test WS listener
transport_ws = WebsocketTransport(upgrader)
async def dummy_handler_ws(conn):
await trio.sleep(0)
listener_ws = transport_ws.create_listener(dummy_handler_ws)
# Type assertion to access private attribute for testing
assert hasattr(listener_ws, "_handshake_timeout")
assert getattr(listener_ws, "_handshake_timeout") == 15.0 # Default timeout
# Test WSS listener with TLS config
import ssl
tls_config = ssl.create_default_context()
transport_wss = WebsocketTransport(upgrader, tls_server_config=tls_config)
async def dummy_handler_wss(conn):
await trio.sleep(0)
listener_wss = transport_wss.create_listener(dummy_handler_wss)
# Type assertion to access private attributes for testing
assert hasattr(listener_wss, "_tls_config")
assert getattr(listener_wss, "_tls_config") is not None
assert hasattr(listener_wss, "_handshake_timeout")
assert getattr(listener_wss, "_handshake_timeout") == 15.0
@pytest.mark.trio
async def test_sni_resolution_limitation():
"""
Test SNI resolution limitation - Python multiaddr library doesn't support
SNI protocol.
"""
upgrader = create_upgrader()
transport = WebsocketTransport(upgrader)
# Test that WSS addresses are returned unchanged (SNI resolution not supported)
wss_maddr = Multiaddr("/dns/example.com/tcp/1234/wss")
resolved = transport.resolve(wss_maddr)
assert len(resolved) == 1
assert resolved[0] == wss_maddr
# Test that non-WSS addresses are returned unchanged
ws_maddr = Multiaddr("/dns/example.com/tcp/1234/ws")
resolved = transport.resolve(ws_maddr)
assert len(resolved) == 1
assert resolved[0] == ws_maddr
# Test that IP addresses are returned unchanged
ip_maddr = Multiaddr("/ip4/127.0.0.1/tcp/1234/wss")
resolved = transport.resolve(ip_maddr)
assert len(resolved) == 1
assert resolved[0] == ip_maddr
@pytest.mark.trio
async def test_websocket_transport_can_dial():
"""Test WebSocket transport CanDial functionality similar to Go implementation."""
# upgrader = create_upgrader() # Not used in this test
# transport = WebsocketTransport(upgrader) # Not used in this test
# Test valid WebSocket addresses that should be dialable
valid_addresses = [
"/ip4/127.0.0.1/tcp/5555/ws",
"/ip4/127.0.0.1/tcp/5555/wss",
"/ip4/127.0.0.1/tcp/5555/tls/ws",
# Note: SNI addresses not supported by Python multiaddr library
]
for addr_str in valid_addresses:
maddr = Multiaddr(addr_str)
# All these should be valid WebSocket multiaddrs
assert is_valid_websocket_multiaddr(maddr), (
f"Address {addr_str} should be valid"
)
# Test invalid addresses that should not be dialable
invalid_addresses = [
"/ip4/127.0.0.1/tcp/5555", # No WebSocket protocol
"/ip4/127.0.0.1/udp/5555/ws", # Wrong transport protocol
]
for addr_str in invalid_addresses:
maddr = Multiaddr(addr_str)
# These should not be valid WebSocket multiaddrs
assert not is_valid_websocket_multiaddr(maddr), (
f"Address {addr_str} should be invalid"
)