mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
- Fix type annotation errors in transport_registry.py and __init__.py - Fix line length violations in test files (E501 errors) - Fix missing return type annotations - Fix cryptography NameAttribute type errors with type: ignore - Fix ExceptionGroup import for cross-version compatibility - Fix test failure in test_wss_listen_without_tls_config by handling ExceptionGroup - Fix len() calls with None arguments in test_tcp_data_transfer.py - Fix missing attribute access errors on interface types - Fix boolean type expectation errors in test_js_ws_ping.py - Fix nursery context manager type errors All tests now pass and linting is clean.
262 lines
9.8 KiB
Python
262 lines
9.8 KiB
Python
"""
|
|
Transport registry for dynamic transport selection based on multiaddr protocols.
|
|
"""
|
|
|
|
from collections.abc import Callable
|
|
import logging
|
|
from typing import Any
|
|
|
|
from multiaddr import Multiaddr
|
|
from multiaddr.protocols import Protocol
|
|
|
|
from libp2p.abc import ITransport
|
|
from libp2p.transport.tcp.tcp import TCP
|
|
from libp2p.transport.upgrader import TransportUpgrader
|
|
from libp2p.transport.websocket.multiaddr_utils import (
|
|
is_valid_websocket_multiaddr,
|
|
)
|
|
|
|
|
|
# Import QUIC utilities here to avoid circular imports
|
|
def _get_quic_transport() -> Any:
|
|
from libp2p.transport.quic.transport import QUICTransport
|
|
|
|
return QUICTransport
|
|
|
|
|
|
def _get_quic_validation() -> Callable[[Multiaddr], bool]:
|
|
from libp2p.transport.quic.utils import is_quic_multiaddr
|
|
|
|
return is_quic_multiaddr
|
|
|
|
|
|
# Import WebsocketTransport here to avoid circular imports
|
|
def _get_websocket_transport() -> Any:
|
|
from libp2p.transport.websocket.transport import WebsocketTransport
|
|
|
|
return WebsocketTransport
|
|
|
|
|
|
logger = logging.getLogger("libp2p.transport.registry")
|
|
|
|
|
|
def _is_valid_tcp_multiaddr(maddr: Multiaddr) -> bool:
|
|
"""
|
|
Validate that a multiaddr has a valid TCP structure.
|
|
|
|
:param maddr: The multiaddr to validate
|
|
:return: True if valid TCP structure, False otherwise
|
|
"""
|
|
try:
|
|
# TCP multiaddr should have structure like /ip4/127.0.0.1/tcp/8080
|
|
# or /ip6/::1/tcp/8080
|
|
protocols: list[Protocol] = list(maddr.protocols())
|
|
|
|
# Must have at least 2 protocols: network (ip4/ip6) + tcp
|
|
if len(protocols) < 2:
|
|
return False
|
|
|
|
# First protocol should be a network protocol (ip4, ip6, dns4, dns6)
|
|
if protocols[0].name not in ["ip4", "ip6", "dns4", "dns6"]:
|
|
return False
|
|
|
|
# Second protocol should be tcp
|
|
if protocols[1].name != "tcp":
|
|
return False
|
|
|
|
# Should not have any protocols after tcp (unless it's a valid
|
|
# continuation like p2p)
|
|
# For now, we'll be strict and only allow network + tcp
|
|
if len(protocols) > 2:
|
|
# Check if the additional protocols are valid continuations
|
|
valid_continuations = ["p2p"] # Add more as needed
|
|
for i in range(2, len(protocols)):
|
|
if protocols[i].name not in valid_continuations:
|
|
return False
|
|
|
|
return True
|
|
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
class TransportRegistry:
|
|
"""
|
|
Registry for mapping multiaddr protocols to transport implementations.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._transports: dict[str, type[ITransport]] = {}
|
|
self._register_default_transports()
|
|
|
|
def _register_default_transports(self) -> None:
|
|
"""Register the default transport implementations."""
|
|
# Register TCP transport for /tcp protocol
|
|
self.register_transport("tcp", TCP)
|
|
|
|
# Register WebSocket transport for /ws and /wss protocols
|
|
WebsocketTransport = _get_websocket_transport()
|
|
self.register_transport("ws", WebsocketTransport)
|
|
self.register_transport("wss", WebsocketTransport)
|
|
|
|
# Register QUIC transport for /quic and /quic-v1 protocols
|
|
QUICTransport = _get_quic_transport()
|
|
self.register_transport("quic", QUICTransport)
|
|
self.register_transport("quic-v1", QUICTransport)
|
|
|
|
def register_transport(
|
|
self, protocol: str, transport_class: type[ITransport]
|
|
) -> None:
|
|
"""
|
|
Register a transport class for a specific protocol.
|
|
|
|
:param protocol: The protocol identifier (e.g., "tcp", "ws")
|
|
:param transport_class: The transport class to register
|
|
"""
|
|
self._transports[protocol] = transport_class
|
|
logger.debug(
|
|
f"Registered transport {transport_class.__name__} for protocol {protocol}"
|
|
)
|
|
|
|
def get_transport(self, protocol: str) -> type[ITransport] | None:
|
|
"""
|
|
Get the transport class for a specific protocol.
|
|
|
|
:param protocol: The protocol identifier
|
|
:return: The transport class or None if not found
|
|
"""
|
|
return self._transports.get(protocol)
|
|
|
|
def get_supported_protocols(self) -> list[str]:
|
|
"""Get list of supported transport protocols."""
|
|
return list(self._transports.keys())
|
|
|
|
def create_transport(
|
|
self, protocol: str, upgrader: TransportUpgrader | None = None, **kwargs: Any
|
|
) -> ITransport | None:
|
|
"""
|
|
Create a transport instance for a specific protocol.
|
|
|
|
:param protocol: The protocol identifier
|
|
:param upgrader: The transport upgrader instance (required for WebSocket)
|
|
:param kwargs: Additional arguments for transport construction
|
|
:return: Transport instance or None if protocol not supported or creation fails
|
|
"""
|
|
transport_class = self.get_transport(protocol)
|
|
if transport_class is None:
|
|
return None
|
|
|
|
try:
|
|
if protocol in ["ws", "wss"]:
|
|
# WebSocket transport requires upgrader
|
|
if upgrader is None:
|
|
logger.warning(
|
|
f"WebSocket transport '{protocol}' requires upgrader"
|
|
)
|
|
return None
|
|
# Use explicit WebsocketTransport to avoid type issues
|
|
WebsocketTransport = _get_websocket_transport()
|
|
return WebsocketTransport(
|
|
upgrader,
|
|
tls_client_config=kwargs.get("tls_client_config"),
|
|
tls_server_config=kwargs.get("tls_server_config"),
|
|
handshake_timeout=kwargs.get("handshake_timeout", 15.0),
|
|
)
|
|
elif protocol in ["quic", "quic-v1"]:
|
|
# QUIC transport requires private_key
|
|
private_key = kwargs.get("private_key")
|
|
if private_key is None:
|
|
logger.warning(f"QUIC transport '{protocol}' requires private_key")
|
|
return None
|
|
# Use explicit QUICTransport to avoid type issues
|
|
QUICTransport = _get_quic_transport()
|
|
config = kwargs.get("config")
|
|
return QUICTransport(private_key, config)
|
|
else:
|
|
# TCP transport doesn't require upgrader
|
|
return transport_class()
|
|
except Exception as e:
|
|
logger.error(f"Failed to create transport for protocol {protocol}: {e}")
|
|
return None
|
|
|
|
|
|
# Global transport registry instance
|
|
_global_registry = TransportRegistry()
|
|
|
|
|
|
def get_transport_registry() -> TransportRegistry:
|
|
"""Get the global transport registry instance."""
|
|
return _global_registry
|
|
|
|
|
|
def register_transport(protocol: str, transport_class: type[ITransport]) -> None:
|
|
"""Register a transport class in the global registry."""
|
|
_global_registry.register_transport(protocol, transport_class)
|
|
|
|
|
|
def create_transport_for_multiaddr(
|
|
maddr: Multiaddr, upgrader: TransportUpgrader, **kwargs: Any
|
|
) -> ITransport | None:
|
|
"""
|
|
Create the appropriate transport for a given multiaddr.
|
|
|
|
:param maddr: The multiaddr to create transport for
|
|
:param upgrader: The transport upgrader instance
|
|
:param kwargs: Additional arguments for transport construction
|
|
(e.g., private_key for QUIC)
|
|
:return: Transport instance or None if no suitable transport found
|
|
"""
|
|
try:
|
|
# Get all protocols in the multiaddr
|
|
protocols = [proto.name for proto in maddr.protocols()]
|
|
|
|
# Check for supported transport protocols in order of preference
|
|
# We need to validate that the multiaddr structure is valid for our transports
|
|
if "quic" in protocols or "quic-v1" in protocols:
|
|
# For QUIC, we need a valid structure like:
|
|
# /ip4/127.0.0.1/udp/4001/quic
|
|
# /ip4/127.0.0.1/udp/4001/quic-v1
|
|
is_quic_multiaddr = _get_quic_validation()
|
|
if is_quic_multiaddr(maddr):
|
|
# Determine QUIC version
|
|
if "quic-v1" in protocols:
|
|
return _global_registry.create_transport(
|
|
"quic-v1", upgrader, **kwargs
|
|
)
|
|
else:
|
|
return _global_registry.create_transport("quic", upgrader, **kwargs)
|
|
elif "ws" in protocols or "wss" in protocols or "tls" in protocols:
|
|
# For WebSocket, we need a valid structure like:
|
|
# /ip4/127.0.0.1/tcp/8080/ws (insecure)
|
|
# /ip4/127.0.0.1/tcp/8080/wss (secure)
|
|
# /ip4/127.0.0.1/tcp/8080/tls/ws (secure with TLS)
|
|
# /ip4/127.0.0.1/tcp/8080/tls/sni/example.com/ws (secure with SNI)
|
|
if is_valid_websocket_multiaddr(maddr):
|
|
# Determine if this is a secure WebSocket connection
|
|
if "wss" in protocols or "tls" in protocols:
|
|
return _global_registry.create_transport("wss", upgrader, **kwargs)
|
|
else:
|
|
return _global_registry.create_transport("ws", upgrader, **kwargs)
|
|
elif "tcp" in protocols:
|
|
# For TCP, we need a valid structure like /ip4/127.0.0.1/tcp/8080
|
|
# Check if the multiaddr has proper TCP structure
|
|
if _is_valid_tcp_multiaddr(maddr):
|
|
return _global_registry.create_transport("tcp", upgrader)
|
|
|
|
# If no supported transport protocol found or structure is invalid, return None
|
|
logger.warning(
|
|
f"No supported transport protocol found or invalid structure in "
|
|
f"multiaddr: {maddr}"
|
|
)
|
|
return None
|
|
|
|
except Exception as e:
|
|
# Handle any errors gracefully (e.g., invalid multiaddr)
|
|
logger.warning(f"Error processing multiaddr {maddr}: {e}")
|
|
return None
|
|
|
|
|
|
def get_supported_transport_protocols() -> list[str]:
|
|
"""Get list of supported transport protocols from the global registry."""
|
|
return _global_registry.get_supported_protocols()
|