Skip to content

Commit 43ff6f5

Browse files
zkleb-aaiAssemblyAI
andauthored
chore: sync sdk code with DeepLearning repo (#162)
Co-authored-by: AssemblyAI <engineering.sdk@assemblyai.com>
1 parent fa9c5ab commit 43ff6f5

9 files changed

Lines changed: 247 additions & 52 deletions

File tree

.github/workflows/lint.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,6 @@ jobs:
8686
if: ${{ steps.counter.outputs.count > 0 }}
8787
with:
8888
python-version: '3.9'
89-
cache: 'pip'
90-
cache-dependency-path: 'setup.py'
9189
- run: pip install mypy==1.5.1
9290
if: ${{ steps.counter.outputs.count > 0 }}
9391
- run: mypy ${{ steps.filter.outputs.python_files }} --follow-imports=silent --ignore-missing-imports

.github/workflows/test.yml

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,33 +22,19 @@ jobs:
2222
os:
2323
- ubuntu-22.04
2424
steps:
25+
- name: Setup python for tox
26+
uses: actions/setup-python@v4
27+
with:
28+
python-version: ${{ matrix.py }}
29+
- name: Install tox
30+
run: python -m pip install tox
2531
- uses: actions/checkout@v3
2632
with:
2733
fetch-depth: 0
2834
- name: Setup python for test ${{ matrix.py }}
2935
uses: actions/setup-python@v4
3036
with:
3137
python-version: ${{ matrix.py }}
32-
cache: 'pip'
33-
cache-dependency-path: 'setup.py'
34-
- name: Cache apt packages
35-
uses: actions/cache@v3
36-
with:
37-
path: |
38-
/var/cache/apt/archives
39-
/var/lib/apt/lists
40-
key: apt-${{ runner.os }}-portaudio
41-
restore-keys: |
42-
apt-${{ runner.os }}-
43-
- name: Cache tox environments
44-
uses: actions/cache@v3
45-
with:
46-
path: .tox
47-
key: tox-${{ matrix.os }}-${{ matrix.py }}-${{ hashFiles('tox.ini', 'setup.py') }}
48-
restore-keys: |
49-
tox-${{ matrix.os }}-${{ matrix.py }}-
50-
- name: Install tox
51-
run: python -m pip install tox
5238
- name: Setup test suite
5339
run: |
5440
sudo apt-get update && sudo apt-get install -y portaudio19-dev

assemblyai/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.48.4"
1+
__version__ = "0.49.0"

assemblyai/transcriber.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,22 @@ def webhook_auth(self) -> Optional[bool]:
523523

524524
return self._impl.transcript.webhook_auth
525525

526+
@property
527+
def language_code(self) -> Optional[Union[str, types.LanguageCode]]:
528+
"The language code of the transcript"
529+
if not self._impl.transcript:
530+
raise ValueError("The internal Transcript object is None.")
531+
532+
return self._impl.transcript.language_code
533+
534+
@property
535+
def language_codes(self) -> Optional[List[Union[str, types.LanguageCode]]]:
536+
"The list of language codes for multilingual/code-switching audio"
537+
if not self._impl.transcript:
538+
raise ValueError("The internal Transcript object is None.")
539+
540+
return self._impl.transcript.language_codes
541+
526542
@property
527543
def lemur(self) -> lemur.Lemur:
528544
"""

assemblyai/types.py

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Any,
88
Dict,
99
List,
10+
Literal,
1011
Optional,
1112
Sequence,
1213
Tuple,
@@ -296,6 +297,9 @@ class EntityType(str, Enum):
296297
filename = "filename"
297298
"Names of computer files, including the extension or filepath (e.g., Taxes/2012/brad-tax-returns.pdf)"
298299

300+
gender = "gender"
301+
"Terms indicating gender identity (e.g., female, male, non-binary)"
302+
299303
gender_sexuality = "gender_sexuality"
300304
"Terms indicating gender identity or sexual orientation, including slang terms (e.g., female; bisexual; trans)"
301305

@@ -314,6 +318,27 @@ class EntityType(str, Enum):
314318
location = "location"
315319
"Any Location reference including mailing address, postal code, city, state, province, country, or coordinates (e.g., Lake Victoria, 145 Windsor St., 90210)"
316320

321+
location_address = "location_address"
322+
"Mailing address (e.g., 123 Main Street, Apartment 4B)"
323+
324+
location_address_street = "location_address_street"
325+
"Street address (e.g., 123 Main Street)"
326+
327+
location_city = "location_city"
328+
"City name (e.g., San Francisco, New York)"
329+
330+
location_coordinate = "location_coordinate"
331+
"Geographic coordinates (e.g., 37.7749° N, 122.4194° W)"
332+
333+
location_country = "location_country"
334+
"Country name (e.g., United States, Canada)"
335+
336+
location_state = "location_state"
337+
"State or province name (e.g., California, Ontario)"
338+
339+
location_zip = "location_zip"
340+
"Postal or ZIP code (e.g., 94102, M5V 3A8)"
341+
317342
marital_status = "marital_status"
318343
"Terms indicating marital status (e.g., Single, common-law, ex-wife, married)"
319344

@@ -338,6 +363,8 @@ class EntityType(str, Enum):
338363
organization = "organization"
339364
"Name of an organization (e.g., CNN, McDonalds, University of Alaska, Northwest General Hospital)"
340365

366+
organization_medical_facility = "organization_medical_facility"
367+
341368
passport_number = "passport_number"
342369
"Passport numbers, issued by any country (e.g., PA4568332; NU3C6L86S12)"
343370

@@ -362,6 +389,9 @@ class EntityType(str, Enum):
362389
religion = "religion"
363390
"Terms indicating religious affiliation (e.g., Hindu, Catholic)"
364391

392+
sexuality = "sexuality"
393+
"Terms indicating sexual orientation (e.g., heterosexual, gay, bisexual)"
394+
365395
statistics = "statistics"
366396
"Medical statistics (e.g., 18%, 18 percent)"
367397

@@ -383,6 +413,40 @@ class EntityType(str, Enum):
383413
zodiac_sign = "zodiac_sign"
384414
"Names of Zodiac signs (e.g., Aries, Taurus)"
385415

416+
# BETA - only english
417+
corporate_action = "corporate_action"
418+
"Corporate actions (e.g., merger, acquisition, IPO)"
419+
420+
day = "day"
421+
"Day reference (e.g., Monday, Friday)"
422+
423+
effect = "effect"
424+
"Effect or result (e.g., increase, decrease)"
425+
426+
financial_metric = "financial_metric"
427+
"Financial metrics (e.g., revenue, profit margin, EBITDA)"
428+
429+
medical_code = "medical_code"
430+
"Medical codes (e.g., ICD-10, CPT codes)"
431+
432+
month = "month"
433+
"Month reference (e.g., January, February)"
434+
435+
organization_id = "organization_id"
436+
"Organization identification numbers (e.g., EIN, company registration number)"
437+
438+
product = "product"
439+
"Product names (e.g., iPhone, Tesla Model 3)"
440+
441+
project = "project"
442+
"Project names (e.g., Project Apollo, Manhattan Project)"
443+
444+
trend = "trend"
445+
"Trend indicators (e.g., upward trend, downward trend)"
446+
447+
year = "year"
448+
"Year reference (e.g., 2023, 1999)"
449+
386450

387451
# EntityType and PIIRedactionPolicy share the same values
388452
PIIRedactionPolicy = EntityType
@@ -704,6 +768,10 @@ class SpeakerOptions(BaseModel):
704768
None,
705769
description="Enable or disable two-stage clustering for speaker diarization",
706770
)
771+
long_file_diarization_method: Optional[Literal["standard", "experimental"]] = Field(
772+
None,
773+
description="Diarization method for long files. Options: standard (default), experimental",
774+
)
707775

708776
if pydantic_v2:
709777

@@ -861,7 +929,13 @@ class RawTranscriptionConfig(BaseModel):
861929
"The list of key terms used to generate the transcript with the Slam-1 speech model. Can't be used together with `prompt`."
862930

863931
language_codes: Optional[List[Union[str, LanguageCode]]] = None
864-
"List of language codes detected in the audio file when language detection is enabled"
932+
"""
933+
A list of language codes associated with the transcript.
934+
935+
When submitting a transcript request, this can be used to provide multiple language codes
936+
for multilingual/code-switching audio (equivalent to passing `language_codes` in the
937+
`/v2/transcript` API request body).
938+
"""
865939

866940
language_detection_results: Optional[LanguageDetectionResults] = None
867941
"Language detection results including code switching languages"
@@ -876,6 +950,7 @@ class TranscriptionConfig:
876950
def __init__(
877951
self,
878952
language_code: Optional[Union[str, LanguageCode]] = None,
953+
language_codes: Optional[List[Union[str, LanguageCode]]] = None,
879954
punctuate: Optional[bool] = None,
880955
format_text: Optional[bool] = None,
881956
dual_channel: Optional[bool] = None,
@@ -922,6 +997,7 @@ def __init__(
922997
"""
923998
Args:
924999
language_code: The language of your audio file. Possible values are found in Supported Languages.
1000+
language_codes: A list of language codes for multilingual/code-switching audio.
9251001
punctuate: Enable Automatic Punctuation
9261002
format_text: Enable Text Formatting
9271003
dual_channel: Enable Dual Channel transcription
@@ -969,6 +1045,7 @@ def __init__(
9691045

9701046
# explicit configurations have higher priority if `raw_transcription_config` has been passed as well
9711047
self.language_code = language_code
1048+
self.language_codes = language_codes
9721049
self.punctuate = punctuate
9731050
self.format_text = format_text
9741051
self.dual_channel = dual_channel
@@ -1455,10 +1532,17 @@ def speech_threshold(self, threshold: Optional[float]) -> None:
14551532

14561533
@property
14571534
def language_codes(self) -> Optional[List[Union[str, LanguageCode]]]:
1458-
"Returns the list of language codes detected in the audio file when language detection is enabled."
1535+
"Returns the list of language codes associated with this transcript/config."
14591536

14601537
return self._raw_transcription_config.language_codes
14611538

1539+
@language_codes.setter
1540+
def language_codes(
1541+
self, language_codes: Optional[List[Union[str, LanguageCode]]]
1542+
) -> None:
1543+
"Sets the list of language codes for multilingual/code-switching audio."
1544+
self._raw_transcription_config.language_codes = language_codes
1545+
14621546
@property
14631547
def language_detection_results(self) -> Optional[LanguageDetectionResults]:
14641548
"Returns the language detection results including code switching languages."
@@ -1878,7 +1962,7 @@ class Utterance(UtteranceWord):
18781962
class Chapter(BaseModel):
18791963
summary: str
18801964
headline: str
1881-
gist: str
1965+
gist: Optional[str] = None
18821966
start: int
18831967
end: int
18841968

tests/unit/test_auto_chapters.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,41 @@ def test_auto_chapters_enabled(httpx_mock: HTTPXMock):
8181
assert transcript_chapter.gist == response_chapter["gist"]
8282
assert transcript_chapter.start == response_chapter["start"]
8383
assert transcript_chapter.end == response_chapter["end"]
84+
85+
86+
def test_auto_chapters_with_missing_gist(httpx_mock: HTTPXMock):
87+
"""
88+
Tests that the SDK can handle Chapter responses where the `gist` field is missing.
89+
The `gist` field is optional in the Chapter model and should default to None.
90+
"""
91+
# Create a mock response with chapters that have missing gist fields
92+
mock_response = factories.generate_dict_factory(AutoChaptersResponseFactory)()
93+
94+
# Remove the gist field from all chapters to simulate backend response without gist
95+
for chapter in mock_response["chapters"]:
96+
del chapter["gist"]
97+
98+
request_body, transcript = unit_test_utils.submit_mock_transcription_request(
99+
httpx_mock,
100+
mock_response=mock_response,
101+
config=aai.TranscriptionConfig(auto_chapters=True),
102+
)
103+
104+
# Check that request body was properly defined
105+
assert request_body.get("auto_chapters") is True
106+
107+
# Check that transcript was properly parsed from JSON response
108+
assert transcript.error is None
109+
assert transcript.chapters is not None
110+
assert len(transcript.chapters) > 0
111+
assert len(transcript.chapters) == len(mock_response["chapters"])
112+
113+
# Verify that chapters can be parsed without gist field
114+
for response_chapter, transcript_chapter in zip(
115+
mock_response["chapters"], transcript.chapters
116+
):
117+
assert transcript_chapter.summary == response_chapter["summary"]
118+
assert transcript_chapter.headline == response_chapter["headline"]
119+
assert transcript_chapter.gist is None # Should be None when missing
120+
assert transcript_chapter.start == response_chapter["start"]
121+
assert transcript_chapter.end == response_chapter["end"]

tests/unit/test_speaker_options.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,52 @@ def test_transcription_config_with_two_stage_clustering():
127127
assert config.speaker_labels is True
128128
assert config.speaker_options == speaker_options
129129
assert config.speaker_options.use_two_stage_clustering is False
130+
131+
132+
def test_speaker_options_long_file_diarization_method():
133+
"""Test that SpeakerOptions can be created with long_file_diarization_method."""
134+
speaker_options = aai.SpeakerOptions(long_file_diarization_method="experimental")
135+
assert speaker_options.long_file_diarization_method == "experimental"
136+
137+
138+
def test_speaker_options_long_file_diarization_all_methods():
139+
"""Test all valid values for long_file_diarization_method."""
140+
methods = ["standard", "experimental"]
141+
for method in methods:
142+
speaker_options = aai.SpeakerOptions(long_file_diarization_method=method)
143+
assert speaker_options.long_file_diarization_method == method
144+
145+
146+
def test_transcription_config_with_long_file_experimental_diarization():
147+
"""Test the issue scenario: TranscriptionConfig with experimental diarization."""
148+
speaker_options = aai.SpeakerOptions(long_file_diarization_method="experimental")
149+
150+
config = aai.TranscriptionConfig(
151+
speaker_labels=True,
152+
speaker_options=speaker_options,
153+
)
154+
155+
assert config.speaker_labels is True
156+
assert config.speaker_options == speaker_options
157+
assert config.speaker_options.long_file_diarization_method == "experimental"
158+
assert config.raw.speaker_options.long_file_diarization_method == "experimental"
159+
160+
161+
def test_transcription_config_with_all_speaker_options():
162+
"""Test TranscriptionConfig with all speaker options fields."""
163+
speaker_options = aai.SpeakerOptions(
164+
min_speakers_expected=2,
165+
max_speakers_expected=5,
166+
use_two_stage_clustering=False,
167+
long_file_diarization_method="experimental",
168+
)
169+
170+
config = aai.TranscriptionConfig(
171+
speaker_labels=True,
172+
speaker_options=speaker_options,
173+
)
174+
175+
assert config.speaker_options.min_speakers_expected == 2
176+
assert config.speaker_options.max_speakers_expected == 5
177+
assert config.speaker_options.use_two_stage_clustering is False
178+
assert config.speaker_options.long_file_diarization_method == "experimental"

tests/unit/test_transcriber.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,36 @@ def test_language_detection(httpx_mock: HTTPXMock):
644644
assert request.get("language_code") is None
645645

646646

647+
def test_language_codes_request(httpx_mock: HTTPXMock):
648+
mock_completed_json = factories.generate_dict_factory(
649+
factories.TranscriptCompletedResponseFactory
650+
)()
651+
652+
httpx_mock.add_response(
653+
url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}",
654+
status_code=httpx.codes.OK,
655+
method="POST",
656+
json=mock_completed_json,
657+
)
658+
659+
httpx_mock.add_response(
660+
url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{mock_completed_json['id']}",
661+
status_code=httpx.codes.OK,
662+
method="GET",
663+
json=mock_completed_json,
664+
)
665+
666+
aai.Transcriber().transcribe(
667+
"https://example.org/audio.wav",
668+
config=aai.TranscriptionConfig(
669+
language_codes=["en", "es"],
670+
),
671+
)
672+
673+
request = json.loads(httpx_mock.get_requests()[0].content.decode())
674+
assert request.get("language_codes") == ["en", "es"]
675+
676+
647677
def test_language_code_string(httpx_mock: HTTPXMock):
648678
mock_completed_json = factories.generate_dict_factory(
649679
factories.TranscriptCompletedResponseFactory

0 commit comments

Comments
 (0)