Skip to content

Commit ba2a8dc

Browse files
committed
feat: update default values for advanced parameters to None in TTS configuration
1 parent daf00fe commit ba2a8dc

3 files changed

Lines changed: 51 additions & 41 deletions

File tree

src/fishaudio/types/tts.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ class TTSConfig(BaseModel):
7575
top_p: Nucleus sampling parameter for token selection. Range: 0.0-1.0. Default: 0.7
7676
temperature: Randomness in generation. Range: 0.0-1.0. Default: 0.7.
7777
Higher = more varied, lower = more consistent
78-
max_new_tokens: Maximum number of tokens to generate. Default: 1024
79-
repetition_penalty: Penalty for repeated tokens. Default: 1.2
80-
min_chunk_length: Minimum chunk length for generation. Default: 50
81-
condition_on_previous_chunks: Whether to condition generation on previous chunks. Default: True
82-
early_stop_threshold: Threshold for early stopping. Default: 1.0
78+
max_new_tokens: Maximum number of tokens to generate. Default: None (server decides)
79+
repetition_penalty: Penalty for repeated tokens. Default: None (server decides)
80+
min_chunk_length: Minimum chunk length for generation. Default: None (server decides)
81+
condition_on_previous_chunks: Whether to condition generation on previous chunks. Default: None (server decides)
82+
early_stop_threshold: Threshold for early stopping. Default: None (server decides)
8383
"""
8484

8585
# Audio output settings
@@ -103,11 +103,11 @@ class TTSConfig(BaseModel):
103103
temperature: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7
104104

105105
# Advanced generation parameters
106-
max_new_tokens: int = 1024
107-
repetition_penalty: float = 1.2
108-
min_chunk_length: int = 50
109-
condition_on_previous_chunks: bool = True
110-
early_stop_threshold: float = 1.0
106+
max_new_tokens: Optional[int] = None
107+
repetition_penalty: Optional[float] = None
108+
min_chunk_length: Optional[int] = None
109+
condition_on_previous_chunks: Optional[bool] = None
110+
early_stop_threshold: Optional[float] = None
111111

112112

113113
class TTSRequest(BaseModel):
@@ -131,11 +131,11 @@ class TTSRequest(BaseModel):
131131
prosody: Speech speed and volume settings. Default: None
132132
top_p: Nucleus sampling for token selection. Range: 0.0-1.0. Default: 0.7
133133
temperature: Randomness in generation. Range: 0.0-1.0. Default: 0.7
134-
max_new_tokens: Maximum number of tokens to generate. Default: 1024
135-
repetition_penalty: Penalty for repeated tokens. Default: 1.2
136-
min_chunk_length: Minimum chunk length for generation. Default: 50
137-
condition_on_previous_chunks: Whether to condition generation on previous chunks. Default: True
138-
early_stop_threshold: Threshold for early stopping. Default: 1.0
134+
max_new_tokens: Maximum number of tokens to generate. Default: None (server decides)
135+
repetition_penalty: Penalty for repeated tokens. Default: None (server decides)
136+
min_chunk_length: Minimum chunk length for generation. Default: None (server decides)
137+
condition_on_previous_chunks: Whether to condition generation on previous chunks. Default: None (server decides)
138+
early_stop_threshold: Threshold for early stopping. Default: None (server decides)
139139
"""
140140

141141
text: str
@@ -151,11 +151,11 @@ class TTSRequest(BaseModel):
151151
prosody: Optional[Prosody] = None
152152
top_p: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7
153153
temperature: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7
154-
max_new_tokens: int = 1024
155-
repetition_penalty: float = 1.2
156-
min_chunk_length: int = 50
157-
condition_on_previous_chunks: bool = True
158-
early_stop_threshold: float = 1.0
154+
max_new_tokens: Optional[int] = None
155+
repetition_penalty: Optional[float] = None
156+
min_chunk_length: Optional[int] = None
157+
condition_on_previous_chunks: Optional[bool] = None
158+
early_stop_threshold: Optional[float] = None
159159

160160

161161
# WebSocket event types for streaming TTS

tests/unit/test_tts.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,11 @@ def test_convert_omit_parameters_not_sent(self, tts_client, mock_client_wrapper)
266266
assert "reference_id" not in payload
267267
assert "sample_rate" not in payload
268268
assert "prosody" not in payload
269+
assert "max_new_tokens" not in payload
270+
assert "repetition_penalty" not in payload
271+
assert "min_chunk_length" not in payload
272+
assert "condition_on_previous_chunks" not in payload
273+
assert "early_stop_threshold" not in payload
269274

