feat: implement WebSocket transport with transport registry system - Add transport_registry.py for centralized transport management - Integrate WebSocket transport with new registry - Add comprehensive test suite for transport registry - Include WebSocket examples and demos - Update transport initialization and swarm integration

This commit is contained in:
acul71
2025-08-09 23:52:55 +02:00
parent a6f85690bf
commit 64107b4648
15 changed files with 2297 additions and 161 deletions

View File

@ -1,4 +1,5 @@
from trio.abc import Stream
import trio
from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException
@ -6,19 +7,20 @@ from libp2p.io.exceptions import IOException
class P2PWebSocketConnection(ReadWriteCloser):
"""
Wraps a raw trio.abc.Stream from an established websocket connection.
This bypasses message-framing issues and provides the raw stream
Wraps a WebSocketConnection to provide the raw stream interface
that libp2p protocols expect.
"""
_stream: Stream
def __init__(self, stream: Stream):
self._stream = stream
def __init__(self, ws_connection, ws_context=None):
self._ws_connection = ws_connection
self._ws_context = ws_context
self._read_buffer = b""
self._read_lock = trio.Lock()
async def write(self, data: bytes) -> None:
try:
await self._stream.send_all(data)
# Send as a binary WebSocket message
await self._ws_connection.send_message(data)
except Exception as e:
raise IOException from e
@ -26,24 +28,68 @@ class P2PWebSocketConnection(ReadWriteCloser):
"""
Read up to n bytes (if n is given), else read up to 64KiB.
"""
try:
if n is None:
# read a reasonable chunk
return await self._stream.receive_some(2**16)
return await self._stream.receive_some(n)
except Exception as e:
raise IOException from e
async with self._read_lock:
try:
# If we have buffered data, return it
if self._read_buffer:
if n is None:
result = self._read_buffer
self._read_buffer = b""
return result
else:
if len(self._read_buffer) >= n:
result = self._read_buffer[:n]
self._read_buffer = self._read_buffer[n:]
return result
else:
result = self._read_buffer
self._read_buffer = b""
return result
# Get the next WebSocket message
message = await self._ws_connection.get_message()
if isinstance(message, str):
message = message.encode('utf-8')
# Add to buffer
self._read_buffer = message
# Return requested amount
if n is None:
result = self._read_buffer
self._read_buffer = b""
return result
else:
if len(self._read_buffer) >= n:
result = self._read_buffer[:n]
self._read_buffer = self._read_buffer[n:]
return result
else:
result = self._read_buffer
self._read_buffer = b""
return result
except Exception as e:
raise IOException from e
async def close(self) -> None:
await self._stream.aclose()
# Close the WebSocket connection
await self._ws_connection.aclose()
# Exit the context manager if we have one
if self._ws_context is not None:
await self._ws_context.__aexit__(None, None, None)
def get_remote_address(self) -> tuple[str, int] | None:
sock = getattr(self._stream, "socket", None)
if sock:
try:
addr = sock.getpeername()
if isinstance(addr, tuple) and len(addr) >= 2:
return str(addr[0]), int(addr[1])
except OSError:
return None
# Try to get remote address from the WebSocket connection
try:
remote = self._ws_connection.remote
if hasattr(remote, 'address') and hasattr(remote, 'port'):
return str(remote.address), int(remote.port)
elif isinstance(remote, str):
# Parse address:port format
if ':' in remote:
host, port = remote.rsplit(':', 1)
return host, int(port)
except Exception:
pass
return None

View File

