From cd1fedaa498e2efb79bf06ac97f17243d526f036 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 29 Jul 2024 11:15:14 +0800 Subject: [PATCH] Add Dart API for audio tagging (#1181) --- .github/scripts/test-dart.sh | 8 + .github/workflows/test-dart.yaml | 1 + CHANGELOG.md | 4 + README.md | 10 +- dart-api-examples/README.md | 3 +- dart-api-examples/audio-tagging/.gitignore | 3 + dart-api-examples/audio-tagging/README.md | 8 + .../audio-tagging/analysis_options.yaml | 30 ++++ dart-api-examples/audio-tagging/bin/ced.dart | 54 +++++++ dart-api-examples/audio-tagging/bin/init.dart | 1 + .../audio-tagging/bin/zipformer.dart | 59 +++++++ dart-api-examples/audio-tagging/pubspec.yaml | 17 +++ dart-api-examples/audio-tagging/run-ced.sh | 19 +++ .../audio-tagging/run-zipformer.sh | 19 +++ .../non-streaming-asr/bin/nemo-ctc.dart | 1 - .../bin/nemo-transducer.dart | 1 - .../non-streaming-asr/bin/paraformer-itn.dart | 1 - .../non-streaming-asr/bin/paraformer.dart | 1 - .../non-streaming-asr/bin/sense-voice.dart | 1 - .../non-streaming-asr/bin/telespeech-ctc.dart | 1 - .../non-streaming-asr/bin/whisper.dart | 1 - .../bin/zipformer-transducer.dart | 1 - dart-api-examples/tts/bin/coqui.dart | 3 +- dart-api-examples/tts/bin/piper.dart | 2 +- dart-api-examples/tts/bin/zh.dart | 3 +- dart-api-examples/vad/bin/vad.dart | 2 +- flutter/sherpa_onnx/lib/sherpa_onnx.dart | 1 + .../sherpa_onnx/lib/src/audio_tagging.dart | 144 ++++++++++++++++++ .../lib/src/sherpa_onnx_bindings.dart | 105 +++++++++++++ scripts/dart/audio-tagging-pubspec.yaml | 18 +++ 30 files changed, 504 insertions(+), 18 deletions(-) create mode 100644 dart-api-examples/audio-tagging/.gitignore create mode 100644 dart-api-examples/audio-tagging/README.md create mode 100644 dart-api-examples/audio-tagging/analysis_options.yaml create mode 100644 dart-api-examples/audio-tagging/bin/ced.dart create mode 120000 dart-api-examples/audio-tagging/bin/init.dart create mode 100644 dart-api-examples/audio-tagging/bin/zipformer.dart create mode 100644 dart-api-examples/audio-tagging/pubspec.yaml create mode 100755 dart-api-examples/audio-tagging/run-ced.sh create mode 100755 dart-api-examples/audio-tagging/run-zipformer.sh create mode 100644 flutter/sherpa_onnx/lib/src/audio_tagging.dart create mode 100644 scripts/dart/audio-tagging-pubspec.yaml diff --git a/.github/scripts/test-dart.sh b/.github/scripts/test-dart.sh index 9da908f48..a35233602 100755 --- a/.github/scripts/test-dart.sh +++ b/.github/scripts/test-dart.sh @@ -4,6 +4,14 @@ set -ex cd dart-api-examples +pushd audio-tagging +echo '----------zipformer----------' +./run-zipformer.sh + +echo '----------ced----------' +./run-ced.sh +popd + pushd vad-with-non-streaming-asr echo '----------TeleSpeech CTC----------' ./run-telespeech-ctc.sh diff --git a/.github/workflows/test-dart.yaml b/.github/workflows/test-dart.yaml index eb5c0b2b0..e90176d09 100644 --- a/.github/workflows/test-dart.yaml +++ b/.github/workflows/test-dart.yaml @@ -110,6 +110,7 @@ jobs: cp scripts/dart/tts-pubspec.yaml dart-api-examples/tts/pubspec.yaml cp scripts/dart/kws-pubspec.yaml dart-api-examples/keyword-spotter/pubspec.yaml cp scripts/dart/vad-non-streaming-asr-pubspec.yaml dart-api-examples/vad-with-non-streaming-asr/pubspec.yaml + cp scripts/dart/audio-tagging-pubspec.yaml dart-api-examples/audio-tagging/pubspec.yaml cp scripts/dart/sherpa-onnx-pubspec.yaml flutter/sherpa_onnx/pubspec.yaml diff --git a/CHANGELOG.md b/CHANGELOG.md index b7106bdb5..1120192b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 1.10.20 + +* Add Dart API for audio tagging + ## 1.10.19 * Prefix all C API functions with SherpaOnnx diff --git a/README.md b/README.md index 26910d80a..bb2a748dd 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,13 @@ |------------------|------------------|----------------------|------------------------| | ✔️ | ✔️ | ✔️ | ✔️ | -| Spoken Language identification | Audio tagging | Voice activity detection | Keyword spotting | -|--------------------------------|---------------|--------------------------|------------------| -| ✔️ | ✔️ | ✔️ | ✔️ | +| Spoken Language identification | Audio tagging | Voice activity detection | +|--------------------------------|---------------|--------------------------| +| ✔️ | ✔️ | ✔️ | + +| Keyword spotting | Add punctuation | +|------------------|-----------------| +| ✔️ | ✔️ | ### Supported platforms diff --git a/dart-api-examples/README.md b/dart-api-examples/README.md index 855691e5b..239313eb3 100644 --- a/dart-api-examples/README.md +++ b/dart-api-examples/README.md @@ -5,7 +5,7 @@ This directory contains examples for Dart API. You can find the package at https://pub.dev/packages/sherpa_onnx -## Descirption +## Description | Directory | Description | |-----------|-------------| @@ -15,6 +15,7 @@ https://pub.dev/packages/sherpa_onnx | [./tts](./tts)| Example for text to speech| | [./vad](./vad)| Example for voice activity detection| | [./vad-with-non-streaming-asr](./vad-with-non-streaming-asr)| Example for voice activity detection with non-streaming speech recognition. You can use it to generate subtitles.| +| [./audio-tagging](./audio-tagging)| Example for audio tagging.| ## How to create an example in this folder diff --git a/dart-api-examples/audio-tagging/.gitignore b/dart-api-examples/audio-tagging/.gitignore new file mode 100644 index 000000000..3a8579040 --- /dev/null +++ b/dart-api-examples/audio-tagging/.gitignore @@ -0,0 +1,3 @@ +# https://dart.dev/guides/libraries/private-files +# Created by `dart pub` +.dart_tool/ diff --git a/dart-api-examples/audio-tagging/README.md b/dart-api-examples/audio-tagging/README.md new file mode 100644 index 000000000..84c2f1f8f --- /dev/null +++ b/dart-api-examples/audio-tagging/README.md @@ -0,0 +1,8 @@ +# Introduction + +This example shows how to use the Dart API from sherpa-onnx for audio tagging. + +| File | Description| +|------|------------| +|[./bin/zipformer.dart](./bin/zipformer.dart)| Use a Zipformer model for audio tagging. See [./run-zipformer.sh](./run-zipformer.sh)| +|[./bin/ced.dart](./bin/ced.dart)| Use a [CED](https://github.com/RicherMans/CED) model for audio tagging. See [./run-ced.sh](./run-ced.sh)| diff --git a/dart-api-examples/audio-tagging/analysis_options.yaml b/dart-api-examples/audio-tagging/analysis_options.yaml new file mode 100644 index 000000000..dee8927aa --- /dev/null +++ b/dart-api-examples/audio-tagging/analysis_options.yaml @@ -0,0 +1,30 @@ +# This file configures the static analysis results for your project (errors, +# warnings, and lints). +# +# This enables the 'recommended' set of lints from `package:lints`. +# This set helps identify many issues that may lead to problems when running +# or consuming Dart code, and enforces writing Dart using a single, idiomatic +# style and format. +# +# If you want a smaller set of lints you can change this to specify +# 'package:lints/core.yaml'. These are just the most critical lints +# (the recommended set includes the core lints). +# The core lints are also what is used by pub.dev for scoring packages. + +include: package:lints/recommended.yaml + +# Uncomment the following section to specify additional rules. + +# linter: +# rules: +# - camel_case_types + +# analyzer: +# exclude: +# - path/to/excluded/files/** + +# For more information about the core and recommended set of lints, see +# https://dart.dev/go/core-lints + +# For additional information about configuring this file, see +# https://dart.dev/guides/language/analysis-options diff --git a/dart-api-examples/audio-tagging/bin/ced.dart b/dart-api-examples/audio-tagging/bin/ced.dart new file mode 100644 index 000000000..0c4b07e37 --- /dev/null +++ b/dart-api-examples/audio-tagging/bin/ced.dart @@ -0,0 +1,54 @@ +// Copyright (c) 2024 Xiaomi Corporation +import 'dart:io'; + +import 'package:args/args.dart'; +import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx; +import './init.dart'; + +void main(List arguments) async { + await initSherpaOnnx(); + + final parser = ArgParser() + ..addOption('model', help: 'Path to the zipformer model') + ..addOption('labels', help: 'Path to class_labels_indices.csv') + ..addOption('top-k', help: 'topK events to be returned', defaultsTo: '5') + ..addOption('wav', help: 'Path to test.wav to be tagged'); + + final res = parser.parse(arguments); + if (res['model'] == null || res['labels'] == null || res['wav'] == null) { + print(parser.usage); + exit(1); + } + + final model = res['model'] as String; + final labels = res['labels'] as String; + final topK = int.tryParse(res['top-k'] as String) ?? 5; + final wav = res['wav'] as String; + + final modelConfig = sherpa_onnx.AudioTaggingModelConfig( + ced: model, + numThreads: 1, + debug: true, + provider: 'cpu', + ); + + final config = sherpa_onnx.AudioTaggingConfig( + model: modelConfig, + labels: labels, + ); + + final at = sherpa_onnx.AudioTagging(config: config); + + final waveData = sherpa_onnx.readWave(wav); + + final stream = at.createStream(); + stream.acceptWaveform( + samples: waveData.samples, sampleRate: waveData.sampleRate); + + final events = at.compute(stream: stream, topK: topK); + + print(events); + + stream.free(); + at.free(); +} diff --git a/dart-api-examples/audio-tagging/bin/init.dart b/dart-api-examples/audio-tagging/bin/init.dart new file mode 120000 index 000000000..48508cfd3 --- /dev/null +++ b/dart-api-examples/audio-tagging/bin/init.dart @@ -0,0 +1 @@ +../../vad/bin/init.dart \ No newline at end of file diff --git a/dart-api-examples/audio-tagging/bin/zipformer.dart b/dart-api-examples/audio-tagging/bin/zipformer.dart new file mode 100644 index 000000000..4021a5efb --- /dev/null +++ b/dart-api-examples/audio-tagging/bin/zipformer.dart @@ -0,0 +1,59 @@ +// Copyright (c) 2024 Xiaomi Corporation +import 'dart:io'; + +import 'package:args/args.dart'; +import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx; +import './init.dart'; + +void main(List arguments) async { + await initSherpaOnnx(); + + final parser = ArgParser() + ..addOption('model', help: 'Path to the zipformer model') + ..addOption('labels', help: 'Path to class_labels_indices.csv') + ..addOption('top-k', help: 'topK events to be returned', defaultsTo: '5') + ..addOption('wav', help: 'Path to test.wav to be tagged'); + + final res = parser.parse(arguments); + if (res['model'] == null || res['labels'] == null || res['wav'] == null) { + print(parser.usage); + exit(1); + } + + final model = res['model'] as String; + final labels = res['labels'] as String; + final topK = int.tryParse(res['top-k'] as String) ?? 5; + final wav = res['wav'] as String; + + final zipformerModelConfig = + sherpa_onnx.OfflineZipformerAudioTaggingModelConfig( + model: model, + ); + + final modelConfig = sherpa_onnx.AudioTaggingModelConfig( + zipformer: zipformerModelConfig, + numThreads: 1, + debug: true, + provider: 'cpu', + ); + + final config = sherpa_onnx.AudioTaggingConfig( + model: modelConfig, + labels: labels, + ); + + final at = sherpa_onnx.AudioTagging(config: config); + + final waveData = sherpa_onnx.readWave(wav); + + final stream = at.createStream(); + stream.acceptWaveform( + samples: waveData.samples, sampleRate: waveData.sampleRate); + + final events = at.compute(stream: stream, topK: topK); + + print(events); + + stream.free(); + at.free(); +} diff --git a/dart-api-examples/audio-tagging/pubspec.yaml b/dart-api-examples/audio-tagging/pubspec.yaml new file mode 100644 index 000000000..15845b8c4 --- /dev/null +++ b/dart-api-examples/audio-tagging/pubspec.yaml @@ -0,0 +1,17 @@ +name: audio_tagging + +description: > + This example demonstrates how to use the Dart API for audio tagging. + +version: 1.0.0 + +environment: + sdk: ^3.4.0 + +dependencies: + sherpa_onnx: ^1.10.19 + path: ^1.9.0 + args: ^2.5.0 + +dev_dependencies: + lints: ^3.0.0 diff --git a/dart-api-examples/audio-tagging/run-ced.sh b/dart-api-examples/audio-tagging/run-ced.sh new file mode 100755 index 000000000..e2fc2c276 --- /dev/null +++ b/dart-api-examples/audio-tagging/run-ced.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +set -ex + +dart pub get + +if [[ ! -f ./sherpa-onnx-ced-mini-audio-tagging-2024-04-19/model.onnx ]]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-ced-mini-audio-tagging-2024-04-19.tar.bz2 + tar xvf sherpa-onnx-ced-mini-audio-tagging-2024-04-19.tar.bz2 + rm sherpa-onnx-ced-mini-audio-tagging-2024-04-19.tar.bz2 +fi + +for w in 1 2 3 4 5 6; do + dart run \ + ./bin/ced.dart \ + --model ./sherpa-onnx-ced-mini-audio-tagging-2024-04-19/model.int8.onnx \ + --labels ./sherpa-onnx-ced-mini-audio-tagging-2024-04-19/class_labels_indices.csv \ + --wav ./sherpa-onnx-ced-mini-audio-tagging-2024-04-19/test_wavs/$w.wav +done diff --git a/dart-api-examples/audio-tagging/run-zipformer.sh b/dart-api-examples/audio-tagging/run-zipformer.sh new file mode 100755 index 000000000..fd1e14f13 --- /dev/null +++ b/dart-api-examples/audio-tagging/run-zipformer.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +set -ex + +dart pub get + +if [[ ! -f ./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.onnx ]]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 + tar xvf sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 + rm sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 +fi + +for w in 1 2 3 4 5 6; do + dart run \ + ./bin/zipformer.dart \ + --model ./sherpa-onnx-zipformer-audio-tagging-2024-04-09/model.int8.onnx \ + --labels ./sherpa-onnx-zipformer-audio-tagging-2024-04-09/class_labels_indices.csv \ + --wav ./sherpa-onnx-zipformer-audio-tagging-2024-04-09/test_wavs/$w.wav +done diff --git a/dart-api-examples/non-streaming-asr/bin/nemo-ctc.dart b/dart-api-examples/non-streaming-asr/bin/nemo-ctc.dart index 2565862bb..d57b81653 100644 --- a/dart-api-examples/non-streaming-asr/bin/nemo-ctc.dart +++ b/dart-api-examples/non-streaming-asr/bin/nemo-ctc.dart @@ -1,6 +1,5 @@ // Copyright (c) 2024 Xiaomi Corporation import 'dart:io'; -import 'dart:typed_data'; import 'package:args/args.dart'; import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx; diff --git a/dart-api-examples/non-streaming-asr/bin/nemo-transducer.dart b/dart-api-examples/non-streaming-asr/bin/nemo-transducer.dart index 3df8095c6..49d0bb83f 100644 --- a/dart-api-examples/non-streaming-asr/bin/nemo-transducer.dart +++ b/dart-api-examples/non-streaming-asr/bin/nemo-transducer.dart @@ -1,6 +1,5 @@ // Copyright (c) 2024 Xiaomi Corporation import 'dart:io'; -import 'dart:typed_data'; import 'package:args/args.dart'; import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx; diff --git a/dart-api-examples/non-streaming-asr/bin/paraformer-itn.dart b/dart-api-examples/non-streaming-asr/bin/paraformer-itn.dart index c8d2c0801..3347c9e3c 100644 --- a/dart-api-examples/non-streaming-asr/bin/paraformer-itn.dart +++ b/dart-api-examples/non-streaming-asr/bin/paraformer-itn.dart @@ -1,6 +1,5 @@ // Copyright (c) 2024 Xiaomi Corporation import 'dart:io'; -import 'dart:typed_data'; import 'package:args/args.dart'; import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx; diff --git a/dart-api-examples/non-streaming-asr/bin/paraformer.dart b/dart-api-examples/non-streaming-asr/bin/paraformer.dart index 15f45a1c5..c0e2af8d1 100644 --- a/dart-api-examples/non-streaming-asr/bin/paraformer.dart +++ b/dart-api-examples/non-streaming-asr/bin/paraformer.dart @@ -1,6 +1,5 @@ // Copyright (c) 2024 Xiaomi Corporation import 'dart:io'; -import 'dart:typed_data'; import 'package:args/args.dart'; import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx; diff --git a/dart-api-examples/non-streaming-asr/bin/sense-voice.dart b/dart-api-examples/non-streaming-asr/bin/sense-voice.dart index d02fcaf6d..5af4bc80e 100644 --- a/dart-api-examples/non-streaming-asr/bin/sense-voice.dart +++ b/dart-api-examples/non-streaming-asr/bin/sense-voice.dart @@ -1,6 +1,5 @@ // Copyright (c) 2024 Xiaomi Corporation import 'dart:io'; -import 'dart:typed_data'; import 'package:args/args.dart'; import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx; diff --git a/dart-api-examples/non-streaming-asr/bin/telespeech-ctc.dart b/dart-api-examples/non-streaming-asr/bin/telespeech-ctc.dart index 633baabef..62c72a923 100644 --- a/dart-api-examples/non-streaming-asr/bin/telespeech-ctc.dart +++ b/dart-api-examples/non-streaming-asr/bin/telespeech-ctc.dart @@ -1,6 +1,5 @@ // Copyright (c) 2024 Xiaomi Corporation import 'dart:io'; -import 'dart:typed_data'; import 'package:args/args.dart'; import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx; diff --git a/dart-api-examples/non-streaming-asr/bin/whisper.dart b/dart-api-examples/non-streaming-asr/bin/whisper.dart index 1fffcd835..b59b8b384 100644 --- a/dart-api-examples/non-streaming-asr/bin/whisper.dart +++ b/dart-api-examples/non-streaming-asr/bin/whisper.dart @@ -1,6 +1,5 @@ // Copyright (c) 2024 Xiaomi Corporation import 'dart:io'; -import 'dart:typed_data'; import 'package:args/args.dart'; import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx; diff --git a/dart-api-examples/non-streaming-asr/bin/zipformer-transducer.dart b/dart-api-examples/non-streaming-asr/bin/zipformer-transducer.dart index 3df8095c6..49d0bb83f 100644 --- a/dart-api-examples/non-streaming-asr/bin/zipformer-transducer.dart +++ b/dart-api-examples/non-streaming-asr/bin/zipformer-transducer.dart @@ -1,6 +1,5 @@ // Copyright (c) 2024 Xiaomi Corporation import 'dart:io'; -import 'dart:typed_data'; import 'package:args/args.dart'; import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx; diff --git a/dart-api-examples/tts/bin/coqui.dart b/dart-api-examples/tts/bin/coqui.dart index 264d671b6..1acc5b6a6 100644 --- a/dart-api-examples/tts/bin/coqui.dart +++ b/dart-api-examples/tts/bin/coqui.dart @@ -1,6 +1,5 @@ // Copyright (c) 2024 Xiaomi Corporation import 'dart:io'; -import 'dart:typed_data'; import 'package:args/args.dart'; import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx; @@ -65,5 +64,5 @@ void main(List arguments) async { samples: audio.samples, sampleRate: audio.sampleRate, ); - print('Saved to ${outputWav}'); + print('Saved to $outputWav'); } diff --git a/dart-api-examples/tts/bin/piper.dart b/dart-api-examples/tts/bin/piper.dart index 327407324..d457e86c6 100644 --- a/dart-api-examples/tts/bin/piper.dart +++ b/dart-api-examples/tts/bin/piper.dart @@ -80,5 +80,5 @@ void main(List arguments) async { samples: audio.samples, sampleRate: audio.sampleRate, ); - print('Saved to ${outputWav}'); + print('Saved to $outputWav'); } diff --git a/dart-api-examples/tts/bin/zh.dart b/dart-api-examples/tts/bin/zh.dart index 44770ab3f..f8fc0d4b7 100644 --- a/dart-api-examples/tts/bin/zh.dart +++ b/dart-api-examples/tts/bin/zh.dart @@ -1,6 +1,5 @@ // Copyright (c) 2024 Xiaomi Corporation import 'dart:io'; -import 'dart:typed_data'; import 'package:args/args.dart'; import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx; @@ -82,5 +81,5 @@ void main(List arguments) async { samples: audio.samples, sampleRate: audio.sampleRate, ); - print('Saved to ${outputWav}'); + print('Saved to $outputWav'); } diff --git a/dart-api-examples/vad/bin/vad.dart b/dart-api-examples/vad/bin/vad.dart index 5baccd2f6..a2416c97b 100644 --- a/dart-api-examples/vad/bin/vad.dart +++ b/dart-api-examples/vad/bin/vad.dart @@ -77,5 +77,5 @@ void main(List arguments) async { sherpa_onnx.writeWave( filename: outputWav, samples: s, sampleRate: waveData.sampleRate); - print('Saved to ${outputWav}'); + print('Saved to $outputWav'); } diff --git a/flutter/sherpa_onnx/lib/sherpa_onnx.dart b/flutter/sherpa_onnx/lib/sherpa_onnx.dart index 4cc1e7621..119e43217 100644 --- a/flutter/sherpa_onnx/lib/sherpa_onnx.dart +++ b/flutter/sherpa_onnx/lib/sherpa_onnx.dart @@ -2,6 +2,7 @@ import 'dart:io'; import 'dart:ffi'; +export 'src/audio_tagging.dart'; export 'src/feature_config.dart'; export 'src/keyword_spotter.dart'; export 'src/offline_recognizer.dart'; diff --git a/flutter/sherpa_onnx/lib/src/audio_tagging.dart b/flutter/sherpa_onnx/lib/src/audio_tagging.dart new file mode 100644 index 000000000..6c650b30c --- /dev/null +++ b/flutter/sherpa_onnx/lib/src/audio_tagging.dart @@ -0,0 +1,144 @@ +// Copyright (c) 2024 Xiaomi Corporation +import 'dart:ffi'; +import 'package:ffi/ffi.dart'; + +import './offline_stream.dart'; +import './sherpa_onnx_bindings.dart'; + +class OfflineZipformerAudioTaggingModelConfig { + const OfflineZipformerAudioTaggingModelConfig({this.model = ''}); + + @override + String toString() { + return 'OfflineZipformerAudioTaggingModelConfig(model: $model)'; + } + + final String model; +} + +class AudioTaggingModelConfig { + AudioTaggingModelConfig( + {this.zipformer = const OfflineZipformerAudioTaggingModelConfig(), + this.ced = '', + this.numThreads = 1, + this.provider = 'cpu', + this.debug = true}); + + @override + String toString() { + return 'AudioTaggingModelConfig(zipformer: $zipformer, ced: $ced, numThreads: $numThreads, provider: $provider, debug: $debug)'; + } + + final OfflineZipformerAudioTaggingModelConfig zipformer; + final String ced; + final int numThreads; + final String provider; + final bool debug; +} + +class AudioTaggingConfig { + AudioTaggingConfig({required this.model, this.labels = ''}); + + @override + String toString() { + return 'AudioTaggingConfig(model: $model, labels: $labels)'; + } + + final AudioTaggingModelConfig model; + final String labels; +} + +class AudioEvent { + AudioEvent({required this.name, required this.index, required this.prob}); + + @override + String toString() { + return 'AudioEvent(name: $name, index: $index, prob: $prob)'; + } + + final String name; + final int index; + final double prob; +} + +class AudioTagging { + AudioTagging._({required this.ptr, required this.config}); + + // The user has to invoke AudioTagging.free() to avoid memory leak. + factory AudioTagging({required AudioTaggingConfig config}) { + final c = calloc(); + + final zipformerPtr = config.model.zipformer.model.toNativeUtf8(); + c.ref.model.zipformer.model = zipformerPtr; + + final cedPtr = config.model.ced.toNativeUtf8(); + c.ref.model.ced = cedPtr; + + c.ref.model.numThreads = config.model.numThreads; + + final providerPtr = config.model.provider.toNativeUtf8(); + c.ref.model.provider = providerPtr; + + c.ref.model.debug = config.model.debug ? 1 : 0; + + final labelsPtr = config.labels.toNativeUtf8(); + c.ref.labels = labelsPtr; + + final ptr = + SherpaOnnxBindings.sherpaOnnxCreateAudioTagging?.call(c) ?? nullptr; + + calloc.free(labelsPtr); + calloc.free(providerPtr); + calloc.free(cedPtr); + calloc.free(zipformerPtr); + calloc.free(c); + + return AudioTagging._(ptr: ptr, config: config); + } + + void free() { + SherpaOnnxBindings.sherpaOnnxDestroyAudioTagging?.call(ptr); + ptr = nullptr; + } + + /// The user has to invoke stream.free() on the returned instance + /// to avoid memory leak + OfflineStream createStream() { + final p = SherpaOnnxBindings.sherpaOnnxAudioTaggingCreateOfflineStream + ?.call(ptr) ?? + nullptr; + return OfflineStream(ptr: p); + } + + List compute({required OfflineStream stream, required int topK}) { + final pp = SherpaOnnxBindings.sherpaOnnxAudioTaggingCompute + ?.call(ptr, stream.ptr, topK) ?? + nullptr; + + final ans = []; + + if (pp == nullptr) { + return ans; + } + + var i = 0; + while (pp[i] != nullptr) { + final p = pp[i]; + + final name = p.ref.name.toDartString(); + final index = p.ref.index; + final prob = p.ref.prob; + final e = AudioEvent(name: name, index: index, prob: prob); + ans.add(e); + + i += 1; + } + + SherpaOnnxBindings.sherpaOnnxAudioTaggingFreeResults?.call(pp); + + return ans; + } + + Pointer ptr; + final AudioTaggingConfig config; +} diff --git a/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart b/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart index 7685f4f2d..7ed3461ad 100644 --- a/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart +++ b/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart @@ -2,6 +2,41 @@ import 'dart:ffi'; import 'package:ffi/ffi.dart'; +final class SherpaOnnxOfflineZipformerAudioTaggingModelConfig extends Struct { + external Pointer model; +} + +final class SherpaOnnxAudioTaggingModelConfig extends Struct { + external SherpaOnnxOfflineZipformerAudioTaggingModelConfig zipformer; + external Pointer ced; + + @Int32() + external int numThreads; + + @Int32() + external int debug; + + external Pointer provider; +} + +final class SherpaOnnxAudioTaggingConfig extends Struct { + external SherpaOnnxAudioTaggingModelConfig model; + external Pointer labels; + + @Int32() + external int topK; +} + +final class SherpaOnnxAudioEvent extends Struct { + external Pointer name; + + @Int32() + external int index; + + @Float() + external double prob; +} + final class SherpaOnnxOfflineTtsVitsModelConfig extends Struct { external Pointer model; external Pointer lexicon; @@ -303,6 +338,8 @@ final class SherpaOnnxKeywordSpotterConfig extends Struct { external Pointer keywordsFile; } +final class SherpaOnnxAudioTagging extends Opaque {} + final class SherpaOnnxKeywordSpotter extends Opaque {} final class SherpaOnnxOfflineTts extends Opaque {} @@ -323,6 +360,40 @@ final class SherpaOnnxSpeakerEmbeddingExtractor extends Opaque {} final class SherpaOnnxSpeakerEmbeddingManager extends Opaque {} +typedef SherpaOnnxCreateAudioTaggingNative = Pointer + Function(Pointer); + +typedef SherpaOnnxCreateAudioTagging = SherpaOnnxCreateAudioTaggingNative; + +typedef SherpaOnnxDestroyAudioTaggingNative = Void Function( + Pointer); + +typedef SherpaOnnxDestroyAudioTagging = void Function( + Pointer); + +typedef SherpaOnnxAudioTaggingCreateOfflineStreamNative + = Pointer Function( + Pointer); + +typedef SherpaOnnxAudioTaggingCreateOfflineStream + = SherpaOnnxAudioTaggingCreateOfflineStreamNative; + +typedef SherpaOnnxAudioTaggingComputeNative + = Pointer> Function( + Pointer, + Pointer, + Int32); + +typedef SherpaOnnxAudioTaggingCompute + = Pointer> Function( + Pointer, Pointer, int); + +typedef SherpaOnnxAudioTaggingFreeResultsNative = Void Function( + Pointer>); + +typedef SherpaOnnxAudioTaggingFreeResults = void Function( + Pointer>); + typedef CreateKeywordSpotterNative = Pointer Function( Pointer); @@ -804,6 +875,13 @@ typedef SherpaOnnxFreeWaveNative = Void Function(Pointer); typedef SherpaOnnxFreeWave = void Function(Pointer); class SherpaOnnxBindings { + static SherpaOnnxCreateAudioTagging? sherpaOnnxCreateAudioTagging; + static SherpaOnnxDestroyAudioTagging? sherpaOnnxDestroyAudioTagging; + static SherpaOnnxAudioTaggingCreateOfflineStream? + sherpaOnnxAudioTaggingCreateOfflineStream; + static SherpaOnnxAudioTaggingCompute? sherpaOnnxAudioTaggingCompute; + static SherpaOnnxAudioTaggingFreeResults? sherpaOnnxAudioTaggingFreeResults; + static CreateKeywordSpotter? createKeywordSpotter; static DestroyKeywordSpotter? destroyKeywordSpotter; static CreateKeywordStream? createKeywordStream; @@ -958,6 +1036,33 @@ class SherpaOnnxBindings { static SherpaOnnxFreeWave? freeWave; static void init(DynamicLibrary dynamicLibrary) { + sherpaOnnxCreateAudioTagging ??= dynamicLibrary + .lookup>( + 'SherpaOnnxCreateAudioTagging') + .asFunction(); + + sherpaOnnxDestroyAudioTagging ??= dynamicLibrary + .lookup>( + 'SherpaOnnxDestroyAudioTagging') + .asFunction(); + + sherpaOnnxAudioTaggingCreateOfflineStream ??= dynamicLibrary + .lookup< + NativeFunction< + SherpaOnnxAudioTaggingCreateOfflineStreamNative>>( + 'SherpaOnnxAudioTaggingCreateOfflineStream') + .asFunction(); + + sherpaOnnxAudioTaggingCompute ??= dynamicLibrary + .lookup>( + 'SherpaOnnxAudioTaggingCompute') + .asFunction(); + + sherpaOnnxAudioTaggingFreeResults ??= dynamicLibrary + .lookup>( + 'SherpaOnnxAudioTaggingFreeResults') + .asFunction(); + createKeywordSpotter ??= dynamicLibrary .lookup>( 'SherpaOnnxCreateKeywordSpotter') diff --git a/scripts/dart/audio-tagging-pubspec.yaml b/scripts/dart/audio-tagging-pubspec.yaml new file mode 100644 index 000000000..5b3b0afd1 --- /dev/null +++ b/scripts/dart/audio-tagging-pubspec.yaml @@ -0,0 +1,18 @@ +name: audio_tagging + +description: > + This example demonstrates how to use the Dart API for audio tagging. + +version: 1.0.0 + +environment: + sdk: ^3.4.0 + +dependencies: + sherpa_onnx: + path: ../../flutter/sherpa_onnx + path: ^1.9.0 + args: ^2.5.0 + +dev_dependencies: + lints: ^3.0.0