├── .github └── workflows │ ├── lint.yml │ ├── publish.yml │ └── test.yml ├── LICENSE ├── README.md ├── assemblyai.png ├── assemblyai ├── __init__.py ├── __version__.py ├── api.py ├── client.py ├── extras.py ├── lemur.py ├── streaming │ ├── __init__.py │ └── v3 │ │ ├── __init__.py │ │ ├── client.py │ │ └── models.py ├── transcriber.py └── types.py ├── ruff.toml ├── setup.py ├── tests ├── __init__.py └── unit │ ├── __init__.py │ ├── conftest.py │ ├── factories.py │ ├── test_auto_chapters.py │ ├── test_auto_highlights.py │ ├── test_client.py │ ├── test_config.py │ ├── test_content_safety.py │ ├── test_custom_spelling.py │ ├── test_domains.py │ ├── test_entity_detection.py │ ├── test_extras.py │ ├── test_iab_categories.py │ ├── test_imports.py │ ├── test_lemur.py │ ├── test_multichannel.py │ ├── test_realtime_transcriber.py │ ├── test_redact_pii.py │ ├── test_sentiment_analysis.py │ ├── test_settings.py │ ├── test_streaming.py │ ├── test_summarization.py │ ├── test_transcriber.py │ ├── test_transcript.py │ ├── test_transcript_group.py │ └── unit_test_utils.py └── tox.ini /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Code Quality 2 | 3 | # Controls when the workflow will run 4 | on: 5 | # Triggers the workflow on push or pull request events but only for the "master" branch 6 | push: 7 | branches: ["master"] 8 | pull_request: 9 | 10 | # Allows you to run this workflow manually from the Actions tab 11 | workflow_dispatch: 12 | 13 | concurrency: 14 | # Cancel previous actions from the same PR or branch except 'master' branch. 15 | # See https://docs.github.com/en/actions/using-jobs/using-concurrency and https://docs.github.com/en/actions/learn-github-actions/contexts for more info. 16 | group: concurrency-group::${{ github.workflow }}::${{ github.event.pull_request.number > 0 && format('pr-{0}', github.event.pull_request.number) || github.ref_name }}${{ github.ref_name == 'master' && format('::{0}', github.run_id) || ''}} 17 | cancel-in-progress: ${{ github.ref_name != 'master' }} 18 | 19 | jobs: 20 | ruff: 21 | needs: [] 22 | runs-on: ubuntu-latest 23 | steps: 24 | - uses: actions/checkout@v3 25 | # Get all changed and modified files. 26 | - uses: dorny/paths-filter@v2 27 | id: filter 28 | with: 29 | list-files: shell 30 | filters: | 31 | python: 32 | - added|modified: 'assemblyai/**/*.py' 33 | # Get count of filtered files. 34 | - run: | 35 | if [ '${{ steps.filter.outputs.python_files }}' != '' ]; then 36 | echo count=$(ls ${{ steps.filter.outputs.python_files }} | wc -l) >> "$GITHUB_OUTPUT" 37 | else 38 | echo count=0 >> "$GITHUB_OUTPUT" 39 | fi 40 | id: counter 41 | if: ${{ steps.filter.outputs.python == 'true' }} 42 | shell: bash 43 | name: Run count files 44 | # Run ruff on filtered files if there are any. 45 | - uses: chartboost/ruff-action@v1 46 | name: Run 'ruff format --check --config ./ruff.toml' 47 | if: ${{ steps.counter.outputs.count > 0 }} 48 | with: 49 | version: 0.3.5 50 | args: 'format --check --config ./ruff.toml' 51 | src: ${{ steps.filter.outputs.python_files }} 52 | - uses: chartboost/ruff-action@v1 53 | name: Run 'ruff' 54 | if: ${{ steps.counter.outputs.count > 0 }} 55 | with: 56 | version: 0.3.5 57 | args: '--config ./ruff.toml' 58 | src: ${{ steps.filter.outputs.python_files }} 59 | 60 | mypy: 61 | needs: [] 62 | runs-on: ubuntu-latest 63 | steps: 64 | - uses: actions/checkout@v3 65 | # Get all changed and modified files. 66 | - uses: dorny/paths-filter@v2 67 | id: filter 68 | with: 69 | list-files: shell 70 | filters: | 71 | python: 72 | - added|modified: 'assemblyai/**/*.py' 73 | # Get count of filtered files. 74 | - run: | 75 | if [ '${{ steps.filter.outputs.python_files }}' != '' ]; then 76 | echo count=$(ls ${{ steps.filter.outputs.python_files }} | wc -l) >> "$GITHUB_OUTPUT" 77 | else 78 | echo count=0 >> "$GITHUB_OUTPUT" 79 | fi 80 | id: counter 81 | if: ${{ steps.filter.outputs.python == 'true' }} 82 | shell: bash 83 | name: Run count files 84 | # Run mypy on filtered files if there are any. 85 | - uses: actions/setup-python@v4 86 | if: ${{ steps.counter.outputs.count > 0 }} 87 | with: 88 | python-version: '3.9' 89 | - run: pip install mypy==1.5.1 90 | if: ${{ steps.counter.outputs.count > 0 }} 91 | - run: mypy ${{ steps.filter.outputs.python_files }} --follow-imports=silent --ignore-missing-imports 92 | if: ${{ steps.counter.outputs.count > 0 }} 93 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distributions 📦 to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - '*' 7 | 8 | jobs: 9 | build-and-publish: 10 | name: Build and publish Python 🐍 distributions 📦 to PyPI 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | - name: Set up Python 15 | uses: actions/setup-python@v4 16 | with: 17 | python-version: "3.8" 18 | - name: Install pypa/build 19 | run: >- 20 | python -m 21 | pip install 22 | build 23 | --user 24 | - name: Build a binary wheel and a source tarball 25 | run: >- 26 | python -m 27 | build 28 | --sdist 29 | --wheel 30 | --outdir dist/ 31 | . 32 | - name: Publish distribution 📦 to Test PyPI 33 | uses: pypa/gh-action-pypi-publish@release/v1 34 | with: 35 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 36 | repository-url: https://test.pypi.org/legacy/ 37 | continue-on-error: true 38 | - name: Publish distribution 📦 to PyPI 39 | if: startsWith(github.ref, 'refs/tags') 40 | uses: pypa/gh-action-pypi-publish@release/v1 41 | with: 42 | password: ${{ secrets.PYPI_API_TOKEN }} 43 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test Python 🐍 Versions + 3rd-party Deps 2 | on: 3 | push: 4 | branches: ["master"] 5 | pull_request: 6 | branches: ["master"] 7 | schedule: 8 | - cron: "0 0 * * *" 9 | workflow_dispatch: 10 | 11 | jobs: 12 | test: 13 | name: Python ${{ matrix.py }} on ${{ matrix.os }} 14 | runs-on: ${{ matrix.os }} 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | py: 19 | - "3.11" 20 | - "3.10" 21 | - "3.9" 22 | os: 23 | - ubuntu-22.04 24 | steps: 25 | - name: Setup python for tox 26 | uses: actions/setup-python@v4 27 | with: 28 | python-version: ${{ matrix.py }} 29 | - name: Install tox 30 | run: python -m pip install tox 31 | - uses: actions/checkout@v3 32 | with: 33 | fetch-depth: 0 34 | - name: Setup python for test ${{ matrix.py }} 35 | uses: actions/setup-python@v4 36 | with: 37 | python-version: ${{ matrix.py }} 38 | - name: Setup test suite 39 | run: | 40 | sudo apt-get update && sudo apt-get install -y portaudio19-dev 41 | python_version="${{ matrix.py }}" 42 | python_version="${python_version/./}" 43 | tox -f "py$python_version" -vvvv --notest 44 | - name: Run test suite 45 | run: | 46 | python_version="${{ matrix.py }}" 47 | python_version="${python_version/./}" 48 | tox -f "py$python_version" -vvvv --skip-pkg-install 49 | env: 50 | PYTEST_ADDOPTS: "-vv --durations=20" 51 | CI_RUN: "yes" 52 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 AssemblyAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assemblyai.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI/assemblyai-python-sdk/ef8dcc0f300ae09b2b528d65f49e770dcefd6243/assemblyai.png -------------------------------------------------------------------------------- /assemblyai/__init__.py: -------------------------------------------------------------------------------- 1 | from . import extras 2 | from .__version__ import __version__ 3 | from .client import Client 4 | from .lemur import Lemur 5 | from .transcriber import RealtimeTranscriber, Transcriber, Transcript, TranscriptGroup 6 | from .types import ( 7 | AssemblyAIError, 8 | AudioEncoding, 9 | AutohighlightResponse, 10 | AutohighlightResult, 11 | Chapter, 12 | ContentSafetyLabel, 13 | ContentSafetyLabelResult, 14 | ContentSafetyResponse, 15 | ContentSafetySeverityScore, 16 | Entity, 17 | EntityType, 18 | IABLabelResult, 19 | IABResponse, 20 | IABResult, 21 | LanguageCode, 22 | LemurActionItemsResponse, 23 | LemurError, 24 | LemurModel, 25 | LemurPurgeRequest, 26 | LemurPurgeResponse, 27 | LemurQuestion, 28 | LemurQuestionAnswer, 29 | LemurQuestionResponse, 30 | LemurSource, 31 | LemurSourceType, 32 | LemurStringResponse, 33 | LemurSummaryResponse, 34 | LemurTaskResponse, 35 | LemurTranscriptSource, 36 | LemurUsage, 37 | ListTranscriptParameters, 38 | ListTranscriptResponse, 39 | PageDetails, 40 | Paragraph, 41 | PIIRedactedAudioQuality, 42 | PIIRedactionPolicy, 43 | PIISubstitutionPolicy, 44 | RawTranscriptionConfig, 45 | RealtimeError, 46 | RealtimeFinalTranscript, 47 | RealtimePartialTranscript, 48 | RealtimeSessionInformation, 49 | RealtimeSessionOpened, 50 | RealtimeTranscript, 51 | RealtimeWord, 52 | Sentence, 53 | Sentiment, 54 | SentimentType, 55 | Settings, 56 | SpeechModel, 57 | StatusResult, 58 | SummarizationModel, 59 | SummarizationType, 60 | Timestamp, 61 | TranscriptError, 62 | TranscriptionConfig, 63 | TranscriptItem, 64 | TranscriptStatus, 65 | Utterance, 66 | UtteranceWord, 67 | Word, 68 | WordBoost, 69 | WordSearchMatch, 70 | ) 71 | 72 | settings = Settings() 73 | """Global settings object that applies to all classes that use the `Client` class.""" 74 | 75 | 76 | __all__ = [ 77 | # types 78 | "AssemblyAIError", 79 | "AudioEncoding", 80 | "AutohighlightResponse", 81 | "AutohighlightResult", 82 | "Chapter", 83 | "Client", 84 | "ContentSafetyLabel", 85 | "ContentSafetyLabelResult", 86 | "ContentSafetyResponse", 87 | "ContentSafetySeverityScore", 88 | "Entity", 89 | "EntityType", 90 | "IABLabelResult", 91 | "IABResponse", 92 | "IABResult", 93 | "LanguageCode", 94 | "Lemur", 95 | "LemurActionItemsResponse", 96 | "LemurError", 97 | "LemurModel", 98 | "LemurPurgeRequest", 99 | "LemurPurgeResponse", 100 | "LemurSource", 101 | "LemurSourceType", 102 | "LemurTranscriptSource", 103 | "LemurQuestion", 104 | "LemurQuestionAnswer", 105 | "LemurQuestionResponse", 106 | "LemurStringResponse", 107 | "LemurSummaryResponse", 108 | "LemurTaskResponse", 109 | "LemurUsage", 110 | "ListTranscriptParameters", 111 | "ListTranscriptResponse", 112 | "PageDetails", 113 | "Sentence", 114 | "Sentiment", 115 | "SentimentType", 116 | "Settings", 117 | "SpeechModel", 118 | "StatusResult", 119 | "SummarizationModel", 120 | "SummarizationType", 121 | "Timestamp", 122 | "Transcriber", 123 | "TranscriptionConfig", 124 | "Transcript", 125 | "TranscriptError", 126 | "TranscriptGroup", 127 | "TranscriptItem", 128 | "TranscriptStatus", 129 | "Utterance", 130 | "UtteranceWord", 131 | "Paragraph", 132 | "PIIRedactedAudioQuality", 133 | "PIISubstitutionPolicy", 134 | "PIIRedactionPolicy", 135 | "RawTranscriptionConfig", 136 | "Word", 137 | "WordBoost", 138 | "WordSearchMatch", 139 | "RealtimeTranscriber", 140 | "RealtimeError", 141 | "RealtimeFinalTranscript", 142 | "RealtimePartialTranscript", 143 | "RealtimeSessionInformation", 144 | "RealtimeSessionOpened", 145 | "RealtimeTranscript", 146 | "RealtimeWord", 147 | # package globals 148 | "settings", 149 | # packages 150 | "extras", 151 | # version 152 | "__version__", 153 | ] 154 | -------------------------------------------------------------------------------- /assemblyai/__version__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.41.1" 2 | -------------------------------------------------------------------------------- /assemblyai/api.py: -------------------------------------------------------------------------------- 1 | from typing import BinaryIO, List, Optional, Union 2 | from urllib.parse import urlencode 3 | 4 | import httpx 5 | 6 | from . import types 7 | 8 | ENDPOINT_TRANSCRIPT = "/v2/transcript" 9 | ENDPOINT_UPLOAD = "/v2/upload" 10 | ENDPOINT_LEMUR_BASE = "/lemur/v3" 11 | ENDPOINT_LEMUR = f"{ENDPOINT_LEMUR_BASE}/generate" 12 | ENDPOINT_REALTIME_WEBSOCKET = "/v2/realtime/ws" 13 | ENDPOINT_REALTIME_TOKEN = "/v2/realtime/token" 14 | 15 | 16 | def _get_error_message(response: httpx.Response) -> str: 17 | """ 18 | Tries to retrieve the `error` field if the response is JSON, otherwise 19 | returns the response text. 20 | 21 | Args: 22 | `response`: the HTTP response 23 | 24 | Returns: the error message 25 | """ 26 | 27 | try: 28 | return response.json()["error"] 29 | except Exception: 30 | return f"\nReason: {response.text}\nRequest: {response.request}" 31 | 32 | 33 | def create_transcript( 34 | client: httpx.Client, 35 | request: types.TranscriptRequest, 36 | ) -> types.TranscriptResponse: 37 | response = client.post( 38 | ENDPOINT_TRANSCRIPT, 39 | json=request.dict( 40 | exclude_none=True, 41 | by_alias=True, 42 | ), 43 | ) 44 | if response.status_code != httpx.codes.OK: 45 | raise types.TranscriptError( 46 | f"failed to transcribe url {request.audio_url}: {_get_error_message(response)}", 47 | response.status_code, 48 | ) 49 | 50 | return types.TranscriptResponse.parse_obj(response.json()) 51 | 52 | 53 | def get_transcript( 54 | client: httpx.Client, 55 | transcript_id: str, 56 | ) -> types.TranscriptResponse: 57 | response = client.get( 58 | f"{ENDPOINT_TRANSCRIPT}/{transcript_id}", 59 | ) 60 | 61 | if response.status_code != httpx.codes.OK: 62 | raise types.TranscriptError( 63 | f"failed to retrieve transcript {transcript_id}: {_get_error_message(response)}", 64 | response.status_code, 65 | ) 66 | 67 | return types.TranscriptResponse.parse_obj(response.json()) 68 | 69 | 70 | def delete_transcript( 71 | client: httpx.Client, 72 | transcript_id: str, 73 | ) -> types.TranscriptResponse: 74 | response = client.delete( 75 | f"{ENDPOINT_TRANSCRIPT}/{transcript_id}", 76 | ) 77 | 78 | if response.status_code != httpx.codes.OK: 79 | raise types.TranscriptError( 80 | f"failed to delete transcript {transcript_id}: {_get_error_message(response)}", 81 | response.status_code, 82 | ) 83 | 84 | return types.TranscriptResponse.parse_obj(response.json()) 85 | 86 | 87 | def upload_file( 88 | client: httpx.Client, 89 | audio_file: BinaryIO, 90 | ) -> str: 91 | """ 92 | Uploads the given file. 93 | 94 | Args: 95 | `client`: the HTTP client 96 | `audio_file`: an opened file (in binary mode) 97 | 98 | Returns: The URL of the uploaded audio file. 99 | """ 100 | 101 | response = client.post( 102 | ENDPOINT_UPLOAD, 103 | content=audio_file, 104 | ) 105 | 106 | if response.status_code != httpx.codes.OK: 107 | raise types.TranscriptError( 108 | f"Failed to upload audio file: {_get_error_message(response)}", 109 | response.status_code, 110 | ) 111 | 112 | return response.json()["upload_url"] 113 | 114 | 115 | def export_subtitles_srt( 116 | client: httpx.Client, 117 | transcript_id: str, 118 | chars_per_caption: Optional[int], 119 | ) -> str: 120 | params = {} 121 | 122 | if chars_per_caption: 123 | params = { 124 | "chars_per_caption": chars_per_caption, 125 | } 126 | 127 | response = client.get( 128 | f"{ENDPOINT_TRANSCRIPT}/{transcript_id}/srt", 129 | params=params, 130 | ) 131 | 132 | if response.status_code != httpx.codes.OK: 133 | raise types.TranscriptError( 134 | f"failed to export SRT for transcript {transcript_id}: {_get_error_message(response)}", 135 | response.status_code, 136 | ) 137 | 138 | return response.text 139 | 140 | 141 | def export_subtitles_vtt( 142 | client: httpx.Client, 143 | transcript_id: str, 144 | chars_per_caption: Optional[int], 145 | ) -> str: 146 | params = {} 147 | 148 | if chars_per_caption: 149 | params = { 150 | "chars_per_caption": chars_per_caption, 151 | } 152 | 153 | response = client.get( 154 | f"{ENDPOINT_TRANSCRIPT}/{transcript_id}/vtt", 155 | params=params, 156 | ) 157 | 158 | if response.status_code != httpx.codes.OK: 159 | raise types.TranscriptError( 160 | f"failed to export VTT for transcript {transcript_id}: {_get_error_message(response)}", 161 | response.status_code, 162 | ) 163 | 164 | return response.text 165 | 166 | 167 | def word_search( 168 | client: httpx.Client, 169 | transcript_id: str, 170 | words: List[str], 171 | ) -> types.WordSearchMatchResponse: 172 | response = client.get( 173 | f"{ENDPOINT_TRANSCRIPT}/{transcript_id}/word-search", 174 | params=urlencode( 175 | { 176 | "words": ",".join(words), 177 | } 178 | ), 179 | ) 180 | 181 | if response.status_code != httpx.codes.OK: 182 | raise types.TranscriptError( 183 | f"failed to search words in transcript {transcript_id}: {_get_error_message(response)}", 184 | response.status_code, 185 | ) 186 | 187 | return types.WordSearchMatchResponse.parse_obj(response.json()) 188 | 189 | 190 | def get_redacted_audio( 191 | client: httpx.Client, transcript_id: str 192 | ) -> types.RedactedAudioResponse: 193 | """ 194 | Retrieves the object containing the redacted audio URL for the given transcript. 195 | 196 | Raises: 197 | RedactedAudioIncompleteError: If response indicates that the redacted audio is still processing 198 | RedactedAudioUnavailableError: If response indicates that the redacted audio is not available 199 | TranscriptError: If we fail to get a valid response from the API at all 200 | 201 | Returns: 202 | `RedactedAudioResponse`, which contains the URL of the redacted audio 203 | """ 204 | 205 | response = client.get(f"{ENDPOINT_TRANSCRIPT}/{transcript_id}/redacted-audio") 206 | 207 | if response.status_code == httpx.codes.ACCEPTED: 208 | raise types.RedactedAudioIncompleteError( 209 | f"redacted audio for transcript {transcript_id} is not ready yet", 210 | response.status_code, 211 | ) 212 | 213 | if response.status_code == httpx.codes.BAD_REQUEST: 214 | raise types.RedactedAudioExpiredError( 215 | f"redacted audio for transcript {transcript_id} is no longer available", 216 | response.status_code, 217 | ) 218 | 219 | if response.status_code != httpx.codes.OK: 220 | raise types.TranscriptError( 221 | f"failed to retrieve redacted audio for transcript {transcript_id}: {_get_error_message(response)}", 222 | response.status_code, 223 | ) 224 | 225 | return types.RedactedAudioResponse.parse_obj(response.json()) 226 | 227 | 228 | def get_sentences( 229 | client: httpx.Client, 230 | transcript_id: str, 231 | ) -> types.SentencesResponse: 232 | response = client.get( 233 | f"{ENDPOINT_TRANSCRIPT}/{transcript_id}/sentences", 234 | ) 235 | 236 | if response.status_code != httpx.codes.OK: 237 | raise types.TranscriptError( 238 | f"failed to retrieve sentences for transcript {transcript_id}: {_get_error_message(response)}", 239 | response.status_code, 240 | ) 241 | 242 | return types.SentencesResponse.parse_obj(response.json()) 243 | 244 | 245 | def get_paragraphs( 246 | client: httpx.Client, 247 | transcript_id: str, 248 | ) -> types.ParagraphsResponse: 249 | response = client.get( 250 | f"{ENDPOINT_TRANSCRIPT}/{transcript_id}/paragraphs", 251 | ) 252 | 253 | if response.status_code != httpx.codes.OK: 254 | raise types.TranscriptError( 255 | f"failed to retrieve paragraphs for transcript {transcript_id}: {_get_error_message(response)}", 256 | response.status_code, 257 | ) 258 | 259 | return types.ParagraphsResponse.parse_obj(response.json()) 260 | 261 | 262 | def list_transcripts( 263 | client: httpx.Client, 264 | params: Optional[types.ListTranscriptParameters], 265 | ) -> types.ListTranscriptResponse: 266 | response = client.get( 267 | ENDPOINT_TRANSCRIPT, 268 | params=( 269 | params.dict( 270 | exclude_none=True, 271 | ) 272 | if params 273 | else None 274 | ), 275 | ) 276 | 277 | if response.status_code != httpx.codes.OK: 278 | raise types.AssemblyAIError( 279 | f"failed to retrieve transcripts: {_get_error_message(response)}", 280 | response.status_code, 281 | ) 282 | 283 | return types.ListTranscriptResponse.parse_obj(response.json()) 284 | 285 | 286 | def lemur_question( 287 | client: httpx.Client, 288 | request: types.LemurQuestionRequest, 289 | http_timeout: Optional[float], 290 | ) -> types.LemurQuestionResponse: 291 | response = client.post( 292 | f"{ENDPOINT_LEMUR}/question-answer", 293 | json=request.dict( 294 | exclude_none=True, 295 | ), 296 | timeout=http_timeout, 297 | ) 298 | 299 | if response.status_code != httpx.codes.OK: 300 | raise types.LemurError( 301 | f"failed to call Lemur questions: {_get_error_message(response)}", 302 | response.status_code, 303 | ) 304 | 305 | return types.LemurQuestionResponse.parse_obj(response.json()) 306 | 307 | 308 | def lemur_summarize( 309 | client: httpx.Client, 310 | request: types.LemurSummaryRequest, 311 | http_timeout: Optional[float], 312 | ) -> types.LemurSummaryResponse: 313 | response = client.post( 314 | f"{ENDPOINT_LEMUR}/summary", 315 | json=request.dict( 316 | exclude_none=True, 317 | ), 318 | timeout=http_timeout, 319 | ) 320 | 321 | if response.status_code != httpx.codes.OK: 322 | raise types.LemurError( 323 | f"failed to call Lemur summary: {_get_error_message(response)}", 324 | response.status_code, 325 | ) 326 | 327 | return types.LemurSummaryResponse.parse_obj(response.json()) 328 | 329 | 330 | def lemur_action_items( 331 | client: httpx.Client, 332 | request: types.LemurActionItemsRequest, 333 | http_timeout: Optional[float], 334 | ) -> types.LemurActionItemsResponse: 335 | response = client.post( 336 | f"{ENDPOINT_LEMUR}/action-items", 337 | json=request.dict( 338 | exclude_none=True, 339 | ), 340 | timeout=http_timeout, 341 | ) 342 | 343 | if response.status_code != httpx.codes.OK: 344 | raise types.LemurError( 345 | f"failed to call Lemur action items: {_get_error_message(response)}", 346 | response.status_code, 347 | ) 348 | 349 | return types.LemurActionItemsResponse.parse_obj(response.json()) 350 | 351 | 352 | def lemur_task( 353 | client: httpx.Client, 354 | request: types.LemurTaskRequest, 355 | http_timeout: Optional[float], 356 | ) -> types.LemurTaskResponse: 357 | response = client.post( 358 | f"{ENDPOINT_LEMUR}/task", 359 | json=request.dict( 360 | exclude_none=True, 361 | ), 362 | timeout=http_timeout, 363 | ) 364 | 365 | if response.status_code != httpx.codes.OK: 366 | raise types.LemurError( 367 | f"failed to call Lemur task: {_get_error_message(response)}", 368 | response.status_code, 369 | ) 370 | 371 | return types.LemurTaskResponse.parse_obj(response.json()) 372 | 373 | 374 | def lemur_purge_request_data( 375 | client: httpx.Client, 376 | request: types.LemurPurgeRequest, 377 | http_timeout: Optional[float], 378 | ) -> types.LemurPurgeResponse: 379 | response = client.delete( 380 | f"{ENDPOINT_LEMUR_BASE}/{request.request_id}", 381 | timeout=http_timeout, 382 | ) 383 | 384 | if response.status_code != httpx.codes.OK: 385 | raise types.LemurError( 386 | f"Failed to purge LeMUR request data for provided request ID: {request.request_id}. Error: {_get_error_message(response)}", 387 | response.status_code, 388 | ) 389 | 390 | return types.LemurPurgeResponse.parse_obj(response.json()) 391 | 392 | 393 | def lemur_get_response_data( 394 | client: httpx.Client, 395 | request_id: str, 396 | http_timeout: Optional[float], 397 | ) -> Union[ 398 | types.LemurStringResponse, 399 | types.LemurQuestionResponse, 400 | ]: 401 | response = client.get( 402 | f"{ENDPOINT_LEMUR_BASE}/{request_id}", 403 | timeout=http_timeout, 404 | ) 405 | 406 | if response.status_code != httpx.codes.OK: 407 | raise types.LemurError( 408 | f"Failed to get LeMUR response data for provided request ID: {request_id}. Error: {_get_error_message(response)}", 409 | response.status_code, 410 | ) 411 | 412 | json_data = response.json() 413 | 414 | if isinstance(json_data.get("response"), list): 415 | return types.LemurQuestionResponse.parse_obj(json_data) 416 | 417 | return types.LemurStringResponse.parse_obj(json_data) 418 | 419 | 420 | def create_temporary_token( 421 | client: httpx.Client, 422 | request: types.RealtimeCreateTemporaryTokenRequest, 423 | http_timeout: Optional[float], 424 | ) -> str: 425 | response = client.post( 426 | f"{ENDPOINT_REALTIME_TOKEN}", 427 | json=request.dict(exclude_none=True), 428 | timeout=http_timeout, 429 | ) 430 | 431 | if response.status_code != httpx.codes.OK: 432 | raise types.AssemblyAIError( 433 | f"Failed to create temporary token: {_get_error_message(response)}", 434 | response.status_code, 435 | ) 436 | 437 | data = types.RealtimeCreateTemporaryTokenResponse.parse_obj(response.json()) 438 | return data.token 439 | -------------------------------------------------------------------------------- /assemblyai/client.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import threading 3 | from typing import ClassVar, Optional 4 | 5 | import httpx 6 | 7 | from . import types 8 | from .__version__ import __version__ 9 | 10 | 11 | class Client: 12 | _default: ClassVar[Optional["Client"]] = None 13 | _lock: ClassVar[threading.Lock] = threading.Lock() 14 | 15 | def __init__( 16 | self, 17 | *, 18 | settings: types.Settings, 19 | api_key_required: bool = True, 20 | ) -> None: 21 | """ 22 | Creates the AssemblyAI client. 23 | 24 | Args: 25 | settings: The settings to use for the client. 26 | api_key_required: If an API key is required (either as environment variable or the global settings). 27 | Can be set to `False` if a different authentication method is used, e.g., a temporary token. 28 | """ 29 | 30 | self._settings = settings.copy() 31 | 32 | if api_key_required and not self._settings.api_key: 33 | raise ValueError( 34 | "Please provide an API key via the ASSEMBLYAI_API_KEY environment variable or the global settings." 35 | ) 36 | 37 | vi = sys.version_info 38 | python_version = f"{vi.major}.{vi.minor}.{vi.micro}" 39 | user_agent = f"{httpx._client.USER_AGENT} AssemblyAI/1.0 (sdk=Python/{__version__} runtime_env=Python/{python_version})" 40 | 41 | headers = {"user-agent": user_agent} 42 | if self._settings.api_key: 43 | headers["authorization"] = self._settings.api_key 44 | 45 | self._last_response: Optional[httpx.Response] = None 46 | 47 | def _store_response(response): 48 | self._last_response = response 49 | 50 | self._http_client = httpx.Client( 51 | base_url=self.settings.base_url, 52 | headers=headers, 53 | timeout=self.settings.http_timeout, 54 | event_hooks={"response": [_store_response]}, 55 | ) 56 | 57 | @property 58 | def last_response(self) -> Optional[httpx.Response]: 59 | """ 60 | Get the last HTTP response, corresponding to the last request sent from this client. 61 | 62 | Returns: 63 | The last HTTP response. 64 | """ 65 | return self._last_response 66 | 67 | @property 68 | def settings(self) -> types.Settings: 69 | """ 70 | Get the current settings. 71 | 72 | Returns: 73 | The current settings. 74 | """ 75 | 76 | return self._settings 77 | 78 | @property 79 | def http_client(self) -> httpx.Client: 80 | """ 81 | Get the current HTTP client. 82 | 83 | Returns: 84 | The current HTTP client. 85 | """ 86 | 87 | return self._http_client 88 | 89 | @classmethod 90 | def get_default(cls, api_key_required: bool = True): 91 | """ 92 | Return the default client. 93 | 94 | Args: 95 | api_key_required: If the default client requires an API key. 96 | 97 | Returns: 98 | The default client with the default settings. 99 | """ 100 | from . import settings as default_settings 101 | 102 | if cls._default is None or cls._default.settings != default_settings: 103 | with cls._lock: 104 | if cls._default is None or cls._default.settings != default_settings: 105 | cls._default = cls( 106 | settings=default_settings, api_key_required=api_key_required 107 | ) 108 | 109 | return cls._default 110 | -------------------------------------------------------------------------------- /assemblyai/extras.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import BinaryIO, Generator, Optional 3 | from warnings import warn 4 | 5 | from . import api 6 | from .client import Client 7 | 8 | 9 | class AssemblyAIExtrasNotInstalledError(ImportError): 10 | def __init__( 11 | self, 12 | msg=""" 13 | You must install the extras for this SDK to use this feature. 14 | Run `pip install "assemblyai[extras]"` to install the extras. 15 | Make sure to install `apt install portaudio19-dev` (Debian/Ubuntu) or 16 | `brew install portaudio` (MacOS) before installing the extras 17 | """, 18 | *args, 19 | **kwargs, 20 | ): 21 | super().__init__(msg, *args, **kwargs) 22 | 23 | 24 | class MicrophoneStream: 25 | def __init__(self, sample_rate: int = 44_100, device_index: Optional[int] = None): 26 | """ 27 | Creates a stream of audio from the microphone. 28 | 29 | Args: 30 | sample_rate: The sample rate to record audio at. 31 | device_index: The index of the input device to use. If None, uses the default device. 32 | """ 33 | try: 34 | import pyaudio 35 | except ImportError: 36 | raise AssemblyAIExtrasNotInstalledError 37 | 38 | self._pyaudio = pyaudio.PyAudio() 39 | self.sample_rate = sample_rate 40 | 41 | self._chunk_size = int(self.sample_rate * 0.1) 42 | self._stream = self._pyaudio.open( 43 | format=pyaudio.paInt16, 44 | channels=1, 45 | rate=sample_rate, 46 | input=True, 47 | frames_per_buffer=self._chunk_size, 48 | input_device_index=device_index, 49 | ) 50 | 51 | self._open = True 52 | 53 | def __iter__(self): 54 | """ 55 | Returns the iterator object. 56 | """ 57 | 58 | return self 59 | 60 | def __next__(self): 61 | """ 62 | Reads a chunk of audio from the microphone. 63 | """ 64 | if not self._open: 65 | raise StopIteration 66 | 67 | try: 68 | return self._stream.read(self._chunk_size) 69 | except KeyboardInterrupt: 70 | raise StopIteration 71 | 72 | def close(self): 73 | """ 74 | Closes the stream. 75 | """ 76 | 77 | self._open = False 78 | 79 | if self._stream.is_active(): 80 | self._stream.stop_stream() 81 | 82 | self._stream.close() 83 | self._pyaudio.terminate() 84 | 85 | 86 | def stream_file( 87 | filepath: str, 88 | sample_rate: int, 89 | ) -> Generator[bytes, None, None]: 90 | """ 91 | Mimics a stream of audio data by reading it chunk by chunk from a file. 92 | 93 | NOTE: Only supports WAV/PCM16 files as of now. 94 | 95 | Args: 96 | filepath: The path to the file to stream. 97 | sample_rate: The sample rate of the audio file. 98 | 99 | Returns: A generator that yields chunks of audio data. 100 | """ 101 | chunk_duration = 0.3 102 | with open(filepath, "rb") as f: 103 | while True: 104 | # send in 300ms segments (2 bytes per frame) 105 | data = f.read(int(sample_rate * chunk_duration) * 2) 106 | 107 | if not data: 108 | break 109 | 110 | yield data 111 | 112 | time.sleep(chunk_duration) 113 | 114 | 115 | def file_from_stream(data: BinaryIO) -> str: 116 | """ 117 | DeprecationWarning: `file_from_stream()` is deprecated and will be removed in 1.0.0. Use `Transcriber.upload_file()` instead. 118 | 119 | Uploads the given stream and returns the uploaded audio url. 120 | 121 | This function can be used to transcribe data that's already 122 | available in memory. 123 | 124 | Example: 125 | ``` 126 | upload_url = aai.extras.file_from_stream(data) 127 | 128 | transcriber = aai.Transcriber() 129 | transcript = transcriber.transcribe(upload_url) 130 | ``` 131 | 132 | Args: 133 | `data`: A file-like object (in binary mode) 134 | """ 135 | warn( 136 | "`file_from_stream()` is deprecated and will be removed in 1.0.0. Use `Transcriber.upload_file()` instead.", 137 | DeprecationWarning, 138 | stacklevel=2, 139 | ) 140 | return api.upload_file( 141 | client=Client.get_default().http_client, 142 | audio_file=data, 143 | ) 144 | -------------------------------------------------------------------------------- /assemblyai/lemur.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import concurrent.futures 4 | from typing import Any, Dict, List, Optional, Union 5 | 6 | from . import api, types 7 | from . import client as _client 8 | 9 | 10 | class _LemurImpl: 11 | def __init__( 12 | self, 13 | *, 14 | client: _client.Client, 15 | sources: Optional[List[types.LemurSource]], 16 | ) -> None: 17 | self._client = client 18 | 19 | self._sources = ( 20 | [types.LemurSourceRequest.from_lemur_source(s) for s in sources] 21 | if sources is not None 22 | else [] 23 | ) 24 | 25 | def question( 26 | self, 27 | questions: List[types.LemurQuestion], 28 | context: Optional[Union[str, Dict[str, Any]]], 29 | timeout: Optional[float], 30 | final_model: Optional[types.LemurModel], 31 | max_output_size: Optional[int], 32 | temperature: Optional[float], 33 | input_text: Optional[str], 34 | ) -> types.LemurQuestionResponse: 35 | response = api.lemur_question( 36 | client=self._client.http_client, 37 | request=types.LemurQuestionRequest( 38 | sources=self._sources, 39 | questions=questions, 40 | context=context, 41 | final_model=final_model, 42 | max_output_size=max_output_size, 43 | temperature=temperature, 44 | input_text=input_text, 45 | ), 46 | http_timeout=timeout, 47 | ) 48 | 49 | return response 50 | 51 | def summarize( 52 | self, 53 | context: Optional[Union[str, Dict[str, Any]]], 54 | answer_format: Optional[str], 55 | final_model: Optional[types.LemurModel], 56 | max_output_size: Optional[int], 57 | timeout: Optional[float], 58 | temperature: Optional[float], 59 | input_text: Optional[str], 60 | ) -> types.LemurSummaryResponse: 61 | response = api.lemur_summarize( 62 | client=self._client.http_client, 63 | request=types.LemurSummaryRequest( 64 | sources=self._sources, 65 | context=context, 66 | answer_format=answer_format, 67 | final_model=final_model, 68 | max_output_size=max_output_size, 69 | temperature=temperature, 70 | input_text=input_text, 71 | ), 72 | http_timeout=timeout, 73 | ) 74 | 75 | return response 76 | 77 | def action_items( 78 | self, 79 | context: Optional[Union[str, Dict[str, Any]]], 80 | answer_format: Optional[str], 81 | final_model: Optional[types.LemurModel], 82 | max_output_size: Optional[int], 83 | timeout: Optional[float], 84 | temperature: Optional[float], 85 | input_text: Optional[str], 86 | ) -> types.LemurActionItemsResponse: 87 | response = api.lemur_action_items( 88 | client=self._client.http_client, 89 | request=types.LemurActionItemsRequest( 90 | sources=self._sources, 91 | context=context, 92 | answer_format=answer_format, 93 | final_model=final_model, 94 | max_output_size=max_output_size, 95 | temperature=temperature, 96 | input_text=input_text, 97 | ), 98 | http_timeout=timeout, 99 | ) 100 | 101 | return response 102 | 103 | def task( 104 | self, 105 | prompt: str, 106 | context: Optional[Union[str, Dict[str, Any]]], 107 | final_model: Optional[types.LemurModel], 108 | max_output_size: Optional[int], 109 | timeout: Optional[float], 110 | temperature: Optional[float], 111 | input_text: Optional[str], 112 | ): 113 | response = api.lemur_task( 114 | client=self._client.http_client, 115 | request=types.LemurTaskRequest( 116 | sources=self._sources, 117 | prompt=prompt, 118 | context=context, 119 | final_model=final_model, 120 | max_output_size=max_output_size, 121 | temperature=temperature, 122 | input_text=input_text, 123 | ), 124 | http_timeout=timeout, 125 | ) 126 | 127 | return response 128 | 129 | @classmethod 130 | def purge_request_data( 131 | cls, 132 | request_id: str, 133 | timeout: Optional[float] = None, 134 | ) -> types.LemurPurgeResponse: 135 | response = api.lemur_purge_request_data( 136 | client=_client.Client.get_default().http_client, 137 | request=types.LemurPurgeRequest( 138 | request_id=request_id, 139 | ), 140 | http_timeout=timeout, 141 | ) 142 | 143 | return response 144 | 145 | def get_response_data( 146 | self, 147 | request_id: str, 148 | timeout: Optional[float] = None, 149 | ) -> Union[ 150 | types.LemurStringResponse, 151 | types.LemurQuestionResponse, 152 | ]: 153 | response = api.lemur_get_response_data( 154 | client=_client.Client.get_default().http_client, 155 | request_id=request_id, 156 | http_timeout=timeout, 157 | ) 158 | 159 | return response 160 | 161 | 162 | class Lemur: 163 | """ 164 | AssemblyAI's LeMUR (Leveraging Large Language Models to Understand Recognized Speech) framework 165 | to process audio files with an LLM. 166 | 167 | See https://www.assemblyai.com/docs/Models/lemur for more information. 168 | """ 169 | 170 | def __init__( 171 | self, 172 | sources: Optional[List[types.LemurSource]] = None, 173 | client: Optional[_client.Client] = None, 174 | ) -> None: 175 | """ 176 | Creates a new LeMUR instance to process audio files with an LLM. 177 | 178 | Args: 179 | 180 | sources: One or a list of sources to process (e.g. a `Transcript` or a `TranscriptGroup`) 181 | client: The client to use for the LeMUR instance. If not provided, the default client will be used 182 | """ 183 | self._client = client or _client.Client.get_default() 184 | 185 | self._impl = _LemurImpl( 186 | client=self._client, 187 | sources=sources, 188 | ) 189 | self._executor = concurrent.futures.ThreadPoolExecutor() 190 | 191 | def question( 192 | self, 193 | questions: Union[types.LemurQuestion, List[types.LemurQuestion]], 194 | context: Optional[Union[str, Dict[str, Any]]] = None, 195 | final_model: Optional[types.LemurModel] = None, 196 | max_output_size: Optional[int] = None, 197 | timeout: Optional[float] = None, 198 | temperature: Optional[float] = None, 199 | input_text: Optional[str] = None, 200 | ) -> types.LemurQuestionResponse: 201 | """ 202 | Question & Answer allows you to ask free form questions about one or many transcripts. 203 | 204 | This can be any question you find useful, such as judging the outcome or determining facts 205 | about the audio. For instance, you can ask for action items from a meeting, did the customer 206 | respond positively, or count how many times a word or phrase was said. 207 | 208 | See also Best Practices on LeMUR: https://www.assemblyai.com/docs/Guides/lemur_best_practices 209 | 210 | Args: 211 | questions: One or a list of questions to ask. 212 | context: The context which is shared among all questions. This can be a string or a dictionary. 213 | final_model: The model that is used for the final prompt after compression is performed. 214 | max_output_size: Max output size in tokens 215 | timeout: The timeout in seconds to wait for the answer(s). 216 | temperature: Change how deterministic the response is, with 0 being the most deterministic and 1 being the least deterministic. 217 | input_text: Custom formatted transcript data. Use instead of transcript_ids. 218 | 219 | Returns: One or a list of answer objects. 220 | """ 221 | 222 | if not isinstance(questions, list): 223 | questions = [questions] 224 | 225 | return self._impl.question( 226 | questions=questions, 227 | context=context, 228 | final_model=final_model, 229 | max_output_size=max_output_size, 230 | timeout=timeout, 231 | temperature=temperature, 232 | input_text=input_text, 233 | ) 234 | 235 | def question_async( 236 | self, 237 | questions: Union[types.LemurQuestion, List[types.LemurQuestion]], 238 | context: Optional[Union[str, Dict[str, Any]]] = None, 239 | final_model: Optional[types.LemurModel] = None, 240 | max_output_size: Optional[int] = None, 241 | timeout: Optional[float] = None, 242 | temperature: Optional[float] = None, 243 | input_text: Optional[str] = None, 244 | ) -> concurrent.futures.Future[types.LemurQuestionResponse]: 245 | """ 246 | Question & Answer allows you to ask free form questions about one or many transcripts. 247 | 248 | This can be any question you find useful, such as judging the outcome or determining facts 249 | about the audio. For instance, you can ask for action items from a meeting, did the customer 250 | respond positively, or count how many times a word or phrase was said. 251 | 252 | See also Best Practices on LeMUR: https://www.assemblyai.com/docs/Guides/lemur_best_practices 253 | 254 | Args: 255 | questions: One or a list of questions to ask. 256 | context: The context which is shared among all questions. This can be a string or a dictionary. 257 | final_model: The model that is used for the final prompt after compression is performed. 258 | max_output_size: Max output size in tokens 259 | timeout: The timeout in seconds to wait for the answer(s). 260 | temperature: Change how deterministic the response is, with 0 being the most deterministic and 1 being the least deterministic. 261 | input_text: Custom formatted transcript data. Use instead of transcript_ids. 262 | 263 | Returns: One or a list of answer objects. 264 | """ 265 | 266 | if not isinstance(questions, list): 267 | questions = [questions] 268 | 269 | return self._executor.submit( 270 | self._impl.question, 271 | questions=questions, 272 | context=context, 273 | final_model=final_model, 274 | max_output_size=max_output_size, 275 | timeout=timeout, 276 | temperature=temperature, 277 | input_text=input_text, 278 | ) 279 | 280 | def summarize( 281 | self, 282 | context: Optional[Union[str, Dict[str, Any]]] = None, 283 | answer_format: Optional[str] = None, 284 | final_model: Optional[types.LemurModel] = None, 285 | max_output_size: Optional[int] = None, 286 | timeout: Optional[float] = None, 287 | temperature: Optional[float] = None, 288 | input_text: Optional[str] = None, 289 | ) -> types.LemurSummaryResponse: 290 | """ 291 | Summary allows you to distill a piece of audio into a few impactful sentences. 292 | You can give the model context to get more pinpoint results while outputting the 293 | results in a variety of formats described in human language. 294 | 295 | See also Best Practices on LeMUR: https://www.assemblyai.com/docs/Guides/lemur_best_practices 296 | 297 | Args: 298 | context: An optional context on the transcript. 299 | answer_format: The format on how the summary shall be summarized. 300 | final_model: The model that is used for the final prompt after compression is performed. 301 | max_output_size: Max output size in tokens 302 | timeout: The timeout in seconds to wait for the summary. 303 | temperature: Change how deterministic the response is, with 0 being the most deterministic and 1 being the least deterministic. 304 | input_text: Custom formatted transcript data. Use instead of transcript_ids. 305 | 306 | Returns: The summary as a string. 307 | """ 308 | 309 | return self._impl.summarize( 310 | context=context, 311 | answer_format=answer_format, 312 | final_model=final_model, 313 | max_output_size=max_output_size, 314 | timeout=timeout, 315 | temperature=temperature, 316 | input_text=input_text, 317 | ) 318 | 319 | def summarize_async( 320 | self, 321 | context: Optional[Union[str, Dict[str, Any]]] = None, 322 | answer_format: Optional[str] = None, 323 | final_model: Optional[types.LemurModel] = None, 324 | max_output_size: Optional[int] = None, 325 | timeout: Optional[float] = None, 326 | temperature: Optional[float] = None, 327 | input_text: Optional[str] = None, 328 | ) -> concurrent.futures.Future[types.LemurSummaryResponse]: 329 | """ 330 | Summary allows you to distill a piece of audio into a few impactful sentences. 331 | You can give the model context to get more pinpoint results while outputting the 332 | results in a variety of formats described in human language. 333 | 334 | See also Best Practices on LeMUR: https://www.assemblyai.com/docs/Guides/lemur_best_practices 335 | 336 | Args: 337 | context: An optional context on the transcript. 338 | answer_format: The format on how the summary shall be summarized. 339 | final_model: The model that is used for the final prompt after compression is performed. 340 | max_output_size: Max output size in tokens 341 | timeout: The timeout in seconds to wait for the summary. 342 | temperature: Change how deterministic the response is, with 0 being the most deterministic and 1 being the least deterministic. 343 | input_text: Custom formatted transcript data. Use instead of transcript_ids. 344 | 345 | Returns: The summary as a string. 346 | """ 347 | 348 | return self._executor.submit( 349 | self._impl.summarize, 350 | context=context, 351 | answer_format=answer_format, 352 | final_model=final_model, 353 | max_output_size=max_output_size, 354 | timeout=timeout, 355 | temperature=temperature, 356 | input_text=input_text, 357 | ) 358 | 359 | def action_items( 360 | self, 361 | context: Optional[Union[str, Dict[str, Any]]] = None, 362 | answer_format: Optional[str] = None, 363 | final_model: Optional[types.LemurModel] = None, 364 | max_output_size: Optional[int] = None, 365 | timeout: Optional[float] = None, 366 | temperature: Optional[float] = None, 367 | input_text: Optional[str] = None, 368 | ) -> types.LemurActionItemsResponse: 369 | """ 370 | Action Items allows you to generate action items from one or many transcripts. 371 | 372 | You can provide the model with a context to get more pinpoint results while outputting the 373 | results in a variety of formats described in human language. 374 | 375 | See also Best Practices on LeMUR: https://www.assemblyai.com/docs/Guides/lemur_best_practices 376 | 377 | Args: 378 | context: An optional context on the transcript. 379 | answer_format: The preferred format for the result action items. 380 | final_model: The model that is used for the final prompt after compression is performed. 381 | max_output_size: Max output size in tokens 382 | timeout: The timeout in seconds to wait for the action items response. 383 | temperature: Change how deterministic the response is, with 0 being the most deterministic and 1 being the least deterministic. 384 | input_text: Custom formatted transcript data. Use instead of transcript_ids. 385 | 386 | Returns: The action items as a string. 387 | """ 388 | 389 | return self._impl.action_items( 390 | context=context, 391 | answer_format=answer_format, 392 | final_model=final_model, 393 | max_output_size=max_output_size, 394 | timeout=timeout, 395 | temperature=temperature, 396 | input_text=input_text, 397 | ) 398 | 399 | def action_items_async( 400 | self, 401 | context: Optional[Union[str, Dict[str, Any]]] = None, 402 | answer_format: Optional[str] = None, 403 | final_model: Optional[types.LemurModel] = None, 404 | max_output_size: Optional[int] = None, 405 | timeout: Optional[float] = None, 406 | temperature: Optional[float] = None, 407 | input_text: Optional[str] = None, 408 | ) -> concurrent.futures.Future[types.LemurActionItemsResponse]: 409 | """ 410 | Action Items allows you to generate action items from one or many transcripts. 411 | 412 | You can provide the model with a context to get more pinpoint results while outputting the 413 | results in a variety of formats described in human language. 414 | 415 | See also Best Practices on LeMUR: https://www.assemblyai.com/docs/Guides/lemur_best_practices 416 | 417 | Args: 418 | context: An optional context on the transcript. 419 | answer_format: The preferred format for the result action items. 420 | final_model: The model that is used for the final prompt after compression is performed. 421 | max_output_size: Max output size in tokens 422 | timeout: The timeout in seconds to wait for the action items response. 423 | temperature: Change how deterministic the response is, with 0 being the most deterministic and 1 being the least deterministic. 424 | input_text: Custom formatted transcript data. Use instead of transcript_ids. 425 | 426 | Returns: The action items as a string. 427 | """ 428 | 429 | return self._executor.submit( 430 | self._impl.action_items, 431 | context=context, 432 | answer_format=answer_format, 433 | final_model=final_model, 434 | max_output_size=max_output_size, 435 | timeout=timeout, 436 | temperature=temperature, 437 | input_text=input_text, 438 | ) 439 | 440 | def task( 441 | self, 442 | prompt: str, 443 | context: Optional[Union[str, Dict[str, Any]]] = None, 444 | final_model: Optional[types.LemurModel] = None, 445 | max_output_size: Optional[int] = None, 446 | timeout: Optional[float] = None, 447 | temperature: Optional[float] = None, 448 | input_text: Optional[str] = None, 449 | ) -> types.LemurTaskResponse: 450 | """ 451 | Task feature allows you to submit a custom prompt to the model. 452 | 453 | See also Best Practices on LeMUR: https://www.assemblyai.com/docs/Guides/lemur_best_practices 454 | 455 | Args: 456 | prompt: The prompt to use for this task. 457 | context: An optional context on the transcript. 458 | final_model: The model that is used for the final prompt after compression is performed. 459 | max_output_size: Max output size in tokens 460 | timeout: The timeout in seconds to wait for the task. 461 | temperature: Change how deterministic the response is, with 0 being the most deterministic and 1 being the least deterministic. 462 | input_text: Custom formatted transcript data. Use instead of transcript_ids. 463 | 464 | Returns: A response to a question or task submitted via custom prompt (with source transcripts or other sources taken into the context) 465 | """ 466 | 467 | return self._impl.task( 468 | prompt=prompt, 469 | context=context, 470 | final_model=final_model, 471 | max_output_size=max_output_size, 472 | timeout=timeout, 473 | temperature=temperature, 474 | input_text=input_text, 475 | ) 476 | 477 | def task_async( 478 | self, 479 | prompt: str, 480 | context: Optional[Union[str, Dict[str, Any]]] = None, 481 | final_model: Optional[types.LemurModel] = None, 482 | max_output_size: Optional[int] = None, 483 | timeout: Optional[float] = None, 484 | temperature: Optional[float] = None, 485 | input_text: Optional[str] = None, 486 | ) -> concurrent.futures.Future[types.LemurTaskResponse]: 487 | """ 488 | Task feature allows you to submit a custom prompt to the model. 489 | 490 | See also Best Practices on LeMUR: https://www.assemblyai.com/docs/Guides/lemur_best_practices 491 | 492 | Args: 493 | prompt: The prompt to use for this task. 494 | context: An optional context on the transcript. 495 | final_model: The model that is used for the final prompt after compression is performed. 496 | max_output_size: Max output size in tokens 497 | timeout: The timeout in seconds to wait for the task. 498 | temperature: Change how deterministic the response is, with 0 being the most deterministic and 1 being the least deterministic. 499 | input_text: Custom formatted transcript data. Use instead of transcript_ids. 500 | 501 | Returns: A response to a question or task submitted via custom prompt (with source transcripts or other sources taken into the context) 502 | """ 503 | 504 | return self._executor.submit( 505 | self._impl.task, 506 | prompt=prompt, 507 | context=context, 508 | final_model=final_model, 509 | max_output_size=max_output_size, 510 | timeout=timeout, 511 | temperature=temperature, 512 | input_text=input_text, 513 | ) 514 | 515 | @classmethod 516 | def purge_request_data( 517 | cls, 518 | request_id: str, 519 | timeout: Optional[float] = None, 520 | ) -> types.LemurPurgeResponse: 521 | """ 522 | Purge sent LeMUR request data that was previously sent. 523 | 524 | Args: 525 | request_id: The request ID that was returned to you from the original LeMUR request that should be purged. 526 | 527 | Returns: A response saying whether the LeMUR request data was successfully purged. 528 | """ 529 | return _LemurImpl.purge_request_data( 530 | request_id=request_id, 531 | timeout=timeout, 532 | ) 533 | 534 | @classmethod 535 | def purge_request_data_async( 536 | cls, 537 | request_id: str, 538 | timeout: Optional[float] = None, 539 | ) -> concurrent.futures.Future[types.LemurPurgeResponse]: 540 | """ 541 | Purge sent LeMUR request data that was previously sent. 542 | 543 | Args: 544 | request_id: The request ID that was returned to you from the original LeMUR request that should be purged. 545 | 546 | Returns: A response saying whether the LeMUR request data was successfully purged. 547 | """ 548 | with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: 549 | response_future = executor.submit( 550 | _LemurImpl.purge_request_data, 551 | request_id=request_id, 552 | timeout=timeout, 553 | ) 554 | return response_future 555 | 556 | def get_response_data( 557 | self, 558 | request_id: str, 559 | timeout: Optional[float] = None, 560 | ) -> Union[ 561 | types.LemurStringResponse, 562 | types.LemurQuestionResponse, 563 | ]: 564 | """ 565 | Retrieve a LeMUR response that was previously generated. 566 | 567 | Args: 568 | request_id: The ID of a previous LeMUR request. 569 | timeout: The timeout in seconds to wait for the task. 570 | 571 | Returns: A LeMUR response that was previously generated. 572 | """ 573 | return self._impl.get_response_data(request_id=request_id, timeout=timeout) 574 | 575 | def get_response_data_async( 576 | self, 577 | request_id: str, 578 | timeout: Optional[float] = None, 579 | ) -> concurrent.futures.Future[ 580 | Union[ 581 | types.LemurStringResponse, 582 | types.LemurQuestionResponse, 583 | ] 584 | ]: 585 | """ 586 | Retrieve a LeMUR response that was previously generated. 587 | 588 | Args: 589 | request_id: The ID of a previous LeMUR request. 590 | timeout: The timeout in seconds to wait for the task. 591 | 592 | Returns: A LeMUR response that was previously generated. 593 | """ 594 | return self._executor.submit( 595 | self._impl.get_response_data, 596 | request_id=request_id, 597 | timeout=timeout, 598 | ) 599 | -------------------------------------------------------------------------------- /assemblyai/streaming/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI/assemblyai-python-sdk/ef8dcc0f300ae09b2b528d65f49e770dcefd6243/assemblyai/streaming/__init__.py -------------------------------------------------------------------------------- /assemblyai/streaming/v3/__init__.py: -------------------------------------------------------------------------------- 1 | from .client import StreamingClient 2 | from .models import ( 3 | BeginEvent, 4 | EventMessage, 5 | StreamingClientOptions, 6 | StreamingError, 7 | StreamingEvents, 8 | StreamingParameters, 9 | StreamingSessionParameters, 10 | TerminationEvent, 11 | TurnEvent, 12 | Word, 13 | ) 14 | 15 | __all__ = [ 16 | "BeginEvent", 17 | "EventMessage", 18 | "StreamingClient", 19 | "StreamingClientOptions", 20 | "StreamingError", 21 | "StreamingEvents", 22 | "StreamingParameters", 23 | "StreamingSessionParameters", 24 | "TerminationEvent", 25 | "TurnEvent", 26 | "Word", 27 | ] 28 | -------------------------------------------------------------------------------- /assemblyai/streaming/v3/client.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import queue 4 | import sys 5 | import threading 6 | from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union 7 | from urllib.parse import urlencode 8 | 9 | import httpx 10 | import websockets 11 | from pydantic import BaseModel 12 | from websockets.sync.client import connect as websocket_connect 13 | 14 | from assemblyai import __version__ 15 | 16 | from .models import ( 17 | BeginEvent, 18 | ErrorEvent, 19 | EventMessage, 20 | OperationMessage, 21 | StreamingClientOptions, 22 | StreamingError, 23 | StreamingErrorCodes, 24 | StreamingEvents, 25 | StreamingParameters, 26 | StreamingSessionParameters, 27 | TerminateSession, 28 | TerminationEvent, 29 | TurnEvent, 30 | UpdateConfiguration, 31 | ) 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | 36 | def _dump_model(model: BaseModel): 37 | if hasattr(model, "model_dump"): 38 | return model.model_dump(exclude_none=True) 39 | return model.dict(exclude_none=True) 40 | 41 | 42 | def _dump_model_json(model: BaseModel): 43 | if hasattr(model, "model_dump_json"): 44 | return model.model_dump_json(exclude_none=True) 45 | return model.json(exclude_none=True) 46 | 47 | 48 | def _user_agent() -> str: 49 | vi = sys.version_info 50 | python_version = f"{vi.major}.{vi.minor}.{vi.micro}" 51 | return ( 52 | f"AssemblyAI/1.0 (sdk=Python/{__version__} runtime_env=Python/{python_version})" 53 | ) 54 | 55 | 56 | class StreamingClient: 57 | def __init__(self, options: StreamingClientOptions): 58 | self._options = options 59 | 60 | self._client = _HTTPClient(api_host=options.api_host, api_key=options.api_key) 61 | 62 | self._handlers: Dict[StreamingEvents, List[Callable]] = {} 63 | 64 | for event in StreamingEvents.__members__.values(): 65 | self._handlers[event] = [] 66 | 67 | self._write_queue: queue.Queue[OperationMessage] = queue.Queue() 68 | self._write_thread = threading.Thread(target=self._write_message) 69 | self._read_thread = threading.Thread(target=self._read_message) 70 | self._stop_event = threading.Event() 71 | 72 | def connect(self, params: StreamingParameters) -> None: 73 | params_dict = _dump_model(params) 74 | params_encoded = urlencode(params_dict) 75 | 76 | uri = f"wss://{self._options.api_host}/v3/ws?{params_encoded}" 77 | headers = { 78 | "Authorization": self._options.token 79 | if self._options.token 80 | else self._options.api_key, 81 | "User-Agent": _user_agent(), 82 | "AssemblyAI-Version": "2025-05-12", 83 | } 84 | 85 | try: 86 | self._websocket = websocket_connect( 87 | uri, 88 | additional_headers=headers, 89 | open_timeout=15, 90 | ) 91 | except websockets.exceptions.ConnectionClosed as exc: 92 | self._handle_error(exc) 93 | return 94 | 95 | self._write_thread.start() 96 | self._read_thread.start() 97 | 98 | logger.debug("Connected to WebSocket server") 99 | 100 | def disconnect(self, terminate: bool = False) -> None: 101 | if terminate and not self._stop_event.is_set(): 102 | self._write_queue.put(TerminateSession()) 103 | 104 | try: 105 | self._read_thread.join() 106 | self._write_thread.join() 107 | 108 | if self._websocket: 109 | self._websocket.close() 110 | except Exception: 111 | pass 112 | 113 | def stream( 114 | self, data: Union[bytes, Generator[bytes, None, None], Iterable[bytes]] 115 | ) -> None: 116 | if isinstance(data, bytes): 117 | self._write_queue.put(data) 118 | return 119 | 120 | for chunk in data: 121 | self._write_queue.put(chunk) 122 | 123 | def set_params(self, params: StreamingSessionParameters): 124 | message = UpdateConfiguration(**_dump_model(params)) 125 | self._write_queue.put(message) 126 | 127 | def on(self, event: StreamingEvents, handler: Callable) -> None: 128 | if event in StreamingEvents.__members__.values() and callable(handler): 129 | self._handlers[event].append(handler) 130 | 131 | def _write_message(self) -> None: 132 | while not self._stop_event.is_set(): 133 | if not self._websocket: 134 | raise ValueError("Not connected to the WebSocket server") 135 | 136 | try: 137 | data = self._write_queue.get(timeout=1) 138 | except queue.Empty: 139 | continue 140 | 141 | try: 142 | if isinstance(data, bytes): 143 | self._websocket.send(data) 144 | elif isinstance(data, BaseModel): 145 | self._websocket.send(_dump_model_json(data)) 146 | else: 147 | raise ValueError(f"Attempted to send invalid message: {type(data)}") 148 | except websockets.exceptions.ConnectionClosed as exc: 149 | self._handle_error(exc) 150 | return 151 | 152 | def _read_message(self) -> None: 153 | while not self._stop_event.is_set(): 154 | if not self._websocket: 155 | raise ValueError("Not connected to the WebSocket server") 156 | 157 | try: 158 | message_data = self._websocket.recv(timeout=1) 159 | except TimeoutError: 160 | continue 161 | except websockets.exceptions.ConnectionClosed as exc: 162 | self._handle_error(exc) 163 | return 164 | 165 | try: 166 | message_json = json.loads(message_data) 167 | except json.JSONDecodeError as exc: 168 | logger.warning(f"Failed to decode message: {exc}") 169 | continue 170 | 171 | message = self._parse_message(message_json) 172 | 173 | if isinstance(message, ErrorEvent): 174 | self._handle_error(message) 175 | elif message: 176 | self._handle_message(message) 177 | else: 178 | logger.warning(f"Unsupported event type: {message_json['type']}") 179 | 180 | def _handle_message(self, message: EventMessage) -> None: 181 | if isinstance(message, TerminationEvent): 182 | self._stop_event.set() 183 | 184 | event_type = StreamingEvents[message.type] 185 | 186 | for handler in self._handlers[event_type]: 187 | handler(self, message) 188 | 189 | def _parse_message(self, data: Dict[str, Any]) -> Optional[EventMessage]: 190 | if "type" in data: 191 | message_type = data.get("type") 192 | 193 | event_type = self._parse_event_type(message_type) 194 | 195 | if event_type == StreamingEvents.Begin: 196 | return BeginEvent.model_validate(data) 197 | elif event_type == StreamingEvents.Termination: 198 | return TerminationEvent.model_validate(data) 199 | elif event_type == StreamingEvents.Turn: 200 | return TurnEvent.model_validate(data) 201 | else: 202 | return None 203 | elif "error" in data: 204 | return ErrorEvent.model_validate(data) 205 | 206 | return None 207 | 208 | @staticmethod 209 | def _parse_event_type(message_type: Optional[Any]) -> Optional[StreamingEvents]: 210 | if not isinstance(message_type, str): 211 | return None 212 | 213 | try: 214 | return StreamingEvents[message_type] 215 | except KeyError: 216 | return None 217 | 218 | def _handle_error( 219 | self, 220 | error: Union[ 221 | ErrorEvent, 222 | websockets.exceptions.ConnectionClosed, 223 | ], 224 | ): 225 | parsed_error = self._parse_error(error) 226 | 227 | for handler in self._handlers[StreamingEvents.Error]: 228 | handler(self, parsed_error) 229 | 230 | self.disconnect() 231 | 232 | def _parse_error( 233 | self, 234 | error: Union[ 235 | ErrorEvent, 236 | websockets.exceptions.ConnectionClosed, 237 | ], 238 | ) -> StreamingError: 239 | if isinstance(error, ErrorEvent): 240 | return StreamingError( 241 | message=error.error, 242 | ) 243 | elif isinstance(error, websockets.exceptions.ConnectionClosed): 244 | if ( 245 | error.code >= 4000 246 | and error.code <= 4999 247 | and error.code in StreamingErrorCodes 248 | ): 249 | error_message = StreamingErrorCodes[error.code] 250 | else: 251 | error_message = error.reason 252 | 253 | if error.code != 1000: 254 | return StreamingError(message=error_message, code=error.code) 255 | 256 | return StreamingError( 257 | message=f"Unknown error: {error}", 258 | ) 259 | 260 | def create_temporary_token( 261 | self, 262 | expires_in_seconds: int, 263 | max_session_duration_seconds: int, 264 | ) -> str: 265 | return self._client.create_temporary_token( 266 | expires_in_seconds=expires_in_seconds, 267 | max_session_duration_seconds=max_session_duration_seconds, 268 | ) 269 | 270 | 271 | class _HTTPClient: 272 | def __init__(self, api_host: str, api_key: Optional[str] = None): 273 | vi = sys.version_info 274 | python_version = f"{vi.major}.{vi.minor}.{vi.micro}" 275 | user_agent = f"{httpx._client.USER_AGENT} AssemblyAI/1.0 (sdk=Python/{__version__} runtime_env=Python/{python_version})" 276 | 277 | headers = {"User-Agent": user_agent} 278 | 279 | if api_key: 280 | headers["Authorization"] = api_key 281 | 282 | self._http_client = httpx.Client( 283 | base_url="https://" + api_host, 284 | headers=headers, 285 | ) 286 | 287 | def create_temporary_token( 288 | self, 289 | expires_in_seconds: Optional[int] = None, 290 | max_session_duration_seconds: Optional[int] = None, 291 | ) -> str: 292 | params: Dict[str, Any] = {} 293 | 294 | if expires_in_seconds: 295 | params["expires_in_seconds"] = expires_in_seconds 296 | 297 | if max_session_duration_seconds: 298 | params["max_session_duration_seconds"] = expires_in_seconds 299 | 300 | response = self._http_client.get( 301 | "/v3/token", 302 | params=params, 303 | ) 304 | 305 | response.raise_for_status() 306 | return response.json()["token"] 307 | -------------------------------------------------------------------------------- /assemblyai/streaming/v3/models.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from enum import Enum 3 | from typing import List, Literal, Optional, Union 4 | 5 | from pydantic import BaseModel 6 | 7 | 8 | class Word(BaseModel): 9 | start: int 10 | end: int 11 | confidence: float 12 | text: str 13 | word_is_final: bool 14 | 15 | 16 | class TurnEvent(BaseModel): 17 | type: Literal["Turn"] 18 | turn_order: int 19 | turn_is_formatted: bool 20 | end_of_turn: bool 21 | transcript: str 22 | end_of_turn_confidence: float 23 | words: List[Word] 24 | 25 | 26 | class BeginEvent(BaseModel): 27 | type: Literal["Begin"] = "Begin" 28 | id: str 29 | expires_at: datetime 30 | 31 | 32 | class TerminationEvent(BaseModel): 33 | type: Literal["Termination"] = "Termination" 34 | audio_duration_seconds: Optional[int] = None 35 | session_duration_seconds: Optional[int] = None 36 | 37 | 38 | class ErrorEvent(BaseModel): 39 | error: str 40 | 41 | 42 | EventMessage = Union[ 43 | BeginEvent, 44 | TerminationEvent, 45 | TurnEvent, 46 | ErrorEvent, 47 | ] 48 | 49 | 50 | class TerminateSession(BaseModel): 51 | type: Literal["Terminate"] = "Terminate" 52 | 53 | 54 | class ForceEndpoint(BaseModel): 55 | type: Literal["ForceEndpoint"] = "ForceEndpoint" 56 | 57 | 58 | class StreamingSessionParameters(BaseModel): 59 | end_of_turn_confidence_threshold: Optional[float] = None 60 | min_end_of_turn_silence_when_confident: Optional[int] = None 61 | max_turn_silence: Optional[int] = None 62 | format_turns: Optional[bool] = None 63 | 64 | 65 | class StreamingParameters(StreamingSessionParameters): 66 | sample_rate: int 67 | 68 | 69 | class UpdateConfiguration(StreamingSessionParameters): 70 | type: Literal["UpdateConfiguration"] = "UpdateConfiguration" 71 | 72 | 73 | OperationMessage = Union[ 74 | bytes, 75 | TerminateSession, 76 | ForceEndpoint, 77 | UpdateConfiguration, 78 | ] 79 | 80 | 81 | class StreamingClientOptions(BaseModel): 82 | api_host: str = "streaming.assemblyai.com" 83 | api_key: Optional[str] = None 84 | token: Optional[str] = None 85 | 86 | 87 | class StreamingError(Exception): 88 | def __init__(self, message: str, code: Optional[int] = None): 89 | super().__init__(message) 90 | self.code = code 91 | 92 | 93 | StreamingErrorCodes = { 94 | 4000: "Sample rate must be a positive integer", 95 | 4001: "Not Authorized", 96 | 4002: "Insufficient Funds", 97 | 4003: """This feature is paid-only and requires you to add a credit card. 98 | Please visit https://app.assemblyai.com/ to add a credit card to your account""", 99 | 4004: "Session Not Found", 100 | 4008: "Session Expired", 101 | 4010: "Session Previously Closed", 102 | 4029: "Client sent audio too fast", 103 | 4030: "Session is handled by another websocket", 104 | 4031: "Session idle for too long", 105 | 4032: "Audio duration is too short", 106 | 4033: "Audio duration is too long", 107 | 4034: "Audio too small to transcode", 108 | 4100: "Endpoint received invalid JSON", 109 | 4101: "Endpoint received a message with an invalid schema", 110 | 4102: "This account has exceeded the number of allowed streams", 111 | 4103: "The session has been reconnected. This websocket is no longer valid.", 112 | 1013: "Temporary server condition forced blocking client's request", 113 | } 114 | 115 | 116 | class StreamingEvents(Enum): 117 | Begin = "Begin" 118 | Termination = "Termination" 119 | Turn = "Turn" 120 | Error = "Error" 121 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | [lint] 2 | # Enable default rules plus I (isort) and S101 (check for asserts). 3 | select = ["I", "E4", "E7", "E9", "F", "S101"] 4 | 5 | [lint.per-file-ignores] 6 | # Ignore import violations in all init files. 7 | "__init__.py" = ["E402"] 8 | # Ignore assert checks in all test files. 9 | "**/*test*.py" = ["S101"] 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from setuptools import find_packages, setup 4 | 5 | long_description = (Path(__file__).parent / "README.md").read_text() 6 | 7 | 8 | def get_version() -> str: 9 | version = {} 10 | with open(Path(__file__).parent / "assemblyai" / "__version__.py") as f: 11 | exec(f.read(), version) 12 | return version["__version__"] 13 | 14 | 15 | setup( 16 | name="assemblyai", 17 | version=get_version(), 18 | description="AssemblyAI Python SDK", 19 | author="AssemblyAI", 20 | author_email="engineering.sdk@assemblyai.com", 21 | packages=find_packages(exclude=["tests", "tests.*"]), 22 | install_requires=[ 23 | "httpx>=0.19.0", 24 | "pydantic>=1.10.17", 25 | "typing-extensions>=3.7", 26 | "websockets>=11.0", 27 | ], 28 | extras_require={ 29 | "extras": ["pyaudio>=0.2.13"], 30 | }, 31 | classifiers=[ 32 | "Development Status :: 3 - Alpha", 33 | "Intended Audience :: Developers", 34 | "Intended Audience :: Science/Research", 35 | "License :: OSI Approved :: MIT License", 36 | "Natural Language :: English", 37 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 38 | "Topic :: Software Development :: Libraries", 39 | "Topic :: Software Development :: Libraries :: Python Modules", 40 | "Programming Language :: Python :: 3", 41 | "Programming Language :: Python :: 3.8", 42 | "Programming Language :: Python :: 3.9", 43 | "Programming Language :: Python :: 3.10", 44 | "Programming Language :: Python :: 3.11", 45 | ], 46 | long_description=long_description, 47 | long_description_content_type="text/markdown", 48 | url="https://github.com/AssemblyAI/assemblyai-python-sdk", 49 | license="MIT License", 50 | license_files=["LICENSE"], 51 | python_requires=">=3.8", 52 | project_urls={ 53 | "Code": "https://github.com/AssemblyAI/assemblyai-python-sdk", 54 | "Issues": "https://github.com/AssemblyAI/assemblyai-python-sdk/issues", 55 | "Documentation": "https://github.com/AssemblyAI/assemblyai-python-sdk/blob/master/README.md", 56 | "API Documentation": "https://www.assemblyai.com/docs/", 57 | "Website": "https://assemblyai.com/", 58 | }, 59 | ) 60 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI/assemblyai-python-sdk/ef8dcc0f300ae09b2b528d65f49e770dcefd6243/tests/__init__.py -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AssemblyAI/assemblyai-python-sdk/ef8dcc0f300ae09b2b528d65f49e770dcefd6243/tests/unit/__init__.py -------------------------------------------------------------------------------- /tests/unit/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.fixture(scope="session", autouse=True) 5 | def faker_seed(): 6 | """ 7 | Seeds the faker library with a constant value. 8 | 9 | See: https://faker.readthedocs.io/en/master/pytest-fixtures.html 10 | """ 11 | return 12345 12 | -------------------------------------------------------------------------------- /tests/unit/factories.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains factories that are used for mocking certain requests/responses 3 | from AssemblyAI's API. 4 | """ 5 | 6 | from enum import Enum 7 | from functools import partial 8 | from typing import Any, Callable, Dict 9 | 10 | import factory 11 | import factory.base 12 | 13 | import assemblyai as aai 14 | from assemblyai import types 15 | 16 | 17 | class TimestampFactory(factory.Factory): 18 | class Meta: 19 | model = aai.Timestamp 20 | 21 | start = factory.Faker("pyint") 22 | end = factory.Faker("pyint") 23 | 24 | 25 | class WordFactory(factory.Factory): 26 | class Meta: 27 | model = aai.Word 28 | 29 | text = factory.Faker("word") 30 | start = factory.Faker("pyint") 31 | end = factory.Faker("pyint") 32 | confidence = factory.Faker("pyfloat", min_value=0.0, max_value=1.0) 33 | speaker = "1" 34 | channel = "1" 35 | 36 | 37 | class UtteranceWordFactory(WordFactory): 38 | class Meta: 39 | model = aai.UtteranceWord 40 | 41 | speaker = "1" 42 | channel = "1" 43 | 44 | 45 | class UtteranceFactory(UtteranceWordFactory): 46 | class Meta: 47 | model = aai.Utterance 48 | 49 | words = factory.List([factory.SubFactory(UtteranceWordFactory)]) 50 | 51 | 52 | class ChapterFactory(factory.Factory): 53 | class Meta: 54 | model = types.Chapter 55 | 56 | summary = factory.Faker("sentence") 57 | headline = factory.Faker("sentence") 58 | gist = factory.Faker("sentence") 59 | start = factory.Faker("pyint") 60 | end = factory.Faker("pyint") 61 | 62 | 63 | class BaseTranscriptFactory(factory.Factory): 64 | class Meta: 65 | model = types.BaseTranscript 66 | 67 | language_code = "en" 68 | audio_url = factory.Faker("url") 69 | punctuate = True 70 | format_text = True 71 | multichannel = None 72 | dual_channel = None 73 | webhook_url = None 74 | webhook_auth_header_name = None 75 | audio_start_from = None 76 | audio_end_at = None 77 | word_boost = None 78 | boost_param = None 79 | filter_profanity = False 80 | redact_pii = False 81 | redact_pii_audio = False 82 | redact_pii_policies = None 83 | redact_pii_sub = None 84 | speaker_labels = False 85 | content_safety = False 86 | iab_categories = False 87 | custom_spelling = None 88 | disfluencies = False 89 | sentiment_analysis = False 90 | auto_chapters = False 91 | entity_detection = False 92 | summarization = False 93 | summary_model = None 94 | summary_type = None 95 | auto_highlights = False 96 | language_detection = False 97 | speech_threshold = None 98 | 99 | 100 | class BaseTranscriptResponseFactory(BaseTranscriptFactory): 101 | class Meta: 102 | model = types.TranscriptResponse 103 | 104 | id = factory.Faker("uuid4") 105 | status = aai.TranscriptStatus.completed 106 | error = None 107 | text = factory.Faker("text") 108 | words = factory.List([factory.SubFactory(WordFactory)]) 109 | utterances = factory.List([factory.SubFactory(UtteranceFactory)]) 110 | confidence = factory.Faker("pyfloat", min_value=0.0, max_value=1.0) 111 | audio_duration = factory.Faker("pyint") 112 | webhook_auth = False 113 | webhook_status_code = None 114 | 115 | 116 | class TranscriptDeletedResponseFactory(BaseTranscriptResponseFactory): 117 | language_code = None 118 | audio_url = "http://deleted_by_user" 119 | text = "Deleted by user." 120 | words = None 121 | utterances = None 122 | confidence = None 123 | punctuate = None 124 | format_text = None 125 | dual_channel = None 126 | multichannel = None 127 | webhook_url = "http://deleted_by_user" 128 | webhook_status_code = None 129 | webhook_auth = False 130 | # webhook_auth_header_name = None # not yet supported in SDK 131 | speed_boost = None 132 | auto_highlights = None 133 | audio_start_from = None 134 | audio_end_at = None 135 | word_boost = None 136 | boost_param = None 137 | filter_profanity = None 138 | redact_pii_audio = None 139 | # redact_pii_quality = None # not yet supported in SDK 140 | redact_pii_policies = None 141 | redact_pii_sub = None 142 | speaker_labels = None 143 | error = None 144 | content_safety = None 145 | iab_categories = None 146 | content_safety_labels = None 147 | iab_categories = None 148 | language_detection = None 149 | custom_spelling = None 150 | # cluster_id = None # not yet supported in SDK 151 | # custom_topics = None # not yet supported in SDK 152 | # topics = None # not yet supported in SDK 153 | speech_threshold = None 154 | chapters = None 155 | entities = None 156 | speakers_expected = None 157 | summary = None 158 | sentiment_analysis = None 159 | 160 | 161 | class TranscriptCompletedResponseFactory(BaseTranscriptResponseFactory): 162 | pass 163 | 164 | 165 | class TranscriptCompletedResponseFactoryBest(BaseTranscriptResponseFactory): 166 | speech_model = "best" 167 | 168 | 169 | class TranscriptCompletedResponseFactoryNano(BaseTranscriptResponseFactory): 170 | speech_model = "nano" 171 | 172 | 173 | class TranscriptQueuedResponseFactory(BaseTranscriptFactory): 174 | class Meta: 175 | model = types.TranscriptResponse 176 | 177 | id = factory.Faker("uuid4") 178 | status = aai.TranscriptStatus.queued 179 | text = None 180 | words = None 181 | utterances = None 182 | confidence = None 183 | audio_duration = None 184 | 185 | 186 | class TranscriptProcessingResponseFactory(BaseTranscriptFactory): 187 | class Meta: 188 | model = types.TranscriptResponse 189 | 190 | id = factory.Faker("uuid4") 191 | status = aai.TranscriptStatus.processing 192 | text = None 193 | words = None 194 | utterances = None 195 | confidence = None 196 | audio_duration = None 197 | 198 | 199 | class TranscriptErrorResponseFactory(BaseTranscriptFactory): 200 | class Meta: 201 | model = types.TranscriptResponse 202 | 203 | status = aai.TranscriptStatus.error 204 | error = "Aw, snap!" 205 | 206 | 207 | class TranscriptRequestFactory(BaseTranscriptFactory): 208 | class Meta: 209 | model = types.TranscriptRequest 210 | 211 | 212 | class PageDetails(factory.Factory): 213 | class Meta: 214 | model = types.PageDetails 215 | 216 | current_url = factory.Faker("url") 217 | limit = 10 218 | next_url = None 219 | prev_url = None 220 | result_count = 2 221 | 222 | 223 | class TranscriptItem(factory.Factory): 224 | class Meta: 225 | model = types.TranscriptItem 226 | 227 | audio_url = factory.Faker("url") 228 | created = factory.Faker("iso8601") 229 | id = factory.Faker("uuid4") 230 | resource_url = factory.Faker("url") 231 | status = aai.TranscriptStatus.completed 232 | completed = None 233 | error = None 234 | 235 | 236 | class ListTranscriptResponse(factory.Factory): 237 | class Meta: 238 | model = types.ListTranscriptResponse 239 | 240 | page_details = factory.SubFactory(PageDetails) 241 | transcripts = factory.List( 242 | [ 243 | factory.SubFactory(TranscriptItem), 244 | factory.SubFactory(TranscriptItem), 245 | ] 246 | ) 247 | 248 | 249 | class LemurUsage(factory.Factory): 250 | class Meta: 251 | model = types.LemurUsage 252 | 253 | input_tokens = factory.Faker("pyint") 254 | output_tokens = factory.Faker("pyint") 255 | 256 | 257 | class LemurQuestionAnswer(factory.Factory): 258 | class Meta: 259 | model = types.LemurQuestionAnswer 260 | 261 | question = factory.Faker("text") 262 | answer = factory.Faker("text") 263 | 264 | 265 | class LemurQuestionResponse(factory.Factory): 266 | class Meta: 267 | model = types.LemurQuestionResponse 268 | 269 | request_id = factory.Faker("uuid4") 270 | usage = factory.SubFactory(LemurUsage) 271 | response = factory.List( 272 | [ 273 | factory.SubFactory(LemurQuestionAnswer), 274 | factory.SubFactory(LemurQuestionAnswer), 275 | ] 276 | ) 277 | 278 | 279 | class LemurSummaryResponse(factory.Factory): 280 | class Meta: 281 | model = types.LemurSummaryResponse 282 | 283 | request_id = factory.Faker("uuid4") 284 | usage = factory.SubFactory(LemurUsage) 285 | response = factory.Faker("text") 286 | 287 | 288 | class LemurActionItemsResponse(factory.Factory): 289 | class Meta: 290 | model = types.LemurActionItemsResponse 291 | 292 | request_id = factory.Faker("uuid4") 293 | usage = factory.SubFactory(LemurUsage) 294 | response = factory.Faker("text") 295 | 296 | 297 | class LemurTaskResponse(factory.Factory): 298 | class Meta: 299 | model = types.LemurTaskResponse 300 | 301 | request_id = factory.Faker("uuid4") 302 | usage = factory.SubFactory(LemurUsage) 303 | response = factory.Faker("text") 304 | 305 | 306 | class LemurStringResponse(factory.Factory): 307 | class Meta: 308 | model = types.LemurStringResponse 309 | 310 | request_id = factory.Faker("uuid4") 311 | usage = factory.SubFactory(LemurUsage) 312 | response = factory.Faker("text") 313 | 314 | 315 | class LemurPurgeResponse(factory.Factory): 316 | class Meta: 317 | model = types.LemurPurgeResponse 318 | 319 | request_id = factory.Faker("uuid4") 320 | request_id_to_purge = factory.Faker("uuid4") 321 | deleted = True 322 | 323 | 324 | class WordSearchMatchFactory(factory.Factory): 325 | class Meta: 326 | model = types.WordSearchMatch 327 | 328 | text = factory.Faker("text") 329 | count = factory.Faker("pyint") 330 | timestamps = [(123, 456)] 331 | indexes = [123, 456] 332 | 333 | 334 | class WordSearchMatchResponseFactory(factory.Factory): 335 | class Meta: 336 | model = types.WordSearchMatchResponse 337 | 338 | total_count = factory.Faker("pyint") 339 | 340 | matches = factory.List([factory.SubFactory(WordSearchMatchFactory)]) 341 | 342 | 343 | class SentenceFactory(WordFactory): 344 | class Meta: 345 | model = types.Sentence 346 | 347 | words = factory.List([factory.SubFactory(WordFactory)]) 348 | 349 | 350 | class ParagraphFactory(SentenceFactory): 351 | class Meta: 352 | model = types.Paragraph 353 | 354 | 355 | class SentencesResponseFactory(factory.Factory): 356 | class Meta: 357 | model = types.SentencesResponse 358 | 359 | sentences = factory.List([factory.SubFactory(SentenceFactory)]) 360 | confidence = factory.Faker("pyfloat", min_value=0.0, max_value=1.0) 361 | audio_duration = factory.Faker("pyint") 362 | 363 | 364 | class ParagraphsResponseFactory(factory.Factory): 365 | class Meta: 366 | model = types.ParagraphsResponse 367 | 368 | paragraphs = factory.List([factory.SubFactory(ParagraphFactory)]) 369 | confidence = factory.Faker("pyfloat", min_value=0.0, max_value=1.0) 370 | audio_duration = factory.Faker("pyint") 371 | 372 | 373 | def generate_dict_factory(f: factory.Factory) -> Callable[[], Dict[str, Any]]: 374 | """ 375 | Creates a dict factory from the given *Factory class. 376 | 377 | Args: 378 | f: The factory to create a dict factory from. 379 | """ 380 | 381 | def stub_is_list(stub: factory.base.StubObject) -> bool: 382 | try: 383 | return all(k.isdigit() for k in stub.__dict__.keys()) 384 | except AttributeError: 385 | return False 386 | 387 | def convert_dict_from_stub(stub: factory.base.StubObject) -> Dict[str, Any]: 388 | stub_dict = stub.__dict__ 389 | for key, value in stub_dict.items(): 390 | if isinstance(value, factory.base.StubObject): 391 | stub_dict[key] = ( 392 | [convert_dict_from_stub(v) for v in value.__dict__.values()] 393 | if stub_is_list(value) 394 | else convert_dict_from_stub(value) 395 | ) 396 | elif isinstance(value, Enum): 397 | stub_dict[key] = value.value 398 | return stub_dict 399 | 400 | def dict_factory(f, **kwargs): 401 | stub = f.stub(**kwargs) 402 | stub_dict = convert_dict_from_stub(stub) 403 | return stub_dict 404 | 405 | return partial(dict_factory, f) 406 | -------------------------------------------------------------------------------- /tests/unit/test_auto_chapters.py: -------------------------------------------------------------------------------- 1 | import factory 2 | import pytest 3 | from pytest_httpx import HTTPXMock 4 | 5 | import tests.unit.unit_test_utils as unit_test_utils 6 | import assemblyai as aai 7 | from tests.unit import factories 8 | 9 | aai.settings.api_key = "test" 10 | 11 | 12 | class AutoChaptersResponseFactory(factories.TranscriptCompletedResponseFactory): 13 | chapters = factory.List([factory.SubFactory(factories.ChapterFactory)]) 14 | 15 | 16 | def test_auto_chapters_fails_without_punctuation(httpx_mock: HTTPXMock): 17 | """ 18 | Tests whether the SDK raises an error before making a request 19 | if `auto_chapters` is enabled and `punctuation` is disabled 20 | """ 21 | 22 | with pytest.raises(ValueError) as error: 23 | unit_test_utils.submit_mock_transcription_request( 24 | httpx_mock, 25 | mock_response={}, # response doesn't matter, since it shouldn't occur 26 | config=aai.TranscriptionConfig( 27 | auto_chapters=True, 28 | punctuate=False, 29 | ), 30 | ) 31 | # Check that the error message informs the user of the invalid parameter 32 | assert "punctuate" in str(error) 33 | 34 | # Check that the error was raised before any requests were made 35 | assert len(httpx_mock.get_requests()) == 0 36 | 37 | 38 | def test_auto_chapters_disabled_by_default(httpx_mock: HTTPXMock): 39 | """ 40 | Tests that excluding `auto_chapters` from the `TranscriptionConfig` will 41 | result in the default behavior of it being excluded from the request body 42 | """ 43 | request_body, transcript = unit_test_utils.submit_mock_transcription_request( 44 | httpx_mock, 45 | mock_response=factories.generate_dict_factory( 46 | factories.TranscriptCompletedResponseFactory 47 | )(), 48 | config=aai.TranscriptionConfig(), 49 | ) 50 | assert request_body.get("auto_chapters") is None 51 | assert transcript.chapters is None 52 | 53 | 54 | def test_auto_chapters_enabled(httpx_mock: HTTPXMock): 55 | """ 56 | Tests that including `auto_chapters=True` in the `TranscriptionConfig` 57 | will result in `auto_chapters=True` in the request body, and that the 58 | response is properly parsed into a `Transcript` object 59 | """ 60 | mock_response = factories.generate_dict_factory(AutoChaptersResponseFactory)() 61 | request_body, transcript = unit_test_utils.submit_mock_transcription_request( 62 | httpx_mock, 63 | mock_response=mock_response, 64 | config=aai.TranscriptionConfig(auto_chapters=True), 65 | ) 66 | 67 | # Check that request body was properly defined 68 | assert request_body.get("auto_chapters") is True 69 | 70 | # Check that transcript was properly parsed from JSON response 71 | assert transcript.error is None 72 | assert transcript.chapters is not None 73 | assert len(transcript.chapters) > 0 74 | assert len(transcript.chapters) == len(mock_response["chapters"]) 75 | 76 | for response_chapter, transcript_chapter in zip( 77 | mock_response["chapters"], transcript.chapters 78 | ): 79 | assert transcript_chapter.summary == response_chapter["summary"] 80 | assert transcript_chapter.headline == response_chapter["headline"] 81 | assert transcript_chapter.gist == response_chapter["gist"] 82 | assert transcript_chapter.start == response_chapter["start"] 83 | assert transcript_chapter.end == response_chapter["end"] 84 | -------------------------------------------------------------------------------- /tests/unit/test_auto_highlights.py: -------------------------------------------------------------------------------- 1 | import factory 2 | from pytest_httpx import HTTPXMock 3 | 4 | import tests.unit.unit_test_utils as unit_test_utils 5 | import assemblyai as aai 6 | from tests.unit import factories 7 | 8 | aai.settings.api_key = "test" 9 | 10 | 11 | class AutohighlightResultFactory(factory.Factory): 12 | class Meta: 13 | model = aai.types.AutohighlightResult 14 | 15 | count = factory.Faker("pyint") 16 | rank = factory.Faker("pyfloat") 17 | text = factory.Faker("sentence") 18 | timestamps = factory.List([factory.SubFactory(factories.TimestampFactory)]) 19 | 20 | 21 | class AutohighlightResponseFactory(factory.Factory): 22 | class Meta: 23 | model = aai.types.AutohighlightResponse 24 | 25 | status = aai.types.StatusResult.success 26 | results = factory.List([factory.SubFactory(AutohighlightResultFactory)]) 27 | 28 | 29 | class AutohighlightTranscriptResponseFactory( 30 | factories.TranscriptCompletedResponseFactory 31 | ): 32 | auto_highlights_result = factory.SubFactory(AutohighlightResponseFactory) 33 | 34 | 35 | def test_auto_highlights_disabled_by_default(httpx_mock: HTTPXMock): 36 | """ 37 | Tests that excluding `auto_highlights` from the `TranscriptionConfig` will 38 | result in the default behavior of it being excluded from the request body 39 | """ 40 | request_body, transcript = unit_test_utils.submit_mock_transcription_request( 41 | httpx_mock, 42 | mock_response=factories.generate_dict_factory( 43 | factories.TranscriptCompletedResponseFactory 44 | )(), 45 | config=aai.TranscriptionConfig(), 46 | ) 47 | assert request_body.get("auto_highlights") is None 48 | assert transcript.auto_highlights is None 49 | 50 | 51 | def test_auto_highlights_enabled(httpx_mock: HTTPXMock): 52 | """ 53 | Tests that including `auto_highlights=True` in the `TranscriptionConfig` 54 | will result in `auto_highlights=True` in the request body, and that the 55 | response is properly parsed into a `Transcript` object 56 | """ 57 | mock_response = factories.generate_dict_factory( 58 | AutohighlightTranscriptResponseFactory 59 | )() 60 | request_body, transcript = unit_test_utils.submit_mock_transcription_request( 61 | httpx_mock, 62 | mock_response=mock_response, 63 | config=aai.TranscriptionConfig(auto_highlights=True), 64 | ) 65 | 66 | # Check that request body was properly defined 67 | assert request_body.get("auto_highlights") is True 68 | 69 | # Check that transcript was properly parsed from JSON response 70 | assert transcript.error is None 71 | assert transcript.auto_highlights is not None 72 | assert ( 73 | transcript.auto_highlights.status 74 | == mock_response["auto_highlights_result"]["status"] 75 | ) 76 | 77 | assert transcript.auto_highlights.results is not None 78 | assert len(transcript.auto_highlights.results) > 0 79 | assert len(transcript.auto_highlights.results) == len( 80 | mock_response["auto_highlights_result"]["results"] 81 | ) 82 | 83 | for response_result, transcript_result in zip( 84 | mock_response["auto_highlights_result"]["results"], 85 | transcript.auto_highlights.results, 86 | ): 87 | assert transcript_result.count == response_result["count"] 88 | assert transcript_result.rank == response_result["rank"] 89 | assert transcript_result.text == response_result["text"] 90 | 91 | for response_timestamp, transcript_timestamp in zip( 92 | response_result["timestamps"], transcript_result.timestamps 93 | ): 94 | assert transcript_timestamp.start == response_timestamp["start"] 95 | assert transcript_timestamp.end == response_timestamp["end"] 96 | -------------------------------------------------------------------------------- /tests/unit/test_client.py: -------------------------------------------------------------------------------- 1 | import assemblyai as aai 2 | 3 | 4 | def test_reset_client_on_settings_change(): 5 | """ 6 | Test that the settings are reset when the global settings have changed. 7 | """ 8 | aai.settings.api_key = "before" 9 | transcriber = aai.Transcriber() 10 | 11 | assert transcriber._client.settings.api_key == "before" 12 | 13 | # Reset it to "test" again. All other tests are also working with this value 14 | aai.settings.api_key = "test" 15 | transcriber = aai.Transcriber() 16 | 17 | assert transcriber._client.settings.api_key == "test" 18 | -------------------------------------------------------------------------------- /tests/unit/test_config.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | import pytest 4 | 5 | import assemblyai as aai 6 | 7 | 8 | def test_configuration_are_none_by_default(): 9 | """ 10 | Tests whether all configurations are None by default. 11 | """ 12 | 13 | config = aai.TranscriptionConfig() 14 | fields = config.raw.__fields_set__ - {"language_code"} 15 | 16 | for name, value in inspect.getmembers(config): 17 | if name in fields and value is not None: 18 | pytest.fail( 19 | f"Configuration field {name} is {value} and not None by default." 20 | ) 21 | -------------------------------------------------------------------------------- /tests/unit/test_content_safety.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import factory 4 | import pytest 5 | from pytest_httpx import HTTPXMock 6 | 7 | import tests.unit.unit_test_utils as unit_test_utils 8 | import assemblyai as aai 9 | from tests.unit import factories 10 | 11 | aai.settings.api_key = "test" 12 | 13 | 14 | class ContentSafetySeverityScoreFactory(factory.Factory): 15 | class Meta: 16 | model = aai.types.ContentSafetySeverityScore 17 | 18 | low = factory.Faker("pyfloat") 19 | medium = factory.Faker("pyfloat") 20 | high = factory.Faker("pyfloat") 21 | 22 | 23 | class ContentSafetyLabelResultFactory(factory.Factory): 24 | class Meta: 25 | model = aai.types.ContentSafetyLabelResult 26 | 27 | label = factory.Faker("enum", enum_cls=aai.types.ContentSafetyLabel) 28 | confidence = factory.Faker("pyfloat") 29 | severity = factory.Faker("pyfloat") 30 | 31 | 32 | class ContentSafetyResultFactory(factory.Factory): 33 | class Meta: 34 | model = aai.types.ContentSafetyResult 35 | 36 | text = factory.Faker("sentence") 37 | labels = factory.List([factory.SubFactory(ContentSafetyLabelResultFactory)]) 38 | timestamp = factory.SubFactory(factories.TimestampFactory) 39 | 40 | 41 | class ContentSafetyResponseFactory(factory.Factory): 42 | class Meta: 43 | model = aai.types.ContentSafetyResponse 44 | 45 | status = aai.types.StatusResult.success 46 | results = factory.List([factory.SubFactory(ContentSafetyResultFactory)]) 47 | summary = factory.Dict( 48 | { 49 | random.choice(list(aai.types.ContentSafetyLabel)).value: factory.Faker( 50 | "pyfloat" 51 | ) 52 | } 53 | ) 54 | severity_score_summary = factory.Dict( 55 | { 56 | random.choice(list(aai.types.ContentSafetyLabel)).value: factory.SubFactory( 57 | ContentSafetySeverityScoreFactory 58 | ) 59 | } 60 | ) 61 | 62 | 63 | class ContentSafetyTranscriptResponseFactory( 64 | factories.TranscriptCompletedResponseFactory 65 | ): 66 | content_safety_labels = factory.SubFactory(ContentSafetyResponseFactory) 67 | 68 | 69 | def test_content_safety_disabled_by_default(httpx_mock: HTTPXMock): 70 | """ 71 | Tests that excluding `content_safety` from the `TranscriptionConfig` will 72 | result in the default behavior of it being excluded from the request body 73 | """ 74 | request_body, transcript = unit_test_utils.submit_mock_transcription_request( 75 | httpx_mock, 76 | mock_response=factories.generate_dict_factory( 77 | factories.TranscriptCompletedResponseFactory 78 | )(), 79 | config=aai.TranscriptionConfig(), 80 | ) 81 | assert request_body.get("content_safety") is None 82 | assert transcript.content_safety is None 83 | 84 | 85 | def test_content_safety_enabled(httpx_mock: HTTPXMock): 86 | """ 87 | Tests that including `content_safety=True` in the `TranscriptionConfig` 88 | will result in `content_safety=True` in the request body, and that the 89 | response is properly parsed into a `Transcript` object 90 | """ 91 | mock_response = factories.generate_dict_factory( 92 | ContentSafetyTranscriptResponseFactory 93 | )() 94 | request_body, transcript = unit_test_utils.submit_mock_transcription_request( 95 | httpx_mock, 96 | mock_response=mock_response, 97 | config=aai.TranscriptionConfig(content_safety=True), 98 | ) 99 | 100 | # Check that request body was properly defined 101 | assert request_body.get("content_safety") is True 102 | 103 | # Check that transcript was properly parsed from JSON response 104 | assert transcript.error is None 105 | assert transcript.content_safety is not None 106 | 107 | # Verify status 108 | assert transcript.content_safety.status == aai.types.StatusResult.success 109 | 110 | # Verify results 111 | assert transcript.content_safety.results is not None 112 | assert len(transcript.content_safety.results) > 0 113 | assert len(transcript.content_safety.results) == len( 114 | mock_response["content_safety_labels"]["results"] 115 | ) 116 | for response_result, transcript_result in zip( 117 | mock_response["content_safety_labels"]["results"], 118 | transcript.content_safety.results, 119 | ): 120 | assert transcript_result.text == response_result["text"] 121 | 122 | assert ( 123 | transcript_result.timestamp.start == response_result["timestamp"]["start"] 124 | ) 125 | assert transcript_result.timestamp.end == response_result["timestamp"]["end"] 126 | 127 | assert len(transcript_result.labels) > 0 128 | assert len(transcript_result.labels) == len(response_result["labels"]) 129 | for response_label, transcript_label in zip( 130 | response_result["labels"], transcript_result.labels 131 | ): 132 | assert transcript_label.label == response_label["label"] 133 | assert transcript_label.confidence == response_label["confidence"] 134 | assert transcript_label.severity == response_label["severity"] 135 | 136 | # Verify summary 137 | assert transcript.content_safety.summary is not None 138 | assert len(transcript.content_safety.summary) > 0 139 | assert len(transcript.content_safety.summary) == len( 140 | mock_response["content_safety_labels"]["summary"] 141 | ) 142 | for response_summary_items, transcript_summary_items in zip( 143 | mock_response["content_safety_labels"]["summary"].items(), 144 | transcript.content_safety.summary.items(), 145 | ): 146 | response_summary_key, response_summary_value = response_summary_items 147 | transcript_summary_key, transcript_summary_value = transcript_summary_items 148 | 149 | assert transcript_summary_key == response_summary_key 150 | assert transcript_summary_value == response_summary_value 151 | 152 | # Verify severity score summary 153 | assert transcript.content_safety.severity_score_summary is not None 154 | assert len(transcript.content_safety.severity_score_summary) > 0 155 | assert len(transcript.content_safety.severity_score_summary) == len( 156 | mock_response["content_safety_labels"]["severity_score_summary"] 157 | ) 158 | for ( 159 | response_severity_score_summary_items, 160 | transcript_severity_score_summary_items, 161 | ) in zip( 162 | mock_response["content_safety_labels"]["severity_score_summary"].items(), 163 | transcript.content_safety.severity_score_summary.items(), 164 | ): 165 | ( 166 | response_severity_score_summary_key, 167 | response_severity_score_summary_values, 168 | ) = response_severity_score_summary_items 169 | ( 170 | transcript_severity_score_summary_key, 171 | transcript_severity_score_summary_values, 172 | ) = transcript_severity_score_summary_items 173 | 174 | assert ( 175 | transcript_severity_score_summary_key == response_severity_score_summary_key 176 | ) 177 | assert ( 178 | transcript_severity_score_summary_values.high 179 | == response_severity_score_summary_values["high"] 180 | ) 181 | assert ( 182 | transcript_severity_score_summary_values.medium 183 | == response_severity_score_summary_values["medium"] 184 | ) 185 | assert ( 186 | transcript_severity_score_summary_values.low 187 | == response_severity_score_summary_values["low"] 188 | ) 189 | 190 | 191 | def test_content_safety_with_confidence_threshold(httpx_mock: HTTPXMock): 192 | """ 193 | Tests that `content_safety_confidence` can be set in the `TranscriptionConfig` 194 | and will be included in the request body 195 | """ 196 | confidence = 40 197 | request, _ = unit_test_utils.submit_mock_transcription_request( 198 | httpx_mock, 199 | mock_response=factories.generate_dict_factory( 200 | factories.TranscriptCompletedResponseFactory 201 | )(), 202 | config=aai.TranscriptionConfig( 203 | content_safety=True, content_safety_confidence=confidence 204 | ), 205 | ) 206 | 207 | assert request.get("content_safety") is True 208 | assert request.get("content_safety_confidence") == confidence 209 | 210 | 211 | @pytest.mark.parametrize("confidence", [1, 101]) 212 | def test_content_safety_with_invalid_confidence_threshold( 213 | httpx_mock: HTTPXMock, confidence: int 214 | ): 215 | """ 216 | Tests that a `content_safety_confidence` outside the acceptable range will cause 217 | an exception to be raised before the request is sent 218 | """ 219 | with pytest.raises(ValueError) as error: 220 | unit_test_utils.submit_mock_transcription_request( 221 | httpx_mock, 222 | mock_response={}, # We don't expect to produce a response 223 | config=aai.TranscriptionConfig( 224 | content_safety=True, content_safety_confidence=confidence 225 | ), 226 | ) 227 | 228 | assert "content_safety_confidence" in str(error) 229 | 230 | # Check that the error was raised before any requests were made 231 | assert len(httpx_mock.get_requests()) == 0 232 | -------------------------------------------------------------------------------- /tests/unit/test_custom_spelling.py: -------------------------------------------------------------------------------- 1 | import factory 2 | from pytest_httpx import HTTPXMock 3 | 4 | import tests.unit.unit_test_utils as unit_test_utils 5 | import assemblyai as aai 6 | from tests.unit import factories 7 | 8 | aai.settings.api_key = "test" 9 | 10 | 11 | class CustomSpellingFactory(factory.Factory): 12 | class Meta: 13 | model = dict # The model is a dictionary 14 | rename = {"_from": "from"} 15 | 16 | _from = factory.List([factory.Faker("word")]) # List of words in 'from' 17 | to = factory.Faker("word") # one word in 'to' 18 | 19 | 20 | class CustomSpellingResponseFactory(factories.TranscriptCompletedResponseFactory): 21 | @factory.lazy_attribute 22 | def custom_spelling(self): 23 | return [CustomSpellingFactory()] 24 | 25 | 26 | def test_custom_spelling_disabled_by_default(httpx_mock: HTTPXMock): 27 | """ 28 | Tests that not calling `set_custom_spelling()` on the `TranscriptionConfig` 29 | will result in the default behavior of it being excluded from the request body. 30 | """ 31 | request_body, transcript = unit_test_utils.submit_mock_transcription_request( 32 | httpx_mock, 33 | mock_response=factories.generate_dict_factory( 34 | factories.TranscriptCompletedResponseFactory 35 | )(), 36 | config=aai.TranscriptionConfig(), 37 | ) 38 | assert request_body.get("custom_spelling") is None 39 | assert transcript.json_response.get("custom_spelling") is None 40 | 41 | 42 | def test_custom_spelling_set_config_succeeds(): 43 | """ 44 | Tests that calling `set_custom_spelling()` on the `TranscriptionConfig` 45 | will set the values correctly, and that the config values can be accessed again 46 | through the custom_spelling property. 47 | """ 48 | config = aai.TranscriptionConfig() 49 | 50 | # Setting a string will be put in a list 51 | config.set_custom_spelling({"AssemblyAI": "assemblyAI"}) 52 | assert config.custom_spelling == {"AssemblyAI": ["assemblyAI"]} 53 | 54 | # Setting multiple pairs works 55 | config.set_custom_spelling( 56 | {"AssemblyAI": "assemblyAI", "Kubernetes": ["k8s", "kubernetes"]}, override=True 57 | ) 58 | assert config.custom_spelling == { 59 | "AssemblyAI": ["assemblyAI"], 60 | "Kubernetes": ["k8s", "kubernetes"], 61 | } 62 | 63 | 64 | def test_custom_spelling_enabled(httpx_mock: HTTPXMock): 65 | """ 66 | Tests that calling `set_custom_spelling()` on the `TranscriptionConfig` 67 | will result in correct `custom_spelling` in the request body, and that the 68 | response is properly parsed into the `custom_spelling` field. 69 | """ 70 | 71 | mock_response = factories.generate_dict_factory(CustomSpellingResponseFactory)() 72 | 73 | # Set up the custom spelling config based on the mocked values 74 | from_ = mock_response["custom_spelling"][0]["from"] 75 | to = mock_response["custom_spelling"][0]["to"] 76 | 77 | config = aai.TranscriptionConfig().set_custom_spelling({to: from_}) 78 | 79 | request_body, transcript = unit_test_utils.submit_mock_transcription_request( 80 | httpx_mock, 81 | mock_response=mock_response, 82 | config=config, 83 | ) 84 | 85 | # Check that request body was properly defined 86 | custom_spelling_response = request_body["custom_spelling"] 87 | assert custom_spelling_response is not None and len(custom_spelling_response) > 0 88 | assert "from" in custom_spelling_response[0] 89 | assert "to" in custom_spelling_response[0] 90 | 91 | # Check that transcript has no errors and custom spelling response corresponds to request 92 | assert transcript.error is None 93 | assert transcript.json_response["custom_spelling"] == custom_spelling_response 94 | -------------------------------------------------------------------------------- /tests/unit/test_domains.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | import assemblyai as aai 4 | 5 | 6 | def test_configuration_drift(): 7 | """ 8 | Tests whether `TranscriptionConfig` drifts from `RawTranscriptionConfig` (properties, methods) 9 | """ 10 | 11 | # a map of special setters that are defined in types.TranscriptionConfig 12 | special_setters = { 13 | "set_audio_slice", # audio_start_from, audio_end_at 14 | "set_custom_spelling", # custom_spelling 15 | "raw", # access to the underlying raw config 16 | "set_word_boost", # word boost setter 17 | "set_casing_and_formatting", # punctuation, formatting setter 18 | "set_redact_pii", # PII redaction 19 | "set_summarize", # summarization 20 | "set_webhook", # webhook 21 | "set_speaker_diarization", # speaker diarization 22 | "set_content_safety", # content safety 23 | } 24 | 25 | # get all members 26 | non_raw_members = inspect.getmembers(aai.TranscriptionConfig) 27 | 28 | # just retrieve the names 29 | raw_member_names = set(aai.RawTranscriptionConfig.__fields__.keys()) 30 | raw_member_names.discard("model_config") 31 | non_raw_member_names = set( 32 | name for name, _ in non_raw_members if not name.startswith("_") 33 | ) 34 | 35 | # get the differences 36 | diff_lhs = non_raw_member_names.difference(raw_member_names) 37 | diff_rhs = raw_member_names.difference(non_raw_member_names) 38 | differences = diff_lhs.union(diff_rhs) 39 | 40 | # check for the special setters 41 | differences = differences - special_setters 42 | 43 | # no differences: no drift. 44 | assert not differences 45 | -------------------------------------------------------------------------------- /tests/unit/test_entity_detection.py: -------------------------------------------------------------------------------- 1 | import factory 2 | from pytest_httpx import HTTPXMock 3 | 4 | import tests.unit.unit_test_utils as unit_test_utils 5 | import assemblyai as aai 6 | from tests.unit import factories 7 | 8 | aai.settings.api_key = "test" 9 | 10 | 11 | class EntityFactory(factory.Factory): 12 | class Meta: 13 | model = aai.types.Entity 14 | 15 | entity_type = factory.Faker("enum", enum_cls=aai.types.EntityType) 16 | text = factory.Faker("sentence") 17 | start = factory.Faker("pyint") 18 | end = factory.Faker("pyint") 19 | 20 | 21 | class EntityDetectionResponseFactory(factories.TranscriptCompletedResponseFactory): 22 | entities = factory.List([factory.SubFactory(EntityFactory)]) 23 | 24 | 25 | def test_entity_detection_disabled_by_default(httpx_mock: HTTPXMock): 26 | """ 27 | Tests that excluding `entity_detection` from the `TranscriptionConfig` will 28 | result in the default behavior of it being excluded from the request body 29 | """ 30 | request_body, transcript = unit_test_utils.submit_mock_transcription_request( 31 | httpx_mock, 32 | mock_response=factories.generate_dict_factory( 33 | factories.TranscriptCompletedResponseFactory 34 | )(), 35 | config=aai.TranscriptionConfig(), 36 | ) 37 | assert request_body.get("entity_detection") is None 38 | assert transcript.entities is None 39 | 40 | 41 | def test_entity_detection_enabled(httpx_mock: HTTPXMock): 42 | """ 43 | Tests that including `entity_detection=True` in the `TranscriptionConfig` 44 | will result in `entity_detection=True` in the request body, and that the 45 | response is properly parsed into a `Transcript` object 46 | """ 47 | mock_response = factories.generate_dict_factory(EntityDetectionResponseFactory)() 48 | request_body, transcript = unit_test_utils.submit_mock_transcription_request( 49 | httpx_mock, 50 | mock_response=mock_response, 51 | config=aai.TranscriptionConfig(entity_detection=True), 52 | ) 53 | 54 | # Check that request body was properly defined 55 | assert request_body.get("entity_detection") is True 56 | 57 | # Check that transcript was properly parsed from JSON response 58 | assert transcript.error is None 59 | assert transcript.entities is not None 60 | assert len(transcript.entities) > 0 61 | assert len(transcript.entities) == len(mock_response["entities"]) 62 | 63 | for entity in transcript.entities: 64 | assert len(entity.text.strip()) > 0 65 | -------------------------------------------------------------------------------- /tests/unit/test_extras.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import mock_open, patch 2 | 3 | import assemblyai as aai 4 | 5 | 6 | def test_stream_file_empty_file(): 7 | """ 8 | Test streaming of an empty file. 9 | """ 10 | 11 | data = b"" 12 | sample_rate = 44100 13 | 14 | m = mock_open(read_data=data) 15 | 16 | with patch("builtins.open", m), patch("time.sleep", return_value=None): 17 | chunks = list(aai.extras.stream_file("fake_path", sample_rate)) 18 | 19 | # Expect no chunk 20 | assert len(chunks) == 0 21 | 22 | 23 | def test_stream_file_small_file(): 24 | """ 25 | Tests streaming a file smaller than 300ms. 26 | """ 27 | 28 | data = b"\x00" * int(0.2 * 44100) * 2 29 | sample_rate = 44100 30 | 31 | m = mock_open(read_data=data) 32 | 33 | with patch("builtins.open", m), patch("time.sleep", return_value=None): 34 | chunks = list(aai.extras.stream_file("fake_path", sample_rate)) 35 | 36 | # Expecting one chunks because of no padding at the end 37 | expected_chunk_length = int(0.2 * sample_rate * 2) 38 | assert len(chunks) == 1 39 | assert len(chunks[0]) == expected_chunk_length 40 | assert chunks[0] == b"\x00" * expected_chunk_length 41 | 42 | 43 | def test_stream_file_large_file(): 44 | """ 45 | Test streaming a file larger than 300ms. 46 | """ 47 | 48 | data = b"\x00" * int(0.6 * 44100) * 2 49 | sample_rate = 44100 50 | 51 | m = mock_open(read_data=data) 52 | 53 | with patch("builtins.open", m), patch("time.sleep", return_value=None): 54 | chunks = list(aai.extras.stream_file("fake_path", sample_rate)) 55 | 56 | # Expecting two chunks 57 | assert len(chunks) == 2 58 | 59 | 60 | def test_stream_file_exact_file(): 61 | """ 62 | Test streaming a file exactly 300ms long. 63 | """ 64 | 65 | data = b"\x00" * int(0.3 * 44100) * 2 66 | sample_rate = 44100 67 | 68 | m = mock_open(read_data=data) 69 | 70 | with patch("builtins.open", m), patch("time.sleep", return_value=None): 71 | chunks = list(aai.extras.stream_file("fake_path", sample_rate)) 72 | 73 | # Expecting one chunk 74 | assert len(chunks) == 1 75 | -------------------------------------------------------------------------------- /tests/unit/test_iab_categories.py: -------------------------------------------------------------------------------- 1 | import factory 2 | from pytest_httpx import HTTPXMock 3 | 4 | import tests.unit.unit_test_utils as unit_test_utils 5 | import assemblyai as aai 6 | from tests.unit import factories 7 | 8 | aai.settings.api_key = "test" 9 | 10 | 11 | class IABLabelResultFactory(factory.Factory): 12 | class Meta: 13 | model = aai.types.IABLabelResult 14 | 15 | relevance = factory.Faker("pyfloat", min_value=0, max_value=1) 16 | label = factory.Faker("word") 17 | 18 | 19 | class IABResultFactory(factory.Factory): 20 | class Meta: 21 | model = aai.types.IABResult 22 | 23 | text = factory.Faker("sentence") 24 | labels = factory.List([factory.SubFactory(IABLabelResultFactory)]) 25 | timestamp = factory.SubFactory(factories.TimestampFactory) 26 | 27 | 28 | class IABResponseFactory(factory.Factory): 29 | class Meta: 30 | model = aai.types.IABResponse 31 | 32 | status = aai.types.StatusResult.success.value 33 | results = factory.List([factory.SubFactory(IABResultFactory)]) 34 | summary = factory.Dict( 35 | { 36 | "Automotive>AutoType>ConceptCars": factory.Faker( 37 | "pyfloat", min_value=0, max_value=1 38 | ) 39 | } 40 | ) 41 | 42 | 43 | class IABCategoriesResponseFactory(factories.TranscriptCompletedResponseFactory): 44 | iab_categories_result = factory.SubFactory(IABResponseFactory) 45 | 46 | 47 | def test_iab_categories_disabled_by_default(httpx_mock: HTTPXMock): 48 | """ 49 | Tests that excluding `iab_categories` from the `TranscriptionConfig` will 50 | result in the default behavior of it being excluded from the request body 51 | """ 52 | 53 | request_body, transcript = unit_test_utils.submit_mock_transcription_request( 54 | httpx_mock, 55 | mock_response=factories.generate_dict_factory( 56 | factories.TranscriptCompletedResponseFactory 57 | )(), 58 | config=aai.TranscriptionConfig(), 59 | ) 60 | assert request_body.get("iab_categories") is None 61 | assert transcript.iab_categories is None 62 | 63 | 64 | def test_iab_categories_enabled(httpx_mock: HTTPXMock): 65 | """ 66 | Tests that including `iab_categories=True` in the `TranscriptionConfig` will 67 | result in `iab_categories` being included in the request body, and that 68 | the response will be properly parsed into the `Transcript` object 69 | """ 70 | 71 | mock_response = factories.generate_dict_factory(IABCategoriesResponseFactory)() 72 | 73 | request_body, transcript = unit_test_utils.submit_mock_transcription_request( 74 | httpx_mock, 75 | mock_response=mock_response, 76 | config=aai.TranscriptionConfig(iab_categories=True), 77 | ) 78 | 79 | assert request_body.get("iab_categories") is True 80 | 81 | assert transcript.error is None 82 | 83 | assert transcript.iab_categories is not None 84 | assert transcript.iab_categories.status == mock_response.get( 85 | "iab_categories_result", {} 86 | ).get("status") 87 | 88 | # Validate results 89 | response_results = mock_response.get("iab_categories_result", {}).get("results", []) 90 | transcript_results = transcript.iab_categories.results 91 | 92 | assert transcript_results is not None 93 | assert len(transcript_results) == len(response_results) 94 | assert len(transcript_results) > 0 95 | 96 | for response_result, transcript_result in zip(response_results, transcript_results): 97 | assert transcript_result.text == response_result.get("text") 98 | assert len(transcript_result.text) > 0 99 | 100 | assert len(transcript_result.labels) > 0 101 | assert len(transcript_result.labels) == len(response_result.get("labels", [])) 102 | for response_label, transcript_label in zip( 103 | response_result.get("labels", []), transcript_result.labels 104 | ): 105 | assert transcript_label.relevance == response_label.get("relevance") 106 | assert transcript_label.label == response_label.get("label") 107 | 108 | # Validate summary 109 | response_summary = mock_response.get("iab_categories_result", {}).get("summary", {}) 110 | transcript_summary = transcript.iab_categories.summary 111 | 112 | assert transcript_summary is not None 113 | assert len(transcript_summary) > 0 114 | assert transcript_summary == response_summary 115 | -------------------------------------------------------------------------------- /tests/unit/test_imports.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from importlib import reload 4 | from unittest.mock import mock_open, patch 5 | 6 | import httpx 7 | import pytest 8 | import pytest_mock 9 | from pytest_httpx import HTTPXMock 10 | 11 | import assemblyai as aai 12 | from assemblyai.api import ENDPOINT_UPLOAD 13 | 14 | 15 | class ImportFailureMocker: 16 | def __init__(self, module: str): 17 | self.module = module 18 | 19 | def find_spec(self, fullname, path, target=None): 20 | if fullname == self.module: 21 | raise ImportError 22 | 23 | def __enter__(self): 24 | # Remove module if already imported 25 | if self.module in sys.modules: 26 | del sys.modules[self.module] 27 | 28 | # Add self as first importer 29 | sys.meta_path.insert(0, self) 30 | return self 31 | 32 | def __exit__(self, type, value, traceback): 33 | # Remove self as importer 34 | sys.meta_path.pop(0) 35 | 36 | 37 | def __reload_assesmblyai_module(): 38 | reload(aai) 39 | aai.settings.api_key = "test" 40 | 41 | 42 | def test_import_sdk_without_extras_installed(): 43 | with ImportFailureMocker("pyaudio"): 44 | __reload_assesmblyai_module() 45 | # Test succeeds if no failures 46 | 47 | 48 | def test_import_sdk_and_use_extra_functions_without_extras_installed( 49 | httpx_mock: HTTPXMock, 50 | ): 51 | with ImportFailureMocker("pyaudio"): 52 | __reload_assesmblyai_module() 53 | 54 | local_file = os.urandom(10) 55 | expected_upload_url = "https://example.org/audio.wav" 56 | 57 | # patch the reading of a local file 58 | with patch("builtins.open", mock_open(read_data=local_file)): 59 | _ = aai.extras.stream_file(filepath="audio.wav", sample_rate=44_100) 60 | 61 | # mock the upload endpoint 62 | httpx_mock.add_response( 63 | url=f"{aai.settings.base_url}{ENDPOINT_UPLOAD}", 64 | status_code=httpx.codes.OK, 65 | method="POST", 66 | json={"upload_url": expected_upload_url}, 67 | match_content=local_file, 68 | ) 69 | 70 | upload_url = aai.extras.file_from_stream(local_file) 71 | assert upload_url == expected_upload_url 72 | 73 | 74 | def test_import_sdk_and_use_MicrophoneStream_without_extras_installed(): 75 | with ImportFailureMocker("pyaudio"): 76 | __reload_assesmblyai_module() 77 | 78 | with pytest.raises(aai.extras.AssemblyAIExtrasNotInstalledError): 79 | aai.extras.MicrophoneStream() 80 | 81 | 82 | def test_import_sdk_and_use_MicrophoneStream_with_extras_installed( 83 | mocker: pytest_mock.MockerFixture, 84 | ): 85 | import pyaudio 86 | 87 | __reload_assesmblyai_module() 88 | 89 | mocker.patch.object(pyaudio.PyAudio, "open", return_value=None) 90 | aai.extras.MicrophoneStream() 91 | 92 | # Test succeeds if no failures 93 | -------------------------------------------------------------------------------- /tests/unit/test_multichannel.py: -------------------------------------------------------------------------------- 1 | from pytest_httpx import HTTPXMock 2 | 3 | import tests.unit.unit_test_utils as unit_test_utils 4 | import assemblyai as aai 5 | from tests.unit import factories 6 | 7 | aai.settings.api_key = "test" 8 | 9 | 10 | class MultichannelResponseFactory(factories.TranscriptCompletedResponseFactory): 11 | multichannel = True 12 | audio_channels = 2 13 | 14 | 15 | def test_multichannel_disabled_by_default(httpx_mock: HTTPXMock): 16 | """ 17 | Tests that not setting `multichannel=True` in the `TranscriptionConfig` 18 | will result in the default behavior of it being excluded from the request body. 19 | """ 20 | request_body, transcript = unit_test_utils.submit_mock_transcription_request( 21 | httpx_mock, 22 | mock_response=factories.generate_dict_factory( 23 | factories.TranscriptCompletedResponseFactory 24 | )(), 25 | config=aai.TranscriptionConfig(), 26 | ) 27 | assert request_body.get("multichannel") is None 28 | assert transcript.json_response.get("multichannel") is None 29 | 30 | 31 | def test_multichannel_enabled(httpx_mock: HTTPXMock): 32 | """ 33 | Tests that not setting `multichannel=True` in the `TranscriptionConfig` 34 | will result in correct `multichannel` in the request body, and that the 35 | response is properly parsed into the `multichannel` and `utterances` field. 36 | """ 37 | 38 | mock_response = factories.generate_dict_factory(MultichannelResponseFactory)() 39 | request_body, transcript = unit_test_utils.submit_mock_transcription_request( 40 | httpx_mock, 41 | mock_response=mock_response, 42 | config=aai.TranscriptionConfig(multichannel=True), 43 | ) 44 | 45 | # Check that request body was properly defined 46 | multichannel_response = request_body.get("multichannel") 47 | assert multichannel_response is not None 48 | 49 | # Check that transcript has no errors and multichannel response is correctly returned 50 | assert transcript.error is None 51 | assert transcript.json_response["multichannel"] == multichannel_response 52 | assert transcript.json_response["audio_channels"] > 1 53 | 54 | # Check that utterances are correctly parsed 55 | assert transcript.utterances is not None 56 | assert len(transcript.utterances) > 0 57 | for utterance in transcript.utterances: 58 | assert int(utterance.channel) > 0 59 | -------------------------------------------------------------------------------- /tests/unit/test_realtime_transcriber.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import uuid 4 | from unittest.mock import MagicMock 5 | from urllib.parse import urlencode 6 | 7 | import httpx 8 | import pytest 9 | import websockets.exceptions 10 | from faker import Faker 11 | from pytest_httpx import HTTPXMock 12 | from pytest_mock import MockFixture 13 | 14 | import assemblyai as aai 15 | from assemblyai.api import ENDPOINT_REALTIME_TOKEN 16 | 17 | aai.settings.api_key = "test" 18 | 19 | 20 | def _disable_rw_threads(mocker: MockFixture): 21 | """ 22 | Disable the read/write threads for the websocket 23 | """ 24 | 25 | mocker.patch("threading.Thread.start", return_value=None) 26 | 27 | 28 | @pytest.mark.parametrize( 29 | "encoding,token,expected_header", 30 | [ 31 | (None, None, {"Authorization": "test"}), 32 | (aai.AudioEncoding.pcm_s16le, None, {"Authorization": "test"}), 33 | (aai.AudioEncoding.pcm_mulaw, None, {"Authorization": "test"}), 34 | (None, "12345678", None), 35 | (aai.AudioEncoding.pcm_s16le, "12345678", None), 36 | ], 37 | ) 38 | def test_realtime_connect_has_parameters( 39 | encoding, token, expected_header, mocker: MockFixture 40 | ): 41 | """ 42 | Test that the connect method has the correct parameters set 43 | """ 44 | aai.settings.base_url = "https://api.assemblyai.com" 45 | 46 | actual_url = None 47 | actual_additional_headers = None 48 | actual_open_timeout = None 49 | 50 | def mocked_websocket_connect( 51 | url: str, additional_headers: dict, open_timeout: float 52 | ): 53 | nonlocal actual_url, actual_additional_headers, actual_open_timeout 54 | actual_url = url 55 | actual_additional_headers = additional_headers 56 | actual_open_timeout = open_timeout 57 | 58 | mocker.patch( 59 | "assemblyai.transcriber.websocket_connect", 60 | new=mocked_websocket_connect, 61 | ) 62 | _disable_rw_threads(mocker) 63 | 64 | transcriber = aai.RealtimeTranscriber( 65 | on_data=lambda: None, 66 | on_error=lambda error: print(error), 67 | sample_rate=44_100, 68 | word_boost=["AssemblyAI"], 69 | encoding=encoding, 70 | token=token, 71 | ) 72 | 73 | transcriber.connect(timeout=15.0) 74 | 75 | params = dict(sample_rate=44100, word_boost=json.dumps(["AssemblyAI"])) 76 | if encoding: 77 | params["encoding"] = encoding.value 78 | if token: 79 | params["token"] = token 80 | 81 | assert actual_url == f"wss://api.assemblyai.com/v2/realtime/ws?{urlencode(params)}" 82 | assert actual_additional_headers == expected_header 83 | assert actual_open_timeout == 15.0 84 | 85 | 86 | def test_realtime_connect_succeeds(mocker: MockFixture): 87 | """ 88 | Tests that the `RealtimeTranscriber` successfully connects to the `real-time` service. 89 | """ 90 | on_error_called = False 91 | 92 | def on_error(error: aai.RealtimeError): 93 | nonlocal on_error_called 94 | on_error_called = True 95 | 96 | transcriber = aai.RealtimeTranscriber( 97 | on_data=lambda _: None, 98 | on_error=on_error, 99 | sample_rate=44_100, 100 | ) 101 | 102 | mocker.patch( 103 | "assemblyai.transcriber.websocket_connect", 104 | return_value=MagicMock(), 105 | ) 106 | 107 | # mock the read/write threads 108 | _disable_rw_threads(mocker) 109 | 110 | # should pass 111 | transcriber.connect() 112 | 113 | # no errors should be called 114 | assert not on_error_called 115 | 116 | 117 | def test_realtime_token_connect_succeeds(mocker: MockFixture): 118 | """ 119 | Tests that the `RealtimeTranscriber` successfully connects 120 | to the `real-time` service when a token is used. 121 | """ 122 | on_error_called = False 123 | 124 | # reset the API key 125 | mocker.patch("assemblyai.settings.api_key", new=None) 126 | 127 | def on_error(error: aai.RealtimeError): 128 | nonlocal on_error_called 129 | on_error_called = True 130 | 131 | transcriber = aai.RealtimeTranscriber( 132 | on_data=lambda _: None, on_error=on_error, sample_rate=44_100, token="12345" 133 | ) 134 | 135 | mocker.patch( 136 | "assemblyai.transcriber.websocket_connect", 137 | return_value=MagicMock(), 138 | ) 139 | 140 | # mock the read/write threads 141 | _disable_rw_threads(mocker) 142 | 143 | # should pass 144 | transcriber.connect() 145 | 146 | # no errors should be called 147 | assert not on_error_called 148 | 149 | 150 | def test_realtime_connect_fails(mocker: MockFixture): 151 | """ 152 | Tests that the `RealtimeTranscriber` fails to connect to the `real-time` service. 153 | """ 154 | 155 | on_error_called = False 156 | 157 | def on_error(error: aai.RealtimeError): 158 | nonlocal on_error_called 159 | on_error_called = True 160 | 161 | assert isinstance(error, aai.RealtimeError) 162 | assert "connection failed" in str(error) 163 | 164 | transcriber = aai.RealtimeTranscriber( 165 | on_data=lambda _: None, 166 | on_error=on_error, 167 | sample_rate=44_100, 168 | ) 169 | mocker.patch( 170 | "assemblyai.transcriber.websocket_connect", 171 | side_effect=Exception("connection failed"), 172 | ) 173 | 174 | transcriber.connect() 175 | 176 | assert on_error_called 177 | 178 | 179 | def test_realtime__read_succeeds(mocker: MockFixture, faker: Faker): 180 | """ 181 | Tests the `_read` method of the `_RealtimeTranscriberImpl` class. 182 | """ 183 | 184 | expected_transcripts = [ 185 | aai.RealtimeFinalTranscript( 186 | created=faker.date_time(), 187 | text=faker.sentence(), 188 | audio_start=0, 189 | audio_end=1, 190 | confidence=1.0, 191 | words=[], 192 | punctuated=True, 193 | text_formatted=True, 194 | ) 195 | ] 196 | 197 | received_transcripts = [] 198 | 199 | def on_data(data: aai.RealtimeTranscript): 200 | nonlocal received_transcripts 201 | received_transcripts.append(data) 202 | 203 | transcriber = aai.RealtimeTranscriber( 204 | on_data=on_data, 205 | on_error=lambda _: None, 206 | sample_rate=44_100, 207 | ) 208 | 209 | transcriber._impl._websocket = MagicMock() 210 | websocket_recv = [ 211 | json.dumps(msg.dict(), default=str) for msg in expected_transcripts 212 | ] 213 | transcriber._impl._websocket.recv.side_effect = websocket_recv 214 | 215 | with pytest.raises(StopIteration): 216 | transcriber._impl._read() 217 | 218 | assert received_transcripts == expected_transcripts 219 | 220 | 221 | def test_realtime__read_fails(mocker: MockFixture): 222 | """ 223 | Tests the `_read` method of the `_RealtimeTranscriberImpl` class. 224 | """ 225 | 226 | on_error_called = False 227 | 228 | def on_error(error: aai.RealtimeError): 229 | nonlocal on_error_called 230 | on_error_called = True 231 | 232 | transcriber = aai.RealtimeTranscriber( 233 | on_data=lambda _: None, 234 | on_error=on_error, 235 | sample_rate=44_100, 236 | ) 237 | 238 | transcriber._impl._websocket = MagicMock() 239 | error = websockets.exceptions.ConnectionClosedOK(rcvd=None, sent=None) 240 | transcriber._impl._websocket.recv.side_effect = error 241 | 242 | transcriber._impl._read() 243 | 244 | assert on_error_called 245 | 246 | 247 | def test_realtime__write_succeeds(mocker: MockFixture): 248 | """ 249 | Tests the `_write` method of the `_RealtimeTranscriberImpl` class. 250 | """ 251 | audio_chunks = [ 252 | bytes([1, 2, 3, 4, 5]), 253 | bytes([6, 7, 8, 9, 10]), 254 | ] 255 | 256 | actual_sent = [] 257 | 258 | def mocked_send(data: str): 259 | nonlocal actual_sent 260 | actual_sent.append(data) 261 | 262 | transcriber = aai.RealtimeTranscriber( 263 | on_data=lambda _: None, 264 | on_error=lambda _: None, 265 | sample_rate=44_100, 266 | ) 267 | 268 | transcriber._impl._websocket = MagicMock() 269 | transcriber._impl._websocket.send = mocked_send 270 | transcriber._impl._stop_event.is_set = MagicMock(side_effect=[False, False, True]) 271 | 272 | transcriber.stream(audio_chunks[0]) 273 | transcriber.stream(audio_chunks[1]) 274 | 275 | transcriber._impl._write() 276 | 277 | # assert that the correct data was sent (= the exact input bytes) 278 | assert len(actual_sent) == 2 279 | assert actual_sent[0] == audio_chunks[0] 280 | assert actual_sent[1] == audio_chunks[1] 281 | 282 | 283 | def test_realtime__handle_message_session_begins(mocker: MockFixture): 284 | """ 285 | Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class 286 | with the `SessionBegins` message. 287 | """ 288 | 289 | test_message = { 290 | "message_type": "SessionBegins", 291 | "session_id": str(uuid.uuid4()), 292 | "expires_at": datetime.datetime.now().isoformat(), 293 | } 294 | 295 | on_open_called = False 296 | 297 | def on_open(session_opened: aai.RealtimeSessionOpened): 298 | nonlocal on_open_called 299 | on_open_called = True 300 | assert isinstance(session_opened, aai.RealtimeSessionOpened) 301 | assert session_opened.session_id == uuid.UUID(test_message["session_id"]) 302 | assert session_opened.expires_at.isoformat() == test_message["expires_at"] 303 | 304 | transcriber = aai.RealtimeTranscriber( 305 | on_open=on_open, 306 | on_data=lambda _: None, 307 | on_error=lambda _: None, 308 | sample_rate=44_100, 309 | ) 310 | 311 | transcriber._impl._handle_message(test_message) 312 | 313 | assert on_open_called 314 | 315 | 316 | def test_realtime__handle_message_partial_transcript(mocker: MockFixture): 317 | """ 318 | Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class 319 | with the `PartialTranscript` message. 320 | """ 321 | 322 | test_message = { 323 | "message_type": "PartialTranscript", 324 | "text": "hello world", 325 | "audio_start": 0, 326 | "audio_end": 1500, 327 | "confidence": 0.99, 328 | "created": datetime.datetime.now().isoformat(), 329 | "words": [ 330 | { 331 | "text": "hello", 332 | "start": 0, 333 | "end": 500, 334 | "confidence": 0.99, 335 | }, 336 | { 337 | "text": "world", 338 | "start": 500, 339 | "end": 1500, 340 | "confidence": 0.99, 341 | }, 342 | ], 343 | } 344 | 345 | on_data_called = False 346 | 347 | def on_data(data: aai.RealtimePartialTranscript): 348 | nonlocal on_data_called 349 | on_data_called = True 350 | assert isinstance(data, aai.RealtimePartialTranscript) 351 | assert data.text == test_message["text"] 352 | assert data.audio_start == test_message["audio_start"] 353 | assert data.audio_end == test_message["audio_end"] 354 | assert data.confidence == test_message["confidence"] 355 | assert data.created.isoformat() == test_message["created"] 356 | assert data.words == [ 357 | aai.RealtimeWord( 358 | text=test_message["words"][0]["text"], 359 | start=test_message["words"][0]["start"], 360 | end=test_message["words"][0]["end"], 361 | confidence=test_message["words"][0]["confidence"], 362 | ), 363 | aai.RealtimeWord( 364 | text=test_message["words"][1]["text"], 365 | start=test_message["words"][1]["start"], 366 | end=test_message["words"][1]["end"], 367 | confidence=test_message["words"][1]["confidence"], 368 | ), 369 | ] 370 | 371 | transcriber = aai.RealtimeTranscriber( 372 | on_data=on_data, 373 | on_error=lambda _: None, 374 | sample_rate=44_100, 375 | ) 376 | 377 | transcriber._impl._handle_message(test_message) 378 | 379 | assert on_data_called 380 | 381 | 382 | def test_realtime__handle_message_final_transcript(mocker: MockFixture): 383 | """ 384 | Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class 385 | with the `FinalTranscript` message. 386 | """ 387 | 388 | test_message = { 389 | "message_type": "FinalTranscript", 390 | "text": "Hello, world!", 391 | "audio_start": 0, 392 | "audio_end": 1500, 393 | "confidence": 0.99, 394 | "created": datetime.datetime.now().isoformat(), 395 | "punctuated": True, 396 | "text_formatted": True, 397 | "words": [ 398 | { 399 | "text": "Hello,", 400 | "start": 0, 401 | "end": 500, 402 | "confidence": 0.99, 403 | }, 404 | { 405 | "text": "world!", 406 | "start": 500, 407 | "end": 1500, 408 | "confidence": 0.99, 409 | }, 410 | ], 411 | } 412 | 413 | on_data_called = False 414 | 415 | def on_data(data: aai.RealtimeFinalTranscript): 416 | nonlocal on_data_called 417 | on_data_called = True 418 | assert isinstance(data, aai.RealtimeFinalTranscript) 419 | assert data.text == test_message["text"] 420 | assert data.audio_start == test_message["audio_start"] 421 | assert data.audio_end == test_message["audio_end"] 422 | assert data.confidence == test_message["confidence"] 423 | assert data.created.isoformat() == test_message["created"] 424 | assert data.punctuated == test_message["punctuated"] 425 | assert data.text_formatted == test_message["text_formatted"] 426 | assert data.words == [ 427 | aai.RealtimeWord( 428 | text=test_message["words"][0]["text"], 429 | start=test_message["words"][0]["start"], 430 | end=test_message["words"][0]["end"], 431 | confidence=test_message["words"][0]["confidence"], 432 | ), 433 | aai.RealtimeWord( 434 | text=test_message["words"][1]["text"], 435 | start=test_message["words"][1]["start"], 436 | end=test_message["words"][1]["end"], 437 | confidence=test_message["words"][1]["confidence"], 438 | ), 439 | ] 440 | 441 | transcriber = aai.RealtimeTranscriber( 442 | on_data=on_data, 443 | on_error=lambda _: None, 444 | sample_rate=44_100, 445 | ) 446 | 447 | transcriber._impl._handle_message(test_message) 448 | 449 | assert on_data_called 450 | 451 | 452 | def test_realtime__handle_message_error_message(mocker: MockFixture): 453 | """ 454 | Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class 455 | with the error message. 456 | """ 457 | 458 | test_message = { 459 | "error": "test error", 460 | } 461 | 462 | on_error_called = False 463 | 464 | def on_error(error: aai.RealtimeError): 465 | nonlocal on_error_called 466 | on_error_called = True 467 | assert isinstance(error, aai.RealtimeError) 468 | assert str(error) == test_message["error"] 469 | 470 | transcriber = aai.RealtimeTranscriber( 471 | on_data=lambda _: None, 472 | on_error=on_error, 473 | sample_rate=44_100, 474 | ) 475 | 476 | transcriber._impl._handle_message(test_message) 477 | 478 | assert on_error_called 479 | 480 | 481 | def test_realtime__handle_message_session_information_message(mocker: MockFixture): 482 | """ 483 | Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class 484 | with the session information message. 485 | """ 486 | 487 | test_message = { 488 | "message_type": "SessionInformation", 489 | "audio_duration_seconds": 3000.0, 490 | } 491 | 492 | on_extra_session_information_called = False 493 | 494 | def on_extra_session_information(data: aai.RealtimeSessionInformation): 495 | nonlocal on_extra_session_information_called 496 | on_extra_session_information_called = True 497 | assert isinstance(data, aai.RealtimeSessionInformation) 498 | assert data.audio_duration_seconds == test_message["audio_duration_seconds"] 499 | 500 | transcriber = aai.RealtimeTranscriber( 501 | on_data=lambda _: None, 502 | on_error=lambda _: None, 503 | sample_rate=44_100, 504 | on_extra_session_information=on_extra_session_information, 505 | ) 506 | 507 | transcriber._impl._handle_message(test_message) 508 | 509 | assert on_extra_session_information_called 510 | 511 | 512 | def test_realtime__handle_message_unknown_message(mocker: MockFixture): 513 | """ 514 | Tests the `_handle_message` method of the `_RealtimeTranscriberImpl` class 515 | with an unknown message. 516 | """ 517 | 518 | test_message = { 519 | "message_type": "Unknown", 520 | } 521 | 522 | on_data_called = False 523 | 524 | def on_data(data: aai.RealtimeTranscript): 525 | nonlocal on_data_called 526 | on_data_called = True 527 | 528 | on_error_called = False 529 | 530 | def on_error(error: aai.RealtimeError): 531 | nonlocal on_error_called 532 | on_error_called = True 533 | 534 | transcriber = aai.RealtimeTranscriber( 535 | on_data=on_data, 536 | on_error=on_error, 537 | sample_rate=44_100, 538 | ) 539 | 540 | transcriber._impl._handle_message(test_message) 541 | 542 | assert not on_data_called 543 | assert not on_error_called 544 | 545 | 546 | def test_create_temporary_token(httpx_mock: HTTPXMock): 547 | """ 548 | Tests whether the creation of a temporary token is successful. 549 | """ 550 | 551 | # mock the specific endpoint 552 | httpx_mock.add_response( 553 | url=f"{aai.settings.base_url}{ENDPOINT_REALTIME_TOKEN}", 554 | status_code=httpx.codes.OK, 555 | method="POST", 556 | json={"token": "123456"}, 557 | ) 558 | 559 | token = aai.RealtimeTranscriber.create_temporary_token(expires_in=3000) 560 | 561 | assert token == "123456" 562 | 563 | 564 | # TODO: create tests for the `RealtimeTranscriber.close` method 565 | -------------------------------------------------------------------------------- /tests/unit/test_redact_pii.py: -------------------------------------------------------------------------------- 1 | import httpx 2 | import pytest 3 | from pytest_httpx import HTTPXMock 4 | from pytest_mock import MockerFixture 5 | 6 | import tests.unit.unit_test_utils as unit_test_utils 7 | import assemblyai as aai 8 | from assemblyai.api import ENDPOINT_TRANSCRIPT 9 | from tests.unit import factories 10 | 11 | aai.settings.api_key = "test" 12 | 13 | 14 | class TranscriptWithPIIRedactionResponseFactory( 15 | factories.TranscriptCompletedResponseFactory 16 | ): 17 | redact_pii = True 18 | redact_pii_audio = True 19 | redact_pii_policies = [ 20 | aai.types.PIIRedactionPolicy.date, 21 | ] 22 | 23 | 24 | def test_redact_pii_disabled_by_default(httpx_mock: HTTPXMock): 25 | """ 26 | Tests that excluding `redact_pii` from the `TranscriptionConfig` will 27 | result in the default behavior of it being excluded from the request body 28 | """ 29 | request_body, _ = unit_test_utils.submit_mock_transcription_request( 30 | httpx_mock, 31 | mock_response=factories.generate_dict_factory( 32 | factories.TranscriptCompletedResponseFactory 33 | )(), 34 | config=aai.TranscriptionConfig(), 35 | ) 36 | assert request_body.get("redact_pii") is None 37 | assert request_body.get("redact_pii_audio") is None 38 | assert request_body.get("redact_pii_policies") is None 39 | assert request_body.get("redact_pii_sub") is None 40 | 41 | 42 | def test_redact_pii_enabled(httpx_mock: HTTPXMock): 43 | """ 44 | Tests that enabling `redact_pii`, along with the required `redact_pii_policies` 45 | parameter will result in the request body containing those fields 46 | """ 47 | policies = [ 48 | aai.types.PIIRedactionPolicy.date, 49 | aai.types.PIIRedactionPolicy.phone_number, 50 | ] 51 | 52 | request_body, _ = unit_test_utils.submit_mock_transcription_request( 53 | httpx_mock, 54 | mock_response=factories.generate_dict_factory( 55 | TranscriptWithPIIRedactionResponseFactory 56 | )(), 57 | config=aai.TranscriptionConfig( 58 | redact_pii=True, 59 | redact_pii_policies=policies, 60 | ), 61 | ) 62 | 63 | assert request_body.get("redact_pii") is True 64 | assert request_body.get("redact_pii_policies") == policies 65 | 66 | 67 | def test_redact_pii_enabled_with_optional_params(httpx_mock: HTTPXMock): 68 | """ 69 | Tests that enabling `redact_pii`, along with the other optional parameters 70 | relevant to PII redaction, will result in the request body containing 71 | those fields 72 | """ 73 | policies = [ 74 | aai.types.PIIRedactionPolicy.date, 75 | aai.types.PIIRedactionPolicy.phone_number, 76 | ] 77 | sub_type = aai.types.PIISubstitutionPolicy.entity_name 78 | 79 | request_body, _ = unit_test_utils.submit_mock_transcription_request( 80 | httpx_mock, 81 | mock_response=factories.generate_dict_factory( 82 | TranscriptWithPIIRedactionResponseFactory 83 | )(), 84 | config=aai.TranscriptionConfig( 85 | redact_pii=True, 86 | redact_pii_audio=True, 87 | redact_pii_policies=policies, 88 | redact_pii_sub=sub_type, 89 | ), 90 | ) 91 | 92 | assert request_body.get("redact_pii") is True 93 | assert request_body.get("redact_pii_audio") is True 94 | assert request_body.get("redact_pii_policies") == policies 95 | assert request_body.get("redact_pii_sub") == sub_type 96 | 97 | 98 | def test_redact_pii_fails_without_policies(httpx_mock: HTTPXMock): 99 | """ 100 | Tests that enabling `redact_pii` without specifying any policies 101 | will result in an exception being raised before the API call is made 102 | """ 103 | with pytest.raises(ValueError) as error: 104 | unit_test_utils.submit_mock_transcription_request( 105 | httpx_mock, 106 | mock_response={}, 107 | config=aai.TranscriptionConfig( 108 | redact_pii=True, 109 | # No policies! 110 | ), 111 | ) 112 | 113 | assert "policy" in str(error) 114 | 115 | # Check that the error was raised before any requests were made 116 | assert len(httpx_mock.get_requests()) == 0 117 | 118 | 119 | def test_redact_pii_params_excluded_when_disabled(httpx_mock: HTTPXMock): 120 | """ 121 | Tests that additional PII redaction parameters are excluded from the submission 122 | request body if `redact_pii` itself is not enabled. 123 | """ 124 | request_body, _ = unit_test_utils.submit_mock_transcription_request( 125 | httpx_mock, 126 | mock_response=factories.generate_dict_factory( 127 | factories.TranscriptCompletedResponseFactory 128 | )(), 129 | config=aai.TranscriptionConfig( 130 | redact_pii=False, 131 | redact_pii_audio=True, 132 | redact_pii_policies=[aai.types.PIIRedactionPolicy.date], 133 | redact_pii_sub=aai.types.PIISubstitutionPolicy.entity_name, 134 | ), 135 | ) 136 | 137 | assert request_body.get("redact_pii") is None 138 | assert request_body.get("redact_pii_audio") is None 139 | assert request_body.get("redact_pii_policies") is None 140 | assert request_body.get("redact_pii_sub") is None 141 | 142 | 143 | def __get_redacted_audio_api_url(transcript: aai.Transcript) -> str: 144 | return ( 145 | f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{transcript.id}/redacted-audio" 146 | ) 147 | 148 | 149 | REDACTED_AUDIO_URL = "https://example.org/redacted-audio.wav" 150 | 151 | 152 | def __mock_successful_pii_audio_responses( 153 | httpx_mock: HTTPXMock, transcript: aai.Transcript 154 | ): 155 | # Mock pending redacted audio response on first call 156 | httpx_mock.add_response( 157 | url=__get_redacted_audio_api_url(transcript), 158 | status_code=202, 159 | method="GET", 160 | json={}, 161 | ) 162 | 163 | # Mock completed redacted audio response on second call 164 | httpx_mock.add_response( 165 | url=__get_redacted_audio_api_url(transcript), 166 | status_code=httpx.codes.OK, 167 | method="GET", 168 | json={ 169 | "redacted_audio_url": REDACTED_AUDIO_URL, 170 | "status": "redacted_audio_ready", 171 | }, 172 | ) 173 | 174 | 175 | def __mock_failed_pii_audio_responses( 176 | httpx_mock: HTTPXMock, transcript: aai.Transcript 177 | ): 178 | httpx_mock.add_response( 179 | url=__get_redacted_audio_api_url(transcript), 180 | status_code=400, 181 | method="GET", 182 | json={}, 183 | ) 184 | 185 | 186 | def test_get_pii_redacted_audio_url(httpx_mock: HTTPXMock): 187 | """ 188 | Tests that the PII-redacted audio URL can be retrieved from the API 189 | with a successful response 190 | """ 191 | _, transcript = unit_test_utils.submit_mock_transcription_request( 192 | httpx_mock, 193 | mock_response=factories.generate_dict_factory( 194 | TranscriptWithPIIRedactionResponseFactory 195 | )(), 196 | config=aai.TranscriptionConfig( 197 | redact_pii=True, 198 | redact_pii_policies=[aai.types.PIIRedactionPolicy.date], 199 | redact_pii_audio=True, 200 | ), 201 | ) 202 | 203 | __mock_successful_pii_audio_responses(httpx_mock, transcript) 204 | redacted_audio_url = transcript.get_redacted_audio_url() 205 | 206 | # Ensure we made a third and fourth network request to get the redacted audio information 207 | assert len(httpx_mock.get_requests()) == 4 208 | 209 | assert redacted_audio_url == REDACTED_AUDIO_URL 210 | 211 | 212 | def test_get_pii_redacted_audio_url_fails_if_redact_pii_not_enabled_for_transcript( 213 | httpx_mock: HTTPXMock, 214 | ): 215 | """ 216 | Tests that an error is thrown before any requests are made if 217 | `redact_pii` was not enabled for the transcript and 218 | `get_redacted_audio_url` is called 219 | """ 220 | _, transcript = unit_test_utils.submit_mock_transcription_request( 221 | httpx_mock, 222 | mock_response=factories.generate_dict_factory( 223 | factories.TranscriptCompletedResponseFactory 224 | )(), # standard response 225 | config=aai.TranscriptionConfig(), # blank config 226 | ) 227 | 228 | with pytest.raises(ValueError) as error: 229 | transcript.get_redacted_audio_url() 230 | 231 | assert "redact_pii" in str(error) 232 | 233 | # Ensure we never made the additional requests to get the redacted audio information 234 | assert len(httpx_mock.get_requests()) == 2 235 | 236 | 237 | def test_get_pii_redacted_audio_url_fails_if_redact_pii_audio_not_enabled_for_transcript( 238 | httpx_mock: HTTPXMock, 239 | ): 240 | """ 241 | Tests that an error is thrown before any requests are made if 242 | `redact_pii_audio` was not enabled for the transcript and 243 | `get_redacted_audio_url` is called 244 | """ 245 | _, transcript = unit_test_utils.submit_mock_transcription_request( 246 | httpx_mock, 247 | mock_response={ 248 | **factories.generate_dict_factory( 249 | TranscriptWithPIIRedactionResponseFactory 250 | )(), 251 | "redact_pii_audio": False, 252 | }, 253 | config=aai.TranscriptionConfig( 254 | redact_pii=True, redact_pii_policies=[aai.types.PIIRedactionPolicy.date] 255 | ), 256 | ) 257 | 258 | with pytest.raises(ValueError) as error: 259 | transcript.get_redacted_audio_url() 260 | 261 | assert "redact_pii_audio" in str(error) 262 | 263 | # Ensure we never made the additional requests to get the redacted audio information 264 | assert len(httpx_mock.get_requests()) == 2 265 | 266 | 267 | def test_get_pii_redacted_audio_url_fails_if_bad_response(httpx_mock: HTTPXMock): 268 | """ 269 | Tests that `get_redacted_audio_url` raises a `RedactedAudioUnavailableError` if 270 | the request to fetch the redacted audio URL returns a `400` status code, indicating 271 | that the redacted audio has expired 272 | """ 273 | _, transcript = unit_test_utils.submit_mock_transcription_request( 274 | httpx_mock, 275 | mock_response=factories.generate_dict_factory( 276 | TranscriptWithPIIRedactionResponseFactory 277 | )(), 278 | config=aai.TranscriptionConfig( 279 | redact_pii=True, 280 | redact_pii_policies=[aai.types.PIIRedactionPolicy.date], 281 | redact_pii_audio=True, 282 | ), 283 | ) 284 | 285 | __mock_failed_pii_audio_responses(httpx_mock, transcript) 286 | with pytest.raises(aai.types.RedactedAudioExpiredError): 287 | transcript.get_redacted_audio_url() 288 | 289 | 290 | def test_save_pii_redacted_audio(httpx_mock: HTTPXMock, mocker: MockerFixture): 291 | """ 292 | Tests that calling `save_redacted_audio` will download the redacted audio file 293 | to the caller's file system 294 | """ 295 | 296 | _, transcript = unit_test_utils.submit_mock_transcription_request( 297 | httpx_mock, 298 | mock_response=factories.generate_dict_factory( 299 | TranscriptWithPIIRedactionResponseFactory 300 | )(), 301 | config=aai.TranscriptionConfig( 302 | redact_pii=True, 303 | redact_pii_policies=[aai.types.PIIRedactionPolicy.date], 304 | redact_pii_audio=True, 305 | ), 306 | ) 307 | 308 | # Mock response that returns the redacted-audio URL 309 | __mock_successful_pii_audio_responses(httpx_mock, transcript) 310 | 311 | # Mock the redacted-audio URL response 312 | mock_audio_file_bytes = b"pretend this is a WAV file" 313 | httpx_mock.add_response( 314 | url=REDACTED_AUDIO_URL, 315 | status_code=httpx.codes.OK, 316 | method="GET", 317 | content=mock_audio_file_bytes, 318 | ) 319 | 320 | # Set up mocks for writing to disk 321 | mock_file = mocker.mock_open() 322 | mocker.patch("builtins.open", mock_file) 323 | 324 | # Download the file 325 | downloaded_filepath = "redacted_audio.wav" 326 | transcript.save_redacted_audio(downloaded_filepath) 327 | 328 | # Ensure correct filepath was written to 329 | mock_file.assert_called_once_with(downloaded_filepath, "wb") 330 | 331 | # Ensure correct file content was written 332 | write_calls = mock_file().write.call_args_list 333 | full_written_bytes = b"".join(call.args[0] for call in write_calls) 334 | assert full_written_bytes == mock_audio_file_bytes 335 | 336 | 337 | def test_save_pii_redacted_audio_fails_if_redact_pii_not_enabled_for_transcript( 338 | httpx_mock: HTTPXMock, 339 | ): 340 | """ 341 | Tests that an error is thrown before any requests are made if 342 | `redact_pii` was not enabled for the transcript and 343 | `save_redacted_audio` is called 344 | """ 345 | _, transcript = unit_test_utils.submit_mock_transcription_request( 346 | httpx_mock, 347 | mock_response=factories.generate_dict_factory( 348 | factories.TranscriptCompletedResponseFactory 349 | )(), # standard response 350 | config=aai.TranscriptionConfig(), # blank config 351 | ) 352 | 353 | with pytest.raises(ValueError) as error: 354 | transcript.save_redacted_audio("redacted_audio.wav") 355 | 356 | assert "redact_pii" in str(error) 357 | 358 | # Ensure we never made the additional requests to get the redacted audio information 359 | assert len(httpx_mock.get_requests()) == 2 360 | 361 | 362 | def test_save_pii_redacted_audio_fails_if_redact_pii_audio_not_enabled_for_transcript( 363 | httpx_mock: HTTPXMock, 364 | ): 365 | """ 366 | Tests that an error is thrown before any requests are made if 367 | `redact_pii_audio` was not enabled for the transcript and 368 | `get_redacted_audio_url` is called 369 | """ 370 | _, transcript = unit_test_utils.submit_mock_transcription_request( 371 | httpx_mock, 372 | mock_response={ 373 | **factories.generate_dict_factory( 374 | TranscriptWithPIIRedactionResponseFactory 375 | )(), 376 | "redact_pii_audio": False, 377 | }, 378 | config=aai.TranscriptionConfig( 379 | redact_pii=True, redact_pii_policies=[aai.types.PIIRedactionPolicy.date] 380 | ), 381 | ) 382 | 383 | with pytest.raises(ValueError) as error: 384 | transcript.save_redacted_audio("redacted_audio.wav") 385 | 386 | assert "redact_pii_audio" in str(error) 387 | 388 | # Ensure we never made the additional requests to get the redacted audio information 389 | assert len(httpx_mock.get_requests()) == 2 390 | 391 | 392 | def test_save_pii_redacted_audio_fails_if_bad_response(httpx_mock: HTTPXMock): 393 | """ 394 | Tests that `save_redacted_audio` raises a `RedactedAudioUnavailableError` if 395 | the request to fetch the redacted audio URL returns a `400` status code, 396 | indicating that the redacted audio has expired 397 | """ 398 | _, transcript = unit_test_utils.submit_mock_transcription_request( 399 | httpx_mock, 400 | mock_response=factories.generate_dict_factory( 401 | TranscriptWithPIIRedactionResponseFactory 402 | )(), 403 | config=aai.TranscriptionConfig( 404 | redact_pii=True, 405 | redact_pii_policies=[aai.types.PIIRedactionPolicy.date], 406 | redact_pii_audio=True, 407 | ), 408 | ) 409 | 410 | __mock_failed_pii_audio_responses(httpx_mock, transcript) 411 | with pytest.raises(aai.types.RedactedAudioExpiredError): 412 | transcript.save_redacted_audio("redacted_audio.wav") 413 | 414 | 415 | def test_save_pii_redacted_audio_fails_if_bad_audio_url_response(httpx_mock: HTTPXMock): 416 | """ 417 | Tests that `save_redacted_audio` raises a `RedactedAudioUnavailableError` if 418 | the request to fetch the redacted audio **file** returns a non-200 status code 419 | """ 420 | _, transcript = unit_test_utils.submit_mock_transcription_request( 421 | httpx_mock, 422 | mock_response=factories.generate_dict_factory( 423 | TranscriptWithPIIRedactionResponseFactory 424 | )(), 425 | config=aai.TranscriptionConfig( 426 | redact_pii=True, 427 | redact_pii_policies=[aai.types.PIIRedactionPolicy.date], 428 | redact_pii_audio=True, 429 | ), 430 | ) 431 | 432 | __mock_successful_pii_audio_responses(httpx_mock, transcript) 433 | httpx_mock.add_response( 434 | url=REDACTED_AUDIO_URL, 435 | status_code=httpx.codes.NOT_FOUND, 436 | method="GET", 437 | ) 438 | with pytest.raises(aai.types.RedactedAudioUnavailableError): 439 | transcript.save_redacted_audio("redacted_audio.wav") 440 | -------------------------------------------------------------------------------- /tests/unit/test_sentiment_analysis.py: -------------------------------------------------------------------------------- 1 | import factory 2 | from pytest_httpx import HTTPXMock 3 | 4 | import tests.unit.unit_test_utils as unit_test_utils 5 | import assemblyai as aai 6 | from tests.unit import factories 7 | 8 | aai.settings.api_key = "test" 9 | 10 | 11 | class SentimentFactory(factories.WordFactory): 12 | sentiment = factory.Faker("enum", enum_cls=aai.types.SentimentType) 13 | speaker = factory.Faker("name") 14 | 15 | 16 | class SentimentAnalysisResponseFactory(factories.TranscriptCompletedResponseFactory): 17 | sentiment_analysis_results = factory.List([factory.SubFactory(SentimentFactory)]) 18 | 19 | 20 | def test_sentiment_analysis_disabled_by_default(httpx_mock: HTTPXMock): 21 | """ 22 | Tests that excluding `sentiment_analysis` from the `TranscriptionConfig` will 23 | result in the default behavior of it being excluded from the request body 24 | """ 25 | request_body, transcript = unit_test_utils.submit_mock_transcription_request( 26 | httpx_mock, 27 | mock_response=factories.generate_dict_factory( 28 | factories.TranscriptCompletedResponseFactory 29 | )(), 30 | config=aai.TranscriptionConfig(), 31 | ) 32 | assert request_body.get("sentiment_analysis") is None 33 | assert transcript.sentiment_analysis is None 34 | 35 | 36 | def test_sentiment_analysis_enabled(httpx_mock: HTTPXMock): 37 | """ 38 | Tests that including `sentiment_analysis=True` in the `TranscriptionConfig` 39 | will result in `sentiment_analysis=True` in the request body, and that the 40 | response is properly parsed into a `Transcript` object 41 | """ 42 | mock_response = factories.generate_dict_factory(SentimentAnalysisResponseFactory)() 43 | request_body, transcript = unit_test_utils.submit_mock_transcription_request( 44 | httpx_mock, 45 | mock_response=mock_response, 46 | config=aai.TranscriptionConfig(sentiment_analysis=True), 47 | ) 48 | 49 | # Check that request body was properly defined 50 | assert request_body.get("sentiment_analysis") is True 51 | 52 | # Check that transcript was properly parsed from JSON response 53 | assert transcript.error is None 54 | 55 | assert transcript.sentiment_analysis is not None 56 | assert len(transcript.sentiment_analysis) > 0 57 | assert len(transcript.sentiment_analysis) == len( 58 | mock_response["sentiment_analysis_results"] 59 | ) 60 | 61 | for response_sentiment_result, transcript_sentiment_result in zip( 62 | mock_response["sentiment_analysis_results"], 63 | transcript.sentiment_analysis, 64 | ): 65 | assert transcript_sentiment_result.text == response_sentiment_result["text"] 66 | assert transcript_sentiment_result.start == response_sentiment_result["start"] 67 | assert transcript_sentiment_result.end == response_sentiment_result["end"] 68 | assert ( 69 | transcript_sentiment_result.confidence 70 | == response_sentiment_result["confidence"] 71 | ) 72 | assert ( 73 | transcript_sentiment_result.sentiment.value 74 | == response_sentiment_result["sentiment"] 75 | ) 76 | assert ( 77 | transcript_sentiment_result.speaker == response_sentiment_result["speaker"] 78 | ) 79 | 80 | 81 | def test_sentiment_analysis_null_start(httpx_mock: HTTPXMock): 82 | """ 83 | Tests that `start` converts null values to 0. 84 | """ 85 | mock_response = { 86 | "audio_url": "https://example/audio.mp3", 87 | "status": "completed", 88 | "sentiment_analysis_results": [ 89 | { 90 | "text": "hi", 91 | "start": None, 92 | "end": 100, 93 | "confidence": 0.99, 94 | "sentiment": "POSITIVE", 95 | } 96 | ], 97 | } 98 | request_body, transcript = unit_test_utils.submit_mock_transcription_request( 99 | httpx_mock, 100 | mock_response=mock_response, 101 | config=aai.TranscriptionConfig(sentiment_analysis=True), 102 | ) 103 | 104 | for response_sentiment_result, transcript_sentiment_result in zip( 105 | mock_response["sentiment_analysis_results"], 106 | transcript.sentiment_analysis, 107 | ): 108 | assert transcript_sentiment_result.start == 0 109 | -------------------------------------------------------------------------------- /tests/unit/test_settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import reload 3 | 4 | import assemblyai as aai 5 | 6 | 7 | def test_api_key_settings(): 8 | """ 9 | Tests that `ASSEMBLYAI_API_KEY` works correctly 10 | """ 11 | tmp1 = os.environ.pop("ASSEMBLYAI_API_KEY", None) 12 | tmp2 = os.environ.pop("API_KEY", None) 13 | 14 | aai.settings.api_key = None 15 | reload(aai) 16 | assert aai.settings.api_key is None 17 | 18 | # this should not change the api key 19 | os.environ["API_KEY"] = "test" 20 | reload(aai) 21 | assert aai.settings.api_key is None 22 | 23 | # this should change the api key 24 | os.environ["ASSEMBLYAI_API_KEY"] = "test" 25 | reload(aai) 26 | assert aai.settings.api_key == "test" 27 | 28 | # reset 29 | if tmp1: 30 | os.environ["ASSEMBLYAI_API_KEY"] = tmp1 31 | else: 32 | os.environ.pop("ASSEMBLYAI_API_KEY", None) 33 | 34 | if tmp2: 35 | os.environ["API_KEY"] = tmp2 36 | else: 37 | os.environ.pop("API_KEY", None) 38 | 39 | reload(aai) 40 | aai.settings.api_key = "test" 41 | 42 | 43 | def test_base_url_settings(): 44 | """ 45 | Tests that `ASSEMBLY_BASE_URL` works correctly 46 | """ 47 | tmp1 = os.environ.pop("ASSEMBLYAI_BASE_URL", None) 48 | tmp2 = os.environ.pop("BASE_URL", None) 49 | 50 | aai.settings.base_url = "https://api.assemblyai.com" 51 | reload(aai) 52 | assert aai.settings.base_url == "https://api.assemblyai.com" 53 | 54 | # this should not change the base url 55 | os.environ["BASE_URL"] = "https://test.com" 56 | reload(aai) 57 | assert aai.settings.base_url == "https://api.assemblyai.com" 58 | 59 | # this should change the base url 60 | os.environ["ASSEMBLYAI_BASE_URL"] = "https://test.com" 61 | reload(aai) 62 | assert aai.settings.base_url == "https://test.com" 63 | 64 | # reset 65 | if tmp1: 66 | os.environ["ASSEMBLYAI_BASE_URL"] = tmp1 67 | else: 68 | os.environ.pop("ASSEMBLYAI_BASE_URL", None) 69 | if tmp2: 70 | os.environ["BASE_URL"] = tmp2 71 | else: 72 | os.environ.pop("BASE_URL", None) 73 | 74 | reload(aai) 75 | aai.settings.api_key = "test" 76 | -------------------------------------------------------------------------------- /tests/unit/test_streaming.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import urlencode 2 | 3 | from pytest_mock import MockFixture 4 | 5 | from assemblyai.streaming.v3 import ( 6 | StreamingClient, 7 | StreamingClientOptions, 8 | StreamingParameters, 9 | ) 10 | 11 | 12 | def _disable_rw_threads(mocker: MockFixture): 13 | """ 14 | Disable the read and write threads for the WebSocket. 15 | """ 16 | 17 | mocker.patch("threading.Thread.start", return_value=None) 18 | 19 | 20 | def test_client_connect(mocker: MockFixture): 21 | actual_url = None 22 | actual_additional_headers = None 23 | actual_open_timeout = None 24 | 25 | def mocked_websocket_connect( 26 | url: str, additional_headers: dict, open_timeout: float 27 | ): 28 | nonlocal actual_url, actual_additional_headers, actual_open_timeout 29 | actual_url = url 30 | actual_additional_headers = additional_headers 31 | actual_open_timeout = open_timeout 32 | 33 | mocker.patch( 34 | "assemblyai.streaming.v3.client.websocket_connect", 35 | new=mocked_websocket_connect, 36 | ) 37 | 38 | _disable_rw_threads(mocker) 39 | 40 | options = StreamingClientOptions(api_key="test", api_host="api.example.com") 41 | client = StreamingClient(options) 42 | 43 | params = StreamingParameters(sample_rate=16000) 44 | client.connect(params) 45 | 46 | expected_headers = { 47 | "sample_rate": params.sample_rate, 48 | } 49 | 50 | assert actual_url == f"wss://api.example.com/v3/ws?{urlencode(expected_headers)}" 51 | assert actual_additional_headers["Authorization"] == "test" 52 | assert actual_additional_headers["AssemblyAI-Version"] == "2025-05-12" 53 | assert "AssemblyAI/1.0" in actual_additional_headers["User-Agent"] 54 | 55 | assert actual_open_timeout == 15 56 | 57 | 58 | def test_client_connect_with_token(mocker: MockFixture): 59 | actual_url = None 60 | actual_additional_headers = None 61 | actual_open_timeout = None 62 | 63 | def mocked_websocket_connect( 64 | url: str, additional_headers: dict, open_timeout: float 65 | ): 66 | nonlocal actual_url, actual_additional_headers, actual_open_timeout 67 | actual_url = url 68 | actual_additional_headers = additional_headers 69 | actual_open_timeout = open_timeout 70 | 71 | mocker.patch( 72 | "assemblyai.streaming.v3.client.websocket_connect", 73 | new=mocked_websocket_connect, 74 | ) 75 | 76 | _disable_rw_threads(mocker) 77 | 78 | options = StreamingClientOptions(token="test", api_host="api.example.com") 79 | client = StreamingClient(options) 80 | 81 | params = StreamingParameters(sample_rate=16000) 82 | client.connect(params) 83 | 84 | expected_headers = { 85 | "sample_rate": params.sample_rate, 86 | } 87 | 88 | assert actual_url == f"wss://api.example.com/v3/ws?{urlencode(expected_headers)}" 89 | assert actual_additional_headers["Authorization"] == "test" 90 | assert actual_additional_headers["AssemblyAI-Version"] == "2025-05-12" 91 | assert "AssemblyAI/1.0" in actual_additional_headers["User-Agent"] 92 | 93 | assert actual_open_timeout == 15 94 | 95 | 96 | def test_client_connect_all_parameters(mocker: MockFixture): 97 | actual_url = None 98 | actual_additional_headers = None 99 | actual_open_timeout = None 100 | 101 | def mocked_websocket_connect( 102 | url: str, additional_headers: dict, open_timeout: float 103 | ): 104 | nonlocal actual_url, actual_additional_headers, actual_open_timeout 105 | actual_url = url 106 | actual_additional_headers = additional_headers 107 | actual_open_timeout = open_timeout 108 | 109 | mocker.patch( 110 | "assemblyai.streaming.v3.client.websocket_connect", 111 | new=mocked_websocket_connect, 112 | ) 113 | 114 | _disable_rw_threads(mocker) 115 | 116 | options = StreamingClientOptions(api_key="test", api_host="api.example.com") 117 | client = StreamingClient(options) 118 | 119 | params = StreamingParameters( 120 | sample_rate=16000, 121 | end_of_turn_confidence_threshold=0.5, 122 | min_end_of_turn_silence_when_confident=2000, 123 | max_turn_silence=3000, 124 | ) 125 | 126 | client.connect(params) 127 | 128 | expected_headers = { 129 | "end_of_turn_confidence_threshold": params.end_of_turn_confidence_threshold, 130 | "min_end_of_turn_silence_when_confident": params.min_end_of_turn_silence_when_confident, 131 | "max_turn_silence": params.max_turn_silence, 132 | "sample_rate": params.sample_rate, 133 | } 134 | 135 | assert actual_url == f"wss://api.example.com/v3/ws?{urlencode(expected_headers)}" 136 | 137 | assert actual_additional_headers["Authorization"] == "test" 138 | assert actual_additional_headers["AssemblyAI-Version"] == "2025-05-12" 139 | assert "AssemblyAI/1.0" in actual_additional_headers["User-Agent"] 140 | 141 | assert actual_open_timeout == 15 142 | 143 | 144 | def test_client_send_audio(mocker: MockFixture): 145 | actual_url = None 146 | actual_additional_headers = None 147 | actual_open_timeout = None 148 | 149 | def mocked_websocket_connect( 150 | url: str, additional_headers: dict, open_timeout: float 151 | ): 152 | nonlocal actual_url, actual_additional_headers, actual_open_timeout 153 | actual_url = url 154 | actual_additional_headers = additional_headers 155 | actual_open_timeout = open_timeout 156 | 157 | mocker.patch( 158 | "assemblyai.streaming.v3.client.websocket_connect", 159 | new=mocked_websocket_connect, 160 | ) 161 | 162 | _disable_rw_threads(mocker) 163 | 164 | options = StreamingClientOptions(api_key="test", api_host="api.example.com") 165 | client = StreamingClient(options) 166 | 167 | params = StreamingParameters(sample_rate=16000) 168 | client.connect(params) 169 | client.stream(b"test audio data") 170 | 171 | assert client._write_queue.qsize() == 1 172 | assert isinstance(client._write_queue.get(timeout=1), bytes) 173 | -------------------------------------------------------------------------------- /tests/unit/test_summarization.py: -------------------------------------------------------------------------------- 1 | import factory 2 | import pytest 3 | from pytest_httpx import HTTPXMock 4 | 5 | import tests.unit.factories as factories 6 | import tests.unit.unit_test_utils as test_utils 7 | import assemblyai as aai 8 | 9 | aai.settings.api_key = "test" 10 | 11 | 12 | class SummarizationResponseFactory(factories.TranscriptCompletedResponseFactory): 13 | summary = factory.Faker("sentence") 14 | 15 | 16 | @pytest.mark.parametrize("required_field", ["punctuate", "format_text"]) 17 | def test_summarization_fails_without_required_field( 18 | httpx_mock: HTTPXMock, required_field: str 19 | ): 20 | """ 21 | Tests whether the SDK raises an error before making a request 22 | if `summarization` is enabled and the given required field is disabled 23 | """ 24 | with pytest.raises(ValueError) as error: 25 | test_utils.submit_mock_transcription_request( 26 | httpx_mock, 27 | {}, 28 | config=aai.TranscriptionConfig( 29 | summarization=True, 30 | **{required_field: False}, # type: ignore 31 | ), 32 | ) 33 | 34 | # Check that the error message informs the user of the invalid parameter 35 | assert required_field in str(error) 36 | 37 | # Check that the error was raised before any requests were made 38 | assert len(httpx_mock.get_requests()) == 0 39 | 40 | 41 | def test_summarization_disabled_by_default(httpx_mock: HTTPXMock): 42 | """ 43 | Tests that excluding `summarization` from the `TranscriptionConfig` will 44 | result in the default behavior of it being excluded from the request body 45 | """ 46 | mock_response = factories.generate_dict_factory( 47 | factories.TranscriptCompletedResponseFactory 48 | )() 49 | request_body, transcript = test_utils.submit_mock_transcription_request( 50 | httpx_mock, 51 | mock_response, 52 | config=aai.TranscriptionConfig(), 53 | ) 54 | 55 | # Check that request body was properly defined 56 | assert request_body.get("summarization") is None 57 | 58 | # Check that transcript was properly parsed from JSON response 59 | assert transcript.error is None 60 | assert transcript.summary is None 61 | 62 | 63 | def test_default_summarization_params(httpx_mock: HTTPXMock): 64 | """ 65 | Tests that including `summarization=True` in the `TranscriptionConfig` 66 | will result in `summarization=True` in the request body. 67 | """ 68 | mock_response = factories.generate_dict_factory(SummarizationResponseFactory)() 69 | request_body, transcript = test_utils.submit_mock_transcription_request( 70 | httpx_mock, mock_response, aai.TranscriptionConfig(summarization=True) 71 | ) 72 | 73 | # Check that request body was properly defined 74 | assert request_body.get("summarization") is True 75 | assert request_body.get("summary_model") is None 76 | assert request_body.get("summary_type") is None 77 | 78 | # Check that transcript was properly parsed from JSON response 79 | assert transcript.error is None 80 | assert transcript.summary == mock_response["summary"] 81 | 82 | 83 | def test_summarization_with_params(httpx_mock: HTTPXMock): 84 | """ 85 | Tests that including additional summarization parameters along with 86 | `summarization=True` in the `TranscriptionConfig` will result in all 87 | parameters being included in the request as well. 88 | """ 89 | 90 | summary_model = aai.SummarizationModel.conversational 91 | summary_type = aai.SummarizationType.bullets 92 | 93 | mock_response = factories.generate_dict_factory(SummarizationResponseFactory)() 94 | 95 | request_body, transcript = test_utils.submit_mock_transcription_request( 96 | httpx_mock, 97 | mock_response, 98 | aai.TranscriptionConfig( 99 | summarization=True, 100 | summary_model=summary_model, 101 | summary_type=summary_type, 102 | ), 103 | ) 104 | 105 | # Check that request body was properly defined 106 | assert request_body.get("summarization") is True 107 | assert request_body.get("summary_model") == summary_model 108 | assert request_body.get("summary_type") == summary_type 109 | 110 | # Check that transcript was properly parsed from JSON response 111 | assert transcript.error is None 112 | assert transcript.summary == mock_response["summary"] 113 | 114 | 115 | def test_summarization_params_excluded_when_disabled(httpx_mock: HTTPXMock): 116 | """ 117 | Tests that additional summarization parameters are excluded from the submission 118 | request body if `summarization` itself is not enabled. 119 | """ 120 | mock_response = factories.generate_dict_factory( 121 | factories.TranscriptCompletedResponseFactory 122 | )() 123 | request_body, transcript = test_utils.submit_mock_transcription_request( 124 | httpx_mock, 125 | mock_response, 126 | aai.TranscriptionConfig( 127 | summarization=False, 128 | summary_model=aai.SummarizationModel.conversational, 129 | summary_type=aai.SummarizationType.bullets, 130 | ), 131 | ) 132 | 133 | # Check that request body was properly defined 134 | assert request_body.get("summarization") is None 135 | assert request_body.get("summary_model") is None 136 | assert request_body.get("summary_type") is None 137 | 138 | # Check that transcript was properly parsed from JSON response 139 | assert transcript.error is None 140 | assert transcript.summary is None 141 | -------------------------------------------------------------------------------- /tests/unit/test_transcript.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | from urllib.parse import urlencode 3 | 4 | import httpx 5 | import pytest 6 | from faker import Faker 7 | from pytest_httpx import HTTPXMock 8 | 9 | import assemblyai as aai 10 | from assemblyai.api import ENDPOINT_TRANSCRIPT 11 | from tests.unit import factories 12 | from assemblyai.types import SpeechModel 13 | 14 | aai.settings.api_key = "test" 15 | 16 | 17 | def test_export_subtitles_succeeds(httpx_mock: HTTPXMock, faker: Faker): 18 | """ 19 | Tests whether exporting subtitles succeed. 20 | """ 21 | 22 | # create a mock response of a completed transcript 23 | mock_transcript_response = factories.generate_dict_factory( 24 | factories.TranscriptCompletedResponseFactory 25 | )() 26 | 27 | expected_subtitles_srt = faker.text() 28 | expected_subtitles_vtt = faker.text() 29 | 30 | transcript = aai.Transcript.from_response( 31 | client=aai.Client.get_default(), 32 | response=aai.types.TranscriptResponse(**mock_transcript_response), 33 | ) 34 | 35 | # mock the specific endpoints 36 | httpx_mock.add_response( 37 | url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{transcript.id}/srt", 38 | status_code=httpx.codes.OK, 39 | method="GET", 40 | text=expected_subtitles_srt, 41 | ) 42 | 43 | httpx_mock.add_response( 44 | url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{transcript.id}/vtt", 45 | status_code=httpx.codes.OK, 46 | method="GET", 47 | text=expected_subtitles_vtt, 48 | ) 49 | 50 | srt_subtitles = transcript.export_subtitles_srt() 51 | vtt_subtitles = transcript.export_subtitles_vtt() 52 | 53 | assert srt_subtitles == expected_subtitles_srt 54 | assert vtt_subtitles == expected_subtitles_vtt 55 | 56 | # check whether we mocked everything 57 | assert len(httpx_mock.get_requests()) == 2 58 | 59 | 60 | def test_export_subtitles_fails(httpx_mock: HTTPXMock): 61 | """ 62 | Tests whether exporting subtitles fails. 63 | """ 64 | 65 | # create a mock response of a completed transcript 66 | mock_transcript_response = factories.generate_dict_factory( 67 | factories.TranscriptCompletedResponseFactory 68 | )() 69 | 70 | transcript = aai.Transcript.from_response( 71 | client=aai.Client.get_default(), 72 | response=aai.types.TranscriptResponse(**mock_transcript_response), 73 | ) 74 | 75 | # mock the specific endpoints 76 | httpx_mock.add_response( 77 | url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{transcript.id}/srt", 78 | status_code=httpx.codes.INTERNAL_SERVER_ERROR, 79 | method="GET", 80 | json={"error": "something went wrong"}, 81 | ) 82 | 83 | httpx_mock.add_response( 84 | url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{transcript.id}/vtt", 85 | status_code=httpx.codes.INTERNAL_SERVER_ERROR, 86 | method="GET", 87 | json={"error": "something went wrong"}, 88 | ) 89 | 90 | with pytest.raises(aai.TranscriptError, match="something went wrong"): 91 | transcript.export_subtitles_srt() 92 | 93 | with pytest.raises(aai.TranscriptError, match="something went wrong"): 94 | transcript.export_subtitles_vtt() 95 | 96 | # check whether we mocked everything 97 | assert len(httpx_mock.get_requests()) == 2 98 | 99 | 100 | def test_word_search_succeeds(httpx_mock: HTTPXMock): 101 | """ 102 | Tests whether word search succeeds. 103 | """ 104 | 105 | # create a mock response of a completed transcript 106 | mock_transcript_response = factories.generate_dict_factory( 107 | factories.TranscriptCompletedResponseFactory 108 | )() 109 | 110 | transcript = aai.Transcript.from_response( 111 | client=aai.Client.get_default(), 112 | response=aai.types.TranscriptResponse(**mock_transcript_response), 113 | ) 114 | 115 | # create a mock response for the word search 116 | mock_word_search_response = factories.generate_dict_factory( 117 | factories.WordSearchMatchResponseFactory 118 | )() 119 | 120 | search_words = { 121 | "words": ",".join(["test", "me"]), 122 | } 123 | # mock the specific endpoints 124 | url = httpx.URL( 125 | f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{transcript.id}/word-search?{urlencode(search_words)}", 126 | ) 127 | 128 | httpx_mock.add_response( 129 | url=url, 130 | status_code=httpx.codes.OK, 131 | method="GET", 132 | json=mock_word_search_response, 133 | ) 134 | 135 | # mimic the SDK call 136 | matches = transcript.word_search(words=["test", "me"]) 137 | 138 | # check integrity of the response 139 | 140 | for idx, word_search in enumerate(matches): 141 | assert isinstance(word_search, aai.WordSearchMatch) 142 | assert word_search.count == mock_word_search_response["matches"][idx]["count"] 143 | assert ( 144 | word_search.timestamps 145 | == mock_word_search_response["matches"][idx]["timestamps"] 146 | ) 147 | assert word_search.text == mock_word_search_response["matches"][idx]["text"] 148 | assert ( 149 | word_search.indexes == mock_word_search_response["matches"][idx]["indexes"] 150 | ) 151 | 152 | # check whether we mocked everything 153 | assert len(httpx_mock.get_requests()) == 1 154 | 155 | 156 | def test_word_search_fails(httpx_mock: HTTPXMock): 157 | """ 158 | Tests whether word search fails. 159 | """ 160 | 161 | # create a mock response of a completed transcript 162 | mock_transcript_response = factories.generate_dict_factory( 163 | factories.TranscriptCompletedResponseFactory 164 | )() 165 | 166 | transcript = aai.Transcript.from_response( 167 | client=aai.Client.get_default(), 168 | response=aai.types.TranscriptResponse(**mock_transcript_response), 169 | ) 170 | 171 | # mock the specific endpoints 172 | url = httpx.URL( 173 | f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{transcript.id}/word-search?words=test", 174 | ) 175 | 176 | httpx_mock.add_response( 177 | url=url, 178 | status_code=httpx.codes.INTERNAL_SERVER_ERROR, 179 | method="GET", 180 | json={"error": "something went wrong"}, 181 | ) 182 | 183 | with pytest.raises(aai.TranscriptError, match="something went wrong"): 184 | transcript.word_search(words=["test"]) 185 | 186 | # check whether we mocked everything 187 | assert len(httpx_mock.get_requests()) == 1 188 | 189 | 190 | def test_get_sentences_and_paragraphs_succeeds(httpx_mock: HTTPXMock): 191 | """ 192 | Tests whether getting sentences and paragraphs succeeds. 193 | """ 194 | 195 | # create a mock response of a completed transcript 196 | mock_transcript_response = factories.generate_dict_factory( 197 | factories.TranscriptCompletedResponseFactory 198 | )() 199 | 200 | # create a mock response for the sentences 201 | mock_sentences_response = factories.generate_dict_factory( 202 | factories.SentencesResponseFactory 203 | )() 204 | 205 | # create a mock response for the paragraphs 206 | mock_paragraphs_response = factories.generate_dict_factory( 207 | factories.ParagraphsResponseFactory 208 | )() 209 | 210 | transcript = aai.Transcript.from_response( 211 | client=aai.Client.get_default(), 212 | response=aai.types.TranscriptResponse(**mock_transcript_response), 213 | ) 214 | 215 | # mock the specific endpoints 216 | httpx_mock.add_response( 217 | url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{transcript.id}/sentences", 218 | status_code=httpx.codes.OK, 219 | method="GET", 220 | json=mock_sentences_response, 221 | ) 222 | 223 | httpx_mock.add_response( 224 | url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{transcript.id}/paragraphs", 225 | status_code=httpx.codes.OK, 226 | method="GET", 227 | json=mock_paragraphs_response, 228 | ) 229 | 230 | # mimic the SDK call 231 | sentences = transcript.get_sentences() 232 | paragraphs = transcript.get_paragraphs() 233 | 234 | # check integrity of the response 235 | def compare_words(lhs: List[aai.Word], rhs: List[Dict[str, Any]]) -> bool: 236 | """ 237 | Compares the list of Word objects with the list of dicts. 238 | 239 | Args: 240 | lhs: The list of Word objects. 241 | rhs: The list of dicts. 242 | 243 | Returns: 244 | True if the lists are equal, False otherwise. 245 | """ 246 | for idx, word in enumerate(lhs): 247 | if word.text != rhs[idx]["text"]: 248 | return False 249 | if word.start != rhs[idx]["start"]: 250 | return False 251 | if word.end != rhs[idx]["end"]: 252 | return False 253 | return True 254 | 255 | for idx, sentence in enumerate(sentences): 256 | assert isinstance(sentence, aai.Sentence) 257 | assert compare_words( 258 | sentence.words, mock_sentences_response["sentences"][idx]["words"] 259 | ) 260 | 261 | for idx, paragraph in enumerate(paragraphs): 262 | assert isinstance(paragraph, aai.Paragraph) 263 | assert compare_words( 264 | paragraph.words, mock_paragraphs_response["paragraphs"][idx]["words"] 265 | ) 266 | 267 | # check whether we mocked everything 268 | assert len(httpx_mock.get_requests()) == 2 269 | 270 | 271 | def test_get_sentences_and_paragraphs_fails(httpx_mock: HTTPXMock): 272 | """ 273 | Tests whether getting sentences and paragraphs fails. 274 | """ 275 | 276 | # create a mock response of a completed transcript 277 | mock_transcript_response = factories.generate_dict_factory( 278 | factories.TranscriptCompletedResponseFactory 279 | )() 280 | 281 | transcript = aai.Transcript.from_response( 282 | client=aai.Client.get_default(), 283 | response=aai.types.TranscriptResponse(**mock_transcript_response), 284 | ) 285 | 286 | # mock the specific endpoints 287 | httpx_mock.add_response( 288 | url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{transcript.id}/sentences", 289 | status_code=httpx.codes.INTERNAL_SERVER_ERROR, 290 | method="GET", 291 | json={"error": "something went wrong"}, 292 | ) 293 | 294 | httpx_mock.add_response( 295 | url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{transcript.id}/paragraphs", 296 | status_code=httpx.codes.INTERNAL_SERVER_ERROR, 297 | method="GET", 298 | json={"error": "something went wrong"}, 299 | ) 300 | 301 | # mimic the SDK call 302 | with pytest.raises(aai.TranscriptError, match="something went wrong"): 303 | transcript.get_sentences() 304 | with pytest.raises(aai.TranscriptError, match="something went wrong"): 305 | transcript.get_paragraphs() 306 | 307 | # check whether we mocked everything 308 | assert len(httpx_mock.get_requests()) == 2 309 | 310 | 311 | def test_get_by_id_completed(httpx_mock: HTTPXMock): 312 | """ 313 | Tests that a completed transcript can be retrieved by its ID. 314 | """ 315 | transcript_id = "123" 316 | 317 | factory_class = factories.TranscriptCompletedResponseFactory 318 | 319 | mock_transcript_response = factories.generate_dict_factory(factory_class)() 320 | 321 | httpx_mock.add_response( 322 | url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{transcript_id}", 323 | status_code=httpx.codes.OK, 324 | method="GET", 325 | json=mock_transcript_response, 326 | ) 327 | 328 | transcript = aai.Transcript.get_by_id(transcript_id) 329 | 330 | assert isinstance(transcript, aai.Transcript) 331 | assert transcript.status == aai.TranscriptStatus.completed 332 | assert transcript.id == transcript_id 333 | assert transcript.error is None 334 | 335 | 336 | def test_get_by_id_error(httpx_mock: HTTPXMock): 337 | """ 338 | Tests that an error transcript can be retrieved by its ID. 339 | """ 340 | transcript_id = "123" 341 | 342 | factory_class = factories.TranscriptErrorResponseFactory 343 | 344 | mock_transcript_response = factories.generate_dict_factory(factory_class)() 345 | 346 | httpx_mock.add_response( 347 | url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{transcript_id}", 348 | status_code=httpx.codes.OK, 349 | method="GET", 350 | json=mock_transcript_response, 351 | ) 352 | 353 | transcript = aai.Transcript.get_by_id(transcript_id) 354 | 355 | assert isinstance(transcript, aai.Transcript) 356 | assert transcript.status == aai.TranscriptStatus.error 357 | assert transcript.id == transcript_id 358 | 359 | assert len(httpx_mock.get_requests()) == 1 360 | 361 | 362 | def test_get_by_id_fails(httpx_mock: HTTPXMock): 363 | """ 364 | Tests that a failed transcript lookup raises an exception. 365 | """ 366 | test_id = 1234 367 | 368 | # json response upon failure 369 | response_json = {"error": "Transcript lookup error, transcript id not found"} 370 | 371 | # mock the specific endpoints 372 | httpx_mock.add_response( 373 | url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{test_id}", 374 | status_code=httpx.codes.BAD_REQUEST, 375 | method="GET", 376 | json=response_json, 377 | ) 378 | 379 | # Check that an error is properly raised when no such transcript exists 380 | with pytest.raises(aai.TranscriptError) as excinfo: 381 | aai.Transcript.get_by_id(test_id) 382 | 383 | # check wheter the TranscriptError contains the specified error message 384 | assert response_json["error"] in str(excinfo.value) 385 | assert len(httpx_mock.get_requests()) == 1 386 | 387 | 388 | def test_get_by_id_async(httpx_mock: HTTPXMock): 389 | transcript_id = "123" 390 | factory_class = factories.TranscriptCompletedResponseFactory 391 | 392 | mock_transcript_response = factories.generate_dict_factory(factory_class)() 393 | 394 | httpx_mock.add_response( 395 | url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{transcript_id}", 396 | status_code=httpx.codes.OK, 397 | method="GET", 398 | json=mock_transcript_response, 399 | ) 400 | 401 | transcript_future = aai.Transcript.get_by_id_async(transcript_id) 402 | transcript = transcript_future.result() 403 | 404 | assert isinstance(transcript, aai.Transcript) 405 | assert transcript.status == aai.TranscriptStatus.completed 406 | assert transcript.id == transcript_id 407 | assert transcript.error is None 408 | 409 | 410 | def test_delete_by_id(httpx_mock: HTTPXMock): 411 | mock_transcript_response = factories.generate_dict_factory( 412 | factories.TranscriptDeletedResponseFactory 413 | )() 414 | transcript_id = mock_transcript_response["id"] 415 | httpx_mock.add_response( 416 | url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{transcript_id}", 417 | status_code=httpx.codes.OK, 418 | method="DELETE", 419 | json=mock_transcript_response, 420 | ) 421 | 422 | transcript = aai.Transcript.delete_by_id(transcript_id) 423 | 424 | assert isinstance(transcript, aai.Transcript) 425 | assert transcript.status == aai.TranscriptStatus.completed 426 | assert transcript.id == transcript_id 427 | assert transcript.error is None 428 | assert transcript.text == mock_transcript_response["text"] 429 | assert transcript.audio_url == mock_transcript_response["audio_url"] 430 | 431 | 432 | def test_delete_by_id_async(httpx_mock: HTTPXMock): 433 | mock_transcript_response = factories.generate_dict_factory( 434 | factories.TranscriptDeletedResponseFactory 435 | )() 436 | transcript_id = mock_transcript_response["id"] 437 | 438 | httpx_mock.add_response( 439 | url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{transcript_id}", 440 | status_code=httpx.codes.OK, 441 | method="DELETE", 442 | json=mock_transcript_response, 443 | ) 444 | 445 | transcript_future = aai.Transcript.delete_by_id_async(transcript_id) 446 | transcript = transcript_future.result() 447 | 448 | assert isinstance(transcript, aai.Transcript) 449 | assert transcript.status == aai.TranscriptStatus.completed 450 | assert transcript.id == transcript_id 451 | assert transcript.error is None 452 | assert transcript.text == mock_transcript_response["text"] 453 | assert transcript.audio_url == mock_transcript_response["audio_url"] 454 | -------------------------------------------------------------------------------- /tests/unit/test_transcript_group.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | import httpx 4 | from pytest_httpx import HTTPXMock 5 | 6 | import assemblyai as aai 7 | from assemblyai.api import ENDPOINT_TRANSCRIPT 8 | from tests.unit import factories 9 | 10 | 11 | def test_transcript_group_accepts_transcript_ids(): 12 | """ 13 | Tests whether a TranscriptGroup accepts transcript IDs. 14 | """ 15 | transcript_ids = [str(uuid.uuid4()), str(uuid.uuid4())] 16 | 17 | transcript_group = aai.TranscriptGroup(transcript_ids=transcript_ids) 18 | 19 | assert [transcript.id for transcript in transcript_group] == transcript_ids 20 | 21 | 22 | def test_transcript_group_check_status(): 23 | """ 24 | Tests the TranscriptGroup's status 25 | """ 26 | 27 | # create a mock response of a completed transcript 28 | mock_completed_transcript = factories.generate_dict_factory( 29 | factories.TranscriptCompletedResponseFactory 30 | )() 31 | 32 | mock_queued_transcript = factories.generate_dict_factory( 33 | factories.TranscriptQueuedResponseFactory 34 | )() 35 | 36 | mock_processing_transcript = factories.generate_dict_factory( 37 | factories.TranscriptProcessingResponseFactory 38 | )() 39 | 40 | mock_error_transcript = factories.generate_dict_factory( 41 | factories.TranscriptErrorResponseFactory 42 | )() 43 | 44 | transcript_completed = aai.Transcript.from_response( 45 | client=aai.Client.get_default(), 46 | response=aai.types.TranscriptResponse(**mock_completed_transcript), 47 | ) 48 | 49 | transcript_queued = aai.Transcript.from_response( 50 | client=aai.Client.get_default(), 51 | response=aai.types.TranscriptResponse(**mock_queued_transcript), 52 | ) 53 | 54 | transcript_processing = aai.Transcript.from_response( 55 | client=aai.Client.get_default(), 56 | response=aai.types.TranscriptResponse(**mock_processing_transcript), 57 | ) 58 | 59 | transcript_error = aai.Transcript.from_response( 60 | client=aai.Client.get_default(), 61 | response=aai.types.TranscriptResponse(**mock_error_transcript), 62 | ) 63 | 64 | transcript_group = aai.TranscriptGroup() 65 | 66 | transcript_group.add_transcript(transcript_completed) 67 | assert transcript_group.status == aai.TranscriptStatus.completed 68 | 69 | transcript_group.add_transcript(transcript_queued) 70 | assert transcript_group.status == aai.TranscriptStatus.queued 71 | 72 | transcript_group.add_transcript(transcript_processing) 73 | assert transcript_group.status == aai.TranscriptStatus.queued 74 | 75 | transcript_group.add_transcript(transcript_error) 76 | assert transcript_group.status == aai.TranscriptStatus.error 77 | 78 | 79 | def test_get_by_ids(httpx_mock: HTTPXMock): 80 | transcript_ids = ["123", "456"] 81 | mock_transcript_response = factories.generate_dict_factory( 82 | factories.TranscriptCompletedResponseFactory 83 | )() 84 | for transcript_id in transcript_ids: 85 | httpx_mock.add_response( 86 | url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{transcript_id}", 87 | status_code=httpx.codes.OK, 88 | method="GET", 89 | json=mock_transcript_response, 90 | ) 91 | 92 | transcript_group = aai.TranscriptGroup.get_by_ids(transcript_ids) 93 | 94 | assert isinstance(transcript_group, aai.TranscriptGroup) 95 | assert transcript_group.status == aai.TranscriptStatus.completed 96 | for transcript in transcript_group: 97 | assert transcript.id in transcript_ids 98 | transcript_ids.remove(transcript.id) 99 | 100 | assert transcript.error is None 101 | assert len(transcript_ids) == 0 102 | 103 | 104 | def test_get_by_id_async(httpx_mock: HTTPXMock): 105 | transcript_ids = ["123", "456"] 106 | mock_transcript_response = factories.generate_dict_factory( 107 | factories.TranscriptCompletedResponseFactory 108 | )() 109 | for transcript_id in transcript_ids: 110 | httpx_mock.add_response( 111 | url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{transcript_id}", 112 | status_code=httpx.codes.OK, 113 | method="GET", 114 | json=mock_transcript_response, 115 | ) 116 | 117 | transcript_group_future = aai.TranscriptGroup.get_by_ids_async(transcript_ids) 118 | transcript_group = transcript_group_future.result() 119 | 120 | assert isinstance(transcript_group, aai.TranscriptGroup) 121 | assert transcript_group.status == aai.TranscriptStatus.completed 122 | for transcript in transcript_group: 123 | assert transcript.id in transcript_ids 124 | transcript_ids.remove(transcript.id) 125 | 126 | assert transcript.error is None 127 | assert len(transcript_ids) == 0 128 | -------------------------------------------------------------------------------- /tests/unit/unit_test_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any, Dict, Tuple 3 | 4 | import httpx 5 | from pytest_httpx import HTTPXMock 6 | 7 | import assemblyai as aai 8 | from assemblyai.api import ENDPOINT_TRANSCRIPT 9 | from tests.unit import factories 10 | 11 | 12 | def submit_mock_transcription_request( 13 | httpx_mock: HTTPXMock, 14 | mock_response: Dict[str, Any], 15 | config: aai.TranscriptionConfig, 16 | ) -> Tuple[Dict[str, Any], aai.transcriber.Transcript]: 17 | """ 18 | Helper function to abstract calling transcriber with given parameters, 19 | and perform some common assertions. 20 | 21 | Args: 22 | httpx_mock: HTTPXMock instance to use for mocking requests 23 | mock_response: Dict to use as mock response from API 24 | config: The `TranscriptionConfig` to use for transcription 25 | 26 | Returns: 27 | A tuple containing the JSON body of the initial submission request, 28 | and the `Transcript` object parsed from the mock response 29 | """ 30 | 31 | mock_transcript_id = mock_response.get("id", "mock_id") 32 | 33 | # Mock initial submission response (transcript is processing) 34 | mock_processing_response = factories.generate_dict_factory( 35 | factories.TranscriptProcessingResponseFactory 36 | )() 37 | 38 | httpx_mock.add_response( 39 | url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}", 40 | status_code=httpx.codes.OK, 41 | method="POST", 42 | json={ 43 | **mock_processing_response, 44 | "id": mock_transcript_id, # inject ID from main mock response 45 | }, 46 | ) 47 | 48 | # Mock polling-for-completeness response, with completed transcript 49 | httpx_mock.add_response( 50 | url=f"{aai.settings.base_url}{ENDPOINT_TRANSCRIPT}/{mock_transcript_id}", 51 | status_code=httpx.codes.OK, 52 | method="GET", 53 | json=mock_response, 54 | ) 55 | 56 | # == Make API request via SDK == 57 | transcript = aai.Transcriber().transcribe( 58 | data="https://example.org/audio.wav", 59 | config=config, 60 | ) 61 | 62 | # Check that submission and polling requests were made 63 | assert len(httpx_mock.get_requests()) == 2 64 | 65 | # Extract body of initial submission request 66 | request = httpx_mock.get_requests()[0] 67 | request_body = json.loads(request.content.decode()) 68 | 69 | return request_body, transcript 70 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py{38,39,310,311}-websockets{latest,11.0}-pyaudio{latest,0.2}-httpx{latest,0.24,0.23,0.22,0.21}-pydantic{latest,2,1.10,1.9,1.8,1.7}-typing-extensions 3 | 4 | [testenv] 5 | deps = 6 | # library dependencies 7 | websocketslatest: websockets 8 | websockets11.0: websockets>=11.0.0,<12.0.0 9 | httpxlatest: httpx 10 | httpx0.24: httpx>=0.24.0,<0.25.0 11 | httpx0.23: httpx>=0.23.0,<0.24.0 12 | httpx0.22: httpx>=0.22.0,<0.23.0 13 | httpx0.21: httpx>=0.21.0,<0.22.0 14 | pydanticlatest: pydantic 15 | pydantic2: pydantic>=2 16 | pydantic1.10: pydantic>=1.10.0,<1.11.0,!=1.10.7 17 | pydantic1.9: pydantic>=1.9.0,<1.10.0 18 | pydantic1.8: pydantic>=1.8.0,<1.9.0 19 | pydantic1.7: pydantic>=1.7.0,<1.8.0 20 | typing-extensions: typing-extensions>=3.7 21 | # extra dependencies 22 | pyaudiolatest: pyaudio 23 | pyaudio0.2: pyaudio>=0.2.13,<0.3.0 24 | # test dependencies 25 | pytest 26 | pytest-httpx 27 | pytest-xdist 28 | pytest-mock 29 | pytest-cov 30 | factory-boy 31 | allowlist_externals = pytest 32 | 33 | commands = pytest -n auto --cov-report term --cov-report xml:coverage.xml --cov=assemblyai 34 | --------------------------------------------------------------------------------