Skip to content

Commit

Permalink
Do not swallow FIN frames without data (fixes: #438)
Browse files Browse the repository at this point in the history
Co-authored-by: Maximilian Hils <[email protected]>
  • Loading branch information
jlaine and mhils committed Jan 7, 2024
1 parent 5772246 commit 98e7bd0
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 16 deletions.
4 changes: 2 additions & 2 deletions src/aioquic/quic/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3019,7 +3019,7 @@ def _write_crypto_frame(
QuicFrameType.CRYPTO,
capacity=frame_overhead,
handler=stream.sender.on_data_delivery,
handler_args=(frame.offset, frame.offset + len(frame.data)),
handler_args=(frame.offset, frame.offset + len(frame.data), False),
)
buf.push_uint_var(frame.offset)
buf.push_uint16(len(frame.data) | 0x4000)
Expand Down Expand Up @@ -3246,7 +3246,7 @@ def _write_stream_frame(
frame_type,
capacity=frame_overhead,
handler=stream.sender.on_data_delivery,
handler_args=(frame.offset, frame.offset + len(frame.data)),
handler_args=(frame.offset, frame.offset + len(frame.data), frame.fin),
)
buf.push_uint_var(stream.stream_id)
if frame.offset:
Expand Down
23 changes: 19 additions & 4 deletions src/aioquic/quic/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def __init__(self, stream_id: Optional[int], writable: bool) -> None:
self.reset_pending = False

self._acked = RangeSet()
self._acked_fin = False
self._buffer = bytearray()
self._buffer_fin: Optional[int] = None
self._buffer_start = 0 # the offset for the start of the buffer
Expand Down Expand Up @@ -249,14 +250,20 @@ def get_reset_frame(self) -> QuicResetStreamFrame:
)

def on_data_delivery(
self, delivery: QuicDeliveryState, start: int, stop: int
self, delivery: QuicDeliveryState, start: int, stop: int, fin: bool
) -> None:
"""
Callback when sent data is ACK'd.
"""
self.buffer_is_empty = False

# If the frame had the FIN bit set, its end MUST match otherwise
# we have a programming error.
assert not fin or stop == self._buffer_fin

if delivery == QuicDeliveryState.ACKED:
if stop > start:
# Some data has been ACK'd, discard it.
self._acked.add(start, stop)
first_range = self._acked[0]
if first_range.start == self._buffer_start:
Expand All @@ -265,13 +272,21 @@ def on_data_delivery(
self._buffer_start += size
del self._buffer[:size]

if self._buffer_start == self._buffer_fin:
# all date up to the FIN has been ACK'd, we're done sending
if fin:
# The FIN has been ACK'd.
self._acked_fin = True

if self._buffer_start == self._buffer_fin and self._acked_fin:
# All data and the FIN have been ACK'd, we're done sending.
self.is_finished = True
else:
if stop > start:
# Some data has been lost, reschedule it.
self.send_buffer_empty = False
self._pending.add(start, stop)
if stop == self._buffer_fin:

if fin:
# The FIN has been lost, reschedule it.
self.send_buffer_empty = False
self._pending_eof = True

Expand Down
52 changes: 42 additions & 10 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,11 +370,11 @@ def test_sender_data(self):
self.assertEqual(stream.sender.next_offset, 16)

# first chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8)
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8, False)
self.assertFalse(stream.sender.is_finished)

# second chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16)
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16, False)
self.assertFalse(stream.sender.is_finished)

def test_sender_data_and_fin(self):
Expand Down Expand Up @@ -409,11 +409,11 @@ def test_sender_data_and_fin(self):
self.assertEqual(stream.sender.next_offset, 16)

# first chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8)
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8, False)
self.assertFalse(stream.sender.is_finished)

# second chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16)
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16, True)
self.assertTrue(stream.sender.is_finished)

def test_sender_data_and_fin_ack_out_of_order(self):
Expand Down Expand Up @@ -448,11 +448,11 @@ def test_sender_data_and_fin_ack_out_of_order(self):
self.assertEqual(stream.sender.next_offset, 16)

# second chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16)
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16, True)
self.assertFalse(stream.sender.is_finished)

# first chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8)
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8, False)
self.assertTrue(stream.sender.is_finished)

def test_sender_data_lost(self):
Expand Down Expand Up @@ -489,7 +489,7 @@ def test_sender_data_lost(self):
self.assertEqual(stream.sender.next_offset, 16)

# a chunk gets lost
stream.sender.on_data_delivery(QuicDeliveryState.LOST, 0, 8)
stream.sender.on_data_delivery(QuicDeliveryState.LOST, 0, 8, False)
self.assertEqual(list(stream.sender._pending), [range(0, 8)])
self.assertEqual(stream.sender.next_offset, 0)

Expand Down Expand Up @@ -535,7 +535,7 @@ def test_sender_data_lost_fin(self):
self.assertEqual(stream.sender.next_offset, 16)

# a chunk gets lost
stream.sender.on_data_delivery(QuicDeliveryState.LOST, 8, 16)
stream.sender.on_data_delivery(QuicDeliveryState.LOST, 8, 16, True)
self.assertEqual(list(stream.sender._pending), [range(8, 16)])
self.assertEqual(stream.sender.next_offset, 8)

Expand All @@ -547,8 +547,12 @@ def test_sender_data_lost_fin(self):
self.assertEqual(list(stream.sender._pending), [])
self.assertEqual(stream.sender.next_offset, 16)

# both chunks gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 16)
# first chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 8, False)
self.assertFalse(stream.sender.is_finished)

# second chunk gets acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 8, 16, True)
self.assertTrue(stream.sender.is_finished)

def test_sender_blocked(self):
Expand Down Expand Up @@ -635,6 +639,10 @@ def test_sender_fin_only(self):
self.assertIsNone(frame)
self.assertTrue(stream.sender.buffer_is_empty)

# EOF is acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 0, True)
self.assertTrue(stream.sender.is_finished)

def test_sender_fin_only_despite_blocked(self):
stream = QuicStream()

Expand All @@ -657,6 +665,30 @@ def test_sender_fin_only_despite_blocked(self):
self.assertIsNone(frame)
self.assertTrue(stream.sender.buffer_is_empty)

def test_sender_fin_then_ack(self):
stream = QuicStream()

# send some data
stream.sender.write(b"data")
frame = stream.sender.get_frame(8)
self.assertEqual(frame.data, b"data")

# data is acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 4, False)
self.assertFalse(stream.sender.is_finished)

# write EOF
stream.sender.write(b"", end_stream=True)
self.assertFalse(stream.sender.buffer_is_empty)
frame = stream.sender.get_frame(8)
self.assertEqual(frame.data, b"")
self.assertTrue(frame.fin)
self.assertEqual(frame.offset, 4)

# EOF is acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 4, 4, True)
self.assertTrue(stream.sender.is_finished)

def test_sender_reset(self):
stream = QuicStream()

Expand Down

0 comments on commit 98e7bd0

Please sign in to comment.