Skip to content

Commit c65fcaa

Browse files
committed
feat: add support for references parameter in TTS conversion methods
Signed-off-by: James Ding <jamesding365@gmail.com>
1 parent b1d0129 commit c65fcaa

2 files changed

Lines changed: 158 additions & 5 deletions

File tree

src/fishaudio/resources/tts.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import asyncio
44
from concurrent.futures import ThreadPoolExecutor
5-
from typing import AsyncIterable, Iterable, Iterator, Optional, Union
5+
from typing import AsyncIterable, Iterable, Iterator, List, Optional, Union
66

77
import ormsgpack
88
from httpx_ws import AsyncWebSocketSession, WebSocketSession, aconnect_ws, connect_ws
@@ -13,6 +13,7 @@
1313
CloseEvent,
1414
FlushEvent,
1515
Model,
16+
ReferenceAudio,
1617
StartEvent,
1718
TextEvent,
1819
TTSConfig,
@@ -59,6 +60,7 @@ def convert(
5960
*,
6061
text: str,
6162
reference_id: Optional[str] = None,
63+
references: List[ReferenceAudio] = [],
6264
config: TTSConfig = TTSConfig(),
6365
model: Model = "s1",
6466
request_options: Optional[RequestOptions] = None,
@@ -69,6 +71,7 @@ def convert(
6971
Args:
7072
text: Text to synthesize
7173
reference_id: Voice reference ID (overridden by config.reference_id if set)
74+
references: Reference audio samples (overridden by config.references if set)
7275
config: TTS configuration (audio settings, voice, model parameters)
7376
model: TTS model to use
7477
request_options: Request-level overrides
@@ -78,7 +81,7 @@ def convert(
7881
7982
Example:
8083
```python
81-
from fishaudio import FishAudio, TTSConfig
84+
from fishaudio import FishAudio, TTSConfig, ReferenceAudio
8285
8386
client = FishAudio(api_key="...")
8487
@@ -88,6 +91,12 @@ def convert(
8891
# With reference_id parameter
8992
audio = client.tts.convert(text="Hello world", reference_id="your_model_id")
9093
94+
# With references parameter
95+
audio = client.tts.convert(
96+
text="Hello world",
97+
references=[ReferenceAudio(audio=audio_bytes, text="sample")]
98+
)
99+
91100
# Custom configuration
92101
config = TTSConfig(format="wav", mp3_bitrate=192)
93102
audio = client.tts.convert(text="Hello world", config=config)
@@ -104,6 +113,10 @@ def convert(
104113
if request.reference_id is None and reference_id is not None:
105114
request.reference_id = reference_id
106115

116+
# Use parameter references only if config doesn't have any
117+
if not request.references and references:
118+
request.references = references
119+
107120
payload = request.model_dump(exclude_none=True)
108121

109122
# Make request with streaming
@@ -125,6 +138,7 @@ def stream_websocket(
125138
text_stream: Iterable[Union[str, TextEvent, FlushEvent]],
126139
*,
127140
reference_id: Optional[str] = None,
141+
references: List[ReferenceAudio] = [],
128142
config: TTSConfig = TTSConfig(),
129143
model: Model = "s1",
130144
max_workers: int = 10,
@@ -137,6 +151,7 @@ def stream_websocket(
137151
Args:
138152
text_stream: Iterator of text chunks to stream
139153
reference_id: Voice reference ID (overridden by config.reference_id if set)
154+
references: Reference audio samples (overridden by config.references if set)
140155
config: TTS configuration (audio settings, voice, model parameters)
141156
model: TTS model to use
142157
max_workers: ThreadPoolExecutor workers for concurrent sender
@@ -146,7 +161,7 @@ def stream_websocket(
146161
147162
Example:
148163
```python
149-
from fishaudio import FishAudio, TTSConfig
164+
from fishaudio import FishAudio, TTSConfig, ReferenceAudio
150165
151166
client = FishAudio(api_key="...")
152167
@@ -165,6 +180,14 @@ def text_generator():
165180
for audio_chunk in client.tts.stream_websocket(text_generator(), reference_id="your_model_id"):
166181
f.write(audio_chunk)
167182
183+
# With references parameter
184+
with open("output.mp3", "wb") as f:
185+
for audio_chunk in client.tts.stream_websocket(
186+
text_generator(),
187+
references=[ReferenceAudio(audio=audio_bytes, text="sample")]
188+
):
189+
f.write(audio_chunk)
190+
168191
# Custom configuration
169192
config = TTSConfig(format="wav", latency="normal")
170193
with open("output.wav", "wb") as f:
@@ -179,6 +202,10 @@ def text_generator():
179202
if tts_request.reference_id is None and reference_id is not None:
180203
tts_request.reference_id = reference_id
181204

205+
# Use parameter references only if config doesn't have any
206+
if not tts_request.references and references:
207+
tts_request.references = references
208+
182209
executor = ThreadPoolExecutor(max_workers=max_workers)
183210

184211
try:
@@ -224,6 +251,7 @@ async def convert(
224251
*,
225252
text: str,
226253
reference_id: Optional[str] = None,
254+
references: List[ReferenceAudio] = [],
227255
config: TTSConfig = TTSConfig(),
228256
model: Model = "s1",
229257
request_options: Optional[RequestOptions] = None,
@@ -234,6 +262,7 @@ async def convert(
234262
Args:
235263
text: Text to synthesize
236264
reference_id: Voice reference ID (overridden by config.reference_id if set)
265+
references: Reference audio samples (overridden by config.references if set)
237266
config: TTS configuration (audio settings, voice, model parameters)
238267
model: TTS model to use
239268
request_options: Request-level overrides
@@ -243,7 +272,7 @@ async def convert(
243272
244273
Example:
245274
```python
246-
from fishaudio import AsyncFishAudio, TTSConfig
275+
from fishaudio import AsyncFishAudio, TTSConfig, ReferenceAudio
247276
248277
client = AsyncFishAudio(api_key="...")
249278
@@ -253,6 +282,12 @@ async def convert(
253282
# With reference_id parameter
254283
audio = await client.tts.convert(text="Hello world", reference_id="your_model_id")
255284
285+
# With references parameter
286+
audio = await client.tts.convert(
287+
text="Hello world",
288+
references=[ReferenceAudio(audio=audio_bytes, text="sample")]
289+
)
290+
256291
# Custom configuration
257292
config = TTSConfig(format="wav", mp3_bitrate=192)
258293
audio = await client.tts.convert(text="Hello world", config=config)
@@ -269,6 +304,10 @@ async def convert(
269304
if request.reference_id is None and reference_id is not None:
270305
request.reference_id = reference_id
271306

307+
# Use parameter references only if config doesn't have any
308+
if not request.references and references:
309+
request.references = references
310+
272311
payload = request.model_dump(exclude_none=True)
273312

274313
# Make request with streaming
@@ -290,6 +329,7 @@ async def stream_websocket(
290329
text_stream: AsyncIterable[Union[str, TextEvent, FlushEvent]],
291330
*,
292331
reference_id: Optional[str] = None,
332+
references: List[ReferenceAudio] = [],
293333
config: TTSConfig = TTSConfig(),
294334
model: Model = "s1",
295335
):
@@ -301,6 +341,7 @@ async def stream_websocket(
301341
Args:
302342
text_stream: Async iterator of text chunks to stream
303343
reference_id: Voice reference ID (overridden by config.reference_id if set)
344+
references: Reference audio samples (overridden by config.references if set)
304345
config: TTS configuration (audio settings, voice, model parameters)
305346
model: TTS model to use
306347
@@ -309,7 +350,7 @@ async def stream_websocket(
309350
310351
Example:
311352
```python
312-
from fishaudio import AsyncFishAudio, TTSConfig
353+
from fishaudio import AsyncFishAudio, TTSConfig, ReferenceAudio
313354
314355
client = AsyncFishAudio(api_key="...")
315356
@@ -328,6 +369,14 @@ async def text_generator():
328369
async for audio_chunk in client.tts.stream_websocket(text_generator(), reference_id="your_model_id"):
329370
await f.write(audio_chunk)
330371
372+
# With references parameter
373+
async with aiofiles.open("output.mp3", "wb") as f:
374+
async for audio_chunk in client.tts.stream_websocket(
375+
text_generator(),
376+
references=[ReferenceAudio(audio=audio_bytes, text="sample")]
377+
):
378+
await f.write(audio_chunk)
379+
331380
# Custom configuration
332381
config = TTSConfig(format="wav", latency="normal")
333382
async with aiofiles.open("output.wav", "wb") as f:
@@ -342,6 +391,10 @@ async def text_generator():
342391
if tts_request.reference_id is None and reference_id is not None:
343392
tts_request.reference_id = reference_id
344393

394+
# Use parameter references only if config doesn't have any
395+
if not tts_request.references and references:
396+
tts_request.references = references
397+
345398
ws: AsyncWebSocketSession
346399
async with aconnect_ws(
347400
"/v1/tts/live",

tests/unit/test_tts.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,46 @@ def test_convert_with_references(self, tts_client, mock_client_wrapper):
135135
assert payload["references"][0]["text"] == "Sample 1"
136136
assert payload["references"][1]["text"] == "Sample 2"
137137

138+
def test_convert_with_references_parameter(self, tts_client, mock_client_wrapper):
139+
"""Test TTS with references as direct parameter."""
140+
mock_response = Mock()
141+
mock_response.iter_bytes.return_value = iter([b"audio"])
142+
mock_client_wrapper.request.return_value = mock_response
143+
144+
references = [
145+
ReferenceAudio(audio=b"ref_audio_1", text="Sample 1"),
146+
ReferenceAudio(audio=b"ref_audio_2", text="Sample 2"),
147+
]
148+
149+
list(tts_client.convert(text="Hello", references=references))
150+
151+
# Verify references in payload
152+
call_args = mock_client_wrapper.request.call_args
153+
payload = ormsgpack.unpackb(call_args[1]["content"])
154+
assert len(payload["references"]) == 2
155+
assert payload["references"][0]["text"] == "Sample 1"
156+
assert payload["references"][1]["text"] == "Sample 2"
157+
158+
def test_convert_config_references_overrides_parameter(
159+
self, tts_client, mock_client_wrapper
160+
):
161+
"""Test that config.references overrides parameter references."""
162+
mock_response = Mock()
163+
mock_response.iter_bytes.return_value = iter([b"audio"])
164+
mock_client_wrapper.request.return_value = mock_response
165+
166+
config_refs = [ReferenceAudio(audio=b"config_audio", text="Config")]
167+
param_refs = [ReferenceAudio(audio=b"param_audio", text="Param")]
168+
169+
config = TTSConfig(references=config_refs)
170+
list(tts_client.convert(text="Hello", references=param_refs, config=config))
171+
172+
# Verify config references take precedence
173+
call_args = mock_client_wrapper.request.call_args
174+
payload = ormsgpack.unpackb(call_args[1]["content"])
175+
assert len(payload["references"]) == 1
176+
assert payload["references"][0]["text"] == "Config"
177+
138178
def test_convert_with_different_backend(self, tts_client, mock_client_wrapper):
139179
"""Test TTS with different backend/model."""
140180
mock_response = Mock()
@@ -364,6 +404,66 @@ async def async_iter_bytes():
364404
payload = ormsgpack.unpackb(call_args[1]["content"])
365405
assert payload["reference_id"] == "voice_from_config"
366406

407+
@pytest.mark.asyncio
408+
async def test_convert_with_references_parameter(
409+
self, async_tts_client, async_mock_client_wrapper
410+
):
411+
"""Test async TTS with references as direct parameter."""
412+
mock_response = Mock()
413+
414+
async def async_iter_bytes():
415+
yield b"audio"
416+
417+
mock_response.aiter_bytes = async_iter_bytes
418+
async_mock_client_wrapper.request = AsyncMock(return_value=mock_response)
419+
420+
references = [
421+
ReferenceAudio(audio=b"ref_audio_1", text="Sample 1"),
422+
ReferenceAudio(audio=b"ref_audio_2", text="Sample 2"),
423+
]
424+
425+
audio_chunks = []
426+
async for chunk in async_tts_client.convert(
427+
text="Hello", references=references
428+
):
429+
audio_chunks.append(chunk)
430+
431+
# Verify references in payload
432+
call_args = async_mock_client_wrapper.request.call_args
433+
payload = ormsgpack.unpackb(call_args[1]["content"])
434+
assert len(payload["references"]) == 2
435+
assert payload["references"][0]["text"] == "Sample 1"
436+
assert payload["references"][1]["text"] == "Sample 2"
437+
438+
@pytest.mark.asyncio
439+
async def test_convert_config_references_overrides_parameter(
440+
self, async_tts_client, async_mock_client_wrapper
441+
):
442+
"""Test that config.references overrides parameter references (async)."""
443+
mock_response = Mock()
444+
445+
async def async_iter_bytes():
446+
yield b"audio"
447+
448+
mock_response.aiter_bytes = async_iter_bytes
449+
async_mock_client_wrapper.request = AsyncMock(return_value=mock_response)
450+
451+
config_refs = [ReferenceAudio(audio=b"config_audio", text="Config")]
452+
param_refs = [ReferenceAudio(audio=b"param_audio", text="Param")]
453+
454+
config = TTSConfig(references=config_refs)
455+
audio_chunks = []
456+
async for chunk in async_tts_client.convert(
457+
text="Hello", references=param_refs, config=config
458+
):
459+
audio_chunks.append(chunk)
460+
461+
# Verify config references take precedence
462+
call_args = async_mock_client_wrapper.request.call_args
463+
payload = ormsgpack.unpackb(call_args[1]["content"])
464+
assert len(payload["references"]) == 1
465+
assert payload["references"][0]["text"] == "Config"
466+
367467
@pytest.mark.asyncio
368468
async def test_convert_with_prosody(
369469
self, async_tts_client, async_mock_client_wrapper

0 commit comments

Comments
 (0)