mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Merge pull request #372 from ralexstokes/add-py36-compatibility
Add py36 compatibility
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from .exceptions import RawConnError
|
||||
from .raw_connection_interface import IRawConnection
|
||||
@ -52,4 +53,6 @@ class RawConnection(IRawConnection):
|
||||
|
||||
async def close(self) -> None:
|
||||
self.writer.close()
|
||||
if sys.version_info < (3, 7):
|
||||
return
|
||||
await self.writer.wait_closed()
|
||||
|
||||
@ -149,7 +149,7 @@ class Pubsub:
|
||||
# Map of topic to topic validator
|
||||
self.topic_validators = {}
|
||||
|
||||
self.counter = time.time_ns()
|
||||
self.counter = int(time.time())
|
||||
|
||||
self._tasks = []
|
||||
# Call handle peer to keep waiting for updates to peer queue
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncIterator, Dict, Tuple, cast
|
||||
|
||||
# NOTE: import ``asynccontextmanager`` from ``contextlib`` when support for python 3.6 is dropped.
|
||||
from async_generator import asynccontextmanager
|
||||
import factory
|
||||
|
||||
from libp2p import generate_new_rsa_identity, generate_peer_id_from
|
||||
@ -173,7 +174,7 @@ async def host_pair_factory(is_secure: bool) -> Tuple[BasicHost, BasicHost]:
|
||||
return hosts[0], hosts[1]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@asynccontextmanager # type: ignore
|
||||
async def pair_of_connected_hosts(
|
||||
is_secure: bool = True
|
||||
) -> AsyncIterator[Tuple[BasicHost, BasicHost]]:
|
||||
|
||||
@ -143,6 +143,14 @@ floodsub_protocol_pytest_params = [
|
||||
]
|
||||
|
||||
|
||||
def _collect_node_ids(adj_list):
|
||||
node_ids = set()
|
||||
for node, neighbors in adj_list.items():
|
||||
node_ids.add(node)
|
||||
node_ids.update(set(neighbors))
|
||||
return node_ids
|
||||
|
||||
|
||||
async def perform_test_from_obj(obj, router_factory) -> None:
|
||||
"""
|
||||
Perform pubsub tests from a test object, which is composed as follows:
|
||||
@ -180,59 +188,43 @@ async def perform_test_from_obj(obj, router_factory) -> None:
|
||||
node_map = {}
|
||||
pubsub_map = {}
|
||||
|
||||
async def add_node(node_id_str: str) -> None:
|
||||
async def add_node(node_id_str: str):
|
||||
pubsub_router = router_factory(protocols=obj["supported_protocols"])
|
||||
pubsub = PubsubFactory(router=pubsub_router)
|
||||
await pubsub.host.get_network().listen(LISTEN_MADDR)
|
||||
node_map[node_id_str] = pubsub.host
|
||||
pubsub_map[node_id_str] = pubsub
|
||||
|
||||
tasks_connect = []
|
||||
for start_node_id in adj_list:
|
||||
# Create node if node does not yet exist
|
||||
if start_node_id not in node_map:
|
||||
await add_node(start_node_id)
|
||||
all_node_ids = _collect_node_ids(adj_list)
|
||||
|
||||
# For each neighbor of start_node, create if does not yet exist,
|
||||
# then connect start_node to neighbor
|
||||
for neighbor_id in adj_list[start_node_id]:
|
||||
# Create neighbor if neighbor does not yet exist
|
||||
if neighbor_id not in node_map:
|
||||
await add_node(neighbor_id)
|
||||
tasks_connect.append(
|
||||
connect(node_map[start_node_id], node_map[neighbor_id])
|
||||
)
|
||||
# Connect nodes and wait at least for 2 seconds
|
||||
await asyncio.gather(*tasks_connect, asyncio.sleep(2))
|
||||
for node in all_node_ids:
|
||||
await add_node(node)
|
||||
|
||||
for node, neighbors in adj_list.items():
|
||||
for neighbor_id in neighbors:
|
||||
await connect(node_map[node], node_map[neighbor_id])
|
||||
|
||||
# NOTE: the test using this routine will fail w/o these sleeps...
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Step 2) Subscribe to topics
|
||||
queues_map = {}
|
||||
topic_map = obj["topic_map"]
|
||||
|
||||
tasks_topic = []
|
||||
tasks_topic_data = []
|
||||
for topic, node_ids in topic_map.items():
|
||||
for node_id in node_ids:
|
||||
tasks_topic.append(pubsub_map[node_id].subscribe(topic))
|
||||
tasks_topic_data.append((node_id, topic))
|
||||
tasks_topic.append(asyncio.sleep(2))
|
||||
queue = await pubsub_map[node_id].subscribe(topic)
|
||||
if node_id not in queues_map:
|
||||
queues_map[node_id] = {}
|
||||
# Store queue in topic-queue map for node
|
||||
queues_map[node_id][topic] = queue
|
||||
|
||||
# Gather is like Promise.all
|
||||
responses = await asyncio.gather(*tasks_topic)
|
||||
for i in range(len(responses) - 1):
|
||||
node_id, topic = tasks_topic_data[i]
|
||||
if node_id not in queues_map:
|
||||
queues_map[node_id] = {}
|
||||
# Store queue in topic-queue map for node
|
||||
queues_map[node_id][topic] = responses[i]
|
||||
|
||||
# Allow time for subscribing before continuing
|
||||
await asyncio.sleep(0.01)
|
||||
# NOTE: the test using this routine will fail w/o these sleeps...
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Step 3) Publish messages
|
||||
topics_in_msgs_ordered = []
|
||||
messages = obj["messages"]
|
||||
tasks_publish = []
|
||||
|
||||
for msg in messages:
|
||||
topics = msg["topics"]
|
||||
@ -242,21 +234,17 @@ async def perform_test_from_obj(obj, router_factory) -> None:
|
||||
# Publish message
|
||||
# TODO: Should be single RPC package with several topics
|
||||
for topic in topics:
|
||||
tasks_publish.append(pubsub_map[node_id].publish(topic, data))
|
||||
|
||||
# For each topic in topics, add (topic, node_id, data) tuple to ordered test list
|
||||
for topic in topics:
|
||||
await pubsub_map[node_id].publish(topic, data)
|
||||
# For each topic in topics, add (topic, node_id, data) tuple to ordered test list
|
||||
topics_in_msgs_ordered.append((topic, node_id, data))
|
||||
|
||||
# Allow time for publishing before continuing
|
||||
await asyncio.gather(*tasks_publish, asyncio.sleep(2))
|
||||
|
||||
# Step 4) Check that all messages were received correctly.
|
||||
for topic, origin_node_id, data in topics_in_msgs_ordered:
|
||||
# Look at each node in each topic
|
||||
for node_id in topic_map[topic]:
|
||||
# Get message from subscription queue
|
||||
msg = await queues_map[node_id][topic].get()
|
||||
queue = queues_map[node_id][topic]
|
||||
msg = await queue.get()
|
||||
assert data == msg.data
|
||||
# Check the message origin
|
||||
assert node_map[origin_node_id].get_id().to_bytes() == msg.from_id
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from socket import socket
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
from multiaddr import Multiaddr
|
||||
@ -53,8 +54,11 @@ class TCPListener(IListener):
|
||||
if self.server is None:
|
||||
return
|
||||
self.server.close()
|
||||
await self.server.wait_closed()
|
||||
server = self.server
|
||||
self.server = None
|
||||
if sys.version_info < (3, 7):
|
||||
return
|
||||
await server.wait_closed()
|
||||
|
||||
|
||||
class TCP(ITransport):
|
||||
|
||||
Reference in New Issue
Block a user