Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions livekit-agents/livekit/agents/utils/participant.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,76 @@ def on_connection_state_changed(state: int) -> None:
room.off("connection_state_changed", on_connection_state_changed)


class ParticipantAttributeWaitAborted(RuntimeError):
"""Raised by :func:`wait_for_participant_attribute` when the wait cannot
complete (room/participant disconnected or never present)."""
Comment on lines +71 to +73
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just use RuntimeError?



async def wait_for_participant_attribute(
room: rtc.Room,
*,
identity: str,
attribute: str,
value: str,
) -> None:
"""Wait until a remote participant's attribute equals ``value``.

Returns immediately if the attribute is already set. Raises
:class:`ParticipantAttributeWaitAborted` if the room is not connected, the
participant is not present, the participant disconnects, or the room
disconnects before the attribute is set.
"""
if not room.isconnected():
raise ParticipantAttributeWaitAborted("room is not connected")
if identity not in room.remote_participants:
raise ParticipantAttributeWaitAborted(f"participant {identity!r} is not in the room")

fut: asyncio.Future[None] = asyncio.Future()

def _is_match(p: rtc.Participant) -> bool:
return (
isinstance(p, rtc.RemoteParticipant)
and p.identity == identity
and p.attributes.get(attribute) == value
)

def _on_attributes_changed(_changed: list[str], p: rtc.Participant) -> None:
if _is_match(p) and not fut.done():
fut.set_result(None)

def _on_participant_disconnected(p: rtc.RemoteParticipant) -> None:
if p.identity == identity and not fut.done():
fut.set_exception(
ParticipantAttributeWaitAborted(
f"participant {identity!r} disconnected while waiting for {attribute}"
)
)

def _on_connection_state_changed(state: int) -> None:
if state == rtc.ConnectionState.CONN_DISCONNECTED and not fut.done():
fut.set_exception(
ParticipantAttributeWaitAborted(
f"room disconnected while waiting for {identity!r} {attribute}"
)
)

room.on("participant_attributes_changed", _on_attributes_changed)
room.on("participant_disconnected", _on_participant_disconnected)
room.on("connection_state_changed", _on_connection_state_changed)

try:
# check after registering so an attribute set between check and subscribe
# cannot slip past
existing = room.remote_participants.get(identity)
if existing and existing.attributes.get(attribute) == value:
return
await fut
finally:
room.off("participant_attributes_changed", _on_attributes_changed)
room.off("participant_disconnected", _on_participant_disconnected)
room.off("connection_state_changed", _on_connection_state_changed)


