Fix typecheck errors and improve WebSocket transport implementation

- Fix INotifee interface compliance in WebSocket demo
- Fix handler function signatures to be async (THandler compatibility)
- Fix is_closed method usage with proper type checking
- Fix pytest.raises multiple exception type issue
- Fix line length violations (E501) across multiple files
- Add debugging logging to Noise security module for troubleshooting
- Update WebSocket transport examples and tests
- Improve transport registry error handling
This commit is contained in:
acul71
2025-08-11 01:25:49 +02:00
parent 64107b4648
commit fe4c17e8d1
16 changed files with 845 additions and 488 deletions

View File

@ -1,9 +1,13 @@
from trio.abc import Stream
import logging
from typing import Any
import trio
from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException
logger = logging.getLogger(__name__)
class P2PWebSocketConnection(ReadWriteCloser):
"""
@ -11,7 +15,7 @@ class P2PWebSocketConnection(ReadWriteCloser):
that libp2p protocols expect.
"""
def __init__(self, ws_connection, ws_context=None):
def __init__(self, ws_connection: Any, ws_context: Any = None) -> None:
self._ws_connection = ws_connection
self._ws_context = ws_context
self._read_buffer = b""
@ -19,57 +23,102 @@ class P2PWebSocketConnection(ReadWriteCloser):
async def write(self, data: bytes) -> None:
try:
logger.debug(f"WebSocket writing {len(data)} bytes")
# Send as a binary WebSocket message
await self._ws_connection.send_message(data)
logger.debug(f"WebSocket wrote {len(data)} bytes successfully")
except Exception as e:
logger.error(f"WebSocket write failed: {e}")
raise IOException from e
async def read(self, n: int | None = None) -> bytes:
"""
Read up to n bytes (if n is given), else read up to 64KiB.
This implementation provides byte-level access to WebSocket messages,
which is required for Noise protocol handshake.
"""
async with self._read_lock:
try:
logger.debug(
f"WebSocket read requested: n={n}, "
f"buffer_size={len(self._read_buffer)}"
)
# If we have buffered data, return it
if self._read_buffer:
if n is None:
result = self._read_buffer
self._read_buffer = b""
logger.debug(
f"WebSocket read returning all buffered data: "
f"{len(result)} bytes"
)
return result
else:
if len(self._read_buffer) >= n:
result = self._read_buffer[:n]
self._read_buffer = self._read_buffer[n:]
logger.debug(
f"WebSocket read returning {len(result)} bytes "
f"from buffer"
)
return result
else:
result = self._read_buffer
self._read_buffer = b""
return result
# We need more data, but we have some buffered
# Keep the buffered data and get more
logger.debug(
f"WebSocket read needs more data: have "
f"{len(self._read_buffer)}, need {n}"
)
pass
# If we need exactly n bytes but don't have enough, get more data
while n is not None and (
not self._read_buffer or len(self._read_buffer) < n
):
logger.debug(
f"WebSocket read getting more data: "
f"buffer_size={len(self._read_buffer)}, need={n}"
)
# Get the next WebSocket message and treat it as a byte stream
# This mimics the Go implementation's NextReader() approach
message = await self._ws_connection.get_message()
if isinstance(message, str):
message = message.encode("utf-8")
logger.debug(
f"WebSocket read received message: {len(message)} bytes"
)
# Add to buffer
self._read_buffer += message
# Get the next WebSocket message
message = await self._ws_connection.get_message()
if isinstance(message, str):
message = message.encode('utf-8')
# Add to buffer
self._read_buffer = message
# Return requested amount
if n is None:
result = self._read_buffer
self._read_buffer = b""
logger.debug(
f"WebSocket read returning all data: {len(result)} bytes"
)
return result
else:
if len(self._read_buffer) >= n:
result = self._read_buffer[:n]
self._read_buffer = self._read_buffer[n:]
logger.debug(
f"WebSocket read returning exact {len(result)} bytes"
)
return result
else:
# This should never happen due to the while loop above
result = self._read_buffer
self._read_buffer = b""
logger.debug(
f"WebSocket read returning remaining {len(result)} bytes"
)
return result
except Exception as e:
logger.error(f"WebSocket read failed: {e}")
raise IOException from e
async def close(self) -> None:
@ -83,12 +132,12 @@ class P2PWebSocketConnection(ReadWriteCloser):
# Try to get remote address from the WebSocket connection
try:
remote = self._ws_connection.remote
if hasattr(remote, 'address') and hasattr(remote, 'port'):
if hasattr(remote, "address") and hasattr(remote, "port"):
return str(remote.address), int(remote.port)
elif isinstance(remote, str):
# Parse address:port format
if ':' in remote:
host, port = remote.rsplit(':', 1)
if ":" in remote:
host, port = remote.rsplit(":", 1)
return host, int(port)
except Exception:
pass

View File

