Skip to content

Commit

Permalink
Explicitly track whether FIN has been acknowledged (fixes: #438)
Browse files Browse the repository at this point in the history
A stream's sender part is only finished if all the data and the FIN have
been acknowledged. Tracking the acknowledged data ranges is not
sufficient to know whether the FIN has been acknowledged too, as the FIN
may have been sent on its own. To fix this we need to explicitly keep
track of whether the frame containing the FIN was acknowledged.

We also:

- Remove a bogus assignment of `QuicStreamSender.send_buffer_empty`,
  this should have been `QuicStreamSender.buffer_is_empty`.
- Only set `buffer_is_empty` to `False` when data or a FIN have been
  lost. This saves useless calls to `QuicStreamSender.get_frame()`.

Co-authored-by: Maximilian Hils <[email protected]>
  • Loading branch information
jlaine and mhils committed Jan 7, 2024
1 parent 5772246 commit 84d4d48
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 18 deletions.
4 changes: 2 additions & 2 deletions src/aioquic/quic/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3019,7 +3019,7 @@ def _write_crypto_frame(
QuicFrameType.CRYPTO,
capacity=frame_overhead,
handler=stream.sender.on_data_delivery,
handler_args=(frame.offset, frame.offset + len(frame.data)),
handler_args=(frame.offset, frame.offset + len(frame.data), False),
)
buf.push_uint_var(frame.offset)
buf.push_uint16(len(frame.data) | 0x4000)
Expand Down Expand Up @@ -3246,7 +3246,7 @@ def _write_stream_frame(
frame_type,
capacity=frame_overhead,
handler=stream.sender.on_data_delivery,
handler_args=(frame.offset, frame.offset + len(frame.data)),
handler_args=(frame.offset, frame.offset + len(frame.data), frame.fin),
)
buf.push_uint_var(stream.stream_id)
if frame.offset:
Expand Down
25 changes: 19 additions & 6 deletions src/aioquic/quic/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def __init__(self, stream_id: Optional[int], writable: bool) -> None:
self.reset_pending = False

self._acked = RangeSet()
self._acked_fin = False
self._buffer = bytearray()
self._buffer_fin: Optional[int] = None
self._buffer_start = 0 # the offset for the start of the buffer
Expand Down Expand Up @@ -249,14 +250,18 @@ def get_reset_frame(self) -> QuicResetStreamFrame:
)

def on_data_delivery(
self, delivery: QuicDeliveryState, start: int, stop: int
self, delivery: QuicDeliveryState, start: int, stop: int, fin: bool
) -> None:
"""
Callback when sent data is ACK'd.
"""
self.buffer_is_empty = False
# If the frame had the FIN bit set, its end MUST match otherwise
# we have a programming error.
assert not fin or stop == self._buffer_fin

if delivery == QuicDeliveryState.ACKED:
if stop > start:
# Some data has been ACK'd, discard it.
self._acked.add(start, stop)
first_range = self._acked[0]
if first_range.start == self._buffer_start:
Expand All @@ -265,14 +270,22 @@ def on_data_delivery(
self._buffer_start += size
del self._buffer[:size]

if self._buffer_start == self._buffer_fin:
# all date up to the FIN has been ACK'd, we're done sending
if fin:
# The FIN has been ACK'd.
self._acked_fin = True

if self._buffer_start == self._buffer_fin and self._acked_fin:
# All data and the FIN have been ACK'd, we're done sending.
self.is_finished = True
else:
if stop > start:
# Some data has been lost, reschedule it.
self.buffer_is_empty = False
self._pending.add(start, stop)
if stop == self._buffer_fin:
self.send_buffer_empty = False

if fin:
# The FIN has been lost, reschedule it.
self.buffer_is_empty = False
self._pending_eof = True

def on_reset_delivery(self, delivery: QuicDeliveryState) -> None:
Expand Down
52 changes: 42 additions & 10 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,11 +370,11 @@ def test_sender_data(self):
self.assertEqual(stream.sender.next_offset, 16)

# first chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8)
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8, False)
self.assertFalse(stream.sender.is_finished)

