enforced f-strings everywhere, %s on logging

extended _multiaddr_from_socket to support UDP and IPv6 automatically
changed TCPListener to use _ip4_or_6_from_multiaddr to get host, and not ip4 only

enforced `from error` everywhere with raises
added call braces to exceptions
This commit is contained in:
Jonathan de Jong
2019-12-19 17:31:18 +01:00
parent 6cf1b98a88
commit b1248ff315
16 changed files with 136 additions and 74 deletions

View File

@ -32,7 +32,7 @@ class ECCPublicKey(PublicKey):
return KeyType.ECC_P256 return KeyType.ECC_P256
def verify(self, data: bytes, signature: bytes) -> bool: def verify(self, data: bytes, signature: bytes) -> bool:
raise NotImplementedError raise NotImplementedError()
class ECCPrivateKey(PrivateKey): class ECCPrivateKey(PrivateKey):
@ -53,7 +53,7 @@ class ECCPrivateKey(PrivateKey):
return KeyType.ECC_P256 return KeyType.ECC_P256
def sign(self, data: bytes) -> bytes: def sign(self, data: bytes) -> bytes:
raise NotImplementedError raise NotImplementedError()
def get_public_key(self) -> PublicKey: def get_public_key(self) -> PublicKey:
public_key_impl = keys.get_public_key(self.impl, self.curve) public_key_impl = keys.get_public_key(self.impl, self.curve)

View File

@ -20,8 +20,10 @@ def deserialize_public_key(data: bytes) -> PublicKey:
f = PublicKey.deserialize_from_protobuf(data) f = PublicKey.deserialize_from_protobuf(data)
try: try:
deserializer = key_type_to_public_key_deserializer[f.key_type] deserializer = key_type_to_public_key_deserializer[f.key_type]
except KeyError: except KeyError as e:
raise MissingDeserializerError({"key_type": f.key_type, "key": "public_key"}) raise MissingDeserializerError(
{"key_type": f.key_type, "key": "public_key"}
) from e
return deserializer(f.data) return deserializer(f.data)
@ -29,6 +31,8 @@ def deserialize_private_key(data: bytes) -> PrivateKey:
f = PrivateKey.deserialize_from_protobuf(data) f = PrivateKey.deserialize_from_protobuf(data)
try: try:
deserializer = key_type_to_private_key_deserializer[f.key_type] deserializer = key_type_to_private_key_deserializer[f.key_type]
except KeyError: except KeyError as e:
raise MissingDeserializerError({"key_type": f.key_type, "key": "private_key"}) raise MissingDeserializerError(
{"key_type": f.key_type, "key": "private_key"}
) from e
return deserializer(f.data) return deserializer(f.data)

View File

@ -93,7 +93,7 @@ class BasicHost(IHost):
:return: all the multiaddr addresses this host is listening to :return: all the multiaddr addresses this host is listening to
""" """
# TODO: We don't need "/p2p/{peer_id}" postfix actually. # TODO: We don't need "/p2p/{peer_id}" postfix actually.
p2p_part = multiaddr.Multiaddr("/p2p/{}".format(self.get_id().pretty())) p2p_part = multiaddr.Multiaddr(f"/p2p/{self.get_id()!s}")
addrs: List[multiaddr.Multiaddr] = [] addrs: List[multiaddr.Multiaddr] = []
for transport in self._network.listeners.values(): for transport in self._network.listeners.values():
@ -131,7 +131,7 @@ class BasicHost(IHost):
except MultiselectClientError as error: except MultiselectClientError as error:
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error) logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)
await net_stream.reset() await net_stream.reset()
raise StreamFailure("failt to open a stream to peer %s", peer_id) from error raise StreamFailure(f"failed to open a stream to peer {peer_id}") from error
net_stream.set_protocol(selected_protocol) net_stream.set_protocol(selected_protocol)
return net_stream return net_stream

View File

