Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Update speaker_id to uint32_t and add speed parameter for gene… #4

Merged
merged 2 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions data/locale/en-US.ini
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
none_no_input=No input
Phonetic_Transcription=Phonetic Transcription
File=File
Text=Text
Generate_Audio=Generate Audio
Speaker_ID=Speaker ID
Model=Model
Delete_Cached_Models=Delete Cached Models
none_no_input="No input"
Phonetic_Transcription="Phonetic Transcription"
File="File"
Text="Text"
Generate_Audio="Generate Audio"
Speaker_ID="Speaker ID"
Model="Model"
Delete_Cached_Models="Delete Cached Models"
Speed="Speed"
19 changes: 13 additions & 6 deletions src/input-thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ void InputThread::run()
while (running) {
obs_log(LOG_DEBUG, "Input thread checking for changes");

std::string new_content_for_generation;

// Monitor files for changes
if (!file.empty()) {
// Check if file has changed
Expand All @@ -37,9 +39,7 @@ void InputThread::run()
}
if (fileContents != lastFileValue) {
// Invoke speech generation if it has changed
if (speechGenerationCallback) {
speechGenerationCallback(fileContents);
}
new_content_for_generation = fileContents;
lastFileValue = fileContents;
}
}
Expand All @@ -58,16 +58,23 @@ void InputThread::run()
obs_data_release(sourceSettings);
if (text && lastOBSTextSourceValue != text) {
// Invoke speech generation if it has changed
if (speechGenerationCallback) {
speechGenerationCallback(text);
}
new_content_for_generation = text;
lastOBSTextSourceValue = text;
}
}
obs_source_release(source);
}
}

if (!new_content_for_generation.empty() && speechGenerationCallback) {
std::thread generationThread([this, new_content_for_generation]() {
obs_log(LOG_DEBUG, "Generating speech from input: %s",
new_content_for_generation.c_str());
speechGenerationCallback(new_content_for_generation);
});
generationThread.detach();
}

// Sleep for a certain interval before checking again
std::this_thread::sleep_for(std::chrono::milliseconds(interval));
}
Expand Down
33 changes: 26 additions & 7 deletions src/sherpa-tts/sherpa-tts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,36 @@

#include <obs-module.h>

