Skip to content

Commit

Permalink
feat: find best embedding matches (#1102)
Browse files Browse the repository at this point in the history
  • Loading branch information
thewh1teagle authored Jul 11, 2024
1 parent 1c104ea commit c0eaf86
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 0 deletions.
38 changes: 38 additions & 0 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,44 @@ void SherpaOnnxSpeakerEmbeddingManagerFreeSearch(const char *name) {
delete[] name;
}

const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult *
SherpaOnnxSpeakerEmbeddingManagerGetBestMatches(
const SherpaOnnxSpeakerEmbeddingManager *p, const float *v, float threshold,
int32_t n) {
auto matches = p->impl->GetBestMatches(v, threshold, n);

if (matches.empty()) {
return nullptr;
}

auto resultMatches =
new SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch[matches.size()];
for (int i = 0; i < matches.size(); ++i) {
resultMatches[i].score = matches[i].score;

char *name = new char[matches[i].name.size() + 1];
std::copy(matches[i].name.begin(), matches[i].name.end(), name);
name[matches[i].name.size()] = '\0';

resultMatches[i].name = name;
}

auto *result = new SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult();
result->count = matches.size();
result->matches = resultMatches;

return result;
}

void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches(
const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult *r) {
for (int32_t i = 0; i < r->count; ++i) {
delete[] r->matches[i].name;
}
delete[] r->matches;
delete r;
};

int32_t SherpaOnnxSpeakerEmbeddingManagerVerify(
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name,
const float *v, float threshold) {
Expand Down
33 changes: 33 additions & 0 deletions sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,39 @@ SHERPA_ONNX_API const char *SherpaOnnxSpeakerEmbeddingManagerSearch(
SHERPA_ONNX_API void SherpaOnnxSpeakerEmbeddingManagerFreeSearch(
const char *name);

SHERPA_ONNX_API typedef struct SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch {
float score;
const char *name;
} SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch;

SHERPA_ONNX_API typedef struct
SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult {
const SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch *matches;
int32_t count;
} SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult;

// Get the best matching speakers whose embeddings match the given
// embedding.
//
// @param p Pointer to the SherpaOnnxSpeakerEmbeddingManager instance.
// @param v Pointer to an array containing the embedding vector.
// @param threshold Minimum similarity score required for a match (between 0 and
// 1).
// @param n Number of best matches to retrieve.
// @return Returns a pointer to
// SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult
// containing the best matches found. Returns NULL if no matches are
// found. The caller is responsible for freeing the returned pointer
// using SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches() to
// avoid memory leaks.
SHERPA_ONNX_API const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult *
SherpaOnnxSpeakerEmbeddingManagerGetBestMatches(
const SherpaOnnxSpeakerEmbeddingManager *p, const float *v, float threshold,
int32_t n);

SHERPA_ONNX_API void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches(
const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult *r);

// Check whether the input embedding matches the embedding of the input
// speaker.
//
Expand Down
39 changes: 39 additions & 0 deletions sherpa-onnx/csrc/speaker-embedding-manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,40 @@ class SpeakerEmbeddingManager::Impl {
return row2name_.at(max_index);
}

std::vector<SpeakerMatch> GetBestMatches(const float *p, float threshold,
int32_t n) {
std::vector<SpeakerMatch> matches;

if (embedding_matrix_.rows() == 0) {
return matches;
}

Eigen::VectorXf v =
Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_);
v.normalize();

Eigen::VectorXf scores = embedding_matrix_ * v;

std::vector<std::pair<float, int>> score_indices;
for (int i = 0; i < scores.size(); ++i) {
if (scores[i] >= threshold) {
score_indices.emplace_back(scores[i], i);
}
}

std::sort(score_indices.rbegin(), score_indices.rend(),
[](const auto &a, const auto &b) { return a.first < b.first; });

matches.reserve(score_indices.size());
for (int i = 0; i < std::min(n, static_cast<int32_t>(score_indices.size()));
++i) {
const auto &pair = score_indices[i];
matches.push_back({row2name_.at(pair.second), pair.first});
}

return matches;
}

bool Verify(const std::string &name, const float *p, float threshold) {
if (!name2row_.count(name)) {
return false;
Expand Down Expand Up @@ -219,6 +253,11 @@ std::string SpeakerEmbeddingManager::Search(const float *p,
return impl_->Search(p, threshold);
}

std::vector<SpeakerMatch> SpeakerEmbeddingManager::GetBestMatches(
const float *p, float threshold, int32_t n) const {
return impl_->GetBestMatches(p, threshold, n);
}

bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p,
float threshold) const {
return impl_->Verify(name, p, threshold);
Expand Down
24 changes: 24 additions & 0 deletions sherpa-onnx/csrc/speaker-embedding-manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
#include <string>
#include <vector>

struct SpeakerMatch {
const std::string name;
float score;
};

namespace sherpa_onnx {

class SpeakerEmbeddingManager {
Expand Down Expand Up @@ -62,6 +67,25 @@ class SpeakerEmbeddingManager {
*/
std::string Search(const float *p, float threshold) const;

/**
* It is for speaker identification.
*
* It computes the cosine similarity between a given embedding and all
* other embeddings and finds the embeddings that have the largest scores
* and the scores are above or equal to the threshold. Returns a vector of
* SpeakerMatch structures containing the speaker names and scores for the
* embeddings if found; otherwise, returns an empty vector.
*
* @param p A pointer to the input embedding.
* @param threshold A value between 0 and 1.
* @param n The number of top matches to return.
* @return A vector of SpeakerMatch structures. If matches are found, the
* vector contains the names and scores of the speakers. Otherwise,
* it returns an empty vector.
*/
std::vector<SpeakerMatch> GetBestMatches(const float *p, float threshold,
int32_t n) const;

/* Check whether the input embedding matches the embedding of the input
* speaker.
*
Expand Down

0 comments on commit c0eaf86

Please sign in to comment.