Skip to content

Commit

Permalink
support for ML-KEM hybrid key exchange groups
Browse files Browse the repository at this point in the history
  • Loading branch information
tomato42 committed Oct 8, 2024
1 parent 0156727 commit 2617bdf
Show file tree
Hide file tree
Showing 7 changed files with 415 additions and 25 deletions.
29 changes: 24 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,26 @@ jobs:
os: ubuntu-latest
python-version: '3.11'
opt-deps: ['brotli', 'zstd']
- name: py3.12with brotli and zstandard
- name: py3.12 with brotli and zstandard
os: ubuntu-latest
python-version: '3.12'
opt-deps: ['brotli', 'zstd']
- name: py3.9 with kyber-py
os: ubuntu-latest
python-version: "3.9"
opt-deps: ["kyber_py"]
- name: py3.10 with kyber-py
os: ubuntu-latest
python-version: "3.10"
opt-deps: ["kyber_py"]
- name: py3.11 with kyber-py
os: ubuntu-latest
python-version: "3.11"
opt-deps: ["kyber_py"]
- name: py3.12 with kyber-py
os: ubuntu-latest
python-version: "3.12"
opt-deps: ["kyber_py"]
# finally test with multiple dependencies installed at the same time
- name: py2.7 with m2crypto, pycrypto, gmpy, gmpy2, and brotli
os: ubuntu-20.04
Expand All @@ -204,22 +220,22 @@ jobs:
- name: py3.9 with m2crypto, gmpy, gmpy2, brotli, and zstandard
os: ubuntu-latest
python-version: 3.9
opt-deps: ['m2crypto', 'gmpy', 'gmpy2', 'brotli', 'zstd']
opt-deps: ['m2crypto', 'gmpy', 'gmpy2', 'brotli', 'zstd', 'kyber_py']
- name: py3.10 with m2crypto, gmpy, gmpy2, brotli, and zstandard
os: ubuntu-latest
python-version: '3.10'
opt-deps: ['m2crypto', 'gmpy', 'gmpy2', 'brotli', 'zstd']
opt-deps: ['m2crypto', 'gmpy', 'gmpy2', 'brotli', 'zstd', 'kyber_py']
- name: py3.11 with m2crypto, gmpy, gmpy2, brotli, and zstandard
os: ubuntu-latest
python-version: '3.11'
# gmpy doesn't build with 3.11
opt-deps: ['m2crypto', 'gmpy2', 'brotli', 'zstd']
opt-deps: ['m2crypto', 'gmpy2', 'brotli', 'zstd', 'kyber_py']
- name: py3.12 with m2crypto, gmpy, gmpy2, brotli, and zstandard
os: ubuntu-latest
python-version: '3.12'
# gmpy doesn't build with 3.12
# coverage to codeclimate can be submitted just once
opt-deps: ['m2crypto', 'gmpy2', 'codeclimate', 'brotli', 'zstd']
opt-deps: ['m2crypto', 'gmpy2', 'codeclimate', 'brotli', 'zstd', 'kyber_py']
steps:
- uses: actions/checkout@v2
if: ${{ !matrix.container }}
Expand Down Expand Up @@ -346,6 +362,9 @@ jobs:
- name: Install zstandard for py3.8 and after
if: ${{ contains(matrix.opt-deps, 'zstd') }}
run: pip install zstandard
- name: Install kyber_py
if: ${{ contains(matrix.opt-deps, 'kyber_py') }}
run: pip install "https://github.com/GiacomoPope/kyber-py/archive/b187189a514b3327578928c1d4c901d34592678e.zip"
- name: Install build dependencies (2.6)
if: ${{ matrix.python-version == '2.6' }}
run: |
Expand Down
8 changes: 7 additions & 1 deletion tlslite/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,13 @@ class GroupName(TLSEnum):
brainpoolP512r1tls13 = 33
allEC.extend(list(range(31, 34)))

all = allEC + allFF
# draft-kwiatkowski-tls-ecdhe-mlkem
secp256r1mlkem768 = 0x11EB
x25519mlkem768 = 0x11EC
secp384r1mlkem1024 = 0x11ED
allKEM = [0x11EB, 0x11EC, 0x11ED]

all = allEC + allFF + allKEM

