├── .gitignore ├── README.md ├── config.json ├── dataset.py ├── dataset ├── 1 │ ├── 1.code.pt │ ├── 1.wav │ ├── 1.wav.f0.npy │ └── 1.wav.soft.pt └── 2 │ ├── 2.code.pt │ ├── 2.wav │ ├── 2.wav.f0.npy │ └── 2.wav.soft.pt ├── demo.ipynb ├── hubert └── put_pretrained_model_here ├── infer.py ├── inference ├── infer_tool.py └── slicer.py ├── logs ├── tts │ └── tts_logs_here └── vc │ └── vc_logs_here ├── model.py ├── modules └── commons.py ├── nsf_hifigan ├── env.py ├── models.py └── utils.py ├── operations.py ├── output └── output_here ├── parametrizations.py ├── parametrize.py ├── preprocess.py ├── raw └── input_here ├── sampler ├── dpm_solver.py └── uni_pc.py ├── test.py ├── train.py ├── unet1d ├── __init__.py ├── activations.py ├── attention.py ├── attention_processor.py ├── dual_transformer_1d.py ├── embeddings.py ├── lora.py ├── outputs.py ├── resnet.py ├── transformer_1d.py ├── unet_1d_blocks.py └── unet_1d_condition.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/*.pyc 2 | **/*.wav 3 | settings.json 4 | **/*.npy 5 | **/*.pt 6 | **/events.* 7 | */*.json 8 | dataset/* 9 | dataset_processed/* 10 | logs/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # NS2VC_v2 3 | 4 | ## Unofficial implementation of NaturalSpeech2 for Voice Conversion 5 | Different from the NS2, I use the vocos but encodec as the vocoder for better quality, and use contentvec to substitute the text embedding and duration span process. 6 | I also adopted the unet1d conditional model from the diffusers lib, thanks for their hard works. 7 | ![ns2_v2](https://github.com/adelacvg/NS2VC/assets/27419496/f62c48ad-874f-4165-805d-90b55d481dcd) 8 | 9 | ### About Zero shot generalization 10 | I did many attempt on improve the generalization of the model. And I find that it's much like the stable diffusion. If a tag is not in your train set, you can't get a promising result. Larger dataset, more speaker, better generalization, better results. The model can ensure speakers in trainset have a good result. 11 | ### Demo 12 | | refer | input | output | 13 | | :----| :---- | :---- | 14 | |[refer0.webm](https://github.com/adelacvg/NS2VC/assets/27419496/abed2fdc-8366-4522-bbc7-646e0ae6b842)| [gt0.webm](https://github.com/adelacvg/NS2VC/assets/27419496/327794b0-e550-4932-8075-4be09e063d45)| [gen0.webm](https://github.com/adelacvg/NS2VC/assets/27419496/3defcd4a-6843-464c-a903-285a14751096)| 15 | |[refer1.webm](https://github.com/adelacvg/NS2VC/assets/27419496/3d924019-0a68-41a5-aeaf-928a9b8fa8b5)| [gt1.webm](https://github.com/adelacvg/NS2VC/assets/27419496/12fc1514-0edb-493d-a07f-3c94b0548557)| [gen1.webm](https://github.com/adelacvg/NS2VC/assets/27419496/f38e8780-1baf-48b5-b6e5-0ba3856599e2)| 16 | |[refer2.webm](https://github.com/adelacvg/NS2VC/assets/27419496/9759088b-10e7-4bb1-a0ed-c808e11b9f9e)|[gt2.webm](https://github.com/adelacvg/NS2VC/assets/27419496/ddff8bfc-7c6a-4d53-9b98-0d66c421d1d1)|[gen2.webm](https://github.com/adelacvg/NS2VC/assets/27419496/d72cb17d-6813-4d87-8ec5-929b2cc2fb15)| 17 | |[refer3.webm](https://github.com/adelacvg/NS2VC/assets/27419496/c9e045ac-914c-4b49-a112-c71acce2eb27)|[gt3.webm](https://github.com/adelacvg/NS2VC/assets/27419496/a684e11d-32fe-46e3-87e0-e0c6047a24dc)|[gen3.webm](https://github.com/adelacvg/NS2VC/assets/27419496/df3ceced-bfae-4272-a8d7-94a49826f04a)| 18 | |[refer4.webm](https://github.com/adelacvg/NS2VC/assets/27419496/e3191a18-44fc-477e-9ed4-60c42ad35b80)|[gt4.webm](https://github.com/adelacvg/NS2VC/assets/27419496/318a0843-89a5-46de-b1e2-2039a457bc17)|[gen4.webm](https://github.com/adelacvg/NS2VC/assets/27419496/06487dab-f047-4461-9e5c-4bd53bfdfd56)| 19 | 20 | 21 | 22 | 23 | ### Data preprocessing 24 | First of all, you need to download the contentvec model and put it under the hubert folder. 25 | The model can be download from here. 26 | 27 | The dataset structure can be like this: 28 | 29 | ``` 30 | dataset 31 | ├── spk1 32 | │ ├── 1.wav 33 | │ ├── 2.wav 34 | │ ├── ... 35 | │ └── spk11 36 | │ ├── 11.wav 37 | ├── 3.wav 38 | ├── 4.wav 39 | ``` 40 | 41 | Overall, you can put the data in any way you like. 42 | 43 | Put the data with .wav extension under the dataset folder, and then run the following command to preprocess the data. 44 | 45 | ```python 46 | python preprocess.py 47 | ``` 48 | 49 | The preprocessed data will be saved under the processed_dataset folder. 50 | 51 | ## Requirements 52 | 53 | You can install the requirements by running the following command. 54 | 55 | ```python 56 | pip install vocos accelerate matplotlib librosa unidecode inflect ema_pytorch tensorboard fairseq praat-parselmouth pyworld 57 | ``` 58 | 59 | ### Training 60 | Install the accelerate first, run `accelerate config` to configure the environment, and then run the following command to train the model. 61 | 62 | ```python 63 | accelerate launch train.py 64 | ``` 65 | 66 | ### Inference 67 | 68 | Change the device, model_path, clean_names and refer_names in the inference.py, and then run the following command to inference the model. 69 | 70 | ```python 71 | python infer.py 72 | ``` 73 | ### Continue training 74 | If you want to fine tune or continue to train a model. 75 | Add 76 | ```python 77 | trainer.load('your_model_path') 78 | ``` 79 | to the `train.py`. 80 | ### Pretrained model 81 | Maybe comming soon, if I had enough data for a good model. 82 | 83 | ### TTS 84 | 85 | If you want to use the TTS model, please check the TTS branch. 86 | 87 | ### Q&A 88 | 89 | qq group:801645314 90 | You can add the qq group to discuss the project. 91 | 92 | Thanks to sovits4, naturalspeech2 and imagen diffusersfor their great works. 93 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | 4 | "train_batch_size":32, 5 | "gradient_accumulate_every": 1, 6 | "train_lr": 0.0001, 7 | "train_num_steps": 1000000, 8 | "ema_update_every": 10, 9 | "ema_decay": 0.995, 10 | "adam_betas": [0.9, 0.99], 11 | "save_and_sample_every":1000, 12 | "timesteps":1000, 13 | "sampling_timesteps":1000, 14 | "results_folder": "results", 15 | "logs_folder" : "logs/vc", 16 | "num_workers": 32, 17 | "eps": 1e-09, 18 | "keep_ckpts": 3, 19 | "all_in_mem": false 20 | }, 21 | "data": { 22 | "training_files": "../vc_dataset_processed", 23 | "val_files": "../val_dataset_processed", 24 | "sampling_rate": 24000, 25 | "hop_length": 256 26 | }, 27 | "phoneme_encoder":{ 28 | "in_channels":256, 29 | "hidden_channels":256, 30 | "out_channels":256, 31 | "n_layers":6, 32 | "p_dropout":0.2 33 | }, 34 | "f0_predictor": 35 | { 36 | "in_channels":256, 37 | "hidden_channels":256, 38 | "out_channels":1, 39 | "attention_layers":10, 40 | "n_heads":8, 41 | "p_dropout":0.5 42 | }, 43 | "prompt_encoder":{ 44 | "in_channels":100, 45 | "hidden_channels":256, 46 | "out_channels":256, 47 | "n_layers":6, 48 | "p_dropout":0.2 49 | }, 50 | "diffusion_encoder":{ 51 | "in_channels":100, 52 | "out_channels":100, 53 | "hidden_channels":256, 54 | "n_heads":8, 55 | "p_dropout":0.2 56 | } 57 | } -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import os 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.utils.data 7 | import torchaudio 8 | import utils 9 | import torchaudio.transforms as T 10 | import random 11 | 12 | 13 | """Multi speaker version""" 14 | 15 | class TestDataset(torch.utils.data.Dataset): 16 | def __init__(self, audio_path, cfg, codec, all_in_mem: bool = False): 17 | self.audiopaths = glob(os.path.join(audio_path, "**/*.wav"), recursive=True) 18 | self.sampling_rate = cfg['data']['sampling_rate'] 19 | self.hop_length = cfg['data']['hop_length'] 20 | random.shuffle(self.audiopaths) 21 | self.all_in_mem = all_in_mem 22 | if self.all_in_mem: 23 | self.cache = [self.get_audio(p[0]) for p in self.audiopaths] 24 | 25 | def get_audio(self, filename): 26 | audio, sampling_rate = torchaudio.load(filename) 27 | audio = T.Resample(sampling_rate, self.sampling_rate)(audio) 28 | 29 | spec = torch.load(filename.replace(".wav", ".spec.pt")).squeeze(0) 30 | 31 | f0 = np.load(filename + ".f0.npy") 32 | f0, uv = utils.interpolate_f0(f0) 33 | f0 = torch.FloatTensor(f0) 34 | uv = torch.FloatTensor(uv) 35 | 36 | c = torch.load(filename+ ".soft.pt") 37 | c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[0]) 38 | 39 | lmin = min(c.size(-1), spec.size(-1)) 40 | assert abs(c.size(-1) - spec.size(-1)) < 3, (c.size(-1), spec.size(-1), f0.shape, filename) 41 | assert abs(audio.shape[1]-lmin * self.hop_length) < 3 * self.hop_length 42 | spec, c, f0, uv = spec[:, :lmin], c[:, :lmin], f0[:lmin], uv[:lmin] 43 | audio = audio[:, :lmin * self.hop_length] 44 | return c.detach(), f0.detach(), spec.detach(), audio.detach(), uv.detach() 45 | 46 | def __getitem__(self, index): 47 | return *self.get_audio(self.audiopaths[index]), *self.get_audio(self.audiopaths[(index+4)%self.__len__()]) 48 | 49 | def __len__(self): 50 | return len(self.audiopaths) 51 | 52 | 53 | class NS2VCDataset(torch.utils.data.Dataset): 54 | """ 55 | 1) loads audio, speaker_id, text pairs 56 | 2) normalizes text and converts them to sequences of integers 57 | 3) computes spectrograms from audio files. 58 | """ 59 | 60 | def __init__(self, audio_path, cfg, codec, all_in_mem: bool = False): 61 | self.audiopaths = glob(os.path.join(audio_path, "**/*.wav"), recursive=True) 62 | self.sampling_rate = cfg['data']['sampling_rate'] 63 | self.hop_length = cfg['data']['hop_length'] 64 | # self.codec = codec 65 | 66 | # random.seed(1234) 67 | random.shuffle(self.audiopaths) 68 | 69 | self.all_in_mem = all_in_mem 70 | if self.all_in_mem: 71 | self.cache = [self.get_audio(p[0]) for p in self.audiopaths] 72 | 73 | def get_audio(self, filename): 74 | audio, sampling_rate = torchaudio.load(filename) 75 | audio = T.Resample(sampling_rate, self.sampling_rate)(audio) 76 | 77 | spec = torch.load(filename.replace(".wav", ".spec.pt")).squeeze(0) 78 | 79 | f0 = np.load(filename + ".f0.npy") 80 | f0, uv = utils.interpolate_f0(f0) 81 | f0 = torch.FloatTensor(f0) 82 | uv = torch.FloatTensor(uv) 83 | 84 | c = torch.load(filename+ ".soft.pt") 85 | c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[0]) 86 | 87 | lmin = min(c.size(-1), spec.size(-1)) 88 | assert abs(c.size(-1) - spec.size(-1)) < 3, (c.size(-1), spec.size(-1), f0.shape, filename) 89 | assert abs(audio.shape[1]-lmin * self.hop_length) < 3 * self.hop_length 90 | spec, c, f0, uv = spec[:, :lmin], c[:, :lmin], f0[:lmin], uv[:lmin] 91 | audio = audio[:, :lmin * self.hop_length] 92 | return c.detach(), f0.detach(), spec.detach(), audio.detach(), uv.detach() 93 | 94 | def random_slice(self, c, f0, spec, audio, uv): 95 | if spec.shape[1] < 30: 96 | print("skip too short audio") 97 | return None 98 | if spec.shape[1] > 400: 99 | start = random.randint(0, spec.shape[1]-400) 100 | end = start + 400 101 | spec, c, f0, uv = spec[:, start:end], c[:, start:end], f0[start:end], uv[start:end] 102 | audio = audio[:, start * self.hop_length : end * self.hop_length] 103 | len_spec = spec.shape[1] 104 | l = random.randint(int(len_spec//3), int(len_spec//3*2)) 105 | u = random.randint(0, len_spec-l) 106 | v = u + l 107 | refer = spec[:, u:v] 108 | c = torch.cat([c[:, :u], c[:, v:]], dim=-1) 109 | f0 = torch.cat([f0[:u], f0[v:]], dim=-1) 110 | spec = torch.cat([spec[:, :u], spec[:, v:]], dim=-1) 111 | uv = torch.cat([uv[:u], uv[v:]], dim=-1) 112 | audio = torch.cat([audio[:, :u * self.hop_length], audio[:, v * self.hop_length:]], dim=-1) 113 | assert c.shape[1] != 0 114 | assert refer.shape[1] != 0 115 | return refer, c, f0, spec, audio, uv 116 | 117 | def __getitem__(self, index): 118 | if self.all_in_mem: 119 | return self.random_slice(*self.cache[index]) 120 | else: 121 | return self.random_slice(*self.get_audio(self.audiopaths[index])) 122 | # print(1) 123 | 124 | def __len__(self): 125 | return len(self.audiopaths) 126 | 127 | 128 | class TextAudioCollate: 129 | 130 | def __call__(self, batch): 131 | hop_length = 320 132 | batch = [b for b in batch if b is not None] 133 | 134 | input_lengths, ids_sorted_decreasing = torch.sort( 135 | torch.LongTensor([x[0].shape[1] for x in batch]), 136 | dim=0, descending=True) 137 | 138 | # refer, c, f0, spec, audio, uv 139 | max_refer_len = max([x[0].size(1) for x in batch]) 140 | max_c_len = max([x[1].size(1) for x in batch]) 141 | max_wav_len = max([x[4].size(1) for x in batch]) 142 | 143 | lengths = torch.LongTensor(len(batch)) 144 | refer_lengths = torch.LongTensor(len(batch)) 145 | 146 | contentvec_dim = batch[0][1].shape[0] 147 | spec_dim = batch[0][3].shape[0] 148 | c_padded = torch.FloatTensor(len(batch), contentvec_dim, max_c_len+1) 149 | f0_padded = torch.FloatTensor(len(batch), max_c_len+1) 150 | spec_padded = torch.FloatTensor(len(batch), spec_dim, max_c_len+1) 151 | refer_padded = torch.FloatTensor(len(batch), spec_dim, max_refer_len+1) 152 | wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len+1) 153 | uv_padded = torch.FloatTensor(len(batch), max_c_len+1) 154 | 155 | c_padded.zero_() 156 | spec_padded.zero_() 157 | refer_padded.zero_() 158 | f0_padded.zero_() 159 | wav_padded.zero_() 160 | uv_padded.zero_() 161 | 162 | for i in range(len(ids_sorted_decreasing)): 163 | row = batch[ids_sorted_decreasing[i]] 164 | 165 | # refer, c, f0, spec, audio, uv 166 | len_refer = row[0].size(1) 167 | len_contentvec = row[1].size(1) 168 | len_wav = row[4].size(1) 169 | 170 | lengths[i] = len_contentvec 171 | refer_lengths[i] = len_refer 172 | 173 | refer_padded[i, :, :len_refer] = row[0][:] 174 | c_padded[i, :, :len_contentvec] = row[1][:] 175 | f0_padded[i, :len_contentvec] = row[2][:] 176 | spec_padded[i, :, :len_contentvec] = row[3][:] 177 | wav_padded[i, :, :len_wav] = row[4][:] 178 | uv_padded[i, :len_contentvec] = row[5][:] 179 | 180 | return c_padded, refer_padded, f0_padded, spec_padded, wav_padded, lengths, refer_lengths, uv_padded 181 | -------------------------------------------------------------------------------- /dataset/1/1.code.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/dataset/1/1.code.pt -------------------------------------------------------------------------------- /dataset/1/1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/dataset/1/1.wav -------------------------------------------------------------------------------- /dataset/1/1.wav.f0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/dataset/1/1.wav.f0.npy -------------------------------------------------------------------------------- /dataset/1/1.wav.soft.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/dataset/1/1.wav.soft.pt -------------------------------------------------------------------------------- /dataset/2/2.code.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/dataset/2/2.code.pt -------------------------------------------------------------------------------- /dataset/2/2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/dataset/2/2.wav -------------------------------------------------------------------------------- /dataset/2/2.wav.f0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/dataset/2/2.wav.f0.npy -------------------------------------------------------------------------------- /dataset/2/2.wav.soft.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/dataset/2/2.wav.soft.pt -------------------------------------------------------------------------------- /hubert/put_pretrained_model_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/hubert/put_pretrained_model_here -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import io 2 | import time 3 | from pathlib import Path 4 | 5 | import librosa 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import soundfile 9 | 10 | from inference import infer_tool 11 | from inference import slicer 12 | from inference.infer_tool import Svc 13 | 14 | def main(): 15 | import argparse 16 | 17 | parser = argparse.ArgumentParser(description='ns2vc inference') 18 | 19 | # Required 20 | parser.add_argument('-m', '--model_path', type=str, default="logs/vc/2023-10-01-17-47-21/model-679.pt", 21 | help='Path to the model.') 22 | parser.add_argument('-c', '--config_path', type=str, default="config.json", 23 | help='Path to the configuration file.') 24 | parser.add_argument('-r', '--refer_names', type=str, nargs='+', default=["keli.wav"], 25 | help='Reference audio path.') 26 | parser.add_argument('-n', '--clean_names', type=str, nargs='+', default=["la.wav"], 27 | help='A list of wav file names located in the raw folder.') 28 | parser.add_argument('-t', '--trans', type=int, nargs='+', default=[0], 29 | help='Pitch adjustment, supports positive and negative (semitone) values.') 30 | 31 | # Optional 32 | parser.add_argument('-a', '--auto_predict_f0', action='store_true', default=True, 33 | help='Automatic pitch prediction for voice conversion. Do not enable this when converting songs as it can cause serious pitch issues.') 34 | parser.add_argument('-cl', '--clip', type=float, default=0, 35 | help='Voice forced slicing. Set to 0 to turn off(default), duration in seconds.') 36 | parser.add_argument('-lg', '--linear_gradient', type=float, default=0, 37 | help='The cross fade length of two audio slices in seconds. If there is a discontinuous voice after forced slicing, you can adjust this value. Otherwise, it is recommended to use. Default 0.') 38 | parser.add_argument('-fmp', '--f0_mean_pooling', action='store_true', default=False, 39 | help='Apply mean filter (pooling) to f0, which may improve some hoarse sounds. Enabling this option will reduce inference speed.') 40 | 41 | # generally keep default 42 | parser.add_argument('-sd', '--slice_db', type=int, default=-40, 43 | help='Loudness for automatic slicing. For noisy audio it can be set to -30') 44 | parser.add_argument('-d', '--device', type=str, default='cuda:2', 45 | help='Device used for inference. None means auto selecting.') 46 | parser.add_argument('-p', '--pad_seconds', type=float, default=0.5, 47 | help='Due to unknown reasons, there may be abnormal noise at the beginning and end. It will disappear after padding a short silent segment.') 48 | parser.add_argument('-wf', '--wav_format', type=str, default='wav', 49 | help='output format') 50 | parser.add_argument('-lgr', '--linear_gradient_retain', type=float, default=0.75, 51 | help='Proportion of cross length retention, range (0-1]. After forced slicing, the beginning and end of each segment need to be discarded.') 52 | parser.add_argument('-ft', '--f0_filter_threshold', type=float, default=0.05, 53 | help='F0 Filtering threshold: This parameter is valid only when f0_mean_pooling is enabled. Values range from 0 to 1. Reducing this value reduces the probability of being out of tune, but increases matte.') 54 | 55 | 56 | args = parser.parse_args() 57 | 58 | clean_names = args.clean_names 59 | refer_names = args.refer_names 60 | trans = args.trans 61 | slice_db = args.slice_db 62 | wav_format = args.wav_format 63 | auto_predict_f0 = args.auto_predict_f0 64 | pad_seconds = args.pad_seconds 65 | clip = args.clip 66 | lg = args.linear_gradient 67 | lgr = args.linear_gradient_retain 68 | F0_mean_pooling = args.f0_mean_pooling 69 | cr_threshold = args.f0_filter_threshold 70 | 71 | svc_model = Svc(args.model_path, args.config_path, args.device) 72 | raw_folder = "raw" 73 | results_folder = "output" 74 | infer_tool.mkdir([raw_folder, results_folder]) 75 | 76 | infer_tool.fill_a_to_b(trans, clean_names) 77 | for clean_name, tran in zip(clean_names, trans): 78 | raw_audio_path = f"{raw_folder}/{clean_name}" 79 | if "." not in raw_audio_path: 80 | raw_audio_path += ".wav" 81 | infer_tool.format_wav(raw_audio_path) 82 | wav_path = Path(raw_audio_path).with_suffix('.wav') 83 | chunks = slicer.cut(wav_path, db_thresh=slice_db) 84 | audio_data, audio_sr = slicer.chunks2audio(wav_path, chunks) 85 | per_size = int(clip*audio_sr) 86 | lg_size = int(lg*audio_sr) 87 | lg_size_r = int(lg_size*lgr) 88 | lg_size_c_l = (lg_size-lg_size_r)//2 89 | lg_size_c_r = lg_size-lg_size_r-lg_size_c_l 90 | lg = np.linspace(0,1,lg_size_r) if lg_size!=0 else 0 91 | 92 | for refer_name in refer_names: 93 | audio = [] 94 | refer_path = f"{raw_folder}/{refer_name}" 95 | if "." not in refer_path: 96 | refer_path += ".wav" 97 | infer_tool.format_wav(refer_path) 98 | refer_path = Path(refer_path).with_suffix('.wav') 99 | for (slice_tag, data) in audio_data: 100 | print(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======') 101 | 102 | length = int(np.ceil(len(data) / audio_sr * svc_model.target_sample)) 103 | if slice_tag: 104 | print('jump empty segment') 105 | _audio = np.zeros(length) 106 | audio.extend(list(infer_tool.pad_array(_audio, length))) 107 | continue 108 | if per_size != 0: 109 | datas = infer_tool.split_list_by_n(data, per_size,lg_size) 110 | else: 111 | datas = [data] 112 | # print(len(datas)) 113 | for k,dat in enumerate(datas): 114 | per_length = int(np.ceil(len(dat) / audio_sr * svc_model.target_sample)) if clip!=0 else length 115 | if clip!=0: print(f'###=====segment clip start, {round(len(dat) / audio_sr, 3)}s======') 116 | # padd 117 | pad_len = int(audio_sr * pad_seconds) 118 | dat = np.concatenate([np.zeros([pad_len]), dat, np.zeros([pad_len])]) 119 | raw_path = io.BytesIO() 120 | soundfile.write(raw_path, dat, audio_sr, format="wav") 121 | raw_path.seek(0) 122 | out_audio, out_sr = svc_model.infer(tran, raw_path, refer_path, 123 | auto_predict_f0=auto_predict_f0, 124 | F0_mean_pooling = F0_mean_pooling, 125 | cr_threshold = cr_threshold 126 | ) 127 | # print(1) 128 | # print(out_audio.shape) 129 | _audio = out_audio.cpu().numpy() 130 | pad_len = int(svc_model.target_sample * pad_seconds) 131 | _audio = _audio[pad_len:-pad_len] 132 | _audio = infer_tool.pad_array(_audio, per_length) 133 | if lg_size!=0 and k!=0: 134 | lg1 = audio[-(lg_size_r+lg_size_c_r):-lg_size_c_r] if lgr != 1 else audio[-lg_size:] 135 | lg2 = _audio[lg_size_c_l:lg_size_c_l+lg_size_r] if lgr != 1 else _audio[0:lg_size] 136 | lg_pre = lg1*(1-lg)+lg2*lg 137 | audio = audio[0:-(lg_size_r+lg_size_c_r)] if lgr != 1 else audio[0:-lg_size] 138 | audio.extend(lg_pre) 139 | _audio = _audio[lg_size_c_l+lg_size_r:] if lgr != 1 else _audio[lg_size:] 140 | audio.extend(list(_audio)) 141 | # print(1) 142 | key = "auto" if auto_predict_f0 else f"{tran}key" 143 | res_path = f'./{results_folder}/{clean_name}_{key}_{refer_name}.{wav_format}' 144 | soundfile.write(res_path, audio, svc_model.target_sample, format=wav_format) 145 | svc_model.clear_empty() 146 | 147 | if __name__ == '__main__': 148 | main() 149 | -------------------------------------------------------------------------------- /inference/infer_tool.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import io 3 | import json 4 | import logging 5 | import os 6 | import time 7 | from pathlib import Path 8 | from inference import slicer 9 | import gc 10 | import librosa 11 | import numpy as np 12 | # import onnxruntime 13 | import soundfile 14 | import torch 15 | import torchaudio 16 | from vocos import Vocos 17 | import torchaudio.transforms as T 18 | 19 | from accelerate import Accelerator 20 | import utils 21 | from model import NaturalSpeech2, Trainer 22 | 23 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 24 | def load_mod(model_path, device, cfg): 25 | data = torch.load(model_path, map_location=device) 26 | model = NaturalSpeech2(cfg=cfg) 27 | model.load_state_dict(data['model']) 28 | model.to(device) 29 | return model 30 | 31 | def read_temp(file_name): 32 | if not os.path.exists(file_name): 33 | with open(file_name, "w") as f: 34 | f.write(json.dumps({"info": "temp_dict"})) 35 | return {} 36 | else: 37 | try: 38 | with open(file_name, "r") as f: 39 | data = f.read() 40 | data_dict = json.loads(data) 41 | if os.path.getsize(file_name) > 50 * 1024 * 1024: 42 | f_name = file_name.replace("\\", "/").split("/")[-1] 43 | print(f"clean {f_name}") 44 | for wav_hash in list(data_dict.keys()): 45 | if int(time.time()) - int(data_dict[wav_hash]["time"]) > 14 * 24 * 3600: 46 | del data_dict[wav_hash] 47 | except Exception as e: 48 | print(e) 49 | print(f"{file_name} error,auto rebuild file") 50 | data_dict = {"info": "temp_dict"} 51 | return data_dict 52 | 53 | 54 | def write_temp(file_name, data): 55 | with open(file_name, "w") as f: 56 | f.write(json.dumps(data)) 57 | 58 | 59 | def timeit(func): 60 | def run(*args, **kwargs): 61 | t = time.time() 62 | res = func(*args, **kwargs) 63 | print('executing \'%s\' costed %.3fs' % (func.__name__, time.time() - t)) 64 | return res 65 | 66 | return run 67 | 68 | 69 | def format_wav(audio_path): 70 | if Path(audio_path).suffix == '.wav': 71 | return 72 | raw_audio, raw_sample_rate = librosa.load(audio_path, mono=True, sr=None) 73 | soundfile.write(Path(audio_path).with_suffix(".wav"), raw_audio, raw_sample_rate) 74 | 75 | 76 | def get_end_file(dir_path, end): 77 | file_lists = [] 78 | for root, dirs, files in os.walk(dir_path): 79 | files = [f for f in files if f[0] != '.'] 80 | dirs[:] = [d for d in dirs if d[0] != '.'] 81 | for f_file in files: 82 | if f_file.endswith(end): 83 | file_lists.append(os.path.join(root, f_file).replace("\\", "/")) 84 | return file_lists 85 | 86 | 87 | def get_md5(content): 88 | return hashlib.new("md5", content).hexdigest() 89 | 90 | def fill_a_to_b(a, b): 91 | if len(a) < len(b): 92 | for _ in range(0, len(b) - len(a)): 93 | a.append(a[0]) 94 | 95 | def mkdir(paths: list): 96 | for path in paths: 97 | if not os.path.exists(path): 98 | os.mkdir(path) 99 | 100 | def pad_array(arr, target_length): 101 | current_length = arr.shape[0] 102 | if current_length >= target_length: 103 | return arr 104 | else: 105 | pad_width = target_length - current_length 106 | pad_left = pad_width // 2 107 | pad_right = pad_width - pad_left 108 | padded_arr = np.pad(arr, (pad_left, pad_right), 'constant', constant_values=(0, 0)) 109 | return padded_arr 110 | 111 | def split_list_by_n(list_collection, n, pre=0): 112 | for i in range(0, len(list_collection), n): 113 | yield list_collection[i-pre if i-pre>=0 else i: i + n] 114 | 115 | 116 | class F0FilterException(Exception): 117 | pass 118 | 119 | class Svc(object): 120 | def __init__(self, model_path, config_path, 121 | device=None, 122 | ): 123 | self.model_path = model_path 124 | if device is None: 125 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") 126 | else: 127 | self.dev = torch.device(device) 128 | self.model = None 129 | self.cfg = json.load(open(config_path)) 130 | self.target_sample = self.cfg['data']['sampling_rate'] 131 | self.hop_size = self.cfg['data']['hop_length'] 132 | # load hubert 133 | self.hubert_model = utils.get_hubert_model().to(self.dev) 134 | self.load_model() 135 | self.vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") 136 | 137 | def load_model(self): 138 | self.model = load_mod(self.model_path, self.dev, self.cfg) 139 | self.model.eval() 140 | 141 | def get_unit_f0_code(self, in_path, tran, refer_path, f0_filter ,F0_mean_pooling,cr_threshold=0.05): 142 | # c, refer, f0, uv, lengths, refer_lengths 143 | wav, sr = librosa.load(in_path, sr=self.target_sample) 144 | 145 | if F0_mean_pooling == True: 146 | f0, uv = utils.compute_f0_uv_torchcrepe(torch.FloatTensor(wav), sampling_rate=self.target_sample, hop_length=self.hop_size,device=self.dev,cr_threshold = cr_threshold) 147 | if f0_filter and sum(f0) == 0: 148 | raise F0FilterException("No voice detected") 149 | f0 = torch.FloatTensor(list(f0)) 150 | uv = torch.FloatTensor(list(uv)) 151 | if F0_mean_pooling == False: 152 | f0 = utils.compute_f0_parselmouth(wav, sampling_rate=self.target_sample, hop_length=self.hop_size) 153 | if f0_filter and sum(f0) == 0: 154 | raise F0FilterException("No voice detected") 155 | f0, uv = utils.interpolate_f0(f0) 156 | f0 = torch.FloatTensor(f0) 157 | uv = torch.FloatTensor(uv) 158 | 159 | f0 = f0 * 2 ** (tran / 12) 160 | f0 = f0.unsqueeze(0).to(self.dev) 161 | uv = uv.unsqueeze(0).to(self.dev) 162 | 163 | wav16k = librosa.resample(wav, orig_sr=self.target_sample, target_sr=16000) 164 | wav16k = torch.from_numpy(wav16k).to(self.dev) 165 | c = utils.get_hubert_content(self.hubert_model, wav_16k_tensor=wav16k) 166 | c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1]) 167 | 168 | c = c.unsqueeze(0).to(self.dev) 169 | 170 | refer_wav, sr = torchaudio.load(refer_path) 171 | wav24k = T.Resample(sr, 24000)(refer_wav) 172 | spec_process = torchaudio.transforms.MelSpectrogram( 173 | sample_rate=24000, 174 | n_fft=1024, 175 | hop_length=256, 176 | n_mels=100, 177 | center=True, 178 | power=1, 179 | ) 180 | spec = spec_process(wav24k)# 1 100 T 181 | spec = torch.log(torch.clip(spec, min=1e-7)) 182 | refer = spec.to(self.dev) 183 | 184 | lengths = torch.LongTensor([c.shape[2]]).to(self.dev) 185 | refer_lengths = torch.LongTensor([refer.shape[2]]).to(self.dev) 186 | 187 | return c, refer, f0, uv, lengths, refer_lengths 188 | 189 | def infer(self, tran, 190 | raw_path, 191 | refer_path, 192 | auto_predict_f0=False, 193 | f0_filter=False, 194 | F0_mean_pooling=False, 195 | cr_threshold = 0.05 196 | ): 197 | 198 | c, refer, f0, uv, lengths, refer_lengths = self.get_unit_f0_code(raw_path, tran, refer_path, f0_filter,F0_mean_pooling,cr_threshold=cr_threshold) 199 | with torch.no_grad(): 200 | start = time.time() 201 | audio,mel = self.model.sample(c, refer, f0, uv, lengths, refer_lengths, self.vocos, auto_predict_f0 =auto_predict_f0) 202 | audio = audio[0].detach().cpu() 203 | # print(audio.shape) 204 | use_time = time.time() - start 205 | print("ns2vc use time:{}".format(use_time)) 206 | return audio, audio.shape[-1] 207 | 208 | def clear_empty(self): 209 | # clean up vram 210 | torch.cuda.empty_cache() 211 | 212 | def unload_model(self): 213 | # unload model 214 | self.model = self.model.to("cpu") 215 | del self.model 216 | gc.collect() 217 | 218 | def slice_inference(self, 219 | raw_audio_path, 220 | spk, 221 | tran, 222 | slice_db, 223 | cluster_infer_ratio, 224 | auto_predict_f0, 225 | noice_scale, 226 | pad_seconds=0.5, 227 | clip_seconds=0, 228 | lg_num=0, 229 | lgr_num =0.75, 230 | F0_mean_pooling = False, 231 | enhancer_adaptive_key = 0, 232 | cr_threshold = 0.05 233 | ): 234 | wav_path = raw_audio_path 235 | chunks = slicer.cut(wav_path, db_thresh=slice_db) 236 | audio_data, audio_sr = slicer.chunks2audio(wav_path, chunks) 237 | per_size = int(clip_seconds*audio_sr) 238 | lg_size = int(lg_num*audio_sr) 239 | lg_size_r = int(lg_size*lgr_num) 240 | lg_size_c_l = (lg_size-lg_size_r)//2 241 | lg_size_c_r = lg_size-lg_size_r-lg_size_c_l 242 | lg = np.linspace(0,1,lg_size_r) if lg_size!=0 else 0 243 | 244 | audio = [] 245 | for (slice_tag, data) in audio_data: 246 | print(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======') 247 | # padd 248 | length = int(np.ceil(len(data) / audio_sr * self.target_sample)) 249 | if slice_tag: 250 | print('jump empty segment') 251 | _audio = np.zeros(length) 252 | audio.extend(list(pad_array(_audio, length))) 253 | continue 254 | if per_size != 0: 255 | datas = split_list_by_n(data, per_size,lg_size) 256 | else: 257 | datas = [data] 258 | for k,dat in enumerate(datas): 259 | per_length = int(np.ceil(len(dat) / audio_sr * self.target_sample)) if clip_seconds!=0 else length 260 | if clip_seconds!=0: print(f'###=====segment clip start, {round(len(dat) / audio_sr, 3)}s======') 261 | # padd 262 | pad_len = int(audio_sr * pad_seconds) 263 | dat = np.concatenate([np.zeros([pad_len]), dat, np.zeros([pad_len])]) 264 | raw_path = io.BytesIO() 265 | soundfile.write(raw_path, dat, audio_sr, format="wav") 266 | raw_path.seek(0) 267 | out_audio, out_sr = self.infer(spk, tran, raw_path, 268 | cluster_infer_ratio=cluster_infer_ratio, 269 | auto_predict_f0=auto_predict_f0, 270 | noice_scale=noice_scale, 271 | F0_mean_pooling = F0_mean_pooling, 272 | enhancer_adaptive_key = enhancer_adaptive_key, 273 | cr_threshold = cr_threshold 274 | ) 275 | _audio = out_audio.cpu().numpy() 276 | pad_len = int(self.target_sample * pad_seconds) 277 | _audio = _audio[pad_len:-pad_len] 278 | _audio = pad_array(_audio, per_length) 279 | if lg_size!=0 and k!=0: 280 | lg1 = audio[-(lg_size_r+lg_size_c_r):-lg_size_c_r] if lgr_num != 1 else audio[-lg_size:] 281 | lg2 = _audio[lg_size_c_l:lg_size_c_l+lg_size_r] if lgr_num != 1 else _audio[0:lg_size] 282 | lg_pre = lg1*(1-lg)+lg2*lg 283 | audio = audio[0:-(lg_size_r+lg_size_c_r)] if lgr_num != 1 else audio[0:-lg_size] 284 | audio.extend(lg_pre) 285 | _audio = _audio[lg_size_c_l+lg_size_r:] if lgr_num != 1 else _audio[lg_size:] 286 | audio.extend(list(_audio)) 287 | return np.array(audio) 288 | 289 | class RealTimeVC: 290 | def __init__(self): 291 | self.last_chunk = None 292 | self.last_o = None 293 | self.chunk_len = 16000 # chunk length 294 | self.pre_len = 3840 # cross fade length, multiples of 640 295 | 296 | # Input and output are 1-dimensional numpy waveform arrays 297 | 298 | def process(self, svc_model, speaker_id, f_pitch_change, input_wav_path, 299 | cluster_infer_ratio=0, 300 | auto_predict_f0=False, 301 | noice_scale=0.4, 302 | f0_filter=False): 303 | 304 | import maad 305 | audio, sr = torchaudio.load(input_wav_path) 306 | audio = audio.cpu().numpy()[0] 307 | temp_wav = io.BytesIO() 308 | if self.last_chunk is None: 309 | input_wav_path.seek(0) 310 | 311 | audio, sr = svc_model.infer(speaker_id, f_pitch_change, input_wav_path, 312 | cluster_infer_ratio=cluster_infer_ratio, 313 | auto_predict_f0=auto_predict_f0, 314 | noice_scale=noice_scale, 315 | f0_filter=f0_filter) 316 | 317 | audio = audio.cpu().numpy() 318 | self.last_chunk = audio[-self.pre_len:] 319 | self.last_o = audio 320 | return audio[-self.chunk_len:] 321 | else: 322 | audio = np.concatenate([self.last_chunk, audio]) 323 | soundfile.write(temp_wav, audio, sr, format="wav") 324 | temp_wav.seek(0) 325 | 326 | audio, sr = svc_model.infer(speaker_id, f_pitch_change, temp_wav, 327 | cluster_infer_ratio=cluster_infer_ratio, 328 | auto_predict_f0=auto_predict_f0, 329 | noice_scale=noice_scale, 330 | f0_filter=f0_filter) 331 | 332 | audio = audio.cpu().numpy() 333 | ret = maad.util.crossfade(self.last_o, audio, self.pre_len) 334 | self.last_chunk = audio[-self.pre_len:] 335 | self.last_o = audio 336 | return ret[self.chunk_len:2 * self.chunk_len] 337 | -------------------------------------------------------------------------------- /inference/slicer.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import torch 3 | import torchaudio 4 | 5 | 6 | class Slicer: 7 | def __init__(self, 8 | sr: int, 9 | threshold: float = -40., 10 | min_length: int = 5000, 11 | min_interval: int = 300, 12 | hop_size: int = 20, 13 | max_sil_kept: int = 5000): 14 | if not min_length >= min_interval >= hop_size: 15 | raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size') 16 | if not max_sil_kept >= hop_size: 17 | raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size') 18 | min_interval = sr * min_interval / 1000 19 | self.threshold = 10 ** (threshold / 20.) 20 | self.hop_size = round(sr * hop_size / 1000) 21 | self.win_size = min(round(min_interval), 4 * self.hop_size) 22 | self.min_length = round(sr * min_length / 1000 / self.hop_size) 23 | self.min_interval = round(min_interval / self.hop_size) 24 | self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size) 25 | 26 | def _apply_slice(self, waveform, begin, end): 27 | if len(waveform.shape) > 1: 28 | return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)] 29 | else: 30 | return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)] 31 | 32 | # @timeit 33 | def slice(self, waveform): 34 | if len(waveform.shape) > 1: 35 | samples = librosa.to_mono(waveform) 36 | else: 37 | samples = waveform 38 | if samples.shape[0] <= self.min_length: 39 | return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}} 40 | rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0) 41 | sil_tags = [] 42 | silence_start = None 43 | clip_start = 0 44 | for i, rms in enumerate(rms_list): 45 | # Keep looping while frame is silent. 46 | if rms < self.threshold: 47 | # Record start of silent frames. 48 | if silence_start is None: 49 | silence_start = i 50 | continue 51 | # Keep looping while frame is not silent and silence start has not been recorded. 52 | if silence_start is None: 53 | continue 54 | # Clear recorded silence start if interval is not enough or clip is too short 55 | is_leading_silence = silence_start == 0 and i > self.max_sil_kept 56 | need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length 57 | if not is_leading_silence and not need_slice_middle: 58 | silence_start = None 59 | continue 60 | # Need slicing. Record the range of silent frames to be removed. 61 | if i - silence_start <= self.max_sil_kept: 62 | pos = rms_list[silence_start: i + 1].argmin() + silence_start 63 | if silence_start == 0: 64 | sil_tags.append((0, pos)) 65 | else: 66 | sil_tags.append((pos, pos)) 67 | clip_start = pos 68 | elif i - silence_start <= self.max_sil_kept * 2: 69 | pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin() 70 | pos += i - self.max_sil_kept 71 | pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start 72 | pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept 73 | if silence_start == 0: 74 | sil_tags.append((0, pos_r)) 75 | clip_start = pos_r 76 | else: 77 | sil_tags.append((min(pos_l, pos), max(pos_r, pos))) 78 | clip_start = max(pos_r, pos) 79 | else: 80 | pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start 81 | pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept 82 | if silence_start == 0: 83 | sil_tags.append((0, pos_r)) 84 | else: 85 | sil_tags.append((pos_l, pos_r)) 86 | clip_start = pos_r 87 | silence_start = None 88 | # Deal with trailing silence. 89 | total_frames = rms_list.shape[0] 90 | if silence_start is not None and total_frames - silence_start >= self.min_interval: 91 | silence_end = min(total_frames, silence_start + self.max_sil_kept) 92 | pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start 93 | sil_tags.append((pos, total_frames + 1)) 94 | # Apply and return slices. 95 | if len(sil_tags) == 0: 96 | return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}} 97 | else: 98 | chunks = [] 99 | # The first segment is not the beginning of the audio. 100 | if sil_tags[0][0]: 101 | chunks.append( 102 | {"slice": False, "split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}"}) 103 | for i in range(0, len(sil_tags)): 104 | # Mark audio segment. Skip the first segment. 105 | if i: 106 | chunks.append({"slice": False, 107 | "split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}"}) 108 | # Mark all mute segments 109 | chunks.append({"slice": True, 110 | "split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}"}) 111 | # The last segment is not the end. 112 | if sil_tags[-1][1] * self.hop_size < len(waveform): 113 | chunks.append({"slice": False, "split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}"}) 114 | chunk_dict = {} 115 | for i in range(len(chunks)): 116 | chunk_dict[str(i)] = chunks[i] 117 | return chunk_dict 118 | 119 | 120 | def cut(audio_path, db_thresh=-30, min_len=5000): 121 | audio, sr = librosa.load(audio_path, sr=None) 122 | slicer = Slicer( 123 | sr=sr, 124 | threshold=db_thresh, 125 | min_length=min_len 126 | ) 127 | chunks = slicer.slice(audio) 128 | return chunks 129 | 130 | 131 | def chunks2audio(audio_path, chunks): 132 | chunks = dict(chunks) 133 | audio, sr = torchaudio.load(audio_path) 134 | if len(audio.shape) == 2 and audio.shape[1] >= 2: 135 | audio = torch.mean(audio, dim=0).unsqueeze(0) 136 | audio = audio.cpu().numpy()[0] 137 | result = [] 138 | for k, v in chunks.items(): 139 | tag = v["split_time"].split(",") 140 | if tag[0] != tag[1]: 141 | result.append((v["slice"], audio[int(tag[0]):int(tag[1])])) 142 | return result, sr -------------------------------------------------------------------------------- /logs/tts/tts_logs_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/logs/tts/tts_logs_here -------------------------------------------------------------------------------- /logs/vc/vc_logs_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/logs/vc/vc_logs_here -------------------------------------------------------------------------------- /modules/commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | def slice_pitch_segments(x, ids_str, segment_size=4): 8 | ret = torch.zeros_like(x[:, :segment_size]) 9 | for i in range(x.size(0)): 10 | idx_str = ids_str[i] 11 | idx_end = idx_str + segment_size 12 | ret[i] = x[i, idx_str:idx_end] 13 | return ret 14 | 15 | def rand_slice_segments_with_pitch(x, pitch, x_lengths=None, segment_size=4): 16 | b, d, t = x.size() 17 | if x_lengths is None: 18 | x_lengths = t 19 | ids_str_max = x_lengths - segment_size + 1 20 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 21 | ret = slice_segments(x, ids_str, segment_size) 22 | ret_pitch = slice_pitch_segments(pitch, ids_str, segment_size) 23 | return ret, ret_pitch, ids_str 24 | 25 | def init_weights(m, mean=0.0, std=0.01): 26 | classname = m.__class__.__name__ 27 | if classname.find("Conv") != -1: 28 | m.weight.data.normal_(mean, std) 29 | 30 | 31 | def get_padding(kernel_size, dilation=1): 32 | return int((kernel_size*dilation - dilation)/2) 33 | 34 | 35 | def convert_pad_shape(pad_shape): 36 | l = pad_shape[::-1] 37 | pad_shape = [item for sublist in l for item in sublist] 38 | return pad_shape 39 | 40 | 41 | def intersperse(lst, item): 42 | result = [item] * (len(lst) * 2 + 1) 43 | result[1::2] = lst 44 | return result 45 | 46 | 47 | def kl_divergence(m_p, logs_p, m_q, logs_q): 48 | """KL(P||Q)""" 49 | kl = (logs_q - logs_p) - 0.5 50 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) 51 | return kl 52 | 53 | 54 | def rand_gumbel(shape): 55 | """Sample from the Gumbel distribution, protect from overflows.""" 56 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 57 | return -torch.log(-torch.log(uniform_samples)) 58 | 59 | 60 | def rand_gumbel_like(x): 61 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 62 | return g 63 | 64 | 65 | def slice_segments(x, ids_str, segment_size=4): 66 | ret = torch.zeros_like(x[:, :, :segment_size]) 67 | for i in range(x.size(0)): 68 | idx_str = ids_str[i] 69 | idx_end = idx_str + segment_size 70 | ret[i] = x[i, :, idx_str:idx_end] 71 | return ret 72 | 73 | 74 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 75 | b, d, t = x.size() 76 | if x_lengths is None: 77 | x_lengths = t 78 | ids_str_max = x_lengths - segment_size + 1 79 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 80 | ret = slice_segments(x, ids_str, segment_size) 81 | return ret, ids_str 82 | 83 | 84 | def rand_spec_segments(x, x_lengths=None, segment_size=4): 85 | b, d, t = x.size() 86 | if x_lengths is None: 87 | x_lengths = t 88 | ids_str_max = x_lengths - segment_size 89 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 90 | ret = slice_segments(x, ids_str, segment_size) 91 | return ret, ids_str 92 | 93 | 94 | def get_timing_signal_1d( 95 | length, channels, min_timescale=1.0, max_timescale=1.0e4): 96 | position = torch.arange(length, dtype=torch.float) 97 | num_timescales = channels // 2 98 | log_timescale_increment = ( 99 | math.log(float(max_timescale) / float(min_timescale)) / 100 | (num_timescales - 1)) 101 | inv_timescales = min_timescale * torch.exp( 102 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) 103 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 104 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 105 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 106 | signal = signal.view(1, channels, length) 107 | return signal 108 | 109 | 110 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 111 | b, channels, length = x.size() 112 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 113 | return x + signal.to(dtype=x.dtype, device=x.device) 114 | 115 | 116 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 117 | b, channels, length = x.size() 118 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 119 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 120 | 121 | 122 | def subsequent_mask(length): 123 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 124 | return mask 125 | 126 | 127 | @torch.jit.script 128 | def fused_add_tanh_sigmoid_multiply(input_a, n_channels): 129 | n_channels_int = n_channels[0] 130 | in_act = input_a 131 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 132 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 133 | # print(t_act.size(), s_act.size()) 134 | acts = t_act * s_act 135 | return acts 136 | 137 | 138 | def convert_pad_shape(pad_shape): 139 | l = pad_shape[::-1] 140 | pad_shape = [item for sublist in l for item in sublist] 141 | return pad_shape 142 | 143 | 144 | def shift_1d(x): 145 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 146 | return x 147 | 148 | 149 | def sequence_mask(length, max_length=None): 150 | if max_length is None: 151 | max_length = length.max() 152 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 153 | return x.unsqueeze(0) < length.unsqueeze(1) 154 | 155 | 156 | def generate_path(duration, mask): 157 | """ 158 | duration: [b, 1, t_x] 159 | mask: [b, 1, t_y, t_x] 160 | """ 161 | device = duration.device 162 | 163 | b, _, t_y, t_x = mask.shape 164 | cum_duration = torch.cumsum(duration, -1) 165 | 166 | cum_duration_flat = cum_duration.view(b * t_x) 167 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 168 | path = path.view(b, t_x, t_y) 169 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 170 | path = path.unsqueeze(1).transpose(2,3) * mask 171 | return path 172 | 173 | 174 | def clip_grad_value_(parameters, clip_value, norm_type=2): 175 | if isinstance(parameters, torch.Tensor): 176 | parameters = [parameters] 177 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 178 | norm_type = float(norm_type) 179 | if clip_value is not None: 180 | clip_value = float(clip_value) 181 | 182 | total_norm = 0 183 | for p in parameters: 184 | param_norm = p.grad.data.norm(norm_type) 185 | total_norm += param_norm.item() ** norm_type 186 | if clip_value is not None: 187 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 188 | total_norm = total_norm ** (1. / norm_type) 189 | return total_norm -------------------------------------------------------------------------------- /nsf_hifigan/env.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | class AttrDict(dict): 6 | def __init__(self, *args, **kwargs): 7 | super(AttrDict, self).__init__(*args, **kwargs) 8 | self.__dict__ = self 9 | 10 | 11 | def build_env(config, config_name, path): 12 | t_path = os.path.join(path, config_name) 13 | if config != t_path: 14 | os.makedirs(path, exist_ok=True) 15 | shutil.copyfile(config, os.path.join(path, config_name)) -------------------------------------------------------------------------------- /nsf_hifigan/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from .env import AttrDict 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.nn as nn 8 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 9 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 10 | from .utils import init_weights, get_padding 11 | 12 | LRELU_SLOPE = 0.1 13 | 14 | 15 | def load_model(model_path, device='cuda'): 16 | h = load_config(model_path) 17 | 18 | generator = Generator(h).to(device) 19 | 20 | cp_dict = torch.load(model_path, map_location=device) 21 | generator.load_state_dict(cp_dict['generator']) 22 | generator.eval() 23 | generator.remove_weight_norm() 24 | del cp_dict 25 | return generator, h 26 | 27 | def load_config(model_path): 28 | config_file = os.path.join(os.path.split(model_path)[0], 'config.json') 29 | with open(config_file) as f: 30 | data = f.read() 31 | 32 | json_config = json.loads(data) 33 | h = AttrDict(json_config) 34 | return h 35 | 36 | 37 | class ResBlock1(torch.nn.Module): 38 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 39 | super(ResBlock1, self).__init__() 40 | self.h = h 41 | self.convs1 = nn.ModuleList([ 42 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 43 | padding=get_padding(kernel_size, dilation[0]))), 44 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 45 | padding=get_padding(kernel_size, dilation[1]))), 46 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 47 | padding=get_padding(kernel_size, dilation[2]))) 48 | ]) 49 | self.convs1.apply(init_weights) 50 | 51 | self.convs2 = nn.ModuleList([ 52 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 53 | padding=get_padding(kernel_size, 1))), 54 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 55 | padding=get_padding(kernel_size, 1))), 56 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 57 | padding=get_padding(kernel_size, 1))) 58 | ]) 59 | self.convs2.apply(init_weights) 60 | 61 | def forward(self, x): 62 | for c1, c2 in zip(self.convs1, self.convs2): 63 | xt = F.leaky_relu(x, LRELU_SLOPE) 64 | xt = c1(xt) 65 | xt = F.leaky_relu(xt, LRELU_SLOPE) 66 | xt = c2(xt) 67 | x = xt + x 68 | return x 69 | 70 | def remove_weight_norm(self): 71 | for l in self.convs1: 72 | remove_weight_norm(l) 73 | for l in self.convs2: 74 | remove_weight_norm(l) 75 | 76 | 77 | class ResBlock2(torch.nn.Module): 78 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): 79 | super(ResBlock2, self).__init__() 80 | self.h = h 81 | self.convs = nn.ModuleList([ 82 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 83 | padding=get_padding(kernel_size, dilation[0]))), 84 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 85 | padding=get_padding(kernel_size, dilation[1]))) 86 | ]) 87 | self.convs.apply(init_weights) 88 | 89 | def forward(self, x): 90 | for c in self.convs: 91 | xt = F.leaky_relu(x, LRELU_SLOPE) 92 | xt = c(xt) 93 | x = xt + x 94 | return x 95 | 96 | def remove_weight_norm(self): 97 | for l in self.convs: 98 | remove_weight_norm(l) 99 | 100 | 101 | class SineGen(torch.nn.Module): 102 | """ Definition of sine generator 103 | SineGen(samp_rate, harmonic_num = 0, 104 | sine_amp = 0.1, noise_std = 0.003, 105 | voiced_threshold = 0, 106 | flag_for_pulse=False) 107 | samp_rate: sampling rate in Hz 108 | harmonic_num: number of harmonic overtones (default 0) 109 | sine_amp: amplitude of sine-wavefrom (default 0.1) 110 | noise_std: std of Gaussian noise (default 0.003) 111 | voiced_thoreshold: F0 threshold for U/V classification (default 0) 112 | flag_for_pulse: this SinGen is used inside PulseGen (default False) 113 | Note: when flag_for_pulse is True, the first time step of a voiced 114 | segment is always sin(np.pi) or cos(0) 115 | """ 116 | 117 | def __init__(self, samp_rate, harmonic_num=0, 118 | sine_amp=0.1, noise_std=0.003, 119 | voiced_threshold=0): 120 | super(SineGen, self).__init__() 121 | self.sine_amp = sine_amp 122 | self.noise_std = noise_std 123 | self.harmonic_num = harmonic_num 124 | self.dim = self.harmonic_num + 1 125 | self.sampling_rate = samp_rate 126 | self.voiced_threshold = voiced_threshold 127 | 128 | def _f02uv(self, f0): 129 | # generate uv signal 130 | uv = torch.ones_like(f0) 131 | uv = uv * (f0 > self.voiced_threshold) 132 | return uv 133 | 134 | @torch.no_grad() 135 | def forward(self, f0, upp): 136 | """ sine_tensor, uv = forward(f0) 137 | input F0: tensor(batchsize=1, length, dim=1) 138 | f0 for unvoiced steps should be 0 139 | output sine_tensor: tensor(batchsize=1, length, dim) 140 | output uv: tensor(batchsize=1, length, 1) 141 | """ 142 | f0 = f0.unsqueeze(-1) 143 | fn = torch.multiply(f0, torch.arange(1, self.dim + 1, device=f0.device).reshape((1, 1, -1))) 144 | rad_values = (fn / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化 145 | rand_ini = torch.rand(fn.shape[0], fn.shape[2], device=fn.device) 146 | rand_ini[:, 0] = 0 147 | rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini 148 | is_half = rad_values.dtype is not torch.float32 149 | tmp_over_one = torch.cumsum(rad_values.double(), 1) # % 1 #####%1意味着后面的cumsum无法再优化 150 | if is_half: 151 | tmp_over_one = tmp_over_one.half() 152 | else: 153 | tmp_over_one = tmp_over_one.float() 154 | tmp_over_one *= upp 155 | tmp_over_one = F.interpolate( 156 | tmp_over_one.transpose(2, 1), scale_factor=upp, 157 | mode='linear', align_corners=True 158 | ).transpose(2, 1) 159 | rad_values = F.interpolate(rad_values.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1) 160 | tmp_over_one %= 1 161 | tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 162 | cumsum_shift = torch.zeros_like(rad_values) 163 | cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 164 | rad_values = rad_values.double() 165 | cumsum_shift = cumsum_shift.double() 166 | sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi) 167 | if is_half: 168 | sine_waves = sine_waves.half() 169 | else: 170 | sine_waves = sine_waves.float() 171 | sine_waves = sine_waves * self.sine_amp 172 | return sine_waves 173 | 174 | 175 | class SourceModuleHnNSF(torch.nn.Module): 176 | """ SourceModule for hn-nsf 177 | SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, 178 | add_noise_std=0.003, voiced_threshod=0) 179 | sampling_rate: sampling_rate in Hz 180 | harmonic_num: number of harmonic above F0 (default: 0) 181 | sine_amp: amplitude of sine source signal (default: 0.1) 182 | add_noise_std: std of additive Gaussian noise (default: 0.003) 183 | note that amplitude of noise in unvoiced is decided 184 | by sine_amp 185 | voiced_threshold: threhold to set U/V given F0 (default: 0) 186 | Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) 187 | F0_sampled (batchsize, length, 1) 188 | Sine_source (batchsize, length, 1) 189 | noise_source (batchsize, length 1) 190 | uv (batchsize, length, 1) 191 | """ 192 | 193 | def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, 194 | add_noise_std=0.003, voiced_threshod=0): 195 | super(SourceModuleHnNSF, self).__init__() 196 | 197 | self.sine_amp = sine_amp 198 | self.noise_std = add_noise_std 199 | 200 | # to produce sine waveforms 201 | self.l_sin_gen = SineGen(sampling_rate, harmonic_num, 202 | sine_amp, add_noise_std, voiced_threshod) 203 | 204 | # to merge source harmonics into a single excitation 205 | self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) 206 | self.l_tanh = torch.nn.Tanh() 207 | 208 | def forward(self, x, upp): 209 | sine_wavs = self.l_sin_gen(x, upp) 210 | sine_merge = self.l_tanh(self.l_linear(sine_wavs)) 211 | return sine_merge 212 | 213 | 214 | class Generator(torch.nn.Module): 215 | def __init__(self, h): 216 | super(Generator, self).__init__() 217 | self.h = h 218 | self.num_kernels = len(h.resblock_kernel_sizes) 219 | self.num_upsamples = len(h.upsample_rates) 220 | self.m_source = SourceModuleHnNSF( 221 | sampling_rate=h.sampling_rate, 222 | harmonic_num=8 223 | ) 224 | self.noise_convs = nn.ModuleList() 225 | self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)) 226 | resblock = ResBlock1 if h.resblock == '1' else ResBlock2 227 | 228 | self.ups = nn.ModuleList() 229 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 230 | c_cur = h.upsample_initial_channel // (2 ** (i + 1)) 231 | self.ups.append(weight_norm( 232 | ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)), 233 | k, u, padding=(k - u) // 2))) 234 | if i + 1 < len(h.upsample_rates): # 235 | stride_f0 = int(np.prod(h.upsample_rates[i + 1:])) 236 | self.noise_convs.append(Conv1d( 237 | 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2)) 238 | else: 239 | self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) 240 | self.resblocks = nn.ModuleList() 241 | ch = h.upsample_initial_channel 242 | for i in range(len(self.ups)): 243 | ch //= 2 244 | for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): 245 | self.resblocks.append(resblock(h, ch, k, d)) 246 | 247 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 248 | self.ups.apply(init_weights) 249 | self.conv_post.apply(init_weights) 250 | self.upp = int(np.prod(h.upsample_rates)) 251 | 252 | def forward(self, x, f0): 253 | har_source = self.m_source(f0, self.upp).transpose(1, 2) 254 | x = self.conv_pre(x) 255 | for i in range(self.num_upsamples): 256 | x = F.leaky_relu(x, LRELU_SLOPE) 257 | x = self.ups[i](x) 258 | x_source = self.noise_convs[i](har_source) 259 | x = x + x_source 260 | xs = None 261 | for j in range(self.num_kernels): 262 | if xs is None: 263 | xs = self.resblocks[i * self.num_kernels + j](x) 264 | else: 265 | xs += self.resblocks[i * self.num_kernels + j](x) 266 | x = xs / self.num_kernels 267 | x = F.leaky_relu(x) 268 | x = self.conv_post(x) 269 | x = torch.tanh(x) 270 | 271 | return x 272 | 273 | def remove_weight_norm(self): 274 | print('Removing weight norm...') 275 | for l in self.ups: 276 | remove_weight_norm(l) 277 | for l in self.resblocks: 278 | l.remove_weight_norm() 279 | remove_weight_norm(self.conv_pre) 280 | remove_weight_norm(self.conv_post) 281 | 282 | 283 | class DiscriminatorP(torch.nn.Module): 284 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 285 | super(DiscriminatorP, self).__init__() 286 | self.period = period 287 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 288 | self.convs = nn.ModuleList([ 289 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 290 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 291 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 292 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), 293 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), 294 | ]) 295 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 296 | 297 | def forward(self, x): 298 | fmap = [] 299 | 300 | # 1d to 2d 301 | b, c, t = x.shape 302 | if t % self.period != 0: # pad first 303 | n_pad = self.period - (t % self.period) 304 | x = F.pad(x, (0, n_pad), "reflect") 305 | t = t + n_pad 306 | x = x.view(b, c, t // self.period, self.period) 307 | 308 | for l in self.convs: 309 | x = l(x) 310 | x = F.leaky_relu(x, LRELU_SLOPE) 311 | fmap.append(x) 312 | x = self.conv_post(x) 313 | fmap.append(x) 314 | x = torch.flatten(x, 1, -1) 315 | 316 | return x, fmap 317 | 318 | 319 | class MultiPeriodDiscriminator(torch.nn.Module): 320 | def __init__(self, periods=None): 321 | super(MultiPeriodDiscriminator, self).__init__() 322 | self.periods = periods if periods is not None else [2, 3, 5, 7, 11] 323 | self.discriminators = nn.ModuleList() 324 | for period in self.periods: 325 | self.discriminators.append(DiscriminatorP(period)) 326 | 327 | def forward(self, y, y_hat): 328 | y_d_rs = [] 329 | y_d_gs = [] 330 | fmap_rs = [] 331 | fmap_gs = [] 332 | for i, d in enumerate(self.discriminators): 333 | y_d_r, fmap_r = d(y) 334 | y_d_g, fmap_g = d(y_hat) 335 | y_d_rs.append(y_d_r) 336 | fmap_rs.append(fmap_r) 337 | y_d_gs.append(y_d_g) 338 | fmap_gs.append(fmap_g) 339 | 340 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 341 | 342 | 343 | class DiscriminatorS(torch.nn.Module): 344 | def __init__(self, use_spectral_norm=False): 345 | super(DiscriminatorS, self).__init__() 346 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 347 | self.convs = nn.ModuleList([ 348 | norm_f(Conv1d(1, 128, 15, 1, padding=7)), 349 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), 350 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), 351 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), 352 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), 353 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), 354 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 355 | ]) 356 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 357 | 358 | def forward(self, x): 359 | fmap = [] 360 | for l in self.convs: 361 | x = l(x) 362 | x = F.leaky_relu(x, LRELU_SLOPE) 363 | fmap.append(x) 364 | x = self.conv_post(x) 365 | fmap.append(x) 366 | x = torch.flatten(x, 1, -1) 367 | 368 | return x, fmap 369 | 370 | 371 | class MultiScaleDiscriminator(torch.nn.Module): 372 | def __init__(self): 373 | super(MultiScaleDiscriminator, self).__init__() 374 | self.discriminators = nn.ModuleList([ 375 | DiscriminatorS(use_spectral_norm=True), 376 | DiscriminatorS(), 377 | DiscriminatorS(), 378 | ]) 379 | self.meanpools = nn.ModuleList([ 380 | AvgPool1d(4, 2, padding=2), 381 | AvgPool1d(4, 2, padding=2) 382 | ]) 383 | 384 | def forward(self, y, y_hat): 385 | y_d_rs = [] 386 | y_d_gs = [] 387 | fmap_rs = [] 388 | fmap_gs = [] 389 | for i, d in enumerate(self.discriminators): 390 | if i != 0: 391 | y = self.meanpools[i - 1](y) 392 | y_hat = self.meanpools[i - 1](y_hat) 393 | y_d_r, fmap_r = d(y) 394 | y_d_g, fmap_g = d(y_hat) 395 | y_d_rs.append(y_d_r) 396 | fmap_rs.append(fmap_r) 397 | y_d_gs.append(y_d_g) 398 | fmap_gs.append(fmap_g) 399 | 400 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 401 | 402 | 403 | def feature_loss(fmap_r, fmap_g): 404 | loss = 0 405 | for dr, dg in zip(fmap_r, fmap_g): 406 | for rl, gl in zip(dr, dg): 407 | loss += torch.mean(torch.abs(rl - gl)) 408 | 409 | return loss * 2 410 | 411 | 412 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 413 | loss = 0 414 | r_losses = [] 415 | g_losses = [] 416 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 417 | r_loss = torch.mean((1 - dr) ** 2) 418 | g_loss = torch.mean(dg ** 2) 419 | loss += (r_loss + g_loss) 420 | r_losses.append(r_loss.item()) 421 | g_losses.append(g_loss.item()) 422 | 423 | return loss, r_losses, g_losses 424 | 425 | 426 | def generator_loss(disc_outputs): 427 | loss = 0 428 | gen_losses = [] 429 | for dg in disc_outputs: 430 | l = torch.mean((1 - dg) ** 2) 431 | gen_losses.append(l) 432 | loss += l 433 | 434 | return loss, gen_losses -------------------------------------------------------------------------------- /nsf_hifigan/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import matplotlib 4 | import torch 5 | from torch.nn.utils import weight_norm 6 | matplotlib.use("Agg") 7 | import matplotlib.pylab as plt 8 | 9 | 10 | def plot_spectrogram(spectrogram): 11 | fig, ax = plt.subplots(figsize=(10, 2)) 12 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 13 | interpolation='none') 14 | plt.colorbar(im, ax=ax) 15 | 16 | fig.canvas.draw() 17 | plt.close() 18 | 19 | return fig 20 | 21 | 22 | def init_weights(m, mean=0.0, std=0.01): 23 | classname = m.__class__.__name__ 24 | if classname.find("Conv") != -1: 25 | m.weight.data.normal_(mean, std) 26 | 27 | 28 | def apply_weight_norm(m): 29 | classname = m.__class__.__name__ 30 | if classname.find("Conv") != -1: 31 | weight_norm(m) 32 | 33 | 34 | def get_padding(kernel_size, dilation=1): 35 | return int((kernel_size*dilation - dilation)/2) 36 | 37 | 38 | def load_checkpoint(filepath, device): 39 | assert os.path.isfile(filepath) 40 | print("Loading '{}'".format(filepath)) 41 | checkpoint_dict = torch.load(filepath, map_location=device) 42 | print("Complete.") 43 | return checkpoint_dict 44 | 45 | 46 | def save_checkpoint(filepath, obj): 47 | print("Saving checkpoint to {}".format(filepath)) 48 | torch.save(obj, filepath) 49 | print("Complete.") 50 | 51 | 52 | def del_old_checkpoints(cp_dir, prefix, n_models=2): 53 | pattern = os.path.join(cp_dir, prefix + '????????') 54 | cp_list = glob.glob(pattern) # get checkpoint paths 55 | cp_list = sorted(cp_list)# sort by iter 56 | if len(cp_list) > n_models: # if more than n_models models are found 57 | for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models 58 | open(cp, 'w').close()# empty file contents 59 | os.unlink(cp)# delete file (move to trash when using Colab) 60 | 61 | 62 | def scan_checkpoint(cp_dir, prefix): 63 | pattern = os.path.join(cp_dir, prefix + '????????') 64 | cp_list = glob.glob(pattern) 65 | if len(cp_list) == 0: 66 | return None 67 | return sorted(cp_list)[-1] 68 | -------------------------------------------------------------------------------- /output/output_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/output/output_here -------------------------------------------------------------------------------- /parametrizations.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, auto 2 | 3 | import torch 4 | from torch import Tensor 5 | import parametrize 6 | from torch.nn.modules import Module 7 | import torch.nn.functional as F 8 | 9 | from typing import Optional 10 | 11 | __all__ = ['orthogonal', 'spectral_norm'] 12 | 13 | 14 | def _is_orthogonal(Q, eps=None): 15 | n, k = Q.size(-2), Q.size(-1) 16 | Id = torch.eye(k, dtype=Q.dtype, device=Q.device) 17 | # A reasonable eps, but not too large 18 | eps = 10. * n * torch.finfo(Q.dtype).eps 19 | return torch.allclose(Q.mH @ Q, Id, atol=eps) 20 | 21 | 22 | def _make_orthogonal(A): 23 | """ Assume that A is a tall matrix. 24 | Compute the Q factor s.t. A = QR (A may be complex) and diag(R) is real and non-negative 25 | """ 26 | X, tau = torch.geqrf(A) 27 | Q = torch.linalg.householder_product(X, tau) 28 | # The diagonal of X is the diagonal of R (which is always real) so we normalise by its signs 29 | Q *= X.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2) 30 | return Q 31 | 32 | 33 | class _OrthMaps(Enum): 34 | matrix_exp = auto() 35 | cayley = auto() 36 | householder = auto() 37 | 38 | 39 | class _Orthogonal(Module): 40 | base: Tensor 41 | 42 | def __init__(self, 43 | weight, 44 | orthogonal_map: _OrthMaps, 45 | *, 46 | use_trivialization=True) -> None: 47 | super().__init__() 48 | 49 | # Note [Householder complex] 50 | # For complex tensors, it is not possible to compute the tensor `tau` necessary for 51 | # linalg.householder_product from the reflectors. 52 | # To see this, note that the reflectors have a shape like: 53 | # 0 0 0 54 | # * 0 0 55 | # * * 0 56 | # which, for complex matrices, give n(n-1) (real) parameters. Now, you need n^2 parameters 57 | # to parametrize the unitary matrices. Saving tau on its own does not work either, because 58 | # not every combination of `(A, tau)` gives a unitary matrix, meaning that if we optimise 59 | # them as independent tensors we would not maintain the constraint 60 | # An equivalent reasoning holds for rectangular matrices 61 | if weight.is_complex() and orthogonal_map == _OrthMaps.householder: 62 | raise ValueError("The householder parametrization does not support complex tensors.") 63 | 64 | self.shape = weight.shape 65 | self.orthogonal_map = orthogonal_map 66 | if use_trivialization: 67 | self.register_buffer("base", None) 68 | 69 | def forward(self, X: torch.Tensor) -> torch.Tensor: 70 | n, k = X.size(-2), X.size(-1) 71 | transposed = n < k 72 | if transposed: 73 | X = X.mT 74 | n, k = k, n 75 | # Here n > k and X is a tall matrix 76 | if self.orthogonal_map == _OrthMaps.matrix_exp or self.orthogonal_map == _OrthMaps.cayley: 77 | # We just need n x k - k(k-1)/2 parameters 78 | X = X.tril() 79 | if n != k: 80 | # Embed into a square matrix 81 | X = torch.cat([X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1) 82 | A = X - X.mH 83 | # A is skew-symmetric (or skew-hermitian) 84 | if self.orthogonal_map == _OrthMaps.matrix_exp: 85 | Q = torch.matrix_exp(A) 86 | elif self.orthogonal_map == _OrthMaps.cayley: 87 | # Computes the Cayley retraction (I+A/2)(I-A/2)^{-1} 88 | Id = torch.eye(n, dtype=A.dtype, device=A.device) 89 | Q = torch.linalg.solve(torch.add(Id, A, alpha=-0.5), torch.add(Id, A, alpha=0.5)) 90 | # Q is now orthogonal (or unitary) of size (..., n, n) 91 | if n != k: 92 | Q = Q[..., :k] 93 | # Q is now the size of the X (albeit perhaps transposed) 94 | else: 95 | # X is real here, as we do not support householder with complex numbers 96 | A = X.tril(diagonal=-1) 97 | tau = 2. / (1. + (A * A).sum(dim=-2)) 98 | Q = torch.linalg.householder_product(A, tau) 99 | # The diagonal of X is 1's and -1's 100 | # We do not want to differentiate through this or update the diagonal of X hence the casting 101 | Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2) 102 | 103 | if hasattr(self, "base"): 104 | Q = self.base @ Q 105 | if transposed: 106 | Q = Q.mT 107 | return Q 108 | 109 | @torch.autograd.no_grad() 110 | def right_inverse(self, Q: torch.Tensor) -> torch.Tensor: 111 | if Q.shape != self.shape: 112 | raise ValueError(f"Expected a matrix or batch of matrices of shape {self.shape}. " 113 | f"Got a tensor of shape {Q.shape}.") 114 | 115 | Q_init = Q 116 | n, k = Q.size(-2), Q.size(-1) 117 | transpose = n < k 118 | if transpose: 119 | Q = Q.mT 120 | n, k = k, n 121 | 122 | # We always make sure to always copy Q in every path 123 | if not hasattr(self, "base"): 124 | # Note [right_inverse expm cayley] 125 | # If we do not have use_trivialization=True, we just implement the inverse of the forward 126 | # map for the Householder. To see why, think that for the Cayley map, 127 | # we would need to find the matrix X \in R^{n x k} such that: 128 | # Y = torch.cat([X.tril(), X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1) 129 | # A = Y - Y.mH 130 | # cayley(A)[:, :k] 131 | # gives the original tensor. It is not clear how to do this. 132 | # Perhaps via some algebraic manipulation involving the QR like that of 133 | # Corollary 2.2 in Edelman, Arias and Smith? 134 | if self.orthogonal_map == _OrthMaps.cayley or self.orthogonal_map == _OrthMaps.matrix_exp: 135 | raise NotImplementedError("It is not possible to assign to the matrix exponential " 136 | "or the Cayley parametrizations when use_trivialization=False.") 137 | 138 | # If parametrization == _OrthMaps.householder, make Q orthogonal via the QR decomposition. 139 | # Here Q is always real because we do not support householder and complex matrices. 140 | # See note [Householder complex] 141 | A, tau = torch.geqrf(Q) 142 | # We want to have a decomposition X = QR with diag(R) > 0, as otherwise we could 143 | # decompose an orthogonal matrix Q as Q = (-Q)@(-Id), which is a valid QR decomposition 144 | # The diagonal of Q is the diagonal of R from the qr decomposition 145 | A.diagonal(dim1=-2, dim2=-1).sign_() 146 | # Equality with zero is ok because LAPACK returns exactly zero when it does not want 147 | # to use a particular reflection 148 | A.diagonal(dim1=-2, dim2=-1)[tau == 0.] *= -1 149 | return A.mT if transpose else A 150 | else: 151 | if n == k: 152 | # We check whether Q is orthogonal 153 | if not _is_orthogonal(Q): 154 | Q = _make_orthogonal(Q) 155 | else: # Is orthogonal 156 | Q = Q.clone() 157 | else: 158 | # Complete Q into a full n x n orthogonal matrix 159 | N = torch.randn(*(Q.size()[:-2] + (n, n - k)), dtype=Q.dtype, device=Q.device) 160 | Q = torch.cat([Q, N], dim=-1) 161 | Q = _make_orthogonal(Q) 162 | self.base = Q 163 | 164 | # It is necessary to return the -Id, as we use the diagonal for the 165 | # Householder parametrization. Using -Id makes: 166 | # householder(torch.zeros(m,n)) == torch.eye(m,n) 167 | # Poor man's version of eye_like 168 | neg_Id = torch.zeros_like(Q_init) 169 | neg_Id.diagonal(dim1=-2, dim2=-1).fill_(-1.) 170 | return neg_Id 171 | 172 | 173 | def orthogonal(module: Module, 174 | name: str = 'weight', 175 | orthogonal_map: Optional[str] = None, 176 | *, 177 | use_trivialization: bool = True) -> Module: 178 | r"""Applies an orthogonal or unitary parametrization to a matrix or a batch of matrices. 179 | 180 | Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, the parametrized 181 | matrix :math:`Q \in \mathbb{K}^{m \times n}` is **orthogonal** as 182 | 183 | .. math:: 184 | 185 | \begin{align*} 186 | Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\ 187 | QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n} 188 | \end{align*} 189 | 190 | where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex 191 | and the transpose when :math:`Q` is real-valued, and 192 | :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. 193 | In plain words, :math:`Q` will have orthonormal columns whenever :math:`m \geq n` 194 | and orthonormal rows otherwise. 195 | 196 | If the tensor has more than two dimensions, we consider it as a batch of matrices of shape `(..., m, n)`. 197 | 198 | The matrix :math:`Q` may be parametrized via three different ``orthogonal_map`` in terms of the original tensor: 199 | 200 | - ``"matrix_exp"``/``"cayley"``: 201 | the :func:`~torch.matrix_exp` :math:`Q = \exp(A)` and the `Cayley map`_ 202 | :math:`Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1}` are applied to a skew-symmetric 203 | :math:`A` to give an orthogonal matrix. 204 | - ``"householder"``: computes a product of Householder reflectors 205 | (:func:`~torch.linalg.householder_product`). 206 | 207 | ``"matrix_exp"``/``"cayley"`` often make the parametrized weight converge faster than 208 | ``"householder"``, but they are slower to compute for very thin or very wide matrices. 209 | 210 | If ``use_trivialization=True`` (default), the parametrization implements the "Dynamic Trivialization Framework", 211 | where an extra matrix :math:`B \in \mathbb{K}^{n \times n}` is stored under 212 | ``module.parametrizations.weight[0].base``. This helps the 213 | convergence of the parametrized layer at the expense of some extra memory use. 214 | See `Trivializations for Gradient-Based Optimization on Manifolds`_ . 215 | 216 | Initial value of :math:`Q`: 217 | If the original tensor is not parametrized and ``use_trivialization=True`` (default), the initial value 218 | of :math:`Q` is that of the original tensor if it is orthogonal (or unitary in the complex case) 219 | and it is orthogonalized via the QR decomposition otherwise (see :func:`torch.linalg.qr`). 220 | Same happens when it is not parametrized and ``orthogonal_map="householder"`` even when ``use_trivialization=False``. 221 | Otherwise, the initial value is the result of the composition of all the registered 222 | parametrizations applied to the original tensor. 223 | 224 | .. note:: 225 | This function is implemented using the parametrization functionality 226 | in :func:`~torch.nn.utils.parametrize.register_parametrization`. 227 | 228 | 229 | .. _`Cayley map`: https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map 230 | .. _`Trivializations for Gradient-Based Optimization on Manifolds`: https://arxiv.org/abs/1909.09501 231 | 232 | Args: 233 | module (nn.Module): module on which to register the parametrization. 234 | name (str, optional): name of the tensor to make orthogonal. Default: ``"weight"``. 235 | orthogonal_map (str, optional): One of the following: ``"matrix_exp"``, ``"cayley"``, ``"householder"``. 236 | Default: ``"matrix_exp"`` if the matrix is square or complex, ``"householder"`` otherwise. 237 | use_trivialization (bool, optional): whether to use the dynamic trivialization framework. 238 | Default: ``True``. 239 | 240 | Returns: 241 | The original module with an orthogonal parametrization registered to the specified 242 | weight 243 | 244 | Example:: 245 | 246 | >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) 247 | >>> orth_linear = orthogonal(nn.Linear(20, 40)) 248 | >>> orth_linear 249 | ParametrizedLinear( 250 | in_features=20, out_features=40, bias=True 251 | (parametrizations): ModuleDict( 252 | (weight): ParametrizationList( 253 | (0): _Orthogonal() 254 | ) 255 | ) 256 | ) 257 | >>> # xdoctest: +IGNORE_WANT 258 | >>> Q = orth_linear.weight 259 | >>> torch.dist(Q.T @ Q, torch.eye(20)) 260 | tensor(4.9332e-07) 261 | """ 262 | weight = getattr(module, name, None) 263 | if not isinstance(weight, Tensor): 264 | raise ValueError( 265 | "Module '{}' has no parameter or buffer with name '{}'".format(module, name) 266 | ) 267 | 268 | # We could implement this for 1-dim tensors as the maps on the sphere 269 | # but I believe it'd bite more people than it'd help 270 | if weight.ndim < 2: 271 | raise ValueError("Expected a matrix or batch of matrices. " 272 | f"Got a tensor of {weight.ndim} dimensions.") 273 | 274 | if orthogonal_map is None: 275 | orthogonal_map = "matrix_exp" if weight.size(-2) == weight.size(-1) or weight.is_complex() else "householder" 276 | 277 | orth_enum = getattr(_OrthMaps, orthogonal_map, None) 278 | if orth_enum is None: 279 | raise ValueError('orthogonal_map has to be one of "matrix_exp", "cayley", "householder". ' 280 | f'Got: {orthogonal_map}') 281 | orth = _Orthogonal(weight, 282 | orth_enum, 283 | use_trivialization=use_trivialization) 284 | parametrize.register_parametrization(module, name, orth, unsafe=True) 285 | return module 286 | 287 | 288 | class _WeightNorm(Module): 289 | def __init__( 290 | self, 291 | dim: int = 0, 292 | ) -> None: 293 | super().__init__() 294 | if dim is None: 295 | dim = -1 296 | self.dim = dim 297 | 298 | def forward(self, weight_g, weight_v): 299 | return torch._weight_norm(weight_v, weight_g, self.dim) 300 | 301 | def right_inverse(self, weight): 302 | # TODO: is the .data necessary? 303 | weight_g = torch.norm_except_dim(weight, 2, self.dim).data 304 | weight_v = weight.data 305 | 306 | return weight_g, weight_v 307 | 308 | 309 | def weight_norm(module: Module, name: str = 'weight', dim: int = 0): 310 | r"""Applies weight normalization to a parameter in the given module. 311 | 312 | .. math:: 313 | \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} 314 | 315 | Weight normalization is a reparameterization that decouples the magnitude 316 | of a weight tensor from its direction. This replaces the parameter specified 317 | by :attr:`name` with two parameters: one specifying the magnitude 318 | and one specifying the direction. 319 | 320 | By default, with ``dim=0``, the norm is computed independently per output 321 | channel/plane. To compute a norm over the entire weight tensor, use 322 | ``dim=None``. 323 | 324 | See https://arxiv.org/abs/1602.07868 325 | 326 | Args: 327 | module (Module): containing module 328 | name (str, optional): name of weight parameter 329 | dim (int, optional): dimension over which to compute the norm 330 | 331 | Returns: 332 | The original module with the weight norm hook 333 | 334 | Example:: 335 | 336 | >>> m = weight_norm(nn.Linear(20, 40), name='weight') 337 | >>> m 338 | Linear(in_features=20, out_features=40, bias=True) 339 | >>> m.parametrizations.weight.original0.size() 340 | torch.Size([40, 1]) 341 | >>> m.parametrizations.weight.original1.size() 342 | torch.Size([40, 20]) 343 | 344 | """ 345 | weight = getattr(module, name, None) 346 | if not isinstance(weight, Tensor): 347 | raise ValueError( 348 | "Module '{}' has no parameter or buffer with name '{}'".format(module, name) 349 | ) 350 | 351 | _weight_norm = _WeightNorm(dim) 352 | parametrize.register_parametrization(module, name, _weight_norm, unsafe=True) 353 | 354 | def _weight_norm_compat_hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): 355 | g_key = f"{prefix}{name}_g" 356 | v_key = f"{prefix}{name}_v" 357 | if g_key in state_dict and v_key in state_dict: 358 | original0 = state_dict.pop(g_key) 359 | original1 = state_dict.pop(v_key) 360 | state_dict[f"{prefix}parametrizations.{name}.original0"] = original0 361 | state_dict[f"{prefix}parametrizations.{name}.original1"] = original1 362 | module._register_load_state_dict_pre_hook(_weight_norm_compat_hook) 363 | return module 364 | 365 | 366 | class _SpectralNorm(Module): 367 | def __init__( 368 | self, 369 | weight: torch.Tensor, 370 | n_power_iterations: int = 1, 371 | dim: int = 0, 372 | eps: float = 1e-12 373 | ) -> None: 374 | super().__init__() 375 | ndim = weight.ndim 376 | if dim >= ndim or dim < -ndim: 377 | raise IndexError("Dimension out of range (expected to be in range of " 378 | f"[-{ndim}, {ndim - 1}] but got {dim})") 379 | 380 | if n_power_iterations <= 0: 381 | raise ValueError('Expected n_power_iterations to be positive, but ' 382 | 'got n_power_iterations={}'.format(n_power_iterations)) 383 | self.dim = dim if dim >= 0 else dim + ndim 384 | self.eps = eps 385 | if ndim > 1: 386 | # For ndim == 1 we do not need to approximate anything (see _SpectralNorm.forward) 387 | self.n_power_iterations = n_power_iterations 388 | weight_mat = self._reshape_weight_to_matrix(weight) 389 | h, w = weight_mat.size() 390 | 391 | u = weight_mat.new_empty(h).normal_(0, 1) 392 | v = weight_mat.new_empty(w).normal_(0, 1) 393 | self.register_buffer('_u', F.normalize(u, dim=0, eps=self.eps)) 394 | self.register_buffer('_v', F.normalize(v, dim=0, eps=self.eps)) 395 | 396 | # Start with u, v initialized to some reasonable values by performing a number 397 | # of iterations of the power method 398 | self._power_method(weight_mat, 15) 399 | 400 | def _reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor: 401 | # Precondition 402 | assert weight.ndim > 1 403 | 404 | if self.dim != 0: 405 | # permute dim to front 406 | weight = weight.permute(self.dim, *(d for d in range(weight.dim()) if d != self.dim)) 407 | 408 | return weight.flatten(1) 409 | 410 | @torch.autograd.no_grad() 411 | def _power_method(self, weight_mat: torch.Tensor, n_power_iterations: int) -> None: 412 | # See original note at torch/nn/utils/spectral_norm.py 413 | # NB: If `do_power_iteration` is set, the `u` and `v` vectors are 414 | # updated in power iteration **in-place**. This is very important 415 | # because in `DataParallel` forward, the vectors (being buffers) are 416 | # broadcast from the parallelized module to each module replica, 417 | # which is a new module object created on the fly. And each replica 418 | # runs its own spectral norm power iteration. So simply assigning 419 | # the updated vectors to the module this function runs on will cause 420 | # the update to be lost forever. And the next time the parallelized 421 | # module is replicated, the same randomly initialized vectors are 422 | # broadcast and used! 423 | # 424 | # Therefore, to make the change propagate back, we rely on two 425 | # important behaviors (also enforced via tests): 426 | # 1. `DataParallel` doesn't clone storage if the broadcast tensor 427 | # is already on correct device; and it makes sure that the 428 | # parallelized module is already on `device[0]`. 429 | # 2. If the out tensor in `out=` kwarg has correct shape, it will 430 | # just fill in the values. 431 | # Therefore, since the same power iteration is performed on all 432 | # devices, simply updating the tensors in-place will make sure that 433 | # the module replica on `device[0]` will update the _u vector on the 434 | # parallelized module (by shared storage). 435 | # 436 | # However, after we update `u` and `v` in-place, we need to **clone** 437 | # them before using them to normalize the weight. This is to support 438 | # backproping through two forward passes, e.g., the common pattern in 439 | # GAN training: loss = D(real) - D(fake). Otherwise, engine will 440 | # complain that variables needed to do backward for the first forward 441 | # (i.e., the `u` and `v` vectors) are changed in the second forward. 442 | 443 | # Precondition 444 | assert weight_mat.ndim > 1 445 | 446 | for _ in range(n_power_iterations): 447 | # Spectral norm of weight equals to `u^T W v`, where `u` and `v` 448 | # are the first left and right singular vectors. 449 | # This power iteration produces approximations of `u` and `v`. 450 | self._u = F.normalize(torch.mv(weight_mat, self._v), # type: ignore[has-type] 451 | dim=0, eps=self.eps, out=self._u) # type: ignore[has-type] 452 | self._v = F.normalize(torch.mv(weight_mat.t(), self._u), 453 | dim=0, eps=self.eps, out=self._v) # type: ignore[has-type] 454 | 455 | def forward(self, weight: torch.Tensor) -> torch.Tensor: 456 | if weight.ndim == 1: 457 | # Faster and more exact path, no need to approximate anything 458 | return F.normalize(weight, dim=0, eps=self.eps) 459 | else: 460 | weight_mat = self._reshape_weight_to_matrix(weight) 461 | if self.training: 462 | self._power_method(weight_mat, self.n_power_iterations) 463 | # See above on why we need to clone 464 | u = self._u.clone(memory_format=torch.contiguous_format) 465 | v = self._v.clone(memory_format=torch.contiguous_format) 466 | # The proper way of computing this should be through F.bilinear, but 467 | # it seems to have some efficiency issues: 468 | # https://github.com/pytorch/pytorch/issues/58093 469 | sigma = torch.dot(u, torch.mv(weight_mat, v)) 470 | return weight / sigma 471 | 472 | def right_inverse(self, value: torch.Tensor) -> torch.Tensor: 473 | # we may want to assert here that the passed value already 474 | # satisfies constraints 475 | return value 476 | 477 | 478 | def spectral_norm(module: Module, 479 | name: str = 'weight', 480 | n_power_iterations: int = 1, 481 | eps: float = 1e-12, 482 | dim: Optional[int] = None) -> Module: 483 | r"""Applies spectral normalization to a parameter in the given module. 484 | 485 | .. math:: 486 | \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, 487 | \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} 488 | 489 | When applied on a vector, it simplifies to 490 | 491 | .. math:: 492 | \mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2} 493 | 494 | Spectral normalization stabilizes the training of discriminators (critics) 495 | in Generative Adversarial Networks (GANs) by reducing the Lipschitz constant 496 | of the model. :math:`\sigma` is approximated performing one iteration of the 497 | `power method`_ every time the weight is accessed. If the dimension of the 498 | weight tensor is greater than 2, it is reshaped to 2D in power iteration 499 | method to get spectral norm. 500 | 501 | 502 | See `Spectral Normalization for Generative Adversarial Networks`_ . 503 | 504 | .. _`power method`: https://en.wikipedia.org/wiki/Power_iteration 505 | .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 506 | 507 | .. note:: 508 | This function is implemented using the parametrization functionality 509 | in :func:`~torch.nn.utils.parametrize.register_parametrization`. It is a 510 | reimplementation of :func:`torch.nn.utils.spectral_norm`. 511 | 512 | .. note:: 513 | When this constraint is registered, the singular vectors associated to the largest 514 | singular value are estimated rather than sampled at random. These are then updated 515 | performing :attr:`n_power_iterations` of the `power method`_ whenever the tensor 516 | is accessed with the module on `training` mode. 517 | 518 | .. note:: 519 | If the `_SpectralNorm` module, i.e., `module.parametrization.weight[idx]`, 520 | is in training mode on removal, it will perform another power iteration. 521 | If you'd like to avoid this iteration, set the module to eval mode 522 | before its removal. 523 | 524 | Args: 525 | module (nn.Module): containing module 526 | name (str, optional): name of weight parameter. Default: ``"weight"``. 527 | n_power_iterations (int, optional): number of power iterations to 528 | calculate spectral norm. Default: ``1``. 529 | eps (float, optional): epsilon for numerical stability in 530 | calculating norms. Default: ``1e-12``. 531 | dim (int, optional): dimension corresponding to number of outputs. 532 | Default: ``0``, except for modules that are instances of 533 | ConvTranspose{1,2,3}d, when it is ``1`` 534 | 535 | Returns: 536 | The original module with a new parametrization registered to the specified 537 | weight 538 | 539 | Example:: 540 | 541 | >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) 542 | >>> # xdoctest: +IGNORE_WANT("non-deterministic") 543 | >>> snm = spectral_norm(nn.Linear(20, 40)) 544 | >>> snm 545 | ParametrizedLinear( 546 | in_features=20, out_features=40, bias=True 547 | (parametrizations): ModuleDict( 548 | (weight): ParametrizationList( 549 | (0): _SpectralNorm() 550 | ) 551 | ) 552 | ) 553 | >>> torch.linalg.matrix_norm(snm.weight, 2) 554 | tensor(1.0081, grad_fn=) 555 | """ 556 | weight = getattr(module, name, None) 557 | if not isinstance(weight, Tensor): 558 | raise ValueError( 559 | "Module '{}' has no parameter or buffer with name '{}'".format(module, name) 560 | ) 561 | 562 | if dim is None: 563 | if isinstance(module, (torch.nn.ConvTranspose1d, 564 | torch.nn.ConvTranspose2d, 565 | torch.nn.ConvTranspose3d)): 566 | dim = 1 567 | else: 568 | dim = 0 569 | parametrize.register_parametrization(module, name, _SpectralNorm(weight, n_power_iterations, dim, eps)) 570 | return module 571 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import math 2 | import multiprocessing 3 | import os 4 | import argparse 5 | from random import shuffle 6 | import torchaudio 7 | import torchaudio.transforms as T 8 | 9 | import torch 10 | from glob import glob 11 | from tqdm import tqdm 12 | 13 | # from audiolm_pytorch import SoundStream, EncodecWrapper 14 | import utils 15 | import logging 16 | 17 | logging.getLogger("numba").setLevel(logging.WARNING) 18 | import librosa 19 | import numpy as np 20 | 21 | hps = utils.get_hparams_from_file("config.json") 22 | sampling_rate = hps.data.sampling_rate 23 | hop_length = hps.data.hop_length 24 | in_dir = "" 25 | 26 | def process_one(filename, hmodel): 27 | wav, sr = torchaudio.load(filename) 28 | if wav.shape[0] > 1: # mix to mono 29 | wav = wav.mean(dim=0, keepdim=True) 30 | wav16k = T.Resample(sr, 16000)(wav) 31 | wav24k = T.Resample(sr, 24000)(wav) 32 | filename = filename.replace(in_dir, in_dir+"_processed").replace('.mp3','.wav').replace('.flac','.wav') 33 | wav24k_path = filename 34 | if not os.path.exists(os.path.dirname(wav24k_path)): 35 | os.makedirs(os.path.dirname(wav24k_path)) 36 | torchaudio.save(wav24k_path, wav24k, 24000) 37 | soft_path = filename + ".soft.pt" 38 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 39 | wav16k = wav16k.to(device) 40 | c = utils.get_hubert_content(hmodel, wav_16k_tensor=wav16k[0]) 41 | torch.save(c.cpu(), soft_path) 42 | 43 | f0_path = filename + ".f0.npy" 44 | f0 = utils.compute_f0_dio( 45 | wav24k.cpu().numpy()[0], sampling_rate=24000, hop_length=hop_length 46 | ) 47 | np.save(f0_path, f0) 48 | 49 | spec_path = filename.replace(".wav", ".spec.pt") 50 | spec_process = torchaudio.transforms.MelSpectrogram( 51 | sample_rate=24000, 52 | n_fft=1024, 53 | hop_length=256, 54 | n_mels=100, 55 | center=True, 56 | power=1, 57 | ) 58 | spec = spec_process(wav24k)# 1 100 T 59 | spec = torch.log(torch.clip(spec, min=1e-7)) 60 | torch.save(spec, spec_path) 61 | 62 | 63 | def process_batch(filenames): 64 | print("Loading hubert for content...") 65 | device = "cuda" if torch.cuda.is_available() else "cpu" 66 | hmodel = utils.get_hubert_model().to(device) 67 | # codec = EncodecWrapper() 68 | print("Loaded hubert.") 69 | for filename in tqdm(filenames): 70 | process_one(filename, hmodel) 71 | 72 | 73 | if __name__ == "__main__": 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument( 76 | "--in_dir", type=str, default="dataset", help="path to input dir" 77 | ) 78 | 79 | args = parser.parse_args() 80 | filenames = glob(f"{args.in_dir}/**/*.wav", recursive=True)+glob(f"{args.in_dir}/**/*.flac", recursive=True) # [:10] 81 | in_dir = args.in_dir 82 | shuffle(filenames) 83 | process_batch(filenames) 84 | -------------------------------------------------------------------------------- /raw/input_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/raw/input_here -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import json 5 | 6 | import torchaudio 7 | # from model import NaturalSpeech2, F0Predictor, Diffusion_Encoder, encode 8 | from dataset import NS2VCDataset, TextAudioCollate 9 | from torch.utils.data import Dataset, DataLoader 10 | from multiprocessing import cpu_count 11 | import torchaudio.transforms as T 12 | # from model import rvq_ce_loss 13 | 14 | 15 | 16 | # if __name__ == '__main__': 17 | # cfg = json.load(open('config.json')) 18 | 19 | # collate_fn = TextAudioCollate() 20 | # codec = EncodecWrapper() 21 | # ds = NS2VCDataset(cfg, codec) 22 | # dl = DataLoader(ds, batch_size = cfg['train']['train_batch_size'], shuffle = True, pin_memory = True, num_workers = 0, collate_fn = collate_fn) 23 | # # c_padded, refer_padded, f0_padded, codes_padded, wav_padded, lengths, refer_lengths, uv_padded = next(iter(dl)) 24 | # data = next(iter(dl)) 25 | # model = NaturalSpeech2(cfg) 26 | # out = model(data, codec) 27 | 28 | # print(c_padded.shape, refer_padded.shape, f0_padded.shape, codes_padded.shape, wav_padded.shape, lengths.shape, refer_lengths.shape, uv_padded.shape) 29 | # torch.Size([8, 256, 276]) torch.Size([8, 128, 276]) torch.Size([8, 276]) torch.Size([8, 128, 276]) torch.Size([8, 1, 88320]) torch.Size([8]) torch.Size([8]) torch.Size([8, 276]) 30 | 31 | # out.backward() 32 | 33 | # c_padded, refer_padded, f0_padded, codes_padded, wav_padded, lengths, refer_lengths, uv_padded = next(iter(dl)) 34 | # # c_padded refer_padded 35 | # c = c_padded 36 | # refer = refer_padded 37 | # f0 = f0_padded 38 | # uv = uv_padded 39 | # codec = EncodecWrapper() 40 | # with torch.no_grad(): 41 | # batches = num_to_groups(1, 1) 42 | # all_samples_list = list(map(lambda n: model.sample(c, refer, f0, uv, codec, batch_size=n), batches)) 43 | # all_samples = torch.cat(all_samples_list, dim = 0) 44 | # torchaudio.save(f'sample.wav', all_samples, 24000) 45 | # print(lengths) 46 | # print(refer_lengths) 47 | 48 | 49 | 50 | # phoneme_encoder = TextEncoder(**cfg['phoneme_encoder']) 51 | # f0_predictor = F0Predictor(**cfg['f0_predictor']) 52 | # prompt_encoder = TextEncoder(**cfg['prompt_encoder']) 53 | # diff_model = Diffusion_Encoder(**cfg['diffusion_encoder']) 54 | # audio_prompt = torch.randn(3, 256, 80) 55 | # contentvec = torch.randn(3, 256, 200) 56 | # f0 = torch.randint(1,100,(3, 200)) 57 | # noised_audio = torch.randn(3, 512, 200) 58 | # times = torch.randn(3) 59 | # audio_prompt_length = torch.tensor([3, 4, 5]) 60 | # contentvec_length = torch.tensor([3, 4, 5]) 61 | # #ok 62 | # audio_prompt = prompt_encoder(audio_prompt,audio_prompt_length) 63 | # #ok 64 | # f0_pred = f0_predictor(contentvec, audio_prompt, contentvec_length, audio_prompt_length) 65 | # #ok 66 | # content = phoneme_encoder(contentvec, contentvec_length,f0) 67 | # #ok 68 | # pred = diff_model( 69 | # noised_audio, 70 | # content, audio_prompt, 71 | # contentvec_length, audio_prompt_length, 72 | # times) 73 | 74 | # print(codes.shape)#24k 1 128 T2+1 75 | 76 | 77 | 78 | #reconstruction 79 | # codec = EncodecWrapper() 80 | # audio, sr = torchaudio.load('dataset/1.wav') 81 | # audio24k = T.Resample(sr, 24000)(audio) 82 | # torchaudio.save('1_24k.wav', audio24k, 24000) 83 | 84 | # codec.eval() 85 | # codes, _, _ = codec(audio24k, return_encoded = True) 86 | # audio = codec.decode(codes).squeeze(0) 87 | # torchaudio.save('1.wav', audio.detach(), 24000) 88 | 89 | # codec = EncodecWrapper() 90 | # gt = torch.randn(4, 128, 276) 91 | # pred = torch.randn(4, 128, 276) 92 | # _, indices, _, quantized_list = encode(gt,8,codec) 93 | # n_q=8 94 | # loss = rvq_ce_loss(gt.unsqueeze(0)-quantized_list, indices, codec, n_q) 95 | # print(loss) 96 | # loss = rvq_ce_loss(pred.unsqueeze(0)-quantized_list, indices, codec, n_q) 97 | # print(loss) 98 | # wav,sr = torchaudio.load('/home/hyc/val_dataset/common_voice_zh-CN_37110506.mp3') 99 | # wav24k = T.Resample(sr, 24000)(wav) 100 | # spec_process = torchaudio.transforms.MelSpectrogram( 101 | # sample_rate=24000, 102 | # n_fft=1024, 103 | # hop_length=256, 104 | # n_mels=100, 105 | # center=True, 106 | # power=1, 107 | # ) 108 | # spec = spec_process(wav24k)# 1 100 T 109 | # spec = torch.log(torch.clip(spec, min=1e-7)) 110 | # print(spec) 111 | # print(spec.shape) 112 | 113 | # prosody_process = torchaudio.transforms.MelSpectrogram( 114 | # sample_rate=24000, 115 | # n_fft=8192, 116 | # hop_length=4096, 117 | # n_mels=400, 118 | # center=True, 119 | # power=1, 120 | # ) 121 | # prosody = prosody_process(wav24k)# 1 400 T 122 | # prosody = torch.log(torch.clip(prosody, min=1e-7)) 123 | # prosody = torch.repeat_interleave(prosody, 16, dim=2) 124 | # prosody[:,:,16:] = (prosody[:,:,16:] + prosody[:,:,:-16]) / 2 125 | # print(prosody) 126 | # print(prosody.shape) 127 | 128 | import diffusers 129 | from diffusers import UNet1DModel,UNet2DConditionModel 130 | from model import NaturalSpeech2 131 | 132 | from unet1d import UNet1DConditionModel 133 | 134 | # a = torch.randn(4, 20, 10) 135 | # lengths = torch.tensor([10, 9, 8, 7]) 136 | # print(torch.arange(10)) 137 | # print(torch.arange(10).expand(4, 20, 10)) 138 | # mask = torch.arange(10).expand(4, 20, 10) >= lengths.unsqueeze(1).unsqueeze(1) 139 | # a = a.masked_fill(mask,0) 140 | # print(a) 141 | 142 | # unet2d = UNet2DConditionModel( 143 | # block_out_channels=(1,2,4,4), 144 | # norm_num_groups=1, 145 | # cross_attention_dim=16, 146 | # attention_head_dim=1, 147 | # ) 148 | # in_img = torch.randn(1,4,16,16) 149 | # cond = torch.randn(1,4,16) 150 | # out = unet2d(in_img, 3, cond) 151 | # print(out.sample.shape) 152 | 153 | # unet1d = UNet1DConditionModel( 154 | # in_channels=1, 155 | # out_channels=1, 156 | # block_out_channels=(4,8,8,8), 157 | # norm_num_groups=2, 158 | # cross_attention_dim=16, 159 | # attention_head_dim=2, 160 | # ) 161 | # audio = torch.randn(1,1,17) 162 | # cond = torch.randn(1,20,16) 163 | # out = unet1d(audio, 3, cond) 164 | # print(out.sample.shape) 165 | from nsf_hifigan.models import load_model 166 | import utils 167 | wav, sr = torchaudio.load('raw/test1.wav') 168 | wav = T.Resample(sr, 44100)(wav) 169 | spec_process = torchaudio.transforms.MelSpectrogram( 170 | sample_rate=44100, 171 | n_fft=2048, 172 | hop_length=512, 173 | n_mels=128, 174 | center=True, 175 | power=1, 176 | ) 177 | 178 | f0 = utils.compute_f0_dio( 179 | wav.cpu().numpy()[0], sampling_rate=44100, hop_length=512 180 | ) 181 | f0 = torch.Tensor(f0) 182 | mel = spec_process(wav) 183 | mel = torch.log(torch.clip(mel, min=1e-7)) 184 | device = 'cuda' 185 | vocoder = load_model('nsf_hifigan/model',device=device)[0] 186 | mel = mel.to(device) 187 | f0 = f0.to(device) 188 | length = min(mel.shape[2],f0.shape[0]) 189 | mel = mel[:,:,:length] 190 | f0 = f0[:length] 191 | wav = vocoder(mel, f0).cpu().squeeze(0) 192 | torchaudio.save('recon.wav', wav, 44100) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from model import Trainer 2 | 3 | trainer = Trainer() 4 | trainer.load('logs/vc/2023-09-28-20-49-43/model-639.pt') 5 | trainer.train() 6 | -------------------------------------------------------------------------------- /unet1d/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet_1d_condition import UNet1DConditionModel -------------------------------------------------------------------------------- /unet1d/activations.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | def get_activation(act_fn): 5 | if act_fn in ["swish", "silu"]: 6 | return nn.SiLU() 7 | elif act_fn == "mish": 8 | return nn.Mish() 9 | elif act_fn == "gelu": 10 | return nn.GELU() 11 | else: 12 | raise ValueError(f"Unsupported activation function: {act_fn}") 13 | -------------------------------------------------------------------------------- /unet1d/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Dict, Optional 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from torch import nn 19 | 20 | from .activations import get_activation 21 | from .attention_processor import Attention 22 | from .embeddings import CombinedTimestepLabelEmbeddings 23 | from .lora import LoRACompatibleLinear 24 | 25 | 26 | class BasicTransformerBlock(nn.Module): 27 | r""" 28 | A basic Transformer block. 29 | 30 | Parameters: 31 | dim (`int`): The number of channels in the input and output. 32 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 33 | attention_head_dim (`int`): The number of channels in each head. 34 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 35 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. 36 | only_cross_attention (`bool`, *optional*): 37 | Whether to use only cross-attention layers. In this case two cross attention layers are used. 38 | double_self_attention (`bool`, *optional*): 39 | Whether to use two self-attention layers. In this case no cross attention layers are used. 40 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 41 | num_embeds_ada_norm (: 42 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. 43 | attention_bias (: 44 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | dim: int, 50 | num_attention_heads: int, 51 | attention_head_dim: int, 52 | dropout=0.0, 53 | cross_attention_dim: Optional[int] = None, 54 | activation_fn: str = "geglu", 55 | num_embeds_ada_norm: Optional[int] = None, 56 | attention_bias: bool = False, 57 | only_cross_attention: bool = False, 58 | double_self_attention: bool = False, 59 | upcast_attention: bool = False, 60 | norm_elementwise_affine: bool = True, 61 | norm_type: str = "layer_norm", 62 | final_dropout: bool = False, 63 | ): 64 | super().__init__() 65 | self.only_cross_attention = only_cross_attention 66 | 67 | self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" 68 | self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" 69 | 70 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: 71 | raise ValueError( 72 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" 73 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." 74 | ) 75 | 76 | # Define 3 blocks. Each block has its own normalization layer. 77 | # 1. Self-Attn 78 | if self.use_ada_layer_norm: 79 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) 80 | elif self.use_ada_layer_norm_zero: 81 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) 82 | else: 83 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 84 | self.attn1 = Attention( 85 | query_dim=dim, 86 | heads=num_attention_heads, 87 | dim_head=attention_head_dim, 88 | dropout=dropout, 89 | bias=attention_bias, 90 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 91 | upcast_attention=upcast_attention, 92 | ) 93 | 94 | # 2. Cross-Attn 95 | if cross_attention_dim is not None or double_self_attention: 96 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 97 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 98 | # the second cross attention block. 99 | self.norm2 = ( 100 | AdaLayerNorm(dim, num_embeds_ada_norm) 101 | if self.use_ada_layer_norm 102 | else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 103 | ) 104 | self.attn2 = Attention( 105 | query_dim=dim, 106 | cross_attention_dim=cross_attention_dim if not double_self_attention else None, 107 | heads=num_attention_heads, 108 | dim_head=attention_head_dim, 109 | dropout=dropout, 110 | bias=attention_bias, 111 | upcast_attention=upcast_attention, 112 | ) # is self-attn if encoder_hidden_states is none 113 | else: 114 | self.norm2 = None 115 | self.attn2 = None 116 | 117 | # 3. Feed-forward 118 | self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 119 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) 120 | 121 | # let chunk size default to None 122 | self._chunk_size = None 123 | self._chunk_dim = 0 124 | 125 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): 126 | # Sets chunk feed-forward 127 | self._chunk_size = chunk_size 128 | self._chunk_dim = dim 129 | 130 | def forward( 131 | self, 132 | hidden_states: torch.FloatTensor, 133 | attention_mask: Optional[torch.FloatTensor] = None, 134 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 135 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 136 | timestep: Optional[torch.LongTensor] = None, 137 | cross_attention_kwargs: Dict[str, Any] = None, 138 | class_labels: Optional[torch.LongTensor] = None, 139 | ): 140 | # Notice that normalization is always applied before the real computation in the following blocks. 141 | # 1. Self-Attention 142 | if self.use_ada_layer_norm: 143 | norm_hidden_states = self.norm1(hidden_states, timestep) 144 | elif self.use_ada_layer_norm_zero: 145 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 146 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 147 | ) 148 | else: 149 | norm_hidden_states = self.norm1(hidden_states) 150 | 151 | cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} 152 | 153 | attn_output = self.attn1( 154 | norm_hidden_states, 155 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 156 | attention_mask=attention_mask, 157 | **cross_attention_kwargs, 158 | ) 159 | if self.use_ada_layer_norm_zero: 160 | attn_output = gate_msa.unsqueeze(1) * attn_output 161 | hidden_states = attn_output + hidden_states 162 | 163 | # 2. Cross-Attention 164 | if self.attn2 is not None: 165 | norm_hidden_states = ( 166 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 167 | ) 168 | 169 | attn_output = self.attn2( 170 | norm_hidden_states, 171 | encoder_hidden_states=encoder_hidden_states, 172 | attention_mask=encoder_attention_mask, 173 | **cross_attention_kwargs, 174 | ) 175 | hidden_states = attn_output + hidden_states 176 | 177 | # 3. Feed-forward 178 | norm_hidden_states = self.norm3(hidden_states) 179 | 180 | if self.use_ada_layer_norm_zero: 181 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 182 | 183 | if self._chunk_size is not None: 184 | # "feed_forward_chunk_size" can be used to save memory 185 | if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: 186 | raise ValueError( 187 | f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." 188 | ) 189 | 190 | num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size 191 | ff_output = torch.cat( 192 | [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], 193 | dim=self._chunk_dim, 194 | ) 195 | else: 196 | ff_output = self.ff(norm_hidden_states) 197 | 198 | if self.use_ada_layer_norm_zero: 199 | ff_output = gate_mlp.unsqueeze(1) * ff_output 200 | 201 | hidden_states = ff_output + hidden_states 202 | 203 | return hidden_states 204 | 205 | 206 | class FeedForward(nn.Module): 207 | r""" 208 | A feed-forward layer. 209 | 210 | Parameters: 211 | dim (`int`): The number of channels in the input. 212 | dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. 213 | mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. 214 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 215 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 216 | final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. 217 | """ 218 | 219 | def __init__( 220 | self, 221 | dim: int, 222 | dim_out: Optional[int] = None, 223 | mult: int = 4, 224 | dropout: float = 0.0, 225 | activation_fn: str = "geglu", 226 | final_dropout: bool = False, 227 | ): 228 | super().__init__() 229 | inner_dim = int(dim * mult) 230 | dim_out = dim_out if dim_out is not None else dim 231 | 232 | if activation_fn == "gelu": 233 | act_fn = GELU(dim, inner_dim) 234 | if activation_fn == "gelu-approximate": 235 | act_fn = GELU(dim, inner_dim, approximate="tanh") 236 | elif activation_fn == "geglu": 237 | act_fn = GEGLU(dim, inner_dim) 238 | elif activation_fn == "geglu-approximate": 239 | act_fn = ApproximateGELU(dim, inner_dim) 240 | 241 | self.net = nn.ModuleList([]) 242 | # project in 243 | self.net.append(act_fn) 244 | # project dropout 245 | self.net.append(nn.Dropout(dropout)) 246 | # project out 247 | self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) 248 | # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout 249 | if final_dropout: 250 | self.net.append(nn.Dropout(dropout)) 251 | 252 | def forward(self, hidden_states): 253 | for module in self.net: 254 | hidden_states = module(hidden_states) 255 | return hidden_states 256 | 257 | 258 | class GELU(nn.Module): 259 | r""" 260 | GELU activation function with tanh approximation support with `approximate="tanh"`. 261 | """ 262 | 263 | def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): 264 | super().__init__() 265 | self.proj = nn.Linear(dim_in, dim_out) 266 | self.approximate = approximate 267 | 268 | def gelu(self, gate): 269 | if gate.device.type != "mps": 270 | return F.gelu(gate, approximate=self.approximate) 271 | # mps: gelu is not implemented for float16 272 | return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) 273 | 274 | def forward(self, hidden_states): 275 | hidden_states = self.proj(hidden_states) 276 | hidden_states = self.gelu(hidden_states) 277 | return hidden_states 278 | 279 | 280 | class GEGLU(nn.Module): 281 | r""" 282 | A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. 283 | 284 | Parameters: 285 | dim_in (`int`): The number of channels in the input. 286 | dim_out (`int`): The number of channels in the output. 287 | """ 288 | 289 | def __init__(self, dim_in: int, dim_out: int): 290 | super().__init__() 291 | self.proj = LoRACompatibleLinear(dim_in, dim_out * 2) 292 | 293 | def gelu(self, gate): 294 | if gate.device.type != "mps": 295 | return F.gelu(gate) 296 | # mps: gelu is not implemented for float16 297 | return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) 298 | 299 | def forward(self, hidden_states): 300 | hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) 301 | return hidden_states * self.gelu(gate) 302 | 303 | 304 | class ApproximateGELU(nn.Module): 305 | """ 306 | The approximate form of Gaussian Error Linear Unit (GELU) 307 | 308 | For more details, see section 2: https://arxiv.org/abs/1606.08415 309 | """ 310 | 311 | def __init__(self, dim_in: int, dim_out: int): 312 | super().__init__() 313 | self.proj = nn.Linear(dim_in, dim_out) 314 | 315 | def forward(self, x): 316 | x = self.proj(x) 317 | return x * torch.sigmoid(1.702 * x) 318 | 319 | 320 | class AdaLayerNorm(nn.Module): 321 | """ 322 | Norm layer modified to incorporate timestep embeddings. 323 | """ 324 | 325 | def __init__(self, embedding_dim, num_embeddings): 326 | super().__init__() 327 | self.emb = nn.Embedding(num_embeddings, embedding_dim) 328 | self.silu = nn.SiLU() 329 | self.linear = nn.Linear(embedding_dim, embedding_dim * 2) 330 | self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) 331 | 332 | def forward(self, x, timestep): 333 | emb = self.linear(self.silu(self.emb(timestep))) 334 | scale, shift = torch.chunk(emb, 2) 335 | x = self.norm(x) * (1 + scale) + shift 336 | return x 337 | 338 | 339 | class AdaLayerNormZero(nn.Module): 340 | """ 341 | Norm layer adaptive layer norm zero (adaLN-Zero). 342 | """ 343 | 344 | def __init__(self, embedding_dim, num_embeddings): 345 | super().__init__() 346 | 347 | self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) 348 | 349 | self.silu = nn.SiLU() 350 | self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) 351 | self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) 352 | 353 | def forward(self, x, timestep, class_labels, hidden_dtype=None): 354 | emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype))) 355 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) 356 | x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] 357 | return x, gate_msa, shift_mlp, scale_mlp, gate_mlp 358 | 359 | 360 | class AdaGroupNorm(nn.Module): 361 | """ 362 | GroupNorm layer modified to incorporate timestep embeddings. 363 | """ 364 | 365 | def __init__( 366 | self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5 367 | ): 368 | super().__init__() 369 | self.num_groups = num_groups 370 | self.eps = eps 371 | 372 | if act_fn is None: 373 | self.act = None 374 | else: 375 | self.act = get_activation(act_fn) 376 | 377 | self.linear = nn.Linear(embedding_dim, out_dim * 2) 378 | 379 | def forward(self, x, emb): 380 | if self.act: 381 | emb = self.act(emb) 382 | emb = self.linear(emb) 383 | emb = emb[:, :, None, None] 384 | scale, shift = emb.chunk(2, dim=1) 385 | 386 | x = F.group_norm(x, self.num_groups, eps=self.eps) 387 | x = x * (1 + scale) + shift 388 | return x 389 | -------------------------------------------------------------------------------- /unet1d/dual_transformer_1d.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Optional 15 | 16 | from torch import nn 17 | 18 | from .transformer_1d import Transformer2DModel, Transformer2DModelOutput 19 | 20 | 21 | class DualTransformer2DModel(nn.Module): 22 | """ 23 | Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. 24 | 25 | Parameters: 26 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 27 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 28 | in_channels (`int`, *optional*): 29 | Pass if the input is continuous. The number of channels in the input and output. 30 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 31 | dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. 32 | cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. 33 | sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. 34 | Note that this is fixed at training time as it is used for learning a number of position embeddings. See 35 | `ImagePositionalEmbeddings`. 36 | num_vector_embeds (`int`, *optional*): 37 | Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. 38 | Includes the class for the masked latent pixel. 39 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 40 | num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. 41 | The number of diffusion steps used during training. Note that this is fixed at training time as it is used 42 | to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for 43 | up to but not more than steps than `num_embeds_ada_norm`. 44 | attention_bias (`bool`, *optional*): 45 | Configure if the TransformerBlocks' attention should contain a bias parameter. 46 | """ 47 | 48 | def __init__( 49 | self, 50 | num_attention_heads: int = 16, 51 | attention_head_dim: int = 88, 52 | in_channels: Optional[int] = None, 53 | num_layers: int = 1, 54 | dropout: float = 0.0, 55 | norm_num_groups: int = 32, 56 | cross_attention_dim: Optional[int] = None, 57 | attention_bias: bool = False, 58 | sample_size: Optional[int] = None, 59 | num_vector_embeds: Optional[int] = None, 60 | activation_fn: str = "geglu", 61 | num_embeds_ada_norm: Optional[int] = None, 62 | ): 63 | super().__init__() 64 | self.transformers = nn.ModuleList( 65 | [ 66 | Transformer2DModel( 67 | num_attention_heads=num_attention_heads, 68 | attention_head_dim=attention_head_dim, 69 | in_channels=in_channels, 70 | num_layers=num_layers, 71 | dropout=dropout, 72 | norm_num_groups=norm_num_groups, 73 | cross_attention_dim=cross_attention_dim, 74 | attention_bias=attention_bias, 75 | sample_size=sample_size, 76 | num_vector_embeds=num_vector_embeds, 77 | activation_fn=activation_fn, 78 | num_embeds_ada_norm=num_embeds_ada_norm, 79 | ) 80 | for _ in range(2) 81 | ] 82 | ) 83 | 84 | # Variables that can be set by a pipeline: 85 | 86 | # The ratio of transformer1 to transformer2's output states to be combined during inference 87 | self.mix_ratio = 0.5 88 | 89 | # The shape of `encoder_hidden_states` is expected to be 90 | # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` 91 | self.condition_lengths = [77, 257] 92 | 93 | # Which transformer to use to encode which condition. 94 | # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` 95 | self.transformer_index_for_condition = [1, 0] 96 | 97 | def forward( 98 | self, 99 | hidden_states, 100 | encoder_hidden_states, 101 | timestep=None, 102 | attention_mask=None, 103 | cross_attention_kwargs=None, 104 | return_dict: bool = True, 105 | ): 106 | """ 107 | Args: 108 | hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. 109 | When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input 110 | hidden_states 111 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): 112 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 113 | self-attention. 114 | timestep ( `torch.long`, *optional*): 115 | Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. 116 | attention_mask (`torch.FloatTensor`, *optional*): 117 | Optional attention mask to be applied in Attention 118 | return_dict (`bool`, *optional*, defaults to `True`): 119 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 120 | 121 | Returns: 122 | [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: 123 | [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When 124 | returning a tuple, the first element is the sample tensor. 125 | """ 126 | input_states = hidden_states 127 | 128 | encoded_states = [] 129 | tokens_start = 0 130 | # attention_mask is not used yet 131 | for i in range(2): 132 | # for each of the two transformers, pass the corresponding condition tokens 133 | condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] 134 | transformer_index = self.transformer_index_for_condition[i] 135 | encoded_state = self.transformers[transformer_index]( 136 | input_states, 137 | encoder_hidden_states=condition_state, 138 | timestep=timestep, 139 | cross_attention_kwargs=cross_attention_kwargs, 140 | return_dict=False, 141 | )[0] 142 | encoded_states.append(encoded_state - input_states) 143 | tokens_start += self.condition_lengths[i] 144 | 145 | output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) 146 | output_states = output_states + input_states 147 | 148 | if not return_dict: 149 | return (output_states,) 150 | 151 | return Transformer2DModelOutput(sample=output_states) 152 | -------------------------------------------------------------------------------- /unet1d/embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import math 15 | from typing import Optional 16 | 17 | import numpy as np 18 | import torch 19 | from torch import nn 20 | 21 | from .activations import get_activation 22 | 23 | 24 | def get_timestep_embedding( 25 | timesteps: torch.Tensor, 26 | embedding_dim: int, 27 | flip_sin_to_cos: bool = False, 28 | downscale_freq_shift: float = 1, 29 | scale: float = 1, 30 | max_period: int = 10000, 31 | ): 32 | """ 33 | This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. 34 | 35 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 36 | These may be fractional. 37 | :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the 38 | embeddings. :return: an [N x dim] Tensor of positional embeddings. 39 | """ 40 | assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" 41 | 42 | half_dim = embedding_dim // 2 43 | exponent = -math.log(max_period) * torch.arange( 44 | start=0, end=half_dim, dtype=torch.float32, device=timesteps.device 45 | ) 46 | exponent = exponent / (half_dim - downscale_freq_shift) 47 | 48 | emb = torch.exp(exponent) 49 | emb = timesteps[:, None].float() * emb[None, :] 50 | 51 | # scale embeddings 52 | emb = scale * emb 53 | 54 | # concat sine and cosine embeddings 55 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) 56 | 57 | # flip sine and cosine embeddings 58 | if flip_sin_to_cos: 59 | emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) 60 | 61 | # zero pad 62 | if embedding_dim % 2 == 1: 63 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 64 | return emb 65 | 66 | 67 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 68 | """ 69 | grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or 70 | [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 71 | """ 72 | grid_h = np.arange(grid_size, dtype=np.float32) 73 | grid_w = np.arange(grid_size, dtype=np.float32) 74 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 75 | grid = np.stack(grid, axis=0) 76 | 77 | grid = grid.reshape([2, 1, grid_size, grid_size]) 78 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 79 | if cls_token and extra_tokens > 0: 80 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 81 | return pos_embed 82 | 83 | 84 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 85 | if embed_dim % 2 != 0: 86 | raise ValueError("embed_dim must be divisible by 2") 87 | 88 | # use half of dimensions to encode grid_h 89 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 90 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 91 | 92 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 93 | return emb 94 | 95 | 96 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 97 | """ 98 | embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) 99 | """ 100 | if embed_dim % 2 != 0: 101 | raise ValueError("embed_dim must be divisible by 2") 102 | 103 | omega = np.arange(embed_dim // 2, dtype=np.float64) 104 | omega /= embed_dim / 2.0 105 | omega = 1.0 / 10000**omega # (D/2,) 106 | 107 | pos = pos.reshape(-1) # (M,) 108 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 109 | 110 | emb_sin = np.sin(out) # (M, D/2) 111 | emb_cos = np.cos(out) # (M, D/2) 112 | 113 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 114 | return emb 115 | 116 | 117 | class PatchEmbed(nn.Module): 118 | """2D Image to Patch Embedding""" 119 | 120 | def __init__( 121 | self, 122 | height=224, 123 | width=224, 124 | patch_size=16, 125 | in_channels=3, 126 | embed_dim=768, 127 | layer_norm=False, 128 | flatten=True, 129 | bias=True, 130 | ): 131 | super().__init__() 132 | 133 | num_patches = (height // patch_size) * (width // patch_size) 134 | self.flatten = flatten 135 | self.layer_norm = layer_norm 136 | 137 | self.proj = nn.Conv2d( 138 | in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias 139 | ) 140 | if layer_norm: 141 | self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) 142 | else: 143 | self.norm = None 144 | 145 | pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) 146 | self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) 147 | 148 | def forward(self, latent): 149 | latent = self.proj(latent) 150 | if self.flatten: 151 | latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC 152 | if self.layer_norm: 153 | latent = self.norm(latent) 154 | return latent + self.pos_embed 155 | 156 | 157 | class TimestepEmbedding(nn.Module): 158 | def __init__( 159 | self, 160 | in_channels: int, 161 | time_embed_dim: int, 162 | act_fn: str = "silu", 163 | out_dim: int = None, 164 | post_act_fn: Optional[str] = None, 165 | cond_proj_dim=None, 166 | ): 167 | super().__init__() 168 | 169 | self.linear_1 = nn.Linear(in_channels, time_embed_dim) 170 | 171 | if cond_proj_dim is not None: 172 | self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) 173 | else: 174 | self.cond_proj = None 175 | 176 | self.act = get_activation(act_fn) 177 | 178 | if out_dim is not None: 179 | time_embed_dim_out = out_dim 180 | else: 181 | time_embed_dim_out = time_embed_dim 182 | self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) 183 | 184 | if post_act_fn is None: 185 | self.post_act = None 186 | else: 187 | self.post_act = get_activation(post_act_fn) 188 | 189 | def forward(self, sample, condition=None): 190 | if condition is not None: 191 | sample = sample + self.cond_proj(condition) 192 | sample = self.linear_1(sample) 193 | 194 | if self.act is not None: 195 | sample = self.act(sample) 196 | 197 | sample = self.linear_2(sample) 198 | 199 | if self.post_act is not None: 200 | sample = self.post_act(sample) 201 | return sample 202 | 203 | 204 | class Timesteps(nn.Module): 205 | def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): 206 | super().__init__() 207 | self.num_channels = num_channels 208 | self.flip_sin_to_cos = flip_sin_to_cos 209 | self.downscale_freq_shift = downscale_freq_shift 210 | 211 | def forward(self, timesteps): 212 | t_emb = get_timestep_embedding( 213 | timesteps, 214 | self.num_channels, 215 | flip_sin_to_cos=self.flip_sin_to_cos, 216 | downscale_freq_shift=self.downscale_freq_shift, 217 | ) 218 | return t_emb 219 | 220 | 221 | class GaussianFourierProjection(nn.Module): 222 | """Gaussian Fourier embeddings for noise levels.""" 223 | 224 | def __init__( 225 | self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False 226 | ): 227 | super().__init__() 228 | self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 229 | self.log = log 230 | self.flip_sin_to_cos = flip_sin_to_cos 231 | 232 | if set_W_to_weight: 233 | # to delete later 234 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 235 | 236 | self.weight = self.W 237 | 238 | def forward(self, x): 239 | if self.log: 240 | x = torch.log(x) 241 | 242 | x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi 243 | 244 | if self.flip_sin_to_cos: 245 | out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) 246 | else: 247 | out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 248 | return out 249 | 250 | 251 | class ImagePositionalEmbeddings(nn.Module): 252 | """ 253 | Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the 254 | height and width of the latent space. 255 | 256 | For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 257 | 258 | For VQ-diffusion: 259 | 260 | Output vector embeddings are used as input for the transformer. 261 | 262 | Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. 263 | 264 | Args: 265 | num_embed (`int`): 266 | Number of embeddings for the latent pixels embeddings. 267 | height (`int`): 268 | Height of the latent image i.e. the number of height embeddings. 269 | width (`int`): 270 | Width of the latent image i.e. the number of width embeddings. 271 | embed_dim (`int`): 272 | Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. 273 | """ 274 | 275 | def __init__( 276 | self, 277 | num_embed: int, 278 | height: int, 279 | width: int, 280 | embed_dim: int, 281 | ): 282 | super().__init__() 283 | 284 | self.height = height 285 | self.width = width 286 | self.num_embed = num_embed 287 | self.embed_dim = embed_dim 288 | 289 | self.emb = nn.Embedding(self.num_embed, embed_dim) 290 | self.height_emb = nn.Embedding(self.height, embed_dim) 291 | self.width_emb = nn.Embedding(self.width, embed_dim) 292 | 293 | def forward(self, index): 294 | emb = self.emb(index) 295 | 296 | height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) 297 | 298 | # 1 x H x D -> 1 x H x 1 x D 299 | height_emb = height_emb.unsqueeze(2) 300 | 301 | width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) 302 | 303 | # 1 x W x D -> 1 x 1 x W x D 304 | width_emb = width_emb.unsqueeze(1) 305 | 306 | pos_emb = height_emb + width_emb 307 | 308 | # 1 x H x W x D -> 1 x L xD 309 | pos_emb = pos_emb.view(1, self.height * self.width, -1) 310 | 311 | emb = emb + pos_emb[:, : emb.shape[1], :] 312 | 313 | return emb 314 | 315 | 316 | class LabelEmbedding(nn.Module): 317 | """ 318 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 319 | 320 | Args: 321 | num_classes (`int`): The number of classes. 322 | hidden_size (`int`): The size of the vector embeddings. 323 | dropout_prob (`float`): The probability of dropping a label. 324 | """ 325 | 326 | def __init__(self, num_classes, hidden_size, dropout_prob): 327 | super().__init__() 328 | use_cfg_embedding = dropout_prob > 0 329 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 330 | self.num_classes = num_classes 331 | self.dropout_prob = dropout_prob 332 | 333 | def token_drop(self, labels, force_drop_ids=None): 334 | """ 335 | Drops labels to enable classifier-free guidance. 336 | """ 337 | if force_drop_ids is None: 338 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 339 | else: 340 | drop_ids = torch.tensor(force_drop_ids == 1) 341 | labels = torch.where(drop_ids, self.num_classes, labels) 342 | return labels 343 | 344 | def forward(self, labels: torch.LongTensor, force_drop_ids=None): 345 | use_dropout = self.dropout_prob > 0 346 | if (self.training and use_dropout) or (force_drop_ids is not None): 347 | labels = self.token_drop(labels, force_drop_ids) 348 | embeddings = self.embedding_table(labels) 349 | return embeddings 350 | 351 | 352 | class TextImageProjection(nn.Module): 353 | def __init__( 354 | self, 355 | text_embed_dim: int = 1024, 356 | image_embed_dim: int = 768, 357 | cross_attention_dim: int = 768, 358 | num_image_text_embeds: int = 10, 359 | ): 360 | super().__init__() 361 | 362 | self.num_image_text_embeds = num_image_text_embeds 363 | self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) 364 | self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) 365 | 366 | def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor): 367 | batch_size = text_embeds.shape[0] 368 | 369 | # image 370 | image_text_embeds = self.image_embeds(image_embeds) 371 | image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1) 372 | 373 | # text 374 | text_embeds = self.text_proj(text_embeds) 375 | 376 | return torch.cat([image_text_embeds, text_embeds], dim=1) 377 | 378 | 379 | class ImageProjection(nn.Module): 380 | def __init__( 381 | self, 382 | image_embed_dim: int = 768, 383 | cross_attention_dim: int = 768, 384 | num_image_text_embeds: int = 32, 385 | ): 386 | super().__init__() 387 | 388 | self.num_image_text_embeds = num_image_text_embeds 389 | self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) 390 | self.norm = nn.LayerNorm(cross_attention_dim) 391 | 392 | def forward(self, image_embeds: torch.FloatTensor): 393 | batch_size = image_embeds.shape[0] 394 | 395 | # image 396 | image_embeds = self.image_embeds(image_embeds) 397 | image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) 398 | image_embeds = self.norm(image_embeds) 399 | return image_embeds 400 | 401 | 402 | class CombinedTimestepLabelEmbeddings(nn.Module): 403 | def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): 404 | super().__init__() 405 | 406 | self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) 407 | self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) 408 | self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) 409 | 410 | def forward(self, timestep, class_labels, hidden_dtype=None): 411 | timesteps_proj = self.time_proj(timestep) 412 | timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) 413 | 414 | class_labels = self.class_embedder(class_labels) # (N, D) 415 | 416 | conditioning = timesteps_emb + class_labels # (N, D) 417 | 418 | return conditioning 419 | 420 | 421 | class TextTimeEmbedding(nn.Module): 422 | def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): 423 | super().__init__() 424 | self.norm1 = nn.LayerNorm(encoder_dim) 425 | self.pool = AttentionPooling(num_heads, encoder_dim) 426 | self.proj = nn.Linear(encoder_dim, time_embed_dim) 427 | self.norm2 = nn.LayerNorm(time_embed_dim) 428 | 429 | def forward(self, hidden_states): 430 | hidden_states = self.norm1(hidden_states) 431 | hidden_states = self.pool(hidden_states) 432 | hidden_states = self.proj(hidden_states) 433 | hidden_states = self.norm2(hidden_states) 434 | return hidden_states 435 | 436 | 437 | class TextImageTimeEmbedding(nn.Module): 438 | def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): 439 | super().__init__() 440 | self.text_proj = nn.Linear(text_embed_dim, time_embed_dim) 441 | self.text_norm = nn.LayerNorm(time_embed_dim) 442 | self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) 443 | 444 | def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor): 445 | # text 446 | time_text_embeds = self.text_proj(text_embeds) 447 | time_text_embeds = self.text_norm(time_text_embeds) 448 | 449 | # image 450 | time_image_embeds = self.image_proj(image_embeds) 451 | 452 | return time_image_embeds + time_text_embeds 453 | 454 | 455 | class ImageTimeEmbedding(nn.Module): 456 | def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): 457 | super().__init__() 458 | self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) 459 | self.image_norm = nn.LayerNorm(time_embed_dim) 460 | 461 | def forward(self, image_embeds: torch.FloatTensor): 462 | # image 463 | time_image_embeds = self.image_proj(image_embeds) 464 | time_image_embeds = self.image_norm(time_image_embeds) 465 | return time_image_embeds 466 | 467 | 468 | class ImageHintTimeEmbedding(nn.Module): 469 | def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): 470 | super().__init__() 471 | self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) 472 | self.image_norm = nn.LayerNorm(time_embed_dim) 473 | self.input_hint_block = nn.Sequential( 474 | nn.Conv2d(3, 16, 3, padding=1), 475 | nn.SiLU(), 476 | nn.Conv2d(16, 16, 3, padding=1), 477 | nn.SiLU(), 478 | nn.Conv2d(16, 32, 3, padding=1, stride=2), 479 | nn.SiLU(), 480 | nn.Conv2d(32, 32, 3, padding=1), 481 | nn.SiLU(), 482 | nn.Conv2d(32, 96, 3, padding=1, stride=2), 483 | nn.SiLU(), 484 | nn.Conv2d(96, 96, 3, padding=1), 485 | nn.SiLU(), 486 | nn.Conv2d(96, 256, 3, padding=1, stride=2), 487 | nn.SiLU(), 488 | nn.Conv2d(256, 4, 3, padding=1), 489 | ) 490 | 491 | def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor): 492 | # image 493 | time_image_embeds = self.image_proj(image_embeds) 494 | time_image_embeds = self.image_norm(time_image_embeds) 495 | hint = self.input_hint_block(hint) 496 | return time_image_embeds, hint 497 | 498 | 499 | class AttentionPooling(nn.Module): 500 | # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 501 | 502 | def __init__(self, num_heads, embed_dim, dtype=None): 503 | super().__init__() 504 | self.dtype = dtype 505 | self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) 506 | self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) 507 | self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) 508 | self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) 509 | self.num_heads = num_heads 510 | self.dim_per_head = embed_dim // self.num_heads 511 | 512 | def forward(self, x): 513 | bs, length, width = x.size() 514 | 515 | def shape(x): 516 | # (bs, length, width) --> (bs, length, n_heads, dim_per_head) 517 | x = x.view(bs, -1, self.num_heads, self.dim_per_head) 518 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 519 | x = x.transpose(1, 2) 520 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 521 | x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) 522 | # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length) 523 | x = x.transpose(1, 2) 524 | return x 525 | 526 | class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) 527 | x = torch.cat([class_token, x], dim=1) # (bs, length+1, width) 528 | 529 | # (bs*n_heads, class_token_length, dim_per_head) 530 | q = shape(self.q_proj(class_token)) 531 | # (bs*n_heads, length+class_token_length, dim_per_head) 532 | k = shape(self.k_proj(x)) 533 | v = shape(self.v_proj(x)) 534 | 535 | # (bs*n_heads, class_token_length, length+class_token_length): 536 | scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) 537 | weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards 538 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 539 | 540 | # (bs*n_heads, dim_per_head, class_token_length) 541 | a = torch.einsum("bts,bcs->bct", weight, v) 542 | 543 | # (bs, length+1, width) 544 | a = a.reshape(bs, -1, 1).transpose(1, 2) 545 | 546 | return a[:, 0, :] # cls_token 547 | -------------------------------------------------------------------------------- /unet1d/lora.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Optional 16 | 17 | import torch.nn.functional as F 18 | from torch import nn 19 | 20 | 21 | class LoRALinearLayer(nn.Module): 22 | def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): 23 | super().__init__() 24 | 25 | if rank > min(in_features, out_features): 26 | raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") 27 | 28 | self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) 29 | self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) 30 | # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. 31 | # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning 32 | self.network_alpha = network_alpha 33 | self.rank = rank 34 | 35 | nn.init.normal_(self.down.weight, std=1 / rank) 36 | nn.init.zeros_(self.up.weight) 37 | 38 | def forward(self, hidden_states): 39 | orig_dtype = hidden_states.dtype 40 | dtype = self.down.weight.dtype 41 | 42 | down_hidden_states = self.down(hidden_states.to(dtype)) 43 | up_hidden_states = self.up(down_hidden_states) 44 | 45 | if self.network_alpha is not None: 46 | up_hidden_states *= self.network_alpha / self.rank 47 | 48 | return up_hidden_states.to(orig_dtype) 49 | 50 | 51 | class LoRAConv1dLayer(nn.Module): 52 | def __init__( 53 | self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None 54 | ): 55 | super().__init__() 56 | 57 | if rank > min(in_features, out_features): 58 | raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") 59 | 60 | self.down = nn.Conv1d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) 61 | # according to the official kohya_ss trainer kernel_size are always fixed for the up layer 62 | # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129 63 | self.up = nn.Conv1d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False) 64 | 65 | # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. 66 | # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning 67 | self.network_alpha = network_alpha 68 | self.rank = rank 69 | 70 | nn.init.normal_(self.down.weight, std=1 / rank) 71 | nn.init.zeros_(self.up.weight) 72 | 73 | def forward(self, hidden_states): 74 | orig_dtype = hidden_states.dtype 75 | dtype = self.down.weight.dtype 76 | 77 | down_hidden_states = self.down(hidden_states.to(dtype)) 78 | up_hidden_states = self.up(down_hidden_states) 79 | 80 | if self.network_alpha is not None: 81 | up_hidden_states *= self.network_alpha / self.rank 82 | 83 | return up_hidden_states.to(orig_dtype) 84 | 85 | 86 | class LoRACompatibleConv(nn.Conv1d): 87 | """ 88 | A convolutional layer that can be used with LoRA. 89 | """ 90 | 91 | def __init__(self, *args, lora_layer: Optional[LoRAConv1dLayer] = None, **kwargs): 92 | super().__init__(*args, **kwargs) 93 | self.lora_layer = lora_layer 94 | 95 | def set_lora_layer(self, lora_layer: Optional[LoRAConv1dLayer]): 96 | self.lora_layer = lora_layer 97 | 98 | def forward(self, x): 99 | if self.lora_layer is None: 100 | # make sure to the functional Conv2D function as otherwise torch.compile's graph will break 101 | # see: https://github.com/huggingface/diffusers/pull/4315 102 | return F.conv1d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 103 | else: 104 | return super().forward(x) + self.lora_layer(x) 105 | 106 | 107 | class LoRACompatibleLinear(nn.Linear): 108 | """ 109 | A Linear layer that can be used with LoRA. 110 | """ 111 | 112 | def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs): 113 | super().__init__(*args, **kwargs) 114 | self.lora_layer = lora_layer 115 | 116 | def set_lora_layer(self, lora_layer: Optional[LoRAConv1dLayer]): 117 | self.lora_layer = lora_layer 118 | 119 | def forward(self, x): 120 | if self.lora_layer is None: 121 | return super().forward(x) 122 | else: 123 | return super().forward(x) + self.lora_layer(x) 124 | -------------------------------------------------------------------------------- /unet1d/outputs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Generic utilities 16 | """ 17 | 18 | from collections import OrderedDict 19 | from dataclasses import fields 20 | from typing import Any, Tuple 21 | 22 | import numpy as np 23 | 24 | def is_tensor(x): 25 | """ 26 | Tests if `x` is a `torch.Tensor` or `np.ndarray`. 27 | """ 28 | import torch 29 | 30 | if isinstance(x, torch.Tensor): 31 | return True 32 | 33 | return isinstance(x, np.ndarray) 34 | 35 | 36 | class BaseOutput(OrderedDict): 37 | """ 38 | Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a 39 | tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular 40 | Python dictionary. 41 | 42 | 43 | 44 | You can't unpack a [`BaseOutput`] directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple 45 | first. 46 | 47 | 48 | """ 49 | 50 | def __post_init__(self): 51 | class_fields = fields(self) 52 | 53 | # Safety and consistency checks 54 | if not len(class_fields): 55 | raise ValueError(f"{self.__class__.__name__} has no fields.") 56 | 57 | first_field = getattr(self, class_fields[0].name) 58 | other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) 59 | 60 | if other_fields_are_none and isinstance(first_field, dict): 61 | for key, value in first_field.items(): 62 | self[key] = value 63 | else: 64 | for field in class_fields: 65 | v = getattr(self, field.name) 66 | if v is not None: 67 | self[field.name] = v 68 | 69 | def __delitem__(self, *args, **kwargs): 70 | raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") 71 | 72 | def setdefault(self, *args, **kwargs): 73 | raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") 74 | 75 | def pop(self, *args, **kwargs): 76 | raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") 77 | 78 | def update(self, *args, **kwargs): 79 | raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") 80 | 81 | def __getitem__(self, k): 82 | if isinstance(k, str): 83 | inner_dict = dict(self.items()) 84 | return inner_dict[k] 85 | else: 86 | return self.to_tuple()[k] 87 | 88 | def __setattr__(self, name, value): 89 | if name in self.keys() and value is not None: 90 | # Don't call self.__setitem__ to avoid recursion errors 91 | super().__setitem__(name, value) 92 | super().__setattr__(name, value) 93 | 94 | def __setitem__(self, key, value): 95 | # Will raise a KeyException if needed 96 | super().__setitem__(key, value) 97 | # Don't call self.__setattr__ to avoid recursion errors 98 | super().__setattr__(key, value) 99 | 100 | def to_tuple(self) -> Tuple[Any]: 101 | """ 102 | Convert self to a tuple containing all the attributes/keys that are not `None`. 103 | """ 104 | return tuple(self[k] for k in self.keys()) 105 | -------------------------------------------------------------------------------- /unet1d/transformer_1d.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from dataclasses import dataclass 15 | from typing import Any, Dict, Optional 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | from torch import nn 20 | 21 | from .outputs import BaseOutput 22 | from .attention import BasicTransformerBlock 23 | from .embeddings import PatchEmbed 24 | from .lora import LoRACompatibleConv, LoRACompatibleLinear 25 | 26 | 27 | @dataclass 28 | class Transformer2DModelOutput(BaseOutput): 29 | """ 30 | The output of [`Transformer2DModel`]. 31 | 32 | Args: 33 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): 34 | The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability 35 | distributions for the unnoised latent pixels. 36 | """ 37 | 38 | sample: torch.FloatTensor 39 | 40 | 41 | class Transformer2DModel(nn.Module): 42 | """ 43 | A 2D Transformer model for image-like data. 44 | 45 | Parameters: 46 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 47 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 48 | in_channels (`int`, *optional*): 49 | The number of channels in the input and output (specify if the input is **continuous**). 50 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 51 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 52 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 53 | sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). 54 | This is fixed during training since it is used to learn a number of position embeddings. 55 | num_vector_embeds (`int`, *optional*): 56 | The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). 57 | Includes the class for the masked latent pixel. 58 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. 59 | num_embeds_ada_norm ( `int`, *optional*): 60 | The number of diffusion steps used during training. Pass if at least one of the norm_layers is 61 | `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are 62 | added to the hidden states. 63 | 64 | During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. 65 | attention_bias (`bool`, *optional*): 66 | Configure if the `TransformerBlocks` attention should contain a bias parameter. 67 | """ 68 | 69 | def __init__( 70 | self, 71 | num_attention_heads: int = 16, 72 | attention_head_dim: int = 88, 73 | in_channels: Optional[int] = None, 74 | out_channels: Optional[int] = None, 75 | num_layers: int = 1, 76 | dropout: float = 0.0, 77 | norm_num_groups: int = 32, 78 | cross_attention_dim: Optional[int] = None, 79 | attention_bias: bool = False, 80 | sample_size: Optional[int] = None, 81 | num_vector_embeds: Optional[int] = None, 82 | patch_size: Optional[int] = None, 83 | activation_fn: str = "geglu", 84 | num_embeds_ada_norm: Optional[int] = None, 85 | use_linear_projection: bool = False, 86 | only_cross_attention: bool = False, 87 | upcast_attention: bool = False, 88 | norm_type: str = "layer_norm", 89 | norm_elementwise_affine: bool = True, 90 | ): 91 | super().__init__() 92 | self.use_linear_projection = use_linear_projection 93 | self.num_attention_heads = num_attention_heads 94 | self.attention_head_dim = attention_head_dim 95 | inner_dim = num_attention_heads * attention_head_dim 96 | 97 | # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` 98 | # Define whether input is continuous or discrete depending on configuration 99 | self.is_input_continuous = (in_channels is not None) and (patch_size is None) 100 | self.is_input_vectorized = num_vector_embeds is not None 101 | self.is_input_patches = in_channels is not None and patch_size is not None 102 | 103 | if norm_type == "layer_norm" and num_embeds_ada_norm is not None: 104 | deprecation_message = ( 105 | f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" 106 | " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." 107 | " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" 108 | " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" 109 | " would be very nice if you could open a Pull request for the `transformer/config.json` file" 110 | ) 111 | # deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) 112 | norm_type = "ada_norm" 113 | 114 | if self.is_input_continuous and self.is_input_vectorized: 115 | raise ValueError( 116 | f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" 117 | " sure that either `in_channels` or `num_vector_embeds` is None." 118 | ) 119 | elif self.is_input_vectorized and self.is_input_patches: 120 | raise ValueError( 121 | f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" 122 | " sure that either `num_vector_embeds` or `num_patches` is None." 123 | ) 124 | elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: 125 | raise ValueError( 126 | f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" 127 | f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." 128 | ) 129 | 130 | # 2. Define input layers 131 | if self.is_input_continuous: 132 | self.in_channels = in_channels 133 | 134 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 135 | if use_linear_projection: 136 | self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) 137 | else: 138 | self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 139 | elif self.is_input_patches: 140 | assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" 141 | 142 | self.height = sample_size 143 | self.width = sample_size 144 | 145 | self.patch_size = patch_size 146 | self.pos_embed = PatchEmbed( 147 | height=sample_size, 148 | width=sample_size, 149 | patch_size=patch_size, 150 | in_channels=in_channels, 151 | embed_dim=inner_dim, 152 | ) 153 | 154 | # 3. Define transformers blocks 155 | self.transformer_blocks = nn.ModuleList( 156 | [ 157 | BasicTransformerBlock( 158 | inner_dim, 159 | num_attention_heads, 160 | attention_head_dim, 161 | dropout=dropout, 162 | cross_attention_dim=cross_attention_dim, 163 | activation_fn=activation_fn, 164 | num_embeds_ada_norm=num_embeds_ada_norm, 165 | attention_bias=attention_bias, 166 | only_cross_attention=only_cross_attention, 167 | upcast_attention=upcast_attention, 168 | norm_type=norm_type, 169 | norm_elementwise_affine=norm_elementwise_affine, 170 | ) 171 | for d in range(num_layers) 172 | ] 173 | ) 174 | 175 | # 4. Define output layers 176 | self.out_channels = in_channels if out_channels is None else out_channels 177 | if self.is_input_continuous: 178 | # TODO: should use out_channels for continuous projections 179 | if use_linear_projection: 180 | self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) 181 | else: 182 | self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 183 | elif self.is_input_vectorized: 184 | self.norm_out = nn.LayerNorm(inner_dim) 185 | self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) 186 | elif self.is_input_patches: 187 | self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) 188 | self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) 189 | self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) 190 | 191 | def forward( 192 | self, 193 | hidden_states: torch.Tensor, 194 | encoder_hidden_states: Optional[torch.Tensor] = None, 195 | timestep: Optional[torch.LongTensor] = None, 196 | class_labels: Optional[torch.LongTensor] = None, 197 | cross_attention_kwargs: Dict[str, Any] = None, 198 | attention_mask: Optional[torch.Tensor] = None, 199 | encoder_attention_mask: Optional[torch.Tensor] = None, 200 | return_dict: bool = True, 201 | ): 202 | """ 203 | The [`Transformer2DModel`] forward method. 204 | 205 | Args: 206 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): 207 | Input `hidden_states`. 208 | encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): 209 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 210 | self-attention. 211 | timestep ( `torch.LongTensor`, *optional*): 212 | Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. 213 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): 214 | Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in 215 | `AdaLayerZeroNorm`. 216 | encoder_attention_mask ( `torch.Tensor`, *optional*): 217 | Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: 218 | 219 | * Mask `(batch, sequence_length)` True = keep, False = discard. 220 | * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. 221 | 222 | If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format 223 | above. This bias will be added to the cross-attention scores. 224 | return_dict (`bool`, *optional*, defaults to `True`): 225 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain 226 | tuple. 227 | 228 | Returns: 229 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 230 | `tuple` where the first element is the sample tensor. 231 | """ 232 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. 233 | # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. 234 | # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. 235 | # expects mask of shape: 236 | # [batch, key_tokens] 237 | # adds singleton query_tokens dimension: 238 | # [batch, 1, key_tokens] 239 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: 240 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) 241 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) 242 | if attention_mask is not None and attention_mask.ndim == 2: 243 | # assume that mask is expressed as: 244 | # (1 = keep, 0 = discard) 245 | # convert mask into a bias that can be added to attention scores: 246 | # (keep = +0, discard = -10000.0) 247 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 248 | attention_mask = attention_mask.unsqueeze(1) 249 | 250 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 251 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: 252 | encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 253 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 254 | 255 | # 1. Input 256 | if self.is_input_continuous: 257 | batch, _, time_length = hidden_states.shape 258 | residual = hidden_states 259 | 260 | hidden_states = self.norm(hidden_states) 261 | if not self.use_linear_projection: 262 | hidden_states = self.proj_in(hidden_states) 263 | inner_dim = hidden_states.shape[1] 264 | hidden_states = hidden_states.permute(0, 2, 1) 265 | else: 266 | inner_dim = hidden_states.shape[1] 267 | hidden_states = hidden_states.permute(0, 2, 1) 268 | hidden_states = self.proj_in(hidden_states) 269 | elif self.is_input_vectorized: 270 | hidden_states = self.latent_image_embedding(hidden_states) 271 | elif self.is_input_patches: 272 | hidden_states = self.pos_embed(hidden_states) 273 | 274 | # 2. Blocks 275 | for block in self.transformer_blocks: 276 | hidden_states = block( 277 | hidden_states, 278 | attention_mask=attention_mask, 279 | encoder_hidden_states=encoder_hidden_states, 280 | encoder_attention_mask=encoder_attention_mask, 281 | timestep=timestep, 282 | cross_attention_kwargs=cross_attention_kwargs, 283 | class_labels=class_labels, 284 | ) 285 | 286 | # 3. Output 287 | if self.is_input_continuous: 288 | if not self.use_linear_projection: 289 | hidden_states = hidden_states.permute(0, 2, 1).contiguous() 290 | hidden_states = self.proj_out(hidden_states) 291 | else: 292 | hidden_states = self.proj_out(hidden_states) 293 | hidden_states = hidden_states.permute(0, 2, 1).contiguous() 294 | 295 | output = hidden_states + residual 296 | elif self.is_input_vectorized: 297 | hidden_states = self.norm_out(hidden_states) 298 | logits = self.out(hidden_states) 299 | # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) 300 | logits = logits.permute(0, 2, 1) 301 | 302 | # log(p(x_0)) 303 | output = F.log_softmax(logits.double(), dim=1).float() 304 | elif self.is_input_patches: 305 | # TODO: cleanup! 306 | conditioning = self.transformer_blocks[0].norm1.emb( 307 | timestep, class_labels, hidden_dtype=hidden_states.dtype 308 | ) 309 | shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) 310 | hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] 311 | hidden_states = self.proj_out_2(hidden_states) 312 | 313 | # unpatchify 314 | height = width = int(hidden_states.shape[1] ** 0.5) 315 | hidden_states = hidden_states.reshape( 316 | shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) 317 | ) 318 | hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) 319 | output = hidden_states.reshape( 320 | shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) 321 | ) 322 | 323 | if not return_dict: 324 | return (output,) 325 | 326 | return Transformer2DModelOutput(sample=output) 327 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import re 4 | import sys 5 | import argparse 6 | import logging 7 | import json 8 | import subprocess 9 | import warnings 10 | import random 11 | import functools 12 | 13 | import librosa 14 | import numpy as np 15 | from scipy.io.wavfile import read 16 | import torch 17 | from torch.nn import functional as F 18 | from modules.commons import sequence_mask 19 | 20 | MATPLOTLIB_FLAG = False 21 | 22 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 23 | logger = logging 24 | 25 | f0_bin = 256 26 | f0_max = 1100.0 27 | f0_min = 50.0 28 | f0_mel_min = 1127 * np.log(1 + f0_min / 700) 29 | f0_mel_max = 1127 * np.log(1 + f0_max / 700) 30 | 31 | 32 | # def normalize_f0(f0, random_scale=True): 33 | # f0_norm = f0.clone() # create a copy of the input Tensor 34 | # batch_size, _, frame_length = f0_norm.shape 35 | # for i in range(batch_size): 36 | # means = torch.mean(f0_norm[i, 0, :]) 37 | # if random_scale: 38 | # factor = random.uniform(0.8, 1.2) 39 | # else: 40 | # factor = 1 41 | # f0_norm[i, 0, :] = (f0_norm[i, 0, :] - means) * factor 42 | # return f0_norm 43 | # def normalize_f0(f0, random_scale=True): 44 | # means = torch.mean(f0[:, 0, :], dim=1, keepdim=True) 45 | # if random_scale: 46 | # factor = torch.Tensor(f0.shape[0],1).uniform_(0.8, 1.2).to(f0.device) 47 | # else: 48 | # factor = torch.ones(f0.shape[0], 1, 1).to(f0.device) 49 | # f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1) 50 | # return f0_norm 51 | 52 | def deprecated(func): 53 | """This is a decorator which can be used to mark functions 54 | as deprecated. It will result in a warning being emitted 55 | when the function is used.""" 56 | @functools.wraps(func) 57 | def new_func(*args, **kwargs): 58 | warnings.simplefilter('always', DeprecationWarning) # turn off filter 59 | warnings.warn("Call to deprecated function {}.".format(func.__name__), 60 | category=DeprecationWarning, 61 | stacklevel=2) 62 | warnings.simplefilter('default', DeprecationWarning) # reset filter 63 | return func(*args, **kwargs) 64 | return new_func 65 | 66 | def normalize_f0(f0, uv, random_scale=True): 67 | # calculate means based on x_mask 68 | uv_sum = torch.sum(uv, dim=1, keepdim=True) 69 | uv_sum[uv_sum == 0] = 9999 70 | means = torch.sum(f0[:, 0, :] * uv, dim=1, keepdim=True) / uv_sum 71 | 72 | if random_scale: 73 | factor = torch.Tensor(f0.shape[0], 1).uniform_(0.8, 1.2).to(f0.device) 74 | else: 75 | factor = torch.ones(f0.shape[0], 1).to(f0.device) 76 | # normalize f0 based on means and factor 77 | f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1) 78 | if torch.isnan(f0_norm).any(): 79 | exit(0) 80 | return f0_norm 81 | 82 | def compute_f0_uv_torchcrepe(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512,device=None,cr_threshold=0.05): 83 | from modules.crepe import CrepePitchExtractor 84 | x = wav_numpy 85 | if p_len is None: 86 | p_len = x.shape[0]//hop_length 87 | else: 88 | assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error" 89 | 90 | f0_min = 50 91 | f0_max = 1100 92 | F0Creper = CrepePitchExtractor(hop_length=hop_length,f0_min=f0_min,f0_max=f0_max,device=device,threshold=cr_threshold) 93 | f0,uv = F0Creper(x[None,:].float(),sampling_rate,pad_to=p_len) 94 | return f0,uv 95 | 96 | def plot_data_to_numpy(x, y): 97 | global MATPLOTLIB_FLAG 98 | if not MATPLOTLIB_FLAG: 99 | import matplotlib 100 | matplotlib.use("Agg") 101 | MATPLOTLIB_FLAG = True 102 | mpl_logger = logging.getLogger('matplotlib') 103 | mpl_logger.setLevel(logging.WARNING) 104 | import matplotlib.pylab as plt 105 | import numpy as np 106 | 107 | fig, ax = plt.subplots(figsize=(10, 2)) 108 | plt.plot(x) 109 | plt.plot(y) 110 | plt.tight_layout() 111 | 112 | fig.canvas.draw() 113 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 114 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 115 | plt.close() 116 | return data 117 | 118 | 119 | 120 | def interpolate_f0(f0): 121 | 122 | data = np.reshape(f0, (f0.size, 1)) 123 | 124 | vuv_vector = np.zeros((data.size, 1), dtype=np.float32) 125 | vuv_vector[data > 0.0] = 1.0 126 | vuv_vector[data <= 0.0] = 0.0 127 | 128 | ip_data = data 129 | 130 | frame_number = data.size 131 | last_value = 0.0 132 | for i in range(frame_number): 133 | if data[i] <= 0.0: 134 | j = i + 1 135 | for j in range(i + 1, frame_number): 136 | if data[j] > 0.0: 137 | break 138 | if j < frame_number - 1: 139 | if last_value > 0.0: 140 | step = (data[j] - data[i - 1]) / float(j - i) 141 | for k in range(i, j): 142 | ip_data[k] = data[i - 1] + step * (k - i + 1) 143 | else: 144 | for k in range(i, j): 145 | ip_data[k] = data[j] 146 | else: 147 | for k in range(i, frame_number): 148 | ip_data[k] = last_value 149 | else: 150 | ip_data[i] = data[i] # this may not be necessary 151 | last_value = data[i] 152 | 153 | return ip_data[:,0], vuv_vector[:,0] 154 | 155 | 156 | def compute_f0_parselmouth(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512): 157 | import parselmouth 158 | x = wav_numpy 159 | if p_len is None: 160 | p_len = x.shape[0]//hop_length 161 | else: 162 | assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error" 163 | time_step = hop_length / sampling_rate * 1000 164 | f0_min = 50 165 | f0_max = 1100 166 | f0 = parselmouth.Sound(x, sampling_rate).to_pitch_ac( 167 | time_step=time_step / 1000, voicing_threshold=0.6, 168 | pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency'] 169 | 170 | pad_size=(p_len - len(f0) + 1) // 2 171 | if(pad_size>0 or p_len - len(f0) - pad_size>0): 172 | f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant') 173 | return f0 174 | 175 | def resize_f0(x, target_len): 176 | source = np.array(x) 177 | source[source<0.001] = np.nan 178 | target = np.interp(np.arange(0, len(source)*target_len, len(source))/ target_len, np.arange(0, len(source)), source) 179 | res = np.nan_to_num(target) 180 | return res 181 | 182 | def compute_f0_dio(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512): 183 | import pyworld 184 | if p_len is None: 185 | p_len = wav_numpy.shape[0]//hop_length 186 | f0, t = pyworld.dio( 187 | wav_numpy.astype(np.double), 188 | fs=sampling_rate, 189 | f0_ceil=800, 190 | frame_period=1000 * hop_length / sampling_rate, 191 | ) 192 | f0 = pyworld.stonemask(wav_numpy.astype(np.double), f0, t, sampling_rate) 193 | for index, pitch in enumerate(f0): 194 | f0[index] = round(pitch, 1) 195 | return resize_f0(f0, p_len) 196 | 197 | def f0_to_coarse(f0): 198 | is_torch = isinstance(f0, torch.Tensor) 199 | f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700) 200 | f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1 201 | 202 | f0_mel[f0_mel <= 1] = 1 203 | f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1 204 | f0_coarse = (f0_mel + 0.5).int() if is_torch else np.rint(f0_mel).astype(np.int) 205 | assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min()) 206 | return f0_coarse 207 | 208 | 209 | def get_hubert_model(): 210 | vec_path = "hubert/checkpoint_best_legacy_500.pt" 211 | print("load model(s) from {}".format(vec_path)) 212 | from fairseq import checkpoint_utils 213 | models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( 214 | [vec_path], 215 | suffix="", 216 | ) 217 | model = models[0] 218 | model.eval() 219 | return model 220 | 221 | def get_hubert_content(hmodel, wav_16k_tensor): 222 | feats = wav_16k_tensor 223 | if feats.dim() == 2: # double channels 224 | feats = feats.mean(-1) 225 | assert feats.dim() == 1, feats.dim() 226 | feats = feats.view(1, -1) 227 | padding_mask = torch.BoolTensor(feats.shape).fill_(False) 228 | inputs = { 229 | "source": feats.to(wav_16k_tensor.device), 230 | "padding_mask": padding_mask.to(wav_16k_tensor.device), 231 | "output_layer": 12, # layer 12 232 | } 233 | with torch.no_grad(): 234 | logits = hmodel.extract_features(**inputs) 235 | feats = hmodel.final_proj(logits[0]) 236 | return feats.transpose(1, 2) 237 | 238 | 239 | def get_content(cmodel, y): 240 | with torch.no_grad(): 241 | c = cmodel.extract_features(y.squeeze(1))[0] 242 | c = c.transpose(1, 2) 243 | return c 244 | 245 | 246 | 247 | def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False): 248 | assert os.path.isfile(checkpoint_path) 249 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 250 | iteration = checkpoint_dict['iteration'] 251 | learning_rate = checkpoint_dict['learning_rate'] 252 | if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None: 253 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 254 | saved_state_dict = checkpoint_dict['model'] 255 | if hasattr(model, 'module'): 256 | state_dict = model.module.state_dict() 257 | else: 258 | state_dict = model.state_dict() 259 | new_state_dict = {} 260 | for k, v in state_dict.items(): 261 | try: 262 | # assert "dec" in k or "disc" in k 263 | # print("load", k) 264 | new_state_dict[k] = saved_state_dict[k] 265 | assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape) 266 | except: 267 | print("error, %s is not in the checkpoint" % k) 268 | logger.info("%s is not in the checkpoint" % k) 269 | new_state_dict[k] = v 270 | if hasattr(model, 'module'): 271 | model.module.load_state_dict(new_state_dict) 272 | else: 273 | model.load_state_dict(new_state_dict) 274 | print("load ") 275 | logger.info("Loaded checkpoint '{}' (iteration {})".format( 276 | checkpoint_path, iteration)) 277 | return model, optimizer, learning_rate, iteration 278 | 279 | 280 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 281 | logger.info("Saving model and optimizer state at iteration {} to {}".format( 282 | iteration, checkpoint_path)) 283 | if hasattr(model, 'module'): 284 | state_dict = model.module.state_dict() 285 | else: 286 | state_dict = model.state_dict() 287 | torch.save({'model': state_dict, 288 | 'iteration': iteration, 289 | 'optimizer': optimizer.state_dict(), 290 | 'learning_rate': learning_rate}, checkpoint_path) 291 | 292 | def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True): 293 | """Freeing up space by deleting saved ckpts 294 | 295 | Arguments: 296 | path_to_models -- Path to the model directory 297 | n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth 298 | sort_by_time -- True -> chronologically delete ckpts 299 | False -> lexicographically delete ckpts 300 | """ 301 | ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))] 302 | name_key = (lambda _f: int(re.compile('model-(\d+)\.pt').match(_f).group(1))) 303 | time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))) 304 | sort_key = time_key if sort_by_time else name_key 305 | x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')], key=sort_key) 306 | to_del = [os.path.join(path_to_models, fn) for fn in 307 | (x_sorted('model')[:-n_ckpts_to_keep])] 308 | del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}") 309 | del_routine = lambda x: [os.remove(x), del_info(x)] 310 | rs = [del_routine(fn) for fn in to_del] 311 | 312 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): 313 | for k, v in scalars.items(): 314 | writer.add_scalar(k, v, global_step) 315 | for k, v in histograms.items(): 316 | writer.add_histogram(k, v, global_step) 317 | for k, v in images.items(): 318 | writer.add_image(k, v, global_step, dataformats='HWC') 319 | for k, v in audios.items(): 320 | writer.add_audio(k, v, global_step, audio_sampling_rate) 321 | 322 | 323 | def latest_checkpoint_path(dir_path, regex="G_*.pth"): 324 | f_list = glob.glob(os.path.join(dir_path, regex)) 325 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 326 | x = f_list[-1] 327 | print(x) 328 | return x 329 | 330 | 331 | def plot_spectrogram_to_numpy(spectrogram): 332 | global MATPLOTLIB_FLAG 333 | if not MATPLOTLIB_FLAG: 334 | import matplotlib 335 | matplotlib.use("Agg") 336 | MATPLOTLIB_FLAG = True 337 | mpl_logger = logging.getLogger('matplotlib') 338 | mpl_logger.setLevel(logging.WARNING) 339 | import matplotlib.pylab as plt 340 | import numpy as np 341 | 342 | fig, ax = plt.subplots(figsize=(10,2)) 343 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 344 | interpolation='none') 345 | plt.colorbar(im, ax=ax) 346 | plt.xlabel("Frames") 347 | plt.ylabel("Channels") 348 | plt.tight_layout() 349 | 350 | fig.canvas.draw() 351 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 352 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 353 | plt.close() 354 | return data 355 | 356 | 357 | def plot_alignment_to_numpy(alignment, info=None): 358 | global MATPLOTLIB_FLAG 359 | if not MATPLOTLIB_FLAG: 360 | import matplotlib 361 | matplotlib.use("Agg") 362 | MATPLOTLIB_FLAG = True 363 | mpl_logger = logging.getLogger('matplotlib') 364 | mpl_logger.setLevel(logging.WARNING) 365 | import matplotlib.pylab as plt 366 | import numpy as np 367 | 368 | fig, ax = plt.subplots(figsize=(6, 4)) 369 | im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower', 370 | interpolation='none') 371 | fig.colorbar(im, ax=ax) 372 | xlabel = 'Decoder timestep' 373 | if info is not None: 374 | xlabel += '\n\n' + info 375 | plt.xlabel(xlabel) 376 | plt.ylabel('Encoder timestep') 377 | plt.tight_layout() 378 | 379 | fig.canvas.draw() 380 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 381 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 382 | plt.close() 383 | return data 384 | 385 | 386 | def load_wav_to_torch(full_path): 387 | sampling_rate, data = read(full_path) 388 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 389 | 390 | 391 | def load_filepaths_and_text(filename, split="|"): 392 | with open(filename, encoding='utf-8') as f: 393 | filepaths_and_text = [line.strip().split(split) for line in f] 394 | return filepaths_and_text 395 | 396 | 397 | def get_hparams(init=True): 398 | parser = argparse.ArgumentParser() 399 | parser.add_argument('-c', '--config', type=str, default="./configs/base.json", 400 | help='JSON file for configuration') 401 | parser.add_argument('-m', '--model', type=str, required=True, 402 | help='Model name') 403 | 404 | args = parser.parse_args() 405 | model_dir = os.path.join("./logs", args.model) 406 | 407 | if not os.path.exists(model_dir): 408 | os.makedirs(model_dir) 409 | 410 | config_path = args.config 411 | config_save_path = os.path.join(model_dir, "config.json") 412 | if init: 413 | with open(config_path, "r") as f: 414 | data = f.read() 415 | with open(config_save_path, "w") as f: 416 | f.write(data) 417 | else: 418 | with open(config_save_path, "r") as f: 419 | data = f.read() 420 | config = json.loads(data) 421 | 422 | hparams = HParams(**config) 423 | hparams.model_dir = model_dir 424 | return hparams 425 | 426 | 427 | def get_hparams_from_dir(model_dir): 428 | config_save_path = os.path.join(model_dir, "config.json") 429 | with open(config_save_path, "r") as f: 430 | data = f.read() 431 | config = json.loads(data) 432 | 433 | hparams =HParams(**config) 434 | hparams.model_dir = model_dir 435 | return hparams 436 | 437 | 438 | def get_hparams_from_file(config_path): 439 | with open(config_path, "r") as f: 440 | data = f.read() 441 | config = json.loads(data) 442 | 443 | hparams =HParams(**config) 444 | return hparams 445 | 446 | 447 | def check_git_hash(model_dir): 448 | source_dir = os.path.dirname(os.path.realpath(__file__)) 449 | if not os.path.exists(os.path.join(source_dir, ".git")): 450 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( 451 | source_dir 452 | )) 453 | return 454 | 455 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 456 | 457 | path = os.path.join(model_dir, "githash") 458 | if os.path.exists(path): 459 | saved_hash = open(path).read() 460 | if saved_hash != cur_hash: 461 | logger.warn("git hash values are different. {}(saved) != {}(current)".format( 462 | saved_hash[:8], cur_hash[:8])) 463 | else: 464 | open(path, "w").write(cur_hash) 465 | 466 | 467 | def get_logger(model_dir, filename="train.log"): 468 | global logger 469 | logger = logging.getLogger(os.path.basename(model_dir)) 470 | logger.setLevel(logging.DEBUG) 471 | 472 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 473 | if not os.path.exists(model_dir): 474 | os.makedirs(model_dir) 475 | h = logging.FileHandler(os.path.join(model_dir, filename)) 476 | h.setLevel(logging.DEBUG) 477 | h.setFormatter(formatter) 478 | logger.addHandler(h) 479 | return logger 480 | 481 | 482 | def repeat_expand_2d(content, target_len): 483 | # content : [h, t] 484 | 485 | src_len = content.shape[-1] 486 | target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device) 487 | temp = torch.arange(src_len+1) * target_len / src_len 488 | current_pos = 0 489 | for i in range(target_len): 490 | if i < temp[current_pos+1]: 491 | target[:, i] = content[:, current_pos] 492 | else: 493 | current_pos += 1 494 | target[:, i] = content[:, current_pos] 495 | 496 | return target 497 | 498 | 499 | def mix_model(model_paths,mix_rate,mode): 500 | mix_rate = torch.FloatTensor(mix_rate)/100 501 | model_tem = torch.load(model_paths[0]) 502 | models = [torch.load(path)["model"] for path in model_paths] 503 | if mode == 0: 504 | mix_rate = F.softmax(mix_rate,dim=0) 505 | for k in model_tem["model"].keys(): 506 | model_tem["model"][k] = torch.zeros_like(model_tem["model"][k]) 507 | for i,model in enumerate(models): 508 | model_tem["model"][k] += model[k]*mix_rate[i] 509 | torch.save(model_tem,os.path.join(os.path.curdir,"output.pth")) 510 | return os.path.join(os.path.curdir,"output.pth") 511 | 512 | class HParams(): 513 | def __init__(self, **kwargs): 514 | for k, v in kwargs.items(): 515 | if type(v) == dict: 516 | v = HParams(**v) 517 | self[k] = v 518 | 519 | def keys(self): 520 | return self.__dict__.keys() 521 | 522 | def items(self): 523 | return self.__dict__.items() 524 | 525 | def values(self): 526 | return self.__dict__.values() 527 | 528 | def __len__(self): 529 | return len(self.__dict__) 530 | 531 | def __getitem__(self, key): 532 | return getattr(self, key) 533 | 534 | def __setitem__(self, key, value): 535 | return setattr(self, key, value) 536 | 537 | def __contains__(self, key): 538 | return key in self.__dict__ 539 | 540 | def __repr__(self): 541 | return self.__dict__.__repr__() --------------------------------------------------------------------------------