mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
- Fix type annotation errors in transport_registry.py and __init__.py - Fix line length violations in test files (E501 errors) - Fix missing return type annotations - Fix cryptography NameAttribute type errors with type: ignore - Fix ExceptionGroup import for cross-version compatibility - Fix test failure in test_wss_listen_without_tls_config by handling ExceptionGroup - Fix len() calls with None arguments in test_tcp_data_transfer.py - Fix missing attribute access errors on interface types - Fix boolean type expectation errors in test_js_ws_ping.py - Fix nursery context manager type errors All tests now pass and linting is clean.
167 lines
6.6 KiB
Python
167 lines
6.6 KiB
Python
import logging
|
|
import time
|
|
from typing import Any
|
|
|
|
import trio
|
|
|
|
from libp2p.io.abc import ReadWriteCloser
|
|
from libp2p.io.exceptions import IOException
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class P2PWebSocketConnection(ReadWriteCloser):
|
|
"""
|
|
Wraps a WebSocketConnection to provide the raw stream interface
|
|
that libp2p protocols expect.
|
|
"""
|
|
|
|
def __init__(
|
|
self, ws_connection: Any, ws_context: Any = None, is_secure: bool = False
|
|
) -> None:
|
|
self._ws_connection = ws_connection
|
|
self._ws_context = ws_context
|
|
self._is_secure = is_secure
|
|
self._read_buffer = b""
|
|
self._read_lock = trio.Lock()
|
|
self._connection_start_time = time.time()
|
|
self._bytes_read = 0
|
|
self._bytes_written = 0
|
|
self._closed = False
|
|
self._close_lock = trio.Lock()
|
|
|
|
async def write(self, data: bytes) -> None:
|
|
if self._closed:
|
|
raise IOException("Connection is closed")
|
|
|
|
try:
|
|
# Send as a binary WebSocket message
|
|
await self._ws_connection.send_message(data)
|
|
self._bytes_written += len(data)
|
|
except Exception as e:
|
|
logger.error(f"WebSocket write failed: {e}")
|
|
raise IOException from e
|
|
|
|
async def read(self, n: int | None = None) -> bytes:
|
|
"""
|
|
Read up to n bytes (if n is given), else read up to 64KiB.
|
|
This implementation provides byte-level access to WebSocket messages,
|
|
which is required for libp2p protocol compatibility.
|
|
|
|
For WebSocket compatibility with libp2p protocols, this method:
|
|
1. Buffers incoming WebSocket messages
|
|
2. Returns exactly the requested number of bytes when n is specified
|
|
3. Accumulates multiple WebSocket messages if needed to satisfy the request
|
|
4. Returns empty bytes (not raises) when connection is closed and no data
|
|
available
|
|
"""
|
|
if self._closed:
|
|
raise IOException("Connection is closed")
|
|
|
|
async with self._read_lock:
|
|
try:
|
|
# If n is None, read at least one message and return all buffered data
|
|
if n is None:
|
|
if not self._read_buffer:
|
|
try:
|
|
# Use a short timeout to avoid blocking indefinitely
|
|
with trio.fail_after(1.0): # 1 second timeout
|
|
message = await self._ws_connection.get_message()
|
|
if isinstance(message, str):
|
|
message = message.encode("utf-8")
|
|
self._read_buffer = message
|
|
except trio.TooSlowError:
|
|
# No message available within timeout
|
|
return b""
|
|
except Exception:
|
|
# Return empty bytes if no data available
|
|
# (connection closed)
|
|
return b""
|
|
|
|
result = self._read_buffer
|
|
self._read_buffer = b""
|
|
self._bytes_read += len(result)
|
|
return result
|
|
|
|
# For specific byte count requests, return UP TO n bytes (not exactly n)
|
|
# This matches TCP semantics where read(1024) returns available data
|
|
# up to 1024 bytes
|
|
|
|
# If we don't have any data buffered, try to get at least one message
|
|
if not self._read_buffer:
|
|
try:
|
|
# Use a short timeout to avoid blocking indefinitely
|
|
with trio.fail_after(1.0): # 1 second timeout
|
|
message = await self._ws_connection.get_message()
|
|
if isinstance(message, str):
|
|
message = message.encode("utf-8")
|
|
self._read_buffer = message
|
|
except trio.TooSlowError:
|
|
return b"" # No data available
|
|
except Exception:
|
|
return b""
|
|
|
|
# Now return up to n bytes from the buffer (TCP-like semantics)
|
|
if len(self._read_buffer) == 0:
|
|
return b""
|
|
|
|
# Return up to n bytes (like TCP read())
|
|
result = self._read_buffer[:n]
|
|
self._read_buffer = self._read_buffer[len(result) :]
|
|
self._bytes_read += len(result)
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket read failed: {e}")
|
|
raise IOException from e
|
|
|
|
async def close(self) -> None:
|
|
"""Close the WebSocket connection. This method is idempotent."""
|
|
async with self._close_lock:
|
|
if self._closed:
|
|
return # Already closed
|
|
|
|
logger.debug("WebSocket connection closing")
|
|
try:
|
|
# Always close the connection directly, avoid context manager issues
|
|
# The context manager may be causing cancel scope corruption
|
|
logger.debug("WebSocket closing connection directly")
|
|
await self._ws_connection.aclose()
|
|
except Exception as e:
|
|
logger.error(f"WebSocket close error: {e}")
|
|
# Don't raise here, as close() should be idempotent
|
|
finally:
|
|
self._closed = True
|
|
logger.debug("WebSocket connection closed")
|
|
|
|
def conn_state(self) -> dict[str, Any]:
|
|
"""
|
|
Return connection state information similar to Go's ConnState() method.
|
|
|
|
:return: Dictionary containing connection state information
|
|
"""
|
|
current_time = time.time()
|
|
return {
|
|
"transport": "websocket",
|
|
"secure": self._is_secure,
|
|
"connection_duration": current_time - self._connection_start_time,
|
|
"bytes_read": self._bytes_read,
|
|
"bytes_written": self._bytes_written,
|
|
"total_bytes": self._bytes_read + self._bytes_written,
|
|
}
|
|
|
|
def get_remote_address(self) -> tuple[str, int] | None:
|
|
# Try to get remote address from the WebSocket connection
|
|
try:
|
|
remote = self._ws_connection.remote
|
|
if hasattr(remote, "address") and hasattr(remote, "port"):
|
|
return str(remote.address), int(remote.port)
|
|
elif isinstance(remote, str):
|
|
# Parse address:port format
|
|
if ":" in remote:
|
|
host, port = remote.rsplit(":", 1)
|
|
return host, int(port)
|
|
except Exception:
|
|
pass
|
|
return None
|