mirror of
https://github.com/varun-r-mallya/py-libp2p.git
synced 2025-12-31 20:36:24 +00:00
fix: add basic quic stream and associated tests
This commit is contained in:
@ -7,7 +7,7 @@ from dataclasses import (
|
||||
field,
|
||||
)
|
||||
import ssl
|
||||
from typing import TypedDict
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from libp2p.custom_types import TProtocol
|
||||
|
||||
@ -76,6 +76,101 @@ class QUICTransportConfig:
|
||||
max_connections: int = 1000 # Maximum number of connections
|
||||
connection_timeout: float = 10.0 # Connection establishment timeout
|
||||
|
||||
MAX_CONCURRENT_STREAMS: int = 1000
|
||||
"""Maximum number of concurrent streams per connection."""
|
||||
|
||||
MAX_INCOMING_STREAMS: int = 1000
|
||||
"""Maximum number of incoming streams per connection."""
|
||||
|
||||
MAX_OUTGOING_STREAMS: int = 1000
|
||||
"""Maximum number of outgoing streams per connection."""
|
||||
|
||||
# Stream timeouts
|
||||
STREAM_OPEN_TIMEOUT: float = 5.0
|
||||
"""Timeout for opening new streams (seconds)."""
|
||||
|
||||
STREAM_ACCEPT_TIMEOUT: float = 30.0
|
||||
"""Timeout for accepting incoming streams (seconds)."""
|
||||
|
||||
STREAM_READ_TIMEOUT: float = 30.0
|
||||
"""Default timeout for stream read operations (seconds)."""
|
||||
|
||||
STREAM_WRITE_TIMEOUT: float = 30.0
|
||||
"""Default timeout for stream write operations (seconds)."""
|
||||
|
||||
STREAM_CLOSE_TIMEOUT: float = 10.0
|
||||
"""Timeout for graceful stream close (seconds)."""
|
||||
|
||||
# Flow control configuration
|
||||
STREAM_FLOW_CONTROL_WINDOW: int = 512 * 1024 # 512KB
|
||||
"""Per-stream flow control window size."""
|
||||
|
||||
CONNECTION_FLOW_CONTROL_WINDOW: int = 768 * 1024 # 768KB
|
||||
"""Connection-wide flow control window size."""
|
||||
|
||||
# Buffer management
|
||||
MAX_STREAM_RECEIVE_BUFFER: int = 1024 * 1024 # 1MB
|
||||
"""Maximum receive buffer size per stream."""
|
||||
|
||||
STREAM_RECEIVE_BUFFER_LOW_WATERMARK: int = 64 * 1024 # 64KB
|
||||
"""Low watermark for stream receive buffer."""
|
||||
|
||||
STREAM_RECEIVE_BUFFER_HIGH_WATERMARK: int = 512 * 1024 # 512KB
|
||||
"""High watermark for stream receive buffer."""
|
||||
|
||||
# Stream lifecycle configuration
|
||||
ENABLE_STREAM_RESET_ON_ERROR: bool = True
|
||||
"""Whether to automatically reset streams on errors."""
|
||||
|
||||
STREAM_RESET_ERROR_CODE: int = 1
|
||||
"""Default error code for stream resets."""
|
||||
|
||||
ENABLE_STREAM_KEEP_ALIVE: bool = False
|
||||
"""Whether to enable stream keep-alive mechanisms."""
|
||||
|
||||
STREAM_KEEP_ALIVE_INTERVAL: float = 30.0
|
||||
"""Interval for stream keep-alive pings (seconds)."""
|
||||
|
||||
# Resource management
|
||||
ENABLE_STREAM_RESOURCE_TRACKING: bool = True
|
||||
"""Whether to track stream resource usage."""
|
||||
|
||||
STREAM_MEMORY_LIMIT_PER_STREAM: int = 2 * 1024 * 1024 # 2MB
|
||||
"""Memory limit per individual stream."""
|
||||
|
||||
STREAM_MEMORY_LIMIT_PER_CONNECTION: int = 100 * 1024 * 1024 # 100MB
|
||||
"""Total memory limit for all streams per connection."""
|
||||
|
||||
# Concurrency and performance
|
||||
ENABLE_STREAM_BATCHING: bool = True
|
||||
"""Whether to batch multiple stream operations."""
|
||||
|
||||
STREAM_BATCH_SIZE: int = 10
|
||||
"""Number of streams to process in a batch."""
|
||||
|
||||
STREAM_PROCESSING_CONCURRENCY: int = 100
|
||||
"""Maximum concurrent stream processing tasks."""
|
||||
|
||||
# Debugging and monitoring
|
||||
ENABLE_STREAM_METRICS: bool = True
|
||||
"""Whether to collect stream metrics."""
|
||||
|
||||
ENABLE_STREAM_TIMELINE_TRACKING: bool = True
|
||||
"""Whether to track stream lifecycle timelines."""
|
||||
|
||||
STREAM_METRICS_COLLECTION_INTERVAL: float = 60.0
|
||||
"""Interval for collecting stream metrics (seconds)."""
|
||||
|
||||
# Error handling configuration
|
||||
STREAM_ERROR_RETRY_ATTEMPTS: int = 3
|
||||
"""Number of retry attempts for recoverable stream errors."""
|
||||
|
||||
STREAM_ERROR_RETRY_DELAY: float = 1.0
|
||||
"""Initial delay between stream error retries (seconds)."""
|
||||
|
||||
STREAM_ERROR_RETRY_BACKOFF_FACTOR: float = 2.0
|
||||
"""Backoff factor for stream error retries."""
|
||||
|
||||
# Protocol identifiers matching go-libp2p
|
||||
# TODO: UNTIL MUITIADDR REPO IS UPDATED
|
||||
# PROTOCOL_QUIC_V1: TProtocol = TProtocol("/quic-v1") # RFC 9000
|
||||
@ -92,3 +187,167 @@ class QUICTransportConfig:
|
||||
|
||||
if self.max_datagram_size < 1200:
|
||||
raise ValueError("Max datagram size must be at least 1200 bytes")
|
||||
|
||||
# Validate timeouts
|
||||
timeout_fields = [
|
||||
"STREAM_OPEN_TIMEOUT",
|
||||
"STREAM_ACCEPT_TIMEOUT",
|
||||
"STREAM_READ_TIMEOUT",
|
||||
"STREAM_WRITE_TIMEOUT",
|
||||
"STREAM_CLOSE_TIMEOUT",
|
||||
]
|
||||
for timeout_field in timeout_fields:
|
||||
if getattr(self, timeout_field) <= 0:
|
||||
raise ValueError(f"{timeout_field} must be positive")
|
||||
|
||||
# Validate flow control windows
|
||||
if self.STREAM_FLOW_CONTROL_WINDOW <= 0:
|
||||
raise ValueError("STREAM_FLOW_CONTROL_WINDOW must be positive")
|
||||
|
||||
if self.CONNECTION_FLOW_CONTROL_WINDOW < self.STREAM_FLOW_CONTROL_WINDOW:
|
||||
raise ValueError(
|
||||
"CONNECTION_FLOW_CONTROL_WINDOW must be >= STREAM_FLOW_CONTROL_WINDOW"
|
||||
)
|
||||
|
||||
# Validate buffer sizes
|
||||
if self.MAX_STREAM_RECEIVE_BUFFER <= 0:
|
||||
raise ValueError("MAX_STREAM_RECEIVE_BUFFER must be positive")
|
||||
|
||||
if self.STREAM_RECEIVE_BUFFER_HIGH_WATERMARK > self.MAX_STREAM_RECEIVE_BUFFER:
|
||||
raise ValueError(
|
||||
"STREAM_RECEIVE_BUFFER_HIGH_WATERMARK cannot".__add__(
|
||||
"exceed MAX_STREAM_RECEIVE_BUFFER"
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
self.STREAM_RECEIVE_BUFFER_LOW_WATERMARK
|
||||
>= self.STREAM_RECEIVE_BUFFER_HIGH_WATERMARK
|
||||
):
|
||||
raise ValueError(
|
||||
"STREAM_RECEIVE_BUFFER_LOW_WATERMARK must be < HIGH_WATERMARK"
|
||||
)
|
||||
|
||||
# Validate memory limits
|
||||
if self.STREAM_MEMORY_LIMIT_PER_STREAM <= 0:
|
||||
raise ValueError("STREAM_MEMORY_LIMIT_PER_STREAM must be positive")
|
||||
|
||||
if self.STREAM_MEMORY_LIMIT_PER_CONNECTION <= 0:
|
||||
raise ValueError("STREAM_MEMORY_LIMIT_PER_CONNECTION must be positive")
|
||||
|
||||
expected_stream_memory = (
|
||||
self.MAX_CONCURRENT_STREAMS * self.STREAM_MEMORY_LIMIT_PER_STREAM
|
||||
)
|
||||
if expected_stream_memory > self.STREAM_MEMORY_LIMIT_PER_CONNECTION * 2:
|
||||
# Allow some headroom, but warn if configuration seems inconsistent
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
"Stream memory configuration may be inconsistent: "
|
||||
f"{self.MAX_CONCURRENT_STREAMS} streams ×"
|
||||
"{self.STREAM_MEMORY_LIMIT_PER_STREAM} bytes "
|
||||
"could exceed connection limit of"
|
||||
f"{self.STREAM_MEMORY_LIMIT_PER_CONNECTION} bytes"
|
||||
)
|
||||
|
||||
def get_stream_config_dict(self) -> dict[str, Any]:
|
||||
"""Get stream-specific configuration as dictionary."""
|
||||
stream_config = {}
|
||||
for attr_name in dir(self):
|
||||
if attr_name.startswith(
|
||||
("STREAM_", "MAX_", "ENABLE_STREAM", "CONNECTION_FLOW")
|
||||
):
|
||||
stream_config[attr_name.lower()] = getattr(self, attr_name)
|
||||
return stream_config
|
||||
|
||||
|
||||
# Additional configuration classes for specific stream features
|
||||
|
||||
|
||||
class QUICStreamFlowControlConfig:
|
||||
"""Configuration for QUIC stream flow control."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_window_size: int = 512 * 1024,
|
||||
max_window_size: int = 2 * 1024 * 1024,
|
||||
window_update_threshold: float = 0.5,
|
||||
enable_auto_tuning: bool = True,
|
||||
):
|
||||
self.initial_window_size = initial_window_size
|
||||
self.max_window_size = max_window_size
|
||||
self.window_update_threshold = window_update_threshold
|
||||
self.enable_auto_tuning = enable_auto_tuning
|
||||
|
||||
|
||||
class QUICStreamMetricsConfig:
|
||||
"""Configuration for QUIC stream metrics collection."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enable_latency_tracking: bool = True,
|
||||
enable_throughput_tracking: bool = True,
|
||||
enable_error_tracking: bool = True,
|
||||
metrics_retention_duration: float = 3600.0, # 1 hour
|
||||
metrics_aggregation_interval: float = 60.0, # 1 minute
|
||||
):
|
||||
self.enable_latency_tracking = enable_latency_tracking
|
||||
self.enable_throughput_tracking = enable_throughput_tracking
|
||||
self.enable_error_tracking = enable_error_tracking
|
||||
self.metrics_retention_duration = metrics_retention_duration
|
||||
self.metrics_aggregation_interval = metrics_aggregation_interval
|
||||
|
||||
|
||||
# Factory function for creating optimized configurations
|
||||
|
||||
|
||||
def create_stream_config_for_use_case(use_case: str) -> QUICTransportConfig:
|
||||
"""
|
||||
Create optimized stream configuration for specific use cases.
|
||||
|
||||
Args:
|
||||
use_case: One of "high_throughput", "low_latency", "many_streams","
|
||||
"memory_constrained"
|
||||
|
||||
Returns:
|
||||
Optimized QUICTransportConfig
|
||||
|
||||
"""
|
||||
base_config = QUICTransportConfig()
|
||||
|
||||
if use_case == "high_throughput":
|
||||
# Optimize for high throughput
|
||||
base_config.STREAM_FLOW_CONTROL_WINDOW = 2 * 1024 * 1024 # 2MB
|
||||
base_config.CONNECTION_FLOW_CONTROL_WINDOW = 10 * 1024 * 1024 # 10MB
|
||||
base_config.MAX_STREAM_RECEIVE_BUFFER = 4 * 1024 * 1024 # 4MB
|
||||
base_config.STREAM_PROCESSING_CONCURRENCY = 200
|
||||
|
||||
elif use_case == "low_latency":
|
||||
# Optimize for low latency
|
||||
base_config.STREAM_OPEN_TIMEOUT = 1.0
|
||||
base_config.STREAM_READ_TIMEOUT = 5.0
|
||||
base_config.STREAM_WRITE_TIMEOUT = 5.0
|
||||
base_config.ENABLE_STREAM_BATCHING = False
|
||||
base_config.STREAM_BATCH_SIZE = 1
|
||||
|
||||
elif use_case == "many_streams":
|
||||
# Optimize for many concurrent streams
|
||||
base_config.MAX_CONCURRENT_STREAMS = 5000
|
||||
base_config.STREAM_FLOW_CONTROL_WINDOW = 128 * 1024 # 128KB
|
||||
base_config.MAX_STREAM_RECEIVE_BUFFER = 256 * 1024 # 256KB
|
||||
base_config.STREAM_PROCESSING_CONCURRENCY = 500
|
||||
|
||||
elif use_case == "memory_constrained":
|
||||
# Optimize for low memory usage
|
||||
base_config.MAX_CONCURRENT_STREAMS = 100
|
||||
base_config.STREAM_FLOW_CONTROL_WINDOW = 64 * 1024 # 64KB
|
||||
base_config.CONNECTION_FLOW_CONTROL_WINDOW = 256 * 1024 # 256KB
|
||||
base_config.MAX_STREAM_RECEIVE_BUFFER = 128 * 1024 # 128KB
|
||||
base_config.STREAM_MEMORY_LIMIT_PER_STREAM = 512 * 1024 # 512KB
|
||||
base_config.STREAM_PROCESSING_CONCURRENCY = 50
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown use case: {use_case}")
|
||||
|
||||
return base_config
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,35 +1,393 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
"""
|
||||
QUIC transport specific exceptions.
|
||||
QUIC Transport exceptions for py-libp2p.
|
||||
Comprehensive error handling for QUIC transport, connection, and stream operations.
|
||||
Based on patterns from go-libp2p and js-libp2p implementations.
|
||||
"""
|
||||
|
||||
from libp2p.exceptions import (
|
||||
BaseLibp2pError,
|
||||
)
|
||||
|
||||
class QUICError(Exception):
|
||||
"""Base exception for all QUIC transport errors."""
|
||||
|
||||
def __init__(self, message: str, error_code: int | None = None):
|
||||
super().__init__(message)
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
class QUICError(BaseLibp2pError):
|
||||
"""Base exception for QUIC transport errors."""
|
||||
# Transport-level exceptions
|
||||
|
||||
|
||||
class QUICDialError(QUICError):
|
||||
"""Exception raised when QUIC dial operation fails."""
|
||||
class QUICTransportError(QUICError):
|
||||
"""Base exception for QUIC transport operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICListenError(QUICError):
|
||||
"""Exception raised when QUIC listen operation fails."""
|
||||
class QUICDialError(QUICTransportError):
|
||||
"""Error occurred during QUIC connection establishment."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICListenError(QUICTransportError):
|
||||
"""Error occurred during QUIC listener operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICSecurityError(QUICTransportError):
|
||||
"""Error related to QUIC security/TLS operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Connection-level exceptions
|
||||
|
||||
|
||||
class QUICConnectionError(QUICError):
|
||||
"""Exception raised for QUIC connection errors."""
|
||||
"""Base exception for QUIC connection operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICConnectionClosedError(QUICConnectionError):
|
||||
"""QUIC connection has been closed."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICConnectionTimeoutError(QUICConnectionError):
|
||||
"""QUIC connection operation timed out."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICHandshakeError(QUICConnectionError):
|
||||
"""Error during QUIC handshake process."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICPeerVerificationError(QUICConnectionError):
|
||||
"""Error verifying peer identity during handshake."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Stream-level exceptions
|
||||
|
||||
|
||||
class QUICStreamError(QUICError):
|
||||
"""Exception raised for QUIC stream errors."""
|
||||
"""Base exception for QUIC stream operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
stream_id: str | None = None,
|
||||
error_code: int | None = None,
|
||||
):
|
||||
super().__init__(message, error_code)
|
||||
self.stream_id = stream_id
|
||||
|
||||
|
||||
class QUICStreamClosedError(QUICStreamError):
|
||||
"""Stream is closed and cannot be used for I/O operations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICStreamResetError(QUICStreamError):
|
||||
"""Stream was reset by local or remote peer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
stream_id: str | None = None,
|
||||
error_code: int | None = None,
|
||||
reset_by_peer: bool = False,
|
||||
):
|
||||
super().__init__(message, stream_id, error_code)
|
||||
self.reset_by_peer = reset_by_peer
|
||||
|
||||
|
||||
class QUICStreamTimeoutError(QUICStreamError):
|
||||
"""Stream operation timed out."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICStreamBackpressureError(QUICStreamError):
|
||||
"""Stream write blocked due to flow control."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICStreamLimitError(QUICStreamError):
|
||||
"""Stream limit reached (too many concurrent streams)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICStreamStateError(QUICStreamError):
|
||||
"""Invalid operation for current stream state."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
stream_id: str | None = None,
|
||||
current_state: str | None = None,
|
||||
attempted_operation: str | None = None,
|
||||
):
|
||||
super().__init__(message, stream_id)
|
||||
self.current_state = current_state
|
||||
self.attempted_operation = attempted_operation
|
||||
|
||||
|
||||
# Flow control exceptions
|
||||
|
||||
|
||||
class QUICFlowControlError(QUICError):
|
||||
"""Base exception for flow control related errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICFlowControlViolationError(QUICFlowControlError):
|
||||
"""Flow control limits were violated."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICFlowControlDeadlockError(QUICFlowControlError):
|
||||
"""Flow control deadlock detected."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Resource management exceptions
|
||||
|
||||
|
||||
class QUICResourceError(QUICError):
|
||||
"""Base exception for resource management errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICMemoryLimitError(QUICResourceError):
|
||||
"""Memory limit exceeded."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICConnectionLimitError(QUICResourceError):
|
||||
"""Connection limit exceeded."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Multiaddr and addressing exceptions
|
||||
|
||||
|
||||
class QUICAddressError(QUICError):
|
||||
"""Base exception for QUIC addressing errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICInvalidMultiaddrError(QUICAddressError):
|
||||
"""Invalid multiaddr format for QUIC transport."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICAddressResolutionError(QUICAddressError):
|
||||
"""Failed to resolve QUIC address."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICProtocolError(QUICError):
|
||||
"""Base exception for QUIC protocol errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICVersionNegotiationError(QUICProtocolError):
|
||||
"""QUIC version negotiation failed."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICUnsupportedVersionError(QUICProtocolError):
|
||||
"""Unsupported QUIC version."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Configuration exceptions
|
||||
|
||||
|
||||
class QUICConfigurationError(QUICError):
|
||||
"""Exception raised for QUIC configuration errors."""
|
||||
"""Base exception for QUIC configuration errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICSecurityError(QUICError):
|
||||
"""Exception raised for QUIC security/TLS errors."""
|
||||
class QUICInvalidConfigError(QUICConfigurationError):
|
||||
"""Invalid QUIC configuration parameters."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QUICCertificateError(QUICConfigurationError):
|
||||
"""Error with TLS certificate configuration."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def map_quic_error_code(error_code: int) -> str:
|
||||
"""
|
||||
Map QUIC error codes to human-readable descriptions.
|
||||
Based on RFC 9000 Transport Error Codes.
|
||||
"""
|
||||
error_codes = {
|
||||
0x00: "NO_ERROR",
|
||||
0x01: "INTERNAL_ERROR",
|
||||
0x02: "CONNECTION_REFUSED",
|
||||
0x03: "FLOW_CONTROL_ERROR",
|
||||
0x04: "STREAM_LIMIT_ERROR",
|
||||
0x05: "STREAM_STATE_ERROR",
|
||||
0x06: "FINAL_SIZE_ERROR",
|
||||
0x07: "FRAME_ENCODING_ERROR",
|
||||
0x08: "TRANSPORT_PARAMETER_ERROR",
|
||||
0x09: "CONNECTION_ID_LIMIT_ERROR",
|
||||
0x0A: "PROTOCOL_VIOLATION",
|
||||
0x0B: "INVALID_TOKEN",
|
||||
0x0C: "APPLICATION_ERROR",
|
||||
0x0D: "CRYPTO_BUFFER_EXCEEDED",
|
||||
0x0E: "KEY_UPDATE_ERROR",
|
||||
0x0F: "AEAD_LIMIT_REACHED",
|
||||
0x10: "NO_VIABLE_PATH",
|
||||
}
|
||||
|
||||
return error_codes.get(error_code, f"UNKNOWN_ERROR_{error_code:02X}")
|
||||
|
||||
|
||||
def create_stream_error(
|
||||
error_type: str,
|
||||
message: str,
|
||||
stream_id: str | None = None,
|
||||
error_code: int | None = None,
|
||||
) -> QUICStreamError:
|
||||
"""
|
||||
Factory function to create appropriate stream error based on type.
|
||||
|
||||
Args:
|
||||
error_type: Type of error ("closed", "reset", "timeout", "backpressure", etc.)
|
||||
message: Error message
|
||||
stream_id: Stream identifier
|
||||
error_code: QUIC error code
|
||||
|
||||
Returns:
|
||||
Appropriate QUICStreamError subclass
|
||||
|
||||
"""
|
||||
error_type = error_type.lower()
|
||||
|
||||
if error_type in ("closed", "close"):
|
||||
return QUICStreamClosedError(message, stream_id, error_code)
|
||||
elif error_type == "reset":
|
||||
return QUICStreamResetError(message, stream_id, error_code)
|
||||
elif error_type == "timeout":
|
||||
return QUICStreamTimeoutError(message, stream_id, error_code)
|
||||
elif error_type in ("backpressure", "flow_control"):
|
||||
return QUICStreamBackpressureError(message, stream_id, error_code)
|
||||
elif error_type in ("limit", "stream_limit"):
|
||||
return QUICStreamLimitError(message, stream_id, error_code)
|
||||
elif error_type == "state":
|
||||
return QUICStreamStateError(message, stream_id)
|
||||
else:
|
||||
return QUICStreamError(message, stream_id, error_code)
|
||||
|
||||
|
||||
def create_connection_error(
|
||||
error_type: str, message: str, error_code: int | None = None
|
||||
) -> QUICConnectionError:
|
||||
"""
|
||||
Factory function to create appropriate connection error based on type.
|
||||
|
||||
Args:
|
||||
error_type: Type of error ("closed", "timeout", "handshake", etc.)
|
||||
message: Error message
|
||||
error_code: QUIC error code
|
||||
|
||||
Returns:
|
||||
Appropriate QUICConnectionError subclass
|
||||
|
||||
"""
|
||||
error_type = error_type.lower()
|
||||
|
||||
if error_type in ("closed", "close"):
|
||||
return QUICConnectionClosedError(message, error_code)
|
||||
elif error_type == "timeout":
|
||||
return QUICConnectionTimeoutError(message, error_code)
|
||||
elif error_type == "handshake":
|
||||
return QUICHandshakeError(message, error_code)
|
||||
elif error_type in ("peer_verification", "verification"):
|
||||
return QUICPeerVerificationError(message, error_code)
|
||||
else:
|
||||
return QUICConnectionError(message, error_code)
|
||||
|
||||
|
||||
class QUICErrorContext:
|
||||
"""
|
||||
Context manager for handling QUIC errors with automatic error mapping.
|
||||
Useful for converting low-level aioquic errors to py-libp2p QUIC errors.
|
||||
"""
|
||||
|
||||
def __init__(self, operation: str, component: str = "quic") -> None:
|
||||
self.operation = operation
|
||||
self.component = component
|
||||
|
||||
def __enter__(self) -> "QUICErrorContext":
|
||||
return self
|
||||
|
||||
# TODO: Fix types for exc_type
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: Any,
|
||||
) -> Literal[False]:
|
||||
if exc_type is None:
|
||||
return False
|
||||
|
||||
if exc_val is None:
|
||||
return False
|
||||
|
||||
# Map common aioquic exceptions to our exceptions
|
||||
if "ConnectionClosed" in str(exc_type):
|
||||
raise QUICConnectionClosedError(
|
||||
f"Connection closed during {self.operation}: {exc_val}"
|
||||
) from exc_val
|
||||
elif "StreamReset" in str(exc_type):
|
||||
raise QUICStreamResetError(
|
||||
f"Stream reset during {self.operation}: {exc_val}"
|
||||
) from exc_val
|
||||
elif "timeout" in str(exc_val).lower():
|
||||
if "stream" in self.component.lower():
|
||||
raise QUICStreamTimeoutError(
|
||||
f"Timeout during {self.operation}: {exc_val}"
|
||||
) from exc_val
|
||||
else:
|
||||
raise QUICConnectionTimeoutError(
|
||||
f"Timeout during {self.operation}: {exc_val}"
|
||||
) from exc_val
|
||||
elif "flow control" in str(exc_val).lower():
|
||||
raise QUICStreamBackpressureError(
|
||||
f"Flow control error during {self.operation}: {exc_val}"
|
||||
) from exc_val
|
||||
|
||||
# Let other exceptions propagate
|
||||
return False
|
||||
|
||||
@ -251,7 +251,7 @@ class QUICListener(IListener):
|
||||
connection._quic.receive_datagram(data, addr, now=time.time())
|
||||
|
||||
# Process events and handle responses
|
||||
await connection._process_events()
|
||||
await connection._process_quic_events()
|
||||
await connection._transmit()
|
||||
|
||||
except Exception as e:
|
||||
@ -386,8 +386,8 @@ class QUICListener(IListener):
|
||||
|
||||
# Start connection management tasks
|
||||
if self._nursery:
|
||||
self._nursery.start_soon(connection._handle_incoming_data)
|
||||
self._nursery.start_soon(connection._handle_timer)
|
||||
self._nursery.start_soon(connection._handle_datagram_received)
|
||||
self._nursery.start_soon(connection._handle_timer_events)
|
||||
|
||||
# TODO: Verify peer identity
|
||||
# await connection.verify_peer_identity()
|
||||
|
||||
@ -1,126 +1,583 @@
|
||||
"""
|
||||
QUIC Stream implementation
|
||||
QUIC Stream implementation for py-libp2p Module 3.
|
||||
Based on patterns from go-libp2p and js-libp2p QUIC implementations.
|
||||
Uses aioquic's native stream capabilities with libp2p interface compliance.
|
||||
"""
|
||||
|
||||
from types import (
|
||||
TracebackType,
|
||||
)
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from enum import Enum
|
||||
import logging
|
||||
import time
|
||||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import trio
|
||||
|
||||
from .exceptions import (
|
||||
QUICStreamBackpressureError,
|
||||
QUICStreamClosedError,
|
||||
QUICStreamResetError,
|
||||
QUICStreamTimeoutError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libp2p.abc import IMuxedStream
|
||||
from libp2p.custom_types import TProtocol
|
||||
|
||||
from .connection import QUICConnection
|
||||
else:
|
||||
IMuxedStream = cast(type, object)
|
||||
TProtocol = cast(type, object)
|
||||
|
||||
from .exceptions import (
|
||||
QUICStreamError,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamState(Enum):
|
||||
"""Stream lifecycle states following libp2p patterns."""
|
||||
|
||||
OPEN = "open"
|
||||
WRITE_CLOSED = "write_closed"
|
||||
READ_CLOSED = "read_closed"
|
||||
CLOSED = "closed"
|
||||
RESET = "reset"
|
||||
|
||||
|
||||
class StreamDirection(Enum):
|
||||
"""Stream direction for tracking initiator."""
|
||||
|
||||
INBOUND = "inbound"
|
||||
OUTBOUND = "outbound"
|
||||
|
||||
|
||||
class StreamTimeline:
|
||||
"""Track stream lifecycle events for debugging and monitoring."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.created_at = time.time()
|
||||
self.opened_at: float | None = None
|
||||
self.first_data_at: float | None = None
|
||||
self.closed_at: float | None = None
|
||||
self.reset_at: float | None = None
|
||||
self.error_code: int | None = None
|
||||
|
||||
def record_open(self) -> None:
|
||||
self.opened_at = time.time()
|
||||
|
||||
def record_first_data(self) -> None:
|
||||
if self.first_data_at is None:
|
||||
self.first_data_at = time.time()
|
||||
|
||||
def record_close(self) -> None:
|
||||
self.closed_at = time.time()
|
||||
|
||||
def record_reset(self, error_code: int) -> None:
|
||||
self.reset_at = time.time()
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
class QUICStream(IMuxedStream):
|
||||
"""
|
||||
Basic QUIC stream implementation for Module 1.
|
||||
QUIC Stream implementation following libp2p IMuxedStream interface.
|
||||
|
||||
This is a minimal implementation to make Module 1 self-contained.
|
||||
Will be moved to a separate stream.py module in Module 3.
|
||||
Based on patterns from go-libp2p and js-libp2p, this implementation:
|
||||
- Leverages QUIC's native multiplexing and flow control
|
||||
- Integrates with libp2p resource management
|
||||
- Provides comprehensive error handling with QUIC-specific codes
|
||||
- Supports bidirectional communication with independent close semantics
|
||||
- Implements proper stream lifecycle management
|
||||
"""
|
||||
|
||||
# Configuration constants based on research
|
||||
DEFAULT_READ_TIMEOUT = 30.0 # 30 seconds
|
||||
DEFAULT_WRITE_TIMEOUT = 30.0 # 30 seconds
|
||||
FLOW_CONTROL_WINDOW_SIZE = 512 * 1024 # 512KB per stream
|
||||
MAX_RECEIVE_BUFFER_SIZE = 1024 * 1024 # 1MB max buffering
|
||||
|
||||
def __init__(
|
||||
self, connection: "QUICConnection", stream_id: int, is_initiator: bool
|
||||
self,
|
||||
connection: "QUICConnection",
|
||||
stream_id: int,
|
||||
direction: StreamDirection,
|
||||
remote_addr: tuple[str, int],
|
||||
resource_scope: Any | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize QUIC stream.
|
||||
|
||||
Args:
|
||||
connection: Parent QUIC connection
|
||||
stream_id: QUIC stream identifier
|
||||
direction: Stream direction (inbound/outbound)
|
||||
resource_scope: Resource manager scope for memory accounting
|
||||
remote_addr: Remote addr stream is connected to
|
||||
|
||||
"""
|
||||
self._connection = connection
|
||||
self._stream_id = stream_id
|
||||
self._is_initiator = is_initiator
|
||||
self._closed = False
|
||||
self._direction = direction
|
||||
self._resource_scope = resource_scope
|
||||
|
||||
# Trio synchronization
|
||||
# libp2p interface compliance
|
||||
self._protocol: TProtocol | None = None
|
||||
self._metadata: dict[str, Any] = {}
|
||||
self._remote_addr = remote_addr
|
||||
|
||||
# Stream state management
|
||||
self._state = StreamState.OPEN
|
||||
self._state_lock = trio.Lock()
|
||||
|
||||
# Flow control and buffering
|
||||
self._receive_buffer = bytearray()
|
||||
self._receive_buffer_lock = trio.Lock()
|
||||
self._receive_event = trio.Event()
|
||||
self._backpressure_event = trio.Event()
|
||||
self._backpressure_event.set() # Initially no backpressure
|
||||
|
||||
# Close/reset state
|
||||
self._write_closed = False
|
||||
self._read_closed = False
|
||||
self._close_event = trio.Event()
|
||||
self._reset_error_code: int | None = None
|
||||
|
||||
async def read(self, n: int | None = -1) -> bytes:
|
||||
"""Read data from the stream."""
|
||||
if self._closed:
|
||||
raise QUICStreamError("Stream is closed")
|
||||
# Lifecycle tracking
|
||||
self._timeline = StreamTimeline()
|
||||
self._timeline.record_open()
|
||||
|
||||
# Wait for data if buffer is empty
|
||||
while not self._receive_buffer and not self._closed:
|
||||
await self._receive_event.wait()
|
||||
self._receive_event = trio.Event() # Reset for next read
|
||||
# Resource accounting
|
||||
self._memory_reserved = 0
|
||||
if self._resource_scope:
|
||||
self._reserve_memory(self.FLOW_CONTROL_WINDOW_SIZE)
|
||||
|
||||
logger.debug(
|
||||
f"Created QUIC stream {stream_id} "
|
||||
f"({direction.value}, connection: {connection.remote_peer_id()})"
|
||||
)
|
||||
|
||||
# Properties for libp2p interface compliance
|
||||
|
||||
@property
|
||||
def protocol(self) -> TProtocol | None:
|
||||
"""Get the protocol identifier for this stream."""
|
||||
return self._protocol
|
||||
|
||||
@protocol.setter
|
||||
def protocol(self, protocol_id: TProtocol) -> None:
|
||||
"""Set the protocol identifier for this stream."""
|
||||
self._protocol = protocol_id
|
||||
self._metadata["protocol"] = protocol_id
|
||||
logger.debug(f"Stream {self.stream_id} protocol set to: {protocol_id}")
|
||||
|
||||
@property
|
||||
def stream_id(self) -> str:
|
||||
"""Get stream ID as string for libp2p compatibility."""
|
||||
return str(self._stream_id)
|
||||
|
||||
@property
|
||||
def muxed_conn(self) -> "QUICConnection": # type: ignore
|
||||
"""Get the parent muxed connection."""
|
||||
return self._connection
|
||||
|
||||
@property
|
||||
def state(self) -> StreamState:
|
||||
"""Get current stream state."""
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def direction(self) -> StreamDirection:
|
||||
"""Get stream direction."""
|
||||
return self._direction
|
||||
|
||||
@property
|
||||
def is_initiator(self) -> bool:
|
||||
"""Check if this stream was locally initiated."""
|
||||
return self._direction == StreamDirection.OUTBOUND
|
||||
|
||||
# Core stream operations
|
||||
|
||||
async def read(self, n: int | None = None) -> bytes:
|
||||
"""
|
||||
Read data from the stream with QUIC flow control.
|
||||
|
||||
Args:
|
||||
n: Maximum number of bytes to read. If None or -1, read all available.
|
||||
|
||||
Returns:
|
||||
Data read from stream
|
||||
|
||||
Raises:
|
||||
QUICStreamClosedError: Stream is closed
|
||||
QUICStreamResetError: Stream was reset
|
||||
QUICStreamTimeoutError: Read timeout exceeded
|
||||
|
||||
"""
|
||||
if n is None:
|
||||
n = -1
|
||||
|
||||
async with self._state_lock:
|
||||
if self._state in (StreamState.CLOSED, StreamState.RESET):
|
||||
raise QUICStreamClosedError(f"Stream {self.stream_id} is closed")
|
||||
|
||||
if self._read_closed:
|
||||
# Return any remaining buffered data, then EOF
|
||||
async with self._receive_buffer_lock:
|
||||
if self._receive_buffer:
|
||||
data = self._extract_data_from_buffer(n)
|
||||
self._timeline.record_first_data()
|
||||
return data
|
||||
return b""
|
||||
|
||||
# Wait for data with timeout
|
||||
timeout = self.DEFAULT_READ_TIMEOUT
|
||||
try:
|
||||
with trio.move_on_after(timeout) as cancel_scope:
|
||||
while True:
|
||||
async with self._receive_buffer_lock:
|
||||
if self._receive_buffer:
|
||||
data = self._extract_data_from_buffer(n)
|
||||
self._timeline.record_first_data()
|
||||
return data
|
||||
|
||||
# Check if stream was closed while waiting
|
||||
if self._read_closed:
|
||||
return b""
|
||||
|
||||
# Wait for more data
|
||||
await self._receive_event.wait()
|
||||
self._receive_event = trio.Event() # Reset for next wait
|
||||
|
||||
if cancel_scope.cancelled_caught:
|
||||
raise QUICStreamTimeoutError(f"Read timeout on stream {self.stream_id}")
|
||||
|
||||
return b""
|
||||
except QUICStreamResetError:
|
||||
# Stream was reset while reading
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading from stream {self.stream_id}: {e}")
|
||||
await self._handle_stream_error(e)
|
||||
raise
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""
|
||||
Write data to the stream with QUIC flow control.
|
||||
|
||||
Args:
|
||||
data: Data to write
|
||||
|
||||
Raises:
|
||||
QUICStreamClosedError: Stream is closed for writing
|
||||
QUICStreamBackpressureError: Flow control window exhausted
|
||||
QUICStreamResetError: Stream was reset
|
||||
|
||||
"""
|
||||
if not data:
|
||||
return
|
||||
|
||||
async with self._state_lock:
|
||||
if self._state in (StreamState.CLOSED, StreamState.RESET):
|
||||
raise QUICStreamClosedError(f"Stream {self.stream_id} is closed")
|
||||
|
||||
if self._write_closed:
|
||||
raise QUICStreamClosedError(
|
||||
f"Stream {self.stream_id} write side is closed"
|
||||
)
|
||||
|
||||
try:
|
||||
# Handle flow control backpressure
|
||||
await self._backpressure_event.wait()
|
||||
|
||||
# Send data through QUIC connection
|
||||
self._connection._quic.send_stream_data(self._stream_id, data)
|
||||
await self._connection._transmit()
|
||||
|
||||
self._timeline.record_first_data()
|
||||
logger.debug(f"Wrote {len(data)} bytes to stream {self.stream_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing to stream {self.stream_id}: {e}")
|
||||
# Convert QUIC-specific errors
|
||||
if "flow control" in str(e).lower():
|
||||
raise QUICStreamBackpressureError(f"Flow control limit reached: {e}")
|
||||
await self._handle_stream_error(e)
|
||||
raise
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close the stream gracefully (both read and write sides).
|
||||
|
||||
This implements proper close semantics where both sides
|
||||
are closed and resources are cleaned up.
|
||||
"""
|
||||
async with self._state_lock:
|
||||
if self._state in (StreamState.CLOSED, StreamState.RESET):
|
||||
return
|
||||
|
||||
logger.debug(f"Closing stream {self.stream_id}")
|
||||
|
||||
# Close both sides
|
||||
if not self._write_closed:
|
||||
await self.close_write()
|
||||
if not self._read_closed:
|
||||
await self.close_read()
|
||||
|
||||
# Update state and cleanup
|
||||
async with self._state_lock:
|
||||
self._state = StreamState.CLOSED
|
||||
|
||||
await self._cleanup_resources()
|
||||
self._timeline.record_close()
|
||||
self._close_event.set()
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} closed")
|
||||
|
||||
async def close_write(self) -> None:
|
||||
"""Close the write side of the stream."""
|
||||
if self._write_closed:
|
||||
return
|
||||
|
||||
try:
|
||||
# Send FIN to close write side
|
||||
self._connection._quic.send_stream_data(
|
||||
self._stream_id, b"", end_stream=True
|
||||
)
|
||||
await self._connection._transmit()
|
||||
|
||||
self._write_closed = True
|
||||
|
||||
async with self._state_lock:
|
||||
if self._read_closed:
|
||||
self._state = StreamState.CLOSED
|
||||
else:
|
||||
self._state = StreamState.WRITE_CLOSED
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} write side closed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing write side of stream {self.stream_id}: {e}")
|
||||
|
||||
async def close_read(self) -> None:
|
||||
"""Close the read side of the stream."""
|
||||
if self._read_closed:
|
||||
return
|
||||
|
||||
try:
|
||||
# Signal read closure to QUIC layer
|
||||
self._connection._quic.reset_stream(self._stream_id, error_code=0)
|
||||
await self._connection._transmit()
|
||||
|
||||
self._read_closed = True
|
||||
|
||||
async with self._state_lock:
|
||||
if self._write_closed:
|
||||
self._state = StreamState.CLOSED
|
||||
else:
|
||||
self._state = StreamState.READ_CLOSED
|
||||
|
||||
# Wake up any pending reads
|
||||
self._receive_event.set()
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} read side closed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing read side of stream {self.stream_id}: {e}")
|
||||
|
||||
async def reset(self, error_code: int = 0) -> None:
|
||||
"""
|
||||
Reset the stream with the given error code.
|
||||
|
||||
Args:
|
||||
error_code: QUIC error code for the reset
|
||||
|
||||
"""
|
||||
async with self._state_lock:
|
||||
if self._state == StreamState.RESET:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Resetting stream {self.stream_id} with error code {error_code}"
|
||||
)
|
||||
|
||||
self._state = StreamState.RESET
|
||||
self._reset_error_code = error_code
|
||||
|
||||
try:
|
||||
# Send QUIC reset frame
|
||||
self._connection._quic.reset_stream(self._stream_id, error_code)
|
||||
await self._connection._transmit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending reset for stream {self.stream_id}: {e}")
|
||||
finally:
|
||||
# Always cleanup resources
|
||||
await self._cleanup_resources()
|
||||
self._timeline.record_reset(error_code)
|
||||
self._close_event.set()
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if stream is completely closed."""
|
||||
return self._state in (StreamState.CLOSED, StreamState.RESET)
|
||||
|
||||
def is_reset(self) -> bool:
|
||||
"""Check if stream was reset."""
|
||||
return self._state == StreamState.RESET
|
||||
|
||||
def can_read(self) -> bool:
|
||||
"""Check if stream can be read from."""
|
||||
return not self._read_closed and self._state not in (
|
||||
StreamState.CLOSED,
|
||||
StreamState.RESET,
|
||||
)
|
||||
|
||||
def can_write(self) -> bool:
|
||||
"""Check if stream can be written to."""
|
||||
return not self._write_closed and self._state not in (
|
||||
StreamState.CLOSED,
|
||||
StreamState.RESET,
|
||||
)
|
||||
|
||||
async def handle_data_received(self, data: bytes, end_stream: bool) -> None:
|
||||
"""
|
||||
Handle data received from the QUIC connection.
|
||||
|
||||
Args:
|
||||
data: Received data
|
||||
end_stream: Whether this is the last data (FIN received)
|
||||
|
||||
"""
|
||||
if self._state == StreamState.RESET:
|
||||
return
|
||||
|
||||
if data:
|
||||
async with self._receive_buffer_lock:
|
||||
if len(self._receive_buffer) + len(data) > self.MAX_RECEIVE_BUFFER_SIZE:
|
||||
logger.warning(
|
||||
f"Stream {self.stream_id} receive buffer overflow, "
|
||||
f"dropping {len(data)} bytes"
|
||||
)
|
||||
return
|
||||
|
||||
self._receive_buffer.extend(data)
|
||||
self._timeline.record_first_data()
|
||||
|
||||
# Notify waiting readers
|
||||
self._receive_event.set()
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} received {len(data)} bytes")
|
||||
|
||||
if end_stream:
|
||||
self._read_closed = True
|
||||
async with self._state_lock:
|
||||
if self._write_closed:
|
||||
self._state = StreamState.CLOSED
|
||||
else:
|
||||
self._state = StreamState.READ_CLOSED
|
||||
|
||||
# Wake up readers to process remaining data and EOF
|
||||
self._receive_event.set()
|
||||
|
||||
logger.debug(f"Stream {self.stream_id} received FIN")
|
||||
|
||||
async def handle_reset(self, error_code: int) -> None:
|
||||
"""
|
||||
Handle stream reset from remote peer.
|
||||
|
||||
Args:
|
||||
error_code: QUIC error code from reset frame
|
||||
|
||||
"""
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id} reset by peer with error code {error_code}"
|
||||
)
|
||||
|
||||
async with self._state_lock:
|
||||
self._state = StreamState.RESET
|
||||
self._reset_error_code = error_code
|
||||
|
||||
await self._cleanup_resources()
|
||||
self._timeline.record_reset(error_code)
|
||||
self._close_event.set()
|
||||
|
||||
# Wake up any pending operations
|
||||
self._receive_event.set()
|
||||
self._backpressure_event.set()
|
||||
|
||||
async def handle_flow_control_update(self, available_window: int) -> None:
|
||||
"""
|
||||
Handle flow control window updates.
|
||||
|
||||
Args:
|
||||
available_window: Available flow control window size
|
||||
|
||||
"""
|
||||
if available_window > 0:
|
||||
self._backpressure_event.set()
|
||||
logger.debug(
|
||||
f"Stream {self.stream_id} flow control".__add__(
|
||||
f"window updated: {available_window}"
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._backpressure_event = trio.Event() # Reset to blocking state
|
||||
logger.debug(f"Stream {self.stream_id} flow control window exhausted")
|
||||
|
||||
def _extract_data_from_buffer(self, n: int) -> bytes:
|
||||
"""Extract data from receive buffer with specified limit."""
|
||||
if n == -1:
|
||||
# Read all available data
|
||||
data = bytes(self._receive_buffer)
|
||||
self._receive_buffer.clear()
|
||||
else:
|
||||
# Read up to n bytes
|
||||
data = bytes(self._receive_buffer[:n])
|
||||
self._receive_buffer = self._receive_buffer[n:]
|
||||
|
||||
return data
|
||||
|
||||
async def write(self, data: bytes) -> None:
|
||||
"""Write data to the stream."""
|
||||
if self._closed:
|
||||
raise QUICStreamError("Stream is closed")
|
||||
async def _handle_stream_error(self, error: Exception) -> None:
|
||||
"""Handle errors by resetting the stream."""
|
||||
logger.error(f"Stream {self.stream_id} error: {error}")
|
||||
await self.reset(error_code=1) # Generic error code
|
||||
|
||||
# Send data using the underlying QUIC connection
|
||||
self._connection._quic.send_stream_data(self._stream_id, data)
|
||||
await self._connection._transmit()
|
||||
def _reserve_memory(self, size: int) -> None:
|
||||
"""Reserve memory with resource manager."""
|
||||
if self._resource_scope:
|
||||
try:
|
||||
self._resource_scope.reserve_memory(size)
|
||||
self._memory_reserved += size
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to reserve memory for stream {self.stream_id}: {e}"
|
||||
)
|
||||
|
||||
async def close(self, error_code: int = 0) -> None:
|
||||
"""Close the stream."""
|
||||
if self._closed:
|
||||
return
|
||||
def _release_memory(self, size: int) -> None:
|
||||
"""Release memory with resource manager."""
|
||||
if self._resource_scope and size > 0:
|
||||
try:
|
||||
self._resource_scope.release_memory(size)
|
||||
self._memory_reserved = max(0, self._memory_reserved - size)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to release memory for stream {self.stream_id}: {e}"
|
||||
)
|
||||
|
||||
self._closed = True
|
||||
async def _cleanup_resources(self) -> None:
|
||||
"""Clean up stream resources."""
|
||||
# Release all reserved memory
|
||||
if self._memory_reserved > 0:
|
||||
self._release_memory(self._memory_reserved)
|
||||
|
||||
# Close the QUIC stream
|
||||
self._connection._quic.reset_stream(self._stream_id, error_code)
|
||||
await self._connection._transmit()
|
||||
# Clear receive buffer
|
||||
async with self._receive_buffer_lock:
|
||||
self._receive_buffer.clear()
|
||||
|
||||
# Remove from connection's stream list
|
||||
self._connection._streams.pop(self._stream_id, None)
|
||||
# Remove from connection's stream registry
|
||||
self._connection._remove_stream(self._stream_id)
|
||||
|
||||
self._close_event.set()
|
||||
logger.debug(f"Stream {self.stream_id} resources cleaned up")
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
"""Check if stream is closed."""
|
||||
return self._closed
|
||||
# Abstact implementations
|
||||
|
||||
async def handle_data_received(self, data: bytes, end_stream: bool) -> None:
|
||||
"""Handle data received from the QUIC connection."""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._receive_buffer.extend(data)
|
||||
self._receive_event.set()
|
||||
|
||||
if end_stream:
|
||||
await self.close()
|
||||
|
||||
async def handle_reset(self, error_code: int) -> None:
|
||||
"""Handle stream reset."""
|
||||
self._closed = True
|
||||
self._close_event.set()
|
||||
|
||||
def set_deadline(self, ttl: int) -> bool:
|
||||
"""
|
||||
Set the deadline
|
||||
"""
|
||||
raise NotImplementedError("Yamux does not support setting read deadlines")
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""
|
||||
Reset the stream
|
||||
"""
|
||||
await self.handle_reset(0)
|
||||
return
|
||||
|
||||
def get_remote_address(self) -> tuple[str, int] | None:
|
||||
return self._connection._remote_addr
|
||||
def get_remote_address(self) -> tuple[str, int]:
|
||||
return self._remote_addr
|
||||
|
||||
async def __aenter__(self) -> "QUICStream":
|
||||
"""Enter the async context manager."""
|
||||
@ -134,3 +591,26 @@ class QUICStream(IMuxedStream):
|
||||
) -> None:
|
||||
"""Exit the async context manager and close the stream."""
|
||||
await self.close()
|
||||
|
||||
def set_deadline(self, ttl: int) -> bool:
|
||||
"""
|
||||
Set a deadline for the stream. QUIC does not support deadlines natively,
|
||||
so this method always returns False to indicate the operation is unsupported.
|
||||
|
||||
:param ttl: Time-to-live in seconds (ignored).
|
||||
:return: False, as deadlines are not supported.
|
||||
"""
|
||||
raise NotImplementedError("QUIC does not support setting read deadlines")
|
||||
|
||||
# String representation for debugging
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"QUICStream(id={self.stream_id}, "
|
||||
f"state={self._state.value}, "
|
||||
f"direction={self._direction.value}, "
|
||||
f"protocol={self._protocol})"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"QUICStream({self.stream_id})"
|
||||
|
||||
@ -1,20 +1,43 @@
|
||||
from unittest.mock import (
|
||||
Mock,
|
||||
)
|
||||
"""
|
||||
Enhanced tests for QUIC connection functionality - Module 3.
|
||||
Tests all new features including advanced stream management, resource management,
|
||||
error handling, and concurrent operations.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from multiaddr.multiaddr import Multiaddr
|
||||
import trio
|
||||
|
||||
from libp2p.crypto.ed25519 import (
|
||||
create_new_key_pair,
|
||||
)
|
||||
from libp2p.crypto.ed25519 import create_new_key_pair
|
||||
from libp2p.peer.id import ID
|
||||
from libp2p.transport.quic.connection import QUICConnection
|
||||
from libp2p.transport.quic.exceptions import QUICStreamError
|
||||
from libp2p.transport.quic.exceptions import (
|
||||
QUICConnectionClosedError,
|
||||
QUICConnectionError,
|
||||
QUICConnectionTimeoutError,
|
||||
QUICStreamLimitError,
|
||||
QUICStreamTimeoutError,
|
||||
)
|
||||
from libp2p.transport.quic.stream import QUICStream, StreamDirection
|
||||
|
||||
|
||||
class TestQUICConnection:
|
||||
"""Test suite for QUIC connection functionality."""
|
||||
class MockResourceScope:
|
||||
"""Mock resource scope for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.memory_reserved = 0
|
||||
|
||||
def reserve_memory(self, size):
|
||||
self.memory_reserved += size
|
||||
|
||||
def release_memory(self, size):
|
||||
self.memory_reserved = max(0, self.memory_reserved - size)
|
||||
|
||||
|
||||
class TestQUICConnectionEnhanced:
|
||||
"""Enhanced test suite for QUIC connection functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_quic_connection(self):
|
||||
@ -23,11 +46,20 @@ class TestQUICConnection:
|
||||
mock.next_event.return_value = None
|
||||
mock.datagrams_to_send.return_value = []
|
||||
mock.get_timer.return_value = None
|
||||
mock.connect = Mock()
|
||||
mock.close = Mock()
|
||||
mock.send_stream_data = Mock()
|
||||
mock.reset_stream = Mock()
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def quic_connection(self, mock_quic_connection):
|
||||
"""Create test QUIC connection."""
|
||||
def mock_resource_scope(self):
|
||||
"""Create mock resource scope."""
|
||||
return MockResourceScope()
|
||||
|
||||
@pytest.fixture
|
||||
def quic_connection(self, mock_quic_connection, mock_resource_scope):
|
||||
"""Create test QUIC connection with enhanced features."""
|
||||
private_key = create_new_key_pair().private_key
|
||||
peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||
|
||||
@ -39,18 +71,44 @@ class TestQUICConnection:
|
||||
is_initiator=True,
|
||||
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
transport=Mock(),
|
||||
resource_scope=mock_resource_scope,
|
||||
)
|
||||
|
||||
def test_connection_initialization(self, quic_connection):
|
||||
"""Test connection initialization."""
|
||||
@pytest.fixture
|
||||
def server_connection(self, mock_quic_connection, mock_resource_scope):
|
||||
"""Create server-side QUIC connection."""
|
||||
private_key = create_new_key_pair().private_key
|
||||
peer_id = ID.from_pubkey(private_key.get_public_key())
|
||||
|
||||
return QUICConnection(
|
||||
quic_connection=mock_quic_connection,
|
||||
remote_addr=("127.0.0.1", 4001),
|
||||
peer_id=peer_id,
|
||||
local_peer_id=peer_id,
|
||||
is_initiator=False,
|
||||
maddr=Multiaddr("/ip4/127.0.0.1/udp/4001/quic"),
|
||||
transport=Mock(),
|
||||
resource_scope=mock_resource_scope,
|
||||
)
|
||||
|
||||
# Basic functionality tests
|
||||
|
||||
def test_connection_initialization_enhanced(
|
||||
self, quic_connection, mock_resource_scope
|
||||
):
|
||||
"""Test enhanced connection initialization."""
|
||||
assert quic_connection._remote_addr == ("127.0.0.1", 4001)
|
||||
assert quic_connection.is_initiator is True
|
||||
assert not quic_connection.is_closed
|
||||
assert not quic_connection.is_established
|
||||
assert len(quic_connection._streams) == 0
|
||||
assert quic_connection._resource_scope == mock_resource_scope
|
||||
assert quic_connection._outbound_stream_count == 0
|
||||
assert quic_connection._inbound_stream_count == 0
|
||||
assert len(quic_connection._stream_accept_queue) == 0
|
||||
|
||||
def test_stream_id_calculation(self):
|
||||
"""Test stream ID calculation for client/server."""
|
||||
def test_stream_id_calculation_enhanced(self):
|
||||
"""Test enhanced stream ID calculation for client/server."""
|
||||
# Client connection (initiator)
|
||||
client_conn = QUICConnection(
|
||||
quic_connection=Mock(),
|
||||
@ -75,45 +133,364 @@ class TestQUICConnection:
|
||||
)
|
||||
assert server_conn._next_stream_id == 1 # Server starts with 1
|
||||
|
||||
def test_incoming_stream_detection(self, quic_connection):
|
||||
"""Test incoming stream detection logic."""
|
||||
def test_incoming_stream_detection_enhanced(self, quic_connection):
|
||||
"""Test enhanced incoming stream detection logic."""
|
||||
# For client (initiator), odd stream IDs are incoming
|
||||
assert quic_connection._is_incoming_stream(1) is True # Server-initiated
|
||||
assert quic_connection._is_incoming_stream(0) is False # Client-initiated
|
||||
assert quic_connection._is_incoming_stream(5) is True # Server-initiated
|
||||
assert quic_connection._is_incoming_stream(4) is False # Client-initiated
|
||||
|
||||
# Stream management tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_stats(self, quic_connection):
|
||||
"""Test connection statistics."""
|
||||
stats = quic_connection.get_stats()
|
||||
async def test_open_stream_basic(self, quic_connection):
|
||||
"""Test basic stream opening."""
|
||||
quic_connection._started = True
|
||||
|
||||
stream = await quic_connection.open_stream()
|
||||
|
||||
assert isinstance(stream, QUICStream)
|
||||
assert stream.stream_id == "0"
|
||||
assert stream.direction == StreamDirection.OUTBOUND
|
||||
assert 0 in quic_connection._streams
|
||||
assert quic_connection._outbound_stream_count == 1
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_open_stream_limit_reached(self, quic_connection):
|
||||
"""Test stream limit enforcement."""
|
||||
quic_connection._started = True
|
||||
quic_connection._outbound_stream_count = quic_connection.MAX_OUTGOING_STREAMS
|
||||
|
||||
with pytest.raises(QUICStreamLimitError, match="Maximum outbound streams"):
|
||||
await quic_connection.open_stream()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_open_stream_timeout(self, quic_connection: QUICConnection):
|
||||
"""Test stream opening timeout."""
|
||||
quic_connection._started = True
|
||||
return
|
||||
|
||||
# Mock the stream ID lock to simulate slow operation
|
||||
async def slow_acquire():
|
||||
await trio.sleep(10) # Longer than timeout
|
||||
|
||||
with patch.object(
|
||||
quic_connection._stream_id_lock, "acquire", side_effect=slow_acquire
|
||||
):
|
||||
with pytest.raises(
|
||||
QUICStreamTimeoutError, match="Stream creation timed out"
|
||||
):
|
||||
await quic_connection.open_stream(timeout=0.1)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_accept_stream_basic(self, quic_connection):
|
||||
"""Test basic stream acceptance."""
|
||||
# Create a mock inbound stream
|
||||
mock_stream = Mock(spec=QUICStream)
|
||||
mock_stream.stream_id = "1"
|
||||
|
||||
# Add to accept queue
|
||||
quic_connection._stream_accept_queue.append(mock_stream)
|
||||
quic_connection._stream_accept_event.set()
|
||||
|
||||
accepted_stream = await quic_connection.accept_stream(timeout=0.1)
|
||||
|
||||
assert accepted_stream == mock_stream
|
||||
assert len(quic_connection._stream_accept_queue) == 0
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_accept_stream_timeout(self, quic_connection):
|
||||
"""Test stream acceptance timeout."""
|
||||
with pytest.raises(QUICStreamTimeoutError, match="Stream accept timed out"):
|
||||
await quic_connection.accept_stream(timeout=0.1)
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_accept_stream_on_closed_connection(self, quic_connection):
|
||||
"""Test stream acceptance on closed connection."""
|
||||
await quic_connection.close()
|
||||
|
||||
with pytest.raises(QUICConnectionClosedError, match="Connection is closed"):
|
||||
await quic_connection.accept_stream()
|
||||
|
||||
# Stream handler tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_stream_handler_setting(self, quic_connection):
|
||||
"""Test setting stream handler."""
|
||||
|
||||
async def mock_handler(stream):
|
||||
pass
|
||||
|
||||
quic_connection.set_stream_handler(mock_handler)
|
||||
assert quic_connection._stream_handler == mock_handler
|
||||
|
||||
# Connection lifecycle tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_start_client(self, quic_connection):
|
||||
"""Test client connection start."""
|
||||
with patch.object(
|
||||
quic_connection, "_initiate_connection", new_callable=AsyncMock
|
||||
) as mock_initiate:
|
||||
await quic_connection.start()
|
||||
|
||||
assert quic_connection._started
|
||||
mock_initiate.assert_called_once()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_start_server(self, server_connection):
|
||||
"""Test server connection start."""
|
||||
await server_connection.start()
|
||||
|
||||
assert server_connection._started
|
||||
assert server_connection._established
|
||||
assert server_connection._connected_event.is_set()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_start_already_started(self, quic_connection):
|
||||
"""Test starting already started connection."""
|
||||
quic_connection._started = True
|
||||
|
||||
# Should not raise error, just log warning
|
||||
await quic_connection.start()
|
||||
assert quic_connection._started
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_start_closed(self, quic_connection):
|
||||
"""Test starting closed connection."""
|
||||
quic_connection._closed = True
|
||||
|
||||
with pytest.raises(
|
||||
QUICConnectionError, match="Cannot start a closed connection"
|
||||
):
|
||||
await quic_connection.start()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_connect_with_nursery(self, quic_connection):
|
||||
"""Test connection establishment with nursery."""
|
||||
quic_connection._started = True
|
||||
quic_connection._established = True
|
||||
quic_connection._connected_event.set()
|
||||
|
||||
with patch.object(
|
||||
quic_connection, "_start_background_tasks", new_callable=AsyncMock
|
||||
) as mock_start_tasks:
|
||||
with patch.object(
|
||||
quic_connection, "verify_peer_identity", new_callable=AsyncMock
|
||||
) as mock_verify:
|
||||
async with trio.open_nursery() as nursery:
|
||||
await quic_connection.connect(nursery)
|
||||
|
||||
assert quic_connection._nursery == nursery
|
||||
mock_start_tasks.assert_called_once()
|
||||
mock_verify.assert_called_once()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_connect_timeout(self, quic_connection: QUICConnection):
|
||||
"""Test connection establishment timeout."""
|
||||
quic_connection._started = True
|
||||
# Don't set connected event to simulate timeout
|
||||
|
||||
with patch.object(
|
||||
quic_connection, "_start_background_tasks", new_callable=AsyncMock
|
||||
):
|
||||
async with trio.open_nursery() as nursery:
|
||||
with pytest.raises(
|
||||
QUICConnectionTimeoutError, match="Connection handshake timed out"
|
||||
):
|
||||
await quic_connection.connect(nursery)
|
||||
|
||||
# Resource management tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_stream_removal_resource_cleanup(
|
||||
self, quic_connection: QUICConnection, mock_resource_scope
|
||||
):
|
||||
"""Test stream removal and resource cleanup."""
|
||||
quic_connection._started = True
|
||||
|
||||
# Create a stream
|
||||
stream = await quic_connection.open_stream()
|
||||
|
||||
# Remove the stream
|
||||
quic_connection._remove_stream(int(stream.stream_id))
|
||||
|
||||
assert int(stream.stream_id) not in quic_connection._streams
|
||||
# Note: Count updates is async, so we can't test it directly here
|
||||
|
||||
# Error handling tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_error_handling(self, quic_connection):
|
||||
"""Test connection error handling."""
|
||||
error = Exception("Test error")
|
||||
|
||||
with patch.object(
|
||||
quic_connection, "close", new_callable=AsyncMock
|
||||
) as mock_close:
|
||||
await quic_connection._handle_connection_error(error)
|
||||
mock_close.assert_called_once()
|
||||
|
||||
# Statistics and monitoring tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_stats_enhanced(self, quic_connection):
|
||||
"""Test enhanced connection statistics."""
|
||||
quic_connection._started = True
|
||||
|
||||
# Create some streams
|
||||
_stream1 = await quic_connection.open_stream()
|
||||
_stream2 = await quic_connection.open_stream()
|
||||
|
||||
stats = quic_connection.get_stream_stats()
|
||||
|
||||
expected_keys = [
|
||||
"peer_id",
|
||||
"remote_addr",
|
||||
"is_initiator",
|
||||
"is_established",
|
||||
"is_closed",
|
||||
"active_streams",
|
||||
"next_stream_id",
|
||||
"total_streams",
|
||||
"outbound_streams",
|
||||
"inbound_streams",
|
||||
"max_streams",
|
||||
"stream_utilization",
|
||||
"stats",
|
||||
]
|
||||
|
||||
for key in expected_keys:
|
||||
assert key in stats
|
||||
|
||||
assert stats["total_streams"] == 2
|
||||
assert stats["outbound_streams"] == 2
|
||||
assert stats["inbound_streams"] == 0
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_close(self, quic_connection):
|
||||
"""Test connection close functionality."""
|
||||
assert not quic_connection.is_closed
|
||||
async def test_get_active_streams(self, quic_connection):
|
||||
"""Test getting active streams."""
|
||||
quic_connection._started = True
|
||||
|
||||
# Create streams
|
||||
stream1 = await quic_connection.open_stream()
|
||||
stream2 = await quic_connection.open_stream()
|
||||
|
||||
active_streams = quic_connection.get_active_streams()
|
||||
|
||||
assert len(active_streams) == 2
|
||||
assert stream1 in active_streams
|
||||
assert stream2 in active_streams
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_get_streams_by_protocol(self, quic_connection):
|
||||
"""Test getting streams by protocol."""
|
||||
quic_connection._started = True
|
||||
|
||||
# Create streams with different protocols
|
||||
stream1 = await quic_connection.open_stream()
|
||||
stream1.protocol = "/test/1.0.0"
|
||||
|
||||
stream2 = await quic_connection.open_stream()
|
||||
stream2.protocol = "/other/1.0.0"
|
||||
|
||||
test_streams = quic_connection.get_streams_by_protocol("/test/1.0.0")
|
||||
other_streams = quic_connection.get_streams_by_protocol("/other/1.0.0")
|
||||
|
||||
assert len(test_streams) == 1
|
||||
assert len(other_streams) == 1
|
||||
assert stream1 in test_streams
|
||||
assert stream2 in other_streams
|
||||
|
||||
# Enhanced close tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_connection_close_enhanced(self, quic_connection: QUICConnection):
|
||||
"""Test enhanced connection close with stream cleanup."""
|
||||
quic_connection._started = True
|
||||
|
||||
# Create some streams
|
||||
_stream1 = await quic_connection.open_stream()
|
||||
_stream2 = await quic_connection.open_stream()
|
||||
|
||||
await quic_connection.close()
|
||||
|
||||
assert quic_connection.is_closed
|
||||
assert len(quic_connection._streams) == 0
|
||||
|
||||
# Concurrent operations tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_stream_operations_on_closed_connection(self, quic_connection):
|
||||
"""Test stream operations on closed connection."""
|
||||
await quic_connection.close()
|
||||
async def test_concurrent_stream_operations(self, quic_connection):
|
||||
"""Test concurrent stream operations."""
|
||||
quic_connection._started = True
|
||||
|
||||
with pytest.raises(QUICStreamError, match="Connection is closed"):
|
||||
await quic_connection.open_stream()
|
||||
async def create_stream():
|
||||
return await quic_connection.open_stream()
|
||||
|
||||
# Create multiple streams concurrently
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i in range(10):
|
||||
nursery.start_soon(create_stream)
|
||||
|
||||
# Wait a bit for all to start
|
||||
await trio.sleep(0.1)
|
||||
|
||||
# Should have created streams without conflicts
|
||||
assert quic_connection._outbound_stream_count == 10
|
||||
assert len(quic_connection._streams) == 10
|
||||
|
||||
# Connection properties tests
|
||||
|
||||
def test_connection_properties(self, quic_connection):
|
||||
"""Test connection property accessors."""
|
||||
assert quic_connection.multiaddr() == quic_connection._maddr
|
||||
assert quic_connection.local_peer_id() == quic_connection._local_peer_id
|
||||
assert quic_connection.remote_peer_id() == quic_connection._peer_id
|
||||
|
||||
# IRawConnection interface tests
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_raw_connection_write(self, quic_connection):
|
||||
"""Test raw connection write interface."""
|
||||
quic_connection._started = True
|
||||
|
||||
with patch.object(quic_connection, "open_stream") as mock_open:
|
||||
mock_stream = AsyncMock()
|
||||
mock_open.return_value = mock_stream
|
||||
|
||||
await quic_connection.write(b"test data")
|
||||
|
||||
mock_open.assert_called_once()
|
||||
mock_stream.write.assert_called_once_with(b"test data")
|
||||
mock_stream.close_write.assert_called_once()
|
||||
|
||||
@pytest.mark.trio
|
||||
async def test_raw_connection_read_not_implemented(self, quic_connection):
|
||||
"""Test raw connection read raises NotImplementedError."""
|
||||
with pytest.raises(NotImplementedError, match="Use muxed connection interface"):
|
||||
await quic_connection.read()
|
||||
|
||||
# String representation tests
|
||||
|
||||
def test_connection_string_representation(self, quic_connection):
|
||||
"""Test connection string representations."""
|
||||
repr_str = repr(quic_connection)
|
||||
str_str = str(quic_connection)
|
||||
|
||||
assert "QUICConnection" in repr_str
|
||||
assert str(quic_connection._peer_id) in repr_str
|
||||
assert str(quic_connection._remote_addr) in repr_str
|
||||
assert str(quic_connection._peer_id) in str_str
|
||||
|
||||
# Mock verification helpers
|
||||
|
||||
def test_mock_resource_scope_functionality(self, mock_resource_scope):
|
||||
"""Test mock resource scope works correctly."""
|
||||
assert mock_resource_scope.memory_reserved == 0
|
||||
|
||||
mock_resource_scope.reserve_memory(1000)
|
||||
assert mock_resource_scope.memory_reserved == 1000
|
||||
|
||||
mock_resource_scope.reserve_memory(500)
|
||||
assert mock_resource_scope.memory_reserved == 1500
|
||||
|
||||
mock_resource_scope.release_memory(600)
|
||||
assert mock_resource_scope.memory_reserved == 900
|
||||
|
||||
mock_resource_scope.release_memory(2000) # Should not go negative
|
||||
assert mock_resource_scope.memory_reserved == 0
|
||||
|
||||
Reference in New Issue
Block a user