├── LICENSE
├── README.md
├── configs
└── SinStack.yaml
├── data_loaders.py
├── ddsp
├── __init__.py
├── audio_analysis.py
├── loss.py
├── mel2control.py
├── model_conformer_naive.py
├── utils.py
└── vocoder.py
├── export.py
├── harmonic_noise_extract.py
├── logger
├── __init__.py
├── saver.py
└── utils.py
├── main.py
├── onnx_infer.py
├── preprocess.py
├── requirements.txt
├── train.py
├── u_noise.ckpt
└── v_noise.ckpt
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 yxlllc
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Miniini DDSP Vocoders
2 | 参考项目
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 | 一个迷你ddsp vocoder,基于pc-ddsp的修改,推理时不需要ISTFT以及复杂的滤波器设计。
13 | 原理是把原先pc-ddsp的sin加法合成部分改成了类似llsm的固定频率正弦波叠加合成,
14 | 气声合成采用一组提前按频率分割好的预制气声,合成时与模型输出的幅度张量相乘
15 |
16 | 这个项目的特征提取器支持谐波分离,运行
17 |
18 | ```bash
19 | python harmonic_noise_extract.py --input_file /xxx
20 | ```
21 |
22 | 即可进行谐波分离。
23 |
24 |
25 |
26 | ## 1. 安装依赖
27 |
28 | 首先下载pytroch,安装方法请参考 [official website](https://pytorch.org/)。
29 |
30 | 然后运行:
31 | ```bash
32 | pip install -r requirements.txt
33 | ```
34 |
35 | ## 2. Preprocessing
36 |
37 | 把训练数据放到如下目录: `data/train/audio`. 把测试数据放到如下目录: `data/val/audio`.
38 | 使用VR模型进行谐波分离,把分离的谐波音频放到如下目录: `data/train/harmonic_audio ` 和 `data/val/harmonic_audio`
39 | (刚刚不是说支持谐波分离吗,为啥又要用到VR模型? 因为VR模型的谐波分离效果好于harmonic_noise_extract.py)。
40 |
41 | 运行如下命令进行预处理:
42 |
43 | ```bash
44 | python preprocess.py --config configs/SinStack.yaml
45 | ```
46 |
47 | ## 3. Training
48 |
49 | ```bash
50 | python train.py --config configs/SinStack.yaml
51 | ```
52 | 如果爆显存了,可以把vocoder.py里的Sine_Generator换成Sine_Generator_Fast。
53 |
54 | 参数里有一个相位loss, 别开,学不会()
55 | ## 4. Inference
56 |
57 | ```bash
58 | python main.py --model_path /xxx --input /audio --output /audio
59 | ```
60 |
--------------------------------------------------------------------------------
/configs/SinStack.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | f0_extractor: 'parselmouth' # 'parselmouth' (singing) or 'dio' (speech) or 'harvest' (speech)
3 | f0_min: 65 # about C2
4 | f0_max: 800 # about G5
5 | sampling_rate: 44100
6 | n_fft: 2048
7 | win_size: 2048
8 | hop_size: 512 # Equal to hop_length
9 | n_mels: 128
10 | mel_fmin: 40
11 | mel_fmax: 16000 # <= sampling_rate / 2
12 | max_nhar: 128 # must be equal to model.n_sin_hars
13 | relative_winsize: 4
14 | duration: 2 # Audio duration during training, must be less than the duration of the shortest audio clip
15 | train_path: data/train # Create a folder named "audio" under this path and put the audio clip in it
16 | valid_path: data/val # Create a folder named "audio" under this path and put the audio clip in it
17 | # directories: ['audio', 'haudio', 'mel', 'f0', 'uv', 'ampl', 'phase']
18 | mel_clamp: 0.000001 # clamp the output of the model to avoid numerical instability
19 | model:
20 | type: 'SinStack'
21 | win_length: 2048
22 | n_sin_hars: 128
23 | n_noise_bin: 64
24 |
25 | triangle_ReLU : true # use triangle ReLU instead of ReLU
26 | triangle_ReLU_up: 0.2
27 | triangle_ReLU_down: 0.8
28 |
29 | uv_noise_k : 512
30 | loss:
31 | fft_min: 256
32 | fft_max: 2048
33 | n_scale: 4 # rss kernel numbers
34 | lambda_uv: 0.0 # uv regularization
35 | lambda_ampl: 0.3 # amplitude regularization
36 | lambda_phase: 0.0 # phase regularization
37 | uv_tolerance: 0.05 # set it to a large value or try other f0 extractors if val_loss_uv is much higher than train_loss_uv
38 | detach_uv_step: 200
39 | device: cuda
40 | env:
41 | expdir: exp/test
42 | gpu_id: 0
43 | train:
44 | num_workers: 2 # if your cpu and gpu are both very strong, set to 0 may be faster!
45 | batch_size: 10
46 | cache_all_data: true # Save Internal-Memory if it is false, but may be slow
47 | epochs: 100000
48 | interval_log: 10
49 | interval_val: 1000
50 | lr: 0.0005
51 | weight_decay: 0
--------------------------------------------------------------------------------
/data_loaders.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | import librosa
5 | import torch
6 | import random
7 | from tqdm import tqdm
8 | from torch.utils.data import Dataset
9 | from logger.utils import traverse_dir
10 |
11 | from typing import Dict, List, Optional, Tuple
12 |
13 |
14 | def get_data_loaders(args, whole_audio=False):
15 | data_train = AudioDataset(
16 | args.data.train_path,
17 | waveform_sec=args.data.duration,
18 | hop_size=args.data.hop_size,
19 | sample_rate=args.data.sampling_rate,
20 | load_all_data=args.train.cache_all_data,
21 | whole_audio=whole_audio,
22 | volume_aug=True)
23 | loader_train = torch.utils.data.DataLoader(
24 | data_train ,
25 | batch_size=args.train.batch_size if not whole_audio else 1,
26 | shuffle=True,
27 | num_workers=args.train.num_workers,
28 | persistent_workers=(args.train.num_workers > 0),
29 | pin_memory=True
30 | )
31 | data_valid = AudioDataset(
32 | args.data.valid_path,
33 | waveform_sec=args.data.duration,
34 | hop_size=args.data.hop_size,
35 | sample_rate=args.data.sampling_rate,
36 | load_all_data=args.train.cache_all_data,
37 | whole_audio=True,
38 | volume_aug=False)
39 | loader_valid = torch.utils.data.DataLoader(
40 | data_valid,
41 | batch_size=1,
42 | shuffle=False,
43 | num_workers=0,
44 | pin_memory=True
45 | )
46 | return loader_train, loader_valid
47 |
48 |
49 | class AudioDataset(Dataset):
50 | def __init__(
51 | self,
52 | path_root,
53 | waveform_sec,
54 | hop_size,
55 | sample_rate,
56 | load_all_data=True,
57 | whole_audio=False,
58 | volume_aug=False
59 | ):
60 | super().__init__()
61 |
62 | self.waveform_sec = waveform_sec
63 | self.sample_rate = sample_rate
64 | self.hop_size = hop_size
65 | self.path_root = path_root
66 | self.paths = traverse_dir(
67 | os.path.join(path_root, 'audio'),
68 | extension='wav',
69 | is_pure=True,
70 | is_sort=True,
71 | is_ext=False
72 | )
73 | self.whole_audio = whole_audio
74 | self.volume_aug = volume_aug
75 | self.data_buffer={}
76 | if load_all_data:
77 | print('Load all the data from :', path_root)
78 | else:
79 | print('Load the f0, uv data from :', path_root)
80 | for name in tqdm(self.paths, total=len(self.paths)):
81 | path_audio = os.path.join(self.path_root, 'audio', name) + '.wav'
82 | duration = librosa.get_duration(filename = path_audio, sr = self.sample_rate)
83 |
84 | path_f0 = os.path.join(self.path_root, 'f0', name) + '.npy'
85 | f0 = np.load(path_f0)
86 | f0 = torch.from_numpy(f0).float().unsqueeze(-1)
87 |
88 | path_uv = os.path.join(self.path_root, 'uv', name) + '.npy'
89 | uv = np.load(path_uv)
90 | uv = torch.from_numpy(uv).float()
91 |
92 | path_ampl = os.path.join(self.path_root, 'ampl', name) + '.npy'
93 | ampl = np.load(path_ampl)
94 | ampl = torch.from_numpy(ampl).float()
95 |
96 | path_phase = os.path.join(self.path_root, 'phase', name) + '.npy'
97 | phase = np.load(path_phase)
98 | phase = torch.from_numpy(phase).float()
99 |
100 | if load_all_data:
101 | audio, sr = librosa.load(path_audio, sr=self.sample_rate)
102 | audio = torch.from_numpy(audio).float()
103 |
104 | path_mel = os.path.join(self.path_root, 'mel', name) + '.npy'
105 | audio_mel = np.load(path_mel)
106 | audio_mel = torch.from_numpy(audio_mel).float()
107 |
108 | self.data_buffer[name] = {
109 | 'duration': duration,
110 | 'audio': audio,
111 | 'audio_mel': audio_mel,
112 | 'f0': f0,
113 | 'uv': uv,
114 | 'ampl': ampl,
115 | 'phase': phase
116 | }
117 | else:
118 | self.data_buffer[name] = {
119 | 'duration': duration,
120 | 'f0': f0,
121 | 'uv': uv,
122 | 'ampl': ampl,
123 | 'phase': phase
124 | }
125 |
126 |
127 | def __getitem__(self, file_idx):
128 | name = self.paths[file_idx]
129 | data_buffer = self.data_buffer[name]
130 | # check duration. if too short, then skip
131 | if data_buffer['duration'] < (self.waveform_sec + 0.1):
132 | return self.__getitem__( (file_idx + 1) % len(self.paths))
133 |
134 | # get item
135 | return self.get_data(name, data_buffer)
136 |
137 | def get_data(self, name, data_buffer):
138 | frame_resolution = self.hop_size / self.sample_rate
139 | duration = data_buffer['duration']
140 | waveform_sec = duration if self.whole_audio else self.waveform_sec
141 |
142 | # load audio
143 | idx_from = 0 if self.whole_audio else random.uniform(0, duration - waveform_sec - 0.1)
144 | start_frame = int(idx_from / frame_resolution)
145 | mel_frame_len = int(waveform_sec / frame_resolution)
146 | audio = data_buffer.get('audio')
147 | if audio is None:
148 | path_audio = os.path.join(self.path_root, 'audio', name) + '.wav'
149 | audio, sr = librosa.load(
150 | path_audio,
151 | sr = self.sample_rate,
152 | offset = start_frame * frame_resolution,
153 | duration = waveform_sec)
154 | # clip audio into N seconds
155 | audio = audio[..., : audio.shape[-1] // self.hop_size * self.hop_size]
156 | audio = torch.from_numpy(audio).float()
157 | else:
158 | audio = audio[..., start_frame * self.hop_size : (start_frame + mel_frame_len) * self.hop_size].clone()
159 |
160 | # load mel
161 | audio_mel = data_buffer.get('audio_mel')
162 | if audio_mel is None:
163 | path_mel = os.path.join(self.path_root, 'mel', name) + '.npy'
164 | audio_mel = np.load(path_mel)
165 | audio_mel = audio_mel[:, start_frame : start_frame + mel_frame_len]
166 | audio_mel = torch.from_numpy(audio_mel).float()
167 | else:
168 | audio_mel = audio_mel[:, start_frame : start_frame + mel_frame_len].clone()
169 |
170 | # load ampl
171 | ampl = data_buffer.get('ampl')
172 | ampl_frames = ampl[:,start_frame : start_frame + mel_frame_len]
173 | phase = data_buffer.get('phase')
174 | phase_frames = phase[:,start_frame : start_frame + mel_frame_len]
175 |
176 | # load f0
177 | f0 = data_buffer.get('f0')
178 | f0_frames = f0[start_frame : start_frame + mel_frame_len]
179 |
180 | # load uv
181 | uv = data_buffer.get('uv')
182 | uv_frames = uv[start_frame : start_frame + mel_frame_len]
183 |
184 | # volume augmentation
185 | #if self.volume_aug:
186 | # max_amp = float(torch.max(torch.abs(audio))) + 1e-5
187 | # max_shift = min(1, np.log10(1/max_amp))
188 | # log10_mel_shift = random.uniform(-1, max_shift)
189 | # audio *= (10 ** log10_mel_shift)
190 | # audio_mel += log10_mel_shift
191 | # audio_mel = torch.clamp(audio_mel, min=-5)
192 |
193 | if self.volume_aug:
194 | max_amp = float(torch.max(torch.abs(audio))) + 1e-5
195 | max_shift = min(1, np.log(1/max_amp))
196 | log_mel_shift = random.uniform(-1, max_shift)
197 | audio *= (np.e ** log_mel_shift)
198 | audio_mel += log_mel_shift
199 | #audio_mel = torch.clamp(audio_mel, min=-11.512938235)
200 |
201 | return dict(
202 | audio=audio,
203 | f0=f0_frames,
204 | uv=uv_frames,
205 | mel=audio_mel,
206 | ampl=ampl_frames,
207 | phase=phase_frames,
208 | name=name
209 | )
210 |
211 | def __len__(self):
212 | return len(self.paths)
213 |
--------------------------------------------------------------------------------
/ddsp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjzxkxdn/Mini-DDSP/305f069c07f0214bf65a28f519e2149e864885a2/ddsp/__init__.py
--------------------------------------------------------------------------------
/ddsp/audio_analysis.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 |
5 | import matplotlib.pyplot as plt
6 | import pyworld as pw
7 | import parselmouth as pm
8 |
9 | import soundfile as sf
10 | import numpy as np
11 |
12 | from ddsp.utils import get_mel_fn, get_n_fft
13 | from ddsp.vocoder import Sine_Generator
14 |
15 | import matplotlib.pyplot as plt
16 |
17 | from typing import Tuple, Union, Optional, Dict
18 |
19 | def czt(
20 | x: torch.Tensor,
21 | m: int,
22 | A: torch.Tensor,
23 | W: torch.Tensor
24 | ):
25 | """
26 | Args:
27 | x: tensor [shape = (n)]
28 | m: int
29 | A: complex
30 | W: complex
31 |
32 | A = A_0 * exp(j * θ)
33 | W = W_0 * exp(-j * ϕ)
34 |
35 | 通常情况下 A_0 = 1 W_0 = 1 θ = 2π(f0/fs) ϕ = 2π(f0/fs)
36 | """
37 | n = x.shape[0]
38 | #l = int(2 ** np.ceil(np.log2(n + m - 1)))
39 | #使用torch计算
40 | l = int(2 ** torch.ceil(torch.log2(torch.tensor(n + m - 1, dtype=torch.double))))
41 |
42 | w = W ** (torch.arange(max(m, n), dtype=torch.double, device=x.device)**2 / 2)
43 |
44 | gn = torch.zeros(l, dtype=torch.complex128, device=x.device)
45 | gn[:n] = (x * (A ** (-torch.arange(0, n, dtype=torch.double,device=x.device))) * w[:n])
46 |
47 | hn = torch.zeros(l, dtype=torch.complex128, device=x.device)
48 | hn[:m] = 1/w[:m]
49 | hn[l-n+1:]=1/torch.flip(w[1:n],dims=(0,))
50 |
51 | yk = torch.fft.fft(gn) * torch.fft.fft(hn)
52 | qn = torch.fft.ifft(yk)
53 | yn = qn[:m] * w[:m]
54 |
55 | return yn
56 |
57 | @torch.jit.script
58 | def sinusoidal_analysis_czt_for(
59 | audio : torch.Tensor,
60 | f0 : torch.Tensor,
61 | sr : int,
62 | hop_size: int,
63 | max_nhar: int,
64 | relative_winsize: int
65 | ):
66 | """
67 | 先用for循环实现,以后优化掉
68 | Args:
69 | audio: Tensor [shape = (t)]
70 | f0: Tensor [shape = (n_frames)]
71 | max_nhar: int
72 | relative_winsize: int
73 | Returns:
74 | log_ampl: [shape = (max_nhar, n_frames)]
75 | x: [shape = (max_nhar, n_frames)]
76 | phase: [shape = (max_nhar, n_frames)]
77 | """
78 | n_frames = f0.shape[0]
79 | f0 = f0.to(audio.device)
80 |
81 | n_fft, f0_min = get_n_fft(f0, sr, relative_winsize)
82 | f0 = f0.clamp(min=f0_min)
83 |
84 | nhar = torch.floor(sr / f0 / 2).clamp(max=max_nhar)
85 | winsize = torch.round(sr / f0 * relative_winsize / 2) * 2
86 |
87 | pad = int(n_fft // 2)
88 | audio_pad = F.pad(audio, [pad, pad])
89 |
90 | ampl = torch.zeros((max_nhar, n_frames), device=audio.device)
91 | phase = torch.zeros((max_nhar, n_frames), device=audio.device)
92 |
93 | # 先用for循环实现,以后优化掉
94 | for i in range(n_frames):
95 | f0_i = f0[i]
96 | f0_i = f0_i.to(dtype=torch.double)
97 | nhar_i = int(nhar[i])
98 | winsize_i = int(winsize[i])
99 | start_i = int(i * hop_size)+pad
100 |
101 | window = torch.blackman_window(winsize_i, device=audio.device)
102 |
103 | audio_frame = audio_pad[start_i-winsize_i//2 : start_i+winsize_i//2]
104 | audio_frame = audio_frame * window
105 |
106 | A = torch.exp(torch.complex(torch.tensor(0.,dtype=torch.double,device=audio.device), 2 * torch.pi * f0_i/sr))
107 | W = torch.exp(torch.complex(torch.tensor(0.,dtype=torch.double,device=audio.device), -2 * torch.pi * f0_i/sr))
108 |
109 | yn = czt(audio_frame, nhar_i, A, W)
110 | yn = 2.381 * (yn / (len(audio_frame)//2+1))
111 |
112 | ampl[:int(nhar_i), i] = torch.abs(yn)
113 | phase[:int(nhar_i), i] = torch.angle(yn)
114 |
115 | return ampl, phase
116 |
117 | def variable_window_STFT(
118 | audio: torch.Tensor,
119 | n_fft: int,
120 | hop_size: int,
121 | window_size: torch.Tensor
122 | ):
123 | '''
124 | window_size可变的STFT
125 | Args:
126 | audio: [shape = (t)]
127 | window_size: Tensor [shape = (n_frames)]
128 | Returns:
129 | S: [shape = (n_fft//2+1, n_frames)]
130 | '''
131 | pad = int(n_fft // 2) # n_frames = t//hop_size + 1
132 | audio_unfold = F.pad(audio, [pad, pad]).unfold(0, n_fft, hop_size) # (n_frames, n_fft)
133 |
134 | window_tensor = generate_window_tensor(window_size, n_fft) # (n_fft, n_frames)
135 | audio_unfold = audio_unfold * window_tensor.T # (n_frames, n_fft)
136 | S = torch.fft.rfft(audio_unfold).T # (n_fft//2+1, n_frames)
137 |
138 | return S
139 |
140 | def generate_window_tensor(window_size, n_fft):
141 | n_frames = window_size.shape[0]
142 | window_tensor = torch.zeros((n_fft, n_frames))
143 |
144 | for i in range(n_frames):
145 | winsize = int(window_size[i])
146 | window = torch.blackman_window(winsize)
147 | pad_size = (n_fft - winsize) // 2
148 | window_tensor[pad_size:pad_size + winsize, i] = window
149 |
150 | return window_tensor
151 |
152 | def sinusoidal_analysis_qifft(
153 | audio: torch.Tensor,
154 | sr: int,
155 | hop_size: int,
156 | f0: torch.Tensor,
157 | max_nhar: int,
158 | relative_winsize: int,
159 | standard_winsize = 1024
160 | ):
161 | # 有bug,不要使用
162 | n_frames = f0.shape[0]
163 |
164 | n_fft, f0_min = get_n_fft(f0, sr, relative_winsize)
165 | f0 = f0.clamp(min=f0_min)
166 |
167 | winsize_size = torch.round((sr / f0 * relative_winsize / 2) * 2).int() # (n_frames)
168 |
169 | standard_window = torch.blackman_window(standard_winsize)
170 | standard_normalizer = 0.5 * torch.sum(standard_window)
171 | normalizer = standard_winsize / standard_normalizer / winsize_size # (n_frames)
172 |
173 | S = variable_window_STFT(
174 | audio, n_fft, hop_size, winsize_size
175 | ) * normalizer # (n_fft//2+1, n_frames)
176 |
177 | spec_magn, spec_phse = torch.abs(S), torch.angle(S)
178 | log_spec_magn = torch.log(torch.clamp(spec_magn, min=1e-8))
179 |
180 | qifft_tensor = torch.zeros((3, max_nhar, n_frames), device=audio.device)
181 | peak_bin_tensor = torch.zeros((max_nhar, n_frames), device=audio.device)
182 | remove_above_nhar = torch.zeros((max_nhar, n_frames), device=audio.device)
183 |
184 | tolerance = 0.3
185 | nhars = torch.clamp((sr / (f0 * 2)).floor().int(), max=max_nhar)
186 | f0_proportions = f0 / sr * n_fft
187 |
188 | for i in range(n_frames):
189 | nhar, f0_proportion = nhars[i], f0_proportions[i]
190 | remove_above_nhar[:nhar, i] = 1
191 |
192 | l_idxs = (
193 | (torch.arange(
194 | 1,nhar+1, device=audio.device) - tolerance
195 | ) * f0_proportion
196 | ).round().clamp(1, n_fft//2 - 1)
197 |
198 | u_idxs = (
199 | (torch.arange(
200 | 1,nhar+1, device=audio.device) + tolerance
201 | ) * f0_proportion
202 | ).round().clamp(1, n_fft//2 - 1)
203 |
204 | for j in range(nhar):
205 | l_idx, u_idx = int(l_idxs[j]), int(u_idxs[j])
206 | peak_bin = torch.argmax(log_spec_magn[l_idx:u_idx+1, i])
207 | peak_bin += l_idx
208 | peak_bin_tensor[j, i] = peak_bin
209 |
210 | qifft_tensor[0, j, i] = log_spec_magn[peak_bin - 1, i]
211 | qifft_tensor[1, j, i] = log_spec_magn[peak_bin , i]
212 | qifft_tensor[2, j, i] = log_spec_magn[peak_bin + 1, i]
213 |
214 | log_ampl, x = qifft(qifft_tensor)
215 | ampl = torch.exp(log_ampl)
216 | phase = torch.zeros((max_nhar, n_frames), device=audio.device)
217 |
218 | x = x + peak_bin_tensor
219 | interp_x = np.linspace(0, n_fft//2, n_fft//2 + 1)
220 | for i in range(n_frames):
221 | phase[:, i] = torch.from_numpy(
222 | np.interp(
223 | x[:, i].numpy(),
224 | interp_x,
225 | np.unwrap(spec_phse[:, i].numpy())
226 | )
227 | )
228 |
229 | ampl = ampl * remove_above_nhar
230 | phase = phase * remove_above_nhar
231 |
232 | return ampl, phase
233 |
234 | def qifft(qifft_tensor: torch.Tensor):
235 | '''
236 | Args:
237 | qifft_tensor: (3, max_nhar, n_frames)
238 | '''
239 | a = qifft_tensor[0, :, :] # (max_nhar, n_frames)
240 | b = qifft_tensor[1, :, :]
241 | c = qifft_tensor[2, :, :]
242 |
243 | a1 = (a + c) / 2.0 - b # (max_nhar, n_frames)
244 | a2 = c - b - a1
245 | x = -a2 / (a1 + 1e-8) * 0.5
246 | x[torch.abs(x) > 1] = 0
247 |
248 | ret = a1 * x * x + a2 * x + b # (max_nhar, n_frames)
249 | idx = ret > b + 0.2
250 | ret[idx] = b[idx] + 0.2 # Why? I don't know.
251 | return ret, x
252 |
253 | class SinusoidalAnalyzer(nn.Module):
254 | def __init__(
255 | self,
256 | sampling_rate: int,
257 | hop_size : int,
258 | max_nhar : int,
259 | relative_winsize: int,
260 | device : str = 'cpu',
261 | ):
262 | super().__init__()
263 | self.sampling_rate = sampling_rate
264 | self.hop_size = hop_size
265 | self.max_nhar = max_nhar
266 | self.relative_winsize = relative_winsize
267 | self.device = device
268 |
269 | def forward(
270 | self,
271 | x : torch.Tensor,
272 | f0 : torch.Tensor,
273 | model : str = 'czt' #目前只能用czt
274 | ) -> Tuple[torch.Tensor, torch.Tensor]:
275 | """
276 | Analyze the given audio signal to extract sinusoidal parameters.
277 |
278 | Args:
279 | x (torch.Tensor): Audio signal tensor of shape (t,).
280 | f0 (torch.Tensor): F0 tensor of shape (n_frames,).
281 | n_frames (int): Number of frames, equal to the length of mel_spec.
282 | model (str): Type of sinusoidal analysis model ('czt' or 'qifft').
283 |
284 | Returns:
285 | torch.Tensor: Extracted sinusoidal parameters of shape (max_nhar, n_frames).
286 | """
287 | if model == 'czt':
288 | ampl, phase = sinusoidal_analysis_czt_for(
289 | x, f0, self.sampling_rate, self.hop_size, self.max_nhar, self.relative_winsize
290 | )
291 | elif model == 'qifft':
292 | ampl, phase = sinusoidal_analysis_qifft(
293 | x, self.sampling_rate, self.hop_size, f0, self.max_nhar, self.relative_winsize
294 | )
295 | else:
296 | raise ValueError(f" [x] Unknown sinusoidal analysis model: {model}")
297 |
298 | return ampl, phase
299 |
300 |
301 |
302 | class F0Analyzer(nn.Module):
303 | def __init__(
304 | self,
305 | sampling_rate: int,
306 | f0_extractor : str,
307 | hop_size : int,
308 | f0_min : float,
309 | f0_max : float,
310 | ):
311 | """
312 | Args:
313 | sampling_rate (int): Sampling rate of the audio signal.
314 | f0_extractor (str): Type of F0 extractor ('parselmouth', 'dio', or 'harvest').
315 | f0_min (float): Minimum F0 in Hz.
316 | f0_max (float): Maximum F0 in Hz.
317 | hop_size (int): Hop size in samples.
318 | """
319 | super(F0Analyzer, self).__init__()
320 | self.sampling_rate = sampling_rate
321 | self.f0_extractor = f0_extractor
322 | self.hop_size = hop_size
323 | self.f0_min = f0_min
324 | self.f0_max = f0_max
325 |
326 | def forward(
327 | self,
328 | x: torch.Tensor,
329 | n_frames: int
330 | ) -> torch.Tensor:
331 | """
332 | Analyze the given audio signal to extract F0.
333 |
334 | Args:
335 | x (torch.Tensor): Audio signal tensor of shape (t,).
336 | n_frames (int): Number of frames, equal to the length of mel_spec.
337 |
338 | Returns:
339 | torch.Tensor: Extracted F0 of shape (n_frames,).
340 | """
341 | x = x.to('cpu').numpy()
342 |
343 | if self.f0_extractor == 'parselmouth':
344 | f0 = self._extract_f0_parselmouth(x, n_frames)
345 | elif self.f0_extractor == 'dio':
346 | f0 = self._extract_f0_dio(x, n_frames)
347 | elif self.f0_extractor == 'harvest':
348 | f0 = self._extract_f0_harvest(x, n_frames)
349 | else:
350 | raise ValueError(f" [x] Unknown f0 extractor: {self.f0_extractor}")
351 |
352 | uv = f0 == 0
353 | return f0, uv
354 |
355 | def _extract_f0_parselmouth(self, x: np.ndarray, n_frames):
356 | l_pad = int(
357 | np.ceil(
358 | 1.5 / self.f0_min * self.sampling_rate
359 | )
360 | )
361 | r_pad = self.hop_size * ((len(x) - 1) // self.hop_size + 1) - len(x) + l_pad + 1
362 | padded_signal = np.pad(x, (l_pad, r_pad))
363 |
364 | sound = pm.Sound(padded_signal, self.sampling_rate)
365 | pitch = sound.to_pitch_ac(
366 | time_step=self.hop_size / self.sampling_rate,
367 | voicing_threshold=0.6,
368 | pitch_floor=self.f0_min,
369 | pitch_ceiling=1100
370 | )
371 |
372 | f0 = pitch.selected_array['frequency']
373 | if len(f0) < n_frames:
374 | f0 = np.pad(f0, (0, n_frames - len(f0)))
375 | f0 = f0[:n_frames]
376 |
377 | return f0
378 |
379 | def _extract_f0_dio(self, x: np.ndarray, n_frames: int) -> np.ndarray:
380 | _f0, t = pw.dio(
381 | x.astype('double'),
382 | self.sampling_rate,
383 | f0_floor=self.f0_min,
384 | f0_ceil=self.f0_max,
385 | channels_in_octave=2,
386 | frame_period=(1000 * self.hop_size / self.sampling_rate)
387 | )
388 |
389 | f0 = pw.stonemask(x.astype('double'), _f0, t, self.sampling_rate)
390 | return f0.astype('float')[:n_frames]
391 |
392 | def _extract_f0_harvest(self, x: np.ndarray, n_frames: int) -> np.ndarray:
393 | f0, _ = pw.harvest(
394 | x.astype('double'),
395 | self.sampling_rate,
396 | f0_floor=self.f0_min,
397 | f0_ceil=self.f0_max,
398 | frame_period=(1000 * self.hop_size / self.sampling_rate)
399 | )
400 | return f0.astype('float')[:n_frames]
401 |
402 |
403 | class MelAnalysis(nn.Module):
404 | def __init__(
405 | self,
406 | sampling_rate: int,
407 | win_size : int,
408 | hop_size : int,
409 | n_mels : int,
410 | n_fft : Optional[int] = None,
411 | mel_fmin : float = 0.0,
412 | mel_fmax : Optional[float] = None,
413 | clamp : float = 1e-5,
414 | device : str = 'cpu',
415 | ):
416 | super().__init__()
417 | n_fft = win_size if n_fft is None else n_fft
418 | self.hann_window: Dict[str, torch.Tensor] = {}
419 |
420 | mel_basis = get_mel_fn(
421 | sr=sampling_rate,
422 | n_fft=n_fft,
423 | n_mels=n_mels,
424 | fmin=mel_fmin,
425 | fmax=mel_fmax,
426 | htk=False,
427 | device=device,
428 | )
429 | mel_basis = mel_basis.float()
430 | self.register_buffer("mel_basis", mel_basis)
431 |
432 | self.hop_size: int = hop_size
433 | self.win_size: int = win_size
434 | self.n_fft : int = n_fft
435 | self.clamp : float = clamp
436 |
437 | def _get_hann_window(self, win_size: int, device) -> torch.Tensor:
438 | key: str = f"{win_size}_{device}"
439 | if key not in self.hann_window:
440 | self.hann_window[key] = torch.hann_window(win_size).to(device)
441 | return self.hann_window[key]
442 |
443 | def _get_mel(
444 | self,
445 | audio: torch.Tensor,
446 | keyshift: float = 0.0,
447 | speed: float = 1.0,
448 | diffsinger: bool = True
449 | ) -> torch.Tensor:
450 | factor: float = 2 ** (keyshift / 12)
451 | n_fft_new: int = int(np.round(self.n_fft * factor))
452 | win_size_new: int = int(np.round(self.win_size * factor))
453 | hop_size_new: int = int(np.round(self.hop_size * speed))
454 | hann_window_new: torch.Tensor = self._get_hann_window(win_size_new, audio.device)
455 |
456 | # 处理双声道信号
457 | if len(audio.shape) == 2:
458 | print("双声道信号")
459 | audio = audio[:, 0]
460 |
461 | if diffsinger:
462 | audio = F.pad(
463 | audio.unsqueeze(0),
464 | ((win_size_new - hop_size_new) // 2, (win_size_new - hop_size_new + 1) // 2),
465 | mode='reflect'
466 | ).squeeze(0)
467 | center: bool = False
468 | else:
469 | center = True
470 |
471 | fft = torch.stft(
472 | audio,
473 | n_fft=n_fft_new,
474 | hop_length=hop_size_new,
475 | win_length=win_size_new,
476 | window=hann_window_new,
477 | center=center,
478 | return_complex=True
479 | )
480 | magnitude: torch.Tensor = fft.abs()
481 |
482 | if keyshift != 0:
483 | size: int = self.n_fft // 2 + 1
484 | resize: int = magnitude.size(1)
485 | if resize < size:
486 | magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
487 | magnitude = magnitude[:, :size, :] * self.win_size / win_size_new
488 |
489 | mel_output: torch.Tensor = torch.matmul(self.mel_basis, magnitude)
490 | return mel_output
491 |
492 | def forward(
493 | self,
494 | audio: torch.Tensor,
495 | keyshift: float = 0.0,
496 | speed: float = 1.0,
497 | diffsinger:bool = True,
498 | mel_base: str = 'e'
499 | ) -> torch.Tensor:
500 | if torch.min(audio) < -1.0 or torch.max(audio) > 1.0:
501 | print('Audio values exceed [-1., 1.] range')
502 |
503 | mel: torch.Tensor = self._get_mel(audio, keyshift=keyshift, speed=speed, diffsinger=diffsinger)
504 | log_mel_spec: torch.Tensor = torch.log(torch.clamp(mel, min=self.clamp))
505 |
506 | if mel_base != 'e':
507 | assert mel_base in ['10', 10], "mel_base must be 'e', '10' or 10."
508 | log_mel_spec *= 0.434294 # Convert log to base 10
509 |
510 | # mel shape: (n_mels, n_frames)
511 |
512 | return log_mel_spec
513 |
514 | if __name__ == '__main__':
515 | # test
516 | from utils import interp_f0
517 |
518 | audiopath = r'E:/pc-ddsp5.29/TheKitchenettes_Alive_RAW_08_030.wav'
519 | audio=sf.read(audiopath)[0]
520 | audio= torch.from_numpy(audio).float()
521 |
522 | sampling_rate = 44100
523 | hop_size = 512
524 | n_fft = 2048
525 | n_mels = 128
526 |
527 | mel_extractor = MelAnalysis(
528 | sampling_rate=sampling_rate,
529 | win_size=n_fft,
530 | hop_size=hop_size,
531 | n_mels=n_mels,
532 | n_fft=n_fft,
533 | mel_fmin=40,
534 | mel_fmax=16000,
535 | clamp=1e-6,
536 | device='cpu',
537 | )
538 |
539 | mel=mel_extractor(audio, keyshift=0.0, speed=1.0, diffsinger=True, mel_base='e')
540 | print(mel.shape)
541 | nfreams = mel.shape[1]
542 | print(nfreams)
543 | print(audio.shape)
544 | print(audio.shape[0]//hop_size)
545 |
546 | f0_analyzer = F0Analyzer(
547 | sampling_rate=sampling_rate,
548 | f0_extractor='parselmouth',
549 | hop_size=hop_size,
550 | f0_min=40,
551 | f0_max=800,
552 | )
553 |
554 | f0, uv = f0_analyzer(audio, n_frames=nfreams)
555 |
556 | print(f0.shape)
557 | f0, _ = interp_f0(f0, uv)
558 |
559 |
560 | ampl, phase = sinusoidal_analysis_czt_for(
561 | audio=audio,
562 | sr=sampling_rate,
563 | hop_size=hop_size,
564 | f0=torch.from_numpy(f0),
565 | max_nhar=256,
566 | relative_winsize=4,
567 | )
568 |
569 | vocoder = Sine_Generator(hop_size=hop_size, sampling_rate=sampling_rate, device='cpu')
570 | y = vocoder(ampl.unsqueeze(0), phase.unsqueeze(0), torch.from_numpy(f0).unsqueeze(0))
571 |
572 | print(y.shape)
573 |
574 | sf.write('test.wav', y.numpy(), sampling_rate,subtype='FLOAT')
575 |
576 |
577 |
--------------------------------------------------------------------------------
/ddsp/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torchaudio
6 |
7 | from .utils import upsample
8 |
9 | class HybridLoss(nn.Module):
10 | def __init__(
11 | self,
12 | block_size,
13 | fft_min,
14 | fft_max,
15 | n_scale,
16 | lambda_uv,
17 | lambda_ampl,
18 | lambda_phase,
19 | device
20 | ):
21 | super().__init__()
22 | self.loss_rss_func = RSSLoss(fft_min, fft_max, n_scale, device = device)
23 | self.loss_uv_func = UVLoss(block_size)
24 | self.loss_ampl_func = AmplLoss()
25 | self.loss_phase_func = PhaseLoss()
26 | self.lambda_uv = lambda_uv
27 | self.lambda_ampl = lambda_ampl
28 | self.lambda_phase = lambda_phase
29 |
30 | def forward(
31 | self,
32 | signal,
33 | s_h,
34 | ampl,
35 | phase,
36 | x_true,
37 | uv_true,
38 | ampl_true,
39 | phase_true,
40 | detach_uv=False,
41 | uv_tolerance=0.05
42 | ):
43 | loss_rss = self.loss_rss_func(signal, x_true)
44 | loss_uv = self.loss_uv_func(signal, s_h, uv_true)
45 | if detach_uv or loss_uv < uv_tolerance:
46 | loss_uv = loss_uv.detach()
47 | loss_ampl = self.loss_ampl_func(ampl, ampl_true)
48 | # loss_phase = self.loss_phase_func(phase, phase_true)
49 | loss_phase = torch.tensor(0.,device="cuda")
50 | loss = loss_rss + self.lambda_uv * loss_uv + self.lambda_ampl * loss_ampl + self.lambda_phase * loss_phase
51 | return loss, (loss_rss, loss_uv, loss_ampl, loss_phase)
52 |
53 | class UVLoss(nn.Module):
54 | def __init__(self, block_size, eps = 1e-5):
55 | super().__init__()
56 | self.block_size = block_size
57 | self.eps = eps
58 |
59 | def forward(self, signal, s_h, uv_true):
60 | uv_mask = upsample(uv_true.unsqueeze(1), self.block_size).squeeze(1)
61 | loss = torch.mean(torch.linalg.norm(s_h * uv_mask, dim = 1) / (torch.linalg.norm(signal * uv_mask , dim = 1) + self.eps))
62 | return loss
63 |
64 | class AmplLoss(nn.Module):
65 | def __init__(self, alpha=1.0, eps = 1e-5):
66 | super().__init__()
67 | self.eps = eps
68 | self.alpha = alpha
69 |
70 | def forward(self, ampl_pred, ampl_true):
71 | ampl_true = ampl_true + self.eps
72 | ampl_pred = ampl_pred + self.eps
73 |
74 | converge_term = torch.mean(
75 | torch.linalg.norm(
76 | ampl_true - ampl_pred,
77 | dim = (1, 2)
78 | ) /
79 | torch.linalg.norm(
80 | ampl_true + ampl_pred,
81 | dim = (1, 2)
82 | )
83 | )
84 |
85 | log_term = F.l1_loss(ampl_pred.log(), ampl_true.log())
86 |
87 | return converge_term + self.alpha * log_term
88 |
89 |
90 | class PhaseLoss(nn.Module):
91 | def __init__(self, eps = 1e-5):
92 | super().__init__()
93 | self.eps = eps
94 |
95 | @staticmethod
96 | def unwrap(x):
97 | return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
98 |
99 | def GD_loss(self, phase_pred, phase_true):
100 | gd_true_diff = phase_true[:,1:,:] - phase_true[:,:-1,:]
101 | gd_pred_diff = phase_pred[:,1:,:] - phase_pred[:,:-1,:]
102 | gd_loss = torch.mean(self.unwrap(gd_true_diff - gd_pred_diff))
103 | return gd_loss
104 |
105 | def PTD_loss(self, phase_pred, phase_true):
106 | ptd_true_diff = phase_true[:,:,1:] - phase_true[:,:,:-1]
107 | ptd_pred_diff = phase_pred[:,:,1:] - phase_pred[:,:,:-1]
108 | ptd_loss = torch.mean(self.unwrap(ptd_true_diff - ptd_pred_diff))
109 | return ptd_loss
110 |
111 |
112 | def forward(self, phase_pred, phase_true):
113 | gd_loss = self.GD_loss(phase_pred, phase_true)
114 | ptd_loss = self.PTD_loss(phase_pred, phase_true)
115 | loss = gd_loss + ptd_loss
116 | return loss
117 |
118 |
119 | class SSSLoss(nn.Module):
120 | """
121 | Single-scale Spectral Loss.
122 | """
123 |
124 | def __init__(self, n_fft=111, alpha=1.0, overlap=0, eps=1e-7):
125 | super().__init__()
126 | self.n_fft = n_fft
127 | self.alpha = alpha
128 | self.eps = eps
129 | self.hop_length = int(n_fft * (1 - overlap)) # 25% of the length
130 | self.spec = torchaudio.transforms.Spectrogram(
131 | n_fft = self.n_fft,
132 | hop_length = self.hop_length,
133 | power=1,
134 | normalized=True,
135 | center=False
136 | )
137 |
138 | def forward(self, x_true, x_pred):
139 | S_true = self.spec(x_true) + self.eps
140 | S_pred = self.spec(x_pred) + self.eps
141 |
142 | converge_term = torch.mean(
143 | torch.linalg.norm(
144 | S_true - S_pred, dim = (1, 2)
145 | ) /
146 | torch.linalg.norm(
147 | S_true + S_pred, dim = (1, 2)
148 | )
149 | )
150 |
151 | log_term = F.l1_loss(S_true.log(), S_pred.log())
152 |
153 | loss = converge_term + self.alpha * log_term
154 | return loss
155 |
156 |
157 | class MSSLoss(nn.Module):
158 | """
159 | Multi-scale Spectral Loss.
160 | Usage ::
161 | mssloss = MSSLoss([2048, 1024, 512, 256], alpha=1.0, overlap=0.75)
162 | mssloss(y_pred, y_gt)
163 | input(y_pred, y_gt) : two of torch.tensor w/ shape(batch, 1d-wave)
164 | output(loss) : torch.tensor(scalar)
165 |
166 | 48k: n_ffts=[2048, 1024, 512, 256]
167 | 24k: n_ffts=[1024, 512, 256, 128]
168 | """
169 |
170 | def __init__(self, n_ffts, alpha=1.0, overlap=0.75, eps=1e-7):
171 | super().__init__()
172 | self.losses = nn.ModuleList([SSSLoss(n_fft, alpha, overlap, eps) for n_fft in n_ffts])
173 |
174 | def forward(self, x_pred, x_true):
175 | x_pred = x_pred[..., :x_true.shape[-1]]
176 | value = 0.
177 | for loss in self.losses:
178 | value += loss(x_true, x_pred)
179 | return value
180 |
181 | class RSSLoss(nn.Module):
182 | '''
183 | Random-scale Spectral Loss.
184 | '''
185 |
186 | def __init__(self, fft_min, fft_max, n_scale, alpha=1.0, overlap=0, eps=1e-7, device='cuda'):
187 | super().__init__()
188 | self.fft_min = fft_min
189 | self.fft_max = fft_max
190 | self.n_scale = n_scale
191 | self.lossdict = {}
192 | for n_fft in range(fft_min, fft_max):
193 | self.lossdict[n_fft] = SSSLoss(n_fft, alpha, overlap, eps).to(device)
194 |
195 | def forward(self, x_pred, x_true):
196 | value = 0.
197 | n_ffts = torch.randint(self.fft_min, self.fft_max, (self.n_scale,))
198 | for n_fft in n_ffts:
199 | loss_func = self.lossdict[int(n_fft)]
200 | value += loss_func(x_true, x_pred)
201 | return value / self.n_scale
202 |
203 |
204 |
205 |
--------------------------------------------------------------------------------
/ddsp/mel2control.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.utils import weight_norm
4 |
5 | from ddsp.model_conformer_naive import ConformerNaiveEncoder
6 |
7 | def split_to_dict(tensor, tensor_splits):
8 | """Split a tensor into a dictionary of multiple tensors."""
9 | labels = []
10 | sizes = []
11 |
12 | for k, v in tensor_splits.items():
13 | labels.append(k)
14 | sizes.append(v)
15 |
16 | tensors = torch.split(tensor, sizes, dim=-1)
17 | return dict(zip(labels, tensors))
18 |
19 |
20 | class Mel2Control(nn.Module):
21 | def __init__(
22 | self,
23 | n_mels,
24 | n_sin_hars,
25 | block_size,
26 | output_splits
27 | ):
28 | super().__init__()
29 | self.output_splits = output_splits
30 | self.mel_emb = nn.Linear(n_mels, 256)
31 | self.phase_emb = nn.Linear(n_sin_hars, 256)
32 | self.decoder = ConformerNaiveEncoder(
33 | num_layers=3,
34 | num_heads=8,
35 | dim_model=256,
36 | use_norm=False,
37 | conv_only=True,
38 | conv_dropout=0,
39 | atten_dropout=0.1)
40 | self.norm = nn.LayerNorm(256)
41 | self.n_out = sum([v for k, v in output_splits.items()])
42 | self.dense_out = weight_norm(nn.Linear(256, self.n_out))
43 |
44 | def forward(self, mel, inp):
45 |
46 | '''
47 | input:
48 | B x n_frames x n_mels
49 | return:
50 | dict of B x n_frames x feat
51 | '''
52 | # print("mel2control mel",mel.shape)
53 | x = self.mel_emb(mel) + self.phase_emb(inp)
54 | x = self.decoder(x)
55 | x = self.norm(x)
56 | e = self.dense_out(x)
57 | controls = split_to_dict(e, self.output_splits)
58 |
59 | return controls
60 |
61 |
--------------------------------------------------------------------------------
/ddsp/model_conformer_naive.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | # From https://github.com/CNChTu/Diffusion-SVC/ by CNChTu
5 | # License: MIT
6 |
7 |
8 | class ConformerNaiveEncoder(nn.Module):
9 | """
10 | Conformer Naive Encoder
11 |
12 | Args:
13 | dim_model (int): Dimension of model
14 | num_layers (int): Number of layers
15 | num_heads (int): Number of heads
16 | use_norm (bool): Whether to use norm for FastAttention, only True can use bf16/fp16, default False
17 | conv_only (bool): Whether to use only conv module without attention, default False
18 | conv_dropout (float): Dropout rate of conv module, default 0.
19 | atten_dropout (float): Dropout rate of attention module, default 0.
20 | """
21 |
22 | def __init__(self,
23 | num_layers: int,
24 | num_heads: int,
25 | dim_model: int,
26 | use_norm: bool = False,
27 | conv_only: bool = False,
28 | conv_dropout: float = 0.,
29 | atten_dropout: float = 0.
30 | ):
31 | super().__init__()
32 | self.num_layers = num_layers
33 | self.num_heads = num_heads
34 | self.dim_model = dim_model
35 | self.use_norm = use_norm
36 | self.residual_dropout = 0.1 # 废弃代码,仅做兼容性保留
37 | self.attention_dropout = 0.1 # 废弃代码,仅做兼容性保留
38 |
39 | self.encoder_layers = nn.ModuleList(
40 | [
41 | CFNEncoderLayer(dim_model, num_heads, use_norm, conv_only, conv_dropout, atten_dropout)
42 | for _ in range(num_layers)
43 | ]
44 | )
45 |
46 | def forward(self, x, mask=None) -> torch.Tensor:
47 | """
48 | Args:
49 | x (torch.Tensor): Input tensor (#batch, length, dim_model)
50 | mask (torch.Tensor): Mask tensor, default None
51 | return:
52 | torch.Tensor: Output tensor (#batch, length, dim_model)
53 | """
54 |
55 | for (i, layer) in enumerate(self.encoder_layers):
56 | x = layer(x, mask)
57 | return x # (#batch, length, dim_model)
58 |
59 |
60 | class CFNEncoderLayer(nn.Module):
61 | """
62 | Conformer Naive Encoder Layer
63 |
64 | Args:
65 | dim_model (int): Dimension of model
66 | num_heads (int): Number of heads
67 | use_norm (bool): Whether to use norm for FastAttention, only True can use bf16/fp16, default False
68 | conv_only (bool): Whether to use only conv module without attention, default False
69 | conv_dropout (float): Dropout rate of conv module, default 0.1
70 | atten_dropout (float): Dropout rate of attention module, default 0.1
71 | """
72 |
73 | def __init__(self,
74 | dim_model: int,
75 | num_heads: int = 8,
76 | use_norm: bool = False,
77 | conv_only: bool = False,
78 | conv_dropout: float = 0.,
79 | atten_dropout: float = 0.1
80 | ):
81 | super().__init__()
82 |
83 | self.conformer = ConformerConvModule(dim_model, use_norm=use_norm, dropout=conv_dropout)
84 |
85 | self.norm = nn.LayerNorm(dim_model)
86 |
87 | self.dropout = nn.Dropout(0.1) # 废弃代码,仅做兼容性保留
88 |
89 | # selfatt -> fastatt: performer!
90 | if not conv_only:
91 | self.attn = nn.TransformerEncoderLayer(
92 | d_model=dim_model,
93 | nhead=num_heads,
94 | dim_feedforward=dim_model * 4,
95 | dropout=atten_dropout,
96 | activation='gelu'
97 | )
98 | else:
99 | self.attn = None
100 |
101 | def forward(self, x, mask=None) -> torch.Tensor:
102 | """
103 | Args:
104 | x (torch.Tensor): Input tensor (#batch, length, dim_model)
105 | mask (torch.Tensor): Mask tensor, default None
106 | return:
107 | torch.Tensor: Output tensor (#batch, length, dim_model)
108 | """
109 | if self.attn is not None:
110 | x = x + (self.attn(self.norm(x), mask=mask))
111 |
112 | x = x + (self.conformer(x))
113 |
114 | return x # (#batch, length, dim_model)
115 |
116 |
117 | class ConformerConvModule(nn.Module):
118 | def __init__(
119 | self,
120 | dim,
121 | expansion_factor=2,
122 | kernel_size=31,
123 | dropout=0.,
124 | use_norm=False,
125 | conv_model_type='mode1'
126 | ):
127 | super().__init__()
128 |
129 | inner_dim = dim * expansion_factor
130 | padding = calc_same_padding(kernel_size)
131 |
132 | if conv_model_type == 'mode1':
133 | self.net = nn.Sequential(
134 | nn.LayerNorm(dim) if use_norm else nn.Identity(),
135 | Transpose((1, 2)),
136 | nn.Conv1d(dim, inner_dim * 2, 1),
137 | nn.GLU(dim=1),
138 | nn.Conv1d(inner_dim, inner_dim, kernel_size=kernel_size, padding=padding[0], groups=inner_dim),
139 | nn.PReLU(num_parameters=inner_dim),
140 | nn.Conv1d(inner_dim, dim, 1),
141 | Transpose((1, 2)),
142 | nn.Dropout(dropout)
143 | )
144 | elif conv_model_type == 'mode2':
145 | raise NotImplementedError('mode2 not implemented yet')
146 | else:
147 | raise ValueError(f'{conv_model_type} is not a valid conv_model_type')
148 |
149 | def forward(self, x):
150 | return self.net(x)
151 |
152 |
153 | def calc_same_padding(kernel_size):
154 | pad = kernel_size // 2
155 | return (pad, pad - (kernel_size + 1) % 2)
156 |
157 |
158 | class Transpose(nn.Module):
159 | def __init__(self, dims):
160 | super().__init__()
161 | assert len(dims) == 2, 'dims must be a tuple of two dimensions'
162 | self.dims = dims
163 |
164 | def forward(self, x):
165 | return x.transpose(*self.dims)
166 |
--------------------------------------------------------------------------------
/ddsp/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from scipy.interpolate import CubicSpline
5 |
6 | def get_mel_fn(
7 | sr : float,
8 | n_fft : int,
9 | n_mels : int,
10 | fmin : float,
11 | fmax : float,
12 | htk : bool,
13 | device : str = 'cpu'
14 | ) -> torch.Tensor:
15 | '''
16 | Args:
17 | htk: bool
18 | Whether to use HTK formula or Slaney formula for mel calculation'
19 | Returns:
20 | weights: Tensor [shape = (n_mels, n_fft // 2 + 1)]
21 | '''
22 | fmin = torch.tensor(fmin, device=device)
23 | fmax = torch.tensor(fmax, device=device)
24 |
25 | if htk:
26 | min_mel = 2595.0 * torch.log10(1.0 + fmin / 700.0)
27 | max_mel = 2595.0 * torch.log10(1.0 + fmax / 700.0)
28 | mels = torch.linspace(min_mel, max_mel, n_mels + 2, device=device)
29 | mel_f = 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
30 | else:
31 | f_sp = 200.0 / 3
32 | min_log_hz = 1000.0
33 | min_log_mel = (min_log_hz) / f_sp
34 | logstep = torch.log(torch.tensor(6.4, device=device)) / 27.0
35 |
36 | if fmin >= min_log_hz:
37 | min_mel = min_log_mel + torch.log(fmin / min_log_hz) / logstep
38 | else:
39 | min_mel = (fmin) / f_sp
40 |
41 | if fmax >= min_log_hz:
42 | max_mel = min_log_mel + torch.log(fmax / min_log_hz) / logstep
43 | else:
44 | max_mel = (fmax) / f_sp
45 |
46 | mels = torch.linspace(min_mel, max_mel, n_mels + 2, device=device)
47 | mel_f = torch.zeros_like(mels)
48 |
49 | log_t = mels >= min_log_mel
50 | mel_f[~log_t] =f_sp * mels[~log_t]
51 | mel_f[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
52 |
53 | n_mels = int(n_mels)
54 | N = 1 + n_fft // 2
55 | weights = torch.zeros((n_mels, N), device=device)
56 |
57 | fftfreqs = (sr / n_fft) * torch.arange(0, N, device=device)
58 |
59 | fdiff = torch.diff(mel_f)
60 | ramps = mel_f.unsqueeze(1) - fftfreqs.unsqueeze(0)
61 |
62 | lower = -ramps[:-2] / fdiff[:-1].unsqueeze(1)
63 | upper = ramps[2:] / fdiff[1:].unsqueeze(1)
64 | weights = torch.max(torch.tensor(0.0), torch.min(lower, upper))
65 |
66 | enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
67 | weights *= enorm.unsqueeze(1)
68 |
69 | return weights
70 |
71 | def expand_uv(uv):
72 | uv = uv.astype('float')
73 | uv = np.min(np.array([uv[:-2],uv[1:-1],uv[2:]]),axis=0)
74 | uv = np.pad(uv, (1, 1), constant_values=(uv[0], uv[-1]))
75 |
76 | return uv
77 |
78 |
79 | def norm_f0(f0: np.ndarray, uv=None):
80 | if uv is None:
81 | uv = f0 == 0
82 |
83 | f0 = np.log2(f0 + uv) # avoid arithmetic error
84 | f0[uv] = -np.inf
85 |
86 | return f0
87 |
88 | def denorm_f0(f0: np.ndarray, uv, pitch_padding=None):
89 | f0 = 2 ** f0
90 |
91 | if uv is not None:
92 | f0[uv > 0] = 0
93 |
94 | if pitch_padding is not None:
95 | f0[pitch_padding] = 0
96 |
97 | return f0
98 |
99 |
100 | def interp_f0_spline(f0: np.ndarray, uv=None):
101 | if uv is None:
102 | uv = f0 == 0
103 | f0max = np.max(f0)
104 | f0 = norm_f0(f0, uv)
105 |
106 | if uv.any() and not uv.all():
107 | spline = CubicSpline(np.where(~uv)[0], f0[~uv])
108 | f0[uv] = spline(np.where(uv)[0])
109 |
110 | return np.clip(denorm_f0(f0, uv=None),0,f0max), uv
111 |
112 | def interp_f0(f0: np.ndarray, uv=None):
113 | if uv is None:
114 | uv = f0 == 0
115 | f0 = norm_f0(f0, uv)
116 |
117 | if uv.any() and not uv.all():
118 | f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
119 |
120 | return denorm_f0(f0, uv=None), uv
121 |
122 |
123 | def get_n_fft(f0: torch.Tensor, sr: int, relative_winsize: int):
124 | '''
125 | Args:
126 | f0: Tensor [shape = (n_frames)]
127 | relative_winsize : int
128 | Relative window size in seconds
129 | Returns:
130 | n_fft: int
131 | '''
132 | # 去掉f0小于20hz的部分,有时候f0会出现很小的数值,导致n_fft计算错误
133 | f0 = f0[f0 > 20]
134 | f0_min = f0.min()
135 | if f0_min > 1000:
136 | f0_min = torch.tensor(1000)
137 |
138 | max_winsize = torch.round(sr / f0_min * relative_winsize / 2) * 2
139 | n_fft = 2 ** torch.ceil(torch.log2(max_winsize))
140 |
141 | return n_fft.int(), f0_min
142 |
143 | def upsample(signal, factor):
144 | '''
145 | signal: B x C X T
146 | factor: int
147 | return: B x C X T*factor
148 | '''
149 | signal = nn.functional.interpolate(
150 | torch.cat(
151 | (signal,signal[:,:,-1:]),
152 | 2
153 | ),
154 | size=signal.shape[-1] * factor + 1,
155 | mode='linear',
156 | align_corners=True
157 | )
158 | signal = signal[:,:,:-1]
159 | return signal
160 |
161 |
162 | if __name__ == '__main__':
163 | # test
164 | '''librosa_mel = librosa_mel_fn(sr=44100, n_fft=2048, n_mels=128, fmin=20, fmax=22050, htk=False)
165 | custom_mel = get_mel_fn(sr=44100, n_fft=2048, n_mels=128, fmin=20, fmax=22050, htk=False, device='cpu').to('cpu')
166 | print(torch.allclose(torch.tensor(librosa_mel), custom_mel, atol=1e-5))
167 | print(np.max( np.abs(librosa_mel - custom_mel.numpy()) ))
168 | # 画出mel filter的对比图,以及相减后的结果
169 | plt.figure(figsize=(50, 5))
170 | plt.subplot(3, 1, 1)
171 | plt.imshow(librosa_mel, origin='lower')
172 | plt.title('librosa_mel')
173 | plt.subplot(3, 1, 2)
174 | plt.imshow(custom_mel.numpy(), origin='lower')
175 | plt.title('custom_mel')
176 | plt.subplot(3, 1, 3)
177 | plt.imshow(np.abs(librosa_mel - custom_mel.numpy()), origin='lower')
178 | plt.title('diff')
179 | plt.show()'''
180 |
181 | # test get_n_fft
182 | f0 = torch.tensor([2000], dtype=torch.float32)
183 | sr = 44100
184 | relative_winsize = 4
185 | n_fft = get_n_fft(f0, sr, relative_winsize)
186 | print(n_fft)
187 |
--------------------------------------------------------------------------------
/ddsp/vocoder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | import yaml
8 |
9 | from ddsp.mel2control import Mel2Control
10 | from .utils import upsample
11 |
12 | def compute_inphase(f0_sum, hop_size, sampling_rate, device, inference=False):
13 | '''
14 | Args:
15 | f0_sum: [shape = (batch, C, T-1)]
16 | hop_size: int
17 | sampling_rate: int
18 | B: int
19 | max_nhar: int
20 | device: str
21 | inference: bool, default False
22 | Returns:
23 | inphase: [shape = (B, C, T)]
24 | '''
25 | B, C, _ = f0_sum.shape
26 | if inference:
27 | inphase = torch.cumsum(
28 | (f0_sum * np.pi / sampling_rate * hop_size).double()%(2*np.pi), dim=2
29 | )
30 | else:
31 | inphase = torch.cumsum(
32 | (f0_sum * np.pi / sampling_rate * hop_size)%(2*np.pi), dim=2
33 | )
34 |
35 | inphase = torch.cat(
36 | (torch.zeros(B, C, 1).to(device), inphase), dim=2
37 | ) % (2*np.pi)
38 | return inphase # [batch, C, T]
39 |
40 | def replicate(t, x, batch):
41 | """
42 | Replicates tensor t to have length x.
43 | Args:
44 | t: input tensor [batch, channels, time]
45 | x: output tensor length
46 | batch: int
47 | Returns:
48 | replicated: tensor [batch, channels, x]
49 | """
50 | repeat_times = (x + t.size(-1) - 1) // t.size(-1)
51 | replicated = t.repeat(batch, 1, repeat_times)
52 | replicated = replicated[:,:,:x]
53 | return replicated
54 |
55 | def get_remove_above_fmax(n_harm, pitch, fmax, level_start=1):
56 | '''
57 | Args:
58 | pitch: b x t x 1
59 | fmax: float
60 | level_start: int, default 1
61 | Returns:
62 | aa: b x t x n_harm
63 | '''
64 | pitches = pitch * torch.arange(level_start, n_harm + level_start).to(pitch)
65 | rm = (pitches < fmax).float() + 1e-7
66 | return rm
67 |
68 | class DotDict(dict):
69 | def __getattr__(*args):
70 | val = dict.get(*args)
71 | return DotDict(val) if type(val) is dict else val
72 |
73 | __setattr__ = dict.__setitem__
74 | __delattr__ = dict.__delitem__
75 |
76 | def load_model(model_path, device='cpu'):
77 | config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml')
78 | with open(config_file, "r") as config:
79 | args = yaml.unsafe_load(config)
80 | args = DotDict(args)
81 |
82 | model_path = Path(model_path)
83 | # load model
84 | print(' [Loading] ' + str(model_path))
85 | if model_path.suffix == '.jit':
86 | model = torch.jit.load(model_path, map_location=torch.device(device))
87 | else:
88 | if args.model.type == 'SinStack':
89 | model = SinStack(
90 | args,
91 | device=device)
92 | else:
93 | raise ValueError(f" [x] Unknown Model: {args.model.type}")
94 | model.to(device)
95 | ckpt = torch.load(model_path, map_location=torch.device(device))
96 | model.load_state_dict(ckpt['model'])
97 | model.eval()
98 | return model, args
99 |
100 | class SinStack(torch.nn.Module):
101 | def __init__(self,
102 | args,
103 | device='cuda'):
104 | super().__init__()
105 |
106 | print(' [DDSP Model] Combtooth Subtractive Synthesiser')
107 | # params
108 | self.register_buffer("sampling_rate", torch.tensor(args.data.sampling_rate))
109 | self.register_buffer("hop_size", torch.tensor(args.data.hop_size))
110 | self.register_buffer("win_length", torch.tensor(2*args.data.hop_size))
111 | self.register_buffer("window", torch.hann_window(2*args.data.hop_size))
112 | self.register_buffer("sin_mag", torch.tensor(args.model.n_sin_hars))
113 | self.register_buffer("noise_mag", torch.tensor(args.model.n_noise_bin))
114 | self.register_buffer("uv_noise_k", torch.tensor(args.model.uv_noise_k))
115 |
116 | # Mel2Control
117 | split_map = {
118 | 'sin_mag' : args.model.n_sin_hars,
119 | 'sin_phase': args.model.n_sin_hars,
120 | 'noise_mag': args.model.n_noise_bin,
121 | }
122 |
123 | self.register_buffer("u_noise", \
124 | torch.load('u_noise.ckpt', map_location=torch.device(device)))
125 | self.register_buffer("v_noise", \
126 | torch.load('v_noise.ckpt', map_location=torch.device(device)))
127 |
128 | print(' [DDSP Model] Mel2Control',self.v_noise.shape)
129 |
130 | self.mel2ctrl = Mel2Control(
131 | args.data.n_mels,
132 | args.model.n_sin_hars,
133 | args.data.hop_size,
134 | split_map
135 | )
136 | self.sine_generator = Sine_Generator(
137 | args.data.hop_size,
138 | args.data.sampling_rate,
139 | device=device
140 | )
141 | self.noise_generator = Noise_Generator(
142 | args.data.sampling_rate,
143 | args.data.hop_size,
144 | self.v_noise,
145 | self.u_noise,
146 | device=device,
147 | triangle_ReLU = args.model.triangle_ReLU,
148 | triangle_ReLU_up = args.model.triangle_ReLU_up,
149 | triangle_ReLU_down = args.model.triangle_ReLU_down,
150 | )
151 |
152 | self.device = device
153 |
154 | def phase_prediction(self, phase_pre_model, sin_phase, f0_list,inference):
155 | '''
156 | Args:
157 | f0_list: [shape = (batch, max_nhar, T)]
158 | sin_phase: [shape = (batch, max_nhar, T)]
159 | inference: bool, default False
160 |
161 | Returns:
162 | sin_phase: [shape = (batch, max_nhar, T)]
163 | '''
164 | if phase_pre_model == 'offset':
165 | f0_sum = f0_list[:, :, 1:] + f0_list[:, :, :-1] # [batch, max_nhar, T-1]
166 | inphase = compute_inphase(
167 | f0_sum,
168 | self.hop_size,
169 | self.sampling_rate,
170 | self.device,
171 | inference=inference
172 | )
173 | sin_phase = inphase + sin_phase # [batch, max_nhar, T]
174 | elif phase_pre_model == 'adjacent difference':
175 | f0_sum = (f0_list[:, 0, 1:] + f0_list[:, 0, :-1]).unsqueeze(1) # [batch, 1, T-1]
176 | inphase = compute_inphase(
177 | f0_sum,
178 | self.hop_size,
179 | self.sampling_rate,
180 | self.device,
181 | inference=inference
182 | )
183 | sin_phase[:, 0, :] = sin_phase[:, 0, :] + inphase.squeeze(1) # 为基频添加初始相位
184 | sin_phase = torch.cumsum(sin_phase, dim=1) # [batch, max_nhar, T]
185 | elif phase_pre_model == 'fundamental difference':
186 | f0_sum = (f0_list[:, 0, 1:] + f0_list[:, 0, :-1]).unsqueeze(1) # [batch, 1, T-1]
187 | inphase = compute_inphase(
188 | f0_sum,
189 | self.hop_size,
190 | self.sampling_rate,
191 | self.device,
192 | inference=inference
193 | )
194 | sin_phase[:, 0, :] = sin_phase[:, 0, :] + inphase.squeeze(1) # 为基频添加初始相位
195 | sin_phase[:, 1:, :] = sin_phase[:, 1:, :] + sin_phase[:, 0, :].unsqueeze(1)
196 | elif phase_pre_model == 'absolute position':
197 | pass
198 | else:
199 | raise ValueError(f" [x] Unknown phase_pre_model: {phase_pre_model}")
200 |
201 | return sin_phase
202 |
203 |
204 | def forward(
205 | self,
206 | mel_frames,
207 | f0_frames,
208 | inference=False,
209 | phase_pre_model='offset', # 'offset','absolute position' or 'difference'
210 | **kwargs
211 | ):
212 | '''
213 | mel_frames: B x n_mels x n_frames
214 | f0_frames: B x n_frames x 1
215 | '''
216 | nhar_range = torch.arange(
217 | start = 1, end = self.sin_mag + 1, device=self.device
218 | ).unsqueeze(0).unsqueeze(-1) # [max_nhar] -> [1, max_nhar, 1]
219 |
220 | f0_list = f0_frames.unsqueeze(1).squeeze(3) * nhar_range # [batch, 1, T] * [1, max_nhar, 1] -> [batch, max_nhar, T]
221 | inp = self.phase_prediction(phase_pre_model, torch.tensor(0.0), f0_list, inference)
222 |
223 | inp = inp.to(torch.float32)
224 |
225 | # parameter prediction
226 | ctrls = self.mel2ctrl(mel_frames.transpose(1, 2), inp.transpose(1, 2))
227 |
228 | sin_mag = torch.exp(ctrls['sin_mag']) / 128 # b x T x max_nhar
229 | sin_phase = ctrls['sin_phase'] # b x T x max_nhar
230 | noise_mag = torch.exp(ctrls['noise_mag']) / 128 # b x T x n_noise
231 |
232 | # permutation
233 | sin_mag = sin_mag.permute(0, 2, 1) # b x max_nhar x T
234 | sin_phase = sin_phase.permute(0, 2, 1) # b x max_nhar x T
235 | noise_mag = noise_mag.permute(0, 2, 1) # b x n_noise x T
236 | B, max_nhar, T = sin_mag.shape
237 |
238 | # remove above fmax
239 | rm_mask = get_remove_above_fmax(
240 | max_nhar,
241 | f0_frames,
242 | fmax = self.sampling_rate / 2
243 | ).permute(0, 2, 1)
244 | # sin_phase = self.phase_prediction(phase_pre_model, sin_phase, f0_list, inference)
245 | sin_phase = sin_phase + inp # 当使用offset时,不需要算两遍初始相位。使用其它模式时需要注释掉这行使用上面那一行
246 | sin_mag, sin_phase = rm_mask * sin_mag, rm_mask * sin_phase
247 |
248 | harmonic = self.sine_generator(sin_mag, sin_phase, f0_list)
249 |
250 | # get uv mask
251 | noise_mag_total = noise_mag.sum(dim=1) # b x T
252 | harmonic_mag_first = sin_mag[:, 0, :] # b x T
253 | uv_mask = self.uv_noise_k *harmonic_mag_first / (self.uv_noise_k*harmonic_mag_first + noise_mag_total + 1e-7) # b x T
254 |
255 | # noise generation
256 | noise = self.noise_generator(noise_mag, uv_mask, f0_frames.permute(0, 2, 1))
257 |
258 | signal = harmonic + noise # [batch, T*hop_size]
259 |
260 | return signal, 0, (harmonic, noise), (sin_mag, sin_phase)
261 |
262 |
263 |
264 | class Sine_Generator(torch.nn.Module):
265 | def __init__(self, hop_size, sampling_rate, device='cpu'):
266 | super().__init__()
267 | self.hop_size = hop_size
268 | self.win_size = hop_size * 2
269 | self.window = torch.hann_window(self.win_size).to(device)
270 | self.sampling_rate = sampling_rate
271 | self.device = device
272 |
273 | def forward(self, ampl, phase, f0_list, inference=False):
274 | '''
275 | ampl: B x max_nhar x T
276 | phase: B x max_nhar x T
277 | f0_list: B x max_nhar x T
278 | '''
279 | B, _, T = ampl.shape
280 | x_list = torch.arange(-self.hop_size, self.hop_size).to(self.device)
281 | x_list = x_list.unsqueeze(0).unsqueeze(0).unsqueeze(0) # 1 x 1 x 1 x win_size
282 |
283 | freq_list = (2 * np.pi * f0_list / self.sampling_rate).unsqueeze(-1) * x_list + phase.unsqueeze(-1) # [batch, max_nhar, T, win_size]
284 |
285 | y_tmp = torch.cos(freq_list) * ampl.unsqueeze(-1) # [batch, max_nhar, T, win_size]
286 | y_tmp = y_tmp.sum(dim=1) # [batch, T, win_size]
287 |
288 | hann_window = self.window.unsqueeze(0).unsqueeze(0) # [1, 1, win_size]
289 | y_tmp_weighted = y_tmp * hann_window # [batch, T, win_size]
290 |
291 | '''# 将 y_tmp_weighted 转换为适合 fold 的形状
292 | y_tmp_weighted_reshaped = y_tmp_weighted.permute(0, 2, 1).contiguous()
293 | y_tmp_weighted_reshaped = y_tmp_weighted_reshaped.view(B * self.win_size, T)
294 | y_tmp_weighted = y_tmp * hann_window # [batch, T, win_size]
295 | # 使用 fold 函数来实现滑动窗口效果
296 | output = torch.nn.functional.fold(
297 | y_tmp_weighted_reshaped.unsqueeze(0),
298 | output_size=(T * self.hop_size + self.hop_size, 1),
299 | kernel_size=(self.win_size, 1),
300 | stride=(self.hop_size, 1)
301 | )
302 | output = output[:, :, self.hop_size:]
303 |
304 | y_return = output.squeeze(0).view(B, T * self.hop_size)'''
305 |
306 | y_tmp_padded = F.pad(y_tmp_weighted , (0, 0, 0,T % 2), "constant", 0)
307 | new_T = y_tmp_padded.shape[1]
308 |
309 | y_tmp_reshaped = y_tmp_padded.view(B, new_T//2, 2, self.win_size)
310 | tensor_even = y_tmp_reshaped[:, :, 0, :] # 偶数时间步
311 | tensor_odd = y_tmp_reshaped[:, :, 1, :] # 奇数时间步
312 |
313 | tensor_even = tensor_even.reshape(B, (new_T//2)*self.win_size)
314 | tensor_odd = tensor_odd.reshape(B, (new_T//2)*self.win_size)
315 |
316 | tensor_even = F.pad(tensor_even, (0, self.hop_size), "constant", 0)
317 | tensor_odd = F.pad(tensor_odd, (self.hop_size, 0), "constant", 0)
318 | cat_tensor = torch.cat((tensor_even.unsqueeze(-1), tensor_odd.unsqueeze(-1)), dim=2)
319 | sum_tensor = torch.sum(cat_tensor, dim=2)
320 | y_return = sum_tensor[:, self.hop_size:T*self.hop_size + self.hop_size]
321 |
322 | return y_return # [batch, T*hop_size]
323 |
324 | class Sine_Generator_Fast(torch.nn.Module):
325 | def __init__(self, hop_size, sampling_rate, device='cpu'):
326 | super().__init__()
327 | self.hop_size = hop_size
328 | self.win_size = hop_size * 2
329 | self.window = torch.hann_window(self.win_size).to(device)
330 | self.sampling_rate = sampling_rate
331 | self.device = device
332 |
333 | def forward(self, ampl, phase, f0_list, inference=False):
334 | '''
335 | ampl: B x max_nhar x T
336 | phase: B x max_nhar x T
337 | f0_list: B x max_nhar x T
338 | '''
339 | B, max_nhar, T = ampl.shape
340 |
341 | k = 16
342 | winsize = self.win_size
343 | x_start_list = torch.arange(-winsize/2, winsize/2, winsize/k, device=self.device)
344 | x_start_list_c = x_start_list.unsqueeze(0).unsqueeze(0).unsqueeze(0) # 1 x 1 x 1 x k
345 |
346 | x_start_list = (x_start_list+ winsize/2).to(torch.int)
347 |
348 | omega = 2 * np.pi * f0_list / self.sampling_rate # [batch, max_nhar, T]
349 | omega = omega.unsqueeze(-1) # [batch, max_nhar, T, 1]
350 |
351 | ampl = ampl.unsqueeze(-1) # [batch, max_nhar, T, 1]
352 | phase = phase.unsqueeze(-1) # [batch, max_nhar, T, 1]
353 |
354 | c = 2 * torch.cos(omega) # [batch, max_nhar, T, 1]
355 | y_tmp = torch.zeros(B, max_nhar, T, winsize, device=self.device)
356 | y_tmp[:,:,:,x_start_list] = ampl * torch.cos(omega * x_start_list_c + phase) # [batch, max_nhar, T, k]
357 | y_tmp[:,:,:,x_start_list+1] = ampl * torch.cos(omega * (x_start_list_c + 1) + phase) # [batch, max_nhar, T, k]
358 |
359 |
360 | for i in range(2, int(winsize/k)):
361 | y_tmp[:,:,:,x_start_list+i] = c * y_tmp[:,:,:,x_start_list+i-1] - y_tmp[:,:,:,x_start_list+i-2]
362 |
363 | # y_tmp = torch.cos(freq_list) * ampl.unsqueeze(-1) # [batch, max_nhar, T, win_size]
364 | y_tmp = y_tmp.sum(dim=1) # [batch, T, win_size]
365 |
366 | hann_window = self.window.unsqueeze(0).unsqueeze(0) # [1, 1, win_size]
367 |
368 | y_tmp_weighted = y_tmp * hann_window # [batch, T, win_size]
369 |
370 | '''
371 | # 将 y_tmp_weighted 转换为适合 fold 的形状
372 | y_tmp_weighted_reshaped = y_tmp_weighted.permute(0, 2, 1).contiguous()
373 | y_tmp_weighted_reshaped = y_tmp_weighted_reshaped.view(B * self.win_size, T)
374 |
375 | # 使用 fold 函数来实现滑动窗口效果
376 | output = torch.nn.functional.fold(
377 | y_tmp_weighted_reshaped.unsqueeze(0),
378 | output_size=(T * self.hop_size + self.hop_size, 1),
379 | kernel_size=(self.win_size, 1),
380 | stride=(self.hop_size, 1)
381 | )
382 | output = output[:, :, self.hop_size:]
383 |
384 | y_return = output.squeeze(0).view(B, T * self.hop_size)'''
385 |
386 |
387 | y_tmp_padded = F.pad(y_tmp_weighted , (0, 0, 0,T % 2), "constant", 0)
388 | new_T = y_tmp_padded.shape[1]
389 |
390 | y_tmp_reshaped = y_tmp_padded.view(B, new_T//2, 2, self.win_size)
391 | tensor_even = y_tmp_reshaped[:, :, 0, :] # 偶数时间步
392 | tensor_odd = y_tmp_reshaped[:, :, 1, :] # 奇数时间步
393 |
394 | tensor_even = tensor_even.reshape(B, (new_T//2)*self.win_size)
395 | tensor_odd = tensor_odd.reshape(B, (new_T//2)*self.win_size)
396 |
397 | tensor_even = F.pad(tensor_even, (0, self.hop_size), "constant", 0)
398 | tensor_odd = F.pad(tensor_odd, (self.hop_size, 0), "constant", 0)
399 | cat_tensor = torch.cat((tensor_even.unsqueeze(-1), tensor_odd.unsqueeze(-1)), dim=2)
400 | sum_tensor = torch.sum(cat_tensor, dim=2)
401 | y_return = sum_tensor[:, self.hop_size:T*self.hop_size + self.hop_size]
402 |
403 | return y_return # [batch, T*hop_size]
404 |
405 |
406 | class Noise_Generator(torch.nn.Module):
407 |
408 | def __init__(self, sampling_rate, hop_size, v_noise, u_noise, triangle_ReLU = True ,triangle_ReLU_up = 0.2, triangle_ReLU_down = 0.8,device='cpu'):
409 | super().__init__()
410 | self.sampling_rate = sampling_rate
411 | self.hop_size = hop_size
412 | self.device = device
413 | self.noiseop = v_noise
414 | self.noiseran = u_noise
415 | self.triangle_ReLU = triangle_ReLU
416 | self.triangle_ReLU_up = triangle_ReLU_up
417 | self.triangle_ReLU_down = triangle_ReLU_down
418 |
419 | @staticmethod
420 | def Triangle_ReLU(x:torch.Tensor, x1,x2):
421 | '''
422 | Triangle ReLU activation function
423 | '''
424 | return -(torch.relu(-(x-x1))/x1) - (torch.relu(x-x1)/x2) + torch.relu(x-x1-x2)/x2 +1
425 |
426 | def fast_phase_gen(self, f0_frames):
427 | n = torch.arange(self.hop_size, device=f0_frames.device)
428 | s0 = f0_frames / self.sampling_rate
429 | ds0 = F.pad(s0[:, 1:, :] - s0[:, :-1, :], (0, 0, 0, 1))
430 | rad = s0 * (n + 1) + 0.5 * ds0 * n * (n + 1) / self.hop_size
431 | rad2 = torch.fmod(rad[..., -1:].float() + 0.5, 1.0) - 0.5
432 | rad_acc = rad2.cumsum(dim=1).fmod(1.0).to(f0_frames)
433 | rad += F.pad(rad_acc[:, :-1, :], (0, 0, 1, 0))
434 | phase = rad.reshape(f0_frames.shape[0], -1, 1)%1
435 | return phase
436 |
437 | def forward(self, noise_mag, uv_mask, f0_frames):
438 | '''
439 | noise_mag: B x n_noise x T
440 | uv_mask: B x 1 x T
441 | f0_frames: B x 1 x T
442 | return: B x T*hop_size
443 | '''
444 | B, _, T = noise_mag.shape
445 | noise_mag_upsamp = upsample(noise_mag, self.hop_size) # b x n_noise x T*hop_size
446 | uv_mask_upsamp = upsample(uv_mask.unsqueeze(1), self.hop_size) # b x 1 x T*hop_size
447 | noiseran = replicate(self.noiseran, T*self.hop_size, batch=B) # b x n_noise x T*hop_size
448 |
449 | if self.triangle_ReLU:
450 | x = self.fast_phase_gen(f0_frames.transpose(1,2)).transpose(1,2) # b x 1 x T*hop_size
451 |
452 | triangle_mask = self.Triangle_ReLU(x, self.triangle_ReLU_up, self.triangle_ReLU_down)
453 | triangle_mask = triangle_mask * uv_mask_upsamp + (1-uv_mask_upsamp)
454 | noise_ = noiseran * triangle_mask # b x n_noise x T*hop_size
455 | else:
456 | noiseop = replicate(self.noiseop, T*self.hop_size, batch=B)
457 | noise_ = (noiseop * uv_mask_upsamp + noiseran * (1 - uv_mask_upsamp)) # b x n_noise x T*hop_size
458 |
459 | noise_ = noiseran * triangle_mask # b x n_noise x T*hop_size
460 | noise = noise_ * noise_mag_upsamp # b x n_noise x T*hop_size
461 | noise = noise.sum(dim=1) # b x T*hop_size
462 | return noise
--------------------------------------------------------------------------------
/export.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os.path
3 |
4 | import torch
5 |
6 | from ddsp.vocoder import load_model
7 |
8 |
9 | class DDSPWrapper(torch.nn.Module):
10 | def __init__(self, module, device):
11 | super().__init__()
12 | self.model = module
13 | self.to(device)
14 |
15 | def forward(self, mel, f0):
16 | mel = mel.transpose(1, 2)
17 | f0 = f0[..., None]
18 | signal, _, (harmonic, noise), (sin_mag, sin_phase) = self.model(mel, f0)
19 | print(f' [Output] signal: {signal.shape}, harmonic: {harmonic.shape}, noise: {noise.shape}, sin_mag: {sin_mag.shape}, sin_phase: {sin_phase.shape}')
20 | return signal
21 |
22 | def parse_args(args=None, namespace=None):
23 | parser = argparse.ArgumentParser(
24 | description='Export model to standalone PyTorch traced module or ONNX format'
25 | )
26 | parser.add_argument(
27 | '-m',
28 | '--model_path',
29 | type=str,
30 | required=True,
31 | help='path to model file'
32 | )
33 | parser.add_argument(
34 | '--traced',
35 | required=False,
36 | action='store_true',
37 | help='export to traced module format'
38 | )
39 | parser.add_argument(
40 | '--onnx',
41 | required=False,
42 | action='store_true',
43 | help='export to ONNX format'
44 | )
45 | cmd = parser.parse_args(args=args, namespace=namespace)
46 | if not cmd.traced and not cmd.onnx:
47 | parser.error('either --traced or --onnx should be specified.')
48 | return cmd
49 |
50 |
51 | def main():
52 | device = 'cpu'
53 | # parse commands
54 | cmd = parse_args()
55 |
56 | # load model
57 | model, args = load_model(cmd.model_path, device=device)
58 | model = DDSPWrapper(model, device)
59 |
60 | # extract model dirname and filename
61 | directory = os.path.dirname(os.path.abspath(cmd.model_path))
62 | name = os.path.basename(cmd.model_path).rsplit('.', maxsplit=1)[0]
63 |
64 | # load input
65 | n_mel_channels = args.data.n_mels
66 | n_frames = 10
67 | mel = torch.randn((1, n_frames, n_mel_channels), dtype=torch.float32, device=device)
68 | f0 = torch.FloatTensor([[440.] * n_frames]).to(device)
69 | print(f' [Input] mel: {mel.shape}, f0: {f0.shape}')
70 |
71 | # export model
72 | with torch.no_grad():
73 | if cmd.traced:
74 | torch_version = torch.version.__version__.rsplit('+', maxsplit=1)[0]
75 | export_path = os.path.join(directory, f'{name}-traced-torch{torch_version}.jit')
76 | print(f' [Tracing] {cmd.model_path} => {export_path}')
77 | model = torch.jit.trace(
78 | model,
79 | (
80 | mel,
81 | f0
82 | ),
83 | check_trace=False
84 | )
85 | torch.jit.save(model, export_path)
86 |
87 | elif cmd.onnx:
88 | # Prepare the export path for ONNX format
89 | onnx_version = "unknown"
90 | try:
91 | import onnx
92 | onnx_version = onnx.__version__
93 | except ImportError:
94 | print("Warning: ONNX package is not installed. Please install it to enable ONNX export.")
95 | return
96 |
97 | export_path = os.path.join(directory, f'{name}-torch{torch.version.__version__[:5]}-onnx{onnx_version}.onnx')
98 | print(f' [Exporting] {cmd.model_path} => {export_path}')
99 |
100 | # Export the model to ONNX
101 | torch.onnx.export(
102 | model,
103 | (mel, f0),
104 | export_path,
105 | export_params=True,
106 | opset_version=15,
107 | input_names=['mel', 'f0'],
108 | output_names=['output'],
109 | dynamic_axes={'mel': {1: 'n_frames'}, 'f0': {1: 'n_frames'},
110 | 'output': {1: 'n_samples'}}
111 | )
112 |
113 | print(f"Model has been successfully exported to {export_path}")
114 | def simplify_onnx_model():
115 |
116 | import onnx
117 | from onnxsim import simplify
118 |
119 | model = onnx.load(r"E:\pc-ddsp5.29\exp\qixuan8\model_19000-torch2.1.0-onnx1.16.2.onnx")
120 |
121 | model_simp, check = simplify(model)
122 |
123 | assert check, "Simplified ONNX model could not be validated"
124 |
125 | onnx.save(model_simp, 'output_model_simplified.onnx')
126 |
127 | if __name__ == '__main__':
128 | main()
129 | #simplify_onnx_model()
--------------------------------------------------------------------------------
/harmonic_noise_extract.py:
--------------------------------------------------------------------------------
1 | import os
2 | import soundfile as sf
3 | import numpy as np
4 | import click
5 | import torch
6 |
7 | from ddsp.audio_analysis import F0Analyzer, SinusoidalAnalyzer
8 | from ddsp.vocoder import Sine_Generator, Sine_Generator_Fast
9 |
10 | @click.command()
11 | @click.option('--input_file', type=click.Path(exists=True))
12 | def main(input_file):
13 | # Load input file
14 | audio, sr = sf.read(input_file)
15 | audio= torch.from_numpy(audio).float()
16 | hop_size = 512
17 | sin_mag = 64
18 |
19 | f0, uv = F0Analyzer(sampling_rate = sr,
20 | f0_extractor = 'parselmouth',
21 | hop_size = hop_size,
22 | f0_min = 30,
23 | f0_max = 800)(audio, len(audio)//hop_size)
24 | f0=torch.from_numpy(f0).float()
25 | f0_frames = f0.unsqueeze(0).unsqueeze(1) # [1 x 1 x n_frames]
26 | uv = torch.from_numpy(~uv).float()
27 |
28 | nhar_range = torch.arange(
29 | start = 1, end = sin_mag + 1).unsqueeze(0).unsqueeze(-1) # [max_nhar] -> [1, max_nhar, 1]
30 | f0_list = f0_frames * nhar_range # [1 x max_nhar x n_frames]
31 |
32 | ampl, phase = SinusoidalAnalyzer(sampling_rate = sr,
33 | hop_size = hop_size,
34 | max_nhar = sin_mag,
35 | relative_winsize = 4
36 | )(
37 | audio,
38 | f0=f0,
39 | )
40 |
41 | ampl = ampl*(uv).unsqueeze(0)
42 | phase = phase*(uv).unsqueeze(0)
43 |
44 | harmonic_audio = Sine_Generator_Fast(hop_size = hop_size,sampling_rate = sr)(ampl.unsqueeze(0), phase.unsqueeze(0), f0_list)
45 | harmonic_audio = harmonic_audio.squeeze(0)
46 |
47 | harmonic_audio = torch.nn.functional.pad( #把长度补全到audio的长度
48 | harmonic_audio, (0, audio.shape[0] - harmonic_audio.shape[0]), 'constant', 0)
49 |
50 | noise_audio = audio - harmonic_audio
51 |
52 | # Save output file
53 | harmonic_audio_file = os.path.splitext(input_file)[0] + '_harmonic.wav'
54 | sf.write(harmonic_audio_file, harmonic_audio.numpy(), sr)
55 |
56 | noise_audio_file = os.path.splitext(input_file)[0] + '_noise.wav'
57 | sf.write(noise_audio_file, noise_audio.numpy(), sr)
58 |
59 |
60 |
61 |
62 | if __name__ == '__main__':
63 | main()
64 |
--------------------------------------------------------------------------------
/logger/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjzxkxdn/Mini-DDSP/305f069c07f0214bf65a28f519e2149e864885a2/logger/__init__.py
--------------------------------------------------------------------------------
/logger/saver.py:
--------------------------------------------------------------------------------
1 | '''
2 | author: wayn391@mastertones
3 | '''
4 |
5 | import os
6 | import time
7 | import yaml
8 | import datetime
9 | import torch
10 |
11 | from torch.utils.tensorboard import SummaryWriter
12 |
13 | class Saver(object):
14 | def __init__(
15 | self,
16 | args,
17 | initial_global_step=-1):
18 |
19 | self.expdir = args.env.expdir
20 | self.sample_rate = args.data.sampling_rate
21 |
22 | # cold start
23 | self.global_step = initial_global_step
24 | self.init_time = time.time()
25 | self.last_time = time.time()
26 |
27 | # makedirs
28 | os.makedirs(self.expdir, exist_ok=True)
29 |
30 | # path
31 | self.path_log_info = os.path.join(self.expdir, 'log_info.txt')
32 |
33 | # ckpt
34 | os.makedirs(self.expdir, exist_ok=True)
35 |
36 | # writer
37 | self.writer = SummaryWriter(os.path.join(self.expdir, 'logs'))
38 |
39 | # save config
40 | path_config = os.path.join(self.expdir, 'config.yaml')
41 | with open(path_config, "w") as out_config:
42 | yaml.dump(dict(args), out_config)
43 |
44 |
45 | def log_info(self, msg):
46 | '''log method'''
47 | if isinstance(msg, dict):
48 | msg_list = []
49 | for k, v in msg.items():
50 | tmp_str = ''
51 | if isinstance(v, int):
52 | tmp_str = '{}: {:,}'.format(k, v)
53 | else:
54 | tmp_str = '{}: {}'.format(k, v)
55 |
56 | msg_list.append(tmp_str)
57 | msg_str = '\n'.join(msg_list)
58 | else:
59 | msg_str = msg
60 |
61 | # dsplay
62 | print(msg_str)
63 |
64 | # save
65 | with open(self.path_log_info, 'a') as fp:
66 | fp.write(msg_str+'\n')
67 |
68 | def log_value(self, dict):
69 | for k, v in dict.items():
70 | self.writer.add_scalar(k, v, self.global_step)
71 |
72 | def log_audio(self, dict):
73 | for k, v in dict.items():
74 | self.writer.add_audio(k, v, global_step=self.global_step, sample_rate=self.sample_rate)
75 |
76 | def get_interval_time(self, update=True):
77 | cur_time = time.time()
78 | time_interval = cur_time - self.last_time
79 | if update:
80 | self.last_time = cur_time
81 | return time_interval
82 |
83 | def get_total_time(self, to_str=True):
84 | total_time = time.time() - self.init_time
85 | if to_str:
86 | total_time = str(datetime.timedelta(
87 | seconds=total_time))[:-5]
88 | return total_time
89 |
90 | def save_model(
91 | self,
92 | model,
93 | optimizer,
94 | name='model',
95 | postfix='',
96 | to_json=False):
97 | # path
98 | if postfix:
99 | postfix = '_' + postfix
100 | path_pt = os.path.join(
101 | self.expdir , name+postfix+'.pt')
102 |
103 | # check
104 | print(' [*] model checkpoint saved: {}'.format(path_pt))
105 |
106 | # save
107 | torch.save({
108 | 'global_step': self.global_step,
109 | 'model': model.state_dict(),
110 | 'optimizer': optimizer.state_dict()}, path_pt)
111 |
112 | # to json
113 | '''
114 | if to_json:
115 | path_json = os.path.join(
116 | self.expdir , name+'.json')
117 | utils.to_json(path_params, path_json)'''
118 |
119 | def global_step_increment(self):
120 | self.global_step += 1
121 |
122 |
123 |
--------------------------------------------------------------------------------
/logger/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import json
4 | import torch
5 | from pathlib import Path
6 | from typing import Tuple, Any, Dict, TypeVar, Union
7 |
8 | def traverse_dir(
9 | root_dir: Path,
10 | extension: str ,
11 | amount: int = None ,
12 | str_include: str = None ,
13 | str_exclude: str = None ,
14 | is_pure: bool = False,
15 | is_sort: bool = False,
16 | is_ext: bool = True
17 | ) -> list:
18 | """
19 | 遍历目录,返回指定后缀的文件列表。
20 |
21 | Args:
22 | root_dir: 根目录
23 | extension: 文件后缀
24 | amount: 最大返回数量 (None 表示返回所有文件)
25 | str_include: 包含字符串
26 | str_exclude: 排除字符串
27 | is_pure: 是否返回相对路径
28 | is_sort: 是否按文件名排序
29 | is_ext: 是否包含后缀
30 |
31 | Returns:
32 | 文件列表
33 | """
34 | root_dir = Path(root_dir).resolve() # Ensure the path is absolute
35 | file_list = []
36 |
37 | for path in root_dir.rglob(f"*{'.' + extension if extension is not None else ''}"):
38 | path_str = str(path)
39 | relative_path = path_str[len(str(root_dir)) + 1:] if is_pure else path_str
40 |
41 | if str_include and str_include not in relative_path:
42 | continue
43 | if str_exclude and str_exclude in relative_path:
44 | continue
45 |
46 | if not is_ext:
47 | relative_path = '.'.join(relative_path.split('.')[:-1])
48 |
49 | file_list.append(relative_path)
50 |
51 | if amount is not None and len(file_list) >= amount:
52 | break
53 |
54 | if is_sort:
55 | file_list.sort()
56 |
57 | return file_list
58 |
59 | T = TypeVar('T', bound='DotDict')
60 |
61 | class DotDict(Dict[str, Any]):
62 | """
63 | DotDict 可以让字典用点符号访问。类似于访问对象。
64 | 和pc-ddsp中的DotDict基本一致,并且嵌套字典也可以。
65 |
66 | 例子:
67 | config = DotDict({'key': {'nested_key': 'value'}})
68 | print(config.key.nested_key)
69 | 输出: 'value'
70 | """
71 | def __init__(self, *args: Union[Dict[str, Any], 'DotDict'], **kwargs: Any) -> None:
72 | super().__init__()
73 | for arg in args:
74 | if isinstance(arg, dict):
75 | for key, value in arg.items():
76 | self[key] = self._convert(value)
77 | elif isinstance(arg, DotDict):
78 | self.update(arg)
79 | for key, value in kwargs.items():
80 | self[key] = self._convert(value)
81 |
82 | def _convert(self, value: Any) -> Any:
83 | if isinstance(value, dict) and not isinstance(value, DotDict):
84 | return DotDict(value)
85 | return value
86 |
87 | def __getattr__(self: T, name: str) -> Any:
88 | try:
89 | value = self[name]
90 | if isinstance(value, dict) and not isinstance(value, DotDict):
91 | self[name] = value = DotDict(value)
92 | return value
93 | except KeyError:
94 | raise AttributeError(f"'DotDict' object has no attribute '{name}'")
95 |
96 | def __setattr__(self, name: str, value: Any) -> None:
97 | self[name] = self._convert(value)
98 |
99 | def __delattr__(self, name: str) -> None:
100 | try:
101 | del self[name]
102 | except KeyError:
103 | raise AttributeError(f"'DotDict' object has no attribute '{name}'")
104 |
105 |
106 | def get_network_paras_amount(model_dict: Dict[str, torch.nn.Module]) -> Dict[str, Tuple[int, int]]:
107 | """
108 | 计算模型参数数量。和pc-ddsp中的函数基本一致,多了参数总量和可训练参数总量的输出。
109 |
110 | Args:
111 | model_dict: 模型字典.
112 |
113 | Returns:
114 | (total_params, trainable_params).
115 | """
116 | info = {}
117 | for model_name, model in model_dict.items():
118 | total_params = sum(p.numel() for p in model.parameters())
119 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
120 |
121 | info[model_name] = (total_params, trainable_params)
122 |
123 | return info
124 |
125 |
126 | def load_config(path_config: Path) -> 'DotDict':
127 | """
128 | 修改自 pc-ddsp 中的 load_config 函数,添加了检查。
129 | 加载并解析给定路径的 YAML 配置文件,返回一个 DotDict 对象,
130 | 允许使用点符号访问配置项。
131 |
132 | Args:
133 | path_config: 配置文件的路径
134 |
135 | Return:
136 | 包含配置数据的 DotDict 对象
137 | """
138 | try:
139 | with open(path_config, "r") as config_file:
140 | args = yaml.safe_load(config_file)
141 | return DotDict(args)
142 | except FileNotFoundError:
143 | raise ValueError(f"配置文件未找到: {path_config}")
144 | except yaml.YAMLError:
145 | raise ValueError(f"YAML 文件格式错误: {path_config}")
146 |
147 | def validate_config(config: DotDict) -> None:
148 | """
149 | 校验配置文件的正确性,抛出异常如果发现任何问题。
150 |
151 | Args:
152 | config: 包含配置数据的 DotDict 对象
153 | """
154 | # 数据部分的校验
155 | assert config.data.f0_extractor in ['parselmouth', 'dio', 'harvest'], "f0_extractor 必须是 'parselmouth', 'dio', 或 'harvest'"
156 | assert isinstance(config.data.f0_min, int), "f0_min 必须是整数"
157 | assert isinstance(config.data.f0_max, int), "f0_max 必须是整数"
158 | assert config.data.f0_min < config.data.f0_max, "f0_min 必须小于 f0_max"
159 | assert config.data.sampling_rate > 0, "sampling_rate 必须大于 0"
160 | assert config.data.n_fft > 0, "n_fft 必须大于 0"
161 | assert config.data.win_length > 0, "win_length 必须大于 0"
162 | assert config.data.block_size > 0, "block_size 必须大于 0"
163 | assert config.data.block_size == config.data.win_length // 2, "block_size 应等于 win_length 的一半"
164 | assert config.data.n_mels > 0, "n_mels 必须大于 0"
165 | assert config.data.mel_fmin >= 0, "mel_fmin 必须大于或等于 0"
166 | assert config.data.mel_fmax > config.data.mel_fmin, "mel_fmax 必须大于 mel_fmin"
167 | assert config.data.duration > 0, "duration 必须大于 0"
168 | assert isinstance(config.data.train_path, str), "train_path 必须是字符串"
169 | assert isinstance(config.data.valid_path, str), "valid_path 必须是字符串"
170 |
171 | # 模型部分的校验
172 | assert config.model.type in ['CombSub'], "model.type 必须是 'CombSub'"
173 | assert config.model.win_length > 0, "model.win_length 必须大于 0"
174 | assert isinstance(config.model.use_mean_filter, bool), "use_mean_filter 必须是布尔值"
175 | assert config.model.n_mag_harmonic > 0, "n_mag_harmonic 必须大于 0"
176 | assert config.model.n_mag_noise > 0, "n_mag_noise 必须大于 0"
177 |
178 | # 损失函数部分的校验
179 | assert config.loss.fft_min > 0, "fft_min 必须大于 0"
180 | assert config.loss.fft_max > config.loss.fft_min, "fft_max 必须大于 fft_min"
181 | assert config.loss.n_scale > 0, "n_scale 必须大于 0"
182 | assert config.loss.lambda_uv > 0, "lambda_uv 必须大于 0"
183 | assert config.loss.uv_tolerance >= 0, "uv_tolerance 必须大于或等于 0"
184 | assert config.loss.detach_uv_step > 0, "detach_uv_step 必须大于 0"
185 |
186 | # 设备部分的校验
187 | assert config.device in ['cuda', 'cpu'], "device 必须是 'cuda' 或 'cpu'"
188 |
189 | # 环境部分的校验
190 | assert isinstance(config.env.expdir, str), "expdir 必须是字符串"
191 | assert isinstance(config.env.gpu_id, int), "gpu_id 必须是整数"
192 |
193 | # 训练部分的校验
194 | assert config.train.num_workers >= 0, "num_workers 必须大于或等于 0"
195 | assert config.train.batch_size > 0, "batch_size 必须大于 0"
196 | assert config.train.epochs > 0, "epochs 必须大于 0"
197 | assert config.train.interval_log > 0, "interval_log 必须大于 0"
198 | assert config.train.interval_val > 0, "interval_val 必须大于 0"
199 | assert config.train.lr > 0, "lr 必须大于 0"
200 | assert config.train.weight_decay >= 0, "weight_decay 必须大于或等于 0"
201 |
202 |
203 |
204 | def to_json(path_params: str, path_json: str) -> None:
205 | # 修改自pc-ddsp,功能不变。
206 | params = torch.load(path_params, map_location=torch.device('cpu'))
207 | raw_state_dict: Dict[str, Any] = {k: v.tolist()
208 | for k, v in params.items()}
209 |
210 | # 先把字典转成json,这样可以减少文件锁定的时间。
211 | with open(path_json, 'w') as outfile:
212 | json_data = json.dumps(raw_state_dict, indent='\t')
213 | outfile.write(json_data)
214 |
215 | def load_model(
216 | expdir: Path,
217 | model: torch.nn.Module,
218 | optimizer: torch.optim.Optimizer,
219 | name: str = 'model',
220 | postfix: str = '',
221 | device: str = 'cpu'
222 | ) -> Tuple[int, torch.nn.Module, torch.optim.Optimizer]:
223 | """
224 | 加载模型和优化器。
225 | Args:
226 | expdir: 保存模型的目录
227 | model: 模型
228 | optimizer: 优化器
229 | name: 模型名称
230 | postfix: 后缀
231 | Returns:
232 | global_step: 全局步数
233 | model: 模型
234 | optimizer: 优化器
235 | """
236 | if postfix:
237 | postfix = '_' + postfix
238 |
239 | path = expdir / (name + postfix)
240 | path_pt = list(expdir.glob('*.pt'))
241 |
242 | global_step = 0
243 | if path_pt:
244 | steps = []
245 | for p in path_pt:
246 | if p.name.startswith(name + postfix):
247 | step = p.name[len(name + postfix)+1:].split('.')[0]
248 | if step == "best":
249 | steps = [-1]
250 | break
251 | else:
252 | steps.append(int(step))
253 |
254 | maxstep = max(steps or [0])
255 |
256 | if maxstep > 0:
257 | path_pt = path.with_name(f'{path.name}_{maxstep}.pt')
258 | else:
259 | path_pt = path.with_name(f'{path.name}_best.pt')
260 |
261 | print(' [*] Restoring model from', path_pt)
262 | ckpt = torch.load(path_pt, map_location=torch.device(device))
263 | global_step = ckpt.get('global_step', 0)
264 | model.load_state_dict(ckpt['model'])
265 | optimizer.load_state_dict(ckpt['optimizer'])
266 |
267 | return global_step, model, optimizer
268 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | from tqdm import tqdm
6 | from logger.utils import DotDict
7 | import multiprocessing
8 | import soundfile as sf
9 | import click
10 | from pathlib import Path
11 | from preprocess import Preprocessor
12 | from ddsp.vocoder import load_model
13 |
14 | def infer(
15 | model : torch.nn.Module,
16 | input : Path,
17 | output: Path,
18 | args : DotDict,
19 | key : float,
20 | device: str,
21 | sample_rate: int
22 | ):
23 | '''
24 | Args:
25 | input : audio file path
26 | output: output audio file path
27 | key : the key change in semitones
28 | '''
29 | # Process single file
30 | print(f"Processing file: {input}")
31 | audio, sr = sf.read(str(input))
32 |
33 | assert sr == sample_rate, f"\
34 | Sample rate of input file {sr} does not match \
35 | model sample rate {sample_rate}"
36 |
37 | # preprocess
38 | preprocessor = Preprocessor(args, device)
39 | mel, f0, uv=preprocessor.mel_f0_uv_process(torch.from_numpy(audio).float())
40 |
41 | print(f"Input shape: {mel.shape}, F0 shape: {f0.shape}, UV shape: {uv.shape}")
42 | print("f0dtype: ", f0.dtype, "uvdtype: ", uv.dtype)
43 | mel = mel.astype(np.float32)
44 | f0 = f0.astype(np.float32)
45 | uv = uv.astype(np.float32)
46 | print("f0dtype: ", f0.dtype, "uvdtype: ", uv.dtype)
47 | # np.save(output.with_suffix('.npy'), mel)
48 |
49 |
50 | # key change
51 | key_change = float(key)
52 | if key_change != 0:
53 | output_f0 = f0 * 2 ** (key_change / 12)
54 | else:
55 | output_f0 = None
56 |
57 | # forward and save the output
58 | with torch.no_grad():
59 | if output_f0 is None:
60 | signal, _, (s_h, s_n), (sin_mag, sin_phase) = model(torch.tensor(mel).float().unsqueeze(0).to(device), torch.tensor(f0).unsqueeze(0).unsqueeze(-1).to(device))
61 | else:
62 | signal, _, (s_h, s_n) = model(torch.tensor(mel).float().unsqueeze(0).to(device), torch.tensor(f0).unsqueeze(0).unsqueeze(-1).to(device))
63 | signal = signal.squeeze().cpu().numpy()
64 | s_h = s_h.squeeze().cpu().numpy()
65 | s_n = s_n.squeeze().cpu().numpy()
66 | sf.write(str(output), signal, args.data.sampling_rate,subtype='FLOAT')
67 | sf.write(str(output.with_suffix('.harmonic.wav')), s_h, args.data.sampling_rate,subtype='FLOAT')
68 | sf.write(str(output.with_suffix('.noise.wav')), s_n, args.data.sampling_rate,subtype='FLOAT')
69 |
70 | @click.command()
71 | @click.option(
72 | '--model_path', type=click.Path(
73 | exists=True, file_okay=True, dir_okay=False, readable=True,
74 | path_type=Path, resolve_path=True
75 | ),
76 | required=True, metavar='CONFIG_FILE',
77 | help='The path to the model.'
78 | )
79 | @click.option(
80 | '--input', type=click.Path(
81 | exists=True, file_okay=True, dir_okay=True, readable=True,
82 | path_type=Path, resolve_path=True
83 | ),
84 | required=True,
85 | help='The path to the WAV file or directory containing WAV files.'
86 | )
87 | @click.option(
88 | '--output', type=click.Path(
89 | exists=True, file_okay=True, dir_okay=True, readable=True,
90 | path_type=Path, resolve_path=True
91 | ),
92 | required=True,
93 | help='The path to the output directory.'
94 | )
95 | @click.option(
96 | '--key', type=int, default=0,
97 | help='key changed (number of semitones)'
98 | )
99 |
100 | def main(model_path, input, output, key):
101 |
102 | # cpu inference is fast enough!
103 | device = 'cpu'
104 | #device = 'cuda' if torch.cuda.is_available() else 'cpu'
105 |
106 | model, args = load_model(model_path, device=device)
107 | print(f"Model loaded: {model_path}")
108 |
109 | if input.is_file():
110 | infer(model, input, output / input.name, args, key, device, args.data.sampling_rate)
111 | elif input.is_dir():
112 | assert output.is_dir(),\
113 | "If input is a directory, output must be a directory as well."
114 | for file in tqdm(input.glob('*.wav')):
115 | infer(
116 | model,
117 | file,
118 | output / file.name,
119 | args,
120 | key,
121 | device,
122 | args.data.sampling_rate
123 | )
124 | if __name__ == '__main__':
125 | main()
--------------------------------------------------------------------------------
/onnx_infer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import onnxruntime
4 | import yaml
5 |
6 |
7 | from tqdm import tqdm
8 | from logger.utils import DotDict
9 | import soundfile as sf
10 | import click
11 | from pathlib import Path
12 | from preprocess import Preprocessor
13 | from ddsp.vocoder import load_model
14 |
15 | def infer_onnx(
16 | model : torch.nn.Module,
17 | input : Path,
18 | output: Path,
19 | args : DotDict,
20 | key : float,
21 | device: str,
22 | sample_rate: int
23 | ):
24 | '''
25 | Args:
26 | input : audio file path
27 | output: output audio file path
28 | key : the key change in semitones
29 | '''
30 | # Process single file
31 | print(f"Processing file: {input}")
32 | audio, sr = sf.read(str(input))
33 |
34 | assert sr == sample_rate, f"\
35 | Sample rate of input file {sr} does not match \
36 | model sample rate {sample_rate}"
37 |
38 | # preprocess
39 | preprocessor = Preprocessor(args, device)
40 | mel, f0, uv=preprocessor.mel_f0_uv_process(torch.from_numpy(audio).float())
41 |
42 | print(f"Input shape: {mel.shape}, F0 shape: {f0.shape}, UV shape: {uv.shape}")
43 | # np.save(output.with_suffix('.npy'), mel)
44 |
45 |
46 | # forward and save the output
47 | '''with torch.no_grad():
48 | if output_f0 is None:
49 | signal, _, (s_h, s_n), (sin_mag, sin_phase) = model(torch.tensor(mel).float().unsqueeze(0).to(device), torch.tensor(f0).unsqueeze(0).unsqueeze(-1).to(device))
50 | else:
51 | signal, _, (s_h, s_n) = model(torch.tensor(mel).float().unsqueeze(0).to(device), torch.tensor(f0).unsqueeze(0).unsqueeze(-1).to(device))
52 | signal = signal.squeeze().cpu().numpy()
53 | s_h = s_h.squeeze().cpu().numpy()
54 | s_n = s_n.squeeze().cpu().numpy()
55 | sf.write(str(output), signal, args.data.sampling_rate,subtype='FLOAT')
56 | sf.write(str(output.with_suffix('.harmonic.wav')), s_h, args.data.sampling_rate,subtype='FLOAT')
57 | sf.write(str(output.with_suffix('.noise.wav')), s_n, args.data.sampling_rate,subtype='FLOAT') '''
58 | #onnx inference
59 | input_name1 = model.get_inputs()[0].name
60 | input_name2 = model.get_inputs()[1].name
61 | output_name = model.get_outputs()[0].name
62 | input_shape1 = model.get_inputs()[0].shape
63 | input_shape2 = model.get_inputs()[1].shape
64 | output_shape = model.get_outputs()[0].shape
65 | print(f"Input name1: {input_name1}, Input name2: {input_name2}, Output name: {output_name}")
66 | print(f"Input shape1: {input_shape1}, Input shape2: {input_shape2}, Output shape: {output_shape}")
67 | # 把mel从[128,n]变成[1,128,n]
68 | mel = np.expand_dims(mel, axis=0)
69 | # 把mel从[1,128,n]变成[1,n,128]
70 | mel = np.transpose(mel, (0, 2, 1))
71 |
72 | f0 = np.expand_dims(f0, axis=0)
73 | ort_inputs = {input_name1: np.array(mel, dtype=np.float32), input_name2: np.array(f0, dtype=np.float32)}
74 |
75 | ort_outs = model.run(None, ort_inputs)
76 | signal = ort_outs[0]
77 | #把signal从[1,n]变成[n]
78 | signal = signal.squeeze()
79 | sf.write(str(output), signal, args.data.sampling_rate,subtype='FLOAT')
80 |
81 |
82 | @click.command()
83 | @click.option(
84 | '--model_path', type=click.Path(
85 | exists=True, file_okay=True, dir_okay=False, readable=True,
86 | path_type=Path, resolve_path=True
87 | ),
88 | required=True, metavar='CONFIG_FILE',
89 | help='The path to the model.'
90 | )
91 | @click.option(
92 | '--input', type=click.Path(
93 | exists=True, file_okay=True, dir_okay=True, readable=True,
94 | path_type=Path, resolve_path=True
95 | ),
96 | required=True,
97 | help='The path to the WAV file or directory containing WAV files.'
98 | )
99 | @click.option(
100 | '--output', type=click.Path(
101 | exists=True, file_okay=True, dir_okay=True, readable=True,
102 | path_type=Path, resolve_path=True
103 | ),
104 | required=True,
105 | help='The path to the output directory.'
106 | )
107 | @click.option(
108 | '--key', type=int, default=0,
109 | help='key changed (number of semitones)'
110 | )
111 |
112 | def main(model_path, input, output, key):
113 |
114 | # cpu inference is fast enough!
115 | device = 'cpu'
116 | #device = 'cuda' if torch.cuda.is_available() else 'cpu'
117 |
118 | #model, args = load_model(model_path, device=device)
119 | model = onnxruntime.InferenceSession(str(model_path))
120 | args = DotDict(yaml.load(open("E:/pc-ddsp5.29/configs/SinStack.yaml", 'r'), Loader=yaml.FullLoader))
121 | print(f"Model loaded: {model_path}")
122 |
123 | if input.is_file():
124 | infer_onnx(model, input, output / input.name, args, key, device, args.data.sampling_rate)
125 | elif input.is_dir():
126 | assert output.is_dir(),\
127 | "If input is a directory, output must be a directory as well."
128 | for file in tqdm(input.glob('*.wav')):
129 | infer_onnx(
130 | model,
131 | file,
132 | output / file.name,
133 | args,
134 | key,
135 | device,
136 | args.data.sampling_rate
137 | )
138 | if __name__ == '__main__':
139 | main()
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 |
5 |
6 | from tqdm import tqdm
7 | from logger.utils import DotDict, traverse_dir, load_config
8 | import soundfile as sf
9 | import click
10 | from pathlib import Path
11 | from ddsp.audio_analysis import MelAnalysis, F0Analyzer, SinusoidalAnalyzer
12 | from ddsp.utils import interp_f0, expand_uv
13 |
14 | class Preprocessor:
15 | def __init__(self, config: DotDict, device: str):
16 | self.config = config
17 |
18 | self.sampling_rate = config.data.sampling_rate
19 | self.train_path = Path(config.data.train_path)
20 | self.valid_path = Path(config.data.valid_path)
21 |
22 | self.device = device
23 | print(f'Preprocessor using device: {self.device}')
24 |
25 | self.mel_extractor = MelAnalysis(
26 | sampling_rate = self.sampling_rate,
27 | hop_size = config.data.hop_size,
28 | win_size = config.data.win_size,
29 | n_fft = config.data.n_fft,
30 | n_mels = config.data.n_mels,
31 | mel_fmin = config.data.mel_fmin,
32 | mel_fmax = config.data.mel_fmax,
33 | clamp = config.data.mel_clamp,
34 | device = self.device
35 | )
36 |
37 | self.f0_extractor = F0Analyzer(
38 | sampling_rate = self.sampling_rate,
39 | f0_extractor = config.data.f0_extractor,
40 | hop_size = config.data.hop_size,
41 | f0_min = config.data.f0_min,
42 | f0_max = config.data.f0_max,
43 | )
44 |
45 | self.sinusoidal_analyzer = SinusoidalAnalyzer(
46 | sampling_rate = self.sampling_rate,
47 | hop_size = config.data.hop_size,
48 | max_nhar = config.data.max_nhar,
49 | relative_winsize = config.data.relative_winsize,
50 | device = self.device
51 | )
52 |
53 | def __call__(self):
54 | return self.preprocess()
55 |
56 | def preprocess(self):
57 | for base_path in [self.train_path, self.valid_path]:
58 | # list files
59 | filelist = traverse_dir(
60 | base_path / "audio",
61 | extension="wav",
62 | is_pure=True,
63 | is_sort=True,
64 | is_ext=False,
65 | )
66 |
67 | for file in tqdm(filelist):
68 | path_harmonic_audio = base_path / "harmonic_audio" / f'{file}.wav'
69 | path_audio = base_path / "audio" / f'{file}.wav'
70 | path_phase = base_path / "phase" / f'{file}.npy'
71 | path_ampl = base_path / "ampl" / f'{file}.npy'
72 | path_mel = base_path / "mel" / f'{file}.npy'
73 | path_f0 = base_path / "f0" / f'{file}.npy'
74 | path_uv = base_path / "uv" / f'{file}.npy'
75 |
76 | # load audio 加载音频
77 | audio, sr = sf.read(str(path_audio))
78 | assert sr == self.sampling_rate, f'Sampling rate of {path_audio} is not {self.sampling_rate}'
79 | audio = torch.from_numpy(audio).float().to(self.device)
80 |
81 | haudio, sr = sf.read(str(path_harmonic_audio))
82 | assert sr == self.sampling_rate, f'Sampling rate of {path_harmonic_audio} is not {self.sampling_rate}'
83 | haudio = torch.from_numpy(haudio).float().to(self.device)
84 |
85 | try:
86 | assert audio.shape[0] == haudio.shape[0]
87 | # extract mel, f0, uv 特征提取
88 | mel, f0, uv = self.mel_f0_uv_process(audio)
89 |
90 | # extract amplitude and phase 振幅和相位分析
91 | tf0 = torch.from_numpy(f0).float().to(self.device)
92 | ampl, phase = self.ampl_phase_process(haudio, tf0)
93 | except:
94 | Path(path_audio).unlink(missing_ok=True)
95 | tqdm.write(f'Audio file {path_audio} f0 extraction failed. Deleted.')
96 | continue
97 |
98 | # 创建空文件
99 | path_mel.parent.mkdir(parents=True, exist_ok=True)
100 | path_f0.parent.mkdir(parents=True, exist_ok=True)
101 | path_uv.parent.mkdir(parents=True, exist_ok=True)
102 | path_phase.parent.mkdir(parents=True, exist_ok=True)
103 | path_ampl.parent.mkdir(parents=True, exist_ok=True)
104 |
105 | # save npy 保存女朋友
106 | np.save(path_mel, mel)
107 | np.save(path_f0 , f0 )
108 | np.save(path_uv , uv )
109 | np.save(path_phase, phase)
110 | np.save(path_ampl, ampl)
111 |
112 | def mel_f0_uv_process(self, audio: torch.Tensor):
113 | # extract mel 特征提取
114 | mel = self.mel_extractor(audio, diffsinger=True)
115 | mel = mel.to('cpu').numpy()
116 |
117 | # extract f0 and uv 基频分析
118 | f0, uv = self.f0_extractor(audio, n_frames=mel.shape[1])
119 |
120 | f0, _ = interp_f0(f0, uv)
121 | uv = expand_uv(uv)
122 |
123 | return mel, f0, uv
124 |
125 | def ampl_phase_process(self, audio: torch.Tensor, f0 : torch.Tensor):
126 | # extract amplitude and phase 振幅和相位分析
127 | ampl, phase = self.sinusoidal_analyzer(audio, f0, model = 'czt')
128 | ampl = ampl.to('cpu').numpy()
129 | phase = phase.to('cpu').numpy()
130 |
131 | return ampl, phase
132 |
133 |
134 |
135 | @click.command(help='Preprocess audio files')
136 | @click.option(
137 | '--config', type=click.Path(
138 | exists=True, file_okay=True, dir_okay=False, readable=True,
139 | path_type=Path, resolve_path=True
140 | ),
141 | required=True, metavar='CONFIG_FILE',
142 | help='The path to the config file.'
143 | )
144 | @click.option(
145 | '--device', type=str, default=None,
146 | help='The device to use for preprocessing.'
147 | )
148 | def main(config, device):
149 |
150 | if device is None:
151 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
152 | print(f'Preprocessor using device: {device}')
153 |
154 | # load config
155 | args = load_config(config)
156 |
157 | # TODO: add config validation
158 | # validate_config(args)
159 |
160 | preprocessor = Preprocessor(args, device)
161 | preprocessor()
162 |
163 | if __name__ == '__main__':
164 | main()
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | gin
2 | gin_config
3 | librosa
4 | numpy
5 | praat-parselmouth
6 | pyworld
7 | PyYAML
8 | SoundFile
9 | tqdm
10 | tensorboard
11 | onnx
12 | onnxruntime-gpu
13 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import numpy as np
4 | from pathlib import Path
5 | from typing import Any, Dict
6 | import click
7 |
8 |
9 | from logger import utils
10 | from data_loaders import get_data_loaders
11 | from ddsp.vocoder import SinStack
12 | from ddsp.loss import HybridLoss
13 | from logger.utils import DotDict, load_model
14 | from logger.saver import Saver
15 |
16 | class ModelTrainer:
17 | def __init__(self, config: DotDict, device: str):
18 | self.args = config
19 | self.device = device
20 |
21 | self.load_model = load_model
22 |
23 | self.model = SinStack(
24 | args=config,
25 | device=device
26 | ).to(device)
27 | self.optimizer = torch.optim.AdamW(self.model.parameters())
28 |
29 | self.initial_global_step, self.model, self.optimizer \
30 | = self.load_model(
31 | Path(config.env.expdir),
32 | self.model,
33 | self.optimizer
34 | )
35 |
36 | for param_group in self.optimizer.param_groups:
37 | param_group['lr'] = config.train.lr
38 | param_group['weight_decay'] = config.train.weight_decay
39 |
40 | self.loss_func = HybridLoss(
41 | config.data.hop_size,
42 | config.loss.fft_min,
43 | config.loss.fft_max,
44 | config.loss.n_scale,
45 | config.loss.lambda_uv,
46 | config.loss.lambda_ampl,
47 | config.loss.lambda_phase,
48 | device
49 | ).to(self.device)
50 |
51 | # device
52 | if device == 'cuda':
53 | torch.cuda.set_device(config.env.gpu_id)
54 |
55 | for state in self.optimizer.state.values():
56 | for k, v in state.items():
57 | if torch.is_tensor(v):
58 | state[k] = v.to(device)
59 |
60 | self.loader_train, self.loader_valid \
61 | = get_data_loaders(self.args, whole_audio=False)
62 |
63 | def train(self):
64 | saver = Saver(
65 | self.args,
66 | initial_global_step=self.initial_global_step
67 | )
68 | print(f' [*] experiment dir: {self.args.env.expdir}')
69 |
70 | params_count = utils.get_network_paras_amount({'model': self.model})
71 | saver.log_info('--- model size ---')
72 | saver.log_info(params_count)
73 |
74 | best_loss = np.inf
75 | num_batches = len(self.loader_train)
76 | self.model.train()
77 | saver.log_info('======= start training =======')
78 |
79 | for epoch in range(self.args.train.epochs):
80 | for batch_idx, data in enumerate(self.loader_train):
81 | saver.global_step_increment()
82 |
83 | self.train_process_batch(
84 | saver,
85 | data,
86 | batch_idx,
87 | num_batches,
88 | epoch,
89 | best_loss
90 | )
91 |
92 | def train_process_batch(
93 | self,
94 | saver: Any,
95 | data: Dict[str, Any],
96 | batch_idx: int,
97 | num_batches: int,
98 | epoch: int,
99 | best_loss: float
100 | ):
101 | self.optimizer.zero_grad()
102 |
103 | # unpack data
104 | self.move_data_to_device(data)
105 |
106 | # forward
107 | signal, _, (s_h, s_n), (pre_ampl, pre_phase) = self.model(
108 | data['mel'], data['f0'], inference=False
109 | )
110 | # loss
111 | detach_uv = saver.global_step < self.args.loss.detach_uv_step
112 | loss, (loss_rss, loss_uv, loss_ampl, loss_phase) = self.loss_func(
113 | signal,
114 | s_h,
115 | pre_ampl,
116 | pre_phase,
117 | data['audio'],
118 | data['uv'],
119 | data['ampl'],
120 | data['phase'],
121 | detach_uv=detach_uv,
122 | uv_tolerance=self.args.loss.uv_tolerance
123 | )
124 |
125 | # handle nan loss and back propagate
126 | if torch.isnan(loss):
127 | raise ValueError(' [x] nan loss ')
128 | else:
129 | loss.backward()
130 | self.optimizer.step()
131 |
132 | self.log_training_progress(
133 | saver,
134 | loss,
135 | loss_rss,
136 | loss_uv,
137 | loss_ampl,
138 | loss_phase,
139 | batch_idx,
140 | num_batches,
141 | epoch
142 | )
143 |
144 | self.validate_and_save(saver, best_loss)
145 |
146 | def move_data_to_device(self, data: Dict[str, Any]):
147 | for k in data.keys():
148 | if k != 'name':
149 | data[k] = data[k].to(self.args.device)
150 |
151 | def log_training_progress(
152 | self,
153 | saver: Any,
154 | loss: torch.Tensor,
155 | loss_rss: torch.Tensor,
156 | loss_uv: torch.Tensor,
157 | loss_ampl: torch.Tensor,
158 | loss_phase: torch.Tensor,
159 | batch_idx: int,
160 | num_batches: int,
161 | epoch: int
162 | ):
163 | if saver.global_step % self.args.train.interval_log == 0:
164 | saver.log_info(
165 | 'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | loss: {:.3f} | rss: {:.3f}| uv: {:.3f} | ampl: {:.3f} | phase: {:.3f} | time: {} | step: {}'.format(
166 | epoch,
167 | batch_idx,
168 | num_batches,
169 | self.args.env.expdir,
170 | self.args.train.interval_log/saver.get_interval_time(),
171 | loss.item(),
172 | loss_rss.item(),
173 | loss_uv.item(),
174 | loss_ampl.item(),
175 | loss_phase.item(),
176 | saver.get_total_time(),
177 | saver.global_step
178 | )
179 | )
180 |
181 | saver.log_value({
182 | 'train/loss': loss.item(),
183 | 'train/rss': loss_rss.item(),
184 | 'train/uv': loss_uv.item(),
185 | 'train/ampl': loss_ampl.item(),
186 | 'train/phase': loss_phase.item()
187 | })
188 |
189 | def validate_and_save(self, saver: Any, best_loss: float):
190 | if saver.global_step % self.args.train.interval_val == 0:
191 | saver.save_model(self.model, self.optimizer, postfix=f'{saver.global_step}')
192 | test_loss, test_loss_rss, test_loss_uv, loss_ampl, loss_phase = self.test(self.args, self.loss_func, self.loader_valid, saver)
193 | saver.log_info(self.get_validation_message(test_loss, test_loss_rss, test_loss_uv, loss_ampl, loss_phase))
194 | saver.log_value({
195 | 'validation/loss': test_loss,
196 | 'validation/rss': test_loss_rss,
197 | 'validation/uv': test_loss_uv
198 | })
199 | #self.update_best_model(saver, best_loss, test_loss)
200 |
201 | def get_validation_message(self, test_loss: float, test_loss_rss: float, test_loss_uv: float, test_loss_ampl: float, test_loss_phase: float) -> str:
202 | return ' --- --- \nloss: {:.3f} | rss: {:.3f} | uv: {:.3f}| ampl: {:.3f}| phase: {:.3f}. '.format(test_loss, test_loss_rss, test_loss_uv, test_loss_ampl, test_loss_phase)
203 |
204 | def update_best_model(self, saver: Any, best_loss: float, test_loss: float):
205 | if test_loss < best_loss:
206 | saver.log_info(' [V] best model updated.')
207 | saver.save_model(self.model, self.optimizer, postfix='best')
208 | best_loss = test_loss
209 |
210 | def test(self, args, loss_func, loader_test, saver):
211 | print(' [*] testing...')
212 | self.model.eval()
213 |
214 | # losses
215 | test_loss = 0.
216 | test_loss_rss = 0.
217 | test_loss_uv = 0.
218 | test_loss_ampl = 0.
219 | test_loss_phase = 0.
220 |
221 | # intialization
222 | num_batches = len(loader_test)
223 | rtf_all = []
224 |
225 | # run
226 | with torch.no_grad():
227 | for bidx, data in enumerate(loader_test):
228 |
229 | loss, loss_rss, loss_uv, loss_ampl, loss_phase = self.test_process_bath(
230 | args, loss_func, saver, num_batches, rtf_all, bidx, data
231 | )
232 | test_loss += loss.item()
233 | test_loss_rss += loss_rss.item()
234 | test_loss_uv += loss_uv.item()
235 | test_loss_ampl += loss_ampl.item()
236 | test_loss_phase += loss_phase.item()
237 |
238 |
239 | # report
240 | test_loss /= num_batches
241 | test_loss_rss /= num_batches
242 | test_loss_uv /= num_batches
243 | test_loss_ampl /= num_batches
244 | test_loss_phase /= num_batches
245 |
246 | # check
247 | print(' [test_loss] test_loss:', test_loss)
248 | print(' [test_loss_rss] test_loss_rss:', test_loss_rss)
249 | print(' [test_loss_uv] test_loss_uv:', test_loss_uv)
250 | print(' [test_loss_ampl] test_loss_ampl:', test_loss_ampl)
251 | print(' [test_loss_phase] test_loss_phase:', test_loss_phase)
252 | print(' Real Time Factor', np.mean(rtf_all))
253 | return test_loss, test_loss_rss, test_loss_uv, test_loss_ampl, test_loss_phase
254 |
255 | def test_process_bath(
256 | self,
257 | args: DotDict,
258 | loss_func: HybridLoss,
259 | saver,
260 | num_batches,
261 | rtf_all,
262 | bidx,
263 | data,
264 | ):
265 | fn = data['name'][0]
266 | print('--------')
267 | print('{}/{} - {}'.format(bidx, num_batches, fn))
268 |
269 | # unpack data
270 | self.move_data_to_device(data)
271 | print('>>', data['name'][0])
272 |
273 | # forward
274 | st_time = time.time()
275 | signal, _, (s_h, s_n), (pre_ampl, pre_phase) = self.model(data['mel'], data['f0'], infer=False)
276 | ed_time = time.time()
277 |
278 | # crop. 因为test的时候,audio的长度不一定等于block_size的整数倍。
279 | signal = self.crop_audio(data, signal)
280 |
281 | # RTF
282 | self.compote_RTF(args, rtf_all, data, ed_time - st_time)
283 |
284 | # log
285 | saver.log_audio({fn+'/gt.wav': data['audio'], fn+'/pred.wav': signal})
286 |
287 | # loss
288 | loss, (loss_rss, loss_uv, loss_ampl, loss_phase) = loss_func(
289 | signal,
290 | s_h,
291 | pre_ampl,
292 | pre_phase,
293 | data['audio'],
294 | data['uv'],
295 | data['ampl'],
296 | data['phase'],
297 | detach_uv=True
298 | )
299 |
300 | return loss, loss_rss, loss_uv, loss_ampl, loss_phase
301 |
302 |
303 | def compote_RTF(self, args, rtf_all, data, run_time):
304 | song_time = data['audio'].shape[-1] / args.data.sampling_rate
305 | rtf = run_time / song_time
306 | print('RTF: {} | {} / {}'.format(rtf, run_time, song_time))
307 | rtf_all.append(rtf)
308 |
309 | def crop_audio(self, data, signal):
310 | min_len = np.min([signal.shape[1], data['audio'].shape[1]])
311 | signal = signal[:,:min_len]
312 | data['audio'] = data['audio'][:,:min_len]
313 | return signal
314 |
315 | @click.command()
316 | @click.option(
317 | '--config', type=click.Path(
318 | exists=True, file_okay=True, dir_okay=False, readable=True,
319 | path_type=Path, resolve_path=True
320 | ),
321 | required=True, metavar='CONFIG_FILE',
322 | help='The path to the config file.'
323 | )
324 | def main(config):
325 | print(' > starting training...')
326 |
327 | # load config
328 | args = utils.load_config(config)
329 | print(' > config:', config)
330 | print(' > exp:', args.env.expdir)
331 |
332 | trainer = ModelTrainer(args, args.device)
333 | trainer.train()
334 |
335 |
336 | if __name__ == '__main__':
337 | main()
338 |
--------------------------------------------------------------------------------
/u_noise.ckpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjzxkxdn/Mini-DDSP/305f069c07f0214bf65a28f519e2149e864885a2/u_noise.ckpt
--------------------------------------------------------------------------------
/v_noise.ckpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yjzxkxdn/Mini-DDSP/305f069c07f0214bf65a28f519e2149e864885a2/v_noise.ckpt
--------------------------------------------------------------------------------