Skip to content

Commit b1d0129

Browse files
committed
feat: add reference_id parameter to TTS conversion methods
Signed-off-by: James Ding <jamesding365@gmail.com>
1 parent c1153d8 commit b1d0129

2 files changed

Lines changed: 124 additions & 0 deletions

File tree

src/fishaudio/resources/tts.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def convert(
5858
self,
5959
*,
6060
text: str,
61+
reference_id: Optional[str] = None,
6162
config: TTSConfig = TTSConfig(),
6263
model: Model = "s1",
6364
request_options: Optional[RequestOptions] = None,
@@ -67,6 +68,7 @@ def convert(
6768
6869
Args:
6970
text: Text to synthesize
71+
reference_id: Voice reference ID (overridden by config.reference_id if set)
7072
config: TTS configuration (audio settings, voice, model parameters)
7173
model: TTS model to use
7274
request_options: Request-level overrides
@@ -83,6 +85,9 @@ def convert(
8385
# Simple usage with defaults
8486
audio = client.tts.convert(text="Hello world")
8587
88+
# With reference_id parameter
89+
audio = client.tts.convert(text="Hello world", reference_id="your_model_id")
90+
8691
# Custom configuration
8792
config = TTSConfig(format="wav", mp3_bitrate=192)
8893
audio = client.tts.convert(text="Hello world", config=config)
@@ -94,6 +99,11 @@ def convert(
9499
"""
95100
# Build request payload from config
96101
request = _config_to_tts_request(config, text)
102+
103+
# Use parameter reference_id only if config doesn't have one
104+
if request.reference_id is None and reference_id is not None:
105+
request.reference_id = reference_id
106+
97107
payload = request.model_dump(exclude_none=True)
98108

99109
# Make request with streaming
@@ -114,6 +124,7 @@ def stream_websocket(
114124
self,
115125
text_stream: Iterable[Union[str, TextEvent, FlushEvent]],
116126
*,
127+
reference_id: Optional[str] = None,
117128
config: TTSConfig = TTSConfig(),
118129
model: Model = "s1",
119130
max_workers: int = 10,
@@ -125,6 +136,7 @@ def stream_websocket(
125136
126137
Args:
127138
text_stream: Iterator of text chunks to stream
139+
reference_id: Voice reference ID (overridden by config.reference_id if set)
128140
config: TTS configuration (audio settings, voice, model parameters)
129141
model: TTS model to use
130142
max_workers: ThreadPoolExecutor workers for concurrent sender
@@ -148,6 +160,11 @@ def text_generator():
148160
for audio_chunk in client.tts.stream_websocket(text_generator()):
149161
f.write(audio_chunk)
150162
163+
# With reference_id parameter
164+
with open("output.mp3", "wb") as f:
165+
for audio_chunk in client.tts.stream_websocket(text_generator(), reference_id="your_model_id"):
166+
f.write(audio_chunk)
167+
151168
# Custom configuration
152169
config = TTSConfig(format="wav", latency="normal")
153170
with open("output.wav", "wb") as f:
@@ -158,6 +175,10 @@ def text_generator():
158175
# Build TTSRequest from config
159176
tts_request = _config_to_tts_request(config, text="")
160177

178+
# Use parameter reference_id only if config doesn't have one
179+
if tts_request.reference_id is None and reference_id is not None:
180+
tts_request.reference_id = reference_id
181+
161182
executor = ThreadPoolExecutor(max_workers=max_workers)
162183

163184
try:
@@ -202,6 +223,7 @@ async def convert(
202223
self,
203224
*,
204225
text: str,
226+
reference_id: Optional[str] = None,
205227
config: TTSConfig = TTSConfig(),
206228
model: Model = "s1",
207229
request_options: Optional[RequestOptions] = None,
@@ -211,6 +233,7 @@ async def convert(
211233
212234
Args:
213235
text: Text to synthesize
236+
reference_id: Voice reference ID (overridden by config.reference_id if set)
214237
config: TTS configuration (audio settings, voice, model parameters)
215238
model: TTS model to use
216239
request_options: Request-level overrides
@@ -227,6 +250,9 @@ async def convert(
227250
# Simple usage with defaults
228251
audio = await client.tts.convert(text="Hello world")
229252
253+
# With reference_id parameter
254+
audio = await client.tts.convert(text="Hello world", reference_id="your_model_id")
255+
230256
# Custom configuration
231257
config = TTSConfig(format="wav", mp3_bitrate=192)
232258
audio = await client.tts.convert(text="Hello world", config=config)
@@ -238,6 +264,11 @@ async def convert(
238264
"""
239265
# Build request payload from config
240266
request = _config_to_tts_request(config, text)
267+
268+
# Use parameter reference_id only if config doesn't have one
269+
if request.reference_id is None and reference_id is not None:
270+
request.reference_id = reference_id
271+
241272
payload = request.model_dump(exclude_none=True)
242273

243274
# Make request with streaming
@@ -258,6 +289,7 @@ async def stream_websocket(
258289
self,
259290
text_stream: AsyncIterable[Union[str, TextEvent, FlushEvent]],
260291
*,
292+
reference_id: Optional[str] = None,
261293
config: TTSConfig = TTSConfig(),
262294
model: Model = "s1",
263295
):
@@ -268,6 +300,7 @@ async def stream_websocket(
268300
269301
Args:
270302
text_stream: Async iterator of text chunks to stream
303+
reference_id: Voice reference ID (overridden by config.reference_id if set)
271304
config: TTS configuration (audio settings, voice, model parameters)
272305
model: TTS model to use
273306
@@ -290,6 +323,11 @@ async def text_generator():
290323
async for audio_chunk in client.tts.stream_websocket(text_generator()):
291324
await f.write(audio_chunk)
292325
326+
# With reference_id parameter
327+
async with aiofiles.open("output.mp3", "wb") as f:
328+
async for audio_chunk in client.tts.stream_websocket(text_generator(), reference_id="your_model_id"):
329+
await f.write(audio_chunk)
330+
293331
# Custom configuration
294332
config = TTSConfig(format="wav", latency="normal")
295333
async with aiofiles.open("output.wav", "wb") as f:
@@ -300,6 +338,10 @@ async def text_generator():
300338
# Build TTSRequest from config
301339
tts_request = _config_to_tts_request(config, text="")
302340

341+
# Use parameter reference_id only if config doesn't have one
342+
if tts_request.reference_id is None and reference_id is not None:
343+
tts_request.reference_id = reference_id
344+
303345
ws: AsyncWebSocketSession
304346
async with aconnect_ws(
305347
"/v1/tts/live",

tests/unit/test_tts.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,39 @@ def test_convert_with_reference_id(self, tts_client, mock_client_wrapper):
8181
payload = ormsgpack.unpackb(call_args[1]["content"])
8282
assert payload["reference_id"] == "voice_123"
8383

84+
def test_convert_with_reference_id_parameter(self, tts_client, mock_client_wrapper):
85+
"""Test TTS with reference_id as direct parameter."""
86+
mock_response = Mock()
87+
mock_response.iter_bytes.return_value = iter([b"audio"])
88+
mock_client_wrapper.request.return_value = mock_response
89+
90+
list(tts_client.convert(text="Hello", reference_id="voice_456"))
91+
92+
# Verify reference_id in payload
93+
call_args = mock_client_wrapper.request.call_args
94+
payload = ormsgpack.unpackb(call_args[1]["content"])
95+
assert payload["reference_id"] == "voice_456"
96+
97+
def test_convert_config_reference_id_overrides_parameter(
98+
self, tts_client, mock_client_wrapper
99+
):
100+
"""Test that config.reference_id overrides parameter reference_id."""
101+
mock_response = Mock()
102+
mock_response.iter_bytes.return_value = iter([b"audio"])
103+
mock_client_wrapper.request.return_value = mock_response
104+
105+
config = TTSConfig(reference_id="voice_from_config")
106+
list(
107+
tts_client.convert(
108+
text="Hello", reference_id="voice_from_param", config=config
109+
)
110+
)
111+
112+
# Verify config reference_id takes precedence
113+
call_args = mock_client_wrapper.request.call_args
114+
payload = ormsgpack.unpackb(call_args[1]["content"])
115+
assert payload["reference_id"] == "voice_from_config"
116+
84117
def test_convert_with_references(self, tts_client, mock_client_wrapper):
85118
"""Test TTS with reference audio samples."""
86119
mock_response = Mock()
@@ -282,6 +315,55 @@ async def async_iter_bytes():
282315
payload = ormsgpack.unpackb(call_args[1]["content"])
283316
assert payload["reference_id"] == "voice_123"
284317

318+
@pytest.mark.asyncio
319+
async def test_convert_with_reference_id_parameter(
320+
self, async_tts_client, async_mock_client_wrapper
321+
):
322+
"""Test async TTS with reference_id as direct parameter."""
323+
mock_response = Mock()
324+
325+
async def async_iter_bytes():
326+
yield b"audio"
327+
328+
mock_response.aiter_bytes = async_iter_bytes
329+
async_mock_client_wrapper.request = AsyncMock(return_value=mock_response)
330+
331+
audio_chunks = []
332+
async for chunk in async_tts_client.convert(
333+
text="Hello", reference_id="voice_456"
334+
):
335+
audio_chunks.append(chunk)
336+
337+
# Verify reference_id in payload
338+
call_args = async_mock_client_wrapper.request.call_args
339+
payload = ormsgpack.unpackb(call_args[1]["content"])
340+
assert payload["reference_id"] == "voice_456"
341+
342+
@pytest.mark.asyncio
343+
async def test_convert_config_reference_id_overrides_parameter(
344+
self, async_tts_client, async_mock_client_wrapper
345+
):
346+
"""Test that config.reference_id overrides parameter reference_id (async)."""
347+
mock_response = Mock()
348+
349+
async def async_iter_bytes():
350+
yield b"audio"
351+
352+
mock_response.aiter_bytes = async_iter_bytes
353+
async_mock_client_wrapper.request = AsyncMock(return_value=mock_response)
354+
355+
config = TTSConfig(reference_id="voice_from_config")
356+
audio_chunks = []
357+
async for chunk in async_tts_client.convert(
358+
text="Hello", reference_id="voice_from_param", config=config
359+
):
360+
audio_chunks.append(chunk)
361+
362+
# Verify config reference_id takes precedence
363+
call_args = async_mock_client_wrapper.request.call_args
364+
payload = ormsgpack.unpackb(call_args[1]["content"])
365+
assert payload["reference_id"] == "voice_from_config"
366+
285367
@pytest.mark.asyncio
286368
async def test_convert_with_prosody(
287369
self, async_tts_client, async_mock_client_wrapper

0 commit comments

Comments
 (0)