├── Inference.ipynb ├── LICENSE ├── Make Training Data.ipynb ├── PGGAN.py ├── README.md ├── STFT.py ├── normalizer.py ├── phase_operation.py ├── pytorch_nsynth_lib ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ └── nsynth.cpython-35.pyc ├── from_tfrecord.py └── nsynth.py ├── spec_ops.py ├── spectrograms_helper.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Wokerker 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Make Training Data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.utils.data as data\n", 11 | "import torchvision.transforms as transforms\n", 12 | "import numpy as np\n", 13 | "from pytorch_nsynth_lib.nsynth import NSynth\n", 14 | "from IPython.display import Audio\n", 15 | "\n", 16 | "import librosa\n", 17 | "import librosa.display\n", 18 | "import phase_operation\n", 19 | "from tqdm import tqdm\n", 20 | "import h5py" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "import spec_ops as spec_ops\n", 30 | "import phase_operation as phase_op\n", 31 | "import spectrograms_helper as spec_helper" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 3, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "train_data = h5py.File('../data/Nsynth_melspec_IF_pitch.hdf5', 'w')\n" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 4, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "# audio samples are loaded as an int16 numpy array\n", 50 | "# rescale intensity range as float [-1, 1]\n", 51 | "toFloat = transforms.Lambda(lambda x: x / np.iinfo(np.int16).max)\n", 52 | "# use instrument_family and instrument_source as classification targets\n", 53 | "dataset = NSynth(\n", 54 | " \"../data/nsynth/nsynth-train\",\n", 55 | " transform=toFloat,\n", 56 | " blacklist_pattern=[ \"string\"], # blacklist string instrument\n", 57 | " categorical_field_list=[\"instrument_family\",\"pitch\"])\n", 58 | "loader = data.DataLoader(dataset, batch_size=1, shuffle=True)\n" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 5, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "def expand(mat):\n", 68 | " expand_vec = np.expand_dims(mat[:,125],axis=1)\n", 69 | " expanded = np.hstack((mat,expand_vec,expand_vec))\n", 70 | " return expanded" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "spec_list=[]\n", 80 | "pitch_list=[]\n", 81 | "IF_list =[]\n", 82 | "mel_spec_list=[]\n", 83 | "mel_IF_list=[]\n", 84 | "\n", 85 | "pitch_set =set()\n", 86 | "count=0\n", 87 | "for samples, instrument_family, pitch, targets in loader:\n", 88 | " \n", 89 | " pitch = targets['pitch'].data.numpy()[0]\n", 90 | "\n", 91 | " if pitch < 24 or pitch > 84:\n", 92 | "# print(\"pitch\",pitch)\n", 93 | " continue\n", 94 | " \n", 95 | " sample = samples.data.numpy().squeeze()\n", 96 | " spec = librosa.stft(sample, n_fft=2048, hop_length = 512)\n", 97 | " \n", 98 | " magnitude = np.log(np.abs(spec)+ 1.0e-6)[:1024]\n", 99 | "# print(\"magnitude Max\",magnitude.max(),\"magnitude Min\",magnitude.min())\n", 100 | " angle =np.angle(spec)\n", 101 | "# print(\"angle Max\",angle.max(),\"angle Min\",angle.min())\n", 102 | "\n", 103 | " IF = phase_operation.instantaneous_frequency(angle,time_axis=1)[:1024]\n", 104 | " \n", 105 | " magnitude = expand(magnitude)\n", 106 | " IF = expand(IF)\n", 107 | " logmelmag2, mel_p = spec_helper.specgrams_to_melspecgrams(magnitude, IF)\n", 108 | "\n", 109 | "# pitch = targets['pitch'].data.numpy()[0]\n", 110 | " \n", 111 | " \n", 112 | " assert magnitude.shape ==(1024, 128)\n", 113 | " assert IF.shape ==(1024, 128)\n", 114 | " \n", 115 | "# spec_list.append(magnitude)\n", 116 | "# IF_list.append(IF)\n", 117 | " pitch_list.append(pitch)\n", 118 | " mel_spec_list.append(logmelmag2)\n", 119 | " mel_IF_list.append(mel_p)\n", 120 | " pitch_set.add(pitch)\n", 121 | " \n", 122 | " count+=1\n", 123 | " if count%10000==0:\n", 124 | " print(count)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "# train_data.create_dataset(\"Spec\", data=spec_list)\n", 134 | "# train_data.create_dataset(\"IF\", data=IF_list)\n", 135 | "train_data.create_dataset(\"pitch\", data=pitch_list)\n", 136 | "train_data.create_dataset(\"mel_Spec\", data=mel_spec_list)\n", 137 | "train_data.create_dataset(\"mel_IF\", data=mel_IF_list)\n" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [] 153 | } 154 | ], 155 | "metadata": { 156 | "kernelspec": { 157 | "display_name": "Python 3", 158 | "language": "python", 159 | "name": "python3" 160 | }, 161 | "language_info": { 162 | "codemirror_mode": { 163 | "name": "ipython", 164 | "version": 3 165 | }, 166 | "file_extension": ".py", 167 | "mimetype": "text/x-python", 168 | "name": "python", 169 | "nbconvert_exporter": "python", 170 | "pygments_lexer": "ipython3", 171 | "version": "3.5.2" 172 | } 173 | }, 174 | "nbformat": 4, 175 | "nbformat_minor": 2 176 | } 177 | -------------------------------------------------------------------------------- /PGGAN.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.init import kaiming_normal_, calculate_gain 9 | 10 | 11 | #os.environ["CUDA_VISIBLE_DEVICES"] 12 | 13 | ################################################################################# 14 | # Construct Help Functions Class################################################# 15 | ################################################################################# 16 | class HelpFunc(object): 17 | @staticmethod 18 | def process_transition(a, b): 19 | """ 20 | Transit tensor a as tensor b's size by 21 | 'nearest neighbor filtering' and 'average pooling' respectively 22 | which mentioned below Figure2 of the Paper https://arxiv.org/pdf/1710.10196.pdf 23 | :param torch.Tensor a: is a tensor with size [batch, channel, height, width] 24 | :param torch.Tensor b: similar as a 25 | :return torch.Tensor : 26 | """ 27 | a_batch, a_channel, a_height, a_width = a.size() 28 | b_batch, b_channel, b_height, b_width = b.size() 29 | # Drop feature maps 30 | if a_channel > b_channel: 31 | a = a[:, :b_channel] 32 | 33 | if a_height > b_height: 34 | assert a_height % b_height == 0 and a_width % b_width == 0 35 | assert a_height / b_height == a_width / b_width 36 | ks = int(a_height // b_height) 37 | a = F.avg_pool2d(a, kernel_size=ks, stride=ks, padding=0, ceil_mode=False, count_include_pad=False) 38 | 39 | if a_height < b_height: 40 | assert b_height % a_height == 0 and b_width % a_width == 0 41 | assert b_height / a_height == b_width / a_width 42 | sf = b_height // a_height 43 | a = F.upsample(a, scale_factor=sf, mode='nearest') 44 | 45 | # Add feature maps. 46 | if a_channel < b_channel: 47 | z = torch.zeros((a_batch, b_channel - a_channel, b_height, b_width)) 48 | a = torch.cat([a, z], 1) 49 | # print("a size: ", a.size()) 50 | return a 51 | 52 | 53 | ################################################################################# 54 | # Construct Middle Classes ###################################################### 55 | ################################################################################# 56 | class PixelWiseNormLayer(nn.Module): 57 | """ 58 | Mentioned in '4.2 PIXELWISE FEATURE VECTOR NORMALIZATION IN GENERATOR' 59 | 'Local response normalization' 60 | """ 61 | 62 | def __init__(self): 63 | super(PixelWiseNormLayer, self).__init__() 64 | 65 | def forward(self, x): 66 | return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + 1e-8) 67 | 68 | 69 | class EqualizedLearningRateLayer(nn.Module): 70 | """ 71 | Mentioned in '4.1 EQUALIZED LEARNING RATE' 72 | Applies equalized learning rate to the preceding layer. 73 | *'To initialize all bias parameters to zero and all weights 74 | according to the normal distribution with unit variance' 75 | """ 76 | 77 | def __init__(self, layer): 78 | super(EqualizedLearningRateLayer, self).__init__() 79 | self.layer_ = layer 80 | 81 | # He's Initializer (He et al., 2015) 82 | kaiming_normal_(self.layer_.weight, a=calculate_gain('conv2d')) 83 | # Cause mean is 0 after He-kaiming function 84 | self.layer_norm_constant_ = (torch.mean(self.layer_.weight.data ** 2)) ** 0.5 85 | self.layer_.weight.data.copy_(self.layer_.weight.data / self.layer_norm_constant_) 86 | 87 | self.bias_ = self.layer_.bias if self.layer_.bias else None 88 | self.layer_.bias = None 89 | 90 | def forward(self, x): 91 | self.layer_norm_constant_ = self.layer_norm_constant_.type(torch.cuda.FloatTensor) 92 | x = self.layer_norm_constant_ * x 93 | if self.bias_ is not None: 94 | # x += self.bias.view(1, -1, 1, 1).expand_as(x) 95 | x += self.bias.view(1, self.bias.size()[0], 1, 1) 96 | return x 97 | 98 | 99 | class MiniBatchAverageLayer(nn.Module): 100 | def __init__(self, 101 | offset=1e-8 # From the original implementation 102 | # https://github.com/tkarras/progressive_growing_of_gans/blob/master/networks.py #L135 103 | ): 104 | super(MiniBatchAverageLayer, self).__init__() 105 | self.offset_ = offset 106 | 107 | def forward(self, x): 108 | # Follow Chapter3 of the Paper: 109 | # Computer the standard deviation for each feature 110 | # in each spatial locations to arrive at the single value 111 | stddev = torch.sqrt(torch.mean((x - torch.mean(x, dim=0, keepdim=True))**2, dim=0, keepdim=True) + self.offset_) 112 | inject_shape = list(x.size())[:] 113 | inject_shape[1] = 1 # Inject 1 line data for the second dim (channel dim). See Chapter3 and Table2 114 | inject = torch.mean(stddev, dim=1, keepdim=True) 115 | inject = inject.expand(inject_shape) 116 | return torch.cat((x, inject), dim=1) 117 | 118 | 119 | ################################################################################# 120 | # Construct Generator and Discriminator ######################################### 121 | ################################################################################# 122 | class Generator(nn.Module): 123 | def __init__(self, 124 | resolution, # Output resolution. Overridden based on dataset. 125 | latent_size, # Dimensionality of the latent vectors. 126 | final_channel=3, # Output channel size, for rgb always 3 127 | fmap_base=2 ** 13, # Overall multiplier for the number of feature maps. 128 | fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. 129 | fmap_max=2 ** 9, # Maximum number of feature maps in any layer. 130 | is_tanh=False, 131 | channel_list=None 132 | 133 | ): 134 | super(Generator, self).__init__() 135 | self.latent_size_ = latent_size 136 | self.is_tanh_ = is_tanh 137 | self.final_channel_ = final_channel 138 | 139 | # Use (fmap_max, fmap_decay, fmap_max) 140 | # to control every level's in||out channels 141 | self.fmap_base_ = fmap_base 142 | self.fmap_decay_ = fmap_decay 143 | self.fmap_max_ = fmap_max 144 | image_pyramid_ = int(np.log2(resolution)) # max level of the Image Pyramid 145 | self.resolution_ = 2 ** image_pyramid_ # correct resolution 146 | self.net_level_max_ = image_pyramid_ - 1 # minus 1 in order to exclude last rgb layer 147 | self.channel_list=channel_list 148 | 149 | self.lod_layers_ = nn.ModuleList() # layer blocks exclude to_rgb layer 150 | self.rgb_layers_ = nn.ModuleList() # rgb layers each correspond to specific level. 151 | 152 | for level in range(self.net_level_max_): 153 | self._construct_by_level(level) 154 | 155 | self.net_level_ = self.net_level_max_ # set default net level as max level 156 | self.net_status_ = "stable" # "stable" or "fadein" 157 | self.net_alpha_ = 1.0 # the previous stage's weight 158 | 159 | 160 | @property 161 | def net_config(self): 162 | """ 163 | Return current net's config. 164 | The config is used to control forward 165 | The pipeline was mentioned below Figure2 of the Paper 166 | """ 167 | return self.net_level_, self.net_status_, self.net_alpha_ 168 | 169 | @net_config.setter 170 | def net_config(self, config_list): 171 | """ 172 | :param iterable config_list: [net_level, net_status, net_alpha] 173 | :return: 174 | """ 175 | self.net_level_, self.net_status_, self.net_alpha_ = config_list 176 | 177 | def forward(self, x): 178 | """ 179 | The pipeline was mentioned below Figure2 of the Paper 180 | """ 181 | if self.net_status_ == "stable": 182 | cur_output_level = self.net_level_ 183 | # print("self.net_level_+1",self.net_level_+1) 184 | for cursor in range(self.net_level_+1): 185 | x = self.lod_layers_[cursor](x) 186 | # print(cursor,x.size()) 187 | x = self.rgb_layers_[cur_output_level](x) 188 | 189 | elif self.net_status_ == "fadein": 190 | pre_output_level = self.net_level_ - 1 191 | cur_output_level = self.net_level_ 192 | pre_weight, cur_weight = self.net_alpha_, 1.0 - self.net_alpha_ 193 | output_cache = [] 194 | for cursor in range(self.net_level_+1): 195 | x = self.lod_layers_[cursor](x) 196 | if cursor == pre_output_level: 197 | output_cache.append(self.rgb_layers_[cursor](x)) 198 | if cursor == cur_output_level: 199 | output_cache.append(self.rgb_layers_[cursor](x)) 200 | x = HelpFunc.process_transition(output_cache[0], output_cache[1]) * pre_weight \ 201 | + output_cache[1] * cur_weight 202 | 203 | else: 204 | raise AttributeError("Please set the net_status: ['stable', 'fadein']") 205 | 206 | # """Final Layer""" 207 | # if self.net_level_max_ == self.net_level_: 208 | # print("Reach MAx") 209 | # x = F.tanh(x) 210 | 211 | 212 | return x 213 | 214 | def _construct_by_level(self, cursor): 215 | in_level = cursor 216 | out_level = cursor + 1 217 | if self.channel_list is not None: 218 | in_channels=self.channel_list[in_level] 219 | out_channels=self.channel_list[out_level] 220 | print("Cursor",cursor,in_channels,out_channels) 221 | else: 222 | in_channels, out_channels = map(self._get_channel_by_stage, (in_level, out_level)) 223 | 224 | block_type = "First" if cursor == 0 else "UpSample" 225 | self._create_block(in_channels, out_channels, block_type) # construct previous (max_level - 1) layers 226 | self._create_block(out_channels, self.final_channel_, "ToRGB") # construct rgb layer for each previous level 227 | 228 | def _create_block(self, in_channels, out_channels, block_type): 229 | """ 230 | Create a network block 231 | :param block_type: only can be "First"||"UpSample"||"ToRGB" 232 | :return: 233 | """ 234 | block_cache = [] 235 | if block_type in ["First", "UpSample"]: 236 | if block_type == "First": 237 | block_cache.append(PixelWiseNormLayer()) 238 | block_cache.append(nn.Conv2d(self.latent_size_+128, out_channels, 239 | kernel_size=(2,16), stride=1, padding=(1,15), bias=False)) 240 | if block_type == "UpSample": 241 | block_cache.append(nn.Upsample(scale_factor=2, mode='nearest')) 242 | block_cache.append(nn.Conv2d(in_channels, out_channels, 243 | kernel_size=3, stride=1, padding=1, bias=False)) 244 | block_cache.append(EqualizedLearningRateLayer(block_cache[-1])) 245 | block_cache.append(nn.LeakyReLU(negative_slope=0.2)) 246 | block_cache.append(PixelWiseNormLayer()) 247 | block_cache.append(nn.Conv2d(out_channels, out_channels, 248 | kernel_size=3, stride=1, padding=1, bias=False)) 249 | block_cache.append(EqualizedLearningRateLayer(block_cache[-1])) 250 | block_cache.append(nn.LeakyReLU(negative_slope=0.2)) 251 | block_cache.append(PixelWiseNormLayer()) 252 | self.lod_layers_.append(nn.Sequential(*block_cache)) 253 | elif block_type == "ToRGB": 254 | block_cache.append(nn.Conv2d(in_channels, out_channels=out_channels, 255 | kernel_size=1, stride=1, padding=0, bias=False)) 256 | block_cache.append(EqualizedLearningRateLayer(block_cache[-1])) 257 | if self.is_tanh_ is True: 258 | block_cache.append(nn.Tanh()) 259 | self.rgb_layers_.append(nn.Sequential(*block_cache)) 260 | else: 261 | raise TypeError("'block_type' must in ['First', 'UpSample', 'ToRGB']") 262 | 263 | def _get_channel_by_stage(self, level): 264 | return min(int(self.fmap_base_ / (2.0 ** (level * self.fmap_decay_))), self.fmap_max_) 265 | 266 | 267 | class Discriminator(nn.Module): 268 | def __init__(self, 269 | resolution, # Output resolution. Overridden based on dataset. 270 | input_channel=2, # input channel size, for rgb always 3 271 | fmap_base=2 ** 13, # Overall multiplier for the number of feature maps. 272 | fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. 273 | fmap_max=2 ** 9, # Maximum number of feature maps in any layer. 274 | is_sigmoid=False, 275 | channel_list=None 276 | ): 277 | super(Discriminator, self).__init__() 278 | self.input_channel_ = input_channel 279 | self.is_sigmoid_ = is_sigmoid 280 | # Use (fmap_max, fmap_decay, fmap_max) 281 | # to control every level's in||out channels 282 | self.fmap_base_ = fmap_base 283 | self.fmap_decay_ = fmap_decay 284 | self.fmap_max_ = fmap_max 285 | image_pyramid_ = int(np.log2(resolution)) # max level of the Image Pyramid 286 | self.resolution_ = 2 ** image_pyramid_ # correct resolution 287 | self.net_level_max_ = image_pyramid_ - 1 # minus 1 in order to exclude first rgb layer 288 | self.channel_list=channel_list 289 | self.lod_layers_ = nn.ModuleList() # layer blocks exclude to_rgb layer 290 | self.rgb_layers_ = nn.ModuleList() # rgb layers each correspond to specific level. 291 | 292 | for level in range(self.net_level_max_, 0, -1): 293 | self._construct_by_level(level) 294 | 295 | self.net_level_ = self.net_level_max_ # set default net level as max level 296 | self.net_status_ = "stable" # "stable" or "fadein" 297 | self.net_alpha_ = 1.0 # the previous stage's weight 298 | self.Softmax = nn.LogSoftmax(dim=1) 299 | 300 | @property 301 | def net_config(self): 302 | return self.net_level_, self.net_status_, self.net_alpha_ 303 | 304 | @net_config.setter 305 | def net_config(self, config_list): 306 | self.net_level_, self.net_status_, self.net_alpha_ = config_list 307 | 308 | def forward(self, x): 309 | if self.net_status_ == "stable": 310 | cur_input_level = self.net_level_max_ - self.net_level_ - 1 311 | x = self.rgb_layers_[cur_input_level](x) 312 | for cursor in range(cur_input_level, self.net_level_max_): 313 | x = self.lod_layers_[cursor](x) 314 | 315 | 316 | B = x.size()[0] 317 | x = x.reshape(B,-1) # flatten 318 | 319 | pitch_distribution = self.Softmax(self.pitch_classifier(x)) 320 | discriminator_output= self.discriminator_classifier(x) 321 | return pitch_distribution, discriminator_output 322 | 323 | elif self.net_status_ == "fadein": 324 | pre_input_level = self.net_level_max_ - self.net_level_ 325 | cur_input_level = self.net_level_max_ - self.net_level_ - 1 326 | pre_weight, cur_weight = self.net_alpha_, 1.0 - self.net_alpha_ 327 | 328 | x_pre_cache = self.rgb_layers_[pre_input_level](x) 329 | x_cur_cache = self.rgb_layers_[cur_input_level](x) 330 | x_cur_cache = self.lod_layers_[cur_input_level](x_cur_cache) 331 | x = HelpFunc.process_transition(x_pre_cache, x_cur_cache) * pre_weight + x_cur_cache * cur_weight 332 | 333 | for cursor in range(cur_input_level + 1, self.net_level_max_): 334 | x = self.lod_layers_[cursor](x) 335 | 336 | B = x.size()[0] 337 | x = x.reshape(B,-1) 338 | pitch_distribution = self.Softmax(self.pitch_classifier(x)) 339 | discriminator_output= self.discriminator_classifier(x) 340 | return pitch_distribution, discriminator_output 341 | 342 | else: 343 | raise AttributeError("Please set the net_status: ['stable', 'fadein']") 344 | 345 | return x 346 | 347 | def _construct_by_level(self, cursor): 348 | in_level = cursor 349 | out_level = cursor - 1 350 | 351 | if self.channel_list is not None: 352 | in_channels=self.channel_list[in_level] 353 | out_channels=self.channel_list[out_level] 354 | print("Cursor",cursor,in_channels,out_channels) 355 | else: 356 | in_channels, out_channels = map(self._get_channel_by_stage, (in_level, out_level)) 357 | 358 | block_type = "Minibatch" if cursor == 1 else "DownSample" 359 | self._create_block(in_channels, out_channels, block_type) # construct (max_level-1) layers(exclude rgb layer) 360 | self._create_block(self.input_channel_, in_channels, "FromRGB") # construct rgb layer for each previous level 361 | 362 | """ Create pitch classifier and discriminator output""" 363 | if block_type == "Minibatch": 364 | self.pitch_classifier= nn.Linear(self.channel_list[0]*2*16, 128) 365 | self.discriminator_classifier= nn.Linear(self.channel_list[0]*2*16, 1) 366 | 367 | 368 | def _create_block(self, in_channels, out_channels, block_type): 369 | """ 370 | Create a network block 371 | :param block_type: only can be "Minibatch"||"DownSample"||"FromRGB" 372 | :return: 373 | """ 374 | block_cache = [] 375 | if block_type == "DownSample": 376 | block_cache.append(nn.Conv2d(in_channels, out_channels, 377 | kernel_size=3, stride=1, padding=1, bias=False)) 378 | block_cache.append(EqualizedLearningRateLayer(block_cache[-1])) 379 | block_cache.append(nn.LeakyReLU(negative_slope=0.2)) 380 | block_cache.append(nn.Conv2d(out_channels, out_channels, 381 | kernel_size=3, stride=1, padding=1, bias=False)) 382 | block_cache.append(EqualizedLearningRateLayer(block_cache[-1])) 383 | block_cache.append(nn.LeakyReLU(negative_slope=0.2)) 384 | block_cache.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=False, count_include_pad=False)) 385 | self.lod_layers_.append(nn.Sequential(*block_cache)) 386 | elif block_type == "FromRGB": 387 | block_cache.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 388 | kernel_size=1, stride=1, padding=0, bias=False)) 389 | block_cache.append(EqualizedLearningRateLayer(block_cache[-1])) 390 | block_cache.append(nn.LeakyReLU(negative_slope=0.2)) 391 | self.rgb_layers_.append(nn.Sequential(*block_cache)) 392 | elif block_type == "Minibatch": 393 | block_cache.append(MiniBatchAverageLayer()) 394 | block_cache.append(nn.Conv2d(in_channels + 1, out_channels, 395 | kernel_size=3, stride=1, padding=1, bias=False)) 396 | block_cache.append(EqualizedLearningRateLayer(block_cache[-1])) 397 | block_cache.append(nn.LeakyReLU(negative_slope=0.2)) 398 | block_cache.append(nn.Conv2d(out_channels, out_channels, 399 | kernel_size=3, stride=1, padding=1, bias=False)) 400 | block_cache.append(EqualizedLearningRateLayer(block_cache[-1])) 401 | block_cache.append(nn.LeakyReLU(negative_slope=0.2)) 402 | # block_cache.append(nn.Conv2d(out_channels, out_channels=1, 403 | # kernel_size=1, stride=1, padding=0, bias=False)) 404 | # block_cache.append(EqualizedLearningRateLayer(block_cache[-1])) 405 | 406 | if self.is_sigmoid_ is True: 407 | block_cache.append(nn.Sigmoid()) 408 | self.lod_layers_.append(nn.Sequential(*block_cache)) 409 | else: 410 | raise TypeError("'block_type' must in ['Minibatch', 'DownSample', 'FromRGB']") 411 | 412 | def _get_channel_by_stage(self, level): 413 | return min(int(self.fmap_base_ / (2.0 ** (level * self.fmap_decay_))), self.fmap_max_) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GANsynth-pytorch 2 | A simplified PyTorch implementation of GANsynth from magenta. 3 | Some of the codes are borrowed from magenta orginal repo. 4 | 5 | ## Note 6 | This repo only support the best setting in the GANsynth paper which is a simple version comparing to orginal Tensorflow version by magenta team. 7 | So if you want to test other frequency setting, you may need to modify the code. 8 | 9 | 10 | 11 | ## Prepare Data 12 | Use Make Training Data notebook to generate HDF5 file for training. 13 | 14 | ## Train a new model 15 | You have to make a directory to save model checkpint and output spectrum first. 16 | ```sh 17 | python3 train.py 18 | ``` 19 | 20 | ## Inference 21 | Use Inference notebook to load the model and generate audio. 22 | 23 | ## Example Generated Audio 24 | https://drive.google.com/open?id=1tNnOtcqCpgTTXGmkHJBA4K6MalBjdsPC 25 | 26 | ## Reference 27 | - https://github.com/shanexn/pytorch-pggan 28 | - https://github.com/kwon-young/pytorch-nsynth 29 | - [GANsynth Magenta](https://github.com/tensorflow/magenta/tree/master/magenta/models/gansynth) 30 | - [GANsynth ICLR Paper](https://arxiv.org/abs/1902.08710) 31 | -------------------------------------------------------------------------------- /STFT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from torch.autograd import Variable 6 | 7 | class STFT(torch.nn.Module): 8 | def __init__(self, filter_length=1024, hop_length=512): 9 | super(STFT, self).__init__() 10 | 11 | self.filter_length = filter_length 12 | self.hop_length = hop_length 13 | self.forward_transform = None 14 | scale = self.filter_length / self.hop_length 15 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 16 | 17 | cutoff = int((self.filter_length / 2 + 1)) 18 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), 19 | np.imag(fourier_basis[:cutoff, :])]) 20 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 21 | inverse_basis = torch.FloatTensor(np.linalg.pinv(scale * fourier_basis).T[:, None, :]) 22 | 23 | self.register_buffer('forward_basis', forward_basis.float()) 24 | self.register_buffer('inverse_basis', inverse_basis.float()) 25 | self.num_samples = 219904 26 | 27 | def transform(self, input_data): 28 | num_batches = input_data.size(0) 29 | num_samples = input_data.size(1) 30 | 31 | self.num_samples = num_samples 32 | 33 | input_data = input_data.view(num_batches, 1, num_samples) 34 | forward_transform = F.conv1d(input_data, 35 | Variable(self.forward_basis, requires_grad=False), 36 | stride = self.hop_length, 37 | padding = self.filter_length) 38 | cutoff = int((self.filter_length / 2) + 1) 39 | real_part = forward_transform[:, :cutoff, :] 40 | imag_part = forward_transform[:, cutoff:, :] 41 | 42 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 43 | phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) 44 | return magnitude, phase 45 | 46 | def inverse(self, magnitude, phase): 47 | # print("magnitude",magnitude[0,0:2,0:10]) 48 | # print("phase",phase[0,0:2,0:10]) 49 | 50 | recombine_magnitude_phase = torch.cat([magnitude*torch.cos(phase), 51 | magnitude*torch.sin(phase)], dim=1) 52 | # print("recombine_magnitude_phase",recombine_magnitude_phase.size()) 53 | # print("recombine_magnitude_phase",recombine_magnitude_phase[0,0:2,0:10]) 54 | 55 | inverse_transform = F.conv_transpose1d(recombine_magnitude_phase, 56 | Variable(self.inverse_basis, requires_grad=False), 57 | stride=self.hop_length, 58 | padding=0) 59 | inverse_transform = inverse_transform[:, :, self.filter_length:] 60 | inverse_transform = inverse_transform[:, :, :self.num_samples] 61 | # print("inverse_transform",inverse_transform[0,0:2,0:10]) 62 | # print("inverse_transform",inverse_transform.size()) 63 | return inverse_transform 64 | 65 | def forward(self, input_data): 66 | self.magnitude, self.phase = self.transform(input_data) 67 | reconstruction = self.inverse(self.magnitude, self.phase) 68 | return reconstruction 69 | -------------------------------------------------------------------------------- /normalizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | 6 | class DataNormalizer(object): 7 | def __init__(self, dataloader): 8 | self.dataloader = dataloader 9 | # self.s_a = 0 10 | # self.s_b = 0 11 | # self.p_a = 0 12 | # self.p_b = 0 13 | 14 | 15 | self._range_normalizer(magnitude_margin=0.8, IF_margin=1.0) 16 | print("s_a:", self.s_a ) 17 | print("s_b:", self.s_b ) 18 | print("p_a:", self.p_a) 19 | print("p_b:", self.p_b) 20 | 21 | # def _range_normalizer(x, margin): 22 | # # x = x.flatten() 23 | # min_x = x.min() 24 | # max_x = x.max() 25 | 26 | # a = margin * (2.0 / (max_x - min_x)) 27 | # b = margin * (-2.0 * min_x / (max_x - min_x) - 1.0) 28 | # return a, b 29 | 30 | def _range_normalizer(self, magnitude_margin, IF_margin): 31 | # x = x.flatten() 32 | min_spec = 10000 33 | max_spec = -10000 34 | min_IF = 10000 35 | max_IF = -10000 36 | 37 | for batch_idx, (spec, IF, pitch_label, mel_spec, mel_IF) in enumerate(self.dataloader.train_loader): 38 | 39 | # training mel 40 | spec = mel_spec 41 | IF = mel_IF 42 | 43 | 44 | 45 | # print("spec",spec.shape) 46 | # print("IF",IF.shape) 47 | 48 | if spec.min() < min_spec: min_spec=spec.min() 49 | if spec.max() > max_spec: max_spec=spec.max() 50 | 51 | if IF.min() < min_IF: min_IF=IF.min() 52 | if IF.max() > max_IF: max_IF=IF.max() 53 | 54 | # print(min_spec) 55 | # print(max_spec) 56 | # print(min_IF) 57 | # print(max_IF) 58 | 59 | self.s_a = magnitude_margin * (2.0 / (max_spec - min_spec)) 60 | self.s_b = magnitude_margin * (-2.0 * min_spec / (max_spec - min_spec) - 1.0) 61 | 62 | self.p_a = IF_margin * (2.0 / (max_IF - min_IF)) 63 | self.p_b = IF_margin * (-2.0 * min_IF / (max_IF - min_IF) - 1.0) 64 | 65 | 66 | def normalize(self, feature_map): 67 | # print("feature_map",feature_map.shape) 68 | # spec = feature_map[:, :, :, 0] 69 | # IF = feature_map[:, :, :, 1] 70 | 71 | # s_a, s_b = self._range_normalizer(spec, 0.8) 72 | # p_a, p_b = self._range_normalizer(IF, 0.8) 73 | 74 | a = np.asarray([self.s_a, self.p_a])[None, :, None, None] 75 | b = np.asarray([self.s_b, self.p_b])[None, :, None, None] 76 | a = torch.FloatTensor(a).cuda() 77 | b = torch.FloatTensor(b).cuda() 78 | # print("feature_map",feature_map.shape) 79 | feature_map = feature_map *a + b 80 | # print("spec Max",feature_map[:,0,:,:].max()) 81 | # print("spec min",feature_map[:,0,:,:].min()) 82 | # print("IF Max",feature_map[:,1,:,:].max()) 83 | # print("IF min",feature_map[:,1,:,:].min()) 84 | 85 | # clip_spec = spec *s_a + s_b 86 | # clip_IF = IF*p_a + p_b 87 | # return clip_spec, clip_IF, (s_a, s_b), (p_a, p_b) 88 | return feature_map 89 | 90 | def denormalize(spec, IF, s_a, s_b, p_a, p_b): 91 | spec = (spec -s_b) / s_a 92 | IF = (IF-p_b) / p_a 93 | return spec, IF -------------------------------------------------------------------------------- /phase_operation.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | from intervaltree import Interval,IntervalTree 4 | 5 | 6 | def diff(x, axis): 7 | """Take the finite difference of a tensor along an axis. 8 | Args: 9 | x: Input tensor of any dimension. 10 | axis: Axis on which to take the finite difference. 11 | Returns: 12 | d: Tensor with size less than x by 1 along the difference dimension. 13 | Raises: 14 | ValueError: Axis out of range for tensor. 15 | """ 16 | shape = x.shape 17 | 18 | begin_back = [0 for unused_s in range(len(shape))] 19 | # print("begin_back",begin_back) 20 | begin_front = [0 for unused_s in range(len(shape))] 21 | 22 | begin_front[axis] = 1 23 | # print("begin_front",begin_front) 24 | 25 | size = list(shape) 26 | size[axis] -= 1 27 | # print("size",size) 28 | slice_front = x[begin_front[0]:begin_front[0]+size[0], begin_front[1]:begin_front[1]+size[1]] 29 | slice_back = x[begin_back[0]:begin_back[0]+size[0], begin_back[1]:begin_back[1]+size[1]] 30 | 31 | # slice_front = tf.slice(x, begin_front, size) 32 | # slice_back = tf.slice(x, begin_back, size) 33 | # print("slice_front",slice_front) 34 | # print(slice_front.shape) 35 | # print("slice_back",slice_back) 36 | 37 | d = slice_front - slice_back 38 | return d 39 | 40 | 41 | def unwrap(p, discont=np.pi, axis=-1): 42 | """Unwrap a cyclical phase tensor. 43 | Args: 44 | p: Phase tensor. 45 | discont: Float, size of the cyclic discontinuity. 46 | axis: Axis of which to unwrap. 47 | Returns: 48 | unwrapped: Unwrapped tensor of same size as input. 49 | """ 50 | dd = diff(p, axis=axis) 51 | # print("dd",dd) 52 | ddmod = np.mod(dd+np.pi,2.0*np.pi)-np.pi # ddmod = tf.mod(dd + np.pi, 2.0 * np.pi) - np.pi 53 | # print("ddmod",ddmod) 54 | 55 | idx = np.logical_and(np.equal(ddmod, -np.pi),np.greater(dd,0)) # idx = tf.logical_and(tf.equal(ddmod, -np.pi), tf.greater(dd, 0)) 56 | # print("idx",idx) 57 | ddmod = np.where(idx, np.ones_like(ddmod) *np.pi, ddmod) # ddmod = tf.where(idx, tf.ones_like(ddmod) * np.pi, ddmod) 58 | # print("ddmod",ddmod) 59 | ph_correct = ddmod - dd 60 | # print("ph_corrct",ph_correct) 61 | 62 | idx = np.less(np.abs(dd), discont) # idx = tf.less(tf.abs(dd), discont) 63 | 64 | ddmod = np.where(idx, np.zeros_like(ddmod), dd) # ddmod = tf.where(idx, tf.zeros_like(ddmod), dd) 65 | ph_cumsum = np.cumsum(ph_correct, axis=axis) # ph_cumsum = tf.cumsum(ph_correct, axis=axis) 66 | # print("idx",idx) 67 | # print("ddmod",ddmod) 68 | # print("ph_cumsum",ph_cumsum) 69 | 70 | 71 | shape = np.array(p.shape) # shape = p.get_shape().as_list() 72 | 73 | shape[axis] = 1 74 | ph_cumsum = np.concatenate([np.zeros(shape, dtype=p.dtype), ph_cumsum], axis=axis) 75 | #ph_cumsum = tf.concat([tf.zeros(shape, dtype=p.dtype), ph_cumsum], axis=axis) 76 | unwrapped = p + ph_cumsum 77 | # print("unwrapped",unwrapped) 78 | return unwrapped 79 | 80 | 81 | def instantaneous_frequency(phase_angle, time_axis): 82 | """Transform a fft tensor from phase angle to instantaneous frequency. 83 | Unwrap and take the finite difference of the phase. Pad with initial phase to 84 | keep the tensor the same size. 85 | Args: 86 | phase_angle: Tensor of angles in radians. [Batch, Time, Freqs] 87 | time_axis: Axis over which to unwrap and take finite difference. 88 | Returns: 89 | dphase: Instantaneous frequency (derivative of phase). Same size as input. 90 | """ 91 | phase_unwrapped = unwrap(phase_angle, axis=time_axis) 92 | # print("phase_unwrapped",phase_unwrapped.shape) 93 | 94 | dphase = diff(phase_unwrapped, axis=time_axis) 95 | # print("dphase",dphase.shape) 96 | 97 | # Add an initial phase to dphase 98 | size = np.array(phase_unwrapped.shape) 99 | # size = phase_unwrapped.get_shape().as_list() 100 | 101 | size[time_axis] = 1 102 | # print("size",size) 103 | begin = [0 for unused_s in size] 104 | # phase_slice = tf.slice(phase_unwrapped, begin, size) 105 | # print("begin",begin) 106 | phase_slice = phase_unwrapped[begin[0]:begin[0]+size[0], begin[1]:begin[1]+size[1]] 107 | # print("phase_slice",phase_slice.shape) 108 | dphase = np.concatenate([phase_slice, dphase], axis=time_axis) / np.pi 109 | 110 | # dphase = tf.concat([phase_slice, dphase], axis=time_axis) / np.pi 111 | return dphase 112 | 113 | 114 | def polar2rect(mag, phase_angle): 115 | """Convert polar-form complex number to its rectangular form.""" 116 | # mag = np.complex(mag) 117 | temp_mag = np.zeros(mag.shape,dtype=np.complex_) 118 | temp_phase = np.zeros(mag.shape,dtype=np.complex_) 119 | 120 | for i, time in enumerate(mag): 121 | for j, time_id in enumerate(time): 122 | # print(mag[i,j]) 123 | temp_mag[i,j] = np.complex(mag[i,j]) 124 | # print(temp_mag[i,j]) 125 | 126 | for i, time in enumerate(phase_angle): 127 | for j, time_id in enumerate(time): 128 | temp_phase[i,j] = np.complex(np.cos(phase_angle[i,j]), np.sin(phase_angle[i,j])) 129 | # print(temp_mag[i,j]) 130 | 131 | # phase = np.complex(np.cos(phase_angle), np.sin(phase_angle)) 132 | 133 | return temp_mag * temp_phase 134 | -------------------------------------------------------------------------------- /pytorch_nsynth_lib/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Top-level package for pytorch-nsynth.""" 4 | 5 | __author__ = """Kwon-Young Choi""" 6 | __email__ = 'kwon-young.choi@hotmail.fr' 7 | __version__ = '0.1.0' 8 | -------------------------------------------------------------------------------- /pytorch_nsynth_lib/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ss12f32v/GANsynth-pytorch/fd282b98f3392375cff7b2a53baa6592e40aed19/pytorch_nsynth_lib/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /pytorch_nsynth_lib/__pycache__/nsynth.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ss12f32v/GANsynth-pytorch/fd282b98f3392375cff7b2a53baa6592e40aed19/pytorch_nsynth_lib/__pycache__/nsynth.cpython-35.pyc -------------------------------------------------------------------------------- /pytorch_nsynth_lib/from_tfrecord.py: -------------------------------------------------------------------------------- 1 | """ 2 | File: from_tfrecord.py 3 | Author: Kwon-Young Choi 4 | Email: kwon-young.choi@hotmail.fr 5 | Date: 2018-11-12 6 | Description: read nsynth dataset from tfrecord file 7 | """ 8 | import tensorflow as tf 9 | import autodebug 10 | 11 | 12 | def parser(serialized_example): 13 | features = tf.parse_single_example( 14 | serialized_example, 15 | features={ 16 | 'note': tf.FixedLenFeature([], tf.int64), 17 | 'instrument': tf.FixedLenFeature([], tf.int64), 18 | 'pitch': tf.FixedLenFeature([], tf.int64), 19 | 'velocity': tf.FixedLenFeature([], tf.int64), 20 | 'sample_rate': tf.FixedLenFeature([], tf.int64), 21 | 'audio': tf.FixedLenSequenceFeature( 22 | shape=[], dtype=tf.float32, allow_missing=True), 23 | 'qualities': tf.FixedLenSequenceFeature( 24 | shape=[], dtype=tf.int64, allow_missing=True), 25 | 'instrument_family': tf.FixedLenFeature([], tf.int64), 26 | 'instrument_source': tf.FixedLenFeature([], tf.int64), 27 | }) 28 | return features 29 | 30 | 31 | data_path = ['data/nsynth-test.tfrecord'] 32 | dataset = tf.data.TFRecordDataset(data_path) 33 | dataset = dataset.map(parser) 34 | dataset = dataset.batch(32) 35 | iterator = dataset.make_one_shot_iterator() 36 | batch_notes = iterator.get_next() 37 | 38 | with tf.Session() as sess: 39 | cpt = 0 40 | while True: 41 | print(cpt) 42 | try: 43 | out = sess.run(batch_notes) 44 | for key, value in out.items(): 45 | print(key, value.dtype, value.shape) 46 | except tf.errors.OutOfRangeError: 47 | break 48 | cpt += 1 49 | -------------------------------------------------------------------------------- /pytorch_nsynth_lib/nsynth.py: -------------------------------------------------------------------------------- 1 | """ 2 | File: nsynth.py 3 | Author: Kwon-Young Choi 4 | Email: kwon-young.choi@hotmail.fr 5 | Date: 2018-11-13 6 | Description: Load NSynth dataset using pytorch Dataset. 7 | If you want to modify the output of the dataset, use the transform 8 | and target_transform callbacks as ususal. 9 | """ 10 | import os 11 | import json 12 | import glob 13 | import numpy as np 14 | import scipy.io.wavfile 15 | import torch 16 | import torch.utils.data as data 17 | import torchvision.transforms as transforms 18 | from sklearn.preprocessing import LabelEncoder 19 | 20 | 21 | class NSynth(data.Dataset): 22 | 23 | """Pytorch dataset for NSynth dataset 24 | args: 25 | root: root dir containing examples.json and audio directory with 26 | wav files. 27 | transform (callable, optional): A function/transform that takes in 28 | a sample and returns a transformed version. 29 | target_transform (callable, optional): A function/transform that takes 30 | in the target and transforms it. 31 | blacklist_pattern: list of string used to blacklist dataset element. 32 | If one of the string is present in the audio filename, this sample 33 | together with its metadata is removed from the dataset. 34 | categorical_field_list: list of string. Each string is a key like 35 | instrument_family that will be used as a classification target. 36 | Each field value will be encoding as an integer using sklearn 37 | LabelEncoder. 38 | """ 39 | 40 | def __init__(self, root, transform=None, target_transform=None, 41 | blacklist_pattern=[], 42 | categorical_field_list=["instrument_family"]): 43 | """Constructor""" 44 | assert(isinstance(root, str)) 45 | assert(isinstance(blacklist_pattern, list)) 46 | assert(isinstance(categorical_field_list, list)) 47 | self.root = root 48 | self.filenames = glob.glob(os.path.join(root, "audio/*.wav")) 49 | with open(os.path.join(root, "examples.json"), "r") as f: 50 | self.json_data = json.load(f) 51 | for pattern in blacklist_pattern: 52 | self.filenames, self.json_data = self.blacklist( 53 | self.filenames, self.json_data, pattern) 54 | self.categorical_field_list = categorical_field_list 55 | self.le = [] 56 | for i, field in enumerate(self.categorical_field_list): 57 | self.le.append(LabelEncoder()) 58 | field_values = [value[field] for value in self.json_data.values()] 59 | self.le[i].fit(field_values) 60 | self.transform = transform 61 | self.target_transform = target_transform 62 | 63 | def blacklist(self, filenames, json_data, pattern): 64 | filenames = [filename for filename in filenames 65 | if pattern not in filename] 66 | json_data = { 67 | key: value for key, value in json_data.items() 68 | if pattern not in key 69 | } 70 | return filenames, json_data 71 | 72 | def __len__(self): 73 | return len(self.filenames) 74 | 75 | def __getitem__(self, index): 76 | """ 77 | Args: 78 | index (int): Index 79 | Returns: 80 | tuple: (audio sample, *categorical targets, json_data) 81 | """ 82 | name = self.filenames[index] 83 | _, sample = scipy.io.wavfile.read(name) 84 | target = self.json_data[os.path.splitext(os.path.basename(name))[0]] 85 | categorical_target = [ 86 | le.transform([target[field]])[0] 87 | for field, le in zip(self.categorical_field_list, self.le)] 88 | if self.transform is not None: 89 | sample = self.transform(sample) 90 | if self.target_transform is not None: 91 | target = self.target_transform(target) 92 | return [sample, *categorical_target, target] 93 | 94 | 95 | if __name__ == "__main__": 96 | # audio samples are loaded as an int16 numpy array 97 | # rescale intensity range as float [-1, 1] 98 | toFloat = transforms.Lambda(lambda x: x / np.iinfo(np.int16).max) 99 | # use instrument_family and instrument_source as classification targets 100 | dataset = NSynth( 101 | "../nsynth-test", 102 | transform=toFloat, 103 | blacklist_pattern=["string"], # blacklist string instrument 104 | categorical_field_list=["instrument_family", "instrument_source"]) 105 | loader = data.DataLoader(dataset, batch_size=32, shuffle=True) 106 | for samples, instrument_family_target, instrument_source_target, targets \ 107 | in loader: 108 | print(samples.shape, instrument_family_target.shape, 109 | instrument_source_target.shape) 110 | print(torch.min(samples), torch.max(samples)) 111 | -------------------------------------------------------------------------------- /spec_ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # import tensorflow as tf 3 | 4 | # mel spectrum constants. 5 | _MEL_BREAK_FREQUENCY_HERTZ = 700.0 6 | _MEL_HIGH_FREQUENCY_Q = 1127.0 7 | 8 | 9 | def mel_to_hertz(mel_values): 10 | """Converts frequencies in `mel_values` from the mel scale to linear scale.""" 11 | return _MEL_BREAK_FREQUENCY_HERTZ * ( 12 | np.exp(np.array(mel_values) / _MEL_HIGH_FREQUENCY_Q) - 1.0) 13 | 14 | 15 | def hertz_to_mel(frequencies_hertz): 16 | """Converts frequencies in `frequencies_hertz` in Hertz to the mel scale.""" 17 | return _MEL_HIGH_FREQUENCY_Q * np.log( 18 | 1.0 + (np.array(frequencies_hertz) / _MEL_BREAK_FREQUENCY_HERTZ)) 19 | 20 | 21 | def linear_to_mel_weight_matrix(num_mel_bins=20, 22 | num_spectrogram_bins=129, 23 | sample_rate=16000, 24 | lower_edge_hertz=125.0, 25 | upper_edge_hertz=3800.0): 26 | """Returns a matrix to warp linear scale spectrograms to the mel scale. 27 | Adapted from tf.contrib.signal.linear_to_mel_weight_matrix with a minimum 28 | band width (in Hz scale) of 1.5 * freq_bin. To preserve accuracy, 29 | we compute the matrix at float64 precision and then cast to `dtype` 30 | at the end. This function can be constant folded by graph optimization 31 | since there are no Tensor inputs. 32 | Args: 33 | num_mel_bins: Int, number of output frequency dimensions. 34 | num_spectrogram_bins: Int, number of input frequency dimensions. 35 | sample_rate: Int, sample rate of the audio. 36 | lower_edge_hertz: Float, lowest frequency to consider. 37 | upper_edge_hertz: Float, highest frequency to consider. 38 | Returns: 39 | Numpy float32 matrix of shape [num_spectrogram_bins, num_mel_bins]. 40 | Raises: 41 | ValueError: Input argument in the wrong range. 42 | """ 43 | # Validate input arguments 44 | if num_mel_bins <= 0: 45 | raise ValueError('num_mel_bins must be positive. Got: %s' % num_mel_bins) 46 | if num_spectrogram_bins <= 0: 47 | raise ValueError( 48 | 'num_spectrogram_bins must be positive. Got: %s' % num_spectrogram_bins) 49 | if sample_rate <= 0.0: 50 | raise ValueError('sample_rate must be positive. Got: %s' % sample_rate) 51 | if lower_edge_hertz < 0.0: 52 | raise ValueError( 53 | 'lower_edge_hertz must be non-negative. Got: %s' % lower_edge_hertz) 54 | if lower_edge_hertz >= upper_edge_hertz: 55 | raise ValueError('lower_edge_hertz %.1f >= upper_edge_hertz %.1f' % 56 | (lower_edge_hertz, upper_edge_hertz)) 57 | if upper_edge_hertz > sample_rate / 2: 58 | raise ValueError('upper_edge_hertz must not be larger than the Nyquist ' 59 | 'frequency (sample_rate / 2). Got: %s for sample_rate: %s' 60 | % (upper_edge_hertz, sample_rate)) 61 | 62 | # HTK excludes the spectrogram DC bin. 63 | bands_to_zero = 1 64 | nyquist_hertz = sample_rate / 2.0 65 | linear_frequencies = np.linspace( 66 | 0.0, nyquist_hertz, num_spectrogram_bins)[bands_to_zero:, np.newaxis] 67 | # spectrogram_bins_mel = hertz_to_mel(linear_frequencies) 68 | 69 | # Compute num_mel_bins triples of (lower_edge, center, upper_edge). The 70 | # center of each band is the lower and upper edge of the adjacent bands. 71 | # Accordingly, we divide [lower_edge_hertz, upper_edge_hertz] into 72 | # num_mel_bins + 2 pieces. 73 | band_edges_mel = np.linspace( 74 | hertz_to_mel(lower_edge_hertz), hertz_to_mel(upper_edge_hertz), 75 | num_mel_bins + 2) 76 | 77 | lower_edge_mel = band_edges_mel[0:-2] 78 | center_mel = band_edges_mel[1:-1] 79 | upper_edge_mel = band_edges_mel[2:] 80 | 81 | freq_res = nyquist_hertz / float(num_spectrogram_bins) 82 | freq_th = 1.5 * freq_res 83 | for i in range(0, num_mel_bins): 84 | center_hz = mel_to_hertz(center_mel[i]) 85 | lower_hz = mel_to_hertz(lower_edge_mel[i]) 86 | upper_hz = mel_to_hertz(upper_edge_mel[i]) 87 | if upper_hz - lower_hz < freq_th: 88 | rhs = 0.5 * freq_th / (center_hz + _MEL_BREAK_FREQUENCY_HERTZ) 89 | dm = _MEL_HIGH_FREQUENCY_Q * np.log(rhs + np.sqrt(1.0 + rhs**2)) 90 | lower_edge_mel[i] = center_mel[i] - dm 91 | upper_edge_mel[i] = center_mel[i] + dm 92 | 93 | lower_edge_hz = mel_to_hertz(lower_edge_mel)[np.newaxis, :] 94 | center_hz = mel_to_hertz(center_mel)[np.newaxis, :] 95 | upper_edge_hz = mel_to_hertz(upper_edge_mel)[np.newaxis, :] 96 | 97 | # Calculate lower and upper slopes for every spectrogram bin. 98 | # Line segments are linear in the mel domain, not Hertz. 99 | lower_slopes = (linear_frequencies - lower_edge_hz) / ( 100 | center_hz - lower_edge_hz) 101 | upper_slopes = (upper_edge_hz - linear_frequencies) / ( 102 | upper_edge_hz - center_hz) 103 | 104 | # Intersect the line segments with each other and zero. 105 | mel_weights_matrix = np.maximum(0.0, np.minimum(lower_slopes, upper_slopes)) 106 | 107 | # Re-add the zeroed lower bins we sliced out above. 108 | # [freq, mel] 109 | mel_weights_matrix = np.pad(mel_weights_matrix, [[bands_to_zero, 0], [0, 0]], 110 | 'constant') 111 | return mel_weights_matrix 112 | 113 | 114 | 115 | 116 | 117 | def random_phase_in_radians(shape, dtype): 118 | return np.pi * (2 * tf.random_uniform(shape, dtype=dtype) - 1.0) 119 | 120 | 121 | -------------------------------------------------------------------------------- /spectrograms_helper.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import spec_ops as spec_ops 4 | import phase_operation as phase_op 5 | 6 | def _linear_to_mel_matrix(): 7 | """Get the mel transformation matrix.""" 8 | _sample_rate=16000 9 | _mel_downscale=1 10 | num_freq_bins = 2048 // 2 11 | lower_edge_hertz = 0.0 12 | upper_edge_hertz = 16000 / 2.0 13 | num_mel_bins = num_freq_bins // _mel_downscale 14 | return spec_ops.linear_to_mel_weight_matrix( 15 | num_mel_bins, num_freq_bins, _sample_rate, lower_edge_hertz, 16 | upper_edge_hertz) 17 | 18 | def _mel_to_linear_matrix(): 19 | """Get the inverse mel transformation matrix.""" 20 | m = _linear_to_mel_matrix() 21 | m_t = np.transpose(m) 22 | p = np.matmul(m, m_t) 23 | d = [1.0 / x if np.abs(x) > 1.0e-8 else x for x in np.sum(p, axis=0)] 24 | return np.matmul(m_t, np.diag(d)) 25 | 26 | 27 | 28 | def melspecgrams_to_specgrams(logmelmag2, mel_p): 29 | """Converts melspecgrams to specgrams. 30 | Args: 31 | melspecgrams: Tensor of log magnitudes and instantaneous frequencies, 32 | shape [freq, time], mel scaling of frequencies. 33 | Returns: 34 | specgrams: Tensor of log magnitudes and instantaneous frequencies, 35 | shape [freq, time]. 36 | """ 37 | logmelmag2 = logmelmag2.T 38 | mel_p = mel_p.T 39 | logmelmag2 = np.array([logmelmag2]) 40 | mel_p = np.array([mel_p]) 41 | 42 | 43 | mel2l = _mel_to_linear_matrix() 44 | mag2 = np.tensordot(np.exp(logmelmag2), mel2l, 1) 45 | logmag = 0.5 * np.log(mag2+1e-6) 46 | mel_phase_angle = np.cumsum(mel_p * np.pi, axis=1) 47 | phase_angle = np.tensordot(mel_phase_angle, mel2l, 1) 48 | p = phase_op.instantaneous_frequency(phase_angle,time_axis=1) 49 | return logmag[0].T, p[0].T 50 | 51 | 52 | def specgrams_to_melspecgrams(magnitude, IF): 53 | """Converts specgrams to melspecgrams. 54 | Args: 55 | specgrams: Tensor of log magnitudes and instantaneous frequencies, 56 | shape [freq, time]. 57 | Returns: 58 | melspecgrams: Tensor of log magnitudes and instantaneous frequencies, 59 | shape [freq, time], mel scaling of frequencies. 60 | """ 61 | logmag = magnitude.T 62 | p = IF.T 63 | mag2 = np.exp(2.0 * logmag) 64 | mag2 = np.array([mag2]) 65 | phase_angle = np.cumsum(p * np.pi, axis=1) 66 | phase_angle = np.array([phase_angle]) 67 | 68 | l2mel = _linear_to_mel_matrix() 69 | logmelmag2 = np.log(np.tensordot(mag2,l2mel,axes=1) + 1e-6) 70 | mel_phase_angle = np.tensordot(phase_angle, l2mel, axes=1) 71 | mel_p = phase_op.instantaneous_frequency(mel_phase_angle,time_axis=1) 72 | return logmelmag2[0].T, mel_p[0].T -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import warnings 4 | from PIL import Image 5 | from PGGAN import * 6 | from normalizer import DataNormalizer 7 | import torch.optim as optim 8 | import torch.backends.cudnn as cudnn 9 | import torchvision.utils as vutils 10 | 11 | import torch.utils.data as udata 12 | import torchvision.datasets as vdatasets 13 | import torchvision.transforms as transforms 14 | from sklearn.model_selection import train_test_split 15 | import h5py 16 | import matplotlib.pyplot as plt 17 | 18 | os.environ["CUDA_VISIBLE_DEVICES"] = '1' 19 | 20 | 21 | folder ='test' 22 | BATCH_SIZE= 4 23 | DEVICE = torch.device("cuda:0") 24 | G_LR = 2e-4 25 | D_LR = 4e-4 26 | ADAM_BETA = (0.0, 0.99) 27 | ADAM_EPS = 1e-8 28 | LAMBDA_FOR_WGANGP = 1 29 | CRITIC_FOR_WGANGP = 1 30 | 31 | TOTAL_DATA_SIZE=11864 #validation Nsynth dataset size. If use trainingset, this argument should be changed 32 | 33 | 34 | class NsynthDataLoader(object): 35 | def __init__(self): 36 | cuda = torch.device("cuda:0") 37 | instr='guitar' 38 | # Nsynth_spec_IF_pitch.hdf5 39 | # self.dataset = h5py.File('../data/Nsynth_spec_IF_pitch.hdf5','r') 40 | self.dataset = h5py.File('data/Nsynth_valid_spec_IF_pitch_and_melSpec.hdf5','r') 41 | spec = self.dataset["Spec"][:] 42 | IF = self.dataset["IF"][:] 43 | pitch = self.dataset["pitch"][:] 44 | mel_spec = self.dataset["mel_Spec"][:] 45 | mel_IF = self.dataset["mel_IF"][:] 46 | 47 | self.train_dataset = udata.TensorDataset(torch.Tensor(spec), 48 | torch.Tensor(IF), 49 | torch.LongTensor(pitch), 50 | torch.Tensor(mel_spec), 51 | torch.Tensor(mel_IF) 52 | ) 53 | self.train_loader = udata.DataLoader(self.train_dataset, batch_size=BATCH_SIZE, shuffle=True) 54 | 55 | def change_batch_size(self, batch_size): 56 | self.train_loader = udata.DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True) 57 | 58 | 59 | class PGGAN(object): 60 | def __init__(self, 61 | resolution, # Resolution. 62 | latent_size, # Dimensionality of the latent vectors. 63 | dataloader, 64 | criterion_type="GAN", # ["GAN", "WGAN-GP"] 65 | rgb_channel=2, # Output channel size, for rgb always 3 66 | fmap_base=2 ** 11, # Overall multiplier for the number of feature maps. 67 | fmap_decay=1.0, # log2 feature map reduction when doubling the resolution. 68 | fmap_max=2 ** 6, # Maximum number of feature maps in any layer. 69 | is_tanh=True, 70 | is_sigmoid=True 71 | ): 72 | self.dataloader=dataloader 73 | self.latent_size_ = latent_size 74 | self.rgb_channel_ = rgb_channel 75 | self.fmap_base_ = fmap_base 76 | self.fmap_decay_ = fmap_decay 77 | self.fmap_max_ = fmap_max 78 | 79 | # self.stable_and_fadein_step = [6, 6, 6, 6, 6, 6, 30] 80 | self.stable_and_fadein_step = [1,1,1,1,1,1,1] 81 | 82 | 83 | 84 | 85 | self.criterion_type_ = criterion_type 86 | self.is_tanh_ = is_tanh 87 | self.is_sigmoid_ = False if self.criterion_type_ in ["WGAN-GP"] else is_sigmoid 88 | 89 | self.gradient_weight_real_ = torch.FloatTensor([-1]).cuda() 90 | self.gradient_weight_fake_ = torch.FloatTensor([1]).cuda() 91 | 92 | 93 | self.init_data_normalizer() 94 | self._init_network() 95 | self.avg_layer = torch.nn.AvgPool2d((2,2),stride=(2,2)) 96 | 97 | 98 | def train(self): 99 | 100 | self._init_optimizer() 101 | self._init_criterion() 102 | 103 | # Declare Model status 104 | net_level = 0 105 | net_status = "stable" 106 | net_alpha = 1.0 107 | 108 | 109 | for cur_level in range(7): 110 | self.stable_steps = self.stable_and_fadein_step[cur_level] 111 | self.fadein_steps = self.stable_and_fadein_step[cur_level] 112 | 113 | 114 | if cur_level==6: 115 | self.dataloader.change_batch_size(batch_size=4) 116 | self.stable_steps = 100 117 | 118 | if cur_level ==0: 119 | net_status == "stable" 120 | for step in range(self.stable_steps): 121 | self._train(cur_level, net_status, net_alpha, step) 122 | else: 123 | net_status = "fadein" 124 | 125 | for step in range(self.fadein_steps): 126 | # net_alpha = 1.0 - (step + 1) / fadein_steps 127 | self._train(cur_level, "fadein", net_alpha, step) 128 | 129 | for step in range(self.stable_steps*2): 130 | net_alpha = 1.0 131 | self._train(cur_level, "stable", net_alpha, step) 132 | if cur_level ==6: 133 | torch.save(self.g_net.state_dict(), folder + '/Gnet_%dx%d_step%d.pth' % (2 ** (cur_level + 1), 2 ** (cur_level + 4), step)) 134 | torch.save(self.d_net.state_dict(), folder + '/Dnet_%dx%d_step%d.pth' % (2 ** (cur_level + 1), 2 ** (cur_level + 4), step)) 135 | torch.save(self.g_net.state_dict(), folder + '/Gnet_%dx%d.pth' % (2 ** (cur_level + 1), 2 ** (cur_level + 4))) 136 | torch.save(self.d_net.state_dict(), folder + '/Dnet_%dx%d.pth' % (2 ** (cur_level + 1), 2 ** (cur_level + 4))) 137 | 138 | 139 | def _train(self, net_level, net_status, net_alpha, cur_step): 140 | 141 | current_level_res = 2 ** (net_level + 1) 142 | 143 | # for batch_idx, (score, target_spectrum, IF) in enumerate(self.dataloader.train_loader): 144 | for batch_idx, (spec, IF, pitch_label, mel_spec, mel_IF) in enumerate(self.dataloader.train_loader): 145 | 146 | 147 | # train mel spec IF 148 | spec = mel_spec 149 | IF = mel_IF 150 | 151 | stack_real_image = torch.stack((spec,IF),dim=1).cuda() 152 | stack_real_image = torch.transpose(stack_real_image,2,3) 153 | 154 | little_batch_size = spec.size()[0] 155 | 156 | 157 | 158 | while stack_real_image.size()[2] != current_level_res: 159 | stack_real_image = self.avg_layer(stack_real_image) 160 | stack_real_image = stack_real_image.cuda(0) 161 | 162 | 163 | if net_status =='stable': 164 | net_alpha = 1.0 165 | elif net_status =='fadein': 166 | 167 | if little_batch_size==BATCH_SIZE: 168 | net_alpha = 1.0 - (cur_step * TOTAL_DATA_SIZE + batch_idx * little_batch_size) / (self.fadein_steps * TOTAL_DATA_SIZE) 169 | else: 170 | net_alpha = 1.0 - (cur_step * TOTAL_DATA_SIZE + batch_idx*(BATCH_SIZE) + little_batch_size) / (self.fadein_steps * TOTAL_DATA_SIZE) 171 | 172 | if net_alpha< 0.0: 173 | print("Alpha too small <0") 174 | return 175 | 176 | # change net status 177 | self.g_net.net_config = [net_level, net_status, net_alpha] 178 | self.d_net.net_config = [net_level, net_status, net_alpha] 179 | 180 | """ Make Fake Condition Vector """ 181 | pitch_label = pitch_label.cuda() 182 | fake_pitch_label = torch.LongTensor(little_batch_size, 1).random_() % 128 183 | fake_one_hot_pitch_condition_vector = torch.zeros(little_batch_size, 128).scatter_(1, fake_pitch_label, 1).unsqueeze(2).unsqueeze(3).cuda() 184 | fake_pitch_label = fake_pitch_label.cuda().squeeze() 185 | 186 | 187 | """ generate random vector """ 188 | fake_seed = torch.randn(little_batch_size, self.latent_size_, 1, 1).cuda() 189 | fake_seed_and_pitch_condition = torch.cat((fake_seed, fake_one_hot_pitch_condition_vector), dim=1) 190 | 191 | fake_generated_sample = self.g_net(fake_seed_and_pitch_condition) 192 | 193 | 194 | stack_real_image = self.data_normalizer.normalize(stack_real_image) 195 | pitch_real, d_real = self.d_net(stack_real_image) 196 | pitch_fake, d_fake = self.d_net(fake_generated_sample.detach()) 197 | 198 | 199 | 200 | 201 | 202 | # WGAN-GP 203 | """ update d_net """ 204 | # real:-1 fake:1 205 | for p in self.d_net.parameters(): 206 | p.requires_grad = True 207 | self.d_net.zero_grad() 208 | 209 | 210 | # Train D with Real 211 | mean_real = d_real.mean() # wgan loss 212 | # mean_real = nn.ReLU()(1.0 - d_real).mean() #hinge loss 213 | 214 | 215 | # Train D with Fake 216 | mean_fake = d_fake.mean() # wgan loss 217 | # mean_fake = nn.ReLU()(1.0 + d_fake).mean() # hinge loss 218 | 219 | 220 | # Train D with GP 221 | gradient_penalty = 10 * self._gradient_penalty(stack_real_image.data, fake_generated_sample.data, little_batch_size, current_level_res) 222 | 223 | 224 | # Train D with classifier Loss 225 | 226 | real_pitch_loss = self.NLL_loss(pitch_real, pitch_label) 227 | fake_pitch_loss = self.NLL_loss(pitch_fake, fake_pitch_label) 228 | p_loss = 10 *(real_pitch_loss) 229 | 230 | # D_loss = mean_fake + mean_real + gradient_penalty + p_loss # hinge 231 | D_loss = mean_fake - mean_real + gradient_penalty + p_loss # wgan-gp 232 | Wasserstein_D = mean_real - mean_fake 233 | 234 | D_loss.backward() 235 | self.d_optim.step() 236 | 237 | 238 | if batch_idx% 3 ==0: # avoid Mode Collpase which caused by strong Generator 239 | 240 | """ update g_net """ 241 | for p in self.d_net.parameters(): 242 | p.requires_grad = False # to avoid computation 243 | self.g_net.zero_grad() 244 | pitch_fake, d_fake = self.d_net(fake_generated_sample) 245 | mean_fake = d_fake.mean() 246 | 247 | fake_pitch_loss = self.NLL_loss(pitch_fake, fake_pitch_label) 248 | timed_fake_pitch_loss = 10 *fake_pitch_loss 249 | G_loss = -mean_fake + timed_fake_pitch_loss 250 | G_loss.backward() 251 | 252 | 253 | 254 | self.g_optim.step() 255 | if batch_idx %1400 ==0: 256 | self.generate_picture(fake_generated_sample[:,0,:,:], current_level_res, cur_step, batch_idx, net_status) 257 | if batch_idx %200==0: 258 | print("Resolution:{}x{}, Status:{}, Cur_step:{}, Batch_id:{}, D_loss:{}, W_D:{}, M_fake:{}, M_Real:{}, GP:{}, Real_P_Loss:{}, Fake_P_Loss:{}, G_loss:{}, Net_alpha:{}".format(\ 259 | current_level_res, current_level_res*(2**3), net_status, cur_step, batch_idx, D_loss, Wasserstein_D, mean_fake, mean_real, real_pitch_loss, fake_pitch_loss, gradient_penalty, G_loss, net_alpha)) 260 | print("self.fadein_steps",self.fadein_steps* TOTAL_DATA_SIZE,\ 261 | "cur_step",(cur_step * TOTAL_DATA_SIZE + batch_idx * little_batch_size),\ 262 | "net_alpha",net_alpha 263 | 264 | ) 265 | def generate_picture(self, spec, resolution, step, batch_idx, status): 266 | spec = spec.data.cpu().numpy() 267 | stack_spec = np.hstack((spec[0],spec[1],spec[2],spec[3])) 268 | flip_stack = np.flipud(stack_spec) 269 | fig = plt.figure(figsize=(20,7)) 270 | plt.imshow(stack_spec,aspect='auto') 271 | 272 | plt.savefig( folder + "/{}_{}_{}_{}_{}_sample.png".format(resolution, resolution*(2**3), status, step, batch_idx )) 273 | 274 | 275 | 276 | def _init_criterion(self): 277 | 278 | self.criterion = self._gradient_penalty 279 | self.NLL_loss = nn.NLLLoss() 280 | 281 | def init_data_normalizer(self): 282 | 283 | self.data_normalizer = DataNormalizer(self.dataloader) 284 | 285 | def _init_network(self): 286 | # Init Generator and Discriminator 287 | print("Create Network") 288 | self.g_net = Generator(256, self.latent_size_, self.rgb_channel_, 289 | is_tanh=self.is_tanh_, channel_list=[256,256,256,256,256,128,64,32]) 290 | self.d_net = Discriminator(256, self.rgb_channel_, 291 | is_sigmoid=self.is_sigmoid_, channel_list=[256,256,256,256,128,64,32,32]) 292 | # if TO_GPU: 293 | self.g_net.cuda(0) 294 | self.d_net.cuda(0) 295 | print(self.g_net) 296 | print(self.d_net) 297 | 298 | 299 | def _init_optimizer(self): 300 | self.g_optim = optim.Adam(self.g_net.parameters(), lr=G_LR, betas=ADAM_BETA, eps=ADAM_EPS) 301 | self.d_optim = optim.Adam(self.d_net.parameters(), lr=D_LR, betas=ADAM_BETA, eps=ADAM_EPS) 302 | 303 | def _gradient_penalty(self, real_data, fake_data, batch_size, res): 304 | """ 305 | This algorithm was mentioned on the Page4 of paper 306 | 'Improved Training of Wasserstein GANs' 307 | This implementation was from 'https://github.com/caogang/wgan-gp' 308 | """ 309 | 310 | # print("real_data.nelement() / batch_size",real_data.nelement() / batch_size) 311 | epsilon = torch.rand(batch_size, 1) 312 | epsilon = epsilon.expand(batch_size, real_data.nelement() / batch_size).contiguous().view(batch_size, 2, res, res*(2**3)) 313 | # epsilon = epsilon.expand_as(real_data) 314 | # print("epsilon",epsilon.size()) 315 | epsilon = epsilon.cuda() 316 | median_x = epsilon * real_data + ((1 - epsilon) * fake_data) 317 | 318 | # if TO_GPU: 319 | median_x = median_x.cuda() 320 | median_data = torch.autograd.Variable(median_x, requires_grad=True) 321 | 322 | _, d_median_data = self.d_net(median_data) 323 | 324 | gradients = torch.autograd.grad(outputs=d_median_data, inputs=median_data, 325 | grad_outputs=torch.ones(d_median_data.size()).cuda(), 326 | create_graph=True, retain_graph=True, only_inputs=True)[0] 327 | gradients = gradients.view(gradients.size(0), -1) 328 | 329 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA_FOR_WGANGP 330 | return gradient_penalty 331 | 332 | 333 | 334 | 335 | def main(): 336 | dataloader = NsynthDataLoader() 337 | p = PGGAN(512, 256, dataloader = dataloader, criterion_type="WGAN-GP") 338 | 339 | p.train() 340 | 341 | 342 | if __name__ == "__main__": 343 | main() 344 | --------------------------------------------------------------------------------