Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicitly track whether FIN has been acknowledged #441

Merged
merged 1 commit into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/aioquic/quic/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3019,7 +3019,7 @@ def _write_crypto_frame(
QuicFrameType.CRYPTO,
capacity=frame_overhead,
handler=stream.sender.on_data_delivery,
handler_args=(frame.offset, frame.offset + len(frame.data)),
handler_args=(frame.offset, frame.offset + len(frame.data), False),
)
buf.push_uint_var(frame.offset)
buf.push_uint16(len(frame.data) | 0x4000)
Expand Down Expand Up @@ -3246,7 +3246,7 @@ def _write_stream_frame(
frame_type,
capacity=frame_overhead,
handler=stream.sender.on_data_delivery,
handler_args=(frame.offset, frame.offset + len(frame.data)),
handler_args=(frame.offset, frame.offset + len(frame.data), frame.fin),
)
buf.push_uint_var(stream.stream_id)
if frame.offset:
Expand Down
25 changes: 19 additions & 6 deletions src/aioquic/quic/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def __init__(self, stream_id: Optional[int], writable: bool) -> None:
self.reset_pending = False

self._acked = RangeSet()
self._acked_fin = False
self._buffer = bytearray()
self._buffer_fin: Optional[int] = None
self._buffer_start = 0 # the offset for the start of the buffer
Expand Down Expand Up @@ -249,14 +250,18 @@ def get_reset_frame(self) -> QuicResetStreamFrame:
)

def on_data_delivery(
self, delivery: QuicDeliveryState, start: int, stop: int
self, delivery: QuicDeliveryState, start: int, stop: int, fin: bool
) -> None:
"""
Callback when sent data is ACK'd.
"""
self.buffer_is_empty = False
# If the frame had the FIN bit set, its end MUST match otherwise
# we have a programming error.
assert not fin or stop == self._buffer_fin

if delivery == QuicDeliveryState.ACKED:
if stop > start:
# Some data has been ACK'd, discard it.
self._acked.add(start, stop)
first_range = self._acked[0]
if first_range.start == self._buffer_start:
Expand All @@ -265,14 +270,22 @@ def on_data_delivery(
self._buffer_start += size
del self._buffer[:size]

if self._buffer_start == self._buffer_fin:
# all date up to the FIN has been ACK'd, we're done sending
if fin:
# The FIN has been ACK'd.
self._acked_fin = True

if self._buffer_start == self._buffer_fin and self._acked_fin:
# All data and the FIN have been ACK'd, we're done sending.
self.is_finished = True
else:
if stop > start:
# Some data has been lost, reschedule it.
self.buffer_is_empty = False
self._pending.add(start, stop)
if stop == self._buffer_fin:
self.send_buffer_empty = False

if fin:
# The FIN has been lost, reschedule it.
self.buffer_is_empty = False
self._pending_eof = True

def on_reset_delivery(self, delivery: QuicDeliveryState) -> None:
Expand Down
52 changes: 42 additions & 10 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,11 +370,11 @@ def test_sender_data(self):
self.assertEqual(stream.sender.next_offset, 16)

# first chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8)
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8, False)
self.assertFalse(stream.sender.is_finished)

# second chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16)
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16, False)
self.assertFalse(stream.sender.is_finished)

def test_sender_data_and_fin(self):
Expand Down Expand Up @@ -409,11 +409,11 @@ def test_sender_data_and_fin(self):
self.assertEqual(stream.sender.next_offset, 16)

# first chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8)
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8, False)
self.assertFalse(stream.sender.is_finished)

# second chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16)
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16, True)
self.assertTrue(stream.sender.is_finished)

def test_sender_data_and_fin_ack_out_of_order(self):
Expand Down Expand Up @@ -448,11 +448,11 @@ def test_sender_data_and_fin_ack_out_of_order(self):
self.assertEqual(stream.sender.next_offset, 16)

# second chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16)
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16, True)
self.assertFalse(stream.sender.is_finished)

# first chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8)
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8, False)
self.assertTrue(stream.sender.is_finished)

def test_sender_data_lost(self):
Expand Down Expand Up @@ -489,7 +489,7 @@ def test_sender_data_lost(self):
self.assertEqual(stream.sender.next_offset, 16)

# a chunk gets lost
stream.sender.on_data_delivery(QuicDeliveryState.LOST, 0, 8)
stream.sender.on_data_delivery(QuicDeliveryState.LOST, 0, 8, False)
self.assertEqual(list(stream.sender._pending), [range(0, 8)])
self.assertEqual(stream.sender.next_offset, 0)

Expand Down Expand Up @@ -535,7 +535,7 @@ def test_sender_data_lost_fin(self):
self.assertEqual(stream.sender.next_offset, 16)

# a chunk gets lost
stream.sender.on_data_delivery(QuicDeliveryState.LOST, 8, 16)
stream.sender.on_data_delivery(QuicDeliveryState.LOST, 8, 16, True)
self.assertEqual(list(stream.sender._pending), [range(8, 16)])
self.assertEqual(stream.sender.next_offset, 8)

Expand All @@ -547,8 +547,12 @@ def test_sender_data_lost_fin(self):
self.assertEqual(list(stream.sender._pending), [])
self.assertEqual(stream.sender.next_offset, 16)

# both chunks gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 16)
# first chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8, False)
self.assertFalse(stream.sender.is_finished)

# second chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16, True)
self.assertTrue(stream.sender.is_finished)

def test_sender_blocked(self):
Expand Down Expand Up @@ -635,6 +639,10 @@ def test_sender_fin_only(self):
self.assertIsNone(frame)
self.assertTrue(stream.sender.buffer_is_empty)

# EOF is acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 0, True)
self.assertTrue(stream.sender.is_finished)

def test_sender_fin_only_despite_blocked(self):
stream = QuicStream()

Expand All @@ -657,6 +665,30 @@ def test_sender_fin_only_despite_blocked(self):
self.assertIsNone(frame)
self.assertTrue(stream.sender.buffer_is_empty)

def test_sender_fin_then_ack(self):
stream = QuicStream()

# send some data
stream.sender.write(b"data")
frame = stream.sender.get_frame(8)
self.assertEqual(frame.data, b"data")

# data is acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 4, False)
self.assertFalse(stream.sender.is_finished)

# write EOF
stream.sender.write(b"", end_stream=True)
self.assertFalse(stream.sender.buffer_is_empty)
frame = stream.sender.get_frame(8)
self.assertEqual(frame.data, b"")
self.assertTrue(frame.fin)
self.assertEqual(frame.offset, 4)

# EOF is acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 4, 4, True)
self.assertTrue(stream.sender.is_finished)

def test_sender_reset(self):
stream = QuicStream()

Expand Down
Loading