Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tls.py assertion issues. #435

Merged
merged 1 commit into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 103 additions & 43 deletions src/aioquic/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ class AlertCertificateExpired(Alert):
description = AlertDescription.certificate_expired


class AlertDecodeError(Alert):
description = AlertDescription.decode_error


class AlertDecryptError(Alert):
description = AlertDescription.decrypt_error

Expand Down Expand Up @@ -330,6 +334,10 @@ class HandshakeType(IntEnum):
MESSAGE_HASH = 254


class NameType(IntEnum):
HOST_NAME = 0


class PskKeyExchangeMode(IntEnum):
PSK_KE = 0
PSK_DHE_KE = 1
Expand Down Expand Up @@ -365,7 +373,9 @@ def pull_block(buf: Buffer, capacity: int) -> Generator:
length = int.from_bytes(buf.pull_bytes(capacity), byteorder="big")
end = buf.tell() + length
yield length
assert buf.tell() == end
if buf.tell() != end:
# There was trailing garbage or our parsing was bad.
raise AlertDecodeError("extra bytes at the end of a block")
jlaine marked this conversation as resolved.
Show resolved Hide resolved


@contextmanager
Expand Down Expand Up @@ -433,6 +443,26 @@ def push_extension(buf: Buffer, extension_type: int) -> Generator:
yield


# ServerName


def pull_server_name(buf: Buffer) -> str:
with pull_block(buf, 2):
name_type = buf.pull_uint8()
if name_type != NameType.HOST_NAME:
# We don't know this name_type.
raise AlertIllegalParameter(
f"ServerName has an unknown name type {name_type}"
)
return pull_opaque(buf, 2).decode("ascii")


def push_server_name(buf: Buffer, server_name: str) -> None:
with push_block(buf, 2):
buf.push_uint8(NameType.HOST_NAME)
push_opaque(buf, 2, server_name.encode("ascii"))


# KeyShareEntry


Expand Down Expand Up @@ -466,6 +496,12 @@ def push_alpn_protocol(buf: Buffer, protocol: str) -> None:
PskIdentity = Tuple[bytes, int]


@dataclass
class OfferedPsks:
identities: List[PskIdentity]
binders: List[bytes]


def pull_psk_identity(buf: Buffer) -> PskIdentity:
identity = pull_opaque(buf, 2)
obfuscated_ticket_age = buf.pull_uint32()
Expand All @@ -485,15 +521,31 @@ def push_psk_binder(buf: Buffer, binder: bytes) -> None:
push_opaque(buf, 1, binder)


# MESSAGES
def pull_offered_psks(buf: Buffer) -> OfferedPsks:
return OfferedPsks(
identities=pull_list(buf, 2, partial(pull_psk_identity, buf)),
binders=pull_list(buf, 2, partial(pull_psk_binder, buf)),
)

Extension = Tuple[int, bytes]

def push_offered_psks(buf: Buffer, pre_shared_key: OfferedPsks) -> None:
push_list(
buf,
2,
partial(push_psk_identity, buf),
pre_shared_key.identities,
)
push_list(
buf,
2,
partial(push_psk_binder, buf),
pre_shared_key.binders,
)

@dataclass
class OfferedPsks:
identities: List[PskIdentity]
binders: List[bytes]

# MESSAGES

Extension = Tuple[int, bytes]


@dataclass
Expand All @@ -517,10 +569,21 @@ class ClientHello:
other_extensions: List[Extension] = field(default_factory=list)


def pull_handshake_type(buf: Buffer, expected_type: HandshakeType) -> None:
"""
Pull the message type and assert it is the expected one.

If it is not, we have a programming error.
"""
message_type = buf.pull_uint8()
assert message_type == expected_type


def pull_client_hello(buf: Buffer) -> ClientHello:
assert buf.pull_uint8() == HandshakeType.CLIENT_HELLO
pull_handshake_type(buf, HandshakeType.CLIENT_HELLO)
with pull_block(buf, 3):
assert buf.pull_uint16() == TLS_VERSION_1_2
if buf.pull_uint16() != TLS_VERSION_1_2:
raise AlertDecodeError("ClientHello version is not 1.2")

