Merge branch 'main' into async-validators

This commit is contained in:
Manu Sheel Gupta
2025-07-05 14:50:18 -07:00
committed by GitHub
4 changed files with 505 additions and 56 deletions

View File

@ -98,16 +98,32 @@ class YamuxStream(IMuxedStream):
# Flow control: Check if we have enough send window
total_len = len(data)
sent = 0
logging.debug(f"Stream {self.stream_id}: Starts writing {total_len} bytes ")
while sent < total_len:
# Wait for available window with timeout
timeout = False
async with self.window_lock:
# Wait for available window
while self.send_window == 0 and not self.closed:
# Release lock while waiting
if self.send_window == 0:
logging.debug(
f"Stream {self.stream_id}: Window is zero, waiting for update"
)
# Release lock and wait with timeout
self.window_lock.release()
await trio.sleep(0.01)
# To avoid re-acquiring the lock immediately,
with trio.move_on_after(5.0) as cancel_scope:
while self.send_window == 0 and not self.closed:
await trio.sleep(0.01)
# If we timed out, cancel the scope
timeout = cancel_scope.cancelled_caught
# Re-acquire lock
await self.window_lock.acquire()
# If we timed out waiting for window update, raise an error
if timeout:
raise MuxedStreamError(
"Timed out waiting for window update after 5 seconds."
)
if self.closed:
raise MuxedStreamError("Stream is closed")
@ -123,25 +139,45 @@ class YamuxStream(IMuxedStream):
await self.conn.secured_conn.write(header + chunk)
sent += to_send
# If window is getting low, consider updating
if self.send_window < DEFAULT_WINDOW_SIZE // 2:
await self.send_window_update()
async def send_window_update(self, increment: int | None = None) -> None:
"""Send a window update to peer."""
if increment is None:
increment = DEFAULT_WINDOW_SIZE - self.recv_window
async def send_window_update(self, increment: int, skip_lock: bool = False) -> None:
"""
Send a window update to peer.
param:increment: The amount to increment the window size by.
If None, uses the difference between DEFAULT_WINDOW_SIZE
and current receive window.
param:skip_lock (bool): If True, skips acquiring window_lock.
This should only be used when calling from a context
that already holds the lock.
"""
if increment <= 0:
# If increment is zero or negative, skip sending update
logging.debug(
f"Stream {self.stream_id}: Skipping window update"
f"(increment={increment})"
)
return
logging.debug(
f"Stream {self.stream_id}: Sending window update with increment={increment}"
)
async with self.window_lock:
self.recv_window += increment
async def _do_window_update() -> None:
header = struct.pack(
YAMUX_HEADER_FORMAT, 0, TYPE_WINDOW_UPDATE, 0, self.stream_id, increment
YAMUX_HEADER_FORMAT,
0,
TYPE_WINDOW_UPDATE,
0,
self.stream_id,
increment,
)
await self.conn.secured_conn.write(header)
if skip_lock:
await _do_window_update()
else:
async with self.window_lock:
await _do_window_update()
async def read(self, n: int | None = -1) -> bytes:
# Handle None value for n by converting it to -1
if n is None:
@ -154,55 +190,68 @@ class YamuxStream(IMuxedStream):
)
raise MuxedStreamEOF("Stream is closed for receiving")
# If reading until EOF (n == -1), block until stream is closed
if n == -1:
while not self.recv_closed and not self.conn.event_shutting_down.is_set():
data = b""
while not self.conn.event_shutting_down.is_set():
# Check if there's data in the buffer
buffer = self.conn.stream_buffers.get(self.stream_id)
if buffer and len(buffer) > 0:
# Wait for closure even if data is available
logging.debug(
f"Stream {self.stream_id}:Waiting for FIN before returning data"
)
await self.conn.stream_events[self.stream_id].wait()
self.conn.stream_events[self.stream_id] = trio.Event()
else:
# No data, wait for data or closure
logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
await self.conn.stream_events[self.stream_id].wait()
self.conn.stream_events[self.stream_id] = trio.Event()
# After loop, check if stream is closed or shutting down
async with self.conn.streams_lock:
if self.conn.event_shutting_down.is_set():
logging.debug(f"Stream {self.stream_id}: Connection shutting down")
raise MuxedStreamEOF("Connection shut down")
if self.closed:
if self.reset_received:
logging.debug(f"Stream {self.stream_id}: Stream was reset")
raise MuxedStreamReset("Stream was reset")
else:
logging.debug(
f"Stream {self.stream_id}: Stream closed cleanly (EOF)"
)
raise MuxedStreamEOF("Stream closed cleanly (EOF)")
buffer = self.conn.stream_buffers.get(self.stream_id)
# If buffer is not available, check if stream is closed
if buffer is None:
logging.debug(
f"Stream {self.stream_id}: Buffer gone, assuming closed"
)
logging.debug(f"Stream {self.stream_id}: No buffer available")
raise MuxedStreamEOF("Stream buffer closed")
# If we have data in buffer, process it
if len(buffer) > 0:
chunk = bytes(buffer)
buffer.clear()
data += chunk
# Send window update for the chunk we just read
async with self.window_lock:
self.recv_window += len(chunk)
logging.debug(f"Stream {self.stream_id}: Update {len(chunk)}")
await self.send_window_update(len(chunk), skip_lock=True)
# If stream is closed (FIN received) and buffer is empty, break
if self.recv_closed and len(buffer) == 0:
logging.debug(f"Stream {self.stream_id}: EOF reached")
raise MuxedStreamEOF("Stream is closed for receiving")
# Return all buffered data
data = bytes(buffer)
buffer.clear()
logging.debug(f"Stream {self.stream_id}: Returning {len(data)} bytes")
logging.debug(f"Stream {self.stream_id}: Closed with empty buffer")
break
# If stream was reset, raise reset error
if self.reset_received:
logging.debug(f"Stream {self.stream_id}: Stream was reset")
raise MuxedStreamReset("Stream was reset")
# Wait for more data or stream closure
logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
await self.conn.stream_events[self.stream_id].wait()
self.conn.stream_events[self.stream_id] = trio.Event()
# After loop exit, first check if we have data to return
if data:
logging.debug(
f"Stream {self.stream_id}: Returning {len(data)} bytes after loop"
)
return data
# For specific size read (n > 0), return available data immediately
return await self.conn.read_stream(self.stream_id, n)
# No data accumulated, now check why we exited the loop
if self.conn.event_shutting_down.is_set():
logging.debug(f"Stream {self.stream_id}: Connection shutting down")
raise MuxedStreamEOF("Connection shut down")
# Return empty data
return b""
else:
data = await self.conn.read_stream(self.stream_id, n)
async with self.window_lock:
self.recv_window += len(data)
logging.debug(
f"Stream {self.stream_id}: Sending window update after read, "
f"increment={len(data)}"
)
await self.send_window_update(len(data), skip_lock=True)
return data
async def close(self) -> None:
if not self.send_closed: