Skip to content

Commit

Permalink
Split server name encoding / parsing to its own methods
Browse files Browse the repository at this point in the history
Add `pull_server_name` and `push_server_name` methods so that they can
be exercised independently.
  • Loading branch information
jlaine committed Dec 28, 2023
1 parent bdaa0ce commit 12964eb
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 11 deletions.
35 changes: 24 additions & 11 deletions src/aioquic/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,10 @@ class HandshakeType(IntEnum):
MESSAGE_HASH = 254


class NameType(IntEnum):
HOST_NAME = 0


class PskKeyExchangeMode(IntEnum):
PSK_KE = 0
PSK_DHE_KE = 1
Expand Down Expand Up @@ -439,6 +443,24 @@ def push_extension(buf: Buffer, extension_type: int) -> Generator:
yield


# ServerName


def pull_server_name(buf: Buffer) -> str:
with pull_block(buf, 2):
name_type = buf.pull_uint8()
if name_type != NameType.HOST_NAME:
# We don't know this name_type.
raise AlertIllegalParameter(f"unknown server name type {name_type}")
return pull_opaque(buf, 2).decode("ascii")


def push_server_name(buf: Buffer, server_name: str) -> None:
with push_block(buf, 2):
buf.push_uint8(NameType.HOST_NAME)
push_opaque(buf, 2, server_name.encode("ascii"))


# KeyShareEntry


Expand Down Expand Up @@ -569,14 +591,7 @@ def pull_extension() -> None:
elif extension_type == ExtensionType.PSK_KEY_EXCHANGE_MODES:
hello.psk_key_exchange_modes = pull_list(buf, 1, buf.pull_uint8)
elif extension_type == ExtensionType.SERVER_NAME:
with pull_block(buf, 2):
name_type = buf.pull_uint8()
if name_type != 0:
# We don't know this name_type.
raise AlertIllegalParameter(
f"unknown server name type {name_type}"
)
hello.server_name = pull_opaque(buf, 2).decode("ascii")
hello.server_name = pull_server_name(buf)
elif extension_type == ExtensionType.ALPN:
hello.alpn_protocols = pull_list(
buf, 2, partial(pull_alpn_protocol, buf)
Expand Down Expand Up @@ -628,9 +643,7 @@ def push_client_hello(buf: Buffer, hello: ClientHello) -> None:

if hello.server_name is not None:
with push_extension(buf, ExtensionType.SERVER_NAME):
with push_block(buf, 2):
buf.push_uint8(0)
push_opaque(buf, 2, hello.server_name.encode("ascii"))
push_server_name(buf, hello.server_name)

if hello.alpn_protocols is not None:
with push_extension(buf, ExtensionType.ALPN):
Expand Down
17 changes: 17 additions & 0 deletions tests/test_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
pull_finished,
pull_new_session_ticket,
pull_server_hello,
pull_server_name,
push_certificate,
push_certificate_request,
push_certificate_verify,
Expand All @@ -36,6 +37,7 @@
push_finished,
push_new_session_ticket,
push_server_hello,
push_server_name,
verify_certificate,
)
from cryptography.exceptions import UnsupportedAlgorithm
Expand Down Expand Up @@ -1458,6 +1460,21 @@ def test_push_finished(self):
push_finished(buf, finished)
self.assertEqual(buf.data, load("tls_finished.bin"))

def test_pull_server_name(self):
buf = Buffer(data=b"\x00\x12\x00\x00\x0fwww.example.com")
self.assertEqual(pull_server_name(buf), "www.example.com")

def test_pull_server_name_with_bad_name_type(self):
buf = Buffer(data=b"\x00\x12\xff\x00\x0fwww.example.com")
with self.assertRaises(tls.AlertIllegalParameter) as cm:
pull_server_name(buf)
self.assertEqual(str(cm.exception), "unknown server name type 255")

def test_push_server_name(self):
buf = Buffer(128)
push_server_name(buf, "www.example.com")
self.assertEqual(buf.data, b"\x00\x12\x00\x00\x0fwww.example.com")


class VerifyCertificateTest(TestCase):
def test_verify_certificate_chain(self):
Expand Down

0 comments on commit 12964eb

Please sign in to comment.