Fix typecheck errors and improve WebSocket transport implementation

- Fix INotifee interface compliance in WebSocket demo
- Fix handler function signatures to be async (THandler compatibility)
- Fix is_closed method usage with proper type checking
- Fix pytest.raises multiple exception type issue
- Fix line length violations (E501) across multiple files
- Add debugging logging to Noise security module for troubleshooting
- Update WebSocket transport examples and tests
- Improve transport registry error handling
This commit is contained in:
acul71
2025-08-11 01:25:49 +02:00
parent 64107b4648
commit fe4c17e8d1
16 changed files with 845 additions and 488 deletions

View File

@ -19,6 +19,7 @@ from libp2p.abc import (
IPeerRouting,
IPeerStore,
ISecureTransport,
ITransport,
)
from libp2p.crypto.keys import (
KeyPair,
@ -231,14 +232,15 @@ def new_swarm(
)
# Create transport based on listen_addrs or default to TCP
transport: ITransport
if listen_addrs is None:
transport = TCP()
else:
# Use the first address to determine transport type
addr = listen_addrs[0]
transport = create_transport_for_multiaddr(addr, upgrader)
if transport is None:
transport_maybe = create_transport_for_multiaddr(addr, upgrader)
if transport_maybe is None:
# Fallback to TCP if no specific transport found
if addr.__contains__("tcp"):
transport = TCP()
@ -250,20 +252,8 @@ def new_swarm(
f"Unknown transport in listen_addrs: {listen_addrs}. "
f"Supported protocols: {supported_protocols}"
)
# Generate X25519 keypair for Noise
noise_key_pair = create_new_x25519_key_pair()
# Default security transports (using Noise as primary)
secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport] = sec_opt or {
NOISE_PROTOCOL_ID: NoiseTransport(
key_pair, noise_privkey=noise_key_pair.private_key
),
TProtocol(secio.ID): secio.Transport(key_pair),
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(
key_pair, peerstore=peerstore_opt
),
}
else:
transport = transport_maybe
# Use given muxer preference if provided, otherwise use global default
if muxer_preference is not None:

View File

@ -1,3 +1,4 @@
import logging
from typing import (
cast,
)
@ -15,6 +16,8 @@ from libp2p.io.msgio import (
FixedSizeLenMsgReadWriter,
)
logger = logging.getLogger(__name__)
SIZE_NOISE_MESSAGE_LEN = 2
MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1
SIZE_NOISE_MESSAGE_BODY_LEN = 2
@ -50,18 +53,25 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
self.noise_state = noise_state
async def write_msg(self, msg: bytes, prefix_encoded: bool = False) -> None:
logger.debug(f"Noise write_msg: encrypting {len(msg)} bytes")
data_encrypted = self.encrypt(msg)
if prefix_encoded:
# Manually add the prefix if needed
data_encrypted = self.prefix + data_encrypted
logger.debug(f"Noise write_msg: writing {len(data_encrypted)} encrypted bytes")
await self.read_writer.write_msg(data_encrypted)
logger.debug("Noise write_msg: write completed successfully")
async def read_msg(self, prefix_encoded: bool = False) -> bytes:
logger.debug("Noise read_msg: reading encrypted message")
noise_msg_encrypted = await self.read_writer.read_msg()
logger.debug(f"Noise read_msg: read {len(noise_msg_encrypted)} encrypted bytes")
if prefix_encoded:
return self.decrypt(noise_msg_encrypted[len(self.prefix) :])
result = self.decrypt(noise_msg_encrypted[len(self.prefix) :])
else:
return self.decrypt(noise_msg_encrypted)
result = self.decrypt(noise_msg_encrypted)
logger.debug(f"Noise read_msg: decrypted to {len(result)} bytes")
return result
async def close(self) -> None:
await self.read_writer.close()

View File

