mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-12 16:10:57 +00:00
Fix typecheck errors and improve WebSocket transport implementation
- Fix INotifee interface compliance in WebSocket demo - Fix handler function signatures to be async (THandler compatibility) - Fix is_closed method usage with proper type checking - Fix pytest.raises multiple exception type issue - Fix line length violations (E501) across multiple files - Add debugging logging to Noise security module for troubleshooting - Update WebSocket transport examples and tests - Improve transport registry error handling
This commit is contained in:
@ -11,13 +11,14 @@ This script demonstrates:
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
|
||||||
# Add the libp2p directory to the path so we can import it
|
# Add the libp2p directory to the path so we can import it
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
import multiaddr
|
import multiaddr
|
||||||
|
|
||||||
from libp2p.transport import (
|
from libp2p.transport import (
|
||||||
create_transport,
|
create_transport,
|
||||||
create_transport_for_multiaddr,
|
create_transport_for_multiaddr,
|
||||||
@ -25,9 +26,8 @@ from libp2p.transport import (
|
|||||||
get_transport_registry,
|
get_transport_registry,
|
||||||
register_transport,
|
register_transport,
|
||||||
)
|
)
|
||||||
from libp2p.transport.upgrader import TransportUpgrader
|
|
||||||
from libp2p.transport.tcp.tcp import TCP
|
from libp2p.transport.tcp.tcp import TCP
|
||||||
from libp2p.transport.websocket.transport import WebsocketTransport
|
from libp2p.transport.upgrader import TransportUpgrader
|
||||||
|
|
||||||
# Set up logging
|
# Set up logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@ -50,7 +50,8 @@ def demo_transport_registry():
|
|||||||
print("\nRegistered transports:")
|
print("\nRegistered transports:")
|
||||||
for protocol in supported:
|
for protocol in supported:
|
||||||
transport_class = registry.get_transport(protocol)
|
transport_class = registry.get_transport(protocol)
|
||||||
print(f" {protocol}: {transport_class.__name__}")
|
class_name = transport_class.__name__ if transport_class else "None"
|
||||||
|
print(f" {protocol}: {class_name}")
|
||||||
|
|
||||||
print()
|
print()
|
||||||
|
|
||||||
@ -114,15 +115,13 @@ def demo_custom_transport_registration():
|
|||||||
print("🔧 Custom Transport Registration Demo")
|
print("🔧 Custom Transport Registration Demo")
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|
||||||
# Create a dummy upgrader
|
|
||||||
upgrader = TransportUpgrader({}, {})
|
|
||||||
|
|
||||||
# Show current supported protocols
|
# Show current supported protocols
|
||||||
print(f"Before registration: {get_supported_transport_protocols()}")
|
print(f"Before registration: {get_supported_transport_protocols()}")
|
||||||
|
|
||||||
# Register a custom transport (using TCP as an example)
|
# Register a custom transport (using TCP as an example)
|
||||||
class CustomTCPTransport(TCP):
|
class CustomTCPTransport(TCP):
|
||||||
"""Custom TCP transport for demonstration."""
|
"""Custom TCP transport for demonstration."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.custom_flag = True
|
self.custom_flag = True
|
||||||
@ -137,7 +136,12 @@ def demo_custom_transport_registration():
|
|||||||
try:
|
try:
|
||||||
custom_transport = create_transport("custom_tcp")
|
custom_transport = create_transport("custom_tcp")
|
||||||
print(f"✅ Created custom transport: {type(custom_transport).__name__}")
|
print(f"✅ Created custom transport: {type(custom_transport).__name__}")
|
||||||
print(f" Custom flag: {custom_transport.custom_flag}")
|
# Check if it has the custom flag (type-safe way)
|
||||||
|
if hasattr(custom_transport, "custom_flag"):
|
||||||
|
flag_value = getattr(custom_transport, "custom_flag", "Not found")
|
||||||
|
print(f" Custom flag: {flag_value}")
|
||||||
|
else:
|
||||||
|
print(" Custom flag: Not found")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ Error creating custom transport: {e}")
|
print(f"❌ Error creating custom transport: {e}")
|
||||||
|
|
||||||
@ -202,4 +206,5 @@ if __name__ == "__main__":
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n❌ Demo failed with error: {e}")
|
print(f"\n❌ Demo failed with error: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|||||||
@ -5,7 +5,6 @@ Simple TCP echo demo to verify basic libp2p functionality.
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import multiaddr
|
import multiaddr
|
||||||
@ -18,10 +17,10 @@ from libp2p.network.swarm import Swarm
|
|||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
from libp2p.peer.peerinfo import info_from_p2p_addr
|
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||||
from libp2p.peer.peerstore import PeerStore
|
from libp2p.peer.peerstore import PeerStore
|
||||||
from libp2p.security.insecure.transport import InsecureTransport, PLAINTEXT_PROTOCOL_ID
|
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||||
from libp2p.stream_muxer.yamux.yamux import Yamux
|
from libp2p.stream_muxer.yamux.yamux import Yamux
|
||||||
from libp2p.transport.upgrader import TransportUpgrader
|
|
||||||
from libp2p.transport.tcp.tcp import TCP
|
from libp2p.transport.tcp.tcp import TCP
|
||||||
|
from libp2p.transport.upgrader import TransportUpgrader
|
||||||
|
|
||||||
# Enable debug logging
|
# Enable debug logging
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
@ -31,12 +30,13 @@ logger = logging.getLogger("libp2p.tcp-example")
|
|||||||
# Simple echo protocol
|
# Simple echo protocol
|
||||||
ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
||||||
|
|
||||||
|
|
||||||
async def echo_handler(stream):
|
async def echo_handler(stream):
|
||||||
"""Simple echo handler that echoes back any data received."""
|
"""Simple echo handler that echoes back any data received."""
|
||||||
try:
|
try:
|
||||||
data = await stream.read(1024)
|
data = await stream.read(1024)
|
||||||
if data:
|
if data:
|
||||||
message = data.decode('utf-8', errors='replace')
|
message = data.decode("utf-8", errors="replace")
|
||||||
print(f"📥 Received: {message}")
|
print(f"📥 Received: {message}")
|
||||||
print(f"📤 Echoing back: {message}")
|
print(f"📤 Echoing back: {message}")
|
||||||
await stream.write(data)
|
await stream.write(data)
|
||||||
@ -45,6 +45,7 @@ async def echo_handler(stream):
|
|||||||
logger.error(f"Echo handler error: {e}")
|
logger.error(f"Echo handler error: {e}")
|
||||||
await stream.close()
|
await stream.close()
|
||||||
|
|
||||||
|
|
||||||
def create_tcp_host():
|
def create_tcp_host():
|
||||||
"""Create a host with TCP transport."""
|
"""Create a host with TCP transport."""
|
||||||
# Create key pair and peer store
|
# Create key pair and peer store
|
||||||
@ -70,6 +71,7 @@ def create_tcp_host():
|
|||||||
|
|
||||||
return host
|
return host
|
||||||
|
|
||||||
|
|
||||||
async def run(port: int, destination: str) -> None:
|
async def run(port: int, destination: str) -> None:
|
||||||
localhost_ip = "0.0.0.0"
|
localhost_ip = "0.0.0.0"
|
||||||
|
|
||||||
@ -84,7 +86,10 @@ async def run(port: int, destination: str) -> None:
|
|||||||
# Set up echo handler
|
# Set up echo handler
|
||||||
host.set_stream_handler(ECHO_PROTOCOL_ID, echo_handler)
|
host.set_stream_handler(ECHO_PROTOCOL_ID, echo_handler)
|
||||||
|
|
||||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
async with (
|
||||||
|
host.run(listen_addrs=[listen_addr]),
|
||||||
|
trio.open_nursery() as (nursery),
|
||||||
|
):
|
||||||
# Start the peer-store cleanup task
|
# Start the peer-store cleanup task
|
||||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||||
|
|
||||||
@ -102,8 +107,8 @@ async def run(port: int, destination: str) -> None:
|
|||||||
print("🌐 TCP Server Started Successfully!")
|
print("🌐 TCP Server Started Successfully!")
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
print(f"📍 Server Address: {client_addr}")
|
print(f"📍 Server Address: {client_addr}")
|
||||||
print(f"🔧 Protocol: /echo/1.0.0")
|
print("🔧 Protocol: /echo/1.0.0")
|
||||||
print(f"🚀 Transport: TCP")
|
print("🚀 Transport: TCP")
|
||||||
print()
|
print()
|
||||||
print("📋 To test the connection, run this in another terminal:")
|
print("📋 To test the connection, run this in another terminal:")
|
||||||
print(f" python test_tcp_echo.py -d {client_addr}")
|
print(f" python test_tcp_echo.py -d {client_addr}")
|
||||||
@ -127,7 +132,10 @@ async def run(port: int, destination: str) -> None:
|
|||||||
host = create_tcp_host()
|
host = create_tcp_host()
|
||||||
|
|
||||||
# Start the host for client operations
|
# Start the host for client operations
|
||||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
async with (
|
||||||
|
host.run(listen_addrs=[listen_addr]),
|
||||||
|
trio.open_nursery() as (nursery),
|
||||||
|
):
|
||||||
# Start the peer-store cleanup task
|
# Start the peer-store cleanup task
|
||||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||||
maddr = multiaddr.Multiaddr(destination)
|
maddr = multiaddr.Multiaddr(destination)
|
||||||
@ -144,7 +152,7 @@ async def run(port: int, destination: str) -> None:
|
|||||||
print("✅ Successfully connected to TCP server!")
|
print("✅ Successfully connected to TCP server!")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = str(e)
|
error_msg = str(e)
|
||||||
print(f"\n❌ Connection Failed!")
|
print("\n❌ Connection Failed!")
|
||||||
print(f" Peer ID: {info.peer_id}")
|
print(f" Peer ID: {info.peer_id}")
|
||||||
print(f" Address: {destination}")
|
print(f" Address: {destination}")
|
||||||
print(f" Error: {error_msg}")
|
print(f" Error: {error_msg}")
|
||||||
@ -191,11 +199,14 @@ async def run(port: int, destination: str) -> None:
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
description = "Simple TCP echo demo for libp2p"
|
description = "Simple TCP echo demo for libp2p"
|
||||||
parser = argparse.ArgumentParser(description=description)
|
parser = argparse.ArgumentParser(description=description)
|
||||||
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
parser.add_argument("-p", "--port", default=0, type=int, help="source port number")
|
||||||
parser.add_argument("-d", "--destination", type=str, help="destination multiaddr string")
|
parser.add_argument(
|
||||||
|
"-d", "--destination", type=str, help="destination multiaddr string"
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -204,5 +215,6 @@ def main() -> None:
|
|||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@ -5,16 +5,16 @@ Simple test script to verify WebSocket transport functionality.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
|
||||||
# Add the libp2p directory to the path so we can import it
|
# Add the libp2p directory to the path so we can import it
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
import multiaddr
|
import multiaddr
|
||||||
|
|
||||||
from libp2p.transport import create_transport, create_transport_for_multiaddr
|
from libp2p.transport import create_transport, create_transport_for_multiaddr
|
||||||
from libp2p.transport.upgrader import TransportUpgrader
|
from libp2p.transport.upgrader import TransportUpgrader
|
||||||
from libp2p.network.connection.raw_connection import RawConnection
|
|
||||||
|
|
||||||
# Set up logging
|
# Set up logging
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
@ -37,7 +37,10 @@ async def test_websocket_transport():
|
|||||||
# Test creating transport from multiaddr
|
# Test creating transport from multiaddr
|
||||||
ws_maddr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
|
ws_maddr = multiaddr.Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
|
||||||
ws_transport_from_maddr = create_transport_for_multiaddr(ws_maddr, upgrader)
|
ws_transport_from_maddr = create_transport_for_multiaddr(ws_maddr, upgrader)
|
||||||
print(f"✅ WebSocket transport from multiaddr: {type(ws_transport_from_maddr).__name__}")
|
print(
|
||||||
|
f"✅ WebSocket transport from multiaddr: "
|
||||||
|
f"{type(ws_transport_from_maddr).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
# Test creating listener
|
# Test creating listener
|
||||||
handler_called = False
|
handler_called = False
|
||||||
@ -52,8 +55,13 @@ async def test_websocket_transport():
|
|||||||
print(f"✅ WebSocket listener created: {type(listener).__name__}")
|
print(f"✅ WebSocket listener created: {type(listener).__name__}")
|
||||||
|
|
||||||
# Test that the transport can be used
|
# Test that the transport can be used
|
||||||
print(f"✅ WebSocket transport supports dialing: {hasattr(ws_transport, 'dial')}")
|
print(
|
||||||
print(f"✅ WebSocket transport supports listening: {hasattr(ws_transport, 'create_listener')}")
|
f"✅ WebSocket transport supports dialing: {hasattr(ws_transport, 'dial')}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"✅ WebSocket transport supports listening: "
|
||||||
|
f"{hasattr(ws_transport, 'create_listener')}"
|
||||||
|
)
|
||||||
|
|
||||||
print("\n🎯 WebSocket Transport Test Results:")
|
print("\n🎯 WebSocket Transport Test Results:")
|
||||||
print("✅ Transport creation: PASS")
|
print("✅ Transport creation: PASS")
|
||||||
@ -64,6 +72,7 @@ async def test_websocket_transport():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ WebSocket transport test failed: {e}")
|
print(f"❌ WebSocket transport test failed: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -75,7 +84,10 @@ async def test_transport_registry():
|
|||||||
print("\n🔧 Testing Transport Registry")
|
print("\n🔧 Testing Transport Registry")
|
||||||
print("=" * 30)
|
print("=" * 30)
|
||||||
|
|
||||||
from libp2p.transport import get_transport_registry, get_supported_transport_protocols
|
from libp2p.transport import (
|
||||||
|
get_supported_transport_protocols,
|
||||||
|
get_transport_registry,
|
||||||
|
)
|
||||||
|
|
||||||
registry = get_transport_registry()
|
registry = get_transport_registry()
|
||||||
supported = get_supported_transport_protocols()
|
supported = get_supported_transport_protocols()
|
||||||
@ -85,7 +97,8 @@ async def test_transport_registry():
|
|||||||
# Test getting transports
|
# Test getting transports
|
||||||
for protocol in supported:
|
for protocol in supported:
|
||||||
transport_class = registry.get_transport(protocol)
|
transport_class = registry.get_transport(protocol)
|
||||||
print(f" {protocol}: {transport_class.__name__}")
|
class_name = transport_class.__name__ if transport_class else "None"
|
||||||
|
print(f" {protocol}: {class_name}")
|
||||||
|
|
||||||
# Test creating transports through registry
|
# Test creating transports through registry
|
||||||
upgrader = TransportUpgrader({}, {})
|
upgrader = TransportUpgrader({}, {})
|
||||||
@ -128,4 +141,5 @@ if __name__ == "__main__":
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n❌ Test failed with error: {e}")
|
print(f"\n❌ Test failed with error: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
@ -1,21 +1,26 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import multiaddr
|
import multiaddr
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
|
from libp2p.abc import INotifee
|
||||||
|
from libp2p.crypto.ed25519 import create_new_key_pair as create_ed25519_key_pair
|
||||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||||
from libp2p.custom_types import TProtocol
|
from libp2p.custom_types import TProtocol
|
||||||
from libp2p.host.basic_host import BasicHost
|
from libp2p.host.basic_host import BasicHost
|
||||||
from libp2p.network.swarm import Swarm
|
from libp2p.network.swarm import Swarm
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
from libp2p.peer.peerinfo import PeerInfo, info_from_p2p_addr
|
from libp2p.peer.peerinfo import info_from_p2p_addr
|
||||||
from libp2p.peer.peerstore import PeerStore
|
from libp2p.peer.peerstore import PeerStore
|
||||||
from libp2p.security.insecure.transport import InsecureTransport, PLAINTEXT_PROTOCOL_ID
|
from libp2p.security.insecure.transport import PLAINTEXT_PROTOCOL_ID, InsecureTransport
|
||||||
from libp2p.security.noise.transport import Transport as NoiseTransport
|
from libp2p.security.noise.transport import (
|
||||||
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
|
PROTOCOL_ID as NOISE_PROTOCOL_ID,
|
||||||
|
Transport as NoiseTransport,
|
||||||
|
)
|
||||||
from libp2p.stream_muxer.yamux.yamux import Yamux
|
from libp2p.stream_muxer.yamux.yamux import Yamux
|
||||||
from libp2p.transport.upgrader import TransportUpgrader
|
from libp2p.transport.upgrader import TransportUpgrader
|
||||||
from libp2p.transport.websocket.transport import WebsocketTransport
|
from libp2p.transport.websocket.transport import WebsocketTransport
|
||||||
@ -25,6 +30,15 @@ logging.basicConfig(level=logging.DEBUG)
|
|||||||
|
|
||||||
logger = logging.getLogger("libp2p.websocket-example")
|
logger = logging.getLogger("libp2p.websocket-example")
|
||||||
|
|
||||||
|
|
||||||
|
# Suppress KeyboardInterrupt by handling SIGINT directly
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
print("✅ Clean exit completed.")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
# Simple echo protocol
|
# Simple echo protocol
|
||||||
ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
ECHO_PROTOCOL_ID = TProtocol("/echo/1.0.0")
|
||||||
|
|
||||||
@ -34,7 +48,7 @@ async def echo_handler(stream):
|
|||||||
try:
|
try:
|
||||||
data = await stream.read(1024)
|
data = await stream.read(1024)
|
||||||
if data:
|
if data:
|
||||||
message = data.decode('utf-8', errors='replace')
|
message = data.decode("utf-8", errors="replace")
|
||||||
print(f"📥 Received: {message}")
|
print(f"📥 Received: {message}")
|
||||||
print(f"📤 Echoing back: {message}")
|
print(f"📤 Echoing back: {message}")
|
||||||
await stream.write(data)
|
await stream.write(data)
|
||||||
@ -44,7 +58,7 @@ async def echo_handler(stream):
|
|||||||
await stream.close()
|
await stream.close()
|
||||||
|
|
||||||
|
|
||||||
def create_websocket_host(listen_addrs=None, use_noise=False):
|
def create_websocket_host(listen_addrs=None, use_plaintext=False):
|
||||||
"""Create a host with WebSocket transport."""
|
"""Create a host with WebSocket transport."""
|
||||||
# Create key pair and peer store
|
# Create key pair and peer store
|
||||||
key_pair = create_new_key_pair()
|
key_pair = create_new_key_pair()
|
||||||
@ -52,11 +66,22 @@ def create_websocket_host(listen_addrs=None, use_noise=False):
|
|||||||
peer_store = PeerStore()
|
peer_store = PeerStore()
|
||||||
peer_store.add_key_pair(peer_id, key_pair)
|
peer_store.add_key_pair(peer_id, key_pair)
|
||||||
|
|
||||||
if use_noise:
|
if use_plaintext:
|
||||||
|
# Create transport upgrader with plaintext security
|
||||||
|
upgrader = TransportUpgrader(
|
||||||
|
secure_transports_by_protocol={
|
||||||
|
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
|
||||||
|
},
|
||||||
|
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Create separate Ed25519 key for Noise protocol
|
||||||
|
noise_key_pair = create_ed25519_key_pair()
|
||||||
|
|
||||||
# Create Noise transport
|
# Create Noise transport
|
||||||
noise_transport = NoiseTransport(
|
noise_transport = NoiseTransport(
|
||||||
libp2p_keypair=key_pair,
|
libp2p_keypair=key_pair,
|
||||||
noise_privkey=key_pair.private_key,
|
noise_privkey=noise_key_pair.private_key,
|
||||||
early_data=None,
|
early_data=None,
|
||||||
with_noise_pipes=False,
|
with_noise_pipes=False,
|
||||||
)
|
)
|
||||||
@ -68,14 +93,6 @@ def create_websocket_host(listen_addrs=None, use_noise=False):
|
|||||||
},
|
},
|
||||||
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# Create transport upgrader with plaintext security
|
|
||||||
upgrader = TransportUpgrader(
|
|
||||||
secure_transports_by_protocol={
|
|
||||||
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair)
|
|
||||||
},
|
|
||||||
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create WebSocket transport
|
# Create WebSocket transport
|
||||||
transport = WebsocketTransport(upgrader)
|
transport = WebsocketTransport(upgrader)
|
||||||
@ -87,7 +104,7 @@ def create_websocket_host(listen_addrs=None, use_noise=False):
|
|||||||
return host
|
return host
|
||||||
|
|
||||||
|
|
||||||
async def run(port: int, destination: str, use_noise: bool = False) -> None:
|
async def run(port: int, destination: str, use_plaintext: bool = False) -> None:
|
||||||
localhost_ip = "0.0.0.0"
|
localhost_ip = "0.0.0.0"
|
||||||
|
|
||||||
if not destination:
|
if not destination:
|
||||||
@ -95,16 +112,66 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
|
|||||||
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}/ws")
|
listen_addr = multiaddr.Multiaddr(f"/ip4/{localhost_ip}/tcp/{port}/ws")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
host = create_websocket_host(use_noise=use_noise)
|
host = create_websocket_host(use_plaintext=use_plaintext)
|
||||||
logger.debug(f"Created host with use_noise={use_noise}")
|
logger.debug(f"Created host with use_plaintext={use_plaintext}")
|
||||||
|
|
||||||
# Set up echo handler
|
# Set up echo handler
|
||||||
host.set_stream_handler(ECHO_PROTOCOL_ID, echo_handler)
|
host.set_stream_handler(ECHO_PROTOCOL_ID, echo_handler)
|
||||||
|
|
||||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
# Add connection event handlers for debugging
|
||||||
|
class DebugNotifee(INotifee):
|
||||||
|
async def opened_stream(self, network, stream):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def closed_stream(self, network, stream):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def connected(self, network, conn):
|
||||||
|
print(
|
||||||
|
f"🔗 New libp2p connection established: "
|
||||||
|
f"{conn.muxed_conn.peer_id}"
|
||||||
|
)
|
||||||
|
if hasattr(conn.muxed_conn, "get_security_protocol"):
|
||||||
|
security = conn.muxed_conn.get_security_protocol()
|
||||||
|
else:
|
||||||
|
security = "Unknown"
|
||||||
|
|
||||||
|
print(f" Security: {security}")
|
||||||
|
|
||||||
|
async def disconnected(self, network, conn):
|
||||||
|
print(f"🔌 libp2p connection closed: {conn.muxed_conn.peer_id}")
|
||||||
|
|
||||||
|
async def listen(self, network, multiaddr):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def listen_close(self, network, multiaddr):
|
||||||
|
pass
|
||||||
|
|
||||||
|
host.get_network().register_notifee(DebugNotifee())
|
||||||
|
|
||||||
|
# Create a cancellation token for clean shutdown
|
||||||
|
cancel_scope = trio.CancelScope()
|
||||||
|
|
||||||
|
async def signal_handler():
|
||||||
|
with trio.open_signal_receiver(signal.SIGINT, signal.SIGTERM) as (
|
||||||
|
signal_receiver
|
||||||
|
):
|
||||||
|
async for sig in signal_receiver:
|
||||||
|
print(f"\n🛑 Received signal {sig}")
|
||||||
|
print("✅ Shutting down WebSocket server...")
|
||||||
|
cancel_scope.cancel()
|
||||||
|
return
|
||||||
|
|
||||||
|
async with (
|
||||||
|
host.run(listen_addrs=[listen_addr]),
|
||||||
|
trio.open_nursery() as (nursery),
|
||||||
|
):
|
||||||
# Start the peer-store cleanup task
|
# Start the peer-store cleanup task
|
||||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||||
|
|
||||||
|
# Start the signal handler
|
||||||
|
nursery.start_soon(signal_handler)
|
||||||
|
|
||||||
# Get the actual address and replace 0.0.0.0 with 127.0.0.1 for client
|
# Get the actual address and replace 0.0.0.0 with 127.0.0.1 for client
|
||||||
# connections
|
# connections
|
||||||
addrs = host.get_addrs()
|
addrs = host.get_addrs()
|
||||||
@ -120,11 +187,12 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
|
|||||||
print("🌐 WebSocket Server Started Successfully!")
|
print("🌐 WebSocket Server Started Successfully!")
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
print(f"📍 Server Address: {client_addr}")
|
print(f"📍 Server Address: {client_addr}")
|
||||||
print(f"🔧 Protocol: /echo/1.0.0")
|
print("🔧 Protocol: /echo/1.0.0")
|
||||||
print(f"🚀 Transport: WebSocket (/ws)")
|
print("🚀 Transport: WebSocket (/ws)")
|
||||||
print()
|
print()
|
||||||
print("📋 To test the connection, run this in another terminal:")
|
print("📋 To test the connection, run this in another terminal:")
|
||||||
print(f" python websocket_demo.py -d {client_addr}")
|
plaintext_flag = " --plaintext" if use_plaintext else ""
|
||||||
|
print(f" python websocket_demo.py -d {client_addr}{plaintext_flag}")
|
||||||
print()
|
print()
|
||||||
print("⏳ Waiting for incoming WebSocket connections...")
|
print("⏳ Waiting for incoming WebSocket connections...")
|
||||||
print("─" * 50)
|
print("─" * 50)
|
||||||
@ -132,9 +200,9 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
|
|||||||
# Add a custom handler to show connection events
|
# Add a custom handler to show connection events
|
||||||
async def custom_echo_handler(stream):
|
async def custom_echo_handler(stream):
|
||||||
peer_id = stream.muxed_conn.peer_id
|
peer_id = stream.muxed_conn.peer_id
|
||||||
print(f"\n🔗 New WebSocket Connection!")
|
print("\n🔗 New WebSocket Connection!")
|
||||||
print(f" Peer ID: {peer_id}")
|
print(f" Peer ID: {peer_id}")
|
||||||
print(f" Protocol: /echo/1.0.0")
|
print(" Protocol: /echo/1.0.0")
|
||||||
|
|
||||||
# Show remote address in multiaddr format
|
# Show remote address in multiaddr format
|
||||||
try:
|
try:
|
||||||
@ -142,21 +210,23 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
|
|||||||
if remote_address:
|
if remote_address:
|
||||||
print(f" Remote: {remote_address}")
|
print(f" Remote: {remote_address}")
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f" Remote: Unknown")
|
print(" Remote: Unknown")
|
||||||
|
|
||||||
print(f" ─" * 40)
|
print(" ─" * 40)
|
||||||
|
|
||||||
# Call the original handler
|
# Call the original handler
|
||||||
await echo_handler(stream)
|
await echo_handler(stream)
|
||||||
|
|
||||||
print(f" ─" * 40)
|
print(" ─" * 40)
|
||||||
print(f"✅ Echo request completed for peer: {peer_id}")
|
print(f"✅ Echo request completed for peer: {peer_id}")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# Replace the handler with our custom one
|
# Replace the handler with our custom one
|
||||||
host.set_stream_handler(ECHO_PROTOCOL_ID, custom_echo_handler)
|
host.set_stream_handler(ECHO_PROTOCOL_ID, custom_echo_handler)
|
||||||
|
|
||||||
await trio.sleep_forever()
|
# Wait indefinitely or until cancelled
|
||||||
|
with cancel_scope:
|
||||||
|
await trio.sleep_forever()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ Error creating WebSocket server: {e}")
|
print(f"❌ Error creating WebSocket server: {e}")
|
||||||
@ -169,12 +239,44 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Create a single host for client operations
|
# Create a single host for client operations
|
||||||
host = create_websocket_host(use_noise=use_noise)
|
host = create_websocket_host(use_plaintext=use_plaintext)
|
||||||
|
|
||||||
# Start the host for client operations
|
# Start the host for client operations
|
||||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
async with (
|
||||||
|
host.run(listen_addrs=[listen_addr]),
|
||||||
|
trio.open_nursery() as (nursery),
|
||||||
|
):
|
||||||
# Start the peer-store cleanup task
|
# Start the peer-store cleanup task
|
||||||
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
nursery.start_soon(host.get_peerstore().start_cleanup_task, 60)
|
||||||
|
|
||||||
|
# Add connection event handlers for debugging
|
||||||
|
class ClientDebugNotifee(INotifee):
|
||||||
|
async def opened_stream(self, network, stream):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def closed_stream(self, network, stream):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def connected(self, network, conn):
|
||||||
|
print(
|
||||||
|
f"🔗 Client: libp2p connection established: "
|
||||||
|
f"{conn.muxed_conn.peer_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def disconnected(self, network, conn):
|
||||||
|
print(
|
||||||
|
f"🔌 Client: libp2p connection closed: "
|
||||||
|
f"{conn.muxed_conn.peer_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def listen(self, network, multiaddr):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def listen_close(self, network, multiaddr):
|
||||||
|
pass
|
||||||
|
|
||||||
|
host.get_network().register_notifee(ClientDebugNotifee())
|
||||||
|
|
||||||
maddr = multiaddr.Multiaddr(destination)
|
maddr = multiaddr.Multiaddr(destination)
|
||||||
info = info_from_p2p_addr(maddr)
|
info = info_from_p2p_addr(maddr)
|
||||||
print("🔌 WebSocket Client Starting...")
|
print("🔌 WebSocket Client Starting...")
|
||||||
@ -185,21 +287,34 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
print("🔗 Connecting to WebSocket server...")
|
print("🔗 Connecting to WebSocket server...")
|
||||||
|
print(f" Security: {'Plaintext' if use_plaintext else 'Noise'}")
|
||||||
await host.connect(info)
|
await host.connect(info)
|
||||||
print("✅ Successfully connected to WebSocket server!")
|
print("✅ Successfully connected to WebSocket server!")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = str(e)
|
error_msg = str(e)
|
||||||
if "unable to connect" in error_msg or "SwarmException" in error_msg:
|
print("\n❌ Connection Failed!")
|
||||||
print(f"\n❌ Connection Failed!")
|
print(f" Peer ID: {info.peer_id}")
|
||||||
print(f" Peer ID: {info.peer_id}")
|
print(f" Address: {destination}")
|
||||||
print(f" Address: {destination}")
|
print(f" Security: {'Plaintext' if use_plaintext else 'Noise'}")
|
||||||
print(f" Error: {error_msg}")
|
print(f" Error: {error_msg}")
|
||||||
print()
|
print(f" Error type: {type(e).__name__}")
|
||||||
print("💡 Troubleshooting:")
|
|
||||||
print(" • Make sure the WebSocket server is running")
|
# Add more detailed error information for debugging
|
||||||
print(" • Check that the server address is correct")
|
if hasattr(e, "__cause__") and e.__cause__:
|
||||||
print(" • Verify the server is listening on the right port")
|
print(f" Root cause: {e.__cause__}")
|
||||||
return
|
print(f" Root cause type: {type(e.__cause__).__name__}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("💡 Troubleshooting:")
|
||||||
|
print(" • Make sure the WebSocket server is running")
|
||||||
|
print(" • Check that the server address is correct")
|
||||||
|
print(" • Verify the server is listening on the right port")
|
||||||
|
print(
|
||||||
|
" • Ensure both client and server use the same sec protocol"
|
||||||
|
)
|
||||||
|
if not use_plaintext:
|
||||||
|
print(" • Noise over WebSocket may have compatibility issues")
|
||||||
|
return
|
||||||
|
|
||||||
# Create a stream and send test data
|
# Create a stream and send test data
|
||||||
try:
|
try:
|
||||||
@ -242,8 +357,18 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
|
|||||||
finally:
|
finally:
|
||||||
# Ensure stream is closed
|
# Ensure stream is closed
|
||||||
try:
|
try:
|
||||||
if stream and not await stream.is_closed():
|
if stream:
|
||||||
await stream.close()
|
# Check if stream has is_closed method and use it
|
||||||
|
has_is_closed = hasattr(stream, "is_closed") and callable(
|
||||||
|
getattr(stream, "is_closed")
|
||||||
|
)
|
||||||
|
if has_is_closed:
|
||||||
|
# type: ignore[attr-defined]
|
||||||
|
if not await stream.is_closed():
|
||||||
|
await stream.close()
|
||||||
|
else:
|
||||||
|
# Fallback: just try to close the stream
|
||||||
|
await stream.close()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -257,6 +382,9 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
|
|||||||
print()
|
print()
|
||||||
print("🚀 Your WebSocket transport is ready for production use!")
|
print("🚀 Your WebSocket transport is ready for production use!")
|
||||||
|
|
||||||
|
# Add a small delay to ensure all cleanup is complete
|
||||||
|
await trio.sleep(0.1)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ Error creating WebSocket client: {e}")
|
print(f"❌ Error creating WebSocket client: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
@ -266,12 +394,15 @@ async def run(port: int, destination: str, use_noise: bool = False) -> None:
|
|||||||
def main() -> None:
|
def main() -> None:
|
||||||
description = """
|
description = """
|
||||||
This program demonstrates the libp2p WebSocket transport.
|
This program demonstrates the libp2p WebSocket transport.
|
||||||
First run 'python websocket_demo.py -p <PORT> [--noise]' to start a WebSocket server.
|
First run
|
||||||
Then run 'python websocket_demo.py <ANOTHER_PORT> -d <DESTINATION> [--noise]'
|
'python websocket_demo.py -p <PORT> [--plaintext]' to start a WebSocket server.
|
||||||
|
Then run
|
||||||
|
'python websocket_demo.py <ANOTHER_PORT> -d <DESTINATION> [--plaintext]'
|
||||||
where <DESTINATION> is the multiaddress shown by the server.
|
where <DESTINATION> is the multiaddress shown by the server.
|
||||||
|
|
||||||
By default, this example uses plaintext security for communication.
|
By default, this example uses Noise encryption for secure communication.
|
||||||
Use --noise for testing with Noise encryption (experimental).
|
Use --plaintext for testing with unencrypted communication
|
||||||
|
(not recommended for production).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
example_maddr = (
|
example_maddr = (
|
||||||
@ -287,20 +418,30 @@ def main() -> None:
|
|||||||
help=f"destination multiaddr string, e.g. {example_maddr}",
|
help=f"destination multiaddr string, e.g. {example_maddr}",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--noise",
|
"--plaintext",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="use Noise encryption instead of plaintext security",
|
help=(
|
||||||
|
"use plaintext security instead of Noise encryption "
|
||||||
|
"(not recommended for production)"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Determine security mode: use plaintext by default, Noise if --noise is specified
|
# Determine security mode: use Noise by default,
|
||||||
use_noise = args.noise
|
# plaintext if --plaintext is specified
|
||||||
|
use_plaintext = args.plaintext
|
||||||
|
|
||||||
try:
|
try:
|
||||||
trio.run(run, args.port, args.destination, use_noise)
|
trio.run(run, args.port, args.destination, use_plaintext)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
# This is expected when Ctrl+C is pressed
|
||||||
|
# The signal handler already printed the shutdown message
|
||||||
|
print("✅ Clean exit completed.")
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Unexpected error: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from libp2p.abc import (
|
|||||||
IPeerRouting,
|
IPeerRouting,
|
||||||
IPeerStore,
|
IPeerStore,
|
||||||
ISecureTransport,
|
ISecureTransport,
|
||||||
|
ITransport,
|
||||||
)
|
)
|
||||||
from libp2p.crypto.keys import (
|
from libp2p.crypto.keys import (
|
||||||
KeyPair,
|
KeyPair,
|
||||||
@ -231,14 +232,15 @@ def new_swarm(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create transport based on listen_addrs or default to TCP
|
# Create transport based on listen_addrs or default to TCP
|
||||||
|
transport: ITransport
|
||||||
if listen_addrs is None:
|
if listen_addrs is None:
|
||||||
transport = TCP()
|
transport = TCP()
|
||||||
else:
|
else:
|
||||||
# Use the first address to determine transport type
|
# Use the first address to determine transport type
|
||||||
addr = listen_addrs[0]
|
addr = listen_addrs[0]
|
||||||
transport = create_transport_for_multiaddr(addr, upgrader)
|
transport_maybe = create_transport_for_multiaddr(addr, upgrader)
|
||||||
|
|
||||||
if transport is None:
|
if transport_maybe is None:
|
||||||
# Fallback to TCP if no specific transport found
|
# Fallback to TCP if no specific transport found
|
||||||
if addr.__contains__("tcp"):
|
if addr.__contains__("tcp"):
|
||||||
transport = TCP()
|
transport = TCP()
|
||||||
@ -250,20 +252,8 @@ def new_swarm(
|
|||||||
f"Unknown transport in listen_addrs: {listen_addrs}. "
|
f"Unknown transport in listen_addrs: {listen_addrs}. "
|
||||||
f"Supported protocols: {supported_protocols}"
|
f"Supported protocols: {supported_protocols}"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
# Generate X25519 keypair for Noise
|
transport = transport_maybe
|
||||||
noise_key_pair = create_new_x25519_key_pair()
|
|
||||||
|
|
||||||
# Default security transports (using Noise as primary)
|
|
||||||
secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport] = sec_opt or {
|
|
||||||
NOISE_PROTOCOL_ID: NoiseTransport(
|
|
||||||
key_pair, noise_privkey=noise_key_pair.private_key
|
|
||||||
),
|
|
||||||
TProtocol(secio.ID): secio.Transport(key_pair),
|
|
||||||
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(
|
|
||||||
key_pair, peerstore=peerstore_opt
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Use given muxer preference if provided, otherwise use global default
|
# Use given muxer preference if provided, otherwise use global default
|
||||||
if muxer_preference is not None:
|
if muxer_preference is not None:
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
@ -15,6 +16,8 @@ from libp2p.io.msgio import (
|
|||||||
FixedSizeLenMsgReadWriter,
|
FixedSizeLenMsgReadWriter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
SIZE_NOISE_MESSAGE_LEN = 2
|
SIZE_NOISE_MESSAGE_LEN = 2
|
||||||
MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1
|
MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1
|
||||||
SIZE_NOISE_MESSAGE_BODY_LEN = 2
|
SIZE_NOISE_MESSAGE_BODY_LEN = 2
|
||||||
@ -50,18 +53,25 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
|
|||||||
self.noise_state = noise_state
|
self.noise_state = noise_state
|
||||||
|
|
||||||
async def write_msg(self, msg: bytes, prefix_encoded: bool = False) -> None:
|
async def write_msg(self, msg: bytes, prefix_encoded: bool = False) -> None:
|
||||||
|
logger.debug(f"Noise write_msg: encrypting {len(msg)} bytes")
|
||||||
data_encrypted = self.encrypt(msg)
|
data_encrypted = self.encrypt(msg)
|
||||||
if prefix_encoded:
|
if prefix_encoded:
|
||||||
# Manually add the prefix if needed
|
# Manually add the prefix if needed
|
||||||
data_encrypted = self.prefix + data_encrypted
|
data_encrypted = self.prefix + data_encrypted
|
||||||
|
logger.debug(f"Noise write_msg: writing {len(data_encrypted)} encrypted bytes")
|
||||||
await self.read_writer.write_msg(data_encrypted)
|
await self.read_writer.write_msg(data_encrypted)
|
||||||
|
logger.debug("Noise write_msg: write completed successfully")
|
||||||
|
|
||||||
async def read_msg(self, prefix_encoded: bool = False) -> bytes:
|
async def read_msg(self, prefix_encoded: bool = False) -> bytes:
|
||||||
|
logger.debug("Noise read_msg: reading encrypted message")
|
||||||
noise_msg_encrypted = await self.read_writer.read_msg()
|
noise_msg_encrypted = await self.read_writer.read_msg()
|
||||||
|
logger.debug(f"Noise read_msg: read {len(noise_msg_encrypted)} encrypted bytes")
|
||||||
if prefix_encoded:
|
if prefix_encoded:
|
||||||
return self.decrypt(noise_msg_encrypted[len(self.prefix) :])
|
result = self.decrypt(noise_msg_encrypted[len(self.prefix) :])
|
||||||
else:
|
else:
|
||||||
return self.decrypt(noise_msg_encrypted)
|
result = self.decrypt(noise_msg_encrypted)
|
||||||
|
logger.debug(f"Noise read_msg: decrypted to {len(result)} bytes")
|
||||||
|
return result
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
await self.read_writer.close()
|
await self.read_writer.close()
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from dataclasses import (
|
from dataclasses import (
|
||||||
dataclass,
|
dataclass,
|
||||||
)
|
)
|
||||||
|
import logging
|
||||||
|
|
||||||
from libp2p.crypto.keys import (
|
from libp2p.crypto.keys import (
|
||||||
PrivateKey,
|
PrivateKey,
|
||||||
@ -12,6 +13,8 @@ from libp2p.crypto.serialization import (
|
|||||||
|
|
||||||
from .pb import noise_pb2 as noise_pb
|
from .pb import noise_pb2 as noise_pb
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
SIGNED_DATA_PREFIX = "noise-libp2p-static-key:"
|
SIGNED_DATA_PREFIX = "noise-libp2p-static-key:"
|
||||||
|
|
||||||
|
|
||||||
@ -48,6 +51,8 @@ def make_handshake_payload_sig(
|
|||||||
id_privkey: PrivateKey, noise_static_pubkey: PublicKey
|
id_privkey: PrivateKey, noise_static_pubkey: PublicKey
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
data = make_data_to_be_signed(noise_static_pubkey)
|
data = make_data_to_be_signed(noise_static_pubkey)
|
||||||
|
logger.debug(f"make_handshake_payload_sig: signing data length: {len(data)}")
|
||||||
|
logger.debug(f"make_handshake_payload_sig: signing data hex: {data.hex()}")
|
||||||
return id_privkey.sign(data)
|
return id_privkey.sign(data)
|
||||||
|
|
||||||
|
|
||||||
@ -60,4 +65,27 @@ def verify_handshake_payload_sig(
|
|||||||
2. signed by the private key corresponding to `id_pubkey`
|
2. signed by the private key corresponding to `id_pubkey`
|
||||||
"""
|
"""
|
||||||
expected_data = make_data_to_be_signed(noise_static_pubkey)
|
expected_data = make_data_to_be_signed(noise_static_pubkey)
|
||||||
return payload.id_pubkey.verify(expected_data, payload.id_sig)
|
logger.debug(
|
||||||
|
f"verify_handshake_payload_sig: payload.id_pubkey type: "
|
||||||
|
f"{type(payload.id_pubkey)}"
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"verify_handshake_payload_sig: noise_static_pubkey type: "
|
||||||
|
f"{type(noise_static_pubkey)}"
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"verify_handshake_payload_sig: expected_data length: {len(expected_data)}"
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"verify_handshake_payload_sig: expected_data hex: {expected_data.hex()}"
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"verify_handshake_payload_sig: payload.id_sig length: {len(payload.id_sig)}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
result = payload.id_pubkey.verify(expected_data, payload.id_sig)
|
||||||
|
logger.debug(f"verify_handshake_payload_sig: verification result: {result}")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"verify_handshake_payload_sig: verification exception: {e}")
|
||||||
|
return False
|
||||||
|
|||||||
@ -2,6 +2,7 @@ from abc import (
|
|||||||
ABC,
|
ABC,
|
||||||
abstractmethod,
|
abstractmethod,
|
||||||
)
|
)
|
||||||
|
import logging
|
||||||
|
|
||||||
from cryptography.hazmat.primitives import (
|
from cryptography.hazmat.primitives import (
|
||||||
serialization,
|
serialization,
|
||||||
@ -46,6 +47,8 @@ from .messages import (
|
|||||||
verify_handshake_payload_sig,
|
verify_handshake_payload_sig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class IPattern(ABC):
|
class IPattern(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -95,6 +98,7 @@ class PatternXX(BasePattern):
|
|||||||
self.early_data = early_data
|
self.early_data = early_data
|
||||||
|
|
||||||
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
|
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
|
||||||
|
logger.debug(f"Noise XX handshake_inbound started for peer {self.local_peer}")
|
||||||
noise_state = self.create_noise_state()
|
noise_state = self.create_noise_state()
|
||||||
noise_state.set_as_responder()
|
noise_state.set_as_responder()
|
||||||
noise_state.start_handshake()
|
noise_state.start_handshake()
|
||||||
@ -107,15 +111,22 @@ class PatternXX(BasePattern):
|
|||||||
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
||||||
|
|
||||||
# Consume msg#1.
|
# Consume msg#1.
|
||||||
|
logger.debug("Noise XX handshake_inbound: reading msg#1")
|
||||||
await read_writer.read_msg()
|
await read_writer.read_msg()
|
||||||
|
logger.debug("Noise XX handshake_inbound: read msg#1 successfully")
|
||||||
|
|
||||||
# Send msg#2, which should include our handshake payload.
|
# Send msg#2, which should include our handshake payload.
|
||||||
|
logger.debug("Noise XX handshake_inbound: preparing msg#2")
|
||||||
our_payload = self.make_handshake_payload()
|
our_payload = self.make_handshake_payload()
|
||||||
msg_2 = our_payload.serialize()
|
msg_2 = our_payload.serialize()
|
||||||
|
logger.debug(f"Noise XX handshake_inbound: sending msg#2 ({len(msg_2)} bytes)")
|
||||||
await read_writer.write_msg(msg_2)
|
await read_writer.write_msg(msg_2)
|
||||||
|
logger.debug("Noise XX handshake_inbound: sent msg#2 successfully")
|
||||||
|
|
||||||
# Receive and consume msg#3.
|
# Receive and consume msg#3.
|
||||||
|
logger.debug("Noise XX handshake_inbound: reading msg#3")
|
||||||
msg_3 = await read_writer.read_msg()
|
msg_3 = await read_writer.read_msg()
|
||||||
|
logger.debug(f"Noise XX handshake_inbound: read msg#3 ({len(msg_3)} bytes)")
|
||||||
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
|
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
|
||||||
|
|
||||||
if handshake_state.rs is None:
|
if handshake_state.rs is None:
|
||||||
@ -147,6 +158,7 @@ class PatternXX(BasePattern):
|
|||||||
async def handshake_outbound(
|
async def handshake_outbound(
|
||||||
self, conn: IRawConnection, remote_peer: ID
|
self, conn: IRawConnection, remote_peer: ID
|
||||||
) -> ISecureConn:
|
) -> ISecureConn:
|
||||||
|
logger.debug(f"Noise XX handshake_outbound started to peer {remote_peer}")
|
||||||
noise_state = self.create_noise_state()
|
noise_state = self.create_noise_state()
|
||||||
|
|
||||||
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
||||||
@ -159,11 +171,15 @@ class PatternXX(BasePattern):
|
|||||||
raise NoiseStateError("Handshake state is not initialized")
|
raise NoiseStateError("Handshake state is not initialized")
|
||||||
|
|
||||||
# Send msg#1, which is *not* encrypted.
|
# Send msg#1, which is *not* encrypted.
|
||||||
|
logger.debug("Noise XX handshake_outbound: sending msg#1")
|
||||||
msg_1 = b""
|
msg_1 = b""
|
||||||
await read_writer.write_msg(msg_1)
|
await read_writer.write_msg(msg_1)
|
||||||
|
logger.debug("Noise XX handshake_outbound: sent msg#1 successfully")
|
||||||
|
|
||||||
# Read msg#2 from the remote, which contains the public key of the peer.
|
# Read msg#2 from the remote, which contains the public key of the peer.
|
||||||
|
logger.debug("Noise XX handshake_outbound: reading msg#2")
|
||||||
msg_2 = await read_writer.read_msg()
|
msg_2 = await read_writer.read_msg()
|
||||||
|
logger.debug(f"Noise XX handshake_outbound: read msg#2 ({len(msg_2)} bytes)")
|
||||||
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
|
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
|
||||||
|
|
||||||
if handshake_state.rs is None:
|
if handshake_state.rs is None:
|
||||||
@ -174,8 +190,27 @@ class PatternXX(BasePattern):
|
|||||||
)
|
)
|
||||||
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
|
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Noise XX handshake_outbound: verifying signature for peer {remote_peer}"
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Noise XX handshake_outbound: remote_pubkey type: {type(remote_pubkey)}"
|
||||||
|
)
|
||||||
|
id_pubkey_repr = peer_handshake_payload.id_pubkey.to_bytes().hex()
|
||||||
|
logger.debug(
|
||||||
|
f"Noise XX handshake_outbound: peer_handshake_payload.id_pubkey: "
|
||||||
|
f"{id_pubkey_repr}"
|
||||||
|
)
|
||||||
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
|
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
|
||||||
|
logger.error(
|
||||||
|
f"Noise XX handshake_outbound: signature verification failed for peer "
|
||||||
|
f"{remote_peer}"
|
||||||
|
)
|
||||||
raise InvalidSignature
|
raise InvalidSignature
|
||||||
|
logger.debug(
|
||||||
|
f"Noise XX handshake_outbound: signature verification successful for peer "
|
||||||
|
f"{remote_peer}"
|
||||||
|
)
|
||||||
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
|
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
|
||||||
if remote_peer_id_from_pubkey != remote_peer:
|
if remote_peer_id_from_pubkey != remote_peer:
|
||||||
raise PeerIDMismatchesPubkey(
|
raise PeerIDMismatchesPubkey(
|
||||||
|
|||||||
@ -7,8 +7,10 @@ from .transport_registry import (
|
|||||||
register_transport,
|
register_transport,
|
||||||
get_supported_transport_protocols,
|
get_supported_transport_protocols,
|
||||||
)
|
)
|
||||||
|
from .upgrader import TransportUpgrader
|
||||||
|
from libp2p.abc import ITransport
|
||||||
|
|
||||||
def create_transport(protocol: str, upgrader=None):
|
def create_transport(protocol: str, upgrader: TransportUpgrader | None = None) -> ITransport:
|
||||||
"""
|
"""
|
||||||
Convenience function to create a transport instance.
|
Convenience function to create a transport instance.
|
||||||
|
|
||||||
@ -28,7 +30,10 @@ def create_transport(protocol: str, upgrader=None):
|
|||||||
registry = get_transport_registry()
|
registry = get_transport_registry()
|
||||||
transport_class = registry.get_transport(protocol)
|
transport_class = registry.get_transport(protocol)
|
||||||
if transport_class:
|
if transport_class:
|
||||||
return registry.create_transport(protocol, upgrader)
|
transport = registry.create_transport(protocol, upgrader)
|
||||||
|
if transport is None:
|
||||||
|
raise ValueError(f"Failed to create transport for protocol: {protocol}")
|
||||||
|
return transport
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported transport protocol: {protocol}")
|
raise ValueError(f"Unsupported transport protocol: {protocol}")
|
||||||
|
|
||||||
|
|||||||
@ -3,13 +3,15 @@ Transport registry for dynamic transport selection based on multiaddr protocols.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Type, Optional
|
from typing import Any
|
||||||
|
|
||||||
from multiaddr import Multiaddr
|
from multiaddr import Multiaddr
|
||||||
|
from multiaddr.protocols import Protocol
|
||||||
|
|
||||||
from libp2p.abc import ITransport
|
from libp2p.abc import ITransport
|
||||||
from libp2p.transport.tcp.tcp import TCP
|
from libp2p.transport.tcp.tcp import TCP
|
||||||
from libp2p.transport.websocket.transport import WebsocketTransport
|
|
||||||
from libp2p.transport.upgrader import TransportUpgrader
|
from libp2p.transport.upgrader import TransportUpgrader
|
||||||
|
from libp2p.transport.websocket.transport import WebsocketTransport
|
||||||
|
|
||||||
logger = logging.getLogger("libp2p.transport.registry")
|
logger = logging.getLogger("libp2p.transport.registry")
|
||||||
|
|
||||||
@ -24,7 +26,7 @@ def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool:
|
|||||||
try:
|
try:
|
||||||
# TCP multiaddr should have structure like /ip4/127.0.0.1/tcp/8080
|
# TCP multiaddr should have structure like /ip4/127.0.0.1/tcp/8080
|
||||||
# or /ip6/::1/tcp/8080
|
# or /ip6/::1/tcp/8080
|
||||||
protocols = maddr.protocols()
|
protocols: list[Protocol] = list(maddr.protocols())
|
||||||
|
|
||||||
# Must have at least 2 protocols: network (ip4/ip6) + tcp
|
# Must have at least 2 protocols: network (ip4/ip6) + tcp
|
||||||
if len(protocols) < 2:
|
if len(protocols) < 2:
|
||||||
@ -38,7 +40,8 @@ def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool:
|
|||||||
if protocols[1].name != "tcp":
|
if protocols[1].name != "tcp":
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Should not have any protocols after tcp (unless it's a valid continuation like p2p)
|
# Should not have any protocols after tcp (unless it's a valid
|
||||||
|
# continuation like p2p)
|
||||||
# For now, we'll be strict and only allow network + tcp
|
# For now, we'll be strict and only allow network + tcp
|
||||||
if len(protocols) > 2:
|
if len(protocols) > 2:
|
||||||
# Check if the additional protocols are valid continuations
|
# Check if the additional protocols are valid continuations
|
||||||
@ -63,7 +66,7 @@ def _is_valid_websocket_multiaddr(maddr: Multiaddr) -> bool:
|
|||||||
try:
|
try:
|
||||||
# WebSocket multiaddr should have structure like /ip4/127.0.0.1/tcp/8080/ws
|
# WebSocket multiaddr should have structure like /ip4/127.0.0.1/tcp/8080/ws
|
||||||
# or /ip6/::1/tcp/8080/ws
|
# or /ip6/::1/tcp/8080/ws
|
||||||
protocols = maddr.protocols()
|
protocols: list[Protocol] = list(maddr.protocols())
|
||||||
|
|
||||||
# Must have at least 3 protocols: network (ip4/ip6/dns4/dns6) + tcp + ws
|
# Must have at least 3 protocols: network (ip4/ip6/dns4/dns6) + tcp + ws
|
||||||
if len(protocols) < 3:
|
if len(protocols) < 3:
|
||||||
@ -100,8 +103,8 @@ class TransportRegistry:
|
|||||||
Registry for mapping multiaddr protocols to transport implementations.
|
Registry for mapping multiaddr protocols to transport implementations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self._transports: Dict[str, Type[ITransport]] = {}
|
self._transports: dict[str, type[ITransport]] = {}
|
||||||
self._register_default_transports()
|
self._register_default_transports()
|
||||||
|
|
||||||
def _register_default_transports(self) -> None:
|
def _register_default_transports(self) -> None:
|
||||||
@ -112,7 +115,9 @@ class TransportRegistry:
|
|||||||
# Register WebSocket transport for /ws protocol
|
# Register WebSocket transport for /ws protocol
|
||||||
self.register_transport("ws", WebsocketTransport)
|
self.register_transport("ws", WebsocketTransport)
|
||||||
|
|
||||||
def register_transport(self, protocol: str, transport_class: Type[ITransport]) -> None:
|
def register_transport(
|
||||||
|
self, protocol: str, transport_class: type[ITransport]
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Register a transport class for a specific protocol.
|
Register a transport class for a specific protocol.
|
||||||
|
|
||||||
@ -120,9 +125,11 @@ class TransportRegistry:
|
|||||||
:param transport_class: The transport class to register
|
:param transport_class: The transport class to register
|
||||||
"""
|
"""
|
||||||
self._transports[protocol] = transport_class
|
self._transports[protocol] = transport_class
|
||||||
logger.debug(f"Registered transport {transport_class.__name__} for protocol {protocol}")
|
logger.debug(
|
||||||
|
f"Registered transport {transport_class.__name__} for protocol {protocol}"
|
||||||
|
)
|
||||||
|
|
||||||
def get_transport(self, protocol: str) -> Optional[Type[ITransport]]:
|
def get_transport(self, protocol: str) -> type[ITransport] | None:
|
||||||
"""
|
"""
|
||||||
Get the transport class for a specific protocol.
|
Get the transport class for a specific protocol.
|
||||||
|
|
||||||
@ -135,7 +142,9 @@ class TransportRegistry:
|
|||||||
"""Get list of supported transport protocols."""
|
"""Get list of supported transport protocols."""
|
||||||
return list(self._transports.keys())
|
return list(self._transports.keys())
|
||||||
|
|
||||||
def create_transport(self, protocol: str, upgrader: Optional[TransportUpgrader] = None, **kwargs) -> Optional[ITransport]:
|
def create_transport(
|
||||||
|
self, protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any
|
||||||
|
) -> ITransport | None:
|
||||||
"""
|
"""
|
||||||
Create a transport instance for a specific protocol.
|
Create a transport instance for a specific protocol.
|
||||||
|
|
||||||
@ -152,9 +161,12 @@ class TransportRegistry:
|
|||||||
if protocol == "ws":
|
if protocol == "ws":
|
||||||
# WebSocket transport requires upgrader
|
# WebSocket transport requires upgrader
|
||||||
if upgrader is None:
|
if upgrader is None:
|
||||||
logger.warning(f"WebSocket transport '{protocol}' requires upgrader")
|
logger.warning(
|
||||||
|
f"WebSocket transport '{protocol}' requires upgrader"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
return transport_class(upgrader)
|
# Use explicit WebsocketTransport to avoid type issues
|
||||||
|
return WebsocketTransport(upgrader)
|
||||||
else:
|
else:
|
||||||
# TCP transport doesn't require upgrader
|
# TCP transport doesn't require upgrader
|
||||||
return transport_class()
|
return transport_class()
|
||||||
@ -172,12 +184,14 @@ def get_transport_registry() -> TransportRegistry:
|
|||||||
return _global_registry
|
return _global_registry
|
||||||
|
|
||||||
|
|
||||||
def register_transport(protocol: str, transport_class: Type[ITransport]) -> None:
|
def register_transport(protocol: str, transport_class: type[ITransport]) -> None:
|
||||||
"""Register a transport class in the global registry."""
|
"""Register a transport class in the global registry."""
|
||||||
_global_registry.register_transport(protocol, transport_class)
|
_global_registry.register_transport(protocol, transport_class)
|
||||||
|
|
||||||
|
|
||||||
def create_transport_for_multiaddr(maddr: Multiaddr, upgrader: TransportUpgrader) -> Optional[ITransport]:
|
def create_transport_for_multiaddr(
|
||||||
|
maddr: Multiaddr, upgrader: TransportUpgrader
|
||||||
|
) -> ITransport | None:
|
||||||
"""
|
"""
|
||||||
Create the appropriate transport for a given multiaddr.
|
Create the appropriate transport for a given multiaddr.
|
||||||
|
|
||||||
@ -203,7 +217,10 @@ def create_transport_for_multiaddr(maddr: Multiaddr, upgrader: TransportUpgrader
|
|||||||
return _global_registry.create_transport("tcp", upgrader)
|
return _global_registry.create_transport("tcp", upgrader)
|
||||||
|
|
||||||
# If no supported transport protocol found or structure is invalid, return None
|
# If no supported transport protocol found or structure is invalid, return None
|
||||||
logger.warning(f"No supported transport protocol found or invalid structure in multiaddr: {maddr}")
|
logger.warning(
|
||||||
|
f"No supported transport protocol found or invalid structure in "
|
||||||
|
f"multiaddr: {maddr}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -1,9 +1,13 @@
|
|||||||
from trio.abc import Stream
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from libp2p.io.abc import ReadWriteCloser
|
from libp2p.io.abc import ReadWriteCloser
|
||||||
from libp2p.io.exceptions import IOException
|
from libp2p.io.exceptions import IOException
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class P2PWebSocketConnection(ReadWriteCloser):
|
class P2PWebSocketConnection(ReadWriteCloser):
|
||||||
"""
|
"""
|
||||||
@ -11,7 +15,7 @@ class P2PWebSocketConnection(ReadWriteCloser):
|
|||||||
that libp2p protocols expect.
|
that libp2p protocols expect.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, ws_connection, ws_context=None):
|
def __init__(self, ws_connection: Any, ws_context: Any = None) -> None:
|
||||||
self._ws_connection = ws_connection
|
self._ws_connection = ws_connection
|
||||||
self._ws_context = ws_context
|
self._ws_context = ws_context
|
||||||
self._read_buffer = b""
|
self._read_buffer = b""
|
||||||
@ -19,57 +23,102 @@ class P2PWebSocketConnection(ReadWriteCloser):
|
|||||||
|
|
||||||
async def write(self, data: bytes) -> None:
|
async def write(self, data: bytes) -> None:
|
||||||
try:
|
try:
|
||||||
|
logger.debug(f"WebSocket writing {len(data)} bytes")
|
||||||
# Send as a binary WebSocket message
|
# Send as a binary WebSocket message
|
||||||
await self._ws_connection.send_message(data)
|
await self._ws_connection.send_message(data)
|
||||||
|
logger.debug(f"WebSocket wrote {len(data)} bytes successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket write failed: {e}")
|
||||||
raise IOException from e
|
raise IOException from e
|
||||||
|
|
||||||
async def read(self, n: int | None = None) -> bytes:
|
async def read(self, n: int | None = None) -> bytes:
|
||||||
"""
|
"""
|
||||||
Read up to n bytes (if n is given), else read up to 64KiB.
|
Read up to n bytes (if n is given), else read up to 64KiB.
|
||||||
|
This implementation provides byte-level access to WebSocket messages,
|
||||||
|
which is required for Noise protocol handshake.
|
||||||
"""
|
"""
|
||||||
async with self._read_lock:
|
async with self._read_lock:
|
||||||
try:
|
try:
|
||||||
|
logger.debug(
|
||||||
|
f"WebSocket read requested: n={n}, "
|
||||||
|
f"buffer_size={len(self._read_buffer)}"
|
||||||
|
)
|
||||||
|
|
||||||
# If we have buffered data, return it
|
# If we have buffered data, return it
|
||||||
if self._read_buffer:
|
if self._read_buffer:
|
||||||
if n is None:
|
if n is None:
|
||||||
result = self._read_buffer
|
result = self._read_buffer
|
||||||
self._read_buffer = b""
|
self._read_buffer = b""
|
||||||
|
logger.debug(
|
||||||
|
f"WebSocket read returning all buffered data: "
|
||||||
|
f"{len(result)} bytes"
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
if len(self._read_buffer) >= n:
|
if len(self._read_buffer) >= n:
|
||||||
result = self._read_buffer[:n]
|
result = self._read_buffer[:n]
|
||||||
self._read_buffer = self._read_buffer[n:]
|
self._read_buffer = self._read_buffer[n:]
|
||||||
|
logger.debug(
|
||||||
|
f"WebSocket read returning {len(result)} bytes "
|
||||||
|
f"from buffer"
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
result = self._read_buffer
|
# We need more data, but we have some buffered
|
||||||
self._read_buffer = b""
|
# Keep the buffered data and get more
|
||||||
return result
|
logger.debug(
|
||||||
|
f"WebSocket read needs more data: have "
|
||||||
|
f"{len(self._read_buffer)}, need {n}"
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
# Get the next WebSocket message
|
# If we need exactly n bytes but don't have enough, get more data
|
||||||
message = await self._ws_connection.get_message()
|
while n is not None and (
|
||||||
if isinstance(message, str):
|
not self._read_buffer or len(self._read_buffer) < n
|
||||||
message = message.encode('utf-8')
|
):
|
||||||
|
logger.debug(
|
||||||
|
f"WebSocket read getting more data: "
|
||||||
|
f"buffer_size={len(self._read_buffer)}, need={n}"
|
||||||
|
)
|
||||||
|
# Get the next WebSocket message and treat it as a byte stream
|
||||||
|
# This mimics the Go implementation's NextReader() approach
|
||||||
|
message = await self._ws_connection.get_message()
|
||||||
|
if isinstance(message, str):
|
||||||
|
message = message.encode("utf-8")
|
||||||
|
|
||||||
# Add to buffer
|
logger.debug(
|
||||||
self._read_buffer = message
|
f"WebSocket read received message: {len(message)} bytes"
|
||||||
|
)
|
||||||
|
# Add to buffer
|
||||||
|
self._read_buffer += message
|
||||||
|
|
||||||
# Return requested amount
|
# Return requested amount
|
||||||
if n is None:
|
if n is None:
|
||||||
result = self._read_buffer
|
result = self._read_buffer
|
||||||
self._read_buffer = b""
|
self._read_buffer = b""
|
||||||
|
logger.debug(
|
||||||
|
f"WebSocket read returning all data: {len(result)} bytes"
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
if len(self._read_buffer) >= n:
|
if len(self._read_buffer) >= n:
|
||||||
result = self._read_buffer[:n]
|
result = self._read_buffer[:n]
|
||||||
self._read_buffer = self._read_buffer[n:]
|
self._read_buffer = self._read_buffer[n:]
|
||||||
|
logger.debug(
|
||||||
|
f"WebSocket read returning exact {len(result)} bytes"
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
|
# This should never happen due to the while loop above
|
||||||
result = self._read_buffer
|
result = self._read_buffer
|
||||||
self._read_buffer = b""
|
self._read_buffer = b""
|
||||||
|
logger.debug(
|
||||||
|
f"WebSocket read returning remaining {len(result)} bytes"
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket read failed: {e}")
|
||||||
raise IOException from e
|
raise IOException from e
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
@ -83,12 +132,12 @@ class P2PWebSocketConnection(ReadWriteCloser):
|
|||||||
# Try to get remote address from the WebSocket connection
|
# Try to get remote address from the WebSocket connection
|
||||||
try:
|
try:
|
||||||
remote = self._ws_connection.remote
|
remote = self._ws_connection.remote
|
||||||
if hasattr(remote, 'address') and hasattr(remote, 'port'):
|
if hasattr(remote, "address") and hasattr(remote, "port"):
|
||||||
return str(remote.address), int(remote.port)
|
return str(remote.address), int(remote.port)
|
||||||
elif isinstance(remote, str):
|
elif isinstance(remote, str):
|
||||||
# Parse address:port format
|
# Parse address:port format
|
||||||
if ':' in remote:
|
if ":" in remote:
|
||||||
host, port = remote.rsplit(':', 1)
|
host, port = remote.rsplit(":", 1)
|
||||||
return host, int(port)
|
return host, int(port)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
|
from collections.abc import Awaitable, Callable
|
||||||
import logging
|
import logging
|
||||||
import socket
|
from typing import Any
|
||||||
from typing import Any, Callable
|
|
||||||
|
|
||||||
from multiaddr import Multiaddr
|
from multiaddr import Multiaddr
|
||||||
import trio
|
import trio
|
||||||
@ -9,7 +9,6 @@ from trio_websocket import serve_websocket
|
|||||||
|
|
||||||
from libp2p.abc import IListener
|
from libp2p.abc import IListener
|
||||||
from libp2p.custom_types import THandler
|
from libp2p.custom_types import THandler
|
||||||
from libp2p.network.connection.raw_connection import RawConnection
|
|
||||||
from libp2p.transport.upgrader import TransportUpgrader
|
from libp2p.transport.upgrader import TransportUpgrader
|
||||||
|
|
||||||
from .connection import P2PWebSocketConnection
|
from .connection import P2PWebSocketConnection
|
||||||
@ -27,7 +26,8 @@ class WebsocketListener(IListener):
|
|||||||
self._upgrader = upgrader
|
self._upgrader = upgrader
|
||||||
self._server = None
|
self._server = None
|
||||||
self._shutdown_event = trio.Event()
|
self._shutdown_event = trio.Event()
|
||||||
self._nursery = None
|
self._nursery: trio.Nursery | None = None
|
||||||
|
self._listeners: Any = None
|
||||||
|
|
||||||
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
|
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
|
||||||
logger.debug(f"WebsocketListener.listen called with {maddr}")
|
logger.debug(f"WebsocketListener.listen called with {maddr}")
|
||||||
@ -51,15 +51,15 @@ class WebsocketListener(IListener):
|
|||||||
logger.debug(f"WebsocketListener: host={host}, port={port}")
|
logger.debug(f"WebsocketListener: host={host}, port={port}")
|
||||||
|
|
||||||
async def serve_websocket_tcp(
|
async def serve_websocket_tcp(
|
||||||
handler: Callable,
|
handler: Callable[[Any], Awaitable[None]],
|
||||||
port: int,
|
port: int,
|
||||||
host: str,
|
host: str,
|
||||||
task_status: trio.TaskStatus[list],
|
task_status: TaskStatus[Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start TCP server and handle WebSocket connections manually"""
|
"""Start TCP server and handle WebSocket connections manually"""
|
||||||
logger.debug("serve_websocket_tcp %s %s", host, port)
|
logger.debug("serve_websocket_tcp %s %s", host, port)
|
||||||
|
|
||||||
async def websocket_handler(request):
|
async def websocket_handler(request: Any) -> None:
|
||||||
"""Handle WebSocket requests"""
|
"""Handle WebSocket requests"""
|
||||||
logger.debug("WebSocket request received")
|
logger.debug("WebSocket request received")
|
||||||
try:
|
try:
|
||||||
@ -68,7 +68,7 @@ class WebsocketListener(IListener):
|
|||||||
logger.debug("WebSocket handshake successful")
|
logger.debug("WebSocket handshake successful")
|
||||||
|
|
||||||
# Create the WebSocket connection wrapper
|
# Create the WebSocket connection wrapper
|
||||||
conn = P2PWebSocketConnection(ws_connection)
|
conn = P2PWebSocketConnection(ws_connection) # type: ignore[no-untyped-call]
|
||||||
|
|
||||||
# Call the handler function that was passed to create_listener
|
# Call the handler function that was passed to create_listener
|
||||||
# This handler will handle the security and muxing upgrades
|
# This handler will handle the security and muxing upgrades
|
||||||
@ -77,22 +77,26 @@ class WebsocketListener(IListener):
|
|||||||
|
|
||||||
# Don't keep the connection alive indefinitely
|
# Don't keep the connection alive indefinitely
|
||||||
# Let the handler manage the connection lifecycle
|
# Let the handler manage the connection lifecycle
|
||||||
logger.debug("Handler completed, connection will be managed by handler")
|
logger.debug(
|
||||||
|
"Handler completed, connection will be managed by handler"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"WebSocket connection error: {e}")
|
logger.debug(f"WebSocket connection error: {e}")
|
||||||
logger.debug(f"Error type: {type(e)}")
|
logger.debug(f"Error type: {type(e)}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.debug(f"Traceback: {traceback.format_exc()}")
|
logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||||
# Reject the connection
|
# Reject the connection
|
||||||
try:
|
try:
|
||||||
await request.reject(400)
|
await request.reject(400)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Use trio_websocket.serve_websocket for proper WebSocket handling
|
# Use trio_websocket.serve_websocket for proper WebSocket handling
|
||||||
from trio_websocket import serve_websocket
|
await serve_websocket(
|
||||||
await serve_websocket(websocket_handler, host, port, None, task_status=task_status)
|
websocket_handler, host, port, None, task_status=task_status
|
||||||
|
)
|
||||||
|
|
||||||
# Store the nursery for shutdown
|
# Store the nursery for shutdown
|
||||||
self._nursery = nursery
|
self._nursery = nursery
|
||||||
@ -111,18 +115,21 @@ class WebsocketListener(IListener):
|
|||||||
logger.error(f"Failed to start WebSocket listener for {maddr}")
|
logger.error(f"Failed to start WebSocket listener for {maddr}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Store the listeners for get_addrs() and close() - these are real SocketListener objects
|
# Store the listeners for get_addrs() and close() - these are real
|
||||||
|
# SocketListener objects
|
||||||
self._listeners = started_listeners
|
self._listeners = started_listeners
|
||||||
logger.debug(f"WebsocketListener.listen returning True with WebSocketServer object")
|
logger.debug(
|
||||||
|
"WebsocketListener.listen returning True with WebSocketServer object"
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_addrs(self) -> tuple[Multiaddr, ...]:
|
def get_addrs(self) -> tuple[Multiaddr, ...]:
|
||||||
if not hasattr(self, '_listeners') or not self._listeners:
|
if not hasattr(self, "_listeners") or not self._listeners:
|
||||||
logger.debug("No listeners available for get_addrs()")
|
logger.debug("No listeners available for get_addrs()")
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
# Handle WebSocketServer objects
|
# Handle WebSocketServer objects
|
||||||
if hasattr(self._listeners, 'port'):
|
if hasattr(self._listeners, "port"):
|
||||||
# This is a WebSocketServer object
|
# This is a WebSocketServer object
|
||||||
port = self._listeners.port
|
port = self._listeners.port
|
||||||
# Create a multiaddr from the port
|
# Create a multiaddr from the port
|
||||||
@ -138,12 +145,12 @@ class WebsocketListener(IListener):
|
|||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Close the WebSocket listener and stop accepting new connections"""
|
"""Close the WebSocket listener and stop accepting new connections"""
|
||||||
logger.debug("WebsocketListener.close called")
|
logger.debug("WebsocketListener.close called")
|
||||||
if hasattr(self, '_listeners') and self._listeners:
|
if hasattr(self, "_listeners") and self._listeners:
|
||||||
# Signal shutdown
|
# Signal shutdown
|
||||||
self._shutdown_event.set()
|
self._shutdown_event.set()
|
||||||
|
|
||||||
# Close the WebSocket server
|
# Close the WebSocket server
|
||||||
if hasattr(self._listeners, 'aclose'):
|
if hasattr(self._listeners, "aclose"):
|
||||||
# This is a WebSocketServer object
|
# This is a WebSocketServer object
|
||||||
logger.debug("Closing WebSocket server")
|
logger.debug("Closing WebSocket server")
|
||||||
await self._listeners.aclose()
|
await self._listeners.aclose()
|
||||||
@ -152,12 +159,12 @@ class WebsocketListener(IListener):
|
|||||||
# This is a list of listeners (like TCP)
|
# This is a list of listeners (like TCP)
|
||||||
logger.debug("Closing TCP listeners")
|
logger.debug("Closing TCP listeners")
|
||||||
for listener in self._listeners:
|
for listener in self._listeners:
|
||||||
listener.close()
|
await listener.aclose()
|
||||||
logger.debug("TCP listeners closed")
|
logger.debug("TCP listeners closed")
|
||||||
else:
|
else:
|
||||||
# Unknown type, try to close it directly
|
# Unknown type, try to close it directly
|
||||||
logger.debug("Closing unknown listener type")
|
logger.debug("Closing unknown listener type")
|
||||||
if hasattr(self._listeners, 'close'):
|
if hasattr(self._listeners, "close"):
|
||||||
self._listeners.close()
|
self._listeners.close()
|
||||||
logger.debug("Unknown listener closed")
|
logger.debug("Unknown listener closed")
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from multiaddr import Multiaddr
|
from multiaddr import Multiaddr
|
||||||
from trio_websocket import open_websocket_url
|
|
||||||
|
|
||||||
from libp2p.abc import IListener, ITransport
|
from libp2p.abc import IListener, ITransport
|
||||||
from libp2p.custom_types import THandler
|
from libp2p.custom_types import THandler
|
||||||
@ -11,7 +11,7 @@ from libp2p.transport.upgrader import TransportUpgrader
|
|||||||
from .connection import P2PWebSocketConnection
|
from .connection import P2PWebSocketConnection
|
||||||
from .listener import WebsocketListener
|
from .listener import WebsocketListener
|
||||||
|
|
||||||
logger = logging.getLogger("libp2p.transport.websocket")
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WebsocketTransport(ITransport):
|
class WebsocketTransport(ITransport):
|
||||||
@ -45,6 +45,7 @@ class WebsocketTransport(ITransport):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from trio_websocket import open_websocket_url
|
from trio_websocket import open_websocket_url
|
||||||
|
|
||||||
# Use the context manager but don't exit it immediately
|
# Use the context manager but don't exit it immediately
|
||||||
# The connection will be closed when the RawConnection is closed
|
# The connection will be closed when the RawConnection is closed
|
||||||
ws_context = open_websocket_url(ws_url)
|
ws_context = open_websocket_url(ws_url)
|
||||||
|
|||||||
@ -2,20 +2,20 @@
|
|||||||
Tests for the transport registry functionality.
|
Tests for the transport registry functionality.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
|
||||||
from multiaddr import Multiaddr
|
from multiaddr import Multiaddr
|
||||||
|
|
||||||
from libp2p.abc import ITransport
|
from libp2p.abc import IListener, IRawConnection, ITransport
|
||||||
|
from libp2p.custom_types import THandler
|
||||||
from libp2p.transport.tcp.tcp import TCP
|
from libp2p.transport.tcp.tcp import TCP
|
||||||
from libp2p.transport.websocket.transport import WebsocketTransport
|
|
||||||
from libp2p.transport.transport_registry import (
|
from libp2p.transport.transport_registry import (
|
||||||
TransportRegistry,
|
TransportRegistry,
|
||||||
create_transport_for_multiaddr,
|
create_transport_for_multiaddr,
|
||||||
|
get_supported_transport_protocols,
|
||||||
get_transport_registry,
|
get_transport_registry,
|
||||||
register_transport,
|
register_transport,
|
||||||
get_supported_transport_protocols,
|
|
||||||
)
|
)
|
||||||
from libp2p.transport.upgrader import TransportUpgrader
|
from libp2p.transport.upgrader import TransportUpgrader
|
||||||
|
from libp2p.transport.websocket.transport import WebsocketTransport
|
||||||
|
|
||||||
|
|
||||||
class TestTransportRegistry:
|
class TestTransportRegistry:
|
||||||
@ -36,8 +36,14 @@ class TestTransportRegistry:
|
|||||||
registry = TransportRegistry()
|
registry = TransportRegistry()
|
||||||
|
|
||||||
# Register a custom transport
|
# Register a custom transport
|
||||||
class CustomTransport:
|
class CustomTransport(ITransport):
|
||||||
pass
|
async def dial(self, maddr: Multiaddr) -> IRawConnection:
|
||||||
|
raise NotImplementedError("CustomTransport dial not implemented")
|
||||||
|
|
||||||
|
def create_listener(self, handler_function: THandler) -> IListener:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"CustomTransport create_listener not implemented"
|
||||||
|
)
|
||||||
|
|
||||||
registry.register_transport("custom", CustomTransport)
|
registry.register_transport("custom", CustomTransport)
|
||||||
assert registry.get_transport("custom") == CustomTransport
|
assert registry.get_transport("custom") == CustomTransport
|
||||||
@ -105,8 +111,15 @@ class TestGlobalRegistry:
|
|||||||
|
|
||||||
def test_register_transport_global(self):
|
def test_register_transport_global(self):
|
||||||
"""Test registering transport globally."""
|
"""Test registering transport globally."""
|
||||||
class GlobalCustomTransport:
|
|
||||||
pass
|
class GlobalCustomTransport(ITransport):
|
||||||
|
async def dial(self, maddr: Multiaddr) -> IRawConnection:
|
||||||
|
raise NotImplementedError("GlobalCustomTransport dial not implemented")
|
||||||
|
|
||||||
|
def create_listener(self, handler_function: THandler) -> IListener:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"GlobalCustomTransport create_listener not implemented"
|
||||||
|
)
|
||||||
|
|
||||||
# Register globally
|
# Register globally
|
||||||
register_transport("global_custom", GlobalCustomTransport)
|
register_transport("global_custom", GlobalCustomTransport)
|
||||||
@ -191,17 +204,18 @@ class TestTransportFactory:
|
|||||||
|
|
||||||
assert transport is None
|
assert transport is None
|
||||||
|
|
||||||
def test_create_transport_for_multiaddr_no_upgrader(self):
|
def test_create_transport_for_multiaddr_with_upgrader(self):
|
||||||
"""Test creating transport without upgrader."""
|
"""Test creating transport with upgrader."""
|
||||||
# This should work for TCP but not WebSocket
|
upgrader = TransportUpgrader({}, {})
|
||||||
|
|
||||||
|
# This should work for both TCP and WebSocket with upgrader
|
||||||
maddr_tcp = Multiaddr("/ip4/127.0.0.1/tcp/8080")
|
maddr_tcp = Multiaddr("/ip4/127.0.0.1/tcp/8080")
|
||||||
transport_tcp = create_transport_for_multiaddr(maddr_tcp, None)
|
transport_tcp = create_transport_for_multiaddr(maddr_tcp, upgrader)
|
||||||
assert transport_tcp is not None
|
assert transport_tcp is not None
|
||||||
|
|
||||||
maddr_ws = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
|
maddr_ws = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
|
||||||
transport_ws = create_transport_for_multiaddr(maddr_ws, None)
|
transport_ws = create_transport_for_multiaddr(maddr_ws, upgrader)
|
||||||
# WebSocket transport creation should fail gracefully
|
assert transport_ws is not None
|
||||||
assert transport_ws is None
|
|
||||||
|
|
||||||
|
|
||||||
class TestTransportInterfaceCompliance:
|
class TestTransportInterfaceCompliance:
|
||||||
@ -211,8 +225,8 @@ class TestTransportInterfaceCompliance:
|
|||||||
"""Test that TCP transport implements ITransport."""
|
"""Test that TCP transport implements ITransport."""
|
||||||
transport = TCP()
|
transport = TCP()
|
||||||
assert isinstance(transport, ITransport)
|
assert isinstance(transport, ITransport)
|
||||||
assert hasattr(transport, 'dial')
|
assert hasattr(transport, "dial")
|
||||||
assert hasattr(transport, 'create_listener')
|
assert hasattr(transport, "create_listener")
|
||||||
assert callable(transport.dial)
|
assert callable(transport.dial)
|
||||||
assert callable(transport.create_listener)
|
assert callable(transport.create_listener)
|
||||||
|
|
||||||
@ -221,8 +235,8 @@ class TestTransportInterfaceCompliance:
|
|||||||
upgrader = TransportUpgrader({}, {})
|
upgrader = TransportUpgrader({}, {})
|
||||||
transport = WebsocketTransport(upgrader)
|
transport = WebsocketTransport(upgrader)
|
||||||
assert isinstance(transport, ITransport)
|
assert isinstance(transport, ITransport)
|
||||||
assert hasattr(transport, 'dial')
|
assert hasattr(transport, "dial")
|
||||||
assert hasattr(transport, 'create_listener')
|
assert hasattr(transport, "create_listener")
|
||||||
assert callable(transport.dial)
|
assert callable(transport.dial)
|
||||||
assert callable(transport.create_listener)
|
assert callable(transport.create_listener)
|
||||||
|
|
||||||
@ -236,10 +250,18 @@ class TestErrorHandling:
|
|||||||
upgrader = TransportUpgrader({}, {})
|
upgrader = TransportUpgrader({}, {})
|
||||||
|
|
||||||
# Register a transport that raises an exception
|
# Register a transport that raises an exception
|
||||||
class ExceptionTransport:
|
class ExceptionTransport(ITransport):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
raise RuntimeError("Transport creation failed")
|
raise RuntimeError("Transport creation failed")
|
||||||
|
|
||||||
|
async def dial(self, maddr: Multiaddr) -> IRawConnection:
|
||||||
|
raise NotImplementedError("ExceptionTransport dial not implemented")
|
||||||
|
|
||||||
|
def create_listener(self, handler_function: THandler) -> IListener:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"ExceptionTransport create_listener not implemented"
|
||||||
|
)
|
||||||
|
|
||||||
registry.register_transport("exception", ExceptionTransport)
|
registry.register_transport("exception", ExceptionTransport)
|
||||||
|
|
||||||
# Should handle exception gracefully and return None
|
# Should handle exception gracefully and return None
|
||||||
@ -252,7 +274,8 @@ class TestErrorHandling:
|
|||||||
|
|
||||||
# Test with a multiaddr that has an unsupported transport protocol
|
# Test with a multiaddr that has an unsupported transport protocol
|
||||||
# This should be handled gracefully by our transport registry
|
# This should be handled gracefully by our transport registry
|
||||||
maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/udp/1234") # udp is not a supported transport
|
# udp is not a supported transport
|
||||||
|
maddr = Multiaddr("/ip4/127.0.0.1/tcp/8080/udp/1234")
|
||||||
transport = create_transport_for_multiaddr(maddr, upgrader)
|
transport = create_transport_for_multiaddr(maddr, upgrader)
|
||||||
|
|
||||||
assert transport is None
|
assert transport is None
|
||||||
@ -286,8 +309,14 @@ class TestIntegration:
|
|||||||
assert registry1 is registry2
|
assert registry1 is registry2
|
||||||
|
|
||||||
# Register a transport in one
|
# Register a transport in one
|
||||||
class PersistentTransport:
|
class PersistentTransport(ITransport):
|
||||||
pass
|
async def dial(self, maddr: Multiaddr) -> IRawConnection:
|
||||||
|
raise NotImplementedError("PersistentTransport dial not implemented")
|
||||||
|
|
||||||
|
def create_listener(self, handler_function: THandler) -> IListener:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"PersistentTransport create_listener not implemented"
|
||||||
|
)
|
||||||
|
|
||||||
registry1.register_transport("persistent", PersistentTransport)
|
registry1.register_transport("persistent", PersistentTransport)
|
||||||
|
|
||||||
|
|||||||
@ -1,23 +1,23 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import trio
|
|
||||||
from multiaddr import Multiaddr
|
from multiaddr import Multiaddr
|
||||||
|
import trio
|
||||||
|
|
||||||
from libp2p.crypto.secp256k1 import create_new_key_pair
|
from libp2p.crypto.secp256k1 import create_new_key_pair
|
||||||
from libp2p.custom_types import TProtocol
|
from libp2p.custom_types import TProtocol
|
||||||
from libp2p.host.basic_host import BasicHost
|
from libp2p.host.basic_host import BasicHost
|
||||||
from libp2p.network.swarm import Swarm
|
from libp2p.network.swarm import Swarm
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
from libp2p.peer.peerinfo import PeerInfo
|
|
||||||
from libp2p.peer.peerstore import PeerStore
|
from libp2p.peer.peerstore import PeerStore
|
||||||
from libp2p.security.insecure.transport import InsecureTransport
|
from libp2p.security.insecure.transport import InsecureTransport
|
||||||
from libp2p.stream_muxer.yamux.yamux import Yamux
|
from libp2p.stream_muxer.yamux.yamux import Yamux
|
||||||
from libp2p.transport.upgrader import TransportUpgrader
|
from libp2p.transport.upgrader import TransportUpgrader
|
||||||
from libp2p.transport.websocket.transport import WebsocketTransport
|
from libp2p.transport.websocket.transport import WebsocketTransport
|
||||||
from libp2p.transport.websocket.listener import WebsocketListener
|
|
||||||
from libp2p.transport.exceptions import OpenConnectionError
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
PLAINTEXT_PROTOCOL_ID = "/plaintext/1.0.0"
|
PLAINTEXT_PROTOCOL_ID = "/plaintext/1.0.0"
|
||||||
|
|
||||||
@ -64,9 +64,6 @@ def create_upgrader():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 2. Listener Basic Functionality Tests
|
# 2. Listener Basic Functionality Tests
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_listener_basic_listen():
|
async def test_listener_basic_listen():
|
||||||
@ -76,12 +73,16 @@ async def test_listener_basic_listen():
|
|||||||
|
|
||||||
# Test listening on IPv4
|
# Test listening on IPv4
|
||||||
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
|
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
|
||||||
listener = transport.create_listener(lambda conn: None)
|
|
||||||
|
async def dummy_handler(conn):
|
||||||
|
await trio.sleep(0)
|
||||||
|
|
||||||
|
listener = transport.create_listener(dummy_handler)
|
||||||
|
|
||||||
# Test that listener can be created and has required methods
|
# Test that listener can be created and has required methods
|
||||||
assert hasattr(listener, 'listen')
|
assert hasattr(listener, "listen")
|
||||||
assert hasattr(listener, 'close')
|
assert hasattr(listener, "close")
|
||||||
assert hasattr(listener, 'get_addrs')
|
assert hasattr(listener, "get_addrs")
|
||||||
|
|
||||||
# Test that listener can handle the address
|
# Test that listener can handle the address
|
||||||
assert ma.value_for_protocol("ip4") == "127.0.0.1"
|
assert ma.value_for_protocol("ip4") == "127.0.0.1"
|
||||||
@ -98,7 +99,11 @@ async def test_listener_port_0_handling():
|
|||||||
transport = WebsocketTransport(upgrader)
|
transport = WebsocketTransport(upgrader)
|
||||||
|
|
||||||
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
|
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
|
||||||
listener = transport.create_listener(lambda conn: None)
|
|
||||||
|
async def dummy_handler(conn):
|
||||||
|
await trio.sleep(0)
|
||||||
|
|
||||||
|
listener = transport.create_listener(dummy_handler)
|
||||||
|
|
||||||
# Test that the address can be parsed correctly
|
# Test that the address can be parsed correctly
|
||||||
port_str = ma.value_for_protocol("tcp")
|
port_str = ma.value_for_protocol("tcp")
|
||||||
@ -115,7 +120,11 @@ async def test_listener_any_interface():
|
|||||||
transport = WebsocketTransport(upgrader)
|
transport = WebsocketTransport(upgrader)
|
||||||
|
|
||||||
ma = Multiaddr("/ip4/0.0.0.0/tcp/0/ws")
|
ma = Multiaddr("/ip4/0.0.0.0/tcp/0/ws")
|
||||||
listener = transport.create_listener(lambda conn: None)
|
|
||||||
|
async def dummy_handler(conn):
|
||||||
|
await trio.sleep(0)
|
||||||
|
|
||||||
|
listener = transport.create_listener(dummy_handler)
|
||||||
|
|
||||||
# Test that the address can be parsed correctly
|
# Test that the address can be parsed correctly
|
||||||
host = ma.value_for_protocol("ip4")
|
host = ma.value_for_protocol("ip4")
|
||||||
@ -134,7 +143,11 @@ async def test_listener_address_preservation():
|
|||||||
# Create address with p2p ID
|
# Create address with p2p ID
|
||||||
p2p_id = "12D3KooWL5xtmx8Mgc6tByjVaPPpTKH42QK7PUFQtZLabdSMKHpF"
|
p2p_id = "12D3KooWL5xtmx8Mgc6tByjVaPPpTKH42QK7PUFQtZLabdSMKHpF"
|
||||||
ma = Multiaddr(f"/ip4/127.0.0.1/tcp/0/ws/p2p/{p2p_id}")
|
ma = Multiaddr(f"/ip4/127.0.0.1/tcp/0/ws/p2p/{p2p_id}")
|
||||||
listener = transport.create_listener(lambda conn: None)
|
|
||||||
|
async def dummy_handler(conn):
|
||||||
|
await trio.sleep(0)
|
||||||
|
|
||||||
|
listener = transport.create_listener(dummy_handler)
|
||||||
|
|
||||||
# Test that p2p ID is preserved in the address
|
# Test that p2p ID is preserved in the address
|
||||||
addr_str = str(ma)
|
addr_str = str(ma)
|
||||||
@ -161,7 +174,7 @@ async def test_dial_basic():
|
|||||||
assert port == "8080"
|
assert port == "8080"
|
||||||
|
|
||||||
# Test that transport has the required methods
|
# Test that transport has the required methods
|
||||||
assert hasattr(transport, 'dial')
|
assert hasattr(transport, "dial")
|
||||||
assert callable(transport.dial)
|
assert callable(transport.dial)
|
||||||
|
|
||||||
|
|
||||||
@ -179,7 +192,7 @@ async def test_dial_with_p2p_id():
|
|||||||
assert p2p_id in addr_str
|
assert p2p_id in addr_str
|
||||||
|
|
||||||
# Test that transport can handle addresses with p2p IDs
|
# Test that transport can handle addresses with p2p IDs
|
||||||
assert hasattr(transport, 'dial')
|
assert hasattr(transport, "dial")
|
||||||
assert callable(transport.dial)
|
assert callable(transport.dial)
|
||||||
|
|
||||||
|
|
||||||
@ -197,15 +210,14 @@ async def test_dial_port_0_resolution():
|
|||||||
assert port_str == "0"
|
assert port_str == "0"
|
||||||
|
|
||||||
# Test that transport has the required methods
|
# Test that transport has the required methods
|
||||||
assert hasattr(transport, 'dial')
|
assert hasattr(transport, "dial")
|
||||||
assert callable(transport.dial)
|
assert callable(transport.dial)
|
||||||
|
|
||||||
|
|
||||||
# 4. Address Validation Tests (CRITICAL)
|
# 4. Address Validation Tests (CRITICAL)
|
||||||
def test_address_validation_ipv4():
|
def test_address_validation_ipv4():
|
||||||
"""Test IPv4 address validation"""
|
"""Test IPv4 address validation"""
|
||||||
upgrader = create_upgrader()
|
# upgrader = create_upgrader() # Not used in this test
|
||||||
transport = WebsocketTransport(upgrader)
|
|
||||||
|
|
||||||
# Valid IPv4 WebSocket addresses
|
# Valid IPv4 WebSocket addresses
|
||||||
valid_addresses = [
|
valid_addresses = [
|
||||||
@ -222,7 +234,9 @@ def test_address_validation_ipv4():
|
|||||||
assert "/ws" in transport_addr
|
assert "/ws" in transport_addr
|
||||||
|
|
||||||
# Test that transport can handle addresses with p2p IDs
|
# Test that transport can handle addresses with p2p IDs
|
||||||
p2p_addr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws/p2p/Qmb6owHp6eaWArVbcJJbQSyifyJBttMMjYV76N2hMbf5Vw")
|
p2p_addr = Multiaddr(
|
||||||
|
"/ip4/127.0.0.1/tcp/8080/ws/p2p/Qmb6owHp6eaWArVbcJJbQSyifyJBttMMjYV76N2hMbf5Vw"
|
||||||
|
)
|
||||||
# Should not raise exception when creating transport address
|
# Should not raise exception when creating transport address
|
||||||
transport_addr = str(p2p_addr)
|
transport_addr = str(p2p_addr)
|
||||||
assert "/ws" in transport_addr
|
assert "/ws" in transport_addr
|
||||||
@ -230,8 +244,7 @@ def test_address_validation_ipv4():
|
|||||||
|
|
||||||
def test_address_validation_ipv6():
|
def test_address_validation_ipv6():
|
||||||
"""Test IPv6 address validation"""
|
"""Test IPv6 address validation"""
|
||||||
upgrader = create_upgrader()
|
# upgrader = create_upgrader() # Not used in this test
|
||||||
transport = WebsocketTransport(upgrader)
|
|
||||||
|
|
||||||
# Valid IPv6 WebSocket addresses
|
# Valid IPv6 WebSocket addresses
|
||||||
valid_addresses = [
|
valid_addresses = [
|
||||||
@ -248,8 +261,7 @@ def test_address_validation_ipv6():
|
|||||||
|
|
||||||
def test_address_validation_dns():
|
def test_address_validation_dns():
|
||||||
"""Test DNS address validation"""
|
"""Test DNS address validation"""
|
||||||
upgrader = create_upgrader()
|
# upgrader = create_upgrader() # Not used in this test
|
||||||
transport = WebsocketTransport(upgrader)
|
|
||||||
|
|
||||||
# Valid DNS WebSocket addresses
|
# Valid DNS WebSocket addresses
|
||||||
valid_addresses = [
|
valid_addresses = [
|
||||||
@ -267,15 +279,14 @@ def test_address_validation_dns():
|
|||||||
|
|
||||||
def test_address_validation_mixed():
|
def test_address_validation_mixed():
|
||||||
"""Test mixed address validation"""
|
"""Test mixed address validation"""
|
||||||
upgrader = create_upgrader()
|
# upgrader = create_upgrader() # Not used in this test
|
||||||
transport = WebsocketTransport(upgrader)
|
|
||||||
|
|
||||||
# Mixed valid and invalid addresses
|
# Mixed valid and invalid addresses
|
||||||
addresses = [
|
addresses = [
|
||||||
"/ip4/127.0.0.1/tcp/8080/ws", # Valid
|
"/ip4/127.0.0.1/tcp/8080/ws", # Valid
|
||||||
"/ip4/127.0.0.1/tcp/8080", # Invalid (no /ws)
|
"/ip4/127.0.0.1/tcp/8080", # Invalid (no /ws)
|
||||||
"/ip6/::1/tcp/8080/ws", # Valid
|
"/ip6/::1/tcp/8080/ws", # Valid
|
||||||
"/ip4/127.0.0.1/ws", # Invalid (no tcp)
|
"/ip4/127.0.0.1/ws", # Invalid (no tcp)
|
||||||
"/dns4/example.com/tcp/80/ws", # Valid
|
"/dns4/example.com/tcp/80/ws", # Valid
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -310,15 +321,14 @@ async def test_dial_invalid_address():
|
|||||||
]
|
]
|
||||||
|
|
||||||
for ma in invalid_addresses:
|
for ma in invalid_addresses:
|
||||||
with pytest.raises((ValueError, OpenConnectionError, Exception)):
|
with pytest.raises(Exception):
|
||||||
await transport.dial(ma)
|
await transport.dial(ma)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_listen_invalid_address():
|
async def test_listen_invalid_address():
|
||||||
"""Test listening on invalid addresses"""
|
"""Test listening on invalid addresses"""
|
||||||
upgrader = create_upgrader()
|
# upgrader = create_upgrader() # Not used in this test
|
||||||
transport = WebsocketTransport(upgrader)
|
|
||||||
|
|
||||||
# Test listening on non-WebSocket addresses
|
# Test listening on non-WebSocket addresses
|
||||||
invalid_addresses = [
|
invalid_addresses = [
|
||||||
@ -352,7 +362,7 @@ async def test_listen_port_in_use():
|
|||||||
assert ma2.value_for_protocol("tcp") == "8080"
|
assert ma2.value_for_protocol("tcp") == "8080"
|
||||||
|
|
||||||
# Test that transport can handle these addresses
|
# Test that transport can handle these addresses
|
||||||
assert hasattr(transport, 'create_listener')
|
assert hasattr(transport, "create_listener")
|
||||||
assert callable(transport.create_listener)
|
assert callable(transport.create_listener)
|
||||||
|
|
||||||
|
|
||||||
@ -364,12 +374,15 @@ async def test_connection_close():
|
|||||||
transport = WebsocketTransport(upgrader)
|
transport = WebsocketTransport(upgrader)
|
||||||
|
|
||||||
# Test that transport has required methods
|
# Test that transport has required methods
|
||||||
assert hasattr(transport, 'dial')
|
assert hasattr(transport, "dial")
|
||||||
assert callable(transport.dial)
|
assert callable(transport.dial)
|
||||||
|
|
||||||
# Test that listener can be created and closed
|
# Test that listener can be created and closed
|
||||||
listener = transport.create_listener(lambda conn: None)
|
async def dummy_handler(conn):
|
||||||
assert hasattr(listener, 'close')
|
await trio.sleep(0)
|
||||||
|
|
||||||
|
listener = transport.create_listener(dummy_handler)
|
||||||
|
assert hasattr(listener, "close")
|
||||||
assert callable(listener.close)
|
assert callable(listener.close)
|
||||||
|
|
||||||
# Test that listener can be closed
|
# Test that listener can be closed
|
||||||
@ -397,16 +410,10 @@ async def test_multiple_connections():
|
|||||||
assert port in ["8080", "8081", "8082"]
|
assert port in ["8080", "8081", "8082"]
|
||||||
|
|
||||||
# Test that transport has required methods
|
# Test that transport has required methods
|
||||||
assert hasattr(transport, 'dial')
|
assert hasattr(transport, "dial")
|
||||||
assert callable(transport.dial)
|
assert callable(transport.dial)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Original test (kept for compatibility)
|
# Original test (kept for compatibility)
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_websocket_dial_and_listen():
|
async def test_websocket_dial_and_listen():
|
||||||
@ -416,11 +423,14 @@ async def test_websocket_dial_and_listen():
|
|||||||
transport = WebsocketTransport(upgrader)
|
transport = WebsocketTransport(upgrader)
|
||||||
|
|
||||||
# Test that transport can create listeners
|
# Test that transport can create listeners
|
||||||
listener = transport.create_listener(lambda conn: None)
|
async def dummy_handler(conn):
|
||||||
|
await trio.sleep(0)
|
||||||
|
|
||||||
|
listener = transport.create_listener(dummy_handler)
|
||||||
assert listener is not None
|
assert listener is not None
|
||||||
assert hasattr(listener, 'listen')
|
assert hasattr(listener, "listen")
|
||||||
assert hasattr(listener, 'close')
|
assert hasattr(listener, "close")
|
||||||
assert hasattr(listener, 'get_addrs')
|
assert hasattr(listener, "get_addrs")
|
||||||
|
|
||||||
# Test that transport can handle WebSocket addresses
|
# Test that transport can handle WebSocket addresses
|
||||||
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
|
ma = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
|
||||||
@ -429,7 +439,7 @@ async def test_websocket_dial_and_listen():
|
|||||||
assert "ws" in str(ma)
|
assert "ws" in str(ma)
|
||||||
|
|
||||||
# Test that transport has dial method
|
# Test that transport has dial method
|
||||||
assert hasattr(transport, 'dial')
|
assert hasattr(transport, "dial")
|
||||||
assert callable(transport.dial)
|
assert callable(transport.dial)
|
||||||
|
|
||||||
# Test that transport can handle WebSocket multiaddrs
|
# Test that transport can handle WebSocket multiaddrs
|
||||||
@ -442,14 +452,9 @@ async def test_websocket_dial_and_listen():
|
|||||||
await listener.close()
|
await listener.close()
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_websocket_transport_basic():
|
async def test_websocket_transport_basic():
|
||||||
"""Test basic WebSocket transport functionality without full libp2p stack"""
|
"""Test basic WebSocket transport functionality without full libp2p stack"""
|
||||||
|
|
||||||
# Create WebSocket transport
|
# Create WebSocket transport
|
||||||
key_pair = create_new_key_pair()
|
key_pair = create_new_key_pair()
|
||||||
upgrader = TransportUpgrader(
|
upgrader = TransportUpgrader(
|
||||||
@ -461,14 +466,17 @@ async def test_websocket_transport_basic():
|
|||||||
transport = WebsocketTransport(upgrader)
|
transport = WebsocketTransport(upgrader)
|
||||||
|
|
||||||
assert transport is not None
|
assert transport is not None
|
||||||
assert hasattr(transport, 'dial')
|
assert hasattr(transport, "dial")
|
||||||
assert hasattr(transport, 'create_listener')
|
assert hasattr(transport, "create_listener")
|
||||||
|
|
||||||
listener = transport.create_listener(lambda conn: None)
|
async def dummy_handler(conn):
|
||||||
|
await trio.sleep(0)
|
||||||
|
|
||||||
|
listener = transport.create_listener(dummy_handler)
|
||||||
assert listener is not None
|
assert listener is not None
|
||||||
assert hasattr(listener, 'listen')
|
assert hasattr(listener, "listen")
|
||||||
assert hasattr(listener, 'close')
|
assert hasattr(listener, "close")
|
||||||
assert hasattr(listener, 'get_addrs')
|
assert hasattr(listener, "get_addrs")
|
||||||
|
|
||||||
valid_addr = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
|
valid_addr = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
|
||||||
assert valid_addr.value_for_protocol("ip4") == "127.0.0.1"
|
assert valid_addr.value_for_protocol("ip4") == "127.0.0.1"
|
||||||
@ -480,8 +488,7 @@ async def test_websocket_transport_basic():
|
|||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_websocket_simple_connection():
|
async def test_websocket_simple_connection():
|
||||||
"""Test WebSocket transport creation and basic functionality without real connections"""
|
"""Test WebSocket transport creation and basic functionality without real conn"""
|
||||||
|
|
||||||
# Create WebSocket transport
|
# Create WebSocket transport
|
||||||
key_pair = create_new_key_pair()
|
key_pair = create_new_key_pair()
|
||||||
upgrader = TransportUpgrader(
|
upgrader = TransportUpgrader(
|
||||||
@ -493,17 +500,17 @@ async def test_websocket_simple_connection():
|
|||||||
transport = WebsocketTransport(upgrader)
|
transport = WebsocketTransport(upgrader)
|
||||||
|
|
||||||
assert transport is not None
|
assert transport is not None
|
||||||
assert hasattr(transport, 'dial')
|
assert hasattr(transport, "dial")
|
||||||
assert hasattr(transport, 'create_listener')
|
assert hasattr(transport, "create_listener")
|
||||||
|
|
||||||
async def simple_handler(conn):
|
async def simple_handler(conn):
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
listener = transport.create_listener(simple_handler)
|
listener = transport.create_listener(simple_handler)
|
||||||
assert listener is not None
|
assert listener is not None
|
||||||
assert hasattr(listener, 'listen')
|
assert hasattr(listener, "listen")
|
||||||
assert hasattr(listener, 'close')
|
assert hasattr(listener, "close")
|
||||||
assert hasattr(listener, 'get_addrs')
|
assert hasattr(listener, "get_addrs")
|
||||||
|
|
||||||
test_addr = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
|
test_addr = Multiaddr("/ip4/127.0.0.1/tcp/0/ws")
|
||||||
assert test_addr.value_for_protocol("ip4") == "127.0.0.1"
|
assert test_addr.value_for_protocol("ip4") == "127.0.0.1"
|
||||||
@ -516,7 +523,6 @@ async def test_websocket_simple_connection():
|
|||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_websocket_real_connection():
|
async def test_websocket_real_connection():
|
||||||
"""Test WebSocket transport creation and basic functionality"""
|
"""Test WebSocket transport creation and basic functionality"""
|
||||||
|
|
||||||
# Create WebSocket transport
|
# Create WebSocket transport
|
||||||
key_pair = create_new_key_pair()
|
key_pair = create_new_key_pair()
|
||||||
upgrader = TransportUpgrader(
|
upgrader = TransportUpgrader(
|
||||||
@ -528,17 +534,17 @@ async def test_websocket_real_connection():
|
|||||||
transport = WebsocketTransport(upgrader)
|
transport = WebsocketTransport(upgrader)
|
||||||
|
|
||||||
assert transport is not None
|
assert transport is not None
|
||||||
assert hasattr(transport, 'dial')
|
assert hasattr(transport, "dial")
|
||||||
assert hasattr(transport, 'create_listener')
|
assert hasattr(transport, "create_listener")
|
||||||
|
|
||||||
async def handler(conn):
|
async def handler(conn):
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
listener = transport.create_listener(handler)
|
listener = transport.create_listener(handler)
|
||||||
assert listener is not None
|
assert listener is not None
|
||||||
assert hasattr(listener, 'listen')
|
assert hasattr(listener, "listen")
|
||||||
assert hasattr(listener, 'close')
|
assert hasattr(listener, "close")
|
||||||
assert hasattr(listener, 'get_addrs')
|
assert hasattr(listener, "get_addrs")
|
||||||
|
|
||||||
await listener.close()
|
await listener.close()
|
||||||
|
|
||||||
@ -546,7 +552,6 @@ async def test_websocket_real_connection():
|
|||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_websocket_with_tcp_fallback():
|
async def test_websocket_with_tcp_fallback():
|
||||||
"""Test WebSocket functionality using TCP transport as fallback"""
|
"""Test WebSocket functionality using TCP transport as fallback"""
|
||||||
|
|
||||||
from tests.utils.factories import host_pair_factory
|
from tests.utils.factories import host_pair_factory
|
||||||
|
|
||||||
async with host_pair_factory() as (host_a, host_b):
|
async with host_pair_factory() as (host_a, host_b):
|
||||||
@ -578,7 +583,6 @@ async def test_websocket_with_tcp_fallback():
|
|||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_websocket_transport_interface():
|
async def test_websocket_transport_interface():
|
||||||
"""Test WebSocket transport interface compliance"""
|
"""Test WebSocket transport interface compliance"""
|
||||||
|
|
||||||
key_pair = create_new_key_pair()
|
key_pair = create_new_key_pair()
|
||||||
upgrader = TransportUpgrader(
|
upgrader = TransportUpgrader(
|
||||||
secure_transports_by_protocol={
|
secure_transports_by_protocol={
|
||||||
@ -589,15 +593,18 @@ async def test_websocket_transport_interface():
|
|||||||
|
|
||||||
transport = WebsocketTransport(upgrader)
|
transport = WebsocketTransport(upgrader)
|
||||||
|
|
||||||
assert hasattr(transport, 'dial')
|
assert hasattr(transport, "dial")
|
||||||
assert hasattr(transport, 'create_listener')
|
assert hasattr(transport, "create_listener")
|
||||||
assert callable(transport.dial)
|
assert callable(transport.dial)
|
||||||
assert callable(transport.create_listener)
|
assert callable(transport.create_listener)
|
||||||
|
|
||||||
listener = transport.create_listener(lambda conn: None)
|
async def dummy_handler(conn):
|
||||||
assert hasattr(listener, 'listen')
|
await trio.sleep(0)
|
||||||
assert hasattr(listener, 'close')
|
|
||||||
assert hasattr(listener, 'get_addrs')
|
listener = transport.create_listener(dummy_handler)
|
||||||
|
assert hasattr(listener, "listen")
|
||||||
|
assert hasattr(listener, "close")
|
||||||
|
assert hasattr(listener, "get_addrs")
|
||||||
|
|
||||||
test_addr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
|
test_addr = Multiaddr("/ip4/127.0.0.1/tcp/8080/ws")
|
||||||
host = test_addr.value_for_protocol("ip4")
|
host = test_addr.value_for_protocol("ip4")
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from libp2p.stream_muxer.yamux.yamux import Yamux
|
|||||||
from libp2p.transport.upgrader import TransportUpgrader
|
from libp2p.transport.upgrader import TransportUpgrader
|
||||||
from libp2p.transport.websocket.transport import WebsocketTransport
|
from libp2p.transport.websocket.transport import WebsocketTransport
|
||||||
|
|
||||||
PLAINTEXT_PROTOCOL_ID = "/plaintext/1.0.0"
|
PLAINTEXT_PROTOCOL_ID = "/plaintext/2.0.0"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
@ -74,6 +74,11 @@ async def test_ping_with_js_node():
|
|||||||
peer_id = ID.from_base58(peer_id_line)
|
peer_id = ID.from_base58(peer_id_line)
|
||||||
maddr = Multiaddr(addr_line)
|
maddr = Multiaddr(addr_line)
|
||||||
|
|
||||||
|
# Debug: Print what we're trying to connect to
|
||||||
|
print(f"JS Node Peer ID: {peer_id_line}")
|
||||||
|
print(f"JS Node Address: {addr_line}")
|
||||||
|
print(f"All JS Node lines: {lines}")
|
||||||
|
|
||||||
# Set up Python host
|
# Set up Python host
|
||||||
key_pair = create_new_key_pair()
|
key_pair = create_new_key_pair()
|
||||||
py_peer_id = ID.from_pubkey(key_pair.public_key)
|
py_peer_id = ID.from_pubkey(key_pair.public_key)
|
||||||
@ -86,13 +91,15 @@ async def test_ping_with_js_node():
|
|||||||
},
|
},
|
||||||
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
muxer_transports_by_protocol={TProtocol("/yamux/1.0.0"): Yamux},
|
||||||
)
|
)
|
||||||
transport = WebsocketTransport()
|
transport = WebsocketTransport(upgrader)
|
||||||
swarm = Swarm(py_peer_id, peer_store, upgrader, transport)
|
swarm = Swarm(py_peer_id, peer_store, upgrader, transport)
|
||||||
host = BasicHost(swarm)
|
host = BasicHost(swarm)
|
||||||
|
|
||||||
# Connect to JS node
|
# Connect to JS node
|
||||||
peer_info = PeerInfo(peer_id, [maddr])
|
peer_info = PeerInfo(peer_id, [maddr])
|
||||||
|
|
||||||
|
print(f"Python trying to connect to: {peer_info}")
|
||||||
|
|
||||||
await trio.sleep(1)
|
await trio.sleep(1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user