├── .gitignore ├── README.md ├── debug.py ├── make_wrapper.py ├── realtime ├── CMakeLists.txt ├── generate_deb.py ├── helppatch.pd ├── libwavae │ ├── CMakeLists.txt │ └── src │ │ ├── deepAudioEngine.h │ │ ├── test.cpp │ │ ├── wavae.cpp │ │ └── wavae.h └── src │ ├── decoder.cpp │ ├── encoder.cpp │ └── signal_in_out_base.c ├── requirements.txt ├── src ├── __init__.py ├── cached_padding.py ├── data.py ├── domain_adaptation.py ├── gan_modules.py ├── hparams.py ├── losses_train.md ├── melencoder.py ├── model.py ├── pca_utils.py ├── resampling.py ├── train_utils.py └── vanilla_vae.py ├── train.py └── udls ├── __init__.py ├── base_dataset.py ├── domain_adaptation.py └── simple_dataset.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pt 3 | *.pth 4 | *.DS_Store 5 | *runs* 6 | *.mdb 7 | *.vscode 8 | *.ts 9 | *build* 10 | *bin* 11 | *.wav 12 | *.deb 13 | *.json 14 | !*package.json 15 | *node_modules* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WaVAE 2 | 3 | Despite its name, its not a waveform based VAE, but a melspec one with a melGAN decoder. There is also an adversarial regularization on the latent space in order to extract the loudness from it. You can even use it on your dataset, and train the whole thing in maybe 2-4 days on a single GPU. 4 | 5 | This model has realtime generation and a highly-compressed and expressive latent representation. 6 | 7 | ## PureData usage demo 8 | 9 | [![Celine to Scream](https://img.youtube.com/vi/Q3Ejm_ll6KU/0.jpg)](https://www.youtube.com/watch?v=Q3Ejm_ll6KU) 10 | 11 | 12 | ## Usage 13 | 14 | Train the spectral model 15 | ```bash 16 | python train.py -c vanilla --wav-loc YOUR_DATA_FOLDER --name ENTER_A_COOL_NAME 17 | ``` 18 | 19 | Remember to delete the `preprocessed` folder between each training, as the models don't use the same preprocessing pipeline. (You can also use the `--lmdb-loc` flag with a different path for each model) 20 | 21 | Train the waveform model 22 | ```bash 23 | python train.py -c melgan --wav-loc YOUR_DATA_FOLDER --name ENTER_THE_SAME_COOL_NAME 24 | ``` 25 | 26 | The training scripts logs into the `runs` folder, you can visualize it using `tensorboard`. 27 | 28 | 29 | Onced both models are trained, trace them using 30 | ```bash 31 | python make_wrapper.py --name AGAIN_THE_SAME_COOL_NAME 32 | ``` 33 | 34 | It will produce a traced script in `runs/COOL_NAME/COOLNAME_LOTSOFWEIRDNUMBERS.ts`. It can be deployed, used in a libtorch C++ environement, without having to use the source code. AND if you want to use the realtime abilities of this model, just pass the `--use-cached-padding true --buffer-size 2048`. 35 | 36 | ## Compiling 37 | 38 | 39 | To compile the pd externals, you can use CMAKE 40 | ```bash 41 | cmake -DCMAKE_PREFIX_PATH=/path.to.libtorch -DCMAKE_BUILD_TYPE=[Release / Debug] -DCUDNN_LIBRARY_PATH=path.to.libcudnn.so -DCUDNN_INCLUDE_PATH=path.to.cudnn.include -G [Ninja / Xcode / Makefile] ../ 42 | ``` 43 | 44 | Or even better, use the precompiled binaries available in the **Release** section of this project. 45 | Just remember to download the CUDA 10.1 cxx11 ABI version of libtorch and unzip it in `/usr/lib/` 46 | 47 | (only tested on ubuntu 18.04 - 19.10 - 20.04) 48 | -------------------------------------------------------------------------------- /debug.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import torch 3 | torch.set_grad_enabled(False) 4 | 5 | model = torch.jit.load("runs/screams/screams_48kHz_16z_4096b.ts") 6 | x = torch.randn(1,4096) 7 | 8 | z = model.encode(x) 9 | # %% 10 | y = model.decode(z) 11 | # %% 12 | print(y.shape) 13 | # %% 14 | -------------------------------------------------------------------------------- /make_wrapper.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | import numpy as np 8 | 9 | from src import config 10 | from src import get_model, compute_pca, LogLoudness 11 | from src.resampling import Resampling 12 | 13 | torch.set_grad_enabled(False) 14 | config.parse_args() 15 | 16 | NAME = config.NAME 17 | ROOT = path.join("runs/", config.NAME) 18 | PCA = True 19 | 20 | config_melgan = ".".join(path.join(ROOT, "melgan", "config").split("/")) 21 | config_vanilla = ".".join(path.join(ROOT, "vanilla", "config").split("/")) 22 | 23 | 24 | class BufferSTFT(nn.Module): 25 | def __init__(self, buffer_size, hop_length): 26 | super().__init__() 27 | n_frame = (config.BUFFER_SIZE // config.HOP_LENGTH - 1) 28 | buffer = torch.zeros(1, 2048 + n_frame * hop_length) 29 | self.register_buffer("buffer", buffer) 30 | self.buffer_size = buffer_size 31 | 32 | def forward(self, x): 33 | self.buffer = torch.roll(self.buffer, -self.buffer_size, -1) 34 | self.buffer[:, -self.buffer_size:] = x 35 | return self.buffer 36 | 37 | 38 | class TracedMelEncoder(nn.Module): 39 | def __init__(self, melencoder, buffer, hop_length, use_buffer=True): 40 | super().__init__() 41 | self.melencoder = melencoder 42 | self.buffer = torch.jit.script(buffer) 43 | self.use_buffer = use_buffer 44 | self.hop_length = hop_length 45 | 46 | def forward(self, x): 47 | if self.use_buffer: 48 | x = self.buffer(x) 49 | return self.melencoder(x) 50 | 51 | 52 | class Wrapper(nn.Module): 53 | def __init__(self): 54 | super().__init__() 55 | 56 | # BUILDING MELGAN ################################################# 57 | hparams_melgan = importlib.import_module(config_melgan).config 58 | hparams_melgan.override(USE_CACHED_PADDING=config.USE_CACHED_PADDING) 59 | melgan = get_model(hparams_melgan) 60 | 61 | pretrained_state_dict = torch.load(path.join(ROOT, "melgan", 62 | "melgan_state.pth"), 63 | map_location="cpu")[0] 64 | state_dict = melgan.state_dict() 65 | state_dict.update(pretrained_state_dict) 66 | melgan.load_state_dict(state_dict) 67 | ################################################################### 68 | 69 | # BUILDING VANILLA ################################################ 70 | hparams_vanilla = importlib.import_module(config_vanilla).config 71 | hparams_vanilla.override(USE_CACHED_PADDING=config.USE_CACHED_PADDING) 72 | vanilla = get_model(hparams_vanilla) 73 | 74 | pretrained_state_dict = torch.load(path.join(ROOT, "vanilla", 75 | "vanilla_state.pth"), 76 | map_location="cpu") 77 | state_dict = vanilla.state_dict() 78 | state_dict.update(pretrained_state_dict) 79 | vanilla.load_state_dict(state_dict) 80 | ################################################################### 81 | 82 | vanilla.eval() 83 | melgan.eval() 84 | 85 | # PRETRACE MODELS ################################################# 86 | self.latent_size = int(config.CHANNELS[-1] // 2) 87 | self.mel_size = int(config.CHANNELS[0]) 88 | 89 | if config.USE_CACHED_PADDING: 90 | test_wav = torch.randn(1, config.BUFFER_SIZE) 91 | test_mel = torch.randn(1, config.INPUT_SIZE, 2) 92 | if hparams_vanilla.EXTRACT_LOUDNESS: 93 | test_z = torch.randn(1, self.latent_size + 1, 1) 94 | else: 95 | test_z = torch.randn(1, self.latent_size, 1) 96 | 97 | else: 98 | test_wav = torch.randn(1, 8192) 99 | test_mel = torch.randn(1, config.INPUT_SIZE, 16) 100 | if hparams_vanilla.EXTRACT_LOUDNESS: 101 | test_z = torch.randn(1, self.latent_size + 1, 16) 102 | else: 103 | test_z = torch.randn(1, self.latent_size, 16) 104 | 105 | melencoder = TracedMelEncoder( 106 | vanilla.melencoder, 107 | BufferSTFT(config.BUFFER_SIZE, config.HOP_LENGTH), 108 | config.HOP_LENGTH, config.USE_CACHED_PADDING) 109 | 110 | logloudness = LogLoudness( 111 | int(hparams_vanilla.HOP_LENGTH * np.prod(hparams_vanilla.RATIOS)), 112 | 1e-4) 113 | 114 | self.trace_logloudness = torch.jit.script(logloudness) 115 | self.trace_melencoder = torch.jit.trace(melencoder, 116 | test_wav, 117 | check_trace=False) 118 | self.trace_encoder = torch.jit.trace(vanilla.topvae.encoder, 119 | test_mel, 120 | check_trace=False) 121 | self.trace_decoder = torch.jit.trace(vanilla.topvae.decoder, 122 | test_z, 123 | check_trace=False) 124 | self.trace_melgan = torch.jit.trace(melgan.decoder, 125 | test_mel, 126 | check_trace=False) 127 | 128 | config.override(SAMPRATE=hparams_vanilla.SAMPRATE, 129 | N_SIGNAL=hparams_vanilla.N_SIGNAL, 130 | EXTRACT_LOUDNESS=hparams_vanilla.EXTRACT_LOUDNESS, 131 | TYPE=hparams_vanilla.TYPE, 132 | HOP_LENGTH=hparams_vanilla.HOP_LENGTH, 133 | RATIOS=hparams_vanilla.RATIOS, 134 | WAV_LOC=hparams_vanilla.WAV_LOC, 135 | LMDB_LOC=hparams_vanilla.LMDB_LOC) 136 | 137 | self.pca = None 138 | 139 | self.resampling = torch.jit.script( 140 | Resampling(config.TARGET_SR, config.SAMPRATE)) 141 | 142 | if PCA: 143 | try: 144 | self.pca = torch.load(path.join(ROOT, "pca.pth")) 145 | print("Precomputed pca found") 146 | 147 | except: 148 | if config.USE_CACHED_PADDING: 149 | raise Exception( 150 | "PCA should be first computed in non cache mode") 151 | print("No precomputed pca found. Computing.") 152 | self.pca = None 153 | 154 | if self.pca == None: 155 | self.pca = compute_pca(self, 32) 156 | torch.save(self.pca, path.join(ROOT, "pca.pth")) 157 | 158 | self.register_buffer("mean", self.pca[0]) 159 | self.register_buffer("std", self.pca[1]) 160 | self.register_buffer("U", self.pca[2]) 161 | 162 | self.extract_loudness = config.EXTRACT_LOUDNESS 163 | 164 | def forward(self, x): 165 | return self.decode(self.encode(x)) 166 | 167 | @torch.jit.export 168 | def melencode(self, x): 169 | return self.trace_melencoder(x) 170 | 171 | @torch.jit.export 172 | def encode(self, x): 173 | x = x.unsqueeze(1) 174 | x = self.resampling.from_target_sampling_rate(x) 175 | x = x.squeeze(1) 176 | 177 | mel = self.melencode(x) 178 | z = self.trace_encoder(mel) 179 | 180 | mean, logvar = torch.split(z, self.latent_size, 1) 181 | z = torch.randn_like(mean) * torch.exp(logvar) + mean 182 | 183 | if self.pca is not None: 184 | z = (z.permute(0, 2, 1) - self.mean).matmul(self.U).div( 185 | self.std).permute(0, 2, 1) 186 | if self.extract_loudness: 187 | loudness = self.trace_logloudness(x) 188 | z = torch.cat([loudness, z], 1) 189 | 190 | z = z.repeat_interleave(self.resampling.ratio).reshape( 191 | z.shape[0], 192 | z.shape[1], 193 | -1, 194 | ) 195 | return z 196 | 197 | @torch.jit.export 198 | def decode(self, z): 199 | z = z[..., ::self.resampling.ratio] 200 | if self.pca is not None: 201 | if self.extract_loudness: 202 | loud, z = z[:, :1, :], z[:, 1:, :] 203 | z = (z.permute(0, 2, 1).matmul( 204 | self.U.permute(1, 0) * self.std) + self.mean).permute( 205 | 0, 2, 1) 206 | z = torch.cat([loud, z], 1) 207 | else: 208 | z = (z.permute(0, 2, 1).matmul( 209 | self.U.permute(1, 0) * self.std) + self.mean).permute( 210 | 0, 2, 1) 211 | mel = torch.sigmoid(self.trace_decoder(z)) 212 | mel = torch.split(mel, self.mel_size, 1)[0] 213 | waveform = self.trace_melgan(mel) 214 | waveform = self.resampling.to_target_sampling_rate(waveform) 215 | return waveform 216 | 217 | 218 | if __name__ == "__main__": 219 | wrapper = Wrapper().cpu() 220 | 221 | name_list = [ 222 | config.NAME, 223 | str(int(np.floor(config.TARGET_SR / 1000))) + "kHz", 224 | str(config.CHANNELS[-1] // 2 + int(config.EXTRACT_LOUDNESS)) + "z" 225 | ] 226 | if config.USE_CACHED_PADDING: 227 | name_list.append( 228 | str(config.BUFFER_SIZE * wrapper.resampling.ratio) + "b") 229 | 230 | name = "_".join(name_list) + ".ts" 231 | torch.jit.script(wrapper).save(path.join(ROOT, name)) 232 | -------------------------------------------------------------------------------- /realtime/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0) 2 | 3 | project(WAVAE) 4 | 5 | add_subdirectory("./libwavae") 6 | 7 | add_library(encoder SHARED src/encoder.cpp) 8 | set_target_properties(encoder PROPERTIES PREFIX "" SUFFIX "~.pd_linux") 9 | 10 | add_library(decoder SHARED src/decoder.cpp) 11 | set_target_properties(decoder PROPERTIES PREFIX "" SUFFIX "~.pd_linux") 12 | 13 | -------------------------------------------------------------------------------- /realtime/generate_deb.py: -------------------------------------------------------------------------------- 1 | from os import system, makedirs 2 | 3 | VERSION = input("Version: ") 4 | PACKAGE = f"wavae_{VERSION}" 5 | 6 | makedirs(f"{PACKAGE}/usr/lib/") 7 | makedirs(f"{PACKAGE}/usr/local/lib/pd-externals/wavae/") 8 | makedirs(f"{PACKAGE}/DEBIAN/") 9 | 10 | system(f"cp build/*.pd_linux {PACKAGE}/usr/local/lib/pd-externals/wavae/") 11 | system(f"cp helppatch.pd {PACKAGE}/usr/local/lib/pd-externals/wavae/help-encoder~.pd") 12 | system(f"cp helppatch.pd {PACKAGE}/usr/local/lib/pd-externals/wavae/help-decoder~.pd") 13 | 14 | system(f"cp build/libwavae/libwavae.so {PACKAGE}/usr/lib/") 15 | 16 | with open(f"{PACKAGE}/DEBIAN/control", "w") as control: 17 | control.write("Package: wavae\n") 18 | control.write(f"Version: {VERSION}\n") 19 | control.write("Maintainer: Antoine CAILLON \n") 20 | control.write("Depends: nvidia-cuda-toolkit\n") 21 | control.write("Architecture: all\n") 22 | control.write( 23 | "Description: WaVAE puredata external. Needs libtorch in /usr/lib\n") 24 | 25 | system(f"dpkg-deb --build {PACKAGE}") 26 | if not input("Enter any key to prevent temporary folder destruction: "): 27 | system(f"rm -fr {PACKAGE}/") 28 | -------------------------------------------------------------------------------- /realtime/helppatch.pd: -------------------------------------------------------------------------------- 1 | #N canvas 697 352 450 167 12; 2 | #N canvas 129 175 713 300 ai_magic 0; 3 | #X obj 16 220 loadbang; 4 | #X msg 16 245 set 2048 1 0.5; 5 | #X obj 16 270 block~; 6 | #X obj 17 64 wavae/encoder~ 8 2048; 7 | #X obj 16 150 wavae/decoder~ 8 2048; 8 | #X obj 16 30 inlet~; 9 | #X obj 16 182 outlet~; 10 | #X text 239 49 The wavae external is composed of two sub externals: 11 | an encoder~ and a decoder~ (making a complete autoencoder <3). The 12 | first arguments sets the latent space dimensionality (swag) and the 13 | second arguments sets the buffer size. This subpatch is where the ai 14 | magic happens ! Actually there is a reason why we encapsulate the encoding 15 | / decoding process inside a subpatch: given that the generation of 16 | a traced model is based on a fixed buffer size / sampling rate basis 17 | \, we have to enforce it using the block~ object. For example here 18 | \, we have a model trained at 24kHz \, with a buffer size of 2048 \, 19 | hence we put jack at 48kHz and multiply by 0.5 it to reach 24kHz in 20 | this subpatch. Voila !; 21 | #N canvas 921 439 450 250 (subpatch) 0; 22 | #X text 28 1 wavae autoencoder; 23 | #X coords 0 1 100 -1 180 200 1 0 0; 24 | #X restore 8 7 graph; 25 | #X connect 0 0 1 0; 26 | #X connect 1 0 2 0; 27 | #X connect 3 0 4 0; 28 | #X connect 3 1 4 1; 29 | #X connect 3 2 4 2; 30 | #X connect 3 3 4 3; 31 | #X connect 3 4 4 4; 32 | #X connect 3 5 4 5; 33 | #X connect 3 6 4 6; 34 | #X connect 3 7 4 7; 35 | #X connect 4 0 6 0; 36 | #X connect 5 0 3 0; 37 | #X restore 72 70 pd ai_magic; 38 | #X obj 73 26 adc~; 39 | #X obj 72 116 dac~; 40 | #X text 217 40 Help patch for wavae.; 41 | #X text 166 65 <----- OPEN THAT SHINY THING; 42 | #X connect 0 0 2 0; 43 | #X connect 1 0 0 0; 44 | -------------------------------------------------------------------------------- /realtime/libwavae/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0) 2 | project(libwavae) 3 | 4 | find_package(Torch REQUIRED) 5 | 6 | file(GLOB SRC src/*.cpp) 7 | 8 | # TEST EXECUTABLE 9 | add_executable(test ${SRC}) 10 | target_link_libraries(test "${TORCH_LIBRARIES}") 11 | set_property(TARGET test PROPERTY CXX_STANDARD 14) 12 | 13 | # ACTUAL LIBRARY 14 | add_library(wavae SHARED ${SRC}) 15 | target_link_libraries(wavae "${TORCH_LIBRARIES}") 16 | set_property(TARGET wavae PROPERTY CXX_STANDARD 14) -------------------------------------------------------------------------------- /realtime/libwavae/src/deepAudioEngine.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #define DIM_REDUCTION_FACTOR 512 6 | 7 | class DeepAudioEngine { 8 | public: 9 | virtual void perform(float *in_buffer, float *out_buffer, 10 | int dsp_vec_size) = 0; 11 | virtual int load(std::string name) = 0; 12 | virtual void set_latent_number(int n) = 0; 13 | }; -------------------------------------------------------------------------------- /realtime/libwavae/src/test.cpp: -------------------------------------------------------------------------------- 1 | #include "wavae.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #define LATENT_NUMBER 16 8 | #define BUFFERSIZE 2048 9 | 10 | int main(int argc, char const *argv[]) { 11 | 12 | DeepAudioEngine *encoder = new wavae::Encoder; 13 | int error = encoder->load("trace_model.ts"); 14 | 15 | DeepAudioEngine *decoder = new wavae::Decoder; 16 | error = decoder->load("trace_model.ts"); 17 | 18 | float *inbuffer = new float[BUFFERSIZE]; 19 | float *outbuffer = new float[BUFFERSIZE]; 20 | float *zbuffer = new float[LATENT_NUMBER * BUFFERSIZE / DIM_REDUCTION_FACTOR]; 21 | 22 | // LOOP TEST 23 | for (int i(0); i < 100; i++) { 24 | std::cout << i << std::endl; 25 | encoder->perform(inbuffer, zbuffer, BUFFERSIZE); 26 | decoder->perform(zbuffer, outbuffer, BUFFERSIZE); 27 | } 28 | 29 | return 0; 30 | } 31 | -------------------------------------------------------------------------------- /realtime/libwavae/src/wavae.cpp: -------------------------------------------------------------------------------- 1 | #include "wavae.h" 2 | #include "deepAudioEngine.h" 3 | #include 4 | #include 5 | 6 | #define DEVICE torch::kCUDA 7 | #define CPU torch::kCPU 8 | 9 | // ENCODER ///////////////////////////////////////////////////////// 10 | 11 | wavae::Encoder::Encoder() { 12 | model_loaded = 0; 13 | at::init_num_threads(); 14 | } 15 | 16 | void wavae::Encoder::set_latent_number(int n) { latent_number = n; } 17 | 18 | void wavae::Encoder::perform(float *in_buffer, float *out_buffer, 19 | int dsp_vec_size) { 20 | torch::NoGradGuard no_grad; 21 | 22 | if (model_loaded) { 23 | 24 | auto tensor = torch::from_blob(in_buffer, {1, dsp_vec_size}); 25 | tensor = tensor.to(DEVICE); 26 | 27 | std::vector input; 28 | input.push_back(tensor); 29 | 30 | auto out_tensor = model.get_method("encode")(std::move(input)).toTensor(); 31 | 32 | out_tensor = out_tensor.repeat_interleave(DIM_REDUCTION_FACTOR); 33 | out_tensor = out_tensor.to(CPU); 34 | 35 | auto out = out_tensor.contiguous().data_ptr(); 36 | 37 | for (int i(0); i < latent_number * dsp_vec_size; i++) { 38 | out_buffer[i] = out[i]; 39 | } 40 | 41 | } else { 42 | 43 | for (int i(0); i < latent_number * dsp_vec_size; i++) { 44 | out_buffer[i] = 0; 45 | } 46 | } 47 | } 48 | 49 | int wavae::Encoder::load(std::string name) { 50 | try { 51 | model = torch::jit::load(name); 52 | model.to(DEVICE); 53 | model_loaded = 1; 54 | return 0; 55 | } catch (const std::exception &e) { 56 | std::cerr << e.what() << '\n'; 57 | return 1; 58 | } 59 | } 60 | 61 | // DECODER ///////////////////////////////////////////////////////// 62 | 63 | wavae::Decoder::Decoder() { 64 | model_loaded = 0; 65 | at::init_num_threads(); 66 | } 67 | 68 | void wavae::Decoder::set_latent_number(int n) { latent_number = n; } 69 | 70 | void wavae::Decoder::perform(float *in_buffer, float *out_buffer, 71 | int dsp_vec_size) { 72 | 73 | torch::NoGradGuard no_grad; 74 | 75 | if (model_loaded) { 76 | 77 | auto tensor = torch::from_blob(in_buffer, {1, latent_number, dsp_vec_size}); 78 | tensor = 79 | tensor.reshape({1, latent_number, -1, DIM_REDUCTION_FACTOR}).mean(-1); 80 | tensor = tensor.to(DEVICE); 81 | 82 | std::vector input; 83 | input.push_back(tensor); 84 | 85 | auto out_tensor = model.get_method("decode")(std::move(input)) 86 | .toTensor() 87 | .reshape({-1}) 88 | .contiguous(); 89 | 90 | out_tensor = out_tensor.to(CPU); 91 | 92 | auto out = out_tensor.data_ptr(); 93 | 94 | for (int i(0); i < dsp_vec_size; i++) { 95 | out_buffer[i] = out[i]; 96 | } 97 | } else { 98 | for (int i(0); i < dsp_vec_size; i++) { 99 | out_buffer[i] = 0; 100 | } 101 | } 102 | } 103 | 104 | int wavae::Decoder::load(std::string name) { 105 | try { 106 | model = torch::jit::load(name); 107 | model.to(DEVICE); 108 | model_loaded = 1; 109 | return 0; 110 | } catch (const std::exception &e) { 111 | std::cerr << e.what() << '\n'; 112 | return 1; 113 | } 114 | } 115 | 116 | extern "C" { 117 | DeepAudioEngine *get_encoder() { return new wavae::Encoder; } 118 | DeepAudioEngine *get_decoder() { return new wavae::Decoder; } 119 | } -------------------------------------------------------------------------------- /realtime/libwavae/src/wavae.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "deepAudioEngine.h" 3 | #include 4 | #include 5 | 6 | namespace wavae { 7 | 8 | class Encoder : public DeepAudioEngine { 9 | public: 10 | Encoder(); 11 | void perform(float *in_buffer, float *out_buffer, int dsp_vec_size) override; 12 | int load(std::string name) override; 13 | void set_latent_number(int n) override; 14 | 15 | protected: 16 | int model_loaded; 17 | int latent_number; 18 | torch::jit::script::Module model; 19 | }; 20 | 21 | class Decoder : public DeepAudioEngine { 22 | public: 23 | Decoder(); 24 | void perform(float *in_buffer, float *out_buffer, int dsp_vec_size) override; 25 | int load(std::string name) override; 26 | void set_latent_number(int n) override; 27 | 28 | protected: 29 | int model_loaded; 30 | int latent_number; 31 | torch::jit::script::Module model; 32 | }; 33 | 34 | } // namespace wavae 35 | -------------------------------------------------------------------------------- /realtime/src/decoder.cpp: -------------------------------------------------------------------------------- 1 | #include "../libwavae/src/deepAudioEngine.h" 2 | #include "cstring" 3 | #include "dlfcn.h" 4 | #include 5 | #include "pthread.h" 6 | #include "sched.h" 7 | #include "thread" 8 | #include 9 | 10 | #define DAE DeepAudioEngine 11 | 12 | static t_class *decoder_tilde_class; 13 | 14 | typedef struct _decoder_tilde { 15 | t_object x_obj; 16 | t_sample f; 17 | 18 | // OBJECT ATTRIBUTES 19 | int loaded, latent_number, buffer_size, activated; 20 | float *in_buffer, *out_buffer, fadein; 21 | std::thread *worker; 22 | DAE *model; 23 | 24 | // DSP RELATED MEMORY MAPS 25 | float **dsp_in_vec, *dsp_out_vec; 26 | int dsp_vec_size; 27 | 28 | } t_decoder_tilde; 29 | 30 | void perform(t_decoder_tilde *x) { 31 | // SET THREAD TO REALTIME PRIORITY 32 | pthread_t this_thread = pthread_self(); 33 | struct sched_param params; 34 | params.sched_priority = sched_get_priority_max(SCHED_FIFO); 35 | int ret = pthread_setschedparam(this_thread, SCHED_FIFO, ¶ms); 36 | 37 | // COMPUTATION 38 | x->model->perform(x->in_buffer, x->out_buffer, x->buffer_size); 39 | } 40 | 41 | t_int *decoder_tilde_perform(t_int *w) { 42 | t_decoder_tilde *x = (t_decoder_tilde *)w[1]; 43 | if (x->dsp_vec_size != x->buffer_size) { 44 | char error[80]; 45 | sprintf(error, "decoder: expecting buffer %d, got %d", x->buffer_size, 46 | x->dsp_vec_size); 47 | post(error); 48 | for (int i(0); i < x->dsp_vec_size; i++) { 49 | x->dsp_out_vec[i] = 0; 50 | } 51 | } else if (x->activated == 0) { 52 | for (int i(0); i < x->dsp_vec_size; i++) { 53 | x->dsp_out_vec[i] = 0; 54 | } 55 | } else { 56 | // WAIT FOR PREVIOUS PROCESS TO END 57 | if (x->worker) { 58 | x->worker->join(); 59 | } 60 | 61 | // COPY INPUT BUFFER TO OBJECT 62 | for (int d(0); d < x->latent_number; d++) { 63 | memcpy(x->in_buffer + (d * x->buffer_size), x->dsp_in_vec[d], 64 | x->buffer_size * sizeof(float)); 65 | } 66 | 67 | // COPY PREVIOUS OUTPUT BUFFER TO PD 68 | memcpy(x->dsp_out_vec, x->out_buffer, x->buffer_size * sizeof(float)); 69 | 70 | // FADE IN 71 | if (x->fadein < .99) { 72 | for (int i(0); i < x->buffer_size; i++) { 73 | x->dsp_out_vec[i] *= x->fadein; 74 | x->fadein = x->loaded ? x->fadein * .99999 + 0.00001 : x->fadein; 75 | } 76 | } 77 | 78 | // START NEXT COMPUTATION 79 | x->worker = new std::thread(perform, x); 80 | } 81 | return w + 2; 82 | } 83 | 84 | void decoder_tilde_dsp(t_decoder_tilde *x, t_signal **sp) { 85 | x->dsp_vec_size = sp[0]->s_n; 86 | for (int i(0); i < x->latent_number; i++) { 87 | x->dsp_in_vec[i] = sp[i]->s_vec; 88 | } 89 | x->dsp_out_vec = sp[x->latent_number]->s_vec; 90 | dsp_add(decoder_tilde_perform, 1, x); 91 | } 92 | 93 | void decoder_tilde_free(t_decoder_tilde *x) { 94 | if (x->worker) { 95 | x->worker->join(); 96 | } 97 | delete x->in_buffer; 98 | delete x->out_buffer; 99 | delete x->dsp_in_vec; 100 | delete x->model; 101 | } 102 | 103 | void *decoder_tilde_new(t_floatarg latent_number, t_floatarg buffer_size) { 104 | t_decoder_tilde *x = (t_decoder_tilde *)pd_new(decoder_tilde_class); 105 | 106 | x->latent_number = int(latent_number) == 0 ? 16 : int(latent_number); 107 | x->buffer_size = int(buffer_size) == 0 ? 512 : int(buffer_size); 108 | x->activated = 1; 109 | 110 | outlet_new(&x->x_obj, &s_signal); 111 | for (int i(1); i < x->latent_number; i++) { 112 | inlet_new(&x->x_obj, &x->x_obj.ob_pd, &s_signal, &s_signal); 113 | } 114 | 115 | x->in_buffer = new float[x->latent_number * x->buffer_size]; 116 | x->out_buffer = new float[x->buffer_size]; 117 | 118 | x->worker = NULL; 119 | 120 | x->loaded = 0; 121 | x->fadein = 0; 122 | 123 | void *hndl = dlopen("/usr/lib/libwavae.so", RTLD_LAZY); 124 | if (!hndl) { 125 | hndl = dlopen("./libwavae/libwavae.so", RTLD_LAZY); 126 | post("Using local version of libwavae"); 127 | } 128 | 129 | x->model = reinterpret_cast(dlsym(hndl, "get_decoder"))(); 130 | x->model->set_latent_number(x->latent_number); 131 | 132 | x->dsp_in_vec = new float *[x->latent_number]; 133 | 134 | return (void *)x; 135 | } 136 | 137 | void decoder_tilde_load(t_decoder_tilde *x, t_symbol *sym) { 138 | x->loaded = 0; 139 | x->fadein = 0; 140 | 141 | int statut = x->model->load(sym->s_name); 142 | 143 | if (statut == 0) { 144 | x->loaded = 1; 145 | post("decoder loaded"); 146 | } else { 147 | post("decoder failed loading model"); 148 | } 149 | } 150 | 151 | void decoder_tilde_activate(t_decoder_tilde *x, t_floatarg arg) { 152 | x->activated = int(arg); 153 | } 154 | 155 | extern "C" { 156 | void decoder_tilde_setup(void) { 157 | decoder_tilde_class = 158 | class_new(gensym("decoder~"), (t_newmethod)decoder_tilde_new, 0, 159 | sizeof(t_decoder_tilde), 0, A_DEFFLOAT, A_DEFFLOAT, 0); 160 | 161 | class_addmethod(decoder_tilde_class, (t_method)decoder_tilde_dsp, 162 | gensym("dsp"), A_CANT, 0); 163 | class_addmethod(decoder_tilde_class, (t_method)decoder_tilde_load, 164 | gensym("load"), A_SYMBOL, A_NULL); 165 | class_addmethod(decoder_tilde_class, (t_method)decoder_tilde_activate, 166 | gensym("activate"), A_DEFFLOAT, A_NULL); 167 | 168 | CLASS_MAINSIGNALIN(decoder_tilde_class, t_decoder_tilde, f); 169 | } 170 | } 171 | -------------------------------------------------------------------------------- /realtime/src/encoder.cpp: -------------------------------------------------------------------------------- 1 | #include "../libwavae/src/deepAudioEngine.h" 2 | #include "cstring" 3 | #include "dlfcn.h" 4 | #include 5 | #include "pthread.h" 6 | #include "sched.h" 7 | #include "thread" 8 | #include 9 | #include 10 | 11 | #define DAE DeepAudioEngine 12 | 13 | static t_class *encoder_tilde_class; 14 | 15 | typedef struct _encoder_tilde { 16 | t_object x_obj; 17 | t_sample f; 18 | 19 | // OBJECT ATTRIBUTES 20 | int latent_number, buffer_size, activated; 21 | float *in_buffer, *out_buffer; 22 | std::thread *worker; 23 | DAE *model; 24 | 25 | // DSP RELATED MEMORY MAPS 26 | float *dsp_in_vec, **dsp_out_vec; 27 | int dsp_vec_size; 28 | 29 | } t_encoder_tilde; 30 | 31 | void perform(t_encoder_tilde *x) { 32 | // SET THREAD TO REALTIME PRIORITY 33 | pthread_t this_thread = pthread_self(); 34 | struct sched_param params; 35 | params.sched_priority = sched_get_priority_max(SCHED_FIFO); 36 | int ret = pthread_setschedparam(this_thread, SCHED_FIFO, ¶ms); 37 | 38 | // COMPUTATION 39 | x->model->perform(x->in_buffer, x->out_buffer, x->buffer_size); 40 | } 41 | 42 | t_int *encoder_tilde_perform(t_int *w) { 43 | t_encoder_tilde *x = (t_encoder_tilde *)w[1]; 44 | 45 | if (x->dsp_vec_size != x->buffer_size) { 46 | char error[80]; 47 | sprintf(error, "encoder: expecting buffer %d, got %d", x->buffer_size, 48 | x->dsp_vec_size); 49 | post(error); 50 | for (int d(0); d < x->latent_number; d++) { 51 | for (int i(0); i < x->dsp_vec_size; i++) { 52 | x->dsp_out_vec[d][i] = 0; 53 | } 54 | } 55 | } else if (x->activated == 0) { 56 | for (int d(0); d < x->latent_number; d++) { 57 | for (int i(0); i < x->dsp_vec_size; i++) { 58 | x->dsp_out_vec[d][i] = 0; 59 | } 60 | } 61 | } else { 62 | // WAIT FOR PREVIOUS PROCESS TO END 63 | if (x->worker) { 64 | x->worker->join(); 65 | } 66 | 67 | // COPY INPUT BUFFER TO OBJECT 68 | memcpy(x->in_buffer, x->dsp_in_vec, x->dsp_vec_size * sizeof(float)); 69 | 70 | // COPY PREVIOUS OUTPUT BUFFER TO PD 71 | for (int d(0); d < x->latent_number; d++) { 72 | memcpy(x->dsp_out_vec[d], x->out_buffer + (d * x->dsp_vec_size), 73 | x->dsp_vec_size * sizeof(float)); 74 | } 75 | 76 | // START NEXT COMPUTATION 77 | x->worker = new std::thread(perform, x); 78 | } 79 | return w + 2; 80 | } 81 | 82 | void encoder_tilde_dsp(t_encoder_tilde *x, t_signal **sp) { 83 | x->dsp_in_vec = sp[0]->s_vec; 84 | x->dsp_vec_size = sp[0]->s_n; 85 | for (int i(0); i < x->latent_number; i++) { 86 | x->dsp_out_vec[i] = sp[i + 1]->s_vec; 87 | } 88 | dsp_add(encoder_tilde_perform, 1, x); 89 | } 90 | 91 | void encoder_tilde_free(t_encoder_tilde *x) { 92 | if (x->worker) { 93 | x->worker->join(); 94 | } 95 | delete x->in_buffer; 96 | delete x->out_buffer; 97 | delete x->dsp_out_vec; 98 | delete x->model; 99 | } 100 | 101 | void *encoder_tilde_new(t_floatarg latent_number, t_floatarg buffer_size) { 102 | t_encoder_tilde *x = (t_encoder_tilde *)pd_new(encoder_tilde_class); 103 | 104 | x->latent_number = int(latent_number) == 0 ? 16 : int(latent_number); 105 | x->buffer_size = int(buffer_size) == 0 ? 512 : int(buffer_size); 106 | x->activated = 1; 107 | 108 | for (int i(0); i < x->latent_number; i++) { 109 | outlet_new(&x->x_obj, &s_signal); 110 | } 111 | 112 | x->in_buffer = new float[x->buffer_size]; 113 | x->out_buffer = new float[x->latent_number * x->buffer_size]; 114 | 115 | x->worker = NULL; 116 | 117 | void *hndl = dlopen("/usr/lib/libwavae.so", RTLD_LAZY); 118 | if (!hndl) { 119 | hndl = dlopen("./libwavae/libwavae.so", RTLD_LAZY); 120 | post("Using local version of libwavae"); 121 | } 122 | 123 | x->model = reinterpret_cast(dlsym(hndl, "get_encoder"))(); 124 | x->model->set_latent_number(x->latent_number); 125 | 126 | x->dsp_out_vec = new float *[x->latent_number]; 127 | 128 | return (void *)x; 129 | } 130 | 131 | void encoder_tilde_load(t_encoder_tilde *x, t_symbol *sym) { 132 | int statut = x->model->load(sym->s_name); 133 | 134 | if (statut == 0) { 135 | post("encoder loaded"); 136 | } else { 137 | post("encoder failed loading model"); 138 | } 139 | } 140 | 141 | void encoder_tilde_activate(t_encoder_tilde *x, t_floatarg arg) { 142 | x->activated = int(arg); 143 | } 144 | 145 | extern "C" { 146 | void encoder_tilde_setup(void) { 147 | encoder_tilde_class = 148 | class_new(gensym("encoder~"), (t_newmethod)encoder_tilde_new, 0, 149 | sizeof(t_encoder_tilde), 0, A_DEFFLOAT, A_DEFFLOAT, 0); 150 | 151 | class_addmethod(encoder_tilde_class, (t_method)encoder_tilde_dsp, 152 | gensym("dsp"), A_CANT, 0); 153 | class_addmethod(encoder_tilde_class, (t_method)encoder_tilde_load, 154 | gensym("load"), A_SYMBOL, A_NULL); 155 | class_addmethod(encoder_tilde_class, (t_method)encoder_tilde_activate, 156 | gensym("activate"), A_DEFFLOAT, A_NULL); 157 | 158 | CLASS_MAINSIGNALIN(encoder_tilde_class, t_encoder_tilde, f); 159 | } 160 | } 161 | -------------------------------------------------------------------------------- /realtime/src/signal_in_out_base.c: -------------------------------------------------------------------------------- 1 | #include "m_pd.h" 2 | 3 | static t_class *pan_tilde_class; 4 | 5 | typedef struct _pan_tilde 6 | { 7 | t_object x_obj; 8 | t_sample f_pan; 9 | t_sample f; 10 | 11 | t_inlet *x_in2; 12 | t_inlet *x_in3; 13 | t_outlet *x_out; 14 | } t_pan_tilde; 15 | 16 | t_int *pan_tilde_perform(t_int *w) 17 | { 18 | t_pan_tilde *x = (t_pan_tilde *)(w[1]); 19 | t_sample *in1 = (t_sample *)(w[2]); 20 | t_sample *in2 = (t_sample *)(w[3]); 21 | t_sample *out = (t_sample *)(w[4]); 22 | int n = (int)(w[5]); 23 | t_sample f_pan = (x->f_pan < 0) ? 0.0 : (x->f_pan > 1) ? 1.0 : x->f_pan; 24 | 25 | while (n--) 26 | *out++ = (*in1++) * (1 - f_pan) + (*in2++) * f_pan; 27 | 28 | return (w + 6); 29 | } 30 | 31 | void pan_tilde_dsp(t_pan_tilde *x, t_signal **sp) 32 | { 33 | dsp_add(pan_tilde_perform, 5, x, 34 | sp[0]->s_vec, sp[1]->s_vec, sp[2]->s_vec, sp[0]->s_n); 35 | } 36 | 37 | void pan_tilde_free(t_pan_tilde *x) 38 | { 39 | inlet_free(x->x_in2); 40 | inlet_free(x->x_in3); 41 | outlet_free(x->x_out); 42 | } 43 | 44 | void *pan_tilde_new(t_floatarg f) 45 | { 46 | t_pan_tilde *x = (t_pan_tilde *)pd_new(pan_tilde_class); 47 | 48 | x->f_pan = f; 49 | 50 | x->x_in2 = inlet_new(&x->x_obj, &x->x_obj.ob_pd, &s_signal, &s_signal); 51 | x->x_in3 = floatinlet_new(&x->x_obj, &x->f_pan); 52 | x->x_out = outlet_new(&x->x_obj, &s_signal); 53 | 54 | return (void *)x; 55 | } 56 | 57 | void pan_tilde_setup(void) 58 | { 59 | pan_tilde_class = class_new(gensym("pan~"), 60 | (t_newmethod)pan_tilde_new, 61 | 0, sizeof(t_pan_tilde), 62 | CLASS_DEFAULT, 63 | A_DEFFLOAT, 0); 64 | 65 | class_addmethod(pan_tilde_class, 66 | (t_method)pan_tilde_dsp, gensym("dsp"), A_CANT, 0); 67 | CLASS_MAINSIGNALIN(pan_tilde_class, t_pan_tilde, f); 68 | } 69 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.5.0 2 | lmdb>=0.98 3 | tqdm>=4.46.0 4 | scipy>=1.4.1 5 | librosa>=0.7.2 6 | effortless_config>=0.6.1 7 | numpy>=1.18.4 8 | matplotlib>=3.2.1 9 | scikit_learn>=0.23.1 10 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from .hparams import config 2 | 3 | from .cached_padding import CachedConv1d, cache_pad, CachedConvTranspose1d 4 | 5 | from .gan_modules import Generator, Discriminator 6 | from .melencoder import MelEncoder 7 | from .vanilla_vae import TopVAE 8 | from .domain_adaptation import Classifier 9 | 10 | from .model import get_model 11 | 12 | from .data import preprocess, Loader, get_flattening_function, gaussian_cdf, log_loudness, LogLoudness 13 | from .train_utils import train_step_melgan, train_step_vanilla 14 | from .pca_utils import compute_pca -------------------------------------------------------------------------------- /src/cached_padding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from . import config 4 | 5 | SCRIPT = True 6 | 7 | 8 | def cache_pad(*args, **kwargs): 9 | if SCRIPT: 10 | return torch.jit.script(CachedPadding(*args, **kwargs)) 11 | else: 12 | return CachedPadding(*args, **kwargs) 13 | 14 | 15 | class CachedPadding(nn.Module): 16 | """ 17 | Cached padding (buffer based inference) 18 | 19 | Replace nn.Conv1d(C,x,x,padding=P) with 20 | 21 | nn.Sequential( 22 | CachedPadding(P, C, True), 23 | nn.Conv1d(C,x,x,padding=0) 24 | ) 25 | 26 | And replace nn.ConvTranspose1d(C,x,2 * r,stride=r, padding=r//2) with 27 | 28 | nn.Sequential( 29 | CachedPadding(1, C, True), 30 | nn.ConvTranspose1d(C,x,2 * r,stride=r, padding=r//2 + r) 31 | ) 32 | """ 33 | def __init__(self, 34 | padding, 35 | channels, 36 | cache=False, 37 | pad_mode="constant", 38 | crop=False): 39 | super().__init__() 40 | self.padding = padding 41 | self.pad_mode = pad_mode 42 | 43 | left_pad = torch.zeros(1, channels, padding) 44 | self.register_buffer("left_pad", left_pad) 45 | 46 | self.cache = cache 47 | self.crop = crop 48 | 49 | def forward(self, x): 50 | if self.cache: 51 | padded_x = torch.cat([self.left_pad, x], -1) 52 | self.left_pad = padded_x[..., -self.padding:] 53 | if self.crop: 54 | padded_x = padded_x[..., :-(self.padding)] 55 | else: 56 | padded_x = nn.functional.pad( 57 | x, (self.padding // 2, self.padding // 2), mode=self.pad_mode) 58 | return padded_x 59 | 60 | def reset(self): 61 | self.left_pad.zero_() 62 | 63 | def __repr__(self): 64 | return f"CachedPadding(padding={self.padding}, cache={self.cache})" 65 | 66 | 67 | class CachedConv1d(nn.Module): 68 | def __init__(self, 69 | in_chan, 70 | out_chan, 71 | kernel, 72 | stride, 73 | padding, 74 | dilation=(1, ), 75 | cache=False, 76 | pad_mode="constant", 77 | weight_norm=False): 78 | super().__init__() 79 | self.pad = cache_pad(2 * padding, in_chan, cache, pad_mode) 80 | self.conv = nn.Conv1d(in_chan, 81 | out_chan, 82 | kernel, 83 | stride, 84 | dilation=dilation) 85 | if weight_norm: 86 | self.conv = nn.utils.weight_norm(self.conv) 87 | 88 | def forward(self, x): 89 | x = self.pad(x) 90 | x = self.conv(x) 91 | return x 92 | 93 | 94 | class CachedConvTranspose1d(nn.Module): 95 | def __init__(self, 96 | in_chan, 97 | out_chan, 98 | kernel, 99 | stride, 100 | dilation=(1, ), 101 | cache=False, 102 | pad_mode="constant", 103 | weight_norm=False): 104 | super().__init__() 105 | assert kernel == 2 * stride, "WESH" 106 | self.cache = cache 107 | self.stride = stride 108 | self.pad = cache_pad(1, in_chan, cache, pad_mode) 109 | self.conv = nn.ConvTranspose1d(in_chan, 110 | out_chan, 111 | kernel_size=kernel, 112 | stride=stride, 113 | padding=0) 114 | if weight_norm: 115 | self.conv = nn.utils.weight_norm(self.conv) 116 | 117 | def forward(self, x): 118 | if self.cache: 119 | x = self.pad(x) 120 | 121 | x = self.conv(x) 122 | 123 | if self.cache: 124 | x = x[..., self.stride:-self.stride] 125 | else: 126 | x = x[..., :-self.stride] 127 | 128 | return x -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | from udls import SimpleDataset 2 | import librosa as li 3 | import torch 4 | import torch.nn as nn 5 | 6 | import numpy as np 7 | from sklearn.mixture import GaussianMixture 8 | from scipy.special import erf 9 | 10 | from tqdm import tqdm 11 | 12 | from . import config 13 | 14 | 15 | class LogLoudness(nn.Module): 16 | def __init__(self, size, eps): 17 | super().__init__() 18 | win = torch.hann_window(size) / torch.mean(torch.hann_window(size)) 19 | win = win.reshape(1, 1, -1) 20 | self.register_buffer("win", win) 21 | self.eps = eps 22 | self.logeps = np.log(self.eps) 23 | self.size = size 24 | 25 | def forward(self, x): 26 | x = torch.stack(torch.split(x, self.size, -1), 1) 27 | x *= self.win 28 | logrms = .5 * torch.log(torch.clamp(torch.mean(x**2, -1), self.eps, 1)) 29 | logrms = (self.logeps - 2 * logrms) / self.logeps 30 | return logrms.unsqueeze(1) 31 | 32 | 33 | def log_loudness(x, size, eps=1e-4): 34 | x_win = x.reshape(x.shape[0], -1, size) 35 | win = np.hanning(size) 36 | win /= np.mean(win) 37 | win = win.reshape(1, 1, -1) 38 | x_win = x_win * win 39 | log_rms = .5 * np.log(np.clip(np.mean(x_win**2, -1), eps, 1)) 40 | log_rms = (np.log(eps) - 2 * log_rms) / np.log(eps) 41 | return log_rms.reshape(log_rms.shape[0], 1, -1) 42 | 43 | 44 | def gaussian_cdf(weights, means, stds): 45 | cdf = lambda x: np.sum([ 46 | w * .5 * (1 + erf((x - m) / (s * np.sqrt(2)))) 47 | for w, m, s in zip(weights, means, stds) 48 | ], 0) 49 | return cdf 50 | 51 | 52 | def get_flattening_function(x, n_mixture=10): 53 | # FIT GMM ON DATA 54 | gmm = GaussianMixture(n_mixture).fit(x.reshape(-1, 1)) 55 | weights, means, variances = gmm.weights_, gmm.means_, gmm.covariances_ 56 | weights = weights.reshape(-1) 57 | means = means.reshape(-1) 58 | stds = np.sqrt(variances.reshape(-1)) 59 | 60 | return weights, means, stds 61 | 62 | 63 | def preprocess(name): 64 | try: 65 | x = li.load(name, config.SAMPRATE)[0] 66 | except KeyboardInterrupt: 67 | exit() 68 | except: 69 | return None 70 | border = len(x) % config.N_SIGNAL 71 | 72 | if len(x) < config.N_SIGNAL: 73 | x = np.pad(x, (0, config.N_SIGNAL - len(x))) 74 | 75 | elif border: 76 | x = x[:-border] 77 | 78 | x = x.reshape(-1, config.N_SIGNAL) 79 | 80 | if config.TYPE == "vanilla": 81 | log_rms = log_loudness(x, config.HOP_LENGTH * np.prod(config.RATIOS)) 82 | x = zip(x, log_rms) 83 | 84 | return x 85 | 86 | 87 | class Loader(torch.utils.data.Dataset): 88 | def __init__(self, cat, config=config): 89 | super().__init__() 90 | if config.WAV_LOC is not None: 91 | wav_loc = config.WAV_LOC.split(",") 92 | else: 93 | wav_loc = None 94 | self.dataset = SimpleDataset(config.LMDB_LOC, 95 | folder_list=config.WAV_LOC, 96 | file_list=config.FILE_LIST, 97 | preprocess_function=preprocess, 98 | map_size=1e11) 99 | self.cat = cat 100 | 101 | def __len__(self): 102 | return len(self.dataset) 103 | 104 | def __getitem__(self, idx): 105 | if config.TYPE == "vanilla": 106 | sample = [] 107 | loudness = [] 108 | for i in range(self.cat): 109 | s, l = self.dataset[(idx + i) % self.__len__()] 110 | sample.append(torch.from_numpy(s).float()) 111 | loudness.append(torch.from_numpy(l).float()) 112 | sample = torch.cat(sample, -1) 113 | loudness = torch.cat(loudness, -1) 114 | return sample, loudness 115 | else: 116 | return self.dataset[idx] 117 | -------------------------------------------------------------------------------- /src/domain_adaptation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from . import config 4 | 5 | 6 | class GradientReverse(torch.autograd.Function): 7 | @staticmethod 8 | def forward(ctx, x, lam): 9 | ctx.lam = lam 10 | return x 11 | 12 | @staticmethod 13 | def backward(ctx, grad): 14 | return -ctx.lam * grad, None 15 | 16 | 17 | class Classifier(nn.Module): 18 | def __init__(self): 19 | super().__init__() 20 | conv = [] 21 | for i in range(len(config.CLASSIFIER_CHANNELS) - 1): 22 | conv.append( 23 | nn.Conv1d(config.CLASSIFIER_CHANNELS[i], 24 | config.CLASSIFIER_CHANNELS[i + 1], 25 | 5, 26 | padding=5 // 2)) 27 | if i != len(config.CLASSIFIER_CHANNELS) - 2: 28 | conv.append(nn.ReLU()) 29 | conv.append(nn.BatchNorm1d(config.CLASSIFIER_CHANNELS[i + 1])) 30 | 31 | lin = [] 32 | for i in range(len(config.CLASSIFIER_LIN_SIZE) - 1): 33 | lin.append( 34 | nn.Linear(config.CLASSIFIER_LIN_SIZE[i], 35 | config.CLASSIFIER_LIN_SIZE[i + 1])) 36 | if i != len(config.CLASSIFIER_LIN_SIZE) - 2: 37 | lin.append(nn.ReLU()) 38 | 39 | self.gradient_reversal = GradientReverse.apply 40 | self.conv = nn.Sequential(*conv) 41 | self.lin = nn.Sequential(*lin) 42 | 43 | def forward(self, z, lam=1): 44 | bs = z.shape[0] 45 | z = self.gradient_reversal(z, lam) 46 | z = self.conv(z) 47 | z = z.permute(0, 2, 1).reshape(-1, config.CLASSIFIER_LIN_SIZE[0]) 48 | z = self.lin(z) 49 | z = z.reshape(bs, -1, config.CLASSIFIER_LIN_SIZE[-1]).permute(0, 2, 1) 50 | mean, logvar = torch.split(z, 1, 1) 51 | return mean, logvar -------------------------------------------------------------------------------- /src/gan_modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from librosa.filters import mel as librosa_mel_fn 5 | from torch.nn.utils import weight_norm 6 | import numpy as np 7 | 8 | from . import config, CachedConvTranspose1d, CachedConv1d, cache_pad 9 | 10 | 11 | def weights_init(m): 12 | classname = m.__class__.__name__ 13 | if classname == "CachedConv1d" or classname == "CachedConvTranspose1d": 14 | m.conv.weight.data.normal_(0.0, 0.02) 15 | elif classname.find("Conv") != -1: 16 | m.weight.data.normal_(0.0, 0.02) 17 | elif classname.find("BatchNorm2d") != -1: 18 | m.weight.data.normal_(1.0, 0.02) 19 | m.bias.data.fill_(0) 20 | 21 | 22 | def WNConv1d(*args, **kwargs): 23 | return weight_norm(nn.Conv1d(*args, **kwargs)) 24 | 25 | 26 | def WNConvTranspose1d(*args, **kwargs): 27 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) 28 | 29 | 30 | class ResnetBlock(nn.Module): 31 | def __init__(self, dim, dilation=1, use_cached_padding=False): 32 | super().__init__() 33 | self.block = nn.Sequential( 34 | nn.LeakyReLU(0.2), 35 | CachedConv1d(dim, dim, 3, 1, dilation, dilation, 36 | use_cached_padding, "reflect", True), 37 | nn.LeakyReLU(0.2), 38 | WNConv1d(dim, dim, kernel_size=1), 39 | ) 40 | self.shortcut = WNConv1d(dim, dim, kernel_size=1) 41 | self.dilation = dilation 42 | self.use_cached_padding = use_cached_padding 43 | self.residual_padding = cache_pad(dilation, 44 | dim, 45 | cache=use_cached_padding, 46 | crop=True) 47 | 48 | def forward(self, x): 49 | blockout = self.block(x) 50 | shortcut = self.shortcut(x) 51 | if self.use_cached_padding: 52 | shortcut = self.residual_padding(shortcut) 53 | return blockout + shortcut 54 | 55 | 56 | class Generator(nn.Module): 57 | def __init__(self, 58 | input_size=config.INPUT_SIZE, 59 | ngf=config.NGF, 60 | n_residual_layers=config.N_RES_G, 61 | ratios=config.RATIOS, 62 | use_cached_padding=config.USE_CACHED_PADDING): 63 | 64 | super().__init__() 65 | self.hop_length = np.prod(ratios) 66 | mult = int(2**len(ratios)) 67 | 68 | model = [ 69 | CachedConv1d(input_size, 70 | mult * ngf, 71 | 7, 72 | 1, 73 | 3, 74 | cache=use_cached_padding, 75 | pad_mode="reflect", 76 | weight_norm=True) 77 | ] 78 | 79 | # Upsample to raw audio scale 80 | for i, r in enumerate(ratios): 81 | model += [ 82 | nn.LeakyReLU(0.2), 83 | CachedConvTranspose1d(mult * ngf, 84 | mult * ngf // 2, 85 | r * 2, 86 | r, 87 | cache=use_cached_padding, 88 | weight_norm=True) 89 | ] 90 | 91 | for j in range(n_residual_layers): 92 | model += [ 93 | ResnetBlock(mult * ngf // 2, 94 | dilation=3**j, 95 | use_cached_padding=use_cached_padding) 96 | ] 97 | 98 | mult //= 2 99 | 100 | model += [ 101 | nn.LeakyReLU(0.2), 102 | CachedConv1d(ngf, 103 | 1, 104 | 7, 105 | 1, 106 | 3, 107 | cache=use_cached_padding, 108 | pad_mode="reflect", 109 | weight_norm=True), 110 | nn.Tanh(), 111 | ] 112 | 113 | self.model = nn.Sequential(*model) 114 | self.apply(weights_init) 115 | 116 | def forward(self, x): 117 | x = self.model(x) 118 | return x 119 | 120 | 121 | class NLayerDiscriminator(nn.Module): 122 | def __init__(self, ndf, n_layers, downsampling_factor): 123 | super().__init__() 124 | model = nn.ModuleDict() 125 | 126 | model["layer_0"] = nn.Sequential( 127 | nn.ReflectionPad1d(7), 128 | WNConv1d(1, ndf, kernel_size=15), 129 | nn.LeakyReLU(0.2, True), 130 | ) 131 | 132 | nf = ndf 133 | stride = downsampling_factor 134 | for n in range(1, n_layers + 1): 135 | nf_prev = nf 136 | nf = min(nf * stride, 1024) 137 | 138 | model["layer_%d" % n] = nn.Sequential( 139 | WNConv1d( 140 | nf_prev, 141 | nf, 142 | kernel_size=stride * 10 + 1, 143 | stride=stride, 144 | padding=stride * 5, 145 | groups=nf_prev // 4, 146 | ), 147 | nn.LeakyReLU(0.2, True), 148 | ) 149 | 150 | nf = min(nf * 2, 1024) 151 | model["layer_%d" % (n_layers + 1)] = nn.Sequential( 152 | WNConv1d(nf_prev, nf, kernel_size=5, stride=1, padding=2), 153 | nn.LeakyReLU(0.2, True), 154 | ) 155 | 156 | model["layer_%d" % (n_layers + 2)] = WNConv1d(nf, 157 | 1, 158 | kernel_size=3, 159 | stride=1, 160 | padding=1) 161 | 162 | self.model = model 163 | 164 | def forward(self, x): 165 | results = [] 166 | for key, layer in self.model.items(): 167 | x = layer(x) 168 | results.append(x) 169 | return results 170 | 171 | 172 | class Discriminator(nn.Module): 173 | def __init__(self, 174 | num_D=config.NUM_D, 175 | ndf=config.NDF, 176 | n_layers=config.N_LAYER_D, 177 | downsampling_factor=config.DOWNSAMP_D): 178 | super().__init__() 179 | self.model = nn.ModuleDict() 180 | for i in range(num_D): 181 | self.model[f"disc_{i}"] = NLayerDiscriminator( 182 | ndf, n_layers, downsampling_factor) 183 | 184 | self.downsample = nn.AvgPool1d(4, 185 | stride=2, 186 | padding=1, 187 | count_include_pad=False) 188 | self.apply(weights_init) 189 | 190 | def forward(self, x): 191 | results = [] 192 | for key, disc in self.model.items(): 193 | results.append(disc(x)) 194 | x = self.downsample(x) 195 | return results 196 | -------------------------------------------------------------------------------- /src/hparams.py: -------------------------------------------------------------------------------- 1 | from effortless_config import Config, setting 2 | 3 | 4 | class config(Config): 5 | groups = ["vanilla", "melgan"] 6 | 7 | TYPE = setting(default="vanilla", vanilla="vanilla", melgan="melgan") 8 | 9 | # MELGAN PARAMETERS 10 | INPUT_SIZE = 128 11 | NGF = 32 12 | N_RES_G = 3 13 | 14 | HOP_LENGTH = 256 15 | 16 | RATIOS = setting(default=[1, 1, 1, 2, 1, 1, 1], 17 | vanilla=[1, 1, 1, 2, 1, 1, 1], 18 | melgan=[8, 8, 2, 2]) 19 | 20 | NUM_D = 3 21 | NDF = 16 22 | N_LAYER_D = 4 23 | DOWNSAMP_D = 4 24 | 25 | # AUTOENCODER 26 | CHANNELS = [128, 256, 256, 512, 512, 512, 128, 32] 27 | KERNEL = 5 28 | EXTRACT_LOUDNESS = False 29 | AUGMENT = setting(default=5, vanilla=5, melgan=1) 30 | 31 | # CLASSIFIER 32 | CLASSIFIER_CHANNELS = [16, 64, 256] 33 | CLASSIFIER_LIN_SIZE = [256, 64, 2] 34 | 35 | # TRAIN PARAMETERS 36 | PATH_PREPEND = "./runs/" 37 | SAMPRATE = 24000 38 | N_SIGNAL = setting(default=2**15, vanilla=2**15, melgan=2**14) 39 | EPOCH = 1000 40 | BATCH = 1 41 | LR = 1e-4 42 | NAME = "untitled" 43 | CKPT = None 44 | 45 | WAV_LOC = None 46 | FILE_LIST = None 47 | LMDB_LOC = "./preprocessed" 48 | 49 | BACKUP = 10000 50 | EVAL = 1000 51 | 52 | # INCREMENTAL GENERATION 53 | USE_CACHED_PADDING = False 54 | BUFFER_SIZE = 1024 55 | TARGET_SR = 48000 56 | 57 | 58 | if __name__ == "__main__": 59 | print(config) 60 | -------------------------------------------------------------------------------- /src/losses_train.md: -------------------------------------------------------------------------------- 1 | # Training melgan 2 | 3 | The loss definition is quite something, so here is creepy copy-pasta of the original training procedure. 4 | 5 | ```python 6 | ####################### 7 | # Train Discriminator # 8 | ####################### 9 | D_fake_det = netD(x_pred_t.cuda().detach()) 10 | D_real = netD(x_t.cuda()) 11 | 12 | loss_D = 0 13 | for scale in D_fake_det: 14 | loss_D += F.relu(1 + scale[-1]).mean() 15 | 16 | for scale in D_real: 17 | loss_D += F.relu(1 - scale[-1]).mean() 18 | 19 | netD.zero_grad() 20 | loss_D.backward() 21 | optD.step() 22 | 23 | ################### 24 | # Train Generator # 25 | ################### 26 | D_fake = netD(x_pred_t.cuda()) 27 | 28 | loss_G = 0 29 | for scale in D_fake: 30 | loss_G += -scale[-1].mean() 31 | 32 | loss_feat = 0 33 | feat_weights = 4.0 / (args.n_layers_D + 1) 34 | D_weights = 1.0 / args.num_D 35 | wt = D_weights * feat_weights 36 | for i in range(args.num_D): 37 | for j in range(len(D_fake[i]) - 1): 38 | loss_feat += wt * F.l1_loss(D_fake[i][j], D_real[i][j].detach()) 39 | 40 | netG.zero_grad() 41 | (loss_G + args.lambda_feat * loss_feat).backward() 42 | optG.step() 43 | ``` 44 | 45 | -------------------------------------------------------------------------------- /src/melencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import librosa as li 4 | import numpy as np 5 | from . import config 6 | 7 | module = lambda x: torch.sqrt(x[..., 0]**2 + x[..., 1]**2) 8 | 9 | 10 | class MelEncoder(nn.Module): 11 | def __init__(self, sampling_rate, hop, input_size, center=False): 12 | super().__init__() 13 | self.hop = hop 14 | self.nfft = 2048 15 | 16 | mel = li.filters.mel(sampling_rate, self.nfft, input_size, fmin=80) 17 | mel = torch.from_numpy(mel) 18 | 19 | self.register_buffer("mel", mel) 20 | self.center = center 21 | 22 | def forward(self, x): 23 | if len(x.shape) == 3: 24 | x = x.squeeze(1) 25 | 26 | if not config.USE_CACHED_PADDING: 27 | x = nn.functional.pad(x, (0, self.nfft - self.hop)) 28 | 29 | S = torch.stft(x, self.nfft, self.hop, 512, center=self.center) 30 | S = 2 * module(S) / 512 31 | S_mel = self.mel.matmul(S) 32 | 33 | if self.training: 34 | S_mel = S_mel[..., :x.shape[-1] // self.hop] 35 | return (torch.log10(torch.clamp(S_mel, min=1e-5)) + 5) / 5 -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from . import Generator, Discriminator, MelEncoder, TopVAE, config, Classifier 4 | 5 | 6 | class Vanilla(nn.Module): 7 | def __init__(self, sampling_rate, hop, ratios, input_size, channels, 8 | kernel, use_cached_padding, extract_loudness): 9 | super().__init__() 10 | self.melencoder = MelEncoder(sampling_rate=sampling_rate, 11 | hop=hop, 12 | input_size=input_size, 13 | center=False) 14 | self.topvae = TopVAE(channels=channels, 15 | kernel=kernel, 16 | ratios=ratios, 17 | use_cached_padding=use_cached_padding, 18 | extract_loudness=extract_loudness) 19 | 20 | if extract_loudness: 21 | self.classifier = Classifier() 22 | 23 | def forward(self, x, loudness=None): 24 | S = self.melencoder(x) 25 | y, mean_y, logvar_y, mean_z, logvar_z = self.topvae(S, loudness) 26 | return y, mean_y, logvar_y, mean_z, logvar_z 27 | 28 | 29 | class melGAN(nn.Module): 30 | def __init__(self, sampling_rate, hop, ratios, input_size, ngf, n_res_g, 31 | use_cached_padding): 32 | super().__init__() 33 | self.encoder = MelEncoder(sampling_rate=sampling_rate, 34 | hop=hop, 35 | input_size=input_size, 36 | center=False) 37 | self.decoder = Generator(input_size=input_size, 38 | ngf=ngf, 39 | n_residual_layers=n_res_g, 40 | ratios=ratios, 41 | use_cached_padding=use_cached_padding) 42 | 43 | def forward(self, x, mel_encoded=False): 44 | if mel_encoded: 45 | mel = x 46 | else: 47 | mel = self.encoder(x) 48 | 49 | y = self.decoder(mel) 50 | return y 51 | 52 | 53 | def get_model(config=config): 54 | if config.TYPE == "melgan": 55 | return melGAN(sampling_rate=config.SAMPRATE, 56 | hop=config.HOP_LENGTH, 57 | ratios=config.RATIOS, 58 | input_size=config.INPUT_SIZE, 59 | ngf=config.NGF, 60 | n_res_g=config.N_RES_G, 61 | use_cached_padding=config.USE_CACHED_PADDING) 62 | 63 | elif config.TYPE == "vanilla": 64 | return Vanilla(sampling_rate=config.SAMPRATE, 65 | hop=config.HOP_LENGTH, 66 | ratios=config.RATIOS, 67 | input_size=config.INPUT_SIZE, 68 | channels=config.CHANNELS, 69 | kernel=config.KERNEL, 70 | use_cached_padding=config.USE_CACHED_PADDING, 71 | extract_loudness=config.EXTRACT_LOUDNESS) 72 | else: 73 | raise Exception(f"Model type {config.TYPE} not understood") 74 | -------------------------------------------------------------------------------- /src/pca_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import Loader, config 3 | from tqdm import tqdm 4 | 5 | 6 | def compute_pca(model, batch_size): 7 | print(config) 8 | loader = Loader(5) 9 | dataloader = torch.utils.data.DataLoader(loader, 10 | batch_size=batch_size, 11 | drop_last=False, 12 | shuffle=True) 13 | 14 | z = [] 15 | 16 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 17 | model = model.to(device) 18 | 19 | for elm in tqdm(dataloader, desc="parsing dataset..."): 20 | z_ = model.encode(elm[0].squeeze().to(device)) # SHAPE B x Z x T 21 | z_ = z_.permute(0, 2, 1).reshape(-1, z_.shape[1]).cpu() # SHAPE BT x Z 22 | z.append(z_) 23 | 24 | z = torch.cat(z, 0) 25 | z = z[torch.randperm(z.shape[0])][:10000].permute(1, 0) 26 | z = z[:, torch.max(z, 0)[0] < 10] 27 | 28 | mean = torch.mean(z, -1, keepdim=True) 29 | std = 3 * torch.std(z) # 99.7% of the range (normal law) 30 | U = torch.svd(z - mean, some=False)[0] 31 | # U = torch.svd(z, some=False)[0] 32 | 33 | # torch.save(z, "z.pth") 34 | # torch.save(mean, "mean.pth") 35 | # torch.save(std, "std.pth") 36 | # torch.save(U, "U.pth") 37 | 38 | return mean.reshape(1, 1, -1), std, U 39 | -------------------------------------------------------------------------------- /src/resampling.py: -------------------------------------------------------------------------------- 1 | from scipy.signal import kaiserord, firwin 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from src.cached_padding import CachedConv1d 8 | 9 | 10 | def kaiser_filter(wc, atten, N=None): 11 | """ 12 | Computes a kaiser lowpass filter 13 | 14 | Parameters 15 | ---------- 16 | 17 | wc: float 18 | Angular frequency 19 | 20 | atten: float 21 | Attenuation (dB, positive) 22 | """ 23 | N_, beta = kaiserord(atten, wc / np.pi) 24 | N_ = 2 * (N_ // 2) + 1 25 | N = N if N is not None else N_ 26 | h = firwin(N, wc, window=('kaiser', beta), scale=False, nyq=np.pi) 27 | return h 28 | 29 | 30 | class Resampling(nn.Module): 31 | def __init__(self, target_sr, source_sr): 32 | super().__init__() 33 | ratio = target_sr // source_sr 34 | assert int(ratio) == ratio 35 | 36 | wc = np.pi / ratio 37 | filt = kaiser_filter(wc, 140) 38 | filt = torch.from_numpy(filt).float() 39 | 40 | self.downsample = CachedConv1d( 41 | 1, 42 | 1, 43 | len(filt), 44 | stride=ratio, 45 | padding=len(filt) // 2, 46 | cache=True, 47 | ) 48 | 49 | self.downsample.conv.weight.data.copy_(filt.reshape(1, 1, -1)) 50 | self.downsample.conv.bias.data.zero_() 51 | 52 | pad = len(filt) % ratio 53 | 54 | filt = nn.functional.pad(filt, (pad, 0)) 55 | filt = filt.reshape(-1, ratio).permute(1, 0) # ratio x T 56 | 57 | pad = (filt.shape[-1] + 1) % 2 58 | filt = nn.functional.pad(filt, (pad, 0)).unsqueeze(1) 59 | 60 | self.upsample = CachedConv1d( 61 | 1, 62 | 2, 63 | filt.shape[-1], 64 | stride=1, 65 | padding=filt.shape[-1] // 2, 66 | cache=True, 67 | ) 68 | 69 | self.upsample.conv.weight.data.copy_(filt) 70 | self.upsample.conv.bias.data.zero_() 71 | 72 | self.ratio = ratio 73 | 74 | @torch.jit.export 75 | def from_target_sampling_rate(self, x): 76 | return self.downsample(x) 77 | 78 | @torch.jit.export 79 | def to_target_sampling_rate(self, x): 80 | x = self.upsample(x) # B x 2 x T 81 | x = x.permute(0, 2, 1).reshape(x.shape[0], -1).unsqueeze(1) 82 | return x 83 | -------------------------------------------------------------------------------- /src/train_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import config 3 | import torch.nn.functional as F 4 | from os import path 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | 9 | def train_step_melgan(model, opt, data, writer, ROOT, step, device): 10 | gen, dis = model 11 | opt_gen, opt_dis = opt 12 | 13 | data = data.unsqueeze(1).to(device) 14 | 15 | y = gen(data) 16 | 17 | # TRAIN DISCRIMINATOR 18 | D_fake = dis(y.detach()) 19 | D_real = dis(data) 20 | 21 | loss_D = 0 22 | 23 | for scale in D_fake: 24 | loss_D += torch.relu(1 + scale[-1]).mean() 25 | for scale in D_real: 26 | loss_D += torch.relu(1 - scale[-1]).mean() 27 | 28 | opt_dis.zero_grad() 29 | loss_D.backward() 30 | opt_dis.step() 31 | 32 | # TRAIN GENERATOR 33 | D_fake = dis(y) 34 | 35 | loss_G = 0 36 | for scale in D_fake: 37 | loss_G += -scale[-1].mean() 38 | 39 | loss_feat = 0 40 | feat_weights = 4.0 / (config.N_LAYER_D + 1) 41 | D_weights = 1.0 / config.NUM_D 42 | wt = D_weights * feat_weights 43 | for i in range(config.NUM_D): 44 | for j in range(len(D_fake[i]) - 1): 45 | loss_feat += wt * F.l1_loss(D_fake[i][j], D_real[i][j].detach()) 46 | 47 | loss_complete = loss_G + 10 * loss_feat 48 | 49 | opt_gen.zero_grad() 50 | loss_complete.backward() 51 | opt_gen.step() 52 | 53 | writer.add_scalar("loss discriminator", loss_D, step) 54 | writer.add_scalar("loss adversarial", loss_G, step) 55 | writer.add_scalar("loss features", loss_feat, step) 56 | 57 | if step % config.BACKUP == 0: 58 | backup_name = path.join(ROOT, f"melgan_state.pth") 59 | states = [gen.state_dict(), dis.state_dict()] 60 | torch.save(states, backup_name) 61 | 62 | if step % config.EVAL == 0: 63 | writer.add_audio("original", data.reshape(-1), step, config.SAMPRATE) 64 | writer.add_audio("generated", y.reshape(-1), step, config.SAMPRATE) 65 | 66 | 67 | def train_step_vanilla(model, 68 | opt, 69 | data, 70 | writer, 71 | ROOT, 72 | step, 73 | device, 74 | flattening=None): 75 | if config.EXTRACT_LOUDNESS: 76 | sample, loudness = data 77 | sample = sample.to(device) 78 | loudness = loudness.to(device) 79 | fl = loudness.cpu().detach().numpy().reshape(-1) 80 | fl = flattening(fl) 81 | fl = torch.from_numpy(fl).float().to(loudness.device) 82 | else: 83 | sample = data[0].to(device) 84 | loudness = None 85 | 86 | with torch.no_grad(): 87 | S = model.melencoder(sample) 88 | 89 | # COMPUTE AUTOENCODER REC AND REG LOSSES 90 | out = model.topvae.loss(S, loudness) 91 | y, mean_y, logvar_y, mean_z, logvar_z, loss_rec, loss_reg = out 92 | loss = loss_rec + .1 * loss_reg 93 | 94 | # COMPUTE DOMAIN ADAPTATION LOSS 95 | if config.EXTRACT_LOUDNESS: 96 | z = torch.randn_like(mean_z) * torch.exp(logvar_z) + mean_z 97 | mean_loudness, logvar_loudness = model.classifier( 98 | z, 1 - np.exp(-step / 100000)) 99 | mean_loudness = torch.sigmoid(mean_loudness).reshape(-1) 100 | logvar_loudness = torch.clamp(logvar_loudness, -10, 0).reshape(-1) 101 | 102 | loss_da = torch.mean(logvar_loudness + (mean_loudness - fl)**2 * 103 | torch.exp(-logvar_loudness)) 104 | loss += loss_da 105 | 106 | opt.zero_grad() 107 | loss.backward() 108 | opt.step() 109 | 110 | writer.add_scalar("loss_rec", loss_rec, step) 111 | writer.add_scalar("loss_reg", loss_reg, step) 112 | 113 | if config.EXTRACT_LOUDNESS: 114 | writer.add_scalar("loss_da", loss_da, step) 115 | writer.add_scalar("lambda da", 1 - np.exp(-step / 100000), step) 116 | 117 | if step % config.BACKUP == 0: 118 | backup_name = path.join(ROOT, f"vanilla_state.pth") 119 | states = model.state_dict() 120 | torch.save(states, backup_name) 121 | 122 | if step % config.EVAL == 0: 123 | writer.add_histogram("mean_y", mean_y.reshape(-1), step) 124 | writer.add_histogram("logvar_y", logvar_y.reshape(-1), step) 125 | writer.add_histogram("mean_z", mean_z.reshape(-1), step) 126 | writer.add_histogram("logvar_z", logvar_z.reshape(-1), step) 127 | 128 | if config.EXTRACT_LOUDNESS: 129 | writer.add_histogram("mean_loudness", mean_loudness.reshape(-1), 130 | step) 131 | writer.add_histogram("logvar_loudness", 132 | logvar_loudness.reshape(-1), step) 133 | writer.add_histogram("flattened_loudness", fl.reshape(-1), step) 134 | 135 | ori = S.detach().cpu().numpy() 136 | ori = np.concatenate([o for o in ori[:4]], -1) 137 | 138 | rec = y.detach().cpu().numpy() 139 | rec = np.concatenate([r for r in rec[:4]], -1) 140 | 141 | img = np.concatenate([rec, ori], 0) 142 | 143 | plt.figure(figsize=(20, 10)) 144 | plt.imshow(img, aspect="auto", origin="lower", cmap="magma") 145 | plt.axis(False) 146 | plt.grid(False) 147 | plt.tight_layout() 148 | writer.add_figure("reconstruction", plt.gcf(), step) 149 | plt.close() 150 | -------------------------------------------------------------------------------- /src/vanilla_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from . import config, CachedConv1d, CachedConvTranspose1d 4 | import numpy as np 5 | 6 | 7 | class ConvEncoder(nn.Module): 8 | """ 9 | Multi Layer Convolutional Variational Encoder 10 | """ 11 | def __init__(self, channels, kernel, ratios, use_cached_padding): 12 | super().__init__() 13 | 14 | self.channels = channels 15 | self.kernel = kernel 16 | self.ratios = ratios 17 | 18 | self.convs = [] 19 | for i in range(len(self.ratios)): 20 | self.convs += [ 21 | CachedConv1d(self.channels[i], 22 | self.channels[i + 1], 23 | self.kernel, 24 | padding=self.kernel // 2, 25 | stride=self.ratios[i], 26 | cache=use_cached_padding) 27 | ] 28 | if i != len(self.ratios) - 1: 29 | self.convs += [nn.LeakyReLU(), nn.BatchNorm1d(self.channels[i + 1])] 30 | 31 | self.convs = nn.Sequential(*self.convs) 32 | 33 | def forward(self, x): 34 | x = self.convs(x) 35 | return x 36 | 37 | 38 | class ConvDecoder(nn.Module): 39 | """ 40 | Multi Layer Convolutional Variational Decoder 41 | """ 42 | def __init__(self, channels, ratios, kernel, use_cached_padding, 43 | extract_loudness): 44 | 45 | self.channels = channels 46 | self.ratios = ratios 47 | self.kernel = kernel 48 | 49 | super().__init__() 50 | self.channels = list(self.channels) 51 | self.channels[0] *= 2 52 | self.channels[-1] //= 2 53 | 54 | if extract_loudness: 55 | self.channels[-1] += 1 56 | 57 | self.convs = [] 58 | 59 | for i in range(len(self.ratios))[::-1]: 60 | if self.ratios[i] == 1: 61 | self.convs += [ 62 | CachedConv1d(self.channels[i + 1], 63 | self.channels[i], 64 | self.kernel, 65 | stride=1, 66 | padding=self.kernel // 2, 67 | cache=use_cached_padding) 68 | ] 69 | 70 | else: 71 | self.convs += [ 72 | CachedConvTranspose1d(self.channels[i + 1], 73 | self.channels[i], 74 | 2 * self.ratios[i], 75 | stride=self.ratios[i], 76 | cache=use_cached_padding) 77 | ] 78 | if i: 79 | self.convs += [nn.LeakyReLU(), nn.BatchNorm1d(self.channels[i])] 80 | 81 | self.convs = nn.Sequential(*self.convs) 82 | 83 | def forward(self, x): 84 | x = self.convs(x) 85 | return x 86 | 87 | 88 | class TopVAE(nn.Module): 89 | """ 90 | Top Variational Auto Encoder 91 | """ 92 | def __init__(self, channels, kernel, ratios, use_cached_padding, 93 | extract_loudness): 94 | super().__init__() 95 | self.encoder = ConvEncoder(channels, kernel, ratios, 96 | use_cached_padding) 97 | self.decoder = ConvDecoder(channels, ratios, kernel, 98 | use_cached_padding, extract_loudness) 99 | 100 | self.channels = channels 101 | 102 | skipped = 0 103 | for p in self.parameters(): 104 | try: 105 | nn.init.xavier_normal_(p) 106 | except: 107 | skipped += 1 108 | 109 | def encode(self, x): 110 | out = self.encoder(x) 111 | mean, logvar = torch.split(out, self.channels[-1] // 2, 1) 112 | z = torch.randn_like(mean) * torch.exp(logvar) + mean 113 | return z, mean, logvar 114 | 115 | def decode(self, z): 116 | rec = self.decoder(z) 117 | mean, logvar = torch.split(rec, self.channels[0], 1) 118 | mean = torch.sigmoid(mean) 119 | logvar = torch.clamp(logvar, min=-10, max=0) 120 | y = torch.randn_like(mean) * torch.exp(logvar) + mean 121 | return y, mean, logvar 122 | 123 | def deterministic_decode(self, z): 124 | rec = self.decoder(z) 125 | mean = torch.split(rec, self.channels[0], 1)[0] 126 | return torch.sigmoid(mean) 127 | 128 | def forward(self, x, loudness): 129 | z, mean_z, logvar_z = self.encode(x) 130 | if loudness is not None: 131 | z = torch.cat([loudness, z], 1) 132 | y, mean_y, logvar_y = self.decode(z) 133 | return y, mean_y, logvar_y, mean_z, logvar_z 134 | 135 | def loss(self, x, loudness): 136 | y, mean_y, logvar_y, mean_z, logvar_z = self.forward(x, loudness) 137 | 138 | loss_rec = logvar_y + (x - mean_y)**2 * torch.exp(-logvar_y) 139 | 140 | loss_reg = mean_z**2 + torch.exp(logvar_z) - logvar_z - 1 141 | 142 | loss_rec = torch.mean(loss_rec) 143 | loss_reg = torch.mean(loss_reg) 144 | 145 | return y, mean_y, logvar_y, mean_z, logvar_z, loss_rec, loss_reg 146 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.utils.tensorboard import SummaryWriter 4 | from effortless_config import Config 5 | 6 | import numpy as np 7 | 8 | from src import config 9 | from src import get_model, Discriminator, preprocess 10 | from src import train_step_melgan, train_step_vanilla 11 | from src import Loader, get_flattening_function, gaussian_cdf 12 | 13 | from tqdm import tqdm 14 | from os import path 15 | 16 | config.parse_args() 17 | 18 | # PREPARE DATA 19 | dataset = Loader(config.AUGMENT) 20 | dataloader = torch.utils.data.DataLoader(dataset, 21 | batch_size=config.BATCH, 22 | shuffle=True, 23 | drop_last=True) 24 | 25 | # PREPARE MODELS 26 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 27 | 28 | # MELGAN TRAINING 29 | if config.TYPE == "melgan": 30 | gen = get_model() 31 | dis = Discriminator() 32 | 33 | if config.CKPT is not None: 34 | ckptgen, ckptdis = torch.load(config.CKPT, map_location="cpu") 35 | gen.load_state_dict(ckptgen) 36 | dis.load_state_dict(ckptdis) 37 | 38 | gen = gen.to(device) 39 | dis = dis.to(device) 40 | 41 | # PREPARE OPTIMIZERS 42 | opt_gen = torch.optim.Adam(gen.parameters(), lr=config.LR, betas=[.5, .9]) 43 | opt_dis = torch.optim.Adam(dis.parameters(), lr=config.LR, betas=[.5, .9]) 44 | 45 | model = gen, dis 46 | opt = opt_gen, opt_dis 47 | 48 | # VANILLA VAE TRAINING 49 | if config.TYPE == "vanilla": 50 | model = get_model() 51 | if config.CKPT is not None: 52 | ckpt = torch.load(config.CKPT, map_location="cpu") 53 | model.load_state_dict(ckpt) 54 | model = model.to(device) 55 | 56 | # PREPARE OPTIMIZER 57 | opt = torch.optim.Adam(model.parameters(), lr=config.LR) 58 | 59 | ROOT = path.join(config.PATH_PREPEND, config.NAME, config.TYPE) 60 | writer = SummaryWriter(ROOT, flush_secs=20) 61 | 62 | with open(path.join(ROOT, "config.py"), "w") as config_out: 63 | config_out.write("from effortless_config import Config\n") 64 | config_out.write(str(config)) 65 | 66 | # POST LOADING PROCESSING 67 | with torch.no_grad(): 68 | if config.TYPE == "vanilla" and config.EXTRACT_LOUDNESS: 69 | try: 70 | print("flatten loudness found, loading") 71 | weights, means, stds = np.load(path.join(ROOT, "flatten.npy")) 72 | flattening_function = gaussian_cdf(weights, means, stds) 73 | except: 74 | loudness = [] 75 | for sample, loud_ in tqdm(dataloader, desc="parsing loudness"): 76 | loudness.append(loud_.reshape(-1)) 77 | loudness = torch.cat(loudness, 0).unsqueeze(1).numpy() 78 | loudness = loudness[:1000000] 79 | print("flattening dataset loudness...") 80 | weights, means, stds = get_flattening_function(loudness) 81 | np.save(path.join(ROOT, "flatten.npy"), [weights, means, stds]) 82 | flattening_function = gaussian_cdf(weights, means, stds) 83 | 84 | else: 85 | flattening_function = None 86 | 87 | print("Start training !") 88 | 89 | # TRAINING PROCESS 90 | step = 0 91 | for e in range(config.EPOCH): 92 | for batch in tqdm(dataloader): 93 | if config.TYPE == "vanilla": 94 | train_step_vanilla(model, 95 | opt, 96 | batch, 97 | writer, 98 | ROOT, 99 | step, 100 | device, 101 | flattening=flattening_function) 102 | 103 | elif config.TYPE == "melgan": 104 | train_step_melgan(model, opt, batch, writer, ROOT, step, device) 105 | 106 | step += 1 107 | -------------------------------------------------------------------------------- /udls/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import SimpleLMDBDataset 2 | from .domain_adaptation import DomainAdaptationDataset 3 | from .simple_dataset import SimpleDataset -------------------------------------------------------------------------------- /udls/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import lmdb 3 | import pickle 4 | 5 | 6 | class SimpleLMDBDataset(torch.utils.data.Dataset): 7 | """ 8 | Wraps a LDMB database as a torch compatible Dataset 9 | """ 10 | def __init__(self, out_database_location, map_size=1e9): 11 | super().__init__() 12 | self.env = lmdb.open(out_database_location, 13 | map_size=map_size, 14 | lock=False) 15 | with self.env.begin(write=False) as txn: 16 | lmdblength = txn.get("length".encode("utf-8")) 17 | self.len = int(lmdblength) if lmdblength is not None else 0 18 | 19 | def __len__(self): 20 | return self.len 21 | 22 | def __setitem__(self, idx, value): 23 | with self.env.begin(write=True) as txn: 24 | txn.put(f"{idx:08d}".encode("utf-8"), pickle.dumps(value)) 25 | if idx > self.len - 1: 26 | self.len = idx + 1 27 | txn.put("length".encode("utf-8"), 28 | f"{self.len:08d}".encode("utf-8")) 29 | 30 | def __getitem__(self, idx): 31 | with self.env.begin(write=False) as txn: 32 | value = pickle.loads(txn.get(f"{idx:08d}".encode("utf-8"))) 33 | return value -------------------------------------------------------------------------------- /udls/domain_adaptation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import SimpleLMDBDataset 3 | from pathlib import Path 4 | import librosa as li 5 | from concurrent.futures import ProcessPoolExecutor 6 | from os import makedirs, path 7 | from tqdm import tqdm 8 | import numpy as np 9 | 10 | 11 | def dummy_load(name): 12 | """ 13 | Preprocess function that takes one audio path and load it into 14 | chunks of 2048 samples. 15 | """ 16 | x = li.load(name, 16000)[0] 17 | if len(x) % 2048: 18 | x = x[:-(len(x) % 2048)] 19 | x = x.reshape(-1, 2048) 20 | return x 21 | 22 | 23 | class DomainAdaptationDataset(torch.utils.data.Dataset): 24 | def __init__(self, 25 | out_database_location, 26 | folder_list, 27 | preprocess_function=dummy_load, 28 | extension="*.wav", 29 | map_size=1e9): 30 | super().__init__() 31 | 32 | self.domains = [] 33 | 34 | makedirs(out_database_location, exist_ok=True) 35 | self.folder_list = folder_list 36 | self.preprocess_function = preprocess_function 37 | self.extension = extension 38 | 39 | for folder in folder_list: 40 | self.domains.append( 41 | SimpleLMDBDataset( 42 | path.join(out_database_location, 43 | path.basename(path.normpath(folder))), map_size)) 44 | 45 | # IF NO DATA INSIDE DATASET: PREPROCESS 46 | self.len = np.sum([len(env) for env in self.domains]) 47 | 48 | if self.len == 0: 49 | self._preprocess() 50 | self.len = np.sum([len(env) for env in self.domains]) 51 | 52 | if self.len == 0: 53 | raise Exception("No data found !") 54 | 55 | def _preprocess(self): 56 | for index_env, (folder, 57 | env) in enumerate(zip(self.folder_list, self.domains)): 58 | files = Path(folder).rglob(self.extension) 59 | 60 | index = 0 61 | 62 | with ProcessPoolExecutor(max_workers=16) as executor: 63 | for output in tqdm( 64 | executor.map(self.preprocess_function, files), 65 | desc=f"parsing dataset for env {index_env}"): 66 | if len(output): 67 | for elm in output: 68 | env[index] = elm 69 | index += 1 70 | 71 | def __len__(self): 72 | return self.len 73 | 74 | def __getitem__(self, index): 75 | for i, env in enumerate(self.domains): 76 | if index >= len(env): 77 | index -= len(env) 78 | else: 79 | return i, env[index] 80 | -------------------------------------------------------------------------------- /udls/simple_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import SimpleLMDBDataset 3 | from pathlib import Path 4 | import librosa as li 5 | from concurrent.futures import ProcessPoolExecutor, TimeoutError 6 | from os import makedirs, path 7 | from tqdm import tqdm 8 | import numpy as np 9 | 10 | 11 | def dummy_load(name): 12 | """ 13 | Preprocess function that takes one audio path and load it into 14 | chunks of 2048 samples. 15 | """ 16 | x = li.load(name, 16000)[0] 17 | if len(x) % 2048: 18 | x = x[:-(len(x) % 2048)] 19 | x = x.reshape(-1, 2048) 20 | if x.shape[0]: 21 | return x 22 | else: 23 | return None 24 | 25 | 26 | class SimpleDataset(torch.utils.data.Dataset): 27 | def __init__( 28 | self, 29 | out_database_location, 30 | folder_list=None, 31 | file_list=None, 32 | preprocess_function=dummy_load, 33 | extension="*.wav,*.aif", 34 | map_size=1e9, 35 | multiprocess=True, 36 | split_percent=.2, 37 | split_set="train", 38 | seed=0, 39 | ): 40 | super().__init__() 41 | 42 | assert folder_list is not None or file_list is not None 43 | 44 | self.env = SimpleLMDBDataset(out_database_location, map_size) 45 | 46 | self.folder_list = folder_list 47 | self.file_list = file_list 48 | 49 | self.preprocess_function = preprocess_function 50 | self.extension = extension 51 | self.multiprocess = multiprocess 52 | 53 | makedirs(out_database_location, exist_ok=True) 54 | 55 | # IF NO DATA INSIDE DATASET: PREPROCESS 56 | self.len = len(self.env) 57 | 58 | if self.len == 0: 59 | self._preprocess() 60 | self.len = len(self.env) 61 | 62 | if self.len == 0: 63 | raise Exception("No data found !") 64 | 65 | self.index = np.arange(self.len) 66 | np.random.seed(seed) 67 | np.random.shuffle(self.index) 68 | 69 | if split_set == "train": 70 | self.len = int(np.floor((1 - split_percent) * self.len)) 71 | self.offset = 0 72 | 73 | elif split_set == "test": 74 | self.offset = int(np.floor((1 - split_percent) * self.len)) 75 | self.len = self.len - self.offset 76 | 77 | elif split_set == "full": 78 | self.offset = 0 79 | 80 | def _preprocess(self): 81 | extension = self.extension.split(",") 82 | idx = 0 83 | wavs = [] 84 | 85 | # POPULATE WAV LIST 86 | if self.folder_list is not None: 87 | for f, folder in enumerate(self.folder_list.split(",")): 88 | print("Recursive search in {}".format(folder)) 89 | for ext in extension: 90 | wavs.extend(list(Path(folder).rglob(ext))) 91 | 92 | else: 93 | with open(self.file_list, "r") as file_list: 94 | wavs = file_list.read().split("\n") 95 | 96 | # CREATE ASYNCHRONOUS PREPROCESS TASKS 97 | if self.multiprocess: 98 | futures = [] 99 | with ProcessPoolExecutor() as executor: 100 | for wav in wavs: 101 | futures.append((path.basename(wav), 102 | executor.submit(self.preprocess_function, 103 | wav))) 104 | loader = tqdm(futures) 105 | for name, f in loader: 106 | loader.set_description("{}".format(name)) 107 | try: 108 | output = f.result(timeout=60) 109 | except TimeoutError: 110 | output = None 111 | print("Failed to preprocess {}".format(name)) 112 | if output is not None: 113 | for o in output: 114 | self.env[idx] = o 115 | idx += 1 116 | else: 117 | loader = tqdm(wavs) 118 | for wav in loader: 119 | loader.set_description("{}".format(path.basename(wav))) 120 | output = self.preprocess_function(wav) 121 | if output is not None: 122 | for o in output: 123 | self.env[idx] = o 124 | idx += 1 125 | 126 | def __len__(self): 127 | return self.len 128 | 129 | def __getitem__(self, index): 130 | return self.env[self.index[index + self.offset]] --------------------------------------------------------------------------------