@ -29,7 +29,7 @@ class RawConnection(IRawConnection):
try: try:
self.writer.write(data) self.writer.write(data)
except ConnectionResetError as error: except ConnectionResetError as error:
raise RawConnError(error) raise RawConnError() from error
# Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501 # Reference: https://github.com/ethereum/lahja/blob/93610b2eb46969ff1797e0748c7ac2595e130aef/lahja/asyncio/endpoint.py#L99-L102 # noqa: E501
# Use a lock to serialize drain() calls. Circumvents this bug: # Use a lock to serialize drain() calls. Circumvents this bug:
# https://bugs.python.org/issue29930 # https://bugs.python.org/issue29930
@ -37,7 +37,7 @@ class RawConnection(IRawConnection):
try: try:
await self.writer.drain() await self.writer.drain()
except ConnectionResetError as error: except ConnectionResetError as error:
raise RawConnError(error) raise RawConnError() from error
async def read(self, n: int = -1) -> bytes: async def read(self, n: int = -1) -> bytes:
""" """
@ -49,7 +49,7 @@ class RawConnection(IRawConnection):
try: try:
return await self.reader.read(n) return await self.reader.read(n)
except ConnectionResetError as error: except ConnectionResetError as error:
raise RawConnError(error) raise RawConnError() from error
async def close(self) -> None: async def close(self) -> None:
self.writer.close() self.writer.close()

View File

@ -47,9 +47,9 @@ class NetStream(INetStream):
try: try:
return await self.muxed_stream.read(n) return await self.muxed_stream.read(n)
except MuxedStreamEOF as error: except MuxedStreamEOF as error:
raise StreamEOF from error raise StreamEOF() from error
except MuxedStreamReset as error: except MuxedStreamReset as error:
raise StreamReset from error raise StreamReset() from error
async def write(self, data: bytes) -> int: async def write(self, data: bytes) -> int:
""" """
@ -60,7 +60,7 @@ class NetStream(INetStream):
try: try:
return await self.muxed_stream.write(data) return await self.muxed_stream.write(data)
except MuxedStreamClosed as error: except MuxedStreamClosed as error:
raise StreamClosed from error raise StreamClosed() from error
async def close(self) -> None: async def close(self) -> None:
"""close stream.""" """close stream."""

View File