@ -1,6 +1,7 @@
from dataclasses import (
dataclass,
)
import logging
from libp2p.crypto.keys import (
PrivateKey,
@ -12,6 +13,8 @@ from libp2p.crypto.serialization import (
from .pb import noise_pb2 as noise_pb
logger = logging.getLogger(__name__)
SIGNED_DATA_PREFIX = "noise-libp2p-static-key:"
@ -48,6 +51,8 @@ def make_handshake_payload_sig(
id_privkey: PrivateKey, noise_static_pubkey: PublicKey
) -> bytes:
data = make_data_to_be_signed(noise_static_pubkey)
logger.debug(f"make_handshake_payload_sig: signing data length: {len(data)}")
logger.debug(f"make_handshake_payload_sig: signing data hex: {data.hex()}")
return id_privkey.sign(data)
@ -60,4 +65,27 @@ def verify_handshake_payload_sig(
2. signed by the private key corresponding to `id_pubkey`
"""
expected_data = make_data_to_be_signed(noise_static_pubkey)
return payload.id_pubkey.verify(expected_data, payload.id_sig)
logger.debug(
f"verify_handshake_payload_sig: payload.id_pubkey type: "
f"{type(payload.id_pubkey)}"
)
logger.debug(
f"verify_handshake_payload_sig: noise_static_pubkey type: "
f"{type(noise_static_pubkey)}"
)
logger.debug(
f"verify_handshake_payload_sig: expected_data length: {len(expected_data)}"
)
logger.debug(
f"verify_handshake_payload_sig: expected_data hex: {expected_data.hex()}"
)
logger.debug(
f"verify_handshake_payload_sig: payload.id_sig length: {len(payload.id_sig)}"
)
try:
result = payload.id_pubkey.verify(expected_data, payload.id_sig)
logger.debug(f"verify_handshake_payload_sig: verification result: {result}")
return result
except Exception as e:
logger.error(f"verify_handshake_payload_sig: verification exception: {e}")
return False

View File

@ -2,6 +2,7 @@ from abc import (
ABC,
abstractmethod,
)
import logging
from cryptography.hazmat.primitives import (
serialization,
@ -46,6 +47,8 @@ from .messages import (
verify_handshake_payload_sig,
)
logger = logging.getLogger(__name__)
class IPattern(ABC):
@abstractmethod
@ -95,6 +98,7 @@ class PatternXX(BasePattern):
self.early_data = early_data
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
logger.debug(f"Noise XX handshake_inbound started for peer {self.local_peer}")
noise_state = self.create_noise_state()
noise_state.set_as_responder()
noise_state.start_handshake()
@ -107,15 +111,22 @@ class PatternXX(BasePattern):
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
# Consume msg#1.
logger.debug("Noise XX handshake_inbound: reading msg#1")
await read_writer.read_msg()
logger.debug("Noise XX handshake_inbound: read msg#1 successfully")
# Send msg#2, which should include our handshake payload.
logger.debug("Noise XX handshake_inbound: preparing msg#2")
our_payload = self.make_handshake_payload()
msg_2 = our_payload.serialize()
logger.debug(f"Noise XX handshake_inbound: sending msg#2 ({len(msg_2)} bytes)")
await read_writer.write_msg(msg_2)
logger.debug("Noise XX handshake_inbound: sent msg#2 successfully")
# Receive and consume msg#3.
logger.debug("Noise XX handshake_inbound: reading msg#3")
msg_3 = await read_writer.read_msg()
logger.debug(f"Noise XX handshake_inbound: read msg#3 ({len(msg_3)} bytes)")
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
if handshake_state.rs is None:
@ -147,6 +158,7 @@ class PatternXX(BasePattern):
async def handshake_outbound(
self, conn: IRawConnection, remote_peer: ID
) -> ISecureConn:
logger.debug(f"Noise XX handshake_outbound started to peer {remote_peer}")
noise_state = self.create_noise_state()
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
@ -159,11 +171,15 @@ class PatternXX(BasePattern):
raise NoiseStateError("Handshake state is not initialized")
# Send msg#1, which is *not* encrypted.
logger.debug("Noise XX handshake_outbound: sending msg#1")
msg_1 = b""
await read_writer.write_msg(msg_1)
logger.debug("Noise XX handshake_outbound: sent msg#1 successfully")
# Read msg#2 from the remote, which contains the public key of the peer.
logger.debug("Noise XX handshake_outbound: reading msg#2")
msg_2 = await read_writer.read_msg()
logger.debug(f"Noise XX handshake_outbound: read msg#2 ({len(msg_2)} bytes)")
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
if handshake_state.rs is None:
@ -174,8 +190,27 @@ class PatternXX(BasePattern):
)
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
logger.debug(
f"Noise XX handshake_outbound: verifying signature for peer {remote_peer}"
)
logger.debug(
f"Noise XX handshake_outbound: remote_pubkey type: {type(remote_pubkey)}"
)
id_pubkey_repr = peer_handshake_payload.id_pubkey.to_bytes().hex()
logger.debug(
f"Noise XX handshake_outbound: peer_handshake_payload.id_pubkey: "
f"{id_pubkey_repr}"
)
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
logger.error(
f"Noise XX handshake_outbound: signature verification failed for peer "
f"{remote_peer}"
)
raise InvalidSignature
logger.debug(
f"Noise XX handshake_outbound: signature verification successful for peer "
f"{remote_peer}"
)
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
if remote_peer_id_from_pubkey != remote_peer:
raise PeerIDMismatchesPubkey(

View File

@ -1,17 +1,19 @@
from .tcp.tcp import TCP
from .websocket.transport import WebsocketTransport
from .transport_registry import (
TransportRegistry,
TransportRegistry,
create_transport_for_multiaddr,
get_transport_registry,
register_transport,
get_supported_transport_protocols,
)
from .upgrader import TransportUpgrader
from libp2p.abc import ITransport
def create_transport(protocol: str, upgrader=None):
def create_transport(protocol: str, upgrader: TransportUpgrader | None = None) -> ITransport:
"""
Convenience function to create a transport instance.
:param protocol: The transport protocol ("tcp", "ws", or custom)
:param upgrader: Optional transport upgrader (required for WebSocket)
:return: Transport instance
@ -28,7 +30,10 @@ def create_transport(protocol: str, upgrader=None):
registry = get_transport_registry()
transport_class = registry.get_transport(protocol)
if transport_class:
return registry.create_transport(protocol, upgrader)
transport = registry.create_transport(protocol, upgrader)
if transport is None:
raise ValueError(f"Failed to create transport for protocol: {protocol}")
return transport
else:
raise ValueError(f"Unsupported transport protocol: {protocol}")

View File

@ -3,13 +3,15 @@ Transport registry for dynamic transport selection based on multiaddr protocols.
"""
import logging
from typing import Dict, Type, Optional
from typing import Any
from multiaddr import Multiaddr
from multiaddr.protocols import Protocol
from libp2p.abc import ITransport
from libp2p.transport.tcp.tcp import TCP
from libp2p.transport.websocket.transport import WebsocketTransport
from libp2p.transport.upgrader import TransportUpgrader
from libp2p.transport.websocket.transport import WebsocketTransport
logger = logging.getLogger("libp2p.transport.registry")
@ -17,28 +19,29 @@ logger = logging.getLogger("libp2p.transport.registry")
def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool:
"""
Validate that a multiaddr has a valid TCP structure.
:param maddr: The multiaddr to validate
:return: True if valid TCP structure, False otherwise
"""
try:
# TCP multiaddr should have structure like /ip4/127.0.0.1/tcp/8080
# or /ip6/::1/tcp/8080
protocols = maddr.protocols()
protocols: list[Protocol] = list(maddr.protocols())
# Must have at least 2 protocols: network (ip4/ip6) + tcp
if len(protocols) < 2:
return False
# First protocol should be a network protocol (ip4, ip6, dns4, dns6)
if protocols[0].name not in ["ip4", "ip6", "dns4", "dns6"]:
return False
# Second protocol should be tcp
if protocols[1].name != "tcp":
return False
# Should not have any protocols after tcp (unless it's a valid continuation like p2p)
# Should not have any protocols after tcp (unless it's a valid
# continuation like p2p)
# For now, we'll be strict and only allow network + tcp
if len(protocols) > 2:
# Check if the additional protocols are valid continuations
@ -46,9 +49,9 @@ def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool:
for i in range(2, len(protocols)):
if protocols[i].name not in valid_continuations:
return False
return True
except Exception:
return False
@ -56,31 +59,31 @@ def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool:
def _is_valid_websocket_multiaddr(maddr: Multiaddr) -> bool:
"""
Validate that a multiaddr has a valid WebSocket structure.
:param maddr: The multiaddr to validate
:return: True if valid WebSocket structure, False otherwise
"""
try:
# WebSocket multiaddr should have structure like /ip4/127.0.0.1/tcp/8080/ws
# or /ip6/::1/tcp/8080/ws
protocols = maddr.protocols()
protocols: list[Protocol] = list(maddr.protocols())
# Must have at least 3 protocols: network (ip4/ip6/dns4/dns6) + tcp + ws
if len(protocols) < 3:
return False
# First protocol should be a network protocol (ip4, ip6, dns4, dns6)
if protocols[0].name not in ["ip4", "ip6", "dns4", "dns6"]:
return False
# Second protocol should be tcp
if protocols[1].name != "tcp":
return False
# Last protocol should be ws
if protocols[-1].name != "ws":
return False
# Should not have any protocols between tcp and ws
if len(protocols) > 3:
# Check if the additional protocols are valid continuations
@ -88,9 +91,9 @@ def _is_valid_websocket_multiaddr(maddr: Multiaddr) -> bool:
for i in range(2, len(protocols) - 1):
if protocols[i].name not in valid_continuations:
return False
return True
except Exception:
return False
@ -99,46 +102,52 @@ class TransportRegistry:
"""
Registry for mapping multiaddr protocols to transport implementations.
"""
def __init__(self):
self._transports: Dict[str, Type[ITransport]] = {}
def __init__(self) -> None:
self._transports: dict[str, type[ITransport]] = {}
self._register_default_transports()
def _register_default_transports(self) -> None:
"""Register the default transport implementations."""
# Register TCP transport for /tcp protocol
self.register_transport("tcp", TCP)
# Register WebSocket transport for /ws protocol
self.register_transport("ws", WebsocketTransport)
def register_transport(self, protocol: str, transport_class: Type[ITransport]) -> None:
def register_transport(
self, protocol: str, transport_class: type[ITransport]
) -> None:
"""
Register a transport class for a specific protocol.
:param protocol: The protocol identifier (e.g., "tcp", "ws")
:param transport_class: The transport class to register
"""
self._transports[protocol] = transport_class
logger.debug(f"Registered transport {transport_class.__name__} for protocol {protocol}")
def get_transport(self, protocol: str) -> Optional[Type[ITransport]]:
logger.debug(
f"Registered transport {transport_class.__name__} for protocol {protocol}"
)
def get_transport(self, protocol: str) -> type[ITransport] | None:
"""
Get the transport class for a specific protocol.
:param protocol: The protocol identifier
:return: The transport class or None if not found
"""
return self._transports.get(protocol)
def get_supported_protocols(self) -> list[str]:
"""Get list of supported transport protocols."""
return list(self._transports.keys())
def create_transport(self, protocol: str, upgrader: Optional[TransportUpgrader] = None, **kwargs) -> Optional[ITransport]:
def create_transport(
self, protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any
) -> ITransport | None:
"""
Create a transport instance for a specific protocol.
:param protocol: The protocol identifier
:param upgrader: The transport upgrader instance (required for WebSocket)
:param kwargs: Additional arguments for transport construction
@ -147,14 +156,17 @@ class TransportRegistry:
transport_class = self.get_transport(protocol)
if transport_class is None:
return None
try:
if protocol == "ws":
# WebSocket transport requires upgrader
if upgrader is None:
logger.warning(f"WebSocket transport '{protocol}' requires upgrader")
logger.warning(
f"WebSocket transport '{protocol}' requires upgrader"
)
return None
return transport_class(upgrader)
# Use explicit WebsocketTransport to avoid type issues
return WebsocketTransport(upgrader)
else:
# TCP transport doesn't require upgrader
return transport_class()
@ -172,15 +184,17 @@ def get_transport_registry() -> TransportRegistry:
return _global_registry
def register_transport(protocol: str, transport_class: Type[ITransport]) -> None:
def register_transport(protocol: str, transport_class: type[ITransport]) -> None:
"""Register a transport class in the global registry."""
_global_registry.register_transport(protocol, transport_class)
def create_transport_for_multiaddr(maddr: Multiaddr, upgrader: TransportUpgrader) -> Optional[ITransport]:
def create_transport_for_multiaddr(
maddr: Multiaddr, upgrader: TransportUpgrader
) -> ITransport | None:
"""
Create the appropriate transport for a given multiaddr.
:param maddr: The multiaddr to create transport for
:param upgrader: The transport upgrader instance
:return: Transport instance or None if no suitable transport found
@ -188,7 +202,7 @@ def create_transport_for_multiaddr(maddr: Multiaddr, upgrader: TransportUpgrader
try:
# Get all protocols in the multiaddr
protocols = [proto.name for proto in maddr.protocols()]
# Check for supported transport protocols in order of preference
# We need to validate that the multiaddr structure is valid for our transports
if "ws" in protocols:
@ -201,11 +215,14 @@ def create_transport_for_multiaddr(maddr: Multiaddr, upgrader: TransportUpgrader
# Check if the multiaddr has proper TCP structure
if _is_valid_tcp_multiaddr(maddr):
return _global_registry.create_transport("tcp", upgrader)
# If no supported transport protocol found or structure is invalid, return None
logger.warning(f"No supported transport protocol found or invalid structure in multiaddr: {maddr}")
logger.warning(
f"No supported transport protocol found or invalid structure in "
f"multiaddr: {maddr}"
)
return None
except Exception as e:
# Handle any errors gracefully (e.g., invalid multiaddr)
logger.warning(f"Error processing multiaddr {maddr}: {e}")

View File

@ -1,9 +1,13 @@
from trio.abc import Stream
import logging
from typing import Any
import trio
from libp2p.io.abc import ReadWriteCloser
from libp2p.io.exceptions import IOException
logger = logging.getLogger(__name__)
class P2PWebSocketConnection(ReadWriteCloser):
"""
@ -11,7 +15,7 @@ class P2PWebSocketConnection(ReadWriteCloser):
that libp2p protocols expect.
"""
def __init__(self, ws_connection, ws_context=None):
def __init__(self, ws_connection: Any, ws_context: Any = None) -> None:
self._ws_connection = ws_connection
self._ws_context = ws_context
self._read_buffer = b""
@ -19,57 +23,102 @@ class P2PWebSocketConnection(ReadWriteCloser):
async def write(self, data: bytes) -> None:
try:
logger.debug(f"WebSocket writing {len(data)} bytes")
# Send as a binary WebSocket message
await self._ws_connection.send_message(data)
logger.debug(f"WebSocket wrote {len(data)} bytes successfully")
except Exception as e:
logger.error(f"WebSocket write failed: {e}")
raise IOException from e
async def read(self, n: int | None = None) -> bytes:
"""
Read up to n bytes (if n is given), else read up to 64KiB.
This implementation provides byte-level access to WebSocket messages,
which is required for Noise protocol handshake.
"""
async with self._read_lock:
try:
logger.debug(
f"WebSocket read requested: n={n}, "
f"buffer_size={len(self._read_buffer)}"
)
# If we have buffered data, return it
if self._read_buffer:
if n is None:
result = self._read_buffer
self._read_buffer = b""
logger.debug(
f"WebSocket read returning all buffered data: "
f"{len(result)} bytes"
)
return result
else:
if len(self._read_buffer) >= n:
result = self._read_buffer[:n]
self._read_buffer = self._read_buffer[n:]
logger.debug(
f"WebSocket read returning {len(result)} bytes "
f"from buffer"
)
return result
else:
result = self._read_buffer
self._read_buffer = b""
return result
# We need more data, but we have some buffered
# Keep the buffered data and get more
logger.debug(
f"WebSocket read needs more data: have "
f"{len(self._read_buffer)}, need {n}"
)
pass
# If we need exactly n bytes but don't have enough, get more data
while n is not None and (
not self._read_buffer or len(self._read_buffer) < n
):
logger.debug(
f"WebSocket read getting more data: "
f"buffer_size={len(self._read_buffer)}, need={n}"
)
# Get the next WebSocket message and treat it as a byte stream
# This mimics the Go implementation's NextReader() approach
message = await self._ws_connection.get_message()
if isinstance(message, str):
message = message.encode("utf-8")
logger.debug(
f"WebSocket read received message: {len(message)} bytes"
)
# Add to buffer
self._read_buffer += message
# Get the next WebSocket message
message = await self._ws_connection.get_message()
if isinstance(message, str):
message = message.encode('utf-8')
# Add to buffer
self._read_buffer = message
# Return requested amount
if n is None:
result = self._read_buffer
self._read_buffer = b""
logger.debug(
f"WebSocket read returning all data: {len(result)} bytes"
)
return result
else:
if len(self._read_buffer) >= n:
result = self._read_buffer[:n]
self._read_buffer = self._read_buffer[n:]
logger.debug(
f"WebSocket read returning exact {len(result)} bytes"
)
return result
else:
# This should never happen due to the while loop above
result = self._read_buffer
self._read_buffer = b""
logger.debug(
f"WebSocket read returning remaining {len(result)} bytes"
)
return result
except Exception as e:
logger.error(f"WebSocket read failed: {e}")
raise IOException from e
async def close(self) -> None:
@ -83,12 +132,12 @@ class P2PWebSocketConnection(ReadWriteCloser):
# Try to get remote address from the WebSocket connection
try:
remote = self._ws_connection.remote
if hasattr(remote, 'address') and hasattr(remote, 'port'):
if hasattr(remote, "address") and hasattr(remote, "port"):
return str(remote.address), int(remote.port)
elif isinstance(remote, str):
# Parse address:port format
if ':' in remote:
host, port = remote.rsplit(':', 1)
if ":" in remote:
host, port = remote.rsplit(":", 1)
return host, int(port)
except Exception:
pass

View File

@ -1,6 +1,6 @@
from collections.abc import Awaitable, Callable
import logging
import socket
from typing import Any, Callable
from typing import Any
from multiaddr import Multiaddr
import trio
@ -9,7 +9,6 @@ from trio_websocket import serve_websocket
from libp2p.abc import IListener
from libp2p.custom_types import THandler
from libp2p.network.connection.raw_connection import RawConnection
from libp2p.transport.upgrader import TransportUpgrader
from .connection import P2PWebSocketConnection
@ -27,7 +26,8 @@ class WebsocketListener(IListener):
self._upgrader = upgrader
self._server = None
self._shutdown_event = trio.Event()
self._nursery = None
self._nursery: trio.Nursery | None = None
self._listeners: Any = None
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
logger.debug(f"WebsocketListener.listen called with {maddr}")
@ -47,56 +47,60 @@ class WebsocketListener(IListener):
if port_str is None:
raise ValueError(f"No TCP port found in multiaddr: {maddr}")
port = int(port_str)
logger.debug(f"WebsocketListener: host={host}, port={port}")
async def serve_websocket_tcp(
handler: Callable,
handler: Callable[[Any], Awaitable[None]],
port: int,
host: str,
task_status: trio.TaskStatus[list],
task_status: TaskStatus[Any],
) -> None:
"""Start TCP server and handle WebSocket connections manually"""
logger.debug("serve_websocket_tcp %s %s", host, port)
async def websocket_handler(request):
async def websocket_handler(request: Any) -> None:
"""Handle WebSocket requests"""
logger.debug("WebSocket request received")
try:
# Accept the WebSocket connection
ws_connection = await request.accept()
logger.debug("WebSocket handshake successful")
# Create the WebSocket connection wrapper
conn = P2PWebSocketConnection(ws_connection)
conn = P2PWebSocketConnection(ws_connection) # type: ignore[no-untyped-call]
# Call the handler function that was passed to create_listener
# This handler will handle the security and muxing upgrades
logger.debug("Calling connection handler")
await self._handler(conn)
# Don't keep the connection alive indefinitely
# Let the handler manage the connection lifecycle
logger.debug("Handler completed, connection will be managed by handler")
logger.debug(
"Handler completed, connection will be managed by handler"
)
except Exception as e:
logger.debug(f"WebSocket connection error: {e}")
logger.debug(f"Error type: {type(e)}")
import traceback
logger.debug(f"Traceback: {traceback.format_exc()}")
# Reject the connection
try:
await request.reject(400)
except:
except Exception:
pass
# Use trio_websocket.serve_websocket for proper WebSocket handling
from trio_websocket import serve_websocket
await serve_websocket(websocket_handler, host, port, None, task_status=task_status)
await serve_websocket(
websocket_handler, host, port, None, task_status=task_status
)
# Store the nursery for shutdown
self._nursery = nursery
# Start the server using nursery.start() like TCP does
logger.debug("Calling nursery.start()...")
started_listeners = await nursery.start(
@ -111,18 +115,21 @@ class WebsocketListener(IListener):
logger.error(f"Failed to start WebSocket listener for {maddr}")
return False
# Store the listeners for get_addrs() and close() - these are real SocketListener objects
# Store the listeners for get_addrs() and close() - these are real
# SocketListener objects
self._listeners = started_listeners
logger.debug(f"WebsocketListener.listen returning True with WebSocketServer object")
logger.debug(
"WebsocketListener.listen returning True with WebSocketServer object"
)
return True
def get_addrs(self) -> tuple[Multiaddr, ...]:
if not hasattr(self, '_listeners') or not self._listeners:
if not hasattr(self, "_listeners") or not self._listeners:
logger.debug("No listeners available for get_addrs()")
return ()
# Handle WebSocketServer objects
if hasattr(self._listeners, 'port'):
if hasattr(self._listeners, "port"):
# This is a WebSocketServer object
port = self._listeners.port
# Create a multiaddr from the port
@ -138,12 +145,12 @@ class WebsocketListener(IListener):
async def close(self) -> None:
"""Close the WebSocket listener and stop accepting new connections"""
logger.debug("WebsocketListener.close called")
if hasattr(self, '_listeners') and self._listeners:
if hasattr(self, "_listeners") and self._listeners:
# Signal shutdown
self._shutdown_event.set()
# Close the WebSocket server
if hasattr(self._listeners, 'aclose'):
if hasattr(self._listeners, "aclose"):
# This is a WebSocketServer object
logger.debug("Closing WebSocket server")
await self._listeners.aclose()
@ -152,15 +159,15 @@ class WebsocketListener(IListener):
# This is a list of listeners (like TCP)
logger.debug("Closing TCP listeners")
for listener in self._listeners:
listener.close()
await listener.aclose()
logger.debug("TCP listeners closed")
else:
# Unknown type, try to close it directly
logger.debug("Closing unknown listener type")
if hasattr(self._listeners, 'close'):
if hasattr(self._listeners, "close"):
self._listeners.close()
logger.debug("Unknown listener closed")
# Clear the listeners reference
self._listeners = None
logger.debug("WebsocketListener.close completed")

View File

@ -1,6 +1,6 @@
import logging
from multiaddr import Multiaddr
from trio_websocket import open_websocket_url
from libp2p.abc import IListener, ITransport
from libp2p.custom_types import THandler
@ -11,7 +11,7 @@ from libp2p.transport.upgrader import TransportUpgrader
from .connection import P2PWebSocketConnection
from .listener import WebsocketListener
logger = logging.getLogger("libp2p.transport.websocket")
logger = logging.getLogger(__name__)
class WebsocketTransport(ITransport):
@ -25,7 +25,7 @@ class WebsocketTransport(ITransport):
async def dial(self, maddr: Multiaddr) -> RawConnection:
"""Dial a WebSocket connection to the given multiaddr."""
logger.debug(f"WebsocketTransport.dial called with {maddr}")
# Extract host and port from multiaddr
host = (
maddr.value_for_protocol("ip4")
@ -45,6 +45,7 @@ class WebsocketTransport(ITransport):
try:
from trio_websocket import open_websocket_url
# Use the context manager but don't exit it immediately
# The connection will be closed when the RawConnection is closed
ws_context = open_websocket_url(ws_url)