mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge branch 'main' into refactor/replace-magic-numbers-with-named-constants
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
"""Libp2p Python implementation."""
|
||||
|
||||
import logging
|
||||
import ssl
|
||||
|
||||
from libp2p.transport.quic.utils import is_quic_multiaddr
|
||||
from typing import Any
|
||||
@ -24,6 +25,7 @@ from libp2p.abc import (
|
||||
IPeerRouting,
|
||||
IPeerStore,
|
||||
ISecureTransport,
|
||||
ITransport,
|
||||
)
|
||||
from libp2p.crypto.keys import (
|
||||
KeyPair,
|
||||
@ -80,6 +82,10 @@ from libp2p.transport.tcp.tcp import (
|
||||
from libp2p.transport.upgrader import (
|
||||
TransportUpgrader,
|
||||
)
|
||||
from libp2p.transport.transport_registry import (
|
||||
create_transport_for_multiaddr,
|
||||
get_supported_transport_protocols,
|
||||
)
|
||||
from libp2p.utils.logging import (
|
||||
setup_logging,
|
||||
)
|
||||
@ -174,7 +180,10 @@ def new_swarm(
|
||||
enable_quic: bool = False,
|
||||
retry_config: Optional["RetryConfig"] = None,
|
||||
connection_config: ConnectionConfig | QUICTransportConfig | None = None,
|
||||
tls_client_config: ssl.SSLContext | None = None,
|
||||
tls_server_config: ssl.SSLContext | None = None,
|
||||
) -> INetworkService:
|
||||
logger.debug(f"new_swarm: enable_quic={enable_quic}, listen_addrs={listen_addrs}")
|
||||
"""
|
||||
Create a swarm instance based on the parameters.
|
||||
|
||||
@ -198,7 +207,7 @@ def new_swarm(
|
||||
|
||||
id_opt = generate_peer_id_from(key_pair)
|
||||
|
||||
transport: TCP | QUICTransport
|
||||
transport: TCP | QUICTransport | ITransport
|
||||
quic_transport_opt = connection_config if isinstance(connection_config, QUICTransportConfig) else None
|
||||
|
||||
if listen_addrs is None:
|
||||
@ -207,14 +216,39 @@ def new_swarm(
|
||||
else:
|
||||
transport = TCP()
|
||||
else:
|
||||
# Use transport registry to select the appropriate transport
|
||||
from libp2p.transport.transport_registry import create_transport_for_multiaddr
|
||||
|
||||
# Create a temporary upgrader for transport selection
|
||||
# We'll create the real upgrader later with the proper configuration
|
||||
temp_upgrader = TransportUpgrader(
|
||||
secure_transports_by_protocol={},
|
||||
muxer_transports_by_protocol={}
|
||||
)
|
||||
|
||||
addr = listen_addrs[0]
|
||||
is_quic = is_quic_multiaddr(addr)
|
||||
if addr.__contains__("tcp"):
|
||||
transport = TCP()
|
||||
elif is_quic:
|
||||
transport = QUICTransport(key_pair.private_key, config=quic_transport_opt)
|
||||
else:
|
||||
raise ValueError(f"Unknown transport in listen_addrs: {listen_addrs}")
|
||||
logger.debug(f"new_swarm: Creating transport for address: {addr}")
|
||||
transport_maybe = create_transport_for_multiaddr(
|
||||
addr,
|
||||
temp_upgrader,
|
||||
private_key=key_pair.private_key,
|
||||
config=quic_transport_opt,
|
||||
tls_client_config=tls_client_config,
|
||||
tls_server_config=tls_server_config
|
||||
)
|
||||
|
||||
if transport_maybe is None:
|
||||
raise ValueError(f"Unsupported transport for listen_addrs: {listen_addrs}")
|
||||
|
||||
transport = transport_maybe
|
||||
logger.debug(f"new_swarm: Created transport: {type(transport)}")
|
||||
|
||||
# If enable_quic is True but we didn't get a QUIC transport, force QUIC
|
||||
if enable_quic and not isinstance(transport, QUICTransport):
|
||||
logger.debug(f"new_swarm: Forcing QUIC transport (enable_quic=True but got {type(transport)})")
|
||||
transport = QUICTransport(key_pair.private_key, config=quic_transport_opt)
|
||||
|
||||
logger.debug(f"new_swarm: Final transport type: {type(transport)}")
|
||||
|
||||
# Generate X25519 keypair for Noise
|
||||
noise_key_pair = create_new_x25519_key_pair()
|
||||
@ -255,6 +289,7 @@ def new_swarm(
|
||||
muxer_transports_by_protocol=muxer_transports_by_protocol,
|
||||
)
|
||||
|
||||
|
||||
peerstore = peerstore_opt or PeerStore()
|
||||
# Store our key pair in peerstore
|
||||
peerstore.add_key_pair(id_opt, key_pair)
|
||||
@ -282,6 +317,8 @@ def new_host(
|
||||
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
|
||||
enable_quic: bool = False,
|
||||
quic_transport_opt: QUICTransportConfig | None = None,
|
||||
tls_client_config: ssl.SSLContext | None = None,
|
||||
tls_server_config: ssl.SSLContext | None = None,
|
||||
) -> IHost:
|
||||
"""
|
||||
Create a new libp2p host based on the given parameters.
|
||||
@ -296,7 +333,9 @@ def new_host(
|
||||
:param enable_mDNS: whether to enable mDNS discovery
|
||||
:param bootstrap: optional list of bootstrap peer addresses as strings
|
||||
:param enable_quic: optinal choice to use QUIC for transport
|
||||
:param transport_opt: optional configuration for quic transport
|
||||
:param quic_transport_opt: optional configuration for quic transport
|
||||
:param tls_client_config: optional TLS client configuration for WebSocket transport
|
||||
:param tls_server_config: optional TLS server configuration for WebSocket transport
|
||||
:return: return a host instance
|
||||
"""
|
||||
|
||||
@ -311,7 +350,9 @@ def new_host(
|
||||
peerstore_opt=peerstore_opt,
|
||||
muxer_preference=muxer_preference,
|
||||
listen_addrs=listen_addrs,
|
||||
connection_config=quic_transport_opt if enable_quic else None
|
||||
connection_config=quic_transport_opt if enable_quic else None,
|
||||
tls_client_config=tls_client_config,
|
||||
tls_server_config=tls_server_config
|
||||
)
|
||||
|
||||
if disc_opt is not None:
|
||||
|
||||
@ -481,13 +481,16 @@ class Swarm(Service, INetworkService):
|
||||
- Call listener listen with the multiaddr
|
||||
- Map multiaddr to listener
|
||||
"""
|
||||
logger.debug(f"Swarm.listen called with multiaddrs: {multiaddrs}")
|
||||
# We need to wait until `self.listener_nursery` is created.
|
||||
logger.debug("Starting to listen")
|
||||
await self.event_listener_nursery_created.wait()
|
||||
|
||||
success_count = 0
|
||||
for maddr in multiaddrs:
|
||||
logger.debug(f"Swarm.listen processing multiaddr: {maddr}")
|
||||
if str(maddr) in self.listeners:
|
||||
logger.debug(f"Swarm.listen: listener already exists for {maddr}")
|
||||
success_count += 1
|
||||
continue
|
||||
|
||||
@ -545,13 +548,18 @@ class Swarm(Service, INetworkService):
|
||||
|
||||
try:
|
||||
# Success
|
||||
logger.debug(f"Swarm.listen: creating listener for {maddr}")
|
||||
listener = self.transport.create_listener(conn_handler)
|
||||
logger.debug(f"Swarm.listen: listener created for {maddr}")
|
||||
self.listeners[str(maddr)] = listener
|
||||
# TODO: `listener.listen` is not bounded with nursery. If we want to be
|
||||
# I/O agnostic, we should change the API.
|
||||
if self.listener_nursery is None:
|
||||
raise SwarmException("swarm instance hasn't been run")
|
||||
assert self.listener_nursery is not None # For type checker
|
||||
logger.debug(f"Swarm.listen: calling listener.listen for {maddr}")
|
||||
await listener.listen(maddr, self.listener_nursery)
|
||||
logger.debug(f"Swarm.listen: listener.listen completed for {maddr}")
|
||||
|
||||
# Call notifiers since event occurred
|
||||
await self.notify_listen(maddr)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from typing import (
|
||||
cast,
|
||||
)
|
||||
@ -15,6 +16,8 @@ from libp2p.io.msgio import (
|
||||
FixedSizeLenMsgReadWriter,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SIZE_NOISE_MESSAGE_LEN = 2
|
||||
MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1
|
||||
SIZE_NOISE_MESSAGE_BODY_LEN = 2
|
||||
@ -50,18 +53,25 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
|
||||
self.noise_state = noise_state
|
||||
|
||||
async def write_msg(self, msg: bytes, prefix_encoded: bool = False) -> None:
|
||||
logger.debug(f"Noise write_msg: encrypting {len(msg)} bytes")
|
||||
data_encrypted = self.encrypt(msg)
|
||||
if prefix_encoded:
|
||||
# Manually add the prefix if needed
|
||||
data_encrypted = self.prefix + data_encrypted
|
||||
logger.debug(f"Noise write_msg: writing {len(data_encrypted)} encrypted bytes")
|
||||
await self.read_writer.write_msg(data_encrypted)
|
||||
logger.debug("Noise write_msg: write completed successfully")
|
||||
|
||||
async def read_msg(self, prefix_encoded: bool = False) -> bytes:
|
||||
logger.debug("Noise read_msg: reading encrypted message")
|
||||
noise_msg_encrypted = await self.read_writer.read_msg()
|
||||
logger.debug(f"Noise read_msg: read {len(noise_msg_encrypted)} encrypted bytes")
|
||||
if prefix_encoded:
|
||||
return self.decrypt(noise_msg_encrypted[len(self.prefix) :])
|
||||
result = self.decrypt(noise_msg_encrypted[len(self.prefix) :])
|
||||
else:
|
||||
return self.decrypt(noise_msg_encrypted)
|
||||
result = self.decrypt(noise_msg_encrypted)
|
||||
logger.debug(f"Noise read_msg: decrypted to {len(result)} bytes")
|
||||
return result
|
||||
|
||||
async def close(self) -> None:
|
||||
await self.read_writer.close()
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
)
|
||||
import logging
|
||||
|
||||
from libp2p.crypto.keys import (
|
||||
PrivateKey,
|
||||
@ -12,6 +13,8 @@ from libp2p.crypto.serialization import (
|
||||
|
||||
from .pb import noise_pb2 as noise_pb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SIGNED_DATA_PREFIX = "noise-libp2p-static-key:"
|
||||
|
||||
|
||||
@ -48,6 +51,8 @@ def make_handshake_payload_sig(
|
||||
id_privkey: PrivateKey, noise_static_pubkey: PublicKey
|
||||
) -> bytes:
|
||||
data = make_data_to_be_signed(noise_static_pubkey)
|
||||
logger.debug(f"make_handshake_payload_sig: signing data length: {len(data)}")
|
||||
logger.debug(f"make_handshake_payload_sig: signing data hex: {data.hex()}")
|
||||
return id_privkey.sign(data)
|
||||
|
||||
|
||||
@ -60,4 +65,27 @@ def verify_handshake_payload_sig(
|
||||
2. signed by the private key corresponding to `id_pubkey`
|
||||
"""
|
||||
expected_data = make_data_to_be_signed(noise_static_pubkey)
|
||||
return payload.id_pubkey.verify(expected_data, payload.id_sig)
|
||||
logger.debug(
|
||||
f"verify_handshake_payload_sig: payload.id_pubkey type: "
|
||||
f"{type(payload.id_pubkey)}"
|
||||
)
|
||||
logger.debug(
|
||||
f"verify_handshake_payload_sig: noise_static_pubkey type: "
|
||||
f"{type(noise_static_pubkey)}"
|
||||
)
|
||||
logger.debug(
|
||||
f"verify_handshake_payload_sig: expected_data length: {len(expected_data)}"
|
||||
)
|
||||
logger.debug(
|
||||
f"verify_handshake_payload_sig: expected_data hex: {expected_data.hex()}"
|
||||
)
|
||||
logger.debug(
|
||||
f"verify_handshake_payload_sig: payload.id_sig length: {len(payload.id_sig)}"
|
||||
)
|
||||
try:
|
||||
result = payload.id_pubkey.verify(expected_data, payload.id_sig)
|
||||
logger.debug(f"verify_handshake_payload_sig: verification result: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"verify_handshake_payload_sig: verification exception: {e}")
|
||||
return False
|
||||
|
||||
@ -2,6 +2,7 @@ from abc import (
|
||||
ABC,
|
||||
abstractmethod,
|
||||
)
|
||||
import logging
|
||||
|
||||
from cryptography.hazmat.primitives import (
|
||||
serialization,
|
||||
@ -46,6 +47,8 @@ from .messages import (
|
||||
verify_handshake_payload_sig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IPattern(ABC):
|
||||
@abstractmethod
|
||||
@ -95,6 +98,7 @@ class PatternXX(BasePattern):
|
||||
self.early_data = early_data
|
||||
|
||||
async def handshake_inbound(self, conn: IRawConnection) -> ISecureConn:
|
||||
logger.debug(f"Noise XX handshake_inbound started for peer {self.local_peer}")
|
||||
noise_state = self.create_noise_state()
|
||||
noise_state.set_as_responder()
|
||||
noise_state.start_handshake()
|
||||
@ -107,15 +111,22 @@ class PatternXX(BasePattern):
|
||||
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
||||
|
||||
# Consume msg#1.
|
||||
logger.debug("Noise XX handshake_inbound: reading msg#1")
|
||||
await read_writer.read_msg()
|
||||
logger.debug("Noise XX handshake_inbound: read msg#1 successfully")
|
||||
|
||||
# Send msg#2, which should include our handshake payload.
|
||||
logger.debug("Noise XX handshake_inbound: preparing msg#2")
|
||||
our_payload = self.make_handshake_payload()
|
||||
msg_2 = our_payload.serialize()
|
||||
logger.debug(f"Noise XX handshake_inbound: sending msg#2 ({len(msg_2)} bytes)")
|
||||
await read_writer.write_msg(msg_2)
|
||||
logger.debug("Noise XX handshake_inbound: sent msg#2 successfully")
|
||||
|
||||
# Receive and consume msg#3.
|
||||
logger.debug("Noise XX handshake_inbound: reading msg#3")
|
||||
msg_3 = await read_writer.read_msg()
|
||||
logger.debug(f"Noise XX handshake_inbound: read msg#3 ({len(msg_3)} bytes)")
|
||||
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_3)
|
||||
|
||||
if handshake_state.rs is None:
|
||||
@ -147,6 +158,7 @@ class PatternXX(BasePattern):
|
||||
async def handshake_outbound(
|
||||
self, conn: IRawConnection, remote_peer: ID
|
||||
) -> ISecureConn:
|
||||
logger.debug(f"Noise XX handshake_outbound started to peer {remote_peer}")
|
||||
noise_state = self.create_noise_state()
|
||||
|
||||
read_writer = NoiseHandshakeReadWriter(conn, noise_state)
|
||||
@ -159,11 +171,15 @@ class PatternXX(BasePattern):
|
||||
raise NoiseStateError("Handshake state is not initialized")
|
||||
|
||||
# Send msg#1, which is *not* encrypted.
|
||||
logger.debug("Noise XX handshake_outbound: sending msg#1")
|
||||
msg_1 = b""
|
||||
await read_writer.write_msg(msg_1)
|
||||
logger.debug("Noise XX handshake_outbound: sent msg#1 successfully")
|
||||
|
||||
# Read msg#2 from the remote, which contains the public key of the peer.
|
||||
logger.debug("Noise XX handshake_outbound: reading msg#2")
|
||||
msg_2 = await read_writer.read_msg()
|
||||
logger.debug(f"Noise XX handshake_outbound: read msg#2 ({len(msg_2)} bytes)")
|
||||
peer_handshake_payload = NoiseHandshakePayload.deserialize(msg_2)
|
||||
|
||||
if handshake_state.rs is None:
|
||||
@ -174,8 +190,27 @@ class PatternXX(BasePattern):
|
||||
)
|
||||
remote_pubkey = self._get_pubkey_from_noise_keypair(handshake_state.rs)
|
||||
|
||||
logger.debug(
|
||||
f"Noise XX handshake_outbound: verifying signature for peer {remote_peer}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Noise XX handshake_outbound: remote_pubkey type: {type(remote_pubkey)}"
|
||||
)
|
||||
id_pubkey_repr = peer_handshake_payload.id_pubkey.to_bytes().hex()
|
||||
logger.debug(
|
||||
f"Noise XX handshake_outbound: peer_handshake_payload.id_pubkey: "
|
||||
f"{id_pubkey_repr}"
|
||||
)
|
||||
if not verify_handshake_payload_sig(peer_handshake_payload, remote_pubkey):
|
||||
logger.error(
|
||||
f"Noise XX handshake_outbound: signature verification failed for peer "
|
||||
f"{remote_peer}"
|
||||
)
|
||||
raise InvalidSignature
|
||||
logger.debug(
|
||||
f"Noise XX handshake_outbound: signature verification successful for peer "
|
||||
f"{remote_peer}"
|
||||
)
|
||||
remote_peer_id_from_pubkey = ID.from_pubkey(peer_handshake_payload.id_pubkey)
|
||||
if remote_peer_id_from_pubkey != remote_peer:
|
||||
raise PeerIDMismatchesPubkey(
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from types import (
|
||||
TracebackType,
|
||||
)
|
||||
@ -15,6 +13,7 @@ from libp2p.abc import (
|
||||
from libp2p.stream_muxer.exceptions import (
|
||||
MuxedConnUnavailable,
|
||||
)
|
||||
from libp2p.stream_muxer.rw_lock import ReadWriteLock
|
||||
|
||||
from .constants import (
|
||||
HeaderTags,
|
||||
@ -34,72 +33,6 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
|
||||
class ReadWriteLock:
|
||||
"""
|
||||
A read-write lock that allows multiple concurrent readers
|
||||
or one exclusive writer, implemented using Trio primitives.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._readers = 0
|
||||
self._readers_lock = trio.Lock() # Protects access to _readers count
|
||||
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
|
||||
|
||||
async def acquire_read(self) -> None:
|
||||
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
|
||||
try:
|
||||
async with self._readers_lock:
|
||||
if self._readers == 0:
|
||||
await self._writer_lock.acquire()
|
||||
self._readers += 1
|
||||
except trio.Cancelled:
|
||||
raise
|
||||
|
||||
async def release_read(self) -> None:
|
||||
"""Release a read lock."""
|
||||
async with self._readers_lock:
|
||||
if self._readers == 1:
|
||||
self._writer_lock.release()
|
||||
self._readers -= 1
|
||||
|
||||
async def acquire_write(self) -> None:
|
||||
"""Acquire an exclusive write lock."""
|
||||
try:
|
||||
await self._writer_lock.acquire()
|
||||
except trio.Cancelled:
|
||||
raise
|
||||
|
||||
def release_write(self) -> None:
|
||||
"""Release the exclusive write lock."""
|
||||
self._writer_lock.release()
|
||||
|
||||
@asynccontextmanager
|
||||
async def read_lock(self) -> AsyncGenerator[None, None]:
|
||||
"""Context manager for acquiring and releasing a read lock safely."""
|
||||
acquire = False
|
||||
try:
|
||||
await self.acquire_read()
|
||||
acquire = True
|
||||
yield
|
||||
finally:
|
||||
if acquire:
|
||||
with trio.CancelScope() as scope:
|
||||
scope.shield = True
|
||||
await self.release_read()
|
||||
|
||||
@asynccontextmanager
|
||||
async def write_lock(self) -> AsyncGenerator[None, None]:
|
||||
"""Context manager for acquiring and releasing a write lock safely."""
|
||||
acquire = False
|
||||
try:
|
||||
await self.acquire_write()
|
||||
acquire = True
|
||||
yield
|
||||
finally:
|
||||
if acquire:
|
||||
self.release_write()
|
||||
|
||||
|
||||
class MplexStream(IMuxedStream):
|
||||
"""
|
||||
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
|
||||
|
||||
70
libp2p/stream_muxer/rw_lock.py
Normal file
70
libp2p/stream_muxer/rw_lock.py
Normal file
@ -0,0 +1,70 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import trio
|
||||
|
||||
|
||||
class ReadWriteLock:
|
||||
"""
|
||||
A read-write lock that allows multiple concurrent readers
|
||||
or one exclusive writer, implemented using Trio primitives.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._readers = 0
|
||||
self._readers_lock = trio.Lock() # Protects access to _readers count
|
||||
self._writer_lock = trio.Semaphore(1) # Allows only one writer at a time
|
||||
|
||||
async def acquire_read(self) -> None:
|
||||
"""Acquire a read lock. Multiple readers can hold it simultaneously."""
|
||||
try:
|
||||
async with self._readers_lock:
|
||||
if self._readers == 0:
|
||||
await self._writer_lock.acquire()
|
||||
self._readers += 1
|
||||
except trio.Cancelled:
|
||||
raise
|
||||
|
||||
async def release_read(self) -> None:
|
||||
"""Release a read lock."""
|
||||
async with self._readers_lock:
|
||||
if self._readers == 1:
|
||||
self._writer_lock.release()
|
||||
self._readers -= 1
|
||||
|
||||
async def acquire_write(self) -> None:
|
||||
"""Acquire an exclusive write lock."""
|
||||
try:
|
||||
await self._writer_lock.acquire()
|
||||
except trio.Cancelled:
|
||||
raise
|
||||
|
||||
def release_write(self) -> None:
|
||||
"""Release the exclusive write lock."""
|
||||
self._writer_lock.release()
|
||||
|
||||
@asynccontextmanager
|
||||
async def read_lock(self) -> AsyncGenerator[None, None]:
|
||||
"""Context manager for acquiring and releasing a read lock safely."""
|
||||
acquire = False
|
||||
try:
|
||||
await self.acquire_read()
|
||||
acquire = True
|
||||
yield
|
||||
finally:
|
||||
if acquire:
|
||||
with trio.CancelScope() as scope:
|
||||
scope.shield = True
|
||||
await self.release_read()
|
||||
|
||||
@asynccontextmanager
|
||||
async def write_lock(self) -> AsyncGenerator[None, None]:
|
||||
"""Context manager for acquiring and releasing a write lock safely."""
|
||||
acquire = False
|
||||
try:
|
||||
await self.acquire_write()
|
||||
acquire = True
|
||||
yield
|
||||
finally:
|
||||
if acquire:
|
||||
self.release_write()
|
||||
@ -44,6 +44,7 @@ from libp2p.stream_muxer.exceptions import (
|
||||
MuxedStreamError,
|
||||
MuxedStreamReset,
|
||||
)
|
||||
from libp2p.stream_muxer.rw_lock import ReadWriteLock
|
||||
|
||||
# Configure logger for this module
|
||||
logger = logging.getLogger("libp2p.stream_muxer.yamux")
|
||||
@ -80,6 +81,8 @@ class YamuxStream(IMuxedStream):
|
||||
self.send_window = DEFAULT_WINDOW_SIZE
|
||||
self.recv_window = DEFAULT_WINDOW_SIZE
|
||||
self.window_lock = trio.Lock()
|
||||
self.rw_lock = ReadWriteLock()
|
||||
self.close_lock = trio.Lock()
|
||||
|
||||
async def __aenter__(self) -> "YamuxStream":
|
||||
"""Enter the async context manager."""
|
||||
@ -95,52 +98,54 @@ class YamuxStream(IMuxedStream):
|
||||
await self.close()
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
if self.send_closed:
|
||||
raise MuxedStreamError("Stream is closed for sending")
|
||||
async with self.rw_lock.write_lock():
|
||||
if self.send_closed:
|
||||
raise MuxedStreamError("Stream is closed for sending")
|
||||
|
||||
# Flow control: Check if we have enough send window
|
||||
total_len = len(data)
|
||||
sent = 0
|
||||
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
|
||||
while sent < total_len:
|
||||
# Wait for available window with timeout
|
||||
timeout = False
|
||||
async with self.window_lock:
|
||||
if self.send_window == 0:
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: Window is zero, waiting for update"
|
||||
# Flow control: Check if we have enough send window
|
||||
total_len = len(data)
|
||||
sent = 0
|
||||
logger.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
|
||||
while sent < total_len:
|
||||
# Wait for available window with timeout
|
||||
timeout = False
|
||||
async with self.window_lock:
|
||||
if self.send_window == 0:
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id}: "
|
||||
"Window is zero, waiting for update"
|
||||
)
|
||||
# Release lock and wait with timeout
|
||||
self.window_lock.release()
|
||||
# To avoid re-acquiring the lock immediately,
|
||||
with trio.move_on_after(5.0) as cancel_scope:
|
||||
while self.send_window == 0 and not self.closed:
|
||||
await trio.sleep(0.01)
|
||||
# If we timed out, cancel the scope
|
||||
timeout = cancel_scope.cancelled_caught
|
||||
# Re-acquire lock
|
||||
await self.window_lock.acquire()
|
||||
|
||||
# If we timed out waiting for window update, raise an error
|
||||
if timeout:
|
||||
raise MuxedStreamError(
|
||||
"Timed out waiting for window update after 5 seconds."
|
||||
)
|
||||
|
||||
if self.closed:
|
||||
raise MuxedStreamError("Stream is closed")
|
||||
|
||||
# Calculate how much we can send now
|
||||
to_send = min(self.send_window, total_len - sent)
|
||||
chunk = data[sent : sent + to_send]
|
||||
self.send_window -= to_send
|
||||
|
||||
# Send the data
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk)
|
||||
)
|
||||
# Release lock and wait with timeout
|
||||
self.window_lock.release()
|
||||
# To avoid re-acquiring the lock immediately,
|
||||
with trio.move_on_after(5.0) as cancel_scope:
|
||||
while self.send_window == 0 and not self.closed:
|
||||
await trio.sleep(0.01)
|
||||
# If we timed out, cancel the scope
|
||||
timeout = cancel_scope.cancelled_caught
|
||||
# Re-acquire lock
|
||||
await self.window_lock.acquire()
|
||||
|
||||
# If we timed out waiting for window update, raise an error
|
||||
if timeout:
|
||||
raise MuxedStreamError(
|
||||
"Timed out waiting for window update after 5 seconds."
|
||||
)
|
||||
|
||||
if self.closed:
|
||||
raise MuxedStreamError("Stream is closed")
|
||||
|
||||
# Calculate how much we can send now
|
||||
to_send = min(self.send_window, total_len - sent)
|
||||
chunk = data[sent : sent + to_send]
|
||||
self.send_window -= to_send
|
||||
|
||||
# Send the data
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, 0, self.stream_id, len(chunk)
|
||||
)
|
||||
await self.conn.secured_conn.write(header + chunk)
|
||||
sent += to_send
|
||||
await self.conn.secured_conn.write(header + chunk)
|
||||
sent += to_send
|
||||
|
||||
async def send_window_update(self, increment: int, skip_lock: bool = False) -> None:
|
||||
"""
|
||||
@ -257,30 +262,32 @@ class YamuxStream(IMuxedStream):
|
||||
return data
|
||||
|
||||
async def close(self) -> None:
|
||||
if not self.send_closed:
|
||||
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
|
||||
)
|
||||
await self.conn.secured_conn.write(header)
|
||||
self.send_closed = True
|
||||
async with self.close_lock:
|
||||
if not self.send_closed:
|
||||
logger.debug(f"Half-closing stream {self.stream_id} (local end)")
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_FIN, self.stream_id, 0
|
||||
)
|
||||
await self.conn.secured_conn.write(header)
|
||||
self.send_closed = True
|
||||
|
||||
# Only set fully closed if both directions are closed
|
||||
if self.send_closed and self.recv_closed:
|
||||
self.closed = True
|
||||
else:
|
||||
# Stream is half-closed but not fully closed
|
||||
self.closed = False
|
||||
# Only set fully closed if both directions are closed
|
||||
if self.send_closed and self.recv_closed:
|
||||
self.closed = True
|
||||
else:
|
||||
# Stream is half-closed but not fully closed
|
||||
self.closed = False
|
||||
|
||||
async def reset(self) -> None:
|
||||
if not self.closed:
|
||||
logger.debug(f"Resetting stream {self.stream_id}")
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
|
||||
)
|
||||
await self.conn.secured_conn.write(header)
|
||||
self.closed = True
|
||||
self.reset_received = True # Mark as reset
|
||||
async with self.close_lock:
|
||||
logger.debug(f"Resetting stream {self.stream_id}")
|
||||
header = struct.pack(
|
||||
YAMUX_HEADER_FORMAT, 0, TYPE_DATA, FLAG_RST, self.stream_id, 0
|
||||
)
|
||||
await self.conn.secured_conn.write(header)
|
||||
self.closed = True
|
||||
self.reset_received = True # Mark as reset
|
||||
|
||||
def set_deadline(self, ttl: int) -> bool:
|
||||
"""
|
||||
|
||||
@ -0,0 +1,57 @@
|
||||
from typing import Any
|
||||
|
||||
from .tcp.tcp import TCP
|
||||
from .websocket.transport import WebsocketTransport
|
||||
from .transport_registry import (
|
||||
TransportRegistry,
|
||||
create_transport_for_multiaddr,
|
||||
get_transport_registry,
|
||||
register_transport,
|
||||
get_supported_transport_protocols,
|
||||
)
|
||||
from .upgrader import TransportUpgrader
|
||||
from libp2p.abc import ITransport
|
||||
|
||||
def create_transport(protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any) -> ITransport:
|
||||
"""
|
||||
Convenience function to create a transport instance.
|
||||
|
||||
:param protocol: The transport protocol ("tcp", "ws", "wss", or custom)
|
||||
:param upgrader: Optional transport upgrader (required for WebSocket)
|
||||
:param kwargs: Additional arguments for transport construction (e.g., tls_client_config, tls_server_config)
|
||||
:return: Transport instance
|
||||
"""
|
||||
# First check if it's a built-in protocol
|
||||
if protocol in ["ws", "wss"]:
|
||||
if upgrader is None:
|
||||
raise ValueError(f"WebSocket transport requires an upgrader")
|
||||
return WebsocketTransport(
|
||||
upgrader,
|
||||
tls_client_config=kwargs.get("tls_client_config"),
|
||||
tls_server_config=kwargs.get("tls_server_config"),
|
||||
handshake_timeout=kwargs.get("handshake_timeout", 15.0)
|
||||
)
|
||||
elif protocol == "tcp":
|
||||
return TCP()
|
||||
else:
|
||||
# Check if it's a custom registered transport
|
||||
registry = get_transport_registry()
|
||||
transport_class = registry.get_transport(protocol)
|
||||
if transport_class:
|
||||
transport = registry.create_transport(protocol, upgrader, **kwargs)
|
||||
if transport is None:
|
||||
raise ValueError(f"Failed to create transport for protocol: {protocol}")
|
||||
return transport
|
||||
else:
|
||||
raise ValueError(f"Unsupported transport protocol: {protocol}")
|
||||
|
||||
__all__ = [
|
||||
"TCP",
|
||||
"WebsocketTransport",
|
||||
"TransportRegistry",
|
||||
"create_transport_for_multiaddr",
|
||||
"create_transport",
|
||||
"get_transport_registry",
|
||||
"register_transport",
|
||||
"get_supported_transport_protocols",
|
||||
]
|
||||
|
||||
@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from aioquic.quic import events
|
||||
from aioquic.quic.connection import QuicConnection
|
||||
@ -871,9 +871,11 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# Process events by type
|
||||
for event_type, event_list in events_by_type.items():
|
||||
if event_type == type(events.StreamDataReceived).__name__:
|
||||
await self._handle_stream_data_batch(
|
||||
cast(list[events.StreamDataReceived], event_list)
|
||||
)
|
||||
# Filter to only StreamDataReceived events
|
||||
stream_data_events = [
|
||||
e for e in event_list if isinstance(e, events.StreamDataReceived)
|
||||
]
|
||||
await self._handle_stream_data_batch(stream_data_events)
|
||||
else:
|
||||
# Process other events individually
|
||||
for event in event_list:
|
||||
|
||||
@ -901,7 +901,7 @@ class QUICListener(IListener):
|
||||
# Set socket options
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
if hasattr(socket, "SO_REUSEPORT"):
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) # type: ignore[attr-defined]
|
||||
|
||||
# Bind to address
|
||||
await sock.bind((host, port))
|
||||
|
||||
267
libp2p/transport/transport_registry.py
Normal file
267
libp2p/transport/transport_registry.py
Normal file
@ -0,0 +1,267 @@
|
||||
"""
|
||||
Transport registry for dynamic transport selection based on multiaddr protocols.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
from multiaddr.protocols import Protocol
|
||||
|
||||
from libp2p.abc import ITransport
|
||||
from libp2p.transport.tcp.tcp import TCP
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
from libp2p.transport.websocket.multiaddr_utils import (
|
||||
is_valid_websocket_multiaddr,
|
||||
)
|
||||
|
||||
|
||||
# Import QUIC utilities here to avoid circular imports
|
||||
def _get_quic_transport() -> Any:
|
||||
from libp2p.transport.quic.transport import QUICTransport
|
||||
|
||||
return QUICTransport
|
||||
|
||||
|
||||
def _get_quic_validation() -> Callable[[Multiaddr], bool]:
|
||||
from libp2p.transport.quic.utils import is_quic_multiaddr
|
||||
|
||||
return is_quic_multiaddr
|
||||
|
||||
|
||||
# Import WebsocketTransport here to avoid circular imports
|
||||
def _get_websocket_transport() -> Any:
|
||||
from libp2p.transport.websocket.transport import WebsocketTransport
|
||||
|
||||
return WebsocketTransport
|
||||
|
||||
|
||||
logger = logging.getLogger("libp2p.transport.registry")
|
||||
|
||||
|
||||
def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool:
|
||||
"""
|
||||
Validate that a multiaddr has a valid TCP structure.
|
||||
|
||||
:param maddr: The multiaddr to validate
|
||||
:return: True if valid TCP structure, False otherwise
|
||||
"""
|
||||
try:
|
||||
# TCP multiaddr should have structure like /ip4/127.0.0.1/tcp/8080
|
||||
# or /ip6/::1/tcp/8080
|
||||
protocols: list[Protocol] = list(maddr.protocols())
|
||||
|
||||
# Must have at least 2 protocols: network (ip4/ip6) + tcp
|
||||
if len(protocols) < 2:
|
||||
return False
|
||||
|
||||
# First protocol should be a network protocol (ip4, ip6, dns4, dns6)
|
||||
if protocols[0].name not in ["ip4", "ip6", "dns4", "dns6"]:
|
||||
return False
|
||||
|
||||
# Second protocol should be tcp
|
||||
if protocols[1].name != "tcp":
|
||||
return False
|
||||
|
||||
# Should not have any protocols after tcp (unless it's a valid
|
||||
# continuation like p2p)
|
||||
# For now, we'll be strict and only allow network + tcp
|
||||
if len(protocols) > 2:
|
||||
# Check if the additional protocols are valid continuations
|
||||
valid_continuations = ["p2p"] # Add more as needed
|
||||
for i in range(2, len(protocols)):
|
||||
if protocols[i].name not in valid_continuations:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class TransportRegistry:
|
||||
"""
|
||||
Registry for mapping multiaddr protocols to transport implementations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._transports: dict[str, type[ITransport]] = {}
|
||||
self._register_default_transports()
|
||||
|
||||
def _register_default_transports(self) -> None:
|
||||
"""Register the default transport implementations."""
|
||||
# Register TCP transport for /tcp protocol
|
||||
self.register_transport("tcp", TCP)
|
||||
|
||||
# Register WebSocket transport for /ws and /wss protocols
|
||||
WebsocketTransport = _get_websocket_transport()
|
||||
self.register_transport("ws", WebsocketTransport)
|
||||
self.register_transport("wss", WebsocketTransport)
|
||||
|
||||
# Register QUIC transport for /quic and /quic-v1 protocols
|
||||
QUICTransport = _get_quic_transport()
|
||||
self.register_transport("quic", QUICTransport)
|
||||
self.register_transport("quic-v1", QUICTransport)
|
||||
|
||||
def register_transport(
|
||||
self, protocol: str, transport_class: type[ITransport]
|
||||
) -> None:
|
||||
"""
|
||||
Register a transport class for a specific protocol.
|
||||
|
||||
:param protocol: The protocol identifier (e.g., "tcp", "ws")
|
||||
:param transport_class: The transport class to register
|
||||
"""
|
||||
self._transports[protocol] = transport_class
|
||||
logger.debug(
|
||||
f"Registered transport {transport_class.__name__} for protocol {protocol}"
|
||||
)
|
||||
|
||||
def get_transport(self, protocol: str) -> type[ITransport] | None:
|
||||
"""
|
||||
Get the transport class for a specific protocol.
|
||||
|
||||
:param protocol: The protocol identifier
|
||||
:return: The transport class or None if not found
|
||||
"""
|
||||
return self._transports.get(protocol)
|
||||
|
||||
def get_supported_protocols(self) -> list[str]:
|
||||
"""Get list of supported transport protocols."""
|
||||
return list(self._transports.keys())
|
||||
|
||||
def create_transport(
|
||||
self, protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any
|
||||
) -> ITransport | None:
|
||||
"""
|
||||
Create a transport instance for a specific protocol.
|
||||
|
||||
:param protocol: The protocol identifier
|
||||
:param upgrader: The transport upgrader instance (required for WebSocket)
|
||||
:param kwargs: Additional arguments for transport construction
|
||||
:return: Transport instance or None if protocol not supported or creation fails
|
||||
"""
|
||||
transport_class = self.get_transport(protocol)
|
||||
if transport_class is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
if protocol in ["ws", "wss"]:
|
||||
# WebSocket transport requires upgrader
|
||||
if upgrader is None:
|
||||
logger.warning(
|
||||
f"WebSocket transport '{protocol}' requires upgrader"
|
||||
)
|
||||
return None
|
||||
# Use explicit WebsocketTransport to avoid type issues
|
||||
WebsocketTransport = _get_websocket_transport()
|
||||
return WebsocketTransport(
|
||||
upgrader,
|
||||
tls_client_config=kwargs.get("tls_client_config"),
|
||||
tls_server_config=kwargs.get("tls_server_config"),
|
||||
handshake_timeout=kwargs.get("handshake_timeout", 15.0),
|
||||
)
|
||||
elif protocol in ["quic", "quic-v1"]:
|
||||
# QUIC transport requires private_key
|
||||
private_key = kwargs.get("private_key")
|
||||
if private_key is None:
|
||||
logger.warning(f"QUIC transport '{protocol}' requires private_key")
|
||||
return None
|
||||
# Use explicit QUICTransport to avoid type issues
|
||||
QUICTransport = _get_quic_transport()
|
||||
config = kwargs.get("config")
|
||||
return QUICTransport(private_key, config)
|
||||
else:
|
||||
# TCP transport doesn't require upgrader
|
||||
return transport_class()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create transport for protocol {protocol}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Global transport registry instance (lazy initialization)
|
||||
_global_registry: TransportRegistry | None = None
|
||||
|
||||
|
||||
def get_transport_registry() -> TransportRegistry:
|
||||
"""Get the global transport registry instance."""
|
||||
global _global_registry
|
||||
if _global_registry is None:
|
||||
_global_registry = TransportRegistry()
|
||||
return _global_registry
|
||||
|
||||
|
||||
def register_transport(protocol: str, transport_class: type[ITransport]) -> None:
|
||||
"""Register a transport class in the global registry."""
|
||||
registry = get_transport_registry()
|
||||
registry.register_transport(protocol, transport_class)
|
||||
|
||||
|
||||
def create_transport_for_multiaddr(
|
||||
maddr: Multiaddr, upgrader: TransportUpgrader, **kwargs: Any
|
||||
) -> ITransport | None:
|
||||
"""
|
||||
Create the appropriate transport for a given multiaddr.
|
||||
|
||||
:param maddr: The multiaddr to create transport for
|
||||
:param upgrader: The transport upgrader instance
|
||||
:param kwargs: Additional arguments for transport construction
|
||||
(e.g., private_key for QUIC)
|
||||
:return: Transport instance or None if no suitable transport found
|
||||
"""
|
||||
try:
|
||||
# Get all protocols in the multiaddr
|
||||
protocols = [proto.name for proto in maddr.protocols()]
|
||||
|
||||
# Check for supported transport protocols in order of preference
|
||||
# We need to validate that the multiaddr structure is valid for our transports
|
||||
if "quic" in protocols or "quic-v1" in protocols:
|
||||
# For QUIC, we need a valid structure like:
|
||||
# /ip4/127.0.0.1/udp/4001/quic
|
||||
# /ip4/127.0.0.1/udp/4001/quic-v1
|
||||
is_quic_multiaddr = _get_quic_validation()
|
||||
if is_quic_multiaddr(maddr):
|
||||
# Determine QUIC version
|
||||
registry = get_transport_registry()
|
||||
if "quic-v1" in protocols:
|
||||
return registry.create_transport("quic-v1", upgrader, **kwargs)
|
||||
else:
|
||||
return registry.create_transport("quic", upgrader, **kwargs)
|
||||
elif "ws" in protocols or "wss" in protocols or "tls" in protocols:
|
||||
# For WebSocket, we need a valid structure like:
|
||||
# /ip4/127.0.0.1/tcp/8080/ws (insecure)
|
||||
# /ip4/127.0.0.1/tcp/8080/wss (secure)
|
||||
# /ip4/127.0.0.1/tcp/8080/tls/ws (secure with TLS)
|
||||
# /ip4/127.0.0.1/tcp/8080/tls/sni/example.com/ws (secure with SNI)
|
||||
if is_valid_websocket_multiaddr(maddr):
|
||||
# Determine if this is a secure WebSocket connection
|
||||
registry = get_transport_registry()
|
||||
if "wss" in protocols or "tls" in protocols:
|
||||
return registry.create_transport("wss", upgrader, **kwargs)
|
||||
else:
|
||||
return registry.create_transport("ws", upgrader, **kwargs)
|
||||
elif "tcp" in protocols:
|
||||
# For TCP, we need a valid structure like /ip4/127.0.0.1/tcp/8080
|
||||
# Check if the multiaddr has proper TCP structure
|
||||
if _is_valid_tcp_multiaddr(maddr):
|
||||
registry = get_transport_registry()
|
||||
return registry.create_transport("tcp", upgrader)
|
||||
|
||||
# If no supported transport protocol found or structure is invalid, return None
|
||||
logger.warning(
|
||||
f"No supported transport protocol found or invalid structure in "
|
||||
f"multiaddr: {maddr}"
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
# Handle any errors gracefully (e.g., invalid multiaddr)
|
||||
logger.warning(f"Error processing multiaddr {maddr}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_supported_transport_protocols() -> list[str]:
|
||||
"""Get list of supported transport protocols from the global registry."""
|
||||
registry = get_transport_registry()
|
||||
return registry.get_supported_protocols()
|
||||
198
libp2p/transport/websocket/connection.py
Normal file
198
libp2p/transport/websocket/connection.py
Normal file
@ -0,0 +1,198 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import trio
|
||||
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
from libp2p.io.exceptions import IOException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class P2PWebSocketConnection(ReadWriteCloser):
|
||||
"""
|
||||
Wraps a WebSocketConnection to provide the raw stream interface
|
||||
that libp2p protocols expect.
|
||||
|
||||
Implements production-ready buffer management and flow control
|
||||
as recommended in the libp2p WebSocket specification.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ws_connection: Any,
|
||||
ws_context: Any = None,
|
||||
is_secure: bool = False,
|
||||
max_buffered_amount: int = 4 * 1024 * 1024,
|
||||
) -> None:
|
||||
self._ws_connection = ws_connection
|
||||
self._ws_context = ws_context
|
||||
self._is_secure = is_secure
|
||||
self._read_buffer = b""
|
||||
self._read_lock = trio.Lock()
|
||||
self._connection_start_time = time.time()
|
||||
self._bytes_read = 0
|
||||
self._bytes_written = 0
|
||||
self._closed = False
|
||||
self._close_lock = trio.Lock()
|
||||
self._max_buffered_amount = max_buffered_amount
|
||||
self._write_lock = trio.Lock()
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""Write data with flow control and buffer management"""
|
||||
if self._closed:
|
||||
raise IOException("Connection is closed")
|
||||
|
||||
async with self._write_lock:
|
||||
try:
|
||||
logger.debug(f"WebSocket writing {len(data)} bytes")
|
||||
|
||||
# Check buffer amount for flow control
|
||||
if hasattr(self._ws_connection, "bufferedAmount"):
|
||||
buffered = self._ws_connection.bufferedAmount
|
||||
if buffered > self._max_buffered_amount:
|
||||
logger.warning(f"WebSocket buffer full: {buffered} bytes")
|
||||
# In production, you might want to
|
||||
# wait or implement backpressure
|
||||
# For now, we'll continue but log the warning
|
||||
|
||||
# Send as a binary WebSocket message
|
||||
await self._ws_connection.send_message(data)
|
||||
self._bytes_written += len(data)
|
||||
logger.debug(f"WebSocket wrote {len(data)} bytes successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket write failed: {e}")
|
||||
self._closed = True
|
||||
raise IOException from e
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
"""
|
||||
Read up to n bytes (if n is given), else read up to 64KiB.
|
||||
This implementation provides byte-level access to WebSocket messages,
|
||||
which is required for libp2p protocol compatibility.
|
||||
|
||||
For WebSocket compatibility with libp2p protocols, this method:
|
||||
1. Buffers incoming WebSocket messages
|
||||
2. Returns exactly the requested number of bytes when n is specified
|
||||
3. Accumulates multiple WebSocket messages if needed to satisfy the request
|
||||
4. Returns empty bytes (not raises) when connection is closed and no data
|
||||
available
|
||||
"""
|
||||
if self._closed:
|
||||
raise IOException("Connection is closed")
|
||||
|
||||
async with self._read_lock:
|
||||
try:
|
||||
# If n is None, read at least one message and return all buffered data
|
||||
if n is None:
|
||||
if not self._read_buffer:
|
||||
try:
|
||||
# Use a short timeout to avoid blocking indefinitely
|
||||
with trio.fail_after(1.0): # 1 second timeout
|
||||
message = await self._ws_connection.get_message()
|
||||
if isinstance(message, str):
|
||||
message = message.encode("utf-8")
|
||||
self._read_buffer = message
|
||||
except trio.TooSlowError:
|
||||
# No message available within timeout
|
||||
return b""
|
||||
except Exception:
|
||||
# Return empty bytes if no data available
|
||||
# (connection closed)
|
||||
return b""
|
||||
|
||||
result = self._read_buffer
|
||||
self._read_buffer = b""
|
||||
self._bytes_read += len(result)
|
||||
return result
|
||||
|
||||
# For specific byte count requests, return UP TO n bytes (not exactly n)
|
||||
# This matches TCP semantics where read(1024) returns available data
|
||||
# up to 1024 bytes
|
||||
|
||||
# If we don't have any data buffered, try to get at least one message
|
||||
if not self._read_buffer:
|
||||
try:
|
||||
# Use a short timeout to avoid blocking indefinitely
|
||||
with trio.fail_after(1.0): # 1 second timeout
|
||||
message = await self._ws_connection.get_message()
|
||||
if isinstance(message, str):
|
||||
message = message.encode("utf-8")
|
||||
self._read_buffer = message
|
||||
except trio.TooSlowError:
|
||||
return b"" # No data available
|
||||
except Exception:
|
||||
return b""
|
||||
|
||||
# Now return up to n bytes from the buffer (TCP-like semantics)
|
||||
if len(self._read_buffer) == 0:
|
||||
return b""
|
||||
|
||||
# Return up to n bytes (like TCP read())
|
||||
result = self._read_buffer[:n]
|
||||
self._read_buffer = self._read_buffer[len(result) :]
|
||||
self._bytes_read += len(result)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket read failed: {e}")
|
||||
raise IOException from e
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket connection. This method is idempotent."""
|
||||
async with self._close_lock:
|
||||
if self._closed:
|
||||
return # Already closed
|
||||
|
||||
logger.debug("WebSocket connection closing")
|
||||
self._closed = True
|
||||
try:
|
||||
# Always close the connection directly, avoid context manager issues
|
||||
# The context manager may be causing cancel scope corruption
|
||||
logger.debug("WebSocket closing connection directly")
|
||||
await self._ws_connection.aclose()
|
||||
# Exit the context manager if we have one
|
||||
if self._ws_context is not None:
|
||||
await self._ws_context.__aexit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket close error: {e}")
|
||||
# Don't raise here, as close() should be idempotent
|
||||
finally:
|
||||
logger.debug("WebSocket connection closed")
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if the connection is closed"""
|
||||
return self._closed
|
||||
|
||||
def conn_state(self) -> dict[str, Any]:
|
||||
"""
|
||||
Return connection state information similar to Go's ConnState() method.
|
||||
|
||||
:return: Dictionary containing connection state information
|
||||
"""
|
||||
current_time = time.time()
|
||||
return {
|
||||
"transport": "websocket",
|
||||
"secure": self._is_secure,
|
||||
"connection_duration": current_time - self._connection_start_time,
|
||||
"bytes_read": self._bytes_read,
|
||||
"bytes_written": self._bytes_written,
|
||||
"total_bytes": self._bytes_read + self._bytes_written,
|
||||
}
|
||||
|
||||
def get_remote_address(self) -> tuple[str, int] | None:
|
||||
# Try to get remote address from the WebSocket connection
|
||||
try:
|
||||
remote = self._ws_connection.remote
|
||||
if hasattr(remote, "address") and hasattr(remote, "port"):
|
||||
return str(remote.address), int(remote.port)
|
||||
elif isinstance(remote, str):
|
||||
# Parse address:port format
|
||||
if ":" in remote:
|
||||
host, port = remote.rsplit(":", 1)
|
||||
return host, int(port)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
225
libp2p/transport/websocket/listener.py
Normal file
225
libp2p/transport/websocket/listener.py
Normal file
@ -0,0 +1,225 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
import logging
|
||||
import ssl
|
||||
from typing import Any
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
import trio
|
||||
from trio_typing import TaskStatus
|
||||
from trio_websocket import serve_websocket
|
||||
|
||||
from libp2p.abc import IListener
|
||||
from libp2p.custom_types import THandler
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
from libp2p.transport.websocket.multiaddr_utils import parse_websocket_multiaddr
|
||||
|
||||
from .connection import P2PWebSocketConnection
|
||||
|
||||
logger = logging.getLogger("libp2p.transport.websocket.listener")
|
||||
|
||||
|
||||
class WebsocketListener(IListener):
|
||||
"""
|
||||
Listen on /ip4/.../tcp/.../ws addresses, handshake WS, wrap into RawConnection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handler: THandler,
|
||||
upgrader: TransportUpgrader,
|
||||
tls_config: ssl.SSLContext | None = None,
|
||||
handshake_timeout: float = 15.0,
|
||||
) -> None:
|
||||
self._handler = handler
|
||||
self._upgrader = upgrader
|
||||
self._tls_config = tls_config
|
||||
self._handshake_timeout = handshake_timeout
|
||||
self._server = None
|
||||
self._shutdown_event = trio.Event()
|
||||
self._nursery: trio.Nursery | None = None
|
||||
self._listeners: Any = None
|
||||
self._is_wss = False # Track whether this is a WSS listener
|
||||
|
||||
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
|
||||
logger.debug(f"WebsocketListener.listen called with {maddr}")
|
||||
|
||||
# Parse the WebSocket multiaddr to determine if it's secure
|
||||
try:
|
||||
parsed = parse_websocket_multiaddr(maddr)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid WebSocket multiaddr: {e}") from e
|
||||
|
||||
# Check if WSS is requested but no TLS config provided
|
||||
if parsed.is_wss and self._tls_config is None:
|
||||
raise ValueError(
|
||||
f"Cannot listen on WSS address {maddr} without TLS configuration"
|
||||
)
|
||||
|
||||
# Store whether this is a WSS listener
|
||||
self._is_wss = parsed.is_wss
|
||||
|
||||
# Extract host and port from the base multiaddr
|
||||
host = (
|
||||
parsed.rest_multiaddr.value_for_protocol("ip4")
|
||||
or parsed.rest_multiaddr.value_for_protocol("ip6")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns4")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns6")
|
||||
or "0.0.0.0"
|
||||
)
|
||||
port_str = parsed.rest_multiaddr.value_for_protocol("tcp")
|
||||
if port_str is None:
|
||||
raise ValueError(f"No TCP port found in multiaddr: {maddr}")
|
||||
port = int(port_str)
|
||||
|
||||
logger.debug(
|
||||
f"WebsocketListener: host={host}, port={port}, secure={parsed.is_wss}"
|
||||
)
|
||||
|
||||
async def serve_websocket_tcp(
|
||||
handler: Callable[[Any], Awaitable[None]],
|
||||
port: int,
|
||||
host: str,
|
||||
task_status: TaskStatus[Any],
|
||||
) -> None:
|
||||
"""Start TCP server and handle WebSocket connections manually"""
|
||||
logger.debug(
|
||||
"serve_websocket_tcp %s %s (secure=%s)", host, port, parsed.is_wss
|
||||
)
|
||||
|
||||
async def websocket_handler(request: Any) -> None:
|
||||
"""Handle WebSocket requests"""
|
||||
logger.debug("WebSocket request received")
|
||||
try:
|
||||
# Apply handshake timeout
|
||||
with trio.fail_after(self._handshake_timeout):
|
||||
# Accept the WebSocket connection
|
||||
ws_connection = await request.accept()
|
||||
logger.debug("WebSocket handshake successful")
|
||||
|
||||
# Create the WebSocket connection wrapper
|
||||
conn = P2PWebSocketConnection(
|
||||
ws_connection, is_secure=parsed.is_wss
|
||||
) # type: ignore[no-untyped-call]
|
||||
|
||||
# Call the handler function that was passed to create_listener
|
||||
# This handler will handle the security and muxing upgrades
|
||||
logger.debug("Calling connection handler")
|
||||
await self._handler(conn)
|
||||
|
||||
# Don't keep the connection alive indefinitely
|
||||
# Let the handler manage the connection lifecycle
|
||||
logger.debug(
|
||||
"Handler completed, connection will be managed by handler"
|
||||
)
|
||||
|
||||
except trio.TooSlowError:
|
||||
logger.debug(
|
||||
f"WebSocket handshake timeout after {self._handshake_timeout}s"
|
||||
)
|
||||
try:
|
||||
await request.reject(408) # Request Timeout
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"WebSocket connection error: {e}")
|
||||
logger.debug(f"Error type: {type(e)}")
|
||||
import traceback
|
||||
|
||||
logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
# Reject the connection
|
||||
try:
|
||||
await request.reject(400)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Use trio_websocket.serve_websocket for proper WebSocket handling
|
||||
ssl_context = self._tls_config if parsed.is_wss else None
|
||||
await serve_websocket(
|
||||
websocket_handler, host, port, ssl_context, task_status=task_status
|
||||
)
|
||||
|
||||
# Store the nursery for shutdown
|
||||
self._nursery = nursery
|
||||
|
||||
# Start the server using nursery.start() like TCP does
|
||||
logger.debug("Calling nursery.start()...")
|
||||
started_listeners = await nursery.start(
|
||||
serve_websocket_tcp,
|
||||
None, # No handler needed since it's defined inside serve_websocket_tcp
|
||||
port,
|
||||
host,
|
||||
)
|
||||
logger.debug(f"nursery.start() returned: {started_listeners}")
|
||||
|
||||
if started_listeners is None:
|
||||
logger.error(f"Failed to start WebSocket listener for {maddr}")
|
||||
return False
|
||||
|
||||
# Store the listeners for get_addrs() and close() - these are real
|
||||
# SocketListener objects
|
||||
self._listeners = started_listeners
|
||||
logger.debug(
|
||||
"WebsocketListener.listen returning True with WebSocketServer object"
|
||||
)
|
||||
return True
|
||||
|
||||
def get_addrs(self) -> tuple[Multiaddr, ...]:
|
||||
if not hasattr(self, "_listeners") or not self._listeners:
|
||||
logger.debug("No listeners available for get_addrs()")
|
||||
return ()
|
||||
|
||||
# Handle WebSocketServer objects
|
||||
if hasattr(self._listeners, "port"):
|
||||
# This is a WebSocketServer object
|
||||
port = self._listeners.port
|
||||
# Create a multiaddr from the port with correct WSS/WS protocol
|
||||
protocol = "wss" if self._is_wss else "ws"
|
||||
return (Multiaddr(f"/ip4/127.0.0.1/tcp/{port}/{protocol}"),)
|
||||
else:
|
||||
# This is a list of listeners (like TCP)
|
||||
listeners = self._listeners
|
||||
# Get addresses from listeners like TCP does
|
||||
return tuple(
|
||||
_multiaddr_from_socket(listener.socket, self._is_wss)
|
||||
for listener in listeners
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket listener and stop accepting new connections"""
|
||||
logger.debug("WebsocketListener.close called")
|
||||
if hasattr(self, "_listeners") and self._listeners:
|
||||
# Signal shutdown
|
||||
self._shutdown_event.set()
|
||||
|
||||
# Close the WebSocket server
|
||||
if hasattr(self._listeners, "aclose"):
|
||||
# This is a WebSocketServer object
|
||||
logger.debug("Closing WebSocket server")
|
||||
await self._listeners.aclose()
|
||||
logger.debug("WebSocket server closed")
|
||||
elif isinstance(self._listeners, (list, tuple)):
|
||||
# This is a list of listeners (like TCP)
|
||||
logger.debug("Closing TCP listeners")
|
||||
for listener in self._listeners:
|
||||
await listener.aclose()
|
||||
logger.debug("TCP listeners closed")
|
||||
else:
|
||||
# Unknown type, try to close it directly
|
||||
logger.debug("Closing unknown listener type")
|
||||
if hasattr(self._listeners, "close"):
|
||||
self._listeners.close()
|
||||
logger.debug("Unknown listener closed")
|
||||
|
||||
# Clear the listeners reference
|
||||
self._listeners = None
|
||||
logger.debug("WebsocketListener.close completed")
|
||||
|
||||
|
||||
def _multiaddr_from_socket(
|
||||
socket: trio.socket.SocketType, is_wss: bool = False
|
||||
) -> Multiaddr:
|
||||
"""Convert socket to multiaddr"""
|
||||
ip, port = socket.getsockname()
|
||||
protocol = "wss" if is_wss else "ws"
|
||||
return Multiaddr(f"/ip4/{ip}/tcp/{port}/{protocol}")
|
||||
202
libp2p/transport/websocket/multiaddr_utils.py
Normal file
202
libp2p/transport/websocket/multiaddr_utils.py
Normal file
@ -0,0 +1,202 @@
|
||||
"""
|
||||
WebSocket multiaddr parsing utilities.
|
||||
"""
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
from multiaddr.protocols import Protocol
|
||||
|
||||
|
||||
class ParsedWebSocketMultiaddr(NamedTuple):
|
||||
"""Parsed WebSocket multiaddr information."""
|
||||
|
||||
is_wss: bool
|
||||
sni: str | None
|
||||
rest_multiaddr: Multiaddr
|
||||
|
||||
|
||||
def parse_websocket_multiaddr(maddr: Multiaddr) -> ParsedWebSocketMultiaddr:
|
||||
"""
|
||||
Parse a WebSocket multiaddr and extract security information.
|
||||
|
||||
:param maddr: The multiaddr to parse
|
||||
:return: Parsed WebSocket multiaddr information
|
||||
:raises ValueError: If the multiaddr is not a valid WebSocket multiaddr
|
||||
"""
|
||||
# First validate that this is a valid WebSocket multiaddr
|
||||
if not is_valid_websocket_multiaddr(maddr):
|
||||
raise ValueError(f"Not a valid WebSocket multiaddr: {maddr}")
|
||||
|
||||
protocols = list(maddr.protocols())
|
||||
|
||||
# Find the WebSocket protocol and check for security
|
||||
is_wss = False
|
||||
sni = None
|
||||
ws_index = -1
|
||||
tls_index = -1
|
||||
sni_index = -1
|
||||
|
||||
# Find protocol indices
|
||||
for i, protocol in enumerate(protocols):
|
||||
if protocol.name == "ws":
|
||||
ws_index = i
|
||||
elif protocol.name == "wss":
|
||||
ws_index = i
|
||||
is_wss = True
|
||||
elif protocol.name == "tls":
|
||||
tls_index = i
|
||||
elif protocol.name == "sni":
|
||||
sni_index = i
|
||||
sni = protocol.value
|
||||
|
||||
if ws_index == -1:
|
||||
raise ValueError("Not a WebSocket multiaddr")
|
||||
|
||||
# Handle /wss protocol (convert to /tls/ws internally)
|
||||
if is_wss and tls_index == -1:
|
||||
# Convert /wss to /tls/ws format
|
||||
# Remove /wss to get the base multiaddr
|
||||
without_wss = maddr.decapsulate(Multiaddr("/wss"))
|
||||
return ParsedWebSocketMultiaddr(
|
||||
is_wss=True, sni=None, rest_multiaddr=without_wss
|
||||
)
|
||||
|
||||
# Handle /tls/ws and /tls/sni/.../ws formats
|
||||
if tls_index != -1:
|
||||
is_wss = True
|
||||
# Extract the base multiaddr (everything before /tls)
|
||||
# For /ip4/127.0.0.1/tcp/8080/tls/ws, we want /ip4/127.0.0.1/tcp/8080
|
||||
# Use multiaddr methods to properly extract the base
|
||||
rest_multiaddr = maddr
|
||||
# Remove /tls/ws or /tls/sni/.../ws from the end
|
||||
if sni_index != -1:
|
||||
# /tls/sni/example.com/ws format
|
||||
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/ws"))
|
||||
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr(f"/sni/{sni}"))
|
||||
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/tls"))
|
||||
else:
|
||||
# /tls/ws format
|
||||
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/ws"))
|
||||
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/tls"))
|
||||
return ParsedWebSocketMultiaddr(
|
||||
is_wss=is_wss, sni=sni, rest_multiaddr=rest_multiaddr
|
||||
)
|
||||
|
||||
# Regular /ws multiaddr - remove /ws and any additional protocols
|
||||
rest_multiaddr = maddr.decapsulate(Multiaddr("/ws"))
|
||||
return ParsedWebSocketMultiaddr(
|
||||
is_wss=False, sni=None, rest_multiaddr=rest_multiaddr
|
||||
)
|
||||
|
||||
|
||||
def is_valid_websocket_multiaddr(maddr: Multiaddr) -> bool:
|
||||
"""
|
||||
Validate that a multiaddr has a valid WebSocket structure.
|
||||
|
||||
:param maddr: The multiaddr to validate
|
||||
:return: True if valid WebSocket structure, False otherwise
|
||||
"""
|
||||
try:
|
||||
# WebSocket multiaddr should have structure like:
|
||||
# /ip4/127.0.0.1/tcp/8080/ws (insecure)
|
||||
# /ip4/127.0.0.1/tcp/8080/wss (secure)
|
||||
# /ip4/127.0.0.1/tcp/8080/tls/ws (secure with TLS)
|
||||
# /ip4/127.0.0.1/tcp/8080/tls/sni/example.com/ws (secure with SNI)
|
||||
protocols: list[Protocol] = list(maddr.protocols())
|
||||
|
||||
# Must have at least 3 protocols: network (ip4/ip6/dns4/dns6) + tcp + ws/wss
|
||||
if len(protocols) < 3:
|
||||
return False
|
||||
|
||||
# First protocol should be a network protocol (ip4, ip6, dns, dns4, dns6)
|
||||
if protocols[0].name not in ["ip4", "ip6", "dns", "dns4", "dns6"]:
|
||||
return False
|
||||
|
||||
# Second protocol should be tcp
|
||||
if protocols[1].name != "tcp":
|
||||
return False
|
||||
|
||||
# Check for valid WebSocket protocols
|
||||
ws_protocols = ["ws", "wss"]
|
||||
tls_protocols = ["tls"]
|
||||
sni_protocols = ["sni"]
|
||||
|
||||
# Find the WebSocket protocol
|
||||
ws_protocol_found = False
|
||||
tls_found = False
|
||||
# sni_found = False # Not used currently
|
||||
|
||||
for i, protocol in enumerate(protocols[2:], start=2):
|
||||
if protocol.name in ws_protocols:
|
||||
ws_protocol_found = True
|
||||
break
|
||||
elif protocol.name in tls_protocols:
|
||||
tls_found = True
|
||||
elif protocol.name in sni_protocols:
|
||||
pass # sni_found = True # Not used in current implementation
|
||||
|
||||
if not ws_protocol_found:
|
||||
return False
|
||||
|
||||
# Validate protocol sequence
|
||||
# For /ws: network + tcp + ws
|
||||
# For /wss: network + tcp + wss
|
||||
# For /tls/ws: network + tcp + tls + ws
|
||||
# For /tls/sni/example.com/ws: network + tcp + tls + sni + ws
|
||||
|
||||
# Check if it's a simple /ws or /wss
|
||||
if len(protocols) == 3:
|
||||
return protocols[2].name in ["ws", "wss"]
|
||||
|
||||
# Check for /tls/ws or /tls/sni/.../ws patterns
|
||||
if tls_found:
|
||||
# Must end with /ws (not /wss when using /tls)
|
||||
if protocols[-1].name != "ws":
|
||||
return False
|
||||
|
||||
# Check for valid TLS sequence
|
||||
tls_index = None
|
||||
for i, protocol in enumerate(protocols[2:], start=2):
|
||||
if protocol.name == "tls":
|
||||
tls_index = i
|
||||
break
|
||||
|
||||
if tls_index is None:
|
||||
return False
|
||||
|
||||
# After tls, we can have sni, then ws
|
||||
remaining_protocols = protocols[tls_index + 1 :]
|
||||
if len(remaining_protocols) == 1:
|
||||
# /tls/ws
|
||||
return remaining_protocols[0].name == "ws"
|
||||
elif len(remaining_protocols) == 2:
|
||||
# /tls/sni/example.com/ws
|
||||
return (
|
||||
remaining_protocols[0].name == "sni"
|
||||
and remaining_protocols[1].name == "ws"
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
# If we have more than 3 protocols but no TLS, check for valid continuations
|
||||
# Allow additional protocols after the WebSocket protocol (like /p2p)
|
||||
valid_continuations = ["p2p"]
|
||||
|
||||
# Find the WebSocket protocol index
|
||||
ws_index = None
|
||||
for i, protocol in enumerate(protocols):
|
||||
if protocol.name in ["ws", "wss"]:
|
||||
ws_index = i
|
||||
break
|
||||
|
||||
if ws_index is not None:
|
||||
# Check protocols after the WebSocket protocol
|
||||
for i in range(ws_index + 1, len(protocols)):
|
||||
if protocols[i].name not in valid_continuations:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
229
libp2p/transport/websocket/transport.py
Normal file
229
libp2p/transport/websocket/transport.py
Normal file
@ -0,0 +1,229 @@
|
||||
import logging
|
||||
import ssl
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.abc import IListener, ITransport
|
||||
from libp2p.custom_types import THandler
|
||||
from libp2p.network.connection.raw_connection import RawConnection
|
||||
from libp2p.transport.exceptions import OpenConnectionError
|
||||
from libp2p.transport.upgrader import TransportUpgrader
|
||||
from libp2p.transport.websocket.multiaddr_utils import parse_websocket_multiaddr
|
||||
|
||||
from .connection import P2PWebSocketConnection
|
||||
from .listener import WebsocketListener
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebsocketTransport(ITransport):
|
||||
"""
|
||||
Libp2p WebSocket transport: dial and listen on /ip4/.../tcp/.../ws and /wss
|
||||
|
||||
Implements production-ready WebSocket transport with:
|
||||
- Flow control and buffer management
|
||||
- Connection limits and rate limiting
|
||||
- Proper error handling and cleanup
|
||||
- Support for both WS and WSS protocols
|
||||
- TLS configuration and handshake timeout
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
upgrader: TransportUpgrader,
|
||||
tls_client_config: ssl.SSLContext | None = None,
|
||||
tls_server_config: ssl.SSLContext | None = None,
|
||||
handshake_timeout: float = 15.0,
|
||||
max_buffered_amount: int = 4 * 1024 * 1024,
|
||||
):
|
||||
self._upgrader = upgrader
|
||||
self._tls_client_config = tls_client_config
|
||||
self._tls_server_config = tls_server_config
|
||||
self._handshake_timeout = handshake_timeout
|
||||
self._max_buffered_amount = max_buffered_amount
|
||||
self._connection_count = 0
|
||||
self._max_connections = 1000 # Production limit
|
||||
|
||||
async def dial(self, maddr: Multiaddr) -> RawConnection:
|
||||
"""Dial a WebSocket connection to the given multiaddr."""
|
||||
logger.debug(f"WebsocketTransport.dial called with {maddr}")
|
||||
|
||||
# Parse the WebSocket multiaddr to determine if it's secure
|
||||
try:
|
||||
parsed = parse_websocket_multiaddr(maddr)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid WebSocket multiaddr: {e}") from e
|
||||
|
||||
# Extract host and port from the base multiaddr
|
||||
host = (
|
||||
parsed.rest_multiaddr.value_for_protocol("ip4")
|
||||
or parsed.rest_multiaddr.value_for_protocol("ip6")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns4")
|
||||
or parsed.rest_multiaddr.value_for_protocol("dns6")
|
||||
)
|
||||
port_str = parsed.rest_multiaddr.value_for_protocol("tcp")
|
||||
if port_str is None:
|
||||
raise ValueError(f"No TCP port found in multiaddr: {maddr}")
|
||||
port = int(port_str)
|
||||
|
||||
# Build WebSocket URL based on security
|
||||
if parsed.is_wss:
|
||||
ws_url = f"wss://{host}:{port}/"
|
||||
else:
|
||||
ws_url = f"ws://{host}:{port}/"
|
||||
|
||||
logger.debug(
|
||||
f"WebsocketTransport.dial connecting to {ws_url} (secure={parsed.is_wss})"
|
||||
)
|
||||
|
||||
try:
|
||||
# Check connection limits
|
||||
if self._connection_count >= self._max_connections:
|
||||
raise OpenConnectionError(
|
||||
f"Maximum connections reached: {self._max_connections}"
|
||||
)
|
||||
|
||||
# Prepare SSL context for WSS connections
|
||||
ssl_context = None
|
||||
if parsed.is_wss:
|
||||
if self._tls_client_config:
|
||||
ssl_context = self._tls_client_config
|
||||
else:
|
||||
# Create default SSL context for client
|
||||
ssl_context = ssl.create_default_context()
|
||||
# Set SNI if available
|
||||
if parsed.sni:
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
logger.debug(f"WebsocketTransport.dial opening connection to {ws_url}")
|
||||
|
||||
# Use a different approach: start background nursery that will persist
|
||||
logger.debug("WebsocketTransport.dial establishing connection")
|
||||
|
||||
# Import trio-websocket functions
|
||||
from trio_websocket import connect_websocket
|
||||
from trio_websocket._impl import _url_to_host
|
||||
|
||||
# Parse the WebSocket URL to get host, port, resource
|
||||
# like trio-websocket does
|
||||
ws_host, ws_port, ws_resource, ws_ssl_context = _url_to_host(
|
||||
ws_url, ssl_context
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"WebsocketTransport.dial parsed URL: host={ws_host}, "
|
||||
f"port={ws_port}, resource={ws_resource}"
|
||||
)
|
||||
|
||||
# Create a background task manager for this connection
|
||||
import trio
|
||||
|
||||
nursery_manager = trio.lowlevel.current_task().parent_nursery
|
||||
if nursery_manager is None:
|
||||
raise OpenConnectionError(
|
||||
f"No parent nursery available for WebSocket connection to {maddr}"
|
||||
)
|
||||
|
||||
# Apply timeout to the connection process
|
||||
with trio.fail_after(self._handshake_timeout):
|
||||
logger.debug("WebsocketTransport.dial connecting WebSocket")
|
||||
ws = await connect_websocket(
|
||||
nursery_manager, # Use the existing nursery from libp2p
|
||||
ws_host,
|
||||
ws_port,
|
||||
ws_resource,
|
||||
use_ssl=ws_ssl_context,
|
||||
message_queue_size=1024, # Reasonable defaults
|
||||
max_message_size=16 * 1024 * 1024, # 16MB max message
|
||||
)
|
||||
logger.debug("WebsocketTransport.dial WebSocket connection established")
|
||||
|
||||
# Create our connection wrapper with both WSS support and flow control
|
||||
conn = P2PWebSocketConnection(
|
||||
ws,
|
||||
None,
|
||||
is_secure=parsed.is_wss,
|
||||
max_buffered_amount=self._max_buffered_amount,
|
||||
)
|
||||
logger.debug("WebsocketTransport.dial created P2PWebSocketConnection")
|
||||
|
||||
self._connection_count += 1
|
||||
logger.debug(f"Total connections: {self._connection_count}")
|
||||
|
||||
return RawConnection(conn, initiator=True)
|
||||
except trio.TooSlowError as e:
|
||||
raise OpenConnectionError(
|
||||
f"WebSocket handshake timeout after {self._handshake_timeout}s "
|
||||
f"for {maddr}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to dial WebSocket {maddr}: {e}")
|
||||
raise OpenConnectionError(f"Failed to dial WebSocket {maddr}: {e}") from e
|
||||
|
||||
def create_listener(self, handler: THandler) -> IListener: # type: ignore[override]
|
||||
"""
|
||||
The type checker is incorrectly reporting this as an inconsistent override.
|
||||
"""
|
||||
logger.debug("WebsocketTransport.create_listener called")
|
||||
return WebsocketListener(
|
||||
handler, self._upgrader, self._tls_server_config, self._handshake_timeout
|
||||
)
|
||||
|
||||
def resolve(self, maddr: Multiaddr) -> list[Multiaddr]:
|
||||
"""
|
||||
Resolve a WebSocket multiaddr, automatically adding SNI for DNS names.
|
||||
Similar to Go's Resolve() method.
|
||||
|
||||
:param maddr: The multiaddr to resolve
|
||||
:return: List of resolved multiaddrs
|
||||
"""
|
||||
try:
|
||||
parsed = parse_websocket_multiaddr(maddr)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Invalid WebSocket multiaddr for resolution: {e}")
|
||||
return [maddr] # Return original if not a valid WebSocket multiaddr
|
||||
|
||||
logger.debug(
|
||||
f"Parsed multiaddr {maddr}: is_wss={parsed.is_wss}, sni={parsed.sni}"
|
||||
)
|
||||
|
||||
if not parsed.is_wss:
|
||||
# No /tls/ws component, this isn't a secure websocket multiaddr
|
||||
return [maddr]
|
||||
|
||||
if parsed.sni is not None:
|
||||
# Already has SNI, return as-is
|
||||
return [maddr]
|
||||
|
||||
# Try to extract DNS name from the base multiaddr
|
||||
dns_name = None
|
||||
for protocol_name in ["dns", "dns4", "dns6"]:
|
||||
try:
|
||||
dns_name = parsed.rest_multiaddr.value_for_protocol(protocol_name)
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if dns_name is None:
|
||||
# No DNS name found, return original
|
||||
return [maddr]
|
||||
|
||||
# Create new multiaddr with SNI
|
||||
# For /dns/example.com/tcp/8080/wss ->
|
||||
# /dns/example.com/tcp/8080/tls/sni/example.com/ws
|
||||
try:
|
||||
# Remove /wss and add /tls/sni/example.com/ws
|
||||
without_wss = maddr.decapsulate(Multiaddr("/wss"))
|
||||
sni_component = Multiaddr(f"/sni/{dns_name}")
|
||||
resolved = (
|
||||
without_wss.encapsulate(Multiaddr("/tls"))
|
||||
.encapsulate(sni_component)
|
||||
.encapsulate(Multiaddr("/ws"))
|
||||
)
|
||||
logger.debug(f"Resolved {maddr} to {resolved}")
|
||||
return [resolved]
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to resolve multiaddr {maddr}: {e}")
|
||||
return [maddr]
|
||||
@ -3,38 +3,24 @@ from __future__ import annotations
|
||||
import socket
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
try:
|
||||
from multiaddr.utils import ( # type: ignore
|
||||
get_network_addrs,
|
||||
get_thin_waist_addresses,
|
||||
)
|
||||
|
||||
_HAS_THIN_WAIST = True
|
||||
except ImportError: # pragma: no cover - only executed in older environments
|
||||
_HAS_THIN_WAIST = False
|
||||
get_thin_waist_addresses = None # type: ignore
|
||||
get_network_addrs = None # type: ignore
|
||||
from multiaddr.utils import get_network_addrs, get_thin_waist_addresses
|
||||
|
||||
|
||||
def _safe_get_network_addrs(ip_version: int) -> list[str]:
|
||||
"""
|
||||
Internal safe wrapper. Returns a list of IP addresses for the requested IP version.
|
||||
Falls back to minimal defaults when Thin Waist helpers are missing.
|
||||
|
||||
:param ip_version: 4 or 6
|
||||
"""
|
||||
if _HAS_THIN_WAIST and get_network_addrs:
|
||||
try:
|
||||
return get_network_addrs(ip_version) or []
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return []
|
||||
# Fallback behavior (very conservative)
|
||||
if ip_version == 4:
|
||||
return ["127.0.0.1"]
|
||||
if ip_version == 6:
|
||||
return ["::1"]
|
||||
return []
|
||||
try:
|
||||
return get_network_addrs(ip_version) or []
|
||||
except Exception: # pragma: no cover - defensive
|
||||
# Fallback behavior (very conservative)
|
||||
if ip_version == 4:
|
||||
return ["127.0.0.1"]
|
||||
if ip_version == 6:
|
||||
return ["::1"]
|
||||
return []
|
||||
|
||||
|
||||
def find_free_port() -> int:
|
||||
@ -47,16 +33,13 @@ def find_free_port() -> int:
|
||||
def _safe_expand(addr: Multiaddr, port: int | None = None) -> list[Multiaddr]:
|
||||
"""
|
||||
Internal safe expansion wrapper. Returns a list of Multiaddr objects.
|
||||
If Thin Waist isn't available, returns [addr] (identity).
|
||||
"""
|
||||
if _HAS_THIN_WAIST and get_thin_waist_addresses:
|
||||
try:
|
||||
if port is not None:
|
||||
return get_thin_waist_addresses(addr, port=port) or []
|
||||
return get_thin_waist_addresses(addr) or []
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return [addr]
|
||||
return [addr]
|
||||
try:
|
||||
if port is not None:
|
||||
return get_thin_waist_addresses(addr, port=port) or []
|
||||
return get_thin_waist_addresses(addr) or []
|
||||
except Exception: # pragma: no cover - defensive
|
||||
return [addr]
|
||||
|
||||
|
||||
def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr]:
|
||||
@ -73,8 +56,9 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr
|
||||
seen_v4: set[str] = set()
|
||||
|
||||
for ip in _safe_get_network_addrs(4):
|
||||
seen_v4.add(ip)
|
||||
addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}"))
|
||||
if ip not in seen_v4: # Avoid duplicates
|
||||
seen_v4.add(ip)
|
||||
addrs.append(Multiaddr(f"/ip4/{ip}/{protocol}/{port}"))
|
||||
|
||||
# Ensure IPv4 loopback is always included when IPv4 interfaces are discovered
|
||||
if seen_v4 and "127.0.0.1" not in seen_v4:
|
||||
@ -89,8 +73,9 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr
|
||||
#
|
||||
# seen_v6: set[str] = set()
|
||||
# for ip in _safe_get_network_addrs(6):
|
||||
# seen_v6.add(ip)
|
||||
# addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}"))
|
||||
# if ip not in seen_v6: # Avoid duplicates
|
||||
# seen_v6.add(ip)
|
||||
# addrs.append(Multiaddr(f"/ip6/{ip}/{protocol}/{port}"))
|
||||
#
|
||||
# # Always include IPv6 loopback for testing purposes when IPv6 is available
|
||||
# # This ensures IPv6 functionality can be tested even without global IPv6 addresses
|
||||
@ -99,7 +84,7 @@ def get_available_interfaces(port: int, protocol: str = "tcp") -> list[Multiaddr
|
||||
|
||||
# Fallback if nothing discovered
|
||||
if not addrs:
|
||||
addrs.append(Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}"))
|
||||
addrs.append(Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}"))
|
||||
|
||||
return addrs
|
||||
|
||||
@ -120,6 +105,20 @@ def expand_wildcard_address(
|
||||
return expanded
|
||||
|
||||
|
||||
def get_wildcard_address(port: int, protocol: str = "tcp") -> Multiaddr:
|
||||
"""
|
||||
Get wildcard address (0.0.0.0) when explicitly needed.
|
||||
|
||||
This function provides access to wildcard binding as a feature when
|
||||
explicitly required, preserving the ability to bind to all interfaces.
|
||||
|
||||
:param port: Port number.
|
||||
:param protocol: Transport protocol.
|
||||
:return: A Multiaddr with wildcard binding (0.0.0.0).
|
||||
"""
|
||||
return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")
|
||||
|
||||
|
||||
def get_optimal_binding_address(port: int, protocol: str = "tcp") -> Multiaddr:
|
||||
"""
|
||||
Choose an optimal address for an example to bind to:
|
||||
@ -148,13 +147,14 @@ def get_optimal_binding_address(port: int, protocol: str = "tcp") -> Multiaddr:
|
||||
if "/ip4/127." in str(c) or "/ip6/::1" in str(c):
|
||||
return c
|
||||
|
||||
# As a final fallback, produce a wildcard
|
||||
return Multiaddr(f"/ip4/0.0.0.0/{protocol}/{port}")
|
||||
# As a final fallback, produce a loopback address
|
||||
return Multiaddr(f"/ip4/127.0.0.1/{protocol}/{port}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_available_interfaces",
|
||||
"get_optimal_binding_address",
|
||||
"get_wildcard_address",
|
||||
"expand_wildcard_address",
|
||||
"find_free_port",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user