@ -87,8 +87,8 @@ class Swarm(INetwork):
try: try:
# Get peer info from peer store # Get peer info from peer store
addrs = self.peerstore.addrs(peer_id) addrs = self.peerstore.addrs(peer_id)
except PeerStoreError: except PeerStoreError as e:
raise SwarmException(f"No known addresses to peer {peer_id}") raise SwarmException(f"No known addresses to peer {peer_id}") from e
if not addrs: if not addrs:
raise SwarmException(f"No known addresses to peer {peer_id}") raise SwarmException(f"No known addresses to peer {peer_id}")
@ -101,7 +101,7 @@ class Swarm(INetwork):
except OpenConnectionError as error: except OpenConnectionError as error:
logger.debug("fail to dial peer %s over base transport", peer_id) logger.debug("fail to dial peer %s over base transport", peer_id)
raise SwarmException( raise SwarmException(
"fail to open connection to peer %s", peer_id f"fail to open connection to peer {peer_id}"
) from error ) from error
logger.debug("dialed peer %s over base transport", peer_id) logger.debug("dialed peer %s over base transport", peer_id)
@ -111,20 +111,20 @@ class Swarm(INetwork):
try: try:
secured_conn = await self.upgrader.upgrade_security(raw_conn, peer_id, True) secured_conn = await self.upgrader.upgrade_security(raw_conn, peer_id, True)
except SecurityUpgradeFailure as error: except SecurityUpgradeFailure as error:
error_msg = "fail to upgrade security for peer %s" logger.debug("failed to upgrade security for peer %s", peer_id)
logger.debug(error_msg, peer_id)
await raw_conn.close() await raw_conn.close()
raise SwarmException(error_msg % peer_id) from error # wip raise SwarmException(
f"failed to upgrade security for peer {peer_id}"
) from error
logger.debug("upgraded security for peer %s", peer_id) logger.debug("upgraded security for peer %s", peer_id)
try: try:
muxed_conn = await self.upgrader.upgrade_connection(secured_conn, peer_id) muxed_conn = await self.upgrader.upgrade_connection(secured_conn, peer_id)
except MuxerUpgradeFailure as error: except MuxerUpgradeFailure as error:
error_msg = "fail to upgrade mux for peer %s" logger.debug("failed to upgrade mux for peer %s", peer_id)
logger.debug(error_msg, peer_id)
await secured_conn.close() await secured_conn.close()
raise SwarmException(error_msg % peer_id) from error # wip raise SwarmException(f"failed to upgrade mux for peer {peer_id}") from error
logger.debug("upgraded mux for peer %s", peer_id) logger.debug("upgraded mux for peer %s", peer_id)
@ -187,10 +187,11 @@ class Swarm(INetwork):
raw_conn, ID(b""), False raw_conn, ID(b""), False
) )
except SecurityUpgradeFailure as error: except SecurityUpgradeFailure as error:
error_msg = "fail to upgrade security for peer at %s" logger.debug("failed to upgrade security for peer at %s", peer_addr)
logger.debug(error_msg, peer_addr)
await raw_conn.close() await raw_conn.close()
raise SwarmException(error_msg % peer_addr) from error # wip raise SwarmException(
f"failed to upgrade security for peer at {peer_addr}"
) from error
peer_id = secured_conn.get_remote_peer() peer_id = secured_conn.get_remote_peer()
logger.debug("upgraded security for peer at %s", peer_addr) logger.debug("upgraded security for peer at %s", peer_addr)
@ -201,10 +202,11 @@ class Swarm(INetwork):
secured_conn, peer_id secured_conn, peer_id
) )
except MuxerUpgradeFailure as error: except MuxerUpgradeFailure as error:
error_msg = "fail to upgrade mux for peer %s" logger.debug("fail to upgrade mux for peer %s", peer_id)
logger.debug(error_msg, peer_id)
await secured_conn.close() await secured_conn.close()
raise SwarmException(error_msg % peer_id) from error # wip raise SwarmException(
f"fail to upgrade mux for peer {peer_id}"
) from error
logger.debug("upgraded mux for peer %s", peer_id) logger.debug("upgraded mux for peer %s", peer_id)
await self.add_conn(muxed_conn) await self.add_conn(muxed_conn)
@ -223,7 +225,7 @@ class Swarm(INetwork):
return True return True
except IOError: except IOError:
# Failed. Continue looping. # Failed. Continue looping.
logger.debug("fail to listen on: " + str(maddr)) logger.debug("fail to listen on: %s", maddr)
# No maddr succeeded # No maddr succeeded
return False return False

View File

@ -71,7 +71,7 @@ class PeerStore(IPeerStore):
try: try:
val = self.peer_data_map[peer_id].get_metadata(key) val = self.peer_data_map[peer_id].get_metadata(key)
except PeerDataError as error: except PeerDataError as error:
raise PeerStoreError(error) raise PeerStoreError() from error
return val return val
raise PeerStoreError("peer ID not found") raise PeerStoreError("peer ID not found")
@ -153,8 +153,8 @@ class PeerStore(IPeerStore):
peer_data = self.peer_data_map[peer_id] peer_data = self.peer_data_map[peer_id]
try: try:
pubkey = peer_data.get_pubkey() pubkey = peer_data.get_pubkey()
except PeerDataError: except PeerDataError as e:
raise PeerStoreError("peer pubkey not found") raise PeerStoreError("peer pubkey not found") from e
return pubkey return pubkey
raise PeerStoreError("peer ID not found") raise PeerStoreError("peer ID not found")
@ -179,8 +179,8 @@ class PeerStore(IPeerStore):
peer_data = self.peer_data_map[peer_id] peer_data = self.peer_data_map[peer_id]
try: try:
privkey = peer_data.get_privkey() privkey = peer_data.get_privkey()
except PeerDataError: except PeerDataError as e:
raise PeerStoreError("peer privkey not found") raise PeerStoreError("peer privkey not found") from e
return privkey return privkey
raise PeerStoreError("peer ID not found") raise PeerStoreError("peer ID not found")

