Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions livekit-agents/livekit/agents/utils/participant.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,76 @@ def on_connection_state_changed(state: int) -> None:
room.off("connection_state_changed", on_connection_state_changed)


class ParticipantAttributeWaitAborted(RuntimeError):
"""Raised by :func:`wait_for_participant_attribute` when the wait cannot
complete (room/participant disconnected or never present)."""
Comment on lines +71 to +73
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just use RuntimeError?



async def wait_for_participant_attribute(
room: rtc.Room,
*,
identity: str,
attribute: str,
value: str,
) -> None:
"""Wait until a remote participant's attribute equals ``value``.

Returns immediately if the attribute is already set. Raises
:class:`ParticipantAttributeWaitAborted` if the room is not connected, the
participant is not present, the participant disconnects, or the room
disconnects before the attribute is set.
"""
if not room.isconnected():
raise ParticipantAttributeWaitAborted("room is not connected")
if identity not in room.remote_participants:
raise ParticipantAttributeWaitAborted(f"participant {identity!r} is not in the room")

fut: asyncio.Future[None] = asyncio.Future()

def _is_match(p: rtc.Participant) -> bool:
return (
isinstance(p, rtc.RemoteParticipant)
and p.identity == identity
and p.attributes.get(attribute) == value
)

def _on_attributes_changed(_changed: list[str], p: rtc.Participant) -> None:
if _is_match(p) and not fut.done():
fut.set_result(None)

def _on_participant_disconnected(p: rtc.RemoteParticipant) -> None:
if p.identity == identity and not fut.done():
fut.set_exception(
ParticipantAttributeWaitAborted(
f"participant {identity!r} disconnected while waiting for {attribute}"
)
)

def _on_connection_state_changed(state: int) -> None:
if state == rtc.ConnectionState.CONN_DISCONNECTED and not fut.done():
fut.set_exception(
ParticipantAttributeWaitAborted(
f"room disconnected while waiting for {identity!r} {attribute}"
)
)

room.on("participant_attributes_changed", _on_attributes_changed)
room.on("participant_disconnected", _on_participant_disconnected)
room.on("connection_state_changed", _on_connection_state_changed)

try:
# check after registering so an attribute set between check and subscribe
# cannot slip past
existing = room.remote_participants.get(identity)
if existing and existing.attributes.get(attribute) == value:
return
await fut
finally:
room.off("participant_attributes_changed", _on_attributes_changed)
room.off("participant_disconnected", _on_participant_disconnected)
room.off("connection_state_changed", _on_connection_state_changed)


