resolved recv_window updates,added support for read_EOF

This commit is contained in:
kaneki003
2025-06-21 13:39:03 +05:30
parent 0a7e13f0ed
commit 209deffc8a
2 changed files with 52 additions and 68 deletions

View File

@ -141,9 +141,7 @@ class YamuxStream(IMuxedStream):
await self.conn.secured_conn.write(header + chunk) await self.conn.secured_conn.write(header + chunk)
sent += to_send sent += to_send
async def send_window_update( async def send_window_update(self, increment: int, skip_lock: bool = False) -> None:
self, increment: int | None, skip_lock: bool = False
) -> None:
""" """
Send a window update to peer. Send a window update to peer.
@ -154,12 +152,7 @@ class YamuxStream(IMuxedStream):
This should only be used when calling from a context This should only be used when calling from a context
that already holds the lock. that already holds the lock.
""" """
increment_value = 0 if increment <= 0:
if increment is None:
increment_value = DEFAULT_WINDOW_SIZE - self.recv_window
else:
increment_value = increment
if increment_value <= 0:
# If increment is zero or negative, skip sending update # If increment is zero or negative, skip sending update
logging.debug( logging.debug(
f"Stream {self.stream_id}: Skipping window update" f"Stream {self.stream_id}: Skipping window update"
@ -171,14 +164,13 @@ class YamuxStream(IMuxedStream):
) )
async def _do_window_update() -> None: async def _do_window_update() -> None:
self.recv_window += increment_value
header = struct.pack( header = struct.pack(
YAMUX_HEADER_FORMAT, YAMUX_HEADER_FORMAT,
0, 0,
TYPE_WINDOW_UPDATE, TYPE_WINDOW_UPDATE,
0, 0,
self.stream_id, self.stream_id,
increment_value, increment,
) )
await self.conn.secured_conn.write(header) await self.conn.secured_conn.write(header)
@ -188,6 +180,22 @@ class YamuxStream(IMuxedStream):
async with self.window_lock: async with self.window_lock:
await _do_window_update() await _do_window_update()
async def read_EOF(self) -> bytes:
"""
To read data from stream until it is closed.
"""
data = b""
try:
while True:
recv = await self.read()
if recv:
data += recv
except MuxedStreamEOF:
logging.debug(
f"Stream {self.stream_id}:EOF reached,total data read:{len(data)} bytes"
)
return data
async def read(self, n: int | None = -1) -> bytes: async def read(self, n: int | None = -1) -> bytes:
# Handle None value for n by converting it to -1 # Handle None value for n by converting it to -1
if n is None: if n is None:
@ -202,61 +210,34 @@ class YamuxStream(IMuxedStream):
# If reading until EOF (n == -1), block until stream is closed # If reading until EOF (n == -1), block until stream is closed
if n == -1: if n == -1:
while not self.recv_closed and not self.conn.event_shutting_down.is_set(): # Check if there's data in the buffer
# Check if there's data in the buffer buffer = self.conn.stream_buffers.get(self.stream_id)
buffer = self.conn.stream_buffers.get(self.stream_id) size = len(buffer) if buffer else 0
if buffer and len(buffer) > 0: if size > 0:
# Wait for closure even if data is available # If any data is available,return it immediately
logging.debug( assert buffer is not None
f"Stream {self.stream_id}:Waiting for FIN before returning data"
)
await self.conn.stream_events[self.stream_id].wait()
self.conn.stream_events[self.stream_id] = trio.Event()
else:
# No data, wait for data or closure
logging.debug(f"Stream {self.stream_id}: Waiting for data or FIN")
await self.conn.stream_events[self.stream_id].wait()
self.conn.stream_events[self.stream_id] = trio.Event()
# After loop, check if stream is closed or shutting down
async with self.conn.streams_lock:
if self.conn.event_shutting_down.is_set():
logging.debug(f"Stream {self.stream_id}: Connection shutting down")
raise MuxedStreamEOF("Connection shut down")
if self.closed:
if self.reset_received:
logging.debug(f"Stream {self.stream_id}: Stream was reset")
raise MuxedStreamReset("Stream was reset")
else:
logging.debug(
f"Stream {self.stream_id}: Stream closed cleanly (EOF)"
)
raise MuxedStreamEOF("Stream closed cleanly (EOF)")
buffer = self.conn.stream_buffers.get(self.stream_id)
if buffer is None:
logging.debug(
f"Stream {self.stream_id}: Buffer gone, assuming closed"
)
raise MuxedStreamEOF("Stream buffer closed")
if self.recv_closed and len(buffer) == 0:
logging.debug(f"Stream {self.stream_id}: EOF reached")
raise MuxedStreamEOF("Stream is closed for receiving")
# Return all buffered data
data = bytes(buffer) data = bytes(buffer)
buffer.clear() buffer.clear()
return data async with self.window_lock:
self.recv_window += len(data)
data = await self.conn.read_stream(self.stream_id, n) await self.send_window_update(len(data), skip_lock=True)
async with self.window_lock: return data
self.recv_window -= len(data) # Otherwise,wait for data or FIN
# Automatically send a window update if recv_window is low if self.recv_closed:
if self.recv_window <= DEFAULT_WINDOW_SIZE // 2: raise MuxedStreamEOF("Stream is closed for receiving")
await self.conn.stream_events[self.stream_id].wait()
self.conn.stream_events[self.stream_id] = trio.Event()
return b""
else:
data = await self.conn.read_stream(self.stream_id, n)
async with self.window_lock:
self.recv_window += len(data)
logging.debug( logging.debug(
f"Stream {self.stream_id}: " f"Stream {self.stream_id}: Sending window update after read, "
f"Low recv_window ({self.recv_window}), sending update" f"increment={len(data)}"
) )
await self.send_window_update(None, skip_lock=True) await self.send_window_update(len(data), skip_lock=True)
return data return data
async def close(self) -> None: async def close(self) -> None:
if not self.send_closed: if not self.send_closed:

