mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
- Fix type annotation errors in transport_registry.py and __init__.py - Fix line length violations in test files (E501 errors) - Fix missing return type annotations - Fix cryptography NameAttribute type errors with type: ignore - Fix ExceptionGroup import for cross-version compatibility - Fix test failure in test_wss_listen_without_tls_config by handling ExceptionGroup - Fix len() calls with None arguments in test_tcp_data_transfer.py - Fix missing attribute access errors on interface types - Fix boolean type expectation errors in test_js_ws_ping.py - Fix nursery context manager type errors All tests now pass and linting is clean.
1618 lines
52 KiB
Python
1618 lines
52 KiB
Python
from collections.abc import Sequence
|
|
import logging
|
|
from typing import Any
|
|
|
|
import pytest
|
|
from exceptiongroup import ExceptionGroup
|
|
from multiaddr import Multiaddr
|
|
import trio
|
|
|
|
from libp2p.crypto.secp256k1 import create_new_key_pair
|
|
from libp2p.custom_types import TProtocol
|
|
from libp2p.host.basic_host import BasicHost
|
|
from libp2p.network.swarm import Swarm
|
|
from libp2p.peer.id import ID
|
|
from libp2p.peer.peerstore import PeerStore
|
|
from libp2p.security.insecure.transport import InsecureTransport
|
|
from libp2p.stream_muxer.yamux.yamux import Yamux
|
|
from libp2p.transport.upgrader import TransportUpgrader
|
|
from libp2p.transport.websocket.multiaddr_utils import (
|
|
is_valid_websocket_multiaddr,
|
|
parse_websocket_multiaddr,
|
|
)
|
|
from libp2p.transport.websocket.transport import WebsocketTransport
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
PLAINTEXT_PROTOCOL_ID = "/plaintext/1.0.0"
|
|
|
|
|
|
async def make_host(
|
|
listen_addrs: Sequence[Multiaddr] | None = None,
|
|
) -> tuple[BasicHost, Any | None]:
|
|
# Identity
|
|
key_pair = create_new_key_pair()
|
|
peer_id = ID.from_pubkey(key_pair.public_key)
|
|
peer_store = PeerStore()
|
|
peer_store.add_key_pair(peer_id, key_pair)
|
|
|
|
# Upgrader
|
|
upgrader = TransportUpgrader(
|
|
secure_transports_by_protocol={
|
|
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
|
|
},
|
|
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
|
)
|
|
|
|
# Transport + Swarm + Host
|
|
transport = WebsocketTransport(upgrader)
|
|
swarm = Swarm(peer_id, peer_store, upgrader, transport)
|
|
host = BasicHost(swarm)
|
|
|
|
# Optionally run/listen
|
|
ctx = None
|
|
if listen_addrs:
|
|
ctx = host.run(listen_addrs)
|
|
await ctx.__aenter__()
|
|
|
|
return host, ctx
|
|
|
|
|
|
def create_upgrader():
|
|
"""Helper function to create a transport upgrader"""
|
|
key_pair = create_new_key_pair()
|
|
return TransportUpgrader(
|
|
secure_transports_by_protocol={
|
|
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
|
|
},
|
|
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
|
)
|
|
|
|
|
|
# 2. Listener Basic Functionality Tests
|
|
@pytest.mark.trio
|
|
async def test_listener_basic_listen():
|
|
"""Test basic listen functionality"""
|
|
upgrader = create_upgrader()
|
|
transport = WebsocketTransport(upgrader)
|
|
|
|
# Test listening on IPv4
|
|
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
|
|
|
|
async def dummy_handler(conn):
|
|
await trio.sleep(0)
|
|
|
|
listener = transport.create_listener(dummy_handler)
|
|
|
|
# Test that listener can be created and has required methods
|
|
assert hasattr(listener, "listen")
|
|
assert hasattr(listener, "close")
|
|
assert hasattr(listener, "get_addrs")
|
|
|
|
# Test that listener can handle the address
|
|
assert ma.value_for_protocol("ip4") == "127.0.0.1"
|
|
assert ma.value_for_protocol("tcp") == "0"
|
|
|
|
# Test that listener can be closed
|
|
await listener.close()
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_listener_port_0_handling():
|
|
"""Test listening on port 0 gets actual port"""
|
|
upgrader = create_upgrader()
|
|
transport = WebsocketTransport(upgrader)
|
|
|
|
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
|
|
|
|
async def dummy_handler(conn):
|
|
await trio.sleep(0)
|
|
|
|
listener = transport.create_listener(dummy_handler)
|
|
|
|
# Test that the address can be parsed correctly
|
|
port_str = ma.value_for_protocol("tcp")
|
|
assert port_str == "0"
|
|
|
|
# Test that listener can be closed
|
|
await listener.close()
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_listener_any_interface():
|
|
"""Test listening on 0.0.0.0"""
|
|
upgrader = create_upgrader()
|
|
transport = WebsocketTransport(upgrader)
|
|
|
|
ma = Multiaddr("/ip4/0.0.0.0/tcp/0/ws")
|
|
|
|
async def dummy_handler(conn):
|
|
await trio.sleep(0)
|
|
|
|
listener = transport.create_listener(dummy_handler)
|
|
|
|
# Test that the address can be parsed correctly
|
|
host = ma.value_for_protocol("ip4")
|
|
assert host == "0.0.0.0"
|
|
|
|
# Test that listener can be closed
|
|
await listener.close()
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_listener_address_preservation():
|
|
"""Test that p2p IDs are preserved in addresses"""
|
|
upgrader = create_upgrader()
|
|
transport = WebsocketTransport(upgrader)
|
|
|
|
# Create address with p2p ID
|
|
p2p_id = "12D3KooWL5xtmx8Mgc6tByjVaPPpTKH42QK7PUFQtZLabdSMKHpF"
|
|
ma = Multiaddr(f"/ip4/127.0.0.1/tcp/0/ws/p2p/{p2p_id}")
|
|
|
|
async def dummy_handler(conn):
|
|
await trio.sleep(0)
|
|
|
|
listener = transport.create_listener(dummy_handler)
|
|
|
|
# Test that p2p ID is preserved in the address
|
|
addr_str = str(ma)
|
|
assert p2p_id in addr_str
|
|
|
|
# Test that listener can be closed
|
|
await listener.close()
|
|
|
|
|
|
# 3. Dial Basic Functionality Tests
|
|
@pytest.mark.trio
|
|
async def test_dial_basic():
|
|
"""Test basic dial functionality"""
|
|
upgrader = create_upgrader()
|
|
transport = WebsocketTransport(upgrader)
|
|
|
|
# Test that transport can parse addresses for dialing
|
|
ma = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
|
|
|
|
# Test that the address can be parsed correctly
|
|
host = ma.value_for_protocol("ip4")
|
|
port = ma.value_for_protocol("tcp")
|
|
assert host == "127.0.0.1"
|
|
assert port == "8080"
|
|
|
|
# Test that transport has the required methods
|
|
assert hasattr(transport, "dial")
|
|
assert callable(transport.dial)
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_dial_with_p2p_id():
|
|
"""Test dialing with p2p ID suffix"""
|
|
upgrader = create_upgrader()
|
|
transport = WebsocketTransport(upgrader)
|
|
|
|
p2p_id = "12D3KooWL5xtmx8Mgc6tByjVaPPpTKH42QK7PUFQtZLabdSMKHpF"
|
|
ma = Multiaddr(f"/ip4/127.0.0.1/tcp/8080/ws/p2p/{p2p_id}")
|
|
|
|
# Test that p2p ID is preserved in the address
|
|
addr_str = str(ma)
|
|
assert p2p_id in addr_str
|
|
|
|
# Test that transport can handle addresses with p2p IDs
|
|
assert hasattr(transport, "dial")
|
|
assert callable(transport.dial)
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_dial_port_0_resolution():
|
|
"""Test dialing to resolved port 0 addresses"""
|
|
upgrader = create_upgrader()
|
|
transport = WebsocketTransport(upgrader)
|
|
|
|
# Test that transport can handle port 0 addresses
|
|
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
|
|
|
|
# Test that the address can be parsed correctly
|
|
port_str = ma.value_for_protocol("tcp")
|
|
assert port_str == "0"
|
|
|
|
# Test that transport has the required methods
|
|
assert hasattr(transport, "dial")
|
|
assert callable(transport.dial)
|
|
|
|
|
|
# 4. Address Validation Tests (CRITICAL)
|
|
def test_address_validation_ipv4():
|
|
"""Test IPv4 address validation"""
|
|
# upgrader = create_upgrader() # Not used in this test
|
|
|
|
# Valid IPv4 WebSocket addresses
|
|
valid_addresses = [
|
|
"/ip4/127.0.0.1/tcp/8080/ws",
|
|
"/ip4/0.0.0.0/tcp/0/ws",
|
|
"/ip4/192.168.1.1/tcp/443/ws",
|
|
]
|
|
|
|
# Test valid addresses can be parsed
|
|
for addr_str in valid_addresses:
|
|
ma = Multiaddr(addr_str)
|
|
# Should not raise exception when creating transport address
|
|
transport_addr = str(ma)
|
|
assert "/ws" in transport_addr
|
|
|
|
# Test that transport can handle addresses with p2p IDs
|
|
p2p_addr = Multiaddr(
|
|
"/ip4/127.0.0.1/tcp/8080/ws/p2p/Qmb6owHp6eaWArVbcJJbQSyifyJBttMMjYV76N2hMbf5Vw"
|
|
)
|
|
# Should not raise exception when creating transport address
|
|
transport_addr = str(p2p_addr)
|
|
assert "/ws" in transport_addr
|
|
|
|
|
|
def test_address_validation_ipv6():
|
|
"""Test IPv6 address validation"""
|
|
# upgrader = create_upgrader() # Not used in this test
|
|
|
|
# Valid IPv6 WebSocket addresses
|
|
valid_addresses = [
|
|
"/ip6/::1/tcp/8080/ws",
|
|
"/ip6/2001:db8::1/tcp/443/ws",
|
|
]
|
|
|
|
# Test valid addresses can be parsed
|
|
for addr_str in valid_addresses:
|
|
ma = Multiaddr(addr_str)
|
|
transport_addr = str(ma)
|
|
assert "/ws" in transport_addr
|
|
|
|
|
|
def test_address_validation_dns():
|
|
"""Test DNS address validation"""
|
|
# upgrader = create_upgrader() # Not used in this test
|
|
|
|
# Valid DNS WebSocket addresses
|
|
valid_addresses = [
|
|
"/dns4/example.com/tcp/80/ws",
|
|
"/dns6/example.com/tcp/443/ws",
|
|
"/dnsaddr/example.com/tcp/8080/ws",
|
|
]
|
|
|
|
# Test valid addresses can be parsed
|
|
for addr_str in valid_addresses:
|
|
ma = Multiaddr(addr_str)
|
|
transport_addr = str(ma)
|
|
assert "/ws" in transport_addr
|
|
|
|
|
|
def test_address_validation_mixed():
|
|
"""Test mixed address validation"""
|
|
# upgrader = create_upgrader() # Not used in this test
|
|
|
|
# Mixed valid and invalid addresses
|
|
addresses = [
|
|
"/ip4/127.0.0.1/tcp/8080/ws", # Valid
|
|
"/ip4/127.0.0.1/tcp/8080", # Invalid (no /ws)
|
|
"/ip6/::1/tcp/8080/ws", # Valid
|
|
"/ip4/127.0.0.1/ws", # Invalid (no tcp)
|
|
"/dns4/example.com/tcp/80/ws", # Valid
|
|
]
|
|
|
|
# Convert to Multiaddr objects
|
|
multiaddrs = [Multiaddr(addr) for addr in addresses]
|
|
|
|
# Test that valid addresses can be processed
|
|
valid_count = 0
|
|
for ma in multiaddrs:
|
|
try:
|
|
# Try to extract transport part
|
|
addr_text = str(ma)
|
|
if "/ws" in addr_text and "/tcp/" in addr_text:
|
|
valid_count += 1
|
|
except Exception:
|
|
pass
|
|
|
|
assert valid_count == 3 # Should have 3 valid addresses
|
|
|
|
|
|
# 5. Error Handling Tests
|
|
@pytest.mark.trio
|
|
async def test_dial_invalid_address():
|
|
"""Test dialing invalid addresses"""
|
|
upgrader = create_upgrader()
|
|
transport = WebsocketTransport(upgrader)
|
|
|
|
# Test dialing non-WebSocket addresses
|
|
invalid_addresses = [
|
|
Multiaddr("/ip4/127.0.0.1/tcp/8080"), # No /ws
|
|
Multiaddr("/ip4/127.0.0.1/ws"), # No tcp
|
|
]
|
|
|
|
for ma in invalid_addresses:
|
|
with pytest.raises(Exception):
|
|
await transport.dial(ma)
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_listen_invalid_address():
|
|
"""Test listening on invalid addresses"""
|
|
# upgrader = create_upgrader() # Not used in this test
|
|
|
|
# Test listening on non-WebSocket addresses
|
|
invalid_addresses = [
|
|
Multiaddr("/ip4/127.0.0.1/tcp/8080"), # No /ws
|
|
Multiaddr("/ip4/127.0.0.1/ws"), # No tcp
|
|
]
|
|
|
|
# Test that invalid addresses are properly identified
|
|
for ma in invalid_addresses:
|
|
# Test that the address parsing works correctly
|
|
if "/ws" in str(ma) and "tcp" not in str(ma):
|
|
# This should be invalid
|
|
assert "tcp" not in str(ma)
|
|
elif "/ws" not in str(ma):
|
|
# This should be invalid
|
|
assert "/ws" not in str(ma)
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_listen_port_in_use():
|
|
"""Test listening on port that's in use"""
|
|
upgrader = create_upgrader()
|
|
transport = WebsocketTransport(upgrader)
|
|
|
|
# Test that transport can handle port conflicts
|
|
ma1 = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
|
|
ma2 = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
|
|
|
|
# Test that both addresses can be parsed
|
|
assert ma1.value_for_protocol("tcp") == "8080"
|
|
assert ma2.value_for_protocol("tcp") == "8080"
|
|
|
|
# Test that transport can handle these addresses
|
|
assert hasattr(transport, "create_listener")
|
|
assert callable(transport.create_listener)
|
|
|
|
|
|
# 6. Connection Lifecycle Tests
|
|
@pytest.mark.trio
|
|
async def test_connection_close():
|
|
"""Test connection closing"""
|
|
upgrader = create_upgrader()
|
|
transport = WebsocketTransport(upgrader)
|
|
|
|
# Test that transport has required methods
|
|
assert hasattr(transport, "dial")
|
|
assert callable(transport.dial)
|
|
|
|
# Test that listener can be created and closed
|
|
async def dummy_handler(conn):
|
|
await trio.sleep(0)
|
|
|
|
listener = transport.create_listener(dummy_handler)
|
|
assert hasattr(listener, "close")
|
|
assert callable(listener.close)
|
|
|
|
# Test that listener can be closed
|
|
await listener.close()
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_multiple_connections():
|
|
"""Test multiple concurrent connections"""
|
|
upgrader = create_upgrader()
|
|
transport = WebsocketTransport(upgrader)
|
|
|
|
# Test that transport can handle multiple addresses
|
|
addresses = [
|
|
Multiaddr("/ip4/127.0.0.1/tcp/8080/ws"),
|
|
Multiaddr("/ip4/127.0.0.1/tcp/8081/ws"),
|
|
Multiaddr("/ip4/127.0.0.1/tcp/8082/ws"),
|
|
]
|
|
|
|
# Test that all addresses can be parsed
|
|
for addr in addresses:
|
|
host = addr.value_for_protocol("ip4")
|
|
port = addr.value_for_protocol("tcp")
|
|
assert host == "127.0.0.1"
|
|
assert port in ["8080", "8081", "8082"]
|
|
|
|
# Test that transport has required methods
|
|
assert hasattr(transport, "dial")
|
|
assert callable(transport.dial)
|
|
|
|
|
|
# Original test (kept for compatibility)
|
|
@pytest.mark.trio
|
|
async def test_websocket_dial_and_listen():
|
|
"""Test basic WebSocket dial and listen functionality with real data transfer"""
|
|
# Test that WebSocket transport can handle basic operations
|
|
upgrader = create_upgrader()
|
|
transport = WebsocketTransport(upgrader)
|
|
|
|
# Test that transport can create listeners
|
|
async def dummy_handler(conn):
|
|
await trio.sleep(0)
|
|
|
|
listener = transport.create_listener(dummy_handler)
|
|
assert listener is not None
|
|
assert hasattr(listener, "listen")
|
|
assert hasattr(listener, "close")
|
|
assert hasattr(listener, "get_addrs")
|
|
|
|
# Test that transport can handle WebSocket addresses
|
|
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
|
|
assert ma.value_for_protocol("ip4") == "127.0.0.1"
|
|
assert ma.value_for_protocol("tcp") == "0"
|
|
assert "ws" in str(ma)
|
|
|
|
# Test that transport has dial method
|
|
assert hasattr(transport, "dial")
|
|
assert callable(transport.dial)
|
|
|
|
# Test that transport can handle WebSocket multiaddrs
|
|
ws_addr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
|
|
assert ws_addr.value_for_protocol("ip4") == "127.0.0.1"
|
|
assert ws_addr.value_for_protocol("tcp") == "8080"
|
|
assert "ws" in str(ws_addr)
|
|
|
|
# Cleanup
|
|
await listener.close()
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_websocket_transport_basic():
|
|
"""Test basic WebSocket transport functionality without full libp2p stack"""
|
|
# Create WebSocket transport
|
|
key_pair = create_new_key_pair()
|
|
upgrader = TransportUpgrader(
|
|
secure_transports_by_protocol={
|
|
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
|
|
},
|
|
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
|
)
|
|
transport = WebsocketTransport(upgrader)
|
|
|
|
assert transport is not None
|
|
assert hasattr(transport, "dial")
|
|
assert hasattr(transport, "create_listener")
|
|
|
|
async def dummy_handler(conn):
|
|
await trio.sleep(0)
|
|
|
|
listener = transport.create_listener(dummy_handler)
|
|
assert listener is not None
|
|
assert hasattr(listener, "listen")
|
|
assert hasattr(listener, "close")
|
|
assert hasattr(listener, "get_addrs")
|
|
|
|
valid_addr = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
|
|
assert valid_addr.value_for_protocol("ip4") == "127.0.0.1"
|
|
assert valid_addr.value_for_protocol("tcp") == "0"
|
|
assert "ws" in str(valid_addr)
|
|
|
|
await listener.close()
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_websocket_simple_connection():
|
|
"""Test WebSocket transport creation and basic functionality without real conn"""
|
|
# Create WebSocket transport
|
|
key_pair = create_new_key_pair()
|
|
upgrader = TransportUpgrader(
|
|
secure_transports_by_protocol={
|
|
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
|
|
},
|
|
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
|
)
|
|
transport = WebsocketTransport(upgrader)
|
|
|
|
assert transport is not None
|
|
assert hasattr(transport, "dial")
|
|
assert hasattr(transport, "create_listener")
|
|
|
|
async def simple_handler(conn):
|
|
await conn.close()
|
|
|
|
listener = transport.create_listener(simple_handler)
|
|
assert listener is not None
|
|
assert hasattr(listener, "listen")
|
|
assert hasattr(listener, "close")
|
|
assert hasattr(listener, "get_addrs")
|
|
|
|
test_addr = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
|
|
assert test_addr.value_for_protocol("ip4") == "127.0.0.1"
|
|
assert test_addr.value_for_protocol("tcp") == "0"
|
|
assert "ws" in str(test_addr)
|
|
|
|
await listener.close()
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_websocket_real_connection():
|
|
"""Test WebSocket transport creation and basic functionality"""
|
|
# Create WebSocket transport
|
|
key_pair = create_new_key_pair()
|
|
upgrader = TransportUpgrader(
|
|
secure_transports_by_protocol={
|
|
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
|
|
},
|
|
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
|
)
|
|
transport = WebsocketTransport(upgrader)
|
|
|
|
assert transport is not None
|
|
assert hasattr(transport, "dial")
|
|
assert hasattr(transport, "create_listener")
|
|
|
|
async def handler(conn):
|
|
await conn.close()
|
|
|
|
listener = transport.create_listener(handler)
|
|
assert listener is not None
|
|
assert hasattr(listener, "listen")
|
|
assert hasattr(listener, "close")
|
|
assert hasattr(listener, "get_addrs")
|
|
|
|
await listener.close()
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_websocket_with_tcp_fallback():
|
|
"""Test WebSocket functionality using TCP transport as fallback"""
|
|
from tests.utils.factories import host_pair_factory
|
|
|
|
async with host_pair_factory() as (host_a, host_b):
|
|
assert len(host_a.get_network().connections) > 0
|
|
assert len(host_b.get_network().connections) > 0
|
|
|
|
test_protocol = TProtocol("/test/protocol/1.0.0")
|
|
received_data = None
|
|
|
|
async def test_handler(stream):
|
|
nonlocal received_data
|
|
received_data = await stream.read(1024)
|
|
await stream.write(b"Response from TCP")
|
|
await stream.close()
|
|
|
|
host_a.set_stream_handler(test_protocol, test_handler)
|
|
stream = await host_b.new_stream(host_a.get_id(), [test_protocol])
|
|
|
|
test_data = b"TCP protocol test"
|
|
await stream.write(test_data)
|
|
response = await stream.read(1024)
|
|
|
|
assert received_data == test_data
|
|
assert response == b"Response from TCP"
|
|
|
|
await stream.close()
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_websocket_data_exchange():
|
|
"""Test WebSocket transport with actual data exchange between two hosts"""
|
|
from libp2p import create_yamux_muxer_option, new_host
|
|
from libp2p.crypto.secp256k1 import create_new_key_pair
|
|
from libp2p.custom_types import TProtocol
|
|
from libp2p.peer.peerinfo import info_from_p2p_addr
|
|
from libp2p.security.insecure.transport import (
|
|
PLAINTEXT_PROTOCOL_ID,
|
|
InsecureTransport,
|
|
)
|
|
|
|
# Create two hosts with plaintext security
|
|
key_pair_a = create_new_key_pair()
|
|
key_pair_b = create_new_key_pair()
|
|
|
|
# Host A (listener)
|
|
security_options_a = {
|
|
PLAINTEXT_PROTOCOL_ID: InsecureTransport(
|
|
local_key_pair=key_pair_a, secure_bytes_provider=None, peerstore=None
|
|
)
|
|
}
|
|
host_a = new_host(
|
|
key_pair=key_pair_a,
|
|
sec_opt=security_options_a,
|
|
muxer_opt=create_yamux_muxer_option(),
|
|
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")],
|
|
)
|
|
|
|
# Host B (dialer)
|
|
security_options_b = {
|
|
PLAINTEXT_PROTOCOL_ID: InsecureTransport(
|
|
local_key_pair=key_pair_b, secure_bytes_provider=None, peerstore=None
|
|
)
|
|
}
|
|
host_b = new_host(
|
|
key_pair=key_pair_b,
|
|
sec_opt=security_options_b,
|
|
muxer_opt=create_yamux_muxer_option(),
|
|
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # WebSocket transport
|
|
)
|
|
|
|
# Test data
|
|
test_data = b"Hello WebSocket Data Exchange!"
|
|
received_data = None
|
|
|
|
# Set up handler on host A
|
|
test_protocol = TProtocol("/test/websocket/data/1.0.0")
|
|
|
|
async def data_handler(stream):
|
|
nonlocal received_data
|
|
received_data = await stream.read(len(test_data))
|
|
await stream.write(received_data) # Echo back
|
|
await stream.close()
|
|
|
|
host_a.set_stream_handler(test_protocol, data_handler)
|
|
|
|
# Start both hosts
|
|
async with (
|
|
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")]),
|
|
host_b.run(listen_addrs=[]),
|
|
):
|
|
# Get host A's listen address
|
|
listen_addrs = host_a.get_addrs()
|
|
assert len(listen_addrs) > 0
|
|
|
|
# Find the WebSocket address
|
|
ws_addr = None
|
|
for addr in listen_addrs:
|
|
if "/ws" in str(addr):
|
|
ws_addr = addr
|
|
break
|
|
|
|
assert ws_addr is not None, "No WebSocket listen address found"
|
|
|
|
# Connect host B to host A
|
|
peer_info = info_from_p2p_addr(ws_addr)
|
|
await host_b.connect(peer_info)
|
|
|
|
# Create stream and test data exchange
|
|
stream = await host_b.new_stream(host_a.get_id(), [test_protocol])
|
|
await stream.write(test_data)
|
|
response = await stream.read(len(test_data))
|
|
await stream.close()
|
|
|
|
# Verify data exchange
|
|
assert received_data == test_data, f"Expected {test_data}, got {received_data}"
|
|
assert response == test_data, f"Expected echo {test_data}, got {response}"
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_websocket_host_pair_data_exchange():
|
|
"""
|
|
Test WebSocket host pair with actual data exchange using host_pair_factory
|
|
pattern.
|
|
"""
|
|
from libp2p import create_yamux_muxer_option, new_host
|
|
from libp2p.crypto.secp256k1 import create_new_key_pair
|
|
from libp2p.custom_types import TProtocol
|
|
from libp2p.peer.peerinfo import info_from_p2p_addr
|
|
from libp2p.security.insecure.transport import (
|
|
PLAINTEXT_PROTOCOL_ID,
|
|
InsecureTransport,
|
|
)
|
|
|
|
# Create two hosts with WebSocket transport and plaintext security
|
|
key_pair_a = create_new_key_pair()
|
|
key_pair_b = create_new_key_pair()
|
|
|
|
# Host A (listener) - WebSocket transport
|
|
security_options_a = {
|
|
PLAINTEXT_PROTOCOL_ID: InsecureTransport(
|
|
local_key_pair=key_pair_a, secure_bytes_provider=None, peerstore=None
|
|
)
|
|
}
|
|
host_a = new_host(
|
|
key_pair=key_pair_a,
|
|
sec_opt=security_options_a,
|
|
muxer_opt=create_yamux_muxer_option(),
|
|
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")],
|
|
)
|
|
|
|
# Host B (dialer) - WebSocket transport
|
|
security_options_b = {
|
|
PLAINTEXT_PROTOCOL_ID: InsecureTransport(
|
|
local_key_pair=key_pair_b, secure_bytes_provider=None, peerstore=None
|
|
)
|
|
}
|
|
host_b = new_host(
|
|
key_pair=key_pair_b,
|
|
sec_opt=security_options_b,
|
|
muxer_opt=create_yamux_muxer_option(),
|
|
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # WebSocket transport
|
|
)
|
|
|
|
# Test data
|
|
test_data = b"Hello WebSocket Host Pair Data Exchange!"
|
|
received_data = None
|
|
|
|
# Set up handler on host A
|
|
test_protocol = TProtocol("/test/websocket/hostpair/1.0.0")
|
|
|
|
async def data_handler(stream):
|
|
nonlocal received_data
|
|
received_data = await stream.read(len(test_data))
|
|
await stream.write(received_data) # Echo back
|
|
await stream.close()
|
|
|
|
host_a.set_stream_handler(test_protocol, data_handler)
|
|
|
|
# Start both hosts and connect them (following host_pair_factory pattern)
|
|
async with (
|
|
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")]),
|
|
host_b.run(listen_addrs=[]),
|
|
):
|
|
# Connect the hosts using the same pattern as host_pair_factory
|
|
# Get host A's listen address and create peer info
|
|
listen_addrs = host_a.get_addrs()
|
|
assert len(listen_addrs) > 0
|
|
|
|
# Find the WebSocket address
|
|
ws_addr = None
|
|
for addr in listen_addrs:
|
|
if "/ws" in str(addr):
|
|
ws_addr = addr
|
|
break
|
|
|
|
assert ws_addr is not None, "No WebSocket listen address found"
|
|
|
|
# Connect host B to host A
|
|
peer_info = info_from_p2p_addr(ws_addr)
|
|
await host_b.connect(peer_info)
|
|
|
|
# Allow time for connection to establish (following host_pair_factory pattern)
|
|
await trio.sleep(0.1)
|
|
|
|
# Verify connection is established
|
|
assert len(host_a.get_network().connections) > 0
|
|
assert len(host_b.get_network().connections) > 0
|
|
|
|
# Test data exchange
|
|
stream = await host_b.new_stream(host_a.get_id(), [test_protocol])
|
|
await stream.write(test_data)
|
|
response = await stream.read(len(test_data))
|
|
await stream.close()
|
|
|
|
# Verify data exchange
|
|
assert received_data == test_data, f"Expected {test_data}, got {received_data}"
|
|
assert response == test_data, f"Expected echo {test_data}, got {response}"
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_wss_host_pair_data_exchange():
|
|
"""Test WSS host pair with actual data exchange using host_pair_factory pattern"""
|
|
import ssl
|
|
|
|
from libp2p import create_yamux_muxer_option, new_host
|
|
from libp2p.crypto.secp256k1 import create_new_key_pair
|
|
from libp2p.custom_types import TProtocol
|
|
from libp2p.peer.peerinfo import info_from_p2p_addr
|
|
from libp2p.security.insecure.transport import (
|
|
PLAINTEXT_PROTOCOL_ID,
|
|
InsecureTransport,
|
|
)
|
|
|
|
# Create TLS contexts for WSS (separate for client and server)
|
|
# For testing, we need to create a self-signed certificate
|
|
try:
|
|
import datetime
|
|
import ipaddress
|
|
import os
|
|
import tempfile
|
|
|
|
from cryptography import x509
|
|
from cryptography.hazmat.primitives import hashes, serialization
|
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
from cryptography.x509.oid import NameOID
|
|
|
|
# Generate private key
|
|
private_key = rsa.generate_private_key(
|
|
public_exponent=65537,
|
|
key_size=2048,
|
|
)
|
|
|
|
# Create certificate
|
|
subject = issuer = x509.Name(
|
|
[
|
|
x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), # type: ignore
|
|
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Test"), # type: ignore
|
|
x509.NameAttribute(NameOID.LOCALITY_NAME, "Test"), # type: ignore
|
|
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Test"), # type: ignore
|
|
x509.NameAttribute(NameOID.COMMON_NAME, "localhost"), # type: ignore
|
|
]
|
|
)
|
|
|
|
cert = (
|
|
x509.CertificateBuilder()
|
|
.subject_name(subject)
|
|
.issuer_name(issuer)
|
|
.public_key(private_key.public_key())
|
|
.serial_number(x509.random_serial_number())
|
|
.not_valid_before(datetime.datetime.now(datetime.UTC))
|
|
.not_valid_after(
|
|
datetime.datetime.now(datetime.UTC) + datetime.timedelta(days=1)
|
|
)
|
|
.add_extension(
|
|
x509.SubjectAlternativeName(
|
|
[
|
|
x509.DNSName("localhost"),
|
|
x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")),
|
|
]
|
|
),
|
|
critical=False,
|
|
)
|
|
.sign(private_key, hashes.SHA256())
|
|
)
|
|
|
|
# Create temporary files for cert and key
|
|
cert_file = tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=".crt")
|
|
key_file = tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=".key")
|
|
|
|
# Write certificate and key to files
|
|
cert_file.write(cert.public_bytes(serialization.Encoding.PEM))
|
|
key_file.write(
|
|
private_key.private_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PrivateFormat.PKCS8,
|
|
encryption_algorithm=serialization.NoEncryption(),
|
|
)
|
|
)
|
|
|
|
cert_file.close()
|
|
key_file.close()
|
|
|
|
# Server context for listener (Host A)
|
|
server_tls_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
|
server_tls_context.load_cert_chain(cert_file.name, key_file.name)
|
|
|
|
# Client context for dialer (Host B)
|
|
client_tls_context = ssl.create_default_context()
|
|
client_tls_context.check_hostname = False
|
|
client_tls_context.verify_mode = ssl.CERT_NONE
|
|
|
|
# Clean up temp files after use
|
|
def cleanup_certs():
|
|
try:
|
|
os.unlink(cert_file.name)
|
|
os.unlink(key_file.name)
|
|
except Exception:
|
|
pass
|
|
|
|
except ImportError:
|
|
pytest.skip("cryptography package required for WSS tests")
|
|
except Exception as e:
|
|
pytest.skip(f"Failed to create test certificates: {e}")
|
|
|
|
# Create two hosts with WSS transport and plaintext security
|
|
key_pair_a = create_new_key_pair()
|
|
key_pair_b = create_new_key_pair()
|
|
|
|
# Host A (listener) - WSS transport with server TLS config
|
|
security_options_a = {
|
|
PLAINTEXT_PROTOCOL_ID: InsecureTransport(
|
|
local_key_pair=key_pair_a, secure_bytes_provider=None, peerstore=None
|
|
)
|
|
}
|
|
host_a = new_host(
|
|
key_pair=key_pair_a,
|
|
sec_opt=security_options_a,
|
|
muxer_opt=create_yamux_muxer_option(),
|
|
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/wss")],
|
|
tls_server_config=server_tls_context,
|
|
)
|
|
|
|
# Host B (dialer) - WSS transport with client TLS config
|
|
security_options_b = {
|
|
PLAINTEXT_PROTOCOL_ID: InsecureTransport(
|
|
local_key_pair=key_pair_b, secure_bytes_provider=None, peerstore=None
|
|
)
|
|
}
|
|
host_b = new_host(
|
|
key_pair=key_pair_b,
|
|
sec_opt=security_options_b,
|
|
muxer_opt=create_yamux_muxer_option(),
|
|
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/wss")], # Ensure WSS transport
|
|
tls_client_config=client_tls_context,
|
|
)
|
|
|
|
# Test data
|
|
test_data = b"Hello WSS Host Pair Data Exchange!"
|
|
received_data = None
|
|
|
|
# Set up handler on host A
|
|
test_protocol = TProtocol("/test/wss/hostpair/1.0.0")
|
|
|
|
async def data_handler(stream):
|
|
nonlocal received_data
|
|
received_data = await stream.read(len(test_data))
|
|
await stream.write(received_data) # Echo back
|
|
await stream.close()
|
|
|
|
host_a.set_stream_handler(test_protocol, data_handler)
|
|
|
|
# Start both hosts and connect them (following host_pair_factory pattern)
|
|
async with (
|
|
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/wss")]),
|
|
host_b.run(listen_addrs=[]),
|
|
):
|
|
# Connect the hosts using the same pattern as host_pair_factory
|
|
# Get host A's listen address and create peer info
|
|
listen_addrs = host_a.get_addrs()
|
|
assert len(listen_addrs) > 0
|
|
|
|
# Find the WSS address
|
|
wss_addr = None
|
|
for addr in listen_addrs:
|
|
if "/wss" in str(addr):
|
|
wss_addr = addr
|
|
break
|
|
|
|
assert wss_addr is not None, "No WSS listen address found"
|
|
|
|
# Connect host B to host A
|
|
peer_info = info_from_p2p_addr(wss_addr)
|
|
await host_b.connect(peer_info)
|
|
|
|
# Allow time for connection to establish (following host_pair_factory pattern)
|
|
await trio.sleep(0.1)
|
|
|
|
# Verify connection is established
|
|
assert len(host_a.get_network().connections) > 0
|
|
assert len(host_b.get_network().connections) > 0
|
|
|
|
# Test data exchange
|
|
stream = await host_b.new_stream(host_a.get_id(), [test_protocol])
|
|
await stream.write(test_data)
|
|
response = await stream.read(len(test_data))
|
|
await stream.close()
|
|
|
|
# Verify data exchange
|
|
assert received_data == test_data, f"Expected {test_data}, got {received_data}"
|
|
assert response == test_data, f"Expected echo {test_data}, got {response}"
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_websocket_transport_interface():
|
|
"""Test WebSocket transport interface compliance"""
|
|
key_pair = create_new_key_pair()
|
|
upgrader = TransportUpgrader(
|
|
secure_transports_by_protocol={
|
|
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
|
|
},
|
|
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
|
)
|
|
|
|
transport = WebsocketTransport(upgrader)
|
|
|
|
assert hasattr(transport, "dial")
|
|
assert hasattr(transport, "create_listener")
|
|
assert callable(transport.dial)
|
|
assert callable(transport.create_listener)
|
|
|
|
async def dummy_handler(conn):
|
|
await trio.sleep(0)
|
|
|
|
listener = transport.create_listener(dummy_handler)
|
|
assert hasattr(listener, "listen")
|
|
assert hasattr(listener, "close")
|
|
assert hasattr(listener, "get_addrs")
|
|
|
|
test_addr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
|
|
host = test_addr.value_for_protocol("ip4")
|
|
port = test_addr.value_for_protocol("tcp")
|
|
assert host == "127.0.0.1"
|
|
assert port == "8080"
|
|
|
|
await listener.close()
|
|
|
|
|
|
# ============================================================================
|
|
# WSS (WebSocket Secure) Tests
|
|
# ============================================================================
|
|
|
|
|
|
def test_wss_multiaddr_validation():
|
|
"""Test WSS multiaddr validation and parsing."""
|
|
# Valid WSS multiaddrs
|
|
valid_wss_addresses = [
|
|
"/ip4/127.0.0.1/tcp/8080/wss",
|
|
"/ip6/::1/tcp/8080/wss",
|
|
"/dns/localhost/tcp/8080/wss",
|
|
"/ip4/127.0.0.1/tcp/8080/tls/ws",
|
|
"/ip6/::1/tcp/8080/tls/ws",
|
|
]
|
|
|
|
# Invalid WSS multiaddrs
|
|
invalid_wss_addresses = [
|
|
"/ip4/127.0.0.1/tcp/8080/ws", # Regular WS, not WSS
|
|
"/ip4/127.0.0.1/tcp/8080", # No WebSocket protocol
|
|
"/ip4/127.0.0.1/wss", # No TCP
|
|
]
|
|
|
|
# Test valid WSS addresses
|
|
for addr_str in valid_wss_addresses:
|
|
ma = Multiaddr(addr_str)
|
|
assert is_valid_websocket_multiaddr(ma), f"Address {addr_str} should be valid"
|
|
|
|
# Test parsing
|
|
parsed = parse_websocket_multiaddr(ma)
|
|
assert parsed.is_wss, f"Address {addr_str} should be parsed as WSS"
|
|
|
|
# Test invalid addresses
|
|
for addr_str in invalid_wss_addresses:
|
|
ma = Multiaddr(addr_str)
|
|
if "/ws" in addr_str and "/wss" not in addr_str and "/tls" not in addr_str:
|
|
# Regular WS should be valid but not WSS
|
|
assert is_valid_websocket_multiaddr(ma), (
|
|
f"Address {addr_str} should be valid"
|
|
)
|
|
parsed = parse_websocket_multiaddr(ma)
|
|
assert not parsed.is_wss, f"Address {addr_str} should not be parsed as WSS"
|
|
else:
|
|
# Invalid addresses should fail validation
|
|
assert not is_valid_websocket_multiaddr(ma), (
|
|
f"Address {addr_str} should be invalid"
|
|
)
|
|
|
|
|
|
def test_wss_multiaddr_parsing():
|
|
"""Test WSS multiaddr parsing functionality."""
|
|
# Test /wss format
|
|
wss_ma = Multiaddr("/ip4/127.0.0.1/tcp/8080/wss")
|
|
parsed = parse_websocket_multiaddr(wss_ma)
|
|
assert parsed.is_wss
|
|
assert parsed.sni is None
|
|
assert parsed.rest_multiaddr.value_for_protocol("ip4") == "127.0.0.1"
|
|
assert parsed.rest_multiaddr.value_for_protocol("tcp") == "8080"
|
|
|
|
# Test /tls/ws format
|
|
tls_ws_ma = Multiaddr("/ip4/127.0.0.1/tcp/8080/tls/ws")
|
|
parsed = parse_websocket_multiaddr(tls_ws_ma)
|
|
assert parsed.is_wss
|
|
assert parsed.sni is None
|
|
assert parsed.rest_multiaddr.value_for_protocol("ip4") == "127.0.0.1"
|
|
assert parsed.rest_multiaddr.value_for_protocol("tcp") == "8080"
|
|
|
|
# Test regular /ws format
|
|
ws_ma = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
|
|
parsed = parse_websocket_multiaddr(ws_ma)
|
|
assert not parsed.is_wss
|
|
assert parsed.sni is None
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_wss_transport_creation():
|
|
"""Test WSS transport creation with TLS configuration."""
|
|
import ssl
|
|
|
|
# Create TLS contexts
|
|
client_ssl_context = ssl.create_default_context()
|
|
server_ssl_context = ssl.create_default_context()
|
|
server_ssl_context.check_hostname = False
|
|
server_ssl_context.verify_mode = ssl.CERT_NONE
|
|
|
|
upgrader = create_upgrader()
|
|
|
|
# Test creating WSS transport with TLS configs
|
|
wss_transport = WebsocketTransport(
|
|
upgrader,
|
|
tls_client_config=client_ssl_context,
|
|
tls_server_config=server_ssl_context,
|
|
)
|
|
|
|
assert wss_transport is not None
|
|
assert hasattr(wss_transport, "dial")
|
|
assert hasattr(wss_transport, "create_listener")
|
|
assert wss_transport._tls_client_config is not None
|
|
assert wss_transport._tls_server_config is not None
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_wss_transport_without_tls_config():
|
|
"""Test WSS transport creation without TLS configuration."""
|
|
upgrader = create_upgrader()
|
|
|
|
# Test creating WSS transport without TLS configs (should still work)
|
|
wss_transport = WebsocketTransport(upgrader)
|
|
|
|
assert wss_transport is not None
|
|
assert hasattr(wss_transport, "dial")
|
|
assert hasattr(wss_transport, "create_listener")
|
|
assert wss_transport._tls_client_config is None
|
|
assert wss_transport._tls_server_config is None
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_wss_dial_parsing():
|
|
"""Test WSS dial functionality with multiaddr parsing."""
|
|
# upgrader = create_upgrader() # Not used in this test
|
|
# transport = WebsocketTransport(upgrader) # Not used in this test
|
|
|
|
# Test WSS multiaddr parsing in dial
|
|
wss_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/wss")
|
|
|
|
# Test that the transport can parse WSS addresses
|
|
# (We can't actually dial without a server, but we can test parsing)
|
|
try:
|
|
parsed = parse_websocket_multiaddr(wss_maddr)
|
|
assert parsed.is_wss
|
|
assert parsed.rest_multiaddr.value_for_protocol("ip4") == "127.0.0.1"
|
|
assert parsed.rest_multiaddr.value_for_protocol("tcp") == "8080"
|
|
except Exception as e:
|
|
pytest.fail(f"WSS multiaddr parsing failed: {e}")
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_wss_listen_parsing():
|
|
"""Test WSS listen functionality with multiaddr parsing."""
|
|
upgrader = create_upgrader()
|
|
transport = WebsocketTransport(upgrader)
|
|
|
|
# Test WSS multiaddr parsing in listen
|
|
wss_maddr = Multiaddr("/ip4/127.0.0.1/tcp/0/wss")
|
|
|
|
async def dummy_handler(conn):
|
|
await trio.sleep(0)
|
|
|
|
listener = transport.create_listener(dummy_handler)
|
|
|
|
# Test that the transport can parse WSS addresses
|
|
try:
|
|
parsed = parse_websocket_multiaddr(wss_maddr)
|
|
assert parsed.is_wss
|
|
assert parsed.rest_multiaddr.value_for_protocol("ip4") == "127.0.0.1"
|
|
assert parsed.rest_multiaddr.value_for_protocol("tcp") == "0"
|
|
except Exception as e:
|
|
pytest.fail(f"WSS multiaddr parsing failed: {e}")
|
|
|
|
await listener.close()
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_wss_listen_without_tls_config():
|
|
"""Test WSS listen without TLS configuration should fail."""
|
|
upgrader = create_upgrader()
|
|
transport = WebsocketTransport(upgrader) # No TLS config
|
|
|
|
wss_maddr = Multiaddr("/ip4/127.0.0.1/tcp/0/wss")
|
|
|
|
async def dummy_handler(conn):
|
|
await trio.sleep(0)
|
|
|
|
listener = transport.create_listener(dummy_handler)
|
|
|
|
# This should raise an error when trying to listen on WSS without TLS config
|
|
with pytest.raises(ExceptionGroup) as exc_info:
|
|
async with trio.open_nursery() as nursery:
|
|
await listener.listen(wss_maddr, nursery)
|
|
|
|
# Check that the ExceptionGroup contains the expected ValueError
|
|
assert len(exc_info.value.exceptions) == 1
|
|
assert isinstance(exc_info.value.exceptions[0], ValueError)
|
|
assert "Cannot listen on WSS address" in str(exc_info.value.exceptions[0])
|
|
assert "without TLS configuration" in str(exc_info.value.exceptions[0])
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_wss_listen_with_tls_config():
|
|
"""Test WSS listen with TLS configuration."""
|
|
import ssl
|
|
|
|
# Create server TLS context
|
|
server_ssl_context = ssl.create_default_context()
|
|
server_ssl_context.check_hostname = False
|
|
server_ssl_context.verify_mode = ssl.CERT_NONE
|
|
|
|
upgrader = create_upgrader()
|
|
transport = WebsocketTransport(upgrader, tls_server_config=server_ssl_context)
|
|
|
|
wss_maddr = Multiaddr("/ip4/127.0.0.1/tcp/0/wss")
|
|
|
|
async def dummy_handler(conn):
|
|
await trio.sleep(0)
|
|
|
|
listener = transport.create_listener(dummy_handler)
|
|
|
|
# This should not raise an error when TLS config is provided
|
|
# Note: We can't actually start listening without proper certificates,
|
|
# but we can test that the validation passes
|
|
try:
|
|
parsed = parse_websocket_multiaddr(wss_maddr)
|
|
assert parsed.is_wss
|
|
assert transport._tls_server_config is not None
|
|
except Exception as e:
|
|
pytest.fail(f"WSS listen with TLS config failed: {e}")
|
|
|
|
await listener.close()
|
|
|
|
|
|
def test_wss_transport_registry():
|
|
"""Test WSS support in transport registry."""
|
|
from libp2p.transport.transport_registry import (
|
|
create_transport_for_multiaddr,
|
|
get_supported_transport_protocols,
|
|
)
|
|
|
|
# Test that WSS is supported
|
|
supported = get_supported_transport_protocols()
|
|
assert "ws" in supported
|
|
assert "wss" in supported
|
|
|
|
# Test transport creation for WSS multiaddrs
|
|
upgrader = create_upgrader()
|
|
|
|
# Test WS multiaddr
|
|
ws_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
|
|
ws_transport = create_transport_for_multiaddr(ws_maddr, upgrader)
|
|
assert ws_transport is not None
|
|
assert isinstance(ws_transport, WebsocketTransport)
|
|
|
|
# Test WSS multiaddr
|
|
wss_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/wss")
|
|
wss_transport = create_transport_for_multiaddr(wss_maddr, upgrader)
|
|
assert wss_transport is not None
|
|
assert isinstance(wss_transport, WebsocketTransport)
|
|
|
|
# Test TLS/WS multiaddr
|
|
tls_ws_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/tls/ws")
|
|
tls_ws_transport = create_transport_for_multiaddr(tls_ws_maddr, upgrader)
|
|
assert tls_ws_transport is not None
|
|
assert isinstance(tls_ws_transport, WebsocketTransport)
|
|
|
|
|
|
def test_wss_multiaddr_formats():
|
|
"""Test different WSS multiaddr formats."""
|
|
# Test various WSS formats
|
|
wss_formats = [
|
|
"/ip4/127.0.0.1/tcp/8080/wss",
|
|
"/ip6/::1/tcp/8080/wss",
|
|
"/dns/localhost/tcp/8080/wss",
|
|
"/ip4/127.0.0.1/tcp/8080/tls/ws",
|
|
"/ip6/::1/tcp/8080/tls/ws",
|
|
"/dns/example.com/tcp/443/tls/ws",
|
|
]
|
|
|
|
for addr_str in wss_formats:
|
|
ma = Multiaddr(addr_str)
|
|
|
|
# Should be valid WebSocket multiaddr
|
|
assert is_valid_websocket_multiaddr(ma), f"Address {addr_str} should be valid"
|
|
|
|
# Should parse as WSS
|
|
parsed = parse_websocket_multiaddr(ma)
|
|
assert parsed.is_wss, f"Address {addr_str} should be parsed as WSS"
|
|
|
|
# Should have correct base multiaddr
|
|
assert parsed.rest_multiaddr.value_for_protocol("tcp") is not None
|
|
|
|
|
|
def test_wss_vs_ws_distinction():
|
|
"""Test that WSS and WS are properly distinguished."""
|
|
# WS addresses should not be WSS
|
|
ws_addresses = [
|
|
"/ip4/127.0.0.1/tcp/8080/ws",
|
|
"/ip6/::1/tcp/8080/ws",
|
|
"/dns/localhost/tcp/8080/ws",
|
|
]
|
|
|
|
for addr_str in ws_addresses:
|
|
ma = Multiaddr(addr_str)
|
|
parsed = parse_websocket_multiaddr(ma)
|
|
assert not parsed.is_wss, f"Address {addr_str} should not be WSS"
|
|
|
|
# WSS addresses should be WSS
|
|
wss_addresses = [
|
|
"/ip4/127.0.0.1/tcp/8080/wss",
|
|
"/ip4/127.0.0.1/tcp/8080/tls/ws",
|
|
]
|
|
|
|
for addr_str in wss_addresses:
|
|
ma = Multiaddr(addr_str)
|
|
parsed = parse_websocket_multiaddr(ma)
|
|
assert parsed.is_wss, f"Address {addr_str} should be WSS"
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_wss_connection_handling():
|
|
"""Test WSS connection handling with security flag."""
|
|
# upgrader = create_upgrader() # Not used in this test
|
|
# transport = WebsocketTransport(upgrader) # Not used in this test
|
|
|
|
# Test that WSS connections are marked as secure
|
|
wss_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/wss")
|
|
parsed = parse_websocket_multiaddr(wss_maddr)
|
|
assert parsed.is_wss
|
|
|
|
# Test that WS connections are not marked as secure
|
|
ws_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
|
|
parsed = parse_websocket_multiaddr(ws_maddr)
|
|
assert not parsed.is_wss
|
|
|
|
|
|
def test_wss_error_handling():
|
|
"""Test WSS error handling for invalid configurations."""
|
|
# upgrader = create_upgrader() # Not used in this test
|
|
|
|
# Test invalid multiaddr formats
|
|
invalid_addresses = [
|
|
"/ip4/127.0.0.1/tcp/8080", # No WebSocket protocol
|
|
"/ip4/127.0.0.1/wss", # No TCP
|
|
"/tcp/8080/wss", # No network protocol
|
|
]
|
|
|
|
for addr_str in invalid_addresses:
|
|
ma = Multiaddr(addr_str)
|
|
assert not is_valid_websocket_multiaddr(ma), (
|
|
f"Address {addr_str} should be invalid"
|
|
)
|
|
|
|
# Should raise ValueError when parsing invalid addresses
|
|
with pytest.raises(ValueError):
|
|
parse_websocket_multiaddr(ma)
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_handshake_timeout():
|
|
"""Test WebSocket handshake timeout functionality."""
|
|
upgrader = create_upgrader()
|
|
|
|
# Test creating transport with custom handshake timeout
|
|
transport = WebsocketTransport(upgrader, handshake_timeout=0.1) # 100ms timeout
|
|
assert transport._handshake_timeout == 0.1
|
|
|
|
# Test that the timeout is passed to the listener
|
|
async def dummy_handler(conn):
|
|
await trio.sleep(0)
|
|
|
|
listener = transport.create_listener(dummy_handler)
|
|
# Type assertion to access private attribute for testing
|
|
assert hasattr(listener, "_handshake_timeout")
|
|
assert getattr(listener, "_handshake_timeout") == 0.1
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_handshake_timeout_creation():
|
|
"""Test handshake timeout in transport creation."""
|
|
upgrader = create_upgrader()
|
|
|
|
# Test creating transport with handshake timeout via create_transport
|
|
from libp2p.transport import create_transport
|
|
|
|
transport = create_transport("ws", upgrader, handshake_timeout=5.0)
|
|
# Type assertion to access private attribute for testing
|
|
assert hasattr(transport, "_handshake_timeout")
|
|
assert getattr(transport, "_handshake_timeout") == 5.0
|
|
|
|
# Test default timeout
|
|
transport_default = create_transport("ws", upgrader)
|
|
assert hasattr(transport_default, "_handshake_timeout")
|
|
assert getattr(transport_default, "_handshake_timeout") == 15.0
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_connection_state_tracking():
|
|
"""Test WebSocket connection state tracking."""
|
|
from libp2p.transport.websocket.connection import P2PWebSocketConnection
|
|
|
|
# Create a mock WebSocket connection
|
|
class MockWebSocketConnection:
|
|
async def send_message(self, data: bytes) -> None:
|
|
pass
|
|
|
|
async def get_message(self) -> bytes:
|
|
return b"test message"
|
|
|
|
async def aclose(self) -> None:
|
|
pass
|
|
|
|
mock_ws = MockWebSocketConnection()
|
|
conn = P2PWebSocketConnection(mock_ws, is_secure=True)
|
|
|
|
# Test initial state
|
|
state = conn.conn_state()
|
|
assert state["transport"] == "websocket"
|
|
assert state["secure"] is True
|
|
assert state["bytes_read"] == 0
|
|
assert state["bytes_written"] == 0
|
|
assert state["total_bytes"] == 0
|
|
assert state["connection_duration"] >= 0
|
|
|
|
# Test byte tracking (we can't actually read/write with mock, but we can test
|
|
# the method)
|
|
# The actual byte tracking will be tested in integration tests
|
|
assert hasattr(conn, "_bytes_read")
|
|
assert hasattr(conn, "_bytes_written")
|
|
assert hasattr(conn, "_connection_start_time")
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_concurrent_close_handling():
|
|
"""Test concurrent close handling similar to Go implementation."""
|
|
from libp2p.transport.websocket.connection import P2PWebSocketConnection
|
|
|
|
# Create a mock WebSocket connection that tracks close calls
|
|
class MockWebSocketConnection:
|
|
def __init__(self):
|
|
self.close_calls = 0
|
|
self.closed = False
|
|
|
|
async def send_message(self, data: bytes) -> None:
|
|
if self.closed:
|
|
raise Exception("Connection closed")
|
|
pass
|
|
|
|
async def get_message(self) -> bytes:
|
|
if self.closed:
|
|
raise Exception("Connection closed")
|
|
return b"test message"
|
|
|
|
async def aclose(self) -> None:
|
|
self.close_calls += 1
|
|
self.closed = True
|
|
|
|
mock_ws = MockWebSocketConnection()
|
|
conn = P2PWebSocketConnection(mock_ws, is_secure=False)
|
|
|
|
# Test that multiple close calls are handled gracefully
|
|
await conn.close()
|
|
await conn.close() # Second close should not raise an error
|
|
|
|
# The mock should only be closed once
|
|
assert mock_ws.close_calls == 1
|
|
assert mock_ws.closed is True
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_zero_byte_write_handling():
|
|
"""Test zero-byte write handling similar to Go implementation."""
|
|
from libp2p.transport.websocket.connection import P2PWebSocketConnection
|
|
|
|
# Create a mock WebSocket connection that tracks write calls
|
|
class MockWebSocketConnection:
|
|
def __init__(self):
|
|
self.write_calls = []
|
|
|
|
async def send_message(self, data: bytes) -> None:
|
|
self.write_calls.append(len(data))
|
|
|
|
async def get_message(self) -> bytes:
|
|
return b"test message"
|
|
|
|
async def aclose(self) -> None:
|
|
pass
|
|
|
|
mock_ws = MockWebSocketConnection()
|
|
conn = P2PWebSocketConnection(mock_ws, is_secure=False)
|
|
|
|
# Test zero-byte write
|
|
await conn.write(b"")
|
|
assert 0 in mock_ws.write_calls
|
|
|
|
# Test normal write
|
|
await conn.write(b"hello")
|
|
assert 5 in mock_ws.write_calls
|
|
|
|
# Test multiple zero-byte writes
|
|
for _ in range(10):
|
|
await conn.write(b"")
|
|
|
|
# Should have 11 zero-byte writes total (1 initial + 10 in loop)
|
|
zero_byte_writes = [call for call in mock_ws.write_calls if call == 0]
|
|
assert len(zero_byte_writes) == 11
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_websocket_transport_protocols():
|
|
"""Test that WebSocket transport reports correct protocols."""
|
|
# upgrader = create_upgrader() # Not used in this test
|
|
# transport = WebsocketTransport(upgrader) # Not used in this test
|
|
|
|
# Test that the transport can handle both WS and WSS protocols
|
|
ws_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
|
|
wss_maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/wss")
|
|
|
|
# Both should be valid WebSocket multiaddrs
|
|
assert is_valid_websocket_multiaddr(ws_maddr)
|
|
assert is_valid_websocket_multiaddr(wss_maddr)
|
|
|
|
# Both should be parseable
|
|
ws_parsed = parse_websocket_multiaddr(ws_maddr)
|
|
wss_parsed = parse_websocket_multiaddr(wss_maddr)
|
|
|
|
assert not ws_parsed.is_wss
|
|
assert wss_parsed.is_wss
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_websocket_listener_addr_format():
|
|
"""Test WebSocket listener address format similar to Go implementation."""
|
|
upgrader = create_upgrader()
|
|
|
|
# Test WS listener
|
|
transport_ws = WebsocketTransport(upgrader)
|
|
|
|
async def dummy_handler_ws(conn):
|
|
await trio.sleep(0)
|
|
|
|
listener_ws = transport_ws.create_listener(dummy_handler_ws)
|
|
# Type assertion to access private attribute for testing
|
|
assert hasattr(listener_ws, "_handshake_timeout")
|
|
assert getattr(listener_ws, "_handshake_timeout") == 15.0 # Default timeout
|
|
|
|
# Test WSS listener with TLS config
|
|
import ssl
|
|
|
|
tls_config = ssl.create_default_context()
|
|
transport_wss = WebsocketTransport(upgrader, tls_server_config=tls_config)
|
|
|
|
async def dummy_handler_wss(conn):
|
|
await trio.sleep(0)
|
|
|
|
listener_wss = transport_wss.create_listener(dummy_handler_wss)
|
|
# Type assertion to access private attributes for testing
|
|
assert hasattr(listener_wss, "_tls_config")
|
|
assert getattr(listener_wss, "_tls_config") is not None
|
|
assert hasattr(listener_wss, "_handshake_timeout")
|
|
assert getattr(listener_wss, "_handshake_timeout") == 15.0
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_sni_resolution_limitation():
|
|
"""
|
|
Test SNI resolution limitation - Python multiaddr library doesn't support
|
|
SNI protocol.
|
|
"""
|
|
upgrader = create_upgrader()
|
|
transport = WebsocketTransport(upgrader)
|
|
|
|
# Test that WSS addresses are returned unchanged (SNI resolution not supported)
|
|
wss_maddr = Multiaddr("/dns/example.com/tcp/1234/wss")
|
|
resolved = transport.resolve(wss_maddr)
|
|
assert len(resolved) == 1
|
|
assert resolved[0] == wss_maddr
|
|
|
|
# Test that non-WSS addresses are returned unchanged
|
|
ws_maddr = Multiaddr("/dns/example.com/tcp/1234/ws")
|
|
resolved = transport.resolve(ws_maddr)
|
|
assert len(resolved) == 1
|
|
assert resolved[0] == ws_maddr
|
|
|
|
# Test that IP addresses are returned unchanged
|
|
ip_maddr = Multiaddr("/ip4/127.0.0.1/tcp/1234/wss")
|
|
resolved = transport.resolve(ip_maddr)
|
|
assert len(resolved) == 1
|
|
assert resolved[0] == ip_maddr
|
|
|
|
|
|
@pytest.mark.trio
|
|
async def test_websocket_transport_can_dial():
|
|
"""Test WebSocket transport CanDial functionality similar to Go implementation."""
|
|
# upgrader = create_upgrader() # Not used in this test
|
|
# transport = WebsocketTransport(upgrader) # Not used in this test
|
|
|
|
# Test valid WebSocket addresses that should be dialable
|
|
valid_addresses = [
|
|
"/ip4/127.0.0.1/tcp/5555/ws",
|
|
"/ip4/127.0.0.1/tcp/5555/wss",
|
|
"/ip4/127.0.0.1/tcp/5555/tls/ws",
|
|
# Note: SNI addresses not supported by Python multiaddr library
|
|
]
|
|
|
|
for addr_str in valid_addresses:
|
|
maddr = Multiaddr(addr_str)
|
|
# All these should be valid WebSocket multiaddrs
|
|
assert is_valid_websocket_multiaddr(maddr), (
|
|
f"Address {addr_str} should be valid"
|
|
)
|
|
|
|
# Test invalid addresses that should not be dialable
|
|
invalid_addresses = [
|
|
"/ip4/127.0.0.1/tcp/5555", # No WebSocket protocol
|
|
"/ip4/127.0.0.1/udp/5555/ws", # Wrong transport protocol
|
|
]
|
|
|
|
for addr_str in invalid_addresses:
|
|
maddr = Multiaddr(addr_str)
|
|
# These should not be valid WebSocket multiaddrs
|
|
assert not is_valid_websocket_multiaddr(maddr), (
|
|
f"Address {addr_str} should be invalid"
|
|
)
|