diff --git a/packages/_common/src/models.ts b/packages/_common/src/models.ts index 45c0c0f..b2136b9 100644 --- a/packages/_common/src/models.ts +++ b/packages/_common/src/models.ts @@ -22,20 +22,22 @@ export class Silero { constructor( private ort: ONNXRuntimeAPI, - private modelFetcher: ModelFetcher + private modelFetcher: ModelFetcher, + private nativeSampleRate: number, ) {} - static new = async (ort: ONNXRuntimeAPI, modelFetcher: ModelFetcher) => { - const model = new Silero(ort, modelFetcher) + static new = async (ort: ONNXRuntimeAPI, modelFetcher: ModelFetcher, nativeSampleRate: number) => { + const model = new Silero(ort, modelFetcher, nativeSampleRate) await model.init() return model } init = async () => { - log.debug("initializing vad") + log.debug(`initializing vad`) const modelArrayBuffer = await this.modelFetcher() this._session = await this.ort.InferenceSession.create(modelArrayBuffer) - this._sr = new this.ort.Tensor("int64", [16000n]) + const tensorSize = BigInt(this.nativeSampleRate) + this._sr = new this.ort.Tensor("int64", [tensorSize]) this.reset_state() log.debug("vad is initialized") } diff --git a/packages/_common/src/non-real-time-vad.ts b/packages/_common/src/non-real-time-vad.ts index cd026f8..a8330b4 100644 --- a/packages/_common/src/non-real-time-vad.ts +++ b/packages/_common/src/non-real-time-vad.ts @@ -15,10 +15,15 @@ interface NonRealTimeVADSpeechData { end: number } -export interface NonRealTimeVADOptions extends FrameProcessorOptions {} +export interface NonRealTimeVADOptions extends FrameProcessorOptions { + nativeSampleRate: number + targetSampleRate: number +} export const defaultNonRealTimeVADOptions: NonRealTimeVADOptions = { ...defaultFrameProcessorOptions, + nativeSampleRate: 16000, + targetSampleRate: 16000, } export class PlatformAgnosticNonRealTimeVAD { @@ -46,7 +51,7 @@ export class PlatformAgnosticNonRealTimeVAD { } init = async () => { - const model = await Silero.new(this.ort, this.modelFetcher) + const model = await Silero.new(this.ort, this.modelFetcher, this.options.nativeSampleRate) this.frameProcessor = new FrameProcessor(model.process, model.reset_state, { frameSamples: this.options.frameSamples, @@ -61,26 +66,30 @@ export class PlatformAgnosticNonRealTimeVAD { run = async function* ( inputAudio: Float32Array, - sampleRate: number + sampleRate?: number, ): AsyncGenerator { + + const targetSampleRate = this.options.targetSampleRate ?? 16000 const resamplerOptions = { - nativeSampleRate: sampleRate, - targetSampleRate: 16000, + nativeSampleRate: sampleRate ?? this.options.nativeSampleRate, + targetSampleRate: targetSampleRate, targetFrameSize: this.options.frameSamples, } + const resampler = new Resampler(resamplerOptions) const frames = resampler.process(inputAudio) + const framesDivisor = (targetSampleRate / 1000); let start: number, end: number for (const i of [...Array(frames.length)].keys()) { const f = frames[i] const { msg, audio } = await this.frameProcessor.process(f) switch (msg) { case Message.SpeechStart: - start = (i * this.options.frameSamples) / 16 + start = (i * this.options.frameSamples) / framesDivisor break case Message.SpeechEnd: - end = ((i + 1) * this.options.frameSamples) / 16 + end = ((i + 1) * this.options.frameSamples) / framesDivisor // @ts-ignore yield { audio, start, end } break @@ -95,7 +104,7 @@ export class PlatformAgnosticNonRealTimeVAD { audio, // @ts-ignore start, - end: (frames.length * this.options.frameSamples) / 16, + end: (frames.length * this.options.frameSamples) / framesDivisor, } } } diff --git a/packages/_common/src/resampler.ts b/packages/_common/src/resampler.ts index 4972c6e..567d5cb 100644 --- a/packages/_common/src/resampler.ts +++ b/packages/_common/src/resampler.ts @@ -10,9 +10,9 @@ export class Resampler { inputBuffer: Array constructor(public options: ResamplerOptions) { - if (options.nativeSampleRate < 16000) { + if (options.nativeSampleRate < 8000) { log.error( - "nativeSampleRate is too low. Should have 16000 = targetSampleRate <= nativeSampleRate" + "nativeSampleRate is too low. Should have 8000 = targetSampleRate <= nativeSampleRate" ) } this.inputBuffer = []