mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
- Fix INotifee interface compliance in WebSocket demo - Fix handler function signatures to be async (THandler compatibility) - Fix is_closed method usage with proper type checking - Fix pytest.raises multiple exception type issue - Fix line length violations (E501) across multiple files - Add debugging logging to Noise security module for troubleshooting - Update WebSocket transport examples and tests - Improve transport registry error handling
145 lines
5.8 KiB
Python
145 lines
5.8 KiB
Python
import logging
|
|
from typing import Any
|
|
|
|
import trio
|
|
|
|
from libp2p.io.abc import ReadWriteCloser
|
|
from libp2p.io.exceptions import IOException
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class P2PWebSocketConnection(ReadWriteCloser):
|
|
"""
|
|
Wraps a WebSocketConnection to provide the raw stream interface
|
|
that libp2p protocols expect.
|
|
"""
|
|
|
|
def __init__(self, ws_connection: Any, ws_context: Any = None) -> None:
|
|
self._ws_connection = ws_connection
|
|
self._ws_context = ws_context
|
|
self._read_buffer = b""
|
|
self._read_lock = trio.Lock()
|
|
|
|
async def write(self, data: bytes) -> None:
|
|
try:
|
|
logger.debug(f"WebSocket writing {len(data)} bytes")
|
|
# Send as a binary WebSocket message
|
|
await self._ws_connection.send_message(data)
|
|
logger.debug(f"WebSocket wrote {len(data)} bytes successfully")
|
|
except Exception as e:
|
|
logger.error(f"WebSocket write failed: {e}")
|
|
raise IOException from e
|
|
|
|
async def read(self, n: int | None = None) -> bytes:
|
|
"""
|
|
Read up to n bytes (if n is given), else read up to 64KiB.
|
|
This implementation provides byte-level access to WebSocket messages,
|
|
which is required for Noise protocol handshake.
|
|
"""
|
|
async with self._read_lock:
|
|
try:
|
|
logger.debug(
|
|
f"WebSocket read requested: n={n}, "
|
|
f"buffer_size={len(self._read_buffer)}"
|
|
)
|
|
|
|
# If we have buffered data, return it
|
|
if self._read_buffer:
|
|
if n is None:
|
|
result = self._read_buffer
|
|
self._read_buffer = b""
|
|
logger.debug(
|
|
f"WebSocket read returning all buffered data: "
|
|
f"{len(result)} bytes"
|
|
)
|
|
return result
|
|
else:
|
|
if len(self._read_buffer) >= n:
|
|
result = self._read_buffer[:n]
|
|
self._read_buffer = self._read_buffer[n:]
|
|
logger.debug(
|
|
f"WebSocket read returning {len(result)} bytes "
|
|
f"from buffer"
|
|
)
|
|
return result
|
|
else:
|
|
# We need more data, but we have some buffered
|
|
# Keep the buffered data and get more
|
|
logger.debug(
|
|
f"WebSocket read needs more data: have "
|
|
f"{len(self._read_buffer)}, need {n}"
|
|
)
|
|
pass
|
|
|
|
# If we need exactly n bytes but don't have enough, get more data
|
|
while n is not None and (
|
|
not self._read_buffer or len(self._read_buffer) < n
|
|
):
|
|
logger.debug(
|
|
f"WebSocket read getting more data: "
|
|
f"buffer_size={len(self._read_buffer)}, need={n}"
|
|
)
|
|
# Get the next WebSocket message and treat it as a byte stream
|
|
# This mimics the Go implementation's NextReader() approach
|
|
message = await self._ws_connection.get_message()
|
|
if isinstance(message, str):
|
|
message = message.encode("utf-8")
|
|
|
|
logger.debug(
|
|
f"WebSocket read received message: {len(message)} bytes"
|
|
)
|
|
# Add to buffer
|
|
self._read_buffer += message
|
|
|
|
# Return requested amount
|
|
if n is None:
|
|
result = self._read_buffer
|
|
self._read_buffer = b""
|
|
logger.debug(
|
|
f"WebSocket read returning all data: {len(result)} bytes"
|
|
)
|
|
return result
|
|
else:
|
|
if len(self._read_buffer) >= n:
|
|
result = self._read_buffer[:n]
|
|
self._read_buffer = self._read_buffer[n:]
|
|
logger.debug(
|
|
f"WebSocket read returning exact {len(result)} bytes"
|
|
)
|
|
return result
|
|
else:
|
|
# This should never happen due to the while loop above
|
|
result = self._read_buffer
|
|
self._read_buffer = b""
|
|
logger.debug(
|
|
f"WebSocket read returning remaining {len(result)} bytes"
|
|
)
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"WebSocket read failed: {e}")
|
|
raise IOException from e
|
|
|
|
async def close(self) -> None:
|
|
# Close the WebSocket connection
|
|
await self._ws_connection.aclose()
|
|
# Exit the context manager if we have one
|
|
if self._ws_context is not None:
|
|
await self._ws_context.__aexit__(None, None, None)
|
|
|
|
def get_remote_address(self) -> tuple[str, int] | None:
|
|
# Try to get remote address from the WebSocket connection
|
|
try:
|
|
remote = self._ws_connection.remote
|
|
if hasattr(remote, "address") and hasattr(remote, "port"):
|
|
return str(remote.address), int(remote.port)
|
|
elif isinstance(remote, str):
|
|
# Parse address:port format
|
|
if ":" in remote:
|
|
host, port = remote.rsplit(":", 1)
|
|
return host, int(port)
|
|
except Exception:
|
|
pass
|
|
return None
|