├── .gitignore ├── LICENSE ├── README.md ├── inference.ipynb ├── inference_client.py ├── inference_client_webrtc.py ├── inference_server.py ├── ioblocks.py ├── model.py ├── prompts ├── bob_duo.wav ├── bob_mono.wav ├── countdown_mono.wav └── toaskanymore.wav ├── requirements.txt ├── requirements_webrtc.txt ├── tokenizer.py ├── transformer.py └── utils ├── __init__.py ├── blocks.py ├── dist.py └── interp.py /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | *.wav 3 | *.mp3 4 | *.m4a 5 | !prompts/*.wav 6 | !prompts/*.mp3 7 | !prompts/*.m4a 8 | __pycache__ 9 | *ckpt 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 Standard Intelligence PBC 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # hertz-dev 2 | 3 | Hertz-dev is an open-source, first-of-its-kind base model for full-duplex conversational audio. 4 | 5 | See our blog post for more details: https://si.inc/hertz-dev/ 6 | 7 | ## Setup 8 | 9 | Inference is known to work on Python 3.10 and CUDA 12.1. Other versions have not been tested as thoroughly. If you want to use CUDA 12.1, you'll need to install torch with `pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121` before running `pip install -r requirements.txt`. 10 | 11 | On Ubuntu you may need to install libportaudio: `sudo apt-get install libportaudio2` 12 | 13 | All three scripts will automatically download the models to the `./ckpt` directory, and checkpoints are also accessible at https://ckpt.si.inc/hertz-dev/index.txt 14 | 15 | ## Usage 16 | 17 | We recommend starting by using `inference.ipynb` to generate one- or two-channel completions from a prompt. 18 | 19 | Then, you can use `inference_client.py` and `inference_server.py` to talk to the model live through your microphone. 20 | These are currently experimental, and have primarily been tested with Ubuntu on the server and MacOS on the client. 21 | 22 | Alternatively, you can use `inference_client_webrtc.py`, which is built on [streamlit](https://streamlit.io/) + [streamlit-webrtc](https://github.com/whitphx/streamlit-webrtc) and runs in a browser: 23 | ```bash 24 | # Install additional requirements 25 | pip install -r requirements_webrtc.txt 26 | # Run the client 27 | streamlit run inference_client_webrtc.py 28 | ``` 29 | Then, access the client at [http://localhost:8501](http://localhost:8501). 30 | 31 | **Note**: If you host the streamlit client anywhere other than `localhost` you will need to connect with https to avoid errors (see [here](https://github.com/whitphx/streamlit-webrtc?tab=readme-ov-file#serving-from-remote-host) for more info). An easy workaround is to `ssh` from the client into the server with port forwarding `ssh -L 127.0.0.1:8501:remote-host:8501 user@remote-host`, after which you can access the client at [http://localhost:8501](http://localhost:8501) as usual. If serving from a remote host with https, you may need to use a STUN server to establish the connection. You can do this by passing the `--use_ice_servers` flag: `streamlit run inference_client_webrtc.py -- --use_ice_servers`. -------------------------------------------------------------------------------- /inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import torch as T\n", 20 | "import torch.nn as nn\n", 21 | "import torch.nn.functional as F\n", 22 | "import torchaudio\n", 23 | "from utils import load_ckpt, print_colored\n", 24 | "from tokenizer import make_tokenizer\n", 25 | "from model import get_hertz_dev_config\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "from IPython.display import Audio, display\n", 28 | "\n", 29 | "\n", 30 | "# If you get an error like \"undefined symbol: __nvJitLinkComplete_12_4, version libnvJitLink.so.12\",\n", 31 | "# you need to install PyTorch with the correct CUDA version. Run:\n", 32 | "# `pip3 uninstall torch torchaudio && pip3 install torch torchaudio --index-url https://download.pytorch.org/whl/cu121`\n", 33 | "\n", 34 | "device = 'cuda' if T.cuda.is_available() else 'cpu'\n", 35 | "T.cuda.set_device(0)\n", 36 | "print_colored(f\"Using device: {device}\", \"grey\")" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "# This code will automatically download them if it can't find them.\n", 46 | "audio_tokenizer = make_tokenizer(device)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 7, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "# We have different checkpoints for the single-speaker and two-speaker models\n", 56 | "# Set to True to load and run inference with the two-speaker model\n", 57 | "TWO_SPEAKER = False\n", 58 | "USE_PURE_AUDIO_ABLATION = False # We trained a base model with no text initialization at all. Toggle this to enable it.\n", 59 | "assert not (USE_PURE_AUDIO_ABLATION and TWO_SPEAKER) # We only have a single-speaker version of this model.\n" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "model_config = get_hertz_dev_config(is_split=TWO_SPEAKER, use_pure_audio_ablation=USE_PURE_AUDIO_ABLATION)\n", 69 | "\n", 70 | "generator = model_config()\n", 71 | "generator = generator.eval().to(T.bfloat16).to(device)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "def load_and_preprocess_audio(audio_path):\n", 81 | " print_colored(\"Loading and preprocessing audio...\", \"blue\", bold=True)\n", 82 | " # Load audio file\n", 83 | " audio_tensor, sr = torchaudio.load(audio_path)\n", 84 | " print_colored(f\"Loaded audio shape: {audio_tensor.shape}\", \"grey\")\n", 85 | " \n", 86 | " if TWO_SPEAKER:\n", 87 | " if audio_tensor.shape[0] == 1:\n", 88 | " print_colored(\"Converting mono to stereo...\", \"grey\")\n", 89 | " audio_tensor = audio_tensor.repeat(2, 1)\n", 90 | " print_colored(f\"Stereo audio shape: {audio_tensor.shape}\", \"grey\")\n", 91 | " else:\n", 92 | " if audio_tensor.shape[0] == 2:\n", 93 | " print_colored(\"Converting stereo to mono...\", \"grey\")\n", 94 | " audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0)\n", 95 | " print_colored(f\"Mono audio shape: {audio_tensor.shape}\", \"grey\")\n", 96 | " \n", 97 | " # Resample to 16kHz if needed\n", 98 | " if sr != 16000:\n", 99 | " print_colored(f\"Resampling from {sr}Hz to 16000Hz...\", \"grey\")\n", 100 | " resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)\n", 101 | " audio_tensor = resampler(audio_tensor)\n", 102 | " \n", 103 | " # Clip to 5 minutes if needed\n", 104 | " max_samples = 16000 * 60 * 5\n", 105 | " if audio_tensor.shape[1] > max_samples:\n", 106 | " print_colored(\"Clipping audio to 5 minutes...\", \"grey\")\n", 107 | " audio_tensor = audio_tensor[:, :max_samples]\n", 108 | "\n", 109 | " \n", 110 | " print_colored(\"Audio preprocessing complete!\", \"green\")\n", 111 | " return audio_tensor.unsqueeze(0)\n", 112 | "\n", 113 | "def display_audio(audio_tensor):\n", 114 | " audio_tensor = audio_tensor.cpu().squeeze()\n", 115 | " if audio_tensor.ndim == 1:\n", 116 | " audio_tensor = audio_tensor.unsqueeze(0)\n", 117 | " audio_tensor = audio_tensor.float()\n", 118 | "\n", 119 | " # Make a waveform plot\n", 120 | " plt.figure(figsize=(4, 1))\n", 121 | " plt.plot(audio_tensor.numpy()[0], linewidth=0.5)\n", 122 | " plt.axis('off')\n", 123 | " plt.show()\n", 124 | "\n", 125 | " # Make an audio player\n", 126 | " display(Audio(audio_tensor.numpy(), rate=16000))\n", 127 | " print_colored(f\"Audio ready for playback ↑\", \"green\", bold=True)\n", 128 | " \n", 129 | " \n", 130 | "\n", 131 | "# Our model is very prompt-sensitive, so we recommend experimenting with a diverse set of prompts.\n", 132 | "prompt_audio = load_and_preprocess_audio('./prompts/toaskanymore.wav')\n", 133 | "display_audio(prompt_audio)\n", 134 | "prompt_len_seconds = 3\n", 135 | "prompt_len = prompt_len_seconds * 8" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "print_colored(\"Encoding prompt...\", \"blue\")\n", 145 | "with T.autocast(device_type='cuda', dtype=T.bfloat16):\n", 146 | " if TWO_SPEAKER:\n", 147 | " encoded_prompt_audio_ch1 = audio_tokenizer.latent_from_data(prompt_audio[:, 0:1].to(device))\n", 148 | " encoded_prompt_audio_ch2 = audio_tokenizer.latent_from_data(prompt_audio[:, 1:2].to(device))\n", 149 | " encoded_prompt_audio = T.cat([encoded_prompt_audio_ch1, encoded_prompt_audio_ch2], dim=-1)\n", 150 | " else:\n", 151 | " encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device))\n", 152 | "print_colored(f\"Encoded prompt shape: {encoded_prompt_audio.shape}\", \"grey\")\n", 153 | "print_colored(\"Prompt encoded successfully!\", \"green\")" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "def get_completion(encoded_prompt_audio, prompt_len, gen_len=None):\n", 163 | " prompt_len_seconds = prompt_len / 8\n", 164 | " print_colored(f\"Prompt length: {prompt_len_seconds:.2f}s\", \"grey\")\n", 165 | " print_colored(\"Completing audio...\", \"blue\")\n", 166 | " encoded_prompt_audio = encoded_prompt_audio[:, :prompt_len]\n", 167 | " with T.autocast(device_type='cuda', dtype=T.bfloat16):\n", 168 | " completed_audio_batch = generator.completion(\n", 169 | " encoded_prompt_audio, \n", 170 | " temps=(.8, (0.5, 0.1)), # (token_temp, (categorical_temp, gaussian_temp))\n", 171 | " use_cache=True,\n", 172 | " gen_len=gen_len)\n", 173 | "\n", 174 | " completed_audio = completed_audio_batch\n", 175 | " print_colored(f\"Decoding completion...\", \"blue\")\n", 176 | " if TWO_SPEAKER:\n", 177 | " decoded_completion_ch1 = audio_tokenizer.data_from_latent(completed_audio[:, :, :32].bfloat16())\n", 178 | " decoded_completion_ch2 = audio_tokenizer.data_from_latent(completed_audio[:, :, 32:].bfloat16())\n", 179 | " decoded_completion = T.cat([decoded_completion_ch1, decoded_completion_ch2], dim=0)\n", 180 | " else:\n", 181 | " decoded_completion = audio_tokenizer.data_from_latent(completed_audio.bfloat16())\n", 182 | " print_colored(f\"Decoded completion shape: {decoded_completion.shape}\", \"grey\")\n", 183 | "\n", 184 | " print_colored(\"Preparing audio for playback...\", \"blue\")\n", 185 | "\n", 186 | " audio_tensor = decoded_completion.cpu().squeeze()\n", 187 | " if audio_tensor.ndim == 1:\n", 188 | " audio_tensor = audio_tensor.unsqueeze(0)\n", 189 | " audio_tensor = audio_tensor.float()\n", 190 | "\n", 191 | " if audio_tensor.abs().max() > 1:\n", 192 | " audio_tensor = audio_tensor / audio_tensor.abs().max()\n", 193 | "\n", 194 | " return audio_tensor[:, max(prompt_len*2000 - 16000, 0):]\n", 195 | "\n", 196 | "num_completions = 10\n", 197 | "print_colored(f\"Generating {num_completions} completions...\", \"blue\")\n", 198 | "for _ in range(num_completions):\n", 199 | " completion = get_completion(encoded_prompt_audio, prompt_len, gen_len=20*8) # 20 seconds of generation\n", 200 | " display_audio(completion)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [] 216 | } 217 | ], 218 | "metadata": { 219 | "kernelspec": { 220 | "display_name": ".venv", 221 | "language": "python", 222 | "name": "python3" 223 | }, 224 | "language_info": { 225 | "codemirror_mode": { 226 | "name": "ipython", 227 | "version": 3 228 | }, 229 | "file_extension": ".py", 230 | "mimetype": "text/x-python", 231 | "name": "python", 232 | "nbconvert_exporter": "python", 233 | "pygments_lexer": "ipython3", 234 | "version": "3.10.12" 235 | } 236 | }, 237 | "nbformat": 4, 238 | "nbformat_minor": 2 239 | } 240 | -------------------------------------------------------------------------------- /inference_client.py: -------------------------------------------------------------------------------- 1 | # server.py remains the same as before 2 | 3 | # Updated client.py 4 | import asyncio 5 | import websockets 6 | import sounddevice as sd 7 | import numpy as np 8 | import base64 9 | import queue 10 | import argparse 11 | import requests 12 | import time 13 | 14 | class AudioClient: 15 | def __init__(self, server_url="ws://localhost:8000", token_temp=None, categorical_temp=None, gaussian_temp=None): 16 | # Convert ws:// to http:// for the base URL 17 | self.base_url = server_url.replace("ws://", "http://") 18 | self.server_url = f"{server_url}/audio" 19 | 20 | # Set temperatures if provided 21 | if any(t is not None for t in [token_temp, categorical_temp, gaussian_temp]): 22 | self.set_temperature_and_echo(token_temp, categorical_temp, gaussian_temp) 23 | 24 | # Initialize queues 25 | self.audio_queue = queue.Queue() 26 | self.output_queue = queue.Queue() 27 | 28 | def set_temperature_and_echo(self, token_temp=None, categorical_temp=None, gaussian_temp=None, echo_testing = False): 29 | """Send temperature settings to server""" 30 | params = {} 31 | if token_temp is not None: 32 | params['token_temp'] = token_temp 33 | if categorical_temp is not None: 34 | params['categorical_temp'] = categorical_temp 35 | if gaussian_temp is not None: 36 | params['gaussian_temp'] = gaussian_temp 37 | 38 | response = requests.post(f"{self.base_url}/set_temperature", params=params) 39 | print(response.json()['message']) 40 | 41 | def audio_callback(self, indata, frames, time, status): 42 | """This is called for each audio block""" 43 | if status: 44 | print(status) 45 | # if np.isclose(indata, 0).all(): 46 | # raise Exception('Audio input is not working - received all zeros') 47 | # Convert float32 to int16 for efficient transmission 48 | indata_int16 = (indata.copy() * 32767).astype(np.int16) 49 | # indata_int16 = np.zeros_like(indata_int16) 50 | self.audio_queue.put(indata_int16) 51 | 52 | def output_stream_callback(self, outdata, frames, time, status): 53 | """Callback for output stream to get audio data""" 54 | if status: 55 | print(status) 56 | 57 | try: 58 | data = self.output_queue.get_nowait() 59 | data = data.astype(np.float32) / 32767.0 60 | if len(data) < len(outdata): 61 | outdata[:len(data)] = data 62 | outdata[len(data):] = 0 63 | else: 64 | outdata[:] = data[:len(outdata)] 65 | except queue.Empty: 66 | outdata.fill(0) 67 | 68 | async def process_audio(self): 69 | async with websockets.connect(self.server_url) as ws: 70 | while self.running: 71 | if not self.audio_queue.empty(): 72 | # Get recorded audio 73 | audio_data = self.audio_queue.get() 74 | print(f'Data from microphone:{audio_data.shape, audio_data.dtype, audio_data.min(), audio_data.max()}') 75 | 76 | # Convert to base64 77 | audio_b64 = base64.b64encode(audio_data.tobytes()).decode('utf-8') 78 | 79 | # Send to server 80 | time_sent = time.time() 81 | await ws.send(f"data:audio/raw;base64,{audio_b64}") 82 | 83 | # Receive processed audio 84 | response = await ws.recv() 85 | response = response.split(",")[1] 86 | time_received = time.time() 87 | print(f"Data sent: {audio_b64[:10]}. Data received: {response[:10]}. Received in {(time_received - time_sent) * 1000:.2f} ms") 88 | processed_audio = np.frombuffer( 89 | base64.b64decode(response), 90 | dtype=np.int16 91 | ).reshape(-1, CHANNELS) 92 | print(f'Data from model:{processed_audio.shape, processed_audio.dtype, processed_audio.min(), processed_audio.max()}') 93 | 94 | self.output_queue.put(processed_audio) 95 | 96 | def start(self): 97 | self.running = True 98 | # Print audio device information 99 | devices = sd.query_devices() 100 | default_input = sd.query_devices(kind='input') 101 | default_output = sd.query_devices(kind='output') 102 | 103 | print("\nAudio Device Configuration:") 104 | print("-" * 50) 105 | print(f"Default Input Device:\n{default_input}\n") 106 | print(f"Default Output Device:\n{default_output}\n") 107 | print("\nAll Available Devices:") 108 | print("-" * 50) 109 | for i, device in enumerate(devices): 110 | print(f"Device {i}:") 111 | print(f"Name: {device['name']}") 112 | print(f"Channels (in/out): {device['max_input_channels']}/{device['max_output_channels']}") 113 | print(f"Sample Rates: {device['default_samplerate']}") 114 | print() 115 | input_device = input("Enter the index of the input device or press enter for default: ") 116 | output_device = input("Enter the index of the output device or press enter for default: ") 117 | if input_device == "": 118 | input_device = default_input['index'] 119 | if output_device == "": 120 | output_device = default_output['index'] 121 | with sd.InputStream(callback=self.audio_callback, 122 | channels=CHANNELS, 123 | samplerate=SAMPLE_RATE, 124 | device=int(input_device), 125 | blocksize=2000), \ 126 | sd.OutputStream(callback=self.output_stream_callback, 127 | channels=CHANNELS, 128 | samplerate=SAMPLE_RATE, 129 | blocksize=2000, 130 | device=int(output_device)): 131 | 132 | asyncio.run(self.process_audio()) 133 | 134 | def stop(self): 135 | self.running = False 136 | 137 | if __name__ == "__main__": 138 | parser = argparse.ArgumentParser(description='Audio Client with Temperature Control') 139 | parser.add_argument('--token_temp', '-t1', type=float, help='Token (LM) temperature parameter') 140 | parser.add_argument('--categorical_temp', '-t2', type=float, help='Categorical (VAE) temperature parameter') 141 | parser.add_argument('--gaussian_temp', '-t3', type=float, help='Gaussian (VAE) temperature parameter') 142 | parser.add_argument('--server', '-s', default="ws://localhost:8000", 143 | help='Server URL (default: ws://localhost:8000)') 144 | 145 | args = parser.parse_args() 146 | 147 | # Audio settings 148 | SAMPLE_RATE = 16000 149 | CHANNELS = 1 150 | 151 | client = AudioClient( 152 | server_url=args.server, 153 | token_temp=args.token_temp, 154 | categorical_temp=args.categorical_temp, 155 | gaussian_temp=args.gaussian_temp 156 | ) 157 | 158 | try: 159 | client.start() 160 | except KeyboardInterrupt: 161 | client.stop() -------------------------------------------------------------------------------- /inference_client_webrtc.py: -------------------------------------------------------------------------------- 1 | # server.py remains the same as before 2 | 3 | # Updated client.py 4 | import asyncio 5 | import websockets 6 | import numpy as np 7 | import base64 8 | import argparse 9 | import requests 10 | import time 11 | import torch 12 | import torchaudio 13 | 14 | import av 15 | import streamlit as st 16 | from typing import List 17 | from streamlit_webrtc import WebRtcMode, webrtc_streamer 18 | 19 | class AudioClient: 20 | def __init__(self, server_url="ws://localhost:8000", token_temp=None, categorical_temp=None, gaussian_temp=None): 21 | # Convert ws:// to http:// for the base URL 22 | self.base_url = server_url.replace("ws://", "http://") 23 | self.server_url = f"{server_url}/audio" 24 | self.sound_check = False 25 | 26 | # Set temperatures if provided 27 | if any(t is not None for t in [token_temp, categorical_temp, gaussian_temp]): 28 | response_message = self.set_temperature_and_echo(token_temp, categorical_temp, gaussian_temp) 29 | print(response_message) 30 | 31 | self.downsampler = torchaudio.transforms.Resample(STREAMING_SAMPLE_RATE, SAMPLE_RATE) 32 | self.upsampler = torchaudio.transforms.Resample(SAMPLE_RATE, STREAMING_SAMPLE_RATE) 33 | self.ws = None 34 | self.in_buffer = None 35 | self.out_buffer = None 36 | 37 | def set_temperature_and_echo(self, token_temp=None, categorical_temp=None, gaussian_temp=None, echo_testing = False): 38 | """Send temperature settings to server""" 39 | params = {} 40 | if token_temp is not None: 41 | params['token_temp'] = token_temp 42 | if categorical_temp is not None: 43 | params['categorical_temp'] = categorical_temp 44 | if gaussian_temp is not None: 45 | params['gaussian_temp'] = gaussian_temp 46 | 47 | response = requests.post(f"{self.base_url}/set_temperature", params=params) 48 | response_message = response.json()['message'] 49 | return response_message 50 | 51 | def _resample(self, audio_data: np.ndarray, resampler: torchaudio.transforms.Resample) -> np.ndarray: 52 | audio_data = audio_data.astype(np.float32) / 32767.0 53 | audio_data = resampler(torch.tensor(audio_data)).numpy() 54 | audio_data = (audio_data * 32767.0).astype(np.int16) 55 | return audio_data 56 | 57 | def upsample(self, audio_data: np.ndarray) -> np.ndarray: 58 | return self._resample(audio_data, self.upsampler) 59 | 60 | def downsample(self, audio_data: np.ndarray) -> np.ndarray: 61 | return self._resample(audio_data, self.downsampler) 62 | 63 | def from_s16_format(self, audio_data: np.ndarray, channels: int) -> np.ndarray: 64 | if channels == 2: 65 | audio_data = audio_data.reshape(-1, 2).T 66 | else: 67 | audio_data = audio_data.reshape(-1) 68 | return audio_data 69 | 70 | def to_s16_format(self, audio_data: np.ndarray): 71 | if len(audio_data.shape) == 2 and audio_data.shape[0] == 2: 72 | audio_data = audio_data.T.reshape(1, -1) 73 | elif len(audio_data.shape) == 1: 74 | audio_data = audio_data.reshape(1, -1) 75 | return audio_data 76 | 77 | def to_channels(self, audio_data: np.ndarray, channels: int) -> np.ndarray: 78 | current_channels = audio_data.shape[0] if len(audio_data.shape) == 2 else 1 79 | if current_channels == channels: 80 | return audio_data 81 | elif current_channels == 1 and channels == 2: 82 | audio_data = np.tile(audio_data, 2).reshape(2, -1) 83 | elif current_channels == 2 and channels == 1: 84 | audio_data = audio_data.astype(np.float32) / 32767.0 85 | audio_data = audio_data.mean(axis=0) 86 | audio_data = (audio_data * 32767.0).astype(np.int16) 87 | return audio_data 88 | 89 | async def process_audio(self, audio_data: np.ndarray) -> np.ndarray: 90 | if self.ws is None: 91 | self.ws = await websockets.connect(self.server_url) 92 | 93 | audio_data = audio_data.reshape(-1, CHANNELS) 94 | print(f'Data from microphone:{audio_data.shape, audio_data.dtype, audio_data.min(), audio_data.max()}') 95 | 96 | # Convert to base64 97 | audio_b64 = base64.b64encode(audio_data.tobytes()).decode('utf-8') 98 | 99 | # Send to server 100 | time_sent = time.time() 101 | await self.ws.send(f"data:audio/raw;base64,{audio_b64}") 102 | 103 | # Receive processed audio 104 | response = await self.ws.recv() 105 | response = response.split(",")[1] 106 | time_received = time.time() 107 | print(f"Data sent: {audio_b64[:10]}. Data received: {response[:10]}. Received in {(time_received - time_sent) * 1000:.2f} ms") 108 | processed_audio = np.frombuffer( 109 | base64.b64decode(response), 110 | dtype=np.int16 111 | ).reshape(-1, CHANNELS) 112 | print(f'Data from model:{processed_audio.shape, processed_audio.dtype, processed_audio.min(), processed_audio.max()}') 113 | 114 | if CHANNELS == 1: 115 | processed_audio = processed_audio.reshape(-1) 116 | return processed_audio 117 | 118 | async def queued_audio_frames_callback(self, frames: List[av.AudioFrame]) -> List[av.AudioFrame]: 119 | out_frames = [] 120 | for frame in frames: 121 | # Read in audio 122 | audio_data = frame.to_ndarray() 123 | 124 | # Convert input audio from s16 format, convert to `CHANNELS` number of channels, and downsample 125 | audio_data = self.from_s16_format(audio_data, len(frame.layout.channels)) 126 | audio_data = self.to_channels(audio_data, CHANNELS) 127 | audio_data = self.downsample(audio_data) 128 | 129 | # Add audio to input buffer 130 | if self.in_buffer is None: 131 | self.in_buffer = audio_data 132 | else: 133 | self.in_buffer = np.concatenate((self.in_buffer, audio_data), axis=-1) 134 | 135 | # Take BLOCK_SIZE samples from input buffer if available for processing 136 | if self.in_buffer.shape[0] >= BLOCK_SIZE: 137 | audio_data = self.in_buffer[:BLOCK_SIZE] 138 | self.in_buffer = self.in_buffer[BLOCK_SIZE:] 139 | else: 140 | audio_data = None 141 | 142 | # Process audio if available and add resulting audio to output buffer 143 | if audio_data is not None: 144 | if not self.sound_check: 145 | audio_data = await self.process_audio(audio_data) 146 | if self.out_buffer is None: 147 | self.out_buffer = audio_data 148 | else: 149 | self.out_buffer = np.concatenate((self.out_buffer, audio_data), axis=-1) 150 | 151 | # Take `out_samples` samples from output buffer if available for output 152 | out_samples = int(frame.samples * SAMPLE_RATE / STREAMING_SAMPLE_RATE) 153 | if self.out_buffer is not None and self.out_buffer.shape[0] >= out_samples: 154 | audio_data = self.out_buffer[:out_samples] 155 | self.out_buffer = self.out_buffer[out_samples:] 156 | else: 157 | audio_data = None 158 | 159 | # Output silence if no audio data available 160 | if audio_data is None: 161 | # output silence 162 | audio_data = np.zeros(out_samples, dtype=np.int16) 163 | 164 | # Upsample output audio, convert to original number of channels, and convert to s16 format 165 | audio_data = self.upsample(audio_data) 166 | audio_data = self.to_channels(audio_data, len(frame.layout.channels)) 167 | audio_data = self.to_s16_format(audio_data) 168 | 169 | # return audio data as AudioFrame 170 | new_frame = av.AudioFrame.from_ndarray(audio_data, format=frame.format.name, layout=frame.layout.name) 171 | new_frame.sample_rate = frame.sample_rate 172 | out_frames.append(new_frame) 173 | 174 | return out_frames 175 | 176 | def stop(self): 177 | if self.ws is not None: 178 | # TODO: this hangs. Figure out why. 179 | #asyncio.get_event_loop().run_until_complete(self.ws.close()) 180 | print("Websocket closed") 181 | self.ws = None 182 | self.in_buffer = None 183 | self.out_buffer = None 184 | 185 | if __name__ == "__main__": 186 | parser = argparse.ArgumentParser(description='Audio Client with Temperature Control') 187 | parser.add_argument('--token_temp', '-t1', type=float, help='Token (LM) temperature parameter') 188 | parser.add_argument('--categorical_temp', '-t2', type=float, help='Categorical (VAE) temperature parameter') 189 | parser.add_argument('--gaussian_temp', '-t3', type=float, help='Gaussian (VAE) temperature parameter') 190 | parser.add_argument('--server', '-s', default="ws://localhost:8000", 191 | help='Server URL (default: ws://localhost:8000)') 192 | parser.add_argument("--use_ice_servers", action="store_true", help="Use public STUN servers") 193 | 194 | args = parser.parse_args() 195 | 196 | # Audio settings 197 | STREAMING_SAMPLE_RATE = 48000 198 | SAMPLE_RATE = 16000 199 | BLOCK_SIZE = 2000 200 | CHANNELS = 1 201 | 202 | st.title("hertz-dev webrtc demo!") 203 | st.markdown(""" 204 | Welcome to the audio processing interface! Here you can talk live with hertz. 205 | - Process audio in real-time through your microphone 206 | - Adjust various temperature parameters for inference 207 | - Test your microphone with sound check mode 208 | - Enable/disable echo cancellation and noise suppression 209 | 210 | To begin, click the START button below and allow microphone access. 211 | """) 212 | 213 | audio_client = st.session_state.get("audio_client") 214 | if audio_client is None: 215 | audio_client = AudioClient( 216 | server_url=args.server, 217 | token_temp=args.token_temp, 218 | categorical_temp=args.categorical_temp, 219 | gaussian_temp=args.gaussian_temp 220 | ) 221 | st.session_state.audio_client = audio_client 222 | 223 | with st.sidebar: 224 | st.markdown("## Inference Settings") 225 | token_temp_default = args.token_temp if args.token_temp is not None else 0.8 226 | token_temp = st.slider("Token Temperature", 0.05, 2.0, token_temp_default, step=0.05) 227 | categorical_temp_default = args.categorical_temp if args.categorical_temp is not None else 0.4 228 | categorical_temp = st.slider("Categorical Temperature", 0.01, 1.0, categorical_temp_default, step=0.01) 229 | gaussian_temp_default = args.gaussian_temp if args.gaussian_temp is not None else 0.1 230 | gaussian_temp = st.slider("Gaussian Temperature", 0.01, 1.0, gaussian_temp_default, step=0.01) 231 | if st.button("Set Temperatures"): 232 | response_message = audio_client.set_temperature_and_echo(token_temp, categorical_temp, gaussian_temp) 233 | st.write(response_message) 234 | 235 | st.markdown("## Microphone Settings") 236 | audio_client.sound_check = st.toggle("Sound Check (Echo)", value=False) 237 | echo_cancellation = st.toggle("Echo Cancellation*‡", value=False) 238 | noise_suppression = st.toggle("Noise Suppression*", value=False) 239 | st.markdown(r"\* *Restart stream to take effect*") 240 | st.markdown("‡ *May cause audio to cut out*") 241 | 242 | # Use a free STUN server from Google if --use_ice_servers is given 243 | # (found in get_ice_servers() at https://github.com/whitphx/streamlit-webrtc/blob/main/sample_utils/turn.py) 244 | rtc_configuration = {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]} if args.use_ice_servers else None 245 | audio_config = {"echoCancellation": echo_cancellation, "noiseSuppression": noise_suppression} 246 | webrtc_streamer( 247 | key="streamer", 248 | mode=WebRtcMode.SENDRECV, 249 | rtc_configuration=rtc_configuration, 250 | media_stream_constraints={"audio": audio_config, "video": False}, 251 | queued_audio_frames_callback=audio_client.queued_audio_frames_callback, 252 | on_audio_ended=audio_client.stop, 253 | async_processing=True, 254 | ) 255 | 256 | -------------------------------------------------------------------------------- /inference_server.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | from fastapi import FastAPI, WebSocket 4 | from fastapi.middleware.cors import CORSMiddleware 5 | import base64 6 | import uvicorn 7 | import traceback 8 | import numpy as np 9 | import argparse 10 | 11 | import torch as T 12 | import torch.nn.functional as F 13 | import torchaudio 14 | 15 | import os 16 | from typing import Optional 17 | 18 | from utils import print_colored 19 | from model import get_hertz_dev_config 20 | 21 | 22 | argparse = argparse.ArgumentParser() 23 | 24 | argparse.add_argument('--prompt_path', type=str, default='./prompts/bob_mono.wav', help=""" 25 | We highly recommend making your own prompt based on a conversation between you and another person. 26 | bob_mono.wav seems to work better for two-channel than bob_stereo.wav. 27 | """) 28 | args = argparse.parse_args() 29 | 30 | 31 | device = 'cuda' if T.cuda.is_available() else T.device('cpu') 32 | print_colored(f"Using device: {device}", "grey") 33 | 34 | model_config = get_hertz_dev_config(is_split=True) 35 | 36 | model = model_config() 37 | model = model.eval().bfloat16().to(device) 38 | 39 | app = FastAPI() 40 | 41 | app.add_middleware( 42 | CORSMiddleware, 43 | allow_origins=["*"], 44 | allow_credentials=True, 45 | allow_methods=["*"], 46 | allow_headers=["*"], 47 | ) 48 | 49 | 50 | # Hyperparams or something. 51 | SAMPLE_RATE = 16000 # Don't change this 52 | TEMPS = (0.8, (0.4, 0.1)) # You can change this, but there's also an endpoint for it. 53 | REPLAY_SECONDS = 3 # What the user hears as context. 54 | 55 | class AudioProcessor: 56 | def __init__(self, model, prompt_path): 57 | self.model = model 58 | self.prompt_path = prompt_path 59 | self.initialize_state(prompt_path) 60 | 61 | def initialize_state(self, prompt_path): 62 | loaded_audio, sr = torchaudio.load(prompt_path) 63 | self.replay_seconds = REPLAY_SECONDS 64 | 65 | if sr != SAMPLE_RATE: 66 | resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE) 67 | loaded_audio = resampler(loaded_audio) 68 | 69 | if loaded_audio.shape[0] == 1: 70 | loaded_audio = loaded_audio.repeat(2, 1) 71 | 72 | audio_length = loaded_audio.shape[-1] 73 | num_chunks = audio_length // 2000 74 | loaded_audio = loaded_audio[..., :num_chunks * 2000] 75 | 76 | self.loaded_audio = loaded_audio.to(device) 77 | 78 | with T.autocast(device_type=device, dtype=T.bfloat16), T.inference_mode(): 79 | self.model.init_cache(bsize=1, device=device, dtype=T.bfloat16, length=1024) 80 | self.next_model_audio = self.model.next_audio_from_audio(self.loaded_audio.unsqueeze(0), temps=TEMPS) 81 | self.prompt_buffer = None 82 | self.prompt_position = 0 83 | self.chunks_until_live = int(self.replay_seconds * 8) 84 | self.initialize_prompt_buffer() 85 | print_colored("AudioProcessor state initialized", "green") 86 | 87 | def initialize_prompt_buffer(self): 88 | self.recorded_audio = self.loaded_audio 89 | prompt_audio = self.loaded_audio.reshape(1, 2, -1) 90 | prompt_audio = prompt_audio[:, :, -(16000*self.replay_seconds):].cpu().numpy() 91 | prompt_audio_mono = prompt_audio.mean(axis=1) 92 | self.prompt_buffer = np.array_split(prompt_audio_mono[0], int(self.replay_seconds * 8)) 93 | print_colored(f"Initialized prompt buffer with {len(self.prompt_buffer)} chunks", "grey") 94 | 95 | async def process_audio(self, audio_data): 96 | if self.chunks_until_live > 0: 97 | print_colored(f"Serving from prompt buffer, {self.chunks_until_live} chunks left", "grey") 98 | chunk = self.prompt_buffer[int(self.replay_seconds * 8) - self.chunks_until_live] 99 | self.chunks_until_live -= 1 100 | 101 | if self.chunks_until_live == 0: 102 | print_colored("Switching to live processing mode", "green") 103 | 104 | time.sleep(0.05) 105 | return chunk 106 | 107 | audio_tensor = T.from_numpy(audio_data).to(device) 108 | audio_tensor = audio_tensor.reshape(1, 1, -1) 109 | audio_tensor = T.cat([audio_tensor, self.next_model_audio], dim=1) 110 | 111 | with T.autocast(device_type=device, dtype=T.bfloat16), T.inference_mode(): 112 | curr_model_audio = self.model.next_audio_from_audio( 113 | audio_tensor, 114 | temps=TEMPS 115 | ) 116 | print(f"Recorded audio shape {self.recorded_audio.shape}, audio tensor shape {audio_tensor.shape}") 117 | self.recorded_audio = T.cat([self.recorded_audio.cpu(), audio_tensor.squeeze(0).cpu()], dim=-1) 118 | 119 | self.next_model_audio = curr_model_audio 120 | 121 | return curr_model_audio.float().cpu().numpy() 122 | 123 | def cleanup(self): 124 | print_colored("Cleaning up audio processor...", "blue") 125 | os.makedirs('audio_recordings', exist_ok=True) 126 | torchaudio.save(f'audio_recordings/{time.strftime("%d-%H-%M")}.wav', self.recorded_audio.cpu(), SAMPLE_RATE) 127 | self.model.deinit_cache() 128 | self.initialize_state(self.prompt_path) 129 | print_colored("Audio processor cleanup complete", "green") 130 | 131 | @app.post("/set_temperature") 132 | async def set_temperature(token_temp: Optional[float] = None, categorical_temp: Optional[float] = None, gaussian_temp: Optional[float] = None): 133 | try: 134 | global TEMPS 135 | TEMPS = (token_temp, (categorical_temp, gaussian_temp)) 136 | 137 | print_colored(f"Temperature updated to: {TEMPS}", "green") 138 | return {"message": f"Temperature updated to: {TEMPS}", "status": "success"} 139 | except Exception as e: 140 | print_colored(f"Error setting temperature: {str(e)}", "red") 141 | return {"message": f"Error setting temperature: {str(e)}", "status": "error"} 142 | 143 | @app.websocket("/audio") 144 | async def websocket_endpoint(websocket: WebSocket): 145 | await websocket.accept() 146 | try: 147 | while True: 148 | data = await websocket.receive_text() 149 | audio_data = np.frombuffer( 150 | base64.b64decode(data.split(",")[1]), 151 | dtype=np.int16 152 | ) 153 | audio_data = audio_data.astype(np.float32) / 32767.0 154 | processed_audio = await audio_processor.process_audio(audio_data) 155 | processed_audio = (processed_audio * 32767).astype(np.int16) 156 | 157 | processed_data = base64.b64encode(processed_audio.tobytes()).decode('utf-8') 158 | await websocket.send_text(f"data:audio/raw;base64,{processed_data}") 159 | 160 | except Exception as e: 161 | print_colored(f"WebSocket error: {e}", "red") 162 | print_colored(f"Full traceback:\n{traceback.format_exc()}", "red") 163 | finally: 164 | audio_processor.cleanup() 165 | await websocket.close() 166 | 167 | 168 | audio_processor = AudioProcessor(model=model, prompt_path=args.prompt_path) 169 | 170 | if __name__ == "__main__": 171 | uvicorn.run(app, host="0.0.0.0", port=8000) 172 | print("Server started") 173 | -------------------------------------------------------------------------------- /ioblocks.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from functools import partial 3 | from contextlib import nullcontext 4 | from typing import List, Tuple 5 | from math import ceil 6 | 7 | import torch as T 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.distributed as dist 11 | from torch import Tensor, int32 12 | from torch.amp import autocast 13 | 14 | from einops import rearrange, pack, unpack 15 | 16 | 17 | from utils import si_module, exists, default, maybe 18 | 19 | 20 | @si_module 21 | class GaussianMixtureIOLayer(nn.Module): 22 | class Config: 23 | latent_dim: int 24 | dim: int 25 | num_components: int 26 | 27 | def __init__(self, c: Config): 28 | super().__init__() 29 | self.latent_dim = c.latent_dim 30 | self.num_components = c.num_components 31 | self.input_projection = nn.Linear(c.latent_dim, c.dim) 32 | 33 | self.fc_loc = nn.Linear(c.dim, c.num_components * c.latent_dim) 34 | self.fc_scale = nn.Linear(c.dim, c.num_components * c.latent_dim) 35 | self.fc_weight = nn.Linear(c.dim, c.num_components) 36 | 37 | def _square_plus(self, x): 38 | return (x + T.sqrt(T.square(x) + 4)) / 2 39 | 40 | def input(self, sampled_latents: T.Tensor) -> T.Tensor: 41 | """Pre-sampled latents T.Tensor (B, L, Z) -> float tensor (B, L, D)""" 42 | hidden = self.input_projection(sampled_latents) 43 | return hidden 44 | 45 | def output(self, h: T.Tensor) -> Tuple[T.Tensor, T.Tensor, T.Tensor]: 46 | """float tensor (B, L, D) -> Tuple of locs, scales, and weights""" 47 | batch_size, seq_len, _ = h.shape 48 | 49 | locs = self.fc_loc(h).view(batch_size, seq_len, self.num_components, self.latent_dim) 50 | scales = T.clamp(self._square_plus(self.fc_scale(h)), min=1e-6).view(batch_size, seq_len, self.num_components, self.latent_dim) 51 | weights = self.fc_weight(h).view(batch_size, seq_len, self.num_components) 52 | 53 | return (locs, scales, weights) 54 | 55 | def loss(self, data, dataHat): 56 | locs, scales, weights = dataHat 57 | log_probs = -0.5 * T.sum( 58 | (data.unsqueeze(-2) - locs).pow(2) / scales.pow(2) + 59 | 2 * T.log(scales) + 60 | T.log(T.tensor(2 * T.pi)), 61 | dim=-1 62 | ) 63 | log_weights = F.log_softmax(weights, dim=-1) 64 | return -T.logsumexp(log_weights + log_probs, dim=-1) 65 | 66 | 67 | def temp_sample(self, orig_pdist, temp): 68 | locs, scales, weights = orig_pdist 69 | if temp is None: 70 | component_samples = locs + scales * T.randn_like(scales) 71 | mixture_samples = F.gumbel_softmax(weights, hard=True) 72 | sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2) 73 | elif isinstance(temp, tuple): 74 | assert len(temp) == 2 75 | categorical_temp, gaussian_temp = temp 76 | component_samples = locs + scales * gaussian_temp * T.randn_like(scales) 77 | mixture_samples = F.gumbel_softmax(weights / categorical_temp, hard=True) 78 | sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2) 79 | else: 80 | component_samples = locs + scales * temp * T.randn_like(scales) 81 | mixture_samples = F.gumbel_softmax(weights / temp, hard=True) 82 | sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2) 83 | return sampled 84 | 85 | 86 | class GPTOutput(nn.Module): 87 | def __init__(self, dim, vocab_size): 88 | super().__init__() 89 | self.output = nn.Linear(dim, vocab_size, bias=False) 90 | 91 | def forward(self, x): 92 | return self.output(x) 93 | 94 | 95 | # helper functions 96 | 97 | def pack_one(t, pattern): 98 | return pack([t], pattern) 99 | 100 | def unpack_one(t, ps, pattern): 101 | return unpack(t, ps, pattern)[0] 102 | 103 | def first(l): 104 | return l[0] 105 | 106 | def round_up_multiple(num, mult): 107 | return ceil(num / mult) * mult 108 | 109 | def get_code_utilization(codes, codebook_size, get_global=False): 110 | if get_global and dist.is_initialized(): 111 | world_size = dist.get_world_size() 112 | else: 113 | world_size = 1 114 | 115 | if world_size > 1: 116 | gathered_tokens = [T.zeros_like(codes) for _ in range(world_size)] 117 | dist.all_gather(gathered_tokens, codes) 118 | gathered_tokens = T.cat(gathered_tokens, dim=0) 119 | else: 120 | gathered_tokens = codes 121 | unique_tokens = len(T.unique(gathered_tokens)) 122 | code_utilization = unique_tokens / min(gathered_tokens.numel(), codebook_size) 123 | return code_utilization 124 | 125 | # tensor helpers 126 | 127 | def round_ste(z: Tensor) -> Tensor: 128 | """Round with straight through gradients.""" 129 | zhat = z.round() 130 | return z + (zhat - z).detach() 131 | 132 | # main class 133 | # lucidrains fsq 134 | @si_module 135 | class FSQ(nn.Module): 136 | @property 137 | def needs_float32_params(self): 138 | return True 139 | 140 | class Config: 141 | levels: List[int] 142 | dim: int | None = None 143 | num_codebooks: int = 1 144 | keep_num_codebooks_dim: bool | None = None 145 | scale: float | None = None 146 | allowed_dtypes: Tuple[str, ...] = ('float32', 'float64') 147 | channel_first: bool = False 148 | projection_has_bias: bool = True 149 | return_indices: bool = True 150 | force_quantization_f32: bool = True 151 | use_rms: bool = False 152 | 153 | def __init__(self, c: Config): 154 | super().__init__() 155 | _levels = T.tensor(c.levels, dtype=int32) 156 | self.register_buffer("_levels", _levels, persistent = False) 157 | 158 | _basis = T.cumprod(T.tensor([1] + c.levels[:-1]), dim=0, dtype=int32) 159 | self.register_buffer("_basis", _basis, persistent = False) 160 | 161 | self.scale = c.scale 162 | 163 | codebook_dim = len(c.levels) 164 | self.codebook_dim = codebook_dim 165 | 166 | effective_codebook_dim = codebook_dim * c.num_codebooks 167 | self.num_codebooks = c.num_codebooks 168 | 169 | self.allowed_dtypes = [] 170 | for dtype_str in c.allowed_dtypes: 171 | if hasattr(T, dtype_str): 172 | self.allowed_dtypes.append(getattr(T, dtype_str)) 173 | else: 174 | raise ValueError(f"Invalid dtype string: {dtype_str}") 175 | 176 | self.effective_codebook_dim = effective_codebook_dim 177 | 178 | keep_num_codebooks_dim = default(c.keep_num_codebooks_dim, c.num_codebooks > 1) 179 | assert not (c.num_codebooks > 1 and not keep_num_codebooks_dim) 180 | self.keep_num_codebooks_dim = keep_num_codebooks_dim 181 | 182 | self.dim = default(c.dim, len(_levels) * c.num_codebooks) 183 | 184 | self.channel_first = c.channel_first 185 | 186 | has_projections = self.dim != effective_codebook_dim 187 | self.project_in = nn.Linear(self.dim, effective_codebook_dim, bias = c.projection_has_bias) if has_projections else nn.Identity() 188 | self.project_out = nn.Linear(effective_codebook_dim, self.dim, bias = c.projection_has_bias) if has_projections else nn.Identity() 189 | 190 | self.has_projections = has_projections 191 | 192 | self.return_indices = c.return_indices 193 | if c.return_indices: 194 | self.codebook_size = self._levels.prod().item() 195 | implicit_codebook = self._indices_to_codes(T.arange(self.codebook_size)) 196 | self.register_buffer("implicit_codebook", implicit_codebook, persistent = False) 197 | 198 | self.allowed_dtypes = c.allowed_dtypes 199 | self.force_quantization_f32 = c.force_quantization_f32 200 | 201 | self.latent_loss = None 202 | 203 | def latent_metric(self, codes, get_global=False): 204 | return {'code_util_estimate': get_code_utilization(codes, self.codebook_size, get_global)} 205 | 206 | def repr_from_latent(self, latent): 207 | return self.indices_to_codes(latent) 208 | 209 | def bound(self, z, eps: float = 1e-3): 210 | """ Bound `z`, an array of shape (..., d). """ 211 | half_l = (self._levels - 1) * (1 + eps) / 2 212 | offset = T.where(self._levels % 2 == 0, 0.5, 0.0) 213 | shift = (offset / half_l).atanh() 214 | return (z + shift).tanh() * half_l - offset 215 | 216 | def quantize(self, z): 217 | """ Quantizes z, returns quantized zhat, same shape as z. """ 218 | quantized = round_ste(self.bound(z)) 219 | half_width = self._levels // 2 # Renormalize to [-1, 1]. 220 | return quantized / half_width 221 | 222 | def _scale_and_shift(self, zhat_normalized): 223 | half_width = self._levels // 2 224 | return (zhat_normalized * half_width) + half_width 225 | 226 | def _scale_and_shift_inverse(self, zhat): 227 | half_width = self._levels // 2 228 | return (zhat - half_width) / half_width 229 | 230 | def _indices_to_codes(self, indices): 231 | level_indices = self.indices_to_level_indices(indices) 232 | codes = self._scale_and_shift_inverse(level_indices) 233 | return codes 234 | 235 | def codes_to_indices(self, zhat): 236 | """ Converts a `code` to an index in the codebook. """ 237 | assert zhat.shape[-1] == self.codebook_dim 238 | zhat = self._scale_and_shift(zhat) 239 | return (zhat * self._basis).sum(dim=-1).to(int32) 240 | 241 | def indices_to_level_indices(self, indices): 242 | """ Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """ 243 | indices = rearrange(indices, '... -> ... 1') 244 | codes_non_centered = (indices // self._basis) % self._levels 245 | return codes_non_centered 246 | 247 | def indices_to_codes(self, indices): 248 | """ Inverse of `codes_to_indices`. """ 249 | assert exists(indices) 250 | 251 | is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) 252 | 253 | codes = self._indices_to_codes(indices) 254 | 255 | if self.keep_num_codebooks_dim: 256 | codes = rearrange(codes, '... c d -> ... (c d)') 257 | 258 | codes = self.project_out(codes) 259 | 260 | if is_img_or_video or self.channel_first: 261 | codes = rearrange(codes, 'b ... d -> b d ...') 262 | 263 | return codes 264 | 265 | # @autocast(device_type='cuda', enabled = False) 266 | def forward(self, z, return_codes=False): 267 | """ 268 | einstein notation 269 | b - batch 270 | n - sequence (or flattened spatial dimensions) 271 | d - feature dimension 272 | c - number of codebook dim 273 | """ 274 | 275 | is_img_or_video = z.ndim >= 4 276 | need_move_channel_last = is_img_or_video or self.channel_first 277 | 278 | # standardize image or video into (batch, seq, dimension) 279 | 280 | if need_move_channel_last: 281 | z = rearrange(z, 'b d ... -> b ... d') 282 | z, ps = pack_one(z, 'b * d') 283 | 284 | assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}' 285 | 286 | z = self.project_in(z) 287 | 288 | z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks) 289 | 290 | # whether to force quantization step to be full precision or not 291 | 292 | force_f32 = self.force_quantization_f32 293 | quantization_context = partial(autocast, device_type='cuda', enabled = False) if force_f32 else nullcontext 294 | 295 | with quantization_context(): 296 | orig_dtype = z.dtype 297 | 298 | if force_f32 and orig_dtype not in self.allowed_dtypes: 299 | z = z.float() 300 | 301 | codes = self.quantize(z) 302 | 303 | # returning indices could be optional 304 | 305 | indices = None 306 | 307 | if self.return_indices: 308 | indices = self.codes_to_indices(codes) 309 | 310 | codes = rearrange(codes, 'b n c d -> b n (c d)') 311 | 312 | codes = codes.type(orig_dtype) 313 | 314 | # project out 315 | if return_codes: 316 | return codes, indices 317 | 318 | out = self.project_out(codes) 319 | 320 | # reconstitute image or video dimensions 321 | 322 | if need_move_channel_last: 323 | out = unpack_one(out, ps, 'b * d') 324 | out = rearrange(out, 'b ... d -> b d ...') 325 | 326 | indices = maybe(unpack_one)(indices, ps, 'b * c') 327 | 328 | if not self.keep_num_codebooks_dim and self.return_indices: 329 | indices = maybe(rearrange)(indices, '... 1 -> ...') 330 | 331 | # return quantized output and indices 332 | 333 | return out, indices -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch as T 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from ioblocks import GaussianMixtureIOLayer, FSQ 8 | 9 | from transformer import Stack, ShapeRotator, Block as PerfBlock, GPTOutput, CACHE_FILL_VALUE, FFNN, Norm 10 | from tokenizer import make_tokenizer 11 | 12 | 13 | from utils import si_module, exists, isnt, tqdm0, print0, default, print0_colored 14 | from utils import load_ckpt 15 | 16 | 17 | @si_module 18 | class LatentQuantizer(nn.Module): 19 | class Config: 20 | compressor_config: Optional[FSQ.Config] = None 21 | 22 | dim: Optional[int] = None 23 | ff_dim: Optional[int] = None 24 | input_dim: int = None 25 | 26 | from_pretrained: Optional[Tuple[str, str]] = None 27 | 28 | def __init__(self, c: Config): 29 | super().__init__() 30 | 31 | if exists(c.from_pretrained): 32 | checkpoint = load_ckpt(*c.from_pretrained) 33 | else: 34 | assert exists(c.compressor_config), f'hmm {c}' 35 | 36 | self.compressor = c.compressor_config() 37 | self.ffnn = FFNN(c.dim, c.ff_dim) 38 | self.input = nn.Linear(c.input_dim, c.dim) if exists(c.input_dim) else nn.Identity() 39 | 40 | if exists(c.from_pretrained): 41 | self.load_state_dict(checkpoint) 42 | 43 | @T.no_grad() 44 | def forward(self, x, return_latent=False, known_latent=None): 45 | """ 46 | x: (B, S, D) 47 | """ 48 | if exists(known_latent): 49 | return self.compressor.indices_to_codes(known_latent) 50 | 51 | x = self.input(x) 52 | x = self.ffnn(x) 53 | x, tokens = self.compressor(x) 54 | 55 | if return_latent: 56 | return x, tokens 57 | return x 58 | 59 | 60 | @si_module 61 | class TransformerVAE(nn.Module): 62 | class Config: 63 | io_config: Optional[GaussianMixtureIOLayer.Config] = None 64 | stack_config: Optional[Stack.Config] = None 65 | quantizer_config: Optional[LatentQuantizer.Config] = None 66 | 67 | plex_layer: int = None 68 | plex_roll: int = 1 69 | split: bool = True 70 | 71 | from_pretrained: Optional[Tuple[str, str]] = None 72 | 73 | def __init__(self, c: Config): 74 | super().__init__() 75 | 76 | if exists(c.from_pretrained): 77 | checkpoint = load_ckpt(*c.from_pretrained) 78 | else: 79 | assert (exists(c.io_config) and exists(c.stack_config) and exists(c.quantizer_config)), f'hmm {c}' 80 | 81 | self.io = c.io_config() 82 | self.stack = c.stack_config() 83 | 84 | self.plex_layer = c.stack_config.layers//2 85 | self.plex_roll = c.plex_roll 86 | self.plex_dim = c.quantizer_config.dim 87 | 88 | assert self.plex_dim is not None and c.stack_config.dim is not None, f'One of the following are None: self.plex_dim: {self.plex_dim}, c.stack_config.dim: {c.stack_config.dim}' 89 | self.plex_projection = nn.Linear(self.plex_dim, c.stack_config.dim) 90 | self.out_norm = Norm(c.stack_config.dim) 91 | 92 | if c.split: 93 | self.io2 = c.io_config() 94 | self.plex_projection2 = nn.Linear(self.plex_dim, c.stack_config.dim) 95 | 96 | self.io2.fc_loc = None 97 | self.io2.fc_scale = None 98 | self.io2.fc_weight = None 99 | 100 | kv_heads = c.stack_config.kv_heads or c.stack_config.n_head 101 | head_dim = c.stack_config.dim // c.stack_config.n_head 102 | self.cache_num_layers = c.stack_config.layers + ((c.stack_config.layers - self.plex_layer) if c.split else 0) 103 | cache_shape = [self.cache_num_layers, c.stack_config.seq_len, 2, kv_heads, head_dim] 104 | self.cache_shape = cache_shape 105 | self.cache = [None] * self.cache_num_layers 106 | 107 | if exists(c.from_pretrained): 108 | result = self.load_state_dict(checkpoint, strict=False) 109 | print0_colored(result, 'yellow') 110 | 111 | self.quantizer = c.quantizer_config().eval() 112 | self.quantizer.requires_grad = False 113 | 114 | @T.no_grad() 115 | def quantize(self, x): 116 | if self.c.split: 117 | x1, x2 = x.chunk(2, dim=-1) 118 | with T.autocast(device_type='cuda', dtype=T.bfloat16): 119 | quantized1 = self.quantizer(x1) 120 | quantized2 = self.quantizer(x2) 121 | return quantized1, quantized2 122 | else: 123 | with T.autocast(device_type='cuda', dtype=T.bfloat16): 124 | return self.quantizer(x) 125 | 126 | @T.no_grad() 127 | def untokenize(self, token_data): 128 | return self.quantizer(None, known_latent=token_data) 129 | 130 | def init_cache(self, bsize, device, dtype, length:int=None): 131 | cache_shape = self.cache_shape.copy() 132 | cache_shape[1] = length or cache_shape[1] 133 | self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1) 134 | 135 | def deinit_cache(self): 136 | self.cache = [None] * self.cache_num_layers 137 | 138 | @T.no_grad() 139 | def forward(self, data, next_tokens: Optional[Tuple[T.Tensor, T.Tensor]] = None, temps: Optional[Tuple[float, Tuple[float, float]]] = None): 140 | if self.c.split: 141 | x1, x2 = data.chunk(2, dim=-1) 142 | x = self.io.input(x1) + self.io2.input(x2) 143 | else: 144 | x = self.io.input(data) 145 | 146 | cache_idx = 0 147 | for l, layer in enumerate(self.stack.layers): 148 | if l == self.plex_layer: 149 | if self.c.split: 150 | plex1, plex2 = self.quantize(data) 151 | plex1 = T.roll(plex1, -self.c.plex_roll, dims=1) 152 | plex2 = T.roll(plex2, -self.c.plex_roll, dims=1) 153 | if exists(next_tokens): 154 | plex1[:, -1:] = self.untokenize(next_tokens[0]) 155 | plex2[:, -1:] = self.untokenize(next_tokens[1]) 156 | x1 = x + self.plex_projection(plex1) 157 | x2 = x + self.plex_projection2(plex2) 158 | else: 159 | plex = self.quantize(data) 160 | plex = T.roll(plex, -self.c.plex_roll, dims=1) 161 | if exists(next_tokens): 162 | plex[:, -1:] = self.untokenize(next_tokens) 163 | x = x + self.plex_projection(plex) 164 | 165 | if l < self.plex_layer: 166 | x = layer(x, kv=self.cache[l]) 167 | else: 168 | if self.c.split: 169 | x1 = layer(x1, kv=self.cache[self.plex_layer + cache_idx]) 170 | cache_idx += 1 171 | x2 = layer(x2, kv=self.cache[self.plex_layer + cache_idx]) 172 | cache_idx += 1 173 | else: 174 | x = layer(x, kv=self.cache[l]) 175 | 176 | with T.autocast(device_type='cuda', dtype=T.bfloat16): 177 | if self.c.split: 178 | x1, x2 = self.out_norm(x1), self.out_norm(x2) 179 | out1, out2 = self.io.output(x1), self.io.output(x2) 180 | else: 181 | x = self.out_norm(x) 182 | out = self.io.output(x) 183 | 184 | if isnt(temps): 185 | if self.c.split: 186 | return out1, out2 187 | else: 188 | return out 189 | else: 190 | if self.c.split: 191 | next_data1 = self.io.temp_sample(out1, temps)[:, -1:, :] 192 | next_data2 = self.io2.temp_sample(out2, temps)[:, -1:, :] 193 | next_data = T.cat([next_data1, next_data2], dim=-1) 194 | return next_data 195 | else: 196 | next_data = self.io.temp_sample(out, temps)[:, -1:, :] 197 | return next_data 198 | 199 | @si_module 200 | class HertzDevModel(nn.Module): 201 | class Config: 202 | dim: int 203 | vocab_size: int 204 | stack_config: Optional[Stack.Config] = None 205 | latent_size: int = 32 206 | 207 | split: bool = True 208 | 209 | quantizer_config: Optional[LatentQuantizer.Config] = None 210 | resynthesizer_config: Optional[TransformerVAE.Config] = None 211 | 212 | from_pretrained: Optional[Tuple[str, str]] = None 213 | 214 | def __init__(self, c: Config): 215 | super().__init__() 216 | 217 | if exists(c.from_pretrained): 218 | checkpoint = load_ckpt(*c.from_pretrained) 219 | else: 220 | assert (exists(c.stack_config)), f'hmm {c}' 221 | 222 | self.input = nn.Linear(c.latent_size, c.dim) 223 | if self.c.split: 224 | self.input2 = nn.Linear(c.latent_size, c.dim) 225 | 226 | self.shape_rotator = ShapeRotator(c.stack_config.dim//c.stack_config.n_head, c.stack_config.seq_len, theta=c.stack_config.theta) 227 | 228 | self.layers = nn.ModuleList([ 229 | PerfBlock( 230 | dim=c.stack_config.dim, 231 | layer_id=l, 232 | n_head=c.stack_config.n_head, 233 | kv_heads=c.stack_config.kv_heads, 234 | ff_dim=c.stack_config.ff_dim, 235 | eps=c.stack_config.eps, 236 | shape_rotator=self.shape_rotator, 237 | ) for l in range(c.stack_config.layers) 238 | ]) 239 | 240 | self.output = GPTOutput(c.dim, c.vocab_size) 241 | if self.c.split: 242 | self.output2 = GPTOutput(c.dim, c.vocab_size) 243 | 244 | self.cache = [None] * c.stack_config.layers 245 | self.kv_heads = c.stack_config.kv_heads or c.stack_config.n_head 246 | self.head_dim = c.stack_config.dim // c.stack_config.n_head 247 | 248 | if exists(c.from_pretrained): 249 | result = self.load_state_dict(checkpoint, strict=False) 250 | print0_colored(result, 'yellow') 251 | 252 | self.resynthesizer = c.resynthesizer_config().eval() 253 | self.resynthesizer.requires_grad = False 254 | 255 | self.audio_tokenizer = make_tokenizer(device='cpu') 256 | self.audio_cache = None 257 | self.audio_latent_cache = None 258 | self.use_audio_cache = False 259 | 260 | @T.no_grad() 261 | def tokenize(self, audio_data): 262 | orig_audio_shape = audio_data.shape 263 | if exists(self.audio_cache): 264 | audio_data = T.cat([self.audio_cache, audio_data], dim=-1) 265 | self.audio_cache = audio_data[..., -(6*16_000):] 266 | elif self.use_audio_cache: 267 | self.audio_cache = audio_data[..., -(6*16_000):] 268 | 269 | if audio_data.shape[1] == 2: 270 | enc_ch1 = self.audio_tokenizer.latent_from_data(audio_data[:, 0:1]) 271 | enc_ch2 = self.audio_tokenizer.latent_from_data(audio_data[:, 1:2]) 272 | return T.cat([enc_ch1, enc_ch2], dim=-1)[:, -(orig_audio_shape[-1]//2000):] 273 | else: 274 | return self.audio_tokenizer.latent_from_data(audio_data)[:, -(orig_audio_shape[-1]//2000):] 275 | 276 | @T.no_grad() 277 | def untokenize(self, token_data): 278 | if exists(self.audio_latent_cache): 279 | token_data = T.cat([self.audio_latent_cache, token_data], dim=1) 280 | self.audio_latent_cache = token_data[:, -(6*8):] 281 | elif self.use_audio_cache: 282 | self.audio_latent_cache = token_data[:, -(6*8):] 283 | 284 | if token_data.shape[-1] == 2*self.c.latent_size: 285 | dec_ch1 = self.audio_tokenizer.data_from_latent(token_data[:, :self.c.latent_size]) 286 | dec_ch2 = self.audio_tokenizer.data_from_latent(token_data[:, self.c.latent_size:]) 287 | return T.cat([dec_ch1, dec_ch2], dim=1)[..., -(token_data.shape[1]*2000):] 288 | else: 289 | return self.audio_tokenizer.data_from_latent(token_data)[..., -(token_data.shape[1]*2000):] 290 | 291 | def init_cache(self, bsize, device, dtype, length:int=None): 292 | cache_shape = [self.c.stack_config.layers, length or self.c.stack_config.seq_len, 2, self.kv_heads, self.head_dim] 293 | self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1) 294 | self.resynthesizer.init_cache(bsize, device, dtype, length) 295 | self.use_audio_cache = True 296 | 297 | def deinit_cache(self): 298 | self.cache = [None] * len(self.layers) 299 | self.resynthesizer.deinit_cache() 300 | self.audio_cache = None 301 | self.audio_latent_cache = None 302 | self.use_audio_cache = False 303 | 304 | @T.no_grad() 305 | def forward(self, data): 306 | if self.c.split: 307 | x1, x2 = data.chunk(2, dim=-1) 308 | x = self.input(x1) + self.input2(x2) 309 | else: 310 | x = self.input(data) 311 | 312 | for l, layer in enumerate(self.layers): 313 | x = layer(x, kv=self.cache[l]) 314 | 315 | if self.c.split: 316 | return self.output(x), self.output2(x) 317 | else: 318 | return self.output(x) 319 | 320 | @T.no_grad() 321 | def next_audio_from_audio(self, audio_data: T.Tensor, temps=(0.8, (0.5, 0.1))): 322 | latents_in = self.tokenize(audio_data) 323 | next_latents = self.next_latent(latents_in, temps) 324 | next_model_latent = next_latents[..., self.c.latent_size:] 325 | audio_decoded = self.untokenize(next_model_latent)[..., -2000:] 326 | return audio_decoded 327 | 328 | 329 | @T.no_grad() 330 | def next_latent(self, model_input: T.Tensor, temps=(0.8, (0.5, 0.1))): 331 | 332 | if self.c.split: 333 | logits1, logits2 = self.forward(model_input) 334 | next_logits1 = logits1[:, -1] 335 | next_logits2 = logits2[:, -1] 336 | next_token1 = F.softmax(next_logits1 / temps[0], dim=-1).multinomial(1) 337 | next_token2 = F.softmax(next_logits2 / temps[0], dim=-1).multinomial(1) 338 | 339 | next_input = self.resynthesizer(model_input, next_tokens=(next_token1, next_token2), temps=temps[1]) 340 | else: 341 | logits = self.forward(model_input) 342 | next_logits = logits[:, -1] 343 | next_token = F.softmax(next_logits / temps[0], dim=-1).multinomial(1) 344 | 345 | next_input = self.resynthesizer(model_input, next_tokens=next_token, temps=temps[1]) 346 | 347 | return next_input 348 | 349 | 350 | @T.no_grad() 351 | def completion(self, data: T.Tensor, temps=(0.8, (0.5, 0.1)), gen_len=None, use_cache=True) -> T.Tensor: 352 | """ 353 | only accepts latent-space data. 354 | """ 355 | if use_cache: 356 | self.init_cache(data.shape[0], data.device, T.bfloat16) 357 | 358 | next_input = generated = data 359 | 360 | target_len = min(data.shape[1] + default(gen_len, data.shape[1]), self.c.stack_config.seq_len) 361 | 362 | for _ in tqdm0(range(data.shape[1], target_len)): 363 | model_input = next_input if use_cache else generated 364 | 365 | next_input = self.next_latent(model_input, temps) 366 | 367 | generated = T.cat([generated, next_input], dim=1) 368 | 369 | if use_cache: 370 | self.deinit_cache() 371 | return generated 372 | 373 | 374 | 375 | def get_hertz_dev_config(is_split=True, use_pure_audio_ablation=False): 376 | if is_split: 377 | checkpoints = [('inference_care_50000', 'e4ff4fe5c7e9f066410d2a5673b7a935'), ('inference_scion_54000', 'cb8bc484423922747b277ebc2933af5d')] 378 | elif not use_pure_audio_ablation: 379 | checkpoints = [('inference_whip_72000', '5e7cee7316900737d55fc5d44cc7a8f7'), ('inference_caraway_112000', 'fcb8368ef8ebf7712f3e31e6856da580')] 380 | else: 381 | checkpoints = [('inference_whip_72000', '5e7cee7316900737d55fc5d44cc7a8f7'), ('inference_syrup_110000', '353c48f553f1706824c11f3bb6a049e9')] 382 | 383 | quantizer_config=LatentQuantizer.Config( 384 | from_pretrained=('inference_volcano_3', 'd42bf674022c5f84b051d5d7794f6169'), 385 | compressor_config=FSQ.Config( 386 | levels=[8,8,8,8,8], 387 | dim=2048, 388 | num_codebooks=1, 389 | keep_num_codebooks_dim=None, 390 | scale=None, 391 | allowed_dtypes=['float32', 'float64', 'bfloat16'], 392 | channel_first=False, 393 | projection_has_bias=True, 394 | return_indices=True, 395 | force_quantization_f32=True, 396 | use_rms=False 397 | ), 398 | dim=2048, 399 | ff_dim=8192, 400 | input_dim=32 401 | ) 402 | 403 | resynthesizer_config=TransformerVAE.Config( 404 | io_config=GaussianMixtureIOLayer.Config( 405 | latent_dim=32, 406 | dim=4096, 407 | num_components=8, 408 | ), 409 | stack_config=Stack.Config( 410 | layers=8, 411 | dim=4096, 412 | seq_len=8192, 413 | n_head=16, 414 | ff_dim=11008, 415 | kv_heads=16, 416 | eps=1e-5, 417 | theta=10_000 418 | ), 419 | quantizer_config=quantizer_config, 420 | plex_layer=None, 421 | plex_roll=1, 422 | split=is_split, 423 | from_pretrained=checkpoints[0], 424 | ) 425 | 426 | return HertzDevModel.Config( 427 | dim=4096, 428 | vocab_size=32_768, 429 | stack_config=Stack.Config( 430 | layers=32, 431 | dim=4096, 432 | seq_len=2048, 433 | n_head=32, 434 | ff_dim=None, 435 | kv_heads=None, 436 | eps=1e-5, 437 | theta=10_000, 438 | ), 439 | quantizer_config=quantizer_config, 440 | resynthesizer_config=resynthesizer_config, 441 | split=is_split, 442 | from_pretrained=checkpoints[1], 443 | ) -------------------------------------------------------------------------------- /prompts/bob_duo.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Standard-Intelligence/hertz-dev/076c2751cf1c73534b1cc4330298239d845f5ec3/prompts/bob_duo.wav -------------------------------------------------------------------------------- /prompts/bob_mono.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Standard-Intelligence/hertz-dev/076c2751cf1c73534b1cc4330298239d845f5ec3/prompts/bob_mono.wav -------------------------------------------------------------------------------- /prompts/countdown_mono.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Standard-Intelligence/hertz-dev/076c2751cf1c73534b1cc4330298239d845f5ec3/prompts/countdown_mono.wav -------------------------------------------------------------------------------- /prompts/toaskanymore.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Standard-Intelligence/hertz-dev/076c2751cf1c73534b1cc4330298239d845f5ec3/prompts/toaskanymore.wav -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.5.1 2 | torchaudio==2.5.1 3 | einops==0.8.0 4 | tqdm==4.66.6 5 | ipython==8.18.1 6 | numpy==1.26.3 7 | soundfile==0.12.1 8 | websockets==13.1 9 | requests==2.32.3 10 | sounddevice==0.5.1 11 | matplotlib==3.9.2 12 | fastapi==0.115.4 13 | uvicorn==0.32.0 14 | huggingface-hub[hf_transfer]==0.26.2 15 | IProgress==0.4 -------------------------------------------------------------------------------- /requirements_webrtc.txt: -------------------------------------------------------------------------------- 1 | streamlit==1.33.0 2 | streamlit-webrtc==0.47.9 -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from typing import Union, Tuple, Literal 4 | 5 | import torch as T 6 | import torch.nn as nn 7 | from torch.nn.utils.parametrizations import weight_norm 8 | 9 | from utils import load_ckpt 10 | from utils.interp import print_colored 11 | from utils import si_module, get_activation 12 | 13 | 14 | 15 | # Adapted from https://github.com/facebookresearch/AudioDec 16 | 17 | def Conv1d1x1(in_channels, out_channels, bias=True): 18 | return nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=bias) 19 | 20 | 21 | class NonCausalConv1d(nn.Module): 22 | """1D noncausal convolution w/ 2-sides padding.""" 23 | 24 | def __init__( 25 | self, 26 | in_channels, 27 | out_channels, 28 | kernel_size, 29 | stride=1, 30 | padding=-1, 31 | dilation=1, 32 | groups=1, 33 | bias=True): 34 | super().__init__() 35 | self.in_channels = in_channels 36 | self.out_channels = out_channels 37 | self.kernel_size = kernel_size 38 | if padding < 0: 39 | padding = (kernel_size - 1) // 2 * dilation 40 | self.dilation = dilation 41 | self.conv = nn.Conv1d( 42 | in_channels=in_channels, 43 | out_channels=out_channels, 44 | kernel_size=kernel_size, 45 | stride=stride, 46 | padding=padding, 47 | dilation=dilation, 48 | groups=groups, 49 | bias=bias, 50 | ) 51 | 52 | def forward(self, x): 53 | """ 54 | Args: 55 | x (Tensor): Float tensor variable with the shape (B, C, T). 56 | Returns: 57 | Tensor: Float tensor variable with the shape (B, C, T). 58 | """ 59 | x = self.conv(x) 60 | return x 61 | 62 | 63 | class NonCausalConvTranspose1d(nn.Module): 64 | """1D noncausal transpose convolution.""" 65 | 66 | def __init__( 67 | self, 68 | in_channels, 69 | out_channels, 70 | kernel_size, 71 | stride, 72 | padding=-1, 73 | output_padding=-1, 74 | groups=1, 75 | bias=True, 76 | ): 77 | super().__init__() 78 | if padding < 0: 79 | padding = (stride+1) // 2 80 | if output_padding < 0: 81 | output_padding = 1 if stride % 2 else 0 82 | self.deconv = nn.ConvTranspose1d( 83 | in_channels=in_channels, 84 | out_channels=out_channels, 85 | kernel_size=kernel_size, 86 | stride=stride, 87 | padding=padding, 88 | output_padding=output_padding, 89 | groups=groups, 90 | bias=bias, 91 | ) 92 | 93 | def forward(self, x): 94 | """ 95 | Args: 96 | x (Tensor): Float tensor variable with the shape (B, C, T). 97 | Returns: 98 | Tensor: Float tensor variable with the shape (B, C', T'). 99 | """ 100 | x = self.deconv(x) 101 | return x 102 | 103 | 104 | class CausalConv1d(NonCausalConv1d): 105 | def __init__( 106 | self, 107 | in_channels, 108 | out_channels, 109 | kernel_size, 110 | stride=1, 111 | dilation=1, 112 | groups=1, 113 | bias=True 114 | ): 115 | super(CausalConv1d, self).__init__( 116 | in_channels=in_channels, 117 | out_channels=out_channels, 118 | kernel_size=kernel_size, 119 | stride=stride, 120 | padding=0, 121 | dilation=dilation, 122 | groups=groups, 123 | bias=bias, 124 | ) 125 | self.stride = stride 126 | self.pad_length = (kernel_size - 1) * dilation 127 | def forward(self, x): 128 | pad = nn.ConstantPad1d((self.pad_length, 0), 0.0) 129 | x = pad(x) 130 | return self.conv(x) 131 | 132 | 133 | class CausalConvTranspose1d(NonCausalConvTranspose1d): 134 | def __init__( 135 | self, 136 | in_channels, 137 | out_channels, 138 | kernel_size, 139 | stride, 140 | bias=True, 141 | pad_buffer=None, 142 | ): 143 | super(CausalConvTranspose1d, self).__init__( 144 | in_channels=in_channels, 145 | out_channels=out_channels, 146 | kernel_size=kernel_size, 147 | stride=stride, 148 | padding=0, 149 | output_padding=0, 150 | bias=bias, 151 | ) 152 | self.stride = stride 153 | self.pad_length = (math.ceil(kernel_size/stride) - 1) 154 | if pad_buffer is None: 155 | pad_buffer = T.zeros(1, in_channels, self.pad_length) 156 | self.register_buffer("pad_buffer", pad_buffer) 157 | 158 | def forward(self, x): 159 | pad = nn.ReplicationPad1d((self.pad_length, 0)) 160 | x = pad(x) 161 | return self.deconv(x)[:, :, self.stride : -self.stride] 162 | 163 | def inference(self, x): 164 | x = T.cat((self.pad_buffer, x), -1) 165 | self.pad_buffer = x[:, :, -self.pad_length:] 166 | return self.deconv(x)[:, :, self.stride : -self.stride] 167 | 168 | def reset_buffer(self): 169 | self.pad_buffer.zero_() 170 | 171 | 172 | class NonCausalResUnit(nn.Module): 173 | def __init__( 174 | self, 175 | in_channels, 176 | out_channels, 177 | kernel_size=7, 178 | dilation=1, 179 | bias=False, 180 | ): 181 | super().__init__() 182 | self.activation = nn.ELU() 183 | self.conv1 = NonCausalConv1d( 184 | in_channels=in_channels, 185 | out_channels=out_channels, 186 | kernel_size=kernel_size, 187 | stride=1, 188 | dilation=dilation, 189 | bias=bias, 190 | ) 191 | self.conv2 = Conv1d1x1(out_channels, out_channels, bias) 192 | 193 | def forward(self, x): 194 | y = self.conv1(self.activation(x)) 195 | y = self.conv2(self.activation(y)) 196 | return x + y 197 | 198 | 199 | class CausalResUnit(NonCausalResUnit): 200 | def __init__( 201 | self, 202 | in_channels, 203 | out_channels, 204 | kernel_size=7, 205 | dilation=1, 206 | bias=False, 207 | ): 208 | super(CausalResUnit, self).__init__( 209 | in_channels=in_channels, 210 | out_channels=out_channels, 211 | kernel_size=kernel_size, 212 | dilation=dilation, 213 | bias=bias, 214 | ) 215 | self.conv1 = CausalConv1d( 216 | in_channels=in_channels, 217 | out_channels=out_channels, 218 | kernel_size=kernel_size, 219 | stride=1, 220 | dilation=dilation, 221 | bias=bias, 222 | ) 223 | 224 | def inference(self, x): 225 | y = self.conv1.inference(self.activation(x)) 226 | y = self.conv2(self.activation(y)) 227 | return x + y 228 | 229 | 230 | class ResNetBlock(nn.Module): 231 | def __init__(self, 232 | in_channels, 233 | out_channels, 234 | stride, 235 | kernel_size=7, 236 | dilations=(1, 3, 9), 237 | bias=True, 238 | mode='encoder', 239 | ): 240 | super().__init__() 241 | assert mode in ('encoder', 'decoder'), f"Mode ({mode}) is not supported!" 242 | 243 | self.mode = mode 244 | self.stride = stride 245 | 246 | ConvUnit = CausalConv1d if mode == 'encoder' else CausalConvTranspose1d 247 | 248 | res_channels = in_channels if mode == 'encoder' else out_channels 249 | 250 | res_units = [CausalResUnit( 251 | res_channels, 252 | res_channels, 253 | kernel_size=kernel_size, 254 | dilation=dilation, 255 | ) for dilation in dilations] 256 | 257 | if in_channels == out_channels: 258 | if mode == 'encoder': 259 | self.pool = nn.AvgPool1d(kernel_size=stride, stride=stride) 260 | if mode == 'decoder': 261 | self.upsample = nn.Upsample(scale_factor=stride, mode='nearest') 262 | conv_unit = nn.Conv1d( 263 | in_channels=in_channels, 264 | out_channels=out_channels, 265 | kernel_size=1, 266 | bias=bias, 267 | ) if in_channels != out_channels else nn.Identity() 268 | else: 269 | conv_unit = ConvUnit( 270 | in_channels=in_channels, 271 | out_channels=out_channels, 272 | kernel_size=(2 * stride), 273 | stride=stride, 274 | bias=bias, 275 | ) 276 | 277 | if mode == 'encoder': 278 | if in_channels == out_channels: 279 | self.res_block = nn.Sequential(*res_units, self.pool, conv_unit) 280 | else: 281 | self.res_block = nn.Sequential(*res_units, conv_unit) 282 | elif mode == 'decoder': 283 | if in_channels == out_channels: 284 | self.res_block = nn.Sequential(self.upsample, conv_unit, *res_units) 285 | else: 286 | self.res_block = nn.Sequential(conv_unit, *res_units) 287 | 288 | def forward(self, x): 289 | out = x 290 | for unit in self.res_block: 291 | out = unit(out) 292 | return out 293 | 294 | def inference(self, x): 295 | for unit in self.res_block: 296 | x = unit.inference(x) 297 | return x 298 | 299 | 300 | 301 | 302 | @si_module 303 | class ResNetStack(nn.Module): 304 | """ 305 | ResNet encoder or decoder stack. Channel ratios 306 | and strides take the default order of from 307 | data/io-layer, to the middle of the model. 308 | """ 309 | class Config: 310 | input_channels: int = 1 311 | output_channels: int = 1 312 | encode_channels: int = 32 313 | decode_channel_multiplier: int = 1 314 | latent_dim: int = None 315 | kernel_size: int = 7 316 | bias: bool = True 317 | channel_ratios: Tuple[int, ...] = (2, 4, 8, 16) 318 | strides: Tuple[int, ...] = (3, 4, 5, 5) 319 | mode: Literal['encoder', 'decoder'] = 'encoder' 320 | 321 | def __init__(self, c: Config): 322 | super().__init__() 323 | assert c.mode in ('encoder', 'decoder'), f"Mode ({c.mode}) is not supported!" 324 | 325 | self.mode = c.mode 326 | 327 | assert len(c.channel_ratios) == len(c.strides) 328 | channel_ratios = (1,) + c.channel_ratios 329 | strides = c.strides 330 | self.middle_channels = c.encode_channels * channel_ratios[-1] 331 | if c.mode == 'decoder': 332 | channel_ratios = tuple(reversed(channel_ratios)) 333 | strides = tuple(reversed(strides)) 334 | 335 | self.multiplier = c.decode_channel_multiplier if c.mode == 'decoder' else 1 336 | res_blocks = [ResNetBlock( 337 | c.encode_channels * channel_ratios[s_idx] * self.multiplier, 338 | c.encode_channels * channel_ratios[s_idx+1] * self.multiplier, 339 | stride, 340 | kernel_size=c.kernel_size, 341 | bias=c.bias, 342 | mode=c.mode, 343 | ) for s_idx, stride in enumerate(strides)] 344 | 345 | data_conv = CausalConv1d( 346 | in_channels=c.input_channels if c.mode == 'encoder' else c.encode_channels * self.multiplier, 347 | out_channels=c.encode_channels if c.mode == 'encoder' else c.output_channels, 348 | kernel_size=c.kernel_size, 349 | stride=1, 350 | bias=False, 351 | ) 352 | 353 | if c.mode == 'encoder': 354 | self.res_stack = nn.Sequential(data_conv, *res_blocks) 355 | elif c.mode == 'decoder': 356 | self.res_stack = nn.Sequential(*res_blocks, data_conv) 357 | 358 | if c.latent_dim is not None: 359 | self.latent_proj = Conv1d1x1(self.middle_channels, c.latent_dim, bias=c.bias) if c.mode == 'encoder' else Conv1d1x1(c.latent_dim, self.middle_channels, bias=c.bias) 360 | if self.multiplier != 1: 361 | self.multiplier_proj = Conv1d1x1(self.middle_channels, self.middle_channels * self.multiplier, bias=c.bias) 362 | 363 | def forward(self, x, return_feats=False): 364 | if self.c.latent_dim is not None and self.mode == 'decoder': 365 | x = self.latent_proj(x) 366 | if self.multiplier != 1: 367 | x = self.multiplier_proj(x) 368 | 369 | feats = [] 370 | for block in self.res_stack: 371 | x = block(x) 372 | if return_feats: 373 | feats.append(x) 374 | if self.c.latent_dim is not None and self.mode == 'encoder': 375 | x = self.latent_proj(x) 376 | if return_feats: 377 | feats.append(x) 378 | if return_feats: 379 | return feats 380 | return x 381 | 382 | def inference(self, x): 383 | for block in self.res_stack: 384 | x = block.inference(x) 385 | return x 386 | 387 | def reset_buffer(self): 388 | def _reset_buffer(m): 389 | if isinstance(m, CausalConv1d) or isinstance(m, CausalConvTranspose1d): 390 | m.reset_buffer() 391 | self.apply(_reset_buffer) 392 | 393 | def reset_parameters(self): 394 | def _reset_parameters(m): 395 | if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)): 396 | m.weight.data.normal_(0.0, 0.01) 397 | 398 | self.apply(_reset_parameters) 399 | 400 | 401 | def apply_weight_norm(self): 402 | def _apply_weight_norm(m): 403 | if isinstance(m, nn.Conv1d) or isinstance( 404 | m, nn.ConvTranspose1d 405 | ): 406 | nn.utils.parametrizations.weight_norm(m) 407 | 408 | self.apply(_apply_weight_norm) 409 | 410 | 411 | def remove_weight_norm(self): 412 | def _remove_weight_norm(m): 413 | try: 414 | print(m) 415 | nn.utils.remove_weight_norm(m) 416 | except ValueError: # this module didn't have weight norm 417 | return 418 | 419 | self.apply(_remove_weight_norm) 420 | 421 | 422 | 423 | @si_module 424 | class GaussianZ(nn.Module): 425 | class Config: 426 | dim: int 427 | latent_dim: int 428 | bias: bool = False 429 | use_weight_norm: bool = False 430 | 431 | def __init__(self, c: Config): 432 | super().__init__() 433 | 434 | self.proj_in = nn.Linear(c.dim, c.latent_dim * 2, bias=c.bias) 435 | self.proj_out = nn.Linear(c.latent_dim, c.dim, bias=c.bias) 436 | 437 | if c.use_weight_norm: 438 | self.proj_in = weight_norm(self.proj_in) 439 | self.proj_out = weight_norm(self.proj_out) 440 | 441 | def reparam(self, mu, logvar): 442 | std = T.exp(logvar / 2) 443 | eps = T.randn_like(std) 444 | return mu + eps * std 445 | 446 | def kl_divergence(self, mu, logvar): 447 | return T.mean(-0.5 * T.sum( 448 | 1 + logvar - mu.pow(2) - logvar.exp(), 449 | dim=(1, 2)) 450 | ) 451 | 452 | def repr_from_latent(self, latent: Union[dict, T.Tensor]): 453 | if isinstance(latent, T.Tensor): 454 | z = latent 455 | else: 456 | z = self.reparam(latent['mu'], latent['logvar']) 457 | l = self.proj_out(z) 458 | return l 459 | 460 | def forward(self, x: T.Tensor) -> Tuple[T.Tensor, dict]: 461 | mu, logvar = self.proj_in(x).chunk(2, dim=-1) 462 | kl_div = self.kl_divergence(mu, logvar) 463 | z = self.reparam(mu, logvar) 464 | xhat = self.proj_out(z) 465 | latent = {'mu': mu, 'logvar': logvar, 'z': z, 'kl_divergence': kl_div} 466 | return xhat, latent 467 | 468 | 469 | 470 | @si_module 471 | class WaveCodec(nn.Module): 472 | class Config: 473 | resnet_config: ResNetStack.Config = None 474 | sample_rate: int = 16_000 475 | use_weight_norm: bool = False 476 | 477 | compressor_config: dataclass = None 478 | 479 | norm_stddev: float = 1.0 480 | 481 | def __init__(self, c: Config): 482 | super().__init__() 483 | self.norm_stddev = c.norm_stddev 484 | self.encoder = c.resnet_config(mode='encoder') 485 | self.sample_rate = c.sample_rate 486 | 487 | self.total_stride = 1 488 | for stride in c.resnet_config.strides: 489 | self.total_stride *= stride 490 | self.tokens_per_second = self.sample_rate / self.total_stride 491 | 492 | self.compressor = c.compressor_config(dim=self.encoder.middle_channels) 493 | 494 | self.decoder = c.resnet_config(mode='decoder') 495 | 496 | if c.use_weight_norm: 497 | self.encoder.apply_weight_norm() 498 | self.decoder.apply_weight_norm() 499 | self.encoder.reset_parameters() 500 | self.decoder.reset_parameters() 501 | 502 | def encode(self, data): 503 | return self.encoder(data/self.norm_stddev) 504 | 505 | def decode(self, latent): 506 | return self.decoder(latent.transpose(1, 2))*self.norm_stddev 507 | 508 | @T.no_grad() 509 | def latent_from_data(self, data, get_parameters=False): 510 | x = self.encode(data) 511 | l_in = x.transpose(1, 2) 512 | l, latent = self.compressor(l_in) 513 | return latent['z'] if not get_parameters else { 514 | 'mu': latent['mu'], 515 | 'logvar': latent['logvar'], 516 | 'z': latent['z'], 517 | } 518 | 519 | @T.no_grad() 520 | def data_from_latent(self, latent): 521 | l = self.compressor.repr_from_latent(latent) 522 | x = self.decode(l) 523 | return x 524 | 525 | def process(self, x): 526 | return self.latent_from_data(x) 527 | 528 | def unprocess(self, latent): 529 | return self.data_from_latent(latent) 530 | 531 | def forward(self, audio_input): 532 | x = self.encode(audio_input) 533 | 534 | l_in = x.transpose(1, 2) 535 | l, latent = self.compressor(l_in) 536 | 537 | xhat = self.decode(l) 538 | return xhat, latent 539 | 540 | 541 | 542 | def make_tokenizer(device='cuda'): 543 | generator_config = WaveCodec.Config( 544 | resnet_config=ResNetStack.Config( 545 | input_channels=1, 546 | output_channels=1, 547 | encode_channels=16, 548 | decode_channel_multiplier=4, 549 | kernel_size=7, 550 | bias=True, 551 | channel_ratios=(4, 8, 16, 16, 16, 16), 552 | strides=(2, 2, 4, 5, 5, 5), 553 | mode=None, 554 | ), 555 | use_weight_norm=True, 556 | 557 | compressor_config=GaussianZ.Config( 558 | dim=None, 559 | latent_dim=32, 560 | 561 | bias=True, 562 | use_weight_norm=True 563 | ), 564 | 565 | norm_stddev=0.05, 566 | ) 567 | checkpoint = load_ckpt("inference_apatosaurus_95000", expected_hash="ba876edb97b988e9196e449dd176ca97") 568 | 569 | tokenizer = generator_config() 570 | 571 | load_result = tokenizer.load_state_dict(checkpoint, strict=False) 572 | print_colored(f"Loaded tokenizer state dict: {load_result}", "grey") 573 | 574 | tokenizer = tokenizer.eval() 575 | # Only convert to bfloat16 if using CUDA 576 | if device == 'cuda': 577 | tokenizer = tokenizer.bfloat16() 578 | tokenizer = tokenizer.to(device) 579 | tokenizer.requires_grad_ = False 580 | return tokenizer 581 | 582 | -------------------------------------------------------------------------------- /transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, MutableMapping 2 | from typing import Union 3 | import math 4 | from contextlib import nullcontext 5 | 6 | import torch 7 | import torch as T 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch import Tensor 11 | from torch.nn.attention import SDPBackend 12 | 13 | from einops import rearrange 14 | 15 | from utils import si_module, default, exists, load_ckpt 16 | 17 | CACHE_FILL_VALUE = -1 18 | 19 | def get_cache_len(cache: Optional[Tensor]) -> int: 20 | """ 21 | cache: (batch, seq_len, 2, kv_heads, head_dim) 22 | """ 23 | if cache is None: 24 | return 0 25 | nonzeros = T.any(cache.flatten(2) != CACHE_FILL_VALUE, dim=-1) 26 | length = nonzeros.sum(dim=-1).int() 27 | assert T.all(length == length[0]) 28 | return length[0] 29 | 30 | 31 | def rotate_half(x): 32 | x1, x2 = x.chunk(2, dim=-1) 33 | return torch.cat((-x2, x1), dim=-1) 34 | 35 | 36 | def apply_rotary_pos_emb(x, cos, sin, offset: int = 0): 37 | assert ( 38 | cos.shape[1] >= offset + x.shape[1] 39 | ), f"Offset and/or input sequence is too large,\ 40 | \n offset: {offset}, seq_len: {x.shape[1]}, max: {cos.shape[1]}" 41 | 42 | cos_out = cos[:, offset : offset + x.shape[1], :, :] 43 | sin_out = sin[:, offset : offset + x.shape[1], :, :] 44 | 45 | return (x * cos_out) + (rotate_half(x) * sin_out) 46 | 47 | 48 | # Adapted from https://github.com/foundation-model-stack/foundation-model-stack 49 | class ShapeRotator: 50 | def __init__( 51 | self, 52 | dim: int, 53 | end: int, 54 | theta: float = 10_000, 55 | ): 56 | super().__init__() 57 | self.dim = dim 58 | self.ratio = theta 59 | self.cached_freqs: MutableMapping[int, MutableMapping[int, torch.Tensor]] = {} 60 | self.max_seq_len_cached: MutableMapping[int, int] = {} 61 | self.ntk_scaling = False 62 | self.max_seq_len = end 63 | 64 | def compute_freqs_cis(self, device, max_seq_len=None): 65 | alpha = 1 66 | dev_idx = device.index 67 | max_seq_len = default(max_seq_len, self.max_seq_len) 68 | 69 | if dev_idx not in self.cached_freqs: 70 | self.cached_freqs[dev_idx] = {} 71 | if dev_idx not in self.max_seq_len_cached: 72 | self.max_seq_len_cached[dev_idx] = 0 73 | 74 | 75 | if self.max_seq_len_cached[dev_idx] > 0: 76 | return 1 77 | max_seq_len = max(max_seq_len, self.max_seq_len) 78 | 79 | if ( 80 | 1 in self.cached_freqs[dev_idx] 81 | and max_seq_len <= self.max_seq_len_cached[dev_idx] 82 | ): 83 | return 1 84 | 85 | ratio = self.ratio 86 | dim = self.dim 87 | 88 | freqs = 1.0 / (ratio ** (torch.arange(0, dim, 2, device=device).float() / dim)) 89 | 90 | t = torch.arange(max_seq_len, device=device, dtype=freqs.dtype) 91 | freqs = torch.einsum("i,j->ij", t, freqs) 92 | emb = torch.cat((freqs, freqs), dim=-1).to(device) 93 | 94 | cos_to_cache = emb.cos()[None, :, None, :] 95 | sin_to_cache = emb.sin()[None, :, None, :] 96 | 97 | self.max_seq_len_cached[dev_idx] = max_seq_len 98 | 99 | self.cached_freqs[dev_idx][alpha] = torch.stack( 100 | [ 101 | cos_to_cache, 102 | sin_to_cache, 103 | ], 104 | dim=-1, 105 | ) 106 | 107 | return alpha 108 | 109 | def rotate( 110 | self, 111 | q: Tensor, 112 | k: Tensor, 113 | offset: int = 0, 114 | ) -> Tuple[Tensor, Tensor]: 115 | """ 116 | Args 117 | ---- 118 | q : torch.Tensor 119 | Embedded query tensor, expected size is B x S x H x Eh 120 | k : torch.Tensor 121 | Embedded query tensor, expected size is B x S x H x Eh 122 | """ 123 | assert len(q.size()) == 4 124 | assert len(k.size()) == 4 125 | 126 | seq_len = self.max_seq_len 127 | alpha = self.compute_freqs_cis(q.device, seq_len) 128 | freqs = self.cached_freqs[q.device.index][alpha] 129 | 130 | freqs = freqs.float() # 1 L D/2 2 2 131 | q_out = apply_rotary_pos_emb(q, freqs[..., 0], freqs[..., 1], offset=offset).type_as(q) 132 | k_out = apply_rotary_pos_emb(k, freqs[..., 0], freqs[..., 1], offset=offset).type_as(k) 133 | 134 | return q_out.view_as(q), k_out.view_as(k) 135 | 136 | class Linear(nn.Linear): 137 | def __init__(self, *args, **kwargs): 138 | super().__init__(*args, **kwargs, bias=False) 139 | 140 | class Norm(nn.Module): 141 | def __init__(self, 142 | dim: int, 143 | eps: float = 1e-5,) -> None: 144 | super().__init__() 145 | self.eps = eps 146 | self.weight = nn.Parameter(T.ones((dim,))) 147 | 148 | def forward(self, input: Tensor) -> Tensor: 149 | return F.layer_norm(input, (self.weight.shape[0],), weight=self.weight, bias=None, eps=self.eps) 150 | 151 | 152 | class FFNN(nn.Module): 153 | def __init__(self, 154 | dim: int, 155 | expand_dim: int = None,): 156 | super().__init__() 157 | expand_dim = default(expand_dim, 256 * ((int(2 * 4 * dim / 3) + 256 - 1) // 256)) 158 | self.dim = dim 159 | self.expand_dim = expand_dim 160 | 161 | self.gateup_proj = Linear(dim, 2*expand_dim) 162 | self.down_proj = Linear(expand_dim, dim) 163 | 164 | def forward(self, x): 165 | gate, up = self.gateup_proj(x).chunk(2, dim=-1) 166 | return self.down_proj(up * F.silu(gate)) 167 | 168 | class GQA(nn.Module): 169 | def __init__(self, 170 | dim: int, 171 | n_head: int, 172 | shape_rotator: ShapeRotator, 173 | kv_heads: Optional[int] = None, 174 | eps: float = 1e-5, 175 | causal: bool = True,): 176 | super().__init__() 177 | self.n_heads = n_head 178 | self.kv_heads = default(kv_heads, n_head) 179 | self.head_dim = dim // n_head 180 | self.causal = causal 181 | 182 | self.proj_qkv = Linear(dim, self.head_dim*(n_head+2*self.kv_heads)) 183 | 184 | self.norm_q = Norm(self.head_dim*n_head, eps=eps) 185 | self.norm_k = Norm(self.head_dim*self.kv_heads, eps=eps) 186 | 187 | self.attn_out = Linear(dim, dim) 188 | 189 | self.shape_rotator = shape_rotator 190 | 191 | def _sdpa(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 192 | k = k.repeat_interleave(self.n_heads // self.kv_heads, dim=2) 193 | v = v.repeat_interleave(self.n_heads // self.kv_heads, dim=2) 194 | x = F.scaled_dot_product_attention( 195 | q.transpose(1, 2), 196 | k.transpose(1, 2), 197 | v.transpose(1, 2), 198 | is_causal=False if (q.size(1) != k.size(1)) else self.causal, 199 | ) 200 | x = x.transpose(1, 2).contiguous() 201 | return x 202 | 203 | def _attend(self, q: Tensor, k: Tensor, v: Tensor, kv_cache: Optional[Tensor] = None,): 204 | cache_len = get_cache_len(kv_cache) 205 | q, k = self.shape_rotator.rotate(q, k, offset=cache_len) 206 | if exists(kv_cache): 207 | k = T.cat([kv_cache[:, :cache_len, 0], k], dim=1) 208 | v = T.cat([kv_cache[:, :cache_len, 1], v], dim=1) 209 | kv_cache[:, :k.size(1), 0] = k 210 | kv_cache[:, :v.size(1), 1] = v 211 | x = self._sdpa(q, k, v) 212 | return self.attn_out(rearrange(x, 'b s h d -> b s (h d)')) 213 | 214 | def _project(self, x): 215 | full_q, full_k, full_v = self.proj_qkv(x).chunk(3, dim=-1) 216 | normed_full_q = self.norm_q(full_q).to(full_q.dtype) 217 | normed_full_k = self.norm_k(full_k).to(full_k.dtype) 218 | 219 | q = rearrange(normed_full_q, 'b s (h d) -> b s h d', h=self.n_heads) 220 | k = rearrange(normed_full_k, 'b s (h d) -> b s h d', h=self.kv_heads) 221 | v = rearrange(full_v, 'b s (h d) -> b s h d', h=self.kv_heads) 222 | return q, k, v 223 | 224 | def forward(self, 225 | x: Tensor, 226 | kv: Optional[Tensor] = None,): 227 | """ 228 | x: (B, S, D) 229 | kv: (B, S, H, D) 230 | """ 231 | q, k, v = self._project(x) 232 | return self._attend(q, k, v, kv_cache=kv) 233 | 234 | 235 | class PreNormAttn(nn.Module): 236 | def __init__(self, 237 | dim: int, 238 | n_head: int, 239 | shape_rotator: ShapeRotator, 240 | kv_heads: Optional[int] = None, 241 | eps: float = 1e-5, 242 | causal: bool = True,): 243 | super().__init__() 244 | self.attn_norm = Norm(dim, eps=eps) 245 | self.attn = GQA(dim, n_head, shape_rotator, kv_heads, eps=eps, causal=causal) 246 | 247 | def forward(self, x: Tensor, kv: Optional[Tensor] = None) -> Tensor: 248 | """ 249 | x: (B, S, D) 250 | kv: (B, S, H, D) 251 | """ 252 | return x + self.attn(self.attn_norm(x), kv) 253 | 254 | class PreNormFFNN(nn.Module): 255 | def __init__(self, 256 | dim: int, 257 | ff_dim: int, 258 | eps: float = 1e-5,): 259 | super().__init__() 260 | self.ffnn_norm = Norm(dim, eps=eps) 261 | self.ffnn = FFNN(dim, ff_dim) 262 | 263 | def forward(self, x: Tensor) -> Tensor: 264 | return x + self.ffnn(self.ffnn_norm(x)) 265 | 266 | class Block(nn.Module): 267 | def __init__(self, 268 | dim: int, 269 | layer_id: int = 0, 270 | n_head: int = 16, 271 | kv_heads: Optional[int] = None, 272 | ff_dim: Optional[int] = None, 273 | eps: float = 1e-5, 274 | causal: bool = True, 275 | shape_rotator: ShapeRotator = None): 276 | super().__init__() 277 | self.attn = PreNormAttn(dim, n_head, shape_rotator, kv_heads, eps=eps, causal=causal) 278 | self.ffnn = PreNormFFNN(dim, ff_dim, eps=eps) 279 | self.dim = dim 280 | self.layer_id = layer_id 281 | self.head_dim = dim // n_head 282 | self.expand_dim = self.ffnn.ffnn.expand_dim 283 | 284 | self.reset_parameters() 285 | 286 | def reset_parameters(self): 287 | std = 1.0 / math.sqrt(self.dim) 288 | nn.init.trunc_normal_(self.ffnn.ffnn.gateup_proj.weight, std=std, a=-3 * std, b=3 * std) 289 | nn.init.trunc_normal_(self.attn.attn.proj_qkv.weight, std=std, a=-3 * std, b=3 * std) 290 | nn.init.trunc_normal_(self.attn.attn.attn_out.weight, std=std, a=-3 * std, b=3 * std) 291 | 292 | xstd = 1.0 / math.sqrt(self.expand_dim) 293 | nn.init.trunc_normal_(self.ffnn.ffnn.down_proj.weight, std=xstd, a=-3 * xstd, b=3 * xstd) 294 | 295 | def forward(self, x: Tensor, kv: Optional[Tensor] = None) -> Tensor: 296 | """ 297 | x: (B, S, D) 298 | kv: (B, S, H, D) 299 | """ 300 | h = self.attn(x, kv) 301 | out = self.ffnn(h) 302 | return out 303 | 304 | 305 | 306 | class GPTOutput(nn.Module): 307 | def __init__(self, dim, vocab_size): 308 | super().__init__() 309 | self.dim = dim 310 | self.norm = Norm(dim) 311 | self.output = Linear(dim, vocab_size) 312 | 313 | self.reset_parameters() 314 | 315 | def reset_parameters(self): 316 | std = 1.0 / math.sqrt(self.dim**2) 317 | nn.init.trunc_normal_(self.output.weight, std=std, a=-3 * std, b=3 * std) 318 | 319 | def forward(self, x): 320 | return self.output(self.norm(x)) 321 | 322 | @si_module 323 | class Stack(nn.Module): 324 | class Config: 325 | layers: int 326 | dim: int 327 | seq_len: int 328 | n_head: int = 32 329 | ff_dim: int = None 330 | kv_heads: int = None 331 | eps: float = 1e-5 332 | theta: Union[int, float] = 10_000 333 | causal: bool = True 334 | 335 | from_pretrained: Optional[Tuple[str, int]] = None 336 | 337 | def __init__(self, c: Config): 338 | super().__init__() 339 | 340 | from_pretrained = c.from_pretrained 341 | if exists(from_pretrained): 342 | checkpoint = load_ckpt(c.from_pretrained) 343 | 344 | self.shape_rotator = ShapeRotator(c.dim//c.n_head, c.seq_len, theta=c.theta) 345 | 346 | self.layers = nn.ModuleList([ 347 | Block( 348 | dim=c.dim, 349 | layer_id=l, 350 | n_head=c.n_head, 351 | kv_heads=c.kv_heads, 352 | ff_dim=c.ff_dim, 353 | eps=c.eps, 354 | causal=c.causal, 355 | shape_rotator=self.shape_rotator, 356 | ) for l in range(c.layers) 357 | ]) 358 | 359 | kv_heads = c.kv_heads or c.n_head 360 | head_dim = c.dim // c.n_head 361 | cache_shape = [c.layers, c.seq_len, 2, kv_heads, head_dim] 362 | self.cache_shape = cache_shape 363 | self.cache = [None] * c.layers 364 | 365 | if exists(from_pretrained): 366 | self.load_state_dict(checkpoint) 367 | 368 | def init_cache(self, bsize, device, dtype, length:int=None): 369 | if self.cache_shape is None: 370 | return 371 | cache_shape = self.cache_shape.copy() 372 | cache_shape[1] = length or cache_shape[1] 373 | self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1) 374 | 375 | def deinit_cache(self): 376 | self.cache = [None] * len(self.cache) 377 | 378 | def forward(self, x: Tensor) -> Tensor: 379 | for l, layer in enumerate(self.layers): 380 | x = layer(x, kv=self.cache[l]) 381 | return x 382 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .blocks import * 2 | from .dist import * 3 | from .interp import * -------------------------------------------------------------------------------- /utils/blocks.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import TypeVar, Generic, Type, Optional 3 | from functools import wraps 4 | import time 5 | import random 6 | 7 | import torch as T 8 | import torch.nn as nn 9 | 10 | # @TODO: remove si_module from codebase 11 | # we use this in our research codebase to make modules from callable configs 12 | si_module_TpV = TypeVar('si_module_TpV') 13 | def si_module(cls: Type[si_module_TpV]) -> Type[si_module_TpV]: 14 | if not hasattr(cls, 'Config') or not isinstance(cls.Config, type): 15 | class Config: 16 | pass 17 | cls.Config = Config 18 | 19 | cls.Config = dataclass(cls.Config) 20 | 21 | class ConfigWrapper(cls.Config, Generic[si_module_TpV]): 22 | def __call__(self, *args, **kwargs) -> si_module_TpV: 23 | if len(kwargs) > 0: 24 | config_dict = {field.name: getattr(self, field.name) for field in self.__dataclass_fields__.values()} 25 | config_dict.update(kwargs) 26 | new_config = type(self)(**config_dict) 27 | return cls(new_config) 28 | else: 29 | return cls(self, *args) 30 | 31 | ConfigWrapper.__module__ = cls.__module__ 32 | ConfigWrapper.__name__ = f"{cls.__name__}Config" 33 | ConfigWrapper.__qualname__ = f"{cls.__qualname__}.Config" 34 | 35 | cls.Config = ConfigWrapper 36 | 37 | original_init = cls.__init__ 38 | def new_init(self, *args, **kwargs): 39 | self.c = next((arg for arg in args if isinstance(arg, cls.Config)), None) or next((arg for arg in kwargs.values() if isinstance(arg, cls.Config)), None) 40 | original_init(self, *args, **kwargs) 41 | self.register_buffer('_device_tracker', T.Tensor(), persistent=False) 42 | 43 | cls.__init__ = new_init 44 | 45 | @property 46 | def device(self): 47 | return self._device_tracker.device 48 | 49 | @property 50 | def dtype(self): 51 | return self._device_tracker.dtype 52 | 53 | cls.device = device 54 | cls.dtype = dtype 55 | 56 | return cls 57 | 58 | 59 | def get_activation(nonlinear_activation, nonlinear_activation_params={}): 60 | if hasattr(nn, nonlinear_activation): 61 | return getattr(nn, nonlinear_activation)(**nonlinear_activation_params) 62 | else: 63 | raise NotImplementedError(f"Activation {nonlinear_activation} not found in torch.nn") 64 | 65 | 66 | def exists(v): 67 | return v is not None 68 | 69 | def isnt(v): 70 | return not exists(v) 71 | 72 | def truthyexists(v): 73 | return exists(v) and v is not False 74 | 75 | def truthyattr(obj, attr): 76 | return hasattr(obj, attr) and truthyexists(getattr(obj, attr)) 77 | 78 | defaultT = TypeVar('defaultT') 79 | 80 | def default(*args: Optional[defaultT]) -> Optional[defaultT]: 81 | for arg in args: 82 | if exists(arg): 83 | return arg 84 | return None 85 | 86 | def maybe(fn): 87 | @wraps(fn) 88 | def inner(x, *args, **kwargs): 89 | if not exists(x): 90 | return x 91 | return fn(x, *args, **kwargs) 92 | return inner 93 | -------------------------------------------------------------------------------- /utils/dist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch as T 3 | import re 4 | from tqdm import tqdm 5 | from datetime import timedelta 6 | 7 | import requests 8 | import hashlib 9 | 10 | from io import BytesIO 11 | from huggingface_hub import hf_hub_download 12 | 13 | def rank0(): 14 | rank = os.environ.get('RANK') 15 | if rank is None or rank == '0': 16 | return True 17 | else: 18 | return False 19 | 20 | def local0(): 21 | local_rank = os.environ.get('LOCAL_RANK') 22 | if local_rank is None or local_rank == '0': 23 | return True 24 | else: 25 | return False 26 | class tqdm0(tqdm): 27 | def __init__(self, *args, **kwargs): 28 | total = kwargs.get('total', None) 29 | if total is None and len(args) > 0: 30 | try: 31 | total = len(args[0]) 32 | except TypeError: 33 | pass 34 | if total is not None: 35 | kwargs['miniters'] = max(1, total // 20) 36 | super().__init__(*args, **kwargs, disable=not rank0(), bar_format='{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]') 37 | 38 | def print0(*args, **kwargs): 39 | if rank0(): 40 | print(*args, **kwargs) 41 | 42 | _PRINTED_IDS = set() 43 | 44 | def printonce(*args, id=None, **kwargs): 45 | if id is None: 46 | id = ' '.join(map(str, args)) 47 | 48 | if id not in _PRINTED_IDS: 49 | print(*args, **kwargs) 50 | _PRINTED_IDS.add(id) 51 | 52 | def print0once(*args, **kwargs): 53 | if rank0(): 54 | printonce(*args, **kwargs) 55 | 56 | def init_dist(): 57 | if T.distributed.is_initialized(): 58 | print0('Distributed already initialized') 59 | rank = T.distributed.get_rank() 60 | local_rank = int(os.environ.get('LOCAL_RANK', 0)) 61 | world_size = T.distributed.get_world_size() 62 | else: 63 | try: 64 | rank = int(os.environ['RANK']) 65 | local_rank = int(os.environ['LOCAL_RANK']) 66 | world_size = int(os.environ['WORLD_SIZE']) 67 | device = f'cuda:{local_rank}' 68 | T.cuda.set_device(device) 69 | T.distributed.init_process_group(backend='nccl', timeout=timedelta(minutes=30), rank=rank, world_size=world_size, device_id=T.device(device)) 70 | print(f'Rank {rank} of {world_size}.') 71 | except Exception as e: 72 | print0once(f'Not initializing distributed env: {e}') 73 | rank = 0 74 | local_rank = 0 75 | world_size = 1 76 | return rank, local_rank, world_size 77 | 78 | def load_ckpt(load_from_location, expected_hash=None): 79 | os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' #Disable this to speed up debugging errors with downloading from the hub 80 | if local0(): 81 | repo_id = "si-pbc/hertz-dev" 82 | print0(f'Loading checkpoint from repo_id {repo_id} and filename {load_from_location}.pt. This may take a while...') 83 | save_path = hf_hub_download(repo_id=repo_id, filename=f"{load_from_location}.pt") 84 | print0(f'Downloaded checkpoint to {save_path}') 85 | if expected_hash is not None: 86 | with open(save_path, 'rb') as f: 87 | file_hash = hashlib.md5(f.read()).hexdigest() 88 | if file_hash != expected_hash: 89 | print(f'Hash mismatch for {save_path}. Expected {expected_hash} but got {file_hash}. Deleting checkpoint and trying again.') 90 | os.remove(save_path) 91 | return load_ckpt(load_from_location, expected_hash) 92 | if T.distributed.is_initialized(): 93 | save_path = [save_path] 94 | T.distributed.broadcast_object_list(save_path, src=0) 95 | save_path = save_path[0] 96 | loaded = T.load(save_path, weights_only=False, map_location='cpu') 97 | print0(f'Loaded checkpoint from {save_path}') 98 | return loaded -------------------------------------------------------------------------------- /utils/interp.py: -------------------------------------------------------------------------------- 1 | import torch as T 2 | import os 3 | 4 | def rank0(): 5 | rank = os.environ.get('RANK') 6 | if rank is None or rank == '0': 7 | return True 8 | else: 9 | return False 10 | 11 | def print_colored(message, color='reset', bold=False, **kwargs): 12 | color_dict = { 13 | 'bold': '\033[1m', 14 | 'green': '\033[92m', 15 | 'yellow': '\033[93m', 16 | 'red': '\033[91m', 17 | 'blue': '\033[94m', 18 | 'grey': '\033[90m', 19 | 'white': '\033[97m', 20 | 'reset': '\033[0m' 21 | } 22 | 23 | color_code = color_dict.get(color.lower(), color_dict['reset']) 24 | prefix = color_dict['bold'] if bold else '' 25 | print(f"{prefix}{color_code}{message}{color_dict['reset']}", **kwargs) 26 | 27 | def print0_colored(*args, **kwargs): 28 | if rank0(): 29 | print_colored(*args, **kwargs) 30 | 31 | def param_count(module): 32 | def count_parameters(model): 33 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 34 | 35 | total_params = count_parameters(module) 36 | output = [f'Total model parameters: {total_params:,}', '---------------------------'] 37 | 38 | for name, child in module.named_children(): 39 | params = count_parameters(child) 40 | output.append(f'{name} parameters: {params:,}') 41 | 42 | return '\n'.join(output) 43 | 44 | def model_size_estimation(module): 45 | def estimate_size(model): 46 | param_size = sum(p.nelement() * p.element_size() for p in model.parameters()) 47 | buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers()) 48 | return param_size + buffer_size 49 | 50 | total_size = estimate_size(module) 51 | output = [f'Total model size: {total_size / 1024**2:.2f} MB', '---------------------------'] 52 | 53 | for name, child in module.named_children(): 54 | child_size = estimate_size(child) 55 | output.append(f'{name} size: {child_size / 1024**2:.2f} MB') 56 | 57 | return '\n'.join(output) 58 | 59 | def layer_param_distribution(module): 60 | def count_parameters(model): 61 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 62 | 63 | def get_layer_types(model): 64 | layer_types = {} 65 | for name, module in model.named_modules(): 66 | layer_type = module.__class__.__name__ 67 | params = sum(p.numel() for p in module.parameters(recurse=False) if p.requires_grad) 68 | if params > 0: 69 | if layer_type not in layer_types: 70 | layer_types[layer_type] = 0 71 | layer_types[layer_type] += params 72 | return layer_types 73 | 74 | total_params = count_parameters(module) 75 | layer_types = get_layer_types(module) 76 | 77 | output = [f'Total trainable parameters: {total_params:,}', '---------------------------'] 78 | 79 | for layer_type, count in sorted(layer_types.items(), key=lambda x: x[1], reverse=True): 80 | percentage = (count / total_params) * 100 81 | output.append(f'{layer_type}: {count:,} ({percentage:.2f}%)') 82 | 83 | return '\n'.join(output) 84 | 85 | --------------------------------------------------------------------------------