Skip to content

Commit

Permalink
Raise a TLS alert if parsing a message causes a BufferReadError
Browse files Browse the repository at this point in the history
While parsing TLS messages, if a `BufferReadError` occurs, it is a sign
we are unable to parse the message because it is malformed. In such an
event raise a `AlertDecodeError` exception so that the connection gets
shut down with a TLS error.
  • Loading branch information
jlaine committed Dec 28, 2023
1 parent 5772246 commit 9616e13
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 79 deletions.
167 changes: 88 additions & 79 deletions src/aioquic/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
from OpenSSL import crypto

from .buffer import Buffer
from .buffer import Buffer, BufferReadError

TLS_VERSION_1_2 = 0x0303
TLS_VERSION_1_3 = 0x0304
Expand Down Expand Up @@ -1345,86 +1345,95 @@ def handle_message(
message = self._receive_buffer[:message_length]
self._receive_buffer = self._receive_buffer[message_length:]

input_buf = Buffer(data=message)

# client states

if self.state == State.CLIENT_EXPECT_SERVER_HELLO:
if message_type == HandshakeType.SERVER_HELLO:
self._client_handle_hello(input_buf, output_buf[Epoch.INITIAL])
else:
raise AlertUnexpectedMessage
elif self.state == State.CLIENT_EXPECT_ENCRYPTED_EXTENSIONS:
if message_type == HandshakeType.ENCRYPTED_EXTENSIONS:
self._client_handle_encrypted_extensions(input_buf)
else:
raise AlertUnexpectedMessage
elif self.state == State.CLIENT_EXPECT_CERTIFICATE_REQUEST_OR_CERTIFICATE:
if message_type == HandshakeType.CERTIFICATE:
self._client_handle_certificate(input_buf)
elif message_type == HandshakeType.CERTIFICATE_REQUEST:
self._client_handle_certificate_request(input_buf)
else:
raise AlertUnexpectedMessage
elif self.state == State.CLIENT_EXPECT_CERTIFICATE:
if message_type == HandshakeType.CERTIFICATE:
self._client_handle_certificate(input_buf)
else:
raise AlertUnexpectedMessage
elif self.state == State.CLIENT_EXPECT_CERTIFICATE_VERIFY:
if message_type == HandshakeType.CERTIFICATE_VERIFY:
self._client_handle_certificate_verify(input_buf)
else:
raise AlertUnexpectedMessage
elif self.state == State.CLIENT_EXPECT_FINISHED:
if message_type == HandshakeType.FINISHED:
self._client_handle_finished(input_buf, output_buf[Epoch.HANDSHAKE])
else:
raise AlertUnexpectedMessage
elif self.state == State.CLIENT_POST_HANDSHAKE:
if message_type == HandshakeType.NEW_SESSION_TICKET:
self._client_handle_new_session_ticket(input_buf)
else:
raise AlertUnexpectedMessage

# server states

elif self.state == State.SERVER_EXPECT_CLIENT_HELLO:
if message_type == HandshakeType.CLIENT_HELLO:
self._server_handle_hello(
input_buf,
output_buf[Epoch.INITIAL],
output_buf[Epoch.HANDSHAKE],
output_buf[Epoch.ONE_RTT],
)
else:
raise AlertUnexpectedMessage
elif self.state == State.SERVER_EXPECT_CERTIFICATE:
if message_type == HandshakeType.CERTIFICATE:
self._server_handle_certificate(
input_buf, output_buf[Epoch.ONE_RTT]
)
else:
raise AlertUnexpectedMessage
elif self.state == State.SERVER_EXPECT_CERTIFICATE_VERIFY:
if message_type == HandshakeType.CERTIFICATE_VERIFY:
self._server_handle_certificate_verify(
input_buf, output_buf[Epoch.ONE_RTT]
)
else:
raise AlertUnexpectedMessage
elif self.state == State.SERVER_EXPECT_FINISHED:
if message_type == HandshakeType.FINISHED:
self._server_handle_finished(input_buf, output_buf[Epoch.ONE_RTT])
else:
raise AlertUnexpectedMessage
elif self.state == State.SERVER_POST_HANDSHAKE:
# process the message
try:
self._handle_reassembled_message(
message_type=message_type,
input_buf=Buffer(data=message),
output_buf=output_buf,
)
except BufferReadError:
raise AlertDecodeError("Could not parse TLS message")

def _handle_reassembled_message(
self, message_type: int, input_buf: Buffer, output_buf: Dict[Epoch, Buffer]
) -> None:
# client states

if self.state == State.CLIENT_EXPECT_SERVER_HELLO:
if message_type == HandshakeType.SERVER_HELLO:
self._client_handle_hello(input_buf, output_buf[Epoch.INITIAL])
else:
raise AlertUnexpectedMessage
elif self.state == State.CLIENT_EXPECT_ENCRYPTED_EXTENSIONS:
if message_type == HandshakeType.ENCRYPTED_EXTENSIONS:
self._client_handle_encrypted_extensions(input_buf)
else:
raise AlertUnexpectedMessage
elif self.state == State.CLIENT_EXPECT_CERTIFICATE_REQUEST_OR_CERTIFICATE:
if message_type == HandshakeType.CERTIFICATE:
self._client_handle_certificate(input_buf)
elif message_type == HandshakeType.CERTIFICATE_REQUEST:
self._client_handle_certificate_request(input_buf)
else:
raise AlertUnexpectedMessage
elif self.state == State.CLIENT_EXPECT_CERTIFICATE:
if message_type == HandshakeType.CERTIFICATE:
self._client_handle_certificate(input_buf)
else:
raise AlertUnexpectedMessage
elif self.state == State.CLIENT_EXPECT_CERTIFICATE_VERIFY:
if message_type == HandshakeType.CERTIFICATE_VERIFY:
self._client_handle_certificate_verify(input_buf)
else:
raise AlertUnexpectedMessage
elif self.state == State.CLIENT_EXPECT_FINISHED:
if message_type == HandshakeType.FINISHED:
self._client_handle_finished(input_buf, output_buf[Epoch.HANDSHAKE])
else:
raise AlertUnexpectedMessage
elif self.state == State.CLIENT_POST_HANDSHAKE:
if message_type == HandshakeType.NEW_SESSION_TICKET:
self._client_handle_new_session_ticket(input_buf)
else:
raise AlertUnexpectedMessage

# server states

elif self.state == State.SERVER_EXPECT_CLIENT_HELLO:
if message_type == HandshakeType.CLIENT_HELLO:
self._server_handle_hello(
input_buf,
output_buf[Epoch.INITIAL],
output_buf[Epoch.HANDSHAKE],
output_buf[Epoch.ONE_RTT],
)
else:
raise AlertUnexpectedMessage
elif self.state == State.SERVER_EXPECT_CERTIFICATE:
if message_type == HandshakeType.CERTIFICATE:
self._server_handle_certificate(input_buf, output_buf[Epoch.ONE_RTT])
else:
raise AlertUnexpectedMessage
elif self.state == State.SERVER_EXPECT_CERTIFICATE_VERIFY:
if message_type == HandshakeType.CERTIFICATE_VERIFY:
self._server_handle_certificate_verify(
input_buf, output_buf[Epoch.ONE_RTT]
)
else:
raise AlertUnexpectedMessage
elif self.state == State.SERVER_EXPECT_FINISHED:
if message_type == HandshakeType.FINISHED:
self._server_handle_finished(input_buf, output_buf[Epoch.ONE_RTT])
else:
raise AlertUnexpectedMessage
elif self.state == State.SERVER_POST_HANDSHAKE:
raise AlertUnexpectedMessage

# This condition should never be reached, because if the message
# contains any extra bytes, the `pull_block` inside the message
# parser will raise `AlertDecodeError`.
assert input_buf.eof()
# This condition should never be reached, because if the message
# contains any extra bytes, the `pull_block` inside the message
# parser will raise `AlertDecodeError`.
assert input_buf.eof()

def _build_session_ticket(
self, new_session_ticket: NewSessionTicket, other_extensions: List[Extension]
Expand Down
12 changes: 12 additions & 0 deletions tests/test_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,18 @@ def test_client_unexpected_message(self):
with self.assertRaises(tls.AlertUnexpectedMessage):
client.handle_message(b"\x00\x00\x00\x00", create_buffers())

def test_client_bad_hello_buffer_read_error(self):
buf = Buffer(capacity=100)
buf.push_uint8(tls.HandshakeType.SERVER_HELLO)
with tls.push_block(buf, 3):
pass

self.handshake_with_client_input_corruption(
# Receive a malformed ServerHello
lambda x: buf.data,
tls.AlertDecodeError("Could not parse TLS message"),
)

def test_client_bad_hello_compression_method(self):
self.handshake_with_client_input_corruption(
# Mess with compression method.
Expand Down

0 comments on commit 9616e13

Please sign in to comment.