2 Commits

Author SHA1 Message Date
4b331d96a7 fix: pre-commit issues 2025-05-31 12:47:46 +01:00
1d9849cb43 feat: achieve ping interop py-libp2p - rust-libp2p
WORKING: Connection, handshake, Yamux setup, initial ping/pong. ISSUE: Frame parser corruption after 2-3 frames (boundary sync)
2025-05-31 12:25:12 +01:00
16 changed files with 466 additions and 390 deletions

View File

@ -1,9 +1,16 @@
import argparse
import logging
import os
import struct
from cryptography.hazmat.primitives.asymmetric import (
x25519,
)
import multiaddr
import trio
from libp2p import (
generate_new_rsa_identity,
new_host,
)
from libp2p.custom_types import (
@ -12,109 +19,496 @@ from libp2p.custom_types import (
from libp2p.network.stream.net_stream import (
INetStream,
)
from libp2p.peer.peerinfo import (
info_from_p2p_addr,
from libp2p.security.noise.transport import Transport as NoiseTransport
from libp2p.stream_muxer.yamux.yamux import (
Yamux,
YamuxStream,
)
from libp2p.stream_muxer.yamux.yamux import PROTOCOL_ID as YAMUX_PROTOCOL_ID
# Configure detailed logging
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.StreamHandler(),
logging.FileHandler("ping_debug.log", mode="w", encoding="utf-8"),
],
)
# Protocol constants - must match rust-libp2p exactly
PING_PROTOCOL_ID = TProtocol("/ipfs/ping/1.0.0")
PING_LENGTH = 32
RESP_TIMEOUT = 60
RESP_TIMEOUT = 30
MAX_FRAME_SIZE = 1024 * 1024 # 1MB max frame size
class InteropYamux(Yamux):
"""Enhanced Yamux with proper rust-libp2p interoperability"""
def __init__(self, *args, **kwargs):
logging.info("InteropYamux.__init__ called")
super().__init__(*args, **kwargs)
self.frame_count = 0
self.debug_frames = True
async def _read_exact_bytes(self, n):
"""Read exactly n bytes from the connection with proper error handling"""
if n == 0:
return b""
if n > MAX_FRAME_SIZE:
logging.error(f"Requested read size {n} exceeds maximum {MAX_FRAME_SIZE}")
return None
data = b""
while len(data) < n:
try:
remaining = n - len(data)
chunk = await self.secured_conn.read(remaining)
except (trio.ClosedResourceError, trio.BrokenResourceError):
logging.debug(
f"Connection closed while reading {n}"
f"bytes (got {len(data)}) for peer {self.peer_id}"
)
return None
except Exception as e:
logging.error(f"Error reading {n} bytes: {e}")
return None
if not chunk:
logging.debug(
f"Connection closed while reading {n}"
f"bytes (got {len(data)}) for peer {self.peer_id}"
)
return None
data += chunk
return data
async def handle_incoming(self):
"""Enhanced incoming frame handler with better error recovery"""
logging.info(f"Starting Yamux for {self.peer_id}")
consecutive_errors = 0
max_consecutive_errors = 3
while not self.event_shutting_down.is_set():
try:
# Read frame header (12 bytes)
header_data = await self._read_exact_bytes(12)
if header_data is None:
logging.debug(
f"Connection closed or incomplete"
f"header for peer {self.peer_id}"
)
break
# Quick sanity check for protocol data leakage
if (
b"/ipfs" in header_data
or b"multi" in header_data
or b"noise" in header_data
):
logging.error(
f"Protocol data in header position: {header_data.hex()}"
)
break
try:
# Unpack header: version, type, flags, stream_id, length
version, msg_type, flags, stream_id, length = struct.unpack(
">BBHII", header_data
)
# Validate header values strictly
if version != 0:
logging.error(f"Invalid yamux version {version}, expected 0")
break
if msg_type not in [0, 1, 2, 3]:
logging.error(f"Invalid message type {msg_type}, expected 0-3")
break
if length > MAX_FRAME_SIZE:
logging.error(f"Frame too large: {length} > {MAX_FRAME_SIZE}")
break
# Additional validation for ping frames
if msg_type == 2 and length != 4:
logging.error(
f"Invalid ping frame length: {length}, expected 4"
)
break
# Log frame details
logging.debug(
f"Received header for peer {self.peer_id}"
f": type={msg_type}, flags={flags},"
f"stream_id={stream_id}, length={length}"
)
consecutive_errors = 0 # Reset error counter on successful parse
except struct.error as e:
consecutive_errors += 1
logging.error(
f"Header parse error #{consecutive_errors}"
f": {e}, data: {header_data.hex()}"
)
if consecutive_errors >= max_consecutive_errors:
logging.error("Too many consecutive header parse errors")
break
continue
# Read payload if present
payload = b""
if length > 0:
payload = await self._read_exact_bytes(length)
if payload is None:
logging.debug(
f"Failed to read payload of"
f"{length} bytes for peer {self.peer_id}"
)
break
if len(payload) != length:
logging.error(
f"Payload length mismatch:"
f"got {len(payload)}, expected {length}"
)
break
# Process frame by type
if msg_type == 0: # Data frame
await self._handle_data_frame(stream_id, flags, payload)
elif msg_type == 1: # Window update
await self._handle_window_update(stream_id, payload)
elif msg_type == 2: # Ping frame
await self._handle_ping_frame(stream_id, flags, payload)
elif msg_type == 3: # GoAway frame
await self._handle_goaway_frame(payload)
break
except (trio.ClosedResourceError, trio.BrokenResourceError):
logging.debug(
f"Connection closed during frame processing for peer {self.peer_id}"
)
break
except Exception as e:
consecutive_errors += 1
logging.error(f"Frame processing error #{consecutive_errors}: {e}")
if consecutive_errors >= max_consecutive_errors:
logging.error("Too many consecutive frame processing errors")
break
await self._cleanup_on_error()
async def _handle_data_frame(self, stream_id, flags, payload):
"""Handle data frames with proper stream lifecycle"""
if stream_id == 0:
logging.warning("Received data frame for stream 0 (control stream)")
return
# Handle SYN flag - new stream creation
if flags & 0x1: # SYN flag
if stream_id in self.streams:
logging.warning(f"SYN received for existing stream {stream_id}")
else:
logging.debug(
f"Creating new stream {stream_id} for peer {self.peer_id}"
)
stream = YamuxStream(self, stream_id, is_outbound=False)
async with self.streams_lock:
self.streams[stream_id] = stream
# Send the new stream to the handler
await self.new_stream_send_channel.send(stream)
logging.debug(f"Sent stream {stream_id} to handler")
# Add data to stream buffer if stream exists
if stream_id in self.streams:
stream = self.streams[stream_id]
if payload:
# Add to stream's receive buffer
async with self.streams_lock:
if not hasattr(stream, "_receive_buffer"):
stream._receive_buffer = bytearray()
if not hasattr(stream, "_receive_event"):
stream._receive_event = trio.Event()
stream._receive_buffer.extend(payload)
stream._receive_event.set()
# Handle stream closure flags
if flags & 0x2: # FIN flag
stream.recv_closed = True
logging.debug(f"Stream {stream_id} received FIN")
if flags & 0x4: # RST flag
stream.reset_received = True
logging.debug(f"Stream {stream_id} received RST")
else:
if payload:
logging.warning(f"Received data for unknown stream {stream_id}")
async def _handle_window_update(self, stream_id, payload):
"""Handle window update frames"""
if len(payload) != 4:
logging.warning(f"Invalid window update payload length: {len(payload)}")
return
delta = struct.unpack(">I", payload)[0]
logging.debug(f"Window update: stream={stream_id}, delta={delta}")
async with self.streams_lock:
if stream_id in self.streams:
if not hasattr(self.streams[stream_id], "send_window"):
self.streams[stream_id].send_window = 256 * 1024 # Default window
self.streams[stream_id].send_window += delta
async def _handle_ping_frame(self, stream_id, flags, payload):
"""Handle ping/pong frames with proper validation"""
if len(payload) != 4:
logging.warning(f"Invalid ping payload length: {len(payload)} (expected 4)")
return
ping_value = struct.unpack(">I", payload)[0]
if flags & 0x1: # SYN flag - ping request
logging.debug(
f"Received ping request with value {ping_value} for peer {self.peer_id}"
)
# Send pong response (ACK flag = 0x2)
try:
pong_header = struct.pack(
">BBHII", 0, 2, 0x2, 0, 4
) # Version=0, Type=2, Flags=ACK, StreamID=0, Length=4
pong_payload = struct.pack(">I", ping_value)
await self.secured_conn.write(pong_header + pong_payload)
logging.debug(f"Sent pong response with value {ping_value}")
except Exception as e:
logging.error(f"Failed to send pong response: {e}")
else:
# Pong response
logging.debug(f"Received pong response with value {ping_value}")
async def _handle_goaway_frame(self, payload):
"""Handle GoAway frames"""
if len(payload) != 4:
logging.warning(f"Invalid GoAway payload length: {len(payload)}")
return
code = struct.unpack(">I", payload)[0]
logging.info(f"Received GoAway frame with code {code}")
self.event_shutting_down.set()
async def handle_ping(stream: INetStream) -> None:
while True:
try:
payload = await stream.read(PING_LENGTH)
peer_id = stream.muxed_conn.peer_id
if payload is not None:
print(f"received ping from {peer_id}")
logging.info(f"Handling ping stream from {peer_id}")
try:
with trio.fail_after(RESP_TIMEOUT):
# Read initial protocol negotiation
initial_data = await stream.read(1024)
logging.debug(
f"Received initial stream data from {peer_id}"
f": {initial_data.hex()} (length={len(initial_data)})"
)
if initial_data == b"/ipfs/ping/1.0.0\n":
logging.debug(
f"Confirmed /ipfs/ping/1.0.0 protocol negotiation from {peer_id}"
)
else:
logging.warning(f"Unexpected initial data: {initial_data!r}")
# Read ping payload
payload = await stream.read(PING_LENGTH)
if not payload:
logging.info(f"Stream closed by {peer_id}")
return
if len(payload) != PING_LENGTH:
logging.warning(
f"Unexpected payload length"
f" {len(payload)} from {peer_id}: {payload.hex()}"
)
return
logging.info(
f"Received ping from {peer_id}:"
f" {payload[:8].hex()}... (length={len(payload)})"
)
await stream.write(payload)
print(f"responded with pong to {peer_id}")
logging.info(f"Sent pong to {peer_id}: {payload[:8].hex()}...")
except trio.TooSlowError:
logging.warning(f"Ping timeout with {peer_id}")
except trio.BrokenResourceError:
logging.info(f"Connection broken with {peer_id}")
except Exception as e:
logging.error(f"Error handling ping from {peer_id}: {e}")
finally:
try:
await stream.close()
logging.debug(f"Closed ping stream with {peer_id}")
except Exception:
await stream.reset()
break
pass
async def send_ping(stream: INetStream) -> None:
peer_id = stream.muxed_conn.peer_id
try:
payload = b"\x01" * PING_LENGTH
print(f"sending ping to {stream.muxed_conn.peer_id}")
await stream.write(payload)
payload = os.urandom(PING_LENGTH)
logging.info(f"Sending ping to {peer_id}: {payload[:8].hex()}...")
with trio.fail_after(RESP_TIMEOUT):
await stream.write(payload)
logging.debug(f"Ping sent to {peer_id}")
response = await stream.read(PING_LENGTH)
if not response:
logging.error(f"No pong response from {peer_id}")
return
if len(response) != PING_LENGTH:
logging.warning(
f"Pong length mismatch: got {len(response)}, expected {PING_LENGTH}"
)
if response == payload:
print(f"received pong from {stream.muxed_conn.peer_id}")
logging.info(f"Ping successful! Pong matches from {peer_id}")
else:
logging.warning(f"Pong mismatch from {peer_id}")
except trio.TooSlowError:
logging.error(f"Ping timeout to {peer_id}")
except Exception as e:
print(f"error occurred : {e}")
logging.error(f"Error sending ping to {peer_id}: {e}")
finally:
try:
await stream.close()
except Exception:
pass
def create_noise_keypair():
try:
x25519_private_key = x25519.X25519PrivateKey.generate()
class NoisePrivateKey:
def __init__(self, key):
self._key = key
def to_bytes(self):
return self._key.private_bytes_raw()
def public_key(self):
return NoisePublicKey(self._key.public_key())
def get_public_key(self):
return NoisePublicKey(self._key.public_key())
class NoisePublicKey:
def __init__(self, key):
self._key = key
def to_bytes(self):
return self._key.public_bytes_raw()
return NoisePrivateKey(x25519_private_key)
except Exception as e:
logging.error(f"Failed to create Noise keypair: {e}")
return None
def info_from_p2p_addr(addr):
"""Extract peer info from multiaddr - you'll need to implement this"""
# This is a placeholder - you need to implement the actual parsing
# based on your libp2p implementation
async def run(port: int, destination: str) -> None:
localhost_ip = "127.0.0.1"
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
host = new_host(listen_addrs=[listen_addr])
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
try:
key_pair = generate_new_rsa_identity()
logging.debug("Generated RSA keypair")
noise_privkey = create_noise_keypair()
logging.debug("Generated Noise keypair")
except Exception as e:
logging.error(f"Key generation failed: {e}")
raise
noise_transport = NoiseTransport(key_pair, noise_privkey=noise_privkey)
logging.debug(f"Noise transport initialized: {noise_transport}")
sec_opt = {TProtocol("/noise"): noise_transport}
muxer_opt = {TProtocol(YAMUX_PROTOCOL_ID): InteropYamux}
logging.info(f"Using muxer: {muxer_opt}")
host = new_host(key_pair=key_pair, sec_opt=sec_opt, muxer_opt=muxer_opt)
peer_id = host.get_id().pretty()
logging.info(f"Host peer ID: {peer_id}")
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery():
if not destination:
host.set_stream_handler(PING_PROTOCOL_ID, handle_ping)
print(
"Run this from the same folder in another console:\n\n"
f"ping-demo -p {int(port) + 1} "
f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id().pretty()}\n"
)
print("Waiting for incoming connection...")
logging.info(f"Server listening on {listen_addr}")
logging.info(f"Full address: {listen_addr}/p2p/{peer_id}")
logging.info("Waiting for connections...")
await trio.sleep_forever()
else:
maddr = multiaddr.Multiaddr(destination)
info = info_from_p2p_addr(maddr)
await host.connect(info)
stream = await host.new_stream(info.peer_id, [PING_PROTOCOL_ID])
nursery.start_soon(send_ping, stream)
return
await trio.sleep_forever()
def main() -> None:
description = """
This program demonstrates a simple p2p ping application using libp2p.
To use it, first run 'python ping.py -p <PORT>', where <PORT> is the port number.
Then, run another instance with 'python ping.py -p <ANOTHER_PORT> -d <DESTINATION>',
where <DESTINATION> is the multiaddress of the previous listener host.
"""
example_maddr = (
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
)
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
"-p", "--port", default=8000, type=int, help="source port number"
)
parser.add_argument(
"-d",
"--destination",
type=str,
help=f"destination multiaddr string, e.g. {example_maddr}",
)
args = parser.parse_args()
if not args.port:
raise RuntimeError("failed to determine local port")
logging.info(f"Connecting to {info.peer_id}")
try:
trio.run(run, *(args.port, args.destination))
with trio.fail_after(30):
await host.connect(info)
logging.info(f"Connected to {info.peer_id}")
await trio.sleep(2.0)
logging.info(f"Opening ping stream to {info.peer_id}")
stream = await host.new_stream(info.peer_id, [PING_PROTOCOL_ID])
logging.info(f"Opened ping stream to {info.peer_id}")
await trio.sleep(0.5)
await send_ping(stream)
logging.info("Ping completed successfully")
logging.info("Keeping connection alive for 5 seconds...")
await trio.sleep(5.0)
except trio.TooSlowError:
logging.error(f"Connection timeout to {info.peer_id}")
raise
except Exception as e:
logging.error(f"Connection failed to {info.peer_id}: {e}")
import traceback
traceback.print_exc()
raise
def main():
parser = argparse.ArgumentParser(
description="libp2p ping with Rust interoperability"
)
parser.add_argument(
"-p", "--port", default=8000, type=int, help="Port to listen on"
)
parser.add_argument("-d", "--destination", type=str, help="Destination multiaddr")
args = parser.parse_args()
try:
trio.run(run, args.port, args.destination)
except KeyboardInterrupt:
pass
logging.info("Terminated by user")
except Exception as e:
logging.error(f"Fatal error: {e}")
raise
if __name__ == "__main__":