hello = ClientHello(
random=buf.pull_bytes(32),
Expand All @@ -535,7 +598,9 @@ def pull_client_hello(buf: Buffer) -> ClientHello:
def pull_extension() -> None:
# pre_shared_key MUST be last
nonlocal after_psk
assert not after_psk
if after_psk:
# the alert is Illegal Parameter per RFC 8446 section 4.2.11.
raise AlertIllegalParameter("PreSharedKey is not the last extension")

extension_type = buf.pull_uint16()
extension_length = buf.pull_uint16()
Expand All @@ -550,20 +615,15 @@ def pull_extension() -> None:
elif extension_type == ExtensionType.PSK_KEY_EXCHANGE_MODES:
hello.psk_key_exchange_modes = pull_list(buf, 1, buf.pull_uint8)
elif extension_type == ExtensionType.SERVER_NAME:
with pull_block(buf, 2):
assert buf.pull_uint8() == 0
hello.server_name = pull_opaque(buf, 2).decode("ascii")
hello.server_name = pull_server_name(buf)
elif extension_type == ExtensionType.ALPN:
hello.alpn_protocols = pull_list(
buf, 2, partial(pull_alpn_protocol, buf)
)
elif extension_type == ExtensionType.EARLY_DATA:
hello.early_data = True
elif extension_type == ExtensionType.PRE_SHARED_KEY:
hello.pre_shared_key = OfferedPsks(
identities=pull_list(buf, 2, partial(pull_psk_identity, buf)),
binders=pull_list(buf, 2, partial(pull_psk_binder, buf)),
)
hello.pre_shared_key = pull_offered_psks(buf)
after_psk = True
else:
hello.other_extensions.append(
Expand Down Expand Up @@ -604,9 +664,7 @@ def push_client_hello(buf: Buffer, hello: ClientHello) -> None:

if hello.server_name is not None:
with push_extension(buf, ExtensionType.SERVER_NAME):
with push_block(buf, 2):
buf.push_uint8(0)
push_opaque(buf, 2, hello.server_name.encode("ascii"))
push_server_name(buf, hello.server_name)

if hello.alpn_protocols is not None:
with push_extension(buf, ExtensionType.ALPN):
Expand All @@ -625,18 +683,7 @@ def push_client_hello(buf: Buffer, hello: ClientHello) -> None:
# pre_shared_key MUST be last
if hello.pre_shared_key is not None:
with push_extension(buf, ExtensionType.PRE_SHARED_KEY):
push_list(
buf,
2,
partial(push_psk_identity, buf),
hello.pre_shared_key.identities,
)
push_list(
buf,
2,
partial(push_psk_binder, buf),
hello.pre_shared_key.binders,
)
push_offered_psks(buf, hello.pre_shared_key)


@dataclass
Expand All @@ -654,9 +701,10 @@ class ServerHello:


def pull_server_hello(buf: Buffer) -> ServerHello:
assert buf.pull_uint8() == HandshakeType.SERVER_HELLO
pull_handshake_type(buf, HandshakeType.SERVER_HELLO)
with pull_block(buf, 3):
assert buf.pull_uint16() == TLS_VERSION_1_2
if buf.pull_uint16() != TLS_VERSION_1_2:
raise AlertDecodeError("ServerHello version is not 1.2")

hello = ServerHello(
random=buf.pull_bytes(32),
Expand Down Expand Up @@ -729,7 +777,7 @@ class NewSessionTicket:
def pull_new_session_ticket(buf: Buffer) -> NewSessionTicket:
new_session_ticket = NewSessionTicket()

assert buf.pull_uint8() == HandshakeType.NEW_SESSION_TICKET
pull_handshake_type(buf, HandshakeType.NEW_SESSION_TICKET)
with pull_block(buf, 3):
new_session_ticket.ticket_lifetime = buf.pull_uint32()
new_session_ticket.ticket_age_add = buf.pull_uint32()
Expand Down Expand Up @@ -780,7 +828,7 @@ class EncryptedExtensions:
def pull_encrypted_extensions(buf: Buffer) -> EncryptedExtensions:
extensions = EncryptedExtensions()

assert buf.pull_uint8() == HandshakeType.ENCRYPTED_EXTENSIONS
pull_handshake_type(buf, HandshakeType.ENCRYPTED_EXTENSIONS)
with pull_block(buf, 3):

def pull_extension() -> None:
Expand Down Expand Up @@ -836,7 +884,7 @@ class Certificate:
def pull_certificate(buf: Buffer) -> Certificate:
certificate = Certificate()

assert buf.pull_uint8() == HandshakeType.CERTIFICATE
pull_handshake_type(buf, HandshakeType.CERTIFICATE)
with pull_block(buf, 3):
certificate.request_context = pull_opaque(buf, 1)

Expand Down Expand Up @@ -876,7 +924,7 @@ class CertificateRequest:
def pull_certificate_request(buf: Buffer) -> CertificateRequest:
certificate_request = CertificateRequest()

assert buf.pull_uint8() == HandshakeType.CERTIFICATE_REQUEST
pull_handshake_type(buf, HandshakeType.CERTIFICATE_REQUEST)
with pull_block(buf, 3):
certificate_request.request_context = pull_opaque(buf, 1)

Expand Down Expand Up @@ -922,7 +970,7 @@ class CertificateVerify:


def pull_certificate_verify(buf: Buffer) -> CertificateVerify:
assert buf.pull_uint8() == HandshakeType.CERTIFICATE_VERIFY
pull_handshake_type(buf, HandshakeType.CERTIFICATE_VERIFY)
with pull_block(buf, 3):
algorithm = buf.pull_uint16()
signature = pull_opaque(buf, 2)
Expand All @@ -945,7 +993,7 @@ class Finished:
def pull_finished(buf: Buffer) -> Finished:
finished = Finished()

assert buf.pull_uint8() == HandshakeType.FINISHED
pull_handshake_type(buf, HandshakeType.FINISHED)
finished.verify_data = pull_opaque(buf, 3)

return finished
Expand Down Expand Up @@ -1373,6 +1421,9 @@ def handle_message(
elif self.state == State.SERVER_POST_HANDSHAKE:
raise AlertUnexpectedMessage

# This condition should never be reached, because if the message
# contains any extra bytes, the `pull_block` inside the message
# parser will raise `AlertDecodeError`.
assert input_buf.eof()

def _build_session_ticket(
Expand Down Expand Up @@ -1402,7 +1453,10 @@ def _build_session_ticket(
)

def _check_certificate_verify_signature(self, verify: CertificateVerify) -> None:
assert verify.algorithm in self._signature_algorithms
if verify.algorithm not in self._signature_algorithms:
raise AlertDecryptError(
"CertificateVerify has a signature algorithm we did not advertise"
)

try:
self._peer_certificate.public_key().verify(
Expand Down Expand Up @@ -1524,8 +1578,14 @@ def _client_handle_hello(self, input_buf: Buffer, output_buf: Buffer) -> None:
[peer_hello.cipher_suite],
AlertHandshakeFailure("Unsupported cipher suite"),
)
assert peer_hello.compression_method in self._legacy_compression_methods
assert peer_hello.supported_version in self._supported_versions
if peer_hello.compression_method not in self._legacy_compression_methods:
raise AlertIllegalParameter(
"ServerHello has a compression method we did not advertise"
)
if peer_hello.supported_version not in self._supported_versions:
raise AlertIllegalParameter(
"ServerHello has a version we did not advertise"
)

# select key schedule
if peer_hello.pre_shared_key is not None:
Expand Down
Loading
Loading