├── .gitignore
├── LICENSE
├── README.md
├── audio_datasets
├── __init__.py
├── constants.py
├── helpers.py
├── librispeech.py
└── text_grid_utils.py
├── compute_wer.py
├── data
└── aligned_librispeech.tar.gz
├── diffusion
├── audio_denoising_diffusion.py
├── noise_schedule.py
├── optimizer.py
└── time_sampler.py
├── evaluation
└── evaluate_transcript.py
├── img
├── esd.png
└── results.png
├── models
├── modules
│ ├── __init__.py
│ ├── blocks.py
│ ├── conv.py
│ ├── norm.py
│ ├── position.py
│ └── transformer.py
├── transformer_wrapper.py
└── unet.py
├── neural_codec
└── encodec_wrapper.py
├── requirements.txt
├── scripts
├── sample
│ └── sample_16_ls_testclean.sh
└── train
│ └── train.sh
├── train_audio_diffusion.py
└── utils
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Justin Lovelace
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Sample-Efficient Diffusion for Text-To-Speech Synthesis
2 |
3 | This repository contains the implementation for SESD, a latent diffusion model for text-to-speech generation. We introduced this model in the following work:
4 |
5 | **Sample-Efficient Diffusion for Text-To-Speech Synthesis**
6 | Justin Lovelace, Soham Ray, Kwangyoun Kim, Kilian Q. Weinberger, Felix Wu
7 | Interspeech 2024
8 | [\[paper\]](https://www.arxiv.org/abs/2409.03717)
9 |
10 | ## Abstract
11 |
12 | This work introduces Sample-Efficient Speech Diffusion (SESD), an algorithm for effective speech synthesis in modest data regimes through latent diffusion. It is based on a novel diffusion architecture, that we call U-Audio Transformer (U-AT), that efficiently scales to long sequences and operates in the latent space of a pre-trained audio autoencoder. Conditioned on character-aware language model representations, SESD achieves impressive results despite training on less than 1k hours of speech – far less than current state-of-the-art systems. In fact, it synthesizes more intelligible speech than the state-of-the-art auto-regressive model, VALL-E, while using less than 2% the training data.
13 |
14 |
15 |

16 |
17 |
Overview of our Sample-Efficient Speech Diffusion (SESD) architecture.
18 |
19 |

20 |
21 |
22 | If you find this work useful, please consider citing:
23 |
24 | ```bibtex
25 | @inproceedings{lovelace2024sample,
26 | title={Sample-Efficient Diffusion for Text-To-Speech Synthesis},
27 | author={Lovelace, Justin and Ray, Soham and Kim, Kwangyoun and Weinberger, Kilian Q and Wu, Felix},
28 | booktitle={Proc. Interspeech 2024},
29 | pages={4403--4407},
30 | year={2024}
31 | }
32 | ```
33 |
34 | ## Installation
35 |
36 | Install the required dependencies:
37 | ```bash
38 | pip install -r requirements.txt
39 | ```
40 |
41 | ## Datasets
42 |
43 | We train and evaluate our models using the LibriSpeech dataset from the Hugging Face Hub and use the standard LibriSpeech dataset for evaluation.
44 |
45 | ### Speaker Prompt Data
46 |
47 | For speaker-prompted generation, we utilize three seconds of another prompt. To extract the corresponding transcript for the speech, we utilized the [Montreal Forced Aligner](https://montreal-forced-aligner.readthedocs.io/en/latest/first_steps/example.html#example-1-aligning-librispeech-english). An aligned version of LibriSpeech can be found at:
48 |
49 | ```bash
50 | data/aligned_librispeech.tar.gz
51 | ```
52 |
53 | After extracting the archive, update the `ALIGNED_DATA_DIR` path in `audio_datasets/constants.py` to point to your data directory.
54 |
55 | ## Training
56 |
57 | Our training setup:
58 | - Single Nvidia A6000 GPU
59 | - BF16 mixed precision training
60 | - Batch size and other parameters may need adjustment based on your hardware
61 |
62 | To train the diffusion model:
63 | ```bash
64 | ./scripts/train/train.sh
65 | ```
66 |
67 | ## Model Checkpoint
68 |
69 | Model checkpoint will be released soon!
70 |
71 | ## Inference
72 |
73 | To synthesize speech for the LibriSpeech test-clean set:
74 | ```bash
75 | ./scripts/sample/sample_16_ls_testclean.sh
76 | ```
77 |
78 | Note: Update the `--resume_dir` argument with the path to your trained model.
79 |
80 | ## Questions and Support
81 |
82 | Feel free to create an issue if you have any questions.
83 |
84 | ## Acknowledgements
85 |
86 | This work built upon excellent open-source implementations from [Phil Wang (Lucidrains)](https://github.com/lucidrains). Specifically, we built off of his PyTorch [DDPM implementation](https://github.com/lucidrains/denoising-diffusion-pytorch).
--------------------------------------------------------------------------------
/audio_datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from audio_datasets.librispeech import LibriSpeech
2 | from audio_datasets.constants import ENCODEC_REDUCTION_FACTOR, ENCODEC_SAMPLING_RATE, LATENT_SAMPLING_RATE
3 |
--------------------------------------------------------------------------------
/audio_datasets/constants.py:
--------------------------------------------------------------------------------
1 | MAX_DURATION_IN_SECONDS = 20
2 | ENCODEC_SAMPLING_RATE = 24000
3 | LIBRISPEECH_SAMPLING_RATE = 16000
4 | ENCODEC_REDUCTION_FACTOR = 320
5 | LATENT_SAMPLING_RATE = 75
6 | ALIGNED_DATA_DIR = '/home/jl3353/stable-audio-diffusion/data/aligned_librispeech/'
--------------------------------------------------------------------------------
/audio_datasets/helpers.py:
--------------------------------------------------------------------------------
1 | from audio_datasets.constants import ENCODEC_REDUCTION_FACTOR, ENCODEC_SAMPLING_RATE, MAX_DURATION_IN_SECONDS
2 |
3 |
4 | def round_up_to_multiple(number, multiple):
5 | remainder = number % multiple
6 | if remainder == 0:
7 | return number
8 | else:
9 | return number + (multiple - remainder)
10 |
11 | def round_up_to_waveform_multiple(number, multiple=16):
12 | waveform_multiple = multiple*ENCODEC_REDUCTION_FACTOR
13 | rounded_number = round_up_to_multiple(number, waveform_multiple)
14 | return rounded_number
15 |
16 | def compute_max_length(multiple=16):
17 | max_len = MAX_DURATION_IN_SECONDS*ENCODEC_SAMPLING_RATE
18 |
19 | max_len = round_up_to_waveform_multiple(max_len, multiple)
20 | return max_len
21 |
22 | def is_audio_length_in_range(audio, sampling_rate=ENCODEC_SAMPLING_RATE):
23 | return len(audio['array']) <= (MAX_DURATION_IN_SECONDS*sampling_rate)
24 |
25 | def is_audio_length_in_test_range(audio):
26 | return ((4*ENCODEC_SAMPLING_RATE) <= len(audio['array'])) and (len(audio['array']) <= (10*ENCODEC_SAMPLING_RATE))
27 |
--------------------------------------------------------------------------------
/audio_datasets/librispeech.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, DataLoader
2 | from datasets import load_dataset, Audio, concatenate_datasets
3 | import numpy as np
4 | import torch
5 | from tqdm import tqdm
6 | import torch.nn.functional as F
7 | import os
8 | import random
9 | import math
10 | import json
11 |
12 | from audio_datasets.text_grid_utils import get_partial_transcript
13 | from audio_datasets.helpers import compute_max_length, round_up_to_waveform_multiple, is_audio_length_in_range, is_audio_length_in_test_range
14 | from audio_datasets.constants import ENCODEC_REDUCTION_FACTOR, ENCODEC_SAMPLING_RATE, MAX_DURATION_IN_SECONDS, ALIGNED_DATA_DIR, LATENT_SAMPLING_RATE
15 |
16 | def lowercase_text(example):
17 | example["text"] = example["text"].lower()
18 | return example
19 |
20 | class LibriSpeech(Dataset):
21 | """
22 | Wrapper around HuggingFace dataset for processing.
23 | """
24 | def __init__(self, split='train', debug=False, tokenizer=None, max_seq_len=None, sampling_rate=None, duration_path=None):
25 | super().__init__()
26 | self.sr = ENCODEC_SAMPLING_RATE if sampling_rate is None else sampling_rate
27 | self.split = split
28 | self.split2dir = {'valid': 'dev-clean', 'test': 'test-clean'}
29 | if split == 'train':
30 | train100 = load_dataset('librispeech_asr', 'clean', split='train.100',)
31 | train360 = load_dataset('librispeech_asr', 'clean', split='train.360', )
32 | train500 = load_dataset('librispeech_asr', 'other', split='train.500', )
33 |
34 | self.hf_dataset = concatenate_datasets([train100, train360, train500])
35 | elif split == 'valid':
36 | self.hf_dataset = load_dataset('librispeech_asr', 'clean', split='validation')
37 | elif split == 'test':
38 | self.hf_dataset = load_dataset('librispeech_asr', 'clean', split='test')
39 | else:
40 | raise ValueError(f"invalid split: {split}, must be in ['train', 'valid'] ")
41 |
42 | # Downsample to accelerate processing for debugging purposes
43 | if debug:
44 | self.hf_dataset = self.hf_dataset.select(range(100))
45 | # Resample to 24kHz for Encodec
46 | self.hf_dataset = self.hf_dataset.cast_column("audio", Audio(sampling_rate=self.sr))
47 |
48 | self.hf_dataset = self.hf_dataset.map(lowercase_text)
49 | if split == 'train':
50 | self.hf_dataset = self.hf_dataset.filter(is_audio_length_in_range, input_columns=['audio'])
51 | elif split == 'test':
52 | self.hf_dataset = self.hf_dataset.filter(is_audio_length_in_test_range, input_columns=['audio'])
53 |
54 | if self.split in {'valid', 'test'}:
55 | unique_speaker_ids = set(self.hf_dataset['speaker_id'])
56 | self.speaker_datasets = {speaker_id:self.hf_dataset.filter(lambda example: example["speaker_id"] == speaker_id) for speaker_id in unique_speaker_ids}
57 |
58 |
59 | self.max_seq_len = max_seq_len if max_seq_len is not None else compute_max_length()
60 | print(f'Max seq length: {self.max_seq_len/ENCODEC_REDUCTION_FACTOR}')
61 |
62 | if tokenizer is not None:
63 | self.hf_dataset = self.hf_dataset.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True, max_length=256), batched=True)
64 | self.tokenizer = tokenizer
65 |
66 | self.duration_dict = None
67 | if split == 'test' and duration_path is not None:
68 | # Read duration path json
69 | with open(duration_path, 'rt') as f:
70 | self.duration_dict = json.load(f)
71 | assert len(self.duration_dict['nucleus_pred']) == len(self.hf_dataset), f'Length of duration dict {len(self.duration_dict)} does not match length of dataset {len(self.hf_dataset)}'
72 | self.nucleus_pred_duration = [float(d) for d in self.duration_dict['nucleus_pred']]
73 |
74 | def __getitem__(self, index):
75 | example = self.hf_dataset[index]
76 | text = example['text']
77 | wav = example['audio']['array'][:self.max_seq_len]
78 | wavpath = example['audio']['path']
79 | npad = self.max_seq_len - len(wav)
80 | assert npad>=0, f'Waveform length {len(wav)} needs to be less than {self.max_seq_len}'
81 | wav_len = len(wav)
82 | # [1, L]: Channels x length
83 | audio_duration_sec = wav_len/ENCODEC_SAMPLING_RATE
84 | wav = torch.tensor(np.pad(wav, pad_width=(0, npad), mode='constant'), dtype=torch.float).unsqueeze(0)
85 |
86 | audio_mask = torch.zeros((self.max_seq_len//ENCODEC_REDUCTION_FACTOR,), dtype=torch.bool)
87 | if self.duration_dict is None:
88 | num_unmasked_tokens = round_up_to_waveform_multiple(wav_len)//ENCODEC_REDUCTION_FACTOR
89 | audio_mask[:num_unmasked_tokens] = True
90 | else:
91 | audio_duration_sec = self.nucleus_pred_duration[index]
92 | wav_len = int(audio_duration_sec*ENCODEC_SAMPLING_RATE)
93 | num_unmasked_tokens = round_up_to_waveform_multiple(wav_len)//ENCODEC_REDUCTION_FACTOR
94 | audio_mask[:num_unmasked_tokens] = True
95 |
96 | data = {'wav': wav, 'text': text, 'audio_duration': audio_duration_sec, 'path':wavpath}
97 | data['audio_mask'] = audio_mask
98 |
99 |
100 | # Speaker prompting
101 | if self.split in {'valid', 'test'}:
102 | split_dir = self.split2dir[self.split]
103 | speaker_id = example['speaker_id']
104 | speaker_ds = self.speaker_datasets[speaker_id]
105 | # Sample idx for n-1 elements and remap matching element to the last element
106 | speaker_idx = random.randint(0, len(speaker_ds)-2)
107 | if speaker_ds[speaker_idx]['id'] == example['id']:
108 | speaker_idx = len(speaker_ds)-1
109 | speaker_example = speaker_ds[speaker_idx]
110 | textgrid_path = os.path.join(ALIGNED_DATA_DIR, split_dir, f'{speaker_id}', f'{speaker_example["id"]}.TextGrid')
111 | partial_transcript = get_partial_transcript(textgrid_path)
112 |
113 | speaker_text = partial_transcript['transcript']
114 | speaker_start_time = partial_transcript['start_time']
115 | speaker_end_time = partial_transcript['end_time']
116 | # Convert seconds to encodec frame rate
117 |
118 | speaker_start_frame = math.floor(speaker_start_time * ENCODEC_SAMPLING_RATE)
119 | speaker_end_frame = math.ceil(speaker_end_time * ENCODEC_SAMPLING_RATE)
120 | speaker_wav_frames = speaker_end_frame - speaker_start_frame
121 | speaker_audio_duration_sec = speaker_wav_frames/ENCODEC_SAMPLING_RATE
122 | speaker_wav = speaker_example['audio']['array'][speaker_start_frame:speaker_end_frame]
123 | speaker_npad = self.max_seq_len - len(speaker_wav)
124 | assert speaker_npad>=0, f'Waveform length {len(speaker_wav)} needs to be less than {self.max_seq_len}'
125 | # [1, L]: Channels x length
126 |
127 | speaker_wav = torch.tensor(np.pad(speaker_wav, pad_width=(0, speaker_npad), mode='constant'), dtype=torch.float).unsqueeze(0)
128 |
129 | speaker_data = {'speaker_wav': speaker_wav, 'speaker_text': speaker_text, 'speaker_audio_duration': speaker_audio_duration_sec}
130 | data.update(speaker_data)
131 |
132 | inpaint_audio_mask = torch.zeros((self.max_seq_len//ENCODEC_REDUCTION_FACTOR,), dtype=torch.bool)
133 | num_unmasked_tokens = round_up_to_waveform_multiple(speaker_wav_frames+wav_len)//ENCODEC_REDUCTION_FACTOR
134 | inpaint_audio_mask[:num_unmasked_tokens] = True
135 | data['inpaint_audio_mask'] = inpaint_audio_mask
136 |
137 | if self.tokenizer is not None:
138 | data['input_ids'] = torch.tensor(example['input_ids'], dtype=torch.long)
139 | data['attention_mask'] = torch.tensor(example['attention_mask'], dtype=torch.long)
140 |
141 | return data
142 |
143 | def __len__(self):
144 | return len(self.hf_dataset)
145 |
146 | if __name__ == "__main__":
147 | dataset = LibriSpeech(split='test')
148 |
149 | example = dataset.__getitem__(0)
150 | import soundfile as sf
151 | import pdb; pdb.set_trace()
152 | sf.write(f'example_audio/librispeech_sample.wav', example['wav'], ENCODEC_SAMPLING_RATE)
153 | with open(f'example_audio/librispeech_text.txt', 'w') as f:
154 | print(example['text'], file=f)
155 |
--------------------------------------------------------------------------------
/audio_datasets/text_grid_utils.py:
--------------------------------------------------------------------------------
1 | from praatio import textgrid
2 | import os
3 |
4 | data_path = '../data/aligned_librispeech/dev-clean/84/84-121123-0002.TextGrid'
5 |
6 | # tg = textgrid.openTextgrid(data_path, False)
7 | def get_word_intervals(textgrid_path):
8 | tg = textgrid.openTextgrid(textgrid_path, False)
9 | return tg.getTier("words").entries
10 |
11 | def get_partial_transcript(textgrid_path, transcript_end_time=3):
12 | intervals = get_word_intervals(textgrid_path)
13 | word_list = []
14 | start_time = 0
15 | end_time = intervals[-1].end
16 | # Reverse intervals to get the last word first
17 | intervals = intervals[::-1]
18 | for interval in intervals:
19 | if (end_time - interval.start) > transcript_end_time:
20 | break
21 | word_list.append(interval.label)
22 | start_time = interval.start
23 | # Reverse word_list to get the words in the correct order
24 | word_list = word_list[::-1]
25 | return {'transcript': ' '.join(word_list),
26 | 'end_time': end_time,
27 | 'start_time': start_time,}
28 |
29 |
30 | if __name__ == "__main__":
31 | partial_transcript = get_partial_transcript(data_path)
32 | print(partial_transcript)
--------------------------------------------------------------------------------
/compute_wer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | from audio_datasets import LibriSpeech
5 | from evaluation.evaluate_transcript import compute_wer
6 | from utils.utils import parse_float_tuple
7 |
8 | def get_wer(sample_dir, text_list, num_samples, guidances, prefix_seconds=0.0):
9 | wer_dict = {}
10 | for guidance in guidances:
11 | filepaths_list = []
12 | for i in range(num_samples):
13 | if prefix_seconds > 0:
14 | filepaths_list.append(os.path.join(sample_dir, f'guide{guidance:.1f}_prefix{prefix_seconds:.1f}', f'audio_{i}.wav'))
15 | else:
16 | filepaths_list.append(os.path.join(sample_dir, f'guide{guidance:.1f}', f'audio_{i}.wav'))
17 |
18 | text_list = text_list[:num_samples]
19 | wer = compute_wer(filepaths_list, text_list)
20 | wer_dict[guidance] = wer
21 | print(f'WER for guidance {guidance}: {wer*100:.1f}')
22 | return wer_dict
23 |
24 | def main(args):
25 | test_ls_dataset = LibriSpeech(split='test')
26 | text_list = test_ls_dataset.hf_dataset['text']
27 |
28 | wer_dict = get_wer(args.sample_dir, text_list, args.num_samples, args.guidance, args.prefix_seconds)
29 |
30 | # Save wer_dict to a json file
31 | with open(os.path.join(args.sample_dir, 'wer.json'), 'w') as f:
32 | json.dump(wer_dict, f)
33 | # Print wer_dict
34 | print(wer_dict)
35 |
36 | if __name__ == "__main__":
37 | parser = argparse.ArgumentParser(description="Training arguments")
38 | parser.add_argument("--sample_dir", type=str, default='saved_models/librispeech/test/librispeech_250k/samples/step_100000')
39 | parser.add_argument("--num_samples", type=int, default=1237)
40 | parser.add_argument('--guidance', type=parse_float_tuple, help='Tuple of float values for dim_mults')
41 | parser.add_argument('--prefix_seconds', type=float, default=0.0)
42 |
43 | args = parser.parse_args()
44 | main(args)
--------------------------------------------------------------------------------
/data/aligned_librispeech.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/justinlovelace/SESD/c59b41d7ec98922bdb6bce4217ebcf6cd1c0d2b2/data/aligned_librispeech.tar.gz
--------------------------------------------------------------------------------
/diffusion/audio_denoising_diffusion.py:
--------------------------------------------------------------------------------
1 | import math
2 | from pathlib import Path
3 | import random
4 | from functools import partial
5 | from collections import namedtuple, Counter
6 | from multiprocessing import cpu_count
7 | import os
8 | import numpy as np
9 | import csv
10 | import timeit
11 | import json
12 | import argparse
13 | from collections import defaultdict
14 | from contextlib import nullcontext
15 | import soundfile as sf
16 |
17 |
18 | import torch
19 | from torch import nn, einsum
20 | import torch.nn.functional as F
21 | from torch.utils.data import Dataset, DataLoader
22 |
23 | from einops import rearrange, reduce, repeat
24 |
25 | from tqdm.auto import tqdm
26 | from ema_pytorch import EMA
27 |
28 | from transformers import get_scheduler, AutoTokenizer, T5ForConditionalGeneration
29 |
30 | from accelerate import Accelerator, DistributedDataParallelKwargs
31 |
32 | from diffusion.optimizer import get_adamw_optimizer, get_lion_optimizer
33 | from utils.utils import compute_grad_norm
34 | import utils.utils as file_utils
35 | from diffusion.noise_schedule import *
36 | from diffusion.time_sampler import LossEMASampler
37 | from audio_datasets import LibriSpeech, ENCODEC_SAMPLING_RATE, LATENT_SAMPLING_RATE, ENCODEC_REDUCTION_FACTOR
38 | from neural_codec.encodec_wrapper import EncodecWrapper
39 | from utils.utils import get_output_dir
40 |
41 |
42 | ModelPrediction = namedtuple('ModelPrediction', ['pred_eps', 'pred_x_start', 'pred_v'])
43 |
44 | # Recommendation from https://arxiv.org/abs/2303.09556
45 | MIN_SNR_GAMMA = 5
46 |
47 | # helpers functions
48 |
49 | def exists(x):
50 | return x is not None
51 |
52 | def default(val, d):
53 | if exists(val):
54 | return val
55 | return d() if callable(d) else d
56 |
57 | def identity(t, *args, **kwargs):
58 | return t
59 |
60 | def cycle(dl):
61 | while True:
62 | for data in dl:
63 | yield data
64 |
65 | def num_to_groups(num, divisor):
66 | groups = num // divisor
67 | remainder = num % divisor
68 | arr = [divisor] * groups
69 | if remainder > 0:
70 | arr.append(remainder)
71 | return arr
72 |
73 | def l2norm(t):
74 | return F.normalize(t, dim = -1)
75 |
76 | # Avoid log(0)
77 | def log(t, eps = 1e-20):
78 | return torch.log(t.clamp(min = eps))
79 |
80 | def right_pad_dims_to(x, t):
81 | padding_dims = x.ndim - t.ndim
82 | if padding_dims <= 0:
83 | return t
84 | return t.view(*t.shape, *((1,) * padding_dims))
85 |
86 | # gaussian diffusion trainer class
87 |
88 | def extract(a, t, x_shape):
89 | b, *_ = t.shape
90 | out = a.gather(-1, t)
91 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
92 |
93 | def set_seeds(seed):
94 | random.seed(seed)
95 | torch.manual_seed(seed)
96 | torch.cuda.manual_seed(seed)
97 |
98 | def masked_mean(t, *, dim, mask = None):
99 | if not exists(mask):
100 | return t.mean(dim = dim)
101 |
102 | denom = mask.sum(dim = dim)
103 | masked_t = t.masked_fill(~mask, 0.)
104 |
105 | return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)
106 |
107 |
108 | class GaussianDiffusion(nn.Module):
109 | def __init__(
110 | self,
111 | model,
112 | *,
113 | max_seq_len,
114 | sampling_timesteps = 250,
115 | text_encoder = 'google/byt5-small',
116 | loss_type = 'l1',
117 | objective = 'pred_v',
118 | parameterization = 'pred_v',
119 | train_schedule = 'cosine',
120 | ema_decay = 0.9999,
121 | sampling_schedule = None,
122 | scale = 1.,
123 | scale_by_std=True,
124 | sampler = 'ddim',
125 | log_var_interp = 0.2,
126 | langevin_step_size = 0.0,
127 | snr_ddim_threshold= None,
128 | unconditional_prob = 0.2,
129 | inpainting_prob = 0.5,
130 | inpainting_duration_mode = 0.01,
131 | inpainting_duration_concentration = 5,
132 | loss_weighting = 'edm',
133 | loss_weighting_args = {},
134 | ):
135 | super().__init__()
136 |
137 | self.denoising_network = EMA(model, beta = ema_decay, update_every = 1, power=3/4)
138 |
139 | self.max_seq_len = max_seq_len
140 |
141 | self.objective = objective
142 | self.parameterization = parameterization
143 | self.sampler=sampler
144 | self.log_var_interp = log_var_interp
145 | self.langevin_step_size = langevin_step_size
146 | self.snr_ddim_threshold = snr_ddim_threshold
147 |
148 | # Min-SNR weighting from https://arxiv.org/abs/2303.09556
149 | # Option no longer supported; buffer kept for backwards compatibility with statedict of prior checkpoints
150 | self.register_buffer('min_snr_gamma', torch.tensor(MIN_SNR_GAMMA))
151 |
152 | self.loss_type = loss_type
153 |
154 | assert objective == 'pred_v', f'objective {objective} must be pred_v'
155 |
156 | self.adaptive_noise_schedule = (train_schedule == 'adaptive')
157 | if self.adaptive_noise_schedule:
158 | self.train_schedule = LossEMASampler()
159 | self.val_ema_sampler = LossEMASampler(ema_decay=0.0)
160 | if loss_weighting == 'v_weighting':
161 | assert objective == 'pred_v', f'objective {objective} must be pred_v for v weighting'
162 | self.diffusion_loss_weighting_obj = V_Weighting()
163 | self.diffusion_loss_weighting = self.diffusion_loss_weighting_obj.v_loss_weighting
164 | elif loss_weighting == 'lognormal_v_weighting':
165 | gamma_mean = loss_weighting_args['loss_weighting_mean']
166 | gamma_std = loss_weighting_args['loss_weighting_std']
167 | assert objective == 'pred_v', f'objective {objective} must be pred_v for lognormal v weighting'
168 | self.diffusion_loss_weighting_obj = LogNormal_V_Weighting(gamma_mean=gamma_mean, gamma_std=gamma_std, objective=objective)
169 | self.diffusion_loss_weighting = self.diffusion_loss_weighting_obj.v_loss_weighting
170 | elif loss_weighting == 'asymmetric_lognormal_v_weighting':
171 | gamma_mean = loss_weighting_args['loss_weighting_mean']
172 | gamma_std = loss_weighting_args['loss_weighting_std']
173 | gamma_std_mult = loss_weighting_args['loss_weighting_std_mult']
174 | assert objective == 'pred_v', f'objective {objective} must be pred_v for lognormal v weighting'
175 | self.diffusion_loss_weighting_obj = Asymmetric_LogNormal_V_Weighting(gamma_mean=gamma_mean, gamma_std=gamma_std, objective=objective, std_mult=gamma_std_mult)
176 | self.diffusion_loss_weighting = self.diffusion_loss_weighting_obj.v_loss_weighting
177 | else:
178 | raise ValueError(f'invalid loss weighting {loss_weighting}')
179 | else:
180 | self.logsnr_loss_tracker = LossEMASampler()
181 | if train_schedule == "simple_linear":
182 | alpha_schedule = simple_linear_schedule
183 | elif train_schedule == "beta_linear":
184 | alpha_schedule = beta_linear_schedule
185 | elif train_schedule == "cosine":
186 | alpha_schedule = cosine_schedule
187 | elif train_schedule == "sigmoid":
188 | alpha_schedule = sigmoid_schedule
189 | else:
190 | raise ValueError(f'invalid noise schedule {train_schedule}')
191 |
192 | self.train_schedule = partial(time_to_alpha, alpha_schedule=alpha_schedule, scale=scale)
193 |
194 | # Sampling schedule
195 | if sampling_schedule is None:
196 | sampling_alpha_schedule = None
197 | elif sampling_schedule == "simple_linear":
198 | sampling_alpha_schedule = simple_linear_schedule
199 | elif sampling_schedule == "beta_linear":
200 | sampling_alpha_schedule = beta_linear_schedule
201 | elif sampling_schedule == "cosine":
202 | sampling_alpha_schedule = cosine_schedule
203 | elif sampling_schedule == "sigmoid":
204 | sampling_alpha_schedule = sigmoid_schedule
205 | else:
206 | raise ValueError(f'invalid sampling schedule {sampling_schedule}')
207 |
208 | if exists(sampling_alpha_schedule):
209 | self.sampling_schedule = partial(time_to_alpha, alpha_schedule=sampling_alpha_schedule, scale=scale)
210 | else:
211 | self.sampling_schedule = self.train_schedule
212 |
213 |
214 | # Optionally rescale data to have unit variance
215 | self.scale_by_std = scale_by_std
216 | if scale_by_std:
217 | self.register_buffer('data_mean', torch.full((128,), fill_value=0.0))
218 | self.register_buffer('data_std', torch.full((128,), fill_value=-1.0))
219 | else:
220 | raise NotImplementedError
221 |
222 | # gamma schedules
223 |
224 | self.sampling_timesteps = sampling_timesteps
225 |
226 | # probability for self conditioning during training
227 |
228 | self.unconditional_prob = unconditional_prob
229 | self.inpainting_prob = inpainting_prob
230 |
231 | if self.unconditional_prob > 0:
232 | self.unconditional_bernoulli = torch.distributions.Bernoulli(probs=self.unconditional_prob)
233 | if self.inpainting_prob > 0:
234 | self.inpainting_bernoulli = torch.distributions.Bernoulli(probs=self.inpainting_prob)
235 | # Mode/Concentration parameterization of Beta distribution
236 | alpha = inpainting_duration_mode*(inpainting_duration_concentration-2) + 1
237 | beta = (1-inpainting_duration_mode)*(inpainting_duration_concentration-2)+1
238 | self.inpainting_duration_beta = torch.distributions.Beta(alpha, beta)
239 |
240 | self.text_encoder_id = text_encoder
241 | self.text_encoder = T5ForConditionalGeneration.from_pretrained(text_encoder, torch_dtype=torch.bfloat16).get_encoder()
242 | for param in self.text_encoder.parameters():
243 | param.requires_grad = False
244 | self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder)
245 |
246 | self.audio_codec = EncodecWrapper()
247 | for param in self.audio_codec.parameters():
248 | param.requires_grad = False
249 |
250 |
251 | def predict_start_from_noise(self, z_t, alpha, noise, sampling=False):
252 | alpha = right_pad_dims_to(z_t, alpha)
253 |
254 | return (z_t - (1-alpha).sqrt() * noise) / alpha.sqrt().clamp(min = 1e-8)
255 |
256 | def predict_noise_from_start(self, z_t, alpha, x0, sampling=False):
257 | alpha = right_pad_dims_to(z_t, alpha)
258 |
259 | return (z_t - alpha.sqrt() * x0) / (1-alpha).sqrt().clamp(min = 1e-8)
260 |
261 | def predict_start_from_v(self, z_t, alpha, v, sampling=False):
262 | alpha = right_pad_dims_to(z_t, alpha)
263 |
264 | x = alpha.sqrt() * z_t - (1-alpha).sqrt() * v
265 |
266 | return x
267 |
268 | def predict_noise_from_v(self, z_t, alpha, v, sampling=False):
269 | alpha = right_pad_dims_to(z_t, alpha)
270 |
271 | eps = (1-alpha).sqrt() * z_t + alpha.sqrt() * v
272 |
273 | return eps
274 |
275 | def predict_v_from_start_and_eps(self, z_t, alpha, x, noise, sampling=False):
276 | alpha = right_pad_dims_to(z_t, alpha)
277 |
278 | v = alpha.sqrt() * noise - x* (1-alpha).sqrt()
279 |
280 | return v
281 |
282 | def diffusion_model_predictions(self, z_t, alpha, *, text_cond, text_cond_mask, audio_mask, sampling=False, cls_free_guidance=1.0, fill_mask=None):
283 | time_cond = alpha.sqrt()
284 | inpainting_mask = fill_mask[:, 0, :].long()
285 | if sampling:
286 | model_output = self.denoising_network.ema_model(z_t, time_cond, text_cond=text_cond, text_cond_mask=text_cond_mask, inpainting_mask=inpainting_mask, audio_mask=audio_mask)
287 | if cls_free_guidance != 1.0:
288 | unc_text_cond = torch.zeros_like(text_cond)[:,:1,:]
289 | unc_text_cond_mask = torch.full_like(text_cond_mask, fill_value=False)[:,:1]
290 | if torch.sum(fill_mask) > 0:
291 | alpha = rearrange(alpha, 'b -> b () ()')
292 | noise = torch.randn_like(z_t)
293 | z_t[fill_mask] = (z_t*alpha.sqrt() + (1-alpha).sqrt()*noise)[fill_mask]
294 | unc_inpainting_mask = torch.full_like(inpainting_mask, fill_value=0)
295 | unc_model_output = self.denoising_network.ema_model(z_t, time_cond, text_cond=unc_text_cond, text_cond_mask=unc_text_cond_mask, inpainting_mask=unc_inpainting_mask, audio_mask=audio_mask)
296 | model_output = model_output*cls_free_guidance + unc_model_output*(1-cls_free_guidance)
297 | else:
298 | model_output = self.denoising_network.online_model(z_t, time_cond, text_cond=text_cond, text_cond_mask=text_cond_mask, inpainting_mask=inpainting_mask, audio_mask=audio_mask)
299 |
300 | if self.parameterization == 'pred_v':
301 | pred_v = model_output
302 | x_start = self.predict_start_from_v(z_t, alpha, pred_v, sampling=sampling)
303 | pred_eps = self.predict_noise_from_v(z_t, alpha, pred_v, sampling=sampling)
304 | else:
305 | raise ValueError(f'invalid objective {self.parameterization}')
306 |
307 | return ModelPrediction(pred_eps, x_start, pred_v)
308 |
309 | def get_sampling_timesteps(self, batch, *, device, start_time=1.0):
310 | times = torch.linspace(start_time, 0., self.sampling_timesteps + 1, device = device)
311 | times = repeat(times, 't -> b t', b = batch)
312 | times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0)
313 | times = times.unbind(dim = -1)
314 | return times
315 |
316 | @torch.no_grad()
317 | def ddim_or_ddpm_sample(self, shape, text_cond, text_cond_mask, audio_mask, prefix_seconds=0, audio_latent=None, cls_free_guidance=1.0, speaker_frames=None, sampler='ddim', log_var_interp=.2, langevin_step_size=0.0, snr_ddim_threshold=None):
318 | batch, device = shape[0], next(self.denoising_network.ema_model.parameters()).device
319 |
320 | time_pairs = self.get_sampling_timesteps(batch, device = device)
321 |
322 | z_t = torch.randn(shape, device=device)
323 |
324 | fill_mask = None
325 |
326 | if prefix_seconds > 0:
327 | assert exists(audio_latent)
328 | if exists(speaker_frames):
329 | num_inpainting_frames = speaker_frames
330 | else:
331 | num_inpainting_frames = round(prefix_seconds*LATENT_SAMPLING_RATE)
332 | torch.full((batch), fill_value=num_inpainting_frames, dtype=torch.int, device=device)
333 |
334 | indices = torch.arange(z_t.shape[2], device=device)
335 |
336 | # Construct mask to insert clean data
337 | fill_mask = repeat((indices <= num_inpainting_frames[:, None]), 'b l -> b c l', c=z_t.shape[1])
338 | else:
339 | fill_mask = torch.full_like(z_t, fill_value=0, dtype=torch.bool)
340 |
341 | x_start = None
342 |
343 | for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', total = self.sampling_timesteps):
344 | # get predicted x0
345 | if prefix_seconds > 0:
346 | z_t[fill_mask] = audio_latent[fill_mask]
347 | if exists(x_start):
348 | x_start[fill_mask] = audio_latent[fill_mask]
349 |
350 | # get alpha sigma of time and next time
351 |
352 | alpha = self.sampling_schedule(time)
353 | alpha_next = self.sampling_schedule(time_next)
354 |
355 | model_output = self.diffusion_model_predictions(z_t, alpha, text_cond=text_cond, text_cond_mask=text_cond_mask, audio_mask=audio_mask, sampling=True, cls_free_guidance=cls_free_guidance, fill_mask=fill_mask)
356 |
357 | alpha, alpha_next = map(partial(right_pad_dims_to, z_t), (alpha, alpha_next))
358 |
359 | # calculate x0 and noise
360 |
361 | x_start = model_output.pred_x_start
362 |
363 | eps = model_output.pred_eps
364 |
365 | if time_next[0] <= 0:
366 | z_t = x_start
367 | continue
368 |
369 | # get noise
370 | snr = alpha_to_shifted_log_snr(alpha_next)[0].item()
371 | if sampler == 'ddim' or (exists(snr_ddim_threshold) and snr > snr_ddim_threshold):
372 | z_t = x_start * alpha_next.sqrt() + eps * (1-alpha_next).sqrt()
373 | elif sampler == 'ddpm':
374 | # get noise
375 | noise = torch.randn_like(z_t)
376 | alpha_now = alpha/alpha_next
377 |
378 | min_var = torch.exp(torch.log1p(-alpha_next) - torch.log1p(-alpha)) * (1.0 -alpha_now)
379 | max_var = (1.0 - alpha_now)
380 | noise_param = log_var_interp
381 | sigma = torch.exp(noise_param * torch.log(max_var) + (1 - noise_param) * torch.log(min_var) )
382 | z_t = 1/alpha_now.sqrt() * (z_t - (1-alpha_now)/(1-alpha).sqrt() * eps) + torch.sqrt(sigma) * noise
383 |
384 | if langevin_step_size > 0:
385 | if prefix_seconds > 0:
386 | z_t[fill_mask] = audio_latent[fill_mask]
387 | if exists(x_start):
388 | x_start[fill_mask] = audio_latent[fill_mask]
389 | alpha_next = self.sampling_schedule(time_next)
390 | model_output = self.diffusion_model_predictions(z_t, alpha_next, text_cond=text_cond, text_cond_mask=text_cond_mask, audio_mask=audio_mask, sampling=True, cls_free_guidance=cls_free_guidance, fill_mask=fill_mask)
391 | alpha_next = right_pad_dims_to(z_t, alpha_next)
392 | noise = torch.randn_like(z_t)
393 | eps = model_output.pred_eps
394 | z_t = z_t - .5*langevin_step_size*(1-alpha_next).sqrt()*eps + math.sqrt(langevin_step_size)*(1-alpha_next).sqrt()*noise
395 | if prefix_seconds > 0:
396 | z_t[fill_mask] = audio_latent[fill_mask]
397 | return z_t
398 |
399 | def normalize_audio_latent(self, audio_latent):
400 | return (audio_latent - rearrange(self.data_mean, 'c -> () c ()')) / rearrange(self.data_std, 'c -> () c ()')
401 |
402 | def unnormalize_audio_latent(self, audio_latent):
403 | return audio_latent * rearrange(self.data_std, 'c -> () c ()') + rearrange(self.data_mean, 'c -> () c ()')
404 |
405 | @torch.no_grad()
406 | def sample(self, data, prefix_seconds=0, cls_free_guidance=1.0):
407 | # [B, L, d_lm]: Embedded text
408 | if prefix_seconds > 0:
409 | merged_text = [' '.join((speaker_text, text)) for speaker_text, text in zip(data['speaker_text'], data['text'])]
410 | tokenizer_output = self.text_tokenizer(merged_text, padding="max_length", truncation=True, max_length=256, return_tensors='pt').to(data['wav'].device)
411 | text_cond = self.text_encoder(tokenizer_output['input_ids'], tokenizer_output['attention_mask']).last_hidden_state.float()
412 | text_cond_mask = tokenizer_output['attention_mask'].bool()
413 | audio_latent = self.audio_codec.encode(data['speaker_wav'])
414 | speaker_frames = torch.floor(data['speaker_audio_duration'] * LATENT_SAMPLING_RATE).int()
415 | audio_mask = data['inpaint_audio_mask']
416 | else:
417 | text_cond = self.text_encoder(data['input_ids'], data['attention_mask']).last_hidden_state.float()
418 | text_cond_mask = data['attention_mask'].bool()
419 | # [B, d_audio, L]
420 | audio_latent = self.audio_codec.encode(data['wav'])
421 | speaker_frames = None
422 | audio_mask = data['audio_mask']
423 | audio_latent = self.normalize_audio_latent(audio_latent)
424 | latent_shape = audio_latent.shape
425 | assert self.sampler in {'ddim', 'ddpm'}
426 | sampled_audio_latent = self.ddim_or_ddpm_sample(latent_shape, text_cond, text_cond_mask, prefix_seconds=prefix_seconds, audio_latent=audio_latent, cls_free_guidance=cls_free_guidance, speaker_frames=speaker_frames, audio_mask=audio_mask, sampler=self.sampler, log_var_interp=self.log_var_interp, langevin_step_size=self.langevin_step_size, snr_ddim_threshold=self.snr_ddim_threshold)
427 | return self.unnormalize_audio_latent(sampled_audio_latent)
428 |
429 | @property
430 | def loss_fn(self):
431 | if self.loss_type == 'l1':
432 | return F.l1_loss
433 | elif self.loss_type == 'l2':
434 | return F.mse_loss
435 | else:
436 | raise ValueError(f'invalid loss type {self.loss_type}')
437 |
438 | def inpainting_enabled(self):
439 | return self.inpainting_prob > 0
440 |
441 | def get_loss_emas(self, split='train'):
442 | assert self.adaptive_noise_schedule, 'EMA loss only available with adaptive noise schedule'
443 | if split == 'train':
444 | return self.train_schedule.get_loss_emas()
445 | elif split == 'val':
446 | return self.val_ema_sampler.get_loss_emas()
447 | else:
448 | raise ValueError(f'invalid split {split}')
449 |
450 | def get_normalized_loss_emas(self, split='train'):
451 | assert self.adaptive_noise_schedule, 'EMA loss only available with adaptive noise schedule'
452 | if split == 'train':
453 | return self.train_schedule.get_normalized_loss_emas()
454 | elif split == 'val':
455 | return self.val_ema_sampler.get_normalized_loss_emas()
456 | else:
457 | raise ValueError(f'invalid split {split}')
458 |
459 | def get_unweighted_loss_emas(self, split='train'):
460 | if self.adaptive_noise_schedule:
461 | loss_ema = self.train_schedule
462 | else:
463 | assert split == 'train', 'EMA loss only available for train split without adaptive noise schedule'
464 | loss_ema = self.logsnr_loss_tracker
465 | if split == 'train':
466 | return loss_ema.get_unweighted_loss_emas()
467 | elif split == 'val':
468 | return self.val_ema_sampler.get_unweighted_loss_emas()
469 | else:
470 | raise ValueError(f'invalid split {split}')
471 |
472 | def save_adaptive_noise_schedule(self, path):
473 | if not self.adaptive_noise_schedule:
474 | ema_unweighted_loss_path = os.path.join(path, f'ema_unweighted_loss.png')
475 | self.logsnr_loss_tracker.save_unweighted_loss_emas(path=ema_unweighted_loss_path)
476 | return
477 | # save train plots
478 | density_path = os.path.join(path, f'schedule_density.png')
479 | self.train_schedule.save_density(path=density_path)
480 | cdf_path = os.path.join(path, f'schedule_cdf.png')
481 | self.train_schedule.save_cumulative_density(path=cdf_path)
482 | ema_loss_path = os.path.join(path, f'ema_loss.png')
483 | self.train_schedule.save_loss_emas(path=ema_loss_path)
484 | ema_unweighted_loss_path = os.path.join(path, f'ema_unweighted_loss.png')
485 | self.train_schedule.save_unweighted_loss_emas(path=ema_unweighted_loss_path)
486 |
487 | # save val plots
488 | density_path = os.path.join(path, f'schedule_density_val.png')
489 | self.val_ema_sampler.save_density(path=density_path)
490 | cdf_path = os.path.join(path, f'schedule_cdf_val.png')
491 | self.val_ema_sampler.save_cumulative_density(path=cdf_path)
492 | ema_loss_path = os.path.join(path, f'ema_loss_val.png')
493 | self.val_ema_sampler.save_loss_emas(path=ema_loss_path)
494 | ema_unweighted_loss_path = os.path.join(path, f'ema_unweighted_loss_val.png')
495 | self.val_ema_sampler.save_unweighted_loss_emas(path=ema_unweighted_loss_path)
496 |
497 | def adaptive_noise_schedule_enabled(self):
498 | return self.adaptive_noise_schedule
499 |
500 | def forward(self, data, accelerator=None):
501 |
502 | with torch.no_grad():
503 | # [B, L, d_lm]: Embedded text
504 | text_cond = self.text_encoder(data['input_ids'], data['attention_mask']).last_hidden_state.float()
505 | # [B, L]: Cross-attn mask
506 | text_cond_mask = data['attention_mask'].bool()
507 |
508 |
509 | # [B, d_audio, L]: embedded audio
510 | with torch.cuda.amp.autocast(enabled=False):
511 | audio_latent = self.audio_codec.encode(data['wav'])
512 |
513 | if self.scale_by_std:
514 | # Estimate standard deviation of the data from the first batch
515 | if self.data_std[0].item() < 0:
516 | gathered_audio_latent = accelerator.gather(audio_latent)
517 | gathered_audio_mask = accelerator.gather(data['audio_mask'])
518 | min_length = gathered_audio_mask.sum(1).min().item()
519 |
520 | # Compute mean and std of the data
521 | data_mean = reduce(gathered_audio_latent[:,:,:min_length], 'b c l -> c', 'mean')
522 | data_std = reduce(gathered_audio_latent[:,:,:min_length], 'b c l -> c', torch.std)
523 | self.data_mean = data_mean
524 | self.data_std = data_std
525 |
526 | print(f'Set data mean: {self.data_mean.tolist()}')
527 | print(f'Set data std: {self.data_std.tolist()}')
528 |
529 | audio_latent = self.normalize_audio_latent(audio_latent)
530 |
531 | batch, audio_channels, audio_length = audio_latent.shape
532 | device = audio_latent.device
533 |
534 | # Mask out text-conditioning with some probability to enable clf-free guidance
535 | if self.unconditional_prob > 0:
536 | unconditional_mask = self.unconditional_bernoulli.sample((batch,)).bool()
537 | text_cond_mask[unconditional_mask, :] = False
538 |
539 | # sample random times
540 |
541 | if self.adaptive_noise_schedule and self.training:
542 | gamma, schedule_density = self.train_schedule.sample(
543 | batch_size=batch, device=device)
544 | alpha = log_snr_to_alpha(gamma)
545 | elif self.adaptive_noise_schedule:
546 | gamma, schedule_density = self.train_schedule.sample(
547 | batch_size=batch, device=device, uniform=True)
548 | alpha = log_snr_to_alpha(gamma)
549 | else:
550 | times = torch.zeros((batch,), device = device).float().uniform_(0, 1.)
551 | alpha = self.train_schedule(times)
552 | gamma = alpha_to_shifted_log_snr(alpha, scale=1)
553 | padded_alpha = right_pad_dims_to(audio_latent, alpha)
554 |
555 | # noise sample
556 | noise = torch.randn_like(audio_latent)
557 |
558 | z_t = padded_alpha.sqrt() * audio_latent + (1-padded_alpha).sqrt() * noise
559 |
560 | # Inpainting logic
561 | inpainting_mask = None
562 | if self.inpainting_prob > 0:
563 | inpainting_batch_mask = self.inpainting_bernoulli.sample((batch,)).bool().to(device)
564 | # Sample durations to mask
565 | inpainting_durations = self.inpainting_duration_beta.sample((batch,)).to(device) * data['audio_duration']
566 | num_inpainting_frames = torch.round(inpainting_durations*LATENT_SAMPLING_RATE).int()
567 |
568 | # Sample where to mask
569 | indices = torch.arange(audio_length, device=device)
570 |
571 | # Construct mask to insert clean data
572 | inpainting_length_mask = ((indices <= num_inpainting_frames[:, None]))
573 | inpainting_mask = (inpainting_length_mask) & inpainting_batch_mask.unsqueeze(-1)
574 | fill_mask = repeat(inpainting_mask, 'b l -> b c l', c=audio_channels)
575 |
576 | z_t[fill_mask] = audio_latent[fill_mask]
577 | else:
578 | fill_mask = torch.full_like(z_t, fill_value=0, dtype=torch.bool)
579 |
580 | # predict and take gradient step
581 | predictions = self.diffusion_model_predictions(z_t, alpha, text_cond=text_cond, text_cond_mask=text_cond_mask, fill_mask=fill_mask, audio_mask=data['audio_mask'])
582 |
583 |
584 | if self.objective == 'pred_eps':
585 | target = noise
586 | pred = predictions.pred_eps
587 | elif self.objective == 'pred_v':
588 | # V-prediction from https://openreview.net/forum?id=TIdIXIpzhoI
589 | velocity = padded_alpha.sqrt() * noise - (1-padded_alpha).sqrt() * audio_latent
590 | target = velocity
591 | pred = predictions.pred_v
592 | else:
593 | raise ValueError(f'invalid objective {self.objective}')
594 |
595 | loss = self.loss_fn(pred, target, reduction = 'none')
596 | loss = reduce(loss, 'b c l -> b l', 'mean')
597 | if self.inpainting_prob > 0:
598 | # Need to align with gamma
599 | # Standard diffusion loss
600 | diff_batch = loss[~inpainting_batch_mask]
601 | diff_loss = masked_mean(diff_batch, dim=1, mask=data['audio_mask'][~inpainting_batch_mask])
602 |
603 | # Masked inpainting loss
604 | inpainting_batch = loss[inpainting_batch_mask]
605 | loss_mask = torch.logical_and((~inpainting_length_mask[inpainting_batch_mask]), data['audio_mask'][inpainting_batch_mask])
606 | inpainting_loss = masked_mean(inpainting_batch, dim=1, mask=loss_mask)
607 | loss = torch.cat([diff_loss, inpainting_loss], dim=0)
608 |
609 | gamma_diff_batch = gamma[~inpainting_batch_mask]
610 | gamma_inpainting_batch = gamma[inpainting_batch_mask]
611 | gamma = torch.cat([gamma_diff_batch, gamma_inpainting_batch], dim=0)
612 |
613 | if self.adaptive_noise_schedule:
614 | schedule_density_diff_batch = schedule_density[~inpainting_batch_mask]
615 | schedule_density_inpainting_batch = schedule_density[inpainting_batch_mask]
616 | schedule_density = torch.cat([schedule_density_diff_batch, schedule_density_inpainting_batch], dim=0)
617 | else:
618 | loss = masked_mean(loss, dim=1, mask=data['audio_mask'])
619 |
620 | if not self.adaptive_noise_schedule:
621 | self.logsnr_loss_tracker.update_with_all_unweighted_losses(gamma.squeeze(), loss)
622 | return loss.mean()
623 |
624 | diffusion_loss_weighting = self.diffusion_loss_weighting(gamma=gamma).squeeze()
625 | weighted_loss = diffusion_loss_weighting * loss
626 | # Update loss ema
627 | # Disable autocast for this step
628 | with torch.cuda.amp.autocast(enabled=False):
629 | if self.training:
630 | self.train_schedule.update_with_all_losses(gamma.squeeze(), weighted_loss)
631 | self.train_schedule.update_with_all_unweighted_losses(gamma.squeeze(), loss)
632 | else:
633 | self.val_ema_sampler.update_with_all_losses(gamma.squeeze(), weighted_loss)
634 | self.val_ema_sampler.update_with_all_unweighted_losses(gamma.squeeze(), loss)
635 | monte_carlo_weighted_loss = torch.exp(torch.log(diffusion_loss_weighting) - torch.log(schedule_density))*loss
636 | return (monte_carlo_weighted_loss).mean()
637 |
638 | # trainer class
639 |
640 | class Trainer(object):
641 | def __init__(
642 | self,
643 | args,
644 | diffusion,
645 | dataset_name,
646 | *,
647 | optimizer = 'adamw',
648 | batch_size = 16,
649 | gradient_accumulate_every = 1,
650 | train_lr = 1e-4,
651 | train_num_steps = 100000,
652 | lr_schedule = 'cosine',
653 | num_warmup_steps = 500,
654 | adam_betas = (0.9, 0.999),
655 | adam_weight_decay = 0.01,
656 | save_and_sample_every = 5000,
657 | num_samples = 25,
658 | mixed_precision = 'no',
659 | prefix_inpainting_seconds=0,
660 | duration_path=None,
661 | seed=None
662 | ):
663 | super().__init__()
664 |
665 | assert prefix_inpainting_seconds in {0., 3.0} or args.kilian, 'Currently only supports 3sec for inpainting'
666 | if exists(seed):
667 | set_seeds(seed)
668 |
669 | self.args = args
670 |
671 | self.accelerator = Accelerator(
672 | mixed_precision = mixed_precision,
673 | log_with='wandb',
674 | )
675 | self.num_devices = self.accelerator.num_processes
676 | args.num_devices = self.num_devices
677 |
678 | args.output_dir = get_output_dir(args)
679 |
680 | if self.accelerator.is_main_process:
681 | os.makedirs(args.output_dir)
682 | print(f'Created {args.output_dir}')
683 |
684 | with open(os.path.join(args.output_dir, 'args.json'), 'w') as f:
685 | json.dump(args.__dict__, f, indent=2)
686 | run = os.path.split(__file__)[-1].split(".")[0]
687 | self.accelerator.init_trackers(run, config=vars(args), init_kwargs={"wandb": {"dir": args.output_dir, "name": args.run_name}})
688 |
689 |
690 | self.diffusion = diffusion
691 |
692 | self.num_samples = num_samples
693 | self.save_and_sample_every = save_and_sample_every
694 | self.prefix_inpainting_seconds = prefix_inpainting_seconds
695 |
696 | self.batch_size = batch_size
697 | self.gradient_accumulate_every = gradient_accumulate_every
698 |
699 | self.train_num_steps = train_num_steps
700 | self.max_seq_len = diffusion.max_seq_len
701 |
702 |
703 | # dataset and dataloader
704 | if dataset_name == 'librispeech':
705 | self.dataset = LibriSpeech(split='train', tokenizer=diffusion.text_tokenizer)
706 | self.val_dataset = LibriSpeech(split='valid', tokenizer=diffusion.text_tokenizer)
707 | self.test_dataset = LibriSpeech(split='test', tokenizer=diffusion.text_tokenizer, max_seq_len=self.dataset.max_seq_len, duration_path=duration_path)
708 | else:
709 | raise ValueError(f'invalid dataset: {dataset_name}')
710 |
711 | self.dataloader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, drop_last=True, pin_memory=True, num_workers=2)
712 | self.val_dataloader = DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)
713 | self.test_dataloader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)
714 |
715 |
716 | # optimizer
717 |
718 | if optimizer == 'adamw':
719 | self.opt = get_adamw_optimizer(diffusion.parameters(), lr = train_lr, betas = adam_betas, weight_decay=adam_weight_decay)
720 | elif optimizer == 'lion':
721 | self.opt = get_lion_optimizer(diffusion.parameters(), lr = train_lr, weight_decay=adam_weight_decay)
722 | else:
723 | raise ValueError(f'invalid optimizer {optimizer}')
724 |
725 | # scheduler
726 |
727 | lr_scheduler = get_scheduler(
728 | lr_schedule,
729 | optimizer=self.opt,
730 | num_warmup_steps=num_warmup_steps*self.num_devices,
731 | num_training_steps=train_num_steps*self.num_devices, # Accelerate does num_devices steps at a time
732 | )
733 |
734 | # for logging results in a folder periodically
735 |
736 | if self.accelerator.is_main_process:
737 |
738 | self.results_folder = args.output_dir
739 |
740 | # step counter state
741 |
742 | self.step = 0
743 |
744 | # prepare model, dataloader, optimizer with accelerator
745 |
746 | self.diffusion, self.opt, self.dataloader, self.val_dataloader, self.test_dataloader, self.lr_scheduler = self.accelerator.prepare(self.diffusion, self.opt, self.dataloader, self.val_dataloader, self.test_dataloader, lr_scheduler)
747 | self.data_iter = cycle(self.dataloader)
748 | self.val_data_iter = cycle(self.val_dataloader)
749 |
750 | def save(self, save_step=False):
751 | if not self.accelerator.is_local_main_process:
752 | return
753 |
754 | data = {
755 | 'step': self.step,
756 | 'model': self.accelerator.get_state_dict(self.diffusion),
757 | 'opt': self.opt.state_dict(),
758 | 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None,
759 | 'scheduler': self.lr_scheduler.state_dict(),
760 | }
761 | if save_step:
762 | torch.save(data, f'{self.results_folder}/model_{self.step}.pt')
763 | else:
764 | torch.save(data, f'{self.results_folder}/model.pt')
765 |
766 | def load(self, file_path=None, best=False, init_only=False, ckpt_step=None):
767 | file_path = file_path if exists(file_path) else self.results_folder
768 | accelerator = self.accelerator
769 | device = accelerator.device
770 |
771 | if ckpt_step is not None:
772 | data = torch.load(f'{file_path}/model_{ckpt_step}.pt', map_location=device)
773 | else:
774 | data = torch.load(f'{file_path}/model.pt', map_location=device)
775 |
776 | model = self.accelerator.unwrap_model(self.diffusion)
777 | strict_load = not (init_only)
778 | model.load_state_dict(data['model'], strict=strict_load)
779 |
780 | if init_only:
781 | return
782 |
783 | # For backwards compatibility with earlier models
784 | if exists(self.accelerator.scaler) and exists(data['scaler']):
785 | self.accelerator.scaler.load_state_dict(data['scaler'])
786 |
787 | self.opt.load_state_dict(data['opt'])
788 |
789 | self.step = data['step']
790 | self.lr_scheduler.load_state_dict(data['scheduler'])
791 |
792 |
793 | @torch.no_grad()
794 | def sample(self, num_samples=None, seed=None, cls_free_guidance=1.0, test=False, prefix_seconds=0.):
795 | if exists(seed):
796 | set_seeds(seed)
797 | diffusion = self.accelerator.unwrap_model(self.diffusion)
798 | num_samples = default(num_samples, self.num_samples)
799 | self.diffusion.eval()
800 | num_sampled = 0
801 | dataloader = self.test_dataloader if test else self.val_dataloader
802 | for batch in dataloader:
803 | sampled_codec_latents = diffusion.sample(batch, prefix_seconds=prefix_seconds, cls_free_guidance=cls_free_guidance)
804 | sampled_wavs = diffusion.audio_codec.decode(sampled_codec_latents).squeeze() # [B, L]
805 |
806 | sampled_wavs = self.accelerator.gather_for_metrics(sampled_wavs).to('cpu')
807 |
808 | input_ids = self.accelerator.gather_for_metrics(batch['input_ids']).to('cpu')
809 |
810 | speaker_durations = self.accelerator.gather_for_metrics(batch['speaker_audio_duration']).to('cpu')
811 |
812 | if prefix_seconds > 0:
813 | num_audio_frames = batch['inpaint_audio_mask'].sum(dim=1)*ENCODEC_REDUCTION_FACTOR # [B]
814 | num_audio_frames = self.accelerator.gather_for_metrics(num_audio_frames).to('cpu')
815 | else:
816 | num_audio_frames = batch['audio_mask'].sum(dim=1)*ENCODEC_REDUCTION_FACTOR # [B]
817 | num_audio_frames = self.accelerator.gather_for_metrics(num_audio_frames).to('cpu')
818 |
819 |
820 | if self.accelerator.is_main_process:
821 | inpainting_suffix = f'_prefix{prefix_seconds}' if prefix_seconds>0 else ''
822 | step_folder = os.path.join(self.results_folder, 'samples', f'step_{self.step}')
823 | samples_folder = os.path.join(step_folder, f'guide{cls_free_guidance}{inpainting_suffix}')
824 | os.makedirs(samples_folder, exist_ok=True)
825 | text_list = [diffusion.text_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in input_ids]
826 | ref_frames = torch.ceil(speaker_durations*ENCODEC_SAMPLING_RATE).int()
827 | for idx in range(len(text_list)):
828 | text = text_list[idx]
829 | print(f'Saving idx: {idx+num_sampled}')
830 | with open(os.path.join(samples_folder, f'text_{idx+num_sampled}.txt'), 'w') as f:
831 | print(text, file=f)
832 | if prefix_seconds > 0:
833 | ref_frames_idx = ref_frames[idx].item()
834 | sf.write(os.path.join(samples_folder, f'audio_{idx+num_sampled}.wav'), sampled_wavs[idx][ref_frames_idx:num_audio_frames[idx].item()], ENCODEC_SAMPLING_RATE)
835 | sf.write(os.path.join(samples_folder, f'ref_{idx+num_sampled}.wav'), sampled_wavs[idx][:ref_frames_idx], ENCODEC_SAMPLING_RATE)
836 | else:
837 | sample_wav = sampled_wavs[idx][:num_audio_frames[idx].item()]
838 | sf.write(os.path.join(samples_folder, f'audio_{idx+num_sampled}.wav'), sample_wav, ENCODEC_SAMPLING_RATE)
839 |
840 | batch_size = self.num_devices * batch['wav'].shape[0]
841 | num_sampled += batch_size
842 |
843 |
844 | if exists(num_samples) and num_sampled >= num_samples:
845 | break
846 | self.diffusion.train()
847 | if self.accelerator.is_main_process and cls_free_guidance == 1.0:
848 | self.diffusion.save_adaptive_noise_schedule(step_folder)
849 |
850 |
851 | def train(self):
852 | accelerator = self.accelerator
853 | device = accelerator.device
854 |
855 | with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar:
856 |
857 | while self.step < self.train_num_steps:
858 |
859 | total_loss = 0.
860 |
861 | for _ in range(self.gradient_accumulate_every):
862 | data = next(self.data_iter)
863 | loss = self.diffusion(data, accelerator)
864 | loss = loss / self.gradient_accumulate_every
865 | total_loss += loss.item()
866 |
867 | self.accelerator.backward(loss)
868 |
869 |
870 | if accelerator.sync_gradients:
871 | grad_norm = compute_grad_norm(self.diffusion.parameters())
872 | accelerator.clip_grad_norm_(self.diffusion.parameters(), self.args.clip_grad_norm)
873 | self.opt.step()
874 | self.lr_scheduler.step()
875 | self.opt.zero_grad()
876 | self.step += 1
877 |
878 | if self.step % 10 == 0:
879 | logs = {
880 | "train/loss": total_loss,
881 | "learning_rate": self.lr_scheduler.get_last_lr()[0],
882 | "grad_norm": grad_norm,
883 | "step": self.step,
884 | "epoch": (self.step*self.gradient_accumulate_every)/len(self.dataloader),
885 | "samples": self.step*self.batch_size*self.gradient_accumulate_every*self.num_devices,
886 | }
887 | if self.diffusion.adaptive_noise_schedule_enabled():
888 | logs["train/weighted_loss_ema"] = self.diffusion.train_schedule.weights().mean()
889 | # Validation loss
890 | if self.step % 50 == 0:
891 | self.diffusion.eval()
892 | with torch.no_grad():
893 | total_val_loss = 0
894 | data = next(self.val_data_iter)
895 | loss = self.diffusion(data)
896 | total_val_loss += loss.item()
897 | if self.diffusion.adaptive_noise_schedule_enabled():
898 | logs['val/weighted_loss'] = self.diffusion.val_ema_sampler.weights().mean()
899 | self.diffusion.train()
900 |
901 | if accelerator.is_main_process:
902 | if self.diffusion.adaptive_noise_schedule_enabled():
903 | loss_emas_dict = self.diffusion.get_loss_emas()
904 | for key, value in loss_emas_dict.items():
905 | logs[f'train/ema_loss/{key}'] = value
906 | loss_emas_dict = self.diffusion.get_normalized_loss_emas()
907 | for key, value in loss_emas_dict.items():
908 | logs[f'train/ema/normalized_ema_loss/{key}'] = value
909 |
910 | # Val loss
911 | loss_emas_dict = self.diffusion.get_loss_emas(split='val')
912 | for key, value in loss_emas_dict.items():
913 | logs[f'val/loss/{key}'] = value
914 | loss_emas_dict = self.diffusion.get_unweighted_loss_emas(split='val')
915 | for key, value in loss_emas_dict.items():
916 | logs[f'val/unweighted_loss/{key}'] = value
917 | loss_emas_dict = self.diffusion.get_unweighted_loss_emas()
918 | for key, value in loss_emas_dict.items():
919 | logs[f'train/ema/unweighted_ema_loss/{key}'] = value
920 | pbar.set_postfix(**logs)
921 | accelerator.log(logs, step=self.step)
922 |
923 | accelerator.wait_for_everyone()
924 | # Update EMA
925 | accelerator.unwrap_model(self.diffusion).denoising_network.update()
926 |
927 | if self.step % self.save_and_sample_every == 0:
928 | self.sample()
929 | for cls_free_guidance in [3.0, 5.0]:
930 | self.sample(cls_free_guidance=cls_free_guidance)
931 |
932 | if self.prefix_inpainting_seconds > 0:
933 | self.sample(prefix_seconds=self.prefix_inpainting_seconds)
934 | for cls_free_guidance in [3.0, 5.0]:
935 | self.sample(cls_free_guidance=cls_free_guidance, prefix_seconds=self.prefix_inpainting_seconds)
936 | self.save(save_step=True)
937 |
938 | self.diffusion.train()
939 |
940 | pbar.update(1)
941 |
942 | # Save final model
943 | self.save()
944 | accelerator.end_training()
945 | accelerator.print('training complete')
--------------------------------------------------------------------------------
/diffusion/noise_schedule.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import numpy as np
4 | from functools import partial
5 | import matplotlib.pyplot as plt
6 | import json
7 |
8 | # Avoid log(0)
9 | def log(t, eps = 1e-12):
10 | return torch.log(t.clamp(min = eps))
11 |
12 | # noise schedules
13 |
14 | def simple_linear_schedule(t, clip_min = 1e-9):
15 | return (1 - t).clamp(min = clip_min)
16 |
17 | def beta_linear_schedule(t, clip_min = 1e-9):
18 | return torch.exp(-1e-4 - 10 * (t ** 2)).clamp(min = clip_min, max = 1.)
19 |
20 | def cosine_schedule(t, start = 0, end = 1, tau = 1, clip_min = 1e-9):
21 | power = 2 * tau
22 | v_start = math.cos(start * math.pi / 2) ** power
23 | v_end = math.cos(end * math.pi / 2) ** power
24 | output = torch.cos((t * (end - start) + start) * math.pi / 2) ** power
25 | output = (v_end - output) / (v_end - v_start)
26 | return output.clamp(min = clip_min)
27 |
28 | def sigmoid_schedule(t, start = -3, end = 3, tau = 1, clamp_min = 1e-9):
29 | v_start = torch.tensor(start / tau).sigmoid()
30 | v_end = torch.tensor(end / tau).sigmoid()
31 | gamma = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
32 | return gamma.clamp_(min = clamp_min, max = 1.)
33 |
34 | def sigmoid_k_weighting(k, gamma):
35 | return torch.sigmoid(-gamma + k)
36 |
37 | class V_Weighting:
38 | def __init__(self, objective='pred_v'):
39 | self.objective = objective
40 |
41 | def v_loss_weighting(self, gamma):
42 | return 1/torch.cosh(-(gamma/2))
43 |
44 | def eps_loss_weighting(self, gamma):
45 | return torch.exp(-gamma/(2))
46 |
47 | class LogNormal_V_Weighting:
48 | def __init__(self, gamma_mean=0.0, gamma_std=0.0, objective='pred_v'):
49 | self.gamma_mean = gamma_mean
50 | self.gamma_std = gamma_std
51 | self.min_gamma = -15
52 | self.max_gamma = 15
53 |
54 | assert objective == 'pred_v'
55 |
56 | self.normal_dist = torch.distributions.normal.Normal(self.gamma_mean, self.gamma_std)
57 |
58 | self.log_max_weighting = self.normal_dist.log_prob(torch.tensor(gamma_mean))
59 |
60 | def v_loss_weighting(self, gamma):
61 | return torch.exp(self.normal_dist.log_prob(gamma) - self.log_max_weighting)
62 |
63 | def v_weighting(self, gamma):
64 | return self.v_loss_weighting(gamma)
65 |
66 | class LogCauchy_V_Weighting:
67 | def __init__(self, gamma_mean=0.0, gamma_std=0.0, objective='pred_v'):
68 | self.gamma_mean = gamma_mean
69 | self.gamma_std = gamma_std
70 | self.min_gamma = -15
71 | self.max_gamma = 15
72 |
73 | assert objective == 'pred_v'
74 |
75 | self.cauchy_dist = torch.distributions.cauchy.Cauchy(self.gamma_mean, self.gamma_std)
76 |
77 | self.log_max_weighting = self.cauchy_dist.log_prob(torch.tensor(gamma_mean))
78 |
79 | def v_loss_weighting(self, gamma):
80 | return torch.exp(self.cauchy_dist.log_prob(gamma) - self.log_max_weighting)
81 |
82 | def v_weighting(self, gamma):
83 | return self.v_loss_weighting(gamma)
84 |
85 | class Asymmetric_LogNormal_V_Weighting:
86 | def __init__(self, gamma_mean=0.0, gamma_std=0.0, std_mult=2.0, objective='pred_v'):
87 | self.gamma_mean = gamma_mean
88 | self.gamma_std = gamma_std
89 | self.min_gamma = -15
90 | self.max_gamma = 15
91 |
92 | assert objective == 'pred_v'
93 |
94 | self.neg_normal_v_weighting = LogCauchy_V_Weighting(gamma_mean, gamma_std*std_mult, objective='pred_v')
95 | self.normal_v_weighting = LogNormal_V_Weighting(gamma_mean, gamma_std, objective='pred_v')
96 |
97 | def v_loss_weighting(self, gamma):
98 | # Use normal weighting for gamma >= self.gamma_mean
99 | normal_weighting = self.normal_v_weighting.v_loss_weighting(gamma)
100 | # Use neg_normal weighting for gamma < self.gamma_mean
101 | neg_normal_weighting = self.neg_normal_v_weighting.v_loss_weighting(gamma)
102 | return torch.where(gamma < self.gamma_mean, neg_normal_weighting, normal_weighting)
103 |
104 | def v_weighting(self, gamma):
105 | # Use cauchy weighting for gamma < self.gamma_mean
106 | neg_normal_weighting = self.neg_normal_v_weighting.v_weighting(gamma)
107 | # Use normal weighting for gamma >= self.gamma_mean
108 | normal_weighting = self.normal_v_weighting.v_weighting(gamma)
109 | return torch.where(gamma < self.gamma_mean, neg_normal_weighting, normal_weighting)
110 |
111 | # converting gamma to alpha, sigma or logsnr
112 | def log_snr_to_alpha(log_snr):
113 | alpha = torch.sigmoid(log_snr)
114 | return alpha
115 |
116 | # Log-SNR shifting (https://arxiv.org/abs/2301.10972)
117 | def alpha_to_shifted_log_snr(alpha, scale = 1):
118 | return (log(alpha) - log(1 - alpha)).clamp(min=-20, max=20) + 2*np.log(scale).item()
119 |
120 | def time_to_alpha(t, alpha_schedule, scale):
121 | alpha = alpha_schedule(t)
122 | shifted_log_snr = alpha_to_shifted_log_snr(alpha, scale = scale)
123 | return log_snr_to_alpha(shifted_log_snr)
124 |
125 | def plot_noise_schedule(unscaled_sampling_schedule, name, y_value):
126 | assert y_value in {'alpha^2', 'alpha', 'log(SNR)'}
127 | t = torch.linspace(0, 1, 100) # 100 points between 0 and 1
128 | scales = [.2, .5, 1.0]
129 | for scale in scales:
130 | sampling_schedule = partial(time_to_alpha, alpha_schedule=unscaled_sampling_schedule, scale=scale)
131 | alphas = sampling_schedule(t) # Obtain noise schedule values for each t
132 | if y_value == 'alpha^2':
133 | y_axis_label = r'$\alpha^2_t$'
134 | y = alphas
135 | elif y_value == 'alpha':
136 | y_axis_label = r'$\alpha_t$'
137 | y = alphas.sqrt()
138 | elif y_value == 'log(SNR)':
139 | y_axis_label = r'$\log(\lambda_t)$'
140 | y = alpha_to_shifted_log_snr(alphas, scale=1)
141 |
142 | plt.plot(t.numpy(), y.numpy(), label=f'Scale: {scale:.1f}')
143 | if y_value == 'log(SNR)':
144 | plt.ylim(-15, 15)
145 | plt.xlabel('t')
146 | plt.ylabel(y_axis_label)
147 | plt.title(f'{name}')
148 | plt.legend()
149 | plt.savefig(f'viz/{name.lower()}_{y_value}.png')
150 | plt.clf()
151 |
152 |
153 | def plot_cosine_schedule():
154 | t = torch.linspace(0, 1, 100) # 100 points between 0 and 1
155 | sampling_schedule = cosine_schedule
156 | alphas = sampling_schedule(t) # Obtain noise schedule values for each t
157 | y = alphas
158 | plt.plot(t.numpy(), y.numpy())
159 | plt.xlabel('t')
160 | plt.ylabel(f'alpha^2')
161 | plt.title(f'Cosine Noise Schedule')
162 | plt.savefig(f'viz/standard_cosine.png')
163 | plt.clf()
164 |
165 | def plot_weighting_functions():
166 | gamma = torch.linspace(-15, 15, 1000)
167 |
168 |
169 | # Plot v weighting
170 | v_obj = V_Weighting(gamma_shift=0.0, objective='pred_eps')
171 | v_weighting = v_obj.v_loss_weighting(gamma)
172 |
173 | # Log-normal V weighting with mean=0.0, std=2.0
174 | v_obj = LogNormal_V_Weighting(gamma_mean=-1.0, gamma_std=2.4, objective='pred_v')
175 | v_weighting_lognormal = v_obj.v_loss_weighting(gamma)
176 |
177 | # Log-cauchyNormal V weighting with mean=0.0, std=2.0
178 | v_obj = Asymmetric_LogNormal_V_Weighting(gamma_mean=-1.0, gamma_std=2.4, objective='pred_v', std_mult=2.0)
179 | v_weighting_logcauchynormal = v_obj.v_loss_weighting(gamma)
180 |
181 | # Create plot
182 | fig, ax = plt.subplots(figsize=(8, 4))
183 | ax.plot(gamma.numpy(), v_weighting_logcauchynormal.numpy(), label='Asymmetric Weighting')
184 | ax.plot(gamma.numpy(), v_weighting_lognormal.numpy(), label='Symmetric Weighting')
185 | ax.plot(gamma.numpy(), v_weighting.numpy(), label='V-Weighting (VoiceBox)')
186 | ax.set_xlabel(r'Log-SNR ($\lambda_t$)', fontsize=14)
187 | ax.set_ylabel(r'Loss Weight ($w(\lambda_t)$)', fontsize=14)
188 | ax.set_title('Loss Weighting Across Noise Levels', fontsize=14)
189 | # Create legend
190 | ax.legend(loc='upper right', ncol=1, fontsize=12)
191 | plt.savefig('viz/v_weighting_pred.png', bbox_inches='tight')
192 | plt.clf()
193 |
194 | # Save log-scale version
195 | fig, ax = plt.subplots()
196 | ax.plot(gamma.numpy(), v_weighting.numpy(), label='V Weighting')
197 | ax.plot(gamma.numpy(), v_weighting_lognormal.numpy(), label='Log-Normal V Weighting (mean=-1.0, std=2.4)')
198 | ax.plot(gamma.numpy(), v_weighting_logcauchynormal.numpy(), label='Log-CauchyNormal V Weighting (mean=0.0, std=2.4)')
199 | ax.set_xlabel(r'$\lambda_t$')
200 | ax.set_ylabel('Weighting (V-Space)')
201 | ax.set_title('V Weighting with pred_v Objective')
202 | ax.legend()
203 | ax.set_yscale('log')
204 | plt.savefig('viz/v_weighting_pred_v_log2.png')
205 | plt.clf()
206 |
207 | # Make sure to call this function in your main visualization routine
208 | if __name__ == '__main__':
209 | plot_weighting_functions()
210 | plot_cosine_schedule()
211 |
--------------------------------------------------------------------------------
/diffusion/optimizer.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Optional, Callable
2 |
3 | import torch
4 | from torch.optim.optimizer import Optimizer
5 | from torch.optim import AdamW
6 | import math
7 |
8 | # functions
9 |
10 | def exists(val):
11 | return val is not None
12 |
13 | # update functions
14 |
15 | def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
16 | # stepweight decay
17 |
18 | p.data.mul_(1 - lr * wd)
19 |
20 | # weight update
21 |
22 | update = exp_avg.clone().lerp_(grad, 1 - beta1).sign_()
23 | p.add_(update, alpha = -lr)
24 |
25 | # decay the momentum running average coefficient
26 |
27 | exp_avg.lerp_(grad, 1 - beta2)
28 |
29 | # class
30 |
31 | class Lion(Optimizer):
32 | def __init__(
33 | self,
34 | params,
35 | lr: float = 1e-4,
36 | betas: Tuple[float, float] = (0.9, 0.99),
37 | weight_decay: float = 0.0,
38 | ):
39 | assert lr > 0.
40 | assert all([0. <= beta <= 1. for beta in betas])
41 |
42 | defaults = dict(
43 | lr = lr,
44 | betas = betas,
45 | weight_decay = weight_decay
46 | )
47 |
48 | super().__init__(params, defaults)
49 |
50 | self.update_fn = update_fn
51 |
52 | @torch.no_grad()
53 | def step(
54 | self,
55 | closure: Optional[Callable] = None
56 | ):
57 |
58 | loss = None
59 | if exists(closure):
60 | with torch.enable_grad():
61 | loss = closure()
62 |
63 | for group in self.param_groups:
64 | for p in filter(lambda p: exists(p.grad), group['params']):
65 |
66 | grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p]
67 |
68 | # init state - exponential moving average of gradient values
69 |
70 | if len(state) == 0:
71 | state['exp_avg'] = torch.zeros_like(p)
72 |
73 | exp_avg = state['exp_avg']
74 |
75 | self.update_fn(
76 | p,
77 | grad,
78 | exp_avg,
79 | lr,
80 | wd,
81 | beta1,
82 | beta2
83 | )
84 |
85 | return loss
86 |
87 | def get_corrected_weight_decay(lr, weight_decay):
88 | if weight_decay == 0:
89 | return 0
90 | corrected_weight_decay = math.exp(math.log(weight_decay) - math.log(lr))
91 | return corrected_weight_decay
92 |
93 | def separate_weight_decayable_params(params):
94 | # Exclude affine params in norms (e.g. LayerNorm, GroupNorm, etc.) and bias terms
95 | no_wd_params = [param for param in params if param.ndim < 2]
96 | wd_params = [param for param in params if param not in set(no_wd_params)]
97 | return wd_params, no_wd_params
98 |
99 | def get_adamw_optimizer(params, lr, betas, weight_decay, eps=1e-8):
100 | params = list(params)
101 | wd_params, no_wd_params = separate_weight_decayable_params(params)
102 |
103 | # Parameterize weight decay independently of learning rate
104 | corrected_weight_decay = get_corrected_weight_decay(lr, weight_decay)
105 |
106 | param_groups = [
107 | {'params': wd_params},
108 | {'params': no_wd_params, 'weight_decay': 0},
109 | ]
110 |
111 | return AdamW(param_groups, lr = lr, weight_decay = corrected_weight_decay, betas=betas, eps=eps)
112 |
113 | def get_lion_optimizer(params, lr, weight_decay):
114 | params = list(params)
115 | wd_params, no_wd_params = separate_weight_decayable_params(params)
116 |
117 | # Parameterize weight decay as corrected of learning rate
118 | corrected_weight_decay = get_corrected_weight_decay(lr, weight_decay)
119 |
120 | param_groups = [
121 | {'params': wd_params},
122 | {'params': no_wd_params, 'weight_decay': 0},
123 | ]
124 |
125 | return Lion(param_groups, lr = lr, weight_decay = corrected_weight_decay)
--------------------------------------------------------------------------------
/diffusion/time_sampler.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 | from einops import rearrange
6 | import matplotlib.pyplot as plt
7 |
8 | class LossEMASampler(nn.Module):
9 | def __init__(self, n_bins=100, ema_decay=0.9, gamma_min=-15, gamma_max=15):
10 | super().__init__()
11 |
12 | self.n_bins = n_bins
13 | self.ema_decay = ema_decay
14 | # Register loss bins as a buffer so that it is saved with the model
15 | self.register_buffer("_loss_bins", torch.ones((n_bins), dtype=torch.float64))
16 | self.register_buffer("_unweighted_loss_bins", torch.ones((n_bins), dtype=torch.float64))
17 | gamma_range = gamma_max - gamma_min
18 | self.bin_length = gamma_range/n_bins
19 | # Register step as a buffer so that it is saved with the model
20 | self.register_buffer("step", torch.tensor(0))
21 | self.gamma_min = gamma_min
22 | self.gamma_max = gamma_max
23 |
24 | def set_bins_to_loss_weight(self, loss_weighting):
25 | gamma_range = self.gamma_max - self.gamma_min
26 | gammas = torch.arange(self.n_bins, dtype=torch.float64) * gamma_range / self.n_bins + self.gamma_min
27 | self._loss_bins = loss_weighting(gammas).to(self._loss_bins.device)
28 |
29 |
30 | def weights(self, ):
31 | weights = self._loss_bins.clone()
32 | return weights
33 |
34 | def update_with_all_losses(self, gamma, losses):
35 | for i in range(self.n_bins):
36 | gamma0 = i*self.bin_length + self.gamma_min
37 | gamma1 = (i+1)*self.bin_length + self.gamma_min
38 |
39 | bin_mask = (gamma >= gamma0) & (gamma < gamma1)
40 | if bin_mask.any():
41 | self._loss_bins[i] = self.ema_decay * self._loss_bins[i] + (1-self.ema_decay) * losses[bin_mask].mean().item()
42 | self.step+=1
43 |
44 | def update_with_all_unweighted_losses(self, gamma, losses):
45 | for i in range(self.n_bins):
46 | gamma0 = i*self.bin_length + self.gamma_min
47 | gamma1 = (i+1)*self.bin_length + self.gamma_min
48 |
49 | bin_mask = (gamma >= gamma0) & (gamma < gamma1)
50 | if bin_mask.any():
51 | self._unweighted_loss_bins[i] = self.ema_decay * self._unweighted_loss_bins[i] + (1-self.ema_decay) * losses[bin_mask].mean().item()
52 | self.step+=1
53 |
54 | def sample(self, batch_size, device, uniform=False):
55 | if uniform:
56 | gamma = torch.rand((batch_size), device=device) * (self.gamma_max - self.gamma_min) + self.gamma_min
57 | density = torch.ones((batch_size), device=device)
58 | return gamma, density
59 | else:
60 | bin_weights = self.weights().to(device)
61 | bins = torch.multinomial(bin_weights, batch_size, replacement=True).to(device)
62 | samples = torch.rand((batch_size), device=device) * self.bin_length
63 | gamma = (samples + bins * self.bin_length + self.gamma_min)
64 | # Check all samples in [-gamma_min, gamma_max]
65 | assert (gamma >= self.gamma_min).all() and (gamma <= self.gamma_max).all()
66 | density = bin_weights[bins]
67 | return gamma, density
68 |
69 | def save_density(self, path):
70 | plt.figure(figsize=(10, 5))
71 | plt.plot(np.arange(self.n_bins)*self.bin_length + self.gamma_min, self.weights().cpu().numpy())
72 | plt.xlabel("gamma")
73 | plt.ylabel("Density of Adaptive Noise Schedule")
74 | plt.grid(True) # Add grid lines
75 | plt.savefig(path)
76 | plt.close()
77 |
78 | def save_cumulative_density(self, path):
79 | plt.figure(figsize=(10, 5))
80 | weights = self.weights().cpu().numpy()
81 | weights = weights/weights.sum()
82 | plt.plot(np.arange(self.n_bins)*self.bin_length + self.gamma_min, weights.cumsum())
83 | plt.xlabel("gamma")
84 | plt.ylabel("Cumulative Density of Adaptive Noise Schedule")
85 | plt.grid(True) # Add grid lines
86 | plt.savefig(path)
87 | plt.close()
88 |
89 | def save_loss_emas(self, path):
90 | plt.figure(figsize=(10, 5))
91 | plt.plot(np.arange(self.n_bins)*self.bin_length + self.gamma_min, self._loss_bins.cpu().numpy())
92 | plt.xlabel("gamma")
93 | plt.ylabel("Loss EMAs")
94 | plt.grid(True) # Add grid lines
95 | plt.savefig(path)
96 | plt.close()
97 |
98 | def save_unweighted_loss_emas(self, path):
99 | plt.figure(figsize=(10, 5))
100 | plt.plot(np.arange(self.n_bins)*self.bin_length + self.gamma_min, self._unweighted_loss_bins.cpu().numpy())
101 | plt.xlabel("gamma")
102 | plt.ylabel("Unweighted Loss EMAs")
103 | plt.grid(True) # Add grid lines
104 | plt.savefig(path)
105 | plt.close()
106 |
107 | def get_loss_emas(self):
108 | # Return dict of loss emas where the key is the gamma range
109 | loss_emas = {}
110 | for i in range(self.n_bins):
111 | gamma0 = i*self.bin_length + self.gamma_min
112 | gamma1 = (i+1)*self.bin_length + self.gamma_min
113 | loss_emas[f"{gamma0:.2f}-{gamma1:.2f}"] = self._loss_bins[i].item()
114 | return loss_emas
115 |
116 | def get_unweighted_loss_emas(self):
117 | # Return dict of loss emas where the key is the gamma range
118 | loss_emas = {}
119 | for i in range(self.n_bins):
120 | gamma0 = i*self.bin_length + self.gamma_min
121 | gamma1 = (i+1)*self.bin_length + self.gamma_min
122 | loss_emas[f"{gamma0:.2f}-{gamma1:.2f}"] = self._unweighted_loss_bins[i].item()
123 | return loss_emas
124 |
125 | def get_normalized_loss_emas(self):
126 | # Return dict of loss emas where the key is the gamma range
127 | loss_emas = {}
128 | denominator = self._loss_bins.sum().item()
129 | for i in range(self.n_bins):
130 | gamma0 = i*self.bin_length + self.gamma_min
131 | gamma1 = (i+1)*self.bin_length + self.gamma_min
132 | loss_emas[f"{gamma0:.2f}-{gamma1:.2f}"] = self._loss_bins[i].item() / denominator
133 | return loss_emas
134 |
--------------------------------------------------------------------------------
/evaluation/evaluate_transcript.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchaudio
3 | from transformers import Wav2Vec2Processor, HubertForCTC, Wav2Vec2FeatureExtractor, WavLMForXVector
4 | from evaluate import load
5 | from tqdm import tqdm
6 | from einops import rearrange, reduce, repeat
7 | import sox
8 | import math
9 |
10 |
11 | def chunker(seq, size):
12 | return (seq[pos:pos + size] for pos in range(0, len(seq), size))
13 |
14 | def trim_silence(wav, sample_rate):
15 | np_arr = wav.numpy().squeeze()
16 | # create a transformer
17 | tfm = sox.Transformer()
18 | tfm.silence(location=-1, silence_threshold=.1)
19 | # transform an in-memory array and return an array
20 | y_out = tfm.build_array(input_array=np_arr, sample_rate_in=sample_rate)
21 | duration = len(y_out)/sample_rate
22 | if duration < .5:
23 | return wav
24 |
25 | return torch.tensor(y_out).unsqueeze(0)
26 |
27 |
28 | @torch.inference_mode()
29 | def compute_wer(wavpath_list, text_list, wav_list=None, model_id='facebook/hubert-large-ls960-ft', truncate=False):
30 | processor = Wav2Vec2Processor.from_pretrained(model_id)
31 | model = HubertForCTC.from_pretrained(model_id).to('cuda')
32 | model.eval()
33 |
34 | if wav_list is None:
35 | wav_list = [torchaudio.load(wavpath) for wavpath in wavpath_list]
36 | waveform_list = []
37 | sample_rate_list = []
38 | for waveform, sample_rate in wav_list:
39 | waveform_list.append(waveform)
40 | sample_rate_list.append(sample_rate)
41 |
42 |
43 | asr_text = []
44 | for i in tqdm(range(len(wav_list))):
45 | waveform = rearrange(waveform_list[i].squeeze(), 'l -> () l')
46 | waveform = torchaudio.functional.resample(waveform, sample_rate_list[0], processor.feature_extractor.sampling_rate)
47 | if truncate:
48 | waveform = trim_silence(waveform, processor.feature_extractor.sampling_rate)
49 | input_values = processor(waveform, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt").input_values.squeeze(0).to('cuda')
50 | logits = model(input_values).logits
51 | predicted_ids = torch.argmax(logits, dim=-1)
52 | transcription = [transcript.lower().strip() for transcript in processor.batch_decode(predicted_ids)]
53 |
54 | asr_text.extend(transcription)
55 |
56 | if text_list is None:
57 | return asr_text
58 | print(f'asr_text: {asr_text[:3]}')
59 | print(f'text_list: {text_list[:3]}')
60 | wer = load("wer")
61 | wer_score = wer.compute(predictions=asr_text, references=text_list)
62 | return wer_score
63 |
--------------------------------------------------------------------------------
/img/esd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/justinlovelace/SESD/c59b41d7ec98922bdb6bce4217ebcf6cd1c0d2b2/img/esd.png
--------------------------------------------------------------------------------
/img/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/justinlovelace/SESD/c59b41d7ec98922bdb6bce4217ebcf6cd1c0d2b2/img/results.png
--------------------------------------------------------------------------------
/models/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from models.modules.norm import RMSNorm, ConvRMSNorm, LayerNorm
2 | from models.modules.transformer import ConditionableTransformer
3 | from models.modules.conv import MaskedConv1d
4 | from models.modules.position import VariationalFourierFeatures, RelativePositionalEmbedding, AbsolutePositionalEmbedding
5 | from models.modules.blocks import FeedForward
--------------------------------------------------------------------------------
/models/modules/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | import torch.nn.functional as F
4 |
5 | from einops import rearrange, reduce, repeat, pack, unpack
6 | from einops.layers.torch import Rearrange
7 | import math
8 |
9 | from models.modules.norm import RMSNorm
10 |
11 | def exists(x):
12 | return x is not None
13 |
14 | def zero_init_(m):
15 | nn.init.zeros_(m.weight)
16 | if exists(m.bias):
17 | nn.init.zeros_(m.bias)
18 |
19 | class SwiGLU(nn.Module):
20 | def forward(self, x):
21 | x, gate = x.chunk(2, dim = -1)
22 | return F.silu(gate) * x
23 |
24 | class FeedForward(nn.Module):
25 | def __init__(
26 | self,
27 | dim,
28 | mult = 4,
29 | time_cond_dim = None,
30 | dropout = 0.,
31 | ):
32 | super().__init__()
33 | self.norm = RMSNorm(dim)
34 | inner_dim = int(dim * mult * 2 / 3)
35 | dim_out = dim
36 |
37 | self.time_cond = None
38 | self.dropout = nn.Dropout(dropout)
39 |
40 | if dropout > 0:
41 | self.net = nn.Sequential(
42 | nn.Linear(dim, inner_dim*2),
43 | SwiGLU(),
44 | nn.Dropout(dropout),
45 | nn.Linear(inner_dim, dim_out)
46 | )
47 | else:
48 | self.net = nn.Sequential(
49 | nn.Linear(dim, inner_dim*2),
50 | SwiGLU(),
51 | nn.Linear(inner_dim, dim_out)
52 | )
53 |
54 | if exists(time_cond_dim):
55 | self.time_cond = nn.Sequential(
56 | nn.SiLU(),
57 | nn.Linear(time_cond_dim, dim * 3),
58 | Rearrange('b d -> b 1 d')
59 | )
60 |
61 | zero_init_(self.time_cond[-2])
62 | else:
63 | zero_init_(self.net[-1])
64 |
65 |
66 | def forward(self, x, time = None):
67 | x = self.norm(x)
68 | if exists(self.time_cond):
69 | assert exists(time)
70 | scale, shift, gate = self.time_cond(time).chunk(3, dim = 2)
71 | x = (x * (scale + 1)) + shift
72 |
73 | x = self.net(x)
74 |
75 | if exists(self.time_cond):
76 | x = x*gate
77 |
78 | return x
79 |
--------------------------------------------------------------------------------
/models/modules/conv.py:
--------------------------------------------------------------------------------
1 | import math
2 | import typing as tp
3 | import warnings
4 |
5 | import torch
6 | from torch import nn
7 | from torch.nn import functional as F
8 |
9 | from einops import rearrange, reduce, repeat
10 |
11 |
12 | class MaskedConv1d(nn.Conv1d):
13 | """Wrapper around Conv1d to provide masking functionality
14 | """
15 | def __init__(self, *args, **kwargs):
16 | super().__init__(*args, **kwargs)
17 |
18 | def forward(self, x, audio_mask=None):
19 | if audio_mask is not None:
20 | assert audio_mask.shape[-1] == x.shape[-1]
21 | conv_mask = rearrange(audio_mask, 'b l -> b () l')
22 | x = torch.where(conv_mask, x, 0)
23 | x = super().forward(x)
24 | return x
--------------------------------------------------------------------------------
/models/modules/norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | import torch.nn.functional as F
4 |
5 | from einops import rearrange, reduce, repeat
6 | from einops.layers.torch import Rearrange
7 |
8 |
9 | class RMSNorm(nn.Module):
10 | def __init__(self, dim):
11 | super().__init__()
12 | self.scale = dim ** 0.5
13 | self.gamma = nn.Parameter(torch.ones(dim))
14 |
15 | def forward(self, x):
16 | out = F.normalize(x, dim = -1) * self.scale * self.gamma
17 | return out
18 |
19 | class ConvRMSNorm(RMSNorm):
20 | """
21 | Convolution-friendly RMSNorm that moves channels to last dimensions
22 | before running the normalization and moves them back to original position right after.
23 | """
24 | def __init__(self, dim):
25 | super().__init__(dim)
26 |
27 | def forward(self, x):
28 | x = rearrange(x, 'b ... t -> b t ...')
29 | x = super().forward(x)
30 | x = rearrange(x, 'b t ... -> b ... t')
31 | return x
32 |
33 | # use layernorm without bias, more stable
34 |
35 | class LayerNorm(nn.Module):
36 | def __init__(self, dim):
37 | super().__init__()
38 | self.gamma = nn.Parameter(torch.ones(dim))
39 | self.register_buffer("beta", torch.zeros(dim))
40 |
41 | def forward(self, x):
42 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
--------------------------------------------------------------------------------
/models/modules/position.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | from einops import rearrange, einsum, repeat
5 |
6 | from models.modules.blocks import FeedForward
7 |
8 | def exists(val):
9 | return val is not None
10 |
11 | class VariationalFourierFeatures(nn.Module):
12 | """ following https://arxiv.org/abs/2107.00630 """
13 |
14 | def __init__(self, n_min=0, n_max=6, step=1):
15 | super().__init__()
16 | assert n_min <= n_max
17 | self.n_min = n_min
18 | self.n_max = n_max
19 | self.step = step
20 |
21 | def forward(self, x):
22 | # Create Base 2 Fourier features
23 | w = 2.**torch.arange(self.n_min, self.n_max+1, self.step, device = x.device, dtype = x.dtype) * 2 * math.pi
24 |
25 | if len(x.shape) == 3:
26 | w = repeat(w, 'f -> b l f', b = x.shape[0], l = x.shape[1])
27 | freqs = einsum(x, w, 'b l d, b l f -> b l d f')
28 | freqs = rearrange(freqs, 'b l d f -> b l (d f)')
29 | fouriered = torch.cat([x, freqs.sin(), freqs.cos()], dim=-1)
30 | elif len(x.shape) == 1:
31 | w = repeat(w, 'f -> l f', l = x.shape[0])
32 | freqs = einsum(x, w, 'l, l f -> l f')
33 | x = rearrange(x, 'l -> l ()')
34 | fouriered = torch.cat([x, freqs.sin(), freqs.cos()], dim=-1)
35 | return fouriered
36 |
37 | class FourierFeatureEmbedding(nn.Module):
38 | def __init__(self, dim, n_min=0, n_max=6, n_layers=3):
39 | super().__init__()
40 | self.variational_fourier_features = VariationalFourierFeatures(n_min=n_min, n_max=n_max)
41 |
42 | self.init_proj = nn.Linear((1+(n_max-n_min))*2+1, dim)
43 |
44 | self.layers = nn.ModuleList([FeedForward(dim) for _ in range(n_layers)])
45 |
46 | def forward(self, x):
47 | fourier_emb = self.init_proj(self.variational_fourier_features(x))
48 | for layer in self.layers:
49 | fourier_emb = layer(fourier_emb) + fourier_emb
50 |
51 | return fourier_emb
52 |
53 | class RelativePositionalEmbedding(nn.Module):
54 | def __init__(self, dim, n_min=0, n_max=6, n_layers=3):
55 | super().__init__()
56 | self.variational_fourier_features = VariationalFourierFeatures(n_min=n_min, n_max=n_max)
57 |
58 | self.init_proj = nn.Linear((1+(n_max-n_min))*2+1, dim)
59 |
60 | self.layers = nn.ModuleList([FeedForward(dim) for _ in range(n_layers)])
61 |
62 | def forward(self, x, attention_mask):
63 | position_indices = torch.arange(x.shape[1], device = x.device)
64 | # Need to handle masked contexts from classifier free guidance
65 | context_len = torch.sum(attention_mask, dim=-1)
66 | # Replace 0s with 1s to avoid divide by 0
67 | context_len = torch.where(context_len == 0, torch.ones_like(context_len), context_len)
68 | relative_position = repeat(position_indices, 'l -> b l', b = x.shape[0])/(context_len.unsqueeze(-1))
69 | relative_position = rearrange(relative_position, 'b l -> b l ()')
70 | relative_position_emb = self.init_proj(self.variational_fourier_features(relative_position))
71 | for layer in self.layers:
72 | relative_position_emb = layer(relative_position_emb) + relative_position_emb
73 |
74 | return relative_position_emb
75 |
76 |
77 | class AbsolutePositionalEmbedding(nn.Module):
78 | def __init__(self, dim, max_seq_len=512):
79 | super().__init__()
80 | self.scale = dim ** -0.5
81 | self.max_seq_len = max_seq_len
82 | self.emb = nn.Embedding(max_seq_len, dim)
83 | nn.init.normal_(self.emb.weight, std=.01)
84 |
85 | def forward(self, x, pos = None):
86 | seq_len, device = x.shape[1], x.device
87 | assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
88 |
89 | if not exists(pos):
90 | pos = torch.arange(seq_len, device = device)
91 |
92 | pos_emb = self.emb(pos)
93 | pos_emb = pos_emb * self.scale
94 | return pos_emb
95 |
96 | # From https://github.com/lucidrains/x-transformers/blob/c7cc22268c8ebceef55fe78343197f0af62edf18/x_transformers/x_transformers.py#L272
97 | class DynamicPositionBias(nn.Module):
98 | def __init__(self, dim, *, heads, depth=3, log_distance = False):
99 | super().__init__()
100 | assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
101 | self.log_distance = log_distance
102 |
103 | self.init_proj = nn.Linear(1, dim)
104 |
105 | self.layers = nn.ModuleList([FeedForward(dim) for _ in range(depth)])
106 |
107 | self.out = nn.Linear(dim, heads)
108 |
109 | @property
110 | def device(self):
111 | return next(self.parameters()).device
112 |
113 | def forward(self, i, j):
114 | assert i == j
115 | n, device = j, self.device
116 |
117 | # get the (n x n) matrix of distances
118 | seq_arange = torch.arange(n, device = device)
119 | context_arange = torch.arange(n, device = device)
120 | indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
121 | indices += (n - 1)
122 |
123 | # input to continuous positions MLP
124 | pos = torch.arange(-n + 1, n, device = device).float()
125 | pos = rearrange(pos, '... -> ... 1')
126 |
127 | if self.log_distance:
128 | pos = torch.sign(pos) * torch.log(pos.abs() + 1) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)
129 |
130 | pos_emb = self.init_proj(pos)
131 | for layer in self.layers:
132 | pos_emb = layer(pos_emb) + pos_emb
133 | pos_biases = self.out(pos_emb)
134 |
135 | # get position biases
136 | bias = pos_biases[indices]
137 | bias = rearrange(bias, 'i j h -> h i j')
138 | return bias
139 |
--------------------------------------------------------------------------------
/models/modules/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | import torch.nn.functional as F
4 |
5 | from einops import rearrange, reduce, repeat, pack, unpack
6 | from einops.layers.torch import Rearrange
7 | import math
8 |
9 | from models.modules.norm import RMSNorm
10 | from models.modules.position import RelativePositionalEmbedding, DynamicPositionBias
11 | from models.modules.blocks import FeedForward
12 |
13 | def exists(x):
14 | return x is not None
15 |
16 | def zero_init_(m):
17 | nn.init.zeros_(m.weight)
18 | if exists(m.bias):
19 | nn.init.zeros_(m.bias)
20 |
21 | class Attention(nn.Module):
22 | def __init__(
23 | self,
24 | dim,
25 | dim_head = 32,
26 | time_cond_dim = None,
27 | dropout=0.
28 | ):
29 | super().__init__()
30 | assert dim % dim_head == 0, 'Dimension must be divisible by the head dimension'
31 | self.heads = dim // dim_head
32 |
33 | self.dropout = dropout
34 | self.time_cond = None
35 |
36 | self.norm = RMSNorm(dim)
37 |
38 | self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
39 | self.to_out = nn.Linear(dim, dim)
40 |
41 | if exists(time_cond_dim):
42 | self.time_cond = nn.Sequential(
43 | nn.SiLU(),
44 | nn.Linear(time_cond_dim, dim * 3),
45 | Rearrange('b d -> b 1 d')
46 | )
47 | zero_init_(self.time_cond[-2])
48 | else:
49 | zero_init_(self.to_out)
50 |
51 | def forward(self, x, attn_bias, time=None, audio_mask=None):
52 | b, c, n = x.shape
53 |
54 | x = self.norm(x)
55 |
56 | if exists(self.time_cond):
57 | assert exists(time)
58 | scale, shift, gate = self.time_cond(time).chunk(3, dim = 2)
59 | x = (x * (scale + 1)) + shift
60 |
61 | qkv = self.to_qkv(x).chunk(3, dim = 2)
62 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads).contiguous(), qkv)
63 |
64 | attn_bias = repeat(attn_bias, 'h i j -> b h i j', b=b)
65 |
66 | if exists(audio_mask):
67 | mask_value = -torch.finfo(q.dtype).max
68 | mask = rearrange(audio_mask, 'b l -> b () () l')
69 | attn_bias = attn_bias.masked_fill(~mask, mask_value)
70 |
71 | out = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout if self.training else 0., attn_mask=attn_bias)
72 | out = rearrange(out, 'b h n d -> b n (h d)')
73 |
74 | out = self.to_out(out)
75 | if exists(self.time_cond):
76 | out = out*gate
77 |
78 | return out
79 |
80 | class CrossAttention(nn.Module):
81 | def __init__(self, dim, dim_context, dim_head = 32, time_cond_dim=None, dropout=0., position_aware=False, cross_attn_pos_dim=None,):
82 | super().__init__()
83 | assert dim % dim_head == 0, 'Dimension must be divisible by the head dimension'
84 | self.heads = dim // dim_head
85 | self.dropout = dropout
86 | self.norm = RMSNorm(dim)
87 | self.time_cond = None
88 | self.time_cond = None
89 | self.position_aware = position_aware
90 |
91 | if self.position_aware:
92 | assert exists(cross_attn_pos_dim)
93 | self.pos_to_k = nn.Linear(cross_attn_pos_dim, dim, bias=False)
94 |
95 |
96 | self.norm_context = nn.LayerNorm(dim_context)
97 |
98 | self.null_kv = nn.Parameter(torch.randn(2, dim))
99 | self.to_q = nn.Linear(dim, dim, bias = False)
100 | self.to_kv = nn.Linear(dim_context, dim * 2, bias = False)
101 | self.to_out = nn.Linear(dim, dim)
102 |
103 | if exists(time_cond_dim):
104 | self.time_cond = nn.Sequential(
105 | nn.SiLU(),
106 | nn.Linear(time_cond_dim, dim * 3),
107 | Rearrange('b d -> b 1 d')
108 | )
109 | zero_init_(self.time_cond[-2])
110 | else:
111 | zero_init_(self.to_out)
112 |
113 | self.q_norm = RMSNorm(dim_head)
114 | self.k_norm = RMSNorm(dim_head)
115 |
116 | def forward(self, x, context, context_mask, time=None, context_pos=None,):
117 | '''
118 | x: [B, L_audio, d_unet]
119 | context: [B, L_text, d_lm]
120 | context_mask: [B, L_text]
121 | '''
122 | b, c, n = x.shape
123 | x = self.norm(x)
124 | if exists(self.time_cond):
125 | assert exists(time)
126 | scale, shift, gate = self.time_cond(time).chunk(3, dim = 2)
127 | x = (x * (scale + 1)) + shift
128 | context = self.norm_context(context)
129 |
130 | q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
131 |
132 | if self.position_aware:
133 | assert exists(context_pos)
134 | k = k + self.pos_to_k(context_pos)
135 |
136 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads).contiguous(), (q, k, v))
137 |
138 | # Null value for classifier free guidance
139 | nk, nv = map(lambda t: repeat(t, '(h d) -> b h 1 d', b = b, h=self.heads), self.null_kv.unbind(dim = -2))
140 | k = torch.cat((nk, k), dim = -2)
141 | v = torch.cat((nv, v), dim = -2)
142 |
143 | query_len = q.shape[2]
144 |
145 | # RMSNorm Trick for stability
146 | q = self.q_norm(q)
147 | k = self.k_norm(k)
148 | # Masking pad tokens
149 | context_mask = F.pad(context_mask, (1, 0), value = True)
150 | context_mask = repeat(context_mask, 'b j -> b h q_len j', h=self.heads, q_len=query_len)
151 |
152 | out = F.scaled_dot_product_attention(q, k, v, attn_mask=context_mask, dropout_p=self.dropout if self.training else 0.)
153 | # attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask
154 | # attn_weight = torch.softmax((q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))) + attn_mask, dim=-1)
155 | # out = attn_weight @ v
156 |
157 | out = rearrange(out, 'b h n d -> b n (h d)')
158 |
159 | out = self.to_out(out)
160 | if exists(self.time_cond):
161 | out = out*gate
162 |
163 | return out
164 |
165 |
166 | class ConditionableTransformer(nn.Module):
167 | def __init__(
168 | self,
169 | dim,
170 | dim_context,
171 | *,
172 | num_layers,
173 | time_cond_dim,
174 | dim_head = 64,
175 | ff_mult = 4,
176 | dropout=0.0,
177 | position_aware_cross_attention=False,
178 | num_registers=8,
179 | dense_connections=True,
180 | ):
181 | super().__init__()
182 | self.dim = dim
183 | self.num_layers = num_layers
184 |
185 | self.position_aware = position_aware_cross_attention
186 | if self.position_aware:
187 | cross_attn_pos_dim = dim//4
188 | self.context_pos_emb = RelativePositionalEmbedding(cross_attn_pos_dim, n_min=0, n_max=6)
189 | else:
190 | cross_attn_pos_dim = None
191 |
192 | self.layers = nn.ModuleList([])
193 | for _ in range(num_layers):
194 | self.layers.append(nn.ModuleList([
195 | Attention(dim = dim, dim_head = dim_head, dropout=dropout),
196 | CrossAttention(dim = dim, dim_head = dim_head, dim_context=dim_context, dropout=dropout, cross_attn_pos_dim=cross_attn_pos_dim, position_aware=position_aware_cross_attention),
197 | FeedForward(dim=dim, mult=ff_mult, time_cond_dim=time_cond_dim, dropout=dropout)
198 | ]))
199 |
200 | self.dynamic_pos_bias = DynamicPositionBias(dim = dim // 4, heads = dim // dim_head, log_distance = False, depth = 2)
201 |
202 | self.has_registers = num_registers > 0
203 | if num_registers > 0:
204 | self.memory_tokens = nn.Parameter(torch.randn(num_registers, dim))
205 |
206 | if dense_connections:
207 | assert num_layers % 2 == 0, 'number of layers must be divisible by 2 for dense connections'
208 | self.dense_blocks = nn.ModuleList([nn.Linear(dim*2, dim) for i in range(num_layers // 2)])
209 | self.dense_connections = dense_connections
210 |
211 | def forward(
212 | self,
213 | x,
214 | *,
215 | time,
216 | context,
217 | context_mask,
218 | audio_mask,
219 | ):
220 |
221 | if self.position_aware:
222 | context_pos = self.context_pos_emb(context, context_mask)
223 |
224 | else:
225 | context_pos = None
226 |
227 | if self.has_registers:
228 | mem = repeat(self.memory_tokens, 'l d -> b l d', b = x.shape[0])
229 | x, mem_packed_shape = pack((mem, x), 'b * d')
230 |
231 | mem_attn_mask = torch.ones_like(mem[:,:,0], dtype=torch.bool)
232 | audio_mask = torch.cat((mem_attn_mask, audio_mask), dim=1)
233 |
234 |
235 | i = j = x.shape[1]
236 | attn_bias = self.dynamic_pos_bias(i, j)
237 |
238 | hiddens = []
239 | for idx, (attn, cross_attn, ff) in enumerate(self.layers):
240 | if self.dense_connections:
241 | if self.has_registers:
242 | if idx < (self.num_layers // 2):
243 | # store hidden states for dense connections
244 | hiddens.append(x[:, mem.shape[1]:, :])
245 | else:
246 | concat_feats = torch.cat((x[:, mem.shape[1]:, :], hiddens.pop()), dim=-1)
247 | x[:, mem.shape[1]:, :] = self.dense_blocks[idx - (self.num_layers // 2)](concat_feats)
248 | else:
249 | if idx < (self.num_layers // 2):
250 | # store hidden states for dense connections
251 | hiddens.append(x)
252 | else:
253 | concat_feats = torch.cat((x, hiddens.pop()), dim=-1)
254 | x = self.dense_blocks[idx - (self.num_layers // 2)](concat_feats)
255 | res = x
256 | x = attn(x, attn_bias=attn_bias, audio_mask=audio_mask) + res
257 |
258 | res = x
259 | x = cross_attn(x, context = context, context_mask=context_mask, context_pos=context_pos) + res
260 |
261 | res = x
262 | x = ff(x, time=time) + res
263 |
264 | if self.has_registers:
265 | mem, x = unpack(x, mem_packed_shape, 'b * d')
266 |
267 | return x
268 |
--------------------------------------------------------------------------------
/models/transformer_wrapper.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | from einops import rearrange, reduce
5 | from functools import partial
6 |
7 | from models.modules import RMSNorm, ConvRMSNorm, ConditionableTransformer, LayerNorm, MaskedConv1d
8 | # helpers functions
9 |
10 | def exists(x):
11 | return x is not None
12 |
13 | def default(val, d):
14 | if exists(val):
15 | return val
16 | return d() if callable(d) else d
17 |
18 | def zero_init_(m):
19 | nn.init.zeros_(m.weight)
20 | if exists(m.bias):
21 | nn.init.zeros_(m.bias)
22 |
23 | def masked_mean(t, *, dim, mask = None):
24 | if not exists(mask):
25 | return t.mean(dim = dim)
26 |
27 | denom = mask.sum(dim = dim, keepdim = True)
28 | mask = rearrange(mask, 'b n -> b n 1')
29 | masked_t = t.masked_fill(~mask, 0.)
30 |
31 | return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)
32 |
33 |
34 | # sinusoidal positional embeds
35 |
36 | class SinusoidalPosEmb(nn.Module):
37 | def __init__(self, dim):
38 | super().__init__()
39 | self.dim = dim
40 |
41 | def forward(self, x):
42 | device = x.device
43 | half_dim = self.dim // 2
44 | emb = math.log(10000) / (half_dim - 1)
45 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
46 | emb = x[:, None] * emb[None, :]
47 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
48 | return emb
49 |
50 | class TransformerWrapper(nn.Module):
51 | def __init__(
52 | self,
53 | dim,
54 | text_dim,
55 | channels = 128,
56 | position_aware_cross_attention=False,
57 | inpainting_embedding = False,
58 | num_transformer_layers = 8,
59 | dropout=0.0,
60 | ):
61 | super().__init__()
62 |
63 |
64 | self.channels = channels
65 | self.out_dim = channels
66 |
67 | input_channels = channels
68 |
69 | self.init_conv = nn.Conv1d(input_channels, dim, 1)
70 |
71 |
72 | if inpainting_embedding:
73 | self.inpainting_embedding = nn.Embedding(2, dim)
74 | else:
75 | self.inpainting_embedding = None
76 |
77 | # time embeddings
78 |
79 | time_dim = dim * 2
80 |
81 | sinu_pos_emb = SinusoidalPosEmb(dim)
82 | fourier_dim = dim
83 |
84 | self.time_mlp = nn.Sequential(
85 | sinu_pos_emb,
86 | nn.Linear(fourier_dim, time_dim),
87 | nn.SiLU(),
88 | nn.Linear(time_dim, time_dim)
89 | )
90 |
91 | # layers
92 |
93 | self.transformer = ConditionableTransformer(dim, dim_context=text_dim, num_layers=num_transformer_layers, time_cond_dim=time_dim, dropout=dropout, position_aware_cross_attention=position_aware_cross_attention)
94 |
95 | self.final_conv = nn.Sequential(
96 | ConvRMSNorm(dim),
97 | nn.SiLU(),
98 | nn.Conv1d(dim, self.out_dim, 1)
99 | )
100 | zero_init_(self.final_conv[-1])
101 |
102 |
103 | self.to_text_non_attn_cond = nn.Sequential(
104 | nn.LayerNorm(text_dim),
105 | nn.Linear(text_dim, time_dim),
106 | nn.SiLU(),
107 | nn.Linear(time_dim, time_dim)
108 | )
109 |
110 | def forward(self, x, time_cond, text_cond=None, text_cond_mask=None, inpainting_mask=None, audio_mask=None):
111 | if not exists(audio_mask):
112 | audio_mask = torch.ones((x.shape[0], x.shape[2]), dtype=torch.bool, device=x.device)
113 | x = self.init_conv(x)
114 | if exists(self.inpainting_embedding):
115 | assert exists(inpainting_mask)
116 | inpainting_emb = self.inpainting_embedding(inpainting_mask)
117 | x = x + rearrange(inpainting_emb, 'b l c -> b c l')
118 |
119 | mean_pooled_context = masked_mean(text_cond, dim=1, mask=text_cond_mask)
120 | text_mean_cond = self.to_text_non_attn_cond(mean_pooled_context)
121 |
122 | # Rescale continuous time [0,1] to similar range as Ho et al. 2020
123 | t = self.time_mlp(time_cond*1000)
124 |
125 | t = t + text_mean_cond
126 |
127 | x = rearrange(x, 'b c l -> b l c')
128 |
129 | x = self.transformer(x, context=text_cond, context_mask=text_cond_mask, time=t, audio_mask=audio_mask)
130 | x = rearrange(x, 'b l c -> b c l')
131 |
132 | return self.final_conv(x)
--------------------------------------------------------------------------------
/models/unet.py:
--------------------------------------------------------------------------------
1 | import math
2 | from functools import partial
3 |
4 | import torch
5 | from torch import nn, einsum
6 | import torch.nn.functional as F
7 |
8 | from einops import rearrange, reduce, repeat
9 | from einops.layers.torch import Rearrange, Reduce
10 |
11 | from models.modules import RMSNorm, ConvRMSNorm, ConditionableTransformer, LayerNorm, MaskedConv1d
12 |
13 | ENCODEC_DIM = 128
14 |
15 | # helpers functions
16 |
17 | def exists(x):
18 | return x is not None
19 |
20 | def default(val, d):
21 | if exists(val):
22 | return val
23 | return d() if callable(d) else d
24 |
25 | # small helper modules
26 |
27 | class Residual(nn.Module):
28 | def __init__(self, fn):
29 | super().__init__()
30 | self.fn = fn
31 |
32 | def forward(self, x, *args, **kwargs):
33 | return self.fn(x, *args, **kwargs) + x
34 |
35 | def Upsample(dim, dim_out = None):
36 | return nn.Sequential(
37 | nn.Upsample(scale_factor = 2, mode = 'nearest'),
38 | MaskedConv1d(dim, default(dim_out, dim), 3, padding = 1)
39 | )
40 |
41 | def Downsample(dim, dim_out = None):
42 | return MaskedConv1d(dim, default(dim_out, dim), 4, 2, 1)
43 |
44 |
45 | # sinusoidal positional embeds
46 |
47 | class SinusoidalPosEmb(nn.Module):
48 | def __init__(self, dim):
49 | super().__init__()
50 | self.dim = dim
51 |
52 | def forward(self, x):
53 | device = x.device
54 | half_dim = self.dim // 2
55 | emb = math.log(10000) / (half_dim - 1)
56 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
57 | emb = x[:, None] * emb[None, :]
58 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
59 | return emb
60 |
61 | class RandomOrLearnedSinusoidalPosEmb(nn.Module):
62 | """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """
63 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
64 |
65 | def __init__(self, dim, is_random = False):
66 | super().__init__()
67 | assert (dim % 2) == 0
68 | half_dim = dim // 2
69 | self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)
70 |
71 | def forward(self, x):
72 | x = rearrange(x, 'b -> b 1')
73 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
74 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
75 | fouriered = torch.cat((x, fouriered), dim = -1)
76 | return fouriered
77 | # building block modules
78 |
79 | class Block(nn.Module):
80 | def __init__(self, dim, dim_out, groups = 8, input_connection=False, input_dim=None):
81 | super().__init__()
82 | self.input_connection = input_connection
83 | if input_connection:
84 | assert exists(input_dim)
85 | self.proj = MaskedConv1d(dim + input_dim, dim_out, 3, padding = 1)
86 | else:
87 | self.proj = MaskedConv1d(dim, dim_out, 3, padding = 1)
88 | self.norm = ConvRMSNorm(dim)
89 | self.act = nn.SiLU()
90 |
91 | def forward(self, x, scale_shift = None, audio_mask=None, input_data=None):
92 | x = self.norm(x)
93 | if exists(scale_shift):
94 | scale, shift = scale_shift
95 | x = x * (scale + 1) + shift
96 | x = self.act(x)
97 | if self.input_connection:
98 | assert exists(input_data)
99 | x = torch.cat((x, input_data), dim=1)
100 |
101 | x = self.proj(x, audio_mask)
102 |
103 | return x
104 |
105 | class ResnetBlock(nn.Module):
106 | def __init__(self, dim, dim_out, *, input_connection=False, input_dim=None, time_emb_dim = None, groups = 32):
107 | super().__init__()
108 | if input_connection:
109 | assert exists(input_dim)
110 | self.input_connection = input_connection
111 | self.mlp = None
112 | if exists(time_emb_dim):
113 | self.mlp = nn.Sequential(
114 | nn.SiLU(),
115 | nn.Linear(time_emb_dim, dim_out * 3),
116 | Rearrange('b c -> b c 1')
117 | )
118 | zero_init_(self.mlp[-2])
119 |
120 | if input_connection:
121 | self.block1 = Block(dim, dim_out, groups = groups, input_connection=True, input_dim=input_dim)
122 | else:
123 | self.block1 = Block(dim, dim_out, groups = groups)
124 | self.block2 = Block(dim_out, dim_out, groups = groups)
125 | self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
126 |
127 | def forward(self, x, time_emb = None, audio_mask=None, input_data=None):
128 |
129 | scale_shift = None
130 | if exists(self.mlp):
131 | assert exists(time_emb)
132 | scale_shift = self.mlp(time_emb)
133 | scale, shift, gate = scale_shift.chunk(3, dim = 1)
134 | scale_shift = (scale, shift)
135 |
136 |
137 | h = self.block1(x, audio_mask=audio_mask, input_data=input_data)
138 |
139 | h = self.block2(h, scale_shift=scale_shift, audio_mask=audio_mask)
140 |
141 | if exists(self.mlp):
142 | h = h*gate
143 |
144 | return h + self.res_conv(x)
145 |
146 | class LinearAttention(nn.Module):
147 | def __init__(self, dim, heads = 4, dim_head = 32):
148 | super().__init__()
149 | self.scale = dim_head ** -0.5
150 | self.heads = heads
151 | hidden_dim = dim_head * heads
152 | self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias = False)
153 |
154 | self.to_out = nn.Conv1d(hidden_dim, dim, 1)
155 | zero_init_(self.to_out)
156 |
157 | def forward(self, x, audio_mask=None):
158 | b, c, n = x.shape
159 | qkv = self.to_qkv(x).chunk(3, dim = 1)
160 | q, k, v = map(lambda t: rearrange(t, 'b (h c) n -> b h c n', h = self.heads), qkv)
161 |
162 | if exists(audio_mask):
163 | mask_value = -torch.finfo(q.dtype).max
164 | mask = audio_mask[:, None, None, :]
165 | k = k.masked_fill(~mask, mask_value)
166 | v = v.masked_fill(~mask, 0.)
167 | del mask
168 |
169 | q = q.softmax(dim = -2)
170 | k = k.softmax(dim = -1)
171 |
172 | q = q * self.scale
173 |
174 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
175 |
176 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
177 | out = rearrange(out, 'b h c n -> b (h c) n', h = self.heads)
178 | return self.to_out(out)
179 |
180 | def l2norm(t):
181 | return F.normalize(t, dim = -1)
182 |
183 |
184 | def masked_mean(t, *, dim, mask = None):
185 | if not exists(mask):
186 | return t.mean(dim = dim)
187 |
188 | denom = mask.sum(dim = dim, keepdim = True)
189 | mask = rearrange(mask, 'b n -> b n 1')
190 | masked_t = t.masked_fill(~mask, 0.)
191 |
192 | return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)
193 |
194 | class PreNorm(nn.Module):
195 | def __init__(self, dim, fn):
196 | super().__init__()
197 | self.fn = fn
198 | self.norm = ConvRMSNorm(dim)
199 |
200 | def forward(self, x, *args, **kwargs):
201 | x = self.norm(x)
202 | return self.fn(x, *args, **kwargs)
203 |
204 | def zero_init_(m):
205 | nn.init.zeros_(m.weight)
206 | if exists(m.bias):
207 | nn.init.zeros_(m.bias)
208 |
209 | # model
210 |
211 | class Unet1D(nn.Module):
212 | def __init__(
213 | self,
214 | dim,
215 | text_dim,
216 | init_dim = None,
217 | out_dim = None,
218 | dim_mults=(1, 2, 4, 8),
219 | channels = 128,
220 | position_aware_cross_attention=False,
221 | inpainting_embedding = False,
222 | resnet_block_groups = 32,
223 | scale_skip_connection=False,
224 | num_transformer_layers = 3,
225 | dropout=0.0,
226 | num_transformer_registers=8,
227 | input_connection=False,
228 | ):
229 | super().__init__()
230 |
231 |
232 | self.channels = channels
233 |
234 | input_channels = channels
235 |
236 | init_dim = default(init_dim, dim)
237 | self.init_conv = nn.Conv1d(input_channels, init_dim, 1)
238 |
239 | dims = [init_dim, *map(lambda m: int(dim * m), dim_mults)]
240 | in_out = list(zip(dims[:-1], dims[1:]))
241 |
242 | block_klass = partial(ResnetBlock, groups = resnet_block_groups, input_connection=input_connection, input_dim=channels)
243 |
244 | if inpainting_embedding:
245 | self.inpainting_embedding = nn.Embedding(2, init_dim)
246 | else:
247 | self.inpainting_embedding = None
248 |
249 | # time embeddings
250 |
251 | time_dim = dim * 2
252 |
253 | sinu_pos_emb = SinusoidalPosEmb(dim)
254 | fourier_dim = dim
255 |
256 | self.time_mlp = nn.Sequential(
257 | sinu_pos_emb,
258 | nn.Linear(fourier_dim, time_dim),
259 | nn.SiLU(),
260 | nn.Linear(time_dim, time_dim)
261 | )
262 |
263 | # layers
264 |
265 | self.downs = nn.ModuleList([])
266 | self.ups = nn.ModuleList([])
267 | self.input_connection = input_connection
268 | num_resolutions = len(in_out)
269 |
270 | for ind, (dim_in, dim_out) in enumerate(in_out):
271 | is_last = ind >= (num_resolutions - 1)
272 |
273 | self.downs.append(nn.ModuleList([
274 | block_klass(dim_in, dim_in, time_emb_dim = time_dim),
275 | block_klass(dim_in, dim_in, time_emb_dim = time_dim),
276 | Residual(PreNorm(dim_in, LinearAttention(dim_in))),
277 | Downsample(dim_in, dim_out) if not is_last else MaskedConv1d(dim_in, dim_out, 3, padding = 1)
278 | ]))
279 |
280 | mid_dim = dims[-1]
281 | self.transformer = ConditionableTransformer(mid_dim, dim_context=text_dim, num_layers=num_transformer_layers, time_cond_dim=time_dim, dropout=dropout, position_aware_cross_attention=position_aware_cross_attention, num_registers=num_transformer_registers)
282 |
283 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
284 | is_last = ind == (len(in_out) - 1)
285 |
286 | self.ups.append(nn.ModuleList([
287 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
288 | block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
289 | Residual(PreNorm(dim_out, LinearAttention(dim_out))),
290 | Upsample(dim_out, dim_in) if not is_last else MaskedConv1d(dim_out, dim_in, 3, padding = 1)
291 | ]))
292 |
293 | self.out_dim = channels
294 |
295 | self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
296 |
297 | self.final_conv = nn.Sequential(
298 | ConvRMSNorm(dim),
299 | nn.SiLU(),
300 | nn.Conv1d(dim, self.out_dim, 1)
301 | )
302 | zero_init_(self.final_conv[-1])
303 |
304 | # Accelerates convergence for image diffusion models
305 | # Use it by default, but haven't ablated
306 | self.scale_skip_connection = (2 ** -0.5) if scale_skip_connection else 1
307 |
308 | self.to_text_non_attn_cond = nn.Sequential(
309 | nn.LayerNorm(text_dim),
310 | nn.Linear(text_dim, time_dim),
311 | nn.SiLU(),
312 | nn.Linear(time_dim, time_dim)
313 | )
314 |
315 | def forward(self, x, time_cond, text_cond=None, text_cond_mask=None, inpainting_mask=None, audio_mask=None):
316 | assert x.shape[-1] % (2**(len(self.downs)-1)) == 0, f'Length of the audio latent must be a factor of {2**(len(self.downs)-1)}'
317 | if not exists(audio_mask):
318 | audio_mask = torch.ones((x.shape[0], x.shape[2]), dtype=torch.bool, device=x.device)
319 | assert torch.remainder(audio_mask.sum(dim=1), (2**(len(self.downs)-1))).sum().item()==0, f'Length of audio mask must be a factor of {2**(len(self.downs)-1)}'
320 |
321 | if self.input_connection:
322 | input_data = x.clone()
323 | # Zero out masked values
324 | input_data = input_data.masked_fill(~audio_mask[:, None, :], 0.)
325 | else:
326 | input_data = None
327 | x = self.init_conv(x)
328 | if exists(self.inpainting_embedding):
329 | assert exists(inpainting_mask)
330 | inpainting_emb = self.inpainting_embedding(inpainting_mask)
331 | x = x + rearrange(inpainting_emb, 'b l c -> b c l')
332 |
333 | r = x.clone()
334 |
335 | mean_pooled_context = masked_mean(text_cond, dim=1, mask=text_cond_mask)
336 | text_mean_cond = self.to_text_non_attn_cond(mean_pooled_context)
337 |
338 | # Rescale continuous time [0,1] to similar range as Ho et al. 2020
339 | t = self.time_mlp(time_cond*1000)
340 |
341 | t = t + text_mean_cond
342 |
343 | h = []
344 | audio_mask_list = [audio_mask]
345 | input_data_list = [input_data]
346 | for idx, (block1, block2, attn, downsample) in enumerate(self.downs):
347 | x = block1(x, t, audio_mask=audio_mask_list[-1], input_data=input_data_list[-1])
348 | h.append(x)
349 |
350 | x = block2(x, t, audio_mask=audio_mask_list[-1], input_data=input_data_list[-1])
351 | x = attn(x, audio_mask=audio_mask_list[-1])
352 | h.append(x)
353 |
354 |
355 | x_prev_shape = x.shape
356 | x = downsample(x, audio_mask_list[-1])
357 | if x.shape[-1] != x_prev_shape[-1]:
358 | downsampled_mask = reduce(audio_mask_list[-1], 'b (l 2) -> b l', reduction='max')
359 | audio_mask_list.append(downsampled_mask)
360 | if self.input_connection:
361 | # Einops to mean pool input data
362 | input_data_list.append(reduce(input_data_list[-1], 'b c (l 2) -> b c l', reduction='mean'))
363 | x = rearrange(x, 'b c l -> b l c')
364 |
365 | x = self.transformer(x, context=text_cond, context_mask=text_cond_mask, time=t, audio_mask=audio_mask_list[-1])
366 | x = rearrange(x, 'b l c -> b c l')
367 |
368 | for block1, block2, attn, upsample in self.ups:
369 | x = torch.cat((x, h.pop()*(self.scale_skip_connection)), dim = 1)
370 | x = block1(x, t, audio_mask_list[-1], input_data_list[-1])
371 |
372 | x = torch.cat((x, h.pop()*(self.scale_skip_connection)), dim = 1)
373 | x = block2(x, t, audio_mask_list[-1], input_data_list[-1])
374 | x = attn(x, audio_mask_list[-1])
375 |
376 | # Awkward implementation to maintain backwards compatibility with previous checkpoints
377 | if isinstance(upsample, nn.Sequential):
378 | # Need to cast to float32 for upsampling
379 | # Upsample operation
380 | x = upsample[0](x.float())
381 | audio_mask_list.pop()
382 | # Masked conv operation
383 | x = upsample[1](x, audio_mask_list[-1])
384 |
385 | if self.input_connection:
386 | input_data_list.pop()
387 | else:
388 | x = upsample(x, audio_mask_list[-1])
389 |
390 | x = torch.cat((x, r), dim = 1)
391 |
392 | x = self.final_res_block(x, t, audio_mask, input_data)
393 | return self.final_conv(x)
--------------------------------------------------------------------------------
/neural_codec/encodec_wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchaudio
3 | from torch import nn, einsum
4 | from encodec import EncodecModel
5 | from encodec.utils import _linear_overlap_add
6 | from encodec.utils import convert_audio
7 | import typing as tp
8 | import numpy as np
9 | from tqdm import tqdm
10 |
11 | from einops import rearrange
12 | from functools import partial
13 | import os
14 |
15 | EncodedFrame = tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]
16 |
17 | class EncodecWrapper(nn.Module):
18 | def __init__(self):
19 | super().__init__()
20 | self.codec = EncodecModel.encodec_model_24khz()
21 | self.codec.set_target_bandwidth(24.)
22 |
23 | def encode(self, x: torch.Tensor) -> torch.Tensor:
24 | """Given a tensor `x`, returns a list of frames containing
25 | the discrete encoded codes for `x`, along with rescaling factors
26 | for each segment, when `self.normalize` is True.
27 |
28 | Each frames is a tuple `(codebook, scale)`, with `codebook` of
29 | shape `[B, K, T]`, with `K` the number of codebooks.
30 | """
31 | assert x.dim() == 3
32 | _, channels, length = x.shape
33 | assert channels > 0 and channels <= 2
34 | segment_length = self.codec.segment_length
35 | if segment_length is None:
36 | segment_length = length
37 | stride = length
38 | else:
39 | stride = self.codec.segment_stride # type: ignore
40 | assert stride is not None
41 |
42 | encoded_frames: tp.List[EncodedFrame] = []
43 | for offset in range(0, length, stride):
44 | frame = x[:, :, offset: offset + segment_length]
45 | encoded_frames.append(self._encode_frame(frame))
46 | assert len(encoded_frames) == 1
47 | assert encoded_frames[0][1] is None
48 | return encoded_frames[0][0]
49 |
50 | def _encode_frame(self, x: torch.Tensor) -> EncodedFrame:
51 | length = x.shape[-1]
52 | duration = length / self.codec.sample_rate
53 | assert self.codec.segment is None or duration <= 1e-5 + self.codec.segment
54 |
55 | if self.codec.normalize:
56 | mono = x.mean(dim=1, keepdim=True)
57 | volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
58 | scale = 1e-8 + volume
59 | x = x / scale
60 | scale = scale.view(-1, 1)
61 | else:
62 | scale = None
63 |
64 | emb = self.codec.encoder(x)
65 | return emb, scale
66 |
67 | def decode(self, emb: torch.Tensor, quantize: bool=True) -> torch.Tensor:
68 | """Decode the given frames into a waveform.
69 | Note that the output might be a bit bigger than the input. In that case,
70 | any extra steps at the end can be trimmed.
71 | """
72 | encoded_frames = [(emb, None)]
73 | segment_length = self.codec.segment_length
74 | if segment_length is None:
75 | assert len(encoded_frames) == 1
76 | return self._decode_frame(encoded_frames[0])
77 |
78 | frames = [self._decode_frame(frame, quantize=quantize) for frame in encoded_frames]
79 | return _linear_overlap_add(frames, self.segment_stride or 1)
80 |
81 | def _decode_frame(self, encoded_frame: EncodedFrame, quantize: bool=True) -> torch.Tensor:
82 | emb, scale = encoded_frame
83 | if quantize:
84 | codes = self.codec.quantizer.encode(emb, self.codec.frame_rate, self.codec.bandwidth)
85 | emb = self.codec.quantizer.decode(codes)
86 |
87 | # codes is [B, K, T], with T frames, K nb of codebooks.
88 | out = self.codec.decoder(emb)
89 | if scale is not None:
90 | out = out * scale.view(-1, 1, 1)
91 | return out
92 |
93 |
94 | def forward(self, wav:torch.tensor, sr:int, quantize:bool=True):
95 | # TODO: Revisit where to handle processing
96 | wav = convert_audio(wav, sr, self.codec.sample_rate, self.codec.channels)
97 | frames = self.encode(wav)
98 | return self.decode(frames, quantize=quantize)[:, :, :wav.shape[-1]]
99 |
100 |
101 |
102 | def test():
103 | def normalize_audio_latent(data_mean, data_std, audio_latent):
104 | return (audio_latent - rearrange(data_mean, 'c -> () c ()')) / rearrange(data_std, 'c -> () c ()')
105 |
106 | def unnormalize_audio_latent(data_mean, data_std, audio_latent):
107 | return audio_latent * rearrange(data_std, 'c -> () c ()') + rearrange(data_mean, 'c -> () c ()')
108 |
109 | codec = EncodecWrapper().to('cuda')
110 | import soundfile as sf
111 | from audio_datasets.librispeech import LibriSpeech, ENCODEC_SAMPLING_RATE
112 |
113 | test_dataset = LibriSpeech(split='test')
114 | # Path to saved model
115 |
116 | data = torch.load('../saved_models/')
117 | data_mean = data['model']['data_mean']
118 | data_std = data['model']['data_std']
119 |
120 | with torch.no_grad():
121 | for idx in range(len(test_dataset)):
122 | example = test_dataset.__getitem__(idx)
123 |
124 | # [B, 1, L]: batch x channels x length
125 | batched_wav = example['wav'][:,:int(example['audio_duration']*ENCODEC_SAMPLING_RATE)].unsqueeze(0).to('cuda')
126 |
127 | # linspace log_snr from -15 to 15
128 | log_snrs = np.linspace(-15, 15, 121)
129 | for log_snr in log_snrs:
130 | print(f'log_snr: {log_snr.item()}')
131 | os.makedirs(f'example_audio/logsnr/{log_snr.item()}', exist_ok=True)
132 | alpha2 = torch.sigmoid(torch.tensor([log_snr], device=batched_wav.device, dtype=torch.float32))
133 |
134 | wav_emb = codec.encode(batched_wav)
135 | normalized_wav_emb = normalize_audio_latent(data_mean, data_std, wav_emb)
136 | noisy_wav_emb = alpha2.sqrt()*normalized_wav_emb + (1-alpha2).sqrt()*torch.randn_like(normalized_wav_emb)
137 | noisy_wav_emb /= alpha2.sqrt()
138 | noisy_wav_emb = unnormalize_audio_latent(data_mean, data_std, noisy_wav_emb)
139 | noisy_reconstruction = codec.decode(noisy_wav_emb)
140 |
141 | sf.write(f'example_audio/logsnr/{log_snr.item()}/audio_{idx}.wav', noisy_reconstruction.squeeze().to('cpu').numpy(), ENCODEC_SAMPLING_RATE)
142 |
143 |
144 | if __name__=='__main__':
145 | test()
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.24.1
2 | aiohttp==3.9.1
3 | aiosignal==1.3.1
4 | appdirs==1.4.4
5 | async-timeout==4.0.3
6 | attrs==23.1.0
7 | audioread==3.0.1
8 | beartype==0.16.4
9 | certifi==2023.11.17
10 | cffi==1.16.0
11 | charset-normalizer==3.3.2
12 | click==8.1.7
13 | contourpy==1.2.0
14 | cycler==0.12.1
15 | data==0.4
16 | datasets==2.15.0
17 | decorator==5.1.1
18 | dill==0.3.7
19 | docker-pycreds==0.4.0
20 | einops==0.7.0
21 | ema-pytorch==0.3.1
22 | encodec==0.1.1
23 | evaluate==0.4.1
24 | ffmpeg==1.4
25 | filelock==3.13.1
26 | fonttools==4.45.1
27 | frozenlist==1.4.0
28 | fsspec==2023.10.0
29 | funcsigs==1.0.2
30 | future==1.0.0
31 | gitdb==4.0.11
32 | GitPython==3.1.40
33 | huggingface-hub==0.19.4
34 | idna==3.6
35 | Jinja2==3.1.2
36 | jiwer==3.0.3
37 | joblib==1.3.2
38 | kiwisolver==1.4.5
39 | latex==0.7.0
40 | lazy_loader==0.3
41 | librosa==0.10.1
42 | llvmlite==0.41.1
43 | MarkupSafe==2.1.3
44 | matplotlib==3.8.2
45 | mpmath==1.3.0
46 | msgpack==1.0.7
47 | multidict==6.0.4
48 | multiprocess==0.70.15
49 | networkx==3.2.1
50 | numba==0.58.1
51 | numpy==1.26.2
52 | nvidia-cublas-cu12==12.1.3.1
53 | nvidia-cuda-cupti-cu12==12.1.105
54 | nvidia-cuda-nvrtc-cu12==12.1.105
55 | nvidia-cuda-runtime-cu12==12.1.105
56 | nvidia-cudnn-cu12==8.9.2.26
57 | nvidia-cufft-cu12==11.0.2.54
58 | nvidia-curand-cu12==10.3.2.106
59 | nvidia-cusolver-cu12==11.4.5.107
60 | nvidia-cusparse-cu12==12.1.0.106
61 | nvidia-nccl-cu12==2.18.1
62 | nvidia-nvjitlink-cu12==12.3.101
63 | nvidia-nvtx-cu12==12.1.105
64 | packaging==23.2
65 | pandas==2.1.3
66 | Pillow==10.1.0
67 | platformdirs==4.0.0
68 | pooch==1.8.0
69 | praatio==6.1.0
70 | protobuf==4.25.1
71 | psutil==5.9.6
72 | pyarrow==14.0.1
73 | pyarrow-hotfix==0.6
74 | pycparser==2.21
75 | pyparsing==3.1.1
76 | python-dateutil==2.8.2
77 | pytz==2023.3.post1
78 | PyYAML==6.0.1
79 | rapidfuzz==3.6.1
80 | regex==2023.10.3
81 | requests==2.31.0
82 | responses==0.18.0
83 | safetensors==0.4.0
84 | scikit-learn==1.3.2
85 | scipy==1.11.4
86 | sentry-sdk==1.37.1
87 | setproctitle==1.3.3
88 | shutilwhich==1.1.0
89 | six==1.16.0
90 | smmap==5.0.1
91 | soundfile==0.12.1
92 | sox==1.4.1
93 | soxr==0.3.7
94 | sympy==1.12
95 | tempdir==0.7.1
96 | threadpoolctl==3.2.0
97 | tokenizers==0.15.0
98 | torch==2.1.1
99 | torchaudio==2.1.1
100 | torchvision==0.16.1
101 | tqdm==4.66.1
102 | transformers==4.35.2
103 | triton==2.1.0
104 | typing_extensions==4.8.0
105 | tzdata==2023.3
106 | urllib3==2.1.0
107 | wandb==0.16.0
108 | xxhash==3.4.1
109 | yarl==1.9.3
110 |
--------------------------------------------------------------------------------
/scripts/sample/sample_16_ls_testclean.sh:
--------------------------------------------------------------------------------
1 | python train_audio_diffusion.py --eval_test --resume_dir saved_models/uvit_byt5large_12layer_final_scale50 --sampling_timesteps 250 --run_name test/sample16 --sampler ddpm --seed 42 --num_samples 16 --scale 0.5 --guidance 2.0,3.0,5.0
2 |
--------------------------------------------------------------------------------
/scripts/train/train.sh:
--------------------------------------------------------------------------------
1 | python ./train_audio_diffusion.py --dataset_name librispeech --optimizer adamw --text_encoder google/byt5-base --batch_size 64 --gradient_accumulation_steps 1 --run_name librispeech_250k --save_and_sample_every 50000 --learning_rate 2e-4 --mixed_precision bf16 --train_schedule adaptive --sampling_schedule cosine --scale .5 --loss_type l2 --scale_skip_connection --inpainting_prob 0.5 --num_train_steps 250000 --num_transformer_layers 8 --dim 512 --dim_mults 1,1,1,1 --num_samples 128 --position_aware_cross_attention --loss_weighting asymmetric_lognormal_v_weighting --loss_weighting_mean -1.0 --loss_weighting_std 2.4 --objective pred_v --inpainting_embedding --dropout 0.1 --adam_weight_decay 2e-4
--------------------------------------------------------------------------------
/train_audio_diffusion.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from utils.utils import get_output_dir, parse_float_tuple
3 | import json
4 | import os
5 | import numpy as np
6 |
7 | from diffusion.audio_denoising_diffusion import GaussianDiffusion, Trainer
8 | from models.unet import Unet1D
9 | from models.transformer_wrapper import TransformerWrapper
10 | from transformers import AutoConfig
11 |
12 |
13 | def main(args):
14 |
15 | config = AutoConfig.from_pretrained(args.text_encoder)
16 | text_dim = config.d_model
17 |
18 | if args.model_arch == 'transformer':
19 | model = TransformerWrapper(
20 | dim=args.dim,
21 | text_dim=text_dim,
22 | channels = 128,
23 | position_aware_cross_attention=args.position_aware_cross_attention,
24 | inpainting_embedding = args.inpainting_embedding,
25 | num_transformer_layers=args.num_transformer_layers,
26 | dropout=args.dropout,
27 | )
28 | elif args.model_arch == 'unet':
29 | model = Unet1D(
30 | dim=args.dim,
31 | text_dim=text_dim,
32 | dim_mults=args.dim_mults,
33 | inpainting_embedding = args.inpainting_embedding,
34 | position_aware_cross_attention=args.position_aware_cross_attention,
35 | num_transformer_layers=args.num_transformer_layers,
36 | scale_skip_connection=args.scale_skip_connection,
37 | dropout=args.dropout,
38 | input_connection = args.input_connection,
39 | num_transformer_registers = args.num_registers,
40 | )
41 |
42 | args.trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
43 | print(f'Trainable params: {args.trainable_params}')
44 |
45 |
46 | loss_weighting_args = {
47 | 'loss_weighting_mean': args.loss_weighting_mean,
48 | 'loss_weighting_std': args.loss_weighting_std,
49 | 'loss_weighting_std_mult': args.loss_weighting_std_mult
50 | }
51 | diffusion = GaussianDiffusion(
52 | model,
53 | max_seq_len = 2048,
54 | text_encoder=args.text_encoder,
55 | sampling_timesteps = args.sampling_timesteps, # number of sampling steps
56 | sampler=args.sampler,
57 | log_var_interp = args.log_var_interp,
58 | langevin_step_size = args.langevin_step_size,
59 | snr_ddim_threshold = args.snr_ddim_threshold,
60 | train_schedule= args.train_schedule,
61 | sampling_schedule= args.sampling_schedule,
62 | loss_type = args.loss_type, # L1 or L2
63 | objective = args.objective,
64 | parameterization = args.parameterization,
65 | ema_decay = args.ema_decay,
66 | scale = args.scale,
67 | unconditional_prob = args.unconditional_prob,
68 | inpainting_prob = args.inpainting_prob,
69 | loss_weighting = args.loss_weighting,
70 | loss_weighting_args = loss_weighting_args,
71 | )
72 |
73 | trainer = Trainer(
74 | args=args,
75 | diffusion=diffusion,
76 | dataset_name=args.dataset_name,
77 | optimizer=args.optimizer,
78 | batch_size= args.batch_size,
79 | gradient_accumulate_every = args.gradient_accumulation_steps,
80 | train_lr = args.learning_rate,
81 | train_num_steps = args.num_train_steps,
82 | lr_schedule = args.lr_schedule,
83 | num_warmup_steps = args.lr_warmup_steps,
84 | adam_betas = (args.adam_beta1, args.adam_beta2),
85 | adam_weight_decay = args.adam_weight_decay,
86 | save_and_sample_every = args.save_and_sample_every,
87 | num_samples = args.num_samples,
88 | mixed_precision = args.mixed_precision,
89 | prefix_inpainting_seconds = args.prefix_inpainting_seconds,
90 | duration_path=args.test_duration_path,
91 | seed=args.seed,
92 | )
93 |
94 | if args.eval or args.eval_test:
95 | trainer.load(args.resume_dir, ckpt_step=args.ckpt_step)
96 | cls_free_guidances = args.guidance
97 | for cls_free_guidance in cls_free_guidances:
98 | trainer.sample(cls_free_guidance=cls_free_guidance, prefix_seconds=args.prefix_inpainting_seconds, test=args.eval_test, seed=42, test_duration_path=args.test_duration_path)
99 | return
100 |
101 | if args.init_model is not None:
102 | trainer.load(args.init_model, init_only=True)
103 |
104 | if args.resume:
105 | trainer.load(args.resume_dir, ckpt_step=args.ckpt_step)
106 |
107 | trainer.train()
108 |
109 | if __name__ == "__main__":
110 | parser = argparse.ArgumentParser(description="Training arguments")
111 | parser.add_argument("--dataset_name", type=str, default='librispeech')
112 | parser.add_argument("--save_dir", type=str, default="saved_models")
113 | parser.add_argument("--text_encoder", type=str, default="google/byt5-small")
114 | parser.add_argument("--output_dir", type=str, default=None)
115 | parser.add_argument("--resume_dir", type=str, default=None)
116 | parser.add_argument("--init_model", type=str, default=None)
117 | parser.add_argument("--run_name", type=str, default=None)
118 | parser.add_argument("--seed", type=int, default=None)
119 | # Architecture hyperparameters
120 | parser.add_argument("--model_arch",
121 | type=str,
122 | default="unet",
123 | choices=["unet", "transformer"],
124 | help="Choose the model architecture")
125 | parser.add_argument("--dim", type=int, default=512)
126 | parser.add_argument('--dim_mults', type=parse_float_tuple, default=(1, 1, 1, 1.5), help='Tuple of integer values for dim_mults')
127 | parser.add_argument("--position_aware_cross_attention", action="store_true", default=False)
128 | parser.add_argument("--scale_skip_connection", action="store_true", default=False)
129 | parser.add_argument("--input_connection", action="store_true", default=False)
130 | parser.add_argument("--num_transformer_layers", type=int, default=3)
131 | parser.add_argument("--num_registers", type=int, default=8)
132 | parser.add_argument("--dropout", type=float, default=0.)
133 | parser.add_argument("--inpainting_embedding", action="store_true", default=False)
134 |
135 | # Optimization hyperparameters
136 | parser.add_argument("--optimizer", type=str, default="adamw")
137 | parser.add_argument("--batch_size", type=int, default=16)
138 | parser.add_argument("--num_train_steps", type=int, default=60000)
139 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
140 | parser.add_argument("--learning_rate", type=float, default=1e-4)
141 | parser.add_argument("--clip_grad_norm", type=float, default=1.0)
142 | parser.add_argument("--lr_schedule", type=str, default="cosine")
143 | parser.add_argument("--lr_warmup_steps", type=int, default=1000)
144 | parser.add_argument("--adam_beta1", type=float, default=0.9)
145 | parser.add_argument("--adam_beta2", type=float, default=0.999)
146 | parser.add_argument("--adam_weight_decay", type=float, default=0)
147 | parser.add_argument("--ema_decay", type=float, default=0.9999)
148 | # Diffusion Hyperparameters
149 | parser.add_argument(
150 | "--objective",
151 | type=str,
152 | default="pred_eps",
153 | choices=["pred_eps", "pred_x0", "pred_v"],
154 | help=(
155 | "Which loss objective to use for the diffusion objective."
156 | ),
157 | )
158 | parser.add_argument(
159 | "--parameterization",
160 | type=str,
161 | default="pred_v",
162 | choices=["pred_eps", "pred_x0", "pred_v"],
163 | help=(
164 | "Which output parameterization to use for the diffusion network."
165 | ),
166 | )
167 | parser.add_argument(
168 | "--loss_type",
169 | type=str,
170 | default="l1",
171 | choices=["l1", "l2"],
172 | help=(
173 | "Which loss function to use for diffusion."
174 | ),
175 | )
176 | parser.add_argument(
177 | "--loss_weighting",
178 | type=str,
179 | default="edm",
180 | choices=["edm", "sigmoid", "v_weighting", "lognormal_v_weighting", "asymmetric_lognormal_v_weighting", "monotonic_lognormal_v_weighting"],
181 | help=(
182 | "Which loss function to use for diffusion."
183 | ),
184 | )
185 | parser.add_argument(
186 | "--loss_weighting_mean",
187 | type=float,
188 | default=0.0,
189 | help=(
190 | "Mean for loss weighting function."
191 | ),
192 | )
193 | parser.add_argument(
194 | "--loss_weighting_std",
195 | type=float,
196 | default=1.0,
197 | help=(
198 | "Standard deviation for loss weighting function."
199 | ),
200 | )
201 | parser.add_argument(
202 | "--loss_weighting_std_mult",
203 | type=float,
204 | default=2.0,
205 | help=(
206 | "Standard deviation for loss weighting function."
207 | ),
208 | )
209 | parser.add_argument(
210 | "--train_schedule",
211 | type=str,
212 | default="cosine",
213 | choices=["beta_linear", "simple_linear", "cosine", 'sigmoid', 'adaptive'],
214 | help=(
215 | "Which noise schedule to use."
216 | ),
217 | )
218 | parser.add_argument(
219 | "--sampling_schedule",
220 | type=str,
221 | default=None,
222 | choices=["beta_linear", "cosine", "simple_linear", None],
223 | help=(
224 | "Which noise schedule to use."
225 | ),
226 | )
227 | parser.add_argument("--resume", action="store_true", default=False)
228 | parser.add_argument("--scale", type=float, default=1.0)
229 | parser.add_argument("--sampling_timesteps", type=int, default=250)
230 | parser.add_argument("--log_var_interp", type=float, default=0.2)
231 | parser.add_argument("--snr_ddim_threshold", type=float, default=None)
232 | parser.add_argument("--langevin_step_size", type=float, default=0.0)
233 | parser.add_argument("--ckpt_step", type=int, default=None)
234 | # Audio Training Parameters
235 | parser.add_argument("--unconditional_prob", type=float, default=.1)
236 | parser.add_argument("--inpainting_prob", type=float, default=.5)
237 | # Generation Arguments
238 | parser.add_argument("--save_and_sample_every", type=int, default=20000)
239 | parser.add_argument("--num_samples", type=int, default=None)
240 | parser.add_argument(
241 | "--sampler",
242 | type=str,
243 | default="ddpm",
244 | choices=["ddim", "ddpm"],
245 | help=(
246 | "Which sampler use for diffusion."
247 | ),
248 | )
249 | parser.add_argument("--prefix_inpainting_seconds", type=float, default=0.)
250 | # Accelerate arguments
251 | parser.add_argument(
252 | "--mixed_precision",
253 | type=str,
254 | default="no",
255 | choices=["no", "fp16", "bf16"],
256 | help=(
257 | "Whether to use mixed precision. Choose"
258 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
259 | "and an Nvidia Ampere GPU."
260 | ),
261 | )
262 | # Load and eval model
263 | parser.add_argument("--eval", action="store_true", default=False)
264 | parser.add_argument("--eval_test", action="store_true", default=False)
265 | parser.add_argument("--test_duration_path", type=str, default=None)
266 | parser.add_argument('--guidance', type=parse_float_tuple, help='Tuple of float values for dim_mults')
267 |
268 | args = parser.parse_args()
269 | if args.eval or args.eval_test:
270 | assert args.resume_dir is not None
271 |
272 | if args.eval or args.eval_test:
273 | with open(os.path.join(args.resume_dir, 'args.json'), 'rt') as f:
274 | saved_args = json.load(f)
275 | args_dict = vars(args)
276 | heldout_params = {'run_name', 'output_dir', 'resume_dir', 'eval', 'eval_test', 'prefix_inpainting_seconds', 'num_samples', 'sampling_timesteps', 'sampling_schedule', 'scale', 'sampler', 'mixed_precision', 'guidance', 'ckpt_step', 'log_var_interp', 'langevin_step_size', 'snr_ddim_threshold', 'test_duration_path'}
277 | for k,v in saved_args.items():
278 | if k in heldout_params:
279 | continue
280 | args_dict[k] = v
281 |
282 | main(args)
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from datetime import datetime
3 | import os
4 | from pathlib import Path
5 | import argparse
6 |
7 | def compute_grad_norm(parameters):
8 | # implementation adapted from https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html#clip_grad_norm_
9 | parameters = [p for p in parameters if p.grad is not None]
10 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), p=2) for p in parameters]), p=2).item()
11 | return total_norm
12 |
13 | def get_output_dir(args):
14 | model_dir = f'{args.dataset_name}/{args.run_name}/'
15 | output_dir = os.path.join(args.save_dir, model_dir)
16 | return output_dir
17 |
18 | def parse_float_tuple(dim_mults_str):
19 | try:
20 | dim_mults = tuple(map(float, dim_mults_str.split(',')))
21 | return dim_mults
22 | except ValueError:
23 | raise argparse.ArgumentTypeError('dim_mults must be a comma-separated list of integers')
24 |
--------------------------------------------------------------------------------