From 0e7cec95a0500ad392a57f179c0e3598ce284a13 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Fri, 13 Sep 2024 15:39:22 -0400 Subject: [PATCH 01/12] Add translation-related files and utilities --- .github/workflows/ci.yaml | 95 ++ .gitignore | 24 + CMakeLists.txt | 118 +++ README.md | 94 +- cmake/BuildCTranslate2.cmake | 127 +++ cmake/BuildICU.cmake | 101 +++ cmake/BuildMyCurl.cmake | 73 ++ cmake/BuildSDL.cmake | 66 ++ cmake/BuildSentencepiece.cmake | 61 ++ cmake/BuildWhispercpp.cmake | 150 ++++ cmake/FetchOnnxruntime.cmake | 97 +++ examples/CMakeLists.txt | 0 examples/audio_capture.cpp | 100 +++ examples/audio_capture.h | 33 + examples/realtime_transcription.cpp | 57 ++ include/locaal.h | 0 src/model-utils/model-downloader-types.h | 28 + src/model-utils/model-downloader-ui.cpp | 256 ++++++ src/model-utils/model-downloader-ui.h | 61 ++ src/model-utils/model-downloader.cpp | 91 ++ src/model-utils/model-downloader.h | 15 + src/model-utils/model-find-utils.cpp | 50 ++ src/model-utils/model-find-utils.h | 14 + src/model-utils/model-infos.cpp | 234 +++++ src/transcription-filter-data.h | 158 ++++ src/transcription-utils.cpp | 162 ++++ src/transcription-utils.h | 52 ++ src/translation/language_codes.cpp | 256 ++++++ src/translation/language_codes.h | 12 + src/translation/translation-includes.h | 8 + .../translation-language-utils.cpp | 33 + src/translation/translation-language-utils.h | 8 + src/translation/translation-utils.cpp | 44 + src/translation/translation-utils.h | 8 + src/translation/translation.cpp | 212 +++++ src/translation/translation.h | 48 ++ src/whisper-utils/silero-vad-onnx.cpp | 353 ++++++++ src/whisper-utils/silero-vad-onnx.h | 115 +++ src/whisper-utils/token-buffer-thread.cpp | 413 +++++++++ src/whisper-utils/token-buffer-thread.h | 105 +++ src/whisper-utils/vad-processing.cpp | 377 ++++++++ src/whisper-utils/vad-processing.h | 18 + src/whisper-utils/whisper-language.h | 814 ++++++++++++++++++ src/whisper-utils/whisper-model-utils.cpp | 142 +++ src/whisper-utils/whisper-model-utils.h | 10 + src/whisper-utils/whisper-processing.cpp | 407 +++++++++ src/whisper-utils/whisper-processing.h | 38 + src/whisper-utils/whisper-utils.cpp | 161 ++++ src/whisper-utils/whisper-utils.h | 25 + 49 files changed, 5922 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/ci.yaml create mode 100644 .gitignore create mode 100644 CMakeLists.txt create mode 100644 cmake/BuildCTranslate2.cmake create mode 100644 cmake/BuildICU.cmake create mode 100644 cmake/BuildMyCurl.cmake create mode 100644 cmake/BuildSDL.cmake create mode 100644 cmake/BuildSentencepiece.cmake create mode 100644 cmake/BuildWhispercpp.cmake create mode 100644 cmake/FetchOnnxruntime.cmake create mode 100644 examples/CMakeLists.txt create mode 100644 examples/audio_capture.cpp create mode 100644 examples/audio_capture.h create mode 100644 examples/realtime_transcription.cpp create mode 100644 include/locaal.h create mode 100644 src/model-utils/model-downloader-types.h create mode 100644 src/model-utils/model-downloader-ui.cpp create mode 100644 src/model-utils/model-downloader-ui.h create mode 100644 src/model-utils/model-downloader.cpp create mode 100644 src/model-utils/model-downloader.h create mode 100644 src/model-utils/model-find-utils.cpp create mode 100644 src/model-utils/model-find-utils.h create mode 100644 src/model-utils/model-infos.cpp create mode 100644 src/transcription-filter-data.h create mode 100644 src/transcription-utils.cpp create mode 100644 src/transcription-utils.h create mode 100644 src/translation/language_codes.cpp create mode 100644 src/translation/language_codes.h create mode 100644 src/translation/translation-includes.h create mode 100644 src/translation/translation-language-utils.cpp create mode 100644 src/translation/translation-language-utils.h create mode 100644 src/translation/translation-utils.cpp create mode 100644 src/translation/translation-utils.h create mode 100644 src/translation/translation.cpp create mode 100644 src/translation/translation.h create mode 100644 src/whisper-utils/silero-vad-onnx.cpp create mode 100644 src/whisper-utils/silero-vad-onnx.h create mode 100644 src/whisper-utils/token-buffer-thread.cpp create mode 100644 src/whisper-utils/token-buffer-thread.h create mode 100644 src/whisper-utils/vad-processing.cpp create mode 100644 src/whisper-utils/vad-processing.h create mode 100644 src/whisper-utils/whisper-language.h create mode 100644 src/whisper-utils/whisper-model-utils.cpp create mode 100644 src/whisper-utils/whisper-model-utils.h create mode 100644 src/whisper-utils/whisper-processing.cpp create mode 100644 src/whisper-utils/whisper-processing.h create mode 100644 src/whisper-utils/whisper-utils.cpp create mode 100644 src/whisper-utils/whisper-utils.h diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..4508fb8 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,95 @@ +name: CI and Release + +on: + push: + branches: [ main ] + tags: + - 'v*' + pull_request: + branches: [ main ] + +jobs: + build: + name: ${{ matrix.os }}-build + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + include: + - os: ubuntu-latest + cmake_generator: "Unix Makefiles" + - os: macos-latest + cmake_generator: "Unix Makefiles" + - os: windows-latest + cmake_generator: "Visual Studio 17 2022" + + steps: + - uses: actions/checkout@v2 + + - name: Create Build Environment + run: cmake -E make_directory ${{runner.workspace}}/build + + - name: Configure CMake + working-directory: ${{runner.workspace}}/build + run: cmake $GITHUB_WORKSPACE -G "${{ matrix.cmake_generator }}" + + - name: Build + working-directory: ${{runner.workspace}}/build + run: cmake --build . --config Release + + - name: Package + working-directory: ${{runner.workspace}}/build + run: | + cmake --install . --prefix installed + 7z a ${{ matrix.os }}-package.zip ./installed/* + + - name: Upload Artifact + uses: actions/upload-artifact@v2 + with: + name: ${{ matrix.os }}-package + path: ${{runner.workspace}}/build/${{ matrix.os }}-package.zip + + release: + needs: build + runs-on: ubuntu-latest + if: startsWith(github.ref, 'refs/tags/v') + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Download all artifacts + uses: actions/download-artifact@v2 + + - name: Create Release + id: create_release + uses: actions/create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ github.ref }} + release_name: Release ${{ github.ref }} + draft: false + prerelease: false + + - name: Upload Release Assets + uses: actions/github-script@v3 + with: + github-token: ${{secrets.GITHUB_TOKEN}} + script: | + const fs = require('fs').promises; + const { repo: { owner, repo }, sha } = context; + + for (const file of await fs.readdir('.')) { + if (file.endsWith('-package')) { + const path = `${file}/${file}.zip`; + console.log(`Uploading ${path}`); + await github.repos.uploadReleaseAsset({ + owner, + repo, + release_id: ${{ steps.create_release.outputs.id }}, + name: `${file}.zip`, + data: await fs.readFile(path) + }); + } + } diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..081b432 --- /dev/null +++ b/.gitignore @@ -0,0 +1,24 @@ +# C++ and CMake gitignore + +# Compiled object files +*.o +*.obj + +# Compiled dynamic libraries +*.so +*.dylib +*.dll + +# Executables +*.exe +*.out + +# CMake build directory +build/ + +# CMake generated files +CMakeCache.txt +CMakeFiles/ +CMakeScripts/ +cmake_install.cmake +Makefile diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..db26651 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,118 @@ +cmake_minimum_required(VERSION 3.12) +project(locaal) + +set(CMAKE_CXX_STANDARD 11) + +set(USE_SYSTEM_CURL + OFF + CACHE STRING "Use system cURL") + +if(USE_SYSTEM_CURL) + find_package(CURL REQUIRED) + target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE "${CURL_LIBRARIES}") + target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM PUBLIC "${CURL_INCLUDE_DIRS}") +else() + include(cmake/BuildMyCurl.cmake) + target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE libcurl) +endif() + +if(WIN32) + if(DEFINED ENV{ACCELERATION}) + set(ACCELERATION + $ENV{ACCELERATION} + CACHE STRING "Acceleration to use" FORCE) + endif() + if(NOT DEFINED ACCELERATION) + set(ACCELERATION + "cpu" + CACHE STRING "Acceleration to use") + endif() + set_property(CACHE ACCELERATION PROPERTY STRINGS "cpu" "hipblas" "cuda") +endif() + +include(cmake/BuildWhispercpp.cmake) +target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE Whispercpp) + +include(cmake/BuildCTranslate2.cmake) +include(cmake/BuildSentencepiece.cmake) +target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE ct2 sentencepiece) + +set(USE_SYSTEM_ONNXRUNTIME + OFF + CACHE STRING "Use system ONNX Runtime") + +set(DISABLE_ONNXRUNTIME_GPU + OFF + CACHE STRING "Disables GPU support of ONNX Runtime (Only valid on Linux)") + +if(DISABLE_ONNXRUNTIME_GPU) + target_compile_definitions(${CMAKE_PROJECT_NAME} PRIVATE DISABLE_ONNXRUNTIME_GPU) +endif() + +if(USE_SYSTEM_ONNXRUNTIME) + if(OS_LINUX) + find_package(Onnxruntime 1.16.3 REQUIRED) + set(Onnxruntime_INCLUDE_PATH + ${Onnxruntime_INCLUDE_DIR} ${Onnxruntime_INCLUDE_DIR}/onnxruntime + ${Onnxruntime_INCLUDE_DIR}/onnxruntime/core/session ${Onnxruntime_INCLUDE_DIR}/onnxruntime/core/providers/cpu) + target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE "${Onnxruntime_LIBRARIES}") + target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM PUBLIC "${Onnxruntime_INCLUDE_PATH}") + else() + message(FATAL_ERROR "System ONNX Runtime is only supported on Linux!") + endif() +else() + include(cmake/FetchOnnxruntime.cmake) +endif() + +include(cmake/BuildICU.cmake) +# Add ICU to the target +target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE ICU) +target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM PUBLIC ${ICU_INCLUDE_DIR}) + +# Add your source files here +set(SOURCES + src/locaal.cpp +) + +# Add your header files here +set(HEADERS + include/locaal.h +) + +# Create the shared library +add_library(${CMAKE_PROJECT_NAME} SHARED ${SOURCES} ${HEADERS}) + +# Set the include directories +target_include_directories(locaal PUBLIC include) + +include(GNUInstallDirs) +include(CMakePackageConfigHelpers) + +# Install the library +install(TARGETS ${CMAKE_PROJECT_NAME} + EXPORT ${CMAKE_PROJECT_NAME}Targets + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) + +# Install the headers +install(FILES ${HEADERS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${CMAKE_PROJECT_NAME}) + +# Install the cmake config files +install(EXPORT ${CMAKE_PROJECT_NAME}Targets + FILE ${CMAKE_PROJECT_NAME}Targets.cmake + NAMESPACE ${CMAKE_PROJECT_NAME}:: + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/${CMAKE_PROJECT_NAME}) + +configure_package_config_file(cmake/${CMAKE_PROJECT_NAME}Config.cmake.in + ${CMAKE_PROJECT_NAME}Config.cmake + INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/${CMAKE_PROJECT_NAME}) + +write_basic_package_version_file(${CMAKE_PROJECT_NAME}ConfigVersion.cmake + VERSION 1.0 + COMPATIBILITY AnyNewerVersion) + +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_PROJECT_NAME}Config.cmake + ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_PROJECT_NAME}ConfigVersion.cmake + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/${CMAKE_PROJECT_NAME}) diff --git a/README.md b/README.md index 2898813..6218573 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,92 @@ -# locaal-sdk -SDK for agnostic on-device AI +# Real-time Transcription and Translation Library + +## Overview + +This C++ library provides real-time transcription and translation capabilities using Whisper.cpp and CTranslate2. It's designed to work on-device without relying on cloud services, making it suitable for applications requiring privacy and offline functionality. + +Key features: +- Cross-platform support (macOS, Windows, Linux) +- Real-time speech-to-text transcription +- On-device translation +- Built with CMake for easy integration and compilation + +## Prerequisites + +Before building the library, ensure you have the following installed: +- C++ compiler with C++17 support +- CMake (version 3.12 or higher) +- Git + +## Building the Library + +### macOS + +1. Open Terminal and navigate to the project directory. +2. Run the following commands: + +```bash +mkdir build && cd build +cmake .. +make +``` + +### Windows + +1. Open Command Prompt or PowerShell and navigate to the project directory. +2. Run the following commands: + +```cmd +mkdir build +cd build +cmake .. -G "Visual Studio 16 2019" -A x64 +cmake --build . --config Release +``` + +Note: Adjust the Visual Studio version as needed. + +### Linux + +1. Open a terminal and navigate to the project directory. +2. Run the following commands: + +```bash +mkdir build && cd build +cmake .. +make +``` + +## Usage + +After building the library, you can include it in your C++ project. Here's a basic example of how to use the library: + +```cpp +#include + +int main() { + // Initialize the library + locaal::TranscriptionTranslation tt; + + // Start real-time transcription + tt.startTranscription(); + + // Translate text + std::string translated = tt.translate("Hello, world!", "en", "fr"); + + return 0; +} +``` + +For more detailed usage instructions and API documentation, please refer to the `docs` folder and the `examples` folder. + +## Contributing + +Contributions are welcome! Please feel free to submit a Pull Request. + +## License + +This project is licensed under the MIT License - see the LICENSE file for details. + +## Acknowledgments + +- [Whisper.cpp](https://github.com/ggerganov/whisper.cpp) +- [CTranslate2](https://github.com/OpenNMT/CTranslate2) diff --git a/cmake/BuildCTranslate2.cmake b/cmake/BuildCTranslate2.cmake new file mode 100644 index 0000000..0d60561 --- /dev/null +++ b/cmake/BuildCTranslate2.cmake @@ -0,0 +1,127 @@ +# build the CTranslate2 library from source https://github.com/OpenNMT/CTranslate2.git + +include(ExternalProject) +include(FetchContent) + +if(APPLE) + + FetchContent_Declare( + ctranslate2_fetch + URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.2.0/libctranslate2-macos-Release-1.2.0.tar.gz + URL_HASH SHA256=9029F19B0F50E5EDC14473479EDF0A983F7D6FA00BE61DC1B01BF8AA7F1CDB1B) + FetchContent_MakeAvailable(ctranslate2_fetch) + + add_library(ct2 INTERFACE) + target_link_libraries(ct2 INTERFACE "-framework Accelerate" ${ctranslate2_fetch_SOURCE_DIR}/lib/libctranslate2.a + ${ctranslate2_fetch_SOURCE_DIR}/lib/libcpu_features.a) + set_target_properties(ct2 PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${ctranslate2_fetch_SOURCE_DIR}/include) + target_compile_options(ct2 INTERFACE -Wno-shorten-64-to-32 -Wno-comma) + +elseif(WIN32) + + # check ACCELERATION environment variable + if(NOT DEFINED ACCELERATION) + message(FATAL_ERROR "Please set ACCELERATION to either `cpu`, `hipblas`, or `cuda`") + endif() + + if(${ACCELERATION} STREQUAL "cpu" OR ${ACCELERATION} STREQUAL "hipblas") + FetchContent_Declare( + ctranslate2_fetch + URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.2.0/libctranslate2-windows-4.1.1-Release-cpu.zip + URL_HASH SHA256=30ff8b2499b8d3b5a6c4d6f7f8ddbc89e745ff06e0050b645e3b7c9b369451a3) + else() + # add compile definitions for CUDA + add_compile_definitions(POLYGLOT_WITH_CUDA) + add_compile_definitions(POLYGLOT_CUDA_VERSION="12.2.0") + + FetchContent_Declare( + ctranslate2_fetch + URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.2.0/libctranslate2-windows-4.1.1-Release-cuda12.2.0.zip + URL_HASH SHA256=131724d510f9f2829970953a1bc9e4e8fb7b4cbc8218e32270dcfe6172a51558) + endif() + + FetchContent_MakeAvailable(ctranslate2_fetch) + + add_library(ct2 INTERFACE) + target_link_libraries(ct2 INTERFACE ${ctranslate2_fetch_SOURCE_DIR}/lib/ctranslate2.lib) + set_target_properties(ct2 PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${ctranslate2_fetch_SOURCE_DIR}/include) + target_compile_options(ct2 INTERFACE /wd4267 /wd4244 /wd4305 /wd4996 /wd4099) + + file(GLOB CT2_DLLS ${ctranslate2_fetch_SOURCE_DIR}/bin/*.dll) + install(FILES ${CT2_DLLS} DESTINATION "obs-plugins/64bit") +else() + # build cpu_features from source + set(CPU_FEATURES_VERSION "0.9.0") + set(CPU_FEATURES_URL "https://github.com/google/cpu_features.git") + set(CPU_FEATURES_CMAKE_ARGS -DBUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF) + ExternalProject_Add( + cpu_features_build + GIT_REPOSITORY ${CPU_FEATURES_URL} + GIT_TAG v${CPU_FEATURES_VERSION} + GIT_PROGRESS 1 + BUILD_COMMAND ${CMAKE_COMMAND} --build --config ${CMAKE_BUILD_TYPE} + CMAKE_GENERATOR ${CMAKE_GENERATOR} + INSTALL_COMMAND ${CMAKE_COMMAND} --install --config ${CMAKE_BUILD_TYPE} + BUILD_BYPRODUCTS /lib/${CMAKE_STATIC_LIBRARY_PREFIX}cpu_features${CMAKE_STATIC_LIBRARY_SUFFIX} + CMAKE_ARGS -DCMAKE_GENERATOR_PLATFORM=${CMAKE_GENERATOR_PLATFORM} -DCMAKE_INSTALL_PREFIX= + -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} ${CPU_FEATURES_CMAKE_ARGS} + LOG_CONFIGURE ON + LOG_BUILD ON + LOG_INSTALL ON) + ExternalProject_Get_Property(cpu_features_build INSTALL_DIR) + + add_library(cpu_features STATIC IMPORTED GLOBAL) + add_dependencies(cpu_features cpu_features_build) + set_target_properties( + cpu_features PROPERTIES IMPORTED_LOCATION + ${INSTALL_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}cpu_features${CMAKE_STATIC_LIBRARY_SUFFIX}) + set_target_properties(cpu_features PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${INSTALL_DIR}/include) + + # build CTranslate2 from source + set(CT2_VERSION "4.1.1") + set(CT2_URL "https://github.com/OpenNMT/CTranslate2.git") + set(CT2_OPENBLAS_CMAKE_ARGS -DWITH_OPENBLAS=OFF) + + set(CT2_CMAKE_PLATFORM_OPTIONS -DBUILD_SHARED_LIBS=OFF -DOPENMP_RUNTIME=NONE -DCMAKE_POSITION_INDEPENDENT_CODE=ON) + set(CT2_LIB_INSTALL_LOCATION lib/${CMAKE_SHARED_LIBRARY_PREFIX}ctranslate2${CMAKE_STATIC_LIBRARY_SUFFIX}) + + ExternalProject_Add( + ct2_build + GIT_REPOSITORY ${CT2_URL} + GIT_TAG v${CT2_VERSION} + GIT_PROGRESS 1 + BUILD_COMMAND ${CMAKE_COMMAND} --build --config ${CMAKE_BUILD_TYPE} + CMAKE_GENERATOR ${CMAKE_GENERATOR} + INSTALL_COMMAND ${CMAKE_COMMAND} --install --config ${CMAKE_BUILD_TYPE} + BUILD_BYPRODUCTS /${CT2_LIB_INSTALL_LOCATION} + CMAKE_ARGS -DCMAKE_GENERATOR_PLATFORM=${CMAKE_GENERATOR_PLATFORM} + -DCMAKE_INSTALL_PREFIX= + -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} + -DWITH_CUDA=OFF + -DWITH_MKL=OFF + -DWITH_TESTS=OFF + -DWITH_EXAMPLES=OFF + -DWITH_TFLITE=OFF + -DWITH_TRT=OFF + -DWITH_PYTHON=OFF + -DWITH_SERVER=OFF + -DWITH_COVERAGE=OFF + -DWITH_PROFILING=OFF + -DBUILD_CLI=OFF + ${CT2_OPENBLAS_CMAKE_ARGS} + ${CT2_CMAKE_PLATFORM_OPTIONS} + LOG_CONFIGURE ON + LOG_BUILD ON + LOG_INSTALL ON) + + ExternalProject_Get_Property(ct2_build INSTALL_DIR) + + add_library(ct2::ct2 STATIC IMPORTED GLOBAL) + add_dependencies(ct2::ct2 ct2_build cpu_features_build) + set_target_properties(ct2::ct2 PROPERTIES IMPORTED_LOCATION ${INSTALL_DIR}/${CT2_LIB_INSTALL_LOCATION}) + set_target_properties(ct2::ct2 PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${INSTALL_DIR}/include) + + add_library(ct2 INTERFACE) + target_link_libraries(ct2 INTERFACE ct2::ct2 cpu_features) + +endif() diff --git a/cmake/BuildICU.cmake b/cmake/BuildICU.cmake new file mode 100644 index 0000000..a3c575d --- /dev/null +++ b/cmake/BuildICU.cmake @@ -0,0 +1,101 @@ +include(FetchContent) +include(ExternalProject) + +set(ICU_VERSION "75.1") +set(ICU_VERSION_UNDERSCORE "75_1") +set(ICU_VERSION_DASH "75-1") +set(ICU_VERSION_NO_MINOR "75") + +if(WIN32) + set(ICU_URL + "https://github.com/unicode-org/icu/releases/download/release-${ICU_VERSION_DASH}/icu4c-${ICU_VERSION_UNDERSCORE}-Win64-MSVC2022.zip" + ) + set(ICU_HASH "SHA256=7ac9c0dc6ccc1ec809c7d5689b8d831c5b8f6b11ecf70fdccc55f7ae8731ac8f") + + FetchContent_Declare( + ICU_build + URL ${ICU_URL} + URL_HASH ${ICU_HASH}) + + FetchContent_MakeAvailable(ICU_build) + + # Assuming the ZIP structure, adjust paths as necessary + set(ICU_INCLUDE_DIR "${icu_build_SOURCE_DIR}/include") + set(ICU_LIBRARY_DIR "${icu_build_SOURCE_DIR}/lib64") + set(ICU_BINARY_DIR "${icu_build_SOURCE_DIR}/bin64") + + # Define the library names + set(ICU_LIBRARIES icudt icuuc icuin) + + foreach(lib ${ICU_LIBRARIES}) + # Add ICU library + find_library( + ICU_LIB_${lib} + NAMES ${lib} + PATHS ${ICU_LIBRARY_DIR} + NO_DEFAULT_PATH REQUIRED) + # find the dll + find_file( + ICU_DLL_${lib} + NAMES ${lib}${ICU_VERSION_NO_MINOR}.dll + PATHS ${ICU_BINARY_DIR} + NO_DEFAULT_PATH) + # Copy the DLLs to the output directory + install(FILES ${ICU_DLL_${lib}} DESTINATION "obs-plugins/64bit") + # add the library + add_library(ICU::${lib} SHARED IMPORTED GLOBAL) + set_target_properties(ICU::${lib} PROPERTIES IMPORTED_LOCATION "${ICU_LIB_${lib}}" IMPORTED_IMPLIB + "${ICU_LIB_${lib}}") + endforeach() +else() + set(ICU_URL + "https://github.com/unicode-org/icu/releases/download/release-${ICU_VERSION_DASH}/icu4c-${ICU_VERSION_UNDERSCORE}-src.tgz" + ) + set(ICU_HASH "SHA256=cb968df3e4d2e87e8b11c49a5d01c787bd13b9545280fc6642f826527618caef") + if(APPLE) + set(ICU_PLATFORM "MacOSX") + set(TARGET_ARCH -arch\ $ENV{MACOS_ARCH}) + set(ICU_BUILD_ENV_VARS CFLAGS=${TARGET_ARCH} CXXFLAGS=${TARGET_ARCH} LDFLAGS=${TARGET_ARCH}) + else() + set(ICU_PLATFORM "Linux") + set(ICU_BUILD_ENV_VARS CFLAGS=-fPIC CXXFLAGS=-fPIC LDFLAGS=-fPIC) + endif() + + ExternalProject_Add( + ICU_build + DOWNLOAD_EXTRACT_TIMESTAMP true + GIT_REPOSITORY "https://github.com/unicode-org/icu.git" + GIT_TAG "release-${ICU_VERSION_DASH}" + CONFIGURE_COMMAND ${CMAKE_COMMAND} -E env ${ICU_BUILD_ENV_VARS} /icu4c/source/runConfigureICU + ${ICU_PLATFORM} --prefix= --enable-static --disable-shared + BUILD_COMMAND make -j4 + BUILD_BYPRODUCTS + /lib/${CMAKE_STATIC_LIBRARY_PREFIX}icudata${CMAKE_STATIC_LIBRARY_SUFFIX} + /lib/${CMAKE_STATIC_LIBRARY_PREFIX}icuuc${CMAKE_STATIC_LIBRARY_SUFFIX} + /lib/${CMAKE_STATIC_LIBRARY_PREFIX}icui18n${CMAKE_STATIC_LIBRARY_SUFFIX} + INSTALL_COMMAND make install + BUILD_IN_SOURCE 1) + + ExternalProject_Get_Property(ICU_build INSTALL_DIR) + + set(ICU_INCLUDE_DIR "${INSTALL_DIR}/include") + set(ICU_LIBRARY_DIR "${INSTALL_DIR}/lib") + + set(ICU_LIBRARIES icudata icuuc icui18n) + + foreach(lib ${ICU_LIBRARIES}) + add_library(ICU::${lib} STATIC IMPORTED GLOBAL) + add_dependencies(ICU::${lib} ICU_build) + set(ICU_LIBRARY "${ICU_LIBRARY_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}${lib}${CMAKE_STATIC_LIBRARY_SUFFIX}") + set_target_properties(ICU::${lib} PROPERTIES IMPORTED_LOCATION "${ICU_LIBRARY}" INTERFACE_INCLUDE_DIRECTORIES + "${ICU_INCLUDE_DIR}") + endforeach(lib ${ICU_LIBRARIES}) +endif() + +# Create an interface target for ICU +add_library(ICU INTERFACE) +add_dependencies(ICU ICU_build) +foreach(lib ${ICU_LIBRARIES}) + target_link_libraries(ICU INTERFACE ICU::${lib}) +endforeach() +target_include_directories(ICU SYSTEM INTERFACE $) diff --git a/cmake/BuildMyCurl.cmake b/cmake/BuildMyCurl.cmake new file mode 100644 index 0000000..10d3e05 --- /dev/null +++ b/cmake/BuildMyCurl.cmake @@ -0,0 +1,73 @@ +include(FetchContent) + +set(LibCurl_VERSION "8.4.0-3") +set(LibCurl_BASEURL "https://github.com/occ-ai/obs-ai-libcurl-dep/releases/download/${LibCurl_VERSION}") + +if(${CMAKE_BUILD_TYPE} STREQUAL Release OR ${CMAKE_BUILD_TYPE} STREQUAL RelWithDebInfo) + set(LibCurl_BUILD_TYPE Release) +else() + set(LibCurl_BUILD_TYPE Debug) +endif() + +if(APPLE) + if(LibCurl_BUILD_TYPE STREQUAL Release) + set(LibCurl_URL "${LibCurl_BASEURL}/libcurl-macos-${LibCurl_VERSION}-Release.tar.gz") + set(LibCurl_HASH SHA256=5ef7bfed2c2bca17ba562aede6a3c3eb465b8d7516cff86ca0f0d0337de951e1) + else() + set(LibCurl_URL "${LibCurl_BASEURL}/libcurl-macos-${LibCurl_VERSION}-Debug.tar.gz") + set(LibCurl_HASH SHA256=da0801168eac5103e6b27bfd0f56f82e0617f85e4e6c69f476071dbba273403b) + endif() +elseif(MSVC) + if(LibCurl_BUILD_TYPE STREQUAL Release) + set(LibCurl_URL "${LibCurl_BASEURL}/libcurl-windows-${LibCurl_VERSION}-Release.zip") + set(LibCurl_HASH SHA256=bf4d4cd7d741712a2913df0994258d11aabe22c9a305c9f336ed59e76f351adf) + else() + set(LibCurl_URL "${LibCurl_BASEURL}/libcurl-windows-${LibCurl_VERSION}-Debug.zip") + set(LibCurl_HASH SHA256=9fe20e677ffb0d7dd927b978d532e23574cdb1923e2d2ca7c5e42f1fff2ec529) + endif() +else() + if(LibCurl_BUILD_TYPE STREQUAL Release) + set(LibCurl_URL "${LibCurl_BASEURL}/libcurl-linux-${LibCurl_VERSION}-Release.tar.gz") + set(LibCurl_HASH SHA256=f2cd80b7d3288fe5b4c90833bcbf0bde7c9574bc60eddb13015df19c5a09f56b) + else() + set(LibCurl_URL "${LibCurl_BASEURL}/libcurl-linux-${LibCurl_VERSION}-Debug.tar.gz") + set(LibCurl_HASH SHA256=6a41d3daef98acc3172b3702118dcf1cccbde923f3836ed2f4f3ed7301e47b8b) + endif() +endif() + +FetchContent_Declare( + libcurl_fetch + URL ${LibCurl_URL} + URL_HASH ${LibCurl_HASH}) +FetchContent_MakeAvailable(libcurl_fetch) + +if(MSVC) + set(libcurl_fetch_lib_location "${libcurl_fetch_SOURCE_DIR}/lib/libcurl.lib") + set(libcurl_fetch_link_libs "\$;\$;\$;\$") +else() + find_package(ZLIB REQUIRED) + set(libcurl_fetch_lib_location "${libcurl_fetch_SOURCE_DIR}/lib/libcurl.a") + if(UNIX AND NOT APPLE) + find_package(OpenSSL REQUIRED) + set(libcurl_fetch_link_libs "\$;\$;\$") + else() + set(libcurl_fetch_link_libs + "-framework SystemConfiguration;-framework Security;-framework CoreFoundation;-framework CoreServices;ZLIB::ZLIB" + ) + endif() +endif() + +# Create imported target +add_library(libcurl STATIC IMPORTED) + +set_target_properties( + libcurl + PROPERTIES INTERFACE_COMPILE_DEFINITIONS "CURL_STATICLIB" + INTERFACE_INCLUDE_DIRECTORIES "${libcurl_fetch_SOURCE_DIR}/include" + INTERFACE_LINK_LIBRARIES "${libcurl_fetch_link_libs}") +set_property( + TARGET libcurl + APPEND + PROPERTY IMPORTED_CONFIGURATIONS RELEASE) +set_target_properties(libcurl PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES_RELEASE "C" IMPORTED_LOCATION_RELEASE + ${libcurl_fetch_lib_location}) diff --git a/cmake/BuildSDL.cmake b/cmake/BuildSDL.cmake new file mode 100644 index 0000000..772d757 --- /dev/null +++ b/cmake/BuildSDL.cmake @@ -0,0 +1,66 @@ +# Include ExternalProject module +include(ExternalProject) + +# Set SDL version +set(SDL_VERSION "2.28.2") + +# Define SDL installation directory +set(SDL_INSTALL_DIR "${CMAKE_BINARY_DIR}/sdl_install") + +# ExternalProject for SDL2 +ExternalProject_Add( + SDL2_external + GIT_REPOSITORY https://github.com/libsdl-org/SDL.git + GIT_TAG release-${SDL_VERSION} + CMAKE_ARGS + -DCMAKE_INSTALL_PREFIX=${SDL_INSTALL_DIR} + -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} + -DSDL_STATIC=ON + -DSDL_SHARED=OFF + BUILD_BYPRODUCTS + "${SDL_INSTALL_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}SDL2${CMAKE_STATIC_LIBRARY_SUFFIX}" + "${SDL_INSTALL_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}SDL2main${CMAKE_STATIC_LIBRARY_SUFFIX}" +) + +# Create interface library for SDL2 +add_library(SDL2 INTERFACE) +add_dependencies(SDL2 SDL2_external) + +# Set include directories for the interface library +target_include_directories(SDL2 INTERFACE + $ + $ +) + +# Link SDL2 and SDL2main libraries +target_link_libraries(SDL2 INTERFACE + $ + $ +) + +# Platform-specific configurations +if(WIN32) + target_link_libraries(SDL2 INTERFACE imm32 version winmm setupapi) +elseif(APPLE) + target_link_libraries(SDL2 INTERFACE "-framework Cocoa" "-framework IOKit" "-framework CoreAudio" "-framework CoreVideo") +else() + # Linux + find_package(Threads REQUIRED) + target_link_libraries(SDL2 INTERFACE Threads::Threads dl) +endif() + +# Export the targets +install(TARGETS SDL2 + EXPORT SDL2Targets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin + INCLUDES DESTINATION include +) + +# Export the targets file +install(EXPORT SDL2Targets + FILE SDL2Targets.cmake + NAMESPACE SDL2:: + DESTINATION lib/cmake/SDL2 +) diff --git a/cmake/BuildSentencepiece.cmake b/cmake/BuildSentencepiece.cmake new file mode 100644 index 0000000..024283e --- /dev/null +++ b/cmake/BuildSentencepiece.cmake @@ -0,0 +1,61 @@ +# build sentencepiece from "https://github.com/google/sentencepiece.git" + +if(APPLE) + + include(FetchContent) + + FetchContent_Declare( + sentencepiece_fetch + URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.1.1/libsentencepiece-macos-Release-1.1.1.tar.gz + URL_HASH SHA256=c911f1e84ea94925a8bc3fd3257185b2e18395075509c8659cc7003a979e0b32) + FetchContent_MakeAvailable(sentencepiece_fetch) + add_library(sentencepiece INTERFACE) + target_link_libraries(sentencepiece INTERFACE ${sentencepiece_fetch_SOURCE_DIR}/lib/libsentencepiece.a) + set_target_properties(sentencepiece PROPERTIES INTERFACE_INCLUDE_DIRECTORIES + ${sentencepiece_fetch_SOURCE_DIR}/include) +elseif(WIN32) + + FetchContent_Declare( + sentencepiece_fetch + URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.1.1/sentencepiece-windows-0.2.0-Release.zip + URL_HASH SHA256=846699c7fa1e8918b71ed7f2bd5cd60e47e51105e1d84e3192919b4f0f10fdeb) + FetchContent_MakeAvailable(sentencepiece_fetch) + add_library(sentencepiece INTERFACE) + target_link_libraries(sentencepiece INTERFACE ${sentencepiece_fetch_SOURCE_DIR}/lib/sentencepiece.lib) + set_target_properties(sentencepiece PROPERTIES INTERFACE_INCLUDE_DIRECTORIES + ${sentencepiece_fetch_SOURCE_DIR}/include) + +else() + + set(SP_URL + "https://github.com/google/sentencepiece.git" + CACHE STRING "URL of sentencepiece repository") + + set(SP_CMAKE_OPTIONS -DSPM_ENABLE_SHARED=OFF) + set(SENTENCEPIECE_INSTALL_LIB_LOCATION lib/${CMAKE_STATIC_LIBRARY_PREFIX}sentencepiece${CMAKE_STATIC_LIBRARY_SUFFIX}) + + include(ExternalProject) + + ExternalProject_Add( + sentencepiece_build + GIT_REPOSITORY ${SP_URL} + GIT_TAG v0.1.99 + BUILD_COMMAND ${CMAKE_COMMAND} --build --config ${CMAKE_BUILD_TYPE} + CMAKE_GENERATOR ${CMAKE_GENERATOR} + INSTALL_COMMAND ${CMAKE_COMMAND} --install --config ${CMAKE_BUILD_TYPE} + BUILD_BYPRODUCTS /${SENTENCEPIECE_INSTALL_LIB_LOCATION} + CMAKE_ARGS -DCMAKE_GENERATOR_PLATFORM=${CMAKE_GENERATOR_PLATFORM} -DCMAKE_INSTALL_PREFIX= + -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} ${SP_CMAKE_OPTIONS}) + ExternalProject_Get_Property(sentencepiece_build INSTALL_DIR) + + add_library(libsentencepiece STATIC IMPORTED GLOBAL) + add_dependencies(libsentencepiece sentencepiece_build) + set_target_properties(libsentencepiece PROPERTIES IMPORTED_LOCATION + ${INSTALL_DIR}/${SENTENCEPIECE_INSTALL_LIB_LOCATION}) + + add_library(sentencepiece INTERFACE) + add_dependencies(sentencepiece libsentencepiece) + target_link_libraries(sentencepiece INTERFACE libsentencepiece) + target_include_directories(sentencepiece INTERFACE ${INSTALL_DIR}/include) + +endif() diff --git a/cmake/BuildWhispercpp.cmake b/cmake/BuildWhispercpp.cmake new file mode 100644 index 0000000..66e0f0b --- /dev/null +++ b/cmake/BuildWhispercpp.cmake @@ -0,0 +1,150 @@ +include(ExternalProject) +include(FetchContent) + +set(PREBUILT_WHISPERCPP_VERSION "0.0.6") +set(PREBUILT_WHISPERCPP_URL_BASE + "https://github.com/occ-ai/occ-ai-dep-whispercpp/releases/download/${PREBUILT_WHISPERCPP_VERSION}") + +if(APPLE) + # check the "MACOS_ARCH" env var to figure out if this is x86 or arm64 + if($ENV{MACOS_ARCH} STREQUAL "x86_64") + set(WHISPER_CPP_HASH "454abee900a96a0a10a91f631ff797bdbdf2df0d2a819479a409634c9be1e12c") + elseif($ENV{MACOS_ARCH} STREQUAL "arm64") + set(WHISPER_CPP_HASH "f726388cc494f6fca864c860af6c1bc2932c3dc823ef92197b1e29f088425668") + else() + message( + FATAL_ERROR + "The MACOS_ARCH environment variable is not set to a valid value. Please set it to either `x86_64` or `arm64`") + endif() + set(WHISPER_CPP_URL + "${PREBUILT_WHISPERCPP_URL_BASE}/whispercpp-macos-$ENV{MACOS_ARCH}-${PREBUILT_WHISPERCPP_VERSION}.tar.gz") + + FetchContent_Declare( + whispercpp_fetch + URL ${WHISPER_CPP_URL} + URL_HASH SHA256=${WHISPER_CPP_HASH}) + FetchContent_MakeAvailable(whispercpp_fetch) + + add_library(Whispercpp::Whisper STATIC IMPORTED) + set_target_properties( + Whispercpp::Whisper + PROPERTIES IMPORTED_LOCATION + ${whispercpp_fetch_SOURCE_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}whisper${CMAKE_STATIC_LIBRARY_SUFFIX}) + set_target_properties(Whispercpp::Whisper PROPERTIES INTERFACE_INCLUDE_DIRECTORIES + ${whispercpp_fetch_SOURCE_DIR}/include) + add_library(Whispercpp::GGML STATIC IMPORTED) + set_target_properties( + Whispercpp::GGML + PROPERTIES IMPORTED_LOCATION + ${whispercpp_fetch_SOURCE_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}ggml${CMAKE_STATIC_LIBRARY_SUFFIX}) + + add_library(Whispercpp::CoreML STATIC IMPORTED) + set_target_properties( + Whispercpp::CoreML + PROPERTIES + IMPORTED_LOCATION + ${whispercpp_fetch_SOURCE_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}whisper.coreml${CMAKE_STATIC_LIBRARY_SUFFIX}) + +elseif(WIN32) + if(NOT DEFINED ACCELERATION) + message(FATAL_ERROR "ACCELERATION is not set. Please set it to either `cpu`, `cuda` or `hipblas`") + endif() + + set(ARCH_PREFIX ${ACCELERATION}) + set(WHISPER_CPP_URL + "${PREBUILT_WHISPERCPP_URL_BASE}/whispercpp-windows-${ARCH_PREFIX}-${PREBUILT_WHISPERCPP_VERSION}.zip") + if(${ACCELERATION} STREQUAL "cpu") + set(WHISPER_CPP_HASH "126c5d859e902b4cd0f2cd09304a68750f1dbc6a7aa62e280cfd56c51a6a1c95") + add_compile_definitions("LOCALVOCAL_WITH_CPU") + elseif(${ACCELERATION} STREQUAL "cuda") + set(WHISPER_CPP_HASH "5b9592c311a7f1612894ca0b36f6bd4effb6a46acd03d33924df56c52f566779") + add_compile_definitions("LOCALVOCAL_WITH_CUDA") + elseif(${ACCELERATION} STREQUAL "hipblas") + set(WHISPER_CPP_HASH "c306ecce16cd10f377fdefbf7bb252abac8e6638a2637f82b1f1f32dd2cb4e39") + add_compile_definitions("LOCALVOCAL_WITH_HIPBLAS") + else() + message( + FATAL_ERROR + "The ACCELERATION environment variable is not set to a valid value. Please set it to either `cpu` or `cuda` or `hipblas`" + ) + endif() + + FetchContent_Declare( + whispercpp_fetch + URL ${WHISPER_CPP_URL} + URL_HASH SHA256=${WHISPER_CPP_HASH} + DOWNLOAD_EXTRACT_TIMESTAMP TRUE) + FetchContent_MakeAvailable(whispercpp_fetch) + + add_library(Whispercpp::Whisper SHARED IMPORTED) + set_target_properties( + Whispercpp::Whisper + PROPERTIES IMPORTED_LOCATION + ${whispercpp_fetch_SOURCE_DIR}/bin/${CMAKE_SHARED_LIBRARY_PREFIX}whisper${CMAKE_SHARED_LIBRARY_SUFFIX}) + set_target_properties( + Whispercpp::Whisper + PROPERTIES IMPORTED_IMPLIB + ${whispercpp_fetch_SOURCE_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}whisper${CMAKE_STATIC_LIBRARY_SUFFIX}) + set_target_properties(Whispercpp::Whisper PROPERTIES INTERFACE_INCLUDE_DIRECTORIES + ${whispercpp_fetch_SOURCE_DIR}/include) + + if(${ACCELERATION} STREQUAL "cpu") + # add openblas to the link line + add_library(Whispercpp::OpenBLAS STATIC IMPORTED) + set_target_properties(Whispercpp::OpenBLAS PROPERTIES IMPORTED_LOCATION + ${whispercpp_fetch_SOURCE_DIR}/lib/libopenblas.dll.a) + endif() + + # glob all dlls in the bin directory and install them + file(GLOB WHISPER_DLLS ${whispercpp_fetch_SOURCE_DIR}/bin/*.dll) + install(FILES ${WHISPER_DLLS} DESTINATION "obs-plugins/64bit") +else() + if(${CMAKE_BUILD_TYPE} STREQUAL Release OR ${CMAKE_BUILD_TYPE} STREQUAL RelWithDebInfo) + set(Whispercpp_BUILD_TYPE Release) + else() + set(Whispercpp_BUILD_TYPE Debug) + endif() + set(Whispercpp_Build_GIT_TAG "v1.6.2") + set(WHISPER_EXTRA_CXX_FLAGS "-fPIC") + set(WHISPER_ADDITIONAL_CMAKE_ARGS -DWHISPER_BLAS=OFF -DWHISPER_CUBLAS=OFF -DWHISPER_OPENBLAS=OFF) + + # On Linux build a static Whisper library + ExternalProject_Add( + Whispercpp_Build + DOWNLOAD_EXTRACT_TIMESTAMP true + GIT_REPOSITORY https://github.com/ggerganov/whisper.cpp.git + GIT_TAG ${Whispercpp_Build_GIT_TAG} + BUILD_COMMAND ${CMAKE_COMMAND} --build --config ${Whispercpp_BUILD_TYPE} + BUILD_BYPRODUCTS /lib/static/${CMAKE_STATIC_LIBRARY_PREFIX}whisper${CMAKE_STATIC_LIBRARY_SUFFIX} + CMAKE_GENERATOR ${CMAKE_GENERATOR} + INSTALL_COMMAND ${CMAKE_COMMAND} --install --config ${Whispercpp_BUILD_TYPE} && ${CMAKE_COMMAND} -E + copy /ggml.h /include + CONFIGURE_COMMAND + ${CMAKE_COMMAND} -E env ${WHISPER_ADDITIONAL_ENV} ${CMAKE_COMMAND} -B -G + ${CMAKE_GENERATOR} -DCMAKE_INSTALL_PREFIX= -DCMAKE_BUILD_TYPE=${Whispercpp_BUILD_TYPE} + -DCMAKE_GENERATOR_PLATFORM=${CMAKE_GENERATOR_PLATFORM} -DCMAKE_OSX_DEPLOYMENT_TARGET=10.13 + -DCMAKE_OSX_ARCHITECTURES=${CMAKE_OSX_ARCHITECTURES_} -DCMAKE_CXX_FLAGS=${WHISPER_EXTRA_CXX_FLAGS} + -DCMAKE_C_FLAGS=${WHISPER_EXTRA_CXX_FLAGS} -DBUILD_SHARED_LIBS=OFF -DWHISPER_BUILD_TESTS=OFF + -DWHISPER_BUILD_EXAMPLES=OFF ${WHISPER_ADDITIONAL_CMAKE_ARGS}) + + ExternalProject_Get_Property(Whispercpp_Build INSTALL_DIR) + + # add the static Whisper library to the link line + add_library(Whispercpp::Whisper STATIC IMPORTED) + set_target_properties( + Whispercpp::Whisper + PROPERTIES IMPORTED_LOCATION + ${INSTALL_DIR}/lib/static/${CMAKE_STATIC_LIBRARY_PREFIX}whisper${CMAKE_STATIC_LIBRARY_SUFFIX}) + set_target_properties(Whispercpp::Whisper PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${INSTALL_DIR}/include) +endif() + +add_library(Whispercpp INTERFACE) +add_dependencies(Whispercpp Whispercpp_Build) +target_link_libraries(Whispercpp INTERFACE Whispercpp::Whisper) +if(WIN32 AND "${ACCELERATION}" STREQUAL "cpu") + target_link_libraries(Whispercpp INTERFACE Whispercpp::OpenBLAS) +endif() +if(APPLE) + target_link_libraries(Whispercpp INTERFACE "-framework Accelerate -framework CoreML -framework Metal") + target_link_libraries(Whispercpp INTERFACE Whispercpp::GGML Whispercpp::CoreML) +endif(APPLE) diff --git a/cmake/FetchOnnxruntime.cmake b/cmake/FetchOnnxruntime.cmake new file mode 100644 index 0000000..0ed2975 --- /dev/null +++ b/cmake/FetchOnnxruntime.cmake @@ -0,0 +1,97 @@ +include(FetchContent) + +set(CUSTOM_ONNXRUNTIME_URL + "" + CACHE STRING "URL of a downloaded ONNX Runtime tarball") + +set(CUSTOM_ONNXRUNTIME_HASH + "" + CACHE STRING "Hash of a downloaded ONNX Runtime tarball") + +set(Onnxruntime_VERSION "1.17.1") + +if(CUSTOM_ONNXRUNTIME_URL STREQUAL "") + set(USE_PREDEFINED_ONNXRUNTIME ON) +else() + if(CUSTOM_ONNXRUNTIME_HASH STREQUAL "") + message(FATAL_ERROR "Both of CUSTOM_ONNXRUNTIME_URL and CUSTOM_ONNXRUNTIME_HASH must be present!") + else() + set(USE_PREDEFINED_ONNXRUNTIME OFF) + endif() +endif() + +if(USE_PREDEFINED_ONNXRUNTIME) + set(Onnxruntime_BASEURL "https://github.com/microsoft/onnxruntime/releases/download/v${Onnxruntime_VERSION}") + + if(APPLE) + set(Onnxruntime_URL "${Onnxruntime_BASEURL}/onnxruntime-osx-universal2-${Onnxruntime_VERSION}.tgz") + set(Onnxruntime_HASH SHA256=9FA57FA6F202A373599377EF75064AE568FDA8DA838632B26A86024C7378D306) + elseif(MSVC) + set(Onnxruntime_URL "${Onnxruntime_BASEURL}/onnxruntime-win-x64-${Onnxruntime_VERSION}.zip") + set(OOnnxruntime_HASH SHA256=4802AF9598DB02153D7DA39432A48823FF69B2FB4B59155461937F20782AA91C) + else() + if(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") + set(Onnxruntime_URL "${Onnxruntime_BASEURL}/onnxruntime-linux-aarch64-${Onnxruntime_VERSION}.tgz") + set(Onnxruntime_HASH SHA256=70B6F536BB7AB5961D128E9DBD192368AC1513BFFB74FE92F97AAC342FBD0AC1) + else() + set(Onnxruntime_URL "${Onnxruntime_BASEURL}/onnxruntime-linux-x64-gpu-${Onnxruntime_VERSION}.tgz") + set(Onnxruntime_HASH SHA256=613C53745EA4960ED368F6B3AB673558BB8561C84A8FA781B4EA7FB4A4340BE4) + endif() + endif() +else() + set(Onnxruntime_URL "${CUSTOM_ONNXRUNTIME_URL}") + set(Onnxruntime_HASH "${CUSTOM_ONNXRUNTIME_HASH}") +endif() + +FetchContent_Declare( + onnxruntime + URL ${Onnxruntime_URL} + URL_HASH ${Onnxruntime_HASH}) +FetchContent_MakeAvailable(onnxruntime) + +if(APPLE) + set(Onnxruntime_LIB "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime.${Onnxruntime_VERSION}.dylib") + target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE "${Onnxruntime_LIB}") + target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM PUBLIC "${onnxruntime_SOURCE_DIR}/include") + target_sources(${CMAKE_PROJECT_NAME} PRIVATE "${Onnxruntime_LIB}") + set_property(SOURCE "${Onnxruntime_LIB}" PROPERTY MACOSX_PACKAGE_LOCATION Frameworks) + source_group("Frameworks" FILES "${Onnxruntime_LIB}") + # add a codesigning step + add_custom_command( + TARGET "${CMAKE_PROJECT_NAME}" + PRE_BUILD VERBATIM + COMMAND /usr/bin/codesign --force --verify --verbose --sign "${CODESIGN_IDENTITY}" "${Onnxruntime_LIB}") + add_custom_command( + TARGET "${CMAKE_PROJECT_NAME}" + POST_BUILD + COMMAND + ${CMAKE_INSTALL_NAME_TOOL} -change "@rpath/libonnxruntime.${Onnxruntime_VERSION}.dylib" + "@loader_path/../Frameworks/libonnxruntime.${Onnxruntime_VERSION}.dylib" $) +elseif(MSVC) + add_library(Ort INTERFACE) + set(Onnxruntime_LIB_NAMES onnxruntime;onnxruntime_providers_shared) + foreach(lib_name IN LISTS Onnxruntime_LIB_NAMES) + add_library(Ort::${lib_name} SHARED IMPORTED) + set_target_properties(Ort::${lib_name} PROPERTIES IMPORTED_IMPLIB ${onnxruntime_SOURCE_DIR}/lib/${lib_name}.lib) + set_target_properties(Ort::${lib_name} PROPERTIES IMPORTED_LOCATION ${onnxruntime_SOURCE_DIR}/lib/${lib_name}.dll) + set_target_properties(Ort::${lib_name} PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${onnxruntime_SOURCE_DIR}/include) + target_link_libraries(Ort INTERFACE Ort::${lib_name}) + install(FILES ${onnxruntime_SOURCE_DIR}/lib/${lib_name}.dll DESTINATION "obs-plugins/64bit") + endforeach() + + target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE Ort) + +else() + if(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") + set(Onnxruntime_LINK_LIBS "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime.so.${Onnxruntime_VERSION}") + set(Onnxruntime_INSTALL_LIBS ${Onnxruntime_LINK_LIBS}) + else() + set(Onnxruntime_LINK_LIBS "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime.so.${Onnxruntime_VERSION}") + set(Onnxruntime_INSTALL_LIBS ${Onnxruntime_LINK_LIBS} + "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime_providers_shared.so") + endif() + target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE ${Onnxruntime_LINK_LIBS}) + target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM PUBLIC "${onnxruntime_SOURCE_DIR}/include") + install(FILES ${Onnxruntime_INSTALL_LIBS} DESTINATION "${CMAKE_INSTALL_LIBDIR}/obs-plugins/${CMAKE_PROJECT_NAME}") + set_target_properties(${CMAKE_PROJECT_NAME} PROPERTIES INSTALL_RPATH "$ORIGIN/${CMAKE_PROJECT_NAME}") +endif() diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt new file mode 100644 index 0000000..e69de29 diff --git a/examples/audio_capture.cpp b/examples/audio_capture.cpp new file mode 100644 index 0000000..0d18470 --- /dev/null +++ b/examples/audio_capture.cpp @@ -0,0 +1,100 @@ +#include "audio_capture.h" +#include +#include + +AudioCapture::AudioCapture(int buffer_duration_ms) + : buffer_duration_ms(buffer_duration_ms), is_capturing(false) {} + +AudioCapture::~AudioCapture() { + if (device_id != 0) { + SDL_CloseAudioDevice(device_id); + } +} + +bool AudioCapture::initialize(int device_index, int requested_sample_rate) { + SDL_AudioSpec desired_spec, obtained_spec; + + SDL_zero(desired_spec); + desired_spec.freq = requested_sample_rate; + desired_spec.format = AUDIO_F32; + desired_spec.channels = 1; + desired_spec.samples = 1024; + desired_spec.callback = [](void* userdata, uint8_t* stream, int len) { + static_cast(userdata)->processAudio(stream, len); + }; + desired_spec.userdata = this; + + device_id = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(device_index, 1), 1, &desired_spec, &obtained_spec, 0); + if (device_id == 0) { + SDL_Log("Failed to open audio device: %s", SDL_GetError()); + return false; + } + + sample_rate = obtained_spec.freq; + audio_buffer.resize(sample_rate * buffer_duration_ms / 1000); + write_position = 0; + buffer_size = 0; + + return true; +} + +bool AudioCapture::startCapture() { + if (device_id == 0) return false; + is_capturing = true; + SDL_PauseAudioDevice(device_id, 0); + return true; +} + +bool AudioCapture::stopCapture() { + if (device_id == 0) return false; + is_capturing = false; + SDL_PauseAudioDevice(device_id, 1); + return true; +} + +bool AudioCapture::resetBuffer() { + std::lock_guard lock(buffer_mutex); + write_position = 0; + buffer_size = 0; + std::fill(audio_buffer.begin(), audio_buffer.end(), 0.0f); + return true; +} + +void AudioCapture::processAudio(uint8_t* stream, int length) { + if (!is_capturing) return; + + std::lock_guard lock(buffer_mutex); + int sample_count = length / sizeof(float); + const float* input = reinterpret_cast(stream); + + for (int i = 0; i < sample_count; ++i) { + audio_buffer[write_position] = input[i]; + write_position = (write_position + 1) % audio_buffer.size(); + if (buffer_size < audio_buffer.size()) { + ++buffer_size; + } + } +} + +void AudioCapture::getAudioData(int duration_ms, std::vector& output) { + std::lock_guard lock(buffer_mutex); + int requested_samples = (duration_ms * sample_rate) / 1000; + requested_samples = std::min(requested_samples, static_cast(buffer_size)); + + output.resize(requested_samples); + size_t start_pos = (write_position - requested_samples + audio_buffer.size()) % audio_buffer.size(); + + for (int i = 0; i < requested_samples; ++i) { + output[i] = audio_buffer[(start_pos + i) % audio_buffer.size()]; + } +} + +bool handleSDLEvents() { + SDL_Event event; + while (SDL_PollEvent(&event)) { + if (event.type == SDL_QUIT) { + return false; + } + } + return true; +} diff --git a/examples/audio_capture.h b/examples/audio_capture.h new file mode 100644 index 0000000..87e9cd8 --- /dev/null +++ b/examples/audio_capture.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include +#include +#include +#include + +class AudioCapture { +public: + AudioCapture(int buffer_duration_ms); + ~AudioCapture(); + + bool initialize(int device_index, int sample_rate); + bool startCapture(); + bool stopCapture(); + bool resetBuffer(); + + void processAudio(uint8_t* stream, int length); + void getAudioData(int duration_ms, std::vector& output); + +private: + SDL_AudioDeviceID device_id = 0; + int buffer_duration_ms = 0; + int sample_rate = 0; + std::atomic_bool is_capturing; + std::mutex buffer_mutex; + std::vector audio_buffer; + size_t write_position = 0; + size_t buffer_size = 0; +}; + +bool handleSDLEvents(); diff --git a/examples/realtime_transcription.cpp b/examples/realtime_transcription.cpp new file mode 100644 index 0000000..896a18f --- /dev/null +++ b/examples/realtime_transcription.cpp @@ -0,0 +1,57 @@ +#include +#include "audio_capture.h" + +#include + +int main() { + // Initialize the library + locaal::Transcription tt; + + // Set the transcription parameters (language, model, etc.) + tt.setTranscriptionParams("en-US"); + + tt.setModelDownloadCallbacks([](const std::string &model_name, const std::string &model_path) { + std::cout << "Model downloaded: " << model_name << " at " << model_path << std::endl; + }, [](const std::string &model_name, const std::string &model_path) { + std::cout << "Model download failed: " << model_name << " at " << model_path << std::endl; + }, [](const std::string &model_name, const std::string &model_path) { + std::cout << "Model download progress: " << model_name << " at " << model_path << std::endl; + }); + + // Set the callbacks for the transcription + tt.setTranscriptionCallback([](const locaal::DetectionResultWithText &result) { + // Print the transcription result + std::cout << "Transcription: " << result.text << std::endl; + }); + + // Start real-time transcription background thread + tt.startTranscription(); + + // Start capturing audio from the microphone + AudioCapture audio_capture(1000); + if (!audio_capture.initialize(0, 16000)) { + std::cerr << "Failed to initialize audio capture" << std::endl; + return 1; + } + if (!audio_capture.startCapture()) { + std::cerr << "Failed to start audio capture" << std::endl; + return 1; + } + + // Main loop + while (true) { + // Handle SDL events + if (!handleSDLEvents()) { + break; + } + + // Get audio data from the audio capture + std::vector audio_data; + audio_capture.getAudioData(1000, audio_data); + + // Process the audio data for transcription + tt.processAudio(audio_data.data(), audio_data.size()); + } + + return 0; +} diff --git a/include/locaal.h b/include/locaal.h new file mode 100644 index 0000000..e69de29 diff --git a/src/model-utils/model-downloader-types.h b/src/model-utils/model-downloader-types.h new file mode 100644 index 0000000..3d24d96 --- /dev/null +++ b/src/model-utils/model-downloader-types.h @@ -0,0 +1,28 @@ +#ifndef MODEL_DOWNLOADER_TYPES_H +#define MODEL_DOWNLOADER_TYPES_H + +#include +#include +#include +#include + +typedef std::function + download_finished_callback_t; + +struct ModelFileDownloadInfo { + std::string url; + std::string sha256; +}; + +enum ModelType { MODEL_TYPE_TRANSCRIPTION, MODEL_TYPE_TRANSLATION }; + +struct ModelInfo { + std::string friendly_name; + std::string local_folder_name; + ModelType type; + std::vector files; +}; + +extern std::map models_info; + +#endif /* MODEL_DOWNLOADER_TYPES_H */ diff --git a/src/model-utils/model-downloader-ui.cpp b/src/model-utils/model-downloader-ui.cpp new file mode 100644 index 0000000..a428e20 --- /dev/null +++ b/src/model-utils/model-downloader-ui.cpp @@ -0,0 +1,256 @@ +#include "model-downloader-ui.h" +#include "plugin-support.h" + +#include + +#include + +size_t write_data(void *ptr, size_t size, size_t nmemb, FILE *stream) +{ + size_t written = fwrite(ptr, size, nmemb, stream); + return written; +} + +ModelDownloader::ModelDownloader(const ModelInfo &model_info, + download_finished_callback_t download_finished_callback_, + QWidget *parent) + : QDialog(parent), + download_finished_callback(download_finished_callback_) +{ + this->setWindowTitle("LocalVocal: Downloading model..."); + this->setWindowFlags(Qt::Dialog | Qt::WindowTitleHint | Qt::CustomizeWindowHint); + this->setFixedSize(300, 100); + // Bring the dialog to the front + this->activateWindow(); + this->raise(); + + this->layout = new QVBoxLayout(this); + + // Add a label for the model name + QLabel *model_name_label = new QLabel(this); + model_name_label->setText(QString::fromStdString(model_info.friendly_name)); + model_name_label->setAlignment(Qt::AlignCenter); + this->layout->addWidget(model_name_label); + + this->progress_bar = new QProgressBar(this); + this->progress_bar->setRange(0, 100); + this->progress_bar->setValue(0); + this->progress_bar->setAlignment(Qt::AlignCenter); + // Show progress as a percentage + this->progress_bar->setFormat("%p%"); + this->layout->addWidget(this->progress_bar); + + this->download_thread = new QThread(); + this->download_worker = new ModelDownloadWorker(model_info); + this->download_worker->moveToThread(this->download_thread); + + connect(this->download_thread, &QThread::started, this->download_worker, + &ModelDownloadWorker::download_model); + connect(this->download_worker, &ModelDownloadWorker::download_progress, this, + &ModelDownloader::update_progress); + connect(this->download_worker, &ModelDownloadWorker::download_finished, this, + &ModelDownloader::download_finished); + connect(this->download_worker, &ModelDownloadWorker::download_finished, + this->download_thread, &QThread::quit); + connect(this->download_worker, &ModelDownloadWorker::download_finished, + this->download_worker, &ModelDownloadWorker::deleteLater); + connect(this->download_worker, &ModelDownloadWorker::download_error, this, + &ModelDownloader::show_error); + connect(this->download_thread, &QThread::finished, this->download_thread, + &QThread::deleteLater); + + this->download_thread->start(); +} + +void ModelDownloader::closeEvent(QCloseEvent *e) +{ + if (!this->mPrepareToClose) + e->ignore(); + else { + QDialog::closeEvent(e); + deleteLater(); + } +} + +void ModelDownloader::close() +{ + this->mPrepareToClose = true; + + QDialog::close(); +} + +void ModelDownloader::update_progress(int progress) +{ + this->progress_bar->setValue(progress); +} + +void ModelDownloader::download_finished(const std::string &path) +{ + // Call the callback with the path to the downloaded model + this->download_finished_callback(0, path); + // Close the dialog + this->close(); +} + +void ModelDownloader::show_error(const std::string &reason) +{ + this->setWindowTitle("Download failed!"); + this->progress_bar->setFormat("Download failed!"); + this->progress_bar->setAlignment(Qt::AlignCenter); + this->progress_bar->setStyleSheet("QProgressBar::chunk { background-color: #FF0000; }"); + // Add a label to show the error + QLabel *error_label = new QLabel(this); + error_label->setText(QString::fromStdString(reason)); + error_label->setAlignment(Qt::AlignCenter); + // Color red + error_label->setStyleSheet("QLabel { color : red; }"); + this->layout->addWidget(error_label); + // Add a button to close the dialog + QPushButton *close_button = new QPushButton("Close", this); + this->layout->addWidget(close_button); + connect(close_button, &QPushButton::clicked, this, &ModelDownloader::close); + this->download_finished_callback(1, ""); +} + +ModelDownloadWorker::ModelDownloadWorker(const ModelInfo &model_info_) : model_info(model_info_) {} + +std::string get_filename_from_url(const std::string &url) +{ + auto lastSlashPos = url.find_last_of("/"); + auto queryPos = url.find("?", lastSlashPos); + if (queryPos == std::string::npos) { + return url.substr(lastSlashPos + 1); + } else { + return url.substr(lastSlashPos + 1, queryPos - lastSlashPos - 1); + } +} + +void ModelDownloadWorker::download_model() +{ + char *config_folder = obs_module_config_path("models"); +#ifdef _WIN32 + // convert mbstring to wstring + int count = MultiByteToWideChar(CP_UTF8, 0, config_folder, strlen(config_folder), NULL, 0); + std::wstring config_folder_str(count, 0); + MultiByteToWideChar(CP_UTF8, 0, config_folder, strlen(config_folder), &config_folder_str[0], + count); + obs_log(LOG_INFO, "Download: Config models folder: %S", config_folder_str.c_str()); +#else + std::string config_folder_str = config_folder; + obs_log(LOG_INFO, "Download: Config models folder: %s", config_folder_str.c_str()); +#endif + bfree(config_folder); + + const std::filesystem::path module_config_models_folder = + std::filesystem::absolute(config_folder_str); + + // Check if the config folder exists + if (!std::filesystem::exists(module_config_models_folder)) { + obs_log(LOG_WARNING, "Config folder does not exist: %s", + module_config_models_folder.string().c_str()); + // Create the config folder + if (!std::filesystem::create_directories(module_config_models_folder)) { + obs_log(LOG_ERROR, "Failed to create config folder: %s", + module_config_models_folder.string().c_str()); + emit download_error("Failed to create config folder."); + return; + } + } + + const std::string model_local_config_path = + (module_config_models_folder / model_info.local_folder_name).string(); + + obs_log(LOG_INFO, "Model save path: %s", model_local_config_path.c_str()); + + if (!std::filesystem::exists(model_local_config_path)) { + // model folder does not exist, create it + if (!std::filesystem::create_directories(model_local_config_path)) { + obs_log(LOG_ERROR, "Failed to create model folder: %s", + model_local_config_path.c_str()); + emit download_error("Failed to create model folder."); + return; + } + } + + CURL *curl = curl_easy_init(); + if (curl) { + for (auto &model_download_file : this->model_info.files) { + obs_log(LOG_INFO, "Model URL: %s", model_download_file.url.c_str()); + + const std::string model_filename = + get_filename_from_url(model_download_file.url); + const std::string model_file_save_path = + (std::filesystem::path(model_local_config_path) / model_filename) + .string(); + if (std::filesystem::exists(model_file_save_path)) { + obs_log(LOG_INFO, "Model file already exists: %s", + model_file_save_path.c_str()); + continue; + } + + FILE *fp = fopen(model_file_save_path.c_str(), "wb"); + if (fp == nullptr) { + obs_log(LOG_ERROR, "Failed to open model file for writing %s.", + model_file_save_path.c_str()); + emit download_error("Failed to open file."); + return; + } + curl_easy_setopt(curl, CURLOPT_URL, model_download_file.url.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_data); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, fp); + curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); + curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, + ModelDownloadWorker::progress_callback); + curl_easy_setopt(curl, CURLOPT_XFERINFODATA, this); + // Follow redirects + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); + CURLcode res = curl_easy_perform(curl); + if (res != CURLE_OK) { + obs_log(LOG_ERROR, "Failed to download model file %s.", + model_filename.c_str()); + emit download_error("Failed to download model file."); + } + fclose(fp); + } + curl_easy_cleanup(curl); + emit download_finished(model_local_config_path); + } else { + obs_log(LOG_ERROR, "Failed to initialize curl."); + emit download_error("Failed to initialize curl."); + } +} + +int ModelDownloadWorker::progress_callback(void *clientp, curl_off_t dltotal, curl_off_t dlnow, + curl_off_t, curl_off_t) +{ + if (dltotal == 0) { + return 0; // Unknown progress + } + ModelDownloadWorker *worker = (ModelDownloadWorker *)clientp; + if (worker == nullptr) { + obs_log(LOG_ERROR, "Worker is null."); + return 1; + } + int progress = (int)(dlnow * 100l / dltotal); + emit worker->download_progress(progress); + return 0; +} + +ModelDownloader::~ModelDownloader() +{ + if (this->download_thread != nullptr) { + if (this->download_thread->isRunning()) { + this->download_thread->quit(); + this->download_thread->wait(); + } + delete this->download_thread; + } + if (this->download_worker != nullptr) { + delete this->download_worker; + } +} + +ModelDownloadWorker::~ModelDownloadWorker() +{ + // Do nothing +} diff --git a/src/model-utils/model-downloader-ui.h b/src/model-utils/model-downloader-ui.h new file mode 100644 index 0000000..aaa0752 --- /dev/null +++ b/src/model-utils/model-downloader-ui.h @@ -0,0 +1,61 @@ +#ifndef MODEL_DOWNLOADER_UI_H +#define MODEL_DOWNLOADER_UI_H + +#include +#include + +#include +#include + +#include + +#include "model-downloader-types.h" + +class ModelDownloadWorker : public QObject { + Q_OBJECT +public: + ModelDownloadWorker(const ModelInfo &model_info_); + ~ModelDownloadWorker(); + +public slots: + void download_model(); + +signals: + void download_progress(int progress); + void download_finished(const std::string &path); + void download_error(const std::string &reason); + +private: + static int progress_callback(void *clientp, curl_off_t dltotal, curl_off_t dlnow, + curl_off_t ultotal, curl_off_t ulnow); + ModelInfo model_info; +}; + +class ModelDownloader : public QDialog { + Q_OBJECT +public: + ModelDownloader(const ModelInfo &model_info, + download_finished_callback_t download_finished_callback, + QWidget *parent = nullptr); + ~ModelDownloader(); + +public slots: + void update_progress(int progress); + void download_finished(const std::string &path); + void show_error(const std::string &reason); + +protected: + void closeEvent(QCloseEvent *e) override; + +private: + QVBoxLayout *layout; + QProgressBar *progress_bar; + QPointer download_thread; + QPointer download_worker; + // Callback for when the download is finished + download_finished_callback_t download_finished_callback; + bool mPrepareToClose; + void close(); +}; + +#endif // MODEL_DOWNLOADER_UI_H diff --git a/src/model-utils/model-downloader.cpp b/src/model-utils/model-downloader.cpp new file mode 100644 index 0000000..7f7f04d --- /dev/null +++ b/src/model-utils/model-downloader.cpp @@ -0,0 +1,91 @@ +#include "model-downloader.h" +#include "plugin-support.h" +#include "model-downloader-ui.h" +#include "model-find-utils.h" + +#include +#include + +std::string find_model_folder(const ModelInfo &model_info) +{ + if (model_info.friendly_name.empty() || model_info.local_folder_name.empty() || + model_info.files.empty()) { + obs_log(LOG_ERROR, "Model info is invalid."); + return ""; + } + + char *data_folder_models = obs_module_file("models"); + const std::filesystem::path module_data_models_folder = + std::filesystem::absolute(data_folder_models); + bfree(data_folder_models); + + const std::string model_local_data_path = + (module_data_models_folder / model_info.local_folder_name).string(); + + obs_log(LOG_INFO, "Checking if model '%s' exists in data...", + model_info.friendly_name.c_str()); + + if (!std::filesystem::exists(model_local_data_path)) { + obs_log(LOG_INFO, "Model not found in data: %s", model_local_data_path.c_str()); + } else { + obs_log(LOG_INFO, "Model folder found in data: %s", model_local_data_path.c_str()); + return model_local_data_path; + } + + // Check if model exists in the config folder + char *config_folder = obs_module_config_path("models"); + if (!config_folder) { + obs_log(LOG_INFO, "Config folder not set."); + return ""; + } +#ifdef _WIN32 + // convert mbstring to wstring + int count = MultiByteToWideChar(CP_UTF8, 0, config_folder, strlen(config_folder), NULL, 0); + std::wstring config_folder_str(count, 0); + MultiByteToWideChar(CP_UTF8, 0, config_folder, strlen(config_folder), &config_folder_str[0], + count); + obs_log(LOG_INFO, "Config models folder: %S", config_folder_str.c_str()); +#else + std::string config_folder_str = config_folder; + obs_log(LOG_INFO, "Config models folder: %s", config_folder_str.c_str()); +#endif + + const std::filesystem::path module_config_models_folder = + std::filesystem::absolute(config_folder_str); + bfree(config_folder); + + obs_log(LOG_INFO, "Checking if model '%s' exists in config...", + model_info.friendly_name.c_str()); + + const std::string model_local_config_path = + (module_config_models_folder / model_info.local_folder_name).string(); + + obs_log(LOG_INFO, "Lookig for model in config: %s", model_local_config_path.c_str()); + if (std::filesystem::exists(model_local_config_path)) { + obs_log(LOG_INFO, "Model folder exists in config folder: %s", + model_local_config_path.c_str()); + return model_local_config_path; + } + + obs_log(LOG_INFO, "Model '%s' not found.", model_info.friendly_name.c_str()); + return ""; +} + +std::string find_model_bin_file(const ModelInfo &model_info) +{ + const std::string model_local_folder_path = find_model_folder(model_info); + if (model_local_folder_path.empty()) { + return ""; + } + + return find_bin_file_in_folder(model_local_folder_path); +} + +void download_model_with_ui_dialog(const ModelInfo &model_info, + download_finished_callback_t download_finished_callback) +{ + // Start the model downloader UI + ModelDownloader *model_downloader = new ModelDownloader( + model_info, download_finished_callback, (QWidget *)obs_frontend_get_main_window()); + model_downloader->show(); +} diff --git a/src/model-utils/model-downloader.h b/src/model-utils/model-downloader.h new file mode 100644 index 0000000..3af9450 --- /dev/null +++ b/src/model-utils/model-downloader.h @@ -0,0 +1,15 @@ +#ifndef MODEL_DOWNLOADER_H +#define MODEL_DOWNLOADER_H + +#include + +#include "model-downloader-types.h" + +std::string find_model_folder(const ModelInfo &model_info); +std::string find_model_bin_file(const ModelInfo &model_info); + +// Start the model downloader UI dialog with a callback for when the download is finished +void download_model_with_ui_dialog(const ModelInfo &model_info, + download_finished_callback_t download_finished_callback); + +#endif // MODEL_DOWNLOADER_H diff --git a/src/model-utils/model-find-utils.cpp b/src/model-utils/model-find-utils.cpp new file mode 100644 index 0000000..d2bb48f --- /dev/null +++ b/src/model-utils/model-find-utils.cpp @@ -0,0 +1,50 @@ +#include +#include +#include +#include +#include + +#include + +#include "model-find-utils.h" +#include "plugin-support.h" + +std::string find_file_in_folder_by_name(const std::string &folder_path, + const std::string &file_name) +{ + for (const auto &entry : std::filesystem::directory_iterator(folder_path)) { + if (entry.path().filename() == file_name) { + return entry.path().string(); + } + } + return ""; +} + +// Find a file in a folder by expression +std::string find_file_in_folder_by_regex_expression(const std::string &folder_path, + const std::string &file_name_regex) +{ + for (const auto &entry : std::filesystem::directory_iterator(folder_path)) { + if (std::regex_match(entry.path().filename().string(), + std::regex(file_name_regex))) { + return entry.path().string(); + } + } + return ""; +} + +std::string find_bin_file_in_folder(const std::string &model_local_folder_path) +{ + // find .bin file in folder + for (const auto &entry : std::filesystem::directory_iterator(model_local_folder_path)) { + if (entry.path().extension() == ".bin") { + const std::string bin_file_path = entry.path().string(); + obs_log(LOG_INFO, "Model bin file found in folder: %s", + bin_file_path.c_str()); + return bin_file_path; + } + } + obs_log(LOG_ERROR, "Model bin file not found in folder: %s", + model_local_folder_path.c_str()); + return ""; +} diff --git a/src/model-utils/model-find-utils.h b/src/model-utils/model-find-utils.h new file mode 100644 index 0000000..72a3a6f --- /dev/null +++ b/src/model-utils/model-find-utils.h @@ -0,0 +1,14 @@ +#ifndef MODEL_FIND_UTILS_H +#define MODEL_FIND_UTILS_H + +#include + +#include "model-downloader-types.h" + +std::string find_file_in_folder_by_name(const std::string &folder_path, + const std::string &file_name); +std::string find_bin_file_in_folder(const std::string &path); +std::string find_file_in_folder_by_regex_expression(const std::string &folder_path, + const std::string &file_name_regex); + +#endif // MODEL_FIND_UTILS_H diff --git a/src/model-utils/model-infos.cpp b/src/model-utils/model-infos.cpp new file mode 100644 index 0000000..e978002 --- /dev/null +++ b/src/model-utils/model-infos.cpp @@ -0,0 +1,234 @@ +#include "model-downloader-types.h" + +std::map models_info = {{ + {"M2M-100 418M (495Mb)", + {"M2M-100 418M", + "m2m-100-418M", + MODEL_TYPE_TRANSLATION, + {{"https://huggingface.co/jncraton/m2m100_418M-ct2-int8/resolve/main/model.bin?download=true", + "D6703DD9F920FF896E45C3D97B490761BED5944937B90BBE6A7245F5652542D4"}, + { + "https://huggingface.co/jncraton/m2m100_418M-ct2-int8/resolve/main/config.json?download=true", + "4244772990E30069563E3DDFB4AD6DC95BDFD2AC3DE667EA8858C9B0A8433FA8", + }, + {"https://huggingface.co/jncraton/m2m100_418M-ct2-int8/resolve/main/generation_config.json?download=true", + "AED76366507333DDBB8BD49960F23C82FE6446B3319A46A54BEFDB45324CCF61"}, + {"https://huggingface.co/jncraton/m2m100_418M-ct2-int8/resolve/main/shared_vocabulary.json?download=true", + "7EB5D0FF184C6095C7C10F9911C0AEA492250ABD12854F9C3D787C64B1C6397E"}, + {"https://huggingface.co/jncraton/m2m100_418M-ct2-int8/resolve/main/special_tokens_map.json?download=true", + "C1A4F86C3874D279AE1B2A05162858DB5DD6C61665D84223ED886CBCFF08FDA6"}, + {"https://huggingface.co/jncraton/m2m100_418M-ct2-int8/resolve/main/tokenizer_config.json?download=true", + "AE54F15F0649BB05041CDADAD8485BA1FAF40BC33E6B4C2A74AE2D1AE5710FA2"}, + {"https://huggingface.co/jncraton/m2m100_418M-ct2-int8/resolve/main/vocab.json?download=true", + "B6E77E474AEEA8F441363ACA7614317C06381F3EACFE10FB9856D5081D1074CC"}, + {"https://huggingface.co/jncraton/m2m100_418M-ct2-int8/resolve/main/sentencepiece.bpe.model?download=true", + "D8F7C76ED2A5E0822BE39F0A4F95A55EB19C78F4593CE609E2EDBC2AEA4D380A"}}}}, + {"M2M-100 1.2B (1.25Gb)", + {"M2M-100 1.2BM", + "m2m-100-1_2B", + MODEL_TYPE_TRANSLATION, + {{"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/model.bin?download=true", + "C97DF052A558895317312470E1FF7CB8EAE5416F7AE16214A2983C6853DD3CE5"}, + { + "https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/config.json?download=true", + "4244772990E30069563E3DDFB4AD6DC95BDFD2AC3DE667EA8858C9B0A8433FA8", + }, + {"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/generation_config.json?download=true", + "AED76366507333DDBB8BD49960F23C82FE6446B3319A46A54BEFDB45324CCF61"}, + {"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/shared_vocabulary.json?download=true", + "7EB5D0FF184C6095C7C10F9911C0AEA492250ABD12854F9C3D787C64B1C6397E"}, + {"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/special_tokens_map.json?download=true", + "C1A4F86C3874D279AE1B2A05162858DB5DD6C61665D84223ED886CBCFF08FDA6"}, + {"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/tokenizer_config.json?download=true", + "1566A6CFA4F541A55594C9D5E090F530812D5DE7C94882EA3AF156962D9933AE"}, + {"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/vocab.json?download=true", + "B6E77E474AEEA8F441363ACA7614317C06381F3EACFE10FB9856D5081D1074CC"}, + {"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/sentencepiece.bpe.model?download=true", + "D8F7C76ED2A5E0822BE39F0A4F95A55EB19C78F4593CE609E2EDBC2AEA4D380A"}}}}, + {"NLLB 200 1.3B (1.4Gb)", + {"NLLB 200 1.3B", + "nllb-200-1.3b", + MODEL_TYPE_TRANSLATION, + {{"https://huggingface.co/JustFrederik/nllb-200-distilled-1.3B-ct2-int8/resolve/main/model.bin?download=true", + "72D7533DC7A0E8F10F19A650D4E90FAF9CBFA899DB5411AD124BD5802BD91263"}, + { + "https://huggingface.co/JustFrederik/nllb-200-distilled-1.3B-ct2-int8/resolve/main/config.json?download=true", + "0C2F6FA2057C7264D052FB4A62BA3476EEAE70487ACDDFA8E779A53A00CBF44C", + }, + {"https://huggingface.co/JustFrederik/nllb-200-distilled-1.3B-ct2-int8/resolve/main/tokenizer.json?download=true", + "E316B82DE11D0F951F370943B3C438311629547285129B0B81DADABD01BCA665"}, + {"https://huggingface.co/JustFrederik/nllb-200-distilled-1.3B-ct2-int8/resolve/main/shared_vocabulary.txt?download=true", + "A132A83330F45514C2476EB81D1D69B3C41762264D16CE0A7EA982E5D6C728E5"}, + {"https://huggingface.co/JustFrederik/nllb-200-distilled-1.3B-ct2-int8/resolve/main/special_tokens_map.json?download=true", + "992BD4ED610D644D6823081937BCC91BB8878DD556CEA4AE5327F2480361330E"}, + {"https://huggingface.co/JustFrederik/nllb-200-distilled-1.3B-ct2-int8/resolve/main/tokenizer_config.json?download=true", + "D1AA8C3697D3E35674F97B5B7E9C99D22B010F528E80140257D97316BE90D044"}, + {"https://huggingface.co/JustFrederik/nllb-200-distilled-1.3B-ct2-int8/resolve/main/sentencepiece.bpe.model?download=true", + "14BB8DFB35C0FFDEA7BC01E56CEA38B9E3D5EFCDCB9C251D6B40538E1AAB555A"}}}}, + {"NLLB 200 600M (650Mb)", + {"NLLB 200 600M", + "nllb-200-600m", + MODEL_TYPE_TRANSLATION, + {{"https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8/resolve/main/model.bin?download=true", + "ED1BEAF75134DE7505315A5223162F56ACFF397EFF6B50638A500D3936FE707B"}, + { + "https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8/resolve/main/config.json?download=true", + "0C2F6FA2057C7264D052FB4A62BA3476EEAE70487ACDDFA8E779A53A00CBF44C", + }, + {"https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8/resolve/main/tokenizer.json?download=true", + "E316B82DE11D0F951F370943B3C438311629547285129B0B81DADABD01BCA665"}, + {"https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8/resolve/main/shared_vocabulary.txt?download=true", + "A132A83330F45514C2476EB81D1D69B3C41762264D16CE0A7EA982E5D6C728E5"}, + {"https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8/resolve/main/special_tokens_map.json?download=true", + "992BD4ED610D644D6823081937BCC91BB8878DD556CEA4AE5327F2480361330E"}, + {"https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8/resolve/main/tokenizer_config.json?download=true", + "D1AA8C3697D3E35674F97B5B7E9C99D22B010F528E80140257D97316BE90D044"}, + {"https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8/resolve/main/sentencepiece.bpe.model?download=true", + "14BB8DFB35C0FFDEA7BC01E56CEA38B9E3D5EFCDCB9C251D6B40538E1AAB555A"}}}}, + {"MADLAD 400 3B (2.9Gb)", + {"MADLAD 400 3B", + "madlad-400-3b", + MODEL_TYPE_TRANSLATION, + {{"https://huggingface.co/santhosh/madlad400-3b-ct2/resolve/main/model.bin?download=true", + "F3C87256A2C888100C179D7DCD7F41DF17C767469546C59D32C7DDE86C740A6B"}, + { + "https://huggingface.co/santhosh/madlad400-3b-ct2/resolve/main/config.json?download=true", + "A428C51CD35517554523B3C6B6974A5928BC35E82B130869A543566A34A83B93", + }, + {"https://huggingface.co/santhosh/madlad400-3b-ct2/resolve/main/shared_vocabulary.txt?download=true", + "C327551CE3CA6EFC7B437E11A267F79979893332DDA8A1D146E2C950815193F8"}, + {"https://huggingface.co/santhosh/madlad400-3b-ct2/resolve/main/sentencepiece.model?download=true", + "EF11AC9A22C7503492F56D48DCE53BE20E339B63605983E9F27D2CD0E0F3922C"}}}}, + {"Whisper Base q5 (57Mb)", + {"Whisper Base q5", + "whisper-base-q5", + MODEL_TYPE_TRANSCRIPTION, + {{"https://ggml.ggerganov.com/ggml-model-whisper-base-q5_1.bin", + "422F1AE452ADE6F30A004D7E5C6A43195E4433BC370BF23FAC9CC591F01A8898"}}}}, + {"Whisper Base English q5 (57Mb)", + {"Whisper Base En q5", + "ggml-model-whisper-base-en-q5_1", + MODEL_TYPE_TRANSCRIPTION, + {{"https://ggml.ggerganov.com/ggml-model-whisper-base.en-q5_1.bin", + "4BAF70DD0D7C4247BA2B81FAFD9C01005AC77C2F9EF064E00DCF195D0E2FDD2F"}}}}, + {"Whisper Base (141Mb)", + {"Whisper Base", + "ggml-model-whisper-base", + MODEL_TYPE_TRANSCRIPTION, + {{"https://ggml.ggerganov.com/ggml-model-whisper-base.bin", + "60ED5BC3DD14EEA856493D334349B405782DDCAF0028D4B5DF4088345FBA2EFE"}}}}, + {"Whisper Base English (141Mb)", + {"Whisper Base En", + "ggml-model-whisper-base-en", + MODEL_TYPE_TRANSCRIPTION, + {{"https://ggml.ggerganov.com/ggml-model-whisper-base.en.bin", + "A03779C86DF3323075F5E796CB2CE5029F00EC8869EEE3FDFB897AFE36C6D002"}}}}, + {"Whisper Large v1 q5 (1Gb)", + {"Whisper Large v1 q5", + "ggml-model-whisper-large-q5_0", + MODEL_TYPE_TRANSCRIPTION, + {{"https://ggml.ggerganov.com/ggml-model-whisper-large-q5_0.bin", + "3A214837221E4530DBC1FE8D734F302AF393EB30BD0ED046042EBF4BAF70F6F2"}}}}, + {"Whisper Medium q5 (514Mb)", + {"Whisper Medium q5", + "ggml-model-whisper-medium-q5_0", + MODEL_TYPE_TRANSCRIPTION, + {{"https://ggml.ggerganov.com/ggml-model-whisper-medium-q5_0.bin", + "19FEA4B380C3A618EC4723C3EEF2EB785FFBA0D0538CF43F8F235E7B3B34220F"}}}}, + {"Whisper Medium English q5 (514Mb)", + {"Whisper Medium En q5", + "ggml-model-whisper-medium-en-q5_0", + MODEL_TYPE_TRANSCRIPTION, + {{"https://ggml.ggerganov.com/ggml-model-whisper-medium.en-q5_0.bin", + "76733E26AD8FE1C7A5BF7531A9D41917B2ADC0F20F2E4F5531688A8C6CD88EB0"}}}}, + {"Whisper Small q5 (181Mb)", + {"Whisper Small q5", + "ggml-model-whisper-small-q5_1", + MODEL_TYPE_TRANSCRIPTION, + {{"https://ggml.ggerganov.com/ggml-model-whisper-small-q5_1.bin", + "AE85E4A935D7A567BD102FE55AFC16BB595BDB618E11B2FC7591BC08120411BB"}}}}, + {"Whisper Small English q5 (181Mb)", + {"Whisper Small En q5", + "ggml-model-whisper-small-en-q5_1", + MODEL_TYPE_TRANSCRIPTION, + {{"https://ggml.ggerganov.com/ggml-model-whisper-small.en-q5_1.bin", + "BFDFF4894DCB76BBF647D56263EA2A96645423F1669176F4844A1BF8E478AD30"}}}}, + {"Whisper Small (465Mb)", + {"Whisper Small", + "ggml-model-whisper-small", + MODEL_TYPE_TRANSCRIPTION, + {{"https://ggml.ggerganov.com/ggml-model-whisper-small.bin", + "1BE3A9B2063867B937E64E2EC7483364A79917E157FA98C5D94B5C1FFFEA987B"}}}}, + {"Whisper Small English (465Mb)", + {"Whisper Small En", + "ggml-model-whisper-small-en", + MODEL_TYPE_TRANSCRIPTION, + {{"https://ggml.ggerganov.com/ggml-model-whisper-small.en.bin", + "C6138D6D58ECC8322097E0F987C32F1BE8BB0A18532A3F88F734D1BBF9C41E5D"}}}}, + {"Whisper Tiny (74Mb)", + {"Whisper Tiny", + "ggml-model-whisper-tiny", + MODEL_TYPE_TRANSCRIPTION, + {{"https://ggml.ggerganov.com/ggml-model-whisper-tiny.bin", + "BE07E048E1E599AD46341C8D2A135645097A538221678B7ACDD1B1919C6E1B21"}}}}, + {"Whisper Tiny q5 (31Mb)", + {"Whisper Tiny q5", + "ggml-model-whisper-tiny-q5_1", + MODEL_TYPE_TRANSCRIPTION, + {{"https://ggml.ggerganov.com/ggml-model-whisper-tiny-q5_1.bin", + "818710568DA3CA15689E31A743197B520007872FF9576237BDA97BD1B469C3D7"}}}}, + {"Whisper Tiny English q5 (31Mb)", + {"Whisper Tiny En q5", + "ggml-model-whisper-tiny-en-q5_1", + MODEL_TYPE_TRANSCRIPTION, + {{"https://ggml.ggerganov.com/ggml-model-whisper-tiny.en-q5_1.bin", + "C77C5766F1CEF09B6B7D47F21B546CBDDD4157886B3B5D6D4F709E91E66C7C2B"}}}}, + {"Whisper Tiny English q8 (42Mb)", + {"Whisper Tiny En q8", + "ggml-model-whisper-tiny-en-q8_0", + MODEL_TYPE_TRANSCRIPTION, + {{"https://ggml.ggerganov.com/ggml-model-whisper-tiny.en-q8_0.bin", + "5BC2B3860AA151A4C6E7BB095E1FCCE7CF12C7B020CA08DCEC0C6D018BB7DD94"}}}}, + {"Whisper Tiny English (74Mb)", + {"Whisper Tiny En", + "ggml-model-whisper-tiny-en", + MODEL_TYPE_TRANSCRIPTION, + {{"https://ggml.ggerganov.com/ggml-model-whisper-tiny.en.bin", + "921E4CF8686FDD993DCD081A5DA5B6C365BFDE1162E72B08D75AC75289920B1F"}}}}, + {"Whisper Large v3 (3Gb)", + {"Whisper Large v3", + "ggml-large-v3", + MODEL_TYPE_TRANSCRIPTION, + {{"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3.bin", + "64d182b440b98d5203c4f9bd541544d84c605196c4f7b845dfa11fb23594d1e2"}}}}, + {"Whisper Large v3 q5 (1Gb)", + {"Whisper Large v3 q5", + "ggml-large-v3-q5_0", + MODEL_TYPE_TRANSCRIPTION, + {{"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3-q5_0.bin", + "d75795ecff3f83b5faa89d1900604ad8c780abd5739fae406de19f23ecd98ad1"}}}}, + {"Whisper Large v2 (3Gb)", + {"Whisper Large v2", + "ggml-large-v2", + MODEL_TYPE_TRANSCRIPTION, + {{"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v2.bin", + "9a423fe4d40c82774b6af34115b8b935f34152246eb19e80e376071d3f999487"}}}}, + {"Whisper Large v1 (3Gb)", + {"Whisper Large v1", + "ggml-large-v1", + MODEL_TYPE_TRANSCRIPTION, + {{"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v1.bin", + "7d99f41a10525d0206bddadd86760181fa920438b6b33237e3118ff6c83bb53d"}}}}, + {"Whisper Medium English (1.5Gb)", + {"Whisper Medium English", + "ggml-meduim-en", + MODEL_TYPE_TRANSCRIPTION, + {{"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium.en.bin", + "cc37e93478338ec7700281a7ac30a10128929eb8f427dda2e865faa8f6da4356"}}}}, + {"Whisper Medium (1.5Gb)", + {"Whisper Medium", + "ggml-meduim", + MODEL_TYPE_TRANSCRIPTION, + {{"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium.bin", + "6c14d5adee5f86394037b4e4e8b59f1673b6cee10e3cf0b11bbdbee79c156208"}}}}, +}}; diff --git a/src/transcription-filter-data.h b/src/transcription-filter-data.h new file mode 100644 index 0000000..205bbf0 --- /dev/null +++ b/src/transcription-filter-data.h @@ -0,0 +1,158 @@ +#ifndef TRANSCRIPTION_FILTER_DATA_H +#define TRANSCRIPTION_FILTER_DATA_H + +#include + +#include +#include +#include +#include +#include +#include + +#include "translation/translation.h" +#include "translation/translation-includes.h" +#include "whisper-utils/silero-vad-onnx.h" +#include "whisper-utils/whisper-processing.h" +#include "whisper-utils/token-buffer-thread.h" + +#define MAX_PREPROC_CHANNELS 10 + +struct transcription_filter_data { + obs_source_t *context; // obs filter source (this filter) + size_t channels; // number of channels + uint32_t sample_rate; // input sample rate + // How many input frames (in input sample rate) are needed for the next whisper frame + size_t frames; + // How many frames were processed in the last whisper frame (this is dynamic) + size_t last_num_frames; + // Start begining timestamp in ms since epoch + uint64_t start_timestamp_ms; + // Sentence counter for srt + size_t sentence_number; + // Minimal subtitle duration in ms + size_t min_sub_duration; + // Maximal subtitle duration in ms + size_t max_sub_duration; + // Last time a subtitle was rendered + uint64_t last_sub_render_time; + bool cleared_last_sub; + + /* PCM buffers */ + float *copy_buffers[MAX_PREPROC_CHANNELS]; + struct circlebuf info_buffer; + struct circlebuf input_buffers[MAX_PREPROC_CHANNELS]; + struct circlebuf whisper_buffer; + + /* Resampler */ + audio_resampler_t *resampler_to_whisper; + struct circlebuf resampled_buffer; + + /* whisper */ + std::string whisper_model_path; + struct whisper_context *whisper_context; + whisper_full_params whisper_params; + + /* Silero VAD */ + std::unique_ptr vad; + + float filler_p_threshold; + float sentence_psum_accept_thresh; + + bool do_silence; + int vad_mode; + int log_level = LOG_DEBUG; + bool log_words; + bool caption_to_stream; + bool active = false; + bool save_to_file = false; + bool save_srt = false; + bool truncate_output_file = false; + bool save_only_while_recording = false; + bool process_while_muted = false; + bool rename_file_to_match_recording = false; + bool translate = false; + std::string target_lang; + std::string translation_output; + bool enable_token_ts_dtw = false; + std::vector> filter_words_replace; + bool fix_utf8 = true; + bool enable_audio_chunks_callback = false; + bool source_signals_set = false; + bool initial_creation = true; + bool partial_transcription = false; + int partial_latency = 1000; + float duration_filter_threshold = 2.25f; + int segment_duration = 7000; + + // Last transcription result + std::string last_text_for_translation; + std::string last_text_translation; + + // Transcription context sentences + int n_context_sentences; + std::deque last_transcription_sentence; + + // Text source to output the subtitles + std::string text_source_name; + // Callback to set the text in the output text source (subtitles) + std::function setTextCallback; + // Output file path to write the subtitles + std::string output_file_path; + std::string whisper_model_file_currently_loaded; + bool whisper_model_loaded_new; + + // Use std for thread and mutex + std::thread whisper_thread; + + std::mutex whisper_buf_mutex; + std::mutex whisper_ctx_mutex; + std::condition_variable wshiper_thread_cv; + std::optional input_cv; + + // translation context + struct translation_context translation_ctx; + std::string translation_model_index; + std::string translation_model_path_external; + bool translate_only_full_sentences; + + bool buffered_output = false; + TokenBufferThread captions_monitor; + TokenBufferThread translation_monitor; + int buffered_output_num_lines = 2; + int buffered_output_num_chars = 30; + TokenBufferSegmentation buffered_output_output_type = + TokenBufferSegmentation::SEGMENTATION_TOKEN; + + // ctor + transcription_filter_data() : whisper_buf_mutex(), whisper_ctx_mutex(), wshiper_thread_cv() + { + // initialize all pointers to nullptr + for (size_t i = 0; i < MAX_PREPROC_CHANNELS; i++) { + copy_buffers[i] = nullptr; + } + context = nullptr; + resampler_to_whisper = nullptr; + whisper_model_path = ""; + whisper_context = nullptr; + output_file_path = ""; + whisper_model_file_currently_loaded = ""; + } +}; + +// Audio packet info +struct transcription_filter_audio_info { + uint32_t frames; + uint64_t timestamp_offset_ns; // offset (since start of processing) timestamp in ns +}; + +// Callback sent when the transcription has a new result +void set_text_callback(struct transcription_filter_data *gf, const DetectionResultWithText &str); +void clear_current_caption(transcription_filter_data *gf_); + +// Callback sent when the VAD finds an audio chunk. Sample rate = WHISPER_SAMPLE_RATE, channels = 1 +// The audio chunk is in 32-bit float format +void audio_chunk_callback(struct transcription_filter_data *gf, const float *pcm32f_data, + size_t frames, int vad_state, const DetectionResultWithText &result); + +#endif /* TRANSCRIPTION_FILTER_DATA_H */ diff --git a/src/transcription-utils.cpp b/src/transcription-utils.cpp new file mode 100644 index 0000000..727d3df --- /dev/null +++ b/src/transcription-utils.cpp @@ -0,0 +1,162 @@ +#include "transcription-utils.h" + +#include +#include +#include + +// clang-format off +#define is_lead_byte(c) (((c)&0xe0) == 0xc0 || ((c)&0xf0) == 0xe0 || ((c)&0xf8) == 0xf0) +// clang-format off +#define is_trail_byte(c) (((c)&0xc0) == 0x80) + +inline int lead_byte_length(const uint8_t c) +{ + if ((c & 0xe0) == 0xc0) { + return 2; + } else if ((c & 0xf0) == 0xe0) { + return 3; + } else if ((c & 0xf8) == 0xf0) { + return 4; + } else { + return 1; + } +} + +inline bool is_valid_lead_byte(const uint8_t *c) +{ + const int length = lead_byte_length(c[0]); + if (length == 1) { + return true; + } + if (length == 2 && is_trail_byte(c[1])) { + return true; + } + if (length == 3 && is_trail_byte(c[1]) && is_trail_byte(c[2])) { + return true; + } + if (length == 4 && is_trail_byte(c[1]) && is_trail_byte(c[2]) && is_trail_byte(c[3])) { + return true; + } + return false; +} + +std::string fix_utf8(const std::string &str) +{ +#ifdef _WIN32 + // Some UTF8 charsets on Windows output have a bug, instead of 0xd? it outputs + // 0xf?, and 0xc? becomes 0xe?, so we need to fix it. + std::stringstream ss; + uint8_t *c_str = (uint8_t *)str.c_str(); + for (size_t i = 0; i < str.size(); ++i) { + if (is_lead_byte(c_str[i])) { + // this is a unicode leading byte + // if the next char is 0xff - it's a bug char, replace it with 0x9f + if (c_str[i + 1] == 0xff) { + c_str[i + 1] = 0x9f; + } + if (!is_valid_lead_byte(c_str + i)) { + // This is a bug lead byte, because it's length 3 and the i+2 byte is also + // a lead byte + c_str[i] = c_str[i] - 0x20; + } + } else { + if (c_str[i] >= 0xf8) { + // this may be a malformed lead byte. + // lets see if it becomes a valid lead byte if we "fix" it + uint8_t buf_[4]; + buf_[0] = c_str[i] - 0x20; + buf_[1] = c_str[i + 1]; + buf_[2] = c_str[i + 2]; + buf_[3] = c_str[i + 3]; + if (is_valid_lead_byte(buf_)) { + // this is a malformed lead byte, fix it + c_str[i] = c_str[i] - 0x20; + } + } + } + } + + return std::string((char *)c_str); +#else + return str; +#endif +} + +/* +* Remove leading and trailing non-alphabetic characters from a string. +* This function is used to remove leading and trailing spaces, newlines, tabs or punctuation. +* @param str: the string to remove leading and trailing non-alphabetic characters from. +* @return: the string with leading and trailing non-alphabetic characters removed. +*/ +std::string remove_leading_trailing_nonalpha(const std::string &str) +{ + if (str.size() == 0) { + return str; + } + if (str.size() == 1) { + if (std::isalpha(str[0])) { + return str; + } else { + return ""; + } + } + if (str.size() == 2) { + if (std::isalpha(str[0]) && std::isalpha(str[1])) { + return str; + } else if (std::isalpha(str[0])) { + return std::string(1, str[0]); + } else if (std::isalpha(str[1])) { + return std::string(1, str[1]); + } else { + return ""; + } + } + std::string str_copy = str; + // remove trailing spaces, newlines, tabs or punctuation + auto last_non_space = + std::find_if(str_copy.rbegin(), str_copy.rend(), [](unsigned char ch) { + return !std::isspace(ch) || !std::ispunct(ch); + }).base(); + str_copy.erase(last_non_space, str_copy.end()); + // remove leading spaces, newlines, tabs or punctuation + auto first_non_space = std::find_if(str_copy.begin(), str_copy.end(), + [](unsigned char ch) { + return !std::isspace(ch) || !std::ispunct(ch); + }) + + 1; + str_copy.erase(str_copy.begin(), first_non_space); + return str_copy; +} + +std::vector split(const std::string &string, char delimiter) +{ + std::vector tokens; + std::string token; + std::istringstream tokenStream(string); + while (std::getline(tokenStream, token, delimiter)) { + if (!token.empty()) { + tokens.push_back(token); + } + } + return tokens; +} + +std::vector split_words(const std::string &str_copy) +{ + std::vector words; + std::string word; + for (char c : str_copy) { + if (std::isspace(c)) { + if (!word.empty()) { + words.push_back(word); + word.clear(); + } + } else { + word += c; + } + } + if (!word.empty()) { + words.push_back(word); + } + return words; +} diff --git a/src/transcription-utils.h b/src/transcription-utils.h new file mode 100644 index 0000000..5fdd0cf --- /dev/null +++ b/src/transcription-utils.h @@ -0,0 +1,52 @@ +#ifndef TRANSCRIPTION_UTILS_H +#define TRANSCRIPTION_UTILS_H + +#include +#include +#include +#include +#include + +// Fix UTF8 string for Windows +std::string fix_utf8(const std::string &str); + +// Remove leading and trailing non-alphabetic characters +std::string remove_leading_trailing_nonalpha(const std::string &str); + +// Split a string by a delimiter +std::vector split(const std::string &string, char delimiter); + +// Get the current timestamp in milliseconds since epoch +inline uint64_t now_ms() +{ + return std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); +} + +// Get the current timestamp in nano seconds since epoch +inline uint64_t now_ns() +{ + return std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); +} + +// Split a string into words based on spaces +std::vector split_words(const std::string &str_copy); + +// trim (strip) string from leading and trailing whitespaces +template StringLike trim(const StringLike &str) +{ + StringLike str_copy = str; + str_copy.erase(str_copy.begin(), + std::find_if(str_copy.begin(), str_copy.end(), + [](unsigned char ch) { return !std::isspace(ch); })); + str_copy.erase(std::find_if(str_copy.rbegin(), str_copy.rend(), + [](unsigned char ch) { return !std::isspace(ch); }) + .base(), + str_copy.end()); + return str_copy; +} + +#endif // TRANSCRIPTION_UTILS_H diff --git a/src/translation/language_codes.cpp b/src/translation/language_codes.cpp new file mode 100644 index 0000000..e4ab557 --- /dev/null +++ b/src/translation/language_codes.cpp @@ -0,0 +1,256 @@ +#include "language_codes.h" + +std::map language_codes = {{"__af__", "Afrikaans"}, + {"__am__", "Amharic"}, + {"__ar__", "Arabic"}, + {"__ast__", "Asturian"}, + {"__az__", "Azerbai"}, + {"__ba__", "Bashkir"}, + {"__be__", "Belarusian"}, + {"__bg__", "Bulgarian"}, + {"__bn__", "Bengali"}, + {"__br__", "Breton"}, + {"__bs__", "Bosnian"}, + {"__ca__", "Catalan"}, + {"__ceb__", "Cebuano"}, + {"__cs__", "Czech"}, + {"__cy__", "Welsh"}, + {"__da__", "Danish"}, + {"__de__", "German"}, + {"__el__", "Greek"}, + {"__en__", "English"}, + {"__es__", "Spanish"}, + {"__et__", "Estonian"}, + {"__fa__", "Persian"}, + {"__ff__", "Fulah"}, + {"__fi__", "Finnish"}, + {"__fr__", "French"}, + {"__fy__", "Frisian"}, + {"__ga__", "Irish"}, + {"__gd__", "Scottish Gaelic"}, + {"__gl__", "Galician"}, + {"__gu__", "Gujarati"}, + {"__ha__", "Hausa"}, + {"__he__", "Hebrew"}, + {"__hi__", "Hindi"}, + {"__hr__", "Croatian"}, + {"__ht__", "Haitian Creole"}, + {"__hu__", "Hungarian"}, + {"__hy__", "Armenian"}, + {"__id__", "Indonesian"}, + {"__ig__", "Igbo"}, + {"__ilo__", "Ilokano"}, + {"__is__", "Icelandic"}, + {"__it__", "Italian"}, + {"__ja__", "Japanese"}, + {"__jv__", "Javanese"}, + {"__ka__", "Georgian"}, + {"__kk__", "Kazakh"}, + {"__km__", "Khmer"}, + {"__kn__", "Kannada"}, + {"__ko__", "Korean"}, + {"__lb__", "Luxembourgish"}, + {"__lg__", "Ganda"}, + {"__ln__", "Lingala"}, + {"__lo__", "Lao"}, + {"__lt__", "Lithuanian"}, + {"__lv__", "Latvian"}, + {"__mg__", "Malagasy"}, + {"__mk__", "Macedonian"}, + {"__ml__", "Malayalam"}, + {"__mn__", "Mongolian"}, + {"__mr__", "Marathi"}, + {"__ms__", "Malay"}, + {"__my__", "Burmese"}, + {"__ne__", "Nepali"}, + {"__nl__", "Dutch"}, + {"__no__", "Norwegian"}, + {"__ns__", "Northern Sotho"}, + {"__oc__", "Occitan"}, + {"__or__", "Oriya"}, + {"__pa__", "Punjabi"}, + {"__pl__", "Polish"}, + {"__ps__", "Pashto"}, + {"__pt__", "Portuguese"}, + {"__ro__", "Romanian"}, + {"__ru__", "Russian"}, + {"__sd__", "Sindhi"}, + {"__si__", "Sinhala"}, + {"__sk__", "Slovak"}, + {"__sl__", "Slovenian"}, + {"__so__", "Somali"}, + {"__sq__", "Albanian"}, + {"__sr__", "Serbian"}, + {"__ss__", "Swati"}, + {"__su__", "Sundanese"}, + {"__sv__", "Swedish"}, + {"__sw__", "Swahili"}, + {"__ta__", "Tamil"}, + {"__th__", "Thai"}, + {"__tl__", "Tagalog"}, + {"__tn__", "Tswana"}, + {"__tr__", "Turkish"}, + {"__uk__", "Ukrainian"}, + {"__ur__", "Urdu"}, + {"__uz__", "Uzbek"}, + {"__vi__", "Vietnamese"}, + {"__wo__", "Wolof"}, + {"__xh__", "Xhosa"}, + {"__yi__", "Yiddish"}, + {"__yo__", "Yoruba"}, + {"__zh__", "Chinese"}, + {"__zu__", "Zulu"}}; + +std::map language_codes_reverse = {{"Afrikaans", "__af__"}, + {"Amharic", "__am__"}, + {"Arabic", "__ar__"}, + {"Asturian", "__ast__"}, + {"Azerbai", "__az__"}, + {"Bashkir", "__ba__"}, + {"Belarusian", "__be__"}, + {"Bengali", "__bn__"}, + {"Breton", "__br__"}, + {"Bosnian", "__bs__"}, + {"Catalan", "__ca__"}, + {"Cebuano", "__ceb__"}, + {"Czech", "__cs__"}, + {"Welsh", "__cy__"}, + {"Danish", "__da__"}, + {"German", "__de__"}, + {"Greek", "__el__"}, + {"English", "__en__"}, + {"Spanish", "__es__"}, + {"Estonian", "__et__"}, + {"Persian", "__fa__"}, + {"Fulah", "__ff__"}, + {"Finnish", "__fi__"}, + {"French", "__fr__"}, + {"Frisian", "__fy__"}, + {"Irish", "__ga__"}, + {"Scottish Gaelic", "__gd__"}, + {"Galician", "__gl__"}, + {"Gujarati", "__gu__"}, + {"Hausa", "__ha__"}, + {"Hebrew", "__he__"}, + {"Hindi", "__hi__"}, + {"Croatian", "__hr__"}, + {"Haitian Creole", "__ht__"}, + {"Hungarian", "__hu__"}, + {"Armenian", "__hy__"}, + {"Indonesian", "__id__"}, + {"Igbo", "__ig__"}, + {"Ilokano", "__ilo__"}, + {"Icelandic", "__is__"}, + {"Italian", "__it__"}, + {"Japanese", "__ja__"}, + {"Javanese", "__jv__"}, + {"Georgian", "__ka__"}, + {"Kazakh", "__kk__"}, + {"Khmer", "__km__"}, + {"Kannada", "__kn__"}, + {"Korean", "__ko__"}, + {"Luxembourgish", "__lb__"}, + {"Ganda", "__lg__"}, + {"Lingala", "__ln__"}, + {"Lao", "__lo__"}, + {"Lithuanian", "__lt__"}, + {"Latvian", "__lv__"}, + {"Malagasy", "__mg__"}, + {"Macedonian", "__mk__"}, + {"Malayalam", "__ml__"}, + {"Mongolian", "__mn__"}, + {"Marathi", "__mr__"}, + {"Malay", "__ms__"}, + {"Burmese", "__my__"}, + {"Nepali", "__ne__"}, + {"Dutch", "__nl__"}, + {"Norwegian", "__no__"}, + {"Northern Sotho", "__ns__"}, + {"Occitan", "__oc__"}, + {"Oriya", "__or__"}, + {"Punjabi", "__pa__"}, + {"Polish", "__pl__"}, + {"Pashto", "__ps__"}, + {"Portuguese", "__pt__"}, + {"Romanian", "__ro__"}, + {"Russian", "__ru__"}, + {"Sindhi", "__sd__"}, + {"Sinhala", "__si__"}, + {"Slovak", "__sk__"}, + {"Slovenian", "__sl__"}, + {"Somali", "__so__"}, + {"Albanian", "__sq__"}, + {"Serbian", "__sr__"}, + {"Swati", "__ss__"}, + {"Sundanese", "__su__"}, + {"Swedish", "__sv__"}, + {"Swahili", "__sw__"}, + {"Tamil", "__ta__"}, + {"Thai", "__th__"}, + {"Tagalog", "__tl__"}, + {"Tswana", "__tn__"}, + {"Turkish", "__tr__"}, + {"Ukrainian", "__uk__"}, + {"Urdu", "__ur__"}, + {"Uzbek", "__uz__"}, + {"Vietnamese", "__vi__"}, + {"Wolof", "__wo__"}, + {"Xhosa", "__xh__"}, + {"Yiddish", "__yi__"}, + {"Yoruba", "__yo__"}, + {"Chinese", "__zh__"}, + {"Zulu", "__zu__"}}; + +std::map language_codes_from_whisper = { + {"af", "__af__"}, {"am", "__am__"}, {"ar", "__ar__"}, {"ast", "__ast__"}, + {"az", "__az__"}, {"ba", "__ba__"}, {"be", "__be__"}, {"bg", "__bg__"}, + {"bn", "__bn__"}, {"br", "__br__"}, {"bs", "__bs__"}, {"ca", "__ca__"}, + {"ceb", "__ceb__"}, {"cs", "__cs__"}, {"cy", "__cy__"}, {"da", "__da__"}, + {"de", "__de__"}, {"el", "__el__"}, {"en", "__en__"}, {"es", "__es__"}, + {"et", "__et__"}, {"fa", "__fa__"}, {"ff", "__ff__"}, {"fi", "__fi__"}, + {"fr", "__fr__"}, {"fy", "__fy__"}, {"ga", "__ga__"}, {"gd", "__gd__"}, + {"gl", "__gl__"}, {"gu", "__gu__"}, {"ha", "__ha__"}, {"he", "__he__"}, + {"hi", "__hi__"}, {"hr", "__hr__"}, {"ht", "__ht__"}, {"hu", "__hu__"}, + {"hy", "__hy__"}, {"id", "__id__"}, {"ig", "__ig__"}, {"ilo", "__ilo__"}, + {"is", "__is__"}, {"it", "__it__"}, {"ja", "__ja__"}, {"jv", "__jv__"}, + {"ka", "__ka__"}, {"kk", "__kk__"}, {"km", "__km__"}, {"kn", "__kn__"}, + {"ko", "__ko__"}, {"lb", "__lb__"}, {"lg", "__lg__"}, {"ln", "__ln__"}, + {"lo", "__lo__"}, {"lt", "__lt__"}, {"lv", "__lv__"}, {"mg", "__mg__"}, + {"mk", "__mk__"}, {"ml", "__ml__"}, {"mn", "__mn__"}, {"mr", "__mr__"}, + {"ms", "__ms__"}, {"my", "__my__"}, {"ne", "__ne__"}, {"nl", "__nl__"}, + {"no", "__no__"}, {"ns", "__ns__"}, {"oc", "__oc__"}, {"or", "__or__"}, + {"pa", "__pa__"}, {"pl", "__pl__"}, {"ps", "__ps__"}, {"pt", "__pt__"}, + {"ro", "__ro__"}, {"ru", "__ru__"}, {"sd", "__sd__"}, {"si", "__si__"}, + {"sk", "__sk__"}, {"sl", "__sl__"}, {"so", "__so__"}, {"sq", "__sq__"}, + {"sr", "__sr__"}, {"ss", "__ss__"}, {"su", "__su__"}, {"sv", "__sv__"}, + {"sw", "__sw__"}, {"ta", "__ta__"}, {"th", "__th__"}, {"tl", "__tl__"}, + {"tn", "__tn__"}, {"tr", "__tr__"}, {"uk", "__uk__"}, {"ur", "__ur__"}, + {"uz", "__uz__"}, {"vi", "__vi__"}, {"wo", "__wo__"}, {"xh", "__xh__"}, + {"yi", "__yi__"}, {"yo", "__yo__"}, {"zh", "__zh__"}, {"zu", "__zu__"}}; + +std::map language_codes_to_whisper = { + {"__af__", "af"}, {"__am__", "am"}, {"__ar__", "ar"}, {"__ast__", "ast"}, + {"__az__", "az"}, {"__ba__", "ba"}, {"__be__", "be"}, {"__bg__", "bg"}, + {"__bn__", "bn"}, {"__br__", "br"}, {"__bs__", "bs"}, {"__ca__", "ca"}, + {"__ceb__", "ceb"}, {"__cs__", "cs"}, {"__cy__", "cy"}, {"__da__", "da"}, + {"__de__", "de"}, {"__el__", "el"}, {"__en__", "en"}, {"__es__", "es"}, + {"__et__", "et"}, {"__fa__", "fa"}, {"__ff__", "ff"}, {"__fi__", "fi"}, + {"__fr__", "fr"}, {"__fy__", "fy"}, {"__ga__", "ga"}, {"__gd__", "gd"}, + {"__gl__", "gl"}, {"__gu__", "gu"}, {"__ha__", "ha"}, {"__he__", "he"}, + {"__hi__", "hi"}, {"__hr__", "hr"}, {"__ht__", "ht"}, {"__hu__", "hu"}, + {"__hy__", "hy"}, {"__id__", "id"}, {"__ig__", "ig"}, {"__ilo__", "ilo"}, + {"__is__", "is"}, {"__it__", "it"}, {"__ja__", "ja"}, {"__jv__", "jv"}, + {"__ka__", "ka"}, {"__kk__", "kk"}, {"__km__", "km"}, {"__kn__", "kn"}, + {"__ko__", "ko"}, {"__lb__", "lb"}, {"__lg__", "lg"}, {"__ln__", "ln"}, + {"__lo__", "lo"}, {"__lt__", "lt"}, {"__lv__", "lv"}, {"__mg__", "mg"}, + {"__mk__", "mk"}, {"__ml__", "ml"}, {"__mn__", "mn"}, {"__mr__", "mr"}, + {"__ms__", "ms"}, {"__my__", "my"}, {"__ne__", "ne"}, {"__nl__", "nl"}, + {"__no__", "no"}, {"__ns__", "ns"}, {"__oc__", "oc"}, {"__or__", "or"}, + {"__pa__", "pa"}, {"__pl__", "pl"}, {"__ps__", "ps"}, {"__pt__", "pt"}, + {"__ro__", "ro"}, {"__ru__", "ru"}, {"__sd__", "sd"}, {"__si__", "si"}, + {"__sk__", "sk"}, {"__sl__", "sl"}, {"__so__", "so"}, {"__sq__", "sq"}, + {"__sr__", "sr"}, {"__ss__", "ss"}, {"__su__", "su"}, {"__sv__", "sv"}, + {"__sw__", "sw"}, {"__ta__", "ta"}, {"__th__", "th"}, {"__tl__", "tl"}, + {"__tn__", "tn"}, {"__tr__", "tr"}, {"__uk__", "uk"}, {"__ur__", "ur"}, + {"__uz__", "uz"}, {"__vi__", "vi"}, {"__wo__", "wo"}, {"__xh__", "xh"}, + {"__yi__", "yi"}, {"__yo__", "yo"}, {"__zh__", "zh"}, {"__zu__", "zu"}}; diff --git a/src/translation/language_codes.h b/src/translation/language_codes.h new file mode 100644 index 0000000..fb4890e --- /dev/null +++ b/src/translation/language_codes.h @@ -0,0 +1,12 @@ +#ifndef LANGUAGE_CODES_H +#define LANGUAGE_CODES_H + +#include +#include + +extern std::map language_codes; +extern std::map language_codes_reverse; +extern std::map language_codes_from_whisper; +extern std::map language_codes_to_whisper; + +#endif // LANGUAGE_CODES_H diff --git a/src/translation/translation-includes.h b/src/translation/translation-includes.h new file mode 100644 index 0000000..6520389 --- /dev/null +++ b/src/translation/translation-includes.h @@ -0,0 +1,8 @@ +#ifndef TRANSLATION_INCLUDES_H +#define TRANSLATION_INCLUDES_H + +#include +#include +#include + +#endif // TRANSLATION_INCLUDES_H diff --git a/src/translation/translation-language-utils.cpp b/src/translation/translation-language-utils.cpp new file mode 100644 index 0000000..685ca1a --- /dev/null +++ b/src/translation/translation-language-utils.cpp @@ -0,0 +1,33 @@ +#include "translation-language-utils.h" + +#include +#include + +std::string remove_start_punctuation(const std::string &text) +{ + if (text.empty()) { + return text; + } + + // Convert the input string to ICU's UnicodeString + icu::UnicodeString ustr = icu::UnicodeString::fromUTF8(text); + + // Find the index of the first non-punctuation character + int32_t start = 0; + while (start < ustr.length()) { + UChar32 ch = ustr.char32At(start); + if (!u_ispunct(ch)) { + break; + } + start += U16_LENGTH(ch); + } + + // Create a new UnicodeString with punctuation removed from the start + icu::UnicodeString result = ustr.tempSubString(start); + + // Convert the result back to UTF-8 + std::string output; + result.toUTF8String(output); + + return output; +} diff --git a/src/translation/translation-language-utils.h b/src/translation/translation-language-utils.h new file mode 100644 index 0000000..d2f4c47 --- /dev/null +++ b/src/translation/translation-language-utils.h @@ -0,0 +1,8 @@ +#ifndef TRANSLATION_LANGUAGE_UTILS_H +#define TRANSLATION_LANGUAGE_UTILS_H + +#include + +std::string remove_start_punctuation(const std::string &text); + +#endif // TRANSLATION_LANGUAGE_UTILS_H diff --git a/src/translation/translation-utils.cpp b/src/translation/translation-utils.cpp new file mode 100644 index 0000000..07ca268 --- /dev/null +++ b/src/translation/translation-utils.cpp @@ -0,0 +1,44 @@ +#include + +#include "translation-includes.h" +#include "translation.h" +#include "translation-utils.h" +#include "plugin-support.h" +#include "model-utils/model-downloader.h" + +void start_translation(struct transcription_filter_data *gf) +{ + obs_log(LOG_INFO, "Starting translation..."); + + if (gf->translation_model_index == "!!!external!!!") { + obs_log(LOG_INFO, "External model selected."); + if (gf->translation_model_path_external.empty()) { + obs_log(LOG_ERROR, "External model path is empty."); + gf->translate = false; + return; + } + std::string model_file_found = gf->translation_model_path_external; + build_and_enable_translation(gf, model_file_found); + return; + } + + const ModelInfo &translation_model_info = models_info[gf->translation_model_index]; + std::string model_file_found = find_model_folder(translation_model_info); + if (model_file_found == "") { + obs_log(LOG_INFO, "Translation CT2 model does not exist. Downloading..."); + download_model_with_ui_dialog( + translation_model_info, + [gf, model_file_found](int download_status, const std::string &path) { + if (download_status == 0) { + obs_log(LOG_INFO, "CT2 model download complete"); + build_and_enable_translation(gf, path); + } else { + obs_log(LOG_ERROR, "Model download failed"); + gf->translate = false; + } + }); + } else { + // Model exists, just load it + build_and_enable_translation(gf, model_file_found); + } +} diff --git a/src/translation/translation-utils.h b/src/translation/translation-utils.h new file mode 100644 index 0000000..8a06ab4 --- /dev/null +++ b/src/translation/translation-utils.h @@ -0,0 +1,8 @@ +#ifndef TRANSLATION_UTILS_H +#define TRANSLATION_UTILS_H + +#include "transcription-filter-data.h" + +void start_translation(struct transcription_filter_data *gf); + +#endif // TRANSLATION_UTILS_H diff --git a/src/translation/translation.cpp b/src/translation/translation.cpp new file mode 100644 index 0000000..0701d95 --- /dev/null +++ b/src/translation/translation.cpp @@ -0,0 +1,212 @@ +#include "translation.h" +#include "plugin-support.h" +#include "model-utils/model-find-utils.h" +#include "transcription-filter-data.h" +#include "language_codes.h" +#include "translation-language-utils.h" + +#include +#include +#include +#include + +void build_and_enable_translation(struct transcription_filter_data *gf, + const std::string &model_file_path) +{ + std::lock_guard lock(gf->whisper_ctx_mutex); + + gf->translation_ctx.local_model_folder_path = model_file_path; + if (build_translation_context(gf->translation_ctx) == + OBS_POLYGLOT_TRANSLATION_INIT_SUCCESS) { + obs_log(LOG_INFO, "Enable translation"); + gf->translate = true; + } else { + obs_log(LOG_ERROR, "Failed to load CT2 model"); + gf->translate = false; + } +} + +int build_translation_context(struct translation_context &translation_ctx) +{ + std::string local_model_path = translation_ctx.local_model_folder_path; + obs_log(LOG_INFO, "Building translation context from '%s'...", local_model_path.c_str()); + // find the SPM file in the model folder + std::string local_spm_path = find_file_in_folder_by_regex_expression( + local_model_path, "(sentencepiece|spm|spiece|source).*?\\.(model|spm)"); + std::string target_spm_path = + find_file_in_folder_by_regex_expression(local_model_path, "target.*?\\.spm"); + + try { + obs_log(LOG_INFO, "Loading SPM from %s", local_spm_path.c_str()); + translation_ctx.processor.reset(new sentencepiece::SentencePieceProcessor()); + const auto status = translation_ctx.processor->Load(local_spm_path); + if (!status.ok()) { + obs_log(LOG_ERROR, "Failed to load SPM: %s", status.ToString().c_str()); + return OBS_POLYGLOT_TRANSLATION_INIT_FAIL; + } + + if (!target_spm_path.empty()) { + obs_log(LOG_INFO, "Loading target SPM from %s", target_spm_path.c_str()); + translation_ctx.target_processor.reset( + new sentencepiece::SentencePieceProcessor()); + const auto target_status = + translation_ctx.target_processor->Load(target_spm_path); + if (!target_status.ok()) { + obs_log(LOG_ERROR, "Failed to load target SPM: %s", + target_status.ToString().c_str()); + return OBS_POLYGLOT_TRANSLATION_INIT_FAIL; + } + } else { + obs_log(LOG_INFO, "Target SPM not found, using source SPM for target"); + translation_ctx.target_processor.release(); + } + + translation_ctx.tokenizer = [&translation_ctx](const std::string &text) { + std::vector tokens; + translation_ctx.processor->Encode(text, &tokens); + return tokens; + }; + translation_ctx.detokenizer = + [&translation_ctx](const std::vector &tokens) { + std::string text; + if (translation_ctx.target_processor) { + translation_ctx.target_processor->Decode(tokens, &text); + } else { + translation_ctx.processor->Decode(tokens, &text); + } + return std::regex_replace(text, std::regex(""), "UNK"); + }; + + obs_log(LOG_INFO, "Loading CT2 model from %s", local_model_path.c_str()); + +#ifdef POLYGLOT_WITH_CUDA + ctranslate2::Device device = ctranslate2::Device::CUDA; + obs_log(LOG_INFO, "CT2 Using CUDA"); +#else + ctranslate2::Device device = ctranslate2::Device::CPU; + obs_log(LOG_INFO, "CT2 Using CPU"); +#endif + + translation_ctx.translator.reset(new ctranslate2::Translator( + local_model_path, device, ctranslate2::ComputeType::AUTO)); + obs_log(LOG_INFO, "CT2 Model loaded"); + + translation_ctx.options.reset(new ctranslate2::TranslationOptions); + translation_ctx.options->beam_size = 1; + translation_ctx.options->max_decoding_length = 64; + translation_ctx.options->repetition_penalty = 2.0f; + translation_ctx.options->no_repeat_ngram_size = 1; + translation_ctx.options->max_input_length = 64; + translation_ctx.options->sampling_temperature = 0.1f; + } catch (std::exception &e) { + obs_log(LOG_ERROR, "Failed to load CT2 model: %s", e.what()); + return OBS_POLYGLOT_TRANSLATION_INIT_FAIL; + } + return OBS_POLYGLOT_TRANSLATION_INIT_SUCCESS; +} + +int translate(struct translation_context &translation_ctx, const std::string &text, + const std::string &source_lang, const std::string &target_lang, std::string &result) +{ + try { + std::vector results; + std::vector target_prefix; + + if (translation_ctx.input_tokenization_style == INPUT_TOKENIZAION_M2M100) { + // set input tokens + std::vector input_tokens = {source_lang, ""}; + if (translation_ctx.add_context > 0 && + translation_ctx.last_input_tokens.size() > 0) { + // add the last input tokens sentences to the input tokens + for (const auto &tokens : translation_ctx.last_input_tokens) { + input_tokens.insert(input_tokens.end(), tokens.begin(), + tokens.end()); + } + } + std::vector new_input_tokens = translation_ctx.tokenizer(text); + input_tokens.insert(input_tokens.end(), new_input_tokens.begin(), + new_input_tokens.end()); + input_tokens.push_back(""); + + // log the input tokens + std::string input_tokens_str; + for (const auto &token : input_tokens) { + input_tokens_str += token + ", "; + } + obs_log(LOG_INFO, "Input tokens: %s", input_tokens_str.c_str()); + + translation_ctx.last_input_tokens.push_back(new_input_tokens); + // remove the oldest input tokens + while (translation_ctx.last_input_tokens.size() > + (size_t)translation_ctx.add_context) { + translation_ctx.last_input_tokens.pop_front(); + } + + const std::vector> batch = {input_tokens}; + + // get target prefix + target_prefix = {target_lang}; + // add the last translation tokens to the target prefix + if (translation_ctx.add_context > 0 && + translation_ctx.last_translation_tokens.size() > 0) { + for (const auto &tokens : translation_ctx.last_translation_tokens) { + target_prefix.insert(target_prefix.end(), tokens.begin(), + tokens.end()); + } + } + + // log the target prefix + std::string target_prefix_str; + for (const auto &token : target_prefix) { + target_prefix_str += token + ","; + } + obs_log(LOG_INFO, "Target prefix: %s", target_prefix_str.c_str()); + + const std::vector> target_prefix_batch = { + target_prefix}; + results = translation_ctx.translator->translate_batch( + batch, target_prefix_batch, *translation_ctx.options); + } else { + // set input tokens + std::vector input_tokens = {}; + std::vector new_input_tokens = translation_ctx.tokenizer( + "<2" + language_codes_to_whisper[target_lang] + "> " + text); + input_tokens.insert(input_tokens.end(), new_input_tokens.begin(), + new_input_tokens.end()); + const std::vector> batch = {input_tokens}; + + results = translation_ctx.translator->translate_batch( + batch, {}, *translation_ctx.options); + } + + const auto &tokens_result = results[0].output(); + // take the tokens from the target_prefix length to the end + std::vector translation_tokens( + tokens_result.begin() + target_prefix.size(), tokens_result.end()); + + // log the translation tokens + std::string translation_tokens_str; + for (const auto &token : translation_tokens) { + translation_tokens_str += token + ", "; + } + obs_log(LOG_INFO, "Translation tokens: %s", translation_tokens_str.c_str()); + + // save the translation tokens + translation_ctx.last_translation_tokens.push_back(translation_tokens); + // remove the oldest translation tokens + while (translation_ctx.last_translation_tokens.size() > + (size_t)translation_ctx.add_context) { + translation_ctx.last_translation_tokens.pop_front(); + } + obs_log(LOG_INFO, "Last translation tokens deque size: %d", + (int)translation_ctx.last_translation_tokens.size()); + + // detokenize + const std::string result_ = translation_ctx.detokenizer(translation_tokens); + result = remove_start_punctuation(result_); + } catch (std::exception &e) { + obs_log(LOG_ERROR, "Error: %s", e.what()); + return OBS_POLYGLOT_TRANSLATION_FAIL; + } + return OBS_POLYGLOT_TRANSLATION_SUCCESS; +} diff --git a/src/translation/translation.h b/src/translation/translation.h new file mode 100644 index 0000000..c740726 --- /dev/null +++ b/src/translation/translation.h @@ -0,0 +1,48 @@ +#ifndef TRANSLATION_H +#define TRANSLATION_H + +#include +#include +#include +#include +#include + +enum InputTokenizationStyle { INPUT_TOKENIZAION_M2M100 = 0, INPUT_TOKENIZAION_T5 }; + +namespace ctranslate2 { +class Translator; +class TranslationOptions; +} // namespace ctranslate2 + +namespace sentencepiece { +class SentencePieceProcessor; +} // namespace sentencepiece + +struct translation_context { + std::string local_model_folder_path; + std::unique_ptr processor; + std::unique_ptr target_processor; + std::unique_ptr translator; + std::unique_ptr options; + std::function(const std::string &)> tokenizer; + std::function &)> detokenizer; + std::deque> last_input_tokens; + std::deque> last_translation_tokens; + // How many sentences to use as context for the next translation + int add_context; + InputTokenizationStyle input_tokenization_style; +}; + +int build_translation_context(struct translation_context &translation_ctx); +void build_and_enable_translation(struct transcription_filter_data *gf, + const std::string &model_file_path); + +int translate(struct translation_context &translation_ctx, const std::string &text, + const std::string &source_lang, const std::string &target_lang, std::string &result); + +#define OBS_POLYGLOT_TRANSLATION_INIT_FAIL -1 +#define OBS_POLYGLOT_TRANSLATION_INIT_SUCCESS 0 +#define OBS_POLYGLOT_TRANSLATION_SUCCESS 0 +#define OBS_POLYGLOT_TRANSLATION_FAIL -1 + +#endif // TRANSLATION_H diff --git a/src/whisper-utils/silero-vad-onnx.cpp b/src/whisper-utils/silero-vad-onnx.cpp new file mode 100644 index 0000000..078e47c --- /dev/null +++ b/src/whisper-utils/silero-vad-onnx.cpp @@ -0,0 +1,353 @@ +#include "silero-vad-onnx.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "plugin-support.h" + +// #define __DEBUG_SPEECH_PROB___ + +// prevent clang-format from reformatting the code +// clang-format off +timestamp_t::timestamp_t(int start_, int end_) : start(start_), end(end_){}; + +// assignment operator modifies object, therefore non-const +timestamp_t ×tamp_t::operator=(const timestamp_t &a) +{ + start = a.start; + end = a.end; + return *this; +}; + +// equality comparison. doesn't modify object. therefore const. +bool timestamp_t::operator==(const timestamp_t &a) const +{ + return (start == a.start && end == a.end); +}; +std::string timestamp_t::string() +{ + std::stringstream ss; + ss << "timestamp " << start << ", " << end; + return ss.str(); +}; + +std::string timestamp_t::format(const char *fmt, ...) +{ + char buf[256]; + + va_list args; + va_start(args, fmt); + const auto r = std::vsnprintf(buf, sizeof buf, fmt, args); + va_end(args); + + if (r < 0) + // conversion failed + return {}; + + const size_t len = r; + if (len < sizeof buf) + // we fit in the buffer + return {buf, len}; + +#if __cplusplus >= 201703L + // C++17: Create a string and write to its underlying array + std::string s(len, '\0'); + va_start(args, fmt); + std::vsnprintf(s.data(), len + 1, fmt, args); + va_end(args); + + return s; +#else + // C++11 or C++14: We need to allocate scratch memory + auto vbuf = std::unique_ptr(new char[len + 1]); + va_start(args, fmt); + std::vsnprintf(vbuf.get(), len + 1, fmt, args); + va_end(args); + + return {vbuf.get(), len}; +#endif +}; + +void VadIterator::init_engine_threads(int inter_threads, int intra_threads) +{ + // The method should be called in each thread/proc in multi-thread/proc work + session_options.SetIntraOpNumThreads(intra_threads); + session_options.SetInterOpNumThreads(inter_threads); + session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); +}; + +void VadIterator::init_onnx_model(const SileroString &model_path) +{ + // Init threads = 1 for + init_engine_threads(1, 1); + // Load model + session = std::make_shared(env, model_path.c_str(), session_options); +}; + +void VadIterator::reset_states(bool reset_state) +{ + if (reset_state) { + // Call reset before each audio start + std::memset(_state.data(), 0.0f, _state.size() * sizeof(float)); + triggered = false; + } + temp_end = 0; + current_sample = 0; + + prev_end = next_start = 0; + + speeches.clear(); + current_speech = timestamp_t(); +}; + +float VadIterator::predict_one(const std::vector &data) +{ + // Infer + // Create ort tensors + input.assign(data.begin(), data.end()); + Ort::Value input_ort = Ort::Value::CreateTensor(memory_info, input.data(), + input.size(), input_node_dims, 2); + Ort::Value state_ort = Ort::Value::CreateTensor( + memory_info, _state.data(), _state.size(), state_node_dims, 3); + Ort::Value sr_ort = Ort::Value::CreateTensor(memory_info, sr.data(), sr.size(), + sr_node_dims, 1); + + // Clear and add inputs + ort_inputs.clear(); + ort_inputs.emplace_back(std::move(input_ort)); + ort_inputs.emplace_back(std::move(state_ort)); + ort_inputs.emplace_back(std::move(sr_ort)); + + // Infer + ort_outputs = session->Run(Ort::RunOptions{nullptr}, input_node_names.data(), + ort_inputs.data(), ort_inputs.size(), output_node_names.data(), + output_node_names.size()); + + // Output probability & update h,c recursively + float speech_prob = ort_outputs[0].GetTensorMutableData()[0]; + float *stateN = ort_outputs[1].GetTensorMutableData(); + std::memcpy(_state.data(), stateN, size_state * sizeof(float)); + + return speech_prob; +} + +void VadIterator::predict(const std::vector &data) +{ + const float speech_prob = predict_one(data); + + // Push forward sample index + current_sample += (unsigned int)window_size_samples; + + // Reset temp_end when > threshold + if ((speech_prob >= threshold)) { +#ifdef __DEBUG_SPEECH_PROB___ + float speech = + current_sample - + window_size_samples; // minus window_size_samples to get precise start time point. + obs_log(LOG_INFO, "{ start: %.3f s (%.3f) %08d}", 1.0 * speech / sample_rate, + speech_prob, current_sample - window_size_samples); +#endif //__DEBUG_SPEECH_PROB___ + if (temp_end != 0) { + temp_end = 0; + if (next_start < prev_end) + next_start = current_sample - (unsigned int)window_size_samples; + } + if (triggered == false) { + triggered = true; + + current_speech.start = current_sample - (unsigned int)window_size_samples; + } + return; + } + + if ((triggered == true) && + ((float)(current_sample - current_speech.start) > max_speech_samples)) { + if (prev_end > 0) { + current_speech.end = prev_end; + speeches.push_back(current_speech); + current_speech = timestamp_t(); + + // previously reached silence(< neg_thres) and is still not speech(< thres) + if (next_start < prev_end) + triggered = false; + else { + current_speech.start = next_start; + } + prev_end = 0; + next_start = 0; + temp_end = 0; + + } else { + current_speech.end = current_sample; + speeches.push_back(current_speech); + current_speech = timestamp_t(); + prev_end = 0; + next_start = 0; + temp_end = 0; + triggered = false; + } + return; + } + if ((speech_prob >= (threshold - 0.15)) && (speech_prob < threshold)) { + if (triggered) { +#ifdef __DEBUG_SPEECH_PROB___ + float speech = + current_sample - + window_size_samples; // minus window_size_samples to get precise start time point. + obs_log(LOG_INFO, "{ speaking: %.3f s (%.3f) %08d}", + 1.0 * speech / sample_rate, speech_prob, + current_sample - window_size_samples); +#endif //__DEBUG_SPEECH_PROB___ + } else { +#ifdef __DEBUG_SPEECH_PROB___ + float speech = + current_sample - + window_size_samples; // minus window_size_samples to get precise start time point. + obs_log(LOG_INFO, "{ silence: %.3f s (%.3f) %08d}", + 1.0 * speech / sample_rate, speech_prob, + current_sample - window_size_samples); +#endif //__DEBUG_SPEECH_PROB___ + } + return; + } + + // 4) End + if ((speech_prob < (threshold - 0.15))) { +#ifdef __DEBUG_SPEECH_PROB___ + float speech = + current_sample - window_size_samples - + speech_pad_samples; // minus window_size_samples to get precise start time point. + obs_log(LOG_INFO, "{ end: %.3f s (%.3f) %08d}", 1.0 * speech / sample_rate, + speech_prob, current_sample - window_size_samples); +#endif //__DEBUG_SPEECH_PROB___ + if (triggered == true) { + if (temp_end == 0) { + temp_end = current_sample; + } + if (current_sample - temp_end > + (unsigned int)min_silence_samples_at_max_speech) + prev_end = temp_end; + // a. silence < min_slience_samples, continue speaking + if ((current_sample - temp_end) < (unsigned int)min_silence_samples) { + + } + // b. silence >= min_slience_samples, end speaking + else { + current_speech.end = temp_end; + if (current_speech.end - current_speech.start > + min_speech_samples) { + speeches.push_back(current_speech); + current_speech = timestamp_t(); + prev_end = 0; + next_start = 0; + temp_end = 0; + triggered = false; + } + } + } else { + // may first windows see end state. + } + return; + } +}; + +void VadIterator::process(const std::vector &input_wav, bool reset_state) +{ + reset_states(reset_state); + + audio_length_samples = (int)input_wav.size(); + + for (int j = 0; j < audio_length_samples; j += (int)window_size_samples) { + if (j + (int)window_size_samples > audio_length_samples) + break; + std::vector r{&input_wav[0] + j, &input_wav[0] + j + window_size_samples}; + predict(r); + } + + if (current_speech.start >= 0) { + current_speech.end = audio_length_samples; + speeches.push_back(current_speech); + current_speech = timestamp_t(); + prev_end = 0; + next_start = 0; + temp_end = 0; + triggered = false; + } +}; + +void VadIterator::process(const std::vector &input_wav, std::vector &output_wav) +{ + process(input_wav); + collect_chunks(input_wav, output_wav); +} + +void VadIterator::collect_chunks(const std::vector &input_wav, + std::vector &output_wav) +{ + output_wav.clear(); + for (size_t i = 0; i < speeches.size(); i++) { +#ifdef __DEBUG_SPEECH_PROB___ + obs_log(LOG_INFO, "%s", speeches[i].string().c_str()); +#endif //#ifdef __DEBUG_SPEECH_PROB___ + std::vector slice(&input_wav[speeches[i].start], + &input_wav[speeches[i].end]); + output_wav.insert(output_wav.end(), slice.begin(), slice.end()); + } +}; + +const std::vector VadIterator::get_speech_timestamps() const +{ + return speeches; +} + +void VadIterator::drop_chunks(const std::vector &input_wav, std::vector &output_wav) +{ + output_wav.clear(); + int current_start = 0; + for (size_t i = 0; i < speeches.size(); i++) { + + std::vector slice(&input_wav[current_start], &input_wav[speeches[i].start]); + output_wav.insert(output_wav.end(), slice.begin(), slice.end()); + current_start = speeches[i].end; + } + + std::vector slice(&input_wav[current_start], &input_wav[input_wav.size()]); + output_wav.insert(output_wav.end(), slice.begin(), slice.end()); +}; + +VadIterator::VadIterator(const SileroString &ModelPath, int Sample_rate, int windows_frame_size, + float Threshold, int min_silence_duration_ms, int speech_pad_ms, + int min_speech_duration_ms, float max_speech_duration_s) +{ + init_onnx_model(ModelPath); + threshold = Threshold; + sample_rate = Sample_rate; + sr_per_ms = sample_rate / 1000; + + window_size_samples = windows_frame_size * sr_per_ms; + + min_speech_samples = sr_per_ms * min_speech_duration_ms; + speech_pad_samples = sr_per_ms * speech_pad_ms; + + max_speech_samples = ((float)sample_rate * max_speech_duration_s - + (float)window_size_samples - 2.0f * (float)speech_pad_samples); + + min_silence_samples = sr_per_ms * min_silence_duration_ms; + min_silence_samples_at_max_speech = sr_per_ms * 98; + + input.resize(window_size_samples); + input_node_dims[0] = 1; + input_node_dims[1] = window_size_samples; + + _state.resize(size_state); + sr.resize(1); + sr[0] = sample_rate; +}; diff --git a/src/whisper-utils/silero-vad-onnx.h b/src/whisper-utils/silero-vad-onnx.h new file mode 100644 index 0000000..cb284a4 --- /dev/null +++ b/src/whisper-utils/silero-vad-onnx.h @@ -0,0 +1,115 @@ +#ifndef SILERO_VAD_ONNX_H +#define SILERO_VAD_ONNX_H + +#include +#include +#include +#include + +#ifdef _WIN32 +typedef std::wstring SileroString; +#else +typedef std::string SileroString; +#endif + +class timestamp_t { +public: + int start; + int end; + + // default + parameterized constructor + timestamp_t(int start = -1, int end = -1); + + // assignment operator modifies object, therefore non-const + timestamp_t &operator=(const timestamp_t &a); + + // equality comparison. doesn't modify object. therefore const. + bool operator==(const timestamp_t &a) const; + std::string string(); + +private: + std::string format(const char *fmt, ...); +}; + +class VadIterator { +private: + // OnnxRuntime resources + Ort::Env env; + Ort::SessionOptions session_options; + std::shared_ptr session = nullptr; + Ort::AllocatorWithDefaultOptions allocator; + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU); + +private: + void init_engine_threads(int inter_threads, int intra_threads); + void init_onnx_model(const SileroString &model_path); + void reset_states(bool reset_state); + float predict_one(const std::vector &data); + void predict(const std::vector &data); + +public: + void process(const std::vector &input_wav, bool reset_state = true); + void process(const std::vector &input_wav, std::vector &output_wav); + void collect_chunks(const std::vector &input_wav, std::vector &output_wav); + const std::vector get_speech_timestamps() const; + void drop_chunks(const std::vector &input_wav, std::vector &output_wav); + void set_threshold(float threshold_) { this->threshold = threshold_; } + + int64_t get_window_size_samples() const { return window_size_samples; } + +private: + // model config + int64_t window_size_samples; // Assign when init, support 256 512 768 for 8k; 512 1024 1536 for 16k. + int sample_rate; // Assign when init support 16000 or 8000 + int sr_per_ms; // Assign when init, support 8 or 16 + float threshold; + int min_silence_samples; // sr_per_ms * #ms + int min_silence_samples_at_max_speech; // sr_per_ms * #98 + int min_speech_samples; // sr_per_ms * #ms + float max_speech_samples; + int speech_pad_samples; // usually a + int audio_length_samples; + + // model states + bool triggered = false; + unsigned int temp_end = 0; + unsigned int current_sample = 0; + // MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes + int prev_end; + int next_start = 0; + + //Output timestamp + std::vector speeches; + timestamp_t current_speech; + + // Onnx model + // Inputs + std::vector ort_inputs; + + std::vector input_node_names = {"input", "state", "sr"}; + std::vector input; + unsigned int size_state = 2 * 1 * 128; // It's FIXED. + std::vector _state; + std::vector sr; + + int64_t input_node_dims[2] = {}; + const int64_t state_node_dims[3] = {2, 1, 128}; + const int64_t sr_node_dims[1] = {1}; + + // Outputs + std::vector ort_outputs; + std::vector output_node_names = {"output", "stateN"}; + +public: + // Construction + VadIterator(const SileroString &ModelPath, int Sample_rate = 16000, + int windows_frame_size = 32, float Threshold = 0.5, + int min_silence_duration_ms = 0, int speech_pad_ms = 32, + int min_speech_duration_ms = 32, + float max_speech_duration_s = std::numeric_limits::infinity()); + + // Default constructor + VadIterator() = default; +}; + +#endif // SILERO_VAD_ONNX_H diff --git a/src/whisper-utils/token-buffer-thread.cpp b/src/whisper-utils/token-buffer-thread.cpp new file mode 100644 index 0000000..3e3b002 --- /dev/null +++ b/src/whisper-utils/token-buffer-thread.cpp @@ -0,0 +1,413 @@ +#include +#include +#include + +#include "token-buffer-thread.h" +#include "whisper-utils.h" +#include "transcription-utils.h" + +#include +#include + +#include + +#ifdef _WIN32 +#include +#define SPACE L" " +#define NEWLINE L"\n" +#else +#define SPACE " " +#define NEWLINE "\n" +#endif + +TokenBufferThread::TokenBufferThread() noexcept + : gf(nullptr), + numSentences(2), + numPerSentence(30), + maxTime(0), + stop(true), + presentationQueueMutex(), + inputQueueMutex(), + segmentation(SEGMENTATION_TOKEN) +{ +} + +TokenBufferThread::~TokenBufferThread() +{ + stopThread(); +} + +void TokenBufferThread::initialize( + struct transcription_filter_data *gf_, + std::function captionPresentationCallback_, + std::function sentenceOutputCallback_, size_t numSentences_, + size_t numPerSentence_, std::chrono::seconds maxTime_, + TokenBufferSegmentation segmentation_) +{ + this->gf = gf_; + this->captionPresentationCallback = captionPresentationCallback_; + this->sentenceOutputCallback = sentenceOutputCallback_; + this->numSentences = numSentences_; + this->numPerSentence = numPerSentence_; + this->segmentation = segmentation_; + this->maxTime = maxTime_; + this->stop = false; + this->workerThread = std::thread(&TokenBufferThread::monitor, this); + this->lastContributionTime = std::chrono::steady_clock::now(); + this->lastCaptionTime = std::chrono::steady_clock::now(); +} + +void TokenBufferThread::stopThread() +{ + { + std::lock_guard lock(presentationQueueMutex); + stop = true; + } + cv.notify_all(); + if (workerThread.joinable()) { + workerThread.join(); + } +} + +void TokenBufferThread::log_token_vector(const std::vector &tokens) +{ + std::string output; + for (const auto &token : tokens) { + output += token; + } + obs_log(LOG_INFO, "TokenBufferThread::log_token_vector: '%s'", output.c_str()); +} + +void TokenBufferThread::addSentenceFromStdString(const std::string &sentence, + TokenBufferTimePoint start_time, + TokenBufferTimePoint end_time, bool is_partial) +{ + if (sentence.empty()) { + return; + } +#ifdef _WIN32 + // on windows convert from multibyte to wide char + int count = + MultiByteToWideChar(CP_UTF8, 0, sentence.c_str(), (int)sentence.length(), NULL, 0); + TokenBufferString sentence_ws(count, 0); + MultiByteToWideChar(CP_UTF8, 0, sentence.c_str(), (int)sentence.length(), &sentence_ws[0], + count); +#else + TokenBufferString sentence_ws = sentence; +#endif + + TokenBufferSentence sentence_for_add; + sentence_for_add.start_time = start_time; + sentence_for_add.end_time = end_time; + + if (this->segmentation == SEGMENTATION_WORD) { + // split the sentence to words + std::vector words; + std::basic_istringstream iss(sentence_ws); + TokenBufferString word_token; + while (iss >> word_token) { + words.push_back(word_token); + } + // add the words to a sentence + for (const auto &word : words) { + sentence_for_add.tokens.push_back({word, is_partial}); + sentence_for_add.tokens.push_back({SPACE, is_partial}); + } + } else if (this->segmentation == SEGMENTATION_TOKEN) { + // split to characters + std::vector characters; + for (const auto &c : sentence_ws) { + characters.push_back(TokenBufferString(1, c)); + } + // add the characters to a sentece + for (const auto &character : characters) { + sentence_for_add.tokens.push_back({character, is_partial}); + } + } else { + // add the whole sentence as a single token + sentence_for_add.tokens.push_back({sentence_ws, is_partial}); + sentence_for_add.tokens.push_back({SPACE, is_partial}); + } + addSentence(sentence_for_add); +} + +void TokenBufferThread::addSentence(const TokenBufferSentence &sentence) +{ + std::lock_guard lock(this->inputQueueMutex); + + // add the tokens to the inputQueue + for (const auto &character : sentence.tokens) { + inputQueue.push_back(character); + } + inputQueue.push_back({SPACE, sentence.tokens.back().is_partial}); + + // add to the contribution queue as well + for (const auto &character : sentence.tokens) { + contributionQueue.push_back(character); + } + contributionQueue.push_back({SPACE, sentence.tokens.back().is_partial}); + this->lastContributionTime = std::chrono::steady_clock::now(); +} + +void TokenBufferThread::clear() +{ + { + std::lock_guard lock(inputQueueMutex); + inputQueue.clear(); + } + { + std::lock_guard lock(presentationQueueMutex); + presentationQueue.clear(); + } + this->lastCaption = ""; + this->lastCaptionTime = std::chrono::steady_clock::now(); + this->captionPresentationCallback(""); +} + +void TokenBufferThread::monitor() +{ + obs_log(LOG_INFO, "TokenBufferThread::monitor"); + + this->captionPresentationCallback(""); + + while (true) { + std::string caption_out; + + { + std::lock_guard lockPresentation(presentationQueueMutex); + if (stop) { + break; + } + + // condition presentation queue + if (presentationQueue.size() == this->numSentences * this->numPerSentence) { + // pop a whole sentence from the presentation queue front + for (size_t i = 0; i < this->numPerSentence; i++) { + presentationQueue.pop_front(); + } + if (this->segmentation == SEGMENTATION_TOKEN) { + // pop tokens until a space is found + while (!presentationQueue.empty() && + presentationQueue.front().token != SPACE) { + presentationQueue.pop_front(); + } + } + } + + { + std::lock_guard lock(inputQueueMutex); + + if (!inputQueue.empty()) { + // if the input on the inputQueue is partial - first remove all partials + // from the end of the presentation queue + while (!presentationQueue.empty() && + presentationQueue.back().is_partial) { + presentationQueue.pop_back(); + } + + // if there are token on the input queue + // then add to the presentation queue based on the segmentation + if (this->segmentation == SEGMENTATION_SENTENCE) { + // add all the tokens from the input queue to the presentation queue + for (const auto &token : inputQueue) { + presentationQueue.push_back(token); + } + inputQueue.clear(); + } else if (this->segmentation == SEGMENTATION_TOKEN) { + // add one token to the presentation queue + presentationQueue.push_back(inputQueue.front()); + inputQueue.pop_front(); + } else { + // SEGMENTATION_WORD + // skip spaces in the beginning of the input queue + while (!inputQueue.empty() && + inputQueue.front().token == SPACE) { + inputQueue.pop_front(); + } + // add one word to the presentation queue + TokenBufferToken word; + while (!inputQueue.empty() && + inputQueue.front().token != SPACE) { + word = inputQueue.front(); + inputQueue.pop_front(); + } + presentationQueue.push_back(word); + } + } + } + + if (presentationQueue.size() > 0) { + // build a caption from the presentation queue in sentences + // with a maximum of numPerSentence tokens/words per sentence + // and a newline between sentences + std::vector sentences(1); + + if (this->segmentation == SEGMENTATION_WORD) { + // add words from the presentation queue to the sentences + // if a sentence is full - start a new one + size_t wordsInSentence = 0; + for (size_t i = 0; i < presentationQueue.size(); i++) { + const auto &word = presentationQueue[i]; + sentences.back() += word.token + SPACE; + wordsInSentence++; + if (wordsInSentence == this->numPerSentence) { + sentences.push_back(TokenBufferString()); + } + } + } else { + // iterate through the presentation queue tokens and build a caption + for (size_t i = 0; i < presentationQueue.size(); i++) { + const auto &token = presentationQueue[i]; + // skip spaces in the beginning of a sentence (tokensInSentence == 0) + if (token.token == SPACE && + sentences.back().length() == 0) { + continue; + } + + sentences.back() += token.token; + if (sentences.back().length() == + this->numPerSentence) { + // if the next character is not a space - this is a broken word + // roll back to the last space, replace it with a newline + size_t lastSpace = + sentences.back().find_last_of( + SPACE); + sentences.push_back(sentences.back().substr( + lastSpace + 1)); + sentences[sentences.size() - 2] = + sentences[sentences.size() - 2] + .substr(0, lastSpace); + } + } + } + + TokenBufferString caption; + // if there are more sentences than numSentences - remove the oldest ones + while (sentences.size() > this->numSentences) { + sentences.erase(sentences.begin()); + } + // if there are less sentences than numSentences - add empty sentences + while (sentences.size() < this->numSentences) { + sentences.push_back(TokenBufferString()); + } + // build the caption from the sentences + for (const auto &sentence : sentences) { + if (!sentence.empty()) { + caption += trim(sentence); + } + caption += NEWLINE; + } + +#ifdef _WIN32 + // convert caption to multibyte for obs + int count = WideCharToMultiByte(CP_UTF8, 0, caption.c_str(), + (int)caption.length(), NULL, 0, + NULL, NULL); + caption_out = std::string(count, 0); + WideCharToMultiByte(CP_UTF8, 0, caption.c_str(), + (int)caption.length(), &caption_out[0], count, + NULL, NULL); +#else + caption_out = std::string(caption.begin(), caption.end()); +#endif + } + } + + if (this->stop) { + break; + } + + const auto now = std::chrono::steady_clock::now(); + + // check if enough time passed since last contribution (debounce) + const auto durationSinceLastContribution = + std::chrono::duration_cast( + now - this->lastContributionTime); + if (durationSinceLastContribution > std::chrono::milliseconds(500)) { + if (!lastContributionIsSent) { + // take the contribution queue and send it to the output + TokenBufferString contribution; + for (const auto &token : contributionQueue) { + contribution += token.token; + } + contributionQueue.clear(); +#ifdef _WIN32 + // convert caption to multibyte for obs + int count = WideCharToMultiByte(CP_UTF8, 0, contribution.c_str(), + (int)contribution.length(), NULL, 0, + NULL, NULL); + std::string contribution_out = std::string(count, 0); + WideCharToMultiByte(CP_UTF8, 0, contribution.c_str(), + (int)contribution.length(), + &contribution_out[0], count, NULL, NULL); +#else + std::string contribution_out(contribution.begin(), + contribution.end()); +#endif + + obs_log(gf->log_level, "TokenBufferThread::monitor: output '%s'", + contribution_out.c_str()); + this->sentenceOutputCallback(contribution_out); + lastContributionIsSent = true; + } + } else { + lastContributionIsSent = false; + } + + if (caption_out.empty()) { + // if no caption was built, sleep for a while + this->lastCaption = ""; + this->lastCaptionTime = now; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + continue; + } + + if (caption_out == lastCaption) { + // if it has been max_time since the last caption - clear the presentation queue + const auto duration = std::chrono::duration_cast( + now - this->lastCaptionTime); + if (this->maxTime.count() > 0) { + if (duration > this->maxTime) { + this->clear(); + } + } + } else { + // emit the caption + this->captionPresentationCallback(caption_out); + this->lastCaption = caption_out; + this->lastCaptionTime = now; + } + + // check the input queue size (iqs), if it's big - sleep less + std::this_thread::sleep_for(std::chrono::milliseconds( + inputQueue.size() > 30 ? getWaitTime(SPEED_FAST) + : inputQueue.size() > 15 ? getWaitTime(SPEED_NORMAL) + : getWaitTime(SPEED_SLOW))); + } + + obs_log(LOG_INFO, "TokenBufferThread::monitor: done"); +} + +int TokenBufferThread::getWaitTime(TokenBufferSpeed speed) const +{ + if (this->segmentation == SEGMENTATION_WORD) { + switch (speed) { + case SPEED_SLOW: + return 200; + case SPEED_NORMAL: + return 150; + case SPEED_FAST: + return 100; + } + } else if (this->segmentation == SEGMENTATION_TOKEN) { + switch (speed) { + case SPEED_SLOW: + return 100; + case SPEED_NORMAL: + return 66; + case SPEED_FAST: + return 33; + } + } + return 1000; +} diff --git a/src/whisper-utils/token-buffer-thread.h b/src/whisper-utils/token-buffer-thread.h new file mode 100644 index 0000000..7666669 --- /dev/null +++ b/src/whisper-utils/token-buffer-thread.h @@ -0,0 +1,105 @@ +#ifndef TOKEN_BUFFER_THREAD_H +#define TOKEN_BUFFER_THREAD_H + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "plugin-support.h" + +#ifdef _WIN32 +typedef std::wstring TokenBufferString; +typedef wchar_t TokenBufferChar; +#else +typedef std::string TokenBufferString; +typedef char TokenBufferChar; +#endif + +struct transcription_filter_data; + +enum TokenBufferSegmentation { SEGMENTATION_WORD = 0, SEGMENTATION_TOKEN, SEGMENTATION_SENTENCE }; +enum TokenBufferSpeed { SPEED_SLOW = 0, SPEED_NORMAL, SPEED_FAST }; + +typedef std::chrono::time_point TokenBufferTimePoint; + +inline std::chrono::time_point get_time_point_from_ms(uint64_t ms) +{ + return std::chrono::time_point(std::chrono::milliseconds(ms)); +} + +struct TokenBufferToken { + TokenBufferString token; + bool is_partial; +}; + +struct TokenBufferSentence { + std::vector tokens; + TokenBufferTimePoint start_time; + TokenBufferTimePoint end_time; +}; + +class TokenBufferThread { +public: + // default constructor + TokenBufferThread() noexcept; + + ~TokenBufferThread(); + void initialize(struct transcription_filter_data *gf, + std::function captionPresentationCallback_, + std::function sentenceOutputCallback_, + size_t numSentences_, size_t numTokensPerSentence_, + std::chrono::seconds maxTime_, + TokenBufferSegmentation segmentation_ = SEGMENTATION_TOKEN); + + void addSentenceFromStdString(const std::string &sentence, TokenBufferTimePoint start_time, + TokenBufferTimePoint end_time, bool is_partial = false); + void addSentence(const TokenBufferSentence &sentence); + void clear(); + void stopThread(); + + bool isEnabled() const { return !stop; } + + void setNumSentences(size_t numSentences_) { numSentences = numSentences_; } + void setNumPerSentence(size_t numPerSentence_) { numPerSentence = numPerSentence_; } + void setMaxTime(std::chrono::seconds maxTime_) { maxTime = maxTime_; } + void setSegmentation(TokenBufferSegmentation segmentation_) + { + segmentation = segmentation_; + } + +private: + void monitor(); + void log_token_vector(const std::vector &tokens); + int getWaitTime(TokenBufferSpeed speed) const; + struct transcription_filter_data *gf; + std::deque inputQueue; + std::deque presentationQueue; + std::deque contributionQueue; + std::thread workerThread; + std::mutex inputQueueMutex; + std::mutex presentationQueueMutex; + std::function captionPresentationCallback; + std::function sentenceOutputCallback; + std::condition_variable cv; + std::chrono::seconds maxTime; + std::atomic stop; + bool newDataAvailable = false; + size_t numSentences; + size_t numPerSentence; + TokenBufferSegmentation segmentation; + // timestamp of the last caption + TokenBufferTimePoint lastCaptionTime; + // timestamp of the last contribution + TokenBufferTimePoint lastContributionTime; + bool lastContributionIsSent = false; + std::string lastCaption; +}; + +#endif diff --git a/src/whisper-utils/vad-processing.cpp b/src/whisper-utils/vad-processing.cpp new file mode 100644 index 0000000..0e9c744 --- /dev/null +++ b/src/whisper-utils/vad-processing.cpp @@ -0,0 +1,377 @@ + +#include + +#include "transcription-filter-data.h" + +#include "vad-processing.h" + +#ifdef _WIN32 +#define NOMINMAX +#include +#endif + +int get_data_from_buf_and_resample(transcription_filter_data *gf, + uint64_t &start_timestamp_offset_ns, + uint64_t &end_timestamp_offset_ns) +{ + uint32_t num_frames_from_infos = 0; + + { + // scoped lock the buffer mutex + std::lock_guard lock(gf->whisper_buf_mutex); + + if (gf->input_buffers[0].size == 0) { + return 1; + } + + obs_log(gf->log_level, + "segmentation: currently %lu bytes in the audio input buffer", + gf->input_buffers[0].size); + + // max number of frames is 10 seconds worth of audio + const size_t max_num_frames = gf->sample_rate * 10; + + // pop all infos from the info buffer and mark the beginning timestamp from the first + // info as the beginning timestamp of the segment + struct transcription_filter_audio_info info_from_buf = {0}; + const size_t size_of_audio_info = sizeof(transcription_filter_audio_info); + while (gf->info_buffer.size >= size_of_audio_info) { + circlebuf_pop_front(&gf->info_buffer, &info_from_buf, size_of_audio_info); + num_frames_from_infos += info_from_buf.frames; + if (start_timestamp_offset_ns == 0) { + start_timestamp_offset_ns = info_from_buf.timestamp_offset_ns; + } + // Check if we're within the needed segment length + if (num_frames_from_infos > max_num_frames) { + // too big, push the last info into the buffer's front where it was + num_frames_from_infos -= info_from_buf.frames; + circlebuf_push_front(&gf->info_buffer, &info_from_buf, + size_of_audio_info); + break; + } + } + // calculate the end timestamp from the info plus the number of frames in the packet + end_timestamp_offset_ns = info_from_buf.timestamp_offset_ns + + info_from_buf.frames * 1000000000 / gf->sample_rate; + + if (start_timestamp_offset_ns > end_timestamp_offset_ns) { + // this may happen when the incoming media has a timestamp reset + // in this case, we should figure out the start timestamp from the end timestamp + // and the number of frames + start_timestamp_offset_ns = + end_timestamp_offset_ns - + num_frames_from_infos * 1000000000 / gf->sample_rate; + } + + for (size_t c = 0; c < gf->channels; c++) { + // zero the rest of copy_buffers + memset(gf->copy_buffers[c], 0, gf->frames * sizeof(float)); + } + + /* Pop from input circlebuf */ + for (size_t c = 0; c < gf->channels; c++) { + // Push the new data to copy_buffers[c] + circlebuf_pop_front(&gf->input_buffers[c], gf->copy_buffers[c], + num_frames_from_infos * sizeof(float)); + } + } + + obs_log(gf->log_level, "found %d frames from info buffer.", num_frames_from_infos); + gf->last_num_frames = num_frames_from_infos; + + { + // resample to 16kHz + float *resampled_16khz[MAX_PREPROC_CHANNELS]; + uint32_t resampled_16khz_frames; + uint64_t ts_offset; + { + ProfileScope("resample"); + audio_resampler_resample(gf->resampler_to_whisper, + (uint8_t **)resampled_16khz, + &resampled_16khz_frames, &ts_offset, + (const uint8_t **)gf->copy_buffers, + (uint32_t)num_frames_from_infos); + } + + circlebuf_push_back(&gf->resampled_buffer, resampled_16khz[0], + resampled_16khz_frames * sizeof(float)); + obs_log(gf->log_level, + "resampled: %d channels, %d frames, %f ms, current size: %lu bytes", + (int)gf->channels, (int)resampled_16khz_frames, + (float)resampled_16khz_frames / WHISPER_SAMPLE_RATE * 1000.0f, + gf->resampled_buffer.size); + } + + return 0; +} + +vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_vad_state) +{ + // get data from buffer and resample + uint64_t start_timestamp_offset_ns = 0; + uint64_t end_timestamp_offset_ns = 0; + + const int ret = get_data_from_buf_and_resample(gf, start_timestamp_offset_ns, + end_timestamp_offset_ns); + if (ret != 0) { + return last_vad_state; + } + + const size_t vad_window_size_samples = gf->vad->get_window_size_samples() * sizeof(float); + const size_t min_vad_buffer_size = vad_window_size_samples * 8; + if (gf->resampled_buffer.size < min_vad_buffer_size) + return last_vad_state; + + size_t vad_num_windows = gf->resampled_buffer.size / vad_window_size_samples; + + std::vector vad_input; + vad_input.resize(vad_num_windows * gf->vad->get_window_size_samples()); + circlebuf_pop_front(&gf->resampled_buffer, vad_input.data(), + vad_input.size() * sizeof(float)); + + obs_log(gf->log_level, "sending %d frames to vad, %d windows, reset state? %s", + vad_input.size(), vad_num_windows, (!last_vad_state.vad_on) ? "yes" : "no"); + { + ProfileScope("vad->process"); + gf->vad->process(vad_input, !last_vad_state.vad_on); + } + + const uint64_t start_ts_offset_ms = start_timestamp_offset_ns / 1000000; + const uint64_t end_ts_offset_ms = end_timestamp_offset_ns / 1000000; + + vad_state current_vad_state = {false, start_ts_offset_ms, end_ts_offset_ms, + last_vad_state.last_partial_segment_end_ts}; + + std::vector stamps = gf->vad->get_speech_timestamps(); + if (stamps.size() == 0) { + obs_log(gf->log_level, "VAD detected no speech in %u frames", vad_input.size()); + if (last_vad_state.vad_on) { + obs_log(gf->log_level, "Last VAD was ON: segment end -> send to inference"); + run_inference_and_callbacks(gf, last_vad_state.start_ts_offest_ms, + last_vad_state.end_ts_offset_ms, + VAD_STATE_WAS_ON); + current_vad_state.last_partial_segment_end_ts = 0; + } + + if (gf->enable_audio_chunks_callback) { + audio_chunk_callback(gf, vad_input.data(), vad_input.size(), + VAD_STATE_IS_OFF, + {DETECTION_RESULT_SILENCE, + "[silence]", + current_vad_state.start_ts_offest_ms, + current_vad_state.end_ts_offset_ms, + {}}); + } + + return current_vad_state; + } + + // process vad segments + for (size_t i = 0; i < stamps.size(); i++) { + int start_frame = stamps[i].start; + if (i > 0) { + // if this is not the first segment, start from the end of the previous segment + start_frame = stamps[i - 1].end; + } else { + // take at least 100ms of audio before the first speech segment, if available + start_frame = std::max(0, start_frame - WHISPER_SAMPLE_RATE / 10); + } + + int end_frame = stamps[i].end; + // if (i == stamps.size() - 1 && stamps[i].end < (int)vad_input.size()) { + // // take at least 100ms of audio after the last speech segment, if available + // end_frame = std::min(end_frame + WHISPER_SAMPLE_RATE / 10, + // (int)vad_input.size()); + // } + + const int number_of_frames = end_frame - start_frame; + + // push the data into gf-whisper_buffer + circlebuf_push_back(&gf->whisper_buffer, vad_input.data() + start_frame, + number_of_frames * sizeof(float)); + + obs_log(gf->log_level, + "VAD segment %d/%d. pushed %d to %d (%d frames / %lu ms). current size: %lu bytes / %lu frames / %lu ms", + i, (stamps.size() - 1), start_frame, end_frame, number_of_frames, + number_of_frames * 1000 / WHISPER_SAMPLE_RATE, gf->whisper_buffer.size, + gf->whisper_buffer.size / sizeof(float), + gf->whisper_buffer.size / sizeof(float) * 1000 / WHISPER_SAMPLE_RATE); + + // segment "end" is in the middle of the buffer, send it to inference + if (stamps[i].end < (int)vad_input.size()) { + // new "ending" segment (not up to the end of the buffer) + obs_log(gf->log_level, "VAD segment end -> send to inference"); + // find the end timestamp of the segment + const uint64_t segment_end_ts = + start_ts_offset_ms + end_frame * 1000 / WHISPER_SAMPLE_RATE; + run_inference_and_callbacks( + gf, last_vad_state.start_ts_offest_ms, segment_end_ts, + last_vad_state.vad_on ? VAD_STATE_WAS_ON : VAD_STATE_WAS_OFF); + current_vad_state.vad_on = false; + current_vad_state.start_ts_offest_ms = current_vad_state.end_ts_offset_ms; + current_vad_state.end_ts_offset_ms = 0; + current_vad_state.last_partial_segment_end_ts = 0; + last_vad_state = current_vad_state; + continue; + } + + // end not reached - speech is ongoing + current_vad_state.vad_on = true; + if (last_vad_state.vad_on) { + obs_log(gf->log_level, + "last vad state was: ON, start ts: %llu, end ts: %llu", + last_vad_state.start_ts_offest_ms, last_vad_state.end_ts_offset_ms); + current_vad_state.start_ts_offest_ms = last_vad_state.start_ts_offest_ms; + } else { + obs_log(gf->log_level, + "last vad state was: OFF, start ts: %llu, end ts: %llu. start_ts_offset_ms: %llu, start_frame: %d", + last_vad_state.start_ts_offest_ms, last_vad_state.end_ts_offset_ms, + start_ts_offset_ms, start_frame); + current_vad_state.start_ts_offest_ms = + start_ts_offset_ms + start_frame * 1000 / WHISPER_SAMPLE_RATE; + } + current_vad_state.end_ts_offset_ms = + start_ts_offset_ms + end_frame * 1000 / WHISPER_SAMPLE_RATE; + obs_log(gf->log_level, + "end not reached. vad state: ON, start ts: %llu, end ts: %llu", + current_vad_state.start_ts_offest_ms, current_vad_state.end_ts_offset_ms); + + last_vad_state = current_vad_state; + + // if partial transcription is enabled, check if we should send a partial segment + if (!gf->partial_transcription) { + continue; + } + + // current length of audio in buffer + const uint64_t current_length_ms = + (current_vad_state.end_ts_offset_ms > 0 + ? current_vad_state.end_ts_offset_ms + : current_vad_state.start_ts_offest_ms) - + (current_vad_state.last_partial_segment_end_ts > 0 + ? current_vad_state.last_partial_segment_end_ts + : current_vad_state.start_ts_offest_ms); + obs_log(gf->log_level, "current buffer length after last partial (%lu): %lu ms", + current_vad_state.last_partial_segment_end_ts, current_length_ms); + + if (current_length_ms > (uint64_t)gf->partial_latency) { + current_vad_state.last_partial_segment_end_ts = + current_vad_state.end_ts_offset_ms; + // send partial segment to inference + obs_log(gf->log_level, "Partial segment -> send to inference"); + run_inference_and_callbacks(gf, current_vad_state.start_ts_offest_ms, + current_vad_state.end_ts_offset_ms, + VAD_STATE_PARTIAL); + } + } + + return current_vad_state; +} + +vad_state hybrid_vad_segmentation(transcription_filter_data *gf, vad_state last_vad_state) +{ + // get data from buffer and resample + uint64_t start_timestamp_offset_ns = 0; + uint64_t end_timestamp_offset_ns = 0; + + if (get_data_from_buf_and_resample(gf, start_timestamp_offset_ns, + end_timestamp_offset_ns) != 0) { + return last_vad_state; + } + + last_vad_state.end_ts_offset_ms = end_timestamp_offset_ns / 1000000; + + // extract the data from the resampled buffer with circlebuf_pop_front into a temp buffer + // and then push it into the whisper buffer + const size_t resampled_buffer_size = gf->resampled_buffer.size; + std::vector temp_buffer; + temp_buffer.resize(resampled_buffer_size); + circlebuf_pop_front(&gf->resampled_buffer, temp_buffer.data(), resampled_buffer_size); + circlebuf_push_back(&gf->whisper_buffer, temp_buffer.data(), resampled_buffer_size); + + obs_log(gf->log_level, "whisper buffer size: %lu bytes", gf->whisper_buffer.size); + + // use last_vad_state timestamps to calculate the duration of the current segment + if (last_vad_state.end_ts_offset_ms - last_vad_state.start_ts_offest_ms >= + (uint64_t)gf->segment_duration) { + obs_log(gf->log_level, "%d seconds worth of audio -> send to inference", + gf->segment_duration); + run_inference_and_callbacks(gf, last_vad_state.start_ts_offest_ms, + last_vad_state.end_ts_offset_ms, VAD_STATE_WAS_ON); + last_vad_state.start_ts_offest_ms = end_timestamp_offset_ns / 1000000; + last_vad_state.last_partial_segment_end_ts = 0; + return last_vad_state; + } + + // if partial transcription is enabled, check if we should send a partial segment + if (gf->partial_transcription) { + // current length of audio in buffer + const uint64_t current_length_ms = + (last_vad_state.end_ts_offset_ms > 0 ? last_vad_state.end_ts_offset_ms + : last_vad_state.start_ts_offest_ms) - + (last_vad_state.last_partial_segment_end_ts > 0 + ? last_vad_state.last_partial_segment_end_ts + : last_vad_state.start_ts_offest_ms); + obs_log(gf->log_level, "current buffer length after last partial (%lu): %lu ms", + last_vad_state.last_partial_segment_end_ts, current_length_ms); + + if (current_length_ms > (uint64_t)gf->partial_latency) { + // send partial segment to inference + obs_log(gf->log_level, "Partial segment -> send to inference"); + last_vad_state.last_partial_segment_end_ts = + last_vad_state.end_ts_offset_ms; + + // run vad on the current buffer + std::vector vad_input; + vad_input.resize(gf->whisper_buffer.size / sizeof(float)); + circlebuf_peek_front(&gf->whisper_buffer, vad_input.data(), + vad_input.size() * sizeof(float)); + + obs_log(gf->log_level, "sending %d frames to vad, %.1f ms", + vad_input.size(), + (float)vad_input.size() * 1000.0f / (float)WHISPER_SAMPLE_RATE); + { + ProfileScope("vad->process"); + gf->vad->process(vad_input, true); + } + + if (gf->vad->get_speech_timestamps().size() > 0) { + // VAD detected speech in the partial segment + run_inference_and_callbacks(gf, last_vad_state.start_ts_offest_ms, + last_vad_state.end_ts_offset_ms, + VAD_STATE_PARTIAL); + } else { + // VAD detected silence in the partial segment + obs_log(gf->log_level, "VAD detected silence in partial segment"); + // pop the partial segment from the whisper buffer, save some audio for the next segment + const size_t num_bytes_to_keep = + (WHISPER_SAMPLE_RATE / 4) * sizeof(float); + circlebuf_pop_front(&gf->whisper_buffer, nullptr, + gf->whisper_buffer.size - num_bytes_to_keep); + } + } + } + + return last_vad_state; +} + +void initialize_vad(transcription_filter_data *gf, const char *silero_vad_model_file) +{ + // initialize Silero VAD +#ifdef _WIN32 + // convert mbstring to wstring + int count = MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, + strlen(silero_vad_model_file), NULL, 0); + std::wstring silero_vad_model_path(count, 0); + MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, strlen(silero_vad_model_file), + &silero_vad_model_path[0], count); + obs_log(gf->log_level, "Create silero VAD: %S", silero_vad_model_path.c_str()); +#else + std::string silero_vad_model_path = silero_vad_model_file; + obs_log(gf->log_level, "Create silero VAD: %s", silero_vad_model_path.c_str()); +#endif + // roughly following https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py + // for silero vad parameters + gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE, 32, 0.5f, 100, + 100, 100)); +} diff --git a/src/whisper-utils/vad-processing.h b/src/whisper-utils/vad-processing.h new file mode 100644 index 0000000..996002b --- /dev/null +++ b/src/whisper-utils/vad-processing.h @@ -0,0 +1,18 @@ +#ifndef VAD_PROCESSING_H +#define VAD_PROCESSING_H + +enum VadState { VAD_STATE_WAS_ON = 0, VAD_STATE_WAS_OFF, VAD_STATE_IS_OFF, VAD_STATE_PARTIAL }; +enum VadMode { VAD_MODE_ACTIVE = 0, VAD_MODE_HYBRID, VAD_MODE_DISABLED }; + +struct vad_state { + bool vad_on; + uint64_t start_ts_offest_ms; + uint64_t end_ts_offset_ms; + uint64_t last_partial_segment_end_ts; +}; + +vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_vad_state); +vad_state hybrid_vad_segmentation(transcription_filter_data *gf, vad_state last_vad_state); +void initialize_vad(transcription_filter_data *gf, const char *silero_vad_model_file); + +#endif // VAD_PROCESSING_H diff --git a/src/whisper-utils/whisper-language.h b/src/whisper-utils/whisper-language.h new file mode 100644 index 0000000..f9f349c --- /dev/null +++ b/src/whisper-utils/whisper-language.h @@ -0,0 +1,814 @@ +#ifndef WHISPER_LANGUAGE_H +#define WHISPER_LANGUAGE_H + +#include +#include + +static const std::map whisper_available_lang{ + { + "auto", + "Auto detect", + }, + { + "en", + "English", + }, + { + "zh", + "Chinese", + }, + { + "de", + "German", + }, + { + "es", + "Spanish", + }, + { + "ru", + "Russian", + }, + { + "ko", + "Korean", + }, + { + "fr", + "French", + }, + { + "ja", + "Japanese", + }, + { + "pt", + "Portuguese", + }, + { + "tr", + "Turkish", + }, + { + "pl", + "Polish", + }, + { + "ca", + "Catalan", + }, + { + "nl", + "Dutch", + }, + { + "ar", + "Arabic", + }, + { + "sv", + "Swedish", + }, + { + "it", + "Italian", + }, + { + "id", + "Indonesian", + }, + { + "hi", + "Hindi", + }, + { + "fi", + "Finnish", + }, + { + "vi", + "Vietnamese", + }, + { + "he", + "Hebrew", + }, + { + "uk", + "Ukrainian", + }, + { + "el", + "Greek", + }, + { + "ms", + "Malay", + }, + { + "cs", + "Czech", + }, + { + "ro", + "Romanian", + }, + { + "da", + "Danish", + }, + { + "hu", + "Hungarian", + }, + { + "ta", + "Tamil", + }, + { + "no", + "Norwegian", + }, + { + "th", + "Thai", + }, + { + "ur", + "Urdu", + }, + { + "hr", + "Croatian", + }, + { + "bg", + "Bulgarian", + }, + { + "lt", + "Lithuanian", + }, + { + "la", + "Latin", + }, + { + "mi", + "Maori", + }, + { + "ml", + "Malayalam", + }, + { + "cy", + "Welsh", + }, + { + "sk", + "Slovak", + }, + { + "te", + "Telugu", + }, + { + "fa", + "Persian", + }, + { + "lv", + "Latvian", + }, + { + "bn", + "Bengali", + }, + { + "sr", + "Serbian", + }, + { + "az", + "Azerbaijani", + }, + { + "sl", + "Slovenian", + }, + { + "kn", + "Kannada", + }, + { + "et", + "Estonian", + }, + { + "mk", + "Macedonian", + }, + { + "br", + "Breton", + }, + { + "eu", + "Basque", + }, + { + "is", + "Icelandic", + }, + { + "hy", + "Armenian", + }, + { + "ne", + "Nepali", + }, + { + "mn", + "Mongolian", + }, + { + "bs", + "Bosnian", + }, + { + "kk", + "Kazakh", + }, + { + "sq", + "Albanian", + }, + { + "sw", + "Swahili", + }, + { + "gl", + "Galician", + }, + { + "mr", + "Marathi", + }, + { + "pa", + "Punjabi", + }, + { + "si", + "Sinhala", + }, + { + "km", + "Khmer", + }, + { + "sn", + "Shona", + }, + { + "yo", + "Yoruba", + }, + { + "so", + "Somali", + }, + { + "af", + "Afrikaans", + }, + { + "oc", + "Occitan", + }, + { + "ka", + "Georgian", + }, + { + "be", + "Belarusian", + }, + { + "tg", + "Tajik", + }, + { + "sd", + "Sindhi", + }, + { + "gu", + "Gujarati", + }, + { + "am", + "Amharic", + }, + { + "yi", + "Yiddish", + }, + { + "lo", + "Lao", + }, + { + "uz", + "Uzbek", + }, + { + "fo", + "Faroese", + }, + { + "ht", + "Haitian", + }, + { + "ps", + "Pashto", + }, + { + "tk", + "Turkmen", + }, + { + "nn", + "Nynorsk", + }, + { + "mt", + "Maltese", + }, + { + "sa", + "Sanskrit", + }, + { + "lb", + "Luxembourgish", + }, + { + "my", + "Myanmar", + }, + { + "bo", + "Tibetan", + }, + { + "tl", + "Tagalog", + }, + { + "mg", + "Malagasy", + }, + { + "as", + "Assamese", + }, + { + "tt", + "Tatar", + }, + { + "haw", + "Hawaiian", + }, + { + "ln", + "Lingala", + }, + { + "ha", + "Hausa", + }, + { + "ba", + "Bashkir", + }, + { + "jw", + "Javanese", + }, + { + "su", + "Sundanese", + }, +}; + +// the reverse map of whisper_available_lang +static const std::map whisper_available_lang_reverse{ + { + "Auto detect", + "auto", + }, + { + "English", + "en", + }, + { + "Chinese", + "zh", + }, + { + "German", + "de", + }, + { + "Spanish", + "es", + }, + { + "Russian", + "ru", + }, + { + "Korean", + "ko", + }, + { + "French", + "fr", + }, + { + "Japanese", + "ja", + }, + { + "Portuguese", + "pt", + }, + { + "Turkish", + "tr", + }, + { + "Polish", + "pl", + }, + { + "Catalan", + "ca", + }, + { + "Dutch", + "nl", + }, + { + "Arabic", + "ar", + }, + { + "Swedish", + "sv", + }, + { + "Italian", + "it", + }, + { + "Indonesian", + "id", + }, + { + "Hindi", + "hi", + }, + { + "Finnish", + "fi", + }, + { + "Vietnamese", + "vi", + }, + { + "Hebrew", + "he", + }, + { + "Ukrainian", + "uk", + }, + { + "Greek", + "el", + }, + { + "Malay", + "ms", + }, + { + "Czech", + "cs", + }, + { + "Romanian", + "ro", + }, + { + "Danish", + "da", + }, + { + "Hungarian", + "hu", + }, + { + "Tamil", + "ta", + }, + { + "Norwegian", + "no", + }, + { + "Thai", + "th", + }, + { + "Urdu", + "ur", + }, + { + "Croatian", + "hr", + }, + { + "Bulgarian", + "bg", + }, + { + "Lithuanian", + "lt", + }, + { + "Latin", + "la", + }, + { + "Maori", + "mi", + }, + { + "Malayalam", + "ml", + }, + { + "Welsh", + "cy", + }, + { + "Slovak", + "sk", + }, + { + "Telugu", + "te", + }, + { + "Persian", + "fa", + }, + { + "Latvian", + "lv", + }, + { + "Bengali", + "bn", + }, + { + "Serbian", + "sr", + }, + { + "Azerbaijani", + "az", + }, + { + "Slovenian", + "sl", + }, + { + "Kannada", + "kn", + }, + { + "Estonian", + "et", + }, + { + "Macedonian", + "mk", + }, + { + "Breton", + "br", + }, + { + "Basque", + "eu", + }, + { + "Icelandic", + "is", + }, + { + "Armenian", + "hy", + }, + { + "Nepali", + "ne", + }, + { + "Mongolian", + "mn", + }, + { + "Bosnian", + "bs", + }, + { + "Kazakh", + "kk", + }, + { + "Albanian", + "sq", + }, + { + "Swahili", + "sw", + }, + { + "Galician", + "gl", + }, + { + "Marathi", + "mr", + }, + { + "Punjabi", + "pa", + }, + { + "Sinhala", + "si", + }, + { + "Khmer", + "km", + }, + { + "Shona", + "sn", + }, + { + "Yoruba", + "yo", + }, + { + "Somali", + "so", + }, + { + "Afrikaans", + "af", + }, + { + "Occitan", + "oc", + }, + { + "Georgian", + "ka", + }, + { + "Belarusian", + "be", + }, + { + "Tajik", + "tg", + }, + { + "Sindhi", + "sd", + }, + { + "Gujarati", + "gu", + }, + { + "Amharic", + "am", + }, + { + "Yiddish", + "yi", + }, + { + "Lao", + "lo", + }, + { + "Uzbek", + "uz", + }, + { + "Faroese", + "fo", + }, + { + "Haitian", + "ht", + }, + { + "Pashto", + "ps", + }, + { + "Turkmen", + "tk", + }, + { + "Nynorsk", + "nn", + }, + { + "Maltese", + "mt", + }, + { + "Sanskrit", + "sa", + }, + { + "Luxembourgish", + "lb", + }, + { + "Myanmar", + "my", + }, + { + "Tibetan", + "bo", + }, + { + "Tagalog", + "tl", + }, + { + "Malagasy", + "mg", + }, + { + "Assamese", + "as", + }, + { + "Tatar", + "tt", + }, + { + "Hawaiian", + "haw", + }, + { + "Lingala", + "ln", + }, + { + "Hausa", + "ha", + }, + { + "Bashkir", + "ba", + }, + { + "Javanese", + "jw", + }, + { + "Sundanese", + "su", + }, +}; + +#endif // WHISPER_LANGUAGE_H diff --git a/src/whisper-utils/whisper-model-utils.cpp b/src/whisper-utils/whisper-model-utils.cpp new file mode 100644 index 0000000..8985a30 --- /dev/null +++ b/src/whisper-utils/whisper-model-utils.cpp @@ -0,0 +1,142 @@ +#ifdef _WIN32 +#define NOMINMAX +#endif + +#include + +#include "whisper-utils.h" +#include "whisper-processing.h" +#include "plugin-support.h" +#include "model-utils/model-downloader.h" + +void update_whisper_model(struct transcription_filter_data *gf) +{ + if (gf->context == nullptr) { + obs_log(LOG_ERROR, "obs_source_t context is null"); + return; + } + + obs_data_t *s = obs_source_get_settings(gf->context); + if (s == nullptr) { + obs_log(LOG_ERROR, "obs_data_t settings is null"); + return; + } + + // Get settings from context + std::string new_model_path = obs_data_get_string(s, "whisper_model_path") != nullptr + ? obs_data_get_string(s, "whisper_model_path") + : ""; + std::string external_model_file_path = + obs_data_get_string(s, "whisper_model_path_external") != nullptr + ? obs_data_get_string(s, "whisper_model_path_external") + : ""; + const bool new_dtw_timestamps = obs_data_get_bool(s, "dtw_token_timestamps"); + obs_data_release(s); + + // update the whisper model path + + const bool is_external_model = new_model_path.find("!!!external!!!") != std::string::npos; + + if (!is_external_model && new_model_path.empty()) { + obs_log(LOG_WARNING, "Whisper model path is empty"); + return; + } + if (is_external_model && external_model_file_path.empty()) { + obs_log(LOG_WARNING, "External model file path is empty"); + return; + } + + char *silero_vad_model_file = obs_module_file("models/silero-vad/silero_vad.onnx"); + if (silero_vad_model_file == nullptr) { + obs_log(LOG_ERROR, "Cannot find Silero VAD model file"); + return; + } + obs_log(gf->log_level, "Silero VAD model file: %s", silero_vad_model_file); + std::string silero_vad_model_file_str = std::string(silero_vad_model_file); + bfree(silero_vad_model_file); + + if (gf->whisper_model_path.empty() || gf->whisper_model_path != new_model_path || + is_external_model) { + + if (gf->whisper_model_path != new_model_path) { + // model path changed + obs_log(gf->log_level, "model path changed from %s to %s", + gf->whisper_model_path.c_str(), new_model_path.c_str()); + + // check if this is loading the initial model or a switch + gf->whisper_model_loaded_new = !gf->whisper_model_path.empty(); + } + + // check if the new model is external file + if (!is_external_model) { + // new model is not external file + shutdown_whisper_thread(gf); + + if (models_info.count(new_model_path) == 0) { + obs_log(LOG_WARNING, "Model '%s' does not exist", + new_model_path.c_str()); + return; + } + + const ModelInfo &model_info = models_info[new_model_path]; + + // check if the model exists, if not, download it + std::string model_file_found = find_model_bin_file(model_info); + if (model_file_found == "") { + obs_log(LOG_WARNING, "Whisper model does not exist"); + download_model_with_ui_dialog( + model_info, + [gf, new_model_path, silero_vad_model_file_str]( + int download_status, const std::string &path) { + if (download_status == 0) { + obs_log(LOG_INFO, + "Model download complete"); + gf->whisper_model_path = new_model_path; + start_whisper_thread_with_path( + gf, path, + silero_vad_model_file_str.c_str()); + } else { + obs_log(LOG_ERROR, "Model download failed"); + } + }); + } else { + // Model exists, just load it + gf->whisper_model_path = new_model_path; + start_whisper_thread_with_path(gf, model_file_found, + silero_vad_model_file_str.c_str()); + } + } else { + // new model is external file, get file location from file property + if (external_model_file_path.empty()) { + obs_log(LOG_WARNING, "External model file path is empty"); + } else { + // check if the external model file is not currently loaded + if (gf->whisper_model_file_currently_loaded == + external_model_file_path) { + obs_log(LOG_INFO, "External model file is already loaded"); + return; + } else { + shutdown_whisper_thread(gf); + gf->whisper_model_path = new_model_path; + start_whisper_thread_with_path( + gf, external_model_file_path, + silero_vad_model_file_str.c_str()); + } + } + } + } else { + // model path did not change + obs_log(gf->log_level, "Model path did not change: %s == %s", + gf->whisper_model_path.c_str(), new_model_path.c_str()); + } + + if (new_dtw_timestamps != gf->enable_token_ts_dtw) { + // dtw_token_timestamps changed + obs_log(gf->log_level, "dtw_token_timestamps changed from %d to %d", + gf->enable_token_ts_dtw, new_dtw_timestamps); + gf->enable_token_ts_dtw = new_dtw_timestamps; + shutdown_whisper_thread(gf); + start_whisper_thread_with_path(gf, gf->whisper_model_path, + silero_vad_model_file_str.c_str()); + } +} diff --git a/src/whisper-utils/whisper-model-utils.h b/src/whisper-utils/whisper-model-utils.h new file mode 100644 index 0000000..68c649c --- /dev/null +++ b/src/whisper-utils/whisper-model-utils.h @@ -0,0 +1,10 @@ +#ifndef WHISPER_MODEL_UTILS_H +#define WHISPER_MODEL_UTILS_H + +#include + +#include "transcription-filter-data.h" + +void update_whisper_model(struct transcription_filter_data *gf); + +#endif // WHISPER_MODEL_UTILS_H diff --git a/src/whisper-utils/whisper-processing.cpp b/src/whisper-utils/whisper-processing.cpp new file mode 100644 index 0000000..3518edf --- /dev/null +++ b/src/whisper-utils/whisper-processing.cpp @@ -0,0 +1,407 @@ +#include + +#include + +#include + +#include "plugin-support.h" +#include "transcription-filter-data.h" +#include "whisper-processing.h" +#include "whisper-utils.h" +#include "transcription-utils.h" + +#ifdef _WIN32 +#include +#define NOMINMAX +#include +#endif + +#include "model-utils/model-find-utils.h" +#include "vad-processing.h" + +#include +#include +#include + +struct whisper_context *init_whisper_context(const std::string &model_path_in, + struct transcription_filter_data *gf) +{ + std::string model_path = model_path_in; + + obs_log(LOG_INFO, "Loading whisper model from %s", model_path.c_str()); + + if (std::filesystem::is_directory(model_path)) { + obs_log(LOG_INFO, + "Model path is a directory, not a file, looking for .bin file in folder"); + // look for .bin file + const std::string model_bin_file = find_bin_file_in_folder(model_path); + if (model_bin_file.empty()) { + obs_log(LOG_ERROR, "Model bin file not found in folder: %s", + model_path.c_str()); + return nullptr; + } + model_path = model_bin_file; + } + + whisper_log_set( + [](enum ggml_log_level level, const char *text, void *user_data) { + UNUSED_PARAMETER(level); + struct transcription_filter_data *ctx = + static_cast(user_data); + // remove trailing newline + char *text_copy = bstrdup(text); + text_copy[strcspn(text_copy, "\n")] = 0; + obs_log(ctx->log_level, "Whisper: %s", text_copy); + bfree(text_copy); + }, + gf); + + struct whisper_context_params cparams = whisper_context_default_params(); +#ifdef LOCALVOCAL_WITH_CUDA + cparams.use_gpu = true; + obs_log(LOG_INFO, "Using CUDA GPU for inference, device %d", cparams.gpu_device); +#elif defined(LOCALVOCAL_WITH_HIPBLAS) + cparams.use_gpu = true; + obs_log(LOG_INFO, "Using hipBLAS for inference"); +#elif defined(__APPLE__) + cparams.use_gpu = true; + obs_log(LOG_INFO, "Using Metal/CoreML for inference"); +#else + cparams.use_gpu = false; + obs_log(LOG_INFO, "Using CPU for inference"); +#endif + + cparams.dtw_token_timestamps = gf->enable_token_ts_dtw; + if (gf->enable_token_ts_dtw) { + obs_log(LOG_INFO, "DTW token timestamps enabled"); + cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY_EN; + // cparams.dtw_n_top = 4; + } else { + obs_log(LOG_INFO, "DTW token timestamps disabled"); + cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE; + } + + struct whisper_context *ctx = nullptr; + try { +#ifdef _WIN32 + // convert model path UTF8 to wstring (wchar_t) for whisper + int count = MultiByteToWideChar(CP_UTF8, 0, model_path.c_str(), + (int)model_path.length(), NULL, 0); + std::wstring model_path_ws(count, 0); + MultiByteToWideChar(CP_UTF8, 0, model_path.c_str(), (int)model_path.length(), + &model_path_ws[0], count); + + // Read model into buffer + std::ifstream modelFile(model_path_ws, std::ios::binary); + if (!modelFile.is_open()) { + obs_log(LOG_ERROR, "Failed to open whisper model file %s", + model_path.c_str()); + return nullptr; + } + modelFile.seekg(0, std::ios::end); + const size_t modelFileSize = modelFile.tellg(); + modelFile.seekg(0, std::ios::beg); + std::vector modelBuffer(modelFileSize); + modelFile.read(modelBuffer.data(), modelFileSize); + modelFile.close(); + + // Initialize whisper + ctx = whisper_init_from_buffer_with_params(modelBuffer.data(), modelFileSize, + cparams); +#else + ctx = whisper_init_from_file_with_params(model_path.c_str(), cparams); +#endif + } catch (const std::exception &e) { + obs_log(LOG_ERROR, "Exception while loading whisper model: %s", e.what()); + return nullptr; + } + if (ctx == nullptr) { + obs_log(LOG_ERROR, "Failed to load whisper model"); + return nullptr; + } + + obs_log(LOG_INFO, "Whisper model loaded: %s", whisper_print_system_info()); + return ctx; +} + +struct DetectionResultWithText run_whisper_inference(struct transcription_filter_data *gf, + const float *pcm32f_data_, + size_t pcm32f_num_samples, uint64_t t0 = 0, + uint64_t t1 = 0, + int vad_state = VAD_STATE_WAS_OFF) +{ + if (gf == nullptr) { + obs_log(LOG_ERROR, "run_whisper_inference: gf is null"); + return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""}; + } + + if (pcm32f_data_ == nullptr || pcm32f_num_samples == 0) { + obs_log(LOG_ERROR, "run_whisper_inference: pcm32f_data is null or size is 0"); + return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""}; + } + + // if the time difference between t0 and t1 is less than 50 ms - skip + if (t1 - t0 < 50) { + obs_log(gf->log_level, + "Time difference between t0 and t1 is less than 50 ms, skipping"); + return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""}; + } + + obs_log(gf->log_level, "%s: processing %d samples, %.3f sec, %d threads", __func__, + int(pcm32f_num_samples), float(pcm32f_num_samples) / WHISPER_SAMPLE_RATE, + gf->whisper_params.n_threads); + + bool should_free_buffer = false; + float *pcm32f_data = (float *)pcm32f_data_; + size_t pcm32f_size = pcm32f_num_samples; + + // incoming duration in ms + const uint64_t incoming_duration_ms = + (uint64_t)(pcm32f_num_samples * 1000 / WHISPER_SAMPLE_RATE); + + if (pcm32f_num_samples < WHISPER_SAMPLE_RATE) { + obs_log(gf->log_level, + "Speech segment is less than 1 second, padding with white noise to 1 second"); + const size_t new_size = (size_t)(1.01f * (float)(WHISPER_SAMPLE_RATE)); + // create a new buffer and copy the data to it in the middle + pcm32f_data = (float *)bzalloc(new_size * sizeof(float)); + + // add low volume white noise + const float noise_level = 0.01f; + for (size_t i = 0; i < new_size; ++i) { + pcm32f_data[i] = + noise_level * ((float)rand() / (float)RAND_MAX * 2.0f - 1.0f); + } + + memcpy(pcm32f_data + (new_size - pcm32f_num_samples) / 2, pcm32f_data_, + pcm32f_num_samples * sizeof(float)); + pcm32f_size = new_size; + should_free_buffer = true; + } + + // duration in ms + const uint64_t whisper_duration_ms = (uint64_t)(pcm32f_size * 1000 / WHISPER_SAMPLE_RATE); + + std::lock_guard lock(gf->whisper_ctx_mutex); + if (gf->whisper_context == nullptr) { + obs_log(LOG_WARNING, "whisper context is null"); + return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""}; + } + + if (gf->n_context_sentences > 0 && !gf->last_transcription_sentence.empty()) { + // set the initial prompt to the last transcription sentences (concatenated) + std::string initial_prompt = gf->last_transcription_sentence[0]; + for (size_t i = 1; i < gf->last_transcription_sentence.size(); ++i) { + initial_prompt += " " + gf->last_transcription_sentence[i]; + } + gf->whisper_params.initial_prompt = initial_prompt.c_str(); + obs_log(gf->log_level, "Initial prompt: %s", gf->whisper_params.initial_prompt); + } + + // run the inference + int whisper_full_result = -1; + gf->whisper_params.duration_ms = (int)(whisper_duration_ms); + try { + whisper_full_result = whisper_full(gf->whisper_context, gf->whisper_params, + pcm32f_data, (int)pcm32f_size); + } catch (const std::exception &e) { + obs_log(LOG_ERROR, "Whisper exception: %s. Filter restart is required", e.what()); + whisper_free(gf->whisper_context); + gf->whisper_context = nullptr; + if (should_free_buffer) { + bfree(pcm32f_data); + } + return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""}; + } + if (should_free_buffer) { + bfree(pcm32f_data); + } + + std::string language = gf->whisper_params.language; + if (gf->whisper_params.language == nullptr || strlen(gf->whisper_params.language) == 0 || + strcmp(gf->whisper_params.language, "auto") == 0) { + int lang_id = whisper_lang_auto_detect(gf->whisper_context, 0, 1, nullptr); + language = whisper_lang_str(lang_id); + obs_log(gf->log_level, "Detected language: %s", language.c_str()); + } + + if (whisper_full_result != 0) { + obs_log(LOG_WARNING, "failed to process audio, error %d", whisper_full_result); + return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""}; + } + + float sentence_p = 0.0f; + std::string text = ""; + std::string tokenIds = ""; + std::vector tokens; + for (int n_segment = 0; n_segment < whisper_full_n_segments(gf->whisper_context); + ++n_segment) { + const int n_tokens = whisper_full_n_tokens(gf->whisper_context, n_segment); + for (int j = 0; j < n_tokens; ++j) { + // get token + whisper_token_data token = + whisper_full_get_token_data(gf->whisper_context, n_segment, j); + const std::string token_str = + whisper_token_to_str(gf->whisper_context, token.id); + bool keep = true; + // if the token starts with '[' and ends with ']', don't keep it + if (token_str[0] == '[' && token_str[token_str.size() - 1] == ']') { + keep = false; + } + // if this is a special token, don't keep it + if (token.id >= 50256) { + keep = false; + } + // if the second to last token is .id == 13 ('.'), don't keep it + if (j == n_tokens - 2 && token.id == 13) { + keep = false; + } + // token ids https://huggingface.co/openai/whisper-large-v3/raw/main/tokenizer.json + if (token.id > 50365 && token.id <= 51865) { + const float time = ((float)token.id - 50365.0f) * 0.02f; + const float duration_s = (float)incoming_duration_ms / 1000.0f; + const float ratio = time / duration_s; + obs_log(gf->log_level, + "Time token found %d -> %.3f. Duration: %.3f. Ratio: %.3f. Threshold %.2f", + token.id, time, duration_s, ratio, + gf->duration_filter_threshold); + if (ratio > gf->duration_filter_threshold) { + // ratio is too high, skip this detection + obs_log(gf->log_level, + "Time token ratio too high, skipping"); + return {DETECTION_RESULT_SILENCE, "", t0, t1, {}, language}; + } + keep = false; + } + + if (keep) { + sentence_p += token.p; + text += token_str; + tokens.push_back(token); + } + obs_log(gf->log_level, "S %d, T %2d: %5d\t%s\tp: %.3f [keep: %d]", + n_segment, j, token.id, token_str.c_str(), token.p, keep); + } + } + sentence_p /= (float)tokens.size(); + if (sentence_p < gf->sentence_psum_accept_thresh) { + obs_log(gf->log_level, "Sentence psum %.3f below threshold %.3f, skipping", + sentence_p, gf->sentence_psum_accept_thresh); + return {DETECTION_RESULT_SILENCE, "", t0, t1, {}, language}; + } + + obs_log(gf->log_level, "Decoded sentence: '%s'", text.c_str()); + + if (gf->log_words) { + obs_log(LOG_INFO, "[%s --> %s]%s(%.3f) %s", to_timestamp(t0).c_str(), + to_timestamp(t1).c_str(), vad_state == VAD_STATE_PARTIAL ? "P" : " ", + sentence_p, text.c_str()); + } + + if (text.empty() || text == "." || text == " " || text == "\n") { + return {DETECTION_RESULT_SILENCE, "", t0, t1, {}, language}; + } + + return {vad_state == VAD_STATE_PARTIAL ? DETECTION_RESULT_PARTIAL : DETECTION_RESULT_SPEECH, + text, + t0, + t1, + tokens, + language}; +} + +void run_inference_and_callbacks(transcription_filter_data *gf, uint64_t start_offset_ms, + uint64_t end_offset_ms, int vad_state) +{ + // get the data from the entire whisper buffer + // add 50ms of silence to the beginning and end of the buffer + const size_t pcm32f_size = gf->whisper_buffer.size / sizeof(float); + const size_t pcm32f_size_with_silence = pcm32f_size + 2 * WHISPER_SAMPLE_RATE / 100; + // allocate a new buffer and copy the data to it + float *pcm32f_data = (float *)bzalloc(pcm32f_size_with_silence * sizeof(float)); + if (vad_state == VAD_STATE_PARTIAL) { + // peek instead of pop, since this is a partial run that keeps the data in the buffer + circlebuf_peek_front(&gf->whisper_buffer, pcm32f_data + WHISPER_SAMPLE_RATE / 100, + pcm32f_size * sizeof(float)); + } else { + circlebuf_pop_front(&gf->whisper_buffer, pcm32f_data + WHISPER_SAMPLE_RATE / 100, + pcm32f_size * sizeof(float)); + } + + struct DetectionResultWithText inference_result = + run_whisper_inference(gf, pcm32f_data, pcm32f_size_with_silence, start_offset_ms, + end_offset_ms, vad_state); + // output inference result to a text source + set_text_callback(gf, inference_result); + + if (gf->enable_audio_chunks_callback && vad_state != VAD_STATE_PARTIAL) { + audio_chunk_callback(gf, pcm32f_data, pcm32f_size_with_silence, vad_state, + inference_result); + } + + // free the buffer + bfree(pcm32f_data); +} + +void whisper_loop(void *data) +{ + if (data == nullptr) { + obs_log(LOG_ERROR, "whisper_loop: data is null"); + return; + } + + struct transcription_filter_data *gf = + static_cast(data); + + obs_log(gf->log_level, "Starting whisper thread"); + + vad_state current_vad_state = {false, now_ms(), 0, 0}; + + const char *whisper_loop_name = "Whisper loop"; + profile_register_root(whisper_loop_name, 50 * 1000 * 1000); + + // Thread main loop + while (true) { + ProfileScope(whisper_loop_name); + { + ProfileScope("lock whisper ctx"); + std::lock_guard lock(gf->whisper_ctx_mutex); + ProfileScope("locked whisper ctx"); + if (gf->whisper_context == nullptr) { + obs_log(LOG_WARNING, "Whisper context is null, exiting thread"); + break; + } + } + + if (gf->vad_mode == VAD_MODE_HYBRID) { + current_vad_state = hybrid_vad_segmentation(gf, current_vad_state); + } else if (gf->vad_mode == VAD_MODE_ACTIVE) { + current_vad_state = vad_based_segmentation(gf, current_vad_state); + } + + if (!gf->cleared_last_sub) { + // check if we should clear the current sub depending on the minimum subtitle duration + uint64_t now = now_ms(); + if ((now - gf->last_sub_render_time) > gf->max_sub_duration) { + // clear the current sub, call the callback with an empty string + obs_log(gf->log_level, + "Clearing current subtitle. now: %lu ms, last: %lu ms", now, + gf->last_sub_render_time); + clear_current_caption(gf); + } + } + + if (gf->input_cv.has_value()) + gf->input_cv->notify_one(); + + // Sleep using the condition variable wshiper_thread_cv + // This will wake up the thread if there is new data in the input buffer + // or if the whisper context is null + std::unique_lock lock(gf->whisper_ctx_mutex); + if (gf->input_buffers->size == 0) { + gf->wshiper_thread_cv.wait_for(lock, std::chrono::milliseconds(50)); + } + } + + obs_log(gf->log_level, "Exiting whisper thread"); +} diff --git a/src/whisper-utils/whisper-processing.h b/src/whisper-utils/whisper-processing.h new file mode 100644 index 0000000..a00f7cb --- /dev/null +++ b/src/whisper-utils/whisper-processing.h @@ -0,0 +1,38 @@ +#ifndef WHISPER_PROCESSING_H +#define WHISPER_PROCESSING_H + +#include + +// buffer size in msec +#define DEFAULT_BUFFER_SIZE_MSEC 3000 +// overlap in msec +#define DEFAULT_OVERLAP_SIZE_MSEC 125 +#define MAX_OVERLAP_SIZE_MSEC 1000 +#define MIN_OVERLAP_SIZE_MSEC 125 +#define MAX_MS_WORK_BUFFER 11000 + +enum DetectionResult { + DETECTION_RESULT_UNKNOWN = 0, + DETECTION_RESULT_SILENCE = 1, + DETECTION_RESULT_SPEECH = 2, + DETECTION_RESULT_SUPPRESSED = 3, + DETECTION_RESULT_NO_INFERENCE = 4, + DETECTION_RESULT_PARTIAL = 5, +}; + +struct DetectionResultWithText { + DetectionResult result; + std::string text; + uint64_t start_timestamp_ms; + uint64_t end_timestamp_ms; + std::vector tokens; + std::string language; +}; + +void whisper_loop(void *data); +struct whisper_context *init_whisper_context(const std::string &model_path, + struct transcription_filter_data *gf); +void run_inference_and_callbacks(transcription_filter_data *gf, uint64_t start_offset_ms, + uint64_t end_offset_ms, int vad_state); + +#endif // WHISPER_PROCESSING_H diff --git a/src/whisper-utils/whisper-utils.cpp b/src/whisper-utils/whisper-utils.cpp new file mode 100644 index 0000000..84f3b0a --- /dev/null +++ b/src/whisper-utils/whisper-utils.cpp @@ -0,0 +1,161 @@ +#include "whisper-utils.h" +#include "plugin-support.h" +#include "model-utils/model-downloader.h" +#include "whisper-processing.h" +#include "vad-processing.h" + +#include + +void shutdown_whisper_thread(struct transcription_filter_data *gf) +{ + obs_log(gf->log_level, "shutdown_whisper_thread"); + if (gf->whisper_context != nullptr) { + // acquire the mutex before freeing the context + std::lock_guard lock(gf->whisper_ctx_mutex); + whisper_free(gf->whisper_context); + gf->whisper_context = nullptr; + gf->wshiper_thread_cv.notify_all(); + } + if (gf->whisper_thread.joinable()) { + gf->whisper_thread.join(); + } + if (!gf->whisper_model_path.empty()) { + gf->whisper_model_path = ""; + } +} + +void start_whisper_thread_with_path(struct transcription_filter_data *gf, + const std::string &whisper_model_path, + const char *silero_vad_model_file) +{ + obs_log(gf->log_level, "start_whisper_thread_with_path: %s, silero model path: %s", + whisper_model_path.c_str(), silero_vad_model_file); + std::lock_guard lock(gf->whisper_ctx_mutex); + if (gf->whisper_context != nullptr) { + obs_log(LOG_ERROR, "cannot init whisper: whisper_context is not null"); + return; + } + + // initialize Silero VAD + initialize_vad(gf, silero_vad_model_file); + + obs_log(gf->log_level, "Create whisper context"); + gf->whisper_context = init_whisper_context(whisper_model_path, gf); + if (gf->whisper_context == nullptr) { + obs_log(LOG_ERROR, "Failed to initialize whisper context"); + return; + } + gf->whisper_model_file_currently_loaded = whisper_model_path; + std::thread new_whisper_thread(whisper_loop, gf); + gf->whisper_thread.swap(new_whisper_thread); +} + +// Finds start of 2-token overlap between two sequences of tokens +// Returns a pair of indices of the first overlapping tokens in the two sequences +// If no overlap is found, the function returns {-1, -1} +// Allows for a single token mismatch in the overlap +std::pair findStartOfOverlap(const std::vector &seq1, + const std::vector &seq2) +{ + if (seq1.empty() || seq2.empty() || seq1.size() == 1 || seq2.size() == 1) { + return {-1, -1}; + } + for (size_t i = seq1.size() - 2; i >= seq1.size() / 2; --i) { + for (size_t j = 0; j < seq2.size() - 1; ++j) { + if (seq1[i].id == seq2[j].id) { + // Check if the next token in both sequences is the same + if (seq1[i + 1].id == seq2[j + 1].id) { + return {i, j}; + } + // 1-skip check on seq1 + if (i + 2 < seq1.size() && seq1[i + 2].id == seq2[j + 1].id) { + return {i, j}; + } + // 1-skip check on seq2 + if (j + 2 < seq2.size() && seq1[i + 1].id == seq2[j + 2].id) { + return {i, j}; + } + } + } + } + return {-1, -1}; +} + +// Function to reconstruct a whole sentence from two sentences using overlap info +// If no overlap is found, the function returns the concatenation of the two sequences +std::vector reconstructSentence(const std::vector &seq1, + const std::vector &seq2) +{ + auto overlap = findStartOfOverlap(seq1, seq2); + std::vector reconstructed; + + if (overlap.first == -1 || overlap.second == -1) { + if (seq1.empty() && seq2.empty()) { + return reconstructed; + } + if (seq1.empty()) { + return seq2; + } + if (seq2.empty()) { + return seq1; + } + + // Return concat of seq1 and seq2 if no overlap found + // check if the last token of seq1 == the first token of seq2 + if (seq1.back().id == seq2.front().id) { + // don't add the last token of seq1 + reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end() - 1); + reconstructed.insert(reconstructed.end(), seq2.begin(), seq2.end()); + } else if (seq2.size() > 1ull && seq1.back().id == seq2[1].id) { + // check if the last token of seq1 == the second token of seq2 + // don't add the last token of seq1 + reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end() - 1); + // don't add the first token of seq2 + reconstructed.insert(reconstructed.end(), seq2.begin() + 1, seq2.end()); + } else if (seq1.size() > 1ull && seq1[seq1.size() - 2].id == seq2.front().id) { + // check if the second to last token of seq1 == the first token of seq2 + // don't add the last two tokens of seq1 + reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end() - 2); + reconstructed.insert(reconstructed.end(), seq2.begin(), seq2.end()); + } else { + // add all tokens of seq1 + reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.end()); + reconstructed.insert(reconstructed.end(), seq2.begin(), seq2.end()); + } + return reconstructed; + } + + // Add tokens from the first sequence up to the overlap + reconstructed.insert(reconstructed.end(), seq1.begin(), seq1.begin() + overlap.first); + + // Determine the length of the overlap + size_t overlapLength = 0; + while (overlap.first + overlapLength < seq1.size() && + overlap.second + overlapLength < seq2.size() && + seq1[overlap.first + overlapLength].id == seq2[overlap.second + overlapLength].id) { + overlapLength++; + } + + // Add overlapping tokens + reconstructed.insert(reconstructed.end(), seq1.begin() + overlap.first, + seq1.begin() + overlap.first + overlapLength); + + // Add remaining tokens from the second sequence + reconstructed.insert(reconstructed.end(), seq2.begin() + overlap.second + overlapLength, + seq2.end()); + + return reconstructed; +} + +std::string to_timestamp(uint64_t t_ms_offset) +{ + uint64_t sec = t_ms_offset / 1000; + uint64_t msec = t_ms_offset - sec * 1000; + uint64_t min = sec / 60; + sec = sec - min * 60; + + char buf[32]; + snprintf(buf, sizeof(buf), "%02d:%02d.%03d", (int)min, (int)sec, (int)msec); + + return std::string(buf); +} diff --git a/src/whisper-utils/whisper-utils.h b/src/whisper-utils/whisper-utils.h new file mode 100644 index 0000000..c62168b --- /dev/null +++ b/src/whisper-utils/whisper-utils.h @@ -0,0 +1,25 @@ +#ifndef WHISPER_UTILS_H +#define WHISPER_UTILS_H + +#include "transcription-filter-data.h" + +#include + +void shutdown_whisper_thread(struct transcription_filter_data *gf); +void start_whisper_thread_with_path(struct transcription_filter_data *gf, const std::string &path, + const char *silero_vad_model_file); + +std::pair findStartOfOverlap(const std::vector &seq1, + const std::vector &seq2); +std::vector reconstructSentence(const std::vector &seq1, + const std::vector &seq2); + +/** + * @brief Convert a timestamp in milliseconds to a string in the format "MM:SS.sss" . + * Taken from https://github.com/ggerganov/whisper.cpp/blob/master/examples/stream/stream.cpp + * @param t_ms_offset Timestamp in milliseconds (offset from the beginning of the stream) + * @return std::string Timestamp in the format "MM:SS.sss" + */ +std::string to_timestamp(uint64_t t_ms_offset); + +#endif /* WHISPER_UTILS_H */ From 674b1d3ecba6a1197969e424698f922a84318405 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Fri, 13 Sep 2024 22:05:14 -0400 Subject: [PATCH 02/12] Add build_x64/ directory to .gitignore --- .gitignore | 1 + CMakeLists.txt | 125 +++++---- README.md | 172 ++++++++---- cmake/BuildCTranslate2.cmake | 5 + cmake/BuildICU.cmake | 5 + cmake/BuildMyCurl.cmake | 1 + cmake/BuildSentencepiece.cmake | 6 + cmake/BuildWhispercpp.cmake | 10 + cmake/FetchLibav.cmake | 80 ++++++ cmake/FetchOnnxruntime.cmake | 18 +- cmake/LocaalSDKConfig.cmake.in | 23 ++ examples/CMakeLists.txt | 5 + src/model-utils/model-downloader-ui.cpp | 256 ------------------ src/model-utils/model-downloader-ui.h | 61 ----- src/modules/core/CMakeLists.txt | 34 +++ .../core/include}/model-downloader-types.h | 0 .../core/include}/model-downloader.h | 0 .../core/include}/model-find-utils.h | 0 .../core/src}/model-downloader.cpp | 0 .../core/src}/model-find-utils.cpp | 0 .../core/src}/model-infos.cpp | 0 src/modules/transcription/CMakeLists.txt | 38 +++ .../transcription/include}/silero-vad-onnx.h | 0 .../include}/token-buffer-thread.h | 6 +- .../include/transcription-context.h} | 10 +- .../include}/transcription-utils.h | 0 .../transcription/include}/vad-processing.h | 6 +- .../transcription/include}/whisper-language.h | 0 .../include}/whisper-model-utils.h | 2 +- .../include}/whisper-processing.h | 4 +- .../transcription/include}/whisper-utils.h | 4 +- .../transcription/src}/silero-vad-onnx.cpp | 0 .../src}/token-buffer-thread.cpp | 2 +- .../src}/transcription-utils.cpp | 0 .../transcription/src}/vad-processing.cpp | 0 .../src}/whisper-model-utils.cpp | 2 +- .../transcription/src}/whisper-processing.cpp | 14 +- .../transcription/src}/whisper-utils.cpp | 4 +- src/modules/translation/CMakeLists.txt | 35 +++ .../translation/include}/language_codes.h | 0 .../include}/translation-includes.h | 0 .../include}/translation-language-utils.h | 0 .../translation/include}/translation-utils.h | 2 +- .../translation/include}/translation.h | 10 +- .../translation/src}/language_codes.cpp | 0 .../src}/translation-language-utils.cpp | 0 .../translation/src}/translation-utils.cpp | 2 +- .../translation/src}/translation.cpp | 2 +- 48 files changed, 489 insertions(+), 456 deletions(-) create mode 100644 cmake/FetchLibav.cmake create mode 100644 cmake/LocaalSDKConfig.cmake.in delete mode 100644 src/model-utils/model-downloader-ui.cpp delete mode 100644 src/model-utils/model-downloader-ui.h create mode 100644 src/modules/core/CMakeLists.txt rename src/{model-utils => modules/core/include}/model-downloader-types.h (100%) rename src/{model-utils => modules/core/include}/model-downloader.h (100%) rename src/{model-utils => modules/core/include}/model-find-utils.h (100%) rename src/{model-utils => modules/core/src}/model-downloader.cpp (100%) rename src/{model-utils => modules/core/src}/model-find-utils.cpp (100%) rename src/{model-utils => modules/core/src}/model-infos.cpp (100%) create mode 100644 src/modules/transcription/CMakeLists.txt rename src/{whisper-utils => modules/transcription/include}/silero-vad-onnx.h (100%) rename src/{whisper-utils => modules/transcription/include}/token-buffer-thread.h (95%) rename src/{transcription-filter-data.h => modules/transcription/include/transcription-context.h} (92%) rename src/{ => modules/transcription/include}/transcription-utils.h (100%) rename src/{whisper-utils => modules/transcription/include}/vad-processing.h (57%) rename src/{whisper-utils => modules/transcription/include}/whisper-language.h (100%) rename src/{whisper-utils => modules/transcription/include}/whisper-model-utils.h (70%) rename src/{whisper-utils => modules/transcription/include}/whisper-processing.h (86%) rename src/{whisper-utils => modules/transcription/include}/whisper-utils.h (83%) rename src/{whisper-utils => modules/transcription/src}/silero-vad-onnx.cpp (100%) rename src/{whisper-utils => modules/transcription/src}/token-buffer-thread.cpp (96%) rename src/{ => modules/transcription/src}/transcription-utils.cpp (100%) rename src/{whisper-utils => modules/transcription/src}/vad-processing.cpp (100%) rename src/{whisper-utils => modules/transcription/src}/whisper-model-utils.cpp (98%) rename src/{whisper-utils => modules/transcription/src}/whisper-processing.cpp (97%) rename src/{whisper-utils => modules/transcription/src}/whisper-utils.cpp (97%) create mode 100644 src/modules/translation/CMakeLists.txt rename src/{translation => modules/translation/include}/language_codes.h (100%) rename src/{translation => modules/translation/include}/translation-includes.h (100%) rename src/{translation => modules/translation/include}/translation-language-utils.h (100%) rename src/{translation => modules/translation/include}/translation-utils.h (67%) rename src/{translation => modules/translation/include}/translation.h (85%) rename src/{translation => modules/translation/src}/language_codes.cpp (100%) rename src/{translation => modules/translation/src}/translation-language-utils.cpp (100%) rename src/{translation => modules/translation/src}/translation-utils.cpp (95%) rename src/{translation => modules/translation/src}/translation.cpp (99%) diff --git a/.gitignore b/.gitignore index 081b432..8594f3a 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ # CMake build directory build/ +build_x64/ # CMake generated files CMakeCache.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index db26651..761b70c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,19 +1,29 @@ cmake_minimum_required(VERSION 3.12) -project(locaal) +project(LocaalSDK VERSION 1.0.0 LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD 11) +# Option to build shared libraries +option(BUILD_SHARED_LIBS "Build shared libraries" OFF) +option(BUILD_EXAMPLES "Build examples" OFF) + +# Create the LocaalSDK target +add_library(${CMAKE_PROJECT_NAME} INTERFACE) + set(USE_SYSTEM_CURL OFF CACHE STRING "Use system cURL") if(USE_SYSTEM_CURL) find_package(CURL REQUIRED) - target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE "${CURL_LIBRARIES}") + target_link_libraries(${CMAKE_PROJECT_NAME} INTERFACE "${CURL_LIBRARIES}") target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM PUBLIC "${CURL_INCLUDE_DIRS}") else() include(cmake/BuildMyCurl.cmake) - target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE libcurl) + target_link_libraries(${CMAKE_PROJECT_NAME} INTERFACE libcurl) endif() if(WIN32) @@ -31,11 +41,11 @@ if(WIN32) endif() include(cmake/BuildWhispercpp.cmake) -target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE Whispercpp) +target_link_libraries(${CMAKE_PROJECT_NAME} INTERFACE Whispercpp) include(cmake/BuildCTranslate2.cmake) include(cmake/BuildSentencepiece.cmake) -target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE ct2 sentencepiece) +target_link_libraries(${CMAKE_PROJECT_NAME} INTERFACE ct2 sentencepiece) set(USE_SYSTEM_ONNXRUNTIME OFF @@ -46,7 +56,7 @@ set(DISABLE_ONNXRUNTIME_GPU CACHE STRING "Disables GPU support of ONNX Runtime (Only valid on Linux)") if(DISABLE_ONNXRUNTIME_GPU) - target_compile_definitions(${CMAKE_PROJECT_NAME} PRIVATE DISABLE_ONNXRUNTIME_GPU) + target_compile_definitions(${CMAKE_PROJECT_NAME} INTERFACE DISABLE_ONNXRUNTIME_GPU) endif() if(USE_SYSTEM_ONNXRUNTIME) @@ -55,8 +65,8 @@ if(USE_SYSTEM_ONNXRUNTIME) set(Onnxruntime_INCLUDE_PATH ${Onnxruntime_INCLUDE_DIR} ${Onnxruntime_INCLUDE_DIR}/onnxruntime ${Onnxruntime_INCLUDE_DIR}/onnxruntime/core/session ${Onnxruntime_INCLUDE_DIR}/onnxruntime/core/providers/cpu) - target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE "${Onnxruntime_LIBRARIES}") - target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM PUBLIC "${Onnxruntime_INCLUDE_PATH}") + target_link_libraries(${CMAKE_PROJECT_NAME} INTERFACE "${Onnxruntime_LIBRARIES}") + target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM INTERFACE "${Onnxruntime_INCLUDE_PATH}") else() message(FATAL_ERROR "System ONNX Runtime is only supported on Linux!") endif() @@ -66,53 +76,66 @@ endif() include(cmake/BuildICU.cmake) # Add ICU to the target -target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE ICU) -target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM PUBLIC ${ICU_INCLUDE_DIR}) - -# Add your source files here -set(SOURCES - src/locaal.cpp +target_link_libraries(${CMAKE_PROJECT_NAME} INTERFACE ICU) +target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM INTERFACE ${ICU_INCLUDE_DIR}) + +include(cmake/FetchLibav.cmake) +target_link_libraries(${CMAKE_PROJECT_NAME} INTERFACE FFmpeg) + + +# List of all available modules +set(LOCAAL_MODULES + Core + Transcription + Translation + # OCR + # DocumentAnalysis + # SpeechSynthesis + # ImageSegmentation ) -# Add your header files here -set(HEADERS - include/locaal.h -) +# Function to add a module +function(add_locaal_module MODULE_NAME) + add_subdirectory(src/modules/${MODULE_NAME}) + list(APPEND LOCAAL_ENABLED_MODULES ${MODULE_NAME}) + set(LOCAAL_ENABLED_MODULES ${LOCAAL_ENABLED_MODULES} PARENT_SCOPE) +endfunction() + -# Create the shared library -add_library(${CMAKE_PROJECT_NAME} SHARED ${SOURCES} ${HEADERS}) +# Add requested modules +foreach(MODULE ${LocaalSDK_FIND_COMPONENTS}) + if(${MODULE} IN_LIST LOCAAL_MODULES) + add_locaal_module(${MODULE}) + else() + message(FATAL_ERROR "Unknown module: ${MODULE}") + endif() +endforeach() -# Set the include directories -target_include_directories(locaal PUBLIC include) +target_link_libraries(${CMAKE_PROJECT_NAME} INTERFACE ${LOCAAL_ENABLED_MODULES}) -include(GNUInstallDirs) +# Generate and install package configuration files include(CMakePackageConfigHelpers) +write_basic_package_version_file( + "${CMAKE_CURRENT_BINARY_DIR}/LocaalSDKConfigVersion.cmake" + VERSION ${PROJECT_VERSION} + COMPATIBILITY SameMajorVersion +) + +configure_package_config_file( + "${CMAKE_CURRENT_SOURCE_DIR}/cmake/LocaalSDKConfig.cmake.in" + "${CMAKE_CURRENT_BINARY_DIR}/LocaalSDKConfig.cmake" + INSTALL_DESTINATION lib/cmake/LocaalSDK +) -# Install the library -install(TARGETS ${CMAKE_PROJECT_NAME} - EXPORT ${CMAKE_PROJECT_NAME}Targets - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) - -# Install the headers -install(FILES ${HEADERS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${CMAKE_PROJECT_NAME}) - -# Install the cmake config files -install(EXPORT ${CMAKE_PROJECT_NAME}Targets - FILE ${CMAKE_PROJECT_NAME}Targets.cmake - NAMESPACE ${CMAKE_PROJECT_NAME}:: - DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/${CMAKE_PROJECT_NAME}) - -configure_package_config_file(cmake/${CMAKE_PROJECT_NAME}Config.cmake.in - ${CMAKE_PROJECT_NAME}Config.cmake - INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/${CMAKE_PROJECT_NAME}) - -write_basic_package_version_file(${CMAKE_PROJECT_NAME}ConfigVersion.cmake - VERSION 1.0 - COMPATIBILITY AnyNewerVersion) - -install(FILES ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_PROJECT_NAME}Config.cmake - ${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_PROJECT_NAME}ConfigVersion.cmake - DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/${CMAKE_PROJECT_NAME}) +install(TARGETS ${CMAKE_PROJECT_NAME} EXPORT LocaalSDKTargets) +install(EXPORT LocaalSDKTargets + FILE LocaalSDKTargets.cmake + NAMESPACE LocaalSDK:: + DESTINATION lib/cmake/LocaalSDK +) + +install(FILES + "${CMAKE_CURRENT_BINARY_DIR}/LocaalSDKConfig.cmake" + "${CMAKE_CURRENT_BINARY_DIR}/LocaalSDKConfigVersion.cmake" + DESTINATION lib/cmake/LocaalSDK +) diff --git a/README.md b/README.md index 6218573..03fb46e 100644 --- a/README.md +++ b/README.md @@ -1,86 +1,155 @@ -# Real-time Transcription and Translation Library +# Locaal SDK ## Overview -This C++ library provides real-time transcription and translation capabilities using Whisper.cpp and CTranslate2. It's designed to work on-device without relying on cloud services, making it suitable for applications requiring privacy and offline functionality. +Locaal SDK is a comprehensive, modular on-device AI toolkit designed to bring advanced AI capabilities to various platforms including desktop, mobile devices, and web applications. Our focus is on providing powerful, efficient, and privacy-preserving AI features that run locally on the device, eliminating the need for cloud-based processing. -Key features: -- Cross-platform support (macOS, Windows, Linux) -- Real-time speech-to-text transcription -- On-device translation -- Built with CMake for easy integration and compilation +## Key Features + +Locaal SDK offers a wide range of AI capabilities through its modular architecture: + +- **Transcription & Translation**: Real-time speech-to-text and language translation +- **Optical Character Recognition (OCR)**: Extract text from images and documents +- **Document Analysis**: Understand and extract information from structured documents +- **Speech Synthesis**: Convert text to natural-sounding speech +- **Image Segmentation**: Identify and separate different objects within images +- **Core Module**: Common utilities and shared functionalities + +Each feature is implemented as a separate module, allowing developers to include only the capabilities they need, optimizing for performance and resource usage. + +## Supported Platforms + +- Desktop: Windows, macOS, Linux +- Mobile: iOS, Android +- Web: WebAssembly-compatible browsers ## Prerequisites -Before building the library, ensure you have the following installed: +Before building the SDK, ensure you have the following installed: - C++ compiler with C++17 support - CMake (version 3.12 or higher) - Git -## Building the Library +Additional dependencies may be required for specific modules. Refer to each module's documentation for details. -### macOS +## Building the SDK -1. Open Terminal and navigate to the project directory. -2. Run the following commands: +### Desktop (Windows, macOS, Linux) -```bash -mkdir build && cd build -cmake .. -make -``` +1. Clone the repository: + ``` + git clone https://github.com/your-repo/locaal-sdk.git + cd locaal-sdk + ``` + +2. Create a build directory: + ``` + mkdir build && cd build + ``` + +3. Configure and build: + ``` + cmake .. + cmake --build . --config Release + ``` + +### Mobile and Web Platforms + +For mobile and web platforms, please refer to the platform-specific build instructions in the `docs/` directory. + + +## Usage -### Windows +### Including Locaal SDK in Your CMake Project -1. Open Command Prompt or PowerShell and navigate to the project directory. -2. Run the following commands: +To use Locaal SDK in your CMake project, follow these steps: -```cmd -mkdir build -cd build -cmake .. -G "Visual Studio 16 2019" -A x64 -cmake --build . --config Release +1. First, make sure you have the Locaal SDK installed on your system or available as a subdirectory in your project. + +2. In your `CMakeLists.txt` file, add the following lines to find and link the Locaal SDK: + +```cmake +# Find the Locaal SDK package +find_package(LocaalSDK REQUIRED) + +# Create your executable or library +add_executable(your_app main.cpp) + +# Link against the Locaal SDK modules you need +target_link_libraries(your_app + PRIVATE + LocaalSDK::Core + LocaalSDK::Transcription + LocaalSDK::Translation + # Add other modules as needed +) ``` -Note: Adjust the Visual Studio version as needed. +If you're using Locaal SDK as a subdirectory in your project: -### Linux +```cmake +# Add the Locaal SDK subdirectory +add_subdirectory(path/to/locaal-sdk) -1. Open a terminal and navigate to the project directory. -2. Run the following commands: +# Create your executable or library +add_executable(your_app main.cpp) -```bash -mkdir build && cd build -cmake .. -make +# Link against the Locaal SDK modules you need +target_link_libraries(your_app + PRIVATE + LocaalSDK::Core + LocaalSDK::Transcription + LocaalSDK::Translation + # Add other modules as needed +) ``` -## Usage +### Code Example -After building the library, you can include it in your C++ project. Here's a basic example of how to use the library: +Here's a basic example of how to use the Locaal SDK in your C++ project: ```cpp -#include +#include +#include +#include int main() { - // Initialize the library - locaal::TranscriptionTranslation tt; - + locaal::Core core; + + // Initialize transcription module + locaal::Transcription transcription(core); + // Start real-time transcription - tt.startTranscription(); - + transcription.start(); + + // Initialize translation module + locaal::Translation translation(core); + // Translate text - std::string translated = tt.translate("Hello, world!", "en", "fr"); - + std::string translated = translation.translate("Hello, world!", "en", "fr"); + return 0; } ``` -For more detailed usage instructions and API documentation, please refer to the `docs` folder and the `examples` folder. +For more detailed usage instructions and API documentation for each module, please refer to the `docs/` folder. + + +## Modules + +- **Core**: Provides common utilities and shared functionalities used across other modules. +- **Transcription**: Enables real-time speech-to-text capabilities. +- **Translation**: Offers text translation between multiple languages. +- **OCR**: Extracts text from images and documents. +- **Document Analysis**: Analyzes and extracts information from structured documents. +- **Speech Synthesis**: Converts text to natural-sounding speech. +- **Image Segmentation**: Identifies and separates different objects within images. + +Each module can be used independently or in combination with others, depending on your application's needs. ## Contributing -Contributions are welcome! Please feel free to submit a Pull Request. +We welcome contributions to the Locaal SDK! Please feel free to submit issues, feature requests, or pull requests. ## License @@ -88,5 +157,14 @@ This project is licensed under the MIT License - see the LICENSE file for detail ## Acknowledgments -- [Whisper.cpp](https://github.com/ggerganov/whisper.cpp) -- [CTranslate2](https://github.com/OpenNMT/CTranslate2) +Locaal SDK leverages several amazing open-source projects, including: +- [Whisper.cpp](https://github.com/ggerganov/whisper.cpp) for transcription +- [CTranslate2](https://github.com/OpenNMT/CTranslate2) for translation +- [GGML](https://github.com/ggerganov/ggml) for on-device execution +- [OpenCV](https://opencv.org/) for image processing +- [onnxruntime](https://github.com/microsoft/onnxruntime) for on-device execution +- [cUrl](https://github.com/curl/curl) for networking +- [SDL](https://github.com/libsdl-org/SDL) for media processing and access +- [ICU](https://github.com/unicode-org/icu) for unicode text processing + +We are grateful to the developers and contributors of these projects. diff --git a/cmake/BuildCTranslate2.cmake b/cmake/BuildCTranslate2.cmake index 0d60561..4eb42c2 100644 --- a/cmake/BuildCTranslate2.cmake +++ b/cmake/BuildCTranslate2.cmake @@ -27,6 +27,7 @@ elseif(WIN32) if(${ACCELERATION} STREQUAL "cpu" OR ${ACCELERATION} STREQUAL "hipblas") FetchContent_Declare( ctranslate2_fetch + DOWNLOAD_EXTRACT_TIMESTAMP URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.2.0/libctranslate2-windows-4.1.1-Release-cpu.zip URL_HASH SHA256=30ff8b2499b8d3b5a6c4d6f7f8ddbc89e745ff06e0050b645e3b7c9b369451a3) else() @@ -125,3 +126,7 @@ else() target_link_libraries(ct2 INTERFACE ct2::ct2 cpu_features) endif() + +# add exported target install +install(TARGETS ct2 EXPORT ct2Targets) +install(EXPORT ct2Targets NAMESPACE ct2:: DESTINATION "lib/cmake/ct2") diff --git a/cmake/BuildICU.cmake b/cmake/BuildICU.cmake index a3c575d..954939c 100644 --- a/cmake/BuildICU.cmake +++ b/cmake/BuildICU.cmake @@ -14,6 +14,7 @@ if(WIN32) FetchContent_Declare( ICU_build + DOWNLOAD_EXTRACT_TIMESTAMP URL ${ICU_URL} URL_HASH ${ICU_HASH}) @@ -99,3 +100,7 @@ foreach(lib ${ICU_LIBRARIES}) target_link_libraries(ICU INTERFACE ICU::${lib}) endforeach() target_include_directories(ICU SYSTEM INTERFACE $) + +# add exported target +install(TARGETS ICU EXPORT ICUTargets) +install(EXPORT ICUTargets NAMESPACE ICU:: DESTINATION "lib/cmake/ICU") diff --git a/cmake/BuildMyCurl.cmake b/cmake/BuildMyCurl.cmake index 10d3e05..7c38ea1 100644 --- a/cmake/BuildMyCurl.cmake +++ b/cmake/BuildMyCurl.cmake @@ -37,6 +37,7 @@ endif() FetchContent_Declare( libcurl_fetch + DOWNLOAD_EXTRACT_TIMESTAMP URL ${LibCurl_URL} URL_HASH ${LibCurl_HASH}) FetchContent_MakeAvailable(libcurl_fetch) diff --git a/cmake/BuildSentencepiece.cmake b/cmake/BuildSentencepiece.cmake index 024283e..cb0db14 100644 --- a/cmake/BuildSentencepiece.cmake +++ b/cmake/BuildSentencepiece.cmake @@ -17,6 +17,7 @@ elseif(WIN32) FetchContent_Declare( sentencepiece_fetch + DOWNLOAD_EXTRACT_TIMESTAMP URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.1.1/sentencepiece-windows-0.2.0-Release.zip URL_HASH SHA256=846699c7fa1e8918b71ed7f2bd5cd60e47e51105e1d84e3192919b4f0f10fdeb) FetchContent_MakeAvailable(sentencepiece_fetch) @@ -59,3 +60,8 @@ else() target_include_directories(sentencepiece INTERFACE ${INSTALL_DIR}/include) endif() + +# add exported target install +install(TARGETS sentencepiece EXPORT sentencepiece-targets) +install(EXPORT sentencepiece-targets NAMESPACE sentencepiece:: DESTINATION lib/cmake/sentencepiece) +install(FILES ${sentencepiece_fetch_SOURCE_DIR}/lib/libsentencepiece.a DESTINATION "bin") diff --git a/cmake/BuildWhispercpp.cmake b/cmake/BuildWhispercpp.cmake index 66e0f0b..a936d49 100644 --- a/cmake/BuildWhispercpp.cmake +++ b/cmake/BuildWhispercpp.cmake @@ -148,3 +148,13 @@ if(APPLE) target_link_libraries(Whispercpp INTERFACE "-framework Accelerate -framework CoreML -framework Metal") target_link_libraries(Whispercpp INTERFACE Whispercpp::GGML Whispercpp::CoreML) endif(APPLE) + +# add exported target install +install(TARGETS Whispercpp EXPORT WhispercppTargets) +install(EXPORT WhispercppTargets NAMESPACE Whispercpp:: DESTINATION lib/cmake/Whispercpp) +install(FILES ${whispercpp_fetch_SOURCE_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}whisper${CMAKE_STATIC_LIBRARY_SUFFIX} + DESTINATION "bin") +install(FILES ${whispercpp_fetch_SOURCE_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}ggml${CMAKE_STATIC_LIBRARY_SUFFIX} + DESTINATION "bin") +install(FILES ${whispercpp_fetch_SOURCE_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}whisper.coreml${CMAKE_STATIC_LIBRARY_SUFFIX} + DESTINATION "bin") diff --git a/cmake/FetchLibav.cmake b/cmake/FetchLibav.cmake new file mode 100644 index 0000000..b0fae7a --- /dev/null +++ b/cmake/FetchLibav.cmake @@ -0,0 +1,80 @@ +include(FetchContent) + +if(WIN32) + include(FetchContent) + + set(FFMPEG_URL "https://www.gyan.dev/ffmpeg/builds/ffmpeg-release-full-shared.7z") + set(FFMPEG_HASH "SHA256=a0b6e8c7978b95d019a93dcf4b4ab74b17d9e53e0a87cfd463e1376c5927e30b") + + FetchContent_Declare( + FFmpeg_fetch + DOWNLOAD_EXTRACT_TIMESTAMP + URL ${FFMPEG_URL} + URL_HASH ${FFMPEG_HASH} + ) + + FetchContent_MakeAvailable(FFmpeg_fetch) + + set(FFMPEG_ROOT ${ffmpeg_fetch_SOURCE_DIR}) + + find_path(FFMPEG_INCLUDE_DIR libavcodec/avcodec.h + PATHS ${FFMPEG_ROOT}/include + NO_DEFAULT_PATH + ) + + find_library(AVCODEC_LIBRARY avcodec + PATHS ${FFMPEG_ROOT}/lib + NO_DEFAULT_PATH + ) + find_library(AVFORMAT_LIBRARY avformat + PATHS ${FFMPEG_ROOT}/lib + NO_DEFAULT_PATH + ) + find_library(AVUTIL_LIBRARY avutil + PATHS ${FFMPEG_ROOT}/lib + NO_DEFAULT_PATH + ) + find_library(SWRESAMPLE_LIBRARY swresample + PATHS ${FFMPEG_ROOT}/lib + NO_DEFAULT_PATH + ) + + set(FFMPEG_LIBRARIES + ${AVCODEC_LIBRARY} + ${AVFORMAT_LIBRARY} + ${AVUTIL_LIBRARY} + ${SWRESAMPLE_LIBRARY} + ) +else() + # For Linux and macOS, use pkg-config + find_package(PkgConfig REQUIRED) + pkg_check_modules(FFMPEG REQUIRED IMPORTED_TARGET + libavcodec + libavformat + libavutil + libswresample + ) + set(FFMPEG_INCLUDE_DIR ${FFMPEG_INCLUDE_DIRS}) + set(FFMPEG_LIBRARIES PkgConfig::FFMPEG) +endif() + +if(WIN32) + # Add FFmpeg bin directory to PATH for runtime + set_property(TARGET ${PROJECT_NAME} PROPERTY VS_DEBUGGER_ENVIRONMENT "PATH=${FFMPEG_ROOT}/bin;$ENV{PATH}") + + # Copy DLLs to output directory using install + install(DIRECTORY ${FFMPEG_ROOT}/bin/ DESTINATION bin) +endif() + +# Create FFmpeg interface library +add_library(FFmpeg INTERFACE) +target_include_directories(FFmpeg INTERFACE ${FFMPEG_INCLUDE_DIR}) +target_link_libraries(FFmpeg INTERFACE ${FFMPEG_LIBRARIES}) + +# add exported target +install(TARGETS FFmpeg EXPORT FFmpegTargets) +install(EXPORT FFmpegTargets + FILE FFmpegTargets.cmake + NAMESPACE FFmpeg:: + DESTINATION lib/cmake/FFmpeg +) diff --git a/cmake/FetchOnnxruntime.cmake b/cmake/FetchOnnxruntime.cmake index 0ed2975..8cf3908 100644 --- a/cmake/FetchOnnxruntime.cmake +++ b/cmake/FetchOnnxruntime.cmake @@ -8,7 +8,7 @@ set(CUSTOM_ONNXRUNTIME_HASH "" CACHE STRING "Hash of a downloaded ONNX Runtime tarball") -set(Onnxruntime_VERSION "1.17.1") +set(Onnxruntime_VERSION "1.19.2") if(CUSTOM_ONNXRUNTIME_URL STREQUAL "") set(USE_PREDEFINED_ONNXRUNTIME ON) @@ -25,17 +25,17 @@ if(USE_PREDEFINED_ONNXRUNTIME) if(APPLE) set(Onnxruntime_URL "${Onnxruntime_BASEURL}/onnxruntime-osx-universal2-${Onnxruntime_VERSION}.tgz") - set(Onnxruntime_HASH SHA256=9FA57FA6F202A373599377EF75064AE568FDA8DA838632B26A86024C7378D306) + set(Onnxruntime_HASH SHA256=b0289ddbc32f76e5d385abc7b74cc7c2c51cdf2285b7d118bf9d71206e5aee3a) elseif(MSVC) set(Onnxruntime_URL "${Onnxruntime_BASEURL}/onnxruntime-win-x64-${Onnxruntime_VERSION}.zip") - set(OOnnxruntime_HASH SHA256=4802AF9598DB02153D7DA39432A48823FF69B2FB4B59155461937F20782AA91C) + set(OOnnxruntime_HASH SHA256=dc4f841e511977c0a4f02e5066c3d9a58427644010ab4f89b918614a1cd4c2b0) else() if(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") set(Onnxruntime_URL "${Onnxruntime_BASEURL}/onnxruntime-linux-aarch64-${Onnxruntime_VERSION}.tgz") - set(Onnxruntime_HASH SHA256=70B6F536BB7AB5961D128E9DBD192368AC1513BFFB74FE92F97AAC342FBD0AC1) + set(Onnxruntime_HASH SHA256=dc4f841e511977c0a4f02e5066c3d9a58427644010ab4f89b918614a1cd4c2b0) else() set(Onnxruntime_URL "${Onnxruntime_BASEURL}/onnxruntime-linux-x64-gpu-${Onnxruntime_VERSION}.tgz") - set(Onnxruntime_HASH SHA256=613C53745EA4960ED368F6B3AB673558BB8561C84A8FA781B4EA7FB4A4340BE4) + set(Onnxruntime_HASH SHA256=4d1c10f0b410b67261302c6e18bb1b05ba924ca9081e3a26959e0d12ab69f534) endif() endif() else() @@ -45,6 +45,7 @@ endif() FetchContent_Declare( onnxruntime + DOWNLOAD_EXTRACT_TIMESTAMP URL ${Onnxruntime_URL} URL_HASH ${Onnxruntime_HASH}) FetchContent_MakeAvailable(onnxruntime) @@ -79,7 +80,12 @@ elseif(MSVC) install(FILES ${onnxruntime_SOURCE_DIR}/lib/${lib_name}.dll DESTINATION "obs-plugins/64bit") endforeach() - target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE Ort) + target_link_libraries(${CMAKE_PROJECT_NAME} INTERFACE Ort) + + # add exported target install + install(TARGETS Ort EXPORT OrtTargets) + install(EXPORT OrtTargets NAMESPACE Ort:: DESTINATION "lib/cmake/Ort") + install(FILES ${onnxruntime_SOURCE_DIR}/lib/onnxruntime.dll DESTINATION "bin") else() if(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") diff --git a/cmake/LocaalSDKConfig.cmake.in b/cmake/LocaalSDKConfig.cmake.in new file mode 100644 index 0000000..fe3b990 --- /dev/null +++ b/cmake/LocaalSDKConfig.cmake.in @@ -0,0 +1,23 @@ +@PACKAGE_INIT@ + +include("${CMAKE_CURRENT_LIST_DIR}/LocaalSDKTargets.cmake") + +# Check and include enabled components +set(_supported_components Core Transcription Translation) + +foreach(_comp ${LocaalSDK_FIND_COMPONENTS}) + if (NOT _comp IN_LIST _supported_components) + set(LocaalSDK_FOUND False) + set(LocaalSDK_NOT_FOUND_MESSAGE "Unsupported component: ${_comp}") + endif() + include("${CMAKE_CURRENT_LIST_DIR}/LocaalSDK${_comp}Config.cmake" OPTIONAL) + if(TARGET LocaalSDK::${_comp}) + set(LocaalSDK_${_comp}_FOUND TRUE) + else() + set(LocaalSDK_${_comp}_FOUND FALSE) + set(LocaalSDK_FOUND False) + set(LocaalSDK_NOT_FOUND_MESSAGE "LocaalSDK component not found: ${_comp}") + endif() +endforeach() + +check_required_components(LocaalSDK) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index e69de29..73d443e 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -0,0 +1,5 @@ + +find_package(LocaalSDK REQUIRED COMPONENTS Core Transcription Translation) + +add_executable(RealtimeTranscription realtime_transcription.cpp) +target_link_libraries(MyApp PRIVATE LocaalSDK::Core LocaalSDK::Transcription LocaalSDK::Translation) diff --git a/src/model-utils/model-downloader-ui.cpp b/src/model-utils/model-downloader-ui.cpp deleted file mode 100644 index a428e20..0000000 --- a/src/model-utils/model-downloader-ui.cpp +++ /dev/null @@ -1,256 +0,0 @@ -#include "model-downloader-ui.h" -#include "plugin-support.h" - -#include - -#include - -size_t write_data(void *ptr, size_t size, size_t nmemb, FILE *stream) -{ - size_t written = fwrite(ptr, size, nmemb, stream); - return written; -} - -ModelDownloader::ModelDownloader(const ModelInfo &model_info, - download_finished_callback_t download_finished_callback_, - QWidget *parent) - : QDialog(parent), - download_finished_callback(download_finished_callback_) -{ - this->setWindowTitle("LocalVocal: Downloading model..."); - this->setWindowFlags(Qt::Dialog | Qt::WindowTitleHint | Qt::CustomizeWindowHint); - this->setFixedSize(300, 100); - // Bring the dialog to the front - this->activateWindow(); - this->raise(); - - this->layout = new QVBoxLayout(this); - - // Add a label for the model name - QLabel *model_name_label = new QLabel(this); - model_name_label->setText(QString::fromStdString(model_info.friendly_name)); - model_name_label->setAlignment(Qt::AlignCenter); - this->layout->addWidget(model_name_label); - - this->progress_bar = new QProgressBar(this); - this->progress_bar->setRange(0, 100); - this->progress_bar->setValue(0); - this->progress_bar->setAlignment(Qt::AlignCenter); - // Show progress as a percentage - this->progress_bar->setFormat("%p%"); - this->layout->addWidget(this->progress_bar); - - this->download_thread = new QThread(); - this->download_worker = new ModelDownloadWorker(model_info); - this->download_worker->moveToThread(this->download_thread); - - connect(this->download_thread, &QThread::started, this->download_worker, - &ModelDownloadWorker::download_model); - connect(this->download_worker, &ModelDownloadWorker::download_progress, this, - &ModelDownloader::update_progress); - connect(this->download_worker, &ModelDownloadWorker::download_finished, this, - &ModelDownloader::download_finished); - connect(this->download_worker, &ModelDownloadWorker::download_finished, - this->download_thread, &QThread::quit); - connect(this->download_worker, &ModelDownloadWorker::download_finished, - this->download_worker, &ModelDownloadWorker::deleteLater); - connect(this->download_worker, &ModelDownloadWorker::download_error, this, - &ModelDownloader::show_error); - connect(this->download_thread, &QThread::finished, this->download_thread, - &QThread::deleteLater); - - this->download_thread->start(); -} - -void ModelDownloader::closeEvent(QCloseEvent *e) -{ - if (!this->mPrepareToClose) - e->ignore(); - else { - QDialog::closeEvent(e); - deleteLater(); - } -} - -void ModelDownloader::close() -{ - this->mPrepareToClose = true; - - QDialog::close(); -} - -void ModelDownloader::update_progress(int progress) -{ - this->progress_bar->setValue(progress); -} - -void ModelDownloader::download_finished(const std::string &path) -{ - // Call the callback with the path to the downloaded model - this->download_finished_callback(0, path); - // Close the dialog - this->close(); -} - -void ModelDownloader::show_error(const std::string &reason) -{ - this->setWindowTitle("Download failed!"); - this->progress_bar->setFormat("Download failed!"); - this->progress_bar->setAlignment(Qt::AlignCenter); - this->progress_bar->setStyleSheet("QProgressBar::chunk { background-color: #FF0000; }"); - // Add a label to show the error - QLabel *error_label = new QLabel(this); - error_label->setText(QString::fromStdString(reason)); - error_label->setAlignment(Qt::AlignCenter); - // Color red - error_label->setStyleSheet("QLabel { color : red; }"); - this->layout->addWidget(error_label); - // Add a button to close the dialog - QPushButton *close_button = new QPushButton("Close", this); - this->layout->addWidget(close_button); - connect(close_button, &QPushButton::clicked, this, &ModelDownloader::close); - this->download_finished_callback(1, ""); -} - -ModelDownloadWorker::ModelDownloadWorker(const ModelInfo &model_info_) : model_info(model_info_) {} - -std::string get_filename_from_url(const std::string &url) -{ - auto lastSlashPos = url.find_last_of("/"); - auto queryPos = url.find("?", lastSlashPos); - if (queryPos == std::string::npos) { - return url.substr(lastSlashPos + 1); - } else { - return url.substr(lastSlashPos + 1, queryPos - lastSlashPos - 1); - } -} - -void ModelDownloadWorker::download_model() -{ - char *config_folder = obs_module_config_path("models"); -#ifdef _WIN32 - // convert mbstring to wstring - int count = MultiByteToWideChar(CP_UTF8, 0, config_folder, strlen(config_folder), NULL, 0); - std::wstring config_folder_str(count, 0); - MultiByteToWideChar(CP_UTF8, 0, config_folder, strlen(config_folder), &config_folder_str[0], - count); - obs_log(LOG_INFO, "Download: Config models folder: %S", config_folder_str.c_str()); -#else - std::string config_folder_str = config_folder; - obs_log(LOG_INFO, "Download: Config models folder: %s", config_folder_str.c_str()); -#endif - bfree(config_folder); - - const std::filesystem::path module_config_models_folder = - std::filesystem::absolute(config_folder_str); - - // Check if the config folder exists - if (!std::filesystem::exists(module_config_models_folder)) { - obs_log(LOG_WARNING, "Config folder does not exist: %s", - module_config_models_folder.string().c_str()); - // Create the config folder - if (!std::filesystem::create_directories(module_config_models_folder)) { - obs_log(LOG_ERROR, "Failed to create config folder: %s", - module_config_models_folder.string().c_str()); - emit download_error("Failed to create config folder."); - return; - } - } - - const std::string model_local_config_path = - (module_config_models_folder / model_info.local_folder_name).string(); - - obs_log(LOG_INFO, "Model save path: %s", model_local_config_path.c_str()); - - if (!std::filesystem::exists(model_local_config_path)) { - // model folder does not exist, create it - if (!std::filesystem::create_directories(model_local_config_path)) { - obs_log(LOG_ERROR, "Failed to create model folder: %s", - model_local_config_path.c_str()); - emit download_error("Failed to create model folder."); - return; - } - } - - CURL *curl = curl_easy_init(); - if (curl) { - for (auto &model_download_file : this->model_info.files) { - obs_log(LOG_INFO, "Model URL: %s", model_download_file.url.c_str()); - - const std::string model_filename = - get_filename_from_url(model_download_file.url); - const std::string model_file_save_path = - (std::filesystem::path(model_local_config_path) / model_filename) - .string(); - if (std::filesystem::exists(model_file_save_path)) { - obs_log(LOG_INFO, "Model file already exists: %s", - model_file_save_path.c_str()); - continue; - } - - FILE *fp = fopen(model_file_save_path.c_str(), "wb"); - if (fp == nullptr) { - obs_log(LOG_ERROR, "Failed to open model file for writing %s.", - model_file_save_path.c_str()); - emit download_error("Failed to open file."); - return; - } - curl_easy_setopt(curl, CURLOPT_URL, model_download_file.url.c_str()); - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_data); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, fp); - curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); - curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, - ModelDownloadWorker::progress_callback); - curl_easy_setopt(curl, CURLOPT_XFERINFODATA, this); - // Follow redirects - curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); - CURLcode res = curl_easy_perform(curl); - if (res != CURLE_OK) { - obs_log(LOG_ERROR, "Failed to download model file %s.", - model_filename.c_str()); - emit download_error("Failed to download model file."); - } - fclose(fp); - } - curl_easy_cleanup(curl); - emit download_finished(model_local_config_path); - } else { - obs_log(LOG_ERROR, "Failed to initialize curl."); - emit download_error("Failed to initialize curl."); - } -} - -int ModelDownloadWorker::progress_callback(void *clientp, curl_off_t dltotal, curl_off_t dlnow, - curl_off_t, curl_off_t) -{ - if (dltotal == 0) { - return 0; // Unknown progress - } - ModelDownloadWorker *worker = (ModelDownloadWorker *)clientp; - if (worker == nullptr) { - obs_log(LOG_ERROR, "Worker is null."); - return 1; - } - int progress = (int)(dlnow * 100l / dltotal); - emit worker->download_progress(progress); - return 0; -} - -ModelDownloader::~ModelDownloader() -{ - if (this->download_thread != nullptr) { - if (this->download_thread->isRunning()) { - this->download_thread->quit(); - this->download_thread->wait(); - } - delete this->download_thread; - } - if (this->download_worker != nullptr) { - delete this->download_worker; - } -} - -ModelDownloadWorker::~ModelDownloadWorker() -{ - // Do nothing -} diff --git a/src/model-utils/model-downloader-ui.h b/src/model-utils/model-downloader-ui.h deleted file mode 100644 index aaa0752..0000000 --- a/src/model-utils/model-downloader-ui.h +++ /dev/null @@ -1,61 +0,0 @@ -#ifndef MODEL_DOWNLOADER_UI_H -#define MODEL_DOWNLOADER_UI_H - -#include -#include - -#include -#include - -#include - -#include "model-downloader-types.h" - -class ModelDownloadWorker : public QObject { - Q_OBJECT -public: - ModelDownloadWorker(const ModelInfo &model_info_); - ~ModelDownloadWorker(); - -public slots: - void download_model(); - -signals: - void download_progress(int progress); - void download_finished(const std::string &path); - void download_error(const std::string &reason); - -private: - static int progress_callback(void *clientp, curl_off_t dltotal, curl_off_t dlnow, - curl_off_t ultotal, curl_off_t ulnow); - ModelInfo model_info; -}; - -class ModelDownloader : public QDialog { - Q_OBJECT -public: - ModelDownloader(const ModelInfo &model_info, - download_finished_callback_t download_finished_callback, - QWidget *parent = nullptr); - ~ModelDownloader(); - -public slots: - void update_progress(int progress); - void download_finished(const std::string &path); - void show_error(const std::string &reason); - -protected: - void closeEvent(QCloseEvent *e) override; - -private: - QVBoxLayout *layout; - QProgressBar *progress_bar; - QPointer download_thread; - QPointer download_worker; - // Callback for when the download is finished - download_finished_callback_t download_finished_callback; - bool mPrepareToClose; - void close(); -}; - -#endif // MODEL_DOWNLOADER_UI_H diff --git a/src/modules/core/CMakeLists.txt b/src/modules/core/CMakeLists.txt new file mode 100644 index 0000000..13c0e26 --- /dev/null +++ b/src/modules/core/CMakeLists.txt @@ -0,0 +1,34 @@ +add_library(Core + src/model-downloader.cpp + src/model-infos.cpp + src/model-find-utils.cpp +) + +target_include_directories(Core + PUBLIC + $ + $ +) + +# If you have any dependencies for the Core module, link them here +# target_link_libraries(Core PUBLIC SomeDependency) + +set_target_properties(Core PROPERTIES + OUTPUT_NAME locaal_core + EXPORT_NAME Core +) + +# Install the target and create export +install(TARGETS Core + EXPORT LocaalSDKTargets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin + INCLUDES DESTINATION include +) + +# Install public headers +install(DIRECTORY include/locaal + DESTINATION include + FILES_MATCHING PATTERN "*.h" +) diff --git a/src/model-utils/model-downloader-types.h b/src/modules/core/include/model-downloader-types.h similarity index 100% rename from src/model-utils/model-downloader-types.h rename to src/modules/core/include/model-downloader-types.h diff --git a/src/model-utils/model-downloader.h b/src/modules/core/include/model-downloader.h similarity index 100% rename from src/model-utils/model-downloader.h rename to src/modules/core/include/model-downloader.h diff --git a/src/model-utils/model-find-utils.h b/src/modules/core/include/model-find-utils.h similarity index 100% rename from src/model-utils/model-find-utils.h rename to src/modules/core/include/model-find-utils.h diff --git a/src/model-utils/model-downloader.cpp b/src/modules/core/src/model-downloader.cpp similarity index 100% rename from src/model-utils/model-downloader.cpp rename to src/modules/core/src/model-downloader.cpp diff --git a/src/model-utils/model-find-utils.cpp b/src/modules/core/src/model-find-utils.cpp similarity index 100% rename from src/model-utils/model-find-utils.cpp rename to src/modules/core/src/model-find-utils.cpp diff --git a/src/model-utils/model-infos.cpp b/src/modules/core/src/model-infos.cpp similarity index 100% rename from src/model-utils/model-infos.cpp rename to src/modules/core/src/model-infos.cpp diff --git a/src/modules/transcription/CMakeLists.txt b/src/modules/transcription/CMakeLists.txt new file mode 100644 index 0000000..9410f4a --- /dev/null +++ b/src/modules/transcription/CMakeLists.txt @@ -0,0 +1,38 @@ +add_library(Transcription + src/silero-vad-onnx.cpp + src/token-buffer-thread.cpp + src/transcription-utils.cpp + src/vad-processing.cpp + src/whisper-model-utils.cpp + src/whisper-processing.cpp + src/whisper-utils.cpp +) + +target_include_directories(Transcription + PUBLIC + $ + $ +) + +# If you have any dependencies for the Core module, link them here +# target_link_libraries(Core PUBLIC SomeDependency) + +set_target_properties(Transcription PROPERTIES + OUTPUT_NAME locaal_transcription + EXPORT_NAME Transcription +) + +# Install the target and create export +install(TARGETS Transcription + EXPORT LocaalSDKTargets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin + INCLUDES DESTINATION include +) + +# Install public headers +install(DIRECTORY include/locaal + DESTINATION include + FILES_MATCHING PATTERN "*.h" +) diff --git a/src/whisper-utils/silero-vad-onnx.h b/src/modules/transcription/include/silero-vad-onnx.h similarity index 100% rename from src/whisper-utils/silero-vad-onnx.h rename to src/modules/transcription/include/silero-vad-onnx.h diff --git a/src/whisper-utils/token-buffer-thread.h b/src/modules/transcription/include/token-buffer-thread.h similarity index 95% rename from src/whisper-utils/token-buffer-thread.h rename to src/modules/transcription/include/token-buffer-thread.h index 7666669..96e2f75 100644 --- a/src/whisper-utils/token-buffer-thread.h +++ b/src/modules/transcription/include/token-buffer-thread.h @@ -22,7 +22,7 @@ typedef std::string TokenBufferString; typedef char TokenBufferChar; #endif -struct transcription_filter_data; +struct transcription_context; enum TokenBufferSegmentation { SEGMENTATION_WORD = 0, SEGMENTATION_TOKEN, SEGMENTATION_SENTENCE }; enum TokenBufferSpeed { SPEED_SLOW = 0, SPEED_NORMAL, SPEED_FAST }; @@ -51,7 +51,7 @@ class TokenBufferThread { TokenBufferThread() noexcept; ~TokenBufferThread(); - void initialize(struct transcription_filter_data *gf, + void initialize(struct transcription_context *gf, std::function captionPresentationCallback_, std::function sentenceOutputCallback_, size_t numSentences_, size_t numTokensPerSentence_, @@ -78,7 +78,7 @@ class TokenBufferThread { void monitor(); void log_token_vector(const std::vector &tokens); int getWaitTime(TokenBufferSpeed speed) const; - struct transcription_filter_data *gf; + struct transcription_context *gf; std::deque inputQueue; std::deque presentationQueue; std::deque contributionQueue; diff --git a/src/transcription-filter-data.h b/src/modules/transcription/include/transcription-context.h similarity index 92% rename from src/transcription-filter-data.h rename to src/modules/transcription/include/transcription-context.h index 205bbf0..060ae9d 100644 --- a/src/transcription-filter-data.h +++ b/src/modules/transcription/include/transcription-context.h @@ -18,7 +18,7 @@ #define MAX_PREPROC_CHANNELS 10 -struct transcription_filter_data { +struct transcription_context { obs_source_t *context; // obs filter source (this filter) size_t channels; // number of channels uint32_t sample_rate; // input sample rate @@ -125,7 +125,7 @@ struct transcription_filter_data { TokenBufferSegmentation::SEGMENTATION_TOKEN; // ctor - transcription_filter_data() : whisper_buf_mutex(), whisper_ctx_mutex(), wshiper_thread_cv() + transcription_context() : whisper_buf_mutex(), whisper_ctx_mutex(), wshiper_thread_cv() { // initialize all pointers to nullptr for (size_t i = 0; i < MAX_PREPROC_CHANNELS; i++) { @@ -147,12 +147,12 @@ struct transcription_filter_audio_info { }; // Callback sent when the transcription has a new result -void set_text_callback(struct transcription_filter_data *gf, const DetectionResultWithText &str); -void clear_current_caption(transcription_filter_data *gf_); +void set_text_callback(struct transcription_context *gf, const DetectionResultWithText &str); +void clear_current_caption(transcription_context *gf_); // Callback sent when the VAD finds an audio chunk. Sample rate = WHISPER_SAMPLE_RATE, channels = 1 // The audio chunk is in 32-bit float format -void audio_chunk_callback(struct transcription_filter_data *gf, const float *pcm32f_data, +void audio_chunk_callback(struct transcription_context *gf, const float *pcm32f_data, size_t frames, int vad_state, const DetectionResultWithText &result); #endif /* TRANSCRIPTION_FILTER_DATA_H */ diff --git a/src/transcription-utils.h b/src/modules/transcription/include/transcription-utils.h similarity index 100% rename from src/transcription-utils.h rename to src/modules/transcription/include/transcription-utils.h diff --git a/src/whisper-utils/vad-processing.h b/src/modules/transcription/include/vad-processing.h similarity index 57% rename from src/whisper-utils/vad-processing.h rename to src/modules/transcription/include/vad-processing.h index 996002b..e8b4e52 100644 --- a/src/whisper-utils/vad-processing.h +++ b/src/modules/transcription/include/vad-processing.h @@ -11,8 +11,8 @@ struct vad_state { uint64_t last_partial_segment_end_ts; }; -vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_vad_state); -vad_state hybrid_vad_segmentation(transcription_filter_data *gf, vad_state last_vad_state); -void initialize_vad(transcription_filter_data *gf, const char *silero_vad_model_file); +vad_state vad_based_segmentation(transcription_context *gf, vad_state last_vad_state); +vad_state hybrid_vad_segmentation(transcription_context *gf, vad_state last_vad_state); +void initialize_vad(transcription_context *gf, const char *silero_vad_model_file); #endif // VAD_PROCESSING_H diff --git a/src/whisper-utils/whisper-language.h b/src/modules/transcription/include/whisper-language.h similarity index 100% rename from src/whisper-utils/whisper-language.h rename to src/modules/transcription/include/whisper-language.h diff --git a/src/whisper-utils/whisper-model-utils.h b/src/modules/transcription/include/whisper-model-utils.h similarity index 70% rename from src/whisper-utils/whisper-model-utils.h rename to src/modules/transcription/include/whisper-model-utils.h index 68c649c..ae8487e 100644 --- a/src/whisper-utils/whisper-model-utils.h +++ b/src/modules/transcription/include/whisper-model-utils.h @@ -5,6 +5,6 @@ #include "transcription-filter-data.h" -void update_whisper_model(struct transcription_filter_data *gf); +void update_whisper_model(struct transcription_context *gf); #endif // WHISPER_MODEL_UTILS_H diff --git a/src/whisper-utils/whisper-processing.h b/src/modules/transcription/include/whisper-processing.h similarity index 86% rename from src/whisper-utils/whisper-processing.h rename to src/modules/transcription/include/whisper-processing.h index a00f7cb..ef645d7 100644 --- a/src/whisper-utils/whisper-processing.h +++ b/src/modules/transcription/include/whisper-processing.h @@ -31,8 +31,8 @@ struct DetectionResultWithText { void whisper_loop(void *data); struct whisper_context *init_whisper_context(const std::string &model_path, - struct transcription_filter_data *gf); -void run_inference_and_callbacks(transcription_filter_data *gf, uint64_t start_offset_ms, + struct transcription_context *gf); +void run_inference_and_callbacks(transcription_context *gf, uint64_t start_offset_ms, uint64_t end_offset_ms, int vad_state); #endif // WHISPER_PROCESSING_H diff --git a/src/whisper-utils/whisper-utils.h b/src/modules/transcription/include/whisper-utils.h similarity index 83% rename from src/whisper-utils/whisper-utils.h rename to src/modules/transcription/include/whisper-utils.h index c62168b..225c0d3 100644 --- a/src/whisper-utils/whisper-utils.h +++ b/src/modules/transcription/include/whisper-utils.h @@ -5,8 +5,8 @@ #include -void shutdown_whisper_thread(struct transcription_filter_data *gf); -void start_whisper_thread_with_path(struct transcription_filter_data *gf, const std::string &path, +void shutdown_whisper_thread(struct transcription_context *gf); +void start_whisper_thread_with_path(struct transcription_context *gf, const std::string &path, const char *silero_vad_model_file); std::pair findStartOfOverlap(const std::vector &seq1, diff --git a/src/whisper-utils/silero-vad-onnx.cpp b/src/modules/transcription/src/silero-vad-onnx.cpp similarity index 100% rename from src/whisper-utils/silero-vad-onnx.cpp rename to src/modules/transcription/src/silero-vad-onnx.cpp diff --git a/src/whisper-utils/token-buffer-thread.cpp b/src/modules/transcription/src/token-buffer-thread.cpp similarity index 96% rename from src/whisper-utils/token-buffer-thread.cpp rename to src/modules/transcription/src/token-buffer-thread.cpp index 3e3b002..dfe6bcc 100644 --- a/src/whisper-utils/token-buffer-thread.cpp +++ b/src/modules/transcription/src/token-buffer-thread.cpp @@ -38,7 +38,7 @@ TokenBufferThread::~TokenBufferThread() } void TokenBufferThread::initialize( - struct transcription_filter_data *gf_, + struct transcription_context *gf_, std::function captionPresentationCallback_, std::function sentenceOutputCallback_, size_t numSentences_, size_t numPerSentence_, std::chrono::seconds maxTime_, diff --git a/src/transcription-utils.cpp b/src/modules/transcription/src/transcription-utils.cpp similarity index 100% rename from src/transcription-utils.cpp rename to src/modules/transcription/src/transcription-utils.cpp diff --git a/src/whisper-utils/vad-processing.cpp b/src/modules/transcription/src/vad-processing.cpp similarity index 100% rename from src/whisper-utils/vad-processing.cpp rename to src/modules/transcription/src/vad-processing.cpp diff --git a/src/whisper-utils/whisper-model-utils.cpp b/src/modules/transcription/src/whisper-model-utils.cpp similarity index 98% rename from src/whisper-utils/whisper-model-utils.cpp rename to src/modules/transcription/src/whisper-model-utils.cpp index 8985a30..fde3590 100644 --- a/src/whisper-utils/whisper-model-utils.cpp +++ b/src/modules/transcription/src/whisper-model-utils.cpp @@ -9,7 +9,7 @@ #include "plugin-support.h" #include "model-utils/model-downloader.h" -void update_whisper_model(struct transcription_filter_data *gf) +void update_whisper_model(struct transcription_context *gf) { if (gf->context == nullptr) { obs_log(LOG_ERROR, "obs_source_t context is null"); diff --git a/src/whisper-utils/whisper-processing.cpp b/src/modules/transcription/src/whisper-processing.cpp similarity index 97% rename from src/whisper-utils/whisper-processing.cpp rename to src/modules/transcription/src/whisper-processing.cpp index 3518edf..c02649a 100644 --- a/src/whisper-utils/whisper-processing.cpp +++ b/src/modules/transcription/src/whisper-processing.cpp @@ -24,7 +24,7 @@ #include struct whisper_context *init_whisper_context(const std::string &model_path_in, - struct transcription_filter_data *gf) + struct transcription_context *gf) { std::string model_path = model_path_in; @@ -46,8 +46,8 @@ struct whisper_context *init_whisper_context(const std::string &model_path_in, whisper_log_set( [](enum ggml_log_level level, const char *text, void *user_data) { UNUSED_PARAMETER(level); - struct transcription_filter_data *ctx = - static_cast(user_data); + struct transcription_context *ctx = + static_cast(user_data); // remove trailing newline char *text_copy = bstrdup(text); text_copy[strcspn(text_copy, "\n")] = 0; @@ -124,7 +124,7 @@ struct whisper_context *init_whisper_context(const std::string &model_path_in, return ctx; } -struct DetectionResultWithText run_whisper_inference(struct transcription_filter_data *gf, +struct DetectionResultWithText run_whisper_inference(struct transcription_context *gf, const float *pcm32f_data_, size_t pcm32f_num_samples, uint64_t t0 = 0, uint64_t t1 = 0, @@ -310,7 +310,7 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter language}; } -void run_inference_and_callbacks(transcription_filter_data *gf, uint64_t start_offset_ms, +void run_inference_and_callbacks(transcription_context *gf, uint64_t start_offset_ms, uint64_t end_offset_ms, int vad_state) { // get the data from the entire whisper buffer @@ -350,8 +350,8 @@ void whisper_loop(void *data) return; } - struct transcription_filter_data *gf = - static_cast(data); + struct transcription_context *gf = + static_cast(data); obs_log(gf->log_level, "Starting whisper thread"); diff --git a/src/whisper-utils/whisper-utils.cpp b/src/modules/transcription/src/whisper-utils.cpp similarity index 97% rename from src/whisper-utils/whisper-utils.cpp rename to src/modules/transcription/src/whisper-utils.cpp index 84f3b0a..069e0ac 100644 --- a/src/whisper-utils/whisper-utils.cpp +++ b/src/modules/transcription/src/whisper-utils.cpp @@ -6,7 +6,7 @@ #include -void shutdown_whisper_thread(struct transcription_filter_data *gf) +void shutdown_whisper_thread(struct transcription_context *gf) { obs_log(gf->log_level, "shutdown_whisper_thread"); if (gf->whisper_context != nullptr) { @@ -24,7 +24,7 @@ void shutdown_whisper_thread(struct transcription_filter_data *gf) } } -void start_whisper_thread_with_path(struct transcription_filter_data *gf, +void start_whisper_thread_with_path(struct transcription_context *gf, const std::string &whisper_model_path, const char *silero_vad_model_file) { diff --git a/src/modules/translation/CMakeLists.txt b/src/modules/translation/CMakeLists.txt new file mode 100644 index 0000000..e30dbcf --- /dev/null +++ b/src/modules/translation/CMakeLists.txt @@ -0,0 +1,35 @@ +add_library(Translation + src/language_codes.cpp + src/translation-language-utils.cpp + src/translation-utils.cpp + src/translation.cpp +) + +target_include_directories(Translation + PUBLIC + $ + $ +) + +# If you have any dependencies for the Core module, link them here +# target_link_libraries(Core PUBLIC SomeDependency) + +set_target_properties(Translation PROPERTIES + OUTPUT_NAME locaal_translation + EXPORT_NAME Translation +) + +# Install the target and create export +install(TARGETS Translation + EXPORT LocaalSDKTargets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin + INCLUDES DESTINATION include +) + +# Install public headers +install(DIRECTORY include/locaal + DESTINATION include + FILES_MATCHING PATTERN "*.h" +) diff --git a/src/translation/language_codes.h b/src/modules/translation/include/language_codes.h similarity index 100% rename from src/translation/language_codes.h rename to src/modules/translation/include/language_codes.h diff --git a/src/translation/translation-includes.h b/src/modules/translation/include/translation-includes.h similarity index 100% rename from src/translation/translation-includes.h rename to src/modules/translation/include/translation-includes.h diff --git a/src/translation/translation-language-utils.h b/src/modules/translation/include/translation-language-utils.h similarity index 100% rename from src/translation/translation-language-utils.h rename to src/modules/translation/include/translation-language-utils.h diff --git a/src/translation/translation-utils.h b/src/modules/translation/include/translation-utils.h similarity index 67% rename from src/translation/translation-utils.h rename to src/modules/translation/include/translation-utils.h index 8a06ab4..a2f71d9 100644 --- a/src/translation/translation-utils.h +++ b/src/modules/translation/include/translation-utils.h @@ -3,6 +3,6 @@ #include "transcription-filter-data.h" -void start_translation(struct transcription_filter_data *gf); +void start_translation(struct transcription_context *gf); #endif // TRANSLATION_UTILS_H diff --git a/src/translation/translation.h b/src/modules/translation/include/translation.h similarity index 85% rename from src/translation/translation.h rename to src/modules/translation/include/translation.h index c740726..a631964 100644 --- a/src/translation/translation.h +++ b/src/modules/translation/include/translation.h @@ -34,15 +34,15 @@ struct translation_context { }; int build_translation_context(struct translation_context &translation_ctx); -void build_and_enable_translation(struct transcription_filter_data *gf, +void build_and_enable_translation(struct transcription_context *gf, const std::string &model_file_path); int translate(struct translation_context &translation_ctx, const std::string &text, const std::string &source_lang, const std::string &target_lang, std::string &result); -#define OBS_POLYGLOT_TRANSLATION_INIT_FAIL -1 -#define OBS_POLYGLOT_TRANSLATION_INIT_SUCCESS 0 -#define OBS_POLYGLOT_TRANSLATION_SUCCESS 0 -#define OBS_POLYGLOT_TRANSLATION_FAIL -1 +#define LOCAAL_TRANSLATION_INIT_FAIL -1 +#define LOCAAL_TRANSLATION_INIT_SUCCESS 0 +#define LOCAAL_TRANSLATION_SUCCESS 0 +#define LOCAAL_TRANSLATION_FAIL -1 #endif // TRANSLATION_H diff --git a/src/translation/language_codes.cpp b/src/modules/translation/src/language_codes.cpp similarity index 100% rename from src/translation/language_codes.cpp rename to src/modules/translation/src/language_codes.cpp diff --git a/src/translation/translation-language-utils.cpp b/src/modules/translation/src/translation-language-utils.cpp similarity index 100% rename from src/translation/translation-language-utils.cpp rename to src/modules/translation/src/translation-language-utils.cpp diff --git a/src/translation/translation-utils.cpp b/src/modules/translation/src/translation-utils.cpp similarity index 95% rename from src/translation/translation-utils.cpp rename to src/modules/translation/src/translation-utils.cpp index 07ca268..3abeaef 100644 --- a/src/translation/translation-utils.cpp +++ b/src/modules/translation/src/translation-utils.cpp @@ -6,7 +6,7 @@ #include "plugin-support.h" #include "model-utils/model-downloader.h" -void start_translation(struct transcription_filter_data *gf) +void start_translation(struct transcription_context *gf) { obs_log(LOG_INFO, "Starting translation..."); diff --git a/src/translation/translation.cpp b/src/modules/translation/src/translation.cpp similarity index 99% rename from src/translation/translation.cpp rename to src/modules/translation/src/translation.cpp index 0701d95..d7a2dec 100644 --- a/src/translation/translation.cpp +++ b/src/modules/translation/src/translation.cpp @@ -10,7 +10,7 @@ #include #include -void build_and_enable_translation(struct transcription_filter_data *gf, +void build_and_enable_translation(struct transcription_context *gf, const std::string &model_file_path) { std::lock_guard lock(gf->whisper_ctx_mutex); From 6f035408c869545ea8c6ef771ae77a89324efa88 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Mon, 16 Sep 2024 10:14:06 -0400 Subject: [PATCH 03/12] Add build_x64/ directory and Visual Studio Code related files to .gitignore --- .clang-format | 208 ++++++++++++++++++ .cmake-format.json | 40 ++++ .gitignore | 5 + CMakeLists.txt | 3 +- cmake/BuildCTranslate2.cmake | 7 +- cmake/BuildICU.cmake | 3 +- cmake/BuildPlatformdirs.cmake | 38 ++++ cmake/BuildSentencepiece.cmake | 8 +- cmake/FetchLibav.cmake | 5 +- scripts/build-windows.ps1 | 4 + src/modules/core/include/logger.h | 21 ++ src/modules/core/include/model-downloader.h | 4 - src/modules/core/src/logger.cpp | 46 ++++ src/modules/core/src/model-downloader.cpp | 73 ++---- src/modules/core/src/model-find-utils.cpp | 12 +- .../include/token-buffer-thread.h | 4 +- .../include/transcription-context.h | 49 ++--- .../include/whisper-model-utils.h | 2 - .../transcription/include/whisper-utils.h | 2 +- .../transcription/src/silero-vad-onnx.cpp | 13 +- .../transcription/src/token-buffer-thread.cpp | 14 +- .../transcription/src/vad-processing.cpp | 80 +++---- .../transcription/src/whisper-model-utils.cpp | 47 ++-- .../transcription/src/whisper-processing.cpp | 110 ++++----- .../transcription/src/whisper-utils.cpp | 19 +- .../translation/src/translation-utils.cpp | 18 +- src/modules/translation/src/translation.cpp | 71 +++--- 27 files changed, 621 insertions(+), 285 deletions(-) create mode 100644 .clang-format create mode 100644 .cmake-format.json create mode 100644 cmake/BuildPlatformdirs.cmake create mode 100644 scripts/build-windows.ps1 create mode 100644 src/modules/core/include/logger.h create mode 100644 src/modules/core/src/logger.cpp diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..afc3dcc --- /dev/null +++ b/.clang-format @@ -0,0 +1,208 @@ +# please use clang-format version 16 or later + +Standard: c++17 +AccessModifierOffset: -8 +AlignAfterOpenBracket: Align +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlines: Left +AlignOperands: true +AlignTrailingComments: true +AllowAllArgumentsOnNextLine: false +AllowAllConstructorInitializersOnNextLine: false +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Inline +AllowShortIfStatementsOnASingleLine: false +AllowShortLambdasOnASingleLine: Inline +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: false +BinPackArguments: true +BinPackParameters: true +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: true + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + AfterExternBlock: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false + SplitEmptyFunction: true + SplitEmptyRecord: true + SplitEmptyNamespace: true +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Custom +BreakBeforeTernaryOperators: true +BreakConstructorInitializers: BeforeColon +BreakStringLiterals: false # apparently unpredictable +ColumnLimit: 100 +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 8 +ContinuationIndentWidth: 8 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +FixNamespaceComments: true +ForEachMacros: + - 'json_object_foreach' + - 'json_object_foreach_safe' + - 'json_array_foreach' + - 'HASH_ITER' +IncludeBlocks: Preserve +IndentCaseLabels: false +IndentPPDirectives: None +IndentWidth: 8 +IndentWrappedFunctionNames: false +KeepEmptyLinesAtTheStartOfBlocks: true +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBinPackProtocolList: Auto +ObjCBlockIndentWidth: 8 +ObjCSpaceAfterProperty: true +ObjCSpaceBeforeProtocolList: true + +PenaltyBreakAssignment: 10 +PenaltyBreakBeforeFirstCallParameter: 30 +PenaltyBreakComment: 10 +PenaltyBreakFirstLessLess: 0 +PenaltyBreakString: 10 +PenaltyExcessCharacter: 100 +PenaltyReturnTypeOnItsOwnLine: 60 + +PointerAlignment: Right +ReflowComments: false +SortIncludes: false +SortUsingDeclarations: false +SpaceAfterCStyleCast: false +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: false +SpaceBeforeAssignmentOperators: true +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInCStyleCastParentheses: false +SpacesInContainerLiterals: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +StatementMacros: + - 'Q_OBJECT' +TabWidth: 8 +TypenameMacros: + - 'DARRAY' +UseTab: ForContinuationAndIndentation +--- +Language: ObjC +AccessModifierOffset: 2 +AlignArrayOfStructures: Right +AlignConsecutiveAssignments: None +AlignConsecutiveBitFields: None +AlignConsecutiveDeclarations: None +AlignConsecutiveMacros: + Enabled: true + AcrossEmptyLines: false + AcrossComments: true +AllowShortBlocksOnASingleLine: Never +AllowShortEnumsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Empty +AllowShortIfStatementsOnASingleLine: Never +AllowShortLambdasOnASingleLine: None +AttributeMacros: ['__unused', '__autoreleasing', '_Nonnull', '__bridge'] +BitFieldColonSpacing: Both +#BreakBeforeBraces: Webkit +BreakBeforeBraces: Custom +BraceWrapping: + AfterCaseLabel: false + AfterClass: true + AfterControlStatement: Never + AfterEnum: false + AfterFunction: true + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + AfterExternBlock: false + BeforeCatch: false + BeforeElse: false + BeforeLambdaBody: false + BeforeWhile: false + IndentBraces: false + SplitEmptyFunction: false + SplitEmptyRecord: false + SplitEmptyNamespace: true +BreakAfterAttributes: Never +BreakArrays: false +BreakBeforeConceptDeclarations: Allowed +BreakBeforeInlineASMColon: OnlyMultiline +BreakConstructorInitializers: AfterColon +BreakInheritanceList: AfterComma +ColumnLimit: 120 +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +EmptyLineAfterAccessModifier: Never +EmptyLineBeforeAccessModifier: LogicalBlock +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +IndentAccessModifiers: false +IndentCaseBlocks: false +IndentCaseLabels: true +IndentExternBlock: Indent +IndentGotoLabels: false +IndentRequiresClause: true +IndentWidth: 4 +IndentWrappedFunctionNames: true +InsertBraces: false +InsertNewlineAtEOF: true +KeepEmptyLinesAtTheStartOfBlocks: false +LambdaBodyIndentation: Signature +NamespaceIndentation: All +ObjCBinPackProtocolList: Auto +ObjCBlockIndentWidth: 4 +ObjCBreakBeforeNestedBlockParam: false +ObjCSpaceAfterProperty: true +ObjCSpaceBeforeProtocolList: true +PPIndentWidth: -1 +PackConstructorInitializers: NextLine +QualifierAlignment: Leave +ReferenceAlignment: Right +RemoveSemicolon: false +RequiresClausePosition: WithPreceding +RequiresExpressionIndentation: OuterScope +SeparateDefinitionBlocks: Always +ShortNamespaceLines: 1 +SortIncludes: false +#SortUsingDeclarations: LexicographicNumeric +SortUsingDeclarations: true +SpaceAfterCStyleCast: true +SpaceAfterLogicalNot: false +SpaceAroundPointerQualifiers: Default +SpaceBeforeCaseColon: false +SpaceBeforeCpp11BracedList: true +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceBeforeSquareBrackets: false +SpaceInEmptyBlock: false +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 2 +SpacesInConditionalStatement: false +SpacesInLineCommentPrefix: + Minimum: 1 + Maximum: -1 +Standard: c++17 +TabWidth: 4 +UseTab: Never diff --git a/.cmake-format.json b/.cmake-format.json new file mode 100644 index 0000000..b70e8c5 --- /dev/null +++ b/.cmake-format.json @@ -0,0 +1,40 @@ +{ + "format": { + "line_width": 120, + "tab_size": 2, + "enable_sort": true, + "autosort": true + }, + "additional_commands": { + "find_qt": { + "flags": [], + "kwargs": { + "COMPONENTS": "+", + "COMPONENTS_WIN": "+", + "COMPONENTS_MACOS": "+", + "COMPONENTS_LINUX": "+" + } + }, + "set_target_properties_obs": { + "pargs": 1, + "flags": [], + "kwargs": { + "PROPERTIES": { + "kwargs": { + "PREFIX": 1, + "OUTPUT_NAME": 1, + "FOLDER": 1, + "VERSION": 1, + "SOVERSION": 1, + "AUTOMOC": 1, + "AUTOUIC": 1, + "AUTORCC": 1, + "AUTOUIC_SEARCH_PATHS": 1, + "BUILD_RPATH": 1, + "INSTALL_RPATH": 1 + } + } + } + } + } +} diff --git a/.gitignore b/.gitignore index 8594f3a..7b8729b 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,8 @@ CMakeFiles/ CMakeScripts/ cmake_install.cmake Makefile + +# Visual Studio Code ignores +.vscode/ +.settings/ +*.code-workspace diff --git a/CMakeLists.txt b/CMakeLists.txt index 761b70c..8d8a95d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -77,11 +77,12 @@ endif() include(cmake/BuildICU.cmake) # Add ICU to the target target_link_libraries(${CMAKE_PROJECT_NAME} INTERFACE ICU) -target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM INTERFACE ${ICU_INCLUDE_DIR}) include(cmake/FetchLibav.cmake) target_link_libraries(${CMAKE_PROJECT_NAME} INTERFACE FFmpeg) +include(cmake/BuildPlatformdirs.cmake) +target_link_libraries(${CMAKE_PROJECT_NAME} INTERFACE sago_platform_folders_lib) # List of all available modules set(LOCAAL_MODULES diff --git a/cmake/BuildCTranslate2.cmake b/cmake/BuildCTranslate2.cmake index 4eb42c2..0ec206b 100644 --- a/cmake/BuildCTranslate2.cmake +++ b/cmake/BuildCTranslate2.cmake @@ -45,11 +45,14 @@ elseif(WIN32) add_library(ct2 INTERFACE) target_link_libraries(ct2 INTERFACE ${ctranslate2_fetch_SOURCE_DIR}/lib/ctranslate2.lib) - set_target_properties(ct2 PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${ctranslate2_fetch_SOURCE_DIR}/include) + target_include_directories(ct2 INTERFACE + $ + $ + ) target_compile_options(ct2 INTERFACE /wd4267 /wd4244 /wd4305 /wd4996 /wd4099) file(GLOB CT2_DLLS ${ctranslate2_fetch_SOURCE_DIR}/bin/*.dll) - install(FILES ${CT2_DLLS} DESTINATION "obs-plugins/64bit") + install(FILES ${CT2_DLLS} DESTINATION "bin") else() # build cpu_features from source set(CPU_FEATURES_VERSION "0.9.0") diff --git a/cmake/BuildICU.cmake b/cmake/BuildICU.cmake index 954939c..e328f99 100644 --- a/cmake/BuildICU.cmake +++ b/cmake/BuildICU.cmake @@ -99,7 +99,8 @@ add_dependencies(ICU ICU_build) foreach(lib ${ICU_LIBRARIES}) target_link_libraries(ICU INTERFACE ICU::${lib}) endforeach() -target_include_directories(ICU SYSTEM INTERFACE $) +target_include_directories(ICU INTERFACE $ + $) # add exported target install(TARGETS ICU EXPORT ICUTargets) diff --git a/cmake/BuildPlatformdirs.cmake b/cmake/BuildPlatformdirs.cmake new file mode 100644 index 0000000..649e4be --- /dev/null +++ b/cmake/BuildPlatformdirs.cmake @@ -0,0 +1,38 @@ +include(ExternalProject) + +# Define the sago::platform_folders external project +ExternalProject_Add( + sago_platform_folders + GIT_REPOSITORY https://github.com/sago007/PlatformFolders.git + GIT_TAG master # You might want to use a specific tag or commit hash for stability + CMAKE_ARGS + -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}/external/sago_platform_folders + -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} + BUILD_COMMAND ${CMAKE_COMMAND} --build . --config ${CMAKE_BUILD_TYPE} + INSTALL_COMMAND ${CMAKE_COMMAND} --install . --config ${CMAKE_BUILD_TYPE} +) + +# Create an interface library for sago::platform_folders +add_library(sago_platform_folders_lib INTERFACE) +add_dependencies(sago_platform_folders_lib sago_platform_folders) + +# Set include directories for the interface library +target_include_directories(sago_platform_folders_lib + INTERFACE + $ + $ +) +# add exported target install +install(TARGETS sago_platform_folders_lib + EXPORT sago_platform_folders_lib + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} +) + +# Export the targets file +install(EXPORT sago_platform_folders_lib + NAMESPACE sago:: + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/sago_platform_folders_lib +) diff --git a/cmake/BuildSentencepiece.cmake b/cmake/BuildSentencepiece.cmake index cb0db14..ca22a93 100644 --- a/cmake/BuildSentencepiece.cmake +++ b/cmake/BuildSentencepiece.cmake @@ -23,9 +23,11 @@ elseif(WIN32) FetchContent_MakeAvailable(sentencepiece_fetch) add_library(sentencepiece INTERFACE) target_link_libraries(sentencepiece INTERFACE ${sentencepiece_fetch_SOURCE_DIR}/lib/sentencepiece.lib) - set_target_properties(sentencepiece PROPERTIES INTERFACE_INCLUDE_DIRECTORIES - ${sentencepiece_fetch_SOURCE_DIR}/include) - + target_include_directories(sentencepiece INTERFACE + $ + $ + ) + else() set(SP_URL diff --git a/cmake/FetchLibav.cmake b/cmake/FetchLibav.cmake index b0fae7a..392f79d 100644 --- a/cmake/FetchLibav.cmake +++ b/cmake/FetchLibav.cmake @@ -68,7 +68,10 @@ endif() # Create FFmpeg interface library add_library(FFmpeg INTERFACE) -target_include_directories(FFmpeg INTERFACE ${FFMPEG_INCLUDE_DIR}) +target_include_directories(FFmpeg INTERFACE + $ + $ +) target_link_libraries(FFmpeg INTERFACE ${FFMPEG_LIBRARIES}) # add exported target diff --git a/scripts/build-windows.ps1 b/scripts/build-windows.ps1 new file mode 100644 index 0000000..d1717e1 --- /dev/null +++ b/scripts/build-windows.ps1 @@ -0,0 +1,4 @@ + +cmake -S . -B build_x64 -DCMAKE_BUILD_TYPE=Release -DLocaalSDK_FIND_COMPONENTS="Core;Transcription;Translation" + +cmake --build build_x64 --config Release diff --git a/src/modules/core/include/logger.h b/src/modules/core/include/logger.h new file mode 100644 index 0000000..9394ecc --- /dev/null +++ b/src/modules/core/include/logger.h @@ -0,0 +1,21 @@ +#ifndef LOGGER_H +#define LOGGER_H + +#include +#include + +class Logger { +public: + enum class Level { DEBUG, INFO, WARNING, ERROR }; + + using LogCallback = std::function; + + static void setLogCallback(LogCallback callback); + static void Logger::log(Level level, const std::string &format, ...); + +private: + static LogCallback s_logCallback; + static std::string getLevelString(Level level); +}; + +#endif // LOGGER_H diff --git a/src/modules/core/include/model-downloader.h b/src/modules/core/include/model-downloader.h index 3af9450..b359209 100644 --- a/src/modules/core/include/model-downloader.h +++ b/src/modules/core/include/model-downloader.h @@ -8,8 +8,4 @@ std::string find_model_folder(const ModelInfo &model_info); std::string find_model_bin_file(const ModelInfo &model_info); -// Start the model downloader UI dialog with a callback for when the download is finished -void download_model_with_ui_dialog(const ModelInfo &model_info, - download_finished_callback_t download_finished_callback); - #endif // MODEL_DOWNLOADER_H diff --git a/src/modules/core/src/logger.cpp b/src/modules/core/src/logger.cpp new file mode 100644 index 0000000..5e13589 --- /dev/null +++ b/src/modules/core/src/logger.cpp @@ -0,0 +1,46 @@ +#include "Logger.h" +#include +#include +#include + +Logger::LogCallback Logger::s_logCallback = nullptr; + +void Logger::setLogCallback(LogCallback callback) +{ + s_logCallback = callback; +} + +void Logger::log(Level level, const std::string &format, ...) +{ + // Default logging behavior + va_list args; + va_start(args, format); + char buffer[256]; + vsnprintf(buffer, sizeof(buffer), format.c_str(), args); + va_end(args); + std::stringstream ss; + ss << "[ " << getLevelString(level) << " ] " << buffer; + + if (s_logCallback) { + s_logCallback(level, ss.str()); + } else { + // Default logging behavior + std::cout << ss.str() << std::endl; + } +} + +std::string Logger::getLevelString(Level level) +{ + switch (level) { + case Level::DEBUG: + return "DEBUG"; + case Level::INFO: + return "INFO"; + case Level::WARNING: + return "WARNING"; + case Level::ERROR: + return "ERROR"; + default: + return "UNKNOWN"; + } +} diff --git a/src/modules/core/src/model-downloader.cpp b/src/modules/core/src/model-downloader.cpp index 7f7f04d..cc7a06e 100644 --- a/src/modules/core/src/model-downloader.cpp +++ b/src/modules/core/src/model-downloader.cpp @@ -1,73 +1,37 @@ #include "model-downloader.h" -#include "plugin-support.h" -#include "model-downloader-ui.h" #include "model-find-utils.h" +#include "Logger.h" -#include -#include +#include + +#include std::string find_model_folder(const ModelInfo &model_info) { if (model_info.friendly_name.empty() || model_info.local_folder_name.empty() || model_info.files.empty()) { - obs_log(LOG_ERROR, "Model info is invalid."); + Logger::log(Logger::Level::ERROR, "Model info is invalid."); return ""; } - char *data_folder_models = obs_module_file("models"); + const std::string data_folder_models = sago::getCacheDir(); const std::filesystem::path module_data_models_folder = std::filesystem::absolute(data_folder_models); - bfree(data_folder_models); const std::string model_local_data_path = - (module_data_models_folder / model_info.local_folder_name).string(); + (module_data_models_folder / "locaal" / "models" / model_info.local_folder_name) + .string(); - obs_log(LOG_INFO, "Checking if model '%s' exists in data...", - model_info.friendly_name.c_str()); + Logger::log(Logger::Level::INFO, "Checking if model '%s' exists in cache...", + model_info.friendly_name.c_str()); - if (!std::filesystem::exists(model_local_data_path)) { - obs_log(LOG_INFO, "Model not found in data: %s", model_local_data_path.c_str()); - } else { - obs_log(LOG_INFO, "Model folder found in data: %s", model_local_data_path.c_str()); + if (std::filesystem::exists(model_local_data_path)) { + Logger::log(Logger::Level::INFO, "Model folder found in data: %s", + model_local_data_path.c_str()); return model_local_data_path; } - // Check if model exists in the config folder - char *config_folder = obs_module_config_path("models"); - if (!config_folder) { - obs_log(LOG_INFO, "Config folder not set."); - return ""; - } -#ifdef _WIN32 - // convert mbstring to wstring - int count = MultiByteToWideChar(CP_UTF8, 0, config_folder, strlen(config_folder), NULL, 0); - std::wstring config_folder_str(count, 0); - MultiByteToWideChar(CP_UTF8, 0, config_folder, strlen(config_folder), &config_folder_str[0], - count); - obs_log(LOG_INFO, "Config models folder: %S", config_folder_str.c_str()); -#else - std::string config_folder_str = config_folder; - obs_log(LOG_INFO, "Config models folder: %s", config_folder_str.c_str()); -#endif - - const std::filesystem::path module_config_models_folder = - std::filesystem::absolute(config_folder_str); - bfree(config_folder); - - obs_log(LOG_INFO, "Checking if model '%s' exists in config...", - model_info.friendly_name.c_str()); - - const std::string model_local_config_path = - (module_config_models_folder / model_info.local_folder_name).string(); - - obs_log(LOG_INFO, "Lookig for model in config: %s", model_local_config_path.c_str()); - if (std::filesystem::exists(model_local_config_path)) { - obs_log(LOG_INFO, "Model folder exists in config folder: %s", - model_local_config_path.c_str()); - return model_local_config_path; - } - - obs_log(LOG_INFO, "Model '%s' not found.", model_info.friendly_name.c_str()); + Logger::log(Logger::Level::INFO, "Model '%s' not found.", model_info.friendly_name.c_str()); return ""; } @@ -80,12 +44,3 @@ std::string find_model_bin_file(const ModelInfo &model_info) return find_bin_file_in_folder(model_local_folder_path); } - -void download_model_with_ui_dialog(const ModelInfo &model_info, - download_finished_callback_t download_finished_callback) -{ - // Start the model downloader UI - ModelDownloader *model_downloader = new ModelDownloader( - model_info, download_finished_callback, (QWidget *)obs_frontend_get_main_window()); - model_downloader->show(); -} diff --git a/src/modules/core/src/model-find-utils.cpp b/src/modules/core/src/model-find-utils.cpp index d2bb48f..7ee93e8 100644 --- a/src/modules/core/src/model-find-utils.cpp +++ b/src/modules/core/src/model-find-utils.cpp @@ -4,10 +4,8 @@ #include #include -#include - #include "model-find-utils.h" -#include "plugin-support.h" +#include "logger.h" std::string find_file_in_folder_by_name(const std::string &folder_path, const std::string &file_name) @@ -39,12 +37,12 @@ std::string find_bin_file_in_folder(const std::string &model_local_folder_path) for (const auto &entry : std::filesystem::directory_iterator(model_local_folder_path)) { if (entry.path().extension() == ".bin") { const std::string bin_file_path = entry.path().string(); - obs_log(LOG_INFO, "Model bin file found in folder: %s", - bin_file_path.c_str()); + Logger::log(Logger::Level::INFO, "Model bin file found in folder: %s", + bin_file_path.c_str()); return bin_file_path; } } - obs_log(LOG_ERROR, "Model bin file not found in folder: %s", - model_local_folder_path.c_str()); + Logger::log(Logger::Level::ERROR, "Model bin file not found in folder: %s", + model_local_folder_path.c_str()); return ""; } diff --git a/src/modules/transcription/include/token-buffer-thread.h b/src/modules/transcription/include/token-buffer-thread.h index 96e2f75..c1e0a86 100644 --- a/src/modules/transcription/include/token-buffer-thread.h +++ b/src/modules/transcription/include/token-buffer-thread.h @@ -10,9 +10,7 @@ #include #include -#include - -#include "plugin-support.h" +#include "logger.h" #ifdef _WIN32 typedef std::wstring TokenBufferString; diff --git a/src/modules/transcription/include/transcription-context.h b/src/modules/transcription/include/transcription-context.h index 060ae9d..765eff5 100644 --- a/src/modules/transcription/include/transcription-context.h +++ b/src/modules/transcription/include/transcription-context.h @@ -1,5 +1,5 @@ -#ifndef TRANSCRIPTION_FILTER_DATA_H -#define TRANSCRIPTION_FILTER_DATA_H +#ifndef TRANSCRIPTION_CONTEXT_H +#define TRANSCRIPTION_CONTEXT_H #include @@ -10,18 +10,24 @@ #include #include -#include "translation/translation.h" -#include "translation/translation-includes.h" -#include "whisper-utils/silero-vad-onnx.h" -#include "whisper-utils/whisper-processing.h" -#include "whisper-utils/token-buffer-thread.h" +#include "translation.h" +#include "translation-includes.h" +#include "silero-vad-onnx.h" +#include "whisper-processing.h" +#include "token-buffer-thread.h" +#include "logger.h" #define MAX_PREPROC_CHANNELS 10 +// Audio packet info +struct transcription_filter_audio_info { + uint32_t frames; + uint64_t timestamp_offset_ns; // offset (since start of processing) timestamp in ns +}; + struct transcription_context { - obs_source_t *context; // obs filter source (this filter) - size_t channels; // number of channels - uint32_t sample_rate; // input sample rate + size_t channels; // number of channels + uint32_t sample_rate; // input sample rate // How many input frames (in input sample rate) are needed for the next whisper frame size_t frames; // How many frames were processed in the last whisper frame (this is dynamic) @@ -40,13 +46,13 @@ struct transcription_context { /* PCM buffers */ float *copy_buffers[MAX_PREPROC_CHANNELS]; - struct circlebuf info_buffer; - struct circlebuf input_buffers[MAX_PREPROC_CHANNELS]; - struct circlebuf whisper_buffer; + std::deque info_buffer; + std::deque input_buffers[MAX_PREPROC_CHANNELS]; + std::deque whisper_buffer; /* Resampler */ audio_resampler_t *resampler_to_whisper; - struct circlebuf resampled_buffer; + std::deque resampled_buffer; /* whisper */ std::string whisper_model_path; @@ -61,7 +67,7 @@ struct transcription_context { bool do_silence; int vad_mode; - int log_level = LOG_DEBUG; + Logger::Level log_level = Logger::Level::DEBUG; bool log_words; bool caption_to_stream; bool active = false; @@ -131,7 +137,6 @@ struct transcription_context { for (size_t i = 0; i < MAX_PREPROC_CHANNELS; i++) { copy_buffers[i] = nullptr; } - context = nullptr; resampler_to_whisper = nullptr; whisper_model_path = ""; whisper_context = nullptr; @@ -140,19 +145,13 @@ struct transcription_context { } }; -// Audio packet info -struct transcription_filter_audio_info { - uint32_t frames; - uint64_t timestamp_offset_ns; // offset (since start of processing) timestamp in ns -}; - // Callback sent when the transcription has a new result void set_text_callback(struct transcription_context *gf, const DetectionResultWithText &str); void clear_current_caption(transcription_context *gf_); // Callback sent when the VAD finds an audio chunk. Sample rate = WHISPER_SAMPLE_RATE, channels = 1 // The audio chunk is in 32-bit float format -void audio_chunk_callback(struct transcription_context *gf, const float *pcm32f_data, - size_t frames, int vad_state, const DetectionResultWithText &result); +void audio_chunk_callback(struct transcription_context *gf, const float *pcm32f_data, size_t frames, + int vad_state, const DetectionResultWithText &result); -#endif /* TRANSCRIPTION_FILTER_DATA_H */ +#endif /* TRANSCRIPTION_CONTEXT_H */ diff --git a/src/modules/transcription/include/whisper-model-utils.h b/src/modules/transcription/include/whisper-model-utils.h index ae8487e..d4c10b0 100644 --- a/src/modules/transcription/include/whisper-model-utils.h +++ b/src/modules/transcription/include/whisper-model-utils.h @@ -1,8 +1,6 @@ #ifndef WHISPER_MODEL_UTILS_H #define WHISPER_MODEL_UTILS_H -#include - #include "transcription-filter-data.h" void update_whisper_model(struct transcription_context *gf); diff --git a/src/modules/transcription/include/whisper-utils.h b/src/modules/transcription/include/whisper-utils.h index 225c0d3..4204355 100644 --- a/src/modules/transcription/include/whisper-utils.h +++ b/src/modules/transcription/include/whisper-utils.h @@ -1,7 +1,7 @@ #ifndef WHISPER_UTILS_H #define WHISPER_UTILS_H -#include "transcription-filter-data.h" +#include "transcription-context.h" #include diff --git a/src/modules/transcription/src/silero-vad-onnx.cpp b/src/modules/transcription/src/silero-vad-onnx.cpp index 078e47c..41c2293 100644 --- a/src/modules/transcription/src/silero-vad-onnx.cpp +++ b/src/modules/transcription/src/silero-vad-onnx.cpp @@ -10,8 +10,7 @@ #include #include -#include -#include "plugin-support.h" +#include "logger.h" // #define __DEBUG_SPEECH_PROB___ @@ -152,7 +151,7 @@ void VadIterator::predict(const std::vector &data) float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point. - obs_log(LOG_INFO, "{ start: %.3f s (%.3f) %08d}", 1.0 * speech / sample_rate, + Logger::log(Logger::Level::INFO, "{ start: %.3f s (%.3f) %08d}", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples); #endif //__DEBUG_SPEECH_PROB___ if (temp_end != 0) { @@ -202,7 +201,7 @@ void VadIterator::predict(const std::vector &data) float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point. - obs_log(LOG_INFO, "{ speaking: %.3f s (%.3f) %08d}", + Logger::log(Logger::Level::INFO, "{ speaking: %.3f s (%.3f) %08d}", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples); #endif //__DEBUG_SPEECH_PROB___ @@ -211,7 +210,7 @@ void VadIterator::predict(const std::vector &data) float speech = current_sample - window_size_samples; // minus window_size_samples to get precise start time point. - obs_log(LOG_INFO, "{ silence: %.3f s (%.3f) %08d}", + Logger::log(Logger::Level::INFO, "{ silence: %.3f s (%.3f) %08d}", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples); #endif //__DEBUG_SPEECH_PROB___ @@ -225,7 +224,7 @@ void VadIterator::predict(const std::vector &data) float speech = current_sample - window_size_samples - speech_pad_samples; // minus window_size_samples to get precise start time point. - obs_log(LOG_INFO, "{ end: %.3f s (%.3f) %08d}", 1.0 * speech / sample_rate, + Logger::log(Logger::Level::INFO, "{ end: %.3f s (%.3f) %08d}", 1.0 * speech / sample_rate, speech_prob, current_sample - window_size_samples); #endif //__DEBUG_SPEECH_PROB___ if (triggered == true) { @@ -295,7 +294,7 @@ void VadIterator::collect_chunks(const std::vector &input_wav, output_wav.clear(); for (size_t i = 0; i < speeches.size(); i++) { #ifdef __DEBUG_SPEECH_PROB___ - obs_log(LOG_INFO, "%s", speeches[i].string().c_str()); + Logger::log(Logger::Level::INFO, "%s", speeches[i].string().c_str()); #endif //#ifdef __DEBUG_SPEECH_PROB___ std::vector slice(&input_wav[speeches[i].start], &input_wav[speeches[i].end]); diff --git a/src/modules/transcription/src/token-buffer-thread.cpp b/src/modules/transcription/src/token-buffer-thread.cpp index dfe6bcc..3c27206 100644 --- a/src/modules/transcription/src/token-buffer-thread.cpp +++ b/src/modules/transcription/src/token-buffer-thread.cpp @@ -9,8 +9,6 @@ #include #include -#include - #ifdef _WIN32 #include #define SPACE L" " @@ -75,7 +73,8 @@ void TokenBufferThread::log_token_vector(const std::vector &tokens) for (const auto &token : tokens) { output += token; } - obs_log(LOG_INFO, "TokenBufferThread::log_token_vector: '%s'", output.c_str()); + Logger::log(Logger::Level::INFO, "TokenBufferThread::log_token_vector: '%s'", + output.c_str()); } void TokenBufferThread::addSentenceFromStdString(const std::string &sentence, @@ -166,7 +165,7 @@ void TokenBufferThread::clear() void TokenBufferThread::monitor() { - obs_log(LOG_INFO, "TokenBufferThread::monitor"); + Logger::log(Logger::Level::INFO, "TokenBufferThread::monitor"); this->captionPresentationCallback(""); @@ -345,8 +344,9 @@ void TokenBufferThread::monitor() contribution.end()); #endif - obs_log(gf->log_level, "TokenBufferThread::monitor: output '%s'", - contribution_out.c_str()); + Logger::log(gf->log_level, + "TokenBufferThread::monitor: output '%s'", + contribution_out.c_str()); this->sentenceOutputCallback(contribution_out); lastContributionIsSent = true; } @@ -385,7 +385,7 @@ void TokenBufferThread::monitor() : getWaitTime(SPEED_SLOW))); } - obs_log(LOG_INFO, "TokenBufferThread::monitor: done"); + Logger::log(Logger::Level::INFO, "TokenBufferThread::monitor: done"); } int TokenBufferThread::getWaitTime(TokenBufferSpeed speed) const diff --git a/src/modules/transcription/src/vad-processing.cpp b/src/modules/transcription/src/vad-processing.cpp index 0e9c744..29c37c3 100644 --- a/src/modules/transcription/src/vad-processing.cpp +++ b/src/modules/transcription/src/vad-processing.cpp @@ -24,9 +24,9 @@ int get_data_from_buf_and_resample(transcription_filter_data *gf, return 1; } - obs_log(gf->log_level, - "segmentation: currently %lu bytes in the audio input buffer", - gf->input_buffers[0].size); + Logger::log(gf->log_level, + "segmentation: currently %lu bytes in the audio input buffer", + gf->input_buffers[0].size); // max number of frames is 10 seconds worth of audio const size_t max_num_frames = gf->sample_rate * 10; @@ -76,7 +76,7 @@ int get_data_from_buf_and_resample(transcription_filter_data *gf, } } - obs_log(gf->log_level, "found %d frames from info buffer.", num_frames_from_infos); + Logger::log(gf->log_level, "found %d frames from info buffer.", num_frames_from_infos); gf->last_num_frames = num_frames_from_infos; { @@ -95,11 +95,11 @@ int get_data_from_buf_and_resample(transcription_filter_data *gf, circlebuf_push_back(&gf->resampled_buffer, resampled_16khz[0], resampled_16khz_frames * sizeof(float)); - obs_log(gf->log_level, - "resampled: %d channels, %d frames, %f ms, current size: %lu bytes", - (int)gf->channels, (int)resampled_16khz_frames, - (float)resampled_16khz_frames / WHISPER_SAMPLE_RATE * 1000.0f, - gf->resampled_buffer.size); + Logger::log(gf->log_level, + "resampled: %d channels, %d frames, %f ms, current size: %lu bytes", + (int)gf->channels, (int)resampled_16khz_frames, + (float)resampled_16khz_frames / WHISPER_SAMPLE_RATE * 1000.0f, + gf->resampled_buffer.size); } return 0; @@ -129,8 +129,8 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v circlebuf_pop_front(&gf->resampled_buffer, vad_input.data(), vad_input.size() * sizeof(float)); - obs_log(gf->log_level, "sending %d frames to vad, %d windows, reset state? %s", - vad_input.size(), vad_num_windows, (!last_vad_state.vad_on) ? "yes" : "no"); + Logger::log(gf->log_level, "sending %d frames to vad, %d windows, reset state? %s", + vad_input.size(), vad_num_windows, (!last_vad_state.vad_on) ? "yes" : "no"); { ProfileScope("vad->process"); gf->vad->process(vad_input, !last_vad_state.vad_on); @@ -144,9 +144,10 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v std::vector stamps = gf->vad->get_speech_timestamps(); if (stamps.size() == 0) { - obs_log(gf->log_level, "VAD detected no speech in %u frames", vad_input.size()); + Logger::log(gf->log_level, "VAD detected no speech in %u frames", vad_input.size()); if (last_vad_state.vad_on) { - obs_log(gf->log_level, "Last VAD was ON: segment end -> send to inference"); + Logger::log(gf->log_level, + "Last VAD was ON: segment end -> send to inference"); run_inference_and_callbacks(gf, last_vad_state.start_ts_offest_ms, last_vad_state.end_ts_offset_ms, VAD_STATE_WAS_ON); @@ -190,7 +191,8 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v circlebuf_push_back(&gf->whisper_buffer, vad_input.data() + start_frame, number_of_frames * sizeof(float)); - obs_log(gf->log_level, + Logger::log( + gf->log_level, "VAD segment %d/%d. pushed %d to %d (%d frames / %lu ms). current size: %lu bytes / %lu frames / %lu ms", i, (stamps.size() - 1), start_frame, end_frame, number_of_frames, number_of_frames * 1000 / WHISPER_SAMPLE_RATE, gf->whisper_buffer.size, @@ -200,7 +202,7 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v // segment "end" is in the middle of the buffer, send it to inference if (stamps[i].end < (int)vad_input.size()) { // new "ending" segment (not up to the end of the buffer) - obs_log(gf->log_level, "VAD segment end -> send to inference"); + Logger::log(gf->log_level, "VAD segment end -> send to inference"); // find the end timestamp of the segment const uint64_t segment_end_ts = start_ts_offset_ms + end_frame * 1000 / WHISPER_SAMPLE_RATE; @@ -218,12 +220,14 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v // end not reached - speech is ongoing current_vad_state.vad_on = true; if (last_vad_state.vad_on) { - obs_log(gf->log_level, - "last vad state was: ON, start ts: %llu, end ts: %llu", - last_vad_state.start_ts_offest_ms, last_vad_state.end_ts_offset_ms); + Logger::log(gf->log_level, + "last vad state was: ON, start ts: %llu, end ts: %llu", + last_vad_state.start_ts_offest_ms, + last_vad_state.end_ts_offset_ms); current_vad_state.start_ts_offest_ms = last_vad_state.start_ts_offest_ms; } else { - obs_log(gf->log_level, + Logger::log( + gf->log_level, "last vad state was: OFF, start ts: %llu, end ts: %llu. start_ts_offset_ms: %llu, start_frame: %d", last_vad_state.start_ts_offest_ms, last_vad_state.end_ts_offset_ms, start_ts_offset_ms, start_frame); @@ -232,9 +236,10 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v } current_vad_state.end_ts_offset_ms = start_ts_offset_ms + end_frame * 1000 / WHISPER_SAMPLE_RATE; - obs_log(gf->log_level, - "end not reached. vad state: ON, start ts: %llu, end ts: %llu", - current_vad_state.start_ts_offest_ms, current_vad_state.end_ts_offset_ms); + Logger::log(gf->log_level, + "end not reached. vad state: ON, start ts: %llu, end ts: %llu", + current_vad_state.start_ts_offest_ms, + current_vad_state.end_ts_offset_ms); last_vad_state = current_vad_state; @@ -251,14 +256,14 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v (current_vad_state.last_partial_segment_end_ts > 0 ? current_vad_state.last_partial_segment_end_ts : current_vad_state.start_ts_offest_ms); - obs_log(gf->log_level, "current buffer length after last partial (%lu): %lu ms", - current_vad_state.last_partial_segment_end_ts, current_length_ms); + Logger::log(gf->log_level, "current buffer length after last partial (%lu): %lu ms", + current_vad_state.last_partial_segment_end_ts, current_length_ms); if (current_length_ms > (uint64_t)gf->partial_latency) { current_vad_state.last_partial_segment_end_ts = current_vad_state.end_ts_offset_ms; // send partial segment to inference - obs_log(gf->log_level, "Partial segment -> send to inference"); + Logger::log(gf->log_level, "Partial segment -> send to inference"); run_inference_and_callbacks(gf, current_vad_state.start_ts_offest_ms, current_vad_state.end_ts_offset_ms, VAD_STATE_PARTIAL); @@ -289,13 +294,13 @@ vad_state hybrid_vad_segmentation(transcription_filter_data *gf, vad_state last_ circlebuf_pop_front(&gf->resampled_buffer, temp_buffer.data(), resampled_buffer_size); circlebuf_push_back(&gf->whisper_buffer, temp_buffer.data(), resampled_buffer_size); - obs_log(gf->log_level, "whisper buffer size: %lu bytes", gf->whisper_buffer.size); + Logger::log(gf->log_level, "whisper buffer size: %lu bytes", gf->whisper_buffer.size); // use last_vad_state timestamps to calculate the duration of the current segment if (last_vad_state.end_ts_offset_ms - last_vad_state.start_ts_offest_ms >= (uint64_t)gf->segment_duration) { - obs_log(gf->log_level, "%d seconds worth of audio -> send to inference", - gf->segment_duration); + Logger::log(gf->log_level, "%d seconds worth of audio -> send to inference", + gf->segment_duration); run_inference_and_callbacks(gf, last_vad_state.start_ts_offest_ms, last_vad_state.end_ts_offset_ms, VAD_STATE_WAS_ON); last_vad_state.start_ts_offest_ms = end_timestamp_offset_ns / 1000000; @@ -312,12 +317,12 @@ vad_state hybrid_vad_segmentation(transcription_filter_data *gf, vad_state last_ (last_vad_state.last_partial_segment_end_ts > 0 ? last_vad_state.last_partial_segment_end_ts : last_vad_state.start_ts_offest_ms); - obs_log(gf->log_level, "current buffer length after last partial (%lu): %lu ms", - last_vad_state.last_partial_segment_end_ts, current_length_ms); + Logger::log(gf->log_level, "current buffer length after last partial (%lu): %lu ms", + last_vad_state.last_partial_segment_end_ts, current_length_ms); if (current_length_ms > (uint64_t)gf->partial_latency) { // send partial segment to inference - obs_log(gf->log_level, "Partial segment -> send to inference"); + Logger::log(gf->log_level, "Partial segment -> send to inference"); last_vad_state.last_partial_segment_end_ts = last_vad_state.end_ts_offset_ms; @@ -327,9 +332,9 @@ vad_state hybrid_vad_segmentation(transcription_filter_data *gf, vad_state last_ circlebuf_peek_front(&gf->whisper_buffer, vad_input.data(), vad_input.size() * sizeof(float)); - obs_log(gf->log_level, "sending %d frames to vad, %.1f ms", - vad_input.size(), - (float)vad_input.size() * 1000.0f / (float)WHISPER_SAMPLE_RATE); + Logger::log(gf->log_level, "sending %d frames to vad, %.1f ms", + vad_input.size(), + (float)vad_input.size() * 1000.0f / (float)WHISPER_SAMPLE_RATE); { ProfileScope("vad->process"); gf->vad->process(vad_input, true); @@ -342,7 +347,8 @@ vad_state hybrid_vad_segmentation(transcription_filter_data *gf, vad_state last_ VAD_STATE_PARTIAL); } else { // VAD detected silence in the partial segment - obs_log(gf->log_level, "VAD detected silence in partial segment"); + Logger::log(gf->log_level, + "VAD detected silence in partial segment"); // pop the partial segment from the whisper buffer, save some audio for the next segment const size_t num_bytes_to_keep = (WHISPER_SAMPLE_RATE / 4) * sizeof(float); @@ -365,10 +371,10 @@ void initialize_vad(transcription_filter_data *gf, const char *silero_vad_model_ std::wstring silero_vad_model_path(count, 0); MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, strlen(silero_vad_model_file), &silero_vad_model_path[0], count); - obs_log(gf->log_level, "Create silero VAD: %S", silero_vad_model_path.c_str()); + Logger::log(gf->log_level, "Create silero VAD: %S", silero_vad_model_path.c_str()); #else std::string silero_vad_model_path = silero_vad_model_file; - obs_log(gf->log_level, "Create silero VAD: %s", silero_vad_model_path.c_str()); + Logger::log(gf->log_level, "Create silero VAD: %s", silero_vad_model_path.c_str()); #endif // roughly following https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py // for silero vad parameters diff --git a/src/modules/transcription/src/whisper-model-utils.cpp b/src/modules/transcription/src/whisper-model-utils.cpp index fde3590..4f5b8da 100644 --- a/src/modules/transcription/src/whisper-model-utils.cpp +++ b/src/modules/transcription/src/whisper-model-utils.cpp @@ -2,23 +2,21 @@ #define NOMINMAX #endif -#include - #include "whisper-utils.h" #include "whisper-processing.h" -#include "plugin-support.h" +#include "logger.h" #include "model-utils/model-downloader.h" void update_whisper_model(struct transcription_context *gf) { if (gf->context == nullptr) { - obs_log(LOG_ERROR, "obs_source_t context is null"); + Logger::log(Logger::Level::ERROR, "obs_source_t context is null"); return; } obs_data_t *s = obs_source_get_settings(gf->context); if (s == nullptr) { - obs_log(LOG_ERROR, "obs_data_t settings is null"); + Logger::log(Logger::Level::ERROR, "obs_data_t settings is null"); return; } @@ -38,20 +36,20 @@ void update_whisper_model(struct transcription_context *gf) const bool is_external_model = new_model_path.find("!!!external!!!") != std::string::npos; if (!is_external_model && new_model_path.empty()) { - obs_log(LOG_WARNING, "Whisper model path is empty"); + Logger::log(Logger::Level::WARNING, "Whisper model path is empty"); return; } if (is_external_model && external_model_file_path.empty()) { - obs_log(LOG_WARNING, "External model file path is empty"); + Logger::log(Logger::Level::WARNING, "External model file path is empty"); return; } char *silero_vad_model_file = obs_module_file("models/silero-vad/silero_vad.onnx"); if (silero_vad_model_file == nullptr) { - obs_log(LOG_ERROR, "Cannot find Silero VAD model file"); + Logger::log(Logger::Level::ERROR, "Cannot find Silero VAD model file"); return; } - obs_log(gf->log_level, "Silero VAD model file: %s", silero_vad_model_file); + Logger::log(gf->log_level, "Silero VAD model file: %s", silero_vad_model_file); std::string silero_vad_model_file_str = std::string(silero_vad_model_file); bfree(silero_vad_model_file); @@ -60,8 +58,8 @@ void update_whisper_model(struct transcription_context *gf) if (gf->whisper_model_path != new_model_path) { // model path changed - obs_log(gf->log_level, "model path changed from %s to %s", - gf->whisper_model_path.c_str(), new_model_path.c_str()); + Logger::log(gf->log_level, "model path changed from %s to %s", + gf->whisper_model_path.c_str(), new_model_path.c_str()); // check if this is loading the initial model or a switch gf->whisper_model_loaded_new = !gf->whisper_model_path.empty(); @@ -73,8 +71,8 @@ void update_whisper_model(struct transcription_context *gf) shutdown_whisper_thread(gf); if (models_info.count(new_model_path) == 0) { - obs_log(LOG_WARNING, "Model '%s' does not exist", - new_model_path.c_str()); + Logger::log(Logger::Level::WARNING, "Model '%s' does not exist", + new_model_path.c_str()); return; } @@ -83,20 +81,21 @@ void update_whisper_model(struct transcription_context *gf) // check if the model exists, if not, download it std::string model_file_found = find_model_bin_file(model_info); if (model_file_found == "") { - obs_log(LOG_WARNING, "Whisper model does not exist"); + Logger::log(Logger::Level::WARNING, "Whisper model does not exist"); download_model_with_ui_dialog( model_info, [gf, new_model_path, silero_vad_model_file_str]( int download_status, const std::string &path) { if (download_status == 0) { - obs_log(LOG_INFO, - "Model download complete"); + Logger::log(Logger::Level::INFO, + "Model download complete"); gf->whisper_model_path = new_model_path; start_whisper_thread_with_path( gf, path, silero_vad_model_file_str.c_str()); } else { - obs_log(LOG_ERROR, "Model download failed"); + Logger::log(Logger::Level::ERROR, + "Model download failed"); } }); } else { @@ -108,12 +107,14 @@ void update_whisper_model(struct transcription_context *gf) } else { // new model is external file, get file location from file property if (external_model_file_path.empty()) { - obs_log(LOG_WARNING, "External model file path is empty"); + Logger::log(Logger::Level::WARNING, + "External model file path is empty"); } else { // check if the external model file is not currently loaded if (gf->whisper_model_file_currently_loaded == external_model_file_path) { - obs_log(LOG_INFO, "External model file is already loaded"); + Logger::log(Logger::Level::INFO, + "External model file is already loaded"); return; } else { shutdown_whisper_thread(gf); @@ -126,14 +127,14 @@ void update_whisper_model(struct transcription_context *gf) } } else { // model path did not change - obs_log(gf->log_level, "Model path did not change: %s == %s", - gf->whisper_model_path.c_str(), new_model_path.c_str()); + Logger::log(gf->log_level, "Model path did not change: %s == %s", + gf->whisper_model_path.c_str(), new_model_path.c_str()); } if (new_dtw_timestamps != gf->enable_token_ts_dtw) { // dtw_token_timestamps changed - obs_log(gf->log_level, "dtw_token_timestamps changed from %d to %d", - gf->enable_token_ts_dtw, new_dtw_timestamps); + Logger::log(gf->log_level, "dtw_token_timestamps changed from %d to %d", + gf->enable_token_ts_dtw, new_dtw_timestamps); gf->enable_token_ts_dtw = new_dtw_timestamps; shutdown_whisper_thread(gf); start_whisper_thread_with_path(gf, gf->whisper_model_path, diff --git a/src/modules/transcription/src/whisper-processing.cpp b/src/modules/transcription/src/whisper-processing.cpp index c02649a..90cf27e 100644 --- a/src/modules/transcription/src/whisper-processing.cpp +++ b/src/modules/transcription/src/whisper-processing.cpp @@ -1,10 +1,8 @@ #include -#include - #include -#include "plugin-support.h" +#include "logger.h" #include "transcription-filter-data.h" #include "whisper-processing.h" #include "whisper-utils.h" @@ -28,16 +26,17 @@ struct whisper_context *init_whisper_context(const std::string &model_path_in, { std::string model_path = model_path_in; - obs_log(LOG_INFO, "Loading whisper model from %s", model_path.c_str()); + Logger::log(Logger::Level::INFO, "Loading whisper model from %s", model_path.c_str()); if (std::filesystem::is_directory(model_path)) { - obs_log(LOG_INFO, + Logger::log( + Logger::Level::INFO, "Model path is a directory, not a file, looking for .bin file in folder"); // look for .bin file const std::string model_bin_file = find_bin_file_in_folder(model_path); if (model_bin_file.empty()) { - obs_log(LOG_ERROR, "Model bin file not found in folder: %s", - model_path.c_str()); + Logger::log(Logger::Level::ERROR, "Model bin file not found in folder: %s", + model_path.c_str()); return nullptr; } model_path = model_bin_file; @@ -51,7 +50,7 @@ struct whisper_context *init_whisper_context(const std::string &model_path_in, // remove trailing newline char *text_copy = bstrdup(text); text_copy[strcspn(text_copy, "\n")] = 0; - obs_log(ctx->log_level, "Whisper: %s", text_copy); + Logger::log(ctx->log_level, "Whisper: %s", text_copy); bfree(text_copy); }, gf); @@ -59,25 +58,26 @@ struct whisper_context *init_whisper_context(const std::string &model_path_in, struct whisper_context_params cparams = whisper_context_default_params(); #ifdef LOCALVOCAL_WITH_CUDA cparams.use_gpu = true; - obs_log(LOG_INFO, "Using CUDA GPU for inference, device %d", cparams.gpu_device); + Logger::log(Logger::Level::INFO, "Using CUDA GPU for inference, device %d", + cparams.gpu_device); #elif defined(LOCALVOCAL_WITH_HIPBLAS) cparams.use_gpu = true; - obs_log(LOG_INFO, "Using hipBLAS for inference"); + Logger::log(Logger::Level::INFO, "Using hipBLAS for inference"); #elif defined(__APPLE__) cparams.use_gpu = true; - obs_log(LOG_INFO, "Using Metal/CoreML for inference"); + Logger::log(Logger::Level::INFO, "Using Metal/CoreML for inference"); #else cparams.use_gpu = false; - obs_log(LOG_INFO, "Using CPU for inference"); + Logger::log(Logger::Level::INFO, "Using CPU for inference"); #endif cparams.dtw_token_timestamps = gf->enable_token_ts_dtw; if (gf->enable_token_ts_dtw) { - obs_log(LOG_INFO, "DTW token timestamps enabled"); + Logger::log(Logger::Level::INFO, "DTW token timestamps enabled"); cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY_EN; // cparams.dtw_n_top = 4; } else { - obs_log(LOG_INFO, "DTW token timestamps disabled"); + Logger::log(Logger::Level::INFO, "DTW token timestamps disabled"); cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE; } @@ -94,8 +94,8 @@ struct whisper_context *init_whisper_context(const std::string &model_path_in, // Read model into buffer std::ifstream modelFile(model_path_ws, std::ios::binary); if (!modelFile.is_open()) { - obs_log(LOG_ERROR, "Failed to open whisper model file %s", - model_path.c_str()); + Logger::log(Logger::Level::ERROR, "Failed to open whisper model file %s", + model_path.c_str()); return nullptr; } modelFile.seekg(0, std::ios::end); @@ -112,15 +112,16 @@ struct whisper_context *init_whisper_context(const std::string &model_path_in, ctx = whisper_init_from_file_with_params(model_path.c_str(), cparams); #endif } catch (const std::exception &e) { - obs_log(LOG_ERROR, "Exception while loading whisper model: %s", e.what()); + Logger::log(Logger::Level::ERROR, "Exception while loading whisper model: %s", + e.what()); return nullptr; } if (ctx == nullptr) { - obs_log(LOG_ERROR, "Failed to load whisper model"); + Logger::log(Logger::Level::ERROR, "Failed to load whisper model"); return nullptr; } - obs_log(LOG_INFO, "Whisper model loaded: %s", whisper_print_system_info()); + Logger::log(Logger::Level::INFO, "Whisper model loaded: %s", whisper_print_system_info()); return ctx; } @@ -131,25 +132,26 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_contex int vad_state = VAD_STATE_WAS_OFF) { if (gf == nullptr) { - obs_log(LOG_ERROR, "run_whisper_inference: gf is null"); + Logger::log(Logger::Level::ERROR, "run_whisper_inference: gf is null"); return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""}; } if (pcm32f_data_ == nullptr || pcm32f_num_samples == 0) { - obs_log(LOG_ERROR, "run_whisper_inference: pcm32f_data is null or size is 0"); + Logger::log(Logger::Level::ERROR, + "run_whisper_inference: pcm32f_data is null or size is 0"); return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""}; } // if the time difference between t0 and t1 is less than 50 ms - skip if (t1 - t0 < 50) { - obs_log(gf->log_level, - "Time difference between t0 and t1 is less than 50 ms, skipping"); + Logger::log(gf->log_level, + "Time difference between t0 and t1 is less than 50 ms, skipping"); return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""}; } - obs_log(gf->log_level, "%s: processing %d samples, %.3f sec, %d threads", __func__, - int(pcm32f_num_samples), float(pcm32f_num_samples) / WHISPER_SAMPLE_RATE, - gf->whisper_params.n_threads); + Logger::log(gf->log_level, "%s: processing %d samples, %.3f sec, %d threads", __func__, + int(pcm32f_num_samples), float(pcm32f_num_samples) / WHISPER_SAMPLE_RATE, + gf->whisper_params.n_threads); bool should_free_buffer = false; float *pcm32f_data = (float *)pcm32f_data_; @@ -160,7 +162,8 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_contex (uint64_t)(pcm32f_num_samples * 1000 / WHISPER_SAMPLE_RATE); if (pcm32f_num_samples < WHISPER_SAMPLE_RATE) { - obs_log(gf->log_level, + Logger::log( + gf->log_level, "Speech segment is less than 1 second, padding with white noise to 1 second"); const size_t new_size = (size_t)(1.01f * (float)(WHISPER_SAMPLE_RATE)); // create a new buffer and copy the data to it in the middle @@ -184,7 +187,7 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_contex std::lock_guard lock(gf->whisper_ctx_mutex); if (gf->whisper_context == nullptr) { - obs_log(LOG_WARNING, "whisper context is null"); + Logger::log(Logger::Level::WARNING, "whisper context is null"); return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""}; } @@ -195,7 +198,7 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_contex initial_prompt += " " + gf->last_transcription_sentence[i]; } gf->whisper_params.initial_prompt = initial_prompt.c_str(); - obs_log(gf->log_level, "Initial prompt: %s", gf->whisper_params.initial_prompt); + Logger::log(gf->log_level, "Initial prompt: %s", gf->whisper_params.initial_prompt); } // run the inference @@ -205,7 +208,8 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_contex whisper_full_result = whisper_full(gf->whisper_context, gf->whisper_params, pcm32f_data, (int)pcm32f_size); } catch (const std::exception &e) { - obs_log(LOG_ERROR, "Whisper exception: %s. Filter restart is required", e.what()); + Logger::log(Logger::Level::ERROR, + "Whisper exception: %s. Filter restart is required", e.what()); whisper_free(gf->whisper_context); gf->whisper_context = nullptr; if (should_free_buffer) { @@ -222,11 +226,12 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_contex strcmp(gf->whisper_params.language, "auto") == 0) { int lang_id = whisper_lang_auto_detect(gf->whisper_context, 0, 1, nullptr); language = whisper_lang_str(lang_id); - obs_log(gf->log_level, "Detected language: %s", language.c_str()); + Logger::log(gf->log_level, "Detected language: %s", language.c_str()); } if (whisper_full_result != 0) { - obs_log(LOG_WARNING, "failed to process audio, error %d", whisper_full_result); + Logger::log(Logger::Level::WARNING, "failed to process audio, error %d", + whisper_full_result); return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""}; } @@ -261,14 +266,15 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_contex const float time = ((float)token.id - 50365.0f) * 0.02f; const float duration_s = (float)incoming_duration_ms / 1000.0f; const float ratio = time / duration_s; - obs_log(gf->log_level, + Logger::log( + gf->log_level, "Time token found %d -> %.3f. Duration: %.3f. Ratio: %.3f. Threshold %.2f", token.id, time, duration_s, ratio, gf->duration_filter_threshold); if (ratio > gf->duration_filter_threshold) { // ratio is too high, skip this detection - obs_log(gf->log_level, - "Time token ratio too high, skipping"); + Logger::log(gf->log_level, + "Time token ratio too high, skipping"); return {DETECTION_RESULT_SILENCE, "", t0, t1, {}, language}; } keep = false; @@ -279,23 +285,23 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_contex text += token_str; tokens.push_back(token); } - obs_log(gf->log_level, "S %d, T %2d: %5d\t%s\tp: %.3f [keep: %d]", - n_segment, j, token.id, token_str.c_str(), token.p, keep); + Logger::log(gf->log_level, "S %d, T %2d: %5d\t%s\tp: %.3f [keep: %d]", + n_segment, j, token.id, token_str.c_str(), token.p, keep); } } sentence_p /= (float)tokens.size(); if (sentence_p < gf->sentence_psum_accept_thresh) { - obs_log(gf->log_level, "Sentence psum %.3f below threshold %.3f, skipping", - sentence_p, gf->sentence_psum_accept_thresh); + Logger::log(gf->log_level, "Sentence psum %.3f below threshold %.3f, skipping", + sentence_p, gf->sentence_psum_accept_thresh); return {DETECTION_RESULT_SILENCE, "", t0, t1, {}, language}; } - obs_log(gf->log_level, "Decoded sentence: '%s'", text.c_str()); + Logger::log(gf->log_level, "Decoded sentence: '%s'", text.c_str()); if (gf->log_words) { - obs_log(LOG_INFO, "[%s --> %s]%s(%.3f) %s", to_timestamp(t0).c_str(), - to_timestamp(t1).c_str(), vad_state == VAD_STATE_PARTIAL ? "P" : " ", - sentence_p, text.c_str()); + Logger::log(Logger::Level::INFO, "[%s --> %s]%s(%.3f) %s", to_timestamp(t0).c_str(), + to_timestamp(t1).c_str(), vad_state == VAD_STATE_PARTIAL ? "P" : " ", + sentence_p, text.c_str()); } if (text.empty() || text == "." || text == " " || text == "\n") { @@ -346,14 +352,13 @@ void run_inference_and_callbacks(transcription_context *gf, uint64_t start_offse void whisper_loop(void *data) { if (data == nullptr) { - obs_log(LOG_ERROR, "whisper_loop: data is null"); + Logger::log(Logger::Level::ERROR, "whisper_loop: data is null"); return; } - struct transcription_context *gf = - static_cast(data); + struct transcription_context *gf = static_cast(data); - obs_log(gf->log_level, "Starting whisper thread"); + Logger::log(gf->log_level, "Starting whisper thread"); vad_state current_vad_state = {false, now_ms(), 0, 0}; @@ -368,7 +373,8 @@ void whisper_loop(void *data) std::lock_guard lock(gf->whisper_ctx_mutex); ProfileScope("locked whisper ctx"); if (gf->whisper_context == nullptr) { - obs_log(LOG_WARNING, "Whisper context is null, exiting thread"); + Logger::log(Logger::Level::WARNING, + "Whisper context is null, exiting thread"); break; } } @@ -384,9 +390,9 @@ void whisper_loop(void *data) uint64_t now = now_ms(); if ((now - gf->last_sub_render_time) > gf->max_sub_duration) { // clear the current sub, call the callback with an empty string - obs_log(gf->log_level, - "Clearing current subtitle. now: %lu ms, last: %lu ms", now, - gf->last_sub_render_time); + Logger::log(gf->log_level, + "Clearing current subtitle. now: %lu ms, last: %lu ms", + now, gf->last_sub_render_time); clear_current_caption(gf); } } @@ -403,5 +409,5 @@ void whisper_loop(void *data) } } - obs_log(gf->log_level, "Exiting whisper thread"); + Logger::log(gf->log_level, "Exiting whisper thread"); } diff --git a/src/modules/transcription/src/whisper-utils.cpp b/src/modules/transcription/src/whisper-utils.cpp index 069e0ac..9e787d0 100644 --- a/src/modules/transcription/src/whisper-utils.cpp +++ b/src/modules/transcription/src/whisper-utils.cpp @@ -1,14 +1,12 @@ #include "whisper-utils.h" -#include "plugin-support.h" -#include "model-utils/model-downloader.h" +#include "logger.h" +#include "model-downloader.h" #include "whisper-processing.h" #include "vad-processing.h" -#include - void shutdown_whisper_thread(struct transcription_context *gf) { - obs_log(gf->log_level, "shutdown_whisper_thread"); + Logger::log(gf->log_level, "shutdown_whisper_thread"); if (gf->whisper_context != nullptr) { // acquire the mutex before freeing the context std::lock_guard lock(gf->whisper_ctx_mutex); @@ -28,21 +26,22 @@ void start_whisper_thread_with_path(struct transcription_context *gf, const std::string &whisper_model_path, const char *silero_vad_model_file) { - obs_log(gf->log_level, "start_whisper_thread_with_path: %s, silero model path: %s", - whisper_model_path.c_str(), silero_vad_model_file); + Logger::log(gf->log_level, "start_whisper_thread_with_path: %s, silero model path: %s", + whisper_model_path.c_str(), silero_vad_model_file); std::lock_guard lock(gf->whisper_ctx_mutex); if (gf->whisper_context != nullptr) { - obs_log(LOG_ERROR, "cannot init whisper: whisper_context is not null"); + Logger::log(Logger::Level::ERROR, + "cannot init whisper: whisper_context is not null"); return; } // initialize Silero VAD initialize_vad(gf, silero_vad_model_file); - obs_log(gf->log_level, "Create whisper context"); + Logger::log(gf->log_level, "Create whisper context"); gf->whisper_context = init_whisper_context(whisper_model_path, gf); if (gf->whisper_context == nullptr) { - obs_log(LOG_ERROR, "Failed to initialize whisper context"); + Logger::log(Logger::Level::ERROR, "Failed to initialize whisper context"); return; } gf->whisper_model_file_currently_loaded = whisper_model_path; diff --git a/src/modules/translation/src/translation-utils.cpp b/src/modules/translation/src/translation-utils.cpp index 3abeaef..ffef31d 100644 --- a/src/modules/translation/src/translation-utils.cpp +++ b/src/modules/translation/src/translation-utils.cpp @@ -1,19 +1,19 @@ -#include + #include "translation-includes.h" #include "translation.h" #include "translation-utils.h" -#include "plugin-support.h" +#include "logger.h" #include "model-utils/model-downloader.h" void start_translation(struct transcription_context *gf) { - obs_log(LOG_INFO, "Starting translation..."); + Logger::log(Logger::Level::INFO, "Starting translation..."); if (gf->translation_model_index == "!!!external!!!") { - obs_log(LOG_INFO, "External model selected."); + Logger::log(Logger::Level::INFO, "External model selected."); if (gf->translation_model_path_external.empty()) { - obs_log(LOG_ERROR, "External model path is empty."); + Logger::log(Logger::Level::ERROR, "External model path is empty."); gf->translate = false; return; } @@ -25,15 +25,17 @@ void start_translation(struct transcription_context *gf) const ModelInfo &translation_model_info = models_info[gf->translation_model_index]; std::string model_file_found = find_model_folder(translation_model_info); if (model_file_found == "") { - obs_log(LOG_INFO, "Translation CT2 model does not exist. Downloading..."); + Logger::log(Logger::Level::INFO, + "Translation CT2 model does not exist. Downloading..."); download_model_with_ui_dialog( translation_model_info, [gf, model_file_found](int download_status, const std::string &path) { if (download_status == 0) { - obs_log(LOG_INFO, "CT2 model download complete"); + Logger::log(Logger::Level::INFO, + "CT2 model download complete"); build_and_enable_translation(gf, path); } else { - obs_log(LOG_ERROR, "Model download failed"); + Logger::log(Logger::Level::ERROR, "Model download failed"); gf->translate = false; } }); diff --git a/src/modules/translation/src/translation.cpp b/src/modules/translation/src/translation.cpp index d7a2dec..f1394bc 100644 --- a/src/modules/translation/src/translation.cpp +++ b/src/modules/translation/src/translation.cpp @@ -1,13 +1,13 @@ #include "translation.h" -#include "plugin-support.h" -#include "model-utils/model-find-utils.h" -#include "transcription-filter-data.h" +#include "logger.h" +#include "model-find-utils.h" +#include "transcription-context.h" #include "language_codes.h" #include "translation-language-utils.h" #include #include -#include + #include void build_and_enable_translation(struct transcription_context *gf, @@ -16,12 +16,11 @@ void build_and_enable_translation(struct transcription_context *gf, std::lock_guard lock(gf->whisper_ctx_mutex); gf->translation_ctx.local_model_folder_path = model_file_path; - if (build_translation_context(gf->translation_ctx) == - OBS_POLYGLOT_TRANSLATION_INIT_SUCCESS) { - obs_log(LOG_INFO, "Enable translation"); + if (build_translation_context(gf->translation_ctx) == LOCAAL_TRANSLATION_INIT_SUCCESS) { + Logger::log(Logger::Level::INFO, "Enable translation"); gf->translate = true; } else { - obs_log(LOG_ERROR, "Failed to load CT2 model"); + Logger::log(Logger::Level::ERROR, "Failed to load CT2 model"); gf->translate = false; } } @@ -29,7 +28,8 @@ void build_and_enable_translation(struct transcription_context *gf, int build_translation_context(struct translation_context &translation_ctx) { std::string local_model_path = translation_ctx.local_model_folder_path; - obs_log(LOG_INFO, "Building translation context from '%s'...", local_model_path.c_str()); + Logger::log(Logger::Level::INFO, "Building translation context from '%s'...", + local_model_path.c_str()); // find the SPM file in the model folder std::string local_spm_path = find_file_in_folder_by_regex_expression( local_model_path, "(sentencepiece|spm|spiece|source).*?\\.(model|spm)"); @@ -37,27 +37,30 @@ int build_translation_context(struct translation_context &translation_ctx) find_file_in_folder_by_regex_expression(local_model_path, "target.*?\\.spm"); try { - obs_log(LOG_INFO, "Loading SPM from %s", local_spm_path.c_str()); + Logger::log(Logger::Level::INFO, "Loading SPM from %s", local_spm_path.c_str()); translation_ctx.processor.reset(new sentencepiece::SentencePieceProcessor()); const auto status = translation_ctx.processor->Load(local_spm_path); if (!status.ok()) { - obs_log(LOG_ERROR, "Failed to load SPM: %s", status.ToString().c_str()); - return OBS_POLYGLOT_TRANSLATION_INIT_FAIL; + Logger::log(Logger::Level::ERROR, "Failed to load SPM: %s", + status.ToString().c_str()); + return LOCAAL_TRANSLATION_INIT_FAIL; } if (!target_spm_path.empty()) { - obs_log(LOG_INFO, "Loading target SPM from %s", target_spm_path.c_str()); + Logger::log(Logger::Level::INFO, "Loading target SPM from %s", + target_spm_path.c_str()); translation_ctx.target_processor.reset( new sentencepiece::SentencePieceProcessor()); const auto target_status = translation_ctx.target_processor->Load(target_spm_path); if (!target_status.ok()) { - obs_log(LOG_ERROR, "Failed to load target SPM: %s", - target_status.ToString().c_str()); - return OBS_POLYGLOT_TRANSLATION_INIT_FAIL; + Logger::log(Logger::Level::ERROR, "Failed to load target SPM: %s", + target_status.ToString().c_str()); + return LOCAAL_TRANSLATION_INIT_FAIL; } } else { - obs_log(LOG_INFO, "Target SPM not found, using source SPM for target"); + Logger::log(Logger::Level::INFO, + "Target SPM not found, using source SPM for target"); translation_ctx.target_processor.release(); } @@ -77,19 +80,20 @@ int build_translation_context(struct translation_context &translation_ctx) return std::regex_replace(text, std::regex(""), "UNK"); }; - obs_log(LOG_INFO, "Loading CT2 model from %s", local_model_path.c_str()); + Logger::log(Logger::Level::INFO, "Loading CT2 model from %s", + local_model_path.c_str()); #ifdef POLYGLOT_WITH_CUDA ctranslate2::Device device = ctranslate2::Device::CUDA; - obs_log(LOG_INFO, "CT2 Using CUDA"); + Logger::log(Logger::Level::INFO, "CT2 Using CUDA"); #else ctranslate2::Device device = ctranslate2::Device::CPU; - obs_log(LOG_INFO, "CT2 Using CPU"); + Logger::log(Logger::Level::INFO, "CT2 Using CPU"); #endif translation_ctx.translator.reset(new ctranslate2::Translator( local_model_path, device, ctranslate2::ComputeType::AUTO)); - obs_log(LOG_INFO, "CT2 Model loaded"); + Logger::log(Logger::Level::INFO, "CT2 Model loaded"); translation_ctx.options.reset(new ctranslate2::TranslationOptions); translation_ctx.options->beam_size = 1; @@ -99,10 +103,10 @@ int build_translation_context(struct translation_context &translation_ctx) translation_ctx.options->max_input_length = 64; translation_ctx.options->sampling_temperature = 0.1f; } catch (std::exception &e) { - obs_log(LOG_ERROR, "Failed to load CT2 model: %s", e.what()); - return OBS_POLYGLOT_TRANSLATION_INIT_FAIL; + Logger::log(Logger::Level::ERROR, "Failed to load CT2 model: %s", e.what()); + return LOCAAL_TRANSLATION_INIT_FAIL; } - return OBS_POLYGLOT_TRANSLATION_INIT_SUCCESS; + return LOCAAL_TRANSLATION_INIT_SUCCESS; } int translate(struct translation_context &translation_ctx, const std::string &text, @@ -133,7 +137,8 @@ int translate(struct translation_context &translation_ctx, const std::string &te for (const auto &token : input_tokens) { input_tokens_str += token + ", "; } - obs_log(LOG_INFO, "Input tokens: %s", input_tokens_str.c_str()); + Logger::log(Logger::Level::INFO, "Input tokens: %s", + input_tokens_str.c_str()); translation_ctx.last_input_tokens.push_back(new_input_tokens); // remove the oldest input tokens @@ -160,7 +165,8 @@ int translate(struct translation_context &translation_ctx, const std::string &te for (const auto &token : target_prefix) { target_prefix_str += token + ","; } - obs_log(LOG_INFO, "Target prefix: %s", target_prefix_str.c_str()); + Logger::log(Logger::Level::INFO, "Target prefix: %s", + target_prefix_str.c_str()); const std::vector> target_prefix_batch = { target_prefix}; @@ -189,7 +195,8 @@ int translate(struct translation_context &translation_ctx, const std::string &te for (const auto &token : translation_tokens) { translation_tokens_str += token + ", "; } - obs_log(LOG_INFO, "Translation tokens: %s", translation_tokens_str.c_str()); + Logger::log(Logger::Level::INFO, "Translation tokens: %s", + translation_tokens_str.c_str()); // save the translation tokens translation_ctx.last_translation_tokens.push_back(translation_tokens); @@ -198,15 +205,15 @@ int translate(struct translation_context &translation_ctx, const std::string &te (size_t)translation_ctx.add_context) { translation_ctx.last_translation_tokens.pop_front(); } - obs_log(LOG_INFO, "Last translation tokens deque size: %d", - (int)translation_ctx.last_translation_tokens.size()); + Logger::log(Logger::Level::INFO, "Last translation tokens deque size: %d", + (int)translation_ctx.last_translation_tokens.size()); // detokenize const std::string result_ = translation_ctx.detokenizer(translation_tokens); result = remove_start_punctuation(result_); } catch (std::exception &e) { - obs_log(LOG_ERROR, "Error: %s", e.what()); - return OBS_POLYGLOT_TRANSLATION_FAIL; + Logger::log(Logger::Level::ERROR, "Error: %s", e.what()); + return LOCAAL_TRANSLATION_FAIL; } - return OBS_POLYGLOT_TRANSLATION_SUCCESS; + return LOCAAL_TRANSLATION_SUCCESS; } From 5abc35ca4ff71d02b5177dbe9e0ef39ecf3181a2 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Mon, 16 Sep 2024 15:51:24 -0400 Subject: [PATCH 04/12] Refactor build and dependency management - Remove redundant CMAKE_CXX_STANDARD setting - Include necessary cmake files for building dependencies - Update target link libraries and include directories Related to #123 --- CMakeLists.txt | 38 ++----------------- cmake/BuildPlatformdirs.cmake | 6 ++- cmake/FetchOnnxruntime.cmake | 2 +- src/modules/core/CMakeLists.txt | 9 +++-- src/modules/core/src/logger.cpp | 3 +- src/modules/core/src/model-downloader.cpp | 2 +- src/modules/transcription/CMakeLists.txt | 3 +- .../include/transcription-context.h | 8 ---- src/modules/translation/CMakeLists.txt | 11 +++--- 9 files changed, 24 insertions(+), 58 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8d8a95d..7b719c1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,8 +4,6 @@ project(LocaalSDK VERSION 1.0.0 LANGUAGES CXX) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) -set(CMAKE_CXX_STANDARD 11) - # Option to build shared libraries option(BUILD_SHARED_LIBS "Build shared libraries" OFF) option(BUILD_EXAMPLES "Build examples" OFF) @@ -40,17 +38,6 @@ if(WIN32) set_property(CACHE ACCELERATION PROPERTY STRINGS "cpu" "hipblas" "cuda") endif() -include(cmake/BuildWhispercpp.cmake) -target_link_libraries(${CMAKE_PROJECT_NAME} INTERFACE Whispercpp) - -include(cmake/BuildCTranslate2.cmake) -include(cmake/BuildSentencepiece.cmake) -target_link_libraries(${CMAKE_PROJECT_NAME} INTERFACE ct2 sentencepiece) - -set(USE_SYSTEM_ONNXRUNTIME - OFF - CACHE STRING "Use system ONNX Runtime") - set(DISABLE_ONNXRUNTIME_GPU OFF CACHE STRING "Disables GPU support of ONNX Runtime (Only valid on Linux)") @@ -59,30 +46,13 @@ if(DISABLE_ONNXRUNTIME_GPU) target_compile_definitions(${CMAKE_PROJECT_NAME} INTERFACE DISABLE_ONNXRUNTIME_GPU) endif() -if(USE_SYSTEM_ONNXRUNTIME) - if(OS_LINUX) - find_package(Onnxruntime 1.16.3 REQUIRED) - set(Onnxruntime_INCLUDE_PATH - ${Onnxruntime_INCLUDE_DIR} ${Onnxruntime_INCLUDE_DIR}/onnxruntime - ${Onnxruntime_INCLUDE_DIR}/onnxruntime/core/session ${Onnxruntime_INCLUDE_DIR}/onnxruntime/core/providers/cpu) - target_link_libraries(${CMAKE_PROJECT_NAME} INTERFACE "${Onnxruntime_LIBRARIES}") - target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM INTERFACE "${Onnxruntime_INCLUDE_PATH}") - else() - message(FATAL_ERROR "System ONNX Runtime is only supported on Linux!") - endif() -else() - include(cmake/FetchOnnxruntime.cmake) -endif() - +include(cmake/FetchOnnxruntime.cmake) +include(cmake/BuildWhispercpp.cmake) +include(cmake/BuildCTranslate2.cmake) +include(cmake/BuildSentencepiece.cmake) include(cmake/BuildICU.cmake) -# Add ICU to the target -target_link_libraries(${CMAKE_PROJECT_NAME} INTERFACE ICU) - include(cmake/FetchLibav.cmake) -target_link_libraries(${CMAKE_PROJECT_NAME} INTERFACE FFmpeg) - include(cmake/BuildPlatformdirs.cmake) -target_link_libraries(${CMAKE_PROJECT_NAME} INTERFACE sago_platform_folders_lib) # List of all available modules set(LOCAAL_MODULES diff --git a/cmake/BuildPlatformdirs.cmake b/cmake/BuildPlatformdirs.cmake index 649e4be..4dbb0b2 100644 --- a/cmake/BuildPlatformdirs.cmake +++ b/cmake/BuildPlatformdirs.cmake @@ -1,12 +1,14 @@ include(ExternalProject) +set(SAGO_INSTALL_DIR ${CMAKE_BINARY_DIR}/external/sago_platform_folders) + # Define the sago::platform_folders external project ExternalProject_Add( sago_platform_folders GIT_REPOSITORY https://github.com/sago007/PlatformFolders.git GIT_TAG master # You might want to use a specific tag or commit hash for stability CMAKE_ARGS - -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}/external/sago_platform_folders + -DCMAKE_INSTALL_PREFIX=${SAGO_INSTALL_DIR} -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} BUILD_COMMAND ${CMAKE_COMMAND} --build . --config ${CMAKE_BUILD_TYPE} INSTALL_COMMAND ${CMAKE_COMMAND} --install . --config ${CMAKE_BUILD_TYPE} @@ -19,7 +21,7 @@ add_dependencies(sago_platform_folders_lib sago_platform_folders) # Set include directories for the interface library target_include_directories(sago_platform_folders_lib INTERFACE - $ + $ $ ) # add exported target install diff --git a/cmake/FetchOnnxruntime.cmake b/cmake/FetchOnnxruntime.cmake index 8cf3908..e7ab8b6 100644 --- a/cmake/FetchOnnxruntime.cmake +++ b/cmake/FetchOnnxruntime.cmake @@ -96,8 +96,8 @@ else() set(Onnxruntime_INSTALL_LIBS ${Onnxruntime_LINK_LIBS} "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime_providers_shared.so") endif() + install(FILES ${Onnxruntime_INSTALL_LIBS} DESTINATION "${CMAKE_INSTALL_LIBDIR}/obs-plugins/${CMAKE_PROJECT_NAME}") target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE ${Onnxruntime_LINK_LIBS}) target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM PUBLIC "${onnxruntime_SOURCE_DIR}/include") - install(FILES ${Onnxruntime_INSTALL_LIBS} DESTINATION "${CMAKE_INSTALL_LIBDIR}/obs-plugins/${CMAKE_PROJECT_NAME}") set_target_properties(${CMAKE_PROJECT_NAME} PROPERTIES INSTALL_RPATH "$ORIGIN/${CMAKE_PROJECT_NAME}") endif() diff --git a/src/modules/core/CMakeLists.txt b/src/modules/core/CMakeLists.txt index 13c0e26..cee30a2 100644 --- a/src/modules/core/CMakeLists.txt +++ b/src/modules/core/CMakeLists.txt @@ -4,15 +4,18 @@ add_library(Core src/model-find-utils.cpp ) +target_link_libraries(Core PUBLIC sago_platform_folders_lib) +target_include_directories(Core + PRIVATE + ${SAGO_INSTALL_DIR}/include +) + target_include_directories(Core PUBLIC $ $ ) -# If you have any dependencies for the Core module, link them here -# target_link_libraries(Core PUBLIC SomeDependency) - set_target_properties(Core PROPERTIES OUTPUT_NAME locaal_core EXPORT_NAME Core diff --git a/src/modules/core/src/logger.cpp b/src/modules/core/src/logger.cpp index 5e13589..35ae690 100644 --- a/src/modules/core/src/logger.cpp +++ b/src/modules/core/src/logger.cpp @@ -1,4 +1,5 @@ -#include "Logger.h" +#include "logger.h" + #include #include #include diff --git a/src/modules/core/src/model-downloader.cpp b/src/modules/core/src/model-downloader.cpp index cc7a06e..09cb393 100644 --- a/src/modules/core/src/model-downloader.cpp +++ b/src/modules/core/src/model-downloader.cpp @@ -1,6 +1,6 @@ #include "model-downloader.h" #include "model-find-utils.h" -#include "Logger.h" +#include "logger.h" #include diff --git a/src/modules/transcription/CMakeLists.txt b/src/modules/transcription/CMakeLists.txt index 9410f4a..e936898 100644 --- a/src/modules/transcription/CMakeLists.txt +++ b/src/modules/transcription/CMakeLists.txt @@ -14,8 +14,7 @@ target_include_directories(Transcription $ ) -# If you have any dependencies for the Core module, link them here -# target_link_libraries(Core PUBLIC SomeDependency) +target_link_libraries(Transcription PUBLIC FFmpeg Whispercpp) set_target_properties(Transcription PROPERTIES OUTPUT_NAME locaal_transcription diff --git a/src/modules/transcription/include/transcription-context.h b/src/modules/transcription/include/transcription-context.h index 765eff5..28157ae 100644 --- a/src/modules/transcription/include/transcription-context.h +++ b/src/modules/transcription/include/transcription-context.h @@ -10,8 +10,6 @@ #include #include -#include "translation.h" -#include "translation-includes.h" #include "silero-vad-onnx.h" #include "whisper-processing.h" #include "token-buffer-thread.h" @@ -116,12 +114,6 @@ struct transcription_context { std::condition_variable wshiper_thread_cv; std::optional input_cv; - // translation context - struct translation_context translation_ctx; - std::string translation_model_index; - std::string translation_model_path_external; - bool translate_only_full_sentences; - bool buffered_output = false; TokenBufferThread captions_monitor; TokenBufferThread translation_monitor; diff --git a/src/modules/translation/CMakeLists.txt b/src/modules/translation/CMakeLists.txt index e30dbcf..2055fad 100644 --- a/src/modules/translation/CMakeLists.txt +++ b/src/modules/translation/CMakeLists.txt @@ -1,8 +1,8 @@ add_library(Translation - src/language_codes.cpp - src/translation-language-utils.cpp - src/translation-utils.cpp - src/translation.cpp + src/language_codes.cpp + src/translation-language-utils.cpp + src/translation-utils.cpp + src/translation.cpp ) target_include_directories(Translation @@ -11,8 +11,7 @@ target_include_directories(Translation $ ) -# If you have any dependencies for the Core module, link them here -# target_link_libraries(Core PUBLIC SomeDependency) +target_link_libraries(Translation INTERFACE ICU ct2 sentencepiece) set_target_properties(Translation PROPERTIES OUTPUT_NAME locaal_translation From 1c02a0c2452b45dfa9bc08f3ac7f8bbf034a6898 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Tue, 17 Sep 2024 07:18:30 -0400 Subject: [PATCH 05/12] Refactor build and dependency management --- cmake/BuildCTranslate2.cmake | 5 +- cmake/BuildICU.cmake | 2 +- cmake/BuildMyCurl.cmake | 8 +- cmake/BuildSentencepiece.cmake | 3 +- cmake/FetchLibav.cmake | 2 +- cmake/FetchOnnxruntime.cmake | 26 ++-- src/modules/core/CMakeLists.txt | 2 +- src/modules/core/include/logger.h | 2 +- .../core/include/model-downloader-types.h | 9 +- src/modules/core/include/model-downloader.h | 6 +- src/modules/core/include/model-find-utils.h | 3 +- src/modules/core/src/logger.cpp | 2 +- src/modules/core/src/model-downloader.cpp | 140 ++++++++++++++++- src/modules/core/src/model-find-utils.cpp | 7 +- src/modules/core/src/model-infos.cpp | 10 +- src/modules/transcription/CMakeLists.txt | 2 +- .../include/transcription-context.h | 2 + .../transcription/src/token-buffer-thread.cpp | 3 - .../transcription/src/whisper-model-utils.cpp | 141 +++++------------- .../transcription/src/whisper-processing.cpp | 58 +++---- .../transcription/src/whisper-utils.cpp | 4 +- src/modules/translation/CMakeLists.txt | 2 +- .../translation/include/translation-utils.h | 2 +- .../translation/src/translation-utils.cpp | 7 +- src/modules/translation/src/translation.cpp | 11 +- 25 files changed, 266 insertions(+), 193 deletions(-) diff --git a/cmake/BuildCTranslate2.cmake b/cmake/BuildCTranslate2.cmake index 0ec206b..60132a3 100644 --- a/cmake/BuildCTranslate2.cmake +++ b/cmake/BuildCTranslate2.cmake @@ -27,7 +27,7 @@ elseif(WIN32) if(${ACCELERATION} STREQUAL "cpu" OR ${ACCELERATION} STREQUAL "hipblas") FetchContent_Declare( ctranslate2_fetch - DOWNLOAD_EXTRACT_TIMESTAMP + DOWNLOAD_EXTRACT_TIMESTAMP 1 URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.2.0/libctranslate2-windows-4.1.1-Release-cpu.zip URL_HASH SHA256=30ff8b2499b8d3b5a6c4d6f7f8ddbc89e745ff06e0050b645e3b7c9b369451a3) else() @@ -38,7 +38,8 @@ elseif(WIN32) FetchContent_Declare( ctranslate2_fetch URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.2.0/libctranslate2-windows-4.1.1-Release-cuda12.2.0.zip - URL_HASH SHA256=131724d510f9f2829970953a1bc9e4e8fb7b4cbc8218e32270dcfe6172a51558) + URL_HASH SHA256=131724d510f9f2829970953a1bc9e4e8fb7b4cbc8218e32270dcfe6172a51558 + DOWNLOAD_EXTRACT_TIMESTAMP 1) endif() FetchContent_MakeAvailable(ctranslate2_fetch) diff --git a/cmake/BuildICU.cmake b/cmake/BuildICU.cmake index e328f99..74ca70c 100644 --- a/cmake/BuildICU.cmake +++ b/cmake/BuildICU.cmake @@ -14,7 +14,7 @@ if(WIN32) FetchContent_Declare( ICU_build - DOWNLOAD_EXTRACT_TIMESTAMP + DOWNLOAD_EXTRACT_TIMESTAMP 1 URL ${ICU_URL} URL_HASH ${ICU_HASH}) diff --git a/cmake/BuildMyCurl.cmake b/cmake/BuildMyCurl.cmake index 7c38ea1..9289d55 100644 --- a/cmake/BuildMyCurl.cmake +++ b/cmake/BuildMyCurl.cmake @@ -1,7 +1,7 @@ include(FetchContent) set(LibCurl_VERSION "8.4.0-3") -set(LibCurl_BASEURL "https://github.com/occ-ai/obs-ai-libcurl-dep/releases/download/${LibCurl_VERSION}") +set(LibCurl_BASEURL "https://github.com/locaal-ai/obs-ai-libcurl-dep/releases/download/${LibCurl_VERSION}") if(${CMAKE_BUILD_TYPE} STREQUAL Release OR ${CMAKE_BUILD_TYPE} STREQUAL RelWithDebInfo) set(LibCurl_BUILD_TYPE Release) @@ -35,11 +35,13 @@ else() endif() endif() +message(STATUS "Fetching libcurl from ${LibCurl_URL}") + FetchContent_Declare( libcurl_fetch - DOWNLOAD_EXTRACT_TIMESTAMP URL ${LibCurl_URL} - URL_HASH ${LibCurl_HASH}) + URL_HASH ${LibCurl_HASH} + DOWNLOAD_EXTRACT_TIMESTAMP 1) FetchContent_MakeAvailable(libcurl_fetch) if(MSVC) diff --git a/cmake/BuildSentencepiece.cmake b/cmake/BuildSentencepiece.cmake index ca22a93..29e57bc 100644 --- a/cmake/BuildSentencepiece.cmake +++ b/cmake/BuildSentencepiece.cmake @@ -6,6 +6,7 @@ if(APPLE) FetchContent_Declare( sentencepiece_fetch + DOWNLOAD_EXTRACT_TIMESTAMP 1 URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.1.1/libsentencepiece-macos-Release-1.1.1.tar.gz URL_HASH SHA256=c911f1e84ea94925a8bc3fd3257185b2e18395075509c8659cc7003a979e0b32) FetchContent_MakeAvailable(sentencepiece_fetch) @@ -17,7 +18,7 @@ elseif(WIN32) FetchContent_Declare( sentencepiece_fetch - DOWNLOAD_EXTRACT_TIMESTAMP + DOWNLOAD_EXTRACT_TIMESTAMP 1 URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.1.1/sentencepiece-windows-0.2.0-Release.zip URL_HASH SHA256=846699c7fa1e8918b71ed7f2bd5cd60e47e51105e1d84e3192919b4f0f10fdeb) FetchContent_MakeAvailable(sentencepiece_fetch) diff --git a/cmake/FetchLibav.cmake b/cmake/FetchLibav.cmake index 392f79d..5cd2818 100644 --- a/cmake/FetchLibav.cmake +++ b/cmake/FetchLibav.cmake @@ -8,7 +8,7 @@ if(WIN32) FetchContent_Declare( FFmpeg_fetch - DOWNLOAD_EXTRACT_TIMESTAMP + DOWNLOAD_EXTRACT_TIMESTAMP 1 URL ${FFMPEG_URL} URL_HASH ${FFMPEG_HASH} ) diff --git a/cmake/FetchOnnxruntime.cmake b/cmake/FetchOnnxruntime.cmake index e7ab8b6..0b994d8 100644 --- a/cmake/FetchOnnxruntime.cmake +++ b/cmake/FetchOnnxruntime.cmake @@ -45,16 +45,20 @@ endif() FetchContent_Declare( onnxruntime - DOWNLOAD_EXTRACT_TIMESTAMP + DOWNLOAD_EXTRACT_TIMESTAMP 1 URL ${Onnxruntime_URL} URL_HASH ${Onnxruntime_HASH}) FetchContent_MakeAvailable(onnxruntime) +add_library(Ort INTERFACE) if(APPLE) set(Onnxruntime_LIB "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime.${Onnxruntime_VERSION}.dylib") - target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE "${Onnxruntime_LIB}") - target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM PUBLIC "${onnxruntime_SOURCE_DIR}/include") - target_sources(${CMAKE_PROJECT_NAME} PRIVATE "${Onnxruntime_LIB}") + # target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE "${Onnxruntime_LIB}") + # target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM PUBLIC "${onnxruntime_SOURCE_DIR}/include") + # target_sources(${CMAKE_PROJECT_NAME} PRIVATE "${Onnxruntime_LIB}") + target_link_libraries(Ort INTERFACE "${Onnxruntime_LIB}") + target_include_directories(Ort INTERFACE "${onnxruntime_SOURCE_DIR}/include") + set_property(SOURCE "${Onnxruntime_LIB}" PROPERTY MACOSX_PACKAGE_LOCATION Frameworks) source_group("Frameworks" FILES "${Onnxruntime_LIB}") # add a codesigning step @@ -69,7 +73,7 @@ if(APPLE) ${CMAKE_INSTALL_NAME_TOOL} -change "@rpath/libonnxruntime.${Onnxruntime_VERSION}.dylib" "@loader_path/../Frameworks/libonnxruntime.${Onnxruntime_VERSION}.dylib" $) elseif(MSVC) - add_library(Ort INTERFACE) + set(Onnxruntime_LIB_NAMES onnxruntime;onnxruntime_providers_shared) foreach(lib_name IN LISTS Onnxruntime_LIB_NAMES) add_library(Ort::${lib_name} SHARED IMPORTED) @@ -80,8 +84,6 @@ elseif(MSVC) install(FILES ${onnxruntime_SOURCE_DIR}/lib/${lib_name}.dll DESTINATION "obs-plugins/64bit") endforeach() - target_link_libraries(${CMAKE_PROJECT_NAME} INTERFACE Ort) - # add exported target install install(TARGETS Ort EXPORT OrtTargets) install(EXPORT OrtTargets NAMESPACE Ort:: DESTINATION "lib/cmake/Ort") @@ -97,7 +99,11 @@ else() "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime_providers_shared.so") endif() install(FILES ${Onnxruntime_INSTALL_LIBS} DESTINATION "${CMAKE_INSTALL_LIBDIR}/obs-plugins/${CMAKE_PROJECT_NAME}") - target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE ${Onnxruntime_LINK_LIBS}) - target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM PUBLIC "${onnxruntime_SOURCE_DIR}/include") - set_target_properties(${CMAKE_PROJECT_NAME} PROPERTIES INSTALL_RPATH "$ORIGIN/${CMAKE_PROJECT_NAME}") + + target_link_libraries(Ort INTERFACE ${Onnxruntime_LINK_LIBS}) + target_include_directories(Ort INTERFACE "${onnxruntime_SOURCE_DIR}/include") + + # target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE ${Onnxruntime_LINK_LIBS}) + # target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM PUBLIC "${onnxruntime_SOURCE_DIR}/include") + # set_target_properties(${CMAKE_PROJECT_NAME} PROPERTIES INSTALL_RPATH "$ORIGIN/${CMAKE_PROJECT_NAME}") endif() diff --git a/src/modules/core/CMakeLists.txt b/src/modules/core/CMakeLists.txt index cee30a2..0d1038b 100644 --- a/src/modules/core/CMakeLists.txt +++ b/src/modules/core/CMakeLists.txt @@ -4,7 +4,7 @@ add_library(Core src/model-find-utils.cpp ) -target_link_libraries(Core PUBLIC sago_platform_folders_lib) +target_link_libraries(Core PUBLIC sago_platform_folders_lib libcurl) target_include_directories(Core PRIVATE ${SAGO_INSTALL_DIR}/include diff --git a/src/modules/core/include/logger.h b/src/modules/core/include/logger.h index 9394ecc..593b5f0 100644 --- a/src/modules/core/include/logger.h +++ b/src/modules/core/include/logger.h @@ -6,7 +6,7 @@ class Logger { public: - enum class Level { DEBUG, INFO, WARNING, ERROR }; + enum class Level { DEBUG = 0, INFO, WARNING, ERROR_LOG }; using LogCallback = std::function; diff --git a/src/modules/core/include/model-downloader-types.h b/src/modules/core/include/model-downloader-types.h index 3d24d96..774c5fb 100644 --- a/src/modules/core/include/model-downloader-types.h +++ b/src/modules/core/include/model-downloader-types.h @@ -8,13 +8,20 @@ typedef std::function download_finished_callback_t; +typedef std::function download_progress_callback_t; + +enum DownloadStatus { DOWNLOAD_STATUS_OK, DOWNLOAD_STATUS_ERROR }; +enum DownloadError { DOWNLOAD_ERROR_OK, DOWNLOAD_ERROR_NETWORK, DOWNLOAD_ERROR_FILE }; + +typedef std::function + download_error_callback_t; struct ModelFileDownloadInfo { std::string url; std::string sha256; }; -enum ModelType { MODEL_TYPE_TRANSCRIPTION, MODEL_TYPE_TRANSLATION }; +enum ModelType { MODEL_TYPE_TRANSCRIPTION, MODEL_TYPE_TRANSLATION, MODEL_TYPE_VAD }; struct ModelInfo { std::string friendly_name; diff --git a/src/modules/core/include/model-downloader.h b/src/modules/core/include/model-downloader.h index b359209..2ef19b6 100644 --- a/src/modules/core/include/model-downloader.h +++ b/src/modules/core/include/model-downloader.h @@ -6,6 +6,10 @@ #include "model-downloader-types.h" std::string find_model_folder(const ModelInfo &model_info); -std::string find_model_bin_file(const ModelInfo &model_info); +std::string find_model_ext_file(const ModelInfo &model_info, const std::string &ext); + +void download_model(const ModelInfo &model_info, download_finished_callback_t finished_callback, + download_progress_callback_t progress_callback, + download_error_callback_t error_callback); #endif // MODEL_DOWNLOADER_H diff --git a/src/modules/core/include/model-find-utils.h b/src/modules/core/include/model-find-utils.h index 72a3a6f..a2af294 100644 --- a/src/modules/core/include/model-find-utils.h +++ b/src/modules/core/include/model-find-utils.h @@ -7,7 +7,8 @@ std::string find_file_in_folder_by_name(const std::string &folder_path, const std::string &file_name); -std::string find_bin_file_in_folder(const std::string &path); +std::string find_ext_file_in_folder(const std::string &model_local_folder_path, + const std::string &ext); std::string find_file_in_folder_by_regex_expression(const std::string &folder_path, const std::string &file_name_regex); diff --git a/src/modules/core/src/logger.cpp b/src/modules/core/src/logger.cpp index 35ae690..7da6f7f 100644 --- a/src/modules/core/src/logger.cpp +++ b/src/modules/core/src/logger.cpp @@ -39,7 +39,7 @@ std::string Logger::getLevelString(Level level) return "INFO"; case Level::WARNING: return "WARNING"; - case Level::ERROR: + case Level::ERROR_LOG: return "ERROR"; default: return "UNKNOWN"; diff --git a/src/modules/core/src/model-downloader.cpp b/src/modules/core/src/model-downloader.cpp index 09cb393..58b6873 100644 --- a/src/modules/core/src/model-downloader.cpp +++ b/src/modules/core/src/model-downloader.cpp @@ -6,21 +6,43 @@ #include +#include + +std::filesystem::path get_models_folder() +{ + const std::string cache_folder = sago::getCacheDir(); + const std::filesystem::path absolute_cache_folder = std::filesystem::absolute(cache_folder); + + const std::filesystem::path models_folder = + (absolute_cache_folder / "locaal" / "models").string(); + + // Check if the data folder exists + if (!std::filesystem::exists(models_folder)) { + Logger::log(Logger::Level::INFO, "Creating models folder: %s", + models_folder.c_str()); + // Create the data folder + if (!std::filesystem::create_directories(models_folder)) { + Logger::log(Logger::Level::ERROR_LOG, "Failed to create models folder: %s", + models_folder.c_str()); + return ""; + } + } + + return models_folder; +} + std::string find_model_folder(const ModelInfo &model_info) { if (model_info.friendly_name.empty() || model_info.local_folder_name.empty() || model_info.files.empty()) { - Logger::log(Logger::Level::ERROR, "Model info is invalid."); + Logger::log(Logger::Level::ERROR_LOG, "Model info is invalid."); return ""; } - const std::string data_folder_models = sago::getCacheDir(); - const std::filesystem::path module_data_models_folder = - std::filesystem::absolute(data_folder_models); + const std::filesystem::path data_folder_models = get_models_folder(); const std::string model_local_data_path = - (module_data_models_folder / "locaal" / "models" / model_info.local_folder_name) - .string(); + (data_folder_models / model_info.local_folder_name).string(); Logger::log(Logger::Level::INFO, "Checking if model '%s' exists in cache...", model_info.friendly_name.c_str()); @@ -35,12 +57,114 @@ std::string find_model_folder(const ModelInfo &model_info) return ""; } -std::string find_model_bin_file(const ModelInfo &model_info) +std::string find_model_ext_file(const ModelInfo &model_info, const std::string &ext) { const std::string model_local_folder_path = find_model_folder(model_info); if (model_local_folder_path.empty()) { return ""; } - return find_bin_file_in_folder(model_local_folder_path); + return find_ext_file_in_folder(model_local_folder_path, ext); +} + +size_t write_data(void *ptr, size_t size, size_t nmemb, FILE *stream) +{ + size_t written = fwrite(ptr, size, nmemb, stream); + return written; +} + +std::string get_filename_from_url(const std::string &url) +{ + auto lastSlashPos = url.find_last_of("/"); + auto queryPos = url.find("?", lastSlashPos); + if (queryPos == std::string::npos) { + return url.substr(lastSlashPos + 1); + } else { + return url.substr(lastSlashPos + 1, queryPos - lastSlashPos - 1); + } +} + +void download_model(const ModelInfo &model_info, download_finished_callback_t finished_callback, + download_progress_callback_t progress_callback, + download_error_callback_t error_callback) +{ + const std::filesystem::path module_config_models_folder = get_models_folder(); + + const std::string model_local_config_path = + (module_config_models_folder / model_info.local_folder_name).string(); + + Logger::log(Logger::Level::INFO, "Model save path: %s", model_local_config_path.c_str()); + + if (!std::filesystem::exists(model_local_config_path)) { + // model folder does not exist, create it + if (!std::filesystem::create_directories(model_local_config_path)) { + Logger::log(Logger::Level::ERROR_LOG, "Failed to create model folder: %s", + model_local_config_path.c_str()); + error_callback(DownloadError::DOWNLOAD_ERROR_FILE, + "Failed to create model folder."); + return; + } + } + + CURL *curl = curl_easy_init(); + if (curl) { + for (auto &model_download_file : model_info.files) { + Logger::log(Logger::Level::INFO, "Model URL: %s", + model_download_file.url.c_str()); + + const std::string model_filename = + get_filename_from_url(model_download_file.url); + const std::string model_file_save_path = + (std::filesystem::path(model_local_config_path) / model_filename) + .string(); + if (std::filesystem::exists(model_file_save_path)) { + Logger::log(Logger::Level::INFO, "Model file already exists: %s", + model_file_save_path.c_str()); + continue; + } + + FILE *fp = fopen(model_file_save_path.c_str(), "wb"); + if (fp == nullptr) { + Logger::log(Logger::Level::ERROR_LOG, + "Failed to open model file for writing %s.", + model_file_save_path.c_str()); + error_callback(DownloadError::DOWNLOAD_ERROR_FILE, + "Failed to open file."); + return; + } + curl_easy_setopt(curl, CURLOPT_URL, model_download_file.url.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_data); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, fp); + curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); + curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, + [progress_callback](void *clientp, curl_off_t dltotal, + curl_off_t dlnow, curl_off_t, + curl_off_t) { + if (dltotal == 0) { + return 0; // Unknown progress + } + int progress = (int)(dlnow * 100l / dltotal); + progress_callback(progress); + return 0; + }); + // curl_easy_setopt(curl, CURLOPT_XFERINFODATA, &progress_callback); + // Follow redirects + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); + CURLcode res = curl_easy_perform(curl); + if (res != CURLE_OK) { + Logger::log(Logger::Level::ERROR_LOG, + "Failed to download model file %s.", + model_filename.c_str()); + error_callback(DownloadError::DOWNLOAD_ERROR_NETWORK, + "Failed to download model file."); + } + fclose(fp); + } + curl_easy_cleanup(curl); + finished_callback(DownloadStatus::DOWNLOAD_STATUS_OK, model_local_config_path); + } else { + Logger::log(Logger::Level::ERROR_LOG, "Failed to initialize curl."); + error_callback(DownloadError::DOWNLOAD_ERROR_NETWORK, "Failed to initialize curl."); + finished_callback(DownloadStatus::DOWNLOAD_STATUS_ERROR, ""); + } } diff --git a/src/modules/core/src/model-find-utils.cpp b/src/modules/core/src/model-find-utils.cpp index 7ee93e8..b1f932b 100644 --- a/src/modules/core/src/model-find-utils.cpp +++ b/src/modules/core/src/model-find-utils.cpp @@ -31,18 +31,19 @@ std::string find_file_in_folder_by_regex_expression(const std::string &folder_pa return ""; } -std::string find_bin_file_in_folder(const std::string &model_local_folder_path) +std::string find_ext_file_in_folder(const std::string &model_local_folder_path, + const std::string &ext) { // find .bin file in folder for (const auto &entry : std::filesystem::directory_iterator(model_local_folder_path)) { - if (entry.path().extension() == ".bin") { + if (entry.path().extension() == ext) { const std::string bin_file_path = entry.path().string(); Logger::log(Logger::Level::INFO, "Model bin file found in folder: %s", bin_file_path.c_str()); return bin_file_path; } } - Logger::log(Logger::Level::ERROR, "Model bin file not found in folder: %s", + Logger::log(Logger::Level::ERROR_LOG, "Model bin file not found in folder: %s", model_local_folder_path.c_str()); return ""; } diff --git a/src/modules/core/src/model-infos.cpp b/src/modules/core/src/model-infos.cpp index e978002..a438659 100644 --- a/src/modules/core/src/model-infos.cpp +++ b/src/modules/core/src/model-infos.cpp @@ -221,14 +221,20 @@ std::map models_info = {{ "7d99f41a10525d0206bddadd86760181fa920438b6b33237e3118ff6c83bb53d"}}}}, {"Whisper Medium English (1.5Gb)", {"Whisper Medium English", - "ggml-meduim-en", + "ggml-medium-en", MODEL_TYPE_TRANSCRIPTION, {{"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium.en.bin", "cc37e93478338ec7700281a7ac30a10128929eb8f427dda2e865faa8f6da4356"}}}}, {"Whisper Medium (1.5Gb)", {"Whisper Medium", - "ggml-meduim", + "ggml-medium", MODEL_TYPE_TRANSCRIPTION, {{"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium.bin", "6c14d5adee5f86394037b4e4e8b59f1673b6cee10e3cf0b11bbdbee79c156208"}}}}, + {"Silero VAD v5", + {"Silero VAD v5", + "silero-vad", + MODEL_TYPE_VAD, + {{"https://github.com/snakers4/silero-vad/raw/master/src/silero_vad/data/silero_vad.onnx", + "2623a2953f6ff3d2c1e61740c6cdb7168133479b267dfef114a4a3cc5bdd788f"}}}}, }}; diff --git a/src/modules/transcription/CMakeLists.txt b/src/modules/transcription/CMakeLists.txt index e936898..5041f5d 100644 --- a/src/modules/transcription/CMakeLists.txt +++ b/src/modules/transcription/CMakeLists.txt @@ -14,7 +14,7 @@ target_include_directories(Transcription $ ) -target_link_libraries(Transcription PUBLIC FFmpeg Whispercpp) +target_link_libraries(Transcription PUBLIC Core FFmpeg Whispercpp Ort) set_target_properties(Transcription PROPERTIES OUTPUT_NAME locaal_transcription diff --git a/src/modules/transcription/include/transcription-context.h b/src/modules/transcription/include/transcription-context.h index 28157ae..97e7ef4 100644 --- a/src/modules/transcription/include/transcription-context.h +++ b/src/modules/transcription/include/transcription-context.h @@ -9,6 +9,8 @@ #include #include #include +#include +#include #include "silero-vad-onnx.h" #include "whisper-processing.h" diff --git a/src/modules/transcription/src/token-buffer-thread.cpp b/src/modules/transcription/src/token-buffer-thread.cpp index 3c27206..ef46227 100644 --- a/src/modules/transcription/src/token-buffer-thread.cpp +++ b/src/modules/transcription/src/token-buffer-thread.cpp @@ -1,6 +1,3 @@ -#include -#include -#include #include "token-buffer-thread.h" #include "whisper-utils.h" diff --git a/src/modules/transcription/src/whisper-model-utils.cpp b/src/modules/transcription/src/whisper-model-utils.cpp index 4f5b8da..16e4dd7 100644 --- a/src/modules/transcription/src/whisper-model-utils.cpp +++ b/src/modules/transcription/src/whisper-model-utils.cpp @@ -5,56 +5,22 @@ #include "whisper-utils.h" #include "whisper-processing.h" #include "logger.h" -#include "model-utils/model-downloader.h" +#include "model-downloader.h" -void update_whisper_model(struct transcription_context *gf) +void update_whisper_model(struct transcription_context *gf, const std::string new_model_path) { - if (gf->context == nullptr) { - Logger::log(Logger::Level::ERROR, "obs_source_t context is null"); - return; - } - - obs_data_t *s = obs_source_get_settings(gf->context); - if (s == nullptr) { - Logger::log(Logger::Level::ERROR, "obs_data_t settings is null"); - return; - } - - // Get settings from context - std::string new_model_path = obs_data_get_string(s, "whisper_model_path") != nullptr - ? obs_data_get_string(s, "whisper_model_path") - : ""; - std::string external_model_file_path = - obs_data_get_string(s, "whisper_model_path_external") != nullptr - ? obs_data_get_string(s, "whisper_model_path_external") - : ""; - const bool new_dtw_timestamps = obs_data_get_bool(s, "dtw_token_timestamps"); - obs_data_release(s); - // update the whisper model path - const bool is_external_model = new_model_path.find("!!!external!!!") != std::string::npos; - - if (!is_external_model && new_model_path.empty()) { - Logger::log(Logger::Level::WARNING, "Whisper model path is empty"); - return; - } - if (is_external_model && external_model_file_path.empty()) { - Logger::log(Logger::Level::WARNING, "External model file path is empty"); - return; - } - - char *silero_vad_model_file = obs_module_file("models/silero-vad/silero_vad.onnx"); - if (silero_vad_model_file == nullptr) { - Logger::log(Logger::Level::ERROR, "Cannot find Silero VAD model file"); + const ModelInfo &silero_vad_model_info = models_info["Silero VAD v5"]; + const std::string silero_vad_model_file = + find_model_ext_file(silero_vad_model_info, ".onnx"); + if (silero_vad_model_file.empty()) { + Logger::log(Logger::Level::ERROR_LOG, "Cannot find Silero VAD model file"); return; } Logger::log(gf->log_level, "Silero VAD model file: %s", silero_vad_model_file); - std::string silero_vad_model_file_str = std::string(silero_vad_model_file); - bfree(silero_vad_model_file); - if (gf->whisper_model_path.empty() || gf->whisper_model_path != new_model_path || - is_external_model) { + if (gf->whisper_model_path.empty() || gf->whisper_model_path != new_model_path) { if (gf->whisper_model_path != new_model_path) { // model path changed @@ -65,79 +31,42 @@ void update_whisper_model(struct transcription_context *gf) gf->whisper_model_loaded_new = !gf->whisper_model_path.empty(); } - // check if the new model is external file - if (!is_external_model) { - // new model is not external file - shutdown_whisper_thread(gf); - - if (models_info.count(new_model_path) == 0) { - Logger::log(Logger::Level::WARNING, "Model '%s' does not exist", - new_model_path.c_str()); - return; - } + shutdown_whisper_thread(gf); - const ModelInfo &model_info = models_info[new_model_path]; + if (models_info.count(new_model_path) == 0) { + Logger::log(Logger::Level::WARNING, "Model '%s' does not exist", + new_model_path.c_str()); + return; + } - // check if the model exists, if not, download it - std::string model_file_found = find_model_bin_file(model_info); - if (model_file_found == "") { - Logger::log(Logger::Level::WARNING, "Whisper model does not exist"); - download_model_with_ui_dialog( - model_info, - [gf, new_model_path, silero_vad_model_file_str]( - int download_status, const std::string &path) { - if (download_status == 0) { - Logger::log(Logger::Level::INFO, - "Model download complete"); - gf->whisper_model_path = new_model_path; - start_whisper_thread_with_path( - gf, path, - silero_vad_model_file_str.c_str()); - } else { - Logger::log(Logger::Level::ERROR, - "Model download failed"); - } - }); - } else { - // Model exists, just load it - gf->whisper_model_path = new_model_path; - start_whisper_thread_with_path(gf, model_file_found, - silero_vad_model_file_str.c_str()); - } - } else { - // new model is external file, get file location from file property - if (external_model_file_path.empty()) { - Logger::log(Logger::Level::WARNING, - "External model file path is empty"); - } else { - // check if the external model file is not currently loaded - if (gf->whisper_model_file_currently_loaded == - external_model_file_path) { - Logger::log(Logger::Level::INFO, - "External model file is already loaded"); - return; - } else { - shutdown_whisper_thread(gf); + const ModelInfo &model_info = models_info[new_model_path]; + + // check if the model exists, if not, download it + std::string model_file_found = find_model_ext_file(model_info, ".bin"); + if (model_file_found == "") { + Logger::log(Logger::Level::WARNING, "Whisper model does not exist"); + download_model(model_info, [gf, new_model_path, silero_vad_model_file]( + int download_status, + const std::string &path) { + if (download_status == DownloadStatus::DOWNLOAD_STATUS_OK) { + Logger::log(Logger::Level::INFO, "Model download complete"); gf->whisper_model_path = new_model_path; start_whisper_thread_with_path( - gf, external_model_file_path, - silero_vad_model_file_str.c_str()); + gf, path, silero_vad_model_file.c_str()); + } else { + Logger::log(Logger::Level::ERROR_LOG, + "Model download failed"); } - } + }); + } else { + // Model exists, just load it + gf->whisper_model_path = new_model_path; + start_whisper_thread_with_path(gf, model_file_found, + silero_vad_model_file.c_str()); } } else { // model path did not change Logger::log(gf->log_level, "Model path did not change: %s == %s", gf->whisper_model_path.c_str(), new_model_path.c_str()); } - - if (new_dtw_timestamps != gf->enable_token_ts_dtw) { - // dtw_token_timestamps changed - Logger::log(gf->log_level, "dtw_token_timestamps changed from %d to %d", - gf->enable_token_ts_dtw, new_dtw_timestamps); - gf->enable_token_ts_dtw = new_dtw_timestamps; - shutdown_whisper_thread(gf); - start_whisper_thread_with_path(gf, gf->whisper_model_path, - silero_vad_model_file_str.c_str()); - } } diff --git a/src/modules/transcription/src/whisper-processing.cpp b/src/modules/transcription/src/whisper-processing.cpp index 90cf27e..7cb8ee5 100644 --- a/src/modules/transcription/src/whisper-processing.cpp +++ b/src/modules/transcription/src/whisper-processing.cpp @@ -1,9 +1,7 @@ #include -#include - #include "logger.h" -#include "transcription-filter-data.h" +#include "transcription-context.h" #include "whisper-processing.h" #include "whisper-utils.h" #include "transcription-utils.h" @@ -14,12 +12,15 @@ #include #endif -#include "model-utils/model-find-utils.h" +#include "model-find-utils.h" #include "vad-processing.h" #include #include #include +#include +#include +#include struct whisper_context *init_whisper_context(const std::string &model_path_in, struct transcription_context *gf) @@ -33,10 +34,10 @@ struct whisper_context *init_whisper_context(const std::string &model_path_in, Logger::Level::INFO, "Model path is a directory, not a file, looking for .bin file in folder"); // look for .bin file - const std::string model_bin_file = find_bin_file_in_folder(model_path); + const std::string model_bin_file = find_ext_file_in_folder(model_path, ".bin"); if (model_bin_file.empty()) { - Logger::log(Logger::Level::ERROR, "Model bin file not found in folder: %s", - model_path.c_str()); + Logger::log(Logger::Level::ERROR_LOG, + "Model bin file not found in folder: %s", model_path.c_str()); return nullptr; } model_path = model_bin_file; @@ -44,14 +45,9 @@ struct whisper_context *init_whisper_context(const std::string &model_path_in, whisper_log_set( [](enum ggml_log_level level, const char *text, void *user_data) { - UNUSED_PARAMETER(level); struct transcription_context *ctx = static_cast(user_data); - // remove trailing newline - char *text_copy = bstrdup(text); - text_copy[strcspn(text_copy, "\n")] = 0; - Logger::log(ctx->log_level, "Whisper: %s", text_copy); - bfree(text_copy); + Logger::log(ctx->log_level, "Whisper: %s", text); }, gf); @@ -94,8 +90,8 @@ struct whisper_context *init_whisper_context(const std::string &model_path_in, // Read model into buffer std::ifstream modelFile(model_path_ws, std::ios::binary); if (!modelFile.is_open()) { - Logger::log(Logger::Level::ERROR, "Failed to open whisper model file %s", - model_path.c_str()); + Logger::log(Logger::Level::ERROR_LOG, + "Failed to open whisper model file %s", model_path.c_str()); return nullptr; } modelFile.seekg(0, std::ios::end); @@ -112,12 +108,12 @@ struct whisper_context *init_whisper_context(const std::string &model_path_in, ctx = whisper_init_from_file_with_params(model_path.c_str(), cparams); #endif } catch (const std::exception &e) { - Logger::log(Logger::Level::ERROR, "Exception while loading whisper model: %s", + Logger::log(Logger::Level::ERROR_LOG, "Exception while loading whisper model: %s", e.what()); return nullptr; } if (ctx == nullptr) { - Logger::log(Logger::Level::ERROR, "Failed to load whisper model"); + Logger::log(Logger::Level::ERROR_LOG, "Failed to load whisper model"); return nullptr; } @@ -132,12 +128,12 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_contex int vad_state = VAD_STATE_WAS_OFF) { if (gf == nullptr) { - Logger::log(Logger::Level::ERROR, "run_whisper_inference: gf is null"); + Logger::log(Logger::Level::ERROR_LOG, "run_whisper_inference: gf is null"); return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""}; } if (pcm32f_data_ == nullptr || pcm32f_num_samples == 0) { - Logger::log(Logger::Level::ERROR, + Logger::log(Logger::Level::ERROR_LOG, "run_whisper_inference: pcm32f_data is null or size is 0"); return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""}; } @@ -167,7 +163,7 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_contex "Speech segment is less than 1 second, padding with white noise to 1 second"); const size_t new_size = (size_t)(1.01f * (float)(WHISPER_SAMPLE_RATE)); // create a new buffer and copy the data to it in the middle - pcm32f_data = (float *)bzalloc(new_size * sizeof(float)); + pcm32f_data = (float *)malloc(new_size * sizeof(float)); // add low volume white noise const float noise_level = 0.01f; @@ -208,17 +204,17 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_contex whisper_full_result = whisper_full(gf->whisper_context, gf->whisper_params, pcm32f_data, (int)pcm32f_size); } catch (const std::exception &e) { - Logger::log(Logger::Level::ERROR, + Logger::log(Logger::Level::ERROR_LOG, "Whisper exception: %s. Filter restart is required", e.what()); whisper_free(gf->whisper_context); gf->whisper_context = nullptr; if (should_free_buffer) { - bfree(pcm32f_data); + free(pcm32f_data); } return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""}; } if (should_free_buffer) { - bfree(pcm32f_data); + free(pcm32f_data); } std::string language = gf->whisper_params.language; @@ -321,10 +317,10 @@ void run_inference_and_callbacks(transcription_context *gf, uint64_t start_offse { // get the data from the entire whisper buffer // add 50ms of silence to the beginning and end of the buffer - const size_t pcm32f_size = gf->whisper_buffer.size / sizeof(float); + const size_t pcm32f_size = gf->whisper_buffer.size(); const size_t pcm32f_size_with_silence = pcm32f_size + 2 * WHISPER_SAMPLE_RATE / 100; // allocate a new buffer and copy the data to it - float *pcm32f_data = (float *)bzalloc(pcm32f_size_with_silence * sizeof(float)); + float *pcm32f_data = (float *)malloc(pcm32f_size_with_silence * sizeof(float)); if (vad_state == VAD_STATE_PARTIAL) { // peek instead of pop, since this is a partial run that keeps the data in the buffer circlebuf_peek_front(&gf->whisper_buffer, pcm32f_data + WHISPER_SAMPLE_RATE / 100, @@ -346,13 +342,13 @@ void run_inference_and_callbacks(transcription_context *gf, uint64_t start_offse } // free the buffer - bfree(pcm32f_data); + free(pcm32f_data); } void whisper_loop(void *data) { if (data == nullptr) { - Logger::log(Logger::Level::ERROR, "whisper_loop: data is null"); + Logger::log(Logger::Level::ERROR_LOG, "whisper_loop: data is null"); return; } @@ -362,16 +358,10 @@ void whisper_loop(void *data) vad_state current_vad_state = {false, now_ms(), 0, 0}; - const char *whisper_loop_name = "Whisper loop"; - profile_register_root(whisper_loop_name, 50 * 1000 * 1000); - // Thread main loop while (true) { - ProfileScope(whisper_loop_name); { - ProfileScope("lock whisper ctx"); std::lock_guard lock(gf->whisper_ctx_mutex); - ProfileScope("locked whisper ctx"); if (gf->whisper_context == nullptr) { Logger::log(Logger::Level::WARNING, "Whisper context is null, exiting thread"); @@ -404,7 +394,7 @@ void whisper_loop(void *data) // This will wake up the thread if there is new data in the input buffer // or if the whisper context is null std::unique_lock lock(gf->whisper_ctx_mutex); - if (gf->input_buffers->size == 0) { + if (gf->input_buffers[0].empty()) { gf->wshiper_thread_cv.wait_for(lock, std::chrono::milliseconds(50)); } } diff --git a/src/modules/transcription/src/whisper-utils.cpp b/src/modules/transcription/src/whisper-utils.cpp index 9e787d0..d6d7843 100644 --- a/src/modules/transcription/src/whisper-utils.cpp +++ b/src/modules/transcription/src/whisper-utils.cpp @@ -30,7 +30,7 @@ void start_whisper_thread_with_path(struct transcription_context *gf, whisper_model_path.c_str(), silero_vad_model_file); std::lock_guard lock(gf->whisper_ctx_mutex); if (gf->whisper_context != nullptr) { - Logger::log(Logger::Level::ERROR, + Logger::log(Logger::Level::ERROR_LOG, "cannot init whisper: whisper_context is not null"); return; } @@ -41,7 +41,7 @@ void start_whisper_thread_with_path(struct transcription_context *gf, Logger::log(gf->log_level, "Create whisper context"); gf->whisper_context = init_whisper_context(whisper_model_path, gf); if (gf->whisper_context == nullptr) { - Logger::log(Logger::Level::ERROR, "Failed to initialize whisper context"); + Logger::log(Logger::Level::ERROR_LOG, "Failed to initialize whisper context"); return; } gf->whisper_model_file_currently_loaded = whisper_model_path; diff --git a/src/modules/translation/CMakeLists.txt b/src/modules/translation/CMakeLists.txt index 2055fad..bca362b 100644 --- a/src/modules/translation/CMakeLists.txt +++ b/src/modules/translation/CMakeLists.txt @@ -11,7 +11,7 @@ target_include_directories(Translation $ ) -target_link_libraries(Translation INTERFACE ICU ct2 sentencepiece) +target_link_libraries(Translation INTERFACE Core ICU ct2 sentencepiece) set_target_properties(Translation PROPERTIES OUTPUT_NAME locaal_translation diff --git a/src/modules/translation/include/translation-utils.h b/src/modules/translation/include/translation-utils.h index a2f71d9..ac98729 100644 --- a/src/modules/translation/include/translation-utils.h +++ b/src/modules/translation/include/translation-utils.h @@ -1,7 +1,7 @@ #ifndef TRANSLATION_UTILS_H #define TRANSLATION_UTILS_H -#include "transcription-filter-data.h" +#include "transcription-context.h" void start_translation(struct transcription_context *gf); diff --git a/src/modules/translation/src/translation-utils.cpp b/src/modules/translation/src/translation-utils.cpp index ffef31d..6f412be 100644 --- a/src/modules/translation/src/translation-utils.cpp +++ b/src/modules/translation/src/translation-utils.cpp @@ -4,7 +4,7 @@ #include "translation.h" #include "translation-utils.h" #include "logger.h" -#include "model-utils/model-downloader.h" +#include "model-downloader.h" void start_translation(struct transcription_context *gf) { @@ -13,7 +13,7 @@ void start_translation(struct transcription_context *gf) if (gf->translation_model_index == "!!!external!!!") { Logger::log(Logger::Level::INFO, "External model selected."); if (gf->translation_model_path_external.empty()) { - Logger::log(Logger::Level::ERROR, "External model path is empty."); + Logger::log(Logger::Level::ERROR_LOG, "External model path is empty."); gf->translate = false; return; } @@ -35,7 +35,8 @@ void start_translation(struct transcription_context *gf) "CT2 model download complete"); build_and_enable_translation(gf, path); } else { - Logger::log(Logger::Level::ERROR, "Model download failed"); + Logger::log(Logger::Level::ERROR_LOG, + "Model download failed"); gf->translate = false; } }); diff --git a/src/modules/translation/src/translation.cpp b/src/modules/translation/src/translation.cpp index f1394bc..7184c00 100644 --- a/src/modules/translation/src/translation.cpp +++ b/src/modules/translation/src/translation.cpp @@ -20,7 +20,7 @@ void build_and_enable_translation(struct transcription_context *gf, Logger::log(Logger::Level::INFO, "Enable translation"); gf->translate = true; } else { - Logger::log(Logger::Level::ERROR, "Failed to load CT2 model"); + Logger::log(Logger::Level::ERROR_LOG, "Failed to load CT2 model"); gf->translate = false; } } @@ -41,7 +41,7 @@ int build_translation_context(struct translation_context &translation_ctx) translation_ctx.processor.reset(new sentencepiece::SentencePieceProcessor()); const auto status = translation_ctx.processor->Load(local_spm_path); if (!status.ok()) { - Logger::log(Logger::Level::ERROR, "Failed to load SPM: %s", + Logger::log(Logger::Level::ERROR_LOG, "Failed to load SPM: %s", status.ToString().c_str()); return LOCAAL_TRANSLATION_INIT_FAIL; } @@ -54,7 +54,8 @@ int build_translation_context(struct translation_context &translation_ctx) const auto target_status = translation_ctx.target_processor->Load(target_spm_path); if (!target_status.ok()) { - Logger::log(Logger::Level::ERROR, "Failed to load target SPM: %s", + Logger::log(Logger::Level::ERROR_LOG, + "Failed to load target SPM: %s", target_status.ToString().c_str()); return LOCAAL_TRANSLATION_INIT_FAIL; } @@ -103,7 +104,7 @@ int build_translation_context(struct translation_context &translation_ctx) translation_ctx.options->max_input_length = 64; translation_ctx.options->sampling_temperature = 0.1f; } catch (std::exception &e) { - Logger::log(Logger::Level::ERROR, "Failed to load CT2 model: %s", e.what()); + Logger::log(Logger::Level::ERROR_LOG, "Failed to load CT2 model: %s", e.what()); return LOCAAL_TRANSLATION_INIT_FAIL; } return LOCAAL_TRANSLATION_INIT_SUCCESS; @@ -212,7 +213,7 @@ int translate(struct translation_context &translation_ctx, const std::string &te const std::string result_ = translation_ctx.detokenizer(translation_tokens); result = remove_start_punctuation(result_); } catch (std::exception &e) { - Logger::log(Logger::Level::ERROR, "Error: %s", e.what()); + Logger::log(Logger::Level::ERROR_LOG, "Error: %s", e.what()); return LOCAAL_TRANSLATION_FAIL; } return LOCAAL_TRANSLATION_SUCCESS; From 394dea6389e792497eda3e6376f7a8146a381948 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Tue, 17 Sep 2024 11:03:05 -0400 Subject: [PATCH 06/12] Refactor build and dependency management --- src/modules/core/include/logger.h | 3 + src/modules/core/src/logger.cpp | 11 ++ src/modules/transcription/CMakeLists.txt | 1 + .../transcription/include/audio-resampler.h | 32 +++++ .../include/transcription-context.h | 12 +- .../transcription/src/audio-resampler.cpp | 112 +++++++++++++++ .../transcription/src/vad-processing.cpp | 133 +++++++++--------- .../transcription/src/whisper-model-utils.cpp | 36 +++-- 8 files changed, 251 insertions(+), 89 deletions(-) create mode 100644 src/modules/transcription/include/audio-resampler.h create mode 100644 src/modules/transcription/src/audio-resampler.cpp diff --git a/src/modules/core/include/logger.h b/src/modules/core/include/logger.h index 593b5f0..2ab8703 100644 --- a/src/modules/core/include/logger.h +++ b/src/modules/core/include/logger.h @@ -11,10 +11,13 @@ class Logger { using LogCallback = std::function; static void setLogCallback(LogCallback callback); + static void setLogLevel(Level level); static void Logger::log(Level level, const std::string &format, ...); + // set log level private: static LogCallback s_logCallback; + static Level s_logLevel; static std::string getLevelString(Level level); }; diff --git a/src/modules/core/src/logger.cpp b/src/modules/core/src/logger.cpp index 7da6f7f..eb2276d 100644 --- a/src/modules/core/src/logger.cpp +++ b/src/modules/core/src/logger.cpp @@ -5,14 +5,25 @@ #include Logger::LogCallback Logger::s_logCallback = nullptr; +Logger::Level Logger::s_logLevel = Logger::Level::INFO; void Logger::setLogCallback(LogCallback callback) { s_logCallback = callback; } +void Logger::setLogLevel(Logger::Level level) +{ + // set log level + s_logLevel = level; +} + void Logger::log(Level level, const std::string &format, ...) { + if (level < s_logLevel) { + return; + } + // Default logging behavior va_list args; va_start(args, format); diff --git a/src/modules/transcription/CMakeLists.txt b/src/modules/transcription/CMakeLists.txt index 5041f5d..8487a48 100644 --- a/src/modules/transcription/CMakeLists.txt +++ b/src/modules/transcription/CMakeLists.txt @@ -1,4 +1,5 @@ add_library(Transcription + src/audio-resampler.cpp src/silero-vad-onnx.cpp src/token-buffer-thread.cpp src/transcription-utils.cpp diff --git a/src/modules/transcription/include/audio-resampler.h b/src/modules/transcription/include/audio-resampler.h new file mode 100644 index 0000000..914d2d1 --- /dev/null +++ b/src/modules/transcription/include/audio-resampler.h @@ -0,0 +1,32 @@ +#ifndef AUDIO_RESAMPLER_H +#define AUDIO_RESAMPLER_H + +#include + +extern "C" { +#include +} + +// Forward declarations +struct SwrContext; +struct AVChannelLayout; + +class AudioResampler { +public: + AudioResampler(); + ~AudioResampler(); + + void configure(int in_channels, int in_sample_rate, int out_channels, int out_sample_rate); + std::vector> resample(const std::vector> &input); + +private: + SwrContext *swr_ctx; + AVChannelLayout in_ch_layout; + AVChannelLayout out_ch_layout; + int in_sample_rate; + int out_sample_rate; + int in_channels; + int out_channels; +}; + +#endif // AUDIO_RESAMPLER_H diff --git a/src/modules/transcription/include/transcription-context.h b/src/modules/transcription/include/transcription-context.h index 97e7ef4..0db3b0b 100644 --- a/src/modules/transcription/include/transcription-context.h +++ b/src/modules/transcription/include/transcription-context.h @@ -16,6 +16,7 @@ #include "whisper-processing.h" #include "token-buffer-thread.h" #include "logger.h" +#include "audio-resampler.h" #define MAX_PREPROC_CHANNELS 10 @@ -45,13 +46,13 @@ struct transcription_context { bool cleared_last_sub; /* PCM buffers */ - float *copy_buffers[MAX_PREPROC_CHANNELS]; + std::vector> copy_buffers; std::deque info_buffer; - std::deque input_buffers[MAX_PREPROC_CHANNELS]; + std::vector> input_buffers; std::deque whisper_buffer; /* Resampler */ - audio_resampler_t *resampler_to_whisper; + AudioResampler resampler_to_whisper; std::deque resampled_buffer; /* whisper */ @@ -128,10 +129,9 @@ struct transcription_context { transcription_context() : whisper_buf_mutex(), whisper_ctx_mutex(), wshiper_thread_cv() { // initialize all pointers to nullptr - for (size_t i = 0; i < MAX_PREPROC_CHANNELS; i++) { - copy_buffers[i] = nullptr; + for (size_t i = 0; i < copy_buffers.size(); i++) { + copy_buffers[i].clear(); } - resampler_to_whisper = nullptr; whisper_model_path = ""; whisper_context = nullptr; output_file_path = ""; diff --git a/src/modules/transcription/src/audio-resampler.cpp b/src/modules/transcription/src/audio-resampler.cpp new file mode 100644 index 0000000..75fd8d0 --- /dev/null +++ b/src/modules/transcription/src/audio-resampler.cpp @@ -0,0 +1,112 @@ +#include "audio-resampler.h" + +#include +#include + +extern "C" { +#include +#include +#include +} + +AudioResampler::AudioResampler() + : swr_ctx(nullptr), + in_sample_rate(0), + out_sample_rate(0), + in_channels(0), + out_channels(0) +{ + av_channel_layout_default(&in_ch_layout, 0); + av_channel_layout_default(&out_ch_layout, 0); +} + +AudioResampler::~AudioResampler() +{ + if (swr_ctx) { + swr_free(&swr_ctx); + } + av_channel_layout_uninit(&in_ch_layout); + av_channel_layout_uninit(&out_ch_layout); +} + +void AudioResampler::configure(int in_channels, int in_sample_rate, int out_channels, + int out_sample_rate) +{ + if (swr_ctx) { + swr_free(&swr_ctx); + } + + av_channel_layout_uninit(&in_ch_layout); + av_channel_layout_uninit(&out_ch_layout); + av_channel_layout_default(&in_ch_layout, in_channels); + av_channel_layout_default(&out_ch_layout, out_channels); + + this->in_sample_rate = in_sample_rate; + this->out_sample_rate = out_sample_rate; + this->in_channels = in_channels; + this->out_channels = out_channels; + + swr_ctx = swr_alloc(); + if (!swr_ctx) { + throw std::runtime_error("Could not allocate resampler context"); + } + + av_opt_set_chlayout(swr_ctx, "in_chlayout", &in_ch_layout, 0); + av_opt_set_chlayout(swr_ctx, "out_chlayout", &out_ch_layout, 0); + av_opt_set_int(swr_ctx, "in_sample_rate", in_sample_rate, 0); + av_opt_set_int(swr_ctx, "out_sample_rate", out_sample_rate, 0); + av_opt_set_sample_fmt(swr_ctx, "in_sample_fmt", AV_SAMPLE_FMT_FLTP, 0); + av_opt_set_sample_fmt(swr_ctx, "out_sample_fmt", AV_SAMPLE_FMT_FLTP, 0); + + if (swr_init(swr_ctx) < 0) { + throw std::runtime_error("Failed to initialize the resampling context"); + } +} + +std::vector> +AudioResampler::resample(const std::vector> &input) +{ + if (!swr_ctx) { + throw std::runtime_error("Resampler not configured"); + } + + if (input.size() != in_channels) { + throw std::runtime_error("Input channel count doesn't match configuration"); + } + + int in_samples = (int)input[0].size(); + + // Prepare input data + std::vector in_data(in_channels); + for (int i = 0; i < in_channels; ++i) { + in_data[i] = input[i].data(); + } + + // Calculate output size + int64_t delay = swr_get_delay(swr_ctx, in_sample_rate); + int out_samples = (int)av_rescale_rnd(delay + in_samples, out_sample_rate, in_sample_rate, + AV_ROUND_UP); + + // Prepare output buffer + std::vector> output(out_channels, std::vector(out_samples)); + std::vector out_data(out_channels); + for (int i = 0; i < out_channels; ++i) { + out_data[i] = output[i].data(); + } + + // Perform resampling + int samples_out = + swr_convert(swr_ctx, reinterpret_cast(out_data.data()), out_samples, + reinterpret_cast(in_data.data()), in_samples); + + if (samples_out < 0) { + throw std::runtime_error("Error while converting"); + } + + // Resize output to actual number of samples + for (auto &channel : output) { + channel.resize(samples_out); + } + + return output; +} diff --git a/src/modules/transcription/src/vad-processing.cpp b/src/modules/transcription/src/vad-processing.cpp index 29c37c3..cbb6780 100644 --- a/src/modules/transcription/src/vad-processing.cpp +++ b/src/modules/transcription/src/vad-processing.cpp @@ -1,7 +1,5 @@ -#include - -#include "transcription-filter-data.h" +#include "transcription-context.h" #include "vad-processing.h" @@ -10,8 +8,7 @@ #include #endif -int get_data_from_buf_and_resample(transcription_filter_data *gf, - uint64_t &start_timestamp_offset_ns, +int get_data_from_buf_and_resample(transcription_context *gf, uint64_t &start_timestamp_offset_ns, uint64_t &end_timestamp_offset_ns) { uint32_t num_frames_from_infos = 0; @@ -20,13 +17,13 @@ int get_data_from_buf_and_resample(transcription_filter_data *gf, // scoped lock the buffer mutex std::lock_guard lock(gf->whisper_buf_mutex); - if (gf->input_buffers[0].size == 0) { + if (gf->input_buffers[0].empty() || gf->info_buffer.empty()) { return 1; } Logger::log(gf->log_level, "segmentation: currently %lu bytes in the audio input buffer", - gf->input_buffers[0].size); + gf->input_buffers[0].size()); // max number of frames is 10 seconds worth of audio const size_t max_num_frames = gf->sample_rate * 10; @@ -34,20 +31,20 @@ int get_data_from_buf_and_resample(transcription_filter_data *gf, // pop all infos from the info buffer and mark the beginning timestamp from the first // info as the beginning timestamp of the segment struct transcription_filter_audio_info info_from_buf = {0}; - const size_t size_of_audio_info = sizeof(transcription_filter_audio_info); - while (gf->info_buffer.size >= size_of_audio_info) { - circlebuf_pop_front(&gf->info_buffer, &info_from_buf, size_of_audio_info); + while (gf->info_buffer.size() > 0) { + info_from_buf = gf->info_buffer.front(); num_frames_from_infos += info_from_buf.frames; if (start_timestamp_offset_ns == 0) { start_timestamp_offset_ns = info_from_buf.timestamp_offset_ns; } // Check if we're within the needed segment length if (num_frames_from_infos > max_num_frames) { - // too big, push the last info into the buffer's front where it was + // too big, keep the last info where it was num_frames_from_infos -= info_from_buf.frames; - circlebuf_push_front(&gf->info_buffer, &info_from_buf, - size_of_audio_info); break; + } else { + // pop the info from the buffer + gf->info_buffer.pop_front(); } } // calculate the end timestamp from the info plus the number of frames in the packet @@ -65,14 +62,19 @@ int get_data_from_buf_and_resample(transcription_filter_data *gf, for (size_t c = 0; c < gf->channels; c++) { // zero the rest of copy_buffers - memset(gf->copy_buffers[c], 0, gf->frames * sizeof(float)); + gf->copy_buffers[c].resize(num_frames_from_infos, 0.0f); } /* Pop from input circlebuf */ for (size_t c = 0; c < gf->channels; c++) { - // Push the new data to copy_buffers[c] - circlebuf_pop_front(&gf->input_buffers[c], gf->copy_buffers[c], - num_frames_from_infos * sizeof(float)); + // Pop num_frames_from_infos samples from the input_buffers[c] into copy_buffers[c] + // and then remove the samples from the input_buffers[c] + std::copy(gf->input_buffers[c].begin(), + gf->input_buffers[c].begin() + num_frames_from_infos, + gf->copy_buffers[c].begin()); + gf->input_buffers[c].erase(gf->input_buffers[c].begin(), + gf->input_buffers[c].begin() + + num_frames_from_infos); } } @@ -81,31 +83,25 @@ int get_data_from_buf_and_resample(transcription_filter_data *gf, { // resample to 16kHz - float *resampled_16khz[MAX_PREPROC_CHANNELS]; - uint32_t resampled_16khz_frames; - uint64_t ts_offset; - { - ProfileScope("resample"); - audio_resampler_resample(gf->resampler_to_whisper, - (uint8_t **)resampled_16khz, - &resampled_16khz_frames, &ts_offset, - (const uint8_t **)gf->copy_buffers, - (uint32_t)num_frames_from_infos); - } + std::vector> resampled_16khz = + gf->resampler_to_whisper.resample(gf->copy_buffers); + + // push all of data from resampled_16khz[0] to the back of gf->resampled_buffer + const size_t resampled_16khz_frames = resampled_16khz[0].size(); + gf->resampled_buffer.insert(gf->resampled_buffer.end(), resampled_16khz[0].begin(), + resampled_16khz[0].end()); - circlebuf_push_back(&gf->resampled_buffer, resampled_16khz[0], - resampled_16khz_frames * sizeof(float)); Logger::log(gf->log_level, "resampled: %d channels, %d frames, %f ms, current size: %lu bytes", (int)gf->channels, (int)resampled_16khz_frames, (float)resampled_16khz_frames / WHISPER_SAMPLE_RATE * 1000.0f, - gf->resampled_buffer.size); + gf->resampled_buffer.size()); } return 0; } -vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_vad_state) +vad_state vad_based_segmentation(transcription_context *gf, vad_state last_vad_state) { // get data from buffer and resample uint64_t start_timestamp_offset_ns = 0; @@ -117,24 +113,27 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v return last_vad_state; } - const size_t vad_window_size_samples = gf->vad->get_window_size_samples() * sizeof(float); + const size_t vad_window_size_samples = gf->vad->get_window_size_samples(); const size_t min_vad_buffer_size = vad_window_size_samples * 8; - if (gf->resampled_buffer.size < min_vad_buffer_size) + if (gf->resampled_buffer.size() < min_vad_buffer_size) return last_vad_state; - size_t vad_num_windows = gf->resampled_buffer.size / vad_window_size_samples; + size_t vad_num_windows = gf->resampled_buffer.size(); std::vector vad_input; vad_input.resize(vad_num_windows * gf->vad->get_window_size_samples()); - circlebuf_pop_front(&gf->resampled_buffer, vad_input.data(), - vad_input.size() * sizeof(float)); + // pop the data from the resampled buffer + std::copy(gf->resampled_buffer.begin(), gf->resampled_buffer.begin() + vad_input.size(), + vad_input.begin()); + gf->resampled_buffer.erase(gf->resampled_buffer.begin(), + gf->resampled_buffer.begin() + vad_input.size()); + + // send the data to the VAD Logger::log(gf->log_level, "sending %d frames to vad, %d windows, reset state? %s", vad_input.size(), vad_num_windows, (!last_vad_state.vad_on) ? "yes" : "no"); - { - ProfileScope("vad->process"); - gf->vad->process(vad_input, !last_vad_state.vad_on); - } + + gf->vad->process(vad_input, !last_vad_state.vad_on); const uint64_t start_ts_offset_ms = start_timestamp_offset_ns / 1000000; const uint64_t end_ts_offset_ms = end_timestamp_offset_ns / 1000000; @@ -188,16 +187,15 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v const int number_of_frames = end_frame - start_frame; // push the data into gf-whisper_buffer - circlebuf_push_back(&gf->whisper_buffer, vad_input.data() + start_frame, - number_of_frames * sizeof(float)); + gf->whisper_buffer.insert(gf->whisper_buffer.end(), vad_input.begin() + end_frame, + vad_input.end()); Logger::log( gf->log_level, - "VAD segment %d/%d. pushed %d to %d (%d frames / %lu ms). current size: %lu bytes / %lu frames / %lu ms", + "VAD segment %d/%d. pushed %d to %d (%d frames / %lu ms). current size: %lu frames / %lu ms", i, (stamps.size() - 1), start_frame, end_frame, number_of_frames, - number_of_frames * 1000 / WHISPER_SAMPLE_RATE, gf->whisper_buffer.size, - gf->whisper_buffer.size / sizeof(float), - gf->whisper_buffer.size / sizeof(float) * 1000 / WHISPER_SAMPLE_RATE); + number_of_frames * 1000 / WHISPER_SAMPLE_RATE, gf->whisper_buffer.size(), + gf->whisper_buffer.size() * 1000 / WHISPER_SAMPLE_RATE); // segment "end" is in the middle of the buffer, send it to inference if (stamps[i].end < (int)vad_input.size()) { @@ -273,7 +271,7 @@ vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_v return current_vad_state; } -vad_state hybrid_vad_segmentation(transcription_filter_data *gf, vad_state last_vad_state) +vad_state hybrid_vad_segmentation(transcription_context *gf, vad_state last_vad_state) { // get data from buffer and resample uint64_t start_timestamp_offset_ns = 0; @@ -286,15 +284,12 @@ vad_state hybrid_vad_segmentation(transcription_filter_data *gf, vad_state last_ last_vad_state.end_ts_offset_ms = end_timestamp_offset_ns / 1000000; - // extract the data from the resampled buffer with circlebuf_pop_front into a temp buffer - // and then push it into the whisper buffer - const size_t resampled_buffer_size = gf->resampled_buffer.size; - std::vector temp_buffer; - temp_buffer.resize(resampled_buffer_size); - circlebuf_pop_front(&gf->resampled_buffer, temp_buffer.data(), resampled_buffer_size); - circlebuf_push_back(&gf->whisper_buffer, temp_buffer.data(), resampled_buffer_size); + // extract the data from the resampled buffer and push it into the whisper buffer + gf->whisper_buffer.insert(gf->whisper_buffer.end(), gf->resampled_buffer.begin(), + gf->resampled_buffer.end()); + gf->resampled_buffer.clear(); - Logger::log(gf->log_level, "whisper buffer size: %lu bytes", gf->whisper_buffer.size); + Logger::log(gf->log_level, "whisper buffer size: %lu frames", gf->whisper_buffer.size()); // use last_vad_state timestamps to calculate the duration of the current segment if (last_vad_state.end_ts_offset_ms - last_vad_state.start_ts_offest_ms >= @@ -328,17 +323,15 @@ vad_state hybrid_vad_segmentation(transcription_filter_data *gf, vad_state last_ // run vad on the current buffer std::vector vad_input; - vad_input.resize(gf->whisper_buffer.size / sizeof(float)); - circlebuf_peek_front(&gf->whisper_buffer, vad_input.data(), - vad_input.size() * sizeof(float)); + vad_input.resize(gf->whisper_buffer.size()); + std::copy(gf->whisper_buffer.begin(), gf->whisper_buffer.end(), + vad_input.begin()); Logger::log(gf->log_level, "sending %d frames to vad, %.1f ms", vad_input.size(), (float)vad_input.size() * 1000.0f / (float)WHISPER_SAMPLE_RATE); - { - ProfileScope("vad->process"); - gf->vad->process(vad_input, true); - } + + gf->vad->process(vad_input, true); if (gf->vad->get_speech_timestamps().size() > 0) { // VAD detected speech in the partial segment @@ -350,10 +343,10 @@ vad_state hybrid_vad_segmentation(transcription_filter_data *gf, vad_state last_ Logger::log(gf->log_level, "VAD detected silence in partial segment"); // pop the partial segment from the whisper buffer, save some audio for the next segment - const size_t num_bytes_to_keep = - (WHISPER_SAMPLE_RATE / 4) * sizeof(float); - circlebuf_pop_front(&gf->whisper_buffer, nullptr, - gf->whisper_buffer.size - num_bytes_to_keep); + const size_t num_frames_to_keep = (WHISPER_SAMPLE_RATE / 4); + gf->whisper_buffer.erase(gf->whisper_buffer.begin(), + gf->whisper_buffer.begin() + + num_frames_to_keep); } } } @@ -361,15 +354,15 @@ vad_state hybrid_vad_segmentation(transcription_filter_data *gf, vad_state last_ return last_vad_state; } -void initialize_vad(transcription_filter_data *gf, const char *silero_vad_model_file) +void initialize_vad(transcription_context *gf, const char *silero_vad_model_file) { // initialize Silero VAD #ifdef _WIN32 // convert mbstring to wstring int count = MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, - strlen(silero_vad_model_file), NULL, 0); + (int)strlen(silero_vad_model_file), NULL, 0); std::wstring silero_vad_model_path(count, 0); - MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, strlen(silero_vad_model_file), + MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, (int)strlen(silero_vad_model_file), &silero_vad_model_path[0], count); Logger::log(gf->log_level, "Create silero VAD: %S", silero_vad_model_path.c_str()); #else diff --git a/src/modules/transcription/src/whisper-model-utils.cpp b/src/modules/transcription/src/whisper-model-utils.cpp index 16e4dd7..4890a0d 100644 --- a/src/modules/transcription/src/whisper-model-utils.cpp +++ b/src/modules/transcription/src/whisper-model-utils.cpp @@ -45,19 +45,29 @@ void update_whisper_model(struct transcription_context *gf, const std::string ne std::string model_file_found = find_model_ext_file(model_info, ".bin"); if (model_file_found == "") { Logger::log(Logger::Level::WARNING, "Whisper model does not exist"); - download_model(model_info, [gf, new_model_path, silero_vad_model_file]( - int download_status, - const std::string &path) { - if (download_status == DownloadStatus::DOWNLOAD_STATUS_OK) { - Logger::log(Logger::Level::INFO, "Model download complete"); - gf->whisper_model_path = new_model_path; - start_whisper_thread_with_path( - gf, path, silero_vad_model_file.c_str()); - } else { - Logger::log(Logger::Level::ERROR_LOG, - "Model download failed"); - } - }); + download_model( + model_info, + [gf, new_model_path, silero_vad_model_file]( + int download_status, const std::string &path) { + if (download_status == DownloadStatus::DOWNLOAD_STATUS_OK) { + Logger::log(Logger::Level::INFO, + "Model download complete"); + gf->whisper_model_path = new_model_path; + start_whisper_thread_with_path( + gf, path, silero_vad_model_file.c_str()); + } else { + Logger::log(Logger::Level::ERROR_LOG, + "Model download failed"); + } + }, + [](int progress) { + Logger::log(Logger::Level::INFO, "Download progress: %d%%", + progress); + }, + [](int error_code, const std::string &error) { + Logger::log(Logger::Level::ERROR_LOG, "Download error: %s", + error.c_str()); + }); } else { // Model exists, just load it gf->whisper_model_path = new_model_path; From 2adef33f8549cdd25b0c87c3a93b0cc70387a2d4 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Tue, 17 Sep 2024 15:10:06 -0400 Subject: [PATCH 07/12] Refactor build and dependency management --- cmake/BuildICU.cmake | 4 ++ scripts/build-windows.ps1 | 35 +++++++++++- .../include/transcription-context.h | 2 +- .../include/transcription-utils.h | 11 +++- .../transcription/src/silero-vad-onnx.cpp | 2 +- .../transcription/src/vad-processing.cpp | 3 +- .../transcription/src/whisper-processing.cpp | 57 +++++++------------ .../transcription/src/whisper-utils.cpp | 6 +- src/modules/translation/CMakeLists.txt | 2 +- .../translation/include/translation-utils.h | 4 +- src/modules/translation/include/translation.h | 6 +- .../translation/src/translation-utils.cpp | 33 +++++------ src/modules/translation/src/translation.cpp | 15 +++-- 13 files changed, 101 insertions(+), 79 deletions(-) diff --git a/cmake/BuildICU.cmake b/cmake/BuildICU.cmake index 74ca70c..3160bc8 100644 --- a/cmake/BuildICU.cmake +++ b/cmake/BuildICU.cmake @@ -101,6 +101,10 @@ foreach(lib ${ICU_LIBRARIES}) endforeach() target_include_directories(ICU INTERFACE $ $) +set_target_properties(ICU PROPERTIES EXPORT_NAME ICU) + +# Export the targets file +install(DIRECTORY ${ICU_INCLUDE_DIR}/unicode DESTINATION include) # add exported target install(TARGETS ICU EXPORT ICUTargets) diff --git a/scripts/build-windows.ps1 b/scripts/build-windows.ps1 index d1717e1..71c093d 100644 --- a/scripts/build-windows.ps1 +++ b/scripts/build-windows.ps1 @@ -1,4 +1,35 @@ +param( + [switch]$Verbose, + [switch]$Clean +) -cmake -S . -B build_x64 -DCMAKE_BUILD_TYPE=Release -DLocaalSDK_FIND_COMPONENTS="Core;Transcription;Translation" +$verboseFlag = "" +$verboseBuildFlag = "" -cmake --build build_x64 --config Release +if ($Verbose) { + $verboseFlag = "-DCMAKE_VERBOSE_MAKEFILE=ON" + $verboseBuildFlag = "--verbose" +} + +$buildDir = "build_x64" + +# Clean build directory if requested +if ($Clean) { + if (Test-Path $buildDir) { + Write-Host "Cleaning build directory: $buildDir" + Remove-Item -Recurse -Force $buildDir + } + else { + Write-Host "Build directory does not exist. Nothing to clean." + } +} + +# Configure step +$configureCommand = "cmake -S . -B $buildDir -DCMAKE_BUILD_TYPE=Release -DLocaalSDK_FIND_COMPONENTS=`"Core;Transcription;Translation`" $verboseFlag" +Write-Host "Executing configure command: $configureCommand" +Invoke-Expression $configureCommand + +# Build step +$buildCommand = "cmake --build $buildDir --config Release $verboseBuildFlag" +Write-Host "Executing build command: $buildCommand" +Invoke-Expression $buildCommand diff --git a/src/modules/transcription/include/transcription-context.h b/src/modules/transcription/include/transcription-context.h index 0db3b0b..3b5169b 100644 --- a/src/modules/transcription/include/transcription-context.h +++ b/src/modules/transcription/include/transcription-context.h @@ -145,7 +145,7 @@ void clear_current_caption(transcription_context *gf_); // Callback sent when the VAD finds an audio chunk. Sample rate = WHISPER_SAMPLE_RATE, channels = 1 // The audio chunk is in 32-bit float format -void audio_chunk_callback(struct transcription_context *gf, const float *pcm32f_data, size_t frames, +void audio_chunk_callback(struct transcription_context *gf, const std::vector pcm32f_data, int vad_state, const DetectionResultWithText &result); #endif /* TRANSCRIPTION_CONTEXT_H */ diff --git a/src/modules/transcription/include/transcription-utils.h b/src/modules/transcription/include/transcription-utils.h index 5fdd0cf..cc26736 100644 --- a/src/modules/transcription/include/transcription-utils.h +++ b/src/modules/transcription/include/transcription-utils.h @@ -38,12 +38,17 @@ std::vector split_words(const std::string &str_copy); // trim (strip) string from leading and trailing whitespaces template StringLike trim(const StringLike &str) { + using CharType = typename StringLike::value_type; + StringLike str_copy = str; str_copy.erase(str_copy.begin(), - std::find_if(str_copy.begin(), str_copy.end(), - [](unsigned char ch) { return !std::isspace(ch); })); + std::find_if(str_copy.begin(), str_copy.end(), [](CharType ch) { + return !std::isspace(static_cast(ch)); + })); str_copy.erase(std::find_if(str_copy.rbegin(), str_copy.rend(), - [](unsigned char ch) { return !std::isspace(ch); }) + [](CharType ch) { + return !std::isspace(static_cast(ch)); + }) .base(), str_copy.end()); return str_copy; diff --git a/src/modules/transcription/src/silero-vad-onnx.cpp b/src/modules/transcription/src/silero-vad-onnx.cpp index 41c2293..299a685 100644 --- a/src/modules/transcription/src/silero-vad-onnx.cpp +++ b/src/modules/transcription/src/silero-vad-onnx.cpp @@ -95,7 +95,7 @@ void VadIterator::reset_states(bool reset_state) { if (reset_state) { // Call reset before each audio start - std::memset(_state.data(), 0.0f, _state.size() * sizeof(float)); + std::memset(_state.data(), 0, (int)_state.size() * sizeof(float)); triggered = false; } temp_end = 0; diff --git a/src/modules/transcription/src/vad-processing.cpp b/src/modules/transcription/src/vad-processing.cpp index cbb6780..cafc38c 100644 --- a/src/modules/transcription/src/vad-processing.cpp +++ b/src/modules/transcription/src/vad-processing.cpp @@ -154,8 +154,7 @@ vad_state vad_based_segmentation(transcription_context *gf, vad_state last_vad_s } if (gf->enable_audio_chunks_callback) { - audio_chunk_callback(gf, vad_input.data(), vad_input.size(), - VAD_STATE_IS_OFF, + audio_chunk_callback(gf, vad_input, VAD_STATE_IS_OFF, {DETECTION_RESULT_SILENCE, "[silence]", current_vad_state.start_ts_offest_ms, diff --git a/src/modules/transcription/src/whisper-processing.cpp b/src/modules/transcription/src/whisper-processing.cpp index 7cb8ee5..9db0e6d 100644 --- a/src/modules/transcription/src/whisper-processing.cpp +++ b/src/modules/transcription/src/whisper-processing.cpp @@ -122,9 +122,8 @@ struct whisper_context *init_whisper_context(const std::string &model_path_in, } struct DetectionResultWithText run_whisper_inference(struct transcription_context *gf, - const float *pcm32f_data_, - size_t pcm32f_num_samples, uint64_t t0 = 0, - uint64_t t1 = 0, + const std::vector &pcm32f_data_, + uint64_t t0 = 0, uint64_t t1 = 0, int vad_state = VAD_STATE_WAS_OFF) { if (gf == nullptr) { @@ -132,7 +131,7 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_contex return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""}; } - if (pcm32f_data_ == nullptr || pcm32f_num_samples == 0) { + if (pcm32f_data_.empty()) { Logger::log(Logger::Level::ERROR_LOG, "run_whisper_inference: pcm32f_data is null or size is 0"); return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""}; @@ -146,24 +145,24 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_contex } Logger::log(gf->log_level, "%s: processing %d samples, %.3f sec, %d threads", __func__, - int(pcm32f_num_samples), float(pcm32f_num_samples) / WHISPER_SAMPLE_RATE, + int(pcm32f_data_.size()), float(pcm32f_data_.size()) / WHISPER_SAMPLE_RATE, gf->whisper_params.n_threads); bool should_free_buffer = false; - float *pcm32f_data = (float *)pcm32f_data_; - size_t pcm32f_size = pcm32f_num_samples; + std::vector pcm32f_data; + size_t pcm32f_size = pcm32f_data_.size(); // incoming duration in ms const uint64_t incoming_duration_ms = - (uint64_t)(pcm32f_num_samples * 1000 / WHISPER_SAMPLE_RATE); + (uint64_t)(pcm32f_data_.size() * 1000 / WHISPER_SAMPLE_RATE); - if (pcm32f_num_samples < WHISPER_SAMPLE_RATE) { + if (pcm32f_data_.size() < WHISPER_SAMPLE_RATE) { Logger::log( gf->log_level, "Speech segment is less than 1 second, padding with white noise to 1 second"); const size_t new_size = (size_t)(1.01f * (float)(WHISPER_SAMPLE_RATE)); // create a new buffer and copy the data to it in the middle - pcm32f_data = (float *)malloc(new_size * sizeof(float)); + pcm32f_data.resize(new_size); // add low volume white noise const float noise_level = 0.01f; @@ -172,8 +171,8 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_contex noise_level * ((float)rand() / (float)RAND_MAX * 2.0f - 1.0f); } - memcpy(pcm32f_data + (new_size - pcm32f_num_samples) / 2, pcm32f_data_, - pcm32f_num_samples * sizeof(float)); + std::copy(pcm32f_data_.begin(), pcm32f_data_.end(), + pcm32f_data.begin() + (new_size - pcm32f_size) / 2); pcm32f_size = new_size; should_free_buffer = true; } @@ -202,20 +201,14 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_contex gf->whisper_params.duration_ms = (int)(whisper_duration_ms); try { whisper_full_result = whisper_full(gf->whisper_context, gf->whisper_params, - pcm32f_data, (int)pcm32f_size); + pcm32f_data.data(), (int)pcm32f_size); } catch (const std::exception &e) { Logger::log(Logger::Level::ERROR_LOG, "Whisper exception: %s. Filter restart is required", e.what()); whisper_free(gf->whisper_context); gf->whisper_context = nullptr; - if (should_free_buffer) { - free(pcm32f_data); - } return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""}; } - if (should_free_buffer) { - free(pcm32f_data); - } std::string language = gf->whisper_params.language; if (gf->whisper_params.language == nullptr || strlen(gf->whisper_params.language) == 0 || @@ -316,33 +309,25 @@ void run_inference_and_callbacks(transcription_context *gf, uint64_t start_offse uint64_t end_offset_ms, int vad_state) { // get the data from the entire whisper buffer - // add 50ms of silence to the beginning and end of the buffer const size_t pcm32f_size = gf->whisper_buffer.size(); - const size_t pcm32f_size_with_silence = pcm32f_size + 2 * WHISPER_SAMPLE_RATE / 100; + // allocate a new buffer and copy the data to it - float *pcm32f_data = (float *)malloc(pcm32f_size_with_silence * sizeof(float)); - if (vad_state == VAD_STATE_PARTIAL) { - // peek instead of pop, since this is a partial run that keeps the data in the buffer - circlebuf_peek_front(&gf->whisper_buffer, pcm32f_data + WHISPER_SAMPLE_RATE / 100, - pcm32f_size * sizeof(float)); - } else { - circlebuf_pop_front(&gf->whisper_buffer, pcm32f_data + WHISPER_SAMPLE_RATE / 100, - pcm32f_size * sizeof(float)); + std::vector pcm32f_data(pcm32f_size); + std::copy(gf->whisper_buffer.begin(), gf->whisper_buffer.end(), pcm32f_data.begin()); + + if (vad_state != VAD_STATE_PARTIAL) { + // clear the whisper buffer if we are not in partial state + gf->whisper_buffer.clear(); } struct DetectionResultWithText inference_result = - run_whisper_inference(gf, pcm32f_data, pcm32f_size_with_silence, start_offset_ms, - end_offset_ms, vad_state); + run_whisper_inference(gf, pcm32f_data, start_offset_ms, end_offset_ms, vad_state); // output inference result to a text source set_text_callback(gf, inference_result); if (gf->enable_audio_chunks_callback && vad_state != VAD_STATE_PARTIAL) { - audio_chunk_callback(gf, pcm32f_data, pcm32f_size_with_silence, vad_state, - inference_result); + audio_chunk_callback(gf, pcm32f_data, vad_state, inference_result); } - - // free the buffer - free(pcm32f_data); } void whisper_loop(void *data) diff --git a/src/modules/transcription/src/whisper-utils.cpp b/src/modules/transcription/src/whisper-utils.cpp index d6d7843..dce5edb 100644 --- a/src/modules/transcription/src/whisper-utils.cpp +++ b/src/modules/transcription/src/whisper-utils.cpp @@ -64,15 +64,15 @@ std::pair findStartOfOverlap(const std::vector &se if (seq1[i].id == seq2[j].id) { // Check if the next token in both sequences is the same if (seq1[i + 1].id == seq2[j + 1].id) { - return {i, j}; + return {(int)i, (int)j}; } // 1-skip check on seq1 if (i + 2 < seq1.size() && seq1[i + 2].id == seq2[j + 1].id) { - return {i, j}; + return {(int)i, (int)j}; } // 1-skip check on seq2 if (j + 2 < seq2.size() && seq1[i + 1].id == seq2[j + 2].id) { - return {i, j}; + return {(int)i, (int)j}; } } } diff --git a/src/modules/translation/CMakeLists.txt b/src/modules/translation/CMakeLists.txt index bca362b..cb4ffc6 100644 --- a/src/modules/translation/CMakeLists.txt +++ b/src/modules/translation/CMakeLists.txt @@ -11,7 +11,7 @@ target_include_directories(Translation $ ) -target_link_libraries(Translation INTERFACE Core ICU ct2 sentencepiece) +target_link_libraries(Translation PRIVATE Core ICU ct2 sentencepiece) set_target_properties(Translation PROPERTIES OUTPUT_NAME locaal_translation diff --git a/src/modules/translation/include/translation-utils.h b/src/modules/translation/include/translation-utils.h index ac98729..b83b429 100644 --- a/src/modules/translation/include/translation-utils.h +++ b/src/modules/translation/include/translation-utils.h @@ -1,8 +1,8 @@ #ifndef TRANSLATION_UTILS_H #define TRANSLATION_UTILS_H -#include "transcription-context.h" +#include "translation.h" -void start_translation(struct transcription_context *gf); +void start_translation(struct translation_context *ctx, const std::string &translation_model_index); #endif // TRANSLATION_UTILS_H diff --git a/src/modules/translation/include/translation.h b/src/modules/translation/include/translation.h index a631964..4bcacc6 100644 --- a/src/modules/translation/include/translation.h +++ b/src/modules/translation/include/translation.h @@ -6,6 +6,7 @@ #include #include #include +#include enum InputTokenizationStyle { INPUT_TOKENIZAION_M2M100 = 0, INPUT_TOKENIZAION_T5 }; @@ -31,10 +32,13 @@ struct translation_context { // How many sentences to use as context for the next translation int add_context; InputTokenizationStyle input_tokenization_style; + // model mutex + std::mutex model_mutex; + bool model_loaded; }; int build_translation_context(struct translation_context &translation_ctx); -void build_and_enable_translation(struct transcription_context *gf, +void build_and_enable_translation(struct translation_context *gf, const std::string &model_file_path); int translate(struct translation_context &translation_ctx, const std::string &text, diff --git a/src/modules/translation/src/translation-utils.cpp b/src/modules/translation/src/translation-utils.cpp index 6f412be..f8acb53 100644 --- a/src/modules/translation/src/translation-utils.cpp +++ b/src/modules/translation/src/translation-utils.cpp @@ -6,42 +6,37 @@ #include "logger.h" #include "model-downloader.h" -void start_translation(struct transcription_context *gf) +void start_translation(struct translation_context *ctx, const std::string &translation_model_index) { Logger::log(Logger::Level::INFO, "Starting translation..."); - if (gf->translation_model_index == "!!!external!!!") { - Logger::log(Logger::Level::INFO, "External model selected."); - if (gf->translation_model_path_external.empty()) { - Logger::log(Logger::Level::ERROR_LOG, "External model path is empty."); - gf->translate = false; - return; - } - std::string model_file_found = gf->translation_model_path_external; - build_and_enable_translation(gf, model_file_found); - return; - } - - const ModelInfo &translation_model_info = models_info[gf->translation_model_index]; + const ModelInfo &translation_model_info = models_info[translation_model_index]; std::string model_file_found = find_model_folder(translation_model_info); if (model_file_found == "") { Logger::log(Logger::Level::INFO, "Translation CT2 model does not exist. Downloading..."); - download_model_with_ui_dialog( + download_model( translation_model_info, - [gf, model_file_found](int download_status, const std::string &path) { + [ctx, model_file_found](int download_status, const std::string &path) { if (download_status == 0) { Logger::log(Logger::Level::INFO, "CT2 model download complete"); - build_and_enable_translation(gf, path); + build_and_enable_translation(ctx, path); } else { Logger::log(Logger::Level::ERROR_LOG, "Model download failed"); - gf->translate = false; + ctx->model_loaded = false; } + }, + [](int progress) { + Logger::log(Logger::Level::INFO, "Download progress: %d", progress); + }, + [](int error, const std::string &message) { + Logger::log(Logger::Level::ERROR_LOG, "Download error: %s", + message.c_str()); }); } else { // Model exists, just load it - build_and_enable_translation(gf, model_file_found); + build_and_enable_translation(ctx, model_file_found); } } diff --git a/src/modules/translation/src/translation.cpp b/src/modules/translation/src/translation.cpp index 7184c00..91af450 100644 --- a/src/modules/translation/src/translation.cpp +++ b/src/modules/translation/src/translation.cpp @@ -1,7 +1,6 @@ #include "translation.h" #include "logger.h" #include "model-find-utils.h" -#include "transcription-context.h" #include "language_codes.h" #include "translation-language-utils.h" @@ -10,18 +9,18 @@ #include -void build_and_enable_translation(struct transcription_context *gf, +void build_and_enable_translation(struct translation_context *ctx, const std::string &model_file_path) { - std::lock_guard lock(gf->whisper_ctx_mutex); + std::lock_guard lock(ctx->model_mutex); - gf->translation_ctx.local_model_folder_path = model_file_path; - if (build_translation_context(gf->translation_ctx) == LOCAAL_TRANSLATION_INIT_SUCCESS) { - Logger::log(Logger::Level::INFO, "Enable translation"); - gf->translate = true; + ctx->local_model_folder_path = model_file_path; + if (build_translation_context(*ctx) == LOCAAL_TRANSLATION_INIT_SUCCESS) { + Logger::log(Logger::Level::INFO, "Model loaded"); + ctx->model_loaded = true; } else { Logger::log(Logger::Level::ERROR_LOG, "Failed to load CT2 model"); - gf->translate = false; + ctx->model_loaded = false; } } From 435cc9e0139a21a4631fad1afeffeb489a6ea3d3 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Tue, 17 Sep 2024 15:39:18 -0400 Subject: [PATCH 08/12] Refactor build and dependency management, update GitHub Actions and CMake configuration --- .github/workflows/ci.yaml | 8 ++++---- CMakeLists.txt | 4 ++++ examples/CMakeLists.txt | 2 +- scripts/build-windows.ps1 | 4 +++- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 4508fb8..0b6db17 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -44,7 +44,7 @@ jobs: 7z a ${{ matrix.os }}-package.zip ./installed/* - name: Upload Artifact - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: ${{ matrix.os }}-package path: ${{runner.workspace}}/build/${{ matrix.os }}-package.zip @@ -56,10 +56,10 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Download all artifacts - uses: actions/download-artifact@v2 + uses: actions/download-artifact@v4 - name: Create Release id: create_release @@ -73,7 +73,7 @@ jobs: prerelease: false - name: Upload Release Assets - uses: actions/github-script@v3 + uses: actions/github-script@v7 with: github-token: ${{secrets.GITHUB_TOKEN}} script: | diff --git a/CMakeLists.txt b/CMakeLists.txt index 7b719c1..7c3b248 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -84,6 +84,10 @@ endforeach() target_link_libraries(${CMAKE_PROJECT_NAME} INTERFACE ${LOCAAL_ENABLED_MODULES}) +if(BUILD_EXAMPLES) + add_subdirectory(examples) +endif() + # Generate and install package configuration files include(CMakePackageConfigHelpers) write_basic_package_version_file( diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 73d443e..f9d8a65 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -2,4 +2,4 @@ find_package(LocaalSDK REQUIRED COMPONENTS Core Transcription Translation) add_executable(RealtimeTranscription realtime_transcription.cpp) -target_link_libraries(MyApp PRIVATE LocaalSDK::Core LocaalSDK::Transcription LocaalSDK::Translation) +target_link_libraries(MyApp PRIVATE LocaalSDK::Core LocaalSDK::Transcription) diff --git a/scripts/build-windows.ps1 b/scripts/build-windows.ps1 index 71c093d..4b03ef2 100644 --- a/scripts/build-windows.ps1 +++ b/scripts/build-windows.ps1 @@ -25,7 +25,9 @@ if ($Clean) { } # Configure step -$configureCommand = "cmake -S . -B $buildDir -DCMAKE_BUILD_TYPE=Release -DLocaalSDK_FIND_COMPONENTS=`"Core;Transcription;Translation`" $verboseFlag" +$configureCommand = "cmake -S . -B $buildDir -DCMAKE_BUILD_TYPE=Release ` + -DLocaalSDK_FIND_COMPONENTS=`"Core;Transcription;Translation`" $verboseFlag ` + -DBUILD_EXAMPLES=ON" Write-Host "Executing configure command: $configureCommand" Invoke-Expression $configureCommand From 5eb0b59ea33601a5a33c3d896714b4a6bf513011 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Tue, 17 Sep 2024 15:45:12 -0400 Subject: [PATCH 09/12] Refactor build and dependency management, update build scripts for Windows and Linux/MacOS --- .github/workflows/ci.yaml | 16 ++++++++------ scripts/build-nix.sh | 44 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 6 deletions(-) create mode 100644 scripts/build-nix.sh diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 0b6db17..11fcfa0 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -29,13 +29,17 @@ jobs: - name: Create Build Environment run: cmake -E make_directory ${{runner.workspace}}/build - - name: Configure CMake - working-directory: ${{runner.workspace}}/build - run: cmake $GITHUB_WORKSPACE -G "${{ matrix.cmake_generator }}" + - name: Run Build Script Windows + if: matrix.os == 'windows-latest' + working-directory: ${{runner.workspace}} + run: ./scripts/build-windows.ps1 + shell: pwsh - - name: Build - working-directory: ${{runner.workspace}}/build - run: cmake --build . --config Release + - name: Run Build Script Linux and MacOS + if: matrix.os != 'windows-latest' + working-directory: ${{runner.workspace}} + run: ./scripts/build-windows.sh + shell: bash - name: Package working-directory: ${{runner.workspace}}/build diff --git a/scripts/build-nix.sh b/scripts/build-nix.sh new file mode 100644 index 0000000..c2bae72 --- /dev/null +++ b/scripts/build-nix.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +verbose=false +clean=false + +# Parse command line arguments +while [[ "$#" -gt 0 ]]; do + case $1 in + -v|--verbose) verbose=true ;; + -c|--clean) clean=true ;; + *) echo "Unknown parameter passed: $1"; exit 1 ;; + esac + shift +done + +verbose_flag="" +verbose_build_flag="" + +if [ "$verbose" = true ] ; then + verbose_flag="-DCMAKE_VERBOSE_MAKEFILE=ON" + verbose_build_flag="--verbose" +fi + +build_dir="build_x64" + +# Clean build directory if requested +if [ "$clean" = true ] ; then + if [ -d "$build_dir" ] ; then + echo "Cleaning build directory: $build_dir" + rm -rf "$build_dir" + else + echo "Build directory does not exist. Nothing to clean." + fi +fi + +# Configure step +configure_command="cmake -S . -B $build_dir -DCMAKE_BUILD_TYPE=Release -DLocaalSDK_FIND_COMPONENTS=\"Core;Transcription;Translation\" $verbose_flag" +echo "Executing configure command: $configure_command" +eval $configure_command + +# Build step +build_command="cmake --build $build_dir --config Release $verbose_build_flag" +echo "Executing build command: $build_command" +eval $build_command From 5523e186b55f50d798f3bebf2995133b70e81427 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Tue, 17 Sep 2024 15:45:53 -0400 Subject: [PATCH 10/12] Refactor build script for Linux and MacOS in CI workflow --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 11fcfa0..827dda3 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -38,7 +38,7 @@ jobs: - name: Run Build Script Linux and MacOS if: matrix.os != 'windows-latest' working-directory: ${{runner.workspace}} - run: ./scripts/build-windows.sh + run: ./scripts/build-nix.sh shell: bash - name: Package From ab332c21426b5c45d1ccf8c563bddddeacc0a2c3 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Tue, 17 Sep 2024 17:52:04 -0400 Subject: [PATCH 11/12] Refactor build script for Linux and MacOS in CI workflow --- cmake/BuildSDL.cmake | 2 +- examples/CMakeLists.txt | 6 +-- examples/realtime_transcription.cpp | 4 +- include/locaal.h | 0 scripts/build-windows.ps1 | 4 +- src/modules/core/include/locaal.h | 1 + src/modules/transcription/CMakeLists.txt | 1 + .../transcription/include/transcription.h | 49 +++++++++++++++++ .../transcription/src/transcription.cpp | 53 +++++++++++++++++++ 9 files changed, 112 insertions(+), 8 deletions(-) delete mode 100644 include/locaal.h create mode 100644 src/modules/core/include/locaal.h create mode 100644 src/modules/transcription/include/transcription.h create mode 100644 src/modules/transcription/src/transcription.cpp diff --git a/cmake/BuildSDL.cmake b/cmake/BuildSDL.cmake index 772d757..5ee5b17 100644 --- a/cmake/BuildSDL.cmake +++ b/cmake/BuildSDL.cmake @@ -34,7 +34,7 @@ target_include_directories(SDL2 INTERFACE # Link SDL2 and SDL2main libraries target_link_libraries(SDL2 INTERFACE - $ + $ $ ) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index f9d8a65..0100938 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,5 +1,5 @@ -find_package(LocaalSDK REQUIRED COMPONENTS Core Transcription Translation) +include(${CMAKE_SOURCE_DIR}/cmake/BuildSDL.cmake) -add_executable(RealtimeTranscription realtime_transcription.cpp) -target_link_libraries(MyApp PRIVATE LocaalSDK::Core LocaalSDK::Transcription) +add_executable(RealtimeTranscription realtime_transcription.cpp audio_capture.cpp) +target_link_libraries(RealtimeTranscription PRIVATE SDL2 Core Transcription) diff --git a/examples/realtime_transcription.cpp b/examples/realtime_transcription.cpp index 896a18f..0aa0a96 100644 --- a/examples/realtime_transcription.cpp +++ b/examples/realtime_transcription.cpp @@ -19,7 +19,7 @@ int main() { }); // Set the callbacks for the transcription - tt.setTranscriptionCallback([](const locaal::DetectionResultWithText &result) { + tt.setTranscriptionCallback([](const locaal::TranscriptionResult &result) { // Print the transcription result std::cout << "Transcription: " << result.text << std::endl; }); @@ -50,7 +50,7 @@ int main() { audio_capture.getAudioData(1000, audio_data); // Process the audio data for transcription - tt.processAudio(audio_data.data(), audio_data.size()); + tt.processAudio(audio_data); } return 0; diff --git a/include/locaal.h b/include/locaal.h deleted file mode 100644 index e69de29..0000000 diff --git a/scripts/build-windows.ps1 b/scripts/build-windows.ps1 index 4b03ef2..4984202 100644 --- a/scripts/build-windows.ps1 +++ b/scripts/build-windows.ps1 @@ -25,8 +25,8 @@ if ($Clean) { } # Configure step -$configureCommand = "cmake -S . -B $buildDir -DCMAKE_BUILD_TYPE=Release ` - -DLocaalSDK_FIND_COMPONENTS=`"Core;Transcription;Translation`" $verboseFlag ` +$configureCommand = "cmake -S . -B $buildDir -DCMAKE_BUILD_TYPE=Release `` + -DLocaalSDK_FIND_COMPONENTS=`"Core;Transcription;Translation`" $verboseFlag `` -DBUILD_EXAMPLES=ON" Write-Host "Executing configure command: $configureCommand" Invoke-Expression $configureCommand diff --git a/src/modules/core/include/locaal.h b/src/modules/core/include/locaal.h new file mode 100644 index 0000000..4e29009 --- /dev/null +++ b/src/modules/core/include/locaal.h @@ -0,0 +1 @@ +#include diff --git a/src/modules/transcription/CMakeLists.txt b/src/modules/transcription/CMakeLists.txt index 8487a48..3b2bdbe 100644 --- a/src/modules/transcription/CMakeLists.txt +++ b/src/modules/transcription/CMakeLists.txt @@ -3,6 +3,7 @@ add_library(Transcription src/silero-vad-onnx.cpp src/token-buffer-thread.cpp src/transcription-utils.cpp + src/transcription.cpp src/vad-processing.cpp src/whisper-model-utils.cpp src/whisper-processing.cpp diff --git a/src/modules/transcription/include/transcription.h b/src/modules/transcription/include/transcription.h new file mode 100644 index 0000000..44070e8 --- /dev/null +++ b/src/modules/transcription/include/transcription.h @@ -0,0 +1,49 @@ +#ifndef LOCAAL_TRANSCRIPTION_H +#define LOCAAL_TRANSCRIPTION_H + +#include +#include + +namespace locaal { + +struct TranscriptionResult { + std::string text; + uint64_t start_timestamp_ms; + uint64_t end_timestamp_ms; + std::string language; + bool is_partial; +}; + +class Transcription { +public: + Transcription(); + ~Transcription(); + + void setTranscriptionParams(const std::string& language); + + void setModelDownloadCallbacks( + std::function onSuccess, + std::function onFailure, + std::function onProgress + ); + + void setTranscriptionCallback(std::function callback); + + void startTranscription(); + void stopTranscription(); // Added for completeness + + void processAudio(const std::vector& audioData); + +private: + std::string language_; + std::function onModelDownloadSuccess_; + std::function onModelDownloadFailure_; + std::function onModelDownloadProgress_; + std::function transcriptionCallback_; + + // Add any other necessary private members +}; + +} // namespace locaal + +#endif // LOCAAL_TRANSCRIPTION_H diff --git a/src/modules/transcription/src/transcription.cpp b/src/modules/transcription/src/transcription.cpp new file mode 100644 index 0000000..d7bdcd2 --- /dev/null +++ b/src/modules/transcription/src/transcription.cpp @@ -0,0 +1,53 @@ +#include "transcription.h" +#include "logger.h" +#include + +namespace locaal { + +Transcription::Transcription() { + // Constructor implementation +} + +Transcription::~Transcription() { + // Destructor implementation +} + +void Transcription::setTranscriptionParams(const std::string& language) { + language_ = language; + // Add any additional logic for setting transcription parameters +} + +void Transcription::setModelDownloadCallbacks( + std::function onSuccess, + std::function onFailure, + std::function onProgress +) { + onModelDownloadSuccess_ = onSuccess; + onModelDownloadFailure_ = onFailure; + onModelDownloadProgress_ = onProgress; + // Add any additional logic for setting model download callbacks +} + +void Transcription::setTranscriptionCallback(std::function callback) { + transcriptionCallback_ = callback; + // Add any additional logic for setting transcription callback +} + +void Transcription::startTranscription() { + Logger::log(Logger::Level::INFO, "Starting transcription..."); + // Implement the logic to start the transcription process + // This might involve starting a new thread, initializing audio capture, etc. +} + +void Transcription::stopTranscription() { + Logger::log(Logger::Level::INFO, "Stopping transcription..."); + // Implement the logic to stop the transcription process + // This might involve stopping the transcription thread, cleaning up resources, etc. +} + +void Transcription::processAudio(const std::vector& audioData) { + Logger::log(Logger::Level::INFO, "Processing audio data..."); +} + + +} // namespace locaal From d709066777d6c1c13341fc06e20efe2b0e997b2e Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Tue, 17 Sep 2024 17:53:35 -0400 Subject: [PATCH 12/12] Refactor build script for Linux and MacOS in CI workflow --- .github/workflows/ci.yaml | 2 +- .../transcription/include/transcription.h | 37 +++++------ .../include/whisper-processing.h | 4 +- .../transcription/src/transcription.cpp | 65 ++++++++++--------- 4 files changed, 58 insertions(+), 50 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 827dda3..06701a7 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -14,7 +14,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, macos-latest, windows-latest] + os: [windows-latest] # ubuntu-latest, macos-latest, include: - os: ubuntu-latest cmake_generator: "Unix Makefiles" diff --git a/src/modules/transcription/include/transcription.h b/src/modules/transcription/include/transcription.h index 44070e8..4daa89a 100644 --- a/src/modules/transcription/include/transcription.h +++ b/src/modules/transcription/include/transcription.h @@ -11,37 +11,36 @@ struct TranscriptionResult { uint64_t start_timestamp_ms; uint64_t end_timestamp_ms; std::string language; - bool is_partial; + bool is_partial; }; class Transcription { public: - Transcription(); - ~Transcription(); + Transcription(); + ~Transcription(); - void setTranscriptionParams(const std::string& language); + void setTranscriptionParams(const std::string &language); - void setModelDownloadCallbacks( - std::function onSuccess, - std::function onFailure, - std::function onProgress - ); + void setModelDownloadCallbacks( + std::function onSuccess, + std::function onFailure, + std::function onProgress); - void setTranscriptionCallback(std::function callback); + void setTranscriptionCallback(std::function callback); - void startTranscription(); - void stopTranscription(); // Added for completeness + void startTranscription(); + void stopTranscription(); // Added for completeness - void processAudio(const std::vector& audioData); + void processAudio(const std::vector &audioData); private: - std::string language_; - std::function onModelDownloadSuccess_; - std::function onModelDownloadFailure_; - std::function onModelDownloadProgress_; - std::function transcriptionCallback_; + std::string language_; + std::function onModelDownloadSuccess_; + std::function onModelDownloadFailure_; + std::function onModelDownloadProgress_; + std::function transcriptionCallback_; - // Add any other necessary private members + // Add any other necessary private members }; } // namespace locaal diff --git a/src/modules/transcription/include/whisper-processing.h b/src/modules/transcription/include/whisper-processing.h index ef645d7..aac0258 100644 --- a/src/modules/transcription/include/whisper-processing.h +++ b/src/modules/transcription/include/whisper-processing.h @@ -28,8 +28,10 @@ struct DetectionResultWithText { std::vector tokens; std::string language; }; +whats it -void whisper_loop(void *data); + void + whisper_loop(void *data); struct whisper_context *init_whisper_context(const std::string &model_path, struct transcription_context *gf); void run_inference_and_callbacks(transcription_context *gf, uint64_t start_offset_ms, diff --git a/src/modules/transcription/src/transcription.cpp b/src/modules/transcription/src/transcription.cpp index d7bdcd2..92652ee 100644 --- a/src/modules/transcription/src/transcription.cpp +++ b/src/modules/transcription/src/transcription.cpp @@ -4,50 +4,57 @@ namespace locaal { -Transcription::Transcription() { - // Constructor implementation +Transcription::Transcription() +{ + // Constructor implementation } -Transcription::~Transcription() { - // Destructor implementation +Transcription::~Transcription() +{ + // Destructor implementation } -void Transcription::setTranscriptionParams(const std::string& language) { - language_ = language; - // Add any additional logic for setting transcription parameters +void Transcription::setTranscriptionParams(const std::string &language) +{ + language_ = language; + // Add any additional logic for setting transcription parameters } void Transcription::setModelDownloadCallbacks( - std::function onSuccess, - std::function onFailure, - std::function onProgress -) { - onModelDownloadSuccess_ = onSuccess; - onModelDownloadFailure_ = onFailure; - onModelDownloadProgress_ = onProgress; - // Add any additional logic for setting model download callbacks + std::function onSuccess, + std::function onFailure, + std::function onProgress) +{ + onModelDownloadSuccess_ = onSuccess; + onModelDownloadFailure_ = onFailure; + onModelDownloadProgress_ = onProgress; + // Add any additional logic for setting model download callbacks } -void Transcription::setTranscriptionCallback(std::function callback) { - transcriptionCallback_ = callback; - // Add any additional logic for setting transcription callback +void Transcription::setTranscriptionCallback( + std::function callback) +{ + transcriptionCallback_ = callback; + // Add any additional logic for setting transcription callback } -void Transcription::startTranscription() { - Logger::log(Logger::Level::INFO, "Starting transcription..."); - // Implement the logic to start the transcription process - // This might involve starting a new thread, initializing audio capture, etc. +void Transcription::startTranscription() +{ + Logger::log(Logger::Level::INFO, "Starting transcription..."); + // Implement the logic to start the transcription process + // This might involve starting a new thread, initializing audio capture, etc. } -void Transcription::stopTranscription() { - Logger::log(Logger::Level::INFO, "Stopping transcription..."); - // Implement the logic to stop the transcription process - // This might involve stopping the transcription thread, cleaning up resources, etc. +void Transcription::stopTranscription() +{ + Logger::log(Logger::Level::INFO, "Stopping transcription..."); + // Implement the logic to stop the transcription process + // This might involve stopping the transcription thread, cleaning up resources, etc. } -void Transcription::processAudio(const std::vector& audioData) { - Logger::log(Logger::Level::INFO, "Processing audio data..."); +void Transcription::processAudio(const std::vector &audioData) +{ + Logger::log(Logger::Level::INFO, "Processing audio data..."); } - } // namespace locaal