-
Notifications
You must be signed in to change notification settings - Fork 488
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support streaming conformer CTC models from wenet (#427)
- Loading branch information
1 parent
b83b3e3
commit fac4f6b
Showing
31 changed files
with
1,212 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
#!/usr/bin/env bash | ||
|
||
set -e | ||
|
||
log() { | ||
# This function is from espnet | ||
local fname=${BASH_SOURCE[1]##*/} | ||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | ||
} | ||
|
||
echo "EXE is $EXE" | ||
echo "PATH: $PATH" | ||
|
||
which $EXE | ||
|
||
log "------------------------------------------------------------" | ||
log "Run streaming Conformer CTC from WeNet" | ||
log "------------------------------------------------------------" | ||
wenet_models=( | ||
sherpa-onnx-zh-wenet-aishell | ||
sherpa-onnx-zh-wenet-aishell2 | ||
sherpa-onnx-zh-wenet-wenetspeech | ||
sherpa-onnx-zh-wenet-multi-cn | ||
sherpa-onnx-en-wenet-librispeech | ||
sherpa-onnx-en-wenet-gigaspeech | ||
) | ||
for name in ${wenet_models[@]}; do | ||
repo_url=https://huggingface.co/csukuangfj/$name | ||
log "Start testing ${repo_url}" | ||
repo=$(basename $repo_url) | ||
log "Download pretrained model and test-data from $repo_url" | ||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url | ||
pushd $repo | ||
git lfs pull --include "*.onnx" | ||
ls -lh *.onnx | ||
popd | ||
|
||
log "test float32 models" | ||
time $EXE \ | ||
--tokens=$repo/tokens.txt \ | ||
--wenet-ctc-model=$repo/model-streaming.onnx \ | ||
$repo/test_wavs/0.wav \ | ||
$repo/test_wavs/1.wav \ | ||
$repo/test_wavs/8k.wav | ||
|
||
log "test int8 models" | ||
time $EXE \ | ||
--tokens=$repo/tokens.txt \ | ||
--wenet-ctc-model=$repo/model-streaming.int8.onnx \ | ||
$repo/test_wavs/0.wav \ | ||
$repo/test_wavs/1.wav \ | ||
$repo/test_wavs/8k.wav | ||
|
||
rm -rf $repo | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
// sherpa-onnx/csrc/online-ctc-decoder.h | ||
// | ||
// Copyright (c) 2023 Xiaomi Corporation | ||
|
||
#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_ | ||
#define SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_ | ||
|
||
#include <vector> | ||
|
||
#include "onnxruntime_cxx_api.h" // NOLINT | ||
|
||
namespace sherpa_onnx { | ||
|
||
struct OnlineCtcDecoderResult { | ||
/// The decoded token IDs | ||
std::vector<int64_t> tokens; | ||
|
||
/// timestamps[i] contains the output frame index where tokens[i] is decoded. | ||
/// Note: The index is after subsampling | ||
std::vector<int32_t> timestamps; | ||
|
||
int32_t num_trailing_blanks = 0; | ||
}; | ||
|
||
class OnlineCtcDecoder { | ||
public: | ||
virtual ~OnlineCtcDecoder() = default; | ||
|
||
/** Run streaming CTC decoding given the output from the encoder model. | ||
* | ||
* @param log_probs A 3-D tensor of shape (N, T, vocab_size) containing | ||
* lob_probs. | ||
* | ||
* @param results Input & Output parameters.. | ||
*/ | ||
virtual void Decode(Ort::Value log_probs, | ||
std::vector<OnlineCtcDecoderResult> *results) = 0; | ||
}; | ||
|
||
} // namespace sherpa_onnx | ||
|
||
#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
// sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc | ||
// | ||
// Copyright (c) 2023 Xiaomi Corporation | ||
|
||
#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h" | ||
|
||
#include <algorithm> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "sherpa-onnx/csrc/macros.h" | ||
|
||
namespace sherpa_onnx { | ||
|
||
void OnlineCtcGreedySearchDecoder::Decode( | ||
Ort::Value log_probs, std::vector<OnlineCtcDecoderResult> *results) { | ||
std::vector<int64_t> log_probs_shape = | ||
log_probs.GetTensorTypeAndShapeInfo().GetShape(); | ||
|
||
if (log_probs_shape[0] != results->size()) { | ||
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d", | ||
static_cast<int32_t>(log_probs_shape[0]), | ||
static_cast<int32_t>(results->size())); | ||
exit(-1); | ||
} | ||
|
||
int32_t batch_size = static_cast<int32_t>(log_probs_shape[0]); | ||
int32_t num_frames = static_cast<int32_t>(log_probs_shape[1]); | ||
int32_t vocab_size = static_cast<int32_t>(log_probs_shape[2]); | ||
|
||
const float *p = log_probs.GetTensorData<float>(); | ||
|
||
for (int32_t b = 0; b != batch_size; ++b) { | ||
auto &r = (*results)[b]; | ||
|
||
int32_t prev_id = -1; | ||
|
||
for (int32_t t = 0; t != num_frames; ++t, p += vocab_size) { | ||
int32_t y = static_cast<int32_t>(std::distance( | ||
static_cast<const float *>(p), | ||
std::max_element(static_cast<const float *>(p), | ||
static_cast<const float *>(p) + vocab_size))); | ||
|
||
if (y == blank_id_) { | ||
r.num_trailing_blanks += 1; | ||
} else { | ||
r.num_trailing_blanks = 0; | ||
} | ||
|
||
if (y != blank_id_ && y != prev_id) { | ||
r.tokens.push_back(y); | ||
r.timestamps.push_back(t); | ||
} | ||
|
||
prev_id = y; | ||
} // for (int32_t t = 0; t != num_frames; ++t) { | ||
} // for (int32_t b = 0; b != batch_size; ++b) | ||
} | ||
|
||
} // namespace sherpa_onnx |
Oops, something went wrong.