Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement address validation token for future connections for client #340

Merged
merged 1 commit into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/aioquic/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import AsyncGenerator, Callable, Optional, cast

from ..quic.configuration import QuicConfiguration
from ..quic.connection import QuicConnection
from ..quic.connection import QuicConnection, QuicTokenHandler
from ..tls import SessionTicketHandler
from .protocol import QuicConnectionProtocol, QuicStreamHandler

Expand All @@ -20,6 +20,7 @@ async def connect(
create_protocol: Optional[Callable] = QuicConnectionProtocol,
session_ticket_handler: Optional[SessionTicketHandler] = None,
stream_handler: Optional[QuicStreamHandler] = None,
token_handler: Optional[QuicTokenHandler] = None,
wait_connected: bool = True,
local_port: int = 0,
) -> AsyncGenerator[QuicConnectionProtocol, None]:
Expand Down Expand Up @@ -60,7 +61,9 @@ async def connect(
if configuration.server_name is None:
configuration.server_name = host
connection = QuicConnection(
configuration=configuration, session_ticket_handler=session_ticket_handler
configuration=configuration,
session_ticket_handler=session_ticket_handler,
token_handler=token_handler,
)

# explicitly enable IPv4/IPv6 dual stack
Expand Down
7 changes: 7 additions & 0 deletions src/aioquic/quic/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ class QuicConfiguration:
The TLS session ticket which should be used for session resumption.
"""

token: bytes = b""
"""
The address validation token that can be used to validate future connections.

.. note:: This is only used by clients.
"""

cadata: Optional[bytes] = None
cafile: Optional[str] = None
capath: Optional[str] = None
Expand Down
26 changes: 24 additions & 2 deletions src/aioquic/quic/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,18 @@
from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import Any, Deque, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple
from typing import (
Any,
Callable,
Deque,
Dict,
FrozenSet,
List,
Optional,
Sequence,
Set,
Tuple,
)

from .. import tls
from ..buffer import (
Expand Down Expand Up @@ -207,6 +218,8 @@ class QuicReceiveContext:
time: float


QuicTokenHandler = Callable[[bytes], None]

END_STATES = frozenset(
[
QuicConnectionState.CLOSING,
Expand Down Expand Up @@ -239,6 +252,7 @@ def __init__(
retry_source_connection_id: Optional[bytes] = None,
session_ticket_fetcher: Optional[tls.SessionTicketFetcher] = None,
session_ticket_handler: Optional[tls.SessionTicketHandler] = None,
token_handler: Optional[QuicTokenHandler] = None,
) -> None:
assert configuration.max_datagram_size >= SMALLEST_MAX_DATAGRAM_SIZE, (
"The smallest allowed maximum datagram size is "
Expand All @@ -252,6 +266,10 @@ def __init__(
retry_source_connection_id is None
), "Cannot set retry_source_connection_id for a client"
else:
assert token_handler is None, "Cannot set `token_handler` for a server"
assert (
configuration.token == b""
), "Cannot set `configuration.token` for a server"
assert (
configuration.certificate is not None
), "SSL certificate is required for a server"
Expand Down Expand Up @@ -319,7 +337,7 @@ def __init__(
)
self._peer_cid_available: List[QuicConnectionId] = []
self._peer_cid_sequence_numbers: Set[int] = set([0])
self._peer_token = b""
self._peer_token = configuration.token
self._quic_logger: Optional[QuicLoggerTrace] = None
self._remote_ack_delay_exponent = 3
self._remote_active_connection_id_limit = 2
Expand Down Expand Up @@ -385,6 +403,7 @@ def __init__(
# callbacks
self._session_ticket_fetcher = session_ticket_fetcher
self._session_ticket_handler = session_ticket_handler
self._token_handler = token_handler

# frame handlers
self.__frame_handlers = {
Expand Down Expand Up @@ -1912,6 +1931,9 @@ def _handle_new_token_frame(
reason_phrase="Clients must not send NEW_TOKEN frames",
)

if self._token_handler is not None:
self._token_handler(token)

def _handle_padding_frame(
self, context: QuicReceiveContext, frame_type: int, buf: Buffer
) -> None:
Expand Down
13 changes: 12 additions & 1 deletion tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,14 +1630,25 @@ def test_handle_new_connection_id_with_retire_prior_to_invalid(self):
)

def test_handle_new_token_frame(self):
with client_and_server() as (client, server):
new_token = None

def token_handler(token):
nonlocal new_token
new_token = token

with client_and_server(client_kwargs={"token_handler": token_handler}) as (
client,
server,
):
# client receives NEW_TOKEN
client._handle_new_token_frame(
client_receive_context(client),
QuicFrameType.NEW_TOKEN,
Buffer(data=binascii.unhexlify("080102030405060708")),
)

self.assertEqual(new_token, binascii.unhexlify("0102030405060708"))

def test_handle_new_token_frame_from_client(self):
with client_and_server() as (client, server):
# server receives NEW_TOKEN
Expand Down
Loading