Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def __init__(

self._request_id = ""
self._reconnect_event = asyncio.Event()
self._ws: aiohttp.ClientWebSocketResponse | None = None

def update_options(
self,
Expand All @@ -313,6 +314,16 @@ def update_options(
# deprecated
keyterms: NotGivenOr[list[str]] = NOT_GIVEN,
) -> None:
if is_given(keyterms):
logger.warning(
"`keyterms` is deprecated, use `keyterm` instead for consistency with Deepgram API."
)
keyterm = keyterms

requires_reconnect = (
is_given(model) or is_given(sample_rate) or is_given(mip_opt_out) or is_given(endpoint_url) or is_given(tags)
)

if is_given(model):
self._opts.model = model
if is_given(sample_rate):
Expand All @@ -321,11 +332,6 @@ def update_options(
self._opts.eot_threshold = eot_threshold
if is_given(eot_timeout_ms):
self._opts.eot_timeout_ms = eot_timeout_ms
if is_given(keyterms):
logger.warning(
"`keyterms` is deprecated, use `keyterm` instead for consistency with Deepgram API."
)
keyterm = keyterms
if is_given(keyterm):
self._opts.keyterm = keyterm
if is_given(mip_opt_out):
Expand All @@ -339,7 +345,59 @@ def update_options(
if is_given(eager_eot_threshold):
self._opts.eager_eot_threshold = eager_eot_threshold

self._reconnect_event.set()
if requires_reconnect:
self._reconnect_event.set()
elif self._ws is not None and not self._ws.closed:
self._send_configure(
keyterm=keyterm,
eot_threshold=eot_threshold,
eot_timeout_ms=eot_timeout_ms,
eager_eot_threshold=eager_eot_threshold,
language_hint=language_hint,
)
else:
self._reconnect_event.set()

def _send_configure(
self,
*,
keyterm: NotGivenOr[str | list[str]] = NOT_GIVEN,
eot_threshold: NotGivenOr[float] = NOT_GIVEN,
eot_timeout_ms: NotGivenOr[int] = NOT_GIVEN,
eager_eot_threshold: NotGivenOr[float] = NOT_GIVEN,
language_hint: NotGivenOr[list[str]] = NOT_GIVEN,
) -> None:
"""Send a Configure control message to update settings mid-stream without reconnecting."""
configure_msg: dict[str, Any] = {"type": "Configure"}

if is_given(keyterm):
terms = [keyterm] if isinstance(keyterm, str) else list(keyterm)
configure_msg["keyterms"] = terms

thresholds: dict[str, Any] = {}
if is_given(eot_threshold):
thresholds["eot_threshold"] = eot_threshold
if is_given(eot_timeout_ms):
thresholds["eot_timeout_ms"] = eot_timeout_ms
if is_given(eager_eot_threshold):
thresholds["eager_eot_threshold"] = eager_eot_threshold
if thresholds:
configure_msg["thresholds"] = thresholds

if is_given(language_hint):
configure_msg["language_hints"] = language_hint

if len(configure_msg) <= 1:
return

asyncio.ensure_future(self._do_send_configure(json.dumps(configure_msg)))

async def _do_send_configure(self, msg_str: str) -> None:
try:
if self._ws is not None and not self._ws.closed:
await self._ws.send_str(msg_str)
except Exception:
logger.debug("failed to send Configure message, ws may be closing")

async def _run(self) -> None:
closing_ws = False
Expand Down Expand Up @@ -424,6 +482,7 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
while True:
try:
ws = await self._connect_ws()
self._ws = ws
tasks = [
asyncio.create_task(send_task(ws)),
asyncio.create_task(recv_task(ws)),
Expand Down Expand Up @@ -451,6 +510,7 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
tasks_group.cancel()
tasks_group.exception() # retrieve the exception
finally:
self._ws = None
if ws is not None:
await ws.close()

Expand Down
Loading