Skip to content

Commit

Permalink
Add some unit tests for new code branches
Browse files Browse the repository at this point in the history
  • Loading branch information
jlaine committed Dec 28, 2023
1 parent 76a4ff9 commit ddfe4a1
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 87 deletions.
95 changes: 60 additions & 35 deletions src/aioquic/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,10 @@ class HandshakeType(IntEnum):
MESSAGE_HASH = 254


class NameType(IntEnum):
HOST_NAME = 0


class PskKeyExchangeMode(IntEnum):
PSK_KE = 0
PSK_DHE_KE = 1
Expand Down Expand Up @@ -439,6 +443,26 @@ def push_extension(buf: Buffer, extension_type: int) -> Generator:
yield


# ServerName


def pull_server_name(buf: Buffer) -> str:
with pull_block(buf, 2):
name_type = buf.pull_uint8()
if name_type != NameType.HOST_NAME:
# We don't know this name_type.
raise AlertIllegalParameter(
f"ServerName has an unknown name type {name_type}"
)
return pull_opaque(buf, 2).decode("ascii")


def push_server_name(buf: Buffer, server_name: str) -> None:
with push_block(buf, 2):
buf.push_uint8(NameType.HOST_NAME)
push_opaque(buf, 2, server_name.encode("ascii"))


# KeyShareEntry


Expand Down Expand Up @@ -472,6 +496,12 @@ def push_alpn_protocol(buf: Buffer, protocol: str) -> None:
PskIdentity = Tuple[bytes, int]


@dataclass
class OfferedPsks:
identities: List[PskIdentity]
binders: List[bytes]


def pull_psk_identity(buf: Buffer) -> PskIdentity:
identity = pull_opaque(buf, 2)
obfuscated_ticket_age = buf.pull_uint32()
Expand All @@ -491,15 +521,31 @@ def push_psk_binder(buf: Buffer, binder: bytes) -> None:
push_opaque(buf, 1, binder)


# MESSAGES
def pull_offered_psks(buf: Buffer) -> OfferedPsks:
return OfferedPsks(
identities=pull_list(buf, 2, partial(pull_psk_identity, buf)),
binders=pull_list(buf, 2, partial(pull_psk_binder, buf)),
)

Extension = Tuple[int, bytes]

def push_offered_psks(buf: Buffer, pre_shared_key: OfferedPsks) -> None:
push_list(
buf,
2,
partial(push_psk_identity, buf),
pre_shared_key.identities,
)
push_list(
buf,
2,
partial(push_psk_binder, buf),
pre_shared_key.binders,
)

@dataclass
class OfferedPsks:
identities: List[PskIdentity]
binders: List[bytes]

# MESSAGES

Extension = Tuple[int, bytes]


@dataclass
Expand Down Expand Up @@ -554,7 +600,7 @@ def pull_extension() -> None:
nonlocal after_psk
if after_psk:
# the alert is Illegal Parameter per RFC 8446 section 4.2.11.
raise AlertIllegalParameter("pre-shared key was not last")
raise AlertIllegalParameter("PreSharedKey is not the last extension")

extension_type = buf.pull_uint16()
extension_length = buf.pull_uint16()
Expand All @@ -569,25 +615,15 @@ def pull_extension() -> None:
elif extension_type == ExtensionType.PSK_KEY_EXCHANGE_MODES:
hello.psk_key_exchange_modes = pull_list(buf, 1, buf.pull_uint8)
elif extension_type == ExtensionType.SERVER_NAME:
with pull_block(buf, 2):
name_type = buf.pull_uint8()
if name_type != 0:
# We don't know this name_type.
raise AlertIllegalParameter(
f"unknown server name type {name_type}"
)
hello.server_name = pull_opaque(buf, 2).decode("ascii")
hello.server_name = pull_server_name(buf)
elif extension_type == ExtensionType.ALPN:
hello.alpn_protocols = pull_list(
buf, 2, partial(pull_alpn_protocol, buf)
)
elif extension_type == ExtensionType.EARLY_DATA:
hello.early_data = True
elif extension_type == ExtensionType.PRE_SHARED_KEY:
hello.pre_shared_key = OfferedPsks(
identities=pull_list(buf, 2, partial(pull_psk_identity, buf)),
binders=pull_list(buf, 2, partial(pull_psk_binder, buf)),
)
hello.pre_shared_key = pull_offered_psks(buf)
after_psk = True
else:
hello.other_extensions.append(
Expand Down Expand Up @@ -628,9 +664,7 @@ def push_client_hello(buf: Buffer, hello: ClientHello) -> None:

if hello.server_name is not None:
with push_extension(buf, ExtensionType.SERVER_NAME):
with push_block(buf, 2):
buf.push_uint8(0)
push_opaque(buf, 2, hello.server_name.encode("ascii"))
push_server_name(buf, hello.server_name)

if hello.alpn_protocols is not None:
with push_extension(buf, ExtensionType.ALPN):
Expand All @@ -649,18 +683,7 @@ def push_client_hello(buf: Buffer, hello: ClientHello) -> None:
# pre_shared_key MUST be last
if hello.pre_shared_key is not None:
with push_extension(buf, ExtensionType.PRE_SHARED_KEY):
push_list(
buf,
2,
partial(push_psk_identity, buf),
hello.pre_shared_key.identities,
)
push_list(
buf,
2,
partial(push_psk_binder, buf),
hello.pre_shared_key.binders,
)
push_offered_psks(hello.pre_shared_key)


@dataclass
Expand Down Expand Up @@ -1430,7 +1453,9 @@ def _build_session_ticket(

def _check_certificate_verify_signature(self, verify: CertificateVerify) -> None:
if verify.algorithm not in self._signature_algorithms:
raise AlertDecryptError
raise AlertDecryptError(
"CertificateVerify has a signature algorithm we did not advertise"
)

try:
self._peer_certificate.public_key().verify(
Expand Down
Loading

0 comments on commit ddfe4a1

Please sign in to comment.