From d8809b520ee3824b8dea139270c14d98f50ad0b9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 20 Sep 2024 19:04:21 +0800 Subject: [PATCH] Fix CI errors introduced by supporting loading keywords from buffers (#1366) --- .../sherpa_onnx/lib/src/keyword_spotter.dart | 9 +++++++- .../lib/src/sherpa_onnx_bindings.dart | 5 ++++ scripts/dotnet/KeywordSpotterConfig.cs | 7 ++++++ scripts/go/_internal/build_linux_arm64.go | 2 +- .../node-addon-api/src/keyword-spotting.cc | 6 +++++ swift-api-examples/SherpaOnnx.swift | 8 +++++-- wasm/kws/sherpa-onnx-kws.js | 23 +++++++++++++++++-- wasm/kws/sherpa-onnx-wasm-main-kws.cc | 2 +- 8 files changed, 55 insertions(+), 7 deletions(-) diff --git a/flutter/sherpa_onnx/lib/src/keyword_spotter.dart b/flutter/sherpa_onnx/lib/src/keyword_spotter.dart index 724acd6f1..c09867995 100644 --- a/flutter/sherpa_onnx/lib/src/keyword_spotter.dart +++ b/flutter/sherpa_onnx/lib/src/keyword_spotter.dart @@ -19,11 +19,13 @@ class KeywordSpotterConfig { this.keywordsScore = 1.0, this.keywordsThreshold = 0.25, this.keywordsFile = '', + this.keywordsBuf = '', + this.keywordsBufSize = 0, }); @override String toString() { - return 'KeywordSpotterConfig(feat: $feat, model: $model, maxActivePaths: $maxActivePaths, numTrailingBlanks: $numTrailingBlanks, keywordsScore: $keywordsScore, keywordsThreshold: $keywordsThreshold, keywordsFile: $keywordsFile)'; + return 'KeywordSpotterConfig(feat: $feat, model: $model, maxActivePaths: $maxActivePaths, numTrailingBlanks: $numTrailingBlanks, keywordsScore: $keywordsScore, keywordsThreshold: $keywordsThreshold, keywordsFile: $keywordsFile, keywordsBuf: $keywordsBuf, keywordsBufSize: $keywordsBufSize)'; } final FeatureConfig feat; @@ -35,6 +37,8 @@ class KeywordSpotterConfig { final double keywordsScore; final double keywordsThreshold; final String keywordsFile; + final String keywordsBuf; + final int keywordsBufSize; } class KeywordResult { @@ -89,9 +93,12 @@ class KeywordSpotter { c.ref.keywordsScore = config.keywordsScore; c.ref.keywordsThreshold = config.keywordsThreshold; c.ref.keywordsFile = config.keywordsFile.toNativeUtf8(); + c.ref.keywordsBuf = config.keywordsBuf.toNativeUtf8(); + c.ref.keywordsBufSize = config.keywordsBufSize; final ptr = SherpaOnnxBindings.createKeywordSpotter?.call(c) ?? nullptr; + calloc.free(c.ref.keywordsBuf); calloc.free(c.ref.keywordsFile); calloc.free(c.ref.model.bpeVocab); calloc.free(c.ref.model.modelingUnit); diff --git a/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart b/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart index 207160087..42294c2d4 100644 --- a/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart +++ b/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart @@ -371,6 +371,11 @@ final class SherpaOnnxKeywordSpotterConfig extends Struct { external double keywordsThreshold; external Pointer keywordsFile; + + external Pointer keywordsBuf; + + @Int32() + external int keywordsBufSize; } final class SherpaOnnxOfflinePunctuation extends Opaque {} diff --git a/scripts/dotnet/KeywordSpotterConfig.cs b/scripts/dotnet/KeywordSpotterConfig.cs index 125afb716..13f3df5eb 100644 --- a/scripts/dotnet/KeywordSpotterConfig.cs +++ b/scripts/dotnet/KeywordSpotterConfig.cs @@ -17,6 +17,8 @@ public KeywordSpotterConfig() KeywordsScore = 1.0F; KeywordsThreshold = 0.25F; KeywordsFile = ""; + KeywordsBuf= ""; + KeywordsBufSize= 0; } public FeatureConfig FeatConfig; public OnlineModelConfig ModelConfig; @@ -28,5 +30,10 @@ public KeywordSpotterConfig() [MarshalAs(UnmanagedType.LPStr)] public string KeywordsFile; + + [MarshalAs(UnmanagedType.LPStr)] + public string KeywordsBuf; + + public int KeywordsBufSize; } } diff --git a/scripts/go/_internal/build_linux_arm64.go b/scripts/go/_internal/build_linux_arm64.go index 0bcb60b86..f25f147d7 100644 --- a/scripts/go/_internal/build_linux_arm64.go +++ b/scripts/go/_internal/build_linux_arm64.go @@ -2,5 +2,5 @@ package sherpa_onnx -// #cgo LDFLAGS: -L ${SRCDIR}/lib/aarch64-unknown-linux-gnu -lsherpa-onnx-c-api -lsherpa-onnx-core -lkaldi-native-fbank-core -lkaldi-decoder-core -lsherpa-onnx-kaldifst-core -lsherpa-onnx-fstfar -lsherpa-onnx-fst -lpiper_phonemize -lespeak-ng -lucd -lonnxruntime -lssentencepiece_core -Wl,-rpath,${SRCDIR}/lib/aarch64-unknown-linux-gnu +// #cgo LDFLAGS: -L ${SRCDIR}/lib/aarch64-unknown-linux-gnu -lsherpa-onnx-c-api -lonnxruntime -Wl,-rpath,${SRCDIR}/lib/aarch64-unknown-linux-gnu import "C" diff --git a/scripts/node-addon-api/src/keyword-spotting.cc b/scripts/node-addon-api/src/keyword-spotting.cc index 1e43190b5..2b5a24100 100644 --- a/scripts/node-addon-api/src/keyword-spotting.cc +++ b/scripts/node-addon-api/src/keyword-spotting.cc @@ -43,6 +43,8 @@ static Napi::External CreateKeywordSpotterWrapper( SHERPA_ONNX_ASSIGN_ATTR_FLOAT(keywords_score, keywordsScore); SHERPA_ONNX_ASSIGN_ATTR_FLOAT(keywords_threshold, keywordsThreshold); SHERPA_ONNX_ASSIGN_ATTR_STR(keywords_file, keywordsFile); + SHERPA_ONNX_ASSIGN_ATTR_STR(keywords_buf, keywordsBuf); + SHERPA_ONNX_ASSIGN_ATTR_INT32(keywords_buf_size, keywordsBufSize); SherpaOnnxKeywordSpotter *kws = SherpaOnnxCreateKeywordSpotter(&c); @@ -86,6 +88,10 @@ static Napi::External CreateKeywordSpotterWrapper( delete[] c.keywords_file; } + if (c.keywords_buf) { + delete[] c.keywords_buf; + } + if (!kws) { Napi::TypeError::New(env, "Please check your config!") .ThrowAsJavaScriptException(); diff --git a/swift-api-examples/SherpaOnnx.swift b/swift-api-examples/SherpaOnnx.swift index e24819306..778bccb9b 100644 --- a/swift-api-examples/SherpaOnnx.swift +++ b/swift-api-examples/SherpaOnnx.swift @@ -966,7 +966,9 @@ func sherpaOnnxKeywordSpotterConfig( maxActivePaths: Int = 4, numTrailingBlanks: Int = 1, keywordsScore: Float = 1.0, - keywordsThreshold: Float = 0.25 + keywordsThreshold: Float = 0.25, + keywordsBuf: String = "", + keywordsBufSize: Int = 0 ) -> SherpaOnnxKeywordSpotterConfig { return SherpaOnnxKeywordSpotterConfig( feat_config: featConfig, @@ -975,7 +977,9 @@ func sherpaOnnxKeywordSpotterConfig( num_trailing_blanks: Int32(numTrailingBlanks), keywords_score: keywordsScore, keywords_threshold: keywordsThreshold, - keywords_file: toCPointer(keywordsFile) + keywords_file: toCPointer(keywordsFile), + keywords_buf: toCPointer(keywordsBuf), + keywords_buf_size: Int32(keywordsBufSize) ) } diff --git a/wasm/kws/sherpa-onnx-kws.js b/wasm/kws/sherpa-onnx-kws.js index dc1712bc9..b7c023356 100644 --- a/wasm/kws/sherpa-onnx-kws.js +++ b/wasm/kws/sherpa-onnx-kws.js @@ -172,10 +172,18 @@ function initKwsConfig(config, Module) { }; } + if (!('keywordsBuf' in config)) { + config.keywordsBuf = ''; + } + + if (!('keywordsBufSize' in config)) { + config.keywordsBufSize = 0; + } + let featConfig = initFeatureExtractorConfig(config.featConfig, Module); let modelConfig = initModelConfig(config.modelConfig, Module); - let numBytes = featConfig.len + modelConfig.len + 4 * 5; + let numBytes = featConfig.len + modelConfig.len + 4 * 7; let ptr = Module._malloc(numBytes); let offset = 0; @@ -198,11 +206,22 @@ function initKwsConfig(config, Module) { offset += 4; let keywordsLen = Module.lengthBytesUTF8(config.keywords) + 1; - let keywordsBuffer = Module._malloc(keywordsLen); + let keywordsBufLen = Module.lengthBytesUTF8(config.keywordsBuf) + 1; + + let keywordsBuffer = Module._malloc(keywordsLen + keywordsBufLen); Module.stringToUTF8(config.keywords, keywordsBuffer, keywordsLen); + Module.stringToUTF8( + config.keywordsBuf, keywordsBuffer + keywordsLen, keywordsBufLen); + Module.setValue(ptr + offset, keywordsBuffer, 'i8*'); offset += 4; + Module.setValue(ptr + offset, keywordsBuffer + keywordsLen, 'i8*'); + offset += 4; + + Module.setValue(ptr + offset, config.keywordsBufLen, 'i32'); + offset += 4; + return { ptr: ptr, len: numBytes, featConfig: featConfig, modelConfig: modelConfig, keywordsBuffer: keywordsBuffer diff --git a/wasm/kws/sherpa-onnx-wasm-main-kws.cc b/wasm/kws/sherpa-onnx-wasm-main-kws.cc index cb3627955..39f295d90 100644 --- a/wasm/kws/sherpa-onnx-wasm-main-kws.cc +++ b/wasm/kws/sherpa-onnx-wasm-main-kws.cc @@ -24,7 +24,7 @@ static_assert(sizeof(SherpaOnnxOnlineModelConfig) == static_assert(sizeof(SherpaOnnxFeatureConfig) == 2 * 4, ""); static_assert(sizeof(SherpaOnnxKeywordSpotterConfig) == sizeof(SherpaOnnxFeatureConfig) + - sizeof(SherpaOnnxOnlineModelConfig) + 5 * 4, + sizeof(SherpaOnnxOnlineModelConfig) + 7 * 4, ""); void CopyHeap(const char *src, int32_t num_bytes, char *dst) {