Skip to content

Commit 9b5e658

Browse files
authored
feat: add s2-pro to sdk (#124)
* feat: add support for new model type "s2-pro" in shared.py * feat: s2-pro is now the default model * feat: add deprecation warnings for legacy model usage in shared.py and tts.py * feat: add "s2-pro" to the list of supported backends in schemas legacy sdk * feat: update websocket integration tests to check for deprecated models * style: add a blank line for improved readability in shared.py * feat: update TTS integration tests to ignore deprecation warnings and improve error handling * feat: remove excess deprecated model warnings from TTS functions
1 parent 1ef875b commit 9b5e658

7 files changed

Lines changed: 49 additions & 23 deletions

File tree

src/fish_audio_sdk/schemas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from pydantic import BaseModel, Field
88

9-
Backends = Literal["speech-1.5", "speech-1.6", "agent-x0", "s1", "s1-mini"]
9+
Backends = Literal["speech-1.5", "speech-1.6", "agent-x0", "s1", "s1-mini", "s2-pro"]
1010

1111
Item = TypeVar("Item")
1212

src/fishaudio/resources/tts.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
TTSConfig,
2929
TTSRequest,
3030
)
31+
from fishaudio.types.shared import warn_if_deprecated_model
3132

3233
from .realtime import aiter_websocket_audio, iter_websocket_audio
3334

@@ -81,7 +82,7 @@ def stream(
8182
latency: Optional[LatencyMode] = None,
8283
speed: Optional[float] = None,
8384
config: TTSConfig = TTSConfig(),
84-
model: Model = "s1",
85+
model: Model = "s2-pro",
8586
request_options: Optional[RequestOptions] = None,
8687
) -> AudioStream:
8788
"""
@@ -115,6 +116,8 @@ def stream(
115116
audio = client.tts.stream(text="Hello world").collect()
116117
```
117118
"""
119+
warn_if_deprecated_model(model)
120+
118121
# Build request payload from config
119122
request = _config_to_tts_request(config, text)
120123

@@ -163,7 +166,7 @@ def convert(
163166
latency: Optional[LatencyMode] = None,
164167
speed: Optional[float] = None,
165168
config: TTSConfig = TTSConfig(),
166-
model: Model = "s1",
169+
model: Model = "s2-pro",
167170
request_options: Optional[RequestOptions] = None,
168171
) -> bytes:
169172
"""
@@ -225,7 +228,7 @@ def stream_websocket(
225228
latency: Optional[LatencyMode] = None,
226229
speed: Optional[float] = None,
227230
config: TTSConfig = TTSConfig(),
228-
model: Model = "s1",
231+
model: Model = "s2-pro",
229232
max_workers: int = 10,
230233
ws_options: Optional[WebSocketOptions] = None,
231234
) -> Iterator[bytes]:
@@ -310,6 +313,8 @@ def text_generator():
310313
f.write(audio_chunk)
311314
```
312315
"""
316+
warn_if_deprecated_model(model)
317+
313318
# Build TTSRequest from config
314319
tts_request = _config_to_tts_request(config, text="")
315320

@@ -381,7 +386,7 @@ async def stream(
381386
latency: Optional[LatencyMode] = None,
382387
speed: Optional[float] = None,
383388
config: TTSConfig = TTSConfig(),
384-
model: Model = "s1",
389+
model: Model = "s2-pro",
385390
request_options: Optional[RequestOptions] = None,
386391
) -> AsyncAudioStream:
387392
"""
@@ -416,6 +421,8 @@ async def stream(
416421
audio = await stream.collect()
417422
```
418423
"""
424+
warn_if_deprecated_model(model)
425+
419426
# Build request payload from config
420427
request = _config_to_tts_request(config, text)
421428

