mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
fix: update conn and transport for security
This commit is contained in:
@ -76,7 +76,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
resource_scope: Any | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize enhanced QUIC connection with security integration.
|
||||
Initialize QUIC connection with security integration.
|
||||
|
||||
Args:
|
||||
quic_connection: aioquic QuicConnection instance
|
||||
@ -105,7 +105,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._connected_event = trio.Event()
|
||||
self._closed_event = trio.Event()
|
||||
|
||||
# Enhanced stream management
|
||||
# Stream management
|
||||
self._streams: dict[int, QUICStream] = {}
|
||||
self._next_stream_id: int = self._calculate_initial_stream_id()
|
||||
self._stream_handler: TQUICStreamHandlerFn | None = None
|
||||
@ -129,8 +129,8 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._peer_verified = False
|
||||
|
||||
# Security state
|
||||
self._peer_certificate: Optional[x509.Certificate] = None
|
||||
self._handshake_events = []
|
||||
self._peer_certificate: x509.Certificate | None = None
|
||||
self._handshake_events: list[events.HandshakeCompleted] = []
|
||||
|
||||
# Background task management
|
||||
self._background_tasks_started = False
|
||||
@ -466,7 +466,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
f"Alternative certificate extraction also failed: {inner_e}"
|
||||
)
|
||||
|
||||
async def get_peer_certificate(self) -> Optional[x509.Certificate]:
|
||||
async def get_peer_certificate(self) -> x509.Certificate | None:
|
||||
"""
|
||||
Get the peer's TLS certificate.
|
||||
|
||||
@ -511,7 +511,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
|
||||
def get_security_info(self) -> dict[str, Any]:
|
||||
"""Get security-related information about the connection."""
|
||||
info: dict[str, bool | Any | None]= {
|
||||
info: dict[str, bool | Any | None] = {
|
||||
"peer_verified": self._peer_verified,
|
||||
"handshake_complete": self._handshake_completed,
|
||||
"peer_id": str(self._peer_id) if self._peer_id else None,
|
||||
@ -534,7 +534,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
),
|
||||
"certificate_not_after": (
|
||||
self._peer_certificate.not_valid_after.isoformat()
|
||||
),
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
@ -574,7 +574,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
|
||||
async def open_stream(self, timeout: float = 5.0) -> QUICStream:
|
||||
"""
|
||||
Open a new outbound stream with enhanced error handling and resource management.
|
||||
Open a new outbound stream
|
||||
|
||||
Args:
|
||||
timeout: Timeout for stream creation
|
||||
@ -607,7 +607,6 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
stream_id = self._next_stream_id
|
||||
self._next_stream_id += 4 # Increment by 4 for bidirectional streams
|
||||
|
||||
# Create enhanced stream
|
||||
stream = QUICStream(
|
||||
connection=self,
|
||||
stream_id=stream_id,
|
||||
@ -766,7 +765,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
self._closed_event.set()
|
||||
|
||||
async def _handle_stream_data(self, event: events.StreamDataReceived) -> None:
|
||||
"""Enhanced stream data handling with proper error management."""
|
||||
"""Stream data handling with proper error management."""
|
||||
stream_id = event.stream_id
|
||||
self._stats["bytes_received"] += len(event.data)
|
||||
|
||||
@ -858,7 +857,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
return stream_id % 2 == 0
|
||||
|
||||
async def _handle_stream_reset(self, event: events.StreamReset) -> None:
|
||||
"""Enhanced stream reset handling."""
|
||||
"""Stream reset handling."""
|
||||
stream_id = event.stream_id
|
||||
self._stats["streams_reset"] += 1
|
||||
|
||||
@ -925,7 +924,7 @@ class QUICConnection(IRawConnection, IMuxedConn):
|
||||
# Connection close
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Enhanced connection close with proper stream cleanup."""
|
||||
"""Connection close with proper stream cleanup."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ import copy
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from aioquic.quic import events
|
||||
from aioquic.quic.configuration import QuicConfiguration
|
||||
@ -18,6 +18,7 @@ import trio
|
||||
|
||||
from libp2p.abc import IListener
|
||||
from libp2p.custom_types import THandler, TProtocol
|
||||
from libp2p.transport.quic.security import QUICTLSConfigManager
|
||||
|
||||
from .config import QUICTransportConfig
|
||||
from .connection import QUICConnection
|
||||
@ -51,6 +52,7 @@ class QUICListener(IListener):
|
||||
handler_function: THandler,
|
||||
quic_configs: dict[TProtocol, QuicConfiguration],
|
||||
config: QUICTransportConfig,
|
||||
security_manager: QUICTLSConfigManager | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize QUIC listener.
|
||||
@ -60,12 +62,14 @@ class QUICListener(IListener):
|
||||
handler_function: Function to handle new connections
|
||||
quic_configs: QUIC configurations for different versions
|
||||
config: QUIC transport configuration
|
||||
security_manager: Security manager for TLS/certificate handling
|
||||
|
||||
"""
|
||||
self._transport = transport
|
||||
self._handler = handler_function
|
||||
self._quic_configs = quic_configs
|
||||
self._config = config
|
||||
self._security_manager = security_manager
|
||||
|
||||
# Network components
|
||||
self._socket: trio.socket.SocketType | None = None
|
||||
@ -117,8 +121,10 @@ class QUICListener(IListener):
|
||||
host, port = quic_multiaddr_to_endpoint(maddr)
|
||||
quic_version = multiaddr_to_quic_version(maddr)
|
||||
|
||||
protocol = f"{quic_version}_server"
|
||||
|
||||
# Validate QUIC version support
|
||||
if quic_version not in self._quic_configs:
|
||||
if protocol not in self._quic_configs:
|
||||
raise QUICListenError(f"Unsupported QUIC version: {quic_version}")
|
||||
|
||||
# Create and bind UDP socket
|
||||
@ -379,6 +385,7 @@ class QUICListener(IListener):
|
||||
is_initiator=False, # We're the server
|
||||
maddr=remote_maddr,
|
||||
transport=self._transport,
|
||||
security_manager=self._security_manager,
|
||||
)
|
||||
|
||||
# Store the connection
|
||||
@ -389,8 +396,16 @@ class QUICListener(IListener):
|
||||
self._nursery.start_soon(connection._handle_datagram_received)
|
||||
self._nursery.start_soon(connection._handle_timer_events)
|
||||
|
||||
# TODO: Verify peer identity
|
||||
# await connection.verify_peer_identity()
|
||||
if self._security_manager:
|
||||
try:
|
||||
await connection._verify_peer_identity_with_security()
|
||||
logger.info(f"Security verification successful for {addr}")
|
||||
except Exception as e:
|
||||
logger.error(f"Security verification failed for {addr}: {e}")
|
||||
self._stats["security_failures"] += 1
|
||||
# Close the connection due to security failure
|
||||
await connection.close()
|
||||
return
|
||||
|
||||
# Call the connection handler
|
||||
if self._nursery:
|
||||
@ -569,6 +584,16 @@ class QUICListener(IListener):
|
||||
)
|
||||
return stats
|
||||
|
||||
def get_security_manager(self) -> Optional["QUICTLSConfigManager"]:
|
||||
"""
|
||||
Get the security manager for this listener.
|
||||
|
||||
Returns:
|
||||
The QUIC TLS configuration manager, or None if not configured
|
||||
|
||||
"""
|
||||
return self._security_manager
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the listener."""
|
||||
addr = self._bound_addresses
|
||||
|
||||
@ -5,18 +5,19 @@ Based on go-libp2p and js-libp2p security patterns.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ec, rsa
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
|
||||
from cryptography.x509.base import Certificate
|
||||
from cryptography.x509.oid import NameOID
|
||||
|
||||
from libp2p.crypto.ed25519 import Ed25519PublicKey
|
||||
from libp2p.crypto.keys import PrivateKey, PublicKey
|
||||
from libp2p.crypto.secp256k1 import Secp256k1PublicKey
|
||||
from libp2p.crypto.serialization import deserialize_public_key
|
||||
from libp2p.peer.id import ID
|
||||
|
||||
from .exceptions import (
|
||||
@ -24,6 +25,11 @@ from .exceptions import (
|
||||
QUICPeerVerificationError,
|
||||
)
|
||||
|
||||
TSecurityConfig = dict[
|
||||
str,
|
||||
Certificate | EllipticCurvePrivateKey | RSAPrivateKey | bool | list[str],
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# libp2p TLS Extension OID - Official libp2p specification
|
||||
@ -34,6 +40,7 @@ CERTIFICATE_VALIDITY_DAYS = 365
|
||||
CERTIFICATE_NOT_BEFORE_BUFFER = 3600 # 1 hour before now
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass
|
||||
class TLSConfig:
|
||||
"""TLS configuration for QUIC transport with libp2p extensions."""
|
||||
@ -43,17 +50,29 @@ class TLSConfig:
|
||||
peer_id: ID
|
||||
|
||||
def get_certificate_der(self) -> bytes:
|
||||
"""Get certificate in DER format for aioquic."""
|
||||
"""Get certificate in DER format for external use."""
|
||||
return self.certificate.public_bytes(serialization.Encoding.DER)
|
||||
|
||||
def get_private_key_der(self) -> bytes:
|
||||
"""Get private key in DER format for aioquic."""
|
||||
"""Get private key in DER format for external use."""
|
||||
return self.private_key.private_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
|
||||
def get_certificate_pem(self) -> bytes:
|
||||
"""Get certificate in PEM format."""
|
||||
return self.certificate.public_bytes(serialization.Encoding.PEM)
|
||||
|
||||
def get_private_key_pem(self) -> bytes:
|
||||
"""Get private key in PEM format."""
|
||||
return self.private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
|
||||
|
||||
class LibP2PExtensionHandler:
|
||||
"""
|
||||
@ -96,7 +115,8 @@ class LibP2PExtensionHandler:
|
||||
# In a full implementation, this would use proper ASN.1 encoding
|
||||
public_key_bytes = libp2p_public_key.serialize()
|
||||
|
||||
# Simple encoding: [public_key_length][public_key][signature_length][signature]
|
||||
# Simple encoding:
|
||||
# [public_key_length][public_key][signature_length][signature]
|
||||
extension_data = (
|
||||
len(public_key_bytes).to_bytes(4, byteorder="big")
|
||||
+ public_key_bytes
|
||||
@ -112,7 +132,7 @@ class LibP2PExtensionHandler:
|
||||
) from e
|
||||
|
||||
@staticmethod
|
||||
def parse_signed_key_extension(extension_data: bytes) -> Tuple[PublicKey, bytes]:
|
||||
def parse_signed_key_extension(extension_data: bytes) -> tuple[PublicKey, bytes]:
|
||||
"""
|
||||
Parse the libp2p Public Key Extension to extract public key and signature.
|
||||
|
||||
@ -158,8 +178,6 @@ class LibP2PExtensionHandler:
|
||||
|
||||
signature = extension_data[offset : offset + signature_length]
|
||||
|
||||
# Deserialize the public key
|
||||
# This is a simplified approach - full implementation would handle all key types
|
||||
public_key = LibP2PKeyConverter.deserialize_public_key(public_key_bytes)
|
||||
|
||||
return public_key, signature
|
||||
@ -199,21 +217,20 @@ class LibP2PKeyConverter:
|
||||
@staticmethod
|
||||
def deserialize_public_key(key_bytes: bytes) -> PublicKey:
|
||||
"""
|
||||
Deserialize libp2p public key from bytes.
|
||||
Deserialize libp2p public key from protobuf bytes.
|
||||
|
||||
Args:
|
||||
key_bytes: Protobuf-serialized public key bytes
|
||||
|
||||
Returns:
|
||||
Deserialized PublicKey instance
|
||||
|
||||
This is a simplified implementation - full version would handle
|
||||
all libp2p key types and proper deserialization.
|
||||
"""
|
||||
# For now, assume Ed25519 keys (most common in libp2p)
|
||||
# Full implementation would detect key type from bytes
|
||||
try:
|
||||
return Ed25519PublicKey.deserialize(key_bytes)
|
||||
except Exception:
|
||||
# Fallback to other key types
|
||||
try:
|
||||
return Secp256k1PublicKey.deserialize(key_bytes)
|
||||
except Exception:
|
||||
raise QUICCertificateError("Unsupported key type in extension")
|
||||
# Use the official libp2p deserialization function
|
||||
return deserialize_public_key(key_bytes)
|
||||
except Exception as e:
|
||||
raise QUICCertificateError(f"Failed to deserialize public key: {e}") from e
|
||||
|
||||
|
||||
class CertificateGenerator:
|
||||
@ -222,7 +239,7 @@ class CertificateGenerator:
|
||||
Follows libp2p TLS specification for QUIC transport.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.extension_handler = LibP2PExtensionHandler()
|
||||
self.key_converter = LibP2PKeyConverter()
|
||||
|
||||
@ -234,6 +251,7 @@ class CertificateGenerator:
|
||||
) -> TLSConfig:
|
||||
"""
|
||||
Generate a TLS certificate with embedded libp2p peer identity.
|
||||
Fixed to use datetime objects for validity periods.
|
||||
|
||||
Args:
|
||||
libp2p_private_key: The libp2p identity private key
|
||||
@ -265,24 +283,31 @@ class CertificateGenerator:
|
||||
libp2p_private_key, cert_public_key_bytes
|
||||
)
|
||||
|
||||
# Set validity period
|
||||
now = time.time()
|
||||
not_before = time.gmtime(now - CERTIFICATE_NOT_BEFORE_BUFFER)
|
||||
not_after = time.gmtime(now + (validity_days * 24 * 3600))
|
||||
# Set validity period using datetime objects (FIXED)
|
||||
now = datetime.utcnow() # Use datetime instead of time.time()
|
||||
not_before = now - timedelta(seconds=CERTIFICATE_NOT_BEFORE_BUFFER)
|
||||
not_after = now + timedelta(days=validity_days)
|
||||
|
||||
# Build certificate
|
||||
# Generate serial number
|
||||
serial_number = int(now.timestamp()) # Convert datetime to timestamp
|
||||
|
||||
# Build certificate with proper datetime objects
|
||||
certificate = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(
|
||||
x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, str(peer_id))])
|
||||
x509.Name(
|
||||
[x509.NameAttribute(NameOID.COMMON_NAME, peer_id.to_base58())] # type: ignore
|
||||
)
|
||||
)
|
||||
.issuer_name(
|
||||
x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, str(peer_id))])
|
||||
x509.Name(
|
||||
[x509.NameAttribute(NameOID.COMMON_NAME, peer_id.to_base58())] # type: ignore
|
||||
)
|
||||
)
|
||||
.public_key(cert_public_key)
|
||||
.serial_number(int(now)) # Use timestamp as serial number
|
||||
.not_valid_before(time.struct_time(not_before))
|
||||
.not_valid_after(time.struct_time(not_after))
|
||||
.serial_number(serial_number)
|
||||
.not_valid_before(not_before)
|
||||
.not_valid_after(not_after)
|
||||
.add_extension(
|
||||
x509.UnrecognizedExtension(
|
||||
oid=LIBP2P_TLS_EXTENSION_OID, value=extension_data
|
||||
@ -293,6 +318,7 @@ class CertificateGenerator:
|
||||
)
|
||||
|
||||
logger.info(f"Generated libp2p TLS certificate for peer {peer_id}")
|
||||
logger.debug(f"Certificate valid from {not_before} to {not_after}")
|
||||
|
||||
return TLSConfig(
|
||||
certificate=certificate, private_key=cert_private_key, peer_id=peer_id
|
||||
@ -308,11 +334,11 @@ class PeerAuthenticator:
|
||||
Validates both TLS certificate integrity and libp2p peer identity.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.extension_handler = LibP2PExtensionHandler()
|
||||
|
||||
def verify_peer_certificate(
|
||||
self, certificate: x509.Certificate, expected_peer_id: Optional[ID] = None
|
||||
self, certificate: x509.Certificate, expected_peer_id: ID | None = None
|
||||
) -> ID:
|
||||
"""
|
||||
Verify a peer's TLS certificate and extract/validate peer identity.
|
||||
@ -366,7 +392,8 @@ class PeerAuthenticator:
|
||||
# Verify against expected peer ID if provided
|
||||
if expected_peer_id and derived_peer_id != expected_peer_id:
|
||||
raise QUICPeerVerificationError(
|
||||
f"Peer ID mismatch: expected {expected_peer_id}, got {derived_peer_id}"
|
||||
f"Peer ID mismatch: expected {expected_peer_id}, "
|
||||
f"got {derived_peer_id}"
|
||||
)
|
||||
|
||||
logger.info(f"Successfully verified peer certificate for {derived_peer_id}")
|
||||
@ -397,38 +424,46 @@ class QUICTLSConfigManager:
|
||||
libp2p_private_key, peer_id
|
||||
)
|
||||
|
||||
def create_server_config(self) -> dict:
|
||||
def create_server_config(
|
||||
self,
|
||||
) -> TSecurityConfig:
|
||||
"""
|
||||
Create aioquic server configuration with libp2p TLS settings.
|
||||
Returns cryptography objects instead of DER bytes.
|
||||
|
||||
Returns:
|
||||
Configuration dictionary for aioquic QuicConfiguration
|
||||
|
||||
"""
|
||||
return {
|
||||
"certificate": self.tls_config.get_certificate_der(),
|
||||
"private_key": self.tls_config.get_private_key_der(),
|
||||
"alpn_protocols": ["libp2p"], # Required ALPN protocol
|
||||
"verify_mode": True, # Require client certificates
|
||||
config: TSecurityConfig = {
|
||||
"certificate": self.tls_config.certificate,
|
||||
"private_key": self.tls_config.private_key,
|
||||
"certificate_chain": [],
|
||||
"alpn_protocols": ["libp2p"],
|
||||
"verify_mode": True,
|
||||
}
|
||||
return config
|
||||
|
||||
def create_client_config(self) -> dict:
|
||||
def create_client_config(self) -> TSecurityConfig:
|
||||
"""
|
||||
Create aioquic client configuration with libp2p TLS settings.
|
||||
Returns cryptography objects instead of DER bytes.
|
||||
|
||||
Returns:
|
||||
Configuration dictionary for aioquic QuicConfiguration
|
||||
|
||||
"""
|
||||
return {
|
||||
"certificate": self.tls_config.get_certificate_der(),
|
||||
"private_key": self.tls_config.get_private_key_der(),
|
||||
"alpn_protocols": ["libp2p"], # Required ALPN protocol
|
||||
"verify_mode": True, # Verify server certificate
|
||||
config: TSecurityConfig = {
|
||||
"certificate": self.tls_config.certificate,
|
||||
"private_key": self.tls_config.private_key,
|
||||
"certificate_chain": [],
|
||||
"alpn_protocols": ["libp2p"],
|
||||
"verify_mode": True,
|
||||
}
|
||||
return config
|
||||
|
||||
def verify_peer_identity(
|
||||
self, peer_certificate: x509.Certificate, expected_peer_id: Optional[ID] = None
|
||||
self, peer_certificate: x509.Certificate, expected_peer_id: ID | None = None
|
||||
) -> ID:
|
||||
"""
|
||||
Verify remote peer's identity from their TLS certificate.
|
||||
|
||||
@ -5,6 +5,7 @@ Based on aioquic library with interface consistency to go-libp2p and js-libp2p.
|
||||
Updated to include Module 5 security integration.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
import copy
|
||||
import logging
|
||||
|
||||
@ -16,7 +17,6 @@ from aioquic.quic.connection import (
|
||||
)
|
||||
import multiaddr
|
||||
import trio
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from libp2p.abc import (
|
||||
IRawConnection,
|
||||
@ -29,13 +29,13 @@ from libp2p.custom_types import THandler, TProtocol
|
||||
from libp2p.peer.id import (
|
||||
ID,
|
||||
)
|
||||
from libp2p.transport.quic.config import QUICTransportKwargs
|
||||
from libp2p.transport.quic.security import TSecurityConfig
|
||||
from libp2p.transport.quic.utils import (
|
||||
get_alpn_protocols,
|
||||
is_quic_multiaddr,
|
||||
multiaddr_to_quic_version,
|
||||
quic_multiaddr_to_endpoint,
|
||||
quic_version_to_wire_format,
|
||||
get_alpn_protocols,
|
||||
)
|
||||
|
||||
from .config import (
|
||||
@ -111,7 +111,7 @@ class QUICTransport(ITransport):
|
||||
)
|
||||
|
||||
def _setup_quic_configurations(self) -> None:
|
||||
"""Setup QUIC configurations for supported protocol versions with TLS security."""
|
||||
"""Setup QUIC configurations."""
|
||||
try:
|
||||
# Get TLS configuration from security manager
|
||||
server_tls_config = self._security_manager.create_server_config()
|
||||
@ -140,12 +140,12 @@ class QUICTransport(ITransport):
|
||||
self._apply_tls_configuration(base_client_config, client_tls_config)
|
||||
|
||||
# QUIC v1 (RFC 9000) configurations
|
||||
quic_v1_server_config = copy.deepcopy(base_server_config)
|
||||
quic_v1_server_config = copy.copy(base_server_config)
|
||||
quic_v1_server_config.supported_versions = [
|
||||
quic_version_to_wire_format(QUIC_V1_PROTOCOL)
|
||||
]
|
||||
|
||||
quic_v1_client_config = copy.deepcopy(base_client_config)
|
||||
quic_v1_client_config = copy.copy(base_client_config)
|
||||
quic_v1_client_config.supported_versions = [
|
||||
quic_version_to_wire_format(QUIC_V1_PROTOCOL)
|
||||
]
|
||||
@ -160,12 +160,12 @@ class QUICTransport(ITransport):
|
||||
|
||||
# QUIC draft-29 configurations for compatibility
|
||||
if self._config.enable_draft29:
|
||||
draft29_server_config = copy.deepcopy(base_server_config)
|
||||
draft29_server_config: QuicConfiguration = copy.copy(base_server_config)
|
||||
draft29_server_config.supported_versions = [
|
||||
quic_version_to_wire_format(QUIC_DRAFT29_PROTOCOL)
|
||||
]
|
||||
|
||||
draft29_client_config = copy.deepcopy(base_client_config)
|
||||
draft29_client_config = copy.copy(base_client_config)
|
||||
draft29_client_config.supported_versions = [
|
||||
quic_version_to_wire_format(QUIC_DRAFT29_PROTOCOL)
|
||||
]
|
||||
@ -185,10 +185,10 @@ class QUICTransport(ITransport):
|
||||
) from e
|
||||
|
||||
def _apply_tls_configuration(
|
||||
self, config: QuicConfiguration, tls_config: dict
|
||||
self, config: QuicConfiguration, tls_config: TSecurityConfig
|
||||
) -> None:
|
||||
"""
|
||||
Apply TLS configuration to QuicConfiguration.
|
||||
Apply TLS configuration to a QUIC configuration using aioquic's actual API.
|
||||
|
||||
Args:
|
||||
config: QuicConfiguration to update
|
||||
@ -196,22 +196,54 @@ class QUICTransport(ITransport):
|
||||
|
||||
"""
|
||||
try:
|
||||
# Set certificate and private key
|
||||
# Set certificate and private key directly on the configuration
|
||||
# aioquic expects cryptography objects, not DER bytes
|
||||
if "certificate" in tls_config and "private_key" in tls_config:
|
||||
# aioquic expects certificate and private key in specific formats
|
||||
# This is a simplified approach - full implementation would handle
|
||||
# proper certificate chain setup
|
||||
config.load_cert_chain_from_der(
|
||||
tls_config["certificate"], tls_config["private_key"]
|
||||
)
|
||||
# The security manager should return cryptography objects
|
||||
# not DER bytes, but if it returns DER bytes, we need to handle that
|
||||
certificate = tls_config["certificate"]
|
||||
private_key = tls_config["private_key"]
|
||||
|
||||
# Check if we received DER bytes and need
|
||||
# to convert to cryptography objects
|
||||
if isinstance(certificate, bytes):
|
||||
from cryptography import x509
|
||||
|
||||
certificate = x509.load_der_x509_certificate(certificate)
|
||||
|
||||
if isinstance(private_key, bytes):
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
private_key = serialization.load_der_private_key( # type: ignore
|
||||
private_key, password=None
|
||||
)
|
||||
|
||||
# Set directly on the configuration object
|
||||
config.certificate = certificate
|
||||
config.private_key = private_key
|
||||
|
||||
# Handle certificate chain if provided
|
||||
certificate_chain = tls_config.get("certificate_chain", [])
|
||||
if certificate_chain and isinstance(certificate_chain, Iterable):
|
||||
# Convert DER bytes to cryptography objects if needed
|
||||
chain_objects = []
|
||||
for cert in certificate_chain:
|
||||
if isinstance(cert, bytes):
|
||||
from cryptography import x509
|
||||
|
||||
cert = x509.load_der_x509_certificate(cert)
|
||||
chain_objects.append(cert)
|
||||
config.certificate_chain = chain_objects
|
||||
|
||||
# Set ALPN protocols
|
||||
if "alpn_protocols" in tls_config:
|
||||
config.alpn_protocols = tls_config["alpn_protocols"]
|
||||
config.alpn_protocols = tls_config["alpn_protocols"] # type: ignore
|
||||
|
||||
# Set certificate verification
|
||||
# Set certificate verification mode
|
||||
if "verify_mode" in tls_config:
|
||||
config.verify_mode = tls_config["verify_mode"]
|
||||
config.verify_mode = tls_config["verify_mode"] # type: ignore
|
||||
|
||||
logger.debug("Successfully applied TLS configuration to QUIC config")
|
||||
|
||||
except Exception as e:
|
||||
raise QUICSecurityError(f"Failed to apply TLS configuration: {e}") from e
|
||||
@ -301,6 +333,7 @@ class QUICTransport(ITransport):
|
||||
|
||||
Raises:
|
||||
QUICSecurityError: If peer verification fails
|
||||
|
||||
"""
|
||||
try:
|
||||
# Get peer certificate from the connection
|
||||
@ -316,7 +349,8 @@ class QUICTransport(ITransport):
|
||||
|
||||
if verified_peer_id != expected_peer_id:
|
||||
raise QUICSecurityError(
|
||||
f"Peer ID verification failed: expected {expected_peer_id}, got {verified_peer_id}"
|
||||
"Peer ID verification failed: expected "
|
||||
f"{expected_peer_id}, got {verified_peer_id}"
|
||||
)
|
||||
|
||||
logger.info(f"Peer identity verified: {verified_peer_id}")
|
||||
@ -437,5 +471,6 @@ class QUICTransport(ITransport):
|
||||
|
||||
Returns:
|
||||
The QUIC TLS configuration manager
|
||||
|
||||
"""
|
||||
return self._security_manager
|
||||
|
||||
@ -184,7 +184,8 @@ def create_quic_multiaddr(
|
||||
if version == "quic-v1" or version == "/quic-v1":
|
||||
quic_proto = QUIC_V1_PROTOCOL
|
||||
elif version == "quic" or version == "/quic":
|
||||
quic_proto = QUIC_DRAFT29_PROTOCOL
|
||||
# This is DRAFT Protocol
|
||||
quic_proto = QUIC_V1_PROTOCOL
|
||||
else:
|
||||
raise QUICInvalidMultiaddrError(f"Invalid QUIC version: {version}")
|
||||
|
||||
|
||||
@ -36,8 +36,8 @@ class MockResourceScope:
|
||||
self.memory_reserved = max(0, self.memory_reserved - size)
|
||||
|
||||
|
||||
class TestQUICConnectionEnhanced:
|
||||
"""Enhanced test suite for QUIC connection functionality."""
|
||||
class TestQUICConnection:
|
||||
"""Test suite for QUIC connection functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_quic_connection(self):
|
||||
@ -58,10 +58,13 @@ class TestQUICConnectionEnhanced:
|
||||
return MockResourceScope()
|
||||
|
||||
@pytest.fixture
|
||||
def quic_connection(self, mock_quic_connection, mock_resource_scope):
|
||||
def quic_connection(
|
||||
self, mock_quic_connection: Mock, mock_resource_scope: MockResourceScope
|
||||
):
|
||||
"""Create test QUIC connection with enhanced features."""
|
||||
private_key = create_new_key_pair().private_key
|
||||
peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||
mock_security_manager = Mock()
|
||||
|
||||
return QUICConnection(
|
||||
quic_connection=mock_quic_connection,
|
||||
@ -72,6 +75,7 @@ class TestQUICConnectionEnhanced:
|
||||
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
transport=Mock(),
|
||||
resource_scope=mock_resource_scope,
|
||||
security_manager=mock_security_manager,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
@ -267,7 +271,9 @@ class TestQUICConnectionEnhanced:
|
||||
await quic_connection.start()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_connect_with_nursery(self, quic_connection):
|
||||
async def test_connection_connect_with_nursery(
|
||||
self, quic_connection: QUICConnection
|
||||
):
|
||||
"""Test connection establishment with nursery."""
|
||||
quic_connection._started = True
|
||||
quic_connection._established = True
|
||||
@ -277,7 +283,9 @@ class TestQUICConnectionEnhanced:
|
||||
quic_connection, "_start_background_tasks", new_callable=AsyncMock
|
||||
) as mock_start_tasks:
|
||||
with patch.object(
|
||||
quic_connection, "verify_peer_identity", new_callable=AsyncMock
|
||||
quic_connection,
|
||||
"_verify_peer_identity_with_security",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_verify:
|
||||
async with trio.open_nursery() as nursery:
|
||||
await quic_connection.connect(nursery)
|
||||
|
||||
@ -66,7 +66,8 @@ Focused tests covering essential functionality required for QUIC transport.
|
||||
|
||||
# for addr_str in invalid_addrs:
|
||||
# maddr = Multiaddr(addr_str)
|
||||
# assert not is_quic_multiaddr(maddr), f"Should not detect {addr_str} as QUIC"
|
||||
# assert not is_quic_multiaddr(maddr),
|
||||
# f"Should not detect {addr_str} as QUIC"
|
||||
|
||||
# def test_malformed_multiaddrs(self):
|
||||
# """Test malformed multiaddrs don't crash."""
|
||||
|
||||
Reference in New Issue
Block a user