@overload
async def wait_for_participant(
room: rtc.Room,
Expand Down
34 changes: 23 additions & 11 deletions livekit-agents/livekit/agents/voice/amd/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,23 +144,18 @@ def __init__(
self._extension_count = 0

def start(self) -> None:
"""Mark classifier as started (enables state guard). Call start_timers() separately."""
"""Mark classifier as started (enables state guard). Arm timers via
:meth:`start_detection_timer` and :meth:`start_no_speech_timer`."""
if self._started:
return
self._started = True

def start_timers(self) -> None:
"""Start the no-speech and detection-timeout timers. Call after start()."""
def start_detection_timer(self) -> None:
"""Arm the overall detection-timeout budget. Call after start()."""
if not self._started or self._closed:
return
self._no_speech_timer = asyncio.get_running_loop().call_later(
self._no_speech_threshold,
functools.partial(
self._silence_timer_callback,
category=AMDCategory.MACHINE_UNAVAILABLE,
reason="no_speech_timeout",
),
)
if self._detection_timeout_timer is not None:
return
self._detection_timeout_timer = asyncio.get_running_loop().call_later(
self._timeout,
functools.partial(
Expand All @@ -170,6 +165,23 @@ def start_timers(self) -> None:
),
)

def start_no_speech_timer(self) -> None:
"""Arm the no-speech timer. Call once we expect audible speech to begin
(e.g. after SIP answer for outbound calls). Typically paired with
:meth:`start_detection_timer` to bound overall detection latency."""
if not self._started or self._closed:
return
if self._no_speech_timer is not None:
return
self._no_speech_timer = asyncio.get_running_loop().call_later(
self._no_speech_threshold,
functools.partial(
self._silence_timer_callback,
category=AMDCategory.MACHINE_UNAVAILABLE,
reason="no_speech_timeout",
),
)

@_state_guard
def on_user_speech_started(self) -> None:
if self._silence_timer is not None:
Expand Down
115 changes: 94 additions & 21 deletions livekit-agents/livekit/agents/voice/amd/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from ...types import NOT_GIVEN, NotGivenOr
from ...utils import EventEmitter, aio, is_given
from ...utils.misc import is_cloud
from ...utils.participant import wait_for_track_publication
from ...utils.participant import (
ParticipantAttributeWaitAborted,
wait_for_participant_attribute,
wait_for_track_publication,
)
from .classifier import (
AMD_PROMPT,
HUMAN_SILENCE_THRESHOLD,
Expand Down Expand Up @@ -57,6 +61,9 @@
"cartesia/ink-whisper",
}

_SIP_CALL_STATUS_ATTR = "sip.callStatus"
_SIP_CALL_STATUS_ACTIVE = "active"


class DetectionOptions(TypedDict, total=False):
human_speech_threshold: float
Expand Down Expand Up @@ -91,8 +98,13 @@ class AMD(EventEmitter[Literal["amd_prediction"]]):
- ``machine-unavailable``: the mailbox is full or not set up; leaving a message is not possible.
- ``uncertain``: the transcript is ambiguous and could not be classified.

AMD should be started before the SIP participant is created so no audio is
missed. Timers begin when the participant's audio track is subscribed.
amd should be started before the SIP participant is created so no audio
is missed. The overall detection-timeout budget starts when the
participant's audio track is subscribed (so amd cannot hang if the call
never connects). For SIP participants, the no-speech timer and
audio/transcript processing are deferred until ``sip.callStatus ==
"active"`` so pre-answer audio (ringback, carrier early media, dialtone)
does not poison the classifier or burn the no-speech budget.

The recommended pattern is the async context manager::

Expand All @@ -114,9 +126,9 @@ class AMD(EventEmitter[Literal["amd_prediction"]]):
agent speech immediately when a machine is detected.
ivr_detection: If ``True`` (default), automatically start IVR
navigation when a ``machine-ivr`` result is returned.
participant_identity: If set, only this participant's audio track
subscription triggers the detection timers. If omitted, the first
remote audio track wins.
participant_identity: If set, amd listens only to this participant's
audio track. If omitted, the first remote audio track wins and
the publisher is resolved from the track sid.
stt: STT used for transcript generation. Accepts an :class:`STT`
instance or an inference model string (e.g.
``"cartesia/ink-whisper"``). When omitted, AMD auto-selects:
Expand Down Expand Up @@ -188,8 +200,10 @@ def __init__(
model_kind="stt",
)

self._stt_task: asyncio.Task[None] | None = None
self._setup_task: asyncio.Task[None] | None = None
self._sip_answer_task: asyncio.Task[None] | None = None
self._audio_ch: aio.Chan[rtc.AudioFrame] | None = None
self._listening: bool = False

@property
def enabled(self) -> bool:
Expand Down Expand Up @@ -248,18 +262,26 @@ async def __aexit__(
# region: lifecycle hooks (called by AudioRecognition)

def push_audio(self, frame: rtc.AudioFrame) -> None:
if not self._listening:
return
if self._audio_ch and not self._audio_ch.closed and self._classifier:
self._audio_ch.send_nowait(frame)

def _on_user_speech_started(self) -> None:
if not self._listening:
return
if self._classifier:
self._classifier.on_user_speech_started()

def _on_user_speech_ended(self, silence_duration: float) -> None:
if not self._listening:
return
if self._classifier:
self._classifier.on_user_speech_ended(silence_duration)

def _on_transcript(self, text: str) -> None:
if not self._listening:
return
if self._classifier:
self._classifier.push_text(text)

Expand All @@ -268,13 +290,14 @@ async def aclose(self) -> None:
return
self._closed = True

if self._stt_task:
self._stt_task.cancel()
try:
await self._stt_task
except asyncio.CancelledError:
pass
self._stt_task = None
pending = [t for t in (self._sip_answer_task, self._setup_task) if t is not None]
if pending:
await aio.cancel_and_wait(*pending)
self._sip_answer_task = None
self._setup_task = None

if self._audio_ch and not self._audio_ch.closed:
self._audio_ch.close()

if self._classifier:
self._classifier.off("amd_prediction", self._on_amd_prediction)
Expand Down Expand Up @@ -321,13 +344,13 @@ async def _run(self, session: AgentSession) -> None:

session._amd = self

# start the classifier first and the timers later when the track is subscribed
# start the classifier first; timers are armed later in _setup
self._classifier.start()
self._start_span()
if session._activity:
session._activity._pause_authorization()

self._stt_task = asyncio.create_task(self._setup(session), name="amd_setup")
self._setup_task = asyncio.create_task(self._setup(session), name="amd_setup")

async def _setup(self, session: AgentSession) -> None:
if self._closed:
Expand All @@ -337,21 +360,71 @@ async def _setup(self, session: AgentSession) -> None:
"session room_io unavailable, starting amd timers immediately as fallback"
)
if self._classifier:
self._classifier.start_timers()
self._classifier.start_detection_timer()
self._classifier.start_no_speech_timer()
self._listening = True
else:
await wait_for_track_publication(
room=session._room_io.room,
room = session._room_io.room
publication = await wait_for_track_publication(
room=room,
identity=self._participant_identity or None,
kind=rtc.TrackKind.KIND_AUDIO,
wait_for_subscription=True,
)
if not self._closed and self._classifier:
self._classifier.start_timers()
if self._closed or not self._classifier:
return
# outer budget runs from track-up so amd bails out even if the
# call never reaches the active state
self._classifier.start_detection_timer()

if self._participant_identity:
publisher = room.remote_participants.get(self._participant_identity)
else:
publisher = next(
(
p
for p in room.remote_participants.values()
if publication.sid in p.track_publications
),
None,
)
if publisher is None:
# publisher already gone; outer detection_timeout will resolve amd
return
if publisher.kind == rtc.ParticipantKind.PARTICIPANT_KIND_SIP:
self._sip_answer_task = asyncio.create_task(
self._wait_for_sip_answer(room, publisher.identity),
name="amd_sip_answer",
)
else:
self._start_listening()

if is_given(self._stt) and not self._closed:
logger.debug("starting amd stt pipeline")
await self._run_stt()

def _start_listening(self) -> None:
if self._closed or not self._classifier:
return
self._classifier.start_no_speech_timer()
self._listening = True

async def _wait_for_sip_answer(self, room: rtc.Room, identity: str) -> None:
try:
await wait_for_participant_attribute(
room,
identity=identity,
attribute=_SIP_CALL_STATUS_ATTR,
value=_SIP_CALL_STATUS_ACTIVE,
)
except ParticipantAttributeWaitAborted as e:
# participant dropped or room disconnected before answer — outer
# timeout will resolve amd with detection_timeout
logger.debug("amd: sip answer wait aborted", extra={"reason": str(e)})
return
if not self._closed:
self._start_listening()

async def _run_stt(self) -> None:
assert is_given(self._stt)
assert self._classifier
Expand Down