├── .gitignore ├── requirements.txt ├── .env.sample ├── text_processor.py ├── LICENSE ├── README.md ├── fastapi_app.py ├── audio_decoder.py ├── audio_analysis.py └── audio_generator.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | outputs 3 | debug_audio_errors 4 | debug_audio_success 5 | .env -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohappyeyeballs==2.6.1 2 | aiohttp==3.13.2 3 | aiosignal==1.4.0 4 | annotated-doc==0.0.3 5 | annotated-types==0.7.0 6 | anyio==4.11.0 7 | astor==0.8.1 8 | attrs==25.4.0 9 | blake3==1.0.8 10 | cachetools==6.2.1 11 | cbor2==5.7.1 12 | certifi==2025.10.5 13 | cffi==2.0.0 14 | charset-normalizer==3.4.4 15 | click==8.2.1 16 | cloudpickle==3.1.2 17 | compressed-tensors==0.11.0 18 | cupy-cuda12x==13.6.0 19 | depyf==0.19.0 20 | dill==0.4.0 21 | diskcache==5.6.3 22 | distro==1.9.0 23 | dnspython==2.8.0 24 | einops==0.8.1 25 | email-validator==2.3.0 26 | fastapi==0.121.1 27 | fastapi-cli==0.0.14 28 | fastapi-cloud-cli==0.3.1 29 | fastrlock==0.8.3 30 | filelock==3.20.0 31 | frozendict==2.4.6 32 | frozenlist==1.8.0 33 | fsspec==2025.10.0 34 | gguf==0.17.1 35 | h11==0.16.0 36 | hf-transfer==0.1.9 37 | hf-xet==1.2.0 38 | httpcore==1.0.9 39 | httptools==0.7.1 40 | httpx==0.28.1 41 | huggingface-hub==0.36.0 42 | idna==3.11 43 | interegular==0.3.3 44 | jinja2==3.1.6 45 | jiter==0.12.0 46 | jsonschema==4.25.1 47 | jsonschema-specifications==2025.9.1 48 | lark==1.2.2 49 | llguidance==0.7.30 50 | llvmlite==0.44.0 51 | lm-format-enforcer==0.11.3 52 | markdown-it-py==4.0.0 53 | markupsafe==3.0.3 54 | mdurl==0.1.2 55 | mistral-common==1.8.5 56 | mpmath==1.3.0 57 | msgpack==1.1.2 58 | msgspec==0.19.0 59 | multidict==6.7.0 60 | networkx==3.5 61 | ninja==1.13.0 62 | numba==0.61.2 63 | numpy==2.2.6 64 | nvidia-cublas-cu12==12.8.4.1 65 | nvidia-cuda-cupti-cu12==12.8.90 66 | nvidia-cuda-nvrtc-cu12==12.8.93 67 | nvidia-cuda-runtime-cu12==12.8.90 68 | nvidia-cudnn-cu12==9.10.2.21 69 | nvidia-cufft-cu12==11.3.3.83 70 | nvidia-cufile-cu12==1.13.1.3 71 | nvidia-curand-cu12==10.3.9.90 72 | nvidia-cusolver-cu12==11.7.3.90 73 | nvidia-cusparse-cu12==12.5.8.93 74 | nvidia-cusparselt-cu12==0.7.1 75 | nvidia-nccl-cu12==2.27.3 76 | nvidia-nvjitlink-cu12==12.8.93 77 | nvidia-nvtx-cu12==12.8.90 78 | openai==2.7.1 79 | openai-harmony==0.0.8 80 | opencv-python-headless==4.12.0.88 81 | orpheus-speech==0.1.0 82 | outlines-core==0.2.11 83 | packaging==25.0 84 | partial-json-parser==0.2.1.1.post6 85 | pillow==12.0.0 86 | prometheus-client==0.23.1 87 | prometheus-fastapi-instrumentator==7.1.0 88 | propcache==0.4.1 89 | protobuf==6.33.0 90 | psutil==7.1.3 91 | py-cpuinfo==9.0.0 92 | pybase64==1.4.2 93 | pycountry==24.6.1 94 | pycparser==2.23 95 | pydantic==2.12.4 96 | pydantic-core==2.41.5 97 | pydantic-extra-types==2.10.6 98 | pygments==2.19.2 99 | python-dotenv==1.2.1 100 | python-json-logger==4.0.0 101 | python-multipart==0.0.20 102 | pyyaml==6.0.3 103 | pyzmq==27.1.0 104 | ray==2.51.1 105 | referencing==0.37.0 106 | regex==2025.11.3 107 | requests==2.32.5 108 | rich==14.2.0 109 | rich-toolkit==0.15.1 110 | rignore==0.7.6 111 | rpds-py==0.28.0 112 | safetensors==0.6.2 113 | scipy==1.16.3 114 | sentencepiece==0.2.1 115 | sentry-sdk==2.43.0 116 | setproctitle==1.3.7 117 | setuptools==79.0.1 118 | shellingham==1.5.4 119 | six==1.17.0 120 | snac==1.2.1 121 | sniffio==1.3.1 122 | soundfile==0.13.1 123 | soxr==1.0.0 124 | starlette==0.49.3 125 | sympy==1.14.0 126 | tiktoken==0.12.0 127 | tokenizers==0.22.1 128 | torch==2.8.0 129 | torchaudio==2.8.0 130 | torchvision==0.23.0 131 | tqdm==4.67.1 132 | transformers==4.57.1 133 | triton==3.4.0 134 | typer==0.20.0 135 | typing-extensions==4.15.0 136 | typing-inspection==0.4.2 137 | urllib3==2.5.0 138 | uvicorn==0.38.0 139 | uvloop==0.22.1 140 | vllm==0.10.2 141 | watchfiles==1.1.1 142 | websockets==15.0.1 143 | xformers==0.0.32.post1 144 | xgrammar==0.1.23 145 | yarl==1.22.0 146 | -------------------------------------------------------------------------------- /.env.sample: -------------------------------------------------------------------------------- 1 | # =============================== 2 | # Orpheus TTS FastAPI Configuration 3 | # =============================== 4 | # Copy this file to .env and modify the values as needed 5 | # All values shown here are the defaults used when environment variables are not set 6 | 7 | # =============================== 8 | # MODEL CONFIGURATION 9 | # =============================== 10 | 11 | # Model name from HuggingFace or local path 12 | TTS_MODEL_NAME="canopylabs/orpheus-tts-0.1-finetune-prod" 13 | 14 | # Data type for model weights and computation 15 | # Options: bfloat16 (default), float16, float32, bf16, fp16, fp32 16 | # bfloat16: Best balance of quality and VRAM usage (recommended) 17 | # float16: Slightly lower VRAM, may have numerical issues 18 | # float32: Highest quality but uses more VRAM 19 | TTS_DTYPE="bfloat16" 20 | 21 | # Maximum sequence length (prompt + generated tokens) 22 | # Larger values use more VRAM but allow longer text processing 23 | # Recommended: 6144-8192 for most use cases 24 | TTS_MAX_MODEL_LEN="8192" 25 | 26 | # Number of GPUs to use for tensor parallelism 27 | # Set to 1 for single GPU setups 28 | TTS_TENSOR_PARALLEL_SIZE="1" 29 | 30 | # Fraction of GPU memory to use (0.0 to 1.0) 31 | # Recommendations by GPU: 32 | # - RTX 3090/4090 (24GB): 0.9 33 | # - RTX A6000 (48GB): 0.5-0.6 34 | # - H100 (80GB): 0.3-0.4 35 | TTS_GPU_MEMORY_UTILIZATION="0.95" 36 | 37 | # =============================== 38 | # PERFORMANCE CONFIGURATION 39 | # =============================== 40 | 41 | # Thread pool size for file I/O 42 | # Recommended: 4-16 based on CPU cores 43 | TTS_MAX_WORKERS="16" 44 | 45 | # =============================== 46 | # GENERATION PARAMETERS (DEFAULTS) 47 | # =============================== 48 | # These are used as defaults when not specified in API requests 49 | # Individual requests can override these values 50 | 51 | # Temperature for sampling (0.0 to 2.0) 52 | # Lower values (0.1-0.3) = more deterministic/consistent 53 | # Higher values (0.5-1.0) = more creative/varied 54 | # Default 0.3 works well for TTS 55 | TTS_TEMPERATURE="0.3" 56 | 57 | # Top-p nucleus sampling (0.0 to 1.0) 58 | # Lower values = more focused/predictable 59 | # Higher values = more diverse outputs 60 | # Default 0.9 provides good quality 61 | TTS_TOP_P="0.9" 62 | 63 | # Repetition penalty (0.5 to 2.0) 64 | # Values > 1.0 discourage repetition 65 | # Values < 1.0 encourage repetition 66 | # Default 1.3 prevents excessive repetition 67 | TTS_REPETITION_PENALTY="1.3" 68 | 69 | # Maximum tokens to generate per request 70 | # Higher values allow longer audio but use more VRAM/time 71 | # Must be less than TTS_MAX_MODEL_LEN minus prompt length 72 | TTS_MAX_TOKENS="4096" 73 | 74 | # =============================== 75 | # LOGGING AND DEBUG 76 | # =============================== 77 | 78 | # Logging level: DEBUG, INFO, WARNING, ERROR 79 | # DEBUG: Verbose logging including sampling parameters 80 | # INFO: Standard operational logging (recommended) 81 | # WARNING: Only warnings and errors 82 | # ERROR: Only errors 83 | LOG_LEVEL="INFO" 84 | 85 | # =============================== 86 | # ADDITIONAL vLLM/SYSTEM SETTINGS 87 | # =============================== 88 | 89 | # Force vLLM V0 engine for stability (recommended) 90 | VLLM_USE_V1="0" 91 | 92 | # SNAC decoder device (cuda/ cpu/ mps) 93 | # Use cuda if you have sufficient VRAM, cpu otherwise 94 | SNAC_DEVICE="cuda" 95 | 96 | # Specify which GPU to use (for multi-GPU systems) 97 | # CUDA_VISIBLE_DEVICES="0" 98 | 99 | # Specify the max number of parallel requests you'll be making to the TTS model 100 | TTS_MAX_NUM_SEQS="64" 101 | 102 | # =============================== 103 | # USAGE EXAMPLES 104 | # =============================== 105 | 106 | # High Quality Setup (requires 24GB+ VRAM): 107 | # TTS_DTYPE="float32" 108 | # TTS_GPU_MEMORY_UTILIZATION="0.95" 109 | # TTS_MAX_MODEL_LEN="8192" 110 | 111 | # Memory Efficient Setup (16-20GB VRAM): 112 | # TTS_DTYPE="bfloat16" 113 | # TTS_GPU_MEMORY_UTILIZATION="0.8" 114 | # TTS_MAX_MODEL_LEN="6144" 115 | 116 | # Creative/Varied Output: 117 | # TTS_TEMPERATURE="0.5" 118 | # TTS_TOP_P="0.95" 119 | # TTS_REPETITION_PENALTY="1.2" 120 | 121 | # Deterministic/Consistent Output: 122 | # TTS_TEMPERATURE="0.1" 123 | # TTS_TOP_P="0.8" 124 | # TTS_REPETITION_PENALTY="1.0" 125 | -------------------------------------------------------------------------------- /text_processor.py: -------------------------------------------------------------------------------- 1 | import re 2 | import logging 3 | import os 4 | from dotenv import load_dotenv 5 | 6 | load_dotenv() 7 | 8 | # Logging Configuration 9 | LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() 10 | 11 | # Setup logging 12 | logging.basicConfig(level=getattr(logging, LOG_LEVEL)) 13 | logger = logging.getLogger(__name__) 14 | 15 | def split_text_into_sentences(text): 16 | """ 17 | Split text into sentences while preserving dialogue integrity and narrative flow. 18 | 19 | This function handles: 20 | - Dialogue preservation (keeping quotes together) 21 | - Dialogue attribution (keeping "he said" with dialogue) 22 | - Paragraph boundaries 23 | - Complex punctuation within dialogue 24 | - Reasonable chunk sizes for TTS processing 25 | """ 26 | # First, split by paragraphs to maintain document structure 27 | paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()] 28 | 29 | all_segments = [] 30 | 31 | for paragraph in paragraphs: 32 | segments = _split_paragraph_intelligently(paragraph) 33 | all_segments.extend(segments) 34 | 35 | # Post-process to ensure reasonable sizes and combine very short segments 36 | return _combine_short_segments(all_segments) 37 | 38 | def _split_paragraph_intelligently(paragraph): 39 | """Split a single paragraph while preserving dialogue and narrative flow.""" 40 | # If paragraph is short enough, return as-is 41 | if len(paragraph) <= 200: 42 | return [paragraph] 43 | 44 | # Find all dialogue blocks (text within quotes) 45 | dialogue_pattern = r'"([^"]*)"' 46 | dialogues = list(re.finditer(dialogue_pattern, paragraph)) 47 | 48 | segments = [] 49 | last_end = 0 50 | 51 | for dialogue_match in dialogues: 52 | # Get text before this dialogue 53 | before_dialogue = paragraph[last_end:dialogue_match.start()].strip() 54 | 55 | # Get the dialogue with quotes 56 | dialogue_text = dialogue_match.group(0) 57 | dialogue_start = dialogue_match.start() 58 | dialogue_end = dialogue_match.end() 59 | 60 | # Look for attribution after the dialogue 61 | # Common attribution patterns: "he said", "she cried", "Jack whispered", etc. 62 | attribution_pattern = r'^\s*([A-Z][a-zA-Z]*\.?\s+[a-zA-Z]+(?:\s+[a-zA-Z]+)*(?:\s+(?:said|cried|whispered|shouted|asked|replied|continued|added|muttered|declared|announced|exclaimed|sobbed|laughed|sighed|nodded|shook|clutched|tugged|looked|turned|moved|went|came|walked|ran|stood|sat|knelt|rose|smiled|frowned|gasped|breathed|swallowed|attempted|tried|began|finished)(?:\s+[a-zA-Z]+)*)?)' 63 | 64 | # Look ahead for attribution (up to 150 characters) 65 | text_after_dialogue = paragraph[dialogue_end:dialogue_end + 150] 66 | attribution_match = re.match(attribution_pattern, text_after_dialogue) 67 | 68 | # Look for attribution before the dialogue (within last 100 characters of before_dialogue) 69 | before_attribution_pattern = r'([A-Z][a-zA-Z]*\.?\s+[a-zA-Z]+(?:\s+[a-zA-Z]+)*)\s*$' 70 | before_attribution_match = None 71 | if before_dialogue: 72 | before_attribution_match = re.search(before_attribution_pattern, before_dialogue[-100:]) 73 | 74 | # Process text before dialogue 75 | if before_dialogue: 76 | if before_attribution_match: 77 | # Split the text, keeping attribution with dialogue 78 | attribution_start_in_before = before_dialogue.rfind(before_attribution_match.group(1)) 79 | pre_attribution_text = before_dialogue[:attribution_start_in_before].strip() 80 | 81 | if pre_attribution_text: 82 | # Split the pre-attribution text if it's long 83 | pre_segments = _split_long_text(pre_attribution_text) 84 | segments.extend(pre_segments) 85 | else: 86 | # No attribution before, split the before_dialogue text normally 87 | before_segments = _split_long_text(before_dialogue) 88 | segments.extend(before_segments) 89 | 90 | # Create the dialogue segment 91 | dialogue_segment = "" 92 | 93 | # Add preceding attribution if found 94 | if before_attribution_match: 95 | dialogue_segment += before_attribution_match.group(1) + " " 96 | 97 | # Add the dialogue 98 | dialogue_segment += dialogue_text 99 | 100 | # Add following attribution if found 101 | if attribution_match: 102 | dialogue_segment += " " + attribution_match.group(1).strip() 103 | last_end = dialogue_end + attribution_match.end() 104 | else: 105 | last_end = dialogue_end 106 | 107 | segments.append(dialogue_segment.strip()) 108 | 109 | # Handle any remaining text after the last dialogue 110 | remaining_text = paragraph[last_end:].strip() 111 | if remaining_text: 112 | remaining_segments = _split_long_text(remaining_text) 113 | segments.extend(remaining_segments) 114 | 115 | # If no dialogues were found, just split the paragraph normally 116 | if not dialogues: 117 | return _split_long_text(paragraph) 118 | 119 | return segments 120 | 121 | def _split_long_text(text, max_length=400): 122 | """Split long text on sentence boundaries, preserving meaning.""" 123 | if len(text) <= max_length: 124 | return [text.strip()] if text.strip() else [] 125 | 126 | # Split on sentence endings, but be careful with abbreviations 127 | sentence_pattern = r'(? max_length: 140 | chunks.append(current_chunk) 141 | current_chunk = sentence 142 | else: 143 | if current_chunk: 144 | current_chunk += " " + sentence 145 | else: 146 | current_chunk = sentence 147 | 148 | # Add the final chunk 149 | if current_chunk: 150 | chunks.append(current_chunk) 151 | 152 | return chunks 153 | 154 | def _combine_short_segments(segments, min_length=40, max_length=500): 155 | """Combine very short segments while keeping segments under max_length.""" 156 | if not segments: 157 | return [] 158 | 159 | combined = [] 160 | current_segment = "" 161 | 162 | for segment in segments: 163 | segment = segment.strip() 164 | if not segment: 165 | continue 166 | 167 | # If adding this segment would exceed max_length, finalize current segment 168 | if current_segment and len(current_segment) + len(segment) + 1 > max_length: 169 | combined.append(current_segment) 170 | current_segment = segment 171 | # If current segment is too short, try to combine with next 172 | elif len(current_segment) < min_length and len(current_segment) + len(segment) + 1 <= max_length: 173 | if current_segment: 174 | current_segment += " " + segment 175 | else: 176 | current_segment = segment 177 | # If the current segment is good length, add it and start new 178 | else: 179 | if current_segment: 180 | combined.append(current_segment) 181 | current_segment = segment 182 | 183 | # Add the last segment 184 | if current_segment: 185 | combined.append(current_segment) 186 | 187 | return combined 188 | 189 | def create_batches(sentences, max_batch_chars=500): 190 | """Create batches by combining sentences up to max_batch_chars""" 191 | batches = [] 192 | current_batch = "" 193 | 194 | for sentence in sentences: 195 | # If adding this sentence would exceed the batch size, start a new batch 196 | if len(current_batch) + len(sentence) > max_batch_chars and current_batch: 197 | batches.append(current_batch) 198 | current_batch = sentence 199 | else: 200 | # Add separator space if needed 201 | if current_batch: 202 | current_batch += " " 203 | current_batch += sentence 204 | 205 | # Add the last batch if it's not empty 206 | if current_batch: 207 | batches.append(current_batch) 208 | 209 | # logger.info(f"Created {len(batches)} batches from {len(sentences)} sentences") 210 | return batches -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Orpheus TTS FastAPI Server (Async) 2 | 3 | A high-performance, production-ready FastAPI-based server that provides OpenAI-compatible Text-to-Speech (TTS) endpoints using the [Orpheus TTS](https://github.com/canopyai/Orpheus-TTS) model with advanced error detection and async parallel processing. This project uses the original `orpheus-speech` python package with vLLM backend, loading the model in bfloat16 by default (with float16/float32 options). Using higher precision formats requires more VRAM but eliminates audio quality issues and artifacts commonly found in quantized models or alternative inference engines. 4 | 5 | The server features sophisticated multi-stage error detection, adaptive retry logic, and comprehensive debugging tools to ensure reliable, high-quality audio generation. Supports async parallel chunk processing with intelligent text chunking that preserves dialogue and narrative flow. 6 | 7 | ## 🚀 Features 8 | 9 | - **OpenAI-compatible API**: Drop-in replacement for OpenAI's TTS API 10 | - **Async Parallel Processing**: Process multiple text chunks simultaneously for faster generation 11 | - **Direct vLLM Integration**: Uses vLLM's AsyncLLMEngine for optimal performance 12 | - **Intelligent Text Chunking**: Preserves dialogue, attribution, and narrative flow 13 | - **Complete Audio Files**: Returns complete WAV files optimized for quality 14 | - **Advanced Error Detection**: Multi-stage analysis prevents audio artifacts and quality issues 15 | - **Adaptive Retry Logic**: Up to 5 automatic retries with parameter adjustment 16 | - **Token Repetition Detection**: Prevents infinite audio loops with pattern analysis 17 | - **Audio Quality Analysis**: Detects silence, repetition, stretching, and monotonic audio 18 | - **Duration Outlier Detection**: Identifies abnormally long audio generation 19 | - **Debug & Success Logging**: Optional comprehensive logging for troubleshooting and tuning 20 | - **SNAC Audio Decoding**: High-quality audio reconstruction from tokens 21 | 22 | ## 🔧 Architecture 23 | 24 | ### Modular Design 25 | The server follows a modular architecture with specialized components: 26 | 27 | - **`fastapi_app.py`**: Main FastAPI application with request handling, CORS, and timeout middleware 28 | - **`audio_generator.py`**: Advanced audio generation with multi-stage error detection and retry logic 29 | - **`audio_decoder.py`**: SNAC-based audio decoding with custom exception handling 30 | - **`text_processor.py`**: Intelligent text chunking that preserves dialogue and narrative flow 31 | - **`audio_analysis.py`**: Comprehensive audio quality analysis using spectrograms and cross-correlation 32 | 33 | ### Async Processing Pipeline 34 | ``` 35 | ┌─────────────────────────────────────────────────┐ 36 | │ FastAPI Request Handler (Async) │ 37 | ├─────────────────────────────────────────────────┤ 38 | │ Text Processing & Chunking (Intelligent) │ 39 | ├─────────────────────────────────────────────────┤ 40 | │ Parallel Token Generation (vLLM AsyncEngine) │ 41 | ├─────────────────────────────────────────────────┤ 42 | │ Multi-Stage Error Detection: │ 43 | │ • Token Repetition Detection │ 44 | │ • Token Count Ratio Analysis │ 45 | │ • Audio Duration Outlier Detection │ 46 | │ • Audio Quality Analysis (Silence/Stretch/Mono) │ 47 | ├─────────────────────────────────────────────────┤ 48 | │ Adaptive Retry Logic (5 attempts max) │ 49 | ├─────────────────────────────────────────────────┤ 50 | │ SNAC Audio Decoding │ 51 | ├─────────────────────────────────────────────────┤ 52 | │ Audio File Generation (Async I/O) │ 53 | ├─────────────────────────────────────────────────┤ 54 | │ Debug/Success Logging (Optional) │ 55 | └─────────────────────────────────────────────────┘ 56 | ``` 57 | 58 | ### Key Improvements Over Sync Version 59 | - **4x faster** for long texts (parallel chunk processing) 60 | - **Non-blocking operations** throughout the pipeline 61 | - **Better resource utilization** with optimized thread pools 62 | - **GPU memory monitoring** and automatic optimization 63 | - **Advanced error detection** with multi-stage analysis to prevent audio artifacts 64 | - **Intelligent retry logic** with adaptive parameter adjustment 65 | - **Comprehensive debugging** with audio quality analysis and metadata logging 66 | 67 | ## 📋 API Endpoints 68 | 69 | ### Core Endpoints 70 | 71 | | Endpoint | Method | Description | 72 | |----------|---------|-------------| 73 | | `/v1/audio/speech` | POST | Generate speech from text (OpenAI-compatible) | 74 | | `/health` | GET | Health check and model status | 75 | | `/` | GET | API information and available endpoints | 76 | 77 | ### Interactive Documentation 78 | 79 | - **Swagger UI**: `http://localhost:8880/docs` 80 | 81 | ## 🛠️ Installation 82 | 83 | ### Prerequisites 84 | 85 | - CUDA-capable GPU (minimum 16GB VRAM recommended) 86 | - Sufficient disk space for model downloads 87 | 88 | ### Install Dependencies 89 | 90 | 1. Install uv 91 | ```bash 92 | curl -LsSf https://astral.sh/uv/install.sh | sh 93 | ``` 94 | 2. Create a virtual environment with Python 3.12: 95 | ```bash 96 | uv venv --python 3.12 97 | ``` 98 | 3. Activate the virtual environment: 99 | ```bash 100 | source .venv/bin/activate 101 | ``` 102 | 4. Install dependencies 103 | ```bash 104 | # Install the required packages 105 | uv pip install -r requirements.txt 106 | ``` 107 | 5. Get access to the Orpheus model from [here](https://huggingface.co/canopylabs/orpheus-3b-0.1-ft) 108 | 6. Login to Huggingface using the CLI after gaining access to the repo 109 | ```bash 110 | huggingface-cli login 111 | ``` 112 | 113 | ## 🔧 Configuration 114 | 115 | ### Environment Variables 116 | 117 | The server supports extensive configuration through environment variables. Copy `.env.sample` to `.env` and modify as needed: 118 | 119 | ```bash 120 | cp .env.sample .env 121 | # Edit .env with your preferred settings 122 | ``` 123 | 124 | ## 🎯 Usage 125 | 126 | ### Starting the Server 127 | 128 | ```bash 129 | uvicorn fastapi_app:app --host 0.0.0.0 --port 8880 130 | ``` 131 | 132 | The server will start on `http://localhost:8880` by default. 133 | 134 | ### Making Requests 135 | 136 | #### Using cURL 137 | 138 | ```bash 139 | # Generate speech 140 | curl -X POST "http://localhost:8880/v1/audio/speech" \ 141 | -H "Content-Type: application/json" \ 142 | -d '{ 143 | "model": "orpheus", 144 | "input": "Hello, this is a test of the Orpheus TTS API!", 145 | "voice": "tara", 146 | "response_format": "wav", 147 | "speed": 1.0 148 | }' \ 149 | --output speech.wav 150 | ``` 151 | 152 | #### Using Python with requests 153 | 154 | ```python 155 | import requests 156 | 157 | response = requests.post( 158 | "http://localhost:8880/v1/audio/speech", 159 | json={ 160 | "model": "orpheus", 161 | "input": "Hello, this is a test!", 162 | "voice": "tara", 163 | "response_format": "wav", 164 | "speed": 1.0 165 | } 166 | ) 167 | 168 | with open("output.wav", "wb") as f: 169 | f.write(response.content) 170 | ``` 171 | 172 | #### Using OpenAI Client Library 173 | 174 | ```python 175 | from openai import OpenAI 176 | 177 | # Configure client to use local server 178 | client = OpenAI( 179 | api_key="dummy-key", # Not validated by local server 180 | base_url="http://localhost:8880" 181 | ) 182 | 183 | response = client.audio.speech.create( 184 | model="orpheus", 185 | voice="tara", 186 | input="Hello, this is a test!" 187 | ) 188 | 189 | with open("speech.wav", "wb") as f: 190 | f.write(response.content) 191 | ``` 192 | 193 | ## 🎤 Voice Options 194 | 195 | The API supports native Orpheus voice names: 196 | 197 | ### Available Voices 198 | - `tara` (default) - neutral/balanced female 199 | - `leah` - warm female voice 200 | - `jess` - expressive female 201 | - `leo` - deep male voice 202 | - `dan` - male voice 203 | - `mia` - young female voice 204 | - `zac` - male voice 205 | - `zoe` - female voice 206 | 207 | ## 🔧 API Reference 208 | 209 | ### POST /v1/audio/speech 210 | 211 | Generate speech from text with automatic parallel processing and advanced error detection for long texts. 212 | 213 | **Request Body:** 214 | ```json 215 | { 216 | "model": "orpheus", // Model to use 217 | "input": "Text to speak", // Text to synthesize (auto-chunked if long) 218 | "voice": "tara", // Voice name (see voice options above) 219 | "response_format": "wav", // Audio format (currently only "wav") 220 | "speed": 1.0, // Speech speed (0.25 to 4.0) 221 | 222 | // Optional: Override environment defaults for sampling 223 | "temperature": 0.2, // Sampling temperature (0.0-2.0, optional) 224 | "top_p": 0.9, // Top-p nucleus sampling (0.0-1.0, optional) 225 | "repetition_penalty": 1.1, // Repetition penalty (0.5-2.0, optional) 226 | "max_tokens": 4096 // Maximum tokens to generate (100-8192, optional) 227 | } 228 | ``` 229 | 230 | **Features:** 231 | - **Automatic Text Chunking**: Long texts (>500 chars) are intelligently split preserving dialogue and narrative flow 232 | - **Parallel Processing**: Multiple chunks processed simultaneously with individual retry logic 233 | - **Multi-Stage Error Detection**: Advanced analysis prevents audio artifacts: 234 | - Token repetition detection (prevents infinite loops) 235 | - Token count ratio analysis (catches outlier generation) 236 | - Audio duration outlier detection (identifies abnormally slow generation) 237 | - Audio quality analysis (detects silence, repetition, stretching, monotonic audio) 238 | - **Adaptive Retry Logic**: Up to 5 automatic retries with parameter adjustment for failed chunks 239 | - **Seamless Audio**: Token-level combination creates single WAV file with natural flow 240 | - **Debug Logging**: Optional comprehensive logging of generation process and error analysis 241 | - **Configurable Parameters**: Override defaults per request with environment fallbacks 242 | 243 | **Response:** 244 | - Content-Type: `audio/wav` 245 | - Body: Binary audio data 246 | 247 | **Error Handling:** 248 | The API implements sophisticated error detection and recovery: 249 | 250 | - **Token-Level Errors**: Repetition patterns, count outliers, invalid ranges 251 | - **Audio-Level Errors**: Duration outliers, quality issues (silence, repetition, stretching) 252 | - **Adaptive Retries**: Automatic parameter adjustment (temperature, repetition_penalty) on retry attempts 253 | - **Graceful Degradation**: Failed chunks are logged and skipped rather than failing entire request 254 | - **Debug Output**: Failed generations saved to `debug_audio_errors/` for analysis 255 | 256 | ### GET /health 257 | 258 | Health check endpoint with detailed system information. 259 | 260 | **Response:** 261 | ```json 262 | { 263 | "status": "healthy", 264 | "model_loaded": true, 265 | "timestamp": 1699896489.123 266 | } 267 | ``` 268 | 269 | ## 📄 License 270 | 271 | This project follows the same license as the original Orpheus TTS repository. 272 | 273 | ## 🔗 Links 274 | 275 | - [Original Orpheus TTS Repository](https://github.com/canopyai/Orpheus-TTS) 276 | - [OpenAI TTS API Documentation](https://platform.openai.com/docs/guides/text-to-speech) 277 | - [FastAPI Documentation](https://fastapi.tiangolo.com/) 278 | - [vLLM Documentation](https://docs.vllm.ai/) 279 | 280 | ## 🔍 Troubleshooting 281 | 282 | ### Common Issues 283 | 284 | 1. **Model Loading Fails** 285 | - Ensure sufficient GPU memory (min 14GB) 286 | - Check CUDA installation and compatibility 287 | - Verify internet connection for model downloads 288 | - Adjust `TTS_GPU_MEMORY_UTILIZATION` if OOM errors occur 289 | 290 | 2. **Out of Memory Errors** 291 | - Reduce `TTS_GPU_MEMORY_UTILIZATION` (try 0.8 or 0.7) 292 | - Decrease `TTS_MAX_MODEL_LEN` (try 6144 or 4096) 293 | - Switch to `TTS_DTYPE="float16"` for lower VRAM usage 294 | 295 | 3. **Token Repetition Errors** 296 | - Check debug logs in `debug_audio_errors/` for pattern analysis 297 | - Increase `TTS_REPETITION_PENALTY` (try 1.2-1.5) 298 | - Adjust `TTS_TEMPERATURE` (try 0.3-0.5 for more variety) 299 | - Reduce `TTS_MAX_TOKENS` if hitting generation limits 300 | - Enable debug logging to analyze token patterns 301 | 302 | 4. **Audio Duration Outliers** 303 | - Check `debug_audio_errors/` for duration analysis 304 | - Lower `TTS_TEMPERATURE` (try 0.1-0.2) for more consistent output 305 | - Increase `TTS_REPETITION_PENALTY` (try 1.3+) to prevent loops 306 | - Adjust `TTS_TOP_P` (try 0.7-0.8) for more focused generation 307 | - Review text for emotion tags that may extend duration 308 | 309 | 5. **Audio Quality Issues (Silence/Stretching/Repetition)** 310 | - Enable `ENABLE_DEBUG_SAVING=true` to save problematic audio 311 | - Check `debug_audio_errors/` for quality metrics 312 | - Adjust `TTS_TEMPERATURE` and `TTS_REPETITION_PENALTY` 313 | - Use audio analysis tool: `python audio_analysis.py ` 314 | - Consider text preprocessing for emotion tags 315 | 316 | 6. **Slow Performance** 317 | - Increase `TTS_MAX_WORKERS` (try 16+) 318 | - Ensure GPU is being used (`nvidia-smi`) 319 | - Check if text is being chunked properly 320 | - Verify `TTS_GPU_MEMORY_UTILIZATION` isn't too low 321 | - Monitor chunk processing times in logs 322 | 323 | 7. **Engine Shutdown Errors** 324 | - Ensure using V0 engine: `VLLM_USE_V1="0"` 325 | - Don't overload with too many parallel requests 326 | - Monitor GPU memory usage 327 | - Check `TTS_TENSOR_PARALLEL_SIZE` matches available GPUs 328 | 329 | 8. **Poor Audio Quality** 330 | - Try `TTS_DTYPE="float32"` for highest quality 331 | - Adjust `TTS_TEMPERATURE` (lower = more consistent) 332 | - Tune `TTS_REPETITION_PENALTY` (higher = less repetitive) 333 | - Check `TTS_TOP_P` settings (0.8-0.95 range) 334 | - Enable success logging to analyze good generations 335 | 336 | ### Debug Tools 337 | 338 | The server includes comprehensive debugging tools: 339 | 340 | - **Debug Audio Saving**: Failed generations saved to `debug_audio_errors/` with metadata 341 | - **Success Logging**: Successful generations saved to `debug_audio_success/` for tuning 342 | - **Audio Analysis**: Use `python audio_analysis.py ` for quality analysis 343 | - **Token Inspection**: Check `debug_audio_success/` for reproducibility data 344 | - **Metadata Logging**: Full generation statistics and error analysis 345 | 346 | Enable debugging with environment variables: 347 | ```bash 348 | ENABLE_DEBUG_SAVING=true 349 | ENABLE_SUCCESS_LOGGING=true 350 | LOG_LEVEL=DEBUG 351 | ``` 352 | -------------------------------------------------------------------------------- /fastapi_app.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Literal 2 | import time 3 | import logging 4 | import os 5 | import uuid 6 | import asyncio 7 | from contextlib import asynccontextmanager 8 | from concurrent.futures import ThreadPoolExecutor 9 | 10 | from fastapi import FastAPI, HTTPException, Request, BackgroundTasks 11 | from fastapi.responses import FileResponse 12 | from fastapi.middleware.cors import CORSMiddleware 13 | from pydantic import BaseModel, Field 14 | import torch 15 | import uvicorn 16 | 17 | # Force V0 engine to avoid async issues 18 | os.environ["VLLM_USE_V1"] = "0" 19 | 20 | from dotenv import load_dotenv 21 | 22 | # Import our new modules 23 | from audio_decoder import initialize_snac_model, shutdown_snac_model 24 | from audio_generator import ( 25 | OrpheusModelExtended, 26 | generate_speech_tokens_direct, 27 | generate_speech_chunks, 28 | tokens_to_audio_file 29 | ) 30 | from text_processor import split_text_into_sentences, create_batches 31 | 32 | load_dotenv() 33 | 34 | def get_dtype_from_string(dtype_str: str) -> torch.dtype: 35 | """Convert string to torch dtype""" 36 | dtype_map = { 37 | "bfloat16": torch.bfloat16, 38 | "float16": torch.float16, 39 | "float32": torch.float32, 40 | "bf16": torch.bfloat16, # alias 41 | "fp16": torch.float16, # alias 42 | "fp32": torch.float32, # alias 43 | } 44 | dtype_str = dtype_str.lower() 45 | if dtype_str not in dtype_map: 46 | logger.info(f"Unknown dtype '{dtype_str}', defaulting to bfloat16") 47 | return torch.bfloat16 48 | return dtype_map[dtype_str] 49 | 50 | # Model Configuration 51 | MODEL_NAME = os.getenv("TTS_MODEL_NAME", "canopylabs/orpheus-tts-0.1-finetune-prod") 52 | DTYPE = get_dtype_from_string(os.getenv("TTS_DTYPE", "bfloat16")) 53 | MAX_MODEL_LEN = int(os.getenv("TTS_MAX_MODEL_LEN", "8192")) 54 | TENSOR_PARALLEL_SIZE = int(os.getenv("TTS_TENSOR_PARALLEL_SIZE", "1")) 55 | GPU_MEMORY_UTILIZATION = float(os.getenv("TTS_GPU_MEMORY_UTILIZATION", "0.9")) 56 | MAX_NUM_SEQS = int(os.getenv("TTS_MAX_NUM_SEQS", "32")) 57 | 58 | # Performance Configuration 59 | MAX_WORKERS = int(os.getenv("TTS_MAX_WORKERS", "16")) 60 | 61 | # SNAC Configuration 62 | SNAC_DEVICE = os.getenv("SNAC_DEVICE", "cuda") 63 | 64 | # Generation Parameters (defaults) 65 | DEFAULT_TEMPERATURE = float(os.getenv("TTS_TEMPERATURE", "0.2")) 66 | DEFAULT_TOP_P = float(os.getenv("TTS_TOP_P", "0.9")) 67 | DEFAULT_REPETITION_PENALTY = float(os.getenv("TTS_REPETITION_PENALTY", "1.1")) 68 | DEFAULT_MAX_TOKENS = int(os.getenv("TTS_MAX_TOKENS", "4096")) 69 | 70 | # Logging Configuration 71 | LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() 72 | 73 | # Setup logging 74 | logging.basicConfig(level=getattr(logging, LOG_LEVEL)) 75 | logger = logging.getLogger(__name__) 76 | 77 | # Log configuration on startup 78 | logger.info("🔧 Configuration loaded:") 79 | logger.info(f" Model: {MODEL_NAME}") 80 | logger.info(f" Dtype: {DTYPE}") 81 | logger.info(f" Max Model Length: {MAX_MODEL_LEN}") 82 | logger.info(f" Tensor Parallel Size: {TENSOR_PARALLEL_SIZE}") 83 | logger.info(f" GPU Memory Utilization: {GPU_MEMORY_UTILIZATION}") 84 | logger.info(f" Max Workers: {MAX_WORKERS}") 85 | logger.info(f" SNAC Device: {SNAC_DEVICE}") 86 | logger.info(f" Temperature: {DEFAULT_TEMPERATURE}") 87 | logger.info(f" Top P: {DEFAULT_TOP_P}") 88 | logger.info(f" Repetition Penalty: {DEFAULT_REPETITION_PENALTY}") 89 | logger.info(f" Max Tokens: {DEFAULT_MAX_TOKENS}") 90 | logger.info(f" Max Num Sequences: {MAX_NUM_SEQS}") 91 | 92 | # Global engine variable 93 | engine = None 94 | executor = None 95 | 96 | @asynccontextmanager 97 | async def lifespan(app: FastAPI): 98 | """FastAPI lifespan context manager for startup and shutdown events""" 99 | # Startup 100 | logger.info("🚀 Loading Orpheus TTS model...") 101 | global engine, executor 102 | 103 | # Load SNAC model for audio decoding 104 | initialize_snac_model(device=SNAC_DEVICE) 105 | 106 | # Load Orpheus TTS model 107 | engine = OrpheusModelExtended( 108 | model_name=MODEL_NAME, 109 | dtype=DTYPE, 110 | max_model_len=MAX_MODEL_LEN, 111 | tensor_parallel_size=TENSOR_PARALLEL_SIZE, 112 | gpu_memory_utilization=GPU_MEMORY_UTILIZATION, 113 | max_num_seqs=MAX_NUM_SEQS 114 | ) 115 | 116 | # Create thread pool executor for file I/O only (no more token decoding) 117 | executor = ThreadPoolExecutor(max_workers=MAX_WORKERS) 118 | 119 | logger.info("✅ Model loaded successfully!") 120 | 121 | # Create outputs directory 122 | os.makedirs("outputs", exist_ok=True) 123 | logger.info("📁 Created outputs directory") 124 | 125 | # Create debug audio directory for edge case analysis 126 | os.makedirs("debug_audio_errors", exist_ok=True) 127 | logger.info("📁 Created debug_audio_errors directory for edge case logging") 128 | 129 | # Create success logging directory for debugging and parameter tuning 130 | os.makedirs("debug_audio_success", exist_ok=True) 131 | logger.info("📁 Created debug_audio_success directory for successful request logging") 132 | 133 | yield 134 | 135 | # Shutdown 136 | logger.info("🔄 Shutting down Orpheus TTS model...") 137 | if executor: 138 | executor.shutdown(wait=True) 139 | shutdown_snac_model() 140 | engine = None 141 | executor = None 142 | logger.info("✅ Shutdown complete!") 143 | 144 | app = FastAPI( 145 | title="Orpheus TTS API", 146 | description="OpenAI-compatible Text-to-Speech API using Orpheus TTS", 147 | version="1.0.0", 148 | lifespan=lifespan 149 | ) 150 | 151 | # Add CORS middleware for web clients 152 | app.add_middleware( 153 | CORSMiddleware, 154 | allow_origins=["*"], # Configure as needed for security 155 | allow_credentials=True, 156 | allow_methods=["*"], 157 | allow_headers=["*"], 158 | ) 159 | 160 | # Add timeout middleware for long-running requests 161 | @app.middleware("http") 162 | async def timeout_middleware(request: Request, call_next): 163 | """Middleware to handle long-running requests with extended timeout""" 164 | start_time = time.time() 165 | 166 | # Set extended timeout for TTS endpoints 167 | if request.url.path.startswith("/v1/audio/"): 168 | # For TTS endpoints, set a much longer timeout (10 minutes) 169 | timeout = 600.0 170 | else: 171 | # For other endpoints, use default timeout (30 seconds) 172 | timeout = 30.0 173 | 174 | try: 175 | # Process request with timeout 176 | response = await asyncio.wait_for(call_next(request), timeout=timeout) 177 | 178 | # Log request duration 179 | process_time = time.time() - start_time 180 | # logger.info(f"Request {request.method} {request.url.path} completed in {process_time:.2f}s") 181 | 182 | return response 183 | 184 | except asyncio.TimeoutError: 185 | logger.error(f"Request {request.method} {request.url.path} timed out after {timeout}s") 186 | raise HTTPException( 187 | status_code=504, 188 | detail={ 189 | "error": "timeout_error", 190 | "message": f"Request timed out after {timeout} seconds", 191 | "type": "server_error", 192 | } 193 | ) 194 | except Exception as e: 195 | process_time = time.time() - start_time 196 | logger.error(f"Request {request.method} {request.url.path} failed after {process_time:.2f}s: {e}") 197 | raise 198 | 199 | # OpenAI-compatible request models 200 | class SpeechRequest(BaseModel): 201 | model: Literal["orpheus"] = Field(default="orpheus", description="The TTS model to use") 202 | input: str = Field(..., description="The text to generate audio for", max_length=4096) 203 | voice: Literal["tara", "leah", "jess", "leo", "dan", "mia", "zac", "zoe"] = Field( 204 | default="tara", 205 | description="The voice to use for generation" 206 | ) 207 | response_format: Literal["wav"] = Field( 208 | default="wav", 209 | description="The format of the audio output" 210 | ) 211 | speed: float = Field(default=1.0, ge=0.25, le=4.0, description="The speed of the generated audio") 212 | 213 | # Optional sampling parameters (will use environment defaults if not specified) 214 | temperature: Optional[float] = Field( 215 | default=None, 216 | ge=0.0, 217 | le=2.0, 218 | description="Temperature for sampling (0.0 to 2.0). Uses environment default if not specified." 219 | ) 220 | top_p: Optional[float] = Field( 221 | default=None, 222 | ge=0.0, 223 | le=1.0, 224 | description="Top-p for nucleus sampling (0.0 to 1.0). Uses environment default if not specified." 225 | ) 226 | repetition_penalty: Optional[float] = Field( 227 | default=None, 228 | ge=0.5, 229 | le=2.0, 230 | description="Repetition penalty (0.5 to 2.0). Uses environment default if not specified." 231 | ) 232 | max_tokens: Optional[int] = Field( 233 | default=None, 234 | ge=100, 235 | le=8192, 236 | description="Maximum tokens to generate. Uses environment default if not specified." 237 | ) 238 | 239 | def cleanup_file(file_path: str) -> None: 240 | """Clean up a single file after serving""" 241 | try: 242 | if os.path.exists(file_path): 243 | os.remove(file_path) 244 | # logger.info(f"Cleaned up file: {file_path}") 245 | except Exception as e: 246 | logger.info(f"Failed to cleanup file {file_path}: {e}") 247 | 248 | @app.post("/v1/audio/speech") 249 | async def create_speech(request: SpeechRequest, background_tasks: BackgroundTasks): 250 | """ 251 | Create speech from text using OpenAI-compatible API. 252 | 253 | Generates audio from the input text using the Orpheus TTS model. 254 | Automatically chunks long text and combines audio at token level for seamless playback. 255 | Returns audio file in WAV format. 256 | 257 | Features: 258 | - Automatic text chunking for long inputs 259 | - Token-level audio combination for seamless playback 260 | - Fully async processing with no threading overhead 261 | - Intelligent retry logic for audio decoding errors 262 | - Detailed logging of chunk processing and retry statistics 263 | - Configurable retry attempts and delays 264 | """ 265 | try: 266 | # logger.info(f"Got request: {request.input}") 267 | 268 | # Validate input 269 | if not request.input.strip(): 270 | raise HTTPException( 271 | status_code=400, 272 | detail={ 273 | "error": "validation_error", 274 | "message": "Input text cannot be empty", 275 | "type": "invalid_request_error", 276 | } 277 | ) 278 | 279 | # logger.info(f"Processing TTS request: {len(request.input)} chars, voice: {request.voice}") 280 | 281 | # Generate unique filename 282 | unique_id = uuid.uuid4() 283 | output_path = f"outputs/{request.voice}_{unique_id}.wav" 284 | 285 | # Generate speech 286 | start_time = time.time() 287 | try: 288 | if len(request.input) > 500: 289 | # Split the text into sentences 290 | sentences = split_text_into_sentences(request.input) 291 | # logger.info(f"Split text into {len(sentences)} segments") 292 | 293 | # Create batches by combining sentences up to max_batch_chars 294 | batches = create_batches(sentences, max_batch_chars=500) 295 | # logger.info(f"Created {len(batches)} batches for processing") 296 | 297 | # Generate tokens for all chunks and combine them 298 | combined_tokens, metadata = await generate_speech_chunks( 299 | engine, batches, request.voice, 300 | request.temperature, request.top_p, 301 | request.repetition_penalty, request.max_tokens, 302 | DEFAULT_TEMPERATURE, DEFAULT_TOP_P, DEFAULT_REPETITION_PENALTY, DEFAULT_MAX_TOKENS, 303 | executor 304 | ) 305 | 306 | # Convert combined tokens to audio file 307 | if executor is None: 308 | raise HTTPException(status_code=500, detail="Executor not initialized") 309 | file_stats = await tokens_to_audio_file(combined_tokens, output_path, executor) 310 | 311 | # Log detailed results including retry statistics 312 | retry_stats = metadata.get("retry_stats", {}) 313 | # logger.info(f"Chunked processing complete - Total attempts: {retry_stats.get('total_attempts', 0)}, " 314 | # f"Total retries: {retry_stats.get('total_retries', 0)}, " 315 | # f"Failed chunks: {retry_stats.get('failed_chunks', 0)}") 316 | else: 317 | # Single processing for shorter text using direct async access 318 | token_chunks, metadata = await generate_speech_tokens_direct( 319 | engine, request.input, request.voice, 320 | request.temperature, request.top_p, 321 | request.repetition_penalty, request.max_tokens, 322 | DEFAULT_TEMPERATURE, DEFAULT_TOP_P, DEFAULT_REPETITION_PENALTY, DEFAULT_MAX_TOKENS, 323 | executor 324 | ) 325 | 326 | # Convert tokens to audio file 327 | if executor is None: 328 | raise HTTPException(status_code=500, detail="Executor not initialized") 329 | file_stats = await tokens_to_audio_file(token_chunks, output_path, executor) 330 | 331 | # Combine metadata 332 | result = {**metadata, **file_stats} 333 | 334 | # Log single processing results including retry information 335 | attempts = metadata.get("attempts", 1) 336 | retries = metadata.get("retries", 0) 337 | # logger.info(f"Single processing complete - Duration: {file_stats['duration_seconds']:.2f}s, " 338 | # f"Attempts: {attempts}, Retries: {retries}") 339 | 340 | except Exception as e: 341 | logger.error(f"Error during TTS generation: {e}") 342 | raise HTTPException( 343 | status_code=500, 344 | detail={ 345 | "error": "processing_error", 346 | "message": f"TTS generation failed: {str(e)}", 347 | "type": "server_error", 348 | } 349 | ) 350 | 351 | end_time = time.time() 352 | generation_time = round(end_time - start_time, 2) 353 | 354 | # Verify the output file was created 355 | if not os.path.exists(output_path): 356 | raise HTTPException( 357 | status_code=500, 358 | detail={ 359 | "error": "processing_error", 360 | "message": "Audio file generation failed", 361 | "type": "server_error", 362 | } 363 | ) 364 | 365 | # logger.info(f"TTS generation completed in {generation_time}s, returning file: {output_path}") 366 | 367 | # Schedule file cleanup after response is sent 368 | background_tasks.add_task(cleanup_file, output_path) 369 | 370 | # Return audio file 371 | return FileResponse( 372 | path=output_path, 373 | media_type="audio/wav", 374 | filename=f"{request.voice}_{unique_id}.wav", 375 | headers={ 376 | "Content-Disposition": f"attachment; filename={request.voice}_{unique_id}.wav" 377 | } 378 | ) 379 | 380 | except HTTPException: 381 | # Re-raise HTTP exceptions as-is 382 | raise 383 | except ValueError as e: 384 | # Handle validation errors 385 | logger.info(f"Validation error: {str(e)}") 386 | raise HTTPException( 387 | status_code=400, 388 | detail={ 389 | "error": "validation_error", 390 | "message": str(e), 391 | "type": "invalid_request_error", 392 | } 393 | ) 394 | except Exception as e: 395 | # Handle unexpected errors 396 | logger.error(f"Unexpected error in speech generation: {str(e)}") 397 | raise HTTPException( 398 | status_code=500, 399 | detail={ 400 | "error": "internal_error", 401 | "message": "An unexpected error occurred", 402 | "type": "server_error", 403 | } 404 | ) 405 | 406 | @app.get("/health") 407 | async def health_check(): 408 | """Health check endpoint""" 409 | return { 410 | "status": "healthy", 411 | "model_loaded": engine is not None, 412 | "timestamp": time.time() 413 | } 414 | 415 | @app.get("/") 416 | async def root(): 417 | """Root endpoint with API information""" 418 | return { 419 | "message": "Orpheus TTS API", 420 | "version": "1.0.0", 421 | "description": "OpenAI-compatible Text-to-Speech API", 422 | "endpoints": { 423 | "speech": "/v1/audio/speech", 424 | "health": "/health" 425 | } 426 | } 427 | 428 | if __name__ == "__main__": 429 | uvicorn.run( 430 | "fastapi_app:app", 431 | host="0.0.0.0", 432 | port=8880, 433 | reload=False, 434 | log_level="info", 435 | timeout_keep_alive=600, # Keep connection alive for 10 minutes 436 | ) -------------------------------------------------------------------------------- /audio_decoder.py: -------------------------------------------------------------------------------- 1 | # This file is a modified version of the orpheus_tts.decoder.py file. Some improvements were made to the error handling of failed audio decoding. 2 | 3 | import torch 4 | import numpy as np 5 | import logging 6 | import statistics 7 | import re 8 | import os 9 | from typing import Optional, AsyncGenerator, Union 10 | from snac import SNAC 11 | 12 | from dotenv import load_dotenv 13 | 14 | load_dotenv() 15 | 16 | # Logging Configuration 17 | LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() 18 | 19 | # Setup logging 20 | logging.basicConfig(level=getattr(logging, LOG_LEVEL)) 21 | logger = logging.getLogger(__name__) 22 | 23 | def normalize_and_count_words(text: str) -> int: 24 | """ 25 | Normalize text and count actual words, handling various separators. 26 | 27 | Treats hyphens, underscores, pipes, slashes, and other separators as word boundaries. 28 | Removes punctuation and extra whitespace. 29 | 30 | Examples: 31 | "Hello world" -> 2 words 32 | "-INSULT-ALBUS-DUMBLEDORE-" -> 3 words (INSULT, ALBUS, DUMBLEDORE) 33 | "hello_world_test" -> 3 words 34 | "a/b/c" -> 3 words 35 | 36 | Args: 37 | text: Input text 38 | 39 | Returns: 40 | int: Number of actual words 41 | """ 42 | if not text or not text.strip(): 43 | return 0 44 | 45 | # Replace common separators with spaces 46 | # This includes: hyphens, underscores, pipes, slashes, etc. 47 | normalized = re.sub(r'[-_|/\\]+', ' ', text) 48 | 49 | # Remove other punctuation and special characters 50 | normalized = re.sub(r'[^\w\s]', ' ', normalized) 51 | 52 | # Split by whitespace and filter out empty strings 53 | words = [word for word in normalized.split() if word.strip()] 54 | 55 | return len(words) 56 | 57 | # Custom exceptions for audio decoding errors 58 | class AudioDecodingError(Exception): 59 | """Base exception for audio decoding errors""" 60 | pass 61 | 62 | class InsufficientTokensError(AudioDecodingError): 63 | """Raised when there are not enough tokens to decode audio""" 64 | pass 65 | 66 | class InvalidTokenRangeError(AudioDecodingError): 67 | """Raised when tokens are outside the valid range (0-4096)""" 68 | pass 69 | 70 | class TokenParsingError(AudioDecodingError): 71 | """Raised when token parsing fails""" 72 | pass 73 | 74 | class TokenFormatError(AudioDecodingError): 75 | """Raised when tokens are not in the expected format""" 76 | pass 77 | 78 | class TokenRepetitionError(AudioDecodingError): 79 | """Raised when repetitive token patterns are detected that cause audio artifacts""" 80 | pass 81 | 82 | class AudioDurationOutlierError(AudioDecodingError): 83 | """Raised when audio duration is an outlier for the given text length""" 84 | pass 85 | 86 | class TokenCountOutlierError(AudioDecodingError): 87 | """Raised when token count is suspiciously high relative to input text length""" 88 | pass 89 | 90 | # Global SNAC model variable 91 | _snac_model = None 92 | _snac_device = None 93 | 94 | def initialize_snac_model(device: str = "cuda") -> None: 95 | """Initialize the SNAC model for audio decoding""" 96 | global _snac_model, _snac_device 97 | 98 | if _snac_model is None: 99 | logger.info(f"Loading SNAC model on device: {device}") 100 | _snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval() 101 | _snac_device = device 102 | _snac_model = _snac_model.to(_snac_device) 103 | logger.info("SNAC model loaded successfully") 104 | 105 | def shutdown_snac_model() -> None: 106 | """Shutdown the SNAC model""" 107 | global _snac_model, _snac_device 108 | 109 | if _snac_model is not None: 110 | logger.info("Shutting down SNAC model") 111 | _snac_model = None 112 | _snac_device = None 113 | 114 | def convert_to_audio(multiframe: list[int], count: int) -> bytes: 115 | """Convert multiframe tokens to audio bytes with proper error handling""" 116 | if _snac_model is None: 117 | raise RuntimeError("SNAC model not initialized. Call initialize_snac_model() first.") 118 | 119 | if len(multiframe) < 7: 120 | logger.error(f"DEBUG: Not enough tokens to decode! Got {len(multiframe)} tokens, need at least 7") 121 | raise InsufficientTokensError(f"Not enough tokens to decode audio: {len(multiframe)} < 7") 122 | 123 | codes_0 = torch.tensor([], device=_snac_device, dtype=torch.int32) 124 | codes_1 = torch.tensor([], device=_snac_device, dtype=torch.int32) 125 | codes_2 = torch.tensor([], device=_snac_device, dtype=torch.int32) 126 | 127 | num_frames = len(multiframe) // 7 128 | frame = multiframe[:num_frames*7] 129 | 130 | for j in range(num_frames): 131 | i = 7*j 132 | if codes_0.shape[0] == 0: 133 | codes_0 = torch.tensor([frame[i]], device=_snac_device, dtype=torch.int32) 134 | else: 135 | codes_0 = torch.cat([codes_0, torch.tensor([frame[i]], device=_snac_device, dtype=torch.int32)]) 136 | 137 | if codes_1.shape[0] == 0: 138 | codes_1 = torch.tensor([frame[i+1]], device=_snac_device, dtype=torch.int32) 139 | codes_1 = torch.cat([codes_1, torch.tensor([frame[i+4]], device=_snac_device, dtype=torch.int32)]) 140 | else: 141 | codes_1 = torch.cat([codes_1, torch.tensor([frame[i+1]], device=_snac_device, dtype=torch.int32)]) 142 | codes_1 = torch.cat([codes_1, torch.tensor([frame[i+4]], device=_snac_device, dtype=torch.int32)]) 143 | 144 | if codes_2.shape[0] == 0: 145 | codes_2 = torch.tensor([frame[i+2]], device=_snac_device, dtype=torch.int32) 146 | codes_2 = torch.cat([codes_2, torch.tensor([frame[i+3]], device=_snac_device, dtype=torch.int32)]) 147 | codes_2 = torch.cat([codes_2, torch.tensor([frame[i+5]], device=_snac_device, dtype=torch.int32)]) 148 | codes_2 = torch.cat([codes_2, torch.tensor([frame[i+6]], device=_snac_device, dtype=torch.int32)]) 149 | else: 150 | codes_2 = torch.cat([codes_2, torch.tensor([frame[i+2]], device=_snac_device, dtype=torch.int32)]) 151 | codes_2 = torch.cat([codes_2, torch.tensor([frame[i+3]], device=_snac_device, dtype=torch.int32)]) 152 | codes_2 = torch.cat([codes_2, torch.tensor([frame[i+5]], device=_snac_device, dtype=torch.int32)]) 153 | codes_2 = torch.cat([codes_2, torch.tensor([frame[i+6]], device=_snac_device, dtype=torch.int32)]) 154 | 155 | codes = [codes_0.unsqueeze(0), codes_1.unsqueeze(0), codes_2.unsqueeze(0)] 156 | 157 | # Check that all tokens are between 0 and 4096 158 | if (torch.any(codes[0] < 0) or torch.any(codes[0] > 4096) or 159 | torch.any(codes[1] < 0) or torch.any(codes[1] > 4096) or 160 | torch.any(codes[2] < 0) or torch.any(codes[2] > 4096)): 161 | 162 | # Find specific out-of-range values for better error reporting 163 | invalid_tokens = [] 164 | for i, code_tensor in enumerate(codes): 165 | invalid_mask = (code_tensor < 0) | (code_tensor > 4096) 166 | if torch.any(invalid_mask): 167 | invalid_values = code_tensor[invalid_mask].tolist() 168 | invalid_tokens.append(f"codes_{i}: {invalid_values}") 169 | 170 | error_msg = f"Some tokens are out of range (0-4096): {'; '.join(invalid_tokens)}" 171 | logger.error(f"DEBUG: {error_msg}") 172 | raise InvalidTokenRangeError(error_msg) 173 | 174 | with torch.inference_mode(): 175 | audio_hat = _snac_model.decode(codes) 176 | 177 | audio_slice = audio_hat[:, :, 2048:4096] 178 | detached_audio = audio_slice.detach().cpu() 179 | audio_np = detached_audio.numpy() 180 | audio_int16 = (audio_np * 32767).astype(np.int16) 181 | audio_bytes = audio_int16.tobytes() 182 | 183 | return audio_bytes 184 | 185 | def turn_token_into_id(token_string: str, index: int) -> int: 186 | """Convert token string to ID with proper error handling""" 187 | # Strip whitespace 188 | token_string = token_string.strip() 189 | 190 | # Find the last token in the string 191 | last_token_start = token_string.rfind(""): 202 | try: 203 | number_str = last_token[14:-1] 204 | token_id = int(number_str) - 10 - ((index % 7) * 4096) 205 | return token_id 206 | except ValueError as e: 207 | logger.error(f"DEBUG: Value error in token conversion: {e}, token: '{last_token}'") 208 | raise TokenParsingError(f"Failed to parse token number from '{last_token}': {e}") 209 | else: 210 | logger.error(f"DEBUG: Token not in expected format: '{last_token}'") 211 | raise TokenFormatError(f"Token not in expected format: '{last_token}'") 212 | 213 | def check_token_repetition(tokens: list[str], max_tokens: int = 4096) -> None: 214 | """ 215 | Check for repetitive token patterns that cause audio artifacts. 216 | Uses adaptive detection based on the actual max_tokens limit and generation progress. 217 | 218 | Args: 219 | tokens: List of raw token strings from the language model 220 | max_tokens: Maximum tokens configured for generation 221 | 222 | Raises: 223 | TokenRepetitionError: When repetitive patterns are detected 224 | """ 225 | logger.debug(f"Checking {len(tokens)} tokens for repetition patterns (max_tokens: {max_tokens})") 226 | 227 | # Convert tokens to token IDs for pattern analysis 228 | token_id_sequence = [] 229 | count = 0 230 | 231 | for token_string in tokens: 232 | try: 233 | token_id = turn_token_into_id(token_string, count) 234 | if token_id > 0: 235 | token_id_sequence.append(token_id) 236 | count += 1 237 | except (TokenParsingError, TokenFormatError): 238 | # Skip invalid tokens, just like in tokens_decoder 239 | continue 240 | 241 | logger.debug(f"Converted {len(token_id_sequence)} valid token IDs for repetition analysis") 242 | 243 | # Only check for repetition if we have a significant number of tokens 244 | # Based on observation: repetition started around token 150, so check when we have >100 tokens 245 | MIN_TOKENS_FOR_REPETITION_CHECK = max(100, max_tokens // 40) # At least 100 or 2.5% of max_tokens 246 | 247 | if len(token_id_sequence) < MIN_TOKENS_FOR_REPETITION_CHECK: 248 | logger.debug(f"Skipping repetition check: {len(token_id_sequence)} tokens < {MIN_TOKENS_FOR_REPETITION_CHECK} minimum") 249 | return 250 | 251 | # Adaptive configuration based on max_tokens and current generation progress 252 | generation_progress = len(token_id_sequence) / max_tokens 253 | 254 | # Dynamic pattern detection configuration based on max_tokens 255 | # Scale pattern lengths with max_tokens to handle different environments 256 | MIN_PATTERN_LENGTH = max(10, max_tokens // 400) # 10 for 4096 tokens, scales up for larger limits 257 | MAX_PATTERN_LENGTH = max(50, max_tokens // 80) # 50 for 4096 tokens, scales up for larger limits 258 | 259 | # Ensure reasonable bounds 260 | MIN_PATTERN_LENGTH = min(MIN_PATTERN_LENGTH, 50) # Cap at 50 to avoid excessive computation 261 | MAX_PATTERN_LENGTH = min(MAX_PATTERN_LENGTH, 200) # Cap at 200 to avoid excessive computation 262 | 263 | # Look at a larger window for long generations (up to 50% of generation) 264 | REPETITION_WINDOW = min(len(token_id_sequence), max(500, int(len(token_id_sequence) * 0.5))) 265 | 266 | # Adaptive repetition threshold based on generation progress and max_tokens 267 | # Higher base thresholds for more conservative detection 268 | base_threshold = max(8, max_tokens // 500) # 8 for 4096 tokens, scales up for larger limits 269 | 270 | if generation_progress < 0.3: # Early in generation 271 | REPETITION_THRESHOLD = base_threshold + 5 # More conservative 272 | elif generation_progress < 0.7: # Mid generation 273 | REPETITION_THRESHOLD = base_threshold + 2 # Moderate 274 | else: # Late in generation (likely approaching limit) 275 | REPETITION_THRESHOLD = max(5, base_threshold - 2) # More aggressive for late-stage 276 | 277 | logger.debug(f"Repetition detection config: window={REPETITION_WINDOW}, pattern_len={MIN_PATTERN_LENGTH}-{MAX_PATTERN_LENGTH}, threshold={REPETITION_THRESHOLD}, progress={generation_progress:.2%}, max_tokens={max_tokens}") 278 | 279 | # Check for repetition patterns 280 | recent_tokens = token_id_sequence[-REPETITION_WINDOW:] 281 | 282 | # Check for repeating patterns of different lengths 283 | for pattern_length in range(MIN_PATTERN_LENGTH, min(MAX_PATTERN_LENGTH + 1, REPETITION_WINDOW // 2 + 1)): 284 | if len(recent_tokens) >= pattern_length * REPETITION_THRESHOLD: 285 | # Extract the pattern from the end and check if it repeats 286 | pattern = recent_tokens[-pattern_length:] 287 | repetitions = 0 288 | 289 | # Count consecutive repetitions at the end 290 | for i in range(len(recent_tokens) - pattern_length, -1, -pattern_length): 291 | if i >= 0 and i + pattern_length <= len(recent_tokens): 292 | if recent_tokens[i:i + pattern_length] == pattern: 293 | repetitions += 1 294 | else: 295 | break 296 | 297 | if repetitions >= REPETITION_THRESHOLD: 298 | logger.error(f"🔄 REPETITION DETECTED: Pattern length {pattern_length} repeated {repetitions} times") 299 | logger.error(f"🔄 Pattern: {pattern}") 300 | logger.error(f"🔄 Generation progress: {generation_progress:.1%} ({len(token_id_sequence)}/{max_tokens} tokens)") 301 | logger.error(f"🔄 Recent token sequence: {token_id_sequence[-min(50, len(token_id_sequence)):]}") 302 | raise TokenRepetitionError(f"Repetitive pattern detected: {pattern_length}-token pattern repeated {repetitions} times at {generation_progress:.1%} progress") 303 | 304 | logger.debug(f"No repetition patterns detected in {len(token_id_sequence)} tokens ({generation_progress:.1%} progress)") 305 | 306 | def check_token_count_ratio(tokens: list[str], input_text: str, max_tokens: int) -> dict: 307 | """ 308 | Check if token count is suspiciously high relative to input text. 309 | Returns a dict with check results and metrics for logging. 310 | 311 | Args: 312 | tokens: List of raw token strings from the language model 313 | input_text: The original input text 314 | max_tokens: Maximum tokens configured for generation 315 | 316 | Returns: 317 | dict: Metrics including is_outlier, token_count, word_count, ratio, etc. 318 | 319 | Raises: 320 | TokenRepetitionError: When token count is suspiciously high 321 | """ 322 | word_count = normalize_and_count_words(input_text) 323 | token_count = len(tokens) 324 | 325 | # Typically, each word generates roughly 10-50 tokens (7 per frame, multiple frames per phoneme) 326 | # If we're generating way more than expected, something might be wrong 327 | EXPECTED_TOKENS_PER_WORD = 50 # Conservative estimate 328 | MAX_RATIO_MULTIPLIER = 3.0 # Allow 3x the expected 329 | 330 | # For very short texts (1-5 words), be much more lenient 331 | # Short utterances with punctuation can legitimately generate more tokens 332 | if word_count <= 5: 333 | MAX_RATIO_MULTIPLIER = 10.0 # Very lenient for short texts 334 | elif word_count <= 10: 335 | MAX_RATIO_MULTIPLIER = 6.0 # More lenient for short texts 336 | 337 | max_expected_tokens = word_count * EXPECTED_TOKENS_PER_WORD * MAX_RATIO_MULTIPLIER 338 | tokens_per_word = token_count / word_count if word_count > 0 else 0 339 | 340 | # Also check if we're hitting the max_tokens limit (often indicates repetition loop) 341 | token_limit_threshold = max_tokens * 0.95 # 95% of limit 342 | is_near_limit = token_count > token_limit_threshold 343 | is_ratio_outlier = token_count > max_expected_tokens 344 | 345 | metrics = { 346 | "token_count": token_count, 347 | "word_count": word_count, 348 | "tokens_per_word": round(tokens_per_word, 2), 349 | "max_expected_tokens": int(max_expected_tokens), 350 | "is_ratio_outlier": is_ratio_outlier, 351 | "is_near_limit": is_near_limit, 352 | "max_tokens": max_tokens, 353 | "limit_usage_percent": round((token_count / max_tokens) * 100, 2) if max_tokens > 0 else 0 354 | } 355 | 356 | if is_ratio_outlier: 357 | logger.error(f"⚠️ Token count outlier: {token_count} tokens for {word_count} words " 358 | f"(ratio: {tokens_per_word:.1f} tokens/word, expected max: {max_expected_tokens:.0f})") 359 | raise TokenCountOutlierError( 360 | f"Token count {token_count} exceeds expected {max_expected_tokens:.0f} " 361 | f"for {word_count} words (ratio: {tokens_per_word:.1f} tokens/word)" 362 | ) 363 | 364 | if is_near_limit: 365 | logger.warning( 366 | f"⚠️ Token generation hit {token_count}/{max_tokens} tokens " 367 | f"({metrics['limit_usage_percent']:.1f}% of limit) - possible repetition loop" 368 | ) 369 | 370 | logger.debug(f"Token count check passed: {token_count} tokens, {tokens_per_word:.1f} tokens/word") 371 | return metrics 372 | 373 | def check_token_variance(tokens: list[str], window_size: int = 100) -> dict: 374 | """ 375 | Check if recent tokens have suspiciously low variance (indicating repetition). 376 | Returns a dict with check results and metrics for logging. 377 | 378 | Args: 379 | tokens: List of raw token strings from the language model 380 | window_size: Number of recent tokens to analyze 381 | 382 | Returns: 383 | dict: Metrics including variance, mean, coefficient_of_variation, is_low_variance 384 | 385 | Raises: 386 | TokenRepetitionError: When token variance is suspiciously low 387 | """ 388 | metrics = { 389 | "window_size": window_size, 390 | "tokens_analyzed": 0, 391 | "variance": 0, 392 | "mean": 0, 393 | "std_dev": 0, 394 | "coefficient_of_variation": 0, 395 | "is_low_variance": False, 396 | "unique_tokens_ratio": 0 397 | } 398 | 399 | if len(tokens) < window_size: 400 | logger.debug(f"Skipping variance check: {len(tokens)} tokens < {window_size} window size") 401 | return metrics 402 | 403 | # Convert recent tokens to IDs 404 | recent_token_ids = [] 405 | count = 0 406 | for token in tokens[-window_size:]: 407 | try: 408 | token_id = turn_token_into_id(token, count) 409 | if token_id > 0: 410 | recent_token_ids.append(token_id) 411 | count += 1 412 | except (TokenParsingError, TokenFormatError): 413 | continue 414 | 415 | if len(recent_token_ids) < 50: 416 | logger.debug(f"Skipping variance check: only {len(recent_token_ids)} valid tokens") 417 | return metrics 418 | 419 | # Calculate variance and other statistics 420 | variance = statistics.variance(recent_token_ids) 421 | mean = statistics.mean(recent_token_ids) 422 | std_dev = statistics.stdev(recent_token_ids) 423 | 424 | # Coefficient of variation (normalized std dev) 425 | coefficient_of_variation = std_dev / mean if mean > 0 else 0 426 | 427 | # Check unique token ratio 428 | unique_tokens = len(set(recent_token_ids)) 429 | unique_ratio = unique_tokens / len(recent_token_ids) 430 | 431 | metrics.update({ 432 | "tokens_analyzed": len(recent_token_ids), 433 | "variance": round(variance, 2), 434 | "mean": round(mean, 2), 435 | "std_dev": round(std_dev, 2), 436 | "coefficient_of_variation": round(coefficient_of_variation, 3), 437 | "unique_tokens": unique_tokens, 438 | "unique_tokens_ratio": round(unique_ratio, 3) 439 | }) 440 | 441 | # Low variance relative to range indicates repetition 442 | # Token IDs are 0-4096, so we'd expect decent variance 443 | LOW_CV_THRESHOLD = 0.15 # Very low variation 444 | LOW_UNIQUE_RATIO_THRESHOLD = 0.20 # Less than 20% unique tokens 445 | 446 | is_low_variance = coefficient_of_variation < LOW_CV_THRESHOLD or unique_ratio < LOW_UNIQUE_RATIO_THRESHOLD 447 | metrics["is_low_variance"] = is_low_variance 448 | 449 | if is_low_variance: 450 | logger.warning( 451 | f"⚠️ Low token variance detected: CV={coefficient_of_variation:.3f}, " 452 | f"unique_ratio={unique_ratio:.3f}, variance={variance:.1f}, mean={mean:.1f}" 453 | ) 454 | # For now, just log warning - could make this raise error if too many false positives 455 | # raise TokenRepetitionError( 456 | # f"Low token variance detected: CV={coefficient_of_variation:.3f}, " 457 | # f"unique_ratio={unique_ratio:.3f}" 458 | # ) 459 | 460 | logger.debug(f"Token variance check: CV={coefficient_of_variation:.3f}, unique_ratio={unique_ratio:.3f}") 461 | return metrics 462 | 463 | async def tokens_decoder(tokens: list[str]) -> AsyncGenerator[bytes, None]: 464 | """ 465 | Decode tokens into audio bytes with proper error handling. 466 | 467 | Args: 468 | tokens: List of token strings from the language model 469 | 470 | Yields: 471 | Audio bytes chunks 472 | 473 | Raises: 474 | AudioDecodingError: When token decoding fails 475 | """ 476 | buffer = [] 477 | count = 0 478 | 479 | logger.debug(f"Starting token decoding for {len(tokens)} tokens") 480 | 481 | for token_string in tokens: 482 | try: 483 | token_id = turn_token_into_id(token_string, count) 484 | 485 | if token_id > 0: 486 | buffer.append(token_id) 487 | count += 1 488 | 489 | # Process buffer when we have enough tokens 490 | if count % 7 == 0 and count > 27: 491 | buffer_to_proc = buffer[-28:] 492 | try: 493 | audio_samples = convert_to_audio(buffer_to_proc, count) 494 | if audio_samples is not None: 495 | yield audio_samples 496 | except AudioDecodingError as e: 497 | # Re-raise audio decoding errors to trigger retry 498 | logger.error(f"Audio decoding failed at count {count}: {e}") 499 | raise 500 | 501 | except (TokenParsingError, TokenFormatError) as e: 502 | # Log token parsing errors but continue processing 503 | logger.info(f"Token parsing error at count {count}: {e}") 504 | continue 505 | 506 | logger.debug(f"Token decoding completed. Processed {count} valid tokens") -------------------------------------------------------------------------------- /audio_analysis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Advanced audio analysis for detecting TTS generation issues. 3 | Uses cross-correlation for repetition detection and spectrogram analysis for stretched audio. 4 | Detects: repeated segments, stretched/slowed audio, silence, clipping, and energy anomalies. 5 | """ 6 | 7 | import numpy as np 8 | import logging 9 | import os 10 | import wave 11 | from pathlib import Path 12 | from typing import Tuple, Dict, List 13 | from scipy.signal import spectrogram, correlate 14 | from dotenv import load_dotenv 15 | 16 | load_dotenv() 17 | 18 | # Logging Configuration 19 | LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() 20 | 21 | # Setup logging 22 | logging.basicConfig(level=getattr(logging, LOG_LEVEL)) 23 | logger = logging.getLogger(__name__) 24 | 25 | # Global flag to enable/disable audio segment extraction for debugging 26 | ENABLE_SEGMENT_EXTRACTION = os.getenv("EXTRACT_AUDIO_SEGMENTS", "False").lower() == "true" 27 | 28 | def analyze_audio_quality(audio_bytes: bytes, sample_rate: int = 24000, 29 | emotion_tag_count: int = 0) -> Tuple[bool, Dict]: 30 | """ 31 | Analyze audio quality using cross-correlation and spectrogram analysis. 32 | Detects: repeated segments, stretched audio, silence, clipping, and energy anomalies. 33 | 34 | Args: 35 | audio_bytes: Raw audio data (16-bit PCM) 36 | sample_rate: Sample rate in Hz (default 24000 for Orpheus) 37 | emotion_tag_count: Number of emotion tags in the text (for adjusting silence thresholds) 38 | 39 | Returns: 40 | Tuple of (has_issues: bool, metrics: dict) 41 | """ 42 | # Convert bytes to numpy array (-1.0 to 1.0 range) 43 | audio_array = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0 44 | 45 | if len(audio_array) == 0: 46 | return True, {"error": "Empty audio"} 47 | 48 | metrics = {} 49 | has_issues = False 50 | 51 | # === 1. BASIC STATS (silence, clipping, energy) === 52 | stats = compute_basic_stats(audio_array) 53 | metrics["stats"] = stats 54 | 55 | # Check for silence issues (adjusted for emotion tags) 56 | silence_threshold = 75.0 + (emotion_tag_count * 5.0) # More tolerance with emotion tags 57 | silence_threshold = min(silence_threshold, 75.0) # Cap at 75% 58 | 59 | if stats["silence_percent"] > silence_threshold: 60 | has_issues = True 61 | logger.warning(f"⚠️ Excessive silence: {stats['silence_percent']:.1f}% of audio") 62 | 63 | # Check for clipping issues 64 | if stats["clipped_percent"] > 1.0: 65 | has_issues = True 66 | logger.warning(f"⚠️ Audio clipping detected: {stats['clipped_percent']:.1f}%") 67 | 68 | # Check for very low energy (volume problem) 69 | if stats["rms_energy"] < 0.01: 70 | has_issues = True 71 | logger.warning(f"⚠️ Very low audio energy: {stats['rms_energy']:.4f}") 72 | 73 | # === 2. REPEATED SEGMENTS DETECTION (cross-correlation) === 74 | repeated_segments = detect_repeated_segments(audio_array, sample_rate) 75 | # Require at least 2 repeated segments to flag as issue (reduces false positives) 76 | # Single similar segment can occur naturally in speech (similar phonemes) 77 | has_repetition = len(repeated_segments) >= 2 78 | 79 | metrics["repeated_segments"] = { 80 | "segments": repeated_segments, 81 | "count": len(repeated_segments), 82 | "has_repetition": has_repetition 83 | } 84 | 85 | if has_repetition: 86 | has_issues = True 87 | logger.warning(f"⚠️ Repeated audio segments detected: {len(repeated_segments)} regions") 88 | 89 | # === 3. SUSTAINED STRETCHING DETECTION === 90 | has_sustained_stretching, sustained_metrics = detect_sustained_stretching(audio_array, sample_rate) 91 | metrics["sustained_stretching"] = sustained_metrics 92 | 93 | if has_sustained_stretching: 94 | has_issues = True 95 | logger.warning(f"⚠️ Sustained audio stretching detected: {sustained_metrics['max_sustained_duration']:.2f}s") 96 | 97 | # === 4. MONOTONIC AUDIO DETECTION (pitch variance) === 98 | is_monotonic, pitch_variance = detect_monotonic_audio(audio_array, sample_rate) 99 | metrics["monotonic"] = { 100 | "is_monotonic": is_monotonic, 101 | "pitch_variance": pitch_variance 102 | } 103 | 104 | if is_monotonic: 105 | has_issues = True 106 | logger.warning(f"⚠️ Monotonic audio detected (pitch variance: {pitch_variance:.1f})") 107 | 108 | # Overall assessment 109 | metrics["has_quality_issues"] = has_issues 110 | metrics["duration_seconds"] = len(audio_array) / sample_rate 111 | 112 | return has_issues, metrics 113 | 114 | 115 | def compute_basic_stats(audio: np.ndarray) -> Dict: 116 | """ 117 | Compute basic audio statistics: silence %, clipping %, energy, etc. 118 | 119 | Args: 120 | audio: Audio signal as numpy array (-1.0 to 1.0) 121 | 122 | Returns: 123 | Dictionary with basic audio statistics 124 | """ 125 | return { 126 | "max": float(np.max(audio)), 127 | "min": float(np.min(audio)), 128 | "mean": float(np.mean(audio)), 129 | "rms_energy": float(np.sqrt(np.mean(audio**2))), 130 | "silence_percent": float(np.mean(np.abs(audio) < 0.01)) * 100, 131 | "clipped_percent": float(np.mean(np.abs(audio) > 0.98)) * 100, 132 | } 133 | 134 | 135 | def detect_repeated_segments(audio: np.ndarray, sample_rate: int, 136 | window_ms: int = 300, 137 | similarity_threshold: float = 0.85, 138 | energy_threshold: float = 0.02) -> List[Tuple[float, float]]: 139 | """ 140 | Detect repeated or looped audio segments using normalized cross-correlation. 141 | Uses sliding-window similarity detection. 142 | 143 | Args: 144 | audio: Audio signal as numpy array (-1.0 to 1.0) 145 | sample_rate: Sample rate in Hz 146 | window_ms: Window size in milliseconds for segment comparison 147 | similarity_threshold: Correlation threshold (0.85 = 85% similar) 148 | energy_threshold: Minimum RMS energy to consider (skips silent segments) 149 | 150 | Returns: 151 | List of (time_start, time_end) tuples for suspicious repeated regions 152 | """ 153 | window_samples = int(sample_rate * window_ms / 1000) 154 | repeated = [] 155 | 156 | for i in range(0, len(audio) - 2 * window_samples, window_samples): 157 | seg1 = audio[i:i + window_samples] 158 | seg2 = audio[i + window_samples:i + 2 * window_samples] 159 | 160 | # Skip if either segment has no variation (pure silence) 161 | if np.std(seg1) == 0 or np.std(seg2) == 0: 162 | continue 163 | 164 | # Skip segments with very low energy (silence/near-silence) 165 | rms1 = np.sqrt(np.mean(seg1 ** 2)) 166 | rms2 = np.sqrt(np.mean(seg2 ** 2)) 167 | if rms1 < energy_threshold or rms2 < energy_threshold: 168 | continue 169 | 170 | # Normalize segments 171 | seg1_norm = (seg1 - np.mean(seg1)) / np.std(seg1) 172 | seg2_norm = (seg2 - np.mean(seg2)) / np.std(seg2) 173 | 174 | # Cross-correlation 175 | corr = correlate(seg1_norm, seg2_norm, mode='valid') 176 | corr /= len(seg1_norm) 177 | sim = np.max(corr) 178 | 179 | if sim > similarity_threshold: 180 | time_start = i / sample_rate 181 | time_end = (i + 2 * window_samples) / sample_rate 182 | repeated.append((round(time_start, 2), round(time_end, 2))) 183 | 184 | return repeated 185 | 186 | 187 | def detect_sustained_stretching(audio: np.ndarray, sample_rate: int, 188 | min_sustained_duration: float = 1.0, 189 | energy_threshold: float = 0.02, 190 | variation_threshold: float = 0.1977) -> Tuple[bool, Dict]: 191 | """ 192 | Detect stretched/prolonged audio (e.g., "iiiiiii", "myyyyyyy", "buuuuut") by finding 193 | long periods of sustained high energy with low variation. 194 | 195 | Stretched phonemes maintain consistent amplitude for unnaturally long periods. 196 | This works for vowels, consonants, or any speech sound that gets abnormally prolonged. 197 | 198 | Args: 199 | audio: Audio signal as numpy array (-1.0 to 1.0) 200 | sample_rate: Sample rate in Hz 201 | min_sustained_duration: Minimum duration to flag as stretched (seconds) 202 | energy_threshold: Minimum RMS energy to consider as "active" speech 203 | variation_threshold: Maximum coefficient of variation for sustained regions 204 | 205 | Returns: 206 | Tuple of (has_stretching: bool, metrics: dict) 207 | """ 208 | # Calculate RMS energy in small windows (50ms) 209 | window_size = int(0.05 * sample_rate) 210 | hop_size = window_size // 2 211 | 212 | rms_values = [] 213 | for i in range(0, len(audio) - window_size, hop_size): 214 | window = audio[i:i + window_size] 215 | rms = np.sqrt(np.mean(window ** 2)) 216 | rms_values.append(rms) 217 | 218 | rms_values = np.array(rms_values) 219 | 220 | if len(rms_values) < 20: # Too short to analyze 221 | return False, {"info": "Audio too short for sustained vowel analysis"} 222 | 223 | # Find regions with sustained high energy 224 | is_high_energy = rms_values > energy_threshold 225 | 226 | # Look for long stretches of high energy with low variation 227 | max_sustained_duration = 0.0 228 | max_sustained_variation = 0.0 229 | current_stretch = [] 230 | current_stretch_start_idx = 0 231 | sustained_regions = [] 232 | 233 | for i, high_energy in enumerate(is_high_energy): 234 | if high_energy: 235 | if len(current_stretch) == 0: 236 | current_stretch_start_idx = i # Mark the start 237 | current_stretch.append(rms_values[i]) 238 | else: 239 | # End of a stretch 240 | if len(current_stretch) > 0: 241 | duration_seconds = (len(current_stretch) * hop_size) / sample_rate 242 | 243 | # Check if this stretch has low variation (sustained) 244 | if len(current_stretch) >= 3: # Need at least 3 samples 245 | mean_energy = np.mean(current_stretch) 246 | std_energy = np.std(current_stretch) 247 | cv = std_energy / mean_energy if mean_energy > 0 else 0 248 | 249 | if duration_seconds > min_sustained_duration and cv < variation_threshold: 250 | # Calculate actual time positions in the audio 251 | start_time = (current_stretch_start_idx * hop_size) / sample_rate 252 | end_time = start_time + duration_seconds 253 | 254 | sustained_regions.append({ 255 | 'start_time': round(float(start_time), 2), 256 | 'end_time': round(float(end_time), 2), 257 | 'duration': round(float(duration_seconds), 2), 258 | 'cv': round(float(cv), 3), 259 | 'mean_energy': round(float(mean_energy), 4) 260 | }) 261 | 262 | if duration_seconds > max_sustained_duration: 263 | max_sustained_duration = duration_seconds 264 | max_sustained_variation = cv 265 | 266 | current_stretch = [] 267 | 268 | # Check final stretch 269 | if len(current_stretch) > 0: 270 | duration_seconds = (len(current_stretch) * hop_size) / sample_rate 271 | if len(current_stretch) >= 3: 272 | mean_energy = np.mean(current_stretch) 273 | std_energy = np.std(current_stretch) 274 | cv = std_energy / mean_energy if mean_energy > 0 else 0 275 | 276 | if duration_seconds > min_sustained_duration and cv < variation_threshold: 277 | # Calculate actual time positions in the audio 278 | start_time = (current_stretch_start_idx * hop_size) / sample_rate 279 | end_time = start_time + duration_seconds 280 | 281 | sustained_regions.append({ 282 | 'start_time': round(float(start_time), 2), 283 | 'end_time': round(float(end_time), 2), 284 | 'duration': round(float(duration_seconds), 2), 285 | 'cv': round(float(cv), 3), 286 | 'mean_energy': round(float(mean_energy), 4) 287 | }) 288 | 289 | if duration_seconds > max_sustained_duration: 290 | max_sustained_duration = duration_seconds 291 | max_sustained_variation = cv 292 | 293 | has_stretching = len(sustained_regions) > 0 294 | 295 | metrics = { 296 | "has_sustained_stretching": has_stretching, 297 | "sustained_regions_count": len(sustained_regions), 298 | "max_sustained_duration": round(float(max_sustained_duration), 2), 299 | "max_sustained_variation": round(float(max_sustained_variation), 3), 300 | "sustained_regions": sustained_regions # Show all regions with timestamps 301 | } 302 | 303 | return has_stretching, metrics 304 | 305 | 306 | def detect_monotonic_audio(audio: np.ndarray, sample_rate: int, 307 | min_pitch_variance: float = 30.0) -> Tuple[bool, float]: 308 | """ 309 | Detect monotonic/flat TTS artifacts by analyzing pitch stability. 310 | 311 | Only checks for TOO LOW variance (monotonic/robot-like speech). 312 | 313 | Uses spectrogram to compute dominant frequency over time, then measures variance. 314 | 315 | Args: 316 | audio: Audio signal as numpy array (-1.0 to 1.0) 317 | sample_rate: Sample rate in Hz 318 | min_pitch_variance: Minimum acceptable pitch variance (default 30 Hz²) 319 | 320 | Returns: 321 | Tuple of (is_monotonic: bool, pitch_variance: float) 322 | """ 323 | try: 324 | # Compute spectrogram 325 | f, t, Sxx = spectrogram(audio, sample_rate, nperseg=1024) 326 | 327 | # Find dominant frequency at each time step 328 | dominant_freq = f[np.argmax(Sxx, axis=0)] 329 | 330 | # Calculate pitch variance 331 | pitch_variance = float(np.var(dominant_freq)) 332 | 333 | # Only check for low variance (monotonic/flat) 334 | is_monotonic = pitch_variance < min_pitch_variance 335 | 336 | return is_monotonic, round(pitch_variance, 2) 337 | 338 | except Exception as e: 339 | logger.warning(f"Error in monotonic audio detection: {e}") 340 | return False, 0.0 341 | 342 | 343 | def quick_audio_check(audio_chunks: list[bytes], sample_rate: int = 24000, 344 | emotion_tag_count: int = 0) -> Tuple[bool, Dict]: 345 | """ 346 | Quick check of audio quality using cross-correlation and spectrogram analysis. 347 | Detects: repeated segments, stretched audio, silence, clipping. 348 | 349 | Args: 350 | audio_chunks: List of audio chunk bytes 351 | sample_rate: Sample rate in Hz 352 | emotion_tag_count: Number of emotion tags in the text (affects silence thresholds) 353 | 354 | Returns: 355 | Tuple of (has_issues: bool, summary_metrics: dict) 356 | """ 357 | # Combine chunks 358 | combined_audio = b''.join(audio_chunks) 359 | 360 | # Run full analysis 361 | has_issues, metrics = analyze_audio_quality(combined_audio, sample_rate, emotion_tag_count) 362 | 363 | # Summarize key findings 364 | stats = metrics.get("stats", {}) 365 | repeated = metrics.get("repeated_segments", {}) 366 | sustained = metrics.get("sustained_stretching", {}) 367 | monotonic = metrics.get("monotonic", {}) 368 | 369 | summary = { 370 | "has_quality_issues": has_issues, 371 | "duration_seconds": metrics.get("duration_seconds", 0), 372 | "silence_percent": stats.get("silence_percent", 0), 373 | "clipped_percent": stats.get("clipped_percent", 0), 374 | "rms_energy": stats.get("rms_energy", 0), 375 | "repeated_segment_count": repeated.get("count", 0), 376 | "has_repetition": repeated.get("has_repetition", False), 377 | "has_sustained_stretching": sustained.get("has_sustained_stretching", False), 378 | "max_sustained_duration": sustained.get("max_sustained_duration", 0), 379 | "is_monotonic": monotonic.get("is_monotonic", False), 380 | "pitch_variance": monotonic.get("pitch_variance", 0), 381 | "emotion_tag_count": emotion_tag_count, 382 | } 383 | 384 | return has_issues, summary 385 | 386 | 387 | def extract_audio_segment(input_path: str, start_time: float, end_time: float, output_path: str): 388 | """ 389 | Extract a segment from a WAV file based on timestamps. 390 | 391 | Args: 392 | input_path: Path to input WAV file 393 | start_time: Start time in seconds 394 | end_time: End time in seconds 395 | output_path: Path to output WAV file 396 | """ 397 | try: 398 | # Open input file 399 | with wave.open(input_path, 'rb') as wf_in: 400 | # Get audio parameters 401 | channels = wf_in.getnchannels() 402 | sample_width = wf_in.getsampwidth() 403 | frame_rate = wf_in.getframerate() 404 | 405 | # Calculate frame positions 406 | start_frame = int(start_time * frame_rate) 407 | end_frame = int(end_time * frame_rate) 408 | num_frames = end_frame - start_frame 409 | 410 | # Validate 411 | total_frames = wf_in.getnframes() 412 | if start_frame < 0 or end_frame > total_frames: 413 | logger.warning(f"Invalid time range for extraction: {start_time:.2f}s to {end_time:.2f}s") 414 | return 415 | 416 | # Move to start position 417 | wf_in.setpos(start_frame) 418 | 419 | # Read the segment 420 | audio_data = wf_in.readframes(num_frames) 421 | 422 | # Write to output file 423 | with wave.open(output_path, 'wb') as wf_out: 424 | wf_out.setnchannels(channels) 425 | wf_out.setsampwidth(sample_width) 426 | wf_out.setframerate(frame_rate) 427 | wf_out.writeframes(audio_data) 428 | 429 | logger.info(f"✂️ Extracted segment: {start_time:.2f}s to {end_time:.2f}s → {output_path}") 430 | 431 | except Exception as e: 432 | logger.error(f"Failed to extract segment: {e}") 433 | 434 | 435 | def analyze_audio_file(file_path: str, emotion_tag_count: int = 0) -> Dict: 436 | """ 437 | Analyze an audio file and return comprehensive quality metrics. 438 | 439 | Args: 440 | file_path: Path to WAV audio file 441 | emotion_tag_count: Number of emotion tags in the original text (optional) 442 | 443 | Returns: 444 | Dictionary with analysis results 445 | """ 446 | # Read WAV file 447 | with wave.open(file_path, 'rb') as wf: 448 | sample_rate = wf.getframerate() 449 | n_frames = wf.getnframes() 450 | audio_bytes = wf.readframes(n_frames) 451 | 452 | # Run analysis 453 | has_issues, metrics = analyze_audio_quality(audio_bytes, sample_rate, emotion_tag_count) 454 | 455 | return metrics 456 | 457 | 458 | def main(): 459 | """ 460 | Command-line interface for analyzing audio files. 461 | Usage: python audio_analysis.py [emotion_tag_count] [--extract] 462 | """ 463 | import sys 464 | import json 465 | 466 | if len(sys.argv) < 2 or '--help' in sys.argv or '-h' in sys.argv: 467 | print("Usage: python audio_analysis.py [emotion_tag_count] [--extract]") 468 | print("\nArguments:") 469 | print(" audio_file.wav Path to audio file to analyze") 470 | print(" emotion_tag_count Number of emotion tags (optional, default: 0)") 471 | print(" --extract Extract problematic segments to files (optional)") 472 | print("\nExamples:") 473 | print(" python audio_analysis.py debug_audio_errors/audio.wav") 474 | print(" python audio_analysis.py debug_audio_errors/audio.wav 2") 475 | print(" python audio_analysis.py debug_audio_errors/audio.wav 0 --extract") 476 | print("\nEnvironment Variables:") 477 | print(" EXTRACT_AUDIO_SEGMENTS=true Enable automatic segment extraction") 478 | sys.exit(1) 479 | 480 | # Parse arguments 481 | file_path = sys.argv[1] 482 | emotion_tag_count = 0 483 | extract_flag = '--extract' in sys.argv 484 | 485 | # Parse emotion_tag_count if provided 486 | for arg in sys.argv[2:]: 487 | if arg != '--extract': 488 | try: 489 | emotion_tag_count = int(arg) 490 | except ValueError: 491 | pass 492 | 493 | # Enable extraction if flag is set OR environment variable is true 494 | global ENABLE_SEGMENT_EXTRACTION 495 | if extract_flag: 496 | ENABLE_SEGMENT_EXTRACTION = True 497 | 498 | print(f"\n{'='*70}") 499 | print(f"🎵 Audio Quality Analysis") 500 | print(f"{'='*70}") 501 | print(f"File: {file_path}") 502 | print(f"Emotion tags: {emotion_tag_count}") 503 | if ENABLE_SEGMENT_EXTRACTION: 504 | print(f"Segment extraction: ✂️ ENABLED") 505 | print(f"{'='*70}\n") 506 | 507 | try: 508 | # Analyze the file 509 | metrics = analyze_audio_file(file_path, emotion_tag_count) 510 | 511 | # Extract key sections 512 | stats = metrics.get("stats", {}) 513 | repeated = metrics.get("repeated_segments", {}) 514 | sustained = metrics.get("sustained_stretching", {}) 515 | monotonic = metrics.get("monotonic", {}) 516 | has_issues = metrics.get("has_quality_issues", False) 517 | 518 | # Display results 519 | print("📊 OVERALL ASSESSMENT") 520 | print(f" Quality Issues: {'❌ YES' if has_issues else '✅ NO'}") 521 | print(f" Duration: {metrics.get('duration_seconds', 0):.2f} seconds") 522 | print() 523 | 524 | print("📈 BASIC STATISTICS") 525 | print(f" RMS Energy: {stats.get('rms_energy', 0):.4f}") 526 | print(f" Silence: {stats.get('silence_percent', 0):.1f}%") 527 | print(f" Clipping: {stats.get('clipped_percent', 0):.1f}%") 528 | print(f" Peak (max): {stats.get('max', 0):.4f}") 529 | print(f" Peak (min): {stats.get('min', 0):.4f}") 530 | print() 531 | 532 | print("🔄 REPETITION DETECTION") 533 | print(f" Has Repetition: {'❌ YES' if repeated.get('has_repetition', False) else '✅ NO'}") 534 | print(f" Repeated Segments: {repeated.get('count', 0)}") 535 | if repeated.get('segments'): 536 | print(f" Time Ranges:") 537 | for start, end in repeated.get('segments', [])[:5]: # Show first 5 538 | print(f" - {start:.2f}s to {end:.2f}s") 539 | if len(repeated.get('segments', [])) > 5: 540 | print(f" ... and {len(repeated.get('segments', [])) - 5} more") 541 | print() 542 | 543 | print("🎤 SUSTAINED AUDIO STRETCHING DETECTION") 544 | has_sustained = sustained.get('has_sustained_stretching', False) 545 | max_duration = sustained.get('max_sustained_duration', 0) 546 | sustained_count = sustained.get('sustained_regions_count', 0) 547 | sustained_regions_list = sustained.get('sustained_regions', []) 548 | 549 | print(f" Has Stretching: {'❌ YES' if has_sustained else '✅ NO'}") 550 | print(f" Sustained Regions: {sustained_count}") 551 | 552 | if has_sustained: 553 | print(f" Longest Sustained: {max_duration:.2f}s") 554 | print(f" Max Variation: {sustained.get('max_sustained_variation', 0):.3f}") 555 | 556 | if sustained_regions_list: 557 | print(f" Time Ranges:") 558 | for region in sustained_regions_list[:10]: # Show first 10 559 | start = region.get('start_time', 0) 560 | end = region.get('end_time', 0) 561 | dur = region.get('duration', 0) 562 | cv = region.get('cv', 0) 563 | print(f" - {start:.2f}s to {end:.2f}s ({dur:.2f}s, CV: {cv:.3f})") 564 | 565 | if len(sustained_regions_list) > 10: 566 | print(f" ... and {len(sustained_regions_list) - 10} more regions") 567 | 568 | print() 569 | 570 | print("🎵 MONOTONIC AUDIO DETECTION") 571 | is_monotonic_flag = monotonic.get('is_monotonic', False) 572 | pitch_var = monotonic.get('pitch_variance', 0) 573 | print(f" Is Monotonic: {'❌ YES' if is_monotonic_flag else '✅ NO'}") 574 | print(f" Pitch Variance: {pitch_var:.2f} Hz²") 575 | 576 | if is_monotonic_flag: 577 | print(f" Type: Flat/Robot-like (variance too low < 30 Hz²)") 578 | else: 579 | print(f" Normal Range: > 30 Hz²") 580 | print() 581 | 582 | # Summary 583 | print(f"{'='*70}") 584 | if has_issues: 585 | print("⚠️ ISSUES DETECTED - Audio may need regeneration") 586 | issues = [] 587 | if stats.get('silence_percent', 0) > 60: 588 | issues.append(f"Excessive silence ({stats.get('silence_percent', 0):.1f}%)") 589 | if stats.get('clipped_percent', 0) > 1: 590 | issues.append(f"Audio clipping ({stats.get('clipped_percent', 0):.1f}%)") 591 | if stats.get('rms_energy', 0) < 0.01: 592 | issues.append(f"Low energy ({stats.get('rms_energy', 0):.4f})") 593 | if repeated.get('has_repetition', False): 594 | issues.append(f"Repeated segments ({repeated.get('count', 0)})") 595 | if sustained.get('has_sustained_stretching', False): 596 | max_dur = sustained.get('max_sustained_duration', 0) 597 | issues.append(f"Sustained audio stretching ({max_dur:.2f}s)") 598 | if monotonic.get('is_monotonic', False): 599 | pitch_var = monotonic.get('pitch_variance', 0) 600 | issues.append(f"Monotonic audio (variance: {pitch_var:.1f} Hz²)") 601 | 602 | for issue in issues: 603 | print(f" • {issue}") 604 | else: 605 | print("✅ Audio quality is good") 606 | print(f"{'='*70}\n") 607 | 608 | # Print full JSON for debugging 609 | print("📋 FULL METRICS (JSON):") 610 | print(json.dumps(metrics, indent=2)) 611 | 612 | # Extract problematic segments if enabled 613 | if ENABLE_SEGMENT_EXTRACTION and has_issues: 614 | print(f"\n{'='*70}") 615 | print("✂️ EXTRACTING PROBLEMATIC SEGMENTS") 616 | print(f"{'='*70}") 617 | 618 | # Get the directory where the input file is located 619 | input_dir = Path(file_path).parent 620 | input_stem = Path(file_path).stem # filename without extension 621 | 622 | extracted_count = 0 623 | 624 | # Extract sustained stretching segments 625 | if sustained.get('has_sustained_stretching', False): 626 | sustained_regions_list = sustained.get('sustained_regions', []) 627 | for i, region in enumerate(sustained_regions_list, 1): 628 | start = region.get('start_time', 0) 629 | end = region.get('end_time', 0) 630 | output_name = f"{input_stem}_sustained_stretch_{i}_{start:.2f}s-{end:.2f}s.wav" 631 | output_path = input_dir / output_name 632 | 633 | extract_audio_segment(file_path, start, end, str(output_path)) 634 | extracted_count += 1 635 | print(f" ✅ Sustained stretch #{i}: {output_path.name}") 636 | 637 | # Extract repeated segments 638 | if repeated.get('has_repetition', False): 639 | repeated_segments_list = repeated.get('segments', []) 640 | for i, (start, end) in enumerate(repeated_segments_list, 1): 641 | output_name = f"{input_stem}_repetition_{i}_{start:.2f}s-{end:.2f}s.wav" 642 | output_path = input_dir / output_name 643 | 644 | extract_audio_segment(file_path, start, end, str(output_path)) 645 | extracted_count += 1 646 | print(f" ✅ Repetition #{i}: {output_path.name}") 647 | 648 | if extracted_count > 0: 649 | print(f"\n✂️ Extracted {extracted_count} segment(s) to: {input_dir}") 650 | else: 651 | print(f"\n⚠️ Issues detected but no extractable segments") 652 | 653 | print(f"{'='*70}") 654 | 655 | except FileNotFoundError: 656 | print(f"❌ Error: File not found: {file_path}") 657 | sys.exit(1) 658 | except Exception as e: 659 | print(f"❌ Error analyzing audio: {e}") 660 | import traceback 661 | traceback.print_exc() 662 | sys.exit(1) 663 | 664 | 665 | if __name__ == "__main__": 666 | main() 667 | 668 | -------------------------------------------------------------------------------- /audio_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import asyncio 3 | import uuid 4 | import os 5 | import wave 6 | import logging 7 | import json 8 | from datetime import datetime 9 | from typing import Optional, Tuple, List, Dict, Literal 10 | from concurrent.futures import ThreadPoolExecutor 11 | 12 | from orpheus_tts import OrpheusModel 13 | from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams 14 | from audio_decoder import ( 15 | tokens_decoder, 16 | AudioDecodingError, 17 | TokenRepetitionError, 18 | AudioDurationOutlierError, 19 | TokenCountOutlierError, 20 | check_token_repetition, 21 | check_token_count_ratio, 22 | check_token_variance, 23 | normalize_and_count_words 24 | ) 25 | from audio_analysis import quick_audio_check 26 | from dotenv import load_dotenv 27 | 28 | load_dotenv() 29 | 30 | # Logging Configuration 31 | LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() 32 | 33 | # Setup logging 34 | logging.basicConfig(level=getattr(logging, LOG_LEVEL)) 35 | logger = logging.getLogger(__name__) 36 | 37 | # Retry configuration 38 | MAX_RETRIES = 5 39 | RETRY_DELAY = 0.1 40 | 41 | # Debug configuration 42 | DEBUG_AUDIO_DIR = "debug_audio_errors" 43 | DEBUG_SUCCESS_DIR = "debug_audio_success" 44 | 45 | # Debug Configuration 46 | ENABLE_DEBUG_SAVING = os.getenv("ENABLE_DEBUG_SAVING", "False").lower() == "true" 47 | ENABLE_SUCCESS_LOGGING = os.getenv("ENABLE_SUCCESS_LOGGING", "False").lower() == "true" 48 | 49 | def is_audio_duration_outlier(text: str, duration_seconds: float) -> Tuple[bool, dict]: 50 | """ 51 | Check if audio duration is an outlier for given text. 52 | Uses multi-tier detection to catch both obvious outliers and moderate slowdowns. 53 | Accounts for emotion/sound effect tags that add non-speech audio duration. 54 | 55 | Args: 56 | text: The input text 57 | duration_seconds: The generated audio duration in seconds 58 | 59 | Returns: 60 | Tuple of (is_outlier: bool, metrics: dict) 61 | """ 62 | # Count words using normalized counting (handles hyphens, underscores, etc.) 63 | word_count = normalize_and_count_words(text) 64 | 65 | # Detect emotion/sound effect tags that add extra audio duration 66 | # These tags cause the model to generate vocal sounds/effects beyond just speech 67 | import re 68 | emotion_tags = re.findall(r'<(\w+)>', text) 69 | 70 | # Orpheus-supported emotion/sound tags that add audio duration 71 | # Each tag causes the model to generate vocal sound effects 72 | ORPHEUS_EMOTION_TAGS = [ 73 | 'laugh', # Laughter or laughing sounds 74 | 'chuckle', # Light laughter or chuckling 75 | 'sigh', # Sighing sounds (resignation/relief) 76 | 'cough', # Coughing or throat clearing 77 | 'sniffle', # Sniffling or nasal sounds 78 | 'groan', # Groaning (discomfort/frustration) 79 | 'yawn', # Yawning (tiredness) 80 | 'gasp', # Gasping (surprise/shock) 81 | ] 82 | 83 | # Count emotion tags in the text 84 | tag_count = sum(1 for tag in emotion_tags if tag.lower() in ORPHEUS_EMOTION_TAGS) 85 | 86 | # Add time allowance for emotion tags 87 | # Conservative estimate: 1.2 seconds per tag (covers most vocal effects without being too lenient) 88 | # This prevents false positives while still catching real repetition issues 89 | SECONDS_PER_EMOTION_TAG = 1.2 90 | emotion_tag_allowance = tag_count * SECONDS_PER_EMOTION_TAG 91 | 92 | # === TIER 1: Absolute threshold check (catches extreme outliers) === 93 | # Tightened thresholds based on typical TTS behavior 94 | MIN_WORDS_PER_SECOND = 2.0 # Slow TTS (120 wpm) - was 1.5 (too lenient) 95 | MAX_WORDS_PER_SECOND = 4.5 # Very fast TTS (270 wpm) - was 4.0 96 | 97 | # Expected duration range 98 | min_expected_duration = word_count / MAX_WORDS_PER_SECOND 99 | max_expected_duration = word_count / MIN_WORDS_PER_SECOND 100 | 101 | # Add buffer for punctuation/pauses (20% buffer, reduced from 30%) 102 | max_expected_duration_absolute = max_expected_duration * 1 103 | 104 | # # For very short texts, be more lenient 105 | # # Single words with punctuation can have natural pauses and emphasis 106 | # if word_count == 1: 107 | # max_expected_duration_absolute *= 5.0 # Very lenient for single words 108 | # elif word_count <= 5: 109 | # max_expected_duration_absolute *= 3.0 # Lenient for short texts (1-5 words) 110 | # elif word_count <= 10: 111 | # max_expected_duration_absolute *= 2.2 # Moderately lenient for short phrases (6-10 words) 112 | 113 | # === TIER 2: Expected speech rate check (catches moderate slowdowns) === 114 | # Use typical TTS speech rate as baseline 115 | # Based on observed data, this model tends to generate faster speech (3+ wps) 116 | EXPECTED_WORDS_PER_SECOND = 3.0 # Adjusted for this model's typical output (180 wpm) 117 | expected_duration = word_count / EXPECTED_WORDS_PER_SECOND 118 | 119 | # Graduated deviation thresholds based on text length 120 | # Shorter texts have more natural variation in pacing 121 | if word_count <= 20: 122 | # Very short texts: use absolute threshold only (most lenient) 123 | max_expected_duration_final = max_expected_duration_absolute 124 | detection_method = "absolute" 125 | elif word_count <= 60: 126 | # Medium texts (21-60 words): allow more deviation due to higher natural pacing variation 127 | MAX_DEVIATION_PERCENT = 35 # More lenient for medium texts 128 | max_expected_duration_relative = expected_duration * (1 + MAX_DEVIATION_PERCENT / 100) 129 | max_expected_duration_final = min(max_expected_duration_absolute, max_expected_duration_relative) 130 | detection_method = "absolute" if max_expected_duration_absolute < max_expected_duration_relative else "relative" 131 | elif word_count <= 100: 132 | # Longer texts (61-100 words): moderate deviation 133 | MAX_DEVIATION_PERCENT = 22 # Moderate 134 | max_expected_duration_relative = expected_duration * (1 + MAX_DEVIATION_PERCENT / 100) 135 | max_expected_duration_final = min(max_expected_duration_absolute, max_expected_duration_relative) 136 | detection_method = "absolute" if max_expected_duration_absolute < max_expected_duration_relative else "relative" 137 | else: 138 | # Very long texts (>100 words): strict detection 139 | MAX_DEVIATION_PERCENT = 17 # Strict for catching edge cases like the 103-word, 40s sample 140 | max_expected_duration_relative = expected_duration * (1 + MAX_DEVIATION_PERCENT / 100) 141 | max_expected_duration_final = min(max_expected_duration_absolute, max_expected_duration_relative) 142 | detection_method = "absolute" if max_expected_duration_absolute < max_expected_duration_relative else "relative" 143 | 144 | # Calculate metrics 145 | words_per_second = word_count / duration_seconds if duration_seconds > 0 else 0 146 | deviation_from_expected = ((duration_seconds - expected_duration) / expected_duration * 100) if expected_duration > 0 else 0 147 | 148 | # === TIER 3: Grace buffer (absolute seconds, not percentage) === 149 | # Add a flat grace buffer to account for natural TTS variation 150 | # This prevents flagging samples that are only marginally over the percentage threshold 151 | # Graduated by text length to be appropriate for each range 152 | if word_count <= 20: 153 | GRACE_BUFFER_SECONDS = 1.0 # Conservative for short texts (prevents missing real slowdowns) 154 | elif word_count <= 60: 155 | GRACE_BUFFER_SECONDS = 2.5 # Medium buffer for medium texts (21-60 words) - increased to match 61-100 range 156 | elif word_count <= 100: 157 | GRACE_BUFFER_SECONDS = 2.5 # Need more buffer for 61-100 word range due to natural variation 158 | else: 159 | GRACE_BUFFER_SECONDS = 0.0 # No grace buffer for >100 words (strict edge case detection) 160 | 161 | # Apply grace buffer and emotion tag allowance 162 | adjusted_threshold = max_expected_duration_final + emotion_tag_allowance + GRACE_BUFFER_SECONDS 163 | adjusted_threshold = 1.15 * adjusted_threshold 164 | 165 | # Determine if this is an outlier (after accounting for emotion tags and grace buffer) 166 | is_outlier = duration_seconds > adjusted_threshold 167 | 168 | metrics = { 169 | "word_count": word_count, 170 | "duration_seconds": round(duration_seconds, 2), 171 | "words_per_second": round(words_per_second, 2), 172 | "min_expected_duration": round(min_expected_duration, 2), 173 | "max_expected_duration": round(adjusted_threshold, 2), 174 | "max_expected_duration_base": round(max_expected_duration_final, 2), 175 | "expected_duration_baseline": round(expected_duration, 2), 176 | "max_expected_absolute": round(max_expected_duration_absolute, 2), 177 | "max_expected_relative": round(max_expected_duration_relative, 2) if word_count > 20 else None, 178 | "detection_method": detection_method, 179 | "grace_buffer_seconds": GRACE_BUFFER_SECONDS, 180 | "emotion_tag_count": tag_count, 181 | "emotion_tag_allowance": round(emotion_tag_allowance, 2), 182 | "emotion_tags_detected": emotion_tags if tag_count > 0 else [], 183 | "is_duration_outlier": is_outlier, 184 | "duration_deviation_percent": round(((duration_seconds - adjusted_threshold) / adjusted_threshold * 100), 2) if adjusted_threshold > 0 else 0, 185 | "deviation_from_expected_percent": round(deviation_from_expected, 2) 186 | } 187 | 188 | if is_outlier: 189 | emotion_tag_info = f", +{emotion_tag_allowance:.1f}s emotion" if tag_count > 0 else "" 190 | logger.warning( 191 | f"⚠️ Audio duration outlier detected ({detection_method}): " 192 | f"{duration_seconds:.2f}s for {word_count} words " 193 | f"(expected: ~{expected_duration:.2f}s, max: {max_expected_duration_final:.2f}s +{GRACE_BUFFER_SECONDS:.1f}s grace{emotion_tag_info}, " 194 | f"actual: {words_per_second:.2f} words/sec, deviation: {deviation_from_expected:+.1f}%)" 195 | ) 196 | 197 | return is_outlier, metrics 198 | 199 | def convert_numpy_types(obj): 200 | """Convert numpy types to native Python types for JSON serialization""" 201 | import numpy as np 202 | 203 | if isinstance(obj, np.integer): 204 | return int(obj) 205 | elif isinstance(obj, np.floating): 206 | return float(obj) 207 | elif isinstance(obj, np.ndarray): 208 | return obj.tolist() 209 | elif isinstance(obj, dict): 210 | return {key: convert_numpy_types(value) for key, value in obj.items()} 211 | elif isinstance(obj, list): 212 | return [convert_numpy_types(item) for item in obj] 213 | else: 214 | return obj 215 | 216 | async def save_debug_audio_with_metadata(audio_chunks: list[bytes], text: str, error_type: str, 217 | request_id: str, metadata: dict, executor: ThreadPoolExecutor) -> str: 218 | """ 219 | Save problematic audio file and detailed JSON metadata for debugging. 220 | 221 | Args: 222 | audio_chunks: The audio data chunks 223 | text: The input text 224 | error_type: Type of error/edge case detected 225 | request_id: Unique request identifier 226 | metadata: Additional metadata to save (should include 'attempts' field) 227 | executor: Thread pool executor for file I/O 228 | 229 | Returns: 230 | Path to the saved debug directory 231 | """ 232 | if not ENABLE_DEBUG_SAVING: 233 | logger.debug("Debug saving is disabled") 234 | return "" 235 | 236 | try: 237 | # Get attempt number from metadata 238 | attempt_num = metadata.get("attempts", 1) 239 | 240 | # Create debug directory with error type, attempt, timestamp, and request ID for easy browsing 241 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 242 | debug_dir = os.path.join(DEBUG_AUDIO_DIR, f"{error_type}_attempt{attempt_num}_{timestamp}_{request_id}") 243 | os.makedirs(debug_dir, exist_ok=True) 244 | 245 | # Save audio file 246 | audio_path = os.path.join(debug_dir, "audio.wav") 247 | 248 | def write_debug_audio(): 249 | with wave.open(audio_path, "wb") as wf: 250 | wf.setnchannels(1) # mono 251 | wf.setsampwidth(2) # 16-bit 252 | wf.setframerate(24000) # 24kHz 253 | 254 | for chunk in audio_chunks: 255 | if chunk: 256 | wf.writeframes(chunk) 257 | 258 | # Write audio in executor 259 | await asyncio.get_running_loop().run_in_executor(executor, write_debug_audio) 260 | 261 | # Prepare comprehensive metadata with retry info at top level for easy viewing 262 | # Convert numpy types to native Python types for JSON serialization 263 | debug_metadata = { 264 | "timestamp": timestamp, 265 | "request_id": request_id, 266 | "error_type": error_type, 267 | "attempt_number": attempt_num, 268 | "retry_number": attempt_num - 1, # 0-indexed retry count 269 | "input_text": text, 270 | "text_length": len(text), 271 | "word_count": normalize_and_count_words(text), 272 | "audio_file": "audio.wav", 273 | "full_metadata": convert_numpy_types(metadata) 274 | } 275 | 276 | # Save metadata as JSON 277 | metadata_path = os.path.join(debug_dir, "metadata.json") 278 | 279 | def write_metadata(): 280 | with open(metadata_path, "w", encoding="utf-8") as f: 281 | json.dump(debug_metadata, f, indent=2, ensure_ascii=False) 282 | 283 | # Write metadata in executor 284 | await asyncio.get_running_loop().run_in_executor(executor, write_metadata) 285 | 286 | logger.info(f"🔍 Debug files saved to: {debug_dir}") 287 | return debug_dir 288 | 289 | except Exception as e: 290 | logger.error(f"Failed to save debug files: {e}") 291 | return "" 292 | 293 | async def save_successful_audio_with_metadata(audio_chunks: list[bytes], text: str, 294 | request_id: str, metadata: dict, 295 | tokens: list[str], executor: ThreadPoolExecutor) -> str: 296 | """ 297 | Save successful audio generation with complete metadata for debugging and fine-tuning. 298 | 299 | This function saves all successful requests (not just errors) with comprehensive metadata 300 | including tokens, detection metrics, and parameters. This allows for later debugging and 301 | fine-tuning of detection parameters (repetition, duration, stretching, etc.). 302 | 303 | Args: 304 | audio_chunks: The audio data chunks 305 | text: The input text 306 | request_id: Unique request identifier 307 | metadata: Complete metadata from generation (includes all checks and metrics) 308 | tokens: The generated tokens (for reproducibility) 309 | executor: Thread pool executor for file I/O 310 | 311 | Returns: 312 | Path to the saved debug directory 313 | """ 314 | if not ENABLE_SUCCESS_LOGGING: 315 | logger.debug("Success logging is disabled") 316 | return "" 317 | 318 | try: 319 | # Calculate audio duration 320 | total_audio_frames = sum(len(chunk) // 2 for chunk in audio_chunks if chunk) # 16-bit audio 321 | duration = total_audio_frames / 24000 # 24kHz sample rate 322 | 323 | # Create directory with duration and request_id as specified by user 324 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 325 | duration_str = f"{duration:.2f}s".replace(".", "_") 326 | debug_dir = os.path.join(DEBUG_SUCCESS_DIR, f"{duration_str}_{request_id}") 327 | os.makedirs(debug_dir, exist_ok=True) 328 | 329 | # Save audio file with duration in filename as specified 330 | duration_str = f"{duration:.2f}s".replace(".", "_") 331 | audio_filename = f"{duration_str}_{request_id}.wav" 332 | audio_path = os.path.join(debug_dir, audio_filename) 333 | 334 | def write_debug_audio(): 335 | with wave.open(audio_path, "wb") as wf: 336 | wf.setnchannels(1) # mono 337 | wf.setsampwidth(2) # 16-bit 338 | wf.setframerate(24000) # 24kHz 339 | 340 | for chunk in audio_chunks: 341 | if chunk: 342 | wf.writeframes(chunk) 343 | 344 | # Write audio in executor 345 | await asyncio.get_running_loop().run_in_executor(executor, write_debug_audio) 346 | 347 | # Prepare comprehensive metadata with ALL information needed for debugging 348 | # Convert numpy types to native Python types for JSON serialization 349 | debug_metadata = { 350 | "timestamp": timestamp, 351 | "request_id": request_id, 352 | "status": "success", 353 | "duration_seconds": round(duration, 2), 354 | "input_text": text, 355 | "text_length": len(text), 356 | "word_count": normalize_and_count_words(text), 357 | "audio_file": audio_filename, 358 | "token_count": len(tokens), 359 | "tokens": tokens, # Save tokens for reproducibility 360 | "attempts": metadata.get("attempts", 1), 361 | "retries": metadata.get("retries", 0), 362 | 363 | # Detection results - critical for fine-tuning parameters 364 | "detection_results": { 365 | "token_repetition": { 366 | "passed": True, # Since this is a successful request 367 | "metrics": metadata.get("checks", {}).get("token_count", {}) 368 | }, 369 | "token_variance": metadata.get("checks", {}).get("token_variance", {}), 370 | "audio_duration": { 371 | "passed": not metadata.get("checks", {}).get("audio_duration", {}).get("is_duration_outlier", False), 372 | "metrics": metadata.get("checks", {}).get("audio_duration", {}) 373 | }, 374 | "audio_quality": { 375 | "passed": True, # Since this is a successful request 376 | "metrics": metadata.get("checks", {}).get("audio_quality", {}) 377 | } 378 | }, 379 | 380 | # Complete metadata for full context 381 | "full_metadata": convert_numpy_types(metadata) 382 | } 383 | 384 | # Save metadata as JSON 385 | metadata_path = os.path.join(debug_dir, "metadata.json") 386 | 387 | def write_metadata(): 388 | with open(metadata_path, "w", encoding="utf-8") as f: 389 | json.dump(debug_metadata, f, indent=2, ensure_ascii=False) 390 | 391 | # Write metadata in executor 392 | await asyncio.get_running_loop().run_in_executor(executor, write_metadata) 393 | 394 | # Also save tokens as a separate text file for easy inspection 395 | tokens_path = os.path.join(debug_dir, "tokens.txt") 396 | 397 | def write_tokens(): 398 | with open(tokens_path, "w", encoding="utf-8") as f: 399 | for i, token in enumerate(tokens): 400 | f.write(f"{i}: {repr(token)}\n") 401 | 402 | # Write tokens in executor 403 | await asyncio.get_running_loop().run_in_executor(executor, write_tokens) 404 | 405 | logger.info(f"✅ Success log saved to: {debug_dir}") 406 | return debug_dir 407 | 408 | except Exception as e: 409 | logger.error(f"Failed to save success log: {e}") 410 | return "" 411 | 412 | class OrpheusModelExtended(OrpheusModel): 413 | """Extended OrpheusModel with additional vLLM parameters""" 414 | 415 | def __init__(self, model_name, dtype=torch.bfloat16, max_model_len=2048, tensor_parallel_size=1, gpu_memory_utilization=0.9, max_num_seqs=64, enable_chunked_prefill=True, enable_prefix_caching=True): 416 | # Store additional parameters 417 | self.max_model_len = max_model_len 418 | self.tensor_parallel_size = tensor_parallel_size 419 | self.gpu_memory_utilization = gpu_memory_utilization 420 | self.enable_chunked_prefill = enable_chunked_prefill 421 | self.enable_prefix_caching = enable_prefix_caching 422 | self.max_num_seqs = max_num_seqs 423 | # Call parent constructor with original parameters 424 | super().__init__(model_name, dtype) 425 | 426 | def _setup_engine(self): 427 | """Override to include additional vLLM parameters""" 428 | # Map torch dtype to vLLM ModelDType literals 429 | vllm_dtype: Literal["auto", "half", "float16", "bfloat16", "float", "float32"] 430 | if self.dtype == torch.bfloat16: 431 | vllm_dtype = "bfloat16" 432 | elif self.dtype == torch.float16: 433 | vllm_dtype = "float16" 434 | elif self.dtype == torch.float32: 435 | vllm_dtype = "float32" 436 | else: 437 | vllm_dtype = "bfloat16" # default fallback 438 | 439 | engine_args = AsyncEngineArgs( 440 | enforce_eager=False, 441 | model=self.model_name, 442 | dtype=vllm_dtype, 443 | max_model_len=self.max_model_len, 444 | tensor_parallel_size=self.tensor_parallel_size, 445 | gpu_memory_utilization=self.gpu_memory_utilization, 446 | enable_chunked_prefill=self.enable_chunked_prefill, 447 | enable_prefix_caching=self.enable_prefix_caching, 448 | max_num_seqs=self.max_num_seqs 449 | ) 450 | return AsyncLLMEngine.from_engine_args(engine_args) 451 | 452 | async def generate_tokens_async(engine, prompt: str, voice: str, request_id: str, 453 | temperature: Optional[float] = None, top_p: Optional[float] = None, 454 | repetition_penalty: Optional[float] = None, max_tokens: Optional[int] = None, 455 | default_temperature: float = 0.2, default_top_p: float = 0.9, 456 | default_repetition_penalty: float = 1.1, default_max_tokens: int = 4096) -> list[str]: 457 | """Generate tokens using the async vLLM engine directly""" 458 | 459 | # Format prompt using the same logic as OrpheusModel 460 | adapted_prompt = f"{voice}: {prompt}" 461 | prompt_tokens = engine.tokeniser(adapted_prompt, return_tensors="pt") 462 | start_token = torch.tensor([[128259]], dtype=torch.int64) 463 | end_tokens = torch.tensor([[128009, 128260, 128261, 128257]], dtype=torch.int64) 464 | all_input_ids = torch.cat([start_token, prompt_tokens.input_ids, end_tokens], dim=1) 465 | prompt_string = engine.tokeniser.decode(all_input_ids[0]) 466 | 467 | # Use provided parameters or fall back to defaults 468 | temperature = temperature if temperature is not None else default_temperature 469 | top_p = top_p if top_p is not None else default_top_p 470 | repetition_penalty = repetition_penalty if repetition_penalty is not None else default_repetition_penalty 471 | max_tokens = max_tokens if max_tokens is not None else default_max_tokens 472 | 473 | # Set up sampling parameters with configurable values 474 | sampling_params = SamplingParams( 475 | temperature=temperature, 476 | top_p=top_p, 477 | max_tokens=max_tokens, 478 | stop_token_ids=[128258], 479 | repetition_penalty=repetition_penalty, 480 | ) 481 | 482 | logger.debug(f"Using sampling params - temp: {temperature}, top_p: {top_p}, " 483 | f"max_tokens: {max_tokens}, rep_penalty: {repetition_penalty}") 484 | 485 | # Generate tokens using async engine 486 | tokens = [] 487 | async for result in engine.engine.generate(prompt=prompt_string, sampling_params=sampling_params, request_id=request_id): 488 | tokens.append(result.outputs[0].text) 489 | 490 | return tokens 491 | 492 | async def generate_speech_tokens_with_retry(engine, prompt: str, voice: str, request_id: str, 493 | temperature: Optional[float] = None, top_p: Optional[float] = None, 494 | repetition_penalty: Optional[float] = None, max_tokens: Optional[int] = None, 495 | default_temperature: float = 0.2, default_top_p: float = 0.9, 496 | default_repetition_penalty: float = 1.1, default_max_tokens: int = 4096, 497 | executor: Optional[ThreadPoolExecutor] = None) -> tuple[list[bytes], dict]: 498 | """Generate speech tokens with retry logic for audio decoding errors""" 499 | 500 | last_error = None 501 | current_repetition_penalty = repetition_penalty # Track current penalty for adjustments 502 | current_temperature = temperature if temperature is not None else default_temperature 503 | effective_max_tokens = max_tokens if max_tokens is not None else default_max_tokens 504 | 505 | for attempt in range(MAX_RETRIES): 506 | token_repetition_error = None 507 | token_count_error = None 508 | audio_chunks = [] 509 | tokens = [] 510 | 511 | try: 512 | # Generate tokens using async engine 513 | tokens = await generate_tokens_async(engine, prompt, voice, request_id, 514 | current_temperature, top_p, current_repetition_penalty, max_tokens, 515 | default_temperature, default_top_p, default_repetition_penalty, default_max_tokens) 516 | 517 | # === STAGE 1: Token-level checks (BEFORE audio generation - lightweight) === 518 | 519 | # Check 1: Token repetition patterns (catch but don't raise yet) 520 | try: 521 | check_token_repetition(tokens, effective_max_tokens) 522 | except TokenRepetitionError as e: 523 | token_repetition_error = e 524 | logger.warning(f"⚠️ Token repetition detected, will generate audio for debugging: {e}") 525 | 526 | # Check 2: Token count vs text length ratio (catch but don't raise yet) 527 | try: 528 | token_count_metrics = check_token_count_ratio(tokens, prompt, effective_max_tokens) 529 | except TokenCountOutlierError as e: 530 | token_count_error = e 531 | token_count_metrics = { 532 | "token_count": len(tokens), 533 | "word_count": normalize_and_count_words(prompt), 534 | "is_ratio_outlier": True, 535 | "error": str(e) 536 | } 537 | logger.warning(f"⚠️ Token count outlier detected, will generate audio for debugging: {e}") 538 | 539 | # Check 3: Token variance analysis (always runs, returns metrics) 540 | token_variance_metrics = check_token_variance(tokens, window_size=100) 541 | 542 | # === STAGE 2: Audio generation === 543 | 544 | # Try to decode tokens to audio chunks (even if token checks failed, for debugging) 545 | async for audio_chunk in tokens_decoder(tokens): 546 | audio_chunks.append(audio_chunk) 547 | 548 | # If we get here, decoding was successful 549 | chunk_count = len(audio_chunks) 550 | total_audio_frames = sum(len(chunk) // 2 for chunk in audio_chunks if chunk) # 16-bit audio 551 | duration = total_audio_frames / 24000 # 24kHz sample rate 552 | 553 | # === STAGE 3: Audio-level checks (AFTER audio generation) === 554 | 555 | # Check 4: Audio duration outlier detection 556 | is_outlier, duration_metrics = is_audio_duration_outlier(prompt, duration) 557 | 558 | # Check 5: Advanced audio quality analysis (if available and enabled) 559 | audio_quality_issues = False 560 | audio_quality_metrics = {} 561 | 562 | # Skip audio quality checks for extremely short texts (≤3 words) 563 | # These naturally have high silence ratios due to lead-in/out 564 | word_count_for_check = normalize_and_count_words(prompt) 565 | skip_quality_check = word_count_for_check <= 3 566 | 567 | if not skip_quality_check: 568 | try: 569 | # Count emotion tags for audio quality analysis 570 | import re 571 | emotion_tags_found = re.findall(r'<(\w+)>', prompt) 572 | ORPHEUS_EMOTION_TAGS = ['laugh', 'chuckle', 'sigh', 'cough', 'sniffle', 'groan', 'yawn', 'gasp'] 573 | emotion_tag_count = sum(1 for tag in emotion_tags_found if tag.lower() in ORPHEUS_EMOTION_TAGS) 574 | 575 | audio_quality_issues, audio_quality_metrics = quick_audio_check(audio_chunks, sample_rate=24000, emotion_tag_count=emotion_tag_count) 576 | if audio_quality_issues: 577 | logger.warning( 578 | f"⚠️ Audio quality issues detected for request_id: {request_id} - " 579 | f"Silence: {audio_quality_metrics.get('silence_percent', 0):.1f}%, " 580 | f"Repetition: {audio_quality_metrics.get('has_repetition', False)}, " 581 | f"Stretched: {audio_quality_metrics.get('is_stretched', False)}" 582 | ) 583 | except Exception as e: 584 | logger.debug(f"Audio quality analysis failed: {e}") 585 | audio_quality_metrics = {"error": str(e)} 586 | elif skip_quality_check: 587 | audio_quality_metrics = {"skipped": "text too short (≤3 words)", "word_count": word_count_for_check} 588 | 589 | # Prepare comprehensive metadata 590 | metadata = { 591 | "request_id": request_id, 592 | "chunk_count": chunk_count, 593 | "total_audio_frames": total_audio_frames, 594 | "duration_seconds": round(duration, 2), 595 | "voice": voice, 596 | "prompt_length": len(prompt), 597 | "attempts": attempt + 1, 598 | "retries": attempt, 599 | "checks": { 600 | "token_count": token_count_metrics, 601 | "token_variance": token_variance_metrics, 602 | "audio_duration": duration_metrics, 603 | "audio_quality": audio_quality_metrics if audio_quality_metrics else None 604 | } 605 | } 606 | 607 | # Now handle token-level errors (after we have audio for debugging) 608 | if token_repetition_error is not None: 609 | logger.error(f"🚨 Token repetition error for request_id: {request_id}") 610 | 611 | # Add error details to metadata 612 | metadata["error"] = str(token_repetition_error) 613 | metadata["token_count"] = len(tokens) 614 | 615 | # Save debug files if executor is available 616 | if executor is not None: 617 | await save_debug_audio_with_metadata( 618 | audio_chunks, prompt, "token_repetition", request_id, metadata, executor 619 | ) 620 | 621 | # Raise the error to trigger retry 622 | raise token_repetition_error 623 | 624 | if token_count_error is not None: 625 | logger.error(f"🚨 Token count error for request_id: {request_id}") 626 | 627 | # Add error details to metadata 628 | metadata["error"] = str(token_count_error) 629 | metadata["token_count"] = len(tokens) 630 | 631 | # Save debug files if executor is available 632 | if executor is not None: 633 | await save_debug_audio_with_metadata( 634 | audio_chunks, prompt, "token_count_outlier", request_id, metadata, executor 635 | ) 636 | 637 | # Raise the error to trigger retry 638 | raise token_count_error 639 | 640 | # If duration is an outlier OR audio quality issues detected, save debug files and retry 641 | if is_outlier or audio_quality_issues: 642 | # Determine error type and message 643 | has_low_variance = token_variance_metrics.get("is_low_variance", False) 644 | 645 | if is_outlier and audio_quality_issues: 646 | error_type = "duration_outlier_with_quality_issues" 647 | error_msg = ( 648 | f"Audio duration outlier with quality issues: {duration:.2f}s for {duration_metrics['word_count']} words " 649 | f"(expected max: {duration_metrics['max_expected_duration']:.2f}s, " 650 | f"silence: {audio_quality_metrics.get('silence_percent', 0):.1f}%, " 651 | f"repetition: {audio_quality_metrics.get('has_repetition', False)}, " 652 | f"stretched: {audio_quality_metrics.get('is_stretched', False)})" 653 | ) 654 | logger.error(f"🚨 {error_msg} for request_id: {request_id}") 655 | elif is_outlier and has_low_variance: 656 | error_type = "duration_outlier_with_low_variance" 657 | error_msg = ( 658 | f"Audio duration outlier with low token variance: {duration:.2f}s for {duration_metrics['word_count']} words " 659 | f"(expected max: {duration_metrics['max_expected_duration']:.2f}s, " 660 | f"CV: {token_variance_metrics.get('coefficient_of_variation', 'N/A')}, " 661 | f"unique ratio: {token_variance_metrics.get('unique_tokens_ratio', 'N/A')})" 662 | ) 663 | logger.error(f"🚨 {error_msg} for request_id: {request_id}") 664 | elif is_outlier: 665 | error_type = "duration_outlier" 666 | error_msg = ( 667 | f"Audio duration outlier: {duration:.2f}s for {duration_metrics['word_count']} words " 668 | f"(expected max: {duration_metrics['max_expected_duration']:.2f}s, " 669 | f"detection method: {duration_metrics.get('detection_method', 'unknown')})" 670 | ) 671 | logger.error(f"🚨 Audio duration outlier detected for request_id: {request_id}") 672 | else: # audio_quality_issues only 673 | error_type = "audio_quality_issues" 674 | error_msg = ( 675 | f"Audio quality issues detected: {duration:.2f}s for {duration_metrics['word_count']} words " 676 | f"(silence: {audio_quality_metrics.get('silence_percent', 0):.1f}%, " 677 | f"repetition: {audio_quality_metrics.get('has_repetition', False)}, " 678 | f"repeated_segments: {audio_quality_metrics.get('repeated_segment_count', 0)}, " 679 | f"stretched: {audio_quality_metrics.get('is_stretched', False)})" 680 | ) 681 | logger.error(f"🚨 Audio quality issues for request_id: {request_id}") 682 | 683 | # Add error details to metadata 684 | metadata["token_count"] = len(tokens) 685 | metadata["error"] = error_msg 686 | metadata["has_low_variance"] = has_low_variance 687 | metadata["has_audio_quality_issues"] = audio_quality_issues 688 | 689 | # Save debug files if executor is available 690 | if executor is not None: 691 | await save_debug_audio_with_metadata( 692 | audio_chunks, prompt, error_type, request_id, metadata, executor 693 | ) 694 | 695 | # Raise appropriate error type to trigger retry 696 | if audio_quality_issues and not is_outlier: 697 | # Only quality issues, treat as duration outlier for retry logic 698 | raise AudioDurationOutlierError(error_msg) 699 | else: 700 | # Duration outlier (with or without quality issues) 701 | raise AudioDurationOutlierError(error_msg) 702 | 703 | # Check for warnings in variance or token count (log but don't retry) 704 | has_warnings = ( 705 | token_variance_metrics.get("is_low_variance", False) or 706 | token_count_metrics.get("is_near_limit", False) 707 | ) 708 | 709 | if has_warnings and executor is not None: 710 | logger.warning(f"⚠️ Edge case warnings detected for request_id: {request_id}") 711 | # Save debug files for analysis even if not retrying 712 | await save_debug_audio_with_metadata( 713 | audio_chunks, prompt, "edge_case_warning", request_id, metadata, executor 714 | ) 715 | 716 | # Save successful requests for debugging and fine-tuning detection parameters 717 | if executor is not None: 718 | await save_successful_audio_with_metadata( 719 | audio_chunks, prompt, request_id, metadata, tokens, executor 720 | ) 721 | 722 | if attempt > 0: 723 | logger.info(f"✅ Token generation successful after {attempt + 1} attempts for request_id: {request_id}") 724 | 725 | return audio_chunks, metadata 726 | 727 | except TokenRepetitionError as e: 728 | last_error = e 729 | logger.error(f"🔄 Token repetition detected on attempt {attempt + 1}/{MAX_RETRIES} for request_id: {request_id}: {e}") 730 | 731 | if attempt < MAX_RETRIES - 1: # Not the last attempt 732 | # For repetition errors, adjust sampling parameters 733 | logger.info(f"⏳ Retrying with adjusted parameters in {RETRY_DELAY} seconds...") 734 | 735 | # Increase repetition penalty for retry attempts 736 | if current_repetition_penalty is None: 737 | current_repetition_penalty = default_repetition_penalty * (1.0 + 0.15 * (attempt + 1)) # Increase by 10% per retry 738 | current_temperature = default_temperature * (1.0 + 0.15 * (attempt + 1)) 739 | else: 740 | current_repetition_penalty = current_repetition_penalty * (1.0 + 0.15 * (attempt + 1)) 741 | current_temperature = current_temperature * (1.0 + 0.15 * (attempt + 1)) 742 | logger.info(f"🔧 Adjusted repetition penalty to {current_repetition_penalty:.2f} and temperature to {current_temperature:.2f} for retry") 743 | await asyncio.sleep(RETRY_DELAY) 744 | else: 745 | logger.error(f"❌ All {MAX_RETRIES} attempts failed due to token repetition for request_id: {request_id}") 746 | 747 | except TokenCountOutlierError as e: 748 | last_error = e 749 | logger.error(f"🔄 Token count outlier on attempt {attempt + 1}/{MAX_RETRIES} for request_id: {request_id}: {e}") 750 | 751 | if attempt < MAX_RETRIES - 1: # Not the last attempt 752 | # For token count outliers, adjust sampling parameters 753 | logger.info(f"⏳ Retrying with adjusted parameters in {RETRY_DELAY} seconds...") 754 | 755 | # Increase repetition penalty for retry attempts 756 | if current_repetition_penalty is None: 757 | current_repetition_penalty = default_repetition_penalty * (1.0 + 0.15 * (attempt + 1)) # Increase by 15% per retry 758 | current_temperature = default_temperature * (1.0 + 0.15 * (attempt + 1)) 759 | else: 760 | current_repetition_penalty = current_repetition_penalty * (1.0 + 0.15 * (attempt + 1)) 761 | current_temperature = current_temperature * (1.0 + 0.15 * (attempt + 1)) 762 | 763 | logger.info(f"🔧 Adjusted repetition penalty to {current_repetition_penalty:.2f} and temperature to {current_temperature:.2f} for retry") 764 | await asyncio.sleep(RETRY_DELAY) 765 | else: 766 | logger.error(f"❌ All {MAX_RETRIES} attempts failed due to token count outlier for request_id: {request_id}") 767 | 768 | except AudioDurationOutlierError as e: 769 | last_error = e 770 | logger.error(f"🔄 Audio duration outlier on attempt {attempt + 1}/{MAX_RETRIES} for request_id: {request_id}: {e}") 771 | 772 | if attempt < MAX_RETRIES - 1: # Not the last attempt 773 | logger.info(f"⏳ Retrying in {RETRY_DELAY} seconds...") 774 | 775 | # Increase repetition penalty slightly for retry attempts 776 | if current_repetition_penalty is None: 777 | current_repetition_penalty = default_repetition_penalty * (1.0 + 0.15 * (attempt + 1)) 778 | current_temperature = default_temperature * (1.0 + 0.15 * (attempt + 1)) 779 | else: 780 | current_repetition_penalty = current_repetition_penalty * (1.0 + 0.15 * (attempt + 1)) 781 | current_temperature = current_temperature * (1.0 + 0.15 * (attempt + 1)) 782 | logger.info(f"🔧 Adjusted repetition penalty to {current_repetition_penalty:.2f} and temperature to {current_temperature:.2f} for retry") 783 | await asyncio.sleep(RETRY_DELAY) 784 | else: 785 | logger.error(f"❌ All {MAX_RETRIES} attempts failed due to audio duration outlier for request_id: {request_id}") 786 | 787 | except AudioDecodingError as e: 788 | last_error = e 789 | logger.info(f"🔄 Audio decoding failed on attempt {attempt + 1}/{MAX_RETRIES} for request_id: {request_id}: {e}") 790 | 791 | # No need to save debug audio - audio decoding errors are always true positives 792 | 793 | if attempt < MAX_RETRIES - 1: # Not the last attempt 794 | logger.info(f"⏳ Retrying in {RETRY_DELAY} seconds...") 795 | 796 | # Increase repetition penalty slightly for retry attempts 797 | if current_repetition_penalty is None: 798 | current_repetition_penalty = default_repetition_penalty * (1.0 + 0.15 * (attempt + 1)) 799 | current_temperature = default_temperature * (1.0 + 0.15 * (attempt + 1)) 800 | else: 801 | current_repetition_penalty = current_repetition_penalty * (1.0 + 0.15 * (attempt + 1)) 802 | current_temperature = current_temperature * (1.0 + 0.15 * (attempt + 1)) 803 | logger.info(f"🔧 Adjusted repetition penalty to {current_repetition_penalty:.2f} and temperature to {current_temperature:.2f} for retry") 804 | await asyncio.sleep(RETRY_DELAY) 805 | else: 806 | logger.error(f"❌ All {MAX_RETRIES} attempts failed for request_id: {request_id}") 807 | 808 | except Exception as e: 809 | # For other exceptions, don't retry - they're likely not transient 810 | logger.error(f"❌ Non-retryable error in token generation for request_id: {request_id}: {e}") 811 | raise 812 | 813 | # If we get here, all retries failed - raise the specific error type 814 | if isinstance(last_error, TokenRepetitionError): 815 | raise TokenRepetitionError(f"Token repetition persisted after {MAX_RETRIES} attempts. Last error: {last_error}") 816 | elif isinstance(last_error, TokenCountOutlierError): 817 | raise TokenCountOutlierError(f"Token count outlier persisted after {MAX_RETRIES} attempts. Last error: {last_error}") 818 | elif isinstance(last_error, AudioDurationOutlierError): 819 | raise AudioDurationOutlierError(f"Audio duration outlier persisted after {MAX_RETRIES} attempts. Last error: {last_error}") 820 | else: 821 | raise AudioDecodingError(f"Token generation failed after {MAX_RETRIES} attempts. Last error: {last_error}") 822 | 823 | async def generate_speech_tokens_direct(engine, prompt: str, voice: str, 824 | temperature: Optional[float] = None, top_p: Optional[float] = None, 825 | repetition_penalty: Optional[float] = None, max_tokens: Optional[int] = None, 826 | default_temperature: float = 0.2, default_top_p: float = 0.9, 827 | default_repetition_penalty: float = 1.1, default_max_tokens: int = 4096, 828 | executor: Optional[ThreadPoolExecutor] = None) -> tuple[list[bytes], dict]: 829 | """Generate speech tokens using direct async vLLM engine access with retry logic""" 830 | try: 831 | # Generate unique request ID 832 | request_id = f"req-{uuid.uuid4().hex[:8]}" 833 | 834 | # logger.info(f"Starting speech token generation with retry logic for voice: {voice}, request_id: {request_id}") 835 | 836 | # Use retry logic 837 | audio_chunks, metadata = await generate_speech_tokens_with_retry( 838 | engine, prompt, voice, request_id, 839 | temperature, top_p, repetition_penalty, max_tokens, 840 | default_temperature, default_top_p, default_repetition_penalty, default_max_tokens, 841 | executor 842 | ) 843 | 844 | # logger.info(f"Speech token generation complete: {metadata['chunk_count']} chunks, " 845 | # f"{metadata['duration_seconds']:.2f}s audio, {metadata['attempts']} attempts") 846 | 847 | return audio_chunks, metadata 848 | 849 | except Exception as e: 850 | logger.error(f"Error in generate_speech_tokens_direct: {e}") 851 | raise 852 | 853 | def combine_token_chunks(token_chunks_list: list[list[bytes]]) -> list[bytes]: 854 | """Combine multiple token chunk lists into a single list""" 855 | combined_chunks = [] 856 | total_chunks = 0 857 | 858 | for chunk_list in token_chunks_list: 859 | combined_chunks.extend(chunk_list) 860 | total_chunks += len(chunk_list) 861 | 862 | # logger.info(f"Combined {len(token_chunks_list)} batches into {total_chunks} total chunks") 863 | return combined_chunks 864 | 865 | async def generate_speech_chunks(engine, text_chunks: list[str], voice: str, 866 | temperature: Optional[float] = None, top_p: Optional[float] = None, 867 | repetition_penalty: Optional[float] = None, max_tokens: Optional[int] = None, 868 | default_temperature: float = 0.2, default_top_p: float = 0.9, 869 | default_repetition_penalty: float = 1.1, default_max_tokens: int = 4096, 870 | executor: Optional[ThreadPoolExecutor] = None) -> tuple[list[bytes], dict]: 871 | """ 872 | Generate speech tokens for multiple text chunks in parallel with retry logic. 873 | Returns combined tokens and detailed metadata. 874 | """ 875 | total_metadata = { 876 | "total_chunks": len(text_chunks), 877 | "chunk_details": [], 878 | "combined_stats": {}, 879 | "retry_stats": { 880 | "total_attempts": 0, 881 | "total_retries": 0, 882 | "failed_chunks": 0 883 | } 884 | } 885 | 886 | # logger.info(f"Starting parallel processing of {len(text_chunks)} text chunks with retry logic") 887 | 888 | # Generate tokens for all chunks in parallel with retry logic 889 | chunk_tasks = [] 890 | for i, chunk in enumerate(text_chunks): 891 | # logger.info(f"Queuing chunk {i+1}/{len(text_chunks)}: {len(chunk)} characters") 892 | request_id = f"chunk-{i+1}-{uuid.uuid4().hex[:8]}" 893 | 894 | # Each chunk gets its own retry logic 895 | task = generate_speech_tokens_with_retry( 896 | engine, chunk, voice, request_id, 897 | temperature, top_p, repetition_penalty, max_tokens, 898 | default_temperature, default_top_p, default_repetition_penalty, default_max_tokens, 899 | executor 900 | ) 901 | chunk_tasks.append(task) 902 | 903 | # Wait for all chunks to complete 904 | chunk_results = await asyncio.gather(*chunk_tasks, return_exceptions=True) 905 | 906 | # Process results and collect metadata 907 | all_chunk_results = [] 908 | failed_chunks = 0 909 | 910 | for i, result in enumerate(chunk_results): 911 | if isinstance(result, BaseException): 912 | # Chunk failed completely 913 | failed_chunks += 1 914 | logger.error(f"Chunk {i+1} failed after all retries: {result}") 915 | total_metadata["retry_stats"]["failed_chunks"] += 1 916 | # Add empty result to maintain order 917 | all_chunk_results.append([]) 918 | 919 | chunk_info = { 920 | "chunk_index": i, 921 | "text_length": len(text_chunks[i]), 922 | "audio_chunks": 0, 923 | "duration_seconds": 0, 924 | "audio_frames": 0, 925 | "request_id": f"chunk-{i+1}-failed", 926 | "attempts": MAX_RETRIES, 927 | "retries": MAX_RETRIES - 1, 928 | "failed": True, 929 | "error": str(result) 930 | } 931 | total_metadata["chunk_details"].append(chunk_info) 932 | else: 933 | # Chunk succeeded - result is a tuple 934 | chunk_tokens, chunk_metadata = result 935 | all_chunk_results.append(chunk_tokens) 936 | 937 | # Add chunk-specific metadata 938 | chunk_info = { 939 | "chunk_index": i, 940 | "text_length": len(text_chunks[i]), 941 | "audio_chunks": chunk_metadata["chunk_count"], 942 | "duration_seconds": chunk_metadata["duration_seconds"], 943 | "audio_frames": chunk_metadata["total_audio_frames"], 944 | "request_id": chunk_metadata["request_id"], 945 | "attempts": chunk_metadata.get("attempts", 1), 946 | "retries": chunk_metadata.get("retries", 0), 947 | "failed": False 948 | } 949 | total_metadata["chunk_details"].append(chunk_info) 950 | 951 | # Update retry stats 952 | total_metadata["retry_stats"]["total_attempts"] += chunk_metadata.get("attempts", 1) 953 | total_metadata["retry_stats"]["total_retries"] += chunk_metadata.get("retries", 0) 954 | 955 | # logger.info(f"Chunk {i+1} complete: {chunk_metadata['chunk_count']} audio chunks, " 956 | # f"{chunk_metadata['duration_seconds']}s duration, " 957 | # f"{chunk_metadata.get('attempts', 1)} attempts") 958 | 959 | # Log retry statistics 960 | retry_stats = total_metadata["retry_stats"] 961 | # logger.info(f"Retry statistics: {retry_stats['total_attempts']} total attempts, " 962 | # f"{retry_stats['total_retries']} total retries, " 963 | # f"{retry_stats['failed_chunks']} failed chunks") 964 | 965 | # Combine all token chunks (excluding failed ones) 966 | combined_tokens = combine_token_chunks(all_chunk_results) 967 | 968 | # Calculate combined statistics 969 | successful_chunks = [chunk for chunk in total_metadata["chunk_details"] if not chunk.get("failed", False)] 970 | total_duration = sum(chunk["duration_seconds"] for chunk in successful_chunks) 971 | total_frames = sum(chunk["audio_frames"] for chunk in successful_chunks) 972 | total_audio_chunks = sum(chunk["audio_chunks"] for chunk in successful_chunks) 973 | 974 | total_metadata["combined_stats"] = { 975 | "total_text_length": sum(len(chunk) for chunk in text_chunks), 976 | "total_audio_chunks": total_audio_chunks, 977 | "total_duration_seconds": round(total_duration, 2), 978 | "total_audio_frames": total_frames, 979 | "successful_chunks": len(successful_chunks), 980 | "failed_chunks": failed_chunks, 981 | "voice": voice 982 | } 983 | 984 | # logger.info(f"Generated and combined {len(text_chunks)} text chunks into {len(combined_tokens)} audio chunks, " 985 | # f"total duration: {total_duration:.2f}s, {failed_chunks} failed chunks") 986 | 987 | # If any chunks failed, raise an error 988 | if failed_chunks > 0: 989 | raise AudioDecodingError(f"{failed_chunks} of {len(text_chunks)} chunks failed after retries") 990 | 991 | return combined_tokens, total_metadata 992 | 993 | def _write_tokens_to_file(token_chunks: list[bytes], output_path: str) -> dict: 994 | """Helper function to write token chunks to WAV file in executor""" 995 | # Ensure output directory exists 996 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 997 | 998 | # Create WAV file 999 | with wave.open(output_path, "wb") as wf: 1000 | wf.setnchannels(1) # mono 1001 | wf.setsampwidth(2) # 16-bit 1002 | wf.setframerate(24000) # 24kHz 1003 | 1004 | total_frames = 0 1005 | 1006 | for chunk in token_chunks: 1007 | if chunk: 1008 | frame_count = len(chunk) // (wf.getsampwidth() * wf.getnchannels()) 1009 | total_frames += frame_count 1010 | wf.writeframes(chunk) 1011 | 1012 | duration = total_frames / wf.getframerate() 1013 | 1014 | file_stats = { 1015 | "total_chunks": len(token_chunks), 1016 | "total_frames": total_frames, 1017 | "duration_seconds": round(duration, 2), 1018 | "file_size_bytes": os.path.getsize(output_path), 1019 | "output_path": output_path 1020 | } 1021 | 1022 | return file_stats 1023 | 1024 | async def tokens_to_audio_file(token_chunks: list[bytes], output_path: str, executor: ThreadPoolExecutor) -> dict: 1025 | """Convert token chunks to WAV audio file""" 1026 | try: 1027 | # Run file I/O in executor 1028 | file_stats = await asyncio.get_running_loop().run_in_executor( 1029 | executor, _write_tokens_to_file, token_chunks, output_path 1030 | ) 1031 | 1032 | # logger.info(f"Audio file created: {len(token_chunks)} chunks, {file_stats['duration_seconds']:.2f}s, {file_stats['file_size_bytes']} bytes") 1033 | 1034 | return file_stats 1035 | 1036 | except Exception as e: 1037 | logger.error(f"Error in tokens_to_audio_file: {e}") 1038 | # Clean up partial file on error 1039 | if os.path.exists(output_path): 1040 | os.remove(output_path) 1041 | raise --------------------------------------------------------------------------------