void generate_audio_from_text(sherpa_tts_context &ctx, const char *text, int speaker_id)
void generate_audio_from_text(sherpa_tts_context &ctx, const std::string &text, uint32_t speaker_id,
float speed)
{
if (ctx.tts == nullptr) {
if (ctx.tts == nullptr || !ctx.initialized || text.empty() ||
ctx.audio_callback == nullptr || speed <= 0.0f) {
return;
}

const SherpaOnnxGeneratedAudio *audio =
SherpaOnnxOfflineTtsGenerate(ctx.tts, text, speaker_id, 1.0);
if (ctx.num_speakers == 0) {
obs_log(LOG_WARNING, "No speakers found in the model. Assuming speaker id 0.");
speaker_id = 0;
} else if (speaker_id >= ctx.num_speakers) {
obs_log(LOG_WARNING, "Speaker id %d is out of range (0 -> %d), using speaker id 0",
speaker_id, ctx.num_speakers - 1);
speaker_id = 0;
}

try {
obs_log(LOG_DEBUG, "Generating audio from text: %s, speaker_id: %d, speed: %f",
text.c_str(), speaker_id, speed);
const SherpaOnnxGeneratedAudio *audio =
SherpaOnnxOfflineTtsGenerate(ctx.tts, text.c_str(), speaker_id, speed);

// Call the audio callback function with the generated audio samples
ctx.audio_callback(ctx.callback_data, audio->samples, audio->n, audio->sample_rate);
// Call the audio callback function with the generated audio samples
ctx.audio_callback(ctx.callback_data, audio->samples, audio->n, audio->sample_rate);

SherpaOnnxDestroyOfflineTtsGeneratedAudio(audio);
SherpaOnnxDestroyOfflineTtsGeneratedAudio(audio);
} catch (const std::exception &e) {
obs_log(LOG_ERROR, "Error generating audio from text: %s", e.what());
}
}

void init_sherpa_tts_context(sherpa_tts_context &context,
Expand Down Expand Up @@ -66,6 +83,8 @@ void init_sherpa_tts_context(sherpa_tts_context &context,
context.tts = SherpaOnnxCreateOfflineTts(&config);
context.audio_callback = audio_callback;
context.callback_data = data;
context.num_speakers = std::max(1, SherpaOnnxOfflineTtsNumSpeakers(context.tts));
context.initialized = true;
}

void destroy_sherpa_tts_context(sherpa_tts_context &context)
Expand Down
5 changes: 4 additions & 1 deletion src/sherpa-tts/sherpa-tts.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ struct sherpa_tts_context {
std::string model_name;
void *callback_data = nullptr;
void (*audio_callback)(void *data, const float *samples, int num_samples, int sample_rate);
uint32_t num_speakers = 0;
bool initialized = false;
};

void init_sherpa_tts_context(sherpa_tts_context &context,
Expand All @@ -18,6 +20,7 @@ void init_sherpa_tts_context(sherpa_tts_context &context,
void *data);
void destroy_sherpa_tts_context(sherpa_tts_context &context);

void generate_audio_from_text(sherpa_tts_context &ctx, const char *text, int speaker_id);
void generate_audio_from_text(sherpa_tts_context &ctx, const std::string &text, uint32_t speaker_id,
float speed);

#endif
3 changes: 2 additions & 1 deletion src/squawk-source-data.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ struct squawk_source_data {
sherpa_tts_context tts_context;
std::unique_ptr<InputThread> inputThread;

int speaker_id;
uint32_t speaker_id;
float speed;
bool phonetic_transcription;

squawk_source_data() { context = nullptr; }
Expand Down
15 changes: 10 additions & 5 deletions src/squawk-source.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ void *squawk_source_create(obs_data_t *settings, obs_source_t *source)
if (squawk_data->phonetic_transcription) {
transformed_text = phonetic_transcription(text);
}
generate_audio_from_text(squawk_data->tts_context, transformed_text.c_str(),
squawk_data->speaker_id);
generate_audio_from_text(squawk_data->tts_context, transformed_text,
squawk_data->speaker_id, squawk_data->speed);
});
squawk_data->inputThread->start();

Expand All @@ -70,6 +70,7 @@ void squawk_source_destroy(void *data)
void squawk_source_defaults(obs_data_t *settings)
{
obs_data_set_default_int(settings, "speaker_id", 0);
obs_data_set_default_double(settings, "speed", 1.0);
obs_data_set_default_string(settings, "text", "Hello, World!");
obs_data_set_default_string(settings, "model", "vits-coqui-en-vctk");
obs_data_set_default_string(settings, "input_source", "none");
Expand Down Expand Up @@ -143,6 +144,9 @@ obs_properties_t *squawk_source_properties(void *data)
// add speaker id property
obs_properties_add_int(ppts, "speaker_id", MT_("Speaker_ID"), 0, 100, 1);

// add a speed slider between 0.1 and 2.5
obs_properties_add_float_slider(ppts, "speed", MT_("Speed"), 0.1, 2.5, 0.1);

// add input source selection dropdown property
obs_property_t *input_source = obs_properties_add_list(
ppts, "input_source", "Input Source", OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_STRING);
Expand Down Expand Up @@ -179,8 +183,8 @@ obs_properties_t *squawk_source_properties(void *data)
original_text.c_str(), text.c_str());
}

generate_audio_from_text(squawk_data_->tts_context, text.c_str(),
speaker_id);
generate_audio_from_text(squawk_data_->tts_context, text, speaker_id,
squawk_data_->speed);

return true;
});
Expand Down Expand Up @@ -221,7 +225,8 @@ void squawk_source_update(void *data, obs_data_t *settings)

squawk_source_data *squawk_data = (squawk_source_data *)data;

squawk_data->speaker_id = (int)obs_data_get_int(settings, "speaker_id");
squawk_data->speaker_id = (uint32_t)obs_data_get_int(settings, "speaker_id");
squawk_data->speed = (float)obs_data_get_double(settings, "speed");
squawk_data->phonetic_transcription = obs_data_get_bool(settings, "phonetic_transcription");

std::string source = obs_data_get_string(settings, "input_source");
Expand Down