From 4a2d3df6078b547282ba51f7c44b0ffc54b8ed02 Mon Sep 17 00:00:00 2001 From: gstarovo Date: Wed, 17 Apr 2024 11:19:54 +0200 Subject: [PATCH] changes in point extension format --- .github/workflows/ci.yml | 8 +- .gitignore | 2 +- scripts/tls.py | 3 +- test | 0 tests/tlstest.py | 157 +++++++++++++++++++++++-- tlslite/handshakesettings.py | 16 ++- tlslite/keyexchange.py | 121 +++++++++++++------ tlslite/session.py | 7 +- tlslite/tlsconnection.py | 70 ++++++++--- unit_tests/test_tlslite_keyexchange.py | 10 +- 10 files changed, 320 insertions(+), 74 deletions(-) delete mode 100644 test diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6581b6fea..64d4eb2b5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -433,7 +433,13 @@ jobs: COVERALLS_FLAG_NAME: ${{ matrix.name }} COVERALLS_PARALLEL: true COVERALLS_SERVICE_NAME: github - run: coveralls + PY_VERSION: ${{ matrix.python-version }} + run: | + if [[ $PY_VERSION == "2.6" ]]; then + COVERALLS_SKIP_SSL_VERIFY=1 coveralls + else + coveralls + fi - name: Publish coverage to Codeclimate if: ${{ contains(matrix.opt-deps, 'codeclimate') }} env: diff --git a/.gitignore b/.gitignore index 564337542..daedfe767 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,4 @@ coverage.xml pylint_report.txt build/ docs/_build/ -htmlcov/ +htmlcov/ \ No newline at end of file diff --git a/scripts/tls.py b/scripts/tls.py index a3f27ebe1..83ac7a59e 100755 --- a/scripts/tls.py +++ b/scripts/tls.py @@ -367,6 +367,7 @@ def printGoodConnection(connection, seconds): print(" Extended Master Secret: {0}".format( connection.extendedMasterSecret)) print(" Session Resumed: {0}".format(connection.resumed)) + print(" Session used ec point format extension: {0}".format(connection.session.ec_point_format)) def printExporter(connection, expLabel, expLength): if expLabel is None: @@ -424,7 +425,7 @@ def clientCmd(argv): connection.handshakeClientCert(cert_chain, privateKey, settings=settings, serverName=address[0], alpn=alpn) stop = time_stamp() - print("Handshake success") + print("Handshake success") except TLSLocalAlert as a: if a.description == AlertDescription.user_canceled: print(str(a)) diff --git a/test b/test deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/tlstest.py b/tests/tlstest.py index 18a64b73e..0a42ca1c2 100755 --- a/tests/tlstest.py +++ b/tests/tlstest.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Authors: +# Authors: # Trevor Perrin # Kees Bos - Added tests for XML-RPC # Dimitris Moraitis - Anon ciphersuites @@ -44,11 +44,11 @@ from xmlrpc import client as xmlrpclib import ssl from tlslite import * -from tlslite.constants import KeyUpdateMessageType +from tlslite.constants import KeyUpdateMessageType, ECPointFormat try: from tack.structures.Tack import Tack - + except ImportError: pass @@ -56,10 +56,10 @@ def printUsage(s=None): if m2cryptoLoaded: crypto = "M2Crypto/OpenSSL" else: - crypto = "Python crypto" + crypto = "Python crypto" if s: print("ERROR: %s" % s) - print("""\ntls.py version %s (using %s) + print("""\ntls.py version %s (using %s) Commands: server HOST:PORT DIRECTORY @@ -67,7 +67,7 @@ def printUsage(s=None): client HOST:PORT DIRECTORY """ % (__version__, crypto)) sys.exit(-1) - + def testConnClient(conn): b1 = os.urandom(1) @@ -92,9 +92,9 @@ def testConnClient(conn): assert r1000 == b1000 def clientTestCmd(argv): - + address = argv[0] - dir = argv[1] + dir = argv[1] #Split address into hostname/port tuple address = address.split(":") @@ -235,7 +235,7 @@ def connect(): settings.minVersion = (3,0) settings.maxVersion = (3,0) connection.handshakeClientCert(settings=settings) - testConnClient(connection) + testConnClient(connection) assert(isinstance(connection.session.serverCertChain, X509CertChain)) connection.close() @@ -286,6 +286,72 @@ def connect(): test_no += 1 + print("Test {0} - client compressed/uncompressed - uncompressed, TLSv1.2".format(test_no)) + synchro.recv(1) + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 3) + settings.maxVersion = (3, 3) + settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"] + connection.handshakeClientCert(settings=settings) + testConnClient(connection) + assert connection.session.ec_point_format == ECPointFormat.uncompressed + connection.close() + + test_no += 1 + + print("Test {0} - client compressed - compressed, TLSv1.2".format(test_no)) + synchro.recv(1) + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 3) + settings.maxVersion = (3, 3) + settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"] + connection.handshakeClientCert(settings=settings) + testConnClient(connection) + assert connection.session.ec_point_format == ECPointFormat.ansiX962_compressed_prime + connection.close() + + test_no += 1 + + print("Test {0} - client uncompressed - error, TLSv1.2".format(test_no)) + synchro.recv(1) + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 3) + settings.maxVersion = (3, 3) + settings.ec_point_formats = [ECPointFormat.uncompressed] + settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"] + try: + connection.handshakeClientCert(settings=settings) + assert False + except TLSIllegalParameterException as e: + assert "No common EC point format" in str(e) + except TLSAbruptCloseError as e: + pass + connection.close() + + test_no += 1 + + print("Test {0} - client comppressed char2 - error, TLSv1.2".format(test_no)) + synchro.recv(1) + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 3) + settings.maxVersion = (3, 3) + settings.ec_point_formats = [ECPointFormat.ansiX962_compressed_char2] + settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"] + try: + connection.handshakeClientCert(settings=settings) + assert False + except ValueError as e: + assert "Unknown EC point format provided: [2]" in str(e) + except TLSAbruptCloseError as e: + pass + connection.close() + + test_no += 1 + print("Test {0} - mismatched ECDSA curve, TLSv1.2".format(test_no)) synchro.recv(1) connection = connect() @@ -2162,6 +2228,76 @@ def connect(): test_no += 1 + print("Test {0} - server uncompressed ec format - uncompressed, TLSv1.2".format(test_no)) + synchro.send(b'R') + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 1) + settings.maxVersion = (3, 3) + settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"] + settings.ec_point_formats = [ECPointFormat.uncompressed] + connection.handshakeServer(certChain=x509ecdsaChain, + privateKey=x509ecdsaKey, settings=settings) + testConnServer(connection) + assert connection.session.ec_point_format == ECPointFormat.uncompressed + connection.close() + + test_no += 1 + + print("Test {0} - server compressed ec format - compressed, TLSv1.2".format(test_no)) + synchro.send(b'R') + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 1) + settings.maxVersion = (3, 3) + settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"] + connection.handshakeServer(certChain=x509ecdsaChain, + privateKey=x509ecdsaKey, settings=settings) + testConnServer(connection) + assert connection.session.ec_point_format == ECPointFormat.ansiX962_compressed_prime + connection.close() + + test_no +=1 + + print("Test {0} - server compressed ec format - error, TLSv1.2".format(test_no)) + synchro.send(b'R') + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 1) + settings.maxVersion = (3, 3) + settings.ec_point_formats = [ECPointFormat.ansiX962_compressed_prime] + settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"] + try: + connection.handshakeServer(certChain=x509ecdsaChain, + privateKey=x509ecdsaKey, settings=settings) + assert False + except TLSIllegalParameterException as e: + assert "No common EC point format" in str(e) + except TLSAbruptCloseError as e: + pass + connection.close() + + test_no +=1 + + print("Test {0} - client compressed char2 - error, TLSv1.2".format(test_no)) + synchro.send(b'R') + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 1) + settings.maxVersion = (3, 3) + settings.eccCurves = ["secp256r1", "secp384r1", "secp521r1", "x25519", "x448"] + try: + connection.handshakeServer(certChain=x509ecdsaChain, + privateKey=x509ecdsaKey, settings=settings) + assert False + except ValueError as e: + assert "Unknown EC point format provided: [2]" in str(e) + except TLSAbruptCloseError as e: + pass + connection.close() + + test_no +=1 + print("Test {0} - mismatched ECDSA curve, TLSv1.2".format(test_no)) synchro.send(b'R') connection = connect() @@ -3416,7 +3552,7 @@ def heartbeat_response_check(message): assert synchro.recv(1) == b'R' connection.close() - test_no += 1 + test_no +=1 print("Tests {0}-{1} - XMLRPXC server".format(test_no, test_no + 2)) @@ -3449,6 +3585,7 @@ def add(self, x, y): return x + y synchro.close() synchroSocket.close() + test_no += 2 print("Test succeeded") diff --git a/tlslite/handshakesettings.py b/tlslite/handshakesettings.py index 38e560a2b..ae48b50a8 100644 --- a/tlslite/handshakesettings.py +++ b/tlslite/handshakesettings.py @@ -7,7 +7,7 @@ """Class for setting handshake parameters.""" -from .constants import CertificateType +from .constants import CertificateType, ECPointFormat from .utils import cryptomath from .utils import cipherfactory from .utils.compat import ecdsaAllCurves, int_types @@ -61,6 +61,8 @@ TICKET_CIPHERS = ["chacha20-poly1305", "aes256gcm", "aes128gcm", "aes128ccm", "aes128ccm_8", "aes256ccm", "aes256ccm_8"] PSK_MODES = ["psk_dhe_ke", "psk_ke"] +EC_POINT_FORMATS = [ECPointFormat.ansiX962_compressed_prime, + ECPointFormat.uncompressed] class Keypair(object): @@ -353,6 +355,10 @@ class HandshakeSettings(object): :vartype keyExchangeNames: list :ivar keyExchangeNames: Enabled key exchange types for the connection, influences selected cipher suites. + + :vartype ec_point_formats: list + :ivar ec_point_formats: Enabled point format extension for + elliptic curves. """ def _init_key_settings(self): @@ -396,6 +402,7 @@ def _init_misc_extensions(self): # resumed connections (as tickets are single-use in TLS 1.3 self.ticket_count = 2 self.record_size_limit = 2**14 + 1 # TLS 1.3 includes content type + self.ec_point_formats = list(EC_POINT_FORMATS) def __init__(self): """Initialise default values for settings.""" @@ -599,6 +606,12 @@ def _sanityCheckExtensions(other): not 64 <= other.record_size_limit <= 2**14 + 1: raise ValueError("record_size_limit cannot exceed 2**14+1 bytes") + bad_ec_ext = [i for i in other.ec_point_formats if + i not in EC_POINT_FORMATS] + if bad_ec_ext: + raise ValueError("Unknown EC point format provided: " + "{0}".format(bad_ec_ext)) + HandshakeSettings._sanityCheckEMSExtension(other) @staticmethod @@ -667,6 +680,7 @@ def _copy_extension_settings(self, other): other.sendFallbackSCSV = self.sendFallbackSCSV other.useEncryptThenMAC = self.useEncryptThenMAC other.usePaddingExtension = self.usePaddingExtension + other.ec_point_formats = self.ec_point_formats # session tickets other.padding_cb = self.padding_cb other.ticketKeys = self.ticketKeys diff --git a/tlslite/keyexchange.py b/tlslite/keyexchange.py index 2242aad3e..190427450 100644 --- a/tlslite/keyexchange.py +++ b/tlslite/keyexchange.py @@ -12,7 +12,7 @@ TLSDecodeError from .messages import ServerKeyExchange, ClientKeyExchange, CertificateVerify from .constants import SignatureAlgorithm, HashAlgorithm, CipherSuite, \ - ExtensionType, GroupName, ECCurveType, SignatureScheme + ExtensionType, GroupName, ECCurveType, SignatureScheme, ECPointFormat from .utils.ecc import getCurveByName, getPointByteSize from .utils.rsakey import RSAKey from .utils.cryptomath import bytesToNumber, getRandomBytes, powMod, \ @@ -705,14 +705,17 @@ def makeServerKeyExchange(self, sigHash=None): kex = ECDHKeyExchange(self.group_id, self.serverHello.server_version) self.ecdhXs = kex.get_random_private_key() - if isinstance(self.ecdhXs, ecdsa.keys.SigningKey): - ecdhYs = bytearray( - self.ecdhXs.get_verifying_key().to_string( - encoding = 'uncompressed' - ) - ) - else: - ecdhYs = kex.calc_public_value(self.ecdhXs) + ext_negotiated = ECPointFormat.uncompressed + ext_c = self.clientHello.getExtension(ExtensionType.ec_point_formats) + ext_s = self.serverHello.getExtension(ExtensionType.ec_point_formats) + if ext_c and ext_s: + try: + ext_negotiated = next((i for i in ext_c.formats \ + if i in ext_s.formats)) + except StopIteration: + raise TLSIllegalParameterException("No common EC point format") + + ecdhYs = kex.calc_public_value(self.ecdhXs, ext_negotiated) version = self.serverHello.server_version serverKeyExchange = ServerKeyExchange(self.cipherSuite, version) @@ -730,7 +733,16 @@ def processClientKeyExchange(self, clientKeyExchange): raise TLSDecodeError("No key share") kex = ECDHKeyExchange(self.group_id, self.serverHello.server_version) - return kex.calc_shared_key(self.ecdhXs, ecdhYc) + ext_supported = [ECPointFormat.uncompressed] + ext_c = self.clientHello.getExtension(ExtensionType.ec_point_formats) + ext_s = self.serverHello.getExtension(ExtensionType.ec_point_formats) + if ext_c and ext_s: + ext_supported = [ + ext for ext in ext_c.formats if ext in ext_s.formats + ] + if not ext_supported: + raise TLSIllegalParameterException("No common EC point format") + return kex.calc_shared_key(self.ecdhXs, ecdhYc, ext_supported) def processServerKeyExchange(self, srvPublicKey, serverKeyExchange): """Process the server key exchange, return premaster secret""" @@ -748,15 +760,19 @@ def processServerKeyExchange(self, srvPublicKey, serverKeyExchange): kex = ECDHKeyExchange(serverKeyExchange.named_curve, self.serverHello.server_version) ecdhXc = kex.get_random_private_key() - if isinstance(ecdhXc, ecdsa.keys.SigningKey): - self.ecdhYc = bytearray( - ecdhXc.get_verifying_key().to_string( - encoding = 'uncompressed' - ) - ) - else: - self.ecdhYc = kex.calc_public_value(ecdhXc) - return kex.calc_shared_key(ecdhXc, ecdh_Ys) + ext_negotiated = ECPointFormat.uncompressed + ext_supported = [ECPointFormat.uncompressed] + ext_c = self.clientHello.getExtension(ExtensionType.ec_point_formats) + ext_s = self.serverHello.getExtension(ExtensionType.ec_point_formats) + if ext_c and ext_s: + try: + ext_supported = [i for i in ext_c.formats if i in ext_s.formats] + ext_negotiated = ext_supported[0] + except IndexError: + raise TLSIllegalParameterException("No common EC point format") + + self.ecdhYc = kex.calc_public_value(ecdhXc, ext_negotiated) + return kex.calc_shared_key(ecdhXc, ecdh_Ys, ext_supported) def makeClientKeyExchange(self): """Make client key exchange for ECDHE""" @@ -903,11 +919,11 @@ def get_random_private_key(self): """ raise NotImplementedError("Abstract class") - def calc_public_value(self, private): + def calc_public_value(self, private, frm_negotiated=None): """Calculate the public value from the provided private value.""" raise NotImplementedError("Abstract class") - def calc_shared_key(self, private, peer_share): + def calc_shared_key(self, private, peer_share, frm_supported=None): """Calcualte the shared key given our private and remote share value""" raise NotImplementedError("Abstract class") @@ -940,9 +956,10 @@ def get_random_private_key(self): needed_bytes = divceil(paramStrength(self.prime) * 2, 8) return bytesToNumber(getRandomBytes(needed_bytes)) - def calc_public_value(self, private): + def calc_public_value(self, private, frm_negotiated=None): """ Calculate the public value for given private value. + Frm_negotiated added for API compatibility, not needed for FFDH. :rtype: int """ @@ -964,8 +981,11 @@ def _normalise_peer_share(self, peer_share): "Key share does not match FFDH prime") return bytesToNumber(peer_share) - def calc_shared_key(self, private, peer_share): - """Calculate the shared key.""" + def calc_shared_key(self, private, peer_share, frm_supported=None): + """Calculate the shared key. + Frm_supported added for API compatibility, not needed for FFDH. + + :rtype: bytearray""" peer_share = self._normalise_peer_share(peer_share) # First half of RFC 2631, Section 2.1.5. Validate the client's public # key. @@ -984,7 +1004,6 @@ def calc_shared_key(self, private, peer_share): class ECDHKeyExchange(RawDHKeyExchange): """Implementation of the Elliptic Curve Diffie-Hellman key exchange.""" - _x_groups = set((GroupName.x25519, GroupName.x448)) @staticmethod @@ -1021,20 +1040,50 @@ def _get_fun_gen_size(self): else: return x448, bytearray(X448_G), X448_ORDER_SIZE - def calc_public_value(self, private): + @staticmethod + def _get_point_format(ext): + """Get extension name from the numeric value.""" + transform = {ECPointFormat.uncompressed: 'uncompressed', + ECPointFormat.ansiX962_compressed_char2: 'compressed', + ECPointFormat.ansiX962_compressed_prime: 'compressed'} + return transform[ext] + + def calc_public_value(self, + private, + frm_negotiated=ECPointFormat.uncompressed): """Calculate public value for given private key.""" + point_fmt = self._get_point_format(frm_negotiated) if isinstance(private, ecdsa.keys.SigningKey): - return private.verifying_key.to_string('uncompressed') + return private.verifying_key.to_string(point_fmt) if self.group in self._x_groups: fun, generator, _ = self._get_fun_gen_size() return fun(private, generator) - else: - curve = getCurveByName(GroupName.toStr(self.group)) - point = curve.generator * private - return bytearray(point.to_bytes('uncompressed')) - def calc_shared_key(self, private, peer_share): - """Calculate the shared key,""" + curve = getCurveByName(GroupName.toStr(self.group)) + point = curve.generator * private + return bytearray(point.to_bytes(encoding=point_fmt)) + + def calc_shared_key(self, private, peer_share, + frm_supported=set([ECPointFormat.uncompressed])): + """Calculate the shared key. + + :type private: bytearray | SigningKey + :param private: private value + + :type peer_share: bytearray + :param peer_share: public value + + :type frm_supported: set(ECPointFormat) + :param frm_supported: acceptable point formats for public value + + :rtype: bytearray + :returns: shared key + + :raises TLSIllegalParameterException + when the paramentrs for point are invalid + """ + valid_encodings = set([self._get_point_format(i) \ + for i in frm_supported]) if self.group in self._x_groups: fun, _, size = self._get_fun_gen_size() @@ -1049,7 +1098,8 @@ def calc_shared_key(self, private, peer_share): curve = getCurveByName(GroupName.toRepr(self.group)) try: abstractPoint = ecdsa.ellipticcurve.AbstractPoint() - point = abstractPoint.from_bytes(curve.curve, peer_share) + point = abstractPoint.from_bytes(curve.curve, peer_share, + valid_encodings=valid_encodings) ecdhYc = ecdsa.ellipticcurve.Point( curve.curve, point[0], point[1]) @@ -1057,7 +1107,8 @@ def calc_shared_key(self, private, peer_share): raise TLSIllegalParameterException("Invalid ECC point") if isinstance(private, ecdsa.keys.SigningKey): ecdh = ecdsa.ecdh.ECDH(curve=curve, private_key=private) - ecdh.load_received_public_key_bytes(peer_share) + ecdh.load_received_public_key_bytes(peer_share, + valid_encodings=valid_encodings) return bytearray(ecdh.generate_sharedsecret_bytes()) S = ecdhYc * private diff --git a/tlslite/session.py b/tlslite/session.py index 0e310b716..372f3168c 100644 --- a/tlslite/session.py +++ b/tlslite/session.py @@ -72,6 +72,9 @@ class Session(object): :vartype tls_1_0_tickets: list :ivar tls_1_0_tickets: list of TLS 1.2 and earlier session tickets received from the server + + :vartype ec_point_format: int + :ivar ec_point_format: used EC point format for the ECDH key exchange; """ def __init__(self): @@ -94,6 +97,7 @@ def __init__(self): self.resumptionMasterSecret = bytearray(0) self.tickets = None self.tls_1_0_tickets = None + self.ec_point_format = None def create(self, masterSecret, sessionID, cipherSuite, srpUsername, clientCertChain, serverCertChain, @@ -102,7 +106,7 @@ def create(self, masterSecret, sessionID, cipherSuite, appProto=bytearray(0), cl_app_secret=bytearray(0), sr_app_secret=bytearray(0), exporterMasterSecret=bytearray(0), resumptionMasterSecret=bytearray(0), tickets=None, - tls_1_0_tickets=None): + tls_1_0_tickets=None, ec_point_format=None): self.masterSecret = masterSecret self.sessionID = sessionID self.cipherSuite = cipherSuite @@ -123,6 +127,7 @@ def create(self, masterSecret, sessionID, cipherSuite, # NOTE we need a reference copy not a copy of object here! self.tickets = tickets self.tls_1_0_tickets = tls_1_0_tickets + self.ec_point_format = ec_point_format def _clone(self): other = Session() diff --git a/tlslite/tlsconnection.py b/tlslite/tlsconnection.py index 582097a71..229b029b6 100644 --- a/tlslite/tlsconnection.py +++ b/tlslite/tlsconnection.py @@ -656,6 +656,21 @@ def _handshakeClientAsyncHelper(self, srpParams, certParams, anonParams, if alpnExt: alpnProto = alpnExt.protocol_names[0] + ext_ec_point = ECPointFormat.uncompressed + if self.version < (3, 4): + ext_c = clientHello.getExtension(ExtensionType.ec_point_formats) + ext_s = serverHello.getExtension(ExtensionType.ec_point_formats) + if ext_c and ext_s: + try: + ext_ec_point = next((i for i in ext_c.formats \ + if i in ext_s.formats)) + + except StopIteration as alert: + for result in self._sendError( + AlertDescription.illegal_parameter, + str(alert)): + yield result + # Create the session object which is used for resumptions self.session = Session() self.session.create(masterSecret, serverHello.session_id, cipherSuite, @@ -667,7 +682,8 @@ def _handshakeClientAsyncHelper(self, srpParams, certParams, anonParams, appProto=alpnProto, # NOTE it must be a reference not a copy tickets=self.tickets, - tls_1_0_tickets=self.tls_1_0_tickets) + tls_1_0_tickets=self.tls_1_0_tickets, + ec_point_format=ext_ec_point) self._handshakeDone(resumed=False) self._serverRandom = serverHello.random self._clientRandom = clientHello.random @@ -745,7 +761,6 @@ def _clientSendClientHello(self, settings, session, srpUsername, for group_name in settings.keyShares: group_id = getattr(GroupName, group_name) key_share = self._genKeyShareEntry(group_id, (3, 4)) - shares.append(key_share) # if TLS 1.3 is enabled, key_share must always be sent # (unless only static PSK is used) @@ -762,8 +777,9 @@ def _clientSendClientHello(self, settings, session, srpUsername, if next((cipher for cipher in cipherSuites \ if cipher in CipherSuite.ecdhAllSuites), None) is not None: groups.extend(self._curveNamesToList(settings)) - extensions.append(ECPointFormatsExtension().\ - create([ECPointFormat.uncompressed])) + if settings.ec_point_formats: + extensions.append(ECPointFormatsExtension().\ + create(settings.ec_point_formats)) # Advertise FFDHE groups if we have DHE ciphers if next((cipher for cipher in cipherSuites if cipher in CipherSuite.dhAllSuites), None) is not None: @@ -838,7 +854,7 @@ def _clientSendClientHello(self, settings, session, srpUsername, session_id, wireCipherSuites, certificateTypes, srpUsername, - reqTack, nextProtos is not None, + reqTack, nextProtos is not None, serverName, extensions=extensions) @@ -915,6 +931,7 @@ def _clientGetServerHello(self, settings, session, clientHello): hello_retry = None ext = result.getExtension(ExtensionType.supported_versions) + if result.random == TLS_1_3_HRR and ext and ext.version > (3, 3): self.version = ext.version hello_retry = result @@ -974,7 +991,6 @@ def _clientGetServerHello(self, settings, session, clientHello): "did sent the key share " "for"): yield result - key_share = self._genKeyShareEntry(group_id, (3, 4)) # old key shares need to be removed @@ -1212,7 +1228,6 @@ def _clientTLS13Handshake(self, settings, session, clientHello, raise TLSIllegalParameterException("Server selected not " "advertised group.") kex = self._getKEX(sr_kex.group, self.version) - shared_sec = kex.calc_shared_key(cl_kex.private, sr_kex.key_exchange) else: @@ -1855,8 +1870,8 @@ def _clientFinished(self, premasterSecret, clientRandom, serverRandom, cipherSuite, clientRandom, serverRandom) - self._calcPendingStates(cipherSuite, masterSecret, - clientRandom, serverRandom, + self._calcPendingStates(cipherSuite, masterSecret, + clientRandom, serverRandom, cipherImplementations) #Exchange ChangeCipherSpec and Finished messages @@ -1989,7 +2004,7 @@ def _clientGetKeyFromChain(self, certificate, settings, tack_ext=None): def handshakeServer(self, verifierDB=None, certChain=None, privateKey=None, reqCert=False, sessionCache=None, settings=None, checker=None, - reqCAs = None, + reqCAs = None, tacks=None, activationFlags=0, nextProtos=None, anon=False, alpn=None, sni=None): """Perform a handshake in the role of server. @@ -2090,7 +2105,7 @@ def handshakeServer(self, verifierDB=None, def handshakeServerAsync(self, verifierDB=None, certChain=None, privateKey=None, reqCert=False, sessionCache=None, settings=None, checker=None, - reqCAs=None, + reqCAs=None, tacks=None, activationFlags=0, nextProtos=None, anon=False, alpn=None, sni=None ): @@ -2108,9 +2123,9 @@ def handshakeServerAsync(self, verifierDB=None, handshaker = self._handshakeServerAsyncHelper(\ verifierDB=verifierDB, cert_chain=certChain, privateKey=privateKey, reqCert=reqCert, - sessionCache=sessionCache, settings=settings, - reqCAs=reqCAs, - tacks=tacks, activationFlags=activationFlags, + sessionCache=sessionCache, settings=settings, + reqCAs=reqCAs, + tacks=tacks, activationFlags=activationFlags, nextProtos=nextProtos, anon=anon, alpn=alpn, sni=sni) for result in self._handshakeWrapperAsync(handshaker, checker): yield result @@ -2270,8 +2285,9 @@ def _handshakeServerAsyncHelper(self, verifierDB, if clientHello.getExtension(ExtensionType.ec_point_formats): # even though the selected cipher may not use ECC, client may want # to send a CA certificate with ECDSA... - extensions.append(ECPointFormatsExtension().create( - [ECPointFormat.uncompressed])) + if settings.ec_point_formats: + extensions.append(ECPointFormatsExtension(). + create(settings.ec_point_formats)) # if client sent Heartbeat extension if clientHello.getExtension(ExtensionType.heartbeat): @@ -2413,6 +2429,21 @@ def _handshakeServerAsyncHelper(self, verifierDB, if clientHello.server_name: serverName = clientHello.server_name.decode("utf-8") + ext_ec_point = ECPointFormat.uncompressed + if version < (3, 4): + ext_c = clientHello.getExtension(ExtensionType.ec_point_formats) + ext_s = serverHello.getExtension(ExtensionType.ec_point_formats) + if ext_c and ext_s: + try: + ext_ec_point = next((i for i in ext_c.formats \ + if i in ext_s.formats)) + + except StopIteration as alert: + for result in self._sendError( + AlertDescription.illegal_parameter, + str(alert)): + yield result + # We'll update the session master secret once it is calculated # in _serverFinished self.session.create(b"", serverHello.session_id, cipherSuite, @@ -2424,7 +2455,8 @@ def _handshakeServerAsyncHelper(self, verifierDB, extendedMasterSecret=self.extendedMasterSecret, appProto=selectedALPN, # NOTE it must be a reference, not a copy! - tickets=self.tickets) + tickets=self.tickets, + ec_point_format=ext_ec_point) # Exchange Finished messages for result in self._serverFinished(premasterSecret, @@ -2709,8 +2741,8 @@ def _serverTLS13Handshake(self, settings, clientHello, cipherSuite, (psk is None and privateKey): self.ecdhCurve = selected_group kex = self._getKEX(selected_group, version) - key_share = self._genKeyShareEntry(selected_group, version) - + key_share = self._genKeyShareEntry(selected_group, + version) try: shared_sec = kex.calc_shared_key(key_share.private, cl_key_share.key_exchange) diff --git a/unit_tests/test_tlslite_keyexchange.py b/unit_tests/test_tlslite_keyexchange.py index cfc02aa4f..a3215f383 100644 --- a/unit_tests/test_tlslite_keyexchange.py +++ b/unit_tests/test_tlslite_keyexchange.py @@ -18,16 +18,16 @@ from tlslite.handshakesettings import HandshakeSettings from tlslite.messages import ServerHello, ClientHello, ServerKeyExchange,\ CertificateRequest, ClientKeyExchange -from tlslite.constants import CipherSuite, CertificateType, AlertDescription, \ +from tlslite.constants import CipherSuite, CertificateType, \ HashAlgorithm, SignatureAlgorithm, GroupName, ECCurveType, \ SignatureScheme -from tlslite.errors import TLSLocalAlert, TLSIllegalParameterException, \ +from tlslite.errors import TLSIllegalParameterException, \ TLSDecryptionFailed, TLSInsufficientSecurity, TLSUnknownPSKIdentity, \ TLSInternalError, TLSDecodeError from tlslite.x509 import X509 from tlslite.x509certchain import X509CertChain from tlslite.utils.keyfactory import parsePEMKey -from tlslite.utils.codec import Parser, Writer +from tlslite.utils.codec import Parser from tlslite.utils.cryptomath import bytesToNumber, getRandomBytes, powMod, \ numberToByteArray, isPrime, numBytes from tlslite.mathtls import makeX, makeU, makeK, goodGroupParameters @@ -2523,13 +2523,13 @@ def test_calc_public_value(self): kex = RawDHKeyExchange(None, None) with self.assertRaises(NotImplementedError): - kex.calc_public_value(None) + kex.calc_public_value(None, None) def test_calc_shared_value(self): kex = RawDHKeyExchange(None, None) with self.assertRaises(NotImplementedError): - kex.calc_shared_key(None, None) + kex.calc_shared_key(None, None, None) class TestFFDHKeyExchange(unittest.TestCase):