├── .gitignore
├── README.md
├── config.json
├── dataset.py
├── dataset
├── 1
│ ├── 1.code.pt
│ ├── 1.wav
│ ├── 1.wav.f0.npy
│ └── 1.wav.soft.pt
└── 2
│ ├── 2.code.pt
│ ├── 2.wav
│ ├── 2.wav.f0.npy
│ └── 2.wav.soft.pt
├── demo.ipynb
├── hubert
└── put_pretrained_model_here
├── infer.py
├── inference
├── infer_tool.py
└── slicer.py
├── logs
├── tts
│ └── tts_logs_here
└── vc
│ └── vc_logs_here
├── model.py
├── modules
└── commons.py
├── nsf_hifigan
├── env.py
├── models.py
└── utils.py
├── operations.py
├── output
└── output_here
├── parametrizations.py
├── parametrize.py
├── preprocess.py
├── raw
└── input_here
├── sampler
├── dpm_solver.py
└── uni_pc.py
├── test.py
├── train.py
├── unet1d
├── __init__.py
├── activations.py
├── attention.py
├── attention_processor.py
├── dual_transformer_1d.py
├── embeddings.py
├── lora.py
├── outputs.py
├── resnet.py
├── transformer_1d.py
├── unet_1d_blocks.py
└── unet_1d_condition.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | **/*.pyc
2 | **/*.wav
3 | settings.json
4 | **/*.npy
5 | **/*.pt
6 | **/events.*
7 | */*.json
8 | dataset/*
9 | dataset_processed/*
10 | logs/*
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # NS2VC_v2
3 |
4 | ## Unofficial implementation of NaturalSpeech2 for Voice Conversion
5 | Different from the NS2, I use the vocos but encodec as the vocoder for better quality, and use contentvec to substitute the text embedding and duration span process.
6 | I also adopted the unet1d conditional model from the diffusers lib, thanks for their hard works.
7 | 
8 |
9 | ### About Zero shot generalization
10 | I did many attempt on improve the generalization of the model. And I find that it's much like the stable diffusion. If a tag is not in your train set, you can't get a promising result. Larger dataset, more speaker, better generalization, better results. The model can ensure speakers in trainset have a good result.
11 | ### Demo
12 | | refer | input | output |
13 | | :----| :---- | :---- |
14 | |[refer0.webm](https://github.com/adelacvg/NS2VC/assets/27419496/abed2fdc-8366-4522-bbc7-646e0ae6b842)| [gt0.webm](https://github.com/adelacvg/NS2VC/assets/27419496/327794b0-e550-4932-8075-4be09e063d45)| [gen0.webm](https://github.com/adelacvg/NS2VC/assets/27419496/3defcd4a-6843-464c-a903-285a14751096)|
15 | |[refer1.webm](https://github.com/adelacvg/NS2VC/assets/27419496/3d924019-0a68-41a5-aeaf-928a9b8fa8b5)| [gt1.webm](https://github.com/adelacvg/NS2VC/assets/27419496/12fc1514-0edb-493d-a07f-3c94b0548557)| [gen1.webm](https://github.com/adelacvg/NS2VC/assets/27419496/f38e8780-1baf-48b5-b6e5-0ba3856599e2)|
16 | |[refer2.webm](https://github.com/adelacvg/NS2VC/assets/27419496/9759088b-10e7-4bb1-a0ed-c808e11b9f9e)|[gt2.webm](https://github.com/adelacvg/NS2VC/assets/27419496/ddff8bfc-7c6a-4d53-9b98-0d66c421d1d1)|[gen2.webm](https://github.com/adelacvg/NS2VC/assets/27419496/d72cb17d-6813-4d87-8ec5-929b2cc2fb15)|
17 | |[refer3.webm](https://github.com/adelacvg/NS2VC/assets/27419496/c9e045ac-914c-4b49-a112-c71acce2eb27)|[gt3.webm](https://github.com/adelacvg/NS2VC/assets/27419496/a684e11d-32fe-46e3-87e0-e0c6047a24dc)|[gen3.webm](https://github.com/adelacvg/NS2VC/assets/27419496/df3ceced-bfae-4272-a8d7-94a49826f04a)|
18 | |[refer4.webm](https://github.com/adelacvg/NS2VC/assets/27419496/e3191a18-44fc-477e-9ed4-60c42ad35b80)|[gt4.webm](https://github.com/adelacvg/NS2VC/assets/27419496/318a0843-89a5-46de-b1e2-2039a457bc17)|[gen4.webm](https://github.com/adelacvg/NS2VC/assets/27419496/06487dab-f047-4461-9e5c-4bd53bfdfd56)|
19 |
20 |
21 |
22 |
23 | ### Data preprocessing
24 | First of all, you need to download the contentvec model and put it under the hubert folder.
25 | The model can be download from here.
26 |
27 | The dataset structure can be like this:
28 |
29 | ```
30 | dataset
31 | ├── spk1
32 | │ ├── 1.wav
33 | │ ├── 2.wav
34 | │ ├── ...
35 | │ └── spk11
36 | │ ├── 11.wav
37 | ├── 3.wav
38 | ├── 4.wav
39 | ```
40 |
41 | Overall, you can put the data in any way you like.
42 |
43 | Put the data with .wav extension under the dataset folder, and then run the following command to preprocess the data.
44 |
45 | ```python
46 | python preprocess.py
47 | ```
48 |
49 | The preprocessed data will be saved under the processed_dataset folder.
50 |
51 | ## Requirements
52 |
53 | You can install the requirements by running the following command.
54 |
55 | ```python
56 | pip install vocos accelerate matplotlib librosa unidecode inflect ema_pytorch tensorboard fairseq praat-parselmouth pyworld
57 | ```
58 |
59 | ### Training
60 | Install the accelerate first, run `accelerate config` to configure the environment, and then run the following command to train the model.
61 |
62 | ```python
63 | accelerate launch train.py
64 | ```
65 |
66 | ### Inference
67 |
68 | Change the device, model_path, clean_names and refer_names in the inference.py, and then run the following command to inference the model.
69 |
70 | ```python
71 | python infer.py
72 | ```
73 | ### Continue training
74 | If you want to fine tune or continue to train a model.
75 | Add
76 | ```python
77 | trainer.load('your_model_path')
78 | ```
79 | to the `train.py`.
80 | ### Pretrained model
81 | Maybe comming soon, if I had enough data for a good model.
82 |
83 | ### TTS
84 |
85 | If you want to use the TTS model, please check the TTS branch.
86 |
87 | ### Q&A
88 |
89 | qq group:801645314
90 | You can add the qq group to discuss the project.
91 |
92 | Thanks to sovits4, naturalspeech2 and imagen diffusersfor their great works.
93 |
--------------------------------------------------------------------------------
/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 |
4 | "train_batch_size":32,
5 | "gradient_accumulate_every": 1,
6 | "train_lr": 0.0001,
7 | "train_num_steps": 1000000,
8 | "ema_update_every": 10,
9 | "ema_decay": 0.995,
10 | "adam_betas": [0.9, 0.99],
11 | "save_and_sample_every":1000,
12 | "timesteps":1000,
13 | "sampling_timesteps":1000,
14 | "results_folder": "results",
15 | "logs_folder" : "logs/vc",
16 | "num_workers": 32,
17 | "eps": 1e-09,
18 | "keep_ckpts": 3,
19 | "all_in_mem": false
20 | },
21 | "data": {
22 | "training_files": "../vc_dataset_processed",
23 | "val_files": "../val_dataset_processed",
24 | "sampling_rate": 24000,
25 | "hop_length": 256
26 | },
27 | "phoneme_encoder":{
28 | "in_channels":256,
29 | "hidden_channels":256,
30 | "out_channels":256,
31 | "n_layers":6,
32 | "p_dropout":0.2
33 | },
34 | "f0_predictor":
35 | {
36 | "in_channels":256,
37 | "hidden_channels":256,
38 | "out_channels":1,
39 | "attention_layers":10,
40 | "n_heads":8,
41 | "p_dropout":0.5
42 | },
43 | "prompt_encoder":{
44 | "in_channels":100,
45 | "hidden_channels":256,
46 | "out_channels":256,
47 | "n_layers":6,
48 | "p_dropout":0.2
49 | },
50 | "diffusion_encoder":{
51 | "in_channels":100,
52 | "out_channels":100,
53 | "hidden_channels":256,
54 | "n_heads":8,
55 | "p_dropout":0.2
56 | }
57 | }
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 | import os
3 | import random
4 | import numpy as np
5 | import torch
6 | import torch.utils.data
7 | import torchaudio
8 | import utils
9 | import torchaudio.transforms as T
10 | import random
11 |
12 |
13 | """Multi speaker version"""
14 |
15 | class TestDataset(torch.utils.data.Dataset):
16 | def __init__(self, audio_path, cfg, codec, all_in_mem: bool = False):
17 | self.audiopaths = glob(os.path.join(audio_path, "**/*.wav"), recursive=True)
18 | self.sampling_rate = cfg['data']['sampling_rate']
19 | self.hop_length = cfg['data']['hop_length']
20 | random.shuffle(self.audiopaths)
21 | self.all_in_mem = all_in_mem
22 | if self.all_in_mem:
23 | self.cache = [self.get_audio(p[0]) for p in self.audiopaths]
24 |
25 | def get_audio(self, filename):
26 | audio, sampling_rate = torchaudio.load(filename)
27 | audio = T.Resample(sampling_rate, self.sampling_rate)(audio)
28 |
29 | spec = torch.load(filename.replace(".wav", ".spec.pt")).squeeze(0)
30 |
31 | f0 = np.load(filename + ".f0.npy")
32 | f0, uv = utils.interpolate_f0(f0)
33 | f0 = torch.FloatTensor(f0)
34 | uv = torch.FloatTensor(uv)
35 |
36 | c = torch.load(filename+ ".soft.pt")
37 | c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[0])
38 |
39 | lmin = min(c.size(-1), spec.size(-1))
40 | assert abs(c.size(-1) - spec.size(-1)) < 3, (c.size(-1), spec.size(-1), f0.shape, filename)
41 | assert abs(audio.shape[1]-lmin * self.hop_length) < 3 * self.hop_length
42 | spec, c, f0, uv = spec[:, :lmin], c[:, :lmin], f0[:lmin], uv[:lmin]
43 | audio = audio[:, :lmin * self.hop_length]
44 | return c.detach(), f0.detach(), spec.detach(), audio.detach(), uv.detach()
45 |
46 | def __getitem__(self, index):
47 | return *self.get_audio(self.audiopaths[index]), *self.get_audio(self.audiopaths[(index+4)%self.__len__()])
48 |
49 | def __len__(self):
50 | return len(self.audiopaths)
51 |
52 |
53 | class NS2VCDataset(torch.utils.data.Dataset):
54 | """
55 | 1) loads audio, speaker_id, text pairs
56 | 2) normalizes text and converts them to sequences of integers
57 | 3) computes spectrograms from audio files.
58 | """
59 |
60 | def __init__(self, audio_path, cfg, codec, all_in_mem: bool = False):
61 | self.audiopaths = glob(os.path.join(audio_path, "**/*.wav"), recursive=True)
62 | self.sampling_rate = cfg['data']['sampling_rate']
63 | self.hop_length = cfg['data']['hop_length']
64 | # self.codec = codec
65 |
66 | # random.seed(1234)
67 | random.shuffle(self.audiopaths)
68 |
69 | self.all_in_mem = all_in_mem
70 | if self.all_in_mem:
71 | self.cache = [self.get_audio(p[0]) for p in self.audiopaths]
72 |
73 | def get_audio(self, filename):
74 | audio, sampling_rate = torchaudio.load(filename)
75 | audio = T.Resample(sampling_rate, self.sampling_rate)(audio)
76 |
77 | spec = torch.load(filename.replace(".wav", ".spec.pt")).squeeze(0)
78 |
79 | f0 = np.load(filename + ".f0.npy")
80 | f0, uv = utils.interpolate_f0(f0)
81 | f0 = torch.FloatTensor(f0)
82 | uv = torch.FloatTensor(uv)
83 |
84 | c = torch.load(filename+ ".soft.pt")
85 | c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[0])
86 |
87 | lmin = min(c.size(-1), spec.size(-1))
88 | assert abs(c.size(-1) - spec.size(-1)) < 3, (c.size(-1), spec.size(-1), f0.shape, filename)
89 | assert abs(audio.shape[1]-lmin * self.hop_length) < 3 * self.hop_length
90 | spec, c, f0, uv = spec[:, :lmin], c[:, :lmin], f0[:lmin], uv[:lmin]
91 | audio = audio[:, :lmin * self.hop_length]
92 | return c.detach(), f0.detach(), spec.detach(), audio.detach(), uv.detach()
93 |
94 | def random_slice(self, c, f0, spec, audio, uv):
95 | if spec.shape[1] < 30:
96 | print("skip too short audio")
97 | return None
98 | if spec.shape[1] > 400:
99 | start = random.randint(0, spec.shape[1]-400)
100 | end = start + 400
101 | spec, c, f0, uv = spec[:, start:end], c[:, start:end], f0[start:end], uv[start:end]
102 | audio = audio[:, start * self.hop_length : end * self.hop_length]
103 | len_spec = spec.shape[1]
104 | l = random.randint(int(len_spec//3), int(len_spec//3*2))
105 | u = random.randint(0, len_spec-l)
106 | v = u + l
107 | refer = spec[:, u:v]
108 | c = torch.cat([c[:, :u], c[:, v:]], dim=-1)
109 | f0 = torch.cat([f0[:u], f0[v:]], dim=-1)
110 | spec = torch.cat([spec[:, :u], spec[:, v:]], dim=-1)
111 | uv = torch.cat([uv[:u], uv[v:]], dim=-1)
112 | audio = torch.cat([audio[:, :u * self.hop_length], audio[:, v * self.hop_length:]], dim=-1)
113 | assert c.shape[1] != 0
114 | assert refer.shape[1] != 0
115 | return refer, c, f0, spec, audio, uv
116 |
117 | def __getitem__(self, index):
118 | if self.all_in_mem:
119 | return self.random_slice(*self.cache[index])
120 | else:
121 | return self.random_slice(*self.get_audio(self.audiopaths[index]))
122 | # print(1)
123 |
124 | def __len__(self):
125 | return len(self.audiopaths)
126 |
127 |
128 | class TextAudioCollate:
129 |
130 | def __call__(self, batch):
131 | hop_length = 320
132 | batch = [b for b in batch if b is not None]
133 |
134 | input_lengths, ids_sorted_decreasing = torch.sort(
135 | torch.LongTensor([x[0].shape[1] for x in batch]),
136 | dim=0, descending=True)
137 |
138 | # refer, c, f0, spec, audio, uv
139 | max_refer_len = max([x[0].size(1) for x in batch])
140 | max_c_len = max([x[1].size(1) for x in batch])
141 | max_wav_len = max([x[4].size(1) for x in batch])
142 |
143 | lengths = torch.LongTensor(len(batch))
144 | refer_lengths = torch.LongTensor(len(batch))
145 |
146 | contentvec_dim = batch[0][1].shape[0]
147 | spec_dim = batch[0][3].shape[0]
148 | c_padded = torch.FloatTensor(len(batch), contentvec_dim, max_c_len+1)
149 | f0_padded = torch.FloatTensor(len(batch), max_c_len+1)
150 | spec_padded = torch.FloatTensor(len(batch), spec_dim, max_c_len+1)
151 | refer_padded = torch.FloatTensor(len(batch), spec_dim, max_refer_len+1)
152 | wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len+1)
153 | uv_padded = torch.FloatTensor(len(batch), max_c_len+1)
154 |
155 | c_padded.zero_()
156 | spec_padded.zero_()
157 | refer_padded.zero_()
158 | f0_padded.zero_()
159 | wav_padded.zero_()
160 | uv_padded.zero_()
161 |
162 | for i in range(len(ids_sorted_decreasing)):
163 | row = batch[ids_sorted_decreasing[i]]
164 |
165 | # refer, c, f0, spec, audio, uv
166 | len_refer = row[0].size(1)
167 | len_contentvec = row[1].size(1)
168 | len_wav = row[4].size(1)
169 |
170 | lengths[i] = len_contentvec
171 | refer_lengths[i] = len_refer
172 |
173 | refer_padded[i, :, :len_refer] = row[0][:]
174 | c_padded[i, :, :len_contentvec] = row[1][:]
175 | f0_padded[i, :len_contentvec] = row[2][:]
176 | spec_padded[i, :, :len_contentvec] = row[3][:]
177 | wav_padded[i, :, :len_wav] = row[4][:]
178 | uv_padded[i, :len_contentvec] = row[5][:]
179 |
180 | return c_padded, refer_padded, f0_padded, spec_padded, wav_padded, lengths, refer_lengths, uv_padded
181 |
--------------------------------------------------------------------------------
/dataset/1/1.code.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/dataset/1/1.code.pt
--------------------------------------------------------------------------------
/dataset/1/1.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/dataset/1/1.wav
--------------------------------------------------------------------------------
/dataset/1/1.wav.f0.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/dataset/1/1.wav.f0.npy
--------------------------------------------------------------------------------
/dataset/1/1.wav.soft.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/dataset/1/1.wav.soft.pt
--------------------------------------------------------------------------------
/dataset/2/2.code.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/dataset/2/2.code.pt
--------------------------------------------------------------------------------
/dataset/2/2.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/dataset/2/2.wav
--------------------------------------------------------------------------------
/dataset/2/2.wav.f0.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/dataset/2/2.wav.f0.npy
--------------------------------------------------------------------------------
/dataset/2/2.wav.soft.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/dataset/2/2.wav.soft.pt
--------------------------------------------------------------------------------
/hubert/put_pretrained_model_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/hubert/put_pretrained_model_here
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | import io
2 | import time
3 | from pathlib import Path
4 |
5 | import librosa
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import soundfile
9 |
10 | from inference import infer_tool
11 | from inference import slicer
12 | from inference.infer_tool import Svc
13 |
14 | def main():
15 | import argparse
16 |
17 | parser = argparse.ArgumentParser(description='ns2vc inference')
18 |
19 | # Required
20 | parser.add_argument('-m', '--model_path', type=str, default="logs/vc/2023-10-01-17-47-21/model-679.pt",
21 | help='Path to the model.')
22 | parser.add_argument('-c', '--config_path', type=str, default="config.json",
23 | help='Path to the configuration file.')
24 | parser.add_argument('-r', '--refer_names', type=str, nargs='+', default=["keli.wav"],
25 | help='Reference audio path.')
26 | parser.add_argument('-n', '--clean_names', type=str, nargs='+', default=["la.wav"],
27 | help='A list of wav file names located in the raw folder.')
28 | parser.add_argument('-t', '--trans', type=int, nargs='+', default=[0],
29 | help='Pitch adjustment, supports positive and negative (semitone) values.')
30 |
31 | # Optional
32 | parser.add_argument('-a', '--auto_predict_f0', action='store_true', default=True,
33 | help='Automatic pitch prediction for voice conversion. Do not enable this when converting songs as it can cause serious pitch issues.')
34 | parser.add_argument('-cl', '--clip', type=float, default=0,
35 | help='Voice forced slicing. Set to 0 to turn off(default), duration in seconds.')
36 | parser.add_argument('-lg', '--linear_gradient', type=float, default=0,
37 | help='The cross fade length of two audio slices in seconds. If there is a discontinuous voice after forced slicing, you can adjust this value. Otherwise, it is recommended to use. Default 0.')
38 | parser.add_argument('-fmp', '--f0_mean_pooling', action='store_true', default=False,
39 | help='Apply mean filter (pooling) to f0, which may improve some hoarse sounds. Enabling this option will reduce inference speed.')
40 |
41 | # generally keep default
42 | parser.add_argument('-sd', '--slice_db', type=int, default=-40,
43 | help='Loudness for automatic slicing. For noisy audio it can be set to -30')
44 | parser.add_argument('-d', '--device', type=str, default='cuda:2',
45 | help='Device used for inference. None means auto selecting.')
46 | parser.add_argument('-p', '--pad_seconds', type=float, default=0.5,
47 | help='Due to unknown reasons, there may be abnormal noise at the beginning and end. It will disappear after padding a short silent segment.')
48 | parser.add_argument('-wf', '--wav_format', type=str, default='wav',
49 | help='output format')
50 | parser.add_argument('-lgr', '--linear_gradient_retain', type=float, default=0.75,
51 | help='Proportion of cross length retention, range (0-1]. After forced slicing, the beginning and end of each segment need to be discarded.')
52 | parser.add_argument('-ft', '--f0_filter_threshold', type=float, default=0.05,
53 | help='F0 Filtering threshold: This parameter is valid only when f0_mean_pooling is enabled. Values range from 0 to 1. Reducing this value reduces the probability of being out of tune, but increases matte.')
54 |
55 |
56 | args = parser.parse_args()
57 |
58 | clean_names = args.clean_names
59 | refer_names = args.refer_names
60 | trans = args.trans
61 | slice_db = args.slice_db
62 | wav_format = args.wav_format
63 | auto_predict_f0 = args.auto_predict_f0
64 | pad_seconds = args.pad_seconds
65 | clip = args.clip
66 | lg = args.linear_gradient
67 | lgr = args.linear_gradient_retain
68 | F0_mean_pooling = args.f0_mean_pooling
69 | cr_threshold = args.f0_filter_threshold
70 |
71 | svc_model = Svc(args.model_path, args.config_path, args.device)
72 | raw_folder = "raw"
73 | results_folder = "output"
74 | infer_tool.mkdir([raw_folder, results_folder])
75 |
76 | infer_tool.fill_a_to_b(trans, clean_names)
77 | for clean_name, tran in zip(clean_names, trans):
78 | raw_audio_path = f"{raw_folder}/{clean_name}"
79 | if "." not in raw_audio_path:
80 | raw_audio_path += ".wav"
81 | infer_tool.format_wav(raw_audio_path)
82 | wav_path = Path(raw_audio_path).with_suffix('.wav')
83 | chunks = slicer.cut(wav_path, db_thresh=slice_db)
84 | audio_data, audio_sr = slicer.chunks2audio(wav_path, chunks)
85 | per_size = int(clip*audio_sr)
86 | lg_size = int(lg*audio_sr)
87 | lg_size_r = int(lg_size*lgr)
88 | lg_size_c_l = (lg_size-lg_size_r)//2
89 | lg_size_c_r = lg_size-lg_size_r-lg_size_c_l
90 | lg = np.linspace(0,1,lg_size_r) if lg_size!=0 else 0
91 |
92 | for refer_name in refer_names:
93 | audio = []
94 | refer_path = f"{raw_folder}/{refer_name}"
95 | if "." not in refer_path:
96 | refer_path += ".wav"
97 | infer_tool.format_wav(refer_path)
98 | refer_path = Path(refer_path).with_suffix('.wav')
99 | for (slice_tag, data) in audio_data:
100 | print(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======')
101 |
102 | length = int(np.ceil(len(data) / audio_sr * svc_model.target_sample))
103 | if slice_tag:
104 | print('jump empty segment')
105 | _audio = np.zeros(length)
106 | audio.extend(list(infer_tool.pad_array(_audio, length)))
107 | continue
108 | if per_size != 0:
109 | datas = infer_tool.split_list_by_n(data, per_size,lg_size)
110 | else:
111 | datas = [data]
112 | # print(len(datas))
113 | for k,dat in enumerate(datas):
114 | per_length = int(np.ceil(len(dat) / audio_sr * svc_model.target_sample)) if clip!=0 else length
115 | if clip!=0: print(f'###=====segment clip start, {round(len(dat) / audio_sr, 3)}s======')
116 | # padd
117 | pad_len = int(audio_sr * pad_seconds)
118 | dat = np.concatenate([np.zeros([pad_len]), dat, np.zeros([pad_len])])
119 | raw_path = io.BytesIO()
120 | soundfile.write(raw_path, dat, audio_sr, format="wav")
121 | raw_path.seek(0)
122 | out_audio, out_sr = svc_model.infer(tran, raw_path, refer_path,
123 | auto_predict_f0=auto_predict_f0,
124 | F0_mean_pooling = F0_mean_pooling,
125 | cr_threshold = cr_threshold
126 | )
127 | # print(1)
128 | # print(out_audio.shape)
129 | _audio = out_audio.cpu().numpy()
130 | pad_len = int(svc_model.target_sample * pad_seconds)
131 | _audio = _audio[pad_len:-pad_len]
132 | _audio = infer_tool.pad_array(_audio, per_length)
133 | if lg_size!=0 and k!=0:
134 | lg1 = audio[-(lg_size_r+lg_size_c_r):-lg_size_c_r] if lgr != 1 else audio[-lg_size:]
135 | lg2 = _audio[lg_size_c_l:lg_size_c_l+lg_size_r] if lgr != 1 else _audio[0:lg_size]
136 | lg_pre = lg1*(1-lg)+lg2*lg
137 | audio = audio[0:-(lg_size_r+lg_size_c_r)] if lgr != 1 else audio[0:-lg_size]
138 | audio.extend(lg_pre)
139 | _audio = _audio[lg_size_c_l+lg_size_r:] if lgr != 1 else _audio[lg_size:]
140 | audio.extend(list(_audio))
141 | # print(1)
142 | key = "auto" if auto_predict_f0 else f"{tran}key"
143 | res_path = f'./{results_folder}/{clean_name}_{key}_{refer_name}.{wav_format}'
144 | soundfile.write(res_path, audio, svc_model.target_sample, format=wav_format)
145 | svc_model.clear_empty()
146 |
147 | if __name__ == '__main__':
148 | main()
149 |
--------------------------------------------------------------------------------
/inference/infer_tool.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import io
3 | import json
4 | import logging
5 | import os
6 | import time
7 | from pathlib import Path
8 | from inference import slicer
9 | import gc
10 | import librosa
11 | import numpy as np
12 | # import onnxruntime
13 | import soundfile
14 | import torch
15 | import torchaudio
16 | from vocos import Vocos
17 | import torchaudio.transforms as T
18 |
19 | from accelerate import Accelerator
20 | import utils
21 | from model import NaturalSpeech2, Trainer
22 |
23 | logging.getLogger('matplotlib').setLevel(logging.WARNING)
24 | def load_mod(model_path, device, cfg):
25 | data = torch.load(model_path, map_location=device)
26 | model = NaturalSpeech2(cfg=cfg)
27 | model.load_state_dict(data['model'])
28 | model.to(device)
29 | return model
30 |
31 | def read_temp(file_name):
32 | if not os.path.exists(file_name):
33 | with open(file_name, "w") as f:
34 | f.write(json.dumps({"info": "temp_dict"}))
35 | return {}
36 | else:
37 | try:
38 | with open(file_name, "r") as f:
39 | data = f.read()
40 | data_dict = json.loads(data)
41 | if os.path.getsize(file_name) > 50 * 1024 * 1024:
42 | f_name = file_name.replace("\\", "/").split("/")[-1]
43 | print(f"clean {f_name}")
44 | for wav_hash in list(data_dict.keys()):
45 | if int(time.time()) - int(data_dict[wav_hash]["time"]) > 14 * 24 * 3600:
46 | del data_dict[wav_hash]
47 | except Exception as e:
48 | print(e)
49 | print(f"{file_name} error,auto rebuild file")
50 | data_dict = {"info": "temp_dict"}
51 | return data_dict
52 |
53 |
54 | def write_temp(file_name, data):
55 | with open(file_name, "w") as f:
56 | f.write(json.dumps(data))
57 |
58 |
59 | def timeit(func):
60 | def run(*args, **kwargs):
61 | t = time.time()
62 | res = func(*args, **kwargs)
63 | print('executing \'%s\' costed %.3fs' % (func.__name__, time.time() - t))
64 | return res
65 |
66 | return run
67 |
68 |
69 | def format_wav(audio_path):
70 | if Path(audio_path).suffix == '.wav':
71 | return
72 | raw_audio, raw_sample_rate = librosa.load(audio_path, mono=True, sr=None)
73 | soundfile.write(Path(audio_path).with_suffix(".wav"), raw_audio, raw_sample_rate)
74 |
75 |
76 | def get_end_file(dir_path, end):
77 | file_lists = []
78 | for root, dirs, files in os.walk(dir_path):
79 | files = [f for f in files if f[0] != '.']
80 | dirs[:] = [d for d in dirs if d[0] != '.']
81 | for f_file in files:
82 | if f_file.endswith(end):
83 | file_lists.append(os.path.join(root, f_file).replace("\\", "/"))
84 | return file_lists
85 |
86 |
87 | def get_md5(content):
88 | return hashlib.new("md5", content).hexdigest()
89 |
90 | def fill_a_to_b(a, b):
91 | if len(a) < len(b):
92 | for _ in range(0, len(b) - len(a)):
93 | a.append(a[0])
94 |
95 | def mkdir(paths: list):
96 | for path in paths:
97 | if not os.path.exists(path):
98 | os.mkdir(path)
99 |
100 | def pad_array(arr, target_length):
101 | current_length = arr.shape[0]
102 | if current_length >= target_length:
103 | return arr
104 | else:
105 | pad_width = target_length - current_length
106 | pad_left = pad_width // 2
107 | pad_right = pad_width - pad_left
108 | padded_arr = np.pad(arr, (pad_left, pad_right), 'constant', constant_values=(0, 0))
109 | return padded_arr
110 |
111 | def split_list_by_n(list_collection, n, pre=0):
112 | for i in range(0, len(list_collection), n):
113 | yield list_collection[i-pre if i-pre>=0 else i: i + n]
114 |
115 |
116 | class F0FilterException(Exception):
117 | pass
118 |
119 | class Svc(object):
120 | def __init__(self, model_path, config_path,
121 | device=None,
122 | ):
123 | self.model_path = model_path
124 | if device is None:
125 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
126 | else:
127 | self.dev = torch.device(device)
128 | self.model = None
129 | self.cfg = json.load(open(config_path))
130 | self.target_sample = self.cfg['data']['sampling_rate']
131 | self.hop_size = self.cfg['data']['hop_length']
132 | # load hubert
133 | self.hubert_model = utils.get_hubert_model().to(self.dev)
134 | self.load_model()
135 | self.vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
136 |
137 | def load_model(self):
138 | self.model = load_mod(self.model_path, self.dev, self.cfg)
139 | self.model.eval()
140 |
141 | def get_unit_f0_code(self, in_path, tran, refer_path, f0_filter ,F0_mean_pooling,cr_threshold=0.05):
142 | # c, refer, f0, uv, lengths, refer_lengths
143 | wav, sr = librosa.load(in_path, sr=self.target_sample)
144 |
145 | if F0_mean_pooling == True:
146 | f0, uv = utils.compute_f0_uv_torchcrepe(torch.FloatTensor(wav), sampling_rate=self.target_sample, hop_length=self.hop_size,device=self.dev,cr_threshold = cr_threshold)
147 | if f0_filter and sum(f0) == 0:
148 | raise F0FilterException("No voice detected")
149 | f0 = torch.FloatTensor(list(f0))
150 | uv = torch.FloatTensor(list(uv))
151 | if F0_mean_pooling == False:
152 | f0 = utils.compute_f0_parselmouth(wav, sampling_rate=self.target_sample, hop_length=self.hop_size)
153 | if f0_filter and sum(f0) == 0:
154 | raise F0FilterException("No voice detected")
155 | f0, uv = utils.interpolate_f0(f0)
156 | f0 = torch.FloatTensor(f0)
157 | uv = torch.FloatTensor(uv)
158 |
159 | f0 = f0 * 2 ** (tran / 12)
160 | f0 = f0.unsqueeze(0).to(self.dev)
161 | uv = uv.unsqueeze(0).to(self.dev)
162 |
163 | wav16k = librosa.resample(wav, orig_sr=self.target_sample, target_sr=16000)
164 | wav16k = torch.from_numpy(wav16k).to(self.dev)
165 | c = utils.get_hubert_content(self.hubert_model, wav_16k_tensor=wav16k)
166 | c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1])
167 |
168 | c = c.unsqueeze(0).to(self.dev)
169 |
170 | refer_wav, sr = torchaudio.load(refer_path)
171 | wav24k = T.Resample(sr, 24000)(refer_wav)
172 | spec_process = torchaudio.transforms.MelSpectrogram(
173 | sample_rate=24000,
174 | n_fft=1024,
175 | hop_length=256,
176 | n_mels=100,
177 | center=True,
178 | power=1,
179 | )
180 | spec = spec_process(wav24k)# 1 100 T
181 | spec = torch.log(torch.clip(spec, min=1e-7))
182 | refer = spec.to(self.dev)
183 |
184 | lengths = torch.LongTensor([c.shape[2]]).to(self.dev)
185 | refer_lengths = torch.LongTensor([refer.shape[2]]).to(self.dev)
186 |
187 | return c, refer, f0, uv, lengths, refer_lengths
188 |
189 | def infer(self, tran,
190 | raw_path,
191 | refer_path,
192 | auto_predict_f0=False,
193 | f0_filter=False,
194 | F0_mean_pooling=False,
195 | cr_threshold = 0.05
196 | ):
197 |
198 | c, refer, f0, uv, lengths, refer_lengths = self.get_unit_f0_code(raw_path, tran, refer_path, f0_filter,F0_mean_pooling,cr_threshold=cr_threshold)
199 | with torch.no_grad():
200 | start = time.time()
201 | audio,mel = self.model.sample(c, refer, f0, uv, lengths, refer_lengths, self.vocos, auto_predict_f0 =auto_predict_f0)
202 | audio = audio[0].detach().cpu()
203 | # print(audio.shape)
204 | use_time = time.time() - start
205 | print("ns2vc use time:{}".format(use_time))
206 | return audio, audio.shape[-1]
207 |
208 | def clear_empty(self):
209 | # clean up vram
210 | torch.cuda.empty_cache()
211 |
212 | def unload_model(self):
213 | # unload model
214 | self.model = self.model.to("cpu")
215 | del self.model
216 | gc.collect()
217 |
218 | def slice_inference(self,
219 | raw_audio_path,
220 | spk,
221 | tran,
222 | slice_db,
223 | cluster_infer_ratio,
224 | auto_predict_f0,
225 | noice_scale,
226 | pad_seconds=0.5,
227 | clip_seconds=0,
228 | lg_num=0,
229 | lgr_num =0.75,
230 | F0_mean_pooling = False,
231 | enhancer_adaptive_key = 0,
232 | cr_threshold = 0.05
233 | ):
234 | wav_path = raw_audio_path
235 | chunks = slicer.cut(wav_path, db_thresh=slice_db)
236 | audio_data, audio_sr = slicer.chunks2audio(wav_path, chunks)
237 | per_size = int(clip_seconds*audio_sr)
238 | lg_size = int(lg_num*audio_sr)
239 | lg_size_r = int(lg_size*lgr_num)
240 | lg_size_c_l = (lg_size-lg_size_r)//2
241 | lg_size_c_r = lg_size-lg_size_r-lg_size_c_l
242 | lg = np.linspace(0,1,lg_size_r) if lg_size!=0 else 0
243 |
244 | audio = []
245 | for (slice_tag, data) in audio_data:
246 | print(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======')
247 | # padd
248 | length = int(np.ceil(len(data) / audio_sr * self.target_sample))
249 | if slice_tag:
250 | print('jump empty segment')
251 | _audio = np.zeros(length)
252 | audio.extend(list(pad_array(_audio, length)))
253 | continue
254 | if per_size != 0:
255 | datas = split_list_by_n(data, per_size,lg_size)
256 | else:
257 | datas = [data]
258 | for k,dat in enumerate(datas):
259 | per_length = int(np.ceil(len(dat) / audio_sr * self.target_sample)) if clip_seconds!=0 else length
260 | if clip_seconds!=0: print(f'###=====segment clip start, {round(len(dat) / audio_sr, 3)}s======')
261 | # padd
262 | pad_len = int(audio_sr * pad_seconds)
263 | dat = np.concatenate([np.zeros([pad_len]), dat, np.zeros([pad_len])])
264 | raw_path = io.BytesIO()
265 | soundfile.write(raw_path, dat, audio_sr, format="wav")
266 | raw_path.seek(0)
267 | out_audio, out_sr = self.infer(spk, tran, raw_path,
268 | cluster_infer_ratio=cluster_infer_ratio,
269 | auto_predict_f0=auto_predict_f0,
270 | noice_scale=noice_scale,
271 | F0_mean_pooling = F0_mean_pooling,
272 | enhancer_adaptive_key = enhancer_adaptive_key,
273 | cr_threshold = cr_threshold
274 | )
275 | _audio = out_audio.cpu().numpy()
276 | pad_len = int(self.target_sample * pad_seconds)
277 | _audio = _audio[pad_len:-pad_len]
278 | _audio = pad_array(_audio, per_length)
279 | if lg_size!=0 and k!=0:
280 | lg1 = audio[-(lg_size_r+lg_size_c_r):-lg_size_c_r] if lgr_num != 1 else audio[-lg_size:]
281 | lg2 = _audio[lg_size_c_l:lg_size_c_l+lg_size_r] if lgr_num != 1 else _audio[0:lg_size]
282 | lg_pre = lg1*(1-lg)+lg2*lg
283 | audio = audio[0:-(lg_size_r+lg_size_c_r)] if lgr_num != 1 else audio[0:-lg_size]
284 | audio.extend(lg_pre)
285 | _audio = _audio[lg_size_c_l+lg_size_r:] if lgr_num != 1 else _audio[lg_size:]
286 | audio.extend(list(_audio))
287 | return np.array(audio)
288 |
289 | class RealTimeVC:
290 | def __init__(self):
291 | self.last_chunk = None
292 | self.last_o = None
293 | self.chunk_len = 16000 # chunk length
294 | self.pre_len = 3840 # cross fade length, multiples of 640
295 |
296 | # Input and output are 1-dimensional numpy waveform arrays
297 |
298 | def process(self, svc_model, speaker_id, f_pitch_change, input_wav_path,
299 | cluster_infer_ratio=0,
300 | auto_predict_f0=False,
301 | noice_scale=0.4,
302 | f0_filter=False):
303 |
304 | import maad
305 | audio, sr = torchaudio.load(input_wav_path)
306 | audio = audio.cpu().numpy()[0]
307 | temp_wav = io.BytesIO()
308 | if self.last_chunk is None:
309 | input_wav_path.seek(0)
310 |
311 | audio, sr = svc_model.infer(speaker_id, f_pitch_change, input_wav_path,
312 | cluster_infer_ratio=cluster_infer_ratio,
313 | auto_predict_f0=auto_predict_f0,
314 | noice_scale=noice_scale,
315 | f0_filter=f0_filter)
316 |
317 | audio = audio.cpu().numpy()
318 | self.last_chunk = audio[-self.pre_len:]
319 | self.last_o = audio
320 | return audio[-self.chunk_len:]
321 | else:
322 | audio = np.concatenate([self.last_chunk, audio])
323 | soundfile.write(temp_wav, audio, sr, format="wav")
324 | temp_wav.seek(0)
325 |
326 | audio, sr = svc_model.infer(speaker_id, f_pitch_change, temp_wav,
327 | cluster_infer_ratio=cluster_infer_ratio,
328 | auto_predict_f0=auto_predict_f0,
329 | noice_scale=noice_scale,
330 | f0_filter=f0_filter)
331 |
332 | audio = audio.cpu().numpy()
333 | ret = maad.util.crossfade(self.last_o, audio, self.pre_len)
334 | self.last_chunk = audio[-self.pre_len:]
335 | self.last_o = audio
336 | return ret[self.chunk_len:2 * self.chunk_len]
337 |
--------------------------------------------------------------------------------
/inference/slicer.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import torch
3 | import torchaudio
4 |
5 |
6 | class Slicer:
7 | def __init__(self,
8 | sr: int,
9 | threshold: float = -40.,
10 | min_length: int = 5000,
11 | min_interval: int = 300,
12 | hop_size: int = 20,
13 | max_sil_kept: int = 5000):
14 | if not min_length >= min_interval >= hop_size:
15 | raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size')
16 | if not max_sil_kept >= hop_size:
17 | raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size')
18 | min_interval = sr * min_interval / 1000
19 | self.threshold = 10 ** (threshold / 20.)
20 | self.hop_size = round(sr * hop_size / 1000)
21 | self.win_size = min(round(min_interval), 4 * self.hop_size)
22 | self.min_length = round(sr * min_length / 1000 / self.hop_size)
23 | self.min_interval = round(min_interval / self.hop_size)
24 | self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
25 |
26 | def _apply_slice(self, waveform, begin, end):
27 | if len(waveform.shape) > 1:
28 | return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)]
29 | else:
30 | return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)]
31 |
32 | # @timeit
33 | def slice(self, waveform):
34 | if len(waveform.shape) > 1:
35 | samples = librosa.to_mono(waveform)
36 | else:
37 | samples = waveform
38 | if samples.shape[0] <= self.min_length:
39 | return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}}
40 | rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
41 | sil_tags = []
42 | silence_start = None
43 | clip_start = 0
44 | for i, rms in enumerate(rms_list):
45 | # Keep looping while frame is silent.
46 | if rms < self.threshold:
47 | # Record start of silent frames.
48 | if silence_start is None:
49 | silence_start = i
50 | continue
51 | # Keep looping while frame is not silent and silence start has not been recorded.
52 | if silence_start is None:
53 | continue
54 | # Clear recorded silence start if interval is not enough or clip is too short
55 | is_leading_silence = silence_start == 0 and i > self.max_sil_kept
56 | need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
57 | if not is_leading_silence and not need_slice_middle:
58 | silence_start = None
59 | continue
60 | # Need slicing. Record the range of silent frames to be removed.
61 | if i - silence_start <= self.max_sil_kept:
62 | pos = rms_list[silence_start: i + 1].argmin() + silence_start
63 | if silence_start == 0:
64 | sil_tags.append((0, pos))
65 | else:
66 | sil_tags.append((pos, pos))
67 | clip_start = pos
68 | elif i - silence_start <= self.max_sil_kept * 2:
69 | pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin()
70 | pos += i - self.max_sil_kept
71 | pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
72 | pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
73 | if silence_start == 0:
74 | sil_tags.append((0, pos_r))
75 | clip_start = pos_r
76 | else:
77 | sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
78 | clip_start = max(pos_r, pos)
79 | else:
80 | pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
81 | pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
82 | if silence_start == 0:
83 | sil_tags.append((0, pos_r))
84 | else:
85 | sil_tags.append((pos_l, pos_r))
86 | clip_start = pos_r
87 | silence_start = None
88 | # Deal with trailing silence.
89 | total_frames = rms_list.shape[0]
90 | if silence_start is not None and total_frames - silence_start >= self.min_interval:
91 | silence_end = min(total_frames, silence_start + self.max_sil_kept)
92 | pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start
93 | sil_tags.append((pos, total_frames + 1))
94 | # Apply and return slices.
95 | if len(sil_tags) == 0:
96 | return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}}
97 | else:
98 | chunks = []
99 | # The first segment is not the beginning of the audio.
100 | if sil_tags[0][0]:
101 | chunks.append(
102 | {"slice": False, "split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}"})
103 | for i in range(0, len(sil_tags)):
104 | # Mark audio segment. Skip the first segment.
105 | if i:
106 | chunks.append({"slice": False,
107 | "split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}"})
108 | # Mark all mute segments
109 | chunks.append({"slice": True,
110 | "split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}"})
111 | # The last segment is not the end.
112 | if sil_tags[-1][1] * self.hop_size < len(waveform):
113 | chunks.append({"slice": False, "split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}"})
114 | chunk_dict = {}
115 | for i in range(len(chunks)):
116 | chunk_dict[str(i)] = chunks[i]
117 | return chunk_dict
118 |
119 |
120 | def cut(audio_path, db_thresh=-30, min_len=5000):
121 | audio, sr = librosa.load(audio_path, sr=None)
122 | slicer = Slicer(
123 | sr=sr,
124 | threshold=db_thresh,
125 | min_length=min_len
126 | )
127 | chunks = slicer.slice(audio)
128 | return chunks
129 |
130 |
131 | def chunks2audio(audio_path, chunks):
132 | chunks = dict(chunks)
133 | audio, sr = torchaudio.load(audio_path)
134 | if len(audio.shape) == 2 and audio.shape[1] >= 2:
135 | audio = torch.mean(audio, dim=0).unsqueeze(0)
136 | audio = audio.cpu().numpy()[0]
137 | result = []
138 | for k, v in chunks.items():
139 | tag = v["split_time"].split(",")
140 | if tag[0] != tag[1]:
141 | result.append((v["slice"], audio[int(tag[0]):int(tag[1])]))
142 | return result, sr
--------------------------------------------------------------------------------
/logs/tts/tts_logs_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/logs/tts/tts_logs_here
--------------------------------------------------------------------------------
/logs/vc/vc_logs_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/logs/vc/vc_logs_here
--------------------------------------------------------------------------------
/modules/commons.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | def slice_pitch_segments(x, ids_str, segment_size=4):
8 | ret = torch.zeros_like(x[:, :segment_size])
9 | for i in range(x.size(0)):
10 | idx_str = ids_str[i]
11 | idx_end = idx_str + segment_size
12 | ret[i] = x[i, idx_str:idx_end]
13 | return ret
14 |
15 | def rand_slice_segments_with_pitch(x, pitch, x_lengths=None, segment_size=4):
16 | b, d, t = x.size()
17 | if x_lengths is None:
18 | x_lengths = t
19 | ids_str_max = x_lengths - segment_size + 1
20 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
21 | ret = slice_segments(x, ids_str, segment_size)
22 | ret_pitch = slice_pitch_segments(pitch, ids_str, segment_size)
23 | return ret, ret_pitch, ids_str
24 |
25 | def init_weights(m, mean=0.0, std=0.01):
26 | classname = m.__class__.__name__
27 | if classname.find("Conv") != -1:
28 | m.weight.data.normal_(mean, std)
29 |
30 |
31 | def get_padding(kernel_size, dilation=1):
32 | return int((kernel_size*dilation - dilation)/2)
33 |
34 |
35 | def convert_pad_shape(pad_shape):
36 | l = pad_shape[::-1]
37 | pad_shape = [item for sublist in l for item in sublist]
38 | return pad_shape
39 |
40 |
41 | def intersperse(lst, item):
42 | result = [item] * (len(lst) * 2 + 1)
43 | result[1::2] = lst
44 | return result
45 |
46 |
47 | def kl_divergence(m_p, logs_p, m_q, logs_q):
48 | """KL(P||Q)"""
49 | kl = (logs_q - logs_p) - 0.5
50 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
51 | return kl
52 |
53 |
54 | def rand_gumbel(shape):
55 | """Sample from the Gumbel distribution, protect from overflows."""
56 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
57 | return -torch.log(-torch.log(uniform_samples))
58 |
59 |
60 | def rand_gumbel_like(x):
61 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
62 | return g
63 |
64 |
65 | def slice_segments(x, ids_str, segment_size=4):
66 | ret = torch.zeros_like(x[:, :, :segment_size])
67 | for i in range(x.size(0)):
68 | idx_str = ids_str[i]
69 | idx_end = idx_str + segment_size
70 | ret[i] = x[i, :, idx_str:idx_end]
71 | return ret
72 |
73 |
74 | def rand_slice_segments(x, x_lengths=None, segment_size=4):
75 | b, d, t = x.size()
76 | if x_lengths is None:
77 | x_lengths = t
78 | ids_str_max = x_lengths - segment_size + 1
79 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
80 | ret = slice_segments(x, ids_str, segment_size)
81 | return ret, ids_str
82 |
83 |
84 | def rand_spec_segments(x, x_lengths=None, segment_size=4):
85 | b, d, t = x.size()
86 | if x_lengths is None:
87 | x_lengths = t
88 | ids_str_max = x_lengths - segment_size
89 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
90 | ret = slice_segments(x, ids_str, segment_size)
91 | return ret, ids_str
92 |
93 |
94 | def get_timing_signal_1d(
95 | length, channels, min_timescale=1.0, max_timescale=1.0e4):
96 | position = torch.arange(length, dtype=torch.float)
97 | num_timescales = channels // 2
98 | log_timescale_increment = (
99 | math.log(float(max_timescale) / float(min_timescale)) /
100 | (num_timescales - 1))
101 | inv_timescales = min_timescale * torch.exp(
102 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
103 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
104 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
105 | signal = F.pad(signal, [0, 0, 0, channels % 2])
106 | signal = signal.view(1, channels, length)
107 | return signal
108 |
109 |
110 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
111 | b, channels, length = x.size()
112 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
113 | return x + signal.to(dtype=x.dtype, device=x.device)
114 |
115 |
116 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
117 | b, channels, length = x.size()
118 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
119 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
120 |
121 |
122 | def subsequent_mask(length):
123 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
124 | return mask
125 |
126 |
127 | @torch.jit.script
128 | def fused_add_tanh_sigmoid_multiply(input_a, n_channels):
129 | n_channels_int = n_channels[0]
130 | in_act = input_a
131 | t_act = torch.tanh(in_act[:, :n_channels_int, :])
132 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
133 | # print(t_act.size(), s_act.size())
134 | acts = t_act * s_act
135 | return acts
136 |
137 |
138 | def convert_pad_shape(pad_shape):
139 | l = pad_shape[::-1]
140 | pad_shape = [item for sublist in l for item in sublist]
141 | return pad_shape
142 |
143 |
144 | def shift_1d(x):
145 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
146 | return x
147 |
148 |
149 | def sequence_mask(length, max_length=None):
150 | if max_length is None:
151 | max_length = length.max()
152 | x = torch.arange(max_length, dtype=length.dtype, device=length.device)
153 | return x.unsqueeze(0) < length.unsqueeze(1)
154 |
155 |
156 | def generate_path(duration, mask):
157 | """
158 | duration: [b, 1, t_x]
159 | mask: [b, 1, t_y, t_x]
160 | """
161 | device = duration.device
162 |
163 | b, _, t_y, t_x = mask.shape
164 | cum_duration = torch.cumsum(duration, -1)
165 |
166 | cum_duration_flat = cum_duration.view(b * t_x)
167 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
168 | path = path.view(b, t_x, t_y)
169 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
170 | path = path.unsqueeze(1).transpose(2,3) * mask
171 | return path
172 |
173 |
174 | def clip_grad_value_(parameters, clip_value, norm_type=2):
175 | if isinstance(parameters, torch.Tensor):
176 | parameters = [parameters]
177 | parameters = list(filter(lambda p: p.grad is not None, parameters))
178 | norm_type = float(norm_type)
179 | if clip_value is not None:
180 | clip_value = float(clip_value)
181 |
182 | total_norm = 0
183 | for p in parameters:
184 | param_norm = p.grad.data.norm(norm_type)
185 | total_norm += param_norm.item() ** norm_type
186 | if clip_value is not None:
187 | p.grad.data.clamp_(min=-clip_value, max=clip_value)
188 | total_norm = total_norm ** (1. / norm_type)
189 | return total_norm
--------------------------------------------------------------------------------
/nsf_hifigan/env.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 |
5 | class AttrDict(dict):
6 | def __init__(self, *args, **kwargs):
7 | super(AttrDict, self).__init__(*args, **kwargs)
8 | self.__dict__ = self
9 |
10 |
11 | def build_env(config, config_name, path):
12 | t_path = os.path.join(path, config_name)
13 | if config != t_path:
14 | os.makedirs(path, exist_ok=True)
15 | shutil.copyfile(config, os.path.join(path, config_name))
--------------------------------------------------------------------------------
/nsf_hifigan/models.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from .env import AttrDict
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | import torch.nn as nn
8 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
9 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
10 | from .utils import init_weights, get_padding
11 |
12 | LRELU_SLOPE = 0.1
13 |
14 |
15 | def load_model(model_path, device='cuda'):
16 | h = load_config(model_path)
17 |
18 | generator = Generator(h).to(device)
19 |
20 | cp_dict = torch.load(model_path, map_location=device)
21 | generator.load_state_dict(cp_dict['generator'])
22 | generator.eval()
23 | generator.remove_weight_norm()
24 | del cp_dict
25 | return generator, h
26 |
27 | def load_config(model_path):
28 | config_file = os.path.join(os.path.split(model_path)[0], 'config.json')
29 | with open(config_file) as f:
30 | data = f.read()
31 |
32 | json_config = json.loads(data)
33 | h = AttrDict(json_config)
34 | return h
35 |
36 |
37 | class ResBlock1(torch.nn.Module):
38 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
39 | super(ResBlock1, self).__init__()
40 | self.h = h
41 | self.convs1 = nn.ModuleList([
42 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
43 | padding=get_padding(kernel_size, dilation[0]))),
44 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
45 | padding=get_padding(kernel_size, dilation[1]))),
46 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
47 | padding=get_padding(kernel_size, dilation[2])))
48 | ])
49 | self.convs1.apply(init_weights)
50 |
51 | self.convs2 = nn.ModuleList([
52 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
53 | padding=get_padding(kernel_size, 1))),
54 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
55 | padding=get_padding(kernel_size, 1))),
56 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
57 | padding=get_padding(kernel_size, 1)))
58 | ])
59 | self.convs2.apply(init_weights)
60 |
61 | def forward(self, x):
62 | for c1, c2 in zip(self.convs1, self.convs2):
63 | xt = F.leaky_relu(x, LRELU_SLOPE)
64 | xt = c1(xt)
65 | xt = F.leaky_relu(xt, LRELU_SLOPE)
66 | xt = c2(xt)
67 | x = xt + x
68 | return x
69 |
70 | def remove_weight_norm(self):
71 | for l in self.convs1:
72 | remove_weight_norm(l)
73 | for l in self.convs2:
74 | remove_weight_norm(l)
75 |
76 |
77 | class ResBlock2(torch.nn.Module):
78 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
79 | super(ResBlock2, self).__init__()
80 | self.h = h
81 | self.convs = nn.ModuleList([
82 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
83 | padding=get_padding(kernel_size, dilation[0]))),
84 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
85 | padding=get_padding(kernel_size, dilation[1])))
86 | ])
87 | self.convs.apply(init_weights)
88 |
89 | def forward(self, x):
90 | for c in self.convs:
91 | xt = F.leaky_relu(x, LRELU_SLOPE)
92 | xt = c(xt)
93 | x = xt + x
94 | return x
95 |
96 | def remove_weight_norm(self):
97 | for l in self.convs:
98 | remove_weight_norm(l)
99 |
100 |
101 | class SineGen(torch.nn.Module):
102 | """ Definition of sine generator
103 | SineGen(samp_rate, harmonic_num = 0,
104 | sine_amp = 0.1, noise_std = 0.003,
105 | voiced_threshold = 0,
106 | flag_for_pulse=False)
107 | samp_rate: sampling rate in Hz
108 | harmonic_num: number of harmonic overtones (default 0)
109 | sine_amp: amplitude of sine-wavefrom (default 0.1)
110 | noise_std: std of Gaussian noise (default 0.003)
111 | voiced_thoreshold: F0 threshold for U/V classification (default 0)
112 | flag_for_pulse: this SinGen is used inside PulseGen (default False)
113 | Note: when flag_for_pulse is True, the first time step of a voiced
114 | segment is always sin(np.pi) or cos(0)
115 | """
116 |
117 | def __init__(self, samp_rate, harmonic_num=0,
118 | sine_amp=0.1, noise_std=0.003,
119 | voiced_threshold=0):
120 | super(SineGen, self).__init__()
121 | self.sine_amp = sine_amp
122 | self.noise_std = noise_std
123 | self.harmonic_num = harmonic_num
124 | self.dim = self.harmonic_num + 1
125 | self.sampling_rate = samp_rate
126 | self.voiced_threshold = voiced_threshold
127 |
128 | def _f02uv(self, f0):
129 | # generate uv signal
130 | uv = torch.ones_like(f0)
131 | uv = uv * (f0 > self.voiced_threshold)
132 | return uv
133 |
134 | @torch.no_grad()
135 | def forward(self, f0, upp):
136 | """ sine_tensor, uv = forward(f0)
137 | input F0: tensor(batchsize=1, length, dim=1)
138 | f0 for unvoiced steps should be 0
139 | output sine_tensor: tensor(batchsize=1, length, dim)
140 | output uv: tensor(batchsize=1, length, 1)
141 | """
142 | f0 = f0.unsqueeze(-1)
143 | fn = torch.multiply(f0, torch.arange(1, self.dim + 1, device=f0.device).reshape((1, 1, -1)))
144 | rad_values = (fn / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
145 | rand_ini = torch.rand(fn.shape[0], fn.shape[2], device=fn.device)
146 | rand_ini[:, 0] = 0
147 | rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
148 | is_half = rad_values.dtype is not torch.float32
149 | tmp_over_one = torch.cumsum(rad_values.double(), 1) # % 1 #####%1意味着后面的cumsum无法再优化
150 | if is_half:
151 | tmp_over_one = tmp_over_one.half()
152 | else:
153 | tmp_over_one = tmp_over_one.float()
154 | tmp_over_one *= upp
155 | tmp_over_one = F.interpolate(
156 | tmp_over_one.transpose(2, 1), scale_factor=upp,
157 | mode='linear', align_corners=True
158 | ).transpose(2, 1)
159 | rad_values = F.interpolate(rad_values.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1)
160 | tmp_over_one %= 1
161 | tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
162 | cumsum_shift = torch.zeros_like(rad_values)
163 | cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
164 | rad_values = rad_values.double()
165 | cumsum_shift = cumsum_shift.double()
166 | sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
167 | if is_half:
168 | sine_waves = sine_waves.half()
169 | else:
170 | sine_waves = sine_waves.float()
171 | sine_waves = sine_waves * self.sine_amp
172 | return sine_waves
173 |
174 |
175 | class SourceModuleHnNSF(torch.nn.Module):
176 | """ SourceModule for hn-nsf
177 | SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
178 | add_noise_std=0.003, voiced_threshod=0)
179 | sampling_rate: sampling_rate in Hz
180 | harmonic_num: number of harmonic above F0 (default: 0)
181 | sine_amp: amplitude of sine source signal (default: 0.1)
182 | add_noise_std: std of additive Gaussian noise (default: 0.003)
183 | note that amplitude of noise in unvoiced is decided
184 | by sine_amp
185 | voiced_threshold: threhold to set U/V given F0 (default: 0)
186 | Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
187 | F0_sampled (batchsize, length, 1)
188 | Sine_source (batchsize, length, 1)
189 | noise_source (batchsize, length 1)
190 | uv (batchsize, length, 1)
191 | """
192 |
193 | def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
194 | add_noise_std=0.003, voiced_threshod=0):
195 | super(SourceModuleHnNSF, self).__init__()
196 |
197 | self.sine_amp = sine_amp
198 | self.noise_std = add_noise_std
199 |
200 | # to produce sine waveforms
201 | self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
202 | sine_amp, add_noise_std, voiced_threshod)
203 |
204 | # to merge source harmonics into a single excitation
205 | self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
206 | self.l_tanh = torch.nn.Tanh()
207 |
208 | def forward(self, x, upp):
209 | sine_wavs = self.l_sin_gen(x, upp)
210 | sine_merge = self.l_tanh(self.l_linear(sine_wavs))
211 | return sine_merge
212 |
213 |
214 | class Generator(torch.nn.Module):
215 | def __init__(self, h):
216 | super(Generator, self).__init__()
217 | self.h = h
218 | self.num_kernels = len(h.resblock_kernel_sizes)
219 | self.num_upsamples = len(h.upsample_rates)
220 | self.m_source = SourceModuleHnNSF(
221 | sampling_rate=h.sampling_rate,
222 | harmonic_num=8
223 | )
224 | self.noise_convs = nn.ModuleList()
225 | self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
226 | resblock = ResBlock1 if h.resblock == '1' else ResBlock2
227 |
228 | self.ups = nn.ModuleList()
229 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
230 | c_cur = h.upsample_initial_channel // (2 ** (i + 1))
231 | self.ups.append(weight_norm(
232 | ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)),
233 | k, u, padding=(k - u) // 2)))
234 | if i + 1 < len(h.upsample_rates): #
235 | stride_f0 = int(np.prod(h.upsample_rates[i + 1:]))
236 | self.noise_convs.append(Conv1d(
237 | 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
238 | else:
239 | self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
240 | self.resblocks = nn.ModuleList()
241 | ch = h.upsample_initial_channel
242 | for i in range(len(self.ups)):
243 | ch //= 2
244 | for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
245 | self.resblocks.append(resblock(h, ch, k, d))
246 |
247 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
248 | self.ups.apply(init_weights)
249 | self.conv_post.apply(init_weights)
250 | self.upp = int(np.prod(h.upsample_rates))
251 |
252 | def forward(self, x, f0):
253 | har_source = self.m_source(f0, self.upp).transpose(1, 2)
254 | x = self.conv_pre(x)
255 | for i in range(self.num_upsamples):
256 | x = F.leaky_relu(x, LRELU_SLOPE)
257 | x = self.ups[i](x)
258 | x_source = self.noise_convs[i](har_source)
259 | x = x + x_source
260 | xs = None
261 | for j in range(self.num_kernels):
262 | if xs is None:
263 | xs = self.resblocks[i * self.num_kernels + j](x)
264 | else:
265 | xs += self.resblocks[i * self.num_kernels + j](x)
266 | x = xs / self.num_kernels
267 | x = F.leaky_relu(x)
268 | x = self.conv_post(x)
269 | x = torch.tanh(x)
270 |
271 | return x
272 |
273 | def remove_weight_norm(self):
274 | print('Removing weight norm...')
275 | for l in self.ups:
276 | remove_weight_norm(l)
277 | for l in self.resblocks:
278 | l.remove_weight_norm()
279 | remove_weight_norm(self.conv_pre)
280 | remove_weight_norm(self.conv_post)
281 |
282 |
283 | class DiscriminatorP(torch.nn.Module):
284 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
285 | super(DiscriminatorP, self).__init__()
286 | self.period = period
287 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
288 | self.convs = nn.ModuleList([
289 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
290 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
291 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
292 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
293 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
294 | ])
295 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
296 |
297 | def forward(self, x):
298 | fmap = []
299 |
300 | # 1d to 2d
301 | b, c, t = x.shape
302 | if t % self.period != 0: # pad first
303 | n_pad = self.period - (t % self.period)
304 | x = F.pad(x, (0, n_pad), "reflect")
305 | t = t + n_pad
306 | x = x.view(b, c, t // self.period, self.period)
307 |
308 | for l in self.convs:
309 | x = l(x)
310 | x = F.leaky_relu(x, LRELU_SLOPE)
311 | fmap.append(x)
312 | x = self.conv_post(x)
313 | fmap.append(x)
314 | x = torch.flatten(x, 1, -1)
315 |
316 | return x, fmap
317 |
318 |
319 | class MultiPeriodDiscriminator(torch.nn.Module):
320 | def __init__(self, periods=None):
321 | super(MultiPeriodDiscriminator, self).__init__()
322 | self.periods = periods if periods is not None else [2, 3, 5, 7, 11]
323 | self.discriminators = nn.ModuleList()
324 | for period in self.periods:
325 | self.discriminators.append(DiscriminatorP(period))
326 |
327 | def forward(self, y, y_hat):
328 | y_d_rs = []
329 | y_d_gs = []
330 | fmap_rs = []
331 | fmap_gs = []
332 | for i, d in enumerate(self.discriminators):
333 | y_d_r, fmap_r = d(y)
334 | y_d_g, fmap_g = d(y_hat)
335 | y_d_rs.append(y_d_r)
336 | fmap_rs.append(fmap_r)
337 | y_d_gs.append(y_d_g)
338 | fmap_gs.append(fmap_g)
339 |
340 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
341 |
342 |
343 | class DiscriminatorS(torch.nn.Module):
344 | def __init__(self, use_spectral_norm=False):
345 | super(DiscriminatorS, self).__init__()
346 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
347 | self.convs = nn.ModuleList([
348 | norm_f(Conv1d(1, 128, 15, 1, padding=7)),
349 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
350 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
351 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
352 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
353 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
354 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
355 | ])
356 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
357 |
358 | def forward(self, x):
359 | fmap = []
360 | for l in self.convs:
361 | x = l(x)
362 | x = F.leaky_relu(x, LRELU_SLOPE)
363 | fmap.append(x)
364 | x = self.conv_post(x)
365 | fmap.append(x)
366 | x = torch.flatten(x, 1, -1)
367 |
368 | return x, fmap
369 |
370 |
371 | class MultiScaleDiscriminator(torch.nn.Module):
372 | def __init__(self):
373 | super(MultiScaleDiscriminator, self).__init__()
374 | self.discriminators = nn.ModuleList([
375 | DiscriminatorS(use_spectral_norm=True),
376 | DiscriminatorS(),
377 | DiscriminatorS(),
378 | ])
379 | self.meanpools = nn.ModuleList([
380 | AvgPool1d(4, 2, padding=2),
381 | AvgPool1d(4, 2, padding=2)
382 | ])
383 |
384 | def forward(self, y, y_hat):
385 | y_d_rs = []
386 | y_d_gs = []
387 | fmap_rs = []
388 | fmap_gs = []
389 | for i, d in enumerate(self.discriminators):
390 | if i != 0:
391 | y = self.meanpools[i - 1](y)
392 | y_hat = self.meanpools[i - 1](y_hat)
393 | y_d_r, fmap_r = d(y)
394 | y_d_g, fmap_g = d(y_hat)
395 | y_d_rs.append(y_d_r)
396 | fmap_rs.append(fmap_r)
397 | y_d_gs.append(y_d_g)
398 | fmap_gs.append(fmap_g)
399 |
400 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
401 |
402 |
403 | def feature_loss(fmap_r, fmap_g):
404 | loss = 0
405 | for dr, dg in zip(fmap_r, fmap_g):
406 | for rl, gl in zip(dr, dg):
407 | loss += torch.mean(torch.abs(rl - gl))
408 |
409 | return loss * 2
410 |
411 |
412 | def discriminator_loss(disc_real_outputs, disc_generated_outputs):
413 | loss = 0
414 | r_losses = []
415 | g_losses = []
416 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
417 | r_loss = torch.mean((1 - dr) ** 2)
418 | g_loss = torch.mean(dg ** 2)
419 | loss += (r_loss + g_loss)
420 | r_losses.append(r_loss.item())
421 | g_losses.append(g_loss.item())
422 |
423 | return loss, r_losses, g_losses
424 |
425 |
426 | def generator_loss(disc_outputs):
427 | loss = 0
428 | gen_losses = []
429 | for dg in disc_outputs:
430 | l = torch.mean((1 - dg) ** 2)
431 | gen_losses.append(l)
432 | loss += l
433 |
434 | return loss, gen_losses
--------------------------------------------------------------------------------
/nsf_hifigan/utils.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import matplotlib
4 | import torch
5 | from torch.nn.utils import weight_norm
6 | matplotlib.use("Agg")
7 | import matplotlib.pylab as plt
8 |
9 |
10 | def plot_spectrogram(spectrogram):
11 | fig, ax = plt.subplots(figsize=(10, 2))
12 | im = ax.imshow(spectrogram, aspect="auto", origin="lower",
13 | interpolation='none')
14 | plt.colorbar(im, ax=ax)
15 |
16 | fig.canvas.draw()
17 | plt.close()
18 |
19 | return fig
20 |
21 |
22 | def init_weights(m, mean=0.0, std=0.01):
23 | classname = m.__class__.__name__
24 | if classname.find("Conv") != -1:
25 | m.weight.data.normal_(mean, std)
26 |
27 |
28 | def apply_weight_norm(m):
29 | classname = m.__class__.__name__
30 | if classname.find("Conv") != -1:
31 | weight_norm(m)
32 |
33 |
34 | def get_padding(kernel_size, dilation=1):
35 | return int((kernel_size*dilation - dilation)/2)
36 |
37 |
38 | def load_checkpoint(filepath, device):
39 | assert os.path.isfile(filepath)
40 | print("Loading '{}'".format(filepath))
41 | checkpoint_dict = torch.load(filepath, map_location=device)
42 | print("Complete.")
43 | return checkpoint_dict
44 |
45 |
46 | def save_checkpoint(filepath, obj):
47 | print("Saving checkpoint to {}".format(filepath))
48 | torch.save(obj, filepath)
49 | print("Complete.")
50 |
51 |
52 | def del_old_checkpoints(cp_dir, prefix, n_models=2):
53 | pattern = os.path.join(cp_dir, prefix + '????????')
54 | cp_list = glob.glob(pattern) # get checkpoint paths
55 | cp_list = sorted(cp_list)# sort by iter
56 | if len(cp_list) > n_models: # if more than n_models models are found
57 | for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models
58 | open(cp, 'w').close()# empty file contents
59 | os.unlink(cp)# delete file (move to trash when using Colab)
60 |
61 |
62 | def scan_checkpoint(cp_dir, prefix):
63 | pattern = os.path.join(cp_dir, prefix + '????????')
64 | cp_list = glob.glob(pattern)
65 | if len(cp_list) == 0:
66 | return None
67 | return sorted(cp_list)[-1]
68 |
--------------------------------------------------------------------------------
/output/output_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/output/output_here
--------------------------------------------------------------------------------
/parametrizations.py:
--------------------------------------------------------------------------------
1 | from enum import Enum, auto
2 |
3 | import torch
4 | from torch import Tensor
5 | import parametrize
6 | from torch.nn.modules import Module
7 | import torch.nn.functional as F
8 |
9 | from typing import Optional
10 |
11 | __all__ = ['orthogonal', 'spectral_norm']
12 |
13 |
14 | def _is_orthogonal(Q, eps=None):
15 | n, k = Q.size(-2), Q.size(-1)
16 | Id = torch.eye(k, dtype=Q.dtype, device=Q.device)
17 | # A reasonable eps, but not too large
18 | eps = 10. * n * torch.finfo(Q.dtype).eps
19 | return torch.allclose(Q.mH @ Q, Id, atol=eps)
20 |
21 |
22 | def _make_orthogonal(A):
23 | """ Assume that A is a tall matrix.
24 | Compute the Q factor s.t. A = QR (A may be complex) and diag(R) is real and non-negative
25 | """
26 | X, tau = torch.geqrf(A)
27 | Q = torch.linalg.householder_product(X, tau)
28 | # The diagonal of X is the diagonal of R (which is always real) so we normalise by its signs
29 | Q *= X.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2)
30 | return Q
31 |
32 |
33 | class _OrthMaps(Enum):
34 | matrix_exp = auto()
35 | cayley = auto()
36 | householder = auto()
37 |
38 |
39 | class _Orthogonal(Module):
40 | base: Tensor
41 |
42 | def __init__(self,
43 | weight,
44 | orthogonal_map: _OrthMaps,
45 | *,
46 | use_trivialization=True) -> None:
47 | super().__init__()
48 |
49 | # Note [Householder complex]
50 | # For complex tensors, it is not possible to compute the tensor `tau` necessary for
51 | # linalg.householder_product from the reflectors.
52 | # To see this, note that the reflectors have a shape like:
53 | # 0 0 0
54 | # * 0 0
55 | # * * 0
56 | # which, for complex matrices, give n(n-1) (real) parameters. Now, you need n^2 parameters
57 | # to parametrize the unitary matrices. Saving tau on its own does not work either, because
58 | # not every combination of `(A, tau)` gives a unitary matrix, meaning that if we optimise
59 | # them as independent tensors we would not maintain the constraint
60 | # An equivalent reasoning holds for rectangular matrices
61 | if weight.is_complex() and orthogonal_map == _OrthMaps.householder:
62 | raise ValueError("The householder parametrization does not support complex tensors.")
63 |
64 | self.shape = weight.shape
65 | self.orthogonal_map = orthogonal_map
66 | if use_trivialization:
67 | self.register_buffer("base", None)
68 |
69 | def forward(self, X: torch.Tensor) -> torch.Tensor:
70 | n, k = X.size(-2), X.size(-1)
71 | transposed = n < k
72 | if transposed:
73 | X = X.mT
74 | n, k = k, n
75 | # Here n > k and X is a tall matrix
76 | if self.orthogonal_map == _OrthMaps.matrix_exp or self.orthogonal_map == _OrthMaps.cayley:
77 | # We just need n x k - k(k-1)/2 parameters
78 | X = X.tril()
79 | if n != k:
80 | # Embed into a square matrix
81 | X = torch.cat([X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1)
82 | A = X - X.mH
83 | # A is skew-symmetric (or skew-hermitian)
84 | if self.orthogonal_map == _OrthMaps.matrix_exp:
85 | Q = torch.matrix_exp(A)
86 | elif self.orthogonal_map == _OrthMaps.cayley:
87 | # Computes the Cayley retraction (I+A/2)(I-A/2)^{-1}
88 | Id = torch.eye(n, dtype=A.dtype, device=A.device)
89 | Q = torch.linalg.solve(torch.add(Id, A, alpha=-0.5), torch.add(Id, A, alpha=0.5))
90 | # Q is now orthogonal (or unitary) of size (..., n, n)
91 | if n != k:
92 | Q = Q[..., :k]
93 | # Q is now the size of the X (albeit perhaps transposed)
94 | else:
95 | # X is real here, as we do not support householder with complex numbers
96 | A = X.tril(diagonal=-1)
97 | tau = 2. / (1. + (A * A).sum(dim=-2))
98 | Q = torch.linalg.householder_product(A, tau)
99 | # The diagonal of X is 1's and -1's
100 | # We do not want to differentiate through this or update the diagonal of X hence the casting
101 | Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2)
102 |
103 | if hasattr(self, "base"):
104 | Q = self.base @ Q
105 | if transposed:
106 | Q = Q.mT
107 | return Q
108 |
109 | @torch.autograd.no_grad()
110 | def right_inverse(self, Q: torch.Tensor) -> torch.Tensor:
111 | if Q.shape != self.shape:
112 | raise ValueError(f"Expected a matrix or batch of matrices of shape {self.shape}. "
113 | f"Got a tensor of shape {Q.shape}.")
114 |
115 | Q_init = Q
116 | n, k = Q.size(-2), Q.size(-1)
117 | transpose = n < k
118 | if transpose:
119 | Q = Q.mT
120 | n, k = k, n
121 |
122 | # We always make sure to always copy Q in every path
123 | if not hasattr(self, "base"):
124 | # Note [right_inverse expm cayley]
125 | # If we do not have use_trivialization=True, we just implement the inverse of the forward
126 | # map for the Householder. To see why, think that for the Cayley map,
127 | # we would need to find the matrix X \in R^{n x k} such that:
128 | # Y = torch.cat([X.tril(), X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1)
129 | # A = Y - Y.mH
130 | # cayley(A)[:, :k]
131 | # gives the original tensor. It is not clear how to do this.
132 | # Perhaps via some algebraic manipulation involving the QR like that of
133 | # Corollary 2.2 in Edelman, Arias and Smith?
134 | if self.orthogonal_map == _OrthMaps.cayley or self.orthogonal_map == _OrthMaps.matrix_exp:
135 | raise NotImplementedError("It is not possible to assign to the matrix exponential "
136 | "or the Cayley parametrizations when use_trivialization=False.")
137 |
138 | # If parametrization == _OrthMaps.householder, make Q orthogonal via the QR decomposition.
139 | # Here Q is always real because we do not support householder and complex matrices.
140 | # See note [Householder complex]
141 | A, tau = torch.geqrf(Q)
142 | # We want to have a decomposition X = QR with diag(R) > 0, as otherwise we could
143 | # decompose an orthogonal matrix Q as Q = (-Q)@(-Id), which is a valid QR decomposition
144 | # The diagonal of Q is the diagonal of R from the qr decomposition
145 | A.diagonal(dim1=-2, dim2=-1).sign_()
146 | # Equality with zero is ok because LAPACK returns exactly zero when it does not want
147 | # to use a particular reflection
148 | A.diagonal(dim1=-2, dim2=-1)[tau == 0.] *= -1
149 | return A.mT if transpose else A
150 | else:
151 | if n == k:
152 | # We check whether Q is orthogonal
153 | if not _is_orthogonal(Q):
154 | Q = _make_orthogonal(Q)
155 | else: # Is orthogonal
156 | Q = Q.clone()
157 | else:
158 | # Complete Q into a full n x n orthogonal matrix
159 | N = torch.randn(*(Q.size()[:-2] + (n, n - k)), dtype=Q.dtype, device=Q.device)
160 | Q = torch.cat([Q, N], dim=-1)
161 | Q = _make_orthogonal(Q)
162 | self.base = Q
163 |
164 | # It is necessary to return the -Id, as we use the diagonal for the
165 | # Householder parametrization. Using -Id makes:
166 | # householder(torch.zeros(m,n)) == torch.eye(m,n)
167 | # Poor man's version of eye_like
168 | neg_Id = torch.zeros_like(Q_init)
169 | neg_Id.diagonal(dim1=-2, dim2=-1).fill_(-1.)
170 | return neg_Id
171 |
172 |
173 | def orthogonal(module: Module,
174 | name: str = 'weight',
175 | orthogonal_map: Optional[str] = None,
176 | *,
177 | use_trivialization: bool = True) -> Module:
178 | r"""Applies an orthogonal or unitary parametrization to a matrix or a batch of matrices.
179 |
180 | Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, the parametrized
181 | matrix :math:`Q \in \mathbb{K}^{m \times n}` is **orthogonal** as
182 |
183 | .. math::
184 |
185 | \begin{align*}
186 | Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\
187 | QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n}
188 | \end{align*}
189 |
190 | where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex
191 | and the transpose when :math:`Q` is real-valued, and
192 | :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix.
193 | In plain words, :math:`Q` will have orthonormal columns whenever :math:`m \geq n`
194 | and orthonormal rows otherwise.
195 |
196 | If the tensor has more than two dimensions, we consider it as a batch of matrices of shape `(..., m, n)`.
197 |
198 | The matrix :math:`Q` may be parametrized via three different ``orthogonal_map`` in terms of the original tensor:
199 |
200 | - ``"matrix_exp"``/``"cayley"``:
201 | the :func:`~torch.matrix_exp` :math:`Q = \exp(A)` and the `Cayley map`_
202 | :math:`Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1}` are applied to a skew-symmetric
203 | :math:`A` to give an orthogonal matrix.
204 | - ``"householder"``: computes a product of Householder reflectors
205 | (:func:`~torch.linalg.householder_product`).
206 |
207 | ``"matrix_exp"``/``"cayley"`` often make the parametrized weight converge faster than
208 | ``"householder"``, but they are slower to compute for very thin or very wide matrices.
209 |
210 | If ``use_trivialization=True`` (default), the parametrization implements the "Dynamic Trivialization Framework",
211 | where an extra matrix :math:`B \in \mathbb{K}^{n \times n}` is stored under
212 | ``module.parametrizations.weight[0].base``. This helps the
213 | convergence of the parametrized layer at the expense of some extra memory use.
214 | See `Trivializations for Gradient-Based Optimization on Manifolds`_ .
215 |
216 | Initial value of :math:`Q`:
217 | If the original tensor is not parametrized and ``use_trivialization=True`` (default), the initial value
218 | of :math:`Q` is that of the original tensor if it is orthogonal (or unitary in the complex case)
219 | and it is orthogonalized via the QR decomposition otherwise (see :func:`torch.linalg.qr`).
220 | Same happens when it is not parametrized and ``orthogonal_map="householder"`` even when ``use_trivialization=False``.
221 | Otherwise, the initial value is the result of the composition of all the registered
222 | parametrizations applied to the original tensor.
223 |
224 | .. note::
225 | This function is implemented using the parametrization functionality
226 | in :func:`~torch.nn.utils.parametrize.register_parametrization`.
227 |
228 |
229 | .. _`Cayley map`: https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map
230 | .. _`Trivializations for Gradient-Based Optimization on Manifolds`: https://arxiv.org/abs/1909.09501
231 |
232 | Args:
233 | module (nn.Module): module on which to register the parametrization.
234 | name (str, optional): name of the tensor to make orthogonal. Default: ``"weight"``.
235 | orthogonal_map (str, optional): One of the following: ``"matrix_exp"``, ``"cayley"``, ``"householder"``.
236 | Default: ``"matrix_exp"`` if the matrix is square or complex, ``"householder"`` otherwise.
237 | use_trivialization (bool, optional): whether to use the dynamic trivialization framework.
238 | Default: ``True``.
239 |
240 | Returns:
241 | The original module with an orthogonal parametrization registered to the specified
242 | weight
243 |
244 | Example::
245 |
246 | >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
247 | >>> orth_linear = orthogonal(nn.Linear(20, 40))
248 | >>> orth_linear
249 | ParametrizedLinear(
250 | in_features=20, out_features=40, bias=True
251 | (parametrizations): ModuleDict(
252 | (weight): ParametrizationList(
253 | (0): _Orthogonal()
254 | )
255 | )
256 | )
257 | >>> # xdoctest: +IGNORE_WANT
258 | >>> Q = orth_linear.weight
259 | >>> torch.dist(Q.T @ Q, torch.eye(20))
260 | tensor(4.9332e-07)
261 | """
262 | weight = getattr(module, name, None)
263 | if not isinstance(weight, Tensor):
264 | raise ValueError(
265 | "Module '{}' has no parameter or buffer with name '{}'".format(module, name)
266 | )
267 |
268 | # We could implement this for 1-dim tensors as the maps on the sphere
269 | # but I believe it'd bite more people than it'd help
270 | if weight.ndim < 2:
271 | raise ValueError("Expected a matrix or batch of matrices. "
272 | f"Got a tensor of {weight.ndim} dimensions.")
273 |
274 | if orthogonal_map is None:
275 | orthogonal_map = "matrix_exp" if weight.size(-2) == weight.size(-1) or weight.is_complex() else "householder"
276 |
277 | orth_enum = getattr(_OrthMaps, orthogonal_map, None)
278 | if orth_enum is None:
279 | raise ValueError('orthogonal_map has to be one of "matrix_exp", "cayley", "householder". '
280 | f'Got: {orthogonal_map}')
281 | orth = _Orthogonal(weight,
282 | orth_enum,
283 | use_trivialization=use_trivialization)
284 | parametrize.register_parametrization(module, name, orth, unsafe=True)
285 | return module
286 |
287 |
288 | class _WeightNorm(Module):
289 | def __init__(
290 | self,
291 | dim: int = 0,
292 | ) -> None:
293 | super().__init__()
294 | if dim is None:
295 | dim = -1
296 | self.dim = dim
297 |
298 | def forward(self, weight_g, weight_v):
299 | return torch._weight_norm(weight_v, weight_g, self.dim)
300 |
301 | def right_inverse(self, weight):
302 | # TODO: is the .data necessary?
303 | weight_g = torch.norm_except_dim(weight, 2, self.dim).data
304 | weight_v = weight.data
305 |
306 | return weight_g, weight_v
307 |
308 |
309 | def weight_norm(module: Module, name: str = 'weight', dim: int = 0):
310 | r"""Applies weight normalization to a parameter in the given module.
311 |
312 | .. math::
313 | \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
314 |
315 | Weight normalization is a reparameterization that decouples the magnitude
316 | of a weight tensor from its direction. This replaces the parameter specified
317 | by :attr:`name` with two parameters: one specifying the magnitude
318 | and one specifying the direction.
319 |
320 | By default, with ``dim=0``, the norm is computed independently per output
321 | channel/plane. To compute a norm over the entire weight tensor, use
322 | ``dim=None``.
323 |
324 | See https://arxiv.org/abs/1602.07868
325 |
326 | Args:
327 | module (Module): containing module
328 | name (str, optional): name of weight parameter
329 | dim (int, optional): dimension over which to compute the norm
330 |
331 | Returns:
332 | The original module with the weight norm hook
333 |
334 | Example::
335 |
336 | >>> m = weight_norm(nn.Linear(20, 40), name='weight')
337 | >>> m
338 | Linear(in_features=20, out_features=40, bias=True)
339 | >>> m.parametrizations.weight.original0.size()
340 | torch.Size([40, 1])
341 | >>> m.parametrizations.weight.original1.size()
342 | torch.Size([40, 20])
343 |
344 | """
345 | weight = getattr(module, name, None)
346 | if not isinstance(weight, Tensor):
347 | raise ValueError(
348 | "Module '{}' has no parameter or buffer with name '{}'".format(module, name)
349 | )
350 |
351 | _weight_norm = _WeightNorm(dim)
352 | parametrize.register_parametrization(module, name, _weight_norm, unsafe=True)
353 |
354 | def _weight_norm_compat_hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
355 | g_key = f"{prefix}{name}_g"
356 | v_key = f"{prefix}{name}_v"
357 | if g_key in state_dict and v_key in state_dict:
358 | original0 = state_dict.pop(g_key)
359 | original1 = state_dict.pop(v_key)
360 | state_dict[f"{prefix}parametrizations.{name}.original0"] = original0
361 | state_dict[f"{prefix}parametrizations.{name}.original1"] = original1
362 | module._register_load_state_dict_pre_hook(_weight_norm_compat_hook)
363 | return module
364 |
365 |
366 | class _SpectralNorm(Module):
367 | def __init__(
368 | self,
369 | weight: torch.Tensor,
370 | n_power_iterations: int = 1,
371 | dim: int = 0,
372 | eps: float = 1e-12
373 | ) -> None:
374 | super().__init__()
375 | ndim = weight.ndim
376 | if dim >= ndim or dim < -ndim:
377 | raise IndexError("Dimension out of range (expected to be in range of "
378 | f"[-{ndim}, {ndim - 1}] but got {dim})")
379 |
380 | if n_power_iterations <= 0:
381 | raise ValueError('Expected n_power_iterations to be positive, but '
382 | 'got n_power_iterations={}'.format(n_power_iterations))
383 | self.dim = dim if dim >= 0 else dim + ndim
384 | self.eps = eps
385 | if ndim > 1:
386 | # For ndim == 1 we do not need to approximate anything (see _SpectralNorm.forward)
387 | self.n_power_iterations = n_power_iterations
388 | weight_mat = self._reshape_weight_to_matrix(weight)
389 | h, w = weight_mat.size()
390 |
391 | u = weight_mat.new_empty(h).normal_(0, 1)
392 | v = weight_mat.new_empty(w).normal_(0, 1)
393 | self.register_buffer('_u', F.normalize(u, dim=0, eps=self.eps))
394 | self.register_buffer('_v', F.normalize(v, dim=0, eps=self.eps))
395 |
396 | # Start with u, v initialized to some reasonable values by performing a number
397 | # of iterations of the power method
398 | self._power_method(weight_mat, 15)
399 |
400 | def _reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor:
401 | # Precondition
402 | assert weight.ndim > 1
403 |
404 | if self.dim != 0:
405 | # permute dim to front
406 | weight = weight.permute(self.dim, *(d for d in range(weight.dim()) if d != self.dim))
407 |
408 | return weight.flatten(1)
409 |
410 | @torch.autograd.no_grad()
411 | def _power_method(self, weight_mat: torch.Tensor, n_power_iterations: int) -> None:
412 | # See original note at torch/nn/utils/spectral_norm.py
413 | # NB: If `do_power_iteration` is set, the `u` and `v` vectors are
414 | # updated in power iteration **in-place**. This is very important
415 | # because in `DataParallel` forward, the vectors (being buffers) are
416 | # broadcast from the parallelized module to each module replica,
417 | # which is a new module object created on the fly. And each replica
418 | # runs its own spectral norm power iteration. So simply assigning
419 | # the updated vectors to the module this function runs on will cause
420 | # the update to be lost forever. And the next time the parallelized
421 | # module is replicated, the same randomly initialized vectors are
422 | # broadcast and used!
423 | #
424 | # Therefore, to make the change propagate back, we rely on two
425 | # important behaviors (also enforced via tests):
426 | # 1. `DataParallel` doesn't clone storage if the broadcast tensor
427 | # is already on correct device; and it makes sure that the
428 | # parallelized module is already on `device[0]`.
429 | # 2. If the out tensor in `out=` kwarg has correct shape, it will
430 | # just fill in the values.
431 | # Therefore, since the same power iteration is performed on all
432 | # devices, simply updating the tensors in-place will make sure that
433 | # the module replica on `device[0]` will update the _u vector on the
434 | # parallelized module (by shared storage).
435 | #
436 | # However, after we update `u` and `v` in-place, we need to **clone**
437 | # them before using them to normalize the weight. This is to support
438 | # backproping through two forward passes, e.g., the common pattern in
439 | # GAN training: loss = D(real) - D(fake). Otherwise, engine will
440 | # complain that variables needed to do backward for the first forward
441 | # (i.e., the `u` and `v` vectors) are changed in the second forward.
442 |
443 | # Precondition
444 | assert weight_mat.ndim > 1
445 |
446 | for _ in range(n_power_iterations):
447 | # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
448 | # are the first left and right singular vectors.
449 | # This power iteration produces approximations of `u` and `v`.
450 | self._u = F.normalize(torch.mv(weight_mat, self._v), # type: ignore[has-type]
451 | dim=0, eps=self.eps, out=self._u) # type: ignore[has-type]
452 | self._v = F.normalize(torch.mv(weight_mat.t(), self._u),
453 | dim=0, eps=self.eps, out=self._v) # type: ignore[has-type]
454 |
455 | def forward(self, weight: torch.Tensor) -> torch.Tensor:
456 | if weight.ndim == 1:
457 | # Faster and more exact path, no need to approximate anything
458 | return F.normalize(weight, dim=0, eps=self.eps)
459 | else:
460 | weight_mat = self._reshape_weight_to_matrix(weight)
461 | if self.training:
462 | self._power_method(weight_mat, self.n_power_iterations)
463 | # See above on why we need to clone
464 | u = self._u.clone(memory_format=torch.contiguous_format)
465 | v = self._v.clone(memory_format=torch.contiguous_format)
466 | # The proper way of computing this should be through F.bilinear, but
467 | # it seems to have some efficiency issues:
468 | # https://github.com/pytorch/pytorch/issues/58093
469 | sigma = torch.dot(u, torch.mv(weight_mat, v))
470 | return weight / sigma
471 |
472 | def right_inverse(self, value: torch.Tensor) -> torch.Tensor:
473 | # we may want to assert here that the passed value already
474 | # satisfies constraints
475 | return value
476 |
477 |
478 | def spectral_norm(module: Module,
479 | name: str = 'weight',
480 | n_power_iterations: int = 1,
481 | eps: float = 1e-12,
482 | dim: Optional[int] = None) -> Module:
483 | r"""Applies spectral normalization to a parameter in the given module.
484 |
485 | .. math::
486 | \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
487 | \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
488 |
489 | When applied on a vector, it simplifies to
490 |
491 | .. math::
492 | \mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2}
493 |
494 | Spectral normalization stabilizes the training of discriminators (critics)
495 | in Generative Adversarial Networks (GANs) by reducing the Lipschitz constant
496 | of the model. :math:`\sigma` is approximated performing one iteration of the
497 | `power method`_ every time the weight is accessed. If the dimension of the
498 | weight tensor is greater than 2, it is reshaped to 2D in power iteration
499 | method to get spectral norm.
500 |
501 |
502 | See `Spectral Normalization for Generative Adversarial Networks`_ .
503 |
504 | .. _`power method`: https://en.wikipedia.org/wiki/Power_iteration
505 | .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
506 |
507 | .. note::
508 | This function is implemented using the parametrization functionality
509 | in :func:`~torch.nn.utils.parametrize.register_parametrization`. It is a
510 | reimplementation of :func:`torch.nn.utils.spectral_norm`.
511 |
512 | .. note::
513 | When this constraint is registered, the singular vectors associated to the largest
514 | singular value are estimated rather than sampled at random. These are then updated
515 | performing :attr:`n_power_iterations` of the `power method`_ whenever the tensor
516 | is accessed with the module on `training` mode.
517 |
518 | .. note::
519 | If the `_SpectralNorm` module, i.e., `module.parametrization.weight[idx]`,
520 | is in training mode on removal, it will perform another power iteration.
521 | If you'd like to avoid this iteration, set the module to eval mode
522 | before its removal.
523 |
524 | Args:
525 | module (nn.Module): containing module
526 | name (str, optional): name of weight parameter. Default: ``"weight"``.
527 | n_power_iterations (int, optional): number of power iterations to
528 | calculate spectral norm. Default: ``1``.
529 | eps (float, optional): epsilon for numerical stability in
530 | calculating norms. Default: ``1e-12``.
531 | dim (int, optional): dimension corresponding to number of outputs.
532 | Default: ``0``, except for modules that are instances of
533 | ConvTranspose{1,2,3}d, when it is ``1``
534 |
535 | Returns:
536 | The original module with a new parametrization registered to the specified
537 | weight
538 |
539 | Example::
540 |
541 | >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
542 | >>> # xdoctest: +IGNORE_WANT("non-deterministic")
543 | >>> snm = spectral_norm(nn.Linear(20, 40))
544 | >>> snm
545 | ParametrizedLinear(
546 | in_features=20, out_features=40, bias=True
547 | (parametrizations): ModuleDict(
548 | (weight): ParametrizationList(
549 | (0): _SpectralNorm()
550 | )
551 | )
552 | )
553 | >>> torch.linalg.matrix_norm(snm.weight, 2)
554 | tensor(1.0081, grad_fn=)
555 | """
556 | weight = getattr(module, name, None)
557 | if not isinstance(weight, Tensor):
558 | raise ValueError(
559 | "Module '{}' has no parameter or buffer with name '{}'".format(module, name)
560 | )
561 |
562 | if dim is None:
563 | if isinstance(module, (torch.nn.ConvTranspose1d,
564 | torch.nn.ConvTranspose2d,
565 | torch.nn.ConvTranspose3d)):
566 | dim = 1
567 | else:
568 | dim = 0
569 | parametrize.register_parametrization(module, name, _SpectralNorm(weight, n_power_iterations, dim, eps))
570 | return module
571 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | import math
2 | import multiprocessing
3 | import os
4 | import argparse
5 | from random import shuffle
6 | import torchaudio
7 | import torchaudio.transforms as T
8 |
9 | import torch
10 | from glob import glob
11 | from tqdm import tqdm
12 |
13 | # from audiolm_pytorch import SoundStream, EncodecWrapper
14 | import utils
15 | import logging
16 |
17 | logging.getLogger("numba").setLevel(logging.WARNING)
18 | import librosa
19 | import numpy as np
20 |
21 | hps = utils.get_hparams_from_file("config.json")
22 | sampling_rate = hps.data.sampling_rate
23 | hop_length = hps.data.hop_length
24 | in_dir = ""
25 |
26 | def process_one(filename, hmodel):
27 | wav, sr = torchaudio.load(filename)
28 | if wav.shape[0] > 1: # mix to mono
29 | wav = wav.mean(dim=0, keepdim=True)
30 | wav16k = T.Resample(sr, 16000)(wav)
31 | wav24k = T.Resample(sr, 24000)(wav)
32 | filename = filename.replace(in_dir, in_dir+"_processed").replace('.mp3','.wav').replace('.flac','.wav')
33 | wav24k_path = filename
34 | if not os.path.exists(os.path.dirname(wav24k_path)):
35 | os.makedirs(os.path.dirname(wav24k_path))
36 | torchaudio.save(wav24k_path, wav24k, 24000)
37 | soft_path = filename + ".soft.pt"
38 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39 | wav16k = wav16k.to(device)
40 | c = utils.get_hubert_content(hmodel, wav_16k_tensor=wav16k[0])
41 | torch.save(c.cpu(), soft_path)
42 |
43 | f0_path = filename + ".f0.npy"
44 | f0 = utils.compute_f0_dio(
45 | wav24k.cpu().numpy()[0], sampling_rate=24000, hop_length=hop_length
46 | )
47 | np.save(f0_path, f0)
48 |
49 | spec_path = filename.replace(".wav", ".spec.pt")
50 | spec_process = torchaudio.transforms.MelSpectrogram(
51 | sample_rate=24000,
52 | n_fft=1024,
53 | hop_length=256,
54 | n_mels=100,
55 | center=True,
56 | power=1,
57 | )
58 | spec = spec_process(wav24k)# 1 100 T
59 | spec = torch.log(torch.clip(spec, min=1e-7))
60 | torch.save(spec, spec_path)
61 |
62 |
63 | def process_batch(filenames):
64 | print("Loading hubert for content...")
65 | device = "cuda" if torch.cuda.is_available() else "cpu"
66 | hmodel = utils.get_hubert_model().to(device)
67 | # codec = EncodecWrapper()
68 | print("Loaded hubert.")
69 | for filename in tqdm(filenames):
70 | process_one(filename, hmodel)
71 |
72 |
73 | if __name__ == "__main__":
74 | parser = argparse.ArgumentParser()
75 | parser.add_argument(
76 | "--in_dir", type=str, default="dataset", help="path to input dir"
77 | )
78 |
79 | args = parser.parse_args()
80 | filenames = glob(f"{args.in_dir}/**/*.wav", recursive=True)+glob(f"{args.in_dir}/**/*.flac", recursive=True) # [:10]
81 | in_dir = args.in_dir
82 | shuffle(filenames)
83 | process_batch(filenames)
84 |
--------------------------------------------------------------------------------
/raw/input_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adelacvg/NS2VC/f71eae076aee9ab3f61c08fe1a41ce700f15567f/raw/input_here
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import json
5 |
6 | import torchaudio
7 | # from model import NaturalSpeech2, F0Predictor, Diffusion_Encoder, encode
8 | from dataset import NS2VCDataset, TextAudioCollate
9 | from torch.utils.data import Dataset, DataLoader
10 | from multiprocessing import cpu_count
11 | import torchaudio.transforms as T
12 | # from model import rvq_ce_loss
13 |
14 |
15 |
16 | # if __name__ == '__main__':
17 | # cfg = json.load(open('config.json'))
18 |
19 | # collate_fn = TextAudioCollate()
20 | # codec = EncodecWrapper()
21 | # ds = NS2VCDataset(cfg, codec)
22 | # dl = DataLoader(ds, batch_size = cfg['train']['train_batch_size'], shuffle = True, pin_memory = True, num_workers = 0, collate_fn = collate_fn)
23 | # # c_padded, refer_padded, f0_padded, codes_padded, wav_padded, lengths, refer_lengths, uv_padded = next(iter(dl))
24 | # data = next(iter(dl))
25 | # model = NaturalSpeech2(cfg)
26 | # out = model(data, codec)
27 |
28 | # print(c_padded.shape, refer_padded.shape, f0_padded.shape, codes_padded.shape, wav_padded.shape, lengths.shape, refer_lengths.shape, uv_padded.shape)
29 | # torch.Size([8, 256, 276]) torch.Size([8, 128, 276]) torch.Size([8, 276]) torch.Size([8, 128, 276]) torch.Size([8, 1, 88320]) torch.Size([8]) torch.Size([8]) torch.Size([8, 276])
30 |
31 | # out.backward()
32 |
33 | # c_padded, refer_padded, f0_padded, codes_padded, wav_padded, lengths, refer_lengths, uv_padded = next(iter(dl))
34 | # # c_padded refer_padded
35 | # c = c_padded
36 | # refer = refer_padded
37 | # f0 = f0_padded
38 | # uv = uv_padded
39 | # codec = EncodecWrapper()
40 | # with torch.no_grad():
41 | # batches = num_to_groups(1, 1)
42 | # all_samples_list = list(map(lambda n: model.sample(c, refer, f0, uv, codec, batch_size=n), batches))
43 | # all_samples = torch.cat(all_samples_list, dim = 0)
44 | # torchaudio.save(f'sample.wav', all_samples, 24000)
45 | # print(lengths)
46 | # print(refer_lengths)
47 |
48 |
49 |
50 | # phoneme_encoder = TextEncoder(**cfg['phoneme_encoder'])
51 | # f0_predictor = F0Predictor(**cfg['f0_predictor'])
52 | # prompt_encoder = TextEncoder(**cfg['prompt_encoder'])
53 | # diff_model = Diffusion_Encoder(**cfg['diffusion_encoder'])
54 | # audio_prompt = torch.randn(3, 256, 80)
55 | # contentvec = torch.randn(3, 256, 200)
56 | # f0 = torch.randint(1,100,(3, 200))
57 | # noised_audio = torch.randn(3, 512, 200)
58 | # times = torch.randn(3)
59 | # audio_prompt_length = torch.tensor([3, 4, 5])
60 | # contentvec_length = torch.tensor([3, 4, 5])
61 | # #ok
62 | # audio_prompt = prompt_encoder(audio_prompt,audio_prompt_length)
63 | # #ok
64 | # f0_pred = f0_predictor(contentvec, audio_prompt, contentvec_length, audio_prompt_length)
65 | # #ok
66 | # content = phoneme_encoder(contentvec, contentvec_length,f0)
67 | # #ok
68 | # pred = diff_model(
69 | # noised_audio,
70 | # content, audio_prompt,
71 | # contentvec_length, audio_prompt_length,
72 | # times)
73 |
74 | # print(codes.shape)#24k 1 128 T2+1
75 |
76 |
77 |
78 | #reconstruction
79 | # codec = EncodecWrapper()
80 | # audio, sr = torchaudio.load('dataset/1.wav')
81 | # audio24k = T.Resample(sr, 24000)(audio)
82 | # torchaudio.save('1_24k.wav', audio24k, 24000)
83 |
84 | # codec.eval()
85 | # codes, _, _ = codec(audio24k, return_encoded = True)
86 | # audio = codec.decode(codes).squeeze(0)
87 | # torchaudio.save('1.wav', audio.detach(), 24000)
88 |
89 | # codec = EncodecWrapper()
90 | # gt = torch.randn(4, 128, 276)
91 | # pred = torch.randn(4, 128, 276)
92 | # _, indices, _, quantized_list = encode(gt,8,codec)
93 | # n_q=8
94 | # loss = rvq_ce_loss(gt.unsqueeze(0)-quantized_list, indices, codec, n_q)
95 | # print(loss)
96 | # loss = rvq_ce_loss(pred.unsqueeze(0)-quantized_list, indices, codec, n_q)
97 | # print(loss)
98 | # wav,sr = torchaudio.load('/home/hyc/val_dataset/common_voice_zh-CN_37110506.mp3')
99 | # wav24k = T.Resample(sr, 24000)(wav)
100 | # spec_process = torchaudio.transforms.MelSpectrogram(
101 | # sample_rate=24000,
102 | # n_fft=1024,
103 | # hop_length=256,
104 | # n_mels=100,
105 | # center=True,
106 | # power=1,
107 | # )
108 | # spec = spec_process(wav24k)# 1 100 T
109 | # spec = torch.log(torch.clip(spec, min=1e-7))
110 | # print(spec)
111 | # print(spec.shape)
112 |
113 | # prosody_process = torchaudio.transforms.MelSpectrogram(
114 | # sample_rate=24000,
115 | # n_fft=8192,
116 | # hop_length=4096,
117 | # n_mels=400,
118 | # center=True,
119 | # power=1,
120 | # )
121 | # prosody = prosody_process(wav24k)# 1 400 T
122 | # prosody = torch.log(torch.clip(prosody, min=1e-7))
123 | # prosody = torch.repeat_interleave(prosody, 16, dim=2)
124 | # prosody[:,:,16:] = (prosody[:,:,16:] + prosody[:,:,:-16]) / 2
125 | # print(prosody)
126 | # print(prosody.shape)
127 |
128 | import diffusers
129 | from diffusers import UNet1DModel,UNet2DConditionModel
130 | from model import NaturalSpeech2
131 |
132 | from unet1d import UNet1DConditionModel
133 |
134 | # a = torch.randn(4, 20, 10)
135 | # lengths = torch.tensor([10, 9, 8, 7])
136 | # print(torch.arange(10))
137 | # print(torch.arange(10).expand(4, 20, 10))
138 | # mask = torch.arange(10).expand(4, 20, 10) >= lengths.unsqueeze(1).unsqueeze(1)
139 | # a = a.masked_fill(mask,0)
140 | # print(a)
141 |
142 | # unet2d = UNet2DConditionModel(
143 | # block_out_channels=(1,2,4,4),
144 | # norm_num_groups=1,
145 | # cross_attention_dim=16,
146 | # attention_head_dim=1,
147 | # )
148 | # in_img = torch.randn(1,4,16,16)
149 | # cond = torch.randn(1,4,16)
150 | # out = unet2d(in_img, 3, cond)
151 | # print(out.sample.shape)
152 |
153 | # unet1d = UNet1DConditionModel(
154 | # in_channels=1,
155 | # out_channels=1,
156 | # block_out_channels=(4,8,8,8),
157 | # norm_num_groups=2,
158 | # cross_attention_dim=16,
159 | # attention_head_dim=2,
160 | # )
161 | # audio = torch.randn(1,1,17)
162 | # cond = torch.randn(1,20,16)
163 | # out = unet1d(audio, 3, cond)
164 | # print(out.sample.shape)
165 | from nsf_hifigan.models import load_model
166 | import utils
167 | wav, sr = torchaudio.load('raw/test1.wav')
168 | wav = T.Resample(sr, 44100)(wav)
169 | spec_process = torchaudio.transforms.MelSpectrogram(
170 | sample_rate=44100,
171 | n_fft=2048,
172 | hop_length=512,
173 | n_mels=128,
174 | center=True,
175 | power=1,
176 | )
177 |
178 | f0 = utils.compute_f0_dio(
179 | wav.cpu().numpy()[0], sampling_rate=44100, hop_length=512
180 | )
181 | f0 = torch.Tensor(f0)
182 | mel = spec_process(wav)
183 | mel = torch.log(torch.clip(mel, min=1e-7))
184 | device = 'cuda'
185 | vocoder = load_model('nsf_hifigan/model',device=device)[0]
186 | mel = mel.to(device)
187 | f0 = f0.to(device)
188 | length = min(mel.shape[2],f0.shape[0])
189 | mel = mel[:,:,:length]
190 | f0 = f0[:length]
191 | wav = vocoder(mel, f0).cpu().squeeze(0)
192 | torchaudio.save('recon.wav', wav, 44100)
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from model import Trainer
2 |
3 | trainer = Trainer()
4 | trainer.load('logs/vc/2023-09-28-20-49-43/model-639.pt')
5 | trainer.train()
6 |
--------------------------------------------------------------------------------
/unet1d/__init__.py:
--------------------------------------------------------------------------------
1 | from .unet_1d_condition import UNet1DConditionModel
--------------------------------------------------------------------------------
/unet1d/activations.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | def get_activation(act_fn):
5 | if act_fn in ["swish", "silu"]:
6 | return nn.SiLU()
7 | elif act_fn == "mish":
8 | return nn.Mish()
9 | elif act_fn == "gelu":
10 | return nn.GELU()
11 | else:
12 | raise ValueError(f"Unsupported activation function: {act_fn}")
13 |
--------------------------------------------------------------------------------
/unet1d/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import Any, Dict, Optional
15 |
16 | import torch
17 | import torch.nn.functional as F
18 | from torch import nn
19 |
20 | from .activations import get_activation
21 | from .attention_processor import Attention
22 | from .embeddings import CombinedTimestepLabelEmbeddings
23 | from .lora import LoRACompatibleLinear
24 |
25 |
26 | class BasicTransformerBlock(nn.Module):
27 | r"""
28 | A basic Transformer block.
29 |
30 | Parameters:
31 | dim (`int`): The number of channels in the input and output.
32 | num_attention_heads (`int`): The number of heads to use for multi-head attention.
33 | attention_head_dim (`int`): The number of channels in each head.
34 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
35 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
36 | only_cross_attention (`bool`, *optional*):
37 | Whether to use only cross-attention layers. In this case two cross attention layers are used.
38 | double_self_attention (`bool`, *optional*):
39 | Whether to use two self-attention layers. In this case no cross attention layers are used.
40 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
41 | num_embeds_ada_norm (:
42 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
43 | attention_bias (:
44 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
45 | """
46 |
47 | def __init__(
48 | self,
49 | dim: int,
50 | num_attention_heads: int,
51 | attention_head_dim: int,
52 | dropout=0.0,
53 | cross_attention_dim: Optional[int] = None,
54 | activation_fn: str = "geglu",
55 | num_embeds_ada_norm: Optional[int] = None,
56 | attention_bias: bool = False,
57 | only_cross_attention: bool = False,
58 | double_self_attention: bool = False,
59 | upcast_attention: bool = False,
60 | norm_elementwise_affine: bool = True,
61 | norm_type: str = "layer_norm",
62 | final_dropout: bool = False,
63 | ):
64 | super().__init__()
65 | self.only_cross_attention = only_cross_attention
66 |
67 | self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
68 | self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
69 |
70 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
71 | raise ValueError(
72 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
73 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
74 | )
75 |
76 | # Define 3 blocks. Each block has its own normalization layer.
77 | # 1. Self-Attn
78 | if self.use_ada_layer_norm:
79 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
80 | elif self.use_ada_layer_norm_zero:
81 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
82 | else:
83 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
84 | self.attn1 = Attention(
85 | query_dim=dim,
86 | heads=num_attention_heads,
87 | dim_head=attention_head_dim,
88 | dropout=dropout,
89 | bias=attention_bias,
90 | cross_attention_dim=cross_attention_dim if only_cross_attention else None,
91 | upcast_attention=upcast_attention,
92 | )
93 |
94 | # 2. Cross-Attn
95 | if cross_attention_dim is not None or double_self_attention:
96 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
97 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
98 | # the second cross attention block.
99 | self.norm2 = (
100 | AdaLayerNorm(dim, num_embeds_ada_norm)
101 | if self.use_ada_layer_norm
102 | else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
103 | )
104 | self.attn2 = Attention(
105 | query_dim=dim,
106 | cross_attention_dim=cross_attention_dim if not double_self_attention else None,
107 | heads=num_attention_heads,
108 | dim_head=attention_head_dim,
109 | dropout=dropout,
110 | bias=attention_bias,
111 | upcast_attention=upcast_attention,
112 | ) # is self-attn if encoder_hidden_states is none
113 | else:
114 | self.norm2 = None
115 | self.attn2 = None
116 |
117 | # 3. Feed-forward
118 | self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
119 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
120 |
121 | # let chunk size default to None
122 | self._chunk_size = None
123 | self._chunk_dim = 0
124 |
125 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
126 | # Sets chunk feed-forward
127 | self._chunk_size = chunk_size
128 | self._chunk_dim = dim
129 |
130 | def forward(
131 | self,
132 | hidden_states: torch.FloatTensor,
133 | attention_mask: Optional[torch.FloatTensor] = None,
134 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
135 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
136 | timestep: Optional[torch.LongTensor] = None,
137 | cross_attention_kwargs: Dict[str, Any] = None,
138 | class_labels: Optional[torch.LongTensor] = None,
139 | ):
140 | # Notice that normalization is always applied before the real computation in the following blocks.
141 | # 1. Self-Attention
142 | if self.use_ada_layer_norm:
143 | norm_hidden_states = self.norm1(hidden_states, timestep)
144 | elif self.use_ada_layer_norm_zero:
145 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
146 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
147 | )
148 | else:
149 | norm_hidden_states = self.norm1(hidden_states)
150 |
151 | cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
152 |
153 | attn_output = self.attn1(
154 | norm_hidden_states,
155 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
156 | attention_mask=attention_mask,
157 | **cross_attention_kwargs,
158 | )
159 | if self.use_ada_layer_norm_zero:
160 | attn_output = gate_msa.unsqueeze(1) * attn_output
161 | hidden_states = attn_output + hidden_states
162 |
163 | # 2. Cross-Attention
164 | if self.attn2 is not None:
165 | norm_hidden_states = (
166 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
167 | )
168 |
169 | attn_output = self.attn2(
170 | norm_hidden_states,
171 | encoder_hidden_states=encoder_hidden_states,
172 | attention_mask=encoder_attention_mask,
173 | **cross_attention_kwargs,
174 | )
175 | hidden_states = attn_output + hidden_states
176 |
177 | # 3. Feed-forward
178 | norm_hidden_states = self.norm3(hidden_states)
179 |
180 | if self.use_ada_layer_norm_zero:
181 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
182 |
183 | if self._chunk_size is not None:
184 | # "feed_forward_chunk_size" can be used to save memory
185 | if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
186 | raise ValueError(
187 | f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
188 | )
189 |
190 | num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
191 | ff_output = torch.cat(
192 | [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
193 | dim=self._chunk_dim,
194 | )
195 | else:
196 | ff_output = self.ff(norm_hidden_states)
197 |
198 | if self.use_ada_layer_norm_zero:
199 | ff_output = gate_mlp.unsqueeze(1) * ff_output
200 |
201 | hidden_states = ff_output + hidden_states
202 |
203 | return hidden_states
204 |
205 |
206 | class FeedForward(nn.Module):
207 | r"""
208 | A feed-forward layer.
209 |
210 | Parameters:
211 | dim (`int`): The number of channels in the input.
212 | dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
213 | mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
214 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
215 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
216 | final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
217 | """
218 |
219 | def __init__(
220 | self,
221 | dim: int,
222 | dim_out: Optional[int] = None,
223 | mult: int = 4,
224 | dropout: float = 0.0,
225 | activation_fn: str = "geglu",
226 | final_dropout: bool = False,
227 | ):
228 | super().__init__()
229 | inner_dim = int(dim * mult)
230 | dim_out = dim_out if dim_out is not None else dim
231 |
232 | if activation_fn == "gelu":
233 | act_fn = GELU(dim, inner_dim)
234 | if activation_fn == "gelu-approximate":
235 | act_fn = GELU(dim, inner_dim, approximate="tanh")
236 | elif activation_fn == "geglu":
237 | act_fn = GEGLU(dim, inner_dim)
238 | elif activation_fn == "geglu-approximate":
239 | act_fn = ApproximateGELU(dim, inner_dim)
240 |
241 | self.net = nn.ModuleList([])
242 | # project in
243 | self.net.append(act_fn)
244 | # project dropout
245 | self.net.append(nn.Dropout(dropout))
246 | # project out
247 | self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
248 | # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
249 | if final_dropout:
250 | self.net.append(nn.Dropout(dropout))
251 |
252 | def forward(self, hidden_states):
253 | for module in self.net:
254 | hidden_states = module(hidden_states)
255 | return hidden_states
256 |
257 |
258 | class GELU(nn.Module):
259 | r"""
260 | GELU activation function with tanh approximation support with `approximate="tanh"`.
261 | """
262 |
263 | def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
264 | super().__init__()
265 | self.proj = nn.Linear(dim_in, dim_out)
266 | self.approximate = approximate
267 |
268 | def gelu(self, gate):
269 | if gate.device.type != "mps":
270 | return F.gelu(gate, approximate=self.approximate)
271 | # mps: gelu is not implemented for float16
272 | return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
273 |
274 | def forward(self, hidden_states):
275 | hidden_states = self.proj(hidden_states)
276 | hidden_states = self.gelu(hidden_states)
277 | return hidden_states
278 |
279 |
280 | class GEGLU(nn.Module):
281 | r"""
282 | A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
283 |
284 | Parameters:
285 | dim_in (`int`): The number of channels in the input.
286 | dim_out (`int`): The number of channels in the output.
287 | """
288 |
289 | def __init__(self, dim_in: int, dim_out: int):
290 | super().__init__()
291 | self.proj = LoRACompatibleLinear(dim_in, dim_out * 2)
292 |
293 | def gelu(self, gate):
294 | if gate.device.type != "mps":
295 | return F.gelu(gate)
296 | # mps: gelu is not implemented for float16
297 | return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
298 |
299 | def forward(self, hidden_states):
300 | hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
301 | return hidden_states * self.gelu(gate)
302 |
303 |
304 | class ApproximateGELU(nn.Module):
305 | """
306 | The approximate form of Gaussian Error Linear Unit (GELU)
307 |
308 | For more details, see section 2: https://arxiv.org/abs/1606.08415
309 | """
310 |
311 | def __init__(self, dim_in: int, dim_out: int):
312 | super().__init__()
313 | self.proj = nn.Linear(dim_in, dim_out)
314 |
315 | def forward(self, x):
316 | x = self.proj(x)
317 | return x * torch.sigmoid(1.702 * x)
318 |
319 |
320 | class AdaLayerNorm(nn.Module):
321 | """
322 | Norm layer modified to incorporate timestep embeddings.
323 | """
324 |
325 | def __init__(self, embedding_dim, num_embeddings):
326 | super().__init__()
327 | self.emb = nn.Embedding(num_embeddings, embedding_dim)
328 | self.silu = nn.SiLU()
329 | self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
330 | self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
331 |
332 | def forward(self, x, timestep):
333 | emb = self.linear(self.silu(self.emb(timestep)))
334 | scale, shift = torch.chunk(emb, 2)
335 | x = self.norm(x) * (1 + scale) + shift
336 | return x
337 |
338 |
339 | class AdaLayerNormZero(nn.Module):
340 | """
341 | Norm layer adaptive layer norm zero (adaLN-Zero).
342 | """
343 |
344 | def __init__(self, embedding_dim, num_embeddings):
345 | super().__init__()
346 |
347 | self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
348 |
349 | self.silu = nn.SiLU()
350 | self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
351 | self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
352 |
353 | def forward(self, x, timestep, class_labels, hidden_dtype=None):
354 | emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
355 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
356 | x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
357 | return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
358 |
359 |
360 | class AdaGroupNorm(nn.Module):
361 | """
362 | GroupNorm layer modified to incorporate timestep embeddings.
363 | """
364 |
365 | def __init__(
366 | self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
367 | ):
368 | super().__init__()
369 | self.num_groups = num_groups
370 | self.eps = eps
371 |
372 | if act_fn is None:
373 | self.act = None
374 | else:
375 | self.act = get_activation(act_fn)
376 |
377 | self.linear = nn.Linear(embedding_dim, out_dim * 2)
378 |
379 | def forward(self, x, emb):
380 | if self.act:
381 | emb = self.act(emb)
382 | emb = self.linear(emb)
383 | emb = emb[:, :, None, None]
384 | scale, shift = emb.chunk(2, dim=1)
385 |
386 | x = F.group_norm(x, self.num_groups, eps=self.eps)
387 | x = x * (1 + scale) + shift
388 | return x
389 |
--------------------------------------------------------------------------------
/unet1d/dual_transformer_1d.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import Optional
15 |
16 | from torch import nn
17 |
18 | from .transformer_1d import Transformer2DModel, Transformer2DModelOutput
19 |
20 |
21 | class DualTransformer2DModel(nn.Module):
22 | """
23 | Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
24 |
25 | Parameters:
26 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
27 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
28 | in_channels (`int`, *optional*):
29 | Pass if the input is continuous. The number of channels in the input and output.
30 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
31 | dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
32 | cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
33 | sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
34 | Note that this is fixed at training time as it is used for learning a number of position embeddings. See
35 | `ImagePositionalEmbeddings`.
36 | num_vector_embeds (`int`, *optional*):
37 | Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
38 | Includes the class for the masked latent pixel.
39 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
40 | num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
41 | The number of diffusion steps used during training. Note that this is fixed at training time as it is used
42 | to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
43 | up to but not more than steps than `num_embeds_ada_norm`.
44 | attention_bias (`bool`, *optional*):
45 | Configure if the TransformerBlocks' attention should contain a bias parameter.
46 | """
47 |
48 | def __init__(
49 | self,
50 | num_attention_heads: int = 16,
51 | attention_head_dim: int = 88,
52 | in_channels: Optional[int] = None,
53 | num_layers: int = 1,
54 | dropout: float = 0.0,
55 | norm_num_groups: int = 32,
56 | cross_attention_dim: Optional[int] = None,
57 | attention_bias: bool = False,
58 | sample_size: Optional[int] = None,
59 | num_vector_embeds: Optional[int] = None,
60 | activation_fn: str = "geglu",
61 | num_embeds_ada_norm: Optional[int] = None,
62 | ):
63 | super().__init__()
64 | self.transformers = nn.ModuleList(
65 | [
66 | Transformer2DModel(
67 | num_attention_heads=num_attention_heads,
68 | attention_head_dim=attention_head_dim,
69 | in_channels=in_channels,
70 | num_layers=num_layers,
71 | dropout=dropout,
72 | norm_num_groups=norm_num_groups,
73 | cross_attention_dim=cross_attention_dim,
74 | attention_bias=attention_bias,
75 | sample_size=sample_size,
76 | num_vector_embeds=num_vector_embeds,
77 | activation_fn=activation_fn,
78 | num_embeds_ada_norm=num_embeds_ada_norm,
79 | )
80 | for _ in range(2)
81 | ]
82 | )
83 |
84 | # Variables that can be set by a pipeline:
85 |
86 | # The ratio of transformer1 to transformer2's output states to be combined during inference
87 | self.mix_ratio = 0.5
88 |
89 | # The shape of `encoder_hidden_states` is expected to be
90 | # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
91 | self.condition_lengths = [77, 257]
92 |
93 | # Which transformer to use to encode which condition.
94 | # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
95 | self.transformer_index_for_condition = [1, 0]
96 |
97 | def forward(
98 | self,
99 | hidden_states,
100 | encoder_hidden_states,
101 | timestep=None,
102 | attention_mask=None,
103 | cross_attention_kwargs=None,
104 | return_dict: bool = True,
105 | ):
106 | """
107 | Args:
108 | hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
109 | When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
110 | hidden_states
111 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
112 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
113 | self-attention.
114 | timestep ( `torch.long`, *optional*):
115 | Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
116 | attention_mask (`torch.FloatTensor`, *optional*):
117 | Optional attention mask to be applied in Attention
118 | return_dict (`bool`, *optional*, defaults to `True`):
119 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
120 |
121 | Returns:
122 | [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
123 | [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
124 | returning a tuple, the first element is the sample tensor.
125 | """
126 | input_states = hidden_states
127 |
128 | encoded_states = []
129 | tokens_start = 0
130 | # attention_mask is not used yet
131 | for i in range(2):
132 | # for each of the two transformers, pass the corresponding condition tokens
133 | condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
134 | transformer_index = self.transformer_index_for_condition[i]
135 | encoded_state = self.transformers[transformer_index](
136 | input_states,
137 | encoder_hidden_states=condition_state,
138 | timestep=timestep,
139 | cross_attention_kwargs=cross_attention_kwargs,
140 | return_dict=False,
141 | )[0]
142 | encoded_states.append(encoded_state - input_states)
143 | tokens_start += self.condition_lengths[i]
144 |
145 | output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
146 | output_states = output_states + input_states
147 |
148 | if not return_dict:
149 | return (output_states,)
150 |
151 | return Transformer2DModelOutput(sample=output_states)
152 |
--------------------------------------------------------------------------------
/unet1d/embeddings.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import math
15 | from typing import Optional
16 |
17 | import numpy as np
18 | import torch
19 | from torch import nn
20 |
21 | from .activations import get_activation
22 |
23 |
24 | def get_timestep_embedding(
25 | timesteps: torch.Tensor,
26 | embedding_dim: int,
27 | flip_sin_to_cos: bool = False,
28 | downscale_freq_shift: float = 1,
29 | scale: float = 1,
30 | max_period: int = 10000,
31 | ):
32 | """
33 | This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
34 |
35 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
36 | These may be fractional.
37 | :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
38 | embeddings. :return: an [N x dim] Tensor of positional embeddings.
39 | """
40 | assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
41 |
42 | half_dim = embedding_dim // 2
43 | exponent = -math.log(max_period) * torch.arange(
44 | start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
45 | )
46 | exponent = exponent / (half_dim - downscale_freq_shift)
47 |
48 | emb = torch.exp(exponent)
49 | emb = timesteps[:, None].float() * emb[None, :]
50 |
51 | # scale embeddings
52 | emb = scale * emb
53 |
54 | # concat sine and cosine embeddings
55 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
56 |
57 | # flip sine and cosine embeddings
58 | if flip_sin_to_cos:
59 | emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
60 |
61 | # zero pad
62 | if embedding_dim % 2 == 1:
63 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
64 | return emb
65 |
66 |
67 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
68 | """
69 | grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
70 | [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
71 | """
72 | grid_h = np.arange(grid_size, dtype=np.float32)
73 | grid_w = np.arange(grid_size, dtype=np.float32)
74 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
75 | grid = np.stack(grid, axis=0)
76 |
77 | grid = grid.reshape([2, 1, grid_size, grid_size])
78 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
79 | if cls_token and extra_tokens > 0:
80 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
81 | return pos_embed
82 |
83 |
84 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
85 | if embed_dim % 2 != 0:
86 | raise ValueError("embed_dim must be divisible by 2")
87 |
88 | # use half of dimensions to encode grid_h
89 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
90 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
91 |
92 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
93 | return emb
94 |
95 |
96 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
97 | """
98 | embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
99 | """
100 | if embed_dim % 2 != 0:
101 | raise ValueError("embed_dim must be divisible by 2")
102 |
103 | omega = np.arange(embed_dim // 2, dtype=np.float64)
104 | omega /= embed_dim / 2.0
105 | omega = 1.0 / 10000**omega # (D/2,)
106 |
107 | pos = pos.reshape(-1) # (M,)
108 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
109 |
110 | emb_sin = np.sin(out) # (M, D/2)
111 | emb_cos = np.cos(out) # (M, D/2)
112 |
113 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
114 | return emb
115 |
116 |
117 | class PatchEmbed(nn.Module):
118 | """2D Image to Patch Embedding"""
119 |
120 | def __init__(
121 | self,
122 | height=224,
123 | width=224,
124 | patch_size=16,
125 | in_channels=3,
126 | embed_dim=768,
127 | layer_norm=False,
128 | flatten=True,
129 | bias=True,
130 | ):
131 | super().__init__()
132 |
133 | num_patches = (height // patch_size) * (width // patch_size)
134 | self.flatten = flatten
135 | self.layer_norm = layer_norm
136 |
137 | self.proj = nn.Conv2d(
138 | in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
139 | )
140 | if layer_norm:
141 | self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
142 | else:
143 | self.norm = None
144 |
145 | pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
146 | self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
147 |
148 | def forward(self, latent):
149 | latent = self.proj(latent)
150 | if self.flatten:
151 | latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
152 | if self.layer_norm:
153 | latent = self.norm(latent)
154 | return latent + self.pos_embed
155 |
156 |
157 | class TimestepEmbedding(nn.Module):
158 | def __init__(
159 | self,
160 | in_channels: int,
161 | time_embed_dim: int,
162 | act_fn: str = "silu",
163 | out_dim: int = None,
164 | post_act_fn: Optional[str] = None,
165 | cond_proj_dim=None,
166 | ):
167 | super().__init__()
168 |
169 | self.linear_1 = nn.Linear(in_channels, time_embed_dim)
170 |
171 | if cond_proj_dim is not None:
172 | self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
173 | else:
174 | self.cond_proj = None
175 |
176 | self.act = get_activation(act_fn)
177 |
178 | if out_dim is not None:
179 | time_embed_dim_out = out_dim
180 | else:
181 | time_embed_dim_out = time_embed_dim
182 | self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
183 |
184 | if post_act_fn is None:
185 | self.post_act = None
186 | else:
187 | self.post_act = get_activation(post_act_fn)
188 |
189 | def forward(self, sample, condition=None):
190 | if condition is not None:
191 | sample = sample + self.cond_proj(condition)
192 | sample = self.linear_1(sample)
193 |
194 | if self.act is not None:
195 | sample = self.act(sample)
196 |
197 | sample = self.linear_2(sample)
198 |
199 | if self.post_act is not None:
200 | sample = self.post_act(sample)
201 | return sample
202 |
203 |
204 | class Timesteps(nn.Module):
205 | def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
206 | super().__init__()
207 | self.num_channels = num_channels
208 | self.flip_sin_to_cos = flip_sin_to_cos
209 | self.downscale_freq_shift = downscale_freq_shift
210 |
211 | def forward(self, timesteps):
212 | t_emb = get_timestep_embedding(
213 | timesteps,
214 | self.num_channels,
215 | flip_sin_to_cos=self.flip_sin_to_cos,
216 | downscale_freq_shift=self.downscale_freq_shift,
217 | )
218 | return t_emb
219 |
220 |
221 | class GaussianFourierProjection(nn.Module):
222 | """Gaussian Fourier embeddings for noise levels."""
223 |
224 | def __init__(
225 | self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
226 | ):
227 | super().__init__()
228 | self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
229 | self.log = log
230 | self.flip_sin_to_cos = flip_sin_to_cos
231 |
232 | if set_W_to_weight:
233 | # to delete later
234 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
235 |
236 | self.weight = self.W
237 |
238 | def forward(self, x):
239 | if self.log:
240 | x = torch.log(x)
241 |
242 | x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
243 |
244 | if self.flip_sin_to_cos:
245 | out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
246 | else:
247 | out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
248 | return out
249 |
250 |
251 | class ImagePositionalEmbeddings(nn.Module):
252 | """
253 | Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
254 | height and width of the latent space.
255 |
256 | For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
257 |
258 | For VQ-diffusion:
259 |
260 | Output vector embeddings are used as input for the transformer.
261 |
262 | Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
263 |
264 | Args:
265 | num_embed (`int`):
266 | Number of embeddings for the latent pixels embeddings.
267 | height (`int`):
268 | Height of the latent image i.e. the number of height embeddings.
269 | width (`int`):
270 | Width of the latent image i.e. the number of width embeddings.
271 | embed_dim (`int`):
272 | Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
273 | """
274 |
275 | def __init__(
276 | self,
277 | num_embed: int,
278 | height: int,
279 | width: int,
280 | embed_dim: int,
281 | ):
282 | super().__init__()
283 |
284 | self.height = height
285 | self.width = width
286 | self.num_embed = num_embed
287 | self.embed_dim = embed_dim
288 |
289 | self.emb = nn.Embedding(self.num_embed, embed_dim)
290 | self.height_emb = nn.Embedding(self.height, embed_dim)
291 | self.width_emb = nn.Embedding(self.width, embed_dim)
292 |
293 | def forward(self, index):
294 | emb = self.emb(index)
295 |
296 | height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
297 |
298 | # 1 x H x D -> 1 x H x 1 x D
299 | height_emb = height_emb.unsqueeze(2)
300 |
301 | width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
302 |
303 | # 1 x W x D -> 1 x 1 x W x D
304 | width_emb = width_emb.unsqueeze(1)
305 |
306 | pos_emb = height_emb + width_emb
307 |
308 | # 1 x H x W x D -> 1 x L xD
309 | pos_emb = pos_emb.view(1, self.height * self.width, -1)
310 |
311 | emb = emb + pos_emb[:, : emb.shape[1], :]
312 |
313 | return emb
314 |
315 |
316 | class LabelEmbedding(nn.Module):
317 | """
318 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
319 |
320 | Args:
321 | num_classes (`int`): The number of classes.
322 | hidden_size (`int`): The size of the vector embeddings.
323 | dropout_prob (`float`): The probability of dropping a label.
324 | """
325 |
326 | def __init__(self, num_classes, hidden_size, dropout_prob):
327 | super().__init__()
328 | use_cfg_embedding = dropout_prob > 0
329 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
330 | self.num_classes = num_classes
331 | self.dropout_prob = dropout_prob
332 |
333 | def token_drop(self, labels, force_drop_ids=None):
334 | """
335 | Drops labels to enable classifier-free guidance.
336 | """
337 | if force_drop_ids is None:
338 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
339 | else:
340 | drop_ids = torch.tensor(force_drop_ids == 1)
341 | labels = torch.where(drop_ids, self.num_classes, labels)
342 | return labels
343 |
344 | def forward(self, labels: torch.LongTensor, force_drop_ids=None):
345 | use_dropout = self.dropout_prob > 0
346 | if (self.training and use_dropout) or (force_drop_ids is not None):
347 | labels = self.token_drop(labels, force_drop_ids)
348 | embeddings = self.embedding_table(labels)
349 | return embeddings
350 |
351 |
352 | class TextImageProjection(nn.Module):
353 | def __init__(
354 | self,
355 | text_embed_dim: int = 1024,
356 | image_embed_dim: int = 768,
357 | cross_attention_dim: int = 768,
358 | num_image_text_embeds: int = 10,
359 | ):
360 | super().__init__()
361 |
362 | self.num_image_text_embeds = num_image_text_embeds
363 | self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
364 | self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
365 |
366 | def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
367 | batch_size = text_embeds.shape[0]
368 |
369 | # image
370 | image_text_embeds = self.image_embeds(image_embeds)
371 | image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
372 |
373 | # text
374 | text_embeds = self.text_proj(text_embeds)
375 |
376 | return torch.cat([image_text_embeds, text_embeds], dim=1)
377 |
378 |
379 | class ImageProjection(nn.Module):
380 | def __init__(
381 | self,
382 | image_embed_dim: int = 768,
383 | cross_attention_dim: int = 768,
384 | num_image_text_embeds: int = 32,
385 | ):
386 | super().__init__()
387 |
388 | self.num_image_text_embeds = num_image_text_embeds
389 | self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
390 | self.norm = nn.LayerNorm(cross_attention_dim)
391 |
392 | def forward(self, image_embeds: torch.FloatTensor):
393 | batch_size = image_embeds.shape[0]
394 |
395 | # image
396 | image_embeds = self.image_embeds(image_embeds)
397 | image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
398 | image_embeds = self.norm(image_embeds)
399 | return image_embeds
400 |
401 |
402 | class CombinedTimestepLabelEmbeddings(nn.Module):
403 | def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
404 | super().__init__()
405 |
406 | self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
407 | self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
408 | self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
409 |
410 | def forward(self, timestep, class_labels, hidden_dtype=None):
411 | timesteps_proj = self.time_proj(timestep)
412 | timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
413 |
414 | class_labels = self.class_embedder(class_labels) # (N, D)
415 |
416 | conditioning = timesteps_emb + class_labels # (N, D)
417 |
418 | return conditioning
419 |
420 |
421 | class TextTimeEmbedding(nn.Module):
422 | def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
423 | super().__init__()
424 | self.norm1 = nn.LayerNorm(encoder_dim)
425 | self.pool = AttentionPooling(num_heads, encoder_dim)
426 | self.proj = nn.Linear(encoder_dim, time_embed_dim)
427 | self.norm2 = nn.LayerNorm(time_embed_dim)
428 |
429 | def forward(self, hidden_states):
430 | hidden_states = self.norm1(hidden_states)
431 | hidden_states = self.pool(hidden_states)
432 | hidden_states = self.proj(hidden_states)
433 | hidden_states = self.norm2(hidden_states)
434 | return hidden_states
435 |
436 |
437 | class TextImageTimeEmbedding(nn.Module):
438 | def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
439 | super().__init__()
440 | self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
441 | self.text_norm = nn.LayerNorm(time_embed_dim)
442 | self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
443 |
444 | def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
445 | # text
446 | time_text_embeds = self.text_proj(text_embeds)
447 | time_text_embeds = self.text_norm(time_text_embeds)
448 |
449 | # image
450 | time_image_embeds = self.image_proj(image_embeds)
451 |
452 | return time_image_embeds + time_text_embeds
453 |
454 |
455 | class ImageTimeEmbedding(nn.Module):
456 | def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
457 | super().__init__()
458 | self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
459 | self.image_norm = nn.LayerNorm(time_embed_dim)
460 |
461 | def forward(self, image_embeds: torch.FloatTensor):
462 | # image
463 | time_image_embeds = self.image_proj(image_embeds)
464 | time_image_embeds = self.image_norm(time_image_embeds)
465 | return time_image_embeds
466 |
467 |
468 | class ImageHintTimeEmbedding(nn.Module):
469 | def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
470 | super().__init__()
471 | self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
472 | self.image_norm = nn.LayerNorm(time_embed_dim)
473 | self.input_hint_block = nn.Sequential(
474 | nn.Conv2d(3, 16, 3, padding=1),
475 | nn.SiLU(),
476 | nn.Conv2d(16, 16, 3, padding=1),
477 | nn.SiLU(),
478 | nn.Conv2d(16, 32, 3, padding=1, stride=2),
479 | nn.SiLU(),
480 | nn.Conv2d(32, 32, 3, padding=1),
481 | nn.SiLU(),
482 | nn.Conv2d(32, 96, 3, padding=1, stride=2),
483 | nn.SiLU(),
484 | nn.Conv2d(96, 96, 3, padding=1),
485 | nn.SiLU(),
486 | nn.Conv2d(96, 256, 3, padding=1, stride=2),
487 | nn.SiLU(),
488 | nn.Conv2d(256, 4, 3, padding=1),
489 | )
490 |
491 | def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor):
492 | # image
493 | time_image_embeds = self.image_proj(image_embeds)
494 | time_image_embeds = self.image_norm(time_image_embeds)
495 | hint = self.input_hint_block(hint)
496 | return time_image_embeds, hint
497 |
498 |
499 | class AttentionPooling(nn.Module):
500 | # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
501 |
502 | def __init__(self, num_heads, embed_dim, dtype=None):
503 | super().__init__()
504 | self.dtype = dtype
505 | self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
506 | self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
507 | self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
508 | self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
509 | self.num_heads = num_heads
510 | self.dim_per_head = embed_dim // self.num_heads
511 |
512 | def forward(self, x):
513 | bs, length, width = x.size()
514 |
515 | def shape(x):
516 | # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
517 | x = x.view(bs, -1, self.num_heads, self.dim_per_head)
518 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
519 | x = x.transpose(1, 2)
520 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
521 | x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
522 | # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
523 | x = x.transpose(1, 2)
524 | return x
525 |
526 | class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
527 | x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
528 |
529 | # (bs*n_heads, class_token_length, dim_per_head)
530 | q = shape(self.q_proj(class_token))
531 | # (bs*n_heads, length+class_token_length, dim_per_head)
532 | k = shape(self.k_proj(x))
533 | v = shape(self.v_proj(x))
534 |
535 | # (bs*n_heads, class_token_length, length+class_token_length):
536 | scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
537 | weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
538 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
539 |
540 | # (bs*n_heads, dim_per_head, class_token_length)
541 | a = torch.einsum("bts,bcs->bct", weight, v)
542 |
543 | # (bs, length+1, width)
544 | a = a.reshape(bs, -1, 1).transpose(1, 2)
545 |
546 | return a[:, 0, :] # cls_token
547 |
--------------------------------------------------------------------------------
/unet1d/lora.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Optional
16 |
17 | import torch.nn.functional as F
18 | from torch import nn
19 |
20 |
21 | class LoRALinearLayer(nn.Module):
22 | def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
23 | super().__init__()
24 |
25 | if rank > min(in_features, out_features):
26 | raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
27 |
28 | self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
29 | self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
30 | # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
31 | # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
32 | self.network_alpha = network_alpha
33 | self.rank = rank
34 |
35 | nn.init.normal_(self.down.weight, std=1 / rank)
36 | nn.init.zeros_(self.up.weight)
37 |
38 | def forward(self, hidden_states):
39 | orig_dtype = hidden_states.dtype
40 | dtype = self.down.weight.dtype
41 |
42 | down_hidden_states = self.down(hidden_states.to(dtype))
43 | up_hidden_states = self.up(down_hidden_states)
44 |
45 | if self.network_alpha is not None:
46 | up_hidden_states *= self.network_alpha / self.rank
47 |
48 | return up_hidden_states.to(orig_dtype)
49 |
50 |
51 | class LoRAConv1dLayer(nn.Module):
52 | def __init__(
53 | self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None
54 | ):
55 | super().__init__()
56 |
57 | if rank > min(in_features, out_features):
58 | raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
59 |
60 | self.down = nn.Conv1d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
61 | # according to the official kohya_ss trainer kernel_size are always fixed for the up layer
62 | # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
63 | self.up = nn.Conv1d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False)
64 |
65 | # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
66 | # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
67 | self.network_alpha = network_alpha
68 | self.rank = rank
69 |
70 | nn.init.normal_(self.down.weight, std=1 / rank)
71 | nn.init.zeros_(self.up.weight)
72 |
73 | def forward(self, hidden_states):
74 | orig_dtype = hidden_states.dtype
75 | dtype = self.down.weight.dtype
76 |
77 | down_hidden_states = self.down(hidden_states.to(dtype))
78 | up_hidden_states = self.up(down_hidden_states)
79 |
80 | if self.network_alpha is not None:
81 | up_hidden_states *= self.network_alpha / self.rank
82 |
83 | return up_hidden_states.to(orig_dtype)
84 |
85 |
86 | class LoRACompatibleConv(nn.Conv1d):
87 | """
88 | A convolutional layer that can be used with LoRA.
89 | """
90 |
91 | def __init__(self, *args, lora_layer: Optional[LoRAConv1dLayer] = None, **kwargs):
92 | super().__init__(*args, **kwargs)
93 | self.lora_layer = lora_layer
94 |
95 | def set_lora_layer(self, lora_layer: Optional[LoRAConv1dLayer]):
96 | self.lora_layer = lora_layer
97 |
98 | def forward(self, x):
99 | if self.lora_layer is None:
100 | # make sure to the functional Conv2D function as otherwise torch.compile's graph will break
101 | # see: https://github.com/huggingface/diffusers/pull/4315
102 | return F.conv1d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
103 | else:
104 | return super().forward(x) + self.lora_layer(x)
105 |
106 |
107 | class LoRACompatibleLinear(nn.Linear):
108 | """
109 | A Linear layer that can be used with LoRA.
110 | """
111 |
112 | def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
113 | super().__init__(*args, **kwargs)
114 | self.lora_layer = lora_layer
115 |
116 | def set_lora_layer(self, lora_layer: Optional[LoRAConv1dLayer]):
117 | self.lora_layer = lora_layer
118 |
119 | def forward(self, x):
120 | if self.lora_layer is None:
121 | return super().forward(x)
122 | else:
123 | return super().forward(x) + self.lora_layer(x)
124 |
--------------------------------------------------------------------------------
/unet1d/outputs.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Generic utilities
16 | """
17 |
18 | from collections import OrderedDict
19 | from dataclasses import fields
20 | from typing import Any, Tuple
21 |
22 | import numpy as np
23 |
24 | def is_tensor(x):
25 | """
26 | Tests if `x` is a `torch.Tensor` or `np.ndarray`.
27 | """
28 | import torch
29 |
30 | if isinstance(x, torch.Tensor):
31 | return True
32 |
33 | return isinstance(x, np.ndarray)
34 |
35 |
36 | class BaseOutput(OrderedDict):
37 | """
38 | Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
39 | tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
40 | Python dictionary.
41 |
42 |
43 |
44 | You can't unpack a [`BaseOutput`] directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple
45 | first.
46 |
47 |
48 | """
49 |
50 | def __post_init__(self):
51 | class_fields = fields(self)
52 |
53 | # Safety and consistency checks
54 | if not len(class_fields):
55 | raise ValueError(f"{self.__class__.__name__} has no fields.")
56 |
57 | first_field = getattr(self, class_fields[0].name)
58 | other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
59 |
60 | if other_fields_are_none and isinstance(first_field, dict):
61 | for key, value in first_field.items():
62 | self[key] = value
63 | else:
64 | for field in class_fields:
65 | v = getattr(self, field.name)
66 | if v is not None:
67 | self[field.name] = v
68 |
69 | def __delitem__(self, *args, **kwargs):
70 | raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
71 |
72 | def setdefault(self, *args, **kwargs):
73 | raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
74 |
75 | def pop(self, *args, **kwargs):
76 | raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
77 |
78 | def update(self, *args, **kwargs):
79 | raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
80 |
81 | def __getitem__(self, k):
82 | if isinstance(k, str):
83 | inner_dict = dict(self.items())
84 | return inner_dict[k]
85 | else:
86 | return self.to_tuple()[k]
87 |
88 | def __setattr__(self, name, value):
89 | if name in self.keys() and value is not None:
90 | # Don't call self.__setitem__ to avoid recursion errors
91 | super().__setitem__(name, value)
92 | super().__setattr__(name, value)
93 |
94 | def __setitem__(self, key, value):
95 | # Will raise a KeyException if needed
96 | super().__setitem__(key, value)
97 | # Don't call self.__setattr__ to avoid recursion errors
98 | super().__setattr__(key, value)
99 |
100 | def to_tuple(self) -> Tuple[Any]:
101 | """
102 | Convert self to a tuple containing all the attributes/keys that are not `None`.
103 | """
104 | return tuple(self[k] for k in self.keys())
105 |
--------------------------------------------------------------------------------
/unet1d/transformer_1d.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from dataclasses import dataclass
15 | from typing import Any, Dict, Optional
16 |
17 | import torch
18 | import torch.nn.functional as F
19 | from torch import nn
20 |
21 | from .outputs import BaseOutput
22 | from .attention import BasicTransformerBlock
23 | from .embeddings import PatchEmbed
24 | from .lora import LoRACompatibleConv, LoRACompatibleLinear
25 |
26 |
27 | @dataclass
28 | class Transformer2DModelOutput(BaseOutput):
29 | """
30 | The output of [`Transformer2DModel`].
31 |
32 | Args:
33 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
34 | The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
35 | distributions for the unnoised latent pixels.
36 | """
37 |
38 | sample: torch.FloatTensor
39 |
40 |
41 | class Transformer2DModel(nn.Module):
42 | """
43 | A 2D Transformer model for image-like data.
44 |
45 | Parameters:
46 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
47 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
48 | in_channels (`int`, *optional*):
49 | The number of channels in the input and output (specify if the input is **continuous**).
50 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
51 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
52 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
53 | sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
54 | This is fixed during training since it is used to learn a number of position embeddings.
55 | num_vector_embeds (`int`, *optional*):
56 | The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
57 | Includes the class for the masked latent pixel.
58 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
59 | num_embeds_ada_norm ( `int`, *optional*):
60 | The number of diffusion steps used during training. Pass if at least one of the norm_layers is
61 | `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
62 | added to the hidden states.
63 |
64 | During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
65 | attention_bias (`bool`, *optional*):
66 | Configure if the `TransformerBlocks` attention should contain a bias parameter.
67 | """
68 |
69 | def __init__(
70 | self,
71 | num_attention_heads: int = 16,
72 | attention_head_dim: int = 88,
73 | in_channels: Optional[int] = None,
74 | out_channels: Optional[int] = None,
75 | num_layers: int = 1,
76 | dropout: float = 0.0,
77 | norm_num_groups: int = 32,
78 | cross_attention_dim: Optional[int] = None,
79 | attention_bias: bool = False,
80 | sample_size: Optional[int] = None,
81 | num_vector_embeds: Optional[int] = None,
82 | patch_size: Optional[int] = None,
83 | activation_fn: str = "geglu",
84 | num_embeds_ada_norm: Optional[int] = None,
85 | use_linear_projection: bool = False,
86 | only_cross_attention: bool = False,
87 | upcast_attention: bool = False,
88 | norm_type: str = "layer_norm",
89 | norm_elementwise_affine: bool = True,
90 | ):
91 | super().__init__()
92 | self.use_linear_projection = use_linear_projection
93 | self.num_attention_heads = num_attention_heads
94 | self.attention_head_dim = attention_head_dim
95 | inner_dim = num_attention_heads * attention_head_dim
96 |
97 | # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
98 | # Define whether input is continuous or discrete depending on configuration
99 | self.is_input_continuous = (in_channels is not None) and (patch_size is None)
100 | self.is_input_vectorized = num_vector_embeds is not None
101 | self.is_input_patches = in_channels is not None and patch_size is not None
102 |
103 | if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
104 | deprecation_message = (
105 | f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
106 | " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
107 | " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
108 | " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
109 | " would be very nice if you could open a Pull request for the `transformer/config.json` file"
110 | )
111 | # deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
112 | norm_type = "ada_norm"
113 |
114 | if self.is_input_continuous and self.is_input_vectorized:
115 | raise ValueError(
116 | f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
117 | " sure that either `in_channels` or `num_vector_embeds` is None."
118 | )
119 | elif self.is_input_vectorized and self.is_input_patches:
120 | raise ValueError(
121 | f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
122 | " sure that either `num_vector_embeds` or `num_patches` is None."
123 | )
124 | elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
125 | raise ValueError(
126 | f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
127 | f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
128 | )
129 |
130 | # 2. Define input layers
131 | if self.is_input_continuous:
132 | self.in_channels = in_channels
133 |
134 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
135 | if use_linear_projection:
136 | self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
137 | else:
138 | self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
139 | elif self.is_input_patches:
140 | assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
141 |
142 | self.height = sample_size
143 | self.width = sample_size
144 |
145 | self.patch_size = patch_size
146 | self.pos_embed = PatchEmbed(
147 | height=sample_size,
148 | width=sample_size,
149 | patch_size=patch_size,
150 | in_channels=in_channels,
151 | embed_dim=inner_dim,
152 | )
153 |
154 | # 3. Define transformers blocks
155 | self.transformer_blocks = nn.ModuleList(
156 | [
157 | BasicTransformerBlock(
158 | inner_dim,
159 | num_attention_heads,
160 | attention_head_dim,
161 | dropout=dropout,
162 | cross_attention_dim=cross_attention_dim,
163 | activation_fn=activation_fn,
164 | num_embeds_ada_norm=num_embeds_ada_norm,
165 | attention_bias=attention_bias,
166 | only_cross_attention=only_cross_attention,
167 | upcast_attention=upcast_attention,
168 | norm_type=norm_type,
169 | norm_elementwise_affine=norm_elementwise_affine,
170 | )
171 | for d in range(num_layers)
172 | ]
173 | )
174 |
175 | # 4. Define output layers
176 | self.out_channels = in_channels if out_channels is None else out_channels
177 | if self.is_input_continuous:
178 | # TODO: should use out_channels for continuous projections
179 | if use_linear_projection:
180 | self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
181 | else:
182 | self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
183 | elif self.is_input_vectorized:
184 | self.norm_out = nn.LayerNorm(inner_dim)
185 | self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
186 | elif self.is_input_patches:
187 | self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
188 | self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
189 | self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
190 |
191 | def forward(
192 | self,
193 | hidden_states: torch.Tensor,
194 | encoder_hidden_states: Optional[torch.Tensor] = None,
195 | timestep: Optional[torch.LongTensor] = None,
196 | class_labels: Optional[torch.LongTensor] = None,
197 | cross_attention_kwargs: Dict[str, Any] = None,
198 | attention_mask: Optional[torch.Tensor] = None,
199 | encoder_attention_mask: Optional[torch.Tensor] = None,
200 | return_dict: bool = True,
201 | ):
202 | """
203 | The [`Transformer2DModel`] forward method.
204 |
205 | Args:
206 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
207 | Input `hidden_states`.
208 | encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
209 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
210 | self-attention.
211 | timestep ( `torch.LongTensor`, *optional*):
212 | Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
213 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
214 | Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
215 | `AdaLayerZeroNorm`.
216 | encoder_attention_mask ( `torch.Tensor`, *optional*):
217 | Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
218 |
219 | * Mask `(batch, sequence_length)` True = keep, False = discard.
220 | * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
221 |
222 | If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
223 | above. This bias will be added to the cross-attention scores.
224 | return_dict (`bool`, *optional*, defaults to `True`):
225 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
226 | tuple.
227 |
228 | Returns:
229 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
230 | `tuple` where the first element is the sample tensor.
231 | """
232 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
233 | # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
234 | # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
235 | # expects mask of shape:
236 | # [batch, key_tokens]
237 | # adds singleton query_tokens dimension:
238 | # [batch, 1, key_tokens]
239 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
240 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
241 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
242 | if attention_mask is not None and attention_mask.ndim == 2:
243 | # assume that mask is expressed as:
244 | # (1 = keep, 0 = discard)
245 | # convert mask into a bias that can be added to attention scores:
246 | # (keep = +0, discard = -10000.0)
247 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
248 | attention_mask = attention_mask.unsqueeze(1)
249 |
250 | # convert encoder_attention_mask to a bias the same way we do for attention_mask
251 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
252 | encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
253 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
254 |
255 | # 1. Input
256 | if self.is_input_continuous:
257 | batch, _, time_length = hidden_states.shape
258 | residual = hidden_states
259 |
260 | hidden_states = self.norm(hidden_states)
261 | if not self.use_linear_projection:
262 | hidden_states = self.proj_in(hidden_states)
263 | inner_dim = hidden_states.shape[1]
264 | hidden_states = hidden_states.permute(0, 2, 1)
265 | else:
266 | inner_dim = hidden_states.shape[1]
267 | hidden_states = hidden_states.permute(0, 2, 1)
268 | hidden_states = self.proj_in(hidden_states)
269 | elif self.is_input_vectorized:
270 | hidden_states = self.latent_image_embedding(hidden_states)
271 | elif self.is_input_patches:
272 | hidden_states = self.pos_embed(hidden_states)
273 |
274 | # 2. Blocks
275 | for block in self.transformer_blocks:
276 | hidden_states = block(
277 | hidden_states,
278 | attention_mask=attention_mask,
279 | encoder_hidden_states=encoder_hidden_states,
280 | encoder_attention_mask=encoder_attention_mask,
281 | timestep=timestep,
282 | cross_attention_kwargs=cross_attention_kwargs,
283 | class_labels=class_labels,
284 | )
285 |
286 | # 3. Output
287 | if self.is_input_continuous:
288 | if not self.use_linear_projection:
289 | hidden_states = hidden_states.permute(0, 2, 1).contiguous()
290 | hidden_states = self.proj_out(hidden_states)
291 | else:
292 | hidden_states = self.proj_out(hidden_states)
293 | hidden_states = hidden_states.permute(0, 2, 1).contiguous()
294 |
295 | output = hidden_states + residual
296 | elif self.is_input_vectorized:
297 | hidden_states = self.norm_out(hidden_states)
298 | logits = self.out(hidden_states)
299 | # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
300 | logits = logits.permute(0, 2, 1)
301 |
302 | # log(p(x_0))
303 | output = F.log_softmax(logits.double(), dim=1).float()
304 | elif self.is_input_patches:
305 | # TODO: cleanup!
306 | conditioning = self.transformer_blocks[0].norm1.emb(
307 | timestep, class_labels, hidden_dtype=hidden_states.dtype
308 | )
309 | shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
310 | hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
311 | hidden_states = self.proj_out_2(hidden_states)
312 |
313 | # unpatchify
314 | height = width = int(hidden_states.shape[1] ** 0.5)
315 | hidden_states = hidden_states.reshape(
316 | shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
317 | )
318 | hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
319 | output = hidden_states.reshape(
320 | shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
321 | )
322 |
323 | if not return_dict:
324 | return (output,)
325 |
326 | return Transformer2DModelOutput(sample=output)
327 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import re
4 | import sys
5 | import argparse
6 | import logging
7 | import json
8 | import subprocess
9 | import warnings
10 | import random
11 | import functools
12 |
13 | import librosa
14 | import numpy as np
15 | from scipy.io.wavfile import read
16 | import torch
17 | from torch.nn import functional as F
18 | from modules.commons import sequence_mask
19 |
20 | MATPLOTLIB_FLAG = False
21 |
22 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
23 | logger = logging
24 |
25 | f0_bin = 256
26 | f0_max = 1100.0
27 | f0_min = 50.0
28 | f0_mel_min = 1127 * np.log(1 + f0_min / 700)
29 | f0_mel_max = 1127 * np.log(1 + f0_max / 700)
30 |
31 |
32 | # def normalize_f0(f0, random_scale=True):
33 | # f0_norm = f0.clone() # create a copy of the input Tensor
34 | # batch_size, _, frame_length = f0_norm.shape
35 | # for i in range(batch_size):
36 | # means = torch.mean(f0_norm[i, 0, :])
37 | # if random_scale:
38 | # factor = random.uniform(0.8, 1.2)
39 | # else:
40 | # factor = 1
41 | # f0_norm[i, 0, :] = (f0_norm[i, 0, :] - means) * factor
42 | # return f0_norm
43 | # def normalize_f0(f0, random_scale=True):
44 | # means = torch.mean(f0[:, 0, :], dim=1, keepdim=True)
45 | # if random_scale:
46 | # factor = torch.Tensor(f0.shape[0],1).uniform_(0.8, 1.2).to(f0.device)
47 | # else:
48 | # factor = torch.ones(f0.shape[0], 1, 1).to(f0.device)
49 | # f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1)
50 | # return f0_norm
51 |
52 | def deprecated(func):
53 | """This is a decorator which can be used to mark functions
54 | as deprecated. It will result in a warning being emitted
55 | when the function is used."""
56 | @functools.wraps(func)
57 | def new_func(*args, **kwargs):
58 | warnings.simplefilter('always', DeprecationWarning) # turn off filter
59 | warnings.warn("Call to deprecated function {}.".format(func.__name__),
60 | category=DeprecationWarning,
61 | stacklevel=2)
62 | warnings.simplefilter('default', DeprecationWarning) # reset filter
63 | return func(*args, **kwargs)
64 | return new_func
65 |
66 | def normalize_f0(f0, uv, random_scale=True):
67 | # calculate means based on x_mask
68 | uv_sum = torch.sum(uv, dim=1, keepdim=True)
69 | uv_sum[uv_sum == 0] = 9999
70 | means = torch.sum(f0[:, 0, :] * uv, dim=1, keepdim=True) / uv_sum
71 |
72 | if random_scale:
73 | factor = torch.Tensor(f0.shape[0], 1).uniform_(0.8, 1.2).to(f0.device)
74 | else:
75 | factor = torch.ones(f0.shape[0], 1).to(f0.device)
76 | # normalize f0 based on means and factor
77 | f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1)
78 | if torch.isnan(f0_norm).any():
79 | exit(0)
80 | return f0_norm
81 |
82 | def compute_f0_uv_torchcrepe(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512,device=None,cr_threshold=0.05):
83 | from modules.crepe import CrepePitchExtractor
84 | x = wav_numpy
85 | if p_len is None:
86 | p_len = x.shape[0]//hop_length
87 | else:
88 | assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error"
89 |
90 | f0_min = 50
91 | f0_max = 1100
92 | F0Creper = CrepePitchExtractor(hop_length=hop_length,f0_min=f0_min,f0_max=f0_max,device=device,threshold=cr_threshold)
93 | f0,uv = F0Creper(x[None,:].float(),sampling_rate,pad_to=p_len)
94 | return f0,uv
95 |
96 | def plot_data_to_numpy(x, y):
97 | global MATPLOTLIB_FLAG
98 | if not MATPLOTLIB_FLAG:
99 | import matplotlib
100 | matplotlib.use("Agg")
101 | MATPLOTLIB_FLAG = True
102 | mpl_logger = logging.getLogger('matplotlib')
103 | mpl_logger.setLevel(logging.WARNING)
104 | import matplotlib.pylab as plt
105 | import numpy as np
106 |
107 | fig, ax = plt.subplots(figsize=(10, 2))
108 | plt.plot(x)
109 | plt.plot(y)
110 | plt.tight_layout()
111 |
112 | fig.canvas.draw()
113 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
114 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
115 | plt.close()
116 | return data
117 |
118 |
119 |
120 | def interpolate_f0(f0):
121 |
122 | data = np.reshape(f0, (f0.size, 1))
123 |
124 | vuv_vector = np.zeros((data.size, 1), dtype=np.float32)
125 | vuv_vector[data > 0.0] = 1.0
126 | vuv_vector[data <= 0.0] = 0.0
127 |
128 | ip_data = data
129 |
130 | frame_number = data.size
131 | last_value = 0.0
132 | for i in range(frame_number):
133 | if data[i] <= 0.0:
134 | j = i + 1
135 | for j in range(i + 1, frame_number):
136 | if data[j] > 0.0:
137 | break
138 | if j < frame_number - 1:
139 | if last_value > 0.0:
140 | step = (data[j] - data[i - 1]) / float(j - i)
141 | for k in range(i, j):
142 | ip_data[k] = data[i - 1] + step * (k - i + 1)
143 | else:
144 | for k in range(i, j):
145 | ip_data[k] = data[j]
146 | else:
147 | for k in range(i, frame_number):
148 | ip_data[k] = last_value
149 | else:
150 | ip_data[i] = data[i] # this may not be necessary
151 | last_value = data[i]
152 |
153 | return ip_data[:,0], vuv_vector[:,0]
154 |
155 |
156 | def compute_f0_parselmouth(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512):
157 | import parselmouth
158 | x = wav_numpy
159 | if p_len is None:
160 | p_len = x.shape[0]//hop_length
161 | else:
162 | assert abs(p_len-x.shape[0]//hop_length) < 4, "pad length error"
163 | time_step = hop_length / sampling_rate * 1000
164 | f0_min = 50
165 | f0_max = 1100
166 | f0 = parselmouth.Sound(x, sampling_rate).to_pitch_ac(
167 | time_step=time_step / 1000, voicing_threshold=0.6,
168 | pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
169 |
170 | pad_size=(p_len - len(f0) + 1) // 2
171 | if(pad_size>0 or p_len - len(f0) - pad_size>0):
172 | f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
173 | return f0
174 |
175 | def resize_f0(x, target_len):
176 | source = np.array(x)
177 | source[source<0.001] = np.nan
178 | target = np.interp(np.arange(0, len(source)*target_len, len(source))/ target_len, np.arange(0, len(source)), source)
179 | res = np.nan_to_num(target)
180 | return res
181 |
182 | def compute_f0_dio(wav_numpy, p_len=None, sampling_rate=44100, hop_length=512):
183 | import pyworld
184 | if p_len is None:
185 | p_len = wav_numpy.shape[0]//hop_length
186 | f0, t = pyworld.dio(
187 | wav_numpy.astype(np.double),
188 | fs=sampling_rate,
189 | f0_ceil=800,
190 | frame_period=1000 * hop_length / sampling_rate,
191 | )
192 | f0 = pyworld.stonemask(wav_numpy.astype(np.double), f0, t, sampling_rate)
193 | for index, pitch in enumerate(f0):
194 | f0[index] = round(pitch, 1)
195 | return resize_f0(f0, p_len)
196 |
197 | def f0_to_coarse(f0):
198 | is_torch = isinstance(f0, torch.Tensor)
199 | f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
200 | f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
201 |
202 | f0_mel[f0_mel <= 1] = 1
203 | f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
204 | f0_coarse = (f0_mel + 0.5).int() if is_torch else np.rint(f0_mel).astype(np.int)
205 | assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min())
206 | return f0_coarse
207 |
208 |
209 | def get_hubert_model():
210 | vec_path = "hubert/checkpoint_best_legacy_500.pt"
211 | print("load model(s) from {}".format(vec_path))
212 | from fairseq import checkpoint_utils
213 | models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
214 | [vec_path],
215 | suffix="",
216 | )
217 | model = models[0]
218 | model.eval()
219 | return model
220 |
221 | def get_hubert_content(hmodel, wav_16k_tensor):
222 | feats = wav_16k_tensor
223 | if feats.dim() == 2: # double channels
224 | feats = feats.mean(-1)
225 | assert feats.dim() == 1, feats.dim()
226 | feats = feats.view(1, -1)
227 | padding_mask = torch.BoolTensor(feats.shape).fill_(False)
228 | inputs = {
229 | "source": feats.to(wav_16k_tensor.device),
230 | "padding_mask": padding_mask.to(wav_16k_tensor.device),
231 | "output_layer": 12, # layer 12
232 | }
233 | with torch.no_grad():
234 | logits = hmodel.extract_features(**inputs)
235 | feats = hmodel.final_proj(logits[0])
236 | return feats.transpose(1, 2)
237 |
238 |
239 | def get_content(cmodel, y):
240 | with torch.no_grad():
241 | c = cmodel.extract_features(y.squeeze(1))[0]
242 | c = c.transpose(1, 2)
243 | return c
244 |
245 |
246 |
247 | def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
248 | assert os.path.isfile(checkpoint_path)
249 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
250 | iteration = checkpoint_dict['iteration']
251 | learning_rate = checkpoint_dict['learning_rate']
252 | if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None:
253 | optimizer.load_state_dict(checkpoint_dict['optimizer'])
254 | saved_state_dict = checkpoint_dict['model']
255 | if hasattr(model, 'module'):
256 | state_dict = model.module.state_dict()
257 | else:
258 | state_dict = model.state_dict()
259 | new_state_dict = {}
260 | for k, v in state_dict.items():
261 | try:
262 | # assert "dec" in k or "disc" in k
263 | # print("load", k)
264 | new_state_dict[k] = saved_state_dict[k]
265 | assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
266 | except:
267 | print("error, %s is not in the checkpoint" % k)
268 | logger.info("%s is not in the checkpoint" % k)
269 | new_state_dict[k] = v
270 | if hasattr(model, 'module'):
271 | model.module.load_state_dict(new_state_dict)
272 | else:
273 | model.load_state_dict(new_state_dict)
274 | print("load ")
275 | logger.info("Loaded checkpoint '{}' (iteration {})".format(
276 | checkpoint_path, iteration))
277 | return model, optimizer, learning_rate, iteration
278 |
279 |
280 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
281 | logger.info("Saving model and optimizer state at iteration {} to {}".format(
282 | iteration, checkpoint_path))
283 | if hasattr(model, 'module'):
284 | state_dict = model.module.state_dict()
285 | else:
286 | state_dict = model.state_dict()
287 | torch.save({'model': state_dict,
288 | 'iteration': iteration,
289 | 'optimizer': optimizer.state_dict(),
290 | 'learning_rate': learning_rate}, checkpoint_path)
291 |
292 | def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True):
293 | """Freeing up space by deleting saved ckpts
294 |
295 | Arguments:
296 | path_to_models -- Path to the model directory
297 | n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
298 | sort_by_time -- True -> chronologically delete ckpts
299 | False -> lexicographically delete ckpts
300 | """
301 | ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
302 | name_key = (lambda _f: int(re.compile('model-(\d+)\.pt').match(_f).group(1)))
303 | time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)))
304 | sort_key = time_key if sort_by_time else name_key
305 | x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')], key=sort_key)
306 | to_del = [os.path.join(path_to_models, fn) for fn in
307 | (x_sorted('model')[:-n_ckpts_to_keep])]
308 | del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
309 | del_routine = lambda x: [os.remove(x), del_info(x)]
310 | rs = [del_routine(fn) for fn in to_del]
311 |
312 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
313 | for k, v in scalars.items():
314 | writer.add_scalar(k, v, global_step)
315 | for k, v in histograms.items():
316 | writer.add_histogram(k, v, global_step)
317 | for k, v in images.items():
318 | writer.add_image(k, v, global_step, dataformats='HWC')
319 | for k, v in audios.items():
320 | writer.add_audio(k, v, global_step, audio_sampling_rate)
321 |
322 |
323 | def latest_checkpoint_path(dir_path, regex="G_*.pth"):
324 | f_list = glob.glob(os.path.join(dir_path, regex))
325 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
326 | x = f_list[-1]
327 | print(x)
328 | return x
329 |
330 |
331 | def plot_spectrogram_to_numpy(spectrogram):
332 | global MATPLOTLIB_FLAG
333 | if not MATPLOTLIB_FLAG:
334 | import matplotlib
335 | matplotlib.use("Agg")
336 | MATPLOTLIB_FLAG = True
337 | mpl_logger = logging.getLogger('matplotlib')
338 | mpl_logger.setLevel(logging.WARNING)
339 | import matplotlib.pylab as plt
340 | import numpy as np
341 |
342 | fig, ax = plt.subplots(figsize=(10,2))
343 | im = ax.imshow(spectrogram, aspect="auto", origin="lower",
344 | interpolation='none')
345 | plt.colorbar(im, ax=ax)
346 | plt.xlabel("Frames")
347 | plt.ylabel("Channels")
348 | plt.tight_layout()
349 |
350 | fig.canvas.draw()
351 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
352 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
353 | plt.close()
354 | return data
355 |
356 |
357 | def plot_alignment_to_numpy(alignment, info=None):
358 | global MATPLOTLIB_FLAG
359 | if not MATPLOTLIB_FLAG:
360 | import matplotlib
361 | matplotlib.use("Agg")
362 | MATPLOTLIB_FLAG = True
363 | mpl_logger = logging.getLogger('matplotlib')
364 | mpl_logger.setLevel(logging.WARNING)
365 | import matplotlib.pylab as plt
366 | import numpy as np
367 |
368 | fig, ax = plt.subplots(figsize=(6, 4))
369 | im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
370 | interpolation='none')
371 | fig.colorbar(im, ax=ax)
372 | xlabel = 'Decoder timestep'
373 | if info is not None:
374 | xlabel += '\n\n' + info
375 | plt.xlabel(xlabel)
376 | plt.ylabel('Encoder timestep')
377 | plt.tight_layout()
378 |
379 | fig.canvas.draw()
380 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
381 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
382 | plt.close()
383 | return data
384 |
385 |
386 | def load_wav_to_torch(full_path):
387 | sampling_rate, data = read(full_path)
388 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate
389 |
390 |
391 | def load_filepaths_and_text(filename, split="|"):
392 | with open(filename, encoding='utf-8') as f:
393 | filepaths_and_text = [line.strip().split(split) for line in f]
394 | return filepaths_and_text
395 |
396 |
397 | def get_hparams(init=True):
398 | parser = argparse.ArgumentParser()
399 | parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
400 | help='JSON file for configuration')
401 | parser.add_argument('-m', '--model', type=str, required=True,
402 | help='Model name')
403 |
404 | args = parser.parse_args()
405 | model_dir = os.path.join("./logs", args.model)
406 |
407 | if not os.path.exists(model_dir):
408 | os.makedirs(model_dir)
409 |
410 | config_path = args.config
411 | config_save_path = os.path.join(model_dir, "config.json")
412 | if init:
413 | with open(config_path, "r") as f:
414 | data = f.read()
415 | with open(config_save_path, "w") as f:
416 | f.write(data)
417 | else:
418 | with open(config_save_path, "r") as f:
419 | data = f.read()
420 | config = json.loads(data)
421 |
422 | hparams = HParams(**config)
423 | hparams.model_dir = model_dir
424 | return hparams
425 |
426 |
427 | def get_hparams_from_dir(model_dir):
428 | config_save_path = os.path.join(model_dir, "config.json")
429 | with open(config_save_path, "r") as f:
430 | data = f.read()
431 | config = json.loads(data)
432 |
433 | hparams =HParams(**config)
434 | hparams.model_dir = model_dir
435 | return hparams
436 |
437 |
438 | def get_hparams_from_file(config_path):
439 | with open(config_path, "r") as f:
440 | data = f.read()
441 | config = json.loads(data)
442 |
443 | hparams =HParams(**config)
444 | return hparams
445 |
446 |
447 | def check_git_hash(model_dir):
448 | source_dir = os.path.dirname(os.path.realpath(__file__))
449 | if not os.path.exists(os.path.join(source_dir, ".git")):
450 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
451 | source_dir
452 | ))
453 | return
454 |
455 | cur_hash = subprocess.getoutput("git rev-parse HEAD")
456 |
457 | path = os.path.join(model_dir, "githash")
458 | if os.path.exists(path):
459 | saved_hash = open(path).read()
460 | if saved_hash != cur_hash:
461 | logger.warn("git hash values are different. {}(saved) != {}(current)".format(
462 | saved_hash[:8], cur_hash[:8]))
463 | else:
464 | open(path, "w").write(cur_hash)
465 |
466 |
467 | def get_logger(model_dir, filename="train.log"):
468 | global logger
469 | logger = logging.getLogger(os.path.basename(model_dir))
470 | logger.setLevel(logging.DEBUG)
471 |
472 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
473 | if not os.path.exists(model_dir):
474 | os.makedirs(model_dir)
475 | h = logging.FileHandler(os.path.join(model_dir, filename))
476 | h.setLevel(logging.DEBUG)
477 | h.setFormatter(formatter)
478 | logger.addHandler(h)
479 | return logger
480 |
481 |
482 | def repeat_expand_2d(content, target_len):
483 | # content : [h, t]
484 |
485 | src_len = content.shape[-1]
486 | target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device)
487 | temp = torch.arange(src_len+1) * target_len / src_len
488 | current_pos = 0
489 | for i in range(target_len):
490 | if i < temp[current_pos+1]:
491 | target[:, i] = content[:, current_pos]
492 | else:
493 | current_pos += 1
494 | target[:, i] = content[:, current_pos]
495 |
496 | return target
497 |
498 |
499 | def mix_model(model_paths,mix_rate,mode):
500 | mix_rate = torch.FloatTensor(mix_rate)/100
501 | model_tem = torch.load(model_paths[0])
502 | models = [torch.load(path)["model"] for path in model_paths]
503 | if mode == 0:
504 | mix_rate = F.softmax(mix_rate,dim=0)
505 | for k in model_tem["model"].keys():
506 | model_tem["model"][k] = torch.zeros_like(model_tem["model"][k])
507 | for i,model in enumerate(models):
508 | model_tem["model"][k] += model[k]*mix_rate[i]
509 | torch.save(model_tem,os.path.join(os.path.curdir,"output.pth"))
510 | return os.path.join(os.path.curdir,"output.pth")
511 |
512 | class HParams():
513 | def __init__(self, **kwargs):
514 | for k, v in kwargs.items():
515 | if type(v) == dict:
516 | v = HParams(**v)
517 | self[k] = v
518 |
519 | def keys(self):
520 | return self.__dict__.keys()
521 |
522 | def items(self):
523 | return self.__dict__.items()
524 |
525 | def values(self):
526 | return self.__dict__.values()
527 |
528 | def __len__(self):
529 | return len(self.__dict__)
530 |
531 | def __getitem__(self, key):
532 | return getattr(self, key)
533 |
534 | def __setitem__(self, key, value):
535 | return setattr(self, key, value)
536 |
537 | def __contains__(self, key):
538 | return key in self.__dict__
539 |
540 | def __repr__(self):
541 | return self.__dict__.__repr__()
--------------------------------------------------------------------------------