@classmethod
def toRepr(cls, value, blacklist=None):
Expand Down
19 changes: 14 additions & 5 deletions tlslite/handshakesettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .constants import CertificateType
from .utils import cryptomath
from .utils import cipherfactory
from .utils.compat import ecdsaAllCurves, int_types
from .utils.compat import ecdsaAllCurves, int_types, ML_KEM_AVAILABLE
from .utils.compression import compression_algo_impls

CIPHER_NAMES = ["chacha20-poly1305",
Expand All @@ -34,10 +34,14 @@
ALL_RSA_SIGNATURE_HASHES = RSA_SIGNATURE_HASHES + ["md5"]
SIGNATURE_SCHEMES = ["Ed25519", "Ed448"]
RSA_SCHEMES = ["pss", "pkcs1"]
CURVE_NAMES = []
if ML_KEM_AVAILABLE:
CURVE_NAMES += ["secp256r1mlkem768", "x25519mlkem768",
"secp384r1mlkem1024"]
# while secp521r1 is the most secure, it's also much slower than the others
# so place it as the last one
CURVE_NAMES = ["x25519", "x448", "secp384r1", "secp256r1",
"secp521r1"]
CURVE_NAMES += ["x25519", "x448", "secp384r1", "secp256r1",
"secp521r1"]
ALL_CURVE_NAMES = CURVE_NAMES + ["secp256k1", "brainpoolP512r1",
"brainpoolP384r1", "brainpoolP256r1"]
if ecdsaAllCurves:
Expand All @@ -57,7 +61,8 @@
TLS13_PERMITTED_GROUPS = ["secp256r1", "secp384r1", "secp521r1",
"x25519", "x448", "ffdhe2048",
"ffdhe3072", "ffdhe4096", "ffdhe6144",
"ffdhe8192"]
"ffdhe8192", "secp256r1mlkem768", "x25519mlkem768",
"secp384r1mlkem1024"]
KNOWN_VERSIONS = ((3, 0), (3, 1), (3, 2), (3, 3), (3, 4))
TICKET_CIPHERS = ["chacha20-poly1305", "aes256gcm", "aes128gcm", "aes128ccm",
"aes128ccm_8", "aes256ccm", "aes256ccm_8"]
Expand Down Expand Up @@ -395,7 +400,11 @@ def _init_key_settings(self):
self.dhParams = None
self.dhGroups = list(ALL_DH_GROUP_NAMES)
self.defaultCurve = "secp256r1"
self.keyShares = ["secp256r1", "x25519"]
if ML_KEM_AVAILABLE:
self.keyShares = ["x25519mlkem768"]
else:
self.keyShares = []
self.keyShares += ["secp256r1", "x25519"]
self.padding_cb = None
self.use_heartbeat_extension = True
self.heartbeat_response_callback = None
Expand Down
174 changes: 173 additions & 1 deletion tlslite/keyexchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@
from .utils import tlshashlib as hashlib
from .utils.x25519 import x25519, x448, X25519_G, X448_G, X25519_ORDER_SIZE, \
X448_ORDER_SIZE
from .utils.compat import int_types
from .utils.compat import int_types, ML_KEM_AVAILABLE
from .utils.codec import DecodeError

if ML_KEM_AVAILABLE:
from kyber_py.ml_kem import ML_KEM_768, ML_KEM_1024


