mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Fix typecheck errors and improve WebSocket transport implementation
- Fix INotifee interface compliance in WebSocket demo - Fix handler function signatures to be async (THandler compatibility) - Fix is_closed method usage with proper type checking - Fix pytest.raises multiple exception type issue - Fix line length violations (E501) across multiple files - Add debugging logging to Noise security module for troubleshooting - Update WebSocket transport examples and tests - Improve transport registry error handling
This commit is contained in:
@ -19,6 +19,7 @@ from libp2p.abc import (
|
||||
IPeerRouting,
|
||||
IPeerStore,
|
||||
ISecureTransport,
|
||||
ITransport,
|
||||
)
|
||||
from libp2p.crypto.keys import (
|
||||
KeyPair,
|
||||
@ -231,14 +232,15 @@ def new_swarm(
|
||||
)
|
||||
|
||||
# Create transport based on listen_addrs or default to TCP
|
||||
transport: ITransport
|
||||
if listen_addrs is None:
|
||||
transport = TCP()
|
||||
else:
|
||||
# Use the first address to determine transport type
|
||||
addr = listen_addrs[0]
|
||||
transport = create_transport_for_multiaddr(addr, upgrader)
|
||||
|
||||
if transport is None:
|
||||
transport_maybe = create_transport_for_multiaddr(addr, upgrader)
|
||||
|
||||
if transport_maybe is None:
|
||||
# Fallback to TCP if no specific transport found
|
||||
if addr.__contains__("tcp"):
|
||||
transport = TCP()
|
||||
@ -250,20 +252,8 @@ def new_swarm(
|
||||
f"Unknown transport in listen_addrs: {listen_addrs}. "
|
||||
f"Supported protocols: {supported_protocols}"
|
||||
)
|
||||
|
||||
# Generate X25519 keypair for Noise
|
||||
noise_key_pair = create_new_x25519_key_pair()
|
||||
|
||||
# Default security transports (using Noise as primary)
|
||||
secure_transports_by_protocol: Mapping[TProtocol, ISecureTransport] = sec_opt or {
|
||||
NOISE_PROTOCOL_ID: NoiseTransport(
|
||||
key_pair, noise_privkey=noise_key_pair.private_key
|
||||
),
|
||||
TProtocol(secio.ID): secio.Transport(key_pair),
|
||||
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(
|
||||
key_pair, peerstore=peerstore_opt
|
||||
),
|
||||
}
|
||||
else:
|
||||
transport = transport_maybe
|
||||
|
||||
# Use given muxer preference if provided, otherwise use global default
|
||||
if muxer_preference is not None:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from typing import (
|
||||
cast,
|
||||
)
|
||||
@ -15,6 +16,8 @@ from libp2p.io.msgio import (
|
||||
FixedSizeLenMsgReadWriter,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SIZE_NOISE_MESSAGE_LEN = 2
|
||||
MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1
|
||||
SIZE_NOISE_MESSAGE_BODY_LEN = 2
|
||||
@ -50,18 +53,25 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
|
||||
self.noise_state = noise_state
|
||||
|
||||
async def write_msg(self, msg: bytes, prefix_encoded: bool = False) -> None:
|
||||
logger.debug(f"Noise write_msg: encrypting {len(msg)} bytes")
|
||||
data_encrypted = self.encrypt(msg)
|
||||
if prefix_encoded:
|
||||
# Manually add the prefix if needed
|
||||
data_encrypted = self.prefix + data_encrypted
|
||||
logger.debug(f"Noise write_msg: writing {len(data_encrypted)} encrypted bytes")
|
||||
await self.read_writer.write_msg(data_encrypted)
|
||||
logger.debug("Noise write_msg: write completed successfully")
|
||||
|
||||
async def read_msg(self, prefix_encoded: bool = False) -> bytes:
|
||||
logger.debug("Noise read_msg: reading encrypted message")
|
||||
noise_msg_encrypted = await self.read_writer.read_msg()
|
||||
logger.debug(f"Noise read_msg: read {len(noise_msg_encrypted)} encrypted bytes")
|
||||
if prefix_encoded:
|
||||
return self.decrypt(noise_msg_encrypted[len(self.prefix) :])
|
||||
result = self.decrypt(noise_msg_encrypted[len(self.prefix) :])
|
||||
else:
|
||||
return self.decrypt(noise_msg_encrypted)
|
||||
result = self.decrypt(noise_msg_encrypted)
|
||||
logger.debug(f"Noise read_msg: decrypted to {len(result)} bytes")
|
||||
return result
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.read_writer.close()
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
)
|
||||
import logging
|
||||
|
||||
from libp2p.crypto.keys import (
|
||||
PrivateKey,
|
||||
@ -12,6 +13,8 @@ from libp2p.crypto.serialization import (
|
||||
|
||||
from .pb import noise_pb2 as noise_pb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SIGNED_DATA_PREFIX = "noise-libp2p-static-key:"
|
||||
|
||||
|
||||
@ -48,6 +51,8 @@ def make_handshake_payload_sig(
|
||||
id_privkey: PrivateKey, noise_static_pubkey: PublicKey
|
||||
) -> bytes:
|
||||
data = make_data_to_be_signed(noise_static_pubkey)
|
||||
logger.debug(f"make_handshake_payload_sig: signing data length: {len(data)}")
|
||||
logger.debug(f"make_handshake_payload_sig: signing data hex: {data.hex()}")
|
||||
return id_privkey.sign(data)
|
||||
|
||||
|
||||
@ -60,4 +65,27 @@ def verify_handshake_payload_sig(
|
||||
2. signed by the private key corresponding to `id_pubkey`
|
||||
"""
|
||||
expected_data = make_data_to_be_signed(noise_static_pubkey)
|
||||
return payload.id_pubkey.verify(expected_data, payload.id_sig)
|
||||
logger.debug(
|
||||
f"verify_handshake_payload_sig: payload.id_pubkey type: "
|
||||
f"{type(payload.id_pubkey)}"
|
||||
)
|
||||
logger.debug(
|
||||
f"verify_handshake_payload_sig: noise_static_pubkey type: "
|
||||
f"{type(noise_static_pubkey)}"
|
||||
)
|
||||
logger.debug(
|
||||
f"verify_handshake_payload_sig: expected_data length: {len(expected_data)}"
|
||||
)
|
||||
logger.debug(
|
||||
f"verify_handshake_payload_sig: expected_data hex: {expected_data.hex()}"
|
||||
)
|
||||
logger.debug(
|
||||
f"verify_handshake_payload_sig: payload.id_sig length: {len(payload.id_sig)}"
|
||||
)
|
||||
try:
|
||||
result = payload.id_pubkey.verify(expected_data, payload.id_sig)
|
||||
logger.debug(f"verify_handshake_payload_sig: verification result: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"verify_handshake_payload_sig: verification exception: {e}")
|
||||
return False
|
||||
|
||||
@ -2,6 +2,7 @@ from abc import (
|
||||
ABC,
|
||||
abstractmethod,
|
||||
)
|
||||
import logging
|
||||
|
||||
from cryptography.hazmat.primitives import (
|
||||
serialization,
|
||||
@ -46,6 +47,8 @@ from .messages import (
|
||||
verify_handshake_payload_sig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IPattern(ABC):
|
||||
@abstractmethod
|
||||
@ -95,6 +98,7 @@ class PatternXX(BasePattern):
|
||||
self.early_data = early_data
|
||||
|
||||
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
|
||||
logger.debug(f"Noise XX handshake_inbound started for peer {self.local_peer}")
|
||||
noise_state = self.create_noise_state()
|
||||
noise_state.set_as_responder()
|
||||
noise_state.start_handshake()
|
||||
@ -107,15 +111,22 @@ class PatternXX(BasePattern):
|
||||
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
||||
|
||||
# Consume msg#1.
|
||||
logger.debug("Noise XX handshake_inbound: reading msg#1")
|
||||
await read_writer.read_msg()
|
||||
logger.debug("Noise XX handshake_inbound: read msg#1 successfully")
|
||||
|
||||
# Send msg#2, which should include our handshake payload.
|
||||
logger.debug("Noise XX handshake_inbound: preparing msg#2")
|
||||
our_payload = self.make_handshake_payload()
|
||||
msg_2 = our_payload.serialize()
|
||||
logger.debug(f"Noise XX handshake_inbound: sending msg#2 ({len(msg_2)} bytes)")
|
||||
await read_writer.write_msg(msg_2)
|
||||
logger.debug("Noise XX handshake_inbound: sent msg#2 successfully")
|
||||
|
||||
# Receive and consume msg#3.
|
||||
logger.debug("Noise XX handshake_inbound: reading msg#3")
|
||||
msg_3 = await read_writer.read_msg()
|
||||
logger.debug(f"Noise XX handshake_inbound: read msg#3 ({len(msg_3)} bytes)")
|
||||
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
|
||||
|
||||
if handshake_state.rs is None:
|
||||
@ -147,6 +158,7 @@ class PatternXX(BasePattern):
|
||||
async def handshake_outbound(
|
||||
self, conn: IRawConnection, remote_peer: ID
|
||||
) -> ISecureConn:
|
||||
logger.debug(f"Noise XX handshake_outbound started to peer {remote_peer}")
|
||||
noise_state = self.create_noise_state()
|
||||
|
||||
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
||||
@ -159,11 +171,15 @@ class PatternXX(BasePattern):
|
||||
raise NoiseStateError("Handshake state is not initialized")
|
||||
|
||||
# Send msg#1, which is *not* encrypted.
|
||||
logger.debug("Noise XX handshake_outbound: sending msg#1")
|
||||
msg_1 = b""
|
||||
await read_writer.write_msg(msg_1)
|
||||
logger.debug("Noise XX handshake_outbound: sent msg#1 successfully")
|
||||
|
||||
# Read msg#2 from the remote, which contains the public key of the peer.
|
||||
logger.debug("Noise XX handshake_outbound: reading msg#2")
|
||||
msg_2 = await read_writer.read_msg()
|
||||
logger.debug(f"Noise XX handshake_outbound: read msg#2 ({len(msg_2)} bytes)")
|
||||
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
|
||||
|
||||
if handshake_state.rs is None:
|
||||
@ -174,8 +190,27 @@ class PatternXX(BasePattern):
|
||||
)
|
||||
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
|
||||
|
||||
logger.debug(
|
||||
f"Noise XX handshake_outbound: verifying signature for peer {remote_peer}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Noise XX handshake_outbound: remote_pubkey type: {type(remote_pubkey)}"
|
||||
)
|
||||
id_pubkey_repr = peer_handshake_payload.id_pubkey.to_bytes().hex()
|
||||
logger.debug(
|
||||
f"Noise XX handshake_outbound: peer_handshake_payload.id_pubkey: "
|
||||
f"{id_pubkey_repr}"
|
||||
)
|
||||
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
|
||||
logger.error(
|
||||
f"Noise XX handshake_outbound: signature verification failed for peer "
|
||||
f"{remote_peer}"
|
||||
)
|
||||
raise InvalidSignature
|
||||
logger.debug(
|
||||
f"Noise XX handshake_outbound: signature verification successful for peer "
|
||||
f"{remote_peer}"
|
||||
)
|
||||
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
|
||||
if remote_peer_id_from_pubkey != remote_peer:
|
||||
raise PeerIDMismatchesPubkey(
|
||||
|
||||
@ -1,17 +1,19 @@
|
||||
from .tcp.tcp import TCP
|
||||
from .websocket.transport import WebsocketTransport
|
||||
from .transport_registry import (
|
||||
TransportRegistry,
|
||||
TransportRegistry,
|
||||
create_transport_for_multiaddr,
|
||||
get_transport_registry,
|
||||
register_transport,
|
||||
get_supported_transport_protocols,
|
||||
)
|
||||
from .upgrader import TransportUpgrader
|
||||
from libp2p.abc import ITransport
|
||||
|
||||
def create_transport(protocol: str, upgrader=None):
|
||||
def create_transport(protocol: str, upgrader: TransportUpgrader | None = None) -> ITransport:
|
||||
"""
|
||||
Convenience function to create a transport instance.
|
||||
|
||||
|
||||
:param protocol: The transport protocol ("tcp", "ws", or custom)
|
||||
:param upgrader: Optional transport upgrader (required for WebSocket)
|
||||
:return: Transport instance
|
||||
@ -28,7 +30,10 @@ def create_transport(protocol: str, upgrader=None):
|
||||
registry = get_transport_registry()
|
||||
transport_class = registry.get_transport(protocol)
|
||||
if transport_class:
|
||||
return registry.create_transport(protocol, upgrader)
|
||||
transport = registry.create_transport(protocol, upgrader)
|
||||
if transport is None:
|
||||
raise ValueError(f"Failed to create transport for protocol: {protocol}")
|
||||
return transport
|
||||
else:
|
||||
raise ValueError(f"Unsupported transport protocol: {protocol}")
|
||||
|
||||
|
||||
@ -3,13 +3,15 @@ Transport registry for dynamic transport selection based on multiaddr protocols.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Type, Optional
|
||||
from typing import Any
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
from multiaddr.protocols import Protocol
|
||||
|
||||
from libp2p.abc import ITransport
|
||||
from libp2p.transport.tcp.tcp import TCP
|
||||
from libp2p.transport.websocket.transport import WebsocketTransport
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
from libp2p.transport.websocket.transport import WebsocketTransport
|
||||
|
||||
logger = logging.getLogger("libp2p.transport.registry")
|
||||
|
||||
@ -17,28 +19,29 @@ logger = logging.getLogger("libp2p.transport.registry")
|
||||
def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool:
|
||||
"""
|
||||
Validate that a multiaddr has a valid TCP structure.
|
||||
|
||||
|
||||
:param maddr: The multiaddr to validate
|
||||
:return: True if valid TCP structure, False otherwise
|
||||
"""
|
||||
try:
|
||||
# TCP multiaddr should have structure like /ip4/127.0.0.1/tcp/8080
|
||||
# or /ip6/::1/tcp/8080
|
||||
protocols = maddr.protocols()
|
||||
|
||||
protocols: list[Protocol] = list(maddr.protocols())
|
||||
|
||||
# Must have at least 2 protocols: network (ip4/ip6) + tcp
|
||||
if len(protocols) < 2:
|
||||
return False
|
||||
|
||||
|
||||
# First protocol should be a network protocol (ip4, ip6, dns4, dns6)
|
||||
if protocols[0].name not in ["ip4", "ip6", "dns4", "dns6"]:
|
||||
return False
|
||||
|
||||
|
||||
# Second protocol should be tcp
|
||||
if protocols[1].name != "tcp":
|
||||
return False
|
||||
|
||||
# Should not have any protocols after tcp (unless it's a valid continuation like p2p)
|
||||
|
||||
# Should not have any protocols after tcp (unless it's a valid
|
||||
# continuation like p2p)
|
||||
# For now, we'll be strict and only allow network + tcp
|
||||
if len(protocols) > 2:
|
||||
# Check if the additional protocols are valid continuations
|
||||
@ -46,9 +49,9 @@ def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool:
|
||||
for i in range(2, len(protocols)):
|
||||
if protocols[i].name not in valid_continuations:
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@ -56,31 +59,31 @@ def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool:
|
||||
def _is_valid_websocket_multiaddr(maddr: Multiaddr) -> bool:
|
||||
"""
|
||||
Validate that a multiaddr has a valid WebSocket structure.
|
||||
|
||||
|
||||
:param maddr: The multiaddr to validate
|
||||
:return: True if valid WebSocket structure, False otherwise
|
||||
"""
|
||||
try:
|
||||
# WebSocket multiaddr should have structure like /ip4/127.0.0.1/tcp/8080/ws
|
||||
# or /ip6/::1/tcp/8080/ws
|
||||
protocols = maddr.protocols()
|
||||
|
||||
protocols: list[Protocol] = list(maddr.protocols())
|
||||
|
||||
# Must have at least 3 protocols: network (ip4/ip6/dns4/dns6) + tcp + ws
|
||||
if len(protocols) < 3:
|
||||
return False
|
||||
|
||||
|
||||
# First protocol should be a network protocol (ip4, ip6, dns4, dns6)
|
||||
if protocols[0].name not in ["ip4", "ip6", "dns4", "dns6"]:
|
||||
return False
|
||||
|
||||
|
||||
# Second protocol should be tcp
|
||||
if protocols[1].name != "tcp":
|
||||
return False
|
||||
|
||||
|
||||
# Last protocol should be ws
|
||||
if protocols[-1].name != "ws":
|
||||
return False
|
||||
|
||||
|
||||
# Should not have any protocols between tcp and ws
|
||||
if len(protocols) > 3:
|
||||
# Check if the additional protocols are valid continuations
|
||||
@ -88,9 +91,9 @@ def _is_valid_websocket_multiaddr(maddr: Multiaddr) -> bool:
|
||||
for i in range(2, len(protocols) - 1):
|
||||
if protocols[i].name not in valid_continuations:
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@ -99,46 +102,52 @@ class TransportRegistry:
|
||||
"""
|
||||
Registry for mapping multiaddr protocols to transport implementations.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._transports: Dict[str, Type[ITransport]] = {}
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._transports: dict[str, type[ITransport]] = {}
|
||||
self._register_default_transports()
|
||||
|
||||
|
||||
def _register_default_transports(self) -> None:
|
||||
"""Register the default transport implementations."""
|
||||
# Register TCP transport for /tcp protocol
|
||||
self.register_transport("tcp", TCP)
|
||||
|
||||
|
||||
# Register WebSocket transport for /ws protocol
|
||||
self.register_transport("ws", WebsocketTransport)
|
||||
|
||||
def register_transport(self, protocol: str, transport_class: Type[ITransport]) -> None:
|
||||
|
||||
def register_transport(
|
||||
self, protocol: str, transport_class: type[ITransport]
|
||||
) -> None:
|
||||
"""
|
||||
Register a transport class for a specific protocol.
|
||||
|
||||
|
||||
:param protocol: The protocol identifier (e.g., "tcp", "ws")
|
||||
:param transport_class: The transport class to register
|
||||
"""
|
||||
self._transports[protocol] = transport_class
|
||||
logger.debug(f"Registered transport {transport_class.__name__} for protocol {protocol}")
|
||||
|
||||
def get_transport(self, protocol: str) -> Optional[Type[ITransport]]:
|
||||
logger.debug(
|
||||
f"Registered transport {transport_class.__name__} for protocol {protocol}"
|
||||
)
|
||||
|
||||
def get_transport(self, protocol: str) -> type[ITransport] | None:
|
||||
"""
|
||||
Get the transport class for a specific protocol.
|
||||
|
||||
|
||||
:param protocol: The protocol identifier
|
||||
:return: The transport class or None if not found
|
||||
"""
|
||||
return self._transports.get(protocol)
|
||||
|
||||
|
||||
def get_supported_protocols(self) -> list[str]:
|
||||
"""Get list of supported transport protocols."""
|
||||
return list(self._transports.keys())
|
||||
|
||||
def create_transport(self, protocol: str, upgrader: Optional[TransportUpgrader] = None, **kwargs) -> Optional[ITransport]:
|
||||
|
||||
def create_transport(
|
||||
self, protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any
|
||||
) -> ITransport | None:
|
||||
"""
|
||||
Create a transport instance for a specific protocol.
|
||||
|
||||
|
||||
:param protocol: The protocol identifier
|
||||
:param upgrader: The transport upgrader instance (required for WebSocket)
|
||||
:param kwargs: Additional arguments for transport construction
|
||||
@ -147,14 +156,17 @@ class TransportRegistry:
|
||||
transport_class = self.get_transport(protocol)
|
||||
if transport_class is None:
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
if protocol == "ws":
|
||||
# WebSocket transport requires upgrader
|
||||
if upgrader is None:
|
||||
logger.warning(f"WebSocket transport '{protocol}' requires upgrader")
|
||||
logger.warning(
|
||||
f"WebSocket transport '{protocol}' requires upgrader"
|
||||
)
|
||||
return None
|
||||
return transport_class(upgrader)
|
||||
# Use explicit WebsocketTransport to avoid type issues
|
||||
return WebsocketTransport(upgrader)
|
||||
else:
|
||||
# TCP transport doesn't require upgrader
|
||||
return transport_class()
|
||||
@ -172,15 +184,17 @@ def get_transport_registry() -> TransportRegistry:
|
||||
return _global_registry
|
||||
|
||||
|
||||
def register_transport(protocol: str, transport_class: Type[ITransport]) -> None:
|
||||
def register_transport(protocol: str, transport_class: type[ITransport]) -> None:
|
||||
"""Register a transport class in the global registry."""
|
||||
_global_registry.register_transport(protocol, transport_class)
|
||||
|
||||
|
||||
def create_transport_for_multiaddr(maddr: Multiaddr, upgrader: TransportUpgrader) -> Optional[ITransport]:
|
||||
def create_transport_for_multiaddr(
|
||||
maddr: Multiaddr, upgrader: TransportUpgrader
|
||||
) -> ITransport | None:
|
||||
"""
|
||||
Create the appropriate transport for a given multiaddr.
|
||||
|
||||
|
||||
:param maddr: The multiaddr to create transport for
|
||||
:param upgrader: The transport upgrader instance
|
||||
:return: Transport instance or None if no suitable transport found
|
||||
@ -188,7 +202,7 @@ def create_transport_for_multiaddr(maddr: Multiaddr, upgrader: TransportUpgrader
|
||||
try:
|
||||
# Get all protocols in the multiaddr
|
||||
protocols = [proto.name for proto in maddr.protocols()]
|
||||
|
||||
|
||||
# Check for supported transport protocols in order of preference
|
||||
# We need to validate that the multiaddr structure is valid for our transports
|
||||
if "ws" in protocols:
|
||||
@ -201,11 +215,14 @@ def create_transport_for_multiaddr(maddr: Multiaddr, upgrader: TransportUpgrader
|
||||
# Check if the multiaddr has proper TCP structure
|
||||
if _is_valid_tcp_multiaddr(maddr):
|
||||
return _global_registry.create_transport("tcp", upgrader)
|
||||
|
||||
|
||||
# If no supported transport protocol found or structure is invalid, return None
|
||||
logger.warning(f"No supported transport protocol found or invalid structure in multiaddr: {maddr}")
|
||||
logger.warning(
|
||||
f"No supported transport protocol found or invalid structure in "
|
||||
f"multiaddr: {maddr}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Handle any errors gracefully (e.g., invalid multiaddr)
|
||||
logger.warning(f"Error processing multiaddr {maddr}: {e}")
|
||||
|
||||
@ -1,9 +1,13 @@
|
||||
from trio.abc import Stream
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
from libp2p.io.exceptions import IOException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class P2PWebSocketConnection(ReadWriteCloser):
|
||||
"""
|
||||
@ -11,7 +15,7 @@ class P2PWebSocketConnection(ReadWriteCloser):
|
||||
that libp2p protocols expect.
|
||||
"""
|
||||
|
||||
def __init__(self, ws_connection, ws_context=None):
|
||||
def __init__(self, ws_connection: Any, ws_context: Any = None) -> None:
|
||||
self._ws_connection = ws_connection
|
||||
self._ws_context = ws_context
|
||||
self._read_buffer = b""
|
||||
@ -19,57 +23,102 @@ class P2PWebSocketConnection(ReadWriteCloser):
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
try:
|
||||
logger.debug(f"WebSocket writing {len(data)} bytes")
|
||||
# Send as a binary WebSocket message
|
||||
await self._ws_connection.send_message(data)
|
||||
logger.debug(f"WebSocket wrote {len(data)} bytes successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket write failed: {e}")
|
||||
raise IOException from e
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
"""
|
||||
Read up to n bytes (if n is given), else read up to 64KiB.
|
||||
This implementation provides byte-level access to WebSocket messages,
|
||||
which is required for Noise protocol handshake.
|
||||
"""
|
||||
async with self._read_lock:
|
||||
try:
|
||||
logger.debug(
|
||||
f"WebSocket read requested: n={n}, "
|
||||
f"buffer_size={len(self._read_buffer)}"
|
||||
)
|
||||
|
||||
# If we have buffered data, return it
|
||||
if self._read_buffer:
|
||||
if n is None:
|
||||
result = self._read_buffer
|
||||
self._read_buffer = b""
|
||||
logger.debug(
|
||||
f"WebSocket read returning all buffered data: "
|
||||
f"{len(result)} bytes"
|
||||
)
|
||||
return result
|
||||
else:
|
||||
if len(self._read_buffer) >= n:
|
||||
result = self._read_buffer[:n]
|
||||
self._read_buffer = self._read_buffer[n:]
|
||||
logger.debug(
|
||||
f"WebSocket read returning {len(result)} bytes "
|
||||
f"from buffer"
|
||||
)
|
||||
return result
|
||||
else:
|
||||
result = self._read_buffer
|
||||
self._read_buffer = b""
|
||||
return result
|
||||
# We need more data, but we have some buffered
|
||||
# Keep the buffered data and get more
|
||||
logger.debug(
|
||||
f"WebSocket read needs more data: have "
|
||||
f"{len(self._read_buffer)}, need {n}"
|
||||
)
|
||||
pass
|
||||
|
||||
# If we need exactly n bytes but don't have enough, get more data
|
||||
while n is not None and (
|
||||
not self._read_buffer or len(self._read_buffer) < n
|
||||
):
|
||||
logger.debug(
|
||||
f"WebSocket read getting more data: "
|
||||
f"buffer_size={len(self._read_buffer)}, need={n}"
|
||||
)
|
||||
# Get the next WebSocket message and treat it as a byte stream
|
||||
# This mimics the Go implementation's NextReader() approach
|
||||
message = await self._ws_connection.get_message()
|
||||
if isinstance(message, str):
|
||||
message = message.encode("utf-8")
|
||||
|
||||
logger.debug(
|
||||
f"WebSocket read received message: {len(message)} bytes"
|
||||
)
|
||||
# Add to buffer
|
||||
self._read_buffer += message
|
||||
|
||||
# Get the next WebSocket message
|
||||
message = await self._ws_connection.get_message()
|
||||
if isinstance(message, str):
|
||||
message = message.encode('utf-8')
|
||||
|
||||
# Add to buffer
|
||||
self._read_buffer = message
|
||||
|
||||
# Return requested amount
|
||||
if n is None:
|
||||
result = self._read_buffer
|
||||
self._read_buffer = b""
|
||||
logger.debug(
|
||||
f"WebSocket read returning all data: {len(result)} bytes"
|
||||
)
|
||||
return result
|
||||
else:
|
||||
if len(self._read_buffer) >= n:
|
||||
result = self._read_buffer[:n]
|
||||
self._read_buffer = self._read_buffer[n:]
|
||||
logger.debug(
|
||||
f"WebSocket read returning exact {len(result)} bytes"
|
||||
)
|
||||
return result
|
||||
else:
|
||||
# This should never happen due to the while loop above
|
||||
result = self._read_buffer
|
||||
self._read_buffer = b""
|
||||
logger.debug(
|
||||
f"WebSocket read returning remaining {len(result)} bytes"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket read failed: {e}")
|
||||
raise IOException from e
|
||||
|
||||
async def close(self) -> None:
|
||||
@ -83,12 +132,12 @@ class P2PWebSocketConnection(ReadWriteCloser):
|
||||
# Try to get remote address from the WebSocket connection
|
||||
try:
|
||||
remote = self._ws_connection.remote
|
||||
if hasattr(remote, 'address') and hasattr(remote, 'port'):
|
||||
if hasattr(remote, "address") and hasattr(remote, "port"):
|
||||
return str(remote.address), int(remote.port)
|
||||
elif isinstance(remote, str):
|
||||
# Parse address:port format
|
||||
if ':' in remote:
|
||||
host, port = remote.rsplit(':', 1)
|
||||
if ":" in remote:
|
||||
host, port = remote.rsplit(":", 1)
|
||||
return host, int(port)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
import logging
|
||||
import socket
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
@ -9,7 +9,6 @@ from trio_websocket import serve_websocket
|
||||
|
||||
from libp2p.abc import IListener
|
||||
from libp2p.custom_types import THandler
|
||||
from libp2p.network.connection.raw_connection import RawConnection
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
|
||||
from .connection import P2PWebSocketConnection
|
||||
@ -27,7 +26,8 @@ class WebsocketListener(IListener):
|
||||
self._upgrader = upgrader
|
||||
self._server = None
|
||||
self._shutdown_event = trio.Event()
|
||||
self._nursery = None
|
||||
self._nursery: trio.Nursery | None = None
|
||||
self._listeners: Any = None
|
||||
|
||||
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
|
||||
logger.debug(f"WebsocketListener.listen called with {maddr}")
|
||||
@ -47,56 +47,60 @@ class WebsocketListener(IListener):
|
||||
if port_str is None:
|
||||
raise ValueError(f"No TCP port found in multiaddr: {maddr}")
|
||||
port = int(port_str)
|
||||
|
||||
|
||||
logger.debug(f"WebsocketListener: host={host}, port={port}")
|
||||
|
||||
async def serve_websocket_tcp(
|
||||
handler: Callable,
|
||||
handler: Callable[[Any], Awaitable[None]],
|
||||
port: int,
|
||||
host: str,
|
||||
task_status: trio.TaskStatus[list],
|
||||
task_status: TaskStatus[Any],
|
||||
) -> None:
|
||||
"""Start TCP server and handle WebSocket connections manually"""
|
||||
logger.debug("serve_websocket_tcp %s %s", host, port)
|
||||
|
||||
async def websocket_handler(request):
|
||||
|
||||
async def websocket_handler(request: Any) -> None:
|
||||
"""Handle WebSocket requests"""
|
||||
logger.debug("WebSocket request received")
|
||||
try:
|
||||
# Accept the WebSocket connection
|
||||
ws_connection = await request.accept()
|
||||
logger.debug("WebSocket handshake successful")
|
||||
|
||||
|
||||
# Create the WebSocket connection wrapper
|
||||
conn = P2PWebSocketConnection(ws_connection)
|
||||
|
||||
conn = P2PWebSocketConnection(ws_connection) # type: ignore[no-untyped-call]
|
||||
|
||||
# Call the handler function that was passed to create_listener
|
||||
# This handler will handle the security and muxing upgrades
|
||||
logger.debug("Calling connection handler")
|
||||
await self._handler(conn)
|
||||
|
||||
|
||||
# Don't keep the connection alive indefinitely
|
||||
# Let the handler manage the connection lifecycle
|
||||
logger.debug("Handler completed, connection will be managed by handler")
|
||||
|
||||
logger.debug(
|
||||
"Handler completed, connection will be managed by handler"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"WebSocket connection error: {e}")
|
||||
logger.debug(f"Error type: {type(e)}")
|
||||
import traceback
|
||||
|
||||
logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
# Reject the connection
|
||||
try:
|
||||
await request.reject(400)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Use trio_websocket.serve_websocket for proper WebSocket handling
|
||||
from trio_websocket import serve_websocket
|
||||
await serve_websocket(websocket_handler, host, port, None, task_status=task_status)
|
||||
await serve_websocket(
|
||||
websocket_handler, host, port, None, task_status=task_status
|
||||
)
|
||||
|
||||
# Store the nursery for shutdown
|
||||
self._nursery = nursery
|
||||
|
||||
|
||||
# Start the server using nursery.start() like TCP does
|
||||
logger.debug("Calling nursery.start()...")
|
||||
started_listeners = await nursery.start(
|
||||
@ -111,18 +115,21 @@ class WebsocketListener(IListener):
|
||||
logger.error(f"Failed to start WebSocket listener for {maddr}")
|
||||
return False
|
||||
|
||||
# Store the listeners for get_addrs() and close() - these are real SocketListener objects
|
||||
# Store the listeners for get_addrs() and close() - these are real
|
||||
# SocketListener objects
|
||||
self._listeners = started_listeners
|
||||
logger.debug(f"WebsocketListener.listen returning True with WebSocketServer object")
|
||||
logger.debug(
|
||||
"WebsocketListener.listen returning True with WebSocketServer object"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def get_addrs(self) -> tuple[Multiaddr, ...]:
|
||||
if not hasattr(self, '_listeners') or not self._listeners:
|
||||
if not hasattr(self, "_listeners") or not self._listeners:
|
||||
logger.debug("No listeners available for get_addrs()")
|
||||
return ()
|
||||
|
||||
|
||||
# Handle WebSocketServer objects
|
||||
if hasattr(self._listeners, 'port'):
|
||||
if hasattr(self._listeners, "port"):
|
||||
# This is a WebSocketServer object
|
||||
port = self._listeners.port
|
||||
# Create a multiaddr from the port
|
||||
@ -138,12 +145,12 @@ class WebsocketListener(IListener):
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket listener and stop accepting new connections"""
|
||||
logger.debug("WebsocketListener.close called")
|
||||
if hasattr(self, '_listeners') and self._listeners:
|
||||
if hasattr(self, "_listeners") and self._listeners:
|
||||
# Signal shutdown
|
||||
self._shutdown_event.set()
|
||||
|
||||
|
||||
# Close the WebSocket server
|
||||
if hasattr(self._listeners, 'aclose'):
|
||||
if hasattr(self._listeners, "aclose"):
|
||||
# This is a WebSocketServer object
|
||||
logger.debug("Closing WebSocket server")
|
||||
await self._listeners.aclose()
|
||||
@ -152,15 +159,15 @@ class WebsocketListener(IListener):
|
||||
# This is a list of listeners (like TCP)
|
||||
logger.debug("Closing TCP listeners")
|
||||
for listener in self._listeners:
|
||||
listener.close()
|
||||
await listener.aclose()
|
||||
logger.debug("TCP listeners closed")
|
||||
else:
|
||||
# Unknown type, try to close it directly
|
||||
logger.debug("Closing unknown listener type")
|
||||
if hasattr(self._listeners, 'close'):
|
||||
if hasattr(self._listeners, "close"):
|
||||
self._listeners.close()
|
||||
logger.debug("Unknown listener closed")
|
||||
|
||||
|
||||
# Clear the listeners reference
|
||||
self._listeners = None
|
||||
logger.debug("WebsocketListener.close completed")
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
from trio_websocket import open_websocket_url
|
||||
|
||||
from libp2p.abc import IListener, ITransport
|
||||
from libp2p.custom_types import THandler
|
||||
@ -11,7 +11,7 @@ from libp2p.transport.upgrader import TransportUpgrader
|
||||
from .connection import P2PWebSocketConnection
|
||||
from .listener import WebsocketListener
|
||||
|
||||
logger = logging.getLogger("libp2p.transport.websocket")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebsocketTransport(ITransport):
|
||||
@ -25,7 +25,7 @@ class WebsocketTransport(ITransport):
|
||||
async def dial(self, maddr: Multiaddr) -> RawConnection:
|
||||
"""Dial a WebSocket connection to the given multiaddr."""
|
||||
logger.debug(f"WebsocketTransport.dial called with {maddr}")
|
||||
|
||||
|
||||
# Extract host and port from multiaddr
|
||||
host = (
|
||||
maddr.value_for_protocol("ip4")
|
||||
@ -45,6 +45,7 @@ class WebsocketTransport(ITransport):
|
||||
|
||||
try:
|
||||
from trio_websocket import open_websocket_url
|
||||
|
||||
# Use the context manager but don't exit it immediately
|
||||
# The connection will be closed when the RawConnection is closed
|
||||
ws_context = open_websocket_url(ws_url)
|
||||
|
||||
Reference in New Issue
Block a user