mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-12 16:10:57 +00:00
feat: implement WebSocket transport with transport registry system - Add transport_registry.py for centralized transport management - Integrate WebSocket transport with new registry - Add comprehensive test suite for transport registry - Include WebSocket examples and demos - Update transport initialization and swarm integration
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
from trio.abc import Stream
|
||||
import trio
|
||||
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
from libp2p.io.exceptions import IOException
|
||||
@ -6,19 +7,20 @@ from libp2p.io.exceptions import IOException
|
||||
|
||||
class P2PWebSocketConnection(ReadWriteCloser):
|
||||
"""
|
||||
Wraps a raw trio.abc.Stream from an established websocket connection.
|
||||
This bypasses message-framing issues and provides the raw stream
|
||||
Wraps a WebSocketConnection to provide the raw stream interface
|
||||
that libp2p protocols expect.
|
||||
"""
|
||||
|
||||
_stream: Stream
|
||||
|
||||
def __init__(self, stream: Stream):
|
||||
self._stream = stream
|
||||
def __init__(self, ws_connection, ws_context=None):
|
||||
self._ws_connection = ws_connection
|
||||
self._ws_context = ws_context
|
||||
self._read_buffer = b""
|
||||
self._read_lock = trio.Lock()
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
try:
|
||||
await self._stream.send_all(data)
|
||||
# Send as a binary WebSocket message
|
||||
await self._ws_connection.send_message(data)
|
||||
except Exception as e:
|
||||
raise IOException from e
|
||||
|
||||
@ -26,24 +28,68 @@ class P2PWebSocketConnection(ReadWriteCloser):
|
||||
"""
|
||||
Read up to n bytes (if n is given), else read up to 64KiB.
|
||||
"""
|
||||
try:
|
||||
if n is None:
|
||||
# read a reasonable chunk
|
||||
return await self._stream.receive_some(2**16)
|
||||
return await self._stream.receive_some(n)
|
||||
except Exception as e:
|
||||
raise IOException from e
|
||||
async with self._read_lock:
|
||||
try:
|
||||
# If we have buffered data, return it
|
||||
if self._read_buffer:
|
||||
if n is None:
|
||||
result = self._read_buffer
|
||||
self._read_buffer = b""
|
||||
return result
|
||||
else:
|
||||
if len(self._read_buffer) >= n:
|
||||
result = self._read_buffer[:n]
|
||||
self._read_buffer = self._read_buffer[n:]
|
||||
return result
|
||||
else:
|
||||
result = self._read_buffer
|
||||
self._read_buffer = b""
|
||||
return result
|
||||
|
||||
# Get the next WebSocket message
|
||||
message = await self._ws_connection.get_message()
|
||||
if isinstance(message, str):
|
||||
message = message.encode('utf-8')
|
||||
|
||||
# Add to buffer
|
||||
self._read_buffer = message
|
||||
|
||||
# Return requested amount
|
||||
if n is None:
|
||||
result = self._read_buffer
|
||||
self._read_buffer = b""
|
||||
return result
|
||||
else:
|
||||
if len(self._read_buffer) >= n:
|
||||
result = self._read_buffer[:n]
|
||||
self._read_buffer = self._read_buffer[n:]
|
||||
return result
|
||||
else:
|
||||
result = self._read_buffer
|
||||
self._read_buffer = b""
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
raise IOException from e
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._stream.aclose()
|
||||
# Close the WebSocket connection
|
||||
await self._ws_connection.aclose()
|
||||
# Exit the context manager if we have one
|
||||
if self._ws_context is not None:
|
||||
await self._ws_context.__aexit__(None, None, None)
|
||||
|
||||
def get_remote_address(self) -> tuple[str, int] | None:
|
||||
sock = getattr(self._stream, "socket", None)
|
||||
if sock:
|
||||
try:
|
||||
addr = sock.getpeername()
|
||||
if isinstance(addr, tuple) and len(addr) >= 2:
|
||||
return str(addr[0]), int(addr[1])
|
||||
except OSError:
|
||||
return None
|
||||
# Try to get remote address from the WebSocket connection
|
||||
try:
|
||||
remote = self._ws_connection.remote
|
||||
if hasattr(remote, 'address') and hasattr(remote, 'port'):
|
||||
return str(remote.address), int(remote.port)
|
||||
elif isinstance(remote, str):
|
||||
# Parse address:port format
|
||||
if ':' in remote:
|
||||
host, port = remote.rsplit(':', 1)
|
||||
return host, int(port)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user