├── .gitmodules ├── blocks ├── __init__.py ├── embedders.py ├── residual_vq.py └── utils.py ├── losses ├── __init__.py ├── time_losses.py ├── perceptual_losses.py ├── freq_losses.py └── adv_losses.py ├── test ├── __init__.py ├── profiler.py ├── pqmf_test.py └── utils_test.py ├── dataset └── __init__.py ├── decoders └── __init__.py ├── diffusion ├── __init__.py ├── model.py ├── sampling.py ├── xunet.py ├── transformers.py ├── inference.py └── utils.py ├── effects ├── __init__.py └── tcn.py ├── encoders ├── __init__.py ├── t5.py ├── wavelets.py ├── gabor_filter.py ├── perceiver_resampler.py └── encoders.py ├── prompts ├── __init__.py └── prompters.py ├── trainer ├── __init__.py └── trainer.py ├── viz ├── __init__.py ├── viz.py.bak └── viz.py ├── README.md ├── autoencoders ├── __init__.py └── soundstream.py ├── audio_diffusion.egg-info ├── dependency_links.txt ├── top_level.txt ├── requires.txt ├── PKG-INFO └── SOURCES.txt ├── model_configs ├── k_audio_uncond_1.json ├── k_audio_uncond_long_1.json ├── k_audio_diffae_avg_1.json ├── k_audio_diffae_no_attn_1.json ├── k_audio_diffae_1.json ├── k_audio_diffae_long.json └── k_audio_diffae_2.json ├── script_configs ├── generate_defaults.ini ├── clap_scores_defaults.ini ├── vocoder_defaults.ini ├── ae_defaults.ini └── latent_defaults.ini ├── setup.py ├── train_archive ├── train_k_defaults.ini ├── train_uncond_k.py ├── train_dvae_k.py ├── train_ad_uncond_full.py └── train_ad_upsampler.py ├── clap_test.py ├── LICENSE ├── convert_mp3_to_wav.py ├── .gitignore ├── defaults.ini ├── prune_latent_uncond.py ├── pca_analysis.py ├── train_clap_duration_predictor.py ├── test_encodec.py ├── train_local_transformer_clap_encoder.py ├── test_data_loader.py ├── prune_ckpt.py ├── train_wavelet_transformer.py └── prune_clap_ckpt.py /.gitmodules: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /blocks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /decoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /effects/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /encoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /prompts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /viz/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # audio-diffusion -------------------------------------------------------------------------------- /autoencoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audio_diffusion.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /audio_diffusion.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | byol 2 | dataset 3 | diffusion 4 | dvae 5 | encoders 6 | test 7 | -------------------------------------------------------------------------------- /audio_diffusion.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | auraloss 2 | einops 3 | fairscale 4 | nwt-pytorch 5 | perceiver-pytorch 6 | pytorch_lightning 7 | torch 8 | torchaudio 9 | vector-quantize-pytorch 10 | wandb 11 | -------------------------------------------------------------------------------- /audio_diffusion.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: audio-diffusion 3 | Version: 1.0.0 4 | Summary: UNKNOWN 5 | Home-page: https://github.com/zqevans/audio-diffusion.git 6 | Author: Zach Evans 7 | License: UNKNOWN 8 | Platform: UNKNOWN 9 | License-File: LICENSE 10 | 11 | UNKNOWN 12 | 13 | -------------------------------------------------------------------------------- /test/profiler.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | class Profiler: 4 | 5 | def __init__(self): 6 | self.ticks = [[time(), None]] 7 | 8 | def tick(self, msg): 9 | self.ticks.append([time(), msg]) 10 | 11 | def __repr__(self): 12 | rep = 80 * "=" + "\n" 13 | for i in range(1, len(self.ticks)): 14 | msg = self.ticks[i][1] 15 | ellapsed = self.ticks[i][0] - self.ticks[i - 1][0] 16 | rep += msg + f": {ellapsed*1000:.2f}ms\n" 17 | rep += 80 * "=" + "\n\n\n" 18 | return rep -------------------------------------------------------------------------------- /model_configs/k_audio_uncond_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "audio_v1", 3 | "input_channels": 2, 4 | "input_size": 65536, 5 | "pqmf_bands": 4, 6 | "mapping_out": 128, 7 | "depths": [2, 3, 3, 3, 3, 3, 4, 4], 8 | "strides": [4, 4, 4, 2, 2, 2, 2], 9 | "channels": [128, 256, 512, 512, 512, 512, 512, 512], 10 | "self_attn_depths": [false, false, false, false, false, true, true, true], 11 | "dropout_rate": 0.0, 12 | "ema_decay":0.995, 13 | "sigma_data": 0.25, 14 | "sigma_min": 1e-2, 15 | "sigma_max": 3, 16 | "sigma_sample_density": { 17 | "type": "lognormal", 18 | "mean": -1.2, 19 | "std": 1.2 20 | } 21 | } -------------------------------------------------------------------------------- /model_configs/k_audio_uncond_long_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "audio_v1", 3 | "input_channels": 2, 4 | "input_size": 524288, 5 | "sample_rate": 12000, 6 | "pqmf_bands": 16, 7 | "mapping_out": 128, 8 | "depths": [2, 2, 2, 3, 3, 3, 4, 4], 9 | "strides": [4, 4, 4, 2, 2, 2, 2], 10 | "channels": [256, 256, 512, 512, 512, 512, 512, 512], 11 | "self_attn_depths": [false, false, false, false, false, true, true, true], 12 | "dropout_rate": 0.0, 13 | "ema_decay":0.995, 14 | "sigma_data": 0.25, 15 | "sigma_min": 1e-2, 16 | "sigma_max": 3, 17 | "sigma_sample_density": { 18 | "type": "lognormal", 19 | "mean": -1.2, 20 | "std": 1.2 21 | } 22 | } -------------------------------------------------------------------------------- /losses/time_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from diffusion.pqmf import CachedPQMF as PQMF 4 | 5 | #Might work, but incredibly CPU-intensive 6 | 7 | class MultiScalePQMFLoss(torch.nn.Module): 8 | def __init__(self,n_channels = 2, attenuation=70, band_counts=[8, 16, 32]): 9 | super(MultiScalePQMFLoss, self).__init__() 10 | self.pqmfs = [] 11 | 12 | for band_count in band_counts: 13 | self.pqmfs.append(PQMF(n_channels, attenuation, band_count)) 14 | 15 | def forward(self, input, target): 16 | loss = 0.0 17 | for pqmf in self.pqmfs: 18 | input_pqmf = pqmf(input.float().cpu()) 19 | target_pqmf = pqmf(target.float().cpu()) 20 | loss += torch.square(input_pqmf - target_pqmf).mean() 21 | 22 | return loss -------------------------------------------------------------------------------- /model_configs/k_audio_diffae_avg_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "audio_v1", 3 | "input_channels": 2, 4 | "input_size": 65536, 5 | "pqmf_bands": 1, 6 | "encoder_base_channels": 32, 7 | "encoder_c_mults": [2, 4, 8, 16, 32], 8 | "encoder_strides": [4, 4, 2, 2, 2], 9 | "local_latent_dim": 64, 10 | "mapping_out": 128, 11 | "depths": [2, 3, 3, 3, 3, 3, 4, 4], 12 | "strides": [4, 4, 4, 2, 2, 2, 2], 13 | "channels": [256, 256, 512, 512, 512, 512, 512, 512], 14 | "self_attn_depths": [false, false, false, false, false, true, true, true], 15 | "dropout_rate": 0.0, 16 | "ema_decay":0.995, 17 | "sigma_data": 0.25, 18 | "sigma_min": 1e-2, 19 | "sigma_max": 3, 20 | "sigma_sample_density": { 21 | "type": "lognormal", 22 | "mean": -1.2, 23 | "std": 1.2 24 | } 25 | } -------------------------------------------------------------------------------- /script_configs/generate_defaults.ini: -------------------------------------------------------------------------------- 1 | 2 | [DEFAULTS] 3 | 4 | # Number of audio samples for the training input 5 | sample_size = 2097152 6 | 7 | # Number of audio samples to make 8 | batch_size = 4 9 | 10 | # eta for DDIM sampling 11 | eta = 0.0 12 | 13 | # classifier-free guidance scale 14 | cfg_scale = 5.0 15 | 16 | # number of text-conditioned model steps 17 | gen_steps = 100 18 | 19 | # number of diffusion decoder steps 20 | decoder_steps = 15 21 | 22 | # the random seed 23 | seed = -1 24 | 25 | # The sample rate of the audio 26 | sample_rate = 48000 27 | 28 | # checkpoint file to (re)start training from 29 | ckpt_path = '' 30 | 31 | # directory to save the checkpoints in 32 | save_dir = './generated' 33 | 34 | # Path to a pretrained CLAP checkpoint 35 | clap_ckpt_path = '' 36 | 37 | clap_fusion = False 38 | 39 | clap_amodel='HTSAT-tiny' -------------------------------------------------------------------------------- /model_configs/k_audio_diffae_no_attn_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "audio_v1", 3 | "input_channels": 2, 4 | "input_size": 65536, 5 | "pqmf_bands": 1, 6 | "encoder_base_channels": 32, 7 | "encoder_c_mults": [2, 4, 8, 16, 32], 8 | "encoder_strides": [4, 4, 2, 2, 2], 9 | "local_latent_dim": 64, 10 | "global_latent_dim": 128, 11 | "mapping_out": 128, 12 | "depths": [2, 3, 3, 3, 3, 3, 4], 13 | "strides": [4, 4, 4, 2, 2, 2], 14 | "channels": [256, 256, 512, 512, 512, 512, 512], 15 | "self_attn_depths": [false, false, false, false, false, false, false], 16 | "dropout_rate": 0.0, 17 | "ema_decay":0.995, 18 | "sigma_data": 0.25, 19 | "sigma_min": 1e-2, 20 | "sigma_max": 3, 21 | "sigma_sample_density": { 22 | "type": "lognormal", 23 | "mean": -1.2, 24 | "std": 1.2 25 | } 26 | } -------------------------------------------------------------------------------- /script_configs/clap_scores_defaults.ini: -------------------------------------------------------------------------------- 1 | 2 | [DEFAULTS] 3 | 4 | #name of the run 5 | name = test-latent-diffae 6 | 7 | # the batch size 8 | batch_size = 8 9 | 10 | # number of GPUs to use for training 11 | num_gpus = 1 12 | 13 | # number of nodes to use for training 14 | num_nodes = 1 15 | 16 | # number of CPU workers for the DataLoader 17 | num_workers = 12 18 | 19 | # Number of audio samples for the training input 20 | sample_size = 65536 21 | 22 | # the random seed 23 | seed = 42 24 | 25 | # The sample rate of the audio 26 | sample_rate = 48000 27 | 28 | # randomly crop input audio? (for augmentation) 29 | random_crop = True 30 | 31 | # directory to save CLAP scores 32 | save_dir = '' 33 | 34 | # Checkpoint for a pre-trained CLAP model 35 | clap_ckpt_path='' 36 | 37 | clap_fusion=False 38 | 39 | clap_amodel='HTSAT-tiny' -------------------------------------------------------------------------------- /model_configs/k_audio_diffae_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "audio_v1", 3 | "input_channels": 2, 4 | "input_size": 65536, 5 | "pqmf_bands": 1, 6 | "encoder_base_channels": 32, 7 | "encoder_c_mults": [2, 4, 8, 16, 32], 8 | "encoder_strides": [4, 4, 2, 2, 2], 9 | "local_latent_dim": 64, 10 | "global_latent_dim": 128, 11 | "mapping_out": 128, 12 | "depths": [2, 3, 3, 3, 3, 3, 4, 4], 13 | "strides": [4, 4, 4, 2, 2, 2, 2], 14 | "channels": [256, 256, 512, 512, 512, 512, 512, 512], 15 | "self_attn_depths": [false, false, false, false, false, true, true, true], 16 | "dropout_rate": 0.0, 17 | "ema_decay":0.995, 18 | "sigma_data": 0.25, 19 | "sigma_min": 1e-2, 20 | "sigma_max": 3, 21 | "sigma_sample_density": { 22 | "type": "lognormal", 23 | "mean": -1.2, 24 | "std": 1.2 25 | } 26 | } -------------------------------------------------------------------------------- /model_configs/k_audio_diffae_long.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "audio_v1", 3 | "input_channels": 2, 4 | "input_size": 65536, 5 | "pqmf_bands": 1, 6 | "encoder_base_channels": 16, 7 | "encoder_c_mults": [2, 4, 8, 16, 32, 32, 32, 32], 8 | "encoder_strides": [4, 4, 4, 2, 2, 2, 2, 2], 9 | "local_latent_dim": 32, 10 | "global_latent_dim": 128, 11 | "mapping_out": 128, 12 | "depths": [2, 3, 3, 3, 3, 3, 4, 4], 13 | "strides": [4, 4, 4, 2, 2, 2, 2], 14 | "channels": [256, 256, 256, 256, 256, 256, 256, 256], 15 | "self_attn_depths": [false, false, false, false, false, false, false, false], 16 | "dropout_rate": 0.0, 17 | "ema_decay":0.995, 18 | "sigma_data": 0.25, 19 | "sigma_min": 1e-2, 20 | "sigma_max": 3, 21 | "sigma_sample_density": { 22 | "type": "lognormal", 23 | "mean": -1.2, 24 | "std": 1.2 25 | } 26 | } -------------------------------------------------------------------------------- /audio_diffusion.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | README.md 3 | setup.py 4 | audio_diffusion.egg-info/PKG-INFO 5 | audio_diffusion.egg-info/SOURCES.txt 6 | audio_diffusion.egg-info/dependency_links.txt 7 | audio_diffusion.egg-info/requires.txt 8 | audio_diffusion.egg-info/top_level.txt 9 | byol/__init__.py 10 | byol/byol_pytorch.py 11 | dataset/__init__.py 12 | dataset/dataloader.py 13 | dataset/dataset.py 14 | diffusion/__init__.py 15 | diffusion/inference.py 16 | diffusion/model.py 17 | diffusion/pqmf.py 18 | diffusion/utils.py 19 | diffusion/FastDiff/FastDiff_model.py 20 | diffusion/FastDiff/__init__.py 21 | diffusion/FastDiff/modules.py 22 | diffusion/FastDiff/util.py 23 | dvae/__init__.py 24 | dvae/dvae.py 25 | dvae/residual_memcodes.py 26 | encoders/__init__.py 27 | encoders/encoders.py 28 | encoders/learner.py 29 | encoders/losses.py 30 | test/__init__.py 31 | test/pqmf_test.py 32 | test/utils_test.py -------------------------------------------------------------------------------- /model_configs/k_audio_diffae_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "audio_v1", 3 | "input_channels": 2, 4 | "input_size": 65536, 5 | "pqmf_bands": 1, 6 | "encoder_base_channels": 32, 7 | "encoder_c_mults": [2, 4, 8, 16, 32, 32, 32, 32, 32], 8 | "encoder_strides": [2, 2, 2, 2, 2, 2, 2, 2, 2], 9 | "local_latent_dim": 64, 10 | "global_latent_dim": 128, 11 | "mapping_out": 128, 12 | "depths": [2, 2, 3, 3, 3, 3, 3, 4, 4], 13 | "strides": [4, 4, 2, 2, 2, 2, 2, 2], 14 | "channels": [256, 256, 512, 512, 512, 512, 512, 512, 512], 15 | "self_attn_depths": [false, false, false, false, false, true, true, true, true], 16 | "dropout_rate": 0.0, 17 | "ema_decay":0.995, 18 | "sigma_data": 0.25, 19 | "sigma_min": 1e-2, 20 | "sigma_max": 3, 21 | "sigma_sample_density": { 22 | "type": "lognormal", 23 | "mean": -1.2, 24 | "std": 1.2 25 | } 26 | } -------------------------------------------------------------------------------- /test/pqmf_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from diffusion.pqmf import CachedPQMF as PQMF 4 | 5 | import unittest 6 | 7 | class TestPqmf(unittest.TestCase): 8 | 9 | def test_pqmf_shapes_equal(self): 10 | signal = torch.randn([1, 1, 131072]) 11 | pqmf = PQMF(1, 70, 32) 12 | encoded = pqmf(signal) 13 | decoded = pqmf.inverse(encoded) 14 | 15 | #the inverse has the same shape as the original 16 | self.assertEqual(list(signal.shape), list(decoded.shape)) 17 | 18 | def test_pqmf_stereo_shapes(self): 19 | signal = torch.randn([1, 2, 131072]) 20 | pqmf = PQMF(2, 70, 32) 21 | encoded = pqmf(signal) 22 | print(encoded.shape) 23 | decoded = pqmf.inverse(encoded) 24 | 25 | #the inverse has the same shape as the original 26 | self.assertEqual(list(signal.shape), list(decoded.shape)) 27 | 28 | 29 | if __name__ == '__main__': 30 | unittest.main() -------------------------------------------------------------------------------- /losses/perceptual_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import librosa as li 5 | 6 | class Loudness(nn.Module): 7 | def __init__(self, sr, block_size, n_fft=2048): 8 | super().__init__() 9 | self.sr = sr 10 | self.block_size = block_size 11 | self.n_fft = n_fft 12 | 13 | f = np.linspace(0, sr / 2, n_fft // 2 + 1) + 1e-7 14 | a_weight = li.A_weighting(f).reshape(-1, 1) 15 | 16 | self.register_buffer("a_weight", torch.from_numpy(a_weight).float()) 17 | self.register_buffer("window", torch.hann_window(self.n_fft)) 18 | 19 | def forward(self, x): 20 | x = torch.stft( 21 | x.squeeze(1), 22 | self.n_fft, 23 | self.block_size, 24 | self.n_fft, 25 | center=True, 26 | window=self.window, 27 | return_complex=True, 28 | ).abs() 29 | x = torch.log(x + 1e-7) + self.a_weight 30 | return torch.mean(x, 1, keepdim=True) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='audio-diffusion', 5 | version='1.0.0', 6 | url='https://github.com/zqevans/audio-diffusion.git', 7 | author='Zach Evans', 8 | packages=find_packages(), 9 | install_requires=[ 10 | 'a-transformers', 11 | 'accelerate', 12 | 'aeiou', 13 | 'archisound', 14 | 'audio-diffusion-pytorch', 15 | 'audio-encoders-pytorch', 16 | 'audio-metadata', 17 | 'auraloss', 18 | 'einops', 19 | 'ema-pytorch', 20 | 'encodec', 21 | 'k-diffusion', 22 | 'pandas', 23 | 'pedalboard', 24 | 'perceiver-pytorch', 25 | 'prefigure', 26 | 'pyloudnorm', 27 | 'pytorch_lightning', 28 | 'PyWavelets', 29 | 'quantizer-pytorch', 30 | 'scipy', 31 | 'torch', 32 | 'torchaudio', 33 | 'tqdm', 34 | 'transformers', 35 | 'vector-quantize-pytorch', 36 | 'wandb', 37 | 'cached_conv @ git+https://github.com/caillonantoine/cached_conv.git', 38 | ], 39 | ) -------------------------------------------------------------------------------- /blocks/embedders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from typing import List 4 | from transformers import AutoTokenizer, T5EncoderModel 5 | from einops import rearrange 6 | 7 | class TextEmbedder(nn.Module): 8 | def __init__(self, model: str = "t5-base", max_length: int = 64, enable_grad = False): 9 | super().__init__() 10 | 11 | self.tokenizer = AutoTokenizer.from_pretrained(model, model_max_length = max_length) 12 | self.transformer = T5EncoderModel.from_pretrained(model, max_length=max_length) 13 | 14 | if not enable_grad: 15 | self.transformer.requires_grad_(False) 16 | 17 | self.max_length = max_length 18 | self.enable_grad = enable_grad 19 | 20 | def forward(self, texts: List[str]) -> Tensor: 21 | 22 | encoded = self.tokenizer( 23 | texts, 24 | truncation=True, 25 | max_length=self.max_length, 26 | padding="max_length", 27 | return_tensors="pt", 28 | ) 29 | 30 | device = next(self.transformer.parameters()).device 31 | input_ids = encoded["input_ids"].to(device) 32 | attention_mask = encoded["attention_mask"].to(device).to(torch.bool) 33 | 34 | if not self.enable_grad: 35 | self.transformer.eval() 36 | 37 | embedding = self.transformer( 38 | input_ids=input_ids, attention_mask=attention_mask 39 | )["last_hidden_state"] 40 | 41 | return embedding, attention_mask -------------------------------------------------------------------------------- /blocks/residual_vq.py: -------------------------------------------------------------------------------- 1 | #Copied and modified from https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/residual_vq.py 2 | 3 | from functools import partial 4 | import torch 5 | from torch import nn 6 | from vector_quantize_pytorch import VectorQuantize 7 | 8 | class ResidualVQ(nn.Module): 9 | """ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """ 10 | def __init__( 11 | self, 12 | *, 13 | num_quantizers, 14 | shared_codebook = False, 15 | **kwargs 16 | ): 17 | super().__init__() 18 | self.layers = nn.ModuleList([VectorQuantize(**kwargs) for _ in range(num_quantizers)]) 19 | 20 | if not shared_codebook: 21 | return 22 | 23 | first_vq, *rest_vq = self.layers 24 | codebook = first_vq._codebook 25 | 26 | for vq in rest_vq: 27 | vq._codebook = codebook 28 | 29 | def forward(self, x): 30 | quantized_out = 0. 31 | residual = x 32 | 33 | all_losses = [] 34 | all_indices = [] 35 | 36 | for layer in self.layers: 37 | quantized, indices, loss = layer(residual) 38 | residual = residual - quantized 39 | quantized_out = quantized_out + quantized 40 | 41 | all_indices.append(indices) 42 | all_losses.append(loss) 43 | 44 | all_losses, all_indices = map(partial(torch.stack, dim = -1), (all_losses, all_indices)) 45 | return quantized_out, all_indices, all_losses -------------------------------------------------------------------------------- /script_configs/vocoder_defaults.ini: -------------------------------------------------------------------------------- 1 | 2 | [DEFAULTS] 3 | 4 | #name of the run 5 | name = test-vocoder 6 | 7 | # training data directory 8 | training_dir = /home/ubuntu/datasets/SignalTrain_LA2A_Dataset_1.1 9 | 10 | # the batch size 11 | batch_size = 8 12 | 13 | # number of GPUs to use for training 14 | num_gpus = 1 15 | 16 | # number of nodes to use for training 17 | num_nodes = 1 18 | 19 | # number of CPU workers for the DataLoader 20 | num_workers = 12 21 | 22 | # Number of audio samples for the training input 23 | sample_size = 65536 24 | 25 | # Number of epochs between demos 26 | demo_every = 20 27 | 28 | # Number of denoising steps for the demos 29 | demo_steps = 250 30 | 31 | # Number of demos to create 32 | num_demos = 16 33 | 34 | # the random seed 35 | seed = 42 36 | 37 | # Batches for gradient accumulation 38 | accum_batches = 1 39 | 40 | # The sample rate of the audio 41 | sample_rate = 48000 42 | 43 | # Number of steps between checkpoints 44 | checkpoint_every = 10000 45 | 46 | # the EMA decay 47 | ema_decay = 0.995 48 | 49 | use_mfcc = False 50 | 51 | n_fft = 1024 52 | 53 | hop_length = 256 54 | 55 | n_mels = 80 56 | 57 | # If true training data is kept in RAM 58 | cache_training_data = False 59 | 60 | # randomly crop input audio? (for augmentation) 61 | random_crop = True 62 | 63 | # checkpoint file to (re)start training from 64 | ckpt_path = '' 65 | 66 | # configuration model specifying model hyperparameters 67 | model_config = '' 68 | 69 | #the multiprocessing start method ['fork', 'forkserver', 'spawn'] 70 | start_method = 'forkserver' -------------------------------------------------------------------------------- /train_archive/train_k_defaults.ini: -------------------------------------------------------------------------------- 1 | 2 | [DEFAULTS] 3 | 4 | #name of the run 5 | name = test-dvae 6 | 7 | # training data directory 8 | training_dir = /home/ubuntu/datasets/SignalTrain_LA2A_Dataset_1.1 9 | 10 | # the batch size 11 | batch_size = 8 12 | 13 | # number of GPUs to use for training 14 | num_gpus = 1 15 | 16 | # number of CPU workers for the DataLoader 17 | num_workers = 12 18 | 19 | # Number of samples to train on must be a multiple of 16384 20 | sample_size = 65536 21 | 22 | # Number of epochs between demos 23 | demo_every = 50 24 | 25 | # Number of denoising steps for the demos 26 | demo_steps = 250 27 | 28 | # Number of demos to create 29 | num_demos = 16 30 | 31 | # the random seed 32 | seed = 42 33 | 34 | # Batches for gradient accumulation 35 | accum_batches = 1 36 | 37 | # The sample rate of the audio 38 | sample_rate = 48000 39 | 40 | # Number of steps between checkpoints 41 | checkpoint_every = 10000 42 | 43 | # the EMA decay 44 | ema_decay = 0.995 45 | 46 | # the dimension of the local latents 47 | latent_dim = 64 48 | 49 | # The dimension of the global latent 50 | global_latent_dim = 128 51 | 52 | # If true training data is kept in RAM 53 | cache_training_data = False 54 | 55 | # randomly crop input audio? (for augmentation) 56 | random_crop = True 57 | 58 | # normalize input audio? 59 | norm_inputs = False 60 | 61 | # checkpoint file to (re)start training from 62 | ckpt_path = '' 63 | 64 | # learning rate 65 | lr = 4e-5 66 | 67 | model_config = '' 68 | 69 | wandb_entity = 'zqevans' 70 | 71 | wandb_group = 'harmonai' 72 | 73 | wandb_project = 'k-audio-diffusion' 74 | 75 | wandb_save_model = False 76 | 77 | #the multiprocessing start method ['fork', 'forkserver', 'spawn'] 78 | start_method = 'spawn' -------------------------------------------------------------------------------- /clap_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import laion_clap 3 | import torch 4 | from dataset.dataset import get_wds_loader 5 | import numpy as np 6 | from sklearn.decomposition import PCA 7 | import matplotlib.pyplot as plt 8 | 9 | def main(args): 10 | clap_model = laion_clap.CLAP_Module(enable_fusion=True).to("cuda") 11 | clap_model.load_ckpt(model_id=3) 12 | 13 | names = [] 14 | 15 | train_dl = get_wds_loader( 16 | batch_size=args.batch_size, 17 | s3_url_prefix=None, 18 | sample_size=args.sample_size, 19 | names=names, 20 | sample_rate=args.sample_rate, 21 | num_workers=args.num_workers, 22 | recursive=True, 23 | random_crop=True, 24 | epoch_steps=10, 25 | ) 26 | 27 | all_embeddings = [] 28 | 29 | for i, batch in enumerate(iter(train_dl)): 30 | print(f"Batch {i}") 31 | audios, jsons, timestamps = batch 32 | audios = audios[0].to("cuda") 33 | 34 | audios = torch.mean(audios, dim=1) 35 | 36 | clap_audio_embeds = clap_model.get_audio_embedding_from_data(audios.cpu().numpy()) 37 | all_embeddings.append(clap_audio_embeds) 38 | 39 | all_embeddings = np.concatenate(all_embeddings, axis=0) 40 | 41 | pca = PCA(n_components=2) 42 | embeddings_2d = pca.fit_transform(all_embeddings) 43 | 44 | plt.scatter(embeddings_2d[:,0], embeddings_2d[:,1]) 45 | plt.savefig("pca_chart.png") 46 | 47 | if __name__ == '__main__': 48 | parser = argparse.ArgumentParser(description='Description of your program') 49 | parser.add_argument('-b','--batch_size', help='Batch size', type=int, default=8) 50 | parser.add_argument('-s','--sample_size', help='Sample size', type=int, default=480000) 51 | parser.add_argument('-r','--sample_rate', help='Sample rate', type=int, default=48000) 52 | parser.add_argument('-w','--num_workers', help='Number of workers', type=int, default=12) 53 | args = parser.parse_args() 54 | 55 | main(args) -------------------------------------------------------------------------------- /diffusion/model.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import math 3 | import pytorch_lightning as pl 4 | import torch 5 | from torch import nn 6 | from torch import optim 7 | from torch.nn import functional as F 8 | from functools import partial 9 | 10 | from .utils import get_alphas_sigmas 11 | 12 | from .pqmf import CachedPQMF as PQMF 13 | 14 | from blocks.blocks import FourierFeatures 15 | from autoencoders.models import AudioAutoencoder 16 | from decoders.diffusion_decoder import DiffusionAttnUnet1D 17 | 18 | @torch.no_grad() 19 | def ema_update(model, averaged_model, decay): 20 | """Incorporates updated model parameters into an exponential moving averaged 21 | version of a model. It should be called after each optimizer step.""" 22 | model_params = dict(model.named_parameters()) 23 | averaged_params = dict(averaged_model.named_parameters()) 24 | assert model_params.keys() == averaged_params.keys() 25 | 26 | for name, param in model_params.items(): 27 | averaged_params[name].mul_(decay).add_(param, alpha=1 - decay) 28 | 29 | model_buffers = dict(model.named_buffers()) 30 | averaged_buffers = dict(averaged_model.named_buffers()) 31 | assert model_buffers.keys() == averaged_buffers.keys() 32 | 33 | for name, buf in model_buffers.items(): 34 | averaged_buffers[name].copy_(buf) 35 | 36 | 37 | class LatentAudioDiffusion(pl.LightningModule): 38 | def __init__( 39 | self, 40 | autoencoder: AudioAutoencoder, 41 | io_channels = 32, 42 | n_attn_layers=4, 43 | channels = [512] * 6 + [1024] * 4, 44 | depth = 10 45 | ): 46 | super().__init__() 47 | 48 | self.latent_dim = autoencoder.latent_dim 49 | self.downsampling_ratio = autoencoder.downsampling_ratio 50 | 51 | self.diffusion = DiffusionAttnUnet1D( 52 | io_channels=self.latent_dim, 53 | n_attn_layers=n_attn_layers, 54 | c_mults=channels, 55 | depth=depth 56 | ) 57 | 58 | self.autoencoder = autoencoder 59 | 60 | def encode(self, reals): 61 | return self.autoencoder.encode(reals) 62 | 63 | def decode(self, latents): 64 | return self.autoencoder.decode(latents) -------------------------------------------------------------------------------- /script_configs/ae_defaults.ini: -------------------------------------------------------------------------------- 1 | 2 | [DEFAULTS] 3 | 4 | #name of the run 5 | name = test-ae 6 | 7 | # training data directory 8 | training_dir = /home/ubuntu/datasets/SignalTrain_LA2A_Dataset_1.1 9 | 10 | # the batch size 11 | batch_size = 8 12 | 13 | # number of GPUs to use for training 14 | num_gpus = 1 15 | 16 | # number of nodes to use for training 17 | num_nodes = 1 18 | 19 | # number of CPU workers for the DataLoader 20 | num_workers = 12 21 | 22 | # Number of samples to train on must be a multiple of 16384 23 | sample_size = 65536 24 | 25 | # Number of epochs between demos 26 | demo_every = 50 27 | 28 | # Number of demos to create 29 | num_demos = 16 30 | 31 | # the random seed 32 | seed = 42 33 | 34 | # Batches for gradient accumulation 35 | accum_batches = 1 36 | 37 | # The sample rate of the audio 38 | sample_rate = 48000 39 | 40 | # Number of steps between checkpoints 41 | checkpoint_every = 5000 42 | 43 | # the EMA decay 44 | ema_decay = 0.995 45 | 46 | # the validation set 47 | latent_dim = 64 48 | 49 | # the validation set 50 | codebook_size = 1024 51 | 52 | # number of quantizers 53 | num_quantizers = 1 54 | 55 | # Number of residual quantizers 56 | num_residuals = 1 57 | 58 | # number of heads for the memcodes 59 | num_heads = 8 60 | 61 | # If true training data is kept in RAM 62 | cache_training_data = False 63 | 64 | # number of sub-bands for the PQMF filter 65 | pqmf_bands = 1 66 | 67 | # randomly crop input audio? (for augmentation) 68 | random_crop = True 69 | 70 | # normalize input audio? 71 | norm_inputs = False 72 | 73 | # checkpoint file to (re)start training from 74 | ckpt_path = '' 75 | 76 | # directory to place preprocessed audio 77 | preprocessed_dir = '' 78 | 79 | # directory to save the checkpoints in 80 | save_dir = '' 81 | 82 | # Depth of the autoencoder 83 | depth = 6 84 | 85 | # Number of attention layers in the autoencoder 86 | n_attn_layers = 0 87 | 88 | #the multiprocessing start method ['fork', 'forkserver', 'spawn'] 89 | start_method = 'forkserver' 90 | 91 | skip_adv_losses = False 92 | 93 | warmup_steps = 100000 94 | 95 | encoder_diffae_ckpt = '' 96 | 97 | pretrained_ckpt_path = '' -------------------------------------------------------------------------------- /diffusion/sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from tqdm import trange 4 | 5 | # Define the noise schedule and sampling loop 6 | def get_alphas_sigmas(t): 7 | """Returns the scaling factors for the clean image (alpha) and for the 8 | noise (sigma), given a timestep.""" 9 | return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) 10 | 11 | def alpha_sigma_to_t(alpha, sigma): 12 | """Returns a timestep, given the scaling factors for the clean image and for 13 | the noise.""" 14 | return torch.atan2(sigma, alpha) / math.pi * 2 15 | 16 | @torch.no_grad() 17 | def sample(model, x, steps, eta, **extra_args): 18 | """Draws samples from a model given starting noise.""" 19 | ts = x.new_ones([x.shape[0]]) 20 | 21 | # Create the noise schedule 22 | t = torch.linspace(1, 0, steps + 1)[:-1] 23 | 24 | alphas, sigmas = get_alphas_sigmas(t) 25 | 26 | # The sampling loop 27 | for i in trange(steps): 28 | 29 | # Get the model output (v, the predicted velocity) 30 | with torch.cuda.amp.autocast(): 31 | v = model(x, ts * t[i], **extra_args).float() 32 | 33 | # Predict the noise and the denoised image 34 | pred = x * alphas[i] - v * sigmas[i] 35 | eps = x * sigmas[i] + v * alphas[i] 36 | 37 | # If we are not on the last timestep, compute the noisy image for the 38 | # next timestep. 39 | if i < steps - 1: 40 | # If eta > 0, adjust the scaling factor for the predicted noise 41 | # downward according to the amount of additional noise to add 42 | ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \ 43 | (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt() 44 | adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt() 45 | 46 | # Recombine the predicted noise and predicted denoised image in the 47 | # correct proportions for the next step 48 | x = pred * alphas[i + 1] + eps * adjusted_sigma 49 | 50 | # Add the correct amount of fresh noise 51 | if eta: 52 | x += torch.randn_like(x) * ddim_sigma 53 | 54 | # If we are on the last timestep, output the denoised image 55 | return pred 56 | -------------------------------------------------------------------------------- /test/utils_test.py: -------------------------------------------------------------------------------- 1 | from diffusion.utils import MidSideEncoding, Stereo 2 | import torch 3 | 4 | import unittest 5 | 6 | class TestUtils(unittest.TestCase): 7 | 8 | def test_mid_side_encoding_invertible(self): 9 | signal = torch.randn([2, 131072], device='cuda') 10 | encoder = MidSideEncoding() 11 | 12 | #Encoding and decoding should be a no-op 13 | self.assertTrue(torch.equal(signal, encoder(encoder(signal)))) 14 | 15 | def test_mid_side_encoding_mono(self): 16 | mono_signal = torch.randn([1, 131072], device='cuda') 17 | # Make the signal stereo 18 | signal = mono_signal.repeat(2, 1) 19 | encoder = MidSideEncoding() 20 | encoded_mono = encoder(signal) 21 | 22 | #Mid content should be the same as the mono signal 23 | #TODO: the data is correct but this check is returning false. Need to figure out why 24 | #self.assertTrue(torch.equal(mono_signal, encoded_mono[0])) 25 | 26 | #Side content should be all zeros 27 | zeros = torch.zeros_like(encoded_mono[1]) 28 | self.assertTrue(torch.equal(zeros, encoded_mono[1])) 29 | 30 | def test_mono_to_stereo(self): 31 | mono_signal = torch.randn([1, 131072], device='cuda') 32 | stereo_signal = Stereo()(mono_signal) 33 | signal_shape = stereo_signal.shape 34 | self.assertEqual(len(signal_shape), 2) 35 | self.assertEqual(signal_shape[0], 2) 36 | self.assertEqual(signal_shape[1], 131072) 37 | 38 | def test_surround_to_stereo(self): 39 | mono_signal = torch.randn([6, 131072], device='cuda') 40 | stereo_signal = Stereo()(mono_signal) 41 | signal_shape = stereo_signal.shape 42 | self.assertEqual(len(signal_shape), 2) 43 | self.assertEqual(signal_shape[0], 2) 44 | self.assertEqual(signal_shape[1], 131072) 45 | 46 | def test_one_channel_to_stereo(self): 47 | mono_signal = torch.randn([131072], device='cuda') 48 | stereo_signal = Stereo()(mono_signal) 49 | signal_shape = stereo_signal.shape 50 | self.assertEqual(len(signal_shape), 2) 51 | self.assertEqual(signal_shape[0], 2) 52 | self.assertEqual(signal_shape[1], 131072) 53 | 54 | 55 | if __name__ == '__main__': 56 | unittest.main() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Zach Evans 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | 24 | MIT License 25 | 26 | Copyright (c) 2022 Phil Wang 27 | 28 | Permission is hereby granted, free of charge, to any person obtaining a copy 29 | of this software and associated documentation files (the "Software"), to deal 30 | in the Software without restriction, including without limitation the rights 31 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 32 | copies of the Software, and to permit persons to whom the Software is 33 | furnished to do so, subject to the following conditions: 34 | 35 | The above copyright notice and this permission notice shall be included in all 36 | copies or substantial portions of the Software. 37 | 38 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 39 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 40 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 41 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 42 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 43 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 44 | SOFTWARE. -------------------------------------------------------------------------------- /encoders/t5.py: -------------------------------------------------------------------------------- 1 | #Copied from https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/t5.py 2 | import torch 3 | from transformers import T5Tokenizer, T5EncoderModel, T5Config 4 | 5 | def exists(val): 6 | return val is not None 7 | 8 | # config 9 | 10 | MAX_LENGTH = 256 11 | 12 | DEFAULT_T5_NAME = 'google/t5-v1_1-base' 13 | 14 | T5_CONFIGS = {} 15 | 16 | # singleton globals 17 | 18 | def get_tokenizer(name): 19 | tokenizer = T5Tokenizer.from_pretrained(name) 20 | return tokenizer 21 | 22 | def get_model(name): 23 | model = T5EncoderModel.from_pretrained(name) 24 | return model 25 | 26 | def get_model_and_tokenizer(name): 27 | global T5_CONFIGS 28 | 29 | if name not in T5_CONFIGS: 30 | T5_CONFIGS[name] = dict() 31 | if "model" not in T5_CONFIGS[name]: 32 | T5_CONFIGS[name]["model"] = get_model(name) 33 | if "tokenizer" not in T5_CONFIGS[name]: 34 | T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name) 35 | 36 | return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer'] 37 | 38 | def get_encoded_dim(name): 39 | if name not in T5_CONFIGS: 40 | # avoids loading the model if we only want to get the dim 41 | config = T5Config.from_pretrained(name) 42 | T5_CONFIGS[name] = dict(config=config) 43 | elif "config" in T5_CONFIGS[name]: 44 | config = T5_CONFIGS[name]["config"] 45 | elif "model" in T5_CONFIGS[name]: 46 | config = T5_CONFIGS[name]["model"].config 47 | else: 48 | assert False 49 | return config.d_model 50 | 51 | # encoding text 52 | 53 | def t5_encode_text(texts, name = DEFAULT_T5_NAME): 54 | t5, tokenizer = get_model_and_tokenizer(name) 55 | 56 | if torch.cuda.is_available(): 57 | t5 = t5.cuda() 58 | 59 | device = next(t5.parameters()).device 60 | 61 | encoded = tokenizer.batch_encode_plus( 62 | texts, 63 | return_tensors = "pt", 64 | padding = 'longest', 65 | max_length = MAX_LENGTH, 66 | truncation = True 67 | ) 68 | 69 | input_ids = encoded.input_ids.to(device) 70 | attn_mask = encoded.attention_mask.to(device) 71 | 72 | t5.eval() 73 | 74 | with torch.no_grad(): 75 | output = t5(input_ids = input_ids, attention_mask = attn_mask) 76 | encoded_text = output.last_hidden_state.detach() 77 | 78 | return encoded_text, attn_mask.bool() 79 | -------------------------------------------------------------------------------- /convert_mp3_to_wav.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | 3 | import argparse 4 | from glob import glob 5 | import os 6 | from multiprocessing import Pool, cpu_count, Barrier 7 | from functools import partial 8 | import tqdm 9 | from tqdm.contrib.concurrent import process_map 10 | import torch 11 | import torchaudio 12 | from torchaudio import transforms as T 13 | import math 14 | from dataset.dataset import get_audio_filenames 15 | import subprocess 16 | 17 | def process_one_file(filenames, args, file_ind): 18 | "this chunks up one file" 19 | filename = filenames[file_ind] # this is actually input_path+/+filename 20 | input_paths = args.input_paths 21 | new_filename = None 22 | 23 | path, ext = os.path.splitext(filename) 24 | 25 | new_filename = f"{path}.wav" 26 | 27 | if new_filename is None: 28 | print(f"ERROR: Something went wrong with name of input file {filename}. Skipping.") 29 | return 30 | try: 31 | subprocess.call(['ffmpeg', '-i', filename, 32 | new_filename]) 33 | except Exception as e: 34 | print(e) 35 | print(f"Error loading {filename} or writing chunks. Skipping.") 36 | 37 | return 38 | 39 | 40 | def main(): 41 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 42 | parser.add_argument('input_paths', nargs='+', help='Path(s) of a file or a folder of files. (recursive)') 43 | args = parser.parse_args() 44 | 45 | #print(f" output_path = {args.output_path}") 46 | #print(f" chunk_size = {args.chunk_size}") 47 | 48 | torchaudio.set_audio_backend("sox_io") 49 | 50 | print("Getting list of input filenames") 51 | filenames = get_audio_filenames(args.input_paths, exts=[".mp3"]) 52 | # for path in args.input_paths: 53 | # for ext in ['wav','flac','ogg']: 54 | # filenames += glob(f'{path}/**/*.{ext}', recursive=True) 55 | n = len(filenames) 56 | print(f"Got {n} input filenames") 57 | 58 | # for i in range(n): 59 | # process_one_file(filenames, args, i) 60 | 61 | print("Processing files (in parallel)") 62 | wrapper = partial(process_one_file, filenames, args) 63 | r = process_map(wrapper, range(0, n), chunksize=1, max_workers=48) # different chunksize used by tqdm. max_workers is to avoid annoying other ppl 64 | 65 | print("Finished") 66 | 67 | 68 | if __name__ == "__main__": 69 | main() 70 | 71 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | generated/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | wandb/ 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | junk/ 136 | .DS_Store -------------------------------------------------------------------------------- /losses/freq_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import auraloss 3 | 4 | class PerceptualSumAndDifferenceSTFTLoss(torch.nn.Module): 5 | """Sum and difference sttereo STFT loss module. 6 | See [Steinmetz et al., 2020](https://arxiv.org/abs/2010.10291) 7 | Args: 8 | fft_sizes (list, optional): List of FFT sizes. 9 | hop_sizes (list, optional): List of hop sizes. 10 | win_lengths (list, optional): List of window lengths. 11 | window (str, optional): Window function type. 12 | w_sum (float, optional): Weight of the sum loss component. Default: 1.0 13 | w_diff (float, optional): Weight of the difference loss component. Default: 1.0 14 | output (str, optional): Format of the loss returned. 15 | 'loss' : Return only the raw, aggregate loss term. 16 | 'full' : Return the raw loss, plus intermediate loss terms. 17 | Default: 'loss' 18 | Returns: 19 | loss: 20 | Aggreate loss term. Only returned if output='loss'. 21 | loss, sum_loss, diff_loss: 22 | Aggregate and intermediate loss terms. Only returned if output='full'. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | fft_sizes=[1024, 2048, 512], 28 | hop_sizes=[120, 240, 50], 29 | win_lengths=[600, 1200, 240], 30 | window="hann_window", 31 | w_sum=1.0, 32 | w_diff=1.0, 33 | output="loss", 34 | sample_rate=48000, 35 | **stft_args 36 | ): 37 | super(PerceptualSumAndDifferenceSTFTLoss, self).__init__() 38 | self.sd = auraloss.perceptual.SumAndDifference() 39 | self.w_sum = 1.0 40 | self.w_diff = 1.0 41 | self.output = output 42 | self.mrstft = auraloss.freq.MultiResolutionSTFTLoss(fft_sizes, hop_sizes, win_lengths, window, sample_rate=sample_rate, **stft_args) 43 | self.aw_fir = auraloss.perceptual.FIRFilter(filter_type="aw", fs=sample_rate) 44 | 45 | def forward(self, input, target): 46 | # input_sum, input_diff = self.aw_fir(*self.sd(input)) 47 | # target_sum, target_diff = self.aw_fir(*self.sd(target)) 48 | 49 | input_sum, input_diff = self.sd(input) 50 | target_sum, target_diff = self.sd(target) 51 | 52 | sum_loss = self.mrstft(input_sum, target_sum) 53 | diff_loss = self.mrstft(input_diff, target_diff) 54 | loss = ((self.w_sum * sum_loss) + (self.w_diff * diff_loss)) / 2 55 | 56 | if self.output == "loss": 57 | return loss 58 | elif self.output == "full": 59 | return loss, sum_loss, diff_loss -------------------------------------------------------------------------------- /script_configs/latent_defaults.ini: -------------------------------------------------------------------------------- 1 | 2 | [DEFAULTS] 3 | 4 | #name of the run 5 | name = test-latent-diffae 6 | 7 | # training data directory 8 | training_dir = /home/ubuntu/datasets/SignalTrain_LA2A_Dataset_1.1 9 | 10 | # the batch size 11 | batch_size = 8 12 | 13 | # number of GPUs to use for training 14 | num_gpus = 1 15 | 16 | # number of nodes to use for training 17 | num_nodes = 1 18 | 19 | # number of CPU workers for the DataLoader 20 | num_workers = 12 21 | 22 | # Number of audio samples for the training input 23 | sample_size = 65536 24 | 25 | # Number of epochs between demos 26 | demo_every = 200 27 | 28 | # Number of denoising steps for the demos 29 | demo_steps = 250 30 | 31 | # Number of demos to create 32 | num_demos = 16 33 | 34 | # the random seed 35 | seed = 42 36 | 37 | # Batches for gradient accumulation 38 | accum_batches = 1 39 | 40 | # The sample rate of the audio 41 | sample_rate = 48000 42 | 43 | # Number of steps between checkpoints 44 | checkpoint_every = 10000 45 | 46 | # the EMA decay 47 | ema_decay = 0.995 48 | 49 | # the validation set 50 | latent_dim = 64 51 | 52 | # the validation set 53 | codebook_size = 1024 54 | 55 | # number of quantizers 56 | num_quantizers = 0 57 | 58 | # number of residual quantizers (same as above, just depends on which file you're using) 59 | num_residuals = 0 60 | 61 | # number of heads for the memcodes 62 | num_heads = 8 63 | 64 | # If true training data is kept in RAM 65 | cache_training_data = False 66 | 67 | # number of sub-bands for the PQMF filter 68 | pqmf_bands = 1 69 | 70 | # randomly crop input audio? (for augmentation) 71 | random_crop = True 72 | 73 | # normalize input audio? 74 | norm_inputs = False 75 | 76 | #Number of steps before freezing encoder in dVAE 77 | warmup_steps = 100000 78 | 79 | # checkpoint file to (re)start training from 80 | ckpt_path = '' 81 | 82 | # pretrained diffusion autoencoder checkpoint file 83 | pretrained_ckpt_path = '' 84 | 85 | # configuration model specifying model hyperparameters 86 | model_config = '' 87 | 88 | # directory to save the checkpoints in 89 | save_dir = '' 90 | 91 | #the multiprocessing start method ['fork', 'forkserver', 'spawn'] 92 | start_method = 'spawn' 93 | 94 | preprocessed_dir = '' 95 | 96 | depth=8 97 | 98 | # Name of the run for automatic restarts 99 | run_name = '' 100 | 101 | # Checkpoint for a pre-trained CLAP model 102 | clap_ckpt_path='' 103 | 104 | clap_fusion=False 105 | 106 | clap_amodel='HTSAT-tiny' -------------------------------------------------------------------------------- /defaults.ini: -------------------------------------------------------------------------------- 1 | 2 | [DEFAULTS] 3 | 4 | #name of the run 5 | name = test-dvae 6 | 7 | # training data directory 8 | training_dir = /home/ubuntu/datasets/SignalTrain_LA2A_Dataset_1.1 9 | 10 | # the batch size 11 | batch_size = 8 12 | 13 | # number of GPUs to use for training 14 | num_gpus = 1 15 | 16 | # number of nodes to use for training 17 | num_nodes = 1 18 | 19 | # number of CPU workers for the DataLoader 20 | num_workers = 8 21 | 22 | # Number of audio samples for the training input 23 | sample_size = 65536 24 | 25 | # Number of epochs between demos 26 | demo_every = 20 27 | 28 | # Number of denoising steps for the demos 29 | demo_steps = 250 30 | 31 | # Number of demos to create 32 | num_demos = 16 33 | 34 | # the random seed 35 | seed = 42 36 | 37 | # Batches for gradient accumulation 38 | accum_batches = 1 39 | 40 | # The sample rate of the audio 41 | sample_rate = 48000 42 | 43 | # Number of steps between checkpoints 44 | checkpoint_every = 10000 45 | 46 | # the EMA decay 47 | ema_decay = 0.995 48 | 49 | # the validation set 50 | latent_dim = 32 51 | 52 | # the validation set 53 | codebook_size = 1024 54 | 55 | # number of quantizers 56 | num_quantizers = 0 57 | 58 | # number of residual quantizers (same as above, depending on sfile) 59 | num_residuals = 0 60 | 61 | # number of heads for the memcodes 62 | num_heads = 8 63 | 64 | # If true training data is kept in RAM 65 | cache_training_data = False 66 | 67 | # number of sub-bands for the PQMF filter 68 | pqmf_bands = 1 69 | 70 | # randomly crop input audio? (for augmentation) 71 | random_crop = True 72 | 73 | # normalize input audio? 74 | norm_inputs = False 75 | 76 | #Number of steps before freezing encoder in dVAE 77 | warmup_steps = 100000 78 | 79 | # for jukebox imbeddings. 0 (high res), 1 (med), or 2 (low res) 80 | jukebox_layer = 0 81 | 82 | # If true, use MFCCs instead of spectrograms for training 83 | use_mfcc = True 84 | 85 | # checkpoint file to (re)start training from 86 | ckpt_path = '' 87 | 88 | # configuration model specifying model hyperparameters 89 | model_config = '' 90 | 91 | #the multiprocessing start method ['fork', 'forkserver', 'spawn'] 92 | start_method = 'spawn' 93 | 94 | global_latent_dim = 128 95 | 96 | preprocessed_dir = '' 97 | 98 | depth=8 99 | 100 | diffusion_kernel_size = 5 101 | 102 | upsample_ratio = 4 103 | 104 | patch_size = 1 105 | 106 | # directory to save the checkpoints in 107 | save_dir = '' 108 | 109 | # Name of the run for automatic restarts 110 | run_name = '' -------------------------------------------------------------------------------- /encoders/wavelets.py: -------------------------------------------------------------------------------- 1 | """The 1D discrete wavelet transform for PyTorch.""" 2 | 3 | from einops import rearrange 4 | import pywt 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | 10 | def get_filter_bank(wavelet): 11 | filt = torch.tensor(pywt.Wavelet(wavelet).filter_bank) 12 | if wavelet.startswith("bior") and torch.all(filt[:, 0] == 0): 13 | filt = filt[:, 1:] 14 | return filt 15 | 16 | 17 | class WaveletEncode1d(nn.Module): 18 | def __init__(self, channels, wavelet, levels): 19 | super().__init__() 20 | self.wavelet = wavelet 21 | self.channels = channels 22 | self.levels = levels 23 | filt = get_filter_bank(wavelet) 24 | assert filt.shape[-1] % 2 == 1 25 | kernel = filt[:2, None] 26 | kernel = torch.flip(kernel, dims=(-1,)) 27 | index_i = torch.repeat_interleave(torch.arange(2), channels) 28 | index_j = torch.tile(torch.arange(channels), (2,)) 29 | kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1]) 30 | kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0] 31 | self.register_buffer("kernel", kernel_final) 32 | 33 | def forward(self, x): 34 | for i in range(self.levels): 35 | low, rest = x[:, : self.channels], x[:, self.channels :] 36 | pad = self.kernel.shape[-1] // 2 37 | low = F.pad(low, (pad, pad), "reflect") 38 | low = F.conv1d(low, self.kernel, stride=2) 39 | rest = rearrange( 40 | rest, "n (c c2) (l l2) -> n (c l2 c2) l", l2=2, c2=self.channels 41 | ) 42 | x = torch.cat([low, rest], dim=1) 43 | return x 44 | 45 | 46 | class WaveletDecode1d(nn.Module): 47 | def __init__(self, channels, wavelet, levels): 48 | super().__init__() 49 | self.wavelet = wavelet 50 | self.channels = channels 51 | self.levels = levels 52 | filt = get_filter_bank(wavelet) 53 | assert filt.shape[-1] % 2 == 1 54 | kernel = filt[2:, None] 55 | index_i = torch.repeat_interleave(torch.arange(2), channels) 56 | index_j = torch.tile(torch.arange(channels), (2,)) 57 | kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1]) 58 | kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0] 59 | self.register_buffer("kernel", kernel_final) 60 | 61 | def forward(self, x): 62 | for i in range(self.levels): 63 | low, rest = x[:, : self.channels * 2], x[:, self.channels * 2 :] 64 | pad = self.kernel.shape[-1] // 2 + 2 65 | low = rearrange(low, "n (l2 c) l -> n c (l l2)", l2=2) 66 | low = F.pad(low, (pad, pad), "reflect") 67 | low = rearrange(low, "n c (l l2) -> n (l2 c) l", l2=2) 68 | low = F.conv_transpose1d( 69 | low, self.kernel, stride=2, padding=self.kernel.shape[-1] // 2 70 | ) 71 | low = low[..., pad - 1 : -pad] 72 | rest = rearrange( 73 | rest, "n (c l2 c2) l -> n (c c2) (l l2)", l2=2, c2=self.channels 74 | ) 75 | x = torch.cat([low, rest], dim=1) 76 | return x -------------------------------------------------------------------------------- /effects/tcn.py: -------------------------------------------------------------------------------- 1 | #Copied from https://github.com/csteinmetz1/steerable-nafx/blob/main/steerable-nafx.ipynb 2 | import torch 3 | 4 | def causal_crop(x, length: int): 5 | if x.shape[-1] != length: 6 | stop = x.shape[-1] - 1 7 | start = stop - length 8 | x = x[..., start:stop] 9 | return x 10 | 11 | class FiLM(torch.nn.Module): 12 | def __init__( 13 | self, 14 | cond_dim, # dim of conditioning input 15 | num_features, # dim of the conv channel 16 | batch_norm=True, 17 | ): 18 | super().__init__() 19 | self.num_features = num_features 20 | self.batch_norm = batch_norm 21 | if batch_norm: 22 | self.bn = torch.nn.BatchNorm1d(num_features, affine=False) 23 | self.adaptor = torch.nn.Linear(cond_dim, num_features * 2) 24 | 25 | def forward(self, x, cond): 26 | 27 | cond = self.adaptor(cond) 28 | g, b = torch.chunk(cond, 2, dim=-1) 29 | g = g.permute(0, 2, 1) 30 | b = b.permute(0, 2, 1) 31 | 32 | if self.batch_norm: 33 | x = self.bn(x) # apply BatchNorm without affine 34 | x = (x * g) + b # then apply conditional affine 35 | 36 | return x 37 | 38 | class TCNBlock(torch.nn.Module): 39 | def __init__(self, in_channels, out_channels, kernel_size, dilation, cond_dim=0, activation=True): 40 | super().__init__() 41 | print(f"Kernel size: {kernel_size}") 42 | print(f"Dilation: {dilation}") 43 | self.conv = torch.nn.Conv1d( 44 | in_channels, 45 | out_channels, 46 | kernel_size, 47 | dilation=dilation, 48 | padding=0, #((kernel_size-1)//2)*dilation, 49 | bias=True) 50 | if cond_dim > 0: 51 | self.film = FiLM(cond_dim, out_channels, batch_norm=False) 52 | if activation: 53 | #self.act = torch.nn.Tanh() 54 | self.act = torch.nn.PReLU() 55 | self.res = torch.nn.Conv1d(in_channels, out_channels, 1, bias=False) 56 | 57 | def forward(self, x, c=None): 58 | x_in = x 59 | x = self.conv(x) 60 | if hasattr(self, "film"): 61 | x = self.film(x, c) 62 | if hasattr(self, "act"): 63 | x = self.act(x) 64 | x_res = causal_crop(self.res(x_in), x.shape[-1]) 65 | x = x + x_res 66 | 67 | return x 68 | 69 | class TCN(torch.nn.Module): 70 | def __init__(self, n_inputs=1, n_outputs=1, n_blocks=10, kernel_size=13, n_channels=64, dilation_growth=4, cond_dim=0): 71 | super().__init__() 72 | self.kernel_size = kernel_size 73 | self.n_channels = n_channels 74 | self.dilation_growth = dilation_growth 75 | self.n_blocks = n_blocks 76 | self.stack_size = n_blocks 77 | 78 | self.blocks = torch.nn.ModuleList() 79 | for n in range(n_blocks): 80 | if n == 0: 81 | in_ch = n_inputs 82 | out_ch = n_channels 83 | act = True 84 | elif (n+1) == n_blocks: 85 | in_ch = n_channels 86 | out_ch = n_outputs 87 | act = True 88 | else: 89 | in_ch = n_channels 90 | out_ch = n_channels 91 | act = True 92 | 93 | dilation = dilation_growth ** n 94 | self.blocks.append(TCNBlock(in_ch, out_ch, kernel_size, dilation, cond_dim=cond_dim, activation=act)) 95 | 96 | def forward(self, x, c=None): 97 | for block in self.blocks: 98 | x = block(x, c) 99 | 100 | return x -------------------------------------------------------------------------------- /diffusion/xunet.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Sequence 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from a_unet import ( 6 | ClassifierFreeGuidancePlugin, 7 | Conv, 8 | Module, 9 | TextConditioningPlugin, 10 | TimeConditioningPlugin, 11 | default, 12 | exists, 13 | ) 14 | from a_unet.apex import ( 15 | AttentionItem, 16 | CrossAttentionItem, 17 | InjectChannelsItem, 18 | ModulationItem, 19 | ResnetItem, 20 | SkipCat, 21 | SkipModulate, 22 | XBlock, 23 | XUNet, 24 | ) 25 | 26 | 27 | def UNetV0( 28 | dim: int, 29 | in_channels: int, 30 | channels: Sequence[int], 31 | factors: Sequence[int], 32 | items: Sequence[int], 33 | attentions: Optional[Sequence[int]] = None, 34 | cross_attentions: Optional[Sequence[int]] = None, 35 | context_channels: Optional[Sequence[int]] = None, 36 | attention_features: Optional[int] = None, 37 | attention_heads: Optional[int] = None, 38 | embedding_features: Optional[int] = None, 39 | resnet_groups: int = 8, 40 | use_modulation: bool = True, 41 | modulation_features: int = 1024, 42 | embedding_max_length: Optional[int] = None, 43 | use_time_conditioning: bool = True, 44 | use_embedding_cfg: bool = False, 45 | use_text_conditioning: bool = False, 46 | out_channels: Optional[int] = None, 47 | embedder: Optional[nn.Module] = None 48 | ): 49 | # Set defaults and check lengths 50 | num_layers = len(channels) 51 | attentions = default(attentions, [0] * num_layers) 52 | cross_attentions = default(cross_attentions, [0] * num_layers) 53 | context_channels = default(context_channels, [0] * num_layers) 54 | xs = (channels, factors, items, attentions, cross_attentions, context_channels) 55 | assert all(len(x) == num_layers for x in xs) # type: ignore 56 | 57 | # Define UNet type 58 | UNetV0 = XUNet 59 | 60 | if use_embedding_cfg: 61 | msg = "use_embedding_cfg requires embedding_max_length" 62 | assert exists(embedding_max_length), msg 63 | UNetV0 = ClassifierFreeGuidancePlugin(UNetV0, embedding_max_length) 64 | 65 | if use_text_conditioning: 66 | UNetV0 = TextConditioningPlugin(UNetV0, embedder) 67 | 68 | if use_time_conditioning: 69 | assert use_modulation, "use_time_conditioning requires use_modulation=True" 70 | UNetV0 = TimeConditioningPlugin(UNetV0) 71 | 72 | # Build 73 | return UNetV0( 74 | dim=dim, 75 | in_channels=in_channels, 76 | out_channels=out_channels, 77 | blocks=[ 78 | XBlock( 79 | channels=channels, 80 | factor=factor, 81 | context_channels=ctx_channels, 82 | items=( 83 | [ResnetItem] 84 | + [ModulationItem] * use_modulation 85 | + [InjectChannelsItem] * (ctx_channels > 0) 86 | + [AttentionItem] * att 87 | + [CrossAttentionItem] * cross 88 | ) 89 | * items, 90 | ) 91 | for channels, factor, items, att, cross, ctx_channels in zip(*xs) # type: ignore # noqa 92 | ], 93 | skip_t=SkipModulate if use_modulation else SkipCat, 94 | attention_features=attention_features, 95 | attention_heads=attention_heads, 96 | embedding_features=embedding_features, 97 | modulation_features=modulation_features, 98 | resnet_groups=resnet_groups, 99 | ) 100 | -------------------------------------------------------------------------------- /viz/viz.py.bak: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | from pathlib import Path 4 | from matplotlib.backends.backend_agg import FigureCanvasAgg 5 | import matplotlib.cm as cm 6 | import matplotlib.pyplot as plt 7 | from matplotlib.colors import Normalize 8 | from matplotlib.figure import Figure 9 | import numpy as np 10 | from PIL import Image 11 | 12 | import torch 13 | from torch import optim, nn 14 | from torch.nn import functional as F 15 | import torchaudio 16 | import torchaudio.transforms as T 17 | import librosa 18 | from einops import rearrange 19 | 20 | import wandb 21 | import numpy as np 22 | import pandas as pd 23 | 24 | 25 | 26 | def embeddings_table(tokens): 27 | "make a table of embeddings for use with wandb" 28 | features, labels = [], [] 29 | embeddings = rearrange(tokens, 'b d n -> b n d') # each demo sample is n vectors in d-dim space 30 | for i in range(embeddings.size()[0]): # nested for's are slow but sure ;-) 31 | for j in range(embeddings.size()[1]): 32 | features.append(embeddings[i,j].detach().cpu().numpy()) 33 | labels.append([f'demo{i}']) # labels does the grouping / color for each point 34 | features = np.array(features) 35 | #print("\nfeatures.shape = ",features.shape) 36 | labels = np.concatenate(labels, axis=0) 37 | cols = [f"dim_{i}" for i in range(features.shape[1])] 38 | df = pd.DataFrame(features, columns=cols) 39 | df['LABEL'] = labels 40 | return wandb.Table(columns=df.columns.to_list(), data=df.values) 41 | 42 | 43 | def proj_pca(tokens, proj_dims=3): 44 | "this projects via PCA, grabbing the first _3_ dimensions" 45 | A = rearrange(tokens, 'b d n -> (b n) d') # put all the vectors into the same d-dim space 46 | k = proj_dims 47 | (U, S, V) = torch.pca_lowrank(A) 48 | proj_data = torch.matmul(A, V[:, :k]) # this is the actual PCA projection step 49 | return torch.reshape(proj_data, (tokens.size()[0], -1, proj_dims)) # put it in shape [batch, n, 3] 50 | 51 | 52 | def pca_point_cloud(tokens): 53 | "produces a 3D wandb point cloud of the tokens using PCA" 54 | data = proj_pca(tokens).cpu().numpy() 55 | points = [] 56 | cmap = cm.tab20 # 20 farly distinct colors 57 | norm = Normalize(vmin=0, vmax=data.shape[0]) 58 | for bi in range(data.shape[0]): # batch index 59 | [r, g, b, _] = [int(255*x) for x in cmap(norm(bi))] 60 | for n in range(data.shape[1]): 61 | #points.append([data[b,n,0], data[b,n,1], data[b,n,2], color]) # only works for color=1 to 14 62 | points.append([data[bi,n,0], data[bi,n,1], data[bi,n,2], r, g, b]) 63 | 64 | point_cloud = np.array(points) 65 | return wandb.Object3D(point_cloud) 66 | 67 | 68 | def spectrogram_image(spec, title=None, ylabel='freq_bin', aspect='auto', xmax=None): 69 | """ 70 | # cf. https://pytorch.org/tutorials/beginner/audio_feature_extractions_tutorial.html 71 | 72 | """ 73 | fig = Figure(figsize=(5, 4), dpi=100) 74 | canvas = FigureCanvasAgg(fig) 75 | axs = fig.add_subplot() 76 | axs.set_title(title or 'Spectrogram (db)') 77 | axs.set_ylabel(ylabel) 78 | axs.set_xlabel('frame') 79 | im = axs.imshow(librosa.power_to_db(spec), origin='lower', aspect=aspect) 80 | if xmax: 81 | axs.set_xlim((0, xmax)) 82 | fig.colorbar(im, ax=axs) 83 | canvas.draw() 84 | rgba = np.asarray(canvas.buffer_rgba()) 85 | return Image.fromarray(rgba) 86 | 87 | 88 | def audio_spectrogram_image(waveform, power=2.0, sample_rate=48000): 89 | """ 90 | # cf. https://pytorch.org/tutorials/beginner/audio_feature_extractions_tutorial.html 91 | """ 92 | n_fft = 1024 93 | win_length = None 94 | hop_length = 512 95 | n_mels = 128 96 | 97 | mel_spectrogram_op = T.MelSpectrogram( 98 | sample_rate=sample_rate, n_fft=n_fft, win_length=win_length, 99 | hop_length=hop_length, center=True, pad_mode="reflect", power=power, 100 | norm='slaney', onesided=True, n_mels=n_mels, mel_scale="htk") 101 | 102 | melspec = mel_spectrogram_op(waveform)[0] # TODO: only left channel for now 103 | return spectrogram_image(melspec, title="MelSpectrogram", ylabel='mel bins (log freq)') 104 | 105 | 106 | def token_spectrogram_image(tokens): 107 | pass -------------------------------------------------------------------------------- /prompts/prompters.py: -------------------------------------------------------------------------------- 1 | import audio_metadata 2 | import random 3 | 4 | # A class that generates text prompts from audio metadata 5 | class MetadataPrompter(): 6 | def __init__(self, sample_rate=48000): 7 | self.sample_rate = sample_rate 8 | 9 | # Creates and returns a text prompt given a metadata object 10 | def get_track_prompt_from_file_metadata(self, filepath): 11 | 12 | try: 13 | track_metadata = audio_metadata.load(filepath) 14 | except Exception as e: 15 | print(f"Couldn't load metadata for {filepath}: {e}") 16 | return filepath #If the metadata can't be loaded, use the file path as the prompt 17 | 18 | properties = [] 19 | 20 | if "tags" in track_metadata: 21 | tags = track_metadata["tags"] 22 | 23 | add_property = lambda name, key: random.random() > 0.1 and properties.append(f"{name}: {', '.join(tags[key])}") 24 | 25 | if 'title' in tags: 26 | add_property("Title", "title") 27 | if 'artist' in tags: 28 | add_property("Artist", "artist") 29 | if 'album' in tags: 30 | add_property("Album", "album") 31 | if 'genre' in tags: 32 | add_property("Genre", "genre") 33 | if 'label' in tags: 34 | add_property("Label", "label") 35 | if 'date' in tags: 36 | add_property("Date", "date") 37 | 38 | random.shuffle(properties) 39 | return "|".join(properties) 40 | 41 | def get_prompt_from_jmann_metadata(metadata): 42 | 43 | properties = [] 44 | 45 | song_attributes = metadata["attributes"] 46 | 47 | attributes = [ 48 | 'Date', 49 | 'Location', 50 | 'Topic', 51 | 'Mood', 52 | 'Beard', 53 | 'Genre', 54 | 'Style', 55 | 'Key', 56 | 'Tempo', 57 | 'Year', 58 | 'Instrument' 59 | ] 60 | 61 | traits = {} 62 | 63 | traits["Name"] = [metadata["name"]] 64 | 65 | traits["Artist"] = ["Jonathan Mann"] 66 | 67 | for attribute in song_attributes: 68 | trait_type = attribute["trait_type"] 69 | value = attribute["value"] 70 | 71 | if trait_type not in attributes: 72 | continue 73 | 74 | if trait_type in traits: 75 | traits[trait_type].append(value) 76 | else: 77 | traits[trait_type] = [value] 78 | 79 | for trait in traits: 80 | properties.append(f"{trait}: {', '.join(traits[trait])}") 81 | 82 | # Sample a random number of properties 83 | properties = random.sample(properties, random.randint(1, len(properties))) 84 | 85 | return "|".join(properties) 86 | 87 | def get_prompt_from_audio_file_metadata(metadata): 88 | 89 | properties = [] 90 | 91 | tags = [ 92 | 'title', 93 | 'artist', 94 | 'album', 95 | 'genre', 96 | 'label', 97 | 'date', 98 | 'composer', 99 | 'bpm', 100 | ] 101 | 102 | for tag in metadata.keys(): 103 | if tag in tags: 104 | properties.append(f"{tag}: {', '.join(metadata[tag])}") 105 | 106 | if len(properties) == 0: 107 | if "path" in metadata: 108 | return metadata["path"] 109 | elif "text" in metadata: 110 | return metadata["text"] 111 | else: 112 | return "" 113 | 114 | # Sample a random number of properties 115 | properties = random.sample(properties, random.randint(1, len(properties))) 116 | #random.shuffle(properties) 117 | 118 | return "|".join(properties) 119 | 120 | def get_prompt_from_fma_metadata(metadata): 121 | 122 | properties = [] 123 | 124 | keys = ['genre', 'album', 'song_title', 'artist', 'composer'] 125 | 126 | if "original_data" in metadata: 127 | original_data = metadata["original_data"] 128 | 129 | for key in keys: 130 | if key == "song_title": 131 | prompt_key = "title" 132 | else: 133 | prompt_key = key 134 | 135 | if key in original_data and str(original_data[key]) != "nan": 136 | properties.append(f"{prompt_key}: {original_data[key]}") 137 | 138 | properties = random.sample(properties, random.randint(1, len(properties))) 139 | return "|".join(properties) -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | from audio_diffusion.models import LatentAudioDiffusion 4 | from audio_diffusion.utils import get_alphas_sigmas 5 | from ema_pytorch import EMA 6 | from audio_diffusion.utils import InverseLR 7 | from torch.nn import functional as F 8 | import pytorch_lightning as pl 9 | 10 | class UnconditionalAudioDiffusionTrainer(pl.LightningModule): 11 | def __init__(self, diffusion_model): 12 | super().__init__() 13 | 14 | self.diffusion = diffusion_model 15 | 16 | self.diffusion_ema = EMA( 17 | self.diffusion, 18 | beta = 0.9999, 19 | power=3/4, 20 | update_every = 1, 21 | update_after_step = 1000 22 | ) 23 | 24 | self.rng = torch.quasirandom.SobolEngine(1, scramble=True) 25 | 26 | def configure_optimizers(self): 27 | optimizer = optim.Adam([*self.diffusion.parameters()], lr=1e-4) 28 | 29 | scheduler = InverseLR(optimizer, inv_gamma=50000, power=1/2, warmup=0.9) 30 | 31 | return [optimizer], [scheduler] 32 | 33 | def training_step(self, batch, batch_idx): 34 | reals = batch 35 | 36 | # Draw uniformly distributed continuous timesteps 37 | t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) 38 | 39 | # Calculate the noise schedule parameters for those timesteps 40 | alphas, sigmas = get_alphas_sigmas(t) 41 | 42 | # Combine the ground truth images and the noise 43 | alphas = alphas[:, None, None] 44 | sigmas = sigmas[:, None, None] 45 | noise = torch.randn_like(latents) 46 | noised_latents = latents * alphas + noise * sigmas 47 | targets = noise * alphas - latents * sigmas 48 | 49 | with torch.cuda.amp.autocast(): 50 | v = self.diffusion(noised_latents, t) 51 | mse_loss = F.mse_loss(v, targets) 52 | loss = mse_loss 53 | 54 | log_dict = { 55 | 'train/loss': loss.detach(), 56 | 'train/lr': self.lr_schedulers().get_last_lr()[0] 57 | } 58 | 59 | self.log_dict(log_dict, prog_bar=True, on_step=True) 60 | return loss 61 | 62 | def on_before_zero_grad(self, *args, **kwargs): 63 | self.diffusion_ema.update() 64 | 65 | class UnconditionalLatentAudioDiffusionTrainer(pl.LightningModule): 66 | def __init__(self, latent_diffusion_model: LatentAudioDiffusion): 67 | super().__init__() 68 | 69 | self.diffusion = latent_diffusion_model 70 | 71 | self.diffusion_ema = EMA( 72 | self.diffusion, 73 | beta = 0.9999, 74 | power=3/4, 75 | update_every = 1, 76 | update_after_step = 1 77 | ) 78 | 79 | self.rng = torch.quasirandom.SobolEngine(1, scramble=True) 80 | 81 | def configure_optimizers(self): 82 | optimizer = optim.Adam([*self.diffusion.parameters()], lr=1e-4) 83 | 84 | scheduler = InverseLR(optimizer, inv_gamma=100, power=1, warmup=0.9) 85 | 86 | return [optimizer], [scheduler] 87 | 88 | def training_step(self, batch, batch_idx): 89 | reals = batch 90 | 91 | with torch.cuda.amp.autocast(): 92 | with torch.no_grad(): 93 | latents = self.diffusion.encode(reals) 94 | 95 | # Draw uniformly distributed continuous timesteps 96 | t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) 97 | 98 | # Calculate the noise schedule parameters for those timesteps 99 | alphas, sigmas = get_alphas_sigmas(t) 100 | 101 | # Combine the ground truth images and the noise 102 | alphas = alphas[:, None, None] 103 | sigmas = sigmas[:, None, None] 104 | noise = torch.randn_like(latents) 105 | noised_latents = latents * alphas + noise * sigmas 106 | targets = noise * alphas - latents * sigmas 107 | 108 | with torch.cuda.amp.autocast(): 109 | v = self.diffusion(noised_latents, t) 110 | mse_loss = F.mse_loss(v, targets) 111 | loss = mse_loss 112 | 113 | log_dict = { 114 | 'train/loss': loss.detach(), 115 | 'train/lr': self.lr_schedulers().get_last_lr()[0] 116 | } 117 | 118 | self.log_dict(log_dict, prog_bar=True, on_step=True) 119 | return loss 120 | 121 | def on_before_zero_grad(self, *args, **kwargs): 122 | self.diffusion_ema.update() -------------------------------------------------------------------------------- /encoders/gabor_filter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class LogGaborFilterBank(nn.Module): 7 | def __init__(self, num_filters, filter_length, audio_channels, sample_rate): 8 | super(LogGaborFilterBank, self).__init__() 9 | 10 | self.num_filters = num_filters 11 | self.filter_length = filter_length 12 | self.audio_channels = audio_channels 13 | self.sample_rate = sample_rate 14 | 15 | self.filters = self._create_filters() 16 | 17 | def _create_filters(self): 18 | filters = [] 19 | 20 | # Define parameters for Log Gabor filters 21 | min_center_frequency = 50 22 | max_center_frequency = self.sample_rate // 2 23 | bandwidth = 1 24 | 25 | # Calculate center frequencies in log scale 26 | center_frequencies = np.geomspace( 27 | min_center_frequency, max_center_frequency, self.num_filters) 28 | 29 | for channel in range(self.audio_channels): 30 | channel_filters = [] 31 | for freq in center_frequencies: 32 | t = torch.linspace(-self.filter_length // 2, self.filter_length // 33 | 2, self.filter_length) / self.sample_rate 34 | omega = 2 * np.pi * freq 35 | sigma = freq / bandwidth 36 | 37 | # Create the filter in the frequency domain 38 | log_gabor = torch.exp(-((torch.log(omega * t) - 39 | torch.log(freq)) ** 2) / (2 * (np.log(sigma)) ** 2)) 40 | # Set the DC component to zero 41 | log_gabor[self.filter_length // 2] = 0 42 | 43 | # Transform the filter to the time domain 44 | log_gabor_time = torch.ifft(log_gabor, n=self.filter_length) 45 | 46 | channel_filters.append(log_gabor_time.view(1, 1, -1)) 47 | 48 | filters.append(torch.cat(channel_filters, dim=0)) 49 | 50 | return nn.Parameter(torch.cat(filters, dim=1), requires_grad=False) 51 | 52 | def forward(self, x): 53 | """ 54 | Apply the filter bank to an input audio signal. 55 | 56 | Args: 57 | x (torch.Tensor): Input audio signal with shape (batch_size, audio_channels, signal_length). 58 | 59 | Returns: 60 | torch.Tensor: Encoded audio signal with shape (batch_size, num_filters * audio_channels, output_length). 61 | """ 62 | batch_size, audio_channels, signal_length = x.size() 63 | 64 | # Apply the filter bank to the input audio signal 65 | x = x.view(batch_size * audio_channels, 1, signal_length) 66 | y = torch.conv1d(x, self.filters, padding=self.filter_length // 2) 67 | 68 | # Reshape the output tensor 69 | output_length = y.size(-1) 70 | y = y.view(batch_size, self.num_filters * 71 | self.audio_channels, output_length) 72 | 73 | return y 74 | 75 | def decode(self, encoded_audio): 76 | """ 77 | Apply the inverse filter bank to the encoded audio signal. 78 | 79 | Args: 80 | encoded_audio (torch.Tensor): Encoded audio signal with shape (batch_size, num_filters * audio_channels, signal_length). 81 | 82 | Returns: 83 | torch.Tensor: Decoded audio signal with shape (batch_size, audio_channels, output_length). 84 | """ 85 | batch_size, _, signal_length = encoded_audio.size() 86 | output_length = signal_length + self.filter_length - 1 87 | 88 | # Reshape the encoded audio tensor 89 | encoded_audio = encoded_audio.view( 90 | batch_size, self.num_filters, self.audio_channels, -1) 91 | 92 | # Transpose the tensor to have the filter dimension first 93 | encoded_audio = encoded_audio.permute(0, 1, 3, 2) 94 | 95 | # Apply the inverse filter bank to each frequency channel and channel separately 96 | decoded_audio = [] 97 | for channel in range(self.audio_channels): 98 | channel_filters = self.filters[:, channel * 99 | self.num_filters:(channel+1)*self.num_filters, :] 100 | channel_encoded_audio = encoded_audio[:, :, :, channel] 101 | channel_decoded_audio = torch.conv1d( 102 | channel_encoded_audio, channel_filters, padding=self.filter_length - 1) 103 | decoded_audio.append(channel_decoded_audio) 104 | 105 | # Reshape the output tensor 106 | decoded_audio = torch.stack(decoded_audio, dim=-1) 107 | decoded_audio = decoded_audio.view( 108 | batch_size, self.audio_channels, output_length) 109 | 110 | return decoded_audio 111 | -------------------------------------------------------------------------------- /viz/viz.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | from pathlib import Path 4 | from matplotlib.backends.backend_agg import FigureCanvasAgg 5 | import matplotlib.cm as cm 6 | import matplotlib.pyplot as plt 7 | from matplotlib.colors import Normalize 8 | from matplotlib.figure import Figure 9 | import numpy as np 10 | from PIL import Image 11 | 12 | import torch 13 | from torch import optim, nn 14 | from torch.nn import functional as F 15 | import torchaudio 16 | import torchaudio.transforms as T 17 | import librosa 18 | from einops import rearrange 19 | 20 | import wandb 21 | import numpy as np 22 | import pandas as pd 23 | 24 | 25 | 26 | def embeddings_table(tokens): 27 | "make a table of embeddings for use with wandb" 28 | features, labels = [], [] 29 | embeddings = rearrange(tokens, 'b d n -> b n d') # each demo sample is n vectors in d-dim space 30 | for i in range(embeddings.size()[0]): # nested for's are slow but sure ;-) 31 | for j in range(embeddings.size()[1]): 32 | features.append(embeddings[i,j].detach().cpu().numpy()) 33 | labels.append([f'demo{i}']) # labels does the grouping / color for each point 34 | features = np.array(features) 35 | #print("\nfeatures.shape = ",features.shape) 36 | labels = np.concatenate(labels, axis=0) 37 | cols = [f"dim_{i}" for i in range(features.shape[1])] 38 | df = pd.DataFrame(features, columns=cols) 39 | df['LABEL'] = labels 40 | return wandb.Table(columns=df.columns.to_list(), data=df.values) 41 | 42 | 43 | def proj_pca(tokens, proj_dims=3): 44 | "this projects via PCA, grabbing the first _3_ dimensions" 45 | A = rearrange(tokens, 'b d n -> (b n) d') # put all the vectors into the same d-dim space 46 | k = proj_dims 47 | (U, S, V) = torch.pca_lowrank(A) 48 | proj_data = torch.matmul(A, V[:, :k]) # this is the actual PCA projection step 49 | return torch.reshape(proj_data, (tokens.size()[0], -1, proj_dims)) # put it in shape [batch, n, 3] 50 | 51 | 52 | def pca_point_cloud(tokens): 53 | "produces a 3D wandb point cloud of the tokens using PCA" 54 | data = proj_pca(tokens).cpu().numpy() 55 | points = [] 56 | cmap = cm.tab20 # 20 farly distinct colors 57 | norm = Normalize(vmin=0, vmax=data.shape[0]) 58 | for bi in range(data.shape[0]): # batch index 59 | [r, g, b, _] = [int(255*x) for x in cmap(norm(bi))] 60 | for n in range(data.shape[1]): 61 | #points.append([data[b,n,0], data[b,n,1], data[b,n,2], color]) # only works for color=1 to 14 62 | points.append([data[bi,n,0], data[bi,n,1], data[bi,n,2], r, g, b]) 63 | 64 | point_cloud = np.array(points) 65 | return wandb.Object3D(point_cloud) 66 | 67 | 68 | def spectrogram_image(spec, title=None, ylabel='freq_bin', aspect='auto', xmax=None, db_range=[35,120]): 69 | """ 70 | # cf. https://pytorch.org/tutorials/beginner/audio_feature_extractions_tutorial.html 71 | 72 | """ 73 | fig = Figure(figsize=(5, 4), dpi=100) 74 | canvas = FigureCanvasAgg(fig) 75 | axs = fig.add_subplot() 76 | axs.set_title(title or 'Spectrogram (db)') 77 | axs.set_ylabel(ylabel) 78 | axs.set_xlabel('frame') 79 | im = axs.imshow(librosa.power_to_db(spec), origin='lower', aspect=aspect, vmin=db_range[0], vmax=db_range[1]) 80 | if xmax: 81 | axs.set_xlim((0, xmax)) 82 | fig.colorbar(im, ax=axs) 83 | canvas.draw() 84 | rgba = np.asarray(canvas.buffer_rgba()) 85 | return Image.fromarray(rgba) 86 | 87 | 88 | def audio_spectrogram_image(waveform, power=2.0, sample_rate=48000): 89 | """ 90 | # cf. https://pytorch.org/tutorials/beginner/audio_feature_extractions_tutorial.html 91 | """ 92 | n_fft = 1024 93 | win_length = None 94 | hop_length = 512 95 | n_mels = 80 96 | 97 | mel_spectrogram_op = T.MelSpectrogram( 98 | sample_rate=sample_rate, n_fft=n_fft, win_length=win_length, 99 | hop_length=hop_length, center=True, pad_mode="reflect", power=power, 100 | norm='slaney', onesided=True, n_mels=n_mels, mel_scale="htk") 101 | 102 | melspec = mel_spectrogram_op(waveform.float()) 103 | melspec = melspec[0] # TODO: only left channel for now 104 | return spectrogram_image(melspec, title="MelSpectrogram", ylabel='mel bins (log freq)') 105 | 106 | 107 | def tokens_spectrogram_image(tokens, aspect='auto', title='Embeddings', ylabel='index'): 108 | embeddings = rearrange(tokens, 'b d n -> (b n) d') 109 | print(f"tokens_spectrogram_image: embeddings.shape = ",embeddings.shape) 110 | fig = Figure(figsize=(10, 4), dpi=100) 111 | canvas = FigureCanvasAgg(fig) 112 | axs = fig.add_subplot() 113 | axs.set_title(title or 'Embeddings') 114 | axs.set_ylabel(ylabel) 115 | axs.set_xlabel('time frame') 116 | im = axs.imshow(embeddings.cpu().numpy().T, origin='lower', aspect=aspect, interpolation='none') #.T because numpy is x/y 'backwards' 117 | fig.colorbar(im, ax=axs) 118 | canvas.draw() 119 | rgba = np.asarray(canvas.buffer_rgba()) 120 | return Image.fromarray(rgba) 121 | -------------------------------------------------------------------------------- /prune_latent_uncond.py: -------------------------------------------------------------------------------- 1 | #@title Imports and definitions 2 | import argparse 3 | from contextlib import contextmanager 4 | from copy import deepcopy 5 | import math 6 | from pathlib import Path 7 | 8 | import sys 9 | import gc 10 | 11 | from autoencoders.soundstream import SoundStreamXLEncoder, SoundStreamXLDecoder 12 | from autoencoders.models import AudioAutoencoder 13 | from audio_encoders_pytorch import Encoder1d 14 | from ema_pytorch import EMA 15 | from audio_diffusion_pytorch.modules import UNetConditional1d 16 | 17 | from audio_diffusion_pytorch import T5Embedder, NumberEmbedder 18 | 19 | import torch 20 | from torch import optim, nn 21 | from torch.nn import functional as F 22 | from torch.utils import data 23 | from tqdm import trange 24 | from einops import rearrange 25 | 26 | import torchaudio 27 | from decoders.diffusion_decoder import DiffusionAttnUnet1D 28 | import numpy as np 29 | 30 | import random 31 | from diffusion.utils import Stereo, PadCrop 32 | from glob import glob 33 | 34 | from torch.nn.parameter import Parameter 35 | 36 | class LatentAudioDiffusion(nn.Module): 37 | def __init__(self, autoencoder: AudioAutoencoder): 38 | super().__init__() 39 | 40 | 41 | self.latent_dim = autoencoder.latent_dim 42 | 43 | self.second_stage_latent_dim = 32 44 | 45 | factors = [2, 2, 2, 2] 46 | 47 | self.latent_downsampling_ratio = np.prod(factors) 48 | 49 | self.downsampling_ratio = autoencoder.downsampling_ratio * self.latent_downsampling_ratio 50 | 51 | self.latent_encoder = Encoder1d( 52 | in_channels=self.latent_dim, 53 | out_channels = self.second_stage_latent_dim, 54 | channels = 128, 55 | multipliers = [1, 2, 4, 8, 8], 56 | factors = factors, 57 | num_blocks = [8, 8, 8, 8], 58 | ) 59 | 60 | self.diffusion = DiffusionAttnUnet1D( 61 | io_channels=self.latent_dim, 62 | cond_dim = self.second_stage_latent_dim, 63 | n_attn_layers=0, 64 | c_mults=[512] * 10, 65 | depth=10 66 | ) 67 | 68 | self.autoencoder = autoencoder 69 | 70 | self.autoencoder.requires_grad_(False) 71 | 72 | self.rng = torch.quasirandom.SobolEngine(1, scramble=True) 73 | 74 | def encode(self, reals): 75 | first_stage_latents = self.autoencoder.encode(reals) 76 | 77 | second_stage_latents = self.latent_encoder(first_stage_latents) 78 | 79 | second_stage_latents = torch.tanh(second_stage_latents) 80 | 81 | return second_stage_latents 82 | 83 | def decode(self, latents, steps=250, device="cuda"): 84 | first_stage_latent_noise = torch.randn([latents.shape[0], self.latent_dim, latents.shape[2]*self.latent_downsampling_ratio]).to(device) 85 | 86 | t = torch.linspace(1, 0, steps + 1, device=device)[:-1] 87 | 88 | step_list = get_spliced_ddpm_cosine_schedule(t) 89 | 90 | first_stage_sampled = sampling.iplms_sample(self.diffusion, first_stage_latent_noise, step_list, {"cond":latents}) 91 | #first_stage_sampled = sample(self.diffusion, first_stage_latent_noise, steps, 0, cond=latents) 92 | decoded = self.autoencoder.decode(first_stage_sampled) 93 | return decoded 94 | 95 | def prune_ckpt_weights(stacked_state_dict): 96 | new_state_dict = {} 97 | for name, param in stacked_state_dict.items(): 98 | if name.startswith("diffusion_ema.ema_model."): 99 | new_name = name.replace("diffusion_ema.ema_model.", "diffusion.") 100 | if isinstance(param, Parameter): 101 | # backwards compatibility for serialized parameters 102 | param = param.data 103 | new_state_dict[new_name] = param 104 | elif name.startswith("autoencoder") or name.startswith("timestamp_embedder"): 105 | new_state_dict[name] = param 106 | 107 | return new_state_dict 108 | 109 | 110 | if __name__ == "__main__": 111 | 112 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 113 | parser.add_argument('--ckpt_path', help='Path to the checkpoint to be pruned') 114 | args = parser.parse_args() 115 | 116 | print("Creating the model...") 117 | 118 | first_stage_config = {"capacity": 64, "c_mults": [2, 4, 8, 16, 32], "strides": [2, 2, 2, 2, 2], "latent_dim": 32} 119 | 120 | first_stage_autoencoder = AudioAutoencoder( 121 | **first_stage_config 122 | ).eval() 123 | 124 | latent_diffusion_config = 125 | 126 | latent_diffae = LatentAudioDiffusionAutoencoder(autoencoder=first_stage_autoencoder) 127 | 128 | model = StackedAELatentDiffusionCond(latent_diffae) 129 | 130 | ckpt_state_dict = torch.load(args.ckpt_path)["state_dict"] 131 | print(ckpt_state_dict.keys()) 132 | 133 | new_ckpt = {} 134 | 135 | new_ckpt["state_dict"] = prune_ckpt_weights(ckpt_state_dict) 136 | 137 | model.load_state_dict(new_ckpt["state_dict"]) 138 | 139 | torch.save(new_ckpt, f'./pruned.ckpt') -------------------------------------------------------------------------------- /pca_analysis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from sklearn.decomposition import PCA 6 | import argparse 7 | import torchaudio 8 | from autoencoders.models import AudioAutoencoder, LatentAudioDiffusionAutoencoder 9 | 10 | def load_audio(audio_dir, autoencoder): 11 | target_sample_rate = 48000 12 | target_sequence_length = 262144 13 | audio_data = [] 14 | for filename in os.listdir(audio_dir): 15 | print(f'Processing {filename}...') 16 | if filename.endswith('.wav'): 17 | filepath = os.path.join(audio_dir, filename) 18 | audio, sample_rate = torchaudio.load(filepath) 19 | 20 | # Resample audio to target sample rate 21 | if sample_rate != target_sample_rate: 22 | resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate) 23 | audio = resampler(audio) 24 | 25 | # Pad or crop audio to target sequence length 26 | audio_length = audio.shape[1] 27 | if audio_length > target_sequence_length: 28 | audio = audio[:, :target_sequence_length] 29 | elif audio_length < target_sequence_length: 30 | pad_length = target_sequence_length - audio_length 31 | audio = torch.nn.functional.pad(audio, (0, pad_length)) 32 | 33 | # Ensure audio is stereo 34 | if audio.shape[0] == 1: 35 | audio = torch.cat([audio, audio], dim=0) 36 | 37 | audio = audio.unsqueeze(0).float() 38 | encoded_audio = encode_audio(audio, autoencoder) 39 | print(encoded_audio.shape) 40 | audio_data.append(encoded_audio.cpu()) 41 | 42 | return torch.cat(audio_data, dim=0) 43 | 44 | def encode_audio(audio, autoencoder): 45 | encoded_audio = autoencoder.encode(audio.to("cuda")) 46 | return encoded_audio 47 | 48 | def perform_pca(data): 49 | pca = PCA() 50 | pca.fit(data) 51 | pca_data = pca.transform(data) 52 | components = pca.components_ 53 | explained_variance_ratios = pca.explained_variance_ratio_ 54 | cumulative_variance_ratio = np.cumsum(explained_variance_ratios) 55 | informative_dimensions = np.argmax(cumulative_variance_ratio >= 0.95) + 1 56 | 57 | return pca_data, components, explained_variance_ratios, informative_dimensions 58 | 59 | def plot_scree(explained_variance_ratios): 60 | plt.plot(np.arange(1, explained_variance_ratios.size+1), explained_variance_ratios, 'bo-', linewidth=2) 61 | plt.xlabel('Principal Component') 62 | plt.ylabel('Explained Variance Ratio') 63 | plt.title('Scree Plot') 64 | plt.grid() 65 | 66 | def save_plot(filename): 67 | plt.savefig(filename) 68 | 69 | if __name__ == '__main__': 70 | # Parse command-line arguments 71 | parser = argparse.ArgumentParser(description='Perform PCA on audio signals') 72 | parser.add_argument('audio_dir', metavar='audio_dir', type=str, 73 | help='path to directory containing audio files') 74 | parser.add_argument('--pretrained_ckpt_path', type=str, 75 | help='path to pretrained checkpoint') 76 | args = parser.parse_args() 77 | 78 | first_stage_config = {"capacity": 64, "c_mults": [2, 4, 8, 16, 32], "strides": [2, 2, 2, 2, 2], "latent_dim": 32} 79 | 80 | first_stage_autoencoder = AudioAutoencoder( 81 | **first_stage_config 82 | ).eval() 83 | 84 | latent_diffae_config = { 85 | "second_stage_latent_dim": 32, 86 | "downsample_factors": [2, 2, 2, 2], 87 | "encoder_base_channels": 128, 88 | "encoder_channel_mults": [1, 2, 4, 8, 8], 89 | "encoder_num_blocks": [8, 8, 8, 8], 90 | "diffusion_channel_dims": [512] * 10 91 | } 92 | 93 | latent_diffae = LatentAudioDiffusionAutoencoder(autoencoder=first_stage_autoencoder, **latent_diffae_config).to("cuda").eval().requires_grad_(False) 94 | 95 | print(f'Loading pretrained diffAE checkpoint from {args.pretrained_ckpt_path}...') 96 | latent_diffae.load_state_dict(torch.load(args.pretrained_ckpt_path, map_location='cpu')['state_dict']) 97 | 98 | print(f'Loading audio files from {args.audio_dir}...') 99 | # Load the audio files and encode them 100 | audio_data = load_audio(args.audio_dir, latent_diffae) 101 | 102 | # Reshape the data to [num_samples, num_features] 103 | num_samples, num_channels, sequence_length = audio_data.shape 104 | print(f'Number of samples: {num_samples}') 105 | data = audio_data.permute(0, 2, 1).reshape(-1, num_channels).numpy() 106 | 107 | # Perform PCA 108 | print('Performing PCA...') 109 | pca_data, components, explained_variance_ratios, informative_dimensions = perform_pca(data) 110 | 111 | # Plot the scree plot 112 | print('Plotting scree plot...') 113 | plot_scree(explained_variance_ratios) 114 | 115 | # Save the plot as a PNG file 116 | save_plot('scree_plot.png') 117 | 118 | print(f'Number of informative dimensions: {informative_dimensions}') -------------------------------------------------------------------------------- /losses/adv_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.utils.weight_norm as wn 5 | import cached_conv as cc 6 | from encodec.msstftd import MultiScaleSTFTDiscriminator 7 | 8 | class Discriminator(nn.Module): 9 | 10 | def __init__(self, in_size, capacity, multiplier, n_layers): 11 | super().__init__() 12 | 13 | net = [ 14 | wn(cc.Conv1d(in_size, capacity, 15, padding=cc.get_padding(15))) 15 | ] 16 | net.append(nn.LeakyReLU(.2)) 17 | 18 | for i in range(n_layers): 19 | net.append( 20 | wn( 21 | cc.Conv1d( 22 | capacity * multiplier**i, 23 | min(1024, capacity * multiplier**(i + 1)), 24 | 41, 25 | stride=multiplier, 26 | padding=cc.get_padding(41, multiplier), 27 | groups=multiplier**(i + 1), 28 | ))) 29 | net.append(nn.LeakyReLU(.2)) 30 | 31 | net.append( 32 | wn( 33 | cc.Conv1d( 34 | min(1024, capacity * multiplier**(i + 1)), 35 | min(1024, capacity * multiplier**(i + 1)), 36 | 5, 37 | padding=cc.get_padding(5), 38 | ))) 39 | net.append(nn.LeakyReLU(.2)) 40 | net.append( 41 | wn(cc.Conv1d(min(1024, capacity * multiplier**(i + 1)), 1, 1))) 42 | self.net = nn.ModuleList(net) 43 | 44 | def forward(self, x): 45 | feature = [] 46 | for layer in self.net: 47 | x = layer(x) 48 | if isinstance(layer, nn.Conv1d): 49 | feature.append(x) 50 | return feature 51 | 52 | 53 | class StackDiscriminators(nn.Module): 54 | 55 | def __init__(self, n_dis, *args, **kwargs): 56 | super().__init__() 57 | self.discriminators = nn.ModuleList( 58 | [Discriminator(*args, **kwargs) for i in range(n_dis)], ) 59 | 60 | def forward(self, x): 61 | features = [] 62 | for layer in self.discriminators: 63 | features.append(layer(x)) 64 | x = nn.functional.avg_pool1d(x, 2) 65 | return features 66 | 67 | def adversarial_combine(self, score_real, score_fake): 68 | loss_dis = torch.relu(1 - score_real) + torch.relu(1 + score_fake) 69 | loss_dis = loss_dis.mean() 70 | loss_gen = -score_fake.mean() 71 | return loss_dis, loss_gen 72 | 73 | def loss(self, x, y): 74 | feature_matching_distance = 0. 75 | feature_true = self.forward(x) 76 | feature_fake = self.forward(y) 77 | 78 | loss_dis = 0 79 | loss_adv = 0 80 | 81 | pred_true = 0 82 | pred_fake = 0 83 | 84 | for scale_true, scale_fake in zip(feature_true, feature_fake): 85 | feature_matching_distance = feature_matching_distance + sum( 86 | map( 87 | lambda x, y: abs(x - y).mean(), 88 | scale_true, 89 | scale_fake, 90 | )) / len(scale_true) 91 | 92 | _dis, _adv = self.adversarial_combine( 93 | scale_true[-1], 94 | scale_fake[-1], 95 | ) 96 | 97 | pred_true = pred_true + scale_true[-1].mean() 98 | pred_fake = pred_fake + scale_fake[-1].mean() 99 | 100 | loss_dis = loss_dis + _dis 101 | loss_adv = loss_adv + _adv 102 | 103 | return loss_dis, loss_adv, feature_matching_distance, pred_true, pred_fake 104 | 105 | class EncodecDiscriminator(nn.Module): 106 | 107 | def __init__(self, *args, **kwargs): 108 | super().__init__() 109 | self.discriminators = MultiScaleSTFTDiscriminator(*args, **kwargs) 110 | 111 | def forward(self, x): 112 | logits, features = self.discriminators(x) 113 | return logits, features 114 | 115 | def adversarial_combine(self, score_real, score_fake): 116 | loss_dis = torch.relu(1 - score_real) + torch.relu(1 + score_fake) 117 | loss_dis = loss_dis.mean() 118 | loss_gen = -score_fake.mean() 119 | return loss_dis, loss_gen 120 | 121 | def loss(self, x, y): 122 | feature_matching_distance = 0. 123 | logits_true, feature_true = self.forward(x) 124 | logits_fake, feature_fake = self.forward(y) 125 | 126 | loss_dis = 0 127 | loss_adv = 0 128 | 129 | pred_true = 0 130 | pred_fake = 0 131 | 132 | for i, (scale_true, scale_fake) in enumerate(zip(feature_true, feature_fake)): 133 | 134 | feature_matching_distance = feature_matching_distance + sum( 135 | map( 136 | lambda x, y: abs(x - y).mean(), 137 | scale_true, 138 | scale_fake, 139 | )) / len(scale_true) 140 | 141 | _dis, _adv = self.adversarial_combine( 142 | logits_true[i], 143 | logits_fake[i], 144 | ) 145 | 146 | pred_true = pred_true + logits_true[i].mean() 147 | pred_fake = pred_fake + logits_fake[i].mean() 148 | 149 | loss_dis = loss_dis + _dis 150 | loss_adv = loss_adv + _adv 151 | 152 | return loss_dis, loss_adv, feature_matching_distance, pred_true, pred_fake -------------------------------------------------------------------------------- /train_clap_duration_predictor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from prefigure.prefigure import get_all_args, push_wandb_config 4 | import math, random 5 | 6 | import sys 7 | import torch 8 | from torch import optim, nn 9 | from torch.nn import functional as F 10 | import pytorch_lightning as pl 11 | from pytorch_lightning.utilities.distributed import rank_zero_only 12 | from einops import rearrange 13 | import numpy as np 14 | import torchaudio 15 | import laion_clap 16 | 17 | from dataset.dataset import get_wds_loader 18 | 19 | def unwrap_text(str_or_tuple): 20 | if type(str_or_tuple) is tuple: 21 | return random.choice(str_or_tuple) 22 | elif type(str_or_tuple) is str: 23 | return str_or_tuple 24 | 25 | class ClapDurationPredictor(pl.LightningModule): 26 | def __init__(self, clap_model: laion_clap.CLAP_Module): 27 | super().__init__() 28 | 29 | self.clap_model = clap_model 30 | 31 | # CLAP embeddings are 512-dim 32 | self.embedding_features = 512 33 | 34 | self.hidden_dim = 1024 35 | 36 | self.max_seconds = 512 37 | 38 | self.to_seconds_embed = nn.Sequential( 39 | nn.Linear(self.embedding_features, self.hidden_dim), 40 | nn.ReLU(), 41 | nn.Linear(self.hidden_dim, self.hidden_dim), 42 | nn.ReLU(), 43 | nn.Linear(self.hidden_dim, self.max_seconds + 1), 44 | nn.ReLU(), 45 | nn.Softmax(dim=-1) 46 | ) 47 | 48 | def get_clap_features(self, prompts, layer_ix=-2): 49 | prompt_tokens = self.clap_model.tokenizer(prompts) 50 | prompt_features = self.clap_model.model.text_branch( 51 | input_ids=prompt_tokens["input_ids"].to(device=self.device, non_blocking=True), 52 | attention_mask=prompt_tokens["attention_mask"].to( 53 | device=self.device, non_blocking=True 54 | ), 55 | output_hidden_states=True 56 | )["hidden_states"][layer_ix] 57 | 58 | masks = prompt_tokens["attention_mask"].to(device=self.device, non_blocking=True) 59 | 60 | return prompt_features, masks 61 | 62 | def configure_optimizers(self): 63 | return optim.Adam([*self.to_seconds_embed.parameters()], lr=4e-5) 64 | 65 | def training_step(self, batch, batch_idx): 66 | _, jsons, _ = batch 67 | 68 | condition_strings = [unwrap_text(json["prompt"][0]) for json in jsons] 69 | 70 | seconds_totals = [json["seconds_total"][0] for json in jsons] 71 | 72 | seconds_totals = torch.tensor(seconds_totals).to(self.device) 73 | seconds_totals = seconds_totals.clamp(0, self.max_seconds) 74 | 75 | with torch.no_grad(): 76 | # Get text embeds 77 | text_embeddings = self.clap_model.get_text_embedding(condition_strings, use_tensor=True) 78 | 79 | second_predictions = self.to_seconds_embed(text_embeddings) 80 | 81 | seconds_totals_one_hot = F.one_hot(seconds_totals, num_classes=self.max_seconds + 1).float() 82 | 83 | loss = F.binary_cross_entropy(seconds_totals_one_hot, second_predictions) 84 | 85 | log_dict = { 86 | 'train/loss': loss.detach(), 87 | } 88 | 89 | self.log_dict(log_dict, prog_bar=True, on_step=True) 90 | return loss 91 | 92 | class ExceptionCallback(pl.Callback): 93 | def on_exception(self, trainer, module, err): 94 | print(f'{type(err).__name__}: {err}') 95 | 96 | 97 | def main(): 98 | 99 | args = get_all_args() 100 | 101 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 102 | print('Using device:', device) 103 | torch.manual_seed(args.seed) 104 | 105 | names = [ 106 | ] 107 | 108 | train_dl = get_wds_loader( 109 | batch_size=args.batch_size, 110 | s3_url_prefix=None, 111 | sample_size=args.sample_size, 112 | names=names, 113 | sample_rate=args.sample_rate, 114 | num_workers=args.num_workers, 115 | recursive=True, 116 | random_crop=True, 117 | epoch_steps=10000 118 | ) 119 | 120 | exc_callback = ExceptionCallback() 121 | ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, save_top_k=-1) 122 | clap_model = laion_clap.CLAP_Module(enable_fusion=args.clap_fusion, device=device, amodel= args.clap_amodel).requires_grad_(False).eval() 123 | 124 | if args.clap_ckpt_path: 125 | clap_model.load_ckpt(ckpt=args.clap_ckpt_path) 126 | else: 127 | clap_model.load_ckpt(model_id=1) 128 | 129 | duration_predictor = ClapDurationPredictor(clap_model) 130 | 131 | wandb_logger = pl.loggers.WandbLogger(project=args.name) 132 | wandb_logger.watch(duration_predictor) 133 | push_wandb_config(wandb_logger, args) 134 | 135 | diffusion_trainer = pl.Trainer( 136 | devices=args.num_gpus, 137 | accelerator="gpu", 138 | num_nodes = args.num_nodes, 139 | strategy='ddp_find_unused_parameters_false', 140 | #precision=16, 141 | accumulate_grad_batches=args.accum_batches, 142 | callbacks=[ckpt_callback, exc_callback], 143 | logger=wandb_logger, 144 | log_every_n_steps=1, 145 | max_epochs=10000000, 146 | default_root_dir=args.save_dir 147 | ) 148 | 149 | diffusion_trainer.fit(duration_predictor, train_dl, ckpt_path=args.ckpt_path) 150 | 151 | if __name__ == '__main__': 152 | main() 153 | 154 | -------------------------------------------------------------------------------- /diffusion/transformers.py: -------------------------------------------------------------------------------- 1 | from x_transformers import ContinuousTransformerWrapper, Encoder 2 | from einops import rearrange 3 | import torch 4 | from torch import nn 5 | from encoders.wavelets import WaveletEncode1d, WaveletDecode1d 6 | from blocks.blocks import FourierFeatures 7 | 8 | class DiffusionTransformer(nn.Module): 9 | def __init__(self, 10 | io_channels=32, 11 | input_length=512, 12 | cond_token_dim=0, 13 | embed_dim=768, 14 | depth=12, 15 | num_heads=8, 16 | wavelet_levels=0): 17 | 18 | super().__init__() 19 | 20 | self.cond_token_dim = cond_token_dim 21 | self.wavelet_levels = wavelet_levels 22 | 23 | data_channels = io_channels 24 | 25 | # Wavelet decomposition 26 | if self.wavelet_levels > 0: 27 | self.wavelet_encoder = WaveletEncode1d(io_channels, "bior4.4", levels = self.wavelet_levels) 28 | self.wavelet_decoder = WaveletDecode1d(io_channels, "bior4.4", levels=self.wavelet_levels) 29 | data_channels = data_channels * (2**self.wavelet_levels) 30 | input_length = input_length // (2**self.wavelet_levels) 31 | 32 | 33 | # Timestep embeddings 34 | timestep_features_dim = 256 35 | 36 | self.timestep_features = FourierFeatures(1, timestep_features_dim) 37 | 38 | self.to_timestep_embed = nn.Sequential( 39 | nn.Linear(timestep_features_dim, embed_dim, bias=True), 40 | nn.SiLU(), 41 | nn.Linear(embed_dim, embed_dim, bias=True), 42 | ) 43 | 44 | if cond_token_dim > 0: 45 | # Conditioning tokens 46 | self.to_cond_embed = nn.Sequential( 47 | nn.Linear(cond_token_dim, embed_dim, bias=False), 48 | nn.SiLU(), 49 | nn.Linear(embed_dim, embed_dim, bias=False) 50 | ) 51 | 52 | # Transformer 53 | 54 | self.transformer = ContinuousTransformerWrapper( 55 | dim_in=data_channels, 56 | dim_out=data_channels, 57 | max_seq_len=input_length + 1, #1 for time conditioning 58 | attn_layers = Encoder( 59 | dim=embed_dim, 60 | depth=depth, 61 | heads=num_heads, 62 | cross_attend = True, 63 | zero_init_branch_output=True, 64 | rotary_pos_emb =True, 65 | ff_swish = True, # set this to True 66 | ff_glu = True 67 | ) 68 | ) 69 | 70 | self.preprocess_conv = nn.Conv1d(data_channels, data_channels, 3, padding=1, bias=False) 71 | nn.init.zeros_(self.preprocess_conv.weight) 72 | self.postprocess_conv = nn.Conv1d(data_channels, data_channels, 3, padding=1, bias=False) 73 | nn.init.zeros_(self.postprocess_conv.weight) 74 | 75 | def forward( 76 | self, 77 | x, 78 | t, 79 | cond_tokens=None, 80 | cond_token_mask=None, 81 | cfg_scale=1.0, 82 | cfg_dropout_prob=0.0): 83 | 84 | # Get the batch of timestep embeddings 85 | timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim) 86 | 87 | timestep_embed = timestep_embed.unsqueeze(1) 88 | 89 | if cond_tokens is not None: 90 | 91 | cond_tokens = self.to_cond_embed(cond_tokens) 92 | 93 | # CFG dropout 94 | if cfg_dropout_prob > 0.0: 95 | null_embed = torch.zeros_like(cond_tokens, device=cond_tokens.device) 96 | dropout_mask = torch.bernoulli(torch.full((cond_tokens.shape[0], 1, 1), cfg_dropout_prob, device=cond_tokens.device)).to(torch.bool) 97 | cond_tokens = torch.where(dropout_mask, null_embed, cond_tokens) 98 | 99 | if self.wavelet_levels > 0: 100 | x = self.wavelet_encoder(x) 101 | 102 | x = self.preprocess_conv(x) + x 103 | 104 | x = rearrange(x, "b c t -> b t c") 105 | 106 | if cond_tokens is not None and cfg_scale != 1.0: 107 | # Classifier-free guidance 108 | # Concatenate conditioned and unconditioned inputs on the batch dimension 109 | batch_inputs = torch.cat([x, x], dim=0) 110 | 111 | null_embed = torch.zeros_like(cond_tokens, device=cond_tokens.device) 112 | 113 | batch_timestep = torch.cat([timestep_embed, timestep_embed], dim=0) 114 | batch_cond = torch.cat([cond_tokens, null_embed], dim=0) 115 | if cond_token_mask is not None: 116 | batch_masks = torch.cat([cond_token_mask, cond_token_mask], dim=0) 117 | else: 118 | batch_masks = None 119 | 120 | output = self.transformer(batch_inputs, prepend_embeds=batch_timestep, context=batch_cond, context_mask=batch_masks) 121 | 122 | cond_output, uncond_output = torch.chunk(output, 2, dim=0) 123 | output = uncond_output + (cond_output - uncond_output) * cfg_scale 124 | 125 | else: 126 | output = self.transformer(x, prepend_embeds=timestep_embed, context=cond_tokens, context_mask=cond_token_mask) 127 | 128 | output = rearrange(output, "b t c -> b c t")[:,:,1:] 129 | 130 | if self.wavelet_levels > 0: 131 | output = self.wavelet_decoder(output) 132 | 133 | output = self.postprocess_conv(output) + output 134 | 135 | return output -------------------------------------------------------------------------------- /encoders/perceiver_resampler.py: -------------------------------------------------------------------------------- 1 | #Copied from https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py 2 | 3 | # attention pooling 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn, einsum 7 | 8 | from einops import rearrange, repeat 9 | from einops.layers.torch import Rearrange 10 | from einops_exts import rearrange_many 11 | 12 | def exists(val): 13 | return val is not None 14 | 15 | def masked_mean(t, *, dim, mask = None): 16 | if not exists(mask): 17 | return t.mean(dim = dim) 18 | 19 | denom = mask.sum(dim = dim, keepdim = True) 20 | mask = rearrange(mask, 'b n -> b n 1') 21 | masked_t = t.masked_fill(~mask, 0.) 22 | 23 | return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5) 24 | 25 | # norms and residuals 26 | 27 | class LayerNorm(nn.Module): 28 | def __init__(self, dim): 29 | super().__init__() 30 | self.gamma = nn.Parameter(torch.ones(dim)) 31 | self.register_buffer("beta", torch.zeros(dim)) 32 | 33 | def forward(self, x): 34 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) 35 | 36 | def FeedForward(dim, mult = 2): 37 | hidden_dim = int(dim * mult) 38 | return nn.Sequential( 39 | LayerNorm(dim), 40 | nn.Linear(dim, hidden_dim, bias = False), 41 | nn.GELU(), 42 | LayerNorm(hidden_dim), 43 | nn.Linear(hidden_dim, dim, bias = False) 44 | ) 45 | 46 | 47 | class PerceiverAttention(nn.Module): 48 | def __init__( 49 | self, 50 | *, 51 | dim, 52 | dim_head = 64, 53 | heads = 8 54 | ): 55 | super().__init__() 56 | self.scale = dim_head ** -0.5 57 | self.heads = heads 58 | inner_dim = dim_head * heads 59 | 60 | self.norm = nn.LayerNorm(dim) 61 | self.norm_latents = nn.LayerNorm(dim) 62 | 63 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 64 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) 65 | 66 | self.to_out = nn.Sequential( 67 | nn.Linear(inner_dim, dim, bias = False), 68 | nn.LayerNorm(dim) 69 | ) 70 | 71 | def forward(self, x, latents, mask = None): 72 | x = self.norm(x) 73 | latents = self.norm_latents(latents) 74 | 75 | b, h = x.shape[0], self.heads 76 | 77 | q = self.to_q(latents) 78 | 79 | # the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to 80 | kv_input = torch.cat((x, latents), dim = -2) 81 | k, v = self.to_kv(kv_input).chunk(2, dim = -1) 82 | 83 | q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h) 84 | 85 | q = q * self.scale 86 | 87 | # attention 88 | 89 | sim = einsum('... i d, ... j d -> ... i j', q, k) 90 | 91 | if exists(mask): 92 | max_neg_value = -torch.finfo(sim.dtype).max 93 | mask = F.pad(mask, (0, latents.shape[-2]), value = True) 94 | mask = rearrange(mask, 'b j -> b 1 1 j') 95 | sim = sim.masked_fill(~mask, max_neg_value) 96 | 97 | attn = sim.softmax(dim = -1) 98 | 99 | out = einsum('... i j, ... j d -> ... i d', attn, v) 100 | out = rearrange(out, 'b h n d -> b n (h d)', h = h) 101 | return self.to_out(out) 102 | 103 | class PerceiverResampler(nn.Module): 104 | def __init__( 105 | self, 106 | *, 107 | dim, 108 | depth, 109 | dim_head = 64, 110 | heads = 8, 111 | num_latents = 64, 112 | num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence 113 | max_seq_len = 512, 114 | ff_mult = 4, 115 | ): 116 | super().__init__() 117 | self.pos_emb = nn.Embedding(max_seq_len, dim) 118 | 119 | self.latents = nn.Parameter(torch.randn(num_latents, dim)) 120 | 121 | self.to_latents_from_mean_pooled_seq = None 122 | 123 | if num_latents_mean_pooled > 0: 124 | self.to_latents_from_mean_pooled_seq = nn.Sequential( 125 | LayerNorm(dim), 126 | nn.Linear(dim, dim * num_latents_mean_pooled), 127 | Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled) 128 | ) 129 | 130 | self.layers = nn.ModuleList([]) 131 | for _ in range(depth): 132 | self.layers.append(nn.ModuleList([ 133 | PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads), 134 | FeedForward(dim = dim, mult = ff_mult) 135 | ])) 136 | 137 | def forward(self, x, mask = None): 138 | n, device = x.shape[1], x.device 139 | pos_emb = self.pos_emb(torch.arange(n, device = device)) 140 | 141 | x_with_pos = x + pos_emb 142 | 143 | latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0]) 144 | 145 | if exists(self.to_latents_from_mean_pooled_seq): 146 | meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool)) 147 | meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) 148 | latents = torch.cat((meanpooled_latents, latents), dim = -2) 149 | 150 | for attn, ff in self.layers: 151 | latents = attn(x_with_pos, latents, mask = mask) + latents 152 | latents = ff(latents) + latents 153 | 154 | return latents 155 | -------------------------------------------------------------------------------- /test_encodec.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from prefigure.prefigure import get_all_args, push_wandb_config 4 | from contextlib import contextmanager 5 | from copy import deepcopy 6 | import math 7 | from pathlib import Path 8 | 9 | import sys, re 10 | import random 11 | import torch 12 | from torch import optim, nn 13 | from torch.nn import functional as F 14 | from torchaudio import transforms as T 15 | from torch.utils import data 16 | from tqdm import trange 17 | from einops import rearrange 18 | import numpy as np 19 | import torchaudio 20 | 21 | from functools import partial 22 | 23 | import wandb 24 | 25 | from dataset.dataset import get_all_s3_urls, get_s3_contents, get_wds_loader, wds_preprocess, log_and_continue, is_valid_sample 26 | import webdataset as wds 27 | import time 28 | 29 | def base_plus_ext(path): 30 | """Split off all file extensions. 31 | Returns base, allext. 32 | :param path: path with extensions 33 | :param returns: path with all extensions removed 34 | """ 35 | match = re.match(r"^((?:.*/|)[^.]+)[.]([^/]*)$", path) 36 | if not match: 37 | return None, None 38 | return match.group(1), match.group(2) 39 | 40 | def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): 41 | """Return function over iterator that groups key, value pairs into samples. 42 | :param keys: function that splits the key into key and extension (base_plus_ext) 43 | :param lcase: convert suffixes to lower case (Default value = True) 44 | """ 45 | print("Running new function") 46 | current_sample = None 47 | for filesample in data: 48 | assert isinstance(filesample, dict) 49 | fname, value = filesample["fname"], filesample["data"] 50 | prefix, suffix = keys(fname) 51 | if wds.tariterators.trace: 52 | print( 53 | prefix, 54 | suffix, 55 | current_sample.keys() if isinstance(current_sample, dict) else None, 56 | ) 57 | if prefix is None: 58 | continue 59 | if lcase: 60 | suffix = suffix.lower() 61 | if current_sample is None or prefix != current_sample["__key__"]: 62 | if valid_sample(current_sample): 63 | yield current_sample 64 | current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) 65 | if suffix in current_sample: 66 | raise ValueError( 67 | f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}" 68 | ) 69 | if suffixes is None or suffix in suffixes: 70 | current_sample[suffix] = value 71 | if valid_sample(current_sample): 72 | yield current_sample 73 | 74 | def valid_sample(sample): 75 | """Check whether a sample is valid. 76 | :param sample: sample to be checked 77 | """ 78 | return ( 79 | sample is not None 80 | and isinstance(sample, dict) 81 | and len(list(sample.keys())) > 0 82 | and not sample.get("__bad__", False) 83 | ) 84 | 85 | 86 | # Creates and returns a text prompt given a metadata object 87 | def get_prompt_from_metadata(metadata): 88 | 89 | print(metadata) 90 | 91 | return "" 92 | 93 | 94 | def main(): 95 | 96 | args = get_all_args() 97 | 98 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 99 | print('Using device:', device) 100 | torch.manual_seed(args.seed) 101 | 102 | print("Creating data loader") 103 | 104 | preprocess_fn = partial(wds_preprocess, 105 | sample_size=args.sample_size, 106 | sample_rate=args.sample_rate, 107 | random_crop=args.random_crop, 108 | verbose=True, 109 | normalize_lufs=-12.0, 110 | metadata_prompt_funcs={"FMA_stereo": get_prompt_from_metadata} 111 | ) 112 | 113 | names = [ 114 | 115 | ] 116 | 117 | urls = get_all_s3_urls( 118 | names=names, 119 | #s3_url_prefix="", 120 | recursive=True, 121 | ) 122 | 123 | #urls = ["s3://s-harmonai/datasets/"] 124 | 125 | def print_inputs(inputs): 126 | print(f"Sample: {inputs}") 127 | return inputs 128 | 129 | wds.tariterators.group_by_keys = group_by_keys 130 | 131 | dataset = wds.DataPipeline( 132 | wds.ResampledShards(urls), # Yields a single .tar URL 133 | wds.split_by_worker, 134 | wds.map(print_inputs), 135 | wds.tarfile_to_samples(handler=log_and_continue), # Opens up a stream to the TAR file, yields files grouped by keys 136 | wds.decode(wds.torch_audio, handler=log_and_continue), 137 | wds.map(preprocess_fn, handler=log_and_continue), 138 | wds.select(is_valid_sample), 139 | wds.to_tuple("audio", "json", "timestamps", handler=log_and_continue), 140 | wds.batched(args.batch_size, partial=False) 141 | ) 142 | 143 | train_dl = wds.WebLoader(dataset, num_workers=args.num_workers) 144 | 145 | print("Creating data loader") 146 | 147 | #for json in train_dl: 148 | for epoch_num in range(1): 149 | train_iter = iter(train_dl) 150 | print(f"Starting epoch {epoch_num}") 151 | start_time = time.time() 152 | for i, sample in enumerate(train_iter): 153 | #json = next(train_dl) 154 | audio, json, timestamps = sample 155 | print(f"Epoch {epoch_num} Batch {i}") 156 | print(audio.shape) 157 | samples_per_sec = ((i+1) * args.batch_size) / (time.time() - start_time) 158 | print(f"Samples/sec this epoch: {samples_per_sec}") 159 | #time.sleep(5.0) 160 | 161 | if __name__ == '__main__': 162 | main() -------------------------------------------------------------------------------- /encoders/encoders.py: -------------------------------------------------------------------------------- 1 | ## Modified from https://github.com/wesbz/SoundStream/blob/main/net.py 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | from perceiver_pytorch import Perceiver 7 | from einops import rearrange 8 | from blocks.blocks import Downsample1d, SelfAttention1d, ResConvBlock 9 | 10 | class AttnResEncoder1D(nn.Module): 11 | def __init__( 12 | self, 13 | global_args, 14 | n_io_channels=2, 15 | depth=12, 16 | n_attn_layers = 5, 17 | downsamples = [0, 2, 2, 2] + [2] * 8, 18 | c_mults = [128, 128, 256, 256] + [512] * 8 19 | ): 20 | super().__init__() 21 | 22 | max_depth = 12 23 | depth = min(depth, max_depth) 24 | 25 | self.act = torch.tanh 26 | 27 | c_mults = c_mults[:depth] 28 | downsamples = downsamples[:depth] 29 | 30 | conv_block = ResConvBlock 31 | 32 | attn_start_layer = depth - n_attn_layers - 1 33 | 34 | c = c_mults[0] 35 | layers = [nn.Sequential( 36 | conv_block(n_io_channels, c, c), 37 | conv_block(c, c, c), 38 | conv_block(c, c, c), 39 | conv_block(c, c, c), 40 | )] 41 | 42 | for i in range(1, depth): 43 | c = c_mults[i] 44 | #downsample = downsamples[i] 45 | c_prev = c_mults[i - 1] 46 | add_attn = i >= attn_start_layer and n_attn_layers > 0 47 | layers.append(nn.Sequential( 48 | Downsample1d(kernel="cubic"), 49 | conv_block(c_prev, c, c), 50 | SelfAttention1d( 51 | c, c // 32) if add_attn else nn.Identity(), 52 | conv_block(c, c, c), 53 | SelfAttention1d( 54 | c, c // 32) if add_attn else nn.Identity(), 55 | conv_block(c, c, c), 56 | SelfAttention1d( 57 | c, c // 32) if add_attn else nn.Identity(), 58 | conv_block(c, c, c), 59 | SelfAttention1d( 60 | c, c // 32) if add_attn else nn.Identity(), 61 | )) 62 | 63 | 64 | layers.append(nn.Sequential( 65 | conv_block(c, c, c), 66 | conv_block(c, c, c), 67 | conv_block(c, c, c), 68 | conv_block(c, c, global_args.latent_dim, is_last=True) 69 | ) 70 | ) 71 | 72 | self.net = nn.Sequential(*layers) 73 | 74 | print(f"Encoder downsampling ratio: {np.prod(downsamples[1:])}") 75 | 76 | with torch.no_grad(): 77 | for param in self.net.parameters(): 78 | param *= 0.5 79 | 80 | def forward(self, input): 81 | return self.act(self.net(input)) 82 | 83 | class GlobalEncoder(nn.Sequential): 84 | def __init__(self, latent_size, io_channels): 85 | c_in = io_channels 86 | c_mults = [128, 128] + [latent_size] * 12 87 | layers = [] 88 | c_mult_prev = c_in 89 | for i, c_mult in enumerate(c_mults): 90 | is_last = i == len(c_mults) - 1 91 | layers.append(ResConvBlock(c_mult_prev, c_mult, c_mult)) 92 | layers.append(ResConvBlock( 93 | c_mult, c_mult, c_mult, is_last=is_last)) 94 | if not is_last: 95 | layers.append(Downsample1d()) 96 | else: 97 | layers.append(nn.AdaptiveAvgPool1d(1)) 98 | layers.append(nn.Flatten()) 99 | c_mult_prev = c_mult 100 | super().__init__(*layers) 101 | 102 | class AudioPerceiverEncoder(nn.Module): 103 | def __init__(self, 104 | in_channels = 2, 105 | out_features = 256, 106 | depth=10, 107 | self_per_cross_attn=2, 108 | weight_tie_layers = False, 109 | internal_num_latents = 256, 110 | internal_latent_dim = 512, 111 | ): 112 | super().__init__() 113 | self.net = Perceiver( 114 | input_channels=in_channels, # number of channels for each token of the input 115 | input_axis=1,# number of axis for input data (1 for audio, 2 for images, 3 for video) 116 | num_freq_bands=128,# number of freq bands, with original value (2 * K + 1) 117 | max_freq=1000., # maximum frequency, hyperparameter depending on how fine the data is 118 | depth=depth,# depth of net. The shape of the final attention mechanism will be: 119 | # depth * (cross attention -> self_per_cross_attn * self attention) 120 | num_latents=internal_num_latents, # number of latents, or induced set points, or centroids. different papers giving it different names 121 | latent_dim=internal_latent_dim, # latent dimension 122 | cross_heads=1, # number of heads for cross attention. paper said 1 123 | latent_heads=8, # number of heads for latent self attention, 8 124 | cross_dim_head=64, # number of dimensions per cross attention head 125 | latent_dim_head=64, # number of dimensions per latent self attention head 126 | num_classes=out_features, # output number of classes 127 | attn_dropout=0., 128 | ff_dropout=0., 129 | weight_tie_layers=weight_tie_layers,# whether to weight tie layers (optional, as indicated in the diagram) 130 | fourier_encode_data=True, # whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself 131 | self_per_cross_attn=self_per_cross_attn # number of self attention blocks per cross attention 132 | ) 133 | 134 | def forward(self, input): 135 | perceiver_input = rearrange(input, 'b d n -> b n d') 136 | return self.net(perceiver_input) 137 | 138 | 139 | -------------------------------------------------------------------------------- /train_local_transformer_clap_encoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from prefigure.prefigure import get_all_args, push_wandb_config 4 | 5 | import sys, os 6 | import random 7 | import torch 8 | from torch import optim, nn 9 | from torch.nn import functional as F 10 | from torch.utils import data 11 | from tqdm import trange 12 | import pytorch_lightning as pl 13 | from einops import rearrange 14 | import numpy as np 15 | import torchaudio 16 | 17 | import wandb 18 | 19 | from ema_pytorch import EMA 20 | import laion_clap 21 | 22 | from diffusion.model import ema_update 23 | from dataset.dataset import get_wds_loader 24 | from blocks.utils import InverseLR 25 | from autoencoders.transformer_ae import TransformerEncoder1D 26 | 27 | from diffusion.pqmf import CachedPQMF as PQMF 28 | 29 | 30 | def unwrap_text(str_or_tuple): 31 | if type(str_or_tuple) is tuple: 32 | return random.choice(str_or_tuple) 33 | elif type(str_or_tuple) is str: 34 | return str_or_tuple 35 | 36 | class ClapAudioEncoder(nn.Module): 37 | def __init__( 38 | self, 39 | in_channels = 1, 40 | pqmf_bands = 32, 41 | ): 42 | super().__init__() 43 | 44 | self.embedding_features = 512 45 | 46 | self.pqmf_bands = pqmf_bands 47 | 48 | if self.pqmf_bands > 1: 49 | self.pqmf = PQMF(in_channels, 70, self.pqmf_bands) 50 | 51 | self.audio_encoder = TransformerEncoder1D( 52 | in_channels = in_channels * self.pqmf_bands, 53 | out_channels = self.embedding_features, 54 | embed_dims = [96, 192, 384, 768], 55 | heads = [4, 8, 16, 32], 56 | depths = [2, 2, 2, 12], 57 | ratios = [4, 4, 2, 2], 58 | local_attn_window_size = 64 59 | ) 60 | 61 | self.pooling = nn.AdaptiveAvgPool1d(1) 62 | 63 | def forward(self, x): 64 | if self.pqmf_bands > 1: 65 | x = self.pqmf(x) 66 | x = self.audio_encoder(x) 67 | x = self.pooling(x) 68 | x = x.squeeze(-1) 69 | x = F.normalize(x, dim=-1) 70 | return x 71 | 72 | class ClapAudioEncoderTrainer(pl.LightningModule): 73 | def __init__(self): 74 | super().__init__() 75 | 76 | self.text_embedder = laion_clap.CLAP_Module(enable_fusion=False).requires_grad_(False).eval() 77 | 78 | self.text_embedder.load_ckpt(model_id=1) 79 | 80 | self.embedding_features = 512 81 | 82 | self.audio_encoder = ClapAudioEncoder() 83 | 84 | self.audio_encoder_ema = EMA( 85 | self.audio_encoder, 86 | beta = 0.9999, 87 | power=3/4, 88 | update_every = 1, 89 | update_after_step = 1 90 | ) 91 | 92 | def configure_optimizers(self): 93 | optimizer = optim.Adam([*self.audio_encoder.parameters()], lr=4e-5) 94 | 95 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500, eta_min=1e-6) 96 | 97 | return [optimizer], [scheduler] 98 | 99 | def training_step(self, batch, batch_idx): 100 | reals, jsons, timestamps = batch 101 | reals = reals[0] 102 | 103 | # Mono input 104 | reals = reals.mean(1, keepdim=True) 105 | 106 | condition_strings = [unwrap_text(json["text"][0]) for json in jsons] 107 | 108 | #print(condition_strings) 109 | 110 | with torch.cuda.amp.autocast(): 111 | 112 | with torch.no_grad(): 113 | text_embeddings = self.text_embedder.get_text_embedding(condition_strings) 114 | text_embeddings = torch.from_numpy(text_embeddings).to(self.device) 115 | 116 | audio_embeddings = self.audio_encoder(reals) 117 | 118 | cosine_sim = F.cosine_similarity(audio_embeddings, text_embeddings, dim=1) 119 | loss = -cosine_sim.mean() 120 | 121 | log_dict = { 122 | 'train/loss': loss.detach(), 123 | 'train/cosine_sim': cosine_sim.mean().detach(), 124 | 'train/lr': self.lr_schedulers().get_last_lr()[0], 125 | 'train/ema_decay': self.audio_encoder_ema.get_current_decay() 126 | } 127 | 128 | self.log_dict(log_dict, prog_bar=True, on_step=True) 129 | return loss 130 | 131 | def on_before_zero_grad(self, *args, **kwargs): 132 | self.audio_encoder_ema.update() 133 | 134 | class ExceptionCallback(pl.Callback): 135 | def on_exception(self, trainer, module, err): 136 | print(f'{type(err).__name__}: {err}') 137 | 138 | def main(): 139 | 140 | args = get_all_args() 141 | 142 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 143 | print('Using device:', device) 144 | torch.manual_seed(args.seed) 145 | 146 | names = [] 147 | 148 | metadata_prompt_funcs = {} 149 | train_dl = get_wds_loader( 150 | batch_size=args.batch_size, 151 | s3_url_prefix=None, 152 | sample_size=args.sample_size, 153 | names=names, 154 | sample_rate=args.sample_rate, 155 | num_workers=args.num_workers, 156 | recursive=True, 157 | random_crop=True, 158 | epoch_steps=10000, 159 | ) 160 | 161 | exc_callback = ExceptionCallback() 162 | 163 | ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, save_top_k=-1, save_last=True) 164 | 165 | clap_audio_encoder = ClapAudioEncoderTrainer() 166 | 167 | wandb_logger = pl.loggers.WandbLogger(project=args.name) 168 | wandb_logger.watch(clap_audio_encoder) 169 | push_wandb_config(wandb_logger, args) 170 | 171 | pl_trainer = pl.Trainer( 172 | devices=args.num_gpus, 173 | accelerator="gpu", 174 | num_nodes = args.num_nodes, 175 | strategy='ddp', 176 | precision=16, 177 | accumulate_grad_batches=args.accum_batches, 178 | callbacks=[ckpt_callback, exc_callback], 179 | logger=wandb_logger, 180 | log_every_n_steps=1, 181 | max_epochs=10000000, 182 | default_root_dir=args.save_dir, 183 | #gradient_clip_val=1.0, 184 | #track_grad_norm=2, 185 | #detect_anomaly = True 186 | ) 187 | 188 | pl_trainer.fit(clap_audio_encoder, train_dl) 189 | 190 | if __name__ == '__main__': 191 | main() -------------------------------------------------------------------------------- /train_archive/train_uncond_k.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from prefigure.prefigure import get_all_args, push_wandb_config 4 | from contextlib import contextmanager 5 | from copy import deepcopy 6 | import math 7 | from pathlib import Path 8 | 9 | import sys 10 | import torch 11 | from torch import optim, nn 12 | from torch.nn import functional as F 13 | from torch.utils import data 14 | from tqdm import trange 15 | import pytorch_lightning as pl 16 | from pytorch_lightning.utilities.distributed import rank_zero_only 17 | from einops import rearrange 18 | 19 | from diffusion.pqmf import CachedPQMF as PQMF 20 | import torchaudio 21 | 22 | import auraloss 23 | 24 | import wandb 25 | 26 | from aeiou.datasets import AudioDataset 27 | import k_diffusion as K 28 | 29 | from decoders.diffusion_decoder import DiffusionAttnUnet1D 30 | from diffusion.model import ema_update 31 | from viz.viz import embeddings_table, pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image 32 | 33 | SIGMA_MIN = 0.0001 34 | SIGMA_MAX = 1 35 | 36 | class DiffusionUncond(pl.LightningModule): 37 | def __init__(self, global_args): 38 | super().__init__() 39 | #self.diffusion = DiffusionAttnUnet1D(io_channels=2, pqmf_bands=global_args.pqmf_bands, n_attn_layers=4) 40 | 41 | self.inner_model = DiffusionAttnUnet1D( 42 | io_channels=2, 43 | pqmf_bands = global_args.pqmf_bands, 44 | n_attn_layers=4, 45 | ) 46 | 47 | self.model = K.Denoiser(self.inner_model, sigma_data = 0.4) 48 | 49 | self.model_ema = deepcopy(self.model) 50 | 51 | self.ema_decay = global_args.ema_decay 52 | 53 | def configure_optimizers(self): 54 | return optim.Adam([*self.model.parameters()], lr=4e-5) 55 | 56 | def training_step(self, batch, batch_idx): 57 | reals = batch 58 | 59 | std = torch.std(batch).detach() 60 | 61 | with torch.cuda.amp.autocast(): 62 | sigma = K.utils.rand_log_normal([reals.shape[0]], loc=-2, scale=1.0, device=self.device) 63 | loss = self.model.loss(reals, torch.randn_like(reals), sigma).mean() 64 | 65 | log_dict = { 66 | 'train/loss': loss.detach(), 67 | 'train/std': std 68 | } 69 | 70 | self.log_dict(log_dict, prog_bar=True, on_step=True) 71 | return loss 72 | 73 | def on_before_zero_grad(self, *args, **kwargs): 74 | decay = 0.95 if self.current_epoch < 25 else self.ema_decay 75 | ema_update(self.model, self.model_ema, decay) 76 | 77 | class ExceptionCallback(pl.Callback): 78 | def on_exception(self, trainer, module, err): 79 | print(f'{type(err).__name__}: {err}', file=sys.stderr) 80 | 81 | class DemoCallback(pl.Callback): 82 | def __init__(self, global_args): 83 | super().__init__() 84 | self.demo_every = global_args.demo_every 85 | self.num_demos = global_args.num_demos 86 | self.demo_samples = global_args.sample_size 87 | self.demo_steps = global_args.demo_steps 88 | self.sample_rate = global_args.sample_rate 89 | 90 | @rank_zero_only 91 | @torch.no_grad() 92 | #def on_train_epoch_end(self, trainer, module): 93 | def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): 94 | last_demo_step = -1 95 | if (trainer.global_step - 1) % self.demo_every != 0 or last_demo_step == trainer.global_step: 96 | #if trainer.current_epoch % self.demo_every != 0: 97 | return 98 | 99 | last_demo_step = trainer.global_step 100 | 101 | print("Getting noise") 102 | noise = torch.randn([self.num_demos, 2, self.demo_samples]).to(module.device) 103 | 104 | try: 105 | 106 | print("Starting sampling") 107 | fakes = K.sampling.sample_dpm_fast(module.model_ema, noise, SIGMA_MIN, SIGMA_MAX, self.demo_steps) 108 | 109 | # Put the demos together 110 | fakes = rearrange(fakes, 'b d n -> d (b n)') 111 | 112 | log_dict = {} 113 | 114 | filename = f'demo_{trainer.global_step:08}.wav' 115 | fakes = fakes.clamp(-1, 1).mul(32767).to(torch.int16).cpu() 116 | torchaudio.save(filename, fakes, self.sample_rate) 117 | 118 | 119 | log_dict[f'demo'] = wandb.Audio(filename, 120 | sample_rate=self.sample_rate, 121 | caption=f'Reconstructed') 122 | 123 | 124 | log_dict[f'demo_melspec_left'] = wandb.Image(audio_spectrogram_image(fakes)) 125 | 126 | 127 | trainer.logger.experiment.log(log_dict, step=trainer.global_step) 128 | except Exception as e: 129 | print(f'{type(e).__name__}: {e}', file=sys.stderr) 130 | 131 | def main(): 132 | 133 | args = get_all_args() 134 | 135 | args.latent_dim = 0 136 | 137 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 138 | print('Using device:', device) 139 | torch.manual_seed(args.seed) 140 | 141 | train_set = AudioDataset( 142 | [args.training_dir], 143 | sample_rate=args.sample_rate, 144 | sample_size=args.sample_size, 145 | random_crop=args.random_crop, 146 | augs='Stereo()' 147 | ) 148 | 149 | train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True, 150 | num_workers=args.num_workers, persistent_workers=True, pin_memory=True) 151 | wandb_logger = pl.loggers.WandbLogger(project=args.name) 152 | exc_callback = ExceptionCallback() 153 | ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, save_top_k=-1) 154 | demo_callback = DemoCallback(args) 155 | diffusion_model = DiffusionUncond(args) 156 | wandb_logger.watch(diffusion_model) 157 | push_wandb_config(wandb_logger, args) 158 | 159 | diffusion_trainer = pl.Trainer( 160 | devices=args.num_gpus, 161 | accelerator="gpu", 162 | num_nodes = args.num_nodes, 163 | strategy='ddp', 164 | precision=16, 165 | accumulate_grad_batches=args.accum_batches, 166 | callbacks=[ckpt_callback, demo_callback, exc_callback], 167 | logger=wandb_logger, 168 | log_every_n_steps=1, 169 | max_epochs=10000000, 170 | ) 171 | 172 | diffusion_trainer.fit(diffusion_model, train_dl, ckpt_path=args.ckpt_path) 173 | 174 | if __name__ == '__main__': 175 | main() -------------------------------------------------------------------------------- /diffusion/inference.py: -------------------------------------------------------------------------------- 1 | from tqdm import trange 2 | import torch 3 | from functools import partial 4 | from .utils import get_alphas_sigmas 5 | 6 | import math 7 | 8 | from scipy import integrate 9 | import torch 10 | from tqdm import trange 11 | 12 | from blocks import utils 13 | 14 | @torch.no_grad() 15 | def sample(model, signal, steps, eta): 16 | """Draws samples from a model given starting noise.""" 17 | ts = signal.new_ones([signal.shape[0]]) 18 | 19 | # Create the noise schedule 20 | t = torch.linspace(1, 0, steps + 1)[:-1] 21 | alphas, sigmas = get_alphas_sigmas(t) 22 | 23 | z_sem = model.encode(signal) 24 | 25 | #Create the noise to sample from (adds a stochastic variation) 26 | x = torch.randn_like(signal) 27 | 28 | # The sampling loop 29 | for i in trange(steps): 30 | 31 | # Get the model output (v, the predicted velocity) 32 | with torch.cuda.amp.autocast(): 33 | v = model.decode(x, ts * t[i], z_sem).float() 34 | 35 | # Predict the noise and the denoised image 36 | pred = x * alphas[i] - v * sigmas[i] 37 | eps = x * sigmas[i] + v * alphas[i] 38 | 39 | # If we are not on the last timestep, compute the noisy image for the 40 | # next timestep. 41 | if i < steps - 1: 42 | # If eta > 0, adjust the scaling factor for the predicted noise 43 | # downward according to the amount of additional noise to add 44 | ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \ 45 | (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt() 46 | adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt() 47 | 48 | # Recombine the predicted noise and predicted denoised image in the 49 | # correct proportions for the next step 50 | x = pred * alphas[i + 1] + eps * adjusted_sigma 51 | 52 | # Add the correct amount of fresh noise 53 | if eta: 54 | x += torch.randn_like(x) * ddim_sigma 55 | 56 | # If we are on the last timestep, output the denoised image 57 | return pred 58 | 59 | 60 | def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): 61 | """Constructs the noise schedule of Karras et al. (2022).""" 62 | ramp = torch.linspace(0, 1, n) 63 | min_inv_rho = sigma_min ** (1 / rho) 64 | max_inv_rho = sigma_max ** (1 / rho) 65 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho 66 | sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) 67 | return sigmas.to(device) 68 | 69 | 70 | def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): 71 | """Constructs an exponential noise schedule.""" 72 | sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp() 73 | return torch.cat([sigmas, sigmas.new_zeros([1])]) 74 | 75 | def make_sample_density(config): 76 | config = config['sigma_sample_density'] 77 | if config['type'] == 'lognormal': 78 | loc = config['mean'] if 'mean' in config else config['loc'] 79 | scale = config['std'] if 'std' in config else config['scale'] 80 | return partial(utils.rand_log_normal, loc=loc, scale=scale) 81 | if config['type'] == 'loglogistic': 82 | loc = config['loc'] 83 | scale = config['scale'] 84 | min_value = config['min_value'] if 'min_value' in config else 0. 85 | max_value = config['max_value'] if 'max_value' in config else float('inf') 86 | return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value) 87 | if config['type'] == 'loguniform': 88 | min_value = config['min_value'] 89 | max_value = config['max_value'] 90 | return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value) 91 | raise ValueError('Unknown sample density type') 92 | 93 | 94 | def to_d(x, sigma, denoised): 95 | """Converts a denoiser output to a Karras ODE derivative.""" 96 | return (x - denoised) / utils.append_dims(sigma, x.ndim) 97 | 98 | 99 | @torch.no_grad() 100 | def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, second_order=True, s_churn=0.): 101 | extra_args = {} if extra_args is None else extra_args 102 | s_in = x.new_ones([x.shape[0]]) 103 | for i in trange(len(sigmas) - 1, disable=disable): 104 | gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) 105 | eps = torch.randn_like(x) 106 | sigma_hat = sigmas[i] * (gamma + 1) 107 | if gamma > 0: 108 | x = x + eps * (sigma_hat**2 - sigmas[i]**2)**0.5 109 | denoised = model(x, sigma_hat * s_in, **extra_args) 110 | d = to_d(x, sigma_hat, denoised) 111 | if callback is not None: 112 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) 113 | dt = sigmas[i + 1] - sigma_hat 114 | if sigmas[i + 1] == 0 or not second_order: 115 | x = x + d * dt 116 | else: 117 | x_2 = x + d * dt 118 | denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) 119 | d_2 = to_d(x_2, sigmas[i + 1], denoised_2) 120 | d_prime = (d + d_2) / 2 121 | x = x + d_prime * dt 122 | return x 123 | 124 | 125 | def linear_multistep_coeff(order, t, i, j): 126 | if order - 1 > i: 127 | raise ValueError(f'Order {order} too high for step {i}') 128 | def fn(tau): 129 | prod = 1. 130 | for k in range(order): 131 | if j == k: 132 | continue 133 | prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) 134 | return prod 135 | return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] 136 | 137 | 138 | @torch.no_grad() 139 | def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4): 140 | extra_args = {} if extra_args is None else extra_args 141 | s_in = x.new_ones([x.shape[0]]) 142 | ds = [] 143 | for i in trange(len(sigmas) - 1, disable=disable): 144 | denoised = model(x, sigmas[i] * s_in, **extra_args) 145 | d = to_d(x, sigmas[i], denoised) 146 | ds.append(d) 147 | if len(ds) > order: 148 | ds.pop(0) 149 | if callback is not None: 150 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) 151 | cur_order = min(i + 1, order) 152 | coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)] 153 | x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) 154 | return x -------------------------------------------------------------------------------- /train_archive/train_dvae_k.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | from copy import deepcopy 5 | import json 6 | import math 7 | from pathlib import Path 8 | 9 | from prefigure.prefigure import get_all_args, push_wandb_config 10 | 11 | import accelerate 12 | import torch 13 | from torch import optim 14 | from torch import multiprocessing as mp 15 | from torch.utils import data 16 | from torchvision import utils 17 | from tqdm import tqdm 18 | 19 | from dataset.dataset import SampleDataset 20 | from decoders.diffusion_decoder import AudioDenoiserModel, Denoiser 21 | 22 | from blocks import utils 23 | 24 | from diffusion import inference 25 | 26 | def main(): 27 | 28 | args = get_all_args() 29 | 30 | mp.set_start_method(args.start_method) 31 | 32 | model_config = json.load(open(args.model_config)) 33 | 34 | size = model_config['input_size'] 35 | 36 | accelerator = accelerate.Accelerator() 37 | device = accelerator.device 38 | print('Using device:', device, flush=True) 39 | 40 | assert model_config['type'] == 'audio_v1' 41 | 42 | inner_model = AudioDenoiserModel( 43 | model_config['input_channels'], 44 | model_config['mapping_out'], 45 | model_config['depths'], 46 | model_config['channels'], 47 | model_config['self_attn_depths'], 48 | dropout_rate=model_config['dropout_rate'], 49 | #mapping_cond_dim=9, 50 | ) 51 | accelerator.print('Parameters:', utils.n_params(inner_model)) 52 | 53 | # If logging to wandb, initialize the run 54 | use_wandb = accelerator.is_main_process and args.wandb_project 55 | if use_wandb: 56 | import wandb 57 | config = vars(args) 58 | config['model_config'] = model_config 59 | config['params'] = utils.n_params(inner_model) 60 | wandb.init(project=args.wandb_project, entity=args.wandb_entity, group=args.wandb_group, config=config, save_code=True) 61 | 62 | opt = optim.AdamW(inner_model.parameters(), lr=args.lr, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3) 63 | sched = utils.InverseLR(opt, inv_gamma=50000, power=1/2, warmup=0.99) 64 | ema_sched = utils.EMAWarmup(power=2/3, max_value=0.9999) 65 | 66 | train_set = SampleDataset([args.training_dir], args) 67 | train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True, 68 | num_workers=args.num_workers, persistent_workers=True, pin_memory=True, drop_last=True) 69 | 70 | inner_model, opt, train_dl = accelerator.prepare(inner_model, opt, train_dl) 71 | 72 | if use_wandb: 73 | wandb.watch(inner_model) 74 | 75 | model = Denoiser(inner_model, sigma_data=model_config['sigma_data']) 76 | model_ema = deepcopy(model) 77 | 78 | if args.resume: 79 | ckpt = torch.load(args.resume, map_location='cpu') 80 | accelerator.unwrap_model(model.inner_model).load_state_dict(ckpt['model']) 81 | accelerator.unwrap_model(model_ema.inner_model).load_state_dict(ckpt['model_ema']) 82 | opt.load_state_dict(ckpt['opt']) 83 | sched.load_state_dict(ckpt['sched']) 84 | ema_sched.load_state_dict(ckpt['ema_sched']) 85 | epoch = ckpt['epoch'] + 1 86 | step = ckpt['step'] + 1 87 | del ckpt 88 | else: 89 | epoch = 0 90 | step = 0 91 | 92 | 93 | sigma_min = model_config['sigma_min'] 94 | sigma_max = model_config['sigma_max'] 95 | assert model_config['sigma_sample_density']['type'] == 'lognormal' 96 | sigma_mean = model_config['sigma_sample_density']['mean'] 97 | sigma_std = model_config['sigma_sample_density']['std'] 98 | 99 | @torch.no_grad() 100 | @utils.eval_mode(model_ema) 101 | def demo(): 102 | if accelerator.is_local_main_process: 103 | tqdm.write('Sampling...') 104 | filename = f'{args.name}_demo_{step:08}.png' 105 | n_per_proc = math.ceil(args.n_to_sample / accelerator.num_processes) 106 | x = torch.randn([n_per_proc, model_config['input_channels'], size[0]], device=device) * sigma_max 107 | sigmas = inference.get_sigmas_karras(50, sigma_min, sigma_max, rho=7., device=device) 108 | x_0 = inference.sample_lms(model_ema, x, sigmas, disable=not accelerator.is_local_main_process) 109 | x_0 = accelerator.gather(x_0)[:args.n_to_sample] 110 | 111 | if accelerator.is_main_process: 112 | # Demo sample logging. If you get here, great job! 113 | print(x_0.shape) 114 | 115 | def save(): 116 | accelerator.wait_for_everyone() 117 | filename = f'{args.name}_{step:08}.pth' 118 | if accelerator.is_local_main_process: 119 | tqdm.write(f'Saving to {filename}...') 120 | obj = { 121 | 'model': accelerator.unwrap_model(model.inner_model).state_dict(), 122 | 'model_ema': accelerator.unwrap_model(model_ema.inner_model).state_dict(), 123 | 'opt': opt.state_dict(), 124 | 'sched': sched.state_dict(), 125 | 'ema_sched': ema_sched.state_dict(), 126 | 'epoch': epoch, 127 | 'step': step 128 | } 129 | accelerator.save(obj, filename) 130 | if args.wandb_save_model and use_wandb: 131 | wandb.save(filename) 132 | 133 | try: 134 | while True: 135 | for batch in tqdm(train_dl, disable=not accelerator.is_local_main_process): 136 | opt.zero_grad() 137 | reals = batch[0] 138 | noise = torch.randn_like(reals) 139 | sigma = torch.distributions.LogNormal(sigma_mean, sigma_std).sample([reals.shape[0]]).to(device) 140 | loss = model.loss(reals, noise, sigma).mean() 141 | accelerator.backward(loss) 142 | opt.step() 143 | sched.step() 144 | ema_decay = ema_sched.get_value() 145 | utils.ema_update(model, model_ema, ema_decay) 146 | ema_sched.step() 147 | 148 | if accelerator.is_local_main_process: 149 | if step % 25 == 0: 150 | tqdm.write(f'Epoch: {epoch}, step: {step}, loss: {loss.item():g}') 151 | 152 | if use_wandb: 153 | log_dict = { 154 | 'epoch': epoch, 155 | 'loss': loss.item(), 156 | 'lr': sched.get_last_lr()[0], 157 | 'ema_decay': ema_decay, 158 | } 159 | wandb.log(log_dict, step=step) 160 | 161 | if step % args.demo_every == 0: 162 | demo() 163 | 164 | if step > 0 and step % args.save_every == 0: 165 | save() 166 | 167 | step += 1 168 | epoch += 1 169 | except KeyboardInterrupt: 170 | pass 171 | 172 | 173 | if __name__ == '__main__': 174 | main() -------------------------------------------------------------------------------- /test_data_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from prefigure.prefigure import get_all_args, push_wandb_config 4 | from contextlib import contextmanager 5 | from copy import deepcopy 6 | import math 7 | from pathlib import Path 8 | 9 | import sys, re 10 | import random 11 | import torch 12 | from torch import optim, nn 13 | from torch.nn import functional as F 14 | from torchaudio import transforms as T 15 | from torch.utils import data 16 | from tqdm import trange 17 | from einops import rearrange 18 | import numpy as np 19 | import torchaudio 20 | 21 | from functools import partial 22 | 23 | import wandb 24 | 25 | from dataset.dataset import get_all_s3_urls, get_s3_contents, get_wds_loader, wds_preprocess, log_and_continue, is_valid_sample 26 | from prompts.prompters import get_prompt_from_jmann_metadata, get_prompt_from_fma_metadata, get_prompt_from_audio_file_metadata 27 | import webdataset as wds 28 | import time 29 | 30 | def base_plus_ext(path): 31 | """Split off all file extensions. 32 | Returns base, allext. 33 | :param path: path with extensions 34 | :param returns: path with all extensions removed 35 | """ 36 | match = re.match(r"^((?:.*/|)[^.]+)[.]([^/]*)$", path) 37 | if not match: 38 | return None, None 39 | return match.group(1), match.group(2) 40 | 41 | def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): 42 | """Return function over iterator that groups key, value pairs into samples. 43 | :param keys: function that splits the key into key and extension (base_plus_ext) 44 | :param lcase: convert suffixes to lower case (Default value = True) 45 | """ 46 | print("Running new function") 47 | current_sample = None 48 | for filesample in data: 49 | assert isinstance(filesample, dict) 50 | fname, value = filesample["fname"], filesample["data"] 51 | prefix, suffix = keys(fname) 52 | if wds.tariterators.trace: 53 | print( 54 | prefix, 55 | suffix, 56 | current_sample.keys() if isinstance(current_sample, dict) else None, 57 | ) 58 | if prefix is None: 59 | continue 60 | if lcase: 61 | suffix = suffix.lower() 62 | if current_sample is None or prefix != current_sample["__key__"]: 63 | if valid_sample(current_sample): 64 | yield current_sample 65 | current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) 66 | if suffix in current_sample: 67 | raise ValueError( 68 | f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}" 69 | ) 70 | if suffixes is None or suffix in suffixes: 71 | current_sample[suffix] = value 72 | if valid_sample(current_sample): 73 | yield current_sample 74 | 75 | def valid_sample(sample): 76 | """Check whether a sample is valid. 77 | :param sample: sample to be checked 78 | """ 79 | return ( 80 | sample is not None 81 | and isinstance(sample, dict) 82 | and len(list(sample.keys())) > 0 83 | and not sample.get("__bad__", False) 84 | ) 85 | 86 | 87 | # Creates and returns a text prompt given a metadata object 88 | def get_prompt_from_metadata(metadata): 89 | 90 | #print(metadata) 91 | 92 | return "" 93 | 94 | 95 | def main(): 96 | 97 | args = get_all_args() 98 | 99 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 100 | print('Using device:', device) 101 | torch.manual_seed(args.seed) 102 | 103 | print("Creating data loader") 104 | 105 | # train_dl = iter(get_wds_loader( 106 | # batch_size=args.batch_size, 107 | # s3_url_prefix=None, 108 | # sample_size=args.sample_size, 109 | # names=names, 110 | # sample_rate=args.sample_rate, 111 | # num_workers=args.num_workers, 112 | # recursive=True, 113 | # random_crop=True 114 | # )) 115 | 116 | 117 | 118 | names = [ 119 | 120 | ] 121 | 122 | metadata_prompt_funcs = {} 123 | 124 | 125 | urls = get_all_s3_urls( 126 | names=names, 127 | #s3_url_prefix="", 128 | recursive=True, 129 | ) 130 | 131 | preprocess_fn = partial(wds_preprocess, 132 | sample_size=args.sample_size, 133 | sample_rate=args.sample_rate, 134 | random_crop=args.random_crop, 135 | #verbose=True, 136 | normalize_lufs=-12.0, 137 | metadata_prompt_funcs=metadata_prompt_funcs 138 | ) 139 | 140 | 141 | def print_inputs(inputs): 142 | print(f"Sample: {inputs}") 143 | return inputs 144 | 145 | wds.tariterators.group_by_keys = group_by_keys 146 | 147 | dataset = wds.DataPipeline( 148 | wds.ResampledShards(urls), # Yields a single .tar URL 149 | wds.split_by_worker, 150 | wds.map(print_inputs), 151 | wds.tarfile_to_samples(handler=log_and_continue), # Opens up a stream to the TAR file, yields files grouped by keys 152 | #wds.shuffle(bufsize=100, initial=10), # Pulls from iterator until initial value 153 | wds.decode(wds.torch_audio, handler=log_and_continue), 154 | #wds.LMDBCached(f"/scratch/wds_lmdb_{time.time()}", mrm -ap_size=6e12), 155 | wds.map(preprocess_fn, handler=log_and_continue), 156 | wds.select(is_valid_sample), 157 | #wds.Cached(), 158 | #wds.shuffle(bufsize=100, initial=10, handler=log_and_continue), # Pulls from iterator until initial value 159 | wds.to_tuple("audio", "json", "timestamps", handler=log_and_continue), 160 | wds.batched(args.batch_size, partial=False) 161 | ) 162 | 163 | train_dl = wds.WebLoader(dataset, num_workers=args.num_workers) 164 | 165 | print("Creating data loader") 166 | 167 | max_seconds_total = 0 168 | 169 | #for json in train_dl: 170 | for epoch_num in range(1): 171 | train_iter = iter(train_dl) 172 | print(f"Starting epoch {epoch_num}") 173 | start_time = time.time() 174 | for i, sample in enumerate(train_iter): 175 | #json = next(train_dl) 176 | audio, jsons, timestamps = sample 177 | print(f"Epoch {epoch_num} Batch {i}") 178 | for json in jsons: 179 | prompt = json["prompt"][0] 180 | seconds_total = json["seconds_total"][0].item() 181 | if seconds_total > max_seconds_total: 182 | max_seconds_total = seconds_total 183 | print(prompt) 184 | print(max_seconds_total) 185 | 186 | print(max_seconds_total) 187 | # print(audio.shape) 188 | samples_per_sec = ((i+1) * args.batch_size) / (time.time() - start_time) 189 | #print(f"Samples/sec this epoch: {samples_per_sec}") 190 | #time.sleep(5.0) 191 | 192 | if __name__ == '__main__': 193 | main() -------------------------------------------------------------------------------- /train_archive/train_ad_uncond_full.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from prefigure.prefigure import get_all_args, push_wandb_config 4 | from contextlib import contextmanager 5 | from copy import deepcopy 6 | import math 7 | from pathlib import Path 8 | 9 | import sys 10 | import torch 11 | from torch import optim, nn 12 | from torch.nn import functional as F 13 | from torch.utils import data 14 | from tqdm import trange 15 | import pytorch_lightning as pl 16 | from pytorch_lightning.utilities.distributed import rank_zero_only 17 | from einops import rearrange 18 | 19 | from diffusion.pqmf import CachedPQMF as PQMF 20 | import torchaudio 21 | 22 | from ema_pytorch import EMA 23 | 24 | import auraloss 25 | 26 | import wandb 27 | 28 | from aeiou.datasets import AudioDataset 29 | 30 | from audio_diffusion_pytorch import AudioDiffusionModel, LogNormalDistribution, Distribution 31 | from diffusion.model import ema_update 32 | from aeiou.viz import audio_spectrogram_image 33 | 34 | class DiffusionUncond(pl.LightningModule): 35 | def __init__(self, global_args): 36 | super().__init__() 37 | 38 | self.diffusion = AudioDiffusionModel( 39 | in_channels=2, 40 | channels=256, 41 | patch_blocks=1, 42 | patch_factor=1, 43 | multipliers=[4, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3], 44 | factors=[1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2], 45 | num_blocks=[1, 1, 1, 2, 2, 2, 2, 4, 4, 4, 4], 46 | attentions=[0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2], 47 | attention_heads=8, 48 | attention_features=64, 49 | attention_multiplier=2, 50 | attention_use_rel_pos=False, 51 | resnet_groups=8, 52 | kernel_multiplier_downsample=2, 53 | use_nearest_upsample=False, 54 | use_skip_scale=True, 55 | use_context_time=True, 56 | use_magnitude_channels=False, 57 | use_stft=True, 58 | stft_num_fft=1023, 59 | stft_hop_length=256, 60 | diffusion_type="v", 61 | diffusion_sigma_distribution=UniformDistribution(), 62 | ) 63 | 64 | self.diffusion_ema = EMA( 65 | self.diffusion, 66 | beta=0.9999, 67 | power=3/4, 68 | update_after_step=100, 69 | update_every=1 70 | ) 71 | 72 | def configure_optimizers(self): 73 | return optim.Adam([*self.diffusion.parameters()], lr=4e-5) 74 | 75 | def training_step(self, batch, batch_idx): 76 | reals = batch 77 | 78 | batch_stdev = torch.std(batch.detach()) 79 | 80 | with torch.cuda.amp.autocast(): 81 | loss = self.diffusion(reals) 82 | 83 | log_dict = { 84 | 'train/loss': loss.detach(), 85 | 'train/data_stdev': batch_stdev 86 | } 87 | 88 | self.log_dict(log_dict, prog_bar=True, on_step=True) 89 | return loss 90 | 91 | def on_before_zero_grad(self, *args, **kwargs): 92 | self.diffusion_ema.update() 93 | 94 | class ExceptionCallback(pl.Callback): 95 | def on_exception(self, trainer, module, err): 96 | print(f'{type(err).__name__}: {err}', file=sys.stderr) 97 | 98 | 99 | class DemoCallback(pl.Callback): 100 | def __init__(self, global_args): 101 | super().__init__() 102 | self.demo_every = global_args.demo_every 103 | self.num_demos = global_args.num_demos 104 | self.demo_samples = global_args.sample_size 105 | self.demo_steps = global_args.demo_steps 106 | self.sample_rate = global_args.sample_rate 107 | 108 | @rank_zero_only 109 | @torch.no_grad() 110 | #def on_train_epoch_end(self, trainer, module): 111 | def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): 112 | last_demo_step = -1 113 | if (trainer.global_step - 1) % self.demo_every != 0 or last_demo_step == trainer.global_step: 114 | #if trainer.current_epoch % self.demo_every != 0: 115 | return 116 | 117 | last_demo_step = trainer.global_step 118 | 119 | try: 120 | print("Creating noise") 121 | noise = torch.randn([self.num_demos, 2, self.demo_samples]).to(module.device) 122 | 123 | print("Starting sampling") 124 | 125 | fakes = module.diffusion_ema.ema_model.sample(noise=noise, num_steps=self.demo_steps) 126 | 127 | print("Rearranging demos") 128 | # Put the demos together 129 | fakes = rearrange(fakes, 'b d n -> d (b n)') 130 | 131 | log_dict = {} 132 | 133 | filename = f'demo_{trainer.global_step:08}.wav' 134 | fakes = fakes.clamp(-1, 1).mul(32767).to(torch.int16).cpu() 135 | torchaudio.save(filename, fakes, self.sample_rate) 136 | 137 | 138 | log_dict[f'demo'] = wandb.Audio(filename, 139 | sample_rate=self.sample_rate, 140 | caption=f'Demos') 141 | 142 | 143 | log_dict[f'demo_melspec_left'] = wandb.Image(audio_spectrogram_image(fakes)) 144 | 145 | 146 | trainer.logger.experiment.log(log_dict, step=trainer.global_step) 147 | except Exception as e: 148 | print(f'{type(e).__name__}: {e}') 149 | 150 | def main(): 151 | 152 | args = get_all_args() 153 | 154 | args.latent_dim = 0 155 | #args.random_crop = False 156 | 157 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 158 | print('Using device:', device) 159 | torch.manual_seed(args.seed) 160 | 161 | train_set = AudioDataset( 162 | [args.training_dir], 163 | sample_rate=args.sample_rate, 164 | sample_size=args.sample_size, 165 | random_crop=args.random_crop, 166 | augs='Stereo()' 167 | ) 168 | 169 | train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True, 170 | num_workers=args.num_workers, persistent_workers=True, pin_memory=True) 171 | wandb_logger = pl.loggers.WandbLogger(project=args.name) 172 | exc_callback = ExceptionCallback() 173 | ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, save_top_k=-1) 174 | demo_callback = DemoCallback(args) 175 | diffusion_model = DiffusionUncond(args) 176 | wandb_logger.watch(diffusion_model) 177 | push_wandb_config(wandb_logger, args) 178 | 179 | diffusion_trainer = pl.Trainer( 180 | devices=args.num_gpus, 181 | accelerator="gpu", 182 | num_nodes = args.num_nodes, 183 | strategy='ddp_find_unused_parameters_false', 184 | precision=16, 185 | accumulate_grad_batches=args.accum_batches, 186 | callbacks=[ckpt_callback, demo_callback, exc_callback], 187 | logger=wandb_logger, 188 | log_every_n_steps=1, 189 | max_epochs=10000000, 190 | ) 191 | 192 | diffusion_trainer.fit(diffusion_model, train_dl, ckpt_path=args.ckpt_path) 193 | 194 | if __name__ == '__main__': 195 | main() -------------------------------------------------------------------------------- /diffusion/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import random 4 | import math 5 | 6 | # Define the diffusion noise schedule 7 | def get_alphas_sigmas(t): 8 | return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) 9 | 10 | class PadCrop(nn.Module): 11 | def __init__(self, n_samples, randomize=True): 12 | super().__init__() 13 | self.n_samples = n_samples 14 | self.randomize = randomize 15 | 16 | def __call__(self, signal): 17 | n, s = signal.shape 18 | start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item() 19 | end = start + self.n_samples 20 | output = signal.new_zeros([n, self.n_samples]) 21 | output[:, :min(s, self.n_samples)] = signal[:, start:end] 22 | return output 23 | 24 | from typing import Tuple 25 | 26 | class PadCrop_Normalized_T(nn.Module): 27 | 28 | def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): 29 | 30 | super().__init__() 31 | 32 | self.n_samples = n_samples 33 | self.sample_rate = sample_rate 34 | self.randomize = randomize 35 | 36 | def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]: 37 | 38 | n_channels, n_samples = source.shape 39 | 40 | upper_bound = max(0, n_samples - self.n_samples) 41 | 42 | offset = 0 43 | if(self.randomize and n_samples > self.n_samples): 44 | offset = random.randint(0, upper_bound + 1) 45 | 46 | t_start = offset / (upper_bound + self.n_samples) 47 | t_end = (offset + self.n_samples) / (upper_bound + self.n_samples) 48 | 49 | chunk = source.new_zeros([n_channels, self.n_samples]) 50 | chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples] 51 | 52 | seconds_start = math.floor(offset / self.sample_rate) 53 | seconds_total = math.ceil(n_samples / self.sample_rate) 54 | 55 | return ( 56 | chunk, 57 | t_start, 58 | t_end, 59 | seconds_start, 60 | seconds_total 61 | ) 62 | 63 | class PhaseFlipper(nn.Module): 64 | "she was PHAAAAAAA-AAAASE FLIPPER, a random invert yeah" 65 | def __init__(self, p=0.5): 66 | super().__init__() 67 | self.p = p 68 | def __call__(self, signal): 69 | return -signal if (random.random() < self.p) else signal 70 | 71 | 72 | class FillTheNoise(nn.Module): 73 | "randomly adds a bit of noise, just to spice things up" 74 | def __init__(self, p=0.33): 75 | super().__init__() 76 | self.p = p 77 | def __call__(self, signal): 78 | return signal + 0.25*random.random()*(2*torch.rand_like(signal)-1) if (random.random() < self.p) else signal 79 | 80 | 81 | class OneMinus(nn.Module): 82 | "aka Destructo: subtracts the signal from +/- 1, just to spice things up" 83 | def __init__(self, p=0.2): 84 | super().__init__() 85 | self.p = p 86 | def __call__(self, signal): 87 | return 0.9*torch.sign(signal) - signal if (random.random() < self.p) else signal 88 | 89 | class RandPool(nn.Module): 90 | def __init__(self, p=0.2): 91 | self.p, self.maxkern = p, 100 92 | def __call__(self, signal): 93 | if (random.random() < self.p): 94 | ksize = int(random.random()*self.maxkern) 95 | avger = nn.AvgPool1d(kernel_size=ksize, stride=1, padding=1) 96 | return avger(signal) 97 | else: 98 | return signal 99 | 100 | 101 | class NormInputs(nn.Module): 102 | "useful for quiet inputs. intended to be part of augmentation chain; not activated by default" 103 | def __init__(self, do_norm=False): 104 | super().__init__() 105 | self.do_norm = do_norm 106 | self.eps = 1e-2 107 | def __call__(self, signal): 108 | return signal if (not self.do_norm) else signal/(torch.amax(signal,-1)[0] + self.eps) 109 | 110 | class Mono(nn.Module): 111 | def __call__(self, signal): 112 | return torch.mean(signal, dim=0) if len(signal.shape) > 1 else signal 113 | 114 | class Stereo(nn.Module): 115 | def __call__(self, signal): 116 | signal_shape = signal.shape 117 | # Check if it's mono 118 | if len(signal_shape) == 1: # s -> 2, s 119 | signal = signal.unsqueeze(0).repeat(2, 1) 120 | elif len(signal_shape) == 2: 121 | if signal_shape[0] == 1: #1, s -> 2, s 122 | signal = signal.repeat(2, 1) 123 | elif signal_shape[0] > 2: #?, s -> 2,s 124 | signal = signal[:2, :] 125 | 126 | return signal 127 | 128 | class RandomGain(nn.Module): 129 | def __init__(self, min_gain, max_gain): 130 | super().__init__() 131 | self.min_gain = min_gain 132 | self.max_gain = max_gain 133 | 134 | def __call__(self, signal): 135 | gain = random.uniform(self.min_gain, self.max_gain) 136 | signal = signal * gain 137 | 138 | return signal 139 | 140 | class MidSideEncoding(nn.Module): 141 | def __call__(self, signal): 142 | #signal_shape should be 2, s 143 | left = signal[0] 144 | right = signal[1] 145 | mid = (left + right) / 2 146 | side = (left - right) / 2 147 | signal[0] = mid 148 | signal[1] = side 149 | 150 | return signal 151 | 152 | 153 | 154 | # Taken from https://github.com/serkansulun/pytorch-pixelshuffle1d/blob/master/pixelshuffle1d.py 155 | 156 | class PixelShuffle1D(torch.nn.Module): 157 | """ 158 | 1D pixel shuffler. https://arxiv.org/pdf/1609.05158.pdf 159 | Upscales sample length, downscales channel length 160 | "short" is input, "long" is output 161 | """ 162 | def __init__(self, upscale_factor): 163 | super(PixelShuffle1D, self).__init__() 164 | self.upscale_factor = upscale_factor 165 | 166 | def forward(self, x): 167 | batch_size = x.shape[0] 168 | short_channel_len = x.shape[1] 169 | short_width = x.shape[2] 170 | 171 | long_channel_len = short_channel_len // self.upscale_factor 172 | long_width = self.upscale_factor * short_width 173 | 174 | x = x.contiguous().view([batch_size, self.upscale_factor, long_channel_len, short_width]) 175 | x = x.permute(0, 2, 3, 1).contiguous() 176 | x = x.view(batch_size, long_channel_len, long_width) 177 | 178 | return x 179 | 180 | class PixelUnshuffle1D(torch.nn.Module): 181 | """ 182 | Inverse of 1D pixel shuffler 183 | Upscales channel length, downscales sample length 184 | "long" is input, "short" is output 185 | """ 186 | def __init__(self, downscale_factor): 187 | super(PixelUnshuffle1D, self).__init__() 188 | self.downscale_factor = downscale_factor 189 | 190 | def forward(self, x): 191 | batch_size = x.shape[0] 192 | long_channel_len = x.shape[1] 193 | long_width = x.shape[2] 194 | 195 | short_channel_len = long_channel_len * self.downscale_factor 196 | short_width = long_width // self.downscale_factor 197 | 198 | x = x.contiguous().view([batch_size, long_channel_len, short_width, self.downscale_factor]) 199 | x = x.permute(0, 3, 1, 2).contiguous() 200 | x = x.view([batch_size, short_channel_len, short_width]) 201 | return x -------------------------------------------------------------------------------- /prune_ckpt.py: -------------------------------------------------------------------------------- 1 | #@title Imports and definitions 2 | import argparse 3 | from contextlib import contextmanager 4 | from copy import deepcopy 5 | import math 6 | from pathlib import Path 7 | 8 | import sys 9 | import gc 10 | 11 | from autoencoders.soundstream import SoundStreamXLEncoder, SoundStreamXLDecoder 12 | from autoencoders.models import AudioAutoencoder 13 | from audio_encoders_pytorch import Encoder1d 14 | from ema_pytorch import EMA 15 | from audio_diffusion_pytorch.modules import UNetCFG1d 16 | 17 | from audio_diffusion_pytorch import T5Embedder, NumberEmbedder 18 | 19 | import torch 20 | from torch import optim, nn 21 | from torch.nn import functional as F 22 | from torch.utils import data 23 | from tqdm import trange 24 | from einops import rearrange 25 | 26 | import torchaudio 27 | from decoders.diffusion_decoder import DiffusionAttnUnet1D 28 | import numpy as np 29 | 30 | import random 31 | from diffusion.utils import Stereo, PadCrop 32 | from glob import glob 33 | 34 | from torch.nn.parameter import Parameter 35 | 36 | class LatentAudioDiffusionAutoencoder(nn.Module): 37 | def __init__(self, autoencoder: AudioAutoencoder): 38 | super().__init__() 39 | 40 | 41 | self.latent_dim = autoencoder.latent_dim 42 | 43 | self.second_stage_latent_dim = 32 44 | 45 | factors = [2, 2, 2, 2] 46 | 47 | self.latent_downsampling_ratio = np.prod(factors) 48 | 49 | self.downsampling_ratio = autoencoder.downsampling_ratio * self.latent_downsampling_ratio 50 | 51 | self.latent_encoder = Encoder1d( 52 | in_channels=self.latent_dim, 53 | out_channels = self.second_stage_latent_dim, 54 | channels = 128, 55 | multipliers = [1, 2, 4, 8, 8], 56 | factors = factors, 57 | num_blocks = [8, 8, 8, 8], 58 | ) 59 | 60 | self.diffusion = DiffusionAttnUnet1D( 61 | io_channels=self.latent_dim, 62 | cond_dim = self.second_stage_latent_dim, 63 | n_attn_layers=0, 64 | c_mults=[512] * 10, 65 | depth=10 66 | ) 67 | 68 | self.autoencoder = autoencoder 69 | 70 | self.autoencoder.requires_grad_(False) 71 | 72 | self.rng = torch.quasirandom.SobolEngine(1, scramble=True) 73 | 74 | def encode(self, reals): 75 | first_stage_latents = self.autoencoder.encode(reals) 76 | 77 | second_stage_latents = self.latent_encoder(first_stage_latents) 78 | 79 | second_stage_latents = torch.tanh(second_stage_latents) 80 | 81 | return second_stage_latents 82 | 83 | def decode(self, latents, steps=250, device="cuda"): 84 | first_stage_latent_noise = torch.randn([latents.shape[0], self.latent_dim, latents.shape[2]*self.latent_downsampling_ratio]).to(device) 85 | 86 | t = torch.linspace(1, 0, steps + 1, device=device)[:-1] 87 | 88 | step_list = get_spliced_ddpm_cosine_schedule(t) 89 | 90 | first_stage_sampled = sampling.iplms_sample(self.diffusion, first_stage_latent_noise, step_list, {"cond":latents}) 91 | #first_stage_sampled = sample(self.diffusion, first_stage_latent_noise, steps, 0, cond=latents) 92 | decoded = self.autoencoder.decode(first_stage_sampled) 93 | return decoded 94 | 95 | class StackedAELatentDiffusionCond(nn.Module): 96 | def __init__(self, latent_ae: LatentAudioDiffusionAutoencoder, diffusion_config): 97 | super().__init__() 98 | 99 | self.latent_dim = latent_ae.second_stage_latent_dim 100 | self.downsampling_ratio = latent_ae.downsampling_ratio 101 | 102 | embedding_max_len = 128 103 | 104 | self.embedder = T5Embedder(model='t5-base', max_length=embedding_max_len).requires_grad_(False) 105 | 106 | self.embedding_features = 768 107 | 108 | self.timestamp_embedder = NumberEmbedder(features=self.embedding_features) 109 | 110 | self.diffusion = UNetCFG1d(**diffusion_config) 111 | 112 | # with torch.no_grad(): 113 | # for param in self.diffusion.parameters(): 114 | # param *= 0.5 115 | 116 | # self.diffusion_ema = EMA( 117 | # self.diffusion, 118 | # beta = 0.9999, 119 | # power=3/4, 120 | # update_every = 1, 121 | # update_after_step = 1000 122 | # ) 123 | 124 | self.autoencoder = latent_ae 125 | 126 | self.autoencoder.requires_grad_(False) 127 | 128 | self.rng = torch.quasirandom.SobolEngine(1, scramble=True) 129 | 130 | def encode(self, reals): 131 | return self.autoencoder.encode(reals) 132 | 133 | def decode(self, latents, steps=250, device="cuda"): 134 | return self.autoencoder.decode(latents, steps, device=device) 135 | 136 | 137 | def prune_ckpt_weights(stacked_state_dict): 138 | new_state_dict = {} 139 | for name, param in stacked_state_dict.items(): 140 | if name.startswith("diffusion_ema.ema_model."): 141 | new_name = name.replace("diffusion_ema.ema_model.", "diffusion.") 142 | if isinstance(param, Parameter): 143 | # backwards compatibility for serialized parameters 144 | param = param.data 145 | new_state_dict[new_name] = param 146 | elif name.startswith("autoencoder") or name.startswith("timestamp_embedder"): 147 | new_state_dict[name] = param 148 | 149 | return new_state_dict 150 | 151 | 152 | if __name__ == "__main__": 153 | 154 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 155 | parser.add_argument('--ckpt_path', help='Path to the checkpoint to be pruned') 156 | #parser.add_argument('--output_path', help='Path to the checkpoint to be pruned') 157 | args = parser.parse_args() 158 | 159 | print("Creating the model...") 160 | 161 | first_stage_config = {"capacity": 64, "c_mults": [2, 4, 8, 16, 32], "strides": [2, 2, 2, 2, 2], "latent_dim": 32} 162 | 163 | first_stage_autoencoder = AudioAutoencoder( 164 | **first_stage_config 165 | ).eval() 166 | 167 | diffusion_config = dict( 168 | in_channels = 32, 169 | context_embedding_features = 768, 170 | context_embedding_max_length = 128 + 2, #2 for timestep embeds 171 | channels = 256, 172 | resnet_groups = 8, 173 | kernel_multiplier_downsample = 2, 174 | multipliers = [2, 3, 4, 4], 175 | factors = [1, 2, 4], 176 | num_blocks = [3, 3, 3], 177 | attentions = [0, 3, 3, 3], 178 | attention_heads = 12, 179 | attention_features = 64, 180 | attention_multiplier = 4, 181 | attention_use_rel_pos=True, 182 | attention_rel_pos_max_distance=2048, 183 | attention_rel_pos_num_buckets=64, 184 | use_nearest_upsample = False, 185 | use_skip_scale = True, 186 | use_context_time = True 187 | ) 188 | 189 | latent_diffae = LatentAudioDiffusionAutoencoder(autoencoder=first_stage_autoencoder) 190 | 191 | model = StackedAELatentDiffusionCond(latent_diffae, diffusion_config=diffusion_config) 192 | 193 | ckpt_state_dict = torch.load(args.ckpt_path)["state_dict"] 194 | #print(ckpt_state_dict.keys()) 195 | 196 | new_ckpt = {} 197 | 198 | new_ckpt["state_dict"] = prune_ckpt_weights(ckpt_state_dict) 199 | 200 | new_ckpt["diffusion_config"] = diffusion_config 201 | 202 | model.load_state_dict(new_ckpt["state_dict"], strict=False) 203 | 204 | torch.save(new_ckpt, f'./pruned.ckpt') -------------------------------------------------------------------------------- /blocks/utils.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import warnings 3 | 4 | import torch 5 | from torch import optim 6 | 7 | def append_dims(x, target_dims): 8 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 9 | dims_to_append = target_dims - x.ndim 10 | if dims_to_append < 0: 11 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 12 | return x[(...,) + (None,) * dims_to_append] 13 | 14 | 15 | def n_params(module): 16 | """Returns the number of trainable parameters in a module.""" 17 | return sum(p.numel() for p in module.parameters()) 18 | 19 | 20 | @contextmanager 21 | def train_mode(model, mode=True): 22 | """A context manager that places a model into training mode and restores 23 | the previous mode on exit.""" 24 | modes = [module.training for module in model.modules()] 25 | try: 26 | yield model.train(mode) 27 | finally: 28 | for i, module in enumerate(model.modules()): 29 | module.training = modes[i] 30 | 31 | 32 | def eval_mode(model): 33 | """A context manager that places a model into evaluation mode and restores 34 | the previous mode on exit.""" 35 | return train_mode(model, False) 36 | 37 | 38 | @torch.no_grad() 39 | def ema_update(model, averaged_model, decay): 40 | """Incorporates updated model parameters into an exponential moving averaged 41 | version of a model. It should be called after each optimizer step.""" 42 | model_params = dict(model.named_parameters()) 43 | averaged_params = dict(averaged_model.named_parameters()) 44 | assert model_params.keys() == averaged_params.keys() 45 | 46 | for name, param in model_params.items(): 47 | averaged_params[name].mul_(decay).add_(param, alpha=1 - decay) 48 | 49 | model_buffers = dict(model.named_buffers()) 50 | averaged_buffers = dict(averaged_model.named_buffers()) 51 | assert model_buffers.keys() == averaged_buffers.keys() 52 | 53 | for name, buf in model_buffers.items(): 54 | averaged_buffers[name].copy_(buf) 55 | 56 | 57 | class EMAWarmup: 58 | """Implements an EMA warmup using an inverse decay schedule. 59 | If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are 60 | good values for models you plan to train for a million or more steps (reaches decay 61 | factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models 62 | you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at 63 | 215.4k steps). 64 | Args: 65 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. 66 | power (float): Exponential factor of EMA warmup. Default: 1. 67 | min_value (float): The minimum EMA decay rate. Default: 0. 68 | max_value (float): The maximum EMA decay rate. Default: 1. 69 | start_at (int): The epoch to start averaging at. Default: 0. 70 | last_epoch (int): The index of last epoch. Default: 0. 71 | """ 72 | 73 | def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0, 74 | last_epoch=0): 75 | self.inv_gamma = inv_gamma 76 | self.power = power 77 | self.min_value = min_value 78 | self.max_value = max_value 79 | self.start_at = start_at 80 | self.last_epoch = last_epoch 81 | 82 | def state_dict(self): 83 | """Returns the state of the class as a :class:`dict`.""" 84 | return dict(self.__dict__.items()) 85 | 86 | def load_state_dict(self, state_dict): 87 | """Loads the class's state. 88 | Args: 89 | state_dict (dict): scaler state. Should be an object returned 90 | from a call to :meth:`state_dict`. 91 | """ 92 | self.__dict__.update(state_dict) 93 | 94 | def get_value(self): 95 | """Gets the current EMA decay rate.""" 96 | epoch = max(0, self.last_epoch - self.start_at) 97 | value = 1 - (1 + epoch / self.inv_gamma) ** -self.power 98 | return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value)) 99 | 100 | def step(self): 101 | """Updates the step count.""" 102 | self.last_epoch += 1 103 | 104 | 105 | class InverseLR(optim.lr_scheduler._LRScheduler): 106 | """Implements an inverse decay learning rate schedule with an optional exponential 107 | warmup. When last_epoch=-1, sets initial lr as lr. 108 | inv_gamma is the number of steps/epochs required for the learning rate to decay to 109 | (1 / 2)**power of its original value. 110 | Args: 111 | optimizer (Optimizer): Wrapped optimizer. 112 | inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1. 113 | power (float): Exponential factor of learning rate decay. Default: 1. 114 | warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) 115 | Default: 0. 116 | final_lr (float): The final learning rate. Default: 0. 117 | last_epoch (int): The index of last epoch. Default: -1. 118 | verbose (bool): If ``True``, prints a message to stdout for 119 | each update. Default: ``False``. 120 | """ 121 | 122 | def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., final_lr=0., 123 | last_epoch=-1, verbose=False): 124 | self.inv_gamma = inv_gamma 125 | self.power = power 126 | if not 0. <= warmup < 1: 127 | raise ValueError('Invalid value for warmup') 128 | self.warmup = warmup 129 | self.final_lr = final_lr 130 | super().__init__(optimizer, last_epoch, verbose) 131 | 132 | def get_lr(self): 133 | if not self._get_lr_called_within_step: 134 | warnings.warn("To get the last learning rate computed by the scheduler, " 135 | "please use `get_last_lr()`.") 136 | 137 | return self._get_closed_form_lr() 138 | 139 | def _get_closed_form_lr(self): 140 | warmup = 1 - self.warmup ** (self.last_epoch + 1) 141 | lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power 142 | return [warmup * max(self.final_lr, base_lr * lr_mult) 143 | for base_lr in self.base_lrs] 144 | 145 | 146 | def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32): 147 | """Draws samples from an lognormal distribution.""" 148 | return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp() 149 | 150 | 151 | def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32): 152 | """Draws samples from an optionally truncated log-logistic distribution.""" 153 | min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64) 154 | max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64) 155 | min_cdf = min_value.log().sub(loc).div(scale).sigmoid() 156 | max_cdf = max_value.log().sub(loc).div(scale).sigmoid() 157 | u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf 158 | return u.logit().mul(scale).add(loc).exp().to(dtype) 159 | 160 | 161 | def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32): 162 | """Draws samples from an log-uniform distribution.""" 163 | min_value = math.log(min_value) 164 | max_value = math.log(max_value) 165 | return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp() -------------------------------------------------------------------------------- /autoencoders/soundstream.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from vector_quantize_pytorch import ResidualVQ 6 | 7 | def mod_sigmoid(x): 8 | return 2 * torch.sigmoid(x)**2.3 + 1e-7 9 | 10 | # Generator 11 | class CausalConv1d(nn.Conv1d): 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | self.causal_padding = self.dilation[0] * (self.kernel_size[0] - 1) 15 | 16 | def forward(self, x): 17 | return self._conv_forward(F.pad(x, [self.causal_padding, 0]), self.weight, self.bias) 18 | 19 | 20 | class CausalConvTranspose1d(nn.ConvTranspose1d): 21 | def __init__(self, *args, **kwargs): 22 | super().__init__(*args, **kwargs) 23 | self.causal_padding = self.dilation[0] * (self.kernel_size[0] - 1) + self.output_padding[0] + 1 - self.stride[0] 24 | 25 | def forward(self, x, output_size=None): 26 | if self.padding_mode != 'zeros': 27 | raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d') 28 | 29 | assert isinstance(self.padding, tuple) 30 | output_padding = self._output_padding( 31 | x, output_size, self.stride, self.padding, self.kernel_size, self.dilation) 32 | return F.conv_transpose1d( 33 | x, self.weight, self.bias, self.stride, self.padding, 34 | output_padding, self.groups, self.dilation)[...,:-self.causal_padding] 35 | 36 | 37 | class ResidualUnit(nn.Module): 38 | def __init__(self, in_channels, out_channels, dilation): 39 | super().__init__() 40 | 41 | self.dilation = dilation 42 | 43 | self.layers = nn.Sequential( 44 | CausalConv1d(in_channels=in_channels, out_channels=out_channels, 45 | kernel_size=7, dilation=dilation), 46 | nn.ELU(), 47 | nn.Conv1d(in_channels=out_channels, out_channels=out_channels, 48 | kernel_size=1) 49 | ) 50 | 51 | def forward(self, x): 52 | return x + self.layers(x) 53 | 54 | 55 | class EncoderBlock(nn.Module): 56 | def __init__(self, in_channels, out_channels, stride): 57 | super().__init__() 58 | 59 | self.layers = nn.Sequential( 60 | ResidualUnit(in_channels=in_channels, 61 | out_channels=in_channels, dilation=1), 62 | nn.ELU(), 63 | ResidualUnit(in_channels=in_channels, 64 | out_channels=in_channels, dilation=3), 65 | nn.ELU(), 66 | ResidualUnit(in_channels=in_channels, 67 | out_channels=in_channels, dilation=9), 68 | nn.ELU(), 69 | ResidualUnit(in_channels=in_channels, 70 | out_channels=in_channels, dilation=1), 71 | nn.ELU(), 72 | ResidualUnit(in_channels=in_channels, 73 | out_channels=in_channels, dilation=3), 74 | nn.ELU(), 75 | ResidualUnit(in_channels=in_channels, 76 | out_channels=in_channels, dilation=9), 77 | nn.ELU(), 78 | CausalConv1d(in_channels=in_channels, out_channels=out_channels, 79 | kernel_size=2*stride, stride=stride) 80 | ) 81 | 82 | def forward(self, x): 83 | return self.layers(x) 84 | 85 | 86 | class DecoderBlock(nn.Module): 87 | def __init__(self, in_channels, out_channels, stride): 88 | super().__init__() 89 | 90 | self.layers = nn.Sequential( 91 | CausalConvTranspose1d(in_channels=in_channels, 92 | out_channels=out_channels, 93 | kernel_size=2*stride, stride=stride), 94 | nn.ELU(), 95 | ResidualUnit(in_channels=out_channels, out_channels=out_channels, 96 | dilation=1), 97 | nn.ELU(), 98 | ResidualUnit(in_channels=out_channels, out_channels=out_channels, 99 | dilation=3), 100 | nn.ELU(), 101 | ResidualUnit(in_channels=out_channels, out_channels=out_channels, 102 | dilation=9), 103 | nn.ELU(), 104 | ResidualUnit(in_channels=out_channels, out_channels=out_channels, 105 | dilation=1), 106 | nn.ELU(), 107 | ResidualUnit(in_channels=out_channels, out_channels=out_channels, 108 | dilation=3), 109 | nn.ELU(), 110 | ResidualUnit(in_channels=out_channels, out_channels=out_channels, 111 | dilation=9), 112 | ) 113 | 114 | def forward(self, x): 115 | return self.layers(x) 116 | 117 | 118 | class SoundStreamXLEncoder(nn.Module): 119 | def __init__(self, in_channels=2, capacity=32, latent_dim=128, c_mults = [2, 4, 4, 4, 8, 16], strides = [2, 2, 2, 4, 5, 8]): 120 | super().__init__() 121 | 122 | c_mults = [1] + c_mults 123 | 124 | self.depth = len(c_mults) 125 | 126 | layers = [ 127 | CausalConv1d(in_channels=in_channels, out_channels=c_mults[0] * capacity, kernel_size=7), 128 | nn.ELU() 129 | ] 130 | 131 | for i in range(self.depth-1): 132 | layers.append(EncoderBlock(in_channels=c_mults[i]*capacity, out_channels=c_mults[i+1]*capacity, stride=strides[i])) 133 | layers.append(nn.ELU()) 134 | 135 | layers.append(CausalConv1d(in_channels=c_mults[-1]*capacity, out_channels=latent_dim, kernel_size=3)) 136 | 137 | self.layers = nn.Sequential(*layers) 138 | 139 | def forward(self, x): 140 | return self.layers(x) 141 | 142 | 143 | class SoundStreamXLDecoder(nn.Module): 144 | def __init__(self, out_channels=2, capacity=32, latent_dim=128, c_mults = [2, 4, 4, 4, 8, 16], strides = [2, 2, 2, 4, 5, 8]): 145 | super().__init__() 146 | 147 | c_mults = [1] + c_mults 148 | 149 | self.depth = len(c_mults) 150 | 151 | layers = [ 152 | CausalConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*capacity, kernel_size=7), 153 | nn.ELU() 154 | ] 155 | 156 | for i in range(self.depth-1, 0, -1): 157 | layers.append(DecoderBlock(in_channels=c_mults[i]*capacity, out_channels=c_mults[i-1]*capacity, stride=strides[i-1])) 158 | layers.append(nn.ELU()) 159 | 160 | layers.append(CausalConv1d(in_channels=c_mults[0] * capacity, out_channels=out_channels, kernel_size=7)) 161 | 162 | self.layers = nn.Sequential(*layers) 163 | 164 | def forward(self, x): 165 | return self.layers(x) 166 | 167 | 168 | # class SoundStreamXL(nn.Module): 169 | # def __init__(self, n_io_channels, n_feature_channels, latent_dim, n_quantizers=8, codebook_size=1024): 170 | # super().__init__() 171 | 172 | # self.encoder = SoundStreamXLEncoder(in_channels=n_io_channels, capacity=n_feature_channels, latent_dim=latent_dim) 173 | # self.decoder = SoundStreamXLDecoder(out_channels=n_io_channels, capacity=n_feature_channels, latent_dim=latent_dim) 174 | 175 | # self.quantizer = ResidualVQ( 176 | # num_quantizers=n_quantizers, 177 | # dim=latent_dim, 178 | # codebook_size=codebook_size, 179 | # kmeans_init=True, 180 | # kmeans_iters=100, 181 | # threshold_ema_dead_code=2, 182 | # #use_cosine_sim=True, 183 | # ) 184 | 185 | # def forward(self, x): 186 | # encoded = self.encoder(x) 187 | # quantized, indices, losses = self.quantizer(encoded) 188 | # decoded = self.decoder(quantized) 189 | # return decoded, indices, losses 190 | 191 | -------------------------------------------------------------------------------- /train_wavelet_transformer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from prefigure.prefigure import get_all_args, push_wandb_config 4 | from contextlib import contextmanager 5 | import math 6 | from pathlib import Path 7 | 8 | import sys 9 | import torch 10 | from torch import optim, nn 11 | from torch.nn import functional as F 12 | from torch.utils import data 13 | from tqdm import trange 14 | import pytorch_lightning as pl 15 | from pytorch_lightning.utilities.distributed import rank_zero_only 16 | from einops import rearrange 17 | import numpy as np 18 | import torchaudio 19 | 20 | import wandb 21 | 22 | from encoders.wavelets import WaveletEncode1d, WaveletDecode1d 23 | 24 | from blocks.utils import InverseLR 25 | from ema_pytorch import EMA 26 | from aeiou.viz import embeddings_table, pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image 27 | from aeiou.datasets import AudioDataset 28 | from dataset.dataset import SampleDataset 29 | from dataset.dataset import get_wds_loader 30 | 31 | from x_transformers import ContinuousTransformerWrapper, ContinuousAutoregressiveWrapper, Decoder 32 | 33 | class WaveletTransformer(pl.LightningModule): 34 | def __init__(self): 35 | super().__init__() 36 | 37 | self.levels = 8 38 | 39 | self.latent_dim = 2 ** (self.levels+1) 40 | self.downsampling_ratio = 2 ** self.levels 41 | 42 | self.transformer = ContinuousTransformerWrapper( 43 | dim_in=self.latent_dim, 44 | dim_out=self.latent_dim, 45 | max_seq_len=1024, 46 | attn_layers = Decoder( 47 | dim=768, 48 | depth=12, 49 | header=8 50 | ) 51 | ) 52 | 53 | self.transformer = ContinuousAutoregressiveWrapper(self.transformer) 54 | 55 | self.transformer_ema = EMA( 56 | self.transformer, 57 | beta = 0.9999, 58 | power=3/4, 59 | update_every = 1, 60 | update_after_step = 1 61 | ) 62 | 63 | self.encoder = WaveletEncode1d(2, "bior4.4", levels = self.levels) 64 | self.decoder = WaveletDecode1d(2, "bior4.4", levels = self.levels) 65 | 66 | def encode(self, reals): 67 | return self.encoder(reals) 68 | 69 | def decode(self, wavelets): 70 | return self.decoder(wavelets) 71 | 72 | def configure_optimizers(self): 73 | optimizer = optim.Adam([*self.transformer.parameters()], lr=1e-4) 74 | 75 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500, eta_min=1e-6) 76 | 77 | return [optimizer], [scheduler] 78 | 79 | def training_step(self, batch, batch_idx): 80 | reals = batch 81 | #reals = reals[0] 82 | 83 | wavelets = self.encode(reals) 84 | 85 | mask = None 86 | 87 | wavelets = rearrange(wavelets, "b c n -> b n c") 88 | 89 | with torch.cuda.amp.autocast(): 90 | loss = self.transformer(wavelets) 91 | 92 | log_dict = { 93 | 'train/loss': loss.detach(), 94 | 'train/lr': self.lr_schedulers().get_last_lr()[0] 95 | } 96 | 97 | self.log_dict(log_dict, prog_bar=True, on_step=True) 98 | return loss 99 | 100 | def on_before_zero_grad(self, *args, **kwargs): 101 | self.transformer_ema.update() 102 | 103 | class ExceptionCallback(pl.Callback): 104 | def on_exception(self, trainer, module, err): 105 | print(f'{type(err).__name__}: {err}', file=sys.stderr) 106 | 107 | 108 | class DemoCallback(pl.Callback): 109 | def __init__(self, global_args, demo_dl): 110 | super().__init__() 111 | self.demo_every = global_args.demo_every 112 | self.demo_samples = global_args.sample_size 113 | self.num_demos = global_args.num_demos 114 | self.sample_rate = global_args.sample_rate 115 | self.demo_dl = iter(demo_dl) 116 | 117 | @rank_zero_only 118 | @torch.no_grad() 119 | #def on_train_epoch_end(self, trainer, module): 120 | def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): 121 | last_demo_step = -1 122 | if (trainer.global_step - 1) % self.demo_every != 0 or last_demo_step == trainer.global_step: 123 | #if trainer.current_epoch % self.demo_every != 0: 124 | return 125 | 126 | last_demo_step = trainer.global_step 127 | 128 | print("Starting demo") 129 | 130 | n_samples = self.demo_samples//module.downsampling_ratio 131 | 132 | demo_reals = next(self.demo_dl).to(module.device) 133 | 134 | try: 135 | real_wavelets = module.encode(demo_reals).to(module.device) 136 | 137 | real_wavelets = rearrange(real_wavelets, "b c n -> b n c") 138 | 139 | start_embeds = real_wavelets[:, :1, :] 140 | 141 | print(f"Start embeds: {start_embeds.shape}") 142 | 143 | fake_wavelets = module.transformer_ema.ema_model.generate(start_embeds, n_samples) 144 | 145 | print("Decoding") 146 | 147 | print(f"Fake wavelets: {fake_wavelets.shape}") 148 | 149 | fakes = module.decode(fake_wavelets) 150 | 151 | print() 152 | 153 | # Put the demos together 154 | fakes = rearrange(fakes, 'b d n -> d (b n)') 155 | 156 | log_dict = {} 157 | 158 | print("Saving files") 159 | filename = f'demo_{trainer.global_step:08}.wav' 160 | fakes = fakes.clamp(-1, 1).mul(32767).to(torch.int16).cpu() 161 | torchaudio.save(filename, fakes, self.sample_rate) 162 | 163 | 164 | log_dict[f'demo'] = wandb.Audio(filename, 165 | sample_rate=self.sample_rate, 166 | caption=f'Reconstructed') 167 | 168 | log_dict[f'demo_melspec_left'] = wandb.Image(audio_spectrogram_image(fakes)) 169 | 170 | print("Done logging") 171 | trainer.logger.experiment.log(log_dict) 172 | 173 | except Exception as e: 174 | print(f'{type(e).__name__}: {e}') 175 | 176 | def main(): 177 | 178 | args = get_all_args() 179 | 180 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 181 | print('Using device:', device) 182 | torch.manual_seed(args.seed) 183 | 184 | train_set = AudioDataset( 185 | [args.training_dir], 186 | sample_rate=args.sample_rate, 187 | sample_size=args.sample_size, 188 | random_crop=args.random_crop, 189 | augs='Stereo(), PhaseFlipper()' 190 | ) 191 | 192 | train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True, 193 | num_workers=args.num_workers, persistent_workers=True, pin_memory=True, drop_last=True) 194 | 195 | demo_dl = data.DataLoader(train_set, args.num_demos, num_workers=args.num_workers, shuffle=True) 196 | 197 | 198 | wandb_logger = pl.loggers.WandbLogger(project=args.name) 199 | 200 | exc_callback = ExceptionCallback() 201 | ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, save_top_k=-1) 202 | demo_callback = DemoCallback(args, demo_dl) 203 | 204 | if args.ckpt_path: 205 | model = WaveletTransformer.load_from_checkpoint(args.ckpt_path, strict=False) 206 | else: 207 | model = WaveletTransformer() 208 | 209 | wandb_logger.watch(model) 210 | push_wandb_config(wandb_logger, args) 211 | 212 | trainer = pl.Trainer( 213 | devices=args.num_gpus, 214 | accelerator="gpu", 215 | num_nodes = args.num_nodes, 216 | strategy='ddp_find_unused_parameters_false', 217 | precision=16, 218 | accumulate_grad_batches=args.accum_batches, 219 | callbacks=[ckpt_callback, demo_callback, exc_callback], 220 | logger=wandb_logger, 221 | log_every_n_steps=1, 222 | max_epochs=10000000, 223 | default_root_dir=args.save_dir 224 | ) 225 | 226 | trainer.fit(model, train_dl) 227 | 228 | if __name__ == '__main__': 229 | main() 230 | 231 | -------------------------------------------------------------------------------- /train_archive/train_ad_upsampler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from prefigure.prefigure import get_all_args, push_wandb_config 4 | from contextlib import contextmanager 5 | from copy import deepcopy 6 | import math 7 | from pathlib import Path 8 | 9 | import sys 10 | import torch 11 | from torch import optim, nn 12 | from torch.nn import functional as F 13 | from torch.utils import data 14 | from tqdm import trange 15 | import pytorch_lightning as pl 16 | from pytorch_lightning.utilities.distributed import rank_zero_only 17 | from einops import rearrange 18 | 19 | from diffusion.pqmf import CachedPQMF as PQMF 20 | import torchaudio 21 | 22 | import auraloss 23 | 24 | import wandb 25 | 26 | from aeiou.datasets import AudioDataset 27 | from dataset.dataset import SampleDataset 28 | 29 | from audio_diffusion_pytorch import AudioDiffusionUpsampler 30 | from audio_diffusion_pytorch.utils import downsample, upsample 31 | from diffusion.model import ema_update 32 | from aeiou.viz import embeddings_table, pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image 33 | 34 | class DiffusionUncond(pl.LightningModule): 35 | def __init__(self, global_args): 36 | super().__init__() 37 | 38 | self.diffusion = AudioDiffusionUpsampler( 39 | factor = 3, 40 | in_channels = 2, 41 | channels = 128, 42 | patch_blocks = 1, 43 | patch_factor = 8, 44 | resnet_groups = 8, 45 | kernel_multiplier_downsample = 2, 46 | multipliers = [1, 2, 4, 4, 4, 4, 4], 47 | factors = [2, 2, 2, 2, 2, 2], 48 | num_blocks = [2, 2, 2, 2, 2, 2], 49 | attentions = [0, 0, 0, 0, 1, 1, 1], 50 | attention_heads = 8, 51 | attention_features = 128, 52 | attention_multiplier = 4 53 | ) 54 | 55 | self.diffusion_ema = deepcopy(self.diffusion) 56 | self.rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=global_args.seed) 57 | self.ema_decay = global_args.ema_decay 58 | 59 | def configure_optimizers(self): 60 | return optim.Adam([*self.diffusion.parameters()], lr=1e-4) 61 | 62 | def training_step(self, batch, batch_idx): 63 | reals = batch 64 | 65 | loss = self.diffusion(reals) 66 | 67 | log_dict = { 68 | 'train/loss': loss.detach() 69 | } 70 | 71 | self.log_dict(log_dict, prog_bar=True, on_step=True) 72 | return loss 73 | 74 | def on_before_zero_grad(self, *args, **kwargs): 75 | decay = 0.95 if self.current_epoch < 25 else self.ema_decay 76 | ema_update(self.diffusion, self.diffusion_ema, decay) 77 | 78 | class ExceptionCallback(pl.Callback): 79 | def on_exception(self, trainer, module, err): 80 | print(f'{type(err).__name__}: {err}', file=sys.stderr) 81 | 82 | 83 | class DemoCallback(pl.Callback): 84 | def __init__(self, demo_dl, global_args): 85 | super().__init__() 86 | self.demo_every = global_args.demo_every 87 | self.demo_samples = global_args.sample_size 88 | self.demo_steps = global_args.demo_steps 89 | self.demo_dl = iter(demo_dl) 90 | self.sample_rate = global_args.sample_rate 91 | 92 | 93 | @rank_zero_only 94 | @torch.no_grad() 95 | #def on_train_epoch_end(self, trainer, module): 96 | def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): 97 | last_demo_step = -1 98 | if (trainer.global_step - 1) % self.demo_every != 0 or last_demo_step == trainer.global_step: 99 | #if trainer.current_epoch % self.demo_every != 0: 100 | return 101 | 102 | downsample_factor = 3 103 | 104 | last_demo_step = trainer.global_step 105 | 106 | demo_reals, _ = next(self.demo_dl) 107 | 108 | try: 109 | downsampled = downsample(demo_reals, downsample_factor) 110 | 111 | upsampled = module.diffusion_ema.sample(downsampled, downsample_factor) 112 | 113 | # Put the demos together 114 | upsampled = rearrange(upsampled, 'b d n -> d (b n)') 115 | 116 | log_dict = {} 117 | 118 | upsample_filename = f'upsampled_{trainer.global_step:08}.wav' 119 | upsampled = upsampled.clamp(-1, 1).mul(32767).to(torch.int16).cpu() 120 | torchaudio.save(upsample_filename, upsampled, self.sample_rate) 121 | 122 | 123 | downsample_filename = f'downsampled_{trainer.global_step:08}.wav' 124 | downsampled = downsampled.clamp(-1, 1).mul(32767).to(torch.int16).cpu() 125 | torchaudio.save(downsample_filename, downsampled, self.sample_rate // downsample_factor) 126 | 127 | reals_filename = f'reals_{trainer.global_step:08}.wav' 128 | demo_reals = demo_reals.clamp(-1, 1).mul(32767).to(torch.int16).cpu() 129 | torchaudio.save(reals_filename, demo_reals, self.sample_rate) 130 | 131 | 132 | log_dict[f'upsampled'] = wandb.Audio(upsample_filename, 133 | sample_rate=self.sample_rate, 134 | caption=f'Upsampled') 135 | log_dict[f'downsampled'] = wandb.Audio(downsample_filename, 136 | sample_rate=self.sample_rate // downsample_factor, 137 | caption=f'Downsampled') 138 | log_dict[f'real'] = wandb.Audio(reals_filename, 139 | sample_rate=self.sample_rate, 140 | caption=f'Real') 141 | 142 | log_dict[f'real_melspec_left'] = wandb.Image(audio_spectrogram_image(demo_reals)) 143 | log_dict[f'downsample_melspec_left'] = wandb.Image(audio_spectrogram_image(downsampled)) 144 | log_dict[f'upsample_melspec_left'] = wandb.Image(audio_spectrogram_image(upsampled)) 145 | 146 | 147 | trainer.logger.experiment.log(log_dict, step=trainer.global_step) 148 | except Exception as e: 149 | print(f'{type(e).__name__}: {e}', file=sys.stderr) 150 | 151 | def main(): 152 | 153 | args = get_all_args() 154 | 155 | args.latent_dim = 0 156 | 157 | #args.random_crop = False 158 | 159 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 160 | print('Using device:', device) 161 | torch.manual_seed(args.seed) 162 | 163 | train_set = AudioDataset( 164 | [args.training_dir], 165 | sample_rate=args.sample_rate, 166 | sample_size=args.sample_size, 167 | random_crop=args.random_crop, 168 | augs='Stereo()' 169 | ) 170 | 171 | #train_set = SampleDataset([args.training_dir], args, keywords=["kick", "snare", "clap", "snap", "hat", "cymbal", "crash", "ride"]) 172 | 173 | train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True, 174 | num_workers=args.num_workers, persistent_workers=True, pin_memory=True) 175 | wandb_logger = pl.loggers.WandbLogger(project=args.name) 176 | demo_dl = data.DataLoader(train_set, args.num_demos, num_workers=args.num_workers, shuffle=True) 177 | exc_callback = ExceptionCallback() 178 | ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, save_top_k=-1) 179 | demo_callback = DemoCallback(demo_dl, args) 180 | diffusion_model = DiffusionUncond(args) 181 | wandb_logger.watch(diffusion_model) 182 | push_wandb_config(wandb_logger, args) 183 | 184 | diffusion_trainer = pl.Trainer( 185 | devices=args.num_gpus, 186 | accelerator="gpu", 187 | num_nodes = args.num_nodes, 188 | strategy='ddp', 189 | #precision=16, 190 | accumulate_grad_batches=args.accum_batches, 191 | callbacks=[ckpt_callback, demo_callback, exc_callback], 192 | logger=wandb_logger, 193 | log_every_n_steps=1, 194 | max_epochs=10000000, 195 | ) 196 | 197 | diffusion_trainer.fit(diffusion_model, train_dl, ckpt_path=args.ckpt_path) 198 | 199 | if __name__ == '__main__': 200 | main() -------------------------------------------------------------------------------- /prune_clap_ckpt.py: -------------------------------------------------------------------------------- 1 | #@title Imports and definitions 2 | import argparse 3 | from contextlib import contextmanager 4 | from copy import deepcopy 5 | import math 6 | from pathlib import Path 7 | 8 | import sys 9 | import gc 10 | 11 | from autoencoders.soundstream import SoundStreamXLEncoder, SoundStreamXLDecoder 12 | from autoencoders.models import AudioAutoencoder 13 | from audio_encoders_pytorch import Encoder1d 14 | from ema_pytorch import EMA 15 | from audio_diffusion_pytorch.modules import UNetCFG1d 16 | 17 | from audio_diffusion_pytorch import T5Embedder, NumberEmbedder 18 | 19 | import torch 20 | from torch import optim, nn 21 | from torch.nn import functional as F 22 | from torch.utils import data 23 | from tqdm import trange 24 | from einops import rearrange 25 | 26 | import torchaudio 27 | from decoders.diffusion_decoder import DiffusionAttnUnet1D 28 | import numpy as np 29 | import laion_clap 30 | 31 | import random 32 | from diffusion.utils import Stereo, PadCrop 33 | from glob import glob 34 | 35 | from torch.nn.parameter import Parameter 36 | 37 | class LatentAudioDiffusionAutoencoder(nn.Module): 38 | def __init__(self, autoencoder: AudioAutoencoder): 39 | super().__init__() 40 | 41 | 42 | self.latent_dim = autoencoder.latent_dim 43 | 44 | self.second_stage_latent_dim = 32 45 | 46 | factors = [2, 2, 2, 2] 47 | 48 | self.latent_downsampling_ratio = np.prod(factors) 49 | 50 | self.downsampling_ratio = autoencoder.downsampling_ratio * self.latent_downsampling_ratio 51 | 52 | self.latent_encoder = Encoder1d( 53 | in_channels=self.latent_dim, 54 | out_channels = self.second_stage_latent_dim, 55 | channels = 128, 56 | multipliers = [1, 2, 4, 8, 8], 57 | factors = factors, 58 | num_blocks = [8, 8, 8, 8], 59 | ) 60 | 61 | self.diffusion = DiffusionAttnUnet1D( 62 | io_channels=self.latent_dim, 63 | cond_dim = self.second_stage_latent_dim, 64 | n_attn_layers=0, 65 | c_mults=[512] * 10, 66 | depth=10 67 | ) 68 | 69 | self.autoencoder = autoencoder 70 | 71 | self.autoencoder.requires_grad_(False) 72 | 73 | self.rng = torch.quasirandom.SobolEngine(1, scramble=True) 74 | 75 | def encode(self, reals): 76 | first_stage_latents = self.autoencoder.encode(reals) 77 | 78 | second_stage_latents = self.latent_encoder(first_stage_latents) 79 | 80 | second_stage_latents = torch.tanh(second_stage_latents) 81 | 82 | return second_stage_latents 83 | 84 | def decode(self, latents, steps=250, device="cuda"): 85 | first_stage_latent_noise = torch.randn([latents.shape[0], self.latent_dim, latents.shape[2]*self.latent_downsampling_ratio]).to(device) 86 | 87 | t = torch.linspace(1, 0, steps + 1, device=device)[:-1] 88 | 89 | step_list = get_spliced_ddpm_cosine_schedule(t) 90 | 91 | first_stage_sampled = sampling.iplms_sample(self.diffusion, first_stage_latent_noise, step_list, {"cond":latents}) 92 | #first_stage_sampled = sample(self.diffusion, first_stage_latent_noise, steps, 0, cond=latents) 93 | decoded = self.autoencoder.decode(first_stage_sampled) 94 | return decoded 95 | 96 | class StackedAELatentDiffusionCond(nn.Module): 97 | def __init__(self, latent_ae: LatentAudioDiffusionAutoencoder, clap_module: laion_clap.CLAP_Module, diffusion_config): 98 | super().__init__() 99 | 100 | self.latent_dim = latent_ae.second_stage_latent_dim 101 | self.downsampling_ratio = latent_ae.downsampling_ratio 102 | 103 | self.embedding_features = 512 104 | 105 | self.embedder = clap_module 106 | 107 | self.diffusion = UNetCFG1d(**diffusion_config) 108 | 109 | self.autoencoder = latent_ae 110 | 111 | self.autoencoder.requires_grad_(False) 112 | 113 | self.rng = torch.quasirandom.SobolEngine(1, scramble=True) 114 | 115 | def encode(self, reals): 116 | return self.autoencoder.encode(reals) 117 | 118 | def decode(self, latents, steps=250, device="cuda"): 119 | return self.autoencoder.decode(latents, steps, device=device) 120 | 121 | 122 | def prune_ckpt_weights(stacked_state_dict): 123 | new_state_dict = {} 124 | for name, param in stacked_state_dict.items(): 125 | if name.startswith("diffusion_ema.ema_model."): 126 | new_name = name.replace("diffusion_ema.ema_model.", "diffusion.") 127 | if isinstance(param, Parameter): 128 | # backwards compatibility for serialized parameters 129 | param = param.data 130 | new_state_dict[new_name] = param 131 | elif name.startswith("autoencoder") or name.startswith("embedder"): 132 | new_state_dict[name] = param 133 | 134 | return new_state_dict 135 | 136 | 137 | if __name__ == "__main__": 138 | 139 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 140 | parser.add_argument('--ckpt_path', help='Path to the checkpoint to be pruned') 141 | parser.add_argument('--clap_ckpt_path', help='Path to the CLAP checkpoint') 142 | parser.add_argument('--clap_fusion', action='store_true', help='Enable CLAP fusion', default=False) 143 | parser.add_argument('--clap_amodel', help='CLAP amodel', default="HTSAT-tiny") 144 | #parser.add_argument('--output_path', help='Path to the checkpoint to be pruned') 145 | args = parser.parse_args() 146 | 147 | print("Creating the model...") 148 | 149 | first_stage_config = {"capacity": 64, "c_mults": [2, 4, 8, 16, 32], "strides": [2, 2, 2, 2, 2], "latent_dim": 32} 150 | 151 | first_stage_autoencoder = AudioAutoencoder( 152 | **first_stage_config 153 | ).eval() 154 | 155 | diffusion_config = dict( 156 | in_channels = 32, 157 | context_embedding_features = 512, 158 | context_embedding_max_length = 1, 159 | channels = 256, 160 | resnet_groups = 8, 161 | kernel_multiplier_downsample = 2, 162 | multipliers = [2, 3, 4, 4, 4, 4], 163 | factors = [1, 2, 2, 4, 4], 164 | num_blocks = [3, 3, 3, 3, 3], 165 | attentions = [0, 0, 2, 2, 2, 2], 166 | attention_heads = 16, 167 | attention_features = 64, 168 | attention_multiplier = 4, 169 | attention_use_rel_pos=True, 170 | attention_rel_pos_max_distance=2048, 171 | attention_rel_pos_num_buckets=256, 172 | use_nearest_upsample = False, 173 | use_skip_scale = True, 174 | use_context_time = True, 175 | ) 176 | 177 | clap_config = dict( 178 | clap_fusion = args.clap_fusion, 179 | clap_amodel = args.clap_amodel 180 | ) 181 | 182 | latent_diffae = LatentAudioDiffusionAutoencoder(autoencoder=first_stage_autoencoder) 183 | 184 | clap_model = laion_clap.CLAP_Module(enable_fusion=args.clap_fusion, amodel= args.clap_amodel).requires_grad_(False).eval() 185 | 186 | if args.clap_ckpt_path: 187 | clap_model.load_ckpt(ckpt=args.clap_ckpt_path) 188 | else: 189 | clap_model.load_ckpt(model_id=1) 190 | 191 | model = StackedAELatentDiffusionCond(latent_diffae, clap_module=clap_model, diffusion_config=diffusion_config) 192 | 193 | ckpt_state_dict = torch.load(args.ckpt_path)["state_dict"] 194 | #print(ckpt_state_dict.keys()) 195 | 196 | new_ckpt = {} 197 | 198 | new_ckpt["state_dict"] = prune_ckpt_weights(ckpt_state_dict) 199 | 200 | new_ckpt["model_config"] = dict( 201 | version = (0, 0, 1), 202 | model_info = dict( 203 | name = 'Clap Conditioned Dance Diffusion Model', 204 | description = 'v1.0', 205 | type = 'CCDD', 206 | native_chunk_size = 2097152, 207 | sample_rate = 48000, 208 | ), 209 | autoencoder_config = first_stage_config, 210 | diffusion_config = diffusion_config, 211 | clap_config = clap_config 212 | ) 213 | 214 | model.load_state_dict(new_ckpt["state_dict"], strict=False) 215 | 216 | torch.save(new_ckpt, f'./pruned.ckpt') --------------------------------------------------------------------------------