mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-08 06:00:53 +00:00
rewrite tcp reader/writer interface
This commit is contained in:
@ -1,42 +1,25 @@
|
||||
import asyncio
|
||||
import trio
|
||||
|
||||
from libp2p.io.exceptions import IOException
|
||||
from .exceptions import RawConnError
|
||||
from .raw_connection_interface import IRawConnection
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
|
||||
|
||||
class RawConnection(IRawConnection):
|
||||
reader: asyncio.StreamReader
|
||||
writer: asyncio.StreamWriter
|
||||
read_write_closer: ReadWriteCloser
|
||||
is_initiator: bool
|
||||
|
||||
_drain_lock: asyncio.Lock
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
initiator: bool,
|
||||
) -> None:
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
def __init__(self, read_write_closer: ReadWriteCloser, initiator: bool) -> None:
|
||||
self.read_write_closer = read_write_closer
|
||||
self.is_initiator = initiator
|
||||
|
||||
self._drain_lock = asyncio.Lock()
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""Raise `RawConnError` if the underlying connection breaks."""
|
||||
try:
|
||||
self.writer.write(data)
|
||||
except ConnectionResetError as error:
|
||||
await self.read_write_closer.write(data)
|
||||
except IOException as error:
|
||||
raise RawConnError(error)
|
||||
# Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501
|
||||
# Use a lock to serialize drain() calls. Circumvents this bug:
|
||||
# https://bugs.python.org/issue29930
|
||||
async with self._drain_lock:
|
||||
try:
|
||||
await self.writer.drain()
|
||||
except ConnectionResetError as error:
|
||||
raise RawConnError(error)
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
"""
|
||||
@ -46,10 +29,9 @@ class RawConnection(IRawConnection):
|
||||
Raise `RawConnError` if the underlying connection breaks
|
||||
"""
|
||||
try:
|
||||
return await self.reader.read(n)
|
||||
except ConnectionResetError as error:
|
||||
return await self.read_write_closer.read(n)
|
||||
except IOException as error:
|
||||
raise RawConnError(error)
|
||||
|
||||
async def close(self) -> None:
|
||||
self.writer.close()
|
||||
await self.writer.wait_closed()
|
||||
await self.read_write_closer.close()
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Dict, List, Optional
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
|
||||
from libp2p.io.abc import ReadWriteCloser
|
||||
from libp2p.network.connection.net_connection_interface import INetConn
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.peer.peerstore import PeerStoreError
|
||||
@ -149,7 +150,7 @@ class Swarm(INetwork):
|
||||
logger.debug("successfully opened a stream to peer %s", peer_id)
|
||||
return net_stream
|
||||
|
||||
async def listen(self, *multiaddrs: Multiaddr) -> bool:
|
||||
async def listen(self, *multiaddrs: Multiaddr, nursery) -> bool:
|
||||
"""
|
||||
:param multiaddrs: one or many multiaddrs to start listening on
|
||||
:return: true if at least one success
|
||||
@ -167,15 +168,8 @@ class Swarm(INetwork):
|
||||
if str(maddr) in self.listeners:
|
||||
return True
|
||||
|
||||
async def conn_handler(
|
||||
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
) -> None:
|
||||
connection_info = writer.get_extra_info("peername")
|
||||
# TODO make a proper multiaddr
|
||||
peer_addr = f"/ip4/{connection_info[0]}/tcp/{connection_info[1]}"
|
||||
logger.debug("inbound connection at %s", peer_addr)
|
||||
# logger.debug("inbound connection request", peer_id)
|
||||
raw_conn = RawConnection(reader, writer, False)
|
||||
async def conn_handler(read_write_closer: ReadWriteCloser) -> None:
|
||||
raw_conn = RawConnection(read_write_closer, False)
|
||||
|
||||
# Per, https://discuss.libp2p.io/t/multistream-security/130, we first secure
|
||||
# the conn and then mux the conn
|
||||
@ -185,14 +179,10 @@ class Swarm(INetwork):
|
||||
raw_conn, ID(b""), False
|
||||
)
|
||||
except SecurityUpgradeFailure as error:
|
||||
error_msg = "fail to upgrade security for peer at %s"
|
||||
logger.debug(error_msg, peer_addr)
|
||||
await raw_conn.close()
|
||||
raise SwarmException(error_msg % peer_addr) from error
|
||||
raise SwarmException() from error
|
||||
peer_id = secured_conn.get_remote_peer()
|
||||
|
||||
logger.debug("upgraded security for peer at %s", peer_addr)
|
||||
logger.debug("identified peer at %s as %s", peer_addr, peer_id)
|
||||
|
||||
try:
|
||||
muxed_conn = await self.upgrader.upgrade_connection(
|
||||
@ -213,7 +203,7 @@ class Swarm(INetwork):
|
||||
# Success
|
||||
listener = self.transport.create_listener(conn_handler)
|
||||
self.listeners[str(maddr)] = listener
|
||||
await listener.listen(maddr)
|
||||
await listener.listen(maddr, nursery)
|
||||
|
||||
# Call notifiers since event occurred
|
||||
self.notify_listen(maddr)
|
||||
|
||||
Reference in New Issue
Block a user