Skip to content
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import Literal

TTSModels = Literal[
"lightning-v2",
"lightning-v3.1",
"lightning_v3.1",
"lightning_v3.1_pro",
]

TTSEncoding = Literal[
"pcm",
"mp3",
"wav",
"mulaw",
"ulaw",
"alaw",
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from __future__ import annotations

import asyncio
import base64
import json
import os
from dataclasses import dataclass, replace
from typing import Any
Expand All @@ -39,6 +41,7 @@

NUM_CHANNELS = 1
SMALLEST_BASE_URL = "https://api.smallest.ai/waves/v1"
SMALLEST_WS_URL = "wss://api.smallest.ai/waves/v1/tts/live"


@dataclass
Expand All @@ -48,52 +51,51 @@ class _TTSOptions:
voice_id: str
sample_rate: int
speed: float
consistency: float
similarity: float
enhancement: float
language: LanguageCode
output_format: TTSEncoding | str
base_url: str
ws_url: str


class TTS(tts.TTS):
def __init__(
self,
*,
api_key: str | None = None,
model: TTSModels | str = "lightning-v3.1",
voice_id: str = "sophia",
model: TTSModels | str = "lightning_v3.1_pro",
voice_id: str | None = None,
sample_rate: int = 24000,
speed: float = 1.0,
consistency: float = 0.5,
similarity: float = 0,
enhancement: float = 1,
language: str = "en",
output_format: TTSEncoding | str = "pcm",
base_url: str = SMALLEST_BASE_URL,
ws_url: str = SMALLEST_WS_URL,
http_session: aiohttp.ClientSession | None = None,
) -> None:
"""
Create a new instance of Smallest AI Lightning TTS.

Args:
api_key: Your Smallest AI API key.
model: The TTS model to use. Use "lightning-v3.1" (default) for the latest
model with 80+ voices and ~100ms latency, or "lightning-v2" for the
previous generation.
voice_id: The voice ID to use for synthesis.
sample_rate: Sample rate for the audio output.
speed: Speed of the speech synthesis.
consistency: Consistency of the speech synthesis.
similarity: Similarity of the speech synthesis.
enhancement: Enhancement level for the speech synthesis.
language: Language of the text to be synthesized.
output_format: Output format of the audio.
base_url: Base URL for the Smallest AI API.
model: The TTS model to use. Use "lightning_v3.1" for the standard model with
217 voices across 12 languages, or "lightning_v3.1_pro" (default) for the
premium pool with curated American, British, and Indian voices at 44.1 kHz.
voice_id: The voice ID to use for synthesis. Defaults to "meher" for
"lightning_v3.1_pro" and "sophia" for all other models. Pro voices must be
paired with "lightning_v3.1_pro"; standard voices with "lightning_v3.1".
sample_rate: Sample rate for the audio output. Both models are natively 44.1 kHz;
supported rates are 8000, 16000, 24000, and 44100.
speed: Speed of the speech synthesis (0.5–2.0).
language: Language of the text to be synthesized. Use "auto" for automatic
detection and code-switching. Pro supports "en", "hi", and "auto" only.
output_format: Output format for HTTP synthesize() calls ("pcm", "mp3", "wav",
"ulaw", "alaw"). WebSocket streaming always returns PCM.
base_url: Base URL for the Smallest AI HTTP API.
ws_url: WebSocket URL for low-latency streaming synthesis.
http_session: An existing aiohttp ClientSession to use.
"""

super().__init__(
capabilities=tts.TTSCapabilities(streaming=False),
capabilities=tts.TTSCapabilities(streaming=True),
sample_rate=sample_rate,
num_channels=NUM_CHANNELS,
)
Expand All @@ -105,20 +107,27 @@ def __init__(
" SMALLEST_API_KEY environment variable"
)

if voice_id is None:
voice_id = "meher" if model == "lightning_v3.1_pro" else "sophia"

self._opts = _TTSOptions(
model=model,
api_key=api_key,
voice_id=voice_id,
sample_rate=sample_rate,
speed=speed,
consistency=consistency,
similarity=similarity,
enhancement=enhancement,
language=LanguageCode(language),
output_format=output_format,
base_url=base_url,
ws_url=ws_url,
)
self._session = http_session
self._pool = utils.ConnectionPool[aiohttp.ClientWebSocketResponse](
connect_cb=self._connect_ws,
close_cb=self._close_ws,
max_session_duration=3600,
mark_refreshed_on_get=False,
)

@property
def model(self) -> str:
Expand All @@ -131,19 +140,31 @@ def provider(self) -> str:
def _ensure_session(self) -> aiohttp.ClientSession:
if not self._session:
self._session = utils.http_context.http_session()

return self._session

async def _connect_ws(self, timeout: float) -> aiohttp.ClientWebSocketResponse:
return await asyncio.wait_for(
self._ensure_session().ws_connect(
self._opts.ws_url,
headers={
"Authorization": f"Bearer {self._opts.api_key}",
"X-Source": "livekit",
"X-LiveKit-Version": __version__,
},
),
timeout,
)

async def _close_ws(self, ws: aiohttp.ClientWebSocketResponse) -> None:
await ws.close()

def update_options(
self,
*,
model: NotGivenOr[TTSModels | str] = NOT_GIVEN,
voice_id: NotGivenOr[str] = NOT_GIVEN,
speed: NotGivenOr[float] = NOT_GIVEN,
sample_rate: NotGivenOr[int] = NOT_GIVEN,
consistency: NotGivenOr[float] = NOT_GIVEN,
similarity: NotGivenOr[float] = NOT_GIVEN,
enhancement: NotGivenOr[float] = NOT_GIVEN,
language: NotGivenOr[str] = NOT_GIVEN,
output_format: NotGivenOr[TTSEncoding | str] = NOT_GIVEN,
) -> None:
Expand All @@ -156,12 +177,6 @@ def update_options(
self._opts.speed = speed
if is_given(sample_rate):
self._opts.sample_rate = sample_rate
if is_given(consistency):
self._opts.consistency = consistency
if is_given(similarity):
self._opts.similarity = similarity
if is_given(enhancement):
self._opts.enhancement = enhancement
if is_given(language):
self._opts.language = LanguageCode(language)
if is_given(output_format):
Expand All @@ -173,37 +188,43 @@ def synthesize(
*,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> ChunkedStream:
return ChunkedStream(
tts=self,
input_text=text,
conn_options=conn_options,
)
return ChunkedStream(tts=self, input_text=text, conn_options=conn_options)

def stream(
self,
*,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> SynthesizeStream:
return SynthesizeStream(tts=self, conn_options=conn_options)

def prewarm(self) -> None:
self._pool.prewarm()

async def aclose(self) -> None:
await self._pool.aclose()


class ChunkedStream(tts.ChunkedStream):
"""Synthesize chunked text using the Smallest AI TTS endpoint."""
"""HTTP-based synthesis — used when synthesize() is called directly."""

def __init__(self, *, tts: TTS, input_text: str, conn_options: APIConnectOptions) -> None:
super().__init__(tts=tts, input_text=input_text, conn_options=conn_options)
self._tts: TTS = tts
self._opts = replace(tts._opts)

async def _run(self, output_emitter: tts.AudioEmitter) -> None:
"""Run the chunked synthesis process."""
try:
data = _to_smallest_options(self._opts)
data["text"] = self._input_text

url = f"{self._opts.base_url}/{self._opts.model}/get_speech"

headers = {
"Authorization": f"Bearer {self._opts.api_key}",
"Content-Type": "application/json",
"X-Source": "livekit",
"X-LiveKit-Version": __version__,
}
async with self._tts._ensure_session().post(
url,
f"{self._opts.base_url}/tts",
headers=headers,
json=data,
timeout=aiohttp.ClientTimeout(total=self._conn_options.timeout),
Expand Down Expand Up @@ -234,13 +255,105 @@ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
raise APIConnectionError() from e


class SynthesizeStream(tts.SynthesizeStream):
"""WebSocket-based streaming synthesis — primary path used by the agent pipeline."""

def __init__(self, *, tts: TTS, conn_options: APIConnectOptions) -> None:
super().__init__(tts=tts, conn_options=conn_options)
self._tts: TTS = tts
self._opts = replace(tts._opts)

async def _run(self, output_emitter: tts.AudioEmitter) -> None:
output_emitter.initialize(
request_id=utils.shortuuid(),
sample_rate=self._opts.sample_rate,
num_channels=NUM_CHANNELS,
mime_type="audio/pcm",
stream=True,
)

try:
text_buffer = ""
async for data in self._input_ch:
if isinstance(data, str):
text_buffer += data
elif isinstance(data, self._FlushSentinel):
if text_buffer.strip():
await self._run_ws(text_buffer.strip(), output_emitter)
text_buffer = ""
except asyncio.TimeoutError:
raise APITimeoutError() from None
except aiohttp.ClientResponseError as e:
raise APIStatusError(
message=e.message, status_code=e.status, request_id=None, body=None
) from None
except APIStatusError:
raise
except Exception as e:
Comment on lines +289 to +292
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 SynthesizeStream._run catches APIStatusError but not APIConnectionError, causing error message loss

The exception handler at lines 289-292 catches APIStatusError and re-raises it, but does not catch APIConnectionError (or its parent APIError). When _run_ws raises APIConnectionError at livekit-plugins/livekit-plugins-smallestai/livekit/plugins/smallestai/tts.py:343-345 (e.g., "SmallestAI TTS error: {message}"), this falls through to the generic except Exception as e: raise APIConnectionError() from e handler, which wraps it in a new APIConnectionError("Connection error.") — discarding the specific error message. The established pattern (used by Deepgram at livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/tts.py:295) is to catch APIError (the common base class) rather than just APIStatusError.

Suggested change
) from None
except APIStatusError:
raise
except Exception as e:
except APIStatusError:
raise
except APIConnectionError:
raise
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

raise APIConnectionError() from e

async def _run_ws(self, text: str, output_emitter: tts.AudioEmitter) -> None:
segment_id = utils.shortuuid()
output_emitter.start_segment(segment_id=segment_id)

payload: dict[str, Any] = {
"model": self._opts.model,
"voice_id": self._opts.voice_id,
"text": text,
"sample_rate": self._opts.sample_rate,
"speed": self._opts.speed,
"language": self._opts.language.language
if isinstance(self._opts.language, LanguageCode)
else self._opts.language,
}

async with self._tts._pool.connection(timeout=self._conn_options.timeout) as ws:
self._acquire_time = self._tts._pool.last_acquire_time
self._connection_reused = self._tts._pool.last_connection_reused
self._mark_started()
await ws.send_str(json.dumps(payload))

while True:
msg = await ws.receive(timeout=self._conn_options.timeout)

if msg.type in (
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSING,
):
raise APIStatusError(
"SmallestAI WebSocket closed unexpectedly",
status_code=ws.close_code or -1,
body=str(msg.data),
)

if msg.type != aiohttp.WSMsgType.TEXT:
continue

event = json.loads(msg.data)
status = event.get("status")

if status == "chunk":
audio_b64 = event.get("data", {}).get("audio")
if audio_b64:
output_emitter.push(base64.b64decode(audio_b64))
elif status == "complete":
output_emitter.end_segment()
break
elif status == "error":
raise APIConnectionError(
f"SmallestAI TTS error: {event.get('message', 'unknown error')}"
)


def _to_smallest_options(opts: _TTSOptions) -> dict[str, Any]:
base_keys = ["voice_id", "sample_rate", "speed", "language", "output_format"]
# consistency, similarity, enhancement are lightning-v2 only params
extra_keys = ["consistency", "similarity", "enhancement"]

keys = base_keys + extra_keys if opts.model == "lightning-v2" else base_keys
result = {key: getattr(opts, key) for key in keys}
if "language" in result and isinstance(result["language"], LanguageCode):
result["language"] = result["language"].language
return result
return {
"model": opts.model,
"voice_id": opts.voice_id,
"sample_rate": opts.sample_rate,
"speed": opts.speed,
"language": opts.language.language
if isinstance(opts.language, LanguageCode)
else opts.language,
"output_format": opts.output_format,
}
Loading