mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Compare commits
2 Commits
interop/py
...
interop/ru
| Author | SHA1 | Date | |
|---|---|---|---|
| 4b331d96a7 | |||
| 1d9849cb43 |
@ -1,9 +1,16 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import struct
|
||||||
|
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import (
|
||||||
|
x25519,
|
||||||
|
)
|
||||||
import multiaddr
|
import multiaddr
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
from libp2p import (
|
from libp2p import (
|
||||||
|
generate_new_rsa_identity,
|
||||||
new_host,
|
new_host,
|
||||||
)
|
)
|
||||||
from libp2p.custom_types import (
|
from libp2p.custom_types import (
|
||||||
@ -12,109 +19,496 @@ from libp2p.custom_types import (
|
|||||||
from libp2p.network.stream.net_stream import (
|
from libp2p.network.stream.net_stream import (
|
||||||
INetStream,
|
INetStream,
|
||||||
)
|
)
|
||||||
from libp2p.peer.peerinfo import (
|
from libp2p.security.noise.transport import Transport as NoiseTransport
|
||||||
info_from_p2p_addr,
|
from libp2p.stream_muxer.yamux.yamux import (
|
||||||
|
Yamux,
|
||||||
|
YamuxStream,
|
||||||
|
)
|
||||||
|
from libp2p.stream_muxer.yamux.yamux import PROTOCOL_ID as YAMUX_PROTOCOL_ID
|
||||||
|
|
||||||
|
# Configure detailed logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||||
|
handlers=[
|
||||||
|
logging.StreamHandler(),
|
||||||
|
logging.FileHandler("ping_debug.log", mode="w", encoding="utf-8"),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Protocol constants - must match rust-libp2p exactly
|
||||||
PING_PROTOCOL_ID = TProtocol("/ipfs/ping/1.0.0")
|
PING_PROTOCOL_ID = TProtocol("/ipfs/ping/1.0.0")
|
||||||
PING_LENGTH = 32
|
PING_LENGTH = 32
|
||||||
RESP_TIMEOUT = 60
|
RESP_TIMEOUT = 30
|
||||||
|
MAX_FRAME_SIZE = 1024 * 1024 # 1MB max frame size
|
||||||
|
|
||||||
|
|
||||||
|
class InteropYamux(Yamux):
|
||||||
|
"""Enhanced Yamux with proper rust-libp2p interoperability"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
logging.info("InteropYamux.__init__ called")
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.frame_count = 0
|
||||||
|
self.debug_frames = True
|
||||||
|
|
||||||
|
async def _read_exact_bytes(self, n):
|
||||||
|
"""Read exactly n bytes from the connection with proper error handling"""
|
||||||
|
if n == 0:
|
||||||
|
return b""
|
||||||
|
|
||||||
|
if n > MAX_FRAME_SIZE:
|
||||||
|
logging.error(f"Requested read size {n} exceeds maximum {MAX_FRAME_SIZE}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = b""
|
||||||
|
while len(data) < n:
|
||||||
|
try:
|
||||||
|
remaining = n - len(data)
|
||||||
|
chunk = await self.secured_conn.read(remaining)
|
||||||
|
except (trio.ClosedResourceError, trio.BrokenResourceError):
|
||||||
|
logging.debug(
|
||||||
|
f"Connection closed while reading {n}"
|
||||||
|
f"bytes (got {len(data)}) for peer {self.peer_id}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error reading {n} bytes: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not chunk:
|
||||||
|
logging.debug(
|
||||||
|
f"Connection closed while reading {n}"
|
||||||
|
f"bytes (got {len(data)}) for peer {self.peer_id}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
data += chunk
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def handle_incoming(self):
|
||||||
|
"""Enhanced incoming frame handler with better error recovery"""
|
||||||
|
logging.info(f"Starting Yamux for {self.peer_id}")
|
||||||
|
|
||||||
|
consecutive_errors = 0
|
||||||
|
max_consecutive_errors = 3
|
||||||
|
|
||||||
|
while not self.event_shutting_down.is_set():
|
||||||
|
try:
|
||||||
|
# Read frame header (12 bytes)
|
||||||
|
header_data = await self._read_exact_bytes(12)
|
||||||
|
if header_data is None:
|
||||||
|
logging.debug(
|
||||||
|
f"Connection closed or incomplete"
|
||||||
|
f"header for peer {self.peer_id}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Quick sanity check for protocol data leakage
|
||||||
|
if (
|
||||||
|
b"/ipfs" in header_data
|
||||||
|
or b"multi" in header_data
|
||||||
|
or b"noise" in header_data
|
||||||
|
):
|
||||||
|
logging.error(
|
||||||
|
f"Protocol data in header position: {header_data.hex()}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Unpack header: version, type, flags, stream_id, length
|
||||||
|
version, msg_type, flags, stream_id, length = struct.unpack(
|
||||||
|
">BBHII", header_data
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate header values strictly
|
||||||
|
if version != 0:
|
||||||
|
logging.error(f"Invalid yamux version {version}, expected 0")
|
||||||
|
break
|
||||||
|
|
||||||
|
if msg_type not in [0, 1, 2, 3]:
|
||||||
|
logging.error(f"Invalid message type {msg_type}, expected 0-3")
|
||||||
|
break
|
||||||
|
|
||||||
|
if length > MAX_FRAME_SIZE:
|
||||||
|
logging.error(f"Frame too large: {length} > {MAX_FRAME_SIZE}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Additional validation for ping frames
|
||||||
|
if msg_type == 2 and length != 4:
|
||||||
|
logging.error(
|
||||||
|
f"Invalid ping frame length: {length}, expected 4"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Log frame details
|
||||||
|
logging.debug(
|
||||||
|
f"Received header for peer {self.peer_id}"
|
||||||
|
f": type={msg_type}, flags={flags},"
|
||||||
|
f"stream_id={stream_id}, length={length}"
|
||||||
|
)
|
||||||
|
|
||||||
|
consecutive_errors = 0 # Reset error counter on successful parse
|
||||||
|
|
||||||
|
except struct.error as e:
|
||||||
|
consecutive_errors += 1
|
||||||
|
logging.error(
|
||||||
|
f"Header parse error #{consecutive_errors}"
|
||||||
|
f": {e}, data: {header_data.hex()}"
|
||||||
|
)
|
||||||
|
if consecutive_errors >= max_consecutive_errors:
|
||||||
|
logging.error("Too many consecutive header parse errors")
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Read payload if present
|
||||||
|
payload = b""
|
||||||
|
if length > 0:
|
||||||
|
payload = await self._read_exact_bytes(length)
|
||||||
|
if payload is None:
|
||||||
|
logging.debug(
|
||||||
|
f"Failed to read payload of"
|
||||||
|
f"{length} bytes for peer {self.peer_id}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
if len(payload) != length:
|
||||||
|
logging.error(
|
||||||
|
f"Payload length mismatch:"
|
||||||
|
f"got {len(payload)}, expected {length}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Process frame by type
|
||||||
|
if msg_type == 0: # Data frame
|
||||||
|
await self._handle_data_frame(stream_id, flags, payload)
|
||||||
|
|
||||||
|
elif msg_type == 1: # Window update
|
||||||
|
await self._handle_window_update(stream_id, payload)
|
||||||
|
|
||||||
|
elif msg_type == 2: # Ping frame
|
||||||
|
await self._handle_ping_frame(stream_id, flags, payload)
|
||||||
|
|
||||||
|
elif msg_type == 3: # GoAway frame
|
||||||
|
await self._handle_goaway_frame(payload)
|
||||||
|
break
|
||||||
|
|
||||||
|
except (trio.ClosedResourceError, trio.BrokenResourceError):
|
||||||
|
logging.debug(
|
||||||
|
f"Connection closed during frame processing for peer {self.peer_id}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
consecutive_errors += 1
|
||||||
|
logging.error(f"Frame processing error #{consecutive_errors}: {e}")
|
||||||
|
if consecutive_errors >= max_consecutive_errors:
|
||||||
|
logging.error("Too many consecutive frame processing errors")
|
||||||
|
break
|
||||||
|
|
||||||
|
await self._cleanup_on_error()
|
||||||
|
|
||||||
|
async def _handle_data_frame(self, stream_id, flags, payload):
|
||||||
|
"""Handle data frames with proper stream lifecycle"""
|
||||||
|
if stream_id == 0:
|
||||||
|
logging.warning("Received data frame for stream 0 (control stream)")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle SYN flag - new stream creation
|
||||||
|
if flags & 0x1: # SYN flag
|
||||||
|
if stream_id in self.streams:
|
||||||
|
logging.warning(f"SYN received for existing stream {stream_id}")
|
||||||
|
else:
|
||||||
|
logging.debug(
|
||||||
|
f"Creating new stream {stream_id} for peer {self.peer_id}"
|
||||||
|
)
|
||||||
|
stream = YamuxStream(self, stream_id, is_outbound=False)
|
||||||
|
async with self.streams_lock:
|
||||||
|
self.streams[stream_id] = stream
|
||||||
|
|
||||||
|
# Send the new stream to the handler
|
||||||
|
await self.new_stream_send_channel.send(stream)
|
||||||
|
logging.debug(f"Sent stream {stream_id} to handler")
|
||||||
|
|
||||||
|
# Add data to stream buffer if stream exists
|
||||||
|
if stream_id in self.streams:
|
||||||
|
stream = self.streams[stream_id]
|
||||||
|
if payload:
|
||||||
|
# Add to stream's receive buffer
|
||||||
|
async with self.streams_lock:
|
||||||
|
if not hasattr(stream, "_receive_buffer"):
|
||||||
|
stream._receive_buffer = bytearray()
|
||||||
|
if not hasattr(stream, "_receive_event"):
|
||||||
|
stream._receive_event = trio.Event()
|
||||||
|
stream._receive_buffer.extend(payload)
|
||||||
|
stream._receive_event.set()
|
||||||
|
|
||||||
|
# Handle stream closure flags
|
||||||
|
if flags & 0x2: # FIN flag
|
||||||
|
stream.recv_closed = True
|
||||||
|
logging.debug(f"Stream {stream_id} received FIN")
|
||||||
|
if flags & 0x4: # RST flag
|
||||||
|
stream.reset_received = True
|
||||||
|
logging.debug(f"Stream {stream_id} received RST")
|
||||||
|
else:
|
||||||
|
if payload:
|
||||||
|
logging.warning(f"Received data for unknown stream {stream_id}")
|
||||||
|
|
||||||
|
async def _handle_window_update(self, stream_id, payload):
|
||||||
|
"""Handle window update frames"""
|
||||||
|
if len(payload) != 4:
|
||||||
|
logging.warning(f"Invalid window update payload length: {len(payload)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
delta = struct.unpack(">I", payload)[0]
|
||||||
|
logging.debug(f"Window update: stream={stream_id}, delta={delta}")
|
||||||
|
|
||||||
|
async with self.streams_lock:
|
||||||
|
if stream_id in self.streams:
|
||||||
|
if not hasattr(self.streams[stream_id], "send_window"):
|
||||||
|
self.streams[stream_id].send_window = 256 * 1024 # Default window
|
||||||
|
self.streams[stream_id].send_window += delta
|
||||||
|
|
||||||
|
async def _handle_ping_frame(self, stream_id, flags, payload):
|
||||||
|
"""Handle ping/pong frames with proper validation"""
|
||||||
|
if len(payload) != 4:
|
||||||
|
logging.warning(f"Invalid ping payload length: {len(payload)} (expected 4)")
|
||||||
|
return
|
||||||
|
|
||||||
|
ping_value = struct.unpack(">I", payload)[0]
|
||||||
|
|
||||||
|
if flags & 0x1: # SYN flag - ping request
|
||||||
|
logging.debug(
|
||||||
|
f"Received ping request with value {ping_value} for peer {self.peer_id}"
|
||||||
|
)
|
||||||
|
# Send pong response (ACK flag = 0x2)
|
||||||
|
try:
|
||||||
|
pong_header = struct.pack(
|
||||||
|
">BBHII", 0, 2, 0x2, 0, 4
|
||||||
|
) # Version=0, Type=2, Flags=ACK, StreamID=0, Length=4
|
||||||
|
pong_payload = struct.pack(">I", ping_value)
|
||||||
|
await self.secured_conn.write(pong_header + pong_payload)
|
||||||
|
logging.debug(f"Sent pong response with value {ping_value}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to send pong response: {e}")
|
||||||
|
else:
|
||||||
|
# Pong response
|
||||||
|
logging.debug(f"Received pong response with value {ping_value}")
|
||||||
|
|
||||||
|
async def _handle_goaway_frame(self, payload):
|
||||||
|
"""Handle GoAway frames"""
|
||||||
|
if len(payload) != 4:
|
||||||
|
logging.warning(f"Invalid GoAway payload length: {len(payload)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
code = struct.unpack(">I", payload)[0]
|
||||||
|
logging.info(f"Received GoAway frame with code {code}")
|
||||||
|
self.event_shutting_down.set()
|
||||||
|
|
||||||
|
|
||||||
async def handle_ping(stream: INetStream) -> None:
|
async def handle_ping(stream: INetStream) -> None:
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
payload = await stream.read(PING_LENGTH)
|
|
||||||
peer_id = stream.muxed_conn.peer_id
|
peer_id = stream.muxed_conn.peer_id
|
||||||
if payload is not None:
|
logging.info(f"Handling ping stream from {peer_id}")
|
||||||
print(f"received ping from {peer_id}")
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
with trio.fail_after(RESP_TIMEOUT):
|
||||||
|
# Read initial protocol negotiation
|
||||||
|
initial_data = await stream.read(1024)
|
||||||
|
logging.debug(
|
||||||
|
f"Received initial stream data from {peer_id}"
|
||||||
|
f": {initial_data.hex()} (length={len(initial_data)})"
|
||||||
|
)
|
||||||
|
if initial_data == b"/ipfs/ping/1.0.0\n":
|
||||||
|
logging.debug(
|
||||||
|
f"Confirmed /ipfs/ping/1.0.0 protocol negotiation from {peer_id}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.warning(f"Unexpected initial data: {initial_data!r}")
|
||||||
|
|
||||||
|
# Read ping payload
|
||||||
|
payload = await stream.read(PING_LENGTH)
|
||||||
|
if not payload:
|
||||||
|
logging.info(f"Stream closed by {peer_id}")
|
||||||
|
return
|
||||||
|
if len(payload) != PING_LENGTH:
|
||||||
|
logging.warning(
|
||||||
|
f"Unexpected payload length"
|
||||||
|
f" {len(payload)} from {peer_id}: {payload.hex()}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
logging.info(
|
||||||
|
f"Received ping from {peer_id}:"
|
||||||
|
f" {payload[:8].hex()}... (length={len(payload)})"
|
||||||
|
)
|
||||||
await stream.write(payload)
|
await stream.write(payload)
|
||||||
print(f"responded with pong to {peer_id}")
|
logging.info(f"Sent pong to {peer_id}: {payload[:8].hex()}...")
|
||||||
|
|
||||||
|
except trio.TooSlowError:
|
||||||
|
logging.warning(f"Ping timeout with {peer_id}")
|
||||||
|
except trio.BrokenResourceError:
|
||||||
|
logging.info(f"Connection broken with {peer_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error handling ping from {peer_id}: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await stream.close()
|
||||||
|
logging.debug(f"Closed ping stream with {peer_id}")
|
||||||
except Exception:
|
except Exception:
|
||||||
await stream.reset()
|
pass
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
async def send_ping(stream: INetStream) -> None:
|
async def send_ping(stream: INetStream) -> None:
|
||||||
|
peer_id = stream.muxed_conn.peer_id
|
||||||
try:
|
try:
|
||||||
payload = b"\x01" * PING_LENGTH
|
payload = os.urandom(PING_LENGTH)
|
||||||
print(f"sending ping to {stream.muxed_conn.peer_id}")
|
logging.info(f"Sending ping to {peer_id}: {payload[:8].hex()}...")
|
||||||
|
|
||||||
await stream.write(payload)
|
|
||||||
|
|
||||||
with trio.fail_after(RESP_TIMEOUT):
|
with trio.fail_after(RESP_TIMEOUT):
|
||||||
|
await stream.write(payload)
|
||||||
|
logging.debug(f"Ping sent to {peer_id}")
|
||||||
response = await stream.read(PING_LENGTH)
|
response = await stream.read(PING_LENGTH)
|
||||||
|
if not response:
|
||||||
|
logging.error(f"No pong response from {peer_id}")
|
||||||
|
return
|
||||||
|
if len(response) != PING_LENGTH:
|
||||||
|
logging.warning(
|
||||||
|
f"Pong length mismatch: got {len(response)}, expected {PING_LENGTH}"
|
||||||
|
)
|
||||||
if response == payload:
|
if response == payload:
|
||||||
print(f"received pong from {stream.muxed_conn.peer_id}")
|
logging.info(f"Ping successful! Pong matches from {peer_id}")
|
||||||
|
else:
|
||||||
|
logging.warning(f"Pong mismatch from {peer_id}")
|
||||||
|
except trio.TooSlowError:
|
||||||
|
logging.error(f"Ping timeout to {peer_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"error occurred : {e}")
|
logging.error(f"Error sending ping to {peer_id}: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await stream.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def create_noise_keypair():
|
||||||
|
try:
|
||||||
|
x25519_private_key = x25519.X25519PrivateKey.generate()
|
||||||
|
|
||||||
|
class NoisePrivateKey:
|
||||||
|
def __init__(self, key):
|
||||||
|
self._key = key
|
||||||
|
|
||||||
|
def to_bytes(self):
|
||||||
|
return self._key.private_bytes_raw()
|
||||||
|
|
||||||
|
def public_key(self):
|
||||||
|
return NoisePublicKey(self._key.public_key())
|
||||||
|
|
||||||
|
def get_public_key(self):
|
||||||
|
return NoisePublicKey(self._key.public_key())
|
||||||
|
|
||||||
|
class NoisePublicKey:
|
||||||
|
def __init__(self, key):
|
||||||
|
self._key = key
|
||||||
|
|
||||||
|
def to_bytes(self):
|
||||||
|
return self._key.public_bytes_raw()
|
||||||
|
|
||||||
|
return NoisePrivateKey(x25519_private_key)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to create Noise keypair: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def info_from_p2p_addr(addr):
|
||||||
|
"""Extract peer info from multiaddr - you'll need to implement this"""
|
||||||
|
# This is a placeholder - you need to implement the actual parsing
|
||||||
|
# based on your libp2p implementation
|
||||||
|
|
||||||
|
|
||||||
async def run(port: int, destination: str) -> None:
|
async def run(port: int, destination: str) -> None:
|
||||||
localhost_ip = "127.0.0.1"
|
|
||||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||||
host = new_host(listen_addrs=[listen_addr])
|
|
||||||
|
|
||||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
try:
|
||||||
|
key_pair = generate_new_rsa_identity()
|
||||||
|
logging.debug("Generated RSA keypair")
|
||||||
|
|
||||||
|
noise_privkey = create_noise_keypair()
|
||||||
|
logging.debug("Generated Noise keypair")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Key generation failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
noise_transport = NoiseTransport(key_pair, noise_privkey=noise_privkey)
|
||||||
|
logging.debug(f"Noise transport initialized: {noise_transport}")
|
||||||
|
sec_opt = {TProtocol("/noise"): noise_transport}
|
||||||
|
muxer_opt = {TProtocol(YAMUX_PROTOCOL_ID): InteropYamux}
|
||||||
|
|
||||||
|
logging.info(f"Using muxer: {muxer_opt}")
|
||||||
|
|
||||||
|
host = new_host(key_pair=key_pair, sec_opt=sec_opt, muxer_opt=muxer_opt)
|
||||||
|
|
||||||
|
peer_id = host.get_id().pretty()
|
||||||
|
logging.info(f"Host peer ID: {peer_id}")
|
||||||
|
|
||||||
|
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery():
|
||||||
if not destination:
|
if not destination:
|
||||||
host.set_stream_handler(PING_PROTOCOL_ID, handle_ping)
|
host.set_stream_handler(PING_PROTOCOL_ID, handle_ping)
|
||||||
|
|
||||||
print(
|
logging.info(f"Server listening on {listen_addr}")
|
||||||
"Run this from the same folder in another console:\n\n"
|
logging.info(f"Full address: {listen_addr}/p2p/{peer_id}")
|
||||||
f"ping-demo -p {int(port) + 1} "
|
logging.info("Waiting for connections...")
|
||||||
f"-d /ip4/{localhost_ip}/tcp/{port}/p2p/{host.get_id().pretty()}\n"
|
|
||||||
)
|
|
||||||
print("Waiting for incoming connection...")
|
|
||||||
|
|
||||||
|
await trio.sleep_forever()
|
||||||
else:
|
else:
|
||||||
maddr = multiaddr.Multiaddr(destination)
|
maddr = multiaddr.Multiaddr(destination)
|
||||||
info = info_from_p2p_addr(maddr)
|
info = info_from_p2p_addr(maddr)
|
||||||
await host.connect(info)
|
|
||||||
stream = await host.new_stream(info.peer_id, [PING_PROTOCOL_ID])
|
|
||||||
|
|
||||||
nursery.start_soon(send_ping, stream)
|
logging.info(f"Connecting to {info.peer_id}")
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
await trio.sleep_forever()
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
description = """
|
|
||||||
This program demonstrates a simple p2p ping application using libp2p.
|
|
||||||
To use it, first run 'python ping.py -p <PORT>', where <PORT> is the port number.
|
|
||||||
Then, run another instance with 'python ping.py -p <ANOTHER_PORT> -d <DESTINATION>',
|
|
||||||
where <DESTINATION> is the multiaddress of the previous listener host.
|
|
||||||
"""
|
|
||||||
|
|
||||||
example_maddr = (
|
|
||||||
"/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description=description)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"-p", "--port", default=8000, type=int, help="source port number"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-d",
|
|
||||||
"--destination",
|
|
||||||
type=str,
|
|
||||||
help=f"destination multiaddr string, e.g. {example_maddr}",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if not args.port:
|
|
||||||
raise RuntimeError("failed to determine local port")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
trio.run(run, *(args.port, args.destination))
|
with trio.fail_after(30):
|
||||||
|
await host.connect(info)
|
||||||
|
logging.info(f"Connected to {info.peer_id}")
|
||||||
|
|
||||||
|
await trio.sleep(2.0)
|
||||||
|
|
||||||
|
logging.info(f"Opening ping stream to {info.peer_id}")
|
||||||
|
stream = await host.new_stream(info.peer_id, [PING_PROTOCOL_ID])
|
||||||
|
logging.info(f"Opened ping stream to {info.peer_id}")
|
||||||
|
|
||||||
|
await trio.sleep(0.5)
|
||||||
|
|
||||||
|
await send_ping(stream)
|
||||||
|
|
||||||
|
logging.info("Ping completed successfully")
|
||||||
|
|
||||||
|
logging.info("Keeping connection alive for 5 seconds...")
|
||||||
|
await trio.sleep(5.0)
|
||||||
|
except trio.TooSlowError:
|
||||||
|
logging.error(f"Connection timeout to {info.peer_id}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Connection failed to {info.peer_id}: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="libp2p ping with Rust interoperability"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-p", "--port", default=8000, type=int, help="Port to listen on"
|
||||||
|
)
|
||||||
|
parser.add_argument("-d", "--destination", type=str, help="Destination multiaddr")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
trio.run(run, args.port, args.destination)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
logging.info("Terminated by user")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Fatal error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -1,19 +0,0 @@
|
|||||||
These commands are to be run in `./interop/exec`
|
|
||||||
|
|
||||||
## Redis
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker run -p 6379:6379 -it redis:latest
|
|
||||||
```
|
|
||||||
|
|
||||||
## Listener
|
|
||||||
|
|
||||||
```bash
|
|
||||||
transport=tcp ip=0.0.0.0 is_dialer=false redis_addr=6379 test_timeout_seconds=180 security=insecure muxer=mplex python3 native_ping.py
|
|
||||||
```
|
|
||||||
|
|
||||||
## Dialer
|
|
||||||
|
|
||||||
```bash
|
|
||||||
transport=tcp ip=0.0.0.0 is_dialer=true port=8001 redis_addr=6379 port=8001 test_timeout_seconds=180 security=insecure muxer=mplex python3 native_ping.py
|
|
||||||
```
|
|
||||||
107
interop/arch.py
107
interop/arch.py
@ -1,107 +0,0 @@
|
|||||||
from dataclasses import (
|
|
||||||
dataclass,
|
|
||||||
)
|
|
||||||
|
|
||||||
import multiaddr
|
|
||||||
import redis
|
|
||||||
import trio
|
|
||||||
|
|
||||||
from libp2p import (
|
|
||||||
new_host,
|
|
||||||
)
|
|
||||||
from libp2p.crypto.keys import (
|
|
||||||
KeyPair,
|
|
||||||
)
|
|
||||||
from libp2p.crypto.rsa import (
|
|
||||||
create_new_key_pair,
|
|
||||||
)
|
|
||||||
from libp2p.crypto.x25519 import create_new_key_pair as create_new_x25519_key_pair
|
|
||||||
from libp2p.custom_types import (
|
|
||||||
TProtocol,
|
|
||||||
)
|
|
||||||
from libp2p.security.insecure.transport import (
|
|
||||||
PLAINTEXT_PROTOCOL_ID,
|
|
||||||
InsecureTransport,
|
|
||||||
)
|
|
||||||
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
|
|
||||||
from libp2p.security.noise.transport import Transport as NoiseTransport
|
|
||||||
import libp2p.security.secio.transport as secio
|
|
||||||
from libp2p.stream_muxer.mplex.mplex import (
|
|
||||||
MPLEX_PROTOCOL_ID,
|
|
||||||
Mplex,
|
|
||||||
)
|
|
||||||
from libp2p.stream_muxer.yamux.yamux import (
|
|
||||||
Yamux,
|
|
||||||
)
|
|
||||||
from libp2p.stream_muxer.yamux.yamux import PROTOCOL_ID as YAMUX_PROTOCOL_ID
|
|
||||||
|
|
||||||
|
|
||||||
def generate_new_rsa_identity() -> KeyPair:
|
|
||||||
return create_new_key_pair()
|
|
||||||
|
|
||||||
|
|
||||||
async def build_host(transport: str, ip: str, port: str, sec_protocol: str, muxer: str):
|
|
||||||
match (sec_protocol, muxer):
|
|
||||||
case ("insecure", "mplex"):
|
|
||||||
key_pair = create_new_key_pair()
|
|
||||||
host = new_host(
|
|
||||||
key_pair,
|
|
||||||
{TProtocol(MPLEX_PROTOCOL_ID): Mplex},
|
|
||||||
{
|
|
||||||
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair),
|
|
||||||
TProtocol(secio.ID): secio.Transport(key_pair),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
muladdr = multiaddr.Multiaddr(f"/ip4/{ip}/tcp/{port}")
|
|
||||||
return (host, muladdr)
|
|
||||||
case ("insecure", "yamux"):
|
|
||||||
key_pair = create_new_key_pair()
|
|
||||||
host = new_host(
|
|
||||||
key_pair,
|
|
||||||
{TProtocol(YAMUX_PROTOCOL_ID): Yamux},
|
|
||||||
{
|
|
||||||
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair),
|
|
||||||
TProtocol(secio.ID): secio.Transport(key_pair),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
muladdr = multiaddr.Multiaddr(f"/ip4/{ip}/tcp/{port}")
|
|
||||||
return (host, muladdr)
|
|
||||||
case ("noise", "yamux"):
|
|
||||||
key_pair = create_new_key_pair()
|
|
||||||
noise_key_pair = create_new_x25519_key_pair()
|
|
||||||
|
|
||||||
host = new_host(
|
|
||||||
key_pair,
|
|
||||||
{TProtocol(YAMUX_PROTOCOL_ID): Yamux},
|
|
||||||
{
|
|
||||||
NOISE_PROTOCOL_ID: NoiseTransport(
|
|
||||||
key_pair, noise_privkey=noise_key_pair.private_key
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
muladdr = multiaddr.Multiaddr(f"/ip4/{ip}/tcp/{port}")
|
|
||||||
return (host, muladdr)
|
|
||||||
case _:
|
|
||||||
raise ValueError("Protocols not supported")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RedisClient:
|
|
||||||
client: redis.Redis
|
|
||||||
|
|
||||||
def brpop(self, key: str, timeout: float) -> list[str]:
|
|
||||||
result = self.client.brpop([key], timeout)
|
|
||||||
return [result[1]] if result else []
|
|
||||||
|
|
||||||
def rpush(self, key: str, value: str) -> None:
|
|
||||||
self.client.rpush(key, value)
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
client = RedisClient(redis.Redis(host="localhost", port=6379, db=0))
|
|
||||||
client.rpush("test", "hello")
|
|
||||||
print(client.blpop("test", timeout=5))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
trio.run(main)
|
|
||||||
@ -1,57 +0,0 @@
|
|||||||
from dataclasses import (
|
|
||||||
dataclass,
|
|
||||||
)
|
|
||||||
import os
|
|
||||||
from typing import (
|
|
||||||
Optional,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def str_to_bool(val: str) -> bool:
|
|
||||||
return val.lower() in ("true", "1")
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigError(Exception):
|
|
||||||
"""Raised when the required environment variables are missing or invalid"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Config:
|
|
||||||
transport: str
|
|
||||||
sec_protocol: Optional[str]
|
|
||||||
muxer: Optional[str]
|
|
||||||
ip: str
|
|
||||||
is_dialer: bool
|
|
||||||
test_timeout: int
|
|
||||||
redis_addr: str
|
|
||||||
port: str
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_env(cls) -> "Config":
|
|
||||||
try:
|
|
||||||
transport = os.environ["transport"]
|
|
||||||
ip = os.environ["ip"]
|
|
||||||
except KeyError as e:
|
|
||||||
raise ConfigError(f"{e.args[0]} env variable not set") from None
|
|
||||||
|
|
||||||
try:
|
|
||||||
is_dialer = str_to_bool(os.environ.get("is_dialer", "true"))
|
|
||||||
test_timeout = int(os.environ.get("test_timeout", "180"))
|
|
||||||
except ValueError as e:
|
|
||||||
raise ConfigError(f"Invalid value in env: {e}") from None
|
|
||||||
|
|
||||||
redis_addr = os.environ.get("redis_addr", 6379)
|
|
||||||
sec_protocol = os.environ.get("security")
|
|
||||||
muxer = os.environ.get("muxer")
|
|
||||||
port = os.environ.get("port", "8000")
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
transport=transport,
|
|
||||||
sec_protocol=sec_protocol,
|
|
||||||
muxer=muxer,
|
|
||||||
ip=ip,
|
|
||||||
is_dialer=is_dialer,
|
|
||||||
test_timeout=test_timeout,
|
|
||||||
redis_addr=redis_addr,
|
|
||||||
port=port,
|
|
||||||
)
|
|
||||||
@ -1,33 +0,0 @@
|
|||||||
import trio
|
|
||||||
|
|
||||||
from interop.exec.config.mod import (
|
|
||||||
Config,
|
|
||||||
ConfigError,
|
|
||||||
)
|
|
||||||
from interop.lib import (
|
|
||||||
run_test,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def main() -> None:
|
|
||||||
try:
|
|
||||||
config = Config.from_env()
|
|
||||||
except ConfigError as e:
|
|
||||||
print(f"Config error: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Uncomment and implement when ready
|
|
||||||
_ = await run_test(
|
|
||||||
config.transport,
|
|
||||||
config.ip,
|
|
||||||
config.port,
|
|
||||||
config.is_dialer,
|
|
||||||
config.test_timeout,
|
|
||||||
config.redis_addr,
|
|
||||||
config.sec_protocol,
|
|
||||||
config.muxer,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
trio.run(main)
|
|
||||||
120
interop/lib.py
120
interop/lib.py
@ -1,120 +0,0 @@
|
|||||||
from dataclasses import (
|
|
||||||
dataclass,
|
|
||||||
)
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
|
|
||||||
from loguru import (
|
|
||||||
logger,
|
|
||||||
)
|
|
||||||
import multiaddr
|
|
||||||
import redis
|
|
||||||
import trio
|
|
||||||
|
|
||||||
from interop.arch import (
|
|
||||||
RedisClient,
|
|
||||||
build_host,
|
|
||||||
)
|
|
||||||
from libp2p.custom_types import (
|
|
||||||
TProtocol,
|
|
||||||
)
|
|
||||||
from libp2p.network.stream.net_stream import (
|
|
||||||
INetStream,
|
|
||||||
)
|
|
||||||
from libp2p.peer.peerinfo import (
|
|
||||||
info_from_p2p_addr,
|
|
||||||
)
|
|
||||||
|
|
||||||
PING_PROTOCOL_ID = TProtocol("/ipfs/ping/1.0.0")
|
|
||||||
PING_LENGTH = 32
|
|
||||||
RESP_TIMEOUT = 60
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_ping(stream: INetStream) -> None:
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
payload = await stream.read(PING_LENGTH)
|
|
||||||
peer_id = stream.muxed_conn.peer_id
|
|
||||||
if payload is not None:
|
|
||||||
print(f"received ping from {peer_id}")
|
|
||||||
|
|
||||||
await stream.write(payload)
|
|
||||||
print(f"responded with pong to {peer_id}")
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
await stream.reset()
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
async def send_ping(stream: INetStream) -> None:
|
|
||||||
try:
|
|
||||||
payload = b"\x01" * PING_LENGTH
|
|
||||||
print(f"sending ping to {stream.muxed_conn.peer_id}")
|
|
||||||
|
|
||||||
await stream.write(payload)
|
|
||||||
|
|
||||||
with trio.fail_after(RESP_TIMEOUT):
|
|
||||||
response = await stream.read(PING_LENGTH)
|
|
||||||
|
|
||||||
if response == payload:
|
|
||||||
print(f"received pong from {stream.muxed_conn.peer_id}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"error occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def run_test(
|
|
||||||
transport, ip, port, is_dialer, test_timeout, redis_addr, sec_protocol, muxer
|
|
||||||
):
|
|
||||||
logger.info("Starting run_test")
|
|
||||||
|
|
||||||
redis_client = RedisClient(
|
|
||||||
redis.Redis(host="localhost", port=int(redis_addr), db=0)
|
|
||||||
)
|
|
||||||
(host, listen_addr) = await build_host(transport, ip, port, sec_protocol, muxer)
|
|
||||||
logger.info(f"Running ping test local_peer={host.get_id()}")
|
|
||||||
|
|
||||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
|
||||||
if not is_dialer:
|
|
||||||
host.set_stream_handler(PING_PROTOCOL_ID, handle_ping)
|
|
||||||
ma = f"{listen_addr}/p2p/{host.get_id().pretty()}"
|
|
||||||
redis_client.rpush("listenerAddr", ma)
|
|
||||||
|
|
||||||
logger.info(f"Test instance, listening: {ma}")
|
|
||||||
else:
|
|
||||||
redis_addr = redis_client.brpop("listenerAddr", timeout=5)
|
|
||||||
destination = redis_addr[0].decode()
|
|
||||||
maddr = multiaddr.Multiaddr(destination)
|
|
||||||
info = info_from_p2p_addr(maddr)
|
|
||||||
|
|
||||||
handshake_start = time.perf_counter()
|
|
||||||
|
|
||||||
logger.info("GETTING READY FOR CONNECTION")
|
|
||||||
await host.connect(info)
|
|
||||||
logger.info("HOST CONNECTED")
|
|
||||||
|
|
||||||
# TILL HERE EVERYTHING IS FINE
|
|
||||||
|
|
||||||
stream = await host.new_stream(info.peer_id, [PING_PROTOCOL_ID])
|
|
||||||
logger.info("CREATED NEW STREAM")
|
|
||||||
|
|
||||||
# DOES NOT MORE FORWARD FROM THIS
|
|
||||||
logger.info("Remote conection established")
|
|
||||||
|
|
||||||
nursery.start_soon(send_ping, stream)
|
|
||||||
|
|
||||||
handshake_plus_ping = (time.perf_counter() - handshake_start) * 1000.0
|
|
||||||
|
|
||||||
logger.info(f"handshake time: {handshake_plus_ping:.2f}ms")
|
|
||||||
return
|
|
||||||
|
|
||||||
await trio.sleep_forever()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Report:
|
|
||||||
handshake_plus_one_rtt_millis: float
|
|
||||||
ping_rtt_millis: float
|
|
||||||
|
|
||||||
def gen_report(self):
|
|
||||||
return json.dumps(self.__dict__)
|
|
||||||
@ -1,6 +1,5 @@
|
|||||||
from collections.abc import (
|
from collections.abc import (
|
||||||
Mapping,
|
Mapping,
|
||||||
Sequence,
|
|
||||||
)
|
)
|
||||||
from importlib.metadata import version as __version
|
from importlib.metadata import version as __version
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -10,8 +9,6 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
import multiaddr
|
|
||||||
|
|
||||||
from libp2p.abc import (
|
from libp2p.abc import (
|
||||||
IHost,
|
IHost,
|
||||||
IMuxedConn,
|
IMuxedConn,
|
||||||
@ -157,7 +154,6 @@ def new_swarm(
|
|||||||
sec_opt: Optional[TSecurityOptions] = None,
|
sec_opt: Optional[TSecurityOptions] = None,
|
||||||
peerstore_opt: Optional[IPeerStore] = None,
|
peerstore_opt: Optional[IPeerStore] = None,
|
||||||
muxer_preference: Optional[Literal["YAMUX", "MPLEX"]] = None,
|
muxer_preference: Optional[Literal["YAMUX", "MPLEX"]] = None,
|
||||||
listen_addrs: Optional[Sequence[multiaddr.Multiaddr]] = None,
|
|
||||||
) -> INetworkService:
|
) -> INetworkService:
|
||||||
"""
|
"""
|
||||||
Create a swarm instance based on the parameters.
|
Create a swarm instance based on the parameters.
|
||||||
@ -167,7 +163,6 @@ def new_swarm(
|
|||||||
:param sec_opt: optional choice of security upgrade
|
:param sec_opt: optional choice of security upgrade
|
||||||
:param peerstore_opt: optional peerstore
|
:param peerstore_opt: optional peerstore
|
||||||
:param muxer_preference: optional explicit muxer preference
|
:param muxer_preference: optional explicit muxer preference
|
||||||
:param listen_addrs: optional list of multiaddrs to listen on
|
|
||||||
:return: return a default swarm instance
|
:return: return a default swarm instance
|
||||||
|
|
||||||
Note: Yamux (/yamux/1.0.0) is the preferred stream multiplexer
|
Note: Yamux (/yamux/1.0.0) is the preferred stream multiplexer
|
||||||
@ -180,16 +175,8 @@ def new_swarm(
|
|||||||
|
|
||||||
id_opt = generate_peer_id_from(key_pair)
|
id_opt = generate_peer_id_from(key_pair)
|
||||||
|
|
||||||
if listen_addrs is None:
|
# TODO: Parse `listen_addrs` to determine transport
|
||||||
transport = TCP()
|
transport = TCP()
|
||||||
else:
|
|
||||||
addr = listen_addrs[0]
|
|
||||||
if addr.__contains__("tcp"):
|
|
||||||
transport = TCP()
|
|
||||||
elif addr.__contains__("quic"):
|
|
||||||
raise ValueError("QUIC not yet supported")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown transport in listen_addrs: {listen_addrs}")
|
|
||||||
|
|
||||||
# Generate X25519 keypair for Noise
|
# Generate X25519 keypair for Noise
|
||||||
noise_key_pair = create_new_x25519_key_pair()
|
noise_key_pair = create_new_x25519_key_pair()
|
||||||
@ -242,7 +229,6 @@ def new_host(
|
|||||||
peerstore_opt: Optional[IPeerStore] = None,
|
peerstore_opt: Optional[IPeerStore] = None,
|
||||||
disc_opt: Optional[IPeerRouting] = None,
|
disc_opt: Optional[IPeerRouting] = None,
|
||||||
muxer_preference: Optional[Literal["YAMUX", "MPLEX"]] = None,
|
muxer_preference: Optional[Literal["YAMUX", "MPLEX"]] = None,
|
||||||
listen_addrs: Sequence[multiaddr.Multiaddr] = None,
|
|
||||||
) -> IHost:
|
) -> IHost:
|
||||||
"""
|
"""
|
||||||
Create a new libp2p host based on the given parameters.
|
Create a new libp2p host based on the given parameters.
|
||||||
@ -253,7 +239,6 @@ def new_host(
|
|||||||
:param peerstore_opt: optional peerstore
|
:param peerstore_opt: optional peerstore
|
||||||
:param disc_opt: optional discovery
|
:param disc_opt: optional discovery
|
||||||
:param muxer_preference: optional explicit muxer preference
|
:param muxer_preference: optional explicit muxer preference
|
||||||
:param listen_addrs: optional list of multiaddrs to listen on
|
|
||||||
:return: return a host instance
|
:return: return a host instance
|
||||||
"""
|
"""
|
||||||
swarm = new_swarm(
|
swarm = new_swarm(
|
||||||
@ -262,7 +247,6 @@ def new_host(
|
|||||||
sec_opt=sec_opt,
|
sec_opt=sec_opt,
|
||||||
peerstore_opt=peerstore_opt,
|
peerstore_opt=peerstore_opt,
|
||||||
muxer_preference=muxer_preference,
|
muxer_preference=muxer_preference,
|
||||||
listen_addrs=listen_addrs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if disc_opt is not None:
|
if disc_opt is not None:
|
||||||
|
|||||||
@ -8,14 +8,10 @@ from collections.abc import (
|
|||||||
KeysView,
|
KeysView,
|
||||||
Sequence,
|
Sequence,
|
||||||
)
|
)
|
||||||
from types import (
|
|
||||||
TracebackType,
|
|
||||||
)
|
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
AsyncContextManager,
|
AsyncContextManager,
|
||||||
Optional,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from multiaddr import (
|
from multiaddr import (
|
||||||
@ -219,7 +215,7 @@ class IMuxedConn(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class IMuxedStream(ReadWriteCloser, AsyncContextManager["IMuxedStream"]):
|
class IMuxedStream(ReadWriteCloser):
|
||||||
"""
|
"""
|
||||||
Interface for a multiplexed stream.
|
Interface for a multiplexed stream.
|
||||||
|
|
||||||
@ -253,20 +249,6 @@ class IMuxedStream(ReadWriteCloser, AsyncContextManager["IMuxedStream"]):
|
|||||||
otherwise False.
|
otherwise False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def __aenter__(self) -> "IMuxedStream":
|
|
||||||
"""Enter the async context manager."""
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(
|
|
||||||
self,
|
|
||||||
exc_type: Optional[type[BaseException]],
|
|
||||||
exc_val: Optional[BaseException],
|
|
||||||
exc_tb: Optional[TracebackType],
|
|
||||||
) -> None:
|
|
||||||
"""Exit the async context manager and close the stream."""
|
|
||||||
await self.close()
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------- net_stream interface.py --------------------------
|
# -------------------------- net_stream interface.py --------------------------
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from collections.abc import (
|
|||||||
from contextlib import (
|
from contextlib import (
|
||||||
asynccontextmanager,
|
asynccontextmanager,
|
||||||
)
|
)
|
||||||
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Optional,
|
Optional,
|
||||||
@ -67,10 +68,7 @@ if TYPE_CHECKING:
|
|||||||
# telling it to listen on the given listen addresses.
|
# telling it to listen on the given listen addresses.
|
||||||
|
|
||||||
|
|
||||||
# logger = logging.getLogger("libp2p.network.basic_host")
|
logger = logging.getLogger("libp2p.network.basic_host")
|
||||||
from loguru import (
|
|
||||||
logger,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BasicHost(IHost):
|
class BasicHost(IHost):
|
||||||
@ -183,15 +181,12 @@ class BasicHost(IHost):
|
|||||||
:return: stream: new stream created
|
:return: stream: new stream created
|
||||||
"""
|
"""
|
||||||
net_stream = await self._network.new_stream(peer_id)
|
net_stream = await self._network.new_stream(peer_id)
|
||||||
logger.info("INETSTREAM CHECKING IN")
|
|
||||||
logger.info(protocol_ids)
|
|
||||||
# Perform protocol muxing to determine protocol to use
|
# Perform protocol muxing to determine protocol to use
|
||||||
try:
|
try:
|
||||||
logger.debug("PROTOCOLS TRYING TO GET SENT")
|
|
||||||
selected_protocol = await self.multiselect_client.select_one_of(
|
selected_protocol = await self.multiselect_client.select_one_of(
|
||||||
list(protocol_ids), MultiselectCommunicator(net_stream)
|
list(protocol_ids), MultiselectCommunicator(net_stream)
|
||||||
)
|
)
|
||||||
logger.info("PROTOCOLS GOT SENT")
|
|
||||||
except MultiselectClientError as error:
|
except MultiselectClientError as error:
|
||||||
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
|
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
|
||||||
await net_stream.reset()
|
await net_stream.reset()
|
||||||
@ -200,29 +195,6 @@ class BasicHost(IHost):
|
|||||||
net_stream.set_protocol(selected_protocol)
|
net_stream.set_protocol(selected_protocol)
|
||||||
return net_stream
|
return net_stream
|
||||||
|
|
||||||
async def send_command(self, peer_id: ID, command: str) -> list[str]:
|
|
||||||
"""
|
|
||||||
Send a multistream-select command to the specified peer and return
|
|
||||||
the response.
|
|
||||||
|
|
||||||
:param peer_id: peer_id that host is connecting
|
|
||||||
:param command: supported multistream-select command (e.g., "ls)
|
|
||||||
:raise StreamFailure: If the stream cannot be opened or negotiation fails
|
|
||||||
:return: list of strings representing the response from peer.
|
|
||||||
"""
|
|
||||||
new_stream = await self._network.new_stream(peer_id)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self.multiselect_client.query_multistream_command(
|
|
||||||
MultiselectCommunicator(new_stream), command
|
|
||||||
)
|
|
||||||
except MultiselectClientError as error:
|
|
||||||
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
|
|
||||||
await new_stream.reset()
|
|
||||||
raise StreamFailure(f"failed to open a stream to peer {peer_id}") from error
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
async def connect(self, peer_info: PeerInfo) -> None:
|
async def connect(self, peer_info: PeerInfo) -> None:
|
||||||
"""
|
"""
|
||||||
Ensure there is a connection between this host and the peer
|
Ensure there is a connection between this host and the peer
|
||||||
|
|||||||
@ -1,11 +1,8 @@
|
|||||||
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
Optional,
|
Optional,
|
||||||
)
|
)
|
||||||
|
|
||||||
# logger = logging.getLogger("libp2p.network.swarm")
|
|
||||||
from loguru import (
|
|
||||||
logger,
|
|
||||||
)
|
|
||||||
from multiaddr import (
|
from multiaddr import (
|
||||||
Multiaddr,
|
Multiaddr,
|
||||||
)
|
)
|
||||||
@ -58,6 +55,8 @@ from .exceptions import (
|
|||||||
SwarmException,
|
SwarmException,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger("libp2p.network.swarm")
|
||||||
|
|
||||||
|
|
||||||
def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn:
|
def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn:
|
||||||
async def stream_handler(stream: INetStream) -> None:
|
async def stream_handler(stream: INetStream) -> None:
|
||||||
@ -131,7 +130,6 @@ class Swarm(Service, INetworkService):
|
|||||||
:return: muxed connection
|
:return: muxed connection
|
||||||
"""
|
"""
|
||||||
if peer_id in self.connections:
|
if peer_id in self.connections:
|
||||||
logger.info("WE ARE RETURNING, PEER ALREADAY EXISTS")
|
|
||||||
# If muxed connection already exists for peer_id,
|
# If muxed connection already exists for peer_id,
|
||||||
# set muxed connection equal to existing muxed connection
|
# set muxed connection equal to existing muxed connection
|
||||||
return self.connections[peer_id]
|
return self.connections[peer_id]
|
||||||
@ -152,7 +150,6 @@ class Swarm(Service, INetworkService):
|
|||||||
# Try all known addresses
|
# Try all known addresses
|
||||||
for multiaddr in addrs:
|
for multiaddr in addrs:
|
||||||
try:
|
try:
|
||||||
logger.info("HANDSHAKE GOING TO HAPPEN")
|
|
||||||
return await self.dial_addr(multiaddr, peer_id)
|
return await self.dial_addr(multiaddr, peer_id)
|
||||||
except SwarmException as e:
|
except SwarmException as e:
|
||||||
exceptions.append(e)
|
exceptions.append(e)
|
||||||
@ -227,11 +224,8 @@ class Swarm(Service, INetworkService):
|
|||||||
logger.debug("attempting to open a stream to peer %s", peer_id)
|
logger.debug("attempting to open a stream to peer %s", peer_id)
|
||||||
|
|
||||||
swarm_conn = await self.dial_peer(peer_id)
|
swarm_conn = await self.dial_peer(peer_id)
|
||||||
logger.info("INETCONN CREATED")
|
|
||||||
|
|
||||||
net_stream = await swarm_conn.new_stream()
|
net_stream = await swarm_conn.new_stream()
|
||||||
logger.info("INETSTREAM CREATED")
|
|
||||||
|
|
||||||
logger.debug("successfully opened a stream to peer %s", peer_id)
|
logger.debug("successfully opened a stream to peer %s", peer_id)
|
||||||
return net_stream
|
return net_stream
|
||||||
|
|
||||||
|
|||||||
@ -60,14 +60,8 @@ class Multiselect(IMultiselectMuxer):
|
|||||||
raise MultiselectError() from error
|
raise MultiselectError() from error
|
||||||
|
|
||||||
if command == "ls":
|
if command == "ls":
|
||||||
supported_protocols = list(self.handlers.keys())
|
# TODO: handle ls command
|
||||||
response = "\n".join(supported_protocols) + "\n"
|
pass
|
||||||
|
|
||||||
try:
|
|
||||||
await communicator.write(response)
|
|
||||||
except MultiselectCommunicatorError as error:
|
|
||||||
raise MultiselectError() from error
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
protocol = TProtocol(command)
|
protocol = TProtocol(command)
|
||||||
if protocol in self.handlers:
|
if protocol in self.handlers:
|
||||||
|
|||||||
@ -2,10 +2,6 @@ from collections.abc import (
|
|||||||
Sequence,
|
Sequence,
|
||||||
)
|
)
|
||||||
|
|
||||||
from loguru import (
|
|
||||||
logger,
|
|
||||||
)
|
|
||||||
|
|
||||||
from libp2p.abc import (
|
from libp2p.abc import (
|
||||||
IMultiselectClient,
|
IMultiselectClient,
|
||||||
IMultiselectCommunicator,
|
IMultiselectCommunicator,
|
||||||
@ -40,15 +36,11 @@ class MultiselectClient(IMultiselectClient):
|
|||||||
try:
|
try:
|
||||||
await communicator.write(MULTISELECT_PROTOCOL_ID)
|
await communicator.write(MULTISELECT_PROTOCOL_ID)
|
||||||
except MultiselectCommunicatorError as error:
|
except MultiselectCommunicatorError as error:
|
||||||
logger.error("WROTE FAIL")
|
|
||||||
raise MultiselectClientError() from error
|
raise MultiselectClientError() from error
|
||||||
|
|
||||||
logger.info(f"WROTE SUC, {MULTISELECT_PROTOCOL_ID}")
|
|
||||||
try:
|
try:
|
||||||
handshake_contents = await communicator.read()
|
handshake_contents = await communicator.read()
|
||||||
logger.info(f"READ SUC, {handshake_contents}")
|
|
||||||
except MultiselectCommunicatorError as error:
|
except MultiselectCommunicatorError as error:
|
||||||
logger.error(f"READ FAIL, {error}")
|
|
||||||
raise MultiselectClientError() from error
|
raise MultiselectClientError() from error
|
||||||
|
|
||||||
if not is_valid_handshake(handshake_contents):
|
if not is_valid_handshake(handshake_contents):
|
||||||
@ -67,12 +59,9 @@ class MultiselectClient(IMultiselectClient):
|
|||||||
:return: selected protocol
|
:return: selected protocol
|
||||||
:raise MultiselectClientError: raised when protocol negotiation failed
|
:raise MultiselectClientError: raised when protocol negotiation failed
|
||||||
"""
|
"""
|
||||||
logger.info("TRYING TO GET THE HANDSHAKE HAPPENED")
|
|
||||||
await self.handshake(communicator)
|
await self.handshake(communicator)
|
||||||
logger.info("HANDSHAKE HAPPENED")
|
|
||||||
|
|
||||||
for protocol in protocols:
|
for protocol in protocols:
|
||||||
logger.info(protocol)
|
|
||||||
try:
|
try:
|
||||||
selected_protocol = await self.try_select(communicator, protocol)
|
selected_protocol = await self.try_select(communicator, protocol)
|
||||||
return selected_protocol
|
return selected_protocol
|
||||||
@ -81,36 +70,6 @@ class MultiselectClient(IMultiselectClient):
|
|||||||
|
|
||||||
raise MultiselectClientError("protocols not supported")
|
raise MultiselectClientError("protocols not supported")
|
||||||
|
|
||||||
async def query_multistream_command(
|
|
||||||
self, communicator: IMultiselectCommunicator, command: str
|
|
||||||
) -> list[str]:
|
|
||||||
"""
|
|
||||||
Send a multistream-select command over the given communicator and return
|
|
||||||
parsed response.
|
|
||||||
|
|
||||||
:param communicator: communicator to use to communicate with counterparty
|
|
||||||
:param command: supported multistream-select command(e.g., ls)
|
|
||||||
:raise MultiselectClientError: If the communicator fails to process data.
|
|
||||||
:return: list of strings representing the response from peer.
|
|
||||||
"""
|
|
||||||
await self.handshake(communicator)
|
|
||||||
|
|
||||||
if command == "ls":
|
|
||||||
try:
|
|
||||||
await communicator.write("ls")
|
|
||||||
except MultiselectCommunicatorError as error:
|
|
||||||
raise MultiselectClientError() from error
|
|
||||||
else:
|
|
||||||
raise ValueError("Command not supported")
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await communicator.read()
|
|
||||||
response_list = response.strip().splitlines()
|
|
||||||
except MultiselectCommunicatorError as error:
|
|
||||||
raise MultiselectClientError() from error
|
|
||||||
|
|
||||||
return response_list
|
|
||||||
|
|
||||||
async def try_select(
|
async def try_select(
|
||||||
self, communicator: IMultiselectCommunicator, protocol: TProtocol
|
self, communicator: IMultiselectCommunicator, protocol: TProtocol
|
||||||
) -> TProtocol:
|
) -> TProtocol:
|
||||||
@ -124,17 +83,11 @@ class MultiselectClient(IMultiselectClient):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
await communicator.write(protocol)
|
await communicator.write(protocol)
|
||||||
from loguru import (
|
|
||||||
logger,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(protocol)
|
|
||||||
except MultiselectCommunicatorError as error:
|
except MultiselectCommunicatorError as error:
|
||||||
raise MultiselectClientError() from error
|
raise MultiselectClientError() from error
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await communicator.read()
|
response = await communicator.read()
|
||||||
logger.info("Response: ", response)
|
|
||||||
except MultiselectCommunicatorError as error:
|
except MultiselectCommunicatorError as error:
|
||||||
raise MultiselectClientError() from error
|
raise MultiselectClientError() from error
|
||||||
|
|
||||||
|
|||||||
@ -122,9 +122,6 @@ class Pubsub(Service, IPubsub):
|
|||||||
strict_signing: bool
|
strict_signing: bool
|
||||||
sign_key: PrivateKey
|
sign_key: PrivateKey
|
||||||
|
|
||||||
# Set of blacklisted peer IDs
|
|
||||||
blacklisted_peers: set[ID]
|
|
||||||
|
|
||||||
event_handle_peer_queue_started: trio.Event
|
event_handle_peer_queue_started: trio.Event
|
||||||
event_handle_dead_peer_queue_started: trio.Event
|
event_handle_dead_peer_queue_started: trio.Event
|
||||||
|
|
||||||
@ -204,9 +201,6 @@ class Pubsub(Service, IPubsub):
|
|||||||
|
|
||||||
self.counter = int(time.time())
|
self.counter = int(time.time())
|
||||||
|
|
||||||
# Set of blacklisted peer IDs
|
|
||||||
self.blacklisted_peers = set()
|
|
||||||
|
|
||||||
self.event_handle_peer_queue_started = trio.Event()
|
self.event_handle_peer_queue_started = trio.Event()
|
||||||
self.event_handle_dead_peer_queue_started = trio.Event()
|
self.event_handle_dead_peer_queue_started = trio.Event()
|
||||||
|
|
||||||
@ -326,82 +320,6 @@ class Pubsub(Service, IPubsub):
|
|||||||
if topic in self.topic_validators
|
if topic in self.topic_validators
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_to_blacklist(self, peer_id: ID) -> None:
|
|
||||||
"""
|
|
||||||
Add a peer to the blacklist.
|
|
||||||
When a peer is blacklisted:
|
|
||||||
- Any existing connection to that peer is immediately closed and removed
|
|
||||||
- The peer is removed from all topic subscription mappings
|
|
||||||
- Future connection attempts from this peer will be rejected
|
|
||||||
- Messages forwarded by or originating from this peer will be dropped
|
|
||||||
- The peer will not be able to participate in pubsub communication
|
|
||||||
|
|
||||||
:param peer_id: the peer ID to blacklist
|
|
||||||
"""
|
|
||||||
self.blacklisted_peers.add(peer_id)
|
|
||||||
logger.debug("Added peer %s to blacklist", peer_id)
|
|
||||||
self.manager.run_task(self._teardown_if_connected, peer_id)
|
|
||||||
|
|
||||||
async def _teardown_if_connected(self, peer_id: ID) -> None:
|
|
||||||
"""Close their stream and remove them if connected"""
|
|
||||||
stream = self.peers.get(peer_id)
|
|
||||||
if stream is not None:
|
|
||||||
try:
|
|
||||||
await stream.reset()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
del self.peers[peer_id]
|
|
||||||
# Also remove from any subscription maps:
|
|
||||||
for _topic, peerset in self.peer_topics.items():
|
|
||||||
if peer_id in peerset:
|
|
||||||
peerset.discard(peer_id)
|
|
||||||
|
|
||||||
def remove_from_blacklist(self, peer_id: ID) -> None:
|
|
||||||
"""
|
|
||||||
Remove a peer from the blacklist.
|
|
||||||
Once removed from the blacklist:
|
|
||||||
- The peer can establish new connections to this node
|
|
||||||
- Messages from this peer will be processed normally
|
|
||||||
- The peer can participate in topic subscriptions and message forwarding
|
|
||||||
|
|
||||||
:param peer_id: the peer ID to remove from blacklist
|
|
||||||
"""
|
|
||||||
self.blacklisted_peers.discard(peer_id)
|
|
||||||
logger.debug("Removed peer %s from blacklist", peer_id)
|
|
||||||
|
|
||||||
def is_peer_blacklisted(self, peer_id: ID) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a peer is blacklisted.
|
|
||||||
|
|
||||||
:param peer_id: the peer ID to check
|
|
||||||
:return: True if peer is blacklisted, False otherwise
|
|
||||||
"""
|
|
||||||
return peer_id in self.blacklisted_peers
|
|
||||||
|
|
||||||
def clear_blacklist(self) -> None:
|
|
||||||
"""
|
|
||||||
Clear all peers from the blacklist.
|
|
||||||
This removes all blacklist restrictions, allowing previously blacklisted
|
|
||||||
peers to:
|
|
||||||
- Establish new connections
|
|
||||||
- Send and forward messages
|
|
||||||
- Participate in topic subscriptions
|
|
||||||
|
|
||||||
"""
|
|
||||||
self.blacklisted_peers.clear()
|
|
||||||
logger.debug("Cleared all peers from blacklist")
|
|
||||||
|
|
||||||
def get_blacklisted_peers(self) -> set[ID]:
|
|
||||||
"""
|
|
||||||
Get a copy of the current blacklisted peers.
|
|
||||||
Returns a snapshot of all currently blacklisted peer IDs. These peers
|
|
||||||
are completely isolated from pubsub communication - their connections
|
|
||||||
are rejected and their messages are dropped.
|
|
||||||
|
|
||||||
:return: a set containing all blacklisted peer IDs
|
|
||||||
"""
|
|
||||||
return self.blacklisted_peers.copy()
|
|
||||||
|
|
||||||
async def stream_handler(self, stream: INetStream) -> None:
|
async def stream_handler(self, stream: INetStream) -> None:
|
||||||
"""
|
"""
|
||||||
Stream handler for pubsub. Gets invoked whenever a new stream is
|
Stream handler for pubsub. Gets invoked whenever a new stream is
|
||||||
@ -428,10 +346,6 @@ class Pubsub(Service, IPubsub):
|
|||||||
await self.event_handle_dead_peer_queue_started.wait()
|
await self.event_handle_dead_peer_queue_started.wait()
|
||||||
|
|
||||||
async def _handle_new_peer(self, peer_id: ID) -> None:
|
async def _handle_new_peer(self, peer_id: ID) -> None:
|
||||||
if self.is_peer_blacklisted(peer_id):
|
|
||||||
logger.debug("Rejecting blacklisted peer %s", peer_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stream: INetStream = await self.host.new_stream(peer_id, self.protocols)
|
stream: INetStream = await self.host.new_stream(peer_id, self.protocols)
|
||||||
except SwarmException as error:
|
except SwarmException as error:
|
||||||
@ -445,6 +359,7 @@ class Pubsub(Service, IPubsub):
|
|||||||
except StreamClosed:
|
except StreamClosed:
|
||||||
logger.debug("Fail to add new peer %s: stream closed", peer_id)
|
logger.debug("Fail to add new peer %s: stream closed", peer_id)
|
||||||
return
|
return
|
||||||
|
# TODO: Check if the peer in black list.
|
||||||
try:
|
try:
|
||||||
self.router.add_peer(peer_id, stream.get_protocol())
|
self.router.add_peer(peer_id, stream.get_protocol())
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
@ -694,20 +609,9 @@ class Pubsub(Service, IPubsub):
|
|||||||
"""
|
"""
|
||||||
logger.debug("attempting to publish message %s", msg)
|
logger.debug("attempting to publish message %s", msg)
|
||||||
|
|
||||||
# Check if the message forwarder (source) is in the blacklist. If yes, reject.
|
# TODO: Check if the `source` is in the blacklist. If yes, reject.
|
||||||
if self.is_peer_blacklisted(msg_forwarder):
|
|
||||||
logger.debug(
|
|
||||||
"Rejecting message from blacklisted source peer %s", msg_forwarder
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if the message originator (from) is in the blacklist. If yes, reject.
|
# TODO: Check if the `from` is in the blacklist. If yes, reject.
|
||||||
msg_from_peer = ID(msg.from_id)
|
|
||||||
if self.is_peer_blacklisted(msg_from_peer):
|
|
||||||
logger.debug(
|
|
||||||
"Rejecting message from blacklisted originator peer %s", msg_from_peer
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# If the message is processed before, return(i.e., don't further process the message) # noqa: E501
|
# If the message is processed before, return(i.e., don't further process the message) # noqa: E501
|
||||||
if self._is_msg_seen(msg):
|
if self._is_msg_seen(msg):
|
||||||
|
|||||||
@ -1,6 +1,3 @@
|
|||||||
from types import (
|
|
||||||
TracebackType,
|
|
||||||
)
|
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Optional,
|
Optional,
|
||||||
@ -260,16 +257,3 @@ class MplexStream(IMuxedStream):
|
|||||||
def get_remote_address(self) -> Optional[tuple[str, int]]:
|
def get_remote_address(self) -> Optional[tuple[str, int]]:
|
||||||
"""Delegate to the parent Mplex connection."""
|
"""Delegate to the parent Mplex connection."""
|
||||||
return self.muxed_conn.get_remote_address()
|
return self.muxed_conn.get_remote_address()
|
||||||
|
|
||||||
async def __aenter__(self) -> "MplexStream":
|
|
||||||
"""Enter the async context manager."""
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(
|
|
||||||
self,
|
|
||||||
exc_type: Optional[type[BaseException]],
|
|
||||||
exc_val: Optional[BaseException],
|
|
||||||
exc_tb: Optional[TracebackType],
|
|
||||||
) -> None:
|
|
||||||
"""Exit the async context manager and close the stream."""
|
|
||||||
await self.close()
|
|
||||||
|
|||||||
@ -9,9 +9,6 @@ from collections.abc import (
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
from types import (
|
|
||||||
TracebackType,
|
|
||||||
)
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Callable,
|
Callable,
|
||||||
Optional,
|
Optional,
|
||||||
@ -77,19 +74,6 @@ class YamuxStream(IMuxedStream):
|
|||||||
self.recv_window = DEFAULT_WINDOW_SIZE
|
self.recv_window = DEFAULT_WINDOW_SIZE
|
||||||
self.window_lock = trio.Lock()
|
self.window_lock = trio.Lock()
|
||||||
|
|
||||||
async def __aenter__(self) -> "YamuxStream":
|
|
||||||
"""Enter the async context manager."""
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(
|
|
||||||
self,
|
|
||||||
exc_type: Optional[type[BaseException]],
|
|
||||||
exc_val: Optional[BaseException],
|
|
||||||
exc_tb: Optional[TracebackType],
|
|
||||||
) -> None:
|
|
||||||
"""Exit the async context manager and close the stream."""
|
|
||||||
await self.close()
|
|
||||||
|
|
||||||
async def write(self, data: bytes) -> None:
|
async def write(self, data: bytes) -> None:
|
||||||
if self.send_closed:
|
if self.send_closed:
|
||||||
raise MuxedStreamError("Stream is closed for sending")
|
raise MuxedStreamError("Stream is closed for sending")
|
||||||
|
|||||||
@ -1 +0,0 @@
|
|||||||
Allow passing `listen_addrs` to `new_swarm` to customize swarm listening behavior.
|
|
||||||
@ -1,2 +0,0 @@
|
|||||||
Feature: Support for sending `ls` command over `multistream-select` to list supported protocols from remote peer.
|
|
||||||
This allows inspecting which protocol handlers a peer supports at runtime.
|
|
||||||
@ -1 +0,0 @@
|
|||||||
implement AsyncContextManager for IMuxedStream to support async with
|
|
||||||
@ -1 +0,0 @@
|
|||||||
implement blacklist management for `pubsub.Pubsub` with methods to get, add, remove, check, and clear blacklisted peer IDs.
|
|
||||||
6
setup.py
6
setup.py
@ -37,14 +37,10 @@ extras_require = {
|
|||||||
"pytest-trio>=0.5.2",
|
"pytest-trio>=0.5.2",
|
||||||
"factory-boy>=2.12.0,<3.0.0",
|
"factory-boy>=2.12.0,<3.0.0",
|
||||||
],
|
],
|
||||||
"interop": ["redis==6.1.0", "logging==0.4.9.6" "loguru==0.7.3"],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
extras_require["dev"] = (
|
extras_require["dev"] = (
|
||||||
extras_require["dev"]
|
extras_require["dev"] + extras_require["docs"] + extras_require["test"]
|
||||||
+ extras_require["docs"]
|
|
||||||
+ extras_require["test"]
|
|
||||||
+ extras_require["interop"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -209,8 +209,8 @@ async def ping_demo(host_a, host_b):
|
|||||||
|
|
||||||
|
|
||||||
async def pubsub_demo(host_a, host_b):
|
async def pubsub_demo(host_a, host_b):
|
||||||
gossipsub_a = GossipSub([GOSSIPSUB_PROTOCOL_ID], 3, 2, 4, None, 1, 1)
|
gossipsub_a = GossipSub([GOSSIPSUB_PROTOCOL_ID], 3, 2, 4, None, 0.1, 1)
|
||||||
gossipsub_b = GossipSub([GOSSIPSUB_PROTOCOL_ID], 3, 2, 4, None, 1, 1)
|
gossipsub_b = GossipSub([GOSSIPSUB_PROTOCOL_ID], 3, 2, 4, None, 0.1, 1)
|
||||||
pubsub_a = Pubsub(host_a, gossipsub_a)
|
pubsub_a = Pubsub(host_a, gossipsub_a)
|
||||||
pubsub_b = Pubsub(host_b, gossipsub_b)
|
pubsub_b = Pubsub(host_b, gossipsub_b)
|
||||||
message_a_to_b = "Hello from A to B"
|
message_a_to_b = "Hello from A to B"
|
||||||
|
|||||||
@ -7,18 +7,12 @@ from trio.testing import (
|
|||||||
wait_all_tasks_blocked,
|
wait_all_tasks_blocked,
|
||||||
)
|
)
|
||||||
|
|
||||||
from libp2p import (
|
|
||||||
new_swarm,
|
|
||||||
)
|
|
||||||
from libp2p.network.exceptions import (
|
from libp2p.network.exceptions import (
|
||||||
SwarmException,
|
SwarmException,
|
||||||
)
|
)
|
||||||
from libp2p.tools.utils import (
|
from libp2p.tools.utils import (
|
||||||
connect_swarm,
|
connect_swarm,
|
||||||
)
|
)
|
||||||
from libp2p.transport.tcp.tcp import (
|
|
||||||
TCP,
|
|
||||||
)
|
|
||||||
from tests.utils.factories import (
|
from tests.utils.factories import (
|
||||||
SwarmFactory,
|
SwarmFactory,
|
||||||
)
|
)
|
||||||
@ -162,20 +156,3 @@ async def test_swarm_multiaddr(security_protocol):
|
|||||||
|
|
||||||
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs + addrs, 10000)
|
swarms[0].peerstore.add_addrs(swarms[1].get_peer_id(), addrs + addrs, 10000)
|
||||||
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
await swarms[0].dial_peer(swarms[1].get_peer_id())
|
||||||
|
|
||||||
|
|
||||||
def test_new_swarm_defaults_to_tcp():
|
|
||||||
swarm = new_swarm()
|
|
||||||
assert isinstance(swarm.transport, TCP)
|
|
||||||
|
|
||||||
|
|
||||||
def test_new_swarm_tcp_multiaddr_supported():
|
|
||||||
addr = Multiaddr("/ip4/127.0.0.1/tcp/9999")
|
|
||||||
swarm = new_swarm(listen_addrs=[addr])
|
|
||||||
assert isinstance(swarm.transport, TCP)
|
|
||||||
|
|
||||||
|
|
||||||
def test_new_swarm_quic_multiaddr_raises():
|
|
||||||
addr = Multiaddr("/ip4/127.0.0.1/udp/9999/quic")
|
|
||||||
with pytest.raises(ValueError, match="QUIC not yet supported"):
|
|
||||||
new_swarm(listen_addrs=[addr])
|
|
||||||
|
|||||||
@ -116,35 +116,3 @@ async def test_multiple_protocol_fails(security_protocol):
|
|||||||
await perform_simple_test(
|
await perform_simple_test(
|
||||||
"", protocols_for_client, protocols_for_listener, security_protocol
|
"", protocols_for_client, protocols_for_listener, security_protocol
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_multistream_command(security_protocol):
|
|
||||||
supported_protocols = [PROTOCOL_ECHO, PROTOCOL_FOO, PROTOCOL_POTATO, PROTOCOL_ROCK]
|
|
||||||
|
|
||||||
async with HostFactory.create_batch_and_listen(
|
|
||||||
2, security_protocol=security_protocol
|
|
||||||
) as hosts:
|
|
||||||
listener, dialer = hosts[1], hosts[0]
|
|
||||||
|
|
||||||
for protocol in supported_protocols:
|
|
||||||
listener.set_stream_handler(
|
|
||||||
protocol, create_echo_stream_handler(ACK_PREFIX)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure dialer knows how to reach the listener
|
|
||||||
dialer.get_peerstore().add_addrs(listener.get_id(), listener.get_addrs(), 10)
|
|
||||||
|
|
||||||
# Dialer asks peer to list the supported protocols using `ls`
|
|
||||||
response = await dialer.send_command(listener.get_id(), "ls")
|
|
||||||
|
|
||||||
# We expect all supported protocols to show up
|
|
||||||
for protocol in supported_protocols:
|
|
||||||
assert protocol in response
|
|
||||||
|
|
||||||
assert "/does/not/exist" not in response
|
|
||||||
assert "/foo/bar/1.2.3" not in response
|
|
||||||
|
|
||||||
# Dialer asks for unspoorted command
|
|
||||||
with pytest.raises(ValueError, match="Command not supported"):
|
|
||||||
await dialer.send_command(listener.get_id(), "random")
|
|
||||||
|
|||||||
@ -702,369 +702,3 @@ async def test_strict_signing_failed_validation(monkeypatch):
|
|||||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
|
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
|
||||||
await trio.sleep(0.01)
|
await trio.sleep(0.01)
|
||||||
assert event.is_set()
|
assert event.is_set()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_blacklist_basic_operations():
|
|
||||||
"""Test basic blacklist operations: add, remove, check, clear."""
|
|
||||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
|
||||||
pubsub = pubsubs_fsub[0]
|
|
||||||
|
|
||||||
# Create test peer IDs
|
|
||||||
peer1 = IDFactory()
|
|
||||||
peer2 = IDFactory()
|
|
||||||
peer3 = IDFactory()
|
|
||||||
|
|
||||||
# Initially no peers should be blacklisted
|
|
||||||
assert len(pubsub.get_blacklisted_peers()) == 0
|
|
||||||
assert not pubsub.is_peer_blacklisted(peer1)
|
|
||||||
assert not pubsub.is_peer_blacklisted(peer2)
|
|
||||||
assert not pubsub.is_peer_blacklisted(peer3)
|
|
||||||
|
|
||||||
# Add peers to blacklist
|
|
||||||
pubsub.add_to_blacklist(peer1)
|
|
||||||
pubsub.add_to_blacklist(peer2)
|
|
||||||
|
|
||||||
# Check blacklist state
|
|
||||||
assert len(pubsub.get_blacklisted_peers()) == 2
|
|
||||||
assert pubsub.is_peer_blacklisted(peer1)
|
|
||||||
assert pubsub.is_peer_blacklisted(peer2)
|
|
||||||
assert not pubsub.is_peer_blacklisted(peer3)
|
|
||||||
|
|
||||||
# Remove one peer from blacklist
|
|
||||||
pubsub.remove_from_blacklist(peer1)
|
|
||||||
|
|
||||||
# Check state after removal
|
|
||||||
assert len(pubsub.get_blacklisted_peers()) == 1
|
|
||||||
assert not pubsub.is_peer_blacklisted(peer1)
|
|
||||||
assert pubsub.is_peer_blacklisted(peer2)
|
|
||||||
assert not pubsub.is_peer_blacklisted(peer3)
|
|
||||||
|
|
||||||
# Add peer3 and then clear all
|
|
||||||
pubsub.add_to_blacklist(peer3)
|
|
||||||
assert len(pubsub.get_blacklisted_peers()) == 2
|
|
||||||
|
|
||||||
pubsub.clear_blacklist()
|
|
||||||
assert len(pubsub.get_blacklisted_peers()) == 0
|
|
||||||
assert not pubsub.is_peer_blacklisted(peer1)
|
|
||||||
assert not pubsub.is_peer_blacklisted(peer2)
|
|
||||||
assert not pubsub.is_peer_blacklisted(peer3)
|
|
||||||
|
|
||||||
# Test duplicate additions (should not increase size)
|
|
||||||
pubsub.add_to_blacklist(peer1)
|
|
||||||
pubsub.add_to_blacklist(peer1)
|
|
||||||
assert len(pubsub.get_blacklisted_peers()) == 1
|
|
||||||
|
|
||||||
# Test removing non-blacklisted peer (should not cause errors)
|
|
||||||
pubsub.remove_from_blacklist(peer2)
|
|
||||||
assert len(pubsub.get_blacklisted_peers()) == 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_blacklist_blocks_new_peer_connections(monkeypatch):
|
|
||||||
"""Test that blacklisted peers are rejected when trying to connect."""
|
|
||||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
|
||||||
pubsub = pubsubs_fsub[0]
|
|
||||||
|
|
||||||
# Create a blacklisted peer ID
|
|
||||||
blacklisted_peer = IDFactory()
|
|
||||||
|
|
||||||
# Add peer to blacklist
|
|
||||||
pubsub.add_to_blacklist(blacklisted_peer)
|
|
||||||
|
|
||||||
new_stream_called = False
|
|
||||||
|
|
||||||
async def mock_new_stream(*args, **kwargs):
|
|
||||||
nonlocal new_stream_called
|
|
||||||
new_stream_called = True
|
|
||||||
# Create a mock stream
|
|
||||||
from unittest.mock import (
|
|
||||||
AsyncMock,
|
|
||||||
Mock,
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_stream = Mock()
|
|
||||||
mock_stream.write = AsyncMock()
|
|
||||||
mock_stream.reset = AsyncMock()
|
|
||||||
mock_stream.get_protocol = Mock(return_value="test_protocol")
|
|
||||||
return mock_stream
|
|
||||||
|
|
||||||
router_add_peer_called = False
|
|
||||||
|
|
||||||
def mock_add_peer(*args, **kwargs):
|
|
||||||
nonlocal router_add_peer_called
|
|
||||||
router_add_peer_called = True
|
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
|
||||||
m.setattr(pubsub.host, "new_stream", mock_new_stream)
|
|
||||||
m.setattr(pubsub.router, "add_peer", mock_add_peer)
|
|
||||||
|
|
||||||
# Attempt to handle the blacklisted peer
|
|
||||||
await pubsub._handle_new_peer(blacklisted_peer)
|
|
||||||
|
|
||||||
# Verify that both new_stream and router.add_peer was not called
|
|
||||||
assert (
|
|
||||||
not new_stream_called
|
|
||||||
), "new_stream should be not be called to get hello packet"
|
|
||||||
assert (
|
|
||||||
not router_add_peer_called
|
|
||||||
), "Router.add_peer should not be called for blacklisted peer"
|
|
||||||
assert (
|
|
||||||
blacklisted_peer not in pubsub.peers
|
|
||||||
), "Blacklisted peer should not be in peers dict"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_blacklist_blocks_messages_from_blacklisted_originator():
|
|
||||||
"""Test that messages from blacklisted originator (from field) are rejected."""
|
|
||||||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
|
||||||
pubsub = pubsubs_fsub[0]
|
|
||||||
blacklisted_originator = pubsubs_fsub[1].my_id # Use existing peer ID
|
|
||||||
|
|
||||||
# Add the originator to blacklist
|
|
||||||
pubsub.add_to_blacklist(blacklisted_originator)
|
|
||||||
|
|
||||||
# Create a message with blacklisted originator
|
|
||||||
msg = make_pubsub_msg(
|
|
||||||
origin_id=blacklisted_originator,
|
|
||||||
topic_ids=[TESTING_TOPIC],
|
|
||||||
data=TESTING_DATA,
|
|
||||||
seqno=b"\x00" * 8,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Subscribe to the topic
|
|
||||||
await pubsub.subscribe(TESTING_TOPIC)
|
|
||||||
|
|
||||||
# Track if router.publish is called
|
|
||||||
router_publish_called = False
|
|
||||||
|
|
||||||
async def mock_router_publish(*args, **kwargs):
|
|
||||||
nonlocal router_publish_called
|
|
||||||
router_publish_called = True
|
|
||||||
await trio.lowlevel.checkpoint()
|
|
||||||
|
|
||||||
original_router_publish = pubsub.router.publish
|
|
||||||
pubsub.router.publish = mock_router_publish
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Attempt to push message from blacklisted originator
|
|
||||||
await pubsub.push_msg(blacklisted_originator, msg)
|
|
||||||
|
|
||||||
# Verify message was rejected
|
|
||||||
assert (
|
|
||||||
not router_publish_called
|
|
||||||
), "Router.publish should not be called for blacklisted originator"
|
|
||||||
assert not pubsub._is_msg_seen(
|
|
||||||
msg
|
|
||||||
), "Message from blacklisted originator should not be marked as seen"
|
|
||||||
|
|
||||||
finally:
|
|
||||||
pubsub.router.publish = original_router_publish
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_blacklist_allows_non_blacklisted_peers():
|
|
||||||
"""Test that non-blacklisted peers can send messages normally."""
|
|
||||||
async with PubsubFactory.create_batch_with_floodsub(3) as pubsubs_fsub:
|
|
||||||
pubsub = pubsubs_fsub[0]
|
|
||||||
allowed_peer = pubsubs_fsub[1].my_id
|
|
||||||
blacklisted_peer = pubsubs_fsub[2].my_id
|
|
||||||
|
|
||||||
# Blacklist one peer but not the other
|
|
||||||
pubsub.add_to_blacklist(blacklisted_peer)
|
|
||||||
|
|
||||||
# Create messages from both peers
|
|
||||||
msg_from_allowed = make_pubsub_msg(
|
|
||||||
origin_id=allowed_peer,
|
|
||||||
topic_ids=[TESTING_TOPIC],
|
|
||||||
data=b"allowed_data",
|
|
||||||
seqno=b"\x00" * 8,
|
|
||||||
)
|
|
||||||
|
|
||||||
msg_from_blacklisted = make_pubsub_msg(
|
|
||||||
origin_id=blacklisted_peer,
|
|
||||||
topic_ids=[TESTING_TOPIC],
|
|
||||||
data=b"blacklisted_data",
|
|
||||||
seqno=b"\x11" * 8,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Subscribe to the topic
|
|
||||||
sub = await pubsub.subscribe(TESTING_TOPIC)
|
|
||||||
|
|
||||||
# Track router.publish calls
|
|
||||||
router_publish_calls = []
|
|
||||||
|
|
||||||
async def mock_router_publish(*args, **kwargs):
|
|
||||||
router_publish_calls.append(args)
|
|
||||||
await trio.lowlevel.checkpoint()
|
|
||||||
|
|
||||||
original_router_publish = pubsub.router.publish
|
|
||||||
pubsub.router.publish = mock_router_publish
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Send message from allowed peer (should succeed)
|
|
||||||
await pubsub.push_msg(allowed_peer, msg_from_allowed)
|
|
||||||
|
|
||||||
# Send message from blacklisted peer (should be rejected)
|
|
||||||
await pubsub.push_msg(allowed_peer, msg_from_blacklisted)
|
|
||||||
|
|
||||||
# Verify only allowed message was processed
|
|
||||||
assert (
|
|
||||||
len(router_publish_calls) == 1
|
|
||||||
), "Only one message should be processed"
|
|
||||||
assert pubsub._is_msg_seen(
|
|
||||||
msg_from_allowed
|
|
||||||
), "Allowed message should be marked as seen"
|
|
||||||
assert not pubsub._is_msg_seen(
|
|
||||||
msg_from_blacklisted
|
|
||||||
), "Blacklisted message should not be marked as seen"
|
|
||||||
|
|
||||||
# Verify subscription received the allowed message
|
|
||||||
received_msg = await sub.get()
|
|
||||||
assert received_msg.data == b"allowed_data"
|
|
||||||
|
|
||||||
finally:
|
|
||||||
pubsub.router.publish = original_router_publish
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_blacklist_integration_with_existing_functionality():
|
|
||||||
"""Test that blacklisting works correctly with existing pubsub functionality."""
|
|
||||||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
|
||||||
pubsub = pubsubs_fsub[0]
|
|
||||||
other_peer = pubsubs_fsub[1].my_id
|
|
||||||
|
|
||||||
# Test that seen messages cache still works with blacklisting
|
|
||||||
pubsub.add_to_blacklist(other_peer)
|
|
||||||
|
|
||||||
msg = make_pubsub_msg(
|
|
||||||
origin_id=other_peer,
|
|
||||||
topic_ids=[TESTING_TOPIC],
|
|
||||||
data=TESTING_DATA,
|
|
||||||
seqno=b"\x00" * 8,
|
|
||||||
)
|
|
||||||
|
|
||||||
# First attempt - should be rejected due to blacklist
|
|
||||||
await pubsub.push_msg(other_peer, msg)
|
|
||||||
assert not pubsub._is_msg_seen(msg)
|
|
||||||
|
|
||||||
# Remove from blacklist
|
|
||||||
pubsub.remove_from_blacklist(other_peer)
|
|
||||||
|
|
||||||
# Now the message should be processed
|
|
||||||
await pubsub.subscribe(TESTING_TOPIC)
|
|
||||||
await pubsub.push_msg(other_peer, msg)
|
|
||||||
assert pubsub._is_msg_seen(msg)
|
|
||||||
|
|
||||||
# If we try to send the same message again, it should be rejected
|
|
||||||
# due to seen cache (not blacklist)
|
|
||||||
router_publish_called = False
|
|
||||||
|
|
||||||
async def mock_router_publish(*args, **kwargs):
|
|
||||||
nonlocal router_publish_called
|
|
||||||
router_publish_called = True
|
|
||||||
await trio.lowlevel.checkpoint()
|
|
||||||
|
|
||||||
original_router_publish = pubsub.router.publish
|
|
||||||
pubsub.router.publish = mock_router_publish
|
|
||||||
|
|
||||||
try:
|
|
||||||
await pubsub.push_msg(other_peer, msg)
|
|
||||||
assert (
|
|
||||||
not router_publish_called
|
|
||||||
), "Duplicate message should be rejected by seen cache"
|
|
||||||
finally:
|
|
||||||
pubsub.router.publish = original_router_publish
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_blacklist_blocks_messages_from_blacklisted_source():
|
|
||||||
"""Test that messages from blacklisted source (forwarder) are rejected."""
|
|
||||||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
|
||||||
pubsub = pubsubs_fsub[0]
|
|
||||||
blacklisted_forwarder = pubsubs_fsub[1].my_id
|
|
||||||
|
|
||||||
# Add the forwarder to blacklist
|
|
||||||
pubsub.add_to_blacklist(blacklisted_forwarder)
|
|
||||||
|
|
||||||
# Create a message
|
|
||||||
msg = make_pubsub_msg(
|
|
||||||
origin_id=pubsubs_fsub[1].my_id,
|
|
||||||
topic_ids=[TESTING_TOPIC],
|
|
||||||
data=TESTING_DATA,
|
|
||||||
seqno=b"\x00" * 8,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Subscribe to the topic so we can check if message is processed
|
|
||||||
await pubsub.subscribe(TESTING_TOPIC)
|
|
||||||
|
|
||||||
# Track if router.publish is called (it shouldn't be for blacklisted forwarder)
|
|
||||||
router_publish_called = False
|
|
||||||
|
|
||||||
async def mock_router_publish(*args, **kwargs):
|
|
||||||
nonlocal router_publish_called
|
|
||||||
router_publish_called = True
|
|
||||||
await trio.lowlevel.checkpoint()
|
|
||||||
|
|
||||||
original_router_publish = pubsub.router.publish
|
|
||||||
pubsub.router.publish = mock_router_publish
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Attempt to push message from blacklisted forwarder
|
|
||||||
await pubsub.push_msg(blacklisted_forwarder, msg)
|
|
||||||
|
|
||||||
# Verify message was rejected
|
|
||||||
assert (
|
|
||||||
not router_publish_called
|
|
||||||
), "Router.publish should not be called for blacklisted forwarder"
|
|
||||||
assert not pubsub._is_msg_seen(
|
|
||||||
msg
|
|
||||||
), "Message from blacklisted forwarder should not be marked as seen"
|
|
||||||
|
|
||||||
finally:
|
|
||||||
pubsub.router.publish = original_router_publish
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_blacklist_tears_down_existing_connection():
|
|
||||||
"""
|
|
||||||
Verify that if a peer is already in pubsub.peers and pubsub.peer_topics,
|
|
||||||
calling add_to_blacklist(peer_id) immediately resets its stream and
|
|
||||||
removes it from both places.
|
|
||||||
"""
|
|
||||||
# Create two pubsub instances (floodsub), so they can connect to each other
|
|
||||||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
|
||||||
pubsub0, pubsub1 = pubsubs_fsub
|
|
||||||
|
|
||||||
# 1) Connect peer1 to peer0
|
|
||||||
await connect(pubsub0.host, pubsub1.host)
|
|
||||||
# Give handle_peer_queue some time to run
|
|
||||||
await trio.sleep(0.1)
|
|
||||||
|
|
||||||
# After connect, pubsub0.peers should contain pubsub1.my_id
|
|
||||||
assert pubsub1.my_id in pubsub0.peers
|
|
||||||
|
|
||||||
# 2) Manually record a subscription from peer1 under TESTING_TOPIC,
|
|
||||||
# so that peer1 shows up in pubsub0.peer_topics[TESTING_TOPIC].
|
|
||||||
sub_msg = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=TESTING_TOPIC)
|
|
||||||
pubsub0.handle_subscription(pubsub1.my_id, sub_msg)
|
|
||||||
|
|
||||||
assert TESTING_TOPIC in pubsub0.peer_topics
|
|
||||||
assert pubsub1.my_id in pubsub0.peer_topics[TESTING_TOPIC]
|
|
||||||
|
|
||||||
# 3) Now blacklist peer1
|
|
||||||
pubsub0.add_to_blacklist(pubsub1.my_id)
|
|
||||||
|
|
||||||
# Allow the asynchronous teardown task (_teardown_if_connected) to run
|
|
||||||
await trio.sleep(0.1)
|
|
||||||
|
|
||||||
# 4a) pubsub0.peers should no longer contain peer1
|
|
||||||
assert pubsub1.my_id not in pubsub0.peers
|
|
||||||
|
|
||||||
# 4b) pubsub0.peer_topics[TESTING_TOPIC] should no longer contain peer1
|
|
||||||
# (or TESTING_TOPIC may have been removed entirely if no other peers remain)
|
|
||||||
if TESTING_TOPIC in pubsub0.peer_topics:
|
|
||||||
assert pubsub1.my_id not in pubsub0.peer_topics[TESTING_TOPIC]
|
|
||||||
else:
|
|
||||||
# It’s also fine if the entire topic entry was pruned
|
|
||||||
assert TESTING_TOPIC not in pubsub0.peer_topics
|
|
||||||
|
|||||||
@ -1,127 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import trio
|
|
||||||
|
|
||||||
from libp2p.stream_muxer.exceptions import (
|
|
||||||
MuxedStreamClosed,
|
|
||||||
MuxedStreamError,
|
|
||||||
)
|
|
||||||
from libp2p.stream_muxer.mplex.datastructures import (
|
|
||||||
StreamID,
|
|
||||||
)
|
|
||||||
from libp2p.stream_muxer.mplex.mplex_stream import (
|
|
||||||
MplexStream,
|
|
||||||
)
|
|
||||||
from libp2p.stream_muxer.yamux.yamux import (
|
|
||||||
YamuxStream,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DummySecuredConn:
|
|
||||||
async def write(self, data):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MockMuxedConn:
|
|
||||||
def __init__(self):
|
|
||||||
self.streams = {}
|
|
||||||
self.streams_lock = trio.Lock()
|
|
||||||
self.event_shutting_down = trio.Event()
|
|
||||||
self.event_closed = trio.Event()
|
|
||||||
self.event_started = trio.Event()
|
|
||||||
self.secured_conn = DummySecuredConn() # For YamuxStream
|
|
||||||
|
|
||||||
async def send_message(self, flag, data, stream_id):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def get_remote_address(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_mplex_stream_async_context_manager():
|
|
||||||
muxed_conn = MockMuxedConn()
|
|
||||||
stream_id = StreamID(1, True) # Use real StreamID
|
|
||||||
stream = MplexStream(
|
|
||||||
name="test_stream",
|
|
||||||
stream_id=stream_id,
|
|
||||||
muxed_conn=muxed_conn,
|
|
||||||
incoming_data_channel=trio.open_memory_channel(8)[1],
|
|
||||||
)
|
|
||||||
async with stream as s:
|
|
||||||
assert s is stream
|
|
||||||
assert not stream.event_local_closed.is_set()
|
|
||||||
assert not stream.event_remote_closed.is_set()
|
|
||||||
assert not stream.event_reset.is_set()
|
|
||||||
assert stream.event_local_closed.is_set()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_yamux_stream_async_context_manager():
|
|
||||||
muxed_conn = MockMuxedConn()
|
|
||||||
stream = YamuxStream(stream_id=1, conn=muxed_conn, is_initiator=True)
|
|
||||||
async with stream as s:
|
|
||||||
assert s is stream
|
|
||||||
assert not stream.closed
|
|
||||||
assert not stream.send_closed
|
|
||||||
assert not stream.recv_closed
|
|
||||||
assert stream.send_closed
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_mplex_stream_async_context_manager_with_error():
|
|
||||||
muxed_conn = MockMuxedConn()
|
|
||||||
stream_id = StreamID(1, True)
|
|
||||||
stream = MplexStream(
|
|
||||||
name="test_stream",
|
|
||||||
stream_id=stream_id,
|
|
||||||
muxed_conn=muxed_conn,
|
|
||||||
incoming_data_channel=trio.open_memory_channel(8)[1],
|
|
||||||
)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
async with stream as s:
|
|
||||||
assert s is stream
|
|
||||||
assert not stream.event_local_closed.is_set()
|
|
||||||
assert not stream.event_remote_closed.is_set()
|
|
||||||
assert not stream.event_reset.is_set()
|
|
||||||
raise ValueError("Test error")
|
|
||||||
assert stream.event_local_closed.is_set()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_yamux_stream_async_context_manager_with_error():
|
|
||||||
muxed_conn = MockMuxedConn()
|
|
||||||
stream = YamuxStream(stream_id=1, conn=muxed_conn, is_initiator=True)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
async with stream as s:
|
|
||||||
assert s is stream
|
|
||||||
assert not stream.closed
|
|
||||||
assert not stream.send_closed
|
|
||||||
assert not stream.recv_closed
|
|
||||||
raise ValueError("Test error")
|
|
||||||
assert stream.send_closed
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_mplex_stream_async_context_manager_write_after_close():
|
|
||||||
muxed_conn = MockMuxedConn()
|
|
||||||
stream_id = StreamID(1, True)
|
|
||||||
stream = MplexStream(
|
|
||||||
name="test_stream",
|
|
||||||
stream_id=stream_id,
|
|
||||||
muxed_conn=muxed_conn,
|
|
||||||
incoming_data_channel=trio.open_memory_channel(8)[1],
|
|
||||||
)
|
|
||||||
async with stream as s:
|
|
||||||
assert s is stream
|
|
||||||
with pytest.raises(MuxedStreamClosed):
|
|
||||||
await stream.write(b"test data")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.trio
|
|
||||||
async def test_yamux_stream_async_context_manager_write_after_close():
|
|
||||||
muxed_conn = MockMuxedConn()
|
|
||||||
stream = YamuxStream(stream_id=1, conn=muxed_conn, is_initiator=True)
|
|
||||||
async with stream as s:
|
|
||||||
assert s is stream
|
|
||||||
with pytest.raises(MuxedStreamError):
|
|
||||||
await stream.write(b"test data")
|
|
||||||
@ -32,25 +32,18 @@ class BaseInteractiveProcess(AbstractInterativeProcess):
|
|||||||
|
|
||||||
async def wait_until_ready(self) -> None:
|
async def wait_until_ready(self) -> None:
|
||||||
patterns_occurred = {pat: False for pat in self.patterns}
|
patterns_occurred = {pat: False for pat in self.patterns}
|
||||||
buffers = {pat: bytearray() for pat in self.patterns}
|
|
||||||
|
|
||||||
async def read_from_daemon_and_check() -> None:
|
async def read_from_daemon_and_check() -> None:
|
||||||
async for data in self.proc.stdout:
|
async for data in self.proc.stdout:
|
||||||
|
# TODO: It takes O(n^2), which is quite bad.
|
||||||
|
# But it should succeed in a few seconds.
|
||||||
self.bytes_read.extend(data)
|
self.bytes_read.extend(data)
|
||||||
for pat, occurred in patterns_occurred.items():
|
for pat, occurred in patterns_occurred.items():
|
||||||
if occurred:
|
if occurred:
|
||||||
continue
|
continue
|
||||||
|
if pat in self.bytes_read:
|
||||||
# Check if pattern is in new data or spans across chunks
|
|
||||||
buf = buffers[pat]
|
|
||||||
buf.extend(data)
|
|
||||||
if pat in buf:
|
|
||||||
patterns_occurred[pat] = True
|
patterns_occurred[pat] = True
|
||||||
else:
|
if all([value for value in patterns_occurred.values()]):
|
||||||
keep = min(len(pat) - 1, len(buf))
|
|
||||||
buffers[pat] = buf[-keep:] if keep > 0 else bytearray()
|
|
||||||
|
|
||||||
if all(patterns_occurred.values()):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
with trio.fail_after(TIMEOUT_DURATION):
|
with trio.fail_after(TIMEOUT_DURATION):
|
||||||
|
|||||||
Reference in New Issue
Block a user