@ -1,6 +1,6 @@
from collections.abc import Awaitable, Callable
import logging
import socket
from typing import Any, Callable
from typing import Any
from multiaddr import Multiaddr
import trio
@ -9,7 +9,6 @@ from trio_websocket import serve_websocket
from libp2p.abc import IListener
from libp2p.custom_types import THandler
from libp2p.network.connection.raw_connection import RawConnection
from libp2p.transport.upgrader import TransportUpgrader
from .connection import P2PWebSocketConnection
@ -27,7 +26,8 @@ class WebsocketListener(IListener):
self._upgrader = upgrader
self._server = None
self._shutdown_event = trio.Event()
self._nursery = None
self._nursery: trio.Nursery | None = None
self._listeners: Any = None
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
logger.debug(f"WebsocketListener.listen called with {maddr}")
@ -47,56 +47,60 @@ class WebsocketListener(IListener):
if port_str is None:
raise ValueError(f"No TCP port found in multiaddr: {maddr}")
port = int(port_str)
logger.debug(f"WebsocketListener: host={host}, port={port}")
async def serve_websocket_tcp(
handler: Callable,
handler: Callable[[Any], Awaitable[None]],
port: int,
host: str,
task_status: trio.TaskStatus[list],
task_status: TaskStatus[Any],
) -> None:
"""Start TCP server and handle WebSocket connections manually"""
logger.debug("serve_websocket_tcp %s %s", host, port)
async def websocket_handler(request):
async def websocket_handler(request: Any) -> None:
"""Handle WebSocket requests"""
logger.debug("WebSocket request received")
try:
# Accept the WebSocket connection
ws_connection = await request.accept()
logger.debug("WebSocket handshake successful")
# Create the WebSocket connection wrapper
conn = P2PWebSocketConnection(ws_connection)
conn = P2PWebSocketConnection(ws_connection) # type: ignore[no-untyped-call]
# Call the handler function that was passed to create_listener
# This handler will handle the security and muxing upgrades
logger.debug("Calling connection handler")
await self._handler(conn)
# Don't keep the connection alive indefinitely
# Let the handler manage the connection lifecycle
logger.debug("Handler completed, connection will be managed by handler")
logger.debug(
"Handler completed, connection will be managed by handler"
)
except Exception as e:
logger.debug(f"WebSocket connection error: {e}")
logger.debug(f"Error type: {type(e)}")
import traceback
logger.debug(f"Traceback: {traceback.format_exc()}")
# Reject the connection
try:
await request.reject(400)
except:
except Exception:
pass
# Use trio_websocket.serve_websocket for proper WebSocket handling
from trio_websocket import serve_websocket
await serve_websocket(websocket_handler, host, port, None, task_status=task_status)
await serve_websocket(
websocket_handler, host, port, None, task_status=task_status
)
# Store the nursery for shutdown
self._nursery = nursery
# Start the server using nursery.start() like TCP does
logger.debug("Calling nursery.start()...")
started_listeners = await nursery.start(
@ -111,18 +115,21 @@ class WebsocketListener(IListener):
logger.error(f"Failed to start WebSocket listener for {maddr}")
return False
# Store the listeners for get_addrs() and close() - these are real SocketListener objects
# Store the listeners for get_addrs() and close() - these are real
# SocketListener objects
self._listeners = started_listeners
logger.debug(f"WebsocketListener.listen returning True with WebSocketServer object")
logger.debug(
"WebsocketListener.listen returning True with WebSocketServer object"
)
return True
def get_addrs(self) -> tuple[Multiaddr, ...]:
if not hasattr(self, '_listeners') or not self._listeners:
if not hasattr(self, "_listeners") or not self._listeners:
logger.debug("No listeners available for get_addrs()")
return ()
# Handle WebSocketServer objects
if hasattr(self._listeners, 'port'):
if hasattr(self._listeners, "port"):
# This is a WebSocketServer object
port = self._listeners.port
# Create a multiaddr from the port
@ -138,12 +145,12 @@ class WebsocketListener(IListener):
async def close(self) -> None:
"""Close the WebSocket listener and stop accepting new connections"""
logger.debug("WebsocketListener.close called")
if hasattr(self, '_listeners') and self._listeners:
if hasattr(self, "_listeners") and self._listeners:
# Signal shutdown
self._shutdown_event.set()
# Close the WebSocket server
if hasattr(self._listeners, 'aclose'):
if hasattr(self._listeners, "aclose"):
# This is a WebSocketServer object
logger.debug("Closing WebSocket server")
await self._listeners.aclose()
@ -152,15 +159,15 @@ class WebsocketListener(IListener):
# This is a list of listeners (like TCP)
logger.debug("Closing TCP listeners")
for listener in self._listeners:
listener.close()
await listener.aclose()
logger.debug("TCP listeners closed")
else:
# Unknown type, try to close it directly
logger.debug("Closing unknown listener type")
if hasattr(self._listeners, 'close'):
if hasattr(self._listeners, "close"):
self._listeners.close()
logger.debug("Unknown listener closed")
# Clear the listeners reference
self._listeners = None
logger.debug("WebsocketListener.close completed")

View File

@ -1,6 +1,6 @@
import logging
from multiaddr import Multiaddr
from trio_websocket import open_websocket_url
from libp2p.abc import IListener, ITransport
from libp2p.custom_types import THandler
@ -11,7 +11,7 @@ from libp2p.transport.upgrader import TransportUpgrader
from .connection import P2PWebSocketConnection
from .listener import WebsocketListener
logger = logging.getLogger("libp2p.transport.websocket")
logger = logging.getLogger(__name__)
class WebsocketTransport(ITransport):
@ -25,7 +25,7 @@ class WebsocketTransport(ITransport):
async def dial(self, maddr: Multiaddr) -> RawConnection:
"""Dial a WebSocket connection to the given multiaddr."""
logger.debug(f"WebsocketTransport.dial called with {maddr}")
# Extract host and port from multiaddr
host = (
maddr.value_for_protocol("ip4")
@ -45,6 +45,7 @@ class WebsocketTransport(ITransport):
try:
from trio_websocket import open_websocket_url
# Use the context manager but don't exit it immediately
# The connection will be closed when the RawConnection is closed
ws_context = open_websocket_url(ws_url)