├── .gitignore ├── LICENSE ├── README.md ├── config.yaml ├── dataset.py ├── evaluate.py ├── imgs ├── screenshot_audio.png └── screenshot_melspec.png ├── loss.py ├── model.py ├── preprocess.py ├── requirements.txt ├── results ├── jsut │ ├── loss.png │ ├── sample1_groundtruth.wav │ ├── sample1_synthesized.wav │ ├── sample2_groundtruth.wav │ └── sample2_synthesized.wav └── jvs │ ├── loss.png │ ├── sample1_open_groundtruth.wav │ ├── sample1_open_synthesized.wav │ ├── sample2_closed_groundtruth.wav │ └── sample2_closed_synthesized.wav ├── run.sh ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | __pycache__ 107 | .vscode 108 | .DS_Store 109 | 110 | # MFA 111 | montreal-forced-aligner/ 112 | 113 | # data, checkpoint, and models 114 | output/ 115 | preprocessed/ 116 | log* 117 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Takaaki Saeki 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simplified Neural Source Filter Model 2 | 3 | My implementation of simplified neural source filter model (S-NSF) in [this paper](https://arxiv.org/abs/1904.12088) with some modifications. 4 | I examined this implementation on [JSUT Corpus](https://arxiv.org/abs/1711.00354) (a single-speaker Japanese speech corpus) and [JVS Corpus](https://arxiv.org/abs/1908.06248) (a multi-speaker Japanese speech corpus). 5 | You can find some training results (around 700k iteration) from [here](./results/). 6 | 7 | 8 | ## Updates 9 | - 2021/08/04: Initial commit 10 | 11 | ## Usage 12 | First, install dependencies with 13 | ``` 14 | $pip install -r requirements.txt 15 | ``` 16 | 17 | You can explore various training parameters by editing `config.yaml`. 18 | 19 | ### Preprocessing 20 | ``` 21 | $python preprocess.py 22 | ``` 23 | 24 | ### Start training 25 | ``` 26 | $python train.py 27 | ``` 28 | 29 | ### Visualize results 30 | ``` 31 | $tensorboard --logdir=${log_path} 32 | ``` 33 | 34 | ![](./imgs/screenshot_melspec.png) 35 | ![](./imgs/screenshot_audio.png) 36 | 37 | 38 | ## References 39 | - [Neural source-filter waveform models for statistical parametric speech synthesis](https://arxiv.org/abs/1904.12088) 40 | - [JSUT corpus: free large-scale Japanese speech corpus for end-to-end speech synthesis](https://arxiv.org/abs/1711.00354) 41 | - [JVS corpus: free Japanese multi-speaker voice corpus](https://arxiv.org/abs/1908.06248) 42 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | path: 2 | dataset_path: "/home/hdd1/jsut_basic5000/wav16k" 3 | preprocessed_path: "/workspace/simplified_neural_source_filter/preprocessed/jsut" 4 | tflog_path: "/home/hdd1/simplified_neural_source_filter/output/jsut/log" 5 | ckpt_path: "/home/hdd1/simplified_neural_source_filter/output/jsut/ckpt" 6 | #dataset_path: "/home/hdd1/jvs100_16k" 7 | #preprocessed_path: "/workspace/simplified_neural_source_filter/preprocessed/jvs" 8 | #tflog_path: "/home/hdd1/simplified_neural_source_filter/output/jvs/log" 9 | #ckpt_path: "/home/hdd1/simplified_neural_source_filter/output/jvs/ckpt" 10 | 11 | preprocess: 12 | corpus: "jsut" 13 | #corpus: "jvs" 14 | num_train: 4950 15 | #num_train: 95 16 | sampling_rate: 16000 17 | frame_length: 400 18 | frame_shift: 80 19 | fft_length: 1024 20 | segment_length: 2 21 | sp_dim: 80 22 | 23 | model: 24 | n_harmonic: 7 25 | phi: 0. 26 | alpha: 0.1 27 | sigma: 0.003 28 | rnn_hidden: 32 29 | cnn_out: 64 30 | n_convlayer: 10 31 | n_transformblock: 5 32 | 33 | train: 34 | batch_size: 2 35 | data_parallel: False 36 | step_total: 2000000 37 | restore_step: 0 38 | learning_rate: 0.0005 39 | grad_clip_thresh: 1.0 40 | plot_step: 1000 41 | val_step: 1000 42 | save_step: 1000 -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021, Takaaki Saeki 2 | 3 | import json 4 | import math 5 | import os 6 | import numpy as np 7 | import pickle 8 | import glob 9 | import matplotlib.pyplot as plt 10 | import math 11 | from torch.utils.data import Dataset, DataLoader 12 | import torch 13 | 14 | class Dataset(Dataset): 15 | 16 | def __init__(self, filetxt, config, device): 17 | 18 | self.preprocessed_path = config["path"]["preprocessed_path"] 19 | with open(os.path.join(self.preprocessed_path, filetxt), 'r') as fr: 20 | self.filelists = [path.strip('\n') for path in fr] 21 | self.device = device 22 | 23 | self.f0s = [] 24 | self.sps = [] 25 | self.wavs = [] 26 | for filepath in self.filelists: 27 | 28 | if config["preprocess"]["corpus"] == "jvs": 29 | basename = os.path.basename(os.path.dirname(filepath)) + "-" + os.path.splitext(os.path.basename(filepath))[0] 30 | elif config["preprocess"]["corpus"] == "jsut": 31 | basename = os.path.splitext(os.path.basename(filepath))[0] 32 | 33 | f0 = np.load(os.path.join(self.preprocessed_path, "f0", "f0-{}.npy".format(basename))) 34 | self.f0s.extend(f0) 35 | sp = np.load(os.path.join(self.preprocessed_path, "sp", "sp-{}.npy".format(basename))) 36 | self.sps.extend(sp) 37 | wav = np.load(os.path.join(self.preprocessed_path, "wav", "wav-{}.npy".format(basename))) 38 | self.wavs.extend(wav) 39 | 40 | self.f0s = np.asarray(self.f0s) 41 | self.sps = np.asarray(self.sps) 42 | self.wavs = np.asarray(self.wavs) 43 | 44 | def __len__(self): 45 | return self.f0s.shape[0] 46 | 47 | def __getitem__(self, idx): 48 | 49 | f0 = self.f0s[idx, :] 50 | sp = self.sps[idx, :, :] 51 | wav = self.wavs[idx, :] 52 | 53 | sample = { 54 | 'f0': f0, 55 | 'sp': sp, 56 | 'wav': wav 57 | } 58 | 59 | return sample 60 | 61 | def collate_fn(self, batch): 62 | 63 | batch_size = len(batch) 64 | 65 | f0s = [] 66 | sps = [] 67 | wavs = [] 68 | 69 | for idx in range(batch_size): 70 | f0s.append(batch[idx]['f0']) 71 | sps.append(batch[idx]['sp']) 72 | wavs.append(batch[idx]['wav']) 73 | f0s = np.asarray(f0s) 74 | sps = np.asarray(sps) 75 | wavs = np.asarray(wavs) 76 | 77 | f0s = torch.from_numpy(f0s).to(self.device) 78 | sps = torch.from_numpy(sps).to(self.device) 79 | wavs = torch.from_numpy(wavs).to(self.device) 80 | feature = torch.cat((f0s.unsqueeze(1), sps), dim=1) 81 | 82 | output = [ 83 | f0s, 84 | sps, 85 | feature, 86 | wavs 87 | ] 88 | 89 | return output 90 | 91 | if __name__ == '__main__': 92 | 93 | device = torch.device("cpu" if torch.cuda.is_available() else "cpu") 94 | train_dataset = Dataset('train.txt', 'preprocessed', device) 95 | train_loader = DataLoader( 96 | train_dataset, 97 | batch_size=16, 98 | shuffle=True, 99 | collate_fn=train_dataset.collate_fn, 100 | ) 101 | 102 | n_batch = 0 103 | for batchs in train_loader: 104 | n_batch += 1 105 | print( 106 | "Training set with size {} is composed of {} batches.".format( 107 | len(train_dataset), n_batch 108 | ) 109 | ) -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021, Takaaki Saeki 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader 6 | 7 | from dataset import Dataset 8 | from utils import log 9 | 10 | def evaluate(config, model, criteria, step, device, val_logger=None): 11 | 12 | valid_dataset = Dataset('val.txt', config, device) 13 | valid_loader = DataLoader( 14 | valid_dataset, 15 | batch_size=config["train"]["batch_size"], 16 | shuffle=True, 17 | collate_fn=valid_dataset.collate_fn, 18 | num_workers=0 19 | ) 20 | 21 | # Evaluation 22 | loss_sums = 0. 23 | for batch in valid_loader: 24 | 25 | with torch.no_grad(): 26 | 27 | output = model(batch[0], batch[2]) 28 | loss = criteria(output, batch[3]) 29 | loss_sums += loss.item() * len(batch) 30 | 31 | loss_mean = loss_sums / len(valid_dataset) 32 | 33 | message = "Validation Step {}, Loss: {}".format( 34 | step, loss_mean 35 | ) 36 | 37 | log(val_logger, step, loss=loss_mean) 38 | 39 | log(val_logger, 40 | audio=output[0, :], 41 | sampling_rate=config["preprocess"]["sampling_rate"], 42 | tag="Validation/audio_step{}_synthesized".format(step), 43 | ) 44 | log(val_logger, 45 | audio=batch[3][0, :], 46 | sampling_rate=config["preprocess"]["sampling_rate"], 47 | tag="Validation/audio_step{}_groundtruth".format(step), 48 | ) 49 | log(val_logger, 50 | figwav=output[0, :], 51 | sampling_rate=config["preprocess"]["sampling_rate"], 52 | tag="Validation/melspec_step{}_synthesized".format(step), 53 | ) 54 | log(val_logger, 55 | figwav=batch[3][0, :], 56 | sampling_rate=config["preprocess"]["sampling_rate"], 57 | tag="Validation/melspec_step{}_groundtruth".format(step), 58 | ) 59 | 60 | return message -------------------------------------------------------------------------------- /imgs/screenshot_audio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Takaaki-Saeki/simplified_neural_source_filter/bdbb7ed6acd6532f9415202d85bd8285bcf49eb7/imgs/screenshot_audio.png -------------------------------------------------------------------------------- /imgs/screenshot_melspec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Takaaki-Saeki/simplified_neural_source_filter/bdbb7ed6acd6532f9415202d85bd8285bcf49eb7/imgs/screenshot_melspec.png -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021, Takaaki Saeki 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchaudio 6 | 7 | class MultiScaleSpectralLoss(nn.Module): 8 | """ 9 | Reference: DDSP: Differentiable Digital Signal Processing 10 | https://arxiv.org/abs/2001.04643 11 | """ 12 | def __init__(self): 13 | super(MultiScaleSpectralLoss, self).__init__() 14 | self.alpha = 1.0 15 | self.fft_sizes = [2048, 512, 256, 128, 64] 16 | self.spectrograms = [] 17 | for fftsize in self.fft_sizes: 18 | self.spectrograms.append( 19 | torchaudio.transforms.Spectrogram(n_fft=fftsize, hop_length=fftsize//4, power=2) 20 | ) 21 | self.spectrograms = nn.ModuleList(self.spectrograms) 22 | self.l1loss = nn.L1Loss() 23 | self.eps = 1e-10 24 | 25 | def forward(self, wav_out, wav_target): 26 | loss = 0. 27 | for spectrogram in self.spectrograms: 28 | S_out = spectrogram(wav_out) 29 | S_target = spectrogram(wav_target) 30 | log_S_out = torch.log(S_out+self.eps) 31 | log_S_target = torch.log(S_target+self.eps) 32 | loss += (self.l1loss(S_out, S_target) + self.alpha * self.l1loss(log_S_out, log_S_target)) 33 | return loss 34 | 35 | if __name__ == '__main__': 36 | 37 | # test 38 | wav_out = torch.ones(1, 64000) 39 | wav_target = torch.ones(1, 64000) 40 | criteria = MultiScaleSpectralLoss() 41 | loss = criteria(wav_out, wav_target) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021, Takaaki Saeki 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import yaml 7 | 8 | class ConvLayers(nn.Module): 9 | def __init__(self, config, n_layer=10): 10 | super(ConvLayers, self).__init__() 11 | self.channel = config["model"]["cnn_out"] 12 | convs = [] 13 | for n in range(n_layer): 14 | dil = n % 10 15 | convs.append( 16 | nn.Conv1d(self.channel, self.channel, 3, dilation=2**dil, stride=1, padding=2**dil), 17 | ) 18 | self.conv_layers = nn.Sequential(*convs) 19 | 20 | def forward(self, excitation, feature): 21 | conv_out = self.conv_layers(excitation) 22 | output = conv_out + excitation + feature 23 | return torch.tanh(output) 24 | 25 | class TransformBlock(nn.Module): 26 | def __init__(self, config, n_convlayer=10): 27 | super(TransformBlock, self).__init__() 28 | self.channel = config["model"]["cnn_out"] 29 | self.in_linear = nn.Linear(1, self.channel) 30 | self.out_linear = nn.Linear(self.channel, 1) 31 | self.conv_layers = ConvLayers(config, n_convlayer) 32 | 33 | def forward(self, excitation, feature): 34 | output = excitation.transpose(1, 2) 35 | output = self.in_linear(output).transpose(1, 2) 36 | output = self.conv_layers(output, feature) 37 | output = self.out_linear(output.transpose(1, 2)) 38 | output = excitation + output.transpose(1, 2) 39 | return output 40 | 41 | class SourceFilter(nn.Module): 42 | def __init__(self, config, device): 43 | super(SourceFilter, self).__init__() 44 | self.source = SourceModule(config, device) 45 | self.filter = FilterModule(config) 46 | 47 | def forward(self, f0s, feature): 48 | excitation = self.source(f0s) 49 | output = self.filter(excitation, feature) 50 | return output 51 | 52 | class SourceModule(nn.Module): 53 | def __init__(self, config, device): 54 | super(SourceModule, self).__init__() 55 | self.n_harmonic = config["model"]["n_harmonic"] 56 | self.phi = config["model"]["phi"] 57 | self.alpha = config["model"]["alpha"] 58 | self.sigma = config["model"]["sigma"] 59 | self.SR = config["preprocess"]["sampling_rate"] 60 | self.frame_shift = config["preprocess"]["frame_shift"] 61 | self.device = device 62 | 63 | self.phi = torch.rand(self.n_harmonic, requires_grad=False) * -1. * np.pi 64 | self.phi[0] = 0. 65 | 66 | self.amplitude = nn.Parameter(torch.ones(self.n_harmonic+1), requires_grad=True) 67 | torch.nn.init.normal_(self.amplitude, 0.0, 1.0) 68 | 69 | def forward(self, f0s): 70 | f0s = torch.repeat_interleave(f0s, self.frame_shift, dim=1) # upsampling 71 | output = 0. 72 | for i in range(self.n_harmonic): 73 | output += self.amplitude[i] * self._signal(f0s*(i+1), self.phi[i]) 74 | output = torch.tanh(output + self.amplitude[self.n_harmonic]) 75 | return output 76 | 77 | def _signal(self, freq, phi): 78 | noise = torch.normal(0., self.sigma, size=freq.shape).to(self.device) 79 | eplus = self.alpha * torch.sin(torch.cumsum(2.*np.pi*freq/self.SR, dim=1) + phi) + noise 80 | argplus = torch.where(freq > 0, 1, 0).to(self.device) 81 | argzero = torch.where(freq == 0, 1, 0).to(self.device) 82 | excitation = eplus*argplus + argzero*self.alpha/(3.*self.sigma)*noise 83 | return excitation 84 | 85 | class FilterModule(nn.Module): 86 | def __init__(self, config): 87 | super(FilterModule, self).__init__() 88 | self.in_dim = 1 + config["preprocess"]["sp_dim"] 89 | self.rnn_hidden = config["model"]["rnn_hidden"] 90 | self.cnn_out = config["model"]["cnn_out"] 91 | self.frame_shift = config["preprocess"]["frame_shift"] 92 | self.n_convlayer = config["model"]["n_convlayer"] 93 | self.n_transformblock = config["model"]["n_transformblock"] 94 | 95 | self.bilstm = nn.LSTM(self.in_dim, self.rnn_hidden, num_layers=1, bidirectional=True, batch_first=True) 96 | self.conv = nn.Conv1d(in_channels=self.rnn_hidden*2, out_channels=self.cnn_out, kernel_size=3, stride=1, padding=1) 97 | self.transform_blocks = nn.ModuleList([TransformBlock(config, self.n_convlayer) for n in range(self.n_transformblock)]) 98 | 99 | def forward(self, excitation, feature): 100 | feature, _ = self.bilstm(feature.transpose(1, 2)) 101 | feature = self.conv(feature.transpose(1, 2)) 102 | feature = torch.repeat_interleave(feature, self.frame_shift, dim=2) # upsampling 103 | output = excitation.unsqueeze(1) 104 | for n in range(self.n_transformblock): 105 | output = self.transform_blocks[n](output, feature) 106 | return output.squeeze(1) 107 | 108 | if __name__ == '__main__': 109 | 110 | # test for basic SourceFilter 111 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 112 | config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader) 113 | f0s = torch.zeros((8, 800)).to(device) # expanded f0s: (B, time) 114 | feature = torch.zeros((8, 81, 800)).to(device) 115 | source_filter_model = SourceFilter(config, device).to(device) 116 | output = source_filter_model(f0s, feature) 117 | assert output.shape == (8, 64000) 118 | 119 | # test for FilterModule 120 | excitation = torch.zeros((8, 64000)).to(device) 121 | feature = torch.zeros((8, 81, 800)).to(device) 122 | filter_module = FilterModule(config).to(device) 123 | output = filter_module(excitation, feature) 124 | assert output.shape == (8, 64000) -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021, Takaaki Saeki 2 | 3 | import numpy as np 4 | import os 5 | import librosa 6 | import soundfile as sf 7 | import tqdm 8 | import pyworld 9 | import glob 10 | import matplotlib.pyplot as plt 11 | import random 12 | import pickle 13 | import argparse 14 | import yaml 15 | 16 | def main(config): 17 | 18 | # configs 19 | datapath = config["path"]["dataset_path"] 20 | out_dir = config["path"]["preprocessed_path"] 21 | num_train = config["preprocess"]["num_train"] 22 | SR = config["preprocess"]["sampling_rate"] 23 | frame_length = config["preprocess"]["frame_length"] # (sample) 24 | frame_shift = config["preprocess"]["frame_shift"] # (sample) 25 | fft_length = config["preprocess"]["fft_length"] # (sample) 26 | segment_length = config["preprocess"]["segment_length"] # (second) 27 | sp_dim = config["preprocess"]["sp_dim"] 28 | 29 | os.makedirs(out_dir, exist_ok=True) 30 | 31 | if config["preprocess"]["corpus"] == "jvs": 32 | train_filelists = [] 33 | val_filelists = [] 34 | spk_idxs = list(range(1, 101)) 35 | random.shuffle(spk_idxs) 36 | for n in range(100): 37 | idx = '0'*(3 - len(str(spk_idxs[n]))) + str(spk_idxs[n]) 38 | if n < num_train: 39 | train_filelists.extend(glob.glob(os.path.join(datapath, 'jvs{}'.format(idx) ,'*.wav'))) 40 | else: 41 | val_filelists.extend(glob.glob(os.path.join(datapath, 'jvs{}'.format(idx) ,'*.wav'))) 42 | filelists = train_filelists + val_filelists 43 | elif config["preprocess"]["corpus"] == "jsut": 44 | filelists = glob.glob(os.path.join(datapath, '*.wav')) 45 | random.seed(0) 46 | random.shuffle(filelists) 47 | train_filelists = filelists[:num_train] 48 | val_filelists = filelists[num_train:] 49 | else: 50 | raise NotImplementedError() 51 | 52 | with open(os.path.join(out_dir, "train.txt"), "w", encoding="utf-8") as f: 53 | for m in train_filelists: 54 | f.write(m + "\n") 55 | with open(os.path.join(out_dir, "val.txt"), "w", encoding="utf-8") as f: 56 | for m in val_filelists: 57 | f.write(m + "\n") 58 | 59 | os.makedirs(out_dir, exist_ok=True) 60 | os.makedirs(os.path.join(out_dir, 'wav'), exist_ok=True) 61 | os.makedirs(os.path.join(out_dir, 'f0'), exist_ok=True) 62 | os.makedirs(os.path.join(out_dir, 'sp'), exist_ok=True) 63 | 64 | for idx, wavpath in enumerate(tqdm.tqdm(filelists)): 65 | 66 | if config["preprocess"]["corpus"] == "jvs": 67 | basename = os.path.basename(os.path.dirname(wavpath)) + "-" + os.path.splitext(os.path.basename(wavpath))[0] 68 | elif config["preprocess"]["corpus"] == "jsut": 69 | basename = os.path.splitext(os.path.basename(wavpath))[0] 70 | 71 | wav, SR = sf.read(wavpath) 72 | wav, _ = librosa.effects.trim(wav, top_db=20) 73 | num_seg = len(wav) // (segment_length*SR) 74 | f0s = [] 75 | sps = [] 76 | wavsegs = [] 77 | for n in range(num_seg): 78 | wav_seg = wav[n*(segment_length*SR):(n+1)*(segment_length*SR)] 79 | _f0, _t = pyworld.dio(wav_seg.astype(np.double), fs=SR, frame_period=frame_shift/SR*1000) 80 | f0 = pyworld.stonemask(wav_seg.astype(np.double), _f0, _t, SR) 81 | sp = np.sqrt(pyworld.cheaptrick(wav_seg.astype(np.double), f0, _t, SR)) 82 | melfilter = librosa.filters.mel(sr=SR, n_fft=fft_length, n_mels=sp_dim) 83 | melsp = np.dot(melfilter, sp.T) 84 | f0s.append(f0[:SR*segment_length//frame_shift]) 85 | sps.append(melsp[:, :SR*segment_length//frame_shift]) 86 | wavsegs.append(wav_seg) 87 | 88 | f0s = np.asarray(f0s).astype(np.float32) 89 | np.save( 90 | os.path.join(out_dir, "f0", "f0-{}.npy".format(basename)), 91 | f0s, 92 | ) 93 | sps = np.asarray(sps).astype(np.float32) 94 | np.save( 95 | os.path.join(out_dir, "sp", "sp-{}.npy".format(basename)), 96 | sps, 97 | ) 98 | wavsegs = np.asarray(wavsegs).astype(np.float32) 99 | np.save( 100 | os.path.join(out_dir, "wav", "wav-{}.npy".format(basename)), 101 | wavsegs, 102 | ) 103 | 104 | if __name__ == '__main__': 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument( 107 | "--num_train", 108 | type=int, 109 | default=None, 110 | required=False 111 | ) 112 | parser.add_argument( 113 | "--sp_dim", 114 | type=int, 115 | default=None, 116 | required=False 117 | ) 118 | parser.add_argument( 119 | "--corpus", 120 | type=str, 121 | default=None, 122 | required=False 123 | ) 124 | parser.add_argument( 125 | "--dataset_path", 126 | type=str, 127 | default=None, 128 | required=False 129 | ) 130 | parser.add_argument( 131 | "--preprocessed_path", 132 | type=str, 133 | default=None, 134 | required=False 135 | ) 136 | args = parser.parse_args() 137 | 138 | config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader) 139 | for name in ['num_train', 'corpus', 'sp_dim']: 140 | if getattr(args, name) is not None: 141 | config['preprocess'][name] = getattr(args, name) 142 | for path in ['dataset_path', 'preprocessed_path']: 143 | if getattr(args, path) is not None: 144 | config['path'][path] = getattr(args, path) 145 | main(config) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | SoundFile 2 | numpy 3 | librosa 4 | pyworld 5 | matplotlib 6 | torch==1.7.0 7 | torchaudio==0.7.0 -------------------------------------------------------------------------------- /results/jsut/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Takaaki-Saeki/simplified_neural_source_filter/bdbb7ed6acd6532f9415202d85bd8285bcf49eb7/results/jsut/loss.png -------------------------------------------------------------------------------- /results/jsut/sample1_groundtruth.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Takaaki-Saeki/simplified_neural_source_filter/bdbb7ed6acd6532f9415202d85bd8285bcf49eb7/results/jsut/sample1_groundtruth.wav -------------------------------------------------------------------------------- /results/jsut/sample1_synthesized.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Takaaki-Saeki/simplified_neural_source_filter/bdbb7ed6acd6532f9415202d85bd8285bcf49eb7/results/jsut/sample1_synthesized.wav -------------------------------------------------------------------------------- /results/jsut/sample2_groundtruth.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Takaaki-Saeki/simplified_neural_source_filter/bdbb7ed6acd6532f9415202d85bd8285bcf49eb7/results/jsut/sample2_groundtruth.wav -------------------------------------------------------------------------------- /results/jsut/sample2_synthesized.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Takaaki-Saeki/simplified_neural_source_filter/bdbb7ed6acd6532f9415202d85bd8285bcf49eb7/results/jsut/sample2_synthesized.wav -------------------------------------------------------------------------------- /results/jvs/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Takaaki-Saeki/simplified_neural_source_filter/bdbb7ed6acd6532f9415202d85bd8285bcf49eb7/results/jvs/loss.png -------------------------------------------------------------------------------- /results/jvs/sample1_open_groundtruth.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Takaaki-Saeki/simplified_neural_source_filter/bdbb7ed6acd6532f9415202d85bd8285bcf49eb7/results/jvs/sample1_open_groundtruth.wav -------------------------------------------------------------------------------- /results/jvs/sample1_open_synthesized.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Takaaki-Saeki/simplified_neural_source_filter/bdbb7ed6acd6532f9415202d85bd8285bcf49eb7/results/jvs/sample1_open_synthesized.wav -------------------------------------------------------------------------------- /results/jvs/sample2_closed_groundtruth.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Takaaki-Saeki/simplified_neural_source_filter/bdbb7ed6acd6532f9415202d85bd8285bcf49eb7/results/jvs/sample2_closed_groundtruth.wav -------------------------------------------------------------------------------- /results/jvs/sample2_closed_synthesized.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Takaaki-Saeki/simplified_neural_source_filter/bdbb7ed6acd6532f9415202d85bd8285bcf49eb7/results/jvs/sample2_closed_synthesized.wav -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | pip install --upgrade pip 2 | pip install -r requirements.txt 3 | 4 | python preprocess.py \ 5 | --num_train=4950 \ 6 | --sp_dim=80 \ 7 | --corpus="jsut" 8 | 9 | python train.py \ 10 | --num_train=4950 \ 11 | --sp_dim=80 \ 12 | --corpus="jsut" 13 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021, Takaaki Saeki 2 | 3 | import os 4 | import torch 5 | import yaml 6 | import torch.nn as nn 7 | from torch.utils.data import DataLoader 8 | from torch.utils.tensorboard import SummaryWriter 9 | from tqdm import tqdm 10 | import argparse 11 | 12 | from utils import get_model, log 13 | from loss import MultiScaleSpectralLoss 14 | from dataset import Dataset 15 | from evaluate import evaluate 16 | 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | def main(config): 20 | 21 | # Get dataset 22 | train_dataset = Dataset('train.txt', config, device) 23 | train_loader = DataLoader( 24 | train_dataset, 25 | batch_size=config["train"]["batch_size"], 26 | shuffle=True, 27 | collate_fn=train_dataset.collate_fn, 28 | num_workers=0 29 | ) 30 | 31 | # Prepare model and optimizer 32 | model, optimizer = get_model(config, device, train=True) 33 | if config["train"]["data_parallel"]: 34 | model = nn.DataParallel(model) 35 | criteria = MultiScaleSpectralLoss().to(device) 36 | 37 | # Init logger 38 | train_tflog_path = os.path.join(config["path"]["tflog_path"], "train") 39 | val_tflog_path = os.path.join(config["path"]["tflog_path"], "val") 40 | train_logger = SummaryWriter(train_tflog_path) 41 | val_logger = SummaryWriter(val_tflog_path) 42 | os.makedirs(train_tflog_path, exist_ok=True) 43 | os.makedirs(val_tflog_path, exist_ok=True) 44 | os.makedirs(config["path"]["ckpt_path"], exist_ok=True) 45 | 46 | # Extract training config 47 | step_total = config["train"]["step_total"] 48 | plot_step = config["train"]["plot_step"] 49 | val_step = config["train"]["val_step"] 50 | save_step = config["train"]["save_step"] 51 | step = config["train"]["restore_step"] 52 | epoch = step // len(train_dataset) 53 | 54 | # Set progress bar 55 | step_bar = tqdm(total=config["train"]["step_total"], desc="Training", position=0) 56 | step_bar.n = config["train"]["restore_step"] 57 | step_bar.update() 58 | 59 | # main training loop 60 | while True: 61 | for batch in train_loader: 62 | 63 | torch.autograd.set_detect_anomaly(True) 64 | optimizer.zero_grad() 65 | 66 | output = model(batch[0], batch[2]) 67 | loss = criteria(output, batch[3]) 68 | 69 | loss.backward() 70 | nn.utils.clip_grad_norm_(model.parameters(), config["train"]["grad_clip_thresh"]) 71 | optimizer.step() 72 | 73 | message1 = "Step: {}, Epoch: {}, ".format(step, epoch) 74 | message2 = "Train Loss: {}".format(loss.item()) 75 | step_bar.write(message1 + message2) 76 | 77 | log(train_logger, step, loss=loss) 78 | 79 | # fig 80 | if step % plot_step == 0: 81 | log(train_logger, 82 | audio=output[0, :], 83 | sampling_rate=config["preprocess"]["sampling_rate"], 84 | tag="Training/audio_step{}_synthesized".format(step), 85 | ) 86 | log(train_logger, 87 | audio=batch[3][0, :], 88 | sampling_rate=config["preprocess"]["sampling_rate"], 89 | tag="Training/audio_step{}_groundtruth".format(step), 90 | ) 91 | log(train_logger, 92 | figwav=output[0, :], 93 | sampling_rate=config["preprocess"]["sampling_rate"], 94 | tag="Training/melspec_step{}_synthesized".format(step), 95 | ) 96 | log(train_logger, 97 | figwav=batch[3][0, :], 98 | sampling_rate=config["preprocess"]["sampling_rate"], 99 | tag="Training/melspec_step{}_groundtruth".format(step), 100 | ) 101 | 102 | if step % val_step == 0: 103 | model.eval() 104 | message = evaluate(config, model, criteria, step, device, val_logger) 105 | step_bar.write(message) 106 | model.train() 107 | 108 | if step % save_step == 0: 109 | torch.save( 110 | { 111 | "model": model.state_dict(), 112 | "optimizer": optimizer.state_dict(), 113 | }, 114 | os.path.join( 115 | config["path"]["ckpt_path"], 116 | "{}.pth.tar".format(step), 117 | ), 118 | ) 119 | 120 | step += 1 121 | step_bar.update(1) 122 | if step == step_total: 123 | quit() 124 | epoch += 1 125 | 126 | if __name__ == '__main__': 127 | parser = argparse.ArgumentParser() 128 | parser.add_argument( 129 | "--num_train", 130 | type=int, 131 | default=None, 132 | required=False 133 | ) 134 | parser.add_argument( 135 | "--sp_dim", 136 | type=int, 137 | default=None, 138 | required=False 139 | ) 140 | parser.add_argument( 141 | "--corpus", 142 | type=str, 143 | default=None, 144 | required=False 145 | ) 146 | parser.add_argument( 147 | "--dataset_path", 148 | type=str, 149 | default=None, 150 | required=False 151 | ) 152 | parser.add_argument( 153 | "--preprocessed_path", 154 | type=str, 155 | default=None, 156 | required=False 157 | ) 158 | parser.add_argument( 159 | "--tflog_path", 160 | type=str, 161 | default=None, 162 | required=False 163 | ) 164 | parser.add_argument( 165 | "--ckpt_path", 166 | type=str, 167 | default=None, 168 | required=False 169 | ) 170 | args = parser.parse_args() 171 | 172 | config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader) 173 | for name in ['num_train', 'corpus', 'sp_dim']: 174 | if getattr(args, name) is not None: 175 | config['preprocess'][name] = getattr(args, name) 176 | for path in ['dataset_path', 'preprocessed_path', 'tflog_path', 'ckpt_path']: 177 | if getattr(args, path) is not None: 178 | config['path'][path] = getattr(args, path) 179 | main(config) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021, Takaaki Saeki 2 | 3 | import os 4 | import torch 5 | import librosa 6 | import librosa.display 7 | import matplotlib.pyplot as plt 8 | from model import SourceFilter 9 | import numpy as np 10 | import soundfile as sf 11 | 12 | def get_model(config, device, train=False): 13 | 14 | model = SourceFilter(config, device).to(device) 15 | if config["train"]["restore_step"]: 16 | ckpt_path = os.path.join( 17 | config["path"]["ckpt_path"], 18 | "{}.pth.tar".format(config["train"]["restore_step"]) 19 | ) 20 | ckpt = torch.load(ckpt_path) 21 | model.load_state_dict(ckpt["model"]) 22 | 23 | if train: 24 | optim = torch.optim.Adam(model.parameters(), lr=config["train"]["learning_rate"]) 25 | if config["train"]["restore_step"]: 26 | optim.load_state_dict(ckpt["optimizer"]) 27 | for state in optim.state.values(): 28 | for k, v in state.items(): 29 | if isinstance(v, torch.Tensor): 30 | state[k] = v.to(device) 31 | model.train() 32 | return model, optim 33 | 34 | model.eval() 35 | model.requires_grad_ = False 36 | return model 37 | 38 | def log(logger, step=None, loss=None, figwav=None, audio=None, sampling_rate=16000, tag=""): 39 | if loss is not None: 40 | logger.add_scalar("Loss/total_loss", loss, step) 41 | 42 | if audio is not None: 43 | audio = audio.detach().cpu().numpy() 44 | logger.add_audio( 45 | tag, 46 | audio / max(abs(audio)), 47 | sample_rate=sampling_rate, 48 | ) 49 | 50 | if figwav is not None: 51 | figwav = figwav.detach().cpu().numpy() 52 | fig = plot_melspec(figwav / max(abs(figwav))) 53 | logger.add_figure(tag, fig) 54 | 55 | def plot_melspec(wav, sampling_rate=16000, frame_length=400, fft_length=512, frame_shift=80): 56 | melspec = librosa.feature.melspectrogram( 57 | wav, 58 | sr=sampling_rate, 59 | hop_length=frame_shift, 60 | win_length=frame_length, 61 | n_mels=128, 62 | fmax=sampling_rate//2 63 | ) 64 | fig, ax = plt.subplots() 65 | melspec_db = librosa.power_to_db(melspec, ref=np.max) 66 | img = librosa.display.specshow( 67 | melspec_db, 68 | x_axis='time', 69 | y_axis='linear', 70 | sr=sampling_rate, 71 | hop_length=frame_shift, 72 | fmax=sampling_rate//2, 73 | ax=ax) 74 | ax.set_title("Melspectrogram", fontsize="medium") 75 | return fig --------------------------------------------------------------------------------