diff --git a/tests/test_tls.py b/tests/test_tls.py index 7bb2c83d0..5cfafaf85 100644 --- a/tests/test_tls.py +++ b/tests/test_tls.py @@ -1,6 +1,7 @@ import binascii import datetime import ssl +from typing import Callable from unittest import TestCase from unittest.mock import patch @@ -87,6 +88,15 @@ def test_pull_block_truncated(self): pass +def corrupt_hello_version(data: bytes) -> bytes: + """ + Corrupt a ClientHello or ServerHello's protocol version. + """ + octets = list(data) + octets[4:6] = [0, 0] + return bytes(octets) + + def create_buffers(): return { tls.Epoch.INITIAL: Buffer(capacity=4096), @@ -145,6 +155,40 @@ def create_server(self, alpn_protocols=None, **kwargs): self.assertEqual(server.state, State.SERVER_EXPECT_CLIENT_HELLO) return server + def handshake_with_client_input_corruption( + self, + corrupt_client_input: Callable[[bytes], bytes], + expected_exception: Exception, + ): + client = self.create_client() + server = self.create_server() + + # Send client hello. + client_buf = create_buffers() + client.handle_message(b"", client_buf) + self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) + server_input = merge_buffers(client_buf) + reset_buffers(client_buf) + + # Handle client hello. + # + # send server hello, encrypted extensions, certificate, certificate verify, + # finished. + server_buf = create_buffers() + server.handle_message(server_input, server_buf) + self.assertEqual(server.state, State.SERVER_EXPECT_FINISHED) + client_input = merge_buffers(server_buf) + reset_buffers(server_buf) + + # Mess with compression method. + client_input = corrupt_client_input(client_input) + + # Handle server hello, encrypted extensions, certificate, certificate verify, + # finished. + with self.assertRaises(expected_exception.__class__) as cm: + client.handle_message(client_input, client_buf) + self.assertEqual(str(cm.exception), str(expected_exception)) + def test_client_unexpected_message(self): client = self.create_client() @@ -176,63 +220,35 @@ def test_client_unexpected_message(self): with self.assertRaises(tls.AlertUnexpectedMessage): client.handle_message(b"\x00\x00\x00\x00", create_buffers()) - def test_client_bad_certificate_verify_data(self): - client = self.create_client() - server = self.create_server() - - # Send client hello. - client_buf = create_buffers() - client.handle_message(b"", client_buf) - self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) - server_input = merge_buffers(client_buf) - reset_buffers(client_buf) - - # Handle client hello. - # - # send server hello, encrypted extensions, certificate, certificate verify, - # finished. - server_buf = create_buffers() - server.handle_message(server_input, server_buf) - self.assertEqual(server.state, State.SERVER_EXPECT_FINISHED) - client_input = merge_buffers(server_buf) - reset_buffers(server_buf) + def test_client_bad_hello_compression_method(self): + self.handshake_with_client_input_corruption( + # Mess with compression method. + lambda x: x[:41] + b"\xff" + x[42:], + tls.AlertIllegalParameter( + "ServerHello has a compression method we did not advertise" + ), + ) - # Mess with certificate verify. - client_input = client_input[:-56] + bytes(4) + client_input[-52:] + def test_client_bad_hello_version(self): + self.handshake_with_client_input_corruption( + # Mess with compression method. + lambda x: x[:48] + b"\xff\xff" + x[50:], + tls.AlertIllegalParameter("ServerHello has a version we did not advertise"), + ) - # Handle server hello, encrypted extensions, certificate, certificate verify, - # finished. - with self.assertRaises(tls.AlertDecryptError): - client.handle_message(client_input, client_buf) + def test_client_bad_certificate_verify_data(self): + self.handshake_with_client_input_corruption( + # Mess with certificate verify. + lambda x: x[:-56] + bytes(4) + x[-52:], + tls.AlertDecryptError(), + ) def test_client_bad_finished_verify_data(self): - client = self.create_client() - server = self.create_server() - - # Send client hello. - client_buf = create_buffers() - client.handle_message(b"", client_buf) - self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO) - server_input = merge_buffers(client_buf) - reset_buffers(client_buf) - - # Handle client hello. - # - # Send server hello, encrypted extensions, certificate, certificate verify, - # finished. - server_buf = create_buffers() - server.handle_message(server_input, server_buf) - self.assertEqual(server.state, State.SERVER_EXPECT_FINISHED) - client_input = merge_buffers(server_buf) - reset_buffers(server_buf) - - # Mess with finished verify data. - client_input = client_input[:-4] + bytes(4) - - # Handle server hello, encrypted extensions, certificate, certificate verify, - # finished. - with self.assertRaises(tls.AlertDecryptError): - client.handle_message(client_input, client_buf) + self.handshake_with_client_input_corruption( + # Mess with finished verify data. + lambda x: x[:-4] + bytes(4), + tls.AlertDecryptError(), + ) def test_server_unexpected_message(self): server = self.create_server() @@ -1035,6 +1051,12 @@ def test_pull_client_hello_with_sni(self): push_client_hello(buf, hello) self.assertEqual(buf.data, load("tls_client_hello_with_sni.bin")) + def test_pull_client_hello_with_unexpected_version(self): + buf = Buffer(data=corrupt_hello_version(load("tls_client_hello.bin"))) + with self.assertRaises(tls.AlertDecodeError) as cm: + pull_client_hello(buf) + self.assertEqual(str(cm.exception), "ClientHello version is not 1.2") + def test_push_client_hello(self): hello = ClientHello( random=binascii.unhexlify( @@ -1156,6 +1178,12 @@ def test_pull_server_hello_with_psk(self): push_server_hello(buf, hello) self.assertEqual(buf.data, load("tls_server_hello_with_psk.bin")) + def test_pull_server_hello_with_unexpected_version(self): + buf = Buffer(data=corrupt_hello_version(load("tls_server_hello.bin"))) + with self.assertRaises(tls.AlertDecodeError) as cm: + pull_server_hello(buf) + self.assertEqual(str(cm.exception), "ServerHello version is not 1.2") + def test_pull_server_hello_with_unknown_extension(self): buf = Buffer(data=load("tls_server_hello_with_unknown_extension.bin")) hello = pull_server_hello(buf)