├── .gitattributes
├── ComoSVC.py
├── Content
└── put_the_checkpoint_here
├── Features.py
├── LICENSE
├── README.md
├── Readme_CN.md
├── Vocoder.py
├── como.py
├── configs
└── the_config_files
├── configs_template
└── diffusion_template.yaml
├── data_loaders.py
├── dataset
└── the_prprocessed_data
├── dataset_slice
└── if_you_need_to_slice
├── filelists
└── put_the_txt_here
├── infer_tool.py
├── inference_main.py
├── logs
└── the_log_files
├── mel_processing.py
├── meldataset.py
├── pitch_extractor.py
├── preparation_slice.py
├── preprocessing1_resample.py
├── preprocessing2_flist.py
├── preprocessing3_feature.py
├── requirements.txt
├── saver.py
├── slicer.py
├── solver.py
├── train.py
├── utils.py
├── vocoder
├── __init__.py
└── m4gan
│ ├── __init__.py
│ ├── hifigan.py
│ └── parallel_wavegan.py
└── wavenet.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/ComoSVC.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import yaml
5 | from Vocoder import Vocoder
6 | from como import Como
7 |
8 |
9 | class DotDict(dict):
10 | def __getattr__(*args):
11 | val = dict.get(*args)
12 | return DotDict(val) if type(val) is dict else val
13 |
14 | __setattr__ = dict.__setitem__
15 | __delattr__ = dict.__delitem__
16 |
17 |
18 | def load_model_vocoder(
19 | model_path,
20 | device='cpu',
21 | config_path = None,
22 | total_steps=1,
23 | teacher=False
24 | ):
25 | if config_path is None:
26 | config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml')
27 | else:
28 | config_file = config_path
29 |
30 | with open(config_file, "r") as config:
31 | args = yaml.safe_load(config)
32 | args = DotDict(args)
33 |
34 | # load vocoder
35 | vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=device)
36 |
37 | # load model
38 | model = ComoSVC(
39 | args.data.encoder_out_channels,
40 | args.model.n_spk,
41 | args.model.use_pitch_aug,
42 | vocoder.dimension,
43 | args.model.n_layers,
44 | args.model.n_chans,
45 | args.model.n_hidden,
46 | total_steps,
47 | teacher
48 | )
49 |
50 | print(' [Loading] ' + model_path)
51 | ckpt = torch.load(model_path, map_location=torch.device(device))
52 | model.to(device)
53 | model.load_state_dict(ckpt['model'],strict=False)
54 | model.eval()
55 | return model, vocoder, args
56 |
57 |
58 | class ComoSVC(nn.Module):
59 | def __init__(
60 | self,
61 | input_channel,
62 | n_spk,
63 | use_pitch_aug=True,
64 | out_dims=128, # define in como
65 | n_layers=20,
66 | n_chans=384,
67 | n_hidden=100,
68 | total_steps=1,
69 | teacher=True
70 | ):
71 | super().__init__()
72 |
73 | self.unit_embed = nn.Linear(input_channel, n_hidden)
74 | self.f0_embed = nn.Linear(1, n_hidden)
75 | self.volume_embed = nn.Linear(1, n_hidden)
76 | self.teacher=teacher
77 |
78 | if use_pitch_aug:
79 | self.aug_shift_embed = nn.Linear(1, n_hidden, bias=False)
80 | else:
81 | self.aug_shift_embed = None
82 | self.n_spk = n_spk
83 | if n_spk is not None and n_spk > 1:
84 | self.spk_embed = nn.Embedding(n_spk, n_hidden)
85 | self.n_hidden = n_hidden
86 | self.decoder = Como(out_dims, n_layers, n_chans, n_hidden, total_steps, teacher)
87 | self.input_channel = input_channel
88 |
89 | def forward(self, units, f0, volume, spk_id = None, aug_shift = None,
90 | gt_spec=None, infer=True):
91 |
92 | '''
93 | input:
94 | B x n_frames x n_unit
95 | return:
96 | dict of B x n_frames x feat
97 | '''
98 |
99 | x = self.unit_embed(units) + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume)
100 |
101 | if self.n_spk is not None and self.n_spk > 1:
102 | if spk_id.shape[1] > 1:
103 | g = spk_id.reshape((spk_id.shape[0], spk_id.shape[1], 1, 1, 1)) # [N, S, B, 1, 1]
104 | g = g * self.speaker_map # [N, S, B, 1, H]
105 | g = torch.sum(g, dim=1) # [N, 1, B, 1, H]
106 | g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N]
107 | x = x + g
108 | else:
109 | x = x + self.spk_embed(spk_id)
110 |
111 | if self.aug_shift_embed is not None and aug_shift is not None:
112 | x = x + self.aug_shift_embed(aug_shift / 5)
113 |
114 | if not infer:
115 | output = self.decoder(gt_spec,x,infer=False)
116 | else:
117 | output = self.decoder(gt_spec,x,infer=True)
118 |
119 | return output
120 |
121 |
--------------------------------------------------------------------------------
/Content/put_the_checkpoint_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Grace9994/CoMoSVC/2ea8e644e2c5b3a8afc0762e870b9daacf3b5be5/Content/put_the_checkpoint_here
--------------------------------------------------------------------------------
/Features.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pyworld
4 | from fairseq import checkpoint_utils
5 |
6 |
7 | class SpeechEncoder(object):
8 | def __init__(self, vec_path="Content/checkpoint_best_legacy_500.pt", device=None):
9 | self.model = None # This is Model
10 | self.hidden_dim = 768
11 | pass
12 |
13 |
14 | def encoder(self, wav):
15 | """
16 | input: wav:[signal_length]
17 | output: embedding:[batchsize,hidden_dim,wav_frame]
18 | """
19 | pass
20 |
21 |
22 |
23 | class ContentVec768L12(SpeechEncoder):
24 | def __init__(self, vec_path="Content/checkpoint_best_legacy_500.pt", device=None):
25 | super().__init__()
26 | print("load model(s) from {}".format(vec_path))
27 | self.hidden_dim = 768
28 | models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
29 | [vec_path],
30 | suffix="",
31 | )
32 | if device is None:
33 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34 | else:
35 | self.dev = torch.device(device)
36 | self.model = models[0].to(self.dev)
37 | self.model.eval()
38 |
39 | def encoder(self, wav):
40 | feats = wav
41 | if feats.dim() == 2: # double channels
42 | feats = feats.mean(-1)
43 | assert feats.dim() == 1, feats.dim()
44 | feats = feats.view(1, -1)
45 | padding_mask = torch.BoolTensor(feats.shape).fill_(False)
46 | inputs = {
47 | "source": feats.to(wav.device),
48 | "padding_mask": padding_mask.to(wav.device),
49 | "output_layer": 12, # layer 12
50 | }
51 | with torch.no_grad():
52 | logits = self.model.extract_features(**inputs)
53 | return logits[0].transpose(1, 2)
54 |
55 |
56 |
57 | class F0Predictor(object):
58 | def compute_f0(self,wav,p_len):
59 | '''
60 | input: wav:[signal_length]
61 | p_len:int
62 | output: f0:[signal_length//hop_length]
63 | '''
64 | pass
65 |
66 | def compute_f0_uv(self,wav,p_len):
67 | '''
68 | input: wav:[signal_length]
69 | p_len:int
70 | output: f0:[signal_length//hop_length],uv:[signal_length//hop_length]
71 | '''
72 | pass
73 |
74 |
75 | class DioF0Predictor(F0Predictor):
76 | def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100):
77 | self.hop_length = hop_length
78 | self.f0_min = f0_min
79 | self.f0_max = f0_max
80 | self.sampling_rate = sampling_rate
81 | self.name = "dio"
82 |
83 | def interpolate_f0(self,f0):
84 | '''
85 | 对F0进行插值处理
86 | '''
87 | vuv_vector = np.zeros_like(f0, dtype=np.float32)
88 | vuv_vector[f0 > 0.0] = 1.0
89 | vuv_vector[f0 <= 0.0] = 0.0
90 |
91 | nzindex = np.nonzero(f0)[0]
92 | data = f0[nzindex]
93 | nzindex = nzindex.astype(np.float32)
94 | time_org = self.hop_length / self.sampling_rate * nzindex
95 | time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate
96 |
97 | if data.shape[0] <= 0:
98 | return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector
99 |
100 | if data.shape[0] == 1:
101 | return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector
102 |
103 | f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1])
104 |
105 | return f0,vuv_vector
106 |
107 | def resize_f0(self,x, target_len):
108 | source = np.array(x)
109 | source[source<0.001] = np.nan
110 | target = np.interp(np.arange(0, len(source)*target_len, len(source))/ target_len, np.arange(0, len(source)), source)
111 | res = np.nan_to_num(target)
112 | return res
113 |
114 | def compute_f0(self,wav,p_len=None):
115 | if p_len is None:
116 | p_len = wav.shape[0]//self.hop_length
117 | f0, t = pyworld.dio(
118 | wav.astype(np.double),
119 | fs=self.sampling_rate,
120 | f0_floor=self.f0_min,
121 | f0_ceil=self.f0_max,
122 | frame_period=1000 * self.hop_length / self.sampling_rate,
123 | )
124 | f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate)
125 | for index, pitch in enumerate(f0):
126 | f0[index] = round(pitch, 1)
127 | return self.interpolate_f0(self.resize_f0(f0, p_len))[0]
128 |
129 | def compute_f0_uv(self,wav,p_len=None):
130 | if p_len is None:
131 | p_len = wav.shape[0]//self.hop_length
132 | f0, t = pyworld.dio(
133 | wav.astype(np.double),
134 | fs=self.sampling_rate,
135 | f0_floor=self.f0_min,
136 | f0_ceil=self.f0_max,
137 | frame_period=1000 * self.hop_length / self.sampling_rate,
138 | )
139 | f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate)
140 | for index, pitch in enumerate(f0):
141 | f0[index] = round(pitch, 1)
142 | return self.interpolate_f0(self.resize_f0(f0, p_len))
143 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Yiwen LU
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
CoMoSVC: Consistency Model Based Singing Voice Conversion
3 |
4 | [中文文档](./Readme_CN.md)
5 |
6 |
7 | A consistency model based Singing Voice Conversion system is composed, which is inspired by [CoMoSpeech](https://github.com/zhenye234/CoMoSpeech): One-Step Speech and Singing Voice Synthesis via Consistency Model.
8 |
9 | This is an implemention of the paper [CoMoSVC](https://arxiv.org/pdf/2401.01792.pdf).
10 | ## Improvements
11 | The subjective evaluations are illustrated through the table below.
12 |
13 |
14 | ## Environment
15 | We have tested the code and it runs successfully on Python 3.8, so you can set up your Conda environment using the following command:
16 |
17 | ```shell
18 | conda create -n Your_Conda_Environment_Name python=3.8
19 | ```
20 | Then after activating your conda environment, you can install the required packages under it by:
21 |
22 | ```shell
23 | pip install -r requirements.txt
24 | ```
25 |
26 | ## Download the Checkpoints
27 | ### 1. m4singer_hifigan
28 |
29 | You should first download [m4singer_hifigan](https://drive.google.com/file/d/10LD3sq_zmAibl379yTW5M-LXy2l_xk6h/view) and then unzip the zip file by
30 | ```shell
31 | unzip m4singer_hifigan.zip
32 | ```
33 | The checkpoints of the vocoder will be in the `m4singer_hifigan` directory
34 |
35 | ### 2. ContentVec
36 | You should download the checkpoint [ContentVec](https://ibm.box.com/s/z1wgl1stco8ffooyatzdwsqn2psd9lrr) and the put it in the `Content` directory to extract the content feature.
37 |
38 | ### 3. m4singer_pe
39 | You should download the pitch_extractor checkpoint of the [m4singer_pe](https://drive.google.com/file/d/19QtXNeqUjY3AjvVycEt3G83lXn2HwbaJ/view) and then unzip the zip file by
40 |
41 | ```shell
42 | unzip m4singer_pe.zip
43 | ```
44 |
45 | ## Dataset Preparation
46 |
47 | You should first create the folders by
48 |
49 | ```shell
50 | mkdir dataset_raw
51 | mkdir dataset
52 | ```
53 | You can refer to different preparation methods based on your needs.
54 |
55 | Preparation With Slicing can help you remove the silent parts and slice the audio for stable training.
56 |
57 |
58 | ### 0. Preparation With Slicing
59 |
60 | Please place your original dataset in the `dataset_slice` directory.
61 |
62 | The original audios can be in any waveformat which should be specified in the command line. You can designate the length of slices you want, the unit of slice_size is milliseconds. The default wavformat and slice_size is mp3 and 10000 respectively.
63 |
64 | ```shell
65 | python preparation_slice.py -w your_wavformat -s slice_size
66 | ```
67 |
68 | ### 1. Preparation Without Slicing
69 |
70 | You can just place the dataset in the `dataset_raw` directory with the following file structure:
71 |
72 | ```
73 | dataset_raw
74 | ├───speaker0
75 | │ ├───xxx1-xxx1.wav
76 | │ ├───...
77 | │ └───Lxx-0xx8.wav
78 | └───speaker1
79 | ├───xx2-0xxx2.wav
80 | ├───...
81 | └───xxx7-xxx007.wav
82 | ```
83 |
84 |
85 | ## Preprocessing
86 |
87 | ### 1. Resample to 24000Hz and mono
88 |
89 | ```shell
90 | python preprocessing1_resample.py -n num_process
91 | ```
92 | num_process is the number of processes, the default num_process is 5.
93 |
94 | ### 2. Split the Training and Validation Datasets, and Generate Configuration Files.
95 |
96 | ```shell
97 | python preprocessing2_flist.py
98 | ```
99 |
100 |
101 | ### 3. Generate Features
102 |
103 | ```shell
104 | python preprocessing3_feature.py -c your_config_file -n num_processes
105 | ```
106 |
107 |
108 | ## Training
109 |
110 | ### 1. Train the Teacher Model
111 |
112 | ```shell
113 | python train.py
114 | ```
115 | The checkpoints will be saved in the `logs/teacher` directory
116 |
117 | ### 2. Train the Consistency Model
118 |
119 | If you want to adjust the config file, you can duplicate a new config file and modify some parameters.
120 |
121 |
122 | ```shell
123 | python train.py -t -c Your_new_configfile_path -p The_teacher_model_checkpoint_path
124 | ```
125 |
126 | ## Inference
127 | You should put the audios you want to convert under the `raw` directory firstly.
128 |
129 | ### Inference by the Teacher Model
130 |
131 | ```shell
132 | python inference_main.py -ts 50 -tm "logs/teacher/model_800000.pt" -tc "logs/teacher/config.yaml" -n "src.wav" -k 0 -s "target_singer"
133 | ```
134 | -ts refers to the total number of iterative steps during inference for the teacher model
135 |
136 | -tm refers to the teacher_model_path
137 |
138 | -tc refers to the teacher_config_path
139 |
140 | -n refers to the source audio
141 |
142 | -k refers to the pitch shift, it can be positive and negative (semitone) values
143 |
144 | -s refers to the target singer
145 |
146 | ### Inference by the Consistency Model
147 |
148 | ```shell
149 | python inference_main.py -ts 1 -cm "logs/como/model_800000.pt" -cc "logs/como/config.yaml" -n "src.wav" -k 0 -s "target_singer" -t
150 | ```
151 | -ts refers to the total number of iterative steps during inference for the student model
152 |
153 | -cm refers to the como_model_path
154 |
155 | -cc refers to the como_config_path
156 |
157 | -t means it is not the teacher model and you don't need to specify anything after it
158 |
--------------------------------------------------------------------------------
/Readme_CN.md:
--------------------------------------------------------------------------------
1 |
2 |
CoMoSVC: One-Step Consistency Model Based Singing Voice Conversion
3 |
4 |
5 | 基于一致性模型的歌声转换及克隆系统,可以一步diffusion采样进行歌声转换,是对论文[CoMoSVC](https://arxiv.org/pdf/2401.01792.pdf)的实现。工作基于[CoMoSpeech](https://github.com/zhenye234/CoMoSpeech): One-Step Speech and Singing Voice Synthesis via Consistency Model.
6 |
7 |
8 |
9 | # 环境配置
10 | Python 3.8环境下创建Conda虚拟环境:
11 |
12 | ```shell
13 | conda create -n Your_Conda_Environment_Name python=3.8
14 | ```
15 | 安装相关依赖库:
16 |
17 | ```shell
18 | pip install -r requirements.txt
19 | ```
20 | ## 下载checkpoints
21 | ### 1. m4singer_hifigan
22 | 下载vocoder [m4singer_hifigan](https://drive.google.com/file/d/10LD3sq_zmAibl379yTW5M-LXy2l_xk6h/view) 并解压
23 |
24 | ```shell
25 | unzip m4singer_hifigan.zip
26 | ```
27 |
28 | vocoder的checkoint将在`m4singer_hifigan`目录中
29 |
30 | ### 2. ContentVec
31 |
32 | 下载 [ContentVec](https://ibm.box.com/s/z1wgl1stco8ffooyatzdwsqn2psd9lrr) 放置在`Content`路径,以提取歌词内容特征。
33 |
34 | ### 3. m4singer_pe
35 |
36 | 下载pitch_extractor [m4singer_pe](https://drive.google.com/file/d/19QtXNeqUjY3AjvVycEt3G83lXn2HwbaJ/view) ,并解压到根目录
37 |
38 | ```shell
39 | unzip m4singer_pe.zip
40 | ```
41 |
42 |
43 | ## 数据准备
44 |
45 |
46 | 构造两个空文件夹
47 |
48 | ```shell
49 | mkdir dataset_raw
50 | mkdir dataset
51 | ```
52 |
53 | 请自行准备歌手的清唱录音数据,随后按照如下操作。
54 |
55 | ### 0. 带切片的数据准备流程
56 |
57 | 请将你的原始数据集放在 dataset_slice 目录下。
58 |
59 | 原始音频可以是任何波形格式,应在命令行中指定。你可以指定你想要的切片长度,切片大小的单位是毫秒。默认的文件格式和切片大小分别是mp3和10000。
60 |
61 | ```shell
62 | python preparation_slice.py -w 你的文件格式 -s 切片大小
63 | ```
64 |
65 | ### 1. 不带切片的数据准备流程
66 |
67 | 你可以只将数据集放在 `dataset_raw` 目录下,按照以下文件结构:
68 |
69 |
70 | ```
71 | dataset_raw
72 | ├───speaker0
73 | │ ├───xxx1-xxx1.wav
74 | │ ├───...
75 | │ └───Lxx-0xx8.wav
76 | └───speaker1
77 | ├───xx2-0xxx2.wav
78 | ├───...
79 | └───xxx7-xxx007.wav
80 | ```
81 |
82 |
83 | ## 预处理
84 |
85 | ### 1. 重采样为24000Hz和单声道
86 |
87 | ```shell
88 | python preprocessing1_resample.py -n num_process
89 | ```
90 | num_process 是进程数,默认5。
91 |
92 |
93 | ### 2. 分割训练和验证数据集,并生成配置文件
94 |
95 | ```shell
96 | python preprocessing2_flist.py
97 | ```
98 |
99 |
100 | ### 3. 生成特征
101 |
102 |
103 |
104 |
105 | 执行下面代码以提取所有特征
106 |
107 | ```shell
108 | python preprocessing3_feature.py -c your_config_file -n num_processes
109 | ```
110 |
111 |
112 |
113 |
114 | ## 训练
115 |
116 | ### 1. 训练 teacher model
117 |
118 | ```shell
119 | python train.py
120 | ```
121 | Checkpoint将存于 `logs/teacher` 目录中
122 |
123 | ### 2. 训练 consistency model
124 |
125 | #### 如果你想调整配置文件,你可以复制一个新的配置文件并修改一些参数。
126 |
127 |
128 |
129 | ```shell
130 | python train.py -t -c Your_new_configfile_path -p The_teacher_model_checkpoint_path
131 | ```
132 |
133 | ## 推理
134 |
135 | 你应该首先将你想要转换的音频放在`raw`目录下。
136 |
137 | ### 采用教师模型的推理
138 |
139 | ```shell
140 | python inference_main.py -ts 50 -tm "logs/teacher/model_800000.pt" -tc "logs/teacher/config.yaml" -n "src.wav" -k 0 -s "target_singer"
141 | ```
142 | -ts 教师模型推理时的迭代步数
143 |
144 | -tm 教师模型路径
145 |
146 | -tc 教师模型配置文件
147 |
148 | -n source音频路径
149 |
150 | -k pitch shift,可以是正负semitone值
151 |
152 | -s 目标歌手
153 |
154 |
155 |
156 | ### 采用CoMoSVC进行推理
157 |
158 | ```shell
159 | python inference_main.py -ts 1 -cm "logs/como/model_800000.pt" -cc "logs/como/config.yaml" -n "src.wav" -k 0 -s "target_singer" -t
160 | ```
161 | -ts 学生模型推理时的迭代步数
162 |
163 | -cm CoMoSVC模型路径
164 |
165 | -cc CoMoSVC模型配置文件
166 |
167 | -t 加上该参数并保留后续为空代表不是教师模型
168 |
--------------------------------------------------------------------------------
/Vocoder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torchaudio.transforms import Resample
4 | from vocoder.m4gan.hifigan import HifiGanGenerator
5 |
6 | from mel_processing import mel_spectrogram, MAX_WAV_VALUE
7 | import utils
8 |
9 | class Vocoder:
10 | def __init__(self, vocoder_type, vocoder_ckpt, device = None):
11 | if device is None:
12 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
13 | self.device = device # device
14 | self.vocodertype = vocoder_type
15 | if vocoder_type == 'm4-gan':
16 | self.vocoder = M4GAN(vocoder_ckpt, device = device)
17 | else:
18 | raise ValueError(f" [x] Unknown vocoder: {vocoder_type}")
19 |
20 | self.resample_kernel = {}
21 | self.vocoder_sample_rate = self.vocoder.sample_rate()
22 | self.vocoder_hop_size = self.vocoder.hop_size()
23 | self.dimension = self.vocoder.dimension()
24 |
25 | def extract(self, audio, sample_rate, keyshift=0):
26 | # resample
27 | if sample_rate == self.vocoder_sample_rate:
28 | audio_res = audio
29 | else:
30 | key_str = str(sample_rate)# 这里是24k
31 | if key_str not in self.resample_kernel:
32 | self.resample_kernel[key_str] = Resample(sample_rate, self.vocoder_sample_rate, lowpass_filter_width = 128).to(self.device)
33 |
34 | audio_res = self.resample_kernel[key_str](audio) # 对原始audio进行resample
35 |
36 | # extract
37 | mel = self.vocoder.extract(audio_res, keyshift=keyshift) # B, n_frames, bins
38 | return mel
39 |
40 | def infer(self, mel, f0):
41 | f0 = f0[:,:mel.size(1),0]
42 | audio = self.vocoder(mel,f0)
43 | return audio
44 |
45 |
46 | class M4GAN(torch.nn.Module):
47 | def __init__(self, model_path, device=None):
48 | super().__init__()
49 | if device is None:
50 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
51 | self.device = device
52 | self.model_path = model_path
53 | self.model = None
54 | self.h = utils.load_config(os.path.join(os.path.split(model_path)[0], 'config.yaml'))
55 |
56 | def sample_rate(self):
57 | return self.h.audio_sample_rate
58 |
59 | def hop_size(self):
60 | return self.h.hop_size
61 |
62 | def dimension(self):
63 | return self.h.audio_num_mel_bins
64 |
65 | def extract(self, audio, keyshift=0):
66 |
67 | mel= mel_spectrogram(audio, self.h.fft_size, self.h.audio_num_mel_bins, self.h.audio_sample_rate, self.h.hop_size, self.h.win_size, self.h.fmin, self.h.fmax, keyshift=keyshift).transpose(1,2)
68 | # mel= mel_spectrogram(audio, 512, 80, 24000, 128, 512, 30, 12000, keyshift=keyshift).transpose(1,2)
69 | return mel
70 |
71 | def load_checkpoint(self, filepath, device):
72 | assert os.path.isfile(filepath)
73 | print("Loading '{}'".format(filepath))
74 | checkpoint_dict = torch.load(filepath, map_location=device)
75 | print("Complete.")
76 | return checkpoint_dict
77 |
78 |
79 | def forward(self, mel, f0):
80 | ckpt_dict = torch.load(self.model_path, map_location=self.device)
81 | state = ckpt_dict["state_dict"]["model_gen"]
82 | self.model = HifiGanGenerator(self.h).to(self.device)
83 | self.model.load_state_dict(state, strict=True)
84 | self.model.remove_weight_norm()
85 | self.model = self.model.eval()
86 | c = mel.transpose(2, 1)
87 | y = self.model(c,f0).view(-1)
88 |
89 | return y[None]
90 |
--------------------------------------------------------------------------------
/como.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import copy
4 | from pitch_extractor import PitchExtractor
5 | from wavenet import WaveNet
6 |
7 | import numpy as np
8 | import torch
9 |
10 |
11 | class BaseModule(torch.nn.Module):
12 | def __init__(self):
13 | super(BaseModule, self).__init__()
14 |
15 | @property
16 | def nparams(self):
17 | """
18 | Returns number of trainable parameters of the module.
19 | """
20 | num_params = 0
21 | for name, param in self.named_parameters():
22 | if param.requires_grad:
23 | num_params += np.prod(param.detach().cpu().numpy().shape)
24 | return num_params
25 |
26 |
27 | def relocate_input(self, x: list):
28 | """
29 | Relocates provided tensors to the same device set for the module.
30 | """
31 | device = next(self.parameters()).device
32 | for i in range(len(x)):
33 | if isinstance(x[i], torch.Tensor) and x[i].device != device:
34 | x[i] = x[i].to(device)
35 | return x
36 |
37 |
38 | class Como(BaseModule):
39 | def __init__(self, out_dims, n_layers, n_chans, n_hidden,total_steps,teacher = True ):
40 | super().__init__()
41 | self.denoise_fn = WaveNet(out_dims, n_layers, n_chans, n_hidden)
42 | self.pe= PitchExtractor()
43 | self.teacher = teacher
44 | if not teacher:
45 | self.denoise_fn_ema = copy.deepcopy(self.denoise_fn)
46 | self.denoise_fn_pretrained = copy.deepcopy(self.denoise_fn)
47 |
48 | self.P_mean =-1.2
49 | self.P_std =1.2
50 | self.sigma_data =0.5
51 |
52 | self.sigma_min= 0.002
53 | self.sigma_max= 80
54 | self.rho=7
55 | self.N = 25
56 | self.total_steps=total_steps
57 | self.spec_min=-6
58 | self.spec_max=1.5
59 | step_indices = torch.arange(self.N)
60 | t_steps = (self.sigma_min ** (1 / self.rho) + step_indices / (self.N - 1) * (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho))) ** self.rho
61 | self.t_steps = torch.cat([torch.zeros_like(t_steps[:1]), self.round_sigma(t_steps)]) # round_tensorj将数据转为tensor
62 |
63 | def norm_spec(self, x):
64 | return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
65 |
66 | def denorm_spec(self, x):
67 | return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
68 |
69 | def EDMPrecond(self, x, sigma ,cond,denoise_fn):
70 | sigma = sigma.reshape(-1, 1, 1 )
71 | c_skip = self.sigma_data ** 2 / ((sigma-self.sigma_min) ** 2 + self.sigma_data ** 2)
72 | c_out = (sigma-self.sigma_min) * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
73 | c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
74 | c_noise = sigma.log() / 4
75 | F_x = denoise_fn((c_in * x), c_noise.flatten(),cond)
76 | D_x = c_skip * x + c_out * (F_x .squeeze(1) )
77 | return D_x
78 |
79 | def EDMLoss(self, x_start, cond):
80 | rnd_normal = torch.randn([x_start.shape[0], 1, 1], device=x_start.device)
81 | sigma = (rnd_normal * self.P_std + self.P_mean).exp()
82 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
83 | n = (torch.randn_like(x_start) ) * sigma # Generate Gaussian Noise
84 | D_yn = self.EDMPrecond(x_start + n, sigma ,cond,self.denoise_fn) # After Denoising
85 | loss = (weight * ((D_yn - x_start) ** 2))
86 | loss=loss.unsqueeze(1).unsqueeze(1)
87 | loss=loss.mean()
88 | return loss
89 |
90 | def round_sigma(self, sigma):
91 | return torch.as_tensor(sigma)
92 |
93 | def edm_sampler(self,latents, cond,num_steps=50, sigma_min=0.002, sigma_max=80, rho=7, S_churn=0, S_min=0, S_max=float('inf'), S_noise=1):
94 | # Time step discretization.
95 | step_indices = torch.arange(num_steps, device=latents.device)
96 |
97 | num_steps=num_steps + 1
98 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
99 | t_steps = torch.cat([self.round_sigma(t_steps), torch.zeros_like(t_steps[:1])])
100 | # Main sampling loop.
101 | x_next = latents * t_steps[0]
102 | x_next = x_next.transpose(1,2)
103 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
104 | x_cur = x_next
105 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
106 | t_hat = self.round_sigma(t_cur + gamma * t_cur)
107 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur)
108 | denoised = self.EDMPrecond(x_hat, t_hat, cond, self.denoise_fn) # mel,sigma,cond
109 | d_cur = (x_hat - denoised) / t_hat # 7th step
110 | x_next = x_hat + (t_next - t_hat) * d_cur
111 |
112 | return x_next
113 |
114 | def CTLoss_D(self,y, cond): # y is the gt_spec
115 | with torch.no_grad():
116 | mu = 0.95
117 | for p, ema_p in zip(self.denoise_fn.parameters(), self.denoise_fn_ema.parameters()):
118 | ema_p.mul_(mu).add_(p, alpha=1 - mu)
119 | n = torch.randint(1, self.N, (y.shape[0],))
120 |
121 | z = torch.randn_like(y) # Gaussian Noise
122 | tn_1 = self.c_t_d(n + 1).reshape(-1, 1, 1).to(y.device)
123 | f_theta = self.EDMPrecond(y + tn_1 * z, tn_1, cond, self.denoise_fn)
124 |
125 | with torch.no_grad():
126 | tn = self.c_t_d(n ).reshape(-1, 1, 1).to(y.device)
127 | #euler step
128 | x_hat = y + tn_1 * z
129 | denoised = self.EDMPrecond(x_hat, tn_1 , cond,self.denoise_fn_pretrained)
130 | d_cur = (x_hat - denoised) / tn_1
131 | y_tn = x_hat + (tn - tn_1) * d_cur
132 | f_theta_ema = self.EDMPrecond( y_tn, tn,cond, self.denoise_fn_ema)
133 |
134 | loss = (f_theta - f_theta_ema.detach()) ** 2 # For consistency model, lembda=1
135 | loss=loss.unsqueeze(1).unsqueeze(1)
136 | loss=loss.mean()
137 |
138 | return loss
139 |
140 | def c_t_d(self, i ):
141 | return self.t_steps[i]
142 |
143 | def get_t_steps(self,N):
144 | N=N+1
145 | step_indices = torch.arange( N ) #, device=latents.device)
146 | t_steps = (self.sigma_min ** (1 / self.rho) + step_indices / (N- 1) * (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho))) ** self.rho
147 |
148 | return t_steps.flip(0)# FLIP t_step
149 |
150 | def CT_sampler(self, latents, cond, t_steps=1):
151 | if t_steps ==1:
152 | t_steps=[80]
153 | else:
154 | t_steps=self.get_t_steps(t_steps)
155 | t_steps = torch.as_tensor(t_steps).to(latents.device)
156 | latents = latents * t_steps[0]
157 | latents = latents.transpose(1,2)
158 | x = self.EDMPrecond(latents, t_steps[0],cond,self.denoise_fn)
159 | for t in t_steps[1:-1]: # N-1 to 1
160 | z = torch.randn_like(x)
161 | x_tn = x + (t ** 2 - self.sigma_min ** 2).sqrt()*z
162 | x = self.EDMPrecond(x_tn, t,cond,self.denoise_fn)
163 | return x
164 |
165 | def forward(self, x, cond, infer=False):
166 |
167 | if self.teacher: # teacher model
168 | if not infer: # training
169 | x=self.norm_spec(x)
170 | loss = self.EDMLoss(x, cond)
171 | return loss
172 | else: # infer
173 | shape = (cond.shape[0], 80, cond.shape[1])
174 | x = torch.randn(shape, device=cond.device)
175 | x=self.edm_sampler(x, cond, self.total_steps)
176 | return self.denorm_spec(x)
177 | else: #Consistency distillation
178 | if not infer: # training
179 | x=self.norm_spec(x)
180 | loss = self.CTLoss_D(x, cond)
181 | return loss
182 | else: # infer
183 | shape = (cond.shape[0], 80, cond.shape[1])
184 | x = torch.randn(shape, device=cond.device) # The Input is the Random Noise
185 | x=self.CT_sampler(x,cond,self.total_steps)
186 | return self.denorm_spec(x)
187 |
188 |
189 |
--------------------------------------------------------------------------------
/configs/the_config_files:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Grace9994/CoMoSVC/2ea8e644e2c5b3a8afc0762e870b9daacf3b5be5/configs/the_config_files
--------------------------------------------------------------------------------
/configs_template/diffusion_template.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | sampling_rate: 24000
3 | hop_length: 128
4 | duration: 2 # Audio duration during training, must be less than the duration of the shortest audio clip
5 | filter_length: 512
6 | win_length: 512
7 | encoder: 'vec768l12' #
8 | cnhubertsoft_gate: 10
9 | encoder_sample_rate: 16000
10 | encoder_hop_size: 320
11 | encoder_out_channels: 768 #
12 | training_files: "filelists/train.txt"
13 | validation_files: "filelists/val.txt"
14 | extensions: # List of extension included in the data collection
15 | - wav
16 | unit_interpolate_mode: "nearest"
17 | model:
18 | type: 'Diffusion'
19 | n_layers: 20
20 | n_chans: 512
21 | n_hidden: 256
22 | use_pitch_aug: true
23 | n_spk: 1 # max number of different speakers
24 | device: cuda
25 | vocoder:
26 | type: 'm4-gan'
27 | ckpt: 'm4singer_hifigan/model_ckpt_steps_1970000.ckpt'
28 | infer:
29 | method: 'dpm-solver++' #
30 | env:
31 | comodir: logs/como
32 | expdir: logs/teacher
33 | gpu_id: 0
34 | train:
35 | num_workers: 4 # If your cpu and gpu are both very strong, set to 0 may be faster!
36 | amp_dtype: fp32 # fp32, fp16 or bf16 (fp16 or bf16 may be faster if it is supported by your gpu)
37 | batch_size: 48
38 | cache_all_data: true # Save Internal-Memory or Graphics-Memory if it is false, but may be slow
39 | cache_device: 'cpu' # Set to 'cuda' to cache the data into the Graphics-Memory, fastest speed for strong gpu
40 | cache_fp16: true
41 | epochs: 100000
42 | interval_log: 10
43 | interval_val: 2000
44 | interval_force_save: 2000
45 | lr: 0.0001
46 | comolr: 0.00005
47 | decay_step: 100000
48 | gamma: 0.5
49 | weight_decay: 0
50 | save_opt: false
51 | spk:
52 | 'SPEAKER1': 0
--------------------------------------------------------------------------------
/data_loaders.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import librosa
5 | import numpy as np
6 | import torch
7 | from torch.utils.data import Dataset
8 | from tqdm import tqdm
9 |
10 | from utils import repeat_expand_2d
11 |
12 |
13 | def traverse_dir(
14 | root_dir,
15 | extensions,
16 | amount=None,
17 | str_include=None,
18 | str_exclude=None,
19 | is_pure=False,
20 | is_sort=False,
21 | is_ext=True):
22 |
23 | file_list = []
24 | cnt = 0
25 | for root, _, files in os.walk(root_dir):
26 | for file in files:
27 | if any([file.endswith(f".{ext}") for ext in extensions]):
28 | # path
29 | mix_path = os.path.join(root, file)
30 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
31 |
32 | # amount
33 | if (amount is not None) and (cnt == amount):
34 | if is_sort:
35 | file_list.sort()
36 | return file_list
37 |
38 | # check string
39 | if (str_include is not None) and (str_include not in pure_path):
40 | continue
41 | if (str_exclude is not None) and (str_exclude in pure_path):
42 | continue
43 |
44 | if not is_ext:
45 | ext = pure_path.split('.')[-1]
46 | pure_path = pure_path[:-(len(ext)+1)]
47 | file_list.append(pure_path)
48 | cnt += 1
49 | if is_sort:
50 | file_list.sort()
51 | return file_list
52 |
53 |
54 | def get_data_loaders(args, whole_audio=False):
55 | data_train = AudioDataset(
56 | filelists = args.data.training_files,
57 | waveform_sec=args.data.duration,
58 | hop_size=args.data.hop_length,
59 | sample_rate=args.data.sampling_rate,
60 | load_all_data=args.train.cache_all_data,
61 | whole_audio=whole_audio,
62 | extensions=args.data.extensions,
63 | n_spk=args.model.n_spk,
64 | spk=args.spk,
65 | device=args.train.cache_device,
66 | fp16=args.train.cache_fp16,
67 | unit_interpolate_mode = args.data.unit_interpolate_mode,
68 | use_aug=True)
69 | loader_train = torch.utils.data.DataLoader(
70 | data_train ,
71 | batch_size=args.train.batch_size if not whole_audio else 1,
72 | shuffle=True,
73 | num_workers=args.train.num_workers if args.train.cache_device=='cpu' else 0,
74 | persistent_workers=(args.train.num_workers > 0) if args.train.cache_device=='cpu' else False,
75 | pin_memory=True if args.train.cache_device=='cpu' else False
76 | )
77 | data_valid = AudioDataset(
78 | filelists = args.data.validation_files,
79 | waveform_sec=args.data.duration,
80 | hop_size=args.data.hop_length,
81 | sample_rate=args.data.sampling_rate,
82 | load_all_data=args.train.cache_all_data,
83 | whole_audio=True,
84 | spk=args.spk,
85 | extensions=args.data.extensions,
86 | unit_interpolate_mode = args.data.unit_interpolate_mode,
87 | n_spk=args.model.n_spk)
88 | loader_valid = torch.utils.data.DataLoader(
89 | data_valid,
90 | batch_size=1,
91 | shuffle=False,
92 | num_workers=0,
93 | pin_memory=True
94 | )
95 | return loader_train, loader_valid
96 |
97 |
98 | class AudioDataset(Dataset):
99 | def __init__(
100 | self,
101 | filelists,
102 | waveform_sec,
103 | hop_size,
104 | sample_rate,
105 | spk,
106 | load_all_data=True,
107 | whole_audio=False,
108 | extensions=['wav'],
109 | n_spk=1,
110 | device='cpu',
111 | fp16=False,
112 | use_aug=False,
113 | unit_interpolate_mode = 'left'
114 | ):
115 | super().__init__()
116 |
117 | self.waveform_sec = waveform_sec
118 | self.sample_rate = sample_rate
119 | self.hop_size = hop_size
120 | self.filelists = filelists
121 | self.whole_audio = whole_audio
122 | self.use_aug = use_aug
123 | self.data_buffer={}
124 | self.pitch_aug_dict = {}
125 | self.unit_interpolate_mode = unit_interpolate_mode
126 | # np.load(os.path.join(self.path_root, 'pitch_aug_dict.npy'), allow_pickle=True).item()
127 | if load_all_data:
128 | print('Load all the data filelists:', filelists)
129 | else:
130 | print('Load the f0, volume data filelists:', filelists)
131 | with open(filelists,"r") as f:
132 | self.paths = f.read().splitlines()
133 | for name_ext in tqdm(self.paths, total=len(self.paths)):
134 | path_audio = name_ext
135 | duration = librosa.get_duration(filename = path_audio, sr = self.sample_rate)
136 |
137 | path_f0 = name_ext + ".f0.npy"
138 | f0,_ = np.load(path_f0,allow_pickle=True)
139 | f0 = torch.from_numpy(np.array(f0,dtype=float)).float().unsqueeze(-1).to(device)
140 |
141 | path_volume = name_ext + ".vol.npy"
142 | volume = np.load(path_volume)
143 | volume = torch.from_numpy(volume).float().unsqueeze(-1).to(device)
144 |
145 | path_augvol = name_ext + ".aug_vol.npy"
146 | aug_vol = np.load(path_augvol)
147 | aug_vol = torch.from_numpy(aug_vol).float().unsqueeze(-1).to(device)
148 |
149 | if n_spk is not None and n_spk > 1:
150 | spk_name = name_ext.split("/")[-2]
151 | spk_id = spk[spk_name] if spk_name in spk else 0
152 | if spk_id < 0 or spk_id >= n_spk:
153 | raise ValueError(' [x] Muiti-speaker traing error : spk_id must be a positive integer from 0 to n_spk-1 ')
154 | else:
155 | spk_id = 0
156 | spk_id = torch.LongTensor(np.array([spk_id])).to(device)
157 |
158 | if load_all_data:
159 | '''
160 | audio, sr = librosa.load(path_audio, sr=self.sample_rate)
161 | if len(audio.shape) > 1:
162 | audio = librosa.to_mono(audio)
163 | audio = torch.from_numpy(audio).to(device)
164 | '''
165 | path_mel = name_ext + ".mel.npy"
166 | mel = np.load(path_mel)
167 | mel = torch.from_numpy(mel).to(device)
168 |
169 | path_augmel = name_ext + ".aug_mel.npy"
170 | aug_mel,keyshift = np.load(path_augmel, allow_pickle=True)
171 | aug_mel = np.array(aug_mel,dtype=float)
172 | aug_mel = torch.from_numpy(aug_mel).to(device)
173 | self.pitch_aug_dict[name_ext] = keyshift
174 |
175 | path_units = name_ext + ".soft.pt"
176 | units = torch.load(path_units).to(device)
177 | units = units[0]
178 | units = repeat_expand_2d(units,f0.size(0),unit_interpolate_mode).transpose(0,1)
179 |
180 | if fp16:
181 | mel = mel.half()
182 | aug_mel = aug_mel.half()
183 | units = units.half()
184 |
185 | self.data_buffer[name_ext] = {
186 | 'duration': duration,
187 | 'mel': mel,
188 | 'aug_mel': aug_mel,
189 | 'units': units,
190 | 'f0': f0,
191 | 'volume': volume,
192 | 'aug_vol': aug_vol,
193 | 'spk_id': spk_id
194 | }
195 | else:
196 | path_augmel = name_ext + ".aug_mel.npy"
197 | aug_mel,keyshift = np.load(path_augmel, allow_pickle=True)
198 | self.pitch_aug_dict[name_ext] = keyshift
199 | self.data_buffer[name_ext] = {
200 | 'duration': duration,
201 | 'f0': f0,
202 | 'volume': volume,
203 | 'aug_vol': aug_vol,
204 | 'spk_id': spk_id
205 | }
206 |
207 |
208 | def __getitem__(self, file_idx):
209 | name_ext = self.paths[file_idx]
210 | data_buffer = self.data_buffer[name_ext]
211 | # check duration. if too short, then skip
212 | if data_buffer['duration'] < (self.waveform_sec + 0.1):
213 | return self.__getitem__( (file_idx + 1) % len(self.paths))
214 |
215 | # get item
216 | return self.get_data(name_ext, data_buffer)
217 |
218 | def get_data(self, name_ext, data_buffer):
219 | name = os.path.splitext(name_ext)[0]
220 | frame_resolution = self.hop_size / self.sample_rate
221 | duration = data_buffer['duration']
222 | waveform_sec = duration if self.whole_audio else self.waveform_sec
223 |
224 | # load audio
225 | idx_from = 0 if self.whole_audio else random.uniform(0, duration - waveform_sec - 0.1)
226 | start_frame = int(idx_from / frame_resolution)
227 | units_frame_len = int(waveform_sec / frame_resolution)
228 | aug_flag = random.choice([True, False]) and self.use_aug
229 | '''
230 | audio = data_buffer.get('audio')
231 | if audio is None:
232 | path_audio = os.path.join(self.path_root, 'audio', name) + '.wav'
233 | audio, sr = librosa.load(
234 | path_audio,
235 | sr = self.sample_rate,
236 | offset = start_frame * frame_resolution,
237 | duration = waveform_sec)
238 | if len(audio.shape) > 1:
239 | audio = librosa.to_mono(audio)
240 | # clip audio into N seconds
241 | audio = audio[ : audio.shape[-1] // self.hop_size * self.hop_size]
242 | audio = torch.from_numpy(audio).float()
243 | else:
244 | audio = audio[start_frame * self.hop_size : (start_frame + units_frame_len) * self.hop_size]
245 | '''
246 | # load mel
247 | mel_key = 'aug_mel' if aug_flag else 'mel'
248 | mel = data_buffer.get(mel_key)
249 | if mel is None:
250 | mel = name_ext + ".mel.npy"
251 | mel = np.load(mel)
252 | mel = mel[start_frame : start_frame + units_frame_len]
253 | mel = torch.from_numpy(mel).float()
254 | else:
255 | mel = mel[start_frame : start_frame + units_frame_len]
256 |
257 | # load f0
258 | f0 = data_buffer.get('f0')
259 | aug_shift = 0
260 | if aug_flag:
261 | aug_shift = self.pitch_aug_dict[name_ext]
262 | f0_frames = 2 ** (aug_shift / 12) * f0[start_frame : start_frame + units_frame_len]
263 |
264 | # load units
265 | units = data_buffer.get('units')
266 | if units is None:
267 | path_units = name_ext + ".soft.pt"
268 | units = torch.load(path_units)
269 | units = units[0]
270 | units = repeat_expand_2d(units,f0.size(0),self.unit_interpolate_mode).transpose(0,1)
271 |
272 | units = units[start_frame : start_frame + units_frame_len]
273 |
274 | # load volume
275 | vol_key = 'aug_vol' if aug_flag else 'volume'
276 | volume = data_buffer.get(vol_key)
277 | volume_frames = volume[start_frame : start_frame + units_frame_len]
278 |
279 | # load spk_id
280 | spk_id = data_buffer.get('spk_id')
281 |
282 | # load shift
283 | aug_shift = torch.from_numpy(np.array([[aug_shift]])).float()
284 |
285 | return dict(mel=mel, f0=f0_frames, volume=volume_frames, units=units, spk_id=spk_id, aug_shift=aug_shift, name=name, name_ext=name_ext)
286 |
287 | def __len__(self):
288 | return len(self.paths)
--------------------------------------------------------------------------------
/dataset/the_prprocessed_data:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Grace9994/CoMoSVC/2ea8e644e2c5b3a8afc0762e870b9daacf3b5be5/dataset/the_prprocessed_data
--------------------------------------------------------------------------------
/dataset_slice/if_you_need_to_slice:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Grace9994/CoMoSVC/2ea8e644e2c5b3a8afc0762e870b9daacf3b5be5/dataset_slice/if_you_need_to_slice
--------------------------------------------------------------------------------
/filelists/put_the_txt_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Grace9994/CoMoSVC/2ea8e644e2c5b3a8afc0762e870b9daacf3b5be5/filelists/put_the_txt_here
--------------------------------------------------------------------------------
/infer_tool.py:
--------------------------------------------------------------------------------
1 | import io
2 | import logging
3 | import os
4 | import time
5 | from pathlib import Path
6 |
7 | import librosa
8 | import numpy as np
9 |
10 | import soundfile
11 | import torch
12 | import torchaudio
13 |
14 | import utils
15 | from ComoSVC import load_model_vocoder
16 | import slicer
17 |
18 | logging.getLogger('matplotlib').setLevel(logging.WARNING)
19 |
20 |
21 | def format_wav(audio_path):
22 | if Path(audio_path).suffix == '.wav':
23 | return
24 | raw_audio, raw_sample_rate = librosa.load(audio_path, mono=True, sr=None)
25 | soundfile.write(Path(audio_path).with_suffix(".wav"), raw_audio, raw_sample_rate)
26 |
27 |
28 | def get_end_file(dir_path, end):
29 | file_lists = []
30 | for root, dirs, files in os.walk(dir_path):
31 | files = [f for f in files if f[0] != '.']
32 | dirs[:] = [d for d in dirs if d[0] != '.']
33 | for f_file in files:
34 | if f_file.endswith(end):
35 | file_lists.append(os.path.join(root, f_file).replace("\\", "/"))
36 | return file_lists
37 |
38 |
39 | def fill_a_to_b(a, b):
40 | if len(a) < len(b):
41 | for _ in range(0, len(b) - len(a)):
42 | a.append(a[0])
43 |
44 | def mkdir(paths: list):
45 | for path in paths:
46 | if not os.path.exists(path):
47 | os.mkdir(path)
48 |
49 | def pad_array(arr, target_length):
50 | current_length = arr.shape[0]
51 | if current_length >= target_length:
52 | return arr
53 | else:
54 | pad_width = target_length - current_length
55 | pad_left = pad_width // 2
56 | pad_right = pad_width - pad_left
57 | padded_arr = np.pad(arr, (pad_left, pad_right), 'constant', constant_values=(0, 0))
58 | return padded_arr
59 |
60 | def split_list_by_n(list_collection, n, pre=0):
61 | for i in range(0, len(list_collection), n):
62 | yield list_collection[i-pre if i-pre>=0 else i: i + n]
63 |
64 |
65 | class F0FilterException(Exception):
66 | pass
67 |
68 | class Svc(object):
69 | def __init__(self,
70 | diffusion_model_path="logs/como/model_8000.pt",
71 | diffusion_config_path="configs/diffusion.yaml",
72 | total_steps=1,
73 | teacher = False
74 | ):
75 |
76 | self.teacher = teacher
77 | self.total_steps=total_steps
78 | self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79 | self.diffusion_model,self.vocoder,self.diffusion_args = load_model_vocoder(diffusion_model_path,self.dev,config_path=diffusion_config_path,total_steps=self.total_steps,teacher=self.teacher)
80 | self.target_sample = self.diffusion_args.data.sampling_rate
81 | self.hop_size = self.diffusion_args.data.hop_length
82 | self.spk2id = self.diffusion_args.spk
83 | self.dtype = torch.float32
84 | self.speech_encoder = self.diffusion_args.data.encoder
85 | self.unit_interpolate_mode = self.diffusion_args.data.unit_interpolate_mode if self.diffusion_args.data.unit_interpolate_mode is not None else 'left'
86 |
87 | # load hubert and model
88 |
89 | from Features import ContentVec768L12
90 | self.hubert_model = ContentVec768L12(device = self.dev)
91 | self.volume_extractor= utils.Volume_Extractor(self.hop_size)
92 |
93 |
94 |
95 | def get_unit_f0(self, wav, tran):
96 |
97 | if not hasattr(self,"f0_predictor_object") or self.f0_predictor_object is None:
98 | from Features import DioF0Predictor
99 | self.f0_predictor_object = DioF0Predictor(hop_length=self.hop_size,sampling_rate=self.target_sample)
100 | f0, uv = self.f0_predictor_object.compute_f0_uv(wav)
101 | f0 = torch.FloatTensor(f0).to(self.dev)
102 | uv = torch.FloatTensor(uv).to(self.dev)
103 |
104 | f0 = f0 * 2 ** (tran / 12)
105 | f0 = f0.unsqueeze(0)
106 | uv = uv.unsqueeze(0)
107 |
108 | wav = torch.from_numpy(wav).to(self.dev)
109 | if not hasattr(self,"audio16k_resample_transform"):
110 | self.audio16k_resample_transform = torchaudio.transforms.Resample(self.target_sample, 16000).to(self.dev)
111 | wav16k = self.audio16k_resample_transform(wav[None,:])[0]
112 |
113 | c = self.hubert_model.encoder(wav16k)
114 | c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1],self.unit_interpolate_mode)
115 | c = c.unsqueeze(0)
116 | return c, f0, uv
117 |
118 | def infer(self, speaker, tran, raw_path):
119 | torchaudio.set_audio_backend("soundfile")
120 | wav, sr = torchaudio.load(raw_path)
121 | if not hasattr(self,"audio_resample_transform") or self.audio16k_resample_transform.orig_freq != sr:
122 | self.audio_resample_transform = torchaudio.transforms.Resample(sr,self.target_sample)
123 | wav = self.audio_resample_transform(wav).numpy()[0]# (100080,)
124 | speaker_id = self.spk2id.get(speaker)
125 |
126 | if not speaker_id and type(speaker) is int:
127 | if len(self.spk2id.__dict__) >= speaker:
128 | speaker_id = speaker
129 | if speaker_id is None:
130 | raise RuntimeError("The name you entered is not in the speaker list!")
131 | sid = torch.LongTensor([int(speaker_id)]).to(self.dev).unsqueeze(0)
132 |
133 | c, f0, uv = self.get_unit_f0(wav, tran)
134 | n_frames = f0.size(1)
135 | c = c.to(self.dtype)
136 | f0 = f0.to(self.dtype)
137 | uv = uv.to(self.dtype)
138 |
139 | with torch.no_grad():
140 | start = time.time()
141 | vol = None
142 | audio = torch.FloatTensor(wav).to(self.dev)
143 | audio_mel = None
144 | vol = self.volume_extractor.extract(audio[None,:])[None,:,None].to(self.dev) if vol is None else vol[:,:,None]
145 | f0 = f0[:,:,None] # torch.Size([1, 390]) to torch.Size([1, 390, 1])
146 | c = c.transpose(-1,-2)
147 | audio_mel = self.diffusion_model(c, f0, vol,spk_id = sid,gt_spec=audio_mel,infer=True)
148 | # print("inferencetool_audiomel",audio_mel.shape)
149 | audio = self.vocoder.infer(audio_mel, f0).squeeze()
150 | use_time = time.time() - start
151 | print("inference_time is:{}".format(use_time))
152 | return audio, audio.shape[-1], n_frames
153 |
154 | def clear_empty(self):
155 | # clean up vram
156 | torch.cuda.empty_cache()
157 |
158 |
159 | def slice_inference(self,
160 | raw_audio_path,
161 | spk,
162 | tran,
163 | slice_db=-40, # -40
164 | pad_seconds=0.5,
165 | clip_seconds=0,
166 | ):
167 |
168 | wav_path = Path(raw_audio_path).with_suffix('.wav')
169 | chunks = slicer.cut(wav_path, db_thresh=slice_db)
170 | audio_data, audio_sr = slicer.chunks2audio(wav_path, chunks)
171 | per_size = int(clip_seconds*audio_sr)
172 | lg_size = 0
173 | global_frame = 0
174 | audio = []
175 | for (slice_tag, data) in audio_data:
176 | print(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======')
177 | # padd
178 | length = int(np.ceil(len(data) / audio_sr * self.target_sample))
179 | if slice_tag:
180 | print('jump empty segment')
181 | _audio = np.zeros(length)
182 | audio.extend(list(pad_array(_audio, length)))
183 | global_frame += length // self.hop_size
184 | continue
185 | if per_size != 0:
186 | datas = split_list_by_n(data, per_size,lg_size)
187 | else:
188 | datas = [data]
189 | for k,dat in enumerate(datas):
190 | per_length = int(np.ceil(len(dat) / audio_sr * self.target_sample)) if clip_seconds!=0 else length
191 | if clip_seconds!=0:
192 | print(f'###=====segment clip start, {round(len(dat) / audio_sr, 3)}s======')
193 | # padd
194 | pad_len = int(audio_sr * pad_seconds)
195 | dat = np.concatenate([np.zeros([pad_len]), dat, np.zeros([pad_len])])
196 | raw_path = io.BytesIO()
197 | soundfile.write(raw_path, dat, audio_sr, format="wav")
198 | raw_path.seek(0)
199 | out_audio, out_sr, out_frame = self.infer(spk, tran, raw_path)
200 | global_frame += out_frame
201 | _audio = out_audio.cpu().numpy()
202 | pad_len = int(self.target_sample * pad_seconds)
203 | _audio = _audio[pad_len:-pad_len]
204 | _audio = pad_array(_audio, per_length)
205 | audio.extend(list(_audio))
206 | return np.array(audio)
207 |
--------------------------------------------------------------------------------
/inference_main.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import soundfile
3 | import os
4 | os.environ["CUDA_VISIBLE_DEVICES"]='1'
5 |
6 | import infer_tool
7 | from infer_tool import Svc
8 |
9 | logging.getLogger('numba').setLevel(logging.WARNING)
10 |
11 |
12 | def main():
13 | import argparse
14 |
15 | parser = argparse.ArgumentParser(description='comosvc inference')
16 | parser.add_argument('-t', '--teacher', action="store_false",help='if it is teacher model')
17 | parser.add_argument('-ts', '--total_steps', type=int,default=1,help='the total number of iterative steps during inference')
18 |
19 |
20 | parser.add_argument('--clip', type=float, default=0, help='Slicing the audios which are to be converted')
21 | parser.add_argument('-n','--clean_names', type=str, nargs='+', default=['1.wav'], help='The audios to be converted,should be put in "raw" directory')
22 | parser.add_argument('-k','--keys', type=int, nargs='+', default=[0], help='To Adjust the Key')
23 | parser.add_argument('-s','--spk_list', type=str, nargs='+', default=['singer1'], help='The target singer')
24 | parser.add_argument('-cm','--como_model_path', type=str, default="./logs/como/model_800000.pt", help='the path to checkpoint of CoMoSVC')
25 | parser.add_argument('-cc','--como_config_path', type=str, default="./logs/como/config.yaml", help='the path to config file of CoMoSVC')
26 | parser.add_argument('-tm','--teacher_model_path', type=str, default="./logs/teacher/model_800000.pt", help='the path to checkpoint of Teacher Model')
27 | parser.add_argument('-tc','--teacher_config_path', type=str, default="./logs/teacher/config.yaml", help='the path to config file of Teacher Model')
28 |
29 | args = parser.parse_args()
30 |
31 | clean_names = args.clean_names
32 | keys = args.keys
33 | spk_list = args.spk_list
34 | slice_db =-40
35 | wav_format = 'wav' # the format of the output audio
36 | pad_seconds = 0.5
37 | clip = args.clip
38 |
39 | if args.teacher:
40 | diffusion_model_path = args.teacher_model_path
41 | diffusion_config_path = args.teacher_config_path
42 | resultfolder='result_teacher'
43 | else:
44 | diffusion_model_path = args.como_model_path
45 | diffusion_config_path = args.como_config_path
46 | resultfolder='result_teacher'
47 |
48 | svc_model = Svc(diffusion_model_path,
49 | diffusion_config_path,
50 | args.total_steps,
51 | args.teacher)
52 |
53 | infer_tool.mkdir(["raw", resultfolder])
54 |
55 | infer_tool.fill_a_to_b(keys, clean_names)
56 | for clean_name, tran in zip(clean_names, keys):
57 | raw_audio_path = f"raw/{clean_name}"
58 | if "." not in raw_audio_path:
59 | raw_audio_path += ".wav"
60 | infer_tool.format_wav(raw_audio_path)
61 | for spk in spk_list:
62 | kwarg = {
63 | "raw_audio_path" : raw_audio_path,
64 | "spk" : spk,
65 | "tran" : tran,
66 | "slice_db" : slice_db,# -40
67 | "pad_seconds" : pad_seconds, # 0.5
68 | "clip_seconds" : clip, #0
69 |
70 | }
71 | audio = svc_model.slice_inference(**kwarg)
72 | step_num=diffusion_model_path.split('/')[-1].split('.')[0]
73 | if args.teacher:
74 | isdiffusion = "teacher"
75 | else:
76 | isdiffusion= "como"
77 | res_path = f'{resultfolder}/{clean_name}_{spk}_{isdiffusion}_{step_num}.{wav_format}'
78 | soundfile.write(res_path, audio, svc_model.target_sample, format=wav_format)
79 | svc_model.clear_empty()
80 |
81 | if __name__ == '__main__':
82 | main()
83 |
--------------------------------------------------------------------------------
/logs/the_log_files:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Grace9994/CoMoSVC/2ea8e644e2c5b3a8afc0762e870b9daacf3b5be5/logs/the_log_files
--------------------------------------------------------------------------------
/mel_processing.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import random
4 | import torch
5 | import torch.utils.data
6 | import torch.nn.functional as F
7 | import numpy as np
8 | from librosa.util import normalize
9 | from scipy.io.wavfile import read
10 | from librosa.filters import mel as librosa_mel_fn
11 | import pathlib
12 | from tqdm import tqdm
13 |
14 | MAX_WAV_VALUE = 32768.0
15 |
16 |
17 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
18 | return torch.log10(torch.clamp(x, min=clip_val) * C)
19 |
20 |
21 | def dynamic_range_decompression_torch(x, C=1):
22 | return torch.exp(x) / C
23 |
24 |
25 | def spectral_normalize_torch(magnitudes):
26 | output = dynamic_range_compression_torch(magnitudes)
27 | return output
28 |
29 |
30 | def spectral_de_normalize_torch(magnitudes):
31 | output = dynamic_range_decompression_torch(magnitudes)
32 | return output
33 |
34 |
35 | mel_basis = {}
36 | hann_window = {}
37 |
38 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, keyshift=0, speed=1,center=False):
39 |
40 | factor = 2 ** (keyshift / 12)
41 | n_fft_new = int(np.round(n_fft * factor))
42 | win_size_new = int(np.round(win_size * factor))
43 | hop_length_new = int(np.round(hop_size * speed))
44 |
45 |
46 | if torch.min(y) < -1.:
47 | print('min value is ', torch.min(y))
48 | if torch.max(y) > 1.:
49 | print('max value is ', torch.max(y))
50 |
51 | global mel_basis, hann_window
52 | mel_basis_key = str(fmax)+'_'+str(y.device)
53 | if mel_basis_key not in mel_basis:
54 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) # 一个mel转换器,即这是一个函数,可以用来提取mel谱
55 | mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device) # 建 Mel 转换器,并将其转换为 PyTorch 的张量,并根据 y.device 将其放置在相应的设备上。
56 |
57 | keyshift_key = str(keyshift)+'_'+str(y.device)
58 | if keyshift_key not in hann_window:
59 | hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
60 |
61 | pad_left = (win_size_new - hop_length_new) //2
62 | pad_right = max((win_size_new- hop_length_new + 1) //2, win_size_new - y.size(-1) - pad_left)
63 | if pad_right < y.size(-1):
64 | mode = 'reflect'
65 | else:
66 | mode = 'constant'
67 | y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode = mode)
68 | y = y.squeeze(1)
69 |
70 |
71 | spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=hann_window[keyshift_key],
72 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
73 |
74 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
75 | if keyshift != 0:
76 | size = n_fft // 2 + 1
77 | resize = spec.size(1)
78 | if resize < size:
79 | spec = F.pad(spec, (0, 0, 0, size-resize))
80 | spec = spec[:, :size, :] * win_size / win_size_new
81 | spec = torch.matmul(mel_basis[mel_basis_key], spec)
82 |
83 | spec = spectral_normalize_torch(spec)
84 |
85 | return spec
86 |
87 |
88 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
89 | if torch.min(y) < -1.:
90 | print('min value is ', torch.min(y))
91 | if torch.max(y) > 1.:
92 | print('max value is ', torch.max(y))
93 |
94 | global hann_window
95 | dtype_device = str(y.dtype) + '_' + str(y.device)
96 | wnsize_dtype_device = str(win_size) + '_' + dtype_device
97 | if wnsize_dtype_device not in hann_window:
98 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
99 |
100 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
101 | y = y.squeeze(1)
102 |
103 | y_dtype = y.dtype
104 | if y.dtype == torch.bfloat16:
105 | y = y.to(torch.float32)
106 |
107 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
108 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
109 | spec = torch.view_as_real(spec).to(y_dtype)
110 |
111 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
112 | return spec
113 |
114 |
115 |
--------------------------------------------------------------------------------
/meldataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from librosa.filters import mel as librosa_mel_fn
6 |
7 |
8 | def spectral_normalize_torch(magnitudes):
9 | output = dynamic_range_compression_torch(magnitudes)
10 | return output
11 |
12 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
13 | return torch.log10(torch.clamp(x, min=clip_val) * C)
14 |
15 |
16 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, keyshift=0, speed=1,center=False):
17 |
18 | factor = 2 ** (keyshift / 12)
19 | n_fft_new = int(np.round(n_fft * factor))
20 | win_size_new = int(np.round(win_size * factor))
21 | hop_length_new = int(np.round(hop_size * speed))
22 |
23 |
24 | if torch.min(y) < -1.:
25 | print('min value is ', torch.min(y))
26 | if torch.max(y) > 1.:
27 | print('max value is ', torch.max(y))
28 |
29 | global mel_basis, hann_window
30 | mel_basis_key = str(fmax)+'_'+str(y.device)
31 | if mel_basis_key not in mel_basis:
32 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) # 一个mel转换器,即这是一个函数,可以用来提取mel谱
33 | mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device) # 建 Mel 转换器,并将其转换为 PyTorch 的张量,并根据 y.device 将其放置在相应的设备上。
34 |
35 | keyshift_key = str(keyshift)+'_'+str(y.device)
36 | if keyshift_key not in hann_window:
37 | hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
38 |
39 | pad_left = (win_size_new - hop_length_new) //2
40 | pad_right = max((win_size_new- hop_length_new + 1) //2, win_size_new - y.size(-1) - pad_left)
41 | if pad_right < y.size(-1):
42 | mode = 'reflect'
43 | else:
44 | mode = 'constant'
45 | y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode = mode)
46 | y = y.squeeze(1)
47 | spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=hann_window[keyshift_key],
48 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
49 |
50 | spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
51 | if keyshift != 0:
52 | size = n_fft // 2 + 1
53 | resize = spec.size(1)
54 | if resize < size:
55 | spec = F.pad(spec, (0, 0, 0, size-resize))
56 | spec = spec[:, :size, :] * win_size / win_size_new
57 | spec = torch.matmul(mel_basis[mel_basis_key], spec)
58 |
59 | spec = spectral_normalize_torch(spec)
60 |
61 | return spec
62 |
--------------------------------------------------------------------------------
/pitch_extractor.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | from torch.nn import Parameter
5 | import torch.onnx.operators
6 | import torch.nn.functional as F
7 | import utils
8 |
9 |
10 | class Reshape(nn.Module):
11 | def __init__(self, *args):
12 | super(Reshape, self).__init__()
13 | self.shape = args
14 |
15 | def forward(self, x):
16 | return x.view(self.shape)
17 |
18 |
19 | class Permute(nn.Module):
20 | def __init__(self, *args):
21 | super(Permute, self).__init__()
22 | self.args = args
23 |
24 | def forward(self, x):
25 | return x.permute(self.args)
26 |
27 |
28 | class LinearNorm(torch.nn.Module):
29 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
30 | super(LinearNorm, self).__init__()
31 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
32 |
33 | torch.nn.init.xavier_uniform_(
34 | self.linear_layer.weight,
35 | gain=torch.nn.init.calculate_gain(w_init_gain))
36 |
37 | def forward(self, x):
38 | return self.linear_layer(x)
39 |
40 |
41 | class ConvNorm(torch.nn.Module):
42 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
43 | padding=None, dilation=1, bias=True, w_init_gain='linear'):
44 | super(ConvNorm, self).__init__()
45 | if padding is None:
46 | assert (kernel_size % 2 == 1)
47 | padding = int(dilation * (kernel_size - 1) / 2)
48 |
49 | self.conv = torch.nn.Conv1d(in_channels, out_channels,
50 | kernel_size=kernel_size, stride=stride,
51 | padding=padding, dilation=dilation,
52 | bias=bias)
53 |
54 | torch.nn.init.xavier_uniform_(
55 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
56 |
57 | def forward(self, signal):
58 | conv_signal = self.conv(signal)
59 | return conv_signal
60 |
61 |
62 | def Embedding(num_embeddings, embedding_dim, padding_idx=None):
63 | m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
64 | nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
65 | if padding_idx is not None:
66 | nn.init.constant_(m.weight[padding_idx], 0)
67 | return m
68 |
69 |
70 | # def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
71 | # # if not export and torch.cuda.is_available():
72 | # # try:
73 | # # from apex.normalization import FusedLayerNorm
74 | # # return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
75 | # # except ImportError:
76 | # # pass
77 | # return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
78 |
79 | class LayerNorm(torch.nn.LayerNorm):
80 | """Layer normalization module.
81 | :param int nout: output dim size
82 | :param int dim: dimension to be normalized
83 | """
84 |
85 | def __init__(self, nout, dim=-1):
86 | """Construct an LayerNorm object."""
87 | super(LayerNorm, self).__init__(nout, eps=1e-12)
88 | self.dim = dim
89 |
90 | def forward(self, x):
91 | """Apply layer normalization.
92 | :param torch.Tensor x: input tensor
93 | :return: layer normalized tensor
94 | :rtype torch.Tensor
95 | """
96 | if self.dim == -1:
97 | return super(LayerNorm, self).forward(x)
98 | return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
99 |
100 | def Linear(in_features, out_features, bias=True):
101 | m = nn.Linear(in_features, out_features, bias)
102 | nn.init.xavier_uniform_(m.weight)
103 | if bias:
104 | nn.init.constant_(m.bias, 0.)
105 | return m
106 |
107 |
108 | class SinusoidalPositionalEmbedding(nn.Module):
109 | """This module produces sinusoidal positional embeddings of any length.
110 |
111 | Padding symbols are ignored.
112 | """
113 |
114 | def __init__(self, embedding_dim, padding_idx, init_size=1024):
115 | super().__init__()
116 | self.embedding_dim = embedding_dim
117 | self.padding_idx = padding_idx
118 | self.weights = SinusoidalPositionalEmbedding.get_embedding(
119 | init_size,
120 | embedding_dim,
121 | padding_idx,
122 | )
123 | self.register_buffer('_float_tensor', torch.FloatTensor(1))
124 |
125 | @staticmethod
126 | def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
127 | """Build sinusoidal embeddings.
128 |
129 | This matches the implementation in tensor2tensor, but differs slightly
130 | from the description in Section 3.5 of "Attention Is All You Need".
131 | """
132 | half_dim = embedding_dim // 2
133 | emb = math.log(10000) / (half_dim - 1)
134 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
135 | emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
136 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
137 | if embedding_dim % 2 == 1:
138 | # zero pad
139 | emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
140 | if padding_idx is not None:
141 | emb[padding_idx, :] = 0
142 | return emb
143 |
144 | def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
145 | """Input is expected to be of size [bsz x seqlen]."""
146 | bsz, seq_len = input.shape[:2]
147 | max_pos = self.padding_idx + 1 + seq_len
148 | if self.weights is None or max_pos > self.weights.size(0):
149 | # recompute/expand embeddings if needed
150 | self.weights = SinusoidalPositionalEmbedding.get_embedding(
151 | max_pos,
152 | self.embedding_dim,
153 | self.padding_idx,
154 | )
155 | self.weights = self.weights.to(self._float_tensor)
156 |
157 | if incremental_state is not None:
158 | # positions is the same for every token when decoding a single step
159 | pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
160 | return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
161 |
162 | positions = utils.make_positions(input, self.padding_idx) if positions is None else positions
163 | return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
164 |
165 | def max_positions(self):
166 | """Maximum number of supported positions."""
167 | return int(1e5) # an arbitrary large number
168 |
169 |
170 | class ConvTBC(nn.Module):
171 | def __init__(self, in_channels, out_channels, kernel_size, padding=0):
172 | super(ConvTBC, self).__init__()
173 | self.in_channels = in_channels
174 | self.out_channels = out_channels
175 | self.kernel_size = kernel_size
176 | self.padding = padding
177 |
178 | self.weight = torch.nn.Parameter(torch.Tensor(
179 | self.kernel_size, in_channels, out_channels))
180 | self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
181 |
182 | def forward(self, input):
183 | return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding)
184 |
185 |
186 | class MultiheadAttention(nn.Module):
187 | def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
188 | add_bias_kv=False, add_zero_attn=False, self_attention=False,
189 | encoder_decoder_attention=False):
190 | super().__init__()
191 | self.embed_dim = embed_dim
192 | self.kdim = kdim if kdim is not None else embed_dim
193 | self.vdim = vdim if vdim is not None else embed_dim
194 | self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
195 |
196 | self.num_heads = num_heads
197 | self.dropout = dropout
198 | self.head_dim = embed_dim // num_heads
199 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
200 | self.scaling = self.head_dim ** -0.5
201 |
202 | self.self_attention = self_attention
203 | self.encoder_decoder_attention = encoder_decoder_attention
204 |
205 | assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
206 | 'value to be of the same size'
207 |
208 | if self.qkv_same_dim:
209 | self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
210 | else:
211 | self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
212 | self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
213 | self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
214 |
215 | if bias:
216 | self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
217 | else:
218 | self.register_parameter('in_proj_bias', None)
219 |
220 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
221 |
222 | if add_bias_kv:
223 | self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
224 | self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
225 | else:
226 | self.bias_k = self.bias_v = None
227 |
228 | self.add_zero_attn = add_zero_attn
229 |
230 | self.reset_parameters()
231 |
232 | self.enable_torch_version = False
233 | if hasattr(F, "multi_head_attention_forward"):
234 | self.enable_torch_version = True
235 | else:
236 | self.enable_torch_version = False
237 | self.last_attn_probs = None
238 |
239 | def reset_parameters(self):
240 | if self.qkv_same_dim:
241 | nn.init.xavier_uniform_(self.in_proj_weight)
242 | else:
243 | nn.init.xavier_uniform_(self.k_proj_weight)
244 | nn.init.xavier_uniform_(self.v_proj_weight)
245 | nn.init.xavier_uniform_(self.q_proj_weight)
246 |
247 | nn.init.xavier_uniform_(self.out_proj.weight)
248 | if self.in_proj_bias is not None:
249 | nn.init.constant_(self.in_proj_bias, 0.)
250 | nn.init.constant_(self.out_proj.bias, 0.)
251 | if self.bias_k is not None:
252 | nn.init.xavier_normal_(self.bias_k)
253 | if self.bias_v is not None:
254 | nn.init.xavier_normal_(self.bias_v)
255 |
256 | def forward(
257 | self,
258 | query, key, value,
259 | key_padding_mask=None,
260 | incremental_state=None,
261 | need_weights=True,
262 | static_kv=False,
263 | attn_mask=None,
264 | before_softmax=False,
265 | need_head_weights=False,
266 | enc_dec_attn_constraint_mask=None,
267 | reset_attn_weight=None
268 | ):
269 | """Input shape: Time x Batch x Channel
270 |
271 | Args:
272 | key_padding_mask (ByteTensor, optional): mask to exclude
273 | keys that are pads, of shape `(batch, src_len)`, where
274 | padding elements are indicated by 1s.
275 | need_weights (bool, optional): return the attention weights,
276 | averaged over heads (default: False).
277 | attn_mask (ByteTensor, optional): typically used to
278 | implement causal attention, where the mask prevents the
279 | attention from looking forward in time (default: None).
280 | before_softmax (bool, optional): return the raw attention
281 | weights and values before the attention softmax.
282 | need_head_weights (bool, optional): return the attention
283 | weights for each head. Implies *need_weights*. Default:
284 | return the average attention weights over all heads.
285 | """
286 | if need_head_weights:
287 | need_weights = True
288 |
289 | tgt_len, bsz, embed_dim = query.size()
290 | assert embed_dim == self.embed_dim
291 | assert list(query.size()) == [tgt_len, bsz, embed_dim]
292 |
293 | if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
294 | if self.qkv_same_dim:
295 | return F.multi_head_attention_forward(query, key, value,
296 | self.embed_dim, self.num_heads,
297 | self.in_proj_weight,
298 | self.in_proj_bias, self.bias_k, self.bias_v,
299 | self.add_zero_attn, self.dropout,
300 | self.out_proj.weight, self.out_proj.bias,
301 | self.training, key_padding_mask, need_weights,
302 | attn_mask)
303 | else:
304 | return F.multi_head_attention_forward(query, key, value,
305 | self.embed_dim, self.num_heads,
306 | torch.empty([0]),
307 | self.in_proj_bias, self.bias_k, self.bias_v,
308 | self.add_zero_attn, self.dropout,
309 | self.out_proj.weight, self.out_proj.bias,
310 | self.training, key_padding_mask, need_weights,
311 | attn_mask, use_separate_proj_weight=True,
312 | q_proj_weight=self.q_proj_weight,
313 | k_proj_weight=self.k_proj_weight,
314 | v_proj_weight=self.v_proj_weight)
315 |
316 | if incremental_state is not None:
317 | print('Not implemented error.')
318 | exit()
319 | else:
320 | saved_state = None
321 |
322 | if self.self_attention:
323 | # self-attention
324 | q, k, v = self.in_proj_qkv(query)
325 | elif self.encoder_decoder_attention:
326 | # encoder-decoder attention
327 | q = self.in_proj_q(query)
328 | if key is None:
329 | assert value is None
330 | k = v = None
331 | else:
332 | k = self.in_proj_k(key)
333 | v = self.in_proj_v(key)
334 |
335 | else:
336 | q = self.in_proj_q(query)
337 | k = self.in_proj_k(key)
338 | v = self.in_proj_v(value)
339 | q *= self.scaling
340 |
341 | if self.bias_k is not None:
342 | assert self.bias_v is not None
343 | k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
344 | v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
345 | if attn_mask is not None:
346 | attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
347 | if key_padding_mask is not None:
348 | key_padding_mask = torch.cat(
349 | [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
350 |
351 | q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
352 | if k is not None:
353 | k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
354 | if v is not None:
355 | v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
356 |
357 | if saved_state is not None:
358 | print('Not implemented error.')
359 | exit()
360 |
361 | src_len = k.size(1)
362 |
363 | # This is part of a workaround to get around fork/join parallelism
364 | # not supporting Optional types.
365 | if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
366 | key_padding_mask = None
367 |
368 | if key_padding_mask is not None:
369 | assert key_padding_mask.size(0) == bsz
370 | assert key_padding_mask.size(1) == src_len
371 |
372 | if self.add_zero_attn:
373 | src_len += 1
374 | k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
375 | v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
376 | if attn_mask is not None:
377 | attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
378 | if key_padding_mask is not None:
379 | key_padding_mask = torch.cat(
380 | [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
381 |
382 | attn_weights = torch.bmm(q, k.transpose(1, 2))
383 | attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
384 |
385 | assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
386 |
387 | if attn_mask is not None:
388 | if len(attn_mask.shape) == 2:
389 | attn_mask = attn_mask.unsqueeze(0)
390 | elif len(attn_mask.shape) == 3:
391 | attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
392 | bsz * self.num_heads, tgt_len, src_len)
393 | attn_weights = attn_weights + attn_mask
394 |
395 | if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
396 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
397 | attn_weights = attn_weights.masked_fill(
398 | enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
399 | -1e9,
400 | )
401 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
402 |
403 | if key_padding_mask is not None:
404 | # don't attend to padding symbols
405 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
406 | attn_weights = attn_weights.masked_fill(
407 | key_padding_mask.unsqueeze(1).unsqueeze(2),
408 | -1e9,
409 | )
410 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
411 |
412 | attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
413 |
414 | if before_softmax:
415 | return attn_weights, v
416 |
417 | attn_weights_float = utils.softmax(attn_weights, dim=-1)
418 | attn_weights = attn_weights_float.type_as(attn_weights)
419 | attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
420 |
421 | if reset_attn_weight is not None:
422 | if reset_attn_weight:
423 | self.last_attn_probs = attn_probs.detach()
424 | else:
425 | assert self.last_attn_probs is not None
426 | attn_probs = self.last_attn_probs
427 | attn = torch.bmm(attn_probs, v)
428 | assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
429 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
430 | attn = self.out_proj(attn)
431 |
432 | if need_weights:
433 | attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
434 | if not need_head_weights:
435 | # average attention weights over heads
436 | attn_weights = attn_weights.mean(dim=0)
437 | else:
438 | attn_weights = None
439 |
440 | return attn, (attn_weights, attn_logits)
441 |
442 | def in_proj_qkv(self, query):
443 | return self._in_proj(query).chunk(3, dim=-1)
444 |
445 | def in_proj_q(self, query):
446 | if self.qkv_same_dim:
447 | return self._in_proj(query, end=self.embed_dim)
448 | else:
449 | bias = self.in_proj_bias
450 | if bias is not None:
451 | bias = bias[:self.embed_dim]
452 | return F.linear(query, self.q_proj_weight, bias)
453 |
454 | def in_proj_k(self, key):
455 | if self.qkv_same_dim:
456 | return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
457 | else:
458 | weight = self.k_proj_weight
459 | bias = self.in_proj_bias
460 | if bias is not None:
461 | bias = bias[self.embed_dim:2 * self.embed_dim]
462 | return F.linear(key, weight, bias)
463 |
464 | def in_proj_v(self, value):
465 | if self.qkv_same_dim:
466 | return self._in_proj(value, start=2 * self.embed_dim)
467 | else:
468 | weight = self.v_proj_weight
469 | bias = self.in_proj_bias
470 | if bias is not None:
471 | bias = bias[2 * self.embed_dim:]
472 | return F.linear(value, weight, bias)
473 |
474 | def _in_proj(self, input, start=0, end=None):
475 | weight = self.in_proj_weight
476 | bias = self.in_proj_bias
477 | weight = weight[start:end, :]
478 | if bias is not None:
479 | bias = bias[start:end]
480 | return F.linear(input, weight, bias)
481 |
482 |
483 | def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
484 | return attn_weights
485 |
486 |
487 | class Swish(torch.autograd.Function):
488 | @staticmethod
489 | def forward(ctx, i):
490 | result = i * torch.sigmoid(i)
491 | ctx.save_for_backward(i)
492 | return result
493 |
494 | @staticmethod
495 | def backward(ctx, grad_output):
496 | i = ctx.saved_variables[0]
497 | sigmoid_i = torch.sigmoid(i)
498 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
499 |
500 |
501 | class CustomSwish(nn.Module):
502 | def forward(self, input_tensor):
503 | return Swish.apply(input_tensor)
504 |
505 |
506 | class TransformerFFNLayer(nn.Module):
507 | def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
508 | super().__init__()
509 | self.kernel_size = kernel_size
510 | self.dropout = dropout
511 | self.act = act
512 | if padding == 'SAME':
513 | self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2)
514 | elif padding == 'LEFT':
515 | self.ffn_1 = nn.Sequential(
516 | nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
517 | nn.Conv1d(hidden_size, filter_size, kernel_size)
518 | )
519 | self.ffn_2 = Linear(filter_size, hidden_size)
520 | if self.act == 'swish':
521 | self.swish_fn = CustomSwish()
522 |
523 | def forward(self, x, incremental_state=None):
524 | # x: T x B x C
525 | if incremental_state is not None:
526 | assert incremental_state is None, 'Nar-generation does not allow this.'
527 | exit(1)
528 |
529 | x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
530 | x = x * self.kernel_size ** -0.5
531 |
532 | if incremental_state is not None:
533 | x = x[-1:]
534 | if self.act == 'gelu':
535 | x = F.gelu(x)
536 | if self.act == 'relu':
537 | x = F.relu(x)
538 | if self.act == 'swish':
539 | x = self.swish_fn(x)
540 | x = F.dropout(x, self.dropout, training=self.training)
541 | x = self.ffn_2(x)
542 | return x
543 |
544 |
545 | class BatchNorm1dTBC(nn.Module):
546 | def __init__(self, c):
547 | super(BatchNorm1dTBC, self).__init__()
548 | self.bn = nn.BatchNorm1d(c)
549 |
550 | def forward(self, x):
551 | """
552 |
553 | :param x: [T, B, C]
554 | :return: [T, B, C]
555 | """
556 | x = x.permute(1, 2, 0) # [B, C, T]
557 | x = self.bn(x) # [B, C, T]
558 | x = x.permute(2, 0, 1) # [T, B, C]
559 | return x
560 |
561 |
562 | class EncSALayer(nn.Module):
563 | def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
564 | relu_dropout=0.1, kernel_size=9, padding='SAME', norm='ln', act='gelu'):
565 | super().__init__()
566 | self.c = c
567 | self.dropout = dropout
568 | self.num_heads = num_heads
569 | if num_heads > 0:
570 | if norm == 'ln':
571 | self.layer_norm1 = LayerNorm(c)
572 | elif norm == 'bn':
573 | self.layer_norm1 = BatchNorm1dTBC(c)
574 | self.self_attn = MultiheadAttention(
575 | self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False,
576 | )
577 | if norm == 'ln':
578 | self.layer_norm2 = LayerNorm(c)
579 | elif norm == 'bn':
580 | self.layer_norm2 = BatchNorm1dTBC(c)
581 | self.ffn = TransformerFFNLayer(
582 | c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
583 |
584 | def forward(self, x, encoder_padding_mask=None, **kwargs):
585 | layer_norm_training = kwargs.get('layer_norm_training', None)
586 | if layer_norm_training is not None:
587 | self.layer_norm1.training = layer_norm_training
588 | self.layer_norm2.training = layer_norm_training
589 | if self.num_heads > 0:
590 | residual = x
591 | x = self.layer_norm1(x)
592 | x, _, = self.self_attn(
593 | query=x,
594 | key=x,
595 | value=x,
596 | key_padding_mask=encoder_padding_mask
597 | )
598 | x = F.dropout(x, self.dropout, training=self.training)
599 | x = residual + x
600 | x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
601 |
602 | residual = x
603 | x = self.layer_norm2(x)
604 | x = self.ffn(x)
605 | x = F.dropout(x, self.dropout, training=self.training)
606 | x = residual + x
607 | x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
608 | return x
609 |
610 |
611 | class DecSALayer(nn.Module):
612 | def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, kernel_size=9, act='gelu'):
613 | super().__init__()
614 | self.c = c
615 | self.dropout = dropout
616 | self.layer_norm1 = LayerNorm(c)
617 | self.self_attn = MultiheadAttention(
618 | c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
619 | )
620 | self.layer_norm2 = LayerNorm(c)
621 | self.encoder_attn = MultiheadAttention(
622 | c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
623 | )
624 | self.layer_norm3 = LayerNorm(c)
625 | self.ffn = TransformerFFNLayer(
626 | c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
627 |
628 | def forward(
629 | self,
630 | x,
631 | encoder_out=None,
632 | encoder_padding_mask=None,
633 | incremental_state=None,
634 | self_attn_mask=None,
635 | self_attn_padding_mask=None,
636 | attn_out=None,
637 | reset_attn_weight=None,
638 | **kwargs,
639 | ):
640 | layer_norm_training = kwargs.get('layer_norm_training', None)
641 | if layer_norm_training is not None:
642 | self.layer_norm1.training = layer_norm_training
643 | self.layer_norm2.training = layer_norm_training
644 | self.layer_norm3.training = layer_norm_training
645 | residual = x
646 | x = self.layer_norm1(x)
647 | x, _ = self.self_attn(
648 | query=x,
649 | key=x,
650 | value=x,
651 | key_padding_mask=self_attn_padding_mask,
652 | incremental_state=incremental_state,
653 | attn_mask=self_attn_mask
654 | )
655 | x = F.dropout(x, self.dropout, training=self.training)
656 | x = residual + x
657 |
658 | residual = x
659 | x = self.layer_norm2(x)
660 | if encoder_out is not None:
661 | x, attn = self.encoder_attn(
662 | query=x,
663 | key=encoder_out,
664 | value=encoder_out,
665 | key_padding_mask=encoder_padding_mask,
666 | incremental_state=incremental_state,
667 | static_kv=True,
668 | enc_dec_attn_constraint_mask=None, #utils.get_incremental_state(self, incremental_state, 'enc_dec_attn_constraint_mask'),
669 | reset_attn_weight=reset_attn_weight
670 | )
671 | attn_logits = attn[1]
672 | else:
673 | assert attn_out is not None
674 | x = self.encoder_attn.in_proj_v(attn_out.transpose(0, 1))
675 | attn_logits = None
676 | x = F.dropout(x, self.dropout, training=self.training)
677 | x = residual + x
678 |
679 | residual = x
680 | x = self.layer_norm3(x)
681 | x = self.ffn(x, incremental_state=incremental_state)
682 | x = F.dropout(x, self.dropout, training=self.training)
683 | x = residual + x
684 | # if len(attn_logits.size()) > 3:
685 | # indices = attn_logits.softmax(-1).max(-1).values.sum(-1).argmax(-1)
686 | # attn_logits = attn_logits.gather(1,
687 | # indices[:, None, None, None].repeat(1, 1, attn_logits.size(-2), attn_logits.size(-1))).squeeze(1)
688 | return x, attn_logits
689 |
690 | class PitchPredictor(torch.nn.Module):
691 | def __init__(self, idim, n_layers=5, n_chans=384, odim=2, kernel_size=5,
692 | dropout_rate=0.1, padding='SAME'):
693 | """Initilize pitch predictor module.
694 | Args:
695 | idim (int): Input dimension.
696 | n_layers (int, optional): Number of convolutional layers.
697 | n_chans (int, optional): Number of channels of convolutional layers.
698 | kernel_size (int, optional): Kernel size of convolutional layers.
699 | dropout_rate (float, optional): Dropout rate.
700 | """
701 | super(PitchPredictor, self).__init__()
702 | self.conv = torch.nn.ModuleList()
703 | self.kernel_size = kernel_size
704 | self.padding = padding
705 | for idx in range(n_layers):
706 | in_chans = idim if idx == 0 else n_chans
707 | self.conv += [torch.nn.Sequential(
708 | torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2)
709 | if padding == 'SAME'
710 | else (kernel_size - 1, 0), 0),
711 | torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0),
712 | torch.nn.ReLU(),
713 | LayerNorm(n_chans, dim=1),
714 | torch.nn.Dropout(dropout_rate)
715 | )]
716 | self.linear = torch.nn.Linear(n_chans, odim)
717 | self.embed_positions = SinusoidalPositionalEmbedding(idim, 0, init_size=4096)
718 | self.pos_embed_alpha = nn.Parameter(torch.Tensor([1]))
719 |
720 | def forward(self, xs):
721 | """
722 |
723 | :param xs: [B, T, H]
724 | :return: [B, T, H]
725 | """
726 | positions = self.pos_embed_alpha * self.embed_positions(xs[..., 0])
727 | xs = xs + positions
728 | xs = xs.transpose(1, -1) # (B, idim, Tmax)
729 | for f in self.conv:
730 | xs = f(xs) # (B, C, Tmax)
731 | # NOTE: calculate in log domain
732 | xs = self.linear(xs.transpose(1, -1)) # (B, Tmax, H)
733 | return xs
734 |
735 |
736 | class Prenet(nn.Module):
737 | def __init__(self, in_dim=80, out_dim=256, kernel=5, n_layers=3, strides=None):
738 | super(Prenet, self).__init__()
739 | padding = kernel // 2
740 | self.layers = []
741 | self.strides = strides if strides is not None else [1] * n_layers
742 | for l in range(n_layers):
743 | self.layers.append(nn.Sequential(
744 | nn.Conv1d(in_dim, out_dim, kernel_size=kernel, padding=padding, stride=self.strides[l]),
745 | nn.ReLU(),
746 | nn.BatchNorm1d(out_dim)
747 | ))
748 | in_dim = out_dim
749 | self.layers = nn.ModuleList(self.layers)
750 | self.out_proj = nn.Linear(out_dim, out_dim)
751 |
752 | def forward(self, x):
753 | """
754 |
755 | :param x: [B, T, 80]
756 | :return: [L, B, T, H], [B, T, H]
757 | """
758 | padding_mask = x.abs().sum(-1).eq(0).data # [B, T]
759 | nonpadding_mask_TB = 1 - padding_mask.float()[:, None, :] # [B, 1, T]
760 | x = x.transpose(1, 2)
761 | hiddens = []
762 | for i, l in enumerate(self.layers):
763 | nonpadding_mask_TB = nonpadding_mask_TB[:, :, ::self.strides[i]]
764 | x = l(x) * nonpadding_mask_TB
765 | hiddens.append(x)
766 | hiddens = torch.stack(hiddens, 0) # [L, B, H, T]
767 | hiddens = hiddens.transpose(2, 3) # [L, B, T, H]
768 | x = self.out_proj(x.transpose(1, 2)) # [B, T, H]
769 | x = x * nonpadding_mask_TB.transpose(1, 2)
770 | return hiddens, x
771 |
772 |
773 | class ConvBlock(nn.Module):
774 | def __init__(self, idim=80, n_chans=256, kernel_size=3, stride=1, norm='gn', dropout=0):
775 | super().__init__()
776 | self.conv = ConvNorm(idim, n_chans, kernel_size, stride=stride)
777 | self.norm = norm
778 | if self.norm == 'bn':
779 | self.norm = nn.BatchNorm1d(n_chans)
780 | elif self.norm == 'in':
781 | self.norm = nn.InstanceNorm1d(n_chans, affine=True)
782 | elif self.norm == 'gn':
783 | self.norm = nn.GroupNorm(n_chans // 16, n_chans)
784 | elif self.norm == 'ln':
785 | self.norm = LayerNorm(n_chans // 16, n_chans)
786 | elif self.norm == 'wn':
787 | self.conv = torch.nn.utils.weight_norm(self.conv.conv)
788 | self.dropout = nn.Dropout(dropout)
789 | self.relu = nn.ReLU()
790 |
791 | def forward(self, x):
792 | """
793 |
794 | :param x: [B, C, T]
795 | :return: [B, C, T]
796 | """
797 | x = self.conv(x)
798 | if not isinstance(self.norm, str):
799 | if self.norm == 'none':
800 | pass
801 | elif self.norm == 'ln':
802 | x = self.norm(x.transpose(1, 2)).transpose(1, 2)
803 | else:
804 | x = self.norm(x)
805 | x = self.relu(x)
806 | x = self.dropout(x)
807 | return x
808 |
809 |
810 | class ConvStacks(nn.Module):
811 | def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn',
812 | dropout=0, strides=None, res=True):
813 | super().__init__()
814 | self.conv = torch.nn.ModuleList()
815 | self.kernel_size = kernel_size
816 | self.res = res
817 | self.in_proj = Linear(idim, n_chans)
818 | if strides is None:
819 | strides = [1] * n_layers
820 | else:
821 | assert len(strides) == n_layers
822 | for idx in range(n_layers):
823 | self.conv.append(ConvBlock(
824 | n_chans, n_chans, kernel_size, stride=strides[idx], norm=norm, dropout=dropout))
825 | self.out_proj = Linear(n_chans, odim)
826 |
827 | def forward(self, x, return_hiddens=False):
828 | """
829 |
830 | :param x: [B, T, H]
831 | :return: [B, T, H]
832 | """
833 | x = self.in_proj(x)
834 | x = x.transpose(1, -1) # (B, idim, Tmax)
835 | hiddens = []
836 | for f in self.conv:
837 | x_ = f(x)
838 | x = x + x_ if self.res else x_ # (B, C, Tmax)
839 | hiddens.append(x)
840 | x = x.transpose(1, -1)
841 | x = self.out_proj(x) # (B, Tmax, H)
842 | if return_hiddens:
843 | hiddens = torch.stack(hiddens, 1) # [B, L, C, T]
844 | return x, hiddens
845 | return x
846 |
847 |
848 | class PitchExtractor(nn.Module):
849 | def __init__(self, n_mel_bins=80, conv_layers=2):
850 | super().__init__()
851 | self.hidden_size = 256
852 | self.predictor_hidden = self.hidden_size
853 | self.conv_layers = conv_layers
854 |
855 | self.mel_prenet = Prenet(n_mel_bins, self.hidden_size, strides=[1, 1, 1])
856 | if self.conv_layers > 0:
857 | self.mel_encoder = ConvStacks(
858 | idim=self.hidden_size, n_chans=self.hidden_size, odim=self.hidden_size, n_layers=self.conv_layers)
859 | self.pitch_predictor = PitchPredictor(
860 | self.hidden_size, n_chans=self.predictor_hidden,
861 | n_layers=5, dropout_rate=0.5, odim=2,
862 | padding='SAME', kernel_size=5)
863 |
864 | def forward(self, mel_input=None):
865 | ret = {}
866 | mel_hidden = self.mel_prenet(mel_input)[1]
867 | if self.conv_layers > 0:
868 | mel_hidden = self.mel_encoder(mel_hidden)
869 |
870 | ret['pitch_pred'] = pitch_pred = self.pitch_predictor(mel_hidden)
871 |
872 | # pitch_padding = mel_input.abs().sum(-1) == 0
873 | # use_uv = hparams['pitch_type'] == 'frame' and hparams['use_uv']
874 |
875 | # ret['f0_denorm_pred'] = denorm_f0(
876 | # pitch_pred[:, :, 0], (pitch_pred[:, :, 1] > 0) if use_uv else None,
877 | # hparams, pitch_padding=pitch_padding)
878 | return ret
--------------------------------------------------------------------------------
/preparation_slice.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | from pydub import AudioSegment
4 | from pydub.silence import split_on_silence
5 | from pydub.utils import make_chunks
6 | import argparse
7 | from multiprocessing import Pool
8 | from glob import glob
9 |
10 | def process(filename,wavformat,size):
11 | songname = filename.split('/')[-1].strip('.'+wavformat)
12 | singer = filename.split('/')[-2]
13 | slice_name = './dataset_raw/'+singer+'/'+songname
14 |
15 | if not os.path.exists('./dataset_raw/'+singer):
16 | os.mkdir('./dataset_raw/'+singer)
17 |
18 | # Removing the silent parts
19 | audio_segment = AudioSegment.from_file(filename) # Loading the audio
20 | list_split_on_silence = split_on_silence(
21 | audio_segment, min_silence_len=600,
22 | silence_thresh=-40,
23 | keep_silence=400)
24 | sum=audio_segment[:1]
25 | for i, chunk in enumerate(list_split_on_silence):
26 | sum=sum+chunk
27 |
28 | # Slicing
29 | chunks = make_chunks(sum, size)
30 |
31 | for i, chunk in enumerate(chunks):
32 | chunk_name=slice_name+"_{0}.wav".format(i)
33 | if not os.path.exists(chunk_name):
34 | #logger1.info(chunk_name)
35 | chunk.export(chunk_name, format="wav")
36 |
37 |
38 |
39 | if __name__ == '__main__':
40 | parser = argparse.ArgumentParser()
41 | parser.add_argument("-w","--wavformat", type=str, default="wav", help="the wavformat of original data")
42 | parser.add_argument("-s","--size", type=int, default=10000, help="the length of audio slices")
43 | args = parser.parse_args()
44 | wavformat = args.wavformat
45 | size = args.size
46 | files=glob('./dataset_slice/*/*.'+wavformat)
47 |
48 | for file in files:
49 | process(file,wavformat,size)
50 |
--------------------------------------------------------------------------------
/preprocessing1_resample.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import os,tqdm
3 | import multiprocessing as mp
4 | import soundfile as sf
5 | from glob import glob
6 | import argparse
7 |
8 | def resample_one(filename):
9 | singer=filename.split('/')[-2]
10 | songname=filename.split('/')[-1]
11 | output_path='dataset/'+singer+'/'+songname
12 | if os.path.exists(output_path):
13 | return
14 | wav, sr = librosa.load(filename, sr=24000)
15 | # normalize the volume
16 | wav = wav / (0.00001+max(abs(wav)))*0.95
17 | # write to file using soundfile
18 | try:
19 | sf.write(output_path, wav, 24000)
20 | except:
21 | print("Error writing file",output_path)
22 | return
23 |
24 | def mkdir_func(input_path):
25 | singer=input_path.split('/')[-2]
26 | out_dir = 'dataset/'+singer
27 | if not os.path.exists(out_dir):
28 | os.makedirs(out_dir)
29 |
30 | def resample_parallel(num_process):
31 | input_paths = glob('dataset_raw/*/*.wav')
32 | print("input_paths",len(input_paths))
33 | # multiprocessing with progress bar
34 | pool = mp.Pool(num_process)
35 | for _ in tqdm.tqdm(pool.imap_unordered(resample_one, input_paths), total=len(input_paths)):
36 | pass
37 |
38 | def path_parallel():
39 | input_paths = glob('dataset_raw/*/*.wav')
40 | input_paths = list(set(input_paths)) # sort
41 | print("input_paths",len(input_paths))
42 | for input_path in input_paths:
43 | mkdir_func(input_path)
44 |
45 | if __name__ == "__main__":
46 | parser = argparse.ArgumentParser()
47 | parser.add_argument("-n","--num_process", type=int, default=5, help="the number of process")
48 | args = parser.parse_args()
49 | num_process = args.num_process
50 | path_parallel()
51 | resample_parallel(num_process)
52 |
--------------------------------------------------------------------------------
/preprocessing2_flist.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import re
4 | import wave
5 | from random import shuffle
6 |
7 | from loguru import logger
8 | from tqdm import tqdm
9 |
10 | import utils
11 |
12 | pattern = re.compile(r'^[\.a-zA-Z0-9_\/]+$')
13 |
14 | def get_wav_duration(file_path):
15 | try:
16 | with wave.open(file_path, 'rb') as wav_file:
17 | # The number of frames
18 | n_frames = wav_file.getnframes()
19 | # The sampling rate
20 | framerate = wav_file.getframerate()
21 | # The duration
22 | return n_frames / float(framerate)
23 | except Exception as e:
24 | logger.error(f"Reading {file_path}")
25 | raise e
26 |
27 | if __name__ == "__main__":
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument("--train_list", type=str, default="./filelists/train.txt", help="path to train list")
30 | parser.add_argument("--val_list", type=str, default="./filelists/val.txt", help="path to val list")
31 | parser.add_argument("--source_dir", type=str, default="./dataset", help="path to source dir")
32 | args = parser.parse_args()
33 |
34 | train = []
35 | val = []
36 | idx = 0
37 | spk_dict = {}
38 | spk_id = 0
39 |
40 |
41 | for speaker in tqdm(os.listdir(args.source_dir)):
42 | spk_dict[speaker] = spk_id
43 | spk_id += 1
44 | wavs = []
45 |
46 | for file_name in os.listdir(os.path.join(args.source_dir, speaker)):
47 | if not file_name.endswith("wav"):
48 | continue
49 | if file_name.startswith("."):
50 | continue
51 |
52 | file_path = "/".join([args.source_dir, speaker, file_name])
53 |
54 | if get_wav_duration(file_path) < 0.3:
55 | logger.info("Skip too short audio: " + file_path)
56 | continue
57 |
58 | wavs.append(file_path)
59 |
60 | shuffle(wavs)
61 | train += wavs[2:]
62 | val += wavs[:2]
63 |
64 | shuffle(train)
65 | shuffle(val)
66 |
67 | logger.info("Writing " + args.train_list)
68 | with open(args.train_list, "w") as f:
69 | for fname in tqdm(train):
70 | wavpath = fname
71 | f.write(wavpath + "\n")
72 |
73 | logger.info("Writing " + args.val_list)
74 | with open(args.val_list, "w") as f:
75 | for fname in tqdm(val):
76 | wavpath = fname
77 | f.write(wavpath + "\n")
78 |
79 |
80 | d_config_template = utils.load_config("configs_template/diffusion_template.yaml")
81 | d_config_template["model"]["n_spk"] = spk_id
82 |
83 | d_config_template["spk"] = spk_dict
84 |
85 |
86 | logger.info("Writing to configs/diffusion.yaml")
87 | utils.save_config("configs/diffusion.yaml",d_config_template)
88 |
--------------------------------------------------------------------------------
/preprocessing3_feature.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import random
5 | from concurrent.futures import ProcessPoolExecutor
6 | from glob import glob
7 | from random import shuffle
8 |
9 | import librosa
10 | import numpy as np
11 | import torch
12 | import torch.multiprocessing as mp
13 | from loguru import logger
14 | from tqdm import tqdm
15 |
16 | import utils
17 | from Vocoder import Vocoder
18 | from mel_processing import spectrogram_torch
19 |
20 | logging.getLogger("numba").setLevel(logging.WARNING)
21 | logging.getLogger("matplotlib").setLevel(logging.WARNING)
22 |
23 | dconfig = utils.load_config("./configs/diffusion.yaml")
24 |
25 |
26 | def process_one(filename, hmodel, device, hop_length, sampling_rate, filter_length, win_length):
27 | wav, sr = librosa.load(filename, sr=sampling_rate)
28 | audio_norm = torch.FloatTensor(wav)
29 | audio_norm = audio_norm.unsqueeze(0)
30 | soft_path = filename + ".soft.pt"
31 | if not os.path.exists(soft_path):
32 | wav16k = librosa.resample(wav, orig_sr=sampling_rate, target_sr=16000)
33 | wav16k = torch.from_numpy(wav16k).to(device)
34 | c = hmodel.encoder(wav16k) # extract content from the pre-trained model
35 | torch.save(c.cpu(), soft_path)
36 |
37 | f0_path = filename + ".f0.npy"
38 | if not os.path.exists(f0_path):
39 | from Features import DioF0Predictor
40 | f0_predictor= DioF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
41 | f0,uv = f0_predictor.compute_f0_uv(wav)
42 | np.save(f0_path, np.asanyarray((f0,uv),dtype=object))
43 |
44 |
45 | spec_path = filename.replace(".wav", ".spec.pt")
46 | if not os.path.exists(spec_path):
47 | # Process spectrogram
48 | # The following code can't be replaced by torch.FloatTensor(wav)
49 | # because load_wav_to_torch return a tensor that need to be normalized
50 | if sr != sampling_rate:
51 | raise ValueError("{} SR doesn't match target {} SR".format(sr,sampling_rate))
52 | spec = spectrogram_torch(
53 | audio_norm,
54 | filter_length,
55 | sampling_rate,
56 | hop_length,
57 | win_length,
58 | center=False,
59 | )
60 | spec = torch.squeeze(spec, 0)
61 | torch.save(spec, spec_path)
62 |
63 |
64 | volume_path = filename + ".vol.npy"
65 | if not os.path.exists(volume_path):
66 | volume_extractor = utils.Volume_Extractor(hop_length)
67 | volume = volume_extractor.extract(audio_norm)
68 | np.save(volume_path, volume.to('cpu').numpy())
69 |
70 |
71 | mel_path = filename + ".mel.npy"
72 | mel_extractor = Vocoder(dconfig.vocoder.type, dconfig.vocoder.ckpt, device=device)
73 |
74 | if not os.path.exists(mel_path) and mel_extractor is not None:
75 | mel_t = mel_extractor.extract(audio_norm.to(device), sampling_rate)
76 | mel = mel_t.squeeze().to('cpu').numpy()
77 | np.save(mel_path, mel)
78 |
79 | max_amp = float(torch.max(torch.abs(audio_norm))) + 1e-5
80 | max_shift = min(1, np.log10(1/max_amp))
81 | log10_vol_shift = random.uniform(-1, max_shift)
82 | keyshift = random.uniform(-5, 5)
83 |
84 | aug_mel_path = filename + ".aug_mel.npy"
85 | if not os.path.exists(aug_mel_path):
86 | aug_mel_t = mel_extractor.extract(audio_norm * (10 ** log10_vol_shift), sampling_rate, keyshift = keyshift)
87 | aug_mel = aug_mel_t.squeeze().to('cpu').numpy()
88 | np.save(aug_mel_path,np.asanyarray((aug_mel,keyshift),dtype=object))
89 |
90 | aug_vol_path = filename + ".aug_vol.npy"
91 | if not os.path.exists(aug_vol_path):
92 | aug_vol = volume_extractor.extract(audio_norm * (10 ** log10_vol_shift))
93 | np.save(aug_vol_path,aug_vol.to('cpu').numpy())
94 |
95 |
96 | def process_batch(file_chunk, hop_length, sampling_rate, filter_length, win_length, device="cpu"):
97 | logger.info("Loading speech encoder for content...")
98 | rank = mp.current_process()._identity
99 | rank = rank[0] if len(rank) > 0 else 0
100 | if torch.cuda.is_available():
101 | gpu_id = rank % torch.cuda.device_count()
102 | device = torch.device(f"cuda:{gpu_id}")
103 | logger.info(f"Rank {rank} uses device {device}")
104 | from Features import ContentVec768L12
105 | hmodel = ContentVec768L12(device = device)
106 | logger.info(f"Loaded speech encoder for rank {rank}")
107 | for filename in tqdm(file_chunk, position = rank):
108 | process_one(filename, hmodel, device, hop_length, sampling_rate, filter_length, win_length)
109 |
110 | def parallel_process(filenames, num_processes, hop_length, sampling_rate, filter_length, win_length, device):
111 | with ProcessPoolExecutor(max_workers=num_processes) as executor:
112 | tasks = []
113 | for i in range(num_processes):
114 | start = int(i * len(filenames) / num_processes)
115 | end = int((i + 1) * len(filenames) / num_processes)
116 | file_chunk = filenames[start:end]
117 | tasks.append(executor.submit(process_batch, file_chunk, hop_length, sampling_rate, filter_length, win_length,device=device))
118 | for task in tqdm(tasks, position = 0):
119 | task.result()
120 |
121 | if __name__ == "__main__":
122 | parser = argparse.ArgumentParser()
123 | parser.add_argument('-c',"--config", type=str, default='configs/diffusion.yaml', help="path to input dir")
124 | parser.add_argument(
125 | '-n','--num_processes', type=int, default=1, help='You are advised to set the number of processes to the same as the number of CPU cores')
126 | args = parser.parse_args()
127 |
128 | dconfig = utils.load_config(args.config)
129 |
130 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
131 | logger.info("Using device: " + str(device))
132 | print("Loading Mel Extractor...",dconfig.vocoder.type)
133 | # mel_extractor = Vocoder(dconfig.vocoder.type, dconfig.vocoder.ckpt, device=device)
134 |
135 | filenames = glob("./dataset/*/*.wav", recursive=True) # [:10]
136 | shuffle(filenames)
137 | mp.set_start_method("spawn", force=True)
138 |
139 | num_processes = args.num_processes
140 | if num_processes == 0:
141 | num_processes = os.cpu_count()
142 |
143 | parallel_process(filenames, num_processes, dconfig.data.hop_length, dconfig.data.sampling_rate, dconfig.data.filter_length, dconfig.data.win_length, device)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.23.5
2 | pyworld
3 | scipy==1.10.0
4 | SoundFile==0.12.1
5 | librosa==0.9.1
6 | torch==2.1.0
7 | fairseq==0.12.2
8 | torchaudio
9 | tqdm
10 | loguru
11 | pydub
12 | matplotlib
13 | tensorboard
14 | tensorboardX
15 | pyyaml
--------------------------------------------------------------------------------
/saver.py:
--------------------------------------------------------------------------------
1 | '''
2 | author: wayn391@mastertones
3 | '''
4 |
5 | import datetime
6 | import os
7 | import time
8 |
9 | import matplotlib.pyplot as plt
10 | import torch
11 | import yaml
12 | from torch.utils.tensorboard import SummaryWriter
13 |
14 |
15 | class Saver(object):
16 | def __init__(
17 | self,
18 | args,
19 | save_dir,
20 | initial_global_step=-1):
21 |
22 | self.expdir = save_dir
23 | self.sample_rate = args.data.sampling_rate
24 |
25 | # cold start
26 | self.global_step = initial_global_step
27 | self.init_time = time.time()
28 | self.last_time = time.time()
29 |
30 | # makedirs
31 | os.makedirs(self.expdir, exist_ok=True)
32 |
33 | # path
34 | self.path_log_info = os.path.join(self.expdir, 'log_info.txt')
35 |
36 | # ckpt
37 | os.makedirs(self.expdir, exist_ok=True)
38 |
39 | # writer
40 | self.writer = SummaryWriter(os.path.join(self.expdir, 'logs'))
41 |
42 | # save config
43 | path_config = os.path.join(self.expdir, 'config.yaml')
44 | with open(path_config, "w") as out_config:
45 | yaml.dump(dict(args), out_config)
46 |
47 |
48 | def log_info(self, msg):
49 | '''log method'''
50 | if isinstance(msg, dict):
51 | msg_list = []
52 | for k, v in msg.items():
53 | tmp_str = ''
54 | if isinstance(v, int):
55 | tmp_str = '{}: {:,}'.format(k, v)
56 | else:
57 | tmp_str = '{}: {}'.format(k, v)
58 |
59 | msg_list.append(tmp_str)
60 | msg_str = '\n'.join(msg_list)
61 | else:
62 | msg_str = msg
63 |
64 | # dsplay
65 | print(msg_str)
66 |
67 | # save
68 | with open(self.path_log_info, 'a') as fp:
69 | fp.write(msg_str+'\n')
70 |
71 | def log_value(self, dict):
72 | for k, v in dict.items():
73 | self.writer.add_scalar(k, v, self.global_step)
74 |
75 | def log_spec(self, name, spec, spec_out, vmin=-14, vmax=3.5):
76 | spec_cat = torch.cat([(spec_out - spec).abs() + vmin, spec, spec_out], -1)
77 | spec = spec_cat[0]
78 | if isinstance(spec, torch.Tensor):
79 | spec = spec.cpu().numpy()
80 | fig = plt.figure(figsize=(12, 9))
81 | plt.pcolor(spec.T, vmin=vmin, vmax=vmax)
82 | plt.tight_layout()
83 | self.writer.add_figure(name, fig, self.global_step)
84 |
85 | def log_audio(self, dict):
86 | for k, v in dict.items():
87 | self.writer.add_audio(k, v, global_step=self.global_step, sample_rate=self.sample_rate)
88 |
89 | def get_interval_time(self, update=True):
90 | cur_time = time.time()
91 | time_interval = cur_time - self.last_time
92 | if update:
93 | self.last_time = cur_time
94 | return time_interval
95 |
96 | def get_total_time(self, to_str=True):
97 | total_time = time.time() - self.init_time
98 | if to_str:
99 | total_time = str(datetime.timedelta(
100 | seconds=total_time))[:-5]
101 | return total_time
102 |
103 | def save_model(
104 | self,
105 | model,
106 | optimizer,
107 | name='model',
108 | postfix='',
109 | to_json=False):
110 | # path
111 | if postfix:
112 | postfix = '_' + postfix
113 | path_pt = os.path.join(
114 | self.expdir , name+postfix+'.pt')
115 |
116 | # check
117 | print(' [*] model checkpoint saved: {}'.format(path_pt))
118 |
119 | # save
120 | if optimizer is not None:
121 | torch.save({
122 | 'global_step': self.global_step,
123 | 'model': model.state_dict(),
124 | 'optimizer': optimizer.state_dict()}, path_pt)
125 | else:
126 | torch.save({
127 | 'global_step': self.global_step,
128 | 'model': model.state_dict()}, path_pt)
129 |
130 |
131 | def delete_model(self, name='model', postfix=''):
132 | # path
133 | if postfix:
134 | postfix = '_' + postfix
135 | path_pt = os.path.join(
136 | self.expdir , name+postfix+'.pt')
137 |
138 | # delete
139 | if os.path.exists(path_pt):
140 | os.remove(path_pt)
141 | print(' [*] model checkpoint deleted: {}'.format(path_pt))
142 |
143 | def global_step_increment(self):
144 | self.global_step += 1
145 |
146 |
147 |
--------------------------------------------------------------------------------
/slicer.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import torch
3 | import torchaudio
4 |
5 |
6 | class Slicer:
7 | def __init__(self,
8 | sr: int,
9 | threshold: float = -40.,
10 | min_length: int = 5000,
11 | min_interval: int = 300,
12 | hop_size: int = 20,
13 | max_sil_kept: int = 5000):
14 | if not min_length >= min_interval >= hop_size:
15 | raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size')
16 | if not max_sil_kept >= hop_size:
17 | raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size')
18 | min_interval = sr * min_interval / 1000
19 | self.threshold = 10 ** (threshold / 20.)
20 | self.hop_size = round(sr * hop_size / 1000)
21 | self.win_size = min(round(min_interval), 4 * self.hop_size)
22 | self.min_length = round(sr * min_length / 1000 / self.hop_size)
23 | self.min_interval = round(min_interval / self.hop_size)
24 | self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
25 |
26 | def _apply_slice(self, waveform, begin, end):
27 | if len(waveform.shape) > 1:
28 | return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)]
29 | else:
30 | return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)]
31 |
32 | # @timeit
33 | def slice(self, waveform):
34 | if len(waveform.shape) > 1:
35 | samples = librosa.to_mono(waveform)
36 | else:
37 | samples = waveform
38 | if samples.shape[0] <= self.min_length:
39 | return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}}
40 | rms_list = librosa.feature.rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
41 | sil_tags = []
42 | silence_start = None
43 | clip_start = 0
44 | for i, rms in enumerate(rms_list):
45 | # Keep looping while frame is silent.
46 | if rms < self.threshold:
47 | # Record start of silent frames.
48 | if silence_start is None:
49 | silence_start = i
50 | continue
51 | # Keep looping while frame is not silent and silence start has not been recorded.
52 | if silence_start is None:
53 | continue
54 | # Clear recorded silence start if interval is not enough or clip is too short
55 | is_leading_silence = silence_start == 0 and i > self.max_sil_kept
56 | need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
57 | if not is_leading_silence and not need_slice_middle:
58 | silence_start = None
59 | continue
60 | # Need slicing. Record the range of silent frames to be removed.
61 | if i - silence_start <= self.max_sil_kept:
62 | pos = rms_list[silence_start: i + 1].argmin() + silence_start
63 | if silence_start == 0:
64 | sil_tags.append((0, pos))
65 | else:
66 | sil_tags.append((pos, pos))
67 | clip_start = pos
68 | elif i - silence_start <= self.max_sil_kept * 2:
69 | pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin()
70 | pos += i - self.max_sil_kept
71 | pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
72 | pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
73 | if silence_start == 0:
74 | sil_tags.append((0, pos_r))
75 | clip_start = pos_r
76 | else:
77 | sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
78 | clip_start = max(pos_r, pos)
79 | else:
80 | pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
81 | pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
82 | if silence_start == 0:
83 | sil_tags.append((0, pos_r))
84 | else:
85 | sil_tags.append((pos_l, pos_r))
86 | clip_start = pos_r
87 | silence_start = None
88 | # Deal with trailing silence.
89 | total_frames = rms_list.shape[0]
90 | if silence_start is not None and total_frames - silence_start >= self.min_interval:
91 | silence_end = min(total_frames, silence_start + self.max_sil_kept)
92 | pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start
93 | sil_tags.append((pos, total_frames + 1))
94 | # Apply and return slices.
95 | if len(sil_tags) == 0:
96 | return {"0": {"slice": False, "split_time": f"0,{len(waveform)}"}}
97 | else:
98 | chunks = []
99 | # 第一段静音并非从头开始,补上有声片段
100 | if sil_tags[0][0]:
101 | chunks.append(
102 | {"slice": False, "split_time": f"0,{min(waveform.shape[0], sil_tags[0][0] * self.hop_size)}"})
103 | for i in range(0, len(sil_tags)):
104 | # 标识有声片段(跳过第一段)
105 | if i:
106 | chunks.append({"slice": False,
107 | "split_time": f"{sil_tags[i - 1][1] * self.hop_size},{min(waveform.shape[0], sil_tags[i][0] * self.hop_size)}"})
108 | # 标识所有静音片段
109 | chunks.append({"slice": True,
110 | "split_time": f"{sil_tags[i][0] * self.hop_size},{min(waveform.shape[0], sil_tags[i][1] * self.hop_size)}"})
111 | # 最后一段静音并非结尾,补上结尾片段
112 | if sil_tags[-1][1] * self.hop_size < len(waveform):
113 | chunks.append({"slice": False, "split_time": f"{sil_tags[-1][1] * self.hop_size},{len(waveform)}"})
114 | chunk_dict = {}
115 | for i in range(len(chunks)):
116 | chunk_dict[str(i)] = chunks[i]
117 | return chunk_dict
118 |
119 |
120 | def cut(audio_path, db_thresh=-30, min_len=5000):
121 | audio, sr = librosa.load(audio_path, sr=None)
122 | slicer = Slicer(
123 | sr=sr,
124 | threshold=db_thresh,
125 | min_length=min_len
126 | )
127 | chunks = slicer.slice(audio)
128 | return chunks
129 |
130 |
131 | def chunks2audio(audio_path, chunks):
132 | chunks = dict(chunks)
133 | audio, sr = torchaudio.load(audio_path)
134 | if len(audio.shape) == 2 and audio.shape[1] >= 2:
135 | audio = torch.mean(audio, dim=0).unsqueeze(0)
136 | audio = audio.cpu().numpy()[0]
137 | result = []
138 | for k, v in chunks.items():
139 | tag = v["split_time"].split(",")
140 | if tag[0] != tag[1]:
141 | result.append((v["slice"], audio[int(tag[0]):int(tag[1])]))
142 | return result, sr
143 |
--------------------------------------------------------------------------------
/solver.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import librosa
4 | import numpy as np
5 | import torch
6 | from torch import autocast
7 | from torch.cuda.amp import GradScaler
8 |
9 | import utils
10 | from saver import Saver
11 |
12 |
13 | def test(args, model, vocoder, loader_test, saver):
14 | print(' [*] testing...')
15 | model.eval()
16 |
17 | # losses
18 | test_loss = 0.
19 |
20 | # intialization
21 | num_batches = len(loader_test)
22 | rtf_all = []
23 |
24 | # run
25 | with torch.no_grad():
26 | for bidx, data in enumerate(loader_test):
27 | fn = data['name'][0].split("/")[-1]
28 | speaker = data['name'][0].split("/")[-2]
29 | print('--------')
30 | print('{}/{} - {}'.format(bidx, num_batches, fn))
31 |
32 | # unpack data
33 | for k in data.keys():
34 | if not k.startswith('name'):
35 | data[k] = data[k].to(args.device)
36 | print('>>', data['name'][0])
37 |
38 | # forward
39 | st_time = time.time()
40 | mel = model(
41 | data['units'],
42 | data['f0'],
43 | data['volume'],
44 | data['spk_id'],
45 | gt_spec= data['mel'],
46 | infer=True
47 | )
48 | signal = vocoder.infer(mel, data['f0'])
49 | ed_time = time.time()
50 |
51 | # RTF
52 | run_time = ed_time - st_time
53 | song_time = signal.shape[-1] / args.data.sampling_rate
54 | rtf = run_time / song_time
55 | print('RTF: {} | {} / {}'.format(rtf, run_time, song_time))
56 | rtf_all.append(rtf)
57 |
58 | # loss
59 | for i in range(args.train.batch_size):
60 | loss = model(
61 | data['units'],
62 | data['f0'],
63 | data['volume'],
64 | data['spk_id'],
65 | gt_spec=data['mel'],
66 | infer=False)
67 | if isinstance(loss, list):
68 | test_loss += loss[0].item()
69 | else:
70 | test_loss += loss.item()
71 |
72 | # log mel
73 | saver.log_spec(f"{speaker}_{fn}.wav", data['mel'], mel)
74 |
75 | # log audi
76 | path_audio = data['name_ext'][0]
77 | audio, sr = librosa.load(path_audio, sr=args.data.sampling_rate)
78 | if len(audio.shape) > 1:
79 | audio = librosa.to_mono(audio)
80 | audio = torch.from_numpy(audio).unsqueeze(0).to(signal)
81 | saver.log_audio({f"{speaker}_{fn}_gt.wav": audio,f"{speaker}_{fn}_pred.wav": signal})
82 | # report
83 | test_loss /= args.train.batch_size
84 | test_loss /= num_batches
85 |
86 | # check
87 | print(' [test_loss] test_loss:', test_loss)
88 | print(' Real Time Factor', np.mean(rtf_all))
89 | return test_loss
90 |
91 |
92 | # def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_test):
93 | def train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_test, teacher):
94 |
95 | # saver
96 | if teacher:
97 | save_dir=args.env.expdir
98 | else:
99 | save_dir=args.env.comodir
100 |
101 | saver = Saver(args, save_dir, initial_global_step=initial_global_step)
102 |
103 | # model size
104 | params_count = utils.get_network_paras_amount({'model': model})
105 | saver.log_info('--- model size ---')
106 | saver.log_info(params_count)
107 |
108 | # run
109 | num_batches = len(loader_train)
110 | model.train()
111 | saver.log_info('======= start training =======')
112 | scaler = GradScaler()
113 | if args.train.amp_dtype == 'fp32': # fp32
114 | dtype = torch.float32
115 | elif args.train.amp_dtype == 'fp16':
116 | dtype = torch.float16
117 | elif args.train.amp_dtype == 'bf16':
118 | dtype = torch.bfloat16
119 | else:
120 | raise ValueError(' [x] Unknown amp_dtype: ' + args.train.amp_dtype)
121 | saver.log_info("epoch|batch_idx/num_batches|output_dir|batch/s|lr|time|step")
122 | for epoch in range(args.train.epochs):
123 | for batch_idx, data in enumerate(loader_train):
124 | saver.global_step_increment()
125 | optimizer.zero_grad()
126 |
127 | # unpack data
128 | for k in data.keys():
129 | if not k.startswith('name'):
130 | data[k] = data[k].to(args.device)
131 |
132 | # forward
133 | if dtype == torch.float32:
134 | loss = model(data['units'].float(), data['f0'], data['volume'], data['spk_id'],
135 | aug_shift = data['aug_shift'], gt_spec=data['mel'].float(), infer=False)
136 | else:
137 | with autocast(device_type=args.device, dtype=dtype):
138 | loss = model(data['units'], data['f0'], data['volume'], data['spk_id'],
139 | aug_shift = data['aug_shift'], gt_spec=data['mel'], infer=False)
140 | if not teacher:
141 | loss=loss*1000
142 | # loss_mel = loss[0]*50
143 | # loss_pitch = loss[1]
144 | # loss = loss_mel +loss_pitch
145 | # loss=loss*1000
146 | # handle nan loss
147 | if torch.isnan(loss):
148 | raise ValueError(' [x] nan loss ')
149 | else:
150 | # backpropagate
151 | if dtype == torch.float32:
152 | loss.backward()
153 | optimizer.step()
154 | else:
155 | scaler.scale(loss).backward()
156 | scaler.step(optimizer)
157 | scaler.update()
158 | scheduler.step()
159 |
160 | # log loss
161 | if saver.global_step % args.train.interval_log == 0:
162 | current_lr = optimizer.param_groups[0]['lr']
163 | saver.log_info(
164 | 'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | lr: {:.6} | loss: {:.3f} | time: {} | step: {}'.format(
165 | epoch,
166 | batch_idx,
167 | num_batches,
168 | save_dir,
169 | args.train.interval_log/saver.get_interval_time(),
170 | current_lr,
171 | loss.item(),
172 | saver.get_total_time(),
173 | saver.global_step
174 | )
175 | )
176 |
177 | saver.log_value({
178 | 'train/loss': loss.item()
179 | })
180 |
181 | # if not teacher:
182 | # saver.log_value({
183 | # 'train/loss_pitch': loss_pitch.item()
184 | # })
185 | # saver.log_value({
186 | # 'train/loss_mel': loss_mel.item()
187 | # })
188 |
189 | saver.log_value({
190 | 'train/lr': current_lr
191 | })
192 |
193 | # validation
194 | if saver.global_step % args.train.interval_val == 0:
195 | optimizer_save = optimizer if args.train.save_opt else None
196 |
197 | # save latest
198 | saver.save_model(model, optimizer_save, postfix=f'{saver.global_step}')
199 | last_val_step = saver.global_step - args.train.interval_val
200 | if last_val_step % args.train.interval_force_save != 0:
201 | saver.delete_model(postfix=f'{last_val_step}')
202 |
203 | # run testing set
204 | test_loss = test(args, model, vocoder, loader_test, saver)
205 |
206 | # log loss
207 | saver.log_info(
208 | ' --- --- \nloss: {:.3f}. '.format(
209 | test_loss,
210 | )
211 | )
212 |
213 | saver.log_value({
214 | 'validation/loss': test_loss
215 | })
216 |
217 | model.train()
218 |
219 |
220 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | from loguru import logger
5 | from torch.optim import lr_scheduler
6 |
7 | import os
8 |
9 | from data_loaders import get_data_loaders
10 | import utils
11 | from solver import train
12 | from ComoSVC import ComoSVC
13 | from Vocoder import Vocoder
14 | from utils import load_teacher_model_with_pitch
15 | from utils import traverse_dir
16 |
17 |
18 |
19 |
20 | def parse_args(args=None, namespace=None):
21 | """Parse command-line arguments."""
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument(
24 | "-c",
25 | "--config",
26 | type=str,
27 | default='configs/diffusion.yaml',
28 | help="path to the config file")
29 |
30 | parser.add_argument(
31 | "-t",
32 | "--teacher",
33 | action='store_false',
34 | help="if it is the teacher model")
35 |
36 | parser.add_argument(
37 | "-s",
38 | "--total_steps",
39 | type=int,
40 | default=1,
41 | help="the number of iterative steps during inference")
42 |
43 | parser.add_argument(
44 | "-p",
45 | "--teacher_model_path",
46 | type=str,
47 | default="logs/teacher/model_800000.pt",
48 | help="path to teacher model")
49 | return parser.parse_args(args=args, namespace=namespace)
50 |
51 |
52 | if __name__ == '__main__':
53 | # parse commands
54 | cmd = parse_args()
55 | # load config
56 | args = utils.load_config(cmd.config)
57 | logger.info(' > config:'+ cmd.config)
58 | # teacher_or_not=cmd.teacher
59 | teacher_model_path=cmd.teacher_model_path
60 | # load vocoder
61 | vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=args.device)
62 |
63 |
64 | # load model
65 | if cmd.teacher:
66 | model = ComoSVC(
67 | args.data.encoder_out_channels,
68 | args.model.n_spk,
69 | args.model.use_pitch_aug,#true
70 | vocoder.dimension,
71 | args.model.n_layers,
72 | args.model.n_chans,
73 | args.model.n_hidden,
74 | cmd.total_steps,
75 | teacher=cmd.teacher
76 | )
77 |
78 | optimizer = torch.optim.AdamW(model.parameters(),lr=args.train.lr)
79 | initial_global_step, model, optimizer = utils.load_model(args.env.expdir, model, optimizer, device=args.device)
80 |
81 | else:
82 | model = ComoSVC(
83 | args.data.encoder_out_channels,
84 | args.model.n_spk,
85 | args.model.use_pitch_aug,
86 | vocoder.dimension,
87 | args.model.n_layers,
88 | args.model.n_chans,
89 | args.model.n_hidden,
90 | cmd.total_steps,
91 | teacher=cmd.teacher
92 | )
93 | model = load_teacher_model_with_pitch(model,checkpoint_dir=cmd.teacher_model_path) # teacher model path
94 |
95 |
96 | # optimizer = torch.optim.AdamW(params=model.decoder.denoise_fn.parameters())
97 | optimizer = torch.optim.AdamW(params=model.decoder.denoise_fn.parameters())
98 | path_pt = traverse_dir(args.env.comodir, ['pt'], is_ext=False)
99 | if len(path_pt)>0:
100 | initial_global_step, model, optimizer = utils.load_model(args.env.comodir, model, optimizer, device=args.device)
101 | else:
102 | initial_global_step = 0
103 |
104 |
105 | if cmd.teacher:
106 | logger.info(f' > The Teacher Model is training now.')
107 | else:
108 | logger.info(f' > The Student Model CoMoSVC is training now.')
109 |
110 |
111 |
112 | for param_group in optimizer.param_groups:
113 | if cmd.teacher:
114 | param_group['initial_lr'] = args.train.lr
115 | param_group['lr'] = args.train.lr * (args.train.gamma ** max(((initial_global_step-2)//args.train.decay_step),0) )
116 | param_group['weight_decay'] = args.train.weight_decay
117 | else:
118 | param_group['initial_lr'] = args.train.comolr
119 | param_group['lr'] = args.train.comolr * (args.train.gamma ** max(((initial_global_step-2)//args.train.decay_step),0) )
120 | param_group['weight_decay'] = args.train.weight_decay
121 | scheduler = lr_scheduler.StepLR(optimizer, step_size=args.train.decay_step, gamma=args.train.gamma,last_epoch=initial_global_step-2)
122 |
123 | # device
124 | if args.device == 'cuda':
125 | torch.cuda.set_device(args.env.gpu_id)
126 | model.to(args.device)
127 |
128 | for state in optimizer.state.values():
129 | for k, v in state.items():
130 | if torch.is_tensor(v):
131 | state[k] = v.to(args.device)
132 |
133 | # datas
134 | loader_train, loader_valid = get_data_loaders(args, whole_audio=False)
135 |
136 | train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_valid, teacher=cmd.teacher)
137 |
138 |
139 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import json
3 | import yaml
4 | import os
5 | import copy
6 | from torch.nn import functional as F
7 |
8 | def repeat_expand_2d(content, target_len, mode = 'left'):
9 | # content : [h, t]
10 | return repeat_expand_2d_left(content, target_len) if mode == 'left' else repeat_expand_2d_other(content, target_len, mode)
11 |
12 |
13 |
14 | def repeat_expand_2d_left(content, target_len):
15 | # content : [h, t]
16 |
17 | src_len = content.shape[-1]
18 | target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device)
19 | temp = torch.arange(src_len+1) * target_len / src_len
20 | current_pos = 0
21 | for i in range(target_len):
22 | if i < temp[current_pos+1]:
23 | target[:, i] = content[:, current_pos]
24 | else:
25 | current_pos += 1
26 | target[:, i] = content[:, current_pos]
27 |
28 | return target
29 |
30 |
31 | # mode : 'nearest'| 'linear'| 'bilinear'| 'bicubic'| 'trilinear'| 'area'
32 | def repeat_expand_2d_other(content, target_len, mode = 'nearest'):
33 | # content : [h, t]
34 | content = content[None,:,:]
35 | target = F.interpolate(content,size=target_len,mode=mode)[0]
36 | return target
37 |
38 |
39 |
40 | def traverse_dir(
41 | root_dir,
42 | extensions,
43 | amount=None,
44 | str_include=None,
45 | str_exclude=None,
46 | is_pure=False,
47 | is_sort=False,
48 | is_ext=True):
49 |
50 | file_list = []
51 | cnt = 0
52 | for root, _, files in os.walk(root_dir):
53 | for file in files:
54 | if any([file.endswith(f".{ext}") for ext in extensions]):
55 | # path
56 | mix_path = os.path.join(root, file)
57 | pure_path = mix_path[len(root_dir)+1:] if is_pure else mix_path
58 |
59 | # amount
60 | if (amount is not None) and (cnt == amount):
61 | if is_sort:
62 | file_list.sort()
63 | return file_list
64 |
65 | # check string
66 | if (str_include is not None) and (str_include not in pure_path):
67 | continue
68 | if (str_exclude is not None) and (str_exclude in pure_path):
69 | continue
70 |
71 | if not is_ext:
72 | ext = pure_path.split('.')[-1]
73 | pure_path = pure_path[:-(len(ext)+1)]
74 | file_list.append(pure_path)
75 | cnt += 1
76 | if is_sort:
77 | file_list.sort()
78 | return file_list
79 |
80 |
81 | class DotDict(dict):
82 | def __getattr__(*args):
83 | val = dict.get(*args)
84 | return DotDict(val) if type(val) is dict else val
85 |
86 | __setattr__ = dict.__setitem__
87 | __delattr__ = dict.__delitem__
88 |
89 | def load_config(path_config):
90 | with open(path_config, "r") as config:
91 | args = yaml.safe_load(config)
92 | args = DotDict(args)
93 | return args
94 |
95 |
96 | def save_config(path_config,config):
97 | config = dict(config)
98 | with open(path_config, "w") as f:
99 | yaml.dump(config, f)
100 |
101 |
102 | class HParams():
103 | def __init__(self, **kwargs):
104 | for k, v in kwargs.items():
105 | if type(v) == dict:
106 | v = HParams(**v)
107 | self[k] = v
108 |
109 | def keys(self):
110 | return self.__dict__.keys()
111 |
112 | def items(self):
113 | return self.__dict__.items()
114 |
115 | def values(self):
116 | return self.__dict__.values()
117 |
118 | def __len__(self):
119 | return len(self.__dict__)
120 |
121 | def __getitem__(self, key):
122 | return getattr(self, key)
123 |
124 | def __setitem__(self, key, value):
125 | return setattr(self, key, value)
126 |
127 | def __contains__(self, key):
128 | return key in self.__dict__
129 |
130 | def __repr__(self):
131 | return self.__dict__.__repr__()
132 |
133 | def get(self,index):
134 | return self.__dict__.get(index)
135 |
136 |
137 | class InferHParams(HParams):
138 | def __init__(self, **kwargs):
139 | for k, v in kwargs.items():
140 | if type(v) == dict:
141 | v = InferHParams(**v)
142 | self[k] = v
143 |
144 | def __getattr__(self,index):
145 | return self.get(index)
146 |
147 |
148 | def make_positions(tensor, padding_idx):
149 | """Replace non-padding symbols with their position numbers.
150 | Position numbers begin at padding_idx+1. Padding symbols are ignored.
151 | """
152 | # The series of casts and type-conversions here are carefully
153 | # balanced to both work with ONNX export and XLA. In particular XLA
154 | # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
155 | # how to handle the dtype kwarg in cumsum.
156 | mask = tensor.ne(padding_idx).int()
157 | return (
158 | torch.cumsum(mask, dim=1).type_as(mask) * mask
159 | ).long() + padding_idx
160 |
161 |
162 |
163 | class Volume_Extractor:
164 | def __init__(self, hop_size = 512):
165 | self.hop_size = hop_size
166 |
167 | def extract(self, audio): # audio: 2d tensor array
168 | if not isinstance(audio,torch.Tensor):
169 | audio = torch.Tensor(audio)
170 | n_frames = int(audio.size(-1) // self.hop_size)
171 | audio2 = audio ** 2
172 | audio2 = torch.nn.functional.pad(audio2, (int(self.hop_size // 2), int((self.hop_size + 1) // 2)), mode = 'reflect')
173 | volume = torch.nn.functional.unfold(audio2[:,None,None,:],(1,self.hop_size),stride=self.hop_size)[:,:,:n_frames].mean(dim=1)[0]
174 | volume = torch.sqrt(volume)
175 | return volume
176 |
177 |
178 | def get_hparams_from_file(config_path, infer_mode = False):
179 | with open(config_path, "r") as f:
180 | data = f.read()
181 | config = json.loads(data)
182 | hparams =HParams(**config) if not infer_mode else InferHParams(**config)
183 | return hparams
184 |
185 |
186 | def load_model(
187 | expdir,
188 | model,
189 | optimizer,
190 | name='model',
191 | postfix='',
192 | device='cpu'):
193 | if postfix == '':
194 | postfix = '_' + postfix
195 | path = os.path.join(expdir, name+postfix)
196 | path_pt = traverse_dir(expdir, ['pt'], is_ext=False)
197 | global_step = 0
198 | if len(path_pt) > 0:
199 | steps = [s[len(path):] for s in path_pt]
200 | maxstep = max([int(s) if s.isdigit() else 0 for s in steps])
201 | if maxstep >= 0:
202 | path_pt = path+str(maxstep)+'.pt'
203 | else:
204 | path_pt = path+'best.pt'
205 | print(' [*] restoring model from', path_pt)
206 | ckpt = torch.load(path_pt, map_location=torch.device(device))
207 | global_step = ckpt['global_step']
208 | model.load_state_dict(ckpt['model'], strict=False)
209 | if ckpt.get("optimizer") is not None:
210 | optimizer.load_state_dict(ckpt['optimizer'])
211 | return global_step, model, optimizer
212 |
213 | def get_network_paras_amount(model_dict):
214 | info = dict()
215 | for model_name, model in model_dict.items():
216 | # all_params = sum(p.numel() for p in model.parameters())
217 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
218 |
219 | info[model_name] = trainable_params
220 | return info
221 |
222 |
223 | def load_teacher_model(model,checkpoint_dir):
224 | model_resumed = torch.load(checkpoint_dir)
225 | model.load_state_dict(model_resumed['model'],strict=False)
226 |
227 | model.decoder.denoise_fn_ema = copy.deepcopy(model.decoder.denoise_fn)
228 | model.decoder.denoise_fn_pretrained= copy.deepcopy(model.decoder.denoise_fn)
229 | return model
230 |
231 |
232 | def load_teacher_model_with_pitch(model,checkpoint_dir):
233 | model_resumed = torch.load(checkpoint_dir)
234 | model_pe_resumed = torch.load('./m4singer_pe/model_ckpt_steps_280000.ckpt')['state_dict']
235 | prefix_in_ckpt ='model'
236 | model_pe_resumed = {k[len(prefix_in_ckpt) + 1:]: v for k, v in model_pe_resumed.items()
237 | if k.startswith(f'{prefix_in_ckpt}.')}
238 | model.load_state_dict(model_resumed['model'],strict=False)
239 | model.decoder.pe.load_state_dict(model_pe_resumed,strict=True)
240 | model.decoder.denoise_fn_ema = copy.deepcopy(model.decoder.denoise_fn)
241 | model.decoder.denoise_fn_pretrained= copy.deepcopy(model.decoder.denoise_fn)
242 | return model
--------------------------------------------------------------------------------
/vocoder/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Grace9994/CoMoSVC/2ea8e644e2c5b3a8afc0762e870b9daacf3b5be5/vocoder/__init__.py
--------------------------------------------------------------------------------
/vocoder/m4gan/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Grace9994/CoMoSVC/2ea8e644e2c5b3a8afc0762e870b9daacf3b5be5/vocoder/m4gan/__init__.py
--------------------------------------------------------------------------------
/vocoder/m4gan/hifigan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6 |
7 |
8 | from vocoder.m4gan.parallel_wavegan import SourceModuleHnNSF
9 | import numpy as np
10 |
11 | LRELU_SLOPE = 0.1
12 |
13 |
14 | def init_weights(m, mean=0.0, std=0.01):
15 | classname = m.__class__.__name__
16 | if classname.find("Conv") != -1:
17 | m.weight.data.normal_(mean, std)
18 |
19 |
20 | def apply_weight_norm(m):
21 | classname = m.__class__.__name__
22 | if classname.find("Conv") != -1:
23 | weight_norm(m)
24 |
25 |
26 | def get_padding(kernel_size, dilation=1):
27 | return int((kernel_size * dilation - dilation) / 2)
28 |
29 |
30 | class ResBlock1(torch.nn.Module):
31 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
32 | super(ResBlock1, self).__init__()
33 | self.h = h
34 | self.convs1 = nn.ModuleList([
35 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
36 | padding=get_padding(kernel_size, dilation[0]))),
37 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
38 | padding=get_padding(kernel_size, dilation[1]))),
39 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
40 | padding=get_padding(kernel_size, dilation[2])))
41 | ])
42 | self.convs1.apply(init_weights)
43 |
44 | self.convs2 = nn.ModuleList([
45 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
46 | padding=get_padding(kernel_size, 1))),
47 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
48 | padding=get_padding(kernel_size, 1))),
49 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
50 | padding=get_padding(kernel_size, 1)))
51 | ])
52 | self.convs2.apply(init_weights)
53 |
54 | def forward(self, x):
55 | for c1, c2 in zip(self.convs1, self.convs2):
56 | xt = F.leaky_relu(x, LRELU_SLOPE)
57 | xt = c1(xt)
58 | xt = F.leaky_relu(xt, LRELU_SLOPE)
59 | xt = c2(xt)
60 | x = xt + x
61 | return x
62 |
63 | def remove_weight_norm(self):
64 | for l in self.convs1:
65 | remove_weight_norm(l)
66 | for l in self.convs2:
67 | remove_weight_norm(l)
68 |
69 |
70 | class ResBlock2(torch.nn.Module):
71 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
72 | super(ResBlock2, self).__init__()
73 | self.h = h
74 | self.convs = nn.ModuleList([
75 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
76 | padding=get_padding(kernel_size, dilation[0]))),
77 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
78 | padding=get_padding(kernel_size, dilation[1])))
79 | ])
80 | self.convs.apply(init_weights)
81 |
82 | def forward(self, x):
83 | for c in self.convs:
84 | xt = F.leaky_relu(x, LRELU_SLOPE)
85 | xt = c(xt)
86 | x = xt + x
87 | return x
88 |
89 | def remove_weight_norm(self):
90 | for l in self.convs:
91 | remove_weight_norm(l)
92 |
93 |
94 | class Conv1d1x1(Conv1d):
95 | """1x1 Conv1d with customized initialization."""
96 |
97 | def __init__(self, in_channels, out_channels, bias):
98 | """Initialize 1x1 Conv1d module."""
99 | super(Conv1d1x1, self).__init__(in_channels, out_channels,
100 | kernel_size=1, padding=0,
101 | dilation=1, bias=bias)
102 |
103 |
104 | class HifiGanGenerator(torch.nn.Module):
105 | def __init__(self, h, c_out=1):
106 | super(HifiGanGenerator, self).__init__()
107 | self.h = h
108 | self.num_kernels = len(h['resblock_kernel_sizes'])
109 | self.num_upsamples = len(h['upsample_rates'])
110 |
111 | if h['use_pitch_embed']:
112 | self.harmonic_num = 8
113 | self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h['upsample_rates']))
114 | self.m_source = SourceModuleHnNSF(
115 | sampling_rate=h['audio_sample_rate'],
116 | harmonic_num=self.harmonic_num)
117 | self.noise_convs = nn.ModuleList()
118 | self.conv_pre = weight_norm(Conv1d(80, h['upsample_initial_channel'], 7, 1, padding=3))
119 | resblock = ResBlock1 if h['resblock'] == '1' else ResBlock2
120 |
121 | self.ups = nn.ModuleList()
122 | for i, (u, k) in enumerate(zip(h['upsample_rates'], h['upsample_kernel_sizes'])):
123 | c_cur = h['upsample_initial_channel'] // (2 ** (i + 1))
124 | self.ups.append(weight_norm(
125 | ConvTranspose1d(c_cur * 2, c_cur, k, u, padding=(k - u) // 2)))
126 | if h['use_pitch_embed']:
127 | if i + 1 < len(h['upsample_rates']):
128 | stride_f0 = np.prod(h['upsample_rates'][i + 1:])
129 | self.noise_convs.append(Conv1d(
130 | 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
131 | else:
132 | self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
133 |
134 | self.resblocks = nn.ModuleList()
135 | for i in range(len(self.ups)):
136 | ch = h['upsample_initial_channel'] // (2 ** (i + 1))
137 | for j, (k, d) in enumerate(zip(h['resblock_kernel_sizes'], h['resblock_dilation_sizes'])):
138 | self.resblocks.append(resblock(h, ch, k, d))
139 |
140 | self.conv_post = weight_norm(Conv1d(ch, c_out, 7, 1, padding=3))
141 | self.ups.apply(init_weights)
142 | self.conv_post.apply(init_weights)
143 |
144 | def forward(self, x, f0=None):
145 | if f0 is not None:
146 | # harmonic-source signal, noise-source signal, uv flag
147 | f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)
148 | har_source, noi_source, uv = self.m_source(f0)
149 | har_source = har_source.transpose(1, 2)
150 |
151 | x = self.conv_pre(x)
152 | for i in range(self.num_upsamples):
153 | x = F.leaky_relu(x, LRELU_SLOPE)
154 | x = self.ups[i](x)
155 | if f0 is not None:
156 | x_source = self.noise_convs[i](har_source)
157 | x_source = torch.nn.functional.relu(x_source)
158 | tmp_shape = x_source.shape[1]
159 | x_source = torch.nn.functional.layer_norm(x_source.transpose(1, -1), (tmp_shape, )).transpose(1, -1)
160 | x = x + x_source
161 | xs = None
162 | for j in range(self.num_kernels):
163 | xs_ = self.resblocks[i * self.num_kernels + j](x)
164 | if xs is None:
165 | xs = xs_
166 | else:
167 | xs += xs_
168 | x = xs / self.num_kernels
169 | x = F.leaky_relu(x)
170 | x = self.conv_post(x)
171 | x = torch.tanh(x)
172 |
173 | return x
174 |
175 | def remove_weight_norm(self):
176 | print('Removing weight norm...')
177 | for l in self.ups:
178 | remove_weight_norm(l)
179 | for l in self.resblocks:
180 | l.remove_weight_norm()
181 | remove_weight_norm(self.conv_pre)
182 | remove_weight_norm(self.conv_post)
183 |
184 |
185 | class DiscriminatorP(torch.nn.Module):
186 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, use_cond=False, c_in=1):
187 | super(DiscriminatorP, self).__init__()
188 | self.use_cond = use_cond
189 | if use_cond:
190 | from utils.hparams import hparams
191 | t = hparams['hop_size']
192 | self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
193 | c_in = 2
194 |
195 | self.period = period
196 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
197 | self.convs = nn.ModuleList([
198 | norm_f(Conv2d(c_in, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
199 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
200 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
201 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
202 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
203 | ])
204 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
205 |
206 | def forward(self, x, mel):
207 | fmap = []
208 | if self.use_cond:
209 | x_mel = self.cond_net(mel)
210 | x = torch.cat([x_mel, x], 1)
211 | # 1d to 2d
212 | b, c, t = x.shape
213 | if t % self.period != 0: # pad first
214 | n_pad = self.period - (t % self.period)
215 | x = F.pad(x, (0, n_pad), "reflect")
216 | t = t + n_pad
217 | x = x.view(b, c, t // self.period, self.period)
218 |
219 | for l in self.convs:
220 | x = l(x)
221 | x = F.leaky_relu(x, LRELU_SLOPE)
222 | fmap.append(x)
223 | x = self.conv_post(x)
224 | fmap.append(x)
225 | x = torch.flatten(x, 1, -1)
226 |
227 | return x, fmap
228 |
229 |
230 | class MultiPeriodDiscriminator(torch.nn.Module):
231 | def __init__(self, use_cond=False, c_in=1):
232 | super(MultiPeriodDiscriminator, self).__init__()
233 | self.discriminators = nn.ModuleList([
234 | DiscriminatorP(2, use_cond=use_cond, c_in=c_in),
235 | DiscriminatorP(3, use_cond=use_cond, c_in=c_in),
236 | DiscriminatorP(5, use_cond=use_cond, c_in=c_in),
237 | DiscriminatorP(7, use_cond=use_cond, c_in=c_in),
238 | DiscriminatorP(11, use_cond=use_cond, c_in=c_in),
239 | ])
240 |
241 | def forward(self, y, y_hat, mel=None):
242 | y_d_rs = []
243 | y_d_gs = []
244 | fmap_rs = []
245 | fmap_gs = []
246 | for i, d in enumerate(self.discriminators):
247 | y_d_r, fmap_r = d(y, mel)
248 | y_d_g, fmap_g = d(y_hat, mel)
249 | y_d_rs.append(y_d_r)
250 | fmap_rs.append(fmap_r)
251 | y_d_gs.append(y_d_g)
252 | fmap_gs.append(fmap_g)
253 |
254 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
255 |
256 |
257 | class DiscriminatorS(torch.nn.Module):
258 | def __init__(self, use_spectral_norm=False, use_cond=False, upsample_rates=None, c_in=1):
259 | super(DiscriminatorS, self).__init__()
260 | self.use_cond = use_cond
261 | if use_cond:
262 | t = np.prod(upsample_rates)
263 | self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
264 | c_in = 2
265 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
266 | self.convs = nn.ModuleList([
267 | norm_f(Conv1d(c_in, 128, 15, 1, padding=7)),
268 | norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
269 | norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
270 | norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
271 | norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
272 | norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
273 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
274 | ])
275 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
276 |
277 | def forward(self, x, mel):
278 | if self.use_cond:
279 | x_mel = self.cond_net(mel)
280 | x = torch.cat([x_mel, x], 1)
281 | fmap = []
282 | for l in self.convs:
283 | x = l(x)
284 | x = F.leaky_relu(x, LRELU_SLOPE)
285 | fmap.append(x)
286 | x = self.conv_post(x)
287 | fmap.append(x)
288 | x = torch.flatten(x, 1, -1)
289 |
290 | return x, fmap
291 |
292 |
293 | class MultiScaleDiscriminator(torch.nn.Module):
294 | def __init__(self, use_cond=False, c_in=1):
295 | super(MultiScaleDiscriminator, self).__init__()
296 | from utils.hparams import hparams
297 | self.discriminators = nn.ModuleList([
298 | DiscriminatorS(use_spectral_norm=True, use_cond=use_cond,
299 | upsample_rates=[4, 4, hparams['hop_size'] // 16],
300 | c_in=c_in),
301 | DiscriminatorS(use_cond=use_cond,
302 | upsample_rates=[4, 4, hparams['hop_size'] // 32],
303 | c_in=c_in),
304 | DiscriminatorS(use_cond=use_cond,
305 | upsample_rates=[4, 4, hparams['hop_size'] // 64],
306 | c_in=c_in),
307 | ])
308 | self.meanpools = nn.ModuleList([
309 | AvgPool1d(4, 2, padding=1),
310 | AvgPool1d(4, 2, padding=1)
311 | ])
312 |
313 | def forward(self, y, y_hat, mel=None):
314 | y_d_rs = []
315 | y_d_gs = []
316 | fmap_rs = []
317 | fmap_gs = []
318 | for i, d in enumerate(self.discriminators):
319 | if i != 0:
320 | y = self.meanpools[i - 1](y)
321 | y_hat = self.meanpools[i - 1](y_hat)
322 | y_d_r, fmap_r = d(y, mel)
323 | y_d_g, fmap_g = d(y_hat, mel)
324 | y_d_rs.append(y_d_r)
325 | fmap_rs.append(fmap_r)
326 | y_d_gs.append(y_d_g)
327 | fmap_gs.append(fmap_g)
328 |
329 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
330 |
331 |
332 | def feature_loss(fmap_r, fmap_g):
333 | loss = 0
334 | for dr, dg in zip(fmap_r, fmap_g):
335 | for rl, gl in zip(dr, dg):
336 | loss += torch.mean(torch.abs(rl - gl))
337 |
338 | return loss * 2
339 |
340 |
341 | def discriminator_loss(disc_real_outputs, disc_generated_outputs):
342 | r_losses = 0
343 | g_losses = 0
344 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
345 | r_loss = torch.mean((1 - dr) ** 2)
346 | g_loss = torch.mean(dg ** 2)
347 | r_losses += r_loss
348 | g_losses += g_loss
349 | r_losses = r_losses / len(disc_real_outputs)
350 | g_losses = g_losses / len(disc_real_outputs)
351 | return r_losses, g_losses
352 |
353 |
354 | def cond_discriminator_loss(outputs):
355 | loss = 0
356 | for dg in outputs:
357 | g_loss = torch.mean(dg ** 2)
358 | loss += g_loss
359 | loss = loss / len(outputs)
360 | return loss
361 |
362 |
363 | def generator_loss(disc_outputs):
364 | loss = 0
365 | for dg in disc_outputs:
366 | l = torch.mean((1 - dg) ** 2)
367 | loss += l
368 | loss = loss / len(disc_outputs)
369 | return loss
370 |
371 |
--------------------------------------------------------------------------------
/vocoder/m4gan/parallel_wavegan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import sys
4 | import torch.nn.functional as torch_nn_func
5 |
6 |
7 | class SineGen(torch.nn.Module):
8 | """ Definition of sine generator
9 | SineGen(samp_rate, harmonic_num = 0,
10 | sine_amp = 0.1, noise_std = 0.003,
11 | voiced_threshold = 0,
12 | flag_for_pulse=False)
13 |
14 | samp_rate: sampling rate in Hz
15 | harmonic_num: number of harmonic overtones (default 0)
16 | sine_amp: amplitude of sine-wavefrom (default 0.1)
17 | noise_std: std of Gaussian noise (default 0.003)
18 | voiced_thoreshold: F0 threshold for U/V classification (default 0)
19 | flag_for_pulse: this SinGen is used inside PulseGen (default False)
20 |
21 | Note: when flag_for_pulse is True, the first time step of a voiced
22 | segment is always sin(np.pi) or cos(0)
23 | """
24 |
25 | def __init__(self, samp_rate, harmonic_num=0,
26 | sine_amp=0.1, noise_std=0.003,
27 | voiced_threshold=0,
28 | flag_for_pulse=False):
29 | super(SineGen, self).__init__()
30 | self.sine_amp = sine_amp
31 | self.noise_std = noise_std
32 | self.harmonic_num = harmonic_num
33 | self.dim = self.harmonic_num + 1
34 | self.sampling_rate = samp_rate
35 | self.voiced_threshold = voiced_threshold
36 | self.flag_for_pulse = flag_for_pulse
37 |
38 | def _f02uv(self, f0):
39 | # generate uv signal
40 | uv = torch.ones_like(f0)
41 | uv = uv * (f0 > self.voiced_threshold)
42 | return uv
43 |
44 | def _f02sine(self, f0_values):
45 | """ f0_values: (batchsize, length, dim)
46 | where dim indicates fundamental tone and overtones
47 | """
48 | # convert to F0 in rad. The interger part n can be ignored
49 | # because 2 * np.pi * n doesn't affect phase
50 | rad_values = (f0_values / self.sampling_rate) % 1
51 |
52 | # initial phase noise (no noise for fundamental component)
53 | rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
54 | device=f0_values.device)
55 | rand_ini[:, 0] = 0
56 | rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
57 |
58 | # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
59 | if not self.flag_for_pulse:
60 | # for normal case
61 |
62 | # To prevent torch.cumsum numerical overflow,
63 | # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
64 | # Buffer tmp_over_one_idx indicates the time step to add -1.
65 | # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
66 | tmp_over_one = torch.cumsum(rad_values, 1) % 1
67 | tmp_over_one_idx = (tmp_over_one[:, 1:, :] -
68 | tmp_over_one[:, :-1, :]) < 0
69 | cumsum_shift = torch.zeros_like(rad_values)
70 | cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
71 |
72 | sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1)
73 | * 2 * np.pi)
74 | else:
75 | # If necessary, make sure that the first time step of every
76 | # voiced segments is sin(pi) or cos(0)
77 | # This is used for pulse-train generation
78 |
79 | # identify the last time step in unvoiced segments
80 | uv = self._f02uv(f0_values)
81 | uv_1 = torch.roll(uv, shifts=-1, dims=1)
82 | uv_1[:, -1, :] = 1
83 | u_loc = (uv < 1) * (uv_1 > 0)
84 |
85 | # get the instantanouse phase
86 | tmp_cumsum = torch.cumsum(rad_values, dim=1)
87 | # different batch needs to be processed differently
88 | for idx in range(f0_values.shape[0]):
89 | temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
90 | temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
91 | # stores the accumulation of i.phase within
92 | # each voiced segments
93 | tmp_cumsum[idx, :, :] = 0
94 | tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
95 |
96 | # rad_values - tmp_cumsum: remove the accumulation of i.phase
97 | # within the previous voiced segment.
98 | i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
99 |
100 | # get the sines
101 | sines = torch.cos(i_phase * 2 * np.pi)
102 | return sines
103 |
104 | def forward(self, f0):
105 | """ sine_tensor, uv = forward(f0)
106 | input F0: tensor(batchsize=1, length, dim=1)
107 | f0 for unvoiced steps should be 0
108 | output sine_tensor: tensor(batchsize=1, length, dim)
109 | output uv: tensor(batchsize=1, length, 1)
110 | """
111 | with torch.no_grad():
112 | f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
113 | device=f0.device)
114 | # fundamental component
115 | f0_buf[:, :, 0] = f0[:, :, 0]
116 | for idx in np.arange(self.harmonic_num):
117 | # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
118 | f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
119 |
120 | # generate sine waveforms
121 | sine_waves = self._f02sine(f0_buf) * self.sine_amp
122 |
123 | # generate uv signal
124 | # uv = torch.ones(f0.shape)
125 | # uv = uv * (f0 > self.voiced_threshold)
126 | uv = self._f02uv(f0)
127 |
128 | # noise: for unvoiced should be similar to sine_amp
129 | # std = self.sine_amp/3 -> max value ~ self.sine_amp
130 | # . for voiced regions is self.noise_std
131 | noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
132 | noise = noise_amp * torch.randn_like(sine_waves)
133 |
134 | # first: set the unvoiced part to 0 by uv
135 | # then: additive noise
136 | sine_waves = sine_waves * uv + noise
137 | return sine_waves, uv, noise
138 |
139 |
140 | class PulseGen(torch.nn.Module):
141 | """ Definition of Pulse train generator
142 |
143 | There are many ways to implement pulse generator.
144 | Here, PulseGen is based on SinGen. For a perfect
145 | """
146 | def __init__(self, samp_rate, pulse_amp = 0.1,
147 | noise_std = 0.003, voiced_threshold = 0):
148 | super(PulseGen, self).__init__()
149 | self.pulse_amp = pulse_amp
150 | self.sampling_rate = samp_rate
151 | self.voiced_threshold = voiced_threshold
152 | self.noise_std = noise_std
153 | self.l_sinegen = SineGen(self.sampling_rate, harmonic_num=0, \
154 | sine_amp=self.pulse_amp, noise_std=0, \
155 | voiced_threshold=self.voiced_threshold, \
156 | flag_for_pulse=True)
157 |
158 | def forward(self, f0):
159 | """ Pulse train generator
160 | pulse_train, uv = forward(f0)
161 | input F0: tensor(batchsize=1, length, dim=1)
162 | f0 for unvoiced steps should be 0
163 | output pulse_train: tensor(batchsize=1, length, dim)
164 | output uv: tensor(batchsize=1, length, 1)
165 |
166 | Note: self.l_sine doesn't make sure that the initial phase of
167 | a voiced segment is np.pi, the first pulse in a voiced segment
168 | may not be at the first time step within a voiced segment
169 | """
170 | with torch.no_grad():
171 | sine_wav, uv, noise = self.l_sinegen(f0)
172 |
173 | # sine without additive noise
174 | pure_sine = sine_wav - noise
175 |
176 | # step t corresponds to a pulse if
177 | # sine[t] > sine[t+1] & sine[t] > sine[t-1]
178 | # & sine[t-1], sine[t+1], and sine[t] are voiced
179 | # or
180 | # sine[t] is voiced, sine[t-1] is unvoiced
181 | # we use torch.roll to simulate sine[t+1] and sine[t-1]
182 | sine_1 = torch.roll(pure_sine, shifts=1, dims=1)
183 | uv_1 = torch.roll(uv, shifts=1, dims=1)
184 | uv_1[:, 0, :] = 0
185 | sine_2 = torch.roll(pure_sine, shifts=-1, dims=1)
186 | uv_2 = torch.roll(uv, shifts=-1, dims=1)
187 | uv_2[:, -1, :] = 0
188 |
189 | loc = (pure_sine > sine_1) * (pure_sine > sine_2) \
190 | * (uv_1 > 0) * (uv_2 > 0) * (uv > 0) \
191 | + (uv_1 < 1) * (uv > 0)
192 |
193 | # pulse train without noise
194 | pulse_train = pure_sine * loc
195 |
196 | # additive noise to pulse train
197 | # note that noise from sinegen is zero in voiced regions
198 | pulse_noise = torch.randn_like(pure_sine) * self.noise_std
199 |
200 | # with additive noise on pulse, and unvoiced regions
201 | pulse_train += pulse_noise * loc + pulse_noise * (1 - uv)
202 | return pulse_train, sine_wav, uv, pulse_noise
203 |
204 |
205 | class SignalsConv1d(torch.nn.Module):
206 | """ Filtering input signal with time invariant filter
207 | Note: FIRFilter conducted filtering given fixed FIR weight
208 | SignalsConv1d convolves two signals
209 | Note: this is based on torch.nn.functional.conv1d
210 |
211 | """
212 |
213 | def __init__(self):
214 | super(SignalsConv1d, self).__init__()
215 |
216 | def forward(self, signal, system_ir):
217 | """ output = forward(signal, system_ir)
218 |
219 | signal: (batchsize, length1, dim)
220 | system_ir: (length2, dim)
221 |
222 | output: (batchsize, length1, dim)
223 | """
224 | if signal.shape[-1] != system_ir.shape[-1]:
225 | print("Error: SignalsConv1d expects shape:")
226 | print("signal (batchsize, length1, dim)")
227 | print("system_id (batchsize, length2, dim)")
228 | print("But received signal: {:s}".format(str(signal.shape)))
229 | print(" system_ir: {:s}".format(str(system_ir.shape)))
230 | sys.exit(1)
231 | padding_length = system_ir.shape[0] - 1
232 | groups = signal.shape[-1]
233 |
234 | # pad signal on the left
235 | signal_pad = torch_nn_func.pad(signal.permute(0, 2, 1), \
236 | (padding_length, 0))
237 | # prepare system impulse response as (dim, 1, length2)
238 | # also flip the impulse response
239 | ir = torch.flip(system_ir.unsqueeze(1).permute(2, 1, 0), \
240 | dims=[2])
241 | # convolute
242 | output = torch_nn_func.conv1d(signal_pad, ir, groups=groups)
243 | return output.permute(0, 2, 1)
244 |
245 |
246 | class CyclicNoiseGen_v1(torch.nn.Module):
247 | """ CyclicnoiseGen_v1
248 | Cyclic noise with a single parameter of beta.
249 | Pytorch v1 implementation assumes f_t is also fixed
250 | """
251 |
252 | def __init__(self, samp_rate,
253 | noise_std=0.003, voiced_threshold=0):
254 | super(CyclicNoiseGen_v1, self).__init__()
255 | self.samp_rate = samp_rate
256 | self.noise_std = noise_std
257 | self.voiced_threshold = voiced_threshold
258 |
259 | self.l_pulse = PulseGen(samp_rate, pulse_amp=1.0,
260 | noise_std=noise_std,
261 | voiced_threshold=voiced_threshold)
262 | self.l_conv = SignalsConv1d()
263 |
264 | def noise_decay(self, beta, f0mean):
265 | """ decayed_noise = noise_decay(beta, f0mean)
266 | decayed_noise = n[t]exp(-t * f_mean / beta / samp_rate)
267 |
268 | beta: (dim=1) or (batchsize=1, 1, dim=1)
269 | f0mean (batchsize=1, 1, dim=1)
270 |
271 | decayed_noise (batchsize=1, length, dim=1)
272 | """
273 | with torch.no_grad():
274 | # exp(-1.0 n / T) < 0.01 => n > -log(0.01)*T = 4.60*T
275 | # truncate the noise when decayed by -40 dB
276 | length = 4.6 * self.samp_rate / f0mean
277 | length = length.int()
278 | time_idx = torch.arange(0, length, device=beta.device)
279 | time_idx = time_idx.unsqueeze(0).unsqueeze(2)
280 | time_idx = time_idx.repeat(beta.shape[0], 1, beta.shape[2])
281 |
282 | noise = torch.randn(time_idx.shape, device=beta.device)
283 |
284 | # due to Pytorch implementation, use f0_mean as the f0 factor
285 | decay = torch.exp(-time_idx * f0mean / beta / self.samp_rate)
286 | return noise * self.noise_std * decay
287 |
288 | def forward(self, f0s, beta):
289 | """ Producde cyclic-noise
290 | """
291 | # pulse train
292 | pulse_train, sine_wav, uv, noise = self.l_pulse(f0s)
293 | pure_pulse = pulse_train - noise
294 |
295 | # decayed_noise (length, dim=1)
296 | if (uv < 1).all():
297 | # all unvoiced
298 | cyc_noise = torch.zeros_like(sine_wav)
299 | else:
300 | f0mean = f0s[uv > 0].mean()
301 |
302 | decayed_noise = self.noise_decay(beta, f0mean)[0, :, :]
303 | # convolute
304 | cyc_noise = self.l_conv(pure_pulse, decayed_noise)
305 |
306 | # add noise in invoiced segments
307 | cyc_noise = cyc_noise + noise * (1.0 - uv)
308 | return cyc_noise, pulse_train, sine_wav, uv, noise
309 |
310 |
311 | class SourceModuleCycNoise_v1(torch.nn.Module):
312 | """ SourceModuleCycNoise_v1
313 | SourceModule(sampling_rate, noise_std=0.003, voiced_threshod=0)
314 | sampling_rate: sampling_rate in Hz
315 |
316 | noise_std: std of Gaussian noise (default: 0.003)
317 | voiced_threshold: threshold to set U/V given F0 (default: 0)
318 |
319 | cyc, noise, uv = SourceModuleCycNoise_v1(F0_upsampled, beta)
320 | F0_upsampled (batchsize, length, 1)
321 | beta (1)
322 | cyc (batchsize, length, 1)
323 | noise (batchsize, length, 1)
324 | uv (batchsize, length, 1)
325 | """
326 |
327 | def __init__(self, sampling_rate, noise_std=0.003, voiced_threshod=0):
328 | super(SourceModuleCycNoise_v1, self).__init__()
329 | self.sampling_rate = sampling_rate
330 | self.noise_std = noise_std
331 | self.l_cyc_gen = CyclicNoiseGen_v1(sampling_rate, noise_std,
332 | voiced_threshod)
333 |
334 | def forward(self, f0_upsamped, beta):
335 | """
336 | cyc, noise, uv = SourceModuleCycNoise_v1(F0, beta)
337 | F0_upsampled (batchsize, length, 1)
338 | beta (1)
339 | cyc (batchsize, length, 1)
340 | noise (batchsize, length, 1)
341 | uv (batchsize, length, 1)
342 | """
343 | # source for harmonic branch
344 | cyc, pulse, sine, uv, add_noi = self.l_cyc_gen(f0_upsamped, beta)
345 |
346 | # source for noise branch, in the same shape as uv
347 | noise = torch.randn_like(uv) * self.noise_std / 3
348 | return cyc, noise, uv
349 |
350 |
351 | class SourceModuleHnNSF(torch.nn.Module):
352 | """ SourceModule for hn-nsf
353 | SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
354 | add_noise_std=0.003, voiced_threshod=0)
355 | sampling_rate: sampling_rate in Hz
356 | harmonic_num: number of harmonic above F0 (default: 0)
357 | sine_amp: amplitude of sine source signal (default: 0.1)
358 | add_noise_std: std of additive Gaussian noise (default: 0.003)
359 | note that amplitude of noise in unvoiced is decided
360 | by sine_amp
361 | voiced_threshold: threhold to set U/V given F0 (default: 0)
362 |
363 | Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
364 | F0_sampled (batchsize, length, 1)
365 | Sine_source (batchsize, length, 1)
366 | noise_source (batchsize, length 1)
367 | uv (batchsize, length, 1)
368 | """
369 |
370 | def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
371 | add_noise_std=0.003, voiced_threshod=0):
372 | super(SourceModuleHnNSF, self).__init__()
373 |
374 | self.sine_amp = sine_amp
375 | self.noise_std = add_noise_std
376 |
377 | # to produce sine waveforms
378 | self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
379 | sine_amp, add_noise_std, voiced_threshod)
380 |
381 | # to merge source harmonics into a single excitation
382 | self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
383 | self.l_tanh = torch.nn.Tanh()
384 |
385 | def forward(self, x):
386 | """
387 | Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
388 | F0_sampled (batchsize, length, 1)
389 | Sine_source (batchsize, length, 1)
390 | noise_source (batchsize, length 1)
391 | """
392 | # source for harmonic branch
393 | sine_wavs, uv, _ = self.l_sin_gen(x)
394 | sine_merge = self.l_tanh(self.l_linear(sine_wavs))
395 |
396 | # source for noise branch, in the same shape as uv
397 | noise = torch.randn_like(uv) * self.sine_amp / 3
398 | return sine_merge, noise, uv
399 |
400 |
401 | if __name__ == '__main__':
402 | source = SourceModuleCycNoise_v1(24000)
403 | x = torch.randn(16, 25600, 1)
404 |
405 |
406 |
--------------------------------------------------------------------------------
/wavenet.py:
--------------------------------------------------------------------------------
1 | import math
2 | from math import sqrt
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.nn import Mish
8 |
9 |
10 | class Conv1d(torch.nn.Conv1d):
11 | def __init__(self, *args, **kwargs):
12 | super().__init__(*args, **kwargs)
13 | nn.init.kaiming_normal_(self.weight)
14 |
15 |
16 | class SinusoidalPosEmb(nn.Module):
17 | def __init__(self, dim):
18 | super().__init__()
19 | self.dim = dim
20 |
21 | def forward(self, x):
22 | device = x.device
23 | half_dim = self.dim // 2
24 | emb = math.log(10000) / (half_dim - 1)
25 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
26 | emb = x[:, None] * emb[None, :]
27 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
28 | return emb
29 |
30 |
31 | class ResidualBlock(nn.Module):
32 | def __init__(self, encoder_hidden, residual_channels, dilation):
33 | super().__init__()
34 | self.residual_channels = residual_channels
35 | self.dilated_conv = nn.Conv1d(
36 | residual_channels,
37 | 2 * residual_channels,
38 | kernel_size=3,
39 | padding=dilation,
40 | dilation=dilation
41 | )
42 | self.diffusion_projection = nn.Linear(residual_channels, residual_channels)
43 | self.conditioner_projection = nn.Conv1d(encoder_hidden, 2 * residual_channels, 1)
44 | self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1)
45 |
46 | def forward(self, x, conditioner, diffusion_step):
47 | # x:[48,512,187]
48 | diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) # [48, 512, 1]
49 | conditioner = self.conditioner_projection(conditioner) # [48, 1024, 187]
50 | y = x + diffusion_step # 维度分别是 [48,512,187] & [48,512,1],将diffusion_step的最后一维复制187份和x相加
51 |
52 | y = self.dilated_conv(y) + conditioner # self.dilated_conv(y)形状是[48, 1024, 187]
53 |
54 | # Using torch.split instead of torch.chunk to avoid using onnx::Slice
55 | gate, filter = torch.split(y, [self.residual_channels, self.residual_channels], dim=1)
56 | # gate和filter的形状都是[48, 512, 187]
57 | y = torch.sigmoid(gate) * torch.tanh(filter) # [48, 512, 187]
58 |
59 | y = self.output_projection(y) # [48, 1024, 187]
60 |
61 | # Using torch.split instead of torch.chunk to avoid using onnx::Slice
62 | residual, skip = torch.split(y, [self.residual_channels, self.residual_channels], dim=1)
63 | #形状都是[48, 512, 187]
64 | return (x + residual) / math.sqrt(2.0), skip
65 |
66 |
67 | class WaveNet(nn.Module):
68 | def __init__(self, in_dims=128, n_layers=20, n_chans=384, n_hidden=256):
69 | #in_dim=vocoder_dimension 512 n_hidden: 100 n_layers: 20 n_spk: 20
70 | super().__init__()
71 | self.input_projection = Conv1d(in_dims, n_chans, 1)
72 | self.diffusion_embedding = SinusoidalPosEmb(n_chans)
73 | self.mlp = nn.Sequential(
74 | nn.Linear(n_chans, n_chans * 4),
75 | Mish(),
76 | nn.Linear(n_chans * 4, n_chans)
77 | )
78 | self.residual_layers = nn.ModuleList([
79 | ResidualBlock(
80 | encoder_hidden=n_hidden,
81 | residual_channels=n_chans,
82 | dilation=1
83 | )
84 | for i in range(n_layers)
85 | ])
86 | self.skip_projection = Conv1d(n_chans, n_chans, 1)
87 | self.output_projection = Conv1d(n_chans, in_dims, 1)
88 | nn.init.zeros_(self.output_projection.weight)
89 |
90 | def forward(self, spec, diffusion_step,cond):
91 | """
92 | :param spec: [B, 1, M, T]
93 | :param diffusion_step: [B, 1]
94 | :param cond: [B, M, T]
95 | :return:
96 | """
97 | # print("spec_shape",spec.shape)
98 | # print("cond_shape",cond.shape)
99 | cond = cond.transpose(1,2) # [48,256,187]
100 | x = spec.squeeze(1)# 没有变化[48, 187, 100]
101 | x = x.transpose(1,2)
102 | x = self.input_projection(x) # [B, residual_channel, T]
103 |
104 |
105 | x = F.relu(x)
106 | diffusion_step = self.diffusion_embedding(diffusion_step)
107 | diffusion_step = self.mlp(diffusion_step)
108 | skip = []
109 | # n_layers=20,要经过20层residual layer
110 | for layer in self.residual_layers:
111 | x, skip_connection = layer(x, cond, diffusion_step)
112 | skip.append(skip_connection)
113 |
114 | x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers)) #[48, 512, 187]
115 | x = self.skip_projection(x)
116 | x = F.relu(x) # [48, 512, 187]
117 | x = self.output_projection(x) # [B, mel_bins, T] [48, 100, 187]
118 | # output=x[:, None, :, :]
119 | # print(output.shape)
120 | return x[:, :, :].transpose(1,2) # [48, 187, 100]
121 |
--------------------------------------------------------------------------------