Skip to content

Commit

Permalink
Ensure no data is sent after a stream reset
Browse files Browse the repository at this point in the history
Once a RESET has been requested on a stream, the stream's state is
exclusively determined by the RESET being acknowledged. Ensure that we
never send out any more data after a RESET being acknowledged, even if
data is lost.
  • Loading branch information
jlaine committed Jan 7, 2024
1 parent 0c38321 commit 984db6d
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 3 deletions.
18 changes: 16 additions & 2 deletions src/aioquic/quic/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ def get_frame(
"""
Get a frame of data to send.
"""
assert self._reset_error_code is None, "cannot call get_frame() after reset()"

# get the first pending data range
try:
r = self._pending[0]
Expand Down Expand Up @@ -257,7 +259,15 @@ def on_data_delivery(
"""
# If the frame had the FIN bit set, its end MUST match otherwise
# we have a programming error.
assert not fin or stop == self._buffer_fin
assert (
not fin or stop == self._buffer_fin
), "on_data_delivered() was called with inconsistent fin / stop"

# If a reset has been requested, stop processing data delivery.
# The transition to the finished state only depends on the reset
# being acknowledged.
if self._reset_error_code is not None:
return

if delivery == QuicDeliveryState.ACKED:
if stop > start:
Expand Down Expand Up @@ -293,9 +303,10 @@ def on_reset_delivery(self, delivery: QuicDeliveryState) -> None:
Callback when a reset is ACK'd.
"""
if delivery == QuicDeliveryState.ACKED:
# the reset has been ACK'd, we're done sending
# The reset has been ACK'd, we're done sending.
self.is_finished = True
else:
# The reset has been lost, reschedule it.
self.reset_pending = True

def reset(self, error_code: int) -> None:
Expand All @@ -306,6 +317,9 @@ def reset(self, error_code: int) -> None:
self._reset_error_code = error_code
self.reset_pending = True

# Prevent any more data from being sent or re-sent.
self.buffer_is_empty = True

def write(self, data: bytes, end_stream: bool = False) -> None:
"""
Write some data bytes to the QUIC stream.
Expand Down
50 changes: 49 additions & 1 deletion tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,17 +692,30 @@ def test_sender_fin_then_ack(self):
def test_sender_reset(self):
stream = QuicStream()

# send some data and EOF
stream.sender.write(b"data", end_stream=True)
frame = stream.sender.get_frame(8)
self.assertEqual(frame.data, b"data")
self.assertTrue(frame.fin)
self.assertEqual(frame.offset, 0)

# reset is requested
stream.sender.reset(QuicErrorCode.NO_ERROR)
self.assertTrue(stream.sender.buffer_is_empty)
self.assertTrue(stream.sender.reset_pending)

# reset is sent
reset = stream.sender.get_reset_frame()
self.assertEqual(reset.error_code, QuicErrorCode.NO_ERROR)
self.assertEqual(reset.final_size, 0)
self.assertEqual(reset.final_size, 4)
self.assertFalse(stream.sender.reset_pending)
self.assertFalse(stream.sender.is_finished)

# data and EOF are acknowledged
stream.sender.on_data_delivery(QuicDeliveryState.ACKED, 0, 4, True)
self.assertTrue(stream.sender.buffer_is_empty)
self.assertFalse(stream.sender.is_finished)

# reset is acklowledged
stream.sender.on_reset_delivery(QuicDeliveryState.ACKED)
self.assertFalse(stream.sender.reset_pending)
Expand All @@ -713,6 +726,7 @@ def test_sender_reset_lost(self):

# reset is requested
stream.sender.reset(QuicErrorCode.NO_ERROR)
self.assertTrue(stream.sender.buffer_is_empty)
self.assertTrue(stream.sender.reset_pending)

# reset is sent
Expand All @@ -735,4 +749,38 @@ def test_sender_reset_lost(self):
# reset is acklowledged
stream.sender.on_reset_delivery(QuicDeliveryState.ACKED)
self.assertFalse(stream.sender.reset_pending)
self.assertTrue(stream.sender.buffer_is_empty)
self.assertTrue(stream.sender.is_finished)

def test_sender_reset_with_data_lost(self):
stream = QuicStream()

# send some data and EOF
stream.sender.write(b"data", end_stream=True)
frame = stream.sender.get_frame(8)
self.assertEqual(frame.data, b"data")
self.assertTrue(frame.fin)
self.assertEqual(frame.offset, 0)

# reset is requested
stream.sender.reset(QuicErrorCode.NO_ERROR)
self.assertTrue(stream.sender.buffer_is_empty)
self.assertTrue(stream.sender.reset_pending)

# reset is sent
reset = stream.sender.get_reset_frame()
self.assertEqual(reset.error_code, QuicErrorCode.NO_ERROR)
self.assertEqual(reset.final_size, 4)
self.assertFalse(stream.sender.reset_pending)
self.assertFalse(stream.sender.is_finished)

# data and EOF are lost
stream.sender.on_data_delivery(QuicDeliveryState.LOST, 0, 4, True)
self.assertTrue(stream.sender.buffer_is_empty)
self.assertFalse(stream.sender.is_finished)

# reset is acklowledged
stream.sender.on_reset_delivery(QuicDeliveryState.ACKED)
self.assertFalse(stream.sender.reset_pending)
self.assertTrue(stream.sender.buffer_is_empty)
self.assertTrue(stream.sender.is_finished)

0 comments on commit 984db6d

Please sign in to comment.