View File

@ -1,6 +1,5 @@
from collections.abc import (
Mapping,
Sequence,
)
from importlib.metadata import version as __version
from typing import (
@ -10,8 +9,6 @@ from typing import (
cast,
)
import multiaddr
from libp2p.abc import (
IHost,
IMuxedConn,
@ -157,7 +154,6 @@ def new_swarm(
sec_opt: Optional[TSecurityOptions] = None,
peerstore_opt: Optional[IPeerStore] = None,
muxer_preference: Optional[Literal["YAMUX", "MPLEX"]] = None,
listen_addrs: Optional[Sequence[multiaddr.Multiaddr]] = None,
) -> INetworkService:
"""
Create a swarm instance based on the parameters.
@ -167,7 +163,6 @@ def new_swarm(
:param sec_opt: optional choice of security upgrade
:param peerstore_opt: optional peerstore
:param muxer_preference: optional explicit muxer preference
:param listen_addrs: optional list of multiaddrs to listen on
:return: return a default swarm instance
Note: Yamux (/yamux/1.0.0) is the preferred stream multiplexer
@ -180,16 +175,8 @@ def new_swarm(
id_opt = generate_peer_id_from(key_pair)
if listen_addrs is None:
# TODO: Parse `listen_addrs` to determine transport
transport = TCP()
else:
addr = listen_addrs[0]
if addr.__contains__("tcp"):
transport = TCP()
elif addr.__contains__("quic"):
raise ValueError("QUIC not yet supported")
else:
raise ValueError(f"Unknown transport in listen_addrs: {listen_addrs}")
# Generate X25519 keypair for Noise
noise_key_pair = create_new_x25519_key_pair()
@ -242,7 +229,6 @@ def new_host(
peerstore_opt: Optional[IPeerStore] = None,
disc_opt: Optional[IPeerRouting] = None,
muxer_preference: Optional[Literal["YAMUX", "MPLEX"]] = None,
listen_addrs: Sequence[multiaddr.Multiaddr] = None,
) -> IHost:
"""
Create a new libp2p host based on the given parameters.
@ -253,7 +239,6 @@ def new_host(
:param peerstore_opt: optional peerstore
:param disc_opt: optional discovery
:param muxer_preference: optional explicit muxer preference
:param listen_addrs: optional list of multiaddrs to listen on
:return: return a host instance
"""
swarm = new_swarm(
@ -262,7 +247,6 @@ def new_host(
sec_opt=sec_opt,
peerstore_opt=peerstore_opt,
muxer_preference=muxer_preference,
listen_addrs=listen_addrs,
)
if disc_opt is not None:

View File

@ -8,14 +8,10 @@ from collections.abc import (
KeysView,
Sequence,
)
from types import (
TracebackType,
)
from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
Optional,
)
from multiaddr import (
@ -219,7 +215,7 @@ class IMuxedConn(ABC):
"""
class IMuxedStream(ReadWriteCloser, AsyncContextManager["IMuxedStream"]):
class IMuxedStream(ReadWriteCloser):
"""
Interface for a multiplexed stream.
@ -253,20 +249,6 @@ class IMuxedStream(ReadWriteCloser, AsyncContextManager["IMuxedStream"]):
otherwise False.
"""
@abstractmethod
async def __aenter__(self) -> "IMuxedStream":
"""Enter the async context manager."""
return self
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
"""Exit the async context manager and close the stream."""
await self.close()
# -------------------------- net_stream interface.py --------------------------

View File

@ -195,29 +195,6 @@ class BasicHost(IHost):
net_stream.set_protocol(selected_protocol)
return net_stream
async def send_command(self, peer_id: ID, command: str) -> list[str]:
"""
Send a multistream-select command to the specified peer and return
the response.
:param peer_id: peer_id that host is connecting
:param command: supported multistream-select command (e.g., "ls)
:raise StreamFailure: If the stream cannot be opened or negotiation fails
:return: list of strings representing the response from peer.
"""
new_stream = await self._network.new_stream(peer_id)
try:
response = await self.multiselect_client.query_multistream_command(
MultiselectCommunicator(new_stream), command
)
except MultiselectClientError as error:
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
await new_stream.reset()
raise StreamFailure(f"failed to open a stream to peer {peer_id}") from error
return response
async def connect(self, peer_info: PeerInfo) -> None:
"""
Ensure there is a connection between this host and the peer

View File

@ -60,14 +60,8 @@ class Multiselect(IMultiselectMuxer):
raise MultiselectError() from error
if command == "ls":
supported_protocols = list(self.handlers.keys())
response = "\n".join(supported_protocols) + "\n"
try:
await communicator.write(response)
except MultiselectCommunicatorError as error:
raise MultiselectError() from error
# TODO: handle ls command
pass
else:
protocol = TProtocol(command)
if protocol in self.handlers:

View File

@ -70,36 +70,6 @@ class MultiselectClient(IMultiselectClient):
raise MultiselectClientError("protocols not supported")
async def query_multistream_command(
self, communicator: IMultiselectCommunicator, command: str
) -> list[str]:
"""
Send a multistream-select command over the given communicator and return
parsed response.
:param communicator: communicator to use to communicate with counterparty
:param command: supported multistream-select command(e.g., ls)
:raise MultiselectClientError: If the communicator fails to process data.
:return: list of strings representing the response from peer.
"""
await self.handshake(communicator)
if command == "ls":
try:
await communicator.write("ls")
except MultiselectCommunicatorError as error:
raise MultiselectClientError() from error
else:
raise ValueError("Command not supported")
try:
response = await communicator.read()
response_list = response.strip().splitlines()
except MultiselectCommunicatorError as error:
raise MultiselectClientError() from error
return response_list
async def try_select(
self, communicator: IMultiselectCommunicator, protocol: TProtocol
) -> TProtocol:

View File

@ -1,6 +1,3 @@
from types import (
TracebackType,
)
from typing import (
TYPE_CHECKING,
Optional,
@ -260,16 +257,3 @@ class MplexStream(IMuxedStream):
def get_remote_address(self) -> Optional[tuple[str, int]]:
"""Delegate to the parent Mplex connection."""
return self.muxed_conn.get_remote_address()
async def __aenter__(self) -> "MplexStream":
"""Enter the async context manager."""
return self
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
"""Exit the async context manager and close the stream."""
await self.close()

View File

@ -9,9 +9,6 @@ from collections.abc import (
import inspect
import logging
import struct
from types import (
TracebackType,
)
from typing import (
Callable,
Optional,
@ -77,19 +74,6 @@ class YamuxStream(IMuxedStream):
self.recv_window = DEFAULT_WINDOW_SIZE
self.window_lock = trio.Lock()
async def __aenter__(self) -> "YamuxStream":
"""Enter the async context manager."""
return self
async def __aexit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
"""Exit the async context manager and close the stream."""
await self.close()
async def write(self, data: bytes) -> None:
if self.send_closed:
raise MuxedStreamError("Stream is closed for sending")

View File

@ -1 +0,0 @@
Allow passing `listen_addrs` to `new_swarm` to customize swarm listening behavior.

View File

@ -1,2 +0,0 @@
Feature: Support for sending `ls` command over `multistream-select` to list supported protocols from remote peer.
This allows inspecting which protocol handlers a peer supports at runtime.

View File

@ -1 +0,0 @@
implement AsyncContextManager for IMuxedStream to support async with

View File

@ -209,8 +209,8 @@ async def ping_demo(host_a, host_b):
async def pubsub_demo(host_a, host_b):
gossipsub_a = GossipSub([GOSSIPSUB_PROTOCOL_ID], 3, 2, 4, None, 1, 1)
gossipsub_b = GossipSub([GOSSIPSUB_PROTOCOL_ID], 3, 2, 4, None, 1, 1)
gossipsub_a = GossipSub([GOSSIPSUB_PROTOCOL_ID], 3, 2, 4, None, 0.1, 1)
gossipsub_b = GossipSub([GOSSIPSUB_PROTOCOL_ID], 3, 2, 4, None, 0.1, 1)
pubsub_a = Pubsub(host_a, gossipsub_a)
pubsub_b = Pubsub(host_b, gossipsub_b)
message_a_to_b = "Hello from A to B"

View File

@ -7,18 +7,12 @@ from trio.testing import (
wait_all_tasks_blocked,
)
from libp2p import (
new_swarm,
)
from libp2p.network.exceptions import (
SwarmException,
)
from libp2p.tools.utils import (
connect_swarm,
)
from libp2p.transport.tcp.tcp import (
TCP,
)
from tests.utils.factories import (
SwarmFactory,
)
@ -162,20 +156,3 @@ async def test_swarm_multiaddr(security_protocol):
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs + addrs, 10000)
await swarms[0].dial_peer(swarms[1].get_peer_id())
def test_new_swarm_defaults_to_tcp():
swarm = new_swarm()
assert isinstance(swarm.transport, TCP)
def test_new_swarm_tcp_multiaddr_supported():
addr = Multiaddr("/ip4/127.0.0.1/tcp/9999")
swarm = new_swarm(listen_addrs=[addr])
assert isinstance(swarm.transport, TCP)
def test_new_swarm_quic_multiaddr_raises():
addr = Multiaddr("/ip4/127.0.0.1/udp/9999/quic")
with pytest.raises(ValueError, match="QUIC not yet supported"):
new_swarm(listen_addrs=[addr])

View File

@ -116,35 +116,3 @@ async def test_multiple_protocol_fails(security_protocol):
await perform_simple_test(
"", protocols_for_client, protocols_for_listener, security_protocol
)
@pytest.mark.trio
async def test_multistream_command(security_protocol):
supported_protocols = [PROTOCOL_ECHO, PROTOCOL_FOO, PROTOCOL_POTATO, PROTOCOL_ROCK]
async with HostFactory.create_batch_and_listen(
2, security_protocol=security_protocol
) as hosts:
listener, dialer = hosts[1], hosts[0]
for protocol in supported_protocols:
listener.set_stream_handler(
protocol, create_echo_stream_handler(ACK_PREFIX)
)
# Ensure dialer knows how to reach the listener
dialer.get_peerstore().add_addrs(listener.get_id(), listener.get_addrs(), 10)
# Dialer asks peer to list the supported protocols using `ls`
response = await dialer.send_command(listener.get_id(), "ls")
# We expect all supported protocols to show up
for protocol in supported_protocols:
assert protocol in response
assert "/does/not/exist" not in response
assert "/foo/bar/1.2.3" not in response
# Dialer asks for unspoorted command
with pytest.raises(ValueError, match="Command not supported"):
await dialer.send_command(listener.get_id(), "random")

View File

@ -1,127 +0,0 @@
import pytest
import trio
from libp2p.stream_muxer.exceptions import (
MuxedStreamClosed,
MuxedStreamError,
)
from libp2p.stream_muxer.mplex.datastructures import (
StreamID,
)
from libp2p.stream_muxer.mplex.mplex_stream import (
MplexStream,
)
from libp2p.stream_muxer.yamux.yamux import (
YamuxStream,
)
class DummySecuredConn:
async def write(self, data):
pass
class MockMuxedConn:
def __init__(self):
self.streams = {}
self.streams_lock = trio.Lock()
self.event_shutting_down = trio.Event()
self.event_closed = trio.Event()
self.event_started = trio.Event()
self.secured_conn = DummySecuredConn() # For YamuxStream
async def send_message(self, flag, data, stream_id):
pass
def get_remote_address(self):
return None
@pytest.mark.trio
async def test_mplex_stream_async_context_manager():
muxed_conn = MockMuxedConn()
stream_id = StreamID(1, True) # Use real StreamID
stream = MplexStream(
name="test_stream",
stream_id=stream_id,
muxed_conn=muxed_conn,
incoming_data_channel=trio.open_memory_channel(8)[1],
)
async with stream as s:
assert s is stream
assert not stream.event_local_closed.is_set()
assert not stream.event_remote_closed.is_set()
assert not stream.event_reset.is_set()
assert stream.event_local_closed.is_set()
@pytest.mark.trio
async def test_yamux_stream_async_context_manager():
muxed_conn = MockMuxedConn()
stream = YamuxStream(stream_id=1, conn=muxed_conn, is_initiator=True)
async with stream as s:
assert s is stream
assert not stream.closed
assert not stream.send_closed
assert not stream.recv_closed
assert stream.send_closed
@pytest.mark.trio
async def test_mplex_stream_async_context_manager_with_error():
muxed_conn = MockMuxedConn()
stream_id = StreamID(1, True)
stream = MplexStream(
name="test_stream",
stream_id=stream_id,
muxed_conn=muxed_conn,
incoming_data_channel=trio.open_memory_channel(8)[1],
)
with pytest.raises(ValueError):
async with stream as s:
assert s is stream
assert not stream.event_local_closed.is_set()
assert not stream.event_remote_closed.is_set()
assert not stream.event_reset.is_set()
raise ValueError("Test error")
assert stream.event_local_closed.is_set()
@pytest.mark.trio
async def test_yamux_stream_async_context_manager_with_error():
muxed_conn = MockMuxedConn()
stream = YamuxStream(stream_id=1, conn=muxed_conn, is_initiator=True)
with pytest.raises(ValueError):
async with stream as s:
assert s is stream
assert not stream.closed
assert not stream.send_closed
assert not stream.recv_closed
raise ValueError("Test error")
assert stream.send_closed
@pytest.mark.trio
async def test_mplex_stream_async_context_manager_write_after_close():
muxed_conn = MockMuxedConn()
stream_id = StreamID(1, True)
stream = MplexStream(
name="test_stream",
stream_id=stream_id,
muxed_conn=muxed_conn,
incoming_data_channel=trio.open_memory_channel(8)[1],
)
async with stream as s:
assert s is stream
with pytest.raises(MuxedStreamClosed):
await stream.write(b"test data")
@pytest.mark.trio
async def test_yamux_stream_async_context_manager_write_after_close():
muxed_conn = MockMuxedConn()
stream = YamuxStream(stream_id=1, conn=muxed_conn, is_initiator=True)
async with stream as s:
assert s is stream
with pytest.raises(MuxedStreamError):
await stream.write(b"test data")

View File

@ -32,25 +32,18 @@ class BaseInteractiveProcess(AbstractInterativeProcess):
async def wait_until_ready(self) -> None:
patterns_occurred = {pat: False for pat in self.patterns}
buffers = {pat: bytearray() for pat in self.patterns}
async def read_from_daemon_and_check() -> None:
async for data in self.proc.stdout:
# TODO: It takes O(n^2), which is quite bad.
# But it should succeed in a few seconds.
self.bytes_read.extend(data)
for pat, occurred in patterns_occurred.items():
if occurred:
continue
# Check if pattern is in new data or spans across chunks
buf = buffers[pat]
buf.extend(data)
if pat in buf:
if pat in self.bytes_read:
patterns_occurred[pat] = True
else:
keep = min(len(pat) - 1, len(buf))
buffers[pat] = buf[-keep:] if keep > 0 else bytearray()
if all(patterns_occurred.values()):
if all([value for value in patterns_occurred.values()]):
return
with trio.fail_after(TIMEOUT_DURATION):