Skip to content

Commit

Permalink
Add swift online punctuation
Browse files Browse the repository at this point in the history
  • Loading branch information
yujinqiu committed Dec 30, 2024
1 parent 38d64a6 commit e3895a2
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 0 deletions.
48 changes: 48 additions & 0 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-punctuation.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/online-punctuation.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/resample.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
Expand Down Expand Up @@ -1717,6 +1718,53 @@ const char *SherpaOfflinePunctuationAddPunct(

void SherpaOfflinePunctuationFreeText(const char *text) { delete[] text; }

struct SherpaOnnxOnlinePunctuation {
std::unique_ptr<sherpa_onnx::OnlinePunctuation> impl;
};

const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation(
const SherpaOnnxOnlinePunctuationConfig *config) {
auto p = new SherpaOnnxOnlinePunctuation;
try {
sherpa_onnx::OnlinePunctuationConfig punctuation_config;
punctuation_config.model.cnn_bilstm = config->model.cnn_bilstm;
punctuation_config.model.bpe_vocab = config->model.bpe_vocab;
punctuation_config.model.num_threads = config->model.num_threads;
punctuation_config.model.debug = config->model.debug;
punctuation_config.model.provider = config->model.provider;

p->impl =
std::make_unique<sherpa_onnx::OnlinePunctuation>(punctuation_config);
} catch (const std::exception &e) {
SHERPA_ONNX_LOGE("Failed to create online punctuation: %s", e.what());
delete p;
return nullptr;
}
return p;
}

void SherpaOnnxDestroyOnlinePunctuation(const SherpaOnnxOnlinePunctuation *p) {
delete p;
}

const char *SherpaOnnxOnlinePunctuationAddPunct(
const SherpaOnnxOnlinePunctuation *punctuation, const char *text) {
if (!punctuation || !text) return nullptr;

try {
std::string s = punctuation->impl->AddPunctuationWithCase(text);
char *p = new char[s.size() + 1];
std::copy(s.begin(), s.end(), p);
p[s.size()] = '\0';
return p;
} catch (const std::exception &e) {
SHERPA_ONNX_LOGE("Failed to add punctuation: %s", e.what());
return nullptr;
}
}

void SherpaOnnxOnlinePunctuationFreeText(const char *text) { delete[] text; }

struct SherpaOnnxLinearResampler {
std::unique_ptr<sherpa_onnx::LinearResample> impl;
};
Expand Down
33 changes: 33 additions & 0 deletions sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,39 @@ SHERPA_ONNX_API const char *SherpaOfflinePunctuationAddPunct(

SHERPA_ONNX_API void SherpaOfflinePunctuationFreeText(const char *text);

SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuationModelConfig {
const char *cnn_bilstm;
const char *bpe_vocab;
int32_t num_threads;
int32_t debug;
const char *provider;
} SherpaOnnxOnlinePunctuationModelConfig;

SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuationConfig {
SherpaOnnxOnlinePunctuationModelConfig model;
} SherpaOnnxOnlinePunctuationConfig;

SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuation SherpaOnnxOnlinePunctuation;

// Create an online punctuation processor. The user has to invoke
// SherpaOnnxDestroyOnlinePunctuation() to free the returned pointer
// to avoid memory leak
SHERPA_ONNX_API const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation(
const SherpaOnnxOnlinePunctuationConfig *config);

// Free a pointer returned by SherpaOnnxCreateOnlinePunctuation()
SHERPA_ONNX_API void SherpaOnnxDestroyOnlinePunctuation(
const SherpaOnnxOnlinePunctuation *punctuation);

// Add punctuations to the input text. The user has to invoke
// SherpaOnnxOnlinePunctuationFreeText() to free the returned pointer
// to avoid memory leak
SHERPA_ONNX_API const char *SherpaOnnxOnlinePunctuationAddPunct(
const SherpaOnnxOnlinePunctuation *punctuation, const char *text);

// Free a pointer returned by SherpaOnnxOnlinePunctuationAddPunct()
SHERPA_ONNX_API void SherpaOnnxOnlinePunctuationFreeText(const char *text);

// for resampling
SHERPA_ONNX_API typedef struct SherpaOnnxLinearResampler
SherpaOnnxLinearResampler;
Expand Down
46 changes: 46 additions & 0 deletions swift-api-examples/SherpaOnnx.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,52 @@ class SherpaOnnxOfflinePunctuationWrapper {
}
}

func sherpaOnnxOnlinePunctuationModelConfig(
cnnBiLstm: String,
bpeVocab: String,
numThreads: Int = 1,
debug: Int = 0,
provider: String = "cpu"
) -> SherpaOnnxOnlinePunctuationModelConfig {
return SherpaOnnxOnlinePunctuationModelConfig(
cnn_bilstm: toCPointer(cnnBiLstm),
bpe_vocab: toCPointer(bpeVocab),
num_threads: Int32(numThreads),
debug: Int32(debug),
provider: toCPointer(provider))
}

func sherpaOnnxOnlinePunctuationConfig(
model: SherpaOnnxOnlinePunctuationModelConfig
) -> SherpaOnnxOnlinePunctuationConfig {
return SherpaOnnxOnlinePunctuationConfig(model: model)
}

class SherpaOnnxOnlinePunctuationWrapper {
/// A pointer to the underlying counterpart in C
let ptr: OpaquePointer!

/// Constructor taking a model config
init(
config: UnsafePointer<SherpaOnnxOnlinePunctuationConfig>!
) {
ptr = SherpaOnnxCreateOnlinePunctuation(config)
}

deinit {
if let ptr {
SherpaOnnxDestroyOnlinePunctuation(ptr)
}
}

func addPunct(text: String) -> String {
let cText = SherpaOnnxOnlinePunctuationAddPunct(ptr, toCPointer(text))
let ans = String(cString: cText!)
SherpaOnnxOnlinePunctuationFreeText(cText)
return ans
}
}

func sherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig(model: String)
-> SherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig
{
Expand Down
35 changes: 35 additions & 0 deletions swift-api-examples/add-punctuation-online.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
func run() {
let model = "./sherpa-onnx-online-punct-en-2024-08-06/model.onnx"
let bpe = "./sherpa-onnx-online-punct-en-2024-08-06/bpe.vocab"

// Create model config
let modelConfig = sherpaOnnxOnlinePunctuationModelConfig(
cnnBiLstm: model,
bpeVocab: bpe
)

// Create punctuation config
var config = sherpaOnnxOnlinePunctuationConfig(model: modelConfig)

// Create punctuation instance
let punct = SherpaOnnxOnlinePunctuationWrapper(config: &config)

// Test texts
let textList = [
"how are you i am fine thank you",
"The African blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry"
]

// Process each text
for i in 0..<textList.count {
let t = punct.addPunct(text: textList[i])
print("\nresult is:\n\(t)")
}
}

@main
struct App {
static func main() {
run()
}
}
36 changes: 36 additions & 0 deletions swift-api-examples/run-add-punctuations-online.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/usr/bin/env bash

set -ex

if [ ! -d ../build-swift-macos ]; then
echo "Please run ../build-swift-macos.sh first!"
exit 1
fi

# Download and extract the online punctuation model if not exists
if [ ! -d ./sherpa-onnx-online-punct-en-2024-08-06 ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
tar xvf sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
rm sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
fi

if [ ! -e ./add-punctuation-online ]; then
# Note: We use -lc++ to link against libc++ instead of libstdc++
swiftc \
-lc++ \
-I ../build-swift-macos/install/include \
-import-objc-header ./SherpaOnnx-Bridging-Header.h \
./add-punctuation-online.swift ./SherpaOnnx.swift \
-L ../build-swift-macos/install/lib/ \
-l sherpa-onnx \
-l onnxruntime \
-o ./add-punctuation-online

strip ./add-punctuation-online
else
echo "./add-punctuation-online exists - skip building"
fi

# Set library path and run the executable
export DYLD_LIBRARY_PATH=$PWD/../build-swift-macos/install/lib:$DYLD_LIBRARY_PATH
./add-punctuation-online

0 comments on commit e3895a2

Please sign in to comment.