class KeyExchange(object):
"""
Common API for calculating Premaster secret
Expand Down Expand Up @@ -1062,3 +1066,171 @@ def calc_shared_key(self, private, peer_share):
S = ecdhYc * private

return numberToByteArray(S.x(), getPointByteSize(ecdhYc))


class KEMKeyExchange(object):
"""
Implementation of the Hybrid KEM key exchange groups.
Caution, KEMs are not symmetric! While they client calls the
same get_random_private_key(), calc_public_value(), and calc_shared_key()
as in FFDH or ECDH, the server calls just the encapsulate_key() method.
"""

def __init__(self, group, version):
if not ML_KEM_AVAILABLE:
raise TLSInternalError("kyber-py library not installed!")
self.group = group
assert version == (3, 4)
del version

if self.group not in GroupName.allKEM:
raise TLSInternalError("called with wrong group")

if self.group == GroupName.secp256r1mlkem768:
self._classic_group = GroupName.secp256r1
elif self.group == GroupName.x25519mlkem768:
self._classic_group = GroupName.x25519
else:
assert self.group == GroupName.secp384r1mlkem1024
self._classic_group = GroupName.secp384r1

def get_random_private_key(self):
"""
Generates a random value to be used as the private key in KEM.
To be used only to generate the KeyShare in ClientHello.
"""

if self.group not in GroupName.allKEM:
raise TLSInternalError("called with wrong group")
if self.group in (GroupName.secp256r1mlkem768,
GroupName.x25519mlkem768):
pqc_pub_key, pqc_priv_key = ML_KEM_768.keygen()
else:
pqc_pub_key, pqc_priv_key = ML_KEM_1024.keygen()

classic_kex = ECDHKeyExchange(self._classic_group, (3, 4))
classic_key = classic_kex.get_random_private_key()

return ((pqc_pub_key, pqc_priv_key), classic_key)

def calc_public_value(self, private):
"""
Extract public values for the private key.
To be used only to generate the KeyShare in ClientHello.
"""
classic_kex = ECDHKeyExchange(self._classic_group, (3, 4))

classic_pub_key_share = classic_kex.calc_public_value(private[1])

if self.group == GroupName.x25519mlkem768:
return private[0][0] + classic_pub_key_share
return classic_pub_key_share + private[0][0]

@staticmethod
def _split_key_shares(public, pqc_first, pqc_key_len, classic_key_len):
if len(public) != classic_key_len + pqc_key_len:
raise TLSIllegalParameterException(
"Invalid key size for the selected group. "
"Expected: {0}, received: {1}".format(
classic_key_len + pqc_key_len,
len(public)))

if pqc_first:
pqc_key = public[:pqc_key_len]
classic_key_share = bytearray(public[pqc_key_len:])
else:
classic_key_share = bytearray(public[:classic_key_len])
pqc_key = public[classic_key_len:]

return pqc_key, classic_key_share

def _group_to_params(self):
"""Returns a tuple:
classic_key_len, pqc_ek_key_len, pqc_ciphertext_len, pqc_first, ML_KEM
"""
if self.group == GroupName.secp256r1mlkem768:
classic_key_len = 65
pqc_key_len = 1184
pqc_ciphertext_len = 1088
pqc_first = False
ml_kem = ML_KEM_768
elif self.group == GroupName.x25519mlkem768:
classic_key_len = 32
pqc_key_len = 1184
pqc_ciphertext_len = 1088
pqc_first = True
ml_kem = ML_KEM_768
else:
assert self.group == GroupName.secp384r1mlkem1024
classic_key_len = 97
pqc_key_len = 1568
pqc_ciphertext_len = 1568
pqc_first = False
ml_kem = ML_KEM_1024

return classic_key_len, pqc_key_len, pqc_ciphertext_len, pqc_first, \
ml_kem

def encapsulate_key(self, public):
"""
Generate a random secret, encapsulate it given the public key,
and return both the random secret and encapsulation of it.
To be used for generation of KeyShare in ServerHello.
"""
classic_key_len, pqc_key_len, _, pqc_first, ml_kem = \
self._group_to_params()

pqc_key, classic_key_share = self._split_key_shares(
public, pqc_first, pqc_key_len, classic_key_len)

classic_kex = ECDHKeyExchange(self._classic_group, (3, 4))
classic_key = classic_kex.get_random_private_key()
classic_my_key_share = classic_kex.calc_public_value(classic_key)
classic_shared_secret = classic_kex.calc_shared_key(
classic_key, classic_key_share)

try:
pqc_shared_secret, pqc_encaps = ml_kem.encaps(pqc_key)
except ValueError:
raise TLSIllegalParameterException(
"Invalid PQC key from peer")

if pqc_first:
shared_secret = pqc_shared_secret + classic_shared_secret
key_encapsulation = pqc_encaps + classic_my_key_share
else:
shared_secret = classic_shared_secret + pqc_shared_secret
key_encapsulation = classic_my_key_share + pqc_encaps

return shared_secret, key_encapsulation

def calc_shared_key(self, private, key_encaps):
"""
Decapsulate the key share received from server.
"""
classic_key_len, _, pqc_key_len, pqc_first, ml_kem = \
self._group_to_params()

pqc_key, classic_key_share = self._split_key_shares(
key_encaps, pqc_first, pqc_key_len, classic_key_len)

classic_kex = ECDHKeyExchange(self._classic_group, (3, 4))
classic_shared_secret = classic_kex.calc_shared_key(
private[1], classic_key_share)

try:
pqc_shared_secret = ml_kem.decaps(private[0][1], pqc_key)
except ValueError:
raise TLSIllegalParameterException(
"Error in KEM decapsulation")

if pqc_first:
shared_secret = pqc_shared_secret + classic_shared_secret
else:
shared_secret = classic_shared_secret + pqc_shared_secret

return shared_secret
42 changes: 31 additions & 11 deletions tlslite/tlsconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from .utils.deprecations import deprecated_params
from .keyexchange import KeyExchange, RSAKeyExchange, DHE_RSAKeyExchange, \
ECDHE_RSAKeyExchange, SRPKeyExchange, ADHKeyExchange, \
AECDHKeyExchange, FFDHKeyExchange, ECDHKeyExchange
AECDHKeyExchange, FFDHKeyExchange, ECDHKeyExchange, KEMKeyExchange
from .handshakehelpers import HandshakeHelpers
from .utils.cipherfactory import createAESCCM, createAESCCM_8, \
createAESGCM, createCHACHA20
Expand Down Expand Up @@ -1196,6 +1196,8 @@ def _clientGetServerHello(self, settings, session, clientHello):
@staticmethod
def _getKEX(group, version):
"""Get object for performing key exchange."""
if group in GroupName.allKEM:
return KEMKeyExchange(group, version)
if group in GroupName.allFF:
return FFDHKeyExchange(group, version)
return ECDHKeyExchange(group, version)
Expand All @@ -1209,6 +1211,15 @@ def _genKeyShareEntry(cls, group, version):
share = kex.calc_public_value(private)
return KeyShareEntry().create(group, share, private)

@classmethod
def _KEMEncaps(cls, group, public):
"""Generate the server's KeyShareEntry object with encapsulated secret.
"""
kex = cls._getKEX(group, (3, 4))
shared_sec, key_share_value = kex.encapsulate_key(public)
key_share = KeyShareEntry().create(group, key_share_value, None)
return shared_sec, key_share

@staticmethod
def _getPRFParams(cipher_suite):
"""Return name of hash used for PRF and the hash output size."""
Expand Down Expand Up @@ -2803,16 +2814,21 @@ def _serverTLS13Handshake(self, settings, clientHello, cipherSuite,
(psk is None and privateKey):
self.ecdhCurve = selected_group
kex = self._getKEX(selected_group, version)
key_share = self._genKeyShareEntry(selected_group, version)
if selected_group in GroupName.allKEM:
shared_sec, key_share = self._KEMEncaps(
selected_group,
cl_key_share.key_exchange)
else:
key_share = self._genKeyShareEntry(selected_group, version)

try:
shared_sec = kex.calc_shared_key(key_share.private,
cl_key_share.key_exchange)
except TLSIllegalParameterException as alert:
for result in self._sendError(
AlertDescription.illegal_parameter,
str(alert)):
yield result
try:
shared_sec = kex.calc_shared_key(key_share.private,
cl_key_share.key_exchange)
except TLSIllegalParameterException as alert:
for result in self._sendError(
AlertDescription.illegal_parameter,
str(alert)):
yield result

sh_extensions.append(ServerKeyShareExtension().create(key_share))
elif (psk is not None and
Expand Down Expand Up @@ -4915,7 +4931,11 @@ def _sigHashesToList(settings, privateKey=None, certList=None,
@staticmethod
def _curveNamesToList(settings):
"""Convert list of acceptable curves to array identifiers"""
return [getattr(GroupName, val) for val in settings.eccCurves]
ret = [getattr(GroupName, val) for val in settings.eccCurves]
if settings.maxVersion < (3, 4) and (3, 4) not in settings.versions:
# if we don't support TLS 1.3, filter out KEMs
ret = [i for i in ret if i not in GroupName.allKEM]
return ret

@staticmethod
def _groupNamesToList(settings):
Expand Down
Loading

0 comments on commit 2617bdf

Please sign in to comment.