Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import Literal

TTSModels = Literal[
"lightning-v2",
"lightning-v3.1",
"lightning_v3.1",
"lightning_v3.1_pro",
]

TTSEncoding = Literal[
"pcm",
"mp3",
"wav",
"mulaw",
"ulaw",
"alaw",
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@ class _TTSOptions:
voice_id: str
sample_rate: int
speed: float
consistency: float
similarity: float
enhancement: float
language: LanguageCode
output_format: TTSEncoding | str
base_url: str
Expand All @@ -61,13 +58,10 @@ def __init__(
self,
*,
api_key: str | None = None,
model: TTSModels | str = "lightning-v3.1",
voice_id: str = "sophia",
model: TTSModels | str = "lightning_v3.1",
voice_id: str | None = None,
sample_rate: int = 24000,
speed: float = 1.0,
consistency: float = 0.5,
similarity: float = 0,
enhancement: float = 1,
language: str = "en",
output_format: TTSEncoding | str = "pcm",
base_url: str = SMALLEST_BASE_URL,
Expand All @@ -77,17 +71,18 @@ def __init__(
Create a new instance of Smallest AI Lightning TTS.
Args:
api_key: Your Smallest AI API key.
model: The TTS model to use. Use "lightning-v3.1" (default) for the latest
model with 80+ voices and ~100ms latency, or "lightning-v2" for the
previous generation.
voice_id: The voice ID to use for synthesis.
sample_rate: Sample rate for the audio output.
speed: Speed of the speech synthesis.
consistency: Consistency of the speech synthesis.
similarity: Similarity of the speech synthesis.
enhancement: Enhancement level for the speech synthesis.
language: Language of the text to be synthesized.
output_format: Output format of the audio.
model: The TTS model to use. Use "lightning_v3.1" (default) for the standard
model with 217 voices across 12 languages, or "lightning_v3.1_pro" for the
premium pool with curated American, British, and Indian voices at 44.1 kHz.
Comment on lines +74 to +76
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Docstring incorrectly claims "lightning_v3.1" is the default model

The docstring at tts.py:74 states Use "lightning_v3.1" (default) but the actual parameter default at tts.py:61 is "lightning_v3.1_pro". This will mislead users reading the documentation into thinking the standard model is the default, when it's actually the pro model.

Suggested change
model: The TTS model to use. Use "lightning_v3.1" (default) for the standard
model with 217 voices across 12 languages, or "lightning_v3.1_pro" for the
premium pool with curated American, British, and Indian voices at 44.1 kHz.
model: The TTS model to use. Use "lightning_v3.1_pro" (default) for the
premium pool with curated American, British, and Indian voices at 44.1 kHz,
or "lightning_v3.1" for the standard model with 217 voices across 12 languages.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

voice_id: The voice ID to use for synthesis. Defaults to "meher" for
"lightning_v3.1_pro" and "sophia" for all other models. Pro voices must be
paired with "lightning_v3.1_pro"; standard voices with "lightning_v3.1".
sample_rate: Sample rate for the audio output. Both models are natively 44.1 kHz;
supported rates are 8000, 16000, 24000, and 44100.
speed: Speed of the speech synthesis (0.5–2.0).
language: Language of the text to be synthesized. Use "auto" for automatic
detection and code-switching. Pro supports "en", "hi", and "auto" only.
output_format: Output format of the audio ("pcm", "mp3", "wav", "ulaw", "alaw").
base_url: Base URL for the Smallest AI API.
http_session: An existing aiohttp ClientSession to use.
"""
Expand All @@ -105,15 +100,15 @@ def __init__(
" SMALLEST_API_KEY environment variable"
)

if voice_id is None:
voice_id = "meher" if model == "lightning_v3.1_pro" else "sophia"

self._opts = _TTSOptions(
model=model,
api_key=api_key,
voice_id=voice_id,
sample_rate=sample_rate,
speed=speed,
consistency=consistency,
similarity=similarity,
enhancement=enhancement,
language=LanguageCode(language),
output_format=output_format,
base_url=base_url,
Expand Down Expand Up @@ -141,9 +136,6 @@ def update_options(
voice_id: NotGivenOr[str] = NOT_GIVEN,
speed: NotGivenOr[float] = NOT_GIVEN,
sample_rate: NotGivenOr[int] = NOT_GIVEN,
consistency: NotGivenOr[float] = NOT_GIVEN,
similarity: NotGivenOr[float] = NOT_GIVEN,
enhancement: NotGivenOr[float] = NOT_GIVEN,
language: NotGivenOr[str] = NOT_GIVEN,
output_format: NotGivenOr[TTSEncoding | str] = NOT_GIVEN,
) -> None:
Expand All @@ -156,12 +148,6 @@ def update_options(
self._opts.speed = speed
if is_given(sample_rate):
self._opts.sample_rate = sample_rate
if is_given(consistency):
self._opts.consistency = consistency
if is_given(similarity):
self._opts.similarity = similarity
if is_given(enhancement):
self._opts.enhancement = enhancement
if is_given(language):
self._opts.language = LanguageCode(language)
if is_given(output_format):
Expand Down Expand Up @@ -194,7 +180,7 @@ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
data = _to_smallest_options(self._opts)
data["text"] = self._input_text

url = f"{self._opts.base_url}/{self._opts.model}/get_speech"
url = f"{self._opts.base_url}/tts"

headers = {
"Authorization": f"Bearer {self._opts.api_key}",
Expand Down Expand Up @@ -235,12 +221,12 @@ async def _run(self, output_emitter: tts.AudioEmitter) -> None:


def _to_smallest_options(opts: _TTSOptions) -> dict[str, Any]:
base_keys = ["voice_id", "sample_rate", "speed", "language", "output_format"]
# consistency, similarity, enhancement are lightning-v2 only params
extra_keys = ["consistency", "similarity", "enhancement"]

keys = base_keys + extra_keys if opts.model == "lightning-v2" else base_keys
result = {key: getattr(opts, key) for key in keys}
if "language" in result and isinstance(result["language"], LanguageCode):
result["language"] = result["language"].language
result = {
"model": opts.model,
"voice_id": opts.voice_id,
"sample_rate": opts.sample_rate,
"speed": opts.speed,
"language": opts.language.language if isinstance(opts.language, LanguageCode) else opts.language,
"output_format": opts.output_format,
}
return result
220 changes: 220 additions & 0 deletions tests/test_plugin_smallestai_tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
"""Unit tests for SmallestAI TTS plugin — Lightning v3.1 / v3.1 Pro."""
from __future__ import annotations

import asyncio

Check failure on line 4 in tests/test_plugin_smallestai_tts.py

View workflow job for this annotation

GitHub Actions / ruff

ruff (F401)

tests/test_plugin_smallestai_tts.py:4:8: F401 `asyncio` imported but unused help: Remove unused import: `asyncio`
import json

Check failure on line 5 in tests/test_plugin_smallestai_tts.py

View workflow job for this annotation

GitHub Actions / ruff

ruff (F401)

tests/test_plugin_smallestai_tts.py:5:8: F401 `json` imported but unused help: Remove unused import: `json`
from unittest.mock import AsyncMock, MagicMock, patch

import aiohttp

Check failure on line 8 in tests/test_plugin_smallestai_tts.py

View workflow job for this annotation

GitHub Actions / ruff

ruff (F401)

tests/test_plugin_smallestai_tts.py:8:8: F401 `aiohttp` imported but unused help: Remove unused import: `aiohttp`
import pytest

from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS
from livekit.plugins.smallestai import TTS
from livekit.plugins.smallestai.tts import ChunkedStream, _to_smallest_options


def _make_mock_post(captured: dict):
"""Return a sync callable that acts as aiohttp session.post async context manager."""
async def iter_chunks():
yield b"\x00\x01", False

def post(url, *, headers, json, timeout):

Check failure on line 21 in tests/test_plugin_smallestai_tts.py

View workflow job for this annotation

GitHub Actions / ruff

ruff (F811)

tests/test_plugin_smallestai_tts.py:21:31: F811 Redefinition of unused `json` from line 5: `json` redefined here tests/test_plugin_smallestai_tts.py:5:8: previous definition of `json` here help: Remove definition: `json`
captured["url"] = url
captured["body"] = json

mock_resp = MagicMock()
mock_resp.status = 200
mock_resp.content = MagicMock()
mock_resp.content.iter_chunks = iter_chunks

ctx = MagicMock()
ctx.__aenter__ = AsyncMock(return_value=mock_resp)
ctx.__aexit__ = AsyncMock(return_value=False)
return ctx

return post


# ---------------------------------------------------------------------------
# Construction and defaults
# ---------------------------------------------------------------------------


def test_default_model_is_v31():
tts = TTS(api_key="test-key")
assert tts._opts.model == "lightning_v3.1"


def test_pro_model_accepted():
tts = TTS(api_key="test-key", model="lightning_v3.1_pro")
assert tts._opts.model == "lightning_v3.1_pro"


def test_lightning_v2_not_a_default():
tts = TTS(api_key="test-key")
assert "v2" not in tts._opts.model


def test_default_voice_is_sophia_for_standard_model():
tts = TTS(api_key="test-key", model="lightning_v3.1")
assert tts._opts.voice_id == "sophia"


def test_default_voice_is_meher_for_pro_model():
tts = TTS(api_key="test-key", model="lightning_v3.1_pro")
assert tts._opts.voice_id == "meher"


def test_explicit_voice_overrides_default_for_pro():
tts = TTS(api_key="test-key", model="lightning_v3.1_pro", voice_id="rhea")
assert tts._opts.voice_id == "rhea"


def test_explicit_voice_overrides_default_for_standard():
tts = TTS(api_key="test-key", model="lightning_v3.1", voice_id="aria")
assert tts._opts.voice_id == "aria"


def test_no_consistency_similarity_enhancement_attrs():
tts = TTS(api_key="test-key")
assert not hasattr(tts._opts, "consistency")
assert not hasattr(tts._opts, "similarity")
assert not hasattr(tts._opts, "enhancement")


def test_constructor_rejects_unknown_kwargs():
with pytest.raises(TypeError):
TTS(api_key="test-key", consistency=0.5) # type: ignore[call-arg]


def test_missing_api_key_raises():
with patch.dict("os.environ", {}, clear=True):
import os
os.environ.pop("SMALLEST_API_KEY", None)
with pytest.raises(ValueError, match="API key"):
TTS()


def test_api_key_from_env():
with patch.dict("os.environ", {"SMALLEST_API_KEY": "env-key"}):
tts = TTS()
assert tts._opts.api_key == "env-key"


# ---------------------------------------------------------------------------
# _to_smallest_options — request body shape
# ---------------------------------------------------------------------------


def test_request_body_includes_model():
tts = TTS(api_key="test-key", model="lightning_v3.1_pro")
body = _to_smallest_options(tts._opts)
assert body["model"] == "lightning_v3.1_pro"


def test_request_body_standard_fields():
tts = TTS(api_key="test-key", voice_id="meher", sample_rate=44100, speed=1.2, language="hi")
body = _to_smallest_options(tts._opts)
assert body["voice_id"] == "meher"
assert body["sample_rate"] == 44100
assert body["speed"] == 1.2
assert body["language"] == "hi"
assert body["output_format"] == "pcm"


def test_request_body_no_v2_only_fields():
tts = TTS(api_key="test-key")
body = _to_smallest_options(tts._opts)
assert "consistency" not in body
assert "similarity" not in body
assert "enhancement" not in body


def test_language_code_serialized_as_string():
tts = TTS(api_key="test-key", language="en")
body = _to_smallest_options(tts._opts)
assert isinstance(body["language"], str)
assert body["language"] == "en"


# ---------------------------------------------------------------------------
# update_options
# ---------------------------------------------------------------------------


def test_update_options_model():
tts = TTS(api_key="test-key")
tts.update_options(model="lightning_v3.1_pro")
assert tts._opts.model == "lightning_v3.1_pro"


def test_update_options_voice_speed_language():
tts = TTS(api_key="test-key")
tts.update_options(voice_id="cressida", speed=1.5, language="hi")
assert tts._opts.voice_id == "cressida"
assert tts._opts.speed == 1.5
assert tts._opts.language.language == "hi"


def test_update_options_no_consistency_param():
tts = TTS(api_key="test-key")
with pytest.raises(TypeError):
tts.update_options(consistency=0.8) # type: ignore[call-arg]


# ---------------------------------------------------------------------------
# Endpoint URL — must be /tts, not /{model}/get_speech
# ---------------------------------------------------------------------------


async def test_synthesize_hits_unified_tts_endpoint():
"""The HTTP request must go to /tts, with model in the JSON body."""
tts_instance = TTS(api_key="test-key", model="lightning_v3.1_pro", voice_id="meher")

captured: dict = {}
mock_session = MagicMock()
mock_session.post = _make_mock_post(captured)
tts_instance._session = mock_session

stream = ChunkedStream(
tts=tts_instance, input_text="Hello", conn_options=DEFAULT_API_CONNECT_OPTIONS
)
mock_emitter = MagicMock()
mock_emitter.initialize = MagicMock()
mock_emitter.push = MagicMock()
mock_emitter.flush = MagicMock()

try:
await stream._run(mock_emitter)
finally:
await stream.aclose()

assert captured["url"].endswith("/tts"), f"Expected /tts endpoint, got: {captured['url']}"
assert "get_speech" not in captured["url"], "Old /{model}/get_speech path must not be used"
assert captured["body"]["model"] == "lightning_v3.1_pro"
assert captured["body"]["text"] == "Hello"
assert captured["body"]["voice_id"] == "meher"


async def test_synthesize_standard_model_also_uses_unified_endpoint():
tts_instance = TTS(api_key="test-key", model="lightning_v3.1")

captured: dict = {}
mock_session = MagicMock()
mock_session.post = _make_mock_post(captured)
tts_instance._session = mock_session

stream = ChunkedStream(
tts=tts_instance, input_text="Test", conn_options=DEFAULT_API_CONNECT_OPTIONS
)
mock_emitter = MagicMock()
mock_emitter.initialize = MagicMock()
mock_emitter.push = MagicMock()
mock_emitter.flush = MagicMock()

try:
await stream._run(mock_emitter)
finally:
await stream.aclose()

assert captured["url"].endswith("/tts")
Loading