mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
Compare commits
1 Commits
py-rust-in
...
test/utils
| Author | SHA1 | Date | |
|---|---|---|---|
| 3035cdc56b |
2
Makefile
2
Makefile
@ -38,7 +38,7 @@ lint:
|
||||
)
|
||||
|
||||
test:
|
||||
python -m pytest tests -n auto
|
||||
python -m pytest tests
|
||||
|
||||
# protobufs management
|
||||
|
||||
|
||||
@ -1,19 +0,0 @@
|
||||
These commands are to be run in `./interop/exec`
|
||||
|
||||
## Redis
|
||||
|
||||
```bash
|
||||
docker run -p 6379:6379 -it redis:latest
|
||||
```
|
||||
|
||||
## Listener
|
||||
|
||||
```bash
|
||||
transport=tcp ip=0.0.0.0 is_dialer=false redis_addr=6379 test_timeout_seconds=180 security=insecure muxer=mplex python3 native_ping.py
|
||||
```
|
||||
|
||||
## Dialer
|
||||
|
||||
```bash
|
||||
transport=tcp ip=0.0.0.0 is_dialer=true port=8001 redis_addr=6379 port=8001 test_timeout_seconds=180 security=insecure muxer=mplex python3 native_ping.py
|
||||
```
|
||||
107
interop/arch.py
107
interop/arch.py
@ -1,107 +0,0 @@
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
)
|
||||
|
||||
import multiaddr
|
||||
import redis
|
||||
import trio
|
||||
|
||||
from libp2p import (
|
||||
new_host,
|
||||
)
|
||||
from libp2p.crypto.keys import (
|
||||
KeyPair,
|
||||
)
|
||||
from libp2p.crypto.rsa import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.crypto.x25519 import create_new_key_pair as create_new_x25519_key_pair
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.security.insecure.transport import (
|
||||
PLAINTEXT_PROTOCOL_ID,
|
||||
InsecureTransport,
|
||||
)
|
||||
from libp2p.security.noise.transport import PROTOCOL_ID as NOISE_PROTOCOL_ID
|
||||
from libp2p.security.noise.transport import Transport as NoiseTransport
|
||||
import libp2p.security.secio.transport as secio
|
||||
from libp2p.stream_muxer.mplex.mplex import (
|
||||
MPLEX_PROTOCOL_ID,
|
||||
Mplex,
|
||||
)
|
||||
from libp2p.stream_muxer.yamux.yamux import (
|
||||
Yamux,
|
||||
)
|
||||
from libp2p.stream_muxer.yamux.yamux import PROTOCOL_ID as YAMUX_PROTOCOL_ID
|
||||
|
||||
|
||||
def generate_new_rsa_identity() -> KeyPair:
|
||||
return create_new_key_pair()
|
||||
|
||||
|
||||
async def build_host(transport: str, ip: str, port: str, sec_protocol: str, muxer: str):
|
||||
match (sec_protocol, muxer):
|
||||
case ("insecure", "mplex"):
|
||||
key_pair = create_new_key_pair()
|
||||
host = new_host(
|
||||
key_pair,
|
||||
{TProtocol(MPLEX_PROTOCOL_ID): Mplex},
|
||||
{
|
||||
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair),
|
||||
TProtocol(secio.ID): secio.Transport(key_pair),
|
||||
},
|
||||
)
|
||||
muladdr = multiaddr.Multiaddr(f"/ip4/{ip}/tcp/{port}")
|
||||
return (host, muladdr)
|
||||
case ("insecure", "yamux"):
|
||||
key_pair = create_new_key_pair()
|
||||
host = new_host(
|
||||
key_pair,
|
||||
{TProtocol(YAMUX_PROTOCOL_ID): Yamux},
|
||||
{
|
||||
TProtocol(PLAINTEXT_PROTOCOL_ID): InsecureTransport(key_pair),
|
||||
TProtocol(secio.ID): secio.Transport(key_pair),
|
||||
},
|
||||
)
|
||||
muladdr = multiaddr.Multiaddr(f"/ip4/{ip}/tcp/{port}")
|
||||
return (host, muladdr)
|
||||
case ("noise", "yamux"):
|
||||
key_pair = create_new_key_pair()
|
||||
noise_key_pair = create_new_x25519_key_pair()
|
||||
|
||||
host = new_host(
|
||||
key_pair,
|
||||
{TProtocol(YAMUX_PROTOCOL_ID): Yamux},
|
||||
{
|
||||
NOISE_PROTOCOL_ID: NoiseTransport(
|
||||
key_pair, noise_privkey=noise_key_pair.private_key
|
||||
)
|
||||
},
|
||||
)
|
||||
muladdr = multiaddr.Multiaddr(f"/ip4/{ip}/tcp/{port}")
|
||||
return (host, muladdr)
|
||||
case _:
|
||||
raise ValueError("Protocols not supported")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RedisClient:
|
||||
client: redis.Redis
|
||||
|
||||
def brpop(self, key: str, timeout: float) -> list[str]:
|
||||
result = self.client.brpop([key], timeout)
|
||||
return [result[1]] if result else []
|
||||
|
||||
def rpush(self, key: str, value: str) -> None:
|
||||
self.client.rpush(key, value)
|
||||
|
||||
|
||||
async def main():
|
||||
client = RedisClient(redis.Redis(host="localhost", port=6379, db=0))
|
||||
client.rpush("test", "hello")
|
||||
print(client.blpop("test", timeout=5))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
trio.run(main)
|
||||
@ -1,57 +0,0 @@
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
)
|
||||
import os
|
||||
from typing import (
|
||||
Optional,
|
||||
)
|
||||
|
||||
|
||||
def str_to_bool(val: str) -> bool:
|
||||
return val.lower() in ("true", "1")
|
||||
|
||||
|
||||
class ConfigError(Exception):
|
||||
"""Raised when the required environment variables are missing or invalid"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
transport: str
|
||||
sec_protocol: Optional[str]
|
||||
muxer: Optional[str]
|
||||
ip: str
|
||||
is_dialer: bool
|
||||
test_timeout: int
|
||||
redis_addr: str
|
||||
port: str
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "Config":
|
||||
try:
|
||||
transport = os.environ["transport"]
|
||||
ip = os.environ["ip"]
|
||||
except KeyError as e:
|
||||
raise ConfigError(f"{e.args[0]} env variable not set") from None
|
||||
|
||||
try:
|
||||
is_dialer = str_to_bool(os.environ.get("is_dialer", "true"))
|
||||
test_timeout = int(os.environ.get("test_timeout", "180"))
|
||||
except ValueError as e:
|
||||
raise ConfigError(f"Invalid value in env: {e}") from None
|
||||
|
||||
redis_addr = os.environ.get("redis_addr", 6379)
|
||||
sec_protocol = os.environ.get("security")
|
||||
muxer = os.environ.get("muxer")
|
||||
port = os.environ.get("port", "8000")
|
||||
|
||||
return cls(
|
||||
transport=transport,
|
||||
sec_protocol=sec_protocol,
|
||||
muxer=muxer,
|
||||
ip=ip,
|
||||
is_dialer=is_dialer,
|
||||
test_timeout=test_timeout,
|
||||
redis_addr=redis_addr,
|
||||
port=port,
|
||||
)
|
||||
@ -1,33 +0,0 @@
|
||||
import trio
|
||||
|
||||
from interop.exec.config.mod import (
|
||||
Config,
|
||||
ConfigError,
|
||||
)
|
||||
from interop.lib import (
|
||||
run_test,
|
||||
)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
try:
|
||||
config = Config.from_env()
|
||||
except ConfigError as e:
|
||||
print(f"Config error: {e}")
|
||||
return
|
||||
|
||||
# Uncomment and implement when ready
|
||||
_ = await run_test(
|
||||
config.transport,
|
||||
config.ip,
|
||||
config.port,
|
||||
config.is_dialer,
|
||||
config.test_timeout,
|
||||
config.redis_addr,
|
||||
config.sec_protocol,
|
||||
config.muxer,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
trio.run(main)
|
||||
120
interop/lib.py
120
interop/lib.py
@ -1,120 +0,0 @@
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
)
|
||||
import json
|
||||
import time
|
||||
|
||||
from loguru import (
|
||||
logger,
|
||||
)
|
||||
import multiaddr
|
||||
import redis
|
||||
import trio
|
||||
|
||||
from interop.arch import (
|
||||
RedisClient,
|
||||
build_host,
|
||||
)
|
||||
from libp2p.custom_types import (
|
||||
TProtocol,
|
||||
)
|
||||
from libp2p.network.stream.net_stream import (
|
||||
INetStream,
|
||||
)
|
||||
from libp2p.peer.peerinfo import (
|
||||
info_from_p2p_addr,
|
||||
)
|
||||
|
||||
PING_PROTOCOL_ID = TProtocol("/ipfs/ping/1.0.0")
|
||||
PING_LENGTH = 32
|
||||
RESP_TIMEOUT = 60
|
||||
|
||||
|
||||
async def handle_ping(stream: INetStream) -> None:
|
||||
while True:
|
||||
try:
|
||||
payload = await stream.read(PING_LENGTH)
|
||||
peer_id = stream.muxed_conn.peer_id
|
||||
if payload is not None:
|
||||
print(f"received ping from {peer_id}")
|
||||
|
||||
await stream.write(payload)
|
||||
print(f"responded with pong to {peer_id}")
|
||||
|
||||
except Exception:
|
||||
await stream.reset()
|
||||
break
|
||||
|
||||
|
||||
async def send_ping(stream: INetStream) -> None:
|
||||
try:
|
||||
payload = b"\x01" * PING_LENGTH
|
||||
print(f"sending ping to {stream.muxed_conn.peer_id}")
|
||||
|
||||
await stream.write(payload)
|
||||
|
||||
with trio.fail_after(RESP_TIMEOUT):
|
||||
response = await stream.read(PING_LENGTH)
|
||||
|
||||
if response == payload:
|
||||
print(f"received pong from {stream.muxed_conn.peer_id}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"error occurred: {e}")
|
||||
|
||||
|
||||
async def run_test(
|
||||
transport, ip, port, is_dialer, test_timeout, redis_addr, sec_protocol, muxer
|
||||
):
|
||||
logger.info("Starting run_test")
|
||||
|
||||
redis_client = RedisClient(
|
||||
redis.Redis(host="localhost", port=int(redis_addr), db=0)
|
||||
)
|
||||
(host, listen_addr) = await build_host(transport, ip, port, sec_protocol, muxer)
|
||||
logger.info(f"Running ping test local_peer={host.get_id()}")
|
||||
|
||||
async with host.run(listen_addrs=[listen_addr]), trio.open_nursery() as nursery:
|
||||
if not is_dialer:
|
||||
host.set_stream_handler(PING_PROTOCOL_ID, handle_ping)
|
||||
ma = f"{listen_addr}/p2p/{host.get_id().pretty()}"
|
||||
redis_client.rpush("listenerAddr", ma)
|
||||
|
||||
logger.info(f"Test instance, listening: {ma}")
|
||||
else:
|
||||
redis_addr = redis_client.brpop("listenerAddr", timeout=5)
|
||||
destination = redis_addr[0].decode()
|
||||
maddr = multiaddr.Multiaddr(destination)
|
||||
info = info_from_p2p_addr(maddr)
|
||||
|
||||
handshake_start = time.perf_counter()
|
||||
|
||||
logger.info("GETTING READY FOR CONNECTION")
|
||||
await host.connect(info)
|
||||
logger.info("HOST CONNECTED")
|
||||
|
||||
# TILL HERE EVERYTHING IS FINE
|
||||
|
||||
stream = await host.new_stream(info.peer_id, [PING_PROTOCOL_ID])
|
||||
logger.info("CREATED NEW STREAM")
|
||||
|
||||
# DOES NOT MORE FORWARD FROM THIS
|
||||
logger.info("Remote conection established")
|
||||
|
||||
nursery.start_soon(send_ping, stream)
|
||||
|
||||
handshake_plus_ping = (time.perf_counter() - handshake_start) * 1000.0
|
||||
|
||||
logger.info(f"handshake time: {handshake_plus_ping:.2f}ms")
|
||||
return
|
||||
|
||||
await trio.sleep_forever()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Report:
|
||||
handshake_plus_one_rtt_millis: float
|
||||
ping_rtt_millis: float
|
||||
|
||||
def gen_report(self):
|
||||
return json.dumps(self.__dict__)
|
||||
@ -5,6 +5,7 @@ from collections.abc import (
|
||||
from contextlib import (
|
||||
asynccontextmanager,
|
||||
)
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Optional,
|
||||
@ -67,10 +68,7 @@ if TYPE_CHECKING:
|
||||
# telling it to listen on the given listen addresses.
|
||||
|
||||
|
||||
# logger = logging.getLogger("libp2p.network.basic_host")
|
||||
from loguru import (
|
||||
logger,
|
||||
)
|
||||
logger = logging.getLogger("libp2p.network.basic_host")
|
||||
|
||||
|
||||
class BasicHost(IHost):
|
||||
@ -183,15 +181,12 @@ class BasicHost(IHost):
|
||||
:return: stream: new stream created
|
||||
"""
|
||||
net_stream = await self._network.new_stream(peer_id)
|
||||
logger.info("INETSTREAM CHECKING IN")
|
||||
logger.info(protocol_ids)
|
||||
|
||||
# Perform protocol muxing to determine protocol to use
|
||||
try:
|
||||
logger.debug("PROTOCOLS TRYING TO GET SENT")
|
||||
selected_protocol = await self.multiselect_client.select_one_of(
|
||||
list(protocol_ids), MultiselectCommunicator(net_stream)
|
||||
)
|
||||
logger.info("PROTOCOLS GOT SENT")
|
||||
except MultiselectClientError as error:
|
||||
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
|
||||
await net_stream.reset()
|
||||
|
||||
@ -1,11 +1,8 @@
|
||||
import logging
|
||||
from typing import (
|
||||
Optional,
|
||||
)
|
||||
|
||||
# logger = logging.getLogger("libp2p.network.swarm")
|
||||
from loguru import (
|
||||
logger,
|
||||
)
|
||||
from multiaddr import (
|
||||
Multiaddr,
|
||||
)
|
||||
@ -58,6 +55,8 @@ from .exceptions import (
|
||||
SwarmException,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("libp2p.network.swarm")
|
||||
|
||||
|
||||
def create_default_stream_handler(network: INetworkService) -> StreamHandlerFn:
|
||||
async def stream_handler(stream: INetStream) -> None:
|
||||
@ -131,7 +130,6 @@ class Swarm(Service, INetworkService):
|
||||
:return: muxed connection
|
||||
"""
|
||||
if peer_id in self.connections:
|
||||
logger.info("WE ARE RETURNING, PEER ALREADAY EXISTS")
|
||||
# If muxed connection already exists for peer_id,
|
||||
# set muxed connection equal to existing muxed connection
|
||||
return self.connections[peer_id]
|
||||
@ -152,7 +150,6 @@ class Swarm(Service, INetworkService):
|
||||
# Try all known addresses
|
||||
for multiaddr in addrs:
|
||||
try:
|
||||
logger.info("HANDSHAKE GOING TO HAPPEN")
|
||||
return await self.dial_addr(multiaddr, peer_id)
|
||||
except SwarmException as e:
|
||||
exceptions.append(e)
|
||||
@ -227,11 +224,8 @@ class Swarm(Service, INetworkService):
|
||||
logger.debug("attempting to open a stream to peer %s", peer_id)
|
||||
|
||||
swarm_conn = await self.dial_peer(peer_id)
|
||||
logger.info("INETCONN CREATED")
|
||||
|
||||
net_stream = await swarm_conn.new_stream()
|
||||
logger.info("INETSTREAM CREATED")
|
||||
|
||||
logger.debug("successfully opened a stream to peer %s", peer_id)
|
||||
return net_stream
|
||||
|
||||
|
||||
@ -2,10 +2,6 @@ from collections.abc import (
|
||||
Sequence,
|
||||
)
|
||||
|
||||
from loguru import (
|
||||
logger,
|
||||
)
|
||||
|
||||
from libp2p.abc import (
|
||||
IMultiselectClient,
|
||||
IMultiselectCommunicator,
|
||||
@ -40,15 +36,11 @@ class MultiselectClient(IMultiselectClient):
|
||||
try:
|
||||
await communicator.write(MULTISELECT_PROTOCOL_ID)
|
||||
except MultiselectCommunicatorError as error:
|
||||
logger.error("WROTE FAIL")
|
||||
raise MultiselectClientError() from error
|
||||
|
||||
logger.info(f"WROTE SUC, {MULTISELECT_PROTOCOL_ID}")
|
||||
try:
|
||||
handshake_contents = await communicator.read()
|
||||
logger.info(f"READ SUC, {handshake_contents}")
|
||||
except MultiselectCommunicatorError as error:
|
||||
logger.error(f"READ FAIL, {error}")
|
||||
raise MultiselectClientError() from error
|
||||
|
||||
if not is_valid_handshake(handshake_contents):
|
||||
@ -67,12 +59,9 @@ class MultiselectClient(IMultiselectClient):
|
||||
:return: selected protocol
|
||||
:raise MultiselectClientError: raised when protocol negotiation failed
|
||||
"""
|
||||
logger.info("TRYING TO GET THE HANDSHAKE HAPPENED")
|
||||
await self.handshake(communicator)
|
||||
logger.info("HANDSHAKE HAPPENED")
|
||||
|
||||
for protocol in protocols:
|
||||
logger.info(protocol)
|
||||
try:
|
||||
selected_protocol = await self.try_select(communicator, protocol)
|
||||
return selected_protocol
|
||||
@ -124,17 +113,11 @@ class MultiselectClient(IMultiselectClient):
|
||||
"""
|
||||
try:
|
||||
await communicator.write(protocol)
|
||||
from loguru import (
|
||||
logger,
|
||||
)
|
||||
|
||||
logger.info(protocol)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
|
||||
try:
|
||||
response = await communicator.read()
|
||||
logger.info("Response: ", response)
|
||||
except MultiselectCommunicatorError as error:
|
||||
raise MultiselectClientError() from error
|
||||
|
||||
|
||||
@ -122,9 +122,6 @@ class Pubsub(Service, IPubsub):
|
||||
strict_signing: bool
|
||||
sign_key: PrivateKey
|
||||
|
||||
# Set of blacklisted peer IDs
|
||||
blacklisted_peers: set[ID]
|
||||
|
||||
event_handle_peer_queue_started: trio.Event
|
||||
event_handle_dead_peer_queue_started: trio.Event
|
||||
|
||||
@ -204,9 +201,6 @@ class Pubsub(Service, IPubsub):
|
||||
|
||||
self.counter = int(time.time())
|
||||
|
||||
# Set of blacklisted peer IDs
|
||||
self.blacklisted_peers = set()
|
||||
|
||||
self.event_handle_peer_queue_started = trio.Event()
|
||||
self.event_handle_dead_peer_queue_started = trio.Event()
|
||||
|
||||
@ -326,82 +320,6 @@ class Pubsub(Service, IPubsub):
|
||||
if topic in self.topic_validators
|
||||
)
|
||||
|
||||
def add_to_blacklist(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Add a peer to the blacklist.
|
||||
When a peer is blacklisted:
|
||||
- Any existing connection to that peer is immediately closed and removed
|
||||
- The peer is removed from all topic subscription mappings
|
||||
- Future connection attempts from this peer will be rejected
|
||||
- Messages forwarded by or originating from this peer will be dropped
|
||||
- The peer will not be able to participate in pubsub communication
|
||||
|
||||
:param peer_id: the peer ID to blacklist
|
||||
"""
|
||||
self.blacklisted_peers.add(peer_id)
|
||||
logger.debug("Added peer %s to blacklist", peer_id)
|
||||
self.manager.run_task(self._teardown_if_connected, peer_id)
|
||||
|
||||
async def _teardown_if_connected(self, peer_id: ID) -> None:
|
||||
"""Close their stream and remove them if connected"""
|
||||
stream = self.peers.get(peer_id)
|
||||
if stream is not None:
|
||||
try:
|
||||
await stream.reset()
|
||||
except Exception:
|
||||
pass
|
||||
del self.peers[peer_id]
|
||||
# Also remove from any subscription maps:
|
||||
for _topic, peerset in self.peer_topics.items():
|
||||
if peer_id in peerset:
|
||||
peerset.discard(peer_id)
|
||||
|
||||
def remove_from_blacklist(self, peer_id: ID) -> None:
|
||||
"""
|
||||
Remove a peer from the blacklist.
|
||||
Once removed from the blacklist:
|
||||
- The peer can establish new connections to this node
|
||||
- Messages from this peer will be processed normally
|
||||
- The peer can participate in topic subscriptions and message forwarding
|
||||
|
||||
:param peer_id: the peer ID to remove from blacklist
|
||||
"""
|
||||
self.blacklisted_peers.discard(peer_id)
|
||||
logger.debug("Removed peer %s from blacklist", peer_id)
|
||||
|
||||
def is_peer_blacklisted(self, peer_id: ID) -> bool:
|
||||
"""
|
||||
Check if a peer is blacklisted.
|
||||
|
||||
:param peer_id: the peer ID to check
|
||||
:return: True if peer is blacklisted, False otherwise
|
||||
"""
|
||||
return peer_id in self.blacklisted_peers
|
||||
|
||||
def clear_blacklist(self) -> None:
|
||||
"""
|
||||
Clear all peers from the blacklist.
|
||||
This removes all blacklist restrictions, allowing previously blacklisted
|
||||
peers to:
|
||||
- Establish new connections
|
||||
- Send and forward messages
|
||||
- Participate in topic subscriptions
|
||||
|
||||
"""
|
||||
self.blacklisted_peers.clear()
|
||||
logger.debug("Cleared all peers from blacklist")
|
||||
|
||||
def get_blacklisted_peers(self) -> set[ID]:
|
||||
"""
|
||||
Get a copy of the current blacklisted peers.
|
||||
Returns a snapshot of all currently blacklisted peer IDs. These peers
|
||||
are completely isolated from pubsub communication - their connections
|
||||
are rejected and their messages are dropped.
|
||||
|
||||
:return: a set containing all blacklisted peer IDs
|
||||
"""
|
||||
return self.blacklisted_peers.copy()
|
||||
|
||||
async def stream_handler(self, stream: INetStream) -> None:
|
||||
"""
|
||||
Stream handler for pubsub. Gets invoked whenever a new stream is
|
||||
@ -428,10 +346,6 @@ class Pubsub(Service, IPubsub):
|
||||
await self.event_handle_dead_peer_queue_started.wait()
|
||||
|
||||
async def _handle_new_peer(self, peer_id: ID) -> None:
|
||||
if self.is_peer_blacklisted(peer_id):
|
||||
logger.debug("Rejecting blacklisted peer %s", peer_id)
|
||||
return
|
||||
|
||||
try:
|
||||
stream: INetStream = await self.host.new_stream(peer_id, self.protocols)
|
||||
except SwarmException as error:
|
||||
@ -445,6 +359,7 @@ class Pubsub(Service, IPubsub):
|
||||
except StreamClosed:
|
||||
logger.debug("Fail to add new peer %s: stream closed", peer_id)
|
||||
return
|
||||
# TODO: Check if the peer in black list.
|
||||
try:
|
||||
self.router.add_peer(peer_id, stream.get_protocol())
|
||||
except Exception as error:
|
||||
@ -694,20 +609,9 @@ class Pubsub(Service, IPubsub):
|
||||
"""
|
||||
logger.debug("attempting to publish message %s", msg)
|
||||
|
||||
# Check if the message forwarder (source) is in the blacklist. If yes, reject.
|
||||
if self.is_peer_blacklisted(msg_forwarder):
|
||||
logger.debug(
|
||||
"Rejecting message from blacklisted source peer %s", msg_forwarder
|
||||
)
|
||||
return
|
||||
# TODO: Check if the `source` is in the blacklist. If yes, reject.
|
||||
|
||||
# Check if the message originator (from) is in the blacklist. If yes, reject.
|
||||
msg_from_peer = ID(msg.from_id)
|
||||
if self.is_peer_blacklisted(msg_from_peer):
|
||||
logger.debug(
|
||||
"Rejecting message from blacklisted originator peer %s", msg_from_peer
|
||||
)
|
||||
return
|
||||
# TODO: Check if the `from` is in the blacklist. If yes, reject.
|
||||
|
||||
# If the message is processed before, return(i.e., don't further process the message) # noqa: E501
|
||||
if self._is_msg_seen(msg):
|
||||
|
||||
@ -1 +0,0 @@
|
||||
implement blacklist management for `pubsub.Pubsub` with methods to get, add, remove, check, and clear blacklisted peer IDs.
|
||||
7
setup.py
7
setup.py
@ -37,14 +37,10 @@ extras_require = {
|
||||
"pytest-trio>=0.5.2",
|
||||
"factory-boy>=2.12.0,<3.0.0",
|
||||
],
|
||||
"interop": ["redis==6.1.0", "logging==0.4.9.6", "loguru==0.7.3"],
|
||||
}
|
||||
|
||||
extras_require["dev"] = (
|
||||
extras_require["dev"]
|
||||
+ extras_require["docs"]
|
||||
+ extras_require["test"]
|
||||
+ extras_require["interop"]
|
||||
extras_require["dev"] + extras_require["docs"] + extras_require["test"]
|
||||
)
|
||||
|
||||
try:
|
||||
@ -69,7 +65,6 @@ install_requires = [
|
||||
"rpcudp>=3.0.0",
|
||||
"trio-typing>=0.0.4",
|
||||
"trio>=0.26.0",
|
||||
"loguru>=0.7.3",
|
||||
]
|
||||
|
||||
# Add platform-specific dependencies
|
||||
|
||||
@ -702,369 +702,3 @@ async def test_strict_signing_failed_validation(monkeypatch):
|
||||
await pubsubs_fsub[0].push_msg(pubsubs_fsub[0].my_id, msg)
|
||||
await trio.sleep(0.01)
|
||||
assert event.is_set()
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_blacklist_basic_operations():
|
||||
"""Test basic blacklist operations: add, remove, check, clear."""
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
pubsub = pubsubs_fsub[0]
|
||||
|
||||
# Create test peer IDs
|
||||
peer1 = IDFactory()
|
||||
peer2 = IDFactory()
|
||||
peer3 = IDFactory()
|
||||
|
||||
# Initially no peers should be blacklisted
|
||||
assert len(pubsub.get_blacklisted_peers()) == 0
|
||||
assert not pubsub.is_peer_blacklisted(peer1)
|
||||
assert not pubsub.is_peer_blacklisted(peer2)
|
||||
assert not pubsub.is_peer_blacklisted(peer3)
|
||||
|
||||
# Add peers to blacklist
|
||||
pubsub.add_to_blacklist(peer1)
|
||||
pubsub.add_to_blacklist(peer2)
|
||||
|
||||
# Check blacklist state
|
||||
assert len(pubsub.get_blacklisted_peers()) == 2
|
||||
assert pubsub.is_peer_blacklisted(peer1)
|
||||
assert pubsub.is_peer_blacklisted(peer2)
|
||||
assert not pubsub.is_peer_blacklisted(peer3)
|
||||
|
||||
# Remove one peer from blacklist
|
||||
pubsub.remove_from_blacklist(peer1)
|
||||
|
||||
# Check state after removal
|
||||
assert len(pubsub.get_blacklisted_peers()) == 1
|
||||
assert not pubsub.is_peer_blacklisted(peer1)
|
||||
assert pubsub.is_peer_blacklisted(peer2)
|
||||
assert not pubsub.is_peer_blacklisted(peer3)
|
||||
|
||||
# Add peer3 and then clear all
|
||||
pubsub.add_to_blacklist(peer3)
|
||||
assert len(pubsub.get_blacklisted_peers()) == 2
|
||||
|
||||
pubsub.clear_blacklist()
|
||||
assert len(pubsub.get_blacklisted_peers()) == 0
|
||||
assert not pubsub.is_peer_blacklisted(peer1)
|
||||
assert not pubsub.is_peer_blacklisted(peer2)
|
||||
assert not pubsub.is_peer_blacklisted(peer3)
|
||||
|
||||
# Test duplicate additions (should not increase size)
|
||||
pubsub.add_to_blacklist(peer1)
|
||||
pubsub.add_to_blacklist(peer1)
|
||||
assert len(pubsub.get_blacklisted_peers()) == 1
|
||||
|
||||
# Test removing non-blacklisted peer (should not cause errors)
|
||||
pubsub.remove_from_blacklist(peer2)
|
||||
assert len(pubsub.get_blacklisted_peers()) == 1
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_blacklist_blocks_new_peer_connections(monkeypatch):
|
||||
"""Test that blacklisted peers are rejected when trying to connect."""
|
||||
async with PubsubFactory.create_batch_with_floodsub(1) as pubsubs_fsub:
|
||||
pubsub = pubsubs_fsub[0]
|
||||
|
||||
# Create a blacklisted peer ID
|
||||
blacklisted_peer = IDFactory()
|
||||
|
||||
# Add peer to blacklist
|
||||
pubsub.add_to_blacklist(blacklisted_peer)
|
||||
|
||||
new_stream_called = False
|
||||
|
||||
async def mock_new_stream(*args, **kwargs):
|
||||
nonlocal new_stream_called
|
||||
new_stream_called = True
|
||||
# Create a mock stream
|
||||
from unittest.mock import (
|
||||
AsyncMock,
|
||||
Mock,
|
||||
)
|
||||
|
||||
mock_stream = Mock()
|
||||
mock_stream.write = AsyncMock()
|
||||
mock_stream.reset = AsyncMock()
|
||||
mock_stream.get_protocol = Mock(return_value="test_protocol")
|
||||
return mock_stream
|
||||
|
||||
router_add_peer_called = False
|
||||
|
||||
def mock_add_peer(*args, **kwargs):
|
||||
nonlocal router_add_peer_called
|
||||
router_add_peer_called = True
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(pubsub.host, "new_stream", mock_new_stream)
|
||||
m.setattr(pubsub.router, "add_peer", mock_add_peer)
|
||||
|
||||
# Attempt to handle the blacklisted peer
|
||||
await pubsub._handle_new_peer(blacklisted_peer)
|
||||
|
||||
# Verify that both new_stream and router.add_peer was not called
|
||||
assert (
|
||||
not new_stream_called
|
||||
), "new_stream should be not be called to get hello packet"
|
||||
assert (
|
||||
not router_add_peer_called
|
||||
), "Router.add_peer should not be called for blacklisted peer"
|
||||
assert (
|
||||
blacklisted_peer not in pubsub.peers
|
||||
), "Blacklisted peer should not be in peers dict"
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_blacklist_blocks_messages_from_blacklisted_originator():
|
||||
"""Test that messages from blacklisted originator (from field) are rejected."""
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
||||
pubsub = pubsubs_fsub[0]
|
||||
blacklisted_originator = pubsubs_fsub[1].my_id # Use existing peer ID
|
||||
|
||||
# Add the originator to blacklist
|
||||
pubsub.add_to_blacklist(blacklisted_originator)
|
||||
|
||||
# Create a message with blacklisted originator
|
||||
msg = make_pubsub_msg(
|
||||
origin_id=blacklisted_originator,
|
||||
topic_ids=[TESTING_TOPIC],
|
||||
data=TESTING_DATA,
|
||||
seqno=b"\x00" * 8,
|
||||
)
|
||||
|
||||
# Subscribe to the topic
|
||||
await pubsub.subscribe(TESTING_TOPIC)
|
||||
|
||||
# Track if router.publish is called
|
||||
router_publish_called = False
|
||||
|
||||
async def mock_router_publish(*args, **kwargs):
|
||||
nonlocal router_publish_called
|
||||
router_publish_called = True
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
original_router_publish = pubsub.router.publish
|
||||
pubsub.router.publish = mock_router_publish
|
||||
|
||||
try:
|
||||
# Attempt to push message from blacklisted originator
|
||||
await pubsub.push_msg(blacklisted_originator, msg)
|
||||
|
||||
# Verify message was rejected
|
||||
assert (
|
||||
not router_publish_called
|
||||
), "Router.publish should not be called for blacklisted originator"
|
||||
assert not pubsub._is_msg_seen(
|
||||
msg
|
||||
), "Message from blacklisted originator should not be marked as seen"
|
||||
|
||||
finally:
|
||||
pubsub.router.publish = original_router_publish
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_blacklist_allows_non_blacklisted_peers():
|
||||
"""Test that non-blacklisted peers can send messages normally."""
|
||||
async with PubsubFactory.create_batch_with_floodsub(3) as pubsubs_fsub:
|
||||
pubsub = pubsubs_fsub[0]
|
||||
allowed_peer = pubsubs_fsub[1].my_id
|
||||
blacklisted_peer = pubsubs_fsub[2].my_id
|
||||
|
||||
# Blacklist one peer but not the other
|
||||
pubsub.add_to_blacklist(blacklisted_peer)
|
||||
|
||||
# Create messages from both peers
|
||||
msg_from_allowed = make_pubsub_msg(
|
||||
origin_id=allowed_peer,
|
||||
topic_ids=[TESTING_TOPIC],
|
||||
data=b"allowed_data",
|
||||
seqno=b"\x00" * 8,
|
||||
)
|
||||
|
||||
msg_from_blacklisted = make_pubsub_msg(
|
||||
origin_id=blacklisted_peer,
|
||||
topic_ids=[TESTING_TOPIC],
|
||||
data=b"blacklisted_data",
|
||||
seqno=b"\x11" * 8,
|
||||
)
|
||||
|
||||
# Subscribe to the topic
|
||||
sub = await pubsub.subscribe(TESTING_TOPIC)
|
||||
|
||||
# Track router.publish calls
|
||||
router_publish_calls = []
|
||||
|
||||
async def mock_router_publish(*args, **kwargs):
|
||||
router_publish_calls.append(args)
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
original_router_publish = pubsub.router.publish
|
||||
pubsub.router.publish = mock_router_publish
|
||||
|
||||
try:
|
||||
# Send message from allowed peer (should succeed)
|
||||
await pubsub.push_msg(allowed_peer, msg_from_allowed)
|
||||
|
||||
# Send message from blacklisted peer (should be rejected)
|
||||
await pubsub.push_msg(allowed_peer, msg_from_blacklisted)
|
||||
|
||||
# Verify only allowed message was processed
|
||||
assert (
|
||||
len(router_publish_calls) == 1
|
||||
), "Only one message should be processed"
|
||||
assert pubsub._is_msg_seen(
|
||||
msg_from_allowed
|
||||
), "Allowed message should be marked as seen"
|
||||
assert not pubsub._is_msg_seen(
|
||||
msg_from_blacklisted
|
||||
), "Blacklisted message should not be marked as seen"
|
||||
|
||||
# Verify subscription received the allowed message
|
||||
received_msg = await sub.get()
|
||||
assert received_msg.data == b"allowed_data"
|
||||
|
||||
finally:
|
||||
pubsub.router.publish = original_router_publish
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_blacklist_integration_with_existing_functionality():
|
||||
"""Test that blacklisting works correctly with existing pubsub functionality."""
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
||||
pubsub = pubsubs_fsub[0]
|
||||
other_peer = pubsubs_fsub[1].my_id
|
||||
|
||||
# Test that seen messages cache still works with blacklisting
|
||||
pubsub.add_to_blacklist(other_peer)
|
||||
|
||||
msg = make_pubsub_msg(
|
||||
origin_id=other_peer,
|
||||
topic_ids=[TESTING_TOPIC],
|
||||
data=TESTING_DATA,
|
||||
seqno=b"\x00" * 8,
|
||||
)
|
||||
|
||||
# First attempt - should be rejected due to blacklist
|
||||
await pubsub.push_msg(other_peer, msg)
|
||||
assert not pubsub._is_msg_seen(msg)
|
||||
|
||||
# Remove from blacklist
|
||||
pubsub.remove_from_blacklist(other_peer)
|
||||
|
||||
# Now the message should be processed
|
||||
await pubsub.subscribe(TESTING_TOPIC)
|
||||
await pubsub.push_msg(other_peer, msg)
|
||||
assert pubsub._is_msg_seen(msg)
|
||||
|
||||
# If we try to send the same message again, it should be rejected
|
||||
# due to seen cache (not blacklist)
|
||||
router_publish_called = False
|
||||
|
||||
async def mock_router_publish(*args, **kwargs):
|
||||
nonlocal router_publish_called
|
||||
router_publish_called = True
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
original_router_publish = pubsub.router.publish
|
||||
pubsub.router.publish = mock_router_publish
|
||||
|
||||
try:
|
||||
await pubsub.push_msg(other_peer, msg)
|
||||
assert (
|
||||
not router_publish_called
|
||||
), "Duplicate message should be rejected by seen cache"
|
||||
finally:
|
||||
pubsub.router.publish = original_router_publish
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_blacklist_blocks_messages_from_blacklisted_source():
|
||||
"""Test that messages from blacklisted source (forwarder) are rejected."""
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
||||
pubsub = pubsubs_fsub[0]
|
||||
blacklisted_forwarder = pubsubs_fsub[1].my_id
|
||||
|
||||
# Add the forwarder to blacklist
|
||||
pubsub.add_to_blacklist(blacklisted_forwarder)
|
||||
|
||||
# Create a message
|
||||
msg = make_pubsub_msg(
|
||||
origin_id=pubsubs_fsub[1].my_id,
|
||||
topic_ids=[TESTING_TOPIC],
|
||||
data=TESTING_DATA,
|
||||
seqno=b"\x00" * 8,
|
||||
)
|
||||
|
||||
# Subscribe to the topic so we can check if message is processed
|
||||
await pubsub.subscribe(TESTING_TOPIC)
|
||||
|
||||
# Track if router.publish is called (it shouldn't be for blacklisted forwarder)
|
||||
router_publish_called = False
|
||||
|
||||
async def mock_router_publish(*args, **kwargs):
|
||||
nonlocal router_publish_called
|
||||
router_publish_called = True
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
original_router_publish = pubsub.router.publish
|
||||
pubsub.router.publish = mock_router_publish
|
||||
|
||||
try:
|
||||
# Attempt to push message from blacklisted forwarder
|
||||
await pubsub.push_msg(blacklisted_forwarder, msg)
|
||||
|
||||
# Verify message was rejected
|
||||
assert (
|
||||
not router_publish_called
|
||||
), "Router.publish should not be called for blacklisted forwarder"
|
||||
assert not pubsub._is_msg_seen(
|
||||
msg
|
||||
), "Message from blacklisted forwarder should not be marked as seen"
|
||||
|
||||
finally:
|
||||
pubsub.router.publish = original_router_publish
|
||||
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_blacklist_tears_down_existing_connection():
|
||||
"""
|
||||
Verify that if a peer is already in pubsub.peers and pubsub.peer_topics,
|
||||
calling add_to_blacklist(peer_id) immediately resets its stream and
|
||||
removes it from both places.
|
||||
"""
|
||||
# Create two pubsub instances (floodsub), so they can connect to each other
|
||||
async with PubsubFactory.create_batch_with_floodsub(2) as pubsubs_fsub:
|
||||
pubsub0, pubsub1 = pubsubs_fsub
|
||||
|
||||
# 1) Connect peer1 to peer0
|
||||
await connect(pubsub0.host, pubsub1.host)
|
||||
# Give handle_peer_queue some time to run
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# After connect, pubsub0.peers should contain pubsub1.my_id
|
||||
assert pubsub1.my_id in pubsub0.peers
|
||||
|
||||
# 2) Manually record a subscription from peer1 under TESTING_TOPIC,
|
||||
# so that peer1 shows up in pubsub0.peer_topics[TESTING_TOPIC].
|
||||
sub_msg = rpc_pb2.RPC.SubOpts(subscribe=True, topicid=TESTING_TOPIC)
|
||||
pubsub0.handle_subscription(pubsub1.my_id, sub_msg)
|
||||
|
||||
assert TESTING_TOPIC in pubsub0.peer_topics
|
||||
assert pubsub1.my_id in pubsub0.peer_topics[TESTING_TOPIC]
|
||||
|
||||
# 3) Now blacklist peer1
|
||||
pubsub0.add_to_blacklist(pubsub1.my_id)
|
||||
|
||||
# Allow the asynchronous teardown task (_teardown_if_connected) to run
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# 4a) pubsub0.peers should no longer contain peer1
|
||||
assert pubsub1.my_id not in pubsub0.peers
|
||||
|
||||
# 4b) pubsub0.peer_topics[TESTING_TOPIC] should no longer contain peer1
|
||||
# (or TESTING_TOPIC may have been removed entirely if no other peers remain)
|
||||
if TESTING_TOPIC in pubsub0.peer_topics:
|
||||
assert pubsub1.my_id not in pubsub0.peer_topics[TESTING_TOPIC]
|
||||
else:
|
||||
# It’s also fine if the entire topic entry was pruned
|
||||
assert TESTING_TOPIC not in pubsub0.peer_topics
|
||||
|
||||
Reference in New Issue
Block a user