├── .gitignore ├── README.md ├── audio_datasets ├── __init__.py ├── constants.py ├── librispeech.py ├── mls.py ├── preprocess_mls.py └── text_grid_utils.py ├── data └── aligned_librispeech.tar.gz ├── diffusion ├── audio_denoising_diffusion.py ├── noise_schedule.py └── optimizer.py ├── evaluation └── evaluate_transcript.py ├── models ├── modules │ ├── __init__.py │ ├── conformer.py │ ├── conv.py │ ├── norm.py │ └── transformer.py └── unet.py ├── neural_codec └── encodec_wrapper.py ├── requirements.txt ├── scripts ├── sample │ └── sample_16_ls_testclean.sh └── train │ ├── train.sh │ └── train_distributed.sh ├── train_audio_diffusion.py ├── training.md └── utils └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Model checkpoints 132 | saved_models/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simple-TTS 2 | 3 | This repo contains the implementation for Simple-TTS, a latent diffusion model for text-to-speech generation. Our submission describing this work is currently under review: 4 | 5 | **Simple-TTS: End-to-End Text-to-Speech Synthesis with Latent Diffusion**\ 6 | by Justin Lovelace, Soham Ray, Kwangyoun Kim, Kilian Q. Weinberger, and Felix Wu 7 | 8 | ## Environment 9 | Install the required dependencies with: 10 | ```bash 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | ## Datasets 15 | 16 | We train our models using the English subset of the Multilingual LibriSpeech (MLS) dataset and use the standard LibriSpeech dataset for evaluation. 17 | 18 | For the MLS dataset, download `mls_english.tar.gz` from [https://www.openslr.org/94/](https://www.openslr.org/94/). Store the unzipped dataset at `/persist/data/mls/mls_english/` or update the `data_dir` path in `audio_datasets/mls.py` accordingly. The MLS dataset can be processed by running the `audio_datasets/preprocess_mls.py` script. 19 | 20 | We access LibriSpeech through the Huggingface Hub. For speaker-prompted generation, we utilize the first three seconds of another prompt. To extract the corresponding transcript from the first three seconds, we utilized the [Montreal Forced Aligner](https://montreal-forced-aligner.readthedocs.io/en/latest/first_steps/example.html#example-1-aligning-librispeech-english). An aligned version of LibriSpeech can be found at 21 | 22 | ```bash 23 | data/aligned_librispeech.tar.gz 24 | ``` 25 | 26 | It should be untarred and you should update the update the `ALIGNED_DATA_DIR` path in `audio_datasets/librispeech.py` to point to the data directory. 27 | 28 | ## Training 29 | 30 | We trained all of the models using `bf16` mixed precision with Nvidia A10G GPUs. Things like batch size, etc. in the provided scripts should be adjusted depending on the hardware setup. 31 | 32 | We provide a sample script to train the diffusion model with reasonable hyperparameters on a single Nvidia A10G GPU. Distributed training is recommended to increase the batch size, but this is useful for debugging. 33 | ```bash 34 | ./scripts/train/train.sh 35 | ``` 36 | 37 | We use [Huggingface Accelerate](https://huggingface.co/docs/accelerate/index) for distributed training and trained our final model on a `g5.48xlarge` instance (8 Nvidia A10Gs). After running `accelerate config` to set the appropriate environment variables (e.g. number of GPUs), a distributed training job with our hyperparameter settings can be launced with 38 | ```bash 39 | ./scripts/train/train_distributed.sh 40 | ``` 41 | 42 | 43 | ## Model Checkpoint 44 | 45 | Our model checkpoint can be downloaded from [here](https://simple-tts.awsdev.asapp.com/ckpt.tar.gz). 46 | 47 | The checkpoint folders contain an `args.json` with the hyperparameter settings for the model as well as the checkpoint itself. The model was trained for 200k steps with a global batch size of 256. The model is likely undertrained and quality improvements could be gained from additional training. Using the `init_model` argument with the training scripts will initialize the model from the provided path. 48 | 49 | ## Sampling 50 | We provide a script for synthesizing speech for the Librispeech test-clean set: 51 | ```bash 52 | ./scripts/sample/sample_16_ls_testclean.sh 53 | ``` 54 | The `--resume_dir` argument should be updated with the path of a trained model. 55 | 56 | ## Contact 57 | Feel free to create an issue if have any questions. 58 | 59 | 60 | ## Acknowledgement 61 | This work built upon excellent open-source implementations from [Phil Wang (Lucidrains)](https://github.com/lucidrains). Specifically, we built off of his Pytorch [DDPM implementation](https://github.com/lucidrains/denoising-diffusion-pytorch). 62 | -------------------------------------------------------------------------------- /audio_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from audio_datasets.librispeech import LibriSpeech 2 | from audio_datasets.mls import MLS 3 | from audio_datasets.constants import ENCODEC_REDUCTION_FACTOR, ENCODEC_SAMPLING_RATE, LATENT_SAMPLING_RATE -------------------------------------------------------------------------------- /audio_datasets/constants.py: -------------------------------------------------------------------------------- 1 | ENCODEC_SAMPLING_RATE = 24000 2 | ENCODEC_REDUCTION_FACTOR = 320 3 | LATENT_SAMPLING_RATE = 75 -------------------------------------------------------------------------------- /audio_datasets/librispeech.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | from datasets import load_dataset, Audio, concatenate_datasets 3 | import numpy as np 4 | import torch 5 | from tqdm import tqdm 6 | import torch.nn.functional as F 7 | import os 8 | import random 9 | import math 10 | 11 | from audio_datasets.text_grid_utils import get_partial_transcript 12 | 13 | def lowercase_text(example): 14 | example["text"] = example["text"].lower() 15 | return example 16 | 17 | 18 | 19 | MAX_DURATION_IN_SECONDS = 20 20 | ENCODEC_SAMPLING_RATE = 24000 21 | ENCODEC_REDUCTION_FACTOR = 320 22 | LATENT_SAMPLING_RATE = 75 23 | ALIGNED_DATA_DIR = '/persist/data/aligned_librispeech' 24 | 25 | def round_up_to_multiple(number, multiple): 26 | remainder = number % multiple 27 | if remainder == 0: 28 | return number 29 | else: 30 | return number + (multiple - remainder) 31 | 32 | def compute_max_length(multiple=128): 33 | max_len = MAX_DURATION_IN_SECONDS*ENCODEC_SAMPLING_RATE 34 | waveform_multiple = multiple*ENCODEC_REDUCTION_FACTOR 35 | 36 | max_len = round_up_to_multiple(max_len, waveform_multiple) 37 | return max_len 38 | 39 | def is_audio_length_in_range(audio): 40 | return len(audio['array']) <= (MAX_DURATION_IN_SECONDS*ENCODEC_SAMPLING_RATE) 41 | 42 | 43 | def is_audio_length_in_test_range(audio): 44 | return ((4*ENCODEC_SAMPLING_RATE) <= len(audio['array'])) and (len(audio['array']) <= (10*ENCODEC_SAMPLING_RATE)) 45 | 46 | 47 | class LibriSpeech(Dataset): 48 | """ 49 | Wrapper around HuggingFace dataset for processing. 50 | """ 51 | def __init__(self, split='train', debug=False, tokenizer=None, max_seq_len=None, sampling_rate=None): 52 | super().__init__() 53 | self.sr = ENCODEC_SAMPLING_RATE if sampling_rate is None else sampling_rate 54 | self.split = split 55 | self.split2dir = {'valid': 'dev-clean', 'test': 'test-clean'} 56 | if split == 'train': 57 | train100 = load_dataset('librispeech_asr', 'clean', split='train.100') 58 | train360 = load_dataset('librispeech_asr', 'clean', split='train.360') 59 | 60 | self.hf_dataset = concatenate_datasets([train100, train360]) 61 | elif split == 'valid': 62 | self.hf_dataset = load_dataset('librispeech_asr', 'clean', split='validation') 63 | elif split == 'test': 64 | self.hf_dataset = load_dataset('librispeech_asr', 'clean', split='test') 65 | else: 66 | raise ValueError(f"invalid split: {split}, must be in ['train', 'valid'] ") 67 | 68 | # Downsample to accelerate processing for debugging purposes 69 | if debug: 70 | self.hf_dataset = self.hf_dataset.select(range(100)) 71 | # Resample to 24kHz for Encodec 72 | self.hf_dataset = self.hf_dataset.cast_column("audio", Audio(sampling_rate=self.sr)) 73 | 74 | self.hf_dataset = self.hf_dataset.map(lowercase_text) 75 | if split == 'train': 76 | self.hf_dataset = self.hf_dataset.filter(is_audio_length_in_range, input_columns=['audio']) 77 | elif split == 'test': 78 | self.hf_dataset = self.hf_dataset.filter(is_audio_length_in_test_range, input_columns=['audio']) 79 | 80 | if self.split in {'valid', 'test'}: 81 | unique_speaker_ids = set(self.hf_dataset['speaker_id']) 82 | self.speaker_datasets = {speaker_id:self.hf_dataset.filter(lambda example: example["speaker_id"] == speaker_id) for speaker_id in unique_speaker_ids} 83 | 84 | 85 | self.max_seq_len = max_seq_len if max_seq_len is not None else compute_max_length() 86 | print(f'Max seq length: {self.max_seq_len/ENCODEC_REDUCTION_FACTOR}') 87 | 88 | if tokenizer is not None: 89 | self.hf_dataset = self.hf_dataset.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True, max_length=256), batched=True) 90 | self.tokenizer = tokenizer 91 | 92 | 93 | def __getitem__(self, index): 94 | example = self.hf_dataset[index] 95 | text = example['text'] 96 | wav = example['audio']['array'][:self.max_seq_len] 97 | wavpath = example['audio']['path'] 98 | npad = self.max_seq_len - len(wav) 99 | assert npad>=0, f'Waveform length {len(wav)} needs to be less than {self.max_seq_len}' 100 | # [1, L]: Channels x length 101 | audio_duration_sec = len(wav)/ENCODEC_SAMPLING_RATE 102 | wav = torch.tensor(np.pad(wav, pad_width=(0, npad), mode='constant'), dtype=torch.float).unsqueeze(0) 103 | 104 | data = {'wav': wav, 'text': text, 'audio_duration': audio_duration_sec, 'path':wavpath} 105 | 106 | 107 | # Speaker prompting 108 | if self.split in {'valid', 'test'}: 109 | split_dir = self.split2dir[self.split] 110 | speaker_id = example['speaker_id'] 111 | speaker_ds = self.speaker_datasets[speaker_id] 112 | # Sample idx for n-1 elements and remap matching element to the last element 113 | speaker_idx = random.randint(0, len(speaker_ds)-2) 114 | if speaker_ds[speaker_idx]['id'] == example['id']: 115 | speaker_idx = len(speaker_ds)-1 116 | speaker_example = speaker_ds[speaker_idx] 117 | textgrid_path = os.path.join(ALIGNED_DATA_DIR, split_dir, f'{speaker_id}', f'{speaker_example["id"]}.TextGrid') 118 | partial_transcript = get_partial_transcript(textgrid_path) 119 | 120 | speaker_text = partial_transcript['transcript'] 121 | speaker_wav_frames = math.ceil(partial_transcript['end_time'] * ENCODEC_SAMPLING_RATE) 122 | speaker_audio_duration_sec = speaker_wav_frames/ENCODEC_SAMPLING_RATE 123 | speaker_wav = speaker_example['audio']['array'][:speaker_wav_frames] 124 | speaker_npad = self.max_seq_len - len(speaker_wav) 125 | assert speaker_npad>=0, f'Waveform length {len(speaker_wav)} needs to be less than {self.max_seq_len}' 126 | # [1, L]: Channels x length 127 | 128 | speaker_wav = torch.tensor(np.pad(speaker_wav, pad_width=(0, speaker_npad), mode='constant'), dtype=torch.float).unsqueeze(0) 129 | 130 | speaker_data = {'speaker_wav': speaker_wav, 'speaker_text': speaker_text, 'speaker_audio_duration': speaker_audio_duration_sec} 131 | data.update(speaker_data) 132 | 133 | 134 | if self.tokenizer is not None: 135 | data['input_ids'] = torch.tensor(example['input_ids'], dtype=torch.long) 136 | data['attention_mask'] = torch.tensor(example['attention_mask'], dtype=torch.long) 137 | 138 | data['audio_mask'] = torch.ones((self.max_seq_len//ENCODEC_REDUCTION_FACTOR,), dtype=torch.bool) 139 | 140 | return data 141 | 142 | def __len__(self): 143 | return len(self.hf_dataset) 144 | 145 | if __name__ == "__main__": 146 | dataset = LibriSpeech(split='test') 147 | 148 | example = dataset.__getitem__(0) 149 | import soundfile as sf 150 | import pdb; pdb.set_trace() 151 | sf.write(f'example_audio/librispeech_sample.wav', example['wav'], ENCODEC_SAMPLING_RATE) 152 | with open(f'example_audio/librispeech_text.txt', 'w') as f: 153 | print(example['text'], file=f) 154 | -------------------------------------------------------------------------------- /audio_datasets/mls.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset as TorchDataset 2 | from datasets import load_dataset, Audio, concatenate_datasets 3 | from datasets import Dataset as HFDataset 4 | import numpy as np 5 | import os 6 | import random 7 | import csv 8 | 9 | import torch 10 | from tqdm import tqdm 11 | import torch.nn.functional as F 12 | 13 | from transformers import AutoTokenizer 14 | 15 | from audio_datasets.constants import ENCODEC_REDUCTION_FACTOR, ENCODEC_SAMPLING_RATE 16 | 17 | def read_csv_into_dict(filename): 18 | data_dir = os.path.dirname(filename) 19 | with open(filename, 'r') as file: 20 | reader = csv.DictReader(file) 21 | data = {'audio': [], 'text':[]} 22 | 23 | for row in reader: 24 | for key, value in row.items(): 25 | if key == 'file_name': 26 | value = os.path.join(data_dir, value) 27 | data['audio'].append(value) 28 | continue 29 | elif key == 'text': 30 | data[key].append(value) 31 | else: 32 | raise ValueError(f'Unexpected csv key: {key}') 33 | 34 | return data 35 | 36 | MAX_DURATION_IN_SECONDS = 20 37 | 38 | def round_up_to_multiple(number, multiple): 39 | remainder = number % multiple 40 | if remainder == 0: 41 | return number 42 | else: 43 | return number + (multiple - remainder) 44 | 45 | def round_up_to_waveform_multiple(number, multiple=16): 46 | waveform_multiple = multiple*ENCODEC_REDUCTION_FACTOR 47 | rounded_number = round_up_to_multiple(number, waveform_multiple) 48 | return rounded_number 49 | 50 | def compute_max_length(multiple=16): 51 | max_len = MAX_DURATION_IN_SECONDS*ENCODEC_SAMPLING_RATE 52 | 53 | max_len = round_up_to_waveform_multiple(max_len, multiple) 54 | return max_len 55 | 56 | def is_audio_length_in_range(audio): 57 | return len(audio['array']) < (MAX_DURATION_IN_SECONDS*ENCODEC_SAMPLING_RATE) 58 | 59 | 60 | class MLS(TorchDataset): 61 | """ 62 | Wrapper around HuggingFace dataset for processing. 63 | """ 64 | def __init__(self, data_dir='/persist/data/mls/mls_english/' , split='train', debug=False, tokenizer=None, sampling_rate=None, max_text_len=256): 65 | super().__init__() 66 | self.sr = ENCODEC_SAMPLING_RATE if sampling_rate is None else sampling_rate 67 | print('Loading audio dataset...') 68 | if split == 'train': 69 | data_dict = read_csv_into_dict(os.path.join(data_dir, 'train', 'metadata.csv')) 70 | 71 | self.hf_dataset = HFDataset.from_dict(data_dict).cast_column("audio", Audio(sampling_rate=self.sr)) 72 | elif split == 'valid': 73 | self.hf_dataset = load_dataset("audiofolder", data_dir=os.path.join(data_dir, 'dev'))['train'].cast_column("audio", Audio(sampling_rate=self.sr)) 74 | else: 75 | raise ValueError(f"invalid split: {split}, must be in ['train', 'valid'] ") 76 | 77 | # Downsample to accelerate processing for debugging purposes 78 | if debug: 79 | self.hf_dataset = self.hf_dataset.select(range(100)) 80 | # Resample to 24kHz for Encodec 81 | 82 | self.max_text_len = max_text_len 83 | 84 | self.max_seq_len = compute_max_length() 85 | print(f'Max seq length: {self.max_seq_len/ENCODEC_REDUCTION_FACTOR}') 86 | 87 | self.tokenizer = tokenizer 88 | 89 | 90 | def __getitem__(self, index): 91 | example = self.hf_dataset[index] 92 | text = example['text'] 93 | wav = example['audio']['array'] 94 | npad = self.max_seq_len - len(wav) 95 | assert npad>=0, f'Waveform length {len(wav)} needs to be less than {self.max_seq_len}' 96 | # [1, L]: Channels x length 97 | wav_len = len(wav) 98 | audio_duration_sec = len(wav)/self.sr 99 | wav = torch.tensor(np.pad(wav, pad_width=(0, npad), mode='constant'), dtype=torch.float).unsqueeze(0) 100 | 101 | silence_tokens = round(random.random()*(self.max_seq_len-wav_len)) 102 | num_unmasked_tokens = round_up_to_waveform_multiple(wav_len+silence_tokens)//ENCODEC_REDUCTION_FACTOR 103 | audio_mask = torch.zeros((self.max_seq_len//ENCODEC_REDUCTION_FACTOR,), dtype=torch.bool) 104 | audio_mask[:num_unmasked_tokens] = True 105 | 106 | if self.tokenizer is not None: 107 | tokenized_text = self.tokenizer(example['text'], padding="max_length", truncation=True, max_length=self.max_text_len) 108 | input_ids = torch.tensor(tokenized_text['input_ids'], dtype=torch.long) 109 | attention_mask = torch.tensor(tokenized_text['attention_mask'], dtype=torch.long) 110 | else: 111 | input_ids = None 112 | attention_mask = None 113 | return {'wav': wav, 'text': text, 'path':example['audio']['path'], 'audio_duration': audio_duration_sec, 'wav_len':wav_len } 114 | 115 | 116 | return {'wav': wav, 'text': text, 'input_ids': input_ids, 'attention_mask': attention_mask, 'audio_duration': audio_duration_sec, 'wav_len':wav_len, 'audio_mask':audio_mask} 117 | 118 | def __len__(self): 119 | return len(self.hf_dataset) 120 | 121 | if __name__ == "__main__": 122 | text_tokenizer = AutoTokenizer.from_pretrained('google/byt5-small') 123 | train_ds = MLS(split='train', tokenizer=text_tokenizer) 124 | import pdb; pdb.set_trace() 125 | val_ds = MLS(split='valid') 126 | 127 | example = train_ds.__getitem__(0) 128 | import soundfile as sf 129 | sf.write(f'example_audio/mls_sample.wav', example['wav'].squeeze().numpy(), ENCODEC_SAMPLING_RATE) 130 | with open(f'example_audio/mls_text.txt', 'w') as f: 131 | print(example['text'], file=f) 132 | -------------------------------------------------------------------------------- /audio_datasets/preprocess_mls.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import soundfile as sf 4 | import numpy as np 5 | from tqdm import tqdm 6 | from audio_datasets.constants import ENCODEC_REDUCTION_FACTOR, ENCODEC_SAMPLING_RATE 7 | 8 | MAX_DURATION_IN_SECONDS = 20 9 | 10 | 11 | def round_up_to_multiple(number, multiple): 12 | remainder = number % multiple 13 | if remainder == 0: 14 | return number 15 | else: 16 | return number + (multiple - remainder) 17 | 18 | def compute_max_length(multiple=128): 19 | max_len = MAX_DURATION_IN_SECONDS*ENCODEC_SAMPLING_RATE 20 | waveform_multiple = multiple*ENCODEC_REDUCTION_FACTOR 21 | 22 | max_len = round_up_to_multiple(max_len, waveform_multiple) 23 | return max_len 24 | 25 | def is_audio_length_in_range(audio, sampling_rate): 26 | return len(audio) <= (MAX_DURATION_IN_SECONDS*sampling_rate) 27 | 28 | def main(): 29 | max_length = compute_max_length() 30 | # Define the header names 31 | headers = ['file_name', 'text'] 32 | data_dir = '/persist/data/mls/mls_english/' 33 | 34 | for split in ['train', 'dev', 'test']: 35 | print(f'Converting {split} split...') 36 | # Specify the input and output file paths 37 | split_dir = os.path.join(data_dir, split) 38 | input_file = os.path.join(split_dir, f'transcripts.txt') 39 | output_file = os.path.join(split_dir, f'metadata.csv') 40 | 41 | # Open the input file for reading 42 | with open(input_file, 'r') as file: 43 | # Create a CSV writer for the output file 44 | with open(output_file, 'w', newline='') as csvfile: 45 | writer = csv.writer(csvfile) 46 | 47 | # Write the headers to the CSV file 48 | writer.writerow(headers) 49 | # Read each line in the input file 50 | for line in tqdm(file): 51 | # Split the line into file path and description 52 | audio_id, description = line.strip().split('\t') 53 | speaker_id, book_id, file_id = audio_id.split('_') 54 | file_path = os.path.join('audio', speaker_id, book_id, f'{audio_id}.flac') 55 | audio, samplerate = sf.read(os.path.join(split_dir, file_path)) 56 | if is_audio_length_in_range(audio, samplerate): 57 | # Write the file path and description as a row in the CSV file 58 | writer.writerow([file_path, description]) 59 | 60 | 61 | 62 | print('Conversion complete!') 63 | 64 | 65 | if __name__ == "__main__": 66 | main() -------------------------------------------------------------------------------- /audio_datasets/text_grid_utils.py: -------------------------------------------------------------------------------- 1 | from praatio import textgrid 2 | import os 3 | 4 | data_path = '/persist/data/aligned_librispeech/test-clean/1089/1089-134686-0000.TextGrid' 5 | 6 | tg = textgrid.openTextgrid(data_path, False) 7 | def get_word_intervals(textgrid_path): 8 | tg = textgrid.openTextgrid(textgrid_path, False) 9 | return tg.getTier("words").entries 10 | 11 | def get_partial_transcript(textgrid_path, transcript_end_time=3): 12 | intervals = get_word_intervals(textgrid_path) 13 | word_list = [] 14 | end_time = 0 15 | for interval in intervals: 16 | if interval.end > transcript_end_time: 17 | break 18 | word_list.append(interval.label) 19 | end_time = interval.end 20 | return {'transcript': ' '.join(word_list), 21 | 'end_time': end_time} 22 | 23 | 24 | -------------------------------------------------------------------------------- /data/aligned_librispeech.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/simple-tts/25eb8be644c3dd5243f88b51dea315da62faf0f6/data/aligned_librispeech.tar.gz -------------------------------------------------------------------------------- /diffusion/audio_denoising_diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | from pathlib import Path 3 | import random 4 | from functools import partial 5 | from collections import namedtuple, Counter 6 | from multiprocessing import cpu_count 7 | import os 8 | import numpy as np 9 | import csv 10 | import timeit 11 | import json 12 | import argparse 13 | from collections import defaultdict 14 | from contextlib import nullcontext 15 | import soundfile as sf 16 | 17 | 18 | import torch 19 | from torch import nn, einsum 20 | import torch.nn.functional as F 21 | from torch.utils.data import Dataset, DataLoader 22 | 23 | from einops import rearrange, reduce, repeat 24 | 25 | from tqdm.auto import tqdm 26 | from ema_pytorch import EMA 27 | 28 | from transformers import get_scheduler, AutoTokenizer, T5ForConditionalGeneration 29 | 30 | from accelerate import Accelerator, DistributedDataParallelKwargs 31 | 32 | from diffusion.optimizer import get_adamw_optimizer 33 | from utils.utils import compute_grad_norm 34 | import utils.utils as file_utils 35 | from diffusion.noise_schedule import * 36 | from audio_datasets import LibriSpeech, ENCODEC_SAMPLING_RATE, LATENT_SAMPLING_RATE, MLS 37 | from neural_codec.encodec_wrapper import EncodecWrapper 38 | from utils.utils import get_output_dir 39 | 40 | 41 | ModelPrediction = namedtuple('ModelPrediction', ['pred_eps', 'pred_x_start', 'pred_v', 'latents']) 42 | 43 | # Recommendation from https://arxiv.org/abs/2303.09556 44 | MIN_SNR_GAMMA = 5 45 | 46 | # helpers functions 47 | 48 | def exists(x): 49 | return x is not None 50 | 51 | def default(val, d): 52 | if exists(val): 53 | return val 54 | return d() if callable(d) else d 55 | 56 | def identity(t, *args, **kwargs): 57 | return t 58 | 59 | def cycle(dl): 60 | while True: 61 | for data in dl: 62 | yield data 63 | 64 | def num_to_groups(num, divisor): 65 | groups = num // divisor 66 | remainder = num % divisor 67 | arr = [divisor] * groups 68 | if remainder > 0: 69 | arr.append(remainder) 70 | return arr 71 | 72 | def l2norm(t): 73 | return F.normalize(t, dim = -1) 74 | 75 | # Avoid log(0) 76 | def log(t, eps = 1e-20): 77 | return torch.log(t.clamp(min = eps)) 78 | 79 | def right_pad_dims_to(x, t): 80 | padding_dims = x.ndim - t.ndim 81 | if padding_dims <= 0: 82 | return t 83 | return t.view(*t.shape, *((1,) * padding_dims)) 84 | 85 | # gaussian diffusion trainer class 86 | 87 | def extract(a, t, x_shape): 88 | b, *_ = t.shape 89 | out = a.gather(-1, t) 90 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 91 | 92 | def set_seeds(seed): 93 | random.seed(seed) 94 | torch.manual_seed(seed) 95 | torch.cuda.manual_seed(seed) 96 | 97 | def masked_mean(t, *, dim, mask = None): 98 | if not exists(mask): 99 | return t.mean(dim = dim) 100 | 101 | denom = mask.sum(dim = dim) 102 | masked_t = t.masked_fill(~mask, 0.) 103 | 104 | return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5) 105 | 106 | 107 | class GaussianDiffusion(nn.Module): 108 | def __init__( 109 | self, 110 | model, 111 | *, 112 | max_seq_len, 113 | sampling_timesteps = 250, 114 | text_encoder = 'google/byt5-small', 115 | loss_type = 'l1', 116 | objective = 'pred_v', 117 | parameterization = 'pred_v', 118 | train_schedule = 'cosine', 119 | ema_decay = 0.9999, 120 | sampling_schedule = None, 121 | scale = 1., 122 | scale_by_std=True, 123 | sampler = 'ddim', 124 | unconditional_prob = 0.2, 125 | inpainting_prob = 0.5, 126 | inpainting_duration_mode = 0.01, 127 | inpainting_duration_concentration = 5, 128 | ): 129 | super().__init__() 130 | 131 | self.denoising_network = EMA(model, beta = ema_decay, update_every = 1, power=3/4) 132 | 133 | self.max_seq_len = max_seq_len 134 | 135 | self.objective = objective 136 | self.parameterization = parameterization 137 | self.sampler=sampler 138 | 139 | # Min-SNR weighting from https://arxiv.org/abs/2303.09556 140 | # Option no longer supported; buffer kept for backwards compatibility with statedict of prior checkpoints 141 | self.register_buffer('min_snr_gamma', torch.tensor(MIN_SNR_GAMMA)) 142 | 143 | self.loss_type = loss_type 144 | 145 | assert objective in {'pred_eps', 'pred_x0', 'pred_v'}, f'objective {objective} must be one of pred_eps, pred_x0, pred_v' 146 | 147 | if train_schedule == "simple_linear": 148 | alpha_schedule = simple_linear_schedule 149 | elif train_schedule == "beta_linear": 150 | alpha_schedule = beta_linear_schedule 151 | elif train_schedule == "cosine": 152 | alpha_schedule = cosine_schedule 153 | elif train_schedule == "sigmoid": 154 | alpha_schedule = sigmoid_schedule 155 | else: 156 | raise ValueError(f'invalid noise schedule {train_schedule}') 157 | 158 | self.train_schedule = partial(time_to_alpha, alpha_schedule=alpha_schedule, scale=scale) 159 | 160 | # Sampling schedule 161 | if sampling_schedule is None: 162 | sampling_alpha_schedule = None 163 | elif sampling_schedule == "simple_linear": 164 | sampling_alpha_schedule = simple_linear_schedule 165 | elif sampling_schedule == "beta_linear": 166 | sampling_alpha_schedule = beta_linear_schedule 167 | elif sampling_schedule == "cosine": 168 | sampling_alpha_schedule = cosine_schedule 169 | elif sampling_schedule == "sigmoid": 170 | sampling_alpha_schedule = sigmoid_schedule 171 | else: 172 | raise ValueError(f'invalid sampling schedule {sampling_schedule}') 173 | 174 | if exists(sampling_alpha_schedule): 175 | self.sampling_schedule = partial(time_to_alpha, alpha_schedule=sampling_alpha_schedule, scale=scale) 176 | else: 177 | self.sampling_schedule = self.train_schedule 178 | 179 | 180 | # Optionally rescale data to have unit variance 181 | self.scale_by_std = scale_by_std 182 | if scale_by_std: 183 | self.register_buffer('std_scale_factor', torch.tensor(-1.0)) 184 | else: 185 | self.std_scale_factor = 1.0 186 | 187 | # gamma schedules 188 | 189 | self.sampling_timesteps = sampling_timesteps 190 | 191 | # probability for self conditioning during training 192 | 193 | self.unconditional_prob = unconditional_prob 194 | self.inpainting_prob = inpainting_prob 195 | 196 | if self.unconditional_prob > 0: 197 | self.unconditional_bernoulli = torch.distributions.Bernoulli(probs=self.unconditional_prob) 198 | if self.inpainting_prob > 0: 199 | self.inpainting_bernoulli = torch.distributions.Bernoulli(probs=self.inpainting_prob) 200 | # Mode/Concentration parameterization of Beta distribution 201 | alpha = inpainting_duration_mode*(inpainting_duration_concentration-2) + 1 202 | beta = (1-inpainting_duration_mode)*(inpainting_duration_concentration-2)+1 203 | self.inpainting_duration_beta = torch.distributions.Beta(alpha, beta) 204 | 205 | self.text_encoder_id = text_encoder 206 | self.text_encoder = T5ForConditionalGeneration.from_pretrained(text_encoder, torch_dtype=torch.bfloat16).get_encoder() 207 | for param in self.text_encoder.parameters(): 208 | param.requires_grad = False 209 | self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder) 210 | 211 | self.audio_codec = EncodecWrapper() 212 | for param in self.audio_codec.parameters(): 213 | param.requires_grad = False 214 | 215 | 216 | def predict_start_from_noise(self, z_t, t, noise, sampling=False): 217 | time_to_alpha = self.sampling_schedule if sampling else self.train_schedule 218 | alpha = time_to_alpha(t) 219 | alpha = right_pad_dims_to(z_t, alpha) 220 | 221 | return (z_t - (1-alpha).sqrt() * noise) / alpha.sqrt().clamp(min = 1e-8) 222 | 223 | def predict_noise_from_start(self, z_t, t, x0, sampling=False): 224 | time_to_alpha = self.sampling_schedule if sampling else self.train_schedule 225 | alpha = time_to_alpha(t) 226 | alpha = right_pad_dims_to(z_t, alpha) 227 | 228 | return (z_t - alpha.sqrt() * x0) / (1-alpha).sqrt().clamp(min = 1e-8) 229 | 230 | def predict_start_from_v(self, z_t, t, v, sampling=False): 231 | time_to_alpha = self.sampling_schedule if sampling else self.train_schedule 232 | alpha = time_to_alpha(t) 233 | alpha = right_pad_dims_to(z_t, alpha) 234 | 235 | x = alpha.sqrt() * z_t - (1-alpha).sqrt() * v 236 | 237 | return x 238 | 239 | def predict_noise_from_v(self, z_t, t, v, sampling=False): 240 | time_to_alpha = self.sampling_schedule if sampling else self.train_schedule 241 | alpha = time_to_alpha(t) 242 | alpha = right_pad_dims_to(z_t, alpha) 243 | 244 | eps = (1-alpha).sqrt() * z_t + alpha.sqrt() * v 245 | 246 | return eps 247 | 248 | def predict_v_from_start_and_eps(self, z_t, t, x, noise, sampling=False): 249 | time_to_alpha = self.sampling_schedule if sampling else self.train_schedule 250 | alpha = time_to_alpha(t) 251 | alpha = right_pad_dims_to(z_t, alpha) 252 | 253 | v = alpha.sqrt() * noise - x* (1-alpha).sqrt() 254 | 255 | return v 256 | 257 | def diffusion_model_predictions(self, z_t, t, *, text_cond, text_cond_mask, sampling=False, cls_free_guidance=1.0, fill_mask=None, audio_mask=None): 258 | time_to_alpha = self.sampling_schedule if sampling else self.train_schedule 259 | time_cond = time_to_alpha(t).sqrt() 260 | latents = None 261 | inpainting_mask = fill_mask[:, 0, :].long() 262 | if sampling: 263 | model_output = self.denoising_network.ema_model(z_t, time_cond, text_cond=text_cond, text_cond_mask=text_cond_mask, inpainting_mask=inpainting_mask, audio_mask=audio_mask) 264 | if cls_free_guidance != 1.0: 265 | unc_text_cond = torch.zeros_like(text_cond)[:,:1,:] 266 | unc_text_cond_mask = torch.full_like(text_cond_mask, fill_value=False)[:,:1] 267 | if exists(fill_mask): 268 | alpha = rearrange(time_to_alpha(t), 'b -> b () ()') 269 | noise = torch.randn_like(z_t) 270 | z_t[fill_mask] = (z_t*alpha.sqrt() + (1-alpha).sqrt()*noise)[fill_mask] 271 | unc_inpainting_mask = torch.full_like(inpainting_mask, fill_value=0) 272 | unc_model_output = self.denoising_network.ema_model(z_t, time_cond, text_cond=unc_text_cond, text_cond_mask=unc_text_cond_mask, inpainting_mask=unc_inpainting_mask, audio_mask=audio_mask) 273 | model_output = model_output*cls_free_guidance + unc_model_output*(1-cls_free_guidance) 274 | else: 275 | model_output = self.denoising_network.online_model(z_t, time_cond, text_cond=text_cond, text_cond_mask=text_cond_mask, inpainting_mask=inpainting_mask, audio_mask=audio_mask) 276 | 277 | pred_v = None 278 | if self.parameterization == 'pred_eps': 279 | pred_eps = model_output 280 | x_start = self.predict_start_from_noise(z_t, t, pred_eps, sampling=sampling) 281 | elif self.parameterization =='pred_x0': 282 | x_start = model_output 283 | pred_eps = self.predict_noise_from_start(z_t, t, x_start, sampling=sampling) 284 | pred_v = self.predict_v_from_start_and_eps(z_t, t, x_start, pred_eps, sampling=sampling) 285 | elif self.parameterization == 'pred_v': 286 | pred_v = model_output 287 | x_start = self.predict_start_from_v(z_t, t, pred_v, sampling=sampling) 288 | pred_eps = self.predict_noise_from_v(z_t, t, pred_v, sampling=sampling) 289 | else: 290 | raise ValueError(f'invalid objective {self.parameterization}') 291 | 292 | return ModelPrediction(pred_eps, x_start, pred_v, latents) 293 | 294 | def get_sampling_timesteps(self, batch, *, device, start_time=1.0): 295 | times = torch.linspace(start_time, 0., self.sampling_timesteps + 1, device = device) 296 | times = repeat(times, 't -> b t', b = batch) 297 | times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0) 298 | times = times.unbind(dim = -1) 299 | return times 300 | 301 | 302 | @torch.no_grad() 303 | def ddim_or_ddpm_sample(self, shape, text_cond, text_cond_mask, prefix_seconds=0, audio_latent=None, cls_free_guidance=1.0, speaker_frames=None, sampler='ddim'): 304 | batch, device = shape[0], next(self.denoising_network.ema_model.parameters()).device 305 | 306 | time_pairs = self.get_sampling_timesteps(batch, device = device) 307 | 308 | z_t = torch.randn(shape, device=device) 309 | 310 | fill_mask = None 311 | 312 | if prefix_seconds > 0: 313 | assert exists(audio_latent) 314 | if exists(speaker_frames): 315 | num_inpainting_frames = speaker_frames 316 | else: 317 | num_inpainting_frames = round(prefix_seconds*LATENT_SAMPLING_RATE) 318 | torch.full((batch), fill_value=num_inpainting_frames, dtype=torch.int, device=device) 319 | 320 | indices = torch.arange(z_t.shape[2], device=device) 321 | 322 | # Construct mask to insert clean data 323 | fill_mask = repeat((indices <= num_inpainting_frames[:, None]), 'b l -> b c l', c=z_t.shape[1]) 324 | else: 325 | fill_mask = torch.full_like(z_t, fill_value=0, dtype=torch.bool) 326 | 327 | x_start = None 328 | latents = None 329 | 330 | for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', total = self.sampling_timesteps): 331 | # get predicted x0 332 | if prefix_seconds > 0: 333 | z_t[fill_mask] = audio_latent[fill_mask] 334 | if exists(x_start): 335 | x_start[fill_mask] = audio_latent[fill_mask] 336 | 337 | model_output = self.diffusion_model_predictions(z_t, time, text_cond=text_cond, text_cond_mask=text_cond_mask, sampling=True, cls_free_guidance=cls_free_guidance, fill_mask=fill_mask) 338 | 339 | # get alpha sigma of time and next time 340 | 341 | alpha = self.sampling_schedule(time) 342 | alpha_next = self.sampling_schedule(time_next) 343 | alpha, alpha_next = map(partial(right_pad_dims_to, z_t), (alpha, alpha_next)) 344 | 345 | # calculate x0 and noise 346 | 347 | x_start = model_output.pred_x_start 348 | 349 | eps = model_output.pred_eps 350 | 351 | if time_next[0] <= 0: 352 | z_t = x_start 353 | continue 354 | 355 | # get noise 356 | if sampler == 'ddim': 357 | z_t = x_start * alpha_next.sqrt() + eps * (1-alpha_next).sqrt() 358 | elif sampler == 'ddpm': 359 | # get noise 360 | noise = torch.randn_like(z_t) 361 | alpha_now = alpha/alpha_next 362 | 363 | min_var = torch.exp(torch.log1p(-alpha_next) - torch.log1p(-alpha)) * (1.0 -alpha_now) 364 | max_var = (1.0 - alpha_now) 365 | noise_param = 0.2 366 | sigma = torch.exp(noise_param * torch.log(max_var) + (1 - noise_param) * torch.log(min_var) ) 367 | z_t = 1/alpha_now.sqrt() * (z_t - (1-alpha_now)/(1-alpha).sqrt() * eps) + torch.sqrt(sigma) * noise 368 | if prefix_seconds > 0: 369 | z_t[fill_mask] = audio_latent[fill_mask] 370 | return z_t 371 | 372 | @torch.no_grad() 373 | def sample(self, data, prefix_seconds=0, cls_free_guidance=1.0): 374 | # [B, L, d_lm]: Embedded text 375 | if prefix_seconds > 0: 376 | merged_text = [' '.join((speaker_text, text)) for speaker_text, text in zip(data['speaker_text'], data['text'])] 377 | tokenizer_output = self.text_tokenizer(merged_text, padding="max_length", truncation=True, max_length=256, return_tensors='pt').to(data['input_ids'].device) 378 | text_cond = self.text_encoder(tokenizer_output['input_ids'], tokenizer_output['attention_mask']).last_hidden_state.float() 379 | text_cond_mask = tokenizer_output['attention_mask'].bool() 380 | audio_latent = self.audio_codec.encode(data['speaker_wav']) 381 | speaker_frames = torch.floor(data['speaker_audio_duration'] * LATENT_SAMPLING_RATE).int() 382 | else: 383 | text_cond = self.text_encoder(data['input_ids'], data['attention_mask']).last_hidden_state.float() 384 | text_cond_mask = data['attention_mask'].bool() 385 | # [B, d_audio, L] 386 | audio_latent = self.audio_codec.encode(data['wav']) 387 | speaker_frames = None 388 | 389 | audio_latent *= self.std_scale_factor 390 | latent_shape = audio_latent.shape 391 | assert self.sampler in {'ddim', 'ddpm'} 392 | sample_fn = partial(self.ddim_or_ddpm_sample, sampler=self.sampler) 393 | return sample_fn(latent_shape, text_cond, text_cond_mask, prefix_seconds=prefix_seconds, audio_latent=audio_latent, cls_free_guidance=cls_free_guidance, speaker_frames=speaker_frames) / self.std_scale_factor 394 | 395 | @property 396 | def loss_fn(self): 397 | if self.loss_type == 'l1': 398 | return F.l1_loss 399 | elif self.loss_type == 'l2': 400 | return F.mse_loss 401 | else: 402 | raise ValueError(f'invalid loss type {self.loss_type}') 403 | 404 | def inpainting_enabled(self): 405 | return self.inpainting_prob > 0 406 | 407 | def forward(self, data, accelerator=None): 408 | 409 | with torch.no_grad(): 410 | # [B, L, d_lm]: Embedded text 411 | text_cond = self.text_encoder(data['input_ids'], data['attention_mask']).last_hidden_state.float() 412 | # [B, L]: Cross-attn mask 413 | text_cond_mask = data['attention_mask'].bool() 414 | 415 | 416 | # [B, d_audio, L]: embedded audio 417 | with torch.cuda.amp.autocast(enabled=False): 418 | audio_latent = self.audio_codec.encode(data['wav']) 419 | 420 | if self.scale_by_std: 421 | # Estimate standard deviation of the data from the first batch 422 | if self.std_scale_factor < 0: 423 | del self.std_scale_factor 424 | gathered_audio_latent = accelerator.gather(audio_latent) 425 | self.register_buffer('std_scale_factor', 1. / gathered_audio_latent.flatten().std()) 426 | print(f'Setting std scale factor: {self.std_scale_factor.item()}') 427 | 428 | audio_latent *= self.std_scale_factor 429 | 430 | batch, audio_channels, audio_length = audio_latent.shape 431 | device = audio_latent.device 432 | 433 | # Mask out text-conditioning with some probability to enable clf-free guidance 434 | if self.unconditional_prob > 0: 435 | unconditional_mask = self.unconditional_bernoulli.sample((batch,)).bool() 436 | text_cond_mask[unconditional_mask, :] = False 437 | 438 | # sample random times 439 | 440 | times = torch.zeros((batch,), device = device).float().uniform_(0, 1.) 441 | 442 | # noise sample 443 | 444 | noise = torch.randn_like(audio_latent) 445 | 446 | alpha = self.train_schedule(times) 447 | alpha = right_pad_dims_to(audio_latent, alpha) 448 | 449 | z_t = alpha.sqrt() * audio_latent + (1-alpha).sqrt() * noise 450 | 451 | # Inpainting logic 452 | inpainting_mask = None 453 | if self.inpainting_prob > 0: 454 | inpainting_batch_mask = self.inpainting_bernoulli.sample((batch,)).bool().to(device) 455 | # Sample durations to mask 456 | inpainting_durations = self.inpainting_duration_beta.sample((batch,)).to(device) * data['audio_duration'] 457 | num_inpainting_frames = torch.round(inpainting_durations*LATENT_SAMPLING_RATE).int() 458 | 459 | # Sample where to mask 460 | indices = torch.arange(audio_length, device=device) 461 | 462 | # Construct mask to insert clean data 463 | inpainting_length_mask = ((indices <= num_inpainting_frames[:, None])) 464 | inpainting_mask = (inpainting_length_mask) & inpainting_batch_mask.unsqueeze(-1) 465 | fill_mask = repeat(inpainting_mask, 'b l -> b c l', c=audio_channels) 466 | 467 | z_t[fill_mask] = audio_latent[fill_mask] 468 | else: 469 | fill_mask = torch.full_like(z_t, fill_value=0, dtype=torch.bool) 470 | 471 | velocity = alpha.sqrt() * noise - (1-alpha).sqrt() * audio_latent 472 | 473 | # predict and take gradient step 474 | predictions = self.diffusion_model_predictions(z_t, times, text_cond=text_cond, text_cond_mask=text_cond_mask, fill_mask=fill_mask, audio_mask=data['audio_mask']) 475 | 476 | 477 | if self.objective == 'pred_x0': 478 | target = audio_latent 479 | pred = predictions.pred_x_start 480 | elif self.objective == 'pred_eps': 481 | target = noise 482 | pred = predictions.pred_eps 483 | elif self.objective == 'pred_v': 484 | # V-prediction from https://openreview.net/forum?id=TIdIXIpzhoI 485 | target = velocity 486 | assert exists(predictions.pred_v) 487 | pred = predictions.pred_v 488 | else: 489 | raise NotImplementedError 490 | 491 | loss = self.loss_fn(pred, target, reduction = 'none') 492 | if self.inpainting_prob > 0: 493 | loss = reduce(loss, 'b c l -> b l', 'mean') 494 | # Standard diffusion loss 495 | diff_batch = loss[~inpainting_batch_mask] 496 | diff_loss = masked_mean(diff_batch, dim=1, mask=data['audio_mask'][~inpainting_batch_mask]) 497 | 498 | # Masked inpainting loss 499 | inpainting_batch = loss[inpainting_batch_mask] 500 | loss_mask = torch.logical_and((~inpainting_length_mask[inpainting_batch_mask]), data['audio_mask'][inpainting_batch_mask]) 501 | inpainting_loss = masked_mean(inpainting_batch, dim=1, mask=loss_mask) 502 | loss = torch.cat([diff_loss, inpainting_loss], dim=0) 503 | else: 504 | loss = reduce(loss, 'b c l -> b l', 'mean') 505 | loss = masked_mean(inpainting_batch, dim=1, mask=data['audio_mask']) 506 | 507 | return loss.mean() 508 | 509 | # trainer class 510 | 511 | class Trainer(object): 512 | def __init__( 513 | self, 514 | args, 515 | diffusion, 516 | dataset_name, 517 | *, 518 | batch_size = 16, 519 | gradient_accumulate_every = 1, 520 | train_lr = 1e-4, 521 | train_num_steps = 100000, 522 | lr_schedule = 'cosine', 523 | num_warmup_steps = 500, 524 | adam_betas = (0.9, 0.999), 525 | adam_weight_decay = 0.01, 526 | save_and_sample_every = 5000, 527 | num_samples = 25, 528 | mixed_precision = 'no', 529 | prefix_inpainting_seconds=0, 530 | seed=None 531 | ): 532 | super().__init__() 533 | 534 | assert prefix_inpainting_seconds in {0., 3.0}, 'Currently only supports 3sec for inpainting' 535 | if exists(seed): 536 | set_seeds(seed) 537 | 538 | self.args = args 539 | 540 | self.accelerator = Accelerator( 541 | mixed_precision = mixed_precision, 542 | log_with=['mlflow'], 543 | ) 544 | self.num_devices = self.accelerator.num_processes 545 | args.num_devices = self.num_devices 546 | 547 | args.output_dir = get_output_dir(args) 548 | 549 | if self.accelerator.is_main_process: 550 | os.makedirs(args.output_dir) 551 | print(f'Created {args.output_dir}') 552 | 553 | with open(os.path.join(args.output_dir, 'args.json'), 'w') as f: 554 | json.dump(args.__dict__, f, indent=2) 555 | run = os.path.split(__file__)[-1].split(".")[0] 556 | if self.num_devices > 1: 557 | run += '_multi' 558 | else: 559 | run += '_debug' 560 | self.accelerator.init_trackers(run, config=vars(args), init_kwargs={"mlflow": {"logging_dir": args.output_dir, "run_name": args.run_name}}) 561 | 562 | 563 | self.diffusion = diffusion 564 | 565 | self.num_samples = num_samples 566 | self.save_and_sample_every = save_and_sample_every 567 | self.prefix_inpainting_seconds = prefix_inpainting_seconds 568 | 569 | self.batch_size = batch_size 570 | self.gradient_accumulate_every = gradient_accumulate_every 571 | 572 | self.train_num_steps = train_num_steps 573 | self.max_seq_len = diffusion.max_seq_len 574 | 575 | 576 | # dataset and dataloader 577 | if dataset_name == 'librispeech': 578 | self.dataset = LibriSpeech(split='train', tokenizer=diffusion.text_tokenizer) 579 | self.val_dataset = LibriSpeech(split='valid', tokenizer=diffusion.text_tokenizer) 580 | self.test_dataset = LibriSpeech(split='test', tokenizer=diffusion.text_tokenizer, max_seq_len=self.dataset.max_seq_len) 581 | elif dataset_name == 'mls': 582 | self.dataset = MLS(split='train', tokenizer=diffusion.text_tokenizer, max_text_len=256 if 'byt5' in diffusion.text_encoder_id else 128) 583 | self.val_dataset = LibriSpeech(split='valid', tokenizer=diffusion.text_tokenizer, max_seq_len=self.dataset.max_seq_len) 584 | self.test_dataset = LibriSpeech(split='test', tokenizer=diffusion.text_tokenizer, max_seq_len=self.dataset.max_seq_len) 585 | else: 586 | raise ValueError(f'invalid dataset: {dataset_name}') 587 | 588 | self.dataloader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, pin_memory=True, num_workers=2) 589 | self.val_dataloader = DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False) 590 | self.test_dataloader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False) 591 | 592 | 593 | # optimizer 594 | 595 | self.opt = get_adamw_optimizer(diffusion, lr = train_lr, betas = adam_betas, weight_decay=adam_weight_decay) 596 | 597 | # scheduler 598 | 599 | lr_scheduler = get_scheduler( 600 | lr_schedule, 601 | optimizer=self.opt, 602 | num_warmup_steps=num_warmup_steps*self.num_devices, 603 | num_training_steps=train_num_steps*self.num_devices, # Accelerate does num_devices steps at a time 604 | ) 605 | 606 | # for logging results in a folder periodically 607 | 608 | if self.accelerator.is_main_process: 609 | 610 | self.results_folder = args.output_dir 611 | 612 | # step counter state 613 | 614 | self.step = 0 615 | 616 | # prepare model, dataloader, optimizer with accelerator 617 | 618 | self.diffusion, self.opt, self.dataloader, self.val_dataloader, self.test_dataloader, self.lr_scheduler = self.accelerator.prepare(self.diffusion, self.opt, self.dataloader, self.val_dataloader, self.test_dataloader, lr_scheduler) 619 | self.data_iter = cycle(self.dataloader) 620 | self.val_data_iter = cycle(self.val_dataloader) 621 | 622 | def save(self, save_step=False): 623 | if not self.accelerator.is_local_main_process: 624 | return 625 | 626 | data = { 627 | 'step': self.step, 628 | 'model': self.accelerator.get_state_dict(self.diffusion), 629 | 'opt': self.opt.state_dict(), 630 | 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None, 631 | 'scheduler': self.lr_scheduler.state_dict(), 632 | } 633 | if save_step: 634 | torch.save(data, f'{self.results_folder}/model_{self.step}.pt') 635 | else: 636 | torch.save(data, f'{self.results_folder}/model.pt') 637 | 638 | def load(self, file_path=None, best=False, init_only=False): 639 | file_path = file_path if exists(file_path) else self.results_folder 640 | accelerator = self.accelerator 641 | device = accelerator.device 642 | 643 | if best: 644 | data = torch.load(f'{file_path}/best_model.pt', map_location=device) 645 | else: 646 | data = torch.load(f'{file_path}/model.pt', map_location=device) 647 | 648 | model = self.accelerator.unwrap_model(self.diffusion) 649 | strict_load = not (init_only) 650 | model.load_state_dict(data['model'], strict=strict_load) 651 | 652 | if init_only: 653 | return 654 | 655 | # For backwards compatibility with earlier models 656 | if exists(self.accelerator.scaler) and exists(data['scaler']): 657 | self.accelerator.scaler.load_state_dict(data['scaler']) 658 | 659 | self.opt.load_state_dict(data['opt']) 660 | 661 | self.step = data['step'] 662 | self.lr_scheduler.load_state_dict(data['scheduler']) 663 | 664 | 665 | @torch.no_grad() 666 | def sample(self, num_samples=None, seed=None, cls_free_guidance=1.0, test=False, prefix_seconds=0.): 667 | if exists(seed): 668 | set_seeds(seed) 669 | diffusion = self.accelerator.unwrap_model(self.diffusion) 670 | num_samples = default(num_samples, self.num_samples) 671 | self.diffusion.eval() 672 | inpainting_enabled = diffusion.inpainting_enabled() and diffusion.sampler != 'dpmpp' and exists(num_samples) 673 | num_sampled = 0 674 | dataloader = self.test_dataloader if test else self.val_dataloader 675 | for batch in dataloader: 676 | sampled_codec_latents = diffusion.sample(batch, prefix_seconds=prefix_seconds, cls_free_guidance=cls_free_guidance) 677 | sampled_wavs = diffusion.audio_codec.decode(sampled_codec_latents).squeeze() 678 | sampled_wavs = self.accelerator.gather_for_metrics(sampled_wavs).to('cpu') 679 | 680 | input_ids = self.accelerator.gather_for_metrics(batch['input_ids']).to('cpu') 681 | 682 | speaker_durations = self.accelerator.gather_for_metrics(batch['speaker_audio_duration']).to('cpu') 683 | 684 | if self.accelerator.is_main_process: 685 | inpainting_suffix = f'_prefix{prefix_seconds}' if prefix_seconds>0 else '' 686 | samples_folder = os.path.join(self.results_folder, 'samples', f'step_{self.step}', f'guide{cls_free_guidance}{inpainting_suffix}') 687 | os.makedirs(samples_folder, exist_ok=True) 688 | text_list = [diffusion.text_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in input_ids] 689 | ref_frames = torch.ceil(speaker_durations*ENCODEC_SAMPLING_RATE).int() 690 | for idx in range(len(text_list)): 691 | text = text_list[idx] 692 | print(f'Saving idx: {idx+num_sampled}') 693 | with open(os.path.join(samples_folder, f'text_{idx+num_sampled}.txt'), 'w') as f: 694 | print(text, file=f) 695 | if prefix_seconds > 0: 696 | ref_frames_idx = ref_frames[idx].item() 697 | sf.write(os.path.join(samples_folder, f'audio_{idx+num_sampled}.wav'), sampled_wavs[idx][ref_frames_idx:], ENCODEC_SAMPLING_RATE) 698 | sf.write(os.path.join(samples_folder, f'ref_{idx+num_sampled}.wav'), sampled_wavs[idx][:ref_frames_idx], ENCODEC_SAMPLING_RATE) 699 | else: 700 | sf.write(os.path.join(samples_folder, f'audio_{idx+num_sampled}.wav'), sampled_wavs[idx], ENCODEC_SAMPLING_RATE) 701 | batch_size = self.num_devices * batch['wav'].shape[0] 702 | num_sampled += batch_size 703 | 704 | 705 | if exists(num_samples) and num_sampled >= num_samples: 706 | break 707 | 708 | 709 | def train(self): 710 | accelerator = self.accelerator 711 | device = accelerator.device 712 | 713 | with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar: 714 | 715 | while self.step < self.train_num_steps: 716 | 717 | total_loss = 0. 718 | 719 | for _ in range(self.gradient_accumulate_every): 720 | data = next(self.data_iter) 721 | loss = self.diffusion(data, accelerator) 722 | loss = loss / self.gradient_accumulate_every 723 | total_loss += loss.item() 724 | 725 | self.accelerator.backward(loss) 726 | 727 | 728 | if accelerator.sync_gradients: 729 | accelerator.clip_grad_norm_(self.diffusion.parameters(), self.args.clip_grad_norm) 730 | self.opt.step() 731 | grad_norm = compute_grad_norm(self.diffusion.parameters()) 732 | self.lr_scheduler.step() 733 | self.opt.zero_grad() 734 | self.step += 1 735 | 736 | if self.step % 10 == 0: 737 | logs = { 738 | "loss": total_loss, 739 | "learning_rate": self.lr_scheduler.get_last_lr()[0], 740 | "grad_norm": grad_norm, 741 | "step": self.step, 742 | "epoch": (self.step*self.gradient_accumulate_every)/len(self.dataloader), 743 | "samples": self.step*self.batch_size*self.gradient_accumulate_every*self.num_devices 744 | } 745 | # Validation loss 746 | if self.step % 50 == 0: 747 | with torch.no_grad(): 748 | total_val_loss = 0 749 | data = next(self.val_data_iter) 750 | loss = self.diffusion(data) 751 | total_val_loss += loss.item() 752 | logs['val_loss'] = total_val_loss 753 | if accelerator.is_main_process: 754 | pbar.set_postfix(**logs) 755 | accelerator.log(logs, step=self.step) 756 | 757 | accelerator.wait_for_everyone() 758 | # Update EMA 759 | accelerator.unwrap_model(self.diffusion).denoising_network.update() 760 | 761 | if self.step % self.save_and_sample_every == 0: 762 | self.sample() 763 | for cls_free_guidance in [2.0, 3.0, 5.0]: 764 | self.sample(cls_free_guidance=cls_free_guidance) 765 | 766 | if self.prefix_inpainting_seconds > 0: 767 | self.sample(prefix_seconds=self.prefix_inpainting_seconds) 768 | for cls_free_guidance in [2.0, 3.0, 5.0]: 769 | self.sample(cls_free_guidance=cls_free_guidance, prefix_seconds=self.prefix_inpainting_seconds) 770 | self.save() 771 | if self.step % (self.save_and_sample_every*2) == 0: 772 | self.save(save_step=True) 773 | 774 | self.diffusion.train() 775 | 776 | pbar.update(1) 777 | 778 | accelerator.end_training() 779 | accelerator.print('training complete') -------------------------------------------------------------------------------- /diffusion/noise_schedule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | from functools import partial 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | # Avoid log(0) 9 | def log(t, eps = 1e-12): 10 | return torch.log(t.clamp(min = eps)) 11 | 12 | # noise schedules 13 | 14 | def simple_linear_schedule(t, clip_min = 1e-9): 15 | return (1 - t).clamp(min = clip_min) 16 | 17 | def beta_linear_schedule(t, clip_min = 1e-9): 18 | return torch.exp(-1e-4 - 10 * (t ** 2)).clamp(min = clip_min, max = 1.) 19 | 20 | def cosine_schedule(t, start = 0, end = 1, tau = 1, clip_min = 1e-9): 21 | power = 2 * tau 22 | v_start = math.cos(start * math.pi / 2) ** power 23 | v_end = math.cos(end * math.pi / 2) ** power 24 | output = torch.cos((t * (end - start) + start) * math.pi / 2) ** power 25 | output = (v_end - output) / (v_end - v_start) 26 | return output.clamp(min = clip_min) 27 | 28 | def sigmoid_schedule(t, start = -3, end = 3, tau = 1, clamp_min = 1e-9): 29 | v_start = torch.tensor(start / tau).sigmoid() 30 | v_end = torch.tensor(end / tau).sigmoid() 31 | gamma = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start) 32 | return gamma.clamp_(min = clamp_min, max = 1.) 33 | 34 | # converting gamma to alpha, sigma or logsnr 35 | def log_snr_to_alpha(log_snr): 36 | alpha = torch.sigmoid(log_snr) 37 | return alpha 38 | 39 | # Log-SNR shifting (https://arxiv.org/abs/2301.10972) 40 | def alpha_to_shifted_log_snr(alpha, scale = 1): 41 | return (log(alpha) - log(1 - alpha)).clamp(min=-20, max=20) + 2*np.log(scale).item() 42 | 43 | def time_to_alpha(t, alpha_schedule, scale): 44 | alpha = alpha_schedule(t) 45 | shifted_log_snr = alpha_to_shifted_log_snr(alpha, scale = scale) 46 | return log_snr_to_alpha(shifted_log_snr) 47 | 48 | def plot_noise_schedule(unscaled_sampling_schedule, name, y_value): 49 | assert y_value in {'alpha^2', 'alpha', 'log(SNR)'} 50 | t = torch.linspace(0, 1, 100) # 100 points between 0 and 1 51 | scales = [.2, .5, 1.0] 52 | for scale in scales: 53 | sampling_schedule = partial(time_to_alpha, alpha_schedule=unscaled_sampling_schedule, scale=scale) 54 | alphas = sampling_schedule(t) # Obtain noise schedule values for each t 55 | if y_value == 'alpha^2': 56 | y_axis_label = r'$\alpha^2_t$' 57 | y = alphas 58 | elif y_value == 'alpha': 59 | y_axis_label = r'$\alpha_t$' 60 | y = alphas.sqrt() 61 | elif y_value == 'log(SNR)': 62 | y_axis_label = r'$\log(\lambda_t)$' 63 | y = alpha_to_shifted_log_snr(alphas, scale=1) 64 | 65 | plt.plot(t.numpy(), y.numpy(), label=f'Scale: {scale:.1f}') 66 | if y_value == 'log(SNR)': 67 | plt.ylim(-15, 15) 68 | plt.xlabel('t') 69 | plt.ylabel(y_axis_label) 70 | plt.title(f'{name}') 71 | plt.legend() 72 | plt.savefig(f'viz/{name.lower()}_{y_value}.png') 73 | plt.clf() 74 | 75 | 76 | def plot_side_by_side_noise_schedule(unscaled_sampling_schedule, name): 77 | 78 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10,5)) 79 | 80 | t = torch.linspace(0, 1, 100) 81 | scales = [1.0, .5, .2] 82 | 83 | for i, y_value in enumerate(['alpha', 'log(SNR)']): 84 | 85 | for scale in scales: 86 | sampling_schedule = partial(time_to_alpha, alpha_schedule=unscaled_sampling_schedule, scale=scale) 87 | alphas = sampling_schedule(t) 88 | 89 | if y_value == 'alpha': 90 | y_axis_label = r'$\alpha_t$' 91 | y = alphas.sqrt() 92 | elif y_value == 'log(SNR)': 93 | y_axis_label = r'$\log(\lambda_t)$' 94 | y = alpha_to_shifted_log_snr(alphas, scale=1) 95 | 96 | ax = ax1 if i == 0 else ax2 97 | ax.plot(t.numpy(), y.numpy(), label=f'Scale: {scale:.1f}') 98 | 99 | if y_value == 'log(SNR)': 100 | ax.set_ylim(-15, 15) 101 | 102 | ax.set_xlabel('t', fontsize=14) 103 | ax.set_ylabel(y_axis_label, fontsize=14) 104 | ax.legend() 105 | 106 | 107 | fig.suptitle(f'{name}', fontsize=18) 108 | fig.tight_layout() 109 | plt.savefig(f'viz/{name.lower().replace(" ", "_")}.png') 110 | plt.clf() 111 | 112 | 113 | def plot_cosine_schedule(): 114 | t = torch.linspace(0, 1, 100) # 100 points between 0 and 1 115 | sampling_schedule = cosine_schedule 116 | alphas = sampling_schedule(t) # Obtain noise schedule values for each t 117 | y = alphas 118 | plt.plot(t.numpy(), y.numpy()) 119 | plt.xlabel('t') 120 | plt.ylabel(f'alpha^2') 121 | plt.title(f'Cosine Noise Schedule') 122 | plt.savefig(f'viz/standard_cosine.png') 123 | plt.clf() 124 | 125 | def visualize(): 126 | unscaled_sampling_schedule = cosine_schedule 127 | 128 | plot_side_by_side_noise_schedule(unscaled_sampling_schedule, 'Shifted Cosine Noise Schedules') 129 | 130 | 131 | 132 | if __name__=='__main__': 133 | visualize() -------------------------------------------------------------------------------- /diffusion/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.optim import AdamW 3 | 4 | # Implementation from Timm: https://github.com/huggingface/pytorch-image-models/blob/2d0dbd17e388953ab81a5c56f80074eff962ea6b/timm/optim/optim_factory.py#L40 5 | # Exclude bias and normalization (e.g. LayerNorm) params 6 | def param_groups_weight_decay( 7 | model: nn.Module, 8 | weight_decay=.01, 9 | no_weight_decay_list=() 10 | ): 11 | no_weight_decay_list = set(no_weight_decay_list) 12 | decay = [] 13 | no_decay = [] 14 | for name, param in model.named_parameters(): 15 | if not param.requires_grad: 16 | continue 17 | 18 | if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list: 19 | no_decay.append(param) 20 | else: 21 | decay.append(param) 22 | 23 | return [ 24 | {'params': no_decay, 'weight_decay': 0.}, 25 | {'params': decay, 'weight_decay': weight_decay}] 26 | 27 | def get_adamw_optimizer(model, lr, betas, weight_decay): 28 | param_groups = param_groups_weight_decay(model, weight_decay=weight_decay) 29 | 30 | return AdamW(param_groups, lr=lr, weight_decay=weight_decay, betas=betas) 31 | -------------------------------------------------------------------------------- /evaluation/evaluate_transcript.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | from transformers import Wav2Vec2Processor, HubertForCTC, Wav2Vec2FeatureExtractor, WavLMForXVector 4 | from evaluate import load 5 | from tqdm import tqdm 6 | from einops import rearrange, reduce, repeat 7 | import sox 8 | import math 9 | 10 | 11 | def chunker(seq, size): 12 | return (seq[pos:pos + size] for pos in range(0, len(seq), size)) 13 | 14 | def trim_silence(wav, sample_rate): 15 | np_arr = wav.numpy().squeeze() 16 | # create a transformer 17 | tfm = sox.Transformer() 18 | tfm.silence(location=-1, silence_threshold=.1) 19 | # transform an in-memory array and return an array 20 | y_out = tfm.build_array(input_array=np_arr, sample_rate_in=sample_rate) 21 | duration = len(y_out)/sample_rate 22 | if duration < .5: 23 | return wav 24 | 25 | return torch.tensor(y_out).unsqueeze(0) 26 | 27 | 28 | @torch.inference_mode() 29 | def compute_wer(wavpath_list, text_list, wav_list=None, model_id='facebook/hubert-large-ls960-ft', truncate=False): 30 | processor = Wav2Vec2Processor.from_pretrained(model_id) 31 | model = HubertForCTC.from_pretrained(model_id).to('cuda') 32 | model.eval() 33 | 34 | if wav_list is None: 35 | wav_list = [torchaudio.load(wavpath) for wavpath in wavpath_list] 36 | waveform_list = [] 37 | sample_rate_list = [] 38 | for waveform, sample_rate in wav_list: 39 | waveform_list.append(waveform) 40 | sample_rate_list.append(sample_rate) 41 | 42 | 43 | asr_text = [] 44 | for i in tqdm(range(len(wav_list))): 45 | waveform = rearrange(waveform_list[i].squeeze(), 'l -> () l') 46 | waveform = torchaudio.functional.resample(waveform, sample_rate_list[0], processor.feature_extractor.sampling_rate) 47 | if truncate: 48 | waveform = trim_silence(waveform, processor.feature_extractor.sampling_rate) 49 | input_values = processor(waveform, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt").input_values.squeeze(0).to('cuda') 50 | logits = model(input_values).logits 51 | predicted_ids = torch.argmax(logits, dim=-1) 52 | transcription = [transcript.lower().strip() for transcript in processor.batch_decode(predicted_ids)] 53 | 54 | asr_text.extend(transcription) 55 | 56 | if text_list is None: 57 | return asr_text 58 | print(f'asr_text: {asr_text[:3]}') 59 | print(f'text_list: {text_list[:3]}') 60 | wer = load("wer") 61 | wer_score = wer.compute(predictions=asr_text, references=text_list) 62 | return wer_score 63 | -------------------------------------------------------------------------------- /models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from models.modules.norm import RMSNorm, ConvRMSNorm, LayerNorm 2 | from models.modules.transformer import ConditionableTransformer 3 | from models.modules.conformer import ConformerConvBlock 4 | from models.modules.conv import MaskedConv1d -------------------------------------------------------------------------------- /models/modules/conformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, reduce, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | from models.modules.norm import RMSNorm, ConvRMSNorm 9 | 10 | def exists(x): 11 | return x is not None 12 | 13 | def zero_init_(m): 14 | nn.init.zeros_(m.weight) 15 | if exists(m.bias): 16 | nn.init.zeros_(m.bias) 17 | 18 | def default(val, d): 19 | if exists(val): 20 | return val 21 | return d() if callable(d) else d 22 | 23 | class ConformerConvBlock(torch.nn.Module): 24 | """Convolution block used in the conformer block""" 25 | 26 | def __init__( 27 | self, 28 | dim, 29 | dim_out=None, 30 | depthwise_kernel_size=17, 31 | expansion_factor=2, 32 | time_cond_dim = None, 33 | channels_first=True, 34 | zero_init=True, 35 | ): 36 | """ 37 | Args: 38 | dim: Embedding dimension 39 | depthwise_kernel_size: Depthwise conv layer kernel size 40 | """ 41 | super(ConformerConvBlock, self).__init__() 42 | dim_out = default(dim_out, dim) 43 | self.channels_first = channels_first 44 | inner_dim = dim * expansion_factor 45 | 46 | self.time_cond = None 47 | self.time_gate = exists(time_cond_dim) and zero_init 48 | if exists(time_cond_dim): 49 | self.time_cond = nn.Sequential( 50 | nn.SiLU(), 51 | nn.Linear(time_cond_dim, dim * 3), 52 | Rearrange('b d -> b 1 d') 53 | ) 54 | zero_init_(self.time_cond[-2]) 55 | 56 | 57 | assert ( 58 | depthwise_kernel_size - 1 59 | ) % 2 == 0, f"kernel_size: {depthwise_kernel_size} should be a odd number for 'SAME' padding" 60 | self.pointwise_conv1 = torch.nn.Conv1d( 61 | dim, 62 | 2 * inner_dim, 63 | kernel_size=1, 64 | stride=1, 65 | padding=0, 66 | ) 67 | self.glu = torch.nn.GLU(dim=1) 68 | self.depthwise_conv = torch.nn.Conv1d( 69 | inner_dim, 70 | inner_dim, 71 | depthwise_kernel_size, 72 | stride=1, 73 | padding=(depthwise_kernel_size - 1) // 2, 74 | groups=inner_dim 75 | ) 76 | self.norm = ConvRMSNorm(dim) if self.channels_first else RMSNorm(dim) 77 | self.activation = nn.SiLU() 78 | self.pointwise_conv2 = torch.nn.Conv1d( 79 | inner_dim, 80 | dim_out, 81 | kernel_size=1, 82 | stride=1, 83 | padding=0, 84 | ) 85 | if (not self.time_gate) and zero_init: 86 | zero_init_(self.pointwise_conv2) 87 | 88 | 89 | def forward(self, x, time = None, scale_shift=None,): 90 | """ 91 | Args: 92 | x: Input of shape B X T X C 93 | Returns: 94 | Tensor of shape B X T X C 95 | """ 96 | assert not (exists(self.time_cond) and exists(scale_shift)) 97 | 98 | x = self.norm(x) 99 | if exists(self.time_cond): 100 | scale, shift, gate = self.time_cond(time).chunk(3, dim = 2) 101 | x = (x * (scale + 1)) + shift 102 | elif exists(scale_shift): 103 | scale, shift, = scale_shift 104 | x = (x * (scale + 1)) + shift 105 | 106 | if not self.channels_first: 107 | x = rearrange(x, 'b l c -> b c l') 108 | 109 | # GLU mechanism 110 | x = self.pointwise_conv1(x) # (batch, 2*inner_dim, len) 111 | x = self.glu(x) # (batch, inner_dim, len) 112 | 113 | # 1D Depthwise Conv 114 | x = self.depthwise_conv(x) 115 | x = self.activation(x) 116 | 117 | x = self.pointwise_conv2(x) 118 | 119 | if not self.channels_first: 120 | x = rearrange(x, 'b c l -> b l c') 121 | 122 | if exists(self.time_cond): 123 | x = x*gate 124 | 125 | return x 126 | -------------------------------------------------------------------------------- /models/modules/conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | import typing as tp 3 | import warnings 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from einops import rearrange, reduce, repeat 10 | 11 | 12 | class MaskedConv1d(nn.Conv1d): 13 | """Wrapper around Conv1d to provide masking functionality 14 | """ 15 | def __init__(self, *args, **kwargs): 16 | super().__init__(*args, **kwargs) 17 | 18 | def forward(self, x, audio_mask=None): 19 | if audio_mask is not None: 20 | assert audio_mask.shape[-1] == x.shape[-1] 21 | conv_mask = rearrange(audio_mask, 'b l -> b () l') 22 | x = torch.where(conv_mask, x, 0) 23 | x = super().forward(x) 24 | return x -------------------------------------------------------------------------------- /models/modules/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, reduce, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | 9 | class RMSNorm(nn.Module): 10 | def __init__(self, dim): 11 | super().__init__() 12 | self.scale = dim ** 0.5 13 | self.gamma = nn.Parameter(torch.ones(dim)) 14 | 15 | def forward(self, x): 16 | out = F.normalize(x, dim = -1) * self.scale * self.gamma 17 | return out 18 | 19 | class ConvRMSNorm(RMSNorm): 20 | """ 21 | Convolution-friendly RMSNorm that moves channels to last dimensions 22 | before running the normalization and moves them back to original position right after. 23 | """ 24 | def __init__(self, dim): 25 | super().__init__(dim) 26 | 27 | def forward(self, x): 28 | x = rearrange(x, 'b ... t -> b t ...') 29 | x = super().forward(x) 30 | x = rearrange(x, 'b t ... -> b ... t') 31 | return x 32 | 33 | # use layernorm without bias, more stable 34 | 35 | class LayerNorm(nn.Module): 36 | def __init__(self, dim): 37 | super().__init__() 38 | self.gamma = nn.Parameter(torch.ones(dim)) 39 | self.register_buffer("beta", torch.zeros(dim)) 40 | 41 | def forward(self, x): 42 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) -------------------------------------------------------------------------------- /models/modules/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, reduce, repeat 6 | from einops.layers.torch import Rearrange 7 | import math 8 | 9 | from models.modules.norm import ConvRMSNorm, RMSNorm 10 | from models.modules.conformer import ConformerConvBlock 11 | 12 | def exists(x): 13 | return x is not None 14 | 15 | def zero_init_(m): 16 | nn.init.zeros_(m.weight) 17 | if exists(m.bias): 18 | nn.init.zeros_(m.bias) 19 | 20 | class GEGLU(nn.Module): 21 | def forward(self, x): 22 | x, gate = x.chunk(2, dim = -1) 23 | return F.gelu(gate) * x 24 | 25 | class FeedForward(nn.Module): 26 | def __init__( 27 | self, 28 | dim, 29 | mult = 4, 30 | time_cond_dim = None, 31 | ): 32 | super().__init__() 33 | self.norm = RMSNorm(dim) 34 | inner_dim = int(dim * mult * 2 / 3) 35 | dim_out = dim 36 | 37 | self.time_cond = None 38 | 39 | self.net = nn.Sequential( 40 | nn.Linear(dim, inner_dim*2), 41 | GEGLU(), 42 | nn.Linear(inner_dim, dim_out) 43 | ) 44 | 45 | if exists(time_cond_dim): 46 | self.time_cond = nn.Sequential( 47 | nn.SiLU(), 48 | nn.Linear(time_cond_dim, dim * 3), 49 | Rearrange('b d -> b 1 d') 50 | ) 51 | 52 | zero_init_(self.time_cond[-2]) 53 | else: 54 | zero_init_(self.net[-1]) 55 | 56 | 57 | def forward(self, x, time = None): 58 | x = self.norm(x) 59 | if exists(self.time_cond): 60 | assert exists(time) 61 | scale, shift, gate = self.time_cond(time).chunk(3, dim = 2) 62 | x = (x * (scale + 1)) + shift 63 | 64 | x = self.net(x) 65 | 66 | if exists(self.time_cond): 67 | x = x*gate 68 | 69 | return x 70 | 71 | 72 | class AbsolutePositionalEmbedding(nn.Module): 73 | def __init__(self, dim, max_seq_len=512): 74 | super().__init__() 75 | self.scale = dim ** -0.5 76 | self.max_seq_len = max_seq_len 77 | self.emb = nn.Embedding(max_seq_len, dim) 78 | nn.init.normal_(self.emb.weight, std=.01) 79 | 80 | def forward(self, x, pos = None): 81 | seq_len, device = x.shape[1], x.device 82 | assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}' 83 | 84 | if not exists(pos): 85 | pos = torch.arange(seq_len, device = device) 86 | 87 | pos_emb = self.emb(pos) 88 | pos_emb = pos_emb * self.scale 89 | return pos_emb 90 | 91 | # From https://github.com/lucidrains/x-transformers/blob/c7cc22268c8ebceef55fe78343197f0af62edf18/x_transformers/x_transformers.py#L272 92 | class DynamicPositionBias(nn.Module): 93 | def __init__(self, dim, *, heads, depth=2, log_distance = False): 94 | super().__init__() 95 | assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1' 96 | self.log_distance = log_distance 97 | 98 | self.mlp = nn.ModuleList([]) 99 | 100 | self.mlp.append(nn.Sequential( 101 | nn.Linear(1, dim), 102 | nn.SiLU() 103 | )) 104 | 105 | for _ in range(depth - 1): 106 | self.mlp.append(nn.Sequential( 107 | nn.Linear(dim, dim), 108 | nn.SiLU() 109 | )) 110 | 111 | self.mlp.append(nn.Linear(dim, heads)) 112 | 113 | @property 114 | def device(self): 115 | return next(self.parameters()).device 116 | 117 | def forward(self, i, j): 118 | assert i == j 119 | n, device = j, self.device 120 | 121 | # get the (n x n) matrix of distances 122 | seq_arange = torch.arange(n, device = device) 123 | context_arange = torch.arange(n, device = device) 124 | indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j') 125 | indices += (n - 1) 126 | 127 | # input to continuous positions MLP 128 | pos = torch.arange(-n + 1, n, device = device).float() 129 | pos = rearrange(pos, '... -> ... 1') 130 | 131 | if self.log_distance: 132 | pos = torch.sign(pos) * torch.log(pos.abs() + 1) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1) 133 | 134 | for layer in self.mlp: 135 | pos = layer(pos) 136 | 137 | # get position biases 138 | bias = pos[indices] 139 | bias = rearrange(bias, 'i j h -> h i j') 140 | return bias 141 | 142 | class Attention(nn.Module): 143 | def __init__( 144 | self, 145 | dim, 146 | dim_head = 32, 147 | time_cond_dim = None, 148 | dropout=0. 149 | ): 150 | super().__init__() 151 | assert dim % dim_head == 0, 'Dimension must be divisible by the head dimension' 152 | self.heads = dim // dim_head 153 | 154 | self.dropout = dropout 155 | self.time_cond = None 156 | 157 | self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = self.heads, log_distance = False, depth = 2) 158 | 159 | self.norm = RMSNorm(dim) 160 | 161 | self.to_qkv = nn.Linear(dim, dim * 3, bias = False) 162 | self.to_out = nn.Linear(dim, dim) 163 | 164 | if exists(time_cond_dim): 165 | self.time_cond = nn.Sequential( 166 | nn.SiLU(), 167 | nn.Linear(time_cond_dim, dim * 3), 168 | Rearrange('b d -> b 1 d') 169 | ) 170 | zero_init_(self.time_cond[-2]) 171 | else: 172 | zero_init_(self.to_out) 173 | 174 | def forward(self, x, time=None, audio_mask=None): 175 | b, c, n = x.shape 176 | 177 | x = self.norm(x) 178 | 179 | if exists(self.time_cond): 180 | assert exists(time) 181 | scale, shift, gate = self.time_cond(time).chunk(3, dim = 2) 182 | x = (x * (scale + 1)) + shift 183 | 184 | qkv = self.to_qkv(x).chunk(3, dim = 2) 185 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads).contiguous(), qkv) 186 | 187 | i, j = map(lambda t: t.shape[-2], (q, k)) 188 | 189 | attn_bias = self.rel_pos(i, j) 190 | attn_bias = repeat(attn_bias, 'h i j -> b h i j', b=b) 191 | 192 | if exists(audio_mask): 193 | mask_value = -torch.finfo(q.dtype).max 194 | mask = rearrange(audio_mask, 'b l -> b () () l') 195 | attn_bias = attn_bias.masked_fill(~mask, mask_value) 196 | 197 | out = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout if self.training else 0., attn_mask=attn_bias) 198 | out = rearrange(out, 'b h n d -> b n (h d)') 199 | 200 | out = self.to_out(out) 201 | if exists(self.time_cond): 202 | out = out*gate 203 | 204 | return out 205 | 206 | class CrossAttention(nn.Module): 207 | def __init__(self, dim, dim_context, dim_head = 32, time_cond_dim=None, dropout=0.): 208 | super().__init__() 209 | assert dim % dim_head == 0, 'Dimension must be divisible by the head dimension' 210 | self.heads = dim // dim_head 211 | self.dropout = dropout 212 | self.norm = RMSNorm(dim) 213 | self.time_cond = None 214 | self.time_cond = None 215 | if exists(time_cond_dim): 216 | self.time_cond = nn.Sequential( 217 | nn.SiLU(), 218 | nn.Linear(time_cond_dim, dim * 3), 219 | Rearrange('b d -> b 1 d') 220 | ) 221 | zero_init_(self.time_cond[-2]) 222 | self.time_cond = None 223 | if exists(time_cond_dim): 224 | self.time_cond = nn.Sequential( 225 | nn.SiLU(), 226 | nn.Linear(time_cond_dim, dim * 3), 227 | Rearrange('b d -> b 1 d') 228 | ) 229 | zero_init_(self.time_cond[-2]) 230 | 231 | 232 | self.norm_context = nn.LayerNorm(dim_context) 233 | 234 | 235 | self.null_kv = nn.Parameter(torch.randn(2, dim)) 236 | self.to_q = nn.Linear(dim, dim, bias = False) 237 | self.to_kv = nn.Linear(dim_context, dim * 2, bias = False) 238 | self.to_out = nn.Linear(dim, dim) 239 | 240 | if exists(time_cond_dim): 241 | self.time_cond = nn.Sequential( 242 | nn.SiLU(), 243 | nn.Linear(time_cond_dim, dim * 3), 244 | Rearrange('b d -> b 1 d') 245 | ) 246 | zero_init_(self.time_cond[-2]) 247 | else: 248 | zero_init_(self.to_out) 249 | 250 | self.q_norm = RMSNorm(dim_head) 251 | self.k_norm = RMSNorm(dim_head) 252 | 253 | def forward(self, x, context, context_mask, time=None): 254 | ''' 255 | x: [B, L_audio, d_unet] 256 | context: [B, L_text, d_lm] 257 | context_mask: [B, L_text] 258 | ''' 259 | b, c, n = x.shape 260 | x = self.norm(x) 261 | if exists(self.time_cond): 262 | assert exists(time) 263 | scale, shift, gate = self.time_cond(time).chunk(3, dim = 2) 264 | x = (x * (scale + 1)) + shift 265 | context = self.norm_context(context) 266 | 267 | q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) 268 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads).contiguous(), (q, k, v)) 269 | # Null value for classifier free guidance 270 | 271 | nk, nv = map(lambda t: repeat(t, '(h d) -> b h 1 d', b = b, h=self.heads), self.null_kv.unbind(dim = -2)) 272 | k = torch.cat((nk, k), dim = -2) 273 | v = torch.cat((nv, v), dim = -2) 274 | 275 | query_len = q.shape[2] 276 | 277 | # RMSNorm Trick for stability 278 | q = self.q_norm(q) 279 | k = self.k_norm(k) 280 | # Masking pad tokens 281 | context_mask = F.pad(context_mask, (1, 0), value = True) 282 | context_mask = repeat(context_mask, 'b j -> b h q_len j', h=self.heads, q_len=query_len) 283 | 284 | out = F.scaled_dot_product_attention(q, k, v, attn_mask=context_mask, dropout_p=self.dropout if self.training else 0.) 285 | # attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask 286 | # attn_weight = torch.softmax((q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))) + attn_mask, dim=-1) 287 | # out = attn_weight @ v 288 | 289 | out = rearrange(out, 'b h n d -> b n (h d)') 290 | 291 | out = self.to_out(out) 292 | if exists(self.time_cond): 293 | out = out*gate 294 | 295 | return out 296 | 297 | 298 | class ConditionableTransformer(nn.Module): 299 | def __init__( 300 | self, 301 | dim, 302 | dim_context, 303 | *, 304 | num_layers, 305 | time_cond_dim, 306 | dim_head = 64, 307 | ff_mult = 4, 308 | dropout=0.0, 309 | conformer=False, 310 | ): 311 | super().__init__() 312 | self.dim = dim 313 | 314 | self.layers = nn.ModuleList([]) 315 | for _ in range(num_layers): 316 | self.layers.append(nn.ModuleList([ 317 | Attention(dim = dim, dim_head = dim_head, dropout=dropout), 318 | CrossAttention(dim = dim, dim_head = dim_head, dim_context=dim_context, dropout=dropout), 319 | ConformerConvBlock(dim, time_cond_dim=time_cond_dim, channels_first=False) if conformer else None, 320 | FeedForward(dim=dim, mult=ff_mult, time_cond_dim=time_cond_dim) 321 | ])) 322 | 323 | 324 | def forward( 325 | self, 326 | x, 327 | *, 328 | time, 329 | context, 330 | context_mask, 331 | audio_mask, 332 | ): 333 | for attn, cross_attn, conv, ff in self.layers: 334 | res = x 335 | x = attn(x, audio_mask=audio_mask) + res 336 | 337 | res = x 338 | x = cross_attn(x, context = context, context_mask=context_mask) + res 339 | 340 | if conv is not None: 341 | res = x 342 | x = conv(x, time=time) + res 343 | 344 | res = x 345 | x = ff(x, time=time) + res 346 | 347 | return x 348 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | 4 | import torch 5 | from torch import nn, einsum 6 | import torch.nn.functional as F 7 | 8 | from einops import rearrange, reduce, repeat 9 | from einops.layers.torch import Rearrange 10 | 11 | from models.modules import RMSNorm, ConvRMSNorm, ConditionableTransformer, LayerNorm, MaskedConv1d 12 | 13 | # helpers functions 14 | 15 | def exists(x): 16 | return x is not None 17 | 18 | def default(val, d): 19 | if exists(val): 20 | return val 21 | return d() if callable(d) else d 22 | 23 | # small helper modules 24 | 25 | class Residual(nn.Module): 26 | def __init__(self, fn): 27 | super().__init__() 28 | self.fn = fn 29 | 30 | def forward(self, x, *args, **kwargs): 31 | return self.fn(x, *args, **kwargs) + x 32 | 33 | def Upsample(dim, dim_out = None): 34 | return nn.Sequential( 35 | nn.Upsample(scale_factor = 2, mode = 'nearest'), 36 | MaskedConv1d(dim, default(dim_out, dim), 3, padding = 1) 37 | ) 38 | 39 | def Downsample(dim, dim_out = None): 40 | return MaskedConv1d(dim, default(dim_out, dim), 4, 2, 1) 41 | 42 | 43 | # sinusoidal positional embeds 44 | 45 | class SinusoidalPosEmb(nn.Module): 46 | def __init__(self, dim): 47 | super().__init__() 48 | self.dim = dim 49 | 50 | def forward(self, x): 51 | device = x.device 52 | half_dim = self.dim // 2 53 | emb = math.log(10000) / (half_dim - 1) 54 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 55 | emb = x[:, None] * emb[None, :] 56 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 57 | return emb 58 | 59 | class RandomOrLearnedSinusoidalPosEmb(nn.Module): 60 | """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """ 61 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ 62 | 63 | def __init__(self, dim, is_random = False): 64 | super().__init__() 65 | assert (dim % 2) == 0 66 | half_dim = dim // 2 67 | self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random) 68 | 69 | def forward(self, x): 70 | x = rearrange(x, 'b -> b 1') 71 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi 72 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) 73 | fouriered = torch.cat((x, fouriered), dim = -1) 74 | return fouriered 75 | # building block modules 76 | 77 | class Block(nn.Module): 78 | def __init__(self, dim, dim_out, groups = 8): 79 | super().__init__() 80 | self.proj = MaskedConv1d(dim, dim_out, 3, padding = 1) 81 | self.norm = nn.GroupNorm(groups, dim) 82 | self.act = nn.SiLU() 83 | 84 | def forward(self, x, scale_shift = None, audio_mask=None): 85 | x = self.norm(x) 86 | if exists(scale_shift): 87 | scale, shift = scale_shift 88 | x = x * (scale + 1) + shift 89 | x = self.act(x) 90 | 91 | x = self.proj(x, audio_mask) 92 | 93 | return x 94 | 95 | class ResnetBlock(nn.Module): 96 | def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 32): 97 | super().__init__() 98 | self.mlp = None 99 | if exists(time_emb_dim): 100 | self.mlp = nn.Sequential( 101 | nn.SiLU(), 102 | nn.Linear(time_emb_dim, dim_out * 3), 103 | Rearrange('b c -> b c 1') 104 | ) 105 | zero_init_(self.mlp[-2]) 106 | 107 | 108 | self.block1 = Block(dim, dim_out, groups = groups) 109 | self.block2 = Block(dim_out, dim_out, groups = groups) 110 | self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity() 111 | 112 | def forward(self, x, time_emb = None, audio_mask=None): 113 | 114 | scale_shift = None 115 | if exists(self.mlp): 116 | assert exists(time_emb) 117 | scale_shift = self.mlp(time_emb) 118 | scale, shift, gate = scale_shift.chunk(3, dim = 1) 119 | scale_shift = (scale, shift) 120 | 121 | h = self.block1(x, audio_mask=audio_mask) 122 | 123 | h = self.block2(h, scale_shift=scale_shift, audio_mask=audio_mask) 124 | 125 | if exists(self.mlp): 126 | h = h*gate 127 | 128 | return h + self.res_conv(x) 129 | 130 | class LinearAttention(nn.Module): 131 | def __init__(self, dim, heads = 4, dim_head = 32): 132 | super().__init__() 133 | self.scale = dim_head ** -0.5 134 | self.heads = heads 135 | hidden_dim = dim_head * heads 136 | self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias = False) 137 | 138 | self.to_out = nn.Conv1d(hidden_dim, dim, 1) 139 | zero_init_(self.to_out) 140 | 141 | def forward(self, x, audio_mask=None): 142 | b, c, n = x.shape 143 | qkv = self.to_qkv(x).chunk(3, dim = 1) 144 | q, k, v = map(lambda t: rearrange(t, 'b (h c) n -> b h c n', h = self.heads), qkv) 145 | 146 | if exists(audio_mask): 147 | mask_value = -torch.finfo(q.dtype).max 148 | mask = audio_mask[:, None, None, :] 149 | k = k.masked_fill(~mask, mask_value) 150 | v = v.masked_fill(~mask, 0.) 151 | del mask 152 | 153 | q = q.softmax(dim = -2) 154 | k = k.softmax(dim = -1) 155 | 156 | q = q * self.scale 157 | 158 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v) 159 | 160 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q) 161 | out = rearrange(out, 'b h c n -> b (h c) n', h = self.heads) 162 | return self.to_out(out) 163 | 164 | def l2norm(t): 165 | return F.normalize(t, dim = -1) 166 | 167 | 168 | def masked_mean(t, *, dim, mask = None): 169 | if not exists(mask): 170 | return t.mean(dim = dim) 171 | 172 | denom = mask.sum(dim = dim, keepdim = True) 173 | mask = rearrange(mask, 'b n -> b n 1') 174 | masked_t = t.masked_fill(~mask, 0.) 175 | 176 | return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5) 177 | 178 | class PreNorm(nn.Module): 179 | def __init__(self, dim, fn): 180 | super().__init__() 181 | self.fn = fn 182 | self.norm = ConvRMSNorm(dim) 183 | 184 | def forward(self, x, *args, **kwargs): 185 | x = self.norm(x) 186 | return self.fn(x, *args, **kwargs) 187 | 188 | def zero_init_(m): 189 | nn.init.zeros_(m.weight) 190 | if exists(m.bias): 191 | nn.init.zeros_(m.bias) 192 | 193 | class FeedForward(nn.Module): 194 | def __init__(self, dim, mult = 4, zero_init=False): 195 | super().__init__() 196 | self.norm = LayerNorm(dim) 197 | 198 | inner_dim = int(dim * mult) 199 | self.net = nn.Sequential( 200 | nn.Linear(dim, inner_dim), 201 | nn.GELU(), 202 | nn.Linear(inner_dim, dim) 203 | ) 204 | if zero_init: 205 | zero_init_(self.net[-1]) 206 | 207 | def forward(self, x): 208 | x = self.norm(x) 209 | 210 | return self.net(x) 211 | 212 | # model 213 | 214 | class Unet1D(nn.Module): 215 | def __init__( 216 | self, 217 | dim, 218 | text_dim, 219 | init_dim = None, 220 | out_dim = None, 221 | dim_mults=(1, 2, 4, 8), 222 | channels = 128, 223 | conformer_transformer=False, 224 | inpainting_embedding = False, 225 | resnet_block_groups = 32, 226 | scale_skip_connection=False, 227 | num_transformer_layers = 3, 228 | dropout=0.0, 229 | ): 230 | super().__init__() 231 | 232 | 233 | self.channels = channels 234 | 235 | input_channels = channels 236 | 237 | init_dim = default(init_dim, dim) 238 | self.init_conv = nn.Conv1d(input_channels, init_dim, 1) 239 | 240 | dims = [init_dim, *map(lambda m: int(dim * m), dim_mults)] 241 | in_out = list(zip(dims[:-1], dims[1:])) 242 | 243 | block_klass = partial(ResnetBlock, groups = resnet_block_groups) 244 | 245 | if inpainting_embedding: 246 | self.inpainting_embedding = nn.Embedding(2, init_dim) 247 | else: 248 | self.inpainting_embedding = None 249 | 250 | # time embeddings 251 | 252 | time_dim = dim * 2 253 | 254 | sinu_pos_emb = SinusoidalPosEmb(dim) 255 | fourier_dim = dim 256 | 257 | self.time_mlp = nn.Sequential( 258 | sinu_pos_emb, 259 | nn.Linear(fourier_dim, time_dim), 260 | nn.SiLU(), 261 | nn.Linear(time_dim, time_dim) 262 | ) 263 | 264 | # layers 265 | 266 | self.downs = nn.ModuleList([]) 267 | self.ups = nn.ModuleList([]) 268 | num_resolutions = len(in_out) 269 | 270 | for ind, (dim_in, dim_out) in enumerate(in_out): 271 | is_last = ind >= (num_resolutions - 1) 272 | 273 | self.downs.append(nn.ModuleList([ 274 | block_klass(dim_in, dim_in, time_emb_dim = time_dim), 275 | block_klass(dim_in, dim_in, time_emb_dim = time_dim), 276 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 277 | Downsample(dim_in, dim_out) if not is_last else MaskedConv1d(dim_in, dim_out, 3, padding = 1) 278 | ])) 279 | 280 | mid_dim = dims[-1] 281 | self.transformer = ConditionableTransformer(mid_dim, dim_context=text_dim, num_layers=num_transformer_layers, time_cond_dim=time_dim, dropout=dropout, conformer=conformer_transformer) 282 | 283 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): 284 | is_last = ind == (len(in_out) - 1) 285 | 286 | self.ups.append(nn.ModuleList([ 287 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), 288 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), 289 | Residual(PreNorm(dim_out, LinearAttention(dim_out))), 290 | Upsample(dim_out, dim_in) if not is_last else MaskedConv1d(dim_out, dim_in, 3, padding = 1) 291 | ])) 292 | 293 | self.out_dim = channels 294 | 295 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) 296 | 297 | self.final_conv = nn.Conv1d(dim, self.out_dim, 1) 298 | 299 | # Accelerates convergence for image diffusion models 300 | # Use it by default, but haven't ablated 301 | self.scale_skip_connection = (2 ** -0.5) if scale_skip_connection else 1 302 | 303 | self.to_text_non_attn_cond = nn.Sequential( 304 | nn.LayerNorm(text_dim), 305 | nn.Linear(text_dim, time_dim), 306 | nn.SiLU(), 307 | nn.Linear(time_dim, time_dim) 308 | ) 309 | 310 | def forward(self, x, time_cond, text_cond=None, text_cond_mask=None, inpainting_mask=None, audio_mask=None): 311 | assert x.shape[-1] % (2**(len(self.downs)-1)) == 0, f'Length of the audio latent must be a factor of {2**(len(self.downs)-1)}' 312 | if not exists(audio_mask): 313 | audio_mask = torch.ones((x.shape[0], x.shape[2]), dtype=torch.bool, device=x.device) 314 | assert torch.remainder(audio_mask.sum(dim=1), (2**(len(self.downs)-1))).sum().item()==0, f'Length of audio mask must be a factor of {2**(len(self.downs)-1)}' 315 | x = self.init_conv(x) 316 | if exists(self.inpainting_embedding): 317 | assert exists(inpainting_mask) 318 | inpainting_emb = self.inpainting_embedding(inpainting_mask) 319 | x = x + rearrange(inpainting_emb, 'b l c -> b c l') 320 | 321 | r = x.clone() 322 | 323 | mean_pooled_context = masked_mean(text_cond, dim=1, mask=text_cond_mask) 324 | text_mean_cond = self.to_text_non_attn_cond(mean_pooled_context) 325 | 326 | # Rescale continuous time [0,1] to similar range as Ho et al. 2020 327 | t = self.time_mlp(time_cond*1000) 328 | 329 | t = t + text_mean_cond 330 | 331 | h = [] 332 | audio_mask_list = [audio_mask] 333 | for block1, block2, attn, downsample in self.downs: 334 | x = block1(x, t, audio_mask=audio_mask_list[-1]) 335 | h.append(x) 336 | 337 | x = block2(x, t, audio_mask=audio_mask_list[-1]) 338 | x = attn(x, audio_mask=audio_mask_list[-1]) 339 | h.append(x) 340 | 341 | 342 | x_prev_shape = x.shape 343 | x = downsample(x, audio_mask_list[-1]) 344 | if x.shape[-1] != x_prev_shape[-1]: 345 | downsampled_mask = reduce(audio_mask_list[-1], 'b (l 2) -> b l', reduction='max') 346 | audio_mask_list.append(downsampled_mask) 347 | x = rearrange(x, 'b c l -> b l c') 348 | 349 | x = self.transformer(x, context=text_cond, context_mask=text_cond_mask, time=t, audio_mask=audio_mask_list[-1]) 350 | x = rearrange(x, 'b l c -> b c l') 351 | 352 | for block1, block2, attn, upsample in self.ups: 353 | x = torch.cat((x, h.pop()*(self.scale_skip_connection)), dim = 1) 354 | x = block1(x, t, audio_mask_list[-1]) 355 | 356 | x = torch.cat((x, h.pop()*(self.scale_skip_connection)), dim = 1) 357 | x = block2(x, t, audio_mask_list[-1]) 358 | x = attn(x, audio_mask_list[-1]) 359 | 360 | # Awkward implementation to maintain backwards compatibility with previous checkpoints 361 | if isinstance(upsample, nn.Sequential): 362 | # Need to cast to float32 for upsampling 363 | # Upsample operation 364 | x = upsample[0](x.float()) 365 | audio_mask_list.pop() 366 | # Masked conv operation 367 | x = upsample[1](x, audio_mask_list[-1]) 368 | 369 | else: 370 | x = upsample(x, audio_mask_list[-1]) 371 | 372 | x = torch.cat((x, r), dim = 1) 373 | 374 | x = self.final_res_block(x, t, audio_mask) 375 | return self.final_conv(x) -------------------------------------------------------------------------------- /neural_codec/encodec_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | from torch import nn, einsum 4 | from encodec import EncodecModel 5 | from encodec.utils import _linear_overlap_add 6 | from encodec.utils import convert_audio 7 | import typing as tp 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | from einops import rearrange 12 | from functools import partial 13 | import os 14 | 15 | EncodedFrame = tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]] 16 | 17 | class EncodecWrapper(nn.Module): 18 | def __init__(self): 19 | super().__init__() 20 | self.codec = EncodecModel.encodec_model_24khz() 21 | self.codec.set_target_bandwidth(24.) 22 | 23 | def encode(self, x: torch.Tensor) -> torch.Tensor: 24 | """Given a tensor `x`, returns a list of frames containing 25 | the discrete encoded codes for `x`, along with rescaling factors 26 | for each segment, when `self.normalize` is True. 27 | 28 | Each frames is a tuple `(codebook, scale)`, with `codebook` of 29 | shape `[B, K, T]`, with `K` the number of codebooks. 30 | """ 31 | assert x.dim() == 3 32 | _, channels, length = x.shape 33 | assert channels > 0 and channels <= 2 34 | segment_length = self.codec.segment_length 35 | if segment_length is None: 36 | segment_length = length 37 | stride = length 38 | else: 39 | stride = self.codec.segment_stride # type: ignore 40 | assert stride is not None 41 | 42 | encoded_frames: tp.List[EncodedFrame] = [] 43 | for offset in range(0, length, stride): 44 | frame = x[:, :, offset: offset + segment_length] 45 | encoded_frames.append(self._encode_frame(frame)) 46 | assert len(encoded_frames) == 1 47 | assert encoded_frames[0][1] is None 48 | return encoded_frames[0][0] 49 | 50 | def _encode_frame(self, x: torch.Tensor) -> EncodedFrame: 51 | length = x.shape[-1] 52 | duration = length / self.codec.sample_rate 53 | assert self.codec.segment is None or duration <= 1e-5 + self.codec.segment 54 | 55 | if self.codec.normalize: 56 | mono = x.mean(dim=1, keepdim=True) 57 | volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() 58 | scale = 1e-8 + volume 59 | x = x / scale 60 | scale = scale.view(-1, 1) 61 | else: 62 | scale = None 63 | 64 | emb = self.codec.encoder(x) 65 | return emb, scale 66 | 67 | def decode(self, emb: torch.Tensor, quantize: bool=True) -> torch.Tensor: 68 | """Decode the given frames into a waveform. 69 | Note that the output might be a bit bigger than the input. In that case, 70 | any extra steps at the end can be trimmed. 71 | """ 72 | encoded_frames = [(emb, None)] 73 | segment_length = self.codec.segment_length 74 | if segment_length is None: 75 | assert len(encoded_frames) == 1 76 | return self._decode_frame(encoded_frames[0]) 77 | 78 | frames = [self._decode_frame(frame, quantize=quantize) for frame in encoded_frames] 79 | return _linear_overlap_add(frames, self.segment_stride or 1) 80 | 81 | def _decode_frame(self, encoded_frame: EncodedFrame, quantize: bool=True) -> torch.Tensor: 82 | emb, scale = encoded_frame 83 | if quantize: 84 | codes = self.codec.quantizer.encode(emb, self.codec.frame_rate, self.codec.bandwidth) 85 | emb = self.codec.quantizer.decode(codes) 86 | 87 | # codes is [B, K, T], with T frames, K nb of codebooks. 88 | out = self.codec.decoder(emb) 89 | if scale is not None: 90 | out = out * scale.view(-1, 1, 1) 91 | return out 92 | 93 | 94 | def forward(self, wav:torch.tensor, sr:int, quantize:bool=True): 95 | # TODO: Revisit where to handle processing 96 | wav = convert_audio(wav, sr, self.codec.sample_rate, self.codec.channels) 97 | frames = self.encode(wav) 98 | return self.decode(frames, quantize=quantize)[:, :, :wav.shape[-1]] 99 | 100 | 101 | 102 | def test(): 103 | codec = EncodecWrapper().to('cuda') 104 | import sys 105 | import soundfile as sf 106 | sys.path.append('../') 107 | from audio_datasets.librispeech import LibriSpeech, ENCODEC_SAMPLING_RATE 108 | from diffusion.noise_schedule import simple_linear_schedule, time_to_alpha, cosine_schedule 109 | 110 | test_dataset = LibriSpeech(split='test') 111 | with torch.no_grad(): 112 | for idx in range(len(test_dataset)): 113 | example = test_dataset.__getitem__(idx) 114 | 115 | # [B, 1, L]: batch x channels x length 116 | batched_wav = example['wav'][:,:int(example['audio_duration']*ENCODEC_SAMPLING_RATE)].unsqueeze(0).to('cuda') 117 | 118 | scales = [1.0, .5, .2] 119 | for scale in scales: 120 | times = [.05*x for x in range(0, 21)] 121 | alpha_schedule = partial(time_to_alpha, alpha_schedule=cosine_schedule, scale=scale) 122 | os.makedirs(f'example_audio/cosine/scale{scale}/', exist_ok=True) 123 | std_scale_factor = .22 124 | for time in times: 125 | os.makedirs(f'example_audio/cosine/scale{scale}/time{time}', exist_ok=True) 126 | alpha = alpha_schedule(torch.tensor([time], device=batched_wav.device)) 127 | 128 | wav_emb = codec.encode(batched_wav)*std_scale_factor 129 | noisy_wav_emb = alpha.sqrt()*wav_emb + (1-alpha).sqrt()*torch.randn_like(wav_emb) 130 | noisy_wav_emb /= alpha.sqrt() 131 | noisy_wav_emb /= std_scale_factor 132 | noisy_reconstruction = codec.decode(noisy_wav_emb) 133 | 134 | sf.write(f'example_audio/cosine/scale{scale}/time{time}/audio_{idx}.wav', noisy_reconstruction.squeeze().to('cpu').numpy(), ENCODEC_SAMPLING_RATE) 135 | 136 | 137 | 138 | 139 | if __name__=='__main__': 140 | test() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | accelerate==0.21.0 3 | aiohttp==3.8.4 4 | aiosignal==1.3.1 5 | alembic==1.11.1 6 | anyascii==0.3.2 7 | appdirs==1.4.4 8 | async-timeout==4.0.2 9 | attrs==23.1.0 10 | audioread==3.0.0 11 | Babel==2.12.1 12 | bangla==0.0.2 13 | biopython==1.79 14 | blinker==1.6.2 15 | bnnumerizer==0.0.2 16 | bnunicodenormalizer==0.1.1 17 | boltons==23.0.0 18 | bottle==0.12.25 19 | cachetools==5.3.1 20 | certifi==2023.5.7 21 | cffi==1.15.1 22 | charset-normalizer==3.1.0 23 | clean-fid==0.1.35 24 | click==8.1.3 25 | clip-anytorch==2.5.2 26 | cloudpickle==2.2.1 27 | cmake==3.26.3 28 | contourpy==1.0.7 29 | coqpit==0.0.17 30 | cycler==0.11.0 31 | Cython==0.29.28 32 | databricks-cli==0.17.7 33 | dataclassy==1.0.1 34 | datasets==2.12.0 35 | dateparser==1.1.8 36 | decorator==5.1.1 37 | dill==0.3.6 38 | docker==6.1.3 39 | docker-pycreds==0.4.0 40 | docopt==0.6.2 41 | einops==0.6.1 42 | ema-pytorch==0.2.3 43 | encodec==0.1.1 44 | entrypoints==0.4 45 | evaluate==0.4.0 46 | filelock==3.12.0 47 | Flask==2.3.2 48 | fonttools==4.39.4 49 | frozenlist==1.3.3 50 | fsspec==2023.5.0 51 | ftfy==6.1.1 52 | g2pkk==0.1.2 53 | gitdb==4.0.10 54 | GitPython==3.1.31 55 | google-auth==2.22.0 56 | google-auth-oauthlib==1.0.0 57 | greenlet==2.0.2 58 | grpcio==1.56.0 59 | gruut==2.2.3 60 | gruut-ipa==0.13.0 61 | gruut-lang-de==2.0.0 62 | gruut-lang-en==2.0.0 63 | gruut-lang-es==2.0.0 64 | gruut-lang-fr==2.0.2 65 | gunicorn==20.1.0 66 | hdbscan==0.8.33 67 | huggingface-hub==0.15.1 68 | HyperPyYAML==1.2.1 69 | idna==3.4 70 | imageio==2.31.0 71 | importlib-metadata==6.6.0 72 | inflect==5.6.0 73 | initdb==0.5.0 74 | itsdangerous==2.1.2 75 | jamo==0.4.1 76 | jieba==0.42.1 77 | Jinja2==3.1.2 78 | jiwer==3.0.2 79 | joblib==1.2.0 80 | jsonlines==1.2.0 81 | jsonmerge==1.9.0 82 | jsonschema==4.17.3 83 | k-diffusion==0.0.15 84 | kiwisolver==1.4.4 85 | kneed==0.8.5 86 | kornia==0.6.12 87 | lazy_loader==0.2 88 | librosa==0.10.0.post2 89 | lit==16.0.5.post0 90 | llvmlite==0.39.1 91 | Mako==1.2.4 92 | Markdown==3.4.3 93 | markdown-it-py==3.0.0 94 | MarkupSafe==2.1.3 95 | matplotlib==3.7.1 96 | mdurl==0.1.2 97 | mecab-python3==1.0.5 98 | mlflow==2.4.1 99 | Montreal-Forced-Aligner==2.2.15 100 | mpmath==1.3.0 101 | msgpack==1.0.5 102 | multidict==6.0.4 103 | multiprocess==0.70.14 104 | networkx==2.8.8 105 | nltk==3.8.1 106 | num2words==0.5.12 107 | numba==0.56.4 108 | numpy==1.23.5 109 | nvidia-cublas-cu11==11.10.3.66 110 | nvidia-cuda-cupti-cu11==11.7.101 111 | nvidia-cuda-nvrtc-cu11==11.7.99 112 | nvidia-cuda-runtime-cu11==11.7.99 113 | nvidia-cudnn-cu11==8.5.0.96 114 | nvidia-cufft-cu11==10.9.0.58 115 | nvidia-curand-cu11==10.2.10.91 116 | nvidia-cusolver-cu11==11.4.0.1 117 | nvidia-cusparse-cu11==11.7.4.91 118 | nvidia-nccl-cu11==2.14.3 119 | nvidia-nvtx-cu11==11.7.91 120 | oauthlib==3.2.2 121 | packaging==23.1 122 | pandas==2.0.2 123 | pathtools==0.1.2 124 | pgvector==0.1.8 125 | Pillow==9.5.0 126 | pooch==1.6.0 127 | praatio==6.0.1 128 | protobuf==3.19.6 129 | psutil==5.9.5 130 | psycopg2==2.9.6 131 | pyarrow==12.0.0 132 | pyasn1==0.5.0 133 | pyasn1-modules==0.3.0 134 | pycparser==2.21 135 | pydub==0.25.1 136 | Pygments==2.15.1 137 | PyJWT==2.7.0 138 | pynini==2.1.5 139 | pynndescent==0.5.10 140 | pyparsing==3.0.9 141 | pypinyin==0.49.0 142 | pyrsistent==0.19.3 143 | pysbd==0.3.4 144 | python-crfsuite==0.9.9 145 | python-dateutil==2.8.2 146 | pytz==2023.3 147 | PyWavelets==1.4.1 148 | PyYAML==6.0 149 | querystring-parser==1.2.4 150 | rapidfuzz==2.13.7 151 | regex==2023.6.3 152 | requests==2.31.0 153 | requests-oauthlib==1.3.1 154 | resampy==0.4.2 155 | resize-right==0.0.2 156 | responses==0.18.0 157 | rich==13.4.2 158 | rich-click==1.6.1 159 | rsa==4.9 160 | ruamel.yaml==0.17.28 161 | ruamel.yaml.clib==0.2.7 162 | safetensors==0.3.1 163 | scikit-image==0.21.0 164 | scikit-learn==1.2.2 165 | scipy==1.10.1 166 | seaborn==0.12.2 167 | sentencepiece==0.1.99 168 | sentry-sdk==1.25.1 169 | setproctitle==1.3.2 170 | six==1.16.0 171 | smmap==5.0.0 172 | soundfile==0.12.1 173 | sox==1.4.1 174 | soxr==0.3.5 175 | speechbrain==0.5.14 176 | SQLAlchemy==2.0.16 177 | sqlparse==0.4.4 178 | sympy==1.12 179 | tabulate==0.9.0 180 | tensorboard==2.13.0 181 | tensorboard-data-server==0.7.1 182 | tensorboardX==2.6 183 | threadpoolctl==3.1.0 184 | tifffile==2023.4.12 185 | tokenizers==0.13.3 186 | torch==2.0.1 187 | torchaudio==2.0.2 188 | torchdiffeq==0.2.3 189 | torchsde==0.2.5 190 | torchvision==0.15.2 191 | tqdm==4.65.0 192 | trainer==0.0.20 193 | trampoline==0.1.2 194 | transformers==4.31.0 195 | triton==2.0.0 196 | TTS==0.14.3 197 | typing_extensions==4.6.3 198 | tzdata==2023.3 199 | tzlocal==5.0.1 200 | umap-learn==0.5.1 201 | unidic-lite==1.0.8 202 | urllib3==1.26.16 203 | waitress==2.1.2 204 | wandb==0.15.4 205 | wcwidth==0.2.6 206 | websocket-client==1.5.3 207 | Werkzeug==2.3.5 208 | xxhash==3.2.0 209 | yarl==1.9.2 210 | zipp==3.15.0 211 | -------------------------------------------------------------------------------- /scripts/sample/sample_16_ls_testclean.sh: -------------------------------------------------------------------------------- 1 | python train_audio_diffusion.py --eval_test --resume_dir saved_models/uvit_byt5large_12layer_final_scale50 --sampling_timesteps 250 --run_name test/sample16 --sampler ddpm --seed 42 --num_samples 16 --scale 0.5 --guidance 2.0,3.0,5.0 2 | -------------------------------------------------------------------------------- /scripts/train/train.sh: -------------------------------------------------------------------------------- 1 | python ./train_audio_diffusion.py --dataset_name mls --text_encoder google/byt5-large --batch_size 16 --gradient_accumulation_steps 2 --run_name uat_byt5large_12layer --save_and_sample_every 5000 --learning_rate 1e-4 --mixed_precision bf16 --scale .5 --loss_type l1 --scale_skip_connection --inpainting_prob 0.5 --num_train_steps 200000 --num_transformer_layers 12 --dim 512 --dim_mults 1,1,1,1.5 --num_samples 128 --inpainting_embedding -------------------------------------------------------------------------------- /scripts/train/train_distributed.sh: -------------------------------------------------------------------------------- 1 | accelerate launch ./train_audio_diffusion.py --dataset_name mls --text_encoder google/byt5-large --batch_size 16 --gradient_accumulation_steps 2 --run_name uat_byt5large_12layer --save_and_sample_every 5000 --learning_rate 1e-4 --mixed_precision bf16 --scale .5 --loss_type l1 --scale_skip_connection --inpainting_prob 0.5 --num_train_steps 200000 --num_transformer_layers 12 --dim 512 --dim_mults 1,1,1,1.5 --num_samples 128 --inpainting_embedding -------------------------------------------------------------------------------- /train_audio_diffusion.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils.utils import get_output_dir, parse_float_tuple 3 | import json 4 | import os 5 | import numpy as np 6 | 7 | from diffusion.audio_denoising_diffusion import GaussianDiffusion, Trainer 8 | from models.unet import Unet1D 9 | from transformers import AutoConfig 10 | 11 | 12 | def main(args): 13 | 14 | config = AutoConfig.from_pretrained(args.text_encoder) 15 | text_dim = config.d_model 16 | 17 | model = Unet1D( 18 | dim=args.dim, 19 | text_dim=text_dim, 20 | dim_mults=args.dim_mults, 21 | inpainting_embedding = args.inpainting_embedding, 22 | conformer_transformer=args.conformer_transformer, 23 | num_transformer_layers=args.num_transformer_layers, 24 | scale_skip_connection=args.scale_skip_connection, 25 | dropout=args.dropout, 26 | ) 27 | 28 | args.trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 29 | print(f'Trainable params: {args.trainable_params}') 30 | 31 | diffusion = GaussianDiffusion( 32 | model, 33 | max_seq_len = 2048, 34 | text_encoder=args.text_encoder, 35 | sampling_timesteps = args.sampling_timesteps, # number of sampling steps 36 | sampler=args.sampler, 37 | train_schedule= args.train_schedule, 38 | sampling_schedule= args.sampling_schedule, 39 | loss_type = args.loss_type, # L1 or L2 40 | objective = args.objective, 41 | parameterization = args.parameterization, 42 | ema_decay = args.ema_decay, 43 | scale = args.scale, 44 | unconditional_prob = args.unconditional_prob, 45 | inpainting_prob = args.inpainting_prob, 46 | ) 47 | 48 | trainer = Trainer( 49 | args=args, 50 | diffusion=diffusion, 51 | dataset_name=args.dataset_name, 52 | batch_size= args.batch_size, 53 | gradient_accumulate_every = args.gradient_accumulation_steps, 54 | train_lr = args.learning_rate, 55 | train_num_steps = args.num_train_steps, 56 | lr_schedule = args.lr_schedule, 57 | num_warmup_steps = args.lr_warmup_steps, 58 | adam_betas = (args.adam_beta1, args.adam_beta2), 59 | adam_weight_decay = args.adam_weight_decay, 60 | save_and_sample_every = args.save_and_sample_every, 61 | num_samples = args.num_samples, 62 | mixed_precision = args.mixed_precision, 63 | prefix_inpainting_seconds = args.prefix_inpainting_seconds, 64 | seed=args.seed, 65 | ) 66 | 67 | if args.eval or args.eval_test: 68 | trainer.load(args.resume_dir) 69 | # trainer.sample() 70 | if args.prefix_inpainting_seconds > 0: 71 | cls_free_guidances = args.guidance 72 | elif 'ablation' in args.run_name: 73 | cls_free_guidances = [5.0] # Use stronger guidance for ablations b/c they are under-trained 74 | else: 75 | cls_free_guidances = args.guidance 76 | for cls_free_guidance in cls_free_guidances: 77 | trainer.sample(cls_free_guidance=cls_free_guidance, prefix_seconds=args.prefix_inpainting_seconds, test=args.eval_test, seed=42) 78 | return 79 | 80 | if args.init_model is not None: 81 | trainer.load(args.init_model, init_only=True) 82 | 83 | if args.resume: 84 | trainer.load(args.resume_dir) 85 | 86 | trainer.train() 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser(description="Training arguments") 90 | parser.add_argument("--dataset_name", type=str, default='librispeech') 91 | parser.add_argument("--save_dir", type=str, default="saved_models") 92 | parser.add_argument("--text_encoder", type=str, default="google/byt5-small") 93 | parser.add_argument("--output_dir", type=str, default=None) 94 | parser.add_argument("--resume_dir", type=str, default=None) 95 | parser.add_argument("--init_model", type=str, default=None) 96 | parser.add_argument("--run_name", type=str, default=None) 97 | parser.add_argument("--seed", type=int, default=None) 98 | # Architecture hyperparameters 99 | parser.add_argument("--dim", type=int, default=512) 100 | parser.add_argument('--dim_mults', type=parse_float_tuple, default=(1, 1, 1, 1.5), help='Tuple of integer values for dim_mults') 101 | parser.add_argument("--conformer_transformer", action="store_true", default=False) 102 | parser.add_argument("--scale_skip_connection", action="store_true", default=False) 103 | parser.add_argument("--num_transformer_layers", type=int, default=3) 104 | parser.add_argument("--dropout", type=float, default=0.) 105 | parser.add_argument("--inpainting_embedding", action="store_true", default=False) 106 | 107 | # Optimization hyperparameters 108 | parser.add_argument("--optimizer", type=str, default="adamw") 109 | parser.add_argument("--batch_size", type=int, default=16) 110 | parser.add_argument("--num_train_steps", type=int, default=60000) 111 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 112 | parser.add_argument("--learning_rate", type=float, default=1e-4) 113 | parser.add_argument("--clip_grad_norm", type=float, default=1.0) 114 | parser.add_argument("--lr_schedule", type=str, default="cosine") 115 | parser.add_argument("--lr_warmup_steps", type=int, default=1000) 116 | parser.add_argument("--adam_beta1", type=float, default=0.9) 117 | parser.add_argument("--adam_beta2", type=float, default=0.999) 118 | parser.add_argument("--adam_weight_decay", type=float, default=0) 119 | parser.add_argument("--ema_decay", type=float, default=0.9999) 120 | # Diffusion Hyperparameters 121 | parser.add_argument( 122 | "--objective", 123 | type=str, 124 | default="pred_v", 125 | choices=["pred_eps", "pred_x0", "pred_v"], 126 | help=( 127 | "Which loss objective to use for the diffusion objective." 128 | ), 129 | ) 130 | parser.add_argument( 131 | "--parameterization", 132 | type=str, 133 | default="pred_v", 134 | choices=["pred_eps", "pred_x0", "pred_v"], 135 | help=( 136 | "Which output parameterization to use for the diffusion network." 137 | ), 138 | ) 139 | parser.add_argument( 140 | "--loss_type", 141 | type=str, 142 | default="l1", 143 | choices=["l1", "l2"], 144 | help=( 145 | "Which loss function to use for diffusion." 146 | ), 147 | ) 148 | parser.add_argument( 149 | "--train_schedule", 150 | type=str, 151 | default="cosine", 152 | choices=["beta_linear", "simple_linear", "cosine", 'sigmoid'], 153 | help=( 154 | "Which noise schedule to use." 155 | ), 156 | ) 157 | parser.add_argument( 158 | "--sampling_schedule", 159 | type=str, 160 | default=None, 161 | choices=["beta_linear", "cosine", "simple_linear", None], 162 | help=( 163 | "Which noise schedule to use." 164 | ), 165 | ) 166 | parser.add_argument("--resume", action="store_true", default=False) 167 | parser.add_argument("--scale", type=float, default=1.0) 168 | parser.add_argument("--sampling_timesteps", type=int, default=250) 169 | # Audio Training Parameters 170 | parser.add_argument("--unconditional_prob", type=float, default=.1) 171 | parser.add_argument("--inpainting_prob", type=float, default=.5) 172 | # Generation Arguments 173 | parser.add_argument("--save_and_sample_every", type=int, default=5000) 174 | parser.add_argument("--num_samples", type=int, default=None) 175 | parser.add_argument( 176 | "--sampler", 177 | type=str, 178 | default="ddim", 179 | choices=["ddim", "ddpm"], 180 | help=( 181 | "Which sampler use for diffusion." 182 | ), 183 | ) 184 | parser.add_argument("--prefix_inpainting_seconds", type=float, default=0.) 185 | # Accelerate arguments 186 | parser.add_argument( 187 | "--mixed_precision", 188 | type=str, 189 | default="no", 190 | choices=["no", "fp16", "bf16"], 191 | help=( 192 | "Whether to use mixed precision. Choose" 193 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 194 | "and an Nvidia Ampere GPU." 195 | ), 196 | ) 197 | # Load and eval model 198 | parser.add_argument("--eval", action="store_true", default=False) 199 | parser.add_argument("--eval_test", action="store_true", default=False) 200 | parser.add_argument('--guidance', type=parse_float_tuple, help='Tuple of float values for dim_mults') 201 | 202 | args = parser.parse_args() 203 | if args.eval or args.eval_test: 204 | assert args.resume_dir is not None 205 | 206 | if args.eval or args.eval_test: 207 | with open(os.path.join(args.resume_dir, 'args.json'), 'rt') as f: 208 | saved_args = json.load(f) 209 | args_dict = vars(args) 210 | heldout_params = {'run_name', 'output_dir', 'resume_dir', 'eval', 'eval_test', 'prefix_inpainting_seconds', 'num_samples', 'sampling_timesteps', 'sampling_schedule', 'scale', 'sampler', 'mixed_precision', 'guidance',} 211 | for k,v in saved_args.items(): 212 | if k in heldout_params: 213 | continue 214 | args_dict[k] = v 215 | 216 | main(args) -------------------------------------------------------------------------------- /training.md: -------------------------------------------------------------------------------- 1 | # Training Guide 2 | 3 | This guide lays out the various parameters in the codebase related to training and some commentary on best practices. We also list default values from the codebase. 4 | 5 | ## Standard Training Hyperparameters 6 | 7 | | Argument | Default | 8 | |-|-| 9 | | --optimizer | 'adamw' | 10 | | --batch_size | 16 | 11 | | --num_train_steps | 60000 | 12 | | --gradient_accumulation_steps | 1 | 13 | | --learning_rate | 1e-4 | 14 | | --clip_grad_norm | 1.0 | 15 | | --lr_schedule | 'cosine' | 16 | | --lr_warmup_steps | 1000 | 17 | | --adam_beta1 | 0.9 | 18 | | --adam_beta2 | 0.999 | 19 | | --adam_weight_decay | 0 | 20 | | --dropout | 0.0 | | 21 | | --clip_grad_norm | 1.0 | 22 | | --mixed_precision | 'no' | Accelerate argument | 23 | 24 | ### Commentary 25 | 26 | The training setup is pretty standard. We use the AdamW optimizer with a cosine learning rate schedule with a short linear warmup. Diffusion models are quite robust to overfitting so we don't employ any explicit regularization (weight decay or dropout). Regularization may be helpful if training the model for much longer. We clip gradients to a norm of 1.0 which improves training stability. 27 | 28 | For training the model, longer training times will be better. Our released checkpoint was trained for 200k steps with a global batch size of 256 (per-device batch size of 16, 2 gradient accumulation steps, and 8 GPUs). It is likely signficantly under-trained and further training would be beneficial. Large batch sizes tend to be beneficial for diffusion models due to the stochastisty of the training objective so distributed training and/or gradient accumulation are recommended. 29 | 30 | We trained our models with `bf16` mixed precision. 31 | 32 | ## Architecture Hyperparameters 33 | 34 | | Argument | Default | 35 | |-|-| 36 | | --dim | 512 | 37 | | --dim_mults | (1, 1, 1, 1.5) | 38 | | --conformer_transformer | False | 39 | | --scale_skip_connection | False | 40 | | --num_transformer_layers | 3 | 41 | 42 | ### Commentary 43 | The diffusion model consists of a U-Net and transformer. The first half of the U-Net downsamples the input to produce low-resolution features that are processed by the transformer. The output of the transformer is then upsampled by the second half of the U-Net to the original resolution to generate the final prediction. 44 | 45 | The structure of the U-Net model is determined by the `dim` and `dim_mults` arguments. The `dim` argument controls the initial dimensionality of the model. The `dim_mults` argument controls the number of downsampling layers in the U-Net and the feature dimensionality at each level. The feature dimensionality is defined as a multiple of the original `dim` value. 46 | 47 | Therefore, our U-Net model has 4 layers and the dimensionality in the middle of the network is `768=512*1.5`. The final dimensionality of the U-Net model is also the dimensionality of the transformer model. The `num_transformer_layers` and the final dimensionality of the U-Net control the size of the transformer. The transformer model contains the cross-attention layers and is therefore primarily responsible for the text-audio alignment. Our released model has a 768d transformer with 12 layers. 48 | 49 | To scale up the model, past work on image generation has shown that it's sufficient to scale the middle of the network, leaving the downsampling/upsampling layers unchanged. Therefore, scaling up the transformer dimensionality and depth is likely the most effective way to scale up the network. 50 | 51 | For diffusion models, it's important for the model dimensionality to be meaningfully larger than the input data. Given that the dimensionality of the EnCodec features is `128`, using `dim=512` for the initial dimensionality is a reasonable choice. I would be cautious about decreasing the input dimensionality. 52 | 53 | It's been shown for text-to-image diffusion models that scaling the U-Net skip connections by a constant factor significantly accelerates convergence. We used this trick (controlled by the `--scale_skip_connection` flag) when training our model, but did not investigate it's impact in detail. 54 | 55 | We also included an option to introduce a conformer-style convolution layer into the transformer (the `--conformer_transformer` flag), but did not end up using it in our primary model. Its use didn't seem to make a significant difference from our preliminary investigation, but we didn't explore it in great detail. The additional feedfoward block does meaningfully increase the size of the transformer so a fair comparison would need to control for that. 56 | 57 | ## Speaker-Prompted Generation Arguments 58 | 59 | | Argument | Default | 60 | |-|-| 61 | | --inpainting_embedding | False | 62 | | --inpainting_prob | 0.5 | | 63 | 64 | We train our model for both zero-shot TTS (i.e. generating speech given only the transcript) and speaker-prompted TTS (i.e. generating speech in the style of some speaker) in a mult-task manner. We train the model for speaker-prompted TTS by only adding noise to the latter portion of the audio latent, providing the model with a clean speech prompt. The `--inpainting_prob` flag controls the portion of instances used in the speaker-prompted setting. The `--inpainting_embedding` introduces binary embeddings which are added to the input to specify the prompt speech frames. We recommend enabling this flag. 65 | 66 | ## Diffusion Hyperparameters 67 | 68 | | Argument | Default | 69 | |-|-| 70 | | --objective | 'pred_v' | 71 | | --parameterization | 'pred_v' | 72 | | --loss_type | 'l1' | 73 | | --scale | 1.0 | 74 | | --unconditional_prob | 0.1 | 75 | 76 | Diffusion models consist of a denoising network that accepts some noisy data as input and attempts to recover the original data. In practice, this network can be parameterized in a variety of different ways. We parameterize our denoising network as a velocity prediction (or v-prediction) network. See [[1]](https://openreview.net/forum?id=TIdIXIpzhoI) for a discussion of the various parameterizations. The v-parameterization has been pretty widely adopted (e.g. [[2]](https://huggingface.co/stabilityai/stable-diffusion-2-1)[[3]](https://arxiv.org/abs/2301.11093)[[4]](https://arxiv.org/abs/2204.03458)) and is therefore a reasonable choice. One nuance is that you can treat the output of the network (`--parameterization`) and the loss function (`--objective`) as two separate design decisions (also discussed in [[1]](https://openreview.net/forum?id=TIdIXIpzhoI)). However, setting both to pred_v is a reasonable default choice. 77 | 78 | The noise schedule is one of the most critical hyperparameters for the quality of the generations. We use the widespread cosine noise schedule and find adjusting the scale factor of the noise schedule to be important for assuring text-speech alignment. See [[5]](https://arxiv.org/abs/2301.10972) for a detailed discussion of this choice. We set the `--scale` flag to `0.5` for our final checkpoint. 79 | 80 | Diffusion models are trained with a regression loss. Both the `l1` and `l2` loss are commonly used in the literature. There's some work [[6]](https://arxiv.org/abs/2111.05826) suggesting that the `l1` loss leads to more conservative generations while the `l2` loss leads to more diverse generations, potentially at the cost of quality. We utilized the `l1` loss to emphasize quality over diversity, but the `l2` may be adviseable depending on the application. 81 | 82 | To enable the use of classifier-free guidance [[7]](https://arxiv.org/abs/2207.12598), we drop the conditioning information with some probability. The probability is controlled by the `--unconditional_prob` flag. It should generally be in the `0.1-0.2` range and we use `0.1` by default. 83 | 84 | ## Validation Parameters 85 | 86 | | Argument | Default | 87 | |-|-| 88 | | --save_and_sample_every | 5000 | 89 | | --sampling_timesteps | 250 | 90 | | --num_samples | None | 91 | | --sampler | 'ddim' | 92 | 93 | We generate some samples every `--save_and_sample_every` steps from the validation set. The `--num_samples` flag controls how many validation samples to generate and the `--sampling_timesteps` controls the number of timesteps used for generation. The `--num_samples` flag should be set reasonably low (e.g. 128) to avoid spending too much time on validation. We discuss sampling in more detail in the other README. -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datetime import datetime 3 | import os 4 | from pathlib import Path 5 | import argparse 6 | 7 | def compute_grad_norm(parameters): 8 | # implementation adapted from https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html#clip_grad_norm_ 9 | parameters = [p for p in parameters if p.grad is not None] 10 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), p=2) for p in parameters]), p=2).item() 11 | return total_norm 12 | 13 | # TODO: update 14 | def get_output_dir(args): 15 | model_dir = f'{args.dataset_name}/{args.run_name}/' 16 | output_dir = os.path.join(args.save_dir, model_dir) 17 | return output_dir 18 | 19 | def parse_float_tuple(dim_mults_str): 20 | try: 21 | dim_mults = tuple(map(float, dim_mults_str.split(','))) 22 | return dim_mults 23 | except ValueError: 24 | raise argparse.ArgumentTypeError('dim_mults must be a comma-separated list of integers') 25 | --------------------------------------------------------------------------------