Files
py-libp2p/libp2p/transport/websocket/multiaddr_utils.py
acul71 396812e84a Experimental: Add comprehensive WebSocket and WSS implementation with tests
- Implemented full WSS support with TLS configuration
- Added handshake timeout and connection state tracking
- Created comprehensive test suite with 13+ WSS unit tests
- Added Python-to-Python WebSocket peer-to-peer tests
- Implemented multiaddr parsing for /ws, /wss, /tls/ws formats
- Added connection state tracking and concurrent close handling
- Created standalone WebSocket client for testing
- Fixed circular import issues with multiaddr utilities
- Added debug tools for WebSocket URL testing

All WebSocket transport functionality is complete and working.
Tests demonstrate WebSocket transport works correctly at the transport layer.
Higher-level libp2p protocol compatibility issues remain (same as JS interop).
2025-09-07 23:44:17 +02:00

203 lines
6.8 KiB
Python

"""
WebSocket multiaddr parsing utilities.
"""
from typing import NamedTuple
from multiaddr import Multiaddr
from multiaddr.protocols import Protocol
class ParsedWebSocketMultiaddr(NamedTuple):
"""Parsed WebSocket multiaddr information."""
is_wss: bool
sni: str | None
rest_multiaddr: Multiaddr
def parse_websocket_multiaddr(maddr: Multiaddr) -> ParsedWebSocketMultiaddr:
"""
Parse a WebSocket multiaddr and extract security information.
:param maddr: The multiaddr to parse
:return: Parsed WebSocket multiaddr information
:raises ValueError: If the multiaddr is not a valid WebSocket multiaddr
"""
# First validate that this is a valid WebSocket multiaddr
if not is_valid_websocket_multiaddr(maddr):
raise ValueError(f"Not a valid WebSocket multiaddr: {maddr}")
protocols = list(maddr.protocols())
# Find the WebSocket protocol and check for security
is_wss = False
sni = None
ws_index = -1
tls_index = -1
sni_index = -1
# Find protocol indices
for i, protocol in enumerate(protocols):
if protocol.name == "ws":
ws_index = i
elif protocol.name == "wss":
ws_index = i
is_wss = True
elif protocol.name == "tls":
tls_index = i
elif protocol.name == "sni":
sni_index = i
sni = protocol.value
if ws_index == -1:
raise ValueError("Not a WebSocket multiaddr")
# Handle /wss protocol (convert to /tls/ws internally)
if is_wss and tls_index == -1:
# Convert /wss to /tls/ws format
# Remove /wss to get the base multiaddr
without_wss = maddr.decapsulate(Multiaddr("/wss"))
return ParsedWebSocketMultiaddr(
is_wss=True, sni=None, rest_multiaddr=without_wss
)
# Handle /tls/ws and /tls/sni/.../ws formats
if tls_index != -1:
is_wss = True
# Extract the base multiaddr (everything before /tls)
# For /ip4/127.0.0.1/tcp/8080/tls/ws, we want /ip4/127.0.0.1/tcp/8080
# Use multiaddr methods to properly extract the base
rest_multiaddr = maddr
# Remove /tls/ws or /tls/sni/.../ws from the end
if sni_index != -1:
# /tls/sni/example.com/ws format
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/ws"))
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr(f"/sni/{sni}"))
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/tls"))
else:
# /tls/ws format
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/ws"))
rest_multiaddr = rest_multiaddr.decapsulate(Multiaddr("/tls"))
return ParsedWebSocketMultiaddr(
is_wss=is_wss, sni=sni, rest_multiaddr=rest_multiaddr
)
# Regular /ws multiaddr - remove /ws and any additional protocols
rest_multiaddr = maddr.decapsulate(Multiaddr("/ws"))
return ParsedWebSocketMultiaddr(
is_wss=False, sni=None, rest_multiaddr=rest_multiaddr
)
def is_valid_websocket_multiaddr(maddr: Multiaddr) -> bool:
"""
Validate that a multiaddr has a valid WebSocket structure.
:param maddr: The multiaddr to validate
:return: True if valid WebSocket structure, False otherwise
"""
try:
# WebSocket multiaddr should have structure like:
# /ip4/127.0.0.1/tcp/8080/ws (insecure)
# /ip4/127.0.0.1/tcp/8080/wss (secure)
# /ip4/127.0.0.1/tcp/8080/tls/ws (secure with TLS)
# /ip4/127.0.0.1/tcp/8080/tls/sni/example.com/ws (secure with SNI)
protocols: list[Protocol] = list(maddr.protocols())
# Must have at least 3 protocols: network (ip4/ip6/dns4/dns6) + tcp + ws/wss
if len(protocols) < 3:
return False
# First protocol should be a network protocol (ip4, ip6, dns, dns4, dns6)
if protocols[0].name not in ["ip4", "ip6", "dns", "dns4", "dns6"]:
return False
# Second protocol should be tcp
if protocols[1].name != "tcp":
return False
# Check for valid WebSocket protocols
ws_protocols = ["ws", "wss"]
tls_protocols = ["tls"]
sni_protocols = ["sni"]
# Find the WebSocket protocol
ws_protocol_found = False
tls_found = False
sni_found = False
for i, protocol in enumerate(protocols[2:], start=2):
if protocol.name in ws_protocols:
ws_protocol_found = True
break
elif protocol.name in tls_protocols:
tls_found = True
elif protocol.name in sni_protocols:
# sni_found = True # Not used in current implementation
if not ws_protocol_found:
return False
# Validate protocol sequence
# For /ws: network + tcp + ws
# For /wss: network + tcp + wss
# For /tls/ws: network + tcp + tls + ws
# For /tls/sni/example.com/ws: network + tcp + tls + sni + ws
# Check if it's a simple /ws or /wss
if len(protocols) == 3:
return protocols[2].name in ["ws", "wss"]
# Check for /tls/ws or /tls/sni/.../ws patterns
if tls_found:
# Must end with /ws (not /wss when using /tls)
if protocols[-1].name != "ws":
return False
# Check for valid TLS sequence
tls_index = None
for i, protocol in enumerate(protocols[2:], start=2):
if protocol.name == "tls":
tls_index = i
break
if tls_index is None:
return False
# After tls, we can have sni, then ws
remaining_protocols = protocols[tls_index + 1 :]
if len(remaining_protocols) == 1:
# /tls/ws
return remaining_protocols[0].name == "ws"
elif len(remaining_protocols) == 2:
# /tls/sni/example.com/ws
return (
remaining_protocols[0].name == "sni"
and remaining_protocols[1].name == "ws"
)
else:
return False
# If we have more than 3 protocols but no TLS, check for valid continuations
# Allow additional protocols after the WebSocket protocol (like /p2p)
valid_continuations = ["p2p"]
# Find the WebSocket protocol index
ws_index = None
for i, protocol in enumerate(protocols):
if protocol.name in ["ws", "wss"]:
ws_index = i
break
if ws_index is not None:
# Check protocols after the WebSocket protocol
for i in range(ws_index + 1, len(protocols)):
if protocols[i].name not in valid_continuations:
return False
return True
except Exception:
return False