diff --git a/src/aioquic/asyncio/client.py b/src/aioquic/asyncio/client.py index 63e736554..a8e41702b 100644 --- a/src/aioquic/asyncio/client.py +++ b/src/aioquic/asyncio/client.py @@ -4,7 +4,7 @@ from typing import AsyncGenerator, Callable, Optional, cast from ..quic.configuration import QuicConfiguration -from ..quic.connection import QuicConnection +from ..quic.connection import QuicConnection, QuicTokenHandler from ..tls import SessionTicketHandler from .protocol import QuicConnectionProtocol, QuicStreamHandler @@ -20,6 +20,7 @@ async def connect( create_protocol: Optional[Callable] = QuicConnectionProtocol, session_ticket_handler: Optional[SessionTicketHandler] = None, stream_handler: Optional[QuicStreamHandler] = None, + token_handler: Optional[QuicTokenHandler] = None, wait_connected: bool = True, local_port: int = 0, ) -> AsyncGenerator[QuicConnectionProtocol, None]: @@ -60,7 +61,9 @@ async def connect( if configuration.server_name is None: configuration.server_name = host connection = QuicConnection( - configuration=configuration, session_ticket_handler=session_ticket_handler + configuration=configuration, + session_ticket_handler=session_ticket_handler, + token_handler=token_handler, ) # explicitly enable IPv4/IPv6 dual stack diff --git a/src/aioquic/quic/configuration.py b/src/aioquic/quic/configuration.py index faba75d27..d1f184c66 100644 --- a/src/aioquic/quic/configuration.py +++ b/src/aioquic/quic/configuration.py @@ -86,6 +86,13 @@ class QuicConfiguration: The TLS session ticket which should be used for session resumption. """ + token: bytes = b"" + """ + The address validation token that can be used to validate future connections. + + .. note:: This is only used by clients. + """ + cadata: Optional[bytes] = None cafile: Optional[str] = None capath: Optional[str] = None diff --git a/src/aioquic/quic/connection.py b/src/aioquic/quic/connection.py index 238db458a..8dfcca8c6 100644 --- a/src/aioquic/quic/connection.py +++ b/src/aioquic/quic/connection.py @@ -5,7 +5,18 @@ from dataclasses import dataclass from enum import Enum from functools import partial -from typing import Any, Deque, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple +from typing import ( + Any, + Callable, + Deque, + Dict, + FrozenSet, + List, + Optional, + Sequence, + Set, + Tuple, +) from .. import tls from ..buffer import ( @@ -207,6 +218,8 @@ class QuicReceiveContext: time: float +QuicTokenHandler = Callable[[bytes], None] + END_STATES = frozenset( [ QuicConnectionState.CLOSING, @@ -239,6 +252,7 @@ def __init__( retry_source_connection_id: Optional[bytes] = None, session_ticket_fetcher: Optional[tls.SessionTicketFetcher] = None, session_ticket_handler: Optional[tls.SessionTicketHandler] = None, + token_handler: Optional[QuicTokenHandler] = None, ) -> None: assert configuration.max_datagram_size >= SMALLEST_MAX_DATAGRAM_SIZE, ( "The smallest allowed maximum datagram size is " @@ -252,6 +266,10 @@ def __init__( retry_source_connection_id is None ), "Cannot set retry_source_connection_id for a client" else: + assert token_handler is None, "Cannot set `token_handler` for a server" + assert ( + configuration.token == b"" + ), "Cannot set `configuration.token` for a server" assert ( configuration.certificate is not None ), "SSL certificate is required for a server" @@ -319,7 +337,7 @@ def __init__( ) self._peer_cid_available: List[QuicConnectionId] = [] self._peer_cid_sequence_numbers: Set[int] = set([0]) - self._peer_token = b"" + self._peer_token = configuration.token self._quic_logger: Optional[QuicLoggerTrace] = None self._remote_ack_delay_exponent = 3 self._remote_active_connection_id_limit = 2 @@ -385,6 +403,7 @@ def __init__( # callbacks self._session_ticket_fetcher = session_ticket_fetcher self._session_ticket_handler = session_ticket_handler + self._token_handler = token_handler # frame handlers self.__frame_handlers = { @@ -1912,6 +1931,9 @@ def _handle_new_token_frame( reason_phrase="Clients must not send NEW_TOKEN frames", ) + if self._token_handler is not None: + self._token_handler(token) + def _handle_padding_frame( self, context: QuicReceiveContext, frame_type: int, buf: Buffer ) -> None: diff --git a/tests/test_connection.py b/tests/test_connection.py index b037b809d..7ec9b6645 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1630,7 +1630,16 @@ def test_handle_new_connection_id_with_retire_prior_to_invalid(self): ) def test_handle_new_token_frame(self): - with client_and_server() as (client, server): + new_token = None + + def token_handler(token): + nonlocal new_token + new_token = token + + with client_and_server(client_kwargs={"token_handler": token_handler}) as ( + client, + server, + ): # client receives NEW_TOKEN client._handle_new_token_frame( client_receive_context(client), @@ -1638,6 +1647,8 @@ def test_handle_new_token_frame(self): Buffer(data=binascii.unhexlify("080102030405060708")), ) + self.assertEqual(new_token, binascii.unhexlify("0102030405060708")) + def test_handle_new_token_frame_from_client(self): with client_and_server() as (client, server): # server receives NEW_TOKEN