From c0eaf86dbd4b7c842852215d5418e065a64e6190 Mon Sep 17 00:00:00 2001 From: thewh1teagle <61390950+thewh1teagle@users.noreply.github.com> Date: Thu, 11 Jul 2024 04:38:06 +0300 Subject: [PATCH] feat: find best embedding matches (#1102) --- sherpa-onnx/c-api/c-api.cc | 38 ++++++++++++++++++ sherpa-onnx/c-api/c-api.h | 33 ++++++++++++++++ sherpa-onnx/csrc/speaker-embedding-manager.cc | 39 +++++++++++++++++++ sherpa-onnx/csrc/speaker-embedding-manager.h | 24 ++++++++++++ 4 files changed, 134 insertions(+) diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index e23305fb7..eb9ec8752 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -1256,6 +1256,44 @@ void SherpaOnnxSpeakerEmbeddingManagerFreeSearch(const char *name) { delete[] name; } +const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult * +SherpaOnnxSpeakerEmbeddingManagerGetBestMatches( + const SherpaOnnxSpeakerEmbeddingManager *p, const float *v, float threshold, + int32_t n) { + auto matches = p->impl->GetBestMatches(v, threshold, n); + + if (matches.empty()) { + return nullptr; + } + + auto resultMatches = + new SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch[matches.size()]; + for (int i = 0; i < matches.size(); ++i) { + resultMatches[i].score = matches[i].score; + + char *name = new char[matches[i].name.size() + 1]; + std::copy(matches[i].name.begin(), matches[i].name.end(), name); + name[matches[i].name.size()] = '\0'; + + resultMatches[i].name = name; + } + + auto *result = new SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult(); + result->count = matches.size(); + result->matches = resultMatches; + + return result; +} + +void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches( + const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult *r) { + for (int32_t i = 0; i < r->count; ++i) { + delete[] r->matches[i].name; + } + delete[] r->matches; + delete r; +}; + int32_t SherpaOnnxSpeakerEmbeddingManagerVerify( const SherpaOnnxSpeakerEmbeddingManager *p, const char *name, const float *v, float threshold) { diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 2bfba98c7..4beba2a73 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -1109,6 +1109,39 @@ SHERPA_ONNX_API const char *SherpaOnnxSpeakerEmbeddingManagerSearch( SHERPA_ONNX_API void SherpaOnnxSpeakerEmbeddingManagerFreeSearch( const char *name); +SHERPA_ONNX_API typedef struct SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch { + float score; + const char *name; +} SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch; + +SHERPA_ONNX_API typedef struct + SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult { + const SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch *matches; + int32_t count; +} SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult; + +// Get the best matching speakers whose embeddings match the given +// embedding. +// +// @param p Pointer to the SherpaOnnxSpeakerEmbeddingManager instance. +// @param v Pointer to an array containing the embedding vector. +// @param threshold Minimum similarity score required for a match (between 0 and +// 1). +// @param n Number of best matches to retrieve. +// @return Returns a pointer to +// SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult +// containing the best matches found. Returns NULL if no matches are +// found. The caller is responsible for freeing the returned pointer +// using SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches() to +// avoid memory leaks. +SHERPA_ONNX_API const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult * +SherpaOnnxSpeakerEmbeddingManagerGetBestMatches( + const SherpaOnnxSpeakerEmbeddingManager *p, const float *v, float threshold, + int32_t n); + +SHERPA_ONNX_API void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches( + const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult *r); + // Check whether the input embedding matches the embedding of the input // speaker. // diff --git a/sherpa-onnx/csrc/speaker-embedding-manager.cc b/sherpa-onnx/csrc/speaker-embedding-manager.cc index 6c90c1953..701fa6e18 100644 --- a/sherpa-onnx/csrc/speaker-embedding-manager.cc +++ b/sherpa-onnx/csrc/speaker-embedding-manager.cc @@ -131,6 +131,40 @@ class SpeakerEmbeddingManager::Impl { return row2name_.at(max_index); } + std::vector GetBestMatches(const float *p, float threshold, + int32_t n) { + std::vector matches; + + if (embedding_matrix_.rows() == 0) { + return matches; + } + + Eigen::VectorXf v = + Eigen::Map(const_cast(p), dim_); + v.normalize(); + + Eigen::VectorXf scores = embedding_matrix_ * v; + + std::vector> score_indices; + for (int i = 0; i < scores.size(); ++i) { + if (scores[i] >= threshold) { + score_indices.emplace_back(scores[i], i); + } + } + + std::sort(score_indices.rbegin(), score_indices.rend(), + [](const auto &a, const auto &b) { return a.first < b.first; }); + + matches.reserve(score_indices.size()); + for (int i = 0; i < std::min(n, static_cast(score_indices.size())); + ++i) { + const auto &pair = score_indices[i]; + matches.push_back({row2name_.at(pair.second), pair.first}); + } + + return matches; + } + bool Verify(const std::string &name, const float *p, float threshold) { if (!name2row_.count(name)) { return false; @@ -219,6 +253,11 @@ std::string SpeakerEmbeddingManager::Search(const float *p, return impl_->Search(p, threshold); } +std::vector SpeakerEmbeddingManager::GetBestMatches( + const float *p, float threshold, int32_t n) const { + return impl_->GetBestMatches(p, threshold, n); +} + bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p, float threshold) const { return impl_->Verify(name, p, threshold); diff --git a/sherpa-onnx/csrc/speaker-embedding-manager.h b/sherpa-onnx/csrc/speaker-embedding-manager.h index ae8728b13..9490765ca 100644 --- a/sherpa-onnx/csrc/speaker-embedding-manager.h +++ b/sherpa-onnx/csrc/speaker-embedding-manager.h @@ -9,6 +9,11 @@ #include #include +struct SpeakerMatch { + const std::string name; + float score; +}; + namespace sherpa_onnx { class SpeakerEmbeddingManager { @@ -62,6 +67,25 @@ class SpeakerEmbeddingManager { */ std::string Search(const float *p, float threshold) const; + /** + * It is for speaker identification. + * + * It computes the cosine similarity between a given embedding and all + * other embeddings and finds the embeddings that have the largest scores + * and the scores are above or equal to the threshold. Returns a vector of + * SpeakerMatch structures containing the speaker names and scores for the + * embeddings if found; otherwise, returns an empty vector. + * + * @param p A pointer to the input embedding. + * @param threshold A value between 0 and 1. + * @param n The number of top matches to return. + * @return A vector of SpeakerMatch structures. If matches are found, the + * vector contains the names and scores of the speakers. Otherwise, + * it returns an empty vector. + */ + std::vector GetBestMatches(const float *p, float threshold, + int32_t n) const; + /* Check whether the input embedding matches the embedding of the input * speaker. *