mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2026-02-11 07:30:55 +00:00
fix: add basic quic stream and associated tests
This commit is contained in:
@ -7,7 +7,7 @@ from dataclasses import (
|
|||||||
field,
|
field,
|
||||||
)
|
)
|
||||||
import ssl
|
import ssl
|
||||||
from typing import TypedDict
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from libp2p.custom_types import TProtocol
|
from libp2p.custom_types import TProtocol
|
||||||
|
|
||||||
@ -76,6 +76,101 @@ class QUICTransportConfig:
|
|||||||
max_connections: int = 1000 # Maximum number of connections
|
max_connections: int = 1000 # Maximum number of connections
|
||||||
connection_timeout: float = 10.0 # Connection establishment timeout
|
connection_timeout: float = 10.0 # Connection establishment timeout
|
||||||
|
|
||||||
|
MAX_CONCURRENT_STREAMS: int = 1000
|
||||||
|
"""Maximum number of concurrent streams per connection."""
|
||||||
|
|
||||||
|
MAX_INCOMING_STREAMS: int = 1000
|
||||||
|
"""Maximum number of incoming streams per connection."""
|
||||||
|
|
||||||
|
MAX_OUTGOING_STREAMS: int = 1000
|
||||||
|
"""Maximum number of outgoing streams per connection."""
|
||||||
|
|
||||||
|
# Stream timeouts
|
||||||
|
STREAM_OPEN_TIMEOUT: float = 5.0
|
||||||
|
"""Timeout for opening new streams (seconds)."""
|
||||||
|
|
||||||
|
STREAM_ACCEPT_TIMEOUT: float = 30.0
|
||||||
|
"""Timeout for accepting incoming streams (seconds)."""
|
||||||
|
|
||||||
|
STREAM_READ_TIMEOUT: float = 30.0
|
||||||
|
"""Default timeout for stream read operations (seconds)."""
|
||||||
|
|
||||||
|
STREAM_WRITE_TIMEOUT: float = 30.0
|
||||||
|
"""Default timeout for stream write operations (seconds)."""
|
||||||
|
|
||||||
|
STREAM_CLOSE_TIMEOUT: float = 10.0
|
||||||
|
"""Timeout for graceful stream close (seconds)."""
|
||||||
|
|
||||||
|
# Flow control configuration
|
||||||
|
STREAM_FLOW_CONTROL_WINDOW: int = 512 * 1024 # 512KB
|
||||||
|
"""Per-stream flow control window size."""
|
||||||
|
|
||||||
|
CONNECTION_FLOW_CONTROL_WINDOW: int = 768 * 1024 # 768KB
|
||||||
|
"""Connection-wide flow control window size."""
|
||||||
|
|
||||||
|
# Buffer management
|
||||||
|
MAX_STREAM_RECEIVE_BUFFER: int = 1024 * 1024 # 1MB
|
||||||
|
"""Maximum receive buffer size per stream."""
|
||||||
|
|
||||||
|
STREAM_RECEIVE_BUFFER_LOW_WATERMARK: int = 64 * 1024 # 64KB
|
||||||
|
"""Low watermark for stream receive buffer."""
|
||||||
|
|
||||||
|
STREAM_RECEIVE_BUFFER_HIGH_WATERMARK: int = 512 * 1024 # 512KB
|
||||||
|
"""High watermark for stream receive buffer."""
|
||||||
|
|
||||||
|
# Stream lifecycle configuration
|
||||||
|
ENABLE_STREAM_RESET_ON_ERROR: bool = True
|
||||||
|
"""Whether to automatically reset streams on errors."""
|
||||||
|
|
||||||
|
STREAM_RESET_ERROR_CODE: int = 1
|
||||||
|
"""Default error code for stream resets."""
|
||||||
|
|
||||||
|
ENABLE_STREAM_KEEP_ALIVE: bool = False
|
||||||
|
"""Whether to enable stream keep-alive mechanisms."""
|
||||||
|
|
||||||
|
STREAM_KEEP_ALIVE_INTERVAL: float = 30.0
|
||||||
|
"""Interval for stream keep-alive pings (seconds)."""
|
||||||
|
|
||||||
|
# Resource management
|
||||||
|
ENABLE_STREAM_RESOURCE_TRACKING: bool = True
|
||||||
|
"""Whether to track stream resource usage."""
|
||||||
|
|
||||||
|
STREAM_MEMORY_LIMIT_PER_STREAM: int = 2 * 1024 * 1024 # 2MB
|
||||||
|
"""Memory limit per individual stream."""
|
||||||
|
|
||||||
|
STREAM_MEMORY_LIMIT_PER_CONNECTION: int = 100 * 1024 * 1024 # 100MB
|
||||||
|
"""Total memory limit for all streams per connection."""
|
||||||
|
|
||||||
|
# Concurrency and performance
|
||||||
|
ENABLE_STREAM_BATCHING: bool = True
|
||||||
|
"""Whether to batch multiple stream operations."""
|
||||||
|
|
||||||
|
STREAM_BATCH_SIZE: int = 10
|
||||||
|
"""Number of streams to process in a batch."""
|
||||||
|
|
||||||
|
STREAM_PROCESSING_CONCURRENCY: int = 100
|
||||||
|
"""Maximum concurrent stream processing tasks."""
|
||||||
|
|
||||||
|
# Debugging and monitoring
|
||||||
|
ENABLE_STREAM_METRICS: bool = True
|
||||||
|
"""Whether to collect stream metrics."""
|
||||||
|
|
||||||
|
ENABLE_STREAM_TIMELINE_TRACKING: bool = True
|
||||||
|
"""Whether to track stream lifecycle timelines."""
|
||||||
|
|
||||||
|
STREAM_METRICS_COLLECTION_INTERVAL: float = 60.0
|
||||||
|
"""Interval for collecting stream metrics (seconds)."""
|
||||||
|
|
||||||
|
# Error handling configuration
|
||||||
|
STREAM_ERROR_RETRY_ATTEMPTS: int = 3
|
||||||
|
"""Number of retry attempts for recoverable stream errors."""
|
||||||
|
|
||||||
|
STREAM_ERROR_RETRY_DELAY: float = 1.0
|
||||||
|
"""Initial delay between stream error retries (seconds)."""
|
||||||
|
|
||||||
|
STREAM_ERROR_RETRY_BACKOFF_FACTOR: float = 2.0
|
||||||
|
"""Backoff factor for stream error retries."""
|
||||||
|
|
||||||
# Protocol identifiers matching go-libp2p
|
# Protocol identifiers matching go-libp2p
|
||||||
# TODO: UNTIL MUITIADDR REPO IS UPDATED
|
# TODO: UNTIL MUITIADDR REPO IS UPDATED
|
||||||
# PROTOCOL_QUIC_V1: TProtocol = TProtocol("/quic-v1") # RFC 9000
|
# PROTOCOL_QUIC_V1: TProtocol = TProtocol("/quic-v1") # RFC 9000
|
||||||
@ -92,3 +187,167 @@ class QUICTransportConfig:
|
|||||||
|
|
||||||
if self.max_datagram_size < 1200:
|
if self.max_datagram_size < 1200:
|
||||||
raise ValueError("Max datagram size must be at least 1200 bytes")
|
raise ValueError("Max datagram size must be at least 1200 bytes")
|
||||||
|
|
||||||
|
# Validate timeouts
|
||||||
|
timeout_fields = [
|
||||||
|
"STREAM_OPEN_TIMEOUT",
|
||||||
|
"STREAM_ACCEPT_TIMEOUT",
|
||||||
|
"STREAM_READ_TIMEOUT",
|
||||||
|
"STREAM_WRITE_TIMEOUT",
|
||||||
|
"STREAM_CLOSE_TIMEOUT",
|
||||||
|
]
|
||||||
|
for timeout_field in timeout_fields:
|
||||||
|
if getattr(self, timeout_field) <= 0:
|
||||||
|
raise ValueError(f"{timeout_field} must be positive")
|
||||||
|
|
||||||
|
# Validate flow control windows
|
||||||
|
if self.STREAM_FLOW_CONTROL_WINDOW <= 0:
|
||||||
|
raise ValueError("STREAM_FLOW_CONTROL_WINDOW must be positive")
|
||||||
|
|
||||||
|
if self.CONNECTION_FLOW_CONTROL_WINDOW < self.STREAM_FLOW_CONTROL_WINDOW:
|
||||||
|
raise ValueError(
|
||||||
|
"CONNECTION_FLOW_CONTROL_WINDOW must be >= STREAM_FLOW_CONTROL_WINDOW"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate buffer sizes
|
||||||
|
if self.MAX_STREAM_RECEIVE_BUFFER <= 0:
|
||||||
|
raise ValueError("MAX_STREAM_RECEIVE_BUFFER must be positive")
|
||||||
|
|
||||||
|
if self.STREAM_RECEIVE_BUFFER_HIGH_WATERMARK > self.MAX_STREAM_RECEIVE_BUFFER:
|
||||||
|
raise ValueError(
|
||||||
|
"STREAM_RECEIVE_BUFFER_HIGH_WATERMARK cannot".__add__(
|
||||||
|
"exceed MAX_STREAM_RECEIVE_BUFFER"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.STREAM_RECEIVE_BUFFER_LOW_WATERMARK
|
||||||
|
>= self.STREAM_RECEIVE_BUFFER_HIGH_WATERMARK
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"STREAM_RECEIVE_BUFFER_LOW_WATERMARK must be < HIGH_WATERMARK"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate memory limits
|
||||||
|
if self.STREAM_MEMORY_LIMIT_PER_STREAM <= 0:
|
||||||
|
raise ValueError("STREAM_MEMORY_LIMIT_PER_STREAM must be positive")
|
||||||
|
|
||||||
|
if self.STREAM_MEMORY_LIMIT_PER_CONNECTION <= 0:
|
||||||
|
raise ValueError("STREAM_MEMORY_LIMIT_PER_CONNECTION must be positive")
|
||||||
|
|
||||||
|
expected_stream_memory = (
|
||||||
|
self.MAX_CONCURRENT_STREAMS * self.STREAM_MEMORY_LIMIT_PER_STREAM
|
||||||
|
)
|
||||||
|
if expected_stream_memory > self.STREAM_MEMORY_LIMIT_PER_CONNECTION * 2:
|
||||||
|
# Allow some headroom, but warn if configuration seems inconsistent
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.warning(
|
||||||
|
"Stream memory configuration may be inconsistent: "
|
||||||
|
f"{self.MAX_CONCURRENT_STREAMS} streams ×"
|
||||||
|
"{self.STREAM_MEMORY_LIMIT_PER_STREAM} bytes "
|
||||||
|
"could exceed connection limit of"
|
||||||
|
f"{self.STREAM_MEMORY_LIMIT_PER_CONNECTION} bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_stream_config_dict(self) -> dict[str, Any]:
|
||||||
|
"""Get stream-specific configuration as dictionary."""
|
||||||
|
stream_config = {}
|
||||||
|
for attr_name in dir(self):
|
||||||
|
if attr_name.startswith(
|
||||||
|
("STREAM_", "MAX_", "ENABLE_STREAM", "CONNECTION_FLOW")
|
||||||
|
):
|
||||||
|
stream_config[attr_name.lower()] = getattr(self, attr_name)
|
||||||
|
return stream_config
|
||||||
|
|
||||||
|
|
||||||
|
# Additional configuration classes for specific stream features
|
||||||
|
|
||||||
|
|
||||||
|
class QUICStreamFlowControlConfig:
|
||||||
|
"""Configuration for QUIC stream flow control."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
initial_window_size: int = 512 * 1024,
|
||||||
|
max_window_size: int = 2 * 1024 * 1024,
|
||||||
|
window_update_threshold: float = 0.5,
|
||||||
|
enable_auto_tuning: bool = True,
|
||||||
|
):
|
||||||
|
self.initial_window_size = initial_window_size
|
||||||
|
self.max_window_size = max_window_size
|
||||||
|
self.window_update_threshold = window_update_threshold
|
||||||
|
self.enable_auto_tuning = enable_auto_tuning
|
||||||
|
|
||||||
|
|
||||||
|
class QUICStreamMetricsConfig:
|
||||||
|
"""Configuration for QUIC stream metrics collection."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
enable_latency_tracking: bool = True,
|
||||||
|
enable_throughput_tracking: bool = True,
|
||||||
|
enable_error_tracking: bool = True,
|
||||||
|
metrics_retention_duration: float = 3600.0, # 1 hour
|
||||||
|
metrics_aggregation_interval: float = 60.0, # 1 minute
|
||||||
|
):
|
||||||
|
self.enable_latency_tracking = enable_latency_tracking
|
||||||
|
self.enable_throughput_tracking = enable_throughput_tracking
|
||||||
|
self.enable_error_tracking = enable_error_tracking
|
||||||
|
self.metrics_retention_duration = metrics_retention_duration
|
||||||
|
self.metrics_aggregation_interval = metrics_aggregation_interval
|
||||||
|
|
||||||
|
|
||||||
|
# Factory function for creating optimized configurations
|
||||||
|
|
||||||
|
|
||||||
|
def create_stream_config_for_use_case(use_case: str) -> QUICTransportConfig:
|
||||||
|
"""
|
||||||
|
Create optimized stream configuration for specific use cases.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_case: One of "high_throughput", "low_latency", "many_streams","
|
||||||
|
"memory_constrained"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optimized QUICTransportConfig
|
||||||
|
|
||||||
|
"""
|
||||||
|
base_config = QUICTransportConfig()
|
||||||
|
|
||||||
|
if use_case == "high_throughput":
|
||||||
|
# Optimize for high throughput
|
||||||
|
base_config.STREAM_FLOW_CONTROL_WINDOW = 2 * 1024 * 1024 # 2MB
|
||||||
|
base_config.CONNECTION_FLOW_CONTROL_WINDOW = 10 * 1024 * 1024 # 10MB
|
||||||
|
base_config.MAX_STREAM_RECEIVE_BUFFER = 4 * 1024 * 1024 # 4MB
|
||||||
|
base_config.STREAM_PROCESSING_CONCURRENCY = 200
|
||||||
|
|
||||||
|
elif use_case == "low_latency":
|
||||||
|
# Optimize for low latency
|
||||||
|
base_config.STREAM_OPEN_TIMEOUT = 1.0
|
||||||
|
base_config.STREAM_READ_TIMEOUT = 5.0
|
||||||
|
base_config.STREAM_WRITE_TIMEOUT = 5.0
|
||||||
|
base_config.ENABLE_STREAM_BATCHING = False
|
||||||
|
base_config.STREAM_BATCH_SIZE = 1
|
||||||
|
|
||||||
|
elif use_case == "many_streams":
|
||||||
|
# Optimize for many concurrent streams
|
||||||
|
base_config.MAX_CONCURRENT_STREAMS = 5000
|
||||||
|
base_config.STREAM_FLOW_CONTROL_WINDOW = 128 * 1024 # 128KB
|
||||||
|
base_config.MAX_STREAM_RECEIVE_BUFFER = 256 * 1024 # 256KB
|
||||||
|
base_config.STREAM_PROCESSING_CONCURRENCY = 500
|
||||||
|
|
||||||
|
elif use_case == "memory_constrained":
|
||||||
|
# Optimize for low memory usage
|
||||||
|
base_config.MAX_CONCURRENT_STREAMS = 100
|
||||||
|
base_config.STREAM_FLOW_CONTROL_WINDOW = 64 * 1024 # 64KB
|
||||||
|
base_config.CONNECTION_FLOW_CONTROL_WINDOW = 256 * 1024 # 256KB
|
||||||
|
base_config.MAX_STREAM_RECEIVE_BUFFER = 128 * 1024 # 128KB
|
||||||
|
base_config.STREAM_MEMORY_LIMIT_PER_STREAM = 512 * 1024 # 512KB
|
||||||
|
base_config.STREAM_PROCESSING_CONCURRENCY = 50
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown use case: {use_case}")
|
||||||
|
|
||||||
|
return base_config
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -1,35 +1,393 @@
|
|||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
"""
|
"""
|
||||||
QUIC transport specific exceptions.
|
QUIC Transport exceptions for py-libp2p.
|
||||||
|
Comprehensive error handling for QUIC transport, connection, and stream operations.
|
||||||
|
Based on patterns from go-libp2p and js-libp2p implementations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from libp2p.exceptions import (
|
|
||||||
BaseLibp2pError,
|
class QUICError(Exception):
|
||||||
)
|
"""Base exception for all QUIC transport errors."""
|
||||||
|
|
||||||
|
def __init__(self, message: str, error_code: int | None = None):
|
||||||
|
super().__init__(message)
|
||||||
|
self.error_code = error_code
|
||||||
|
|
||||||
|
|
||||||
class QUICError(BaseLibp2pError):
|
# Transport-level exceptions
|
||||||
"""Base exception for QUIC transport errors."""
|
|
||||||
|
|
||||||
|
|
||||||
class QUICDialError(QUICError):
|
class QUICTransportError(QUICError):
|
||||||
"""Exception raised when QUIC dial operation fails."""
|
"""Base exception for QUIC transport operations."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class QUICListenError(QUICError):
|
class QUICDialError(QUICTransportError):
|
||||||
"""Exception raised when QUIC listen operation fails."""
|
"""Error occurred during QUIC connection establishment."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICListenError(QUICTransportError):
|
||||||
|
"""Error occurred during QUIC listener operations."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICSecurityError(QUICTransportError):
|
||||||
|
"""Error related to QUIC security/TLS operations."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Connection-level exceptions
|
||||||
|
|
||||||
|
|
||||||
class QUICConnectionError(QUICError):
|
class QUICConnectionError(QUICError):
|
||||||
"""Exception raised for QUIC connection errors."""
|
"""Base exception for QUIC connection operations."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICConnectionClosedError(QUICConnectionError):
|
||||||
|
"""QUIC connection has been closed."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICConnectionTimeoutError(QUICConnectionError):
|
||||||
|
"""QUIC connection operation timed out."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICHandshakeError(QUICConnectionError):
|
||||||
|
"""Error during QUIC handshake process."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICPeerVerificationError(QUICConnectionError):
|
||||||
|
"""Error verifying peer identity during handshake."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Stream-level exceptions
|
||||||
|
|
||||||
|
|
||||||
class QUICStreamError(QUICError):
|
class QUICStreamError(QUICError):
|
||||||
"""Exception raised for QUIC stream errors."""
|
"""Base exception for QUIC stream operations."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
stream_id: str | None = None,
|
||||||
|
error_code: int | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(message, error_code)
|
||||||
|
self.stream_id = stream_id
|
||||||
|
|
||||||
|
|
||||||
|
class QUICStreamClosedError(QUICStreamError):
|
||||||
|
"""Stream is closed and cannot be used for I/O operations."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICStreamResetError(QUICStreamError):
|
||||||
|
"""Stream was reset by local or remote peer."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
stream_id: str | None = None,
|
||||||
|
error_code: int | None = None,
|
||||||
|
reset_by_peer: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(message, stream_id, error_code)
|
||||||
|
self.reset_by_peer = reset_by_peer
|
||||||
|
|
||||||
|
|
||||||
|
class QUICStreamTimeoutError(QUICStreamError):
|
||||||
|
"""Stream operation timed out."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICStreamBackpressureError(QUICStreamError):
|
||||||
|
"""Stream write blocked due to flow control."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICStreamLimitError(QUICStreamError):
|
||||||
|
"""Stream limit reached (too many concurrent streams)."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICStreamStateError(QUICStreamError):
|
||||||
|
"""Invalid operation for current stream state."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
stream_id: str | None = None,
|
||||||
|
current_state: str | None = None,
|
||||||
|
attempted_operation: str | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(message, stream_id)
|
||||||
|
self.current_state = current_state
|
||||||
|
self.attempted_operation = attempted_operation
|
||||||
|
|
||||||
|
|
||||||
|
# Flow control exceptions
|
||||||
|
|
||||||
|
|
||||||
|
class QUICFlowControlError(QUICError):
|
||||||
|
"""Base exception for flow control related errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICFlowControlViolationError(QUICFlowControlError):
|
||||||
|
"""Flow control limits were violated."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICFlowControlDeadlockError(QUICFlowControlError):
|
||||||
|
"""Flow control deadlock detected."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Resource management exceptions
|
||||||
|
|
||||||
|
|
||||||
|
class QUICResourceError(QUICError):
|
||||||
|
"""Base exception for resource management errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICMemoryLimitError(QUICResourceError):
|
||||||
|
"""Memory limit exceeded."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICConnectionLimitError(QUICResourceError):
|
||||||
|
"""Connection limit exceeded."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Multiaddr and addressing exceptions
|
||||||
|
|
||||||
|
|
||||||
|
class QUICAddressError(QUICError):
|
||||||
|
"""Base exception for QUIC addressing errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICInvalidMultiaddrError(QUICAddressError):
|
||||||
|
"""Invalid multiaddr format for QUIC transport."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICAddressResolutionError(QUICAddressError):
|
||||||
|
"""Failed to resolve QUIC address."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICProtocolError(QUICError):
|
||||||
|
"""Base exception for QUIC protocol errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICVersionNegotiationError(QUICProtocolError):
|
||||||
|
"""QUIC version negotiation failed."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICUnsupportedVersionError(QUICProtocolError):
|
||||||
|
"""Unsupported QUIC version."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Configuration exceptions
|
||||||
|
|
||||||
|
|
||||||
class QUICConfigurationError(QUICError):
|
class QUICConfigurationError(QUICError):
|
||||||
"""Exception raised for QUIC configuration errors."""
|
"""Base exception for QUIC configuration errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class QUICSecurityError(QUICError):
|
class QUICInvalidConfigError(QUICConfigurationError):
|
||||||
"""Exception raised for QUIC security/TLS errors."""
|
"""Invalid QUIC configuration parameters."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class QUICCertificateError(QUICConfigurationError):
|
||||||
|
"""Error with TLS certificate configuration."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def map_quic_error_code(error_code: int) -> str:
|
||||||
|
"""
|
||||||
|
Map QUIC error codes to human-readable descriptions.
|
||||||
|
Based on RFC 9000 Transport Error Codes.
|
||||||
|
"""
|
||||||
|
error_codes = {
|
||||||
|
0x00: "NO_ERROR",
|
||||||
|
0x01: "INTERNAL_ERROR",
|
||||||
|
0x02: "CONNECTION_REFUSED",
|
||||||
|
0x03: "FLOW_CONTROL_ERROR",
|
||||||
|
0x04: "STREAM_LIMIT_ERROR",
|
||||||
|
0x05: "STREAM_STATE_ERROR",
|
||||||
|
0x06: "FINAL_SIZE_ERROR",
|
||||||
|
0x07: "FRAME_ENCODING_ERROR",
|
||||||
|
0x08: "TRANSPORT_PARAMETER_ERROR",
|
||||||
|
0x09: "CONNECTION_ID_LIMIT_ERROR",
|
||||||
|
0x0A: "PROTOCOL_VIOLATION",
|
||||||
|
0x0B: "INVALID_TOKEN",
|
||||||
|
0x0C: "APPLICATION_ERROR",
|
||||||
|
0x0D: "CRYPTO_BUFFER_EXCEEDED",
|
||||||
|
0x0E: "KEY_UPDATE_ERROR",
|
||||||
|
0x0F: "AEAD_LIMIT_REACHED",
|
||||||
|
0x10: "NO_VIABLE_PATH",
|
||||||
|
}
|
||||||
|
|
||||||
|
return error_codes.get(error_code, f"UNKNOWN_ERROR_{error_code:02X}")
|
||||||
|
|
||||||
|
|
||||||
|
def create_stream_error(
|
||||||
|
error_type: str,
|
||||||
|
message: str,
|
||||||
|
stream_id: str | None = None,
|
||||||
|
error_code: int | None = None,
|
||||||
|
) -> QUICStreamError:
|
||||||
|
"""
|
||||||
|
Factory function to create appropriate stream error based on type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_type: Type of error ("closed", "reset", "timeout", "backpressure", etc.)
|
||||||
|
message: Error message
|
||||||
|
stream_id: Stream identifier
|
||||||
|
error_code: QUIC error code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Appropriate QUICStreamError subclass
|
||||||
|
|
||||||
|
"""
|
||||||
|
error_type = error_type.lower()
|
||||||
|
|
||||||
|
if error_type in ("closed", "close"):
|
||||||
|
return QUICStreamClosedError(message, stream_id, error_code)
|
||||||
|
elif error_type == "reset":
|
||||||
|
return QUICStreamResetError(message, stream_id, error_code)
|
||||||
|
elif error_type == "timeout":
|
||||||
|
return QUICStreamTimeoutError(message, stream_id, error_code)
|
||||||
|
elif error_type in ("backpressure", "flow_control"):
|
||||||
|
return QUICStreamBackpressureError(message, stream_id, error_code)
|
||||||
|
elif error_type in ("limit", "stream_limit"):
|
||||||
|
return QUICStreamLimitError(message, stream_id, error_code)
|
||||||
|
elif error_type == "state":
|
||||||
|
return QUICStreamStateError(message, stream_id)
|
||||||
|
else:
|
||||||
|
return QUICStreamError(message, stream_id, error_code)
|
||||||
|
|
||||||
|
|
||||||
|
def create_connection_error(
|
||||||
|
error_type: str, message: str, error_code: int | None = None
|
||||||
|
) -> QUICConnectionError:
|
||||||
|
"""
|
||||||
|
Factory function to create appropriate connection error based on type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_type: Type of error ("closed", "timeout", "handshake", etc.)
|
||||||
|
message: Error message
|
||||||
|
error_code: QUIC error code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Appropriate QUICConnectionError subclass
|
||||||
|
|
||||||
|
"""
|
||||||
|
error_type = error_type.lower()
|
||||||
|
|
||||||
|
if error_type in ("closed", "close"):
|
||||||
|
return QUICConnectionClosedError(message, error_code)
|
||||||
|
elif error_type == "timeout":
|
||||||
|
return QUICConnectionTimeoutError(message, error_code)
|
||||||
|
elif error_type == "handshake":
|
||||||
|
return QUICHandshakeError(message, error_code)
|
||||||
|
elif error_type in ("peer_verification", "verification"):
|
||||||
|
return QUICPeerVerificationError(message, error_code)
|
||||||
|
else:
|
||||||
|
return QUICConnectionError(message, error_code)
|
||||||
|
|
||||||
|
|
||||||
|
class QUICErrorContext:
|
||||||
|
"""
|
||||||
|
Context manager for handling QUIC errors with automatic error mapping.
|
||||||
|
Useful for converting low-level aioquic errors to py-libp2p QUIC errors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, operation: str, component: str = "quic") -> None:
|
||||||
|
self.operation = operation
|
||||||
|
self.component = component
|
||||||
|
|
||||||
|
def __enter__(self) -> "QUICErrorContext":
|
||||||
|
return self
|
||||||
|
|
||||||
|
# TODO: Fix types for exc_type
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: Any,
|
||||||
|
) -> Literal[False]:
|
||||||
|
if exc_type is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if exc_val is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Map common aioquic exceptions to our exceptions
|
||||||
|
if "ConnectionClosed" in str(exc_type):
|
||||||
|
raise QUICConnectionClosedError(
|
||||||
|
f"Connection closed during {self.operation}: {exc_val}"
|
||||||
|
) from exc_val
|
||||||
|
elif "StreamReset" in str(exc_type):
|
||||||
|
raise QUICStreamResetError(
|
||||||
|
f"Stream reset during {self.operation}: {exc_val}"
|
||||||
|
) from exc_val
|
||||||
|
elif "timeout" in str(exc_val).lower():
|
||||||
|
if "stream" in self.component.lower():
|
||||||
|
raise QUICStreamTimeoutError(
|
||||||
|
f"Timeout during {self.operation}: {exc_val}"
|
||||||
|
) from exc_val
|
||||||
|
else:
|
||||||
|
raise QUICConnectionTimeoutError(
|
||||||
|
f"Timeout during {self.operation}: {exc_val}"
|
||||||
|
) from exc_val
|
||||||
|
elif "flow control" in str(exc_val).lower():
|
||||||
|
raise QUICStreamBackpressureError(
|
||||||
|
f"Flow control error during {self.operation}: {exc_val}"
|
||||||
|
) from exc_val
|
||||||
|
|
||||||
|
# Let other exceptions propagate
|
||||||
|
return False
|
||||||
|
|||||||
@ -251,7 +251,7 @@ class QUICListener(IListener):
|
|||||||
connection._quic.receive_datagram(data, addr, now=time.time())
|
connection._quic.receive_datagram(data, addr, now=time.time())
|
||||||
|
|
||||||
# Process events and handle responses
|
# Process events and handle responses
|
||||||
await connection._process_events()
|
await connection._process_quic_events()
|
||||||
await connection._transmit()
|
await connection._transmit()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -386,8 +386,8 @@ class QUICListener(IListener):
|
|||||||
|
|
||||||
# Start connection management tasks
|
# Start connection management tasks
|
||||||
if self._nursery:
|
if self._nursery:
|
||||||
self._nursery.start_soon(connection._handle_incoming_data)
|
self._nursery.start_soon(connection._handle_datagram_received)
|
||||||
self._nursery.start_soon(connection._handle_timer)
|
self._nursery.start_soon(connection._handle_timer_events)
|
||||||
|
|
||||||
# TODO: Verify peer identity
|
# TODO: Verify peer identity
|
||||||
# await connection.verify_peer_identity()
|
# await connection.verify_peer_identity()
|
||||||
|
|||||||
@ -1,126 +1,583 @@
|
|||||||
"""
|
"""
|
||||||
QUIC Stream implementation
|
QUIC Stream implementation for py-libp2p Module 3.
|
||||||
|
Based on patterns from go-libp2p and js-libp2p QUIC implementations.
|
||||||
|
Uses aioquic's native stream capabilities with libp2p interface compliance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from types import (
|
from enum import Enum
|
||||||
TracebackType,
|
import logging
|
||||||
)
|
import time
|
||||||
from typing import TYPE_CHECKING, cast
|
from types import TracebackType
|
||||||
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
import trio
|
import trio
|
||||||
|
|
||||||
|
from .exceptions import (
|
||||||
|
QUICStreamBackpressureError,
|
||||||
|
QUICStreamClosedError,
|
||||||
|
QUICStreamResetError,
|
||||||
|
QUICStreamTimeoutError,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from libp2p.abc import IMuxedStream
|
from libp2p.abc import IMuxedStream
|
||||||
|
from libp2p.custom_types import TProtocol
|
||||||
|
|
||||||
from .connection import QUICConnection
|
from .connection import QUICConnection
|
||||||
else:
|
else:
|
||||||
IMuxedStream = cast(type, object)
|
IMuxedStream = cast(type, object)
|
||||||
|
TProtocol = cast(type, object)
|
||||||
|
|
||||||
from .exceptions import (
|
logger = logging.getLogger(__name__)
|
||||||
QUICStreamError,
|
|
||||||
)
|
|
||||||
|
class StreamState(Enum):
|
||||||
|
"""Stream lifecycle states following libp2p patterns."""
|
||||||
|
|
||||||
|
OPEN = "open"
|
||||||
|
WRITE_CLOSED = "write_closed"
|
||||||
|
READ_CLOSED = "read_closed"
|
||||||
|
CLOSED = "closed"
|
||||||
|
RESET = "reset"
|
||||||
|
|
||||||
|
|
||||||
|
class StreamDirection(Enum):
|
||||||
|
"""Stream direction for tracking initiator."""
|
||||||
|
|
||||||
|
INBOUND = "inbound"
|
||||||
|
OUTBOUND = "outbound"
|
||||||
|
|
||||||
|
|
||||||
|
class StreamTimeline:
|
||||||
|
"""Track stream lifecycle events for debugging and monitoring."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.created_at = time.time()
|
||||||
|
self.opened_at: float | None = None
|
||||||
|
self.first_data_at: float | None = None
|
||||||
|
self.closed_at: float | None = None
|
||||||
|
self.reset_at: float | None = None
|
||||||
|
self.error_code: int | None = None
|
||||||
|
|
||||||
|
def record_open(self) -> None:
|
||||||
|
self.opened_at = time.time()
|
||||||
|
|
||||||
|
def record_first_data(self) -> None:
|
||||||
|
if self.first_data_at is None:
|
||||||
|
self.first_data_at = time.time()
|
||||||
|
|
||||||
|
def record_close(self) -> None:
|
||||||
|
self.closed_at = time.time()
|
||||||
|
|
||||||
|
def record_reset(self, error_code: int) -> None:
|
||||||
|
self.reset_at = time.time()
|
||||||
|
self.error_code = error_code
|
||||||
|
|
||||||
|
|
||||||
class QUICStream(IMuxedStream):
|
class QUICStream(IMuxedStream):
|
||||||
"""
|
"""
|
||||||
Basic QUIC stream implementation for Module 1.
|
QUIC Stream implementation following libp2p IMuxedStream interface.
|
||||||
|
|
||||||
This is a minimal implementation to make Module 1 self-contained.
|
Based on patterns from go-libp2p and js-libp2p, this implementation:
|
||||||
Will be moved to a separate stream.py module in Module 3.
|
- Leverages QUIC's native multiplexing and flow control
|
||||||
|
- Integrates with libp2p resource management
|
||||||
|
- Provides comprehensive error handling with QUIC-specific codes
|
||||||
|
- Supports bidirectional communication with independent close semantics
|
||||||
|
- Implements proper stream lifecycle management
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Configuration constants based on research
|
||||||
|
DEFAULT_READ_TIMEOUT = 30.0 # 30 seconds
|
||||||
|
DEFAULT_WRITE_TIMEOUT = 30.0 # 30 seconds
|
||||||
|
FLOW_CONTROL_WINDOW_SIZE = 512 * 1024 # 512KB per stream
|
||||||
|
MAX_RECEIVE_BUFFER_SIZE = 1024 * 1024 # 1MB max buffering
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, connection: "QUICConnection", stream_id: int, is_initiator: bool
|
self,
|
||||||
|
connection: "QUICConnection",
|
||||||
|
stream_id: int,
|
||||||
|
direction: StreamDirection,
|
||||||
|
remote_addr: tuple[str, int],
|
||||||
|
resource_scope: Any | None = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Initialize QUIC stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection: Parent QUIC connection
|
||||||
|
stream_id: QUIC stream identifier
|
||||||
|
direction: Stream direction (inbound/outbound)
|
||||||
|
resource_scope: Resource manager scope for memory accounting
|
||||||
|
remote_addr: Remote addr stream is connected to
|
||||||
|
|
||||||
|
"""
|
||||||
self._connection = connection
|
self._connection = connection
|
||||||
self._stream_id = stream_id
|
self._stream_id = stream_id
|
||||||
self._is_initiator = is_initiator
|
self._direction = direction
|
||||||
self._closed = False
|
self._resource_scope = resource_scope
|
||||||
|
|
||||||
# Trio synchronization
|
# libp2p interface compliance
|
||||||
|
self._protocol: TProtocol | None = None
|
||||||
|
self._metadata: dict[str, Any] = {}
|
||||||
|
self._remote_addr = remote_addr
|
||||||
|
|
||||||
|
# Stream state management
|
||||||
|
self._state = StreamState.OPEN
|
||||||
|
self._state_lock = trio.Lock()
|
||||||
|
|
||||||
|
# Flow control and buffering
|
||||||
self._receive_buffer = bytearray()
|
self._receive_buffer = bytearray()
|
||||||
|
self._receive_buffer_lock = trio.Lock()
|
||||||
self._receive_event = trio.Event()
|
self._receive_event = trio.Event()
|
||||||
|
self._backpressure_event = trio.Event()
|
||||||
|
self._backpressure_event.set() # Initially no backpressure
|
||||||
|
|
||||||
|
# Close/reset state
|
||||||
|
self._write_closed = False
|
||||||
|
self._read_closed = False
|
||||||
self._close_event = trio.Event()
|
self._close_event = trio.Event()
|
||||||
|
self._reset_error_code: int | None = None
|
||||||
|
|
||||||
async def read(self, n: int | None = -1) -> bytes:
|
# Lifecycle tracking
|
||||||
"""Read data from the stream."""
|
self._timeline = StreamTimeline()
|
||||||
if self._closed:
|
self._timeline.record_open()
|
||||||
raise QUICStreamError("Stream is closed")
|
|
||||||
|
|
||||||
# Wait for data if buffer is empty
|
# Resource accounting
|
||||||
while not self._receive_buffer and not self._closed:
|
self._memory_reserved = 0
|
||||||
await self._receive_event.wait()
|
if self._resource_scope:
|
||||||
self._receive_event = trio.Event() # Reset for next read
|
self._reserve_memory(self.FLOW_CONTROL_WINDOW_SIZE)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Created QUIC stream {stream_id} "
|
||||||
|
f"({direction.value}, connection: {connection.remote_peer_id()})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Properties for libp2p interface compliance
|
||||||
|
|
||||||
|
@property
|
||||||
|
def protocol(self) -> TProtocol | None:
|
||||||
|
"""Get the protocol identifier for this stream."""
|
||||||
|
return self._protocol
|
||||||
|
|
||||||
|
@protocol.setter
|
||||||
|
def protocol(self, protocol_id: TProtocol) -> None:
|
||||||
|
"""Set the protocol identifier for this stream."""
|
||||||
|
self._protocol = protocol_id
|
||||||
|
self._metadata["protocol"] = protocol_id
|
||||||
|
logger.debug(f"Stream {self.stream_id} protocol set to: {protocol_id}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stream_id(self) -> str:
|
||||||
|
"""Get stream ID as string for libp2p compatibility."""
|
||||||
|
return str(self._stream_id)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def muxed_conn(self) -> "QUICConnection": # type: ignore
|
||||||
|
"""Get the parent muxed connection."""
|
||||||
|
return self._connection
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self) -> StreamState:
|
||||||
|
"""Get current stream state."""
|
||||||
|
return self._state
|
||||||
|
|
||||||
|
@property
|
||||||
|
def direction(self) -> StreamDirection:
|
||||||
|
"""Get stream direction."""
|
||||||
|
return self._direction
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_initiator(self) -> bool:
|
||||||
|
"""Check if this stream was locally initiated."""
|
||||||
|
return self._direction == StreamDirection.OUTBOUND
|
||||||
|
|
||||||
|
# Core stream operations
|
||||||
|
|
||||||
|
async def read(self, n: int | None = None) -> bytes:
|
||||||
|
"""
|
||||||
|
Read data from the stream with QUIC flow control.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n: Maximum number of bytes to read. If None or -1, read all available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Data read from stream
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
QUICStreamClosedError: Stream is closed
|
||||||
|
QUICStreamResetError: Stream was reset
|
||||||
|
QUICStreamTimeoutError: Read timeout exceeded
|
||||||
|
|
||||||
|
"""
|
||||||
|
if n is None:
|
||||||
|
n = -1
|
||||||
|
|
||||||
|
async with self._state_lock:
|
||||||
|
if self._state in (StreamState.CLOSED, StreamState.RESET):
|
||||||
|
raise QUICStreamClosedError(f"Stream {self.stream_id} is closed")
|
||||||
|
|
||||||
|
if self._read_closed:
|
||||||
|
# Return any remaining buffered data, then EOF
|
||||||
|
async with self._receive_buffer_lock:
|
||||||
|
if self._receive_buffer:
|
||||||
|
data = self._extract_data_from_buffer(n)
|
||||||
|
self._timeline.record_first_data()
|
||||||
|
return data
|
||||||
|
return b""
|
||||||
|
|
||||||
|
# Wait for data with timeout
|
||||||
|
timeout = self.DEFAULT_READ_TIMEOUT
|
||||||
|
try:
|
||||||
|
with trio.move_on_after(timeout) as cancel_scope:
|
||||||
|
while True:
|
||||||
|
async with self._receive_buffer_lock:
|
||||||
|
if self._receive_buffer:
|
||||||
|
data = self._extract_data_from_buffer(n)
|
||||||
|
self._timeline.record_first_data()
|
||||||
|
return data
|
||||||
|
|
||||||
|
# Check if stream was closed while waiting
|
||||||
|
if self._read_closed:
|
||||||
|
return b""
|
||||||
|
|
||||||
|
# Wait for more data
|
||||||
|
await self._receive_event.wait()
|
||||||
|
self._receive_event = trio.Event() # Reset for next wait
|
||||||
|
|
||||||
|
if cancel_scope.cancelled_caught:
|
||||||
|
raise QUICStreamTimeoutError(f"Read timeout on stream {self.stream_id}")
|
||||||
|
|
||||||
|
return b""
|
||||||
|
except QUICStreamResetError:
|
||||||
|
# Stream was reset while reading
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error reading from stream {self.stream_id}: {e}")
|
||||||
|
await self._handle_stream_error(e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def write(self, data: bytes) -> None:
|
||||||
|
"""
|
||||||
|
Write data to the stream with QUIC flow control.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Data to write
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
QUICStreamClosedError: Stream is closed for writing
|
||||||
|
QUICStreamBackpressureError: Flow control window exhausted
|
||||||
|
QUICStreamResetError: Stream was reset
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not data:
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self._state_lock:
|
||||||
|
if self._state in (StreamState.CLOSED, StreamState.RESET):
|
||||||
|
raise QUICStreamClosedError(f"Stream {self.stream_id} is closed")
|
||||||
|
|
||||||
|
if self._write_closed:
|
||||||
|
raise QUICStreamClosedError(
|
||||||
|
f"Stream {self.stream_id} write side is closed"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Handle flow control backpressure
|
||||||
|
await self._backpressure_event.wait()
|
||||||
|
|
||||||
|
# Send data through QUIC connection
|
||||||
|
self._connection._quic.send_stream_data(self._stream_id, data)
|
||||||
|
await self._connection._transmit()
|
||||||
|
|
||||||
|
self._timeline.record_first_data()
|
||||||
|
logger.debug(f"Wrote {len(data)} bytes to stream {self.stream_id}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error writing to stream {self.stream_id}: {e}")
|
||||||
|
# Convert QUIC-specific errors
|
||||||
|
if "flow control" in str(e).lower():
|
||||||
|
raise QUICStreamBackpressureError(f"Flow control limit reached: {e}")
|
||||||
|
await self._handle_stream_error(e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""
|
||||||
|
Close the stream gracefully (both read and write sides).
|
||||||
|
|
||||||
|
This implements proper close semantics where both sides
|
||||||
|
are closed and resources are cleaned up.
|
||||||
|
"""
|
||||||
|
async with self._state_lock:
|
||||||
|
if self._state in (StreamState.CLOSED, StreamState.RESET):
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug(f"Closing stream {self.stream_id}")
|
||||||
|
|
||||||
|
# Close both sides
|
||||||
|
if not self._write_closed:
|
||||||
|
await self.close_write()
|
||||||
|
if not self._read_closed:
|
||||||
|
await self.close_read()
|
||||||
|
|
||||||
|
# Update state and cleanup
|
||||||
|
async with self._state_lock:
|
||||||
|
self._state = StreamState.CLOSED
|
||||||
|
|
||||||
|
await self._cleanup_resources()
|
||||||
|
self._timeline.record_close()
|
||||||
|
self._close_event.set()
|
||||||
|
|
||||||
|
logger.debug(f"Stream {self.stream_id} closed")
|
||||||
|
|
||||||
|
async def close_write(self) -> None:
|
||||||
|
"""Close the write side of the stream."""
|
||||||
|
if self._write_closed:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Send FIN to close write side
|
||||||
|
self._connection._quic.send_stream_data(
|
||||||
|
self._stream_id, b"", end_stream=True
|
||||||
|
)
|
||||||
|
await self._connection._transmit()
|
||||||
|
|
||||||
|
self._write_closed = True
|
||||||
|
|
||||||
|
async with self._state_lock:
|
||||||
|
if self._read_closed:
|
||||||
|
self._state = StreamState.CLOSED
|
||||||
|
else:
|
||||||
|
self._state = StreamState.WRITE_CLOSED
|
||||||
|
|
||||||
|
logger.debug(f"Stream {self.stream_id} write side closed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing write side of stream {self.stream_id}: {e}")
|
||||||
|
|
||||||
|
async def close_read(self) -> None:
|
||||||
|
"""Close the read side of the stream."""
|
||||||
|
if self._read_closed:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Signal read closure to QUIC layer
|
||||||
|
self._connection._quic.reset_stream(self._stream_id, error_code=0)
|
||||||
|
await self._connection._transmit()
|
||||||
|
|
||||||
|
self._read_closed = True
|
||||||
|
|
||||||
|
async with self._state_lock:
|
||||||
|
if self._write_closed:
|
||||||
|
self._state = StreamState.CLOSED
|
||||||
|
else:
|
||||||
|
self._state = StreamState.READ_CLOSED
|
||||||
|
|
||||||
|
# Wake up any pending reads
|
||||||
|
self._receive_event.set()
|
||||||
|
|
||||||
|
logger.debug(f"Stream {self.stream_id} read side closed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing read side of stream {self.stream_id}: {e}")
|
||||||
|
|
||||||
|
async def reset(self, error_code: int = 0) -> None:
|
||||||
|
"""
|
||||||
|
Reset the stream with the given error code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_code: QUIC error code for the reset
|
||||||
|
|
||||||
|
"""
|
||||||
|
async with self._state_lock:
|
||||||
|
if self._state == StreamState.RESET:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Resetting stream {self.stream_id} with error code {error_code}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._state = StreamState.RESET
|
||||||
|
self._reset_error_code = error_code
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Send QUIC reset frame
|
||||||
|
self._connection._quic.reset_stream(self._stream_id, error_code)
|
||||||
|
await self._connection._transmit()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error sending reset for stream {self.stream_id}: {e}")
|
||||||
|
finally:
|
||||||
|
# Always cleanup resources
|
||||||
|
await self._cleanup_resources()
|
||||||
|
self._timeline.record_reset(error_code)
|
||||||
|
self._close_event.set()
|
||||||
|
|
||||||
|
def is_closed(self) -> bool:
|
||||||
|
"""Check if stream is completely closed."""
|
||||||
|
return self._state in (StreamState.CLOSED, StreamState.RESET)
|
||||||
|
|
||||||
|
def is_reset(self) -> bool:
|
||||||
|
"""Check if stream was reset."""
|
||||||
|
return self._state == StreamState.RESET
|
||||||
|
|
||||||
|
def can_read(self) -> bool:
|
||||||
|
"""Check if stream can be read from."""
|
||||||
|
return not self._read_closed and self._state not in (
|
||||||
|
StreamState.CLOSED,
|
||||||
|
StreamState.RESET,
|
||||||
|
)
|
||||||
|
|
||||||
|
def can_write(self) -> bool:
|
||||||
|
"""Check if stream can be written to."""
|
||||||
|
return not self._write_closed and self._state not in (
|
||||||
|
StreamState.CLOSED,
|
||||||
|
StreamState.RESET,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handle_data_received(self, data: bytes, end_stream: bool) -> None:
|
||||||
|
"""
|
||||||
|
Handle data received from the QUIC connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Received data
|
||||||
|
end_stream: Whether this is the last data (FIN received)
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self._state == StreamState.RESET:
|
||||||
|
return
|
||||||
|
|
||||||
|
if data:
|
||||||
|
async with self._receive_buffer_lock:
|
||||||
|
if len(self._receive_buffer) + len(data) > self.MAX_RECEIVE_BUFFER_SIZE:
|
||||||
|
logger.warning(
|
||||||
|
f"Stream {self.stream_id} receive buffer overflow, "
|
||||||
|
f"dropping {len(data)} bytes"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
self._receive_buffer.extend(data)
|
||||||
|
self._timeline.record_first_data()
|
||||||
|
|
||||||
|
# Notify waiting readers
|
||||||
|
self._receive_event.set()
|
||||||
|
|
||||||
|
logger.debug(f"Stream {self.stream_id} received {len(data)} bytes")
|
||||||
|
|
||||||
|
if end_stream:
|
||||||
|
self._read_closed = True
|
||||||
|
async with self._state_lock:
|
||||||
|
if self._write_closed:
|
||||||
|
self._state = StreamState.CLOSED
|
||||||
|
else:
|
||||||
|
self._state = StreamState.READ_CLOSED
|
||||||
|
|
||||||
|
# Wake up readers to process remaining data and EOF
|
||||||
|
self._receive_event.set()
|
||||||
|
|
||||||
|
logger.debug(f"Stream {self.stream_id} received FIN")
|
||||||
|
|
||||||
|
async def handle_reset(self, error_code: int) -> None:
|
||||||
|
"""
|
||||||
|
Handle stream reset from remote peer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_code: QUIC error code from reset frame
|
||||||
|
|
||||||
|
"""
|
||||||
|
logger.debug(
|
||||||
|
f"Stream {self.stream_id} reset by peer with error code {error_code}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async with self._state_lock:
|
||||||
|
self._state = StreamState.RESET
|
||||||
|
self._reset_error_code = error_code
|
||||||
|
|
||||||
|
await self._cleanup_resources()
|
||||||
|
self._timeline.record_reset(error_code)
|
||||||
|
self._close_event.set()
|
||||||
|
|
||||||
|
# Wake up any pending operations
|
||||||
|
self._receive_event.set()
|
||||||
|
self._backpressure_event.set()
|
||||||
|
|
||||||
|
async def handle_flow_control_update(self, available_window: int) -> None:
|
||||||
|
"""
|
||||||
|
Handle flow control window updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
available_window: Available flow control window size
|
||||||
|
|
||||||
|
"""
|
||||||
|
if available_window > 0:
|
||||||
|
self._backpressure_event.set()
|
||||||
|
logger.debug(
|
||||||
|
f"Stream {self.stream_id} flow control".__add__(
|
||||||
|
f"window updated: {available_window}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._backpressure_event = trio.Event() # Reset to blocking state
|
||||||
|
logger.debug(f"Stream {self.stream_id} flow control window exhausted")
|
||||||
|
|
||||||
|
def _extract_data_from_buffer(self, n: int) -> bytes:
|
||||||
|
"""Extract data from receive buffer with specified limit."""
|
||||||
if n == -1:
|
if n == -1:
|
||||||
|
# Read all available data
|
||||||
data = bytes(self._receive_buffer)
|
data = bytes(self._receive_buffer)
|
||||||
self._receive_buffer.clear()
|
self._receive_buffer.clear()
|
||||||
else:
|
else:
|
||||||
|
# Read up to n bytes
|
||||||
data = bytes(self._receive_buffer[:n])
|
data = bytes(self._receive_buffer[:n])
|
||||||
self._receive_buffer = self._receive_buffer[n:]
|
self._receive_buffer = self._receive_buffer[n:]
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def write(self, data: bytes) -> None:
|
async def _handle_stream_error(self, error: Exception) -> None:
|
||||||
"""Write data to the stream."""
|
"""Handle errors by resetting the stream."""
|
||||||
if self._closed:
|
logger.error(f"Stream {self.stream_id} error: {error}")
|
||||||
raise QUICStreamError("Stream is closed")
|
await self.reset(error_code=1) # Generic error code
|
||||||
|
|
||||||
# Send data using the underlying QUIC connection
|
def _reserve_memory(self, size: int) -> None:
|
||||||
self._connection._quic.send_stream_data(self._stream_id, data)
|
"""Reserve memory with resource manager."""
|
||||||
await self._connection._transmit()
|
if self._resource_scope:
|
||||||
|
try:
|
||||||
|
self._resource_scope.reserve_memory(size)
|
||||||
|
self._memory_reserved += size
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to reserve memory for stream {self.stream_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
async def close(self, error_code: int = 0) -> None:
|
def _release_memory(self, size: int) -> None:
|
||||||
"""Close the stream."""
|
"""Release memory with resource manager."""
|
||||||
if self._closed:
|
if self._resource_scope and size > 0:
|
||||||
return
|
try:
|
||||||
|
self._resource_scope.release_memory(size)
|
||||||
|
self._memory_reserved = max(0, self._memory_reserved - size)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to release memory for stream {self.stream_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
self._closed = True
|
async def _cleanup_resources(self) -> None:
|
||||||
|
"""Clean up stream resources."""
|
||||||
|
# Release all reserved memory
|
||||||
|
if self._memory_reserved > 0:
|
||||||
|
self._release_memory(self._memory_reserved)
|
||||||
|
|
||||||
# Close the QUIC stream
|
# Clear receive buffer
|
||||||
self._connection._quic.reset_stream(self._stream_id, error_code)
|
async with self._receive_buffer_lock:
|
||||||
await self._connection._transmit()
|
self._receive_buffer.clear()
|
||||||
|
|
||||||
# Remove from connection's stream list
|
# Remove from connection's stream registry
|
||||||
self._connection._streams.pop(self._stream_id, None)
|
self._connection._remove_stream(self._stream_id)
|
||||||
|
|
||||||
self._close_event.set()
|
logger.debug(f"Stream {self.stream_id} resources cleaned up")
|
||||||
|
|
||||||
def is_closed(self) -> bool:
|
# Abstact implementations
|
||||||
"""Check if stream is closed."""
|
|
||||||
return self._closed
|
|
||||||
|
|
||||||
async def handle_data_received(self, data: bytes, end_stream: bool) -> None:
|
def get_remote_address(self) -> tuple[str, int]:
|
||||||
"""Handle data received from the QUIC connection."""
|
return self._remote_addr
|
||||||
if self._closed:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._receive_buffer.extend(data)
|
|
||||||
self._receive_event.set()
|
|
||||||
|
|
||||||
if end_stream:
|
|
||||||
await self.close()
|
|
||||||
|
|
||||||
async def handle_reset(self, error_code: int) -> None:
|
|
||||||
"""Handle stream reset."""
|
|
||||||
self._closed = True
|
|
||||||
self._close_event.set()
|
|
||||||
|
|
||||||
def set_deadline(self, ttl: int) -> bool:
|
|
||||||
"""
|
|
||||||
Set the deadline
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Yamux does not support setting read deadlines")
|
|
||||||
|
|
||||||
async def reset(self) -> None:
|
|
||||||
"""
|
|
||||||
Reset the stream
|
|
||||||
"""
|
|
||||||
await self.handle_reset(0)
|
|
||||||
return
|
|
||||||
|
|
||||||
def get_remote_address(self) -> tuple[str, int] | None:
|
|
||||||
return self._connection._remote_addr
|
|
||||||
|
|
||||||
async def __aenter__(self) -> "QUICStream":
|
async def __aenter__(self) -> "QUICStream":
|
||||||
"""Enter the async context manager."""
|
"""Enter the async context manager."""
|
||||||
@ -134,3 +591,26 @@ class QUICStream(IMuxedStream):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Exit the async context manager and close the stream."""
|
"""Exit the async context manager and close the stream."""
|
||||||
await self.close()
|
await self.close()
|
||||||
|
|
||||||
|
def set_deadline(self, ttl: int) -> bool:
|
||||||
|
"""
|
||||||
|
Set a deadline for the stream. QUIC does not support deadlines natively,
|
||||||
|
so this method always returns False to indicate the operation is unsupported.
|
||||||
|
|
||||||
|
:param ttl: Time-to-live in seconds (ignored).
|
||||||
|
:return: False, as deadlines are not supported.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("QUIC does not support setting read deadlines")
|
||||||
|
|
||||||
|
# String representation for debugging
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"QUICStream(id={self.stream_id}, "
|
||||||
|
f"state={self._state.value}, "
|
||||||
|
f"direction={self._direction.value}, "
|
||||||
|
f"protocol={self._protocol})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"QUICStream({self.stream_id})"
|
||||||
|
|||||||
@ -1,20 +1,43 @@
|
|||||||
from unittest.mock import (
|
"""
|
||||||
Mock,
|
Enhanced tests for QUIC connection functionality - Module 3.
|
||||||
)
|
Tests all new features including advanced stream management, resource management,
|
||||||
|
error handling, and concurrent operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from multiaddr.multiaddr import Multiaddr
|
from multiaddr.multiaddr import Multiaddr
|
||||||
|
import trio
|
||||||
|
|
||||||
from libp2p.crypto.ed25519 import (
|
from libp2p.crypto.ed25519 import create_new_key_pair
|
||||||
create_new_key_pair,
|
|
||||||
)
|
|
||||||
from libp2p.peer.id import ID
|
from libp2p.peer.id import ID
|
||||||
from libp2p.transport.quic.connection import QUICConnection
|
from libp2p.transport.quic.connection import QUICConnection
|
||||||
from libp2p.transport.quic.exceptions import QUICStreamError
|
from libp2p.transport.quic.exceptions import (
|
||||||
|
QUICConnectionClosedError,
|
||||||
|
QUICConnectionError,
|
||||||
|
QUICConnectionTimeoutError,
|
||||||
|
QUICStreamLimitError,
|
||||||
|
QUICStreamTimeoutError,
|
||||||
|
)
|
||||||
|
from libp2p.transport.quic.stream import QUICStream, StreamDirection
|
||||||
|
|
||||||
|
|
||||||
class TestQUICConnection:
|
class MockResourceScope:
|
||||||
"""Test suite for QUIC connection functionality."""
|
"""Mock resource scope for testing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.memory_reserved = 0
|
||||||
|
|
||||||
|
def reserve_memory(self, size):
|
||||||
|
self.memory_reserved += size
|
||||||
|
|
||||||
|
def release_memory(self, size):
|
||||||
|
self.memory_reserved = max(0, self.memory_reserved - size)
|
||||||
|
|
||||||
|
|
||||||
|
class TestQUICConnectionEnhanced:
|
||||||
|
"""Enhanced test suite for QUIC connection functionality."""
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_quic_connection(self):
|
def mock_quic_connection(self):
|
||||||
@ -23,11 +46,20 @@ class TestQUICConnection:
|
|||||||
mock.next_event.return_value = None
|
mock.next_event.return_value = None
|
||||||
mock.datagrams_to_send.return_value = []
|
mock.datagrams_to_send.return_value = []
|
||||||
mock.get_timer.return_value = None
|
mock.get_timer.return_value = None
|
||||||
|
mock.connect = Mock()
|
||||||
|
mock.close = Mock()
|
||||||
|
mock.send_stream_data = Mock()
|
||||||
|
mock.reset_stream = Mock()
|
||||||
return mock
|
return mock
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def quic_connection(self, mock_quic_connection):
|
def mock_resource_scope(self):
|
||||||
"""Create test QUIC connection."""
|
"""Create mock resource scope."""
|
||||||
|
return MockResourceScope()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def quic_connection(self, mock_quic_connection, mock_resource_scope):
|
||||||
|
"""Create test QUIC connection with enhanced features."""
|
||||||
private_key = create_new_key_pair().private_key
|
private_key = create_new_key_pair().private_key
|
||||||
peer_id = ID.from_pubkey(private_key.get_public_key())
|
peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||||
|
|
||||||
@ -39,18 +71,44 @@ class TestQUICConnection:
|
|||||||
is_initiator=True,
|
is_initiator=True,
|
||||||
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||||
transport=Mock(),
|
transport=Mock(),
|
||||||
|
resource_scope=mock_resource_scope,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_connection_initialization(self, quic_connection):
|
@pytest.fixture
|
||||||
"""Test connection initialization."""
|
def server_connection(self, mock_quic_connection, mock_resource_scope):
|
||||||
|
"""Create server-side QUIC connection."""
|
||||||
|
private_key = create_new_key_pair().private_key
|
||||||
|
peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||||
|
|
||||||
|
return QUICConnection(
|
||||||
|
quic_connection=mock_quic_connection,
|
||||||
|
remote_addr=("127.0.0.1", 4001),
|
||||||
|
peer_id=peer_id,
|
||||||
|
local_peer_id=peer_id,
|
||||||
|
is_initiator=False,
|
||||||
|
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||||
|
transport=Mock(),
|
||||||
|
resource_scope=mock_resource_scope,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Basic functionality tests
|
||||||
|
|
||||||
|
def test_connection_initialization_enhanced(
|
||||||
|
self, quic_connection, mock_resource_scope
|
||||||
|
):
|
||||||
|
"""Test enhanced connection initialization."""
|
||||||
assert quic_connection._remote_addr == ("127.0.0.1", 4001)
|
assert quic_connection._remote_addr == ("127.0.0.1", 4001)
|
||||||
assert quic_connection.is_initiator is True
|
assert quic_connection.is_initiator is True
|
||||||
assert not quic_connection.is_closed
|
assert not quic_connection.is_closed
|
||||||
assert not quic_connection.is_established
|
assert not quic_connection.is_established
|
||||||
assert len(quic_connection._streams) == 0
|
assert len(quic_connection._streams) == 0
|
||||||
|
assert quic_connection._resource_scope == mock_resource_scope
|
||||||
|
assert quic_connection._outbound_stream_count == 0
|
||||||
|
assert quic_connection._inbound_stream_count == 0
|
||||||
|
assert len(quic_connection._stream_accept_queue) == 0
|
||||||
|
|
||||||
def test_stream_id_calculation(self):
|
def test_stream_id_calculation_enhanced(self):
|
||||||
"""Test stream ID calculation for client/server."""
|
"""Test enhanced stream ID calculation for client/server."""
|
||||||
# Client connection (initiator)
|
# Client connection (initiator)
|
||||||
client_conn = QUICConnection(
|
client_conn = QUICConnection(
|
||||||
quic_connection=Mock(),
|
quic_connection=Mock(),
|
||||||
@ -75,45 +133,364 @@ class TestQUICConnection:
|
|||||||
)
|
)
|
||||||
assert server_conn._next_stream_id == 1 # Server starts with 1
|
assert server_conn._next_stream_id == 1 # Server starts with 1
|
||||||
|
|
||||||
def test_incoming_stream_detection(self, quic_connection):
|
def test_incoming_stream_detection_enhanced(self, quic_connection):
|
||||||
"""Test incoming stream detection logic."""
|
"""Test enhanced incoming stream detection logic."""
|
||||||
# For client (initiator), odd stream IDs are incoming
|
# For client (initiator), odd stream IDs are incoming
|
||||||
assert quic_connection._is_incoming_stream(1) is True # Server-initiated
|
assert quic_connection._is_incoming_stream(1) is True # Server-initiated
|
||||||
assert quic_connection._is_incoming_stream(0) is False # Client-initiated
|
assert quic_connection._is_incoming_stream(0) is False # Client-initiated
|
||||||
assert quic_connection._is_incoming_stream(5) is True # Server-initiated
|
assert quic_connection._is_incoming_stream(5) is True # Server-initiated
|
||||||
assert quic_connection._is_incoming_stream(4) is False # Client-initiated
|
assert quic_connection._is_incoming_stream(4) is False # Client-initiated
|
||||||
|
|
||||||
|
# Stream management tests
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_connection_stats(self, quic_connection):
|
async def test_open_stream_basic(self, quic_connection):
|
||||||
"""Test connection statistics."""
|
"""Test basic stream opening."""
|
||||||
stats = quic_connection.get_stats()
|
quic_connection._started = True
|
||||||
|
|
||||||
|
stream = await quic_connection.open_stream()
|
||||||
|
|
||||||
|
assert isinstance(stream, QUICStream)
|
||||||
|
assert stream.stream_id == "0"
|
||||||
|
assert stream.direction == StreamDirection.OUTBOUND
|
||||||
|
assert 0 in quic_connection._streams
|
||||||
|
assert quic_connection._outbound_stream_count == 1
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_open_stream_limit_reached(self, quic_connection):
|
||||||
|
"""Test stream limit enforcement."""
|
||||||
|
quic_connection._started = True
|
||||||
|
quic_connection._outbound_stream_count = quic_connection.MAX_OUTGOING_STREAMS
|
||||||
|
|
||||||
|
with pytest.raises(QUICStreamLimitError, match="Maximum outbound streams"):
|
||||||
|
await quic_connection.open_stream()
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_open_stream_timeout(self, quic_connection: QUICConnection):
|
||||||
|
"""Test stream opening timeout."""
|
||||||
|
quic_connection._started = True
|
||||||
|
return
|
||||||
|
|
||||||
|
# Mock the stream ID lock to simulate slow operation
|
||||||
|
async def slow_acquire():
|
||||||
|
await trio.sleep(10) # Longer than timeout
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
quic_connection._stream_id_lock, "acquire", side_effect=slow_acquire
|
||||||
|
):
|
||||||
|
with pytest.raises(
|
||||||
|
QUICStreamTimeoutError, match="Stream creation timed out"
|
||||||
|
):
|
||||||
|
await quic_connection.open_stream(timeout=0.1)
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_accept_stream_basic(self, quic_connection):
|
||||||
|
"""Test basic stream acceptance."""
|
||||||
|
# Create a mock inbound stream
|
||||||
|
mock_stream = Mock(spec=QUICStream)
|
||||||
|
mock_stream.stream_id = "1"
|
||||||
|
|
||||||
|
# Add to accept queue
|
||||||
|
quic_connection._stream_accept_queue.append(mock_stream)
|
||||||
|
quic_connection._stream_accept_event.set()
|
||||||
|
|
||||||
|
accepted_stream = await quic_connection.accept_stream(timeout=0.1)
|
||||||
|
|
||||||
|
assert accepted_stream == mock_stream
|
||||||
|
assert len(quic_connection._stream_accept_queue) == 0
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_accept_stream_timeout(self, quic_connection):
|
||||||
|
"""Test stream acceptance timeout."""
|
||||||
|
with pytest.raises(QUICStreamTimeoutError, match="Stream accept timed out"):
|
||||||
|
await quic_connection.accept_stream(timeout=0.1)
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_accept_stream_on_closed_connection(self, quic_connection):
|
||||||
|
"""Test stream acceptance on closed connection."""
|
||||||
|
await quic_connection.close()
|
||||||
|
|
||||||
|
with pytest.raises(QUICConnectionClosedError, match="Connection is closed"):
|
||||||
|
await quic_connection.accept_stream()
|
||||||
|
|
||||||
|
# Stream handler tests
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_stream_handler_setting(self, quic_connection):
|
||||||
|
"""Test setting stream handler."""
|
||||||
|
|
||||||
|
async def mock_handler(stream):
|
||||||
|
pass
|
||||||
|
|
||||||
|
quic_connection.set_stream_handler(mock_handler)
|
||||||
|
assert quic_connection._stream_handler == mock_handler
|
||||||
|
|
||||||
|
# Connection lifecycle tests
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_connection_start_client(self, quic_connection):
|
||||||
|
"""Test client connection start."""
|
||||||
|
with patch.object(
|
||||||
|
quic_connection, "_initiate_connection", new_callable=AsyncMock
|
||||||
|
) as mock_initiate:
|
||||||
|
await quic_connection.start()
|
||||||
|
|
||||||
|
assert quic_connection._started
|
||||||
|
mock_initiate.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_connection_start_server(self, server_connection):
|
||||||
|
"""Test server connection start."""
|
||||||
|
await server_connection.start()
|
||||||
|
|
||||||
|
assert server_connection._started
|
||||||
|
assert server_connection._established
|
||||||
|
assert server_connection._connected_event.is_set()
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_connection_start_already_started(self, quic_connection):
|
||||||
|
"""Test starting already started connection."""
|
||||||
|
quic_connection._started = True
|
||||||
|
|
||||||
|
# Should not raise error, just log warning
|
||||||
|
await quic_connection.start()
|
||||||
|
assert quic_connection._started
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_connection_start_closed(self, quic_connection):
|
||||||
|
"""Test starting closed connection."""
|
||||||
|
quic_connection._closed = True
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
QUICConnectionError, match="Cannot start a closed connection"
|
||||||
|
):
|
||||||
|
await quic_connection.start()
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_connection_connect_with_nursery(self, quic_connection):
|
||||||
|
"""Test connection establishment with nursery."""
|
||||||
|
quic_connection._started = True
|
||||||
|
quic_connection._established = True
|
||||||
|
quic_connection._connected_event.set()
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
quic_connection, "_start_background_tasks", new_callable=AsyncMock
|
||||||
|
) as mock_start_tasks:
|
||||||
|
with patch.object(
|
||||||
|
quic_connection, "verify_peer_identity", new_callable=AsyncMock
|
||||||
|
) as mock_verify:
|
||||||
|
async with trio.open_nursery() as nursery:
|
||||||
|
await quic_connection.connect(nursery)
|
||||||
|
|
||||||
|
assert quic_connection._nursery == nursery
|
||||||
|
mock_start_tasks.assert_called_once()
|
||||||
|
mock_verify.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_connection_connect_timeout(self, quic_connection: QUICConnection):
|
||||||
|
"""Test connection establishment timeout."""
|
||||||
|
quic_connection._started = True
|
||||||
|
# Don't set connected event to simulate timeout
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
quic_connection, "_start_background_tasks", new_callable=AsyncMock
|
||||||
|
):
|
||||||
|
async with trio.open_nursery() as nursery:
|
||||||
|
with pytest.raises(
|
||||||
|
QUICConnectionTimeoutError, match="Connection handshake timed out"
|
||||||
|
):
|
||||||
|
await quic_connection.connect(nursery)
|
||||||
|
|
||||||
|
# Resource management tests
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_stream_removal_resource_cleanup(
|
||||||
|
self, quic_connection: QUICConnection, mock_resource_scope
|
||||||
|
):
|
||||||
|
"""Test stream removal and resource cleanup."""
|
||||||
|
quic_connection._started = True
|
||||||
|
|
||||||
|
# Create a stream
|
||||||
|
stream = await quic_connection.open_stream()
|
||||||
|
|
||||||
|
# Remove the stream
|
||||||
|
quic_connection._remove_stream(int(stream.stream_id))
|
||||||
|
|
||||||
|
assert int(stream.stream_id) not in quic_connection._streams
|
||||||
|
# Note: Count updates is async, so we can't test it directly here
|
||||||
|
|
||||||
|
# Error handling tests
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_connection_error_handling(self, quic_connection):
|
||||||
|
"""Test connection error handling."""
|
||||||
|
error = Exception("Test error")
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
quic_connection, "close", new_callable=AsyncMock
|
||||||
|
) as mock_close:
|
||||||
|
await quic_connection._handle_connection_error(error)
|
||||||
|
mock_close.assert_called_once()
|
||||||
|
|
||||||
|
# Statistics and monitoring tests
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_connection_stats_enhanced(self, quic_connection):
|
||||||
|
"""Test enhanced connection statistics."""
|
||||||
|
quic_connection._started = True
|
||||||
|
|
||||||
|
# Create some streams
|
||||||
|
_stream1 = await quic_connection.open_stream()
|
||||||
|
_stream2 = await quic_connection.open_stream()
|
||||||
|
|
||||||
|
stats = quic_connection.get_stream_stats()
|
||||||
|
|
||||||
expected_keys = [
|
expected_keys = [
|
||||||
"peer_id",
|
"total_streams",
|
||||||
"remote_addr",
|
"outbound_streams",
|
||||||
"is_initiator",
|
"inbound_streams",
|
||||||
"is_established",
|
"max_streams",
|
||||||
"is_closed",
|
"stream_utilization",
|
||||||
"active_streams",
|
"stats",
|
||||||
"next_stream_id",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
for key in expected_keys:
|
for key in expected_keys:
|
||||||
assert key in stats
|
assert key in stats
|
||||||
|
|
||||||
|
assert stats["total_streams"] == 2
|
||||||
|
assert stats["outbound_streams"] == 2
|
||||||
|
assert stats["inbound_streams"] == 0
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_connection_close(self, quic_connection):
|
async def test_get_active_streams(self, quic_connection):
|
||||||
"""Test connection close functionality."""
|
"""Test getting active streams."""
|
||||||
assert not quic_connection.is_closed
|
quic_connection._started = True
|
||||||
|
|
||||||
|
# Create streams
|
||||||
|
stream1 = await quic_connection.open_stream()
|
||||||
|
stream2 = await quic_connection.open_stream()
|
||||||
|
|
||||||
|
active_streams = quic_connection.get_active_streams()
|
||||||
|
|
||||||
|
assert len(active_streams) == 2
|
||||||
|
assert stream1 in active_streams
|
||||||
|
assert stream2 in active_streams
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_get_streams_by_protocol(self, quic_connection):
|
||||||
|
"""Test getting streams by protocol."""
|
||||||
|
quic_connection._started = True
|
||||||
|
|
||||||
|
# Create streams with different protocols
|
||||||
|
stream1 = await quic_connection.open_stream()
|
||||||
|
stream1.protocol = "/test/1.0.0"
|
||||||
|
|
||||||
|
stream2 = await quic_connection.open_stream()
|
||||||
|
stream2.protocol = "/other/1.0.0"
|
||||||
|
|
||||||
|
test_streams = quic_connection.get_streams_by_protocol("/test/1.0.0")
|
||||||
|
other_streams = quic_connection.get_streams_by_protocol("/other/1.0.0")
|
||||||
|
|
||||||
|
assert len(test_streams) == 1
|
||||||
|
assert len(other_streams) == 1
|
||||||
|
assert stream1 in test_streams
|
||||||
|
assert stream2 in other_streams
|
||||||
|
|
||||||
|
# Enhanced close tests
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_connection_close_enhanced(self, quic_connection: QUICConnection):
|
||||||
|
"""Test enhanced connection close with stream cleanup."""
|
||||||
|
quic_connection._started = True
|
||||||
|
|
||||||
|
# Create some streams
|
||||||
|
_stream1 = await quic_connection.open_stream()
|
||||||
|
_stream2 = await quic_connection.open_stream()
|
||||||
|
|
||||||
await quic_connection.close()
|
await quic_connection.close()
|
||||||
|
|
||||||
assert quic_connection.is_closed
|
assert quic_connection.is_closed
|
||||||
|
assert len(quic_connection._streams) == 0
|
||||||
|
|
||||||
|
# Concurrent operations tests
|
||||||
|
|
||||||
@pytest.mark.trio
|
@pytest.mark.trio
|
||||||
async def test_stream_operations_on_closed_connection(self, quic_connection):
|
async def test_concurrent_stream_operations(self, quic_connection):
|
||||||
"""Test stream operations on closed connection."""
|
"""Test concurrent stream operations."""
|
||||||
await quic_connection.close()
|
quic_connection._started = True
|
||||||
|
|
||||||
with pytest.raises(QUICStreamError, match="Connection is closed"):
|
async def create_stream():
|
||||||
await quic_connection.open_stream()
|
return await quic_connection.open_stream()
|
||||||
|
|
||||||
|
# Create multiple streams concurrently
|
||||||
|
async with trio.open_nursery() as nursery:
|
||||||
|
for i in range(10):
|
||||||
|
nursery.start_soon(create_stream)
|
||||||
|
|
||||||
|
# Wait a bit for all to start
|
||||||
|
await trio.sleep(0.1)
|
||||||
|
|
||||||
|
# Should have created streams without conflicts
|
||||||
|
assert quic_connection._outbound_stream_count == 10
|
||||||
|
assert len(quic_connection._streams) == 10
|
||||||
|
|
||||||
|
# Connection properties tests
|
||||||
|
|
||||||
|
def test_connection_properties(self, quic_connection):
|
||||||
|
"""Test connection property accessors."""
|
||||||
|
assert quic_connection.multiaddr() == quic_connection._maddr
|
||||||
|
assert quic_connection.local_peer_id() == quic_connection._local_peer_id
|
||||||
|
assert quic_connection.remote_peer_id() == quic_connection._peer_id
|
||||||
|
|
||||||
|
# IRawConnection interface tests
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_raw_connection_write(self, quic_connection):
|
||||||
|
"""Test raw connection write interface."""
|
||||||
|
quic_connection._started = True
|
||||||
|
|
||||||
|
with patch.object(quic_connection, "open_stream") as mock_open:
|
||||||
|
mock_stream = AsyncMock()
|
||||||
|
mock_open.return_value = mock_stream
|
||||||
|
|
||||||
|
await quic_connection.write(b"test data")
|
||||||
|
|
||||||
|
mock_open.assert_called_once()
|
||||||
|
mock_stream.write.assert_called_once_with(b"test data")
|
||||||
|
mock_stream.close_write.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.trio
|
||||||
|
async def test_raw_connection_read_not_implemented(self, quic_connection):
|
||||||
|
"""Test raw connection read raises NotImplementedError."""
|
||||||
|
with pytest.raises(NotImplementedError, match="Use muxed connection interface"):
|
||||||
|
await quic_connection.read()
|
||||||
|
|
||||||
|
# String representation tests
|
||||||
|
|
||||||
|
def test_connection_string_representation(self, quic_connection):
|
||||||
|
"""Test connection string representations."""
|
||||||
|
repr_str = repr(quic_connection)
|
||||||
|
str_str = str(quic_connection)
|
||||||
|
|
||||||
|
assert "QUICConnection" in repr_str
|
||||||
|
assert str(quic_connection._peer_id) in repr_str
|
||||||
|
assert str(quic_connection._remote_addr) in repr_str
|
||||||
|
assert str(quic_connection._peer_id) in str_str
|
||||||
|
|
||||||
|
# Mock verification helpers
|
||||||
|
|
||||||
|
def test_mock_resource_scope_functionality(self, mock_resource_scope):
|
||||||
|
"""Test mock resource scope works correctly."""
|
||||||
|
assert mock_resource_scope.memory_reserved == 0
|
||||||
|
|
||||||
|
mock_resource_scope.reserve_memory(1000)
|
||||||
|
assert mock_resource_scope.memory_reserved == 1000
|
||||||
|
|
||||||
|
mock_resource_scope.reserve_memory(500)
|
||||||
|
assert mock_resource_scope.memory_reserved == 1500
|
||||||
|
|
||||||
|
mock_resource_scope.release_memory(600)
|
||||||
|
assert mock_resource_scope.memory_reserved == 900
|
||||||
|
|
||||||
|
mock_resource_scope.release_memory(2000) # Should not go negative
|
||||||
|
assert mock_resource_scope.memory_reserved == 0
|
||||||
|
|||||||
Reference in New Issue
Block a user