├── icon.gif ├── icon.png ├── utils ├── __init__.py ├── blocks.py ├── interp.py └── dist.py ├── prompts ├── bob_duo.wav ├── bob_mono.wav ├── toaskanymore.wav └── countdown_mono.wav ├── reset.js ├── update.js ├── .gitignore ├── requirements.txt ├── install.js ├── start.js ├── README.md ├── pinokio.js ├── torch.js ├── inference_server.py ├── inference_client.py ├── app.py ├── inference.ipynb ├── LICENSE ├── ioblocks.py ├── transformer.py ├── model.py └── tokenizer.py /icon.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cocktailpeanut/hallucinator/HEAD/icon.gif -------------------------------------------------------------------------------- /icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cocktailpeanut/hallucinator/HEAD/icon.png -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .blocks import * 2 | from .dist import * 3 | from .interp import * -------------------------------------------------------------------------------- /prompts/bob_duo.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cocktailpeanut/hallucinator/HEAD/prompts/bob_duo.wav -------------------------------------------------------------------------------- /prompts/bob_mono.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cocktailpeanut/hallucinator/HEAD/prompts/bob_mono.wav -------------------------------------------------------------------------------- /prompts/toaskanymore.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cocktailpeanut/hallucinator/HEAD/prompts/toaskanymore.wav -------------------------------------------------------------------------------- /prompts/countdown_mono.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cocktailpeanut/hallucinator/HEAD/prompts/countdown_mono.wav -------------------------------------------------------------------------------- /reset.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | run: [{ 3 | method: "fs.rm", 4 | params: { 5 | path: "env" 6 | } 7 | }] 8 | } 9 | -------------------------------------------------------------------------------- /update.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | run: [{ 3 | method: "shell.run", 4 | params: { 5 | message: "git pull" 6 | } 7 | }] 8 | } 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | *.wav 3 | *.mp3 4 | *.m4a 5 | !prompts/*.wav 6 | !prompts/*.mp3 7 | !prompts/*.m4a 8 | __pycache__ 9 | *ckpt 10 | env 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | tqdm 3 | ipython 4 | numpy 5 | soundfile 6 | websockets 7 | requests 8 | sounddevice 9 | matplotlib 10 | fastapi 11 | uvicorn 12 | gradio 13 | -------------------------------------------------------------------------------- /install.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | run: [ 3 | { 4 | method: "script.start", 5 | params: { 6 | uri: "torch.js", 7 | params: { 8 | venv: "env", // Edit this to customize the venv folder path 9 | } 10 | } 11 | }, 12 | { 13 | method: "shell.run", 14 | params: { 15 | venv: "env", // Edit this to customize the venv folder path 16 | message: [ 17 | "pip install -r requirements.txt", 18 | ] 19 | } 20 | }, 21 | { 22 | method: "fs.link", 23 | params: { 24 | venv: "env" 25 | } 26 | } 27 | ] 28 | } 29 | -------------------------------------------------------------------------------- /start.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | daemon: true, 3 | run: [ 4 | { 5 | method: "shell.run", 6 | params: { 7 | venv: "env", // Edit this to customize the venv folder path 8 | message: "python app.py", 9 | on: [{ 10 | // The regular expression pattern to monitor. 11 | // When this pattern occurs in the shell terminal, the shell will return, 12 | // and the script will go onto the next step. 13 | "event": "/http:\/\/[0-9:.]+/", 14 | 15 | // "done": true will move to the next step while keeping the shell alive. 16 | // "kill": true will move to the next step after killing the shell. 17 | "done": true 18 | }] 19 | } 20 | }, 21 | { 22 | // This step sets the local variable 'url'. 23 | // This local variable will be used in pinokio.js to display the "Open WebUI" tab when the value is set. 24 | method: "local.set", 25 | params: { 26 | // the input.event is the regular expression match object from the previous step 27 | url: "{{input.event[0]}}" 28 | } 29 | }, 30 | ] 31 | } 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hallucinator 2 | 3 | **[NVIDIA ONLY]** Upload an audio clip and let the AI autocomplete the rest. Powered by [hertz](https://github.com/Standard-Intelligence/hertz-dev) from Standard Intelligence. 4 | 5 | > See this thread from Standard Intelligence to learn what this model is capable of: https://x.com/si_pbc/status/1853184307063660723 6 | 7 | 8 | https://github.com/user-attachments/assets/13fe1e65-eafc-44bf-8d49-ed6ad7a0aaef 9 | 10 | 11 | # Credits 12 | 13 | This project is a slight modification of the official [hertz-dev](https://github.com/Standard-Intelligence/hertz-dev) project. Changes made: 14 | 15 | 1. **app.py:** added a gradio web ui for audio autocompletion 16 | 2. **transformer.py:** Use `SDPBackend.EFFICIENT_ATTENTION` instead of `SDPBackend.FLASH_ATTENTION` because installing flash attention takes too long and we might be dead by the time it finishes installing. 17 | 18 | 19 | # Install 20 | 21 | ## 1. One Click Install 22 | 23 | You can install it with one click on https://pinokio.computer 24 | 25 | ## 2. Manual Install 26 | 27 | Install pytorch 28 | 29 | ``` 30 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 31 | ``` 32 | 33 | instal dependencies 34 | 35 | ``` 36 | pip install -r requirements.txt 37 | ``` 38 | 39 | ## Run 40 | 41 | ``` 42 | python app.py 43 | ``` 44 | -------------------------------------------------------------------------------- /pinokio.js: -------------------------------------------------------------------------------- 1 | const path = require('path') 2 | module.exports = { 3 | version: "2.0", 4 | title: "Hallucinator", 5 | description: "[NVIDIA ONLY] Autocomplete any voice(s), powered by Hertz AI (Standard Intelligence)", 6 | icon: "icon.gif", 7 | menu: async (kernel, info) => { 8 | let installed = info.exists("env") 9 | let running = { 10 | install: info.running("install.js"), 11 | start: info.running("start.js"), 12 | update: info.running("update.js"), 13 | reset: info.running("reset.js") 14 | } 15 | if (running.install) { 16 | return [{ 17 | default: true, 18 | icon: "fa-solid fa-plug", 19 | text: "Installing", 20 | href: "install.js", 21 | }] 22 | } else if (installed) { 23 | if (running.start) { 24 | let local = info.local("start.js") 25 | if (local && local.url) { 26 | return [{ 27 | default: true, 28 | icon: "fa-solid fa-rocket", 29 | text: "Open Web UI", 30 | href: local.url, 31 | }, { 32 | icon: 'fa-solid fa-terminal', 33 | text: "Terminal", 34 | href: "start.js", 35 | }] 36 | } else { 37 | return [{ 38 | default: true, 39 | icon: 'fa-solid fa-terminal', 40 | text: "Terminal", 41 | href: "start.js", 42 | }] 43 | } 44 | } else if (running.update) { 45 | return [{ 46 | default: true, 47 | icon: 'fa-solid fa-terminal', 48 | text: "Updating", 49 | href: "update.js", 50 | }] 51 | } else if (running.reset) { 52 | return [{ 53 | default: true, 54 | icon: 'fa-solid fa-terminal', 55 | text: "Resetting", 56 | href: "reset.js", 57 | }] 58 | } else { 59 | return [{ 60 | default: true, 61 | icon: "fa-solid fa-power-off", 62 | text: "Start", 63 | href: "start.js", 64 | }, { 65 | icon: "fa-solid fa-plug", 66 | text: "Update", 67 | href: "update.js", 68 | }, { 69 | icon: "fa-solid fa-plug", 70 | text: "Install", 71 | href: "install.js", 72 | }, { 73 | icon: "fa-regular fa-circle-xmark", 74 | text: "Reset", 75 | href: "reset.js", 76 | }] 77 | } 78 | } else { 79 | return [{ 80 | default: true, 81 | icon: "fa-solid fa-plug", 82 | text: "Install", 83 | href: "install.js", 84 | }] 85 | } 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /torch.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | run: [ 3 | // windows nvidia 4 | { 5 | "when": "{{platform === 'win32' && gpu === 'nvidia'}}", 6 | "method": "shell.run", 7 | "params": { 8 | "venv": "{{args && args.venv ? args.venv : null}}", 9 | "path": "{{args && args.path ? args.path : '.'}}", 10 | "message": "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121" 11 | } 12 | }, 13 | // windows amd 14 | { 15 | "when": "{{platform === 'win32' && gpu === 'amd'}}", 16 | "method": "shell.run", 17 | "params": { 18 | "venv": "{{args && args.venv ? args.venv : null}}", 19 | "path": "{{args && args.path ? args.path : '.'}}", 20 | "message": "pip install torch-directml torchaudio torchvision numpy==1.26.4" 21 | } 22 | }, 23 | // windows cpu 24 | { 25 | "when": "{{platform === 'win32' && (gpu !== 'nvidia' && gpu !== 'amd')}}", 26 | "method": "shell.run", 27 | "params": { 28 | "venv": "{{args && args.venv ? args.venv : null}}", 29 | "path": "{{args && args.path ? args.path : '.'}}", 30 | "message": "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu" 31 | } 32 | }, 33 | // mac 34 | { 35 | "when": "{{platform === 'darwin'}}", 36 | "method": "shell.run", 37 | "params": { 38 | "venv": "{{args && args.venv ? args.venv : null}}", 39 | "path": "{{args && args.path ? args.path : '.'}}", 40 | "message": "pip install torch torchvision torchaudio" 41 | } 42 | }, 43 | // linux nvidia 44 | { 45 | "when": "{{platform === 'linux' && gpu === 'nvidia'}}", 46 | "method": "shell.run", 47 | "params": { 48 | "venv": "{{args && args.venv ? args.venv : null}}", 49 | "path": "{{args && args.path ? args.path : '.'}}", 50 | "message": "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu" 51 | } 52 | }, 53 | // linux rocm (amd) 54 | { 55 | "when": "{{platform === 'linux' && gpu === 'amd'}}", 56 | "method": "shell.run", 57 | "params": { 58 | "venv": "{{args && args.venv ? args.venv : null}}", 59 | "path": "{{args && args.path ? args.path : '.'}}", 60 | "message": "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2" 61 | } 62 | }, 63 | // linux cpu 64 | { 65 | "when": "{{platform === 'linux' && (gpu !== 'amd' && gpu !=='nvidia')}}", 66 | "method": "shell.run", 67 | "params": { 68 | "venv": "{{args && args.venv ? args.venv : null}}", 69 | "path": "{{args && args.path ? args.path : '.'}}", 70 | "message": "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu" 71 | } 72 | } 73 | ] 74 | } 75 | -------------------------------------------------------------------------------- /utils/blocks.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import TypeVar, Generic, Type, Optional 3 | from functools import wraps 4 | import time 5 | import random 6 | 7 | import torch as T 8 | import torch.nn as nn 9 | 10 | # @TODO: remove si_module from codebase 11 | # we use this in our research codebase to make modules from callable configs 12 | si_module_TpV = TypeVar('si_module_TpV') 13 | def si_module(cls: Type[si_module_TpV]) -> Type[si_module_TpV]: 14 | if not hasattr(cls, 'Config') or not isinstance(cls.Config, type): 15 | class Config: 16 | pass 17 | cls.Config = Config 18 | 19 | cls.Config = dataclass(cls.Config) 20 | 21 | class ConfigWrapper(cls.Config, Generic[si_module_TpV]): 22 | def __call__(self, *args, **kwargs) -> si_module_TpV: 23 | if len(kwargs) > 0: 24 | config_dict = {field.name: getattr(self, field.name) for field in self.__dataclass_fields__.values()} 25 | config_dict.update(kwargs) 26 | new_config = type(self)(**config_dict) 27 | return cls(new_config) 28 | else: 29 | return cls(self, *args) 30 | 31 | ConfigWrapper.__module__ = cls.__module__ 32 | ConfigWrapper.__name__ = f"{cls.__name__}Config" 33 | ConfigWrapper.__qualname__ = f"{cls.__qualname__}.Config" 34 | 35 | cls.Config = ConfigWrapper 36 | 37 | original_init = cls.__init__ 38 | def new_init(self, *args, **kwargs): 39 | self.c = next((arg for arg in args if isinstance(arg, cls.Config)), None) or next((arg for arg in kwargs.values() if isinstance(arg, cls.Config)), None) 40 | original_init(self, *args, **kwargs) 41 | self.register_buffer('_device_tracker', T.Tensor(), persistent=False) 42 | 43 | cls.__init__ = new_init 44 | 45 | @property 46 | def device(self): 47 | return self._device_tracker.device 48 | 49 | @property 50 | def dtype(self): 51 | return self._device_tracker.dtype 52 | 53 | cls.device = device 54 | cls.dtype = dtype 55 | 56 | return cls 57 | 58 | 59 | def get_activation(nonlinear_activation, nonlinear_activation_params={}): 60 | if hasattr(nn, nonlinear_activation): 61 | return getattr(nn, nonlinear_activation)(**nonlinear_activation_params) 62 | else: 63 | raise NotImplementedError(f"Activation {nonlinear_activation} not found in torch.nn") 64 | 65 | 66 | def exists(v): 67 | return v is not None 68 | 69 | def isnt(v): 70 | return not exists(v) 71 | 72 | def truthyexists(v): 73 | return exists(v) and v is not False 74 | 75 | def truthyattr(obj, attr): 76 | return hasattr(obj, attr) and truthyexists(getattr(obj, attr)) 77 | 78 | defaultT = TypeVar('defaultT') 79 | 80 | def default(*args: Optional[defaultT]) -> Optional[defaultT]: 81 | for arg in args: 82 | if exists(arg): 83 | return arg 84 | return None 85 | 86 | def maybe(fn): 87 | @wraps(fn) 88 | def inner(x, *args, **kwargs): 89 | if not exists(x): 90 | return x 91 | return fn(x, *args, **kwargs) 92 | return inner 93 | -------------------------------------------------------------------------------- /utils/interp.py: -------------------------------------------------------------------------------- 1 | import torch as T 2 | import os 3 | 4 | def rank0(): 5 | rank = os.environ.get('RANK') 6 | if rank is None or rank == '0': 7 | return True 8 | else: 9 | return False 10 | 11 | def print_colored(message, color='reset', bold=False, **kwargs): 12 | color_dict = { 13 | 'bold': '\033[1m', 14 | 'green': '\033[92m', 15 | 'yellow': '\033[93m', 16 | 'red': '\033[91m', 17 | 'blue': '\033[94m', 18 | 'grey': '\033[90m', 19 | 'white': '\033[97m', 20 | 'reset': '\033[0m' 21 | } 22 | 23 | color_code = color_dict.get(color.lower(), color_dict['reset']) 24 | prefix = color_dict['bold'] if bold else '' 25 | print(f"{prefix}{color_code}{message}{color_dict['reset']}", **kwargs) 26 | 27 | def print0_colored(*args, **kwargs): 28 | if rank0(): 29 | print_colored(*args, **kwargs) 30 | 31 | def param_count(module): 32 | def count_parameters(model): 33 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 34 | 35 | total_params = count_parameters(module) 36 | output = [f'Total model parameters: {total_params:,}', '---------------------------'] 37 | 38 | for name, child in module.named_children(): 39 | params = count_parameters(child) 40 | output.append(f'{name} parameters: {params:,}') 41 | 42 | return '\n'.join(output) 43 | 44 | def model_size_estimation(module): 45 | def estimate_size(model): 46 | param_size = sum(p.nelement() * p.element_size() for p in model.parameters()) 47 | buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers()) 48 | return param_size + buffer_size 49 | 50 | total_size = estimate_size(module) 51 | output = [f'Total model size: {total_size / 1024**2:.2f} MB', '---------------------------'] 52 | 53 | for name, child in module.named_children(): 54 | child_size = estimate_size(child) 55 | output.append(f'{name} size: {child_size / 1024**2:.2f} MB') 56 | 57 | return '\n'.join(output) 58 | 59 | def layer_param_distribution(module): 60 | def count_parameters(model): 61 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 62 | 63 | def get_layer_types(model): 64 | layer_types = {} 65 | for name, module in model.named_modules(): 66 | layer_type = module.__class__.__name__ 67 | params = sum(p.numel() for p in module.parameters(recurse=False) if p.requires_grad) 68 | if params > 0: 69 | if layer_type not in layer_types: 70 | layer_types[layer_type] = 0 71 | layer_types[layer_type] += params 72 | return layer_types 73 | 74 | total_params = count_parameters(module) 75 | layer_types = get_layer_types(module) 76 | 77 | output = [f'Total trainable parameters: {total_params:,}', '---------------------------'] 78 | 79 | for layer_type, count in sorted(layer_types.items(), key=lambda x: x[1], reverse=True): 80 | percentage = (count / total_params) * 100 81 | output.append(f'{layer_type}: {count:,} ({percentage:.2f}%)') 82 | 83 | return '\n'.join(output) 84 | 85 | -------------------------------------------------------------------------------- /utils/dist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch as T 3 | import re 4 | from tqdm import tqdm 5 | from datetime import timedelta 6 | 7 | import requests 8 | import hashlib 9 | 10 | from io import BytesIO 11 | 12 | def rank0(): 13 | rank = os.environ.get('RANK') 14 | if rank is None or rank == '0': 15 | return True 16 | else: 17 | return False 18 | 19 | def local0(): 20 | local_rank = os.environ.get('LOCAL_RANK') 21 | if local_rank is None or local_rank == '0': 22 | return True 23 | else: 24 | return False 25 | class tqdm0(tqdm): 26 | def __init__(self, *args, **kwargs): 27 | total = kwargs.get('total', None) 28 | if total is None and len(args) > 0: 29 | try: 30 | total = len(args[0]) 31 | except TypeError: 32 | pass 33 | if total is not None: 34 | kwargs['miniters'] = max(1, total // 20) 35 | super().__init__(*args, **kwargs, disable=not rank0(), bar_format='{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]') 36 | 37 | def print0(*args, **kwargs): 38 | if rank0(): 39 | print(*args, **kwargs) 40 | 41 | _PRINTED_IDS = set() 42 | 43 | def printonce(*args, id=None, **kwargs): 44 | if id is None: 45 | id = ' '.join(map(str, args)) 46 | 47 | if id not in _PRINTED_IDS: 48 | print(*args, **kwargs) 49 | _PRINTED_IDS.add(id) 50 | 51 | def print0once(*args, **kwargs): 52 | if rank0(): 53 | printonce(*args, **kwargs) 54 | 55 | def init_dist(): 56 | if T.distributed.is_initialized(): 57 | print0('Distributed already initialized') 58 | rank = T.distributed.get_rank() 59 | local_rank = int(os.environ.get('LOCAL_RANK', 0)) 60 | world_size = T.distributed.get_world_size() 61 | else: 62 | try: 63 | rank = int(os.environ['RANK']) 64 | local_rank = int(os.environ['LOCAL_RANK']) 65 | world_size = int(os.environ['WORLD_SIZE']) 66 | device = f'cuda:{local_rank}' 67 | T.cuda.set_device(device) 68 | T.distributed.init_process_group(backend='nccl', timeout=timedelta(minutes=30), rank=rank, world_size=world_size, device_id=T.device(device)) 69 | print(f'Rank {rank} of {world_size}.') 70 | except Exception as e: 71 | print0once(f'Not initializing distributed env: {e}') 72 | rank = 0 73 | local_rank = 0 74 | world_size = 1 75 | return rank, local_rank, world_size 76 | 77 | def load_ckpt(load_from_location, expected_hash=None): 78 | if local0(): 79 | os.makedirs('ckpt', exist_ok=True) 80 | url = f"https://ckpt.si.inc/hertz-dev/{load_from_location}.pt" 81 | save_path = f"ckpt/{load_from_location}.pt" 82 | if not os.path.exists(save_path): 83 | response = requests.get(url, stream=True) 84 | total_size = int(response.headers.get('content-length', 0)) 85 | with open(save_path, 'wb') as f, tqdm(total=total_size, desc=f'Downloading {load_from_location}.pt', unit='GB', unit_scale=1/(1024*1024*1024)) as pbar: 86 | for chunk in response.iter_content(chunk_size=8192): 87 | f.write(chunk) 88 | pbar.update(len(chunk)) 89 | if expected_hash is not None: 90 | with open(save_path, 'rb') as f: 91 | file_hash = hashlib.md5(f.read()).hexdigest() 92 | if file_hash != expected_hash: 93 | print(f'Hash mismatch for {save_path}. Expected {expected_hash} but got {file_hash}. Deleting checkpoint and trying again.') 94 | os.remove(save_path) 95 | return load_ckpt(load_from_location, expected_hash) 96 | if T.distributed.is_initialized(): 97 | T.distributed.barrier() # so that ranks don't try to load checkpoint before it's finished downloading 98 | loaded = T.load(f"ckpt/{load_from_location}.pt", weights_only=False, map_location='cpu') 99 | return loaded -------------------------------------------------------------------------------- /inference_server.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | from fastapi import FastAPI, WebSocket 4 | from fastapi.middleware.cors import CORSMiddleware 5 | import base64 6 | import uvicorn 7 | import traceback 8 | import numpy as np 9 | import argparse 10 | 11 | import torch as T 12 | import torch.nn.functional as F 13 | import torchaudio 14 | 15 | import os 16 | from typing import Optional 17 | 18 | from utils import print_colored 19 | from model import get_hertz_dev_config 20 | 21 | 22 | argparse = argparse.ArgumentParser() 23 | 24 | argparse.add_argument('--prompt_path', type=str, default='./prompts/bob_mono.wav', help=""" 25 | We highly recommend making your own prompt based on a conversation between you and another person. 26 | bob_mono.wav seems to work better for two-channel than bob_stereo.wav. 27 | """) 28 | args = argparse.parse_args() 29 | 30 | 31 | device = 'cuda' if T.cuda.is_available() else T.device('cpu') 32 | print_colored(f"Using device: {device}", "grey") 33 | 34 | model_config = get_hertz_dev_config(is_split=True) 35 | 36 | model = model_config() 37 | model = model.eval().bfloat16().to(device) 38 | 39 | app = FastAPI() 40 | 41 | app.add_middleware( 42 | CORSMiddleware, 43 | allow_origins=["*"], 44 | allow_credentials=True, 45 | allow_methods=["*"], 46 | allow_headers=["*"], 47 | ) 48 | 49 | 50 | # Hyperparams or something. 51 | SAMPLE_RATE = 16000 # Don't change this 52 | TEMPS = (0.8, (0.4, 0.1)) # You can change this, but there's also an endpoint for it. 53 | REPLAY_SECONDS = 3 # What the user hears as context. 54 | 55 | class AudioProcessor: 56 | def __init__(self, model, prompt_path): 57 | self.model = model 58 | self.prompt_path = prompt_path 59 | self.initialize_state(prompt_path) 60 | 61 | def initialize_state(self, prompt_path): 62 | loaded_audio, sr = torchaudio.load(prompt_path) 63 | self.replay_seconds = REPLAY_SECONDS 64 | 65 | if sr != SAMPLE_RATE: 66 | resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE) 67 | loaded_audio = resampler(loaded_audio) 68 | 69 | if loaded_audio.shape[0] == 1: 70 | loaded_audio = loaded_audio.repeat(2, 1) 71 | 72 | audio_length = loaded_audio.shape[-1] 73 | num_chunks = audio_length // 2000 74 | loaded_audio = loaded_audio[..., :num_chunks * 2000] 75 | 76 | self.loaded_audio = loaded_audio.to(device) 77 | 78 | with T.autocast(device_type=device, dtype=T.bfloat16), T.inference_mode(): 79 | self.model.init_cache(bsize=1, device=device, dtype=T.bfloat16, length=1024) 80 | self.next_model_audio = self.model.next_audio_from_audio(self.loaded_audio.unsqueeze(0), temps=TEMPS) 81 | self.prompt_buffer = None 82 | self.prompt_position = 0 83 | self.chunks_until_live = int(self.replay_seconds * 8) 84 | self.initialize_prompt_buffer() 85 | print_colored("AudioProcessor state initialized", "green") 86 | 87 | def initialize_prompt_buffer(self): 88 | self.recorded_audio = self.loaded_audio 89 | prompt_audio = self.loaded_audio.reshape(1, 2, -1) 90 | prompt_audio = prompt_audio[:, :, -(16000*self.replay_seconds):].cpu().numpy() 91 | prompt_audio_mono = prompt_audio.mean(axis=1) 92 | self.prompt_buffer = np.array_split(prompt_audio_mono[0], int(self.replay_seconds * 8)) 93 | print_colored(f"Initialized prompt buffer with {len(self.prompt_buffer)} chunks", "grey") 94 | 95 | async def process_audio(self, audio_data): 96 | if self.chunks_until_live > 0: 97 | print_colored(f"Serving from prompt buffer, {self.chunks_until_live} chunks left", "grey") 98 | chunk = self.prompt_buffer[int(self.replay_seconds * 8) - self.chunks_until_live] 99 | self.chunks_until_live -= 1 100 | 101 | if self.chunks_until_live == 0: 102 | print_colored("Switching to live processing mode", "green") 103 | 104 | time.sleep(0.05) 105 | return chunk 106 | 107 | audio_tensor = T.from_numpy(audio_data).to(device) 108 | audio_tensor = audio_tensor.reshape(1, 1, -1) 109 | audio_tensor = T.cat([audio_tensor, self.next_model_audio], dim=1) 110 | 111 | with T.autocast(device_type=device, dtype=T.bfloat16), T.inference_mode(): 112 | curr_model_audio = self.model.next_audio_from_audio( 113 | audio_tensor, 114 | temps=TEMPS 115 | ) 116 | print(f"Recorded audio shape {self.recorded_audio.shape}, audio tensor shape {audio_tensor.shape}") 117 | self.recorded_audio = T.cat([self.recorded_audio.cpu(), audio_tensor.squeeze(0).cpu()], dim=-1) 118 | 119 | self.next_model_audio = curr_model_audio 120 | 121 | return curr_model_audio.float().cpu().numpy() 122 | 123 | def cleanup(self): 124 | print_colored("Cleaning up audio processor...", "blue") 125 | os.makedirs('audio_recordings', exist_ok=True) 126 | torchaudio.save(f'audio_recordings/{time.strftime("%d-%H-%M")}.wav', self.recorded_audio.cpu(), SAMPLE_RATE) 127 | self.model.deinit_cache() 128 | self.initialize_state(self.prompt_path) 129 | print_colored("Audio processor cleanup complete", "green") 130 | 131 | @app.post("/set_temperature") 132 | async def set_temperature(token_temp: Optional[float] = None, categorical_temp: Optional[float] = None, gaussian_temp: Optional[float] = None): 133 | try: 134 | global TEMPS 135 | TEMPS = (token_temp, (categorical_temp, gaussian_temp)) 136 | 137 | print_colored(f"Temperature updated to: {TEMPS}", "green") 138 | return {"message": f"Temperature updated to: {TEMPS}", "status": "success"} 139 | except Exception as e: 140 | print_colored(f"Error setting temperature: {str(e)}", "red") 141 | return {"message": f"Error setting temperature: {str(e)}", "status": "error"} 142 | 143 | @app.websocket("/audio") 144 | async def websocket_endpoint(websocket: WebSocket): 145 | await websocket.accept() 146 | try: 147 | while True: 148 | data = await websocket.receive_text() 149 | audio_data = np.frombuffer( 150 | base64.b64decode(data.split(",")[1]), 151 | dtype=np.int16 152 | ) 153 | audio_data = audio_data.astype(np.float32) / 32767.0 154 | processed_audio = await audio_processor.process_audio(audio_data) 155 | processed_audio = (processed_audio * 32767).astype(np.int16) 156 | 157 | processed_data = base64.b64encode(processed_audio.tobytes()).decode('utf-8') 158 | await websocket.send_text(f"data:audio/raw;base64,{processed_data}") 159 | 160 | except Exception as e: 161 | print_colored(f"WebSocket error: {e}", "red") 162 | print_colored(f"Full traceback:\n{traceback.format_exc()}", "red") 163 | finally: 164 | audio_processor.cleanup() 165 | await websocket.close() 166 | 167 | 168 | audio_processor = AudioProcessor(model=model, prompt_path=args.prompt_path) 169 | 170 | if __name__ == "__main__": 171 | uvicorn.run(app, host="0.0.0.0", port=8000) 172 | print("Server started") 173 | -------------------------------------------------------------------------------- /inference_client.py: -------------------------------------------------------------------------------- 1 | # server.py remains the same as before 2 | 3 | # Updated client.py 4 | import asyncio 5 | import websockets 6 | import sounddevice as sd 7 | import numpy as np 8 | import base64 9 | import queue 10 | import argparse 11 | import requests 12 | import time 13 | 14 | class AudioClient: 15 | def __init__(self, server_url="ws://localhost:8000", token_temp=None, categorical_temp=None, gaussian_temp=None): 16 | # Convert ws:// to http:// for the base URL 17 | self.base_url = server_url.replace("ws://", "http://") 18 | self.server_url = f"{server_url}/audio" 19 | 20 | # Set temperatures if provided 21 | if any(t is not None for t in [token_temp, categorical_temp, gaussian_temp]): 22 | self.set_temperature_and_echo(token_temp, categorical_temp, gaussian_temp) 23 | 24 | # Initialize queues 25 | self.audio_queue = queue.Queue() 26 | self.output_queue = queue.Queue() 27 | 28 | def set_temperature_and_echo(self, token_temp=None, categorical_temp=None, gaussian_temp=None, echo_testing = False): 29 | """Send temperature settings to server""" 30 | params = {} 31 | if token_temp is not None: 32 | params['token_temp'] = token_temp 33 | if categorical_temp is not None: 34 | params['categorical_temp'] = categorical_temp 35 | if gaussian_temp is not None: 36 | params['gaussian_temp'] = gaussian_temp 37 | 38 | response = requests.post(f"{self.base_url}/set_temperature", params=params) 39 | print(response.json()['message']) 40 | 41 | def audio_callback(self, indata, frames, time, status): 42 | """This is called for each audio block""" 43 | if status: 44 | print(status) 45 | # if np.isclose(indata, 0).all(): 46 | # raise Exception('Audio input is not working - received all zeros') 47 | # Convert float32 to int16 for efficient transmission 48 | indata_int16 = (indata.copy() * 32767).astype(np.int16) 49 | # indata_int16 = np.zeros_like(indata_int16) 50 | self.audio_queue.put(indata_int16) 51 | 52 | def output_stream_callback(self, outdata, frames, time, status): 53 | """Callback for output stream to get audio data""" 54 | if status: 55 | print(status) 56 | 57 | try: 58 | data = self.output_queue.get_nowait() 59 | data = data.astype(np.float32) / 32767.0 60 | if len(data) < len(outdata): 61 | outdata[:len(data)] = data 62 | outdata[len(data):] = 0 63 | else: 64 | outdata[:] = data[:len(outdata)] 65 | except queue.Empty: 66 | outdata.fill(0) 67 | 68 | async def process_audio(self): 69 | async with websockets.connect(self.server_url) as ws: 70 | while self.running: 71 | if not self.audio_queue.empty(): 72 | # Get recorded audio 73 | audio_data = self.audio_queue.get() 74 | print(f'Data from microphone:{audio_data.shape, audio_data.dtype, audio_data.min(), audio_data.max()}') 75 | 76 | # Convert to base64 77 | audio_b64 = base64.b64encode(audio_data.tobytes()).decode('utf-8') 78 | 79 | # Send to server 80 | time_sent = time.time() 81 | await ws.send(f"data:audio/raw;base64,{audio_b64}") 82 | 83 | # Receive processed audio 84 | response = await ws.recv() 85 | response = response.split(",")[1] 86 | time_received = time.time() 87 | print(f"Data sent: {audio_b64[:10]}. Data received: {response[:10]}. Received in {(time_received - time_sent) * 1000:.2f} ms") 88 | processed_audio = np.frombuffer( 89 | base64.b64decode(response), 90 | dtype=np.int16 91 | ).reshape(-1, CHANNELS) 92 | print(f'Data from model:{processed_audio.shape, processed_audio.dtype, processed_audio.min(), processed_audio.max()}') 93 | 94 | self.output_queue.put(processed_audio) 95 | 96 | def start(self): 97 | self.running = True 98 | # Print audio device information 99 | devices = sd.query_devices() 100 | default_input = sd.query_devices(kind='input') 101 | default_output = sd.query_devices(kind='output') 102 | 103 | print("\nAudio Device Configuration:") 104 | print("-" * 50) 105 | print(f"Default Input Device:\n{default_input}\n") 106 | print(f"Default Output Device:\n{default_output}\n") 107 | print("\nAll Available Devices:") 108 | print("-" * 50) 109 | for i, device in enumerate(devices): 110 | print(f"Device {i}:") 111 | print(f"Name: {device['name']}") 112 | print(f"Channels (in/out): {device['max_input_channels']}/{device['max_output_channels']}") 113 | print(f"Sample Rates: {device['default_samplerate']}") 114 | print() 115 | input_device = input("Enter the index of the input device or press enter for default: ") 116 | output_device = input("Enter the index of the output device or press enter for default: ") 117 | if input_device == "": 118 | input_device = default_input['index'] 119 | if output_device == "": 120 | output_device = default_output['index'] 121 | with sd.InputStream(callback=self.audio_callback, 122 | channels=CHANNELS, 123 | samplerate=SAMPLE_RATE, 124 | device=int(input_device), 125 | blocksize=2000), \ 126 | sd.OutputStream(callback=self.output_stream_callback, 127 | channels=CHANNELS, 128 | samplerate=SAMPLE_RATE, 129 | blocksize=2000, 130 | device=int(output_device)): 131 | 132 | asyncio.run(self.process_audio()) 133 | 134 | def stop(self): 135 | self.running = False 136 | 137 | if __name__ == "__main__": 138 | parser = argparse.ArgumentParser(description='Audio Client with Temperature Control') 139 | parser.add_argument('--token_temp', '-t1', type=float, help='Token (LM) temperature parameter') 140 | parser.add_argument('--categorical_temp', '-t2', type=float, help='Categorical (VAE) temperature parameter') 141 | parser.add_argument('--gaussian_temp', '-t3', type=float, help='Gaussian (VAE) temperature parameter') 142 | parser.add_argument('--server', '-s', default="ws://localhost:8000", 143 | help='Server URL (default: ws://localhost:8000)') 144 | 145 | args = parser.parse_args() 146 | 147 | # Audio settings 148 | SAMPLE_RATE = 16000 149 | CHANNELS = 1 150 | 151 | client = AudioClient( 152 | server_url=args.server, 153 | token_temp=args.token_temp, 154 | categorical_temp=args.categorical_temp, 155 | gaussian_temp=args.gaussian_temp 156 | ) 157 | 158 | try: 159 | client.start() 160 | except KeyboardInterrupt: 161 | client.stop() -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import torch as T 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchaudio 5 | from utils import load_ckpt, print_colored 6 | from tokenizer import make_tokenizer 7 | from model import get_hertz_dev_config 8 | import matplotlib.pyplot as plt 9 | from IPython.display import Audio, display 10 | import gradio as gr 11 | 12 | # If you get an error like "undefined symbol: __nvJitLinkComplete_12_4, version libnvJitLink.so.12", 13 | # you need to install PyTorch with the correct CUDA version. Run: 14 | # `pip3 uninstall torch torchaudio && pip3 install torch torchaudio --index-url https://download.pytorch.org/whl/cu121` 15 | 16 | device = 'cuda' if T.cuda.is_available() else 'cpu' 17 | T.cuda.set_device(0) 18 | print_colored(f"Using device: {device}", "grey") 19 | 20 | # If you've already downloaded the model checkpoints, save them in ckpt/. 21 | # This code will automatically download them if it can't find them. 22 | audio_tokenizer = make_tokenizer(device) 23 | 24 | # We have different checkpoints for the single-speaker and two-speaker models 25 | # Set to True to load and run inference with the two-speaker model 26 | #TWO_SPEAKER = False 27 | TWO_SPEAKER = True 28 | USE_PURE_AUDIO_ABLATION = False # We trained a base model with no text initialization at all. Toggle this to enable it. 29 | assert not (USE_PURE_AUDIO_ABLATION and TWO_SPEAKER) # We only have a single-speaker version of this model. 30 | 31 | 32 | def load_and_preprocess_audio(audio_path, speakers): 33 | print_colored("Loading and preprocessing audio...", "blue", bold=True) 34 | # Load audio file 35 | audio_tensor, sr = torchaudio.load(audio_path) 36 | print_colored(f"Loaded audio shape: {audio_tensor.shape}", "grey") 37 | 38 | if speakers == 2: 39 | if audio_tensor.shape[0] == 1: 40 | print_colored("Converting mono to stereo...", "grey") 41 | audio_tensor = audio_tensor.repeat(2, 1) 42 | print_colored(f"Stereo audio shape: {audio_tensor.shape}", "grey") 43 | else: 44 | if audio_tensor.shape[0] == 2: 45 | print_colored("Converting stereo to mono...", "grey") 46 | audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0) 47 | print_colored(f"Mono audio shape: {audio_tensor.shape}", "grey") 48 | 49 | # Resample to 16kHz if needed 50 | if sr != 16000: 51 | print_colored(f"Resampling from {sr}Hz to 16000Hz...", "grey") 52 | resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000) 53 | audio_tensor = resampler(audio_tensor) 54 | 55 | # Clip to 5 minutes if needed 56 | max_samples = 16000 * 60 * 5 57 | if audio_tensor.shape[1] > max_samples: 58 | print_colored("Clipping audio to 5 minutes...", "grey") 59 | audio_tensor = audio_tensor[:, :max_samples] 60 | 61 | 62 | print_colored("Audio preprocessing complete!", "green") 63 | return audio_tensor.unsqueeze(0) 64 | 65 | 66 | 67 | # Our model is very prompt-sensitive, so we recommend experimenting with a diverse set of prompts. 68 | #prompt_audio = load_and_preprocess_audio('./prompts/toaskanymore.wav') 69 | 70 | 71 | def get_completion(encoded_prompt_audio, prompt_len, gen_len, speakers, token_temp, categorical_temp, gaussian_temp): 72 | 73 | TWO_SPEAKER = (speakers == 2) 74 | model_config = get_hertz_dev_config(is_split=TWO_SPEAKER, use_pure_audio_ablation=USE_PURE_AUDIO_ABLATION) 75 | generator = model_config() 76 | generator = generator.eval().to(T.bfloat16).to(device) 77 | 78 | 79 | prompt_len_seconds = prompt_len / 8 80 | print_colored(f"Prompt length: {prompt_len_seconds:.2f}s", "grey") 81 | print_colored("Completing audio...", "blue") 82 | encoded_prompt_audio = encoded_prompt_audio[:, :prompt_len] 83 | 84 | with T.autocast(device_type='cuda', dtype=T.bfloat16): 85 | completed_audio_batch = generator.completion( 86 | encoded_prompt_audio, 87 | temps=(token_temp, (categorical_temp, gaussian_temp)), 88 | #temps=(.8, (0.5, 0.1)), # (token_temp, (categorical_temp, gaussian_temp)) 89 | use_cache=True, 90 | gen_len=gen_len 91 | ) 92 | 93 | completed_audio = completed_audio_batch 94 | print_colored(f"Decoding completion...", "blue") 95 | if TWO_SPEAKER: 96 | decoded_completion_ch1 = audio_tokenizer.data_from_latent(completed_audio[:, :, :32].bfloat16()) 97 | decoded_completion_ch2 = audio_tokenizer.data_from_latent(completed_audio[:, :, 32:].bfloat16()) 98 | decoded_completion = T.cat([decoded_completion_ch1, decoded_completion_ch2], dim=0) 99 | else: 100 | decoded_completion = audio_tokenizer.data_from_latent(completed_audio.bfloat16()) 101 | print_colored(f"Decoded completion shape: {decoded_completion.shape}", "grey") 102 | 103 | print_colored("Preparing audio for playback...", "blue") 104 | 105 | audio_tensor = decoded_completion.cpu().squeeze() 106 | if audio_tensor.ndim == 1: 107 | audio_tensor = audio_tensor.unsqueeze(0) 108 | audio_tensor = audio_tensor.float() 109 | 110 | if audio_tensor.abs().max() > 1: 111 | audio_tensor = audio_tensor / audio_tensor.abs().max() 112 | 113 | #return audio_tensor[:, max(prompt_len*2000 - 16000, 0):] 114 | return audio_tensor 115 | 116 | def run(audio_path, prompt_len_seconds, gen_len_seconds, speakers, token_temp, categorical_temp, gaussian_temp): 117 | # 1. encode audio 118 | prompt_audio = load_and_preprocess_audio(audio_path, speakers) 119 | prompt_len = prompt_len_seconds * 8 120 | gen_len = gen_len_seconds * 8 121 | print(f"prompt_len={prompt_len}, gen_len={gen_len}") 122 | print_colored("Encoding prompt...", "blue") 123 | with T.autocast(device_type='cuda', dtype=T.bfloat16): 124 | if speakers == 2: 125 | encoded_prompt_audio_ch1 = audio_tokenizer.latent_from_data(prompt_audio[:, 0:1].to(device)) 126 | encoded_prompt_audio_ch2 = audio_tokenizer.latent_from_data(prompt_audio[:, 1:2].to(device)) 127 | encoded_prompt_audio = T.cat([encoded_prompt_audio_ch1, encoded_prompt_audio_ch2], dim=-1) 128 | else: 129 | encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device)) 130 | print_colored(f"Encoded prompt shape: {encoded_prompt_audio.shape}", "grey") 131 | print_colored("Prompt encoded successfully!", "green") 132 | 133 | # 2. get completion 134 | audio_tensor = get_completion(encoded_prompt_audio, prompt_len, gen_len, speakers, token_temp, categorical_temp, gaussian_temp) 135 | audio_np = audio_tensor.numpy() 136 | audio_tensor = audio_tensor.cpu().squeeze() 137 | if audio_tensor.ndim == 1: 138 | audio_tensor = audio_tensor.unsqueeze(0) 139 | audio_tensor = audio_tensor.float() 140 | # audio_np = audio_tensor.numpy() 141 | 142 | torchaudio.save("generated.wav", audio_tensor, 16000) 143 | return "generated.wav" 144 | 145 | 146 | 147 | # sample_rate = 16000 148 | # return (sample_rate, audio_np) 149 | 150 | 151 | 152 | with gr.Blocks() as demo: 153 | with gr.Row(): 154 | with gr.Column(): 155 | with gr.Group(): 156 | audio = gr.Audio(label="Reference Audio", type="filepath") 157 | with gr.Row(): 158 | prompt_len_seconds = gr.Number(label="Continue from N sec", value=3) 159 | gen_len = gr.Number(label="Generate N seconds", value=10) 160 | speakers = gr.Radio(label="Number of Speakers", choices=[1,2], value=1) 161 | button = gr.Button("Generate") 162 | with gr.Accordion("Advanced", open=False): 163 | token_temp = gr.Number(label="token temperature", value=0.8) 164 | categorical_temp = gr.Number(label="categorical temperature", value=0.4) 165 | gaussian_temp = gr.Number(label="gaussian temperature", value=0.1) 166 | with gr.Column(): 167 | generated = gr.Audio(label="Generated", type="filepath", interactive=False) 168 | button.click( 169 | fn=run, 170 | inputs=[audio, prompt_len_seconds, gen_len, speakers, token_temp, categorical_temp, gaussian_temp], 171 | outputs=[generated] 172 | ) 173 | 174 | demo.launch() 175 | -------------------------------------------------------------------------------- /inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import torch as T\n", 20 | "import torch.nn as nn\n", 21 | "import torch.nn.functional as F\n", 22 | "import torchaudio\n", 23 | "from utils import load_ckpt, print_colored\n", 24 | "from tokenizer import make_tokenizer\n", 25 | "from model import get_hertz_dev_config\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "from IPython.display import Audio, display\n", 28 | "\n", 29 | "\n", 30 | "# If you get an error like \"undefined symbol: __nvJitLinkComplete_12_4, version libnvJitLink.so.12\",\n", 31 | "# you need to install PyTorch with the correct CUDA version. Run:\n", 32 | "# `pip3 uninstall torch torchaudio && pip3 install torch torchaudio --index-url https://download.pytorch.org/whl/cu121`\n", 33 | "\n", 34 | "device = 'cuda' if T.cuda.is_available() else 'cpu'\n", 35 | "T.cuda.set_device(0)\n", 36 | "print_colored(f\"Using device: {device}\", \"grey\")" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "# If you've already downloaded the model checkpoints, save them in ckpt/.\n", 46 | "# This code will automatically download them if it can't find them.\n", 47 | "audio_tokenizer = make_tokenizer(device)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 4, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "# We have different checkpoints for the single-speaker and two-speaker models\n", 57 | "# Set to True to load and run inference with the two-speaker model\n", 58 | "TWO_SPEAKER = False\n", 59 | "USE_PURE_AUDIO_ABLATION = False # We trained a base model with no text initialization at all. Toggle this to enable it.\n", 60 | "assert not (USE_PURE_AUDIO_ABLATION and TWO_SPEAKER) # We only have a single-speaker version of this model.\n" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "model_config = get_hertz_dev_config(is_split=TWO_SPEAKER, use_pure_audio_ablation=USE_PURE_AUDIO_ABLATION)\n", 70 | "\n", 71 | "generator = model_config()\n", 72 | "generator = generator.eval().to(T.bfloat16).to(device)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "def load_and_preprocess_audio(audio_path):\n", 82 | " print_colored(\"Loading and preprocessing audio...\", \"blue\", bold=True)\n", 83 | " # Load audio file\n", 84 | " audio_tensor, sr = torchaudio.load(audio_path)\n", 85 | " print_colored(f\"Loaded audio shape: {audio_tensor.shape}\", \"grey\")\n", 86 | " \n", 87 | " if TWO_SPEAKER:\n", 88 | " if audio_tensor.shape[0] == 1:\n", 89 | " print_colored(\"Converting mono to stereo...\", \"grey\")\n", 90 | " audio_tensor = audio_tensor.repeat(2, 1)\n", 91 | " print_colored(f\"Stereo audio shape: {audio_tensor.shape}\", \"grey\")\n", 92 | " else:\n", 93 | " if audio_tensor.shape[0] == 2:\n", 94 | " print_colored(\"Converting stereo to mono...\", \"grey\")\n", 95 | " audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0)\n", 96 | " print_colored(f\"Mono audio shape: {audio_tensor.shape}\", \"grey\")\n", 97 | " \n", 98 | " # Resample to 16kHz if needed\n", 99 | " if sr != 16000:\n", 100 | " print_colored(f\"Resampling from {sr}Hz to 16000Hz...\", \"grey\")\n", 101 | " resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)\n", 102 | " audio_tensor = resampler(audio_tensor)\n", 103 | " \n", 104 | " # Clip to 5 minutes if needed\n", 105 | " max_samples = 16000 * 60 * 5\n", 106 | " if audio_tensor.shape[1] > max_samples:\n", 107 | " print_colored(\"Clipping audio to 5 minutes...\", \"grey\")\n", 108 | " audio_tensor = audio_tensor[:, :max_samples]\n", 109 | "\n", 110 | " \n", 111 | " print_colored(\"Audio preprocessing complete!\", \"green\")\n", 112 | " return audio_tensor.unsqueeze(0)\n", 113 | "\n", 114 | "def display_audio(audio_tensor):\n", 115 | " audio_tensor = audio_tensor.cpu().squeeze()\n", 116 | " if audio_tensor.ndim == 1:\n", 117 | " audio_tensor = audio_tensor.unsqueeze(0)\n", 118 | " audio_tensor = audio_tensor.float()\n", 119 | "\n", 120 | " # Make a waveform plot\n", 121 | " plt.figure(figsize=(4, 1))\n", 122 | " plt.plot(audio_tensor.numpy()[0], linewidth=0.5)\n", 123 | " plt.axis('off')\n", 124 | " plt.show()\n", 125 | "\n", 126 | " # Make an audio player\n", 127 | " display(Audio(audio_tensor.numpy(), rate=16000))\n", 128 | " print_colored(f\"Audio ready for playback ↑\", \"green\", bold=True)\n", 129 | " \n", 130 | " \n", 131 | "\n", 132 | "# Our model is very prompt-sensitive, so we recommend experimenting with a diverse set of prompts.\n", 133 | "prompt_audio = load_and_preprocess_audio('./prompts/toaskanymore.wav')\n", 134 | "display_audio(prompt_audio)\n", 135 | "prompt_len_seconds = 3\n", 136 | "prompt_len = prompt_len_seconds * 8" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "print_colored(\"Encoding prompt...\", \"blue\")\n", 146 | "with T.autocast(device_type='cuda', dtype=T.bfloat16):\n", 147 | " if TWO_SPEAKER:\n", 148 | " encoded_prompt_audio_ch1 = audio_tokenizer.latent_from_data(prompt_audio[:, 0:1].to(device))\n", 149 | " encoded_prompt_audio_ch2 = audio_tokenizer.latent_from_data(prompt_audio[:, 1:2].to(device))\n", 150 | " encoded_prompt_audio = T.cat([encoded_prompt_audio_ch1, encoded_prompt_audio_ch2], dim=-1)\n", 151 | " else:\n", 152 | " encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device))\n", 153 | "print_colored(f\"Encoded prompt shape: {encoded_prompt_audio.shape}\", \"grey\")\n", 154 | "print_colored(\"Prompt encoded successfully!\", \"green\")" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "def get_completion(encoded_prompt_audio, prompt_len, gen_len=None):\n", 164 | " prompt_len_seconds = prompt_len / 8\n", 165 | " print_colored(f\"Prompt length: {prompt_len_seconds:.2f}s\", \"grey\")\n", 166 | " print_colored(\"Completing audio...\", \"blue\")\n", 167 | " encoded_prompt_audio = encoded_prompt_audio[:, :prompt_len]\n", 168 | " with T.autocast(device_type='cuda', dtype=T.bfloat16):\n", 169 | " completed_audio_batch = generator.completion(\n", 170 | " encoded_prompt_audio, \n", 171 | " temps=(.8, (0.5, 0.1)), # (token_temp, (categorical_temp, gaussian_temp))\n", 172 | " use_cache=True,\n", 173 | " gen_len=gen_len)\n", 174 | "\n", 175 | " completed_audio = completed_audio_batch\n", 176 | " print_colored(f\"Decoding completion...\", \"blue\")\n", 177 | " if TWO_SPEAKER:\n", 178 | " decoded_completion_ch1 = audio_tokenizer.data_from_latent(completed_audio[:, :, :32].bfloat16())\n", 179 | " decoded_completion_ch2 = audio_tokenizer.data_from_latent(completed_audio[:, :, 32:].bfloat16())\n", 180 | " decoded_completion = T.cat([decoded_completion_ch1, decoded_completion_ch2], dim=0)\n", 181 | " else:\n", 182 | " decoded_completion = audio_tokenizer.data_from_latent(completed_audio.bfloat16())\n", 183 | " print_colored(f\"Decoded completion shape: {decoded_completion.shape}\", \"grey\")\n", 184 | "\n", 185 | " print_colored(\"Preparing audio for playback...\", \"blue\")\n", 186 | "\n", 187 | " audio_tensor = decoded_completion.cpu().squeeze()\n", 188 | " if audio_tensor.ndim == 1:\n", 189 | " audio_tensor = audio_tensor.unsqueeze(0)\n", 190 | " audio_tensor = audio_tensor.float()\n", 191 | "\n", 192 | " if audio_tensor.abs().max() > 1:\n", 193 | " audio_tensor = audio_tensor / audio_tensor.abs().max()\n", 194 | "\n", 195 | " return audio_tensor[:, max(prompt_len*2000 - 16000, 0):]\n", 196 | "\n", 197 | "num_completions = 10\n", 198 | "print_colored(f\"Generating {num_completions} completions...\", \"blue\")\n", 199 | "for _ in range(num_completions):\n", 200 | " completion = get_completion(encoded_prompt_audio, prompt_len, gen_len=20*8) # 20 seconds of generation\n", 201 | " display_audio(completion)" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [] 217 | } 218 | ], 219 | "metadata": { 220 | "kernelspec": { 221 | "display_name": ".venv", 222 | "language": "python", 223 | "name": "python3" 224 | }, 225 | "language_info": { 226 | "codemirror_mode": { 227 | "name": "ipython", 228 | "version": 3 229 | }, 230 | "file_extension": ".py", 231 | "mimetype": "text/x-python", 232 | "name": "python", 233 | "nbconvert_exporter": "python", 234 | "pygments_lexer": "ipython3", 235 | "version": "3.10.12" 236 | } 237 | }, 238 | "nbformat": 4, 239 | "nbformat_minor": 2 240 | } 241 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 Standard Intelligence PBC 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /ioblocks.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from functools import partial 3 | from contextlib import nullcontext 4 | from typing import List, Tuple 5 | from math import ceil 6 | 7 | import torch as T 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.distributed as dist 11 | from torch import Tensor, int32 12 | from torch.amp import autocast 13 | 14 | from einops import rearrange, pack, unpack 15 | 16 | 17 | from utils import si_module, exists, default, maybe 18 | 19 | 20 | @si_module 21 | class GaussianMixtureIOLayer(nn.Module): 22 | class Config: 23 | latent_dim: int 24 | dim: int 25 | num_components: int 26 | 27 | def __init__(self, c: Config): 28 | super().__init__() 29 | self.latent_dim = c.latent_dim 30 | self.num_components = c.num_components 31 | self.input_projection = nn.Linear(c.latent_dim, c.dim) 32 | 33 | self.fc_loc = nn.Linear(c.dim, c.num_components * c.latent_dim) 34 | self.fc_scale = nn.Linear(c.dim, c.num_components * c.latent_dim) 35 | self.fc_weight = nn.Linear(c.dim, c.num_components) 36 | 37 | def _square_plus(self, x): 38 | return (x + T.sqrt(T.square(x) + 4)) / 2 39 | 40 | def input(self, sampled_latents: T.Tensor) -> T.Tensor: 41 | """Pre-sampled latents T.Tensor (B, L, Z) -> float tensor (B, L, D)""" 42 | hidden = self.input_projection(sampled_latents) 43 | return hidden 44 | 45 | def output(self, h: T.Tensor) -> Tuple[T.Tensor, T.Tensor, T.Tensor]: 46 | """float tensor (B, L, D) -> Tuple of locs, scales, and weights""" 47 | batch_size, seq_len, _ = h.shape 48 | 49 | locs = self.fc_loc(h).view(batch_size, seq_len, self.num_components, self.latent_dim) 50 | scales = T.clamp(self._square_plus(self.fc_scale(h)), min=1e-6).view(batch_size, seq_len, self.num_components, self.latent_dim) 51 | weights = self.fc_weight(h).view(batch_size, seq_len, self.num_components) 52 | 53 | return (locs, scales, weights) 54 | 55 | def loss(self, data, dataHat): 56 | locs, scales, weights = dataHat 57 | log_probs = -0.5 * T.sum( 58 | (data.unsqueeze(-2) - locs).pow(2) / scales.pow(2) + 59 | 2 * T.log(scales) + 60 | T.log(T.tensor(2 * T.pi)), 61 | dim=-1 62 | ) 63 | log_weights = F.log_softmax(weights, dim=-1) 64 | return -T.logsumexp(log_weights + log_probs, dim=-1) 65 | 66 | 67 | def temp_sample(self, orig_pdist, temp): 68 | locs, scales, weights = orig_pdist 69 | if temp is None: 70 | component_samples = locs + scales * T.randn_like(scales) 71 | mixture_samples = F.gumbel_softmax(weights, hard=True) 72 | sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2) 73 | elif isinstance(temp, tuple): 74 | assert len(temp) == 2 75 | categorical_temp, gaussian_temp = temp 76 | component_samples = locs + scales * gaussian_temp * T.randn_like(scales) 77 | mixture_samples = F.gumbel_softmax(weights / categorical_temp, hard=True) 78 | sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2) 79 | else: 80 | component_samples = locs + scales * temp * T.randn_like(scales) 81 | mixture_samples = F.gumbel_softmax(weights / temp, hard=True) 82 | sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2) 83 | return sampled 84 | 85 | 86 | class GPTOutput(nn.Module): 87 | def __init__(self, dim, vocab_size): 88 | super().__init__() 89 | self.output = nn.Linear(dim, vocab_size, bias=False) 90 | 91 | def forward(self, x): 92 | return self.output(x) 93 | 94 | 95 | # helper functions 96 | 97 | def pack_one(t, pattern): 98 | return pack([t], pattern) 99 | 100 | def unpack_one(t, ps, pattern): 101 | return unpack(t, ps, pattern)[0] 102 | 103 | def first(l): 104 | return l[0] 105 | 106 | def round_up_multiple(num, mult): 107 | return ceil(num / mult) * mult 108 | 109 | def get_code_utilization(codes, codebook_size, get_global=False): 110 | if get_global and dist.is_initialized(): 111 | world_size = dist.get_world_size() 112 | else: 113 | world_size = 1 114 | 115 | if world_size > 1: 116 | gathered_tokens = [T.zeros_like(codes) for _ in range(world_size)] 117 | dist.all_gather(gathered_tokens, codes) 118 | gathered_tokens = T.cat(gathered_tokens, dim=0) 119 | else: 120 | gathered_tokens = codes 121 | unique_tokens = len(T.unique(gathered_tokens)) 122 | code_utilization = unique_tokens / min(gathered_tokens.numel(), codebook_size) 123 | return code_utilization 124 | 125 | # tensor helpers 126 | 127 | def round_ste(z: Tensor) -> Tensor: 128 | """Round with straight through gradients.""" 129 | zhat = z.round() 130 | return z + (zhat - z).detach() 131 | 132 | # main class 133 | # lucidrains fsq 134 | @si_module 135 | class FSQ(nn.Module): 136 | @property 137 | def needs_float32_params(self): 138 | return True 139 | 140 | class Config: 141 | levels: List[int] 142 | dim: int | None = None 143 | num_codebooks: int = 1 144 | keep_num_codebooks_dim: bool | None = None 145 | scale: float | None = None 146 | allowed_dtypes: Tuple[str, ...] = ('float32', 'float64') 147 | channel_first: bool = False 148 | projection_has_bias: bool = True 149 | return_indices: bool = True 150 | force_quantization_f32: bool = True 151 | use_rms: bool = False 152 | 153 | def __init__(self, c: Config): 154 | super().__init__() 155 | _levels = T.tensor(c.levels, dtype=int32) 156 | self.register_buffer("_levels", _levels, persistent = False) 157 | 158 | _basis = T.cumprod(T.tensor([1] + c.levels[:-1]), dim=0, dtype=int32) 159 | self.register_buffer("_basis", _basis, persistent = False) 160 | 161 | self.scale = c.scale 162 | 163 | codebook_dim = len(c.levels) 164 | self.codebook_dim = codebook_dim 165 | 166 | effective_codebook_dim = codebook_dim * c.num_codebooks 167 | self.num_codebooks = c.num_codebooks 168 | 169 | self.allowed_dtypes = [] 170 | for dtype_str in c.allowed_dtypes: 171 | if hasattr(T, dtype_str): 172 | self.allowed_dtypes.append(getattr(T, dtype_str)) 173 | else: 174 | raise ValueError(f"Invalid dtype string: {dtype_str}") 175 | 176 | self.effective_codebook_dim = effective_codebook_dim 177 | 178 | keep_num_codebooks_dim = default(c.keep_num_codebooks_dim, c.num_codebooks > 1) 179 | assert not (c.num_codebooks > 1 and not keep_num_codebooks_dim) 180 | self.keep_num_codebooks_dim = keep_num_codebooks_dim 181 | 182 | self.dim = default(c.dim, len(_levels) * c.num_codebooks) 183 | 184 | self.channel_first = c.channel_first 185 | 186 | has_projections = self.dim != effective_codebook_dim 187 | self.project_in = nn.Linear(self.dim, effective_codebook_dim, bias = c.projection_has_bias) if has_projections else nn.Identity() 188 | self.project_out = nn.Linear(effective_codebook_dim, self.dim, bias = c.projection_has_bias) if has_projections else nn.Identity() 189 | 190 | self.has_projections = has_projections 191 | 192 | self.return_indices = c.return_indices 193 | if c.return_indices: 194 | self.codebook_size = self._levels.prod().item() 195 | implicit_codebook = self._indices_to_codes(T.arange(self.codebook_size)) 196 | self.register_buffer("implicit_codebook", implicit_codebook, persistent = False) 197 | 198 | self.allowed_dtypes = c.allowed_dtypes 199 | self.force_quantization_f32 = c.force_quantization_f32 200 | 201 | self.latent_loss = None 202 | 203 | def latent_metric(self, codes, get_global=False): 204 | return {'code_util_estimate': get_code_utilization(codes, self.codebook_size, get_global)} 205 | 206 | def repr_from_latent(self, latent): 207 | return self.indices_to_codes(latent) 208 | 209 | def bound(self, z, eps: float = 1e-3): 210 | """ Bound `z`, an array of shape (..., d). """ 211 | half_l = (self._levels - 1) * (1 + eps) / 2 212 | offset = T.where(self._levels % 2 == 0, 0.5, 0.0) 213 | shift = (offset / half_l).atanh() 214 | return (z + shift).tanh() * half_l - offset 215 | 216 | def quantize(self, z): 217 | """ Quantizes z, returns quantized zhat, same shape as z. """ 218 | quantized = round_ste(self.bound(z)) 219 | half_width = self._levels // 2 # Renormalize to [-1, 1]. 220 | return quantized / half_width 221 | 222 | def _scale_and_shift(self, zhat_normalized): 223 | half_width = self._levels // 2 224 | return (zhat_normalized * half_width) + half_width 225 | 226 | def _scale_and_shift_inverse(self, zhat): 227 | half_width = self._levels // 2 228 | return (zhat - half_width) / half_width 229 | 230 | def _indices_to_codes(self, indices): 231 | level_indices = self.indices_to_level_indices(indices) 232 | codes = self._scale_and_shift_inverse(level_indices) 233 | return codes 234 | 235 | def codes_to_indices(self, zhat): 236 | """ Converts a `code` to an index in the codebook. """ 237 | assert zhat.shape[-1] == self.codebook_dim 238 | zhat = self._scale_and_shift(zhat) 239 | return (zhat * self._basis).sum(dim=-1).to(int32) 240 | 241 | def indices_to_level_indices(self, indices): 242 | """ Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """ 243 | indices = rearrange(indices, '... -> ... 1') 244 | codes_non_centered = (indices // self._basis) % self._levels 245 | return codes_non_centered 246 | 247 | def indices_to_codes(self, indices): 248 | """ Inverse of `codes_to_indices`. """ 249 | assert exists(indices) 250 | 251 | is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) 252 | 253 | codes = self._indices_to_codes(indices) 254 | 255 | if self.keep_num_codebooks_dim: 256 | codes = rearrange(codes, '... c d -> ... (c d)') 257 | 258 | codes = self.project_out(codes) 259 | 260 | if is_img_or_video or self.channel_first: 261 | codes = rearrange(codes, 'b ... d -> b d ...') 262 | 263 | return codes 264 | 265 | # @autocast(device_type='cuda', enabled = False) 266 | def forward(self, z, return_codes=False): 267 | """ 268 | einstein notation 269 | b - batch 270 | n - sequence (or flattened spatial dimensions) 271 | d - feature dimension 272 | c - number of codebook dim 273 | """ 274 | 275 | is_img_or_video = z.ndim >= 4 276 | need_move_channel_last = is_img_or_video or self.channel_first 277 | 278 | # standardize image or video into (batch, seq, dimension) 279 | 280 | if need_move_channel_last: 281 | z = rearrange(z, 'b d ... -> b ... d') 282 | z, ps = pack_one(z, 'b * d') 283 | 284 | assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}' 285 | 286 | z = self.project_in(z) 287 | 288 | z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks) 289 | 290 | # whether to force quantization step to be full precision or not 291 | 292 | force_f32 = self.force_quantization_f32 293 | quantization_context = partial(autocast, device_type='cuda', enabled = False) if force_f32 else nullcontext 294 | 295 | with quantization_context(): 296 | orig_dtype = z.dtype 297 | 298 | if force_f32 and orig_dtype not in self.allowed_dtypes: 299 | z = z.float() 300 | 301 | codes = self.quantize(z) 302 | 303 | # returning indices could be optional 304 | 305 | indices = None 306 | 307 | if self.return_indices: 308 | indices = self.codes_to_indices(codes) 309 | 310 | codes = rearrange(codes, 'b n c d -> b n (c d)') 311 | 312 | codes = codes.type(orig_dtype) 313 | 314 | # project out 315 | if return_codes: 316 | return codes, indices 317 | 318 | out = self.project_out(codes) 319 | 320 | # reconstitute image or video dimensions 321 | 322 | if need_move_channel_last: 323 | out = unpack_one(out, ps, 'b * d') 324 | out = rearrange(out, 'b ... d -> b d ...') 325 | 326 | indices = maybe(unpack_one)(indices, ps, 'b * c') 327 | 328 | if not self.keep_num_codebooks_dim and self.return_indices: 329 | indices = maybe(rearrange)(indices, '... 1 -> ...') 330 | 331 | # return quantized output and indices 332 | 333 | return out, indices -------------------------------------------------------------------------------- /transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, MutableMapping 2 | from typing import Union 3 | import math 4 | from contextlib import nullcontext 5 | 6 | import torch 7 | import torch as T 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch import Tensor 11 | from torch.nn.attention import SDPBackend 12 | 13 | from einops import rearrange 14 | 15 | from utils import si_module, default, exists, load_ckpt 16 | 17 | CACHE_FILL_VALUE = -1 18 | 19 | def get_cache_len(cache: Optional[Tensor]) -> int: 20 | """ 21 | cache: (batch, seq_len, 2, kv_heads, head_dim) 22 | """ 23 | if cache is None: 24 | return 0 25 | nonzeros = T.any(cache.flatten(2) != CACHE_FILL_VALUE, dim=-1) 26 | length = nonzeros.sum(dim=-1).int() 27 | assert T.all(length == length[0]) 28 | return length[0] 29 | 30 | 31 | def rotate_half(x): 32 | x1, x2 = x.chunk(2, dim=-1) 33 | return torch.cat((-x2, x1), dim=-1) 34 | 35 | 36 | def apply_rotary_pos_emb(x, cos, sin, offset: int = 0): 37 | assert ( 38 | cos.shape[1] >= offset + x.shape[1] 39 | ), f"Offset and/or input sequence is too large,\ 40 | \n offset: {offset}, seq_len: {x.shape[1]}, max: {cos.shape[1]}" 41 | 42 | cos_out = cos[:, offset : offset + x.shape[1], :, :] 43 | sin_out = sin[:, offset : offset + x.shape[1], :, :] 44 | 45 | return (x * cos_out) + (rotate_half(x) * sin_out) 46 | 47 | 48 | # Adapted from https://github.com/foundation-model-stack/foundation-model-stack 49 | class ShapeRotator: 50 | def __init__( 51 | self, 52 | dim: int, 53 | end: int, 54 | theta: float = 10_000, 55 | ): 56 | super().__init__() 57 | self.dim = dim 58 | self.ratio = theta 59 | self.cached_freqs: MutableMapping[int, MutableMapping[int, torch.Tensor]] = {} 60 | self.max_seq_len_cached: MutableMapping[int, int] = {} 61 | self.ntk_scaling = False 62 | self.max_seq_len = end 63 | 64 | def compute_freqs_cis(self, device, max_seq_len=None): 65 | alpha = 1 66 | dev_idx = device.index 67 | max_seq_len = default(max_seq_len, self.max_seq_len) 68 | 69 | if dev_idx not in self.cached_freqs: 70 | self.cached_freqs[dev_idx] = {} 71 | if dev_idx not in self.max_seq_len_cached: 72 | self.max_seq_len_cached[dev_idx] = 0 73 | 74 | 75 | if self.max_seq_len_cached[dev_idx] > 0: 76 | return 1 77 | max_seq_len = max(max_seq_len, self.max_seq_len) 78 | 79 | if ( 80 | 1 in self.cached_freqs[dev_idx] 81 | and max_seq_len <= self.max_seq_len_cached[dev_idx] 82 | ): 83 | return 1 84 | 85 | ratio = self.ratio 86 | dim = self.dim 87 | 88 | freqs = 1.0 / (ratio ** (torch.arange(0, dim, 2, device=device).float() / dim)) 89 | 90 | t = torch.arange(max_seq_len, device=device, dtype=freqs.dtype) 91 | freqs = torch.einsum("i,j->ij", t, freqs) 92 | emb = torch.cat((freqs, freqs), dim=-1).to(device) 93 | 94 | cos_to_cache = emb.cos()[None, :, None, :] 95 | sin_to_cache = emb.sin()[None, :, None, :] 96 | 97 | self.max_seq_len_cached[dev_idx] = max_seq_len 98 | 99 | self.cached_freqs[dev_idx][alpha] = torch.stack( 100 | [ 101 | cos_to_cache, 102 | sin_to_cache, 103 | ], 104 | dim=-1, 105 | ) 106 | 107 | return alpha 108 | 109 | def rotate( 110 | self, 111 | q: Tensor, 112 | k: Tensor, 113 | offset: int = 0, 114 | ) -> Tuple[Tensor, Tensor]: 115 | """ 116 | Args 117 | ---- 118 | q : torch.Tensor 119 | Embedded query tensor, expected size is B x S x H x Eh 120 | k : torch.Tensor 121 | Embedded query tensor, expected size is B x S x H x Eh 122 | """ 123 | assert len(q.size()) == 4 124 | assert len(k.size()) == 4 125 | 126 | seq_len = self.max_seq_len 127 | alpha = self.compute_freqs_cis(q.device, seq_len) 128 | freqs = self.cached_freqs[q.device.index][alpha] 129 | 130 | freqs = freqs.float() # 1 L D/2 2 2 131 | q_out = apply_rotary_pos_emb(q, freqs[..., 0], freqs[..., 1], offset=offset).type_as(q) 132 | k_out = apply_rotary_pos_emb(k, freqs[..., 0], freqs[..., 1], offset=offset).type_as(k) 133 | 134 | return q_out.view_as(q), k_out.view_as(k) 135 | 136 | class Linear(nn.Linear): 137 | def __init__(self, *args, **kwargs): 138 | super().__init__(*args, **kwargs, bias=False) 139 | 140 | class Norm(nn.Module): 141 | def __init__(self, 142 | dim: int, 143 | eps: float = 1e-5,) -> None: 144 | super().__init__() 145 | self.eps = eps 146 | self.weight = nn.Parameter(T.ones((dim,))) 147 | 148 | def forward(self, input: Tensor) -> Tensor: 149 | return F.layer_norm(input, (self.weight.shape[0],), weight=self.weight, bias=None, eps=self.eps) 150 | 151 | 152 | class FFNN(nn.Module): 153 | def __init__(self, 154 | dim: int, 155 | expand_dim: int = None,): 156 | super().__init__() 157 | expand_dim = default(expand_dim, 256 * ((int(2 * 4 * dim / 3) + 256 - 1) // 256)) 158 | self.dim = dim 159 | self.expand_dim = expand_dim 160 | 161 | self.gateup_proj = Linear(dim, 2*expand_dim) 162 | self.down_proj = Linear(expand_dim, dim) 163 | 164 | def forward(self, x): 165 | gate, up = self.gateup_proj(x).chunk(2, dim=-1) 166 | return self.down_proj(up * F.silu(gate)) 167 | 168 | class GQA(nn.Module): 169 | def __init__(self, 170 | dim: int, 171 | n_head: int, 172 | shape_rotator: ShapeRotator, 173 | kv_heads: Optional[int] = None, 174 | eps: float = 1e-5, 175 | causal: bool = True,): 176 | super().__init__() 177 | self.n_heads = n_head 178 | self.kv_heads = default(kv_heads, n_head) 179 | self.head_dim = dim // n_head 180 | self.causal = causal 181 | 182 | self.proj_qkv = Linear(dim, self.head_dim*(n_head+2*self.kv_heads)) 183 | 184 | self.norm_q = Norm(self.head_dim*n_head, eps=eps) 185 | self.norm_k = Norm(self.head_dim*self.kv_heads, eps=eps) 186 | 187 | self.attn_out = Linear(dim, dim) 188 | 189 | self.shape_rotator = shape_rotator 190 | 191 | def _sdpa(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 192 | k = k.repeat_interleave(self.n_heads // self.kv_heads, dim=2) 193 | v = v.repeat_interleave(self.n_heads // self.kv_heads, dim=2) 194 | #with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION) if k.device.type == 'cuda' else nullcontext(): 195 | with nn.attention.sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION) if k.device.type == 'cuda' else nullcontext(): 196 | x = F.scaled_dot_product_attention( 197 | q.transpose(1, 2), 198 | k.transpose(1, 2), 199 | v.transpose(1, 2), 200 | is_causal=False if (q.size(1) != k.size(1)) else self.causal, 201 | ) 202 | x = x.transpose(1, 2).contiguous() 203 | return x 204 | 205 | def _attend(self, q: Tensor, k: Tensor, v: Tensor, kv_cache: Optional[Tensor] = None,): 206 | cache_len = get_cache_len(kv_cache) 207 | q, k = self.shape_rotator.rotate(q, k, offset=cache_len) 208 | if exists(kv_cache): 209 | k = T.cat([kv_cache[:, :cache_len, 0], k], dim=1) 210 | v = T.cat([kv_cache[:, :cache_len, 1], v], dim=1) 211 | kv_cache[:, :k.size(1), 0] = k 212 | kv_cache[:, :v.size(1), 1] = v 213 | x = self._sdpa(q, k, v) 214 | return self.attn_out(rearrange(x, 'b s h d -> b s (h d)')) 215 | 216 | def _project(self, x): 217 | full_q, full_k, full_v = self.proj_qkv(x).chunk(3, dim=-1) 218 | normed_full_q = self.norm_q(full_q).to(full_q.dtype) 219 | normed_full_k = self.norm_k(full_k).to(full_k.dtype) 220 | 221 | q = rearrange(normed_full_q, 'b s (h d) -> b s h d', h=self.n_heads) 222 | k = rearrange(normed_full_k, 'b s (h d) -> b s h d', h=self.kv_heads) 223 | v = rearrange(full_v, 'b s (h d) -> b s h d', h=self.kv_heads) 224 | return q, k, v 225 | 226 | def forward(self, 227 | x: Tensor, 228 | kv: Optional[Tensor] = None,): 229 | """ 230 | x: (B, S, D) 231 | kv: (B, S, H, D) 232 | """ 233 | q, k, v = self._project(x) 234 | return self._attend(q, k, v, kv_cache=kv) 235 | 236 | 237 | class PreNormAttn(nn.Module): 238 | def __init__(self, 239 | dim: int, 240 | n_head: int, 241 | shape_rotator: ShapeRotator, 242 | kv_heads: Optional[int] = None, 243 | eps: float = 1e-5, 244 | causal: bool = True,): 245 | super().__init__() 246 | self.attn_norm = Norm(dim, eps=eps) 247 | self.attn = GQA(dim, n_head, shape_rotator, kv_heads, eps=eps, causal=causal) 248 | 249 | def forward(self, x: Tensor, kv: Optional[Tensor] = None) -> Tensor: 250 | """ 251 | x: (B, S, D) 252 | kv: (B, S, H, D) 253 | """ 254 | return x + self.attn(self.attn_norm(x), kv) 255 | 256 | class PreNormFFNN(nn.Module): 257 | def __init__(self, 258 | dim: int, 259 | ff_dim: int, 260 | eps: float = 1e-5,): 261 | super().__init__() 262 | self.ffnn_norm = Norm(dim, eps=eps) 263 | self.ffnn = FFNN(dim, ff_dim) 264 | 265 | def forward(self, x: Tensor) -> Tensor: 266 | return x + self.ffnn(self.ffnn_norm(x)) 267 | 268 | class Block(nn.Module): 269 | def __init__(self, 270 | dim: int, 271 | layer_id: int = 0, 272 | n_head: int = 16, 273 | kv_heads: Optional[int] = None, 274 | ff_dim: Optional[int] = None, 275 | eps: float = 1e-5, 276 | causal: bool = True, 277 | shape_rotator: ShapeRotator = None): 278 | super().__init__() 279 | self.attn = PreNormAttn(dim, n_head, shape_rotator, kv_heads, eps=eps, causal=causal) 280 | self.ffnn = PreNormFFNN(dim, ff_dim, eps=eps) 281 | self.dim = dim 282 | self.layer_id = layer_id 283 | self.head_dim = dim // n_head 284 | self.expand_dim = self.ffnn.ffnn.expand_dim 285 | 286 | self.reset_parameters() 287 | 288 | def reset_parameters(self): 289 | std = 1.0 / math.sqrt(self.dim) 290 | nn.init.trunc_normal_(self.ffnn.ffnn.gateup_proj.weight, std=std, a=-3 * std, b=3 * std) 291 | nn.init.trunc_normal_(self.attn.attn.proj_qkv.weight, std=std, a=-3 * std, b=3 * std) 292 | nn.init.trunc_normal_(self.attn.attn.attn_out.weight, std=std, a=-3 * std, b=3 * std) 293 | 294 | xstd = 1.0 / math.sqrt(self.expand_dim) 295 | nn.init.trunc_normal_(self.ffnn.ffnn.down_proj.weight, std=xstd, a=-3 * xstd, b=3 * xstd) 296 | 297 | def forward(self, x: Tensor, kv: Optional[Tensor] = None) -> Tensor: 298 | """ 299 | x: (B, S, D) 300 | kv: (B, S, H, D) 301 | """ 302 | h = self.attn(x, kv) 303 | out = self.ffnn(h) 304 | return out 305 | 306 | 307 | 308 | class GPTOutput(nn.Module): 309 | def __init__(self, dim, vocab_size): 310 | super().__init__() 311 | self.dim = dim 312 | self.norm = Norm(dim) 313 | self.output = Linear(dim, vocab_size) 314 | 315 | self.reset_parameters() 316 | 317 | def reset_parameters(self): 318 | std = 1.0 / math.sqrt(self.dim**2) 319 | nn.init.trunc_normal_(self.output.weight, std=std, a=-3 * std, b=3 * std) 320 | 321 | def forward(self, x): 322 | return self.output(self.norm(x)) 323 | 324 | @si_module 325 | class Stack(nn.Module): 326 | class Config: 327 | layers: int 328 | dim: int 329 | seq_len: int 330 | n_head: int = 32 331 | ff_dim: int = None 332 | kv_heads: int = None 333 | eps: float = 1e-5 334 | theta: Union[int, float] = 10_000 335 | causal: bool = True 336 | 337 | from_pretrained: Optional[Tuple[str, int]] = None 338 | 339 | def __init__(self, c: Config): 340 | super().__init__() 341 | 342 | from_pretrained = c.from_pretrained 343 | if exists(from_pretrained): 344 | checkpoint = load_ckpt(c.from_pretrained) 345 | 346 | self.shape_rotator = ShapeRotator(c.dim//c.n_head, c.seq_len, theta=c.theta) 347 | 348 | self.layers = nn.ModuleList([ 349 | Block( 350 | dim=c.dim, 351 | layer_id=l, 352 | n_head=c.n_head, 353 | kv_heads=c.kv_heads, 354 | ff_dim=c.ff_dim, 355 | eps=c.eps, 356 | causal=c.causal, 357 | shape_rotator=self.shape_rotator, 358 | ) for l in range(c.layers) 359 | ]) 360 | 361 | kv_heads = c.kv_heads or c.n_head 362 | head_dim = c.dim // c.n_head 363 | cache_shape = [c.layers, c.seq_len, 2, kv_heads, head_dim] 364 | self.cache_shape = cache_shape 365 | self.cache = [None] * c.layers 366 | 367 | if exists(from_pretrained): 368 | self.load_state_dict(checkpoint) 369 | 370 | def init_cache(self, bsize, device, dtype, length:int=None): 371 | if self.cache_shape is None: 372 | return 373 | cache_shape = self.cache_shape.copy() 374 | cache_shape[1] = length or cache_shape[1] 375 | self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1) 376 | 377 | def deinit_cache(self): 378 | self.cache = [None] * len(self.cache) 379 | 380 | def forward(self, x: Tensor) -> Tensor: 381 | for l, layer in enumerate(self.layers): 382 | x = layer(x, kv=self.cache[l]) 383 | return x 384 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch as T 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from ioblocks import GaussianMixtureIOLayer, FSQ 8 | 9 | from transformer import Stack, ShapeRotator, Block as PerfBlock, GPTOutput, CACHE_FILL_VALUE, FFNN, Norm 10 | from tokenizer import make_tokenizer 11 | 12 | 13 | from utils import si_module, exists, isnt, tqdm0, print0, default, print0_colored 14 | from utils import load_ckpt 15 | 16 | 17 | @si_module 18 | class LatentQuantizer(nn.Module): 19 | class Config: 20 | compressor_config: Optional[FSQ.Config] = None 21 | 22 | dim: Optional[int] = None 23 | ff_dim: Optional[int] = None 24 | input_dim: int = None 25 | 26 | from_pretrained: Optional[Tuple[str, str]] = None 27 | 28 | def __init__(self, c: Config): 29 | super().__init__() 30 | 31 | if exists(c.from_pretrained): 32 | checkpoint = load_ckpt(*c.from_pretrained) 33 | else: 34 | assert exists(c.compressor_config), f'hmm {c}' 35 | 36 | self.compressor = c.compressor_config() 37 | self.ffnn = FFNN(c.dim, c.ff_dim) 38 | self.input = nn.Linear(c.input_dim, c.dim) if exists(c.input_dim) else nn.Identity() 39 | 40 | if exists(c.from_pretrained): 41 | self.load_state_dict(checkpoint) 42 | 43 | @T.no_grad() 44 | def forward(self, x, return_latent=False, known_latent=None): 45 | """ 46 | x: (B, S, D) 47 | """ 48 | if exists(known_latent): 49 | return self.compressor.indices_to_codes(known_latent) 50 | 51 | x = self.input(x) 52 | x = self.ffnn(x) 53 | x, tokens = self.compressor(x) 54 | 55 | if return_latent: 56 | return x, tokens 57 | return x 58 | 59 | 60 | @si_module 61 | class TransformerVAE(nn.Module): 62 | class Config: 63 | io_config: Optional[GaussianMixtureIOLayer.Config] = None 64 | stack_config: Optional[Stack.Config] = None 65 | quantizer_config: Optional[LatentQuantizer.Config] = None 66 | 67 | plex_layer: int = None 68 | plex_roll: int = 1 69 | split: bool = True 70 | 71 | from_pretrained: Optional[Tuple[str, str]] = None 72 | 73 | def __init__(self, c: Config): 74 | super().__init__() 75 | 76 | if exists(c.from_pretrained): 77 | checkpoint = load_ckpt(*c.from_pretrained) 78 | else: 79 | assert (exists(c.io_config) and exists(c.stack_config) and exists(c.quantizer_config)), f'hmm {c}' 80 | 81 | self.io = c.io_config() 82 | self.stack = c.stack_config() 83 | 84 | self.plex_layer = c.stack_config.layers//2 85 | self.plex_roll = c.plex_roll 86 | self.plex_dim = c.quantizer_config.dim 87 | 88 | assert self.plex_dim is not None and c.stack_config.dim is not None, f'One of the following are None: self.plex_dim: {self.plex_dim}, c.stack_config.dim: {c.stack_config.dim}' 89 | self.plex_projection = nn.Linear(self.plex_dim, c.stack_config.dim) 90 | self.out_norm = Norm(c.stack_config.dim) 91 | 92 | if c.split: 93 | self.io2 = c.io_config() 94 | self.plex_projection2 = nn.Linear(self.plex_dim, c.stack_config.dim) 95 | 96 | self.io2.fc_loc = None 97 | self.io2.fc_scale = None 98 | self.io2.fc_weight = None 99 | 100 | kv_heads = c.stack_config.kv_heads or c.stack_config.n_head 101 | head_dim = c.stack_config.dim // c.stack_config.n_head 102 | self.cache_num_layers = c.stack_config.layers + ((c.stack_config.layers - self.plex_layer) if c.split else 0) 103 | cache_shape = [self.cache_num_layers, c.stack_config.seq_len, 2, kv_heads, head_dim] 104 | self.cache_shape = cache_shape 105 | self.cache = [None] * self.cache_num_layers 106 | 107 | if exists(c.from_pretrained): 108 | result = self.load_state_dict(checkpoint, strict=False) 109 | print0_colored(result, 'yellow') 110 | 111 | self.quantizer = c.quantizer_config().eval() 112 | self.quantizer.requires_grad = False 113 | 114 | @T.no_grad() 115 | def quantize(self, x): 116 | if self.c.split: 117 | x1, x2 = x.chunk(2, dim=-1) 118 | with T.autocast(device_type='cuda', dtype=T.bfloat16): 119 | quantized1 = self.quantizer(x1) 120 | quantized2 = self.quantizer(x2) 121 | return quantized1, quantized2 122 | else: 123 | with T.autocast(device_type='cuda', dtype=T.bfloat16): 124 | return self.quantizer(x) 125 | 126 | @T.no_grad() 127 | def untokenize(self, token_data): 128 | return self.quantizer(None, known_latent=token_data) 129 | 130 | def init_cache(self, bsize, device, dtype, length:int=None): 131 | cache_shape = self.cache_shape.copy() 132 | cache_shape[1] = length or cache_shape[1] 133 | self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1) 134 | 135 | def deinit_cache(self): 136 | self.cache = [None] * self.cache_num_layers 137 | 138 | @T.no_grad() 139 | def forward(self, data, next_tokens: Optional[Tuple[T.Tensor, T.Tensor]] = None, temps: Optional[Tuple[float, Tuple[float, float]]] = None): 140 | if self.c.split: 141 | x1, x2 = data.chunk(2, dim=-1) 142 | x = self.io.input(x1) + self.io2.input(x2) 143 | else: 144 | x = self.io.input(data) 145 | 146 | cache_idx = 0 147 | for l, layer in enumerate(self.stack.layers): 148 | if l == self.plex_layer: 149 | if self.c.split: 150 | plex1, plex2 = self.quantize(data) 151 | plex1 = T.roll(plex1, -self.c.plex_roll, dims=1) 152 | plex2 = T.roll(plex2, -self.c.plex_roll, dims=1) 153 | if exists(next_tokens): 154 | plex1[:, -1:] = self.untokenize(next_tokens[0]) 155 | plex2[:, -1:] = self.untokenize(next_tokens[1]) 156 | x1 = x + self.plex_projection(plex1) 157 | x2 = x + self.plex_projection2(plex2) 158 | else: 159 | plex = self.quantize(data) 160 | plex = T.roll(plex, -self.c.plex_roll, dims=1) 161 | if exists(next_tokens): 162 | plex[:, -1:] = self.untokenize(next_tokens) 163 | x = x + self.plex_projection(plex) 164 | 165 | if l < self.plex_layer: 166 | x = layer(x, kv=self.cache[l]) 167 | else: 168 | if self.c.split: 169 | x1 = layer(x1, kv=self.cache[self.plex_layer + cache_idx]) 170 | cache_idx += 1 171 | x2 = layer(x2, kv=self.cache[self.plex_layer + cache_idx]) 172 | cache_idx += 1 173 | else: 174 | x = layer(x, kv=self.cache[l]) 175 | 176 | with T.autocast(device_type='cuda', dtype=T.bfloat16): 177 | if self.c.split: 178 | x1, x2 = self.out_norm(x1), self.out_norm(x2) 179 | out1, out2 = self.io.output(x1), self.io.output(x2) 180 | else: 181 | x = self.out_norm(x) 182 | out = self.io.output(x) 183 | 184 | if isnt(temps): 185 | if self.c.split: 186 | return out1, out2 187 | else: 188 | return out 189 | else: 190 | if self.c.split: 191 | next_data1 = self.io.temp_sample(out1, temps)[:, -1:, :] 192 | next_data2 = self.io2.temp_sample(out2, temps)[:, -1:, :] 193 | next_data = T.cat([next_data1, next_data2], dim=-1) 194 | return next_data 195 | else: 196 | next_data = self.io.temp_sample(out, temps)[:, -1:, :] 197 | return next_data 198 | 199 | @si_module 200 | class HertzDevModel(nn.Module): 201 | class Config: 202 | dim: int 203 | vocab_size: int 204 | stack_config: Optional[Stack.Config] = None 205 | latent_size: int = 32 206 | 207 | split: bool = True 208 | 209 | quantizer_config: Optional[LatentQuantizer.Config] = None 210 | resynthesizer_config: Optional[TransformerVAE.Config] = None 211 | 212 | from_pretrained: Optional[Tuple[str, str]] = None 213 | 214 | def __init__(self, c: Config): 215 | super().__init__() 216 | 217 | if exists(c.from_pretrained): 218 | checkpoint = load_ckpt(*c.from_pretrained) 219 | else: 220 | assert (exists(c.stack_config)), f'hmm {c}' 221 | 222 | self.input = nn.Linear(c.latent_size, c.dim) 223 | if self.c.split: 224 | self.input2 = nn.Linear(c.latent_size, c.dim) 225 | 226 | self.shape_rotator = ShapeRotator(c.stack_config.dim//c.stack_config.n_head, c.stack_config.seq_len, theta=c.stack_config.theta) 227 | 228 | self.layers = nn.ModuleList([ 229 | PerfBlock( 230 | dim=c.stack_config.dim, 231 | layer_id=l, 232 | n_head=c.stack_config.n_head, 233 | kv_heads=c.stack_config.kv_heads, 234 | ff_dim=c.stack_config.ff_dim, 235 | eps=c.stack_config.eps, 236 | shape_rotator=self.shape_rotator, 237 | ) for l in range(c.stack_config.layers) 238 | ]) 239 | 240 | self.output = GPTOutput(c.dim, c.vocab_size) 241 | if self.c.split: 242 | self.output2 = GPTOutput(c.dim, c.vocab_size) 243 | 244 | self.cache = [None] * c.stack_config.layers 245 | self.kv_heads = c.stack_config.kv_heads or c.stack_config.n_head 246 | self.head_dim = c.stack_config.dim // c.stack_config.n_head 247 | 248 | if exists(c.from_pretrained): 249 | result = self.load_state_dict(checkpoint, strict=False) 250 | print0_colored(result, 'yellow') 251 | 252 | self.resynthesizer = c.resynthesizer_config().eval() 253 | self.resynthesizer.requires_grad = False 254 | 255 | self.audio_tokenizer = make_tokenizer(device='cpu') 256 | self.audio_cache = None 257 | self.audio_latent_cache = None 258 | self.use_audio_cache = False 259 | 260 | @T.no_grad() 261 | def tokenize(self, audio_data): 262 | orig_audio_shape = audio_data.shape 263 | if exists(self.audio_cache): 264 | audio_data = T.cat([self.audio_cache, audio_data], dim=-1) 265 | self.audio_cache = audio_data[..., -(6*16_000):] 266 | elif self.use_audio_cache: 267 | self.audio_cache = audio_data[..., -(6*16_000):] 268 | 269 | if audio_data.shape[1] == 2: 270 | enc_ch1 = self.audio_tokenizer.latent_from_data(audio_data[:, 0:1]) 271 | enc_ch2 = self.audio_tokenizer.latent_from_data(audio_data[:, 1:2]) 272 | return T.cat([enc_ch1, enc_ch2], dim=-1)[:, -(orig_audio_shape[-1]//2000):] 273 | else: 274 | return self.audio_tokenizer.latent_from_data(audio_data)[:, -(orig_audio_shape[-1]//2000):] 275 | 276 | @T.no_grad() 277 | def untokenize(self, token_data): 278 | if exists(self.audio_latent_cache): 279 | token_data = T.cat([self.audio_latent_cache, token_data], dim=1) 280 | self.audio_latent_cache = token_data[:, -(6*8):] 281 | elif self.use_audio_cache: 282 | self.audio_latent_cache = token_data[:, -(6*8):] 283 | 284 | if token_data.shape[-1] == 2*self.c.latent_size: 285 | dec_ch1 = self.audio_tokenizer.data_from_latent(token_data[:, :self.c.latent_size]) 286 | dec_ch2 = self.audio_tokenizer.data_from_latent(token_data[:, self.c.latent_size:]) 287 | return T.cat([dec_ch1, dec_ch2], dim=1)[..., -(token_data.shape[1]*2000):] 288 | else: 289 | return self.audio_tokenizer.data_from_latent(token_data)[..., -(token_data.shape[1]*2000):] 290 | 291 | def init_cache(self, bsize, device, dtype, length:int=None): 292 | cache_shape = [self.c.stack_config.layers, length or self.c.stack_config.seq_len, 2, self.kv_heads, self.head_dim] 293 | self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1) 294 | self.resynthesizer.init_cache(bsize, device, dtype, length) 295 | self.use_audio_cache = True 296 | 297 | def deinit_cache(self): 298 | self.cache = [None] * len(self.layers) 299 | self.resynthesizer.deinit_cache() 300 | self.audio_cache = None 301 | self.audio_latent_cache = None 302 | self.use_audio_cache = False 303 | 304 | @T.no_grad() 305 | def forward(self, data): 306 | if self.c.split: 307 | x1, x2 = data.chunk(2, dim=-1) 308 | x = self.input(x1) + self.input2(x2) 309 | else: 310 | x = self.input(data) 311 | 312 | for l, layer in enumerate(self.layers): 313 | x = layer(x, kv=self.cache[l]) 314 | 315 | if self.c.split: 316 | return self.output(x), self.output2(x) 317 | else: 318 | return self.output(x) 319 | 320 | @T.no_grad() 321 | def next_audio_from_audio(self, audio_data: T.Tensor, temps=(0.8, (0.5, 0.1))): 322 | latents_in = self.tokenize(audio_data) 323 | next_latents = self.next_latent(latents_in, temps) 324 | next_model_latent = next_latents[..., self.c.latent_size:] 325 | audio_decoded = self.untokenize(next_model_latent)[..., -2000:] 326 | return audio_decoded 327 | 328 | 329 | @T.no_grad() 330 | def next_latent(self, model_input: T.Tensor, temps=(0.8, (0.5, 0.1))): 331 | 332 | if self.c.split: 333 | logits1, logits2 = self.forward(model_input) 334 | next_logits1 = logits1[:, -1] 335 | next_logits2 = logits2[:, -1] 336 | next_token1 = F.softmax(next_logits1 / temps[0], dim=-1).multinomial(1) 337 | next_token2 = F.softmax(next_logits2 / temps[0], dim=-1).multinomial(1) 338 | 339 | next_input = self.resynthesizer(model_input, next_tokens=(next_token1, next_token2), temps=temps[1]) 340 | else: 341 | logits = self.forward(model_input) 342 | next_logits = logits[:, -1] 343 | next_token = F.softmax(next_logits / temps[0], dim=-1).multinomial(1) 344 | 345 | next_input = self.resynthesizer(model_input, next_tokens=next_token, temps=temps[1]) 346 | 347 | return next_input 348 | 349 | 350 | @T.no_grad() 351 | def completion(self, data: T.Tensor, temps=(0.8, (0.5, 0.1)), gen_len=None, use_cache=True) -> T.Tensor: 352 | """ 353 | only accepts latent-space data. 354 | """ 355 | if use_cache: 356 | self.init_cache(data.shape[0], data.device, T.bfloat16) 357 | 358 | next_input = generated = data 359 | 360 | target_len = min(data.shape[1] + default(gen_len, data.shape[1]), self.c.stack_config.seq_len) 361 | 362 | for _ in tqdm0(range(data.shape[1], target_len)): 363 | model_input = next_input if use_cache else generated 364 | 365 | next_input = self.next_latent(model_input, temps) 366 | 367 | generated = T.cat([generated, next_input], dim=1) 368 | 369 | if use_cache: 370 | self.deinit_cache() 371 | return generated 372 | 373 | 374 | 375 | def get_hertz_dev_config(is_split=True, use_pure_audio_ablation=False): 376 | if is_split: 377 | checkpoints = [('inference_care_50000', 'e4ff4fe5c7e9f066410d2a5673b7a935'), ('inference_scion_54000', 'cb8bc484423922747b277ebc2933af5d')] 378 | elif not use_pure_audio_ablation: 379 | checkpoints = [('inference_whip_72000', '5e7cee7316900737d55fc5d44cc7a8f7'), ('inference_caraway_112000', 'fcb8368ef8ebf7712f3e31e6856da580')] 380 | else: 381 | checkpoints = [('inference_whip_72000', '5e7cee7316900737d55fc5d44cc7a8f7'), ('inference_syrup_110000', '353c48f553f1706824c11f3bb6a049e9')] 382 | 383 | quantizer_config=LatentQuantizer.Config( 384 | from_pretrained=('inference_volcano_3', 'd42bf674022c5f84b051d5d7794f6169'), 385 | compressor_config=FSQ.Config( 386 | levels=[8,8,8,8,8], 387 | dim=2048, 388 | num_codebooks=1, 389 | keep_num_codebooks_dim=None, 390 | scale=None, 391 | allowed_dtypes=['float32', 'float64', 'bfloat16'], 392 | channel_first=False, 393 | projection_has_bias=True, 394 | return_indices=True, 395 | force_quantization_f32=True, 396 | use_rms=False 397 | ), 398 | dim=2048, 399 | ff_dim=8192, 400 | input_dim=32 401 | ) 402 | 403 | resynthesizer_config=TransformerVAE.Config( 404 | io_config=GaussianMixtureIOLayer.Config( 405 | latent_dim=32, 406 | dim=4096, 407 | num_components=8, 408 | ), 409 | stack_config=Stack.Config( 410 | layers=8, 411 | dim=4096, 412 | seq_len=8192, 413 | n_head=16, 414 | ff_dim=11008, 415 | kv_heads=16, 416 | eps=1e-5, 417 | theta=10_000 418 | ), 419 | quantizer_config=quantizer_config, 420 | plex_layer=None, 421 | plex_roll=1, 422 | split=is_split, 423 | from_pretrained=checkpoints[0], 424 | ) 425 | 426 | return HertzDevModel.Config( 427 | dim=4096, 428 | vocab_size=32_768, 429 | stack_config=Stack.Config( 430 | layers=32, 431 | dim=4096, 432 | seq_len=2048, 433 | n_head=32, 434 | ff_dim=None, 435 | kv_heads=None, 436 | eps=1e-5, 437 | theta=10_000, 438 | ), 439 | quantizer_config=quantizer_config, 440 | resynthesizer_config=resynthesizer_config, 441 | split=is_split, 442 | from_pretrained=checkpoints[1], 443 | ) -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from typing import Union, Tuple, Literal 4 | 5 | import torch as T 6 | import torch.nn as nn 7 | from torch.nn.utils.parametrizations import weight_norm 8 | 9 | from utils import load_ckpt 10 | from utils.interp import print_colored 11 | from utils import si_module, get_activation 12 | 13 | 14 | 15 | # Adapted from https://github.com/facebookresearch/AudioDec 16 | 17 | def Conv1d1x1(in_channels, out_channels, bias=True): 18 | return nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=bias) 19 | 20 | 21 | class NonCausalConv1d(nn.Module): 22 | """1D noncausal convolution w/ 2-sides padding.""" 23 | 24 | def __init__( 25 | self, 26 | in_channels, 27 | out_channels, 28 | kernel_size, 29 | stride=1, 30 | padding=-1, 31 | dilation=1, 32 | groups=1, 33 | bias=True): 34 | super().__init__() 35 | self.in_channels = in_channels 36 | self.out_channels = out_channels 37 | self.kernel_size = kernel_size 38 | if padding < 0: 39 | padding = (kernel_size - 1) // 2 * dilation 40 | self.dilation = dilation 41 | self.conv = nn.Conv1d( 42 | in_channels=in_channels, 43 | out_channels=out_channels, 44 | kernel_size=kernel_size, 45 | stride=stride, 46 | padding=padding, 47 | dilation=dilation, 48 | groups=groups, 49 | bias=bias, 50 | ) 51 | 52 | def forward(self, x): 53 | """ 54 | Args: 55 | x (Tensor): Float tensor variable with the shape (B, C, T). 56 | Returns: 57 | Tensor: Float tensor variable with the shape (B, C, T). 58 | """ 59 | x = self.conv(x) 60 | return x 61 | 62 | 63 | class NonCausalConvTranspose1d(nn.Module): 64 | """1D noncausal transpose convolution.""" 65 | 66 | def __init__( 67 | self, 68 | in_channels, 69 | out_channels, 70 | kernel_size, 71 | stride, 72 | padding=-1, 73 | output_padding=-1, 74 | groups=1, 75 | bias=True, 76 | ): 77 | super().__init__() 78 | if padding < 0: 79 | padding = (stride+1) // 2 80 | if output_padding < 0: 81 | output_padding = 1 if stride % 2 else 0 82 | self.deconv = nn.ConvTranspose1d( 83 | in_channels=in_channels, 84 | out_channels=out_channels, 85 | kernel_size=kernel_size, 86 | stride=stride, 87 | padding=padding, 88 | output_padding=output_padding, 89 | groups=groups, 90 | bias=bias, 91 | ) 92 | 93 | def forward(self, x): 94 | """ 95 | Args: 96 | x (Tensor): Float tensor variable with the shape (B, C, T). 97 | Returns: 98 | Tensor: Float tensor variable with the shape (B, C', T'). 99 | """ 100 | x = self.deconv(x) 101 | return x 102 | 103 | 104 | class CausalConv1d(NonCausalConv1d): 105 | def __init__( 106 | self, 107 | in_channels, 108 | out_channels, 109 | kernel_size, 110 | stride=1, 111 | dilation=1, 112 | groups=1, 113 | bias=True 114 | ): 115 | super(CausalConv1d, self).__init__( 116 | in_channels=in_channels, 117 | out_channels=out_channels, 118 | kernel_size=kernel_size, 119 | stride=stride, 120 | padding=0, 121 | dilation=dilation, 122 | groups=groups, 123 | bias=bias, 124 | ) 125 | self.stride = stride 126 | self.pad_length = (kernel_size - 1) * dilation 127 | def forward(self, x): 128 | pad = nn.ConstantPad1d((self.pad_length, 0), 0.0) 129 | x = pad(x) 130 | return self.conv(x) 131 | 132 | 133 | class CausalConvTranspose1d(NonCausalConvTranspose1d): 134 | def __init__( 135 | self, 136 | in_channels, 137 | out_channels, 138 | kernel_size, 139 | stride, 140 | bias=True, 141 | pad_buffer=None, 142 | ): 143 | super(CausalConvTranspose1d, self).__init__( 144 | in_channels=in_channels, 145 | out_channels=out_channels, 146 | kernel_size=kernel_size, 147 | stride=stride, 148 | padding=0, 149 | output_padding=0, 150 | bias=bias, 151 | ) 152 | self.stride = stride 153 | self.pad_length = (math.ceil(kernel_size/stride) - 1) 154 | if pad_buffer is None: 155 | pad_buffer = T.zeros(1, in_channels, self.pad_length) 156 | self.register_buffer("pad_buffer", pad_buffer) 157 | 158 | def forward(self, x): 159 | pad = nn.ReplicationPad1d((self.pad_length, 0)) 160 | x = pad(x) 161 | return self.deconv(x)[:, :, self.stride : -self.stride] 162 | 163 | def inference(self, x): 164 | x = T.cat((self.pad_buffer, x), -1) 165 | self.pad_buffer = x[:, :, -self.pad_length:] 166 | return self.deconv(x)[:, :, self.stride : -self.stride] 167 | 168 | def reset_buffer(self): 169 | self.pad_buffer.zero_() 170 | 171 | 172 | class NonCausalResUnit(nn.Module): 173 | def __init__( 174 | self, 175 | in_channels, 176 | out_channels, 177 | kernel_size=7, 178 | dilation=1, 179 | bias=False, 180 | ): 181 | super().__init__() 182 | self.activation = nn.ELU() 183 | self.conv1 = NonCausalConv1d( 184 | in_channels=in_channels, 185 | out_channels=out_channels, 186 | kernel_size=kernel_size, 187 | stride=1, 188 | dilation=dilation, 189 | bias=bias, 190 | ) 191 | self.conv2 = Conv1d1x1(out_channels, out_channels, bias) 192 | 193 | def forward(self, x): 194 | y = self.conv1(self.activation(x)) 195 | y = self.conv2(self.activation(y)) 196 | return x + y 197 | 198 | 199 | class CausalResUnit(NonCausalResUnit): 200 | def __init__( 201 | self, 202 | in_channels, 203 | out_channels, 204 | kernel_size=7, 205 | dilation=1, 206 | bias=False, 207 | ): 208 | super(CausalResUnit, self).__init__( 209 | in_channels=in_channels, 210 | out_channels=out_channels, 211 | kernel_size=kernel_size, 212 | dilation=dilation, 213 | bias=bias, 214 | ) 215 | self.conv1 = CausalConv1d( 216 | in_channels=in_channels, 217 | out_channels=out_channels, 218 | kernel_size=kernel_size, 219 | stride=1, 220 | dilation=dilation, 221 | bias=bias, 222 | ) 223 | 224 | def inference(self, x): 225 | y = self.conv1.inference(self.activation(x)) 226 | y = self.conv2(self.activation(y)) 227 | return x + y 228 | 229 | 230 | class ResNetBlock(nn.Module): 231 | def __init__(self, 232 | in_channels, 233 | out_channels, 234 | stride, 235 | kernel_size=7, 236 | dilations=(1, 3, 9), 237 | bias=True, 238 | mode='encoder', 239 | ): 240 | super().__init__() 241 | assert mode in ('encoder', 'decoder'), f"Mode ({mode}) is not supported!" 242 | 243 | self.mode = mode 244 | self.stride = stride 245 | 246 | ConvUnit = CausalConv1d if mode == 'encoder' else CausalConvTranspose1d 247 | 248 | res_channels = in_channels if mode == 'encoder' else out_channels 249 | 250 | res_units = [CausalResUnit( 251 | res_channels, 252 | res_channels, 253 | kernel_size=kernel_size, 254 | dilation=dilation, 255 | ) for dilation in dilations] 256 | 257 | if in_channels == out_channels: 258 | if mode == 'encoder': 259 | self.pool = nn.AvgPool1d(kernel_size=stride, stride=stride) 260 | if mode == 'decoder': 261 | self.upsample = nn.Upsample(scale_factor=stride, mode='nearest') 262 | conv_unit = nn.Conv1d( 263 | in_channels=in_channels, 264 | out_channels=out_channels, 265 | kernel_size=1, 266 | bias=bias, 267 | ) if in_channels != out_channels else nn.Identity() 268 | else: 269 | conv_unit = ConvUnit( 270 | in_channels=in_channels, 271 | out_channels=out_channels, 272 | kernel_size=(2 * stride), 273 | stride=stride, 274 | bias=bias, 275 | ) 276 | 277 | if mode == 'encoder': 278 | if in_channels == out_channels: 279 | self.res_block = nn.Sequential(*res_units, self.pool, conv_unit) 280 | else: 281 | self.res_block = nn.Sequential(*res_units, conv_unit) 282 | elif mode == 'decoder': 283 | if in_channels == out_channels: 284 | self.res_block = nn.Sequential(self.upsample, conv_unit, *res_units) 285 | else: 286 | self.res_block = nn.Sequential(conv_unit, *res_units) 287 | 288 | def forward(self, x): 289 | out = x 290 | for unit in self.res_block: 291 | out = unit(out) 292 | return out 293 | 294 | def inference(self, x): 295 | for unit in self.res_block: 296 | x = unit.inference(x) 297 | return x 298 | 299 | 300 | 301 | 302 | @si_module 303 | class ResNetStack(nn.Module): 304 | """ 305 | ResNet encoder or decoder stack. Channel ratios 306 | and strides take the default order of from 307 | data/io-layer, to the middle of the model. 308 | """ 309 | class Config: 310 | input_channels: int = 1 311 | output_channels: int = 1 312 | encode_channels: int = 32 313 | decode_channel_multiplier: int = 1 314 | latent_dim: int = None 315 | kernel_size: int = 7 316 | bias: bool = True 317 | channel_ratios: Tuple[int, ...] = (2, 4, 8, 16) 318 | strides: Tuple[int, ...] = (3, 4, 5, 5) 319 | mode: Literal['encoder', 'decoder'] = 'encoder' 320 | 321 | def __init__(self, c: Config): 322 | super().__init__() 323 | assert c.mode in ('encoder', 'decoder'), f"Mode ({c.mode}) is not supported!" 324 | 325 | self.mode = c.mode 326 | 327 | assert len(c.channel_ratios) == len(c.strides) 328 | channel_ratios = (1,) + c.channel_ratios 329 | strides = c.strides 330 | self.middle_channels = c.encode_channels * channel_ratios[-1] 331 | if c.mode == 'decoder': 332 | channel_ratios = tuple(reversed(channel_ratios)) 333 | strides = tuple(reversed(strides)) 334 | 335 | self.multiplier = c.decode_channel_multiplier if c.mode == 'decoder' else 1 336 | res_blocks = [ResNetBlock( 337 | c.encode_channels * channel_ratios[s_idx] * self.multiplier, 338 | c.encode_channels * channel_ratios[s_idx+1] * self.multiplier, 339 | stride, 340 | kernel_size=c.kernel_size, 341 | bias=c.bias, 342 | mode=c.mode, 343 | ) for s_idx, stride in enumerate(strides)] 344 | 345 | data_conv = CausalConv1d( 346 | in_channels=c.input_channels if c.mode == 'encoder' else c.encode_channels * self.multiplier, 347 | out_channels=c.encode_channels if c.mode == 'encoder' else c.output_channels, 348 | kernel_size=c.kernel_size, 349 | stride=1, 350 | bias=False, 351 | ) 352 | 353 | if c.mode == 'encoder': 354 | self.res_stack = nn.Sequential(data_conv, *res_blocks) 355 | elif c.mode == 'decoder': 356 | self.res_stack = nn.Sequential(*res_blocks, data_conv) 357 | 358 | if c.latent_dim is not None: 359 | self.latent_proj = Conv1d1x1(self.middle_channels, c.latent_dim, bias=c.bias) if c.mode == 'encoder' else Conv1d1x1(c.latent_dim, self.middle_channels, bias=c.bias) 360 | if self.multiplier != 1: 361 | self.multiplier_proj = Conv1d1x1(self.middle_channels, self.middle_channels * self.multiplier, bias=c.bias) 362 | 363 | def forward(self, x, return_feats=False): 364 | if self.c.latent_dim is not None and self.mode == 'decoder': 365 | x = self.latent_proj(x) 366 | if self.multiplier != 1: 367 | x = self.multiplier_proj(x) 368 | 369 | feats = [] 370 | for block in self.res_stack: 371 | x = block(x) 372 | if return_feats: 373 | feats.append(x) 374 | if self.c.latent_dim is not None and self.mode == 'encoder': 375 | x = self.latent_proj(x) 376 | if return_feats: 377 | feats.append(x) 378 | if return_feats: 379 | return feats 380 | return x 381 | 382 | def inference(self, x): 383 | for block in self.res_stack: 384 | x = block.inference(x) 385 | return x 386 | 387 | def reset_buffer(self): 388 | def _reset_buffer(m): 389 | if isinstance(m, CausalConv1d) or isinstance(m, CausalConvTranspose1d): 390 | m.reset_buffer() 391 | self.apply(_reset_buffer) 392 | 393 | def reset_parameters(self): 394 | def _reset_parameters(m): 395 | if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)): 396 | m.weight.data.normal_(0.0, 0.01) 397 | 398 | self.apply(_reset_parameters) 399 | 400 | 401 | def apply_weight_norm(self): 402 | def _apply_weight_norm(m): 403 | if isinstance(m, nn.Conv1d) or isinstance( 404 | m, nn.ConvTranspose1d 405 | ): 406 | nn.utils.parametrizations.weight_norm(m) 407 | 408 | self.apply(_apply_weight_norm) 409 | 410 | 411 | def remove_weight_norm(self): 412 | def _remove_weight_norm(m): 413 | try: 414 | print(m) 415 | nn.utils.remove_weight_norm(m) 416 | except ValueError: # this module didn't have weight norm 417 | return 418 | 419 | self.apply(_remove_weight_norm) 420 | 421 | 422 | 423 | @si_module 424 | class GaussianZ(nn.Module): 425 | class Config: 426 | dim: int 427 | latent_dim: int 428 | bias: bool = False 429 | use_weight_norm: bool = False 430 | 431 | def __init__(self, c: Config): 432 | super().__init__() 433 | 434 | self.proj_in = nn.Linear(c.dim, c.latent_dim * 2, bias=c.bias) 435 | self.proj_out = nn.Linear(c.latent_dim, c.dim, bias=c.bias) 436 | 437 | if c.use_weight_norm: 438 | self.proj_in = weight_norm(self.proj_in) 439 | self.proj_out = weight_norm(self.proj_out) 440 | 441 | def reparam(self, mu, logvar): 442 | std = T.exp(logvar / 2) 443 | eps = T.randn_like(std) 444 | return mu + eps * std 445 | 446 | def kl_divergence(self, mu, logvar): 447 | return T.mean(-0.5 * T.sum( 448 | 1 + logvar - mu.pow(2) - logvar.exp(), 449 | dim=(1, 2)) 450 | ) 451 | 452 | def repr_from_latent(self, latent: Union[dict, T.Tensor]): 453 | if isinstance(latent, T.Tensor): 454 | z = latent 455 | else: 456 | z = self.reparam(latent['mu'], latent['logvar']) 457 | l = self.proj_out(z) 458 | return l 459 | 460 | def forward(self, x: T.Tensor) -> Tuple[T.Tensor, dict]: 461 | mu, logvar = self.proj_in(x).chunk(2, dim=-1) 462 | kl_div = self.kl_divergence(mu, logvar) 463 | z = self.reparam(mu, logvar) 464 | xhat = self.proj_out(z) 465 | latent = {'mu': mu, 'logvar': logvar, 'z': z, 'kl_divergence': kl_div} 466 | return xhat, latent 467 | 468 | 469 | 470 | @si_module 471 | class WaveCodec(nn.Module): 472 | class Config: 473 | resnet_config: ResNetStack.Config = None 474 | sample_rate: int = 16_000 475 | use_weight_norm: bool = False 476 | 477 | compressor_config: dataclass = None 478 | 479 | norm_stddev: float = 1.0 480 | 481 | def __init__(self, c: Config): 482 | super().__init__() 483 | self.norm_stddev = c.norm_stddev 484 | self.encoder = c.resnet_config(mode='encoder') 485 | self.sample_rate = c.sample_rate 486 | 487 | self.total_stride = 1 488 | for stride in c.resnet_config.strides: 489 | self.total_stride *= stride 490 | self.tokens_per_second = self.sample_rate / self.total_stride 491 | 492 | self.compressor = c.compressor_config(dim=self.encoder.middle_channels) 493 | 494 | self.decoder = c.resnet_config(mode='decoder') 495 | 496 | if c.use_weight_norm: 497 | self.encoder.apply_weight_norm() 498 | self.decoder.apply_weight_norm() 499 | self.encoder.reset_parameters() 500 | self.decoder.reset_parameters() 501 | 502 | def encode(self, data): 503 | return self.encoder(data/self.norm_stddev) 504 | 505 | def decode(self, latent): 506 | return self.decoder(latent.transpose(1, 2))*self.norm_stddev 507 | 508 | @T.no_grad() 509 | def latent_from_data(self, data, get_parameters=False): 510 | x = self.encode(data) 511 | l_in = x.transpose(1, 2) 512 | l, latent = self.compressor(l_in) 513 | return latent['z'] if not get_parameters else { 514 | 'mu': latent['mu'], 515 | 'logvar': latent['logvar'], 516 | 'z': latent['z'], 517 | } 518 | 519 | @T.no_grad() 520 | def data_from_latent(self, latent): 521 | l = self.compressor.repr_from_latent(latent) 522 | x = self.decode(l) 523 | return x 524 | 525 | def process(self, x): 526 | return self.latent_from_data(x) 527 | 528 | def unprocess(self, latent): 529 | return self.data_from_latent(latent) 530 | 531 | def forward(self, audio_input): 532 | x = self.encode(audio_input) 533 | 534 | l_in = x.transpose(1, 2) 535 | l, latent = self.compressor(l_in) 536 | 537 | xhat = self.decode(l) 538 | return xhat, latent 539 | 540 | 541 | 542 | def make_tokenizer(device='cuda'): 543 | generator_config = WaveCodec.Config( 544 | resnet_config=ResNetStack.Config( 545 | input_channels=1, 546 | output_channels=1, 547 | encode_channels=16, 548 | decode_channel_multiplier=4, 549 | kernel_size=7, 550 | bias=True, 551 | channel_ratios=(4, 8, 16, 16, 16, 16), 552 | strides=(2, 2, 4, 5, 5, 5), 553 | mode=None, 554 | ), 555 | use_weight_norm=True, 556 | 557 | compressor_config=GaussianZ.Config( 558 | dim=None, 559 | latent_dim=32, 560 | 561 | bias=True, 562 | use_weight_norm=True 563 | ), 564 | 565 | norm_stddev=0.05, 566 | ) 567 | checkpoint = load_ckpt("inference_apatosaurus_95000", expected_hash="ba876edb97b988e9196e449dd176ca97") 568 | 569 | tokenizer = generator_config() 570 | 571 | load_result = tokenizer.load_state_dict(checkpoint, strict=False) 572 | print_colored(f"Loaded tokenizer state dict: {load_result}", "grey") 573 | 574 | tokenizer = tokenizer.eval() 575 | # Only convert to bfloat16 if using CUDA 576 | if device == 'cuda': 577 | tokenizer = tokenizer.bfloat16() 578 | tokenizer = tokenizer.to(device) 579 | tokenizer.requires_grad_ = False 580 | return tokenizer 581 | 582 | --------------------------------------------------------------------------------