@ -1,6 +1,6 @@
import logging
import socket
from typing import Any
from typing import Any, Callable
from multiaddr import Multiaddr
import trio
@ -10,6 +10,7 @@ from trio_websocket import serve_websocket
from libp2p.abc import IListener
from libp2p.custom_types import THandler
from libp2p.network.connection.raw_connection import RawConnection
from libp2p.transport.upgrader import TransportUpgrader
from .connection import P2PWebSocketConnection
@ -21,11 +22,15 @@ class WebsocketListener(IListener):
Listen on /ip4/.../tcp/.../ws addresses, handshake WS, wrap into RawConnection.
"""
def __init__(self, handler: THandler) -> None:
def __init__(self, handler: THandler, upgrader: TransportUpgrader) -> None:
self._handler = handler
self._upgrader = upgrader
self._server = None
self._shutdown_event = trio.Event()
self._nursery = None
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
logger.debug(f"WebsocketListener.listen called with {maddr}")
addr_str = str(maddr)
if addr_str.endswith("/wss"):
raise NotImplementedError("/wss (TLS) not yet supported")
@ -42,43 +47,126 @@ class WebsocketListener(IListener):
if port_str is None:
raise ValueError(f"No TCP port found in multiaddr: {maddr}")
port = int(port_str)
logger.debug(f"WebsocketListener: host={host}, port={port}")
async def serve(
task_status: TaskStatus[Any] = trio.TASK_STATUS_IGNORED,
async def serve_websocket_tcp(
handler: Callable,
port: int,
host: str,
task_status: trio.TaskStatus[list],
) -> None:
# positional ssl_context=None
self._server = await serve_websocket(
self._handle_connection, host, port, None
)
task_status.started()
await self._server.wait_closed()
"""Start TCP server and handle WebSocket connections manually"""
logger.debug("serve_websocket_tcp %s %s", host, port)
async def websocket_handler(request):
"""Handle WebSocket requests"""
logger.debug("WebSocket request received")
try:
# Accept the WebSocket connection
ws_connection = await request.accept()
logger.debug("WebSocket handshake successful")
# Create the WebSocket connection wrapper
conn = P2PWebSocketConnection(ws_connection)
# Call the handler function that was passed to create_listener
# This handler will handle the security and muxing upgrades
logger.debug("Calling connection handler")
await self._handler(conn)
# Don't keep the connection alive indefinitely
# Let the handler manage the connection lifecycle
logger.debug("Handler completed, connection will be managed by handler")
except Exception as e:
logger.debug(f"WebSocket connection error: {e}")
logger.debug(f"Error type: {type(e)}")
import traceback
logger.debug(f"Traceback: {traceback.format_exc()}")
# Reject the connection
try:
await request.reject(400)
except:
pass
# Use trio_websocket.serve_websocket for proper WebSocket handling
from trio_websocket import serve_websocket
await serve_websocket(websocket_handler, host, port, None, task_status=task_status)
await nursery.start(serve)
# Store the nursery for shutdown
self._nursery = nursery
# Start the server using nursery.start() like TCP does
logger.debug("Calling nursery.start()...")
started_listeners = await nursery.start(
serve_websocket_tcp,
None, # No handler needed since it's defined inside serve_websocket_tcp
port,
host,
)
logger.debug(f"nursery.start() returned: {started_listeners}")
if started_listeners is None:
logger.error(f"Failed to start WebSocket listener for {maddr}")
return False
# Store the listeners for get_addrs() and close() - these are real SocketListener objects
self._listeners = started_listeners
logger.debug(f"WebsocketListener.listen returning True with WebSocketServer object")
return True
async def _handle_connection(self, websocket: Any) -> None:
try:
# use raw transport_stream
conn = P2PWebSocketConnection(websocket.stream)
raw = RawConnection(conn, initiator=False)
await self._handler(raw)
except Exception as e:
logger.debug("WebSocket connection error: %s", e)
def get_addrs(self) -> tuple[Multiaddr, ...]:
if not self._server or not self._server.sockets:
if not hasattr(self, '_listeners') or not self._listeners:
logger.debug("No listeners available for get_addrs()")
return ()
addrs = []
for sock in self._server.sockets:
host, port = sock.getsockname()[:2]
if sock.family == socket.AF_INET6:
addr = Multiaddr(f"/ip6/{host}/tcp/{port}/ws")
else:
addr = Multiaddr(f"/ip4/{host}/tcp/{port}/ws")
addrs.append(addr)
return tuple(addrs)
# Handle WebSocketServer objects
if hasattr(self._listeners, 'port'):
# This is a WebSocketServer object
port = self._listeners.port
# Create a multiaddr from the port
return (Multiaddr(f"/ip4/127.0.0.1/tcp/{port}/ws"),)
else:
# This is a list of listeners (like TCP)
listeners = self._listeners
# Get addresses from listeners like TCP does
return tuple(
_multiaddr_from_socket(listener.socket) for listener in listeners
)
async def close(self) -> None:
if self._server:
self._server.close()
await self._server.wait_closed()
"""Close the WebSocket listener and stop accepting new connections"""
logger.debug("WebsocketListener.close called")
if hasattr(self, '_listeners') and self._listeners:
# Signal shutdown
self._shutdown_event.set()
# Close the WebSocket server
if hasattr(self._listeners, 'aclose'):
# This is a WebSocketServer object
logger.debug("Closing WebSocket server")
await self._listeners.aclose()
logger.debug("WebSocket server closed")
elif isinstance(self._listeners, (list, tuple)):
# This is a list of listeners (like TCP)
logger.debug("Closing TCP listeners")
for listener in self._listeners:
listener.close()
logger.debug("TCP listeners closed")
else:
# Unknown type, try to close it directly
logger.debug("Closing unknown listener type")
if hasattr(self._listeners, 'close'):
self._listeners.close()
logger.debug("Unknown listener closed")
# Clear the listeners reference
self._listeners = None
logger.debug("WebsocketListener.close completed")
def _multiaddr_from_socket(socket: trio.socket.SocketType) -> Multiaddr:
"""Convert socket to multiaddr"""
ip, port = socket.getsockname()
return Multiaddr(f"/ip4/{ip}/tcp/{port}/ws")

View File

@ -1,3 +1,4 @@
import logging
from multiaddr import Multiaddr
from trio_websocket import open_websocket_url
@ -5,54 +6,51 @@ from libp2p.abc import IListener, ITransport
from libp2p.custom_types import THandler
from libp2p.network.connection.raw_connection import RawConnection
from libp2p.transport.exceptions import OpenConnectionError
from libp2p.transport.upgrader import TransportUpgrader
from .connection import P2PWebSocketConnection
from .listener import WebsocketListener
logger = logging.getLogger("libp2p.transport.websocket")
class WebsocketTransport(ITransport):
"""
Libp2p WebSocket transport: dial and listen on /ip4/.../tcp/.../ws
"""
def __init__(self, upgrader: TransportUpgrader):
self._upgrader = upgrader
async def dial(self, maddr: Multiaddr) -> RawConnection:
# Handle addresses with /p2p/ PeerID suffix by truncating them at /ws
addr_text = str(maddr)
try:
ws_part_index = addr_text.index("/ws")
# Create a new Multiaddr containing only the transport part
transport_maddr = Multiaddr(addr_text[: ws_part_index + 3])
except ValueError:
raise ValueError(
f"WebsocketTransport requires a /ws protocol, not found in {maddr}"
) from None
# Check for /wss, which is not supported yet
if str(transport_maddr).endswith("/wss"):
raise NotImplementedError("/wss (TLS) not yet supported")
"""Dial a WebSocket connection to the given multiaddr."""
logger.debug(f"WebsocketTransport.dial called with {maddr}")
# Extract host and port from multiaddr
host = (
transport_maddr.value_for_protocol("ip4")
or transport_maddr.value_for_protocol("ip6")
or transport_maddr.value_for_protocol("dns")
or transport_maddr.value_for_protocol("dns4")
or transport_maddr.value_for_protocol("dns6")
maddr.value_for_protocol("ip4")
or maddr.value_for_protocol("ip6")
or maddr.value_for_protocol("dns")
or maddr.value_for_protocol("dns4")
or maddr.value_for_protocol("dns6")
)
if host is None:
raise ValueError(f"No host protocol found in {transport_maddr}")
port_str = transport_maddr.value_for_protocol("tcp")
port_str = maddr.value_for_protocol("tcp")
if port_str is None:
raise ValueError(f"No TCP port found in multiaddr: {transport_maddr}")
raise ValueError(f"No TCP port found in multiaddr: {maddr}")
port = int(port_str)
host_str = f"[{host}]" if ":" in host else host
uri = f"ws://{host_str}:{port}"
# Build WebSocket URL
ws_url = f"ws://{host}:{port}/"
logger.debug(f"WebsocketTransport.dial connecting to {ws_url}")
try:
async with open_websocket_url(uri, ssl_context=None) as ws:
conn = P2PWebSocketConnection(ws.stream) # type: ignore[attr-defined]
return RawConnection(conn, initiator=True)
from trio_websocket import open_websocket_url
# Use the context manager but don't exit it immediately
# The connection will be closed when the RawConnection is closed
ws_context = open_websocket_url(ws_url)
ws = await ws_context.__aenter__()
conn = P2PWebSocketConnection(ws, ws_context) # type: ignore[attr-defined]
return RawConnection(conn, initiator=True)
except Exception as e:
raise OpenConnectionError(f"Failed to dial WebSocket {maddr}: {e}") from e
@ -60,4 +58,5 @@ class WebsocketTransport(ITransport):
"""
The type checker is incorrectly reporting this as an inconsistent override.
"""
return WebsocketListener(handler)
logger.debug("WebsocketTransport.create_listener called")
return WebsocketListener(handler, self._upgrader)