# second chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16)
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16, False)
self.assertFalse(stream.sender.is_finished)

def test_sender_data_and_fin(self):
Expand Down Expand Up @@ -409,11 +409,11 @@ def test_sender_data_and_fin(self):
self.assertEqual(stream.sender.next_offset, 16)

# first chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8)
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8, False)
self.assertFalse(stream.sender.is_finished)

# second chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16)
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16, True)
self.assertTrue(stream.sender.is_finished)

def test_sender_data_and_fin_ack_out_of_order(self):
Expand Down Expand Up @@ -448,11 +448,11 @@ def test_sender_data_and_fin_ack_out_of_order(self):
self.assertEqual(stream.sender.next_offset, 16)

# second chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16)
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16, True)
self.assertFalse(stream.sender.is_finished)

# first chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8)
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8, False)
self.assertTrue(stream.sender.is_finished)

def test_sender_data_lost(self):
Expand Down Expand Up @@ -489,7 +489,7 @@ def test_sender_data_lost(self):
self.assertEqual(stream.sender.next_offset, 16)

# a chunk gets lost
stream.sender.on_data_delivery(QuicDeliveryState.LOST, 0, 8)
stream.sender.on_data_delivery(QuicDeliveryState.LOST, 0, 8, False)
self.assertEqual(list(stream.sender._pending), [range(0, 8)])
self.assertEqual(stream.sender.next_offset, 0)

Expand Down Expand Up @@ -535,7 +535,7 @@ def test_sender_data_lost_fin(self):
self.assertEqual(stream.sender.next_offset, 16)

# a chunk gets lost
stream.sender.on_data_delivery(QuicDeliveryState.LOST, 8, 16)
stream.sender.on_data_delivery(QuicDeliveryState.LOST, 8, 16, True)
self.assertEqual(list(stream.sender._pending), [range(8, 16)])
self.assertEqual(stream.sender.next_offset, 8)

Expand All @@ -547,8 +547,12 @@ def test_sender_data_lost_fin(self):
self.assertEqual(list(stream.sender._pending), [])
self.assertEqual(stream.sender.next_offset, 16)

# both chunks gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 16)
# first chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8, False)
self.assertFalse(stream.sender.is_finished)

# second chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16, True)
self.assertTrue(stream.sender.is_finished)

def test_sender_blocked(self):
Expand Down Expand Up @@ -635,6 +639,10 @@ def test_sender_fin_only(self):
self.assertIsNone(frame)
self.assertTrue(stream.sender.buffer_is_empty)

# EOF is acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 0, True)
self.assertTrue(stream.sender.is_finished)

def test_sender_fin_only_despite_blocked(self):
stream = QuicStream()

Expand All @@ -657,6 +665,30 @@ def test_sender_fin_only_despite_blocked(self):
self.assertIsNone(frame)
self.assertTrue(stream.sender.buffer_is_empty)

def test_sender_fin_then_ack(self):
stream = QuicStream()

# send some data
stream.sender.write(b"data")
frame = stream.sender.get_frame(8)
self.assertEqual(frame.data, b"data")

# data is acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 4, False)
self.assertFalse(stream.sender.is_finished)

# write EOF
stream.sender.write(b"", end_stream=True)
self.assertFalse(stream.sender.buffer_is_empty)
frame = stream.sender.get_frame(8)
self.assertEqual(frame.data, b"")
self.assertTrue(frame.fin)
self.assertEqual(frame.offset, 4)

# EOF is acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 4, 4, True)
self.assertTrue(stream.sender.is_finished)

def test_sender_reset(self):
stream = QuicStream()

Expand Down

0 comments on commit 84d4d48

Please sign in to comment.