mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge pull request #781 from GautamBytes/add-ws-transport
Add WebSocket transport support
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@ -178,6 +178,10 @@ env.bak/
|
||||
#lockfiles
|
||||
uv.lock
|
||||
poetry.lock
|
||||
tests/interop/js_libp2p/js_node/node_modules/
|
||||
tests/interop/js_libp2p/js_node/package-lock.json
|
||||
tests/interop/js_libp2p/js_node/src/node_modules/
|
||||
tests/interop/js_libp2p/js_node/src/package-lock.json
|
||||
|
||||
# Sphinx documentation build
|
||||
_build/
|
||||
|
||||
446
examples/test_tcp_data_transfer.py
Normal file
446
examples/test_tcp_data_transfer.py
Normal file
@ -0,0 +1,446 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TCP P2P Data Transfer Test
|
||||
|
||||
This test proves that TCP peer-to-peer data transfer works correctly in libp2p.
|
||||
This serves as a baseline to compare with WebSocket tests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p import create_yamux_muxer_option, new_host
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||
|
||||
# Test protocol for data exchange
|
||||
TCP_DATA_PROTOCOL = TProtocol("/test/tcp-data-exchange/1.0.0")
|
||||
|
||||
|
||||
async def create_tcp_host_pair():
|
||||
"""Create a pair of hosts configured for TCP communication."""
|
||||
# Create key pairs
|
||||
key_pair_a = create_new_key_pair()
|
||||
key_pair_b = create_new_key_pair()
|
||||
|
||||
# Create security options (using plaintext for simplicity)
|
||||
def security_options(kp):
|
||||
return {
|
||||
PLAINTEXT_PROTOCOL_ID: InsecureTransport(
|
||||
local_key_pair=kp, secure_bytes_provider=None, peerstore=None
|
||||
)
|
||||
}
|
||||
|
||||
# Host A (listener) - TCP transport (default)
|
||||
host_a = new_host(
|
||||
key_pair=key_pair_a,
|
||||
sec_opt=security_options(key_pair_a),
|
||||
muxer_opt=create_yamux_muxer_option(),
|
||||
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")],
|
||||
)
|
||||
|
||||
# Host B (dialer) - TCP transport (default)
|
||||
host_b = new_host(
|
||||
key_pair=key_pair_b,
|
||||
sec_opt=security_options(key_pair_b),
|
||||
muxer_opt=create_yamux_muxer_option(),
|
||||
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")],
|
||||
)
|
||||
|
||||
return host_a, host_b
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_tcp_basic_connection():
|
||||
"""Test basic TCP connection establishment."""
|
||||
host_a, host_b = await create_tcp_host_pair()
|
||||
|
||||
connection_established = False
|
||||
|
||||
async def connection_handler(stream):
|
||||
nonlocal connection_established
|
||||
connection_established = True
|
||||
await stream.close()
|
||||
|
||||
host_a.set_stream_handler(TCP_DATA_PROTOCOL, connection_handler)
|
||||
|
||||
async with (
|
||||
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]),
|
||||
host_b.run(listen_addrs=[]),
|
||||
):
|
||||
# Get host A's listen address
|
||||
listen_addrs = host_a.get_addrs()
|
||||
assert listen_addrs, "Host A should have listen addresses"
|
||||
|
||||
# Extract TCP address
|
||||
tcp_addr = None
|
||||
for addr in listen_addrs:
|
||||
if "/tcp/" in str(addr) and "/ws" not in str(addr):
|
||||
tcp_addr = addr
|
||||
break
|
||||
|
||||
assert tcp_addr, f"No TCP address found in {listen_addrs}"
|
||||
print(f"🔗 Host A listening on: {tcp_addr}")
|
||||
|
||||
# Create peer info for host A
|
||||
peer_info = info_from_p2p_addr(tcp_addr)
|
||||
|
||||
# Host B connects to host A
|
||||
await host_b.connect(peer_info)
|
||||
print("✅ TCP connection established")
|
||||
|
||||
# Open a stream to test the connection
|
||||
stream = await host_b.new_stream(peer_info.peer_id, [TCP_DATA_PROTOCOL])
|
||||
await stream.close()
|
||||
|
||||
# Wait a bit for the handler to be called
|
||||
await trio.sleep(0.1)
|
||||
|
||||
assert connection_established, "TCP connection handler should have been called"
|
||||
print("✅ TCP basic connection test successful!")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_tcp_data_transfer():
|
||||
"""Test TCP peer-to-peer data transfer."""
|
||||
host_a, host_b = await create_tcp_host_pair()
|
||||
|
||||
# Test data
|
||||
test_data = b"Hello TCP P2P Data Transfer! This is a test message."
|
||||
received_data = None
|
||||
transfer_complete = trio.Event()
|
||||
|
||||
async def data_handler(stream):
|
||||
nonlocal received_data
|
||||
try:
|
||||
# Read the incoming data
|
||||
received_data = await stream.read(len(test_data))
|
||||
# Echo it back to confirm successful transfer
|
||||
await stream.write(received_data)
|
||||
await stream.close()
|
||||
transfer_complete.set()
|
||||
except Exception as e:
|
||||
print(f"Handler error: {e}")
|
||||
transfer_complete.set()
|
||||
|
||||
host_a.set_stream_handler(TCP_DATA_PROTOCOL, data_handler)
|
||||
|
||||
async with (
|
||||
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]),
|
||||
host_b.run(listen_addrs=[]),
|
||||
):
|
||||
# Get host A's listen address
|
||||
listen_addrs = host_a.get_addrs()
|
||||
assert listen_addrs, "Host A should have listen addresses"
|
||||
|
||||
# Extract TCP address
|
||||
tcp_addr = None
|
||||
for addr in listen_addrs:
|
||||
if "/tcp/" in str(addr) and "/ws" not in str(addr):
|
||||
tcp_addr = addr
|
||||
break
|
||||
|
||||
assert tcp_addr, f"No TCP address found in {listen_addrs}"
|
||||
print(f"🔗 Host A listening on: {tcp_addr}")
|
||||
|
||||
# Create peer info for host A
|
||||
peer_info = info_from_p2p_addr(tcp_addr)
|
||||
|
||||
# Host B connects to host A
|
||||
await host_b.connect(peer_info)
|
||||
print("✅ TCP connection established")
|
||||
|
||||
# Open a stream for data transfer
|
||||
stream = await host_b.new_stream(peer_info.peer_id, [TCP_DATA_PROTOCOL])
|
||||
print("✅ TCP stream opened")
|
||||
|
||||
# Send test data
|
||||
await stream.write(test_data)
|
||||
print(f"📤 Sent data: {test_data}")
|
||||
|
||||
# Read echoed data back
|
||||
echoed_data = await stream.read(len(test_data))
|
||||
print(f"📥 Received echo: {echoed_data}")
|
||||
|
||||
await stream.close()
|
||||
|
||||
# Wait for transfer to complete
|
||||
with trio.fail_after(5.0): # 5 second timeout
|
||||
await transfer_complete.wait()
|
||||
|
||||
# Verify data transfer
|
||||
assert received_data == test_data, (
|
||||
f"Data mismatch: {received_data} != {test_data}"
|
||||
)
|
||||
assert echoed_data == test_data, f"Echo mismatch: {echoed_data} != {test_data}"
|
||||
|
||||
print("✅ TCP P2P data transfer successful!")
|
||||
print(f" Original: {test_data}")
|
||||
print(f" Received: {received_data}")
|
||||
print(f" Echoed: {echoed_data}")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_tcp_large_data_transfer():
|
||||
"""Test TCP with larger data payloads."""
|
||||
host_a, host_b = await create_tcp_host_pair()
|
||||
|
||||
# Large test data (10KB)
|
||||
test_data = b"TCP Large Data Test! " * 500 # ~10KB
|
||||
received_data = None
|
||||
transfer_complete = trio.Event()
|
||||
|
||||
async def large_data_handler(stream):
|
||||
nonlocal received_data
|
||||
try:
|
||||
# Read data in chunks
|
||||
chunks = []
|
||||
total_received = 0
|
||||
expected_size = len(test_data)
|
||||
|
||||
while total_received < expected_size:
|
||||
chunk = await stream.read(min(1024, expected_size - total_received))
|
||||
if not chunk:
|
||||
break
|
||||
chunks.append(chunk)
|
||||
total_received += len(chunk)
|
||||
|
||||
received_data = b"".join(chunks)
|
||||
|
||||
# Send back confirmation
|
||||
await stream.write(b"RECEIVED_OK")
|
||||
await stream.close()
|
||||
transfer_complete.set()
|
||||
except Exception as e:
|
||||
print(f"Large data handler error: {e}")
|
||||
transfer_complete.set()
|
||||
|
||||
host_a.set_stream_handler(TCP_DATA_PROTOCOL, large_data_handler)
|
||||
|
||||
async with (
|
||||
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]),
|
||||
host_b.run(listen_addrs=[]),
|
||||
):
|
||||
# Get host A's listen address
|
||||
listen_addrs = host_a.get_addrs()
|
||||
assert listen_addrs, "Host A should have listen addresses"
|
||||
|
||||
# Extract TCP address
|
||||
tcp_addr = None
|
||||
for addr in listen_addrs:
|
||||
if "/tcp/" in str(addr) and "/ws" not in str(addr):
|
||||
tcp_addr = addr
|
||||
break
|
||||
|
||||
assert tcp_addr, f"No TCP address found in {listen_addrs}"
|
||||
print(f"🔗 Host A listening on: {tcp_addr}")
|
||||
print(f"📊 Test data size: {len(test_data)} bytes")
|
||||
|
||||
# Create peer info for host A
|
||||
peer_info = info_from_p2p_addr(tcp_addr)
|
||||
|
||||
# Host B connects to host A
|
||||
await host_b.connect(peer_info)
|
||||
print("✅ TCP connection established")
|
||||
|
||||
# Open a stream for data transfer
|
||||
stream = await host_b.new_stream(peer_info.peer_id, [TCP_DATA_PROTOCOL])
|
||||
print("✅ TCP stream opened")
|
||||
|
||||
# Send large test data in chunks
|
||||
chunk_size = 1024
|
||||
sent_bytes = 0
|
||||
for i in range(0, len(test_data), chunk_size):
|
||||
chunk = test_data[i : i + chunk_size]
|
||||
await stream.write(chunk)
|
||||
sent_bytes += len(chunk)
|
||||
if sent_bytes % (chunk_size * 4) == 0: # Progress every 4KB
|
||||
print(f"📤 Sent {sent_bytes}/{len(test_data)} bytes")
|
||||
|
||||
print(f"📤 Sent all {len(test_data)} bytes")
|
||||
|
||||
# Read confirmation
|
||||
confirmation = await stream.read(1024)
|
||||
print(f"📥 Received confirmation: {confirmation}")
|
||||
|
||||
await stream.close()
|
||||
|
||||
# Wait for transfer to complete
|
||||
with trio.fail_after(10.0): # 10 second timeout for large data
|
||||
await transfer_complete.wait()
|
||||
|
||||
# Verify data transfer
|
||||
assert received_data is not None, "No data was received"
|
||||
assert received_data == test_data, (
|
||||
"Large data transfer failed:"
|
||||
+ f" sizes {len(received_data)} != {len(test_data)}"
|
||||
)
|
||||
assert confirmation == b"RECEIVED_OK", f"Confirmation failed: {confirmation}"
|
||||
|
||||
print("✅ TCP large data transfer successful!")
|
||||
print(f" Data size: {len(test_data)} bytes")
|
||||
print(f" Received: {len(received_data)} bytes")
|
||||
print(f" Match: {received_data == test_data}")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_tcp_bidirectional_transfer():
|
||||
"""Test bidirectional data transfer over TCP."""
|
||||
host_a, host_b = await create_tcp_host_pair()
|
||||
|
||||
# Test data
|
||||
data_a_to_b = b"Message from Host A to Host B via TCP"
|
||||
data_b_to_a = b"Response from Host B to Host A via TCP"
|
||||
|
||||
received_on_a = None
|
||||
received_on_b = None
|
||||
transfer_complete_a = trio.Event()
|
||||
transfer_complete_b = trio.Event()
|
||||
|
||||
async def handler_a(stream):
|
||||
nonlocal received_on_a
|
||||
try:
|
||||
# Read data from B
|
||||
received_on_a = await stream.read(len(data_b_to_a))
|
||||
print(f"🅰️ Host A received: {received_on_a}")
|
||||
await stream.close()
|
||||
transfer_complete_a.set()
|
||||
except Exception as e:
|
||||
print(f"Handler A error: {e}")
|
||||
transfer_complete_a.set()
|
||||
|
||||
async def handler_b(stream):
|
||||
nonlocal received_on_b
|
||||
try:
|
||||
# Read data from A
|
||||
received_on_b = await stream.read(len(data_a_to_b))
|
||||
print(f"🅱️ Host B received: {received_on_b}")
|
||||
await stream.close()
|
||||
transfer_complete_b.set()
|
||||
except Exception as e:
|
||||
print(f"Handler B error: {e}")
|
||||
transfer_complete_b.set()
|
||||
|
||||
# Set up handlers on both hosts
|
||||
protocol_a_to_b = TProtocol("/test/tcp-a-to-b/1.0.0")
|
||||
protocol_b_to_a = TProtocol("/test/tcp-b-to-a/1.0.0")
|
||||
|
||||
host_a.set_stream_handler(protocol_b_to_a, handler_a)
|
||||
host_b.set_stream_handler(protocol_a_to_b, handler_b)
|
||||
|
||||
async with (
|
||||
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]),
|
||||
host_b.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]),
|
||||
):
|
||||
# Get addresses
|
||||
addrs_a = host_a.get_addrs()
|
||||
addrs_b = host_b.get_addrs()
|
||||
|
||||
assert addrs_a and addrs_b, "Both hosts should have addresses"
|
||||
|
||||
# Extract TCP addresses
|
||||
tcp_addr_a = next(
|
||||
(
|
||||
addr
|
||||
for addr in addrs_a
|
||||
if "/tcp/" in str(addr) and "/ws" not in str(addr)
|
||||
),
|
||||
None,
|
||||
)
|
||||
tcp_addr_b = next(
|
||||
(
|
||||
addr
|
||||
for addr in addrs_b
|
||||
if "/tcp/" in str(addr) and "/ws" not in str(addr)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
assert tcp_addr_a and tcp_addr_b, (
|
||||
f"TCP addresses not found: A={addrs_a}, B={addrs_b}"
|
||||
)
|
||||
print(f"🔗 Host A listening on: {tcp_addr_a}")
|
||||
print(f"🔗 Host B listening on: {tcp_addr_b}")
|
||||
|
||||
# Create peer infos
|
||||
peer_info_a = info_from_p2p_addr(tcp_addr_a)
|
||||
peer_info_b = info_from_p2p_addr(tcp_addr_b)
|
||||
|
||||
# Establish connections
|
||||
await host_b.connect(peer_info_a)
|
||||
await host_a.connect(peer_info_b)
|
||||
print("✅ Bidirectional TCP connections established")
|
||||
|
||||
# Send data A -> B
|
||||
stream_a_to_b = await host_a.new_stream(peer_info_b.peer_id, [protocol_a_to_b])
|
||||
await stream_a_to_b.write(data_a_to_b)
|
||||
print(f"📤 A->B: {data_a_to_b}")
|
||||
await stream_a_to_b.close()
|
||||
|
||||
# Send data B -> A
|
||||
stream_b_to_a = await host_b.new_stream(peer_info_a.peer_id, [protocol_b_to_a])
|
||||
await stream_b_to_a.write(data_b_to_a)
|
||||
print(f"📤 B->A: {data_b_to_a}")
|
||||
await stream_b_to_a.close()
|
||||
|
||||
# Wait for both transfers to complete
|
||||
with trio.fail_after(5.0):
|
||||
await transfer_complete_a.wait()
|
||||
await transfer_complete_b.wait()
|
||||
|
||||
# Verify bidirectional transfer
|
||||
assert received_on_a == data_b_to_a, f"A received wrong data: {received_on_a}"
|
||||
assert received_on_b == data_a_to_b, f"B received wrong data: {received_on_b}"
|
||||
|
||||
print("✅ TCP bidirectional data transfer successful!")
|
||||
print(f" A->B: {data_a_to_b}")
|
||||
print(f" B->A: {data_b_to_a}")
|
||||
print(f" ✓ A got: {received_on_a}")
|
||||
print(f" ✓ B got: {received_on_b}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests directly
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
print("🧪 Running TCP P2P Data Transfer Tests")
|
||||
print("=" * 50)
|
||||
|
||||
async def run_all_tcp_tests():
|
||||
try:
|
||||
print("\n1. Testing basic TCP connection...")
|
||||
await test_tcp_basic_connection()
|
||||
except Exception as e:
|
||||
print(f"❌ Basic TCP connection test failed: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
print("\n2. Testing TCP data transfer...")
|
||||
await test_tcp_data_transfer()
|
||||
except Exception as e:
|
||||
print(f"❌ TCP data transfer test failed: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
print("\n3. Testing TCP large data transfer...")
|
||||
await test_tcp_large_data_transfer()
|
||||
except Exception as e:
|
||||
print(f"❌ TCP large data transfer test failed: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
print("\n4. Testing TCP bidirectional transfer...")
|
||||
await test_tcp_bidirectional_transfer()
|
||||
except Exception as e:
|
||||
print(f"❌ TCP bidirectional transfer test failed: {e}")
|
||||
return
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("🏁 TCP P2P Tests Complete - All Tests PASSED!")
|
||||
|
||||
trio.run(run_all_tcp_tests)
|
||||
210
examples/transport_integration_demo.py
Normal file
210
examples/transport_integration_demo.py
Normal file
@ -0,0 +1,210 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Demo script showing the new transport integration capabilities in py-libp2p.
|
||||
|
||||
This script demonstrates:
|
||||
1. How to use the transport registry
|
||||
2. How to create transports dynamically based on multiaddrs
|
||||
3. How to register custom transports
|
||||
4. How the new system automatically selects the right transport
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
# Add the libp2p directory to the path so we can import it
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import multiaddr
|
||||
|
||||
from libp2p.transport import (
|
||||
create_transport,
|
||||
create_transport_for_multiaddr,
|
||||
get_supported_transport_protocols,
|
||||
get_transport_registry,
|
||||
register_transport,
|
||||
)
|
||||
from libp2p.transport.tcp.tcp import TCP
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def demo_transport_registry():
|
||||
"""Demonstrate the transport registry functionality."""
|
||||
print("🔧 Transport Registry Demo")
|
||||
print("=" * 50)
|
||||
|
||||
# Get the global registry
|
||||
registry = get_transport_registry()
|
||||
|
||||
# Show supported protocols
|
||||
supported = get_supported_transport_protocols()
|
||||
print(f"Supported transport protocols: {supported}")
|
||||
|
||||
# Show registered transports
|
||||
print("\nRegistered transports:")
|
||||
for protocol in supported:
|
||||
transport_class = registry.get_transport(protocol)
|
||||
class_name = transport_class.__name__ if transport_class else "None"
|
||||
print(f" {protocol}: {class_name}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def demo_transport_factory():
|
||||
"""Demonstrate the transport factory functions."""
|
||||
print("🏭 Transport Factory Demo")
|
||||
print("=" * 50)
|
||||
|
||||
# Create a dummy upgrader for WebSocket transport
|
||||
upgrader = TransportUpgrader({}, {})
|
||||
|
||||
# Create transports using the factory function
|
||||
try:
|
||||
tcp_transport = create_transport("tcp")
|
||||
print(f"✅ Created TCP transport: {type(tcp_transport).__name__}")
|
||||
|
||||
ws_transport = create_transport("ws", upgrader)
|
||||
print(f"✅ Created WebSocket transport: {type(ws_transport).__name__}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating transport: {e}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def demo_multiaddr_transport_selection():
|
||||
"""Demonstrate automatic transport selection based on multiaddrs."""
|
||||
print("🎯 Multiaddr Transport Selection Demo")
|
||||
print("=" * 50)
|
||||
|
||||
# Create a dummy upgrader
|
||||
upgrader = TransportUpgrader({}, {})
|
||||
|
||||
# Test different multiaddr types
|
||||
test_addrs = [
|
||||
"/ip4/127.0.0.1/tcp/8080",
|
||||
"/ip4/127.0.0.1/tcp/8080/ws",
|
||||
"/ip6/::1/tcp/8080/ws",
|
||||
"/dns4/example.com/tcp/443/ws",
|
||||
]
|
||||
|
||||
for addr_str in test_addrs:
|
||||
try:
|
||||
maddr = multiaddr.Multiaddr(addr_str)
|
||||
transport = create_transport_for_multiaddr(maddr, upgrader)
|
||||
|
||||
if transport:
|
||||
print(f"✅ {addr_str} -> {type(transport).__name__}")
|
||||
else:
|
||||
print(f"❌ {addr_str} -> No transport found")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ {addr_str} -> Error: {e}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def demo_custom_transport_registration():
|
||||
"""Demonstrate how to register custom transports."""
|
||||
print("🔧 Custom Transport Registration Demo")
|
||||
print("=" * 50)
|
||||
|
||||
# Show current supported protocols
|
||||
print(f"Before registration: {get_supported_transport_protocols()}")
|
||||
|
||||
# Register a custom transport (using TCP as an example)
|
||||
class CustomTCPTransport(TCP):
|
||||
"""Custom TCP transport for demonstration."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.custom_flag = True
|
||||
|
||||
# Register the custom transport
|
||||
register_transport("custom_tcp", CustomTCPTransport)
|
||||
|
||||
# Show updated supported protocols
|
||||
print(f"After registration: {get_supported_transport_protocols()}")
|
||||
|
||||
# Test creating the custom transport
|
||||
try:
|
||||
custom_transport = create_transport("custom_tcp")
|
||||
print(f"✅ Created custom transport: {type(custom_transport).__name__}")
|
||||
# Check if it has the custom flag (type-safe way)
|
||||
if hasattr(custom_transport, "custom_flag"):
|
||||
flag_value = getattr(custom_transport, "custom_flag", "Not found")
|
||||
print(f" Custom flag: {flag_value}")
|
||||
else:
|
||||
print(" Custom flag: Not found")
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating custom transport: {e}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def demo_integration_with_libp2p():
|
||||
"""Demonstrate how the new system integrates with libp2p."""
|
||||
print("🚀 Libp2p Integration Demo")
|
||||
print("=" * 50)
|
||||
|
||||
print("The new transport system integrates seamlessly with libp2p:")
|
||||
print()
|
||||
print("1. ✅ Automatic transport selection based on multiaddr")
|
||||
print("2. ✅ Support for WebSocket (/ws) protocol")
|
||||
print("3. ✅ Fallback to TCP for backward compatibility")
|
||||
print("4. ✅ Easy registration of new transport protocols")
|
||||
print("5. ✅ No changes needed to existing libp2p code")
|
||||
print()
|
||||
|
||||
print("Example usage in libp2p:")
|
||||
print(" # This will automatically use WebSocket transport")
|
||||
print(" host = new_host(listen_addrs=['/ip4/127.0.0.1/tcp/8080/ws'])")
|
||||
print()
|
||||
print(" # This will automatically use TCP transport")
|
||||
print(" host = new_host(listen_addrs=['/ip4/127.0.0.1/tcp/8080'])")
|
||||
print()
|
||||
|
||||
print()
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all demos."""
|
||||
print("🎉 Py-libp2p Transport Integration Demo")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# Run all demos
|
||||
demo_transport_registry()
|
||||
demo_transport_factory()
|
||||
demo_multiaddr_transport_selection()
|
||||
demo_custom_transport_registration()
|
||||
demo_integration_with_libp2p()
|
||||
|
||||
print("🎯 Summary of New Features:")
|
||||
print("=" * 40)
|
||||
print("✅ Transport Registry: Central registry for all transport implementations")
|
||||
print("✅ Dynamic Transport Selection: Automatic selection based on multiaddr")
|
||||
print("✅ WebSocket Support: Full /ws protocol support")
|
||||
print("✅ Extensible Architecture: Easy to add new transport protocols")
|
||||
print("✅ Backward Compatibility: Existing TCP code continues to work")
|
||||
print("✅ Factory Functions: Simple API for creating transports")
|
||||
print()
|
||||
print("🚀 The transport system is now ready for production use!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Demo interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Demo failed with error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
220
examples/websocket/test_tcp_echo.py
Normal file
220
examples/websocket/test_tcp_echo.py
Normal file
@ -0,0 +1,220 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple TCP echo demo to verify basic libp2p functionality.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.host.basic_host import BasicHost
|
||||
from libp2p.network.swarm import Swarm
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.peer.peerstore import PeerStore
|
||||
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||
from libp2p.stream_muxer.yamux.yamux import Yamux
|
||||
from libp2p.transport.tcp.tcp import TCP
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
|
||||
# Enable debug logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger("libp2p.tcp-example")
|
||||
|
||||
# Simple echo protocol
|
||||
ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
||||
|
||||
|
||||
async def echo_handler(stream):
|
||||
"""Simple echo handler that echoes back any data received."""
|
||||
try:
|
||||
data = await stream.read(1024)
|
||||
if data:
|
||||
message = data.decode("utf-8", errors="replace")
|
||||
print(f"📥 Received: {message}")
|
||||
print(f"📤 Echoing back: {message}")
|
||||
await stream.write(data)
|
||||
await stream.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Echo handler error: {e}")
|
||||
await stream.close()
|
||||
|
||||
|
||||
def create_tcp_host():
|
||||
"""Create a host with TCP transport."""
|
||||
# Create key pair and peer store
|
||||
key_pair = create_new_key_pair()
|
||||
peer_id = ID.from_pubkey(key_pair.public_key)
|
||||
peer_store = PeerStore()
|
||||
peer_store.add_key_pair(peer_id, key_pair)
|
||||
|
||||
# Create transport upgrader with plaintext security
|
||||
upgrader = TransportUpgrader(
|
||||
secure_transports_by_protocol={
|
||||
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
|
||||
},
|
||||
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
||||
)
|
||||
|
||||
# Create TCP transport
|
||||
transport = TCP()
|
||||
|
||||
# Create swarm and host
|
||||
swarm = Swarm(peer_id, peer_store, upgrader, transport)
|
||||
host = BasicHost(swarm)
|
||||
|
||||
return host
|
||||
|
||||
|
||||
async def run(port: int, destination: str) -> None:
|
||||
localhost_ip = "0.0.0.0"
|
||||
|
||||
if not destination:
|
||||
# Create first host (listener) with TCP transport
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}")
|
||||
|
||||
try:
|
||||
host = create_tcp_host()
|
||||
logger.debug("Created TCP host")
|
||||
|
||||
# Set up echo handler
|
||||
host.set_stream_handler(ECHO_PROTOCOL_ID, echo_handler)
|
||||
|
||||
async with (
|
||||
host.run(listen_addrs=[listen_addr]),
|
||||
trio.open_nursery() as (nursery),
|
||||
):
|
||||
# Start the peer-store cleanup task
|
||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||
|
||||
# Get the actual address and replace 0.0.0.0 with 127.0.0.1 for client
|
||||
# connections
|
||||
addrs = host.get_addrs()
|
||||
logger.debug(f"Host addresses: {addrs}")
|
||||
if not addrs:
|
||||
print("❌ Error: No addresses found for the host")
|
||||
return
|
||||
|
||||
server_addr = str(addrs[0])
|
||||
client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/")
|
||||
|
||||
print("🌐 TCP Server Started Successfully!")
|
||||
print("=" * 50)
|
||||
print(f"📍 Server Address: {client_addr}")
|
||||
print("🔧 Protocol: /echo/1.0.0")
|
||||
print("🚀 Transport: TCP")
|
||||
print()
|
||||
print("📋 To test the connection, run this in another terminal:")
|
||||
print(f" python test_tcp_echo.py -d {client_addr}")
|
||||
print()
|
||||
print("⏳ Waiting for incoming TCP connections...")
|
||||
print("─" * 50)
|
||||
|
||||
await trio.sleep_forever()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating TCP server: {e}")
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
else:
|
||||
# Create second host (dialer) with TCP transport
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}")
|
||||
|
||||
try:
|
||||
# Create a single host for client operations
|
||||
host = create_tcp_host()
|
||||
|
||||
# Start the host for client operations
|
||||
async with (
|
||||
host.run(listen_addrs=[listen_addr]),
|
||||
trio.open_nursery() as (nursery),
|
||||
):
|
||||
# Start the peer-store cleanup task
|
||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||
maddr = multiaddr.Multiaddr(destination)
|
||||
info = info_from_p2p_addr(maddr)
|
||||
print("🔌 TCP Client Starting...")
|
||||
print("=" * 40)
|
||||
print(f"🎯 Target Peer: {info.peer_id}")
|
||||
print(f"📍 Target Address: {destination}")
|
||||
print()
|
||||
|
||||
try:
|
||||
print("🔗 Connecting to TCP server...")
|
||||
await host.connect(info)
|
||||
print("✅ Successfully connected to TCP server!")
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print("\n❌ Connection Failed!")
|
||||
print(f" Peer ID: {info.peer_id}")
|
||||
print(f" Address: {destination}")
|
||||
print(f" Error: {error_msg}")
|
||||
return
|
||||
|
||||
# Create a stream and send test data
|
||||
try:
|
||||
stream = await host.new_stream(info.peer_id, [ECHO_PROTOCOL_ID])
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to create stream: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
print("🚀 Starting Echo Protocol Test...")
|
||||
print("─" * 40)
|
||||
|
||||
# Send test data
|
||||
test_message = b"Hello TCP Transport!"
|
||||
print(f"📤 Sending message: {test_message.decode('utf-8')}")
|
||||
await stream.write(test_message)
|
||||
|
||||
# Read response
|
||||
print("⏳ Waiting for server response...")
|
||||
response = await stream.read(1024)
|
||||
print(f"📥 Received response: {response.decode('utf-8')}")
|
||||
|
||||
await stream.close()
|
||||
|
||||
print("─" * 40)
|
||||
if response == test_message:
|
||||
print("🎉 Echo test successful!")
|
||||
print("✅ TCP transport is working perfectly!")
|
||||
else:
|
||||
print("❌ Echo test failed!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Echo protocol error: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
print("✅ TCP demo completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating TCP client: {e}")
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
|
||||
def main() -> None:
|
||||
description = "Simple TCP echo demo for libp2p"
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
||||
parser.add_argument(
|
||||
"-d", "--destination", type=str, help="destination multiaddr string"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
trio.run(run, args.port, args.destination)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
145
examples/websocket/test_websocket_transport.py
Normal file
145
examples/websocket/test_websocket_transport.py
Normal file
@ -0,0 +1,145 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple test script to verify WebSocket transport functionality.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
# Add the libp2p directory to the path so we can import it
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
import multiaddr
|
||||
|
||||
from libp2p.transport import create_transport, create_transport_for_multiaddr
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def test_websocket_transport():
|
||||
"""Test basic WebSocket transport functionality."""
|
||||
print("🧪 Testing WebSocket Transport Functionality")
|
||||
print("=" * 50)
|
||||
|
||||
# Create a dummy upgrader
|
||||
upgrader = TransportUpgrader({}, {})
|
||||
|
||||
# Test creating WebSocket transport
|
||||
try:
|
||||
ws_transport = create_transport("ws", upgrader)
|
||||
print(f"✅ WebSocket transport created: {type(ws_transport).__name__}")
|
||||
|
||||
# Test creating transport from multiaddr
|
||||
ws_maddr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
|
||||
ws_transport_from_maddr = create_transport_for_multiaddr(ws_maddr, upgrader)
|
||||
print(
|
||||
f"✅ WebSocket transport from multiaddr: "
|
||||
f"{type(ws_transport_from_maddr).__name__}"
|
||||
)
|
||||
|
||||
# Test creating listener
|
||||
handler_called = False
|
||||
|
||||
async def test_handler(conn):
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
print(f"✅ Connection handler called with: {type(conn).__name__}")
|
||||
await conn.close()
|
||||
|
||||
listener = ws_transport.create_listener(test_handler)
|
||||
print(f"✅ WebSocket listener created: {type(listener).__name__}")
|
||||
|
||||
# Test that the transport can be used
|
||||
print(
|
||||
f"✅ WebSocket transport supports dialing: {hasattr(ws_transport, 'dial')}"
|
||||
)
|
||||
print(
|
||||
f"✅ WebSocket transport supports listening: "
|
||||
f"{hasattr(ws_transport, 'create_listener')}"
|
||||
)
|
||||
|
||||
print("\n🎯 WebSocket Transport Test Results:")
|
||||
print("✅ Transport creation: PASS")
|
||||
print("✅ Multiaddr parsing: PASS")
|
||||
print("✅ Listener creation: PASS")
|
||||
print("✅ Interface compliance: PASS")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ WebSocket transport test failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_transport_registry():
|
||||
"""Test the transport registry functionality."""
|
||||
print("\n🔧 Testing Transport Registry")
|
||||
print("=" * 30)
|
||||
|
||||
from libp2p.transport import (
|
||||
get_supported_transport_protocols,
|
||||
get_transport_registry,
|
||||
)
|
||||
|
||||
registry = get_transport_registry()
|
||||
supported = get_supported_transport_protocols()
|
||||
|
||||
print(f"Supported protocols: {supported}")
|
||||
|
||||
# Test getting transports
|
||||
for protocol in supported:
|
||||
transport_class = registry.get_transport(protocol)
|
||||
class_name = transport_class.__name__ if transport_class else "None"
|
||||
print(f" {protocol}: {class_name}")
|
||||
|
||||
# Test creating transports through registry
|
||||
upgrader = TransportUpgrader({}, {})
|
||||
|
||||
for protocol in supported:
|
||||
try:
|
||||
transport = registry.create_transport(protocol, upgrader)
|
||||
if transport:
|
||||
print(f"✅ {protocol}: Created successfully")
|
||||
else:
|
||||
print(f"❌ {protocol}: Failed to create")
|
||||
except Exception as e:
|
||||
print(f"❌ {protocol}: Error - {e}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all tests."""
|
||||
print("🚀 WebSocket Transport Integration Test Suite")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# Run tests
|
||||
success = await test_websocket_transport()
|
||||
await test_transport_registry()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
if success:
|
||||
print("🎉 All tests passed! WebSocket transport is working correctly.")
|
||||
else:
|
||||
print("❌ Some tests failed. Check the output above for details.")
|
||||
|
||||
print("\n🚀 WebSocket transport is ready for use in py-libp2p!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Test interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Test failed with error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
448
examples/websocket/websocket_demo.py
Normal file
448
examples/websocket/websocket_demo.py
Normal file
@ -0,0 +1,448 @@
|
||||
import argparse
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.abc import INotifee
|
||||
from libp2p.crypto.ed25519 import create_new_key_pair as create_ed25519_key_pair
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.host.basic_host import BasicHost
|
||||
from libp2p.network.swarm import Swarm
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.peer.peerstore import PeerStore
|
||||
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||
from libp2p.security.noise.transport import (
|
||||
PROTOCOL_ID as NOISE_PROTOCOL_ID,
|
||||
Transport as NoiseTransport,
|
||||
)
|
||||
from libp2p.stream_muxer.yamux.yamux import Yamux
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
from libp2p.transport.websocket.transport import WebsocketTransport
|
||||
|
||||
# Enable debug logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger("libp2p.websocket-example")
|
||||
|
||||
|
||||
# Suppress KeyboardInterrupt by handling SIGINT directly
|
||||
def signal_handler(signum, frame):
|
||||
print("✅ Clean exit completed.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
# Simple echo protocol
|
||||
ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
||||
|
||||
|
||||
async def echo_handler(stream):
|
||||
"""Simple echo handler that echoes back any data received."""
|
||||
try:
|
||||
data = await stream.read(1024)
|
||||
if data:
|
||||
message = data.decode("utf-8", errors="replace")
|
||||
print(f"📥 Received: {message}")
|
||||
print(f"📤 Echoing back: {message}")
|
||||
await stream.write(data)
|
||||
await stream.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Echo handler error: {e}")
|
||||
await stream.close()
|
||||
|
||||
|
||||
def create_websocket_host(listen_addrs=None, use_plaintext=False):
|
||||
"""Create a host with WebSocket transport."""
|
||||
# Create key pair and peer store
|
||||
key_pair = create_new_key_pair()
|
||||
peer_id = ID.from_pubkey(key_pair.public_key)
|
||||
peer_store = PeerStore()
|
||||
peer_store.add_key_pair(peer_id, key_pair)
|
||||
|
||||
if use_plaintext:
|
||||
# Create transport upgrader with plaintext security
|
||||
upgrader = TransportUpgrader(
|
||||
secure_transports_by_protocol={
|
||||
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
|
||||
},
|
||||
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
||||
)
|
||||
else:
|
||||
# Create separate Ed25519 key for Noise protocol
|
||||
noise_key_pair = create_ed25519_key_pair()
|
||||
|
||||
# Create Noise transport
|
||||
noise_transport = NoiseTransport(
|
||||
libp2p_keypair=key_pair,
|
||||
noise_privkey=noise_key_pair.private_key,
|
||||
early_data=None,
|
||||
with_noise_pipes=False,
|
||||
)
|
||||
|
||||
# Create transport upgrader with Noise security
|
||||
upgrader = TransportUpgrader(
|
||||
secure_transports_by_protocol={
|
||||
TProtocol(NOISE_PROTOCOL_ID): noise_transport
|
||||
},
|
||||
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
||||
)
|
||||
|
||||
# Create WebSocket transport
|
||||
transport = WebsocketTransport(upgrader)
|
||||
|
||||
# Create swarm and host
|
||||
swarm = Swarm(peer_id, peer_store, upgrader, transport)
|
||||
host = BasicHost(swarm)
|
||||
|
||||
return host
|
||||
|
||||
|
||||
async def run(port: int, destination: str, use_plaintext: bool = False) -> None:
|
||||
localhost_ip = "0.0.0.0"
|
||||
|
||||
if not destination:
|
||||
# Create first host (listener) with WebSocket transport
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}/ws")
|
||||
|
||||
try:
|
||||
host = create_websocket_host(use_plaintext=use_plaintext)
|
||||
logger.debug(f"Created host with use_plaintext={use_plaintext}")
|
||||
|
||||
# Set up echo handler
|
||||
host.set_stream_handler(ECHO_PROTOCOL_ID, echo_handler)
|
||||
|
||||
# Add connection event handlers for debugging
|
||||
class DebugNotifee(INotifee):
|
||||
async def opened_stream(self, network, stream):
|
||||
pass
|
||||
|
||||
async def closed_stream(self, network, stream):
|
||||
pass
|
||||
|
||||
async def connected(self, network, conn):
|
||||
print(
|
||||
f"🔗 New libp2p connection established: "
|
||||
f"{conn.muxed_conn.peer_id}"
|
||||
)
|
||||
if hasattr(conn.muxed_conn, "get_security_protocol"):
|
||||
security = conn.muxed_conn.get_security_protocol()
|
||||
else:
|
||||
security = "Unknown"
|
||||
|
||||
print(f" Security: {security}")
|
||||
|
||||
async def disconnected(self, network, conn):
|
||||
print(f"🔌 libp2p connection closed: {conn.muxed_conn.peer_id}")
|
||||
|
||||
async def listen(self, network, multiaddr):
|
||||
pass
|
||||
|
||||
async def listen_close(self, network, multiaddr):
|
||||
pass
|
||||
|
||||
host.get_network().register_notifee(DebugNotifee())
|
||||
|
||||
# Create a cancellation token for clean shutdown
|
||||
cancel_scope = trio.CancelScope()
|
||||
|
||||
async def signal_handler():
|
||||
with trio.open_signal_receiver(signal.SIGINT, signal.SIGTERM) as (
|
||||
signal_receiver
|
||||
):
|
||||
async for sig in signal_receiver:
|
||||
print(f"\n🛑 Received signal {sig}")
|
||||
print("✅ Shutting down WebSocket server...")
|
||||
cancel_scope.cancel()
|
||||
return
|
||||
|
||||
async with (
|
||||
host.run(listen_addrs=[listen_addr]),
|
||||
trio.open_nursery() as (nursery),
|
||||
):
|
||||
# Start the peer-store cleanup task
|
||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||
|
||||
# Start the signal handler
|
||||
nursery.start_soon(signal_handler)
|
||||
|
||||
# Get the actual address and replace 0.0.0.0 with 127.0.0.1 for client
|
||||
# connections
|
||||
addrs = host.get_addrs()
|
||||
logger.debug(f"Host addresses: {addrs}")
|
||||
if not addrs:
|
||||
print("❌ Error: No addresses found for the host")
|
||||
print("Debug: host.get_addrs() returned empty list")
|
||||
return
|
||||
|
||||
server_addr = str(addrs[0])
|
||||
client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/")
|
||||
|
||||
print("🌐 WebSocket Server Started Successfully!")
|
||||
print("=" * 50)
|
||||
print(f"📍 Server Address: {client_addr}")
|
||||
print("🔧 Protocol: /echo/1.0.0")
|
||||
print("🚀 Transport: WebSocket (/ws)")
|
||||
print()
|
||||
print("📋 To test the connection, run this in another terminal:")
|
||||
plaintext_flag = " --plaintext" if use_plaintext else ""
|
||||
print(f" python websocket_demo.py -d {client_addr}{plaintext_flag}")
|
||||
print()
|
||||
print("⏳ Waiting for incoming WebSocket connections...")
|
||||
print("─" * 50)
|
||||
|
||||
# Add a custom handler to show connection events
|
||||
async def custom_echo_handler(stream):
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
print("\n🔗 New WebSocket Connection!")
|
||||
print(f" Peer ID: {peer_id}")
|
||||
print(" Protocol: /echo/1.0.0")
|
||||
|
||||
# Show remote address in multiaddr format
|
||||
try:
|
||||
remote_address = stream.get_remote_address()
|
||||
if remote_address:
|
||||
print(f" Remote: {remote_address}")
|
||||
except Exception:
|
||||
print(" Remote: Unknown")
|
||||
|
||||
print(" ─" * 40)
|
||||
|
||||
# Call the original handler
|
||||
await echo_handler(stream)
|
||||
|
||||
print(" ─" * 40)
|
||||
print(f"✅ Echo request completed for peer: {peer_id}")
|
||||
print()
|
||||
|
||||
# Replace the handler with our custom one
|
||||
host.set_stream_handler(ECHO_PROTOCOL_ID, custom_echo_handler)
|
||||
|
||||
# Wait indefinitely or until cancelled
|
||||
with cancel_scope:
|
||||
await trio.sleep_forever()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating WebSocket server: {e}")
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
else:
|
||||
# Create second host (dialer) with WebSocket transport
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}/ws")
|
||||
|
||||
try:
|
||||
# Create a single host for client operations
|
||||
host = create_websocket_host(use_plaintext=use_plaintext)
|
||||
|
||||
# Start the host for client operations
|
||||
async with (
|
||||
host.run(listen_addrs=[listen_addr]),
|
||||
trio.open_nursery() as (nursery),
|
||||
):
|
||||
# Start the peer-store cleanup task
|
||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||
|
||||
# Add connection event handlers for debugging
|
||||
class ClientDebugNotifee(INotifee):
|
||||
async def opened_stream(self, network, stream):
|
||||
pass
|
||||
|
||||
async def closed_stream(self, network, stream):
|
||||
pass
|
||||
|
||||
async def connected(self, network, conn):
|
||||
print(
|
||||
f"🔗 Client: libp2p connection established: "
|
||||
f"{conn.muxed_conn.peer_id}"
|
||||
)
|
||||
|
||||
async def disconnected(self, network, conn):
|
||||
print(
|
||||
f"🔌 Client: libp2p connection closed: "
|
||||
f"{conn.muxed_conn.peer_id}"
|
||||
)
|
||||
|
||||
async def listen(self, network, multiaddr):
|
||||
pass
|
||||
|
||||
async def listen_close(self, network, multiaddr):
|
||||
pass
|
||||
|
||||
host.get_network().register_notifee(ClientDebugNotifee())
|
||||
|
||||
maddr = multiaddr.Multiaddr(destination)
|
||||
info = info_from_p2p_addr(maddr)
|
||||
print("🔌 WebSocket Client Starting...")
|
||||
print("=" * 40)
|
||||
print(f"🎯 Target Peer: {info.peer_id}")
|
||||
print(f"📍 Target Address: {destination}")
|
||||
print()
|
||||
|
||||
try:
|
||||
print("🔗 Connecting to WebSocket server...")
|
||||
print(f" Security: {'Plaintext' if use_plaintext else 'Noise'}")
|
||||
await host.connect(info)
|
||||
print("✅ Successfully connected to WebSocket server!")
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print("\n❌ Connection Failed!")
|
||||
print(f" Peer ID: {info.peer_id}")
|
||||
print(f" Address: {destination}")
|
||||
print(f" Security: {'Plaintext' if use_plaintext else 'Noise'}")
|
||||
print(f" Error: {error_msg}")
|
||||
print(f" Error type: {type(e).__name__}")
|
||||
|
||||
# Add more detailed error information for debugging
|
||||
if hasattr(e, "__cause__") and e.__cause__:
|
||||
print(f" Root cause: {e.__cause__}")
|
||||
print(f" Root cause type: {type(e.__cause__).__name__}")
|
||||
|
||||
print()
|
||||
print("💡 Troubleshooting:")
|
||||
print(" • Make sure the WebSocket server is running")
|
||||
print(" • Check that the server address is correct")
|
||||
print(" • Verify the server is listening on the right port")
|
||||
print(
|
||||
" • Ensure both client and server use the same sec protocol"
|
||||
)
|
||||
if not use_plaintext:
|
||||
print(" • Noise over WebSocket may have compatibility issues")
|
||||
return
|
||||
|
||||
# Create a stream and send test data
|
||||
try:
|
||||
stream = await host.new_stream(info.peer_id, [ECHO_PROTOCOL_ID])
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to create stream: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
print("🚀 Starting Echo Protocol Test...")
|
||||
print("─" * 40)
|
||||
|
||||
# Send test data
|
||||
test_message = b"Hello WebSocket Transport!"
|
||||
print(f"📤 Sending message: {test_message.decode('utf-8')}")
|
||||
await stream.write(test_message)
|
||||
|
||||
# Read response
|
||||
print("⏳ Waiting for server response...")
|
||||
response = await stream.read(1024)
|
||||
print(f"📥 Received response: {response.decode('utf-8')}")
|
||||
|
||||
await stream.close()
|
||||
|
||||
print("─" * 40)
|
||||
if response == test_message:
|
||||
print("🎉 Echo test successful!")
|
||||
print("✅ WebSocket transport is working perfectly!")
|
||||
print("✅ Client completed successfully, exiting.")
|
||||
else:
|
||||
print("❌ Echo test failed!")
|
||||
print(" Response doesn't match sent data.")
|
||||
print(f" Sent: {test_message}")
|
||||
print(f" Received: {response}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f"Echo protocol error: {error_msg}")
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# Ensure stream is closed
|
||||
try:
|
||||
if stream:
|
||||
# Check if stream has is_closed method and use it
|
||||
has_is_closed = hasattr(stream, "is_closed") and callable(
|
||||
getattr(stream, "is_closed")
|
||||
)
|
||||
if has_is_closed:
|
||||
# type: ignore[attr-defined]
|
||||
if not await stream.is_closed():
|
||||
await stream.close()
|
||||
else:
|
||||
# Fallback: just try to close the stream
|
||||
await stream.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# host.run() context manager handles cleanup automatically
|
||||
print()
|
||||
print("🎉 WebSocket Demo Completed Successfully!")
|
||||
print("=" * 50)
|
||||
print("✅ WebSocket transport is working perfectly!")
|
||||
print("✅ Echo protocol communication successful!")
|
||||
print("✅ libp2p integration verified!")
|
||||
print()
|
||||
print("🚀 Your WebSocket transport is ready for production use!")
|
||||
|
||||
# Add a small delay to ensure all cleanup is complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating WebSocket client: {e}")
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
|
||||
def main() -> None:
|
||||
description = """
|
||||
This program demonstrates the libp2p WebSocket transport.
|
||||
First run
|
||||
'python websocket_demo.py -p <PORT> [--plaintext]' to start a WebSocket server.
|
||||
Then run
|
||||
'python websocket_demo.py <ANOTHER_PORT> -d <DESTINATION> [--plaintext]'
|
||||
where <DESTINATION> is the multiaddress shown by the server.
|
||||
|
||||
By default, this example uses Noise encryption for secure communication.
|
||||
Use --plaintext for testing with unencrypted communication
|
||||
(not recommended for production).
|
||||
"""
|
||||
|
||||
example_maddr = (
|
||||
"/ip4/127.0.0.1/tcp/8888/ws/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--destination",
|
||||
type=str,
|
||||
help=f"destination multiaddr string, e.g. {example_maddr}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plaintext",
|
||||
action="store_true",
|
||||
help=(
|
||||
"use plaintext security instead of Noise encryption "
|
||||
"(not recommended for production)"
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine security mode: use Noise by default,
|
||||
# plaintext if --plaintext is specified
|
||||
use_plaintext = args.plaintext
|
||||
|
||||
try:
|
||||
trio.run(run, args.port, args.destination, use_plaintext)
|
||||
except KeyboardInterrupt:
|
||||
# This is expected when Ctrl+C is pressed
|
||||
# The signal handler already printed the shutdown message
|
||||
print("✅ Clean exit completed.")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"❌ Unexpected error: {e}")
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,6 +1,7 @@
|
||||
"""Libp2p Python implementation."""
|
||||
|
||||
import logging
|
||||
import ssl
|
||||
|
||||
from libp2p.transport.quic.utils import is_quic_multiaddr
|
||||
from typing import Any
|
||||
@ -24,6 +25,7 @@ from libp2p.abc import (
|
||||
IPeerRouting,
|
||||
IPeerStore,
|
||||
ISecureTransport,
|
||||
ITransport,
|
||||
)
|
||||
from libp2p.crypto.keys import (
|
||||
KeyPair,
|
||||
@ -80,6 +82,10 @@ from libp2p.transport.tcp.tcp import (
|
||||
from libp2p.transport.upgrader import (
|
||||
TransportUpgrader,
|
||||
)
|
||||
from libp2p.transport.transport_registry import (
|
||||
create_transport_for_multiaddr,
|
||||
get_supported_transport_protocols,
|
||||
)
|
||||
from libp2p.utils.logging import (
|
||||
setup_logging,
|
||||
)
|
||||
@ -174,7 +180,10 @@ def new_swarm(
|
||||
enable_quic: bool = False,
|
||||
retry_config: Optional["RetryConfig"] = None,
|
||||
connection_config: ConnectionConfig | QUICTransportConfig | None = None,
|
||||
tls_client_config: ssl.SSLContext | None = None,
|
||||
tls_server_config: ssl.SSLContext | None = None,
|
||||
) -> INetworkService:
|
||||
logger.debug(f"new_swarm: enable_quic={enable_quic}, listen_addrs={listen_addrs}")
|
||||
"""
|
||||
Create a swarm instance based on the parameters.
|
||||
|
||||
@ -198,7 +207,7 @@ def new_swarm(
|
||||
|
||||
id_opt = generate_peer_id_from(key_pair)
|
||||
|
||||
transport: TCP | QUICTransport
|
||||
transport: TCP | QUICTransport | ITransport
|
||||
quic_transport_opt = connection_config if isinstance(connection_config, QUICTransportConfig) else None
|
||||
|
||||
if listen_addrs is None:
|
||||
@ -207,14 +216,39 @@ def new_swarm(
|
||||
else:
|
||||
transport = TCP()
|
||||
else:
|
||||
# Use transport registry to select the appropriate transport
|
||||
from libp2p.transport.transport_registry import create_transport_for_multiaddr
|
||||
|
||||
# Create a temporary upgrader for transport selection
|
||||
# We'll create the real upgrader later with the proper configuration
|
||||
temp_upgrader = TransportUpgrader(
|
||||
secure_transports_by_protocol={},
|
||||
muxer_transports_by_protocol={}
|
||||
)
|
||||
|
||||
addr = listen_addrs[0]
|
||||
is_quic = is_quic_multiaddr(addr)
|
||||
if addr.__contains__("tcp"):
|
||||
transport = TCP()
|
||||
elif is_quic:
|
||||
transport = QUICTransport(key_pair.private_key, config=quic_transport_opt)
|
||||
else:
|
||||
raise ValueError(f"Unknown transport in listen_addrs: {listen_addrs}")
|
||||
logger.debug(f"new_swarm: Creating transport for address: {addr}")
|
||||
transport_maybe = create_transport_for_multiaddr(
|
||||
addr,
|
||||
temp_upgrader,
|
||||
private_key=key_pair.private_key,
|
||||
config=quic_transport_opt,
|
||||
tls_client_config=tls_client_config,
|
||||
tls_server_config=tls_server_config
|
||||
)
|
||||
|
||||
if transport_maybe is None:
|
||||
raise ValueError(f"Unsupported transport for listen_addrs: {listen_addrs}")
|
||||
|
||||
transport = transport_maybe
|
||||
logger.debug(f"new_swarm: Created transport: {type(transport)}")
|
||||
|
||||
# If enable_quic is True but we didn't get a QUIC transport, force QUIC
|
||||
if enable_quic and not isinstance(transport, QUICTransport):
|
||||
logger.debug(f"new_swarm: Forcing QUIC transport (enable_quic=True but got {type(transport)})")
|
||||
transport = QUICTransport(key_pair.private_key, config=quic_transport_opt)
|
||||
|
||||
logger.debug(f"new_swarm: Final transport type: {type(transport)}")
|
||||
|
||||
# Generate X25519 keypair for Noise
|
||||
noise_key_pair = create_new_x25519_key_pair()
|
||||
@ -255,6 +289,7 @@ def new_swarm(
|
||||
muxer_transports_by_protocol=muxer_transports_by_protocol,
|
||||
)
|
||||
|
||||
|
||||
peerstore = peerstore_opt or PeerStore()
|
||||
# Store our key pair in peerstore
|
||||
peerstore.add_key_pair(id_opt, key_pair)
|
||||
@ -282,6 +317,8 @@ def new_host(
|
||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
enable_quic: bool = False,
|
||||
quic_transport_opt: QUICTransportConfig | None = None,
|
||||
tls_client_config: ssl.SSLContext | None = None,
|
||||
tls_server_config: ssl.SSLContext | None = None,
|
||||
) -> IHost:
|
||||
"""
|
||||
Create a new libp2p host based on the given parameters.
|
||||
@ -296,7 +333,9 @@ def new_host(
|
||||
:param enable_mDNS: whether to enable mDNS discovery
|
||||
:param bootstrap: optional list of bootstrap peer addresses as strings
|
||||
:param enable_quic: optinal choice to use QUIC for transport
|
||||
:param transport_opt: optional configuration for quic transport
|
||||
:param quic_transport_opt: optional configuration for quic transport
|
||||
:param tls_client_config: optional TLS client configuration for WebSocket transport
|
||||
:param tls_server_config: optional TLS server configuration for WebSocket transport
|
||||
:return: return a host instance
|
||||
"""
|
||||
|
||||
@ -311,7 +350,9 @@ def new_host(
|
||||
peerstore_opt=peerstore_opt,
|
||||
muxer_preference=muxer_preference,
|
||||
listen_addrs=listen_addrs,
|
||||
connection_config=quic_transport_opt if enable_quic else None
|
||||
connection_config=quic_transport_opt if enable_quic else None,
|
||||
tls_client_config=tls_client_config,
|
||||
tls_server_config=tls_server_config
|
||||
)
|
||||
|
||||
if disc_opt is not None:
|
||||
|
||||
@ -481,13 +481,16 @@ class Swarm(Service, INetworkService):
|
||||
- Call listener listen with the multiaddr
|
||||
- Map multiaddr to listener
|
||||
"""
|
||||
logger.debug(f"Swarm.listen called with multiaddrs: {multiaddrs}")
|
||||
# We need to wait until `self.listener_nursery` is created.
|
||||
logger.debug("Starting to listen")
|
||||
await self.event_listener_nursery_created.wait()
|
||||
|
||||
success_count = 0
|
||||
for maddr in multiaddrs:
|
||||
logger.debug(f"Swarm.listen processing multiaddr: {maddr}")
|
||||
if str(maddr) in self.listeners:
|
||||
logger.debug(f"Swarm.listen: listener already exists for {maddr}")
|
||||
success_count += 1
|
||||
continue
|
||||
|
||||
@ -545,13 +548,18 @@ class Swarm(Service, INetworkService):
|
||||
|
||||
try:
|
||||
# Success
|
||||
logger.debug(f"Swarm.listen: creating listener for {maddr}")
|
||||
listener = self.transport.create_listener(conn_handler)
|
||||
logger.debug(f"Swarm.listen: listener created for {maddr}")
|
||||
self.listeners[str(maddr)] = listener
|
||||
# TODO: `listener.listen` is not bounded with nursery. If we want to be
|
||||
# I/O agnostic, we should change the API.
|
||||
if self.listener_nursery is None:
|
||||
raise SwarmException("swarm instance hasn't been run")
|
||||
assert self.listener_nursery is not None # For type checker
|
||||
logger.debug(f"Swarm.listen: calling listener.listen for {maddr}")
|
||||
await listener.listen(maddr, self.listener_nursery)
|
||||
logger.debug(f"Swarm.listen: listener.listen completed for {maddr}")
|
||||
|
||||
# Call notifiers since event occurred
|
||||
await self.notify_listen(maddr)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from typing import (
|
||||
cast,
|
||||
)
|
||||
@ -15,6 +16,8 @@ from libp2p.io.msgio import (
|
||||
FixedSizeLenMsgReadWriter,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SIZE_NOISE_MESSAGE_LEN = 2
|
||||
MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1
|
||||
SIZE_NOISE_MESSAGE_BODY_LEN = 2
|
||||
@ -50,18 +53,25 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
|
||||
self.noise_state = noise_state
|
||||
|
||||
async def write_msg(self, msg: bytes, prefix_encoded: bool = False) -> None:
|
||||
logger.debug(f"Noise write_msg: encrypting {len(msg)} bytes")
|
||||
data_encrypted = self.encrypt(msg)
|
||||
if prefix_encoded:
|
||||
# Manually add the prefix if needed
|
||||
data_encrypted = self.prefix + data_encrypted
|
||||
logger.debug(f"Noise write_msg: writing {len(data_encrypted)} encrypted bytes")
|
||||
await self.read_writer.write_msg(data_encrypted)
|
||||
logger.debug("Noise write_msg: write completed successfully")
|
||||
|
||||
async def read_msg(self, prefix_encoded: bool = False) -> bytes:
|
||||
logger.debug("Noise read_msg: reading encrypted message")
|
||||
noise_msg_encrypted = await self.read_writer.read_msg()
|
||||
logger.debug(f"Noise read_msg: read {len(noise_msg_encrypted)} encrypted bytes")
|
||||
if prefix_encoded:
|
||||
return self.decrypt(noise_msg_encrypted[len(self.prefix) :])
|
||||
result = self.decrypt(noise_msg_encrypted[len(self.prefix) :])
|
||||
else:
|
||||
return self.decrypt(noise_msg_encrypted)
|
||||
result = self.decrypt(noise_msg_encrypted)
|
||||
logger.debug(f"Noise read_msg: decrypted to {len(result)} bytes")
|
||||
return result
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.read_writer.close()
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
)
|
||||
import logging
|
||||
|
||||
from libp2p.crypto.keys import (
|
||||
PrivateKey,
|
||||
@ -12,6 +13,8 @@ from libp2p.crypto.serialization import (
|
||||
|
||||
from .pb import noise_pb2 as noise_pb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SIGNED_DATA_PREFIX = "noise-libp2p-static-key:"
|
||||
|
||||
|
||||
@ -48,6 +51,8 @@ def make_handshake_payload_sig(
|
||||
id_privkey: PrivateKey, noise_static_pubkey: PublicKey
|
||||
) -> bytes:
|
||||
data = make_data_to_be_signed(noise_static_pubkey)
|
||||
logger.debug(f"make_handshake_payload_sig: signing data length: {len(data)}")
|
||||
logger.debug(f"make_handshake_payload_sig: signing data hex: {data.hex()}")
|
||||
return id_privkey.sign(data)
|
||||
|
||||
|
||||
@ -60,4 +65,27 @@ def verify_handshake_payload_sig(
|
||||
2. signed by the private key corresponding to `id_pubkey`
|
||||
"""
|
||||
expected_data = make_data_to_be_signed(noise_static_pubkey)
|
||||
return payload.id_pubkey.verify(expected_data, payload.id_sig)
|
||||
logger.debug(
|
||||
f"verify_handshake_payload_sig: payload.id_pubkey type: "
|
||||
f"{type(payload.id_pubkey)}"
|
||||
)
|
||||
logger.debug(
|
||||
f"verify_handshake_payload_sig: noise_static_pubkey type: "
|
||||
f"{type(noise_static_pubkey)}"
|
||||
)
|
||||
logger.debug(
|
||||
f"verify_handshake_payload_sig: expected_data length: {len(expected_data)}"
|
||||
)
|
||||
logger.debug(
|
||||
f"verify_handshake_payload_sig: expected_data hex: {expected_data.hex()}"
|
||||
)
|
||||
logger.debug(
|
||||
f"verify_handshake_payload_sig: payload.id_sig length: {len(payload.id_sig)}"
|
||||
)
|
||||
try:
|
||||
result = payload.id_pubkey.verify(expected_data, payload.id_sig)
|
||||
logger.debug(f"verify_handshake_payload_sig: verification result: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"verify_handshake_payload_sig: verification exception: {e}")
|
||||
return False
|
||||
|
||||
@ -2,6 +2,7 @@ from abc import (
|
||||
ABC,
|
||||
abstractmethod,
|
||||
)
|
||||
import logging
|
||||
|
||||
from cryptography.hazmat.primitives import (
|
||||
serialization,
|
||||
@ -46,6 +47,8 @@ from .messages import (
|
||||
verify_handshake_payload_sig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IPattern(ABC):
|
||||
@abstractmethod
|
||||
@ -95,6 +98,7 @@ class PatternXX(BasePattern):
|
||||
self.early_data = early_data
|
||||
|
||||
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
|
||||
logger.debug(f"Noise XX handshake_inbound started for peer {self.local_peer}")
|
||||
noise_state = self.create_noise_state()
|
||||
noise_state.set_as_responder()
|
||||
noise_state.start_handshake()
|
||||
@ -107,15 +111,22 @@ class PatternXX(BasePattern):
|
||||
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
||||
|
||||
# Consume msg#1.
|
||||
logger.debug("Noise XX handshake_inbound: reading msg#1")
|
||||
await read_writer.read_msg()
|
||||
logger.debug("Noise XX handshake_inbound: read msg#1 successfully")
|
||||
|
||||
# Send msg#2, which should include our handshake payload.
|
||||
logger.debug("Noise XX handshake_inbound: preparing msg#2")
|
||||
our_payload = self.make_handshake_payload()
|
||||
msg_2 = our_payload.serialize()
|
||||
logger.debug(f"Noise XX handshake_inbound: sending msg#2 ({len(msg_2)} bytes)")
|
||||
await read_writer.write_msg(msg_2)
|
||||
logger.debug("Noise XX handshake_inbound: sent msg#2 successfully")
|
||||
|
||||
# Receive and consume msg#3.
|
||||
logger.debug("Noise XX handshake_inbound: reading msg#3")
|
||||
msg_3 = await read_writer.read_msg()
|
||||
logger.debug(f"Noise XX handshake_inbound: read msg#3 ({len(msg_3)} bytes)")
|
||||
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
|
||||
|
||||
if handshake_state.rs is None:
|
||||
@ -147,6 +158,7 @@ class PatternXX(BasePattern):
|
||||
async def handshake_outbound(
|
||||
self, conn: IRawConnection, remote_peer: ID
|
||||
) -> ISecureConn:
|
||||
logger.debug(f"Noise XX handshake_outbound started to peer {remote_peer}")
|
||||
noise_state = self.create_noise_state()
|
||||
|
||||
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
||||
@ -159,11 +171,15 @@ class PatternXX(BasePattern):
|
||||
raise NoiseStateError("Handshake state is not initialized")
|
||||
|
||||
# Send msg#1, which is *not* encrypted.
|
||||
logger.debug("Noise XX handshake_outbound: sending msg#1")
|
||||
msg_1 = b""
|
||||
await read_writer.write_msg(msg_1)
|
||||
logger.debug("Noise XX handshake_outbound: sent msg#1 successfully")
|
||||
|
||||
# Read msg#2 from the remote, which contains the public key of the peer.
|
||||
logger.debug("Noise XX handshake_outbound: reading msg#2")
|
||||
msg_2 = await read_writer.read_msg()
|
||||
logger.debug(f"Noise XX handshake_outbound: read msg#2 ({len(msg_2)} bytes)")
|
||||
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
|
||||
|
||||
if handshake_state.rs is None:
|
||||
@ -174,8 +190,27 @@ class PatternXX(BasePattern):
|
||||
)
|
||||
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
|
||||
|
||||
logger.debug(
|
||||
f"Noise XX handshake_outbound: verifying signature for peer {remote_peer}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Noise XX handshake_outbound: remote_pubkey type: {type(remote_pubkey)}"
|
||||
)
|
||||
id_pubkey_repr = peer_handshake_payload.id_pubkey.to_bytes().hex()
|
||||
logger.debug(
|
||||
f"Noise XX handshake_outbound: peer_handshake_payload.id_pubkey: "
|
||||
f"{id_pubkey_repr}"
|
||||
)
|
||||
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
|
||||
logger.error(
|
||||
f"Noise XX handshake_outbound: signature verification failed for peer "
|
||||
f"{remote_peer}"
|
||||
)
|
||||
raise InvalidSignature
|
||||
logger.debug(
|
||||
f"Noise XX handshake_outbound: signature verification successful for peer "
|
||||
f"{remote_peer}"
|
||||
)
|
||||
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
|
||||
if remote_peer_id_from_pubkey != remote_peer:
|
||||
raise PeerIDMismatchesPubkey(
|
||||
|
||||
@ -0,0 +1,57 @@
|
||||
from typing import Any
|
||||
|
||||
from .tcp.tcp import TCP
|
||||
from .websocket.transport import WebsocketTransport
|
||||
from .transport_registry import (
|
||||
TransportRegistry,
|
||||
create_transport_for_multiaddr,
|
||||
get_transport_registry,
|
||||
register_transport,
|
||||
get_supported_transport_protocols,
|
||||
)
|
||||
from .upgrader import TransportUpgrader
|
||||
from libp2p.abc import ITransport
|
||||
|
||||
def create_transport(protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any) -> ITransport:
|
||||
"""
|
||||
Convenience function to create a transport instance.
|
||||
|
||||
:param protocol: The transport protocol ("tcp", "ws", "wss", or custom)
|
||||
:param upgrader: Optional transport upgrader (required for WebSocket)
|
||||
:param kwargs: Additional arguments for transport construction (e.g., tls_client_config, tls_server_config)
|
||||
:return: Transport instance
|
||||
"""
|
||||
# First check if it's a built-in protocol
|
||||
if protocol in ["ws", "wss"]:
|
||||
if upgrader is None:
|
||||
raise ValueError(f"WebSocket transport requires an upgrader")
|
||||
return WebsocketTransport(
|
||||
upgrader,
|
||||
tls_client_config=kwargs.get("tls_client_config"),
|
||||
tls_server_config=kwargs.get("tls_server_config"),
|
||||
handshake_timeout=kwargs.get("handshake_timeout", 15.0)
|
||||
)
|
||||
elif protocol == "tcp":
|
||||
return TCP()
|
||||
else:
|
||||
# Check if it's a custom registered transport
|
||||
registry = get_transport_registry()
|
||||
transport_class = registry.get_transport(protocol)
|
||||
if transport_class:
|
||||
transport = registry.create_transport(protocol, upgrader, **kwargs)
|
||||
if transport is None:
|
||||
raise ValueError(f"Failed to create transport for protocol: {protocol}")
|
||||
return transport
|
||||
else:
|
||||
raise ValueError(f"Unsupported transport protocol: {protocol}")
|
||||
|
||||
__all__ = [
|
||||
"TCP",
|
||||
"WebsocketTransport",
|
||||
"TransportRegistry",
|
||||
"create_transport_for_multiaddr",
|
||||
"create_transport",
|
||||
"get_transport_registry",
|
||||
"register_transport",
|
||||
"get_supported_transport_protocols",
|
||||
]
|
||||
|
||||
@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from aioquic.quic import events
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
@ -871,9 +871,11 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# Process events by type
|
||||
for event_type, event_list in events_by_type.items():
|
||||
if event_type == type(events.StreamDataReceived).__name__:
|
||||
await self._handle_stream_data_batch(
|
||||
cast(list[events.StreamDataReceived], event_list)
|
||||
)
|
||||
# Filter to only StreamDataReceived events
|
||||
stream_data_events = [
|
||||
e for e in event_list if isinstance(e, events.StreamDataReceived)
|
||||
]
|
||||
await self._handle_stream_data_batch(stream_data_events)
|
||||
else:
|
||||
# Process other events individually
|
||||
for event in event_list:
|
||||
|
||||
267
libp2p/transport/transport_registry.py
Normal file
267
libp2p/transport/transport_registry.py
Normal file
@ -0,0 +1,267 @@
|
||||
"""
|
||||
Transport registry for dynamic transport selection based on multiaddr protocols.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
from multiaddr.protocols import Protocol
|
||||
|
||||
from libp2p.abc import ITransport
|
||||
from libp2p.transport.tcp.tcp import TCP
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
from libp2p.transport.websocket.multiaddr_utils import (
|
||||
is_valid_websocket_multiaddr,
|
||||
)
|
||||
|
||||
|
||||
# Import QUIC utilities here to avoid circular imports
|
||||
def _get_quic_transport() -> Any:
|
||||
from libp2p.transport.quic.transport import QUICTransport
|
||||
|
||||
return QUICTransport
|
||||
|
||||
|
||||
def _get_quic_validation() -> Callable[[Multiaddr], bool]:
|
||||
from libp2p.transport.quic.utils import is_quic_multiaddr
|
||||
|
||||
return is_quic_multiaddr
|
||||
|
||||
|
||||
# Import WebsocketTransport here to avoid circular imports
|
||||
def _get_websocket_transport() -> Any:
|
||||
from libp2p.transport.websocket.transport import WebsocketTransport
|
||||
|
||||
return WebsocketTransport
|
||||
|
||||
|
||||
logger = logging.getLogger("libp2p.transport.registry")
|
||||
|
||||
|
||||
def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool:
|
||||
"""
|
||||
Validate that a multiaddr has a valid TCP structure.
|
||||
|
||||
:param maddr: The multiaddr to validate
|
||||
:return: True if valid TCP structure, False otherwise
|
||||
"""
|
||||
try:
|
||||
# TCP multiaddr should have structure like /ip4/127.0.0.1/tcp/8080
|
||||
# or /ip6/::1/tcp/8080
|
||||
protocols: list[Protocol] = list(maddr.protocols())
|
||||
|
||||
# Must have at least 2 protocols: network (ip4/ip6) + tcp
|
||||
if len(protocols) < 2:
|
||||
return False
|
||||
|
||||
# First protocol should be a network protocol (ip4, ip6, dns4, dns6)
|
||||
if protocols[0].name not in ["ip4", "ip6", "dns4", "dns6"]:
|
||||
return False
|
||||
|
||||
# Second protocol should be tcp
|
||||
if protocols[1].name != "tcp":
|
||||
return False
|
||||
|
||||
# Should not have any protocols after tcp (unless it's a valid
|
||||
# continuation like p2p)
|
||||
# For now, we'll be strict and only allow network + tcp
|
||||
if len(protocols) > 2:
|
||||
# Check if the additional protocols are valid continuations
|
||||
valid_continuations = ["p2p"] # Add more as needed
|
||||
for i in range(2, len(protocols)):
|
||||
if protocols[i].name not in valid_continuations:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class TransportRegistry:
|
||||
"""
|
||||
Registry for mapping multiaddr protocols to transport implementations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._transports: dict[str, type[ITransport]] = {}
|
||||
self._register_default_transports()
|
||||
|
||||
def _register_default_transports(self) -> None:
|
||||
"""Register the default transport implementations."""
|
||||
# Register TCP transport for /tcp protocol
|
||||
self.register_transport("tcp", TCP)
|
||||
|
||||
# Register WebSocket transport for /ws and /wss protocols
|
||||
WebsocketTransport = _get_websocket_transport()
|
||||
self.register_transport("ws", WebsocketTransport)
|
||||
self.register_transport("wss", WebsocketTransport)
|
||||
|
||||
# Register QUIC transport for /quic and /quic-v1 protocols
|
||||
QUICTransport = _get_quic_transport()
|
||||
self.register_transport("quic", QUICTransport)
|
||||
self.register_transport("quic-v1", QUICTransport)
|
||||
|
||||
def register_transport(
|
||||
self, protocol: str, transport_class: type[ITransport]
|
||||
) -> None:
|
||||
"""
|
||||
Register a transport class for a specific protocol.
|
||||
|
||||
:param protocol: The protocol identifier (e.g., "tcp", "ws")
|
||||
:param transport_class: The transport class to register
|
||||
"""
|
||||
self._transports[protocol] = transport_class
|
||||
logger.debug(
|
||||
f"Registered transport {transport_class.__name__} for protocol {protocol}"
|
||||
)
|
||||
|
||||
def get_transport(self, protocol: str) -> type[ITransport] | None:
|
||||
"""
|
||||
Get the transport class for a specific protocol.
|
||||
|
||||
:param protocol: The protocol identifier
|
||||
:return: The transport class or None if not found
|
||||
"""
|
||||
return self._transports.get(protocol)
|
||||
|
||||
def get_supported_protocols(self) -> list[str]:
|
||||
"""Get list of supported transport protocols."""
|
||||
return list(self._transports.keys())
|
||||
|
||||
def create_transport(
|
||||
self, protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any
|
||||
) -> ITransport | None:
|
||||
"""
|
||||
Create a transport instance for a specific protocol.
|
||||
|
||||
:param protocol: The protocol identifier
|
||||
:param upgrader: The transport upgrader instance (required for WebSocket)
|
||||
:param kwargs: Additional arguments for transport construction
|
||||
:return: Transport instance or None if protocol not supported or creation fails
|
||||
"""
|
||||
transport_class = self.get_transport(protocol)
|
||||
if transport_class is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
if protocol in ["ws", "wss"]:
|
||||
# WebSocket transport requires upgrader
|
||||
if upgrader is None:
|
||||
logger.warning(
|
||||
f"WebSocket transport '{protocol}' requires upgrader"
|
||||
)
|
||||
return None
|
||||
# Use explicit WebsocketTransport to avoid type issues
|
||||
WebsocketTransport = _get_websocket_transport()
|
||||
return WebsocketTransport(
|
||||
upgrader,
|
||||
tls_client_config=kwargs.get("tls_client_config"),
|
||||
tls_server_config=kwargs.get("tls_server_config"),
|
||||
handshake_timeout=kwargs.get("handshake_timeout", 15.0),
|
||||
)
|
||||
elif protocol in ["quic", "quic-v1"]:
|
||||
# QUIC transport requires private_key
|
||||
private_key = kwargs.get("private_key")
|
||||
if private_key is None:
|
||||
logger.warning(f"QUIC transport '{protocol}' requires private_key")
|
||||
return None
|
||||
# Use explicit QUICTransport to avoid type issues
|
||||
QUICTransport = _get_quic_transport()
|
||||
config = kwargs.get("config")
|
||||
return QUICTransport(private_key, config)
|
||||
else:
|
||||
# TCP transport doesn't require upgrader
|
||||
return transport_class()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create transport for protocol {protocol}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Global transport registry instance (lazy initialization)
|
||||
_global_registry: TransportRegistry | None = None
|
||||
|
||||
|
||||
def get_transport_registry() -> TransportRegistry:
|
||||
"""Get the global transport registry instance."""
|
||||
global _global_registry
|
||||
if _global_registry is None:
|
||||
_global_registry = TransportRegistry()
|
||||
return _global_registry
|
||||
|
||||
|
||||
def register_transport(protocol: str, transport_class: type[ITransport]) -> None:
|
||||
"""Register a transport class in the global registry."""
|
||||
registry = get_transport_registry()
|
||||
registry.register_transport(protocol, transport_class)
|
||||
|
||||
|
||||
def create_transport_for_multiaddr(
|
||||
maddr: Multiaddr, upgrader: TransportUpgrader, **kwargs: Any
|
||||
) -> ITransport | None:
|
||||
"""
|
||||
Create the appropriate transport for a given multiaddr.
|
||||
|
||||
:param maddr: The multiaddr to create transport for
|
||||
:param upgrader: The transport upgrader instance
|
||||
:param kwargs: Additional arguments for transport construction
|
||||
(e.g., private_key for QUIC)
|
||||
:return: Transport instance or None if no suitable transport found
|
||||
"""
|
||||
try:
|
||||
# Get all protocols in the multiaddr
|
||||
protocols = [proto.name for proto in maddr.protocols()]
|
||||
|
||||
# Check for supported transport protocols in order of preference
|
||||
# We need to validate that the multiaddr structure is valid for our transports
|
||||
if "quic" in protocols or "quic-v1" in protocols:
|
||||
# For QUIC, we need a valid structure like:
|
||||
# /ip4/127.0.0.1/udp/4001/quic
|
||||
# /ip4/127.0.0.1/udp/4001/quic-v1
|
||||
is_quic_multiaddr = _get_quic_validation()
|
||||
if is_quic_multiaddr(maddr):
|
||||
# Determine QUIC version
|
||||
registry = get_transport_registry()
|
||||
if "quic-v1" in protocols:
|
||||
return registry.create_transport("quic-v1", upgrader, **kwargs)
|
||||
else:
|
||||
return registry.create_transport("quic", upgrader, **kwargs)
|
||||
elif "ws" in protocols or "wss" in protocols or "tls" in protocols:
|
||||
# For WebSocket, we need a valid structure like:
|
||||
# /ip4/127.0.0.1/tcp/8080/ws (insecure)
|
||||
# /ip4/127.0.0.1/tcp/8080/wss (secure)
|
||||
# /ip4/127.0.0.1/tcp/8080/tls/ws (secure with TLS)
|
||||
# /ip4/127.0.0.1/tcp/8080/tls/sni/example.com/ws (secure with SNI)
|
||||
if is_valid_websocket_multiaddr(maddr):
|
||||
# Determine if this is a secure WebSocket connection
|
||||
registry = get_transport_registry()
|
||||
if "wss" in protocols or "tls" in protocols:
|
||||
return registry.create_transport("wss", upgrader, **kwargs)
|
||||
else:
|
||||
return registry.create_transport("ws", upgrader, **kwargs)
|
||||
elif "tcp" in protocols:
|
||||
# For TCP, we need a valid structure like /ip4/127.0.0.1/tcp/8080
|
||||
# Check if the multiaddr has proper TCP structure
|
||||
if _is_valid_tcp_multiaddr(maddr):
|
||||
registry = get_transport_registry()
|
||||
return registry.create_transport("tcp", upgrader)
|
||||
|
||||
# If no supported transport protocol found or structure is invalid, return None
|
||||
logger.warning(
|
||||
f"No supported transport protocol found or invalid structure in "
|
||||
f"multiaddr: {maddr}"
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
# Handle any errors gracefully (e.g., invalid multiaddr)
|
||||
logger.warning(f"Error processing multiaddr {maddr}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_supported_transport_protocols() -> list[str]:
|
||||
"""Get list of supported transport protocols from the global registry."""
|
||||
registry = get_transport_registry()
|
||||
return registry.get_supported_protocols()
|
||||
198
libp2p/transport/websocket/connection.py
Normal file
198
libp2p/transport/websocket/connection.py
Normal file
@ -0,0 +1,198 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
from libp2p.io.exceptions import IOException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class P2PWebSocketConnection(ReadWriteCloser):
|
||||
"""
|
||||
Wraps a WebSocketConnection to provide the raw stream interface
|
||||
that libp2p protocols expect.
|
||||
|
||||
Implements production-ready buffer management and flow control
|
||||
as recommended in the libp2p WebSocket specification.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ws_connection: Any,
|
||||
ws_context: Any = None,
|
||||
is_secure: bool = False,
|
||||
max_buffered_amount: int = 4 * 1024 * 1024,
|
||||
) -> None:
|
||||
self._ws_connection = ws_connection
|
||||
self._ws_context = ws_context
|
||||
self._is_secure = is_secure
|
||||
self._read_buffer = b""
|
||||
self._read_lock = trio.Lock()
|
||||
self._connection_start_time = time.time()
|
||||
self._bytes_read = 0
|
||||
self._bytes_written = 0
|
||||
self._closed = False
|
||||
self._close_lock = trio.Lock()
|
||||
self._max_buffered_amount = max_buffered_amount
|
||||
self._write_lock = trio.Lock()
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""Write data with flow control and buffer management"""
|
||||
if self._closed:
|
||||
raise IOException("Connection is closed")
|
||||
|
||||
async with self._write_lock:
|
||||
try:
|
||||
logger.debug(f"WebSocket writing {len(data)} bytes")
|
||||
|
||||
# Check buffer amount for flow control
|
||||
if hasattr(self._ws_connection, "bufferedAmount"):
|
||||
buffered = self._ws_connection.bufferedAmount
|
||||
if buffered > self._max_buffered_amount:
|
||||
logger.warning(f"WebSocket buffer full: {buffered} bytes")
|
||||
# In production, you might want to
|
||||
# wait or implement backpressure
|
||||
# For now, we'll continue but log the warning
|
||||
|
||||
# Send as a binary WebSocket message
|
||||
await self._ws_connection.send_message(data)
|
||||
self._bytes_written += len(data)
|
||||
logger.debug(f"WebSocket wrote {len(data)} bytes successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket write failed: {e}")
|
||||
self._closed = True
|
||||
raise IOException from e
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
"""
|
||||
Read up to n bytes (if n is given), else read up to 64KiB.
|
||||
This implementation provides byte-level access to WebSocket messages,
|
||||
which is required for libp2p protocol compatibility.
|
||||
|
||||
For WebSocket compatibility with libp2p protocols, this method:
|
||||
1. Buffers incoming WebSocket messages
|
||||
2. Returns exactly the requested number of bytes when n is specified
|
||||
3. Accumulates multiple WebSocket messages if needed to satisfy the request
|
||||
4. Returns empty bytes (not raises) when connection is closed and no data
|
||||
available
|
||||
"""
|
||||
if self._closed:
|
||||
raise IOException("Connection is closed")
|
||||
|
||||
async with self._read_lock:
|
||||
try:
|
||||
# If n is None, read at least one message and return all buffered data
|
||||
if n is None:
|
||||
if not self._read_buffer:
|
||||
try:
|
||||
# Use a short timeout to avoid blocking indefinitely
|
||||
with trio.fail_after(1.0): # 1 second timeout
|
||||
message = await self._ws_connection.get_message()
|
||||
if isinstance(message, str):
|
||||
message = message.encode("utf-8")
|
||||
self._read_buffer = message
|
||||
except trio.TooSlowError:
|
||||
# No message available within timeout
|
||||
return b""
|
||||
except Exception:
|
||||
# Return empty bytes if no data available
|
||||
# (connection closed)
|
||||
return b""
|
||||
|
||||
result = self._read_buffer
|
||||
self._read_buffer = b""
|
||||
self._bytes_read += len(result)
|
||||
return result
|
||||
|
||||
# For specific byte count requests, return UP TO n bytes (not exactly n)
|
||||
# This matches TCP semantics where read(1024) returns available data
|
||||
# up to 1024 bytes
|
||||
|
||||
# If we don't have any data buffered, try to get at least one message
|
||||
if not self._read_buffer:
|
||||
try:
|
||||
# Use a short timeout to avoid blocking indefinitely
|
||||
with trio.fail_after(1.0): # 1 second timeout
|
||||
message = await self._ws_connection.get_message()
|
||||
if isinstance(message, str):
|
||||
message = message.encode("utf-8")
|
||||
self._read_buffer = message
|
||||
except trio.TooSlowError:
|
||||
return b"" # No data available
|
||||
except Exception:
|
||||
return b""
|
||||
|
||||
# Now return up to n bytes from the buffer (TCP-like semantics)
|
||||
if len(self._read_buffer) == 0:
|
||||
return b""
|
||||
|
||||
# Return up to n bytes (like TCP read())
|
||||
result = self._read_buffer[:n]
|
||||
self._read_buffer = self._read_buffer[len(result) :]
|
||||
self._bytes_read += len(result)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket read failed: {e}")
|
||||
raise IOException from e
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket connection. This method is idempotent."""
|
||||
async with self._close_lock:
|
||||
if self._closed:
|
||||
return # Already closed
|
||||
|
||||
logger.debug("WebSocket connection closing")
|
||||
self._closed = True
|
||||
try:
|
||||
# Always close the connection directly, avoid context manager issues
|
||||
# The context manager may be causing cancel scope corruption
|
||||
logger.debug("WebSocket closing connection directly")
|
||||
await self._ws_connection.aclose()
|
||||
# Exit the context manager if we have one
|
||||
if self._ws_context is not None:
|
||||
await self._ws_context.__aexit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket close error: {e}")
|
||||
# Don't raise here, as close() should be idempotent
|
||||
finally:
|
||||
logger.debug("WebSocket connection closed")
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if the connection is closed"""
|
||||
return self._closed
|
||||
|
||||
def conn_state(self) -> dict[str, Any]:
|
||||
"""
|
||||
Return connection state information similar to Go's ConnState() method.
|
||||
|
||||
:return: Dictionary containing connection state information
|
||||
"""
|
||||
current_time = time.time()
|
||||
return {
|
||||
"transport": "websocket",
|
||||
"secure": self._is_secure,
|
||||
"connection_duration": current_time - self._connection_start_time,
|
||||
"bytes_read": self._bytes_read,
|
||||
"bytes_written": self._bytes_written,
|
||||
"total_bytes": self._bytes_read + self._bytes_written,
|
||||
}
|
||||
|
||||
def get_remote_address(self) -> tuple[str, int] | None:
|
||||
# Try to get remote address from the WebSocket connection
|
||||
try:
|
||||
remote = self._ws_connection.remote
|
||||
if hasattr(remote, "address") and hasattr(remote, "port"):
|
||||
return str(remote.address), int(remote.port)
|
||||
elif isinstance(remote, str):
|
||||
# Parse address:port format
|
||||
if ":" in remote:
|
||||
host, port = remote.rsplit(":", 1)
|
||||
return host, int(port)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
225
libp2p/transport/websocket/listener.py
Normal file
225
libp2p/transport/websocket/listener.py
Normal file
@ -0,0 +1,225 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
import logging
|
||||
import ssl
|
||||
from typing import Any
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
from trio_typing import TaskStatus
|
||||
from trio_websocket import serve_websocket
|
||||
|
||||
from libp2p.abc import IListener
|
||||
from libp2p.custom_types import THandler
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
from libp2p.transport.websocket.multiaddr_utils import parse_websocket_multiaddr
|
||||
|
||||
from .connection import P2PWebSocketConnection
|
||||
|
||||
logger = logging.getLogger("libp2p.transport.websocket.listener")
|
||||
|
||||
|
||||
class WebsocketListener(IListener):
|
||||
"""
|
||||
Listen on /ip4/.../tcp/.../ws addresses, handshake WS, wrap into RawConnection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler: THandler,
|
||||
upgrader: TransportUpgrader,
|
||||
tls_config: ssl.SSLContext | None = None,
|
||||
handshake_timeout: float = 15.0,
|
||||
) -> None:
|
||||
self._handler = handler
|
||||
self._upgrader = upgrader
|
||||
self._tls_config = tls_config
|
||||
self._handshake_timeout = handshake_timeout
|
||||
self._server = None
|
||||
self._shutdown_event = trio.Event()
|
||||
self._nursery: trio.Nursery | None = None
|
||||
self._listeners: Any = None
|
||||
self._is_wss = False # Track whether this is a WSS listener
|
||||
|
||||
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
|
||||
logger.debug(f"WebsocketListener.listen called with {maddr}")
|
||||
|
||||
# Parse the WebSocket multiaddr to determine if it's secure
|
||||
try:
|
||||
parsed = parse_websocket_multiaddr(maddr)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid WebSocket multiaddr: {e}") from e
|
||||
|
||||
# Check if WSS is requested but no TLS config provided
|
||||
if parsed.is_wss and self._tls_config is None:
|
||||
raise ValueError(
|
||||
f"Cannot listen on WSS address {maddr} without TLS configuration"
|
||||
)
|
||||
|
||||
# Store whether this is a WSS listener
|
||||
self._is_wss = parsed.is_wss
|
||||
|
||||
# Extract host and port from the base multiaddr
|
||||
host = (
|
||||
parsed.rest_multiaddr.value_for_protocol("ip4")
|
||||
or parsed.rest_multiaddr.value_for_protocol("ip6")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns4")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns6")
|
||||
or "0.0.0.0"
|
||||
)
|
||||
port_str = parsed.rest_multiaddr.value_for_protocol("tcp")
|
||||
if port_str is None:
|
||||
raise ValueError(f"No TCP port found in multiaddr: {maddr}")
|
||||
port = int(port_str)
|
||||
|
||||
logger.debug(
|
||||
f"WebsocketListener: host={host}, port={port}, secure={parsed.is_wss}"
|
||||
)
|
||||
|
||||
async def serve_websocket_tcp(
|
||||
handler: Callable[[Any], Awaitable[None]],
|
||||
port: int,
|
||||
host: str,
|
||||
task_status: TaskStatus[Any],
|
||||
) -> None:
|
||||
"""Start TCP server and handle WebSocket connections manually"""
|
||||
logger.debug(
|
||||
"serve_websocket_tcp %s %s (secure=%s)", host, port, parsed.is_wss
|
||||
)
|
||||
|
||||
async def websocket_handler(request: Any) -> None:
|
||||
"""Handle WebSocket requests"""
|
||||
logger.debug("WebSocket request received")
|
||||
try:
|
||||
# Apply handshake timeout
|
||||
with trio.fail_after(self._handshake_timeout):
|
||||
# Accept the WebSocket connection
|
||||
ws_connection = await request.accept()
|
||||
logger.debug("WebSocket handshake successful")
|
||||
|
||||
# Create the WebSocket connection wrapper
|
||||
conn = P2PWebSocketConnection(
|
||||
ws_connection, is_secure=parsed.is_wss
|
||||
) # type: ignore[no-untyped-call]
|
||||
|
||||
# Call the handler function that was passed to create_listener
|
||||
# This handler will handle the security and muxing upgrades
|
||||
logger.debug("Calling connection handler")
|
||||
await self._handler(conn)
|
||||
|
||||
# Don't keep the connection alive indefinitely
|
||||
# Let the handler manage the connection lifecycle
|
||||
logger.debug(
|
||||
"Handler completed, connection will be managed by handler"
|
||||
)
|
||||
|
||||
except trio.TooSlowError:
|
||||
logger.debug(
|
||||
f"WebSocket handshake timeout after {self._handshake_timeout}s"
|
||||
)
|
||||
try:
|
||||
await request.reject(408) # Request Timeout
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"WebSocket connection error: {e}")
|
||||
logger.debug(f"Error type: {type(e)}")
|
||||
import traceback
|
||||
|
||||
logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
# Reject the connection
|
||||
try:
|
||||
await request.reject(400)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Use trio_websocket.serve_websocket for proper WebSocket handling
|
||||
ssl_context = self._tls_config if parsed.is_wss else None
|
||||
await serve_websocket(
|
||||
websocket_handler, host, port, ssl_context, task_status=task_status
|
||||
)
|
||||
|
||||
# Store the nursery for shutdown
|
||||
self._nursery = nursery
|
||||
|
||||
# Start the server using nursery.start() like TCP does
|
||||
logger.debug("Calling nursery.start()...")
|
||||
started_listeners = await nursery.start(
|
||||
serve_websocket_tcp,
|
||||
None, # No handler needed since it's defined inside serve_websocket_tcp
|
||||
port,
|
||||
host,
|
||||
)
|
||||
logger.debug(f"nursery.start() returned: {started_listeners}")
|
||||
|
||||
if started_listeners is None:
|
||||
logger.error(f"Failed to start WebSocket listener for {maddr}")
|
||||
return False
|
||||
|
||||
# Store the listeners for get_addrs() and close() - these are real
|
||||
# SocketListener objects
|
||||
self._listeners = started_listeners
|
||||
logger.debug(
|
||||
"WebsocketListener.listen returning True with WebSocketServer object"
|
||||
)
|
||||
return True
|
||||
|
||||
def get_addrs(self) -> tuple[Multiaddr, ...]:
|
||||
if not hasattr(self, "_listeners") or not self._listeners:
|
||||
logger.debug("No listeners available for get_addrs()")
|
||||
return ()
|
||||
|
||||
# Handle WebSocketServer objects
|
||||
if hasattr(self._listeners, "port"):
|
||||
# This is a WebSocketServer object
|
||||
port = self._listeners.port
|
||||
# Create a multiaddr from the port with correct WSS/WS protocol
|
||||
protocol = "wss" if self._is_wss else "ws"
|
||||
return (Multiaddr(f"/ip4/127.0.0.1/tcp/{port}/{protocol}"),)
|
||||
else:
|
||||
# This is a list of listeners (like TCP)
|
||||
listeners = self._listeners
|
||||
# Get addresses from listeners like TCP does
|
||||
return tuple(
|
||||
_multiaddr_from_socket(listener.socket, self._is_wss)
|
||||
for listener in listeners
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket listener and stop accepting new connections"""
|
||||
logger.debug("WebsocketListener.close called")
|
||||
if hasattr(self, "_listeners") and self._listeners:
|
||||
# Signal shutdown
|
||||
self._shutdown_event.set()
|
||||
|
||||
# Close the WebSocket server
|
||||
if hasattr(self._listeners, "aclose"):
|
||||
# This is a WebSocketServer object
|
||||
logger.debug("Closing WebSocket server")
|
||||
await self._listeners.aclose()
|
||||
logger.debug("WebSocket server closed")
|
||||
elif isinstance(self._listeners, (list, tuple)):
|
||||
# This is a list of listeners (like TCP)
|
||||
logger.debug("Closing TCP listeners")
|
||||
for listener in self._listeners:
|
||||
await listener.aclose()
|
||||
logger.debug("TCP listeners closed")
|
||||
else:
|
||||
# Unknown type, try to close it directly
|
||||
logger.debug("Closing unknown listener type")
|
||||
if hasattr(self._listeners, "close"):
|
||||
self._listeners.close()
|
||||
logger.debug("Unknown listener closed")
|
||||
|
||||
# Clear the listeners reference
|
||||
self._listeners = None
|
||||
logger.debug("WebsocketListener.close completed")
|
||||
|
||||
|
||||
def _multiaddr_from_socket(
|
||||
socket: trio.socket.SocketType, is_wss: bool = False
|
||||
) -> Multiaddr:
|
||||
"""Convert socket to multiaddr"""
|
||||
ip, port = socket.getsockname()
|
||||
protocol = "wss" if is_wss else "ws"
|
||||
return Multiaddr(f"/ip4/{ip}/tcp/{port}/{protocol}")
|
||||
202
libp2p/transport/websocket/multiaddr_utils.py
Normal file
202
libp2p/transport/websocket/multiaddr_utils.py
Normal file
@ -0,0 +1,202 @@
|
||||
"""
|
||||
WebSocket multiaddr parsing utilities.
|
||||
"""
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
from multiaddr.protocols import Protocol
|
||||
|
||||
|
||||
class ParsedWebSocketMultiaddr(NamedTuple):
|
||||
"""Parsed WebSocket multiaddr information."""
|
||||
|
||||
is_wss: bool
|
||||
sni: str | None
|
||||
rest_multiaddr: Multiaddr
|
||||
|
||||
|
||||
def parse_websocket_multiaddr(maddr: Multiaddr) -> ParsedWebSocketMultiaddr:
|
||||
"""
|
||||
Parse a WebSocket multiaddr and extract security information.
|
||||
|
||||
:param maddr: The multiaddr to parse
|
||||
:return: Parsed WebSocket multiaddr information
|
||||
:raises ValueError: If the multiaddr is not a valid WebSocket multiaddr
|
||||
"""
|
||||
# First validate that this is a valid WebSocket multiaddr
|
||||
if not is_valid_websocket_multiaddr(maddr):
|
||||
raise ValueError(f"Not a valid WebSocket multiaddr: {maddr}")
|
||||
|
||||
protocols = list(maddr.protocols())
|
||||
|
||||
# Find the WebSocket protocol and check for security
|
||||
is_wss = False
|
||||
sni = None
|
||||
ws_index = -1
|
||||
tls_index = -1
|
||||
sni_index = -1
|
||||
|
||||
# Find protocol indices
|
||||
for i, protocol in enumerate(protocols):
|
||||
if protocol.name == "ws":
|
||||
ws_index = i
|
||||
elif protocol.name == "wss":
|
||||
ws_index = i
|
||||
is_wss = True
|
||||
elif protocol.name == "tls":
|
||||
tls_index = i
|
||||
elif protocol.name == "sni":
|
||||
sni_index = i
|
||||
sni = protocol.value
|
||||
|
||||
if ws_index == -1:
|
||||
raise ValueError("Not a WebSocket multiaddr")
|
||||
|
||||
# Handle /wss protocol (convert to /tls/ws internally)
|
||||
if is_wss and tls_index == -1:
|
||||
# Convert /wss to /tls/ws format
|
||||
# Remove /wss to get the base multiaddr
|
||||
without_wss = maddr.decapsulate(Multiaddr("/wss"))
|
||||
return ParsedWebSocketMultiaddr(
|
||||
is_wss=True, sni=None, rest_multiaddr=without_wss
|
||||
)
|
||||
|
||||
# Handle /tls/ws and /tls/sni/.../ws formats
|
||||
if tls_index != -1:
|
||||
is_wss = True
|
||||
# Extract the base multiaddr (everything before /tls)
|
||||
# For /ip4/127.0.0.1/tcp/8080/tls/ws, we want /ip4/127.0.0.1/tcp/8080
|
||||
# Use multiaddr methods to properly extract the base
|
||||
rest_multiaddr = maddr
|
||||
# Remove /tls/ws or /tls/sni/.../ws from the end
|
||||
if sni_index != -1:
|
||||
# /tls/sni/example.com/ws format
|
||||
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/ws"))
|
||||
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr(f"/sni/{sni}"))
|
||||
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/tls"))
|
||||
else:
|
||||
# /tls/ws format
|
||||
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/ws"))
|
||||
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/tls"))
|
||||
return ParsedWebSocketMultiaddr(
|
||||
is_wss=is_wss, sni=sni, rest_multiaddr=rest_multiaddr
|
||||
)
|
||||
|
||||
# Regular /ws multiaddr - remove /ws and any additional protocols
|
||||
rest_multiaddr = maddr.decapsulate(Multiaddr("/ws"))
|
||||
return ParsedWebSocketMultiaddr(
|
||||
is_wss=False, sni=None, rest_multiaddr=rest_multiaddr
|
||||
)
|
||||
|
||||
|
||||
def is_valid_websocket_multiaddr(maddr: Multiaddr) -> bool:
|
||||
"""
|
||||
Validate that a multiaddr has a valid WebSocket structure.
|
||||
|
||||
:param maddr: The multiaddr to validate
|
||||
:return: True if valid WebSocket structure, False otherwise
|
||||
"""
|
||||
try:
|
||||
# WebSocket multiaddr should have structure like:
|
||||
# /ip4/127.0.0.1/tcp/8080/ws (insecure)
|
||||
# /ip4/127.0.0.1/tcp/8080/wss (secure)
|
||||
# /ip4/127.0.0.1/tcp/8080/tls/ws (secure with TLS)
|
||||
# /ip4/127.0.0.1/tcp/8080/tls/sni/example.com/ws (secure with SNI)
|
||||
protocols: list[Protocol] = list(maddr.protocols())
|
||||
|
||||
# Must have at least 3 protocols: network (ip4/ip6/dns4/dns6) + tcp + ws/wss
|
||||
if len(protocols) < 3:
|
||||
return False
|
||||
|
||||
# First protocol should be a network protocol (ip4, ip6, dns, dns4, dns6)
|
||||
if protocols[0].name not in ["ip4", "ip6", "dns", "dns4", "dns6"]:
|
||||
return False
|
||||
|
||||
# Second protocol should be tcp
|
||||
if protocols[1].name != "tcp":
|
||||
return False
|
||||
|
||||
# Check for valid WebSocket protocols
|
||||
ws_protocols = ["ws", "wss"]
|
||||
tls_protocols = ["tls"]
|
||||
sni_protocols = ["sni"]
|
||||
|
||||
# Find the WebSocket protocol
|
||||
ws_protocol_found = False
|
||||
tls_found = False
|
||||
# sni_found = False # Not used currently
|
||||
|
||||
for i, protocol in enumerate(protocols[2:], start=2):
|
||||
if protocol.name in ws_protocols:
|
||||
ws_protocol_found = True
|
||||
break
|
||||
elif protocol.name in tls_protocols:
|
||||
tls_found = True
|
||||
elif protocol.name in sni_protocols:
|
||||
pass # sni_found = True # Not used in current implementation
|
||||
|
||||
if not ws_protocol_found:
|
||||
return False
|
||||
|
||||
# Validate protocol sequence
|
||||
# For /ws: network + tcp + ws
|
||||
# For /wss: network + tcp + wss
|
||||
# For /tls/ws: network + tcp + tls + ws
|
||||
# For /tls/sni/example.com/ws: network + tcp + tls + sni + ws
|
||||
|
||||
# Check if it's a simple /ws or /wss
|
||||
if len(protocols) == 3:
|
||||
return protocols[2].name in ["ws", "wss"]
|
||||
|
||||
# Check for /tls/ws or /tls/sni/.../ws patterns
|
||||
if tls_found:
|
||||
# Must end with /ws (not /wss when using /tls)
|
||||
if protocols[-1].name != "ws":
|
||||
return False
|
||||
|
||||
# Check for valid TLS sequence
|
||||
tls_index = None
|
||||
for i, protocol in enumerate(protocols[2:], start=2):
|
||||
if protocol.name == "tls":
|
||||
tls_index = i
|
||||
break
|
||||
|
||||
if tls_index is None:
|
||||
return False
|
||||
|
||||
# After tls, we can have sni, then ws
|
||||
remaining_protocols = protocols[tls_index + 1 :]
|
||||
if len(remaining_protocols) == 1:
|
||||
# /tls/ws
|
||||
return remaining_protocols[0].name == "ws"
|
||||
elif len(remaining_protocols) == 2:
|
||||
# /tls/sni/example.com/ws
|
||||
return (
|
||||
remaining_protocols[0].name == "sni"
|
||||
and remaining_protocols[1].name == "ws"
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
# If we have more than 3 protocols but no TLS, check for valid continuations
|
||||
# Allow additional protocols after the WebSocket protocol (like /p2p)
|
||||
valid_continuations = ["p2p"]
|
||||
|
||||
# Find the WebSocket protocol index
|
||||
ws_index = None
|
||||
for i, protocol in enumerate(protocols):
|
||||
if protocol.name in ["ws", "wss"]:
|
||||
ws_index = i
|
||||
break
|
||||
|
||||
if ws_index is not None:
|
||||
# Check protocols after the WebSocket protocol
|
||||
for i in range(ws_index + 1, len(protocols)):
|
||||
if protocols[i].name not in valid_continuations:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
229
libp2p/transport/websocket/transport.py
Normal file
229
libp2p/transport/websocket/transport.py
Normal file
@ -0,0 +1,229 @@
|
||||
import logging
|
||||
import ssl
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.abc import IListener, ITransport
|
||||
from libp2p.custom_types import THandler
|
||||
from libp2p.network.connection.raw_connection import RawConnection
|
||||
from libp2p.transport.exceptions import OpenConnectionError
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
from libp2p.transport.websocket.multiaddr_utils import parse_websocket_multiaddr
|
||||
|
||||
from .connection import P2PWebSocketConnection
|
||||
from .listener import WebsocketListener
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebsocketTransport(ITransport):
|
||||
"""
|
||||
Libp2p WebSocket transport: dial and listen on /ip4/.../tcp/.../ws and /wss
|
||||
|
||||
Implements production-ready WebSocket transport with:
|
||||
- Flow control and buffer management
|
||||
- Connection limits and rate limiting
|
||||
- Proper error handling and cleanup
|
||||
- Support for both WS and WSS protocols
|
||||
- TLS configuration and handshake timeout
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
upgrader: TransportUpgrader,
|
||||
tls_client_config: ssl.SSLContext | None = None,
|
||||
tls_server_config: ssl.SSLContext | None = None,
|
||||
handshake_timeout: float = 15.0,
|
||||
max_buffered_amount: int = 4 * 1024 * 1024,
|
||||
):
|
||||
self._upgrader = upgrader
|
||||
self._tls_client_config = tls_client_config
|
||||
self._tls_server_config = tls_server_config
|
||||
self._handshake_timeout = handshake_timeout
|
||||
self._max_buffered_amount = max_buffered_amount
|
||||
self._connection_count = 0
|
||||
self._max_connections = 1000 # Production limit
|
||||
|
||||
async def dial(self, maddr: Multiaddr) -> RawConnection:
|
||||
"""Dial a WebSocket connection to the given multiaddr."""
|
||||
logger.debug(f"WebsocketTransport.dial called with {maddr}")
|
||||
|
||||
# Parse the WebSocket multiaddr to determine if it's secure
|
||||
try:
|
||||
parsed = parse_websocket_multiaddr(maddr)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid WebSocket multiaddr: {e}") from e
|
||||
|
||||
# Extract host and port from the base multiaddr
|
||||
host = (
|
||||
parsed.rest_multiaddr.value_for_protocol("ip4")
|
||||
or parsed.rest_multiaddr.value_for_protocol("ip6")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns4")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns6")
|
||||
)
|
||||
port_str = parsed.rest_multiaddr.value_for_protocol("tcp")
|
||||
if port_str is None:
|
||||
raise ValueError(f"No TCP port found in multiaddr: {maddr}")
|
||||
port = int(port_str)
|
||||
|
||||
# Build WebSocket URL based on security
|
||||
if parsed.is_wss:
|
||||
ws_url = f"wss://{host}:{port}/"
|
||||
else:
|
||||
ws_url = f"ws://{host}:{port}/"
|
||||
|
||||
logger.debug(
|
||||
f"WebsocketTransport.dial connecting to {ws_url} (secure={parsed.is_wss})"
|
||||
)
|
||||
|
||||
try:
|
||||
# Check connection limits
|
||||
if self._connection_count >= self._max_connections:
|
||||
raise OpenConnectionError(
|
||||
f"Maximum connections reached: {self._max_connections}"
|
||||
)
|
||||
|
||||
# Prepare SSL context for WSS connections
|
||||
ssl_context = None
|
||||
if parsed.is_wss:
|
||||
if self._tls_client_config:
|
||||
ssl_context = self._tls_client_config
|
||||
else:
|
||||
# Create default SSL context for client
|
||||
ssl_context = ssl.create_default_context()
|
||||
# Set SNI if available
|
||||
if parsed.sni:
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
logger.debug(f"WebsocketTransport.dial opening connection to {ws_url}")
|
||||
|
||||
# Use a different approach: start background nursery that will persist
|
||||
logger.debug("WebsocketTransport.dial establishing connection")
|
||||
|
||||
# Import trio-websocket functions
|
||||
from trio_websocket import connect_websocket
|
||||
from trio_websocket._impl import _url_to_host
|
||||
|
||||
# Parse the WebSocket URL to get host, port, resource
|
||||
# like trio-websocket does
|
||||
ws_host, ws_port, ws_resource, ws_ssl_context = _url_to_host(
|
||||
ws_url, ssl_context
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"WebsocketTransport.dial parsed URL: host={ws_host}, "
|
||||
f"port={ws_port}, resource={ws_resource}"
|
||||
)
|
||||
|
||||
# Create a background task manager for this connection
|
||||
import trio
|
||||
|
||||
nursery_manager = trio.lowlevel.current_task().parent_nursery
|
||||
if nursery_manager is None:
|
||||
raise OpenConnectionError(
|
||||
f"No parent nursery available for WebSocket connection to {maddr}"
|
||||
)
|
||||
|
||||
# Apply timeout to the connection process
|
||||
with trio.fail_after(self._handshake_timeout):
|
||||
logger.debug("WebsocketTransport.dial connecting WebSocket")
|
||||
ws = await connect_websocket(
|
||||
nursery_manager, # Use the existing nursery from libp2p
|
||||
ws_host,
|
||||
ws_port,
|
||||
ws_resource,
|
||||
use_ssl=ws_ssl_context,
|
||||
message_queue_size=1024, # Reasonable defaults
|
||||
max_message_size=16 * 1024 * 1024, # 16MB max message
|
||||
)
|
||||
logger.debug("WebsocketTransport.dial WebSocket connection established")
|
||||
|
||||
# Create our connection wrapper with both WSS support and flow control
|
||||
conn = P2PWebSocketConnection(
|
||||
ws,
|
||||
None,
|
||||
is_secure=parsed.is_wss,
|
||||
max_buffered_amount=self._max_buffered_amount,
|
||||
)
|
||||
logger.debug("WebsocketTransport.dial created P2PWebSocketConnection")
|
||||
|
||||
self._connection_count += 1
|
||||
logger.debug(f"Total connections: {self._connection_count}")
|
||||
|
||||
return RawConnection(conn, initiator=True)
|
||||
except trio.TooSlowError as e:
|
||||
raise OpenConnectionError(
|
||||
f"WebSocket handshake timeout after {self._handshake_timeout}s "
|
||||
f"for {maddr}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to dial WebSocket {maddr}: {e}")
|
||||
raise OpenConnectionError(f"Failed to dial WebSocket {maddr}: {e}") from e
|
||||
|
||||
def create_listener(self, handler: THandler) -> IListener: # type: ignore[override]
|
||||
"""
|
||||
The type checker is incorrectly reporting this as an inconsistent override.
|
||||
"""
|
||||
logger.debug("WebsocketTransport.create_listener called")
|
||||
return WebsocketListener(
|
||||
handler, self._upgrader, self._tls_server_config, self._handshake_timeout
|
||||
)
|
||||
|
||||
def resolve(self, maddr: Multiaddr) -> list[Multiaddr]:
|
||||
"""
|
||||
Resolve a WebSocket multiaddr, automatically adding SNI for DNS names.
|
||||
Similar to Go's Resolve() method.
|
||||
|
||||
:param maddr: The multiaddr to resolve
|
||||
:return: List of resolved multiaddrs
|
||||
"""
|
||||
try:
|
||||
parsed = parse_websocket_multiaddr(maddr)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Invalid WebSocket multiaddr for resolution: {e}")
|
||||
return [maddr] # Return original if not a valid WebSocket multiaddr
|
||||
|
||||
logger.debug(
|
||||
f"Parsed multiaddr {maddr}: is_wss={parsed.is_wss}, sni={parsed.sni}"
|
||||
)
|
||||
|
||||
if not parsed.is_wss:
|
||||
# No /tls/ws component, this isn't a secure websocket multiaddr
|
||||
return [maddr]
|
||||
|
||||
if parsed.sni is not None:
|
||||
# Already has SNI, return as-is
|
||||
return [maddr]
|
||||
|
||||
# Try to extract DNS name from the base multiaddr
|
||||
dns_name = None
|
||||
for protocol_name in ["dns", "dns4", "dns6"]:
|
||||
try:
|
||||
dns_name = parsed.rest_multiaddr.value_for_protocol(protocol_name)
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if dns_name is None:
|
||||
# No DNS name found, return original
|
||||
return [maddr]
|
||||
|
||||
# Create new multiaddr with SNI
|
||||
# For /dns/example.com/tcp/8080/wss ->
|
||||
# /dns/example.com/tcp/8080/tls/sni/example.com/ws
|
||||
try:
|
||||
# Remove /wss and add /tls/sni/example.com/ws
|
||||
without_wss = maddr.decapsulate(Multiaddr("/wss"))
|
||||
sni_component = Multiaddr(f"/sni/{dns_name}")
|
||||
resolved = (
|
||||
without_wss.encapsulate(Multiaddr("/tls"))
|
||||
.encapsulate(sni_component)
|
||||
.encapsulate(Multiaddr("/ws"))
|
||||
)
|
||||
logger.debug(f"Resolved {maddr} to {resolved}")
|
||||
return [resolved]
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to resolve multiaddr {maddr}: {e}")
|
||||
return [maddr]
|
||||
12
newsfragments/585.feature.rst
Normal file
12
newsfragments/585.feature.rst
Normal file
@ -0,0 +1,12 @@
|
||||
Added experimental WebSocket transport support with basic WS and WSS functionality. This includes:
|
||||
|
||||
- WebSocket transport implementation with trio-websocket backend
|
||||
- Support for both WS (WebSocket) and WSS (WebSocket Secure) protocols
|
||||
- Basic connection management and stream handling
|
||||
- TLS configuration support for WSS connections
|
||||
- Multiaddr parsing for WebSocket addresses
|
||||
- Integration with libp2p host and peer discovery
|
||||
|
||||
**Note**: This is experimental functionality. Advanced features like proxy support,
|
||||
interop testing, and production examples are still in development. See
|
||||
https://github.com/libp2p/py-libp2p/discussions/937 for the complete roadmap of missing features.
|
||||
@ -33,6 +33,8 @@ dependencies = [
|
||||
"rpcudp>=3.0.0",
|
||||
"trio-typing>=0.0.4",
|
||||
"trio>=0.26.0",
|
||||
"fastecdsa==2.3.2; sys_platform != 'win32'",
|
||||
"trio-websocket>=0.11.0",
|
||||
"zeroconf (>=0.147.0,<0.148.0)",
|
||||
]
|
||||
classifiers = [
|
||||
|
||||
324
tests/core/transport/test_transport_registry.py
Normal file
324
tests/core/transport/test_transport_registry.py
Normal file
@ -0,0 +1,324 @@
|
||||
"""
|
||||
Tests for the transport registry functionality.
|
||||
"""
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.abc import IListener, IRawConnection, ITransport
|
||||
from libp2p.custom_types import THandler
|
||||
from libp2p.transport.tcp.tcp import TCP
|
||||
from libp2p.transport.transport_registry import (
|
||||
TransportRegistry,
|
||||
create_transport_for_multiaddr,
|
||||
get_supported_transport_protocols,
|
||||
get_transport_registry,
|
||||
register_transport,
|
||||
)
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
from libp2p.transport.websocket.transport import WebsocketTransport
|
||||
|
||||
|
||||
class TestTransportRegistry:
|
||||
"""Test the TransportRegistry class."""
|
||||
|
||||
def test_init(self):
|
||||
"""Test registry initialization."""
|
||||
registry = TransportRegistry()
|
||||
assert isinstance(registry, TransportRegistry)
|
||||
|
||||
# Check that default transports are registered
|
||||
supported = registry.get_supported_protocols()
|
||||
assert "tcp" in supported
|
||||
assert "ws" in supported
|
||||
|
||||
def test_register_transport(self):
|
||||
"""Test transport registration."""
|
||||
registry = TransportRegistry()
|
||||
|
||||
# Register a custom transport
|
||||
class CustomTransport(ITransport):
|
||||
async def dial(self, maddr: Multiaddr) -> IRawConnection:
|
||||
raise NotImplementedError("CustomTransport dial not implemented")
|
||||
|
||||
def create_listener(self, handler_function: THandler) -> IListener:
|
||||
raise NotImplementedError(
|
||||
"CustomTransport create_listener not implemented"
|
||||
)
|
||||
|
||||
registry.register_transport("custom", CustomTransport)
|
||||
assert registry.get_transport("custom") == CustomTransport
|
||||
|
||||
def test_get_transport(self):
|
||||
"""Test getting registered transports."""
|
||||
registry = TransportRegistry()
|
||||
|
||||
# Test existing transports
|
||||
assert registry.get_transport("tcp") == TCP
|
||||
assert registry.get_transport("ws") == WebsocketTransport
|
||||
|
||||
# Test non-existent transport
|
||||
assert registry.get_transport("nonexistent") is None
|
||||
|
||||
def test_get_supported_protocols(self):
|
||||
"""Test getting supported protocols."""
|
||||
registry = TransportRegistry()
|
||||
protocols = registry.get_supported_protocols()
|
||||
|
||||
assert isinstance(protocols, list)
|
||||
assert "tcp" in protocols
|
||||
assert "ws" in protocols
|
||||
|
||||
def test_create_transport_tcp(self):
|
||||
"""Test creating TCP transport."""
|
||||
registry = TransportRegistry()
|
||||
upgrader = TransportUpgrader({}, {})
|
||||
|
||||
transport = registry.create_transport("tcp", upgrader)
|
||||
assert isinstance(transport, TCP)
|
||||
|
||||
def test_create_transport_websocket(self):
|
||||
"""Test creating WebSocket transport."""
|
||||
registry = TransportRegistry()
|
||||
upgrader = TransportUpgrader({}, {})
|
||||
|
||||
transport = registry.create_transport("ws", upgrader)
|
||||
assert isinstance(transport, WebsocketTransport)
|
||||
|
||||
def test_create_transport_invalid_protocol(self):
|
||||
"""Test creating transport with invalid protocol."""
|
||||
registry = TransportRegistry()
|
||||
upgrader = TransportUpgrader({}, {})
|
||||
|
||||
transport = registry.create_transport("invalid", upgrader)
|
||||
assert transport is None
|
||||
|
||||
def test_create_transport_websocket_no_upgrader(self):
|
||||
"""Test that WebSocket transport requires upgrader."""
|
||||
registry = TransportRegistry()
|
||||
|
||||
# This should fail gracefully and return None
|
||||
transport = registry.create_transport("ws", None)
|
||||
assert transport is None
|
||||
|
||||
|
||||
class TestGlobalRegistry:
|
||||
"""Test the global registry functions."""
|
||||
|
||||
def test_get_transport_registry(self):
|
||||
"""Test getting the global registry."""
|
||||
registry = get_transport_registry()
|
||||
assert isinstance(registry, TransportRegistry)
|
||||
|
||||
def test_register_transport_global(self):
|
||||
"""Test registering transport globally."""
|
||||
|
||||
class GlobalCustomTransport(ITransport):
|
||||
async def dial(self, maddr: Multiaddr) -> IRawConnection:
|
||||
raise NotImplementedError("GlobalCustomTransport dial not implemented")
|
||||
|
||||
def create_listener(self, handler_function: THandler) -> IListener:
|
||||
raise NotImplementedError(
|
||||
"GlobalCustomTransport create_listener not implemented"
|
||||
)
|
||||
|
||||
# Register globally
|
||||
register_transport("global_custom", GlobalCustomTransport)
|
||||
|
||||
# Check that it's available
|
||||
registry = get_transport_registry()
|
||||
assert registry.get_transport("global_custom") == GlobalCustomTransport
|
||||
|
||||
def test_get_supported_transport_protocols_global(self):
|
||||
"""Test getting supported protocols from global registry."""
|
||||
protocols = get_supported_transport_protocols()
|
||||
assert isinstance(protocols, list)
|
||||
assert "tcp" in protocols
|
||||
assert "ws" in protocols
|
||||
|
||||
|
||||
class TestTransportFactory:
|
||||
"""Test the transport factory functions."""
|
||||
|
||||
def test_create_transport_for_multiaddr_tcp(self):
|
||||
"""Test creating transport for TCP multiaddr."""
|
||||
upgrader = TransportUpgrader({}, {})
|
||||
|
||||
# TCP multiaddr
|
||||
maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080")
|
||||
transport = create_transport_for_multiaddr(maddr, upgrader)
|
||||
|
||||
assert transport is not None
|
||||
assert isinstance(transport, TCP)
|
||||
|
||||
def test_create_transport_for_multiaddr_websocket(self):
|
||||
"""Test creating transport for WebSocket multiaddr."""
|
||||
upgrader = TransportUpgrader({}, {})
|
||||
|
||||
# WebSocket multiaddr
|
||||
maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
|
||||
transport = create_transport_for_multiaddr(maddr, upgrader)
|
||||
|
||||
assert transport is not None
|
||||
assert isinstance(transport, WebsocketTransport)
|
||||
|
||||
def test_create_transport_for_multiaddr_websocket_secure(self):
|
||||
"""Test creating transport for WebSocket multiaddr."""
|
||||
upgrader = TransportUpgrader({}, {})
|
||||
|
||||
# WebSocket multiaddr
|
||||
maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
|
||||
transport = create_transport_for_multiaddr(maddr, upgrader)
|
||||
|
||||
assert transport is not None
|
||||
assert isinstance(transport, WebsocketTransport)
|
||||
|
||||
def test_create_transport_for_multiaddr_ipv6(self):
|
||||
"""Test creating transport for IPv6 multiaddr."""
|
||||
upgrader = TransportUpgrader({}, {})
|
||||
|
||||
# IPv6 WebSocket multiaddr
|
||||
maddr = Multiaddr("/ip6/::1/tcp/8080/ws")
|
||||
transport = create_transport_for_multiaddr(maddr, upgrader)
|
||||
|
||||
assert transport is not None
|
||||
assert isinstance(transport, WebsocketTransport)
|
||||
|
||||
def test_create_transport_for_multiaddr_dns(self):
|
||||
"""Test creating transport for DNS multiaddr."""
|
||||
upgrader = TransportUpgrader({}, {})
|
||||
|
||||
# DNS WebSocket multiaddr
|
||||
maddr = Multiaddr("/dns4/example.com/tcp/443/ws")
|
||||
transport = create_transport_for_multiaddr(maddr, upgrader)
|
||||
|
||||
assert transport is not None
|
||||
assert isinstance(transport, WebsocketTransport)
|
||||
|
||||
def test_create_transport_for_multiaddr_unknown(self):
|
||||
"""Test creating transport for unknown multiaddr."""
|
||||
upgrader = TransportUpgrader({}, {})
|
||||
|
||||
# Unknown multiaddr
|
||||
maddr = Multiaddr("/ip4/127.0.0.1/udp/8080")
|
||||
transport = create_transport_for_multiaddr(maddr, upgrader)
|
||||
|
||||
assert transport is None
|
||||
|
||||
def test_create_transport_for_multiaddr_with_upgrader(self):
|
||||
"""Test creating transport with upgrader."""
|
||||
upgrader = TransportUpgrader({}, {})
|
||||
|
||||
# This should work for both TCP and WebSocket with upgrader
|
||||
maddr_tcp = Multiaddr("/ip4/127.0.0.1/tcp/8080")
|
||||
transport_tcp = create_transport_for_multiaddr(maddr_tcp, upgrader)
|
||||
assert transport_tcp is not None
|
||||
|
||||
maddr_ws = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
|
||||
transport_ws = create_transport_for_multiaddr(maddr_ws, upgrader)
|
||||
assert transport_ws is not None
|
||||
|
||||
|
||||
class TestTransportInterfaceCompliance:
|
||||
"""Test that all transports implement the required interface."""
|
||||
|
||||
def test_tcp_implements_itransport(self):
|
||||
"""Test that TCP transport implements ITransport."""
|
||||
transport = TCP()
|
||||
assert isinstance(transport, ITransport)
|
||||
assert hasattr(transport, "dial")
|
||||
assert hasattr(transport, "create_listener")
|
||||
assert callable(transport.dial)
|
||||
assert callable(transport.create_listener)
|
||||
|
||||
def test_websocket_implements_itransport(self):
|
||||
"""Test that WebSocket transport implements ITransport."""
|
||||
upgrader = TransportUpgrader({}, {})
|
||||
transport = WebsocketTransport(upgrader)
|
||||
assert isinstance(transport, ITransport)
|
||||
assert hasattr(transport, "dial")
|
||||
assert hasattr(transport, "create_listener")
|
||||
assert callable(transport.dial)
|
||||
assert callable(transport.create_listener)
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test error handling in the transport registry."""
|
||||
|
||||
def test_create_transport_with_exception(self):
|
||||
"""Test handling of transport creation exceptions."""
|
||||
registry = TransportRegistry()
|
||||
upgrader = TransportUpgrader({}, {})
|
||||
|
||||
# Register a transport that raises an exception
|
||||
class ExceptionTransport(ITransport):
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise RuntimeError("Transport creation failed")
|
||||
|
||||
async def dial(self, maddr: Multiaddr) -> IRawConnection:
|
||||
raise NotImplementedError("ExceptionTransport dial not implemented")
|
||||
|
||||
def create_listener(self, handler_function: THandler) -> IListener:
|
||||
raise NotImplementedError(
|
||||
"ExceptionTransport create_listener not implemented"
|
||||
)
|
||||
|
||||
registry.register_transport("exception", ExceptionTransport)
|
||||
|
||||
# Should handle exception gracefully and return None
|
||||
transport = registry.create_transport("exception", upgrader)
|
||||
assert transport is None
|
||||
|
||||
def test_invalid_multiaddr_handling(self):
|
||||
"""Test handling of invalid multiaddrs."""
|
||||
upgrader = TransportUpgrader({}, {})
|
||||
|
||||
# Test with a multiaddr that has an unsupported transport protocol
|
||||
# This should be handled gracefully by our transport registry
|
||||
# udp is not a supported transport
|
||||
maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/udp/1234")
|
||||
transport = create_transport_for_multiaddr(maddr, upgrader)
|
||||
|
||||
assert transport is None
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Test integration scenarios."""
|
||||
|
||||
def test_multiple_transport_types(self):
|
||||
"""Test using multiple transport types in the same registry."""
|
||||
registry = TransportRegistry()
|
||||
upgrader = TransportUpgrader({}, {})
|
||||
|
||||
# Create different transport types
|
||||
tcp_transport = registry.create_transport("tcp", upgrader)
|
||||
ws_transport = registry.create_transport("ws", upgrader)
|
||||
|
||||
# All should be different types
|
||||
assert isinstance(tcp_transport, TCP)
|
||||
assert isinstance(ws_transport, WebsocketTransport)
|
||||
|
||||
# All should be different instances
|
||||
assert tcp_transport is not ws_transport
|
||||
|
||||
def test_transport_registry_persistence(self):
|
||||
"""Test that transport registry persists across calls."""
|
||||
registry1 = get_transport_registry()
|
||||
registry2 = get_transport_registry()
|
||||
|
||||
# Should be the same instance
|
||||
assert registry1 is registry2
|
||||
|
||||
# Register a transport in one
|
||||
class PersistentTransport(ITransport):
|
||||
async def dial(self, maddr: Multiaddr) -> IRawConnection:
|
||||
raise NotImplementedError("PersistentTransport dial not implemented")
|
||||
|
||||
def create_listener(self, handler_function: THandler) -> IListener:
|
||||
raise NotImplementedError(
|
||||
"PersistentTransport create_listener not implemented"
|
||||
)
|
||||
|
||||
registry1.register_transport("persistent", PersistentTransport)
|
||||
|
||||
# Should be available in the other
|
||||
assert registry2.get_transport("persistent") == PersistentTransport
|
||||
1631
tests/core/transport/test_websocket.py
Normal file
1631
tests/core/transport/test_websocket.py
Normal file
File diff suppressed because it is too large
Load Diff
532
tests/core/transport/test_websocket_p2p.py
Normal file
532
tests/core/transport/test_websocket_p2p.py
Normal file
@ -0,0 +1,532 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Python-to-Python WebSocket peer-to-peer tests.
|
||||
|
||||
This module tests real WebSocket communication between two Python libp2p hosts,
|
||||
including both WS and WSS (WebSocket Secure) scenarios.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p import create_yamux_muxer_option, new_host
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.crypto.x25519 import create_new_key_pair as create_new_x25519_key_pair
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||
from libp2p.security.noise.transport import (
|
||||
PROTOCOL_ID as NOISE_PROTOCOL_ID,
|
||||
Transport as NoiseTransport,
|
||||
)
|
||||
from libp2p.transport.websocket.multiaddr_utils import (
|
||||
is_valid_websocket_multiaddr,
|
||||
parse_websocket_multiaddr,
|
||||
)
|
||||
|
||||
PING_PROTOCOL_ID = TProtocol("/ipfs/ping/1.0.0")
|
||||
PING_LENGTH = 32
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_websocket_p2p_plaintext():
|
||||
"""Test Python-to-Python WebSocket communication with plaintext security."""
|
||||
# Create two hosts with plaintext security
|
||||
key_pair_a = create_new_key_pair()
|
||||
key_pair_b = create_new_key_pair()
|
||||
|
||||
# Host A (listener) - use only plaintext security
|
||||
security_options_a = {
|
||||
PLAINTEXT_PROTOCOL_ID: InsecureTransport(
|
||||
local_key_pair=key_pair_a, secure_bytes_provider=None, peerstore=None
|
||||
)
|
||||
}
|
||||
host_a = new_host(
|
||||
key_pair=key_pair_a,
|
||||
sec_opt=security_options_a,
|
||||
muxer_opt=create_yamux_muxer_option(),
|
||||
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")],
|
||||
)
|
||||
|
||||
# Host B (dialer) - use only plaintext security
|
||||
security_options_b = {
|
||||
PLAINTEXT_PROTOCOL_ID: InsecureTransport(
|
||||
local_key_pair=key_pair_b, secure_bytes_provider=None, peerstore=None
|
||||
)
|
||||
}
|
||||
host_b = new_host(
|
||||
key_pair=key_pair_b,
|
||||
sec_opt=security_options_b,
|
||||
muxer_opt=create_yamux_muxer_option(),
|
||||
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # Ensure WebSocket
|
||||
# transport
|
||||
)
|
||||
|
||||
# Test data
|
||||
test_data = b"Hello WebSocket P2P!"
|
||||
received_data = None
|
||||
|
||||
# Set up ping handler on host A
|
||||
async def ping_handler(stream):
|
||||
nonlocal received_data
|
||||
received_data = await stream.read(len(test_data))
|
||||
await stream.write(received_data) # Echo back
|
||||
await stream.close()
|
||||
|
||||
host_a.set_stream_handler(PING_PROTOCOL_ID, ping_handler)
|
||||
|
||||
# Start both hosts
|
||||
async with (
|
||||
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")]),
|
||||
host_b.run(listen_addrs=[]),
|
||||
):
|
||||
# Get host A's listen address
|
||||
listen_addrs = host_a.get_addrs()
|
||||
assert len(listen_addrs) > 0
|
||||
|
||||
# Find the WebSocket address
|
||||
ws_addr = None
|
||||
for addr in listen_addrs:
|
||||
if "/ws" in str(addr):
|
||||
ws_addr = addr
|
||||
break
|
||||
|
||||
assert ws_addr is not None, "No WebSocket listen address found"
|
||||
assert is_valid_websocket_multiaddr(ws_addr), "Invalid WebSocket multiaddr"
|
||||
|
||||
# Parse the WebSocket multiaddr
|
||||
parsed = parse_websocket_multiaddr(ws_addr)
|
||||
assert not parsed.is_wss, "Should be plain WebSocket, not WSS"
|
||||
assert parsed.sni is None, "SNI should be None for plain WebSocket"
|
||||
|
||||
# Connect host B to host A
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
|
||||
peer_info = info_from_p2p_addr(ws_addr)
|
||||
await host_b.connect(peer_info)
|
||||
|
||||
# Create stream and test communication
|
||||
stream = await host_b.new_stream(host_a.get_id(), [PING_PROTOCOL_ID])
|
||||
await stream.write(test_data)
|
||||
response = await stream.read(len(test_data))
|
||||
await stream.close()
|
||||
|
||||
# Verify communication
|
||||
assert received_data == test_data, f"Expected {test_data}, got {received_data}"
|
||||
assert response == test_data, f"Expected echo {test_data}, got {response}"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_websocket_p2p_noise():
|
||||
"""Test Python-to-Python WebSocket communication with Noise security."""
|
||||
# Create two hosts with Noise security
|
||||
key_pair_a = create_new_key_pair()
|
||||
key_pair_b = create_new_key_pair()
|
||||
noise_key_pair_a = create_new_x25519_key_pair()
|
||||
noise_key_pair_b = create_new_x25519_key_pair()
|
||||
|
||||
# Host A (listener)
|
||||
security_options_a = {
|
||||
NOISE_PROTOCOL_ID: NoiseTransport(
|
||||
libp2p_keypair=key_pair_a,
|
||||
noise_privkey=noise_key_pair_a.private_key,
|
||||
early_data=None,
|
||||
with_noise_pipes=False,
|
||||
)
|
||||
}
|
||||
host_a = new_host(
|
||||
key_pair=key_pair_a,
|
||||
sec_opt=security_options_a,
|
||||
muxer_opt=create_yamux_muxer_option(),
|
||||
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")],
|
||||
)
|
||||
|
||||
# Host B (dialer)
|
||||
security_options_b = {
|
||||
NOISE_PROTOCOL_ID: NoiseTransport(
|
||||
libp2p_keypair=key_pair_b,
|
||||
noise_privkey=noise_key_pair_b.private_key,
|
||||
early_data=None,
|
||||
with_noise_pipes=False,
|
||||
)
|
||||
}
|
||||
host_b = new_host(
|
||||
key_pair=key_pair_b,
|
||||
sec_opt=security_options_b,
|
||||
muxer_opt=create_yamux_muxer_option(),
|
||||
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # Ensure WebSocket
|
||||
# transport
|
||||
)
|
||||
|
||||
# Test data
|
||||
test_data = b"Hello WebSocket P2P with Noise!"
|
||||
received_data = None
|
||||
|
||||
# Set up ping handler on host A
|
||||
async def ping_handler(stream):
|
||||
nonlocal received_data
|
||||
received_data = await stream.read(len(test_data))
|
||||
await stream.write(received_data) # Echo back
|
||||
await stream.close()
|
||||
|
||||
host_a.set_stream_handler(PING_PROTOCOL_ID, ping_handler)
|
||||
|
||||
# Start both hosts
|
||||
async with (
|
||||
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")]),
|
||||
host_b.run(listen_addrs=[]),
|
||||
):
|
||||
# Get host A's listen address
|
||||
listen_addrs = host_a.get_addrs()
|
||||
assert len(listen_addrs) > 0
|
||||
|
||||
# Find the WebSocket address
|
||||
ws_addr = None
|
||||
for addr in listen_addrs:
|
||||
if "/ws" in str(addr):
|
||||
ws_addr = addr
|
||||
break
|
||||
|
||||
assert ws_addr is not None, "No WebSocket listen address found"
|
||||
assert is_valid_websocket_multiaddr(ws_addr), "Invalid WebSocket multiaddr"
|
||||
|
||||
# Parse the WebSocket multiaddr
|
||||
parsed = parse_websocket_multiaddr(ws_addr)
|
||||
assert not parsed.is_wss, "Should be plain WebSocket, not WSS"
|
||||
assert parsed.sni is None, "SNI should be None for plain WebSocket"
|
||||
|
||||
# Connect host B to host A
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
|
||||
peer_info = info_from_p2p_addr(ws_addr)
|
||||
await host_b.connect(peer_info)
|
||||
|
||||
# Create stream and test communication
|
||||
stream = await host_b.new_stream(host_a.get_id(), [PING_PROTOCOL_ID])
|
||||
await stream.write(test_data)
|
||||
response = await stream.read(len(test_data))
|
||||
await stream.close()
|
||||
|
||||
# Verify communication
|
||||
assert received_data == test_data, f"Expected {test_data}, got {received_data}"
|
||||
assert response == test_data, f"Expected echo {test_data}, got {response}"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_websocket_p2p_libp2p_ping():
|
||||
"""Test Python-to-Python WebSocket communication using libp2p ping protocol."""
|
||||
# Create two hosts with Noise security
|
||||
key_pair_a = create_new_key_pair()
|
||||
key_pair_b = create_new_key_pair()
|
||||
noise_key_pair_a = create_new_x25519_key_pair()
|
||||
noise_key_pair_b = create_new_x25519_key_pair()
|
||||
|
||||
# Host A (listener)
|
||||
security_options_a = {
|
||||
NOISE_PROTOCOL_ID: NoiseTransport(
|
||||
libp2p_keypair=key_pair_a,
|
||||
noise_privkey=noise_key_pair_a.private_key,
|
||||
early_data=None,
|
||||
with_noise_pipes=False,
|
||||
)
|
||||
}
|
||||
host_a = new_host(
|
||||
key_pair=key_pair_a,
|
||||
sec_opt=security_options_a,
|
||||
muxer_opt=create_yamux_muxer_option(),
|
||||
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")],
|
||||
)
|
||||
|
||||
# Host B (dialer)
|
||||
security_options_b = {
|
||||
NOISE_PROTOCOL_ID: NoiseTransport(
|
||||
libp2p_keypair=key_pair_b,
|
||||
noise_privkey=noise_key_pair_b.private_key,
|
||||
early_data=None,
|
||||
with_noise_pipes=False,
|
||||
)
|
||||
}
|
||||
host_b = new_host(
|
||||
key_pair=key_pair_b,
|
||||
sec_opt=security_options_b,
|
||||
muxer_opt=create_yamux_muxer_option(),
|
||||
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # Ensure WebSocket
|
||||
# transport
|
||||
)
|
||||
|
||||
# Set up ping handler on host A (standard libp2p ping protocol)
|
||||
async def ping_handler(stream):
|
||||
# Read ping data (32 bytes)
|
||||
ping_data = await stream.read(PING_LENGTH)
|
||||
# Echo back the same data (pong)
|
||||
await stream.write(ping_data)
|
||||
await stream.close()
|
||||
|
||||
host_a.set_stream_handler(PING_PROTOCOL_ID, ping_handler)
|
||||
|
||||
# Start both hosts
|
||||
async with (
|
||||
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")]),
|
||||
host_b.run(listen_addrs=[]),
|
||||
):
|
||||
# Get host A's listen address
|
||||
listen_addrs = host_a.get_addrs()
|
||||
assert len(listen_addrs) > 0
|
||||
|
||||
# Find the WebSocket address
|
||||
ws_addr = None
|
||||
for addr in listen_addrs:
|
||||
if "/ws" in str(addr):
|
||||
ws_addr = addr
|
||||
break
|
||||
|
||||
assert ws_addr is not None, "No WebSocket listen address found"
|
||||
|
||||
# Connect host B to host A
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
|
||||
peer_info = info_from_p2p_addr(ws_addr)
|
||||
await host_b.connect(peer_info)
|
||||
|
||||
# Create stream and test libp2p ping protocol
|
||||
stream = await host_b.new_stream(host_a.get_id(), [PING_PROTOCOL_ID])
|
||||
|
||||
# Send ping (32 bytes as per libp2p ping protocol)
|
||||
ping_data = b"\x01" * PING_LENGTH
|
||||
await stream.write(ping_data)
|
||||
|
||||
# Receive pong (should be same 32 bytes)
|
||||
pong_data = await stream.read(PING_LENGTH)
|
||||
await stream.close()
|
||||
|
||||
# Verify ping-pong
|
||||
assert pong_data == ping_data, (
|
||||
f"Expected ping {ping_data}, got pong {pong_data}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_websocket_p2p_multiple_streams():
|
||||
"""
|
||||
Test Python-to-Python WebSocket communication with multiple concurrent
|
||||
streams.
|
||||
"""
|
||||
# Create two hosts with Noise security
|
||||
key_pair_a = create_new_key_pair()
|
||||
key_pair_b = create_new_key_pair()
|
||||
noise_key_pair_a = create_new_x25519_key_pair()
|
||||
noise_key_pair_b = create_new_x25519_key_pair()
|
||||
|
||||
# Host A (listener)
|
||||
security_options_a = {
|
||||
NOISE_PROTOCOL_ID: NoiseTransport(
|
||||
libp2p_keypair=key_pair_a,
|
||||
noise_privkey=noise_key_pair_a.private_key,
|
||||
early_data=None,
|
||||
with_noise_pipes=False,
|
||||
)
|
||||
}
|
||||
host_a = new_host(
|
||||
key_pair=key_pair_a,
|
||||
sec_opt=security_options_a,
|
||||
muxer_opt=create_yamux_muxer_option(),
|
||||
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")],
|
||||
)
|
||||
|
||||
# Host B (dialer)
|
||||
security_options_b = {
|
||||
NOISE_PROTOCOL_ID: NoiseTransport(
|
||||
libp2p_keypair=key_pair_b,
|
||||
noise_privkey=noise_key_pair_b.private_key,
|
||||
early_data=None,
|
||||
with_noise_pipes=False,
|
||||
)
|
||||
}
|
||||
host_b = new_host(
|
||||
key_pair=key_pair_b,
|
||||
sec_opt=security_options_b,
|
||||
muxer_opt=create_yamux_muxer_option(),
|
||||
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # Ensure WebSocket
|
||||
# transport
|
||||
)
|
||||
|
||||
# Test protocol
|
||||
test_protocol = TProtocol("/test/multiple/streams/1.0.0")
|
||||
received_data = []
|
||||
|
||||
# Set up handler on host A
|
||||
async def test_handler(stream):
|
||||
data = await stream.read(1024)
|
||||
received_data.append(data)
|
||||
await stream.write(data) # Echo back
|
||||
await stream.close()
|
||||
|
||||
host_a.set_stream_handler(test_protocol, test_handler)
|
||||
|
||||
# Start both hosts
|
||||
async with (
|
||||
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")]),
|
||||
host_b.run(listen_addrs=[]),
|
||||
):
|
||||
# Get host A's listen address
|
||||
listen_addrs = host_a.get_addrs()
|
||||
ws_addr = None
|
||||
for addr in listen_addrs:
|
||||
if "/ws" in str(addr):
|
||||
ws_addr = addr
|
||||
break
|
||||
|
||||
assert ws_addr is not None, "No WebSocket listen address found"
|
||||
|
||||
# Connect host B to host A
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
|
||||
peer_info = info_from_p2p_addr(ws_addr)
|
||||
await host_b.connect(peer_info)
|
||||
|
||||
# Create multiple concurrent streams
|
||||
num_streams = 5
|
||||
test_data_list = [f"Stream {i} data".encode() for i in range(num_streams)]
|
||||
|
||||
async def create_stream_and_test(stream_id: int, data: bytes):
|
||||
stream = await host_b.new_stream(host_a.get_id(), [test_protocol])
|
||||
await stream.write(data)
|
||||
response = await stream.read(len(data))
|
||||
await stream.close()
|
||||
return response
|
||||
|
||||
# Run all streams concurrently
|
||||
tasks = [
|
||||
create_stream_and_test(i, test_data_list[i]) for i in range(num_streams)
|
||||
]
|
||||
responses = []
|
||||
for task in tasks:
|
||||
responses.append(await task)
|
||||
|
||||
# Verify all communications
|
||||
assert len(received_data) == num_streams, (
|
||||
f"Expected {num_streams} received messages, got {len(received_data)}"
|
||||
)
|
||||
for i, (sent, received, response) in enumerate(
|
||||
zip(test_data_list, received_data, responses)
|
||||
):
|
||||
assert received == sent, f"Stream {i}: Expected {sent}, got {received}"
|
||||
assert response == sent, f"Stream {i}: Expected echo {sent}, got {response}"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_websocket_p2p_connection_state():
|
||||
"""Test WebSocket connection state tracking and metadata."""
|
||||
# Create two hosts with Noise security
|
||||
key_pair_a = create_new_key_pair()
|
||||
key_pair_b = create_new_key_pair()
|
||||
noise_key_pair_a = create_new_x25519_key_pair()
|
||||
noise_key_pair_b = create_new_x25519_key_pair()
|
||||
|
||||
# Host A (listener)
|
||||
security_options_a = {
|
||||
NOISE_PROTOCOL_ID: NoiseTransport(
|
||||
libp2p_keypair=key_pair_a,
|
||||
noise_privkey=noise_key_pair_a.private_key,
|
||||
early_data=None,
|
||||
with_noise_pipes=False,
|
||||
)
|
||||
}
|
||||
host_a = new_host(
|
||||
key_pair=key_pair_a,
|
||||
sec_opt=security_options_a,
|
||||
muxer_opt=create_yamux_muxer_option(),
|
||||
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")],
|
||||
)
|
||||
|
||||
# Host B (dialer)
|
||||
security_options_b = {
|
||||
NOISE_PROTOCOL_ID: NoiseTransport(
|
||||
libp2p_keypair=key_pair_b,
|
||||
noise_privkey=noise_key_pair_b.private_key,
|
||||
early_data=None,
|
||||
with_noise_pipes=False,
|
||||
)
|
||||
}
|
||||
host_b = new_host(
|
||||
key_pair=key_pair_b,
|
||||
sec_opt=security_options_b,
|
||||
muxer_opt=create_yamux_muxer_option(),
|
||||
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")], # Ensure WebSocket
|
||||
# transport
|
||||
)
|
||||
|
||||
# Set up handler on host A
|
||||
async def test_handler(stream):
|
||||
# Read some data
|
||||
await stream.read(1024)
|
||||
# Write some data back
|
||||
await stream.write(b"Response data")
|
||||
await stream.close()
|
||||
|
||||
host_a.set_stream_handler(PING_PROTOCOL_ID, test_handler)
|
||||
|
||||
# Start both hosts
|
||||
async with (
|
||||
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0/ws")]),
|
||||
host_b.run(listen_addrs=[]),
|
||||
):
|
||||
# Get host A's listen address
|
||||
listen_addrs = host_a.get_addrs()
|
||||
ws_addr = None
|
||||
for addr in listen_addrs:
|
||||
if "/ws" in str(addr):
|
||||
ws_addr = addr
|
||||
break
|
||||
|
||||
assert ws_addr is not None, "No WebSocket listen address found"
|
||||
|
||||
# Connect host B to host A
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
|
||||
peer_info = info_from_p2p_addr(ws_addr)
|
||||
await host_b.connect(peer_info)
|
||||
|
||||
# Create stream and test communication
|
||||
stream = await host_b.new_stream(host_a.get_id(), [PING_PROTOCOL_ID])
|
||||
await stream.write(b"Test data for connection state")
|
||||
response = await stream.read(1024)
|
||||
await stream.close()
|
||||
|
||||
# Verify response
|
||||
assert response == b"Response data", f"Expected 'Response data', got {response}"
|
||||
|
||||
# Test connection state (if available)
|
||||
# Note: This tests the connection state tracking we implemented
|
||||
connections = host_b.get_network().connections
|
||||
assert len(connections) > 0, "Should have at least one connection"
|
||||
|
||||
# Get the connection to host A
|
||||
conn_to_a = None
|
||||
for peer_id, conn_list in connections.items():
|
||||
if peer_id == host_a.get_id():
|
||||
# connections maps peer_id to list of connections, get the first one
|
||||
conn_to_a = conn_list[0] if conn_list else None
|
||||
break
|
||||
|
||||
assert conn_to_a is not None, "Should have connection to host A"
|
||||
|
||||
# Test that the connection has the expected properties
|
||||
assert hasattr(conn_to_a, "muxed_conn"), "Connection should have muxed_conn"
|
||||
assert hasattr(conn_to_a.muxed_conn, "secured_conn"), (
|
||||
"Muxed connection should have underlying secured_conn"
|
||||
)
|
||||
|
||||
# If the underlying connection is our WebSocket connection, test its state
|
||||
# Type assertion to access private attribute for testing
|
||||
underlying_conn = getattr(conn_to_a.muxed_conn, "secured_conn")
|
||||
if hasattr(underlying_conn, "conn_state"):
|
||||
state = underlying_conn.conn_state()
|
||||
assert "connection_start_time" in state, (
|
||||
"Connection state should include start time"
|
||||
)
|
||||
assert "bytes_read" in state, "Connection state should include bytes read"
|
||||
assert "bytes_written" in state, (
|
||||
"Connection state should include bytes written"
|
||||
)
|
||||
assert state["bytes_read"] > 0, "Should have read some bytes"
|
||||
assert state["bytes_written"] > 0, "Should have written some bytes"
|
||||
0
tests/interop/__init__.py
Normal file
0
tests/interop/__init__.py
Normal file
21
tests/interop/js_libp2p/js_node/src/package.json
Normal file
21
tests/interop/js_libp2p/js_node/src/package.json
Normal file
@ -0,0 +1,21 @@
|
||||
{
|
||||
"name": "src",
|
||||
"version": "1.0.0",
|
||||
"main": "ping.js",
|
||||
"scripts": {
|
||||
"test": "echo \"Error: no test specified\" && exit 1"
|
||||
},
|
||||
"keywords": [],
|
||||
"author": "",
|
||||
"license": "ISC",
|
||||
"description": "",
|
||||
"dependencies": {
|
||||
"@chainsafe/libp2p-noise": "^9.0.0",
|
||||
"@chainsafe/libp2p-yamux": "^5.0.1",
|
||||
"@libp2p/ping": "^2.0.36",
|
||||
"@libp2p/plaintext": "^2.0.29",
|
||||
"@libp2p/websockets": "^9.2.18",
|
||||
"libp2p": "^2.9.0",
|
||||
"multiaddr": "^10.0.1"
|
||||
}
|
||||
}
|
||||
122
tests/interop/js_libp2p/js_node/src/ws_ping_node.mjs
Normal file
122
tests/interop/js_libp2p/js_node/src/ws_ping_node.mjs
Normal file
@ -0,0 +1,122 @@
|
||||
import { createLibp2p } from 'libp2p'
|
||||
import { webSockets } from '@libp2p/websockets'
|
||||
import { ping } from '@libp2p/ping'
|
||||
import { noise } from '@chainsafe/libp2p-noise'
|
||||
import { plaintext } from '@libp2p/plaintext'
|
||||
import { yamux } from '@chainsafe/libp2p-yamux'
|
||||
// import { identify } from '@libp2p/identify' // Commented out for compatibility
|
||||
|
||||
// Configuration from environment (with defaults for compatibility)
|
||||
const TRANSPORT = process.env.transport || 'ws'
|
||||
const SECURITY = process.env.security || 'noise'
|
||||
const MUXER = process.env.muxer || 'yamux'
|
||||
const IP = process.env.ip || '0.0.0.0'
|
||||
|
||||
async function main() {
|
||||
console.log(`🔧 Configuration: transport=${TRANSPORT}, security=${SECURITY}, muxer=${MUXER}`)
|
||||
|
||||
// Build options following the proven pattern from test-plans-fork
|
||||
const options = {
|
||||
start: true,
|
||||
connectionGater: {
|
||||
denyDialMultiaddr: async () => false
|
||||
},
|
||||
connectionMonitor: {
|
||||
enabled: false
|
||||
},
|
||||
services: {
|
||||
ping: ping()
|
||||
}
|
||||
}
|
||||
|
||||
// Transport configuration (following get-libp2p.ts pattern)
|
||||
switch (TRANSPORT) {
|
||||
case 'ws':
|
||||
options.transports = [webSockets()]
|
||||
options.addresses = {
|
||||
listen: [`/ip4/${IP}/tcp/0/ws`]
|
||||
}
|
||||
break
|
||||
case 'wss':
|
||||
process.env.NODE_TLS_REJECT_UNAUTHORIZED = '0'
|
||||
options.transports = [webSockets()]
|
||||
options.addresses = {
|
||||
listen: [`/ip4/${IP}/tcp/0/wss`]
|
||||
}
|
||||
break
|
||||
default:
|
||||
throw new Error(`Unknown transport: ${TRANSPORT}`)
|
||||
}
|
||||
|
||||
// Security configuration
|
||||
switch (SECURITY) {
|
||||
case 'noise':
|
||||
options.connectionEncryption = [noise()]
|
||||
break
|
||||
case 'plaintext':
|
||||
options.connectionEncryption = [plaintext()]
|
||||
break
|
||||
default:
|
||||
throw new Error(`Unknown security: ${SECURITY}`)
|
||||
}
|
||||
|
||||
// Muxer configuration
|
||||
switch (MUXER) {
|
||||
case 'yamux':
|
||||
options.streamMuxers = [yamux()]
|
||||
break
|
||||
default:
|
||||
throw new Error(`Unknown muxer: ${MUXER}`)
|
||||
}
|
||||
|
||||
console.log('🔧 Creating libp2p node with proven interop configuration...')
|
||||
const node = await createLibp2p(options)
|
||||
|
||||
await node.start()
|
||||
|
||||
console.log(node.peerId.toString())
|
||||
for (const addr of node.getMultiaddrs()) {
|
||||
console.log(addr.toString())
|
||||
}
|
||||
|
||||
// Debug: Print supported protocols
|
||||
console.log('DEBUG: Supported protocols:')
|
||||
if (node.services && node.services.registrar) {
|
||||
const protocols = node.services.registrar.getProtocols()
|
||||
for (const protocol of protocols) {
|
||||
console.log('DEBUG: Protocol:', protocol)
|
||||
}
|
||||
}
|
||||
|
||||
// Debug: Print connection encryption protocols
|
||||
console.log('DEBUG: Connection encryption protocols:')
|
||||
try {
|
||||
if (node.components && node.components.connectionEncryption) {
|
||||
for (const encrypter of node.components.connectionEncryption) {
|
||||
console.log('DEBUG: Encrypter:', encrypter.protocol)
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.log('DEBUG: Could not access connectionEncryption:', e.message)
|
||||
}
|
||||
|
||||
// Debug: Print stream muxer protocols
|
||||
console.log('DEBUG: Stream muxer protocols:')
|
||||
try {
|
||||
if (node.components && node.components.streamMuxers) {
|
||||
for (const muxer of node.components.streamMuxers) {
|
||||
console.log('DEBUG: Muxer:', muxer.protocol)
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.log('DEBUG: Could not access streamMuxers:', e.message)
|
||||
}
|
||||
|
||||
// Keep the process alive
|
||||
await new Promise(() => {})
|
||||
}
|
||||
|
||||
main().catch(err => {
|
||||
console.error(err)
|
||||
process.exit(1)
|
||||
})
|
||||
127
tests/interop/test_js_ws_ping.py
Normal file
127
tests/interop/test_js_ws_ping.py
Normal file
@ -0,0 +1,127 @@
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
from trio.lowlevel import open_process
|
||||
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.host.basic_host import BasicHost
|
||||
from libp2p.network.exceptions import SwarmException
|
||||
from libp2p.network.swarm import Swarm
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import PeerInfo
|
||||
from libp2p.peer.peerstore import PeerStore
|
||||
from libp2p.security.insecure.transport import InsecureTransport
|
||||
from libp2p.stream_muxer.yamux.yamux import Yamux
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
from libp2p.transport.websocket.transport import WebsocketTransport
|
||||
|
||||
PLAINTEXT_PROTOCOL_ID = "/plaintext/2.0.0"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_ping_with_js_node():
|
||||
# Skip this test due to JavaScript dependency issues
|
||||
pytest.skip("Skipping JS interop test due to dependency issues")
|
||||
js_node_dir = os.path.join(os.path.dirname(__file__), "js_libp2p", "js_node", "src")
|
||||
script_name = "./ws_ping_node.mjs"
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
["npm", "install"],
|
||||
cwd=js_node_dir,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
except (subprocess.CalledProcessError, FileNotFoundError) as e:
|
||||
pytest.fail(f"Failed to run 'npm install': {e}")
|
||||
|
||||
# Launch the JS libp2p node (long-running)
|
||||
proc = await open_process(
|
||||
["node", script_name],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=js_node_dir,
|
||||
)
|
||||
assert proc.stdout is not None, "stdout pipe missing"
|
||||
assert proc.stderr is not None, "stderr pipe missing"
|
||||
stdout = proc.stdout
|
||||
stderr = proc.stderr
|
||||
|
||||
try:
|
||||
# Read first two lines (PeerID and multiaddr)
|
||||
buffer = b""
|
||||
with trio.fail_after(30):
|
||||
while buffer.count(b"\n") < 2:
|
||||
chunk = await stdout.receive_some(1024)
|
||||
if not chunk:
|
||||
break
|
||||
buffer += chunk
|
||||
|
||||
lines = [line for line in buffer.decode().splitlines() if line.strip()]
|
||||
if len(lines) < 2:
|
||||
stderr_output = await stderr.receive_some(2048)
|
||||
stderr_output = stderr_output.decode()
|
||||
pytest.fail(
|
||||
"JS node did not produce expected PeerID and multiaddr.\n"
|
||||
f"Stdout: {buffer.decode()!r}\n"
|
||||
f"Stderr: {stderr_output!r}"
|
||||
)
|
||||
peer_id_line, addr_line = lines[0], lines[1]
|
||||
peer_id = ID.from_base58(peer_id_line)
|
||||
maddr = Multiaddr(addr_line)
|
||||
|
||||
# Debug: Print what we're trying to connect to
|
||||
print(f"JS Node Peer ID: {peer_id_line}")
|
||||
print(f"JS Node Address: {addr_line}")
|
||||
print(f"All JS Node lines: {lines}")
|
||||
|
||||
# Set up Python host
|
||||
key_pair = create_new_key_pair()
|
||||
py_peer_id = ID.from_pubkey(key_pair.public_key)
|
||||
peer_store = PeerStore()
|
||||
peer_store.add_key_pair(py_peer_id, key_pair)
|
||||
|
||||
upgrader = TransportUpgrader(
|
||||
secure_transports_by_protocol={
|
||||
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
|
||||
},
|
||||
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
||||
)
|
||||
transport = WebsocketTransport(upgrader)
|
||||
swarm = Swarm(py_peer_id, peer_store, upgrader, transport)
|
||||
host = BasicHost(swarm)
|
||||
|
||||
# Connect to JS node
|
||||
peer_info = PeerInfo(peer_id, [maddr])
|
||||
|
||||
print(f"Python trying to connect to: {peer_info}")
|
||||
|
||||
# Use the host as a context manager
|
||||
async with host.run(listen_addrs=[]):
|
||||
await trio.sleep(1)
|
||||
|
||||
try:
|
||||
await host.connect(peer_info)
|
||||
except SwarmException as e:
|
||||
underlying_error = e.__cause__
|
||||
pytest.fail(
|
||||
"Connection failed with SwarmException.\n"
|
||||
f"THE REAL ERROR IS: {underlying_error!r}\n"
|
||||
)
|
||||
|
||||
assert host.get_network().connections.get(peer_id) is not None
|
||||
|
||||
# Ping protocol
|
||||
stream = await host.new_stream(peer_id, [TProtocol("/ipfs/ping/1.0.0")])
|
||||
await stream.write(b"ping")
|
||||
data = await stream.read(4)
|
||||
assert data == b"pong"
|
||||
finally:
|
||||
proc.send_signal(signal.SIGTERM)
|
||||
await trio.sleep(0)
|
||||
Reference in New Issue
Block a user