├── .gitignore ├── README.md ├── combine_datasets.py ├── compare_logs.py ├── eval_diffusion.py ├── eval_vqvae.py ├── plot_log.py ├── sample_diffusion.py ├── sample_vqvae.py ├── sample_vqvae_uncond.py ├── setup.py ├── stat_compare.py ├── stat_generate.py ├── train_classifier.py ├── train_diffusion.py ├── train_enc_pred.py ├── train_vqvae.py ├── train_vqvae_add.py ├── train_vqvae_uncond.py ├── voice_search_vqvae.py └── vq_voice_swap ├── __init__.py ├── dataset.py ├── diffusion ├── __init__.py ├── diffusion.py ├── make.py └── schedule.py ├── diffusion_model.py ├── ema.py ├── logger.py ├── loss_tracker.py ├── models ├── __init__.py ├── base.py ├── classifier.py ├── conv_encoder.py ├── encoder_predictor.py ├── make.py ├── unet.py └── wavegrad.py ├── smoothing.py ├── smoothing_test.py ├── train_loop.py ├── util.py ├── vq.py └── vq_vae.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .DS_Store 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # vq-voice-swap 2 | 3 | This ongoing project aims to use [diffusion models](https://arxiv.org/abs/2006.11239) for speech generation and speaker conversion. It includes scripts for training and evaluating diffusion models on speech datasets like [LibriSpeech](https://www.openslr.org/12). 4 | 5 | This project initially started out as an experiment in using [VQ-VAE](https://arxiv.org/abs/1711.00937) + a diffusion model for speaker conversion. The results are now quite reasonable, but I am still working on improvements. 6 | 7 | Using this codebase, you can record yourself speaking and change the voice in the recording without changing the actual content (i.e. words). See [Using Pretrained Models for Speaker Conversion](#using-pretrained-models-for-speaker-conversion) to try it out! 8 | 9 | # What's Included 10 | 11 | This codebase includes data loaders for LibriSpeech, scripts to train diffusion models, VQ-VAEs, and classifiers, and various scripts for sampling and evaluating models. 12 | 13 | You can train unconditional diffusion models on speech data using [train_diffusion.py](train_diffusion.py). The resulting diffusion models can then be sampled via [sample_diffusion.py](sample_diffusion.py). 14 | 15 | You can train speaker-conditional VQ-VAE + diffusion models using [train_vqvae.py](train_vqvae.py). The model can then be used for speaker conversion by running [sample_vqvae.py](sample_vqvae.py). 16 | 17 | You can train classifiers using [train_classifier.py](train_classifier.py). Classifiers can be used with `sample_diffusion.py` to potentially improve sample quality, as done in [Diffusion Models Beat GANs on Image Synthesis](https://arxiv.org/abs/2105.05233). 18 | 19 | # Using Pretrained Models for Speaker Conversion 20 | 21 | You can download a pre-trained speaker conversion VQ-VAE [here](http://data.aqnichol.com/vq-voice-swap/vqvae-unet-mfcc/). In particular, download the file `model_ema_0.99.pt` and put it into this directory. You can then sample from this model like so, given a source audio file `input.wav` which is at least four seconds long: 22 | 23 | ``` 24 | python3 sample_vqvae.py \ 25 | --encoding ulaw \ 26 | --input-file input.wav \ 27 | --label 0 \ 28 | model_ema_0.99.pt \ 29 | output.wav 30 | ``` 31 | 32 | Here we passed `--label 0`, where `0` is the target speaker class label. [This JSON file](http://data.aqnichol.com/vq-voice-swap/vqvae-unet-mfcc/classes.json) shows how numerical class labels map to LibriSpeech speakers. 33 | 34 | # Evaluations for Unconditional Models 35 | 36 | For generative modeling, loss or log-likelihood don't necessarily correspond well with sample quality. To evaluate sample quality, it is preferrable to use perception-aware evaluation metrics, typically by leveraging pre-trained models. 37 | 38 | I have prepared evaluation metrics similar to [Inception Score](https://github.com/openai/improved-gan/tree/master/inception_score) and [FID](https://github.com/bioinf-jku/TTUR), but for speech segments rather than images: 39 | 40 | * **Class score** is similar to [Inception Score](https://github.com/openai/improved-gan/tree/master/inception_score), and higher is better. This should be taken as a measure of individual sample quality and speaker coverage. 41 | * **Frechet score** is similar to [FID](https://github.com/bioinf-jku/TTUR), and lower is better. Frechet score measures both fidelity and diversity, or more generally how well two distributions match. 42 | 43 | These evals use a [pre-trained speaker classifier](http://data.aqnichol.com/vq-voice-swap/eval/) to extract features. For evaluating a diffusion model, you must first generate a directory of 10k samples using the `sample_diffusion.py` script. Next, download the [pre-trained classifier](http://data.aqnichol.com/vq-voice-swap/eval/model_classifier.pt), and run [stat_generate.py](stat_generate.py) on your samples to gather statistics and compute a class score. Then you can generate or [download](http://data.aqnichol.com/vq-voice-swap/eval/train_clean_360.npz) statistics for the training set, and run [stat_compare.py](stat_compare.py) to compute the Frechet score. 44 | 45 | # Results on Unconditional Generation 46 | 47 | Here are all of the models I've trained (and released), with their corresponding evals. Each model links to a directory with samples, evaluation statistics, and the model checkpoints: 48 | 49 | * [unet32](http://data.aqnichol.com/vq-voice-swap/unet32): a 10M parameter UNet model with the default noise schedule. For this model, I sampled with 50 steps using a sample-time schedule `t = s^2` where `s` is linearly spaced. 50 | * Class score: 47.1 51 | * Frechet score: 2494 52 | * [unet64](http://data.aqnichol.com/vq-voice-swap/unet64/): a 50M parameter model which is otherwise similar to the unet32, but with some learning rate annealing at the end of training. 53 | * Class score: 69.0 54 | * Frechet score: 1834 55 | * [unet64/early_stopped](http://data.aqnichol.com/vq-voice-swap/unet64/early_stopped/): like unet64, but *without* learning rate annealing. Surprisingly, the Frechet score is much better, suggesting some kind of overfitting. 56 | * Class score: 51.5 57 | * Frechet score: 855 58 | -------------------------------------------------------------------------------- /combine_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Combine two LibriSpeech-like datasets into one directory with a shared index 3 | and symbolic links to the sub-datasets. 4 | """ 5 | 6 | import argparse 7 | import json 8 | import os 9 | import sys 10 | 11 | from vq_voice_swap.dataset import LibriSpeech 12 | 13 | 14 | def main(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("directories", type=str, nargs="+") 17 | parser.add_argument("output", type=str) 18 | args = parser.parse_args() 19 | 20 | if os.path.exists(args.output): 21 | print(f"error: output directory already exists: {args.output}") 22 | sys.exit(1) 23 | os.mkdir(args.output) 24 | 25 | full_index = {} 26 | for i, subdir in enumerate(args.directories): 27 | print(f"creating dataset for {subdir}...") 28 | dataset = LibriSpeech(subdir) 29 | prefix = f"{i:02}_" 30 | full_index.update({prefix + k: v for k, v in dataset.index.items()}) 31 | for speaker_id in dataset.index.keys(): 32 | os.symlink( 33 | os.path.join(subdir, speaker_id), 34 | os.path.join(args.output, prefix + speaker_id), 35 | ) 36 | 37 | with open(os.path.join(args.output, "index.json"), "w") as f: 38 | json.dump(full_index, f) 39 | 40 | 41 | if __name__ == "__main__": 42 | main() 43 | -------------------------------------------------------------------------------- /compare_logs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plot one or more values throughout training from one or more logs, showing 3 | them on the same plot for easy comparison. 4 | 5 | Pass log keys to the `--fields` flag, for example `--fields mse base_q0`. 6 | Keys can be regular expressions, such as `base.*`, to average values. 7 | 8 | The final (plain) arguments are `log_file [log_file_1 ...] output_image`. 9 | Each log file is plotted separately, and the legend will indicate which plots 10 | are from which log file. This makes it easy to compare runs. 11 | 12 | As an example, here's how to compare three fields across two runs: 13 | 14 | python compare_logs.py --fields base_q0 label_q0 cond_q0 -- log1.txt log2.txt out.png 15 | 16 | """ 17 | 18 | import argparse 19 | import os 20 | import re 21 | 22 | import matplotlib 23 | 24 | matplotlib.use("agg") 25 | import matplotlib.pyplot as plt 26 | 27 | from vq_voice_swap.logger import read_log 28 | from vq_voice_swap.smoothing import moving_average 29 | 30 | 31 | def main(): 32 | args = arg_parser().parse_args() 33 | 34 | for filename in args.log_files: 35 | name, _ = os.path.splitext(os.path.basename(filename)) 36 | for field in args.fields: 37 | entries = [(step, field_value(x, field)) for step, x in read_log(filename)] 38 | entries = [(x, y) for x, y in entries if y is not None] 39 | xs, ys = tuple(zip(*entries)) 40 | ys = moving_average(ys, args.smoothing) 41 | plt.plot(xs, ys, label=f"{name} {field}") 42 | plt.ylim(args.min_y, args.max_y) 43 | if args.max_x is not None: 44 | plt.xlim(0, args.max_x) 45 | plt.xlabel("step") 46 | plt.ylabel("loss") 47 | plt.legend() 48 | plt.savefig(args.out_file) 49 | 50 | 51 | def field_value(log_entry, field_expr): 52 | values = [v for k, v in log_entry.items() if re.match(field_expr, k)] 53 | if len(values) == 0: 54 | return None 55 | return sum(values) / len(values) 56 | 57 | 58 | def arg_parser(): 59 | parser = argparse.ArgumentParser( 60 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 61 | ) 62 | parser.add_argument("--smoothing", type=int, default=1) 63 | parser.add_argument("--max-x", type=float, default=None) 64 | parser.add_argument("--min-y", type=float, default=0.0) 65 | parser.add_argument("--max-y", type=float, default=1.0) 66 | parser.add_argument("--fields", type=str, nargs="+", default="base_q.") 67 | parser.add_argument("log_files", nargs="+", type=str) 68 | parser.add_argument("out_file", type=str) 69 | return parser 70 | 71 | 72 | if __name__ == "__main__": 73 | main() 74 | -------------------------------------------------------------------------------- /eval_diffusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate how well a diffusion model performs. 3 | """ 4 | 5 | import argparse 6 | 7 | import torch 8 | 9 | from vq_voice_swap.dataset import create_data_loader 10 | from vq_voice_swap.diffusion_model import DiffusionModel 11 | from vq_voice_swap.loss_tracker import LossTracker 12 | 13 | 14 | def main(): 15 | args = arg_parser().parse_args() 16 | 17 | data_loader, _ = create_data_loader( 18 | directory=args.data_dir, batch_size=args.batch_size 19 | ) 20 | 21 | print("loading model from checkpoint...") 22 | model = DiffusionModel.load(args.checkpoint_path) 23 | 24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | model.to(device) 26 | 27 | tracker = LossTracker(avg_size=1_000_000) 28 | 29 | num_samples = 0 30 | for data_batch in data_loader: 31 | audio_seq = data_batch["samples"][:, None].to(device) 32 | ts = torch.rand(args.batch_size, device=device) 33 | noise = torch.randn_like(audio_seq) 34 | samples = model.diffusion.sample_q(audio_seq, ts, epsilon=noise) 35 | with torch.no_grad(): 36 | noise_pred = model.predictor(samples, ts) 37 | losses = ((noise - noise_pred) ** 2).flatten(1).mean(dim=1) 38 | 39 | tracker.add(ts, losses) 40 | log_dict = tracker.log_dict() 41 | 42 | num_samples += len(ts) 43 | 44 | msg = " ".join([f"{key}={value:.06f}" for key, value in log_dict.items()]) 45 | print(f"{num_samples} samples: {msg}") 46 | 47 | 48 | def arg_parser(): 49 | parser = argparse.ArgumentParser( 50 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 51 | ) 52 | parser.add_argument("--batch-size", type=int, default=4) 53 | parser.add_argument("checkpoint_path", type=str) 54 | parser.add_argument("data_dir", type=str) 55 | return parser 56 | 57 | 58 | if __name__ == "__main__": 59 | main() 60 | -------------------------------------------------------------------------------- /eval_vqvae.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate how much a VQ-VAE leverages labels by measuring how much worse the 3 | loss becomes when the label is randomized. 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from vq_voice_swap.dataset import create_data_loader 12 | from vq_voice_swap.loss_tracker import LossTracker 13 | from vq_voice_swap.vq_vae import ConcreteVQVAE 14 | 15 | 16 | def main(): 17 | args = arg_parser().parse_args() 18 | 19 | data_loader, num_labels = create_data_loader( 20 | directory=args.data_dir, batch_size=args.batch_size 21 | ) 22 | 23 | print("loading model from checkpoint...") 24 | model = ConcreteVQVAE.load(args.checkpoint_path) 25 | assert model.num_labels == num_labels 26 | 27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | model.to(device) 29 | 30 | trackers = { 31 | key: LossTracker(avg_size=1_000_000, prefix=f"{key}_") for key in ["cond"] 32 | } 33 | output_stats = [ 34 | OutputStats(module, key) for key, module in (("cond", model.cond_predictor),) 35 | ] 36 | 37 | num_samples = 0 38 | for data_batch in data_loader: 39 | audio_seq = data_batch["samples"][:, None].to(device) 40 | labels = data_batch["label"].to(device) 41 | with torch.no_grad(): 42 | losses = model.losses(audio_seq, labels) 43 | 44 | log_dict = {} 45 | for key, mses in losses["mses_dict"].items(): 46 | trackers[key].add(losses["ts"], mses) 47 | log_dict.update(trackers[key].log_dict()) 48 | for stat in output_stats: 49 | log_dict.update(stat.log_dict()) 50 | 51 | num_samples += len(labels) 52 | 53 | msg = " ".join([f"{key}={value:.06f}" for key, value in log_dict.items()]) 54 | print(f"{num_samples} samples: {msg}") 55 | 56 | 57 | class OutputStats: 58 | def __init__(self, module: nn.Module, key: str): 59 | self.module = module 60 | self.key = key 61 | self.stds = LossTracker(prefix=f"{key}_std_") 62 | 63 | def hook(_module, inputs, output): 64 | self.stds.add(inputs[1], output.flatten(1).std(dim=1)) 65 | 66 | self.module.register_forward_hook(hook) 67 | 68 | def log_dict(self): 69 | return self.stds.log_dict() 70 | 71 | 72 | def arg_parser(): 73 | parser = argparse.ArgumentParser( 74 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 75 | ) 76 | parser.add_argument("--batch-size", type=int, default=4) 77 | parser.add_argument("checkpoint_path", type=str) 78 | parser.add_argument("data_dir", type=str) 79 | return parser 80 | 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /plot_log.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plot the MSE over a run from its log file. 3 | """ 4 | 5 | import argparse 6 | 7 | import matplotlib 8 | 9 | matplotlib.use("agg") 10 | import matplotlib.pyplot as plt 11 | 12 | from vq_voice_swap.logger import read_log 13 | from vq_voice_swap.smoothing import moving_average 14 | 15 | 16 | def main(): 17 | args = arg_parser().parse_args() 18 | entries = [(step, x["loss"]) for step, x in read_log(args.log_file)] 19 | xs, ys = list(zip(*entries)) 20 | ys = moving_average(ys, args.smoothing) 21 | plt.plot(xs, ys) 22 | plt.ylim(0, args.max_y) 23 | plt.xlabel("step") 24 | plt.ylabel("loss") 25 | plt.savefig(args.out_file) 26 | 27 | 28 | def arg_parser(): 29 | parser = argparse.ArgumentParser( 30 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 31 | ) 32 | parser.add_argument("--smoothing", type=int, default=100) 33 | parser.add_argument("--max-y", type=float, default=1.0) 34 | parser.add_argument("log_file", type=str) 35 | parser.add_argument("out_file", type=str) 36 | return parser 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /sample_diffusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train an unconditional diffusion model on waveforms. 3 | """ 4 | 5 | import argparse 6 | import math 7 | import os 8 | from functools import partial 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from tqdm.auto import tqdm 13 | 14 | from vq_voice_swap.dataset import ChunkWriter 15 | from vq_voice_swap.diffusion_model import DiffusionModel 16 | from vq_voice_swap.models import Classifier 17 | 18 | 19 | def main(): 20 | args = arg_parser().parse_args() 21 | 22 | schedule = eval(args.schedule) 23 | 24 | model = DiffusionModel.load(args.checkpoint_path) 25 | 26 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 27 | model.to(device) 28 | model.eval() 29 | 30 | if args.classifier_path: 31 | classifier = Classifier.load(args.classifier_path).to(device) 32 | classifier.eval() 33 | 34 | def cond_fn(x, ts, labels=None): 35 | if labels is None: 36 | labels = sample_labels(args, classifier.num_labels, len(ts), ts.device) 37 | with torch.enable_grad(): 38 | x = x.detach().clone().requires_grad_() 39 | logits = classifier(x, ts, use_checkpoint=args.grad_checkpoint) 40 | logprobs = F.log_softmax(logits, dim=-1) 41 | grads = torch.autograd.grad(logprobs[range(len(x)), labels].sum(), x)[0] 42 | return grads.detach() * args.classifier_scale 43 | 44 | else: 45 | cond_fn = None 46 | 47 | if args.num_samples is None: 48 | generate_one_sample( 49 | args, 50 | model, 51 | device, 52 | constrain=args.constrain, 53 | cond_fn=cond_fn, 54 | schedule=schedule, 55 | ) 56 | else: 57 | generate_many_samples( 58 | args, 59 | model, 60 | device, 61 | constrain=args.constrain, 62 | cond_fn=cond_fn, 63 | schedule=schedule, 64 | ) 65 | 66 | 67 | def generate_one_sample(args, model, device, cond_fn=None, **kwargs): 68 | x_T = torch.randn(1, 1, 64000, device=device) 69 | cond_pred, cond_fn = condition_on_sampled_labels(args, model, cond_fn, 1, device) 70 | sample = model.diffusion.ddpm_sample( 71 | x_T, cond_pred, args.sample_steps, progress=True, cond_fn=cond_fn, **kwargs 72 | ) 73 | 74 | writer = ChunkWriter(args.sample_path, 16000, encoding=args.encoding) 75 | writer.write(sample.view(-1).cpu().numpy()) 76 | writer.close() 77 | 78 | 79 | def generate_many_samples(args, model, device, cond_fn=None, **kwargs): 80 | os.mkdir(args.sample_path) 81 | 82 | num_batches = int(math.ceil(args.num_samples / args.batch_size)) 83 | count = 0 84 | 85 | for _ in tqdm(range(num_batches)): 86 | x_T = torch.randn(args.batch_size, 1, 64000, device=device) 87 | cond_pred, cond_fn_1 = condition_on_sampled_labels( 88 | args, model, cond_fn, args.batch_size, device 89 | ) 90 | sample = model.diffusion.ddpm_sample( 91 | x_T, 92 | cond_pred, 93 | args.sample_steps, 94 | progress=False, 95 | cond_fn=cond_fn_1, 96 | **kwargs, 97 | ) 98 | for seq in sample: 99 | if count == args.num_samples: 100 | break 101 | sample_path = os.path.join(args.sample_path, f"sample_{count:06}.wav") 102 | writer = ChunkWriter(sample_path, 16000, encoding=args.encoding) 103 | writer.write(seq.view(-1).cpu().numpy()) 104 | writer.close() 105 | count += 1 106 | 107 | 108 | def condition_on_sampled_labels(args, model, cond_fn, batch_size, device): 109 | if model.num_labels is None: 110 | return model.predictor, cond_fn 111 | labels = sample_labels(args, model.num_labels, batch_size, device) 112 | if cond_fn is not None: 113 | cond_fn = partial(cond_fn, labels=labels) 114 | return partial(model.predictor, labels=labels), cond_fn 115 | 116 | 117 | def sample_labels(args, num_labels, batch_size, device): 118 | if args.target_class is not None: 119 | out = torch.tensor([args.target_class] * batch_size) 120 | else: 121 | out = torch.randint(low=0, high=num_labels, size=(batch_size,)) 122 | return out.to(dtype=torch.long, device=device) 123 | 124 | 125 | def arg_parser(): 126 | parser = argparse.ArgumentParser( 127 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 128 | ) 129 | parser.add_argument("--checkpoint-path", default="model_diffusion.pt", type=str) 130 | parser.add_argument("--sample-steps", default=100, type=int) 131 | parser.add_argument("--batch-size", default=1, type=int) 132 | parser.add_argument("--constrain", action="store_true") 133 | parser.add_argument("--sample-path", default="sample.wav", type=str) 134 | parser.add_argument("--num-samples", default=None, type=int) 135 | parser.add_argument("--grad-checkpoint", action="store_true") 136 | parser.add_argument("--classifier-path", default=None, type=str) 137 | parser.add_argument("--classifier-scale", default=1.0, type=float) 138 | parser.add_argument("--target-class", default=None, type=int) 139 | parser.add_argument("--schedule", default="lambda t: t", type=str) 140 | parser.add_argument("--encoding", default="linear", type=str) 141 | return parser 142 | 143 | 144 | if __name__ == "__main__": 145 | main() 146 | -------------------------------------------------------------------------------- /sample_vqvae.py: -------------------------------------------------------------------------------- 1 | """ 2 | Encode and decode a sample from a VQ-VAE. 3 | """ 4 | 5 | import argparse 6 | 7 | import torch 8 | 9 | from vq_voice_swap.dataset import ChunkReader, ChunkWriter 10 | from vq_voice_swap.models import EncoderPredictor 11 | from vq_voice_swap.vq_vae import VQVAE 12 | 13 | 14 | def main(): 15 | args = arg_parser().parse_args() 16 | 17 | print("loading model from checkpoint...") 18 | model = VQVAE.load(args.checkpoint_path) 19 | assert args.label < model.num_labels 20 | 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | model.to(device) 23 | 24 | enc_pred = None 25 | if args.enc_pred_path: 26 | print("loading encoder predictor") 27 | enc_pred = EncoderPredictor.load(args.enc_pred_path).to(device) 28 | 29 | print(f"loading waveform from {args.input_file}...") 30 | reader = ChunkReader( 31 | args.input_file, sample_rate=args.sample_rate, encoding=args.encoding 32 | ) 33 | try: 34 | chunk = reader.read(args.seconds * args.sample_rate) 35 | finally: 36 | reader.close() 37 | in_seq = torch.from_numpy(chunk[None, None]).to(device) 38 | 39 | print("encoding audio sequence...") 40 | if args.no_vq: 41 | with torch.no_grad(): 42 | encoded = model.encoder(in_seq) 43 | else: 44 | encoded = model.encode(in_seq) 45 | 46 | print("decoding audio samples...") 47 | labels = torch.tensor([args.label]).long().to(device) 48 | sample = model.decode( 49 | encoded, 50 | labels, 51 | steps=args.sample_steps, 52 | progress=True, 53 | constrain=True, 54 | enc_pred=enc_pred, 55 | enc_pred_scale=args.enc_pred_scale, 56 | ) 57 | 58 | if args.check_vq: 59 | assert not args.no_vq 60 | encoded_1 = model.encode(sample) 61 | count = (encoded == encoded_1).float().mean() 62 | print(f"fraction of consistent VQ codes: {count}") 63 | 64 | sample = sample.clamp(-1, 1).cpu().numpy().flatten() 65 | 66 | print(f"saving result to {args.output_file}...") 67 | writer = ChunkWriter( 68 | args.output_file, sample_rate=args.sample_rate, encoding=args.encoding 69 | ) 70 | try: 71 | writer.write(sample) 72 | finally: 73 | writer.close() 74 | 75 | 76 | def arg_parser(): 77 | parser = argparse.ArgumentParser( 78 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 79 | ) 80 | parser.add_argument("--sample-rate", type=int, default=16000) 81 | parser.add_argument("--sample-steps", type=int, default=100) 82 | parser.add_argument("--seconds", type=int, default=4) 83 | parser.add_argument("--label", type=int, default=None, required=True) 84 | parser.add_argument("--input-file", type=str, default=None, required=True) 85 | parser.add_argument("--encoding", type=str, default="linear") 86 | parser.add_argument("--enc-pred-path", type=str, default=None) 87 | parser.add_argument("--enc-pred-scale", type=float, default=1.0) 88 | parser.add_argument("--no-vq", action="store_true") 89 | parser.add_argument("--check-vq", action="store_true") 90 | parser.add_argument("checkpoint_path", type=str) 91 | parser.add_argument("output_file", type=str) 92 | return parser 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /sample_vqvae_uncond.py: -------------------------------------------------------------------------------- 1 | """ 2 | Encode and decode a sample from a VQ-VAE, where decoding uses unconditional 3 | guidance. The model should have been fine-tuned using train_vqvae_uncond.py. 4 | """ 5 | 6 | import argparse 7 | 8 | import torch 9 | 10 | from vq_voice_swap.dataset import ChunkReader, ChunkWriter 11 | from vq_voice_swap.vq_vae import VQVAE 12 | 13 | 14 | def main(): 15 | args = arg_parser().parse_args() 16 | 17 | schedule = eval(args.schedule) 18 | 19 | print("loading model from checkpoint...") 20 | model = VQVAE.load(args.checkpoint_path) 21 | assert args.label + 1 < model.num_labels 22 | 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | model.to(device) 25 | 26 | print(f"loading waveform from {args.input_file}...") 27 | reader = ChunkReader( 28 | args.input_file, sample_rate=args.sample_rate, encoding=args.encoding 29 | ) 30 | try: 31 | chunk = reader.read(args.seconds * args.sample_rate) 32 | finally: 33 | reader.close() 34 | in_seq = torch.from_numpy(chunk[None, None]).to(device) 35 | 36 | print("encoding audio sequence...") 37 | if args.no_vq: 38 | with torch.no_grad(): 39 | encoded = model.encoder(in_seq) 40 | else: 41 | encoded = model.encode(in_seq) 42 | 43 | print("decoding audio samples...") 44 | labels = torch.tensor([args.label]).long().to(device) 45 | sample = model.decode_uncond_guidance( 46 | encoded, 47 | labels, 48 | steps=args.sample_steps, 49 | progress=True, 50 | constrain=True, 51 | label_scale=args.guide_label_scale, 52 | vq_scale=args.guide_vq_scale, 53 | schedule=schedule, 54 | ) 55 | 56 | if args.check_vq: 57 | assert not args.no_vq 58 | encoded_1 = model.encode(sample) 59 | count = (encoded == encoded_1).float().mean() 60 | print(f"fraction of consistent VQ codes: {count}") 61 | 62 | sample = sample.clamp(-1, 1).cpu().numpy().flatten() 63 | 64 | print(f"saving result to {args.output_file}...") 65 | writer = ChunkWriter( 66 | args.output_file, sample_rate=args.sample_rate, encoding=args.encoding 67 | ) 68 | try: 69 | writer.write(sample) 70 | finally: 71 | writer.close() 72 | 73 | 74 | def arg_parser(): 75 | parser = argparse.ArgumentParser( 76 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 77 | ) 78 | parser.add_argument("--sample-rate", type=int, default=16000) 79 | parser.add_argument("--sample-steps", type=int, default=100) 80 | parser.add_argument("--seconds", type=int, default=4) 81 | parser.add_argument("--label", type=int, default=None, required=True) 82 | parser.add_argument("--input-file", type=str, default=None, required=True) 83 | parser.add_argument("--encoding", type=str, default="linear") 84 | parser.add_argument("--schedule", default="lambda t: t", type=str) 85 | parser.add_argument("--guide-label-scale", type=float, default=1.0) 86 | parser.add_argument("--guide-vq-scale", type=float, default=0.0) 87 | parser.add_argument("--no-vq", action="store_true") 88 | parser.add_argument("--check-vq", action="store_true") 89 | parser.add_argument("checkpoint_path", type=str) 90 | parser.add_argument("output_file", type=str) 91 | return parser 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="vq-voice-swap", 5 | py_modules=["vq_voice_swap"], 6 | install_requires=["numpy", "torch", "torchaudio", "tqdm"], 7 | ) 8 | -------------------------------------------------------------------------------- /stat_compare.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compare the statistics of two batch statistics, similar to the Frechet 3 | Inception Distance. 4 | """ 5 | 6 | import argparse 7 | 8 | import numpy as np 9 | from scipy import linalg 10 | 11 | 12 | def main(): 13 | args = arg_parser().parse_args() 14 | stat1 = np.load(args.stat_1) 15 | stat2 = np.load(args.stat_2) 16 | print(frechet_distance(stat1["mean"], stat1["cov"], stat2["mean"], stat2["cov"])) 17 | 18 | 19 | def frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 20 | # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132 21 | assert ( 22 | mu1.shape == mu2.shape 23 | ), "Training and test mean vectors have different lengths" 24 | assert ( 25 | sigma1.shape == sigma2.shape 26 | ), "Training and test covariances have different dimensions" 27 | 28 | diff = mu1 - mu2 29 | 30 | # product might be almost singular 31 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 32 | if not np.isfinite(covmean).all(): 33 | msg = ( 34 | "fid calculation produces singular product; adding %s to diagonal of cov estimates" 35 | % eps 36 | ) 37 | print(msg) 38 | offset = np.eye(sigma1.shape[0]) * eps 39 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 40 | 41 | # numerical error might give slight imaginary component 42 | if np.iscomplexobj(covmean): 43 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 44 | m = np.max(np.abs(covmean.imag)) 45 | raise ValueError("Imaginary component {}".format(m)) 46 | covmean = covmean.real 47 | 48 | tr_covmean = np.trace(covmean) 49 | 50 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 51 | 52 | 53 | def arg_parser(): 54 | parser = argparse.ArgumentParser( 55 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 56 | ) 57 | parser.add_argument("stat_1", type=str) 58 | parser.add_argument("stat_2", type=str) 59 | return parser 60 | 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /stat_generate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate feature statistics for a batch of samples using a classifier's 3 | features. Also computes a class score similar to Inception Score. 4 | """ 5 | 6 | import argparse 7 | import multiprocessing as mp 8 | import os 9 | from typing import Iterable, Iterator, List, Optional 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | from tqdm.auto import tqdm 15 | 16 | from vq_voice_swap.dataset import (ChunkReader, create_data_loader, 17 | lookup_audio_duration) 18 | from vq_voice_swap.models import Classifier 19 | 20 | 21 | def main(): 22 | args = arg_parser().parse_args() 23 | segments = load_segments(args) 24 | 25 | classifier = Classifier.load(args.checkpoint_path) 26 | 27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | classifier.to(device) 29 | 30 | features = [] 31 | probs = [] 32 | for batch in batch_segments(args.batch_size, tqdm(segments)): 33 | ts = torch.zeros(len(batch)).to(device) 34 | batch = batch.to(device) 35 | with torch.no_grad(): 36 | fv = classifier.stem(batch, ts) 37 | features.extend(fv.cpu().numpy()) 38 | probs.extend(F.softmax(classifier.out(fv), dim=-1).cpu().numpy()) 39 | 40 | features = np.stack(features, axis=0) 41 | probs = np.stack(probs, axis=0) 42 | 43 | mean = np.mean(features, axis=0) 44 | cov = np.cov(features, rowvar=False) 45 | 46 | # Based on inception score. 47 | # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L49 48 | kl = probs * (np.log(probs) - np.log(np.expand_dims(np.mean(probs, 0), 0))) 49 | kl = np.mean(np.sum(kl, 1)) 50 | score = np.exp(kl) 51 | print(f"classifier score: {score}") 52 | 53 | np.savez(args.output_path, mean=mean, cov=cov, probs=probs, class_score=score) 54 | 55 | 56 | def batch_segments( 57 | batch_size: int, segs: Iterator[torch.Tensor] 58 | ) -> Iterator[torch.Tensor]: 59 | batch = [] 60 | for seg in segs: 61 | batch.append(seg) 62 | if len(batch) == batch_size: 63 | yield torch.stack(batch)[:, None] 64 | batch = [] 65 | if len(batch): 66 | yield torch.stack(batch)[:, None] 67 | 68 | 69 | def load_segments(args) -> Iterator[torch.Tensor]: 70 | if (args.data_dir is None and args.sample_dir is None) or ( 71 | args.data_dir is not None and args.sample_dir is not None 72 | ): 73 | raise argparse.ArgumentError( 74 | message="must specify --data-dir or --sample-dir, but not both" 75 | ) 76 | if args.data_dir is not None: 77 | loader, _ = create_data_loader(args.data_dir, batch_size=1) 78 | return segments_from_loader(args.num_samples, loader) 79 | else: 80 | files = [ 81 | os.path.join(args.sample_dir, x) 82 | for x in os.listdir(args.sample_dir) 83 | if not x.startswith(".") and x.endswith(".wav") 84 | ] 85 | if args.num_samples: 86 | files = files[: args.num_samples] 87 | return segments_from_files(files) 88 | 89 | 90 | def segments_from_loader( 91 | limit: Optional[int], loader: Iterable[dict] 92 | ) -> Iterator[torch.Tensor]: 93 | i = 0 94 | for batch in loader: 95 | yield batch["samples"].view(-1) 96 | i += 1 97 | if limit and i >= limit: 98 | break 99 | 100 | 101 | def segments_from_files(files: List[str]) -> Iterator[torch.Tensor]: 102 | ctx = mp.get_context("spawn") 103 | with ctx.Pool(4) as pool: 104 | for x in pool.imap_unordered(_read_audio_file, files): 105 | yield torch.from_numpy(x) 106 | 107 | 108 | def _read_audio_file(path: str) -> np.ndarray: 109 | duration = lookup_audio_duration(path) # may not be precise 110 | cr = ChunkReader(path, sample_rate=16000) 111 | return cr.read(16000 * int(duration + 2)) 112 | 113 | 114 | def arg_parser(): 115 | parser = argparse.ArgumentParser( 116 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 117 | ) 118 | parser.add_argument("--checkpoint-path", default="model_classifier.pt", type=str) 119 | parser.add_argument("--batch-size", default=4, type=int) 120 | parser.add_argument("--num-samples", default=None, type=int) 121 | parser.add_argument("--sample-dir", default=None, type=str) 122 | parser.add_argument("--data-dir", default=None, type=str) 123 | parser.add_argument("output_path", type=str) 124 | return parser 125 | 126 | 127 | if __name__ == "__main__": 128 | main() 129 | -------------------------------------------------------------------------------- /train_classifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a voice classifier on noised inputs. 3 | """ 4 | 5 | from vq_voice_swap.train_loop import ClassifierTrainLoop 6 | 7 | 8 | def main(): 9 | ClassifierTrainLoop().loop() 10 | 11 | 12 | if __name__ == "__main__": 13 | main() 14 | -------------------------------------------------------------------------------- /train_diffusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train an diffusion model on waveforms. 3 | """ 4 | 5 | from vq_voice_swap.train_loop import DiffusionTrainLoop 6 | 7 | 8 | def main(): 9 | DiffusionTrainLoop().loop() 10 | 11 | 12 | if __name__ == "__main__": 13 | main() 14 | -------------------------------------------------------------------------------- /train_enc_pred.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a model to predict the latents from a VQVAE encoder from noised audio 3 | samples. 4 | """ 5 | 6 | from vq_voice_swap.train_loop import EncoderPredictorTrainLoop 7 | 8 | 9 | def main(): 10 | EncoderPredictorTrainLoop().loop() 11 | 12 | 13 | if __name__ == "__main__": 14 | main() 15 | -------------------------------------------------------------------------------- /train_vqvae.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train an VQ-VAE + diffusion model on waveforms. 3 | """ 4 | 5 | from vq_voice_swap.train_loop import VQVAETrainLoop 6 | 7 | 8 | def main(): 9 | VQVAETrainLoop().loop() 10 | 11 | 12 | if __name__ == "__main__": 13 | main() 14 | -------------------------------------------------------------------------------- /train_vqvae_add.py: -------------------------------------------------------------------------------- 1 | """ 2 | Add classes to a pre-trained VQVAE from a new dataset. 3 | """ 4 | 5 | from vq_voice_swap.train_loop import VQVAEAddClassesTrainLoop 6 | 7 | 8 | def main(): 9 | VQVAEAddClassesTrainLoop().loop() 10 | 11 | 12 | if __name__ == "__main__": 13 | main() 14 | -------------------------------------------------------------------------------- /train_vqvae_uncond.py: -------------------------------------------------------------------------------- 1 | """ 2 | Fine-tune a pre-trained VQVAE to not always take VQ codes 3 | or classes as input. Adds a new zero class that is 4 | unconditional. 5 | """ 6 | 7 | from vq_voice_swap.train_loop import VQVAEUncondTrainLoop 8 | 9 | 10 | def main(): 11 | VQVAEUncondTrainLoop().loop() 12 | 13 | 14 | if __name__ == "__main__": 15 | main() 16 | -------------------------------------------------------------------------------- /voice_search_vqvae.py: -------------------------------------------------------------------------------- 1 | """ 2 | Find the class label that minimizes the reconstruction error of an audio clip. 3 | 4 | This can be seen as searching for the voice that best matches the actual 5 | speaker of a clip. 6 | """ 7 | 8 | import argparse 9 | 10 | import torch 11 | from tqdm.auto import tqdm 12 | 13 | from vq_voice_swap.dataset import ChunkReader 14 | from vq_voice_swap.vq_vae import VQVAE 15 | 16 | 17 | def main(): 18 | args = arg_parser().parse_args() 19 | 20 | print("loading model from checkpoint...") 21 | model = VQVAE.load(args.checkpoint_path) 22 | 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | model.to(device) 25 | 26 | print(f"loading waveform from {args.input_file}...") 27 | reader = ChunkReader( 28 | args.input_file, sample_rate=args.sample_rate, encoding=args.encoding 29 | ) 30 | try: 31 | chunk = reader.read(args.seconds * args.sample_rate) 32 | finally: 33 | reader.close() 34 | in_seq = torch.from_numpy(chunk[None, None]).to(device) 35 | 36 | print("encoding audio sequence...") 37 | encoded = model.vq.embed(model.encode(in_seq)).detach() 38 | 39 | print("evaluating all losses...") 40 | labels = ( 41 | torch.tensor( 42 | [i for i in range(model.num_labels) for _ in range(args.num_timesteps)] 43 | ) 44 | .long() 45 | .to(device) 46 | ) 47 | ts = torch.linspace( 48 | 0.0, 1.0, steps=args.num_timesteps, dtype=torch.float32, device=device 49 | ).repeat(model.num_labels) 50 | losses = ( 51 | evaluate_losses( 52 | model, in_seq, labels, ts, encoded, args.batch_size, args.num_seeds 53 | ) 54 | .reshape([-1, args.num_timesteps]) 55 | .mean(-1) 56 | .cpu() 57 | .numpy() 58 | .tolist() 59 | ) 60 | 61 | print(f"top {min(args.top_k, len(losses))} sorted losses") 62 | print("-------") 63 | id_loss = sorted(enumerate(losses), key=lambda x: x[1]) 64 | for id, loss in id_loss[: args.top_k]: 65 | print(f"{id}\t\t{loss:.6f}") 66 | 67 | 68 | def evaluate_losses( 69 | model: VQVAE, 70 | targets: torch.Tensor, 71 | labels: torch.Tensor, 72 | ts: torch.Tensor, 73 | encoded: torch.Tensor, 74 | batch_size: int, 75 | num_seeds: int, 76 | ): 77 | results = [] 78 | 79 | # Fix a noise seed for every example to reduce variance 80 | epsilons = torch.randn_like(targets[None].repeat(num_seeds, 1, 1, 1)) 81 | 82 | for i in tqdm(range(0, len(labels), batch_size)): 83 | labels_mb = labels[i : i + batch_size] 84 | ts_mb = ts[i : i + batch_size] 85 | encoded_mb = encoded.repeat(len(ts_mb), 1, 1) 86 | targets_mb = targets.repeat(len(ts_mb), 1, 1) 87 | 88 | sub_results = [] 89 | for epsilon in epsilons: 90 | epsilon_mb = epsilon.repeat(len(ts_mb), 1, 1) 91 | noised_inputs = model.diffusion.sample_q( 92 | targets_mb, ts_mb, epsilon=epsilon_mb 93 | ) 94 | with torch.no_grad(): 95 | predictions = model.predictor( 96 | noised_inputs, ts_mb, cond=encoded_mb, labels=labels_mb 97 | ) 98 | mses = ((predictions - epsilon_mb) ** 2).flatten(1).mean(1) 99 | sub_results.append(mses) 100 | 101 | results.append(torch.stack(sub_results).mean(0)) 102 | 103 | return torch.cat(results) 104 | 105 | 106 | def arg_parser(): 107 | parser = argparse.ArgumentParser( 108 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 109 | ) 110 | parser.add_argument("--sample-rate", type=int, default=16000) 111 | parser.add_argument("--seconds", type=int, default=4) 112 | parser.add_argument("--encoding", type=str, default="linear") 113 | parser.add_argument("--num-timesteps", type=int, default=16) 114 | parser.add_argument("--num-seeds", type=int, default=1) 115 | parser.add_argument("--batch-size", type=int, default=16) 116 | parser.add_argument("--top-k", type=int, default=20) 117 | parser.add_argument("--input-file", type=str, default=None, required=True) 118 | parser.add_argument("checkpoint_path", type=str) 119 | return parser 120 | 121 | 122 | if __name__ == "__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /vq_voice_swap/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/vq-voice-swap/fffc6b8ff176a41659bba1842cb501cee493b124/vq_voice_swap/__init__.py -------------------------------------------------------------------------------- /vq_voice_swap/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import subprocess 4 | from typing import Dict, Optional, Tuple, Union 5 | 6 | import numpy as np 7 | from torch.utils.data import DataLoader, Dataset 8 | 9 | DURATION_ESTIMATE_SLACK = 0.05 10 | 11 | 12 | def create_data_loader( 13 | directory: str, batch_size: int, encoding="linear", num_workers=4, **dataset_kwargs 14 | ) -> Tuple[DataLoader, int]: 15 | """ 16 | Create an audio data loader, either from LibriSpeech or from a small 17 | synthetic set of tones. 18 | 19 | Returned batches are dicts containing at least two keys: 20 | - 'label' (int): the speaker ID. 21 | - 'samples' (tensor): an [N x T] batch of samples. 22 | 23 | :param directory: the LibriSpeech data directory, or "tones" to use a 24 | placeholder dataset. 25 | :param batch_size: the number of samples per batch. 26 | :param encoding: the audio encoding, "linear" or "ulaw". 27 | :param num_workers: number of parallel data loading threads. 28 | :return: a pair (loader, num_labels), where loader is the DataLoader and 29 | num_labels is one greater than the maximum label index. 30 | """ 31 | if directory == "tones": 32 | dataset = ToneDataset(encoding=encoding) 33 | else: 34 | dataset = LibriSpeech(directory, encoding=encoding, **dataset_kwargs) 35 | return ( 36 | DataLoader( 37 | dataset, 38 | batch_size=batch_size, 39 | shuffle=True, 40 | num_workers=num_workers, 41 | drop_last=True, 42 | ), 43 | len(dataset.speaker_ids), 44 | ) 45 | 46 | 47 | class LibriSpeech(Dataset): 48 | def __init__( 49 | self, 50 | directory: str, 51 | encoding: str = "linear", 52 | window_duration: float = 4.0, 53 | window_spacing: float = 0.2, 54 | sample_rate: int = 16000, 55 | ): 56 | self.directory = directory 57 | self.encoding = encoding 58 | self.window_duration = window_duration 59 | self.window_spacing = window_spacing 60 | self.sample_rate = sample_rate 61 | 62 | index_path = os.path.join(self.directory, "index.json") 63 | if os.path.exists(index_path): 64 | with open(index_path, "rt") as f: 65 | self.index = json.load(f) 66 | else: 67 | self.index = _build_file_index(directory) 68 | with open(index_path, "wt") as f: 69 | json.dump(self.index, f) 70 | 71 | self.speaker_ids = sorted(self.index.keys()) 72 | self.data = [] 73 | for label, speaker_id in enumerate(self.speaker_ids): 74 | self._create_speaker_data( 75 | label, os.path.join(self.directory, speaker_id), self.index[speaker_id] 76 | ) 77 | 78 | def _create_speaker_data( 79 | self, label: int, path: str, index_dict: Dict[str, Union[Dict, float]] 80 | ): 81 | for name, item in index_dict.items(): 82 | sub_path = os.path.join(path, name) 83 | if isinstance(item, float): 84 | window_samples = int(self.sample_rate * self.window_duration) 85 | space_samples = int(self.sample_rate * self.window_spacing) 86 | total_samples = int(self.sample_rate * (item - DURATION_ESTIMATE_SLACK)) 87 | idx = 0 88 | if window_samples >= total_samples: 89 | self.data.append(LibriSpeechDatum(label, sub_path, 0)) 90 | else: 91 | while idx + window_samples < total_samples: 92 | self.data.append(LibriSpeechDatum(label, sub_path, idx)) 93 | idx += space_samples 94 | else: 95 | self._create_speaker_data(label, sub_path, item) 96 | 97 | def __len__(self) -> int: 98 | return len(self.data) 99 | 100 | def __getitem__(self, index) -> Dict[str, Union[int, np.ndarray]]: 101 | datum = self.data[index] 102 | reader = ChunkReader(datum.path, self.sample_rate, encoding=self.encoding) 103 | try: 104 | reader.read(datum.offset) 105 | num_samples = int(self.sample_rate * self.window_duration) 106 | samples = reader.read(num_samples) 107 | samples = np.pad(samples, (0, num_samples - len(samples))) 108 | return {"label": datum.label, "samples": samples} 109 | finally: 110 | reader.close() 111 | 112 | 113 | class LibriSpeechDataError(Exception): 114 | pass 115 | 116 | 117 | class LibriSpeechDatum: 118 | def __init__(self, label: int, path: str, offset: int): 119 | self.label = label 120 | self.path = path 121 | self.offset = offset 122 | 123 | 124 | class ToneDataset(Dataset): 125 | """ 126 | A dataset where each "speaker" is a different frequency and each sample is 127 | just a phase-shifted sinusoidal wave. 128 | """ 129 | 130 | def __init__(self, encoding: str = "linear"): 131 | self.encoding = encoding 132 | self.speaker_ids = [300, 500, 1000] 133 | 134 | def __len__(self): 135 | return len(self.speaker_ids) * 10 136 | 137 | def __getitem__(self, index) -> Dict[str, Union[int, np.ndarray]]: 138 | speaker = index % len(self.speaker_ids) 139 | frequency = self.speaker_ids[speaker] 140 | phase = (index // len(self.speaker_ids)) / 10 141 | 142 | data = np.arange(0, 64000, step=1).astype(np.float32) / 16000 143 | coeffs = (data + phase) * np.pi * 2 * frequency 144 | 145 | samples = np.sin(coeffs) 146 | samples = encode_from_linear(samples, self.encoding) 147 | 148 | return { 149 | "label": speaker, 150 | "samples": samples, 151 | } 152 | 153 | 154 | def _build_file_index(data_dir: str) -> Dict[str, Union[Dict, float]]: 155 | result = {} 156 | for item in os.listdir(data_dir): 157 | item_path = os.path.join(data_dir, item) 158 | if item.endswith(".flac") and not item.startswith("."): 159 | result[item] = lookup_audio_duration(item_path) 160 | elif os.path.isdir(item_path): 161 | sub_result = _build_file_index(item_path) 162 | if len(sub_result): 163 | result[item] = sub_result 164 | return result 165 | 166 | 167 | class ChunkReader: 168 | """ 169 | An API for reading chunks of audio samples from an audio file. 170 | 171 | :param path: the path to the audio file. 172 | :param sample_rate: the number of samples per second, used for resampling. 173 | 174 | Adapted from https://github.com/unixpickle/udt-voice-swap/blob/9ab0404c3e102ec19709c2d6e9763ae629b4f897/voice_swap/data.py#L63 175 | """ 176 | 177 | def __init__(self, path: str, sample_rate: int, encoding: str = "linear"): 178 | self.path = path 179 | self.sample_rate = sample_rate 180 | self.encoding = encoding 181 | self._done = False 182 | 183 | audio_reader, audio_writer = os.pipe() 184 | try: 185 | args = [ 186 | "ffmpeg", 187 | "-i", 188 | path, 189 | "-f", 190 | "s16le", 191 | "-ar", 192 | str(sample_rate), 193 | "-ac", 194 | "1", 195 | "pipe:%i" % audio_writer, 196 | ] 197 | self._ffmpeg_proc = subprocess.Popen( 198 | args, 199 | pass_fds=(audio_writer,), 200 | stdin=subprocess.DEVNULL, 201 | stderr=subprocess.DEVNULL, 202 | stdout=subprocess.DEVNULL, 203 | ) 204 | self._audio_reader = audio_reader 205 | audio_reader = None 206 | finally: 207 | os.close(audio_writer) 208 | if audio_reader is not None: 209 | os.close(audio_reader) 210 | 211 | self._reader = os.fdopen(self._audio_reader, "rb") 212 | 213 | def read(self, chunk_size: int) -> Optional[np.ndarray]: 214 | """ 215 | Read a chunk of audio samples from the file. 216 | 217 | :param chunk_size: the number of samples to read. 218 | :return: A chunk of audio, represented as a 1-D numpy array of floats, 219 | where each sample is in the range [-1, 1]. 220 | When there are no more samples left, None is returned. 221 | """ 222 | buf = self.read_raw(chunk_size) 223 | if buf is None: 224 | return None 225 | linear = np.frombuffer(buf, dtype="int16").astype("float32") / (2 ** 15) 226 | return encode_from_linear(linear, self.encoding) 227 | 228 | def read_raw(self, chunk_size) -> Optional[bytes]: 229 | if self._done: 230 | return None 231 | buffer_size = chunk_size * 2 232 | buf = self._reader.read(buffer_size) 233 | if len(buf) < buffer_size: 234 | self._done = True 235 | if not len(buf): 236 | return None 237 | return buf 238 | 239 | def close(self): 240 | if not self._done: 241 | self._reader.close() 242 | self._ffmpeg_proc.wait() 243 | else: 244 | self._ffmpeg_proc.wait() 245 | self._reader.close() 246 | 247 | 248 | class ChunkWriter: 249 | """ 250 | An API for writing chunks of audio samples from an audio file. 251 | 252 | :param path: the path to the audio file. 253 | :param sample_rate: the number of samples per second to write. 254 | """ 255 | 256 | def __init__(self, path: str, sample_rate: int, encoding: str = "linear"): 257 | self.path = path 258 | self.sample_rate = sample_rate 259 | self.encoding = encoding 260 | 261 | audio_reader, audio_writer = os.pipe() 262 | try: 263 | audio_format = ["-ar", str(sample_rate), "-ac", "1", "-f", "s16le"] 264 | audio_params = audio_format + [ 265 | "-probesize", 266 | "32", 267 | "-thread_queue_size", 268 | "60", 269 | "-i", 270 | "pipe:%i" % audio_reader, 271 | ] 272 | output_params = [path] 273 | self._ffmpeg_proc = subprocess.Popen( 274 | ["ffmpeg", "-y", *audio_params, *output_params], 275 | pass_fds=(audio_reader,), 276 | stdin=subprocess.DEVNULL, 277 | stderr=subprocess.DEVNULL, 278 | stdout=subprocess.DEVNULL, 279 | ) 280 | self._audio_writer = audio_writer 281 | audio_writer = None 282 | finally: 283 | if audio_writer is not None: 284 | os.close(audio_writer) 285 | os.close(audio_reader) 286 | 287 | self._writer = os.fdopen(self._audio_writer, "wb", buffering=1024) 288 | 289 | def write(self, chunk: np.ndarray): 290 | """ 291 | Read a chunk of audio samples from the file. 292 | 293 | :param chunk: a chunk of samples, stored as a 1-D numpy array of floats, 294 | where each sample is in the range [-1, 1]. 295 | """ 296 | chunk = np.clip(chunk, -1, 1) 297 | chunk = decode_to_linear(chunk, self.encoding) 298 | data = bytes((chunk * (2 ** 15 - 1)).astype("int16")) 299 | self._writer.write(data) 300 | 301 | def close(self): 302 | self._writer.close() 303 | self._ffmpeg_proc.wait() 304 | 305 | 306 | def lookup_audio_duration(path: str) -> float: 307 | p = subprocess.Popen( 308 | ["ffmpeg", "-i", path], 309 | stdin=subprocess.DEVNULL, 310 | stdout=subprocess.PIPE, 311 | stderr=subprocess.PIPE, 312 | ) 313 | _, output = p.communicate() 314 | output = str(output, "utf-8") 315 | lines = [x.strip() for x in output.split("\n")] 316 | duration_lines = [x for x in lines if x.startswith("Duration:")] 317 | if len(duration_lines) != 1: 318 | raise ValueError(f"unexpected output from ffmpeg for: {path}") 319 | duration_str = duration_lines[0].split(" ")[1].split(",")[0] 320 | hours, minutes, seconds = [float(x) for x in duration_str.split(":")] 321 | return seconds + (minutes + hours * 60) * 60 322 | 323 | 324 | def encode_from_linear(x: np.ndarray, encoding: str) -> np.ndarray: 325 | if encoding == "linear": 326 | return x 327 | elif encoding == "ulaw": 328 | return encode_u_law(x) 329 | else: 330 | raise ValueError(f"unknown audio encoding: {encoding}") 331 | 332 | 333 | def decode_to_linear(x: np.ndarray, encoding: str) -> np.ndarray: 334 | if encoding == "linear": 335 | return x 336 | elif encoding == "ulaw": 337 | return decode_u_law(x) 338 | else: 339 | raise ValueError(f"unknown audio encoding: {encoding}") 340 | 341 | 342 | def encode_u_law(x: np.ndarray, mu: float = 255.0) -> np.ndarray: 343 | return np.sign(x) * (np.log(1 + mu * np.abs(x)) / np.log(1 + mu)) 344 | 345 | 346 | def decode_u_law(x, mu: float = 255.0) -> np.ndarray: 347 | return np.sign(x) * (1 / mu) * ((1 + mu) ** np.abs(x) - 1) 348 | -------------------------------------------------------------------------------- /vq_voice_swap/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffusion import Diffusion 2 | from .make import make_schedule 3 | from .schedule import CosSchedule, ExpSchedule, Schedule 4 | 5 | __all__ = ["Diffusion", "make_schedule", "CosSchedule", "Schedule", "ExpSchedule"] 6 | -------------------------------------------------------------------------------- /vq_voice_swap/diffusion/diffusion.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional 2 | 3 | import torch 4 | from tqdm.auto import tqdm 5 | 6 | from .schedule import Schedule 7 | 8 | 9 | class Diffusion: 10 | """ 11 | A PyTorch implementation of continuous-time diffusion. 12 | """ 13 | 14 | def __init__(self, schedule: Schedule): 15 | self.schedule = schedule 16 | 17 | def sample_q( 18 | self, x_0: torch.Tensor, ts: torch.Tensor, epsilon: torch.Tensor = None 19 | ) -> torch.Tensor: 20 | """ 21 | Sample from q(x_t | x_0) for a batch of x_0. 22 | """ 23 | if epsilon is None: 24 | epsilon = torch.randn_like(x_0) 25 | alphas = broadcast_as(self.schedule(ts), x_0) 26 | return alphas.sqrt() * x_0 + (1 - alphas).sqrt() * epsilon 27 | 28 | def eps_to_x0( 29 | self, x_t: torch.Tensor, ts: torch.Tensor, epsilon_prediction: torch.Tensor 30 | ) -> torch.Tensor: 31 | """ 32 | Evaluate the mean of p(x_0 | x_t), provided the model's epsilon 33 | prediction for x_t. 34 | """ 35 | alphas = broadcast_as(self.schedule(ts), x_t) 36 | return (x_t - (1 - alphas).sqrt() * epsilon_prediction) * alphas.rsqrt() 37 | 38 | def x0_to_eps( 39 | self, x_t: torch.Tensor, ts: torch.Tensor, x_0: torch.Tensor 40 | ) -> torch.Tensor: 41 | """ 42 | Compute the inverse of eps_to_x0() with respect to epsilon, computing 43 | the epsilon which would have given an x_0 prediction. 44 | """ 45 | alphas = broadcast_as(self.schedule(ts), x_t) 46 | return (x_t - x_0 * alphas.sqrt()) * (1 - alphas).rsqrt() 47 | 48 | def ddpm_previous( 49 | self, 50 | x_t: torch.Tensor, 51 | ts: torch.Tensor, 52 | step: float, 53 | epsilon_prediction: torch.Tensor, 54 | noise: torch.Tensor = None, 55 | sigma_large: bool = False, 56 | constrain: bool = False, 57 | cond_fn: Callable = None, 58 | ) -> torch.Tensor: 59 | """ 60 | Sample the previous timestep using reverse diffusion. 61 | """ 62 | if noise is None: 63 | noise = torch.randn_like(x_t) 64 | alphas_t = broadcast_as(self.schedule(ts), x_t) 65 | alphas_prev = broadcast_as(self.schedule(ts - step), x_t) 66 | alphas = alphas_t / alphas_prev 67 | betas = 1 - alphas 68 | 69 | def eps_to_prev(eps): 70 | return alphas.rsqrt() * (x_t - betas * (1 - alphas_t).rsqrt() * eps) 71 | 72 | def prev_to_eps(prev): 73 | return (-prev * alphas.sqrt() + x_t) * (1 - alphas_t).sqrt() / betas 74 | 75 | if not sigma_large: 76 | sigmas = betas * (1 - alphas_prev) / (1 - alphas_t) 77 | else: 78 | sigmas = betas 79 | 80 | if cond_fn is not None: 81 | mean_pred = eps_to_prev(epsilon_prediction) 82 | mean_pred = mean_pred + sigmas * cond_fn(mean_pred, ts - step) 83 | epsilon_prediction = prev_to_eps(mean_pred) 84 | 85 | if constrain: 86 | x0 = self.eps_to_x0(x_t, ts, epsilon_prediction) 87 | x0 = (x0 - x0.mean(dim=-1, keepdim=True)).clamp(-1, 1) 88 | epsilon_prediction = self.x0_to_eps(x_t, ts, x0) 89 | 90 | return eps_to_prev(epsilon_prediction) + sigmas.sqrt() * noise 91 | 92 | def ddpm_sample( 93 | self, 94 | x_T: torch.Tensor, 95 | predictor: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], 96 | steps: int, 97 | progress: bool = False, 98 | sigma_large: bool = False, 99 | constrain: bool = False, 100 | cond_fn: Callable = None, 101 | schedule: Callable = None, 102 | ) -> torch.Tensor: 103 | """ 104 | Sample x_0 from x_t using reverse diffusion. 105 | """ 106 | x_t = x_T 107 | ts = [(i + 1) / steps for i in range(steps)] 108 | t_step = 1 / steps 109 | 110 | its = enumerate(ts[::-1]) 111 | if progress: 112 | its = tqdm(its) 113 | 114 | for i, t in its: 115 | ts = torch.tensor([t] * x_T.shape[0]).to(x_T) 116 | if schedule is not None: 117 | t_step = schedule(ts) - schedule(ts - 1 / steps) 118 | ts = schedule(ts) 119 | 120 | with torch.no_grad(): 121 | eps = predictor(x_t, ts) 122 | x_t = self.ddpm_previous( 123 | x_t=x_t, 124 | ts=ts, 125 | step=t_step, 126 | epsilon_prediction=eps, 127 | noise=torch.zeros_like(x_T) if i + 1 == steps else None, 128 | sigma_large=sigma_large, 129 | constrain=constrain, 130 | cond_fn=cond_fn, 131 | ) 132 | 133 | return x_t 134 | 135 | def ddpm_losses( 136 | self, 137 | x: torch.tensor, 138 | predictor: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], 139 | ts: Optional[torch.Tensor] = None, 140 | noise: Optional[torch.Tensor] = None, 141 | ) -> torch.Tensor: 142 | """ 143 | Compute the DDPM loss per batch. 144 | """ 145 | if ts is None: 146 | ts = torch.rand(len(x), device=x.device) 147 | if noise is None: 148 | noise = torch.randn_like(x) 149 | samples = self.sample_q(x, ts, epsilon=noise) 150 | noise_pred = predictor(samples, ts) 151 | return ((noise - noise_pred) ** 2).flatten(1).mean(dim=1) 152 | 153 | 154 | def broadcast_as(ts: torch.Tensor, tensor: torch.Tensor) -> torch.Tensor: 155 | while len(ts.shape) < len(tensor.shape): 156 | ts = ts[:, None] 157 | return ts.to(tensor) + torch.zeros_like(tensor) 158 | -------------------------------------------------------------------------------- /vq_voice_swap/diffusion/make.py: -------------------------------------------------------------------------------- 1 | from .schedule import CosSchedule, ExpSchedule, Schedule 2 | 3 | 4 | def make_schedule(name: str) -> Schedule: 5 | """ 6 | Create a schedule from a human-readable name. 7 | """ 8 | if name == "exp": 9 | return ExpSchedule() 10 | elif name == "cos": 11 | return CosSchedule() 12 | else: 13 | raise ValueError(f"unknown schedule: {name}") 14 | -------------------------------------------------------------------------------- /vq_voice_swap/diffusion/schedule.py: -------------------------------------------------------------------------------- 1 | import math 2 | from abc import ABC, abstractmethod 3 | 4 | import torch 5 | 6 | 7 | class Schedule(ABC): 8 | @abstractmethod 9 | def __call__(self, t: torch.Tensor) -> torch.Tensor: 10 | """ 11 | Evaluate alpha for the noise schedule to a timestep in [0, 1]. 12 | """ 13 | 14 | 15 | class ExpSchedule(Schedule): 16 | """ 17 | A noise schedule defined as exp(-k*t^2), which is nearly equivalent to 18 | using betas linearly interpolated from a tiny value to a larger value. 19 | """ 20 | 21 | def __init__(self, alpha_final: float = 1e-5): 22 | super().__init__() 23 | self.alpha_final = alpha_final 24 | 25 | # alpha(t) = exp(-k*t^2) 26 | # alpha(1.0) = exp(-k) 27 | # k = -ln(alpha(1.0)) 28 | self.k = -math.log(alpha_final) 29 | 30 | def __call__(self, t: torch.Tensor) -> torch.Tensor: 31 | return torch.exp(-self.k * (t ** 2)) 32 | 33 | 34 | class CosSchedule(Schedule): 35 | """ 36 | The squared cosine schedule cos(t*pi/2)^2, introduced by 37 | https://arxiv.org/abs/2102.09672. 38 | """ 39 | 40 | def __call__(self, t: torch.Tensor) -> torch.Tensor: 41 | return torch.cos(t * math.pi / 2) ** 2 42 | -------------------------------------------------------------------------------- /vq_voice_swap/diffusion_model.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | import torch 4 | 5 | from .diffusion import Diffusion, make_schedule 6 | from .models import Savable, make_predictor 7 | 8 | 9 | class DiffusionModel(Savable): 10 | """ 11 | A diffusion model and its corresponding diffusion process. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | pred_name: str, 17 | base_channels: int, 18 | schedule_name: str = "exp", 19 | num_labels: Optional[int] = None, 20 | cond_channels: Optional[int] = None, 21 | dropout: float = 0.0, 22 | ): 23 | super().__init__() 24 | self.pred_name = pred_name 25 | self.base_channels = base_channels 26 | self.schedule_name = schedule_name 27 | self.num_labels = num_labels 28 | self.cond_channels = cond_channels 29 | 30 | # Fix bug in some checkpoints where dropout is a tuple. 31 | self.dropout = dropout[0] if isinstance(dropout, tuple) else dropout 32 | 33 | self.predictor = make_predictor( 34 | pred_name, 35 | base_channels=base_channels, 36 | cond_channels=cond_channels, 37 | num_labels=num_labels, 38 | dropout=self.dropout, 39 | ) 40 | self.diffusion = Diffusion(make_schedule(schedule_name)) 41 | 42 | def forward(self, *args, **kwargs) -> torch.Tensor: 43 | return self.predictor(*args, **kwargs) 44 | 45 | def add_labels(self, n: int, end: bool = True): 46 | assert self.num_labels is not None, "model must be class-conditional" 47 | self.predictor.add_labels(n, end=end) 48 | self.num_labels += n 49 | 50 | def save_kwargs(self) -> Dict[str, Any]: 51 | return dict( 52 | pred_name=self.pred_name, 53 | base_channels=self.base_channels, 54 | schedule_name=self.schedule_name, 55 | num_labels=self.num_labels, 56 | cond_channels=self.cond_channels, 57 | dropout=self.dropout, 58 | ) 59 | -------------------------------------------------------------------------------- /vq_voice_swap/ema.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Dict 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class ModelEMA: 9 | """ 10 | An exponential moving average of model parameters. 11 | 12 | :param source_model: the non-EMA model whose parameters are copied. 13 | :param rates: a dict mapping parameter names (or prefixes of names) 14 | to EMA rates. Any parameter not explicitly named in this 15 | dict will use the rate from its longest prefix in the dict. 16 | """ 17 | 18 | def __init__(self, source_model: nn.Module, rates: Dict[str, float]): 19 | self.source_model = source_model 20 | self.rates = rates 21 | self.model = copy.deepcopy(source_model) 22 | 23 | def update(self): 24 | """ 25 | Update the EMA parameters based on the current source parameters. 26 | """ 27 | for (name, source), target in zip( 28 | self.source_model.named_parameters(), self.model.parameters() 29 | ): 30 | rate = 1 - lookup_longest_prefix(self.rates, name) 31 | with torch.no_grad(): 32 | target.add_(rate * (source - target)) 33 | 34 | 35 | def lookup_longest_prefix(values: Dict[str, float], name: str) -> float: 36 | longest = None 37 | for k in values.keys(): 38 | if name.startswith(k) and (longest is None or len(k) > len(longest)): 39 | longest = k 40 | if longest is None: 41 | raise KeyError(f"no rate prefix found for parameter: {name}") 42 | return values[longest] 43 | -------------------------------------------------------------------------------- /vq_voice_swap/logger.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Iterator, TextIO, Tuple, Union 2 | 3 | # The log line indicating that a checkpoint was saved. 4 | SAVED_MSG = "# saved\n" 5 | 6 | 7 | def read_log(log_reader: Union[str, TextIO]) -> Iterator[Tuple[int, Dict[str, Any]]]: 8 | """ 9 | Read entries in a log file as dicts. 10 | 11 | Returns an iterator over (step, dict) pairs. 12 | """ 13 | if isinstance(log_reader, str): 14 | with open(log_reader, "rt") as f: 15 | yield from read_log(f) 16 | return 17 | line_idx = 0 18 | while True: 19 | line = log_reader.readline().rstrip() 20 | line_idx += 1 21 | if not line: 22 | break 23 | elif line.startswith("#"): 24 | continue 25 | try: 26 | if not line.startswith("step "): 27 | raise ValueError 28 | step_str, kv_str = line[5:].split(": ") 29 | step_idx = int(step_str) 30 | kv_strs = kv_str.split(" ") 31 | kvs = {} 32 | for kv_str in kv_strs: 33 | k_str, v_str = kv_str.split("=") 34 | kvs[k_str] = float(v_str) 35 | except ValueError: 36 | raise ValueError(f"unexpected format at line {line_idx}") 37 | yield step_idx, kvs 38 | 39 | 40 | class Logger: 41 | """ 42 | Log training iterations to a file and to standard output. 43 | 44 | The log includes a dict of keys and numerical values for each step, as 45 | well as optional markers whenever checkpoints were saved to a file. 46 | 47 | The log can be resumed, in which case it is automatically truncated to the 48 | last save (or not truncated, if no saves are marked). 49 | To access the step of the first log message from a resume, look at the 50 | start_step attribute. 51 | """ 52 | 53 | def __init__(self, out_filename: str, resume: bool = False): 54 | self.start_step = 0 55 | if resume: 56 | with open(out_filename, "r") as in_file: 57 | all_lines = in_file.readlines() 58 | 59 | # The log may not include a save due to legacy code, but if 60 | # it does, we should truncate to it. 61 | if SAVED_MSG in all_lines: 62 | keep_lines = len(all_lines) - all_lines[::-1].index(SAVED_MSG) 63 | all_lines = all_lines[:keep_lines] 64 | 65 | step_lines = [x for x in all_lines if x.startswith("step ")] 66 | if len(step_lines): 67 | self.start_step = int(step_lines[-1].split(" ")[1].split(":")[0]) 68 | 69 | # Re-write the (possibly truncated) log. 70 | self.out_file = open(out_filename, "w+") 71 | self.out_file.write("".join(all_lines)) 72 | self.out_file.flush() 73 | else: 74 | self.out_file = open(out_filename, "w+") 75 | 76 | def log(self, step: int, **kwargs): 77 | fields = " ".join(f"{k}={v:.05f}" for k, v in kwargs.items()) 78 | log_line = f"step {step + self.start_step}: {fields}" 79 | self.out_file.write(log_line + "\n") 80 | self.out_file.flush() 81 | print(log_line) 82 | 83 | def mark_save(self): 84 | self.out_file.write(SAVED_MSG) 85 | self.out_file.flush() 86 | 87 | def close(self): 88 | self.out_file.close() 89 | -------------------------------------------------------------------------------- /vq_voice_swap/loss_tracker.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | class LossTracker: 8 | """ 9 | Track loss averages throughout training to log separate 10 | quantiles of the diffusion loss function. 11 | """ 12 | 13 | def __init__(self, quantiles: int = 4, avg_size: int = 1000, prefix: str = ""): 14 | self.quantiles = quantiles 15 | self.avg_size = avg_size 16 | self.prefix = prefix 17 | self.history = [[] for _ in range(quantiles)] 18 | 19 | def add(self, ts: torch.Tensor, mses: torch.Tensor): 20 | ts_list = ts.detach().cpu().numpy().tolist() 21 | mses_list = mses.detach().cpu().numpy().tolist() 22 | for t, mse in zip(ts_list, mses_list): 23 | quantile = int(t * (self.quantiles - 1e-8)) 24 | history = self.history[quantile] 25 | if len(history) == self.avg_size: 26 | del history[0] 27 | history.append(mse) 28 | 29 | def quantile_averages(self) -> List[Optional[float]]: 30 | return [float(np.mean(x)) if len(x) else None for x in self.history] 31 | 32 | def log_dict(self) -> Dict[str, float]: 33 | avgs = self.quantile_averages() 34 | return { 35 | f"{self.prefix}q{i}": avg for i, avg in enumerate(avgs) if avg is not None 36 | } 37 | -------------------------------------------------------------------------------- /vq_voice_swap/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Predictor, Savable, atomic_save 2 | from .classifier import Classifier, ClassifierStem 3 | from .conv_encoder import ConvMFCCEncoder 4 | from .encoder_predictor import EncoderPredictor 5 | from .make import make_encoder, make_predictor 6 | from .unet import UNetEncoder, UNetPredictor 7 | from .wavegrad import WaveGradEncoder, WaveGradPredictor 8 | 9 | __all__ = [ 10 | "Predictor", 11 | "Savable", 12 | "atomic_save", 13 | "Classifier", 14 | "ClassifierStem", 15 | "ConvMFCCEncoder", 16 | "EncoderPredictor", 17 | "make_encoder", 18 | "make_predictor", 19 | "UNetEncoder", 20 | "UNetPredictor", 21 | "WaveGradEncoder", 22 | "WaveGradPredictor", 23 | ] 24 | -------------------------------------------------------------------------------- /vq_voice_swap/models/base.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | import tempfile 4 | from abc import abstractmethod 5 | from typing import Any, Callable, Dict, List 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class Predictor(nn.Module): 13 | @abstractmethod 14 | def forward(self, xs: torch.Tensor, ts: torch.Tensor, **kwargs) -> torch.Tensor: 15 | """ 16 | Apply the epsilon predictor to a batch of noised inputs. 17 | """ 18 | 19 | def condition(self, **kwargs) -> Callable: 20 | return functools.partial(self, **kwargs) 21 | 22 | @abstractmethod 23 | def add_labels(self, n: int, end: bool = True): 24 | """ 25 | Add a number of class label embeddings. 26 | 27 | This should only be called on class-conditional predictors. 28 | """ 29 | 30 | @abstractmethod 31 | def label_parameters(self) -> List[nn.Parameter]: 32 | """ 33 | Get all of the parameters which encode label information. 34 | 35 | This should only be called on class-conditional predictors. 36 | """ 37 | 38 | @property 39 | @abstractmethod 40 | def downsample_rate(self) -> int: 41 | """ 42 | Get the downsample rate to ensure that input sequences are evenly 43 | divisible by it. 44 | """ 45 | 46 | 47 | class Encoder(nn.Module): 48 | @abstractmethod 49 | def forward(self, xs: torch.Tensor, **kwargs) -> torch.Tensor: 50 | """ 51 | Apply the encoder to get a lower-resolution tensor. 52 | """ 53 | 54 | @property 55 | @abstractmethod 56 | def downsample_rate(self) -> int: 57 | """ 58 | Get the downsample rate to ensure that input sequences are evenly 59 | divisible by it. 60 | """ 61 | 62 | 63 | class Savable(nn.Module): 64 | """ 65 | A module which saves constructor kwargs to reconstruct itself. 66 | """ 67 | 68 | @abstractmethod 69 | def save_kwargs(self) -> Dict[str, Any]: 70 | """ 71 | Get kwargs for restoring this model. 72 | """ 73 | 74 | def save_dict(self) -> Dict[str, Any]: 75 | """ 76 | Save a dict compatible with load_dict(). 77 | """ 78 | return { 79 | "kwargs": self.save_kwargs(), 80 | "state_dict": self.state_dict(), 81 | } 82 | 83 | @classmethod 84 | def load_dict(cls, state: Dict[str, Any]) -> Any: 85 | """ 86 | Construct an object saved with save_dict(). 87 | """ 88 | obj = cls(**state["kwargs"]) 89 | obj.load_state_dict(state["state_dict"]) 90 | return obj 91 | 92 | def save(self, path: str): 93 | """ 94 | Save this model to a file for loading with load(). 95 | """ 96 | atomic_save(self.save_dict(), path) 97 | 98 | @classmethod 99 | def load(cls, path: str): 100 | """ 101 | Load a fresh model instance from a file created with save(). 102 | """ 103 | state = torch.load(path, map_location="cpu") 104 | return cls.load_dict(state) 105 | 106 | def load_from_pretrained(self, model: nn.Module) -> int: 107 | """ 108 | Load the available parameters from a model into self. 109 | This only copies the union of self and model. 110 | 111 | :return: the total number of parameters copied. In particular, this is 112 | the sum of the product of the shapes of the parameters. 113 | """ 114 | src_params = dict(model.named_parameters()) 115 | dst_params = dict(self.named_parameters()) 116 | total = 0 117 | for name, dst in dst_params.items(): 118 | if name in src_params: 119 | with torch.no_grad(): 120 | if dst.shape != src_params[name].shape: 121 | raise RuntimeError( 122 | f"Parameter {name} has shape {dst.shape} in destination " 123 | f"but {src_params[name].shape} in source." 124 | ) 125 | dst.copy_(src_params[name]) 126 | total += np.prod(dst.shape) 127 | return total 128 | 129 | 130 | def atomic_save(state: Any, path: str): 131 | with tempfile.TemporaryDirectory() as tmp_dir: 132 | tmp_file = os.path.join(tmp_dir, "out.pt") 133 | torch.save(state, tmp_file) 134 | os.rename(tmp_file, path) 135 | -------------------------------------------------------------------------------- /vq_voice_swap/models/classifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | Flexible audio sequence classification models. 3 | """ 4 | 5 | import math 6 | from typing import Any, Dict, Optional 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | from torch.utils.checkpoint import checkpoint 12 | 13 | from .base import Savable 14 | from .unet import ResBlock, UNetPredictor, activation, norm_act, scale_module 15 | from .wavegrad import TimeEmbedding 16 | 17 | 18 | class Classifier(Savable): 19 | """ 20 | A module which adds an N-way linear layer head to a classifier stem. 21 | """ 22 | 23 | def __init__(self, num_labels: int, **kwargs): 24 | super().__init__() 25 | self.num_labels = num_labels 26 | self.stem = ClassifierStem(**kwargs) 27 | self.out = nn.Sequential( 28 | activation(), scale_module(nn.Linear(self.stem.out_channels, num_labels)) 29 | ) 30 | 31 | def forward( 32 | self, x: torch.Tensor, ts: torch.Tensor, use_checkpoint: bool = False, **kwargs 33 | ) -> torch.Tensor: 34 | h = self.stem(x, ts, use_checkpoint=use_checkpoint, **kwargs) 35 | h = self.out(h) 36 | return h 37 | 38 | def save_kwargs(self) -> Dict[str, Any]: 39 | return dict( 40 | num_labels=self.num_labels, 41 | base_channels=self.stem.base_channels, 42 | channel_mult=self.stem.channel_mult, 43 | output_mult=self.stem.output_mult, 44 | depth_mult=self.stem.depth_mult, 45 | ) 46 | 47 | 48 | class ClassifierStem(nn.Module): 49 | """ 50 | A module which takes [N x 1 x T] sequences and produces feature vectors of 51 | the shape [N x C]. 52 | """ 53 | 54 | def __init__( 55 | self, 56 | base_channels: int = 32, 57 | channel_mult: int = (1, 1, 2, 2, 2, 4, 4, 8, 8), 58 | output_mult: int = 16, 59 | depth_mult: int = 2, 60 | ): 61 | super().__init__() 62 | self.base_channels = base_channels 63 | self.channel_mult = channel_mult 64 | self.output_mult = output_mult 65 | self.depth_mult = depth_mult 66 | self.out_channels = base_channels * output_mult 67 | 68 | embed_dim = base_channels * 4 69 | self.embed_dim = embed_dim 70 | self.time_embed = TimeEmbedding(embed_dim) 71 | self.time_embed_extra = nn.Sequential( 72 | activation(), 73 | nn.Linear(embed_dim, embed_dim), 74 | ) 75 | 76 | self.in_conv = nn.Conv1d(1, base_channels, kernel_size=3, padding=1) 77 | 78 | self.blocks = nn.ModuleList([]) 79 | cur_channels = base_channels 80 | for ch_mult in channel_mult: 81 | for _ in range(depth_mult): 82 | self.blocks.append( 83 | ResBlock( 84 | channels=cur_channels, 85 | out_channels=ch_mult * base_channels, 86 | emb_channels=embed_dim, 87 | ) 88 | ) 89 | cur_channels = ch_mult * base_channels 90 | self.blocks.append( 91 | ResBlock( 92 | channels=cur_channels, 93 | out_channels=cur_channels, 94 | emb_channels=embed_dim, 95 | scale_factor=0.5, 96 | ) 97 | ) 98 | 99 | self.out = nn.Sequential( 100 | norm_act(cur_channels), 101 | AttentionPool1d( 102 | cur_channels, 103 | head_channels=min(cur_channels, 64), 104 | out_channels=self.out_channels, 105 | ), 106 | ) 107 | 108 | def conditional_embedding(self, ts: torch.Tensor, **kwargs) -> torch.Tensor: 109 | return self.time_embed_extra(self.time_embed(ts)) 110 | 111 | def forward( 112 | self, x: torch.Tensor, ts: torch.Tensor, use_checkpoint: bool = False, **kwargs 113 | ) -> torch.Tensor: 114 | emb = self.conditional_embedding(ts, **kwargs) 115 | h = self.in_conv(x) 116 | for block in self.blocks: 117 | if use_checkpoint: 118 | h = checkpoint(block, h, emb) 119 | else: 120 | h = block(h, emb) 121 | return self.out(h) 122 | 123 | def load_from_predictor(self, pred: UNetPredictor) -> int: 124 | dsts = [self.in_conv, self.time_embed, self.time_embed_extra, *self.blocks] 125 | srcs = [pred.in_conv, pred.time_embed, pred.time_embed_extra, *pred.down_blocks] 126 | total = 0 127 | for dst, src in zip(dsts, srcs): 128 | dst.load_state_dict(src.state_dict()) 129 | total += sum(np.prod(x.shape) for x in src.state_dict().values()) 130 | return total 131 | 132 | 133 | class AttentionPool1d(nn.Module): 134 | """ 135 | Adapted from: https://github.com/openai/guided-diffusion/blob/b16b0a180ffac9da8a6a03f1e78de8e96669eee8/guided_diffusion/unet.py#L22 136 | """ 137 | 138 | def __init__( 139 | self, 140 | channels: int, 141 | head_channels: int = 64, 142 | out_channels: Optional[int] = None, 143 | ): 144 | super().__init__() 145 | assert ( 146 | channels % head_channels == 0 147 | ), f"head channels ({head_channels}) must divide output channels ({out_channels})" 148 | self.qkv_proj = nn.Conv1d(channels, 3 * channels, 1) 149 | self.c_proj = nn.Conv1d(channels, out_channels or channels, 1) 150 | self.num_heads = channels // head_channels 151 | self.attention = QKVAttention(self.num_heads) 152 | 153 | def forward(self, x: torch.Tensor) -> torch.Tensor: 154 | x = torch.cat([torch.zeros_like(x[..., :1]), x], dim=-1) # NC(T+1) 155 | x = self.qkv_proj(x) 156 | x = self.attention(x) 157 | x = self.c_proj(x) 158 | return x[..., 0] 159 | 160 | 161 | class QKVAttention(nn.Module): 162 | """ 163 | Adapted from: https://github.com/openai/guided-diffusion/blob/b16b0a180ffac9da8a6a03f1e78de8e96669eee8/guided_diffusion/unet.py#L361 164 | """ 165 | 166 | def __init__(self, n_heads: int): 167 | super().__init__() 168 | self.n_heads = n_heads 169 | 170 | def forward(self, qkv: torch.Tensor) -> torch.Tensor: 171 | """ 172 | Apply QKV attention. 173 | 174 | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. 175 | :return: an [N x (H * C) x T] tensor after attention. 176 | """ 177 | bs, width, length = qkv.shape 178 | assert width % (3 * self.n_heads) == 0 179 | ch = width // (3 * self.n_heads) 180 | q, k, v = qkv.chunk(3, dim=1) 181 | scale = 1 / math.sqrt(math.sqrt(ch)) 182 | weight = torch.einsum( 183 | "bct,bcs->bts", 184 | (q * scale).view(bs * self.n_heads, ch, length), 185 | (k * scale).view(bs * self.n_heads, ch, length), 186 | ) 187 | weight = torch.softmax(weight, dim=-1) 188 | a = torch.einsum( 189 | "bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length) 190 | ) 191 | return a.reshape(bs, -1, length) 192 | -------------------------------------------------------------------------------- /vq_voice_swap/models/conv_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Encoders from https://arxiv.org/abs/1901.08810. 3 | """ 4 | 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.utils.checkpoint import checkpoint 10 | 11 | from .base import Encoder 12 | 13 | 14 | class ConvMFCCEncoder(Encoder): 15 | """ 16 | The convolutional model with MFCC features at regular intervals. 17 | 18 | Requires torchaudio upon initialization. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | base_channels: int, 24 | out_channels: int = 64, 25 | input_ulaw: bool = True, 26 | input_rate: int = 16000, 27 | mfcc_rate: int = 100, 28 | version: int = 1, 29 | ): 30 | super().__init__() 31 | self.base_channels = base_channels 32 | self.out_channels = out_channels 33 | self.input_ulaw = input_ulaw 34 | self.input_rate = input_rate 35 | self.mfcc_rate = mfcc_rate 36 | self.mid_channels = base_channels * 12 37 | self.version = version 38 | 39 | assert mfcc_rate % 2 == 0, "must be able to downsample MFCCs once" 40 | assert input_rate % mfcc_rate == 0, "must evenly downsample input sequences" 41 | 42 | from torchaudio.transforms import MFCC 43 | 44 | if version == 2: 45 | n_fft = round(400 * input_rate / 16000) 46 | else: 47 | n_fft = (input_rate // self.mfcc_rate) * 2 48 | self.mfcc = MFCC( 49 | sample_rate=input_rate, 50 | n_mfcc=13, 51 | log_mels=version == 1, 52 | melkwargs=dict( 53 | n_fft=n_fft, 54 | hop_length=input_rate // self.mfcc_rate, 55 | n_mels=40 if version == 1 else 80, 56 | normalized=version == 2, 57 | ), 58 | ) 59 | 60 | self.blocks = nn.ModuleList( 61 | [ 62 | nn.Sequential( 63 | nn.Conv1d(13 * 3, self.mid_channels, 3, padding=1), 64 | nn.GELU(), 65 | ), 66 | ResConv(self.mid_channels, self.mid_channels, 3, padding=1), 67 | nn.Sequential( 68 | nn.Conv1d( 69 | self.mid_channels, self.mid_channels, 4, stride=2, padding=1 70 | ), 71 | nn.GELU(), 72 | ), 73 | *[ 74 | ResConv(self.mid_channels, self.mid_channels, 3, padding=1) 75 | for _ in range(2) 76 | ], 77 | *[ResConv(self.mid_channels, self.mid_channels, 1) for _ in range(4)], 78 | nn.Conv1d(self.mid_channels, self.out_channels, 1), 79 | ] 80 | ) 81 | # Zero output so that by default we don't affect the 82 | # behavior of downstream models. 83 | for p in self.blocks[-1].parameters(): 84 | with torch.no_grad(): 85 | p.zero_() 86 | 87 | def forward( 88 | self, 89 | x: torch.Tensor, 90 | use_checkpoint: bool = False, 91 | ) -> torch.Tensor: 92 | assert x.shape[1] == 1, "input must only have one channel" 93 | if self.input_ulaw: 94 | # MFCC layer expects linear waveform. 95 | x = invert_ulaw(x) 96 | h = self.mfcc(x[:, 0, :]) 97 | deriv = deltas(h) 98 | accel = deltas(deriv) 99 | h = torch.cat([h, deriv, accel], dim=1) 100 | for block in self.blocks: 101 | if use_checkpoint and h.requires_grad: 102 | h = checkpoint(block, h) 103 | else: 104 | h = block(h) 105 | return h 106 | 107 | @property 108 | def downsample_rate(self) -> int: 109 | return self.input_rate // (self.mfcc_rate // 2) 110 | 111 | 112 | class ResConv(nn.Module): 113 | def __init__(self, *args, **kwargs): 114 | super().__init__() 115 | self.conv = nn.Conv1d(*args, **kwargs) 116 | 117 | def forward(self, x): 118 | h = self.conv(x) 119 | h = F.gelu(h) 120 | return x + h 121 | 122 | 123 | def deltas(seq: torch.Tensor) -> torch.Tensor: 124 | right_shifted = torch.cat([seq[..., :1], seq[..., :-1]], dim=-1) 125 | left_shifted = torch.cat([seq[..., 1:], seq[..., -1:]], dim=-1) 126 | 127 | d1 = right_shifted - seq 128 | d2 = seq - left_shifted 129 | return (d1 + d2) / 2 130 | 131 | 132 | def invert_ulaw(x: torch.Tensor, mu: float = 255.0) -> torch.Tensor: 133 | return x.sign() * (1 / mu) * ((1 + mu) ** x.abs() - 1) 134 | -------------------------------------------------------------------------------- /vq_voice_swap/models/encoder_predictor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Models to predict the outputs of an Encoder from noised audio. 3 | """ 4 | 5 | from typing import Any, Dict 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from .base import Savable 12 | from .unet import UNetPredictor 13 | 14 | 15 | class EncoderPredictor(Savable): 16 | """ 17 | A model which predicts a series of categorical variables. 18 | 19 | :param base_channels: channel multiplier for the model. 20 | :param downsample_rate: downsampling factor for the latents. 21 | :param num_latents: dictionary size we are predicting. 22 | :param bottleneck_dim: the bottleneck layer dimension. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | base_channels: int, 28 | downsample_rate: int, 29 | num_latents: int, 30 | bottleneck_dim: int = 64, 31 | ): 32 | super().__init__() 33 | self.base_channels = base_channels 34 | self.downsample_rate = downsample_rate 35 | self.num_latents = num_latents 36 | self.bottleneck_dim = bottleneck_dim 37 | self.unet = UNetPredictor(base_channels, out_channels=bottleneck_dim) 38 | self.out = nn.Conv1d(bottleneck_dim, num_latents, 1) 39 | 40 | def forward( 41 | self, x: torch.Tensor, ts: torch.Tensor, use_checkpoint: bool = False 42 | ) -> torch.Tensor: 43 | """ 44 | Predict the codes for a given sequence. 45 | 46 | :param x: an [N x C x T] Tensor. 47 | :param ts: an [N] Tensor of timesteps. 48 | :param use_checkpoint: if true, use gradient checkpointing. 49 | :return: an [N x D x T//R] Tensor of logits, where D is the number of 50 | categorical latents, and R is the downsampling rate. 51 | """ 52 | h = self.unet(x, ts, use_checkpoint=use_checkpoint) 53 | h = F.interpolate( 54 | h, size=(h.shape[-1] // self.downsample_rate,), mode="nearest" 55 | ) 56 | h = self.out(h) 57 | return h 58 | 59 | def losses( 60 | self, x: torch.Tensor, ts: torch.Tensor, targets: torch.Tensor, **kwargs 61 | ) -> torch.Tensor: 62 | losses = F.cross_entropy(self(x, ts, **kwargs), targets, reduction="none") 63 | return losses.mean(-1) 64 | 65 | def save_kwargs(self) -> Dict[str, Any]: 66 | return dict( 67 | base_channels=self.base_channels, 68 | downsample_rate=self.downsample_rate, 69 | num_latents=self.num_latents, 70 | bottleneck_dim=self.bottleneck_dim, 71 | ) 72 | -------------------------------------------------------------------------------- /vq_voice_swap/models/make.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from .base import Encoder, Predictor 4 | from .conv_encoder import ConvMFCCEncoder 5 | from .unet import UNetEncoder, UNetPredictor 6 | from .wavegrad import WaveGradEncoder, WaveGradPredictor 7 | 8 | 9 | def make_predictor( 10 | pred_name: str, 11 | base_channels: int = 32, 12 | num_labels: Optional[int] = None, 13 | cond_channels: Optional[int] = None, 14 | dropout: float = 0.0, 15 | ) -> Predictor: 16 | """ 17 | Create a Predictor model from a human-readable name. 18 | """ 19 | if pred_name == "wavegrad": 20 | assert not dropout, "dropout not supported for wavegrad" 21 | cond_mult = cond_channels // base_channels if cond_channels else 16 22 | return WaveGradPredictor( 23 | base_channels=base_channels, 24 | cond_mult=cond_mult, 25 | num_labels=num_labels, 26 | ) 27 | elif pred_name == "unet": 28 | return UNetPredictor( 29 | base_channels=base_channels, 30 | cond_channels=cond_channels, 31 | num_labels=num_labels, 32 | dropout=dropout, 33 | ) 34 | else: 35 | raise ValueError(f"unknown predictor: {pred_name}") 36 | 37 | 38 | def make_encoder( 39 | enc_name: str, 40 | base_channels: int = 32, 41 | cond_mult: int = 16, 42 | ) -> Encoder: 43 | """ 44 | Create an Encoder model from a human-readable name. 45 | """ 46 | if enc_name == "wavegrad": 47 | return WaveGradEncoder(cond_mult=cond_mult, base_channels=base_channels) 48 | elif enc_name == "unet": 49 | return UNetEncoder( 50 | base_channels=base_channels, out_channels=base_channels * cond_mult 51 | ) 52 | elif enc_name == "unet128": 53 | # Like unet, but with downsample rate 128 rather than 256. 54 | return UNetEncoder( 55 | base_channels=base_channels, 56 | channel_mult=(1, 1, 2, 2, 2, 4, 4, 8), 57 | out_channels=base_channels * cond_mult, 58 | ) 59 | elif enc_name == "unet128-dilated": 60 | return UNetEncoder( 61 | base_channels=base_channels, 62 | channel_mult=(1, 1, 2, 2, 2, 4, 4, 8), 63 | out_dilations=(4, 8, 16, 32), 64 | out_channels=base_channels * cond_mult, 65 | ) 66 | elif enc_name == "conv-mfcc-ulaw": 67 | return ConvMFCCEncoder( 68 | base_channels=base_channels, out_channels=base_channels * cond_mult 69 | ) 70 | elif enc_name == "conv-mfcc-ulaw-v2": 71 | return ConvMFCCEncoder( 72 | base_channels=base_channels, 73 | out_channels=base_channels * cond_mult, 74 | version=2, 75 | ) 76 | elif enc_name == "conv-mfcc-linear": 77 | return ConvMFCCEncoder( 78 | base_channels=base_channels, 79 | out_channels=base_channels * cond_mult, 80 | input_ulaw=False, 81 | ) 82 | else: 83 | raise ValueError(f"unknown encoder: {enc_name}") 84 | -------------------------------------------------------------------------------- /vq_voice_swap/models/unet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/openai/guided-diffusion/blob/b16b0a180ffac9da8a6a03f1e78de8e96669eee8/guided_diffusion/unet.py. 3 | """ 4 | 5 | from typing import List, Optional, Tuple 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.utils.checkpoint import checkpoint 11 | 12 | from .base import Encoder, Predictor 13 | from .wavegrad import TimeEmbedding 14 | 15 | 16 | class UNetPredictor(Predictor): 17 | def __init__( 18 | self, 19 | base_channels: int, 20 | channel_mult: Tuple[int] = (1, 1, 2, 2, 2, 4, 4, 8, 8), 21 | middle_dilations: Tuple[int] = (4, 8, 16, 32), 22 | depth_mult: int = 2, 23 | cond_channels: Optional[int] = None, 24 | num_labels: Optional[int] = None, 25 | in_channels: int = 1, 26 | out_channels: int = 1, 27 | dropout: float = 0.0, 28 | ): 29 | super().__init__() 30 | self.base_channels = base_channels 31 | self.channel_mult = channel_mult 32 | self.middle_dilations = middle_dilations 33 | self.depth_mult = depth_mult 34 | self.cond_channels = cond_channels 35 | self.num_labels = num_labels 36 | self.in_channels = in_channels 37 | self.out_channels = out_channels 38 | 39 | embed_dim = base_channels * 4 40 | self.time_embed = TimeEmbedding(embed_dim) 41 | self.time_embed_extra = nn.Sequential( 42 | activation(), nn.Linear(embed_dim, embed_dim) 43 | ) 44 | if num_labels is not None: 45 | self.class_embed = nn.Embedding(num_labels, embed_dim) 46 | if cond_channels is not None: 47 | self.cond_proj = nn.Conv1d(cond_channels, base_channels, 3, padding=1) 48 | 49 | self.in_conv = nn.Conv1d(in_channels, base_channels, 3, padding=1) 50 | 51 | skip_channels = [base_channels] 52 | cur_channels = base_channels 53 | 54 | self.down_blocks = nn.ModuleList([]) 55 | for depth, mult in enumerate(channel_mult): 56 | for _ in range(depth_mult): 57 | self.down_blocks.append( 58 | ResBlock( 59 | channels=cur_channels, 60 | emb_channels=embed_dim, 61 | out_channels=mult * base_channels, 62 | dropout=dropout, 63 | ) 64 | ) 65 | cur_channels = mult * base_channels 66 | skip_channels.append(cur_channels) 67 | if depth != len(channel_mult) - 1: 68 | self.down_blocks.append( 69 | ResBlock( 70 | channels=cur_channels, 71 | emb_channels=embed_dim, 72 | scale_factor=0.5, 73 | dropout=dropout, 74 | ), 75 | ) 76 | skip_channels.append(cur_channels) 77 | 78 | self.middle_blocks = nn.ModuleList( 79 | [ 80 | ResBlock( 81 | channels=cur_channels, 82 | emb_channels=embed_dim, 83 | dilation=d, 84 | dropout=dropout, 85 | ) 86 | for d in middle_dilations 87 | ] 88 | ) 89 | 90 | self.up_blocks = nn.ModuleList([]) 91 | for depth, mult in list(enumerate(channel_mult))[::-1]: 92 | for _ in range(depth_mult + 1): 93 | in_ch = skip_channels.pop() 94 | self.up_blocks.append( 95 | ResBlock( 96 | channels=cur_channels + in_ch, 97 | emb_channels=embed_dim, 98 | out_channels=mult * base_channels, 99 | dropout=dropout, 100 | ) 101 | ) 102 | cur_channels = mult * base_channels 103 | if depth: 104 | self.up_blocks.append( 105 | ResBlock( 106 | channels=cur_channels, 107 | emb_channels=embed_dim, 108 | scale_factor=2.0, 109 | dropout=dropout, 110 | ) 111 | ) 112 | 113 | self.out = nn.Sequential( 114 | norm_act(base_channels), 115 | nn.Conv1d(base_channels, out_channels, 3, padding=1), 116 | ) 117 | 118 | def forward( 119 | self, 120 | x: torch.Tensor, 121 | ts: torch.Tensor, 122 | cond: Optional[torch.Tensor] = None, 123 | labels: Optional[torch.Tensor] = None, 124 | use_checkpoint: bool = False, 125 | ) -> torch.Tensor: 126 | assert (labels is None) == ( 127 | self.num_labels is None 128 | ), "must provide labels if and only if model is class conditional" 129 | assert (cond is None) == ( 130 | self.cond_channels is None 131 | ), "must provide cond sequence if and only if model is conditional" 132 | 133 | emb = self.time_embed_extra(self.time_embed(ts)) 134 | if labels is not None: 135 | emb = emb + self.class_embed(labels) 136 | 137 | h = self.in_conv(x) 138 | if cond is not None: 139 | h = h + F.interpolate(self.cond_proj(cond), h.shape[-1]) 140 | 141 | skips = [h] 142 | for block in self.down_blocks: 143 | if use_checkpoint: 144 | h = checkpoint(block, h, emb) 145 | else: 146 | h = block(h, emb) 147 | skips.append(h) 148 | for block in self.middle_blocks: 149 | if use_checkpoint: 150 | h = checkpoint(block, h, emb) 151 | else: 152 | h = block(h, emb) 153 | for i, block in enumerate(self.up_blocks): 154 | # No skip connection for upsampling block. 155 | if i % (self.depth_mult + 2) != self.depth_mult + 1: 156 | h = torch.cat([h, skips.pop()], axis=1) 157 | if use_checkpoint: 158 | h = checkpoint(block, h, emb) 159 | else: 160 | h = block(h, emb) 161 | 162 | h = self.out(h) 163 | return h 164 | 165 | def add_labels(self, n: int, end: bool = True): 166 | assert self.num_labels is not None 167 | old_weight = self.class_embed.weight.detach() 168 | old_count = self.num_labels 169 | 170 | self.num_labels += n 171 | self.class_embed = nn.Embedding(self.num_labels, old_weight.shape[-1]) 172 | with torch.no_grad(): 173 | if end: 174 | self.class_embed.weight[:old_count].copy_(old_weight) 175 | else: 176 | self.class_embed.weight[n:].copy_(old_weight) 177 | 178 | def label_parameters(self) -> List[nn.Parameter]: 179 | assert self.num_labels is not None 180 | return list(self.class_embed.parameters()) 181 | 182 | @property 183 | def downsample_rate(self) -> int: 184 | return 2 ** (len(self.channel_mult) - 1) 185 | 186 | 187 | class UNetEncoder(Encoder): 188 | def __init__( 189 | self, 190 | base_channels: int, 191 | channel_mult: Tuple[int] = (1, 1, 2, 2, 2, 4, 4, 8, 8), 192 | out_dilations: Tuple[int] = (), 193 | depth_mult: int = 2, 194 | in_channels: int = 1, 195 | out_channels: int = 512, 196 | ): 197 | super().__init__() 198 | self.base_channels = base_channels 199 | self.channel_mult = channel_mult 200 | self.depth_mult = depth_mult 201 | self.in_channels = in_channels 202 | self.out_channels = out_channels 203 | 204 | self.in_conv = nn.Conv1d(in_channels, base_channels, 3, padding=1) 205 | 206 | self.blocks = nn.ModuleList([]) 207 | 208 | cur_channels = base_channels 209 | for depth, mult in enumerate(channel_mult): 210 | for _ in range(depth_mult): 211 | self.blocks.append( 212 | ResBlock( 213 | channels=cur_channels, 214 | out_channels=mult * base_channels, 215 | ) 216 | ) 217 | cur_channels = mult * base_channels 218 | if depth != len(channel_mult) - 1: 219 | self.blocks.append(ResBlock(channels=cur_channels, scale_factor=0.5)) 220 | 221 | for d in out_dilations: 222 | self.blocks.append(ResBlock(channels=cur_channels, dilation=d)) 223 | 224 | self.out = nn.Sequential( 225 | norm_act(cur_channels), 226 | nn.Conv1d(cur_channels, out_channels, 3, padding=1), 227 | ) 228 | 229 | def forward( 230 | self, 231 | x: torch.Tensor, 232 | use_checkpoint: bool = False, 233 | ) -> torch.Tensor: 234 | h = self.in_conv(x) 235 | for block in self.blocks: 236 | if use_checkpoint: 237 | h = checkpoint(block, h) 238 | else: 239 | h = block(h) 240 | h = self.out(h) 241 | return h 242 | 243 | @property 244 | def downsample_rate(self) -> int: 245 | return 2 ** (len(self.channel_mult) - 1) 246 | 247 | 248 | class ResBlock(nn.Module): 249 | def __init__( 250 | self, 251 | channels: int, 252 | emb_channels: Optional[int] = None, 253 | out_channels: Optional[int] = None, 254 | scale_factor: float = 1.0, 255 | dilation: int = 2, 256 | dropout: float = 0.0, 257 | ): 258 | super().__init__() 259 | self.channels = channels 260 | self.emb_channels = emb_channels 261 | self.out_channels = out_channels or channels 262 | self.scale_factor = scale_factor 263 | self.dropout = dropout 264 | 265 | skip_conv = nn.Identity() 266 | if self.channels != self.out_channels: 267 | skip_conv = nn.Conv1d(self.channels, self.out_channels, 1) 268 | self.skip = nn.Sequential( 269 | Resize(scale_factor), 270 | skip_conv, 271 | ) 272 | 273 | if self.emb_channels: 274 | self.cond_layers = nn.Sequential( 275 | activation(), 276 | # Start with a small amount of conditioning. 277 | scale_module(nn.Linear(emb_channels, self.out_channels * 2), s=0.1), 278 | ) 279 | 280 | self.pre_cond = nn.Sequential( 281 | norm_act(channels), 282 | Resize(scale_factor), 283 | nn.Conv1d(self.channels, self.out_channels, 3, padding=1), 284 | normalization(self.out_channels), 285 | ) 286 | out_conv = scale_module( 287 | nn.Conv1d( 288 | self.out_channels, 289 | self.out_channels, 290 | 3, 291 | padding=dilation, 292 | dilation=dilation, 293 | ) 294 | ) 295 | if self.dropout: 296 | self.post_cond = nn.Sequential( 297 | activation(), 298 | nn.Dropout(p=dropout), 299 | out_conv, 300 | ) 301 | else: 302 | self.post_cond = nn.Sequential( 303 | activation(), 304 | out_conv, 305 | ) 306 | 307 | def forward( 308 | self, x: torch.Tensor, cond: Optional[torch.Tensor] = None 309 | ) -> torch.Tensor: 310 | h = self.pre_cond(x) 311 | if self.emb_channels: 312 | cond_ab = self.cond_layers(cond)[..., None] 313 | cond_a, cond_b = torch.split(cond_ab, self.out_channels, dim=1) 314 | h = h * (cond_a + 1) + cond_b 315 | h = self.post_cond(h) 316 | return self.skip(x) + h 317 | 318 | 319 | class Resize(nn.Module): 320 | def __init__(self, scale_factor: float): 321 | super().__init__() 322 | self.scale_factor = scale_factor 323 | 324 | def forward(self, x: torch.Tensor) -> torch.Tensor: 325 | if self.scale_factor == 1.0: 326 | return x 327 | if self.scale_factor < 1.0: 328 | down_factor = int(round(1 / self.scale_factor)) 329 | assert ( 330 | float(1 / down_factor - self.scale_factor) < 1e-5 331 | ), "scale factor must be integer or 1/integer" 332 | return F.avg_pool1d(x, down_factor) 333 | else: 334 | return F.interpolate(x, scale_factor=self.scale_factor) 335 | 336 | 337 | def norm_act(ch: int) -> nn.Module: 338 | return nn.Sequential(normalization(ch), activation()) 339 | 340 | 341 | def activation() -> nn.Module: 342 | return nn.GELU() 343 | 344 | 345 | def normalization(ch: int) -> nn.Module: 346 | num_groups = 32 347 | while ch % num_groups: 348 | num_groups //= 2 349 | return nn.GroupNorm(num_groups=num_groups, num_channels=ch) 350 | 351 | 352 | def scale_module(module: nn.Module, s: float = 0.0) -> nn.Module: 353 | for p in module.parameters(): 354 | with torch.no_grad(): 355 | p.mul_(s) 356 | return module 357 | -------------------------------------------------------------------------------- /vq_voice_swap/models/wavegrad.py: -------------------------------------------------------------------------------- 1 | """ 2 | A model similar to that in GAN-TTS (https://arxiv.org/abs/1909.11646v2) 3 | and WaveGrad (https://arxiv.org/abs/2009.00713). 4 | """ 5 | 6 | import math 7 | from typing import List, Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential 12 | 13 | from .base import Encoder, Predictor 14 | 15 | 16 | class WaveGradPredictor(Predictor): 17 | def __init__( 18 | self, 19 | cond_mult: int = 16, 20 | base_channels: int = 32, 21 | num_labels: Optional[int] = None, 22 | ): 23 | super().__init__() 24 | self.cond_channels = cond_mult * base_channels 25 | self.base_channels = base_channels 26 | self.d_blocks = nn.ModuleList( 27 | [ 28 | nn.Conv1d(1, base_channels, 5, padding=2), 29 | DBlock(base_channels, base_channels * 4, 4), 30 | DBlock(base_channels * 4, base_channels * 4, 2), 31 | DBlock(base_channels * 4, base_channels * 8, 2), 32 | DBlock(base_channels * 8, base_channels * 16, 2), 33 | ] 34 | ) 35 | self.u_conv_1 = nn.Conv1d(self.cond_channels, base_channels * 24, 3, padding=1) 36 | self.u_blocks = nn.ModuleList( 37 | [ 38 | UBlock( 39 | base_channels * 24, 40 | base_channels * 16, 41 | base_channels * 16, 42 | 2, 43 | num_labels=num_labels, 44 | ), 45 | UBlock( 46 | base_channels * 16, 47 | base_channels * 16, 48 | base_channels * 8, 49 | 2, 50 | num_labels=num_labels, 51 | ), 52 | UBlock( 53 | base_channels * 16, 54 | base_channels * 8, 55 | base_channels * 4, 56 | 2, 57 | num_labels=num_labels, 58 | ), 59 | UBlock( 60 | base_channels * 8, 61 | base_channels * 4, 62 | base_channels * 4, 63 | 2, 64 | num_labels=num_labels, 65 | ), 66 | UBlock( 67 | base_channels * 4, 68 | base_channels * 4, 69 | base_channels, 70 | 4, 71 | num_labels=num_labels, 72 | ), 73 | ] 74 | ) 75 | self.u_ln = NCTLayerNorm(base_channels * 4) 76 | self.u_conv_2 = nn.Conv1d(base_channels * 4, 1, 3, padding=1) 77 | for p in self.u_conv_2.parameters(): 78 | with torch.no_grad(): 79 | p.zero_() 80 | 81 | def forward( 82 | self, 83 | x: torch.Tensor, 84 | t: torch.Tensor, 85 | cond: Optional[torch.Tensor] = None, 86 | labels: Optional[torch.Tensor] = None, 87 | use_checkpoint=False, 88 | ) -> torch.Tensor: 89 | assert x.shape[2] % 64 == 0, "timesteps must be divisible by 64" 90 | 91 | # Model doesn't need to be conditional 92 | if cond is None: 93 | cond = torch.zeros(x.shape[0], self.cond_channels, x.shape[2] // 64).to(x) 94 | 95 | d_outputs = [] 96 | d_input = x 97 | for block in self.d_blocks: 98 | if use_checkpoint: 99 | if not d_input.requires_grad: 100 | d_input = d_input.clone().requires_grad_(True) 101 | d_input = checkpoint(block, d_input) 102 | else: 103 | d_input = block(d_input) 104 | d_outputs.append(d_input) 105 | 106 | u_input = self.u_conv_1(cond) 107 | for block in self.u_blocks: 108 | 109 | def run_fn(u_input, d_output, block=block, t=t, labels=labels): 110 | return block(u_input, d_output, t, labels=labels) 111 | 112 | if use_checkpoint: 113 | u_input = checkpoint(run_fn, u_input, d_outputs.pop()) 114 | else: 115 | u_input = run_fn(u_input, d_outputs.pop()) 116 | out = self.u_ln(u_input) 117 | out = self.u_conv_2(out) 118 | return out 119 | 120 | def add_labels(self, n: int, end: bool = True): 121 | for block in self.u_blocks: 122 | block.add_labels(n, end=end) 123 | 124 | def label_parameters(self) -> List[nn.Parameter]: 125 | return [x for n, x in self.named_parameters() if "label_emb" in n] 126 | 127 | @property 128 | def downsample_rate(self) -> int: 129 | return 64 130 | 131 | 132 | class WaveGradEncoder(Encoder): 133 | """ 134 | An encoder-only version of WaveGradPredictor that can be used to downsample 135 | waveforms. 136 | """ 137 | 138 | def __init__(self, cond_mult: int = 16, base_channels: int = 32): 139 | super().__init__() 140 | self.cond_channels = cond_mult * base_channels 141 | self.d_blocks = nn.Sequential( 142 | nn.Conv1d(1, base_channels, 5, padding=2), 143 | DBlock(base_channels, base_channels * 4, 4, extra_blocks=1), 144 | DBlock(base_channels * 4, base_channels * 4, 2, extra_blocks=1), 145 | DBlock(base_channels * 4, base_channels * 8, 2, extra_blocks=1), 146 | DBlock(base_channels * 8, base_channels * 16, 2, extra_blocks=1), 147 | DBlock(base_channels * 16, self.cond_channels, 2, extra_blocks=1), 148 | ) 149 | 150 | def forward(self, x: torch.Tensor, use_checkpoint: bool = False) -> torch.Tensor: 151 | if use_checkpoint: 152 | if not x.requires_grad: 153 | x = x.clone().requires_grad_() 154 | return checkpoint_sequential(self.d_blocks, len(self.d_blocks), x) 155 | else: 156 | return self.d_blocks(x) 157 | 158 | @property 159 | def downsample_rate(self) -> int: 160 | return 64 161 | 162 | 163 | class UBlock(nn.Module): 164 | def __init__( 165 | self, 166 | in_channels: int, 167 | out_channels: int, 168 | cond_channels: int, 169 | upsample_rate: int, 170 | num_labels: Optional[int] = None, 171 | ): 172 | super().__init__() 173 | self.in_channels = in_channels 174 | self.out_channels = out_channels 175 | self.cond_channels = cond_channels 176 | self.upsample_rate = upsample_rate 177 | 178 | def make_film(): 179 | return FILM(cond_channels, out_channels, num_labels=num_labels) 180 | 181 | self.film_1 = make_film() 182 | self.film_2 = make_film() 183 | self.film_3 = make_film() 184 | 185 | self.res_transform = nn.Sequential( 186 | nn.Upsample(scale_factor=upsample_rate), 187 | nn.Conv1d(in_channels, out_channels, 3, padding=1), 188 | ) 189 | self.block_1 = nn.Sequential( 190 | NCTLayerNorm(in_channels), 191 | nn.GELU(), 192 | nn.Upsample(scale_factor=upsample_rate), 193 | nn.Conv1d(in_channels, out_channels, 3, padding=1), 194 | ) 195 | self.block_2 = nn.Sequential( 196 | nn.GELU(), 197 | nn.Conv1d(out_channels, out_channels, 3, dilation=2, padding=2), 198 | ) 199 | self.block_3 = nn.Sequential( 200 | NCTLayerNorm(out_channels), 201 | nn.GELU(), 202 | nn.Conv1d(out_channels, out_channels, 3, dilation=4, padding=4), 203 | ) 204 | self.block_4 = nn.Sequential( 205 | nn.GELU(), 206 | nn.Conv1d(out_channels, out_channels, 3, dilation=8, padding=8), 207 | nn.GELU(), 208 | nn.Conv1d(out_channels, out_channels, 3, dilation=16, padding=16), 209 | ) 210 | 211 | def forward( 212 | self, 213 | h: torch.Tensor, 214 | z: torch.Tensor, 215 | t: torch.Tensor, 216 | labels: Optional[torch.Tensor] = None, 217 | ) -> torch.Tensor: 218 | res_out = self.res_transform(h) 219 | output = self.block_1(h) 220 | output = self.block_2(self.film_1(output, z, t, labels=labels)) 221 | output = output + res_out 222 | res_out = output 223 | output = self.block_3(self.film_2(output, z, t, labels=labels)) 224 | output = self.block_4(self.film_3(output, z, t, labels=labels)) 225 | return output + res_out 226 | 227 | def add_labels(self, n: int, end: bool = True): 228 | for film in [self.film_1, self.film_2, self.film_3]: 229 | film.add_labels(n, end=end) 230 | 231 | 232 | class DBlock(nn.Module): 233 | def __init__( 234 | self, 235 | in_channels: int, 236 | out_channels: int, 237 | downsample_rate: int, 238 | extra_blocks: int = 0, 239 | ): 240 | super().__init__() 241 | self.in_channels = in_channels 242 | self.out_channels = out_channels 243 | self.downsample_rate = downsample_rate 244 | self.extra_blocks = extra_blocks 245 | 246 | self.res_transform = nn.Sequential( 247 | nn.Conv1d(in_channels, out_channels, 3, padding=1), 248 | nn.AvgPool1d(downsample_rate, stride=downsample_rate), 249 | ) 250 | self.block_1 = nn.Sequential( 251 | NCTLayerNorm(in_channels), 252 | nn.AvgPool1d(downsample_rate, stride=downsample_rate), 253 | nn.GELU(), 254 | nn.Conv1d(in_channels, out_channels, 3, padding=1), 255 | nn.GELU(), 256 | nn.Conv1d(out_channels, out_channels, 3, dilation=2, padding=2), 257 | ) 258 | self.extra = nn.ModuleList( 259 | [ 260 | nn.Sequential( 261 | NCTLayerNorm(out_channels), 262 | nn.GELU(), 263 | nn.Conv1d(out_channels, out_channels, 3, padding=1), 264 | nn.GELU(), 265 | nn.Conv1d(out_channels, out_channels, 3, dilation=4, padding=4), 266 | nn.GELU(), 267 | nn.Conv1d(out_channels, out_channels, 3, dilation=8, padding=8), 268 | ) 269 | for _ in range(extra_blocks) 270 | ] 271 | ) 272 | 273 | def forward(self, h: torch.Tensor) -> torch.Tensor: 274 | res = self.block_1(h) + self.res_transform(h) 275 | for block in self.extra: 276 | res = res + block(res) 277 | return res 278 | 279 | 280 | class FILM(nn.Module): 281 | """ 282 | A FiLM layer that conditions on a timestep and (possibly) a label. 283 | 284 | The timestep is a floating point in the range [0, 1], whereas the labels 285 | are integers in the range [0, num_labels). 286 | 287 | The output of a FiLM layer is a tuple (alpha, beta), where alpha is a 288 | multiplier and beta is a bias. 289 | """ 290 | 291 | def __init__( 292 | self, cond_channels: int, out_channels: int, num_labels: Optional[int] = None 293 | ): 294 | super().__init__() 295 | self.cond_channels = cond_channels 296 | self.out_channels = out_channels 297 | self.hidden_channels = out_channels * 2 298 | self.num_labels = num_labels 299 | self.time_emb = TimeEmbedding(self.hidden_channels) 300 | self.cond_emb = nn.Sequential( 301 | NCTLayerNorm(cond_channels), 302 | nn.Conv1d(cond_channels, self.hidden_channels, 3, padding=1), 303 | ) 304 | if num_labels is not None: 305 | self.label_emb = nn.Embedding(num_labels, self.hidden_channels) 306 | # Random initial label embeddings appears to hurt performance. 307 | with torch.no_grad(): 308 | self.label_emb.weight.zero_() 309 | else: 310 | self.label_emb = None 311 | self.out_layer = nn.Sequential( 312 | nn.GELU(), 313 | nn.Conv1d(self.hidden_channels, out_channels * 2, 3, padding=1), 314 | ) 315 | # Start off with little conditioning signal. 316 | with torch.no_grad(): 317 | self.out_layer[1].weight.mul_(0.1) 318 | self.out_layer[1].bias.mul_(0.0) 319 | 320 | def forward( 321 | self, 322 | inputs: torch.Tensor, 323 | cond: torch.Tensor, 324 | t: torch.Tensor, 325 | labels: Optional[torch.Tensor] = None, 326 | ) -> torch.Tensor: 327 | embedding = self.time_emb(t) 328 | assert (labels is None) == (self.label_emb is None) 329 | if labels is not None: 330 | embedding = embedding + self.label_emb(labels) 331 | while len(embedding.shape) < len(cond.shape): 332 | embedding = embedding[..., None] 333 | embedding = embedding + self.cond_emb(cond) 334 | alpha_beta = self.out_layer(embedding) 335 | alpha, beta = torch.split(alpha_beta, self.out_channels, dim=1) 336 | return inputs * (1 + alpha) + beta 337 | 338 | def add_labels(self, n: int, end: bool = True): 339 | assert self.num_labels is not None 340 | old_weight = self.label_emb.weight.detach() 341 | old_count = self.num_labels 342 | 343 | self.num_labels += n 344 | self.label_emb = nn.Embedding(self.num_labels, old_weight.shape[-1]) 345 | with torch.no_grad(): 346 | if end: 347 | self.label_emb.weight[:old_count].copy_(old_weight) 348 | else: 349 | self.label_emb.weight[n:].copy_(old_weight) 350 | 351 | 352 | class TimeEmbedding(nn.Module): 353 | def __init__(self, channels: int): 354 | super().__init__() 355 | assert not channels % 2, f"channels {channels} should be divisible by two" 356 | self.channels = channels 357 | self.proj = nn.Linear(channels, channels) 358 | 359 | def forward(self, t: torch.Tensor) -> torch.Tensor: 360 | half = self.channels // 2 361 | min_coeff = 0.1 362 | max_coeff = 100.0 363 | freqs = ( 364 | torch.exp( 365 | -math.log(max_coeff / min_coeff) 366 | * torch.arange(start=0, end=half, dtype=torch.float32) 367 | / (half - 1) 368 | ) 369 | * max_coeff 370 | ).to(t) 371 | args = t[:, None] * freqs[None] 372 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 373 | return self.proj(embedding) 374 | 375 | 376 | class NCTLayerNorm(nn.Module): 377 | """ 378 | LayerNorm that normalizes channels in NCT tensors. 379 | """ 380 | 381 | def __init__(self, ch: int): 382 | super().__init__() 383 | self.ln = nn.LayerNorm((ch,)) 384 | 385 | def forward(self, x: torch.Tensor) -> torch.Tensor: 386 | x = x.permute(0, 2, 1).contiguous() 387 | x = self.ln(x) 388 | x = x.permute(0, 2, 1).contiguous() 389 | return x 390 | -------------------------------------------------------------------------------- /vq_voice_swap/smoothing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def moving_average(xs: np.ndarray, window_size: int) -> np.ndarray: 5 | """ 6 | :param xs: a 1-D array of floating points. 7 | :param window_size: the number of points to average over. 8 | :return: an array like xs, where every entry is the average of window_size 9 | points in xs. Thus, entry k is the average of [k, k-1, ...]. 10 | """ 11 | if len(xs) <= window_size: 12 | return np.cumsum(xs) / (np.arange(len(xs)) + 1) 13 | return np.concatenate( 14 | [ 15 | np.cumsum(xs)[: window_size - 1] / (np.arange(window_size - 1) + 1), 16 | np.convolve(xs, np.ones([window_size]) / window_size, mode="valid"), 17 | ] 18 | ) 19 | -------------------------------------------------------------------------------- /vq_voice_swap/smoothing_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from .smoothing import moving_average 5 | 6 | 7 | @pytest.mark.parametrize("length", [9, 10, 11, 51]) 8 | def test_moving_average(length: int): 9 | data = np.random.normal(size=(length,)) 10 | actual = moving_average(data, 10) 11 | expected = slow_moving_average(data, 10) 12 | assert np.allclose(actual, expected) 13 | 14 | 15 | def slow_moving_average(data: np.ndarray, window: int): 16 | res = np.zeros_like(data) 17 | for i in range(len(data)): 18 | start = max(0, i - window + 1) 19 | res[i] = np.mean(data[start : i + 1]) 20 | return res 21 | -------------------------------------------------------------------------------- /vq_voice_swap/train_loop.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import sys 5 | import time 6 | from abc import ABC, abstractmethod 7 | from typing import Any, Dict, Iterable, List, Set, Tuple 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.optim import AdamW 13 | 14 | from vq_voice_swap.loss_tracker import LossTracker 15 | 16 | from .dataset import create_data_loader 17 | from .diffusion import Diffusion, make_schedule 18 | from .diffusion_model import DiffusionModel 19 | from .ema import ModelEMA 20 | from .logger import Logger 21 | from .loss_tracker import LossTracker 22 | from .models import Classifier, EncoderPredictor, Savable 23 | from .util import count_params, repeat_dataset 24 | from .vq import ReviveVQLoss, StandardVQLoss 25 | from .vq_vae import VQVAE 26 | 27 | 28 | class TrainLoop(ABC): 29 | """ 30 | An abstract training loop with methods to override for controlling 31 | different pieces of training. 32 | """ 33 | 34 | def __init__(self, args=None): 35 | if args is None: 36 | args = self.arg_parser().parse_args() 37 | self.args = args 38 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 39 | 40 | if not os.path.exists(args.output_dir): 41 | os.mkdir(args.output_dir) 42 | 43 | self.data_loader, self.num_labels = self.create_data_loader() 44 | self.model, self.resume = self.create_model() 45 | self.model.to(self.device) 46 | 47 | self.emas = self.create_emas() 48 | self.opt = self.create_opt() 49 | self.logger, self.tracker = self.create_logger_tracker() 50 | 51 | self.total_steps = self.logger.start_step 52 | self.loop_steps = 0 53 | 54 | self.freeze_parameters(self.frozen_parameters()) 55 | self.write_run_info() 56 | 57 | def loop(self): 58 | for i, data_batch in enumerate(repeat_dataset(self.data_loader)): 59 | self.total_steps = i + self.logger.start_step 60 | self.loop_steps = i 61 | self.step(data_batch) 62 | 63 | def step(self, data_batch: Dict[str, torch.Tensor]): 64 | self.opt.zero_grad() 65 | 66 | all_losses = [] 67 | all_ts = [] 68 | all_loss = 0.0 69 | all_extra = dict() 70 | 71 | for microbatch, weight in self.split_microbatches(data_batch): 72 | losses, ts, extra_losses = self.compute_losses(microbatch) 73 | 74 | # Re-weighted losses for microbatch averaging 75 | extra_losses = {k: v * weight for k, v in extra_losses.items()} 76 | loss = losses.mean() * weight 77 | for extra in extra_losses.values(): 78 | loss = loss + extra 79 | 80 | self.loss_backward(loss) 81 | 82 | # Needed to re-aggregate the microbatch losses for 83 | # normal logging. 84 | all_losses.append(losses.detach()) 85 | all_ts.append(ts) 86 | all_loss = all_loss + loss.detach() 87 | all_extra = { 88 | k: v.detach() + all_extra.get(k, 0.0) for k, v in extra_losses.items() 89 | } 90 | 91 | self.step_optimizer() 92 | self.log_losses( 93 | all_loss, torch.cat(all_losses, dim=0), torch.cat(all_ts, dim=0), all_extra 94 | ) 95 | 96 | if (self.total_steps + 1) % self.args.save_interval == 0: 97 | self.save() 98 | 99 | def split_microbatches( 100 | self, data_batch: Dict[str, torch.Tensor] 101 | ) -> List[Tuple[Dict[str, torch.Tensor], float]]: 102 | key = next(iter(data_batch.keys())) 103 | batch_size = len(data_batch[key]) 104 | if not self.args.microbatch or self.args.microbatch > batch_size: 105 | return [(data_batch, 1.0)] 106 | res = [] 107 | for i in range(0, batch_size, self.args.microbatch): 108 | sub_batch = { 109 | k: v[i : i + self.args.microbatch] for k, v in data_batch.items() 110 | } 111 | res.append((sub_batch, len(sub_batch[key]) / batch_size)) 112 | return res 113 | 114 | def loss_backward(self, loss: torch.Tensor): 115 | loss.backward() 116 | 117 | def step_optimizer(self): 118 | self.opt.step() 119 | for ema in self.emas.values(): 120 | ema.update() 121 | 122 | @abstractmethod 123 | def compute_losses( 124 | self, data_batch: Dict[str, torch.Tensor] 125 | ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]: 126 | """ 127 | Compute loss per batch element, and also return the diffusion timestep 128 | for each loss. Also return a (possibly empty) dict of other losses. 129 | 130 | :return: a tuple (losses, ts, other). 131 | """ 132 | 133 | def log_losses( 134 | self, 135 | loss: torch.Tensor, 136 | losses: torch.Tensor, 137 | ts: torch.Tensor, 138 | extra_losses: Dict[str, torch.Tensor], 139 | ): 140 | self.tracker.add(ts, losses) 141 | other = {k: v.item() for k, v in extra_losses.items()} 142 | other.update(self.tracker.log_dict()) 143 | self.logger.log(self.loop_steps + 1, loss=loss.item(), **other) 144 | 145 | def save(self): 146 | self.model.save(self.checkpoint_path()) 147 | for rate, ema in self.emas.items(): 148 | ema.model.save(self.ema_path(rate)) 149 | torch.save(self.opt.state_dict(), self.opt_path()) 150 | self.logger.mark_save() 151 | 152 | def create_data_loader(self) -> Tuple[Iterable, int]: 153 | return create_data_loader( 154 | directory=self.args.data_dir, 155 | batch_size=self.args.batch_size, 156 | encoding=self.args.encoding, 157 | ) 158 | 159 | def create_model(self) -> Tuple[Savable, bool]: 160 | if os.path.exists(self.checkpoint_path()): 161 | print("loading from checkpoint...") 162 | model = self.model_class().load(self.checkpoint_path()) 163 | resume = True 164 | else: 165 | print("creating new model") 166 | model = self.create_new_model() 167 | resume = False 168 | 169 | if self.args.pretrained_path: 170 | print(f"loading from pretrained model: {self.args.pretrained_path} ...") 171 | num_params = self.load_from_pretrained(model) 172 | print(f"loaded {num_params} pre-trained parameters...") 173 | print(f"total parameters: {count_params(model)}") 174 | return model, resume 175 | 176 | def create_emas(self) -> Dict[float, ModelEMA]: 177 | res = {} 178 | for rate_str in self.args.ema_rate.split(","): 179 | rate = float(rate_str) 180 | assert rate not in res, "cannot have duplicate EMA rate" 181 | ema = ModelEMA(self.model, rates={"": rate}) 182 | path = self.ema_path(rate) 183 | if os.path.exists(path): 184 | print(f"loading EMA {rate} from checkpoint...") 185 | ema.model = self.model_class().load(path).to(self.device) 186 | res[rate] = ema 187 | return res 188 | 189 | def create_opt(self) -> torch.optim.Optimizer: 190 | opt = AdamW( 191 | self.model.parameters(), 192 | lr=self.args.lr, 193 | weight_decay=self.args.weight_decay, 194 | ) 195 | if os.path.exists(self.opt_path()): 196 | print("loading optimizer from checkpoint...") 197 | opt.load_state_dict(torch.load(self.opt_path(), map_location="cpu")) 198 | return opt 199 | 200 | def frozen_parameters(self) -> Set[nn.Parameter]: 201 | return set() 202 | 203 | def freeze_parameters(self, params: Set[nn.Parameter]): 204 | param_to_idx = {param: idx for idx, param in enumerate(self.model.parameters())} 205 | count = 0 206 | sd = self.opt.state_dict().copy() 207 | for p in params: 208 | self.freeze_parameter(param_to_idx[p], p, sd) 209 | count += p.numel() 210 | if count: 211 | self.opt.load_state_dict(sd) 212 | print(f"frozen parameters: {count}") 213 | 214 | def freeze_parameter( 215 | self, idx: int, param: nn.Parameter, opt_state: Dict[str, Any] 216 | ): 217 | param.requires_grad_(False) 218 | if idx in opt_state["state"]: 219 | # A step has been taken, and the parameter might have some 220 | # momentum built up that we need to cancel out. 221 | assert opt_state["state"][idx]["exp_avg"].shape == param.shape 222 | opt_state["state"] = opt_state["state"].copy() 223 | opt_state["state"][idx] = opt_state["state"][idx].copy() 224 | opt_state["state"][idx]["exp_avg"].zero_() 225 | opt_state["state"][idx]["exp_avg_sq"].zero_() 226 | 227 | def create_logger_tracker(self) -> Tuple[Logger, LossTracker]: 228 | return Logger(self.log_path(), resume=self.resume), LossTracker() 229 | 230 | def checkpoint_path(self): 231 | return os.path.join(self.args.output_dir, "model.pt") 232 | 233 | def ema_path(self, rate): 234 | return os.path.join(self.args.output_dir, f"model_ema_{rate}.pt") 235 | 236 | def opt_path(self): 237 | return os.path.join(self.args.output_dir, "opt.pt") 238 | 239 | def log_path(self): 240 | return os.path.join(self.args.output_dir, "train_log.txt") 241 | 242 | @abstractmethod 243 | def model_class(self) -> Any: 244 | """ 245 | Get the Savable class used to construct models. 246 | """ 247 | 248 | @abstractmethod 249 | def create_new_model(self) -> Savable: 250 | """ 251 | Create a new instance of the model. 252 | """ 253 | 254 | def load_from_pretrained(self, model: Savable) -> int: 255 | pt = self.model_class().load(self.args.pretrained_path) 256 | return model.load_from_pretrained(pt) 257 | 258 | def write_run_info(self): 259 | filename = f"run_info_{int(time.time())}.json" 260 | with open(os.path.join(self.args.output_dir, filename), "w+") as f: 261 | json.dump(self.run_info(), f, indent=4) 262 | 263 | def run_info(self) -> Dict: 264 | return dict( 265 | args=self.args.__dict__, 266 | command=sys.argv[0], 267 | start_steps=self.total_steps, 268 | ) 269 | 270 | @classmethod 271 | def arg_parser(cls) -> argparse.ArgumentParser: 272 | """ 273 | Get an argument parser for the training command. 274 | """ 275 | parser = argparse.ArgumentParser( 276 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 277 | ) 278 | parser.add_argument("--lr", default=1e-4, type=float) 279 | parser.add_argument("--ema-rate", default="0.9999", type=str) 280 | parser.add_argument("--weight-decay", default=0.0, type=float) 281 | parser.add_argument("--batch-size", default=8, type=int) 282 | parser.add_argument("--microbatch", default=None, type=int) 283 | parser.add_argument("--output-dir", default=cls.default_output_dir(), type=str) 284 | parser.add_argument("--pretrained-path", default=None, type=str) 285 | parser.add_argument("--save-interval", default=1000, type=int) 286 | parser.add_argument("--grad-checkpoint", action="store_true") 287 | parser.add_argument("--encoding", default="linear", type=str) 288 | parser.add_argument("data_dir", type=str) 289 | return parser 290 | 291 | @classmethod 292 | @abstractmethod 293 | def default_output_dir(cls) -> str: 294 | """ 295 | Get the default directory name for training output. 296 | """ 297 | 298 | 299 | class DiffusionTrainLoop(TrainLoop): 300 | def compute_losses( 301 | self, data_batch: Dict[str, torch.Tensor] 302 | ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]: 303 | audio_seq = data_batch["samples"][:, None].to(self.device) 304 | if self.args.class_cond: 305 | extra_kwargs = dict(labels=data_batch["label"].to(self.device)) 306 | else: 307 | extra_kwargs = dict() 308 | ts = torch.rand(len(audio_seq), device=self.device) 309 | losses = self.model.diffusion.ddpm_losses( 310 | audio_seq, 311 | self.model.predictor.condition( 312 | use_checkpoint=self.args.grad_checkpoint, **extra_kwargs 313 | ), 314 | ts=ts, 315 | ) 316 | return losses, ts, dict() 317 | 318 | def model_class(self) -> Any: 319 | return DiffusionModel 320 | 321 | def create_new_model(self) -> Savable: 322 | return self.model_class()( 323 | pred_name=self.args.predictor, 324 | base_channels=self.args.base_channels, 325 | schedule_name=self.args.schedule, 326 | dropout=self.args.dropout, 327 | num_labels=self.num_labels if self.args.class_cond else None, 328 | ) 329 | 330 | @classmethod 331 | def arg_parser(cls) -> argparse.ArgumentParser: 332 | parser = super().arg_parser() 333 | parser.add_argument("--predictor", default="unet", type=str) 334 | parser.add_argument("--base-channels", default=32, type=int) 335 | parser.add_argument("--dropout", default=0.0, type=float) 336 | parser.add_argument("--schedule", default="exp", type=str) 337 | parser.add_argument("--class-cond", action="store_true") 338 | return parser 339 | 340 | @classmethod 341 | def default_output_dir(cls) -> str: 342 | return "ckpt_diffusion" 343 | 344 | 345 | class VQVAETrainLoop(DiffusionTrainLoop): 346 | def __init__(self, **kwargs): 347 | super().__init__(**kwargs) 348 | if self.args.revival_coeff: 349 | self.vq_loss = ReviveVQLoss( 350 | revival=self.args.revival_coeff, commitment=self.args.commitment_coeff 351 | ) 352 | else: 353 | self.vq_loss = StandardVQLoss(commitment=self.args.commitment_coeff) 354 | 355 | def compute_losses( 356 | self, data_batch: Dict[str, torch.Tensor] 357 | ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]: 358 | audio_seq = data_batch["samples"][:, None].to(self.device) 359 | if self.args.class_cond: 360 | extra_kwargs = dict(labels=data_batch["label"].to(self.device)) 361 | else: 362 | extra_kwargs = dict() 363 | losses = self.model.losses( 364 | self.vq_loss, 365 | audio_seq, 366 | jitter=self.args.jitter, 367 | **extra_kwargs, 368 | use_checkpoint=self.args.grad_checkpoint, 369 | ) 370 | return losses["mses"], losses["ts"], dict(vq_loss=losses["vq_loss"]) 371 | 372 | def model_class(self) -> Any: 373 | return VQVAE 374 | 375 | def create_model(self) -> Tuple[Savable, bool]: 376 | model, resume = super().create_model() 377 | model.vq.dead_rate = self.args.dead_rate 378 | return model, resume 379 | 380 | def create_new_model(self) -> Savable: 381 | return self.model_class()( 382 | pred_name=self.args.predictor, 383 | base_channels=self.args.base_channels, 384 | enc_name=self.args.encoder, 385 | cond_mult=self.args.cond_mult, 386 | dictionary_size=self.args.dictionary_size, 387 | schedule_name=self.args.schedule, 388 | dropout=self.args.dropout, 389 | num_labels=self.num_labels if self.args.class_cond else None, 390 | ) 391 | 392 | def frozen_parameters(self) -> Set[nn.Parameter]: 393 | res = set() 394 | if self.args.freeze_encoder: 395 | res.update(self.model.encoder.parameters()) 396 | if self.args.freeze_vq: 397 | res.update(self.model.vq.parameters()) 398 | return res 399 | 400 | @classmethod 401 | def arg_parser(cls) -> argparse.ArgumentParser: 402 | parser = super().arg_parser() 403 | parser.add_argument("--encoder", default="unet", type=str) 404 | parser.add_argument("--cond-mult", default=16, type=int) 405 | parser.add_argument("--dictionary-size", default=512, type=int) 406 | parser.add_argument("--freeze-encoder", action="store_true") 407 | parser.add_argument("--freeze-vq", action="store_true") 408 | parser.add_argument("--commitment-coeff", default=0.25, type=float) 409 | parser.add_argument("--revival-coeff", default=0.0, type=float) 410 | parser.add_argument("--dead-rate", default=100, type=int) 411 | parser.add_argument("--jitter", default=0.0, type=float) 412 | return parser 413 | 414 | def load_from_pretrained(self, model: Savable) -> int: 415 | pt, err = None, None 416 | for cls in [self.model_class(), DiffusionModel]: 417 | try: 418 | pt = cls.load(self.args.pretrained_path) 419 | except RuntimeError as exc: 420 | err = exc 421 | if pt is None: 422 | raise err 423 | return model.load_from_pretrained(pt) 424 | 425 | def step_optimizer(self): 426 | super().step_optimizer() 427 | if self.should_revive(): 428 | self.model.vq.revive_dead_entries() 429 | 430 | def should_revive(self) -> bool: 431 | return not self.args.revival_coeff and not self.args.freeze_vq 432 | 433 | @classmethod 434 | def default_output_dir(cls) -> str: 435 | return "ckpt_vqvae" 436 | 437 | 438 | class VQVAEAddClassesTrainLoop(VQVAETrainLoop): 439 | def __init__(self, **kwargs): 440 | # These are set during model load. 441 | self.pretrained_kwargs = None 442 | self.pretrained_num_labels = None 443 | 444 | super().__init__(**kwargs) 445 | assert self.args.class_cond 446 | 447 | def compute_losses( 448 | self, data_batch: Dict[str, torch.Tensor] 449 | ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]: 450 | data_batch["label"] = data_batch["label"] + self.pretrained_num_labels 451 | return super().compute_losses(data_batch) 452 | 453 | def create_model(self) -> Tuple[Savable, bool]: 454 | assert self.args.pretrained_path, "must load from a pre-trained VQVAE" 455 | assert self.args.class_cond, "must create a class-conditional model" 456 | pretrained = VQVAE.load(self.args.pretrained_path) 457 | self.pretrained_num_labels = pretrained.num_labels 458 | self.pretrained_kwargs = pretrained.save_kwargs() 459 | 460 | return super().create_model() 461 | 462 | def create_new_model(self) -> Savable: 463 | kwargs = self.pretrained_kwargs.copy() 464 | kwargs["num_labels"] = self.num_labels + self.pretrained_num_labels 465 | return self.model_class()(**kwargs) 466 | 467 | def load_from_pretrained(self, model: Savable) -> int: 468 | base_model = VQVAE.load(self.args.pretrained_path) 469 | base_model.add_labels(self.num_labels) 470 | return model.load_from_pretrained(base_model) 471 | 472 | def frozen_parameters(self) -> Set[nn.Parameter]: 473 | label_params = set(self.model.predictor.label_parameters()) 474 | x = set(x for x in self.model.parameters() if x not in label_params) 475 | return x 476 | 477 | def should_revive(self) -> bool: 478 | # Don't mess with the VQ codebook, since we might not be 479 | # using all of it for the new classes, but still want to 480 | # preserve functionality on the old classes. 481 | return False 482 | 483 | @classmethod 484 | def default_output_dir(cls) -> str: 485 | return "ckpt_vqvae_added" 486 | 487 | 488 | class VQVAEUncondTrainLoop(VQVAETrainLoop): 489 | def __init__(self, **kwargs): 490 | # These are set during model load. 491 | self.pretrained_kwargs = None 492 | self.pretrained_num_labels = None 493 | 494 | super().__init__(**kwargs) 495 | assert self.args.class_cond 496 | 497 | def compute_losses( 498 | self, data_batch: Dict[str, torch.Tensor] 499 | ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]: 500 | label_mask = torch.rand(data_batch["label"].shape) > self.args.no_class_prob 501 | labels = (data_batch["label"] + 1) * label_mask 502 | 503 | audio_seq = data_batch["samples"][:, None].to(self.device) 504 | extra_kwargs = dict(labels=labels.to(self.device)) 505 | losses = self.model.losses( 506 | self.vq_loss, 507 | audio_seq, 508 | jitter=self.args.jitter, 509 | **extra_kwargs, 510 | use_checkpoint=self.args.grad_checkpoint, 511 | no_vq_prob=self.args.no_vq_prob, 512 | ) 513 | return losses["mses"], losses["ts"], dict(vq_loss=losses["vq_loss"]) 514 | 515 | def create_model(self) -> Tuple[Savable, bool]: 516 | assert self.args.pretrained_path, "must load from a pre-trained VQVAE" 517 | assert self.args.class_cond, "must create a class-conditional model" 518 | pretrained = VQVAE.load(self.args.pretrained_path) 519 | self.pretrained_num_labels = pretrained.num_labels 520 | self.pretrained_kwargs = pretrained.save_kwargs() 521 | 522 | return super().create_model() 523 | 524 | def create_new_model(self) -> Savable: 525 | kwargs = self.pretrained_kwargs.copy() 526 | kwargs["num_labels"] = self.pretrained_num_labels + 1 527 | return self.model_class()(**kwargs) 528 | 529 | def load_from_pretrained(self, model: Savable) -> int: 530 | base_model = VQVAE.load(self.args.pretrained_path) 531 | base_model.add_labels(1, end=False) 532 | return model.load_from_pretrained(base_model) 533 | 534 | @classmethod 535 | def arg_parser(cls) -> argparse.ArgumentParser: 536 | parser = super().arg_parser() 537 | parser.add_argument("--no-class-prob", default=0.1, type=float) 538 | parser.add_argument("--no-vq-prob", default=0.1, type=float) 539 | return parser 540 | 541 | @classmethod 542 | def default_output_dir(cls) -> str: 543 | return "ckpt_vqvae_uncond" 544 | 545 | 546 | class ClassifierTrainLoop(TrainLoop): 547 | def __init__(self, **kwargs): 548 | super().__init__(**kwargs) 549 | self.diffusion = Diffusion(make_schedule(self.args.schedule)) 550 | 551 | def compute_losses( 552 | self, data_batch: Dict[str, torch.Tensor] 553 | ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]: 554 | audio_seq = data_batch["samples"][:, None].to(self.device) 555 | labels = data_batch["label"].to(self.device) 556 | ts = self.sample_timesteps(len(audio_seq)) 557 | 558 | samples = self.diffusion.sample_q(audio_seq, ts) 559 | logits = self.model(samples, ts, use_checkpoint=self.args.grad_checkpoint) 560 | nlls = -F.log_softmax(logits, dim=-1)[range(len(labels)), labels] 561 | return nlls, ts, dict() 562 | 563 | def sample_timesteps(self, n: int) -> torch.Tensor: 564 | ts = torch.rand(n, device=self.device) 565 | if self.total_steps < self.args.curriculum_steps: 566 | frac = self.total_steps / self.args.curriculum_steps 567 | power = self.args.curriculum_start * (1 - frac) + frac 568 | ts = ts ** power 569 | return ts 570 | 571 | def model_class(self) -> Any: 572 | return Classifier 573 | 574 | def create_new_model(self) -> Savable: 575 | return self.model_class()( 576 | num_labels=self.num_labels, base_channels=self.args.base_channels 577 | ) 578 | 579 | def load_from_pretrained(self, model: Savable) -> int: 580 | dm = DiffusionModel.load(self.args.pretrained_path) 581 | return model.load_from_predictor(dm.predictor) 582 | 583 | @classmethod 584 | def arg_parser(cls) -> argparse.ArgumentParser: 585 | parser = super().arg_parser() 586 | parser.add_argument("--base-channels", default=32, type=int) 587 | parser.add_argument("--schedule", default="exp", type=str) 588 | parser.add_argument("--curriculum-start", default=30.0, type=float) 589 | parser.add_argument("--curriculum-steps", default=0, type=int) 590 | return parser 591 | 592 | @classmethod 593 | def default_output_dir(cls) -> str: 594 | return "ckpt_classifier" 595 | 596 | 597 | class EncoderPredictorTrainLoop(TrainLoop): 598 | def __init__(self, **kwargs): 599 | self.vq_vae = None 600 | super().__init__(**kwargs) 601 | 602 | def compute_losses( 603 | self, data_batch: Dict[str, torch.Tensor] 604 | ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]: 605 | audio_seq = data_batch["samples"][:, None].to(self.device) 606 | ts = self.sample_timesteps(len(audio_seq)) 607 | with torch.no_grad(): 608 | targets = self.vq_vae.encode(audio_seq) 609 | samples = self.vq_vae.diffusion.sample_q(audio_seq, ts) 610 | losses = self.model.losses( 611 | samples, ts, targets, use_checkpoint=self.args.grad_checkpoint 612 | ) 613 | return losses, ts, dict() 614 | 615 | def sample_timesteps(self, n: int) -> torch.Tensor: 616 | ts = torch.rand(n, device=self.device) 617 | if self.total_steps < self.args.curriculum_steps: 618 | frac = self.total_steps / self.args.curriculum_steps 619 | power = self.args.curriculum_start * (1 - frac) + frac 620 | ts = ts ** power 621 | return ts 622 | 623 | def model_class(self) -> Any: 624 | return EncoderPredictor 625 | 626 | def create_model(self) -> Tuple[Savable, bool]: 627 | self.vq_vae = VQVAE.load(self.args.vq_vae_path).to(self.device) 628 | return super().create_model() 629 | 630 | def create_new_model(self) -> Savable: 631 | return self.model_class()( 632 | base_channels=self.args.base_channels, 633 | downsample_rate=self.vq_vae.encoder.downsample_rate, 634 | num_latents=self.vq_vae.dictionary_size, 635 | ) 636 | 637 | @classmethod 638 | def arg_parser(cls) -> argparse.ArgumentParser: 639 | parser = super().arg_parser() 640 | parser.add_argument("--vq-vae-path", type=str, required=True) 641 | parser.add_argument("--base-channels", type=int, default=32) 642 | parser.add_argument("--curriculum-start", default=30.0, type=float) 643 | parser.add_argument("--curriculum-steps", default=0, type=int) 644 | return parser 645 | 646 | @classmethod 647 | def default_output_dir(cls) -> str: 648 | return "ckpt_enc_pred" 649 | -------------------------------------------------------------------------------- /vq_voice_swap/util.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Iterator 2 | 3 | import torch.nn as nn 4 | 5 | 6 | def repeat_dataset(data_loader: Iterable) -> Iterator: 7 | while True: 8 | yield from data_loader 9 | 10 | 11 | def count_params(model: nn.Module) -> int: 12 | return sum(x.numel() for x in model.parameters()) 13 | -------------------------------------------------------------------------------- /vq_voice_swap/vq.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of the vector quantization step in a VQ-VAE. 3 | 4 | Code adapted from: https://github.com/unixpickle/vq-vae-2/blob/6874db74dbc8e7a24c33303c0aa12d66d803c725/vq_vae_2/vq.py 5 | """ 6 | 7 | import random 8 | from abc import abstractmethod 9 | from typing import Callable, Dict, Tuple 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | 17 | class VQLoss(nn.Module): 18 | """ 19 | An abstract loss function for a VQ layer. 20 | """ 21 | 22 | @abstractmethod 23 | def forward( 24 | self, inputs: torch.Tensor, embedded: torch.Tensor, dictionary: torch.Tensor 25 | ) -> torch.Tensor: 26 | """ 27 | Compute a VQ loss given the unquantized vectors, the embedded outputs, 28 | and the entire embedding table. 29 | 30 | :param inputs: an [N x C x ...] Tensor of inputs. 31 | :param embedded: a Tensor like `inputs` of embedded vectors. 32 | :param dictionary: a [D x C] dictionary. 33 | """ 34 | 35 | 36 | class StandardVQLoss(VQLoss): 37 | """ 38 | The standard VQ-VAE loss for vector quantization. 39 | """ 40 | 41 | def __init__(self, commitment: float = 0.25): 42 | super().__init__() 43 | self.commitment = commitment 44 | 45 | def forward( 46 | self, inputs: torch.Tensor, embedded: torch.Tensor, dictionary: torch.Tensor 47 | ) -> torch.Tensor: 48 | _ = dictionary 49 | codebook_loss = ((inputs.detach() - embedded) ** 2).mean() 50 | comm_loss = ((inputs - embedded.detach()) ** 2).mean() 51 | return codebook_loss + self.commitment * comm_loss 52 | 53 | 54 | class ReviveVQLoss(StandardVQLoss): 55 | """ 56 | A VQ-VAE loss with an additional term pulling every codebook entry 57 | slightly towards the mean input to prevent dead centers. 58 | """ 59 | 60 | def __init__(self, revival: float, **kwargs): 61 | super().__init__(**kwargs) 62 | self.revival = revival 63 | 64 | def forward( 65 | self, inputs: torch.Tensor, embedded: torch.Tensor, dictionary: torch.Tensor 66 | ) -> torch.Tensor: 67 | loss = super().forward(inputs, embedded, dictionary) 68 | 69 | flat_inputs, _ = flatten_channels(inputs) 70 | distances = embedding_distances(dictionary, flat_inputs) 71 | return loss + self.revival * distances.mean() 72 | 73 | 74 | class VQ(nn.Module): 75 | """ 76 | A vector quantization layer. 77 | 78 | Inputs are Tensors of shape [N x C x ...]. 79 | Outputs include an embedded version of the input Tensor of the same shape, 80 | a quantized, discrete [N x ...] Tensor, and other losses. 81 | 82 | :param num_channels: the depth of the input Tensors. 83 | :param num_codes: the number of codebook entries. 84 | :param dead_rate: the number of forward passes after which a dictionary 85 | entry is considered dead if it has not been used. 86 | """ 87 | 88 | def __init__(self, num_channels: int, num_codes: int, dead_rate: int = 100): 89 | super().__init__() 90 | self.num_channels = num_channels 91 | self.num_codes = num_codes 92 | self.dead_rate = dead_rate 93 | 94 | self.dictionary = nn.Parameter(torch.randn(num_codes, num_channels)) 95 | self.register_buffer("usage_count", dead_rate * torch.ones(num_codes).long()) 96 | self._last_batch = None # used for revival 97 | 98 | def embed(self, idxs: torch.Tensor) -> torch.Tensor: 99 | """ 100 | Convert encoded indices into embeddings. 101 | 102 | :param idxs: an [N x ...] Tensor. 103 | :return: an [N x C x ...] Tensor with gradients to the dictionary. 104 | """ 105 | batch_size = idxs.shape[0] 106 | new_shape = (batch_size, self.num_channels, *idxs.shape[1:]) 107 | idxs = idxs.view(batch_size, -1) 108 | embedded = F.embedding(idxs, self.dictionary) 109 | embedded = embedded.permute(0, 2, 1).reshape(new_shape) 110 | return embedded 111 | 112 | def forward(self, inputs: torch.Tensor) -> Dict[str, torch.Tensor]: 113 | """ 114 | Apply vector quantization. 115 | 116 | If the module is in training mode, this will also update the usage 117 | tracker and re-initialize dead dictionary entries. 118 | 119 | :param inputs: an [N x C x ...] Tensor. 120 | :return: a dict containing the following keys: 121 | - "embedded": a Tensor like inputs whose gradients flow to 122 | the dictionary. 123 | - "passthrough": a Tensor like inputs whose gradients are 124 | passed through to inputs. 125 | - "idxs": an [N x ...] integer Tensor of code indices. 126 | """ 127 | idxs_shape = (inputs.shape[0], *inputs.shape[2:]) 128 | x, unflatten_fn = flatten_channels(inputs) 129 | 130 | diffs = embedding_distances(self.dictionary, x) 131 | idxs = torch.argmin(diffs, dim=-1) 132 | embedded = self.embed(idxs) 133 | passthrough = embedded.detach() + (x - x.detach()) 134 | 135 | if self.training: 136 | self._update_tracker(idxs) 137 | self._last_batch = x.detach() 138 | 139 | return { 140 | "embedded": unflatten_fn(embedded), 141 | "passthrough": unflatten_fn(passthrough), 142 | "idxs": idxs.reshape(idxs_shape), 143 | } 144 | 145 | def revive_dead_entries(self): 146 | """ 147 | Use the dictionary usage tracker to re-initialize entries that aren't 148 | being used often. 149 | 150 | Uses statistics from the previous call to forward() to revive centers. 151 | Thus, forward() must have been called at least once. 152 | """ 153 | assert ( 154 | self._last_batch is not None 155 | ), "cannot revive dead entries until a batch has been run" 156 | inputs = self._last_batch 157 | 158 | counts = self.usage_count.detach().cpu().numpy() 159 | new_dictionary = None 160 | inputs_numpy = None 161 | input_probs = None 162 | for i, count in enumerate(counts): 163 | if count: 164 | continue 165 | if new_dictionary is None: 166 | new_dictionary = self.dictionary.detach().cpu().numpy() 167 | if inputs_numpy is None: 168 | inputs_numpy = inputs.detach().cpu().numpy() 169 | # K-means++ init: probabilities proportional to dist^2. 170 | input_probs = ( 171 | embedding_distances(self.dictionary, inputs) 172 | .min(-1)[0] 173 | .detach() 174 | .clamp(min=0) 175 | .cpu() 176 | .numpy() 177 | ) 178 | input_probs /= np.sum(input_probs) 179 | new_dictionary[i] = inputs_numpy[ 180 | np.random.choice(len(input_probs), p=input_probs) 181 | ] 182 | counts[i] = self.dead_rate 183 | if new_dictionary is not None: 184 | dict_tensor = torch.from_numpy(new_dictionary).to(self.dictionary) 185 | counts_tensor = torch.from_numpy(counts).to(self.usage_count) 186 | with torch.no_grad(): 187 | self.dictionary.copy_(dict_tensor) 188 | self.usage_count.copy_(counts_tensor) 189 | 190 | def _update_tracker(self, idxs): 191 | raw_idxs = set(idxs.detach().cpu().numpy().flatten()) 192 | update = -np.ones([self.num_codes], dtype=np.int) 193 | for idx in raw_idxs: 194 | update[idx] = self.dead_rate 195 | self.usage_count.add_(torch.from_numpy(update).to(self.usage_count)) 196 | self.usage_count.clamp_(0, self.dead_rate) 197 | 198 | 199 | def embedding_distances(dictionary: torch.Tensor, tensor: torch.Tensor) -> torch.Tensor: 200 | """ 201 | Compute distances between every embedding in a 202 | dictionary and every vector in a Tensor. 203 | 204 | This will not generate a huge intermediate Tensor, 205 | unlike the naive implementation. 206 | 207 | :param dictionary: a [D x C] Tensor. 208 | :param tensor: an [N x C] Tensor. 209 | 210 | :return: an [N x D] Tensor of distances. 211 | """ 212 | dict_norms = torch.sum(torch.pow(dictionary, 2), dim=-1) 213 | tensor_norms = torch.sum(torch.pow(tensor, 2), dim=-1) 214 | 215 | # Work-around for https://github.com/pytorch/pytorch/issues/18862. 216 | exp_tensor = tensor[..., None].view(-1, tensor.shape[-1], 1) 217 | exp_dict = dictionary[None].expand(exp_tensor.shape[0], *dictionary.shape) 218 | dots = torch.bmm(exp_dict, exp_tensor)[..., 0] 219 | dots = dots.view(*tensor.shape[:-1], dots.shape[-1]) 220 | 221 | return -2 * dots + dict_norms + tensor_norms[..., None] 222 | 223 | 224 | def flatten_channels( 225 | x: torch.Tensor, 226 | ) -> Tuple[torch.Tensor, Callable[[torch.Tensor], torch.Tensor]]: 227 | """ 228 | Turn an [N x C x ...] Tensor into a [B x C] Tensor. 229 | 230 | :return: a tuple (new_tensor, reverse_fn). The reverse_fn can be applied 231 | to a [B x C] Tensor to get back an [N x C x ...] Tensor. 232 | """ 233 | in_shape = x.shape 234 | batch, channels = in_shape[:2] 235 | x = x.view(batch, channels, -1) 236 | x = x.permute(0, 2, 1) 237 | permuted_shape = x.shape 238 | x = x.reshape(-1, channels) 239 | 240 | def reverse_fn(y): 241 | return y.reshape(permuted_shape).permute(0, 2, 1).reshape(in_shape) 242 | 243 | return x, reverse_fn 244 | -------------------------------------------------------------------------------- /vq_voice_swap/vq_vae.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | import torch 4 | 5 | from .diffusion_model import DiffusionModel 6 | from .models import EncoderPredictor, make_encoder 7 | from .vq import VQ, VQLoss 8 | 9 | 10 | class VQVAE(DiffusionModel): 11 | """ 12 | A waveform VQ-VAE with a diffusion decoder. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | base_channels: int, 18 | enc_name: str = "unet", 19 | cond_mult: int = 16, 20 | dictionary_size: int = 512, 21 | **kwargs, 22 | ): 23 | encoder = make_encoder( 24 | enc_name=enc_name, base_channels=base_channels, cond_mult=cond_mult 25 | ) 26 | kwargs["cond_channels"] = base_channels * cond_mult 27 | super().__init__(base_channels=base_channels, **kwargs) 28 | self.enc_name = enc_name 29 | self.cond_mult = cond_mult 30 | self.dictionary_size = dictionary_size 31 | self.encoder = encoder 32 | self.vq = VQ(self.cond_channels, dictionary_size) 33 | 34 | def losses( 35 | self, 36 | vq_loss: VQLoss, 37 | inputs: torch.Tensor, 38 | labels: Optional[torch.Tensor] = None, 39 | jitter: float = 0.0, 40 | no_vq_prob: float = 0.0, 41 | **extra_kwargs: Any, 42 | ) -> Dict[str, torch.Tensor]: 43 | """ 44 | Compute losses for training the VQVAE. 45 | 46 | :param vq_loss: the vector-quantization loss function. 47 | :param inputs: the input [N x 1 x T] audio Tensor. 48 | :param labels: an [N] Tensor of integer labels. 49 | :param jitter: jitter regularization to use. 50 | :param no_vq_prob: probability of dropping VQ codes per sequence. 51 | :return: a dict containing the following keys: 52 | - "vq_loss": loss for the vector quantization layer. 53 | - "mse": mean loss for all batch elements. 54 | - "ts": a 1-D float tensor of the timesteps per batch entry. 55 | - "mses": a 1-D tensor of the mean MSE losses per batch entry. 56 | """ 57 | encoder_out = self.encoder(inputs, **extra_kwargs) 58 | if jitter: 59 | encoder_out = jitter_seq(encoder_out, jitter) 60 | vq_out = self.vq(encoder_out) 61 | vq_loss = vq_loss(encoder_out, vq_out["embedded"], self.vq.dictionary) 62 | 63 | ts = torch.rand(inputs.shape[0]).to(inputs) 64 | epsilon = torch.randn_like(inputs) 65 | noised_inputs = self.diffusion.sample_q(inputs, ts, epsilon=epsilon) 66 | cond = vq_out["passthrough"] 67 | 68 | if no_vq_prob: 69 | cond_mask = (torch.rand(len(cond)) > no_vq_prob).to(cond) 70 | while len(cond_mask.shape) < len(cond.shape): 71 | cond_mask = cond_mask[..., None] 72 | cond = cond * cond_mask 73 | 74 | predictions = self.predictor( 75 | noised_inputs, ts, cond=cond, labels=labels, **extra_kwargs 76 | ) 77 | mses = ((predictions - epsilon) ** 2).flatten(1).mean(1) 78 | mse = mses.mean() 79 | 80 | return {"vq_loss": vq_loss, "mse": mse, "ts": ts, "mses": mses} 81 | 82 | def encode(self, inputs: torch.Tensor) -> torch.Tensor: 83 | """ 84 | Encode a waveform as discrete symbols. 85 | 86 | :param inputs: an [N x 1 x T] audio Tensor. 87 | :return: an [N x T1] Tensor of latent codes. 88 | """ 89 | with torch.no_grad(): 90 | return self.vq(self.encoder(inputs))["idxs"] 91 | 92 | def decode( 93 | self, 94 | codes: torch.Tensor, 95 | labels: Optional[torch.Tensor] = None, 96 | steps: int = 100, 97 | progress: bool = False, 98 | constrain: bool = False, 99 | enc_pred: Optional[EncoderPredictor] = None, 100 | enc_pred_scale: float = 1.0, 101 | **kwargs, 102 | ) -> torch.Tensor: 103 | """ 104 | Sample the decoder using encoded audio and corresponding labels. 105 | 106 | :param codes: an [N x T1] Tensor of latent codes or an [N x C x T1] 107 | Tensor of latent code embeddings. 108 | :param labels: an [N] Tensor of integer labels. 109 | :param steps: number of diffusion steps. 110 | :param progress: if True, show a progress bar with tqdm. 111 | :param constrain: if True, clamp x_start predictions. 112 | :param enc_pred: an encoder predictor for guidance. 113 | :param enc_pred_scale: the scale for guidance. 114 | :return: an [N x 1 x T] Tensor of audio. 115 | """ 116 | if len(codes.shape) == 2: 117 | cond_seq = self.vq.embed(codes) 118 | elif len(codes.shape) == 3: 119 | cond_seq = codes 120 | else: 121 | raise ValueError(f"unsupported codes shape: {codes.shape}") 122 | 123 | targets = self.vq(cond_seq)["idxs"] 124 | 125 | def cond_fn(x, ts): 126 | with torch.enable_grad(): 127 | x_grad = x.detach().clone().requires_grad_(True) 128 | losses = enc_pred.losses(x_grad, ts, targets) * targets.shape[-1] 129 | grads = torch.autograd.grad(losses.sum(), x_grad)[0] 130 | return grads * enc_pred_scale * -1 131 | 132 | x_T = torch.randn( 133 | codes.shape[0], 1, codes.shape[-1] * self.encoder.downsample_rate 134 | ).to(codes.device) 135 | return self.diffusion.ddpm_sample( 136 | x_T, 137 | lambda xs, ts, **kwargs: self.predictor( 138 | xs, ts, cond=cond_seq, labels=labels, **kwargs 139 | ), 140 | steps=steps, 141 | progress=progress, 142 | constrain=constrain, 143 | cond_fn=cond_fn if enc_pred is not None else None, 144 | **kwargs, 145 | ) 146 | 147 | def decode_uncond_guidance( 148 | self, 149 | codes: torch.Tensor, 150 | labels: Optional[torch.Tensor] = None, 151 | steps: int = 100, 152 | progress: bool = False, 153 | constrain: bool = False, 154 | label_scale: float = 0.0, 155 | vq_scale: float = 0.0, 156 | **kwargs, 157 | ) -> torch.Tensor: 158 | """ 159 | Sample the decoder using unconditional guidance towards encoded audio 160 | and corresponding labels. 161 | 162 | :param codes: an [N x T1] Tensor of latent codes or an [N x C x T1] 163 | Tensor of latent code embeddings. 164 | :param labels: an [N] Tensor of integer labels, which have not been 165 | offset for the unconditional label. 166 | :param steps: number of diffusion steps. 167 | :param progress: if True, show a progress bar with tqdm. 168 | :param constrain: if True, clamp x_start predictions. 169 | :param label_scale: guidance scale for labels (0.0 does no guidance). 170 | :param vq_scale: guidance scale for VQ codes (0.0 does no guidance). 171 | :return: an [N x 1 x T] Tensor of audio. 172 | """ 173 | if len(codes.shape) == 2: 174 | cond_seq = self.vq.embed(codes) 175 | elif len(codes.shape) == 3: 176 | cond_seq = codes 177 | else: 178 | raise ValueError(f"unsupported codes shape: {codes.shape}") 179 | 180 | x_T = torch.randn( 181 | codes.shape[0], 1, codes.shape[-1] * self.encoder.downsample_rate 182 | ).to(codes.device) 183 | 184 | def pred_fn(xs, ts, **kwargs): 185 | xs = torch.cat([xs] * 3, dim=0) 186 | ts = torch.cat([ts] * 3, dim=0) 187 | kwargs = {k: torch.cat([v] * 3, dim=0) for k, v in kwargs.items()} 188 | 189 | cond_batch = cond_seq 190 | label_batch = labels + 1 191 | 192 | if vq_scale: 193 | cond_batch = torch.cat([cond_batch, torch.zeros_like(cond_seq)], dim=0) 194 | if label_batch is not None: 195 | label_batch = torch.cat([label_batch, labels + 1], dim=0) 196 | if labels is not None and label_scale: 197 | cond_batch = torch.cat([cond_batch, cond_seq], dim=0) 198 | label_batch = torch.cat([label_batch, torch.zeros_like(labels)], dim=0) 199 | outs = self.predictor(xs, ts, cond=cond_batch, labels=label_batch, **kwargs) 200 | 201 | base_pred = outs[: len(cond_seq)] 202 | outs = outs[len(cond_seq) :] 203 | pred = base_pred 204 | 205 | for flag, scale in [(True, vq_scale), (labels is not None, label_scale)]: 206 | if flag and scale: 207 | sub_out = outs[: len(cond_seq)] 208 | outs = outs[len(cond_seq) :] 209 | pred = pred + scale * (base_pred - sub_out) 210 | 211 | return pred 212 | 213 | return self.diffusion.ddpm_sample( 214 | x_T, 215 | pred_fn, 216 | steps=steps, 217 | progress=progress, 218 | constrain=constrain, 219 | **kwargs, 220 | ) 221 | 222 | @property 223 | def downsample_rate(self) -> int: 224 | """ 225 | Get the minimum divisor required for input sequences. 226 | """ 227 | # Naive lowest common multiple. 228 | x, y = super().downsample_rate, self.encoder.downsample_rate 229 | return next(i for i in range(x * y) if i % x == 0 and i % y == 0) 230 | 231 | def save_kwargs(self) -> Dict[str, Any]: 232 | res = super().save_kwargs() 233 | res.update( 234 | dict( 235 | enc_name=self.enc_name, 236 | cond_mult=self.cond_mult, 237 | dictionary_size=self.dictionary_size, 238 | ) 239 | ) 240 | return res 241 | 242 | 243 | def jitter_seq(seq: torch.Tensor, p: float) -> torch.Tensor: 244 | """ 245 | Apply temporal jitter to a latent sequence. 246 | 247 | This regularization technique was proposed in 248 | https://arxiv.org/abs/1901.08810. 249 | 250 | :param seq: an [N x C x T] Tensor. 251 | :param p: probability of a timestep being replaced. 252 | """ 253 | right_shifted = torch.cat([seq[..., :1], seq[..., :-1]], dim=-1) 254 | left_shifted = torch.cat([seq[..., 1:], seq[..., -1:]], dim=-1) 255 | nums = torch.rand(seq.shape[0], 1, seq.shape[-1]).to(seq.device) 256 | 257 | return torch.where( 258 | nums < p / 2, 259 | right_shifted, 260 | torch.where(nums < p, left_shifted, seq), 261 | ) 262 | --------------------------------------------------------------------------------