Security: SecureSession

Make security sessions(secio, noise) share the same implementation
`BaseSession` to avoid duplicate implementation of buffered read.
This commit is contained in:
mhchia
2020-02-17 23:33:45 +08:00
parent 2df47a943c
commit 3c2e835725
8 changed files with 150 additions and 196 deletions

View File

@ -1,88 +0,0 @@
import io
from noise.connection import NoiseConnection as NoiseState
from libp2p.crypto.keys import PrivateKey
from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID
from libp2p.security.base_session import BaseSession
from libp2p.security.noise.io import MsgReadWriter, NoiseTransportReadWriter
class NoiseConnection(BaseSession):
buf: io.BytesIO
low_watermark: int
high_watermark: int
read_writer: IRawConnection
noise_state: NoiseState
def __init__(
self,
local_peer: ID,
local_private_key: PrivateKey,
remote_peer: ID,
conn: IRawConnection,
is_initiator: bool,
noise_state: NoiseState,
# remote_permanent_pubkey
) -> None:
super().__init__(local_peer, local_private_key, is_initiator, remote_peer)
self.conn = conn
self.noise_state = noise_state
self._reset_internal_buffer()
def get_msg_read_writer(self) -> MsgReadWriter:
return NoiseTransportReadWriter(self.conn, self.noise_state)
async def close(self) -> None:
await self.conn.close()
def _reset_internal_buffer(self) -> None:
self.buf = io.BytesIO()
self.low_watermark = 0
self.high_watermark = 0
def _drain(self, n: int) -> bytes:
if self.low_watermark == self.high_watermark:
return bytes()
data = self.buf.getbuffer()[self.low_watermark : self.high_watermark]
if n is None:
n = len(data)
result = data[:n].tobytes()
self.low_watermark += len(result)
if self.low_watermark == self.high_watermark:
del data # free the memoryview so we can free the underlying BytesIO
self.buf.close()
self._reset_internal_buffer()
return result
async def read(self, n: int = None) -> bytes:
if n == 0:
return bytes()
data_from_buffer = self._drain(n)
if len(data_from_buffer) > 0:
return data_from_buffer
msg = await self.read_msg()
if n < len(msg):
self.buf.write(msg)
self.low_watermark = 0
self.high_watermark = len(msg)
return self._drain(n)
else:
return msg
async def read_msg(self) -> bytes:
return await self.get_msg_read_writer().read_msg()
async def write(self, data: bytes) -> None:
await self.write_msg(data)
async def write_msg(self, msg: bytes) -> None:
await self.get_msg_read_writer().write_msg(msg)

View File

@ -1,14 +1,11 @@
from abc import ABC, abstractmethod
from typing import cast
from noise.connection import NoiseConnection as NoiseState
from libp2p.io.abc import ReadWriteCloser, MsgReadWriter, EncryptedMsgReadWriter
from libp2p.io.abc import EncryptedMsgReadWriter, MsgReadWriteCloser, ReadWriteCloser
from libp2p.io.msgio import BaseMsgReadWriter, encode_msg_with_length
from libp2p.io.utils import read_exactly
from libp2p.network.connection.raw_connection_interface import IRawConnection
SIZE_NOISE_MESSAGE_LEN = 2
MAX_NOISE_MESSAGE_LEN = 2 ** (8 * SIZE_NOISE_MESSAGE_LEN) - 1
SIZE_NOISE_MESSAGE_BODY_LEN = 2
@ -50,7 +47,14 @@ def decode_msg_body(noise_msg: bytes) -> bytes:
class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
read_writer: MsgReadWriter
"""
The base implementation of noise message reader/writer.
`encrypt` and `decrypt` are not implemented here, which should be
implemented by the subclasses.
"""
read_writer: MsgReadWriteCloser
noise_state: NoiseState
def __init__(self, conn: IRawConnection, noise_state: NoiseState) -> None:
@ -67,6 +71,9 @@ class BaseNoiseMsgReadWriter(EncryptedMsgReadWriter):
noise_msg = self.decrypt(noise_msg_encrypted)
return decode_msg_body(noise_msg)
async def close(self) -> None:
await self.read_writer.close()
class NoiseHandshakeReadWriter(BaseNoiseMsgReadWriter):
def encrypt(self, data: bytes) -> bytes:

View File

@ -8,15 +8,15 @@ from libp2p.crypto.keys import PrivateKey
from libp2p.network.connection.raw_connection_interface import IRawConnection
from libp2p.peer.id import ID
from libp2p.security.secure_conn_interface import ISecureConn
from libp2p.security.secure_session import SecureSession
from .connection import NoiseConnection
from .exceptions import (
HandshakeHasNotFinished,
InvalidSignature,
NoiseStateError,
PeerIDMismatchesPubkey,
)
from .io import encode_msg_body, decode_msg_body, NoiseHandshakeReadWriter
from .io import NoiseHandshakeReadWriter, NoiseTransportReadWriter
from .messages import (
NoiseHandshakePayload,
make_handshake_payload_sig,
@ -56,16 +56,6 @@ class BasePattern(IPattern):
)
return NoiseHandshakePayload(self.libp2p_privkey.get_public_key(), signature)
async def write_msg(self, conn: IRawConnection, data: bytes) -> None:
noise_msg = encode_msg_body(data)
data_encrypted = self.noise_state.write_message(noise_msg)
await self.read_writer.write_msg(data_encrypted)
async def read_msg(self) -> bytes:
noise_msg_encrypted = await self.read_writer.read_msg()
noise_msg = self.noise_state.read_message(noise_msg_encrypted)
return decode_msg_body(noise_msg)
class PatternXX(BasePattern):
def __init__(
@ -116,14 +106,13 @@ class PatternXX(BasePattern):
raise HandshakeHasNotFinished(
"handshake is done but it is not marked as finished in `noise_state`"
)
return NoiseConnection(
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
return SecureSession(
self.local_peer,
self.libp2p_privkey,
remote_peer_id_from_pubkey,
conn,
transport_read_writer,
False,
noise_state,
)
async def handshake_outbound(
@ -171,7 +160,12 @@ class PatternXX(BasePattern):
raise HandshakeHasNotFinished(
"handshake is done but it is not marked as finished in `noise_state`"
)
transport_read_writer = NoiseTransportReadWriter(conn, noise_state)
return NoiseConnection(
self.local_peer, self.libp2p_privkey, remote_peer, conn, False, noise_state
return SecureSession(
self.local_peer,
self.libp2p_privkey,
remote_peer,
transport_read_writer,
False,
)