Skip to content

Commit

Permalink
[quic] Implement token for future connections for client
Browse files Browse the repository at this point in the history
- Add `token` to `QuicConfiguration`.
- Add `token_handler` to `aioquic.asyncio.connect`.

See: https://www.rfc-editor.org/rfc/rfc9000.html#name-address-validation-for-futu

Co-authored-by: Jeremy Lainé <[email protected]>
  • Loading branch information
msoxzw and jlaine committed Dec 12, 2023
1 parent 094b01d commit a12f9ec
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 5 deletions.
7 changes: 5 additions & 2 deletions src/aioquic/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import AsyncGenerator, Callable, Optional, cast

from ..quic.configuration import QuicConfiguration
from ..quic.connection import QuicConnection
from ..quic.connection import QuicConnection, QuicTokenHandler
from ..tls import SessionTicketHandler
from .protocol import QuicConnectionProtocol, QuicStreamHandler

Expand All @@ -20,6 +20,7 @@ async def connect(
create_protocol: Optional[Callable] = QuicConnectionProtocol,
session_ticket_handler: Optional[SessionTicketHandler] = None,
stream_handler: Optional[QuicStreamHandler] = None,
token_handler: Optional[QuicTokenHandler] = None,
wait_connected: bool = True,
local_port: int = 0,
) -> AsyncGenerator[QuicConnectionProtocol, None]:
Expand Down Expand Up @@ -60,7 +61,9 @@ async def connect(
if configuration.server_name is None:
configuration.server_name = host
connection = QuicConnection(
configuration=configuration, session_ticket_handler=session_ticket_handler
configuration=configuration,
session_ticket_handler=session_ticket_handler,
token_handler=token_handler,
)

# explicitly enable IPv4/IPv6 dual stack
Expand Down
7 changes: 7 additions & 0 deletions src/aioquic/quic/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ class QuicConfiguration:
The TLS session ticket which should be used for session resumption.
"""

token: bytes = b""
"""
The address validation token that can be used to validate future connections.
.. note:: This is only used by clients.
"""

cadata: Optional[bytes] = None
cafile: Optional[str] = None
capath: Optional[str] = None
Expand Down
26 changes: 24 additions & 2 deletions src/aioquic/quic/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,18 @@
from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import Any, Deque, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple
from typing import (
Any,
Callable,
Deque,
Dict,
FrozenSet,
List,
Optional,
Sequence,
Set,
Tuple,
)

from .. import tls
from ..buffer import (
Expand Down Expand Up @@ -207,6 +218,8 @@ class QuicReceiveContext:
time: float


QuicTokenHandler = Callable[[bytes], None]

END_STATES = frozenset(
[
QuicConnectionState.CLOSING,
Expand Down Expand Up @@ -239,6 +252,7 @@ def __init__(
retry_source_connection_id: Optional[bytes] = None,
session_ticket_fetcher: Optional[tls.SessionTicketFetcher] = None,
session_ticket_handler: Optional[tls.SessionTicketHandler] = None,
token_handler: Optional[QuicTokenHandler] = None,
) -> None:
assert configuration.max_datagram_size >= SMALLEST_MAX_DATAGRAM_SIZE, (
"The smallest allowed maximum datagram size is "
Expand All @@ -252,6 +266,10 @@ def __init__(
retry_source_connection_id is None
), "Cannot set retry_source_connection_id for a client"
else:
assert token_handler is None, "Cannot set `token_handler` for a server"
assert (
configuration.token == b""
), "Cannot set `configuration.token` for a server"
assert (
configuration.certificate is not None
), "SSL certificate is required for a server"
Expand Down Expand Up @@ -319,7 +337,7 @@ def __init__(
)
self._peer_cid_available: List[QuicConnectionId] = []
self._peer_cid_sequence_numbers: Set[int] = set([0])
self._peer_token = b""
self._peer_token = configuration.token
self._quic_logger: Optional[QuicLoggerTrace] = None
self._remote_ack_delay_exponent = 3
self._remote_active_connection_id_limit = 2
Expand Down Expand Up @@ -385,6 +403,7 @@ def __init__(
# callbacks
self._session_ticket_fetcher = session_ticket_fetcher
self._session_ticket_handler = session_ticket_handler
self._token_handler = token_handler

# frame handlers
self.__frame_handlers = {
Expand Down Expand Up @@ -1912,6 +1931,9 @@ def _handle_new_token_frame(
reason_phrase="Clients must not send NEW_TOKEN frames",
)

if self._token_handler is not None:
self._token_handler(token)

def _handle_padding_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
Expand Down
8 changes: 7 additions & 1 deletion tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,7 +1630,13 @@ def test_handle_new_connection_id_with_retire_prior_to_invalid(self):
)

def test_handle_new_token_frame(self):
with client_and_server() as (client, server):
def token_handler(token):
self.assertEqual(token, binascii.unhexlify("0102030405060708"))

with client_and_server(client_kwargs={"token_handler": token_handler}) as (
client,
server,
):
# client receives NEW_TOKEN
client._handle_new_token_frame(
client_receive_context(client),
Expand Down

0 comments on commit a12f9ec

Please sign in to comment.