├── .gitignore ├── LICENSE ├── README.md ├── audio_datasets ├── __init__.py ├── constants.py ├── helpers.py ├── librispeech.py └── text_grid_utils.py ├── compute_wer.py ├── data └── aligned_librispeech.tar.gz ├── diffusion ├── audio_denoising_diffusion.py ├── noise_schedule.py ├── optimizer.py └── time_sampler.py ├── evaluation └── evaluate_transcript.py ├── img ├── esd.png └── results.png ├── models ├── modules │ ├── __init__.py │ ├── blocks.py │ ├── conv.py │ ├── norm.py │ ├── position.py │ └── transformer.py ├── transformer_wrapper.py └── unet.py ├── neural_codec └── encodec_wrapper.py ├── requirements.txt ├── scripts ├── sample │ └── sample_16_ls_testclean.sh └── train │ └── train.sh ├── train_audio_diffusion.py └── utils └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Justin Lovelace 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sample-Efficient Diffusion for Text-To-Speech Synthesis 2 | 3 | This repository contains the implementation for SESD, a latent diffusion model for text-to-speech generation. We introduced this model in the following work: 4 | 5 | **Sample-Efficient Diffusion for Text-To-Speech Synthesis** 6 | Justin Lovelace, Soham Ray, Kwangyoun Kim, Kilian Q. Weinberger, Felix Wu 7 | Interspeech 2024 8 | [\[paper\]](https://www.arxiv.org/abs/2409.03717) 9 | 10 | ## Abstract 11 | 12 | This work introduces Sample-Efficient Speech Diffusion (SESD), an algorithm for effective speech synthesis in modest data regimes through latent diffusion. It is based on a novel diffusion architecture, that we call U-Audio Transformer (U-AT), that efficiently scales to long sequences and operates in the latent space of a pre-trained audio autoencoder. Conditioned on character-aware language model representations, SESD achieves impressive results despite training on less than 1k hours of speech – far less than current state-of-the-art systems. In fact, it synthesizes more intelligible speech than the state-of-the-art auto-regressive model, VALL-E, while using less than 2% the training data. 13 | 14 |
15 | Model Architecture 16 |
17 | Overview of our Sample-Efficient Speech Diffusion (SESD) architecture. 18 |

19 | Main Results 20 |
21 | 22 | If you find this work useful, please consider citing: 23 | 24 | ```bibtex 25 | @inproceedings{lovelace2024sample, 26 | title={Sample-Efficient Diffusion for Text-To-Speech Synthesis}, 27 | author={Lovelace, Justin and Ray, Soham and Kim, Kwangyoun and Weinberger, Kilian Q and Wu, Felix}, 28 | booktitle={Proc. Interspeech 2024}, 29 | pages={4403--4407}, 30 | year={2024} 31 | } 32 | ``` 33 | 34 | ## Installation 35 | 36 | Install the required dependencies: 37 | ```bash 38 | pip install -r requirements.txt 39 | ``` 40 | 41 | ## Datasets 42 | 43 | We train and evaluate our models using the LibriSpeech dataset from the Hugging Face Hub and use the standard LibriSpeech dataset for evaluation. 44 | 45 | ### Speaker Prompt Data 46 | 47 | For speaker-prompted generation, we utilize three seconds of another prompt. To extract the corresponding transcript for the speech, we utilized the [Montreal Forced Aligner](https://montreal-forced-aligner.readthedocs.io/en/latest/first_steps/example.html#example-1-aligning-librispeech-english). An aligned version of LibriSpeech can be found at: 48 | 49 | ```bash 50 | data/aligned_librispeech.tar.gz 51 | ``` 52 | 53 | After extracting the archive, update the `ALIGNED_DATA_DIR` path in `audio_datasets/constants.py` to point to your data directory. 54 | 55 | ## Training 56 | 57 | Our training setup: 58 | - Single Nvidia A6000 GPU 59 | - BF16 mixed precision training 60 | - Batch size and other parameters may need adjustment based on your hardware 61 | 62 | To train the diffusion model: 63 | ```bash 64 | ./scripts/train/train.sh 65 | ``` 66 | 67 | ## Model Checkpoint 68 | 69 | Model checkpoint will be released soon! 70 | 71 | ## Inference 72 | 73 | To synthesize speech for the LibriSpeech test-clean set: 74 | ```bash 75 | ./scripts/sample/sample_16_ls_testclean.sh 76 | ``` 77 | 78 | Note: Update the `--resume_dir` argument with the path to your trained model. 79 | 80 | ## Questions and Support 81 | 82 | Feel free to create an issue if you have any questions. 83 | 84 | ## Acknowledgements 85 | 86 | This work built upon excellent open-source implementations from [Phil Wang (Lucidrains)](https://github.com/lucidrains). Specifically, we built off of his PyTorch [DDPM implementation](https://github.com/lucidrains/denoising-diffusion-pytorch). -------------------------------------------------------------------------------- /audio_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from audio_datasets.librispeech import LibriSpeech 2 | from audio_datasets.constants import ENCODEC_REDUCTION_FACTOR, ENCODEC_SAMPLING_RATE, LATENT_SAMPLING_RATE 3 | -------------------------------------------------------------------------------- /audio_datasets/constants.py: -------------------------------------------------------------------------------- 1 | MAX_DURATION_IN_SECONDS = 20 2 | ENCODEC_SAMPLING_RATE = 24000 3 | LIBRISPEECH_SAMPLING_RATE = 16000 4 | ENCODEC_REDUCTION_FACTOR = 320 5 | LATENT_SAMPLING_RATE = 75 6 | ALIGNED_DATA_DIR = '/home/jl3353/stable-audio-diffusion/data/aligned_librispeech/' -------------------------------------------------------------------------------- /audio_datasets/helpers.py: -------------------------------------------------------------------------------- 1 | from audio_datasets.constants import ENCODEC_REDUCTION_FACTOR, ENCODEC_SAMPLING_RATE, MAX_DURATION_IN_SECONDS 2 | 3 | 4 | def round_up_to_multiple(number, multiple): 5 | remainder = number % multiple 6 | if remainder == 0: 7 | return number 8 | else: 9 | return number + (multiple - remainder) 10 | 11 | def round_up_to_waveform_multiple(number, multiple=16): 12 | waveform_multiple = multiple*ENCODEC_REDUCTION_FACTOR 13 | rounded_number = round_up_to_multiple(number, waveform_multiple) 14 | return rounded_number 15 | 16 | def compute_max_length(multiple=16): 17 | max_len = MAX_DURATION_IN_SECONDS*ENCODEC_SAMPLING_RATE 18 | 19 | max_len = round_up_to_waveform_multiple(max_len, multiple) 20 | return max_len 21 | 22 | def is_audio_length_in_range(audio, sampling_rate=ENCODEC_SAMPLING_RATE): 23 | return len(audio['array']) <= (MAX_DURATION_IN_SECONDS*sampling_rate) 24 | 25 | def is_audio_length_in_test_range(audio): 26 | return ((4*ENCODEC_SAMPLING_RATE) <= len(audio['array'])) and (len(audio['array']) <= (10*ENCODEC_SAMPLING_RATE)) 27 | -------------------------------------------------------------------------------- /audio_datasets/librispeech.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | from datasets import load_dataset, Audio, concatenate_datasets 3 | import numpy as np 4 | import torch 5 | from tqdm import tqdm 6 | import torch.nn.functional as F 7 | import os 8 | import random 9 | import math 10 | import json 11 | 12 | from audio_datasets.text_grid_utils import get_partial_transcript 13 | from audio_datasets.helpers import compute_max_length, round_up_to_waveform_multiple, is_audio_length_in_range, is_audio_length_in_test_range 14 | from audio_datasets.constants import ENCODEC_REDUCTION_FACTOR, ENCODEC_SAMPLING_RATE, MAX_DURATION_IN_SECONDS, ALIGNED_DATA_DIR, LATENT_SAMPLING_RATE 15 | 16 | def lowercase_text(example): 17 | example["text"] = example["text"].lower() 18 | return example 19 | 20 | class LibriSpeech(Dataset): 21 | """ 22 | Wrapper around HuggingFace dataset for processing. 23 | """ 24 | def __init__(self, split='train', debug=False, tokenizer=None, max_seq_len=None, sampling_rate=None, duration_path=None): 25 | super().__init__() 26 | self.sr = ENCODEC_SAMPLING_RATE if sampling_rate is None else sampling_rate 27 | self.split = split 28 | self.split2dir = {'valid': 'dev-clean', 'test': 'test-clean'} 29 | if split == 'train': 30 | train100 = load_dataset('librispeech_asr', 'clean', split='train.100',) 31 | train360 = load_dataset('librispeech_asr', 'clean', split='train.360', ) 32 | train500 = load_dataset('librispeech_asr', 'other', split='train.500', ) 33 | 34 | self.hf_dataset = concatenate_datasets([train100, train360, train500]) 35 | elif split == 'valid': 36 | self.hf_dataset = load_dataset('librispeech_asr', 'clean', split='validation') 37 | elif split == 'test': 38 | self.hf_dataset = load_dataset('librispeech_asr', 'clean', split='test') 39 | else: 40 | raise ValueError(f"invalid split: {split}, must be in ['train', 'valid'] ") 41 | 42 | # Downsample to accelerate processing for debugging purposes 43 | if debug: 44 | self.hf_dataset = self.hf_dataset.select(range(100)) 45 | # Resample to 24kHz for Encodec 46 | self.hf_dataset = self.hf_dataset.cast_column("audio", Audio(sampling_rate=self.sr)) 47 | 48 | self.hf_dataset = self.hf_dataset.map(lowercase_text) 49 | if split == 'train': 50 | self.hf_dataset = self.hf_dataset.filter(is_audio_length_in_range, input_columns=['audio']) 51 | elif split == 'test': 52 | self.hf_dataset = self.hf_dataset.filter(is_audio_length_in_test_range, input_columns=['audio']) 53 | 54 | if self.split in {'valid', 'test'}: 55 | unique_speaker_ids = set(self.hf_dataset['speaker_id']) 56 | self.speaker_datasets = {speaker_id:self.hf_dataset.filter(lambda example: example["speaker_id"] == speaker_id) for speaker_id in unique_speaker_ids} 57 | 58 | 59 | self.max_seq_len = max_seq_len if max_seq_len is not None else compute_max_length() 60 | print(f'Max seq length: {self.max_seq_len/ENCODEC_REDUCTION_FACTOR}') 61 | 62 | if tokenizer is not None: 63 | self.hf_dataset = self.hf_dataset.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True, max_length=256), batched=True) 64 | self.tokenizer = tokenizer 65 | 66 | self.duration_dict = None 67 | if split == 'test' and duration_path is not None: 68 | # Read duration path json 69 | with open(duration_path, 'rt') as f: 70 | self.duration_dict = json.load(f) 71 | assert len(self.duration_dict['nucleus_pred']) == len(self.hf_dataset), f'Length of duration dict {len(self.duration_dict)} does not match length of dataset {len(self.hf_dataset)}' 72 | self.nucleus_pred_duration = [float(d) for d in self.duration_dict['nucleus_pred']] 73 | 74 | def __getitem__(self, index): 75 | example = self.hf_dataset[index] 76 | text = example['text'] 77 | wav = example['audio']['array'][:self.max_seq_len] 78 | wavpath = example['audio']['path'] 79 | npad = self.max_seq_len - len(wav) 80 | assert npad>=0, f'Waveform length {len(wav)} needs to be less than {self.max_seq_len}' 81 | wav_len = len(wav) 82 | # [1, L]: Channels x length 83 | audio_duration_sec = wav_len/ENCODEC_SAMPLING_RATE 84 | wav = torch.tensor(np.pad(wav, pad_width=(0, npad), mode='constant'), dtype=torch.float).unsqueeze(0) 85 | 86 | audio_mask = torch.zeros((self.max_seq_len//ENCODEC_REDUCTION_FACTOR,), dtype=torch.bool) 87 | if self.duration_dict is None: 88 | num_unmasked_tokens = round_up_to_waveform_multiple(wav_len)//ENCODEC_REDUCTION_FACTOR 89 | audio_mask[:num_unmasked_tokens] = True 90 | else: 91 | audio_duration_sec = self.nucleus_pred_duration[index] 92 | wav_len = int(audio_duration_sec*ENCODEC_SAMPLING_RATE) 93 | num_unmasked_tokens = round_up_to_waveform_multiple(wav_len)//ENCODEC_REDUCTION_FACTOR 94 | audio_mask[:num_unmasked_tokens] = True 95 | 96 | data = {'wav': wav, 'text': text, 'audio_duration': audio_duration_sec, 'path':wavpath} 97 | data['audio_mask'] = audio_mask 98 | 99 | 100 | # Speaker prompting 101 | if self.split in {'valid', 'test'}: 102 | split_dir = self.split2dir[self.split] 103 | speaker_id = example['speaker_id'] 104 | speaker_ds = self.speaker_datasets[speaker_id] 105 | # Sample idx for n-1 elements and remap matching element to the last element 106 | speaker_idx = random.randint(0, len(speaker_ds)-2) 107 | if speaker_ds[speaker_idx]['id'] == example['id']: 108 | speaker_idx = len(speaker_ds)-1 109 | speaker_example = speaker_ds[speaker_idx] 110 | textgrid_path = os.path.join(ALIGNED_DATA_DIR, split_dir, f'{speaker_id}', f'{speaker_example["id"]}.TextGrid') 111 | partial_transcript = get_partial_transcript(textgrid_path) 112 | 113 | speaker_text = partial_transcript['transcript'] 114 | speaker_start_time = partial_transcript['start_time'] 115 | speaker_end_time = partial_transcript['end_time'] 116 | # Convert seconds to encodec frame rate 117 | 118 | speaker_start_frame = math.floor(speaker_start_time * ENCODEC_SAMPLING_RATE) 119 | speaker_end_frame = math.ceil(speaker_end_time * ENCODEC_SAMPLING_RATE) 120 | speaker_wav_frames = speaker_end_frame - speaker_start_frame 121 | speaker_audio_duration_sec = speaker_wav_frames/ENCODEC_SAMPLING_RATE 122 | speaker_wav = speaker_example['audio']['array'][speaker_start_frame:speaker_end_frame] 123 | speaker_npad = self.max_seq_len - len(speaker_wav) 124 | assert speaker_npad>=0, f'Waveform length {len(speaker_wav)} needs to be less than {self.max_seq_len}' 125 | # [1, L]: Channels x length 126 | 127 | speaker_wav = torch.tensor(np.pad(speaker_wav, pad_width=(0, speaker_npad), mode='constant'), dtype=torch.float).unsqueeze(0) 128 | 129 | speaker_data = {'speaker_wav': speaker_wav, 'speaker_text': speaker_text, 'speaker_audio_duration': speaker_audio_duration_sec} 130 | data.update(speaker_data) 131 | 132 | inpaint_audio_mask = torch.zeros((self.max_seq_len//ENCODEC_REDUCTION_FACTOR,), dtype=torch.bool) 133 | num_unmasked_tokens = round_up_to_waveform_multiple(speaker_wav_frames+wav_len)//ENCODEC_REDUCTION_FACTOR 134 | inpaint_audio_mask[:num_unmasked_tokens] = True 135 | data['inpaint_audio_mask'] = inpaint_audio_mask 136 | 137 | if self.tokenizer is not None: 138 | data['input_ids'] = torch.tensor(example['input_ids'], dtype=torch.long) 139 | data['attention_mask'] = torch.tensor(example['attention_mask'], dtype=torch.long) 140 | 141 | return data 142 | 143 | def __len__(self): 144 | return len(self.hf_dataset) 145 | 146 | if __name__ == "__main__": 147 | dataset = LibriSpeech(split='test') 148 | 149 | example = dataset.__getitem__(0) 150 | import soundfile as sf 151 | import pdb; pdb.set_trace() 152 | sf.write(f'example_audio/librispeech_sample.wav', example['wav'], ENCODEC_SAMPLING_RATE) 153 | with open(f'example_audio/librispeech_text.txt', 'w') as f: 154 | print(example['text'], file=f) 155 | -------------------------------------------------------------------------------- /audio_datasets/text_grid_utils.py: -------------------------------------------------------------------------------- 1 | from praatio import textgrid 2 | import os 3 | 4 | data_path = '../data/aligned_librispeech/dev-clean/84/84-121123-0002.TextGrid' 5 | 6 | # tg = textgrid.openTextgrid(data_path, False) 7 | def get_word_intervals(textgrid_path): 8 | tg = textgrid.openTextgrid(textgrid_path, False) 9 | return tg.getTier("words").entries 10 | 11 | def get_partial_transcript(textgrid_path, transcript_end_time=3): 12 | intervals = get_word_intervals(textgrid_path) 13 | word_list = [] 14 | start_time = 0 15 | end_time = intervals[-1].end 16 | # Reverse intervals to get the last word first 17 | intervals = intervals[::-1] 18 | for interval in intervals: 19 | if (end_time - interval.start) > transcript_end_time: 20 | break 21 | word_list.append(interval.label) 22 | start_time = interval.start 23 | # Reverse word_list to get the words in the correct order 24 | word_list = word_list[::-1] 25 | return {'transcript': ' '.join(word_list), 26 | 'end_time': end_time, 27 | 'start_time': start_time,} 28 | 29 | 30 | if __name__ == "__main__": 31 | partial_transcript = get_partial_transcript(data_path) 32 | print(partial_transcript) -------------------------------------------------------------------------------- /compute_wer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from audio_datasets import LibriSpeech 5 | from evaluation.evaluate_transcript import compute_wer 6 | from utils.utils import parse_float_tuple 7 | 8 | def get_wer(sample_dir, text_list, num_samples, guidances, prefix_seconds=0.0): 9 | wer_dict = {} 10 | for guidance in guidances: 11 | filepaths_list = [] 12 | for i in range(num_samples): 13 | if prefix_seconds > 0: 14 | filepaths_list.append(os.path.join(sample_dir, f'guide{guidance:.1f}_prefix{prefix_seconds:.1f}', f'audio_{i}.wav')) 15 | else: 16 | filepaths_list.append(os.path.join(sample_dir, f'guide{guidance:.1f}', f'audio_{i}.wav')) 17 | 18 | text_list = text_list[:num_samples] 19 | wer = compute_wer(filepaths_list, text_list) 20 | wer_dict[guidance] = wer 21 | print(f'WER for guidance {guidance}: {wer*100:.1f}') 22 | return wer_dict 23 | 24 | def main(args): 25 | test_ls_dataset = LibriSpeech(split='test') 26 | text_list = test_ls_dataset.hf_dataset['text'] 27 | 28 | wer_dict = get_wer(args.sample_dir, text_list, args.num_samples, args.guidance, args.prefix_seconds) 29 | 30 | # Save wer_dict to a json file 31 | with open(os.path.join(args.sample_dir, 'wer.json'), 'w') as f: 32 | json.dump(wer_dict, f) 33 | # Print wer_dict 34 | print(wer_dict) 35 | 36 | if __name__ == "__main__": 37 | parser = argparse.ArgumentParser(description="Training arguments") 38 | parser.add_argument("--sample_dir", type=str, default='saved_models/librispeech/test/librispeech_250k/samples/step_100000') 39 | parser.add_argument("--num_samples", type=int, default=1237) 40 | parser.add_argument('--guidance', type=parse_float_tuple, help='Tuple of float values for dim_mults') 41 | parser.add_argument('--prefix_seconds', type=float, default=0.0) 42 | 43 | args = parser.parse_args() 44 | main(args) -------------------------------------------------------------------------------- /data/aligned_librispeech.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinlovelace/SESD/c59b41d7ec98922bdb6bce4217ebcf6cd1c0d2b2/data/aligned_librispeech.tar.gz -------------------------------------------------------------------------------- /diffusion/audio_denoising_diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | from pathlib import Path 3 | import random 4 | from functools import partial 5 | from collections import namedtuple, Counter 6 | from multiprocessing import cpu_count 7 | import os 8 | import numpy as np 9 | import csv 10 | import timeit 11 | import json 12 | import argparse 13 | from collections import defaultdict 14 | from contextlib import nullcontext 15 | import soundfile as sf 16 | 17 | 18 | import torch 19 | from torch import nn, einsum 20 | import torch.nn.functional as F 21 | from torch.utils.data import Dataset, DataLoader 22 | 23 | from einops import rearrange, reduce, repeat 24 | 25 | from tqdm.auto import tqdm 26 | from ema_pytorch import EMA 27 | 28 | from transformers import get_scheduler, AutoTokenizer, T5ForConditionalGeneration 29 | 30 | from accelerate import Accelerator, DistributedDataParallelKwargs 31 | 32 | from diffusion.optimizer import get_adamw_optimizer, get_lion_optimizer 33 | from utils.utils import compute_grad_norm 34 | import utils.utils as file_utils 35 | from diffusion.noise_schedule import * 36 | from diffusion.time_sampler import LossEMASampler 37 | from audio_datasets import LibriSpeech, ENCODEC_SAMPLING_RATE, LATENT_SAMPLING_RATE, ENCODEC_REDUCTION_FACTOR 38 | from neural_codec.encodec_wrapper import EncodecWrapper 39 | from utils.utils import get_output_dir 40 | 41 | 42 | ModelPrediction = namedtuple('ModelPrediction', ['pred_eps', 'pred_x_start', 'pred_v']) 43 | 44 | # Recommendation from https://arxiv.org/abs/2303.09556 45 | MIN_SNR_GAMMA = 5 46 | 47 | # helpers functions 48 | 49 | def exists(x): 50 | return x is not None 51 | 52 | def default(val, d): 53 | if exists(val): 54 | return val 55 | return d() if callable(d) else d 56 | 57 | def identity(t, *args, **kwargs): 58 | return t 59 | 60 | def cycle(dl): 61 | while True: 62 | for data in dl: 63 | yield data 64 | 65 | def num_to_groups(num, divisor): 66 | groups = num // divisor 67 | remainder = num % divisor 68 | arr = [divisor] * groups 69 | if remainder > 0: 70 | arr.append(remainder) 71 | return arr 72 | 73 | def l2norm(t): 74 | return F.normalize(t, dim = -1) 75 | 76 | # Avoid log(0) 77 | def log(t, eps = 1e-20): 78 | return torch.log(t.clamp(min = eps)) 79 | 80 | def right_pad_dims_to(x, t): 81 | padding_dims = x.ndim - t.ndim 82 | if padding_dims <= 0: 83 | return t 84 | return t.view(*t.shape, *((1,) * padding_dims)) 85 | 86 | # gaussian diffusion trainer class 87 | 88 | def extract(a, t, x_shape): 89 | b, *_ = t.shape 90 | out = a.gather(-1, t) 91 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 92 | 93 | def set_seeds(seed): 94 | random.seed(seed) 95 | torch.manual_seed(seed) 96 | torch.cuda.manual_seed(seed) 97 | 98 | def masked_mean(t, *, dim, mask = None): 99 | if not exists(mask): 100 | return t.mean(dim = dim) 101 | 102 | denom = mask.sum(dim = dim) 103 | masked_t = t.masked_fill(~mask, 0.) 104 | 105 | return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5) 106 | 107 | 108 | class GaussianDiffusion(nn.Module): 109 | def __init__( 110 | self, 111 | model, 112 | *, 113 | max_seq_len, 114 | sampling_timesteps = 250, 115 | text_encoder = 'google/byt5-small', 116 | loss_type = 'l1', 117 | objective = 'pred_v', 118 | parameterization = 'pred_v', 119 | train_schedule = 'cosine', 120 | ema_decay = 0.9999, 121 | sampling_schedule = None, 122 | scale = 1., 123 | scale_by_std=True, 124 | sampler = 'ddim', 125 | log_var_interp = 0.2, 126 | langevin_step_size = 0.0, 127 | snr_ddim_threshold= None, 128 | unconditional_prob = 0.2, 129 | inpainting_prob = 0.5, 130 | inpainting_duration_mode = 0.01, 131 | inpainting_duration_concentration = 5, 132 | loss_weighting = 'edm', 133 | loss_weighting_args = {}, 134 | ): 135 | super().__init__() 136 | 137 | self.denoising_network = EMA(model, beta = ema_decay, update_every = 1, power=3/4) 138 | 139 | self.max_seq_len = max_seq_len 140 | 141 | self.objective = objective 142 | self.parameterization = parameterization 143 | self.sampler=sampler 144 | self.log_var_interp = log_var_interp 145 | self.langevin_step_size = langevin_step_size 146 | self.snr_ddim_threshold = snr_ddim_threshold 147 | 148 | # Min-SNR weighting from https://arxiv.org/abs/2303.09556 149 | # Option no longer supported; buffer kept for backwards compatibility with statedict of prior checkpoints 150 | self.register_buffer('min_snr_gamma', torch.tensor(MIN_SNR_GAMMA)) 151 | 152 | self.loss_type = loss_type 153 | 154 | assert objective == 'pred_v', f'objective {objective} must be pred_v' 155 | 156 | self.adaptive_noise_schedule = (train_schedule == 'adaptive') 157 | if self.adaptive_noise_schedule: 158 | self.train_schedule = LossEMASampler() 159 | self.val_ema_sampler = LossEMASampler(ema_decay=0.0) 160 | if loss_weighting == 'v_weighting': 161 | assert objective == 'pred_v', f'objective {objective} must be pred_v for v weighting' 162 | self.diffusion_loss_weighting_obj = V_Weighting() 163 | self.diffusion_loss_weighting = self.diffusion_loss_weighting_obj.v_loss_weighting 164 | elif loss_weighting == 'lognormal_v_weighting': 165 | gamma_mean = loss_weighting_args['loss_weighting_mean'] 166 | gamma_std = loss_weighting_args['loss_weighting_std'] 167 | assert objective == 'pred_v', f'objective {objective} must be pred_v for lognormal v weighting' 168 | self.diffusion_loss_weighting_obj = LogNormal_V_Weighting(gamma_mean=gamma_mean, gamma_std=gamma_std, objective=objective) 169 | self.diffusion_loss_weighting = self.diffusion_loss_weighting_obj.v_loss_weighting 170 | elif loss_weighting == 'asymmetric_lognormal_v_weighting': 171 | gamma_mean = loss_weighting_args['loss_weighting_mean'] 172 | gamma_std = loss_weighting_args['loss_weighting_std'] 173 | gamma_std_mult = loss_weighting_args['loss_weighting_std_mult'] 174 | assert objective == 'pred_v', f'objective {objective} must be pred_v for lognormal v weighting' 175 | self.diffusion_loss_weighting_obj = Asymmetric_LogNormal_V_Weighting(gamma_mean=gamma_mean, gamma_std=gamma_std, objective=objective, std_mult=gamma_std_mult) 176 | self.diffusion_loss_weighting = self.diffusion_loss_weighting_obj.v_loss_weighting 177 | else: 178 | raise ValueError(f'invalid loss weighting {loss_weighting}') 179 | else: 180 | self.logsnr_loss_tracker = LossEMASampler() 181 | if train_schedule == "simple_linear": 182 | alpha_schedule = simple_linear_schedule 183 | elif train_schedule == "beta_linear": 184 | alpha_schedule = beta_linear_schedule 185 | elif train_schedule == "cosine": 186 | alpha_schedule = cosine_schedule 187 | elif train_schedule == "sigmoid": 188 | alpha_schedule = sigmoid_schedule 189 | else: 190 | raise ValueError(f'invalid noise schedule {train_schedule}') 191 | 192 | self.train_schedule = partial(time_to_alpha, alpha_schedule=alpha_schedule, scale=scale) 193 | 194 | # Sampling schedule 195 | if sampling_schedule is None: 196 | sampling_alpha_schedule = None 197 | elif sampling_schedule == "simple_linear": 198 | sampling_alpha_schedule = simple_linear_schedule 199 | elif sampling_schedule == "beta_linear": 200 | sampling_alpha_schedule = beta_linear_schedule 201 | elif sampling_schedule == "cosine": 202 | sampling_alpha_schedule = cosine_schedule 203 | elif sampling_schedule == "sigmoid": 204 | sampling_alpha_schedule = sigmoid_schedule 205 | else: 206 | raise ValueError(f'invalid sampling schedule {sampling_schedule}') 207 | 208 | if exists(sampling_alpha_schedule): 209 | self.sampling_schedule = partial(time_to_alpha, alpha_schedule=sampling_alpha_schedule, scale=scale) 210 | else: 211 | self.sampling_schedule = self.train_schedule 212 | 213 | 214 | # Optionally rescale data to have unit variance 215 | self.scale_by_std = scale_by_std 216 | if scale_by_std: 217 | self.register_buffer('data_mean', torch.full((128,), fill_value=0.0)) 218 | self.register_buffer('data_std', torch.full((128,), fill_value=-1.0)) 219 | else: 220 | raise NotImplementedError 221 | 222 | # gamma schedules 223 | 224 | self.sampling_timesteps = sampling_timesteps 225 | 226 | # probability for self conditioning during training 227 | 228 | self.unconditional_prob = unconditional_prob 229 | self.inpainting_prob = inpainting_prob 230 | 231 | if self.unconditional_prob > 0: 232 | self.unconditional_bernoulli = torch.distributions.Bernoulli(probs=self.unconditional_prob) 233 | if self.inpainting_prob > 0: 234 | self.inpainting_bernoulli = torch.distributions.Bernoulli(probs=self.inpainting_prob) 235 | # Mode/Concentration parameterization of Beta distribution 236 | alpha = inpainting_duration_mode*(inpainting_duration_concentration-2) + 1 237 | beta = (1-inpainting_duration_mode)*(inpainting_duration_concentration-2)+1 238 | self.inpainting_duration_beta = torch.distributions.Beta(alpha, beta) 239 | 240 | self.text_encoder_id = text_encoder 241 | self.text_encoder = T5ForConditionalGeneration.from_pretrained(text_encoder, torch_dtype=torch.bfloat16).get_encoder() 242 | for param in self.text_encoder.parameters(): 243 | param.requires_grad = False 244 | self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder) 245 | 246 | self.audio_codec = EncodecWrapper() 247 | for param in self.audio_codec.parameters(): 248 | param.requires_grad = False 249 | 250 | 251 | def predict_start_from_noise(self, z_t, alpha, noise, sampling=False): 252 | alpha = right_pad_dims_to(z_t, alpha) 253 | 254 | return (z_t - (1-alpha).sqrt() * noise) / alpha.sqrt().clamp(min = 1e-8) 255 | 256 | def predict_noise_from_start(self, z_t, alpha, x0, sampling=False): 257 | alpha = right_pad_dims_to(z_t, alpha) 258 | 259 | return (z_t - alpha.sqrt() * x0) / (1-alpha).sqrt().clamp(min = 1e-8) 260 | 261 | def predict_start_from_v(self, z_t, alpha, v, sampling=False): 262 | alpha = right_pad_dims_to(z_t, alpha) 263 | 264 | x = alpha.sqrt() * z_t - (1-alpha).sqrt() * v 265 | 266 | return x 267 | 268 | def predict_noise_from_v(self, z_t, alpha, v, sampling=False): 269 | alpha = right_pad_dims_to(z_t, alpha) 270 | 271 | eps = (1-alpha).sqrt() * z_t + alpha.sqrt() * v 272 | 273 | return eps 274 | 275 | def predict_v_from_start_and_eps(self, z_t, alpha, x, noise, sampling=False): 276 | alpha = right_pad_dims_to(z_t, alpha) 277 | 278 | v = alpha.sqrt() * noise - x* (1-alpha).sqrt() 279 | 280 | return v 281 | 282 | def diffusion_model_predictions(self, z_t, alpha, *, text_cond, text_cond_mask, audio_mask, sampling=False, cls_free_guidance=1.0, fill_mask=None): 283 | time_cond = alpha.sqrt() 284 | inpainting_mask = fill_mask[:, 0, :].long() 285 | if sampling: 286 | model_output = self.denoising_network.ema_model(z_t, time_cond, text_cond=text_cond, text_cond_mask=text_cond_mask, inpainting_mask=inpainting_mask, audio_mask=audio_mask) 287 | if cls_free_guidance != 1.0: 288 | unc_text_cond = torch.zeros_like(text_cond)[:,:1,:] 289 | unc_text_cond_mask = torch.full_like(text_cond_mask, fill_value=False)[:,:1] 290 | if torch.sum(fill_mask) > 0: 291 | alpha = rearrange(alpha, 'b -> b () ()') 292 | noise = torch.randn_like(z_t) 293 | z_t[fill_mask] = (z_t*alpha.sqrt() + (1-alpha).sqrt()*noise)[fill_mask] 294 | unc_inpainting_mask = torch.full_like(inpainting_mask, fill_value=0) 295 | unc_model_output = self.denoising_network.ema_model(z_t, time_cond, text_cond=unc_text_cond, text_cond_mask=unc_text_cond_mask, inpainting_mask=unc_inpainting_mask, audio_mask=audio_mask) 296 | model_output = model_output*cls_free_guidance + unc_model_output*(1-cls_free_guidance) 297 | else: 298 | model_output = self.denoising_network.online_model(z_t, time_cond, text_cond=text_cond, text_cond_mask=text_cond_mask, inpainting_mask=inpainting_mask, audio_mask=audio_mask) 299 | 300 | if self.parameterization == 'pred_v': 301 | pred_v = model_output 302 | x_start = self.predict_start_from_v(z_t, alpha, pred_v, sampling=sampling) 303 | pred_eps = self.predict_noise_from_v(z_t, alpha, pred_v, sampling=sampling) 304 | else: 305 | raise ValueError(f'invalid objective {self.parameterization}') 306 | 307 | return ModelPrediction(pred_eps, x_start, pred_v) 308 | 309 | def get_sampling_timesteps(self, batch, *, device, start_time=1.0): 310 | times = torch.linspace(start_time, 0., self.sampling_timesteps + 1, device = device) 311 | times = repeat(times, 't -> b t', b = batch) 312 | times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0) 313 | times = times.unbind(dim = -1) 314 | return times 315 | 316 | @torch.no_grad() 317 | def ddim_or_ddpm_sample(self, shape, text_cond, text_cond_mask, audio_mask, prefix_seconds=0, audio_latent=None, cls_free_guidance=1.0, speaker_frames=None, sampler='ddim', log_var_interp=.2, langevin_step_size=0.0, snr_ddim_threshold=None): 318 | batch, device = shape[0], next(self.denoising_network.ema_model.parameters()).device 319 | 320 | time_pairs = self.get_sampling_timesteps(batch, device = device) 321 | 322 | z_t = torch.randn(shape, device=device) 323 | 324 | fill_mask = None 325 | 326 | if prefix_seconds > 0: 327 | assert exists(audio_latent) 328 | if exists(speaker_frames): 329 | num_inpainting_frames = speaker_frames 330 | else: 331 | num_inpainting_frames = round(prefix_seconds*LATENT_SAMPLING_RATE) 332 | torch.full((batch), fill_value=num_inpainting_frames, dtype=torch.int, device=device) 333 | 334 | indices = torch.arange(z_t.shape[2], device=device) 335 | 336 | # Construct mask to insert clean data 337 | fill_mask = repeat((indices <= num_inpainting_frames[:, None]), 'b l -> b c l', c=z_t.shape[1]) 338 | else: 339 | fill_mask = torch.full_like(z_t, fill_value=0, dtype=torch.bool) 340 | 341 | x_start = None 342 | 343 | for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', total = self.sampling_timesteps): 344 | # get predicted x0 345 | if prefix_seconds > 0: 346 | z_t[fill_mask] = audio_latent[fill_mask] 347 | if exists(x_start): 348 | x_start[fill_mask] = audio_latent[fill_mask] 349 | 350 | # get alpha sigma of time and next time 351 | 352 | alpha = self.sampling_schedule(time) 353 | alpha_next = self.sampling_schedule(time_next) 354 | 355 | model_output = self.diffusion_model_predictions(z_t, alpha, text_cond=text_cond, text_cond_mask=text_cond_mask, audio_mask=audio_mask, sampling=True, cls_free_guidance=cls_free_guidance, fill_mask=fill_mask) 356 | 357 | alpha, alpha_next = map(partial(right_pad_dims_to, z_t), (alpha, alpha_next)) 358 | 359 | # calculate x0 and noise 360 | 361 | x_start = model_output.pred_x_start 362 | 363 | eps = model_output.pred_eps 364 | 365 | if time_next[0] <= 0: 366 | z_t = x_start 367 | continue 368 | 369 | # get noise 370 | snr = alpha_to_shifted_log_snr(alpha_next)[0].item() 371 | if sampler == 'ddim' or (exists(snr_ddim_threshold) and snr > snr_ddim_threshold): 372 | z_t = x_start * alpha_next.sqrt() + eps * (1-alpha_next).sqrt() 373 | elif sampler == 'ddpm': 374 | # get noise 375 | noise = torch.randn_like(z_t) 376 | alpha_now = alpha/alpha_next 377 | 378 | min_var = torch.exp(torch.log1p(-alpha_next) - torch.log1p(-alpha)) * (1.0 -alpha_now) 379 | max_var = (1.0 - alpha_now) 380 | noise_param = log_var_interp 381 | sigma = torch.exp(noise_param * torch.log(max_var) + (1 - noise_param) * torch.log(min_var) ) 382 | z_t = 1/alpha_now.sqrt() * (z_t - (1-alpha_now)/(1-alpha).sqrt() * eps) + torch.sqrt(sigma) * noise 383 | 384 | if langevin_step_size > 0: 385 | if prefix_seconds > 0: 386 | z_t[fill_mask] = audio_latent[fill_mask] 387 | if exists(x_start): 388 | x_start[fill_mask] = audio_latent[fill_mask] 389 | alpha_next = self.sampling_schedule(time_next) 390 | model_output = self.diffusion_model_predictions(z_t, alpha_next, text_cond=text_cond, text_cond_mask=text_cond_mask, audio_mask=audio_mask, sampling=True, cls_free_guidance=cls_free_guidance, fill_mask=fill_mask) 391 | alpha_next = right_pad_dims_to(z_t, alpha_next) 392 | noise = torch.randn_like(z_t) 393 | eps = model_output.pred_eps 394 | z_t = z_t - .5*langevin_step_size*(1-alpha_next).sqrt()*eps + math.sqrt(langevin_step_size)*(1-alpha_next).sqrt()*noise 395 | if prefix_seconds > 0: 396 | z_t[fill_mask] = audio_latent[fill_mask] 397 | return z_t 398 | 399 | def normalize_audio_latent(self, audio_latent): 400 | return (audio_latent - rearrange(self.data_mean, 'c -> () c ()')) / rearrange(self.data_std, 'c -> () c ()') 401 | 402 | def unnormalize_audio_latent(self, audio_latent): 403 | return audio_latent * rearrange(self.data_std, 'c -> () c ()') + rearrange(self.data_mean, 'c -> () c ()') 404 | 405 | @torch.no_grad() 406 | def sample(self, data, prefix_seconds=0, cls_free_guidance=1.0): 407 | # [B, L, d_lm]: Embedded text 408 | if prefix_seconds > 0: 409 | merged_text = [' '.join((speaker_text, text)) for speaker_text, text in zip(data['speaker_text'], data['text'])] 410 | tokenizer_output = self.text_tokenizer(merged_text, padding="max_length", truncation=True, max_length=256, return_tensors='pt').to(data['wav'].device) 411 | text_cond = self.text_encoder(tokenizer_output['input_ids'], tokenizer_output['attention_mask']).last_hidden_state.float() 412 | text_cond_mask = tokenizer_output['attention_mask'].bool() 413 | audio_latent = self.audio_codec.encode(data['speaker_wav']) 414 | speaker_frames = torch.floor(data['speaker_audio_duration'] * LATENT_SAMPLING_RATE).int() 415 | audio_mask = data['inpaint_audio_mask'] 416 | else: 417 | text_cond = self.text_encoder(data['input_ids'], data['attention_mask']).last_hidden_state.float() 418 | text_cond_mask = data['attention_mask'].bool() 419 | # [B, d_audio, L] 420 | audio_latent = self.audio_codec.encode(data['wav']) 421 | speaker_frames = None 422 | audio_mask = data['audio_mask'] 423 | audio_latent = self.normalize_audio_latent(audio_latent) 424 | latent_shape = audio_latent.shape 425 | assert self.sampler in {'ddim', 'ddpm'} 426 | sampled_audio_latent = self.ddim_or_ddpm_sample(latent_shape, text_cond, text_cond_mask, prefix_seconds=prefix_seconds, audio_latent=audio_latent, cls_free_guidance=cls_free_guidance, speaker_frames=speaker_frames, audio_mask=audio_mask, sampler=self.sampler, log_var_interp=self.log_var_interp, langevin_step_size=self.langevin_step_size, snr_ddim_threshold=self.snr_ddim_threshold) 427 | return self.unnormalize_audio_latent(sampled_audio_latent) 428 | 429 | @property 430 | def loss_fn(self): 431 | if self.loss_type == 'l1': 432 | return F.l1_loss 433 | elif self.loss_type == 'l2': 434 | return F.mse_loss 435 | else: 436 | raise ValueError(f'invalid loss type {self.loss_type}') 437 | 438 | def inpainting_enabled(self): 439 | return self.inpainting_prob > 0 440 | 441 | def get_loss_emas(self, split='train'): 442 | assert self.adaptive_noise_schedule, 'EMA loss only available with adaptive noise schedule' 443 | if split == 'train': 444 | return self.train_schedule.get_loss_emas() 445 | elif split == 'val': 446 | return self.val_ema_sampler.get_loss_emas() 447 | else: 448 | raise ValueError(f'invalid split {split}') 449 | 450 | def get_normalized_loss_emas(self, split='train'): 451 | assert self.adaptive_noise_schedule, 'EMA loss only available with adaptive noise schedule' 452 | if split == 'train': 453 | return self.train_schedule.get_normalized_loss_emas() 454 | elif split == 'val': 455 | return self.val_ema_sampler.get_normalized_loss_emas() 456 | else: 457 | raise ValueError(f'invalid split {split}') 458 | 459 | def get_unweighted_loss_emas(self, split='train'): 460 | if self.adaptive_noise_schedule: 461 | loss_ema = self.train_schedule 462 | else: 463 | assert split == 'train', 'EMA loss only available for train split without adaptive noise schedule' 464 | loss_ema = self.logsnr_loss_tracker 465 | if split == 'train': 466 | return loss_ema.get_unweighted_loss_emas() 467 | elif split == 'val': 468 | return self.val_ema_sampler.get_unweighted_loss_emas() 469 | else: 470 | raise ValueError(f'invalid split {split}') 471 | 472 | def save_adaptive_noise_schedule(self, path): 473 | if not self.adaptive_noise_schedule: 474 | ema_unweighted_loss_path = os.path.join(path, f'ema_unweighted_loss.png') 475 | self.logsnr_loss_tracker.save_unweighted_loss_emas(path=ema_unweighted_loss_path) 476 | return 477 | # save train plots 478 | density_path = os.path.join(path, f'schedule_density.png') 479 | self.train_schedule.save_density(path=density_path) 480 | cdf_path = os.path.join(path, f'schedule_cdf.png') 481 | self.train_schedule.save_cumulative_density(path=cdf_path) 482 | ema_loss_path = os.path.join(path, f'ema_loss.png') 483 | self.train_schedule.save_loss_emas(path=ema_loss_path) 484 | ema_unweighted_loss_path = os.path.join(path, f'ema_unweighted_loss.png') 485 | self.train_schedule.save_unweighted_loss_emas(path=ema_unweighted_loss_path) 486 | 487 | # save val plots 488 | density_path = os.path.join(path, f'schedule_density_val.png') 489 | self.val_ema_sampler.save_density(path=density_path) 490 | cdf_path = os.path.join(path, f'schedule_cdf_val.png') 491 | self.val_ema_sampler.save_cumulative_density(path=cdf_path) 492 | ema_loss_path = os.path.join(path, f'ema_loss_val.png') 493 | self.val_ema_sampler.save_loss_emas(path=ema_loss_path) 494 | ema_unweighted_loss_path = os.path.join(path, f'ema_unweighted_loss_val.png') 495 | self.val_ema_sampler.save_unweighted_loss_emas(path=ema_unweighted_loss_path) 496 | 497 | def adaptive_noise_schedule_enabled(self): 498 | return self.adaptive_noise_schedule 499 | 500 | def forward(self, data, accelerator=None): 501 | 502 | with torch.no_grad(): 503 | # [B, L, d_lm]: Embedded text 504 | text_cond = self.text_encoder(data['input_ids'], data['attention_mask']).last_hidden_state.float() 505 | # [B, L]: Cross-attn mask 506 | text_cond_mask = data['attention_mask'].bool() 507 | 508 | 509 | # [B, d_audio, L]: embedded audio 510 | with torch.cuda.amp.autocast(enabled=False): 511 | audio_latent = self.audio_codec.encode(data['wav']) 512 | 513 | if self.scale_by_std: 514 | # Estimate standard deviation of the data from the first batch 515 | if self.data_std[0].item() < 0: 516 | gathered_audio_latent = accelerator.gather(audio_latent) 517 | gathered_audio_mask = accelerator.gather(data['audio_mask']) 518 | min_length = gathered_audio_mask.sum(1).min().item() 519 | 520 | # Compute mean and std of the data 521 | data_mean = reduce(gathered_audio_latent[:,:,:min_length], 'b c l -> c', 'mean') 522 | data_std = reduce(gathered_audio_latent[:,:,:min_length], 'b c l -> c', torch.std) 523 | self.data_mean = data_mean 524 | self.data_std = data_std 525 | 526 | print(f'Set data mean: {self.data_mean.tolist()}') 527 | print(f'Set data std: {self.data_std.tolist()}') 528 | 529 | audio_latent = self.normalize_audio_latent(audio_latent) 530 | 531 | batch, audio_channels, audio_length = audio_latent.shape 532 | device = audio_latent.device 533 | 534 | # Mask out text-conditioning with some probability to enable clf-free guidance 535 | if self.unconditional_prob > 0: 536 | unconditional_mask = self.unconditional_bernoulli.sample((batch,)).bool() 537 | text_cond_mask[unconditional_mask, :] = False 538 | 539 | # sample random times 540 | 541 | if self.adaptive_noise_schedule and self.training: 542 | gamma, schedule_density = self.train_schedule.sample( 543 | batch_size=batch, device=device) 544 | alpha = log_snr_to_alpha(gamma) 545 | elif self.adaptive_noise_schedule: 546 | gamma, schedule_density = self.train_schedule.sample( 547 | batch_size=batch, device=device, uniform=True) 548 | alpha = log_snr_to_alpha(gamma) 549 | else: 550 | times = torch.zeros((batch,), device = device).float().uniform_(0, 1.) 551 | alpha = self.train_schedule(times) 552 | gamma = alpha_to_shifted_log_snr(alpha, scale=1) 553 | padded_alpha = right_pad_dims_to(audio_latent, alpha) 554 | 555 | # noise sample 556 | noise = torch.randn_like(audio_latent) 557 | 558 | z_t = padded_alpha.sqrt() * audio_latent + (1-padded_alpha).sqrt() * noise 559 | 560 | # Inpainting logic 561 | inpainting_mask = None 562 | if self.inpainting_prob > 0: 563 | inpainting_batch_mask = self.inpainting_bernoulli.sample((batch,)).bool().to(device) 564 | # Sample durations to mask 565 | inpainting_durations = self.inpainting_duration_beta.sample((batch,)).to(device) * data['audio_duration'] 566 | num_inpainting_frames = torch.round(inpainting_durations*LATENT_SAMPLING_RATE).int() 567 | 568 | # Sample where to mask 569 | indices = torch.arange(audio_length, device=device) 570 | 571 | # Construct mask to insert clean data 572 | inpainting_length_mask = ((indices <= num_inpainting_frames[:, None])) 573 | inpainting_mask = (inpainting_length_mask) & inpainting_batch_mask.unsqueeze(-1) 574 | fill_mask = repeat(inpainting_mask, 'b l -> b c l', c=audio_channels) 575 | 576 | z_t[fill_mask] = audio_latent[fill_mask] 577 | else: 578 | fill_mask = torch.full_like(z_t, fill_value=0, dtype=torch.bool) 579 | 580 | # predict and take gradient step 581 | predictions = self.diffusion_model_predictions(z_t, alpha, text_cond=text_cond, text_cond_mask=text_cond_mask, fill_mask=fill_mask, audio_mask=data['audio_mask']) 582 | 583 | 584 | if self.objective == 'pred_eps': 585 | target = noise 586 | pred = predictions.pred_eps 587 | elif self.objective == 'pred_v': 588 | # V-prediction from https://openreview.net/forum?id=TIdIXIpzhoI 589 | velocity = padded_alpha.sqrt() * noise - (1-padded_alpha).sqrt() * audio_latent 590 | target = velocity 591 | pred = predictions.pred_v 592 | else: 593 | raise ValueError(f'invalid objective {self.objective}') 594 | 595 | loss = self.loss_fn(pred, target, reduction = 'none') 596 | loss = reduce(loss, 'b c l -> b l', 'mean') 597 | if self.inpainting_prob > 0: 598 | # Need to align with gamma 599 | # Standard diffusion loss 600 | diff_batch = loss[~inpainting_batch_mask] 601 | diff_loss = masked_mean(diff_batch, dim=1, mask=data['audio_mask'][~inpainting_batch_mask]) 602 | 603 | # Masked inpainting loss 604 | inpainting_batch = loss[inpainting_batch_mask] 605 | loss_mask = torch.logical_and((~inpainting_length_mask[inpainting_batch_mask]), data['audio_mask'][inpainting_batch_mask]) 606 | inpainting_loss = masked_mean(inpainting_batch, dim=1, mask=loss_mask) 607 | loss = torch.cat([diff_loss, inpainting_loss], dim=0) 608 | 609 | gamma_diff_batch = gamma[~inpainting_batch_mask] 610 | gamma_inpainting_batch = gamma[inpainting_batch_mask] 611 | gamma = torch.cat([gamma_diff_batch, gamma_inpainting_batch], dim=0) 612 | 613 | if self.adaptive_noise_schedule: 614 | schedule_density_diff_batch = schedule_density[~inpainting_batch_mask] 615 | schedule_density_inpainting_batch = schedule_density[inpainting_batch_mask] 616 | schedule_density = torch.cat([schedule_density_diff_batch, schedule_density_inpainting_batch], dim=0) 617 | else: 618 | loss = masked_mean(loss, dim=1, mask=data['audio_mask']) 619 | 620 | if not self.adaptive_noise_schedule: 621 | self.logsnr_loss_tracker.update_with_all_unweighted_losses(gamma.squeeze(), loss) 622 | return loss.mean() 623 | 624 | diffusion_loss_weighting = self.diffusion_loss_weighting(gamma=gamma).squeeze() 625 | weighted_loss = diffusion_loss_weighting * loss 626 | # Update loss ema 627 | # Disable autocast for this step 628 | with torch.cuda.amp.autocast(enabled=False): 629 | if self.training: 630 | self.train_schedule.update_with_all_losses(gamma.squeeze(), weighted_loss) 631 | self.train_schedule.update_with_all_unweighted_losses(gamma.squeeze(), loss) 632 | else: 633 | self.val_ema_sampler.update_with_all_losses(gamma.squeeze(), weighted_loss) 634 | self.val_ema_sampler.update_with_all_unweighted_losses(gamma.squeeze(), loss) 635 | monte_carlo_weighted_loss = torch.exp(torch.log(diffusion_loss_weighting) - torch.log(schedule_density))*loss 636 | return (monte_carlo_weighted_loss).mean() 637 | 638 | # trainer class 639 | 640 | class Trainer(object): 641 | def __init__( 642 | self, 643 | args, 644 | diffusion, 645 | dataset_name, 646 | *, 647 | optimizer = 'adamw', 648 | batch_size = 16, 649 | gradient_accumulate_every = 1, 650 | train_lr = 1e-4, 651 | train_num_steps = 100000, 652 | lr_schedule = 'cosine', 653 | num_warmup_steps = 500, 654 | adam_betas = (0.9, 0.999), 655 | adam_weight_decay = 0.01, 656 | save_and_sample_every = 5000, 657 | num_samples = 25, 658 | mixed_precision = 'no', 659 | prefix_inpainting_seconds=0, 660 | duration_path=None, 661 | seed=None 662 | ): 663 | super().__init__() 664 | 665 | assert prefix_inpainting_seconds in {0., 3.0} or args.kilian, 'Currently only supports 3sec for inpainting' 666 | if exists(seed): 667 | set_seeds(seed) 668 | 669 | self.args = args 670 | 671 | self.accelerator = Accelerator( 672 | mixed_precision = mixed_precision, 673 | log_with='wandb', 674 | ) 675 | self.num_devices = self.accelerator.num_processes 676 | args.num_devices = self.num_devices 677 | 678 | args.output_dir = get_output_dir(args) 679 | 680 | if self.accelerator.is_main_process: 681 | os.makedirs(args.output_dir) 682 | print(f'Created {args.output_dir}') 683 | 684 | with open(os.path.join(args.output_dir, 'args.json'), 'w') as f: 685 | json.dump(args.__dict__, f, indent=2) 686 | run = os.path.split(__file__)[-1].split(".")[0] 687 | self.accelerator.init_trackers(run, config=vars(args), init_kwargs={"wandb": {"dir": args.output_dir, "name": args.run_name}}) 688 | 689 | 690 | self.diffusion = diffusion 691 | 692 | self.num_samples = num_samples 693 | self.save_and_sample_every = save_and_sample_every 694 | self.prefix_inpainting_seconds = prefix_inpainting_seconds 695 | 696 | self.batch_size = batch_size 697 | self.gradient_accumulate_every = gradient_accumulate_every 698 | 699 | self.train_num_steps = train_num_steps 700 | self.max_seq_len = diffusion.max_seq_len 701 | 702 | 703 | # dataset and dataloader 704 | if dataset_name == 'librispeech': 705 | self.dataset = LibriSpeech(split='train', tokenizer=diffusion.text_tokenizer) 706 | self.val_dataset = LibriSpeech(split='valid', tokenizer=diffusion.text_tokenizer) 707 | self.test_dataset = LibriSpeech(split='test', tokenizer=diffusion.text_tokenizer, max_seq_len=self.dataset.max_seq_len, duration_path=duration_path) 708 | else: 709 | raise ValueError(f'invalid dataset: {dataset_name}') 710 | 711 | self.dataloader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, pin_memory=True, num_workers=2) 712 | self.val_dataloader = DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False) 713 | self.test_dataloader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False) 714 | 715 | 716 | # optimizer 717 | 718 | if optimizer == 'adamw': 719 | self.opt = get_adamw_optimizer(diffusion.parameters(), lr = train_lr, betas = adam_betas, weight_decay=adam_weight_decay) 720 | elif optimizer == 'lion': 721 | self.opt = get_lion_optimizer(diffusion.parameters(), lr = train_lr, weight_decay=adam_weight_decay) 722 | else: 723 | raise ValueError(f'invalid optimizer {optimizer}') 724 | 725 | # scheduler 726 | 727 | lr_scheduler = get_scheduler( 728 | lr_schedule, 729 | optimizer=self.opt, 730 | num_warmup_steps=num_warmup_steps*self.num_devices, 731 | num_training_steps=train_num_steps*self.num_devices, # Accelerate does num_devices steps at a time 732 | ) 733 | 734 | # for logging results in a folder periodically 735 | 736 | if self.accelerator.is_main_process: 737 | 738 | self.results_folder = args.output_dir 739 | 740 | # step counter state 741 | 742 | self.step = 0 743 | 744 | # prepare model, dataloader, optimizer with accelerator 745 | 746 | self.diffusion, self.opt, self.dataloader, self.val_dataloader, self.test_dataloader, self.lr_scheduler = self.accelerator.prepare(self.diffusion, self.opt, self.dataloader, self.val_dataloader, self.test_dataloader, lr_scheduler) 747 | self.data_iter = cycle(self.dataloader) 748 | self.val_data_iter = cycle(self.val_dataloader) 749 | 750 | def save(self, save_step=False): 751 | if not self.accelerator.is_local_main_process: 752 | return 753 | 754 | data = { 755 | 'step': self.step, 756 | 'model': self.accelerator.get_state_dict(self.diffusion), 757 | 'opt': self.opt.state_dict(), 758 | 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None, 759 | 'scheduler': self.lr_scheduler.state_dict(), 760 | } 761 | if save_step: 762 | torch.save(data, f'{self.results_folder}/model_{self.step}.pt') 763 | else: 764 | torch.save(data, f'{self.results_folder}/model.pt') 765 | 766 | def load(self, file_path=None, best=False, init_only=False, ckpt_step=None): 767 | file_path = file_path if exists(file_path) else self.results_folder 768 | accelerator = self.accelerator 769 | device = accelerator.device 770 | 771 | if ckpt_step is not None: 772 | data = torch.load(f'{file_path}/model_{ckpt_step}.pt', map_location=device) 773 | else: 774 | data = torch.load(f'{file_path}/model.pt', map_location=device) 775 | 776 | model = self.accelerator.unwrap_model(self.diffusion) 777 | strict_load = not (init_only) 778 | model.load_state_dict(data['model'], strict=strict_load) 779 | 780 | if init_only: 781 | return 782 | 783 | # For backwards compatibility with earlier models 784 | if exists(self.accelerator.scaler) and exists(data['scaler']): 785 | self.accelerator.scaler.load_state_dict(data['scaler']) 786 | 787 | self.opt.load_state_dict(data['opt']) 788 | 789 | self.step = data['step'] 790 | self.lr_scheduler.load_state_dict(data['scheduler']) 791 | 792 | 793 | @torch.no_grad() 794 | def sample(self, num_samples=None, seed=None, cls_free_guidance=1.0, test=False, prefix_seconds=0.): 795 | if exists(seed): 796 | set_seeds(seed) 797 | diffusion = self.accelerator.unwrap_model(self.diffusion) 798 | num_samples = default(num_samples, self.num_samples) 799 | self.diffusion.eval() 800 | num_sampled = 0 801 | dataloader = self.test_dataloader if test else self.val_dataloader 802 | for batch in dataloader: 803 | sampled_codec_latents = diffusion.sample(batch, prefix_seconds=prefix_seconds, cls_free_guidance=cls_free_guidance) 804 | sampled_wavs = diffusion.audio_codec.decode(sampled_codec_latents).squeeze() # [B, L] 805 | 806 | sampled_wavs = self.accelerator.gather_for_metrics(sampled_wavs).to('cpu') 807 | 808 | input_ids = self.accelerator.gather_for_metrics(batch['input_ids']).to('cpu') 809 | 810 | speaker_durations = self.accelerator.gather_for_metrics(batch['speaker_audio_duration']).to('cpu') 811 | 812 | if prefix_seconds > 0: 813 | num_audio_frames = batch['inpaint_audio_mask'].sum(dim=1)*ENCODEC_REDUCTION_FACTOR # [B] 814 | num_audio_frames = self.accelerator.gather_for_metrics(num_audio_frames).to('cpu') 815 | else: 816 | num_audio_frames = batch['audio_mask'].sum(dim=1)*ENCODEC_REDUCTION_FACTOR # [B] 817 | num_audio_frames = self.accelerator.gather_for_metrics(num_audio_frames).to('cpu') 818 | 819 | 820 | if self.accelerator.is_main_process: 821 | inpainting_suffix = f'_prefix{prefix_seconds}' if prefix_seconds>0 else '' 822 | step_folder = os.path.join(self.results_folder, 'samples', f'step_{self.step}') 823 | samples_folder = os.path.join(step_folder, f'guide{cls_free_guidance}{inpainting_suffix}') 824 | os.makedirs(samples_folder, exist_ok=True) 825 | text_list = [diffusion.text_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in input_ids] 826 | ref_frames = torch.ceil(speaker_durations*ENCODEC_SAMPLING_RATE).int() 827 | for idx in range(len(text_list)): 828 | text = text_list[idx] 829 | print(f'Saving idx: {idx+num_sampled}') 830 | with open(os.path.join(samples_folder, f'text_{idx+num_sampled}.txt'), 'w') as f: 831 | print(text, file=f) 832 | if prefix_seconds > 0: 833 | ref_frames_idx = ref_frames[idx].item() 834 | sf.write(os.path.join(samples_folder, f'audio_{idx+num_sampled}.wav'), sampled_wavs[idx][ref_frames_idx:num_audio_frames[idx].item()], ENCODEC_SAMPLING_RATE) 835 | sf.write(os.path.join(samples_folder, f'ref_{idx+num_sampled}.wav'), sampled_wavs[idx][:ref_frames_idx], ENCODEC_SAMPLING_RATE) 836 | else: 837 | sample_wav = sampled_wavs[idx][:num_audio_frames[idx].item()] 838 | sf.write(os.path.join(samples_folder, f'audio_{idx+num_sampled}.wav'), sample_wav, ENCODEC_SAMPLING_RATE) 839 | 840 | batch_size = self.num_devices * batch['wav'].shape[0] 841 | num_sampled += batch_size 842 | 843 | 844 | if exists(num_samples) and num_sampled >= num_samples: 845 | break 846 | self.diffusion.train() 847 | if self.accelerator.is_main_process and cls_free_guidance == 1.0: 848 | self.diffusion.save_adaptive_noise_schedule(step_folder) 849 | 850 | 851 | def train(self): 852 | accelerator = self.accelerator 853 | device = accelerator.device 854 | 855 | with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar: 856 | 857 | while self.step < self.train_num_steps: 858 | 859 | total_loss = 0. 860 | 861 | for _ in range(self.gradient_accumulate_every): 862 | data = next(self.data_iter) 863 | loss = self.diffusion(data, accelerator) 864 | loss = loss / self.gradient_accumulate_every 865 | total_loss += loss.item() 866 | 867 | self.accelerator.backward(loss) 868 | 869 | 870 | if accelerator.sync_gradients: 871 | grad_norm = compute_grad_norm(self.diffusion.parameters()) 872 | accelerator.clip_grad_norm_(self.diffusion.parameters(), self.args.clip_grad_norm) 873 | self.opt.step() 874 | self.lr_scheduler.step() 875 | self.opt.zero_grad() 876 | self.step += 1 877 | 878 | if self.step % 10 == 0: 879 | logs = { 880 | "train/loss": total_loss, 881 | "learning_rate": self.lr_scheduler.get_last_lr()[0], 882 | "grad_norm": grad_norm, 883 | "step": self.step, 884 | "epoch": (self.step*self.gradient_accumulate_every)/len(self.dataloader), 885 | "samples": self.step*self.batch_size*self.gradient_accumulate_every*self.num_devices, 886 | } 887 | if self.diffusion.adaptive_noise_schedule_enabled(): 888 | logs["train/weighted_loss_ema"] = self.diffusion.train_schedule.weights().mean() 889 | # Validation loss 890 | if self.step % 50 == 0: 891 | self.diffusion.eval() 892 | with torch.no_grad(): 893 | total_val_loss = 0 894 | data = next(self.val_data_iter) 895 | loss = self.diffusion(data) 896 | total_val_loss += loss.item() 897 | if self.diffusion.adaptive_noise_schedule_enabled(): 898 | logs['val/weighted_loss'] = self.diffusion.val_ema_sampler.weights().mean() 899 | self.diffusion.train() 900 | 901 | if accelerator.is_main_process: 902 | if self.diffusion.adaptive_noise_schedule_enabled(): 903 | loss_emas_dict = self.diffusion.get_loss_emas() 904 | for key, value in loss_emas_dict.items(): 905 | logs[f'train/ema_loss/{key}'] = value 906 | loss_emas_dict = self.diffusion.get_normalized_loss_emas() 907 | for key, value in loss_emas_dict.items(): 908 | logs[f'train/ema/normalized_ema_loss/{key}'] = value 909 | 910 | # Val loss 911 | loss_emas_dict = self.diffusion.get_loss_emas(split='val') 912 | for key, value in loss_emas_dict.items(): 913 | logs[f'val/loss/{key}'] = value 914 | loss_emas_dict = self.diffusion.get_unweighted_loss_emas(split='val') 915 | for key, value in loss_emas_dict.items(): 916 | logs[f'val/unweighted_loss/{key}'] = value 917 | loss_emas_dict = self.diffusion.get_unweighted_loss_emas() 918 | for key, value in loss_emas_dict.items(): 919 | logs[f'train/ema/unweighted_ema_loss/{key}'] = value 920 | pbar.set_postfix(**logs) 921 | accelerator.log(logs, step=self.step) 922 | 923 | accelerator.wait_for_everyone() 924 | # Update EMA 925 | accelerator.unwrap_model(self.diffusion).denoising_network.update() 926 | 927 | if self.step % self.save_and_sample_every == 0: 928 | self.sample() 929 | for cls_free_guidance in [3.0, 5.0]: 930 | self.sample(cls_free_guidance=cls_free_guidance) 931 | 932 | if self.prefix_inpainting_seconds > 0: 933 | self.sample(prefix_seconds=self.prefix_inpainting_seconds) 934 | for cls_free_guidance in [3.0, 5.0]: 935 | self.sample(cls_free_guidance=cls_free_guidance, prefix_seconds=self.prefix_inpainting_seconds) 936 | self.save(save_step=True) 937 | 938 | self.diffusion.train() 939 | 940 | pbar.update(1) 941 | 942 | # Save final model 943 | self.save() 944 | accelerator.end_training() 945 | accelerator.print('training complete') -------------------------------------------------------------------------------- /diffusion/noise_schedule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | from functools import partial 5 | import matplotlib.pyplot as plt 6 | import json 7 | 8 | # Avoid log(0) 9 | def log(t, eps = 1e-12): 10 | return torch.log(t.clamp(min = eps)) 11 | 12 | # noise schedules 13 | 14 | def simple_linear_schedule(t, clip_min = 1e-9): 15 | return (1 - t).clamp(min = clip_min) 16 | 17 | def beta_linear_schedule(t, clip_min = 1e-9): 18 | return torch.exp(-1e-4 - 10 * (t ** 2)).clamp(min = clip_min, max = 1.) 19 | 20 | def cosine_schedule(t, start = 0, end = 1, tau = 1, clip_min = 1e-9): 21 | power = 2 * tau 22 | v_start = math.cos(start * math.pi / 2) ** power 23 | v_end = math.cos(end * math.pi / 2) ** power 24 | output = torch.cos((t * (end - start) + start) * math.pi / 2) ** power 25 | output = (v_end - output) / (v_end - v_start) 26 | return output.clamp(min = clip_min) 27 | 28 | def sigmoid_schedule(t, start = -3, end = 3, tau = 1, clamp_min = 1e-9): 29 | v_start = torch.tensor(start / tau).sigmoid() 30 | v_end = torch.tensor(end / tau).sigmoid() 31 | gamma = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start) 32 | return gamma.clamp_(min = clamp_min, max = 1.) 33 | 34 | def sigmoid_k_weighting(k, gamma): 35 | return torch.sigmoid(-gamma + k) 36 | 37 | class V_Weighting: 38 | def __init__(self, objective='pred_v'): 39 | self.objective = objective 40 | 41 | def v_loss_weighting(self, gamma): 42 | return 1/torch.cosh(-(gamma/2)) 43 | 44 | def eps_loss_weighting(self, gamma): 45 | return torch.exp(-gamma/(2)) 46 | 47 | class LogNormal_V_Weighting: 48 | def __init__(self, gamma_mean=0.0, gamma_std=0.0, objective='pred_v'): 49 | self.gamma_mean = gamma_mean 50 | self.gamma_std = gamma_std 51 | self.min_gamma = -15 52 | self.max_gamma = 15 53 | 54 | assert objective == 'pred_v' 55 | 56 | self.normal_dist = torch.distributions.normal.Normal(self.gamma_mean, self.gamma_std) 57 | 58 | self.log_max_weighting = self.normal_dist.log_prob(torch.tensor(gamma_mean)) 59 | 60 | def v_loss_weighting(self, gamma): 61 | return torch.exp(self.normal_dist.log_prob(gamma) - self.log_max_weighting) 62 | 63 | def v_weighting(self, gamma): 64 | return self.v_loss_weighting(gamma) 65 | 66 | class LogCauchy_V_Weighting: 67 | def __init__(self, gamma_mean=0.0, gamma_std=0.0, objective='pred_v'): 68 | self.gamma_mean = gamma_mean 69 | self.gamma_std = gamma_std 70 | self.min_gamma = -15 71 | self.max_gamma = 15 72 | 73 | assert objective == 'pred_v' 74 | 75 | self.cauchy_dist = torch.distributions.cauchy.Cauchy(self.gamma_mean, self.gamma_std) 76 | 77 | self.log_max_weighting = self.cauchy_dist.log_prob(torch.tensor(gamma_mean)) 78 | 79 | def v_loss_weighting(self, gamma): 80 | return torch.exp(self.cauchy_dist.log_prob(gamma) - self.log_max_weighting) 81 | 82 | def v_weighting(self, gamma): 83 | return self.v_loss_weighting(gamma) 84 | 85 | class Asymmetric_LogNormal_V_Weighting: 86 | def __init__(self, gamma_mean=0.0, gamma_std=0.0, std_mult=2.0, objective='pred_v'): 87 | self.gamma_mean = gamma_mean 88 | self.gamma_std = gamma_std 89 | self.min_gamma = -15 90 | self.max_gamma = 15 91 | 92 | assert objective == 'pred_v' 93 | 94 | self.neg_normal_v_weighting = LogCauchy_V_Weighting(gamma_mean, gamma_std*std_mult, objective='pred_v') 95 | self.normal_v_weighting = LogNormal_V_Weighting(gamma_mean, gamma_std, objective='pred_v') 96 | 97 | def v_loss_weighting(self, gamma): 98 | # Use normal weighting for gamma >= self.gamma_mean 99 | normal_weighting = self.normal_v_weighting.v_loss_weighting(gamma) 100 | # Use neg_normal weighting for gamma < self.gamma_mean 101 | neg_normal_weighting = self.neg_normal_v_weighting.v_loss_weighting(gamma) 102 | return torch.where(gamma < self.gamma_mean, neg_normal_weighting, normal_weighting) 103 | 104 | def v_weighting(self, gamma): 105 | # Use cauchy weighting for gamma < self.gamma_mean 106 | neg_normal_weighting = self.neg_normal_v_weighting.v_weighting(gamma) 107 | # Use normal weighting for gamma >= self.gamma_mean 108 | normal_weighting = self.normal_v_weighting.v_weighting(gamma) 109 | return torch.where(gamma < self.gamma_mean, neg_normal_weighting, normal_weighting) 110 | 111 | # converting gamma to alpha, sigma or logsnr 112 | def log_snr_to_alpha(log_snr): 113 | alpha = torch.sigmoid(log_snr) 114 | return alpha 115 | 116 | # Log-SNR shifting (https://arxiv.org/abs/2301.10972) 117 | def alpha_to_shifted_log_snr(alpha, scale = 1): 118 | return (log(alpha) - log(1 - alpha)).clamp(min=-20, max=20) + 2*np.log(scale).item() 119 | 120 | def time_to_alpha(t, alpha_schedule, scale): 121 | alpha = alpha_schedule(t) 122 | shifted_log_snr = alpha_to_shifted_log_snr(alpha, scale = scale) 123 | return log_snr_to_alpha(shifted_log_snr) 124 | 125 | def plot_noise_schedule(unscaled_sampling_schedule, name, y_value): 126 | assert y_value in {'alpha^2', 'alpha', 'log(SNR)'} 127 | t = torch.linspace(0, 1, 100) # 100 points between 0 and 1 128 | scales = [.2, .5, 1.0] 129 | for scale in scales: 130 | sampling_schedule = partial(time_to_alpha, alpha_schedule=unscaled_sampling_schedule, scale=scale) 131 | alphas = sampling_schedule(t) # Obtain noise schedule values for each t 132 | if y_value == 'alpha^2': 133 | y_axis_label = r'$\alpha^2_t$' 134 | y = alphas 135 | elif y_value == 'alpha': 136 | y_axis_label = r'$\alpha_t$' 137 | y = alphas.sqrt() 138 | elif y_value == 'log(SNR)': 139 | y_axis_label = r'$\log(\lambda_t)$' 140 | y = alpha_to_shifted_log_snr(alphas, scale=1) 141 | 142 | plt.plot(t.numpy(), y.numpy(), label=f'Scale: {scale:.1f}') 143 | if y_value == 'log(SNR)': 144 | plt.ylim(-15, 15) 145 | plt.xlabel('t') 146 | plt.ylabel(y_axis_label) 147 | plt.title(f'{name}') 148 | plt.legend() 149 | plt.savefig(f'viz/{name.lower()}_{y_value}.png') 150 | plt.clf() 151 | 152 | 153 | def plot_cosine_schedule(): 154 | t = torch.linspace(0, 1, 100) # 100 points between 0 and 1 155 | sampling_schedule = cosine_schedule 156 | alphas = sampling_schedule(t) # Obtain noise schedule values for each t 157 | y = alphas 158 | plt.plot(t.numpy(), y.numpy()) 159 | plt.xlabel('t') 160 | plt.ylabel(f'alpha^2') 161 | plt.title(f'Cosine Noise Schedule') 162 | plt.savefig(f'viz/standard_cosine.png') 163 | plt.clf() 164 | 165 | def plot_weighting_functions(): 166 | gamma = torch.linspace(-15, 15, 1000) 167 | 168 | 169 | # Plot v weighting 170 | v_obj = V_Weighting(gamma_shift=0.0, objective='pred_eps') 171 | v_weighting = v_obj.v_loss_weighting(gamma) 172 | 173 | # Log-normal V weighting with mean=0.0, std=2.0 174 | v_obj = LogNormal_V_Weighting(gamma_mean=-1.0, gamma_std=2.4, objective='pred_v') 175 | v_weighting_lognormal = v_obj.v_loss_weighting(gamma) 176 | 177 | # Log-cauchyNormal V weighting with mean=0.0, std=2.0 178 | v_obj = Asymmetric_LogNormal_V_Weighting(gamma_mean=-1.0, gamma_std=2.4, objective='pred_v', std_mult=2.0) 179 | v_weighting_logcauchynormal = v_obj.v_loss_weighting(gamma) 180 | 181 | # Create plot 182 | fig, ax = plt.subplots(figsize=(8, 4)) 183 | ax.plot(gamma.numpy(), v_weighting_logcauchynormal.numpy(), label='Asymmetric Weighting') 184 | ax.plot(gamma.numpy(), v_weighting_lognormal.numpy(), label='Symmetric Weighting') 185 | ax.plot(gamma.numpy(), v_weighting.numpy(), label='V-Weighting (VoiceBox)') 186 | ax.set_xlabel(r'Log-SNR ($\lambda_t$)', fontsize=14) 187 | ax.set_ylabel(r'Loss Weight ($w(\lambda_t)$)', fontsize=14) 188 | ax.set_title('Loss Weighting Across Noise Levels', fontsize=14) 189 | # Create legend 190 | ax.legend(loc='upper right', ncol=1, fontsize=12) 191 | plt.savefig('viz/v_weighting_pred.png', bbox_inches='tight') 192 | plt.clf() 193 | 194 | # Save log-scale version 195 | fig, ax = plt.subplots() 196 | ax.plot(gamma.numpy(), v_weighting.numpy(), label='V Weighting') 197 | ax.plot(gamma.numpy(), v_weighting_lognormal.numpy(), label='Log-Normal V Weighting (mean=-1.0, std=2.4)') 198 | ax.plot(gamma.numpy(), v_weighting_logcauchynormal.numpy(), label='Log-CauchyNormal V Weighting (mean=0.0, std=2.4)') 199 | ax.set_xlabel(r'$\lambda_t$') 200 | ax.set_ylabel('Weighting (V-Space)') 201 | ax.set_title('V Weighting with pred_v Objective') 202 | ax.legend() 203 | ax.set_yscale('log') 204 | plt.savefig('viz/v_weighting_pred_v_log2.png') 205 | plt.clf() 206 | 207 | # Make sure to call this function in your main visualization routine 208 | if __name__ == '__main__': 209 | plot_weighting_functions() 210 | plot_cosine_schedule() 211 | -------------------------------------------------------------------------------- /diffusion/optimizer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional, Callable 2 | 3 | import torch 4 | from torch.optim.optimizer import Optimizer 5 | from torch.optim import AdamW 6 | import math 7 | 8 | # functions 9 | 10 | def exists(val): 11 | return val is not None 12 | 13 | # update functions 14 | 15 | def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): 16 | # stepweight decay 17 | 18 | p.data.mul_(1 - lr * wd) 19 | 20 | # weight update 21 | 22 | update = exp_avg.clone().lerp_(grad, 1 - beta1).sign_() 23 | p.add_(update, alpha = -lr) 24 | 25 | # decay the momentum running average coefficient 26 | 27 | exp_avg.lerp_(grad, 1 - beta2) 28 | 29 | # class 30 | 31 | class Lion(Optimizer): 32 | def __init__( 33 | self, 34 | params, 35 | lr: float = 1e-4, 36 | betas: Tuple[float, float] = (0.9, 0.99), 37 | weight_decay: float = 0.0, 38 | ): 39 | assert lr > 0. 40 | assert all([0. <= beta <= 1. for beta in betas]) 41 | 42 | defaults = dict( 43 | lr = lr, 44 | betas = betas, 45 | weight_decay = weight_decay 46 | ) 47 | 48 | super().__init__(params, defaults) 49 | 50 | self.update_fn = update_fn 51 | 52 | @torch.no_grad() 53 | def step( 54 | self, 55 | closure: Optional[Callable] = None 56 | ): 57 | 58 | loss = None 59 | if exists(closure): 60 | with torch.enable_grad(): 61 | loss = closure() 62 | 63 | for group in self.param_groups: 64 | for p in filter(lambda p: exists(p.grad), group['params']): 65 | 66 | grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p] 67 | 68 | # init state - exponential moving average of gradient values 69 | 70 | if len(state) == 0: 71 | state['exp_avg'] = torch.zeros_like(p) 72 | 73 | exp_avg = state['exp_avg'] 74 | 75 | self.update_fn( 76 | p, 77 | grad, 78 | exp_avg, 79 | lr, 80 | wd, 81 | beta1, 82 | beta2 83 | ) 84 | 85 | return loss 86 | 87 | def get_corrected_weight_decay(lr, weight_decay): 88 | if weight_decay == 0: 89 | return 0 90 | corrected_weight_decay = math.exp(math.log(weight_decay) - math.log(lr)) 91 | return corrected_weight_decay 92 | 93 | def separate_weight_decayable_params(params): 94 | # Exclude affine params in norms (e.g. LayerNorm, GroupNorm, etc.) and bias terms 95 | no_wd_params = [param for param in params if param.ndim < 2] 96 | wd_params = [param for param in params if param not in set(no_wd_params)] 97 | return wd_params, no_wd_params 98 | 99 | def get_adamw_optimizer(params, lr, betas, weight_decay, eps=1e-8): 100 | params = list(params) 101 | wd_params, no_wd_params = separate_weight_decayable_params(params) 102 | 103 | # Parameterize weight decay independently of learning rate 104 | corrected_weight_decay = get_corrected_weight_decay(lr, weight_decay) 105 | 106 | param_groups = [ 107 | {'params': wd_params}, 108 | {'params': no_wd_params, 'weight_decay': 0}, 109 | ] 110 | 111 | return AdamW(param_groups, lr = lr, weight_decay = corrected_weight_decay, betas=betas, eps=eps) 112 | 113 | def get_lion_optimizer(params, lr, weight_decay): 114 | params = list(params) 115 | wd_params, no_wd_params = separate_weight_decayable_params(params) 116 | 117 | # Parameterize weight decay as corrected of learning rate 118 | corrected_weight_decay = get_corrected_weight_decay(lr, weight_decay) 119 | 120 | param_groups = [ 121 | {'params': wd_params}, 122 | {'params': no_wd_params, 'weight_decay': 0}, 123 | ] 124 | 125 | return Lion(param_groups, lr = lr, weight_decay = corrected_weight_decay) -------------------------------------------------------------------------------- /diffusion/time_sampler.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from einops import rearrange 6 | import matplotlib.pyplot as plt 7 | 8 | class LossEMASampler(nn.Module): 9 | def __init__(self, n_bins=100, ema_decay=0.9, gamma_min=-15, gamma_max=15): 10 | super().__init__() 11 | 12 | self.n_bins = n_bins 13 | self.ema_decay = ema_decay 14 | # Register loss bins as a buffer so that it is saved with the model 15 | self.register_buffer("_loss_bins", torch.ones((n_bins), dtype=torch.float64)) 16 | self.register_buffer("_unweighted_loss_bins", torch.ones((n_bins), dtype=torch.float64)) 17 | gamma_range = gamma_max - gamma_min 18 | self.bin_length = gamma_range/n_bins 19 | # Register step as a buffer so that it is saved with the model 20 | self.register_buffer("step", torch.tensor(0)) 21 | self.gamma_min = gamma_min 22 | self.gamma_max = gamma_max 23 | 24 | def set_bins_to_loss_weight(self, loss_weighting): 25 | gamma_range = self.gamma_max - self.gamma_min 26 | gammas = torch.arange(self.n_bins, dtype=torch.float64) * gamma_range / self.n_bins + self.gamma_min 27 | self._loss_bins = loss_weighting(gammas).to(self._loss_bins.device) 28 | 29 | 30 | def weights(self, ): 31 | weights = self._loss_bins.clone() 32 | return weights 33 | 34 | def update_with_all_losses(self, gamma, losses): 35 | for i in range(self.n_bins): 36 | gamma0 = i*self.bin_length + self.gamma_min 37 | gamma1 = (i+1)*self.bin_length + self.gamma_min 38 | 39 | bin_mask = (gamma >= gamma0) & (gamma < gamma1) 40 | if bin_mask.any(): 41 | self._loss_bins[i] = self.ema_decay * self._loss_bins[i] + (1-self.ema_decay) * losses[bin_mask].mean().item() 42 | self.step+=1 43 | 44 | def update_with_all_unweighted_losses(self, gamma, losses): 45 | for i in range(self.n_bins): 46 | gamma0 = i*self.bin_length + self.gamma_min 47 | gamma1 = (i+1)*self.bin_length + self.gamma_min 48 | 49 | bin_mask = (gamma >= gamma0) & (gamma < gamma1) 50 | if bin_mask.any(): 51 | self._unweighted_loss_bins[i] = self.ema_decay * self._unweighted_loss_bins[i] + (1-self.ema_decay) * losses[bin_mask].mean().item() 52 | self.step+=1 53 | 54 | def sample(self, batch_size, device, uniform=False): 55 | if uniform: 56 | gamma = torch.rand((batch_size), device=device) * (self.gamma_max - self.gamma_min) + self.gamma_min 57 | density = torch.ones((batch_size), device=device) 58 | return gamma, density 59 | else: 60 | bin_weights = self.weights().to(device) 61 | bins = torch.multinomial(bin_weights, batch_size, replacement=True).to(device) 62 | samples = torch.rand((batch_size), device=device) * self.bin_length 63 | gamma = (samples + bins * self.bin_length + self.gamma_min) 64 | # Check all samples in [-gamma_min, gamma_max] 65 | assert (gamma >= self.gamma_min).all() and (gamma <= self.gamma_max).all() 66 | density = bin_weights[bins] 67 | return gamma, density 68 | 69 | def save_density(self, path): 70 | plt.figure(figsize=(10, 5)) 71 | plt.plot(np.arange(self.n_bins)*self.bin_length + self.gamma_min, self.weights().cpu().numpy()) 72 | plt.xlabel("gamma") 73 | plt.ylabel("Density of Adaptive Noise Schedule") 74 | plt.grid(True) # Add grid lines 75 | plt.savefig(path) 76 | plt.close() 77 | 78 | def save_cumulative_density(self, path): 79 | plt.figure(figsize=(10, 5)) 80 | weights = self.weights().cpu().numpy() 81 | weights = weights/weights.sum() 82 | plt.plot(np.arange(self.n_bins)*self.bin_length + self.gamma_min, weights.cumsum()) 83 | plt.xlabel("gamma") 84 | plt.ylabel("Cumulative Density of Adaptive Noise Schedule") 85 | plt.grid(True) # Add grid lines 86 | plt.savefig(path) 87 | plt.close() 88 | 89 | def save_loss_emas(self, path): 90 | plt.figure(figsize=(10, 5)) 91 | plt.plot(np.arange(self.n_bins)*self.bin_length + self.gamma_min, self._loss_bins.cpu().numpy()) 92 | plt.xlabel("gamma") 93 | plt.ylabel("Loss EMAs") 94 | plt.grid(True) # Add grid lines 95 | plt.savefig(path) 96 | plt.close() 97 | 98 | def save_unweighted_loss_emas(self, path): 99 | plt.figure(figsize=(10, 5)) 100 | plt.plot(np.arange(self.n_bins)*self.bin_length + self.gamma_min, self._unweighted_loss_bins.cpu().numpy()) 101 | plt.xlabel("gamma") 102 | plt.ylabel("Unweighted Loss EMAs") 103 | plt.grid(True) # Add grid lines 104 | plt.savefig(path) 105 | plt.close() 106 | 107 | def get_loss_emas(self): 108 | # Return dict of loss emas where the key is the gamma range 109 | loss_emas = {} 110 | for i in range(self.n_bins): 111 | gamma0 = i*self.bin_length + self.gamma_min 112 | gamma1 = (i+1)*self.bin_length + self.gamma_min 113 | loss_emas[f"{gamma0:.2f}-{gamma1:.2f}"] = self._loss_bins[i].item() 114 | return loss_emas 115 | 116 | def get_unweighted_loss_emas(self): 117 | # Return dict of loss emas where the key is the gamma range 118 | loss_emas = {} 119 | for i in range(self.n_bins): 120 | gamma0 = i*self.bin_length + self.gamma_min 121 | gamma1 = (i+1)*self.bin_length + self.gamma_min 122 | loss_emas[f"{gamma0:.2f}-{gamma1:.2f}"] = self._unweighted_loss_bins[i].item() 123 | return loss_emas 124 | 125 | def get_normalized_loss_emas(self): 126 | # Return dict of loss emas where the key is the gamma range 127 | loss_emas = {} 128 | denominator = self._loss_bins.sum().item() 129 | for i in range(self.n_bins): 130 | gamma0 = i*self.bin_length + self.gamma_min 131 | gamma1 = (i+1)*self.bin_length + self.gamma_min 132 | loss_emas[f"{gamma0:.2f}-{gamma1:.2f}"] = self._loss_bins[i].item() / denominator 133 | return loss_emas 134 | -------------------------------------------------------------------------------- /evaluation/evaluate_transcript.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | from transformers import Wav2Vec2Processor, HubertForCTC, Wav2Vec2FeatureExtractor, WavLMForXVector 4 | from evaluate import load 5 | from tqdm import tqdm 6 | from einops import rearrange, reduce, repeat 7 | import sox 8 | import math 9 | 10 | 11 | def chunker(seq, size): 12 | return (seq[pos:pos + size] for pos in range(0, len(seq), size)) 13 | 14 | def trim_silence(wav, sample_rate): 15 | np_arr = wav.numpy().squeeze() 16 | # create a transformer 17 | tfm = sox.Transformer() 18 | tfm.silence(location=-1, silence_threshold=.1) 19 | # transform an in-memory array and return an array 20 | y_out = tfm.build_array(input_array=np_arr, sample_rate_in=sample_rate) 21 | duration = len(y_out)/sample_rate 22 | if duration < .5: 23 | return wav 24 | 25 | return torch.tensor(y_out).unsqueeze(0) 26 | 27 | 28 | @torch.inference_mode() 29 | def compute_wer(wavpath_list, text_list, wav_list=None, model_id='facebook/hubert-large-ls960-ft', truncate=False): 30 | processor = Wav2Vec2Processor.from_pretrained(model_id) 31 | model = HubertForCTC.from_pretrained(model_id).to('cuda') 32 | model.eval() 33 | 34 | if wav_list is None: 35 | wav_list = [torchaudio.load(wavpath) for wavpath in wavpath_list] 36 | waveform_list = [] 37 | sample_rate_list = [] 38 | for waveform, sample_rate in wav_list: 39 | waveform_list.append(waveform) 40 | sample_rate_list.append(sample_rate) 41 | 42 | 43 | asr_text = [] 44 | for i in tqdm(range(len(wav_list))): 45 | waveform = rearrange(waveform_list[i].squeeze(), 'l -> () l') 46 | waveform = torchaudio.functional.resample(waveform, sample_rate_list[0], processor.feature_extractor.sampling_rate) 47 | if truncate: 48 | waveform = trim_silence(waveform, processor.feature_extractor.sampling_rate) 49 | input_values = processor(waveform, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt").input_values.squeeze(0).to('cuda') 50 | logits = model(input_values).logits 51 | predicted_ids = torch.argmax(logits, dim=-1) 52 | transcription = [transcript.lower().strip() for transcript in processor.batch_decode(predicted_ids)] 53 | 54 | asr_text.extend(transcription) 55 | 56 | if text_list is None: 57 | return asr_text 58 | print(f'asr_text: {asr_text[:3]}') 59 | print(f'text_list: {text_list[:3]}') 60 | wer = load("wer") 61 | wer_score = wer.compute(predictions=asr_text, references=text_list) 62 | return wer_score 63 | -------------------------------------------------------------------------------- /img/esd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinlovelace/SESD/c59b41d7ec98922bdb6bce4217ebcf6cd1c0d2b2/img/esd.png -------------------------------------------------------------------------------- /img/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinlovelace/SESD/c59b41d7ec98922bdb6bce4217ebcf6cd1c0d2b2/img/results.png -------------------------------------------------------------------------------- /models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from models.modules.norm import RMSNorm, ConvRMSNorm, LayerNorm 2 | from models.modules.transformer import ConditionableTransformer 3 | from models.modules.conv import MaskedConv1d 4 | from models.modules.position import VariationalFourierFeatures, RelativePositionalEmbedding, AbsolutePositionalEmbedding 5 | from models.modules.blocks import FeedForward -------------------------------------------------------------------------------- /models/modules/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, reduce, repeat, pack, unpack 6 | from einops.layers.torch import Rearrange 7 | import math 8 | 9 | from models.modules.norm import RMSNorm 10 | 11 | def exists(x): 12 | return x is not None 13 | 14 | def zero_init_(m): 15 | nn.init.zeros_(m.weight) 16 | if exists(m.bias): 17 | nn.init.zeros_(m.bias) 18 | 19 | class SwiGLU(nn.Module): 20 | def forward(self, x): 21 | x, gate = x.chunk(2, dim = -1) 22 | return F.silu(gate) * x 23 | 24 | class FeedForward(nn.Module): 25 | def __init__( 26 | self, 27 | dim, 28 | mult = 4, 29 | time_cond_dim = None, 30 | dropout = 0., 31 | ): 32 | super().__init__() 33 | self.norm = RMSNorm(dim) 34 | inner_dim = int(dim * mult * 2 / 3) 35 | dim_out = dim 36 | 37 | self.time_cond = None 38 | self.dropout = nn.Dropout(dropout) 39 | 40 | if dropout > 0: 41 | self.net = nn.Sequential( 42 | nn.Linear(dim, inner_dim*2), 43 | SwiGLU(), 44 | nn.Dropout(dropout), 45 | nn.Linear(inner_dim, dim_out) 46 | ) 47 | else: 48 | self.net = nn.Sequential( 49 | nn.Linear(dim, inner_dim*2), 50 | SwiGLU(), 51 | nn.Linear(inner_dim, dim_out) 52 | ) 53 | 54 | if exists(time_cond_dim): 55 | self.time_cond = nn.Sequential( 56 | nn.SiLU(), 57 | nn.Linear(time_cond_dim, dim * 3), 58 | Rearrange('b d -> b 1 d') 59 | ) 60 | 61 | zero_init_(self.time_cond[-2]) 62 | else: 63 | zero_init_(self.net[-1]) 64 | 65 | 66 | def forward(self, x, time = None): 67 | x = self.norm(x) 68 | if exists(self.time_cond): 69 | assert exists(time) 70 | scale, shift, gate = self.time_cond(time).chunk(3, dim = 2) 71 | x = (x * (scale + 1)) + shift 72 | 73 | x = self.net(x) 74 | 75 | if exists(self.time_cond): 76 | x = x*gate 77 | 78 | return x 79 | -------------------------------------------------------------------------------- /models/modules/conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | import typing as tp 3 | import warnings 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from einops import rearrange, reduce, repeat 10 | 11 | 12 | class MaskedConv1d(nn.Conv1d): 13 | """Wrapper around Conv1d to provide masking functionality 14 | """ 15 | def __init__(self, *args, **kwargs): 16 | super().__init__(*args, **kwargs) 17 | 18 | def forward(self, x, audio_mask=None): 19 | if audio_mask is not None: 20 | assert audio_mask.shape[-1] == x.shape[-1] 21 | conv_mask = rearrange(audio_mask, 'b l -> b () l') 22 | x = torch.where(conv_mask, x, 0) 23 | x = super().forward(x) 24 | return x -------------------------------------------------------------------------------- /models/modules/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, reduce, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | 9 | class RMSNorm(nn.Module): 10 | def __init__(self, dim): 11 | super().__init__() 12 | self.scale = dim ** 0.5 13 | self.gamma = nn.Parameter(torch.ones(dim)) 14 | 15 | def forward(self, x): 16 | out = F.normalize(x, dim = -1) * self.scale * self.gamma 17 | return out 18 | 19 | class ConvRMSNorm(RMSNorm): 20 | """ 21 | Convolution-friendly RMSNorm that moves channels to last dimensions 22 | before running the normalization and moves them back to original position right after. 23 | """ 24 | def __init__(self, dim): 25 | super().__init__(dim) 26 | 27 | def forward(self, x): 28 | x = rearrange(x, 'b ... t -> b t ...') 29 | x = super().forward(x) 30 | x = rearrange(x, 'b t ... -> b ... t') 31 | return x 32 | 33 | # use layernorm without bias, more stable 34 | 35 | class LayerNorm(nn.Module): 36 | def __init__(self, dim): 37 | super().__init__() 38 | self.gamma = nn.Parameter(torch.ones(dim)) 39 | self.register_buffer("beta", torch.zeros(dim)) 40 | 41 | def forward(self, x): 42 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) -------------------------------------------------------------------------------- /models/modules/position.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from einops import rearrange, einsum, repeat 5 | 6 | from models.modules.blocks import FeedForward 7 | 8 | def exists(val): 9 | return val is not None 10 | 11 | class VariationalFourierFeatures(nn.Module): 12 | """ following https://arxiv.org/abs/2107.00630 """ 13 | 14 | def __init__(self, n_min=0, n_max=6, step=1): 15 | super().__init__() 16 | assert n_min <= n_max 17 | self.n_min = n_min 18 | self.n_max = n_max 19 | self.step = step 20 | 21 | def forward(self, x): 22 | # Create Base 2 Fourier features 23 | w = 2.**torch.arange(self.n_min, self.n_max+1, self.step, device = x.device, dtype = x.dtype) * 2 * math.pi 24 | 25 | if len(x.shape) == 3: 26 | w = repeat(w, 'f -> b l f', b = x.shape[0], l = x.shape[1]) 27 | freqs = einsum(x, w, 'b l d, b l f -> b l d f') 28 | freqs = rearrange(freqs, 'b l d f -> b l (d f)') 29 | fouriered = torch.cat([x, freqs.sin(), freqs.cos()], dim=-1) 30 | elif len(x.shape) == 1: 31 | w = repeat(w, 'f -> l f', l = x.shape[0]) 32 | freqs = einsum(x, w, 'l, l f -> l f') 33 | x = rearrange(x, 'l -> l ()') 34 | fouriered = torch.cat([x, freqs.sin(), freqs.cos()], dim=-1) 35 | return fouriered 36 | 37 | class FourierFeatureEmbedding(nn.Module): 38 | def __init__(self, dim, n_min=0, n_max=6, n_layers=3): 39 | super().__init__() 40 | self.variational_fourier_features = VariationalFourierFeatures(n_min=n_min, n_max=n_max) 41 | 42 | self.init_proj = nn.Linear((1+(n_max-n_min))*2+1, dim) 43 | 44 | self.layers = nn.ModuleList([FeedForward(dim) for _ in range(n_layers)]) 45 | 46 | def forward(self, x): 47 | fourier_emb = self.init_proj(self.variational_fourier_features(x)) 48 | for layer in self.layers: 49 | fourier_emb = layer(fourier_emb) + fourier_emb 50 | 51 | return fourier_emb 52 | 53 | class RelativePositionalEmbedding(nn.Module): 54 | def __init__(self, dim, n_min=0, n_max=6, n_layers=3): 55 | super().__init__() 56 | self.variational_fourier_features = VariationalFourierFeatures(n_min=n_min, n_max=n_max) 57 | 58 | self.init_proj = nn.Linear((1+(n_max-n_min))*2+1, dim) 59 | 60 | self.layers = nn.ModuleList([FeedForward(dim) for _ in range(n_layers)]) 61 | 62 | def forward(self, x, attention_mask): 63 | position_indices = torch.arange(x.shape[1], device = x.device) 64 | # Need to handle masked contexts from classifier free guidance 65 | context_len = torch.sum(attention_mask, dim=-1) 66 | # Replace 0s with 1s to avoid divide by 0 67 | context_len = torch.where(context_len == 0, torch.ones_like(context_len), context_len) 68 | relative_position = repeat(position_indices, 'l -> b l', b = x.shape[0])/(context_len.unsqueeze(-1)) 69 | relative_position = rearrange(relative_position, 'b l -> b l ()') 70 | relative_position_emb = self.init_proj(self.variational_fourier_features(relative_position)) 71 | for layer in self.layers: 72 | relative_position_emb = layer(relative_position_emb) + relative_position_emb 73 | 74 | return relative_position_emb 75 | 76 | 77 | class AbsolutePositionalEmbedding(nn.Module): 78 | def __init__(self, dim, max_seq_len=512): 79 | super().__init__() 80 | self.scale = dim ** -0.5 81 | self.max_seq_len = max_seq_len 82 | self.emb = nn.Embedding(max_seq_len, dim) 83 | nn.init.normal_(self.emb.weight, std=.01) 84 | 85 | def forward(self, x, pos = None): 86 | seq_len, device = x.shape[1], x.device 87 | assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}' 88 | 89 | if not exists(pos): 90 | pos = torch.arange(seq_len, device = device) 91 | 92 | pos_emb = self.emb(pos) 93 | pos_emb = pos_emb * self.scale 94 | return pos_emb 95 | 96 | # From https://github.com/lucidrains/x-transformers/blob/c7cc22268c8ebceef55fe78343197f0af62edf18/x_transformers/x_transformers.py#L272 97 | class DynamicPositionBias(nn.Module): 98 | def __init__(self, dim, *, heads, depth=3, log_distance = False): 99 | super().__init__() 100 | assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1' 101 | self.log_distance = log_distance 102 | 103 | self.init_proj = nn.Linear(1, dim) 104 | 105 | self.layers = nn.ModuleList([FeedForward(dim) for _ in range(depth)]) 106 | 107 | self.out = nn.Linear(dim, heads) 108 | 109 | @property 110 | def device(self): 111 | return next(self.parameters()).device 112 | 113 | def forward(self, i, j): 114 | assert i == j 115 | n, device = j, self.device 116 | 117 | # get the (n x n) matrix of distances 118 | seq_arange = torch.arange(n, device = device) 119 | context_arange = torch.arange(n, device = device) 120 | indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j') 121 | indices += (n - 1) 122 | 123 | # input to continuous positions MLP 124 | pos = torch.arange(-n + 1, n, device = device).float() 125 | pos = rearrange(pos, '... -> ... 1') 126 | 127 | if self.log_distance: 128 | pos = torch.sign(pos) * torch.log(pos.abs() + 1) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1) 129 | 130 | pos_emb = self.init_proj(pos) 131 | for layer in self.layers: 132 | pos_emb = layer(pos_emb) + pos_emb 133 | pos_biases = self.out(pos_emb) 134 | 135 | # get position biases 136 | bias = pos_biases[indices] 137 | bias = rearrange(bias, 'i j h -> h i j') 138 | return bias 139 | -------------------------------------------------------------------------------- /models/modules/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, reduce, repeat, pack, unpack 6 | from einops.layers.torch import Rearrange 7 | import math 8 | 9 | from models.modules.norm import RMSNorm 10 | from models.modules.position import RelativePositionalEmbedding, DynamicPositionBias 11 | from models.modules.blocks import FeedForward 12 | 13 | def exists(x): 14 | return x is not None 15 | 16 | def zero_init_(m): 17 | nn.init.zeros_(m.weight) 18 | if exists(m.bias): 19 | nn.init.zeros_(m.bias) 20 | 21 | class Attention(nn.Module): 22 | def __init__( 23 | self, 24 | dim, 25 | dim_head = 32, 26 | time_cond_dim = None, 27 | dropout=0. 28 | ): 29 | super().__init__() 30 | assert dim % dim_head == 0, 'Dimension must be divisible by the head dimension' 31 | self.heads = dim // dim_head 32 | 33 | self.dropout = dropout 34 | self.time_cond = None 35 | 36 | self.norm = RMSNorm(dim) 37 | 38 | self.to_qkv = nn.Linear(dim, dim * 3, bias = False) 39 | self.to_out = nn.Linear(dim, dim) 40 | 41 | if exists(time_cond_dim): 42 | self.time_cond = nn.Sequential( 43 | nn.SiLU(), 44 | nn.Linear(time_cond_dim, dim * 3), 45 | Rearrange('b d -> b 1 d') 46 | ) 47 | zero_init_(self.time_cond[-2]) 48 | else: 49 | zero_init_(self.to_out) 50 | 51 | def forward(self, x, attn_bias, time=None, audio_mask=None): 52 | b, c, n = x.shape 53 | 54 | x = self.norm(x) 55 | 56 | if exists(self.time_cond): 57 | assert exists(time) 58 | scale, shift, gate = self.time_cond(time).chunk(3, dim = 2) 59 | x = (x * (scale + 1)) + shift 60 | 61 | qkv = self.to_qkv(x).chunk(3, dim = 2) 62 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads).contiguous(), qkv) 63 | 64 | attn_bias = repeat(attn_bias, 'h i j -> b h i j', b=b) 65 | 66 | if exists(audio_mask): 67 | mask_value = -torch.finfo(q.dtype).max 68 | mask = rearrange(audio_mask, 'b l -> b () () l') 69 | attn_bias = attn_bias.masked_fill(~mask, mask_value) 70 | 71 | out = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout if self.training else 0., attn_mask=attn_bias) 72 | out = rearrange(out, 'b h n d -> b n (h d)') 73 | 74 | out = self.to_out(out) 75 | if exists(self.time_cond): 76 | out = out*gate 77 | 78 | return out 79 | 80 | class CrossAttention(nn.Module): 81 | def __init__(self, dim, dim_context, dim_head = 32, time_cond_dim=None, dropout=0., position_aware=False, cross_attn_pos_dim=None,): 82 | super().__init__() 83 | assert dim % dim_head == 0, 'Dimension must be divisible by the head dimension' 84 | self.heads = dim // dim_head 85 | self.dropout = dropout 86 | self.norm = RMSNorm(dim) 87 | self.time_cond = None 88 | self.time_cond = None 89 | self.position_aware = position_aware 90 | 91 | if self.position_aware: 92 | assert exists(cross_attn_pos_dim) 93 | self.pos_to_k = nn.Linear(cross_attn_pos_dim, dim, bias=False) 94 | 95 | 96 | self.norm_context = nn.LayerNorm(dim_context) 97 | 98 | self.null_kv = nn.Parameter(torch.randn(2, dim)) 99 | self.to_q = nn.Linear(dim, dim, bias = False) 100 | self.to_kv = nn.Linear(dim_context, dim * 2, bias = False) 101 | self.to_out = nn.Linear(dim, dim) 102 | 103 | if exists(time_cond_dim): 104 | self.time_cond = nn.Sequential( 105 | nn.SiLU(), 106 | nn.Linear(time_cond_dim, dim * 3), 107 | Rearrange('b d -> b 1 d') 108 | ) 109 | zero_init_(self.time_cond[-2]) 110 | else: 111 | zero_init_(self.to_out) 112 | 113 | self.q_norm = RMSNorm(dim_head) 114 | self.k_norm = RMSNorm(dim_head) 115 | 116 | def forward(self, x, context, context_mask, time=None, context_pos=None,): 117 | ''' 118 | x: [B, L_audio, d_unet] 119 | context: [B, L_text, d_lm] 120 | context_mask: [B, L_text] 121 | ''' 122 | b, c, n = x.shape 123 | x = self.norm(x) 124 | if exists(self.time_cond): 125 | assert exists(time) 126 | scale, shift, gate = self.time_cond(time).chunk(3, dim = 2) 127 | x = (x * (scale + 1)) + shift 128 | context = self.norm_context(context) 129 | 130 | q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) 131 | 132 | if self.position_aware: 133 | assert exists(context_pos) 134 | k = k + self.pos_to_k(context_pos) 135 | 136 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads).contiguous(), (q, k, v)) 137 | 138 | # Null value for classifier free guidance 139 | nk, nv = map(lambda t: repeat(t, '(h d) -> b h 1 d', b = b, h=self.heads), self.null_kv.unbind(dim = -2)) 140 | k = torch.cat((nk, k), dim = -2) 141 | v = torch.cat((nv, v), dim = -2) 142 | 143 | query_len = q.shape[2] 144 | 145 | # RMSNorm Trick for stability 146 | q = self.q_norm(q) 147 | k = self.k_norm(k) 148 | # Masking pad tokens 149 | context_mask = F.pad(context_mask, (1, 0), value = True) 150 | context_mask = repeat(context_mask, 'b j -> b h q_len j', h=self.heads, q_len=query_len) 151 | 152 | out = F.scaled_dot_product_attention(q, k, v, attn_mask=context_mask, dropout_p=self.dropout if self.training else 0.) 153 | # attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask 154 | # attn_weight = torch.softmax((q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))) + attn_mask, dim=-1) 155 | # out = attn_weight @ v 156 | 157 | out = rearrange(out, 'b h n d -> b n (h d)') 158 | 159 | out = self.to_out(out) 160 | if exists(self.time_cond): 161 | out = out*gate 162 | 163 | return out 164 | 165 | 166 | class ConditionableTransformer(nn.Module): 167 | def __init__( 168 | self, 169 | dim, 170 | dim_context, 171 | *, 172 | num_layers, 173 | time_cond_dim, 174 | dim_head = 64, 175 | ff_mult = 4, 176 | dropout=0.0, 177 | position_aware_cross_attention=False, 178 | num_registers=8, 179 | dense_connections=True, 180 | ): 181 | super().__init__() 182 | self.dim = dim 183 | self.num_layers = num_layers 184 | 185 | self.position_aware = position_aware_cross_attention 186 | if self.position_aware: 187 | cross_attn_pos_dim = dim//4 188 | self.context_pos_emb = RelativePositionalEmbedding(cross_attn_pos_dim, n_min=0, n_max=6) 189 | else: 190 | cross_attn_pos_dim = None 191 | 192 | self.layers = nn.ModuleList([]) 193 | for _ in range(num_layers): 194 | self.layers.append(nn.ModuleList([ 195 | Attention(dim = dim, dim_head = dim_head, dropout=dropout), 196 | CrossAttention(dim = dim, dim_head = dim_head, dim_context=dim_context, dropout=dropout, cross_attn_pos_dim=cross_attn_pos_dim, position_aware=position_aware_cross_attention), 197 | FeedForward(dim=dim, mult=ff_mult, time_cond_dim=time_cond_dim, dropout=dropout) 198 | ])) 199 | 200 | self.dynamic_pos_bias = DynamicPositionBias(dim = dim // 4, heads = dim // dim_head, log_distance = False, depth = 2) 201 | 202 | self.has_registers = num_registers > 0 203 | if num_registers > 0: 204 | self.memory_tokens = nn.Parameter(torch.randn(num_registers, dim)) 205 | 206 | if dense_connections: 207 | assert num_layers % 2 == 0, 'number of layers must be divisible by 2 for dense connections' 208 | self.dense_blocks = nn.ModuleList([nn.Linear(dim*2, dim) for i in range(num_layers // 2)]) 209 | self.dense_connections = dense_connections 210 | 211 | def forward( 212 | self, 213 | x, 214 | *, 215 | time, 216 | context, 217 | context_mask, 218 | audio_mask, 219 | ): 220 | 221 | if self.position_aware: 222 | context_pos = self.context_pos_emb(context, context_mask) 223 | 224 | else: 225 | context_pos = None 226 | 227 | if self.has_registers: 228 | mem = repeat(self.memory_tokens, 'l d -> b l d', b = x.shape[0]) 229 | x, mem_packed_shape = pack((mem, x), 'b * d') 230 | 231 | mem_attn_mask = torch.ones_like(mem[:,:,0], dtype=torch.bool) 232 | audio_mask = torch.cat((mem_attn_mask, audio_mask), dim=1) 233 | 234 | 235 | i = j = x.shape[1] 236 | attn_bias = self.dynamic_pos_bias(i, j) 237 | 238 | hiddens = [] 239 | for idx, (attn, cross_attn, ff) in enumerate(self.layers): 240 | if self.dense_connections: 241 | if self.has_registers: 242 | if idx < (self.num_layers // 2): 243 | # store hidden states for dense connections 244 | hiddens.append(x[:, mem.shape[1]:, :]) 245 | else: 246 | concat_feats = torch.cat((x[:, mem.shape[1]:, :], hiddens.pop()), dim=-1) 247 | x[:, mem.shape[1]:, :] = self.dense_blocks[idx - (self.num_layers // 2)](concat_feats) 248 | else: 249 | if idx < (self.num_layers // 2): 250 | # store hidden states for dense connections 251 | hiddens.append(x) 252 | else: 253 | concat_feats = torch.cat((x, hiddens.pop()), dim=-1) 254 | x = self.dense_blocks[idx - (self.num_layers // 2)](concat_feats) 255 | res = x 256 | x = attn(x, attn_bias=attn_bias, audio_mask=audio_mask) + res 257 | 258 | res = x 259 | x = cross_attn(x, context = context, context_mask=context_mask, context_pos=context_pos) + res 260 | 261 | res = x 262 | x = ff(x, time=time) + res 263 | 264 | if self.has_registers: 265 | mem, x = unpack(x, mem_packed_shape, 'b * d') 266 | 267 | return x 268 | -------------------------------------------------------------------------------- /models/transformer_wrapper.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from einops import rearrange, reduce 5 | from functools import partial 6 | 7 | from models.modules import RMSNorm, ConvRMSNorm, ConditionableTransformer, LayerNorm, MaskedConv1d 8 | # helpers functions 9 | 10 | def exists(x): 11 | return x is not None 12 | 13 | def default(val, d): 14 | if exists(val): 15 | return val 16 | return d() if callable(d) else d 17 | 18 | def zero_init_(m): 19 | nn.init.zeros_(m.weight) 20 | if exists(m.bias): 21 | nn.init.zeros_(m.bias) 22 | 23 | def masked_mean(t, *, dim, mask = None): 24 | if not exists(mask): 25 | return t.mean(dim = dim) 26 | 27 | denom = mask.sum(dim = dim, keepdim = True) 28 | mask = rearrange(mask, 'b n -> b n 1') 29 | masked_t = t.masked_fill(~mask, 0.) 30 | 31 | return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5) 32 | 33 | 34 | # sinusoidal positional embeds 35 | 36 | class SinusoidalPosEmb(nn.Module): 37 | def __init__(self, dim): 38 | super().__init__() 39 | self.dim = dim 40 | 41 | def forward(self, x): 42 | device = x.device 43 | half_dim = self.dim // 2 44 | emb = math.log(10000) / (half_dim - 1) 45 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 46 | emb = x[:, None] * emb[None, :] 47 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 48 | return emb 49 | 50 | class TransformerWrapper(nn.Module): 51 | def __init__( 52 | self, 53 | dim, 54 | text_dim, 55 | channels = 128, 56 | position_aware_cross_attention=False, 57 | inpainting_embedding = False, 58 | num_transformer_layers = 8, 59 | dropout=0.0, 60 | ): 61 | super().__init__() 62 | 63 | 64 | self.channels = channels 65 | self.out_dim = channels 66 | 67 | input_channels = channels 68 | 69 | self.init_conv = nn.Conv1d(input_channels, dim, 1) 70 | 71 | 72 | if inpainting_embedding: 73 | self.inpainting_embedding = nn.Embedding(2, dim) 74 | else: 75 | self.inpainting_embedding = None 76 | 77 | # time embeddings 78 | 79 | time_dim = dim * 2 80 | 81 | sinu_pos_emb = SinusoidalPosEmb(dim) 82 | fourier_dim = dim 83 | 84 | self.time_mlp = nn.Sequential( 85 | sinu_pos_emb, 86 | nn.Linear(fourier_dim, time_dim), 87 | nn.SiLU(), 88 | nn.Linear(time_dim, time_dim) 89 | ) 90 | 91 | # layers 92 | 93 | self.transformer = ConditionableTransformer(dim, dim_context=text_dim, num_layers=num_transformer_layers, time_cond_dim=time_dim, dropout=dropout, position_aware_cross_attention=position_aware_cross_attention) 94 | 95 | self.final_conv = nn.Sequential( 96 | ConvRMSNorm(dim), 97 | nn.SiLU(), 98 | nn.Conv1d(dim, self.out_dim, 1) 99 | ) 100 | zero_init_(self.final_conv[-1]) 101 | 102 | 103 | self.to_text_non_attn_cond = nn.Sequential( 104 | nn.LayerNorm(text_dim), 105 | nn.Linear(text_dim, time_dim), 106 | nn.SiLU(), 107 | nn.Linear(time_dim, time_dim) 108 | ) 109 | 110 | def forward(self, x, time_cond, text_cond=None, text_cond_mask=None, inpainting_mask=None, audio_mask=None): 111 | if not exists(audio_mask): 112 | audio_mask = torch.ones((x.shape[0], x.shape[2]), dtype=torch.bool, device=x.device) 113 | x = self.init_conv(x) 114 | if exists(self.inpainting_embedding): 115 | assert exists(inpainting_mask) 116 | inpainting_emb = self.inpainting_embedding(inpainting_mask) 117 | x = x + rearrange(inpainting_emb, 'b l c -> b c l') 118 | 119 | mean_pooled_context = masked_mean(text_cond, dim=1, mask=text_cond_mask) 120 | text_mean_cond = self.to_text_non_attn_cond(mean_pooled_context) 121 | 122 | # Rescale continuous time [0,1] to similar range as Ho et al. 2020 123 | t = self.time_mlp(time_cond*1000) 124 | 125 | t = t + text_mean_cond 126 | 127 | x = rearrange(x, 'b c l -> b l c') 128 | 129 | x = self.transformer(x, context=text_cond, context_mask=text_cond_mask, time=t, audio_mask=audio_mask) 130 | x = rearrange(x, 'b l c -> b c l') 131 | 132 | return self.final_conv(x) -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | 4 | import torch 5 | from torch import nn, einsum 6 | import torch.nn.functional as F 7 | 8 | from einops import rearrange, reduce, repeat 9 | from einops.layers.torch import Rearrange, Reduce 10 | 11 | from models.modules import RMSNorm, ConvRMSNorm, ConditionableTransformer, LayerNorm, MaskedConv1d 12 | 13 | ENCODEC_DIM = 128 14 | 15 | # helpers functions 16 | 17 | def exists(x): 18 | return x is not None 19 | 20 | def default(val, d): 21 | if exists(val): 22 | return val 23 | return d() if callable(d) else d 24 | 25 | # small helper modules 26 | 27 | class Residual(nn.Module): 28 | def __init__(self, fn): 29 | super().__init__() 30 | self.fn = fn 31 | 32 | def forward(self, x, *args, **kwargs): 33 | return self.fn(x, *args, **kwargs) + x 34 | 35 | def Upsample(dim, dim_out = None): 36 | return nn.Sequential( 37 | nn.Upsample(scale_factor = 2, mode = 'nearest'), 38 | MaskedConv1d(dim, default(dim_out, dim), 3, padding = 1) 39 | ) 40 | 41 | def Downsample(dim, dim_out = None): 42 | return MaskedConv1d(dim, default(dim_out, dim), 4, 2, 1) 43 | 44 | 45 | # sinusoidal positional embeds 46 | 47 | class SinusoidalPosEmb(nn.Module): 48 | def __init__(self, dim): 49 | super().__init__() 50 | self.dim = dim 51 | 52 | def forward(self, x): 53 | device = x.device 54 | half_dim = self.dim // 2 55 | emb = math.log(10000) / (half_dim - 1) 56 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 57 | emb = x[:, None] * emb[None, :] 58 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 59 | return emb 60 | 61 | class RandomOrLearnedSinusoidalPosEmb(nn.Module): 62 | """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """ 63 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ 64 | 65 | def __init__(self, dim, is_random = False): 66 | super().__init__() 67 | assert (dim % 2) == 0 68 | half_dim = dim // 2 69 | self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random) 70 | 71 | def forward(self, x): 72 | x = rearrange(x, 'b -> b 1') 73 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi 74 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) 75 | fouriered = torch.cat((x, fouriered), dim = -1) 76 | return fouriered 77 | # building block modules 78 | 79 | class Block(nn.Module): 80 | def __init__(self, dim, dim_out, groups = 8, input_connection=False, input_dim=None): 81 | super().__init__() 82 | self.input_connection = input_connection 83 | if input_connection: 84 | assert exists(input_dim) 85 | self.proj = MaskedConv1d(dim + input_dim, dim_out, 3, padding = 1) 86 | else: 87 | self.proj = MaskedConv1d(dim, dim_out, 3, padding = 1) 88 | self.norm = ConvRMSNorm(dim) 89 | self.act = nn.SiLU() 90 | 91 | def forward(self, x, scale_shift = None, audio_mask=None, input_data=None): 92 | x = self.norm(x) 93 | if exists(scale_shift): 94 | scale, shift = scale_shift 95 | x = x * (scale + 1) + shift 96 | x = self.act(x) 97 | if self.input_connection: 98 | assert exists(input_data) 99 | x = torch.cat((x, input_data), dim=1) 100 | 101 | x = self.proj(x, audio_mask) 102 | 103 | return x 104 | 105 | class ResnetBlock(nn.Module): 106 | def __init__(self, dim, dim_out, *, input_connection=False, input_dim=None, time_emb_dim = None, groups = 32): 107 | super().__init__() 108 | if input_connection: 109 | assert exists(input_dim) 110 | self.input_connection = input_connection 111 | self.mlp = None 112 | if exists(time_emb_dim): 113 | self.mlp = nn.Sequential( 114 | nn.SiLU(), 115 | nn.Linear(time_emb_dim, dim_out * 3), 116 | Rearrange('b c -> b c 1') 117 | ) 118 | zero_init_(self.mlp[-2]) 119 | 120 | if input_connection: 121 | self.block1 = Block(dim, dim_out, groups = groups, input_connection=True, input_dim=input_dim) 122 | else: 123 | self.block1 = Block(dim, dim_out, groups = groups) 124 | self.block2 = Block(dim_out, dim_out, groups = groups) 125 | self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 126 | 127 | def forward(self, x, time_emb = None, audio_mask=None, input_data=None): 128 | 129 | scale_shift = None 130 | if exists(self.mlp): 131 | assert exists(time_emb) 132 | scale_shift = self.mlp(time_emb) 133 | scale, shift, gate = scale_shift.chunk(3, dim = 1) 134 | scale_shift = (scale, shift) 135 | 136 | 137 | h = self.block1(x, audio_mask=audio_mask, input_data=input_data) 138 | 139 | h = self.block2(h, scale_shift=scale_shift, audio_mask=audio_mask) 140 | 141 | if exists(self.mlp): 142 | h = h*gate 143 | 144 | return h + self.res_conv(x) 145 | 146 | class LinearAttention(nn.Module): 147 | def __init__(self, dim, heads = 4, dim_head = 32): 148 | super().__init__() 149 | self.scale = dim_head ** -0.5 150 | self.heads = heads 151 | hidden_dim = dim_head * heads 152 | self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias = False) 153 | 154 | self.to_out = nn.Conv1d(hidden_dim, dim, 1) 155 | zero_init_(self.to_out) 156 | 157 | def forward(self, x, audio_mask=None): 158 | b, c, n = x.shape 159 | qkv = self.to_qkv(x).chunk(3, dim = 1) 160 | q, k, v = map(lambda t: rearrange(t, 'b (h c) n -> b h c n', h = self.heads), qkv) 161 | 162 | if exists(audio_mask): 163 | mask_value = -torch.finfo(q.dtype).max 164 | mask = audio_mask[:, None, None, :] 165 | k = k.masked_fill(~mask, mask_value) 166 | v = v.masked_fill(~mask, 0.) 167 | del mask 168 | 169 | q = q.softmax(dim = -2) 170 | k = k.softmax(dim = -1) 171 | 172 | q = q * self.scale 173 | 174 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v) 175 | 176 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q) 177 | out = rearrange(out, 'b h c n -> b (h c) n', h = self.heads) 178 | return self.to_out(out) 179 | 180 | def l2norm(t): 181 | return F.normalize(t, dim = -1) 182 | 183 | 184 | def masked_mean(t, *, dim, mask = None): 185 | if not exists(mask): 186 | return t.mean(dim = dim) 187 | 188 | denom = mask.sum(dim = dim, keepdim = True) 189 | mask = rearrange(mask, 'b n -> b n 1') 190 | masked_t = t.masked_fill(~mask, 0.) 191 | 192 | return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5) 193 | 194 | class PreNorm(nn.Module): 195 | def __init__(self, dim, fn): 196 | super().__init__() 197 | self.fn = fn 198 | self.norm = ConvRMSNorm(dim) 199 | 200 | def forward(self, x, *args, **kwargs): 201 | x = self.norm(x) 202 | return self.fn(x, *args, **kwargs) 203 | 204 | def zero_init_(m): 205 | nn.init.zeros_(m.weight) 206 | if exists(m.bias): 207 | nn.init.zeros_(m.bias) 208 | 209 | # model 210 | 211 | class Unet1D(nn.Module): 212 | def __init__( 213 | self, 214 | dim, 215 | text_dim, 216 | init_dim = None, 217 | out_dim = None, 218 | dim_mults=(1, 2, 4, 8), 219 | channels = 128, 220 | position_aware_cross_attention=False, 221 | inpainting_embedding = False, 222 | resnet_block_groups = 32, 223 | scale_skip_connection=False, 224 | num_transformer_layers = 3, 225 | dropout=0.0, 226 | num_transformer_registers=8, 227 | input_connection=False, 228 | ): 229 | super().__init__() 230 | 231 | 232 | self.channels = channels 233 | 234 | input_channels = channels 235 | 236 | init_dim = default(init_dim, dim) 237 | self.init_conv = nn.Conv1d(input_channels, init_dim, 1) 238 | 239 | dims = [init_dim, *map(lambda m: int(dim * m), dim_mults)] 240 | in_out = list(zip(dims[:-1], dims[1:])) 241 | 242 | block_klass = partial(ResnetBlock, groups = resnet_block_groups, input_connection=input_connection, input_dim=channels) 243 | 244 | if inpainting_embedding: 245 | self.inpainting_embedding = nn.Embedding(2, init_dim) 246 | else: 247 | self.inpainting_embedding = None 248 | 249 | # time embeddings 250 | 251 | time_dim = dim * 2 252 | 253 | sinu_pos_emb = SinusoidalPosEmb(dim) 254 | fourier_dim = dim 255 | 256 | self.time_mlp = nn.Sequential( 257 | sinu_pos_emb, 258 | nn.Linear(fourier_dim, time_dim), 259 | nn.SiLU(), 260 | nn.Linear(time_dim, time_dim) 261 | ) 262 | 263 | # layers 264 | 265 | self.downs = nn.ModuleList([]) 266 | self.ups = nn.ModuleList([]) 267 | self.input_connection = input_connection 268 | num_resolutions = len(in_out) 269 | 270 | for ind, (dim_in, dim_out) in enumerate(in_out): 271 | is_last = ind >= (num_resolutions - 1) 272 | 273 | self.downs.append(nn.ModuleList([ 274 | block_klass(dim_in, dim_in, time_emb_dim = time_dim), 275 | block_klass(dim_in, dim_in, time_emb_dim = time_dim), 276 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 277 | Downsample(dim_in, dim_out) if not is_last else MaskedConv1d(dim_in, dim_out, 3, padding = 1) 278 | ])) 279 | 280 | mid_dim = dims[-1] 281 | self.transformer = ConditionableTransformer(mid_dim, dim_context=text_dim, num_layers=num_transformer_layers, time_cond_dim=time_dim, dropout=dropout, position_aware_cross_attention=position_aware_cross_attention, num_registers=num_transformer_registers) 282 | 283 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): 284 | is_last = ind == (len(in_out) - 1) 285 | 286 | self.ups.append(nn.ModuleList([ 287 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), 288 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), 289 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 290 | Upsample(dim_out, dim_in) if not is_last else MaskedConv1d(dim_out, dim_in, 3, padding = 1) 291 | ])) 292 | 293 | self.out_dim = channels 294 | 295 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) 296 | 297 | self.final_conv = nn.Sequential( 298 | ConvRMSNorm(dim), 299 | nn.SiLU(), 300 | nn.Conv1d(dim, self.out_dim, 1) 301 | ) 302 | zero_init_(self.final_conv[-1]) 303 | 304 | # Accelerates convergence for image diffusion models 305 | # Use it by default, but haven't ablated 306 | self.scale_skip_connection = (2 ** -0.5) if scale_skip_connection else 1 307 | 308 | self.to_text_non_attn_cond = nn.Sequential( 309 | nn.LayerNorm(text_dim), 310 | nn.Linear(text_dim, time_dim), 311 | nn.SiLU(), 312 | nn.Linear(time_dim, time_dim) 313 | ) 314 | 315 | def forward(self, x, time_cond, text_cond=None, text_cond_mask=None, inpainting_mask=None, audio_mask=None): 316 | assert x.shape[-1] % (2**(len(self.downs)-1)) == 0, f'Length of the audio latent must be a factor of {2**(len(self.downs)-1)}' 317 | if not exists(audio_mask): 318 | audio_mask = torch.ones((x.shape[0], x.shape[2]), dtype=torch.bool, device=x.device) 319 | assert torch.remainder(audio_mask.sum(dim=1), (2**(len(self.downs)-1))).sum().item()==0, f'Length of audio mask must be a factor of {2**(len(self.downs)-1)}' 320 | 321 | if self.input_connection: 322 | input_data = x.clone() 323 | # Zero out masked values 324 | input_data = input_data.masked_fill(~audio_mask[:, None, :], 0.) 325 | else: 326 | input_data = None 327 | x = self.init_conv(x) 328 | if exists(self.inpainting_embedding): 329 | assert exists(inpainting_mask) 330 | inpainting_emb = self.inpainting_embedding(inpainting_mask) 331 | x = x + rearrange(inpainting_emb, 'b l c -> b c l') 332 | 333 | r = x.clone() 334 | 335 | mean_pooled_context = masked_mean(text_cond, dim=1, mask=text_cond_mask) 336 | text_mean_cond = self.to_text_non_attn_cond(mean_pooled_context) 337 | 338 | # Rescale continuous time [0,1] to similar range as Ho et al. 2020 339 | t = self.time_mlp(time_cond*1000) 340 | 341 | t = t + text_mean_cond 342 | 343 | h = [] 344 | audio_mask_list = [audio_mask] 345 | input_data_list = [input_data] 346 | for idx, (block1, block2, attn, downsample) in enumerate(self.downs): 347 | x = block1(x, t, audio_mask=audio_mask_list[-1], input_data=input_data_list[-1]) 348 | h.append(x) 349 | 350 | x = block2(x, t, audio_mask=audio_mask_list[-1], input_data=input_data_list[-1]) 351 | x = attn(x, audio_mask=audio_mask_list[-1]) 352 | h.append(x) 353 | 354 | 355 | x_prev_shape = x.shape 356 | x = downsample(x, audio_mask_list[-1]) 357 | if x.shape[-1] != x_prev_shape[-1]: 358 | downsampled_mask = reduce(audio_mask_list[-1], 'b (l 2) -> b l', reduction='max') 359 | audio_mask_list.append(downsampled_mask) 360 | if self.input_connection: 361 | # Einops to mean pool input data 362 | input_data_list.append(reduce(input_data_list[-1], 'b c (l 2) -> b c l', reduction='mean')) 363 | x = rearrange(x, 'b c l -> b l c') 364 | 365 | x = self.transformer(x, context=text_cond, context_mask=text_cond_mask, time=t, audio_mask=audio_mask_list[-1]) 366 | x = rearrange(x, 'b l c -> b c l') 367 | 368 | for block1, block2, attn, upsample in self.ups: 369 | x = torch.cat((x, h.pop()*(self.scale_skip_connection)), dim = 1) 370 | x = block1(x, t, audio_mask_list[-1], input_data_list[-1]) 371 | 372 | x = torch.cat((x, h.pop()*(self.scale_skip_connection)), dim = 1) 373 | x = block2(x, t, audio_mask_list[-1], input_data_list[-1]) 374 | x = attn(x, audio_mask_list[-1]) 375 | 376 | # Awkward implementation to maintain backwards compatibility with previous checkpoints 377 | if isinstance(upsample, nn.Sequential): 378 | # Need to cast to float32 for upsampling 379 | # Upsample operation 380 | x = upsample[0](x.float()) 381 | audio_mask_list.pop() 382 | # Masked conv operation 383 | x = upsample[1](x, audio_mask_list[-1]) 384 | 385 | if self.input_connection: 386 | input_data_list.pop() 387 | else: 388 | x = upsample(x, audio_mask_list[-1]) 389 | 390 | x = torch.cat((x, r), dim = 1) 391 | 392 | x = self.final_res_block(x, t, audio_mask, input_data) 393 | return self.final_conv(x) -------------------------------------------------------------------------------- /neural_codec/encodec_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | from torch import nn, einsum 4 | from encodec import EncodecModel 5 | from encodec.utils import _linear_overlap_add 6 | from encodec.utils import convert_audio 7 | import typing as tp 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | from einops import rearrange 12 | from functools import partial 13 | import os 14 | 15 | EncodedFrame = tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]] 16 | 17 | class EncodecWrapper(nn.Module): 18 | def __init__(self): 19 | super().__init__() 20 | self.codec = EncodecModel.encodec_model_24khz() 21 | self.codec.set_target_bandwidth(24.) 22 | 23 | def encode(self, x: torch.Tensor) -> torch.Tensor: 24 | """Given a tensor `x`, returns a list of frames containing 25 | the discrete encoded codes for `x`, along with rescaling factors 26 | for each segment, when `self.normalize` is True. 27 | 28 | Each frames is a tuple `(codebook, scale)`, with `codebook` of 29 | shape `[B, K, T]`, with `K` the number of codebooks. 30 | """ 31 | assert x.dim() == 3 32 | _, channels, length = x.shape 33 | assert channels > 0 and channels <= 2 34 | segment_length = self.codec.segment_length 35 | if segment_length is None: 36 | segment_length = length 37 | stride = length 38 | else: 39 | stride = self.codec.segment_stride # type: ignore 40 | assert stride is not None 41 | 42 | encoded_frames: tp.List[EncodedFrame] = [] 43 | for offset in range(0, length, stride): 44 | frame = x[:, :, offset: offset + segment_length] 45 | encoded_frames.append(self._encode_frame(frame)) 46 | assert len(encoded_frames) == 1 47 | assert encoded_frames[0][1] is None 48 | return encoded_frames[0][0] 49 | 50 | def _encode_frame(self, x: torch.Tensor) -> EncodedFrame: 51 | length = x.shape[-1] 52 | duration = length / self.codec.sample_rate 53 | assert self.codec.segment is None or duration <= 1e-5 + self.codec.segment 54 | 55 | if self.codec.normalize: 56 | mono = x.mean(dim=1, keepdim=True) 57 | volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() 58 | scale = 1e-8 + volume 59 | x = x / scale 60 | scale = scale.view(-1, 1) 61 | else: 62 | scale = None 63 | 64 | emb = self.codec.encoder(x) 65 | return emb, scale 66 | 67 | def decode(self, emb: torch.Tensor, quantize: bool=True) -> torch.Tensor: 68 | """Decode the given frames into a waveform. 69 | Note that the output might be a bit bigger than the input. In that case, 70 | any extra steps at the end can be trimmed. 71 | """ 72 | encoded_frames = [(emb, None)] 73 | segment_length = self.codec.segment_length 74 | if segment_length is None: 75 | assert len(encoded_frames) == 1 76 | return self._decode_frame(encoded_frames[0]) 77 | 78 | frames = [self._decode_frame(frame, quantize=quantize) for frame in encoded_frames] 79 | return _linear_overlap_add(frames, self.segment_stride or 1) 80 | 81 | def _decode_frame(self, encoded_frame: EncodedFrame, quantize: bool=True) -> torch.Tensor: 82 | emb, scale = encoded_frame 83 | if quantize: 84 | codes = self.codec.quantizer.encode(emb, self.codec.frame_rate, self.codec.bandwidth) 85 | emb = self.codec.quantizer.decode(codes) 86 | 87 | # codes is [B, K, T], with T frames, K nb of codebooks. 88 | out = self.codec.decoder(emb) 89 | if scale is not None: 90 | out = out * scale.view(-1, 1, 1) 91 | return out 92 | 93 | 94 | def forward(self, wav:torch.tensor, sr:int, quantize:bool=True): 95 | # TODO: Revisit where to handle processing 96 | wav = convert_audio(wav, sr, self.codec.sample_rate, self.codec.channels) 97 | frames = self.encode(wav) 98 | return self.decode(frames, quantize=quantize)[:, :, :wav.shape[-1]] 99 | 100 | 101 | 102 | def test(): 103 | def normalize_audio_latent(data_mean, data_std, audio_latent): 104 | return (audio_latent - rearrange(data_mean, 'c -> () c ()')) / rearrange(data_std, 'c -> () c ()') 105 | 106 | def unnormalize_audio_latent(data_mean, data_std, audio_latent): 107 | return audio_latent * rearrange(data_std, 'c -> () c ()') + rearrange(data_mean, 'c -> () c ()') 108 | 109 | codec = EncodecWrapper().to('cuda') 110 | import soundfile as sf 111 | from audio_datasets.librispeech import LibriSpeech, ENCODEC_SAMPLING_RATE 112 | 113 | test_dataset = LibriSpeech(split='test') 114 | # Path to saved model 115 | 116 | data = torch.load('../saved_models/') 117 | data_mean = data['model']['data_mean'] 118 | data_std = data['model']['data_std'] 119 | 120 | with torch.no_grad(): 121 | for idx in range(len(test_dataset)): 122 | example = test_dataset.__getitem__(idx) 123 | 124 | # [B, 1, L]: batch x channels x length 125 | batched_wav = example['wav'][:,:int(example['audio_duration']*ENCODEC_SAMPLING_RATE)].unsqueeze(0).to('cuda') 126 | 127 | # linspace log_snr from -15 to 15 128 | log_snrs = np.linspace(-15, 15, 121) 129 | for log_snr in log_snrs: 130 | print(f'log_snr: {log_snr.item()}') 131 | os.makedirs(f'example_audio/logsnr/{log_snr.item()}', exist_ok=True) 132 | alpha2 = torch.sigmoid(torch.tensor([log_snr], device=batched_wav.device, dtype=torch.float32)) 133 | 134 | wav_emb = codec.encode(batched_wav) 135 | normalized_wav_emb = normalize_audio_latent(data_mean, data_std, wav_emb) 136 | noisy_wav_emb = alpha2.sqrt()*normalized_wav_emb + (1-alpha2).sqrt()*torch.randn_like(normalized_wav_emb) 137 | noisy_wav_emb /= alpha2.sqrt() 138 | noisy_wav_emb = unnormalize_audio_latent(data_mean, data_std, noisy_wav_emb) 139 | noisy_reconstruction = codec.decode(noisy_wav_emb) 140 | 141 | sf.write(f'example_audio/logsnr/{log_snr.item()}/audio_{idx}.wav', noisy_reconstruction.squeeze().to('cpu').numpy(), ENCODEC_SAMPLING_RATE) 142 | 143 | 144 | if __name__=='__main__': 145 | test() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.24.1 2 | aiohttp==3.9.1 3 | aiosignal==1.3.1 4 | appdirs==1.4.4 5 | async-timeout==4.0.3 6 | attrs==23.1.0 7 | audioread==3.0.1 8 | beartype==0.16.4 9 | certifi==2023.11.17 10 | cffi==1.16.0 11 | charset-normalizer==3.3.2 12 | click==8.1.7 13 | contourpy==1.2.0 14 | cycler==0.12.1 15 | data==0.4 16 | datasets==2.15.0 17 | decorator==5.1.1 18 | dill==0.3.7 19 | docker-pycreds==0.4.0 20 | einops==0.7.0 21 | ema-pytorch==0.3.1 22 | encodec==0.1.1 23 | evaluate==0.4.1 24 | ffmpeg==1.4 25 | filelock==3.13.1 26 | fonttools==4.45.1 27 | frozenlist==1.4.0 28 | fsspec==2023.10.0 29 | funcsigs==1.0.2 30 | future==1.0.0 31 | gitdb==4.0.11 32 | GitPython==3.1.40 33 | huggingface-hub==0.19.4 34 | idna==3.6 35 | Jinja2==3.1.2 36 | jiwer==3.0.3 37 | joblib==1.3.2 38 | kiwisolver==1.4.5 39 | latex==0.7.0 40 | lazy_loader==0.3 41 | librosa==0.10.1 42 | llvmlite==0.41.1 43 | MarkupSafe==2.1.3 44 | matplotlib==3.8.2 45 | mpmath==1.3.0 46 | msgpack==1.0.7 47 | multidict==6.0.4 48 | multiprocess==0.70.15 49 | networkx==3.2.1 50 | numba==0.58.1 51 | numpy==1.26.2 52 | nvidia-cublas-cu12==12.1.3.1 53 | nvidia-cuda-cupti-cu12==12.1.105 54 | nvidia-cuda-nvrtc-cu12==12.1.105 55 | nvidia-cuda-runtime-cu12==12.1.105 56 | nvidia-cudnn-cu12==8.9.2.26 57 | nvidia-cufft-cu12==11.0.2.54 58 | nvidia-curand-cu12==10.3.2.106 59 | nvidia-cusolver-cu12==11.4.5.107 60 | nvidia-cusparse-cu12==12.1.0.106 61 | nvidia-nccl-cu12==2.18.1 62 | nvidia-nvjitlink-cu12==12.3.101 63 | nvidia-nvtx-cu12==12.1.105 64 | packaging==23.2 65 | pandas==2.1.3 66 | Pillow==10.1.0 67 | platformdirs==4.0.0 68 | pooch==1.8.0 69 | praatio==6.1.0 70 | protobuf==4.25.1 71 | psutil==5.9.6 72 | pyarrow==14.0.1 73 | pyarrow-hotfix==0.6 74 | pycparser==2.21 75 | pyparsing==3.1.1 76 | python-dateutil==2.8.2 77 | pytz==2023.3.post1 78 | PyYAML==6.0.1 79 | rapidfuzz==3.6.1 80 | regex==2023.10.3 81 | requests==2.31.0 82 | responses==0.18.0 83 | safetensors==0.4.0 84 | scikit-learn==1.3.2 85 | scipy==1.11.4 86 | sentry-sdk==1.37.1 87 | setproctitle==1.3.3 88 | shutilwhich==1.1.0 89 | six==1.16.0 90 | smmap==5.0.1 91 | soundfile==0.12.1 92 | sox==1.4.1 93 | soxr==0.3.7 94 | sympy==1.12 95 | tempdir==0.7.1 96 | threadpoolctl==3.2.0 97 | tokenizers==0.15.0 98 | torch==2.1.1 99 | torchaudio==2.1.1 100 | torchvision==0.16.1 101 | tqdm==4.66.1 102 | transformers==4.35.2 103 | triton==2.1.0 104 | typing_extensions==4.8.0 105 | tzdata==2023.3 106 | urllib3==2.1.0 107 | wandb==0.16.0 108 | xxhash==3.4.1 109 | yarl==1.9.3 110 | -------------------------------------------------------------------------------- /scripts/sample/sample_16_ls_testclean.sh: -------------------------------------------------------------------------------- 1 | python train_audio_diffusion.py --eval_test --resume_dir saved_models/uvit_byt5large_12layer_final_scale50 --sampling_timesteps 250 --run_name test/sample16 --sampler ddpm --seed 42 --num_samples 16 --scale 0.5 --guidance 2.0,3.0,5.0 2 | -------------------------------------------------------------------------------- /scripts/train/train.sh: -------------------------------------------------------------------------------- 1 | python ./train_audio_diffusion.py --dataset_name librispeech --optimizer adamw --text_encoder google/byt5-base --batch_size 64 --gradient_accumulation_steps 1 --run_name librispeech_250k --save_and_sample_every 50000 --learning_rate 2e-4 --mixed_precision bf16 --train_schedule adaptive --sampling_schedule cosine --scale .5 --loss_type l2 --scale_skip_connection --inpainting_prob 0.5 --num_train_steps 250000 --num_transformer_layers 8 --dim 512 --dim_mults 1,1,1,1 --num_samples 128 --position_aware_cross_attention --loss_weighting asymmetric_lognormal_v_weighting --loss_weighting_mean -1.0 --loss_weighting_std 2.4 --objective pred_v --inpainting_embedding --dropout 0.1 --adam_weight_decay 2e-4 -------------------------------------------------------------------------------- /train_audio_diffusion.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils.utils import get_output_dir, parse_float_tuple 3 | import json 4 | import os 5 | import numpy as np 6 | 7 | from diffusion.audio_denoising_diffusion import GaussianDiffusion, Trainer 8 | from models.unet import Unet1D 9 | from models.transformer_wrapper import TransformerWrapper 10 | from transformers import AutoConfig 11 | 12 | 13 | def main(args): 14 | 15 | config = AutoConfig.from_pretrained(args.text_encoder) 16 | text_dim = config.d_model 17 | 18 | if args.model_arch == 'transformer': 19 | model = TransformerWrapper( 20 | dim=args.dim, 21 | text_dim=text_dim, 22 | channels = 128, 23 | position_aware_cross_attention=args.position_aware_cross_attention, 24 | inpainting_embedding = args.inpainting_embedding, 25 | num_transformer_layers=args.num_transformer_layers, 26 | dropout=args.dropout, 27 | ) 28 | elif args.model_arch == 'unet': 29 | model = Unet1D( 30 | dim=args.dim, 31 | text_dim=text_dim, 32 | dim_mults=args.dim_mults, 33 | inpainting_embedding = args.inpainting_embedding, 34 | position_aware_cross_attention=args.position_aware_cross_attention, 35 | num_transformer_layers=args.num_transformer_layers, 36 | scale_skip_connection=args.scale_skip_connection, 37 | dropout=args.dropout, 38 | input_connection = args.input_connection, 39 | num_transformer_registers = args.num_registers, 40 | ) 41 | 42 | args.trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 43 | print(f'Trainable params: {args.trainable_params}') 44 | 45 | 46 | loss_weighting_args = { 47 | 'loss_weighting_mean': args.loss_weighting_mean, 48 | 'loss_weighting_std': args.loss_weighting_std, 49 | 'loss_weighting_std_mult': args.loss_weighting_std_mult 50 | } 51 | diffusion = GaussianDiffusion( 52 | model, 53 | max_seq_len = 2048, 54 | text_encoder=args.text_encoder, 55 | sampling_timesteps = args.sampling_timesteps, # number of sampling steps 56 | sampler=args.sampler, 57 | log_var_interp = args.log_var_interp, 58 | langevin_step_size = args.langevin_step_size, 59 | snr_ddim_threshold = args.snr_ddim_threshold, 60 | train_schedule= args.train_schedule, 61 | sampling_schedule= args.sampling_schedule, 62 | loss_type = args.loss_type, # L1 or L2 63 | objective = args.objective, 64 | parameterization = args.parameterization, 65 | ema_decay = args.ema_decay, 66 | scale = args.scale, 67 | unconditional_prob = args.unconditional_prob, 68 | inpainting_prob = args.inpainting_prob, 69 | loss_weighting = args.loss_weighting, 70 | loss_weighting_args = loss_weighting_args, 71 | ) 72 | 73 | trainer = Trainer( 74 | args=args, 75 | diffusion=diffusion, 76 | dataset_name=args.dataset_name, 77 | optimizer=args.optimizer, 78 | batch_size= args.batch_size, 79 | gradient_accumulate_every = args.gradient_accumulation_steps, 80 | train_lr = args.learning_rate, 81 | train_num_steps = args.num_train_steps, 82 | lr_schedule = args.lr_schedule, 83 | num_warmup_steps = args.lr_warmup_steps, 84 | adam_betas = (args.adam_beta1, args.adam_beta2), 85 | adam_weight_decay = args.adam_weight_decay, 86 | save_and_sample_every = args.save_and_sample_every, 87 | num_samples = args.num_samples, 88 | mixed_precision = args.mixed_precision, 89 | prefix_inpainting_seconds = args.prefix_inpainting_seconds, 90 | duration_path=args.test_duration_path, 91 | seed=args.seed, 92 | ) 93 | 94 | if args.eval or args.eval_test: 95 | trainer.load(args.resume_dir, ckpt_step=args.ckpt_step) 96 | cls_free_guidances = args.guidance 97 | for cls_free_guidance in cls_free_guidances: 98 | trainer.sample(cls_free_guidance=cls_free_guidance, prefix_seconds=args.prefix_inpainting_seconds, test=args.eval_test, seed=42, test_duration_path=args.test_duration_path) 99 | return 100 | 101 | if args.init_model is not None: 102 | trainer.load(args.init_model, init_only=True) 103 | 104 | if args.resume: 105 | trainer.load(args.resume_dir, ckpt_step=args.ckpt_step) 106 | 107 | trainer.train() 108 | 109 | if __name__ == "__main__": 110 | parser = argparse.ArgumentParser(description="Training arguments") 111 | parser.add_argument("--dataset_name", type=str, default='librispeech') 112 | parser.add_argument("--save_dir", type=str, default="saved_models") 113 | parser.add_argument("--text_encoder", type=str, default="google/byt5-small") 114 | parser.add_argument("--output_dir", type=str, default=None) 115 | parser.add_argument("--resume_dir", type=str, default=None) 116 | parser.add_argument("--init_model", type=str, default=None) 117 | parser.add_argument("--run_name", type=str, default=None) 118 | parser.add_argument("--seed", type=int, default=None) 119 | # Architecture hyperparameters 120 | parser.add_argument("--model_arch", 121 | type=str, 122 | default="unet", 123 | choices=["unet", "transformer"], 124 | help="Choose the model architecture") 125 | parser.add_argument("--dim", type=int, default=512) 126 | parser.add_argument('--dim_mults', type=parse_float_tuple, default=(1, 1, 1, 1.5), help='Tuple of integer values for dim_mults') 127 | parser.add_argument("--position_aware_cross_attention", action="store_true", default=False) 128 | parser.add_argument("--scale_skip_connection", action="store_true", default=False) 129 | parser.add_argument("--input_connection", action="store_true", default=False) 130 | parser.add_argument("--num_transformer_layers", type=int, default=3) 131 | parser.add_argument("--num_registers", type=int, default=8) 132 | parser.add_argument("--dropout", type=float, default=0.) 133 | parser.add_argument("--inpainting_embedding", action="store_true", default=False) 134 | 135 | # Optimization hyperparameters 136 | parser.add_argument("--optimizer", type=str, default="adamw") 137 | parser.add_argument("--batch_size", type=int, default=16) 138 | parser.add_argument("--num_train_steps", type=int, default=60000) 139 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 140 | parser.add_argument("--learning_rate", type=float, default=1e-4) 141 | parser.add_argument("--clip_grad_norm", type=float, default=1.0) 142 | parser.add_argument("--lr_schedule", type=str, default="cosine") 143 | parser.add_argument("--lr_warmup_steps", type=int, default=1000) 144 | parser.add_argument("--adam_beta1", type=float, default=0.9) 145 | parser.add_argument("--adam_beta2", type=float, default=0.999) 146 | parser.add_argument("--adam_weight_decay", type=float, default=0) 147 | parser.add_argument("--ema_decay", type=float, default=0.9999) 148 | # Diffusion Hyperparameters 149 | parser.add_argument( 150 | "--objective", 151 | type=str, 152 | default="pred_eps", 153 | choices=["pred_eps", "pred_x0", "pred_v"], 154 | help=( 155 | "Which loss objective to use for the diffusion objective." 156 | ), 157 | ) 158 | parser.add_argument( 159 | "--parameterization", 160 | type=str, 161 | default="pred_v", 162 | choices=["pred_eps", "pred_x0", "pred_v"], 163 | help=( 164 | "Which output parameterization to use for the diffusion network." 165 | ), 166 | ) 167 | parser.add_argument( 168 | "--loss_type", 169 | type=str, 170 | default="l1", 171 | choices=["l1", "l2"], 172 | help=( 173 | "Which loss function to use for diffusion." 174 | ), 175 | ) 176 | parser.add_argument( 177 | "--loss_weighting", 178 | type=str, 179 | default="edm", 180 | choices=["edm", "sigmoid", "v_weighting", "lognormal_v_weighting", "asymmetric_lognormal_v_weighting", "monotonic_lognormal_v_weighting"], 181 | help=( 182 | "Which loss function to use for diffusion." 183 | ), 184 | ) 185 | parser.add_argument( 186 | "--loss_weighting_mean", 187 | type=float, 188 | default=0.0, 189 | help=( 190 | "Mean for loss weighting function." 191 | ), 192 | ) 193 | parser.add_argument( 194 | "--loss_weighting_std", 195 | type=float, 196 | default=1.0, 197 | help=( 198 | "Standard deviation for loss weighting function." 199 | ), 200 | ) 201 | parser.add_argument( 202 | "--loss_weighting_std_mult", 203 | type=float, 204 | default=2.0, 205 | help=( 206 | "Standard deviation for loss weighting function." 207 | ), 208 | ) 209 | parser.add_argument( 210 | "--train_schedule", 211 | type=str, 212 | default="cosine", 213 | choices=["beta_linear", "simple_linear", "cosine", 'sigmoid', 'adaptive'], 214 | help=( 215 | "Which noise schedule to use." 216 | ), 217 | ) 218 | parser.add_argument( 219 | "--sampling_schedule", 220 | type=str, 221 | default=None, 222 | choices=["beta_linear", "cosine", "simple_linear", None], 223 | help=( 224 | "Which noise schedule to use." 225 | ), 226 | ) 227 | parser.add_argument("--resume", action="store_true", default=False) 228 | parser.add_argument("--scale", type=float, default=1.0) 229 | parser.add_argument("--sampling_timesteps", type=int, default=250) 230 | parser.add_argument("--log_var_interp", type=float, default=0.2) 231 | parser.add_argument("--snr_ddim_threshold", type=float, default=None) 232 | parser.add_argument("--langevin_step_size", type=float, default=0.0) 233 | parser.add_argument("--ckpt_step", type=int, default=None) 234 | # Audio Training Parameters 235 | parser.add_argument("--unconditional_prob", type=float, default=.1) 236 | parser.add_argument("--inpainting_prob", type=float, default=.5) 237 | # Generation Arguments 238 | parser.add_argument("--save_and_sample_every", type=int, default=20000) 239 | parser.add_argument("--num_samples", type=int, default=None) 240 | parser.add_argument( 241 | "--sampler", 242 | type=str, 243 | default="ddpm", 244 | choices=["ddim", "ddpm"], 245 | help=( 246 | "Which sampler use for diffusion." 247 | ), 248 | ) 249 | parser.add_argument("--prefix_inpainting_seconds", type=float, default=0.) 250 | # Accelerate arguments 251 | parser.add_argument( 252 | "--mixed_precision", 253 | type=str, 254 | default="no", 255 | choices=["no", "fp16", "bf16"], 256 | help=( 257 | "Whether to use mixed precision. Choose" 258 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 259 | "and an Nvidia Ampere GPU." 260 | ), 261 | ) 262 | # Load and eval model 263 | parser.add_argument("--eval", action="store_true", default=False) 264 | parser.add_argument("--eval_test", action="store_true", default=False) 265 | parser.add_argument("--test_duration_path", type=str, default=None) 266 | parser.add_argument('--guidance', type=parse_float_tuple, help='Tuple of float values for dim_mults') 267 | 268 | args = parser.parse_args() 269 | if args.eval or args.eval_test: 270 | assert args.resume_dir is not None 271 | 272 | if args.eval or args.eval_test: 273 | with open(os.path.join(args.resume_dir, 'args.json'), 'rt') as f: 274 | saved_args = json.load(f) 275 | args_dict = vars(args) 276 | heldout_params = {'run_name', 'output_dir', 'resume_dir', 'eval', 'eval_test', 'prefix_inpainting_seconds', 'num_samples', 'sampling_timesteps', 'sampling_schedule', 'scale', 'sampler', 'mixed_precision', 'guidance', 'ckpt_step', 'log_var_interp', 'langevin_step_size', 'snr_ddim_threshold', 'test_duration_path'} 277 | for k,v in saved_args.items(): 278 | if k in heldout_params: 279 | continue 280 | args_dict[k] = v 281 | 282 | main(args) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datetime import datetime 3 | import os 4 | from pathlib import Path 5 | import argparse 6 | 7 | def compute_grad_norm(parameters): 8 | # implementation adapted from https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html#clip_grad_norm_ 9 | parameters = [p for p in parameters if p.grad is not None] 10 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), p=2) for p in parameters]), p=2).item() 11 | return total_norm 12 | 13 | def get_output_dir(args): 14 | model_dir = f'{args.dataset_name}/{args.run_name}/' 15 | output_dir = os.path.join(args.save_dir, model_dir) 16 | return output_dir 17 | 18 | def parse_float_tuple(dim_mults_str): 19 | try: 20 | dim_mults = tuple(map(float, dim_mults_str.split(','))) 21 | return dim_mults 22 | except ValueError: 23 | raise argparse.ArgumentTypeError('dim_mults must be a comma-separated list of integers') 24 | --------------------------------------------------------------------------------