mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-12 16:10:57 +00:00
fix: duplication connection creation for same sessions
This commit is contained in:
@ -4,9 +4,11 @@ Implements libp2p TLS specification for QUIC transport with peer identity integr
|
||||
Based on go-libp2p and js-libp2p security patterns.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
import ssl
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
@ -25,11 +27,6 @@ from .exceptions import (
|
||||
QUICPeerVerificationError,
|
||||
)
|
||||
|
||||
TSecurityConfig = dict[
|
||||
str,
|
||||
Certificate | EllipticCurvePrivateKey | RSAPrivateKey | bool | list[str],
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# libp2p TLS Extension OID - Official libp2p specification
|
||||
@ -312,7 +309,7 @@ class CertificateGenerator:
|
||||
x509.UnrecognizedExtension(
|
||||
oid=LIBP2P_TLS_EXTENSION_OID, value=extension_data
|
||||
),
|
||||
critical=True, # This extension is critical for libp2p
|
||||
critical=False,
|
||||
)
|
||||
.sign(cert_private_key, hashes.SHA256())
|
||||
)
|
||||
@ -407,6 +404,269 @@ class PeerAuthenticator:
|
||||
) from e
|
||||
|
||||
|
||||
@dataclass
|
||||
class QUICTLSSecurityConfig:
|
||||
"""
|
||||
Type-safe TLS security configuration for QUIC transport.
|
||||
"""
|
||||
|
||||
# Core TLS components (required)
|
||||
certificate: Certificate
|
||||
private_key: Union[EllipticCurvePrivateKey, RSAPrivateKey]
|
||||
|
||||
# Certificate chain (optional)
|
||||
certificate_chain: List[Certificate] = field(default_factory=list)
|
||||
|
||||
# ALPN protocols
|
||||
alpn_protocols: List[str] = field(default_factory=lambda: ["libp2p"])
|
||||
|
||||
# TLS verification settings
|
||||
verify_mode: Union[bool, ssl.VerifyMode] = False
|
||||
check_hostname: bool = False
|
||||
|
||||
# Optional peer ID for validation
|
||||
peer_id: Optional[ID] = None
|
||||
|
||||
# Configuration metadata
|
||||
is_client_config: bool = False
|
||||
config_name: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate configuration after initialization."""
|
||||
self._validate()
|
||||
|
||||
def _validate(self) -> None:
|
||||
"""Validate the TLS configuration."""
|
||||
if self.certificate is None:
|
||||
raise ValueError("Certificate is required")
|
||||
|
||||
if self.private_key is None:
|
||||
raise ValueError("Private key is required")
|
||||
|
||||
if not isinstance(self.certificate, x509.Certificate):
|
||||
raise TypeError(
|
||||
f"Certificate must be x509.Certificate, got {type(self.certificate)}"
|
||||
)
|
||||
|
||||
if not isinstance(
|
||||
self.private_key, (ec.EllipticCurvePrivateKey, rsa.RSAPrivateKey)
|
||||
):
|
||||
raise TypeError(
|
||||
f"Private key must be EC or RSA key, got {type(self.private_key)}"
|
||||
)
|
||||
|
||||
if not self.alpn_protocols:
|
||||
raise ValueError("At least one ALPN protocol is required")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
Convert to dictionary format for compatibility with existing code.
|
||||
|
||||
Returns:
|
||||
Dictionary compatible with the original TSecurityConfig format
|
||||
|
||||
"""
|
||||
return {
|
||||
"certificate": self.certificate,
|
||||
"private_key": self.private_key,
|
||||
"certificate_chain": self.certificate_chain.copy(),
|
||||
"alpn_protocols": self.alpn_protocols.copy(),
|
||||
"verify_mode": self.verify_mode,
|
||||
"check_hostname": self.check_hostname,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict, **kwargs) -> "QUICTLSSecurityConfig":
|
||||
"""
|
||||
Create instance from dictionary format.
|
||||
|
||||
Args:
|
||||
config_dict: Dictionary in TSecurityConfig format
|
||||
**kwargs: Additional parameters for the config
|
||||
|
||||
Returns:
|
||||
QUICTLSSecurityConfig instance
|
||||
|
||||
"""
|
||||
return cls(
|
||||
certificate=config_dict["certificate"],
|
||||
private_key=config_dict["private_key"],
|
||||
certificate_chain=config_dict.get("certificate_chain", []),
|
||||
alpn_protocols=config_dict.get("alpn_protocols", ["libp2p"]),
|
||||
verify_mode=config_dict.get("verify_mode", False),
|
||||
check_hostname=config_dict.get("check_hostname", False),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def validate_certificate_key_match(self) -> bool:
|
||||
"""
|
||||
Validate that the certificate and private key match.
|
||||
|
||||
Returns:
|
||||
True if certificate and private key match
|
||||
|
||||
"""
|
||||
try:
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
|
||||
# Get public keys from both certificate and private key
|
||||
cert_public_key = self.certificate.public_key()
|
||||
private_public_key = self.private_key.public_key()
|
||||
|
||||
# Compare their PEM representations
|
||||
cert_pub_pem = cert_public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
private_pub_pem = private_public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
return cert_pub_pem == private_pub_pem
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def has_libp2p_extension(self) -> bool:
|
||||
"""
|
||||
Check if the certificate has the required libp2p extension.
|
||||
|
||||
Returns:
|
||||
True if libp2p extension is present
|
||||
|
||||
"""
|
||||
try:
|
||||
libp2p_oid = "1.3.6.1.4.1.53594.1.1"
|
||||
for ext in self.certificate.extensions:
|
||||
if str(ext.oid) == libp2p_oid:
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def is_certificate_valid(self) -> bool:
|
||||
"""
|
||||
Check if the certificate is currently valid (not expired).
|
||||
|
||||
Returns:
|
||||
True if certificate is valid
|
||||
|
||||
"""
|
||||
try:
|
||||
from datetime import datetime
|
||||
|
||||
now = datetime.utcnow()
|
||||
return (
|
||||
self.certificate.not_valid_before
|
||||
<= now
|
||||
<= self.certificate.not_valid_after
|
||||
)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_certificate_info(self) -> dict:
|
||||
"""
|
||||
Get certificate information for debugging.
|
||||
|
||||
Returns:
|
||||
Dictionary with certificate details
|
||||
|
||||
"""
|
||||
try:
|
||||
return {
|
||||
"subject": str(self.certificate.subject),
|
||||
"issuer": str(self.certificate.issuer),
|
||||
"serial_number": self.certificate.serial_number,
|
||||
"not_valid_before": self.certificate.not_valid_before,
|
||||
"not_valid_after": self.certificate.not_valid_after,
|
||||
"has_libp2p_extension": self.has_libp2p_extension(),
|
||||
"is_valid": self.is_certificate_valid(),
|
||||
"certificate_key_match": self.validate_certificate_key_match(),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
def debug_print(self) -> None:
|
||||
"""Print debugging information about this configuration."""
|
||||
print(f"=== TLS Security Config Debug ({self.config_name or 'unnamed'}) ===")
|
||||
print(f"Is client config: {self.is_client_config}")
|
||||
print(f"ALPN protocols: {self.alpn_protocols}")
|
||||
print(f"Verify mode: {self.verify_mode}")
|
||||
print(f"Check hostname: {self.check_hostname}")
|
||||
print(f"Certificate chain length: {len(self.certificate_chain)}")
|
||||
|
||||
cert_info = self.get_certificate_info()
|
||||
for key, value in cert_info.items():
|
||||
print(f"Certificate {key}: {value}")
|
||||
|
||||
print(f"Private key type: {type(self.private_key).__name__}")
|
||||
if hasattr(self.private_key, "key_size"):
|
||||
print(f"Private key size: {self.private_key.key_size}")
|
||||
|
||||
|
||||
def create_server_tls_config(
|
||||
certificate: Certificate,
|
||||
private_key: Union[EllipticCurvePrivateKey, RSAPrivateKey],
|
||||
peer_id: Optional[ID] = None,
|
||||
**kwargs,
|
||||
) -> QUICTLSSecurityConfig:
|
||||
"""
|
||||
Create a server TLS configuration.
|
||||
|
||||
Args:
|
||||
certificate: X.509 certificate
|
||||
private_key: Private key corresponding to certificate
|
||||
peer_id: Optional peer ID for validation
|
||||
**kwargs: Additional configuration parameters
|
||||
|
||||
Returns:
|
||||
Server TLS configuration
|
||||
|
||||
"""
|
||||
return QUICTLSSecurityConfig(
|
||||
certificate=certificate,
|
||||
private_key=private_key,
|
||||
peer_id=peer_id,
|
||||
is_client_config=False,
|
||||
config_name="server",
|
||||
verify_mode=False, # Server doesn't verify client certs in libp2p
|
||||
check_hostname=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def create_client_tls_config(
|
||||
certificate: Certificate,
|
||||
private_key: Union[EllipticCurvePrivateKey, RSAPrivateKey],
|
||||
peer_id: Optional[ID] = None,
|
||||
**kwargs,
|
||||
) -> QUICTLSSecurityConfig:
|
||||
"""
|
||||
Create a client TLS configuration.
|
||||
|
||||
Args:
|
||||
certificate: X.509 certificate
|
||||
private_key: Private key corresponding to certificate
|
||||
peer_id: Optional peer ID for validation
|
||||
**kwargs: Additional configuration parameters
|
||||
|
||||
Returns:
|
||||
Client TLS configuration
|
||||
|
||||
"""
|
||||
return QUICTLSSecurityConfig(
|
||||
certificate=certificate,
|
||||
private_key=private_key,
|
||||
peer_id=peer_id,
|
||||
is_client_config=True,
|
||||
config_name="client",
|
||||
verify_mode=False, # Client doesn't verify server certs in libp2p
|
||||
check_hostname=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class QUICTLSConfigManager:
|
||||
"""
|
||||
Manages TLS configuration for QUIC transport with libp2p security.
|
||||
@ -424,44 +684,40 @@ class QUICTLSConfigManager:
|
||||
libp2p_private_key, peer_id
|
||||
)
|
||||
|
||||
def create_server_config(
|
||||
self,
|
||||
) -> TSecurityConfig:
|
||||
def create_server_config(self) -> QUICTLSSecurityConfig:
|
||||
"""
|
||||
Create aioquic server configuration with libp2p TLS settings.
|
||||
Returns cryptography objects instead of DER bytes.
|
||||
Create server configuration using the new class-based approach.
|
||||
|
||||
Returns:
|
||||
Configuration dictionary for aioquic QuicConfiguration
|
||||
QUICTLSSecurityConfig instance for server
|
||||
|
||||
"""
|
||||
config: TSecurityConfig = {
|
||||
"certificate": self.tls_config.certificate,
|
||||
"private_key": self.tls_config.private_key,
|
||||
"certificate_chain": [],
|
||||
"alpn_protocols": ["libp2p"],
|
||||
"verify_mode": False,
|
||||
"check_hostname": False,
|
||||
}
|
||||
config = create_server_tls_config(
|
||||
certificate=self.tls_config.certificate,
|
||||
private_key=self.tls_config.private_key,
|
||||
peer_id=self.peer_id,
|
||||
)
|
||||
|
||||
print("🔧 SECURITY: Created server config")
|
||||
config.debug_print()
|
||||
return config
|
||||
|
||||
def create_client_config(self) -> TSecurityConfig:
|
||||
def create_client_config(self) -> QUICTLSSecurityConfig:
|
||||
"""
|
||||
Create aioquic client configuration with libp2p TLS settings.
|
||||
Returns cryptography objects instead of DER bytes.
|
||||
Create client configuration using the new class-based approach.
|
||||
|
||||
Returns:
|
||||
Configuration dictionary for aioquic QuicConfiguration
|
||||
QUICTLSSecurityConfig instance for client
|
||||
|
||||
"""
|
||||
config: TSecurityConfig = {
|
||||
"certificate": self.tls_config.certificate,
|
||||
"private_key": self.tls_config.private_key,
|
||||
"certificate_chain": [],
|
||||
"alpn_protocols": ["libp2p"],
|
||||
"verify_mode": False,
|
||||
"check_hostname": False,
|
||||
}
|
||||
config = create_client_tls_config(
|
||||
certificate=self.tls_config.certificate,
|
||||
private_key=self.tls_config.private_key,
|
||||
peer_id=self.peer_id,
|
||||
)
|
||||
|
||||
print("🔧 SECURITY: Created client config")
|
||||
config.debug_print()
|
||||
return config
|
||||
|
||||
def verify_peer_identity(
|
||||
|
||||
Reference in New Issue
Block a user