View File

@ -1,7 +1,7 @@
import pytest import pytest
import trio import trio
from libp2p.abc import ISecureConn from libp2p.abc import IMuxedStream, ISecureConn
from libp2p.crypto.keys import PrivateKey, PublicKey from libp2p.crypto.keys import PrivateKey, PublicKey
from libp2p.peer.id import ID from libp2p.peer.id import ID
from libp2p.stream_muxer.mplex.constants import ( from libp2p.stream_muxer.mplex.constants import (
@ -59,13 +59,15 @@ class DummyMuxedConn(Mplex):
self.event_started = trio.Event() self.event_started = trio.Event()
self.stream_backlog_limit = 256 self.stream_backlog_limit = 256
self.stream_backlog_semaphore = trio.Semaphore(256) self.stream_backlog_semaphore = trio.Semaphore(256)
channels = trio.open_memory_channel[MplexStream](0) # Use IMuxedStream for type consistency with Mplex
channels = trio.open_memory_channel[IMuxedStream](0)
self.new_stream_send_channel, self.new_stream_receive_channel = channels self.new_stream_send_channel, self.new_stream_receive_channel = channels
async def send_message( async def send_message(
self, flag: HeaderTags, data: bytes, stream_id: StreamID self, flag: HeaderTags, data: bytes | None, stream_id: StreamID
) -> None: ) -> int:
await trio.sleep(0.01) await trio.sleep(0.01)
return 0
@pytest.mark.trio @pytest.mark.trio
@ -75,10 +77,11 @@ async def test_concurrent_writes_are_serialized():
class LoggingMuxedConn(DummyMuxedConn): class LoggingMuxedConn(DummyMuxedConn):
async def send_message( async def send_message(
self, flag: HeaderTags, data: bytes, stream_id: StreamID self, flag: HeaderTags, data: bytes | None, stream_id: StreamID
) -> None: ) -> int:
send_log.append(data) send_log.append(data)
await trio.sleep(0.01) await trio.sleep(0.01)
return 0
memory_send, memory_recv = trio.open_memory_channel(8) memory_send, memory_recv = trio.open_memory_channel(8)
stream = MplexStream( stream = MplexStream(