├── .gitignore ├── .idea ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── pytorch-WGAN-GP.iml ├── rSettings.xml └── vcs.xml ├── README.md ├── __init__.py ├── dataset.py ├── display_result.py ├── img ├── generated_images.png └── paper1.png ├── layer.py ├── main.py ├── model.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | datasets/ 3 | log/ 4 | results/ -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /workspace.xml 3 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/pytorch-WGAN-GP.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/rSettings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WGAN-GP 2 | 3 | ### Title 4 | [Improved Training of Wasserstein GANs](https://arxiv.org/abs/1704.00028) 5 | 6 | ### Abstract 7 | Generative Adversarial Networks (GANs) are powerful generative models, but suffer from training instability. The recently proposed Wasserstein GAN (WGAN) makes progress toward stable training of GANs, but sometimes can still generate only low-quality samples or fail to converge. We find that these problems are often due to the use of weight clipping in WGAN to enforce a Lipschitz constraint on the critic, which can lead to undesired behavior. We propose an alternative to clipping weights: penalize the norm of gradient of the critic with respect to its input. Our proposed method performs better than standard WGAN and enables stable training of a wide variety of GAN architectures with almost no hyperparameter tuning, including 101-layer ResNets and language models over discrete data. We also achieve high quality generations on CIFAR-10 and LSUN bedrooms. 8 | 9 | ![alt text](./img/paper1.png "Novelty of WGAN-GP") 10 | 11 | ## Train 12 | $ python main.py --mode train \ 13 | --scope [scope name] \ 14 | --name_data [data name] \ 15 | --dir_data [data directory] \ 16 | --dir_log [log directory] \ 17 | --dir_checkpoint [checkpoint directory] 18 | --gpu_ids [gpu id; '-1': no gpu, '0, 1, ..., N-1': gpus] 19 | --- 20 | $ python main.py --mode train \ 21 | --scope wgan-gp \ 22 | --name_data celeba \ 23 | --dir_data ./datasets \ 24 | --dir_log ./log \ 25 | --dir_checkpoint ./checkpoint 26 | --gpu_ids 0 27 | 28 | * Set **[scope name]** uniquely. 29 | * Hyperparameters were written to **arg.txt** under the **[log directory]**. 30 | * To understand hierarchy of directories based on their arguments, see **directories structure** below. 31 | 32 | 33 | ## Test 34 | $ python main.py --mode test \ 35 | --scope [scope name] \ 36 | --name_data [data name] \ 37 | --dir_data [data directory] \ 38 | --dir_log [log directory] \ 39 | --dir_checkpoint [checkpoint directory] \ 40 | --dir_result [result directory] 41 | --gpu_ids [gpu id; '-1': no gpu, '0, 1, ..., N-1': gpus] 42 | --- 43 | $ python main.py --mode test \ 44 | --scope wgan-gp \ 45 | --name_data celeba \ 46 | --dir_data ./datasets \ 47 | --dir_log ./log \ 48 | --dir_checkpoint ./checkpoints \ 49 | --dir_result ./results 50 | --gpu_ids 0 51 | 52 | * To test using trained network, set **[scope name]** defined in the **train** phase. 53 | * Generated images are saved in the **images** subfolder along with **[result directory]** folder. 54 | * **index.html** is also generated to display the generated images. 55 | 56 | 57 | ## Tensorboard 58 | $ tensorboard --logdir [log directory]/[scope name]/[data name] \ 59 | --port [(optional) 4 digit port number] 60 | --- 61 | $ tensorboard --logdir ./log/wgan-gp/celeba \ 62 | --port 6006 63 | 64 | After the above comment executes, go **http://localhost:6006** 65 | 66 | * You can change **[(optional) 4 digit port number]**. 67 | * Default 4 digit port number is **6006**. 68 | 69 | 70 | ## Results 71 | ![alt text](./img/generated_images.png "Generated Images by WGAN-GP") 72 | * The results were generated by a network trained with **celeba** dataset during **10 epochs**. 73 | * After the Test phase runs, execute **display_result.py** to display the figure. 74 | 75 | 76 | ## Directories structure 77 | pytorch-WGAN-GP 78 | +---[dir_checkpoint] 79 | | \---[scope] 80 | | \---[name_data] 81 | | +---model_epoch00000.pth 82 | | | ... 83 | | \---model_epoch12345.pth 84 | +---[dir_data] 85 | | \---[name_data] 86 | | +---000000.png 87 | | | ... 88 | | \---12345.png 89 | +---[dir_log] 90 | | \---[scope] 91 | | \---[name_data] 92 | | +---arg.txt 93 | | \---events.out.tfevents 94 | \---[dir_result] 95 | \---[scope] 96 | \---[name_data] 97 | +---images 98 | | +---00000-output.png 99 | | | ... 100 | | +---12345-output.png 101 | \---index.html 102 | 103 | --- 104 | 105 | pytorch-WGAN-GP 106 | +---checkpoints 107 | | \---wgan-gp 108 | | \---celeba 109 | | +---model_epoch00001.pth 110 | | | ... 111 | | \---model_epoch0010.pth 112 | +---datasets 113 | | \---celeba 114 | | +---000001.jpg 115 | | | ... 116 | | \---202599.jpg 117 | +---log 118 | | \---wgan-gp 119 | | \---celeba 120 | | +---arg.txt 121 | | \---events.out.tfevents 122 | \---results 123 | \---wgan-gp 124 | \---celeba 125 | +---images 126 | | +---0000-output.png 127 | | | ... 128 | | +---0127-output.png 129 | \---index.html 130 | 131 | * Above directory is created by setting arguments when **main.py** is executed. 132 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hanyoseob/pytorch-WGAN-GP/311745b5e05828c71d8bc22d9dd10ccdae4ab000/__init__.py -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from skimage import transform 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | class Dataset(torch.utils.data.Dataset): 9 | def __init__(self, data_dir, data_type='float32', nch=1, transform=[]): 10 | self.data_dir = data_dir 11 | self.transform = transform 12 | self.nch = nch 13 | self.data_type = data_type 14 | 15 | lst_data = os.listdir(data_dir) 16 | 17 | self.names = lst_data 18 | 19 | def __getitem__(self, index): 20 | data = plt.imread(os.path.join(self.data_dir, self.names[index]))[:, :, :self.nch] 21 | 22 | if data.dtype == np.uint8: 23 | data = data / 255.0 24 | 25 | if self.transform: 26 | data = self.transform(data) 27 | 28 | return data 29 | 30 | def __len__(self): 31 | return len(self.names) 32 | 33 | 34 | class ToTensor(object): 35 | def __call__(self, data): 36 | data = data.transpose((2, 0, 1)).astype(np.float32) 37 | return torch.from_numpy(data) 38 | 39 | 40 | class Normalize(object): 41 | def __call__(self, data): 42 | data = 2 * data - 1 43 | return data 44 | 45 | 46 | class RandomFlip(object): 47 | def __call__(self, data): 48 | if np.random.rand() > 0.5: 49 | data = np.fliplr(data) 50 | 51 | return data 52 | 53 | 54 | class Rescale(object): 55 | def __init__(self, output_size): 56 | assert isinstance(output_size, (int, tuple)) 57 | self.output_size = output_size 58 | 59 | def __call__(self, data): 60 | h, w = data.shape[:2] 61 | 62 | if isinstance(self.output_size, int): 63 | if h > w: 64 | new_h, new_w = self.output_size * h / w, self.output_size 65 | else: 66 | new_h, new_w = self.output_size, self.output_size * w / h 67 | else: 68 | new_h, new_w = self.output_size 69 | 70 | new_h, new_w = int(new_h), int(new_w) 71 | 72 | data = transform.resize(data, (new_h, new_w)) 73 | return data 74 | 75 | 76 | class CenterCrop(object): 77 | def __init__(self, output_size): 78 | assert isinstance(output_size, (int, tuple)) 79 | if isinstance(output_size, int): 80 | self.output_size = (output_size, output_size) 81 | else: 82 | assert len(output_size) == 2 83 | self.output_size = output_size 84 | 85 | def __call__(self, data): 86 | h, w = data.shape[:2] 87 | 88 | new_h, new_w = self.output_size 89 | 90 | top = int(abs(h - new_h) / 2) 91 | left = int(abs(w - new_w) / 2) 92 | 93 | data = data[top: top + new_h, left: left + new_w] 94 | 95 | return data 96 | 97 | 98 | class RandomCrop(object): 99 | 100 | def __init__(self, output_size): 101 | 102 | assert isinstance(output_size, (int, tuple)) 103 | if isinstance(output_size, int): 104 | self.output_size = (output_size, output_size) 105 | else: 106 | assert len(output_size) == 2 107 | self.output_size = output_size 108 | 109 | def __call__(self, data): 110 | h, w = data.shape[:2] 111 | 112 | new_h, new_w = self.output_size 113 | 114 | top = np.random.randint(0, h - new_h) 115 | left = np.random.randint(0, w - new_w) 116 | 117 | data = data[top: top + new_h, left: left + new_w] 118 | return data 119 | 120 | 121 | class ToNumpy(object): 122 | def __call__(self, data): 123 | 124 | if data.ndim == 3: 125 | data = data.to('cpu').detach().numpy().transpose((1, 2, 0)) 126 | elif data.ndim == 4: 127 | data = data.to('cpu').detach().numpy().transpose((0, 2, 3, 1)) 128 | 129 | return data 130 | 131 | 132 | class Denomalize(object): 133 | def __call__(self, data): 134 | 135 | return (data + 1) / 2 136 | -------------------------------------------------------------------------------- /display_result.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torchvision.utils as vutils 5 | import matplotlib.pyplot as plt 6 | 7 | dir_result = './results/wgan-gp/celeba/images' 8 | lst_result = os.listdir(dir_result) 9 | 10 | np.random.shuffle(lst_result) 11 | 12 | nx = 64 13 | ny = 64 14 | nch = 3 15 | 16 | n = 8 17 | m = 4 18 | 19 | n_id = np.arange(len(lst_result)//m) 20 | np.random.shuffle(n_id) 21 | img = torch.zeros((n*m, ny, nx, nch)) 22 | 23 | for i in range(n*m): 24 | p = n_id[i] 25 | img[i, :, :, :] = torch.from_numpy(plt.imread(os.path.join(dir_result, lst_result[p]))[:, :, :nch]) 26 | 27 | img = img.permute((0, 3, 1, 2)) 28 | 29 | plt.figure(figsize=(n, m)) 30 | plt.axis("off") 31 | # plt.title("Generated Images") 32 | plt.subplots_adjust(left=0, right=1, top=1, bottom=0) 33 | plt.imshow(np.transpose(vutils.make_grid(img, padding=2, normalize=True), (1, 2, 0))) 34 | 35 | plt.show() 36 | 37 | -------------------------------------------------------------------------------- /img/generated_images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hanyoseob/pytorch-WGAN-GP/311745b5e05828c71d8bc22d9dd10ccdae4ab000/img/generated_images.png -------------------------------------------------------------------------------- /img/paper1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hanyoseob/pytorch-WGAN-GP/311745b5e05828c71d8bc22d9dd10ccdae4ab000/img/paper1.png -------------------------------------------------------------------------------- /layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class CNR2d(nn.Module): 7 | def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, norm='bnorm', relu=0.0, drop=[], bias=[]): 8 | super().__init__() 9 | 10 | if bias == []: 11 | if norm == 'bnorm': 12 | bias = False 13 | else: 14 | bias = True 15 | 16 | layers = [] 17 | layers += [Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)] 18 | 19 | if norm != []: 20 | layers += [Norm2d(nch_out, norm)] 21 | 22 | if relu != []: 23 | layers += [ReLU(relu)] 24 | 25 | if drop != []: 26 | layers += [nn.Dropout2d(drop)] 27 | 28 | self.cbr = nn.Sequential(*layers) 29 | 30 | def forward(self, x): 31 | return self.cbr(x) 32 | 33 | 34 | class DECNR2d(nn.Module): 35 | def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, norm='bnorm', relu=0.0, drop=[], bias=[]): 36 | super().__init__() 37 | 38 | if bias == []: 39 | if norm == 'bnorm': 40 | bias = False 41 | else: 42 | bias = True 43 | 44 | layers = [] 45 | layers += [Deconv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias)] 46 | 47 | if norm != []: 48 | layers += [Norm2d(nch_out, norm)] 49 | 50 | if relu != []: 51 | layers += [ReLU(relu)] 52 | 53 | if drop != []: 54 | layers += [nn.Dropout2d(drop)] 55 | 56 | self.decbr = nn.Sequential(*layers) 57 | 58 | def forward(self, x): 59 | return self.decbr(x) 60 | 61 | 62 | class ResBlock(nn.Module): 63 | def __init__(self, nch_in, nch_out, kernel_size=3, stride=1, padding=1, padding_mode='reflection', norm='inorm', relu=0.0, drop=[], bias=[]): 64 | super().__init__() 65 | 66 | if bias == []: 67 | if norm == 'bnorm': 68 | bias = False 69 | else: 70 | bias = True 71 | 72 | layers = [] 73 | 74 | # 1st conv 75 | layers += [Padding(padding, padding_mode=padding_mode)] 76 | layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=relu)] 77 | 78 | if drop != []: 79 | layers += [nn.Dropout2d(drop)] 80 | 81 | # 2nd conv 82 | layers += [Padding(padding, padding_mode=padding_mode)] 83 | layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=[])] 84 | 85 | self.resblk = nn.Sequential(*layers) 86 | 87 | def forward(self, x): 88 | return x + self.resblk(x) 89 | 90 | 91 | class CNR1d(nn.Module): 92 | def __init__(self, nch_in, nch_out, norm='bnorm', relu=0.0, drop=[]): 93 | super().__init__() 94 | 95 | if norm == 'bnorm': 96 | bias = False 97 | else: 98 | bias = True 99 | 100 | layers = [] 101 | layers += [nn.Linear(nch_in, nch_out, bias=bias)] 102 | 103 | if norm != []: 104 | layers += [Norm2d(nch_out, norm)] 105 | 106 | if relu != []: 107 | layers += [ReLU(relu)] 108 | 109 | if drop != []: 110 | layers += [nn.Dropout2d(drop)] 111 | 112 | self.cbr = nn.Sequential(*layers) 113 | 114 | def forward(self, x): 115 | return self.cbr(x) 116 | 117 | 118 | class Conv2d(nn.Module): 119 | def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, bias=True): 120 | super(Conv2d, self).__init__() 121 | self.conv = nn.Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) 122 | 123 | def forward(self, x): 124 | return self.conv(x) 125 | 126 | 127 | class Deconv2d(nn.Module): 128 | def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, bias=True): 129 | super(Deconv2d, self).__init__() 130 | self.deconv = nn.ConvTranspose2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias) 131 | 132 | # layers = [nn.Upsample(scale_factor=2, mode='bilinear'), 133 | # nn.ReflectionPad2d(1), 134 | # nn.Conv2d(nch_in , nch_out, kernel_size=3, stride=1, padding=0)] 135 | # 136 | # self.deconv = nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | return self.deconv(x) 140 | 141 | 142 | class Linear(nn.Module): 143 | def __init__(self, nch_in, nch_out): 144 | super(Linear, self).__init__() 145 | self.linear = nn.Linear(nch_in, nch_out) 146 | 147 | def forward(self, x): 148 | return self.linear(x) 149 | 150 | 151 | class Norm2d(nn.Module): 152 | def __init__(self, nch, norm_mode): 153 | super(Norm2d, self).__init__() 154 | if norm_mode == 'bnorm': 155 | self.norm = nn.BatchNorm2d(nch) 156 | elif norm_mode == 'inorm': 157 | self.norm = nn.InstanceNorm2d(nch) 158 | 159 | def forward(self, x): 160 | return self.norm(x) 161 | 162 | 163 | class ReLU(nn.Module): 164 | def __init__(self, relu): 165 | super(ReLU, self).__init__() 166 | if relu > 0: 167 | self.relu = nn.LeakyReLU(relu, True) 168 | elif relu == 0: 169 | self.relu = nn.ReLU(True) 170 | 171 | def forward(self, x): 172 | return self.relu(x) 173 | 174 | 175 | class Padding(nn.Module): 176 | def __init__(self, padding, padding_mode='zeros', value=0): 177 | super(Padding, self).__init__() 178 | if padding_mode == 'reflection': 179 | self. padding = nn.ReflectionPad2d(padding) 180 | elif padding_mode == 'replication': 181 | self.padding = nn.ReplicationPad2d(padding) 182 | elif padding_mode == 'constant': 183 | self.padding = nn.ConstantPad2d(padding, value) 184 | elif padding_mode == 'zeros': 185 | self.padding = nn.ZeroPad2d(padding) 186 | 187 | def forward(self, x): 188 | return self.padding(x) 189 | 190 | 191 | class Pooling2d(nn.Module): 192 | def __init__(self, nch=[], pool=2, type='avg'): 193 | super().__init__() 194 | 195 | if type == 'avg': 196 | self.pooling = nn.AvgPool2d(pool) 197 | elif type == 'max': 198 | self.pooling = nn.MaxPool2d(pool) 199 | elif type == 'conv': 200 | self.pooling = nn.Conv2d(nch, nch, kernel_size=pool, stride=pool) 201 | 202 | def forward(self, x): 203 | return self.pooling(x) 204 | 205 | 206 | class UnPooling2d(nn.Module): 207 | def __init__(self, nch=[], pool=2, type='nearest'): 208 | super().__init__() 209 | 210 | if type == 'nearest': 211 | self.unpooling = nn.Upsample(scale_factor=pool, mode='nearest', align_corners=True) 212 | elif type == 'bilinear': 213 | self.unpooling = nn.Upsample(scale_factor=pool, mode='bilinear', align_corners=True) 214 | elif type == 'conv': 215 | self.unpooling = nn.ConvTranspose2d(nch, nch, kernel_size=pool, stride=pool) 216 | 217 | def forward(self, x): 218 | return self.unpooling(x) 219 | 220 | 221 | class Concat(nn.Module): 222 | def __init__(self): 223 | super().__init__() 224 | 225 | def forward(self, x1, x2): 226 | diffy = x2.size()[2] - x1.size()[2] 227 | diffx = x2.size()[3] - x1.size()[3] 228 | 229 | x1 = F.pad(x1, [diffx // 2, diffx - diffx // 2, 230 | diffy // 2, diffy - diffy // 2]) 231 | 232 | return torch.cat([x2, x1], dim=1) 233 | 234 | 235 | class TV1dLoss(nn.Module): 236 | def __init__(self): 237 | super(TV1dLoss, self).__init__() 238 | 239 | def forward(self, input): 240 | # loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \ 241 | # torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :])) 242 | loss = torch.mean(torch.abs(input[:, :-1] - input[:, 1:])) 243 | 244 | return loss 245 | 246 | 247 | class TV2dLoss(nn.Module): 248 | def __init__(self): 249 | super(TV2dLoss, self).__init__() 250 | 251 | def forward(self, input): 252 | loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \ 253 | torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :])) 254 | return loss 255 | 256 | 257 | class SSIM2dLoss(nn.Module): 258 | def __init__(self): 259 | super(SSIM2dLoss, self).__init__() 260 | 261 | def forward(self, input, targer): 262 | loss = 0 263 | return loss 264 | 265 | 266 | class GradientPaneltyLoss(nn.Module): 267 | def __init__(self): 268 | super(GradientPaneltyLoss, self).__init__() 269 | 270 | def forward(self, y, x): 271 | """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.""" 272 | weight = torch.ones_like(y) 273 | dydx = torch.autograd.grad(outputs=y, 274 | inputs=x, 275 | grad_outputs=weight, 276 | retain_graph=True, 277 | create_graph=True, 278 | only_inputs=True)[0] 279 | 280 | dydx = dydx.view(dydx.size(0), -1) 281 | dydx_l2norm = torch.sqrt(torch.sum(dydx ** 2, dim=1)) 282 | return torch.mean((dydx_l2norm - 1) ** 2) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch.backends.cudnn as cudnn 3 | 4 | from train import * 5 | from utils import * 6 | 7 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 8 | 9 | cudnn.benchmark = True 10 | cudnn.fastest = True 11 | 12 | ## setup parse 13 | parser = argparse.ArgumentParser(description='Train the WGAN-GP network', 14 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 15 | 16 | parser.add_argument('--gpu_ids', default='0', dest='gpu_ids') 17 | 18 | parser.add_argument('--mode', default='train', choices=['train', 'test'], dest='mode') 19 | parser.add_argument('--train_continue', default='off', choices=['on', 'off'], dest='train_continue') 20 | 21 | parser.add_argument('--scope', default='wgan-gp', dest='scope') 22 | parser.add_argument('--norm', type=str, default='inorm', dest='norm') 23 | 24 | parser.add_argument('--name_data', type=str, default='celeba', dest='name_data') 25 | 26 | parser.add_argument('--dir_checkpoint', default='./checkpoints', dest='dir_checkpoint') 27 | parser.add_argument('--dir_log', default='./log', dest='dir_log') 28 | 29 | parser.add_argument('--dir_data', default='../datasets', dest='dir_data') 30 | parser.add_argument('--dir_result', default='./results', dest='dir_result') 31 | 32 | parser.add_argument('--num_epoch', type=int, default=10, dest='num_epoch') 33 | parser.add_argument('--batch_size', type=int, default=128, dest='batch_size') 34 | 35 | parser.add_argument('--lr_G', type=float, default=2e-4, dest='lr_G') 36 | parser.add_argument('--lr_D', type=float, default=2e-4, dest='lr_D') 37 | 38 | parser.add_argument('--num_freq_disp', type=int, default=50, dest='num_freq_disp') 39 | parser.add_argument('--num_freq_save', type=int, default=5, dest='num_freq_save') 40 | 41 | parser.add_argument('--lr_policy', type=str, default='linear', choices=['linear', 'step', 'plateau', 'cosine'], dest='lr_policy') 42 | parser.add_argument('--n_epochs', type=int, default=100, dest='n_epochs') 43 | parser.add_argument('--n_epochs_decay', type=int, default=100, dest='n_epochs_decay') 44 | parser.add_argument('--lr_decay_iters', type=int, default=50, dest='lr_decay_iters') 45 | 46 | parser.add_argument('--wgt_gan', type=float, default=1e0, dest='wgt_gan') 47 | parser.add_argument('--wgt_disc', type=float, default=1e0, dest='wgt_disc') 48 | 49 | parser.add_argument('--optim', default='adam', choices=['sgd', 'adam', 'rmsprop'], dest='optim') 50 | parser.add_argument('--beta1', default=0.5, dest='beta1') 51 | 52 | parser.add_argument('--ny_in', type=int, default=1, dest='ny_in') 53 | parser.add_argument('--nx_in', type=int, default=1, dest='nx_in') 54 | parser.add_argument('--nch_in', type=int, default=100, dest='nch_in') 55 | 56 | parser.add_argument('--ny_load', type=int, default=64, dest='ny_load') 57 | parser.add_argument('--nx_load', type=int, default=64, dest='nx_load') 58 | parser.add_argument('--nch_load', type=int, default=3, dest='nch_load') 59 | 60 | parser.add_argument('--ny_out', type=int, default=64, dest='ny_out') 61 | parser.add_argument('--nx_out', type=int, default=64, dest='nx_out') 62 | parser.add_argument('--nch_out', type=int, default=3, dest='nch_out') 63 | 64 | parser.add_argument('--nch_ker', type=int, default=64, dest='nch_ker') 65 | 66 | parser.add_argument('--data_type', default='float32', dest='data_type') 67 | 68 | PARSER = Parser(parser) 69 | 70 | def main(): 71 | ARGS = PARSER.get_arguments() 72 | PARSER.write_args() 73 | PARSER.print_args() 74 | 75 | TRAINER = Train(ARGS) 76 | 77 | if ARGS.mode == 'train': 78 | TRAINER.train() 79 | elif ARGS.mode == 'test': 80 | TRAINER.test() 81 | 82 | if __name__ == '__main__': 83 | main() -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from layer import * 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import init 6 | from torch.optim import lr_scheduler 7 | 8 | 9 | class DCGAN(nn.Module): 10 | def __init__(self, nch_in, nch_out, nch_ker=64, norm='bnorm'): 11 | super(DCGAN, self).__init__() 12 | 13 | self.nch_in = nch_in 14 | self.nch_out = nch_out 15 | self.nch_ker = nch_ker 16 | self.norm = norm 17 | 18 | if norm == 'bnorm': 19 | self.bias = False 20 | else: 21 | self.bias = True 22 | 23 | self.dec5 = DECNR2d(1 * self.nch_in, 8 * self.nch_ker, kernel_size=4, stride=1, padding=0, norm=self.norm, relu=0.0, drop=[]) 24 | self.dec4 = DECNR2d(8 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0, drop=[]) 25 | self.dec3 = DECNR2d(4 * self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0, drop=[]) 26 | self.dec2 = DECNR2d(2 * self.nch_ker, 1 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0, drop=[]) 27 | self.dec1 = Deconv2d(1 * self.nch_ker, 1 * self.nch_out,kernel_size=4, stride=2, padding=1, bias=False) 28 | 29 | def forward(self, x): 30 | 31 | x = self.dec5(x) 32 | x = self.dec4(x) 33 | x = self.dec3(x) 34 | x = self.dec2(x) 35 | x = self.dec1(x) 36 | 37 | x = torch.tanh(x) 38 | 39 | return x 40 | 41 | 42 | class UNet(nn.Module): 43 | def __init__(self, nch_in, nch_out, nch_ker=64, norm='bnorm'): 44 | super(UNet, self).__init__() 45 | 46 | self.nch_in = nch_in 47 | self.nch_out = nch_out 48 | self.nch_ker = nch_ker 49 | self.norm = norm 50 | 51 | if norm == 'bnorm': 52 | self.bias = False 53 | else: 54 | self.bias = True 55 | 56 | self.enc1 = CNR2d(1 * self.nch_in, 1 * self.nch_ker, stride=2, norm=self.norm, relu=0.2, drop=[]) 57 | self.enc2 = CNR2d(1 * self.nch_ker, 2 * self.nch_ker, stride=2, norm=self.norm, relu=0.2, drop=[]) 58 | self.enc3 = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, stride=2, norm=self.norm, relu=0.2, drop=[]) 59 | self.enc4 = CNR2d(4 * self.nch_ker, 8 * self.nch_ker, stride=2, norm=self.norm, relu=0.2, drop=[]) 60 | self.enc5 = CNR2d(8 * self.nch_ker, 8 * self.nch_ker, stride=2, norm=self.norm, relu=0.2, drop=[]) 61 | self.enc6 = CNR2d(8 * self.nch_ker, 8 * self.nch_ker, stride=2, norm=self.norm, relu=0.2, drop=[]) 62 | self.enc7 = CNR2d(8 * self.nch_ker, 8 * self.nch_ker, stride=2, norm=self.norm, relu=0.2, drop=[]) 63 | self.enc8 = CNR2d(8 * self.nch_ker, 8 * self.nch_ker, stride=2, norm=self.norm, relu=0.0, drop=[]) 64 | 65 | self.dec8 = DECNR2d(1 * 8 * self.nch_ker, 8 * self.nch_ker, stride=2, norm=self.norm, relu=0.0, drop=0.5) 66 | self.dec7 = DECNR2d(2 * 8 * self.nch_ker, 8 * self.nch_ker, stride=2, norm=self.norm, relu=0.0, drop=0.5) 67 | self.dec6 = DECNR2d(2 * 8 * self.nch_ker, 8 * self.nch_ker, stride=2, norm=self.norm, relu=0.0, drop=0.5) 68 | self.dec5 = DECNR2d(2 * 8 * self.nch_ker, 8 * self.nch_ker, stride=2, norm=self.norm, relu=0.0, drop=[]) 69 | self.dec4 = DECNR2d(2 * 8 * self.nch_ker, 4 * self.nch_ker, stride=2, norm=self.norm, relu=0.0, drop=[]) 70 | self.dec3 = DECNR2d(2 * 4 * self.nch_ker, 2 * self.nch_ker, stride=2, norm=self.norm, relu=0.0, drop=[]) 71 | self.dec2 = DECNR2d(2 * 2 * self.nch_ker, 1 * self.nch_ker, stride=2, norm=self.norm, relu=0.0, drop=[]) 72 | self.dec1 = DECNR2d(2 * 1 * self.nch_ker, 1 * self.nch_out, stride=2, norm=[], relu=[], drop=[], bias=False) 73 | 74 | def forward(self, x): 75 | 76 | enc1 = self.enc1(x) 77 | enc2 = self.enc2(enc1) 78 | enc3 = self.enc3(enc2) 79 | enc4 = self.enc4(enc3) 80 | enc5 = self.enc5(enc4) 81 | enc6 = self.enc6(enc5) 82 | enc7 = self.enc7(enc6) 83 | enc8 = self.enc8(enc7) 84 | 85 | dec8 = self.dec8(enc8) 86 | dec7 = self.dec7(torch.cat([enc7, dec8], dim=1)) 87 | dec6 = self.dec6(torch.cat([enc6, dec7], dim=1)) 88 | dec5 = self.dec5(torch.cat([enc5, dec6], dim=1)) 89 | dec4 = self.dec4(torch.cat([enc4, dec5], dim=1)) 90 | dec3 = self.dec3(torch.cat([enc3, dec4], dim=1)) 91 | dec2 = self.dec2(torch.cat([enc2, dec3], dim=1)) 92 | dec1 = self.dec1(torch.cat([enc1, dec2], dim=1)) 93 | 94 | x = torch.tanh(dec1) 95 | 96 | return x 97 | 98 | 99 | class ResNet(nn.Module): 100 | def __init__(self, nch_in, nch_out, nch_ker=64, norm='bnorm', nblk=6): 101 | super(ResNet, self).__init__() 102 | 103 | self.nch_in = nch_in 104 | self.nch_out = nch_out 105 | self.nch_ker = nch_ker 106 | self.norm = norm 107 | self.nblk = nblk 108 | 109 | if norm == 'bnorm': 110 | self.bias = False 111 | else: 112 | self.bias = True 113 | 114 | self.enc1 = CNR2d(self.nch_in, 1 * self.nch_ker, kernel_size=7, stride=1, padding=3, norm=self.norm, relu=0.0) 115 | 116 | self.enc2 = CNR2d(1 * self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0) 117 | 118 | self.enc3 = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0) 119 | 120 | if self.nblk: 121 | res = [] 122 | 123 | for i in range(self.nblk): 124 | res += [ResBlock(4 * self.nch_ker, 4 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection')] 125 | 126 | self.res = nn.Sequential(*res) 127 | 128 | self.dec3 = DECNR2d(4 * self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0) 129 | 130 | self.dec2 = DECNR2d(2 * self.nch_ker, 1 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0) 131 | 132 | self.dec1 = CNR2d(1 * self.nch_ker, self.nch_out, kernel_size=7, stride=1, padding=3, norm=[], relu=[], bias=False) 133 | 134 | def forward(self, x): 135 | x = self.enc1(x) 136 | x = self.enc2(x) 137 | x = self.enc3(x) 138 | 139 | if self.nblk: 140 | x = self.res(x) 141 | 142 | x = self.dec3(x) 143 | x = self.dec2(x) 144 | x = self.dec1(x) 145 | 146 | x = torch.tanh(x) 147 | 148 | return x 149 | 150 | 151 | class Discriminator(nn.Module): 152 | def __init__(self, nch_in, nch_ker=64, norm='bnorm'): 153 | super(Discriminator, self).__init__() 154 | 155 | self.nch_in = nch_in 156 | self.nch_ker = nch_ker 157 | self.norm = norm 158 | 159 | if norm == 'bnorm': 160 | self.bias = False 161 | else: 162 | self.bias = True 163 | 164 | # dsc1 : 256 x 256 x 3 -> 128 x 128 x 64 165 | # dsc2 : 128 x 128 x 64 -> 64 x 64 x 128 166 | # dsc3 : 64 x 64 x 128 -> 32 x 32 x 256 167 | # dsc4 : 32 x 32 x 256 -> 32 x 32 x 512 168 | # dsc5 : 32 x 32 x 512 -> 32 x 32 x 1 169 | 170 | self.dsc1 = CNR2d(1 * self.nch_in, 1 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.2) 171 | self.dsc2 = CNR2d(1 * self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.2) 172 | self.dsc3 = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.2) 173 | self.dsc4 = CNR2d(4 * self.nch_ker, 8 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.2) 174 | self.dsc5 = CNR2d(8 * self.nch_ker, 1, kernel_size=4, stride=1, padding=1, norm=[], relu=[], bias=False) 175 | 176 | # self.dsc1 = CNR2d(1 * self.nch_in, 1 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=[], relu=0.2) 177 | # self.dsc2 = CNR2d(1 * self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=[], relu=0.2) 178 | # self.dsc3 = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=[], relu=0.2) 179 | # self.dsc4 = CNR2d(4 * self.nch_ker, 8 * self.nch_ker, kernel_size=4, stride=1, padding=1, norm=[], relu=0.2) 180 | # self.dsc5 = CNR2d(8 * self.nch_ker, 1, kernel_size=4, stride=1, padding=1, norm=[], relu=[], bias=False) 181 | 182 | def forward(self, x): 183 | 184 | x = self.dsc1(x) 185 | x = self.dsc2(x) 186 | x = self.dsc3(x) 187 | x = self.dsc4(x) 188 | x = self.dsc5(x) 189 | 190 | # x = torch.sigmoid(x) 191 | 192 | return x 193 | 194 | 195 | def init_weights(net, init_type='normal', init_gain=0.02): 196 | """Initialize network weights. 197 | 198 | Parameters: 199 | net (network) -- network to be initialized 200 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 201 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 202 | 203 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 204 | work better for some applications. Feel free to try yourself. 205 | """ 206 | def init_func(m): # define the initialization function 207 | classname = m.__class__.__name__ 208 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 209 | if init_type == 'normal': 210 | init.normal_(m.weight.data, 0.0, init_gain) 211 | elif init_type == 'xavier': 212 | init.xavier_normal_(m.weight.data, gain=init_gain) 213 | elif init_type == 'kaiming': 214 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 215 | elif init_type == 'orthogonal': 216 | init.orthogonal_(m.weight.data, gain=init_gain) 217 | else: 218 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 219 | if hasattr(m, 'bias') and m.bias is not None: 220 | init.constant_(m.bias.data, 0.0) 221 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 222 | init.normal_(m.weight.data, 1.0, init_gain) 223 | init.constant_(m.bias.data, 0.0) 224 | 225 | print('initialize network with %s' % init_type) 226 | net.apply(init_func) # apply the initialization function 227 | 228 | 229 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 230 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 231 | Parameters: 232 | net (network) -- the network to be initialized 233 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 234 | gain (float) -- scaling factor for normal, xavier and orthogonal. 235 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 236 | 237 | Return an initialized network. 238 | """ 239 | if gpu_ids: 240 | assert(torch.cuda.is_available()) 241 | net.to(gpu_ids[0]) 242 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 243 | init_weights(net, init_type, init_gain=init_gain) 244 | return net 245 | 246 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from model import * 2 | from dataset import * 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torchvision import transforms 8 | from torch.utils.tensorboard import SummaryWriter 9 | import matplotlib.pyplot as plt 10 | from matplotlib.pyplot import cm 11 | from statistics import mean 12 | 13 | 14 | class Train: 15 | def __init__(self, args): 16 | self.mode = args.mode 17 | self.train_continue = args.train_continue 18 | 19 | self.scope = args.scope 20 | self.dir_checkpoint = args.dir_checkpoint 21 | self.dir_log = args.dir_log 22 | 23 | self.dir_data = args.dir_data 24 | self.dir_result = args.dir_result 25 | 26 | self.num_epoch = args.num_epoch 27 | self.batch_size = args.batch_size 28 | 29 | self.lr_G = args.lr_G 30 | self.lr_D = args.lr_D 31 | 32 | self.wgt_gan = args.wgt_gan 33 | self.wgt_disc = args.wgt_disc 34 | 35 | self.optim = args.optim 36 | self.beta1 = args.beta1 37 | 38 | self.ny_in = args.ny_in 39 | self.nx_in = args.nx_in 40 | self.nch_in = args.nch_in 41 | 42 | self.ny_load = args.ny_load 43 | self.nx_load = args.nx_load 44 | self.nch_load = args.nch_load 45 | 46 | self.ny_out = args.ny_out 47 | self.nx_out = args.nx_out 48 | self.nch_out = args.nch_out 49 | 50 | self.nch_ker = args.nch_ker 51 | 52 | self.data_type = args.data_type 53 | self.norm = args.norm 54 | 55 | self.gpu_ids = args.gpu_ids 56 | 57 | self.num_freq_disp = args.num_freq_disp 58 | self.num_freq_save = args.num_freq_save 59 | 60 | self.name_data = args.name_data 61 | 62 | if self.gpu_ids and torch.cuda.is_available(): 63 | self.device = torch.device("cuda:%d" % self.gpu_ids[0]) 64 | torch.cuda.set_device(self.gpu_ids[0]) 65 | else: 66 | self.device = torch.device("cpu") 67 | 68 | def save(self, dir_chck, netG, netD, optimG, optimD, epoch): 69 | if not os.path.exists(dir_chck): 70 | os.makedirs(dir_chck) 71 | 72 | torch.save({'netG': netG.state_dict(), 'netD': netD.state_dict(), 73 | 'optimG': optimG.state_dict(), 'optimD': optimD.state_dict()}, 74 | '%s/model_epoch%04d.pth' % (dir_chck, epoch)) 75 | 76 | def load(self, dir_chck, netG, netD=[], optimG=[], optimD=[], epoch=[], mode='train'): 77 | if not epoch: 78 | ckpt = os.listdir(dir_chck) 79 | ckpt.sort() 80 | epoch = int(ckpt[-1].split('epoch')[1].split('.pth')[0]) 81 | 82 | dict_net = torch.load('%s/model_epoch%04d.pth' % (dir_chck, epoch)) 83 | 84 | print('Loaded %dth network' % epoch) 85 | 86 | if mode == 'train': 87 | netG.load_state_dict(dict_net['netG']) 88 | netD.load_state_dict(dict_net['netD']) 89 | optimG.load_state_dict(dict_net['optimG']) 90 | optimD.load_state_dict(dict_net['optimD']) 91 | 92 | return netG, netD, optimG, optimD, epoch 93 | 94 | elif mode == 'test': 95 | netG.load_state_dict(dict_net['netG']) 96 | 97 | return netG, epoch 98 | 99 | def preprocess(self, data): 100 | rescale = Rescale((self.ny_load, self.nx_load)) 101 | randomcrop = RandomCrop((self.ny_out, self.nx_out)) 102 | normalize = Normalize() 103 | randomflip = RandomFlip() 104 | totensor = ToTensor() 105 | # return totensor(randomcrop(rescale(randomflip(nomalize(data))))) 106 | return totensor(normalize(rescale(data))) 107 | 108 | def deprocess(self, data): 109 | tonumpy = ToNumpy() 110 | denomalize = Denomalize() 111 | return denomalize(tonumpy(data)) 112 | 113 | 114 | def train(self): 115 | mode = self.mode 116 | 117 | train_continue = self.train_continue 118 | num_epoch = self.num_epoch 119 | 120 | lr_G = self.lr_G 121 | lr_D = self.lr_D 122 | 123 | wgt_gan = self.wgt_gan 124 | wgt_disc = self.wgt_disc 125 | 126 | batch_size = self.batch_size 127 | device = self.device 128 | 129 | gpu_ids = self.gpu_ids 130 | 131 | nch_in = self.nch_in 132 | nch_out = self.nch_out 133 | nch_ker = self.nch_ker 134 | 135 | norm = self.norm 136 | name_data = self.name_data 137 | 138 | num_freq_disp = self.num_freq_disp 139 | num_freq_save = self.num_freq_save 140 | 141 | ny_in = self.ny_in 142 | nx_in = self.nx_in 143 | 144 | ## setup dataset 145 | dir_chck = os.path.join(self.dir_checkpoint, self.scope, name_data) 146 | 147 | dir_data_train = os.path.join(self.dir_data, name_data) 148 | dir_log = os.path.join(self.dir_log, self.scope, name_data) 149 | 150 | transform_train = transforms.Compose([Normalize(), Rescale((self.ny_load, self.nx_load)), ToTensor()]) 151 | transform_inv = transforms.Compose([ToNumpy(), Denomalize()]) 152 | 153 | dataset_train = Dataset(dir_data_train, data_type=self.data_type, nch=self.nch_out, transform=transform_train) 154 | 155 | loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=8) 156 | 157 | num_train = len(dataset_train) 158 | 159 | num_batch_train = int((num_train / batch_size) + ((num_train % batch_size) != 0)) 160 | 161 | ## setup network 162 | netG = DCGAN(nch_in, nch_out, nch_ker, norm) 163 | netD = Discriminator(nch_out, nch_ker, []) 164 | 165 | init_net(netG, init_type='normal', init_gain=0.02, gpu_ids=gpu_ids) 166 | init_net(netD, init_type='normal', init_gain=0.02, gpu_ids=gpu_ids) 167 | 168 | ## setup loss & optimization 169 | fn_GAN = nn.BCEWithLogitsLoss().to(device) 170 | fn_GP = GradientPaneltyLoss().to(device) 171 | 172 | paramsG = netG.parameters() 173 | paramsD = netD.parameters() 174 | 175 | optimG = torch.optim.Adam(paramsG, lr=lr_G, betas=(self.beta1, 0.999)) 176 | optimD = torch.optim.Adam(paramsD, lr=lr_D, betas=(self.beta1, 0.999)) 177 | 178 | # schedG = get_scheduler(optimG, self.opts) 179 | # schedD = get_scheduler(optimD, self.opts) 180 | 181 | # schedG = torch.optim.lr_scheduler.ExponentialLR(optimG, gamma=0.9) 182 | # schedD = torch.optim.lr_scheduler.ExponentialLR(optimD, gamma=0.9) 183 | 184 | ## load from checkpoints 185 | st_epoch = 0 186 | 187 | if train_continue == 'on': 188 | netG, netD, optimG, optimD, st_epoch = self.load(dir_chck, netG, netD, optimG, optimD, mode=mode) 189 | 190 | ## setup tensorboard 191 | writer_train = SummaryWriter(log_dir=dir_log) 192 | 193 | for epoch in range(st_epoch + 1, num_epoch + 1): 194 | ## training phase 195 | netG.train() 196 | netD.train() 197 | 198 | loss_G_train = [] 199 | loss_D_real_train = [] 200 | loss_D_fake_train = [] 201 | 202 | for i, data in enumerate(loader_train, 1): 203 | def should(freq): 204 | return freq > 0 and (i % freq == 0 or i == num_batch_train) 205 | 206 | label = data.to(device) 207 | input = torch.randn(label.size(0), nch_in, ny_in, nx_in).to(device) 208 | 209 | # forward netG 210 | output = netG(input) 211 | 212 | # backward netD 213 | set_requires_grad(netD, True) 214 | optimD.zero_grad() 215 | 216 | pred_real = netD(label) 217 | pred_fake = netD(output.detach()) 218 | 219 | alpha = torch.rand(label.size(0), 1, 1, 1).to(self.device) 220 | output_ = (alpha * label + (1 - alpha) * output.detach()).requires_grad_(True) 221 | src_out_ = netD(output_) 222 | 223 | # BCE Loss 224 | # loss_D_real = fn_GAN(pred_real, torch.ones_like(pred_real)) 225 | # loss_D_fake = fn_GAN(pred_fake, torch.zeros_like(pred_fake)) 226 | 227 | # WGAN Loss 228 | loss_D_real = torch.mean(pred_real) 229 | loss_D_fake = -torch.mean(pred_fake) 230 | 231 | # Gradient penalty Loss 232 | loss_D_gp = fn_GP(src_out_, output_) 233 | 234 | loss_D = 0.5 * (loss_D_real + loss_D_fake) + loss_D_gp 235 | # loss_D = 0.5 * (loss_D_real + loss_D_fake) 236 | 237 | loss_D.backward() 238 | optimD.step() 239 | 240 | # backward netG 241 | set_requires_grad(netD, False) 242 | optimG.zero_grad() 243 | 244 | pred_fake = netD(output) 245 | 246 | # loss_G = fn_GAN(pred_fake, torch.ones_like(pred_fake)) 247 | loss_G = torch.mean(pred_fake) 248 | 249 | loss_G.backward() 250 | optimG.step() 251 | 252 | # get losses 253 | loss_G_train += [loss_G.item()] 254 | loss_D_real_train += [loss_D_real.item()] 255 | loss_D_fake_train += [loss_D_fake.item()] 256 | 257 | print('TRAIN: EPOCH %d: BATCH %04d/%04d: ' 258 | 'GEN GAN: %.4f DISC FAKE: %.4f DISC REAL: %.4f' % 259 | (epoch, i, num_batch_train, 260 | mean(loss_G_train), mean(loss_D_fake_train), mean(loss_D_real_train))) 261 | 262 | if should(num_freq_disp): 263 | ## show output 264 | output = transform_inv(output) 265 | label = transform_inv(label) 266 | 267 | writer_train.add_images('output', output, num_batch_train * (epoch - 1) + i, dataformats='NHWC') 268 | writer_train.add_images('label', label, num_batch_train * (epoch - 1) + i, dataformats='NHWC') 269 | 270 | writer_train.add_scalar('loss_G', mean(loss_G_train), epoch) 271 | writer_train.add_scalar('loss_D_fake', mean(loss_D_fake_train), epoch) 272 | writer_train.add_scalar('loss_D_real', mean(loss_D_real_train), epoch) 273 | # writer_train.add_scalar('distance_Wasserstein', -(mean(loss_D_fake_train) + mean(loss_D_real_train)), epoch) 274 | 275 | # update schduler 276 | # schedG.step() 277 | # schedD.step() 278 | 279 | ## save 280 | if (epoch % num_freq_save) == 0: 281 | self.save(dir_chck, netG, netD, optimG, optimD, epoch) 282 | 283 | writer_train.close() 284 | 285 | def test(self): 286 | mode = self.mode 287 | 288 | batch_size = self.batch_size 289 | device = self.device 290 | gpu_ids = self.gpu_ids 291 | 292 | ny_in = self.ny_in 293 | nx_in = self.nx_in 294 | 295 | nch_in = self.nch_in 296 | nch_out = self.nch_out 297 | nch_ker = self.nch_ker 298 | 299 | norm = self.norm 300 | 301 | name_data = self.name_data 302 | 303 | ## setup dataset 304 | dir_chck = os.path.join(self.dir_checkpoint, self.scope, name_data) 305 | 306 | dir_result = os.path.join(self.dir_result, self.scope, name_data) 307 | dir_result_save = os.path.join(dir_result, 'images') 308 | if not os.path.exists(dir_result_save): 309 | os.makedirs(dir_result_save) 310 | 311 | transform_inv = transforms.Compose([ToNumpy(), Denomalize()]) 312 | 313 | ## setup network 314 | netG = DCGAN(nch_in, nch_out, nch_ker, norm) 315 | init_net(netG, init_type='normal', init_gain=0.02, gpu_ids=gpu_ids) 316 | 317 | ## load from checkpoints 318 | st_epoch = 0 319 | 320 | netG, st_epoch = self.load(dir_chck, netG, mode=mode) 321 | 322 | ## test phase 323 | with torch.no_grad(): 324 | netG.eval() 325 | # netG.train() 326 | 327 | input = torch.randn(batch_size, nch_in, ny_in, nx_in).to(device) 328 | 329 | output = netG(input) 330 | 331 | output = transform_inv(output) 332 | 333 | for j in range(output.shape[0]): 334 | name = j 335 | fileset = {'name': name, 336 | 'output': "%04d-output.png" % name} 337 | 338 | if nch_out == 3: 339 | plt.imsave(os.path.join(dir_result_save, fileset['output']), output[j, :, :, :].squeeze()) 340 | elif nch_out == 1: 341 | plt.imsave(os.path.join(dir_result_save, fileset['output']), output[j, :, :, :].squeeze(), cmap=cm.gray) 342 | 343 | append_index(dir_result, fileset) 344 | 345 | 346 | def set_requires_grad(nets, requires_grad=False): 347 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 348 | Parameters: 349 | nets (network list) -- a list of networks 350 | requires_grad (bool) -- whether the networks require gradients or not 351 | """ 352 | if not isinstance(nets, list): 353 | nets = [nets] 354 | for net in nets: 355 | if net is not None: 356 | for param in net.parameters(): 357 | param.requires_grad = requires_grad 358 | 359 | 360 | def get_scheduler(optimizer, opt): 361 | """Return a learning rate scheduler 362 | 363 | Parameters: 364 | optimizer -- the optimizer of the network 365 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  366 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 367 | 368 | For 'linear', we keep the same learning rate for the first epochs 369 | and linearly decay the rate to zero over the next epochs. 370 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 371 | See https://pytorch.org/docs/stable/optim.html for more details. 372 | """ 373 | if opt.lr_policy == 'linear': 374 | def lambda_rule(epoch): 375 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) 376 | return lr_l 377 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 378 | elif opt.lr_policy == 'step': 379 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 380 | elif opt.lr_policy == 'plateau': 381 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 382 | elif opt.lr_policy == 'cosine': 383 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) 384 | else: 385 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 386 | return scheduler 387 | 388 | 389 | def append_index(dir_result, fileset, step=False): 390 | index_path = os.path.join(dir_result, "index.html") 391 | if os.path.exists(index_path): 392 | index = open(index_path, "a") 393 | else: 394 | index = open(index_path, "w") 395 | index.write("") 396 | if step: 397 | index.write("") 398 | for key, value in fileset.items(): 399 | index.write("" % key) 400 | index.write('') 401 | 402 | # for fileset in filesets: 403 | index.write("") 404 | 405 | if step: 406 | index.write("" % fileset["step"]) 407 | index.write("" % fileset["name"]) 408 | 409 | del fileset['name'] 410 | 411 | for key, value in fileset.items(): 412 | index.write("" % value) 413 | 414 | index.write("") 415 | return index_path 416 | 417 | 418 | def add_plot(output, label, writer, epoch=[], ylabel='Density', xlabel='Radius', namescope=[]): 419 | fig, ax = plt.subplots() 420 | 421 | ax.plot(output.transpose(1, 0).detach().numpy(), '-') 422 | ax.plot(label.transpose(1, 0).detach().numpy(), '--') 423 | 424 | ax.set_xlim(0, 400) 425 | 426 | ax.grid(True) 427 | ax.set_ylabel(ylabel) 428 | ax.set_xlabel(xlabel) 429 | 430 | writer.add_figure(namescope, fig, epoch) 431 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import logging 5 | import torch 6 | # import argparse 7 | 8 | '''' 9 | class Logger: 10 | class Parser: 11 | ''' 12 | class Parser: 13 | def __init__(self, parser): 14 | self.__parser = parser 15 | self.__args = parser.parse_args() 16 | 17 | # set gpu ids 18 | str_ids = self.__args.gpu_ids.split(',') 19 | self.__args.gpu_ids = [] 20 | for str_id in str_ids: 21 | id = int(str_id) 22 | if id >= 0: 23 | self.__args.gpu_ids.append(id) 24 | # if len(self.__args.gpu_ids) > 0: 25 | # torch.cuda.set_device(self.__args.gpu_ids[0]) 26 | 27 | def get_parser(self): 28 | return self.__parser 29 | 30 | def get_arguments(self): 31 | return self.__args 32 | 33 | def write_args(self): 34 | params_dict = vars(self.__args) 35 | 36 | log_dir = os.path.join(params_dict['dir_log'], params_dict['scope'], params_dict['name_data']) 37 | args_name = os.path.join(log_dir, 'args.txt') 38 | 39 | if not os.path.exists(log_dir): 40 | os.makedirs(log_dir) 41 | 42 | with open(args_name, 'wt') as args_fid: 43 | args_fid.write('----' * 10 + '\n') 44 | args_fid.write('{0:^40}'.format('PARAMETER TABLES') + '\n') 45 | args_fid.write('----' * 10 + '\n') 46 | for k, v in sorted(params_dict.items()): 47 | args_fid.write('{}'.format(str(k)) + ' : ' + ('{0:>%d}' % (35 - len(str(k)))).format(str(v)) + '\n') 48 | args_fid.write('----' * 10 + '\n') 49 | 50 | def print_args(self, name='PARAMETER TABLES'): 51 | params_dict = vars(self.__args) 52 | 53 | print('----' * 10) 54 | print('{0:^40}'.format(name)) 55 | print('----' * 10) 56 | for k, v in sorted(params_dict.items()): 57 | if '__' not in str(k): 58 | print('{}'.format(str(k)) + ' : ' + ('{0:>%d}' % (35 - len(str(k)))).format(str(v))) 59 | print('----' * 10) 60 | 61 | 62 | class Logger: 63 | def __init__(self, info=logging.INFO, name=__name__): 64 | logger = logging.getLogger(name) 65 | logger.setLevel(info) 66 | 67 | self.__logger = logger 68 | 69 | def get_logger(self, handler_type='stream_handler'): 70 | if handler_type == 'stream_handler': 71 | handler = logging.StreamHandler() 72 | log_format = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 73 | handler.setFormatter(log_format) 74 | else: 75 | handler = logging.FileHandler('utils.log') 76 | 77 | self.__logger.addHandler(handler) 78 | 79 | return self.__logger 80 | --------------------------------------------------------------------------------
step%s
%d%s