270275
# references is an empty list by default, so it IS included
271276
assert payload["references"] == []
@@ -445,24 +450,24 @@ def test_convert_with_new_advanced_parameters(
445450
assert payload["condition_on_previous_chunks"] is False
446451
assert payload["early_stop_threshold"] == 0.8
447452

448-
def test_convert_new_parameters_have_defaults(
453+
def test_convert_advanced_parameters_not_sent_by_default(
449454
self, tts_client, mock_client_wrapper
450455
):
451-
"""Test TTS default values for new advanced parameters."""
456+
"""Test that advanced parameters are not sent when not explicitly set."""
452457
mock_response = Mock()
453458
mock_response.iter_bytes.return_value = iter([b"audio"])
454459
mock_client_wrapper.request.return_value = mock_response
455460

456461
tts_client.convert(text="Hello")
457462

458-
# Verify default values for new parameters in payload
463+
# Verify advanced parameters are NOT in payload by default
459464
call_args = mock_client_wrapper.request.call_args
460465
payload = ormsgpack.unpackb(call_args[1]["content"])
461-
assert payload["max_new_tokens"] == 1024
462-
assert payload["repetition_penalty"] == 1.2
463-
assert payload["min_chunk_length"] == 50
464-
assert payload["condition_on_previous_chunks"] is True
465-
assert payload["early_stop_threshold"] == 1.0
466+
assert "max_new_tokens" not in payload
467+
assert "repetition_penalty" not in payload
468+
assert "min_chunk_length" not in payload
469+
assert "condition_on_previous_chunks" not in payload
470+
assert "early_stop_threshold" not in payload
466471

467472

468473
class TestAsyncTTSClient:
@@ -676,13 +681,18 @@ async def async_iter_bytes():
676681

677682
await async_tts_client.convert(text="Hello")
678683

679-
# Verify OMIT params not in payload
684+
# Verify None params not in payload
680685
call_args = async_mock_client_wrapper.request.call_args
681686
payload = ormsgpack.unpackb(call_args[1]["content"])
682687

683688
assert "reference_id" not in payload
684689
assert "sample_rate" not in payload
685690
assert "prosody" not in payload
691+
assert "max_new_tokens" not in payload
692+
assert "repetition_penalty" not in payload
693+
assert "min_chunk_length" not in payload
694+
assert "condition_on_previous_chunks" not in payload
695+
assert "early_stop_threshold" not in payload
686696

687697
@pytest.mark.asyncio
688698
async def test_convert_empty_response(

tests/unit/test_types.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,12 @@ def test_tts_config_defaults(self):
111111
assert config.latency == "balanced"
112112
assert config.top_p == 0.7
113113
assert config.temperature == 0.7
114-
# New parameter defaults
115-
assert config.max_new_tokens == 1024
116-
assert config.repetition_penalty == 1.2
117-
assert config.min_chunk_length == 50
118-
assert config.condition_on_previous_chunks is True
119-
assert config.early_stop_threshold == 1.0
114+
# Advanced parameters default to None (server decides)
115+
assert config.max_new_tokens is None
116+
assert config.repetition_penalty is None
117+
assert config.min_chunk_length is None
118+
assert config.condition_on_previous_chunks is None
119+
assert config.early_stop_threshold is None
120120

121121
def test_tts_config_custom_new_parameters(self):
122122
"""Test TTSConfig with custom values for new parameters."""
@@ -141,12 +141,12 @@ def test_tts_request_defaults(self):
141141
assert request.format == "mp3"
142142
assert request.chunk_length == 200
143143
assert request.latency == "balanced"
144-
# New parameter defaults
145-
assert request.max_new_tokens == 1024
146-
assert request.repetition_penalty == 1.2
147-
assert request.min_chunk_length == 50
148-
assert request.condition_on_previous_chunks is True
149-
assert request.early_stop_threshold == 1.0
144+
# Advanced parameters default to None (server decides)
145+
assert request.max_new_tokens is None
146+
assert request.repetition_penalty is None
147+
assert request.min_chunk_length is None
148+
assert request.condition_on_previous_chunks is None
149+
assert request.early_stop_threshold is None
150150

151151
def test_tts_request_custom_new_parameters(self):
152152
"""Test TTSRequest with custom values for new parameters."""

0 commit comments

Comments
 (0)