@overload
async def wait_for_participant(
room: rtc.Room,
Expand Down
6 changes: 6 additions & 0 deletions livekit-agents/livekit/agents/voice/agent_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1913,6 +1913,12 @@ def on_end_of_turn(self, info: _EndOfTurnInfo) -> bool:
# IMPORTANT: This method is sync to avoid it being cancelled by the AudioRecognition
# We explicitly create a new task here

# TODO: @chenghao-mou replace this direct call with the public `eot_prediction`
# event once feat/AGT-2520-multimodal-EOU lands.
# amd can consume the turn if it detects machine and interrupt_on_machien is True
if (amd := self._session._amd) is not None and amd._on_end_of_turn(info):
return True

if self._scheduling_paused or self._new_turns_blocked:
self._cancel_preemptive_generation()
logger.warning(
Expand Down
164 changes: 108 additions & 56 deletions livekit-agents/livekit/agents/voice/amd/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from ...llm.tool_context import Tool, ToolContext, function_tool
from ...llm.utils import execute_function_call
from ...log import logger
from ...types import NOT_GIVEN, NotGivenOr
from ...utils import EventEmitter, aio, is_given, log_exceptions
from ...utils import EventEmitter, aio, log_exceptions

HUMAN_SPEECH_THRESHOLD = 2.5
HUMAN_SILENCE_THRESHOLD = 0.5
Expand Down Expand Up @@ -87,15 +86,15 @@ def is_machine(self) -> bool:
# endregion


def _state_guard(method: Callable[..., Any]) -> Callable[..., Any]:
def _listening_guard(method: Callable[..., Any]) -> Callable[..., Any]:
"""Drop inputs that arrive outside the listening window.

Pre-listen audio (ringback, dialtone) and post-verdict transcripts are silently dropped.
"""

@functools.wraps(method)
def wrapper(self: "_AMDClassifier", *args: Any, **kwargs: Any) -> Any:
if self.closed or not self.started:
logger.warning(
"AMD state is invalid: started=%s, closed=%s",
self.started,
self.closed,
)
if self._closed or not self._listening:
return
return method(self, *args, **kwargs)

Expand All @@ -113,6 +112,7 @@ def __init__(
timeout: float = TIMEOUT,
prompt: str = AMD_PROMPT,
source: str = "stt",
wait_until_finished: bool = False,
):
super().__init__()
self._human_speech_threshold = human_speech_threshold
Expand All @@ -121,6 +121,7 @@ def __init__(
self._no_speech_threshold = no_speech_threshold
self._timeout = timeout
self._source = source
self._wait_until_finished = wait_until_finished

self._input_ch: aio.Chan[str] = aio.Chan()
self._classify_task: asyncio.Task[None] | None = None
Expand All @@ -136,41 +137,47 @@ def __init__(
self._prompt = prompt
self._speech_started_at: float | None = None
self._speech_ended_at: float | None = None
self._started = False
self._listening = False
self._closed = False
self._machine_silence_reached = False
self._silence_reached = False
self._eot_reached = False
self._emitted = False
self._transcript = ""
self._extension_count = 0

def start(self) -> None:
"""Mark classifier as started (enables state guard). Call start_timers() separately."""
if self._started:
def start_detection_timer(self) -> None:
"""Arm the overall detection-timeout budget."""
if self._closed or self._detection_timeout_timer is not None:
return
self._started = True

def start_timers(self) -> None:
"""Start the no-speech and detection-timeout timers. Call after start()."""
if not self._started or self._closed:
return
self._no_speech_timer = asyncio.get_running_loop().call_later(
self._no_speech_threshold,
functools.partial(
self._silence_timer_callback,
category=AMDCategory.MACHINE_UNAVAILABLE,
reason="no_speech_timeout",
),
)
self._detection_timeout_timer = asyncio.get_running_loop().call_later(
self._timeout,
functools.partial(
self._silence_timer_callback,
self._on_timeout,
category=AMDCategory.UNCERTAIN,
reason="detection_timeout",
),
)

@_state_guard
def start_listening(self) -> None:
"""Open the input gate and arm the no-speech timer.

Call once we expect audible speech to begin (e.g. after sip answer
for outbound calls). Until this fires, all input methods are no-ops.
"""
if self._closed or self._listening:
return
self._listening = True
if self._no_speech_timer is None:
self._no_speech_timer = asyncio.get_running_loop().call_later(
self._no_speech_threshold,
functools.partial(
self._on_timeout,
category=AMDCategory.UNCERTAIN,
reason="no_speech_timeout",
),
)

@_listening_guard
def on_user_speech_started(self) -> None:
if self._silence_timer is not None:
self._silence_timer.cancel()
Expand All @@ -181,9 +188,10 @@ def on_user_speech_started(self) -> None:
self._no_speech_timer = None
if self._speech_started_at is None:
self._speech_started_at = time.time()
self._machine_silence_reached = False
self._silence_reached = False
self._eot_reached = False

@_state_guard
@_listening_guard
def on_user_speech_ended(self, silence_duration: float) -> None:
if self._speech_started_at is None:
logger.warning("on_user_speech_ended called before on_user_speech_started")
Expand All @@ -200,7 +208,7 @@ def on_user_speech_ended(self, silence_duration: float) -> None:
self._silence_timer = asyncio.get_running_loop().call_later(
max(0, self._human_silence_threshold - silence_duration),
functools.partial(
self._silence_timer_callback,
self._on_timeout,
category=AMDCategory.HUMAN,
reason="short_greeting",
speech_duration=speech_duration,
Expand All @@ -210,10 +218,7 @@ def on_user_speech_ended(self, silence_duration: float) -> None:
else:
self._silence_timer = asyncio.get_running_loop().call_later(
max(0, self._machine_silence_threshold - silence_duration),
functools.partial(
self._silence_timer_callback,
speech_duration=speech_duration,
),
self._on_silence_reached,
)
self._silence_timer_trigger = "long_speech"
return
Expand All @@ -227,42 +232,94 @@ def on_user_speech_ended(self, silence_duration: float) -> None:
self._silence_timer_trigger = None
self._silence_timer = asyncio.get_running_loop().call_later(
max(0, self._machine_silence_threshold - silence_duration),
functools.partial(self._silence_timer_callback, speech_duration=speech_duration),
self._on_silence_reached,
)
self._silence_timer_trigger = "long_speech"

def _set_verdict(self, result: AMDPredictionEvent) -> None:
self._verdict_result = result
self._try_emit_result()

def on_end_of_turn(self) -> None:
"""Signal that the turn detector has committed a positive end-of-turn.

For machine verdicts, both this signal and the post-speech silence
timer must fire before the verdict is emitted (whichever lands last
unblocks the wait).

For human/uncertain verdicts only the silence timer is required.
"""
if self._closed:
return
self._eot_reached = True
self._try_emit_result()

def _can_emit(self, verdict: AMDPredictionEvent) -> bool:
"""Both gates: post-speech silence is required for every verdict; eot
is additionally required for machine verdicts (humans emit as soon as
silence is confirmed so we can respond quickly)."""
if not self._silence_reached:
return False
return self._eot_reached if verdict.is_machine else True

def _try_emit_result(self) -> None:
if self._verdict_result is None:
return
if not self._machine_silence_reached:
return
if self._closed or self._emitted:
return
if not self._can_emit(self._verdict_result):
return
self._verdict_ready.set()
if self._detection_timeout_timer is not None:
self._detection_timeout_timer.cancel()
self._detection_timeout_timer = None
if self._no_speech_timer is not None:
self._no_speech_timer.cancel()
self._no_speech_timer = None

self._listening = False
self.emit("amd_prediction", self._verdict_result)
self._emitted = True

@log_exceptions(logger=logger)
@_state_guard
def _silence_timer_callback(
def _on_silence_reached(self) -> None:
"""Post-speech silence window elapsed. Flip the silence gate and try
to emit any verdict the classifier has already produced."""
if self._closed:
return
self._silence_timer = None
self._silence_timer_trigger = None
self._silence_reached = True
self._try_emit_result()

@log_exceptions(logger=logger)
def _on_timeout(
self,
category: NotGivenOr[AMDCategory] = NOT_GIVEN,
reason: NotGivenOr[str] = NOT_GIVEN,
category: AMDCategory,
reason: str,
speech_duration: float | None = None,
) -> None:
"""A timeout (detection budget, no-speech, short greeting) fired.
Synthesize a verdict if none exists and try to emit.

Not gated by ``_listening_guard``: detection_timeout must still fire
when the call never reaches listening (e.g. sip never answered).
"""
if self._closed:
return
if self._silence_timer:
self._silence_timer.cancel()
self._silence_timer = None
self._silence_timer_trigger = None

if is_given(category) and is_given(reason) and self._verdict_result is None:
self._silence_reached = True
has_speech = self._speech_started_at is not None or bool(self._transcript)
# if there is no speech, force eot so that both eot and timeout are satisfied
# to emit verdict
if not (self._wait_until_finished and has_speech):
self._eot_reached = True

if self._verdict_result is None:
self._set_verdict(
AMDPredictionEvent(
speech_duration=speech_duration or self.speech_duration,
Expand All @@ -272,11 +329,9 @@ def _silence_timer_callback(
delay=(time.time() - self._speech_ended_at) if self._speech_ended_at else 0.0,
)
)

self._machine_silence_reached = True
self._try_emit_result()

@_state_guard
@_listening_guard
def push_text(self, text: str, source: str = "stt") -> None:
"""Push transcript text to the AMD classifier."""
if self._input_ch.closed:
Expand All @@ -296,10 +351,7 @@ def push_text(self, text: str, source: str = "stt") -> None:
remaining = (self._speech_ended_at + self._machine_silence_threshold) - time.time()
self._silence_timer = asyncio.get_running_loop().call_later(
max(0, remaining),
functools.partial(
self._silence_timer_callback,
speech_duration=self.speech_duration,
),
self._on_silence_reached,
)
self._silence_timer_trigger = "long_speech"

Expand Down Expand Up @@ -357,7 +409,7 @@ def _on_postpone_elapsed() -> None:
# silence reached so any pending verdict (or one produced by the
# re-classification below) can emit instead of waiting on the
# detection timeout.
self._machine_silence_reached = True
self._silence_reached = True
if not self._input_ch.closed:
# re-trigger classification with the latest transcript; on the
# next run, postpone is unavailable once extensions are
Expand Down Expand Up @@ -422,11 +474,11 @@ async def close(self) -> None:
await aio.cancel_and_wait(self._classify_task)

self._closed = True
self._started = False
self._listening = False

@property
def started(self) -> bool:
return self._started
def listening(self) -> bool:
return self._listening

@property
def closed(self) -> bool:
Expand Down
Loading
Loading