Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new API to cleanup OpenSSL threads. #1959

Merged
merged 4 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/common/common.c
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,12 @@ OQS_API void OQS_init(void) {
#endif
}

OQS_API void OQS_thread_stop(void) {
#if defined(OQS_USE_OPENSSL)
oqs_thread_stop();
#endif
}

OQS_API const char *OQS_version(void) {
return OQS_VERSION_TEXT;
}
Expand Down
101 changes: 57 additions & 44 deletions src/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
* SPDX-License-Identifier: MIT
*/


#ifndef OQS_COMMON_H
#define OQS_COMMON_H

#include <limits.h>
#include <stdint.h>
#include <stdlib.h>
#include <stdio.h>
#include <stdlib.h>

#include <oqs/oqsconfig.h>

Expand All @@ -27,14 +26,15 @@ extern "C" {
* using OpenSSL functions when OQS_USE_OPENSSL is defined, and
* standard C library functions otherwise.
*/
#if (defined(OQS_USE_OPENSSL) || defined(OQS_DLOPEN_OPENSSL)) && defined(OPENSSL_VERSION_NUMBER)
#if (defined(OQS_USE_OPENSSL) || defined(OQS_DLOPEN_OPENSSL)) && \
defined(OPENSSL_VERSION_NUMBER)
#include <openssl/crypto.h>

/**
* Allocates memory of a given size.
* @param size The size of the memory to be allocated in bytes.
* @return A pointer to the allocated memory.
*/
* Allocates memory of a given size.
* @param size The size of the memory to be allocated in bytes.
* @return A pointer to the allocated memory.
*/
#define OQS_MEM_malloc(size) OPENSSL_malloc(size)

/**
Expand All @@ -43,7 +43,8 @@ extern "C" {
* @param element_size The size of each element in bytes.
* @return A pointer to the allocated memory.
*/
#define OQS_MEM_calloc(num_elements, element_size) OPENSSL_zalloc((num_elements) * (element_size))
#define OQS_MEM_calloc(num_elements, element_size) \
OPENSSL_zalloc((num_elements) * (element_size))
/**
* Duplicates a string.
* @param str The string to be duplicated.
Expand All @@ -52,10 +53,10 @@ extern "C" {
#define OQS_MEM_strdup(str) OPENSSL_strdup(str)
#else
/**
* Allocates memory of a given size.
* @param size The size of the memory to be allocated in bytes.
* @return A pointer to the allocated memory.
*/
* Allocates memory of a given size.
* @param size The size of the memory to be allocated in bytes.
* @return A pointer to the allocated memory.
*/
#define OQS_MEM_malloc(size) malloc(size) // IGNORE memory-check

/**
Expand All @@ -64,7 +65,8 @@ extern "C" {
* @param element_size The size of each element in bytes.
* @return A pointer to the allocated memory.
*/
#define OQS_MEM_calloc(num_elements, element_size) calloc(num_elements, element_size) // IGNORE memory-check
#define OQS_MEM_calloc(num_elements, element_size) \
calloc(num_elements, element_size) // IGNORE memory-check
/**
* Duplicates a string.
* @param str The string to be duplicated.
Expand All @@ -77,13 +79,14 @@ extern "C" {
* Macro for terminating the program if x is
* a null pointer.
*/
#define OQS_EXIT_IF_NULLPTR(x, loc) \
do { \
if ( (x) == (void*)0 ) { \
fprintf(stderr, "Unexpected NULL returned from %s API. Exiting.\n", loc); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define OQS_EXIT_IF_NULLPTR(x, loc) \
do { \
if ((x) == (void *)0) { \
fprintf(stderr, "Unexpected NULL returned from %s API. Exiting.\n", \
loc); \
exit(EXIT_FAILURE); \
} \
} while (0)

/**
* This macro is intended to replace those assert()s
Expand All @@ -98,22 +101,24 @@ extern "C" {
*/
#ifdef OQS_USE_OPENSSL
#ifdef OPENSSL_NO_STDIO
#define OQS_OPENSSL_GUARD(x) \
do { \
if( 1 != (x) ) { \
fprintf(stderr, "Error return value from OpenSSL API: %d. Exiting.\n", x); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define OQS_OPENSSL_GUARD(x) \
do { \
if (1 != (x)) { \
fprintf(stderr, "Error return value from OpenSSL API: %d. Exiting.\n", \
x); \
exit(EXIT_FAILURE); \
} \
} while (0)
#else // OPENSSL_NO_STDIO
#define OQS_OPENSSL_GUARD(x) \
do { \
if( 1 != (x) ) { \
fprintf(stderr, "Error return value from OpenSSL API: %d. Exiting.\n", x); \
OSSL_FUNC(ERR_print_errors_fp)(stderr); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define OQS_OPENSSL_GUARD(x) \
do { \
if (1 != (x)) { \
fprintf(stderr, "Error return value from OpenSSL API: %d. Exiting.\n", \
x); \
OSSL_FUNC(ERR_print_errors_fp)(stderr); \
exit(EXIT_FAILURE); \
} \
} while (0)
#endif // OPENSSL_NO_STDIO
#endif // OQS_USE_OPENSSL

Expand All @@ -123,13 +128,13 @@ extern "C" {
* only handle values up to INT_MAX for those parameters.
* This macro is a temporary workaround for such functions.
*/
#define SIZE_T_TO_INT_OR_EXIT(size_t_var_name, int_var_name) \
int int_var_name = 0; \
if (size_t_var_name <= INT_MAX) { \
int_var_name = (int)size_t_var_name; \
} else { \
exit(EXIT_FAILURE); \
}
#define SIZE_T_TO_INT_OR_EXIT(size_t_var_name, int_var_name) \
int int_var_name = 0; \
if (size_t_var_name <= INT_MAX) { \
int_var_name = (int)size_t_var_name; \
} else { \
exit(EXIT_FAILURE); \
}

/**
* Defines which functions should be exposed outside the LibOQS library
Expand Down Expand Up @@ -213,6 +218,14 @@ OQS_API int OQS_CPU_has_extension(OQS_CPU_EXT ext);
*/
OQS_API void OQS_init(void);

/**
* This function stops OpenSSL threads, which allows resources
* to be cleaned up in the correct order.
* @note When liboqs is used in a multithreaded application,
* each thread should call this function prior to stopping.
*/
OQS_API void OQS_thread_stop(void);

/**
* This function frees prefetched OpenSSL objects
*/
Expand Down Expand Up @@ -277,8 +290,8 @@ OQS_API void OQS_MEM_insecure_free(void *ptr);
* Allocates size bytes of uninitialized memory with a base pointer that is
* a multiple of alignment. Alignment must be a power of two and a multiple
* of sizeof(void *). Size must be a multiple of alignment.
* @note The allocated memory should be freed with `OQS_MEM_aligned_free` when it
* is no longer needed.
* @note The allocated memory should be freed with `OQS_MEM_aligned_free` when
* it is no longer needed.
*/
void *OQS_MEM_aligned_alloc(size_t alignment, size_t size);

Expand Down
36 changes: 23 additions & 13 deletions src/common/ossl_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,26 @@ VOID_FUNC(void, ERR_print_errors_fp, (FILE *fp), (fp))
VOID_FUNC(void, EVP_CIPHER_CTX_free, (EVP_CIPHER_CTX *c), (c))
FUNC(EVP_CIPHER_CTX *, EVP_CIPHER_CTX_new, (void), ())
FUNC(int, EVP_CIPHER_CTX_set_padding, (EVP_CIPHER_CTX *c, int pad), (c, pad))
FUNC(int, EVP_DigestFinalXOF, (EVP_MD_CTX *ctx, unsigned char *md, size_t len), (ctx, md, len))
FUNC(int, EVP_DigestFinal_ex, (EVP_MD_CTX *ctx, unsigned char *md, unsigned int *s), (ctx, md, s))
FUNC(int, EVP_DigestInit_ex, (EVP_MD_CTX *ctx, const EVP_MD *type, ENGINE *impl), (ctx, type, impl))
FUNC(int, EVP_DigestUpdate, (EVP_MD_CTX *ctx, const void *d, size_t cnt), (ctx, d, cnt))
FUNC(int, EVP_EncryptFinal_ex, (EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl), (ctx, out, outl))
FUNC(int, EVP_EncryptInit_ex, (EVP_CIPHER_CTX *ctx, const EVP_CIPHER *cipher, ENGINE *impl, const unsigned char *key, const unsigned char *iv),
FUNC(int, EVP_DigestFinalXOF, (EVP_MD_CTX *ctx, unsigned char *md, size_t len),
(ctx, md, len))
FUNC(int, EVP_DigestFinal_ex,
(EVP_MD_CTX *ctx, unsigned char *md, unsigned int *s), (ctx, md, s))
FUNC(int, EVP_DigestInit_ex,
(EVP_MD_CTX *ctx, const EVP_MD *type, ENGINE *impl), (ctx, type, impl))
FUNC(int, EVP_DigestUpdate, (EVP_MD_CTX *ctx, const void *d, size_t cnt),
(ctx, d, cnt))
FUNC(int, EVP_EncryptFinal_ex,
(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl), (ctx, out, outl))
FUNC(int, EVP_EncryptInit_ex,
(EVP_CIPHER_CTX *ctx, const EVP_CIPHER *cipher, ENGINE *impl,
const unsigned char *key, const unsigned char *iv),
(ctx, cipher, impl, key, iv))
FUNC(int, EVP_EncryptUpdate, (EVP_CIPHER_CTX *ctx, unsigned char *out,
int *outl, const unsigned char *in, int inl),
FUNC(int, EVP_EncryptUpdate,
(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl,
const unsigned char *in, int inl),
(ctx, out, outl, in, inl))
FUNC(int, EVP_MD_CTX_copy_ex, (EVP_MD_CTX *out, const EVP_MD_CTX *in), (out, in))
FUNC(int, EVP_MD_CTX_copy_ex, (EVP_MD_CTX *out, const EVP_MD_CTX *in),
(out, in))
VOID_FUNC(void, EVP_MD_CTX_free, (EVP_MD_CTX *ctx), (ctx))
FUNC(EVP_MD_CTX *, EVP_MD_CTX_new, (void), ())
FUNC(int, EVP_MD_CTX_reset, (EVP_MD_CTX *ctx), (ctx))
Expand All @@ -29,12 +38,12 @@ FUNC(const EVP_CIPHER *, EVP_aes_128_ctr, (void), ())
FUNC(const EVP_CIPHER *, EVP_aes_256_ecb, (void), ())
FUNC(const EVP_CIPHER *, EVP_aes_256_ctr, (void), ())
#if OPENSSL_VERSION_NUMBER >= 0x30000000L
FUNC(EVP_CIPHER *, EVP_CIPHER_fetch, (OSSL_LIB_CTX *ctx, const char *algorithm,
const char *properties),
FUNC(EVP_CIPHER *, EVP_CIPHER_fetch,
(OSSL_LIB_CTX *ctx, const char *algorithm, const char *properties),
(ctx, algorithm, properties))
VOID_FUNC(void, EVP_CIPHER_free, (EVP_CIPHER *cipher), (cipher))
FUNC(EVP_MD *, EVP_MD_fetch, (OSSL_LIB_CTX *ctx, const char *algorithm,
const char *properties),
FUNC(EVP_MD *, EVP_MD_fetch,
(OSSL_LIB_CTX *ctx, const char *algorithm, const char *properties),
(ctx, algorithm, properties))
VOID_FUNC(void, EVP_MD_free, (EVP_MD *md), (md))
#else
Expand All @@ -51,3 +60,4 @@ VOID_FUNC(void, OPENSSL_cleanse, (void *ptr, size_t len), (ptr, len))
FUNC(int, RAND_bytes, (unsigned char *buf, int num), (buf, num))
FUNC(int, RAND_poll, (void), ())
FUNC(int, RAND_status, (void), ())
VOID_FUNC(void, OPENSSL_thread_stop, (void), ())
4 changes: 4 additions & 0 deletions src/common/ossl_helpers.c
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ void oqs_ossl_destroy(void) {
#endif
}

void oqs_thread_stop(void) {
OSSL_FUNC(OPENSSL_thread_stop)();
}

const EVP_MD *oqs_sha256(void) {
#if OPENSSL_VERSION_NUMBER >= 0x30000000L
#if defined(OQS_USE_PTHREADS)
Expand Down
5 changes: 3 additions & 2 deletions src/common/ossl_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ extern "C" {
#if defined(OQS_USE_OPENSSL)
void oqs_ossl_destroy(void);

void oqs_thread_stop(void);

const EVP_MD *oqs_sha256(void);

const EVP_MD *oqs_sha384(void);
Expand All @@ -39,8 +41,7 @@ const EVP_CIPHER *oqs_aes_256_ctr(void);

#ifdef OQS_DLOPEN_OPENSSL

#define FUNC(ret, name, args, cargs) \
ret _oqs_ossl_##name args;
#define FUNC(ret, name, args, cargs) ret _oqs_ossl_##name args;
#define VOID_FUNC FUNC
#include "ossl_functions.h"
#undef VOID_FUNC
Expand Down
1 change: 1 addition & 0 deletions tests/test_kem.c
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ struct thread_data {
void *test_wrapper(void *arg) {
struct thread_data *td = arg;
td->rc = kem_test_correctness(td->alg_name);
OQS_thread_stop();
return NULL;
}
#endif
Expand Down
1 change: 1 addition & 0 deletions tests/test_sig.c
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ struct thread_data {
void *test_wrapper(void *arg) {
struct thread_data *td = arg;
td->rc = sig_test_correctness(td->alg_name);
OQS_thread_stop();
return NULL;
}
#endif
Expand Down
6 changes: 5 additions & 1 deletion tests/test_sig_stfl.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

#if OQS_USE_PTHREADS_IN_TESTS
#include <pthread.h>

static pthread_mutex_t *test_sk_lock = NULL;
static pthread_mutex_t *sk_lock = NULL;
#endif
Expand Down Expand Up @@ -990,6 +989,7 @@ void *test_query_key(void *arg) {
struct lock_test_data *td = arg;
printf("\n%s: Start Query Stateful Key info\n", __func__);
td->rc = sig_stfl_test_query_key(td->alg_name);
OQS_thread_stop();
printf("%s: End Query Stateful Key info\n\n", __func__);
return NULL;
}
Expand All @@ -998,6 +998,7 @@ void *test_sig_gen(void *arg) {
struct lock_test_data *td = arg;
printf("\n%s: Start Generate Stateful Signature\n", __func__);
td->rc = sig_stfl_test_sig_gen(td->alg_name);
OQS_thread_stop();
printf("%s: End Generate Stateful Signature\n\n", __func__);
return NULL;
}
Expand All @@ -1006,19 +1007,22 @@ void *test_create_keys(void *arg) {
struct lock_test_data *td = arg;
printf("\n%s: Start Generate Keys\n", __func__);
td->rc = sig_stfl_test_secret_key_lock(td->alg_name, td->katfile);
OQS_thread_stop();
printf("%s: End Generate Stateful Keys\n\n", __func__);
return NULL;
}

void *test_correctness_wrapper(void *arg) {
struct thread_data *td = arg;
td->rc = sig_stfl_test_correctness(td->alg_name, td->katfile);
OQS_thread_stop();
return NULL;
}

void *test_secret_key_wrapper(void *arg) {
struct thread_data *td = arg;
td->rc = sig_stfl_test_secret_key(td->alg_name, td->katfile);
OQS_thread_stop();
return NULL;
}
#endif
Expand Down
Loading