├── .gitignore ├── FUNDING.yml ├── LICENSE ├── README.md ├── activate_venv.bat ├── audio_diffusion ├── __init__.py ├── blocks.py ├── models.py └── utils.py ├── dataset ├── __init__.py └── dataset.py ├── defaults.ini ├── example_launch_command.txt ├── make_audio_chunks.ipynb ├── setup.py ├── train_latent_cond.py ├── training_args.md ├── utils └── patch_bnb.py ├── viz ├── __init__.py └── viz.py └── windows_setup.cmd /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Dataset 132 | audio_chunks/ 133 | 134 | # Models 135 | models/ 136 | 137 | # Demos 138 | demo_* -------------------------------------------------------------------------------- /FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: serp-ai 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Harmonai-org 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # README 🎁 2 | 3 | # About 4 | 5 | This repo is for text-to-audio diffusion utilizing a denoising unet and Meta's Encodec. The unet is trained to denoise Encodec's encoded codebooks while taking in t5 text embeddings as conditioning. Encodec's decoder can then take the denoised codebooks, and decode it to the uncompressed .wav file. 6 | 7 | The architecture is by no means perfect as it is being actively tested/worked on. If you have any suggestions for improvements to try please don't hesistate to let us know! 8 | 9 | # Instructions 10 | 11 | - Clone the repo 12 | - Set up your environment 13 | - Launch the `train_latent_cond.py` file with accelerate (`example_launch_command.txt` in root directory for an example) 14 | - `training_args.md` in root directory for argument explanations 15 | - Inferencing scripts/notebooks/trained models coming soon 16 | 17 | # Shout Outs 18 | 19 | - Thanks to [Hugging Face](https://huggingface.co/) for diffusers/transformers and being a huge contribution to the open source community 20 | - Thanks to [HarmonAI](https://www.harmonai.org/) for their audio diffusion research and contributions to the open source community 21 | - Thanks to [Stable Diffusion](https://stability.ai/) and OpenAI for the unet/cross-attention base code and for their open source contributions 22 | - Thanks to [Meta](https://github.com/facebookresearch/encodec) for open sourcing Encodec and all of their other open source contributions 23 | - Thanks to [Google](https://github.com/google-research/text-to-text-transfer-transformer) for open sourcing the t5 large language model. 24 | - Shoutout to [EveryDream](https://github.com/victorchall/EveryDream2trainer) for windows venv setup and bnb patch 25 | -------------------------------------------------------------------------------- /activate_venv.bat: -------------------------------------------------------------------------------- 1 | call venv/scripts/activate.bat -------------------------------------------------------------------------------- /audio_diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serp-ai/ai-text-to-audio-latent-diffusion/55230601b4f34b30bc52568f58619c8c33b3e202/audio_diffusion/__init__.py -------------------------------------------------------------------------------- /audio_diffusion/blocks.py: -------------------------------------------------------------------------------- 1 | from torch import nn, einsum 2 | from abc import abstractmethod 3 | from einops import rearrange, repeat 4 | from typing import Optional, Any 5 | from inspect import isfunction 6 | import os 7 | import numpy as np 8 | import math 9 | from torch.nn import functional as F 10 | import torch 11 | 12 | _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") 13 | # dummy replace 14 | def convert_module_to_f16(x): 15 | pass 16 | 17 | def convert_module_to_f32(x): 18 | pass 19 | 20 | try: 21 | import xformers 22 | import xformers.ops 23 | XFORMERS_IS_AVAILBLE = True 24 | except: 25 | XFORMERS_IS_AVAILBLE = False 26 | print("No module 'xformers'. Proceeding without it.") 27 | 28 | class TimestepBlock(nn.Module): 29 | """ 30 | From https://github.com/Stability-AI/stablediffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py 31 | Any module where forward() takes timestep embeddings as a second argument. 32 | """ 33 | 34 | @abstractmethod 35 | def forward(self, x, emb): 36 | """ 37 | Apply the module to `x` given `emb` timestep embeddings. 38 | """ 39 | 40 | 41 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 42 | """ 43 | From https://github.com/Stability-AI/stablediffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py 44 | A sequential module that passes timestep embeddings to the children that 45 | support it as an extra input. 46 | """ 47 | 48 | def forward(self, x, emb, context=None): 49 | for layer in self: 50 | if isinstance(layer, TimestepBlock): 51 | x = layer(x, emb) 52 | elif isinstance(layer, AudioTransformer): 53 | x = layer(x, context) 54 | else: 55 | x = layer(x) 56 | return x 57 | 58 | 59 | class ResBlock(TimestepBlock): 60 | """ 61 | Adapted from https://github.com/Stability-AI/stablediffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py 62 | A residual block that can optionally change the number of channels. 63 | :param channels: the number of input channels. 64 | :param emb_channels: the number of timestep embedding channels. 65 | :param dropout: the rate of dropout. 66 | :param out_channels: if specified, the number of out channels. 67 | :param use_conv: if True and out_channels is specified, use a spatial 68 | convolution instead of a smaller 1x1 convolution to change the 69 | channels in the skip connection. 70 | :param dims: determines if the signal is 1D, 2D, or 3D. 71 | :param use_checkpoint: if True, use gradient checkpointing on this module. 72 | :param up: if True, use this block for upsampling. 73 | :param down: if True, use this block for downsampling. 74 | """ 75 | 76 | def __init__( 77 | self, 78 | channels, 79 | emb_channels, 80 | dropout, 81 | out_channels=None, 82 | use_conv=False, 83 | use_scale_shift_norm=False, 84 | dims=1, 85 | use_checkpoint=False, 86 | up=False, 87 | down=False, 88 | ): 89 | super().__init__() 90 | self.channels = channels 91 | self.emb_channels = emb_channels 92 | self.dropout = dropout 93 | self.out_channels = out_channels or channels 94 | self.use_conv = use_conv 95 | self.use_checkpoint = use_checkpoint 96 | self.use_scale_shift_norm = use_scale_shift_norm 97 | 98 | self.in_layers = nn.Sequential( 99 | normalization(channels), 100 | nn.SiLU(), 101 | conv_nd(dims, channels, self.out_channels, 5, padding=2), 102 | ) 103 | 104 | self.updown = up or down 105 | 106 | if up: 107 | self.h_upd = Upsample(channels, False, dims) 108 | self.x_upd = Upsample(channels, False, dims) 109 | elif down: 110 | self.h_upd = Downsample(channels, False, dims) 111 | self.x_upd = Downsample(channels, False, dims) 112 | else: 113 | self.h_upd = self.x_upd = nn.Identity() 114 | 115 | self.emb_layers = nn.Sequential( 116 | nn.SiLU(), 117 | linear( 118 | emb_channels, 119 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels 120 | ), 121 | ) 122 | self.out_layers = nn.Sequential( 123 | normalization(self.out_channels), 124 | nn.SiLU(), 125 | nn.Dropout(p=dropout), 126 | zero_module( 127 | conv_nd(dims, self.out_channels, self.out_channels, 5, padding=2) 128 | ), 129 | ) 130 | 131 | if self.out_channels == channels: 132 | self.skip_connection = nn.Identity() 133 | elif use_conv: 134 | self.skip_connection = conv_nd( 135 | dims, channels, self.out_channels, 1, bias=False 136 | ) 137 | else: 138 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, bias=False) 139 | 140 | def forward(self, x, emb): 141 | """ 142 | Apply the block to a Tensor, conditioned on a timestep embedding. 143 | :param x: an [N x C x ...] Tensor of features. 144 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 145 | :return: an [N x C x ...] Tensor of outputs. 146 | """ 147 | return checkpoint( 148 | self._forward, (x, emb), self.parameters(), self.use_checkpoint 149 | ) 150 | 151 | 152 | def _forward(self, x, emb): 153 | if self.updown: 154 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 155 | h = in_rest(x) 156 | h = self.h_upd(h) 157 | x = self.x_upd(x) 158 | h = in_conv(h) 159 | else: 160 | h = self.in_layers(x) 161 | emb_out = self.emb_layers(emb).type(h.dtype) 162 | while len(emb_out.shape) < len(h.shape): 163 | emb_out = emb_out[..., None] 164 | if self.use_scale_shift_norm: 165 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 166 | scale, shift = torch.chunk(emb_out, 2, dim=1) 167 | h = out_norm(h) * (1 + scale) + shift 168 | h = out_rest(h) 169 | else: 170 | h = h + emb_out 171 | h = self.out_layers(h) 172 | return self.skip_connection(x) + h 173 | 174 | 175 | class AttentionBlock(nn.Module): 176 | """ 177 | From https://github.com/Stability-AI/stablediffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py 178 | An attention block that allows spatial positions to attend to each other. 179 | Originally ported from here, but adapted to the N-d case. 180 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 181 | """ 182 | 183 | def __init__( 184 | self, 185 | channels, 186 | num_heads=1, 187 | num_head_channels=-1, 188 | use_checkpoint=False, 189 | use_new_attention_order=False, 190 | ): 191 | super().__init__() 192 | self.channels = channels 193 | if num_head_channels == -1: 194 | self.num_heads = num_heads 195 | else: 196 | assert ( 197 | channels % num_head_channels == 0 198 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" 199 | self.num_heads = channels // num_head_channels 200 | self.use_checkpoint = use_checkpoint 201 | self.norm = normalization(channels) 202 | self.qkv = conv_nd(1, channels, channels * 3, 1) 203 | if use_new_attention_order: 204 | # split qkv before split heads 205 | self.attention = QKVAttention(self.num_heads) 206 | else: 207 | # split heads before split qkv 208 | self.attention = QKVAttentionLegacy(self.num_heads) 209 | 210 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) 211 | 212 | def forward(self, x): 213 | return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! 214 | #return pt_checkpoint(self._forward, x) # pytorch 215 | 216 | def _forward(self, x): 217 | b, c, *spatial = x.shape 218 | x = x.reshape(b, c, -1) 219 | qkv = self.qkv(self.norm(x)) 220 | h = self.attention(qkv) 221 | h = self.proj_out(h) 222 | return (x + h).reshape(b, c, *spatial) 223 | 224 | 225 | def count_flops_attn(model, _x, y): 226 | """ 227 | From https://github.com/Stability-AI/stablediffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py 228 | A counter for the `thop` package to count the operations in an 229 | attention operation. 230 | Meant to be used like: 231 | macs, params = thop.profile( 232 | model, 233 | inputs=(inputs, timestamps), 234 | custom_ops={QKVAttention: QKVAttention.count_flops}, 235 | ) 236 | """ 237 | b, c, *spatial = y[0].shape 238 | num_spatial = int(np.prod(spatial)) 239 | # We perform two matmuls with the same number of ops. 240 | # The first computes the weight matrix, the second computes 241 | # the combination of the value vectors. 242 | matmul_ops = 2 * b * (num_spatial ** 2) * c 243 | model.total_ops += torch.DoubleTensor([matmul_ops]) 244 | 245 | 246 | class QKVAttentionLegacy(nn.Module): 247 | """ 248 | From https://github.com/Stability-AI/stablediffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py 249 | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping 250 | """ 251 | 252 | def __init__(self, n_heads): 253 | super().__init__() 254 | self.n_heads = n_heads 255 | 256 | def forward(self, qkv): 257 | """ 258 | Apply QKV attention. 259 | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. 260 | :return: an [N x (H * C) x T] tensor after attention. 261 | """ 262 | bs, width, length = qkv.shape 263 | assert width % (3 * self.n_heads) == 0 264 | ch = width // (3 * self.n_heads) 265 | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) 266 | scale = 1 / math.sqrt(math.sqrt(ch)) 267 | weight = torch.einsum( 268 | "bct,bcs->bts", q * scale, k * scale 269 | ) # More stable with f16 than dividing afterwards 270 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 271 | a = torch.einsum("bts,bcs->bct", weight, v) 272 | return a.reshape(bs, -1, length) 273 | 274 | @staticmethod 275 | def count_flops(model, _x, y): 276 | return count_flops_attn(model, _x, y) 277 | 278 | 279 | class QKVAttention(nn.Module): 280 | """ 281 | From https://github.com/Stability-AI/stablediffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py 282 | A module which performs QKV attention and splits in a different order. 283 | """ 284 | 285 | def __init__(self, n_heads): 286 | super().__init__() 287 | self.n_heads = n_heads 288 | 289 | def forward(self, qkv): 290 | """ 291 | Apply QKV attention. 292 | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. 293 | :return: an [N x (H * C) x T] tensor after attention. 294 | """ 295 | bs, width, length = qkv.shape 296 | assert width % (3 * self.n_heads) == 0 297 | ch = width // (3 * self.n_heads) 298 | q, k, v = qkv.chunk(3, dim=1) 299 | scale = 1 / math.sqrt(math.sqrt(ch)) 300 | weight = torch.einsum( 301 | "bct,bcs->bts", 302 | (q * scale).view(bs * self.n_heads, ch, length), 303 | (k * scale).view(bs * self.n_heads, ch, length), 304 | ) # More stable with f16 than dividing afterwards 305 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 306 | a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) 307 | return a.reshape(bs, -1, length) 308 | 309 | @staticmethod 310 | def count_flops(model, _x, y): 311 | return count_flops_attn(model, _x, y) 312 | 313 | 314 | class Downsample(nn.Module): 315 | """ 316 | From https://github.com/Stability-AI/stablediffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py 317 | A downsampling layer with an optional convolution. 318 | :param channels: channels in the inputs and outputs. 319 | :param use_conv: a bool determining if a convolution is applied. 320 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 321 | downsampling occurs in the inner-two dimensions. 322 | """ 323 | 324 | def __init__(self, channels, use_conv, dims=1, out_channels=None,padding=2): 325 | super().__init__() 326 | self.channels = channels 327 | self.out_channels = out_channels or channels 328 | self.use_conv = use_conv 329 | self.dims = dims 330 | stride = 2 if dims != 3 else (1, 2, 2) 331 | if use_conv: 332 | self.op = conv_nd( 333 | dims, self.channels, self.out_channels, 5, stride=stride, padding=padding 334 | ) 335 | else: 336 | assert self.channels == self.out_channels 337 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 338 | 339 | def forward(self, x): 340 | assert x.shape[1] == self.channels 341 | return self.op(x) 342 | 343 | 344 | class Upsample(nn.Module): 345 | """ 346 | From https://github.com/Stability-AI/stablediffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py 347 | An upsampling layer with an optional convolution. 348 | :param channels: channels in the inputs and outputs. 349 | :param use_conv: a bool determining if a convolution is applied. 350 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 351 | upsampling occurs in the inner-two dimensions. 352 | """ 353 | 354 | def __init__(self, channels, use_conv, dims=1, out_channels=None, padding=2): 355 | super().__init__() 356 | self.channels = channels 357 | self.out_channels = out_channels or channels 358 | self.use_conv = use_conv 359 | self.dims = dims 360 | if use_conv: 361 | self.conv = conv_nd(dims, self.channels, self.out_channels, 5, padding=padding) 362 | 363 | def forward(self, x): 364 | assert x.shape[1] == self.channels 365 | if self.dims == 3: 366 | x = F.interpolate( 367 | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" 368 | ) 369 | else: 370 | x = F.interpolate(x.float(), scale_factor=2, mode="nearest") 371 | if self.use_conv: 372 | x = self.conv(x) 373 | return x 374 | 375 | 376 | class AudioTransformer(nn.Module): 377 | """ 378 | adapted from https://github.com/Stability-AI/stablediffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py 379 | Transformer block for audio-like data. 380 | First, project the input (aka embedding) 381 | and reshape to b, t, d. 382 | Then apply standard transformer action. 383 | Finally, reshape to audio 384 | """ 385 | def __init__(self, in_channels, n_heads, d_head, 386 | depth=1, dropout=0., context_dim=None, 387 | disable_self_attn=False, use_linear=False, 388 | use_checkpoint=True): 389 | super().__init__() 390 | if exists(context_dim) and not isinstance(context_dim, list): 391 | context_dim = [context_dim] 392 | self.in_channels = in_channels 393 | inner_dim = n_heads * d_head 394 | self.norm = Normalize(in_channels) 395 | if not use_linear: 396 | self.proj_in = nn.Conv1d(in_channels, 397 | inner_dim, 398 | kernel_size=1, 399 | stride=1, 400 | padding=0) 401 | else: 402 | self.proj_in = nn.Linear(in_channels, inner_dim) 403 | 404 | self.transformer_blocks = nn.ModuleList( 405 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], 406 | disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) 407 | for d in range(depth)] 408 | ) 409 | if not use_linear: 410 | self.proj_out = zero_module(nn.Conv1d(inner_dim, 411 | in_channels, 412 | kernel_size=1, 413 | stride=1, 414 | padding=0)) 415 | else: 416 | self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) 417 | self.use_linear = use_linear 418 | 419 | def forward(self, x, context=None): 420 | # note: if no context is given, cross-attention defaults to self-attention 421 | if not isinstance(context, list): 422 | context = [context] 423 | b, c, l = x.shape 424 | x_in = x 425 | x = self.norm(x) 426 | if not self.use_linear: 427 | x = self.proj_in(x) 428 | x = rearrange(x, 'b c l -> b l c').contiguous() 429 | if self.use_linear: 430 | x = self.proj_in(x) 431 | for i, block in enumerate(self.transformer_blocks): 432 | x = block(x, context=context[i]) 433 | if self.use_linear: 434 | x = self.proj_out(x) 435 | x = rearrange(x, 'b l c -> b c l', l=l).contiguous() 436 | if not self.use_linear: 437 | x = self.proj_out(x) 438 | return x + x_in 439 | 440 | 441 | def exists(val): 442 | return val is not None 443 | 444 | 445 | def uniq(arr): 446 | return{el: True for el in arr}.keys() 447 | 448 | 449 | def default(val, d): 450 | if exists(val): 451 | return val 452 | return d() if isfunction(d) else d 453 | 454 | 455 | def max_neg_value(t): 456 | return -torch.finfo(t.dtype).max 457 | 458 | 459 | def init_(tensor): 460 | dim = tensor.shape[-1] 461 | std = 1 / math.sqrt(dim) 462 | tensor.uniform_(-std, std) 463 | return tensor 464 | 465 | 466 | def expand_to_planes(input, shape): 467 | return input[..., None].repeat([1, 1, shape[2]]) 468 | 469 | 470 | # feedforward 471 | class GEGLU(nn.Module): 472 | def __init__(self, dim_in, dim_out): 473 | super().__init__() 474 | self.proj = nn.Linear(dim_in, dim_out * 2) 475 | 476 | def forward(self, x): 477 | x, gate = self.proj(x).chunk(2, dim=-1) 478 | return x * F.gelu(gate) 479 | 480 | 481 | class FeedForward(nn.Module): 482 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 483 | super().__init__() 484 | inner_dim = int(dim * mult) 485 | dim_out = default(dim_out, dim) 486 | project_in = nn.Sequential( 487 | nn.Linear(dim, inner_dim), 488 | nn.GELU() 489 | ) if not glu else GEGLU(dim, inner_dim) 490 | 491 | self.net = nn.Sequential( 492 | project_in, 493 | nn.Dropout(dropout), 494 | nn.Linear(inner_dim, dim_out) 495 | ) 496 | 497 | def forward(self, x): 498 | return self.net(x) 499 | 500 | 501 | def zero_module(module): 502 | """ 503 | Zero out the parameters of a module and return it. 504 | """ 505 | for p in module.parameters(): 506 | p.detach().zero_() 507 | return module 508 | 509 | 510 | def Normalize(in_channels): 511 | return torch.nn.GroupNorm(num_groups=1, num_channels=in_channels, eps=1e-6, affine=True) 512 | 513 | 514 | class CrossAttention(nn.Module): 515 | """From https://github.com/Stability-AI/stablediffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py""" 516 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 517 | super().__init__() 518 | inner_dim = dim_head * heads 519 | context_dim = default(context_dim, query_dim) 520 | 521 | self.scale = dim_head ** -0.5 522 | self.heads = heads 523 | 524 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 525 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 526 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 527 | 528 | self.to_out = nn.Sequential( 529 | nn.Linear(inner_dim, query_dim), 530 | nn.Dropout(dropout) 531 | ) 532 | 533 | def forward(self, x, context=None, mask=None): 534 | h = self.heads 535 | 536 | q = self.to_q(x) 537 | context = default(context, x) 538 | k = self.to_k(context) 539 | v = self.to_v(context) 540 | 541 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 542 | 543 | # force cast to fp32 to avoid overflowing 544 | if _ATTN_PRECISION =="fp32": 545 | with torch.autocast(enabled=False, device_type = 'cuda'): 546 | q, k = q.float(), k.float() 547 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 548 | else: 549 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 550 | 551 | del q, k 552 | 553 | if exists(mask): 554 | mask = rearrange(mask, 'b ... -> b (...)') 555 | max_neg_value = -torch.finfo(sim.dtype).max 556 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 557 | sim.masked_fill_(~mask, max_neg_value) 558 | 559 | # attention, what we cannot get enough of 560 | sim = sim.softmax(dim=-1) 561 | 562 | out = einsum('b i j, b j d -> b i d', sim, v) 563 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 564 | return self.to_out(out) 565 | 566 | 567 | class MemoryEfficientCrossAttention(nn.Module): 568 | # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 569 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 570 | super().__init__() 571 | print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " 572 | f"{heads} heads.") 573 | inner_dim = dim_head * heads 574 | context_dim = default(context_dim, query_dim) 575 | 576 | self.heads = heads 577 | self.dim_head = dim_head 578 | 579 | self.to_q = nn.Linear(query_dim, inner_dim) 580 | self.to_k = nn.Linear(context_dim, inner_dim) 581 | self.to_v = nn.Linear(context_dim, inner_dim) 582 | 583 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), 584 | nn.Dropout(dropout)) 585 | self.attention_op: Optional[Any] = None 586 | 587 | def forward(self, x, context=None, mask=None): 588 | q = self.to_q(x) 589 | context = default(context, x) 590 | k = self.to_k(context) 591 | v = self.to_v(context) 592 | 593 | b, _, _ = q.shape 594 | q, k, v = map( 595 | lambda t: t.unsqueeze(3) 596 | .reshape(b, t.shape[1], self.heads, self.dim_head) 597 | .permute(0, 2, 1, 3) 598 | .reshape(b * self.heads, t.shape[1], self.dim_head) 599 | .contiguous(), 600 | (q, k, v), 601 | ) 602 | 603 | # actually compute the attention, what we cannot get enough of 604 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) 605 | if exists(mask): 606 | raise NotImplementedError 607 | out = ( 608 | out.unsqueeze(0) 609 | .reshape(b, self.heads, out.shape[1], self.dim_head) 610 | .permute(0, 2, 1, 3) 611 | .reshape(b, out.shape[1], self.heads * self.dim_head) 612 | ) 613 | return self.to_out(out) 614 | 615 | 616 | class BasicTransformerBlock(nn.Module): 617 | """From https://github.com/Stability-AI/stablediffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py""" 618 | ATTENTION_MODES = { 619 | "softmax": CrossAttention, # vanilla attention 620 | "softmax-xformers": MemoryEfficientCrossAttention 621 | } 622 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, 623 | disable_self_attn=False): 624 | super().__init__() 625 | attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" 626 | assert attn_mode in self.ATTENTION_MODES 627 | attn_cls = self.ATTENTION_MODES[attn_mode] 628 | self.disable_self_attn = disable_self_attn 629 | self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, 630 | context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn 631 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 632 | self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, 633 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 634 | self.norm1 = nn.LayerNorm(dim) 635 | self.norm2 = nn.LayerNorm(dim) 636 | self.norm3 = nn.LayerNorm(dim) 637 | self.checkpoint = checkpoint 638 | 639 | def forward(self, x, context=None): 640 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 641 | 642 | def _forward(self, x, context=None): 643 | x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x 644 | x = self.attn2(self.norm2(x), context=context) + x 645 | x = self.ff(self.norm3(x)) + x 646 | return x 647 | 648 | 649 | def extract_into_tensor(a, t, x_shape): 650 | b, *_ = t.shape 651 | out = a.gather(-1, t) 652 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 653 | 654 | 655 | def checkpoint(func, inputs, params, flag): 656 | """ 657 | Evaluate a function without caching intermediate activations, allowing for 658 | reduced memory at the expense of extra compute in the backward pass. 659 | :param func: the function to evaluate. 660 | :param inputs: the argument sequence to pass to `func`. 661 | :param params: a sequence of parameters `func` depends on but does not 662 | explicitly take as arguments. 663 | :param flag: if False, disable gradient checkpointing. 664 | """ 665 | if flag: 666 | args = tuple(inputs) + tuple(params) 667 | return CheckpointFunction.apply(func, len(inputs), *args) 668 | else: 669 | return func(*inputs) 670 | 671 | 672 | class CheckpointFunction(torch.autograd.Function): 673 | """From https://github.com/Stability-AI/stablediffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py""" 674 | @staticmethod 675 | def forward(ctx, run_function, length, *args): 676 | ctx.run_function = run_function 677 | ctx.input_tensors = list(args[:length]) 678 | ctx.input_params = list(args[length:]) 679 | ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), 680 | "dtype": torch.get_autocast_gpu_dtype(), 681 | "cache_enabled": torch.is_autocast_cache_enabled()} 682 | with torch.no_grad(): 683 | output_tensors = ctx.run_function(*ctx.input_tensors) 684 | return output_tensors 685 | 686 | @staticmethod 687 | def backward(ctx, *output_grads): 688 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 689 | with torch.enable_grad(), \ 690 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 691 | # Fixes a bug where the first op in run_function modifies the 692 | # Tensor storage in place, which is not allowed for detach()'d 693 | # Tensors. 694 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 695 | output_tensors = ctx.run_function(*shallow_copies) 696 | input_grads = torch.autograd.grad( 697 | output_tensors, 698 | ctx.input_tensors + ctx.input_params, 699 | output_grads, 700 | allow_unused=True, 701 | ) 702 | del ctx.input_tensors 703 | del ctx.input_params 704 | del output_tensors 705 | return (None, None) + input_grads 706 | 707 | 708 | def zero_module(module): 709 | """ 710 | Zero out the parameters of a module and return it. 711 | """ 712 | for p in module.parameters(): 713 | p.detach().zero_() 714 | return module 715 | 716 | 717 | def normalization(channels): 718 | """ 719 | Make a standard normalization layer. 720 | :param channels: number of input channels. 721 | :return: an nn.Module for normalization. 722 | """ 723 | return GroupNorm(1, channels) 724 | 725 | 726 | class GroupNorm(nn.GroupNorm): 727 | def forward(self, x): 728 | return super().forward(x.float()).type(x.dtype) 729 | 730 | def conv_nd(dims, *args, **kwargs): 731 | """ 732 | Create a 1D, 2D, or 3D convolution module. 733 | """ 734 | if dims == 1: 735 | return nn.Conv1d(*args, **kwargs) 736 | elif dims == 2: 737 | return nn.Conv2d(*args, **kwargs) 738 | elif dims == 3: 739 | return nn.Conv3d(*args, **kwargs) 740 | raise ValueError(f"unsupported dimensions: {dims}") 741 | 742 | 743 | def linear(*args, **kwargs): 744 | """ 745 | Create a linear module. 746 | """ 747 | return nn.Linear(*args, **kwargs) 748 | 749 | 750 | def avg_pool_nd(dims, *args, **kwargs): 751 | """ 752 | Create a 1D, 2D, or 3D average pooling module. 753 | """ 754 | if dims == 1: 755 | return nn.AvgPool1d(*args, **kwargs) 756 | elif dims == 2: 757 | return nn.AvgPool2d(*args, **kwargs) 758 | elif dims == 3: 759 | return nn.AvgPool3d(*args, **kwargs) 760 | raise ValueError(f"unsupported dimensions: {dims}") 761 | -------------------------------------------------------------------------------- /audio_diffusion/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from diffusers.models.embeddings import GaussianFourierProjection 6 | 7 | from .blocks import conv_nd, exists, normalization, zero_module, convert_module_to_f16, convert_module_to_f32 8 | from .blocks import TimestepEmbedSequential, ResBlock, AttentionBlock, AudioTransformer, Downsample, Upsample 9 | 10 | 11 | class UNetModel(nn.Module): 12 | """ 13 | Adapted from https://github.com/Stability-AI/stablediffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py 14 | The full UNet model with attention and timestep embedding. 15 | :param in_channels: channels in the input Tensor. 16 | :param model_channels: base channel count for the model. 17 | :param out_channels: channels in the output Tensor. 18 | :param num_res_blocks: number of residual blocks per downsample. 19 | :param attention_resolutions: a collection of downsample rates at which 20 | attention will take place. May be a set, list, or tuple. 21 | For example, if this contains 4, then at 4x downsampling, attention 22 | will be used. 23 | :param dropout: the dropout probability. 24 | :param channel_mult: channel multiplier for each level of the UNet. 25 | :param conv_resample: if True, use learned convolutions for upsampling and 26 | downsampling. 27 | :param dims: determines if the signal is 1D, 2D, or 3D. 28 | :param num_classes: if specified (as an int), then this model will be 29 | class-conditional with `num_classes` classes. 30 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 31 | :param num_heads: the number of attention heads in each attention layer. 32 | :param num_heads_channels: if specified, ignore num_heads and instead use 33 | a fixed channel width per attention head. 34 | :param num_heads_upsample: works with num_heads to set a different number 35 | of heads for upsampling. Deprecated. 36 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 37 | :param resblock_updown: use residual blocks for up/downsampling. 38 | :param use_new_attention_order: use a different attention pattern for potentially 39 | increased efficiency. 40 | """ 41 | 42 | def __init__( 43 | self, 44 | sample_size, 45 | in_channels, 46 | model_channels, 47 | out_channels, 48 | num_res_blocks, 49 | attention_resolutions, 50 | dropout=0, 51 | channel_mult=(1, 2, 4, 8), 52 | conv_resample=True, 53 | dims=1, 54 | num_classes=None, 55 | use_checkpoint=False, 56 | use_fp16=False, 57 | num_heads=-1, 58 | num_head_channels=-1, 59 | num_heads_upsample=-1, 60 | use_scale_shift_norm=False, 61 | resblock_updown=False, 62 | use_new_attention_order=False, 63 | use_audio_transformer=False, # custom transformer support 64 | transformer_depth=1, # custom transformer support 65 | context_dim=None, # custom transformer support 66 | n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model 67 | legacy=True, 68 | disable_self_attentions=None, 69 | num_attention_blocks=None, 70 | disable_middle_self_attn=False, 71 | use_linear_in_transformer=False, 72 | flip_sin_to_cos=False 73 | ): 74 | super().__init__() 75 | if use_audio_transformer: 76 | assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' 77 | 78 | if context_dim is not None: 79 | assert use_audio_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' 80 | from omegaconf.listconfig import ListConfig 81 | if type(context_dim) == ListConfig: 82 | context_dim = list(context_dim) 83 | 84 | if num_heads_upsample == -1: 85 | num_heads_upsample = num_heads 86 | 87 | if num_heads == -1: 88 | assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' 89 | 90 | if num_head_channels == -1: 91 | assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' 92 | 93 | self.sample_size = sample_size 94 | self.in_channels = in_channels 95 | self.model_channels = model_channels 96 | self.out_channels = out_channels 97 | if isinstance(num_res_blocks, int): 98 | self.num_res_blocks = len(channel_mult) * [num_res_blocks] 99 | else: 100 | if len(num_res_blocks) != len(channel_mult): 101 | raise ValueError("provide num_res_blocks either as an int (globally constant) or " 102 | "as a list/tuple (per-level) with the same length as channel_mult") 103 | self.num_res_blocks = num_res_blocks 104 | if disable_self_attentions is not None: 105 | # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not 106 | assert len(disable_self_attentions) == len(channel_mult) 107 | if num_attention_blocks is not None: 108 | assert len(num_attention_blocks) == len(self.num_res_blocks) 109 | assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) 110 | print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " 111 | f"This option has LESS priority than attention_resolutions {attention_resolutions}, " 112 | f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " 113 | f"attention will still not be set.") 114 | 115 | self.attention_resolutions = attention_resolutions 116 | self.dropout = dropout 117 | self.channel_mult = channel_mult 118 | self.conv_resample = conv_resample 119 | self.num_classes = num_classes 120 | self.use_checkpoint = use_checkpoint 121 | self.dtype = torch.float16 if use_fp16 else torch.float32 122 | self.num_heads = num_heads 123 | self.num_head_channels = num_head_channels 124 | self.num_heads_upsample = num_heads_upsample 125 | self.predict_codebook_ids = n_embed is not None 126 | 127 | self.time_proj = GaussianFourierProjection( 128 | embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos 129 | ) 130 | timestep_input_dim = 16 131 | time_embed_dim = self.model_channels * 4 132 | 133 | self.time_embed = nn.Sequential( 134 | nn.Linear(timestep_input_dim, time_embed_dim), 135 | nn.SiLU(), 136 | nn.Linear(time_embed_dim, time_embed_dim), 137 | ) 138 | 139 | if self.num_classes is not None: 140 | if isinstance(self.num_classes, int): 141 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 142 | elif self.num_classes == "continuous": 143 | print("setting up linear c_adm embedding layer") 144 | self.label_emb = nn.Linear(1, time_embed_dim) 145 | else: 146 | raise ValueError() 147 | 148 | self.input_blocks = nn.ModuleList( 149 | [ 150 | TimestepEmbedSequential( 151 | conv_nd(dims, in_channels, model_channels, 5, padding=2) 152 | ) 153 | ] 154 | ) 155 | self._feature_size = model_channels 156 | input_block_chans = [model_channels] 157 | ch = model_channels 158 | ds = 1 159 | for level, mult in enumerate(channel_mult): 160 | for nr in range(self.num_res_blocks[level]): 161 | layers = [ 162 | ResBlock( 163 | ch, 164 | time_embed_dim, 165 | dropout, 166 | out_channels=mult * model_channels, 167 | dims=dims, 168 | use_checkpoint=use_checkpoint, 169 | use_scale_shift_norm=use_scale_shift_norm, 170 | ) 171 | ] 172 | ch = mult * model_channels 173 | if ds in attention_resolutions: 174 | if num_head_channels == -1: 175 | dim_head = ch // num_heads 176 | else: 177 | num_heads = ch // num_head_channels 178 | dim_head = num_head_channels 179 | if legacy: 180 | #num_heads = 1 181 | dim_head = ch // num_heads if use_audio_transformer else num_head_channels 182 | if exists(disable_self_attentions): 183 | disabled_sa = disable_self_attentions[level] 184 | else: 185 | disabled_sa = False 186 | 187 | if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: 188 | layers.append( 189 | AttentionBlock( 190 | ch, 191 | use_checkpoint=use_checkpoint, 192 | num_heads=num_heads, 193 | num_head_channels=dim_head, 194 | use_new_attention_order=use_new_attention_order, 195 | ) if not use_audio_transformer else AudioTransformer( 196 | ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, 197 | disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, 198 | use_checkpoint=use_checkpoint 199 | ) 200 | ) 201 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 202 | self._feature_size += ch 203 | input_block_chans.append(ch) 204 | if level != len(channel_mult) - 1: 205 | out_ch = ch 206 | self.input_blocks.append( 207 | TimestepEmbedSequential( 208 | ResBlock( 209 | ch, 210 | time_embed_dim, 211 | dropout, 212 | out_channels=out_ch, 213 | dims=dims, 214 | use_checkpoint=use_checkpoint, 215 | use_scale_shift_norm=use_scale_shift_norm, 216 | down=True, 217 | ) 218 | if resblock_updown 219 | else Downsample( 220 | ch, conv_resample, dims=dims, out_channels=out_ch 221 | ) 222 | ) 223 | ) 224 | ch = out_ch 225 | input_block_chans.append(ch) 226 | ds *= 2 227 | self._feature_size += ch 228 | 229 | if num_head_channels == -1: 230 | dim_head = ch // num_heads 231 | else: 232 | num_heads = ch // num_head_channels 233 | dim_head = num_head_channels 234 | if legacy: 235 | #num_heads = 1 236 | dim_head = ch // num_heads if use_audio_transformer else num_head_channels 237 | self.middle_block = TimestepEmbedSequential( 238 | ResBlock( 239 | ch, 240 | time_embed_dim, 241 | dropout, 242 | dims=dims, 243 | use_checkpoint=use_checkpoint, 244 | use_scale_shift_norm=use_scale_shift_norm, 245 | ), 246 | AttentionBlock( 247 | ch, 248 | use_checkpoint=use_checkpoint, 249 | num_heads=num_heads, 250 | num_head_channels=dim_head, 251 | use_new_attention_order=use_new_attention_order, 252 | ) if not use_audio_transformer else AudioTransformer( # always uses a self-attn 253 | ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, 254 | disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, 255 | use_checkpoint=use_checkpoint 256 | ), 257 | ResBlock( 258 | ch, 259 | time_embed_dim, 260 | dropout, 261 | dims=dims, 262 | use_checkpoint=use_checkpoint, 263 | use_scale_shift_norm=use_scale_shift_norm, 264 | ), 265 | ) 266 | self._feature_size += ch 267 | 268 | self.output_blocks = nn.ModuleList([]) 269 | for level, mult in list(enumerate(channel_mult))[::-1]: 270 | for i in range(self.num_res_blocks[level] + 1): 271 | ich = input_block_chans.pop() 272 | layers = [ 273 | ResBlock( 274 | ch + ich, 275 | time_embed_dim, 276 | dropout, 277 | out_channels=model_channels * mult, 278 | dims=dims, 279 | use_checkpoint=use_checkpoint, 280 | use_scale_shift_norm=use_scale_shift_norm, 281 | ) 282 | ] 283 | ch = model_channels * mult 284 | if ds in attention_resolutions: 285 | if num_head_channels == -1: 286 | dim_head = ch // num_heads 287 | else: 288 | num_heads = ch // num_head_channels 289 | dim_head = num_head_channels 290 | if legacy: 291 | #num_heads = 1 292 | dim_head = ch // num_heads if use_audio_transformer else num_head_channels 293 | if exists(disable_self_attentions): 294 | disabled_sa = disable_self_attentions[level] 295 | else: 296 | disabled_sa = False 297 | 298 | if not exists(num_attention_blocks) or i < num_attention_blocks[level]: 299 | layers.append( 300 | AttentionBlock( 301 | ch, 302 | use_checkpoint=use_checkpoint, 303 | num_heads=num_heads_upsample, 304 | num_head_channels=dim_head, 305 | use_new_attention_order=use_new_attention_order, 306 | ) if not use_audio_transformer else AudioTransformer( 307 | ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, 308 | disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, 309 | use_checkpoint=use_checkpoint 310 | ) 311 | ) 312 | if level and i == self.num_res_blocks[level]: 313 | out_ch = ch 314 | layers.append( 315 | ResBlock( 316 | ch, 317 | time_embed_dim, 318 | dropout, 319 | out_channels=out_ch, 320 | dims=dims, 321 | use_checkpoint=use_checkpoint, 322 | use_scale_shift_norm=use_scale_shift_norm, 323 | up=True, 324 | ) 325 | if resblock_updown 326 | else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) 327 | ) 328 | ds //= 2 329 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 330 | self._feature_size += ch 331 | 332 | self.out = nn.Sequential( 333 | normalization(ch), 334 | nn.SiLU(), 335 | zero_module(conv_nd(dims, model_channels, out_channels, 5, padding=2)), 336 | ) 337 | if self.predict_codebook_ids: 338 | self.id_predictor = nn.Sequential( 339 | normalization(ch), 340 | conv_nd(dims, model_channels, n_embed, 1), 341 | #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits 342 | ) 343 | 344 | def convert_to_fp16(self): 345 | """ 346 | Convert the torso of the model to float16. 347 | """ 348 | self.input_blocks.apply(convert_module_to_f16) 349 | self.middle_block.apply(convert_module_to_f16) 350 | self.output_blocks.apply(convert_module_to_f16) 351 | 352 | def convert_to_fp32(self): 353 | """ 354 | Convert the torso of the model to float32. 355 | """ 356 | self.input_blocks.apply(convert_module_to_f32) 357 | self.middle_block.apply(convert_module_to_f32) 358 | self.output_blocks.apply(convert_module_to_f32) 359 | 360 | def forward(self, x, timesteps=None, context=None, y=None,**kwargs): 361 | """ 362 | Apply the model to an input batch. 363 | :param x: an [N x C x ...] Tensor of inputs. 364 | :param timesteps: a 1-D batch of timesteps. 365 | :param context: conditioning plugged in via crossattn 366 | :param y: an [N] Tensor of labels, if class-conditional. 367 | :return: an [N x C x ...] Tensor of outputs. 368 | """ 369 | assert (y is not None) == ( 370 | self.num_classes is not None 371 | ), "must specify y if and only if the model is class-conditional" 372 | hs = [] 373 | 374 | t_emb = self.time_proj(timesteps) 375 | emb = self.time_embed(t_emb) 376 | 377 | if self.num_classes is not None: 378 | assert y.shape[0] == x.shape[0] 379 | emb = emb + self.label_emb(y) 380 | 381 | h = x.type(self.dtype) 382 | for module in self.input_blocks: 383 | h = module(h, emb, context) 384 | hs.append(h) 385 | h = self.middle_block(h, emb, context) 386 | for module in self.output_blocks: 387 | h = torch.cat([h, hs.pop()], dim=1) 388 | h = module(h, emb, context) 389 | h = h.type(x.dtype) 390 | if self.predict_codebook_ids: 391 | return self.id_predictor(h) 392 | else: 393 | return self.out(h) 394 | -------------------------------------------------------------------------------- /audio_diffusion/utils.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import warnings 3 | 4 | import torch 5 | from torch import nn 6 | import random 7 | import math 8 | from torch import optim 9 | 10 | def append_dims(x, target_dims): 11 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 12 | dims_to_append = target_dims - x.ndim 13 | if dims_to_append < 0: 14 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 15 | return x[(...,) + (None,) * dims_to_append] 16 | 17 | 18 | def n_params(module): 19 | """Returns the number of trainable parameters in a module.""" 20 | return sum(p.numel() for p in module.parameters()) 21 | 22 | 23 | @contextmanager 24 | def train_mode(model, mode=True): 25 | """A context manager that places a model into training mode and restores 26 | the previous mode on exit.""" 27 | modes = [module.training for module in model.modules()] 28 | try: 29 | yield model.train(mode) 30 | finally: 31 | for i, module in enumerate(model.modules()): 32 | module.training = modes[i] 33 | 34 | 35 | def eval_mode(model): 36 | """A context manager that places a model into evaluation mode and restores 37 | the previous mode on exit.""" 38 | return train_mode(model, False) 39 | 40 | @torch.no_grad() 41 | def ema_update(model, averaged_model, decay): 42 | """Incorporates updated model parameters into an exponential moving averaged 43 | version of a model. It should be called after each optimizer step.""" 44 | model_params = dict(model.named_parameters()) 45 | # move parameters to CPU to save memory 46 | model_params = {name: param.cpu() for name, param in model_params.items()} 47 | averaged_params = dict(averaged_model.named_parameters()) 48 | assert model_params.keys() == averaged_params.keys() 49 | 50 | for name, param in model_params.items(): 51 | averaged_params[name].mul_(decay).add_(param, alpha=1 - decay) 52 | 53 | model_buffers = dict(model.named_buffers()) 54 | # move buffers to CPU to save memory 55 | model_buffers = {name: buf.cpu() for name, buf in model_buffers.items()} 56 | averaged_buffers = dict(averaged_model.named_buffers()) 57 | assert model_buffers.keys() == averaged_buffers.keys() 58 | 59 | for name, buf in model_buffers.items(): 60 | averaged_buffers[name].copy_(buf) 61 | 62 | 63 | class EMAWarmup: 64 | """Implements an EMA warmup using an inverse decay schedule. 65 | If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are 66 | good values for models you plan to train for a million or more steps (reaches decay 67 | factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models 68 | you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at 69 | 215.4k steps). 70 | Args: 71 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. 72 | power (float): Exponential factor of EMA warmup. Default: 1. 73 | min_value (float): The minimum EMA decay rate. Default: 0. 74 | max_value (float): The maximum EMA decay rate. Default: 1. 75 | start_at (int): The epoch to start averaging at. Default: 0. 76 | last_epoch (int): The index of last epoch. Default: 0. 77 | """ 78 | 79 | def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0, 80 | last_epoch=0): 81 | self.inv_gamma = inv_gamma 82 | self.power = power 83 | self.min_value = min_value 84 | self.max_value = max_value 85 | self.start_at = start_at 86 | self.last_epoch = last_epoch 87 | 88 | def state_dict(self): 89 | """Returns the state of the class as a :class:`dict`.""" 90 | return dict(self.__dict__.items()) 91 | 92 | def load_state_dict(self, state_dict): 93 | """Loads the class's state. 94 | Args: 95 | state_dict (dict): scaler state. Should be an object returned 96 | from a call to :meth:`state_dict`. 97 | """ 98 | self.__dict__.update(state_dict) 99 | 100 | def get_value(self): 101 | """Gets the current EMA decay rate.""" 102 | epoch = max(0, self.last_epoch - self.start_at) 103 | value = 1 - (1 + epoch / self.inv_gamma) ** -self.power 104 | return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value)) 105 | 106 | def step(self): 107 | """Updates the step count.""" 108 | self.last_epoch += 1 109 | 110 | 111 | class InverseLR(optim.lr_scheduler._LRScheduler): 112 | """Implements an inverse decay learning rate schedule with an optional exponential 113 | warmup. When last_epoch=-1, sets initial lr as lr. 114 | inv_gamma is the number of steps/epochs required for the learning rate to decay to 115 | (1 / 2)**power of its original value. 116 | Args: 117 | optimizer (Optimizer): Wrapped optimizer. 118 | inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1. 119 | power (float): Exponential factor of learning rate decay. Default: 1. 120 | warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) 121 | Default: 0. 122 | final_lr (float): The final learning rate. Default: 0. 123 | last_epoch (int): The index of last epoch. Default: -1. 124 | verbose (bool): If ``True``, prints a message to stdout for 125 | each update. Default: ``False``. 126 | """ 127 | 128 | def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., final_lr=0., 129 | last_epoch=-1, verbose=False): 130 | self.inv_gamma = inv_gamma 131 | self.power = power 132 | if not 0. <= warmup < 1: 133 | raise ValueError('Invalid value for warmup') 134 | self.warmup = warmup 135 | self.final_lr = final_lr 136 | super().__init__(optimizer, last_epoch, verbose) 137 | 138 | def get_lr(self): 139 | if not self._get_lr_called_within_step: 140 | warnings.warn("To get the last learning rate computed by the scheduler, " 141 | "please use `get_last_lr()`.") 142 | 143 | return self._get_closed_form_lr() 144 | 145 | def _get_closed_form_lr(self): 146 | warmup = 1 - self.warmup ** (self.last_epoch + 1) 147 | lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power 148 | return [warmup * max(self.final_lr, base_lr * lr_mult) 149 | for base_lr in self.base_lrs] 150 | 151 | 152 | # Define the diffusion noise schedule 153 | def get_alphas_sigmas(t): 154 | return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) 155 | 156 | def append_dims(x, target_dims): 157 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 158 | dims_to_append = target_dims - x.ndim 159 | if dims_to_append < 0: 160 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 161 | return x[(...,) + (None,) * dims_to_append] 162 | 163 | def expand_to_planes(input, shape): 164 | return input[..., None].repeat([1, 1, shape[2]]) 165 | 166 | class PadCrop(nn.Module): 167 | def __init__(self, n_samples, randomize=True): 168 | super().__init__() 169 | self.n_samples = n_samples 170 | self.randomize = randomize 171 | 172 | def __call__(self, signal): 173 | n, s = signal.shape 174 | start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item() 175 | end = start + self.n_samples 176 | output = signal.new_zeros([n, self.n_samples]) 177 | output[:, :min(s, self.n_samples)] = signal[:, start:end] 178 | return output 179 | 180 | class RandomPhaseInvert(nn.Module): 181 | def __init__(self, p=0.5): 182 | super().__init__() 183 | self.p = p 184 | def __call__(self, signal): 185 | return -signal if (random.random() < self.p) else signal 186 | 187 | class Stereo(nn.Module): 188 | def __call__(self, signal): 189 | signal_shape = signal.shape 190 | # Check if it's mono 191 | if len(signal_shape) == 1: # s -> 2, s 192 | signal = signal.unsqueeze(0).repeat(2, 1) 193 | elif len(signal_shape) == 2: 194 | if signal_shape[0] == 1: #1, s -> 2, s 195 | signal = signal.repeat(2, 1) 196 | elif signal_shape[0] > 2: #?, s -> 2,s 197 | signal = signal[:2, :] 198 | 199 | return signal 200 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serp-ai/ai-text-to-audio-latent-diffusion/55230601b4f34b30bc52568f58619c8c33b3e202/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | from torchaudio import transforms as T 4 | import random 5 | from glob import glob 6 | import os 7 | import re 8 | from audio_diffusion.utils import Stereo, PadCrop, RandomPhaseInvert 9 | import tqdm 10 | from multiprocessing import Pool, cpu_count 11 | from functools import partial 12 | 13 | import torch 14 | import torchaudio 15 | from torchaudio import transforms as T 16 | import random 17 | from glob import glob 18 | import os 19 | from audio_diffusion.utils import Stereo, PadCrop, RandomPhaseInvert 20 | import tqdm 21 | from multiprocessing import Pool, cpu_count 22 | from functools import partial 23 | from encodec.utils import convert_audio 24 | 25 | class SampleDataset(torch.utils.data.Dataset): 26 | def __init__(self, paths, global_args, tokenizer=None): 27 | super().__init__() 28 | self.filenames = [] 29 | 30 | print(f"Random crop: {global_args.random_crop}") 31 | self.augs = torch.nn.Sequential( 32 | PadCrop(global_args.sample_size, randomize=global_args.random_crop), 33 | # RandomPhaseInvert(), 34 | ) 35 | 36 | if tokenizer is not None: 37 | self.tokenizer = tokenizer 38 | else: 39 | self.tokenizer = None 40 | 41 | for path in paths: 42 | for ext in ['wav','flac','ogg','aiff','aif','mp3']: 43 | self.filenames += glob(f'{path}/**/*.{ext}', recursive=True) 44 | 45 | self.sr = global_args.sample_rate 46 | if hasattr(global_args,'load_frac'): 47 | self.load_frac = global_args.load_frac 48 | else: 49 | self.load_frac = 1.0 50 | self.num_gpus = global_args.num_gpus 51 | 52 | # if hasattr(global_args,'channels'): 53 | # self.channels = global_args.channels 54 | # else: 55 | # self.channels = 2 56 | self.channels = 1 57 | 58 | self.use_text_dropout = global_args.use_text_dropout 59 | self.text_dropout_prob = global_args.text_dropout_prob 60 | 61 | self.shuffle_prompts = global_args.shuffle_prompts 62 | self.shuffle_prompts_sep = global_args.shuffle_prompts_sep.strip("\'").strip('\"') 63 | self.shuffle_prompts_prob = global_args.shuffle_prompts_prob 64 | 65 | self.cache_training_data = global_args.cache_training_data 66 | 67 | if self.cache_training_data: self.preload_files() 68 | 69 | 70 | def load_file(self, filename): 71 | wav, sr = torchaudio.load(filename) 72 | wav = torch.mean(wav, dim=0, keepdim=True) # convert to 1 channel 73 | wav = wav.unsqueeze(0) 74 | audio = convert_audio(wav, sr, self.sr, self.channels) 75 | return audio 76 | 77 | def load_file_ind(self, file_list,i): # used when caching training data 78 | return self.load_file(file_list[i]).cpu() 79 | 80 | def get_data_range(self): # for parallel runs, only grab part of the data 81 | start, stop = 0, len(self.filenames) 82 | try: 83 | local_rank = int(os.environ["LOCAL_RANK"]) 84 | world_size = int(os.environ["WORLD_SIZE"]) 85 | interval = stop//world_size 86 | start, stop = local_rank*interval, (local_rank+1)*interval 87 | print("local_rank, world_size, start, stop =",local_rank, world_size, start, stop) 88 | return start, stop 89 | except KeyError as e: # we're on GPU 0 and the others haven't been initialized yet 90 | start, stop = 0, len(self.filenames)//self.num_gpus 91 | return start, stop 92 | 93 | def preload_files(self): 94 | n = int(len(self.filenames)*self.load_frac) 95 | print(f"Caching {n} input audio files:") 96 | wrapper = partial(self.load_file_ind, self.filenames) 97 | start, stop = self.get_data_range() 98 | with Pool(processes=cpu_count()) as p: # //8 to avoid FS bottleneck and/or too many processes (b/c * num_gpus) 99 | self.audio_files = list(tqdm.tqdm(p.imap(wrapper, range(start,stop)), total=stop-start)) 100 | 101 | def __len__(self): 102 | return len(self.filenames) 103 | 104 | def __getitem__(self, idx): 105 | audio_filename = self.filenames[idx] 106 | try: 107 | if self.cache_training_data: 108 | audio = self.audio_files[idx] # .copy() 109 | else: 110 | audio = self.load_file(audio_filename) 111 | 112 | if len(audio.shape) > 2: 113 | audio = audio.squeeze(0) 114 | #Run augmentations on this sample (including random crop) 115 | if self.augs is not None: 116 | audio = self.augs(audio) 117 | 118 | if self.use_text_dropout: 119 | if random.random() < self.text_dropout_prob: 120 | return (audio, '') 121 | audio_filename = os.path.splitext(audio_filename)[0].lower().split('/')[-1].split('\\')[-1] 122 | # remove _chunk{num} from filename 123 | audio_filename = re.sub(r'_chunk\d+', '', audio_filename) 124 | if self.shuffle_prompts and random.random() < self.shuffle_prompts_prob: 125 | # split the filename by seperator and shuffle the order 126 | audio_filename = audio_filename.split(self.shuffle_prompts_sep) 127 | random.shuffle(audio_filename) 128 | audio_filename = self.shuffle_prompts_sep.join(audio_filename) 129 | return (audio, audio_filename) 130 | except Exception as e: 131 | print(f'Couldn\'t load file {audio_filename}: {e}') 132 | return self[random.randrange(len(self))] 133 | -------------------------------------------------------------------------------- /defaults.ini: -------------------------------------------------------------------------------- 1 | 2 | [DEFAULTS] 3 | 4 | #name of the run 5 | name = dd-finetune 6 | 7 | # training data directory 8 | training_dir = '' 9 | 10 | # the batch size 11 | batch_size = 8 12 | 13 | # number of GPUs to use for training 14 | num_gpus = 1 15 | 16 | # number of nodes to use for training 17 | num_nodes = 1 18 | 19 | # number of CPU workers for the DataLoader 20 | num_workers = 2 21 | 22 | # Number of audio samples for the training input 23 | sample_size = 1310720 24 | 25 | # Number of steps between demos 26 | demo_every = 1000 27 | 28 | # Number of denoising steps for the demos 29 | demo_steps = 250 30 | 31 | # Number of demos to create 32 | num_demos = 16 33 | 34 | # the EMA decay 35 | ema_decay = 0.995 36 | 37 | # the random seed 38 | seed = 42 39 | 40 | # Batches for gradient accumulation 41 | accum_batches = 4 42 | 43 | # The sample rate of the audio 44 | sample_rate = 24000 45 | 46 | # Number of steps between checkpoints 47 | checkpoint_every = 10000 48 | 49 | # unused, required by the model code 50 | latent_dim = 0 51 | 52 | # If true training data is kept in RAM 53 | cache_training_data = False 54 | 55 | # randomly crop input audio? (for augmentation) 56 | random_crop = False 57 | 58 | # checkpoint file to (re)start training from 59 | ckpt_path = '' 60 | 61 | # Path to output the model checkpoints 62 | save_path = '' 63 | 64 | # Resume training from checkpoint 65 | resume_from_checkpoint = '' 66 | 67 | #the multiprocessing start method ['fork', 'forkserver', 'spawn'] 68 | start_method = 'spawn' 69 | 70 | # Whether to save the model checkpoints to Weights & Biases 71 | save_wandb = 'none' 72 | 73 | # What precision to use for training 74 | precision = 'bf16' 75 | 76 | # Learning rate 77 | lr = 4e-5 78 | 79 | # Scale lr 80 | scale_lr = False 81 | 82 | # Lr warmup steps 83 | lr_warmup_steps = 0 84 | 85 | # 8-bit Optimizer 86 | use_8bit_optim = False 87 | 88 | # Gradient checkpointing 89 | gradient_checkpointing = False 90 | 91 | # Adam beta1 92 | adam_beta1 = 0.9 93 | 94 | # Adam beta2 95 | adam_beta2 = 0.999 96 | 97 | # Adam eps 98 | adam_epsilon = 1e-8 99 | 100 | # Weight decay 101 | adam_weight_decay = 1e-2 102 | 103 | # Max gradient norm 104 | max_grad_norm = 1.0 105 | 106 | # Number of epochs 107 | num_epochs = 999999999 108 | 109 | # Max steps 110 | max_train_steps = 0 111 | 112 | # Lr scheduler 113 | lr_scheduler ='constant' 114 | 115 | # Target bandwidth 116 | target_bandwidth=3 117 | 118 | # Channels 119 | channels = 4 120 | 121 | # Train text encoder 122 | train_text_encoder = '' 123 | 124 | # Embedder path 125 | embedder_path = '' 126 | 127 | # Use embedder 128 | use_embedder = True 129 | 130 | # Use text dropout 131 | use_text_dropout = False 132 | 133 | # Text dropout prob 134 | text_dropout_prob = 0.2 135 | 136 | # Shuffle prompts 137 | shuffle_prompts = False 138 | 139 | # Shuffle prompts seperator string 140 | shuffle_prompts_sep = ', ' 141 | 142 | # Shuffle prompts probability 143 | shuffle_prompts_prob = 1.0 -------------------------------------------------------------------------------- /example_launch_command.txt: -------------------------------------------------------------------------------- 1 | accelerate launch ^ 2 | --mixed_precision=bf16 ^ 3 | --num_processes=1 ^ 4 | --num_machines=1 ^ 5 | --dynamo_backend=no ^ 6 | train_latent_cond.py ^ 7 | --ckpt-path=none ^ 8 | --resume-from-checkpoint=none ^ 9 | --name=latent_dance_diffusion ^ 10 | --precision=bf16 ^ 11 | --training-dir=audio_chunks ^ 12 | --sample-size=1310720 ^ 13 | --accum-batches=12 ^ 14 | --batch-size=4 ^ 15 | --demo-every=500 ^ 16 | --checkpoint-every=500 ^ 17 | --num-workers=5 ^ 18 | --num-gpus=1 ^ 19 | --lr=4e-5 ^ 20 | --save-path=models/LatentDanceDiffusion ^ 21 | --num-demos=1 ^ 22 | --target-bandwidth=3 ^ 23 | --gradient-checkpointing=True ^ 24 | --random-crop=True ^ 25 | --use-8bit-optim=True ^ 26 | --use-embedder=True ^ 27 | --use-text-dropout=True ^ 28 | --shuffle-prompts=True ^ 29 | --shuffle-prompts-sep=", " ^ 30 | --shuffle-prompts-prob=1.0 -------------------------------------------------------------------------------- /make_audio_chunks.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "from pydub import AudioSegment\n", 11 | "\n", 12 | "def chunk_audio(input_folder, output_folder, chunk_size_samples, sr=24000, spacing=0.5, last_chunk_min_seconds=60, debug=True, errors_=None):\n", 13 | " if not os.path.exists(output_folder):\n", 14 | " os.makedirs(output_folder)\n", 15 | " errors = []\n", 16 | " for filename in os.listdir(input_folder):\n", 17 | " try:\n", 18 | " if not filename.endswith('.wav'):\n", 19 | " continue\n", 20 | " filepath = os.path.join(input_folder, filename)\n", 21 | " if errors_ is not None:\n", 22 | " if filepath not in errors_:\n", 23 | " continue\n", 24 | " sound = AudioSegment.from_wav(filepath)\n", 25 | " sound = sound.set_frame_rate(sr)\n", 26 | " if debug:\n", 27 | " print(f'Resampled {filepath} to {sr} Hz')\n", 28 | " \n", 29 | " samples = sound.get_array_of_samples()\n", 30 | " n_samples = len(samples)\n", 31 | " n_chunks = n_samples//chunk_size_samples + (n_samples % chunk_size_samples > 0)\n", 32 | " start = 0\n", 33 | " for i in range(n_chunks):\n", 34 | " end = start + chunk_size_samples\n", 35 | " if i == n_chunks - 1 and sound.duration_seconds < last_chunk_min_seconds:\n", 36 | " start = start + int( (last_chunk_min_seconds*sr - n_samples)/(n_chunks-1) )\n", 37 | " end = start + chunk_size_samples\n", 38 | " if end > n_samples:\n", 39 | " end = n_samples\n", 40 | " chunk = sound._spawn(samples[start:end])\n", 41 | " output_filename = os.path.splitext(filename)[0] + f'_chunk{i}.wav'\n", 42 | " output_filepath = os.path.join(output_folder, output_filename)\n", 43 | " chunk.export(output_filepath, format='wav')\n", 44 | " start = end - int(spacing*chunk_size_samples)\n", 45 | " except Exception as e:\n", 46 | " print(f'Error processing {filepath}: {e}')\n", 47 | " errors.append(filepath)\n", 48 | " return errors" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "input_folder = 'audio'\n", 58 | "output_folder = 'audio_chunks'\n", 59 | "chunk_size_samples = 4000000\n", 60 | "sr = 24000\n", 61 | "norm = False\n", 62 | "spacing = 0.1\n", 63 | "debug = True\n", 64 | "\n", 65 | "errors = chunk_audio(input_folder, output_folder, chunk_size_samples, sr=sr, norm=norm, spacing=spacing, debug=debug)" 66 | ] 67 | } 68 | ], 69 | "metadata": { 70 | "kernelspec": { 71 | "display_name": "Python 3", 72 | "language": "python", 73 | "name": "python3" 74 | }, 75 | "language_info": { 76 | "name": "python", 77 | "version": "3.10.8 (tags/v3.10.8:aaaf517, Oct 11 2022, 16:50:30) [MSC v.1933 64 bit (AMD64)]" 78 | }, 79 | "orig_nbformat": 4, 80 | "vscode": { 81 | "interpreter": { 82 | "hash": "0e022ae142024989f92e98d32365f88ae2215fe7b3636042713ef46624b031fa" 83 | } 84 | } 85 | }, 86 | "nbformat": 4, 87 | "nbformat_minor": 2 88 | } 89 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='sample-generator', 5 | version='1.0.0', 6 | url='https://github.com/harmonai-org/sample-generator.git', 7 | author='Zach Evans', 8 | packages=find_packages(), 9 | install_requires=[ 10 | 'einops', 11 | 'pandas', 12 | 'prefigure', 13 | 'pytorch_lightning', 14 | 'scipy', 15 | 'torch', 16 | 'torchaudio', 17 | 'tqdm', 18 | 'wandb', 19 | ], 20 | ) 21 | -------------------------------------------------------------------------------- /train_latent_cond.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import sys 3 | import os 4 | import itertools 5 | from tqdm.auto import tqdm 6 | from typing import List 7 | 8 | from prefigure.prefigure import get_all_args 9 | from copy import deepcopy 10 | import math 11 | 12 | from accelerate import Accelerator 13 | from accelerate.logging import get_logger 14 | from accelerate.utils import set_seed 15 | 16 | import torch 17 | from torch import nn 18 | from torch.nn import functional as F 19 | from torch.utils import data 20 | 21 | from diffusers.optimization import get_scheduler 22 | 23 | from dataset.dataset import SampleDataset 24 | 25 | from audio_diffusion.models import UNetModel 26 | from audio_diffusion.utils import ema_update 27 | 28 | from encodec import EncodecModel 29 | from encodec.utils import save_audio 30 | 31 | from transformers import T5Tokenizer, T5EncoderModel 32 | 33 | from diffusers.schedulers import ( 34 | DDIMScheduler, 35 | DPMSolverMultistepScheduler, 36 | EulerAncestralDiscreteScheduler, 37 | EulerDiscreteScheduler, 38 | LMSDiscreteScheduler, 39 | PNDMScheduler, 40 | DDPMScheduler 41 | ) 42 | 43 | 44 | logger = get_logger(__name__) 45 | 46 | 47 | class FrozenT5Embedder(nn.Module): 48 | """Uses the T5 transformer encoder for text 49 | 50 | Code from: https://github.com/justinpinkney/stable-diffusion/blob/main/ldm/modules/encoders/modules.py""" 51 | def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl 52 | super().__init__() 53 | self.tokenizer = T5Tokenizer.from_pretrained(version) 54 | self.transformer = T5EncoderModel.from_pretrained(version) 55 | self.device = device 56 | self.max_length = max_length 57 | 58 | def freeze(self): 59 | self.transformer = self.transformer.eval() 60 | for param in self.parameters(): 61 | param.requires_grad = False 62 | 63 | def forward(self, text): 64 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 65 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 66 | tokens = batch_encoding["input_ids"].to(self.device) 67 | outputs = self.transformer(input_ids=tokens) 68 | 69 | z = outputs.last_hidden_state 70 | return z 71 | 72 | def encode(self, text): 73 | return self(text) 74 | 75 | 76 | # Denoising loop 77 | def sample(unet, codes, embedder, scheduler, device, num_inference_steps=50, batch_size=1, prompt='kawaii, future bass, edm', negative_prompt=None, do_classifier_free_guidance=True, guidance_scale=7, eta=0.0): 78 | """Code adapted from: https://github.com/huggingface/diffusers/blob/debc74f442dc74210528eb6d8a4d1f7f27fa18c3/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py""" 79 | # prepare timesteps 80 | scheduler.set_timesteps(num_inference_steps, device=device) 81 | timesteps = scheduler.timesteps 82 | 83 | text_embeddings = embedder(prompt) 84 | # get unconditional embeddings for classifier free guidance 85 | if do_classifier_free_guidance: 86 | uncond_tokens: List[str] 87 | if negative_prompt is None: 88 | uncond_tokens = [""] * batch_size 89 | elif type(prompt) is not type(negative_prompt): 90 | raise TypeError( 91 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 92 | f" {type(prompt)}." 93 | ) 94 | elif isinstance(negative_prompt, str): 95 | uncond_tokens = [negative_prompt] 96 | elif batch_size != len(negative_prompt): 97 | raise ValueError( 98 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 99 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 100 | " the batch size of `prompt`." 101 | ) 102 | else: 103 | uncond_tokens = negative_prompt 104 | 105 | uncond_embeddings = embedder(uncond_tokens) 106 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 107 | 108 | ts = codes.new_ones([codes.shape[0]]) 109 | 110 | with tqdm(total=num_inference_steps): 111 | for i, t in enumerate(timesteps): 112 | # expand the codes if we are doing classifier free guidance 113 | code_model_input = torch.cat([codes] * 2) if do_classifier_free_guidance else codes 114 | code_model_input = scheduler.scale_model_input(code_model_input, t) 115 | 116 | # predict the noise residual 117 | noise_pred = unet(code_model_input, timesteps=t*ts, context=text_embeddings) 118 | 119 | # perform guidance 120 | if do_classifier_free_guidance: 121 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 122 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 123 | 124 | # compute the previous noisy sample x_t -> x_t-1 125 | codes = scheduler.step(noise_pred, t, codes).prev_sample 126 | return codes 127 | 128 | 129 | def target_bandwidth_to_channels(target_bandwidth): 130 | """Maps target bandwidth to number of channels""" 131 | if target_bandwidth == 1.5: 132 | return 2 133 | elif target_bandwidth == 3.0: 134 | return 4 135 | elif target_bandwidth == 6.0: 136 | return 8 137 | elif target_bandwidth == 12.0: 138 | return 16 139 | elif target_bandwidth == 24.0: 140 | return 32 141 | else: 142 | raise ValueError(f"Invalid target bandwidth: {target_bandwidth}") 143 | 144 | 145 | class DiffusionUncond(nn.Module): 146 | def __init__(self, global_args): 147 | super().__init__() 148 | self.unet = UNetModel( 149 | in_channels=global_args.channels, # depends on target bandwidth 150 | out_channels=global_args.channels, # depends on target bandwidth 151 | sample_size=4096, # size of latent codes (change depending on length of audio) 152 | model_channels=320, # hyperparameter to tune 153 | attention_resolutions=[4, 2, 1], # hyperparameter to tune 154 | num_res_blocks=6, # hyperparameter to tune 155 | channel_mult=[ 1, 2, 4, 4 ], # hyperparameter to tune 156 | use_audio_transformer=True, # hyperparameter to tune 157 | use_linear_in_transformer=True, # hyperparameter to tune 158 | transformer_depth=1, # hyperparameter to tune 159 | num_head_channels=64, # hyperparameter to tune 160 | dropout=0.0, # hyperparameter to tune 161 | use_checkpoint=True if global_args.gradient_checkpointing not in [None, False, 'false', '', 'False'] else False, 162 | context_dim=1024, 163 | legacy=False, 164 | dims=1 165 | ) 166 | self.unet_ema = deepcopy(self.unet) 167 | self.encodec = EncodecModel.encodec_model_24khz() 168 | self.encodec.set_target_bandwidth(global_args.target_bandwidth) 169 | 170 | 171 | def main(): 172 | """Code adapted from: https://github.com/Harmonai-org/sample-generator/blob/main/train_uncond.py""" 173 | args = get_all_args() 174 | 175 | save_path = None if args.save_path == "" else args.save_path 176 | 177 | args.channels = target_bandwidth_to_channels(args.target_bandwidth) 178 | 179 | print(f'Using {args.channels} channels for target bandwidth {args.target_bandwidth}') 180 | 181 | accelerator = Accelerator( 182 | gradient_accumulation_steps=args.accum_batches, 183 | mixed_precision=args.precision, 184 | log_with="tensorboard", 185 | logging_dir=save_path, 186 | ) 187 | 188 | if args.scale_lr: 189 | args.lr = ( 190 | args.lr * args.accum_batches * args.batch_size * accelerator.num_gpus 191 | ) 192 | 193 | # taken from stable diffusion v-prediction model defaults, not sure if the most optimal, tune as needed 194 | scheduler_config = { 195 | "beta_end": 0.012, 196 | "beta_schedule": "scaled_linear", 197 | "beta_start": 0.00085, 198 | "clip_sample": False, 199 | "num_train_timesteps": 1000, 200 | "prediction_type": "epsilon", # "v_prediction", 201 | "trained_betas": None 202 | } 203 | 204 | model = DiffusionUncond(args) 205 | if args.ckpt_path.lower() not in [None, '', 'none', 'false']: 206 | print(f'Loading checkpoint from {args.ckpt_path}') 207 | model.unet_ema.load_state_dict(torch.load(args.ckpt_path)) 208 | model.unet.load_state_dict(torch.load(args.ckpt_path.replace('-ema', ''))) 209 | print('Loaded checkpoint') 210 | 211 | if args.use_embedder: 212 | # Load the tokenizer 213 | if args.embedder_path not in ['none', 'false', '']: 214 | embedder = FrozenT5Embedder(version=args.embedder_path, device="cuda", max_length=77) 215 | else: 216 | embedder = FrozenT5Embedder(version="google/t5-v1_1-large", device="cuda", max_length=77) 217 | if args.train_text_encoder in ['none', 'false', '', None, False]: 218 | embedder.freeze() 219 | else: 220 | print('Training text encoder') 221 | else: 222 | embedder = None 223 | 224 | train_set = SampleDataset([args.training_dir], args) 225 | train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True, 226 | num_workers=args.num_workers, persistent_workers=True, pin_memory=True, drop_last=True) 227 | 228 | # Load models 229 | unet = model.unet 230 | unet_ema = model.unet_ema 231 | unet_ema.requires_grad_(False) 232 | encodec = model.encodec 233 | encodec.to(accelerator.device) 234 | encodec.requires_grad_(False) 235 | 236 | if args.gradient_checkpointing not in [None, '', 'none', 'false', False]: 237 | print('Enabling gradient checkpointing') 238 | if args.train_text_encoder not in ['none', 'false', '', None, False]: 239 | embedder.gradient_checkpointing_enable() 240 | 241 | # Use 8-bit Adam for lower memory usage 242 | if args.use_8bit_optim: 243 | try: 244 | import bitsandbytes as bnb 245 | except ImportError: 246 | raise ImportError( 247 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 248 | ) 249 | 250 | optimizer_class = bnb.optim.AdamW8bit 251 | print('using 8-bit Adam') 252 | else: 253 | optimizer_class = torch.optim.AdamW 254 | 255 | params_to_optimize = ( 256 | unet.parameters() 257 | ) 258 | optimizer = optimizer_class( 259 | params_to_optimize, 260 | lr=args.lr, 261 | betas=(args.adam_beta1, args.adam_beta2), 262 | weight_decay=args.adam_weight_decay, 263 | eps=args.adam_epsilon, 264 | ) 265 | 266 | if args.max_train_steps == 0 or args.max_train_steps == '': 267 | args.max_train_steps = None 268 | # Scheduler and math around the number of training steps. 269 | overrode_max_train_steps = False 270 | num_update_steps_per_epoch = math.ceil(len(train_dl) / args.accum_batches) 271 | if args.max_train_steps is None: 272 | args.max_train_steps = args.num_epochs * num_update_steps_per_epoch 273 | overrode_max_train_steps = True 274 | 275 | lr_scheduler = get_scheduler( 276 | args.lr_scheduler, 277 | optimizer=optimizer, 278 | num_warmup_steps=args.lr_warmup_steps * args.accum_batches, 279 | num_training_steps=args.max_train_steps * args.accum_batches, 280 | ) 281 | 282 | if not args.use_embedder or (args.train_text_encoder in ['none', 'false', ''] and args.use_embedder): 283 | unet, optimizer, train_dl, lr_scheduler = accelerator.prepare( 284 | unet, optimizer, train_dl, lr_scheduler 285 | ) 286 | else: 287 | unet, embedder, optimizer, train_dl, lr_scheduler = accelerator.prepare( 288 | unet, embedder, optimizer, train_dl, lr_scheduler 289 | ) 290 | accelerator.register_for_checkpointing(lr_scheduler) 291 | 292 | weight_dtype = torch.float32 293 | if args.precision == "fp16": 294 | weight_dtype = torch.float16 295 | elif args.precision == "bf16": 296 | weight_dtype = torch.bfloat16 297 | 298 | encodec.to(accelerator.device, dtype=torch.float32) 299 | if args.train_text_encoder in ['none', 'false', ''] and args.use_embedder: 300 | embedder.to(accelerator.device, dtype=weight_dtype) 301 | 302 | set_seed(args.seed) 303 | noise_scheduler = DDPMScheduler(**scheduler_config) 304 | 305 | scheduler_config["set_alpha_to_one"] = False 306 | scheduler_config["steps_offset"] = 1 307 | 308 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 309 | num_update_steps_per_epoch = math.ceil(len(train_dl) / args.accum_batches) 310 | if overrode_max_train_steps: 311 | args.max_train_steps = args.num_epochs * num_update_steps_per_epoch 312 | # Afterwards we recalculate our number of training epochs 313 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 314 | 315 | # We need to initialize the trackers we use, and also store our configuration. 316 | # The trackers initializes automatically on the main process. 317 | if accelerator.is_main_process: 318 | accelerator.init_trackers("latentdancebooth", config=vars(args)) 319 | 320 | # Train! 321 | total_batch_size = args.batch_size * accelerator.num_processes * args.accum_batches 322 | 323 | logger.info("***** Running training *****") 324 | logger.info(f" Num examples = {len(train_set)}") 325 | logger.info(f" Num batches each epoch = {len(train_dl)}") 326 | logger.info(f" Num Epochs = {args.num_epochs}") 327 | logger.info(f" Instantaneous batch size per device = {args.batch_size}") 328 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 329 | logger.info(f" Gradient Accumulation steps = {args.accum_batches}") 330 | logger.info(f" Total optimization steps = {args.max_train_steps}") 331 | global_step = 0 332 | first_epoch = 0 333 | last_demo_step = -1 334 | 335 | if args.resume_from_checkpoint.lower() not in ['', 'none', 'false']: 336 | if args.resume_from_checkpoint != "latest": 337 | path = os.path.basename(args.resume_from_checkpoint) 338 | else: 339 | # Get the mos recent checkpoint 340 | dirs = os.listdir(args.save_path) 341 | dirs = [d for d in dirs if d.startswith("checkpoint")] 342 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 343 | path = dirs[-1] 344 | accelerator.print(f"Resuming from checkpoint {path}") 345 | accelerator.load_state(os.path.join(args.save_path, path)) 346 | global_step = int(path.split("-")[1]) 347 | 348 | resume_global_step = global_step * args.accum_batches 349 | first_epoch = resume_global_step // num_update_steps_per_epoch 350 | resume_step = resume_global_step % num_update_steps_per_epoch 351 | 352 | # Only show the progress bar once on each machine. 353 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) 354 | progress_bar.set_description("Steps") 355 | 356 | for epoch in range(first_epoch, args.num_train_epochs): 357 | unet.train() 358 | if args.use_embedder and args.train_text_encoder not in ['none', 'false', '']: 359 | embedder.train() 360 | for step, batch in enumerate(train_dl): 361 | # Skip steps until we reach the resumed step 362 | if args.resume_from_checkpoint.lower() not in ['', 'none', 'false'] and epoch == first_epoch and step < resume_step: 363 | if step % args.accum_batches == 0: 364 | progress_bar.update(1) 365 | continue 366 | 367 | with accelerator.accumulate(unet): 368 | # move audio to discrete codes 369 | encoded_frames = encodec.encode(batch[0]) 370 | 371 | codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).to(accelerator.device, dtype=weight_dtype) 372 | 373 | # scale from 0 - 1023 to -1 to 1 374 | codes = (codes / 511.5) - 1 375 | codes = torch.clamp(codes, -1., 1.) 376 | 377 | # Sample noise that we'll add to the codes 378 | noise = torch.randn_like(codes) 379 | bsz = codes.shape[0] 380 | 381 | # Sample a random timestep for each audio sample 382 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=codes.device) 383 | timesteps = timesteps.long() 384 | 385 | # Add noise to the codes according to the noise magnitude at each timestep 386 | # (this is the forward diffusion process) 387 | noisy_codes = noise_scheduler.add_noise(codes, noise, timesteps) 388 | 389 | if args.use_embedder: 390 | # Get the text conditioning 391 | input_ids = embedder(batch[1]) 392 | 393 | noise_pred = unet(noisy_codes, timesteps=timesteps, context=input_ids) 394 | else: 395 | noise_pred = unet(noisy_codes, timesteps=timesteps) 396 | 397 | # Get the target for loss depending on the prediction type 398 | if noise_scheduler.config.prediction_type == "epsilon": 399 | target = noise 400 | elif noise_scheduler.config.prediction_type == "v_prediction": 401 | target = noise_scheduler.get_velocity(codes, noise, timesteps) 402 | else: 403 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 404 | 405 | loss = F.mse_loss(noise_pred.float(), target.float(), reduction='mean') 406 | 407 | accelerator.backward(loss) 408 | if accelerator.sync_gradients: 409 | params_to_clip = ( 410 | itertools.chain(unet.parameters(), embedder.parameters()) 411 | if args.use_embedder and args.train_text_encoder not in ['none', 'false', ''] else 412 | unet.parameters() 413 | ) 414 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 415 | optimizer.step() 416 | ema_update(unet, unet_ema, args.ema_decay) 417 | lr_scheduler.step() 418 | optimizer.zero_grad() 419 | 420 | # Checks if the accelerator has performed an optimization step behind the scenes 421 | if accelerator.sync_gradients: 422 | progress_bar.update(1) 423 | global_step += 1 424 | 425 | if global_step % args.checkpoint_every == 0: 426 | if accelerator.is_main_process: 427 | save_path = os.path.join(args.save_path, f"checkpoint-{global_step}") 428 | accelerator.save_state(save_path) 429 | logger.info(f"Saved state to {save_path}") 430 | torch.save(unet_ema.state_dict(), os.path.join(args.save_path, f"_checkpoint-ema-{global_step}.pkl")) 431 | torch.save(unet.state_dict(), os.path.join(args.save_path, f"_checkpoint-{global_step}.pkl")) 432 | if args.train_text_encoder not in ['none', 'false', '']: 433 | torch.save(embedder.state_dict(), os.path.join(args.save_path, f"_checkpoint-t5-{global_step}.pkl")) 434 | 435 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 436 | progress_bar.set_postfix(**logs) 437 | accelerator.log(logs, step=global_step) 438 | 439 | if global_step >= args.max_train_steps: 440 | break 441 | 442 | if (global_step - 1) % args.demo_every != 0 or last_demo_step == global_step: 443 | continue 444 | 445 | if accelerator.is_main_process: 446 | with torch.no_grad(): 447 | last_demo_step = global_step 448 | 449 | noise = torch.randn([args.num_demos, args.channels, 4096]).to(accelerator.device, dtype=weight_dtype) 450 | 451 | try: 452 | if args.use_embedder: 453 | fakes = sample(unet, noise, embedder, DDIMScheduler(**scheduler_config), accelerator.device, num_inference_steps=args.demo_steps) 454 | else: 455 | fakes = sample(unet, noise, None, DDIMScheduler(**scheduler_config), accelerator.device, num_inference_steps=args.demo_steps) 456 | 457 | # scale from -1 to 1 to 0 - 1023 and discretize 458 | fakes = ((fakes + 1) * 511.5).to(torch.long) 459 | fakes = fakes.clamp(0, 1023) 460 | 461 | # decode 462 | decoded_frames = encodec.decode([(fakes, None)]) 463 | 464 | # save demos 465 | filename = f'demo_{global_step:08}.wav' 466 | for i, audio in enumerate(decoded_frames): 467 | if i > 0: 468 | save_audio(audio.cpu(), f"{filename[:-4]}_{i}.wav", encodec.sample_rate) 469 | else: 470 | save_audio(audio.cpu(), filename, encodec.sample_rate) 471 | except Exception as e: 472 | print(f'{type(e).__name__}: {e}', file=sys.stderr) 473 | 474 | accelerator.wait_for_everyone() 475 | 476 | if accelerator.is_main_process: 477 | accelerator.save(unet.state_dict(), f"checkpoint-{global_step}.pkl") 478 | 479 | if __name__ == '__main__': 480 | main() 481 | -------------------------------------------------------------------------------- /training_args.md: -------------------------------------------------------------------------------- 1 | training_dir - training data directory 2 | 3 | batch_size - batch size 4 | 5 | num_gpus - number of GPUs to use for training 6 | 7 | num_nodes - number of nodes to use for training 8 | 9 | num_workers - number of CPU workers for the DataLoader 10 | 11 | sample_size - number of audio samples for the training input (uncompressed 24hz) 12 | 13 | demo_every - number of steps between demos 14 | 15 | demo_steps - number of denoising steps for the demos 16 | 17 | num_demos - number of demos to create 18 | 19 | ema_decay - the EMA decay 20 | 21 | seed - the random seed 22 | 23 | accum_batches - How many batches for gradient accumulation 24 | 25 | checkpoint_every - number of steps between checkpoints 26 | 27 | cache_training_data - if true training data is kept in RAM 28 | 29 | random_crop - if true randomly crop input audio (for augmentation) 30 | 31 | ckpt_path - checkpoint file to (re)start training from (use the ema weights and the non-ema will automatically load with it) 32 | 33 | save_path - path to output the model checkpoints 34 | 35 | resume_from_checkpoint - resume training from checkpoint 36 | 37 | precision - what precision to use for training 38 | 39 | lr - learning rate 40 | 41 | scale_lr - whether or not to scale the learning rate (lr * accum_batches * batch_size * num_gpus) 42 | 43 | lr_warmup_steps - learning rate warmup steps 44 | 45 | use_8bit_optim - if true use 8-bit optimizer 46 | 47 | gradient_checkpointing - if true use gradient checkpointing 48 | 49 | adam_beta1 - adam beta1 50 | 51 | adam_beta2 - adam beta2 52 | 53 | adam_epsilon - adam eps 54 | 55 | adam_weight_decay - adam weight decay 56 | 57 | max_grad_norm - max gradient norm 58 | 59 | num_epochs - number of epochs until training is finished 60 | 61 | max_train_steps - maximum number of training steps until training is finished 62 | 63 | lr_scheduler - what learning rate scheduler to use 64 | 65 | target_bandwidth - target bandwidth for Encodec's compression 66 | 67 | train_text_encoder - if true train the text encoder 68 | 69 | embedder_path - path to the text encoder 70 | 71 | use_text_dropout - if true use '' in place of the prompt randomly based on the text_dropout_prob 72 | 73 | text_dropout_prob - chance that the prompt will be dropped 74 | 75 | shuffle_prompts - if true randomly shuffle the prompt by the shuffle_prompts_prob split by the seperator string (shuffle_prompts_sep) 76 | 77 | shuffle_prompts_sep - string to seperate prompt by 78 | 79 | shuffle_prompts_prob - probability that the prompt will be shuffled -------------------------------------------------------------------------------- /utils/patch_bnb.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright [2022] Victor C Hall 3 | Licensed under the GNU Affero General Public License; 4 | You may not use this code except in compliance with the License. 5 | You may obtain a copy of the License at 6 | https://www.gnu.org/licenses/agpl-3.0.en.html 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | # see: https://github.com/TimDettmers/bitsandbytes/issues/30 for explanation 15 | import sys 16 | import os 17 | from subprocess import check_output 18 | import shutil 19 | 20 | _CEXT_PATCH = " self.lib = ct.cdll.LoadLibrary(str(binary_path))" 21 | _MAIN_PATCH = " return 'libbitsandbytes_cuda116.dll'" 22 | 23 | def patch_main(): 24 | bnbpath_main = "venv/Lib/site-packages/bitsandbytes/cuda_setup/main.py" 25 | try: 26 | with open(bnbpath_main, "r") as f: 27 | contents = f.read() 28 | contents = contents.split('\n') 29 | except Exception as ex: 30 | print(f"cannot find bitsandbytes install, aborting, error: {ex}") 31 | return False 32 | 33 | main_patched = False 34 | 35 | for i, line in enumerate(contents): 36 | if i == 112: 37 | if line != _MAIN_PATCH: 38 | contents[i] = _MAIN_PATCH 39 | main_patched = True 40 | else: 41 | print(" *** Already patched!") 42 | main_patched = True 43 | 44 | assert main_patched, "unable to patch bitsandbytes, may be mismatched version, requires 0.35.0" 45 | 46 | with open(bnbpath_main, "w") as f: 47 | for line in contents: 48 | f.write(line + "\n") 49 | #print(contents) 50 | 51 | return main_patched 52 | 53 | def patch_cext(): 54 | bnbpath_cextension = "venv/Lib/site-packages/bitsandbytes/cextension.py" 55 | try: 56 | with open(bnbpath_cextension, "r") as f: 57 | contents = f.read() 58 | contents = contents.split('\n') 59 | except Exception as ex: 60 | print(f"cannot find bitsandbytes install, aborting, error: {ex}") 61 | return False 62 | 63 | cext_patched = False 64 | 65 | for i, line in enumerate(contents): 66 | # update both lines 28 and 31 to be sure correct dll is returned 67 | if (i == 30 or i == 27): 68 | if line != _CEXT_PATCH: 69 | contents[i] = _CEXT_PATCH 70 | cext_patched = True 71 | else: 72 | cext_patched = True 73 | 74 | assert cext_patched, "unable to patch bitsandbytes, died midprocess, something broke and may need to reinstall bitsandbytes==0.35.0" 75 | 76 | with open(bnbpath_cextension, "w") as f: 77 | for line in contents: 78 | f.write(line + "\n") 79 | #print(contents) 80 | 81 | return cext_patched 82 | 83 | def iswindows(): 84 | return sys.platform.startswith('win') 85 | 86 | def error(): 87 | print("Somethnig went wrong trying to patch bitsandbytes, aborting") 88 | print("make sure your venv is activated and try again") 89 | print("or if activated try: ") 90 | print(" pip install bitsandbytes==0.35.0") 91 | raise RuntimeError("** FATAL ERROR: unable to patch bitsandbytes for Windows env") 92 | 93 | def check_dlls(): 94 | dll_exists = os.path.exists("venv/Lib/site-packages/bitsandbytes/libbitsandbytes_cuda116.dll") 95 | if not dll_exists: 96 | if not os.path.exists("tmp/bnb_cache"): 97 | check_output("git clone https://github.com/DeXtmL/bitsandbytes-win-prebuilt tmp/bnb_cache", shell=True) 98 | shutil.copy("tmp/bnb_cache/libbitsandbytes_cuda116.dll", "venv/Lib/site-packages/bitsandbytes/libbitsandbytes_cuda116.dll") 99 | dll_exists = os.path.exists("venv/Lib/site-packages/bitsandbytes/libbitsandbytes_cuda116.dll") 100 | return dll_exists 101 | 102 | def main(): 103 | """ 104 | applies a patch for windows compatibility for bitsandbytes 0.35.0 for using their AdamW8bit optimizer 105 | """ 106 | if iswindows(): 107 | print() 108 | print(" *** Applying bitsandbytes patch for windows ***") 109 | if not check_dlls(): 110 | print("unable to find bitsandbytes dll or clone them from git, aborting") 111 | raise RuntimeError("** FATAL ERROR: unable to patch bitsandbytes for Windows env") 112 | 113 | main_patched = patch_main() 114 | cext_patched = patch_cext() 115 | if main_patched and cext_patched: 116 | try: 117 | print(" *************************************************************") 118 | print(" *** bitsandbytes windows patch applied, attempting import *** ") 119 | import bitsandbytes 120 | print(f" *** bitsandbytes patch succeeded, everything looks good! ***") 121 | except: 122 | error() 123 | else: 124 | error() 125 | else: 126 | print(" *** not using windows environment, skipping bitsandbytes patch ***") 127 | return 128 | 129 | if __name__ == "__main__": 130 | main() -------------------------------------------------------------------------------- /viz/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/serp-ai/ai-text-to-audio-latent-diffusion/55230601b4f34b30bc52568f58619c8c33b3e202/viz/__init__.py -------------------------------------------------------------------------------- /viz/viz.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | from pathlib import Path 4 | from matplotlib.backends.backend_agg import FigureCanvasAgg 5 | import matplotlib.cm as cm 6 | import matplotlib.pyplot as plt 7 | from matplotlib.colors import Normalize 8 | from matplotlib.figure import Figure 9 | import numpy as np 10 | from PIL import Image 11 | 12 | import torch 13 | from torch import optim, nn 14 | from torch.nn import functional as F 15 | import torchaudio 16 | import torchaudio.transforms as T 17 | import librosa 18 | from einops import rearrange 19 | 20 | import wandb 21 | import numpy as np 22 | import pandas as pd 23 | 24 | def spectrogram_image(spec, title=None, ylabel='freq_bin', aspect='auto', xmax=None, db_range=[35,120]): 25 | """ 26 | # cf. https://pytorch.org/tutorials/beginner/audio_feature_extractions_tutorial.html 27 | 28 | """ 29 | fig = Figure(figsize=(5, 4), dpi=100) 30 | canvas = FigureCanvasAgg(fig) 31 | axs = fig.add_subplot() 32 | axs.set_title(title or 'Spectrogram (db)') 33 | axs.set_ylabel(ylabel) 34 | axs.set_xlabel('frame') 35 | im = axs.imshow(librosa.power_to_db(spec), origin='lower', aspect=aspect, vmin=db_range[0], vmax=db_range[1]) 36 | if xmax: 37 | axs.set_xlim((0, xmax)) 38 | fig.colorbar(im, ax=axs) 39 | canvas.draw() 40 | rgba = np.asarray(canvas.buffer_rgba()) 41 | return Image.fromarray(rgba) 42 | 43 | 44 | def audio_spectrogram_image(waveform, power=2.0, sample_rate=48000): 45 | """ 46 | # cf. https://pytorch.org/tutorials/beginner/audio_feature_extractions_tutorial.html 47 | """ 48 | n_fft = 1024 49 | win_length = None 50 | hop_length = 512 51 | n_mels = 80 52 | 53 | mel_spectrogram_op = T.MelSpectrogram( 54 | sample_rate=sample_rate, n_fft=n_fft, win_length=win_length, 55 | hop_length=hop_length, center=True, pad_mode="reflect", power=power, 56 | norm='slaney', onesided=True, n_mels=n_mels, mel_scale="htk") 57 | 58 | melspec = mel_spectrogram_op(waveform.float()) 59 | melspec = melspec[0] # TODO: only left channel for now 60 | return spectrogram_image(melspec, title="MelSpectrogram", ylabel='mel bins (log freq)') 61 | -------------------------------------------------------------------------------- /windows_setup.cmd: -------------------------------------------------------------------------------- 1 | python -m venv venv 2 | call "venv\Scripts\activate.bat" 3 | echo should be in venv here 4 | cd . 5 | python -m pip install --upgrade pip 6 | pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.13.1+cu116 --extra-index-url "https://download.pytorch.org/whl/cu116" 7 | pip install transformers==4.25.1 8 | pip install diffusers[torch]==0.10.2 9 | pip install pynvml==11.4.1 10 | pip install bitsandbytes==0.35.0 11 | git clone https://github.com/DeXtmL/bitsandbytes-win-prebuilt tmp/bnb_cache 12 | pip install ftfy==6.1.1 13 | pip install aiohttp==3.8.3 14 | pip install tensorboard>=2.11.0 15 | pip install protobuf==3.20.1 16 | pip install wandb==0.13.6 17 | pip install pyre-extensions==0.0.23 18 | pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl 19 | ::pip install "xformers-0.0.15.dev0+affe4da.d20221212-cp38-cp38-win_amd64.whl" --force-reinstall 20 | pip install pytorch-lightning==1.6.5 21 | pip install OmegaConf==2.2.3 22 | pip install numpy==1.23.5 23 | pip install einops pandas prefigure scipy tqdm pydub encodec 24 | python utils/patch_bnb.py 25 | GOTO :eof 26 | 27 | :ERROR 28 | echo Something blew up. Make sure Pyton 3.10.x is installed and in your PATH. 29 | 30 | :eof 31 | ECHO done 32 | pause --------------------------------------------------------------------------------