├── cfgs ├── config.py └── config_latest.py ├── .gitignore ├── models ├── __init__.py ├── mss.py ├── meta.py ├── unet.py └── svsgan.py ├── dataset.py ├── preprocess.py ├── README.md ├── main.py └── train.py /cfgs/config.py: -------------------------------------------------------------------------------- 1 | from .config_latest import cfg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__ 2 | *.wav 3 | .DS_Store 4 | *.txt -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet import * 2 | from .svsgan import * 3 | from .mss import * -------------------------------------------------------------------------------- /models/mss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F -------------------------------------------------------------------------------- /cfgs/config_latest.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | cfg = edict() 4 | 5 | cfg.len_frame = 1024 6 | cfg.len_hop = 1024 // 4 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data_utils 2 | import librosa 3 | import numpy as np 4 | from cfgs.config import cfg 5 | class Data(data_utils.Dataset): 6 | def __init__(self, files): 7 | self.files = [i.strip() for i in open(files).readlines()] 8 | def __getitem__(self, idx): 9 | # sound, sample_rate = torchaudio.load(self.files[idx]) 10 | sound, sample_rate = librosa.load(self.files[idx], mono=False) 11 | mono = librosa.to_mono(sound) 12 | len_frame = cfg.len_frame 13 | len_hop = cfg.len_hop 14 | spectrogram_mono = librosa.stft(mono, n_fft=len_frame, hop_length=len_hop) 15 | spectrogram_nonvocal = librosa.stft(sound[0], n_fft=len_frame, hop_length=len_hop) 16 | spectrogram_mono = spectrogram_mono.astype(np.float32) 17 | spectrogram_nonvocal = spectrogram_nonvocal.astype(np.float32) 18 | 19 | return spectrogram_mono, spectrogram_nonvocal 20 | def __len__(self): 21 | return len(self.files) 22 | 23 | if __name__ == '__main__': 24 | data = Data('data_train.txt') 25 | print(data.__len__()) 26 | for i in train_loader: 27 | print(i) -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | def preprocess(cfg): 5 | records = [os.path.join(root, filename) for root, dirs, filenames in os.walk(cfg.dir) for filename in filenames if filename.endswith('.wav')] 6 | total_num = len(records) 7 | test_num = int(cfg.test_ratio * total_num) 8 | train_num = total_num - test_num 9 | train_records = records[0:train_num] 10 | test_records = records[train_num:] 11 | 12 | with open(cfg.name + '_all.txt', 'w') as f: 13 | f.writelines((i + '\n' for i in records)) 14 | with open(cfg.name + '_train.txt', 'w') as f: 15 | f.writelines((i + '\n' for i in train_records)) 16 | with open(cfg.name + '_test.txt', 'w') as f: 17 | f.writelines((i + '\n' for i in test_records)) 18 | 19 | 20 | if __name__ == '__main__': 21 | import argparse 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--name', help='comma separated list of GPU(s) to use.', default = 'data') 24 | parser.add_argument('--test_ratio', help='ratio of test data', default = 0.1) 25 | parser.add_argument('--dir', help="directory of logging", default='./') 26 | args = parser.parse_args() 27 | preprocess(args) -------------------------------------------------------------------------------- /models/meta.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | class Model(object): 5 | def __init__(self, args): 6 | pass 7 | def get_model(self): 8 | pass 9 | def get_optimizer(self): 10 | pass 11 | def loss_func(self, *args): 12 | pass 13 | def train(self, train_iter, eval_iter, args): 14 | model = self.get_model() 15 | if args.cuda: 16 | model.cuda() 17 | 18 | optimizer = torch.optim.Adam(model.parameters(), lr = args.lr) 19 | 20 | steps = 0 21 | model.train() 22 | for epoch in range(args.epochs): 23 | for batch in train_iter: 24 | feature, target = batch.feature, batch.target 25 | # feature. 26 | if args.cuda: 27 | feature, target = feature.cuda(), target.cuda() 28 | 29 | optimizer.zero_grad() 30 | output = model(feature) 31 | loss = self.loss_func(output, target) 32 | loss.backward() 33 | optimizer.step() 34 | 35 | steps += 1 36 | save_path = 0 37 | if steps % args.test_interval == 0 and eval_iter: 38 | self.eval(eval_iter) 39 | if steps % args.save_interval == 0: 40 | torch.save(model, save_path) 41 | def eval(self, eval_iter): 42 | pass -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | # Requirements 4 | - pytorch 5 | - librosa 6 | ```shell 7 | pip3 install librosa 8 | ``` 9 | - librosa 10 | # Usage 11 | - data helper 12 | As the model only takes wave files as inputs, it will collect all supported audio file paths in preprocessing procedure. 13 | ```bash 14 | python3 preprocess.py --dir 15 | ``` 16 | - training 17 | ```bash 18 | python3 main.py --model --train --train_manifest 19 | ``` 20 | - evaluation 21 | ```bash 22 | python3 main.py --model --eval --eval_manifest 23 | ``` 24 | 25 | - predict 26 | ```shell 27 | python3 main.py --model --predict 28 | ``` 29 | # Results 30 | our result(paper/official repo result) on iKala 31 | | Model | NSDR(Vocal) | NSDR(Instrumental) | SIR(Vocal) | SIR(Instrumental) | SAR(Vocal) | SAR(Instrumental) | 32 | |:-----:|:-----------:|:------------------:|:----------:|:-----------------:|:----------:|:-----------------:| 33 | | U-Net | (11.094) | (14.435) | (23.960) | (21.832) | (17.715) | (14.120) | 34 | | SVSGAN | (-) | (-) | (23.70) | (-) | (14.10) | (-) | 35 | | GRU-RIS | (-) | (-) | (23.70) | (-) | (14.10) | (-) | 36 | 37 | # Reference 38 | | Model | Original Paper | Official Repo | 39 | |:-----:|:-----:|:-----:| 40 | | U-Net | [Singing Voice Separation with Deep U-Net Convolutional Networks](https://ismir2017.smcnus.org/wp-content/uploads/2017/10/171_Paper.pdf)| - | 41 | | SVSGAN | [SVSGAN: Singing Voice Separation via Generative Adversarial Network](https://arxiv.org/abs/1710.11428)| - | 42 | | MSS | [Singing Voice Separation via Recurrent Inference and Skip-Filtering Connections](https://arxiv.org/abs/1711.01437) | [Js-Mim/mss_pytorch](https://github.com/Js-Mim/mss_pytorch) | 43 | 44 | # Troubleshooting 45 | 1. NoBackendError audioread 46 | ```shell 47 | ``` -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data_utils 2 | import argparse 3 | from models import * 4 | from dataset import Data 5 | models_dict = { 6 | 'UNet': UNet 7 | } 8 | 9 | def get_args(): 10 | def check_args(args): 11 | 12 | # 任何操作都需指定一个合法的模型 13 | assert args.model in models_dict 14 | 15 | # 如果执行训练过程,则训练集不为空 16 | if args.train: 17 | assert args.train_manifest is not None 18 | 19 | 20 | # assert args.train 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--model', help='model to use.') 23 | parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.', default='0,1') 24 | parser.add_argument('--load', help='load weights') 25 | parser.add_argument('--batch_size', help='batch size', default=8) 26 | parser.add_argument('--max_epoch', help='load model', default=80) 27 | parser.add_argument('--log_dir', help="directory of logging", default=None) 28 | 29 | parser.add_argument('--train', action='store_true') 30 | parser.add_argument('--eval', action='store_true') 31 | parser.add_argument('--predict', action='store_true') 32 | 33 | args = parser.parse_args() 34 | check_args(args) 35 | return args 36 | 37 | def main(args): 38 | train_dataset = Data(args.train_manifest) 39 | train_loader = data_utils.DataLoader(train_dataset, args.batch_size) 40 | 41 | eval_loader = None 42 | if args.eval_manifest: 43 | eval_dataset = Data(args.eval_manifest) 44 | eval_loader = data_utils.DataLoader(eval_dataset, args.batch_size) 45 | 46 | net = models_dict[args.model](args) 47 | 48 | if args.train: 49 | net.train(train_loader, eval_loader) 50 | elif args.eval: 51 | net.eval() 52 | else: 53 | net.predict() 54 | 55 | if __name__ == '__main__': 56 | args = get_args() 57 | main(args) 58 | 59 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from models import * 2 | 3 | import torch.utils.data as data_utils 4 | 5 | from dataset import Data 6 | 7 | batch_size = 8 8 | train_dataset = Data('train') 9 | test_dataset = Data('train') 10 | train_loader = data_utils.DataLoader(train_dataset, batch_size) 11 | test_loader = data_utils.DataLoader(test_dataset, batch_size) 12 | 13 | 14 | class config: 15 | class encoder: 16 | leakiness = 0.2 17 | ch_in = [1, 16, 32, 64, 128, 256] 18 | ch_out = [16, 32, 64, 128, 256, 512] 19 | kernel_size = (5, 5) 20 | stride = 2 21 | class decoder: 22 | ch_in = [512, 512, 256, 128, 64, 32] 23 | ch_out = [256, 128, 64, 32, 16] 24 | kernel_size = (5, 5) 25 | stride = 2 26 | 27 | net = UNet(config) 28 | if True: 29 | net.cuda() 30 | 31 | 32 | if __name__ == '__main__': 33 | import argparse 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.', default='0,1') 36 | parser.add_argument('--load', help='load model') 37 | parser.add_argument('--batch_size', help='load model', default=8) 38 | parser.add_argument('--max_epoch', help='load model', default=80) 39 | parser.add_argument('--log_dir', help="directory of logging", default=None) 40 | args = parser.parse_args() 41 | 42 | for epoch in range(args.max_epoch): 43 | net.train(epoch) 44 | 45 | 46 | 47 | 48 | def train(model, train_iter, dev_iter, loss_func, args): 49 | """ 50 | # Arguments 51 | model: 52 | train_iter: 53 | dev_iter: 54 | loss_func: 55 | args: 56 | """ 57 | if args.cuda: 58 | model.cuda() 59 | 60 | optimizer = torch.optim.Adam(model.parameters(), lr = args.lr) 61 | 62 | steps = 0 63 | model.train() 64 | for epoch in range(args.epochs): 65 | for batch in train_iter: 66 | feature, target = batch.text, batch.label 67 | feature. 68 | if args.cuda: 69 | feature, target = feature.cuda(), target.cuda() 70 | 71 | optimizer.zero_grad() 72 | 73 | output = model(feature) 74 | 75 | loss = loss_func(output, target) 76 | loss.backward() 77 | optimizer.step() 78 | 79 | steps += 1 80 | 81 | if steps % args.test_interval == 0: 82 | eval() 83 | if steps % args.save_interval == 0: 84 | torch.save(model, save_path) 85 | 86 | def eval(model, data_iter, args): 87 | model.eval() 88 | 89 | for batch in data_iter: 90 | feature, target = batch.text, batch.label 91 | output = model(feature) 92 | loss = loss_func(output, feature) 93 | 94 | 95 | model.train() 96 | 97 | 98 | np.arange(*[0.5, 1.0], 5) -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from math import floor, ceil 5 | 6 | __all__ = ['UNet'] 7 | 8 | def same_padding_conv(x, conv): 9 | dim = len(x.size()) 10 | if dim == 4: 11 | b, c, h, w = x.size() 12 | elif dim == 5: 13 | b, t, c, h, w = x.size() 14 | else: 15 | raise NotImplementedError() 16 | 17 | if isinstance(conv, nn.Conv2d): 18 | padding = ((w // conv.stride[0] - 1) * conv.stride[0] + conv.kernel_size[0] - w) 19 | padding_l = floor(padding / 2) 20 | padding_r = ceil(padding / 2) 21 | padding = ((h // conv.stride[1] - 1) * conv.stride[1] + conv.kernel_size[1] - h) 22 | padding_t = floor(padding / 2) 23 | padding_b = ceil(padding / 2) 24 | x = F.pad(x, pad = (padding_l,padding_r,padding_t,padding_b)) 25 | x = conv(x) 26 | elif isinstance(conv, nn.ConvTranspose2d): 27 | padding = ((w - 1) * conv.stride + conv.kernel_size[0] - w * conv.stride[0]) 28 | padding_l = floor(padding / 2) 29 | padding_r = ceil(padding / 2) 30 | padding = ((h - 1) * conv.stride + conv.kernel_size[1] - h * conv.stride[1]) 31 | padding_t = floor(padding / 2) 32 | padding_b = ceil(padding / 2) 33 | x = conv(x) 34 | x = x[:,:,padding_t:-padding_b,padding_l:-padding_r] 35 | else: 36 | raise NotImplementedError() 37 | return x 38 | 39 | class config: 40 | class encoder: 41 | leakiness = 0.2 42 | ch_in = [1, 16, 32, 64, 128, 256] 43 | ch_out = [16, 32, 64, 128, 256, 512] 44 | kernel_size = (5, 5) 45 | stride = 2 46 | class decoder: 47 | ch_in = [512, 512, 256, 128, 64, 32] 48 | ch_out = [256, 128, 64, 32, 16] 49 | kernel_size = (5, 5) 50 | stride = 2 51 | 52 | def loss_func(x, predict, y): 53 | return torch.sum(torch.abs(x * predict - y)) 54 | 55 | class Net(nn.Module): 56 | def __init__(self, config): 57 | super(Net, self).__init__() 58 | self.convs = [] 59 | self.deconvs = [] 60 | self.kernel_size = config.encoder.kernel_size 61 | self.stride = config.encoder.stride 62 | for i in range(len(config.encoder.ch_out)): 63 | self.convs.append( 64 | nn.Sequential( 65 | nn.Conv2d( 66 | in_channels = config.encoder.ch_in[i], 67 | out_channels = config.encoder.ch_out[i], 68 | kernel_size = self.kernel_size, 69 | stride = self.stride 70 | ), 71 | nn.BatchNorm2d(config.encoder.ch_out[i]), 72 | nn.LeakyReLU(config.encoder.leakiness), 73 | ) 74 | ) 75 | for i in range(len(config.decoder.ch_out)): 76 | self.deconvs.append( 77 | nn.Sequential( 78 | nn.ConvTranspose2d( 79 | in_channels = config.decoder.ch_in[i], 80 | out_channels = config.decoder.ch_out[i], 81 | kernel_size = config.decoder.kernel_size, 82 | stride = config.decoder.stride 83 | ), 84 | nn.BatchNorm2d(config.decoder.ch_out[i]), 85 | nn.ReLU() 86 | ) 87 | ) 88 | def forward(self, x): 89 | conv_output = [] 90 | skip_connections = [] 91 | for layer_idx, conv in enumerate(self.convs): 92 | x = same_padding_conv(x, conv) 93 | if layer_idx != len(self.convs) - 1: 94 | skip_connections.append(x) 95 | for layer_idx, deconv in enumerate(self.deconvs): 96 | x = same_padding_conv(x, deconv) 97 | if layer_idx < 3: 98 | x = F.dropout2d(x, p = 0.5) 99 | x = torch.cat([skip_connections.pop(), x], dim = 1) 100 | return x 101 | 102 | class UNet(object): 103 | def __init__(self, args): 104 | self.batch_size = args.batch_size 105 | 106 | self.model = Net(config) 107 | pass 108 | def train(self, train_iter, dev_iter, args): 109 | if args.cuda: 110 | model.cuda() 111 | 112 | optimizer = torch.optim.Adam(model.parameters(), lr = args.lr) 113 | 114 | steps = 0 115 | self.model.train() 116 | for epoch in range(args.epochs): 117 | for batch in train_iter: 118 | feature, target = batch.feature, batch.target 119 | # feature. 120 | if args.cuda: 121 | feature, target = feature.cuda(), target.cuda() 122 | 123 | optimizer.zero_grad() 124 | output = model(feature) 125 | loss = loss_func(feature, output, target) 126 | loss.backward() 127 | optimizer.step() 128 | 129 | steps += 1 130 | 131 | if steps % args.test_interval == 0 and dev_iter: 132 | self.eval(dev_iter) 133 | if steps % args.save_interval == 0: 134 | torch.save(model, save_path) 135 | def eval(self, data_iter, args): 136 | pass 137 | 138 | 139 | if __name__ == '__main__': 140 | unet = Net(config) 141 | print(unet) 142 | 143 | import numpy as np 144 | from torch.autograd import Variable 145 | unet.forward(Variable(torch.Tensor(np.ones((7,1,512,128))))) 146 | -------------------------------------------------------------------------------- /models/svsgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | # import # utils, 6 | import torch, time, os, pickle 7 | 8 | import numpy as np 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from torch.autograd import Variable 12 | 13 | __all__ = ['SVSGAN'] 14 | 15 | class generator(nn.Module): 16 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 17 | # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S 18 | def __init__(self, dataset = 'mnist'): 19 | super(generator, self).__init__() 20 | if dataset == 'mnist' or 'fashion-mnist': 21 | self.input_height = 28 22 | self.input_width = 28 23 | self.input_dim = 62 + 10 24 | self.output_dim = 1 25 | 26 | self.fc = nn.Sequential( 27 | nn.Linear(self.input_dim, 1024), 28 | nn.BatchNorm1d(1024), 29 | nn.ReLU(), 30 | nn.Linear(1024, 128 * (self.input_height // 4) * (self.input_width // 4)), 31 | nn.BatchNorm1d(128 * (self.input_height // 4) * (self.input_width // 4)), 32 | nn.ReLU(), 33 | ) 34 | self.deconv = nn.Sequential( 35 | nn.ConvTranspose2d(128, 64, 4, 2, 1), 36 | nn.BatchNorm2d(64), 37 | nn.ReLU(), 38 | nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), 39 | nn.Sigmoid(), 40 | ) 41 | # utils.initialize_weights(self) 42 | 43 | def forward(self, input, label): 44 | x = torch.cat([input, label], 1) 45 | x = self.fc(x) 46 | x = x.view(-1, 128, (self.input_height // 4), (self.input_width // 4)) 47 | x = self.deconv(x) 48 | 49 | return x 50 | 51 | class discriminator(nn.Module): 52 | # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) 53 | # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S 54 | def __init__(self, dataset = 'mnist'): 55 | super(discriminator, self).__init__() 56 | if dataset == 'mnist' or 'fashion-mnist': 57 | self.input_height = 28 58 | self.input_width = 28 59 | self.input_dim = 1 + 10 60 | self.output_dim = 1 61 | 62 | self.conv = nn.Sequential( 63 | nn.Conv2d(self.input_dim, 64, 4, 2, 1), 64 | nn.LeakyReLU(0.2), 65 | nn.Conv2d(64, 128, 4, 2, 1), 66 | nn.BatchNorm2d(128), 67 | nn.LeakyReLU(0.2), 68 | ) 69 | self.fc = nn.Sequential( 70 | nn.Linear(128 * (self.input_height // 4) * (self.input_width // 4), 1024), 71 | nn.BatchNorm1d(1024), 72 | nn.LeakyReLU(0.2), 73 | nn.Linear(1024, self.output_dim), 74 | nn.Sigmoid(), 75 | ) 76 | # utils.initialize_weights(self) 77 | 78 | def forward(self, input, label): 79 | x = torch.cat([input, label], 1) 80 | x = self.conv(x) 81 | x = x.view(-1, 128 * (self.input_height // 4) * (self.input_width // 4)) 82 | x = self.fc(x) 83 | 84 | return x 85 | 86 | class SVSGAN(object): 87 | def __init__(self, args): 88 | # parameters 89 | self.epoch = args.epoch 90 | self.sample_num = 100 91 | self.batch_size = args.batch_size 92 | self.save_dir = args.save_dir 93 | self.result_dir = args.result_dir 94 | self.dataset = args.dataset 95 | self.log_dir = args.log_dir 96 | self.gpu_mode = args.gpu_mode 97 | self.model_name = args.gan_type 98 | 99 | # networks init 100 | self.G = generator(self.dataset) 101 | self.D = discriminator(self.dataset) 102 | self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) 103 | self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) 104 | 105 | if self.gpu_mode: 106 | self.G.cuda() 107 | self.D.cuda() 108 | self.BCE_loss = nn.BCELoss().cuda() 109 | else: 110 | self.BCE_loss = nn.BCELoss() 111 | 112 | print('---------- Networks architecture -------------') 113 | # utils.print_network(self.G) 114 | # utils.print_network(self.D) 115 | print('-----------------------------------------------') 116 | 117 | # load mnist 118 | # self.data_X, self.data_Y = # utils.load_mnist(args.dataset) 119 | self.z_dim = 62 120 | self.y_dim = 10 121 | 122 | # fixed noise & condition # utils.print_network(self.G) 123 | # utils.print_network(self.D) 124 | self.sample_z_ = torch.zeros((self.sample_num, self.z_dim)) 125 | for i in range(10): 126 | self.sample_z_[i*self.y_dim] = torch.rand(1, self.z_dim) 127 | for j in range(1, self.y_dim): 128 | self.sample_z_[i*self.y_dim + j] = self.sample_z_[i*self.y_dim] 129 | 130 | temp = torch.zeros((10, 1)) 131 | for i in range(self.y_dim): 132 | temp[i, 0] = i 133 | 134 | temp_y = torch.zeros((self.sample_num, 1)) 135 | for i in range(10): 136 | temp_y[i*self.y_dim: (i+1)*self.y_dim] = temp 137 | 138 | self.sample_y_ = torch.zeros((self.sample_num, self.y_dim)) 139 | self.sample_y_.scatter_(1, temp_y.type(torch.LongTensor), 1) 140 | if self.gpu_mode: 141 | self.sample_z_, self.sample_y_ = Variable(self.sample_z_.cuda(), volatile=True), Variable(self.sample_y_.cuda(), volatile=True) 142 | else: 143 | self.sample_z_, self.sample_y_ = Variable(self.sample_z_, volatile=True), Variable(self.sample_y_, volatile=True) 144 | 145 | def train(self): 146 | self.train_hist = {} 147 | self.train_hist['D_loss'] = [] 148 | self.train_hist['G_loss'] = [] 149 | self.train_hist['per_epoch_time'] = [] 150 | self.train_hist['total_time'] = [] 151 | 152 | if self.gpu_mode: 153 | self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1).cuda()), Variable(torch.zeros(self.batch_size, 1).cuda()) 154 | else: 155 | self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1)), Variable(torch.zeros(self.batch_size, 1)) 156 | 157 | self.fill = torch.zeros([10, 10, self.data_X.size()[2], self.data_X.size()[3]]) 158 | for i in range(10): 159 | self.fill[i, i, :, :] = 1 160 | 161 | self.D.train() 162 | print('training start!!') 163 | start_time = time.time() 164 | for epoch in range(self.epoch): 165 | self.G.train() 166 | epoch_start_time = time.time() 167 | for iter in range(len(self.data_X) // self.batch_size): 168 | x_ = self.data_X[iter*self.batch_size:(iter+1)*self.batch_size] 169 | z_ = torch.rand((self.batch_size, self.z_dim)) 170 | y_vec_ = self.data_Y[iter*self.batch_size:(iter+1)*self.batch_size] 171 | y_fill_ = self.fill[torch.max(y_vec_, 1)[1].squeeze()] 172 | 173 | if self.gpu_mode: 174 | x_, z_, y_vec_, y_fill_ = Variable(x_.cuda()), Variable(z_.cuda()), \ 175 | Variable(y_vec_.cuda()), Variable(y_fill_.cuda()) 176 | else: 177 | x_, z_, y_vec_, y_fill_ = Variable(x_), Variable(z_), Variable(y_vec_), Variable(y_fill_) 178 | 179 | # update D network 180 | self.D_optimizer.zero_grad() 181 | 182 | D_real = self.D(x_, y_fill_) 183 | D_real_loss = self.BCE_loss(D_real, self.y_real_) 184 | 185 | G_ = self.G(z_, y_vec_) 186 | D_fake = self.D(G_, y_fill_) 187 | D_fake_loss = self.BCE_loss(D_fake, self.y_fake_) 188 | 189 | D_loss = D_real_loss + D_fake_loss 190 | self.train_hist['D_loss'].append(D_loss.data[0]) 191 | 192 | D_loss.backward() 193 | self.D_optimizer.step() 194 | 195 | # update G network 196 | self.G_optimizer.zero_grad() 197 | 198 | G_ = self.G(z_, y_vec_) 199 | D_fake = self.D(G_, y_fill_) 200 | G_loss = self.BCE_loss(D_fake, self.y_real_) 201 | self.train_hist['G_loss'].append(G_loss.data[0]) 202 | 203 | G_loss.backward() 204 | self.G_optimizer.step() 205 | 206 | if ((iter + 1) % 100) == 0: 207 | print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % 208 | ((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0])) 209 | 210 | self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) 211 | self.visualize_results((epoch+1)) 212 | 213 | self.train_hist['total_time'].append(time.time() - start_time) 214 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), 215 | self.epoch, self.train_hist['total_time'][0])) 216 | print("Training finish!... save training results") 217 | 218 | self.save() 219 | # utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, 220 | # self.epoch) 221 | # utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) 222 | 223 | def visualize_results(self, epoch, fix=True): 224 | self.G.eval() 225 | 226 | if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): 227 | os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) 228 | 229 | image_frame_dim = int(np.floor(np.sqrt(self.sample_num))) 230 | 231 | if fix: 232 | """ fixed noise """ 233 | samples = self.G(self.sample_z_, self.sample_y_) 234 | else: 235 | """ random noise """ 236 | temp = torch.LongTensor(self.batch_size, 1).random_() % 10 237 | sample_y_ = torch.FloatTensor(self.batch_size, 10) 238 | sample_y_.zero_() 239 | sample_y_.scatter_(1, temp, 1) 240 | if self.gpu_mode: 241 | sample_z_, sample_y_ = Variable(torch.rand((self.batch_size, self.z_dim)).cuda(), volatile=True), \ 242 | Variable(sample_y_.cuda(), volatile=True) 243 | else: 244 | sample_z_, sample_y_ = Variable(torch.rand((self.batch_size, self.z_dim)), volatile=True), \ 245 | Variable(sample_y_, volatile=True) 246 | 247 | samples = self.G(sample_z_, sample_y_) 248 | 249 | if self.gpu_mode: 250 | samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) 251 | else: 252 | samples = samples.data.numpy().transpose(0, 2, 3, 1) 253 | 254 | # utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], 255 | # self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png') 256 | 257 | def save(self): 258 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 259 | 260 | if not os.path.exists(save_dir): 261 | os.makedirs(save_dir) 262 | 263 | torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) 264 | torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) 265 | 266 | with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: 267 | pickle.dump(self.train_hist, f) 268 | 269 | def load(self): 270 | save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) 271 | 272 | self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) 273 | self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) 274 | --------------------------------------------------------------------------------