mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
fix: pre-commit issues
This commit is contained in:
@ -1,27 +1,39 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import struct
|
||||||
|
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import (
|
||||||
|
x25519,
|
||||||
|
)
|
||||||
import multiaddr
|
import multiaddr
|
||||||
import trio
|
import trio
|
||||||
import logging
|
|
||||||
import struct
|
|
||||||
import os
|
|
||||||
from typing import Optional, Dict, Any
|
|
||||||
|
|
||||||
from libp2p import new_host, generate_new_rsa_identity
|
from libp2p import (
|
||||||
from libp2p.custom_types import TProtocol
|
generate_new_rsa_identity,
|
||||||
from libp2p.network.stream.net_stream import INetStream
|
new_host,
|
||||||
|
)
|
||||||
|
from libp2p.custom_types import (
|
||||||
|
TProtocol,
|
||||||
|
)
|
||||||
|
from libp2p.network.stream.net_stream import (
|
||||||
|
INetStream,
|
||||||
|
)
|
||||||
from libp2p.security.noise.transport import Transport as NoiseTransport
|
from libp2p.security.noise.transport import Transport as NoiseTransport
|
||||||
from libp2p.stream_muxer.yamux.yamux import Yamux, YamuxStream, PROTOCOL_ID as YAMUX_PROTOCOL_ID
|
from libp2p.stream_muxer.yamux.yamux import (
|
||||||
from cryptography.hazmat.primitives.asymmetric import x25519
|
Yamux,
|
||||||
from libp2p.crypto.keys import KeyPair
|
YamuxStream,
|
||||||
|
)
|
||||||
|
from libp2p.stream_muxer.yamux.yamux import PROTOCOL_ID as YAMUX_PROTOCOL_ID
|
||||||
|
|
||||||
# Configure detailed logging
|
# Configure detailed logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.DEBUG,
|
level=logging.DEBUG,
|
||||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||||
handlers=[
|
handlers=[
|
||||||
logging.StreamHandler(),
|
logging.StreamHandler(),
|
||||||
logging.FileHandler('ping_debug.log', mode='w', encoding='utf-8')
|
logging.FileHandler("ping_debug.log", mode="w", encoding="utf-8"),
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Protocol constants - must match rust-libp2p exactly
|
# Protocol constants - must match rust-libp2p exactly
|
||||||
@ -30,9 +42,10 @@ PING_LENGTH = 32
|
|||||||
RESP_TIMEOUT = 30
|
RESP_TIMEOUT = 30
|
||||||
MAX_FRAME_SIZE = 1024 * 1024 # 1MB max frame size
|
MAX_FRAME_SIZE = 1024 * 1024 # 1MB max frame size
|
||||||
|
|
||||||
|
|
||||||
class InteropYamux(Yamux):
|
class InteropYamux(Yamux):
|
||||||
"""Enhanced Yamux with proper rust-libp2p interoperability"""
|
"""Enhanced Yamux with proper rust-libp2p interoperability"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
logging.info("InteropYamux.__init__ called")
|
logging.info("InteropYamux.__init__ called")
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@ -43,25 +56,31 @@ class InteropYamux(Yamux):
|
|||||||
"""Read exactly n bytes from the connection with proper error handling"""
|
"""Read exactly n bytes from the connection with proper error handling"""
|
||||||
if n == 0:
|
if n == 0:
|
||||||
return b""
|
return b""
|
||||||
|
|
||||||
if n > MAX_FRAME_SIZE:
|
if n > MAX_FRAME_SIZE:
|
||||||
logging.error(f"Requested read size {n} exceeds maximum {MAX_FRAME_SIZE}")
|
logging.error(f"Requested read size {n} exceeds maximum {MAX_FRAME_SIZE}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
data = b""
|
data = b""
|
||||||
while len(data) < n:
|
while len(data) < n:
|
||||||
try:
|
try:
|
||||||
remaining = n - len(data)
|
remaining = n - len(data)
|
||||||
chunk = await self.secured_conn.read(remaining)
|
chunk = await self.secured_conn.read(remaining)
|
||||||
except (trio.ClosedResourceError, trio.BrokenResourceError):
|
except (trio.ClosedResourceError, trio.BrokenResourceError):
|
||||||
logging.debug(f"Connection closed while reading {n} bytes (got {len(data)}) for peer {self.peer_id}")
|
logging.debug(
|
||||||
|
f"Connection closed while reading {n}"
|
||||||
|
f"bytes (got {len(data)}) for peer {self.peer_id}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error reading {n} bytes: {e}")
|
logging.error(f"Error reading {n} bytes: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not chunk:
|
if not chunk:
|
||||||
logging.debug(f"Connection closed while reading {n} bytes (got {len(data)}) for peer {self.peer_id}")
|
logging.debug(
|
||||||
|
f"Connection closed while reading {n}"
|
||||||
|
f"bytes (got {len(data)}) for peer {self.peer_id}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
data += chunk
|
data += chunk
|
||||||
return data
|
return data
|
||||||
@ -69,55 +88,75 @@ class InteropYamux(Yamux):
|
|||||||
async def handle_incoming(self):
|
async def handle_incoming(self):
|
||||||
"""Enhanced incoming frame handler with better error recovery"""
|
"""Enhanced incoming frame handler with better error recovery"""
|
||||||
logging.info(f"Starting Yamux for {self.peer_id}")
|
logging.info(f"Starting Yamux for {self.peer_id}")
|
||||||
|
|
||||||
consecutive_errors = 0
|
consecutive_errors = 0
|
||||||
max_consecutive_errors = 3
|
max_consecutive_errors = 3
|
||||||
|
|
||||||
while not self.event_shutting_down.is_set():
|
while not self.event_shutting_down.is_set():
|
||||||
try:
|
try:
|
||||||
# Read frame header (12 bytes)
|
# Read frame header (12 bytes)
|
||||||
header_data = await self._read_exact_bytes(12)
|
header_data = await self._read_exact_bytes(12)
|
||||||
if header_data is None:
|
if header_data is None:
|
||||||
logging.debug(f"Connection closed or incomplete header for peer {self.peer_id}")
|
logging.debug(
|
||||||
|
f"Connection closed or incomplete"
|
||||||
|
f"header for peer {self.peer_id}"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
# Quick sanity check for protocol data leakage
|
# Quick sanity check for protocol data leakage
|
||||||
if b'/ipfs' in header_data or b'multi' in header_data or b'noise' in header_data:
|
if (
|
||||||
logging.error(f"Protocol data in header position: {header_data.hex()}")
|
b"/ipfs" in header_data
|
||||||
|
or b"multi" in header_data
|
||||||
|
or b"noise" in header_data
|
||||||
|
):
|
||||||
|
logging.error(
|
||||||
|
f"Protocol data in header position: {header_data.hex()}"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Unpack header: version, type, flags, stream_id, length
|
# Unpack header: version, type, flags, stream_id, length
|
||||||
version, msg_type, flags, stream_id, length = struct.unpack(">BBHII", header_data)
|
version, msg_type, flags, stream_id, length = struct.unpack(
|
||||||
|
">BBHII", header_data
|
||||||
|
)
|
||||||
|
|
||||||
# Validate header values strictly
|
# Validate header values strictly
|
||||||
if version != 0:
|
if version != 0:
|
||||||
logging.error(f"Invalid yamux version {version}, expected 0")
|
logging.error(f"Invalid yamux version {version}, expected 0")
|
||||||
break
|
break
|
||||||
|
|
||||||
if msg_type not in [0, 1, 2, 3]:
|
if msg_type not in [0, 1, 2, 3]:
|
||||||
logging.error(f"Invalid message type {msg_type}, expected 0-3")
|
logging.error(f"Invalid message type {msg_type}, expected 0-3")
|
||||||
break
|
break
|
||||||
|
|
||||||
if length > MAX_FRAME_SIZE:
|
if length > MAX_FRAME_SIZE:
|
||||||
logging.error(f"Frame too large: {length} > {MAX_FRAME_SIZE}")
|
logging.error(f"Frame too large: {length} > {MAX_FRAME_SIZE}")
|
||||||
break
|
break
|
||||||
|
|
||||||
# Additional validation for ping frames
|
# Additional validation for ping frames
|
||||||
if msg_type == 2 and length != 4:
|
if msg_type == 2 and length != 4:
|
||||||
logging.error(f"Invalid ping frame length: {length}, expected 4")
|
logging.error(
|
||||||
|
f"Invalid ping frame length: {length}, expected 4"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
# Log frame details
|
# Log frame details
|
||||||
logging.debug(f"Received header for peer {self.peer_id}: type={msg_type}, flags={flags}, stream_id={stream_id}, length={length}")
|
logging.debug(
|
||||||
|
f"Received header for peer {self.peer_id}"
|
||||||
|
f": type={msg_type}, flags={flags},"
|
||||||
|
f"stream_id={stream_id}, length={length}"
|
||||||
|
)
|
||||||
|
|
||||||
consecutive_errors = 0 # Reset error counter on successful parse
|
consecutive_errors = 0 # Reset error counter on successful parse
|
||||||
|
|
||||||
except struct.error as e:
|
except struct.error as e:
|
||||||
consecutive_errors += 1
|
consecutive_errors += 1
|
||||||
logging.error(f"Header parse error #{consecutive_errors}: {e}, data: {header_data.hex()}")
|
logging.error(
|
||||||
|
f"Header parse error #{consecutive_errors}"
|
||||||
|
f": {e}, data: {header_data.hex()}"
|
||||||
|
)
|
||||||
if consecutive_errors >= max_consecutive_errors:
|
if consecutive_errors >= max_consecutive_errors:
|
||||||
logging.error("Too many consecutive header parse errors, closing connection")
|
logging.error("Too many consecutive header parse errors")
|
||||||
break
|
break
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -126,34 +165,42 @@ class InteropYamux(Yamux):
|
|||||||
if length > 0:
|
if length > 0:
|
||||||
payload = await self._read_exact_bytes(length)
|
payload = await self._read_exact_bytes(length)
|
||||||
if payload is None:
|
if payload is None:
|
||||||
logging.debug(f"Failed to read payload of {length} bytes for peer {self.peer_id}")
|
logging.debug(
|
||||||
|
f"Failed to read payload of"
|
||||||
|
f"{length} bytes for peer {self.peer_id}"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
if len(payload) != length:
|
if len(payload) != length:
|
||||||
logging.error(f"Payload length mismatch: got {len(payload)}, expected {length}")
|
logging.error(
|
||||||
|
f"Payload length mismatch:"
|
||||||
|
f"got {len(payload)}, expected {length}"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
# Process frame by type
|
# Process frame by type
|
||||||
if msg_type == 0: # Data frame
|
if msg_type == 0: # Data frame
|
||||||
await self._handle_data_frame(stream_id, flags, payload)
|
await self._handle_data_frame(stream_id, flags, payload)
|
||||||
|
|
||||||
elif msg_type == 1: # Window update
|
elif msg_type == 1: # Window update
|
||||||
await self._handle_window_update(stream_id, payload)
|
await self._handle_window_update(stream_id, payload)
|
||||||
|
|
||||||
elif msg_type == 2: # Ping frame
|
elif msg_type == 2: # Ping frame
|
||||||
await self._handle_ping_frame(stream_id, flags, payload)
|
await self._handle_ping_frame(stream_id, flags, payload)
|
||||||
|
|
||||||
elif msg_type == 3: # GoAway frame
|
elif msg_type == 3: # GoAway frame
|
||||||
await self._handle_goaway_frame(payload)
|
await self._handle_goaway_frame(payload)
|
||||||
break
|
break
|
||||||
|
|
||||||
except (trio.ClosedResourceError, trio.BrokenResourceError):
|
except (trio.ClosedResourceError, trio.BrokenResourceError):
|
||||||
logging.debug(f"Connection closed during frame processing for peer {self.peer_id}")
|
logging.debug(
|
||||||
|
f"Connection closed during frame processing for peer {self.peer_id}"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
consecutive_errors += 1
|
consecutive_errors += 1
|
||||||
logging.error(f"Frame processing error #{consecutive_errors}: {e}")
|
logging.error(f"Frame processing error #{consecutive_errors}: {e}")
|
||||||
if consecutive_errors >= max_consecutive_errors:
|
if consecutive_errors >= max_consecutive_errors:
|
||||||
logging.error("Too many consecutive frame processing errors, closing connection")
|
logging.error("Too many consecutive frame processing errors")
|
||||||
break
|
break
|
||||||
|
|
||||||
await self._cleanup_on_error()
|
await self._cleanup_on_error()
|
||||||
@ -169,28 +216,30 @@ class InteropYamux(Yamux):
|
|||||||
if stream_id in self.streams:
|
if stream_id in self.streams:
|
||||||
logging.warning(f"SYN received for existing stream {stream_id}")
|
logging.warning(f"SYN received for existing stream {stream_id}")
|
||||||
else:
|
else:
|
||||||
logging.debug(f"Creating new stream {stream_id} for peer {self.peer_id}")
|
logging.debug(
|
||||||
|
f"Creating new stream {stream_id} for peer {self.peer_id}"
|
||||||
|
)
|
||||||
stream = YamuxStream(self, stream_id, is_outbound=False)
|
stream = YamuxStream(self, stream_id, is_outbound=False)
|
||||||
async with self.streams_lock:
|
async with self.streams_lock:
|
||||||
self.streams[stream_id] = stream
|
self.streams[stream_id] = stream
|
||||||
|
|
||||||
# Send the new stream to the handler
|
# Send the new stream to the handler
|
||||||
await self.new_stream_send_channel.send(stream)
|
await self.new_stream_send_channel.send(stream)
|
||||||
logging.debug(f"Sent stream {stream_id} to handler")
|
logging.debug(f"Sent stream {stream_id} to handler")
|
||||||
|
|
||||||
# Add data to stream buffer if stream exists
|
# Add data to stream buffer if stream exists
|
||||||
if stream_id in self.streams:
|
if stream_id in self.streams:
|
||||||
stream = self.streams[stream_id]
|
stream = self.streams[stream_id]
|
||||||
if payload:
|
if payload:
|
||||||
# Add to stream's receive buffer
|
# Add to stream's receive buffer
|
||||||
async with self.streams_lock:
|
async with self.streams_lock:
|
||||||
if not hasattr(stream, '_receive_buffer'):
|
if not hasattr(stream, "_receive_buffer"):
|
||||||
stream._receive_buffer = bytearray()
|
stream._receive_buffer = bytearray()
|
||||||
if not hasattr(stream, '_receive_event'):
|
if not hasattr(stream, "_receive_event"):
|
||||||
stream._receive_event = trio.Event()
|
stream._receive_event = trio.Event()
|
||||||
stream._receive_buffer.extend(payload)
|
stream._receive_buffer.extend(payload)
|
||||||
stream._receive_event.set()
|
stream._receive_event.set()
|
||||||
|
|
||||||
# Handle stream closure flags
|
# Handle stream closure flags
|
||||||
if flags & 0x2: # FIN flag
|
if flags & 0x2: # FIN flag
|
||||||
stream.recv_closed = True
|
stream.recv_closed = True
|
||||||
@ -207,13 +256,13 @@ class InteropYamux(Yamux):
|
|||||||
if len(payload) != 4:
|
if len(payload) != 4:
|
||||||
logging.warning(f"Invalid window update payload length: {len(payload)}")
|
logging.warning(f"Invalid window update payload length: {len(payload)}")
|
||||||
return
|
return
|
||||||
|
|
||||||
delta = struct.unpack(">I", payload)[0]
|
delta = struct.unpack(">I", payload)[0]
|
||||||
logging.debug(f"Window update: stream={stream_id}, delta={delta}")
|
logging.debug(f"Window update: stream={stream_id}, delta={delta}")
|
||||||
|
|
||||||
async with self.streams_lock:
|
async with self.streams_lock:
|
||||||
if stream_id in self.streams:
|
if stream_id in self.streams:
|
||||||
if not hasattr(self.streams[stream_id], 'send_window'):
|
if not hasattr(self.streams[stream_id], "send_window"):
|
||||||
self.streams[stream_id].send_window = 256 * 1024 # Default window
|
self.streams[stream_id].send_window = 256 * 1024 # Default window
|
||||||
self.streams[stream_id].send_window += delta
|
self.streams[stream_id].send_window += delta
|
||||||
|
|
||||||
@ -222,14 +271,18 @@ class InteropYamux(Yamux):
|
|||||||
if len(payload) != 4:
|
if len(payload) != 4:
|
||||||
logging.warning(f"Invalid ping payload length: {len(payload)} (expected 4)")
|
logging.warning(f"Invalid ping payload length: {len(payload)} (expected 4)")
|
||||||
return
|
return
|
||||||
|
|
||||||
ping_value = struct.unpack(">I", payload)[0]
|
ping_value = struct.unpack(">I", payload)[0]
|
||||||
|
|
||||||
if flags & 0x1: # SYN flag - ping request
|
if flags & 0x1: # SYN flag - ping request
|
||||||
logging.debug(f"Received ping request with value {ping_value} for peer {self.peer_id}")
|
logging.debug(
|
||||||
|
f"Received ping request with value {ping_value} for peer {self.peer_id}"
|
||||||
|
)
|
||||||
# Send pong response (ACK flag = 0x2)
|
# Send pong response (ACK flag = 0x2)
|
||||||
try:
|
try:
|
||||||
pong_header = struct.pack(">BBHII", 0, 2, 0x2, 0, 4) # Version=0, Type=2, Flags=ACK, StreamID=0, Length=4
|
pong_header = struct.pack(
|
||||||
|
">BBHII", 0, 2, 0x2, 0, 4
|
||||||
|
) # Version=0, Type=2, Flags=ACK, StreamID=0, Length=4
|
||||||
pong_payload = struct.pack(">I", ping_value)
|
pong_payload = struct.pack(">I", ping_value)
|
||||||
await self.secured_conn.write(pong_header + pong_payload)
|
await self.secured_conn.write(pong_header + pong_payload)
|
||||||
logging.debug(f"Sent pong response with value {ping_value}")
|
logging.debug(f"Sent pong response with value {ping_value}")
|
||||||
@ -244,11 +297,12 @@ class InteropYamux(Yamux):
|
|||||||
if len(payload) != 4:
|
if len(payload) != 4:
|
||||||
logging.warning(f"Invalid GoAway payload length: {len(payload)}")
|
logging.warning(f"Invalid GoAway payload length: {len(payload)}")
|
||||||
return
|
return
|
||||||
|
|
||||||
code = struct.unpack(">I", payload)[0]
|
code = struct.unpack(">I", payload)[0]
|
||||||
logging.info(f"Received GoAway frame with code {code}")
|
logging.info(f"Received GoAway frame with code {code}")
|
||||||
self.event_shutting_down.set()
|
self.event_shutting_down.set()
|
||||||
|
|
||||||
|
|
||||||
async def handle_ping(stream: INetStream) -> None:
|
async def handle_ping(stream: INetStream) -> None:
|
||||||
peer_id = stream.muxed_conn.peer_id
|
peer_id = stream.muxed_conn.peer_id
|
||||||
logging.info(f"Handling ping stream from {peer_id}")
|
logging.info(f"Handling ping stream from {peer_id}")
|
||||||
@ -257,9 +311,14 @@ async def handle_ping(stream: INetStream) -> None:
|
|||||||
with trio.fail_after(RESP_TIMEOUT):
|
with trio.fail_after(RESP_TIMEOUT):
|
||||||
# Read initial protocol negotiation
|
# Read initial protocol negotiation
|
||||||
initial_data = await stream.read(1024)
|
initial_data = await stream.read(1024)
|
||||||
logging.debug(f"Received initial stream data from {peer_id}: {initial_data.hex()} (length={len(initial_data)})")
|
logging.debug(
|
||||||
|
f"Received initial stream data from {peer_id}"
|
||||||
|
f": {initial_data.hex()} (length={len(initial_data)})"
|
||||||
|
)
|
||||||
if initial_data == b"/ipfs/ping/1.0.0\n":
|
if initial_data == b"/ipfs/ping/1.0.0\n":
|
||||||
logging.debug(f"Confirmed /ipfs/ping/1.0.0 protocol negotiation from {peer_id}")
|
logging.debug(
|
||||||
|
f"Confirmed /ipfs/ping/1.0.0 protocol negotiation from {peer_id}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logging.warning(f"Unexpected initial data: {initial_data!r}")
|
logging.warning(f"Unexpected initial data: {initial_data!r}")
|
||||||
|
|
||||||
@ -269,9 +328,15 @@ async def handle_ping(stream: INetStream) -> None:
|
|||||||
logging.info(f"Stream closed by {peer_id}")
|
logging.info(f"Stream closed by {peer_id}")
|
||||||
return
|
return
|
||||||
if len(payload) != PING_LENGTH:
|
if len(payload) != PING_LENGTH:
|
||||||
logging.warning(f"Unexpected payload length {len(payload)} from {peer_id}: {payload.hex()}")
|
logging.warning(
|
||||||
|
f"Unexpected payload length"
|
||||||
|
f" {len(payload)} from {peer_id}: {payload.hex()}"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
logging.info(f"Received ping from {peer_id}: {payload[:8].hex()}... (length={len(payload)})")
|
logging.info(
|
||||||
|
f"Received ping from {peer_id}:"
|
||||||
|
f" {payload[:8].hex()}... (length={len(payload)})"
|
||||||
|
)
|
||||||
await stream.write(payload)
|
await stream.write(payload)
|
||||||
logging.info(f"Sent pong to {peer_id}: {payload[:8].hex()}...")
|
logging.info(f"Sent pong to {peer_id}: {payload[:8].hex()}...")
|
||||||
|
|
||||||
@ -285,9 +350,10 @@ async def handle_ping(stream: INetStream) -> None:
|
|||||||
try:
|
try:
|
||||||
await stream.close()
|
await stream.close()
|
||||||
logging.debug(f"Closed ping stream with {peer_id}")
|
logging.debug(f"Closed ping stream with {peer_id}")
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def send_ping(stream: INetStream) -> None:
|
async def send_ping(stream: INetStream) -> None:
|
||||||
peer_id = stream.muxed_conn.peer_id
|
peer_id = stream.muxed_conn.peer_id
|
||||||
try:
|
try:
|
||||||
@ -301,7 +367,9 @@ async def send_ping(stream: INetStream) -> None:
|
|||||||
logging.error(f"No pong response from {peer_id}")
|
logging.error(f"No pong response from {peer_id}")
|
||||||
return
|
return
|
||||||
if len(response) != PING_LENGTH:
|
if len(response) != PING_LENGTH:
|
||||||
logging.warning(f"Pong length mismatch: got {len(response)}, expected {PING_LENGTH}")
|
logging.warning(
|
||||||
|
f"Pong length mismatch: got {len(response)}, expected {PING_LENGTH}"
|
||||||
|
)
|
||||||
if response == payload:
|
if response == payload:
|
||||||
logging.info(f"Ping successful! Pong matches from {peer_id}")
|
logging.info(f"Ping successful! Pong matches from {peer_id}")
|
||||||
else:
|
else:
|
||||||
@ -313,30 +381,31 @@ async def send_ping(stream: INetStream) -> None:
|
|||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
await stream.close()
|
await stream.close()
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def create_noise_keypair():
|
def create_noise_keypair():
|
||||||
try:
|
try:
|
||||||
x25519_private_key = x25519.X25519PrivateKey.generate()
|
x25519_private_key = x25519.X25519PrivateKey.generate()
|
||||||
|
|
||||||
class NoisePrivateKey:
|
class NoisePrivateKey:
|
||||||
def __init__(self, key):
|
def __init__(self, key):
|
||||||
self._key = key
|
self._key = key
|
||||||
|
|
||||||
def to_bytes(self):
|
def to_bytes(self):
|
||||||
return self._key.private_bytes_raw()
|
return self._key.private_bytes_raw()
|
||||||
|
|
||||||
def public_key(self):
|
def public_key(self):
|
||||||
return NoisePublicKey(self._key.public_key())
|
return NoisePublicKey(self._key.public_key())
|
||||||
|
|
||||||
def get_public_key(self):
|
def get_public_key(self):
|
||||||
return NoisePublicKey(self._key.public_key())
|
return NoisePublicKey(self._key.public_key())
|
||||||
|
|
||||||
class NoisePublicKey:
|
class NoisePublicKey:
|
||||||
def __init__(self, key):
|
def __init__(self, key):
|
||||||
self._key = key
|
self._key = key
|
||||||
|
|
||||||
def to_bytes(self):
|
def to_bytes(self):
|
||||||
return self._key.public_bytes_raw()
|
return self._key.public_bytes_raw()
|
||||||
|
|
||||||
@ -345,73 +414,70 @@ def create_noise_keypair():
|
|||||||
logging.error(f"Failed to create Noise keypair: {e}")
|
logging.error(f"Failed to create Noise keypair: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def info_from_p2p_addr(addr):
|
def info_from_p2p_addr(addr):
|
||||||
"""Extract peer info from multiaddr - you'll need to implement this"""
|
"""Extract peer info from multiaddr - you'll need to implement this"""
|
||||||
# This is a placeholder - you need to implement the actual parsing
|
# This is a placeholder - you need to implement the actual parsing
|
||||||
# based on your libp2p implementation
|
# based on your libp2p implementation
|
||||||
pass
|
|
||||||
|
|
||||||
async def run(port: int, destination: str) -> None:
|
async def run(port: int, destination: str) -> None:
|
||||||
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
key_pair = generate_new_rsa_identity()
|
key_pair = generate_new_rsa_identity()
|
||||||
logging.debug("Generated RSA keypair")
|
logging.debug("Generated RSA keypair")
|
||||||
|
|
||||||
noise_privkey = create_noise_keypair()
|
noise_privkey = create_noise_keypair()
|
||||||
logging.debug("Generated Noise keypair")
|
logging.debug("Generated Noise keypair")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Key generation failed: {e}")
|
logging.error(f"Key generation failed: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
noise_transport = NoiseTransport(key_pair, noise_privkey=noise_privkey)
|
noise_transport = NoiseTransport(key_pair, noise_privkey=noise_privkey)
|
||||||
logging.debug(f"Noise transport initialized: {noise_transport}")
|
logging.debug(f"Noise transport initialized: {noise_transport}")
|
||||||
sec_opt = {TProtocol("/noise"): noise_transport}
|
sec_opt = {TProtocol("/noise"): noise_transport}
|
||||||
muxer_opt = {TProtocol(YAMUX_PROTOCOL_ID): InteropYamux}
|
muxer_opt = {TProtocol(YAMUX_PROTOCOL_ID): InteropYamux}
|
||||||
|
|
||||||
logging.info(f"Using muxer: {muxer_opt}")
|
logging.info(f"Using muxer: {muxer_opt}")
|
||||||
|
|
||||||
host = new_host(
|
host = new_host(key_pair=key_pair, sec_opt=sec_opt, muxer_opt=muxer_opt)
|
||||||
key_pair=key_pair,
|
|
||||||
sec_opt=sec_opt,
|
|
||||||
muxer_opt=muxer_opt
|
|
||||||
)
|
|
||||||
|
|
||||||
peer_id = host.get_id().pretty()
|
peer_id = host.get_id().pretty()
|
||||||
logging.info(f"Host peer ID: {peer_id}")
|
logging.info(f"Host peer ID: {peer_id}")
|
||||||
|
|
||||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery():
|
||||||
if not destination:
|
if not destination:
|
||||||
host.set_stream_handler(PING_PROTOCOL_ID, handle_ping)
|
host.set_stream_handler(PING_PROTOCOL_ID, handle_ping)
|
||||||
|
|
||||||
logging.info(f"Server listening on {listen_addr}")
|
logging.info(f"Server listening on {listen_addr}")
|
||||||
logging.info(f"Full address: {listen_addr}/p2p/{peer_id}")
|
logging.info(f"Full address: {listen_addr}/p2p/{peer_id}")
|
||||||
logging.info("Waiting for connections...")
|
logging.info("Waiting for connections...")
|
||||||
|
|
||||||
await trio.sleep_forever()
|
await trio.sleep_forever()
|
||||||
else:
|
else:
|
||||||
maddr = multiaddr.Multiaddr(destination)
|
maddr = multiaddr.Multiaddr(destination)
|
||||||
info = info_from_p2p_addr(maddr)
|
info = info_from_p2p_addr(maddr)
|
||||||
|
|
||||||
logging.info(f"Connecting to {info.peer_id}")
|
logging.info(f"Connecting to {info.peer_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with trio.fail_after(30):
|
with trio.fail_after(30):
|
||||||
await host.connect(info)
|
await host.connect(info)
|
||||||
logging.info(f"Connected to {info.peer_id}")
|
logging.info(f"Connected to {info.peer_id}")
|
||||||
|
|
||||||
await trio.sleep(2.0)
|
await trio.sleep(2.0)
|
||||||
|
|
||||||
logging.info(f"Opening ping stream to {info.peer_id}")
|
logging.info(f"Opening ping stream to {info.peer_id}")
|
||||||
stream = await host.new_stream(info.peer_id, [PING_PROTOCOL_ID])
|
stream = await host.new_stream(info.peer_id, [PING_PROTOCOL_ID])
|
||||||
logging.info(f"Opened ping stream to {info.peer_id}")
|
logging.info(f"Opened ping stream to {info.peer_id}")
|
||||||
|
|
||||||
await trio.sleep(0.5)
|
await trio.sleep(0.5)
|
||||||
|
|
||||||
await send_ping(stream)
|
await send_ping(stream)
|
||||||
|
|
||||||
logging.info("Ping completed successfully")
|
logging.info("Ping completed successfully")
|
||||||
|
|
||||||
logging.info("Keeping connection alive for 5 seconds...")
|
logging.info("Keeping connection alive for 5 seconds...")
|
||||||
await trio.sleep(5.0)
|
await trio.sleep(5.0)
|
||||||
except trio.TooSlowError:
|
except trio.TooSlowError:
|
||||||
@ -420,39 +486,22 @@ async def run(port: int, destination: str) -> None:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Connection failed to {info.peer_id}: {e}")
|
logging.error(f"Connection failed to {info.peer_id}: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def info_from_p2p_addr(addr):
|
|
||||||
"""Extract peer info from multiaddr"""
|
|
||||||
from libp2p.peer.id import ID
|
|
||||||
|
|
||||||
# Parse the multiaddr to extract peer ID
|
|
||||||
protocols = addr.protocols()
|
|
||||||
peer_id_str = None
|
|
||||||
|
|
||||||
for proto in protocols:
|
|
||||||
if proto.code == 421: # p2p protocol
|
|
||||||
peer_id_str = proto.value
|
|
||||||
break
|
|
||||||
|
|
||||||
if not peer_id_str:
|
|
||||||
raise ValueError("No peer ID found in multiaddr")
|
|
||||||
|
|
||||||
class PeerInfo:
|
|
||||||
def __init__(self, peer_id, addrs):
|
|
||||||
self.peer_id = peer_id
|
|
||||||
self.addrs = addrs
|
|
||||||
|
|
||||||
return PeerInfo(ID.from_base58(peer_id_str), [addr])
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="libp2p ping with Rust interoperability")
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument("-p", "--port", default=8000, type=int, help="Port to listen on")
|
description="libp2p ping with Rust interoperability"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-p", "--port", default=8000, type=int, help="Port to listen on"
|
||||||
|
)
|
||||||
parser.add_argument("-d", "--destination", type=str, help="Destination multiaddr")
|
parser.add_argument("-d", "--destination", type=str, help="Destination multiaddr")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
trio.run(run, args.port, args.destination)
|
trio.run(run, args.port, args.destination)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@ -461,5 +510,6 @@ def main():
|
|||||||
logging.error(f"Fatal error: {e}")
|
logging.error(f"Fatal error: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user