diff --git a/src/aioquic/quic/stream.py b/src/aioquic/quic/stream.py index 4076a19d4..95f578b58 100644 --- a/src/aioquic/quic/stream.py +++ b/src/aioquic/quic/stream.py @@ -201,6 +201,8 @@ def get_frame( """ Get a frame of data to send. """ + assert self._reset_error_code is None, "cannot call get_frame() after reset()" + # get the first pending data range try: r = self._pending[0] @@ -257,7 +259,15 @@ def on_data_delivery( """ # If the frame had the FIN bit set, its end MUST match otherwise # we have a programming error. - assert not fin or stop == self._buffer_fin + assert ( + not fin or stop == self._buffer_fin + ), "on_data_delivered() was called with inconsistent fin / stop" + + # If a reset has been requested, stop processing data delivery. + # The transition to the finished state only depends on the reset + # being acknowledged. + if self._reset_error_code is not None: + return if delivery == QuicDeliveryState.ACKED: if stop > start: @@ -293,9 +303,10 @@ def on_reset_delivery(self, delivery: QuicDeliveryState) -> None: Callback when a reset is ACK'd. """ if delivery == QuicDeliveryState.ACKED: - # the reset has been ACK'd, we're done sending + # The reset has been ACK'd, we're done sending. self.is_finished = True else: + # The reset has been lost, reschedule it. self.reset_pending = True def reset(self, error_code: int) -> None: @@ -306,6 +317,9 @@ def reset(self, error_code: int) -> None: self._reset_error_code = error_code self.reset_pending = True + # Prevent any more data from being sent or re-sent. + self.buffer_is_empty = True + def write(self, data: bytes, end_stream: bool = False) -> None: """ Write some data bytes to the QUIC stream. diff --git a/tests/test_stream.py b/tests/test_stream.py index 03aa8b964..b6d593063 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -692,17 +692,30 @@ def test_sender_fin_then_ack(self): def test_sender_reset(self): stream = QuicStream() + # send some data and EOF + stream.sender.write(b"data", end_stream=True) + frame = stream.sender.get_frame(8) + self.assertEqual(frame.data, b"data") + self.assertTrue(frame.fin) + self.assertEqual(frame.offset, 0) + # reset is requested stream.sender.reset(QuicErrorCode.NO_ERROR) + self.assertTrue(stream.sender.buffer_is_empty) self.assertTrue(stream.sender.reset_pending) # reset is sent reset = stream.sender.get_reset_frame() self.assertEqual(reset.error_code, QuicErrorCode.NO_ERROR) - self.assertEqual(reset.final_size, 0) + self.assertEqual(reset.final_size, 4) self.assertFalse(stream.sender.reset_pending) self.assertFalse(stream.sender.is_finished) + # data and EOF are acknowledged + stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 4, True) + self.assertTrue(stream.sender.buffer_is_empty) + self.assertFalse(stream.sender.is_finished) + # reset is acklowledged stream.sender.on_reset_delivery(QuicDeliveryState.ACKED) self.assertFalse(stream.sender.reset_pending) @@ -713,6 +726,7 @@ def test_sender_reset_lost(self): # reset is requested stream.sender.reset(QuicErrorCode.NO_ERROR) + self.assertTrue(stream.sender.buffer_is_empty) self.assertTrue(stream.sender.reset_pending) # reset is sent @@ -735,4 +749,38 @@ def test_sender_reset_lost(self): # reset is acklowledged stream.sender.on_reset_delivery(QuicDeliveryState.ACKED) self.assertFalse(stream.sender.reset_pending) + self.assertTrue(stream.sender.buffer_is_empty) + self.assertTrue(stream.sender.is_finished) + + def test_sender_reset_with_data_lost(self): + stream = QuicStream() + + # send some data and EOF + stream.sender.write(b"data", end_stream=True) + frame = stream.sender.get_frame(8) + self.assertEqual(frame.data, b"data") + self.assertTrue(frame.fin) + self.assertEqual(frame.offset, 0) + + # reset is requested + stream.sender.reset(QuicErrorCode.NO_ERROR) + self.assertTrue(stream.sender.buffer_is_empty) + self.assertTrue(stream.sender.reset_pending) + + # reset is sent + reset = stream.sender.get_reset_frame() + self.assertEqual(reset.error_code, QuicErrorCode.NO_ERROR) + self.assertEqual(reset.final_size, 4) + self.assertFalse(stream.sender.reset_pending) + self.assertFalse(stream.sender.is_finished) + + # data and EOF are lost + stream.sender.on_data_delivery(QuicDeliveryState.LOST, 0, 4, True) + self.assertTrue(stream.sender.buffer_is_empty) + self.assertFalse(stream.sender.is_finished) + + # reset is acklowledged + stream.sender.on_reset_delivery(QuicDeliveryState.ACKED) + self.assertFalse(stream.sender.reset_pending) + self.assertTrue(stream.sender.buffer_is_empty) self.assertTrue(stream.sender.is_finished)