From c4100007442e0839b06bfd129817be6b7f9bb376 Mon Sep 17 00:00:00 2001 From: chenghao-mou Date: Thu, 21 May 2026 15:30:12 +0000 Subject: [PATCH] fix(vad): reset streams on flush --- .changeset/vad-stream-flush-reset.md | 6 +++++ agents/src/vad.ts | 6 +++++ agents/src/voice/audio_recognition.ts | 39 ++++++++++++++++++++++----- plugins/silero/src/onnx_model.ts | 6 +++++ plugins/silero/src/vad.ts | 38 ++++++++++++++++++++++++-- 5 files changed, 87 insertions(+), 8 deletions(-) create mode 100644 .changeset/vad-stream-flush-reset.md diff --git a/.changeset/vad-stream-flush-reset.md b/.changeset/vad-stream-flush-reset.md new file mode 100644 index 000000000..3ccd5efa8 --- /dev/null +++ b/.changeset/vad-stream-flush-reset.md @@ -0,0 +1,6 @@ +--- +'@livekit/agents': patch +'@livekit/agents-plugin-silero': patch +--- + +Reset active VAD streams on flush so STT end-of-speech can recover without recreating streams. diff --git a/agents/src/vad.ts b/agents/src/vad.ts index 422c2654c..e946ae213 100644 --- a/agents/src/vad.ts +++ b/agents/src/vad.ts @@ -221,6 +221,12 @@ export abstract class VADStream implements AsyncIterableIterator { this.inputWriter.write(frame); } + /** + * Mark the end of the current segment. + * + * Implementations must treat this as a hard segment boundary: drop any accumulated + * speech/silence state so the next pushed frame starts a fresh segment. + */ flush() { if (this.inputClosed) { throw new Error('Input is closed'); diff --git a/agents/src/voice/audio_recognition.ts b/agents/src/voice/audio_recognition.ts index 5b45b3143..d4a3ac90d 100644 --- a/agents/src/voice/audio_recognition.ts +++ b/agents/src/voice/audio_recognition.ts @@ -34,7 +34,7 @@ import { type SpeechEvent, SpeechEventType } from '../stt/stt.js'; import { traceTypes, tracer } from '../telemetry/index.js'; import { splitWords } from '../tokenize/basic/word.js'; import { Task, cancelAndWait, delay, readStream, waitForAbort } from '../utils.js'; -import { type VAD, type VADEvent, VADEventType } from '../vad.js'; +import { type VAD, type VADEvent, VADEventType, type VADStream } from '../vad.js'; import type { TurnDetectionMode } from './agent_session.js'; import { type UserTurnExceededEvent, createUserTurnExceededEvent } from './events.js'; import type { STTNode } from './io.js'; @@ -218,6 +218,7 @@ export class AudioRecognition { private userTurnStart: number | undefined; private userTurnCommitted = false; private speaking = false; + private vadSpeechStarted = false; private sampleRate?: number; private userTurnSpan?: Span; @@ -251,6 +252,7 @@ export class AudioRecognition { private commitUserTurnTask?: Task; private sttForwardTask?: Task; private vadTask?: Task; + private vadStream?: VADStream; private sttConsumerTask?: Task; private interruptionTask?: Task; @@ -1002,9 +1004,20 @@ export class AudioRecognition { // and user state won't be updated until a new VAD SOS is received. // Reset VAD so that incorrect end of turn from STT can be corrected by VAD interruption. // If user is still speaking (an immediate VAD SOS will interrupt the agent). - if (this.vad && this.speaking) { - this.logger.warn('stt end of speech received while user is speaking, resetting vad'); - this.resetVad(); + if (this.vad && this.vadSpeechStarted) { + if (this.vadStream) { + this.vadStream.flush(); + } else { + this.resetVad(); + } + + this.logger.warn( + { + vadSpeechStartTime: this.speechStartTime, + flushed: this.vadStream !== undefined, + }, + 'stt end of speech received while vad is still in a speech segment, flushing vad', + ); } this.speaking = false; this.userTurnCommitted = true; @@ -1168,9 +1181,13 @@ export class AudioRecognition { // clear the transcript if the user turn was committed this.audioTranscript = ''; this.finalTranscriptConfidence = []; - this.lastSpeakingTime = undefined; this.lastFinalTranscriptTime = 0; - this.speechStartTime = undefined; + // Concurrent user speech might have changed it; only reset if there is no new speech. + if (this.lastSpeakingTime === lastSpeakingTime) { + this.speechStartTime = undefined; + this.vadSpeechStarted = false; + this.lastSpeakingTime = undefined; + } } this.userTurnCommitted = false; @@ -1304,6 +1321,7 @@ export class AudioRecognition { if (!vad) return; const vadStream = vad.stream(); + this.vadStream = vadStream; vadStream.updateInputStream(this.vadInputStream); const abortHandler = () => { @@ -1322,6 +1340,10 @@ export class AudioRecognition { this.logger.debug('VAD task: START_OF_SPEECH'); { const startTime = Date.now() - ev.speechDuration - ev.inferenceDuration; + if (!this.vadSpeechStarted) { + this.speechStartTime = startTime; + this.vadSpeechStarted = true; + } const span = this.ensureUserTurnSpan(startTime); const ctx = this.userTurnContext(span); this.endpointing.onStartOfSpeech(startTime, this.isAgentSpeaking); @@ -1366,6 +1388,7 @@ export class AudioRecognition { } // when VAD fires END_OF_SPEECH, it already waited for the silence_duration + this.vadSpeechStarted = false; this.speaking = false; if ( @@ -1382,6 +1405,9 @@ export class AudioRecognition { this.logger.error(e, 'Error in VAD task'); } finally { this.logger.debug('VAD task closed'); + if (this.vadStream === vadStream) { + this.vadStream = undefined; + } } } @@ -1563,6 +1589,7 @@ export class AudioRecognition { this.speechStartTime = undefined; this.userTurnStart = undefined; this.lastSpeakingTime = undefined; + this.vadSpeechStarted = false; this.speaking = false; this.userTurnCommitted = false; diff --git a/plugins/silero/src/onnx_model.ts b/plugins/silero/src/onnx_model.ts index 287052c7b..b83b8647b 100644 --- a/plugins/silero/src/onnx_model.ts +++ b/plugins/silero/src/onnx_model.ts @@ -58,6 +58,12 @@ export class OnnxModel { return this.#contextSize; } + reset(): void { + this.#context.fill(0); + this.#rnnState.fill(0); + this.#inputBuffer.fill(0); + } + async run(x: Float32Array): Promise { this.#inputBuffer.set(this.#context, 0); this.#inputBuffer.set(x, this.#contextSize); diff --git a/plugins/silero/src/vad.ts b/plugins/silero/src/vad.ts index c1529720e..70f0beb1c 100644 --- a/plugins/silero/src/vad.ts +++ b/plugins/silero/src/vad.ts @@ -150,13 +150,46 @@ export class VADStream extends baseStream { let speechThresholdDuration = 0; let silenceThresholdDuration = 0; - let inputFrames = []; + let inputFrames: AudioFrame[] = []; let inferenceFrames: AudioFrame[] = []; let resampler: AudioResampler | null = null; // used to avoid drift when the sampleRate ratio is not an integer let inputCopyRemainingFrac = 0.0; + const resetState = () => { + this.#model.reset(); + this.#expFilter = new ExpFilter(0.35); + + speechBufferIndex = 0; + this.#speechBufferMaxReached = false; + this.#speechBuffer?.fill(0); + + pubSpeaking = false; + pubSpeechDuration = 0; + pubSilenceDuration = 0; + pubCurrentSample = 0; + pubTimestamp = 0; + speechThresholdDuration = 0; + silenceThresholdDuration = 0; + + inputFrames = []; + inferenceFrames = []; + inputCopyRemainingFrac = 0.0; + this.#extraInferenceTime = 0; + + resampler?.close(); + resampler = + this.#inputSampleRate && this.#opts.sampleRate !== this.#inputSampleRate + ? new AudioResampler( + this.#inputSampleRate, + this.#opts.sampleRate, + 1, + AudioResamplerQuality.QUICK, + ) + : null; + }; + while (!this.closed) { const { done, value: frame } = await this.inputReader.read(); if (done) { @@ -164,7 +197,8 @@ export class VADStream extends baseStream { } if (typeof frame === 'symbol') { - continue; // ignore flush sentinel for now + resetState(); + continue; } if (!this.#inputSampleRate || !this.#speechBuffer) {