feat: add agent version to identify protocol and improved tests

This commit is contained in:
acul71
2025-02-19 03:43:50 +01:00
committed by Paul Robinson
parent dc903460dc
commit bd8bd953ec
3 changed files with 56 additions and 3 deletions

View File

@ -15,16 +15,22 @@ from libp2p.custom_types import (
from libp2p.network.stream.exceptions import (
StreamClosed,
)
from libp2p.network.stream.net_stream_interface import (
INetStream,
)
from libp2p.utils import (
get_agent_version,
)
from .pb.identify_pb2 import (
Identify,
)
logger = logging.getLogger("libp2p.identity.identify")
ID = TProtocol("/ipfs/id/1.0.0")
PROTOCOL_VERSION = "ipfs/0.1.0"
# TODO dynamically generate the agent version
AGENT_VERSION = "py-libp2p/alpha"
logger = logging.getLogger("libp2p.identity.identify")
AGENT_VERSION = get_agent_version()
def _multiaddr_to_bytes(maddr: Multiaddr) -> bytes:

View File

@ -1,4 +1,8 @@
from importlib.metadata import (
version,
)
import itertools
import logging
import math
from libp2p.exceptions import (
@ -12,6 +16,8 @@ from .io.utils import (
read_exactly,
)
logger = logging.getLogger("libp2p.utils")
# Unsigned LEB128(varint codec)
# Reference: https://github.com/ethereum/py-wasm/blob/master/wasm/parsers/leb128.py
@ -84,3 +90,19 @@ async def read_delim(reader: Reader) -> bytes:
f'`msg_bytes` is not delimited by b"\\n": `msg_bytes`={msg_bytes!r}'
)
return msg_bytes[:-1]
def get_agent_version() -> str:
"""
Return the version of libp2p.
If the version cannot be determined due to an exception, return "py-libp2p/unknown".
:return: The version of libp2p.
:rtype: str
"""
try:
return f"py-libp2p/{version('libp2p')}"
except Exception as e:
logger.warning("Could not fetch libp2p version: %s", e)
return "py-libp2p/unknown"

View File

@ -4,8 +4,11 @@ from libp2p.identity.identify.pb.identify_pb2 import (
Identify,
)
from libp2p.identity.identify.protocol import (
AGENT_VERSION,
ID,
PROTOCOL_VERSION,
_mk_identify_protobuf,
_multiaddr_to_bytes,
)
from tests.factories import (
host_pair_factory,
@ -24,4 +27,26 @@ async def test_identify_protocol(security_protocol):
identify_response = Identify()
identify_response.ParseFromString(response)
# sanity check
assert identify_response == _mk_identify_protobuf(host_a)
# Check protocol version
assert identify_response.protocol_version == PROTOCOL_VERSION
# Check agent version
assert identify_response.agent_version == AGENT_VERSION
# Check public key
assert identify_response.public_key == host_a.get_public_key().serialize()
# Check listen addresses
assert identify_response.listen_addrs == list(
map(_multiaddr_to_bytes, host_a.get_addrs())
)
# TODO: Check observed address
# assert identify_response.observed_addr == host_b.get_addrs()[0]
# Check protocols
assert set(identify_response.protocols) == set(host_a.get_mux().get_protocols())