From 1be1bdf9ebfdedc82470f84c9f738312dbbdc7b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jeremy=20Lain=C3=A9?= Date: Thu, 28 Dec 2023 17:37:07 +0100 Subject: [PATCH] Raise a TLS alert if parsing a message causes a BufferReadError While parsing TLS messages, if a `BufferReadError` occurs, it is a sign we are unable to parse the message because it is malformed. In such an event raise a `AlertDecodeError` exception so that the connection gets shut down with a TLS error. --- src/aioquic/tls.py | 167 ++++++++++++++++++++++++--------------------- tests/test_tls.py | 14 ++++ 2 files changed, 102 insertions(+), 79 deletions(-) diff --git a/src/aioquic/tls.py b/src/aioquic/tls.py index 7d61db8e7..18a402266 100644 --- a/src/aioquic/tls.py +++ b/src/aioquic/tls.py @@ -41,7 +41,7 @@ from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat from OpenSSL import crypto -from .buffer import Buffer +from .buffer import Buffer, BufferReadError TLS_VERSION_1_2 = 0x0303 TLS_VERSION_1_3 = 0x0304 @@ -1345,86 +1345,95 @@ def handle_message( message = self._receive_buffer[:message_length] self._receive_buffer = self._receive_buffer[message_length:] - input_buf = Buffer(data=message) - - # client states - - if self.state == State.CLIENT_EXPECT_SERVER_HELLO: - if message_type == HandshakeType.SERVER_HELLO: - self._client_handle_hello(input_buf, output_buf[Epoch.INITIAL]) - else: - raise AlertUnexpectedMessage - elif self.state == State.CLIENT_EXPECT_ENCRYPTED_EXTENSIONS: - if message_type == HandshakeType.ENCRYPTED_EXTENSIONS: - self._client_handle_encrypted_extensions(input_buf) - else: - raise AlertUnexpectedMessage - elif self.state == State.CLIENT_EXPECT_CERTIFICATE_REQUEST_OR_CERTIFICATE: - if message_type == HandshakeType.CERTIFICATE: - self._client_handle_certificate(input_buf) - elif message_type == HandshakeType.CERTIFICATE_REQUEST: - self._client_handle_certificate_request(input_buf) - else: - raise AlertUnexpectedMessage - elif self.state == State.CLIENT_EXPECT_CERTIFICATE: - if message_type == HandshakeType.CERTIFICATE: - self._client_handle_certificate(input_buf) - else: - raise AlertUnexpectedMessage - elif self.state == State.CLIENT_EXPECT_CERTIFICATE_VERIFY: - if message_type == HandshakeType.CERTIFICATE_VERIFY: - self._client_handle_certificate_verify(input_buf) - else: - raise AlertUnexpectedMessage - elif self.state == State.CLIENT_EXPECT_FINISHED: - if message_type == HandshakeType.FINISHED: - self._client_handle_finished(input_buf, output_buf[Epoch.HANDSHAKE]) - else: - raise AlertUnexpectedMessage - elif self.state == State.CLIENT_POST_HANDSHAKE: - if message_type == HandshakeType.NEW_SESSION_TICKET: - self._client_handle_new_session_ticket(input_buf) - else: - raise AlertUnexpectedMessage - - # server states - - elif self.state == State.SERVER_EXPECT_CLIENT_HELLO: - if message_type == HandshakeType.CLIENT_HELLO: - self._server_handle_hello( - input_buf, - output_buf[Epoch.INITIAL], - output_buf[Epoch.HANDSHAKE], - output_buf[Epoch.ONE_RTT], - ) - else: - raise AlertUnexpectedMessage - elif self.state == State.SERVER_EXPECT_CERTIFICATE: - if message_type == HandshakeType.CERTIFICATE: - self._server_handle_certificate( - input_buf, output_buf[Epoch.ONE_RTT] - ) - else: - raise AlertUnexpectedMessage - elif self.state == State.SERVER_EXPECT_CERTIFICATE_VERIFY: - if message_type == HandshakeType.CERTIFICATE_VERIFY: - self._server_handle_certificate_verify( - input_buf, output_buf[Epoch.ONE_RTT] - ) - else: - raise AlertUnexpectedMessage - elif self.state == State.SERVER_EXPECT_FINISHED: - if message_type == HandshakeType.FINISHED: - self._server_handle_finished(input_buf, output_buf[Epoch.ONE_RTT]) - else: - raise AlertUnexpectedMessage - elif self.state == State.SERVER_POST_HANDSHAKE: + # process the message + try: + self._handle_reassembled_message( + message_type=message_type, + input_buf=Buffer(data=message), + output_buf=output_buf, + ) + except BufferReadError: + raise AlertDecodeError("Could not parse TLS message") + + def _handle_reassembled_message( + self, message_type: int, input_buf: Buffer, output_buf: Dict[Epoch, Buffer] + ) -> None: + # client states + + if self.state == State.CLIENT_EXPECT_SERVER_HELLO: + if message_type == HandshakeType.SERVER_HELLO: + self._client_handle_hello(input_buf, output_buf[Epoch.INITIAL]) + else: + raise AlertUnexpectedMessage + elif self.state == State.CLIENT_EXPECT_ENCRYPTED_EXTENSIONS: + if message_type == HandshakeType.ENCRYPTED_EXTENSIONS: + self._client_handle_encrypted_extensions(input_buf) + else: + raise AlertUnexpectedMessage + elif self.state == State.CLIENT_EXPECT_CERTIFICATE_REQUEST_OR_CERTIFICATE: + if message_type == HandshakeType.CERTIFICATE: + self._client_handle_certificate(input_buf) + elif message_type == HandshakeType.CERTIFICATE_REQUEST: + self._client_handle_certificate_request(input_buf) + else: + raise AlertUnexpectedMessage + elif self.state == State.CLIENT_EXPECT_CERTIFICATE: + if message_type == HandshakeType.CERTIFICATE: + self._client_handle_certificate(input_buf) + else: + raise AlertUnexpectedMessage + elif self.state == State.CLIENT_EXPECT_CERTIFICATE_VERIFY: + if message_type == HandshakeType.CERTIFICATE_VERIFY: + self._client_handle_certificate_verify(input_buf) + else: + raise AlertUnexpectedMessage + elif self.state == State.CLIENT_EXPECT_FINISHED: + if message_type == HandshakeType.FINISHED: + self._client_handle_finished(input_buf, output_buf[Epoch.HANDSHAKE]) + else: + raise AlertUnexpectedMessage + elif self.state == State.CLIENT_POST_HANDSHAKE: + if message_type == HandshakeType.NEW_SESSION_TICKET: + self._client_handle_new_session_ticket(input_buf) + else: + raise AlertUnexpectedMessage + + # server states + + elif self.state == State.SERVER_EXPECT_CLIENT_HELLO: + if message_type == HandshakeType.CLIENT_HELLO: + self._server_handle_hello( + input_buf, + output_buf[Epoch.INITIAL], + output_buf[Epoch.HANDSHAKE], + output_buf[Epoch.ONE_RTT], + ) + else: + raise AlertUnexpectedMessage + elif self.state == State.SERVER_EXPECT_CERTIFICATE: + if message_type == HandshakeType.CERTIFICATE: + self._server_handle_certificate(input_buf, output_buf[Epoch.ONE_RTT]) + else: + raise AlertUnexpectedMessage + elif self.state == State.SERVER_EXPECT_CERTIFICATE_VERIFY: + if message_type == HandshakeType.CERTIFICATE_VERIFY: + self._server_handle_certificate_verify( + input_buf, output_buf[Epoch.ONE_RTT] + ) + else: + raise AlertUnexpectedMessage + elif self.state == State.SERVER_EXPECT_FINISHED: + if message_type == HandshakeType.FINISHED: + self._server_handle_finished(input_buf, output_buf[Epoch.ONE_RTT]) + else: raise AlertUnexpectedMessage + elif self.state == State.SERVER_POST_HANDSHAKE: + raise AlertUnexpectedMessage - # This condition should never be reached, because if the message - # contains any extra bytes, the `pull_block` inside the message - # parser will raise `AlertDecodeError`. - assert input_buf.eof() + # This condition should never be reached, because if the message + # contains any extra bytes, the `pull_block` inside the message + # parser will raise `AlertDecodeError`. + assert input_buf.eof() def _build_session_ticket( self, new_session_ticket: NewSessionTicket, other_extensions: List[Extension] diff --git a/tests/test_tls.py b/tests/test_tls.py index 77ff7f476..2f9bd630e 100644 --- a/tests/test_tls.py +++ b/tests/test_tls.py @@ -219,6 +219,20 @@ def test_client_unexpected_message(self): with self.assertRaises(tls.AlertUnexpectedMessage): client.handle_message(b"\x00\x00\x00\x00", create_buffers()) + def test_client_bad_hello_buffer_read_error(self): + buf = Buffer(capacity=100) + buf.push_uint8(tls.HandshakeType.SERVER_HELLO) + with tls.push_block(buf, 3): + pass + + self.handshake_with_client_input_corruption( + # Receive a malformed ServerHello + lambda x: buf.data, + tls.AlertDecodeError( + "Could not parse TLS message" + ), + ) + def test_client_bad_hello_compression_method(self): self.handshake_with_client_input_corruption( # Mess with compression method.