├── .github
└── workflows
│ └── ci.yml
├── .gitignore
├── Dance_Diffusion.ipynb
├── Finetune_Dance_Diffusion.ipynb
├── LICENSE
├── README.md
├── audio_diffusion
├── __init__.py
├── blocks.py
├── models.py
└── utils.py
├── cog.yaml
├── dataset
├── __init__.py
└── dataset.py
├── defaults.ini
├── meta.json
├── predict.py
├── setup.py
├── train_uncond.py
└── viz
├── __init__.py
└── viz.py
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: Test
2 |
3 | on:
4 | pull_request: {}
5 | push:
6 | branches: "main"
7 | tags: "*"
8 |
9 | jobs:
10 | test:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - uses: actions/checkout@v2
14 |
15 | - name: Setup Python
16 | uses: actions/setup-python@v1
17 | with:
18 | python-version: 3.9
19 |
20 | - name: Install linters
21 | run: pip install black flake8 isort
22 |
23 | # - name: linting
24 | # run: |
25 | # isort --diff .
26 | # black --check .
27 | # flake8
28 | # # - run: pytest --cov --cov-fail-under=80
29 |
30 | docker-image:
31 | runs-on: ubuntu-latest
32 | steps:
33 | - uses: actions/checkout@v2
34 |
35 | - name: Install cog
36 | run: |
37 | curl -o /usr/local/bin/cog -L https://github.com/replicate/cog/releases/latest/download/cog_`uname -s`_`uname -m`
38 | chmod +x /usr/local/bin/cog
39 | - name: Build docker image
40 | run: cog build -t image:test
41 |
42 | # - name: Smoke test docker image
43 | # run: |
44 | # docker run --rm middlepollen:test
45 | - name: Configure AWS credentials
46 | uses: aws-actions/configure-aws-credentials@v1
47 | with:
48 | aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
49 | aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
50 | aws-region: us-east-1
51 |
52 | - name: Login to Amazon ECR
53 | id: login-ecr
54 | uses: aws-actions/amazon-ecr-login@v1
55 |
56 | - name: Create repository if it doesn't exist yet
57 | run: aws ecr describe-repositories --repository-names ${{ github.repository }} || aws ecr create-repository --repository-name ${{ github.repository }}
58 |
59 | - name: Build, tag, and push image to Amazon ECR
60 | env:
61 | ECR_REGISTRY: 614871946825.dkr.ecr.us-east-1.amazonaws.com #${{ steps.login-ecr.outputs.registry }}
62 | ECR_REPOSITORY: ${{ github.repository }}
63 | IMAGE_TAG: latest
64 | run: |
65 | docker tag image:test $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG
66 | docker push $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG
67 | docker inspect image:test > inspect.json
68 |
69 | - uses: webfactory/ssh-agent@v0.5.4
70 | with:
71 | ssh-private-key: ${{ secrets.SSH_KEY }}
72 | - name: Update model registry
73 | env:
74 | ECR_REGISTRY: 614871946825.dkr.ecr.us-east-1.amazonaws.com #${{ steps.login-ecr.outputs.registry }}
75 | ECR_REPOSITORY: ${{ github.repository }}
76 | IMAGE_TAG: latest
77 | run: |
78 | git config --global user.email "ci@pollinations.ai"
79 | git config --global user.name "pollinations-ci"
80 | git clone git@github.com:pollinations/model-index.git
81 | mkdir -p model-index/${{ github.repository }}
82 | cp meta.json model-index/${{ github.repository }}/meta.json
83 | cp inspect.json model-index/${{ github.repository }}/inspect.json
84 | cd model-index && python add_image.py ${{ github.repository }} $ECR_REGISTRY/$ECR_REPOSITORY && cd ..
85 | cd model-index && git add . && (git commit -m "Updated ${{ github.repository }}: ${{ github.event.head_commit.message }}" && git push) || echo "model index not updated"
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/Dance_Diffusion.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "
"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {
13 | "id": "HHcTRGvUmoME"
14 | },
15 | "source": [
16 | "# Dance Diffusion v0.10\n",
17 | "\n",
18 | "Welcome to the Dance Diffusion beta!\n",
19 | "\n",
20 | "Dance Diffusion is the first in a suite of generative audio tools for producers and musicians to be released by Harmonai. For more info or to get involved in the development of these tools, please visit https://harmonai.org and fill out the form on the front page.\n",
21 | "\n",
22 | "[Click here to ensure you are using the latest version](https://colab.research.google.com/github/Harmonai-org/sample-generator/blob/main/Dance_Diffusion.ipynb)\n",
23 | "\n",
24 | "**Audio diffusion tools in this notebook**:\n",
25 | "\n",
26 | "- Unconditional random audio sample generation\n",
27 | "- Audio sample regeneration/style transfer using a single audio file\n",
28 | "- Audio interpolation between two audio files"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {
34 | "id": "1iZwJ9ong-pH"
35 | },
36 | "source": [
37 | "# Instructions\n",
38 | "\n"
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {
44 | "id": "-ITvsXU6hCAx"
45 | },
46 | "source": [
47 | "## Before anything else\n",
48 | "- Run the \"Setup\" section\n",
49 | "- Sign in to the Google Drive account you want to save your models in\n",
50 | "- Select the model you want to sample from in the \"Model settings\" section, this determines the length and sound of your samples\n",
51 | "- The `save_to_wandb` option futher down adds the ability to log your audio generations to [Weights & Biases](https://www.wandb.ai/site), an experiment tracking and model and data versioning tool.\n",
52 | "\n",
53 | "## For random sample generation\n",
54 | "- Choose the number of random samples you would like Dance Diffusion to generate for you \n",
55 | "- Choose the number of diffusion steps you would like Dance Diffusion to execute\n",
56 | "- Make sure the \"skip_for_run_all\" checkbox is unchecked\n",
57 | "- Run the cell under the \"Generate new sounds\" header\n",
58 | "\n",
59 | "## To regenerate your own sounds\n",
60 | "- Enter the path to an audio file you want to regenerate, or upload when prompted\n",
61 | "- Make sure the \"skip_for_run_all\" checkbox is unchecked\n",
62 | "- Run the cell under the \"Regenerate your own sounds\" header\n",
63 | "\n",
64 | "## To interpolate between two different sounds\n",
65 | "- Enter the paths to two audio files you want to interpolate between, or upload them when prompted\n",
66 | "- Make sure the \"skip_for_run_all\" checkbox is unchecked\n",
67 | "- Run the cell under the \"Interpolate between sounds\" header"
68 | ]
69 | },
70 | {
71 | "cell_type": "markdown",
72 | "metadata": {
73 | "id": "pJkAc1j4pfAt"
74 | },
75 | "source": [
76 | "### Credits & License"
77 | ]
78 | },
79 | {
80 | "cell_type": "markdown",
81 | "metadata": {
82 | "id": "cjHsmepZqNMs"
83 | },
84 | "source": [
85 | "\n",
86 | "\n",
87 | "Original notebook by Zach Evans (https://github.com/zqevans, https://twitter.com/zqevans). \n",
88 | "\n",
89 | "Overall structure and setup code taken from Disco Diffusion (https://www.discodiffusion.com)\n",
90 | "\n",
91 | "Interpolation and audio display code from CRASH inference notebook (https://github.com/simonrouard/CRASH)\n",
92 | "\n",
93 | "Spruced up by Chris the Wizard (https://twitter.com/chris_wizard)\n",
94 | "\n"
95 | ]
96 | },
97 | {
98 | "cell_type": "markdown",
99 | "metadata": {
100 | "id": "u97w34BXmust"
101 | },
102 | "source": [
103 | "Licensed under the MIT License\n",
104 | "\n",
105 | "Copyright (c) 2022 Zach Evans\n",
106 | "\n",
107 | "Permission is hereby granted, free of charge, to any person obtaining a copy\n",
108 | "of this software and associated documentation files (the \"Software\"), to deal\n",
109 | "in the Software without restriction, including without limitation the rights\n",
110 | "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n",
111 | "copies of the Software, and to permit persons to whom the Software is\n",
112 | "furnished to do so, subject to the following conditions:\n",
113 | "\n",
114 | "The above copyright notice and this permission notice shall be included in\n",
115 | "all copies or substantial portions of the Software.\n",
116 | "\n",
117 | "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
118 | "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
119 | "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n",
120 | "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
121 | "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n",
122 | "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n",
123 | "THE SOFTWARE.\n"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": null,
129 | "metadata": {
130 | "cellView": "form",
131 | "id": "aBF6kI89LJ5O"
132 | },
133 | "outputs": [],
134 | "source": [
135 | "#@title <- View Changelog\n",
136 | "skip_for_run_all = True #@param {type: 'boolean'}\n",
137 | "\n",
138 | "if skip_for_run_all == False:\n",
139 | " print(\n",
140 | " '''\n",
141 | "\n",
142 | " v0.1 Update: Jul 30, 2022 - zqevans\n",
143 | " \n",
144 | " - Added Dance Diffusion notebook\n",
145 | "\n",
146 | " v0.2 Update: Aug 8, 2022 - zqevans\n",
147 | "\n",
148 | " - Moved to models trained on \n",
149 | "\n",
150 | " v0.3 Update: Aug 11, 2022 - zqevans\n",
151 | "\n",
152 | " - Reverted to old model architecture\n",
153 | " - Fixed CRASH sampling code\n",
154 | "\n",
155 | " v0.4 Update: Aug 16, 2022 - zqevans\n",
156 | "\n",
157 | " - Added jmann-small-190k model\n",
158 | "\n",
159 | " v0.5 Update: Aug 17, 2022 - zqevans\n",
160 | "\n",
161 | " - Added interpolations\n",
162 | " \n",
163 | " v0.6 Update: Aug 18, 2022 - zqevans\n",
164 | "\n",
165 | " - Fixed bug in interpolations\n",
166 | "\n",
167 | " v0.7 Update: Aug 20, 2022 - zqevans\n",
168 | " - Added maestro-150k model\n",
169 | " - Added unlocked-250k model\n",
170 | " - Improved documentation\n",
171 | "\n",
172 | " v0.7.1 Update: Aug 21, 2022 - chris the wizard\n",
173 | " - Added introduction\n",
174 | " - Added instructions\n",
175 | " - Added skips for sections\n",
176 | " - Added upload prompts for audio files\n",
177 | " - Removed stale demos\n",
178 | "\n",
179 | " v0.8 Update: Aug 24, 2022 - zqevans\n",
180 | " - Added Honk model\n",
181 | " - Removed Rave Archive model\n",
182 | " - Added sample length multipliers and batch sizes to regeneration and interpolation\n",
183 | "\n",
184 | " v0.9 Update: Aug 24, 2022 - zqevans\n",
185 | " - Added glitch.cool model\n",
186 | " - Added jmann-large model\n",
187 | " - Regenerated sounds are now output individually\n",
188 | " - Added custom model sample_size and sample_rate options\n",
189 | "\n",
190 | " v0.10 Update: Sep 26, 2022 - morganmcg1\n",
191 | " - Added optional, off by default, Weights & Biases logging of the generated audio samples\n",
192 | " '''\n",
193 | " )"
194 | ]
195 | },
196 | {
197 | "cell_type": "markdown",
198 | "metadata": {
199 | "id": "lU97ZiP7nSKS"
200 | },
201 | "source": [
202 | "# Setup\n",
203 | "Run everything in this section before any generation"
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": null,
209 | "metadata": {
210 | "cellView": "form",
211 | "id": "mxb-qgh0nUOf"
212 | },
213 | "outputs": [],
214 | "source": [
215 | "#@title Check GPU Status\n",
216 | "import subprocess\n",
217 | "simple_nvidia_smi_display = True#@param {type:\"boolean\"}\n",
218 | "if simple_nvidia_smi_display:\n",
219 | " #!nvidia-smi\n",
220 | " nvidiasmi_output = subprocess.run(['nvidia-smi', '-L'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
221 | " print(nvidiasmi_output)\n",
222 | "else:\n",
223 | " #!nvidia-smi -i 0 -e 0\n",
224 | " nvidiasmi_output = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
225 | " print(nvidiasmi_output)\n",
226 | " nvidiasmi_ecc_note = subprocess.run(['nvidia-smi', '-i', '0', '-e', '0'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
227 | " print(nvidiasmi_ecc_note)"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": null,
233 | "metadata": {
234 | "cellView": "form",
235 | "id": "T_mFtzHvnlJL"
236 | },
237 | "outputs": [],
238 | "source": [
239 | "#@title Prepare folders\n",
240 | "import subprocess, os, sys, ipykernel\n",
241 | "\n",
242 | "def gitclone(url, targetdir=None):\n",
243 | " if targetdir:\n",
244 | " res = subprocess.run(['git', 'clone', url, targetdir], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
245 | " else:\n",
246 | " res = subprocess.run(['git', 'clone', url], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
247 | " print(res)\n",
248 | "\n",
249 | "def pipi(modulestr):\n",
250 | " res = subprocess.run(['pip', 'install', modulestr], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
251 | " print(res)\n",
252 | "\n",
253 | "def pipie(modulestr):\n",
254 | " res = subprocess.run(['git', 'install', '-e', modulestr], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
255 | " print(res)\n",
256 | "\n",
257 | "def wget(url, outputdir):\n",
258 | " # Using the !wget command instead of the subprocess to get the loading bar\n",
259 | " !wget $url -O $outputdir\n",
260 | " # res = subprocess.run(['wget', url, '-P', f'{outputdir}'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
261 | " # print(res)\n",
262 | "\n",
263 | "try:\n",
264 | " from google.colab import drive\n",
265 | " print(\"Google Colab detected. Using Google Drive.\")\n",
266 | " is_colab = True\n",
267 | " #@markdown Check to connect your Google Drive\n",
268 | " google_drive = True #@param {type:\"boolean\"}\n",
269 | " #@markdown Click here if you'd like to save the diffusion model checkpoint file to (and/or load from) your Google Drive:\n",
270 | " save_models_to_google_drive = True #@param {type:\"boolean\"}\n",
271 | "except:\n",
272 | " is_colab = False\n",
273 | " google_drive = False\n",
274 | " save_models_to_google_drive = False\n",
275 | " print(\"Google Colab not detected.\")\n",
276 | "\n",
277 | "if is_colab:\n",
278 | " if google_drive is True:\n",
279 | " drive.mount('/content/drive')\n",
280 | " ai_root = '/content/drive/MyDrive/AI'\n",
281 | " root_path = f'{ai_root}/Dance_Diffusion'\n",
282 | " else:\n",
283 | " root_path = '/content'\n",
284 | "else:\n",
285 | " root_path = os.getcwd()\n",
286 | "\n",
287 | "import os\n",
288 | "def createPath(filepath):\n",
289 | " os.makedirs(filepath, exist_ok=True)\n",
290 | "\n",
291 | "initDirPath = f'{root_path}/init_audio'\n",
292 | "createPath(initDirPath)\n",
293 | "outDirPath = f'{root_path}/audio_out'\n",
294 | "createPath(outDirPath)\n",
295 | "\n",
296 | "if is_colab:\n",
297 | " if google_drive and not save_models_to_google_drive or not google_drive:\n",
298 | " model_path = '/content/models'\n",
299 | " createPath(model_path)\n",
300 | " if google_drive and save_models_to_google_drive:\n",
301 | " model_path = f'{ai_root}/models'\n",
302 | " createPath(model_path)\n",
303 | "else:\n",
304 | " model_path = f'{root_path}/models'\n",
305 | " createPath(model_path)\n",
306 | "\n",
307 | "# libraries = f'{root_path}/libraries'\n",
308 | "# createPath(libraries)\n",
309 | "\n",
310 | "#@markdown Check the box below to save your generated audio to [Weights & Biases](https://wandb.ai/site)\n",
311 | "save_to_wandb = False #@param {type: \"boolean\"}\n",
312 | "\n",
313 | "if save_to_wandb:\n",
314 | " print(\"\\nInstalling wandb...\")\n",
315 | " os.system(\"pip install -qqq wandb --upgrade\")\n",
316 | " import wandb\n",
317 | " # Check if logged in to wandb\n",
318 | " try:\n",
319 | " import netrc\n",
320 | " netrc.netrc().hosts['api.wandb.ai']\n",
321 | " wandb.login()\n",
322 | " except:\n",
323 | " print(\"\\nPlease log in to Weights & Biases...\")\n",
324 | " print(\"1. Sign up for a free wandb account here: https://www.wandb.ai/site\")\n",
325 | " print(\"2. Enter your wandb API key, from https://wandb.ai/authorize, in the field below to log in: \\n\")\n",
326 | " wandb.login()"
327 | ]
328 | },
329 | {
330 | "cell_type": "code",
331 | "execution_count": null,
332 | "metadata": {
333 | "cellView": "form",
334 | "id": "y9BS0ks1oEgP"
335 | },
336 | "outputs": [],
337 | "source": [
338 | "#@title Install dependencies\n",
339 | "!git clone https://github.com/harmonai-org/sample-generator\n",
340 | "!git clone --recursive https://github.com/crowsonkb/v-diffusion-pytorch\n",
341 | "!pip install /content/sample-generator\n",
342 | "!pip install /content/v-diffusion-pytorch\n",
343 | "!pip install ipywidgets==7.7.1"
344 | ]
345 | },
346 | {
347 | "cell_type": "code",
348 | "execution_count": 4,
349 | "metadata": {
350 | "cellView": "form",
351 | "id": "haxvUGZ0VpzA"
352 | },
353 | "outputs": [],
354 | "source": [
355 | "#@title Imports and definitions\n",
356 | "from prefigure.prefigure import get_all_args\n",
357 | "from contextlib import contextmanager\n",
358 | "from copy import deepcopy\n",
359 | "import math\n",
360 | "from pathlib import Path\n",
361 | "from google.colab import files\n",
362 | "\n",
363 | "import sys\n",
364 | "import gc\n",
365 | "\n",
366 | "from diffusion import sampling\n",
367 | "import torch\n",
368 | "from torch import optim, nn\n",
369 | "from torch.nn import functional as F\n",
370 | "from torch.utils import data\n",
371 | "from tqdm import trange\n",
372 | "from einops import rearrange\n",
373 | "\n",
374 | "import torchaudio\n",
375 | "from audio_diffusion.models import DiffusionAttnUnet1D\n",
376 | "import numpy as np\n",
377 | "\n",
378 | "import random\n",
379 | "import matplotlib.pyplot as plt\n",
380 | "import IPython.display as ipd\n",
381 | "from audio_diffusion.utils import Stereo, PadCrop\n",
382 | "from glob import glob\n",
383 | "\n",
384 | "#@title Model code\n",
385 | "class DiffusionUncond(nn.Module):\n",
386 | " def __init__(self, global_args):\n",
387 | " super().__init__()\n",
388 | "\n",
389 | " self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers = 4)\n",
390 | " self.diffusion_ema = deepcopy(self.diffusion)\n",
391 | " self.rng = torch.quasirandom.SobolEngine(1, scramble=True)\n",
392 | "\n",
393 | "import matplotlib.pyplot as plt\n",
394 | "import IPython.display as ipd\n",
395 | "\n",
396 | "def plot_and_hear(audio, sr):\n",
397 | " display(ipd.Audio(audio.cpu().clamp(-1, 1), rate=sr))\n",
398 | " plt.plot(audio.cpu().t().numpy())\n",
399 | " \n",
400 | "def load_to_device(path, sr):\n",
401 | " audio, file_sr = torchaudio.load(path)\n",
402 | " if sr != file_sr:\n",
403 | " audio = torchaudio.transforms.Resample(file_sr, sr)(audio)\n",
404 | " audio = audio.to(device)\n",
405 | " return audio\n",
406 | "\n",
407 | "def get_alphas_sigmas(t):\n",
408 | " \"\"\"Returns the scaling factors for the clean image (alpha) and for the\n",
409 | " noise (sigma), given a timestep.\"\"\"\n",
410 | " return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)\n",
411 | "\n",
412 | "def get_crash_schedule(t):\n",
413 | " sigma = torch.sin(t * math.pi / 2) ** 2\n",
414 | " alpha = (1 - sigma ** 2) ** 0.5\n",
415 | " return alpha_sigma_to_t(alpha, sigma)\n",
416 | "\n",
417 | "def t_to_alpha_sigma(t):\n",
418 | " \"\"\"Returns the scaling factors for the clean image and for the noise, given\n",
419 | " a timestep.\"\"\"\n",
420 | " return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)\n",
421 | "\n",
422 | "def alpha_sigma_to_t(alpha, sigma):\n",
423 | " \"\"\"Returns a timestep, given the scaling factors for the clean image and for\n",
424 | " the noise.\"\"\"\n",
425 | " return torch.atan2(sigma, alpha) / math.pi * 2\n",
426 | "\n",
427 | "#@title Args\n",
428 | "sample_size = 65536 \n",
429 | "sample_rate = 48000 \n",
430 | "latent_dim = 0 \n",
431 | "\n",
432 | "class Object(object):\n",
433 | " pass\n",
434 | "\n",
435 | "args = Object()\n",
436 | "args.sample_size = sample_size\n",
437 | "args.sample_rate = sample_rate\n",
438 | "args.latent_dim = latent_dim\n",
439 | "\n",
440 | "#@title Logging\n",
441 | "def get_one_channel(audio_data, channel):\n",
442 | " '''\n",
443 | " Takes a numpy audio array and returns 1 channel\n",
444 | " '''\n",
445 | " # Check if the audio has more than 1 channel \n",
446 | " if len(audio_data.shape) > 1:\n",
447 | " is_stereo = True \n",
448 | " if np.argmax(audio_data.shape)==0:\n",
449 | " audio_data = audio_data[:,channel] \n",
450 | " else:\n",
451 | " audio_data = audio_data[channel,:]\n",
452 | " else:\n",
453 | " is_stereo = False\n",
454 | "\n",
455 | " return audio_data\n",
456 | "\n",
457 | "def log_audio_to_wandb(\n",
458 | " generated, model_name, custom_ckpt_path, steps, batch_size, sample_rate, sample_size, \n",
459 | " generated_all=None, channel=0, original_sample=None, gen_type='new_sounds', noise_level=None, sample_length_mult=None, file_path=None\n",
460 | " ):\n",
461 | "\n",
462 | " print('\\nSaving your audio generations to Weights & Biases...')\n",
463 | "\n",
464 | " # Get model name\n",
465 | " if model_name == \"custom\":\n",
466 | " wandb_model_name = custom_ckpt_path\n",
467 | " else:\n",
468 | " wandb_model_name = model_name\n",
469 | " \n",
470 | " # Create config to log to wandb\n",
471 | " wandb_config = {\n",
472 | " \"model\":model_name,\n",
473 | " \"steps\":steps,\n",
474 | " \"batch_size\":batch_size,\n",
475 | " \"sample_rate\":sample_rate,\n",
476 | " \"sample_size\":sample_size,\n",
477 | " \"channel\":channel,\n",
478 | " \"gen_type\":gen_type,\n",
479 | " \"noise_level\":noise_level,\n",
480 | " \"sample_length_mult\":sample_length_mult,\n",
481 | " \"file_path\":file_path\n",
482 | " }\n",
483 | "\n",
484 | " # Create a new wandb run\n",
485 | " wandb.init(project='harmonai-audio-gen', config=wandb_config)\n",
486 | " wandb_run_url = wandb.run.get_url()\n",
487 | "\n",
488 | " # Create a Weights & Biases Table\n",
489 | " audio_generations_table = wandb.Table(columns=['audio', 'steps', 'model', 'batch_size', \n",
490 | " 'sample_rate', 'sample_size', 'duration'])\n",
491 | "\n",
492 | " # Add each individual generated sample to a wandb Table\n",
493 | " for idx, g in enumerate(generated.cpu().numpy()):\n",
494 | " \n",
495 | " # Check if the audio has more than 1 channel \n",
496 | " if idx==0: \n",
497 | " if len(g.shape) > 1:\n",
498 | " stereo = True \n",
499 | " else:\n",
500 | " stereo = False\n",
501 | "\n",
502 | " if stereo:\n",
503 | " g = g[channel]\n",
504 | "\n",
505 | " duration = np.max(g.shape) / sample_rate \n",
506 | " wandb_audio = wandb.Audio(g, sample_rate=sample_rate, caption=wandb_model_name)\n",
507 | " audio_generations_table.add_data(wandb_audio, steps, wandb_model_name, batch_size, \n",
508 | " sample_rate, sample_size, duration)\n",
509 | "\n",
510 | " # Log the samples Tables and finish the wandb run\n",
511 | " wandb.log({f'{gen_type}/harmonai_generations' : audio_generations_table})\n",
512 | " \n",
513 | " # Log the combined samples in another wandb Table\n",
514 | " if generated_all is not None:\n",
515 | " g_all = get_one_channel(generated_all, channel)\n",
516 | " duration_all = np.max(g_all.shape) / sample_rate \n",
517 | " audio_all_generations_table = wandb.Table(columns=['audio', 'steps', 'model', 'batch_size', \n",
518 | " 'sample_rate', 'sample_size', 'duration'])\n",
519 | " wandb_all_audio = wandb.Audio(g_all.cpu().numpy(), sample_rate=sample_rate, caption=wandb_model_name)\n",
520 | " audio_all_generations_table.add_data(wandb_all_audio, steps, wandb_model_name, batch_size, \n",
521 | " sample_rate, sample_size, duration_all)\n",
522 | " wandb.log({f'{gen_type}/all_harmonai_generations' : audio_all_generations_table})\n",
523 | "\n",
524 | " if original_sample is not None:\n",
525 | " original_sample = get_one_channel(original_sample, channel)\n",
526 | " audio_original_sample_table = wandb.Table(columns=['audio', 'file_path'])\n",
527 | " wandb_original_audio = wandb.Audio(original_sample, sample_rate=sample_rate)\n",
528 | " audio_original_sample_table.add_data(wandb_original_audio, file_path)\n",
529 | " wandb.log({f'{gen_type}/original_sample' : audio_original_sample_table})\n",
530 | " \n",
531 | " wandb.finish()\n",
532 | "\n",
533 | " print(f'Your audio generations are saved in Weights & Biases here: {wandb_run_url}\\n')"
534 | ]
535 | },
536 | {
537 | "cell_type": "markdown",
538 | "metadata": {
539 | "id": "SMQ8vYNQO22Y"
540 | },
541 | "source": [
542 | "# Model settings"
543 | ]
544 | },
545 | {
546 | "cell_type": "markdown",
547 | "metadata": {
548 | "id": "LWxBqHH_Yjvt"
549 | },
550 | "source": [
551 | "Select the model you want to sample from:\n",
552 | "---\n",
553 | "Model name | Description | Sample rate | Output samples\n",
554 | "--- | --- | --- | ---\n",
555 | "glitch-440k |Trained on clips from samples provided by [glitch.cool](https://glitch.cool) | 48000 | 65536\n",
556 | "jmann-small-190k |Trained on clips from Jonathan Mann's [Song-A-Day](https://songaday.world/) project | 48000 | 65536\n",
557 | "jmann-large-580k |Trained on clips from Jonathan Mann's [Song-A-Day](https://songaday.world/) project | 48000 | 131072\n",
558 | "maestro-150k |Trained on piano clips from the [MAESTRO](https://magenta.tensorflow.org/datasets/maestro) dataset | 16000 | 65536\n",
559 | "unlocked-250k |Trained on clips from the [Unlocked Recordings](https://archive.org/details/unlockedrecordings) dataset | 16000 | 65536\n",
560 | "honk-140k |Trained on recordings of the Canada Goose from [xeno-canto](https://xeno-canto.org/) | 16000 | 65536\n"
561 | ]
562 | },
563 | {
564 | "cell_type": "code",
565 | "execution_count": null,
566 | "metadata": {
567 | "cellView": "form",
568 | "id": "JHsHQcc6rHu7"
569 | },
570 | "outputs": [],
571 | "source": [
572 | "from urllib.parse import urlparse\n",
573 | "import hashlib\n",
574 | "#@title Create the model\n",
575 | "model_name = \"glitch-440k\" #@param [\"glitch-440k\", \"jmann-small-190k\", \"jmann-large-580k\", \"maestro-150k\", \"unlocked-250k\", \"honk-140k\", \"custom\"]\n",
576 | "\n",
577 | "#@markdown ###Custom options\n",
578 | "\n",
579 | "#@markdown If you have a custom fine-tuned model, choose \"custom\" above and enter a path to the model checkpoint here\n",
580 | "\n",
581 | "#@markdown These options will not affect non-custom models\n",
582 | "custom_ckpt_path = ''#@param {type: 'string'}\n",
583 | "\n",
584 | "custom_sample_rate = 16000 #@param {type: 'number'}\n",
585 | "custom_sample_size = 65536 #@param {type: 'number'}\n",
586 | "\n",
587 | "models_map = {\n",
588 | "\n",
589 | " \"glitch-440k\": {'downloaded': False,\n",
590 | " 'sha': \"48caefdcbb7b15e1a0b3d08587446936302535de74b0e05e0d61beba865ba00a\", \n",
591 | " 'uri_list': [\"https://model-server.zqevans2.workers.dev/gwf-440k.ckpt\"],\n",
592 | " 'sample_rate': 48000,\n",
593 | " 'sample_size': 65536\n",
594 | " },\n",
595 | " \"jmann-small-190k\": {'downloaded': False,\n",
596 | " 'sha': \"1e2a23a54e960b80227303d0495247a744fa1296652148da18a4da17c3784e9b\", \n",
597 | " 'uri_list': [\"https://model-server.zqevans2.workers.dev/jmann-small-190k.ckpt\"],\n",
598 | " 'sample_rate': 48000,\n",
599 | " 'sample_size': 65536\n",
600 | " },\n",
601 | " \"jmann-large-580k\": {'downloaded': False,\n",
602 | " 'sha': \"6b32b5ff1c666c4719da96a12fd15188fa875d6f79f8dd8e07b4d54676afa096\", \n",
603 | " 'uri_list': [\"https://model-server.zqevans2.workers.dev/jmann-large-580k.ckpt\"],\n",
604 | " 'sample_rate': 48000,\n",
605 | " 'sample_size': 131072\n",
606 | " },\n",
607 | " \"maestro-150k\": {'downloaded': False,\n",
608 | " 'sha': \"49d9abcae642e47c2082cec0b2dce95a45dc6e961805b6500204e27122d09485\", \n",
609 | " 'uri_list': [\"https://model-server.zqevans2.workers.dev/maestro-uncond-150k.ckpt\"],\n",
610 | " 'sample_rate': 16000,\n",
611 | " 'sample_size': 65536\n",
612 | " },\n",
613 | " \"unlocked-250k\": {'downloaded': False,\n",
614 | " 'sha': \"af337c8416732216eeb52db31dcc0d49a8d48e2b3ecaa524cb854c36b5a3503a\", \n",
615 | " 'uri_list': [\"https://model-server.zqevans2.workers.dev/unlocked-uncond-250k.ckpt\"],\n",
616 | " 'sample_rate': 16000,\n",
617 | " 'sample_size': 65536\n",
618 | " },\n",
619 | " \"honk-140k\": {'downloaded': False,\n",
620 | " 'sha': \"a66847844659d287f55b7adbe090224d55aeafdd4c2b3e1e1c6a02992cb6e792\", \n",
621 | " 'uri_list': [\"https://model-server.zqevans2.workers.dev/honk-140k.ckpt\"],\n",
622 | " 'sample_rate': 16000,\n",
623 | " 'sample_size': 65536\n",
624 | " },\n",
625 | "}\n",
626 | "\n",
627 | "#@markdown If you're having issues with model downloads, check this to compare the SHA:\n",
628 | "check_model_SHA = True #@param{type:\"boolean\"}\n",
629 | "\n",
630 | "def get_model_filename(diffusion_model_name):\n",
631 | " model_uri = models_map[diffusion_model_name]['uri_list'][0]\n",
632 | " model_filename = os.path.basename(urlparse(model_uri).path)\n",
633 | " return model_filename\n",
634 | "\n",
635 | "def download_model(diffusion_model_name, uri_index=0):\n",
636 | " if diffusion_model_name != 'custom':\n",
637 | " model_filename = get_model_filename(diffusion_model_name)\n",
638 | " model_local_path = os.path.join(model_path, model_filename)\n",
639 | " if os.path.exists(model_local_path) and check_model_SHA:\n",
640 | " print(f'Checking {diffusion_model_name} File')\n",
641 | " with open(model_local_path, \"rb\") as f:\n",
642 | " bytes = f.read() \n",
643 | " hash = hashlib.sha256(bytes).hexdigest()\n",
644 | " print(f'SHA: {hash}')\n",
645 | " if hash == models_map[diffusion_model_name]['sha']:\n",
646 | " print(f'{diffusion_model_name} SHA matches')\n",
647 | " models_map[diffusion_model_name]['downloaded'] = True\n",
648 | " else:\n",
649 | " print(f\"{diffusion_model_name} SHA doesn't match. Will redownload it.\")\n",
650 | " elif os.path.exists(model_local_path) and not check_model_SHA or models_map[diffusion_model_name]['downloaded']:\n",
651 | " print(f'{diffusion_model_name} already downloaded. If the file is corrupt, enable check_model_SHA.')\n",
652 | " models_map[diffusion_model_name]['downloaded'] = True\n",
653 | "\n",
654 | " if not models_map[diffusion_model_name]['downloaded']:\n",
655 | " for model_uri in models_map[diffusion_model_name]['uri_list']:\n",
656 | " wget(model_uri, model_local_path)\n",
657 | " with open(model_local_path, \"rb\") as f:\n",
658 | " bytes = f.read() \n",
659 | " hash = hashlib.sha256(bytes).hexdigest()\n",
660 | " print(f'SHA: {hash}')\n",
661 | " if os.path.exists(model_local_path):\n",
662 | " models_map[diffusion_model_name]['downloaded'] = True\n",
663 | " return\n",
664 | " else:\n",
665 | " print(f'{diffusion_model_name} model download from {model_uri} failed. Will try any fallback uri.')\n",
666 | " print(f'{diffusion_model_name} download failed.')\n",
667 | "\n",
668 | "if model_name == \"custom\":\n",
669 | " ckpt_path = custom_ckpt_path\n",
670 | " args.sample_size = custom_sample_size\n",
671 | " args.sample_rate = custom_sample_rate\n",
672 | "else:\n",
673 | " model_info = models_map[model_name]\n",
674 | " download_model(model_name)\n",
675 | " ckpt_path = f'{model_path}/{get_model_filename(model_name)}'\n",
676 | " args.sample_size = model_info[\"sample_size\"]\n",
677 | " args.sample_rate = model_info[\"sample_rate\"]\n",
678 | "\n",
679 | "print(\"Creating the model...\")\n",
680 | "model = DiffusionUncond(args)\n",
681 | "model.load_state_dict(torch.load(ckpt_path)[\"state_dict\"])\n",
682 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
683 | "model = model.requires_grad_(False).to(device)\n",
684 | "print(\"Model created\")\n",
685 | "\n",
686 | "# # Remove non-EMA\n",
687 | "del model.diffusion\n",
688 | "\n",
689 | "model_fn = model.diffusion_ema\n",
690 | "\n"
691 | ]
692 | },
693 | {
694 | "cell_type": "markdown",
695 | "metadata": {
696 | "id": "_GQK9yZHTr_z"
697 | },
698 | "source": [
699 | "# Generate new sounds\n",
700 | "\n",
701 | "Feeding white noise into the model to be denoised creates novel sounds in the \"space\" of the training data."
702 | ]
703 | },
704 | {
705 | "cell_type": "code",
706 | "execution_count": null,
707 | "metadata": {
708 | "cellView": "form",
709 | "id": "zntGqLTJq6xU"
710 | },
711 | "outputs": [],
712 | "source": [
713 | "#@markdown How many audio clips to create\n",
714 | "batch_size = 16#@param {type:\"number\"}\n",
715 | "\n",
716 | "#@markdown Number of steps (100 is a good start, more steps trades off speed for quality)\n",
717 | "steps = 100 #@param {type:\"number\"}\n",
718 | "\n",
719 | "#@markdown Check the box below to save your generated audio to [Weights & Biases](https://www.wandb.ai/site)\n",
720 | "save_new_generations_to_wandb = False #@param {type: \"boolean\"}\n",
721 | "\n",
722 | "#@markdown Check the box below to skip this section when running all cells\n",
723 | "skip_for_run_all = False #@param {type: \"boolean\"}\n",
724 | "\n",
725 | "if not skip_for_run_all:\n",
726 | " torch.cuda.empty_cache()\n",
727 | " gc.collect()\n",
728 | "\n",
729 | " # Generate random noise to sample from\n",
730 | " noise = torch.randn([batch_size, 2, args.sample_size]).to(device)\n",
731 | "\n",
732 | " t = torch.linspace(1, 0, steps + 1, device=device)[:-1]\n",
733 | " step_list = get_crash_schedule(t)\n",
734 | "\n",
735 | " # Generate the samples from the noise\n",
736 | " generated = sampling.iplms_sample(model_fn, noise, step_list, {})\n",
737 | "\n",
738 | " # Hard-clip the generated audio\n",
739 | " generated = generated.clamp(-1, 1)\n",
740 | "\n",
741 | " # Put the demos together\n",
742 | " generated_all = rearrange(generated, 'b d n -> d (b n)')\n",
743 | "\n",
744 | " print(\"All samples\")\n",
745 | " plot_and_hear(generated_all, args.sample_rate)\n",
746 | " for ix, gen_sample in enumerate(generated):\n",
747 | " print(f'sample #{ix + 1}')\n",
748 | " display(ipd.Audio(gen_sample.cpu(), rate=args.sample_rate))\n",
749 | "\n",
750 | " # If Weights & Biases logging enabled, save generations\n",
751 | " if save_new_generations_to_wandb:\n",
752 | " # Check if logged in to wandb\n",
753 | " try:\n",
754 | " import netrc\n",
755 | " netrc.netrc().hosts['api.wandb.ai']\n",
756 | "\n",
757 | " log_audio_to_wandb(generated, model_name, custom_ckpt_path, steps, batch_size, \n",
758 | " args.sample_rate, args.sample_size, generated_all=generated_all)\n",
759 | " except:\n",
760 | " print(\"Not logged in to Weights & Biases, please tick the `save_to_wandb` box at the top of this notebook and run that cell again to log in to Weights & Biases first\")\n",
761 | "\n",
762 | "else:\n",
763 | " print(\"Skipping section, uncheck 'skip_for_run_all' to enable\")\n"
764 | ]
765 | },
766 | {
767 | "cell_type": "markdown",
768 | "metadata": {
769 | "id": "v0WKP7ku67vn"
770 | },
771 | "source": [
772 | "# Regenerate your own sounds\n",
773 | "By adding noise to an audio file and running it through the model to be denoised, new details will be created, pulling the audio closer to the \"sonic space\" of the model. The more noise you add, the more the sound will change.\n",
774 | "\n",
775 | "The effect of this is a kind of \"style transfer\" on the audio. For those familiar with image generation models, this is analogous to an \"init image\"."
776 | ]
777 | },
778 | {
779 | "cell_type": "code",
780 | "execution_count": null,
781 | "metadata": {
782 | "cellView": "form",
783 | "id": "bKgS7vZc4lN9"
784 | },
785 | "outputs": [],
786 | "source": [
787 | "\n",
788 | "#@markdown Enter a path to an audio file you want to alter, or leave blank to upload a file (.wav or .flac)\n",
789 | "file_path = \"\" #@param{type:\"string\"}\n",
790 | "\n",
791 | "#@markdown Total number of steps (100 is a good start, more steps trades off speed for quality)\n",
792 | "steps = 100#@param {type:\"number\"}\n",
793 | "\n",
794 | "#@markdown How much (0-1) to re-noise the original sample. Adding more noise (a higher number) means a bigger change to the input audio\n",
795 | "noise_level = 0.6#@param {type:\"number\"}\n",
796 | "\n",
797 | "#@markdown Multiplier on the default sample length from the model, allows for longer audio clips at the expense of VRAM\n",
798 | "sample_length_mult = 4#@param {type:\"number\"}\n",
799 | "\n",
800 | "#@markdown How many variations to create\n",
801 | "batch_size = 4 #@param {type:\"number\"}\n",
802 | "\n",
803 | "#@markdown Check the box below to save your generated audio to [Weights & Biases](https://www.wandb.ai/site)\n",
804 | "save_own_generations_to_wandb = False #@param {type: \"boolean\"}\n",
805 | "\n",
806 | "#@markdown Check the box below to skip this section when running all cells\n",
807 | "skip_for_run_all = False #@param {type: \"boolean\"}\n",
808 | "\n",
809 | "effective_length = args.sample_size * sample_length_mult\n",
810 | "\n",
811 | "if not skip_for_run_all:\n",
812 | " torch.cuda.empty_cache()\n",
813 | " gc.collect()\n",
814 | "\n",
815 | " if file_path == \"\":\n",
816 | " print(\"No file path provided, please upload a file\")\n",
817 | " uploaded = files.upload()\n",
818 | " file_path = list(uploaded.keys())[0]\n",
819 | "\n",
820 | " augs = torch.nn.Sequential(\n",
821 | " PadCrop(effective_length, randomize=True),\n",
822 | " Stereo()\n",
823 | " )\n",
824 | "\n",
825 | " audio_sample = load_to_device(file_path, args.sample_rate)\n",
826 | "\n",
827 | " audio_sample = augs(audio_sample).unsqueeze(0).repeat([batch_size, 1, 1])\n",
828 | "\n",
829 | " print(\"Initial audio sample\")\n",
830 | " plot_and_hear(audio_sample[0], args.sample_rate)\n",
831 | "\n",
832 | " t = torch.linspace(0, 1, steps + 1, device=device)\n",
833 | " step_list = get_crash_schedule(t)\n",
834 | " step_list = step_list[step_list < noise_level]\n",
835 | "\n",
836 | " alpha, sigma = t_to_alpha_sigma(step_list[-1])\n",
837 | " noised = torch.randn([batch_size, 2, effective_length], device='cuda')\n",
838 | " noised = audio_sample * alpha + noised * sigma\n",
839 | "\n",
840 | " generated = sampling.iplms_sample(model_fn, noised, step_list.flip(0)[:-1], {})\n",
841 | "\n",
842 | " print(\"Regenerated audio samples\")\n",
843 | " plot_and_hear(rearrange(generated, 'b d n -> d (b n)'), args.sample_rate)\n",
844 | "\n",
845 | " for ix, gen_sample in enumerate(generated):\n",
846 | " print(f'sample #{ix + 1}')\n",
847 | " display(ipd.Audio(gen_sample.cpu(), rate=args.sample_rate))\n",
848 | "\n",
849 | " # If Weights & Biases logging enabled, save generations\n",
850 | " if save_own_generations_to_wandb:\n",
851 | " # Check if logged in to wandb\n",
852 | " try:\n",
853 | " import netrc\n",
854 | " netrc.netrc().hosts['api.wandb.ai']\n",
855 | "\n",
856 | " log_audio_to_wandb(generated, model_name, custom_ckpt_path, steps, batch_size, \n",
857 | " args.sample_rate, args.sample_size, file_path=file_path, original_sample=audio_sample[0].cpu().numpy(),\n",
858 | " noise_level=noise_level, gen_type='own_file')\n",
859 | " except:\n",
860 | " print(\"Not logged in to Weights & Biases, please tick the `save_to_wandb` box at the top of this notebook and run that cell again to log in to Weights & Biases first\")\n",
861 | "\n",
862 | "else:\n",
863 | " print(\"Skipping section, uncheck 'skip_for_run_all' to enable\")"
864 | ]
865 | },
866 | {
867 | "cell_type": "markdown",
868 | "metadata": {
869 | "id": "vW8N8GCCM-yT"
870 | },
871 | "source": [
872 | "# Interpolate between sounds\n",
873 | "Diffusion models allow for interpolation between inputs through a process of deterministic noising and denoising. \n",
874 | "\n",
875 | "By deterministically noising two audio files, interpolating between the results, and deterministically denoising them, we can can create new sounds \"between\" the audio files provided."
876 | ]
877 | },
878 | {
879 | "cell_type": "code",
880 | "execution_count": null,
881 | "metadata": {
882 | "cellView": "form",
883 | "id": "l3Al3thgO5rb"
884 | },
885 | "outputs": [],
886 | "source": [
887 | "# Interpolation code taken and modified from CRASH\n",
888 | "def compute_interpolation_in_latent(latent1, latent2, lambd):\n",
889 | " '''\n",
890 | " Implementation of Spherical Linear Interpolation: https://en.wikipedia.org/wiki/Slerp\n",
891 | " latent1: tensor of shape (2, n)\n",
892 | " latent2: tensor of shape (2, n)\n",
893 | " lambd: list of floats between 0 and 1 representing the parameter t of the Slerp\n",
894 | " '''\n",
895 | " device = latent1.device\n",
896 | " lambd = torch.tensor(lambd)\n",
897 | "\n",
898 | " assert(latent1.shape[0] == latent2.shape[0])\n",
899 | "\n",
900 | " # get the number of channels\n",
901 | " nc = latent1.shape[0]\n",
902 | " interps = []\n",
903 | " for channel in range(nc):\n",
904 | " \n",
905 | " cos_omega = latent1[channel]@latent2[channel] / \\\n",
906 | " (torch.linalg.norm(latent1[channel])*torch.linalg.norm(latent2[channel]))\n",
907 | " omega = torch.arccos(cos_omega).item()\n",
908 | "\n",
909 | " a = torch.sin((1-lambd)*omega) / np.sin(omega)\n",
910 | " b = torch.sin(lambd*omega) / np.sin(omega)\n",
911 | " a = a.unsqueeze(1).to(device)\n",
912 | " b = b.unsqueeze(1).to(device)\n",
913 | " interps.append(a * latent1[channel] + b * latent2[channel])\n",
914 | " return rearrange(torch.cat(interps), \"(c b) n -> b c n\", c=nc) \n",
915 | "\n",
916 | "#@markdown Enter the paths to two audio files to interpolate between (.wav or .flac)\n",
917 | "source_audio_path = \"\" #@param{type:\"string\"}\n",
918 | "target_audio_path = \"\" #@param{type:\"string\"}\n",
919 | "\n",
920 | "#@markdown Total number of steps (100 is a good start, can go lower for more speed/less quality)\n",
921 | "steps = 100#@param {type:\"number\"}\n",
922 | "\n",
923 | "#@markdown Number of interpolated samples\n",
924 | "n_interps = 12 #@param {type:\"number\"}\n",
925 | "\n",
926 | "#@markdown Multiplier on the default sample length from the model, allows for longer audio clips at the expense of VRAM\n",
927 | "sample_length_mult = 1#@param {type:\"number\"}\n",
928 | "\n",
929 | "#@markdown Check the box below to skip this section when running all cells\n",
930 | "skip_for_run_all = False #@param {type: \"boolean\"}\n",
931 | "\n",
932 | "effective_length = args.sample_size * sample_length_mult\n",
933 | "\n",
934 | "if not skip_for_run_all:\n",
935 | "\n",
936 | " augs = torch.nn.Sequential(\n",
937 | " PadCrop(effective_length, randomize=True),\n",
938 | " Stereo()\n",
939 | " )\n",
940 | "\n",
941 | " if source_audio_path == \"\":\n",
942 | " print(\"No file path provided for the source audio, please upload a file\")\n",
943 | " uploaded = files.upload()\n",
944 | " source_audio_path = list(uploaded.keys())[0]\n",
945 | "\n",
946 | " audio_sample_1 = load_to_device(source_audio_path, args.sample_rate)\n",
947 | "\n",
948 | " print(\"Source audio sample loaded\")\n",
949 | "\n",
950 | " if target_audio_path == \"\":\n",
951 | " print(\"No file path provided for the target audio, please upload a file\")\n",
952 | " uploaded = files.upload()\n",
953 | " target_audio_path = list(uploaded.keys())[0]\n",
954 | "\n",
955 | " audio_sample_2 = load_to_device(target_audio_path, args.sample_rate)\n",
956 | "\n",
957 | " print(\"Target audio sample loaded\")\n",
958 | "\n",
959 | " audio_samples = augs(audio_sample_1).unsqueeze(0).repeat([2, 1, 1])\n",
960 | " audio_samples[1] = augs(audio_sample_2)\n",
961 | "\n",
962 | " print(\"Initial audio samples\")\n",
963 | " plot_and_hear(audio_samples[0], args.sample_rate)\n",
964 | " plot_and_hear(audio_samples[1], args.sample_rate)\n",
965 | "\n",
966 | " t = torch.linspace(0, 1, steps + 1, device=device)\n",
967 | " step_list = get_crash_schedule(t)\n",
968 | "\n",
969 | " reversed = sampling.iplms_sample(model_fn, audio_samples, step_list, {}, is_reverse=True)\n",
970 | "\n",
971 | " latent_series = compute_interpolation_in_latent(reversed[0], reversed[1], [k/n_interps for k in range(n_interps + 2)])\n",
972 | "\n",
973 | " generated = sampling.iplms_sample(model_fn, latent_series, step_list.flip(0)[:-1], {})\n",
974 | "\n",
975 | " # Put the demos together\n",
976 | " generated_all = rearrange(generated, 'b d n -> d (b n)')\n",
977 | "\n",
978 | " print(\"Full interpolation\")\n",
979 | " plot_and_hear(generated_all, args.sample_rate)\n",
980 | " for ix, gen_sample in enumerate(generated):\n",
981 | " print(f'sample #{ix + 1}')\n",
982 | " display(ipd.Audio(gen_sample.cpu(), rate=args.sample_rate))\n",
983 | "else:\n",
984 | " print(\"Skipping section, uncheck 'skip_for_run_all' to enable\") "
985 | ]
986 | }
987 | ],
988 | "metadata": {
989 | "accelerator": "GPU",
990 | "colab": {
991 | "collapsed_sections": [],
992 | "provenance": []
993 | },
994 | "gpuClass": "standard",
995 | "kernelspec": {
996 | "display_name": "Python 3",
997 | "name": "python3"
998 | },
999 | "language_info": {
1000 | "name": "python"
1001 | }
1002 | },
1003 | "nbformat": 4,
1004 | "nbformat_minor": 0
1005 | }
1006 |
--------------------------------------------------------------------------------
/Finetune_Dance_Diffusion.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "HHcTRGvUmoME"
7 | },
8 | "source": [
9 | "# Dance Diffusion finetune"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {
15 | "id": "u97w34BXmust"
16 | },
17 | "source": [
18 | "Licensed under the MIT License\n",
19 | "\n",
20 | "Copyright (c) 2022 Zach Evans\n",
21 | "\n",
22 | "Permission is hereby granted, free of charge, to any person obtaining a copy\n",
23 | "of this software and associated documentation files (the \"Software\"), to deal\n",
24 | "in the Software without restriction, including without limitation the rights\n",
25 | "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n",
26 | "copies of the Software, and to permit persons to whom the Software is\n",
27 | "furnished to do so, subject to the following conditions:\n",
28 | "\n",
29 | "The above copyright notice and this permission notice shall be included in\n",
30 | "all copies or substantial portions of the Software.\n",
31 | "\n",
32 | "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
33 | "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
34 | "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n",
35 | "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
36 | "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n",
37 | "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n",
38 | "THE SOFTWARE.\n"
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {
44 | "id": "lU97ZiP7nSKS"
45 | },
46 | "source": [
47 | "# Set Up"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "metadata": {
54 | "cellView": "form",
55 | "id": "mxb-qgh0nUOf"
56 | },
57 | "outputs": [],
58 | "source": [
59 | "#@title Check GPU Status\n",
60 | "import subprocess\n",
61 | "simple_nvidia_smi_display = True#@param {type:\"boolean\"}\n",
62 | "if simple_nvidia_smi_display:\n",
63 | " #!nvidia-smi\n",
64 | " nvidiasmi_output = subprocess.run(['nvidia-smi', '-L'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
65 | " print(nvidiasmi_output)\n",
66 | "else:\n",
67 | " #!nvidia-smi -i 0 -e 0\n",
68 | " nvidiasmi_output = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
69 | " print(nvidiasmi_output)\n",
70 | " nvidiasmi_ecc_note = subprocess.run(['nvidia-smi', '-i', '0', '-e', '0'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
71 | " print(nvidiasmi_ecc_note)"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": null,
77 | "metadata": {
78 | "cellView": "form",
79 | "id": "T_mFtzHvnlJL"
80 | },
81 | "outputs": [],
82 | "source": [
83 | "#@title Prepare folders\n",
84 | "import subprocess, os, sys, ipykernel\n",
85 | "\n",
86 | "def gitclone(url, targetdir=None):\n",
87 | " if targetdir:\n",
88 | " res = subprocess.run(['git', 'clone', url, targetdir], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
89 | " else:\n",
90 | " res = subprocess.run(['git', 'clone', url], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
91 | " print(res)\n",
92 | "\n",
93 | "def pipi(modulestr):\n",
94 | " res = subprocess.run(['pip', 'install', modulestr], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
95 | " print(res)\n",
96 | "\n",
97 | "def pipie(modulestr):\n",
98 | " res = subprocess.run(['git', 'install', '-e', modulestr], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
99 | " print(res)\n",
100 | "\n",
101 | "def wget(url, outputdir):\n",
102 | " res = subprocess.run(['wget', url, '-P', f'{outputdir}'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
103 | " print(res)\n",
104 | "\n",
105 | "try:\n",
106 | " from google.colab import drive\n",
107 | " print(\"Google Colab detected. Using Google Drive.\")\n",
108 | " is_colab = True\n",
109 | " google_drive = True #@param {type:\"boolean\"}\n",
110 | " #@markdown Click here if you'd like to save the diffusion model checkpoint file to (and/or load from) your Google Drive:\n",
111 | " save_models_to_google_drive = True #@param {type:\"boolean\"}\n",
112 | "except:\n",
113 | " is_colab = False\n",
114 | " google_drive = False\n",
115 | " save_models_to_google_drive = False\n",
116 | " print(\"Google Colab not detected.\")\n",
117 | "\n",
118 | "if is_colab:\n",
119 | " if google_drive is True:\n",
120 | " drive.mount('/content/drive')\n",
121 | " root_path = '/content/drive/MyDrive/AI/Bass_Diffusion'\n",
122 | " else:\n",
123 | " root_path = '/content'\n",
124 | "else:\n",
125 | " root_path = os.getcwd()\n",
126 | "\n",
127 | "import os\n",
128 | "def createPath(filepath):\n",
129 | " os.makedirs(filepath, exist_ok=True)\n",
130 | "\n",
131 | "initDirPath = f'{root_path}/init_audio'\n",
132 | "createPath(initDirPath)\n",
133 | "outDirPath = f'{root_path}/audio_out'\n",
134 | "createPath(outDirPath)\n",
135 | "\n",
136 | "if is_colab:\n",
137 | " if google_drive and not save_models_to_google_drive or not google_drive:\n",
138 | " model_path = '/content/models'\n",
139 | " createPath(model_path)\n",
140 | " if google_drive and save_models_to_google_drive:\n",
141 | " model_path = f'{root_path}/models'\n",
142 | " createPath(model_path)\n",
143 | "else:\n",
144 | " model_path = f'{root_path}/models'\n",
145 | " createPath(model_path)"
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "execution_count": null,
151 | "metadata": {
152 | "id": "y9BS0ks1oEgP",
153 | "cellView": "form"
154 | },
155 | "outputs": [],
156 | "source": [
157 | "#@title Install dependencies\n",
158 | "!git clone https://github.com/harmonai-org/sample-generator\n",
159 | "!pip install /content/sample-generator"
160 | ]
161 | },
162 | {
163 | "cell_type": "markdown",
164 | "source": [
165 | "# Train"
166 | ],
167 | "metadata": {
168 | "id": "0xq2TJzIPTcJ"
169 | }
170 | },
171 | {
172 | "cell_type": "code",
173 | "source": [
174 | "#@markdown Log in to [Weights & Biases](https://wandb.ai/) for run tracking\n",
175 | "!wandb login"
176 | ],
177 | "metadata": {
178 | "cellView": "form",
179 | "id": "oxJFFEZ8CD8g"
180 | },
181 | "execution_count": null,
182 | "outputs": []
183 | },
184 | {
185 | "cell_type": "code",
186 | "source": [
187 | "#@markdown Name for the finetune project, used as the W&B project name, as well as the directory for the saved checkpoints\n",
188 | "NAME=\"dd-drums-finetune\" #@param {type:\"string\"}\n",
189 | "\n",
190 | "#@markdown Path to the directory of audio data to use for fine-tuning\n",
191 | "TRAINING_DIR=\"/content/drive/MyDrive/Audio/Drums\" #@param {type:\"string\"}\n",
192 | "\n",
193 | "#@markdown Path to the checkpoint to fine-tune\n",
194 | "CKPT_PATH=\"/content/drive/MyDrive/AI/models/jmann-small-190k.ckpt\" #@param {type:\"string\"}\n",
195 | "\n",
196 | "#@markdown Directory path for saving the fine-tuned outputs\n",
197 | "OUTPUT_DIR=\"/content/drive/MyDrive/AI/models/DanceDiffusion/finetune\" #@param {type:\"string\"}\n",
198 | "\n",
199 | "#@markdown Number of training steps between demos\n",
200 | "DEMO_EVERY=250 #@param {type:\"number\"}\n",
201 | "\n",
202 | "#@markdown Number of training steps between saving model checkpoints\n",
203 | "CHECKPOINT_EVERY=500 #@param {type:\"number\"}\n",
204 | "\n",
205 | "#@markdown Sample rate to train at\n",
206 | "SAMPLE_RATE=48000 #@param {type:\"number\"}\n",
207 | "\n",
208 | "#@markdown Number of audio samples per training sample\n",
209 | "SAMPLE_SIZE=65536 #@param {type:\"number\"}\n",
210 | "\n",
211 | "#@markdown If true, the audio samples provided will be randomly cropped to SAMPLE_SIZE samples\n",
212 | "#@markdown\n",
213 | "#@markdown Turn off if you want to ensure the training data always starts at the beginning of the audio files (good for things like drum one-shots)\n",
214 | "RANDOM_CROP=True #@param {type:\"boolean\"}\n",
215 | "\n",
216 | "#@markdown Batch size to fine-tune (make it as high as it can go for your GPU)\n",
217 | "BATCH_SIZE=2 #@param {type:\"number\"}\n",
218 | "\n",
219 | "#@markdown Accumulate gradients over n batches, useful for training on one GPU. \n",
220 | "#@markdown\n",
221 | "#@markdown Effective batch size is BATCH_SIZE * ACCUM_BATCHES.\n",
222 | "#@markdown\n",
223 | "#@markdown Also increases the time between demos and saved checkpoints\n",
224 | "ACCUM_BATCHES=4 #@param {type:\"number\"}\n",
225 | "\n",
226 | "random_crop_str = f\"--random-crop True\" if RANDOM_CROP else \"\"\n",
227 | "\n",
228 | "# Escape spaces in paths\n",
229 | "CKPT_PATH = CKPT_PATH.replace(f\" \", f\"\\ \")\n",
230 | "OUTPUT_DIR = f\"{OUTPUT_DIR}/{NAME}\".replace(f\" \", f\"\\ \")\n",
231 | "\n",
232 | "%cd /content/sample-generator/\n",
233 | "\n",
234 | "!python3 /content/sample-generator/train_uncond.py --ckpt-path $CKPT_PATH --name $NAME --training-dir $TRAINING_DIR --sample-size $SAMPLE_SIZE --sample-rate $SAMPLE_RATE --batch-size $BATCH_SIZE --demo-every $DEMO_EVERY --demo-steps 100 --checkpoint-every $CHECKPOINT_EVERY --num-workers 2 $random_crop_str --save-path $OUTPUT_DIR"
235 | ],
236 | "metadata": {
237 | "id": "-Q0XrS0HEmch",
238 | "cellView": "form"
239 | },
240 | "execution_count": null,
241 | "outputs": []
242 | }
243 | ],
244 | "metadata": {
245 | "accelerator": "GPU",
246 | "colab": {
247 | "collapsed_sections": [
248 | "HHcTRGvUmoME"
249 | ],
250 | "name": "Finetune Dance Diffusion.ipynb",
251 | "provenance": []
252 | },
253 | "gpuClass": "standard",
254 | "kernelspec": {
255 | "display_name": "Python 3",
256 | "name": "python3"
257 | },
258 | "language_info": {
259 | "name": "python"
260 | }
261 | },
262 | "nbformat": 4,
263 | "nbformat_minor": 0
264 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Harmonai-org
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # sample-generator
2 | Tools to train a generative model on arbitrary audio samples
3 |
4 | Dance Diffusion notebook: [](https://colab.research.google.com/github/Harmonai-org/sample-generator/blob/main/Dance_Diffusion.ipynb)
5 |
6 | Dance Diffusion fine-tune notebook: [](https://colab.research.google.com/github/Harmonai-org/sample-generator/blob/main/Finetune_Dance_Diffusion.ipynb)
7 |
8 | ## Todo
9 |
10 | - [x] Add inference notebook
11 | - [x] Add interpolations to nobebook
12 | - [x] Add fine-tune notebook
13 | - [ ] Add guidance to notebook
14 |
--------------------------------------------------------------------------------
/audio_diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pollinations/dance-diffusion/1a8eb27c2c985920575b3908530c7478fe9200b7/audio_diffusion/__init__.py
--------------------------------------------------------------------------------
/audio_diffusion/blocks.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, main, skip=None):
8 | super().__init__()
9 | self.main = nn.Sequential(*main)
10 | self.skip = skip if skip else nn.Identity()
11 |
12 | def forward(self, input):
13 | return self.main(input) + self.skip(input)
14 |
15 | # Noise level (and other) conditioning
16 | class ResConvBlock(ResidualBlock):
17 | def __init__(self, c_in, c_mid, c_out, is_last=False):
18 | skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
19 | super().__init__([
20 | nn.Conv1d(c_in, c_mid, 5, padding=2),
21 | nn.GroupNorm(1, c_mid),
22 | nn.GELU(),
23 | nn.Conv1d(c_mid, c_out, 5, padding=2),
24 | nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
25 | nn.GELU() if not is_last else nn.Identity(),
26 | ], skip)
27 |
28 | class SelfAttention1d(nn.Module):
29 | def __init__(self, c_in, n_head=1, dropout_rate=0.):
30 | super().__init__()
31 | assert c_in % n_head == 0
32 | self.norm = nn.GroupNorm(1, c_in)
33 | self.n_head = n_head
34 | self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
35 | self.out_proj = nn.Conv1d(c_in, c_in, 1)
36 | self.dropout = nn.Dropout(dropout_rate, inplace=True)
37 |
38 | def forward(self, input):
39 | n, c, s = input.shape
40 | qkv = self.qkv_proj(self.norm(input))
41 | qkv = qkv.view(
42 | [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
43 | q, k, v = qkv.chunk(3, dim=1)
44 | scale = k.shape[3]**-0.25
45 | att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
46 | y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
47 | return input + self.dropout(self.out_proj(y))
48 |
49 | class SkipBlock(nn.Module):
50 | def __init__(self, *main):
51 | super().__init__()
52 | self.main = nn.Sequential(*main)
53 |
54 | def forward(self, input):
55 | return torch.cat([self.main(input), input], dim=1)
56 |
57 | class FourierFeatures(nn.Module):
58 | def __init__(self, in_features, out_features, std=1.):
59 | super().__init__()
60 | assert out_features % 2 == 0
61 | self.weight = nn.Parameter(torch.randn(
62 | [out_features // 2, in_features]) * std)
63 |
64 | def forward(self, input):
65 | f = 2 * math.pi * input @ self.weight.T
66 | return torch.cat([f.cos(), f.sin()], dim=-1)
67 |
68 |
69 | _kernels = {
70 | 'linear':
71 | [1 / 8, 3 / 8, 3 / 8, 1 / 8],
72 | 'cubic':
73 | [-0.01171875, -0.03515625, 0.11328125, 0.43359375,
74 | 0.43359375, 0.11328125, -0.03515625, -0.01171875],
75 | 'lanczos3':
76 | [0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
77 | -0.066637322306633, 0.13550527393817902, 0.44638532400131226,
78 | 0.44638532400131226, 0.13550527393817902, -0.066637322306633,
79 | -0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
80 | }
81 |
82 |
83 | class Downsample1d(nn.Module):
84 | def __init__(self, kernel='linear', pad_mode='reflect'):
85 | super().__init__()
86 | self.pad_mode = pad_mode
87 | kernel_1d = torch.tensor(_kernels[kernel])
88 | self.pad = kernel_1d.shape[0] // 2 - 1
89 | self.register_buffer('kernel', kernel_1d)
90 |
91 | def forward(self, x):
92 | x = F.pad(x, (self.pad,) * 2, self.pad_mode)
93 | weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
94 | indices = torch.arange(x.shape[1], device=x.device)
95 | weight[indices, indices] = self.kernel.to(weight)
96 | return F.conv1d(x, weight, stride=2)
97 |
98 |
99 | class Upsample1d(nn.Module):
100 | def __init__(self, kernel='linear', pad_mode='reflect'):
101 | super().__init__()
102 | self.pad_mode = pad_mode
103 | kernel_1d = torch.tensor(_kernels[kernel]) * 2
104 | self.pad = kernel_1d.shape[0] // 2 - 1
105 | self.register_buffer('kernel', kernel_1d)
106 |
107 | def forward(self, x):
108 | x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
109 | weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
110 | indices = torch.arange(x.shape[1], device=x.device)
111 | weight[indices, indices] = self.kernel.to(weight)
112 | return F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
113 |
--------------------------------------------------------------------------------
/audio_diffusion/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | from .blocks import SkipBlock, FourierFeatures, SelfAttention1d, ResConvBlock, Downsample1d, Upsample1d
6 | from .utils import append_dims, expand_to_planes
7 |
8 | class DiffusionAttnUnet1D(nn.Module):
9 | def __init__(
10 | self,
11 | global_args,
12 | io_channels = 2,
13 | depth=14,
14 | n_attn_layers = 6,
15 | c_mults = [128, 128, 256, 256] + [512] * 10
16 | ):
17 | super().__init__()
18 |
19 | self.timestep_embed = FourierFeatures(1, 16)
20 |
21 | attn_layer = depth - n_attn_layers - 1
22 |
23 | block = nn.Identity()
24 |
25 | conv_block = ResConvBlock
26 |
27 | for i in range(depth, 0, -1):
28 | c = c_mults[i - 1]
29 | if i > 1:
30 | c_prev = c_mults[i - 2]
31 | add_attn = i >= attn_layer and n_attn_layers > 0
32 | block = SkipBlock(
33 | Downsample1d("cubic"),
34 | conv_block(c_prev, c, c),
35 | SelfAttention1d(
36 | c, c // 32) if add_attn else nn.Identity(),
37 | conv_block(c, c, c),
38 | SelfAttention1d(
39 | c, c // 32) if add_attn else nn.Identity(),
40 | conv_block(c, c, c),
41 | SelfAttention1d(
42 | c, c // 32) if add_attn else nn.Identity(),
43 | block,
44 | conv_block(c * 2 if i != depth else c, c, c),
45 | SelfAttention1d(
46 | c, c // 32) if add_attn else nn.Identity(),
47 | conv_block(c, c, c),
48 | SelfAttention1d(
49 | c, c // 32) if add_attn else nn.Identity(),
50 | conv_block(c, c, c_prev),
51 | SelfAttention1d(c_prev, c_prev //
52 | 32) if add_attn else nn.Identity(),
53 | Upsample1d(kernel="cubic")
54 | # nn.Upsample(scale_factor=2, mode='linear',
55 | # align_corners=False),
56 | )
57 | else:
58 | block = nn.Sequential(
59 | conv_block(io_channels + 16 + global_args.latent_dim, c, c),
60 | conv_block(c, c, c),
61 | conv_block(c, c, c),
62 | block,
63 | conv_block(c * 2, c, c),
64 | conv_block(c, c, c),
65 | conv_block(c, c, io_channels, is_last=True),
66 | )
67 | self.net = block
68 |
69 | with torch.no_grad():
70 | for param in self.net.parameters():
71 | param *= 0.5
72 |
73 | def forward(self, input, t, cond=None):
74 | timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)
75 |
76 | inputs = [input, timestep_embed]
77 |
78 | if cond is not None:
79 | cond = F.interpolate(cond, (input.shape[2], ), mode='linear', align_corners=False)
80 | inputs.append(cond)
81 |
82 | return self.net(torch.cat(inputs, dim=1))
--------------------------------------------------------------------------------
/audio_diffusion/utils.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | import warnings
3 |
4 | import torch
5 | from torch import nn
6 | import random
7 | import math
8 | from torch import optim
9 |
10 | def append_dims(x, target_dims):
11 | """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
12 | dims_to_append = target_dims - x.ndim
13 | if dims_to_append < 0:
14 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
15 | return x[(...,) + (None,) * dims_to_append]
16 |
17 |
18 | def n_params(module):
19 | """Returns the number of trainable parameters in a module."""
20 | return sum(p.numel() for p in module.parameters())
21 |
22 |
23 | @contextmanager
24 | def train_mode(model, mode=True):
25 | """A context manager that places a model into training mode and restores
26 | the previous mode on exit."""
27 | modes = [module.training for module in model.modules()]
28 | try:
29 | yield model.train(mode)
30 | finally:
31 | for i, module in enumerate(model.modules()):
32 | module.training = modes[i]
33 |
34 |
35 | def eval_mode(model):
36 | """A context manager that places a model into evaluation mode and restores
37 | the previous mode on exit."""
38 | return train_mode(model, False)
39 |
40 | @torch.no_grad()
41 | def ema_update(model, averaged_model, decay):
42 | """Incorporates updated model parameters into an exponential moving averaged
43 | version of a model. It should be called after each optimizer step."""
44 | model_params = dict(model.named_parameters())
45 | averaged_params = dict(averaged_model.named_parameters())
46 | assert model_params.keys() == averaged_params.keys()
47 |
48 | for name, param in model_params.items():
49 | averaged_params[name].mul_(decay).add_(param, alpha=1 - decay)
50 |
51 | model_buffers = dict(model.named_buffers())
52 | averaged_buffers = dict(averaged_model.named_buffers())
53 | assert model_buffers.keys() == averaged_buffers.keys()
54 |
55 | for name, buf in model_buffers.items():
56 | averaged_buffers[name].copy_(buf)
57 |
58 |
59 | class EMAWarmup:
60 | """Implements an EMA warmup using an inverse decay schedule.
61 | If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
62 | good values for models you plan to train for a million or more steps (reaches decay
63 | factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
64 | you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
65 | 215.4k steps).
66 | Args:
67 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
68 | power (float): Exponential factor of EMA warmup. Default: 1.
69 | min_value (float): The minimum EMA decay rate. Default: 0.
70 | max_value (float): The maximum EMA decay rate. Default: 1.
71 | start_at (int): The epoch to start averaging at. Default: 0.
72 | last_epoch (int): The index of last epoch. Default: 0.
73 | """
74 |
75 | def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0,
76 | last_epoch=0):
77 | self.inv_gamma = inv_gamma
78 | self.power = power
79 | self.min_value = min_value
80 | self.max_value = max_value
81 | self.start_at = start_at
82 | self.last_epoch = last_epoch
83 |
84 | def state_dict(self):
85 | """Returns the state of the class as a :class:`dict`."""
86 | return dict(self.__dict__.items())
87 |
88 | def load_state_dict(self, state_dict):
89 | """Loads the class's state.
90 | Args:
91 | state_dict (dict): scaler state. Should be an object returned
92 | from a call to :meth:`state_dict`.
93 | """
94 | self.__dict__.update(state_dict)
95 |
96 | def get_value(self):
97 | """Gets the current EMA decay rate."""
98 | epoch = max(0, self.last_epoch - self.start_at)
99 | value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
100 | return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value))
101 |
102 | def step(self):
103 | """Updates the step count."""
104 | self.last_epoch += 1
105 |
106 |
107 | class InverseLR(optim.lr_scheduler._LRScheduler):
108 | """Implements an inverse decay learning rate schedule with an optional exponential
109 | warmup. When last_epoch=-1, sets initial lr as lr.
110 | inv_gamma is the number of steps/epochs required for the learning rate to decay to
111 | (1 / 2)**power of its original value.
112 | Args:
113 | optimizer (Optimizer): Wrapped optimizer.
114 | inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
115 | power (float): Exponential factor of learning rate decay. Default: 1.
116 | warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
117 | Default: 0.
118 | final_lr (float): The final learning rate. Default: 0.
119 | last_epoch (int): The index of last epoch. Default: -1.
120 | verbose (bool): If ``True``, prints a message to stdout for
121 | each update. Default: ``False``.
122 | """
123 |
124 | def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., final_lr=0.,
125 | last_epoch=-1, verbose=False):
126 | self.inv_gamma = inv_gamma
127 | self.power = power
128 | if not 0. <= warmup < 1:
129 | raise ValueError('Invalid value for warmup')
130 | self.warmup = warmup
131 | self.final_lr = final_lr
132 | super().__init__(optimizer, last_epoch, verbose)
133 |
134 | def get_lr(self):
135 | if not self._get_lr_called_within_step:
136 | warnings.warn("To get the last learning rate computed by the scheduler, "
137 | "please use `get_last_lr()`.")
138 |
139 | return self._get_closed_form_lr()
140 |
141 | def _get_closed_form_lr(self):
142 | warmup = 1 - self.warmup ** (self.last_epoch + 1)
143 | lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
144 | return [warmup * max(self.final_lr, base_lr * lr_mult)
145 | for base_lr in self.base_lrs]
146 |
147 |
148 | # Define the diffusion noise schedule
149 | def get_alphas_sigmas(t):
150 | return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
151 |
152 | def append_dims(x, target_dims):
153 | """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
154 | dims_to_append = target_dims - x.ndim
155 | if dims_to_append < 0:
156 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
157 | return x[(...,) + (None,) * dims_to_append]
158 |
159 | def expand_to_planes(input, shape):
160 | return input[..., None].repeat([1, 1, shape[2]])
161 |
162 | class PadCrop(nn.Module):
163 | def __init__(self, n_samples, randomize=True):
164 | super().__init__()
165 | self.n_samples = n_samples
166 | self.randomize = randomize
167 |
168 | def __call__(self, signal):
169 | n, s = signal.shape
170 | start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
171 | end = start + self.n_samples
172 | output = signal.new_zeros([n, self.n_samples])
173 | output[:, :min(s, self.n_samples)] = signal[:, start:end]
174 | return output
175 |
176 | class RandomPhaseInvert(nn.Module):
177 | def __init__(self, p=0.5):
178 | super().__init__()
179 | self.p = p
180 | def __call__(self, signal):
181 | return -signal if (random.random() < self.p) else signal
182 |
183 | class Stereo(nn.Module):
184 | def __call__(self, signal):
185 | signal_shape = signal.shape
186 | # Check if it's mono
187 | if len(signal_shape) == 1: # s -> 2, s
188 | signal = signal.unsqueeze(0).repeat(2, 1)
189 | elif len(signal_shape) == 2:
190 | if signal_shape[0] == 1: #1, s -> 2, s
191 | signal = signal.repeat(2, 1)
192 | elif signal_shape[0] > 2: #?, s -> 2,s
193 | signal = signal[:2, :]
194 |
195 | return signal
196 |
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | # Configuration for Cog ⚙️
2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3 |
4 | build:
5 | # set to true if your model requires a GPU
6 | gpu: true
7 |
8 | # a list of ubuntu apt packages to install
9 | system_packages:
10 | - "curl"
11 | - "ffmpeg"
12 |
13 | # python version in the form '3.8' or '3.8.12'
14 | python_version: "3.8"
15 |
16 | # a list of packages in the format ==
17 | python_packages:
18 | - "numpy==1.21.6"
19 | - "torchvision==0.10.0"
20 | - "torch==1.9.0"
21 | - "torchaudio==0.9.0"
22 |
23 | # commands run after the environment is setup
24 | run:
25 | - pip install --upgrade pip
26 | - pip install einops pandas prefigure pytorch_lightning scipy torch tqdm wandb
27 | - mkdir /models
28 | - git clone https://github.com/harmonai-org/sample-generator
29 | - git clone --recursive https://github.com/crowsonkb/v-diffusion-pytorch
30 | - pip install /sample-generator
31 | - pip install /v-diffusion-pytorch
32 | - pip install ipywidgets==7.7.1
33 | - pip install matplotlib soundfile
34 | - wget -O /models/gwf-440k.ckpt https://model-server.zqevans2.workers.dev/gwf-440k.ckpt > /dev/null 2>&1
35 | - wget -O /models/jmann-large-580k.ckpt https://model-server.zqevans2.workers.dev/jmann-large-580k.ckpt > /dev/null 2>&1
36 | - wget -O /models/maestro-uncond-150k.ckpt https://model-server.zqevans2.workers.dev/maestro-uncond-150k.ckpt > /dev/null 2>&1
37 | - wget -O /models/unlocked-uncond-250k.ckpt https://model-server.zqevans2.workers.dev/unlocked-uncond-250k.ckpt > /dev/null 2>&1
38 | predict: "predict.py:Predictor"
39 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pollinations/dance-diffusion/1a8eb27c2c985920575b3908530c7478fe9200b7/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchaudio
3 | from torchaudio import transforms as T
4 | import random
5 | from glob import glob
6 | import os
7 | from audio_diffusion.utils import Stereo, PadCrop, RandomPhaseInvert
8 | import tqdm
9 | from multiprocessing import Pool, cpu_count
10 | from functools import partial
11 |
12 | class SampleDataset(torch.utils.data.Dataset):
13 | def __init__(self, paths, global_args):
14 | super().__init__()
15 | self.filenames = []
16 |
17 | print(f"Random crop: {global_args.random_crop}")
18 |
19 | self.augs = torch.nn.Sequential(
20 | PadCrop(global_args.sample_size, randomize=global_args.random_crop),
21 | RandomPhaseInvert(),
22 | )
23 |
24 | self.encoding = torch.nn.Sequential(
25 | Stereo()
26 | )
27 |
28 | for path in paths:
29 | for ext in ['wav','flac','ogg','aiff','aif','mp3']:
30 | self.filenames += glob(f'{path}/**/*.{ext}', recursive=True)
31 |
32 | self.sr = global_args.sample_rate
33 | if hasattr(global_args,'load_frac'):
34 | self.load_frac = global_args.load_frac
35 | else:
36 | self.load_frac = 1.0
37 | self.num_gpus = global_args.num_gpus
38 |
39 | self.cache_training_data = global_args.cache_training_data
40 |
41 | if self.cache_training_data: self.preload_files()
42 |
43 |
44 | def load_file(self, filename):
45 | audio, sr = torchaudio.load(filename)
46 | if sr != self.sr:
47 | resample_tf = T.Resample(sr, self.sr)
48 | audio = resample_tf(audio)
49 | return audio
50 |
51 | def load_file_ind(self, file_list,i): # used when caching training data
52 | return self.load_file(file_list[i]).cpu()
53 |
54 | def get_data_range(self): # for parallel runs, only grab part of the data
55 | start, stop = 0, len(self.filenames)
56 | try:
57 | local_rank = int(os.environ["LOCAL_RANK"])
58 | world_size = int(os.environ["WORLD_SIZE"])
59 | interval = stop//world_size
60 | start, stop = local_rank*interval, (local_rank+1)*interval
61 | print("local_rank, world_size, start, stop =",local_rank, world_size, start, stop)
62 | return start, stop
63 | #rank = os.environ["RANK"]
64 | except KeyError as e: # we're on GPU 0 and the others haven't been initialized yet
65 | start, stop = 0, len(self.filenames)//self.num_gpus
66 | return start, stop
67 |
68 | def preload_files(self):
69 | n = int(len(self.filenames)*self.load_frac)
70 | print(f"Caching {n} input audio files:")
71 | wrapper = partial(self.load_file_ind, self.filenames)
72 | start, stop = self.get_data_range()
73 | with Pool(processes=cpu_count()) as p: # //8 to avoid FS bottleneck and/or too many processes (b/c * num_gpus)
74 | self.audio_files = list(tqdm.tqdm(p.imap(wrapper, range(start,stop)), total=stop-start))
75 |
76 | def __len__(self):
77 | return len(self.filenames)
78 |
79 | def __getitem__(self, idx):
80 | audio_filename = self.filenames[idx]
81 | try:
82 | if self.cache_training_data:
83 | audio = self.audio_files[idx] # .copy()
84 | else:
85 | audio = self.load_file(audio_filename)
86 |
87 | #Run augmentations on this sample (including random crop)
88 | if self.augs is not None:
89 | audio = self.augs(audio)
90 |
91 | audio = audio.clamp(-1, 1)
92 |
93 | #Encode the file to assist in prediction
94 | if self.encoding is not None:
95 | audio = self.encoding(audio)
96 |
97 | return (audio, audio_filename)
98 | except Exception as e:
99 | # print(f'Couldn\'t load file {audio_filename}: {e}')
100 | return self[random.randrange(len(self))]
--------------------------------------------------------------------------------
/defaults.ini:
--------------------------------------------------------------------------------
1 |
2 | [DEFAULTS]
3 |
4 | #name of the run
5 | name = dd-finetune
6 |
7 | # training data directory
8 | training_dir = ''
9 |
10 | # the batch size
11 | batch_size = 8
12 |
13 | # number of GPUs to use for training
14 | num_gpus = 1
15 |
16 | # number of nodes to use for training
17 | num_nodes = 1
18 |
19 | # number of CPU workers for the DataLoader
20 | num_workers = 2
21 |
22 | # Number of audio samples for the training input
23 | sample_size = 65536
24 |
25 | # Number of steps between demos
26 | demo_every = 1000
27 |
28 | # Number of denoising steps for the demos
29 | demo_steps = 250
30 |
31 | # Number of demos to create
32 | num_demos = 16
33 |
34 | # the EMA decay
35 | ema_decay = 0.995
36 |
37 | # the random seed
38 | seed = 42
39 |
40 | # Batches for gradient accumulation
41 | accum_batches = 4
42 |
43 | # The sample rate of the audio
44 | sample_rate = 48000
45 |
46 | # Number of steps between checkpoints
47 | checkpoint_every = 10000
48 |
49 | # unused, required by the model code
50 | latent_dim = 0
51 |
52 | # If true training data is kept in RAM
53 | cache_training_data = False
54 |
55 | # randomly crop input audio? (for augmentation)
56 | random_crop = True
57 |
58 | # checkpoint file to (re)start training from
59 | ckpt_path = ''
60 |
61 | # Path to output the model checkpoints
62 | save_path = ''
63 |
64 | #the multiprocessing start method ['fork', 'forkserver', 'spawn']
65 | start_method = 'spawn'
66 |
--------------------------------------------------------------------------------
/meta.json:
--------------------------------------------------------------------------------
1 | {
2 | "featured": true,
3 | "url": "pollinations/dance-diffusion",
4 | "key": "614871946825.dkr.ecr.us-east-1.amazonaws.com/pollinations/dance-diffusion",
5 | "name": "Dance Diffusion",
6 | "img": "https://images.squarespace-cdn.com/content/v1/62b318b33406753bcd6a17a5/ffa25141-8195-4916-8acf-7d5a44f08dfe/Transparent_Harmonai+Logo-02+%281%29.png?format=1500w",
7 | "path": "dancediffusion",
8 | "category": "8 Audio",
9 | "credits": "[harmonai.org](https://harmonai.org)",
10 | "description": "
\n\n\nDance Diffusion is the first in a suite of generative audio tools for producers and musicians to be released by Harmonai. For more info or to get involved in the development of these tools, please visit https://harmonai.org and fill out the form on the front page..",
11 | "pollinator_group": [
12 | "a100"
13 | ]
14 | }
15 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | # Prediction interface for Cog ⚙️
2 | # https://github.com/replicate/cog/blob/main/docs/python.md
3 |
4 | import gc
5 | import hashlib
6 | import json
7 | import math
8 | import os
9 | import random
10 | import sys
11 | from contextlib import contextmanager
12 | from copy import deepcopy
13 | from glob import glob
14 | from pathlib import Path
15 | from re import L
16 | from urllib.parse import urlparse
17 |
18 | import IPython.display as ipd
19 | import matplotlib.pyplot as plt
20 | import numpy as np
21 | import soundfile
22 | import torch
23 | import torchaudio
24 | import wandb
25 | from cog import BasePredictor, Input, Path
26 | from diffusion import sampling
27 | from einops import rearrange
28 | from prefigure.prefigure import get_all_args
29 | from torch import nn, optim
30 | from torch.nn import functional as F
31 | from torch.utils import data
32 | from tqdm import trange
33 |
34 | from audio_diffusion.models import DiffusionAttnUnet1D
35 | from audio_diffusion.utils import PadCrop, Stereo
36 |
37 | #@title Args
38 |
39 | latent_dim = 0
40 |
41 |
42 |
43 | def wget(url, outputdir):
44 | # Using the !wget command instead of the subprocess to get the loading bar
45 | os.system(f"wget {url} -O {outputdir}")
46 | # res = subprocess.run(['wget', url, '-P', f'{outputdir}'], stdout=subprocess.PIPE).stdout.decode('utf-8')
47 | # print(res)
48 |
49 | def report_status(**kwargs):
50 | status = json.dumps(kwargs)
51 | print(f"pollen_status: {status}")
52 |
53 |
54 | # #@markdown Number of steps (100 is a good start, more steps trades off speed for quality)
55 | # steps = 100 #@param {type:"number"}
56 |
57 | class Predictor(BasePredictor):
58 | def setup(self):
59 | """Load the model into memory to make running multiple predictions efficient"""
60 | # self.model = torch.load("./weights.pth")
61 | self.loaded_model_fn = None
62 | self.loaded_model_name = None
63 | os.system("ls -l /models")
64 | def predict(
65 | self,
66 | model_name: str = Input(description="Model", default = "maestro-150k", choices=["glitch-440k", "jmann-large-580k", "maestro-150k", "unlocked-250k"]),
67 | length: float = Input(description="Number of seconds to generate", default=8),
68 | batch_size: int = Input(description="How many samples to generate", default=1),
69 | steps: int = Input(description="Number of steps, higher numbers will give more refined output but will take longer. The maximum is 150.", default=100),
70 | ) -> Path:
71 | """Run a single prediction on the model"""
72 |
73 | # JSON encode {title: "Pimping your prompt", payload: prompt }
74 | #report_status(title="Translating", payload=prompt)
75 |
76 | args = Object()
77 |
78 | args.latent_dim = latent_dim
79 |
80 | #@title Create the model
81 | model_path = "/models"
82 |
83 | model_info = models_map[model_name]
84 | args.sample_rate = model_info["sample_rate"]
85 | args.sample_size = int(((args.sample_rate * length) // 8192) * 8192)
86 | print("sample_size", args.sample_size)
87 |
88 |
89 | if self.loaded_model_name != model_name:
90 | download_model(model_name,0,model_path)
91 | ckpt_path = f'{model_path}/{get_model_filename(model_name)}'
92 | print("Creating the model...")
93 | model = DiffusionUncond(args)
94 | model.load_state_dict(torch.load(ckpt_path)["state_dict"])
95 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
96 | model = model.requires_grad_(False).to(device)
97 | # # Remove non-EMA
98 | del model.diffusion
99 |
100 | self.loaded_model_fn = model.diffusion_ema
101 | print("Model created")
102 |
103 | model_fn = self.loaded_model_fn
104 |
105 |
106 |
107 | torch.cuda.empty_cache()
108 | gc.collect()
109 |
110 | # Generate random noise to sample from
111 | noise = torch.randn([batch_size, 2, args.sample_size]).to(device)
112 |
113 | t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
114 | step_list = get_crash_schedule(t)
115 |
116 | # Generate the samples from the noise
117 | generated = sampling.iplms_sample(model_fn, noise, step_list, {})
118 |
119 | # Hard-clip the generated audio
120 | generated = generated.clamp(-1, 1)
121 |
122 | print("All samples")
123 | # plot_and_hear(generated_all, args.sample_rate)
124 | samples = []
125 | for ix, gen_sample in enumerate(generated):
126 | print(f'sample #{ix + 1}')
127 | #audio = ipd.Audio(gen_sample.cpu(), rate=args.sample_rate)
128 | print(gen_sample.shape)
129 | samples.append(gen_sample)
130 | else:
131 | print("Skipping section, uncheck 'skip_for_run_all' to enable")
132 |
133 | # concatenate the samples
134 | samples = torch.cat(samples, dim=1)
135 | # save to disk (format is c n)
136 | soundfile.write(f'/tmp/sample.wav', samples.permute(1,0).cpu().numpy(), args.sample_rate)
137 | return Path(f"/tmp/sample.wav")
138 |
139 |
140 |
141 | #@title Model code
142 | class DiffusionUncond(nn.Module):
143 | def __init__(self, global_args):
144 | super().__init__()
145 |
146 | self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers = 4)
147 | self.diffusion_ema = deepcopy(self.diffusion)
148 | self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
149 |
150 | import IPython.display as ipd
151 | import matplotlib.pyplot as plt
152 |
153 |
154 | def plot_and_hear(audio, sr):
155 | display(ipd.Audio(audio.cpu().clamp(-1, 1), rate=sr))
156 | plt.plot(audio.cpu().t().numpy())
157 |
158 | def load_to_device(path, sr):
159 | audio, file_sr = torchaudio.load(path)
160 | if sr != file_sr:
161 | audio = torchaudio.transforms.Resample(file_sr, sr)(audio)
162 | audio = audio.to(device)
163 | return audio
164 |
165 | def get_alphas_sigmas(t):
166 | """Returns the scaling factors for the clean image (alpha) and for the
167 | noise (sigma), given a timestep."""
168 | return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
169 |
170 | def get_crash_schedule(t):
171 | sigma = torch.sin(t * math.pi / 2) ** 2
172 | alpha = (1 - sigma ** 2) ** 0.5
173 | return alpha_sigma_to_t(alpha, sigma)
174 |
175 | def t_to_alpha_sigma(t):
176 | """Returns the scaling factors for the clean image and for the noise, given
177 | a timestep."""
178 | return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
179 |
180 | def alpha_sigma_to_t(alpha, sigma):
181 | """Returns a timestep, given the scaling factors for the clean image and for
182 | the noise."""
183 | return torch.atan2(sigma, alpha) / math.pi * 2
184 |
185 | class Object(object):
186 | pass
187 |
188 |
189 |
190 | #@title Logging
191 | def get_one_channel(audio_data, channel):
192 | '''
193 | Takes a numpy audio array and returns 1 channel
194 | '''
195 | # Check if the audio has more than 1 channel
196 | if len(audio_data.shape) > 1:
197 | is_stereo = True
198 | if np.argmax(audio_data.shape)==0:
199 | audio_data = audio_data[:,channel]
200 | else:
201 | audio_data = audio_data[channel,:]
202 | else:
203 | is_stereo = False
204 |
205 | return audio_data
206 |
207 | print("hey")
208 |
209 |
210 |
211 | def get_model_filename(diffusion_model_name):
212 | model_uri = models_map[diffusion_model_name]['uri_list'][0]
213 | model_filename = os.path.basename(urlparse(model_uri).path)
214 | return model_filename
215 |
216 | def download_model(diffusion_model_name, uri_index=0, model_path='/models'):
217 | if diffusion_model_name != 'custom':
218 | model_filename = get_model_filename(diffusion_model_name)
219 | model_local_path = os.path.join(model_path, model_filename)
220 |
221 |
222 | if not models_map[diffusion_model_name]['downloaded']:
223 | for model_uri in models_map[diffusion_model_name]['uri_list']:
224 | wget(model_uri, model_local_path)
225 | with open(model_local_path, "rb") as f:
226 | bytes = f.read()
227 | hash = hashlib.sha256(bytes).hexdigest()
228 | print(f'SHA: {hash}')
229 | if os.path.exists(model_local_path):
230 | models_map[diffusion_model_name]['downloaded'] = True
231 | return
232 | else:
233 | print(f'{diffusion_model_name} model download from {model_uri} failed. Will try any fallback uri.')
234 | print(f'{diffusion_model_name} download failed.')
235 |
236 |
237 | models_map = {
238 |
239 | "glitch-440k": {'downloaded': True,
240 | 'sha': "48caefdcbb7b15e1a0b3d08587446936302535de74b0e05e0d61beba865ba00a",
241 | 'uri_list': ["https://model-server.zqevans2.workers.dev/gwf-440k.ckpt"],
242 | 'sample_rate': 48000,
243 | 'sample_size': 65536
244 | },
245 | "jmann-small-190k": {'downloaded': False,
246 | 'sha': "1e2a23a54e960b80227303d0495247a744fa1296652148da18a4da17c3784e9b",
247 | 'uri_list': ["https://model-server.zqevans2.workers.dev/jmann-small-190k.ckpt"],
248 | 'sample_rate': 48000,
249 | 'sample_size': 65536
250 | },
251 | "jmann-large-580k": {'downloaded': True,
252 | 'sha': "6b32b5ff1c666c4719da96a12fd15188fa875d6f79f8dd8e07b4d54676afa096",
253 | 'uri_list': ["https://model-server.zqevans2.workers.dev/jmann-large-580k.ckpt"],
254 | 'sample_rate': 48000,
255 | 'sample_size': 131072
256 | },
257 | "maestro-150k": {'downloaded': True,
258 | 'sha': "49d9abcae642e47c2082cec0b2dce95a45dc6e961805b6500204e27122d09485",
259 | 'uri_list': ["https://model-server.zqevans2.workers.dev/maestro-uncond-150k.ckpt"],
260 | 'sample_rate': 16000,
261 | 'sample_size': 65536
262 | },
263 | "unlocked-250k": {'downloaded': True,
264 | 'sha': "af337c8416732216eeb52db31dcc0d49a8d48e2b3ecaa524cb854c36b5a3503a",
265 | 'uri_list': ["https://model-server.zqevans2.workers.dev/unlocked-uncond-250k.ckpt"],
266 | 'sample_rate': 16000,
267 | 'sample_size': 65536
268 | },
269 | "honk-140k": {'downloaded': False,
270 | 'sha': "a66847844659d287f55b7adbe090224d55aeafdd4c2b3e1e1c6a02992cb6e792",
271 | 'uri_list': ["https://model-server.zqevans2.workers.dev/honk-140k.ckpt"],
272 | 'sample_rate': 16000,
273 | 'sample_size': 65536
274 | },
275 | }
276 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='audio-diffusion',
5 | version='1.0.0',
6 | url='https://github.com/zqevans/audio-diffusion.git',
7 | author='Zach Evans',
8 | packages=find_packages(),
9 | install_requires=[
10 | 'einops',
11 | 'pandas',
12 | 'prefigure',
13 | 'pytorch_lightning',
14 | 'scipy',
15 | 'torch',
16 | 'torchaudio',
17 | 'tqdm',
18 | 'wandb',
19 | ],
20 | )
--------------------------------------------------------------------------------
/train_uncond.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from prefigure.prefigure import get_all_args, push_wandb_config
4 | from contextlib import contextmanager
5 | from copy import deepcopy
6 | import math
7 | from pathlib import Path
8 |
9 | import sys
10 | import torch
11 | from torch import optim, nn
12 | from torch.nn import functional as F
13 | from torch.utils import data
14 | from tqdm import trange
15 | import pytorch_lightning as pl
16 | from pytorch_lightning.utilities.distributed import rank_zero_only
17 | from einops import rearrange
18 | import torchaudio
19 | import wandb
20 |
21 | from dataset.dataset import SampleDataset
22 |
23 | from audio_diffusion.models import DiffusionAttnUnet1D
24 | from audio_diffusion.utils import ema_update
25 | from viz.viz import audio_spectrogram_image
26 |
27 |
28 | # Define the noise schedule and sampling loop
29 | def get_alphas_sigmas(t):
30 | """Returns the scaling factors for the clean image (alpha) and for the
31 | noise (sigma), given a timestep."""
32 | return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
33 |
34 | def get_crash_schedule(t):
35 | sigma = torch.sin(t * math.pi / 2) ** 2
36 | alpha = (1 - sigma ** 2) ** 0.5
37 | return alpha_sigma_to_t(alpha, sigma)
38 |
39 | def alpha_sigma_to_t(alpha, sigma):
40 | """Returns a timestep, given the scaling factors for the clean image and for
41 | the noise."""
42 | return torch.atan2(sigma, alpha) / math.pi * 2
43 |
44 | @torch.no_grad()
45 | def sample(model, x, steps, eta):
46 | """Draws samples from a model given starting noise."""
47 | ts = x.new_ones([x.shape[0]])
48 |
49 | # Create the noise schedule
50 | t = torch.linspace(1, 0, steps + 1)[:-1]
51 |
52 | t = get_crash_schedule(t)
53 |
54 | alphas, sigmas = get_alphas_sigmas(t)
55 |
56 | # The sampling loop
57 | for i in trange(steps):
58 |
59 | # Get the model output (v, the predicted velocity)
60 | with torch.cuda.amp.autocast():
61 | v = model(x, ts * t[i]).float()
62 |
63 | # Predict the noise and the denoised image
64 | pred = x * alphas[i] - v * sigmas[i]
65 | eps = x * sigmas[i] + v * alphas[i]
66 |
67 | # If we are not on the last timestep, compute the noisy image for the
68 | # next timestep.
69 | if i < steps - 1:
70 | # If eta > 0, adjust the scaling factor for the predicted noise
71 | # downward according to the amount of additional noise to add
72 | ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
73 | (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
74 | adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()
75 |
76 | # Recombine the predicted noise and predicted denoised image in the
77 | # correct proportions for the next step
78 | x = pred * alphas[i + 1] + eps * adjusted_sigma
79 |
80 | # Add the correct amount of fresh noise
81 | if eta:
82 | x += torch.randn_like(x) * ddim_sigma
83 |
84 | # If we are on the last timestep, output the denoised image
85 | return pred
86 |
87 |
88 |
89 | class DiffusionUncond(pl.LightningModule):
90 | def __init__(self, global_args):
91 | super().__init__()
92 |
93 | self.diffusion = DiffusionAttnUnet1D(global_args, io_channels=2, n_attn_layers=4)
94 | self.diffusion_ema = deepcopy(self.diffusion)
95 | self.rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=global_args.seed)
96 | self.ema_decay = global_args.ema_decay
97 |
98 | def configure_optimizers(self):
99 | return optim.Adam([*self.diffusion.parameters()], lr=4e-5)
100 |
101 | def training_step(self, batch, batch_idx):
102 | reals = batch[0]
103 |
104 | # Draw uniformly distributed continuous timesteps
105 | t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
106 |
107 | t = get_crash_schedule(t)
108 |
109 | # Calculate the noise schedule parameters for those timesteps
110 | alphas, sigmas = get_alphas_sigmas(t)
111 |
112 | # Combine the ground truth images and the noise
113 | alphas = alphas[:, None, None]
114 | sigmas = sigmas[:, None, None]
115 | noise = torch.randn_like(reals)
116 | noised_reals = reals * alphas + noise * sigmas
117 | targets = noise * alphas - reals * sigmas
118 |
119 | with torch.cuda.amp.autocast():
120 | v = self.diffusion(noised_reals, t)
121 | mse_loss = F.mse_loss(v, targets)
122 | loss = mse_loss
123 |
124 | log_dict = {
125 | 'train/loss': loss.detach(),
126 | 'train/mse_loss': mse_loss.detach(),
127 | }
128 |
129 | self.log_dict(log_dict, prog_bar=True, on_step=True)
130 | return loss
131 |
132 | def on_before_zero_grad(self, *args, **kwargs):
133 | decay = 0.95 if self.current_epoch < 25 else self.ema_decay
134 | ema_update(self.diffusion, self.diffusion_ema, decay)
135 |
136 | class ExceptionCallback(pl.Callback):
137 | def on_exception(self, trainer, module, err):
138 | print(f'{type(err).__name__}: {err}', file=sys.stderr)
139 |
140 |
141 | class DemoCallback(pl.Callback):
142 | def __init__(self, global_args):
143 | super().__init__()
144 | self.demo_every = global_args.demo_every
145 | self.num_demos = global_args.num_demos
146 | self.demo_samples = global_args.sample_size
147 | self.demo_steps = global_args.demo_steps
148 | self.sample_rate = global_args.sample_rate
149 | self.last_demo_step = -1
150 |
151 | @rank_zero_only
152 | @torch.no_grad()
153 | #def on_train_epoch_end(self, trainer, module):
154 | def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx):
155 |
156 | if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
157 | return
158 |
159 | self.last_demo_step = trainer.global_step
160 |
161 | noise = torch.randn([self.num_demos, 2, self.demo_samples]).to(module.device)
162 |
163 | try:
164 | fakes = sample(module.diffusion_ema, noise, self.demo_steps, 0)
165 |
166 | # Put the demos together
167 | fakes = rearrange(fakes, 'b d n -> d (b n)')
168 |
169 | log_dict = {}
170 |
171 | filename = f'demo_{trainer.global_step:08}.wav'
172 | fakes = fakes.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
173 | torchaudio.save(filename, fakes, self.sample_rate)
174 |
175 |
176 | log_dict[f'demo'] = wandb.Audio(filename,
177 | sample_rate=self.sample_rate,
178 | caption=f'Demo')
179 |
180 | log_dict[f'demo_melspec_left'] = wandb.Image(audio_spectrogram_image(fakes))
181 |
182 | trainer.logger.experiment.log(log_dict, step=trainer.global_step)
183 | except Exception as e:
184 | print(f'{type(e).__name__}: {e}', file=sys.stderr)
185 |
186 | def main():
187 |
188 | args = get_all_args()
189 |
190 | args.latent_dim = 0
191 |
192 | save_path = None if args.save_path == "" else args.save_path
193 |
194 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
195 | print('Using device:', device)
196 | torch.manual_seed(args.seed)
197 |
198 | train_set = SampleDataset([args.training_dir], args)
199 | train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True,
200 | num_workers=args.num_workers, persistent_workers=True, pin_memory=True)
201 | wandb_logger = pl.loggers.WandbLogger(project=args.name)
202 |
203 | exc_callback = ExceptionCallback()
204 | ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, save_top_k=-1, dirpath=save_path)
205 | demo_callback = DemoCallback(args)
206 |
207 | diffusion_model = DiffusionUncond(args)
208 |
209 | wandb_logger.watch(diffusion_model)
210 | push_wandb_config(wandb_logger, args)
211 |
212 | diffusion_trainer = pl.Trainer(
213 | gpus=args.num_gpus,
214 | accelerator="gpu",
215 | # num_nodes = args.num_nodes,
216 | # strategy='ddp',
217 | precision=16,
218 | accumulate_grad_batches=args.accum_batches,
219 | callbacks=[ckpt_callback, demo_callback, exc_callback],
220 | logger=wandb_logger,
221 | log_every_n_steps=1,
222 | max_epochs=10000000,
223 | )
224 |
225 | diffusion_trainer.fit(diffusion_model, train_dl, ckpt_path=args.ckpt_path)
226 |
227 | if __name__ == '__main__':
228 | main()
229 |
230 |
--------------------------------------------------------------------------------
/viz/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pollinations/dance-diffusion/1a8eb27c2c985920575b3908530c7478fe9200b7/viz/__init__.py
--------------------------------------------------------------------------------
/viz/viz.py:
--------------------------------------------------------------------------------
1 |
2 | import math
3 | from pathlib import Path
4 | from matplotlib.backends.backend_agg import FigureCanvasAgg
5 | import matplotlib.cm as cm
6 | import matplotlib.pyplot as plt
7 | from matplotlib.colors import Normalize
8 | from matplotlib.figure import Figure
9 | import numpy as np
10 | from PIL import Image
11 |
12 | import torch
13 | from torch import optim, nn
14 | from torch.nn import functional as F
15 | import torchaudio
16 | import torchaudio.transforms as T
17 | import librosa
18 | from einops import rearrange
19 |
20 | import wandb
21 | import numpy as np
22 | import pandas as pd
23 |
24 | def spectrogram_image(spec, title=None, ylabel='freq_bin', aspect='auto', xmax=None, db_range=[35,120]):
25 | """
26 | # cf. https://pytorch.org/tutorials/beginner/audio_feature_extractions_tutorial.html
27 |
28 | """
29 | fig = Figure(figsize=(5, 4), dpi=100)
30 | canvas = FigureCanvasAgg(fig)
31 | axs = fig.add_subplot()
32 | axs.set_title(title or 'Spectrogram (db)')
33 | axs.set_ylabel(ylabel)
34 | axs.set_xlabel('frame')
35 | im = axs.imshow(librosa.power_to_db(spec), origin='lower', aspect=aspect, vmin=db_range[0], vmax=db_range[1])
36 | if xmax:
37 | axs.set_xlim((0, xmax))
38 | fig.colorbar(im, ax=axs)
39 | canvas.draw()
40 | rgba = np.asarray(canvas.buffer_rgba())
41 | return Image.fromarray(rgba)
42 |
43 |
44 | def audio_spectrogram_image(waveform, power=2.0, sample_rate=48000):
45 | """
46 | # cf. https://pytorch.org/tutorials/beginner/audio_feature_extractions_tutorial.html
47 | """
48 | n_fft = 1024
49 | win_length = None
50 | hop_length = 512
51 | n_mels = 80
52 |
53 | mel_spectrogram_op = T.MelSpectrogram(
54 | sample_rate=sample_rate, n_fft=n_fft, win_length=win_length,
55 | hop_length=hop_length, center=True, pad_mode="reflect", power=power,
56 | norm='slaney', onesided=True, n_mels=n_mels, mel_scale="htk")
57 |
58 | melspec = mel_spectrogram_op(waveform.float())
59 | melspec = melspec[0] # TODO: only left channel for now
60 | return spectrogram_image(melspec, title="MelSpectrogram", ylabel='mel bins (log freq)')
61 |
--------------------------------------------------------------------------------