mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge branch 'main' into fix/885-Update-default-Bind-address
This commit is contained in:
446
examples/test_tcp_data_transfer.py
Normal file
446
examples/test_tcp_data_transfer.py
Normal file
@ -0,0 +1,446 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TCP P2P Data Transfer Test
|
||||
|
||||
This test proves that TCP peer-to-peer data transfer works correctly in libp2p.
|
||||
This serves as a baseline to compare with WebSocket tests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p import create_yamux_muxer_option, new_host
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||
|
||||
# Test protocol for data exchange
|
||||
TCP_DATA_PROTOCOL = TProtocol("/test/tcp-data-exchange/1.0.0")
|
||||
|
||||
|
||||
async def create_tcp_host_pair():
|
||||
"""Create a pair of hosts configured for TCP communication."""
|
||||
# Create key pairs
|
||||
key_pair_a = create_new_key_pair()
|
||||
key_pair_b = create_new_key_pair()
|
||||
|
||||
# Create security options (using plaintext for simplicity)
|
||||
def security_options(kp):
|
||||
return {
|
||||
PLAINTEXT_PROTOCOL_ID: InsecureTransport(
|
||||
local_key_pair=kp, secure_bytes_provider=None, peerstore=None
|
||||
)
|
||||
}
|
||||
|
||||
# Host A (listener) - TCP transport (default)
|
||||
host_a = new_host(
|
||||
key_pair=key_pair_a,
|
||||
sec_opt=security_options(key_pair_a),
|
||||
muxer_opt=create_yamux_muxer_option(),
|
||||
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")],
|
||||
)
|
||||
|
||||
# Host B (dialer) - TCP transport (default)
|
||||
host_b = new_host(
|
||||
key_pair=key_pair_b,
|
||||
sec_opt=security_options(key_pair_b),
|
||||
muxer_opt=create_yamux_muxer_option(),
|
||||
listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")],
|
||||
)
|
||||
|
||||
return host_a, host_b
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_tcp_basic_connection():
|
||||
"""Test basic TCP connection establishment."""
|
||||
host_a, host_b = await create_tcp_host_pair()
|
||||
|
||||
connection_established = False
|
||||
|
||||
async def connection_handler(stream):
|
||||
nonlocal connection_established
|
||||
connection_established = True
|
||||
await stream.close()
|
||||
|
||||
host_a.set_stream_handler(TCP_DATA_PROTOCOL, connection_handler)
|
||||
|
||||
async with (
|
||||
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]),
|
||||
host_b.run(listen_addrs=[]),
|
||||
):
|
||||
# Get host A's listen address
|
||||
listen_addrs = host_a.get_addrs()
|
||||
assert listen_addrs, "Host A should have listen addresses"
|
||||
|
||||
# Extract TCP address
|
||||
tcp_addr = None
|
||||
for addr in listen_addrs:
|
||||
if "/tcp/" in str(addr) and "/ws" not in str(addr):
|
||||
tcp_addr = addr
|
||||
break
|
||||
|
||||
assert tcp_addr, f"No TCP address found in {listen_addrs}"
|
||||
print(f"🔗 Host A listening on: {tcp_addr}")
|
||||
|
||||
# Create peer info for host A
|
||||
peer_info = info_from_p2p_addr(tcp_addr)
|
||||
|
||||
# Host B connects to host A
|
||||
await host_b.connect(peer_info)
|
||||
print("✅ TCP connection established")
|
||||
|
||||
# Open a stream to test the connection
|
||||
stream = await host_b.new_stream(peer_info.peer_id, [TCP_DATA_PROTOCOL])
|
||||
await stream.close()
|
||||
|
||||
# Wait a bit for the handler to be called
|
||||
await trio.sleep(0.1)
|
||||
|
||||
assert connection_established, "TCP connection handler should have been called"
|
||||
print("✅ TCP basic connection test successful!")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_tcp_data_transfer():
|
||||
"""Test TCP peer-to-peer data transfer."""
|
||||
host_a, host_b = await create_tcp_host_pair()
|
||||
|
||||
# Test data
|
||||
test_data = b"Hello TCP P2P Data Transfer! This is a test message."
|
||||
received_data = None
|
||||
transfer_complete = trio.Event()
|
||||
|
||||
async def data_handler(stream):
|
||||
nonlocal received_data
|
||||
try:
|
||||
# Read the incoming data
|
||||
received_data = await stream.read(len(test_data))
|
||||
# Echo it back to confirm successful transfer
|
||||
await stream.write(received_data)
|
||||
await stream.close()
|
||||
transfer_complete.set()
|
||||
except Exception as e:
|
||||
print(f"Handler error: {e}")
|
||||
transfer_complete.set()
|
||||
|
||||
host_a.set_stream_handler(TCP_DATA_PROTOCOL, data_handler)
|
||||
|
||||
async with (
|
||||
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]),
|
||||
host_b.run(listen_addrs=[]),
|
||||
):
|
||||
# Get host A's listen address
|
||||
listen_addrs = host_a.get_addrs()
|
||||
assert listen_addrs, "Host A should have listen addresses"
|
||||
|
||||
# Extract TCP address
|
||||
tcp_addr = None
|
||||
for addr in listen_addrs:
|
||||
if "/tcp/" in str(addr) and "/ws" not in str(addr):
|
||||
tcp_addr = addr
|
||||
break
|
||||
|
||||
assert tcp_addr, f"No TCP address found in {listen_addrs}"
|
||||
print(f"🔗 Host A listening on: {tcp_addr}")
|
||||
|
||||
# Create peer info for host A
|
||||
peer_info = info_from_p2p_addr(tcp_addr)
|
||||
|
||||
# Host B connects to host A
|
||||
await host_b.connect(peer_info)
|
||||
print("✅ TCP connection established")
|
||||
|
||||
# Open a stream for data transfer
|
||||
stream = await host_b.new_stream(peer_info.peer_id, [TCP_DATA_PROTOCOL])
|
||||
print("✅ TCP stream opened")
|
||||
|
||||
# Send test data
|
||||
await stream.write(test_data)
|
||||
print(f"📤 Sent data: {test_data}")
|
||||
|
||||
# Read echoed data back
|
||||
echoed_data = await stream.read(len(test_data))
|
||||
print(f"📥 Received echo: {echoed_data}")
|
||||
|
||||
await stream.close()
|
||||
|
||||
# Wait for transfer to complete
|
||||
with trio.fail_after(5.0): # 5 second timeout
|
||||
await transfer_complete.wait()
|
||||
|
||||
# Verify data transfer
|
||||
assert received_data == test_data, (
|
||||
f"Data mismatch: {received_data} != {test_data}"
|
||||
)
|
||||
assert echoed_data == test_data, f"Echo mismatch: {echoed_data} != {test_data}"
|
||||
|
||||
print("✅ TCP P2P data transfer successful!")
|
||||
print(f" Original: {test_data}")
|
||||
print(f" Received: {received_data}")
|
||||
print(f" Echoed: {echoed_data}")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_tcp_large_data_transfer():
|
||||
"""Test TCP with larger data payloads."""
|
||||
host_a, host_b = await create_tcp_host_pair()
|
||||
|
||||
# Large test data (10KB)
|
||||
test_data = b"TCP Large Data Test! " * 500 # ~10KB
|
||||
received_data = None
|
||||
transfer_complete = trio.Event()
|
||||
|
||||
async def large_data_handler(stream):
|
||||
nonlocal received_data
|
||||
try:
|
||||
# Read data in chunks
|
||||
chunks = []
|
||||
total_received = 0
|
||||
expected_size = len(test_data)
|
||||
|
||||
while total_received < expected_size:
|
||||
chunk = await stream.read(min(1024, expected_size - total_received))
|
||||
if not chunk:
|
||||
break
|
||||
chunks.append(chunk)
|
||||
total_received += len(chunk)
|
||||
|
||||
received_data = b"".join(chunks)
|
||||
|
||||
# Send back confirmation
|
||||
await stream.write(b"RECEIVED_OK")
|
||||
await stream.close()
|
||||
transfer_complete.set()
|
||||
except Exception as e:
|
||||
print(f"Large data handler error: {e}")
|
||||
transfer_complete.set()
|
||||
|
||||
host_a.set_stream_handler(TCP_DATA_PROTOCOL, large_data_handler)
|
||||
|
||||
async with (
|
||||
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]),
|
||||
host_b.run(listen_addrs=[]),
|
||||
):
|
||||
# Get host A's listen address
|
||||
listen_addrs = host_a.get_addrs()
|
||||
assert listen_addrs, "Host A should have listen addresses"
|
||||
|
||||
# Extract TCP address
|
||||
tcp_addr = None
|
||||
for addr in listen_addrs:
|
||||
if "/tcp/" in str(addr) and "/ws" not in str(addr):
|
||||
tcp_addr = addr
|
||||
break
|
||||
|
||||
assert tcp_addr, f"No TCP address found in {listen_addrs}"
|
||||
print(f"🔗 Host A listening on: {tcp_addr}")
|
||||
print(f"📊 Test data size: {len(test_data)} bytes")
|
||||
|
||||
# Create peer info for host A
|
||||
peer_info = info_from_p2p_addr(tcp_addr)
|
||||
|
||||
# Host B connects to host A
|
||||
await host_b.connect(peer_info)
|
||||
print("✅ TCP connection established")
|
||||
|
||||
# Open a stream for data transfer
|
||||
stream = await host_b.new_stream(peer_info.peer_id, [TCP_DATA_PROTOCOL])
|
||||
print("✅ TCP stream opened")
|
||||
|
||||
# Send large test data in chunks
|
||||
chunk_size = 1024
|
||||
sent_bytes = 0
|
||||
for i in range(0, len(test_data), chunk_size):
|
||||
chunk = test_data[i : i + chunk_size]
|
||||
await stream.write(chunk)
|
||||
sent_bytes += len(chunk)
|
||||
if sent_bytes % (chunk_size * 4) == 0: # Progress every 4KB
|
||||
print(f"📤 Sent {sent_bytes}/{len(test_data)} bytes")
|
||||
|
||||
print(f"📤 Sent all {len(test_data)} bytes")
|
||||
|
||||
# Read confirmation
|
||||
confirmation = await stream.read(1024)
|
||||
print(f"📥 Received confirmation: {confirmation}")
|
||||
|
||||
await stream.close()
|
||||
|
||||
# Wait for transfer to complete
|
||||
with trio.fail_after(10.0): # 10 second timeout for large data
|
||||
await transfer_complete.wait()
|
||||
|
||||
# Verify data transfer
|
||||
assert received_data is not None, "No data was received"
|
||||
assert received_data == test_data, (
|
||||
"Large data transfer failed:"
|
||||
+ f" sizes {len(received_data)} != {len(test_data)}"
|
||||
)
|
||||
assert confirmation == b"RECEIVED_OK", f"Confirmation failed: {confirmation}"
|
||||
|
||||
print("✅ TCP large data transfer successful!")
|
||||
print(f" Data size: {len(test_data)} bytes")
|
||||
print(f" Received: {len(received_data)} bytes")
|
||||
print(f" Match: {received_data == test_data}")
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_tcp_bidirectional_transfer():
|
||||
"""Test bidirectional data transfer over TCP."""
|
||||
host_a, host_b = await create_tcp_host_pair()
|
||||
|
||||
# Test data
|
||||
data_a_to_b = b"Message from Host A to Host B via TCP"
|
||||
data_b_to_a = b"Response from Host B to Host A via TCP"
|
||||
|
||||
received_on_a = None
|
||||
received_on_b = None
|
||||
transfer_complete_a = trio.Event()
|
||||
transfer_complete_b = trio.Event()
|
||||
|
||||
async def handler_a(stream):
|
||||
nonlocal received_on_a
|
||||
try:
|
||||
# Read data from B
|
||||
received_on_a = await stream.read(len(data_b_to_a))
|
||||
print(f"🅰️ Host A received: {received_on_a}")
|
||||
await stream.close()
|
||||
transfer_complete_a.set()
|
||||
except Exception as e:
|
||||
print(f"Handler A error: {e}")
|
||||
transfer_complete_a.set()
|
||||
|
||||
async def handler_b(stream):
|
||||
nonlocal received_on_b
|
||||
try:
|
||||
# Read data from A
|
||||
received_on_b = await stream.read(len(data_a_to_b))
|
||||
print(f"🅱️ Host B received: {received_on_b}")
|
||||
await stream.close()
|
||||
transfer_complete_b.set()
|
||||
except Exception as e:
|
||||
print(f"Handler B error: {e}")
|
||||
transfer_complete_b.set()
|
||||
|
||||
# Set up handlers on both hosts
|
||||
protocol_a_to_b = TProtocol("/test/tcp-a-to-b/1.0.0")
|
||||
protocol_b_to_a = TProtocol("/test/tcp-b-to-a/1.0.0")
|
||||
|
||||
host_a.set_stream_handler(protocol_b_to_a, handler_a)
|
||||
host_b.set_stream_handler(protocol_a_to_b, handler_b)
|
||||
|
||||
async with (
|
||||
host_a.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]),
|
||||
host_b.run(listen_addrs=[Multiaddr("/ip4/127.0.0.1/tcp/0")]),
|
||||
):
|
||||
# Get addresses
|
||||
addrs_a = host_a.get_addrs()
|
||||
addrs_b = host_b.get_addrs()
|
||||
|
||||
assert addrs_a and addrs_b, "Both hosts should have addresses"
|
||||
|
||||
# Extract TCP addresses
|
||||
tcp_addr_a = next(
|
||||
(
|
||||
addr
|
||||
for addr in addrs_a
|
||||
if "/tcp/" in str(addr) and "/ws" not in str(addr)
|
||||
),
|
||||
None,
|
||||
)
|
||||
tcp_addr_b = next(
|
||||
(
|
||||
addr
|
||||
for addr in addrs_b
|
||||
if "/tcp/" in str(addr) and "/ws" not in str(addr)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
assert tcp_addr_a and tcp_addr_b, (
|
||||
f"TCP addresses not found: A={addrs_a}, B={addrs_b}"
|
||||
)
|
||||
print(f"🔗 Host A listening on: {tcp_addr_a}")
|
||||
print(f"🔗 Host B listening on: {tcp_addr_b}")
|
||||
|
||||
# Create peer infos
|
||||
peer_info_a = info_from_p2p_addr(tcp_addr_a)
|
||||
peer_info_b = info_from_p2p_addr(tcp_addr_b)
|
||||
|
||||
# Establish connections
|
||||
await host_b.connect(peer_info_a)
|
||||
await host_a.connect(peer_info_b)
|
||||
print("✅ Bidirectional TCP connections established")
|
||||
|
||||
# Send data A -> B
|
||||
stream_a_to_b = await host_a.new_stream(peer_info_b.peer_id, [protocol_a_to_b])
|
||||
await stream_a_to_b.write(data_a_to_b)
|
||||
print(f"📤 A->B: {data_a_to_b}")
|
||||
await stream_a_to_b.close()
|
||||
|
||||
# Send data B -> A
|
||||
stream_b_to_a = await host_b.new_stream(peer_info_a.peer_id, [protocol_b_to_a])
|
||||
await stream_b_to_a.write(data_b_to_a)
|
||||
print(f"📤 B->A: {data_b_to_a}")
|
||||
await stream_b_to_a.close()
|
||||
|
||||
# Wait for both transfers to complete
|
||||
with trio.fail_after(5.0):
|
||||
await transfer_complete_a.wait()
|
||||
await transfer_complete_b.wait()
|
||||
|
||||
# Verify bidirectional transfer
|
||||
assert received_on_a == data_b_to_a, f"A received wrong data: {received_on_a}"
|
||||
assert received_on_b == data_a_to_b, f"B received wrong data: {received_on_b}"
|
||||
|
||||
print("✅ TCP bidirectional data transfer successful!")
|
||||
print(f" A->B: {data_a_to_b}")
|
||||
print(f" B->A: {data_b_to_a}")
|
||||
print(f" ✓ A got: {received_on_a}")
|
||||
print(f" ✓ B got: {received_on_b}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests directly
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
print("🧪 Running TCP P2P Data Transfer Tests")
|
||||
print("=" * 50)
|
||||
|
||||
async def run_all_tcp_tests():
|
||||
try:
|
||||
print("\n1. Testing basic TCP connection...")
|
||||
await test_tcp_basic_connection()
|
||||
except Exception as e:
|
||||
print(f"❌ Basic TCP connection test failed: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
print("\n2. Testing TCP data transfer...")
|
||||
await test_tcp_data_transfer()
|
||||
except Exception as e:
|
||||
print(f"❌ TCP data transfer test failed: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
print("\n3. Testing TCP large data transfer...")
|
||||
await test_tcp_large_data_transfer()
|
||||
except Exception as e:
|
||||
print(f"❌ TCP large data transfer test failed: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
print("\n4. Testing TCP bidirectional transfer...")
|
||||
await test_tcp_bidirectional_transfer()
|
||||
except Exception as e:
|
||||
print(f"❌ TCP bidirectional transfer test failed: {e}")
|
||||
return
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("🏁 TCP P2P Tests Complete - All Tests PASSED!")
|
||||
|
||||
trio.run(run_all_tcp_tests)
|
||||
210
examples/transport_integration_demo.py
Normal file
210
examples/transport_integration_demo.py
Normal file
@ -0,0 +1,210 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Demo script showing the new transport integration capabilities in py-libp2p.
|
||||
|
||||
This script demonstrates:
|
||||
1. How to use the transport registry
|
||||
2. How to create transports dynamically based on multiaddrs
|
||||
3. How to register custom transports
|
||||
4. How the new system automatically selects the right transport
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
# Add the libp2p directory to the path so we can import it
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import multiaddr
|
||||
|
||||
from libp2p.transport import (
|
||||
create_transport,
|
||||
create_transport_for_multiaddr,
|
||||
get_supported_transport_protocols,
|
||||
get_transport_registry,
|
||||
register_transport,
|
||||
)
|
||||
from libp2p.transport.tcp.tcp import TCP
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def demo_transport_registry():
|
||||
"""Demonstrate the transport registry functionality."""
|
||||
print("🔧 Transport Registry Demo")
|
||||
print("=" * 50)
|
||||
|
||||
# Get the global registry
|
||||
registry = get_transport_registry()
|
||||
|
||||
# Show supported protocols
|
||||
supported = get_supported_transport_protocols()
|
||||
print(f"Supported transport protocols: {supported}")
|
||||
|
||||
# Show registered transports
|
||||
print("\nRegistered transports:")
|
||||
for protocol in supported:
|
||||
transport_class = registry.get_transport(protocol)
|
||||
class_name = transport_class.__name__ if transport_class else "None"
|
||||
print(f" {protocol}: {class_name}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def demo_transport_factory():
|
||||
"""Demonstrate the transport factory functions."""
|
||||
print("🏭 Transport Factory Demo")
|
||||
print("=" * 50)
|
||||
|
||||
# Create a dummy upgrader for WebSocket transport
|
||||
upgrader = TransportUpgrader({}, {})
|
||||
|
||||
# Create transports using the factory function
|
||||
try:
|
||||
tcp_transport = create_transport("tcp")
|
||||
print(f"✅ Created TCP transport: {type(tcp_transport).__name__}")
|
||||
|
||||
ws_transport = create_transport("ws", upgrader)
|
||||
print(f"✅ Created WebSocket transport: {type(ws_transport).__name__}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating transport: {e}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def demo_multiaddr_transport_selection():
|
||||
"""Demonstrate automatic transport selection based on multiaddrs."""
|
||||
print("🎯 Multiaddr Transport Selection Demo")
|
||||
print("=" * 50)
|
||||
|
||||
# Create a dummy upgrader
|
||||
upgrader = TransportUpgrader({}, {})
|
||||
|
||||
# Test different multiaddr types
|
||||
test_addrs = [
|
||||
"/ip4/127.0.0.1/tcp/8080",
|
||||
"/ip4/127.0.0.1/tcp/8080/ws",
|
||||
"/ip6/::1/tcp/8080/ws",
|
||||
"/dns4/example.com/tcp/443/ws",
|
||||
]
|
||||
|
||||
for addr_str in test_addrs:
|
||||
try:
|
||||
maddr = multiaddr.Multiaddr(addr_str)
|
||||
transport = create_transport_for_multiaddr(maddr, upgrader)
|
||||
|
||||
if transport:
|
||||
print(f"✅ {addr_str} -> {type(transport).__name__}")
|
||||
else:
|
||||
print(f"❌ {addr_str} -> No transport found")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ {addr_str} -> Error: {e}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def demo_custom_transport_registration():
|
||||
"""Demonstrate how to register custom transports."""
|
||||
print("🔧 Custom Transport Registration Demo")
|
||||
print("=" * 50)
|
||||
|
||||
# Show current supported protocols
|
||||
print(f"Before registration: {get_supported_transport_protocols()}")
|
||||
|
||||
# Register a custom transport (using TCP as an example)
|
||||
class CustomTCPTransport(TCP):
|
||||
"""Custom TCP transport for demonstration."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.custom_flag = True
|
||||
|
||||
# Register the custom transport
|
||||
register_transport("custom_tcp", CustomTCPTransport)
|
||||
|
||||
# Show updated supported protocols
|
||||
print(f"After registration: {get_supported_transport_protocols()}")
|
||||
|
||||
# Test creating the custom transport
|
||||
try:
|
||||
custom_transport = create_transport("custom_tcp")
|
||||
print(f"✅ Created custom transport: {type(custom_transport).__name__}")
|
||||
# Check if it has the custom flag (type-safe way)
|
||||
if hasattr(custom_transport, "custom_flag"):
|
||||
flag_value = getattr(custom_transport, "custom_flag", "Not found")
|
||||
print(f" Custom flag: {flag_value}")
|
||||
else:
|
||||
print(" Custom flag: Not found")
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating custom transport: {e}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def demo_integration_with_libp2p():
|
||||
"""Demonstrate how the new system integrates with libp2p."""
|
||||
print("🚀 Libp2p Integration Demo")
|
||||
print("=" * 50)
|
||||
|
||||
print("The new transport system integrates seamlessly with libp2p:")
|
||||
print()
|
||||
print("1. ✅ Automatic transport selection based on multiaddr")
|
||||
print("2. ✅ Support for WebSocket (/ws) protocol")
|
||||
print("3. ✅ Fallback to TCP for backward compatibility")
|
||||
print("4. ✅ Easy registration of new transport protocols")
|
||||
print("5. ✅ No changes needed to existing libp2p code")
|
||||
print()
|
||||
|
||||
print("Example usage in libp2p:")
|
||||
print(" # This will automatically use WebSocket transport")
|
||||
print(" host = new_host(listen_addrs=['/ip4/127.0.0.1/tcp/8080/ws'])")
|
||||
print()
|
||||
print(" # This will automatically use TCP transport")
|
||||
print(" host = new_host(listen_addrs=['/ip4/127.0.0.1/tcp/8080'])")
|
||||
print()
|
||||
|
||||
print()
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all demos."""
|
||||
print("🎉 Py-libp2p Transport Integration Demo")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# Run all demos
|
||||
demo_transport_registry()
|
||||
demo_transport_factory()
|
||||
demo_multiaddr_transport_selection()
|
||||
demo_custom_transport_registration()
|
||||
demo_integration_with_libp2p()
|
||||
|
||||
print("🎯 Summary of New Features:")
|
||||
print("=" * 40)
|
||||
print("✅ Transport Registry: Central registry for all transport implementations")
|
||||
print("✅ Dynamic Transport Selection: Automatic selection based on multiaddr")
|
||||
print("✅ WebSocket Support: Full /ws protocol support")
|
||||
print("✅ Extensible Architecture: Easy to add new transport protocols")
|
||||
print("✅ Backward Compatibility: Existing TCP code continues to work")
|
||||
print("✅ Factory Functions: Simple API for creating transports")
|
||||
print()
|
||||
print("🚀 The transport system is now ready for production use!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Demo interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Demo failed with error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
220
examples/websocket/test_tcp_echo.py
Normal file
220
examples/websocket/test_tcp_echo.py
Normal file
@ -0,0 +1,220 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple TCP echo demo to verify basic libp2p functionality.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.host.basic_host import BasicHost
|
||||
from libp2p.network.swarm import Swarm
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.peer.peerstore import PeerStore
|
||||
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||
from libp2p.stream_muxer.yamux.yamux import Yamux
|
||||
from libp2p.transport.tcp.tcp import TCP
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
|
||||
# Enable debug logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger("libp2p.tcp-example")
|
||||
|
||||
# Simple echo protocol
|
||||
ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
||||
|
||||
|
||||
async def echo_handler(stream):
|
||||
"""Simple echo handler that echoes back any data received."""
|
||||
try:
|
||||
data = await stream.read(1024)
|
||||
if data:
|
||||
message = data.decode("utf-8", errors="replace")
|
||||
print(f"📥 Received: {message}")
|
||||
print(f"📤 Echoing back: {message}")
|
||||
await stream.write(data)
|
||||
await stream.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Echo handler error: {e}")
|
||||
await stream.close()
|
||||
|
||||
|
||||
def create_tcp_host():
|
||||
"""Create a host with TCP transport."""
|
||||
# Create key pair and peer store
|
||||
key_pair = create_new_key_pair()
|
||||
peer_id = ID.from_pubkey(key_pair.public_key)
|
||||
peer_store = PeerStore()
|
||||
peer_store.add_key_pair(peer_id, key_pair)
|
||||
|
||||
# Create transport upgrader with plaintext security
|
||||
upgrader = TransportUpgrader(
|
||||
secure_transports_by_protocol={
|
||||
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
|
||||
},
|
||||
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
||||
)
|
||||
|
||||
# Create TCP transport
|
||||
transport = TCP()
|
||||
|
||||
# Create swarm and host
|
||||
swarm = Swarm(peer_id, peer_store, upgrader, transport)
|
||||
host = BasicHost(swarm)
|
||||
|
||||
return host
|
||||
|
||||
|
||||
async def run(port: int, destination: str) -> None:
|
||||
localhost_ip = "0.0.0.0"
|
||||
|
||||
if not destination:
|
||||
# Create first host (listener) with TCP transport
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}")
|
||||
|
||||
try:
|
||||
host = create_tcp_host()
|
||||
logger.debug("Created TCP host")
|
||||
|
||||
# Set up echo handler
|
||||
host.set_stream_handler(ECHO_PROTOCOL_ID, echo_handler)
|
||||
|
||||
async with (
|
||||
host.run(listen_addrs=[listen_addr]),
|
||||
trio.open_nursery() as (nursery),
|
||||
):
|
||||
# Start the peer-store cleanup task
|
||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||
|
||||
# Get the actual address and replace 0.0.0.0 with 127.0.0.1 for client
|
||||
# connections
|
||||
addrs = host.get_addrs()
|
||||
logger.debug(f"Host addresses: {addrs}")
|
||||
if not addrs:
|
||||
print("❌ Error: No addresses found for the host")
|
||||
return
|
||||
|
||||
server_addr = str(addrs[0])
|
||||
client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/")
|
||||
|
||||
print("🌐 TCP Server Started Successfully!")
|
||||
print("=" * 50)
|
||||
print(f"📍 Server Address: {client_addr}")
|
||||
print("🔧 Protocol: /echo/1.0.0")
|
||||
print("🚀 Transport: TCP")
|
||||
print()
|
||||
print("📋 To test the connection, run this in another terminal:")
|
||||
print(f" python test_tcp_echo.py -d {client_addr}")
|
||||
print()
|
||||
print("⏳ Waiting for incoming TCP connections...")
|
||||
print("─" * 50)
|
||||
|
||||
await trio.sleep_forever()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating TCP server: {e}")
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
else:
|
||||
# Create second host (dialer) with TCP transport
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}")
|
||||
|
||||
try:
|
||||
# Create a single host for client operations
|
||||
host = create_tcp_host()
|
||||
|
||||
# Start the host for client operations
|
||||
async with (
|
||||
host.run(listen_addrs=[listen_addr]),
|
||||
trio.open_nursery() as (nursery),
|
||||
):
|
||||
# Start the peer-store cleanup task
|
||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||
maddr = multiaddr.Multiaddr(destination)
|
||||
info = info_from_p2p_addr(maddr)
|
||||
print("🔌 TCP Client Starting...")
|
||||
print("=" * 40)
|
||||
print(f"🎯 Target Peer: {info.peer_id}")
|
||||
print(f"📍 Target Address: {destination}")
|
||||
print()
|
||||
|
||||
try:
|
||||
print("🔗 Connecting to TCP server...")
|
||||
await host.connect(info)
|
||||
print("✅ Successfully connected to TCP server!")
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print("\n❌ Connection Failed!")
|
||||
print(f" Peer ID: {info.peer_id}")
|
||||
print(f" Address: {destination}")
|
||||
print(f" Error: {error_msg}")
|
||||
return
|
||||
|
||||
# Create a stream and send test data
|
||||
try:
|
||||
stream = await host.new_stream(info.peer_id, [ECHO_PROTOCOL_ID])
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to create stream: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
print("🚀 Starting Echo Protocol Test...")
|
||||
print("─" * 40)
|
||||
|
||||
# Send test data
|
||||
test_message = b"Hello TCP Transport!"
|
||||
print(f"📤 Sending message: {test_message.decode('utf-8')}")
|
||||
await stream.write(test_message)
|
||||
|
||||
# Read response
|
||||
print("⏳ Waiting for server response...")
|
||||
response = await stream.read(1024)
|
||||
print(f"📥 Received response: {response.decode('utf-8')}")
|
||||
|
||||
await stream.close()
|
||||
|
||||
print("─" * 40)
|
||||
if response == test_message:
|
||||
print("🎉 Echo test successful!")
|
||||
print("✅ TCP transport is working perfectly!")
|
||||
else:
|
||||
print("❌ Echo test failed!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Echo protocol error: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
print("✅ TCP demo completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating TCP client: {e}")
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
|
||||
def main() -> None:
|
||||
description = "Simple TCP echo demo for libp2p"
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
||||
parser.add_argument(
|
||||
"-d", "--destination", type=str, help="destination multiaddr string"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
trio.run(run, args.port, args.destination)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
145
examples/websocket/test_websocket_transport.py
Normal file
145
examples/websocket/test_websocket_transport.py
Normal file
@ -0,0 +1,145 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple test script to verify WebSocket transport functionality.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
# Add the libp2p directory to the path so we can import it
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
import multiaddr
|
||||
|
||||
from libp2p.transport import create_transport, create_transport_for_multiaddr
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def test_websocket_transport():
|
||||
"""Test basic WebSocket transport functionality."""
|
||||
print("🧪 Testing WebSocket Transport Functionality")
|
||||
print("=" * 50)
|
||||
|
||||
# Create a dummy upgrader
|
||||
upgrader = TransportUpgrader({}, {})
|
||||
|
||||
# Test creating WebSocket transport
|
||||
try:
|
||||
ws_transport = create_transport("ws", upgrader)
|
||||
print(f"✅ WebSocket transport created: {type(ws_transport).__name__}")
|
||||
|
||||
# Test creating transport from multiaddr
|
||||
ws_maddr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
|
||||
ws_transport_from_maddr = create_transport_for_multiaddr(ws_maddr, upgrader)
|
||||
print(
|
||||
f"✅ WebSocket transport from multiaddr: "
|
||||
f"{type(ws_transport_from_maddr).__name__}"
|
||||
)
|
||||
|
||||
# Test creating listener
|
||||
handler_called = False
|
||||
|
||||
async def test_handler(conn):
|
||||
nonlocal handler_called
|
||||
handler_called = True
|
||||
print(f"✅ Connection handler called with: {type(conn).__name__}")
|
||||
await conn.close()
|
||||
|
||||
listener = ws_transport.create_listener(test_handler)
|
||||
print(f"✅ WebSocket listener created: {type(listener).__name__}")
|
||||
|
||||
# Test that the transport can be used
|
||||
print(
|
||||
f"✅ WebSocket transport supports dialing: {hasattr(ws_transport, 'dial')}"
|
||||
)
|
||||
print(
|
||||
f"✅ WebSocket transport supports listening: "
|
||||
f"{hasattr(ws_transport, 'create_listener')}"
|
||||
)
|
||||
|
||||
print("\n🎯 WebSocket Transport Test Results:")
|
||||
print("✅ Transport creation: PASS")
|
||||
print("✅ Multiaddr parsing: PASS")
|
||||
print("✅ Listener creation: PASS")
|
||||
print("✅ Interface compliance: PASS")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ WebSocket transport test failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_transport_registry():
|
||||
"""Test the transport registry functionality."""
|
||||
print("\n🔧 Testing Transport Registry")
|
||||
print("=" * 30)
|
||||
|
||||
from libp2p.transport import (
|
||||
get_supported_transport_protocols,
|
||||
get_transport_registry,
|
||||
)
|
||||
|
||||
registry = get_transport_registry()
|
||||
supported = get_supported_transport_protocols()
|
||||
|
||||
print(f"Supported protocols: {supported}")
|
||||
|
||||
# Test getting transports
|
||||
for protocol in supported:
|
||||
transport_class = registry.get_transport(protocol)
|
||||
class_name = transport_class.__name__ if transport_class else "None"
|
||||
print(f" {protocol}: {class_name}")
|
||||
|
||||
# Test creating transports through registry
|
||||
upgrader = TransportUpgrader({}, {})
|
||||
|
||||
for protocol in supported:
|
||||
try:
|
||||
transport = registry.create_transport(protocol, upgrader)
|
||||
if transport:
|
||||
print(f"✅ {protocol}: Created successfully")
|
||||
else:
|
||||
print(f"❌ {protocol}: Failed to create")
|
||||
except Exception as e:
|
||||
print(f"❌ {protocol}: Error - {e}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all tests."""
|
||||
print("🚀 WebSocket Transport Integration Test Suite")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# Run tests
|
||||
success = await test_websocket_transport()
|
||||
await test_transport_registry()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
if success:
|
||||
print("🎉 All tests passed! WebSocket transport is working correctly.")
|
||||
else:
|
||||
print("❌ Some tests failed. Check the output above for details.")
|
||||
|
||||
print("\n🚀 WebSocket transport is ready for use in py-libp2p!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Test interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Test failed with error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
448
examples/websocket/websocket_demo.py
Normal file
448
examples/websocket/websocket_demo.py
Normal file
@ -0,0 +1,448 @@
|
||||
import argparse
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.abc import INotifee
|
||||
from libp2p.crypto.ed25519 import create_new_key_pair as create_ed25519_key_pair
|
||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||
from libp2p.custom_types import TProtocol
|
||||
from libp2p.host.basic_host import BasicHost
|
||||
from libp2p.network.swarm import Swarm
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||
from libp2p.peer.peerstore import PeerStore
|
||||
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||
from libp2p.security.noise.transport import (
|
||||
PROTOCOL_ID as NOISE_PROTOCOL_ID,
|
||||
Transport as NoiseTransport,
|
||||
)
|
||||
from libp2p.stream_muxer.yamux.yamux import Yamux
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
from libp2p.transport.websocket.transport import WebsocketTransport
|
||||
|
||||
# Enable debug logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger("libp2p.websocket-example")
|
||||
|
||||
|
||||
# Suppress KeyboardInterrupt by handling SIGINT directly
|
||||
def signal_handler(signum, frame):
|
||||
print("✅ Clean exit completed.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
# Simple echo protocol
|
||||
ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
||||
|
||||
|
||||
async def echo_handler(stream):
|
||||
"""Simple echo handler that echoes back any data received."""
|
||||
try:
|
||||
data = await stream.read(1024)
|
||||
if data:
|
||||
message = data.decode("utf-8", errors="replace")
|
||||
print(f"📥 Received: {message}")
|
||||
print(f"📤 Echoing back: {message}")
|
||||
await stream.write(data)
|
||||
await stream.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Echo handler error: {e}")
|
||||
await stream.close()
|
||||
|
||||
|
||||
def create_websocket_host(listen_addrs=None, use_plaintext=False):
|
||||
"""Create a host with WebSocket transport."""
|
||||
# Create key pair and peer store
|
||||
key_pair = create_new_key_pair()
|
||||
peer_id = ID.from_pubkey(key_pair.public_key)
|
||||
peer_store = PeerStore()
|
||||
peer_store.add_key_pair(peer_id, key_pair)
|
||||
|
||||
if use_plaintext:
|
||||
# Create transport upgrader with plaintext security
|
||||
upgrader = TransportUpgrader(
|
||||
secure_transports_by_protocol={
|
||||
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
|
||||
},
|
||||
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
||||
)
|
||||
else:
|
||||
# Create separate Ed25519 key for Noise protocol
|
||||
noise_key_pair = create_ed25519_key_pair()
|
||||
|
||||
# Create Noise transport
|
||||
noise_transport = NoiseTransport(
|
||||
libp2p_keypair=key_pair,
|
||||
noise_privkey=noise_key_pair.private_key,
|
||||
early_data=None,
|
||||
with_noise_pipes=False,
|
||||
)
|
||||
|
||||
# Create transport upgrader with Noise security
|
||||
upgrader = TransportUpgrader(
|
||||
secure_transports_by_protocol={
|
||||
TProtocol(NOISE_PROTOCOL_ID): noise_transport
|
||||
},
|
||||
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
||||
)
|
||||
|
||||
# Create WebSocket transport
|
||||
transport = WebsocketTransport(upgrader)
|
||||
|
||||
# Create swarm and host
|
||||
swarm = Swarm(peer_id, peer_store, upgrader, transport)
|
||||
host = BasicHost(swarm)
|
||||
|
||||
return host
|
||||
|
||||
|
||||
async def run(port: int, destination: str, use_plaintext: bool = False) -> None:
|
||||
localhost_ip = "0.0.0.0"
|
||||
|
||||
if not destination:
|
||||
# Create first host (listener) with WebSocket transport
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}/ws")
|
||||
|
||||
try:
|
||||
host = create_websocket_host(use_plaintext=use_plaintext)
|
||||
logger.debug(f"Created host with use_plaintext={use_plaintext}")
|
||||
|
||||
# Set up echo handler
|
||||
host.set_stream_handler(ECHO_PROTOCOL_ID, echo_handler)
|
||||
|
||||
# Add connection event handlers for debugging
|
||||
class DebugNotifee(INotifee):
|
||||
async def opened_stream(self, network, stream):
|
||||
pass
|
||||
|
||||
async def closed_stream(self, network, stream):
|
||||
pass
|
||||
|
||||
async def connected(self, network, conn):
|
||||
print(
|
||||
f"🔗 New libp2p connection established: "
|
||||
f"{conn.muxed_conn.peer_id}"
|
||||
)
|
||||
if hasattr(conn.muxed_conn, "get_security_protocol"):
|
||||
security = conn.muxed_conn.get_security_protocol()
|
||||
else:
|
||||
security = "Unknown"
|
||||
|
||||
print(f" Security: {security}")
|
||||
|
||||
async def disconnected(self, network, conn):
|
||||
print(f"🔌 libp2p connection closed: {conn.muxed_conn.peer_id}")
|
||||
|
||||
async def listen(self, network, multiaddr):
|
||||
pass
|
||||
|
||||
async def listen_close(self, network, multiaddr):
|
||||
pass
|
||||
|
||||
host.get_network().register_notifee(DebugNotifee())
|
||||
|
||||
# Create a cancellation token for clean shutdown
|
||||
cancel_scope = trio.CancelScope()
|
||||
|
||||
async def signal_handler():
|
||||
with trio.open_signal_receiver(signal.SIGINT, signal.SIGTERM) as (
|
||||
signal_receiver
|
||||
):
|
||||
async for sig in signal_receiver:
|
||||
print(f"\n🛑 Received signal {sig}")
|
||||
print("✅ Shutting down WebSocket server...")
|
||||
cancel_scope.cancel()
|
||||
return
|
||||
|
||||
async with (
|
||||
host.run(listen_addrs=[listen_addr]),
|
||||
trio.open_nursery() as (nursery),
|
||||
):
|
||||
# Start the peer-store cleanup task
|
||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||
|
||||
# Start the signal handler
|
||||
nursery.start_soon(signal_handler)
|
||||
|
||||
# Get the actual address and replace 0.0.0.0 with 127.0.0.1 for client
|
||||
# connections
|
||||
addrs = host.get_addrs()
|
||||
logger.debug(f"Host addresses: {addrs}")
|
||||
if not addrs:
|
||||
print("❌ Error: No addresses found for the host")
|
||||
print("Debug: host.get_addrs() returned empty list")
|
||||
return
|
||||
|
||||
server_addr = str(addrs[0])
|
||||
client_addr = server_addr.replace("/ip4/0.0.0.0/", "/ip4/127.0.0.1/")
|
||||
|
||||
print("🌐 WebSocket Server Started Successfully!")
|
||||
print("=" * 50)
|
||||
print(f"📍 Server Address: {client_addr}")
|
||||
print("🔧 Protocol: /echo/1.0.0")
|
||||
print("🚀 Transport: WebSocket (/ws)")
|
||||
print()
|
||||
print("📋 To test the connection, run this in another terminal:")
|
||||
plaintext_flag = " --plaintext" if use_plaintext else ""
|
||||
print(f" python websocket_demo.py -d {client_addr}{plaintext_flag}")
|
||||
print()
|
||||
print("⏳ Waiting for incoming WebSocket connections...")
|
||||
print("─" * 50)
|
||||
|
||||
# Add a custom handler to show connection events
|
||||
async def custom_echo_handler(stream):
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
print("\n🔗 New WebSocket Connection!")
|
||||
print(f" Peer ID: {peer_id}")
|
||||
print(" Protocol: /echo/1.0.0")
|
||||
|
||||
# Show remote address in multiaddr format
|
||||
try:
|
||||
remote_address = stream.get_remote_address()
|
||||
if remote_address:
|
||||
print(f" Remote: {remote_address}")
|
||||
except Exception:
|
||||
print(" Remote: Unknown")
|
||||
|
||||
print(" ─" * 40)
|
||||
|
||||
# Call the original handler
|
||||
await echo_handler(stream)
|
||||
|
||||
print(" ─" * 40)
|
||||
print(f"✅ Echo request completed for peer: {peer_id}")
|
||||
print()
|
||||
|
||||
# Replace the handler with our custom one
|
||||
host.set_stream_handler(ECHO_PROTOCOL_ID, custom_echo_handler)
|
||||
|
||||
# Wait indefinitely or until cancelled
|
||||
with cancel_scope:
|
||||
await trio.sleep_forever()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating WebSocket server: {e}")
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
else:
|
||||
# Create second host (dialer) with WebSocket transport
|
||||
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}/ws")
|
||||
|
||||
try:
|
||||
# Create a single host for client operations
|
||||
host = create_websocket_host(use_plaintext=use_plaintext)
|
||||
|
||||
# Start the host for client operations
|
||||
async with (
|
||||
host.run(listen_addrs=[listen_addr]),
|
||||
trio.open_nursery() as (nursery),
|
||||
):
|
||||
# Start the peer-store cleanup task
|
||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||
|
||||
# Add connection event handlers for debugging
|
||||
class ClientDebugNotifee(INotifee):
|
||||
async def opened_stream(self, network, stream):
|
||||
pass
|
||||
|
||||
async def closed_stream(self, network, stream):
|
||||
pass
|
||||
|
||||
async def connected(self, network, conn):
|
||||
print(
|
||||
f"🔗 Client: libp2p connection established: "
|
||||
f"{conn.muxed_conn.peer_id}"
|
||||
)
|
||||
|
||||
async def disconnected(self, network, conn):
|
||||
print(
|
||||
f"🔌 Client: libp2p connection closed: "
|
||||
f"{conn.muxed_conn.peer_id}"
|
||||
)
|
||||
|
||||
async def listen(self, network, multiaddr):
|
||||
pass
|
||||
|
||||
async def listen_close(self, network, multiaddr):
|
||||
pass
|
||||
|
||||
host.get_network().register_notifee(ClientDebugNotifee())
|
||||
|
||||
maddr = multiaddr.Multiaddr(destination)
|
||||
info = info_from_p2p_addr(maddr)
|
||||
print("🔌 WebSocket Client Starting...")
|
||||
print("=" * 40)
|
||||
print(f"🎯 Target Peer: {info.peer_id}")
|
||||
print(f"📍 Target Address: {destination}")
|
||||
print()
|
||||
|
||||
try:
|
||||
print("🔗 Connecting to WebSocket server...")
|
||||
print(f" Security: {'Plaintext' if use_plaintext else 'Noise'}")
|
||||
await host.connect(info)
|
||||
print("✅ Successfully connected to WebSocket server!")
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print("\n❌ Connection Failed!")
|
||||
print(f" Peer ID: {info.peer_id}")
|
||||
print(f" Address: {destination}")
|
||||
print(f" Security: {'Plaintext' if use_plaintext else 'Noise'}")
|
||||
print(f" Error: {error_msg}")
|
||||
print(f" Error type: {type(e).__name__}")
|
||||
|
||||
# Add more detailed error information for debugging
|
||||
if hasattr(e, "__cause__") and e.__cause__:
|
||||
print(f" Root cause: {e.__cause__}")
|
||||
print(f" Root cause type: {type(e.__cause__).__name__}")
|
||||
|
||||
print()
|
||||
print("💡 Troubleshooting:")
|
||||
print(" • Make sure the WebSocket server is running")
|
||||
print(" • Check that the server address is correct")
|
||||
print(" • Verify the server is listening on the right port")
|
||||
print(
|
||||
" • Ensure both client and server use the same sec protocol"
|
||||
)
|
||||
if not use_plaintext:
|
||||
print(" • Noise over WebSocket may have compatibility issues")
|
||||
return
|
||||
|
||||
# Create a stream and send test data
|
||||
try:
|
||||
stream = await host.new_stream(info.peer_id, [ECHO_PROTOCOL_ID])
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to create stream: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
print("🚀 Starting Echo Protocol Test...")
|
||||
print("─" * 40)
|
||||
|
||||
# Send test data
|
||||
test_message = b"Hello WebSocket Transport!"
|
||||
print(f"📤 Sending message: {test_message.decode('utf-8')}")
|
||||
await stream.write(test_message)
|
||||
|
||||
# Read response
|
||||
print("⏳ Waiting for server response...")
|
||||
response = await stream.read(1024)
|
||||
print(f"📥 Received response: {response.decode('utf-8')}")
|
||||
|
||||
await stream.close()
|
||||
|
||||
print("─" * 40)
|
||||
if response == test_message:
|
||||
print("🎉 Echo test successful!")
|
||||
print("✅ WebSocket transport is working perfectly!")
|
||||
print("✅ Client completed successfully, exiting.")
|
||||
else:
|
||||
print("❌ Echo test failed!")
|
||||
print(" Response doesn't match sent data.")
|
||||
print(f" Sent: {test_message}")
|
||||
print(f" Received: {response}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f"Echo protocol error: {error_msg}")
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# Ensure stream is closed
|
||||
try:
|
||||
if stream:
|
||||
# Check if stream has is_closed method and use it
|
||||
has_is_closed = hasattr(stream, "is_closed") and callable(
|
||||
getattr(stream, "is_closed")
|
||||
)
|
||||
if has_is_closed:
|
||||
# type: ignore[attr-defined]
|
||||
if not await stream.is_closed():
|
||||
await stream.close()
|
||||
else:
|
||||
# Fallback: just try to close the stream
|
||||
await stream.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# host.run() context manager handles cleanup automatically
|
||||
print()
|
||||
print("🎉 WebSocket Demo Completed Successfully!")
|
||||
print("=" * 50)
|
||||
print("✅ WebSocket transport is working perfectly!")
|
||||
print("✅ Echo protocol communication successful!")
|
||||
print("✅ libp2p integration verified!")
|
||||
print()
|
||||
print("🚀 Your WebSocket transport is ready for production use!")
|
||||
|
||||
# Add a small delay to ensure all cleanup is complete
|
||||
await trio.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error creating WebSocket client: {e}")
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
|
||||
def main() -> None:
|
||||
description = """
|
||||
This program demonstrates the libp2p WebSocket transport.
|
||||
First run
|
||||
'python websocket_demo.py -p <PORT> [--plaintext]' to start a WebSocket server.
|
||||
Then run
|
||||
'python websocket_demo.py <ANOTHER_PORT> -d <DESTINATION> [--plaintext]'
|
||||
where <DESTINATION> is the multiaddress shown by the server.
|
||||
|
||||
By default, this example uses Noise encryption for secure communication.
|
||||
Use --plaintext for testing with unencrypted communication
|
||||
(not recommended for production).
|
||||
"""
|
||||
|
||||
example_maddr = (
|
||||
"/ip4/127.0.0.1/tcp/8888/ws/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--destination",
|
||||
type=str,
|
||||
help=f"destination multiaddr string, e.g. {example_maddr}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plaintext",
|
||||
action="store_true",
|
||||
help=(
|
||||
"use plaintext security instead of Noise encryption "
|
||||
"(not recommended for production)"
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine security mode: use Noise by default,
|
||||
# plaintext if --plaintext is specified
|
||||
use_plaintext = args.plaintext
|
||||
|
||||
try:
|
||||
trio.run(run, args.port, args.destination, use_plaintext)
|
||||
except KeyboardInterrupt:
|
||||
# This is expected when Ctrl+C is pressed
|
||||
# The signal handler already printed the shutdown message
|
||||
print("✅ Clean exit completed.")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"❌ Unexpected error: {e}")
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user