Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to choose different sample rate #36

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions packages/_common/src/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,22 @@ export class Silero {

constructor(
private ort: ONNXRuntimeAPI,
private modelFetcher: ModelFetcher
private modelFetcher: ModelFetcher,
private nativeSampleRate: number,
) {}

static new = async (ort: ONNXRuntimeAPI, modelFetcher: ModelFetcher) => {
const model = new Silero(ort, modelFetcher)
static new = async (ort: ONNXRuntimeAPI, modelFetcher: ModelFetcher, nativeSampleRate: number) => {
const model = new Silero(ort, modelFetcher, nativeSampleRate)
await model.init()
return model
}

init = async () => {
log.debug("initializing vad")
log.debug(`initializing vad`)
const modelArrayBuffer = await this.modelFetcher()
this._session = await this.ort.InferenceSession.create(modelArrayBuffer)
this._sr = new this.ort.Tensor("int64", [16000n])
const tensorSize = BigInt(this.nativeSampleRate)
this._sr = new this.ort.Tensor("int64", [tensorSize])
this.reset_state()
log.debug("vad is initialized")
}
Expand Down
25 changes: 17 additions & 8 deletions packages/_common/src/non-real-time-vad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@ interface NonRealTimeVADSpeechData {
end: number
}

export interface NonRealTimeVADOptions extends FrameProcessorOptions {}
export interface NonRealTimeVADOptions extends FrameProcessorOptions {
nativeSampleRate: number
targetSampleRate: number
}

export const defaultNonRealTimeVADOptions: NonRealTimeVADOptions = {
...defaultFrameProcessorOptions,
nativeSampleRate: 16000,
targetSampleRate: 16000,
}

export class PlatformAgnosticNonRealTimeVAD {
Expand Down Expand Up @@ -46,7 +51,7 @@ export class PlatformAgnosticNonRealTimeVAD {
}

init = async () => {
const model = await Silero.new(this.ort, this.modelFetcher)
const model = await Silero.new(this.ort, this.modelFetcher, this.options.nativeSampleRate)

this.frameProcessor = new FrameProcessor(model.process, model.reset_state, {
frameSamples: this.options.frameSamples,
Expand All @@ -61,26 +66,30 @@ export class PlatformAgnosticNonRealTimeVAD {

run = async function* (
inputAudio: Float32Array,
sampleRate: number
sampleRate?: number,
): AsyncGenerator<NonRealTimeVADSpeechData> {

const targetSampleRate = this.options.targetSampleRate ?? 16000
const resamplerOptions = {
nativeSampleRate: sampleRate,
targetSampleRate: 16000,
nativeSampleRate: sampleRate ?? this.options.nativeSampleRate,
targetSampleRate: targetSampleRate,
targetFrameSize: this.options.frameSamples,
}

const resampler = new Resampler(resamplerOptions)
const frames = resampler.process(inputAudio)
const framesDivisor = (targetSampleRate / 1000);
let start: number, end: number
for (const i of [...Array(frames.length)].keys()) {
const f = frames[i]
const { msg, audio } = await this.frameProcessor.process(f)
switch (msg) {
case Message.SpeechStart:
start = (i * this.options.frameSamples) / 16
start = (i * this.options.frameSamples) / framesDivisor
break

case Message.SpeechEnd:
end = ((i + 1) * this.options.frameSamples) / 16
end = ((i + 1) * this.options.frameSamples) / framesDivisor
// @ts-ignore
yield { audio, start, end }
break
Expand All @@ -95,7 +104,7 @@ export class PlatformAgnosticNonRealTimeVAD {
audio,
// @ts-ignore
start,
end: (frames.length * this.options.frameSamples) / 16,
end: (frames.length * this.options.frameSamples) / framesDivisor,
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions packages/_common/src/resampler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ export class Resampler {
inputBuffer: Array<number>

constructor(public options: ResamplerOptions) {
if (options.nativeSampleRate < 16000) {
if (options.nativeSampleRate < 8000) {
log.error(
"nativeSampleRate is too low. Should have 16000 = targetSampleRate <= nativeSampleRate"
"nativeSampleRate is too low. Should have 8000 = targetSampleRate <= nativeSampleRate"
)
}
this.inputBuffer = []
Expand Down