From ec69a116bfc88de3e426f995484901f208cdfdf8 Mon Sep 17 00:00:00 2001 From: Carl Lei Date: Fri, 15 Sep 2023 12:53:57 +0800 Subject: [PATCH] [connection] Use constant-time stream ID allocation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Instead of running over all the created streams, keep track of the next ID for unidirectional and bidirectional streams. Co-authored-by: Jeremy Lainé --- src/aioquic/quic/connection.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/aioquic/quic/connection.py b/src/aioquic/quic/connection.py index b99c4ab85..0d2d17fed 100644 --- a/src/aioquic/quic/connection.py +++ b/src/aioquic/quic/connection.py @@ -303,6 +303,8 @@ def __init__( self._local_max_streams_uni = Limit( frame_type=QuicFrameType.MAX_STREAMS_UNI, name="max_streams_uni", value=128 ) + self._local_next_stream_id_bidi = 0 if self._is_client else 1 + self._local_next_stream_id_uni = 2 if self._is_client else 3 self._loss_at: Optional[float] = None self._network_paths: List[QuicNetworkPath] = [] self._pacing_at: Optional[float] = None @@ -623,10 +625,10 @@ def get_next_available_stream_id(self, is_unidirectional=False) -> int: """ Return the stream ID for the next stream created by this endpoint. """ - stream_id = (int(is_unidirectional) << 1) | int(not self._is_client) - while stream_id in self._streams or stream_id in self._streams_finished: - stream_id += 4 - return stream_id + if is_unidirectional: + return self._local_next_stream_id_uni + else: + return self._local_next_stream_id_bidi def get_timer(self) -> Optional[float]: """ @@ -1291,12 +1293,17 @@ def _get_or_create_stream_for_send(self, stream_id: int) -> QuicStream: streams_blocked = self._streams_blocked_bidi # create stream + is_unidirectional = stream_is_unidirectional(stream_id) stream = self._streams[stream_id] = QuicStream( stream_id=stream_id, max_stream_data_local=max_stream_data_local, max_stream_data_remote=max_stream_data_remote, - readable=not stream_is_unidirectional(stream_id), + readable=not is_unidirectional, ) + if is_unidirectional: + self._local_next_stream_id_uni = stream_id + 4 + else: + self._local_next_stream_id_bidi = stream_id + 4 # mark stream as blocked if needed if stream_id // 4 >= max_streams: