├── LICENSE ├── README.md ├── components ├── filtered_noise.py ├── harmonic_oscillator.py ├── loudness_extractor.py ├── ptcrepe │ ├── README.md │ └── ptcrepe │ │ ├── crepe.py │ │ └── utils.py └── reverb.py ├── configs └── violin.yaml ├── data ├── mp3_to_wav.sh └── violin │ ├── test │ ├── VIII.+Double.wav │ └── f0_0.004 │ │ └── VIII.+Double.f0.csv │ └── train │ ├── II.+Double.wav │ ├── III.+Corrente.wav │ ├── IV.+Double+Presto.wav │ ├── VI.+Double.wav │ └── f0_0.004 │ ├── II.+Double.f0.csv │ ├── III.+Corrente.f0.csv │ ├── IV.+Double+Presto.f0.csv │ └── VI.+Double.f0.csv ├── requirements.txt └── train ├── dataset └── audiodata.py ├── loss └── mss_loss.py ├── network └── autoencoder │ ├── autoencoder.py │ ├── decoder.py │ └── encoder.py ├── optimizer └── radam.py ├── test.py ├── train.py └── trainer ├── PinkModule └── logging.py ├── __init__.py ├── io.py └── trainer.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jongho Choi, Sungho Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch version of DDSP 2 | 3 | # DDSP : Differentiable Digital Signal Processing 4 | 5 | > Original Authors : Jesse Engel, Lamtharn (Hanoi) Hantrakul, Chenjie Gu, Adam Roberts (Google) 6 | 7 | > This Repository is NOT an official implement of authors. 8 | 9 | ## Demo Page ## 10 | 11 | - [Link](https://sweetcocoa.github.io/ddsp-pytorch-samples/) 12 | 13 | ## How to train with your own data 14 | 15 | 1. Clone this repository 16 | 17 | ```bash 18 | git clone https://github.com/sweetcocoa/ddsp-pytorch 19 | ``` 20 | 21 | 2. Prepare your own audio data. (wav, mp3, flac.. ) 22 | 3. Use ffmpeg to convert that audio's sampling rate to 16k 23 | 24 | ```bash 25 | # example 26 | ffmpeg -y -loglevel fatal -i $input_file -ac 1 -ar 16000 $output_file 27 | ``` 28 | 4. Use [CREPE](https://github.com/marl/crepe) to precalculate the fundamental frequency of the audio. 29 | 30 | ```bash 31 | # example 32 | crepe directory-to-audio/ --output directory-to-audio/f0_0.004/ --viterbi --step-size 4 33 | ``` 34 | 35 | 5. MAKE config file. (See configuration *config/violin.yaml* to make appropriate config file.) And edit train/train.py 36 | 37 | ```python 38 | config = setup(default_config="../configs/your_config.yaml") 39 | ``` 40 | 6. Run train/train.py 41 | 42 | ```bash 43 | cd train 44 | python train.py 45 | ``` 46 | 47 | ## How to test your own model ## 48 | 49 | ```bash 50 | cd train 51 | python test.py\ 52 | --input input.wav\ 53 | --output output.wav\ 54 | --ckpt trained_weight.pth\ 55 | --config config/your-config.yaml\ 56 | --wave_length 16000 57 | ``` 58 | 59 | ## Download pretrained weight file ### 60 | > [download](https://github.com/sweetcocoa/ddsp-pytorch/raw/models/weight.zip) 61 | 62 | ## Contact ## 63 | 64 | - Jongho Choi (sweetcocoa@snu.ac.kr, BS Student @ Seoul National Univ.) 65 | - Sungho Lee (dlfqhsdugod1106@gmail.com, BS Student @ Postech.) 66 | 67 | > Equally contributed. 68 | -------------------------------------------------------------------------------- /components/filtered_noise.py: -------------------------------------------------------------------------------- 1 | """ 2 | 2020_01_20 - 2020_01_29 3 | Simple trainable filtered noise model for DDSP decoder. 4 | TODO : 5 | code refactoring 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | class FilteredNoise(nn.Module): 14 | def __init__(self, frame_length = 64, attenuate_gain = 1e-2, device = 'cuda'): 15 | super(FilteredNoise, self).__init__() 16 | 17 | self.frame_length = frame_length 18 | self.device = device 19 | self.attenuate_gain = attenuate_gain 20 | 21 | def forward(self, z): 22 | """ 23 | Compute linear-phase LTI-FVR (time-varient in terms of frame by frame) filter banks in batch from network output, 24 | and create time-varying filtered noise by overlap-add method. 25 | 26 | Argument: 27 | z['H'] : filter coefficient bank for each batch, which will be used for constructing linear-phase filter. 28 | - dimension : (batch_num, frame_num, filter_coeff_length) 29 | 30 | """ 31 | 32 | batch_num, frame_num, filter_coeff_length = z['H'].shape 33 | self.filter_window = nn.Parameter(torch.hann_window(filter_coeff_length * 2 - 1, dtype = torch.float32), requires_grad = False).to(self.device) 34 | 35 | INPUT_FILTER_COEFFICIENT = z['H'] 36 | 37 | # Desired linear-phase filter can be obtained by time-shifting a zero-phase form (especially to a causal form to be real-time), 38 | # which has zero imaginery part in the frequency response. 39 | # Therefore, first we create a zero-phase filter in frequency domain. 40 | # Then, IDFT & make it causal form. length IDFT-ed signal size can be both even or odd, 41 | # but we choose odd number such that a single sample can represent the center of impulse response. 42 | ZERO_PHASE_FR_BANK = INPUT_FILTER_COEFFICIENT.unsqueeze(-1).expand(batch_num, frame_num, filter_coeff_length, 2).contiguous() 43 | ZERO_PHASE_FR_BANK[..., 1] = 0 44 | ZERO_PHASE_FR_BANK = ZERO_PHASE_FR_BANK.view(-1, filter_coeff_length, 2) 45 | zero_phase_ir_bank = torch.irfft(ZERO_PHASE_FR_BANK, 1, signal_sizes = (filter_coeff_length * 2 - 1,)) 46 | 47 | # Make linear phase causal impulse response & Hann-window it. 48 | # Then zero pad + DFT for linear convolution. 49 | linear_phase_ir_bank = zero_phase_ir_bank.roll(filter_coeff_length - 1, 1) 50 | windowed_linear_phase_ir_bank = linear_phase_ir_bank * self.filter_window.view(1, -1) 51 | zero_paded_windowed_linear_phase_ir_bank = nn.functional.pad(windowed_linear_phase_ir_bank, (0, self.frame_length - 1)) 52 | ZERO_PADED_WINDOWED_LINEAR_PHASE_FR_BANK = torch.rfft(zero_paded_windowed_linear_phase_ir_bank, 1) 53 | 54 | # Generate white noise & zero pad & DFT for linear convolution. 55 | noise = torch.rand(batch_num, frame_num, self.frame_length, dtype = torch.float32).view(-1, self.frame_length).to(self.device) * 2 - 1 56 | zero_paded_noise = nn.functional.pad(noise, (0, filter_coeff_length * 2 - 2)) 57 | ZERO_PADED_NOISE = torch.rfft(zero_paded_noise, 1) 58 | 59 | # Convolve & IDFT to make filtered noise frame, for each frame, noise band, and batch. 60 | FILTERED_NOISE = torch.zeros_like(ZERO_PADED_NOISE).to(self.device) 61 | FILTERED_NOISE[:, :, 0] = ZERO_PADED_NOISE[:, :, 0] * ZERO_PADED_WINDOWED_LINEAR_PHASE_FR_BANK[:, :, 0] \ 62 | - ZERO_PADED_NOISE[:, :, 1] * ZERO_PADED_WINDOWED_LINEAR_PHASE_FR_BANK[:, :, 1] 63 | FILTERED_NOISE[:, :, 1] = ZERO_PADED_NOISE[:, :, 0] * ZERO_PADED_WINDOWED_LINEAR_PHASE_FR_BANK[:, :, 1] \ 64 | + ZERO_PADED_NOISE[:, :, 1] * ZERO_PADED_WINDOWED_LINEAR_PHASE_FR_BANK[:, :, 0] 65 | filtered_noise = torch.irfft(FILTERED_NOISE, 1).view(batch_num, frame_num, -1) * self.attenuate_gain 66 | 67 | # Overlap-add to build time-varying filtered noise. 68 | overlap_add_filter = torch.eye(filtered_noise.shape[-1], requires_grad = False).unsqueeze(1).to(self.device) 69 | output_signal = nn.functional.conv_transpose1d(filtered_noise.transpose(1, 2), 70 | overlap_add_filter, 71 | stride = self.frame_length, 72 | padding = 0).squeeze(1) 73 | 74 | return output_signal 75 | -------------------------------------------------------------------------------- /components/harmonic_oscillator.py: -------------------------------------------------------------------------------- 1 | """ 2 | 2020_01_15 - 2020_01_29 3 | Harmonic Oscillator model for DDSP decoder. 4 | TODO : 5 | upsample + interpolation 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | class HarmonicOscillator(nn.Module): 14 | def __init__(self, sr=16000, frame_length=64, attenuate_gain=0.02, device="cuda"): 15 | super(HarmonicOscillator, self).__init__() 16 | self.sr = sr 17 | self.frame_length = frame_length 18 | self.attenuate_gain = attenuate_gain 19 | 20 | self.device = device 21 | 22 | self.framerate_to_audiorate = nn.Upsample( 23 | scale_factor=self.frame_length, mode="linear", align_corners=False 24 | ) 25 | 26 | def forward(self, z): 27 | 28 | """ 29 | Compute Addictive Synthesis 30 | Argument: 31 | z['f0'] : fundamental frequency envelope for each sample 32 | - dimension (batch_num, frame_rate_time_samples) 33 | z['c'] : harmonic distribution of partials for each sample 34 | - dimension (batch_num, partial_num, frame_rate_time_samples) 35 | z['a'] : loudness of entire sound for each sample 36 | - dimension (batch_num, frame_rate_time_samples) 37 | Returns: 38 | addictive_output : synthesized sinusoids for each sample 39 | - dimension (batch_num, audio_rate_time_samples) 40 | """ 41 | 42 | fundamentals = z["f0"] 43 | framerate_c_bank = z["c"] 44 | 45 | num_osc = framerate_c_bank.shape[1] 46 | 47 | # Build a frequency envelopes of each partials from z['f0'] data 48 | partial_mult = ( 49 | torch.linspace(1, num_osc, num_osc, dtype=torch.float32).unsqueeze(-1).to(self.device) 50 | ) 51 | framerate_f0_bank = ( 52 | fundamentals.unsqueeze(-1).expand(-1, -1, num_osc).transpose(1, 2) * partial_mult 53 | ) 54 | 55 | # Antialias z['c'] 56 | mask_filter = (framerate_f0_bank < self.sr / 2).float() 57 | antialiased_framerate_c_bank = framerate_c_bank * mask_filter 58 | 59 | # Upsample frequency envelopes and build phase bank 60 | audiorate_f0_bank = self.framerate_to_audiorate(framerate_f0_bank) 61 | audiorate_phase_bank = torch.cumsum(audiorate_f0_bank / self.sr, 2) 62 | 63 | # Upsample amplitude envelopes 64 | audiorate_a_bank = self.framerate_to_audiorate(antialiased_framerate_c_bank) 65 | 66 | # Build harmonic sinusoid bank and sum to build harmonic sound 67 | sinusoid_bank = ( 68 | torch.sin(2 * np.pi * audiorate_phase_bank) * audiorate_a_bank * self.attenuate_gain 69 | ) 70 | 71 | framerate_loudness = z["a"] 72 | audiorate_loudness = self.framerate_to_audiorate(framerate_loudness.unsqueeze(0)).squeeze(0) 73 | 74 | addictive_output = torch.sum(sinusoid_bank, 1) * audiorate_loudness 75 | 76 | return addictive_output 77 | -------------------------------------------------------------------------------- /components/loudness_extractor.py: -------------------------------------------------------------------------------- 1 | """ 2 | 2020_01_29 - 2020_02_03 3 | Loudness Extractor / Envelope Follower 4 | TODO : 5 | check appropriate gain structure 6 | GPU test 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | class LoudnessExtractor(nn.Module): 15 | def __init__(self, 16 | sr = 16000, 17 | frame_length = 64, 18 | attenuate_gain = 2., 19 | device = 'cuda'): 20 | 21 | super(LoudnessExtractor, self).__init__() 22 | 23 | self.sr = sr 24 | self.frame_length = frame_length 25 | self.n_fft = self.frame_length * 5 26 | 27 | self.device = device 28 | 29 | self.attenuate_gain = attenuate_gain 30 | self.smoothing_window = nn.Parameter(torch.hann_window(self.n_fft, dtype = torch.float32), requires_grad = False).to(self.device) 31 | 32 | 33 | 34 | def torch_A_weighting(self, FREQUENCIES, min_db = -45.0): 35 | """ 36 | Compute A-weighting weights in Decibel scale (codes from librosa) and 37 | transform into amplitude domain (with DB-SPL equation). 38 | 39 | Argument: 40 | FREQUENCIES : tensor of frequencies to return amplitude weight 41 | min_db : mininum decibel weight. appropriate min_db value is important, as 42 | exp/log calculation might raise numeric error with float32 type. 43 | 44 | Returns: 45 | weights : tensor of amplitude attenuation weights corresponding to the FREQUENCIES tensor. 46 | """ 47 | 48 | # Calculate A-weighting in Decibel scale. 49 | FREQUENCY_SQUARED = FREQUENCIES ** 2 50 | const = torch.tensor([12200, 20.6, 107.7, 737.9]) ** 2.0 51 | WEIGHTS_IN_DB = 2.0 + 20.0 * (torch.log10(const[0]) + 4 * torch.log10(FREQUENCIES) 52 | - torch.log10(FREQUENCY_SQUARED + const[0]) 53 | - torch.log10(FREQUENCY_SQUARED + const[1]) 54 | - 0.5 * torch.log10(FREQUENCY_SQUARED + const[2]) 55 | - 0.5 * torch.log10(FREQUENCY_SQUARED + const[3])) 56 | 57 | # Set minimum Decibel weight. 58 | if min_db is not None: 59 | WEIGHTS_IN_DB = torch.max(WEIGHTS_IN_DB, torch.tensor([min_db], dtype = torch.float32).to(self.device)) 60 | 61 | # Transform Decibel scale weight to amplitude scale weight. 62 | weights = torch.exp(torch.log(torch.tensor([10.], dtype = torch.float32).to(self.device)) * WEIGHTS_IN_DB / 10) 63 | 64 | return weights 65 | 66 | 67 | def forward(self, z): 68 | """ 69 | Compute A-weighted Loudness Extraction 70 | Input: 71 | z['audio'] : batch of time-domain signals 72 | Output: 73 | output_signal : batch of reverberated signals 74 | """ 75 | 76 | input_signal = z['audio'] 77 | paded_input_signal = nn.functional.pad(input_signal, (self.frame_length * 2, self.frame_length * 2)) 78 | sliced_signal = paded_input_signal.unfold(1, self.n_fft, self.frame_length) 79 | sliced_windowed_signal = sliced_signal * self.smoothing_window 80 | 81 | SLICED_SIGNAL = torch.rfft(sliced_windowed_signal, 1, onesided = False) 82 | 83 | SLICED_SIGNAL_LOUDNESS_SPECTRUM = torch.zeros(SLICED_SIGNAL.shape[:-1]) 84 | SLICED_SIGNAL_LOUDNESS_SPECTRUM = SLICED_SIGNAL[:, :, :, 0] ** 2 + SLICED_SIGNAL[:, :, :, 1] ** 2 85 | 86 | freq_bin_size = self.sr / self.n_fft 87 | FREQUENCIES = torch.tensor([(freq_bin_size * i) % (0.5 * self.sr) for i in range(self.n_fft)]).to(self.device) 88 | A_WEIGHTS = self.torch_A_weighting(FREQUENCIES) 89 | 90 | A_WEIGHTED_SLICED_SIGNAL_LOUDNESS_SPECTRUM = SLICED_SIGNAL_LOUDNESS_SPECTRUM * A_WEIGHTS 91 | A_WEIGHTED_SLICED_SIGNAL_LOUDNESS = torch.sqrt(torch.sum(A_WEIGHTED_SLICED_SIGNAL_LOUDNESS_SPECTRUM, 2)) / self.n_fft * self.attenuate_gain 92 | 93 | return A_WEIGHTED_SLICED_SIGNAL_LOUDNESS 94 | -------------------------------------------------------------------------------- /components/ptcrepe/README.md: -------------------------------------------------------------------------------- 1 | # CREPE Pitch Tracker (PyTorch) # 2 | 3 | - Original Tensorflow Implementation : [https://github.com/marl/crepe](https://github.com/marl/crepe) 4 | 5 | --- 6 | CREPE is a monophonic pitch tracker based on a deep convolutional neural network operating directly on the time-domain waveform input. CREPE is originally implemented with tensorflow, which is very inconvenient framework to use. 7 | 8 | 9 | ## Usage 10 | 11 | ```python 12 | import crepe 13 | import torch 14 | device = torch.device(0) 15 | cr = crepe.CREPE("full").to(device) 16 | cr.predict("path/to/audio.file", "path/to/output/directory/", ) 17 | ``` 18 | 19 | ## WIP 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /components/ptcrepe/ptcrepe/crepe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchaudio 4 | import os, sys 5 | from .utils import * 6 | import numpy as np 7 | 8 | 9 | class ConvBlock(nn.Module): 10 | def __init__(self, f, w, s, in_channels): 11 | super().__init__() 12 | p1 = (w - 1) // 2 13 | p2 = (w - 1) - p1 14 | self.pad = nn.ZeroPad2d((0, 0, p1, p2)) 15 | 16 | self.conv2d = nn.Conv2d( 17 | in_channels=in_channels, out_channels=f, kernel_size=(w, 1), stride=s 18 | ) 19 | self.relu = nn.ReLU() 20 | self.bn = nn.BatchNorm2d(f) 21 | self.pool = nn.MaxPool2d(kernel_size=(2, 1)) 22 | self.dropout = nn.Dropout(0.25) 23 | 24 | def forward(self, x): 25 | x = self.pad(x) 26 | x = self.conv2d(x) 27 | x = self.relu(x) 28 | x = self.bn(x) 29 | x = self.pool(x) 30 | x = self.dropout(x) 31 | return x 32 | 33 | 34 | class CREPE(nn.Module): 35 | def __init__(self, model_capacity="full"): 36 | super().__init__() 37 | 38 | capacity_multiplier = {"tiny": 4, "small": 8, "medium": 16, "large": 24, "full": 32}[ 39 | model_capacity 40 | ] 41 | 42 | self.layers = [1, 2, 3, 4, 5, 6] 43 | filters = [n * capacity_multiplier for n in [32, 4, 4, 4, 8, 16]] 44 | filters = [1] + filters 45 | widths = [512, 64, 64, 64, 64, 64] 46 | strides = [(4, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)] 47 | 48 | for i in range(len(self.layers)): 49 | f, w, s, in_channel = filters[i + 1], widths[i], strides[i], filters[i] 50 | self.add_module("conv%d" % i, ConvBlock(f, w, s, in_channel)) 51 | 52 | self.linear = nn.Linear(64 * capacity_multiplier, 360) 53 | self.load_weight(model_capacity) 54 | self.eval() 55 | 56 | def load_weight(self, model_capacity): 57 | download_weights(model_capacity) 58 | package_dir = os.path.dirname(os.path.realpath(__file__)) 59 | filename = "crepe-{}.pth".format(model_capacity) 60 | self.load_state_dict(torch.load(os.path.join(package_dir, filename))) 61 | 62 | def forward(self, x): 63 | # x : shape (batch, sample) 64 | x = x.view(x.shape[0], 1, -1, 1) 65 | for i in range(len(self.layers)): 66 | x = self.__getattr__("conv%d" % i)(x) 67 | 68 | x = x.permute(0, 3, 2, 1) 69 | x = x.reshape(x.shape[0], -1) 70 | x = self.linear(x) 71 | x = torch.sigmoid(x) 72 | return x 73 | 74 | def get_activation(self, audio, sr, center=True, step_size=10, batch_size=128): 75 | """ 76 | audio : (N,) or (C, N) 77 | """ 78 | 79 | if sr != 16000: 80 | rs = torchaudio.transforms.Resample(sr, 16000) 81 | audio = rs(audio) 82 | 83 | if len(audio.shape) == 2: 84 | if audio.shape[0] == 1: 85 | audio = audio[0] 86 | else: 87 | audio = audio.mean(dim=0) # make mono 88 | 89 | def get_frame(audio, step_size, center): 90 | if center: 91 | audio = nn.functional.pad(audio, pad=(512, 512)) 92 | # make 1024-sample frames of the audio with hop length of 10 milliseconds 93 | hop_length = int(16000 * step_size / 1000) 94 | n_frames = 1 + int((len(audio) - 1024) / hop_length) 95 | assert audio.dtype == torch.float32 96 | itemsize = 1 # float32 byte size 97 | frames = torch.as_strided( 98 | audio, size=(1024, n_frames), stride=(itemsize, hop_length * itemsize) 99 | ) 100 | frames = frames.transpose(0, 1).clone() 101 | 102 | frames -= torch.mean(frames, axis=1).unsqueeze(-1) 103 | frames /= torch.std(frames, axis=1).unsqueeze(-1) 104 | return frames 105 | 106 | frames = get_frame(audio, step_size, center) 107 | activation_stack = [] 108 | device = self.linear.weight.device 109 | 110 | for i in range(0, len(frames), batch_size): 111 | f = frames[i : min(i + batch_size, len(frames))] 112 | f = f.to(device) 113 | act = self.forward(f) 114 | activation_stack.append(act.cpu()) 115 | activation = torch.cat(activation_stack, dim=0) 116 | return activation 117 | 118 | def predict(self, audio, sr, viterbi=False, center=True, step_size=10, batch_size=128): 119 | activation = self.get_activation(audio, sr, batch_size=batch_size, step_size=step_size) 120 | frequency = to_freq(activation, viterbi=viterbi) 121 | confidence = activation.max(dim=1)[0] 122 | time = torch.arange(confidence.shape[0]) * step_size / 1000.0 123 | return time, frequency, confidence, activation 124 | 125 | def process_file( 126 | self, 127 | file, 128 | output=None, 129 | viterbi=False, 130 | center=True, 131 | step_size=10, 132 | save_plot=False, 133 | batch_size=128, 134 | ): 135 | try: 136 | audio, sr = torchaudio.load(file) 137 | except ValueError: 138 | print("CREPE-pytorch : Could not read", file, file=sys.stderr) 139 | 140 | with torch.no_grad(): 141 | time, frequency, confidence, activation = self.predict( 142 | audio, 143 | sr, 144 | viterbi=viterbi, 145 | center=center, 146 | step_size=step_size, 147 | batch_size=batch_size, 148 | ) 149 | 150 | time, frequency, confidence, activation = ( 151 | time.numpy(), 152 | frequency.numpy(), 153 | confidence.numpy(), 154 | activation.numpy(), 155 | ) 156 | 157 | f0_file = os.path.join(output, os.path.basename(os.path.splitext(file)[0])) + ".f0.csv" 158 | f0_data = np.vstack([time, frequency, confidence]).transpose() 159 | np.savetxt( 160 | f0_file, 161 | f0_data, 162 | fmt=["%.3f", "%.3f", "%.6f"], 163 | delimiter=",", 164 | header="time,frequency,confidence", 165 | comments="", 166 | ) 167 | 168 | # save the salience visualization in a PNG file 169 | if save_plot: 170 | import matplotlib.cm 171 | from imageio import imwrite 172 | 173 | plot_file = ( 174 | os.path.join(output, os.path.basename(os.path.splitext(file)[0])) 175 | + ".activation.png" 176 | ) 177 | # to draw the low pitches in the bottom 178 | salience = np.flip(activation, axis=1) 179 | inferno = matplotlib.cm.get_cmap("inferno") 180 | image = inferno(salience.transpose()) 181 | 182 | imwrite(plot_file, (255 * image).astype(np.uint8)) 183 | 184 | 185 | if __name__ == "__main__": 186 | cr = CREPE().cuda() 187 | import glob 188 | 189 | files = glob.glob("/workspace/data/singing_raw_16k/*.wav") 190 | # files = ["../../ddsp/data/violin/VI.+Double.wav"] 191 | target = "/workspace/data/singing_raw_16k/f0_0.004/" 192 | from tqdm import tqdm 193 | 194 | tq = tqdm(files) 195 | for file in tq: 196 | tq.set_description(file) 197 | cr.process_file(file, target, step_size=4, viterbi=True) 198 | -------------------------------------------------------------------------------- /components/ptcrepe/ptcrepe/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | 5 | 6 | def download_weights(model_capacitiy): 7 | try: 8 | from urllib.request import urlretrieve 9 | except ImportError: 10 | from urllib import urlretrieve 11 | 12 | weight_file = "crepe-{}.pth".format(model_capacitiy) 13 | base_url = "https://github.com/sweetcocoa/crepe-pytorch/raw/models/" 14 | 15 | # in all other cases, decompress the weights file if necessary 16 | package_dir = os.path.dirname(os.path.realpath(__file__)) 17 | weight_path = os.path.join(package_dir, weight_file) 18 | if not os.path.isfile(weight_path): 19 | print("Downloading weight file {} from {} ...".format(weight_path, base_url + weight_file)) 20 | urlretrieve(base_url + weight_file, weight_path) 21 | 22 | 23 | def to_local_average_cents(salience, center=None): 24 | """ 25 | find the weighted average cents near the argmax bin 26 | """ 27 | 28 | if not hasattr(to_local_average_cents, "cents_mapping"): 29 | # the bin number-to-cents mapping 30 | to_local_average_cents.mapping = ( 31 | torch.tensor(np.linspace(0, 7180, 360)) + 1997.3794084376191 32 | ) 33 | 34 | if isinstance(salience, np.ndarray): 35 | salience = torch.from_numpy(salience) 36 | 37 | if salience.ndim == 1: 38 | if center is None: 39 | center = int(torch.argmax(salience)) 40 | start = max(0, center - 4) 41 | end = min(len(salience), center + 5) 42 | salience = salience[start:end] 43 | product_sum = torch.sum(salience * to_local_average_cents.mapping[start:end]) 44 | weight_sum = torch.sum(salience) 45 | return product_sum / weight_sum 46 | if salience.ndim == 2: 47 | return torch.tensor( 48 | [to_local_average_cents(salience[i, :]) for i in range(salience.shape[0])] 49 | ) 50 | 51 | raise Exception("label should be either 1d or 2d Tensor") 52 | 53 | 54 | def to_viterbi_cents(salience): 55 | """ 56 | Find the Viterbi path using a transition prior that induces pitch 57 | continuity. 58 | 59 | * Note : This is NOT implemented with pytorch. 60 | """ 61 | from hmmlearn import hmm 62 | 63 | # uniform prior on the starting pitch 64 | starting = np.ones(360) / 360 65 | 66 | # transition probabilities inducing continuous pitch 67 | xx, yy = np.meshgrid(range(360), range(360)) 68 | transition = np.maximum(12 - abs(xx - yy), 0) 69 | transition = transition / np.sum(transition, axis=1)[:, None] 70 | 71 | # emission probability = fixed probability for self, evenly distribute the 72 | # others 73 | self_emission = 0.1 74 | emission = np.eye(360) * self_emission + np.ones(shape=(360, 360)) * ((1 - self_emission) / 360) 75 | 76 | # fix the model parameters because we are not optimizing the model 77 | model = hmm.MultinomialHMM(360, starting, transition) 78 | model.startprob_, model.transmat_, model.emissionprob_ = starting, transition, emission 79 | 80 | # find the Viterbi path 81 | observations = np.argmax(salience, axis=1) 82 | path = model.predict(observations.reshape(-1, 1), [len(observations)]) 83 | 84 | return np.array( 85 | [to_local_average_cents(salience[i, :], path[i]) for i in range(len(observations))] 86 | ) 87 | 88 | 89 | def to_freq(activation, viterbi=False): 90 | if viterbi: 91 | cents = to_viterbi_cents(activation.detach().numpy()) 92 | cents = torch.tensor(cents) 93 | else: 94 | cents = to_local_average_cents(activation) 95 | 96 | frequency = 10 * 2 ** (cents / 1200) 97 | frequency[torch.isnan(frequency)] = 0 98 | return frequency 99 | -------------------------------------------------------------------------------- /components/reverb.py: -------------------------------------------------------------------------------- 1 | """ 2 | 2020_01_17 - 2020_01_29 3 | Simple trainable FIR reverb model for DDSP decoder. 4 | TODO : 5 | numerically stable decays 6 | crossfade 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | class TrainableFIRReverb(nn.Module): 15 | def __init__(self, reverb_length=48000, device="cuda"): 16 | 17 | super(TrainableFIRReverb, self).__init__() 18 | 19 | # default reverb length is set to 3sec. 20 | # thus this model can max out t60 to 3sec, which corresponds to rich chamber characters. 21 | self.reverb_length = reverb_length 22 | self.device = device 23 | 24 | # impulse response of reverb. 25 | self.fir = nn.Parameter( 26 | torch.rand(1, self.reverb_length, dtype=torch.float32).to(self.device) * 2 - 1, 27 | requires_grad=True, 28 | ) 29 | 30 | # Initialized drywet to around 26%. 31 | # but equal-loudness crossfade between identity impulse and fir reverb impulse is not implemented yet. 32 | self.drywet = nn.Parameter( 33 | torch.tensor([-1.0], dtype=torch.float32).to(self.device), requires_grad=True 34 | ) 35 | 36 | # Initialized decay to 5, to make t60 = 1sec. 37 | self.decay = nn.Parameter( 38 | torch.tensor([3.0], dtype=torch.float32).to(self.device), requires_grad=True 39 | ) 40 | 41 | def forward(self, z): 42 | """ 43 | Compute FIR Reverb 44 | Input: 45 | z['audio_synth'] : batch of time-domain signals 46 | Output: 47 | output_signal : batch of reverberated signals 48 | """ 49 | 50 | # Send batch of input signals in time domain to frequency domain. 51 | # Appropriate zero padding is required for linear convolution. 52 | input_signal = z["audio_synth"] 53 | zero_pad_input_signal = nn.functional.pad(input_signal, (0, self.fir.shape[-1] - 1)) 54 | INPUT_SIGNAL = torch.rfft(zero_pad_input_signal, 1) 55 | 56 | # Build decaying impulse response and send it to frequency domain. 57 | # Appropriate zero padding is required for linear convolution. 58 | # Dry-wet mixing is done by mixing impulse response, rather than mixing at the final stage. 59 | 60 | """ TODO 61 | Not numerically stable decay method? 62 | """ 63 | decay_envelope = torch.exp( 64 | -(torch.exp(self.decay) + 2) 65 | * torch.linspace(0, 1, self.reverb_length, dtype=torch.float32).to(self.device) 66 | ) 67 | decay_fir = self.fir * decay_envelope 68 | 69 | ir_identity = torch.zeros(1, decay_fir.shape[-1]).to(self.device) 70 | ir_identity[:, 0] = 1 71 | 72 | """ TODO 73 | Equal-loudness(intensity) crossfade between to ir. 74 | """ 75 | final_fir = ( 76 | torch.sigmoid(self.drywet) * decay_fir + (1 - torch.sigmoid(self.drywet)) * ir_identity 77 | ) 78 | zero_pad_final_fir = nn.functional.pad(final_fir, (0, input_signal.shape[-1] - 1)) 79 | 80 | FIR = torch.rfft(zero_pad_final_fir, 1) 81 | 82 | # Convolve and inverse FFT to get original signal. 83 | OUTPUT_SIGNAL = torch.zeros_like(INPUT_SIGNAL).to(self.device) 84 | OUTPUT_SIGNAL[:, :, 0] = ( 85 | INPUT_SIGNAL[:, :, 0] * FIR[:, :, 0] - INPUT_SIGNAL[:, :, 1] * FIR[:, :, 1] 86 | ) 87 | OUTPUT_SIGNAL[:, :, 1] = ( 88 | INPUT_SIGNAL[:, :, 0] * FIR[:, :, 1] + INPUT_SIGNAL[:, :, 1] * FIR[:, :, 0] 89 | ) 90 | 91 | output_signal = torch.irfft(OUTPUT_SIGNAL, 1) 92 | 93 | return output_signal 94 | -------------------------------------------------------------------------------- /configs/violin.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 64 2 | bidirectional: false 3 | ckpt: ../../ckpt/violin/200131.pth 4 | crepe: full 5 | experiment_name: DDSP_violin 6 | f0_threshold: 0.5 7 | frame_resolution: 0.004 8 | gpu: 0 9 | gru_units: 512 10 | loss: mss 11 | lr: 0.001 12 | lr_decay: 0.98 13 | lr_min: 1.0e-07 14 | lr_scheduler: multi 15 | metric: mss 16 | mlp_layers: 3 17 | mlp_units: 512 18 | n_fft: 2048 19 | n_freq: 65 20 | n_harmonics: 101 21 | n_mels: 128 22 | n_mfcc: 30 23 | num_step: 100000 24 | num_workers: 4 25 | optimizer: radam 26 | resume: false 27 | sample_rate: 16000 28 | seed: 940513 29 | tensorboard_dir: ../tensorboard_log/ 30 | test: ../data/violin/test/ 31 | train: ../data/violin/train/ 32 | use_reverb: true 33 | use_z: false 34 | valid_waveform_sec: 12 35 | validation_interval: 1000 36 | waveform_sec: 1 37 | z_units: 16 38 | -------------------------------------------------------------------------------- /data/mp3_to_wav.sh: -------------------------------------------------------------------------------- 1 | # Convert .mp3 to wav (16000hz mono) 2 | echo Converting the mp3 files to wav ... 3 | 4 | FOLDER=$PWD 5 | COUNTER=$(find -name *.mp3|wc -l) 6 | 7 | for f in $PWD/**/*.mp3; do 8 | COUNTER=$((COUNTER - 1)) 9 | echo -ne "\rConverting ($COUNTER) : $f..." 10 | ffmpeg -y -loglevel fatal -i $f -ac 1 -ar 16000 ${f/\.mp3/.wav} 11 | done 12 | -------------------------------------------------------------------------------- /data/violin/test/VIII.+Double.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sweetcocoa/ddsp-pytorch/ea5f25318dd4cd22c601dd405ebc2bac8e3f4cb6/data/violin/test/VIII.+Double.wav -------------------------------------------------------------------------------- /data/violin/train/II.+Double.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sweetcocoa/ddsp-pytorch/ea5f25318dd4cd22c601dd405ebc2bac8e3f4cb6/data/violin/train/II.+Double.wav -------------------------------------------------------------------------------- /data/violin/train/III.+Corrente.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sweetcocoa/ddsp-pytorch/ea5f25318dd4cd22c601dd405ebc2bac8e3f4cb6/data/violin/train/III.+Corrente.wav -------------------------------------------------------------------------------- /data/violin/train/IV.+Double+Presto.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sweetcocoa/ddsp-pytorch/ea5f25318dd4cd22c601dd405ebc2bac8e3f4cb6/data/violin/train/IV.+Double+Presto.wav -------------------------------------------------------------------------------- /data/violin/train/VI.+Double.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sweetcocoa/ddsp-pytorch/ea5f25318dd4cd22c601dd405ebc2bac8e3f4cb6/data/violin/train/VI.+Double.wav -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | tqdm 4 | tensorboardX 5 | easydict 6 | omegaconf 7 | #torch 8 | #torchaudio 9 | #torchvision -------------------------------------------------------------------------------- /train/dataset/audiodata.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | import pandas as pd 4 | from torch.utils.data import Dataset 5 | import numpy as np 6 | 7 | """ 8 | Output : Randomly cropped wave with specific length & corresponding f0 (if necessary). 9 | """ 10 | 11 | 12 | class AudioData(Dataset): 13 | def __init__( 14 | self, 15 | paths, 16 | seed=940513, 17 | waveform_sec=4.0, 18 | sample_rate=16000, 19 | waveform_transform=None, 20 | label_transform=None, 21 | ): 22 | super().__init__() 23 | self.paths = paths 24 | self.random = np.random.RandomState(seed) 25 | self.waveform_sec = waveform_sec 26 | self.waveform_transform = waveform_transform 27 | self.label_transform = label_transform 28 | self.sample_rate = sample_rate 29 | 30 | def __getitem__(self, idx): 31 | raise NotImplementedError 32 | 33 | def __len__(self): 34 | return len(self.paths) 35 | 36 | 37 | class SupervisedAudioData(AudioData): 38 | def __init__( 39 | self, 40 | paths, 41 | csv_paths, 42 | seed=940513, 43 | waveform_sec=1.0, 44 | sample_rate=16000, 45 | frame_resolution=0.004, 46 | f0_threshold=0.5, 47 | waveform_transform=None, 48 | label_transform=None, 49 | random_sample=True, 50 | ): 51 | super().__init__( 52 | paths=paths, 53 | seed=seed, 54 | waveform_sec=waveform_sec, 55 | sample_rate=sample_rate, 56 | waveform_transform=waveform_transform, 57 | label_transform=label_transform, 58 | ) 59 | self.csv_paths = csv_paths 60 | self.frame_resolution = frame_resolution 61 | self.f0_threshold = f0_threshold 62 | self.num_frame = int(self.waveform_sec / self.frame_resolution) # number of csv's row 63 | self.hop_length = int(self.sample_rate * frame_resolution) 64 | self.num_wave = int(self.sample_rate * self.waveform_sec) 65 | self.random_sample = random_sample 66 | 67 | def __getitem__(self, file_idx): 68 | target_f0 = pd.read_csv(self.csv_paths[file_idx]) 69 | 70 | # sample interval 71 | if self.random_sample: 72 | idx_from = self.random.randint( 73 | 1, len(target_f0) - self.num_frame 74 | ) # No samples from first frame - annoying to implement b.c it has to be padding at the first frame. 75 | else: 76 | idx_from = 1 77 | idx_to = idx_from + self.num_frame 78 | frame_from = target_f0["time"][idx_from] 79 | # frame_to = target_f0['time'][idx_to] 80 | confidence = target_f0["confidence"][idx_from:idx_to] 81 | 82 | f0 = target_f0["frequency"][idx_from:idx_to].values.astype(np.float32) 83 | f0[confidence < self.f0_threshold] = 0.0 84 | f0 = torch.from_numpy(f0) 85 | 86 | waveform_from = int(frame_from * self.sample_rate) 87 | # waveform_to = waveform_from + self.num_wave 88 | 89 | audio, sr = torchaudio.load( 90 | self.paths[file_idx], offset=waveform_from, num_frames=self.num_wave 91 | ) 92 | audio = audio[0] 93 | assert sr == self.sample_rate 94 | 95 | return dict(audio=audio, f0=f0,) 96 | -------------------------------------------------------------------------------- /train/loss/mss_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of Multi-Scale Spectral Loss as described in DDSP, 3 | which is originally suggested in NSF (Wang et al., 2019) 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torchaudio 9 | import torch.nn.functional as F 10 | 11 | 12 | class SSSLoss(nn.Module): 13 | """ 14 | Single-scale Spectral Loss. 15 | """ 16 | 17 | def __init__(self, n_fft, alpha=1.0, overlap=0.75, eps=1e-7): 18 | super().__init__() 19 | self.n_fft = n_fft 20 | self.alpha = alpha 21 | self.eps = eps 22 | self.hop_length = int(n_fft * (1 - overlap)) # 25% of the length 23 | self.spec = torchaudio.transforms.Spectrogram(n_fft=self.n_fft, hop_length=self.hop_length) 24 | 25 | def forward(self, x_pred, x_true): 26 | S_true = self.spec(x_true) 27 | S_pred = self.spec(x_pred) 28 | 29 | linear_term = F.l1_loss(S_pred, S_true) 30 | log_term = F.l1_loss((S_true + self.eps).log2(), (S_pred + self.eps).log2()) 31 | 32 | loss = linear_term + self.alpha * log_term 33 | return loss 34 | 35 | 36 | class MSSLoss(nn.Module): 37 | """ 38 | Multi-scale Spectral Loss. 39 | 40 | Usage :: 41 | 42 | mssloss = MSSLoss([2048, 1024, 512, 256], alpha=1.0, overlap=0.75) 43 | mssloss(y_pred, y_gt) 44 | 45 | input(y_pred, y_gt) : two of torch.tensor w/ shape(batch, 1d-wave) 46 | output(loss) : torch.tensor(scalar) 47 | """ 48 | 49 | def __init__(self, n_ffts: list, alpha=1.0, overlap=0.75, eps=1e-7, use_reverb=True): 50 | super().__init__() 51 | self.losses = nn.ModuleList([SSSLoss(n_fft, alpha, overlap, eps) for n_fft in n_ffts]) 52 | if use_reverb: 53 | self.signal_key = "audio_reverb" 54 | else: 55 | self.signal_key = "audio_synth" 56 | 57 | def forward(self, x_pred, x_true): 58 | if isinstance(x_pred, dict): 59 | x_pred = x_pred[self.signal_key] 60 | 61 | if isinstance(x_true, dict): 62 | x_true = x_true["audio"] 63 | 64 | # cut reverbation off 65 | x_pred = x_pred[..., : x_true.shape[-1]] 66 | 67 | losses = [loss(x_pred, x_true) for loss in self.losses] 68 | return sum(losses).sum() 69 | 70 | -------------------------------------------------------------------------------- /train/network/autoencoder/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from components.harmonic_oscillator import HarmonicOscillator 5 | from components.reverb import TrainableFIRReverb 6 | from components.filtered_noise import FilteredNoise 7 | from network.autoencoder.decoder import Decoder 8 | from network.autoencoder.encoder import Encoder 9 | 10 | 11 | class AutoEncoder(nn.Module): 12 | def __init__(self, config): 13 | """ 14 | encoder_config 15 | use_z=False, 16 | sample_rate=16000, 17 | z_units=16, 18 | n_fft=2048, 19 | hop_length=64, 20 | n_mels=128, 21 | n_mfcc=30, 22 | gru_units=512 23 | 24 | decoder_config 25 | mlp_units=512, 26 | mlp_layers=3, 27 | use_z=False, 28 | z_units=16, 29 | n_harmonics=101, 30 | n_freq=65, 31 | gru_units=512, 32 | 33 | components_config 34 | sample_rate 35 | hop_length 36 | """ 37 | super().__init__() 38 | 39 | self.decoder = Decoder(config) 40 | self.encoder = Encoder(config) 41 | 42 | hop_length = frame_length = int(config.sample_rate * config.frame_resolution) 43 | 44 | self.harmonic_oscillator = HarmonicOscillator( 45 | sr=config.sample_rate, frame_length=hop_length 46 | ) 47 | 48 | self.filtered_noise = FilteredNoise(frame_length=hop_length) 49 | 50 | self.reverb = TrainableFIRReverb(reverb_length=config.sample_rate * 3) 51 | 52 | self.crepe = None 53 | self.config = config 54 | 55 | def forward(self, batch, add_reverb=True): 56 | """ 57 | z 58 | 59 | input(dict(f0, z(optional), l)) : a dict object which contains key-values below 60 | f0 : fundamental frequency for each frame. torch.tensor w/ shape(B, time) 61 | z : (optional) residual information. torch.tensor w/ shape(B, time, z_units) 62 | loudness : torch.tensor w/ shape(B, time) 63 | """ 64 | batch = self.encoder(batch) 65 | latent = self.decoder(batch) 66 | 67 | harmonic = self.harmonic_oscillator(latent) 68 | noise = self.filtered_noise(latent) 69 | 70 | audio = dict( 71 | harmonic=harmonic, noise=noise, audio_synth=harmonic + noise[:, : harmonic.shape[-1]] 72 | ) 73 | 74 | if self.config.use_reverb and add_reverb: 75 | audio["audio_reverb"] = self.reverb(audio) 76 | 77 | audio["a"] = latent["a"] 78 | audio["c"] = latent["c"] 79 | 80 | return audio 81 | 82 | def get_f0(self, x, sample_rate=16000, f0_threshold=0.5): 83 | """ 84 | input: 85 | x = torch.tensor((1), wave sample) 86 | 87 | output: 88 | f0 : (n_frames, ). fundamental frequencies 89 | """ 90 | if self.crepe is None: 91 | from components.ptcrepe.ptcrepe.crepe import CREPE 92 | 93 | self.crepe = CREPE(self.config.crepe) 94 | for param in self.parameters(): 95 | self.device = param.device 96 | break 97 | self.crepe = self.crepe.to(self.device) 98 | self.eval() 99 | 100 | with torch.no_grad(): 101 | time, f0, confidence, activation = self.crepe.predict( 102 | x, 103 | sr=sample_rate, 104 | viterbi=True, 105 | step_size=int(self.config.frame_resolution * 1000), 106 | batch_size=32, 107 | ) 108 | 109 | f0 = f0.float().to(self.device) 110 | f0[confidence < f0_threshold] = 0.0 111 | f0 = f0[:-1] 112 | 113 | return f0 114 | 115 | def reconstruction(self, x, sample_rate=16000, add_reverb=True, f0_threshold=0.5, f0=None): 116 | """ 117 | input: 118 | x = torch.tensor((1), wave sample) 119 | f0 (if exists) = (num_frames, ) 120 | 121 | output(dict): 122 | f0 : (n_frames, ). fundamental frequencies 123 | a : (n_frames, ). amplitudes 124 | c : (n_harmonics, n_frames). harmonic constants 125 | sig : (n_samples) 126 | audio_reverb : (n_samples + reverb, ). reconstructed signal 127 | """ 128 | self.eval() 129 | 130 | with torch.no_grad(): 131 | if f0 is None: 132 | f0 = self.get_f0(x, sample_rate=sample_rate, f0_threshold=f0_threshold) 133 | 134 | batch = dict(f0=f0.unsqueeze(0), audio=x.to(self.device),) 135 | 136 | recon = self.forward(batch, add_reverb=add_reverb) 137 | 138 | # make shape consistent(removing batch dim) 139 | for k, v in recon.items(): 140 | recon[k] = v[0] 141 | 142 | recon["f0"] = f0 143 | 144 | return recon 145 | -------------------------------------------------------------------------------- /train/network/autoencoder/decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of decoder network architecture of DDSP. 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class MLP(nn.Module): 11 | """ 12 | MLP (Multi-layer Perception). 13 | 14 | One layer consists of what as below: 15 | - 1 Dense Layer 16 | - 1 Layer Norm 17 | - 1 ReLU 18 | 19 | constructor arguments : 20 | n_input : dimension of input 21 | n_units : dimension of hidden unit 22 | n_layer : depth of MLP (the number of layers) 23 | relu : relu (default : nn.ReLU, can be changed to nn.LeakyReLU, nn.PReLU for example.) 24 | 25 | input(x): torch.tensor w/ shape(B, ... , n_input) 26 | output(x): torch.tensor w/ (B, ..., n_units) 27 | """ 28 | 29 | def __init__(self, n_input, n_units, n_layer, relu=nn.ReLU, inplace=True): 30 | super().__init__() 31 | self.n_layer = n_layer 32 | self.n_input = n_input 33 | self.n_units = n_units 34 | self.inplace = inplace 35 | 36 | self.add_module( 37 | f"mlp_layer1", 38 | nn.Sequential( 39 | nn.Linear(n_input, n_units), 40 | nn.LayerNorm(normalized_shape=n_units), 41 | relu(inplace=self.inplace), 42 | ), 43 | ) 44 | 45 | for i in range(2, n_layer + 1): 46 | self.add_module( 47 | f"mlp_layer{i}", 48 | nn.Sequential( 49 | nn.Linear(n_units, n_units), 50 | nn.LayerNorm(normalized_shape=n_units), 51 | relu(inplace=self.inplace), 52 | ), 53 | ) 54 | 55 | def forward(self, x): 56 | for i in range(1, self.n_layer + 1): 57 | x = self.__getattr__(f"mlp_layer{i}")(x) 58 | return x 59 | 60 | 61 | class Decoder(nn.Module): 62 | """ 63 | Decoder. 64 | 65 | Constructor arguments: 66 | use_z : (Bool), if True, Decoder will use z as input. 67 | mlp_units: 512 68 | mlp_layers: 3 69 | z_units: 16 70 | n_harmonics: 101 71 | n_freq: 65 72 | gru_units: 512 73 | bidirectional: False 74 | 75 | input(dict(f0, z(optional), l)) : a dict object which contains key-values below 76 | f0 : fundamental frequency for each frame. torch.tensor w/ shape(B, time) 77 | z : (optional) residual information. torch.tensor w/ shape(B, time, z_units) 78 | loudness : torch.tensor w/ shape(B, time) 79 | 80 | *note dimension of z is not specified in the paper. 81 | 82 | output : a dict object which contains key-values below 83 | f0 : same as input 84 | c : torch.tensor w/ shape(B, time, n_harmonics) which satisfies sum(c) == 1 85 | a : torch.tensor w/ shape(B, time) which satisfies a > 0 86 | H : noise filter in frequency domain. torch.tensor w/ shape(B, frame_num, filter_coeff_length) 87 | """ 88 | 89 | def __init__(self, config): 90 | super().__init__() 91 | 92 | self.config = config 93 | 94 | self.mlp_f0 = MLP(n_input=1, n_units=config.mlp_units, n_layer=config.mlp_layers) 95 | self.mlp_loudness = MLP(n_input=1, n_units=config.mlp_units, n_layer=config.mlp_layers) 96 | if config.use_z: 97 | self.mlp_z = MLP( 98 | n_input=config.z_units, n_units=config.mlp_units, n_layer=config.mlp_layers 99 | ) 100 | self.num_mlp = 3 101 | else: 102 | self.num_mlp = 2 103 | 104 | self.gru = nn.GRU( 105 | input_size=self.num_mlp * config.mlp_units, 106 | hidden_size=config.gru_units, 107 | num_layers=1, 108 | batch_first=True, 109 | bidirectional=config.bidirectional, 110 | ) 111 | 112 | self.mlp_gru = MLP( 113 | n_input=config.gru_units * 2 if config.bidirectional else config.gru_units, 114 | n_units=config.mlp_units, 115 | n_layer=config.mlp_layers, 116 | inplace=True, 117 | ) 118 | 119 | # one element for overall loudness 120 | self.dense_harmonic = nn.Linear(config.mlp_units, config.n_harmonics + 1) 121 | self.dense_filter = nn.Linear(config.mlp_units, config.n_freq) 122 | 123 | def forward(self, batch): 124 | f0 = batch["f0"].unsqueeze(-1) 125 | loudness = batch["loudness"].unsqueeze(-1) 126 | 127 | if self.config.use_z: 128 | z = batch["z"] 129 | latent_z = self.mlp_z(z) 130 | 131 | latent_f0 = self.mlp_f0(f0) 132 | latent_loudness = self.mlp_loudness(loudness) 133 | 134 | if self.config.use_z: 135 | latent = torch.cat((latent_f0, latent_z, latent_loudness), dim=-1) 136 | else: 137 | latent = torch.cat((latent_f0, latent_loudness), dim=-1) 138 | 139 | latent, (h) = self.gru(latent) 140 | latent = self.mlp_gru(latent) 141 | 142 | amplitude = self.dense_harmonic(latent) 143 | 144 | a = amplitude[..., 0] 145 | a = Decoder.modified_sigmoid(a) 146 | 147 | # a = torch.sigmoid(amplitude[..., 0]) 148 | c = F.softmax(amplitude[..., 1:], dim=-1) 149 | 150 | H = self.dense_filter(latent) 151 | H = Decoder.modified_sigmoid(H) 152 | 153 | c = c.permute(0, 2, 1) # to match the shape of harmonic oscillator's input. 154 | 155 | return dict(f0=batch["f0"], a=a, c=c, H=H) 156 | 157 | @staticmethod 158 | def modified_sigmoid(a): 159 | a = a.sigmoid() 160 | a = a.pow(2.3026) # log10 161 | a = a.mul(2.0) 162 | a.add_(1e-7) 163 | return a 164 | 165 | -------------------------------------------------------------------------------- /train/network/autoencoder/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | import torch.nn as nn 4 | from components.loudness_extractor import LoudnessExtractor 5 | 6 | 7 | class Z_Encoder(nn.Module): 8 | def __init__( 9 | self, 10 | n_fft, 11 | hop_length, 12 | sample_rate=16000, 13 | n_mels=128, 14 | n_mfcc=30, 15 | gru_units=512, 16 | z_units=16, 17 | bidirectional=False, 18 | ): 19 | super().__init__() 20 | self.mfcc = torchaudio.transforms.MFCC( 21 | sample_rate=sample_rate, 22 | n_mfcc=n_mfcc, 23 | log_mels=True, 24 | melkwargs=dict( 25 | n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, f_min=20.0, f_max=8000.0, 26 | ), 27 | ) 28 | 29 | self.norm = nn.InstanceNorm1d(n_mfcc, affine=True) 30 | self.permute = lambda x: x.permute(0, 2, 1) 31 | self.gru = nn.GRU( 32 | input_size=n_mfcc, 33 | hidden_size=gru_units, 34 | num_layers=1, 35 | batch_first=True, 36 | bidirectional=bidirectional, 37 | ) 38 | self.dense = nn.Linear(gru_units * 2 if bidirectional else gru_units, z_units) 39 | 40 | def forward(self, batch): 41 | x = batch["audio"] 42 | x = self.mfcc(x) 43 | x = x[:, :, :-1] 44 | x = self.norm(x) 45 | x = self.permute(x) 46 | x, _ = self.gru(x) 47 | x = self.dense(x) 48 | return x 49 | 50 | 51 | class Encoder(nn.Module): 52 | """ 53 | Encoder. 54 | 55 | contains: Z_encoder, loudness extractor 56 | 57 | Constructor arguments: 58 | use_z : Bool, if True, Encoder will produce z as output. 59 | sample_rate=16000, 60 | z_units=16, 61 | n_fft=2048, 62 | n_mels=128, 63 | n_mfcc=30, 64 | gru_units=512, 65 | bidirectional=False 66 | 67 | input(dict(audio, f0)) : a dict object which contains key-values below 68 | f0 : fundamental frequency for each frame. torch.tensor w/ shape(B, frame) 69 | audio : raw audio w/ shape(B, time) 70 | 71 | output : a dict object which contains key-values below 72 | 73 | loudness : torch.tensor w/ shape(B, frame) 74 | f0 : same as input 75 | z : (optional) residual information. torch.tensor w/ shape(B, frame, z_units) 76 | """ 77 | 78 | def __init__(self, config): 79 | super().__init__() 80 | 81 | self.config = config 82 | self.hop_length = int(config.sample_rate * config.frame_resolution) 83 | 84 | self.loudness_extractor = LoudnessExtractor( 85 | sr=config.sample_rate, frame_length=self.hop_length, 86 | ) 87 | 88 | if config.use_z: 89 | self.z_encoder = Z_Encoder( 90 | sample_rate=config.sample_rate, 91 | n_fft=config.n_fft, 92 | hop_length=self.hop_length, 93 | n_mels=config.n_mels, 94 | n_mfcc=config.n_mfcc, 95 | gru_units=config.gru_units, 96 | z_units=config.z_units, 97 | bidirectional=config.bidirectional, 98 | ) 99 | 100 | def forward(self, batch): 101 | batch["loudness"] = self.loudness_extractor(batch) 102 | if self.config.use_z: 103 | batch["z"] = self.z_encoder(batch) 104 | 105 | if self.config.sample_rate % self.hop_length != 0: 106 | # if sample rate is not divided by hop_length 107 | # In short, this is not needed if sr == 16000 108 | batch["loudness"] = batch["loudness"][:, : batch["f0"].shape[-1]] 109 | batch["z"] = batch["z"][:, : batch["f0"].shape[-1]] 110 | 111 | return batch 112 | 113 | -------------------------------------------------------------------------------- /train/optimizer/radam.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original Source : https://github.com/LiyuanLucasLiu/RAdam 3 | """ 4 | 5 | import math 6 | import torch 7 | from torch.optim.optimizer import Optimizer, required 8 | 9 | class RAdam(Optimizer): 10 | 11 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 12 | if not 0.0 <= lr: 13 | raise ValueError("Invalid learning rate: {}".format(lr)) 14 | if not 0.0 <= eps: 15 | raise ValueError("Invalid epsilon value: {}".format(eps)) 16 | if not 0.0 <= betas[0] < 1.0: 17 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 18 | if not 0.0 <= betas[1] < 1.0: 19 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 20 | 21 | self.degenerated_to_sgd = degenerated_to_sgd 22 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 23 | for param in params: 24 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): 25 | param['buffer'] = [[None, None, None] for _ in range(10)] 26 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) 27 | super(RAdam, self).__init__(params, defaults) 28 | 29 | def __setstate__(self, state): 30 | super(RAdam, self).__setstate__(state) 31 | 32 | def step(self, closure=None): 33 | 34 | loss = None 35 | if closure is not None: 36 | loss = closure() 37 | 38 | for group in self.param_groups: 39 | 40 | for p in group['params']: 41 | if p.grad is None: 42 | continue 43 | grad = p.grad.data.float() 44 | if grad.is_sparse: 45 | raise RuntimeError('RAdam does not support sparse gradients') 46 | 47 | p_data_fp32 = p.data.float() 48 | 49 | state = self.state[p] 50 | 51 | if len(state) == 0: 52 | state['step'] = 0 53 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 54 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 55 | else: 56 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 57 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 58 | 59 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 60 | beta1, beta2 = group['betas'] 61 | 62 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 63 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 64 | 65 | state['step'] += 1 66 | buffered = group['buffer'][int(state['step'] % 10)] 67 | if state['step'] == buffered[0]: 68 | N_sma, step_size = buffered[1], buffered[2] 69 | else: 70 | buffered[0] = state['step'] 71 | beta2_t = beta2 ** state['step'] 72 | N_sma_max = 2 / (1 - beta2) - 1 73 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 74 | buffered[1] = N_sma 75 | 76 | # more conservative since it's an approximated value 77 | if N_sma >= 5: 78 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 79 | elif self.degenerated_to_sgd: 80 | step_size = 1.0 / (1 - beta1 ** state['step']) 81 | else: 82 | step_size = -1 83 | buffered[2] = step_size 84 | 85 | # more conservative since it's an approximated value 86 | if N_sma >= 5: 87 | if group['weight_decay'] != 0: 88 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 89 | denom = exp_avg_sq.sqrt().add_(group['eps']) 90 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 91 | p.data.copy_(p_data_fp32) 92 | elif step_size > 0: 93 | if group['weight_decay'] != 0: 94 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 95 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 96 | p.data.copy_(p_data_fp32) 97 | 98 | return loss 99 | 100 | class PlainRAdam(Optimizer): 101 | 102 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 103 | if not 0.0 <= lr: 104 | raise ValueError("Invalid learning rate: {}".format(lr)) 105 | if not 0.0 <= eps: 106 | raise ValueError("Invalid epsilon value: {}".format(eps)) 107 | if not 0.0 <= betas[0] < 1.0: 108 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 109 | if not 0.0 <= betas[1] < 1.0: 110 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 111 | 112 | self.degenerated_to_sgd = degenerated_to_sgd 113 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 114 | 115 | super(PlainRAdam, self).__init__(params, defaults) 116 | 117 | def __setstate__(self, state): 118 | super(PlainRAdam, self).__setstate__(state) 119 | 120 | def step(self, closure=None): 121 | 122 | loss = None 123 | if closure is not None: 124 | loss = closure() 125 | 126 | for group in self.param_groups: 127 | 128 | for p in group['params']: 129 | if p.grad is None: 130 | continue 131 | grad = p.grad.data.float() 132 | if grad.is_sparse: 133 | raise RuntimeError('RAdam does not support sparse gradients') 134 | 135 | p_data_fp32 = p.data.float() 136 | 137 | state = self.state[p] 138 | 139 | if len(state) == 0: 140 | state['step'] = 0 141 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 142 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 143 | else: 144 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 145 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 146 | 147 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 148 | beta1, beta2 = group['betas'] 149 | 150 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 151 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 152 | 153 | state['step'] += 1 154 | beta2_t = beta2 ** state['step'] 155 | N_sma_max = 2 / (1 - beta2) - 1 156 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 157 | 158 | 159 | # more conservative since it's an approximated value 160 | if N_sma >= 5: 161 | if group['weight_decay'] != 0: 162 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 163 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 164 | denom = exp_avg_sq.sqrt().add_(group['eps']) 165 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 166 | p.data.copy_(p_data_fp32) 167 | elif self.degenerated_to_sgd: 168 | if group['weight_decay'] != 0: 169 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 170 | step_size = group['lr'] / (1 - beta1 ** state['step']) 171 | p_data_fp32.add_(-step_size, exp_avg) 172 | p.data.copy_(p_data_fp32) 173 | 174 | return loss 175 | 176 | 177 | class AdamW(Optimizer): 178 | 179 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0): 180 | if not 0.0 <= lr: 181 | raise ValueError("Invalid learning rate: {}".format(lr)) 182 | if not 0.0 <= eps: 183 | raise ValueError("Invalid epsilon value: {}".format(eps)) 184 | if not 0.0 <= betas[0] < 1.0: 185 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 186 | if not 0.0 <= betas[1] < 1.0: 187 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 188 | 189 | defaults = dict(lr=lr, betas=betas, eps=eps, 190 | weight_decay=weight_decay, warmup = warmup) 191 | super(AdamW, self).__init__(params, defaults) 192 | 193 | def __setstate__(self, state): 194 | super(AdamW, self).__setstate__(state) 195 | 196 | def step(self, closure=None): 197 | loss = None 198 | if closure is not None: 199 | loss = closure() 200 | 201 | for group in self.param_groups: 202 | 203 | for p in group['params']: 204 | if p.grad is None: 205 | continue 206 | grad = p.grad.data.float() 207 | if grad.is_sparse: 208 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 209 | 210 | p_data_fp32 = p.data.float() 211 | 212 | state = self.state[p] 213 | 214 | if len(state) == 0: 215 | state['step'] = 0 216 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 217 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 218 | else: 219 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 220 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 221 | 222 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 223 | beta1, beta2 = group['betas'] 224 | 225 | state['step'] += 1 226 | 227 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 228 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 229 | 230 | denom = exp_avg_sq.sqrt().add_(group['eps']) 231 | bias_correction1 = 1 - beta1 ** state['step'] 232 | bias_correction2 = 1 - beta2 ** state['step'] 233 | 234 | if group['warmup'] > state['step']: 235 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 236 | else: 237 | scheduled_lr = group['lr'] 238 | 239 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 240 | 241 | if group['weight_decay'] != 0: 242 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 243 | 244 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 245 | 246 | p.data.copy_(p_data_fp32) 247 | 248 | return loss -------------------------------------------------------------------------------- /train/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | args : 3 | --input : input wav 4 | --output : output wav path 5 | --ckpt : pretrained weight file 6 | --config : network-corresponding yaml config file 7 | --wave_length : wave length in format 8 | (default : 0, which means all) 9 | WARNING : gpu memory might be not enough. 10 | """ 11 | 12 | import torch 13 | import torchaudio 14 | import os, sys 15 | 16 | sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/../") 17 | from network.autoencoder.autoencoder import AutoEncoder 18 | from omegaconf import OmegaConf 19 | 20 | import argparse 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--input", default=".wav") 24 | parser.add_argument("--output", default="output.wav") 25 | parser.add_argument("--ckpt", default=".pth") 26 | parser.add_argument("--config", default=".yaml") 27 | parser.add_argument("--wave_length", default=16000) 28 | args = parser.parse_args() 29 | 30 | y, sr = torchaudio.load(args.input, num_frames=None if args.wave_length == 0 else args.wave_length) 31 | 32 | config = OmegaConf.load(args.config) 33 | if sr != config.sample_rate: 34 | # Resample if sampling rate is not equal to model's 35 | resampler = torchaudio.transforms.Resample(sr, config.sample_rate) 36 | y = resampler(y) 37 | 38 | print("File :", args.input, "Loaded") 39 | 40 | net = AutoEncoder(config).cuda() 41 | net.load_state_dict(torch.load(args.ckpt)) 42 | net.eval() 43 | 44 | print("Network Loaded") 45 | 46 | recon = net.reconstruction(y) 47 | 48 | dereverb = recon["audio_synth"].cpu() 49 | torchaudio.save( 50 | os.path.splitext(args.output)[0] + "_synth.wav", dereverb, sample_rate=config.sample_rate 51 | ) 52 | 53 | if config.use_reverb: 54 | recon_add_reverb = recon["audio_reverb"].cpu() 55 | torchaudio.save( 56 | os.path.splitext(args.output)[0] + "_reverb.wav", 57 | recon_add_reverb, 58 | sample_rate=config.sample_rate, 59 | ) 60 | -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchaudio 4 | from omegaconf import OmegaConf 5 | import sys, os, tqdm, glob 6 | import numpy as np 7 | 8 | sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/../") 9 | from torch.utils.data.dataloader import DataLoader 10 | import torch.optim as optim 11 | 12 | from trainer.trainer import Trainer 13 | from trainer.io import setup, set_seeds 14 | 15 | from dataset.audiodata import SupervisedAudioData, AudioData 16 | from network.autoencoder.autoencoder import AutoEncoder 17 | from loss.mss_loss import MSSLoss 18 | from optimizer.radam import RAdam 19 | 20 | """ 21 | "setup" allows you to OVERRIDE the config through command line interface 22 | - for example 23 | $ python train.py --batch_size 64 --lr 0.01 --use_reverb 24 | """ 25 | 26 | config = setup(default_config="../configs/violin.yaml") 27 | # config = setup(pdb_on_error=True, trace=False, autolog=False, default_config=dict( 28 | # # general config 29 | # ckpt="../../ddsp_ckpt/violin/200131.pth", # checkpoint 30 | # gpu="0", 31 | # num_workers=4, # number of dataloader thread 32 | # seed=940513, # random seed 33 | # tensorboard_dir="../tensorboard_log/", 34 | # experiment_name="DDSP_violin", # experiment results are compared w/ this name. 35 | 36 | # # data config 37 | # train="../data/violin/train/", # data directory. should contain f0, too. 38 | # test="../data/violin/test/", 39 | # waveform_sec=1.0, # the length of training data. 40 | # frame_resolution=0.004, # 1 / frame rate 41 | # batch_size=64, 42 | # f0_threshold=0.5, # f0 with confidence below threshold will go to ZERO. 43 | # valid_waveform_sec=4.0, # the length of validation data 44 | # n_fft=2048, # (Z encoder) 45 | # n_mels=128, # (Z encoder) 46 | # n_mfcc=30, # (Z encoder) 47 | # sample_rate=16000, 48 | 49 | # # training config 50 | # num_step=100000, 51 | # validation_interval=1000, 52 | # lr=0.001, 53 | # lr_decay=0.98, 54 | # lr_min=1e-7, 55 | # lr_scheduler="multi", # 'plateau' 'no' 'cosine' 56 | # optimizer='radam', # 'adam', 'radam' 57 | # loss="mss", 58 | # metric="mss", 59 | # resume=False, # when training from a specific checkpoint. 60 | 61 | # # network config 62 | # mlp_units=512, 63 | # mlp_layers=3, 64 | # use_z=False, 65 | # use_reverb=False, 66 | # z_units=16, 67 | # n_harmonics=101, 68 | # n_freq=65, 69 | # gru_units=512, 70 | # crepe="full", 71 | # bidirectional=False, 72 | # )) 73 | 74 | print(OmegaConf.create(config.__dict__).pretty()) 75 | set_seeds(config.seed) 76 | Trainer.set_experiment_name(config.experiment_name) 77 | 78 | net = AutoEncoder(config).cuda() 79 | 80 | loss = MSSLoss([2048, 1024, 512, 256], use_reverb=config.use_reverb).cuda() 81 | 82 | # Define evaluation metrics 83 | if config.metric == "mss": 84 | 85 | def metric(output, gt): 86 | with torch.no_grad(): 87 | return -loss(output, gt) 88 | 89 | 90 | elif config.metric == "f0": 91 | # TODO Implement 92 | raise NotImplementedError 93 | else: 94 | raise NotImplementedError 95 | # -----------------------------/> 96 | 97 | # Dataset & Dataloader Prepare 98 | train_data = glob.glob(config.train + "/*.wav") * config.batch_size 99 | train_data_csv = [ 100 | os.path.dirname(wav) 101 | + f"/f0_{config.frame_resolution:.3f}/" 102 | + os.path.basename(os.path.splitext(wav)[0]) 103 | + ".f0.csv" 104 | for wav in train_data 105 | ] 106 | 107 | valid_data = glob.glob(config.test + "/*.wav") 108 | valid_data_csv = [ 109 | os.path.dirname(wav) 110 | + f"/f0_{config.frame_resolution:.3f}/" 111 | + os.path.basename(os.path.splitext(wav)[0]) 112 | + ".f0.csv" 113 | for wav in valid_data 114 | ] 115 | 116 | train_dataset = SupervisedAudioData( 117 | sample_rate=config.sample_rate, 118 | paths=train_data, 119 | csv_paths=train_data_csv, 120 | seed=config.seed, 121 | waveform_sec=config.waveform_sec, 122 | frame_resolution=config.frame_resolution, 123 | ) 124 | 125 | valid_dataset = SupervisedAudioData( 126 | sample_rate=config.sample_rate, 127 | paths=valid_data, 128 | csv_paths=valid_data_csv, 129 | seed=config.seed, 130 | waveform_sec=config.valid_waveform_sec, 131 | frame_resolution=config.frame_resolution, 132 | random_sample=False, 133 | ) 134 | 135 | train_dataloader = DataLoader( 136 | train_dataset, 137 | batch_size=config.batch_size, 138 | shuffle=True, 139 | num_workers=config.num_workers, 140 | pin_memory=True, 141 | ) 142 | 143 | valid_dataloader = DataLoader( 144 | valid_dataset, 145 | batch_size=int(config.batch_size // (config.valid_waveform_sec / config.waveform_sec)), 146 | shuffle=False, 147 | num_workers=config.num_workers, 148 | pin_memory=False, 149 | ) 150 | # -------------------------------------/> 151 | 152 | # Setting Optimizer 153 | if config.optimizer == "adam": 154 | optimizer = optim.Adam(filter(lambda x: x.requires_grad, net.parameters()), lr=config.lr) 155 | elif config.optimizer == "radam": 156 | optimizer = RAdam(filter(lambda x: x.requires_grad, net.parameters()), lr=config.lr) 157 | else: 158 | raise NotImplementedError 159 | # -------------------------------------/> 160 | 161 | # Setting Scheduler 162 | if config.lr_scheduler == "cosine": 163 | # restart every T_0 * validation_interval steps 164 | scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( 165 | optimizer, T_0=20, eta_min=config.lr_min 166 | ) 167 | elif config.lr_scheduler == "plateau": 168 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 169 | optimizer, mode="max", patience=5, factor=config.lr_decay 170 | ) 171 | elif config.lr_scheduler == "multi": 172 | # decay every ( 10000 // validation_interval ) steps 173 | scheduler = optim.lr_scheduler.MultiStepLR( 174 | optimizer, 175 | [(x + 1) * 10000 // config.validation_interval for x in range(10)], 176 | gamma=config.lr_decay, 177 | ) 178 | elif config.lr_scheduler == "no": 179 | scheduler = None 180 | else: 181 | raise ValueError(f"unknown lr_scheduler :: {config.lr_scheduler}") 182 | # ---------------------------------------/> 183 | 184 | trainer = Trainer( 185 | net, 186 | criterion=loss, 187 | metric=metric, 188 | train_dataloader=train_dataloader, 189 | val_dataloader=valid_dataloader, 190 | optimizer=optimizer, 191 | lr_scheduler=scheduler, 192 | ckpt=config.ckpt, 193 | is_data_dict=True, 194 | experiment_id=os.path.splitext(os.path.basename(config.ckpt))[0], 195 | tensorboard_dir=config.tensorboard_dir, 196 | ) 197 | 198 | save_counter = 0 199 | save_interval = 10 200 | 201 | 202 | def validation_callback(): 203 | global save_counter, save_interval 204 | # Save generated audio per every validation 205 | net.eval() 206 | 207 | def tensorboard_audio(data_loader, phase): 208 | 209 | bd = next(iter(data_loader)) 210 | for k, v in bd.items(): 211 | bd[k] = v.cuda() 212 | 213 | original_audio = bd["audio"][0] 214 | estimation = net(bd) 215 | 216 | if config.use_reverb: 217 | reconed_audio = estimation["audio_reverb"][0, : len(original_audio)] 218 | trainer.tensorboard.add_audio( 219 | f"{trainer.config['experiment_id']}/{phase}_recon", 220 | reconed_audio.cpu(), 221 | trainer.config["step"], 222 | sample_rate=config.sample_rate, 223 | ) 224 | 225 | reconed_audio_dereverb = estimation["audio_synth"][0, : len(original_audio)] 226 | trainer.tensorboard.add_audio( 227 | f"{trainer.config['experiment_id']}/{phase}_recon_dereverb", 228 | reconed_audio_dereverb.cpu(), 229 | trainer.config["step"], 230 | sample_rate=config.sample_rate, 231 | ) 232 | trainer.tensorboard.add_audio( 233 | f"{trainer.config['experiment_id']}/{phase}_original", 234 | original_audio.cpu(), 235 | trainer.config["step"], 236 | sample_rate=config.sample_rate, 237 | ) 238 | 239 | tensorboard_audio(train_dataloader, phase="train") 240 | tensorboard_audio(valid_dataloader, phase="valid") 241 | 242 | save_counter += 1 243 | if save_counter % save_interval == 0: 244 | trainer.save(trainer.ckpt + f"-{trainer.config['step']}") 245 | 246 | 247 | trainer.register_callback(validation_callback) 248 | if config.resume: 249 | trainer.load(config.ckpt) 250 | 251 | trainer.add_external_config(config) 252 | trainer.train(step=config.num_step, validation_interval=config.validation_interval) 253 | 254 | -------------------------------------------------------------------------------- /train/trainer/PinkModule/logging.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import sys 3 | 4 | 5 | class PinkBlackLogger: 6 | def __init__(self, fp, stream=sys.stdout): 7 | self.stream = stream 8 | self.fp = fp 9 | 10 | def write(self, message): 11 | self.fp.write(message) 12 | self.fp.flush() 13 | self.stream.write(message) 14 | 15 | def flush(self): 16 | self.stream.flush() 17 | 18 | 19 | def padding(arg, width, pad=' '): 20 | if isinstance(arg, float): 21 | return '{:.6f}'.format(arg).center(width, pad) 22 | elif isinstance(arg, int): 23 | return '{:6d}'.format(arg).center(width, pad) 24 | elif isinstance(arg, str): 25 | return arg.center(width, pad) 26 | elif isinstance(arg, tuple): 27 | if len(arg) != 2: 28 | raise ValueError('Unknown type: {}'.format(type(arg), arg)) 29 | if not isinstance(arg[1], str): 30 | raise ValueError('Unknown type: {}' 31 | .format(type(arg[1]), arg[1])) 32 | return padding(arg[0], width, pad=pad) 33 | else: 34 | raise ValueError('Unknown type: {}'.format(type(arg), arg)) 35 | 36 | 37 | def print_row(kwarg_list=[], pad=' '): 38 | len_kwargs = len(kwarg_list) 39 | term_width = shutil.get_terminal_size().columns 40 | width = min((term_width - 1 - len_kwargs) * 9 // 10, 150) // len_kwargs 41 | row = '|{}' * len_kwargs + '|' 42 | columns = [] 43 | for kwarg in kwarg_list: 44 | columns.append(padding(kwarg, width, pad=pad)) 45 | print(row.format(*columns)) 46 | -------------------------------------------------------------------------------- /train/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from . import io, trainer 2 | -------------------------------------------------------------------------------- /train/trainer/io.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | from .PinkModule.logging import PinkBlackLogger 3 | from omegaconf import OmegaConf 4 | 5 | 6 | def convert_type(string: str): 7 | try: 8 | f = float(string) 9 | if f.is_integer(): 10 | return int(f) 11 | else: 12 | return f 13 | except ValueError: 14 | return string 15 | 16 | 17 | def get_args(default_config: dict): 18 | import argparse 19 | 20 | parser = argparse.ArgumentParser() 21 | 22 | if not "gpu" in default_config.keys(): 23 | parser.add_argument(f"--gpu", default=None, help="CUDA visible devices : default:None") 24 | 25 | for k, v in default_config.items(): 26 | k = k.lower() 27 | if isinstance(v, bool): 28 | if v is True: 29 | parser.add_argument(f"--{k}", help=f"{k} : default:{v}", action="store_false") 30 | else: 31 | parser.add_argument(f"--{k}", help=f"{k} : default:{v}", action="store_true") 32 | 33 | else: 34 | parser.add_argument(f"--{k}", default=v, help=f"{k} : default:{v}") 35 | args = parser.parse_args() 36 | 37 | if args.gpu and "gpu" in default_config.keys(): 38 | # Default argument로 gpu를 줬다면 이렇게 세팅 39 | os.environ.update({"CUDA_VISIBLE_DEVICES": str(args.gpu)}) 40 | 41 | for k in default_config.keys(): 42 | k = k.lower() 43 | val = getattr(args, k) 44 | if not isinstance(val, bool): 45 | setattr(args, k, convert_type(str(val))) 46 | 47 | return args 48 | 49 | 50 | def setup( 51 | trace=False, 52 | pdb_on_error=True, 53 | default_config=None, 54 | autolog=False, 55 | autolog_dir="pinkblack_autolog", 56 | ): 57 | """ 58 | :param trace: 59 | :param pdb_on_error: 60 | :param default_config: dict or str(yaml file) 61 | :param autolog: 62 | :param autolog_dir: 63 | gpu -> CUDA_VISIBLE_DEVICES 64 | :return: argparsed config 65 | 66 | Example >> 67 | ```bash 68 | CUDA_VISIBLE_DEVICES=1,3 python myscript.py --batch_size 32 --ckpt ckpt.pth --epochs 100 --lr 0.001 69 | ``` 70 | ```python3 71 | setup(default_config=dict(gpu="1,3", 72 | batch_size=32, 73 | lr=1e-3, 74 | epochs=100, 75 | ckpt="ckpt.pth")) 76 | ``` 77 | ```bash 78 | python myscript.py --gpu 1,3 --batch_size 32 --ckpt ckpt.pth --epochs 100 --lr 0.001 79 | ``` 80 | 81 | """ 82 | if trace: 83 | import backtrace 84 | 85 | backtrace.hook(align=True) 86 | 87 | if pdb_on_error: 88 | old_hook = sys.excepthook 89 | 90 | def new_hook(type_, value, tb): 91 | old_hook(type_, value, tb) 92 | if type_ != KeyboardInterrupt: 93 | import pdb 94 | 95 | pdb.post_mortem(tb) 96 | 97 | sys.excepthook = new_hook 98 | 99 | args = None 100 | if default_config is not None: 101 | if isinstance(default_config, str): 102 | default_config = OmegaConf.load(default_config) 103 | args = get_args(default_config) 104 | 105 | import time, datetime 106 | 107 | dt = datetime.datetime.fromtimestamp(time.time()) 108 | dt = datetime.datetime.strftime(dt, f"{os.path.basename(sys.argv[0])}.%Y%m%d_%H%M%S.log") 109 | 110 | if args is not None and hasattr(args, "ckpt"): 111 | logpath = args.ckpt + "_" + dt 112 | else: 113 | logpath = os.path.join(autolog_dir, dt) 114 | 115 | os.makedirs(os.path.dirname(logpath), exist_ok=True) 116 | 117 | if args is not None: 118 | conf = OmegaConf.create(args.__dict__) 119 | conf.save(logpath[:-4] + ".yaml") 120 | conf.save(args.ckpt + ".yaml") 121 | 122 | if autolog: 123 | fp = open(logpath, "w") 124 | sys.stdout = PinkBlackLogger(fp, sys.stdout) 125 | sys.stderr = PinkBlackLogger(fp, sys.stderr) 126 | print("PinkBlack :: args :", args.__dict__) 127 | 128 | return args 129 | 130 | 131 | def set_seeds(seed, strict=False): 132 | """ 133 | strict 가 True이면, Cudnn까지도 deterministic 하게 한다. 134 | cudnn은 아주 조금 stochastic한 연산 결과를 보여주므로, 정확한 재현이 필요하다면 True로 설정. 135 | 136 | If strict == True, then cudnn backend will be deterministic 137 | torch.backends.cudnn.deterministic = True 138 | """ 139 | import random 140 | import numpy as np 141 | import torch 142 | 143 | torch.manual_seed(seed) 144 | if torch.cuda.is_available(): 145 | torch.cuda.manual_seed_all(seed) 146 | if strict: 147 | torch.backends.cudnn.deterministic = True 148 | np.random.seed(seed) 149 | random.seed(seed) 150 | -------------------------------------------------------------------------------- /train/trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.optim import Adam 4 | from torch.nn.utils import clip_grad_norm_ 5 | from torch.optim.lr_scheduler import ReduceLROnPlateau 6 | 7 | from tqdm import tqdm 8 | from time import time 9 | from tensorboardX.writer import SummaryWriter 10 | from datetime import datetime 11 | from collections import defaultdict 12 | 13 | import os 14 | import json 15 | import logging 16 | import pandas as pd 17 | 18 | from .PinkModule.logging import * 19 | 20 | 21 | class AverageMeter(object): 22 | """ 23 | Computes and stores the average and current value 24 | """ 25 | 26 | def __init__(self): 27 | self.val = 0 28 | self.avg = 0 29 | self.sum = 0 30 | self.count = 0 31 | 32 | def update(self, val, n=1): 33 | self.val = val 34 | self.sum += val * n 35 | self.count += n 36 | self.avg = self.sum / self.count 37 | 38 | 39 | def cal_accuracy(pred, target): 40 | pred = torch.max(pred, 1)[1] 41 | corrects = torch.sum(pred == target).float() 42 | return corrects / pred.size(0) 43 | 44 | 45 | class Trainer: 46 | experiment_name = None 47 | 48 | def __init__( 49 | self, 50 | net, 51 | criterion=None, 52 | metric=cal_accuracy, 53 | train_dataloader=None, 54 | val_dataloader=None, 55 | test_dataloader=None, 56 | optimizer=None, 57 | lr_scheduler=None, 58 | tensorboard_dir="./pinkblack_tb/", 59 | ckpt="./ckpt/ckpt.pth", 60 | experiment_id=None, 61 | clip_gradient_norm=False, 62 | is_data_dict=False, 63 | ): 64 | """ 65 | :param net: nn.Module Network 66 | :param criterion: loss function. __call__(prediction, *batch_y) 67 | :param metric: metric function __call__(prediction, *batch_y). 68 | *note* : bigger is better. (Early Stopping할 때 metric이 더 큰 값을 선택한다) 69 | 70 | :param train_dataloader: 71 | :param val_dataloader: 72 | :param test_dataloader: 73 | 74 | :param optimizer: torch.optim 75 | :param lr_scheduler: 76 | :param tensorboard_dir: tensorboard log 77 | :param ckpt: 78 | :param experiment_id: be shown on tensorboard 79 | :param clip_gradient_norm: False or Scalar value (숫자를 입력하면 gradient clipping한다.) 80 | :param is_data_dict: whether dataloaders return dict. (dataloader에서 주는 데이터가 dict인지) 81 | """ 82 | 83 | self.net = net 84 | self.criterion = nn.CrossEntropyLoss() if criterion is None else criterion 85 | self.metric = metric 86 | 87 | self.dataloader = dict() 88 | if train_dataloader is not None: 89 | self.dataloader["train"] = train_dataloader 90 | if val_dataloader is not None: 91 | self.dataloader["val"] = val_dataloader 92 | if test_dataloader is not None: 93 | self.dataloader["test"] = test_dataloader 94 | 95 | if train_dataloader is None or val_dataloader is None: 96 | logging.warning("Init Trainer :: Two dataloaders are needed!") 97 | 98 | self.optimizer = ( 99 | Adam(filter(lambda p: p.requires_grad, self.net.parameters())) 100 | if optimizer is None 101 | else optimizer 102 | ) 103 | self.lr_scheduler = lr_scheduler 104 | 105 | self.ckpt = ckpt 106 | 107 | self.config = defaultdict(float) 108 | self.config["max_train_metric"] = -1e8 109 | self.config["max_val_metric"] = -1e8 110 | self.config["max_test_metric"] = -1e8 111 | self.config["tensorboard_dir"] = tensorboard_dir 112 | self.config["timestamp"] = datetime.now().strftime("%Y%m%d_%H%M%S") 113 | self.config["clip_gradient_norm"] = clip_gradient_norm 114 | self.config["is_data_dict"] = is_data_dict 115 | 116 | if experiment_id is None: 117 | self.config["experiment_id"] = self.config["timestamp"] 118 | else: 119 | self.config["experiment_id"] = experiment_id 120 | 121 | self.dataframe = pd.DataFrame() 122 | 123 | self.device = Trainer.get_model_device(self.net) 124 | if self.device == torch.device("cpu"): 125 | logging.warning( 126 | "Init Trainer :: Do you really want to train the network on CPU instead of GPU?" 127 | ) 128 | 129 | if self.config["tensorboard_dir"] is not None: 130 | self.tensorboard = SummaryWriter(self.config["tensorboard_dir"]) 131 | else: 132 | self.tensorboard = None 133 | 134 | self.callbacks = defaultdict(list) 135 | 136 | def register_callback(self, func, phase="val"): 137 | self.callbacks[phase].append(func) 138 | 139 | def save(self, f=None): 140 | if f is None: 141 | f = self.ckpt 142 | os.makedirs(os.path.dirname(f), exist_ok=True) 143 | if isinstance(self.net, nn.DataParallel): 144 | state_dict = self.net.module.state_dict() 145 | else: 146 | state_dict = self.net.state_dict() 147 | torch.save(state_dict, f) 148 | torch.save(self.optimizer.state_dict(), f + ".optimizer") 149 | 150 | if self.lr_scheduler is not None: 151 | torch.save(self.lr_scheduler.state_dict(), f + ".scheduler") 152 | 153 | with open(f + ".config", "w") as fp: 154 | json.dump(self.config, fp) 155 | 156 | self.dataframe.to_csv(f + ".csv", float_format="%.6f", index=False) 157 | 158 | def load(self, f=None): 159 | if f is None: 160 | f = self.ckpt 161 | 162 | if isinstance(self.net, nn.DataParallel): 163 | self.net.module.load_state_dict(torch.load(f, map_location=self.device)) 164 | else: 165 | self.net.load_state_dict(torch.load(f, map_location=self.device)) 166 | 167 | if os.path.exists(f + ".config"): 168 | with open(f + ".config", "r") as fp: 169 | dic = json.loads(fp.read()) 170 | self.config = defaultdict(float, dic) 171 | print("Loaded,", self.config) 172 | 173 | if os.path.exists(f + ".optimizer"): 174 | self.optimizer.load_state_dict(torch.load(f + ".optimizer")) 175 | 176 | if os.path.exists(f + ".scheduler") and self.lr_scheduler is not None: 177 | self.lr_scheduler.load_state_dict(torch.load(f + ".scheduler")) 178 | 179 | if os.path.exists(f + ".csv"): 180 | self.dataframe = pd.read_csv(f + ".csv") 181 | 182 | if self.config["tensorboard_dir"] is not None: 183 | self.tensorboard = SummaryWriter(self.config["tensorboard_dir"]) 184 | else: 185 | self.tensorboard = None 186 | 187 | def train( 188 | self, epoch=None, phases=None, step=None, validation_interval=1, save_every_validation=False 189 | ): 190 | """ 191 | :param epoch: train dataloader를 순회할 횟수 192 | :param phases: ['train', 'val', 'test'] 중 필요하지 않은 phase를 뺄 수 있다. 193 | >> trainer.train(1, phases=['val']) 194 | 195 | :param step: epoch이 아닌 step을 훈련단위로 할 때의 총 step 수. 196 | :param validation_interval: validation 간격 197 | :param save_every_validation: True이면, validation마다 checkpoint를 저장한다. 198 | :return: None 199 | """ 200 | if phases is None: 201 | phases = list(self.dataloader.keys()) 202 | 203 | if epoch is None and step is None: 204 | raise ValueError("PinkBlack.trainer :: epoch or step should be specified.") 205 | 206 | train_unit = "epoch" if step is None else "step" 207 | self.config[train_unit] = int(self.config[train_unit]) 208 | 209 | num_unit = epoch if step is None else step 210 | validation_interval = 1 if validation_interval <= 0 else validation_interval 211 | 212 | kwarg_list = [train_unit] 213 | for phase in phases: 214 | kwarg_list += [f"{phase}_loss", f"{phase}_metric"] 215 | kwarg_list += ["lr", "time"] 216 | 217 | print_row(kwarg_list=[""] * len(kwarg_list), pad="-") 218 | print_row(kwarg_list=kwarg_list, pad=" ") 219 | print_row(kwarg_list=[""] * len(kwarg_list), pad="-") 220 | 221 | start = self.config[train_unit] 222 | 223 | for i in range(start, start + num_unit, validation_interval): 224 | start_time = time() 225 | if train_unit == "epoch": 226 | for phase in phases: 227 | self.config[f"{phase}_loss"], self.config[f"{phase}_metric"] = self._train( 228 | phase, num_steps=len(self.dataloader[phase]) 229 | ) 230 | for func in self.callbacks[phase]: 231 | func() 232 | self.config[train_unit] += 1 233 | elif train_unit == "step": 234 | for phase in phases: 235 | if phase == "train": 236 | # num_unit 이 validation interval로 나눠떨어지지 않는 경우 237 | num_steps = min((start + num_unit - i), validation_interval) 238 | self.config[train_unit] += num_steps 239 | else: 240 | num_steps = len(self.dataloader[phase]) 241 | self.config[f"{phase}_loss"], self.config[f"{phase}_metric"] = self._train( 242 | phase, num_steps=num_steps 243 | ) 244 | for func in self.callbacks[phase]: 245 | func() 246 | else: 247 | raise NotImplementedError 248 | 249 | if self.lr_scheduler is not None: 250 | if isinstance(self.lr_scheduler, ReduceLROnPlateau): 251 | self.lr_scheduler.step(self.config["val_metric"]) 252 | else: 253 | self.lr_scheduler.step() 254 | 255 | i_str = str(self.config[train_unit]) 256 | is_best = self.config["max_val_metric"] < self.config["val_metric"] 257 | if is_best: 258 | for phase in phases: 259 | self.config[f"max_{phase}_metric"] = max( 260 | self.config[f"max_{phase}_metric"], self.config[f"{phase}_metric"] 261 | ) 262 | i_str = (str(self.config[train_unit])) + "-best" 263 | 264 | elapsed_time = time() - start_time 265 | if self.tensorboard is not None: 266 | _loss, _metric = {}, {} 267 | for phase in phases: 268 | _loss[phase] = self.config[f"{phase}_loss"] 269 | _metric[phase] = self.config[f"{phase}_metric"] 270 | 271 | self.tensorboard.add_scalars( 272 | f"{self.config['experiment_id']}/loss", _loss, self.config[train_unit] 273 | ) 274 | self.tensorboard.add_scalars( 275 | f"{self.config['experiment_id']}/metric", _metric, self.config[train_unit] 276 | ) 277 | self.tensorboard.add_scalar( 278 | f"{self.config['experiment_id']}/time", elapsed_time, self.config[train_unit] 279 | ) 280 | self.tensorboard.add_scalar( 281 | f"{self.config['experiment_id']}/lr", 282 | self.optimizer.param_groups[0]["lr"], 283 | self.config[train_unit], 284 | ) 285 | 286 | print_kwarg = [i_str] 287 | for phase in phases: 288 | print_kwarg += [self.config[f"{phase}_loss"], self.config[f"{phase}_metric"]] 289 | print_kwarg += [self.optimizer.param_groups[0]["lr"], elapsed_time] 290 | 291 | print_row(kwarg_list=print_kwarg, pad=" ") 292 | print_row(kwarg_list=[""] * len(kwarg_list), pad="-") 293 | self.dataframe = self.dataframe.append( 294 | dict(zip(kwarg_list, print_kwarg)), ignore_index=True 295 | ) 296 | 297 | if is_best: 298 | self.save(self.ckpt) 299 | if Trainer.experiment_name is not None: 300 | self.update_experiment() 301 | 302 | if save_every_validation: 303 | self.save(self.ckpt + f"-{self.config[train_unit]}") 304 | 305 | def _step(self, phase, iterator, only_inference=False): 306 | 307 | if self.config["is_data_dict"]: 308 | batch_dict = next(iterator) 309 | batch_size = batch_dict[list(batch_dict.keys())[0]].size(0) 310 | for k, v in batch_dict.items(): 311 | batch_dict[k] = v.to(self.device) 312 | else: 313 | batch_x, batch_y = next(iterator) 314 | if isinstance(batch_x, list): 315 | batch_x = [x.to(self.device) for x in batch_x] 316 | else: 317 | batch_x = [batch_x.to(self.device)] 318 | 319 | if isinstance(batch_y, list): 320 | batch_y = [y.to(self.device) for y in batch_y] 321 | else: 322 | batch_y = [batch_y.to(self.device)] 323 | 324 | batch_size = batch_x[0].size(0) 325 | 326 | self.optimizer.zero_grad() 327 | with torch.set_grad_enabled(phase == "train"): 328 | if self.config["is_data_dict"]: 329 | outputs = self.net(batch_dict) 330 | if not only_inference: 331 | loss = self.criterion(outputs, batch_dict) 332 | else: 333 | outputs = self.net(*batch_x) 334 | if not only_inference: 335 | loss = self.criterion(outputs, *batch_y) 336 | 337 | if only_inference: 338 | return outputs 339 | 340 | if phase == "train": 341 | loss.backward() 342 | if self.config["clip_gradient_norm"]: 343 | clip_grad_norm_(self.net.parameters(), self.config["clip_gradient_norm"]) 344 | self.optimizer.step() 345 | 346 | with torch.no_grad(): 347 | if self.config["is_data_dict"]: 348 | metric = self.metric(outputs, batch_dict) 349 | else: 350 | metric = self.metric(outputs, *batch_y) 351 | 352 | return {"loss": loss.item(), "batch_size": batch_size, "metric": metric.item()} 353 | 354 | def _train(self, phase, num_steps=0): 355 | running_loss = AverageMeter() 356 | running_metric = AverageMeter() 357 | 358 | if phase == "train": 359 | self.net.train() 360 | else: 361 | self.net.eval() 362 | 363 | dataloader = self.dataloader[phase] 364 | step_iterator = iter(dataloader) 365 | tq = tqdm(range(num_steps), leave=False) 366 | for st in tq: 367 | if (st + 1) % len(dataloader) == 0: 368 | step_iterator = iter(dataloader) 369 | results = self._step(phase=phase, iterator=step_iterator) 370 | tq.set_description(f"Loss:{results['loss']:.4f}, Metric:{results['metric']:.4f}") 371 | running_loss.update(results["loss"], results["batch_size"]) 372 | running_metric.update(results["metric"], results["batch_size"]) 373 | 374 | return running_loss.avg, running_metric.avg 375 | 376 | def eval(self, dataloader=None): 377 | self.net.eval() 378 | if dataloader is None: 379 | dataloader = self.dataloader["val"] 380 | phase = "val" 381 | 382 | output_list = [] 383 | step_iterator = iter(dataloader) 384 | num_steps = len(dataloader) 385 | for st in tqdm(range(num_steps), leave=False): 386 | results = self._step(phase="val", iterator=step_iterator, only_inference=True) 387 | output_list.append(results) 388 | 389 | output_cat = torch.cat(output_list) 390 | return output_cat 391 | 392 | def add_external_config(self, args): 393 | """ 394 | args : a dict-like object which contains key-value configurations. 395 | """ 396 | if isinstance(args, dict): 397 | new_d = defaultdict(float) 398 | for k, v in args.items(): 399 | new_d[f"config_{k}"] = v 400 | self.config.update(new_d) 401 | else: 402 | new_d = defaultdict(float) 403 | for k, v in args.__dict__.items(): 404 | new_d[f"config_{k}"] = v 405 | self.config.update(new_d) 406 | 407 | def update_experiment(self): 408 | assert Trainer.experiment_name is not None 409 | df_config = pd.DataFrame(pd.Series(self.config)).T.set_index("experiment_id") 410 | if os.path.exists(Trainer.experiment_name + ".csv"): 411 | df_ex = pd.read_csv(Trainer.experiment_name + ".csv", index_col=0) 412 | if self.config["experiment_id"] in df_ex.index: 413 | df_ex = df_ex.drop(self.config["experiment_id"]) 414 | df_ex = df_ex.append(df_config, sort=False) 415 | else: 416 | df_ex = df_config 417 | df_ex.to_csv(Trainer.experiment_name + ".csv") 418 | return df_ex 419 | 420 | @staticmethod 421 | def get_model_device(net): 422 | device = torch.device("cpu") 423 | for param in net.parameters(): 424 | device = param.device 425 | break 426 | return device 427 | 428 | @staticmethod 429 | def set_experiment_name(name): 430 | Trainer.experiment_name = name 431 | --------------------------------------------------------------------------------