From 57722463ff6b6cbe11f056d7f253ed5161f25356 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Thu, 28 Dec 2023 07:57:15 -0800 Subject: [PATCH] Fix tls.py assertion issues. (#435) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1) Some assertions had side effects and would cause a loss of framing if python was run with -O or -OO 2) Some assertions should have been error checks. Co-authored-by: Jeremy Lainé --- src/aioquic/tls.py | 146 +++++++++++++++++++++--------- tests/test_tls.py | 221 ++++++++++++++++++++++++++++++++++----------- 2 files changed, 272 insertions(+), 95 deletions(-) diff --git a/src/aioquic/tls.py b/src/aioquic/tls.py index 57b1904c3..7d61db8e7 100644 --- a/src/aioquic/tls.py +++ b/src/aioquic/tls.py @@ -100,6 +100,10 @@ class AlertCertificateExpired(Alert): description = AlertDescription.certificate_expired +class AlertDecodeError(Alert): + description = AlertDescription.decode_error + + class AlertDecryptError(Alert): description = AlertDescription.decrypt_error @@ -330,6 +334,10 @@ class HandshakeType(IntEnum): MESSAGE_HASH = 254 +class NameType(IntEnum): + HOST_NAME = 0 + + class PskKeyExchangeMode(IntEnum): PSK_KE = 0 PSK_DHE_KE = 1 @@ -365,7 +373,9 @@ def pull_block(buf: Buffer, capacity: int) -> Generator: length = int.from_bytes(buf.pull_bytes(capacity), byteorder="big") end = buf.tell() + length yield length - assert buf.tell() == end + if buf.tell() != end: + # There was trailing garbage or our parsing was bad. + raise AlertDecodeError("extra bytes at the end of a block") @contextmanager @@ -433,6 +443,26 @@ def push_extension(buf: Buffer, extension_type: int) -> Generator: yield +# ServerName + + +def pull_server_name(buf: Buffer) -> str: + with pull_block(buf, 2): + name_type = buf.pull_uint8() + if name_type != NameType.HOST_NAME: + # We don't know this name_type. + raise AlertIllegalParameter( + f"ServerName has an unknown name type {name_type}" + ) + return pull_opaque(buf, 2).decode("ascii") + + +def push_server_name(buf: Buffer, server_name: str) -> None: + with push_block(buf, 2): + buf.push_uint8(NameType.HOST_NAME) + push_opaque(buf, 2, server_name.encode("ascii")) + + # KeyShareEntry @@ -466,6 +496,12 @@ def push_alpn_protocol(buf: Buffer, protocol: str) -> None: PskIdentity = Tuple[bytes, int] +@dataclass +class OfferedPsks: + identities: List[PskIdentity] + binders: List[bytes] + + def pull_psk_identity(buf: Buffer) -> PskIdentity: identity = pull_opaque(buf, 2) obfuscated_ticket_age = buf.pull_uint32() @@ -485,15 +521,31 @@ def push_psk_binder(buf: Buffer, binder: bytes) -> None: push_opaque(buf, 1, binder) -# MESSAGES +def pull_offered_psks(buf: Buffer) -> OfferedPsks: + return OfferedPsks( + identities=pull_list(buf, 2, partial(pull_psk_identity, buf)), + binders=pull_list(buf, 2, partial(pull_psk_binder, buf)), + ) -Extension = Tuple[int, bytes] +def push_offered_psks(buf: Buffer, pre_shared_key: OfferedPsks) -> None: + push_list( + buf, + 2, + partial(push_psk_identity, buf), + pre_shared_key.identities, + ) + push_list( + buf, + 2, + partial(push_psk_binder, buf), + pre_shared_key.binders, + ) -@dataclass -class OfferedPsks: - identities: List[PskIdentity] - binders: List[bytes] + +# MESSAGES + +Extension = Tuple[int, bytes] @dataclass @@ -517,10 +569,21 @@ class ClientHello: other_extensions: List[Extension] = field(default_factory=list) +def pull_handshake_type(buf: Buffer, expected_type: HandshakeType) -> None: + """ + Pull the message type and assert it is the expected one. + + If it is not, we have a programming error. + """ + message_type = buf.pull_uint8() + assert message_type == expected_type + + def pull_client_hello(buf: Buffer) -> ClientHello: - assert buf.pull_uint8() == HandshakeType.CLIENT_HELLO + pull_handshake_type(buf, HandshakeType.CLIENT_HELLO) with pull_block(buf, 3): - assert buf.pull_uint16() == TLS_VERSION_1_2 + if buf.pull_uint16() != TLS_VERSION_1_2: + raise AlertDecodeError("ClientHello version is not 1.2") hello = ClientHello( random=buf.pull_bytes(32), @@ -535,7 +598,9 @@ def pull_client_hello(buf: Buffer) -> ClientHello: def pull_extension() -> None: # pre_shared_key MUST be last nonlocal after_psk - assert not after_psk + if after_psk: + # the alert is Illegal Parameter per RFC 8446 section 4.2.11. + raise AlertIllegalParameter("PreSharedKey is not the last extension") extension_type = buf.pull_uint16() extension_length = buf.pull_uint16() @@ -550,9 +615,7 @@ def pull_extension() -> None: elif extension_type == ExtensionType.PSK_KEY_EXCHANGE_MODES: hello.psk_key_exchange_modes = pull_list(buf, 1, buf.pull_uint8) elif extension_type == ExtensionType.SERVER_NAME: - with pull_block(buf, 2): - assert buf.pull_uint8() == 0 - hello.server_name = pull_opaque(buf, 2).decode("ascii") + hello.server_name = pull_server_name(buf) elif extension_type == ExtensionType.ALPN: hello.alpn_protocols = pull_list( buf, 2, partial(pull_alpn_protocol, buf) @@ -560,10 +623,7 @@ def pull_extension() -> None: elif extension_type == ExtensionType.EARLY_DATA: hello.early_data = True elif extension_type == ExtensionType.PRE_SHARED_KEY: - hello.pre_shared_key = OfferedPsks( - identities=pull_list(buf, 2, partial(pull_psk_identity, buf)), - binders=pull_list(buf, 2, partial(pull_psk_binder, buf)), - ) + hello.pre_shared_key = pull_offered_psks(buf) after_psk = True else: hello.other_extensions.append( @@ -604,9 +664,7 @@ def push_client_hello(buf: Buffer, hello: ClientHello) -> None: if hello.server_name is not None: with push_extension(buf, ExtensionType.SERVER_NAME): - with push_block(buf, 2): - buf.push_uint8(0) - push_opaque(buf, 2, hello.server_name.encode("ascii")) + push_server_name(buf, hello.server_name) if hello.alpn_protocols is not None: with push_extension(buf, ExtensionType.ALPN): @@ -625,18 +683,7 @@ def push_client_hello(buf: Buffer, hello: ClientHello) -> None: # pre_shared_key MUST be last if hello.pre_shared_key is not None: with push_extension(buf, ExtensionType.PRE_SHARED_KEY): - push_list( - buf, - 2, - partial(push_psk_identity, buf), - hello.pre_shared_key.identities, - ) - push_list( - buf, - 2, - partial(push_psk_binder, buf), - hello.pre_shared_key.binders, - ) + push_offered_psks(buf, hello.pre_shared_key) @dataclass @@ -654,9 +701,10 @@ class ServerHello: def pull_server_hello(buf: Buffer) -> ServerHello: - assert buf.pull_uint8() == HandshakeType.SERVER_HELLO + pull_handshake_type(buf, HandshakeType.SERVER_HELLO) with pull_block(buf, 3): - assert buf.pull_uint16() == TLS_VERSION_1_2 + if buf.pull_uint16() != TLS_VERSION_1_2: + raise AlertDecodeError("ServerHello version is not 1.2") hello = ServerHello( random=buf.pull_bytes(32), @@ -729,7 +777,7 @@ class NewSessionTicket: def pull_new_session_ticket(buf: Buffer) -> NewSessionTicket: new_session_ticket = NewSessionTicket() - assert buf.pull_uint8() == HandshakeType.NEW_SESSION_TICKET + pull_handshake_type(buf, HandshakeType.NEW_SESSION_TICKET) with pull_block(buf, 3): new_session_ticket.ticket_lifetime = buf.pull_uint32() new_session_ticket.ticket_age_add = buf.pull_uint32() @@ -780,7 +828,7 @@ class EncryptedExtensions: def pull_encrypted_extensions(buf: Buffer) -> EncryptedExtensions: extensions = EncryptedExtensions() - assert buf.pull_uint8() == HandshakeType.ENCRYPTED_EXTENSIONS + pull_handshake_type(buf, HandshakeType.ENCRYPTED_EXTENSIONS) with pull_block(buf, 3): def pull_extension() -> None: @@ -836,7 +884,7 @@ class Certificate: def pull_certificate(buf: Buffer) -> Certificate: certificate = Certificate() - assert buf.pull_uint8() == HandshakeType.CERTIFICATE + pull_handshake_type(buf, HandshakeType.CERTIFICATE) with pull_block(buf, 3): certificate.request_context = pull_opaque(buf, 1) @@ -876,7 +924,7 @@ class CertificateRequest: def pull_certificate_request(buf: Buffer) -> CertificateRequest: certificate_request = CertificateRequest() - assert buf.pull_uint8() == HandshakeType.CERTIFICATE_REQUEST + pull_handshake_type(buf, HandshakeType.CERTIFICATE_REQUEST) with pull_block(buf, 3): certificate_request.request_context = pull_opaque(buf, 1) @@ -922,7 +970,7 @@ class CertificateVerify: def pull_certificate_verify(buf: Buffer) -> CertificateVerify: - assert buf.pull_uint8() == HandshakeType.CERTIFICATE_VERIFY + pull_handshake_type(buf, HandshakeType.CERTIFICATE_VERIFY) with pull_block(buf, 3): algorithm = buf.pull_uint16() signature = pull_opaque(buf, 2) @@ -945,7 +993,7 @@ class Finished: def pull_finished(buf: Buffer) -> Finished: finished = Finished() - assert buf.pull_uint8() == HandshakeType.FINISHED + pull_handshake_type(buf, HandshakeType.FINISHED) finished.verify_data = pull_opaque(buf, 3) return finished @@ -1373,6 +1421,9 @@ def handle_message( elif self.state == State.SERVER_POST_HANDSHAKE: raise AlertUnexpectedMessage + # This condition should never be reached, because if the message + # contains any extra bytes, the `pull_block` inside the message + # parser will raise `AlertDecodeError`. assert input_buf.eof() def _build_session_ticket( @@ -1402,7 +1453,10 @@ def _build_session_ticket( ) def _check_certificate_verify_signature(self, verify: CertificateVerify) -> None: - assert verify.algorithm in self._signature_algorithms + if verify.algorithm not in self._signature_algorithms: + raise AlertDecryptError( + "CertificateVerify has a signature algorithm we did not advertise" + ) try: self._peer_certificate.public_key().verify( @@ -1524,8 +1578,14 @@ def _client_handle_hello(self, input_buf: Buffer, output_buf: Buffer) -> None: [peer_hello.cipher_suite], AlertHandshakeFailure("Unsupported cipher suite"), ) - assert peer_hello.compression_method in self._legacy_compression_methods - assert peer_hello.supported_version in self._supported_versions + if peer_hello.compression_method not in self._legacy_compression_methods: + raise AlertIllegalParameter( + "ServerHello has a compression method we did not advertise" + ) + if peer_hello.supported_version not in self._supported_versions: + raise AlertIllegalParameter( + "ServerHello has a version we did not advertise" + ) # select key schedule if peer_hello.pre_shared_key is not None: diff --git a/tests/test_tls.py b/tests/test_tls.py index 7bb2c83d0..77ff7f476 100644 --- a/tests/test_tls.py +++ b/tests/test_tls.py @@ -28,6 +28,7 @@ pull_finished, pull_new_session_ticket, pull_server_hello, + pull_server_name, push_certificate, push_certificate_request, push_certificate_verify, @@ -36,6 +37,7 @@ push_finished, push_new_session_ticket, push_server_hello, + push_server_name, verify_certificate, ) from cryptography.exceptions import UnsupportedAlgorithm @@ -87,6 +89,13 @@ def test_pull_block_truncated(self): pass +def corrupt_hello_version(data: bytes) -> bytes: + """ + Corrupt a ClientHello or ServerHello's protocol version. + """ + return data[:4] + b"\xff\xff" + data[6:] + + def create_buffers(): return { tls.Epoch.INITIAL: Buffer(capacity=4096), @@ -145,6 +154,40 @@ def create_server(self, alpn_protocols=None, **kwargs): self.assertEqual(server.state, State.SERVER_EXPECT_CLIENT_HELLO) return server + def handshake_with_client_input_corruption( + self, + corrupt_client_input, + expected_exception, + ): + client = self.create_client() + server = self.create_server() + + # Send client hello. + client_buf = create_buffers() + client.handle_message(b"", client_buf) + self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) + server_input = merge_buffers(client_buf) + reset_buffers(client_buf) + + # Handle client hello. + # + # send server hello, encrypted extensions, certificate, certificate verify, + # finished. + server_buf = create_buffers() + server.handle_message(server_input, server_buf) + self.assertEqual(server.state, State.SERVER_EXPECT_FINISHED) + client_input = merge_buffers(server_buf) + reset_buffers(server_buf) + + # Mess with compression method. + client_input = corrupt_client_input(client_input) + + # Handle server hello, encrypted extensions, certificate, certificate verify, + # finished. + with self.assertRaises(expected_exception.__class__) as cm: + client.handle_message(client_input, client_buf) + self.assertEqual(str(cm.exception), str(expected_exception)) + def test_client_unexpected_message(self): client = self.create_client() @@ -176,63 +219,44 @@ def test_client_unexpected_message(self): with self.assertRaises(tls.AlertUnexpectedMessage): client.handle_message(b"\x00\x00\x00\x00", create_buffers()) - def test_client_bad_certificate_verify_data(self): - client = self.create_client() - server = self.create_server() - - # Send client hello. - client_buf = create_buffers() - client.handle_message(b"", client_buf) - self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) - server_input = merge_buffers(client_buf) - reset_buffers(client_buf) + def test_client_bad_hello_compression_method(self): + self.handshake_with_client_input_corruption( + # Mess with compression method. + lambda x: x[:41] + b"\xff" + x[42:], + tls.AlertIllegalParameter( + "ServerHello has a compression method we did not advertise" + ), + ) - # Handle client hello. - # - # send server hello, encrypted extensions, certificate, certificate verify, - # finished. - server_buf = create_buffers() - server.handle_message(server_input, server_buf) - self.assertEqual(server.state, State.SERVER_EXPECT_FINISHED) - client_input = merge_buffers(server_buf) - reset_buffers(server_buf) + def test_client_bad_hello_version(self): + self.handshake_with_client_input_corruption( + # Mess with supported version. + lambda x: x[:48] + b"\xff\xff" + x[50:], + tls.AlertIllegalParameter("ServerHello has a version we did not advertise"), + ) - # Mess with certificate verify. - client_input = client_input[:-56] + bytes(4) + client_input[-52:] + def test_client_bad_certificate_verify_algorithm(self): + self.handshake_with_client_input_corruption( + # Mess with certificate verify. + lambda x: x[:-440] + b"\xff\xff" + x[-438:], + tls.AlertDecryptError( + "CertificateVerify has a signature algorithm we did not advertise" + ), + ) - # Handle server hello, encrypted extensions, certificate, certificate verify, - # finished. - with self.assertRaises(tls.AlertDecryptError): - client.handle_message(client_input, client_buf) + def test_client_bad_certificate_verify_data(self): + self.handshake_with_client_input_corruption( + # Mess with certificate verify. + lambda x: x[:-56] + bytes(4) + x[-52:], + tls.AlertDecryptError(), + ) def test_client_bad_finished_verify_data(self): - client = self.create_client() - server = self.create_server() - - # Send client hello. - client_buf = create_buffers() - client.handle_message(b"", client_buf) - self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) - server_input = merge_buffers(client_buf) - reset_buffers(client_buf) - - # Handle client hello. - # - # Send server hello, encrypted extensions, certificate, certificate verify, - # finished. - server_buf = create_buffers() - server.handle_message(server_input, server_buf) - self.assertEqual(server.state, State.SERVER_EXPECT_FINISHED) - client_input = merge_buffers(server_buf) - reset_buffers(server_buf) - - # Mess with finished verify data. - client_input = client_input[:-4] + bytes(4) - - # Handle server hello, encrypted extensions, certificate, certificate verify, - # finished. - with self.assertRaises(tls.AlertDecryptError): - client.handle_message(client_input, client_buf) + self.handshake_with_client_input_corruption( + # Mess with finished verify data. + lambda x: x[:-4] + bytes(4), + tls.AlertDecryptError(), + ) def test_server_unexpected_message(self): server = self.create_server() @@ -765,6 +789,16 @@ def second_handshake_bad_pre_shared_key(): class TlsTest(TestCase): + def test_pull_block_incomplete_read(self): + """ + If a block is not read until its end, an alert should be raised. + """ + buf = Buffer(data=bytes([2, 0, 0])) + with self.assertRaises(tls.AlertDecodeError) as cm: + with pull_block(buf, 1): + buf.pull_bytes(1) + self.assertEqual(str(cm.exception), "extra bytes at the end of a block") + def test_pull_client_hello(self): buf = Buffer(data=load("tls_client_hello.bin")) hello = pull_client_hello(buf) @@ -954,6 +988,62 @@ def test_pull_client_hello_with_psk(self): push_client_hello(buf, hello) self.assertEqual(buf.data, load("tls_client_hello_with_psk.bin")) + def test_pull_client_hello_with_psk_and_other_extension(self): + buf = Buffer(capacity=1000) + + # Prepare PSK. + psk_buf = Buffer(capacity=100) + tls.push_offered_psks( + psk_buf, + tls.OfferedPsks( + identities=[], + binders=[], + ), + ) + + # Write a ClientHello with an extension *after* PSK. + hello = ClientHello( + random=binascii.unhexlify( + "18b2b23bf3e44b5d52ccfe7aecbc5ff14eadc3d349fabf804d71f165ae76e7d5" + ), + legacy_session_id=binascii.unhexlify( + "9aee82a2d186c1cb32a329d9dcfe004a1a438ad0485a53c6bfcf55c132a23235" + ), + cipher_suites=[tls.CipherSuite.AES_256_GCM_SHA384], + legacy_compression_methods=[tls.CompressionMethod.NULL], + key_share=[ + ( + tls.Group.SECP256R1, + binascii.unhexlify( + "047bfea344467535054263b75def60cffa82405a211b68d1eb8d1d944e67aef8" + "93c7665a5473d032cfaf22a73da28eb4aacae0017ed12557b5791f98a1e84f15" + "b0" + ), + ) + ], + psk_key_exchange_modes=[tls.PskKeyExchangeMode.PSK_DHE_KE], + signature_algorithms=[tls.SignatureAlgorithm.RSA_PSS_RSAE_SHA256], + supported_groups=[tls.Group.SECP256R1], + supported_versions=[tls.TLS_VERSION_1_3], + other_extensions=[ + ( + tls.ExtensionType.PRE_SHARED_KEY, + psk_buf.data, + ), + ( + tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS_DRAFT, + CLIENT_QUIC_TRANSPORT_PARAMETERS, + ), + ], + ) + push_client_hello(buf, hello) + + # Try reading it back. + buf.seek(0) + with self.assertRaises(tls.AlertIllegalParameter) as cm: + pull_client_hello(buf) + self.assertEqual(str(cm.exception), "PreSharedKey is not the last extension") + def test_pull_client_hello_with_sni(self): buf = Buffer(data=load("tls_client_hello_with_sni.bin")) hello = pull_client_hello(buf) @@ -1035,6 +1125,12 @@ def test_pull_client_hello_with_sni(self): push_client_hello(buf, hello) self.assertEqual(buf.data, load("tls_client_hello_with_sni.bin")) + def test_pull_client_hello_with_unexpected_version(self): + buf = Buffer(data=corrupt_hello_version(load("tls_client_hello.bin"))) + with self.assertRaises(tls.AlertDecodeError) as cm: + pull_client_hello(buf) + self.assertEqual(str(cm.exception), "ClientHello version is not 1.2") + def test_push_client_hello(self): hello = ClientHello( random=binascii.unhexlify( @@ -1156,6 +1252,12 @@ def test_pull_server_hello_with_psk(self): push_server_hello(buf, hello) self.assertEqual(buf.data, load("tls_server_hello_with_psk.bin")) + def test_pull_server_hello_with_unexpected_version(self): + buf = Buffer(data=corrupt_hello_version(load("tls_server_hello.bin"))) + with self.assertRaises(tls.AlertDecodeError) as cm: + pull_server_hello(buf) + self.assertEqual(str(cm.exception), "ServerHello version is not 1.2") + def test_pull_server_hello_with_unknown_extension(self): buf = Buffer(data=load("tls_server_hello_with_unknown_extension.bin")) hello = pull_server_hello(buf) @@ -1433,6 +1535,21 @@ def test_push_finished(self): push_finished(buf, finished) self.assertEqual(buf.data, load("tls_finished.bin")) + def test_pull_server_name(self): + buf = Buffer(data=b"\x00\x12\x00\x00\x0fwww.example.com") + self.assertEqual(pull_server_name(buf), "www.example.com") + + def test_pull_server_name_with_bad_name_type(self): + buf = Buffer(data=b"\x00\x12\xff\x00\x0fwww.example.com") + with self.assertRaises(tls.AlertIllegalParameter) as cm: + pull_server_name(buf) + self.assertEqual(str(cm.exception), "ServerName has an unknown name type 255") + + def test_push_server_name(self): + buf = Buffer(128) + push_server_name(buf, "www.example.com") + self.assertEqual(buf.data, b"\x00\x12\x00\x00\x0fwww.example.com") + class VerifyCertificateTest(TestCase): def test_verify_certificate_chain(self):