diff --git a/.gitignore b/.gitignore index e46cc8aa..e17714b5 100644 --- a/.gitignore +++ b/.gitignore @@ -178,3 +178,7 @@ env.bak/ #lockfiles uv.lock poetry.lock +tests/interop/js_libp2p/js_node/node_modules/ +tests/interop/js_libp2p/js_node/package-lock.json +tests/interop/js_libp2p/js_node/src/node_modules/ +tests/interop/js_libp2p/js_node/src/package-lock.json diff --git a/libp2p/transport/__init__.py b/libp2p/transport/__init__.py index e69de29b..62cc5f06 100644 --- a/libp2p/transport/__init__.py +++ b/libp2p/transport/__init__.py @@ -0,0 +1,7 @@ +from .tcp.tcp import TCP +from .websocket.transport import WebsocketTransport + +__all__ = [ + "TCP", + "WebsocketTransport", +] diff --git a/libp2p/transport/websocket/connection.py b/libp2p/transport/websocket/connection.py new file mode 100644 index 00000000..b8c23603 --- /dev/null +++ b/libp2p/transport/websocket/connection.py @@ -0,0 +1,49 @@ +from trio.abc import Stream + +from libp2p.io.abc import ReadWriteCloser +from libp2p.io.exceptions import IOException + + +class P2PWebSocketConnection(ReadWriteCloser): + """ + Wraps a raw trio.abc.Stream from an established websocket connection. + This bypasses message-framing issues and provides the raw stream + that libp2p protocols expect. + """ + + _stream: Stream + + def __init__(self, stream: Stream): + self._stream = stream + + async def write(self, data: bytes) -> None: + try: + await self._stream.send_all(data) + except Exception as e: + raise IOException from e + + async def read(self, n: int | None = None) -> bytes: + """ + Read up to n bytes (if n is given), else read up to 64KiB. + """ + try: + if n is None: + # read a reasonable chunk + return await self._stream.receive_some(2**16) + return await self._stream.receive_some(n) + except Exception as e: + raise IOException from e + + async def close(self) -> None: + await self._stream.aclose() + + def get_remote_address(self) -> tuple[str, int] | None: + sock = getattr(self._stream, "socket", None) + if sock: + try: + addr = sock.getpeername() + if isinstance(addr, tuple) and len(addr) >= 2: + return str(addr[0]), int(addr[1]) + except OSError: + return None + return None diff --git a/libp2p/transport/websocket/listener.py b/libp2p/transport/websocket/listener.py new file mode 100644 index 00000000..be3cc035 --- /dev/null +++ b/libp2p/transport/websocket/listener.py @@ -0,0 +1,81 @@ +import logging +import socket +from typing import Any + +from multiaddr import Multiaddr +import trio +from trio_typing import TaskStatus +from trio_websocket import serve_websocket + +from libp2p.abc import IListener +from libp2p.custom_types import THandler +from libp2p.network.connection.raw_connection import RawConnection + +from .connection import P2PWebSocketConnection + +logger = logging.getLogger("libp2p.transport.websocket.listener") + + +class WebsocketListener(IListener): + """ + Listen on /ip4/.../tcp/.../ws addresses, handshake WS, wrap into RawConnection. + """ + + def __init__(self, handler: THandler) -> None: + self._handler = handler + self._server = None + + async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool: + addr_str = str(maddr) + if addr_str.endswith("/wss"): + raise NotImplementedError("/wss (TLS) not yet supported") + + host = ( + maddr.value_for_protocol("ip4") + or maddr.value_for_protocol("ip6") + or maddr.value_for_protocol("dns") + or maddr.value_for_protocol("dns4") + or maddr.value_for_protocol("dns6") + or "0.0.0.0" + ) + port = int(maddr.value_for_protocol("tcp")) + + async def serve( + task_status: TaskStatus[Any] = trio.TASK_STATUS_IGNORED, + ) -> None: + # positional ssl_context=None + self._server = await serve_websocket( + self._handle_connection, host, port, None + ) + task_status.started() + await self._server.wait_closed() + + await nursery.start(serve) + return True + + async def _handle_connection(self, websocket: Any) -> None: + try: + # use raw transport_stream + conn = P2PWebSocketConnection(websocket.stream) + raw = RawConnection(conn, initiator=False) + await self._handler(raw) + except Exception as e: + logger.debug("WebSocket connection error: %s", e) + + def get_addrs(self) -> tuple[Multiaddr, ...]: + if not self._server or not self._server.sockets: + return () + addrs = [] + for sock in self._server.sockets: + host, port = sock.getsockname()[:2] + if sock.family == socket.AF_INET6: + addr = Multiaddr(f"/ip6/{host}/tcp/{port}/ws") + else: + addr = Multiaddr(f"/ip4/{host}/tcp/{port}/ws") + addrs.append(addr) + return tuple(addrs) + + async def close(self) -> None: + if self._server: + self._server.close() + await self._server.wait_closed() diff --git a/libp2p/transport/websocket/transport.py b/libp2p/transport/websocket/transport.py new file mode 100644 index 00000000..4085b556 --- /dev/null +++ b/libp2p/transport/websocket/transport.py @@ -0,0 +1,49 @@ +from multiaddr import Multiaddr +from trio_websocket import open_websocket_url + +from libp2p.abc import IListener, ITransport +from libp2p.custom_types import THandler +from libp2p.network.connection.raw_connection import RawConnection +from libp2p.transport.exceptions import OpenConnectionError + +from .connection import P2PWebSocketConnection +from .listener import WebsocketListener + + +class WebsocketTransport(ITransport): + """ + Libp2p WebSocket transport: dial and listen on /ip4/.../tcp/.../ws + """ + + async def dial(self, maddr: Multiaddr) -> RawConnection: + text = str(maddr) + if text.endswith("/wss"): + raise NotImplementedError("/wss (TLS) not yet supported") + if not text.endswith("/ws"): + raise ValueError(f"WebsocketTransport only supports /ws, got {maddr}") + + host = ( + maddr.value_for_protocol("ip4") + or maddr.value_for_protocol("ip6") + or maddr.value_for_protocol("dns") + or maddr.value_for_protocol("dns4") + or maddr.value_for_protocol("dns6") + ) + if host is None: + raise ValueError(f"No host protocol found in {maddr}") + + port = int(maddr.value_for_protocol("tcp")) + uri = f"ws://{host}:{port}" + + try: + async with open_websocket_url(uri, ssl_context=None) as ws: + conn = P2PWebSocketConnection(ws.stream) # type: ignore[attr-defined] + return RawConnection(conn, initiator=True) + except Exception as e: + raise OpenConnectionError(f"Failed to dial WebSocket {maddr}: {e}") from e + + def create_listener(self, handler: THandler) -> IListener: # type: ignore[override] + """ + The type checker is incorrectly reporting this as an inconsistent override. + """ + return WebsocketListener(handler) diff --git a/pyproject.toml b/pyproject.toml index 259c6c17..b5feab5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "trio-typing>=0.0.4", "trio>=0.26.0", "fastecdsa==2.3.2; sys_platform != 'win32'", + "trio-websocket>=0.11.0", "zeroconf (>=0.147.0,<0.148.0)", ] classifiers = [ diff --git a/tests/interop/__init__.py b/tests/interop/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/interop/js_libp2p/js_node/src/package.json b/tests/interop/js_libp2p/js_node/src/package.json new file mode 100644 index 00000000..1a7a2547 --- /dev/null +++ b/tests/interop/js_libp2p/js_node/src/package.json @@ -0,0 +1,18 @@ +{ + "name": "src", + "version": "1.0.0", + "main": "ping.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1" + }, + "keywords": [], + "author": "", + "license": "ISC", + "description": "", + "dependencies": { + "@libp2p/ping": "^2.0.36", + "@libp2p/websockets": "^9.2.18", + "libp2p": "^2.9.0", + "multiaddr": "^10.0.1" + } +} diff --git a/tests/interop/js_libp2p/js_node/src/ws_ping_node.mjs b/tests/interop/js_libp2p/js_node/src/ws_ping_node.mjs new file mode 100644 index 00000000..18988b43 --- /dev/null +++ b/tests/interop/js_libp2p/js_node/src/ws_ping_node.mjs @@ -0,0 +1,35 @@ +import { createLibp2p } from 'libp2p' +import { webSockets } from '@libp2p/websockets' +import { ping } from '@libp2p/ping' +import { plaintext } from '@libp2p/insecure' +import { mplex } from '@libp2p/mplex' + +async function main() { + const node = await createLibp2p({ + transports: [ webSockets() ], + connectionEncryption: [ plaintext() ], + streamMuxers: [ mplex() ], + services: { + // installs /ipfs/ping/1.0.0 handler + ping: ping() + }, + addresses: { + listen: ['/ip4/127.0.0.1/tcp/0/ws'] + } + }) + + await node.start() + + console.log(node.peerId.toString()) + for (const addr of node.getMultiaddrs()) { + console.log(addr.toString()) + } + + // Keep the process alive + await new Promise(() => {}) +} + +main().catch(err => { + console.error(err) + process.exit(1) +}) diff --git a/tests/interop/test_js_ws_ping.py b/tests/interop/test_js_ws_ping.py new file mode 100644 index 00000000..813c7cf2 --- /dev/null +++ b/tests/interop/test_js_ws_ping.py @@ -0,0 +1,85 @@ +import os +import signal +import subprocess + +import pytest +from multiaddr import Multiaddr +import trio +from trio.lowlevel import open_process + +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.custom_types import TProtocol +from libp2p.host.basic_host import BasicHost +from libp2p.network.swarm import Swarm +from libp2p.peer.id import ID +from libp2p.peer.peerinfo import PeerInfo +from libp2p.peer.peerstore import PeerStore +from libp2p.security.insecure.transport import InsecureTransport +from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex +from libp2p.transport.upgrader import TransportUpgrader +from libp2p.transport.websocket.transport import WebsocketTransport + +PLAINTEXT_PROTOCOL_ID = "/plaintext/1.0.0" + + +@pytest.mark.trio +async def test_ping_with_js_node(): + # 1) Path to the JS node script + js_node_dir = os.path.join(os.path.dirname(__file__), "js_libp2p", "js_node", "src") + script_name = "ws_ping_node.mjs" + + # 2) Launch the JS libp2p node (long-running) + proc = await open_process( + ["node", script_name], + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + cwd=js_node_dir, + ) + try: + # 3) Read first two lines (PeerID and multiaddr) + buffer = b"" + with trio.fail_after(10): + while buffer.count(b"\n") < 2: + chunk = await proc.stdout.receive_some(1024) # type: ignore + if not chunk: + break + buffer += chunk + + lines = buffer.decode().strip().split("\n") + peer_id_line, addr_line = lines[0], lines[1] + peer_id = ID.from_base58(peer_id_line) + maddr = Multiaddr(addr_line) + + # 4) Set up Python host + key_pair = create_new_key_pair() + py_peer_id = ID.from_pubkey(key_pair.public_key) + peer_store = PeerStore() + peer_store.add_key_pair(py_peer_id, key_pair) + + upgrader = TransportUpgrader( + secure_transports_by_protocol={ + TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair) + }, + muxer_transports_by_protocol={TProtocol(MPLEX_PROTOCOL_ID): Mplex}, + ) + transport = WebsocketTransport() + swarm = Swarm(py_peer_id, peer_store, upgrader, transport) + host = BasicHost(swarm) + + # 5) Connect to JS node + peer_info = PeerInfo(peer_id, [maddr]) + await host.connect(peer_info) + assert host.get_network().connections.get(peer_id) is not None + await trio.sleep(0.1) + + # 6) Ping protocol + stream = await host.new_stream(peer_id, [TProtocol("/ipfs/ping/1.0.0")]) + await stream.write(b"ping") + data = await stream.read(4) + assert data == b"pong" + + # 7) Cleanup + await host.close() + finally: + proc.send_signal(signal.SIGTERM) + await trio.sleep(0) diff --git a/tests/transport/__init__.py b/tests/transport/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/transport/test_websocket.py b/tests/transport/test_websocket.py new file mode 100644 index 00000000..412e1063 --- /dev/null +++ b/tests/transport/test_websocket.py @@ -0,0 +1,72 @@ +from collections.abc import Sequence +from typing import Any + +import pytest +from multiaddr import Multiaddr + +from libp2p.crypto.secp256k1 import create_new_key_pair +from libp2p.custom_types import TProtocol +from libp2p.host.basic_host import BasicHost +from libp2p.network.swarm import Swarm +from libp2p.peer.id import ID +from libp2p.peer.peerinfo import PeerInfo +from libp2p.peer.peerstore import PeerStore +from libp2p.security.insecure.transport import InsecureTransport +from libp2p.stream_muxer.mplex.mplex import MPLEX_PROTOCOL_ID, Mplex +from libp2p.transport.upgrader import TransportUpgrader +from libp2p.transport.websocket.transport import WebsocketTransport + +PLAINTEXT_PROTOCOL_ID = "/plaintext/1.0.0" + + +async def make_host( + listen_addrs: Sequence[Multiaddr] | None = None, +) -> tuple[BasicHost, Any | None]: + # 1) Identity + key_pair = create_new_key_pair() + peer_id = ID.from_pubkey(key_pair.public_key) + peer_store = PeerStore() + peer_store.add_key_pair(peer_id, key_pair) + + # 2) Upgrader + upgrader = TransportUpgrader( + secure_transports_by_protocol={ + TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair) + }, + muxer_transports_by_protocol={TProtocol(MPLEX_PROTOCOL_ID): Mplex}, + ) + + # 3) Transport + Swarm + Host + transport = WebsocketTransport() + swarm = Swarm(peer_id, peer_store, upgrader, transport) + host = BasicHost(swarm) + + # 4) Optionally run/listen + ctx = None + if listen_addrs: + ctx = host.run(listen_addrs) + await ctx.__aenter__() + + return host, ctx + + +@pytest.mark.trio +async def test_websocket_dial_and_listen(): + # Start server + server_host, server_ctx = await make_host([Multiaddr("/ip4/127.0.0.1/tcp/0/ws")]) + # Client + client_host, _ = await make_host(None) + + # Dial + peer_info = PeerInfo(server_host.get_id(), server_host.get_addrs()) + await client_host.connect(peer_info) + + # Verify connections + assert client_host.get_network().connections.get(server_host.get_id()) + assert server_host.get_network().connections.get(client_host.get_id()) + + # Cleanup + await client_host.close() + if server_ctx: + await server_ctx.__aexit__(None, None, None) + await server_host.close()