-
Notifications
You must be signed in to change notification settings - Fork 3.2k
feat: add Gemini 3.1 flash TTS support, implement streaming response … #6134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fc07ea8
277d6b6
17f2b58
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| import logging | ||
| import os | ||
| from dotenv import load_dotenv | ||
|
|
||
| from livekit.agents import ( | ||
| Agent, | ||
| AgentServer, | ||
| AgentSession, | ||
| JobContext, | ||
| cli, | ||
| ) | ||
| from livekit.plugins import deepgram, google | ||
| from livekit.plugins.google.beta import GeminiTTS | ||
|
|
||
| logger = logging.getLogger("gemini-tts-agent") | ||
| load_dotenv() | ||
|
|
||
| class GeminiTTSAgent(Agent): | ||
| def __init__(self) -> None: | ||
| super().__init__( | ||
| instructions="Your name is Kelly. Respond briefly and concisely using voice conversation.", | ||
| ) | ||
|
|
||
| async def on_enter(self) -> None: | ||
| self.session.generate_reply(instructions="greet the user and introduce yourself") | ||
|
|
||
| server = AgentServer() | ||
|
|
||
| @server.rtc_session() | ||
| async def entrypoint(ctx: JobContext) -> None: | ||
| session = AgentSession( | ||
| stt=deepgram.STT(), | ||
| llm=google.LLM(model="gemini-2.5-flash"), | ||
| tts=GeminiTTS( | ||
| api_key=os.environ.get("GOOGLE_API_KEY"), | ||
| voice_name="Kore", | ||
| model="gemini-3.1-flash-tts-preview" | ||
| ), | ||
| ) | ||
| await session.start(agent=GeminiTTSAgent(), room=ctx.room) | ||
|
|
||
| if __name__ == "__main__": | ||
| cli.run_app(server) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,7 @@ | |
| ) | ||
| from livekit.agents.utils import is_given | ||
|
|
||
| GEMINI_TTS_MODELS = Literal["gemini-2.5-flash-preview-tts", "gemini-2.5-pro-preview-tts"] | ||
| GEMINI_TTS_MODELS = Literal["gemini-2.5-flash-preview-tts", "gemini-2.5-pro-preview-tts", "gemini-3.1-flash-tts-preview"] | ||
| GEMINI_VOICES = Literal[ | ||
| "Zephyr", | ||
| "Puck", | ||
|
|
@@ -49,7 +49,7 @@ | |
| "Sulafat", | ||
| ] | ||
|
|
||
| DEFAULT_MODEL = "gemini-2.5-flash-preview-tts" | ||
| DEFAULT_MODEL = "gemini-3.1-flash-tts-preview" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🚩 Default model changed to 'gemini-3.1-flash-tts-preview' — breaking change for existing users The Was this helpful? React with 👍 or 👎 to provide feedback. |
||
| DEFAULT_VOICE = "Kore" | ||
| DEFAULT_SAMPLE_RATE = 24000 # not configurable | ||
| NUM_CHANNELS = 1 | ||
|
|
@@ -87,7 +87,7 @@ def __init__( | |
| - For Google Gemini API: Set the `api_key` argument or the `GOOGLE_API_KEY` environment variable. | ||
|
|
||
| Args: | ||
| model (str, optional): The Gemini TTS model to use. Defaults to "gemini-2.5-flash-preview-tts". | ||
| model (str, optional): The Gemini TTS model to use. Defaults to "gemini-3.1-flash-tts-preview". | ||
| voice_name (str, optional): The voice to use for synthesis. Defaults to "Kore". | ||
| api_key (str, optional): The API key for Google Gemini. If not provided, it attempts to read from the `GOOGLE_API_KEY` environment variable. | ||
| vertexai (bool, optional): Whether to use VertexAI. Defaults to False. | ||
|
|
@@ -200,7 +200,7 @@ async def _run(self, output_emitter: tts.AudioEmitter) -> None: | |
| if self._tts._opts.instructions is not None: | ||
| input_text = f'{self._tts._opts.instructions}:\n"{input_text}"' | ||
|
|
||
| response = await self._tts._client.aio.models.generate_content( | ||
| response = await self._tts._client.aio.models.generate_content_stream( | ||
| model=self._tts._opts.model, | ||
| contents=input_text, | ||
| config=config, | ||
|
|
@@ -213,22 +213,21 @@ async def _run(self, output_emitter: tts.AudioEmitter) -> None: | |
| mime_type="audio/pcm", | ||
| ) | ||
|
|
||
| if ( | ||
| not response.candidates | ||
| or not (content := response.candidates[0].content) | ||
| or not content.parts | ||
| ): | ||
| raise APIStatusError("No audio content generated") | ||
|
|
||
| for part in content.parts: | ||
| async for chunk in response: | ||
| if ( | ||
| (inline_data := part.inline_data) | ||
| and inline_data.data | ||
| and inline_data.mime_type | ||
| and inline_data.mime_type.startswith("audio/") | ||
| chunk.candidates | ||
| and chunk.candidates[0].content | ||
| and chunk.candidates[0].content.parts | ||
| ): | ||
| # mime_type: audio/L16;codec=pcm;rate=24000 | ||
| output_emitter.push(inline_data.data) | ||
| for part in chunk.candidates[0].content.parts: | ||
| if ( | ||
| (inline_data := part.inline_data) | ||
| and inline_data.data | ||
| and inline_data.mime_type | ||
| and inline_data.mime_type.startswith("audio/") | ||
| ): | ||
| # mime_type: audio/L16;codec=pcm;rate=24000 | ||
| output_emitter.push(inline_data.data) | ||
|
|
||
| except ClientError as e: | ||
| raise APIStatusError( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from unittest.mock import AsyncMock, MagicMock, patch | ||
| import pytest | ||
| from google.genai import types | ||
| from livekit.agents.types import APIConnectOptions | ||
|
Check failure on line 6 in tests/test_plugin_google_gemini_tts.py
|
||
| from livekit.plugins.google.beta.gemini_tts import TTS | ||
| from livekit.agents import tts | ||
|
|
||
| pytestmark = pytest.mark.plugin("google") | ||
|
|
||
| @pytest.mark.asyncio | ||
| @patch("livekit.plugins.google.beta.gemini_tts.Client") | ||
| async def test_gemini_tts_success(mock_genai_client_class) -> None: | ||
| # Setup mocks for GenAI Client | ||
| mock_client = MagicMock() | ||
| mock_genai_client_class.return_value = mock_client | ||
|
|
||
| mock_stream = AsyncMock() | ||
| mock_client.aio.models.generate_content_stream = mock_stream | ||
|
|
||
| # Mock chunk response candidates | ||
| class MockInlineData: | ||
| def __init__(self, data: bytes): | ||
| self.data = data | ||
| self.mime_type = "audio/pcm" | ||
|
|
||
| class MockPart: | ||
| def __init__(self, data: bytes): | ||
| self.inline_data = MockInlineData(data) | ||
|
|
||
| class MockContent: | ||
| def __init__(self, data: bytes): | ||
| self.parts = [MockPart(data)] | ||
|
|
||
| class MockCandidate: | ||
| def __init__(self, data: bytes): | ||
| self.content = MockContent(data) | ||
|
|
||
| class MockChunk: | ||
| def __init__(self, data: bytes): | ||
| self.candidates = [MockCandidate(data)] | ||
|
|
||
| async def mock_generator(*args, **kwargs): | ||
| yield MockChunk(b"\x00" * 4800) | ||
| yield MockChunk(b"\x01" * 4800) | ||
|
|
||
| mock_stream.side_effect = mock_generator | ||
|
|
||
| # Initialize TTS | ||
| google_tts = TTS(api_key="test-api-key") | ||
|
|
||
| # Create output emitter mock | ||
| mock_emitter = MagicMock(spec=tts.AudioEmitter) | ||
|
|
||
| # Run ChunkedStream | ||
| stream = google_tts.synthesize("Hello world") | ||
| await stream._run(mock_emitter) | ||
|
|
||
| # Assertions | ||
| mock_stream.assert_called_once() | ||
| mock_emitter.initialize.assert_called_once() | ||
| assert mock_emitter.push.call_count == 2 | ||
Uh oh!
There was an error while loading. Please reload this page.