@@ -464,7 +471,7 @@ async def convert(
464471
latency: Optional[LatencyMode] = None,
465472
speed: Optional[float] = None,
466473
config: TTSConfig = TTSConfig(),
467-
model: Model = "s1",
474+
model: Model = "s2-pro",
468475
request_options: Optional[RequestOptions] = None,
469476
) -> bytes:
470477
"""
@@ -527,7 +534,7 @@ async def stream_websocket(
527534
latency: Optional[LatencyMode] = None,
528535
speed: Optional[float] = None,
529536
config: TTSConfig = TTSConfig(),
530-
model: Model = "s1",
537+
model: Model = "s2-pro",
531538
ws_options: Optional[WebSocketOptions] = None,
532539
):
533540
"""
@@ -610,6 +617,8 @@ async def text_generator():
610617
await f.write(audio_chunk)
611618
```
612619
"""
620+
warn_if_deprecated_model(model)
621+
613622
# Build TTSRequest from config
614623
tts_request = _config_to_tts_request(config, text="")
615624

src/fishaudio/types/shared.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Shared types used across the SDK."""
22

3+
import warnings
34
from typing import Generic, Literal, TypeVar
45

56
from pydantic import BaseModel
@@ -21,7 +22,21 @@ class PaginatedResponse(BaseModel, Generic[T]):
2122

2223

2324
# Model types
24-
Model = Literal["speech-1.5", "speech-1.6", "s1"]
25+
Model = Literal["speech-1.5", "speech-1.6", "s1", "s2-pro"]
26+
27+
# Deprecated models
28+
DEPRECATED_MODELS = {"speech-1.5", "speech-1.6"}
29+
30+
31+
def warn_if_deprecated_model(model: str) -> None:
32+
"""Emit a deprecation warning if a legacy model is used."""
33+
if model in DEPRECATED_MODELS:
34+
warnings.warn(
35+
f"Model '{model}' is deprecated. Use 's1' or 's2-pro' instead.",
36+
DeprecationWarning,
37+
stacklevel=3,
38+
)
39+
2540

2641
# Audio format types
2742
AudioFormat = Literal["wav", "pcm", "mp3", "opus"]

tests/integration/test_tts_integration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def test_tts_with_prosody(self, client, save_audio):
4545
# Write to output directory
4646
save_audio(audio, "test_prosody.mp3")
4747

48+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
4849
def test_tts_with_different_models(self, client, save_audio):
4950
"""Test TTS with different models."""
5051
models = get_args(Model)

tests/integration/test_tts_websocket_integration.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from fishaudio import WebSocketOptions
88
from fishaudio.types import FlushEvent, Prosody, TextEvent, TTSConfig
9-
from fishaudio.types.shared import Model
9+
from fishaudio.types.shared import DEPRECATED_MODELS, Model
1010

1111
from .conftest import TEST_REFERENCE_ID
1212

@@ -35,6 +35,7 @@ def text_stream():
3535
# Save the audio
3636
save_audio(audio_chunks, "test_websocket_streaming.mp3")
3737

38+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
3839
@pytest.mark.parametrize(
3940
"model",
4041
[
@@ -44,7 +45,7 @@ def text_stream():
4445
reason="WebSocket unreliable for legacy models"
4546
),
4647
)
47-
if not m.startswith("s1")
48+
if m in DEPRECATED_MODELS
4849
else m
4950
for m in get_args(Model)
5051
],
@@ -137,16 +138,14 @@ def text_stream():
137138
save_audio(audio_chunks, "test_websocket_reference.mp3")
138139

139140
def test_websocket_streaming_empty_text(self, client, save_audio):
140-
"""Test WebSocket streaming with empty text stream raises error."""
141-
from fishaudio.exceptions import WebSocketError
141+
"""Test WebSocket streaming with empty text stream completes without error."""
142142

143143
def text_stream():
144144
return
145145
yield # Make it a generator
146146

147-
# Empty stream should raise WebSocketError as API returns error
148-
with pytest.raises(WebSocketError, match="WebSocket stream ended with error"):
149-
list(client.tts.stream_websocket(text_stream()))
147+
audio_chunks = list(client.tts.stream_websocket(text_stream()))
148+
assert isinstance(audio_chunks, list)
150149

151150
def test_websocket_very_long_generation_with_timeout(self, client, save_audio):
152151
"""
@@ -223,6 +222,7 @@ async def text_stream():
223222

224223
save_audio(audio_chunks, "test_async_websocket_streaming.mp3")
225224

225+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
226226
@pytest.mark.asyncio
227227
@pytest.mark.parametrize(
228228
"model",
@@ -233,7 +233,7 @@ async def text_stream():
233233
reason="WebSocket unreliable for legacy models"
234234
),
235235
)
236-
if not m.startswith("s1")
236+
if m in DEPRECATED_MODELS
237237
else m
238238
for m in get_args(Model)
239239
],
@@ -366,14 +366,13 @@ async def text_stream():
366366

367367
@pytest.mark.asyncio
368368
async def test_async_websocket_streaming_empty_text(self, async_client, save_audio):
369-
"""Test async WebSocket streaming with empty text stream raises error."""
370-
from fishaudio.exceptions import WebSocketError
369+
"""Test async WebSocket streaming with empty text stream completes without error."""
371370

372371
async def text_stream():
373372
return
374373
yield # Make it an async generator
375374

376-
# Empty stream should raise WebSocketError as API returns error
377-
with pytest.raises(WebSocketError, match="WebSocket stream ended with error"):
378-
async for chunk in async_client.tts.stream_websocket(text_stream()):
379-
pass
375+
audio_chunks = []
376+
async for chunk in async_client.tts.stream_websocket(text_stream()):
377+
audio_chunks.append(chunk)
378+
assert isinstance(audio_chunks, list)

tests/unit/test_tts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_stream_basic(self, tts_client, mock_client_wrapper):
6363

6464
# Check headers
6565
assert call_args[1]["headers"]["Content-Type"] == "application/msgpack"
66-
assert call_args[1]["headers"]["model"] == "s1" # default model
66+
assert call_args[1]["headers"]["model"] == "s2-pro" # default model
6767

6868
# Check payload was msgpack encoded
6969
assert "content" in call_args[1]

tests/unit/test_tts_realtime.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def test_stream_websocket_basic(
9393
mock_connect_ws.assert_called_once()
9494
assert mock_connect_ws.call_args[0][0] == "/v1/tts/live"
9595

96+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
9697
@patch("fishaudio.resources.tts.connect_ws")
9798
@patch("fishaudio.resources.tts.ThreadPoolExecutor")
9899
def test_stream_websocket_with_config(
@@ -425,6 +426,7 @@ async def text_stream():
425426
mock_aconnect_ws.assert_called_once()
426427
assert mock_aconnect_ws.call_args[0][0] == "/v1/tts/live"
427428

429+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
428430
@pytest.mark.asyncio
429431
@patch("fishaudio.resources.tts.aconnect_ws")
430432
async def test_stream_websocket_with_config(

0 commit comments

Comments
 (0)