Skip to content

Commit

Permalink
Add some unit tests for new code branches
Browse files Browse the repository at this point in the history
  • Loading branch information
jlaine committed Dec 28, 2023
1 parent 76a4ff9 commit 11f92bf
Showing 1 changed file with 79 additions and 53 deletions.
132 changes: 79 additions & 53 deletions tests/test_tls.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import binascii
import datetime
import ssl
from typing import Callable
from unittest import TestCase
from unittest.mock import patch

Expand Down Expand Up @@ -87,6 +88,13 @@ def test_pull_block_truncated(self):
pass


def corrupt_hello_version(data: bytes) -> bytes:
"""
Corrupt a ClientHello or ServerHello's protocol version.
"""
return data[:4] + b"\xff\xff" + data[6:]


def create_buffers():
return {
tls.Epoch.INITIAL: Buffer(capacity=4096),
Expand Down Expand Up @@ -145,6 +153,40 @@ def create_server(self, alpn_protocols=None, **kwargs):
self.assertEqual(server.state, State.SERVER_EXPECT_CLIENT_HELLO)
return server

def handshake_with_client_input_corruption(
self,
corrupt_client_input: Callable[[bytes], bytes],
expected_exception: Exception,
):
client = self.create_client()
server = self.create_server()

# Send client hello.
client_buf = create_buffers()
client.handle_message(b"", client_buf)
self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO)
server_input = merge_buffers(client_buf)
reset_buffers(client_buf)

# Handle client hello.
#
# send server hello, encrypted extensions, certificate, certificate verify,
# finished.
server_buf = create_buffers()
server.handle_message(server_input, server_buf)
self.assertEqual(server.state, State.SERVER_EXPECT_FINISHED)
client_input = merge_buffers(server_buf)
reset_buffers(server_buf)

# Mess with compression method.
client_input = corrupt_client_input(client_input)

# Handle server hello, encrypted extensions, certificate, certificate verify,
# finished.
with self.assertRaises(expected_exception.__class__) as cm:
client.handle_message(client_input, client_buf)
self.assertEqual(str(cm.exception), str(expected_exception))

def test_client_unexpected_message(self):
client = self.create_client()

Expand Down Expand Up @@ -176,63 +218,35 @@ def test_client_unexpected_message(self):
with self.assertRaises(tls.AlertUnexpectedMessage):
client.handle_message(b"\x00\x00\x00\x00", create_buffers())

def test_client_bad_certificate_verify_data(self):
client = self.create_client()
server = self.create_server()

# Send client hello.
client_buf = create_buffers()
client.handle_message(b"", client_buf)
self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO)
server_input = merge_buffers(client_buf)
reset_buffers(client_buf)

# Handle client hello.
#
# send server hello, encrypted extensions, certificate, certificate verify,
# finished.
server_buf = create_buffers()
server.handle_message(server_input, server_buf)
self.assertEqual(server.state, State.SERVER_EXPECT_FINISHED)
client_input = merge_buffers(server_buf)
reset_buffers(server_buf)
def test_client_bad_hello_compression_method(self):
self.handshake_with_client_input_corruption(
# Mess with compression method.
lambda x: x[:41] + b"\xff" + x[42:],
tls.AlertIllegalParameter(
"ServerHello has a compression method we did not advertise"
),
)

# Mess with certificate verify.
client_input = client_input[:-56] + bytes(4) + client_input[-52:]
def test_client_bad_hello_version(self):
self.handshake_with_client_input_corruption(
# Mess with supported version.
lambda x: x[:48] + b"\xff\xff" + x[50:],
tls.AlertIllegalParameter("ServerHello has a version we did not advertise"),
)

# Handle server hello, encrypted extensions, certificate, certificate verify,
# finished.
with self.assertRaises(tls.AlertDecryptError):
client.handle_message(client_input, client_buf)
def test_client_bad_certificate_verify_data(self):
self.handshake_with_client_input_corruption(
# Mess with certificate verify.
lambda x: x[:-56] + bytes(4) + x[-52:],
tls.AlertDecryptError(),
)

def test_client_bad_finished_verify_data(self):
client = self.create_client()
server = self.create_server()

# Send client hello.
client_buf = create_buffers()
client.handle_message(b"", client_buf)
self.assertEqual(client.state, State.CLIENT_EXPECT_SERVER_HELLO)
server_input = merge_buffers(client_buf)
reset_buffers(client_buf)

# Handle client hello.
#
# Send server hello, encrypted extensions, certificate, certificate verify,
# finished.
server_buf = create_buffers()
server.handle_message(server_input, server_buf)
self.assertEqual(server.state, State.SERVER_EXPECT_FINISHED)
client_input = merge_buffers(server_buf)
reset_buffers(server_buf)

# Mess with finished verify data.
client_input = client_input[:-4] + bytes(4)

# Handle server hello, encrypted extensions, certificate, certificate verify,
# finished.
with self.assertRaises(tls.AlertDecryptError):
client.handle_message(client_input, client_buf)
self.handshake_with_client_input_corruption(
# Mess with finished verify data.
lambda x: x[:-4] + bytes(4),
tls.AlertDecryptError(),
)

def test_server_unexpected_message(self):
server = self.create_server()
Expand Down Expand Up @@ -1035,6 +1049,12 @@ def test_pull_client_hello_with_sni(self):
push_client_hello(buf, hello)
self.assertEqual(buf.data, load("tls_client_hello_with_sni.bin"))

def test_pull_client_hello_with_unexpected_version(self):
buf = Buffer(data=corrupt_hello_version(load("tls_client_hello.bin")))
with self.assertRaises(tls.AlertDecodeError) as cm:
pull_client_hello(buf)
self.assertEqual(str(cm.exception), "ClientHello version is not 1.2")

def test_push_client_hello(self):
hello = ClientHello(
random=binascii.unhexlify(
Expand Down Expand Up @@ -1156,6 +1176,12 @@ def test_pull_server_hello_with_psk(self):
push_server_hello(buf, hello)
self.assertEqual(buf.data, load("tls_server_hello_with_psk.bin"))

def test_pull_server_hello_with_unexpected_version(self):
buf = Buffer(data=corrupt_hello_version(load("tls_server_hello.bin")))
with self.assertRaises(tls.AlertDecodeError) as cm:
pull_server_hello(buf)
self.assertEqual(str(cm.exception), "ServerHello version is not 1.2")

def test_pull_server_hello_with_unknown_extension(self):
buf = Buffer(data=load("tls_server_hello_with_unknown_extension.bin"))
hello = pull_server_hello(buf)
Expand Down

0 comments on commit 11f92bf

Please sign in to comment.