View File

@ -49,7 +49,7 @@ class Multiselect(IMultiselectMuxer):
try: try:
command = await communicator.read() command = await communicator.read()
except MultiselectCommunicatorError as error: except MultiselectCommunicatorError as error:
raise MultiselectError(error) raise MultiselectError() from error
if command == "ls": if command == "ls":
# TODO: handle ls command # TODO: handle ls command
@ -60,13 +60,13 @@ class Multiselect(IMultiselectMuxer):
try: try:
await communicator.write(protocol) await communicator.write(protocol)
except MultiselectCommunicatorError as error: except MultiselectCommunicatorError as error:
raise MultiselectError(error) raise MultiselectError() from error
return protocol, self.handlers[protocol] return protocol, self.handlers[protocol]
try: try:
await communicator.write(PROTOCOL_NOT_FOUND_MSG) await communicator.write(PROTOCOL_NOT_FOUND_MSG)
except MultiselectCommunicatorError as error: except MultiselectCommunicatorError as error:
raise MultiselectError(error) raise MultiselectError() from error
async def handshake(self, communicator: IMultiselectCommunicator) -> None: async def handshake(self, communicator: IMultiselectCommunicator) -> None:
""" """
@ -78,12 +78,12 @@ class Multiselect(IMultiselectMuxer):
try: try:
await communicator.write(MULTISELECT_PROTOCOL_ID) await communicator.write(MULTISELECT_PROTOCOL_ID)
except MultiselectCommunicatorError as error: except MultiselectCommunicatorError as error:
raise MultiselectError(error) raise MultiselectError() from error
try: try:
handshake_contents = await communicator.read() handshake_contents = await communicator.read()
except MultiselectCommunicatorError as error: except MultiselectCommunicatorError as error:
raise MultiselectError(error) raise MultiselectError() from error
if not is_valid_handshake(handshake_contents): if not is_valid_handshake(handshake_contents):
raise MultiselectError( raise MultiselectError(

View File

@ -25,12 +25,12 @@ class MultiselectClient(IMultiselectClient):
try: try:
await communicator.write(MULTISELECT_PROTOCOL_ID) await communicator.write(MULTISELECT_PROTOCOL_ID)
except MultiselectCommunicatorError as error: except MultiselectCommunicatorError as error:
raise MultiselectClientError(error) raise MultiselectClientError() from error
try: try:
handshake_contents = await communicator.read() handshake_contents = await communicator.read()
except MultiselectCommunicatorError as error: except MultiselectCommunicatorError as error:
raise MultiselectClientError(str(error)) raise MultiselectClientError() from error
if not is_valid_handshake(handshake_contents): if not is_valid_handshake(handshake_contents):
raise MultiselectClientError("multiselect protocol ID mismatch") raise MultiselectClientError("multiselect protocol ID mismatch")
@ -73,18 +73,18 @@ class MultiselectClient(IMultiselectClient):
try: try:
await communicator.write(protocol) await communicator.write(protocol)
except MultiselectCommunicatorError as error: except MultiselectCommunicatorError as error:
raise MultiselectClientError(error) raise MultiselectClientError() from error
try: try:
response = await communicator.read() response = await communicator.read()
except MultiselectCommunicatorError as error: except MultiselectCommunicatorError as error:
raise MultiselectClientError(str(error)) raise MultiselectClientError() from error
if response == protocol: if response == protocol:
return protocol return protocol
if response == PROTOCOL_NOT_FOUND_MSG: if response == PROTOCOL_NOT_FOUND_MSG:
raise MultiselectClientError("protocol not supported") raise MultiselectClientError("protocol not supported")
raise MultiselectClientError("unrecognized response: " + response) raise MultiselectClientError(f"unrecognized response: {response}")
def is_valid_handshake(handshake_contents: str) -> bool: def is_valid_handshake(handshake_contents: str) -> bool:

View File

@ -52,13 +52,13 @@ class InsecureSession(BaseSession):
encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes) encoded_msg_bytes = encode_fixedint_prefixed(msg_bytes)
try: try:
await self.write(encoded_msg_bytes) await self.write(encoded_msg_bytes)
except RawConnError: except RawConnError as e:
raise HandshakeFailure("connection closed") raise HandshakeFailure("connection closed") from e
try: try:
remote_msg_bytes = await read_fixedint_prefixed(self.conn) remote_msg_bytes = await read_fixedint_prefixed(self.conn)
except RawConnError: except RawConnError as e:
raise HandshakeFailure("connection closed") raise HandshakeFailure("connection closed") from e
remote_msg = plaintext_pb2.Exchange() remote_msg = plaintext_pb2.Exchange()
remote_msg.ParseFromString(remote_msg_bytes) remote_msg.ParseFromString(remote_msg_bytes)
received_peer_id = ID(remote_msg.id) received_peer_id = ID(remote_msg.id)
@ -77,12 +77,12 @@ class InsecureSession(BaseSession):
received_pubkey = deserialize_public_key( received_pubkey = deserialize_public_key(
remote_msg.pubkey.SerializeToString() remote_msg.pubkey.SerializeToString()
) )
except ValueError: except ValueError as e:
raise HandshakeFailure( raise HandshakeFailure(
f"unknown `key_type` of remote_msg.pubkey={remote_msg.pubkey}" f"unknown `key_type` of remote_msg.pubkey={remote_msg.pubkey}"
) ) from e
except MissingDeserializerError as error: except MissingDeserializerError as error:
raise HandshakeFailure(error) raise HandshakeFailure() from error
peer_id_from_received_pubkey = ID.from_pubkey(received_pubkey) peer_id_from_received_pubkey = ID.from_pubkey(received_pubkey)
if peer_id_from_received_pubkey != received_peer_id: if peer_id_from_received_pubkey != received_peer_id:
raise HandshakeFailure( raise HandshakeFailure(

View File

@ -131,8 +131,8 @@ class SecureSession(BaseSession):
msg = await self.conn.read_msg() msg = await self.conn.read_msg()
try: try:
decrypted_msg = self.remote_encrypter.decrypt_if_valid(msg) decrypted_msg = self.remote_encrypter.decrypt_if_valid(msg)
except InvalidMACException: except InvalidMACException as e:
raise DecryptionFailedException raise DecryptionFailedException() from e
return decrypted_msg return decrypted_msg
async def write(self, data: bytes) -> int: async def write(self, data: bytes) -> int:
@ -175,7 +175,7 @@ class Proposal:
try: try:
public_key = deserialize_public_key(public_key_protobuf_bytes) public_key = deserialize_public_key(public_key_protobuf_bytes)
except MissingDeserializerError as error: except MissingDeserializerError as error:
raise SedesException(error) raise SedesException() from error
exchanges = protobuf.exchanges exchanges = protobuf.exchanges
ciphers = protobuf.ciphers ciphers = protobuf.ciphers
hashes = protobuf.hashes hashes = protobuf.hashes
@ -424,8 +424,8 @@ async def create_secure_session(
await conn.close() await conn.close()
raise e raise e
# `IOException` includes errors raised while read from/write to raw connection # `IOException` includes errors raised while read from/write to raw connection
except IOException: except IOException as e:
raise SecioException("connection closed") raise SecioException("connection closed") from e
is_initiator = remote_peer is not None is_initiator = remote_peer is not None
session = _mk_session_from( session = _mk_session_from(
@ -435,8 +435,8 @@ async def create_secure_session(
try: try:
received_nonce = await _finish_handshake(session, remote_nonce) received_nonce = await _finish_handshake(session, remote_nonce)
# `IOException` includes errors raised while read from/write to raw connection # `IOException` includes errors raised while read from/write to raw connection
except IOException: except IOException as e:
raise SecioException("connection closed") raise SecioException("connection closed") from e
if received_nonce != local_nonce: if received_nonce != local_nonce:
await conn.close() await conn.close()
raise InconsistentNonce() raise InconsistentNonce()

View File

@ -75,7 +75,7 @@ class MplexStream(IMuxedStream):
if task_event_reset in done: if task_event_reset in done:
if self.event_reset.is_set(): if self.event_reset.is_set():
raise MplexStreamReset raise MplexStreamReset()
else: else:
# However, it is abnormal that `Event.wait` is unblocked without any of the flag # However, it is abnormal that `Event.wait` is unblocked without any of the flag
# is set. The task is probably cancelled. # is set. The task is probably cancelled.
@ -91,7 +91,7 @@ class MplexStream(IMuxedStream):
if task_event_remote_closed in done: if task_event_remote_closed in done:
if self.event_remote_closed.is_set(): if self.event_remote_closed.is_set():
raise MplexStreamEOF raise MplexStreamEOF()
else: else:
# However, it is abnormal that `Event.wait` is unblocked without any of the flag # However, it is abnormal that `Event.wait` is unblocked without any of the flag
# is set. The task is probably cancelled. # is set. The task is probably cancelled.
@ -126,7 +126,7 @@ class MplexStream(IMuxedStream):
f"the number of bytes to read `n` must be positive or -1 to indicate read until EOF" f"the number of bytes to read `n` must be positive or -1 to indicate read until EOF"
) )
if self.event_reset.is_set(): if self.event_reset.is_set():
raise MplexStreamReset raise MplexStreamReset()
if n == -1: if n == -1:
return await self._read_until_eof() return await self._read_until_eof()
if len(self._buf) == 0 and self.incoming_data.empty(): if len(self._buf) == 0 and self.incoming_data.empty():

View File

@ -56,7 +56,7 @@ class P2PDProcess:
is_pubsub_signing: bool = False, is_pubsub_signing: bool = False,
is_pubsub_signing_strict: bool = False, is_pubsub_signing_strict: bool = False,
) -> None: ) -> None:
args = [f"-listen={str(control_maddr)}"] args = [f"-listen={control_maddr!s}"]
# NOTE: To support `-insecure`, we need to hack `go-libp2p-daemon`. # NOTE: To support `-insecure`, we need to hack `go-libp2p-daemon`.
if not is_secure: if not is_secure:
args.append("-insecure=true") args.append("-insecure=true")

View File

@ -81,7 +81,7 @@ class DummyAccountNode:
:param dest_user: user to send crypto to :param dest_user: user to send crypto to
:param amount: amount of crypto to send :param amount: amount of crypto to send
""" """
msg_contents = "send," + source_user + "," + dest_user + "," + str(amount) msg_contents = f"send,{source_user},{dest_user},{amount!s}"
await self.pubsub.publish(CRYPTO_TOPIC, msg_contents.encode()) await self.pubsub.publish(CRYPTO_TOPIC, msg_contents.encode())
async def publish_set_crypto(self, user: str, amount: int) -> None: async def publish_set_crypto(self, user: str, amount: int) -> None:
@ -92,7 +92,7 @@ class DummyAccountNode:
:param user: user to set crypto for :param user: user to set crypto for
:param amount: amount of crypto :param amount: amount of crypto
""" """
msg_contents = "set," + user + "," + str(amount) msg_contents = f"set,{user},{amount!s}"
await self.pubsub.publish(CRYPTO_TOPIC, msg_contents.encode()) await self.pubsub.publish(CRYPTO_TOPIC, msg_contents.encode())
def handle_send_crypto(self, source_user: str, dest_user: str, amount: int) -> None: def handle_send_crypto(self, source_user: str, dest_user: str, amount: int) -> None:

View File

@ -51,7 +51,7 @@ async def echo_stream_handler(stream: INetStream) -> None:
while True: while True:
read_string = (await stream.read(MAX_READ_LEN)).decode() read_string = (await stream.read(MAX_READ_LEN)).decode()
resp = "ack:" + read_string resp = f"ack:{read_string}"
await stream.write(resp.encode()) await stream.write(resp.encode())

View File

@ -1,9 +1,11 @@
import asyncio import asyncio
from socket import socket import socket
import sys import sys
from typing import List from typing import List, Optional
from multiaddr import Multiaddr from multiaddr import Multiaddr
from multiaddr.protocols import P_IP4, P_IP6, P_TCP, P_UDP
from multiaddr.protocols import protocol_with_code as p_code
from libp2p.network.connection.raw_connection import RawConnection from libp2p.network.connection.raw_connection import RawConnection
from libp2p.network.connection.raw_connection_interface import IRawConnection from libp2p.network.connection.raw_connection_interface import IRawConnection
@ -29,10 +31,14 @@ class TCPListener(IListener):
:param maddr: maddr of peer :param maddr: maddr of peer
:return: return True if successful :return: return True if successful
""" """
listen_addr = _ip4_or_6_from_multiaddr(maddr)
if listen_addr is None:
raise NotImplementedError(
"Can only start TCP Listener with a IPv4 or IPv6 address"
)
self.server = await asyncio.start_server( self.server = await asyncio.start_server(
self.handler, self.handler, listen_addr, maddr.value_for_protocol("tcp")
maddr.value_for_protocol("ip4"),
maddr.value_for_protocol("tcp"),
) )
socket = self.server.sockets[0] socket = self.server.sockets[0]
self.multiaddrs.append(_multiaddr_from_socket(socket)) self.multiaddrs.append(_multiaddr_from_socket(socket))
@ -70,7 +76,10 @@ class TCP(ITransport):
:return: `RawConnection` if successful :return: `RawConnection` if successful
:raise OpenConnectionError: raised when failed to open connection :raise OpenConnectionError: raised when failed to open connection
""" """
self.host = maddr.value_for_protocol("ip4") self.host = _ip4_or_6_from_multiaddr(maddr)
if self.host is None:
raise ValueError("Cannot find ipv4 or ipv6 host in multiaddress")
self.port = int(maddr.value_for_protocol("tcp")) self.port = int(maddr.value_for_protocol("tcp"))
try: try:
@ -91,5 +100,52 @@ class TCP(ITransport):
return TCPListener(handler_function) return TCPListener(handler_function)
def _multiaddr_from_socket(socket: socket) -> Multiaddr: def _ip4_or_6_from_multiaddr(maddr: Multiaddr) -> Optional[str]:
return Multiaddr("/ip4/%s/tcp/%s" % socket.getsockname()) # wip if P_IP4 in maddr.protocols():
return maddr.value_for_protocol(P_IP4)
elif P_IP6 in maddr.protocols():
return maddr.value_for_protocol(P_IP6)
else:
return None
def _multiaddr_from_socket(sock: socket.socket) -> Multiaddr:
# Reference: http://man7.org/linux/man-pages/man2/socket.2.html#DESCRIPTION
# todo: move this to more generic libp2p.transport helper function
# Reference: https://stackoverflow.com/questions/5815675/what-is-sock-dgram-and-sock-stream
# Selects first protocol in sequence if bitwise AND matches, else None
t_proto = next(
(
v
for k, v in {
socket.SOCK_STREAM: p_code(P_TCP).name,
socket.SOCK_DGRAM: p_code(P_UDP).name,
}.items()
if k & sock.type != 0
),
None,
)
if t_proto is None:
raise NotImplementedError(
f"Cannot convert socket to multiaddr, socket type is of {sock.type}"
)
# Reference: https://docs.python.org/3/library/socket.html#socket-families
if sock.family == socket.AF_INET:
# ipv4: (host, port)
addr, port = sock.getsockname()
ip = p_code(P_IP4).name
elif sock.family == socket.AF_INET6:
# ipv6: (host, port, flowinfo, scopeid)
addr, port = sock.getsockname()[:2]
ip = p_code(P_IP6).name
else:
raise NotImplementedError(
f"Cannot convert socket to multiaddr, socket family is of {sock.family}"
)
return Multiaddr(f"/{ip}/{addr}/{t_proto}/{port}")