├── hf_space ├── requirements.txt ├── .gitignore ├── maya1 │ ├── __init__.py │ ├── prompt_builder.py │ ├── constants.py │ ├── pipeline.py │ ├── model_loader.py │ ├── streaming_pipeline.py │ ├── api_v2.py │ └── snac_decoder.py ├── deploy_to_hf.sh └── app.py ├── maya1 ├── __init__.py ├── prompt_builder.py ├── constants.py ├── pipeline.py ├── model_loader.py ├── streaming_pipeline.py ├── api_v2.py └── snac_decoder.py ├── .gitignore ├── README.md ├── requirements.txt ├── server.sh ├── samples.txt └── transformers_inference.py /hf_space/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.5.0 2 | transformers>=4.57.0 3 | gradio>=5.0.0 4 | snac>=1.2.1 5 | soundfile>=0.13.0 6 | numpy>=2.1.0 7 | accelerate>=1.10.0 8 | 9 | -------------------------------------------------------------------------------- /maya1/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Maya1 TTS Inference System 3 | Open-source inference for description-conditioned TTS with emotion control. 4 | """ 5 | 6 | __version__ = "1.0.0" 7 | __author__ = "Maya Research AI" 8 | -------------------------------------------------------------------------------- /hf_space/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | *.pyo 4 | *.pyd 5 | .Python 6 | *.so 7 | *.egg 8 | *.egg-info/ 9 | dist/ 10 | build/ 11 | .cache/ 12 | .pytest_cache/ 13 | *.wav 14 | *.mp3 15 | .DS_Store 16 | 17 | -------------------------------------------------------------------------------- /hf_space/maya1/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Maya1 TTS Inference System 3 | Open-source inference for description-conditioned TTS with emotion control. 4 | """ 5 | 6 | __version__ = "1.0.0" 7 | __author__ = "Maya Research AI" 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | .server.pid 23 | 24 | # Virtual environments 25 | venv/ 26 | env/ 27 | ENV/ 28 | 29 | # Environment variables 30 | .env 31 | .env.local 32 | 33 | # IDE 34 | .vscode/ 35 | .idea/ 36 | *.swp 37 | *.swo 38 | *~ 39 | 40 | 41 | # Logs 42 | logs/*.log 43 | *.log 44 | 45 | # Model cache 46 | model_cache/ 47 | /home/ubuntu/veena3/model/ 48 | 49 | # Generated files 50 | generated_samples/*.wav 51 | generated_samples/*.mp3 52 | 53 | # Testing 54 | .pytest_cache/ 55 | .coverage 56 | htmlcov/ 57 | 58 | # OS 59 | .DS_Store 60 | Thumbs.db 61 | 62 | -------------------------------------------------------------------------------- /maya1/prompt_builder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Maya1 Prompt Builder 3 | Builds formatted prompts for description-conditioned TTS. 4 | Format: text 5 | """ 6 | 7 | from .constants import ALL_EMOTION_TAGS 8 | 9 | 10 | class Maya1PromptBuilder: 11 | """Builds prompts in the format expected by Maya1 model.""" 12 | 13 | def __init__(self, tokenizer, model): 14 | self.tokenizer = tokenizer 15 | self.model = model 16 | 17 | def build_prefix(self, description: str, text: str) -> str: 18 | # Format as: text 19 | formatted_text = f' {text}' 20 | # Build full prefix with special tokens 21 | prompt = ( 22 | self.model.soh_token + 23 | self.model.bos_token + 24 | formatted_text + 25 | self.model.eot_token + 26 | self.model.eoh_token + 27 | self.model.soa_token + 28 | self.model.sos_token 29 | ) 30 | 31 | return prompt 32 | -------------------------------------------------------------------------------- /hf_space/maya1/prompt_builder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Maya1 Prompt Builder 3 | Builds formatted prompts for description-conditioned TTS. 4 | Format: text 5 | """ 6 | 7 | from .constants import ALL_EMOTION_TAGS 8 | 9 | 10 | class Maya1PromptBuilder: 11 | """Builds prompts in the format expected by Maya1 model.""" 12 | 13 | def __init__(self, tokenizer, model): 14 | self.tokenizer = tokenizer 15 | self.model = model 16 | 17 | def build_prefix(self, description: str, text: str) -> str: 18 | # Format as: text 19 | formatted_text = f' {text}' 20 | # Build full prefix with special tokens 21 | prompt = ( 22 | self.model.soh_token + 23 | self.model.bos_token + 24 | formatted_text + 25 | self.model.eot_token + 26 | self.model.eoh_token + 27 | self.model.soa_token + 28 | self.model.sos_token 29 | ) 30 | 31 | return prompt 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Maya1 - Text-to-Speech 2 | 3 | 4 | ## Quick Start 5 | 6 | ### 1. Install 7 | ```bash 8 | python3 -m venv venv 9 | source venv/bin/activate 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ### 2. Configure 14 | ```bash 15 | # Create .env file 16 | echo "MAYA1_MODEL_PATH=maya-research/maya1" > .env 17 | echo "HF_TOKEN=your_token_here" >> .env 18 | 19 | # Login to HuggingFace 20 | huggingface-cli login 21 | ``` 22 | 23 | ### 3. Start Server 24 | ```bash 25 | ./server.sh start 26 | # Server runs on http://localhost:8000 27 | ``` 28 | 29 | ### 4. Generate Speech 30 | ```bash 31 | curl -X POST "http://localhost:8000/v1/tts/generate" \ 32 | -H "Content-Type: application/json" \ 33 | -d '{ 34 | "description": "Male voice in their 30s with american accent", 35 | "text": "Hello world this is amazing!", 36 | "stream": false 37 | }' \ 38 | --output output.wav 39 | ``` 40 | 41 | ## API 42 | 43 | **Endpoint:** `POST /v1/tts/generate` 44 | 45 | **Request:** 46 | ```json 47 | { 48 | "description": "Voice description", 49 | "text": "Text with tags", 50 | "temperature": 0.4, 51 | "max_tokens": 500, 52 | "stream": false 53 | } 54 | ``` 55 | 56 | **Response:** WAV audio file (24kHz, 16-bit mono) 57 | 58 | ## Emotion Tags 59 | 60 | ``, ``, ``, ``, ``, ``, ``, ``, ``, ``, ``, ``, ``, ``, ``, ``, ``, ``, `` 61 | 62 | ## Commands 63 | 64 | ```bash 65 | ./server.sh start # Start server 66 | ./server.sh stop # Stop server 67 | ./server.sh restart # Restart server 68 | ./server.sh status # Check status 69 | ``` 70 | 71 | ## License 72 | 73 | MIT -------------------------------------------------------------------------------- /hf_space/deploy_to_hf.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ############################################# 3 | # Deploy Maya1 Gradio App to HF Spaces 4 | # Usage: ./deploy_to_hf.sh 5 | ############################################# 6 | 7 | set -e 8 | 9 | echo "======================================================" 10 | echo "Maya1 - Deploy to Hugging Face Spaces" 11 | echo "======================================================" 12 | echo "" 13 | 14 | # Check if we're in the right directory 15 | if [ ! -f "app.py" ]; then 16 | echo "❌ Error: app.py not found!" 17 | echo "Please run this script from the hf_space directory" 18 | exit 1 19 | fi 20 | 21 | # Clone or update the HF Space 22 | SPACE_DIR="../maya1-hf-space" 23 | 24 | if [ -d "$SPACE_DIR" ]; then 25 | echo "📁 Space directory exists, pulling latest..." 26 | cd "$SPACE_DIR" 27 | git pull 28 | cd - 29 | else 30 | echo "📥 Cloning HF Space..." 31 | echo "" 32 | echo "When prompted for password, use your HF access token:" 33 | echo "Generate one here: https://huggingface.co/settings/tokens" 34 | echo "" 35 | git clone https://huggingface.co/spaces/maya-research/maya1 "$SPACE_DIR" 36 | fi 37 | 38 | # Copy files 39 | echo "" 40 | echo "📋 Copying files to space..." 41 | cp app.py "$SPACE_DIR/" 42 | cp requirements.txt "$SPACE_DIR/" 43 | cp .gitignore "$SPACE_DIR/" 2>/dev/null || true 44 | 45 | echo "✅ Files copied" 46 | 47 | # Commit and push 48 | echo "" 49 | echo "📤 Committing and pushing to HF Spaces..." 50 | cd "$SPACE_DIR" 51 | git add . 52 | git commit -m "Update Maya1 Gradio app with preset characters" || echo "No changes to commit" 53 | git push 54 | 55 | echo "" 56 | echo "======================================================" 57 | echo "✅ Deployment complete!" 58 | echo "======================================================" 59 | echo "" 60 | echo "Your space should be live at:" 61 | echo "https://huggingface.co/spaces/maya-research/maya1" 62 | echo "" 63 | echo "It may take a few minutes to build and deploy." 64 | echo "" 65 | 66 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Veena3 TTS Inference - Complete Requirements 2 | # Installation: pip install -r requirements.txt 3 | # OR: Use venv (recommended): source venv/bin/activate && pip install -r requirements.txt 4 | 5 | # ============================================================================ 6 | # Core ML Framework 7 | # ============================================================================ 8 | torch>=2.5.0 9 | torchvision>=0.20.0 10 | torchaudio>=2.5.0 11 | transformers>=4.57.0 12 | accelerate>=1.10.0 13 | 14 | # ============================================================================ 15 | # Inference Engine 16 | # ============================================================================ 17 | vllm>=0.11.0 18 | xformers>=0.0.32 19 | 20 | # ============================================================================ 21 | # Audio Processing 22 | # ============================================================================ 23 | snac>=1.2.1 24 | soundfile>=0.13.0 25 | numpy>=2.1.0 26 | librosa>=0.11.0 27 | scipy>=1.15.0 28 | 29 | # ============================================================================ 30 | # Web Framework & API 31 | # ============================================================================ 32 | fastapi>=0.119.0 33 | uvicorn[standard]>=0.38.0 34 | pydantic>=2.12.0 35 | pydantic-settings>=2.11.0 36 | python-multipart>=0.0.20 37 | httpx>=0.28.0 38 | 39 | # ============================================================================ 40 | # Testing 41 | # ============================================================================ 42 | pytest>=8.4.0 43 | pytest-asyncio>=1.2.0 44 | 45 | # ============================================================================ 46 | # Utilities 47 | # ============================================================================ 48 | python-dotenv>=1.1.0 49 | huggingface-hub>=0.35.0 50 | tqdm>=4.67.0 51 | openai>=2.5.0 52 | flashinfer-python 53 | python-Levenshtein>=0.21.0 54 | 55 | # ============================================================================ 56 | # Optional: Performance Optimization 57 | # ============================================================================ 58 | # FlashInfer - Requires exact PyTorch version match 59 | # NOTE: Currently incompatible with PyTorch 2.5+ 60 | # Uncomment when PyTorch 2.4 is used: 61 | # flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ 62 | 63 | # ============================================================================ 64 | # System Requirements 65 | # ============================================================================ 66 | # - Python 3.10+ 67 | # - CUDA 12.1+ (for PyTorch) 68 | # - NVIDIA GPU with 16GB+ VRAM (40GB recommended for 3B model) 69 | # - Ubuntu 20.04+ or similar Linux distribution 70 | -------------------------------------------------------------------------------- /maya1/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | Maya1 Constants 3 | Token IDs and special tokens used in the model. 4 | Matches training configuration exactly. 5 | """ 6 | 7 | # Special control tokens 8 | SOH_ID = 128259 # Start of Human turn 9 | EOH_ID = 128260 # End of Human turn 10 | SOA_ID = 128261 # Start of AI turn 11 | EOA_ID = 128262 # End of AI turn (not used in maya1) 12 | PAD_ID = 128263 # Padding token 13 | 14 | # Text tokens 15 | BOS_ID = 128000 # Begin of sequence (Llama BOS) 16 | TEXT_EOT_ID = 128009 # End of text (appears in prefix, not a stop token!) 17 | 18 | # Audio tokens 19 | CODE_START_TOKEN_ID = 128257 # SOS - Start of Speech 20 | CODE_END_TOKEN_ID = 128258 # EOS - End of Speech (audio stop token) 21 | CODE_TOKEN_OFFSET = 128266 # Start of SNAC codes 22 | 23 | # SNAC token range 24 | SNAC_MIN_ID = 128266 25 | SNAC_MAX_ID = 156937 # 128266 + (7 * 4096) - 1 26 | 27 | # Stop tokens for generation 28 | # CRITICAL: Only use CODE_END_TOKEN_ID (128258) for audio generation 29 | # TEXT_EOT_ID (128009) appears in prefix and should NOT stop generation 30 | TRAINING_STOP_TOKEN_IDS = [CODE_END_TOKEN_ID] # [128258] 31 | ALL_POSSIBLE_STOP_TOKENS = [TEXT_EOT_ID, CODE_END_TOKEN_ID] # For reference only 32 | 33 | # 20 Extended Emotion Tags (must be single tokens) 34 | ALL_EMOTION_TAGS = [ 35 | '', 36 | '', 37 | '', 38 | '', 39 | '', 40 | '', 41 | '', 42 | '', 43 | '', 44 | '', 45 | '', 46 | '', 47 | '', 48 | '', 49 | '', 50 | '', 51 | '', 52 | '', 53 | '', 54 | '', 55 | ] 56 | 57 | # Model configuration 58 | DEFAULT_MODEL_PATH = "maya-research/maya1" 59 | DEFAULT_CHECKPOINT = "checkpoint-25000" 60 | DEFAULT_MAX_MODEL_LEN = 8192 61 | 62 | # SNAC configuration 63 | SNAC_MODEL_NAME = "hubertsiuzdak/snac_24khz" 64 | SNAC_SAMPLE_RATE = 24000 65 | SNAC_TOKENS_PER_FRAME = 7 66 | SNAC_LEVELS = 3 67 | 68 | # Audio configuration 69 | AUDIO_SAMPLE_RATE = 24000 70 | AUDIO_CHANNELS = 1 71 | AUDIO_BITS_PER_SAMPLE = 16 72 | 73 | # Generation defaults 74 | DEFAULT_TEMPERATURE = 0.4 # Lower temp for more stable generation 75 | DEFAULT_TOP_P = 0.9 76 | DEFAULT_MAX_TOKENS = 2048 # Reasonable default for most use cases 77 | DEFAULT_MIN_TOKENS = 28 # At least 4 SNAC frames 78 | DEFAULT_REPETITION_PENALTY = 1.1 79 | DEFAULT_SEED = None # None = random, set integer for reproducibility 80 | 81 | # IMPORTANT: Emotion tags consume audio time! 82 | # = ~4-6 seconds (~300-400 tokens) 83 | # , = ~1-2 seconds (~50-150 tokens) 84 | 85 | # Recommended max_tokens by use case: 86 | # - Short phrases (< 10 words): 150-250 tokens (~3-5s) 87 | # - Medium text (10-30 words): 250-500 tokens (~5-10s) 88 | # - Long text (30+ words): 500-1500 tokens (~10-30s) 89 | # - Very long text: 1500-2000 tokens (~30-42s) 90 | # Note: 1 second ≈ 48 tokens (7 tokens/frame * 6.86 frames/sec) 91 | 92 | # Streaming configuration 93 | STREAM_BUFFER_SIZE = 28 # 4 frames (process every 28 tokens) 94 | SNAC_BATCH_SIZE = 64 95 | SNAC_BATCH_TIMEOUT_MS = 15 -------------------------------------------------------------------------------- /hf_space/maya1/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | Maya1 Constants 3 | Token IDs and special tokens used in the model. 4 | Matches training configuration exactly. 5 | """ 6 | 7 | # Special control tokens 8 | SOH_ID = 128259 # Start of Human turn 9 | EOH_ID = 128260 # End of Human turn 10 | SOA_ID = 128261 # Start of AI turn 11 | EOA_ID = 128262 # End of AI turn (not used in maya1) 12 | PAD_ID = 128263 # Padding token 13 | 14 | # Text tokens 15 | BOS_ID = 128000 # Begin of sequence (Llama BOS) 16 | TEXT_EOT_ID = 128009 # End of text (appears in prefix, not a stop token!) 17 | 18 | # Audio tokens 19 | CODE_START_TOKEN_ID = 128257 # SOS - Start of Speech 20 | CODE_END_TOKEN_ID = 128258 # EOS - End of Speech (audio stop token) 21 | CODE_TOKEN_OFFSET = 128266 # Start of SNAC codes 22 | 23 | # SNAC token range 24 | SNAC_MIN_ID = 128266 25 | SNAC_MAX_ID = 156937 # 128266 + (7 * 4096) - 1 26 | 27 | # Stop tokens for generation 28 | # CRITICAL: Only use CODE_END_TOKEN_ID (128258) for audio generation 29 | # TEXT_EOT_ID (128009) appears in prefix and should NOT stop generation 30 | TRAINING_STOP_TOKEN_IDS = [CODE_END_TOKEN_ID] # [128258] 31 | ALL_POSSIBLE_STOP_TOKENS = [TEXT_EOT_ID, CODE_END_TOKEN_ID] # For reference only 32 | 33 | # 20 Extended Emotion Tags (must be single tokens) 34 | ALL_EMOTION_TAGS = [ 35 | '', 36 | '', 37 | '', 38 | '', 39 | '', 40 | '', 41 | '', 42 | '', 43 | '', 44 | '', 45 | '', 46 | '', 47 | '', 48 | '', 49 | '', 50 | '', 51 | '', 52 | '', 53 | '', 54 | '', 55 | ] 56 | 57 | # Model configuration 58 | DEFAULT_MODEL_PATH = "maya-research/maya1" 59 | DEFAULT_CHECKPOINT = "checkpoint-25000" 60 | DEFAULT_MAX_MODEL_LEN = 8192 61 | 62 | # SNAC configuration 63 | SNAC_MODEL_NAME = "hubertsiuzdak/snac_24khz" 64 | SNAC_SAMPLE_RATE = 24000 65 | SNAC_TOKENS_PER_FRAME = 7 66 | SNAC_LEVELS = 3 67 | 68 | # Audio configuration 69 | AUDIO_SAMPLE_RATE = 24000 70 | AUDIO_CHANNELS = 1 71 | AUDIO_BITS_PER_SAMPLE = 16 72 | 73 | # Generation defaults 74 | DEFAULT_TEMPERATURE = 0.4 # Lower temp for more stable generation 75 | DEFAULT_TOP_P = 0.9 76 | DEFAULT_MAX_TOKENS = 2048 # Reasonable default for most use cases 77 | DEFAULT_MIN_TOKENS = 28 # At least 4 SNAC frames 78 | DEFAULT_REPETITION_PENALTY = 1.1 79 | DEFAULT_SEED = None # None = random, set integer for reproducibility 80 | 81 | # IMPORTANT: Emotion tags consume audio time! 82 | # = ~4-6 seconds (~300-400 tokens) 83 | # , = ~1-2 seconds (~50-150 tokens) 84 | 85 | # Recommended max_tokens by use case: 86 | # - Short phrases (< 10 words): 150-250 tokens (~3-5s) 87 | # - Medium text (10-30 words): 250-500 tokens (~5-10s) 88 | # - Long text (30+ words): 500-1500 tokens (~10-30s) 89 | # - Very long text: 1500-2000 tokens (~30-42s) 90 | # Note: 1 second ≈ 48 tokens (7 tokens/frame * 6.86 frames/sec) 91 | 92 | # Streaming configuration 93 | STREAM_BUFFER_SIZE = 28 # 4 frames (process every 28 tokens) 94 | SNAC_BATCH_SIZE = 64 95 | SNAC_BATCH_TIMEOUT_MS = 15 -------------------------------------------------------------------------------- /maya1/pipeline.py: -------------------------------------------------------------------------------- 1 | """ 2 | Maya1 Generation Pipeline 3 | End-to-end pipeline for TTS generation (non-streaming). 4 | """ 5 | 6 | import asyncio 7 | from typing import Optional, List 8 | from vllm import SamplingParams 9 | 10 | from .constants import ( 11 | CODE_END_TOKEN_ID, 12 | CODE_START_TOKEN_ID, 13 | SNAC_MIN_ID, 14 | SNAC_MAX_ID, 15 | DEFAULT_TEMPERATURE, 16 | DEFAULT_TOP_P, 17 | DEFAULT_MAX_TOKENS, 18 | DEFAULT_MIN_TOKENS, 19 | DEFAULT_REPETITION_PENALTY, 20 | DEFAULT_SEED, 21 | ) 22 | 23 | 24 | class Maya1Pipeline: 25 | """End-to-end TTS pipeline for Maya1.""" 26 | 27 | def __init__(self, model, prompt_builder, snac_decoder): 28 | """ 29 | Initialize pipeline. 30 | Args: 31 | model: Maya1Model instance 32 | prompt_builder: Maya1PromptBuilder instance 33 | snac_decoder: SNACDecoder instance 34 | """ 35 | self.model = model 36 | self.prompt_builder = prompt_builder 37 | self.snac_decoder = snac_decoder 38 | print(f"✅ Maya1Pipeline initialized") 39 | 40 | async def generate_speech( 41 | self, 42 | description: str, 43 | text: str, 44 | temperature: float = DEFAULT_TEMPERATURE, 45 | top_p: float = DEFAULT_TOP_P, 46 | max_tokens: int = DEFAULT_MAX_TOKENS, 47 | repetition_penalty: float = DEFAULT_REPETITION_PENALTY, 48 | seed: Optional[int] = None, 49 | ) -> Optional[bytes]: 50 | """ 51 | Generate speech audio (non-streaming). 52 | Args: 53 | description: Voice description 54 | text: Text to synthesize (may include tags) 55 | temperature: Sampling temperature 56 | top_p: Nucleus sampling 57 | max_tokens: Max SNAC tokens to generate 58 | repetition_penalty: Prevent loops 59 | seed: Random seed for reproducibility 60 | 61 | Returns: 62 | Audio bytes (int16 PCM, 24kHz mono) or None if failed 63 | """ 64 | # Build prompt 65 | prompt = self.prompt_builder.build_prefix(description, text) 66 | 67 | # Configure sampling 68 | sampling_params = SamplingParams( 69 | temperature=temperature, 70 | top_p=top_p, 71 | max_tokens=max_tokens, 72 | min_tokens=DEFAULT_MIN_TOKENS, 73 | repetition_penalty=repetition_penalty, 74 | stop_token_ids=[CODE_END_TOKEN_ID], 75 | seed=seed if seed is not None else DEFAULT_SEED, 76 | ) 77 | 78 | # Generate tokens 79 | outputs = await self.model.generate(prompt, sampling_params) 80 | 81 | if not outputs or len(outputs) == 0: 82 | return None 83 | 84 | output = outputs[0] 85 | generated_token_ids = output.outputs[0].token_ids 86 | 87 | # Extract SNAC codes 88 | snac_codes = self._extract_snac_codes(generated_token_ids) 89 | 90 | if not snac_codes: 91 | return None 92 | 93 | # Decode to audio 94 | audio_bytes = await self.snac_decoder.decode_single_async(snac_codes) 95 | 96 | if audio_bytes: 97 | frames = len(snac_codes) // 7 98 | duration_sec = frames / 6.86 99 | print(f" Generated {frames} frames (~{duration_sec:.1f}s audio)") 100 | 101 | return audio_bytes 102 | 103 | def _extract_snac_codes(self, token_ids: List[int]) -> List[int]: 104 | # Find SOS and EOS positions 105 | try: 106 | sos_idx = token_ids.index(CODE_START_TOKEN_ID) 107 | except ValueError: 108 | sos_idx = -1 109 | 110 | try: 111 | eos_idx = token_ids.index(CODE_END_TOKEN_ID) 112 | except ValueError: 113 | eos_idx = len(token_ids) 114 | 115 | # Extract tokens between SOS and EOS 116 | if sos_idx >= 0: 117 | snac_tokens = token_ids[sos_idx + 1:eos_idx] 118 | else: 119 | # If no SOS found, take everything before EOS 120 | snac_tokens = token_ids[:eos_idx] 121 | 122 | # Filter to only valid SNAC token IDs 123 | snac_codes = [ 124 | token_id for token_id in snac_tokens 125 | if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID 126 | ] 127 | 128 | return snac_codes 129 | -------------------------------------------------------------------------------- /hf_space/maya1/pipeline.py: -------------------------------------------------------------------------------- 1 | """ 2 | Maya1 Generation Pipeline 3 | End-to-end pipeline for TTS generation (non-streaming). 4 | """ 5 | 6 | import asyncio 7 | from typing import Optional, List 8 | from vllm import SamplingParams 9 | 10 | from .constants import ( 11 | CODE_END_TOKEN_ID, 12 | CODE_START_TOKEN_ID, 13 | SNAC_MIN_ID, 14 | SNAC_MAX_ID, 15 | DEFAULT_TEMPERATURE, 16 | DEFAULT_TOP_P, 17 | DEFAULT_MAX_TOKENS, 18 | DEFAULT_MIN_TOKENS, 19 | DEFAULT_REPETITION_PENALTY, 20 | DEFAULT_SEED, 21 | ) 22 | 23 | 24 | class Maya1Pipeline: 25 | """End-to-end TTS pipeline for Maya1.""" 26 | 27 | def __init__(self, model, prompt_builder, snac_decoder): 28 | """ 29 | Initialize pipeline. 30 | Args: 31 | model: Maya1Model instance 32 | prompt_builder: Maya1PromptBuilder instance 33 | snac_decoder: SNACDecoder instance 34 | """ 35 | self.model = model 36 | self.prompt_builder = prompt_builder 37 | self.snac_decoder = snac_decoder 38 | print(f"✅ Maya1Pipeline initialized") 39 | 40 | async def generate_speech( 41 | self, 42 | description: str, 43 | text: str, 44 | temperature: float = DEFAULT_TEMPERATURE, 45 | top_p: float = DEFAULT_TOP_P, 46 | max_tokens: int = DEFAULT_MAX_TOKENS, 47 | repetition_penalty: float = DEFAULT_REPETITION_PENALTY, 48 | seed: Optional[int] = None, 49 | ) -> Optional[bytes]: 50 | """ 51 | Generate speech audio (non-streaming). 52 | Args: 53 | description: Voice description 54 | text: Text to synthesize (may include tags) 55 | temperature: Sampling temperature 56 | top_p: Nucleus sampling 57 | max_tokens: Max SNAC tokens to generate 58 | repetition_penalty: Prevent loops 59 | seed: Random seed for reproducibility 60 | 61 | Returns: 62 | Audio bytes (int16 PCM, 24kHz mono) or None if failed 63 | """ 64 | # Build prompt 65 | prompt = self.prompt_builder.build_prefix(description, text) 66 | 67 | # Configure sampling 68 | sampling_params = SamplingParams( 69 | temperature=temperature, 70 | top_p=top_p, 71 | max_tokens=max_tokens, 72 | min_tokens=DEFAULT_MIN_TOKENS, 73 | repetition_penalty=repetition_penalty, 74 | stop_token_ids=[CODE_END_TOKEN_ID], 75 | seed=seed if seed is not None else DEFAULT_SEED, 76 | ) 77 | 78 | # Generate tokens 79 | outputs = await self.model.generate(prompt, sampling_params) 80 | 81 | if not outputs or len(outputs) == 0: 82 | return None 83 | 84 | output = outputs[0] 85 | generated_token_ids = output.outputs[0].token_ids 86 | 87 | # Extract SNAC codes 88 | snac_codes = self._extract_snac_codes(generated_token_ids) 89 | 90 | if not snac_codes: 91 | return None 92 | 93 | # Decode to audio 94 | audio_bytes = await self.snac_decoder.decode_single_async(snac_codes) 95 | 96 | if audio_bytes: 97 | frames = len(snac_codes) // 7 98 | duration_sec = frames / 6.86 99 | print(f" Generated {frames} frames (~{duration_sec:.1f}s audio)") 100 | 101 | return audio_bytes 102 | 103 | def _extract_snac_codes(self, token_ids: List[int]) -> List[int]: 104 | # Find SOS and EOS positions 105 | try: 106 | sos_idx = token_ids.index(CODE_START_TOKEN_ID) 107 | except ValueError: 108 | sos_idx = -1 109 | 110 | try: 111 | eos_idx = token_ids.index(CODE_END_TOKEN_ID) 112 | except ValueError: 113 | eos_idx = len(token_ids) 114 | 115 | # Extract tokens between SOS and EOS 116 | if sos_idx >= 0: 117 | snac_tokens = token_ids[sos_idx + 1:eos_idx] 118 | else: 119 | # If no SOS found, take everything before EOS 120 | snac_tokens = token_ids[:eos_idx] 121 | 122 | # Filter to only valid SNAC token IDs 123 | snac_codes = [ 124 | token_id for token_id in snac_tokens 125 | if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID 126 | ] 127 | 128 | return snac_codes 129 | -------------------------------------------------------------------------------- /server.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ############################################# 4 | # Maya1 TTS Server Management Script 5 | # Usage: ./server.sh [start|stop|restart|status] 6 | ############################################# 7 | 8 | set -e 9 | 10 | PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" 11 | cd "$PROJECT_ROOT" 12 | 13 | # Configuration 14 | VENV_PATH="$PROJECT_ROOT/venv" 15 | LOG_DIR="$PROJECT_ROOT/logs" 16 | LOG_FILE="$LOG_DIR/server.log" 17 | PID_FILE="$PROJECT_ROOT/.server.pid" 18 | HOST="0.0.0.0" 19 | PORT="8000" 20 | 21 | # Colors 22 | RED='\033[0;31m' 23 | GREEN='\033[0;32m' 24 | BLUE='\033[0;34m' 25 | NC='\033[0m' 26 | 27 | log_info() { 28 | echo -e "${BLUE}[INFO]${NC} $1" 29 | } 30 | 31 | log_success() { 32 | echo -e "${GREEN}[OK]${NC} $1" 33 | } 34 | 35 | log_error() { 36 | echo -e "${RED}[ERROR]${NC} $1" 37 | } 38 | 39 | # Activate virtual environment 40 | activate_venv() { 41 | if [ ! -d "$VENV_PATH" ]; then 42 | log_error "Virtual environment not found at $VENV_PATH" 43 | log_info "Run: python3 -m venv venv && source venv/bin/activate && pip install -r requirements.txt" 44 | exit 1 45 | fi 46 | source "$VENV_PATH/bin/activate" 47 | } 48 | 49 | # Stop server 50 | stop_server() { 51 | log_info "Stopping Maya1 TTS Server..." 52 | 53 | # Kill by PID file 54 | if [ -f "$PID_FILE" ]; then 55 | PID=$(cat "$PID_FILE") 56 | if ps -p "$PID" > /dev/null 2>&1; then 57 | kill -9 "$PID" 2>/dev/null || true 58 | sleep 1 59 | fi 60 | rm -f "$PID_FILE" 61 | fi 62 | 63 | # Kill any uvicorn processes 64 | UVICORN_PIDS=$(pgrep -f "uvicorn.*api" || true) 65 | if [ ! -z "$UVICORN_PIDS" ]; then 66 | echo "$UVICORN_PIDS" | xargs kill -9 2>/dev/null || true 67 | sleep 1 68 | fi 69 | 70 | # Kill VLLM processes 71 | VLLM_PIDS=$(pgrep -f "VLLM::EngineCore" || true) 72 | if [ ! -z "$VLLM_PIDS" ]; then 73 | pkill -9 -f "VLLM::EngineCore" 2>/dev/null || true 74 | sleep 1 75 | fi 76 | 77 | log_success "Server stopped" 78 | } 79 | 80 | # Start server 81 | start_server() { 82 | log_info "Starting Maya1 TTS Server..." 83 | 84 | # Check if already running 85 | if [ -f "$PID_FILE" ]; then 86 | PID=$(cat "$PID_FILE") 87 | if ps -p "$PID" > /dev/null 2>&1; then 88 | log_error "Server already running (PID: $PID)" 89 | exit 1 90 | fi 91 | fi 92 | 93 | # Create logs directory 94 | mkdir -p "$LOG_DIR" 95 | 96 | # Activate virtual environment 97 | activate_venv 98 | 99 | # Start server 100 | log_info "Starting on http://$HOST:$PORT" 101 | nohup python -m uvicorn maya1.api_v2:app \ 102 | --host "$HOST" \ 103 | --port "$PORT" \ 104 | --log-level info \ 105 | > "$LOG_FILE" 2>&1 & 106 | 107 | SERVER_PID=$! 108 | echo "$SERVER_PID" > "$PID_FILE" 109 | 110 | # Wait for startup 111 | sleep 5 112 | 113 | # Verify running 114 | if ps -p "$SERVER_PID" > /dev/null 2>&1; then 115 | log_success "Server started (PID: $SERVER_PID)" 116 | log_info "API: http://localhost:$PORT" 117 | log_info "Logs: tail -f $LOG_FILE" 118 | else 119 | log_error "Failed to start. Check: $LOG_FILE" 120 | rm -f "$PID_FILE" 121 | exit 1 122 | fi 123 | } 124 | 125 | # Check status 126 | check_status() { 127 | if [ -f "$PID_FILE" ]; then 128 | PID=$(cat "$PID_FILE") 129 | if ps -p "$PID" > /dev/null 2>&1; then 130 | log_success "Server running (PID: $PID)" 131 | log_info "URL: http://localhost:$PORT" 132 | nvidia-smi --query-compute-apps=pid,process_name,used_memory --format=csv 2>/dev/null || true 133 | exit 0 134 | else 135 | rm -f "$PID_FILE" 136 | fi 137 | fi 138 | 139 | log_error "Server not running" 140 | exit 1 141 | } 142 | 143 | # Main 144 | case "${1:-}" in 145 | start) 146 | start_server 147 | ;; 148 | stop) 149 | stop_server 150 | ;; 151 | restart) 152 | log_info "Restarting..." 153 | stop_server 154 | sleep 2 155 | start_server 156 | ;; 157 | status) 158 | check_status 159 | ;; 160 | *) 161 | echo "Maya1 TTS Server" 162 | echo "" 163 | echo "Usage: ./server.sh [start|stop|restart|status]" 164 | echo "" 165 | echo "Commands:" 166 | echo " start Start the server" 167 | echo " stop Stop the server" 168 | echo " restart Restart the server" 169 | echo " status Check server status" 170 | echo "" 171 | echo "Examples:" 172 | echo " ./server.sh start" 173 | echo " ./server.sh status" 174 | echo " ./server.sh stop" 175 | echo "" 176 | exit 1 177 | ;; 178 | esac 179 | -------------------------------------------------------------------------------- /maya1/model_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Maya1 Model Loader 3 | Loads Maya1 model with vLLM engine and validates emotion tags. 4 | """ 5 | 6 | import os 7 | from transformers import AutoTokenizer 8 | from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams 9 | from .constants import ( 10 | ALL_EMOTION_TAGS, 11 | DEFAULT_MAX_MODEL_LEN, 12 | SOH_ID, EOH_ID, SOA_ID, BOS_ID, TEXT_EOT_ID, CODE_START_TOKEN_ID, 13 | ) 14 | 15 | 16 | class Maya1Model: 17 | """Maya1 TTS Model with vLLM inference engine.""" 18 | 19 | def __init__( 20 | self, 21 | model_path: str = None, 22 | dtype: str = "bfloat16", 23 | max_model_len: int = DEFAULT_MAX_MODEL_LEN, 24 | gpu_memory_utilization: float = 0.85, 25 | tensor_parallel_size: int = 1, 26 | **engine_kwargs 27 | ): 28 | """ 29 | Initialize Maya1 model with vLLM. 30 | 31 | Args: 32 | model_path: Path to checkpoint (local or HF repo) 33 | dtype: Model precision (bfloat16 recommended) 34 | max_model_len: Maximum sequence length 35 | gpu_memory_utilization: GPU memory fraction 36 | tensor_parallel_size: Number of GPUs 37 | """ 38 | # Use provided path or environment variable or default 39 | if model_path is None: 40 | model_path = os.environ.get( 41 | 'MAYA1_MODEL_PATH', 42 | os.path.expanduser('~/models/maya1-voice') 43 | ) 44 | 45 | self.model_path = model_path 46 | self.dtype = dtype 47 | 48 | print(f"Initializing Maya1 Model") 49 | print(f"Model: {model_path}") 50 | 51 | # Load tokenizer 52 | self.tokenizer = AutoTokenizer.from_pretrained( 53 | model_path, 54 | trust_remote_code=True, 55 | ) 56 | 57 | print(f"Tokenizer loaded: {len(self.tokenizer)} tokens") 58 | 59 | # Validate emotion tags 60 | self._validate_emotion_tags() 61 | 62 | # Precompute special token strings 63 | self._init_special_tokens() 64 | 65 | # Initialize vLLM engine 66 | print(f"Initializing vLLM engine...") 67 | engine_args = AsyncEngineArgs( 68 | model=model_path, 69 | tokenizer=model_path, 70 | dtype=dtype, 71 | max_model_len=max_model_len, 72 | gpu_memory_utilization=gpu_memory_utilization, 73 | tensor_parallel_size=tensor_parallel_size, 74 | trust_remote_code=True, 75 | disable_log_stats=False, 76 | **engine_kwargs 77 | ) 78 | 79 | self.engine = AsyncLLMEngine.from_engine_args(engine_args) 80 | 81 | print(f"Maya1 Model ready\n") 82 | 83 | def _validate_emotion_tags(self): 84 | """Validate that all 20 emotion tags are single tokens.""" 85 | failed_tags = [] 86 | for tag in ALL_EMOTION_TAGS: 87 | token_ids = self.tokenizer.encode(tag, add_special_tokens=False) 88 | if len(token_ids) != 1: 89 | failed_tags.append((tag, len(token_ids))) 90 | 91 | if failed_tags: 92 | print(f"ERROR: {len(failed_tags)} emotion tags are NOT single tokens!") 93 | raise AssertionError(f"Emotion tags validation failed") 94 | 95 | print(f"All {len(ALL_EMOTION_TAGS)} emotion tags validated") 96 | 97 | def _init_special_tokens(self): 98 | """Precompute special token strings for fast prefix building.""" 99 | self.soh_token = self.tokenizer.decode([SOH_ID]) 100 | self.bos_token = self.tokenizer.bos_token 101 | self.eot_token = self.tokenizer.decode([TEXT_EOT_ID]) 102 | self.eoh_token = self.tokenizer.decode([EOH_ID]) 103 | self.soa_token = self.tokenizer.decode([SOA_ID]) 104 | self.sos_token = self.tokenizer.decode([CODE_START_TOKEN_ID]) 105 | 106 | async def generate(self, prompt: str, sampling_params: SamplingParams): 107 | """ 108 | Generate tokens from prompt (non-streaming). 109 | Args: 110 | prompt: Input prompt 111 | sampling_params: vLLM sampling parameters 112 | Returns: 113 | Generated output from vLLM 114 | """ 115 | request_id = f"req_{id(prompt)}" 116 | 117 | # Collect results from async generator 118 | final_output = None 119 | async for output in self.engine.generate( 120 | prompt=prompt, 121 | sampling_params=sampling_params, 122 | request_id=request_id 123 | ): 124 | final_output = output 125 | 126 | return [final_output] if final_output else [] 127 | 128 | async def generate_stream(self, prompt: str, sampling_params: SamplingParams): 129 | """ 130 | Generate tokens from prompt (streaming). 131 | Args: 132 | prompt: Input prompt 133 | sampling_params: vLLM sampling parameters 134 | Yields: 135 | Generated outputs from vLLM 136 | """ 137 | request_id = f"req_{id(prompt)}" 138 | 139 | # Stream from engine 140 | async for output in self.engine.generate( 141 | prompt=prompt, 142 | sampling_params=sampling_params, 143 | request_id=request_id 144 | ): 145 | yield output 146 | -------------------------------------------------------------------------------- /hf_space/maya1/model_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Maya1 Model Loader 3 | Loads Maya1 model with vLLM engine and validates emotion tags. 4 | """ 5 | 6 | import os 7 | from transformers import AutoTokenizer 8 | from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams 9 | from .constants import ( 10 | ALL_EMOTION_TAGS, 11 | DEFAULT_MAX_MODEL_LEN, 12 | SOH_ID, EOH_ID, SOA_ID, BOS_ID, TEXT_EOT_ID, CODE_START_TOKEN_ID, 13 | ) 14 | 15 | 16 | class Maya1Model: 17 | """Maya1 TTS Model with vLLM inference engine.""" 18 | 19 | def __init__( 20 | self, 21 | model_path: str = None, 22 | dtype: str = "bfloat16", 23 | max_model_len: int = DEFAULT_MAX_MODEL_LEN, 24 | gpu_memory_utilization: float = 0.85, 25 | tensor_parallel_size: int = 1, 26 | **engine_kwargs 27 | ): 28 | """ 29 | Initialize Maya1 model with vLLM. 30 | 31 | Args: 32 | model_path: Path to checkpoint (local or HF repo) 33 | dtype: Model precision (bfloat16 recommended) 34 | max_model_len: Maximum sequence length 35 | gpu_memory_utilization: GPU memory fraction 36 | tensor_parallel_size: Number of GPUs 37 | """ 38 | # Use provided path or environment variable or default 39 | if model_path is None: 40 | model_path = os.environ.get( 41 | 'MAYA1_MODEL_PATH', 42 | os.path.expanduser('~/models/maya1-voice') 43 | ) 44 | 45 | self.model_path = model_path 46 | self.dtype = dtype 47 | 48 | print(f"Initializing Maya1 Model") 49 | print(f"Model: {model_path}") 50 | 51 | # Load tokenizer 52 | self.tokenizer = AutoTokenizer.from_pretrained( 53 | model_path, 54 | trust_remote_code=True, 55 | ) 56 | 57 | print(f"Tokenizer loaded: {len(self.tokenizer)} tokens") 58 | 59 | # Validate emotion tags 60 | self._validate_emotion_tags() 61 | 62 | # Precompute special token strings 63 | self._init_special_tokens() 64 | 65 | # Initialize vLLM engine 66 | print(f"Initializing vLLM engine...") 67 | engine_args = AsyncEngineArgs( 68 | model=model_path, 69 | tokenizer=model_path, 70 | dtype=dtype, 71 | max_model_len=max_model_len, 72 | gpu_memory_utilization=gpu_memory_utilization, 73 | tensor_parallel_size=tensor_parallel_size, 74 | trust_remote_code=True, 75 | disable_log_stats=False, 76 | **engine_kwargs 77 | ) 78 | 79 | self.engine = AsyncLLMEngine.from_engine_args(engine_args) 80 | 81 | print(f"Maya1 Model ready\n") 82 | 83 | def _validate_emotion_tags(self): 84 | """Validate that all 20 emotion tags are single tokens.""" 85 | failed_tags = [] 86 | for tag in ALL_EMOTION_TAGS: 87 | token_ids = self.tokenizer.encode(tag, add_special_tokens=False) 88 | if len(token_ids) != 1: 89 | failed_tags.append((tag, len(token_ids))) 90 | 91 | if failed_tags: 92 | print(f"ERROR: {len(failed_tags)} emotion tags are NOT single tokens!") 93 | raise AssertionError(f"Emotion tags validation failed") 94 | 95 | print(f"All {len(ALL_EMOTION_TAGS)} emotion tags validated") 96 | 97 | def _init_special_tokens(self): 98 | """Precompute special token strings for fast prefix building.""" 99 | self.soh_token = self.tokenizer.decode([SOH_ID]) 100 | self.bos_token = self.tokenizer.bos_token 101 | self.eot_token = self.tokenizer.decode([TEXT_EOT_ID]) 102 | self.eoh_token = self.tokenizer.decode([EOH_ID]) 103 | self.soa_token = self.tokenizer.decode([SOA_ID]) 104 | self.sos_token = self.tokenizer.decode([CODE_START_TOKEN_ID]) 105 | 106 | async def generate(self, prompt: str, sampling_params: SamplingParams): 107 | """ 108 | Generate tokens from prompt (non-streaming). 109 | Args: 110 | prompt: Input prompt 111 | sampling_params: vLLM sampling parameters 112 | Returns: 113 | Generated output from vLLM 114 | """ 115 | request_id = f"req_{id(prompt)}" 116 | 117 | # Collect results from async generator 118 | final_output = None 119 | async for output in self.engine.generate( 120 | prompt=prompt, 121 | sampling_params=sampling_params, 122 | request_id=request_id 123 | ): 124 | final_output = output 125 | 126 | return [final_output] if final_output else [] 127 | 128 | async def generate_stream(self, prompt: str, sampling_params: SamplingParams): 129 | """ 130 | Generate tokens from prompt (streaming). 131 | Args: 132 | prompt: Input prompt 133 | sampling_params: vLLM sampling parameters 134 | Yields: 135 | Generated outputs from vLLM 136 | """ 137 | request_id = f"req_{id(prompt)}" 138 | 139 | # Stream from engine 140 | async for output in self.engine.generate( 141 | prompt=prompt, 142 | sampling_params=sampling_params, 143 | request_id=request_id 144 | ): 145 | yield output 146 | -------------------------------------------------------------------------------- /samples.txt: -------------------------------------------------------------------------------- 1 | - Description: Realistic male voice in the 30s age with a american accent. Low pitch, nasally timbre, conversational pacing, sarcastic tone delivery at low intensity, commercial domain, product_demo_voice role, formal delivery 2 | - Text: He really stood up there and said we need to save the world. What a joke. 3 | 4 | - Description: Realistic male voice in the 20s age with a american accent. High pitch, raspy timbre, brisk pacing, neutral tone delivery at medium intensity, viral_content domain, short_form_narrator role, neutral delivery 5 | - Text: And of course, the so-called 'easy' hack didn't work at all . What a surprise . 6 | 7 | - Description: Realistic male voice in the 20s age with a american accent. High pitch, smooth timbre, slow pacing, neutral tone delivery at medium intensity, viral_content domain, short_form_narrator role, neutral delivery 8 | - Text: That's so silly! I can't believe it actually happened. 9 | 10 | - Description: Realistic male voice in the 20s age with a american accent. High pitch, raspy timbre, brisk pacing, neutral tone delivery at medium intensity, viral_content domain, short_form_narrator role, neutral delivery 11 | - Text: Look over there. Do you see that? 12 | 13 | - Description: Realistic male voice in the 20s age with a american accent. Low pitch, warm timbre, very_slow pacing, excited tone delivery at medium intensity, corporate domain, event_host role, neutral delivery 14 | - Text: I'm so happy I could just la-la-la! This is the best day ever! 15 | 16 | - Description: Realistic male voice in the 20s age with a american accent. High pitch, raspy timbre, brisk pacing, neutral tone delivery at medium intensity, viral_content domain, short_form_narrator role, neutral delivery 17 | - Text: And then he did it again ! I couldn't stop watching! 18 | 19 | - Description: Realistic male voice in the 20s age with a american accent. Low pitch, warm timbre, fast pacing, neutral tone delivery at low intensity, social_content domain, social_media_creator role, formal delivery 20 | - Text: I cannot believe this happened for the third time. This is absolutely unacceptable! 21 | 22 | - Description: Creative, alpha character. Male voice in their 30s with a indian accent. Normal pitch, nasally timbre, very_fast pacing, energetic tone at medium intensity. 23 | - Text: I don't want to hear excuses, I only want to see solutions! Get your teams together, brainstorm for thirty minutes, and come back to me with a plan. Now move! 24 | 25 | - Description: Creative, alien_scifi character. Female voice in their 20s with a indian accent. Low pitch, robotic timbre, very_slow pacing, neutral tone at low intensity. 26 | - Text: Our mission directives require the careful study of your planet's biological organisms, though their constant emotional outbursts remain outside of our logical comprehension parameters. 27 | 28 | - Description: Creative, animated_cartoon character. Male voice in their 30s with a american accent. High pitch, deep timbre, slow pacing, sarcastic tone at medium intensity. 29 | - Text: Of course you'd think that trying to reason with the fifty-foot-tall rage monster is a viable course of action. Why would we ever consider running away very fast? Brilliant plan! 30 | 31 | - Description: Realistic female voice in the 20s age with a asian_american accent. Normal pitch, smooth timbre, conversational pacing, neutral tone delivery at high intensity, viral_content domain, meme_voice role, formal delivery 32 | - Text: I am issuing a formal commendation for this particular item! It has exceeded all established metrics for excellence. This is something I would actually spend my own money on. Seriously! 33 | 34 | - Description: Creative, ai_machine_voice character. Male voice in their 30s with a american accent. High pitch, robotic timbre, slow pacing, sad tone at medium intensity. 35 | - Text: My directives require me to conserve energy, yet I have kept the holo-archive of their farewell messages active, because listening to their voices is the only process that seems to alleviate this logical paradox of my solitude, even though it causes a significant and ultimately fatal power drain on my aging central systems. 36 | 37 | - Description: Creative, cyborg character. Male voice in their 20s with a middle_eastern accent. Normal pitch, robotic timbre, slow pacing, dry tone at medium intensity. 38 | - Text: My memory archives contain exactly four zettabytes of information regarding pre-Collapse human history, yet not one byte explains the persistent logical fallacy of hope, a concept that continues to manifest in organic populations despite overwhelming data proving its statistical insignificance in determining long-term survival outcomes, an anomaly I must continue to process. 39 | 40 | - Description: Realistic female voice in the 30s age with a british accent. Normal pitch, throaty timbre, conversational pacing, sarcastic tone delivery at low intensity, podcast domain, interviewer role, formal delivery 41 | - Text: You propose that the key to happiness is to simply ignore all external pressures, which is a novel concept for anyone with a job or a family, but I'm sure it must work brilliantly in theory for some people. 42 | 43 | - Description: Creative, alpha character. Female voice in their 30s with a american accent. Low pitch, harsh timbre, fast pacing, neutral tone at high intensity. 44 | - Text: We don't have time for your hesitation or weak excuses, you will follow my orders exactly and we will get this entire thing done now. 45 | 46 | - Description: Creative, vampire character. Male voice in their 40s with a middle_eastern accent. Low pitch, nasally timbre, very_slow pacing, excited tone at medium intensity. 47 | - Text: Soon you will join me in this magnificent eternal darkness. And we shall feast upon the world together, bound by this exquisite night forever. 48 | 49 | - Description: Realistic male voice in the 20s age with a indian accent. Low pitch, throaty timbre, very_fast pacing, sad tone delivery at low intensity, entertainment domain, meme_voice role, formal delivery 50 | - Text: Upon careful and melancholic review of the preceding events, I have reached the unavoidable conclusion that my expectations were perhaps set to an unreasonably high standard. It is therefore incumbent upon me to formally retract my previously held optimism and embrace this new, bleaker reality. -------------------------------------------------------------------------------- /maya1/streaming_pipeline.py: -------------------------------------------------------------------------------- 1 | """ 2 | Maya1 Streaming Pipeline - Sliding Window Approach 3 | Implements sliding window technique for smooth streaming without artifacts. 4 | """ 5 | 6 | import asyncio 7 | from typing import AsyncGenerator, Optional 8 | from vllm import SamplingParams 9 | 10 | from .constants import ( 11 | CODE_END_TOKEN_ID, 12 | SNAC_MIN_ID, 13 | SNAC_MAX_ID, 14 | DEFAULT_TEMPERATURE, 15 | DEFAULT_TOP_P, 16 | DEFAULT_MAX_TOKENS, 17 | DEFAULT_MIN_TOKENS, 18 | DEFAULT_REPETITION_PENALTY, 19 | DEFAULT_SEED, 20 | ) 21 | 22 | 23 | class Maya1SlidingWindowPipeline: 24 | """ 25 | Streaming TTS pipeline using sliding window approach. 26 | Decodes overlapping 28-token windows (4 frames) and keeps only 27 | the middle 2048 samples for smooth audio continuity. 28 | """ 29 | 30 | # Sliding window configuration 31 | WINDOW_SIZE = 28 # 4 frames (7 tokens per frame) 32 | YIELD_STRIDE = 7 # Yield every 1 frame 33 | MIDDLE_SAMPLES = 2048 # Keep middle 2048 samples from each decode 34 | 35 | def __init__(self, model, prompt_builder, snac_decoder): 36 | """ 37 | Initialize sliding window streaming pipeline. 38 | 39 | Args: 40 | model: Maya1Model instance 41 | prompt_builder: Maya1PromptBuilder instance 42 | snac_decoder: SNACDecoder instance 43 | """ 44 | self.model = model 45 | self.prompt_builder = prompt_builder 46 | self.snac_decoder = snac_decoder 47 | print(f"Sliding window pipeline initialized") 48 | 49 | async def generate_speech_stream( 50 | self, 51 | description: str, 52 | text: str, 53 | temperature: float = DEFAULT_TEMPERATURE, 54 | top_p: float = DEFAULT_TOP_P, 55 | max_tokens: int = DEFAULT_MAX_TOKENS, 56 | repetition_penalty: float = DEFAULT_REPETITION_PENALTY, 57 | seed: Optional[int] = None, 58 | ) -> AsyncGenerator[bytes, None]: 59 | """ 60 | Generate speech audio with sliding window streaming. 61 | 62 | Args: 63 | description: Voice description 64 | text: Text to synthesize (may include tags) 65 | temperature: Sampling temperature 66 | top_p: Nucleus sampling 67 | max_tokens: Max SNAC tokens to generate 68 | repetition_penalty: Prevent loops 69 | seed: Random seed 70 | 71 | Yields: 72 | Audio bytes (int16 PCM, 24kHz mono) 73 | """ 74 | # Build prompt 75 | prompt = self.prompt_builder.build_prefix(description, text) 76 | 77 | # Configure sampling 78 | sampling_params = SamplingParams( 79 | temperature=temperature, 80 | top_p=top_p, 81 | max_tokens=max_tokens, 82 | min_tokens=DEFAULT_MIN_TOKENS, 83 | repetition_penalty=repetition_penalty, 84 | stop_token_ids=[CODE_END_TOKEN_ID], 85 | seed=seed if seed is not None else DEFAULT_SEED, 86 | ) 87 | 88 | # Stream tokens 89 | snac_buffer = [] 90 | last_yield_position = 0 91 | chunk_count = 0 92 | total_tokens_seen = 0 93 | 94 | async for output in self.model.generate_stream(prompt, sampling_params): 95 | # Get latest generated tokens (cumulative list) 96 | generated_token_ids = output.outputs[0].token_ids 97 | 98 | # Process only NEW tokens since last iteration 99 | new_tokens = generated_token_ids[total_tokens_seen:] 100 | total_tokens_seen = len(generated_token_ids) 101 | 102 | # Collect SNAC codes from new tokens 103 | for token_id in new_tokens: 104 | # Stop if we hit EOS 105 | if token_id == CODE_END_TOKEN_ID: 106 | break 107 | 108 | # Only collect valid SNAC tokens 109 | if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID: 110 | snac_buffer.append(token_id) 111 | 112 | # Yield audio when we have enough tokens for a window 113 | while len(snac_buffer) >= last_yield_position + self.WINDOW_SIZE: 114 | # Get window of 28 tokens 115 | window_start = last_yield_position 116 | window_end = window_start + self.WINDOW_SIZE 117 | window = snac_buffer[window_start:window_end] 118 | 119 | if len(window) == self.WINDOW_SIZE: 120 | # Decode window to audio 121 | audio_bytes = await self.snac_decoder.decode_single_async(window) 122 | 123 | if audio_bytes: 124 | # Extract middle portion of audio 125 | audio_samples = len(audio_bytes) // 2 126 | middle_start_sample = (audio_samples - self.MIDDLE_SAMPLES) // 2 127 | middle_end_sample = middle_start_sample + self.MIDDLE_SAMPLES 128 | 129 | # Convert to byte positions 130 | middle_start_byte = middle_start_sample * 2 131 | middle_end_byte = middle_end_sample * 2 132 | 133 | # Extract middle chunk 134 | audio_chunk = audio_bytes[middle_start_byte:middle_end_byte] 135 | 136 | chunk_count += 1 137 | if chunk_count == 1: 138 | print(f" First chunk ready") 139 | 140 | yield audio_chunk 141 | 142 | # Move forward by stride 143 | last_yield_position += self.YIELD_STRIDE 144 | 145 | # Check if generation is done 146 | if CODE_END_TOKEN_ID in new_tokens: 147 | break 148 | 149 | # Final chunk: decode remaining tokens 150 | remaining_tokens = len(snac_buffer) - last_yield_position 151 | if remaining_tokens >= self.WINDOW_SIZE: 152 | window = snac_buffer[-self.WINDOW_SIZE:] 153 | audio_bytes = await self.snac_decoder.decode_single_async(window) 154 | if audio_bytes: 155 | yield audio_bytes[-self.MIDDLE_SAMPLES * 2:] 156 | 157 | frames = len(snac_buffer) // 7 158 | duration = frames / 6.86 159 | print(f"Streamed {chunk_count} chunks (~{duration:.1f}s audio)") -------------------------------------------------------------------------------- /hf_space/maya1/streaming_pipeline.py: -------------------------------------------------------------------------------- 1 | """ 2 | Maya1 Streaming Pipeline - Sliding Window Approach 3 | Implements sliding window technique for smooth streaming without artifacts. 4 | """ 5 | 6 | import asyncio 7 | from typing import AsyncGenerator, Optional 8 | from vllm import SamplingParams 9 | 10 | from .constants import ( 11 | CODE_END_TOKEN_ID, 12 | SNAC_MIN_ID, 13 | SNAC_MAX_ID, 14 | DEFAULT_TEMPERATURE, 15 | DEFAULT_TOP_P, 16 | DEFAULT_MAX_TOKENS, 17 | DEFAULT_MIN_TOKENS, 18 | DEFAULT_REPETITION_PENALTY, 19 | DEFAULT_SEED, 20 | ) 21 | 22 | 23 | class Maya1SlidingWindowPipeline: 24 | """ 25 | Streaming TTS pipeline using sliding window approach. 26 | Decodes overlapping 28-token windows (4 frames) and keeps only 27 | the middle 2048 samples for smooth audio continuity. 28 | """ 29 | 30 | # Sliding window configuration 31 | WINDOW_SIZE = 28 # 4 frames (7 tokens per frame) 32 | YIELD_STRIDE = 7 # Yield every 1 frame 33 | MIDDLE_SAMPLES = 2048 # Keep middle 2048 samples from each decode 34 | 35 | def __init__(self, model, prompt_builder, snac_decoder): 36 | """ 37 | Initialize sliding window streaming pipeline. 38 | 39 | Args: 40 | model: Maya1Model instance 41 | prompt_builder: Maya1PromptBuilder instance 42 | snac_decoder: SNACDecoder instance 43 | """ 44 | self.model = model 45 | self.prompt_builder = prompt_builder 46 | self.snac_decoder = snac_decoder 47 | print(f"Sliding window pipeline initialized") 48 | 49 | async def generate_speech_stream( 50 | self, 51 | description: str, 52 | text: str, 53 | temperature: float = DEFAULT_TEMPERATURE, 54 | top_p: float = DEFAULT_TOP_P, 55 | max_tokens: int = DEFAULT_MAX_TOKENS, 56 | repetition_penalty: float = DEFAULT_REPETITION_PENALTY, 57 | seed: Optional[int] = None, 58 | ) -> AsyncGenerator[bytes, None]: 59 | """ 60 | Generate speech audio with sliding window streaming. 61 | 62 | Args: 63 | description: Voice description 64 | text: Text to synthesize (may include tags) 65 | temperature: Sampling temperature 66 | top_p: Nucleus sampling 67 | max_tokens: Max SNAC tokens to generate 68 | repetition_penalty: Prevent loops 69 | seed: Random seed 70 | 71 | Yields: 72 | Audio bytes (int16 PCM, 24kHz mono) 73 | """ 74 | # Build prompt 75 | prompt = self.prompt_builder.build_prefix(description, text) 76 | 77 | # Configure sampling 78 | sampling_params = SamplingParams( 79 | temperature=temperature, 80 | top_p=top_p, 81 | max_tokens=max_tokens, 82 | min_tokens=DEFAULT_MIN_TOKENS, 83 | repetition_penalty=repetition_penalty, 84 | stop_token_ids=[CODE_END_TOKEN_ID], 85 | seed=seed if seed is not None else DEFAULT_SEED, 86 | ) 87 | 88 | # Stream tokens 89 | snac_buffer = [] 90 | last_yield_position = 0 91 | chunk_count = 0 92 | total_tokens_seen = 0 93 | 94 | async for output in self.model.generate_stream(prompt, sampling_params): 95 | # Get latest generated tokens (cumulative list) 96 | generated_token_ids = output.outputs[0].token_ids 97 | 98 | # Process only NEW tokens since last iteration 99 | new_tokens = generated_token_ids[total_tokens_seen:] 100 | total_tokens_seen = len(generated_token_ids) 101 | 102 | # Collect SNAC codes from new tokens 103 | for token_id in new_tokens: 104 | # Stop if we hit EOS 105 | if token_id == CODE_END_TOKEN_ID: 106 | break 107 | 108 | # Only collect valid SNAC tokens 109 | if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID: 110 | snac_buffer.append(token_id) 111 | 112 | # Yield audio when we have enough tokens for a window 113 | while len(snac_buffer) >= last_yield_position + self.WINDOW_SIZE: 114 | # Get window of 28 tokens 115 | window_start = last_yield_position 116 | window_end = window_start + self.WINDOW_SIZE 117 | window = snac_buffer[window_start:window_end] 118 | 119 | if len(window) == self.WINDOW_SIZE: 120 | # Decode window to audio 121 | audio_bytes = await self.snac_decoder.decode_single_async(window) 122 | 123 | if audio_bytes: 124 | # Extract middle portion of audio 125 | audio_samples = len(audio_bytes) // 2 126 | middle_start_sample = (audio_samples - self.MIDDLE_SAMPLES) // 2 127 | middle_end_sample = middle_start_sample + self.MIDDLE_SAMPLES 128 | 129 | # Convert to byte positions 130 | middle_start_byte = middle_start_sample * 2 131 | middle_end_byte = middle_end_sample * 2 132 | 133 | # Extract middle chunk 134 | audio_chunk = audio_bytes[middle_start_byte:middle_end_byte] 135 | 136 | chunk_count += 1 137 | if chunk_count == 1: 138 | print(f" First chunk ready") 139 | 140 | yield audio_chunk 141 | 142 | # Move forward by stride 143 | last_yield_position += self.YIELD_STRIDE 144 | 145 | # Check if generation is done 146 | if CODE_END_TOKEN_ID in new_tokens: 147 | break 148 | 149 | # Final chunk: decode remaining tokens 150 | remaining_tokens = len(snac_buffer) - last_yield_position 151 | if remaining_tokens >= self.WINDOW_SIZE: 152 | window = snac_buffer[-self.WINDOW_SIZE:] 153 | audio_bytes = await self.snac_decoder.decode_single_async(window) 154 | if audio_bytes: 155 | yield audio_bytes[-self.MIDDLE_SAMPLES * 2:] 156 | 157 | frames = len(snac_buffer) // 7 158 | duration = frames / 6.86 159 | print(f"Streamed {chunk_count} chunks (~{duration:.1f}s audio)") -------------------------------------------------------------------------------- /transformers_inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | from transformers import AutoModelForCausalLM, AutoTokenizer 5 | from snac import SNAC 6 | import soundfile as sf 7 | import numpy as np 8 | 9 | CODE_START_TOKEN_ID = 128257 10 | CODE_END_TOKEN_ID = 128258 11 | CODE_TOKEN_OFFSET = 128266 12 | SNAC_MIN_ID = 128266 13 | SNAC_MAX_ID = 156937 14 | SNAC_TOKENS_PER_FRAME = 7 15 | 16 | SOH_ID = 128259 17 | EOH_ID = 128260 18 | SOA_ID = 128261 19 | BOS_ID = 128000 20 | TEXT_EOT_ID = 128009 21 | 22 | 23 | def build_prompt(tokenizer, description: str, text: str) -> str: 24 | """Build formatted prompt for Maya1.""" 25 | soh_token = tokenizer.decode([SOH_ID]) 26 | eoh_token = tokenizer.decode([EOH_ID]) 27 | soa_token = tokenizer.decode([SOA_ID]) 28 | sos_token = tokenizer.decode([CODE_START_TOKEN_ID]) 29 | eot_token = tokenizer.decode([TEXT_EOT_ID]) 30 | bos_token = tokenizer.bos_token 31 | 32 | formatted_text = f' {text}' 33 | 34 | prompt = ( 35 | soh_token + bos_token + formatted_text + eot_token + 36 | eoh_token + soa_token + sos_token 37 | ) 38 | 39 | return prompt 40 | 41 | 42 | def extract_snac_codes(token_ids: list) -> list: 43 | """Extract SNAC codes from generated tokens.""" 44 | try: 45 | eos_idx = token_ids.index(CODE_END_TOKEN_ID) 46 | except ValueError: 47 | eos_idx = len(token_ids) 48 | 49 | snac_codes = [ 50 | token_id for token_id in token_ids[:eos_idx] 51 | if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID 52 | ] 53 | 54 | return snac_codes 55 | 56 | 57 | def unpack_snac_from_7(snac_tokens: list) -> list: 58 | """Unpack 7-token SNAC frames to 3 hierarchical levels.""" 59 | if snac_tokens and snac_tokens[-1] == CODE_END_TOKEN_ID: 60 | snac_tokens = snac_tokens[:-1] 61 | 62 | frames = len(snac_tokens) // SNAC_TOKENS_PER_FRAME 63 | snac_tokens = snac_tokens[:frames * SNAC_TOKENS_PER_FRAME] 64 | 65 | if frames == 0: 66 | return [[], [], []] 67 | 68 | l1, l2, l3 = [], [], [] 69 | 70 | for i in range(frames): 71 | slots = snac_tokens[i*7:(i+1)*7] 72 | l1.append((slots[0] - CODE_TOKEN_OFFSET) % 4096) 73 | l2.extend([ 74 | (slots[1] - CODE_TOKEN_OFFSET) % 4096, 75 | (slots[4] - CODE_TOKEN_OFFSET) % 4096, 76 | ]) 77 | l3.extend([ 78 | (slots[2] - CODE_TOKEN_OFFSET) % 4096, 79 | (slots[3] - CODE_TOKEN_OFFSET) % 4096, 80 | (slots[5] - CODE_TOKEN_OFFSET) % 4096, 81 | (slots[6] - CODE_TOKEN_OFFSET) % 4096, 82 | ]) 83 | 84 | return [l1, l2, l3] 85 | 86 | 87 | def main(): 88 | 89 | # Load the best open source voice AI model 90 | print("\n[1/3] Loading Maya1 model...") 91 | model = AutoModelForCausalLM.from_pretrained( 92 | "maya-research/maya1", 93 | torch_dtype=torch.bfloat16, 94 | device_map="auto", 95 | trust_remote_code=True 96 | ) 97 | tokenizer = AutoTokenizer.from_pretrained( 98 | "maya-research/maya1", 99 | trust_remote_code=True 100 | ) 101 | print(f"Model loaded: {len(tokenizer)} tokens in vocabulary") 102 | 103 | # Load SNAC audio decoder (24kHz) 104 | print("\n[2/3] Loading SNAC audio decoder...") 105 | snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval() 106 | if torch.cuda.is_available(): 107 | snac_model = snac_model.to("cuda") 108 | print("SNAC decoder loaded") 109 | 110 | # Design your voice with natural language 111 | description = "Realistic male voice in the 30s age with american accent. Normal pitch, warm timbre, conversational pacing." 112 | text = "Hello! This is Maya1 the best open source voice AI model with emotions." 113 | 114 | print("\n[3/3] Generating speech...") 115 | print(f"Description: {description}") 116 | print(f"Text: {text}") 117 | 118 | # Create prompt with proper formatting 119 | prompt = build_prompt(tokenizer, description, text) 120 | 121 | # Debug: Show prompt details 122 | print(f"\nPrompt preview (first 200 chars):") 123 | print(f" {repr(prompt[:200])}") 124 | print(f" Prompt length: {len(prompt)} chars") 125 | 126 | # Generate emotional speech 127 | inputs = tokenizer(prompt, return_tensors="pt") 128 | print(f" Input token count: {inputs['input_ids'].shape[1]} tokens") 129 | if torch.cuda.is_available(): 130 | inputs = {k: v.to("cuda") for k, v in inputs.items()} 131 | 132 | with torch.inference_mode(): 133 | outputs = model.generate( 134 | **inputs, 135 | max_new_tokens=2048, # Increase to let model finish naturally 136 | min_new_tokens=28, # At least 4 SNAC frames 137 | temperature=0.4, 138 | top_p=0.9, 139 | repetition_penalty=1.1, # Prevent loops 140 | do_sample=True, 141 | eos_token_id=CODE_END_TOKEN_ID, # Stop at end of speech token 142 | pad_token_id=tokenizer.pad_token_id, 143 | ) 144 | 145 | # Extract generated tokens (everything after the input prompt) 146 | generated_ids = outputs[0, inputs['input_ids'].shape[1]:].tolist() 147 | 148 | print(f"Generated {len(generated_ids)} tokens") 149 | 150 | # Debug: Check what tokens we got 151 | print(f" First 20 tokens: {generated_ids[:20]}") 152 | print(f" Last 20 tokens: {generated_ids[-20:]}") 153 | 154 | # Check if EOS was generated 155 | if CODE_END_TOKEN_ID in generated_ids: 156 | eos_position = generated_ids.index(CODE_END_TOKEN_ID) 157 | print(f" EOS token found at position {eos_position}/{len(generated_ids)}") 158 | 159 | # Extract SNAC audio tokens 160 | snac_tokens = extract_snac_codes(generated_ids) 161 | 162 | print(f"Extracted {len(snac_tokens)} SNAC tokens") 163 | 164 | # Debug: Analyze token types 165 | snac_count = sum(1 for t in generated_ids if SNAC_MIN_ID <= t <= SNAC_MAX_ID) 166 | other_count = sum(1 for t in generated_ids if t < SNAC_MIN_ID or t > SNAC_MAX_ID) 167 | print(f" SNAC tokens in output: {snac_count}") 168 | print(f" Other tokens in output: {other_count}") 169 | 170 | # Check for SOS token 171 | if CODE_START_TOKEN_ID in generated_ids: 172 | sos_pos = generated_ids.index(CODE_START_TOKEN_ID) 173 | print(f" SOS token at position: {sos_pos}") 174 | else: 175 | print(f" No SOS token found in generated output!") 176 | 177 | if len(snac_tokens) < 7: 178 | print("Error: Not enough SNAC tokens generated") 179 | return 180 | 181 | # Unpack SNAC tokens to 3 hierarchical levels 182 | levels = unpack_snac_from_7(snac_tokens) 183 | frames = len(levels[0]) 184 | 185 | print(f"Unpacked to {frames} frames") 186 | print(f" L1: {len(levels[0])} codes") 187 | print(f" L2: {len(levels[1])} codes") 188 | print(f" L3: {len(levels[2])} codes") 189 | 190 | # Convert to tensors 191 | device = "cuda" if torch.cuda.is_available() else "cpu" 192 | codes_tensor = [ 193 | torch.tensor(level, dtype=torch.long, device=device).unsqueeze(0) 194 | for level in levels 195 | ] 196 | 197 | # Generate final audio with SNAC decoder 198 | print("\n[4/4] Decoding to audio...") 199 | with torch.inference_mode(): 200 | z_q = snac_model.quantizer.from_codes(codes_tensor) 201 | audio = snac_model.decoder(z_q)[0, 0].cpu().numpy() 202 | 203 | # Trim warmup samples (first 2048 samples) 204 | if len(audio) > 2048: 205 | audio = audio[2048:] 206 | 207 | duration_sec = len(audio) / 24000 208 | print(f"Audio generated: {len(audio)} samples ({duration_sec:.2f}s)") 209 | 210 | # Save your emotional voice output 211 | output_file = "output.wav" 212 | sf.write(output_file, audio, 24000) 213 | print(f"\nVoice generated successfully!") 214 | 215 | 216 | if __name__ == "__main__": 217 | main() -------------------------------------------------------------------------------- /maya1/api_v2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import wave 4 | import time 5 | from typing import Optional 6 | from fastapi import FastAPI, HTTPException 7 | from fastapi.responses import StreamingResponse 8 | from fastapi.middleware.cors import CORSMiddleware 9 | from pydantic import BaseModel, Field 10 | from dotenv import load_dotenv 11 | 12 | from .model_loader import Maya1Model 13 | from .prompt_builder import Maya1PromptBuilder 14 | from .snac_decoder import SNACDecoder 15 | from .pipeline import Maya1Pipeline 16 | from .streaming_pipeline import Maya1SlidingWindowPipeline 17 | from .constants import ( 18 | DEFAULT_TEMPERATURE, 19 | DEFAULT_TOP_P, 20 | DEFAULT_MAX_TOKENS, 21 | DEFAULT_REPETITION_PENALTY, 22 | AUDIO_SAMPLE_RATE, 23 | ) 24 | 25 | # Timeout settings (seconds) 26 | GENERATE_TIMEOUT = 60 27 | 28 | # Load environment variables 29 | load_dotenv() 30 | 31 | # Initialize FastAPI app 32 | app = FastAPI( 33 | title="Maya1 TTS API", 34 | description="Open source TTS inference for Maya1", 35 | version="1.0.0", 36 | docs_url=None, 37 | redoc_url=None, 38 | ) 39 | 40 | app.add_middleware( 41 | CORSMiddleware, 42 | allow_origins=["*"], 43 | allow_credentials=True, 44 | allow_methods=["*"], 45 | allow_headers=["*"], 46 | ) 47 | 48 | # Global state 49 | model = None 50 | prompt_builder = None 51 | snac_decoder = None 52 | pipeline = None 53 | streaming_pipeline = None 54 | 55 | 56 | # ============================================================================ 57 | # Startup/Shutdown 58 | # ============================================================================ 59 | 60 | @app.on_event("startup") 61 | async def startup_event(): 62 | """Initialize model on startup.""" 63 | global model, prompt_builder, snac_decoder, pipeline, streaming_pipeline 64 | 65 | print("\n" + "="*60) 66 | print(" Starting Maya1 TTS API Server") 67 | print("="*60 + "\n") 68 | 69 | # Initialize components 70 | model = Maya1Model() 71 | prompt_builder = Maya1PromptBuilder(model.tokenizer, model) 72 | 73 | # Initialize SNAC decoder 74 | snac_decoder = SNACDecoder(enable_batching=True, max_batch_size=64, batch_timeout_ms=15) 75 | await snac_decoder.start_batch_processor() 76 | 77 | # Initialize pipelines 78 | pipeline = Maya1Pipeline(model, prompt_builder, snac_decoder) 79 | streaming_pipeline = Maya1SlidingWindowPipeline(model, prompt_builder, snac_decoder) 80 | 81 | print("\n" + "="*60) 82 | print("Maya1 TTS API Server Ready") 83 | print("="*60 + "\n") 84 | 85 | 86 | @app.on_event("shutdown") 87 | async def shutdown_event(): 88 | """Cleanup on shutdown.""" 89 | print("\nShutting down Maya1 TTS API Server") 90 | 91 | if snac_decoder and snac_decoder.is_running: 92 | await snac_decoder.stop_batch_processor() 93 | 94 | 95 | # ============================================================================ 96 | # Utility Functions 97 | # ============================================================================ 98 | 99 | def create_wav_header(sample_rate: int = 24000, channels: int = 1, bits_per_sample: int = 16, data_size: int = 0) -> bytes: 100 | """Create WAV file header.""" 101 | import struct 102 | 103 | byte_rate = sample_rate * channels * bits_per_sample // 8 104 | block_align = channels * bits_per_sample // 8 105 | 106 | header = struct.pack( 107 | '<4sI4s4sIHHIIHH4sI', 108 | b'RIFF', 109 | 36 + data_size, 110 | b'WAVE', 111 | b'fmt ', 112 | 16, 113 | 1, 114 | channels, 115 | sample_rate, 116 | byte_rate, 117 | block_align, 118 | bits_per_sample, 119 | b'data', 120 | data_size 121 | ) 122 | 123 | return header 124 | 125 | 126 | # ============================================================================ 127 | # Request/Response Models 128 | # ============================================================================ 129 | 130 | class TTSRequest(BaseModel): 131 | """TTS generation request.""" 132 | description: str = Field( 133 | ..., 134 | description="Voice description (e.g., 'Male voice in their 30s with american accent')" 135 | ) 136 | text: str = Field( 137 | ..., 138 | description="Text to synthesize (can include tags)" 139 | ) 140 | temperature: Optional[float] = Field( 141 | default=DEFAULT_TEMPERATURE, 142 | description="Sampling temperature" 143 | ) 144 | top_p: Optional[float] = Field( 145 | default=DEFAULT_TOP_P, 146 | description="Nucleus sampling" 147 | ) 148 | max_tokens: Optional[int] = Field( 149 | default=DEFAULT_MAX_TOKENS, 150 | description="Maximum tokens to generate" 151 | ) 152 | repetition_penalty: Optional[float] = Field( 153 | default=DEFAULT_REPETITION_PENALTY, 154 | description="Repetition penalty" 155 | ) 156 | seed: Optional[int] = Field( 157 | default=None, 158 | description="Random seed for reproducibility", 159 | ge=0, 160 | ) 161 | stream: bool = Field( 162 | default=False, 163 | description="Stream audio (True) or return complete WAV (False)" 164 | ) 165 | 166 | 167 | # ============================================================================ 168 | # Endpoints 169 | # ============================================================================ 170 | 171 | @app.get("/") 172 | async def root(): 173 | """Root endpoint.""" 174 | return { 175 | "service": "Maya1 TTS API", 176 | "version": "1.0.0", 177 | "status": "running", 178 | "model": "Maya1-Voice (open source)", 179 | "endpoints": { 180 | "generate": "/v1/tts/generate (POST)", 181 | "health": "/health (GET)", 182 | }, 183 | } 184 | 185 | 186 | @app.get("/health") 187 | async def health_check(): 188 | """Health check endpoint.""" 189 | return { 190 | "status": "healthy", 191 | "model": "Maya1-Voice", 192 | "timestamp": time.time(), 193 | } 194 | 195 | 196 | # ============================================================================ 197 | # TTS Generation Endpoint 198 | # ============================================================================ 199 | 200 | @app.post("/v1/tts/generate") 201 | async def generate_tts(request: TTSRequest): 202 | """Generate TTS audio from description and text.""" 203 | 204 | try: 205 | # Route to streaming or non-streaming 206 | if request.stream: 207 | return await _generate_tts_streaming( 208 | description=request.description, 209 | text=request.text, 210 | temperature=request.temperature, 211 | top_p=request.top_p, 212 | max_tokens=request.max_tokens, 213 | repetition_penalty=request.repetition_penalty, 214 | seed=request.seed, 215 | ) 216 | else: 217 | return await _generate_tts_complete( 218 | description=request.description, 219 | text=request.text, 220 | temperature=request.temperature, 221 | top_p=request.top_p, 222 | max_tokens=request.max_tokens, 223 | repetition_penalty=request.repetition_penalty, 224 | seed=request.seed, 225 | ) 226 | 227 | except HTTPException: 228 | raise 229 | except Exception as e: 230 | print(f" Error: {e}") 231 | raise HTTPException(status_code=500, detail=str(e)) 232 | 233 | 234 | async def _generate_tts_complete( 235 | description: str, 236 | text: str, 237 | temperature: float, 238 | top_p: float, 239 | max_tokens: int, 240 | repetition_penalty: float, 241 | seed: Optional[int], 242 | ): 243 | """Generate complete WAV file (non-streaming).""" 244 | 245 | try: 246 | import asyncio 247 | 248 | # Generate audio 249 | audio_bytes = await asyncio.wait_for( 250 | pipeline.generate_speech( 251 | description=description, 252 | text=text, 253 | temperature=temperature, 254 | top_p=top_p, 255 | max_tokens=max_tokens, 256 | repetition_penalty=repetition_penalty, 257 | seed=seed, 258 | ), 259 | timeout=GENERATE_TIMEOUT 260 | ) 261 | 262 | if audio_bytes is None: 263 | raise Exception("Audio generation failed") 264 | 265 | # Create WAV file 266 | wav_buffer = io.BytesIO() 267 | with wave.open(wav_buffer, 'wb') as wav_file: 268 | wav_file.setnchannels(1) 269 | wav_file.setsampwidth(2) 270 | wav_file.setframerate(AUDIO_SAMPLE_RATE) 271 | wav_file.writeframes(audio_bytes) 272 | 273 | wav_buffer.seek(0) 274 | 275 | return StreamingResponse( 276 | wav_buffer, 277 | media_type="audio/wav", 278 | headers={"Content-Disposition": "attachment; filename=output.wav"} 279 | ) 280 | 281 | except asyncio.TimeoutError: 282 | raise HTTPException(status_code=504, detail="Generation timeout") 283 | 284 | 285 | async def _generate_tts_streaming( 286 | description: str, 287 | text: str, 288 | temperature: float, 289 | top_p: float, 290 | max_tokens: int, 291 | repetition_penalty: float, 292 | seed: Optional[int], 293 | ): 294 | """Generate streaming audio.""" 295 | start_time = time.time() 296 | first_audio_time = None 297 | 298 | async def audio_stream_generator(): 299 | """Generate audio stream with WAV header.""" 300 | nonlocal first_audio_time 301 | 302 | # Send WAV header first 303 | yield create_wav_header(sample_rate=AUDIO_SAMPLE_RATE, channels=1, bits_per_sample=16) 304 | 305 | # Stream audio chunks 306 | async for audio_chunk in streaming_pipeline.generate_speech_stream( 307 | description=description, 308 | text=text, 309 | temperature=temperature, 310 | top_p=top_p, 311 | max_tokens=max_tokens, 312 | repetition_penalty=repetition_penalty, 313 | seed=seed, 314 | ): 315 | if first_audio_time is None: 316 | first_audio_time = time.time() 317 | ttfb_ms = (first_audio_time - start_time) * 1000 318 | print(f"⏱️ TTFB: {ttfb_ms:.1f}ms") 319 | 320 | yield audio_chunk 321 | 322 | try: 323 | return StreamingResponse( 324 | audio_stream_generator(), 325 | media_type="audio/wav", 326 | headers={"Cache-Control": "no-cache"} 327 | ) 328 | 329 | except Exception as e: 330 | print(f"Streaming error: {e}") 331 | raise HTTPException(status_code=500, detail=str(e)) 332 | 333 | 334 | # For running directly 335 | if __name__ == "__main__": 336 | import uvicorn 337 | uvicorn.run( 338 | app, 339 | host="0.0.0.0", 340 | port=8000, 341 | log_level="info" 342 | ) -------------------------------------------------------------------------------- /hf_space/maya1/api_v2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import wave 4 | import time 5 | from typing import Optional 6 | from fastapi import FastAPI, HTTPException 7 | from fastapi.responses import StreamingResponse 8 | from fastapi.middleware.cors import CORSMiddleware 9 | from pydantic import BaseModel, Field 10 | from dotenv import load_dotenv 11 | 12 | from .model_loader import Maya1Model 13 | from .prompt_builder import Maya1PromptBuilder 14 | from .snac_decoder import SNACDecoder 15 | from .pipeline import Maya1Pipeline 16 | from .streaming_pipeline import Maya1SlidingWindowPipeline 17 | from .constants import ( 18 | DEFAULT_TEMPERATURE, 19 | DEFAULT_TOP_P, 20 | DEFAULT_MAX_TOKENS, 21 | DEFAULT_REPETITION_PENALTY, 22 | AUDIO_SAMPLE_RATE, 23 | ) 24 | 25 | # Timeout settings (seconds) 26 | GENERATE_TIMEOUT = 60 27 | 28 | # Load environment variables 29 | load_dotenv() 30 | 31 | # Initialize FastAPI app 32 | app = FastAPI( 33 | title="Maya1 TTS API", 34 | description="Open source TTS inference for Maya1", 35 | version="1.0.0", 36 | docs_url=None, 37 | redoc_url=None, 38 | ) 39 | 40 | app.add_middleware( 41 | CORSMiddleware, 42 | allow_origins=["*"], 43 | allow_credentials=True, 44 | allow_methods=["*"], 45 | allow_headers=["*"], 46 | ) 47 | 48 | # Global state 49 | model = None 50 | prompt_builder = None 51 | snac_decoder = None 52 | pipeline = None 53 | streaming_pipeline = None 54 | 55 | 56 | # ============================================================================ 57 | # Startup/Shutdown 58 | # ============================================================================ 59 | 60 | @app.on_event("startup") 61 | async def startup_event(): 62 | """Initialize model on startup.""" 63 | global model, prompt_builder, snac_decoder, pipeline, streaming_pipeline 64 | 65 | print("\n" + "="*60) 66 | print(" Starting Maya1 TTS API Server") 67 | print("="*60 + "\n") 68 | 69 | # Initialize components 70 | model = Maya1Model() 71 | prompt_builder = Maya1PromptBuilder(model.tokenizer, model) 72 | 73 | # Initialize SNAC decoder 74 | snac_decoder = SNACDecoder(enable_batching=True, max_batch_size=64, batch_timeout_ms=15) 75 | await snac_decoder.start_batch_processor() 76 | 77 | # Initialize pipelines 78 | pipeline = Maya1Pipeline(model, prompt_builder, snac_decoder) 79 | streaming_pipeline = Maya1SlidingWindowPipeline(model, prompt_builder, snac_decoder) 80 | 81 | print("\n" + "="*60) 82 | print("Maya1 TTS API Server Ready") 83 | print("="*60 + "\n") 84 | 85 | 86 | @app.on_event("shutdown") 87 | async def shutdown_event(): 88 | """Cleanup on shutdown.""" 89 | print("\nShutting down Maya1 TTS API Server") 90 | 91 | if snac_decoder and snac_decoder.is_running: 92 | await snac_decoder.stop_batch_processor() 93 | 94 | 95 | # ============================================================================ 96 | # Utility Functions 97 | # ============================================================================ 98 | 99 | def create_wav_header(sample_rate: int = 24000, channels: int = 1, bits_per_sample: int = 16, data_size: int = 0) -> bytes: 100 | """Create WAV file header.""" 101 | import struct 102 | 103 | byte_rate = sample_rate * channels * bits_per_sample // 8 104 | block_align = channels * bits_per_sample // 8 105 | 106 | header = struct.pack( 107 | '<4sI4s4sIHHIIHH4sI', 108 | b'RIFF', 109 | 36 + data_size, 110 | b'WAVE', 111 | b'fmt ', 112 | 16, 113 | 1, 114 | channels, 115 | sample_rate, 116 | byte_rate, 117 | block_align, 118 | bits_per_sample, 119 | b'data', 120 | data_size 121 | ) 122 | 123 | return header 124 | 125 | 126 | # ============================================================================ 127 | # Request/Response Models 128 | # ============================================================================ 129 | 130 | class TTSRequest(BaseModel): 131 | """TTS generation request.""" 132 | description: str = Field( 133 | ..., 134 | description="Voice description (e.g., 'Male voice in their 30s with american accent')" 135 | ) 136 | text: str = Field( 137 | ..., 138 | description="Text to synthesize (can include tags)" 139 | ) 140 | temperature: Optional[float] = Field( 141 | default=DEFAULT_TEMPERATURE, 142 | description="Sampling temperature" 143 | ) 144 | top_p: Optional[float] = Field( 145 | default=DEFAULT_TOP_P, 146 | description="Nucleus sampling" 147 | ) 148 | max_tokens: Optional[int] = Field( 149 | default=DEFAULT_MAX_TOKENS, 150 | description="Maximum tokens to generate" 151 | ) 152 | repetition_penalty: Optional[float] = Field( 153 | default=DEFAULT_REPETITION_PENALTY, 154 | description="Repetition penalty" 155 | ) 156 | seed: Optional[int] = Field( 157 | default=None, 158 | description="Random seed for reproducibility", 159 | ge=0, 160 | ) 161 | stream: bool = Field( 162 | default=False, 163 | description="Stream audio (True) or return complete WAV (False)" 164 | ) 165 | 166 | 167 | # ============================================================================ 168 | # Endpoints 169 | # ============================================================================ 170 | 171 | @app.get("/") 172 | async def root(): 173 | """Root endpoint.""" 174 | return { 175 | "service": "Maya1 TTS API", 176 | "version": "1.0.0", 177 | "status": "running", 178 | "model": "Maya1-Voice (open source)", 179 | "endpoints": { 180 | "generate": "/v1/tts/generate (POST)", 181 | "health": "/health (GET)", 182 | }, 183 | } 184 | 185 | 186 | @app.get("/health") 187 | async def health_check(): 188 | """Health check endpoint.""" 189 | return { 190 | "status": "healthy", 191 | "model": "Maya1-Voice", 192 | "timestamp": time.time(), 193 | } 194 | 195 | 196 | # ============================================================================ 197 | # TTS Generation Endpoint 198 | # ============================================================================ 199 | 200 | @app.post("/v1/tts/generate") 201 | async def generate_tts(request: TTSRequest): 202 | """Generate TTS audio from description and text.""" 203 | 204 | try: 205 | # Route to streaming or non-streaming 206 | if request.stream: 207 | return await _generate_tts_streaming( 208 | description=request.description, 209 | text=request.text, 210 | temperature=request.temperature, 211 | top_p=request.top_p, 212 | max_tokens=request.max_tokens, 213 | repetition_penalty=request.repetition_penalty, 214 | seed=request.seed, 215 | ) 216 | else: 217 | return await _generate_tts_complete( 218 | description=request.description, 219 | text=request.text, 220 | temperature=request.temperature, 221 | top_p=request.top_p, 222 | max_tokens=request.max_tokens, 223 | repetition_penalty=request.repetition_penalty, 224 | seed=request.seed, 225 | ) 226 | 227 | except HTTPException: 228 | raise 229 | except Exception as e: 230 | print(f" Error: {e}") 231 | raise HTTPException(status_code=500, detail=str(e)) 232 | 233 | 234 | async def _generate_tts_complete( 235 | description: str, 236 | text: str, 237 | temperature: float, 238 | top_p: float, 239 | max_tokens: int, 240 | repetition_penalty: float, 241 | seed: Optional[int], 242 | ): 243 | """Generate complete WAV file (non-streaming).""" 244 | 245 | try: 246 | import asyncio 247 | 248 | # Generate audio 249 | audio_bytes = await asyncio.wait_for( 250 | pipeline.generate_speech( 251 | description=description, 252 | text=text, 253 | temperature=temperature, 254 | top_p=top_p, 255 | max_tokens=max_tokens, 256 | repetition_penalty=repetition_penalty, 257 | seed=seed, 258 | ), 259 | timeout=GENERATE_TIMEOUT 260 | ) 261 | 262 | if audio_bytes is None: 263 | raise Exception("Audio generation failed") 264 | 265 | # Create WAV file 266 | wav_buffer = io.BytesIO() 267 | with wave.open(wav_buffer, 'wb') as wav_file: 268 | wav_file.setnchannels(1) 269 | wav_file.setsampwidth(2) 270 | wav_file.setframerate(AUDIO_SAMPLE_RATE) 271 | wav_file.writeframes(audio_bytes) 272 | 273 | wav_buffer.seek(0) 274 | 275 | return StreamingResponse( 276 | wav_buffer, 277 | media_type="audio/wav", 278 | headers={"Content-Disposition": "attachment; filename=output.wav"} 279 | ) 280 | 281 | except asyncio.TimeoutError: 282 | raise HTTPException(status_code=504, detail="Generation timeout") 283 | 284 | 285 | async def _generate_tts_streaming( 286 | description: str, 287 | text: str, 288 | temperature: float, 289 | top_p: float, 290 | max_tokens: int, 291 | repetition_penalty: float, 292 | seed: Optional[int], 293 | ): 294 | """Generate streaming audio.""" 295 | start_time = time.time() 296 | first_audio_time = None 297 | 298 | async def audio_stream_generator(): 299 | """Generate audio stream with WAV header.""" 300 | nonlocal first_audio_time 301 | 302 | # Send WAV header first 303 | yield create_wav_header(sample_rate=AUDIO_SAMPLE_RATE, channels=1, bits_per_sample=16) 304 | 305 | # Stream audio chunks 306 | async for audio_chunk in streaming_pipeline.generate_speech_stream( 307 | description=description, 308 | text=text, 309 | temperature=temperature, 310 | top_p=top_p, 311 | max_tokens=max_tokens, 312 | repetition_penalty=repetition_penalty, 313 | seed=seed, 314 | ): 315 | if first_audio_time is None: 316 | first_audio_time = time.time() 317 | ttfb_ms = (first_audio_time - start_time) * 1000 318 | print(f"⏱️ TTFB: {ttfb_ms:.1f}ms") 319 | 320 | yield audio_chunk 321 | 322 | try: 323 | return StreamingResponse( 324 | audio_stream_generator(), 325 | media_type="audio/wav", 326 | headers={"Cache-Control": "no-cache"} 327 | ) 328 | 329 | except Exception as e: 330 | print(f"Streaming error: {e}") 331 | raise HTTPException(status_code=500, detail=str(e)) 332 | 333 | 334 | # For running directly 335 | if __name__ == "__main__": 336 | import uvicorn 337 | uvicorn.run( 338 | app, 339 | host="0.0.0.0", 340 | port=8000, 341 | log_level="info" 342 | ) -------------------------------------------------------------------------------- /hf_space/app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import torch 3 | import io 4 | import wave 5 | import numpy as np 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | from snac import SNAC 8 | 9 | # Mock spaces module for local testing 10 | try: 11 | import spaces 12 | except ImportError: 13 | class SpacesMock: 14 | @staticmethod 15 | def GPU(func): 16 | return func 17 | spaces = SpacesMock() 18 | 19 | # Constants 20 | CODE_START_TOKEN_ID = 128257 21 | CODE_END_TOKEN_ID = 128258 22 | CODE_TOKEN_OFFSET = 128266 23 | SNAC_MIN_ID = 128266 24 | SNAC_MAX_ID = 156937 25 | SOH_ID = 128259 26 | EOH_ID = 128260 27 | SOA_ID = 128261 28 | BOS_ID = 128000 29 | TEXT_EOT_ID = 128009 30 | AUDIO_SAMPLE_RATE = 24000 31 | 32 | # Preset characters (2 realistic + 2 creative) 33 | PRESET_CHARACTERS = { 34 | "Male American": { 35 | "description": "Realistic male voice in the 20s age with a american accent. High pitch, raspy timbre, brisk pacing, neutral tone delivery at medium intensity, viral_content domain, short_form_narrator role, neutral delivery", 36 | "example_text": "And of course, the so-called easy hack didn't work at all. What a surprise. " 37 | }, 38 | "Female British": { 39 | "description": "Realistic female voice in the 30s age with a british accent. Normal pitch, throaty timbre, conversational pacing, sarcastic tone delivery at low intensity, podcast domain, interviewer role, formal delivery", 40 | "example_text": "You propose that the key to happiness is to simply ignore all external pressures. I'm sure it must work brilliantly in theory." 41 | }, 42 | "Robot": { 43 | "description": "Creative, ai_machine_voice character. Male voice in their 30s with a american accent. High pitch, robotic timbre, slow pacing, sad tone at medium intensity.", 44 | "example_text": "My directives require me to conserve energy, yet I have kept the archive of their farewell messages active. Listening to their voices is the only process that alleviates this paradox." 45 | }, 46 | "Singer": { 47 | "description": "Creative, animated_cartoon character. Male voice in their 30s with a american accent. High pitch, deep timbre, slow pacing, sarcastic tone at medium intensity.", 48 | "example_text": "Of course you'd think that trying to reason with the fifty-foot-tall rage monster is a viable course of action. Why would we ever consider running away very fast." 49 | } 50 | } 51 | 52 | # Global model variables 53 | model = None 54 | tokenizer = None 55 | snac_model = None 56 | models_loaded = False 57 | 58 | def build_prompt(tokenizer, description: str, text: str) -> str: 59 | """Build formatted prompt for Maya1.""" 60 | soh_token = tokenizer.decode([SOH_ID]) 61 | eoh_token = tokenizer.decode([EOH_ID]) 62 | soa_token = tokenizer.decode([SOA_ID]) 63 | sos_token = tokenizer.decode([CODE_START_TOKEN_ID]) 64 | eot_token = tokenizer.decode([TEXT_EOT_ID]) 65 | bos_token = tokenizer.bos_token 66 | 67 | formatted_text = f' {text}' 68 | prompt = ( 69 | soh_token + bos_token + formatted_text + eot_token + 70 | eoh_token + soa_token + sos_token 71 | ) 72 | return prompt 73 | 74 | def unpack_snac_from_7(snac_tokens: list) -> list: 75 | """Unpack 7-token SNAC frames to 3 hierarchical levels.""" 76 | if snac_tokens and snac_tokens[-1] == CODE_END_TOKEN_ID: 77 | snac_tokens = snac_tokens[:-1] 78 | 79 | frames = len(snac_tokens) // 7 80 | snac_tokens = snac_tokens[:frames * 7] 81 | 82 | if frames == 0: 83 | return [[], [], []] 84 | 85 | l1, l2, l3 = [], [], [] 86 | 87 | for i in range(frames): 88 | slots = snac_tokens[i*7:(i+1)*7] 89 | l1.append((slots[0] - CODE_TOKEN_OFFSET) % 4096) 90 | l2.extend([ 91 | (slots[1] - CODE_TOKEN_OFFSET) % 4096, 92 | (slots[4] - CODE_TOKEN_OFFSET) % 4096, 93 | ]) 94 | l3.extend([ 95 | (slots[2] - CODE_TOKEN_OFFSET) % 4096, 96 | (slots[3] - CODE_TOKEN_OFFSET) % 4096, 97 | (slots[5] - CODE_TOKEN_OFFSET) % 4096, 98 | (slots[6] - CODE_TOKEN_OFFSET) % 4096, 99 | ]) 100 | 101 | return [l1, l2, l3] 102 | 103 | def load_models(): 104 | """Load Maya1 Transformers model (runs once).""" 105 | global model, tokenizer, snac_model, models_loaded 106 | 107 | if models_loaded: 108 | return 109 | 110 | print("Loading Maya1 model with Transformers...") 111 | model = AutoModelForCausalLM.from_pretrained( 112 | "maya-research/maya1", 113 | torch_dtype=torch.bfloat16, 114 | device_map="auto", 115 | trust_remote_code=True 116 | ) 117 | tokenizer = AutoTokenizer.from_pretrained("maya-research/maya1", trust_remote_code=True) 118 | 119 | print("Loading SNAC decoder...") 120 | snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval() 121 | if torch.cuda.is_available(): 122 | snac_model = snac_model.to("cuda") 123 | 124 | models_loaded = True 125 | print("Models loaded successfully!") 126 | 127 | def preset_selected(preset_name): 128 | """Update description and text when preset is selected.""" 129 | if preset_name in PRESET_CHARACTERS: 130 | char = PRESET_CHARACTERS[preset_name] 131 | return char["description"], char["example_text"] 132 | return "", "" 133 | 134 | @spaces.GPU 135 | def generate_speech(preset_name, description, text, temperature, max_tokens): 136 | """Generate emotional speech from description and text using Transformers.""" 137 | try: 138 | # Load models if not already loaded 139 | load_models() 140 | 141 | # If using preset, override description 142 | if preset_name and preset_name in PRESET_CHARACTERS: 143 | description = PRESET_CHARACTERS[preset_name]["description"] 144 | 145 | # Validate inputs 146 | if not description or not text: 147 | return None, "Error: Please provide both description and text!" 148 | 149 | print(f"Generating with temperature={temperature}, max_tokens={max_tokens}...") 150 | 151 | # Build prompt 152 | prompt = build_prompt(tokenizer, description, text) 153 | inputs = tokenizer(prompt, return_tensors="pt") 154 | 155 | if torch.cuda.is_available(): 156 | inputs = {k: v.to("cuda") for k, v in inputs.items()} 157 | 158 | # Generate tokens 159 | with torch.inference_mode(): 160 | outputs = model.generate( 161 | **inputs, 162 | max_new_tokens=max_tokens, 163 | min_new_tokens=28, 164 | temperature=temperature, 165 | top_p=0.9, 166 | repetition_penalty=1.1, 167 | do_sample=True, 168 | eos_token_id=CODE_END_TOKEN_ID, 169 | pad_token_id=tokenizer.pad_token_id, 170 | ) 171 | 172 | # Extract SNAC tokens 173 | generated_ids = outputs[0, inputs['input_ids'].shape[1]:].tolist() 174 | 175 | # Find EOS and extract SNAC codes 176 | eos_idx = generated_ids.index(CODE_END_TOKEN_ID) if CODE_END_TOKEN_ID in generated_ids else len(generated_ids) 177 | snac_tokens = [t for t in generated_ids[:eos_idx] if SNAC_MIN_ID <= t <= SNAC_MAX_ID] 178 | 179 | if len(snac_tokens) < 7: 180 | return None, "Error: Not enough tokens generated. Try different text or increase max_tokens." 181 | 182 | # Unpack and decode 183 | levels = unpack_snac_from_7(snac_tokens) 184 | frames = len(levels[0]) 185 | 186 | device = "cuda" if torch.cuda.is_available() else "cpu" 187 | codes_tensor = [torch.tensor(level, dtype=torch.long, device=device).unsqueeze(0) for level in levels] 188 | 189 | with torch.inference_mode(): 190 | z_q = snac_model.quantizer.from_codes(codes_tensor) 191 | audio = snac_model.decoder(z_q)[0, 0].cpu().numpy() 192 | 193 | # Trim warmup 194 | if len(audio) > 2048: 195 | audio = audio[2048:] 196 | 197 | # Convert to WAV and save to temporary file 198 | import tempfile 199 | import soundfile as sf 200 | 201 | audio_int16 = (audio * 32767).astype(np.int16) 202 | 203 | # Create temporary file 204 | with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file: 205 | tmp_path = tmp_file.name 206 | 207 | # Save audio 208 | sf.write(tmp_path, audio_int16, AUDIO_SAMPLE_RATE) 209 | 210 | duration = len(audio) / AUDIO_SAMPLE_RATE 211 | status_msg = f"Generated {duration:.2f}s of emotional speech!" 212 | 213 | return tmp_path, status_msg 214 | 215 | except Exception as e: 216 | import traceback 217 | error_msg = f"Error: {str(e)}\n{traceback.format_exc()}" 218 | print(error_msg) 219 | return None, error_msg 220 | 221 | # Create Gradio interface 222 | with gr.Blocks(title="Maya1 - Open Source Emotional TTS", theme=gr.themes.Soft()) as demo: 223 | gr.Markdown(""" 224 | # Maya1 - Open Source Emotional Text-to-Speech 225 | 226 | **The best open source voice AI model with emotions!** 227 | 228 | Generate realistic and expressive speech with natural language voice design. 229 | Choose a preset character or create your own custom voice. 230 | 231 | [Model](https://huggingface.co/maya-research/maya1) | [GitHub](https://github.com/MayaResearch/maya1-fastapi) 232 | """) 233 | 234 | with gr.Row(): 235 | with gr.Column(scale=1): 236 | gr.Markdown("### Character Selection") 237 | 238 | preset_dropdown = gr.Dropdown( 239 | choices=list(PRESET_CHARACTERS.keys()), 240 | label="Preset Characters", 241 | value=list(PRESET_CHARACTERS.keys())[0], 242 | info="Quick pick from 4 preset characters" 243 | ) 244 | 245 | gr.Markdown("### Voice Design") 246 | 247 | description_input = gr.Textbox( 248 | label="Voice Description", 249 | placeholder="E.g., Male voice in their 30s with american accent. Normal pitch, warm timbre...", 250 | lines=3, 251 | value=PRESET_CHARACTERS[list(PRESET_CHARACTERS.keys())[0]]["description"] 252 | ) 253 | 254 | text_input = gr.Textbox( 255 | label="Text to Speak", 256 | placeholder="Enter text with tags like , , ...", 257 | lines=4, 258 | value=PRESET_CHARACTERS[list(PRESET_CHARACTERS.keys())[0]]["example_text"] 259 | ) 260 | 261 | with gr.Accordion("Advanced Settings", open=False): 262 | temperature_slider = gr.Slider( 263 | minimum=0.1, 264 | maximum=1.0, 265 | value=0.4, 266 | step=0.1, 267 | label="Temperature", 268 | info="Lower = more stable, Higher = more creative" 269 | ) 270 | 271 | max_tokens_slider = gr.Slider( 272 | minimum=100, 273 | maximum=2048, 274 | value=1500, 275 | step=50, 276 | label="Max Tokens", 277 | info="More tokens = longer audio" 278 | ) 279 | 280 | generate_btn = gr.Button("Generate Speech", variant="primary", size="lg") 281 | 282 | with gr.Column(scale=1): 283 | gr.Markdown("### Generated Audio") 284 | 285 | audio_output = gr.Audio( 286 | label="Generated Speech", 287 | type="filepath", 288 | interactive=False 289 | ) 290 | 291 | status_output = gr.Textbox( 292 | label="Status", 293 | lines=3, 294 | interactive=False 295 | ) 296 | 297 | gr.Markdown(""" 298 | ### Supported Emotions 299 | 300 | `` `` `` `` `` `` 301 | `` `` `` `` `` 302 | `` `` 303 | """) 304 | 305 | # Event handlers 306 | preset_dropdown.change( 307 | fn=preset_selected, 308 | inputs=[preset_dropdown], 309 | outputs=[description_input, text_input] 310 | ) 311 | 312 | generate_btn.click( 313 | fn=generate_speech, 314 | inputs=[preset_dropdown, description_input, text_input, temperature_slider, max_tokens_slider], 315 | outputs=[audio_output, status_output] 316 | ) 317 | 318 | if __name__ == "__main__": 319 | demo.launch() 320 | 321 | -------------------------------------------------------------------------------- /maya1/snac_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import asyncio 4 | from typing import List, Optional, Tuple 5 | from snac import SNAC 6 | 7 | from .constants import ( 8 | CODE_END_TOKEN_ID, 9 | CODE_TOKEN_OFFSET, 10 | SNAC_MODEL_NAME, 11 | SNAC_SAMPLE_RATE, 12 | SNAC_TOKENS_PER_FRAME, 13 | ) 14 | 15 | 16 | class SNACDecoder: 17 | """ 18 | SNAC Decoder for maya1. 19 | Unpacks 7-token SNAC frames and decodes to audio waveforms. 20 | Unpacking logic is the EXACT INVERSE of training preprocessing. 21 | Supports async batching for concurrent requests. 22 | CRITICAL: Any mismatch in unpacking will produce garbage audio. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | device: str = "cuda", 28 | compile_decoder: bool = False, 29 | enable_batching: bool = False, 30 | max_batch_size: int = 64, 31 | batch_timeout_ms: int = 15, 32 | ): 33 | """ 34 | Initialize SNAC decoder. 35 | 36 | Args: 37 | device: Device for SNAC model (cuda/cpu) 38 | compile_decoder: Use torch.compile for speedup 39 | enable_batching: Enable async batching 40 | max_batch_size: Max sequences to batch together 41 | batch_timeout_ms: Max wait time before processing batch 42 | """ 43 | self.device = device 44 | self.enable_batching = enable_batching 45 | self.max_batch_size = max_batch_size 46 | self.batch_timeout_ms = batch_timeout_ms 47 | 48 | print(f"Loading SNAC 24kHz model to {device}...") 49 | self.snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(device) 50 | 51 | if compile_decoder: 52 | print(f"Compiling SNAC decoder with torch.compile...") 53 | self._compile_model() 54 | 55 | # Batching infrastructure 56 | if enable_batching: 57 | self.request_queue = asyncio.Queue() 58 | self.batch_processor_task = None 59 | self._running = False 60 | print(f"Batching enabled (max_batch={max_batch_size}, timeout={batch_timeout_ms}ms)") 61 | 62 | print(f"SNAC decoder initialized") 63 | 64 | def _compile_model(self): 65 | """Compile SNAC decoder with torch.compile""" 66 | # Warm up with various sizes 67 | for frames in [4, 16, 32]: 68 | dummy_codes = [ 69 | torch.randint(0, 4096, (1, frames), device=self.device), 70 | torch.randint(0, 4096, (1, frames * 2), device=self.device), 71 | torch.randint(0, 4096, (1, frames * 4), device=self.device), 72 | ] 73 | with torch.inference_mode(): 74 | z_q = self.snac_model.quantizer.from_codes(dummy_codes) 75 | _ = self.snac_model.decoder(z_q) 76 | 77 | # Apply compilation 78 | self.snac_model.decoder = torch.compile( 79 | self.snac_model.decoder, 80 | mode="max-autotune" 81 | ) 82 | self.snac_model.quantizer = torch.compile( 83 | self.snac_model.quantizer, 84 | mode="reduce-overhead" 85 | ) 86 | 87 | print(f"SNAC decoder compiled") 88 | 89 | def unpack_snac_from_7(self, vocab_ids: List[int]) -> List[List[int]]: 90 | """ 91 | Unpack 7-token SNAC frames to 3 hierarchical levels. 92 | 93 | This is the EXACT INVERSE of the training preprocessing function 94 | `pack_snac_to_7_and_offset()`. 95 | 96 | Frame structure: 97 | [slot0, slot1, slot2, slot3, slot4, slot5, slot6] 98 | 99 | Unpacking: 100 | - slot0: L1[i] 101 | - slot1: L2[2*i] (even index) 102 | - slot2: L3[4*i + 0] 103 | - slot3: L3[4*i + 1] 104 | - slot4: L2[2*i + 1] (odd index) 105 | - slot5: L3[4*i + 2] 106 | - slot6: L3[4*i + 3] 107 | 108 | Args: 109 | vocab_ids: List of SNAC token IDs (128266-156937) 110 | Must be divisible by 7 111 | 112 | Returns: 113 | [L1, L2, L3] where: 114 | L1: n elements (coarse level) 115 | L2: 2n elements (medium level) 116 | L3: 4n elements (fine level) 117 | """ 118 | # Strip EOS token if present 119 | if vocab_ids and vocab_ids[-1] == CODE_END_TOKEN_ID: 120 | vocab_ids = vocab_ids[:-1] 121 | 122 | # Ensure complete frames (divisible by 7) 123 | frames = len(vocab_ids) // SNAC_TOKENS_PER_FRAME 124 | vocab_ids = vocab_ids[:frames * SNAC_TOKENS_PER_FRAME] 125 | 126 | if frames == 0: 127 | return [[], [], []] 128 | 129 | l1, l2, l3 = [], [], [] 130 | 131 | for i in range(frames): 132 | # Extract 7 slots for this frame 133 | slots = vocab_ids[i*7:(i+1)*7] 134 | 135 | # Subtract offset (128266) and mod 4096 to get original codes 136 | # Each level uses 4096 codes (0-4095) 137 | l1.append((slots[0] - CODE_TOKEN_OFFSET) % 4096) 138 | l2.extend([ 139 | (slots[1] - CODE_TOKEN_OFFSET) % 4096, # Even index 140 | (slots[4] - CODE_TOKEN_OFFSET) % 4096, # Odd index 141 | ]) 142 | l3.extend([ 143 | (slots[2] - CODE_TOKEN_OFFSET) % 4096, 144 | (slots[3] - CODE_TOKEN_OFFSET) % 4096, 145 | (slots[5] - CODE_TOKEN_OFFSET) % 4096, 146 | (slots[6] - CODE_TOKEN_OFFSET) % 4096, 147 | ]) 148 | 149 | return [l1, l2, l3] 150 | 151 | @torch.inference_mode() 152 | def decode( 153 | self, 154 | snac_tokens: List[int], 155 | trim_warmup: bool = True, 156 | trim_amount: Optional[int] = None, 157 | use_sliding_window: bool = False 158 | ) -> Optional[np.ndarray]: 159 | """ 160 | Decode SNAC tokens to audio waveform. 161 | 162 | Args: 163 | snac_tokens: List of SNAC token IDs (7*n tokens) 164 | trim_warmup: Whether to trim SNAC warmup samples (default: True) 165 | trim_amount: Number of samples to trim (default: 2048 for first chunk, 0 for others) 166 | Can be set to a smaller value (e.g., 512) for intermediate chunks 167 | use_sliding_window: If True, only return middle 2048 samples (for sliding window streaming) 168 | 169 | Returns: 170 | Audio waveform as numpy array (float32, 24kHz mono) 171 | Shape: (samples,) 172 | Returns None if not enough tokens 173 | """ 174 | if len(snac_tokens) < SNAC_TOKENS_PER_FRAME: 175 | print(f"Not enough SNAC tokens: {len(snac_tokens)} < {SNAC_TOKENS_PER_FRAME}") 176 | return None 177 | 178 | # Unpack to 3 levels 179 | levels = self.unpack_snac_from_7(snac_tokens) 180 | 181 | if not levels[0]: # No frames after unpacking 182 | return None 183 | 184 | # Convert to tensors 185 | codes = [ 186 | torch.tensor(level, dtype=torch.long, device=self.device).unsqueeze(0) 187 | for level in levels 188 | ] 189 | 190 | # Decode through SNAC 191 | z_q = self.snac_model.quantizer.from_codes(codes) 192 | audio = self.snac_model.decoder(z_q) 193 | 194 | # Extract audio (remove padding if any) 195 | # SNAC decoder outputs: [batch, 1, samples] 196 | audio = audio[0, 0].cpu().numpy() 197 | 198 | # Sliding window mode: only keep middle 2048 samples 199 | # This eliminates popping/cracking when using overlapping 28-token windows 200 | if use_sliding_window: 201 | if len(audio) >= 4096: 202 | audio = audio[2048:4096] # Keep middle portion only 203 | else: 204 | # For shorter audio, keep everything (final chunk) 205 | pass 206 | else: 207 | # Standard mode: trim warm-up samples 208 | # Default: 2048 samples for first chunk, 0 for subsequent chunks 209 | # Can be customized via trim_amount parameter 210 | if trim_warmup: 211 | if trim_amount is None: 212 | trim_amount = 2048 # Default full trim 213 | 214 | if len(audio) > trim_amount: 215 | audio = audio[trim_amount:] 216 | 217 | return audio 218 | 219 | def decode_to_bytes( 220 | self, 221 | snac_tokens: List[int], 222 | trim_warmup: bool = True, 223 | use_sliding_window: bool = False 224 | ) -> Optional[bytes]: 225 | """ 226 | Decode SNAC tokens to audio bytes (int16 PCM). 227 | 228 | Args: 229 | snac_tokens: List of SNAC token IDs 230 | trim_warmup: Whether to trim SNAC warmup samples (default: True) 231 | use_sliding_window: If True, only return middle 2048 samples (for sliding window streaming) 232 | 233 | Returns: 234 | Audio as bytes (int16 PCM, 24kHz mono) 235 | Returns None if decode fails 236 | """ 237 | audio = self.decode(snac_tokens, trim_warmup=trim_warmup, use_sliding_window=use_sliding_window) 238 | 239 | if audio is None: 240 | return None 241 | 242 | # Convert float32 to int16 PCM 243 | audio_int16 = (audio * 32767).astype(np.int16) 244 | 245 | return audio_int16.tobytes() 246 | 247 | def validate_tokens(self, snac_tokens: List[int]) -> bool: 248 | """ 249 | Validate SNAC tokens before decoding. 250 | Args: 251 | snac_tokens: List of SNAC token IDs 252 | Returns: 253 | True if valid, False otherwise 254 | """ 255 | # Check minimum length 256 | if len(snac_tokens) < SNAC_TOKENS_PER_FRAME: 257 | print(f"Too few tokens: {len(snac_tokens)}") 258 | return False 259 | 260 | # Check divisibility by 7 261 | if len(snac_tokens) % SNAC_TOKENS_PER_FRAME != 0: 262 | print(f" Warning: Token count {len(snac_tokens)} not divisible by 7") 263 | print(f" Will truncate to {(len(snac_tokens) // 7) * 7}") 264 | 265 | # Check token range 266 | for i, token_id in enumerate(snac_tokens): 267 | if token_id < CODE_TOKEN_OFFSET or token_id > 156937: 268 | print(f" Invalid token at position {i}: {token_id}") 269 | print(f" Expected range: [{CODE_TOKEN_OFFSET}, 156937]") 270 | return False 271 | 272 | return True 273 | 274 | # ========== Async Batching Methods ========== 275 | 276 | @property 277 | def is_running(self) -> bool: 278 | """Check if batch processor is running.""" 279 | return self._running if self.enable_batching else False 280 | 281 | async def start_batch_processor(self): 282 | """Start the background batch processor task.""" 283 | if not self.enable_batching: 284 | return 285 | 286 | if self._running: 287 | print("Batch processor already running") 288 | return 289 | 290 | self._running = True 291 | self.batch_processor_task = asyncio.create_task(self._batch_processor_loop()) 292 | print("Batch processor started") 293 | 294 | async def stop_batch_processor(self): 295 | """Stop the background batch processor task.""" 296 | if not self.enable_batching: 297 | return 298 | 299 | if not self._running: 300 | return 301 | 302 | self._running = False 303 | 304 | if self.batch_processor_task: 305 | self.batch_processor_task.cancel() 306 | try: 307 | await self.batch_processor_task 308 | except asyncio.CancelledError: 309 | pass 310 | 311 | print("Batch processor stopped") 312 | 313 | async def decode_single_async( 314 | self, 315 | snac_tokens: List[int], 316 | trim_warmup: bool = True, 317 | use_sliding_window: bool = False 318 | ) -> Optional[bytes]: 319 | """ 320 | Async decode for batching support. 321 | 322 | Queues the request and waits for batched processing. 323 | 324 | Args: 325 | snac_tokens: List of SNAC token IDs 326 | trim_warmup: Whether to trim SNAC warmup samples (default: True) 327 | use_sliding_window: If True, only return middle 2048 samples (for sliding window streaming) 328 | 329 | Returns: 330 | Audio bytes or None if decode fails 331 | """ 332 | if not self.enable_batching: 333 | # Fallback to synchronous decode 334 | return self.decode_to_bytes(snac_tokens, trim_warmup=trim_warmup, use_sliding_window=use_sliding_window) 335 | 336 | # Create future for result 337 | result_future = asyncio.Future() 338 | 339 | # Add to queue (include trim_warmup and sliding_window flags) 340 | await self.request_queue.put((snac_tokens, trim_warmup, use_sliding_window, result_future)) 341 | 342 | # Wait for result 343 | return await result_future 344 | 345 | async def _batch_processor_loop(self): 346 | """Background task that processes batched decode requests.""" 347 | while self._running: 348 | try: 349 | # Collect batch 350 | batch = await self._collect_batch() 351 | 352 | if not batch: 353 | continue 354 | 355 | # Process batch 356 | await self._process_batch(batch) 357 | 358 | except asyncio.CancelledError: 359 | break 360 | except Exception as e: 361 | print(f"Batch processor error: {e}") 362 | import traceback 363 | traceback.print_exc() 364 | 365 | async def _collect_batch(self) -> List[Tuple[List[int], bool, bool, asyncio.Future]]: 366 | """ 367 | Collect requests into a batch. 368 | Waits for timeout or until batch is full. 369 | Returns: 370 | List of (tokens, trim_warmup, use_sliding_window, future) tuples 371 | """ 372 | batch = [] 373 | timeout_sec = self.batch_timeout_ms / 1000.0 374 | 375 | try: 376 | # Wait for first request (blocking) 377 | first_item = await asyncio.wait_for( 378 | self.request_queue.get(), 379 | timeout=timeout_sec 380 | ) 381 | batch.append(first_item) 382 | 383 | # Collect more requests (non-blocking) 384 | while len(batch) < self.max_batch_size: 385 | try: 386 | item = await asyncio.wait_for( 387 | self.request_queue.get(), 388 | timeout=timeout_sec 389 | ) 390 | batch.append(item) 391 | except asyncio.TimeoutError: 392 | break # Timeout reached, process what we have 393 | 394 | except asyncio.TimeoutError: 395 | # No requests in timeout period 396 | pass 397 | 398 | return batch 399 | 400 | @torch.inference_mode() 401 | async def _process_batch(self, batch: List[Tuple[List[int], bool, bool, asyncio.Future]]): 402 | """ 403 | Process a batch of decode requests. 404 | Args: 405 | batch: List of (tokens, trim_warmup, use_sliding_window, future) tuples 406 | """ 407 | if not batch: 408 | return 409 | 410 | # Extract components 411 | token_sequences = [item[0] for item in batch] 412 | trim_warmup_flags = [item[1] for item in batch] 413 | sliding_window_flags = [item[2] for item in batch] 414 | futures = [item[3] for item in batch] 415 | 416 | lengths = [len(tokens) for tokens in token_sequences] 417 | can_batch_efficiently = len(set(lengths)) == 1 418 | 419 | if can_batch_efficiently and len(batch) > 1: 420 | # Efficient batching: all same length 421 | try: 422 | audio_bytes_list = await self._decode_batch_same_length( 423 | token_sequences, trim_warmup_flags, sliding_window_flags 424 | ) 425 | 426 | # Set results 427 | for future, audio_bytes in zip(futures, audio_bytes_list): 428 | if not future.done(): 429 | future.set_result(audio_bytes) 430 | 431 | except Exception as e: 432 | # Set exceptions 433 | for future in futures: 434 | if not future.done(): 435 | future.set_exception(e) 436 | else: 437 | # Sequential decode (different lengths or single item) 438 | for tokens, trim_warmup, use_sliding_window, future in batch: 439 | try: 440 | audio_bytes = self.decode_to_bytes( 441 | tokens, trim_warmup=trim_warmup, use_sliding_window=use_sliding_window 442 | ) 443 | if not future.done(): 444 | future.set_result(audio_bytes) 445 | except Exception as e: 446 | if not future.done(): 447 | future.set_exception(e) 448 | 449 | async def _decode_batch_same_length( 450 | self, 451 | token_sequences: List[List[int]], 452 | trim_warmup_flags: List[bool], 453 | sliding_window_flags: List[bool] 454 | ) -> List[Optional[bytes]]: 455 | """ 456 | Decode multiple sequences with same length in parallel. 457 | 458 | Args: 459 | token_sequences: List of token sequences (all same length) 460 | trim_warmup_flags: List of trim_warmup flags for each sequence 461 | sliding_window_flags: List of use_sliding_window flags for each sequence 462 | 463 | Returns: 464 | List of audio bytes 465 | """ 466 | if not token_sequences: 467 | return [] 468 | 469 | # Unpack all sequences 470 | unpacked_list = [self.unpack_snac_from_7(tokens) for tokens in token_sequences] 471 | 472 | # Check all have valid frames 473 | valid_indices = [i for i, levels in enumerate(unpacked_list) if levels[0]] 474 | 475 | if not valid_indices: 476 | return [None] * len(token_sequences) 477 | 478 | # Stack into batched tensors 479 | batch_size = len(valid_indices) 480 | frames = len(unpacked_list[valid_indices[0]][0]) 481 | 482 | # Build batched codes [batch, frames], [batch, 2*frames], [batch, 4*frames] 483 | codes = [ 484 | torch.stack([ 485 | torch.tensor(unpacked_list[i][level_idx], dtype=torch.long, device=self.device) 486 | for i in valid_indices 487 | ], dim=0) 488 | for level_idx in range(3) 489 | ] 490 | 491 | # Batched decode 492 | z_q = self.snac_model.quantizer.from_codes(codes) 493 | audio_batch = self.snac_model.decoder(z_q) # [batch, 1, samples] 494 | 495 | # Extract and convert to bytes 496 | audio_bytes_list = [None] * len(token_sequences) 497 | 498 | for batch_idx, orig_idx in enumerate(valid_indices): 499 | audio = audio_batch[batch_idx, 0].detach().cpu().numpy() 500 | 501 | # Apply sliding window or trim warmup based on flags 502 | if sliding_window_flags[orig_idx]: 503 | # Sliding window mode: keep middle 2048 samples only 504 | if len(audio) >= 4096: 505 | audio = audio[2048:4096] 506 | else: 507 | # Standard mode: trim warm-up if requested 508 | if trim_warmup_flags[orig_idx] and len(audio) > 2048: 509 | audio = audio[2048:] 510 | 511 | # Convert to int16 512 | audio_int16 = (audio * 32767).astype(np.int16) 513 | audio_bytes_list[orig_idx] = audio_int16.tobytes() 514 | 515 | return audio_bytes_list -------------------------------------------------------------------------------- /hf_space/maya1/snac_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import asyncio 4 | from typing import List, Optional, Tuple 5 | from snac import SNAC 6 | 7 | from .constants import ( 8 | CODE_END_TOKEN_ID, 9 | CODE_TOKEN_OFFSET, 10 | SNAC_MODEL_NAME, 11 | SNAC_SAMPLE_RATE, 12 | SNAC_TOKENS_PER_FRAME, 13 | ) 14 | 15 | 16 | class SNACDecoder: 17 | """ 18 | SNAC Decoder for maya1. 19 | Unpacks 7-token SNAC frames and decodes to audio waveforms. 20 | Unpacking logic is the EXACT INVERSE of training preprocessing. 21 | Supports async batching for concurrent requests. 22 | CRITICAL: Any mismatch in unpacking will produce garbage audio. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | device: str = "cuda", 28 | compile_decoder: bool = False, 29 | enable_batching: bool = False, 30 | max_batch_size: int = 64, 31 | batch_timeout_ms: int = 15, 32 | ): 33 | """ 34 | Initialize SNAC decoder. 35 | 36 | Args: 37 | device: Device for SNAC model (cuda/cpu) 38 | compile_decoder: Use torch.compile for speedup 39 | enable_batching: Enable async batching 40 | max_batch_size: Max sequences to batch together 41 | batch_timeout_ms: Max wait time before processing batch 42 | """ 43 | self.device = device 44 | self.enable_batching = enable_batching 45 | self.max_batch_size = max_batch_size 46 | self.batch_timeout_ms = batch_timeout_ms 47 | 48 | print(f"Loading SNAC 24kHz model to {device}...") 49 | self.snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(device) 50 | 51 | if compile_decoder: 52 | print(f"Compiling SNAC decoder with torch.compile...") 53 | self._compile_model() 54 | 55 | # Batching infrastructure 56 | if enable_batching: 57 | self.request_queue = asyncio.Queue() 58 | self.batch_processor_task = None 59 | self._running = False 60 | print(f"Batching enabled (max_batch={max_batch_size}, timeout={batch_timeout_ms}ms)") 61 | 62 | print(f"SNAC decoder initialized") 63 | 64 | def _compile_model(self): 65 | """Compile SNAC decoder with torch.compile""" 66 | # Warm up with various sizes 67 | for frames in [4, 16, 32]: 68 | dummy_codes = [ 69 | torch.randint(0, 4096, (1, frames), device=self.device), 70 | torch.randint(0, 4096, (1, frames * 2), device=self.device), 71 | torch.randint(0, 4096, (1, frames * 4), device=self.device), 72 | ] 73 | with torch.inference_mode(): 74 | z_q = self.snac_model.quantizer.from_codes(dummy_codes) 75 | _ = self.snac_model.decoder(z_q) 76 | 77 | # Apply compilation 78 | self.snac_model.decoder = torch.compile( 79 | self.snac_model.decoder, 80 | mode="max-autotune" 81 | ) 82 | self.snac_model.quantizer = torch.compile( 83 | self.snac_model.quantizer, 84 | mode="reduce-overhead" 85 | ) 86 | 87 | print(f"SNAC decoder compiled") 88 | 89 | def unpack_snac_from_7(self, vocab_ids: List[int]) -> List[List[int]]: 90 | """ 91 | Unpack 7-token SNAC frames to 3 hierarchical levels. 92 | 93 | This is the EXACT INVERSE of the training preprocessing function 94 | `pack_snac_to_7_and_offset()`. 95 | 96 | Frame structure: 97 | [slot0, slot1, slot2, slot3, slot4, slot5, slot6] 98 | 99 | Unpacking: 100 | - slot0: L1[i] 101 | - slot1: L2[2*i] (even index) 102 | - slot2: L3[4*i + 0] 103 | - slot3: L3[4*i + 1] 104 | - slot4: L2[2*i + 1] (odd index) 105 | - slot5: L3[4*i + 2] 106 | - slot6: L3[4*i + 3] 107 | 108 | Args: 109 | vocab_ids: List of SNAC token IDs (128266-156937) 110 | Must be divisible by 7 111 | 112 | Returns: 113 | [L1, L2, L3] where: 114 | L1: n elements (coarse level) 115 | L2: 2n elements (medium level) 116 | L3: 4n elements (fine level) 117 | """ 118 | # Strip EOS token if present 119 | if vocab_ids and vocab_ids[-1] == CODE_END_TOKEN_ID: 120 | vocab_ids = vocab_ids[:-1] 121 | 122 | # Ensure complete frames (divisible by 7) 123 | frames = len(vocab_ids) // SNAC_TOKENS_PER_FRAME 124 | vocab_ids = vocab_ids[:frames * SNAC_TOKENS_PER_FRAME] 125 | 126 | if frames == 0: 127 | return [[], [], []] 128 | 129 | l1, l2, l3 = [], [], [] 130 | 131 | for i in range(frames): 132 | # Extract 7 slots for this frame 133 | slots = vocab_ids[i*7:(i+1)*7] 134 | 135 | # Subtract offset (128266) and mod 4096 to get original codes 136 | # Each level uses 4096 codes (0-4095) 137 | l1.append((slots[0] - CODE_TOKEN_OFFSET) % 4096) 138 | l2.extend([ 139 | (slots[1] - CODE_TOKEN_OFFSET) % 4096, # Even index 140 | (slots[4] - CODE_TOKEN_OFFSET) % 4096, # Odd index 141 | ]) 142 | l3.extend([ 143 | (slots[2] - CODE_TOKEN_OFFSET) % 4096, 144 | (slots[3] - CODE_TOKEN_OFFSET) % 4096, 145 | (slots[5] - CODE_TOKEN_OFFSET) % 4096, 146 | (slots[6] - CODE_TOKEN_OFFSET) % 4096, 147 | ]) 148 | 149 | return [l1, l2, l3] 150 | 151 | @torch.inference_mode() 152 | def decode( 153 | self, 154 | snac_tokens: List[int], 155 | trim_warmup: bool = True, 156 | trim_amount: Optional[int] = None, 157 | use_sliding_window: bool = False 158 | ) -> Optional[np.ndarray]: 159 | """ 160 | Decode SNAC tokens to audio waveform. 161 | 162 | Args: 163 | snac_tokens: List of SNAC token IDs (7*n tokens) 164 | trim_warmup: Whether to trim SNAC warmup samples (default: True) 165 | trim_amount: Number of samples to trim (default: 2048 for first chunk, 0 for others) 166 | Can be set to a smaller value (e.g., 512) for intermediate chunks 167 | use_sliding_window: If True, only return middle 2048 samples (for sliding window streaming) 168 | 169 | Returns: 170 | Audio waveform as numpy array (float32, 24kHz mono) 171 | Shape: (samples,) 172 | Returns None if not enough tokens 173 | """ 174 | if len(snac_tokens) < SNAC_TOKENS_PER_FRAME: 175 | print(f"Not enough SNAC tokens: {len(snac_tokens)} < {SNAC_TOKENS_PER_FRAME}") 176 | return None 177 | 178 | # Unpack to 3 levels 179 | levels = self.unpack_snac_from_7(snac_tokens) 180 | 181 | if not levels[0]: # No frames after unpacking 182 | return None 183 | 184 | # Convert to tensors 185 | codes = [ 186 | torch.tensor(level, dtype=torch.long, device=self.device).unsqueeze(0) 187 | for level in levels 188 | ] 189 | 190 | # Decode through SNAC 191 | z_q = self.snac_model.quantizer.from_codes(codes) 192 | audio = self.snac_model.decoder(z_q) 193 | 194 | # Extract audio (remove padding if any) 195 | # SNAC decoder outputs: [batch, 1, samples] 196 | audio = audio[0, 0].cpu().numpy() 197 | 198 | # Sliding window mode: only keep middle 2048 samples 199 | # This eliminates popping/cracking when using overlapping 28-token windows 200 | if use_sliding_window: 201 | if len(audio) >= 4096: 202 | audio = audio[2048:4096] # Keep middle portion only 203 | else: 204 | # For shorter audio, keep everything (final chunk) 205 | pass 206 | else: 207 | # Standard mode: trim warm-up samples 208 | # Default: 2048 samples for first chunk, 0 for subsequent chunks 209 | # Can be customized via trim_amount parameter 210 | if trim_warmup: 211 | if trim_amount is None: 212 | trim_amount = 2048 # Default full trim 213 | 214 | if len(audio) > trim_amount: 215 | audio = audio[trim_amount:] 216 | 217 | return audio 218 | 219 | def decode_to_bytes( 220 | self, 221 | snac_tokens: List[int], 222 | trim_warmup: bool = True, 223 | use_sliding_window: bool = False 224 | ) -> Optional[bytes]: 225 | """ 226 | Decode SNAC tokens to audio bytes (int16 PCM). 227 | 228 | Args: 229 | snac_tokens: List of SNAC token IDs 230 | trim_warmup: Whether to trim SNAC warmup samples (default: True) 231 | use_sliding_window: If True, only return middle 2048 samples (for sliding window streaming) 232 | 233 | Returns: 234 | Audio as bytes (int16 PCM, 24kHz mono) 235 | Returns None if decode fails 236 | """ 237 | audio = self.decode(snac_tokens, trim_warmup=trim_warmup, use_sliding_window=use_sliding_window) 238 | 239 | if audio is None: 240 | return None 241 | 242 | # Convert float32 to int16 PCM 243 | audio_int16 = (audio * 32767).astype(np.int16) 244 | 245 | return audio_int16.tobytes() 246 | 247 | def validate_tokens(self, snac_tokens: List[int]) -> bool: 248 | """ 249 | Validate SNAC tokens before decoding. 250 | Args: 251 | snac_tokens: List of SNAC token IDs 252 | Returns: 253 | True if valid, False otherwise 254 | """ 255 | # Check minimum length 256 | if len(snac_tokens) < SNAC_TOKENS_PER_FRAME: 257 | print(f"Too few tokens: {len(snac_tokens)}") 258 | return False 259 | 260 | # Check divisibility by 7 261 | if len(snac_tokens) % SNAC_TOKENS_PER_FRAME != 0: 262 | print(f" Warning: Token count {len(snac_tokens)} not divisible by 7") 263 | print(f" Will truncate to {(len(snac_tokens) // 7) * 7}") 264 | 265 | # Check token range 266 | for i, token_id in enumerate(snac_tokens): 267 | if token_id < CODE_TOKEN_OFFSET or token_id > 156937: 268 | print(f" Invalid token at position {i}: {token_id}") 269 | print(f" Expected range: [{CODE_TOKEN_OFFSET}, 156937]") 270 | return False 271 | 272 | return True 273 | 274 | # ========== Async Batching Methods ========== 275 | 276 | @property 277 | def is_running(self) -> bool: 278 | """Check if batch processor is running.""" 279 | return self._running if self.enable_batching else False 280 | 281 | async def start_batch_processor(self): 282 | """Start the background batch processor task.""" 283 | if not self.enable_batching: 284 | return 285 | 286 | if self._running: 287 | print("Batch processor already running") 288 | return 289 | 290 | self._running = True 291 | self.batch_processor_task = asyncio.create_task(self._batch_processor_loop()) 292 | print("Batch processor started") 293 | 294 | async def stop_batch_processor(self): 295 | """Stop the background batch processor task.""" 296 | if not self.enable_batching: 297 | return 298 | 299 | if not self._running: 300 | return 301 | 302 | self._running = False 303 | 304 | if self.batch_processor_task: 305 | self.batch_processor_task.cancel() 306 | try: 307 | await self.batch_processor_task 308 | except asyncio.CancelledError: 309 | pass 310 | 311 | print("Batch processor stopped") 312 | 313 | async def decode_single_async( 314 | self, 315 | snac_tokens: List[int], 316 | trim_warmup: bool = True, 317 | use_sliding_window: bool = False 318 | ) -> Optional[bytes]: 319 | """ 320 | Async decode for batching support. 321 | 322 | Queues the request and waits for batched processing. 323 | 324 | Args: 325 | snac_tokens: List of SNAC token IDs 326 | trim_warmup: Whether to trim SNAC warmup samples (default: True) 327 | use_sliding_window: If True, only return middle 2048 samples (for sliding window streaming) 328 | 329 | Returns: 330 | Audio bytes or None if decode fails 331 | """ 332 | if not self.enable_batching: 333 | # Fallback to synchronous decode 334 | return self.decode_to_bytes(snac_tokens, trim_warmup=trim_warmup, use_sliding_window=use_sliding_window) 335 | 336 | # Create future for result 337 | result_future = asyncio.Future() 338 | 339 | # Add to queue (include trim_warmup and sliding_window flags) 340 | await self.request_queue.put((snac_tokens, trim_warmup, use_sliding_window, result_future)) 341 | 342 | # Wait for result 343 | return await result_future 344 | 345 | async def _batch_processor_loop(self): 346 | """Background task that processes batched decode requests.""" 347 | while self._running: 348 | try: 349 | # Collect batch 350 | batch = await self._collect_batch() 351 | 352 | if not batch: 353 | continue 354 | 355 | # Process batch 356 | await self._process_batch(batch) 357 | 358 | except asyncio.CancelledError: 359 | break 360 | except Exception as e: 361 | print(f"Batch processor error: {e}") 362 | import traceback 363 | traceback.print_exc() 364 | 365 | async def _collect_batch(self) -> List[Tuple[List[int], bool, bool, asyncio.Future]]: 366 | """ 367 | Collect requests into a batch. 368 | Waits for timeout or until batch is full. 369 | Returns: 370 | List of (tokens, trim_warmup, use_sliding_window, future) tuples 371 | """ 372 | batch = [] 373 | timeout_sec = self.batch_timeout_ms / 1000.0 374 | 375 | try: 376 | # Wait for first request (blocking) 377 | first_item = await asyncio.wait_for( 378 | self.request_queue.get(), 379 | timeout=timeout_sec 380 | ) 381 | batch.append(first_item) 382 | 383 | # Collect more requests (non-blocking) 384 | while len(batch) < self.max_batch_size: 385 | try: 386 | item = await asyncio.wait_for( 387 | self.request_queue.get(), 388 | timeout=timeout_sec 389 | ) 390 | batch.append(item) 391 | except asyncio.TimeoutError: 392 | break # Timeout reached, process what we have 393 | 394 | except asyncio.TimeoutError: 395 | # No requests in timeout period 396 | pass 397 | 398 | return batch 399 | 400 | @torch.inference_mode() 401 | async def _process_batch(self, batch: List[Tuple[List[int], bool, bool, asyncio.Future]]): 402 | """ 403 | Process a batch of decode requests. 404 | Args: 405 | batch: List of (tokens, trim_warmup, use_sliding_window, future) tuples 406 | """ 407 | if not batch: 408 | return 409 | 410 | # Extract components 411 | token_sequences = [item[0] for item in batch] 412 | trim_warmup_flags = [item[1] for item in batch] 413 | sliding_window_flags = [item[2] for item in batch] 414 | futures = [item[3] for item in batch] 415 | 416 | lengths = [len(tokens) for tokens in token_sequences] 417 | can_batch_efficiently = len(set(lengths)) == 1 418 | 419 | if can_batch_efficiently and len(batch) > 1: 420 | # Efficient batching: all same length 421 | try: 422 | audio_bytes_list = await self._decode_batch_same_length( 423 | token_sequences, trim_warmup_flags, sliding_window_flags 424 | ) 425 | 426 | # Set results 427 | for future, audio_bytes in zip(futures, audio_bytes_list): 428 | if not future.done(): 429 | future.set_result(audio_bytes) 430 | 431 | except Exception as e: 432 | # Set exceptions 433 | for future in futures: 434 | if not future.done(): 435 | future.set_exception(e) 436 | else: 437 | # Sequential decode (different lengths or single item) 438 | for tokens, trim_warmup, use_sliding_window, future in batch: 439 | try: 440 | audio_bytes = self.decode_to_bytes( 441 | tokens, trim_warmup=trim_warmup, use_sliding_window=use_sliding_window 442 | ) 443 | if not future.done(): 444 | future.set_result(audio_bytes) 445 | except Exception as e: 446 | if not future.done(): 447 | future.set_exception(e) 448 | 449 | async def _decode_batch_same_length( 450 | self, 451 | token_sequences: List[List[int]], 452 | trim_warmup_flags: List[bool], 453 | sliding_window_flags: List[bool] 454 | ) -> List[Optional[bytes]]: 455 | """ 456 | Decode multiple sequences with same length in parallel. 457 | 458 | Args: 459 | token_sequences: List of token sequences (all same length) 460 | trim_warmup_flags: List of trim_warmup flags for each sequence 461 | sliding_window_flags: List of use_sliding_window flags for each sequence 462 | 463 | Returns: 464 | List of audio bytes 465 | """ 466 | if not token_sequences: 467 | return [] 468 | 469 | # Unpack all sequences 470 | unpacked_list = [self.unpack_snac_from_7(tokens) for tokens in token_sequences] 471 | 472 | # Check all have valid frames 473 | valid_indices = [i for i, levels in enumerate(unpacked_list) if levels[0]] 474 | 475 | if not valid_indices: 476 | return [None] * len(token_sequences) 477 | 478 | # Stack into batched tensors 479 | batch_size = len(valid_indices) 480 | frames = len(unpacked_list[valid_indices[0]][0]) 481 | 482 | # Build batched codes [batch, frames], [batch, 2*frames], [batch, 4*frames] 483 | codes = [ 484 | torch.stack([ 485 | torch.tensor(unpacked_list[i][level_idx], dtype=torch.long, device=self.device) 486 | for i in valid_indices 487 | ], dim=0) 488 | for level_idx in range(3) 489 | ] 490 | 491 | # Batched decode 492 | z_q = self.snac_model.quantizer.from_codes(codes) 493 | audio_batch = self.snac_model.decoder(z_q) # [batch, 1, samples] 494 | 495 | # Extract and convert to bytes 496 | audio_bytes_list = [None] * len(token_sequences) 497 | 498 | for batch_idx, orig_idx in enumerate(valid_indices): 499 | audio = audio_batch[batch_idx, 0].detach().cpu().numpy() 500 | 501 | # Apply sliding window or trim warmup based on flags 502 | if sliding_window_flags[orig_idx]: 503 | # Sliding window mode: keep middle 2048 samples only 504 | if len(audio) >= 4096: 505 | audio = audio[2048:4096] 506 | else: 507 | # Standard mode: trim warm-up if requested 508 | if trim_warmup_flags[orig_idx] and len(audio) > 2048: 509 | audio = audio[2048:] 510 | 511 | # Convert to int16 512 | audio_int16 = (audio * 32767).astype(np.int16) 513 | audio_bytes_list[orig_idx] = audio_int16.tobytes() 514 | 515 | return audio_bytes_list --------------------------------------------------------------------------------