├── .gitignore
├── .idea
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── pytorch-WGAN-GP.iml
├── rSettings.xml
└── vcs.xml
├── README.md
├── __init__.py
├── dataset.py
├── display_result.py
├── img
├── generated_images.png
└── paper1.png
├── layer.py
├── main.py
├── model.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | checkpoints/
2 | datasets/
3 | log/
4 | results/
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /workspace.xml
3 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/pytorch-WGAN-GP.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/rSettings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # WGAN-GP
2 |
3 | ### Title
4 | [Improved Training of Wasserstein GANs](https://arxiv.org/abs/1704.00028)
5 |
6 | ### Abstract
7 | Generative Adversarial Networks (GANs) are powerful generative models, but suffer from training instability. The recently proposed Wasserstein GAN (WGAN) makes progress toward stable training of GANs, but sometimes can still generate only low-quality samples or fail to converge. We find that these problems are often due to the use of weight clipping in WGAN to enforce a Lipschitz constraint on the critic, which can lead to undesired behavior. We propose an alternative to clipping weights: penalize the norm of gradient of the critic with respect to its input. Our proposed method performs better than standard WGAN and enables stable training of a wide variety of GAN architectures with almost no hyperparameter tuning, including 101-layer ResNets and language models over discrete data. We also achieve high quality generations on CIFAR-10 and LSUN bedrooms.
8 |
9 | 
10 |
11 | ## Train
12 | $ python main.py --mode train \
13 | --scope [scope name] \
14 | --name_data [data name] \
15 | --dir_data [data directory] \
16 | --dir_log [log directory] \
17 | --dir_checkpoint [checkpoint directory]
18 | --gpu_ids [gpu id; '-1': no gpu, '0, 1, ..., N-1': gpus]
19 | ---
20 | $ python main.py --mode train \
21 | --scope wgan-gp \
22 | --name_data celeba \
23 | --dir_data ./datasets \
24 | --dir_log ./log \
25 | --dir_checkpoint ./checkpoint
26 | --gpu_ids 0
27 |
28 | * Set **[scope name]** uniquely.
29 | * Hyperparameters were written to **arg.txt** under the **[log directory]**.
30 | * To understand hierarchy of directories based on their arguments, see **directories structure** below.
31 |
32 |
33 | ## Test
34 | $ python main.py --mode test \
35 | --scope [scope name] \
36 | --name_data [data name] \
37 | --dir_data [data directory] \
38 | --dir_log [log directory] \
39 | --dir_checkpoint [checkpoint directory] \
40 | --dir_result [result directory]
41 | --gpu_ids [gpu id; '-1': no gpu, '0, 1, ..., N-1': gpus]
42 | ---
43 | $ python main.py --mode test \
44 | --scope wgan-gp \
45 | --name_data celeba \
46 | --dir_data ./datasets \
47 | --dir_log ./log \
48 | --dir_checkpoint ./checkpoints \
49 | --dir_result ./results
50 | --gpu_ids 0
51 |
52 | * To test using trained network, set **[scope name]** defined in the **train** phase.
53 | * Generated images are saved in the **images** subfolder along with **[result directory]** folder.
54 | * **index.html** is also generated to display the generated images.
55 |
56 |
57 | ## Tensorboard
58 | $ tensorboard --logdir [log directory]/[scope name]/[data name] \
59 | --port [(optional) 4 digit port number]
60 | ---
61 | $ tensorboard --logdir ./log/wgan-gp/celeba \
62 | --port 6006
63 |
64 | After the above comment executes, go **http://localhost:6006**
65 |
66 | * You can change **[(optional) 4 digit port number]**.
67 | * Default 4 digit port number is **6006**.
68 |
69 |
70 | ## Results
71 | 
72 | * The results were generated by a network trained with **celeba** dataset during **10 epochs**.
73 | * After the Test phase runs, execute **display_result.py** to display the figure.
74 |
75 |
76 | ## Directories structure
77 | pytorch-WGAN-GP
78 | +---[dir_checkpoint]
79 | | \---[scope]
80 | | \---[name_data]
81 | | +---model_epoch00000.pth
82 | | | ...
83 | | \---model_epoch12345.pth
84 | +---[dir_data]
85 | | \---[name_data]
86 | | +---000000.png
87 | | | ...
88 | | \---12345.png
89 | +---[dir_log]
90 | | \---[scope]
91 | | \---[name_data]
92 | | +---arg.txt
93 | | \---events.out.tfevents
94 | \---[dir_result]
95 | \---[scope]
96 | \---[name_data]
97 | +---images
98 | | +---00000-output.png
99 | | | ...
100 | | +---12345-output.png
101 | \---index.html
102 |
103 | ---
104 |
105 | pytorch-WGAN-GP
106 | +---checkpoints
107 | | \---wgan-gp
108 | | \---celeba
109 | | +---model_epoch00001.pth
110 | | | ...
111 | | \---model_epoch0010.pth
112 | +---datasets
113 | | \---celeba
114 | | +---000001.jpg
115 | | | ...
116 | | \---202599.jpg
117 | +---log
118 | | \---wgan-gp
119 | | \---celeba
120 | | +---arg.txt
121 | | \---events.out.tfevents
122 | \---results
123 | \---wgan-gp
124 | \---celeba
125 | +---images
126 | | +---0000-output.png
127 | | | ...
128 | | +---0127-output.png
129 | \---index.html
130 |
131 | * Above directory is created by setting arguments when **main.py** is executed.
132 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hanyoseob/pytorch-WGAN-GP/311745b5e05828c71d8bc22d9dd10ccdae4ab000/__init__.py
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from skimage import transform
5 | import matplotlib.pyplot as plt
6 |
7 |
8 | class Dataset(torch.utils.data.Dataset):
9 | def __init__(self, data_dir, data_type='float32', nch=1, transform=[]):
10 | self.data_dir = data_dir
11 | self.transform = transform
12 | self.nch = nch
13 | self.data_type = data_type
14 |
15 | lst_data = os.listdir(data_dir)
16 |
17 | self.names = lst_data
18 |
19 | def __getitem__(self, index):
20 | data = plt.imread(os.path.join(self.data_dir, self.names[index]))[:, :, :self.nch]
21 |
22 | if data.dtype == np.uint8:
23 | data = data / 255.0
24 |
25 | if self.transform:
26 | data = self.transform(data)
27 |
28 | return data
29 |
30 | def __len__(self):
31 | return len(self.names)
32 |
33 |
34 | class ToTensor(object):
35 | def __call__(self, data):
36 | data = data.transpose((2, 0, 1)).astype(np.float32)
37 | return torch.from_numpy(data)
38 |
39 |
40 | class Normalize(object):
41 | def __call__(self, data):
42 | data = 2 * data - 1
43 | return data
44 |
45 |
46 | class RandomFlip(object):
47 | def __call__(self, data):
48 | if np.random.rand() > 0.5:
49 | data = np.fliplr(data)
50 |
51 | return data
52 |
53 |
54 | class Rescale(object):
55 | def __init__(self, output_size):
56 | assert isinstance(output_size, (int, tuple))
57 | self.output_size = output_size
58 |
59 | def __call__(self, data):
60 | h, w = data.shape[:2]
61 |
62 | if isinstance(self.output_size, int):
63 | if h > w:
64 | new_h, new_w = self.output_size * h / w, self.output_size
65 | else:
66 | new_h, new_w = self.output_size, self.output_size * w / h
67 | else:
68 | new_h, new_w = self.output_size
69 |
70 | new_h, new_w = int(new_h), int(new_w)
71 |
72 | data = transform.resize(data, (new_h, new_w))
73 | return data
74 |
75 |
76 | class CenterCrop(object):
77 | def __init__(self, output_size):
78 | assert isinstance(output_size, (int, tuple))
79 | if isinstance(output_size, int):
80 | self.output_size = (output_size, output_size)
81 | else:
82 | assert len(output_size) == 2
83 | self.output_size = output_size
84 |
85 | def __call__(self, data):
86 | h, w = data.shape[:2]
87 |
88 | new_h, new_w = self.output_size
89 |
90 | top = int(abs(h - new_h) / 2)
91 | left = int(abs(w - new_w) / 2)
92 |
93 | data = data[top: top + new_h, left: left + new_w]
94 |
95 | return data
96 |
97 |
98 | class RandomCrop(object):
99 |
100 | def __init__(self, output_size):
101 |
102 | assert isinstance(output_size, (int, tuple))
103 | if isinstance(output_size, int):
104 | self.output_size = (output_size, output_size)
105 | else:
106 | assert len(output_size) == 2
107 | self.output_size = output_size
108 |
109 | def __call__(self, data):
110 | h, w = data.shape[:2]
111 |
112 | new_h, new_w = self.output_size
113 |
114 | top = np.random.randint(0, h - new_h)
115 | left = np.random.randint(0, w - new_w)
116 |
117 | data = data[top: top + new_h, left: left + new_w]
118 | return data
119 |
120 |
121 | class ToNumpy(object):
122 | def __call__(self, data):
123 |
124 | if data.ndim == 3:
125 | data = data.to('cpu').detach().numpy().transpose((1, 2, 0))
126 | elif data.ndim == 4:
127 | data = data.to('cpu').detach().numpy().transpose((0, 2, 3, 1))
128 |
129 | return data
130 |
131 |
132 | class Denomalize(object):
133 | def __call__(self, data):
134 |
135 | return (data + 1) / 2
136 |
--------------------------------------------------------------------------------
/display_result.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import torchvision.utils as vutils
5 | import matplotlib.pyplot as plt
6 |
7 | dir_result = './results/wgan-gp/celeba/images'
8 | lst_result = os.listdir(dir_result)
9 |
10 | np.random.shuffle(lst_result)
11 |
12 | nx = 64
13 | ny = 64
14 | nch = 3
15 |
16 | n = 8
17 | m = 4
18 |
19 | n_id = np.arange(len(lst_result)//m)
20 | np.random.shuffle(n_id)
21 | img = torch.zeros((n*m, ny, nx, nch))
22 |
23 | for i in range(n*m):
24 | p = n_id[i]
25 | img[i, :, :, :] = torch.from_numpy(plt.imread(os.path.join(dir_result, lst_result[p]))[:, :, :nch])
26 |
27 | img = img.permute((0, 3, 1, 2))
28 |
29 | plt.figure(figsize=(n, m))
30 | plt.axis("off")
31 | # plt.title("Generated Images")
32 | plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
33 | plt.imshow(np.transpose(vutils.make_grid(img, padding=2, normalize=True), (1, 2, 0)))
34 |
35 | plt.show()
36 |
37 |
--------------------------------------------------------------------------------
/img/generated_images.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hanyoseob/pytorch-WGAN-GP/311745b5e05828c71d8bc22d9dd10ccdae4ab000/img/generated_images.png
--------------------------------------------------------------------------------
/img/paper1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hanyoseob/pytorch-WGAN-GP/311745b5e05828c71d8bc22d9dd10ccdae4ab000/img/paper1.png
--------------------------------------------------------------------------------
/layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class CNR2d(nn.Module):
7 | def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, norm='bnorm', relu=0.0, drop=[], bias=[]):
8 | super().__init__()
9 |
10 | if bias == []:
11 | if norm == 'bnorm':
12 | bias = False
13 | else:
14 | bias = True
15 |
16 | layers = []
17 | layers += [Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)]
18 |
19 | if norm != []:
20 | layers += [Norm2d(nch_out, norm)]
21 |
22 | if relu != []:
23 | layers += [ReLU(relu)]
24 |
25 | if drop != []:
26 | layers += [nn.Dropout2d(drop)]
27 |
28 | self.cbr = nn.Sequential(*layers)
29 |
30 | def forward(self, x):
31 | return self.cbr(x)
32 |
33 |
34 | class DECNR2d(nn.Module):
35 | def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, norm='bnorm', relu=0.0, drop=[], bias=[]):
36 | super().__init__()
37 |
38 | if bias == []:
39 | if norm == 'bnorm':
40 | bias = False
41 | else:
42 | bias = True
43 |
44 | layers = []
45 | layers += [Deconv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias)]
46 |
47 | if norm != []:
48 | layers += [Norm2d(nch_out, norm)]
49 |
50 | if relu != []:
51 | layers += [ReLU(relu)]
52 |
53 | if drop != []:
54 | layers += [nn.Dropout2d(drop)]
55 |
56 | self.decbr = nn.Sequential(*layers)
57 |
58 | def forward(self, x):
59 | return self.decbr(x)
60 |
61 |
62 | class ResBlock(nn.Module):
63 | def __init__(self, nch_in, nch_out, kernel_size=3, stride=1, padding=1, padding_mode='reflection', norm='inorm', relu=0.0, drop=[], bias=[]):
64 | super().__init__()
65 |
66 | if bias == []:
67 | if norm == 'bnorm':
68 | bias = False
69 | else:
70 | bias = True
71 |
72 | layers = []
73 |
74 | # 1st conv
75 | layers += [Padding(padding, padding_mode=padding_mode)]
76 | layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=relu)]
77 |
78 | if drop != []:
79 | layers += [nn.Dropout2d(drop)]
80 |
81 | # 2nd conv
82 | layers += [Padding(padding, padding_mode=padding_mode)]
83 | layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=[])]
84 |
85 | self.resblk = nn.Sequential(*layers)
86 |
87 | def forward(self, x):
88 | return x + self.resblk(x)
89 |
90 |
91 | class CNR1d(nn.Module):
92 | def __init__(self, nch_in, nch_out, norm='bnorm', relu=0.0, drop=[]):
93 | super().__init__()
94 |
95 | if norm == 'bnorm':
96 | bias = False
97 | else:
98 | bias = True
99 |
100 | layers = []
101 | layers += [nn.Linear(nch_in, nch_out, bias=bias)]
102 |
103 | if norm != []:
104 | layers += [Norm2d(nch_out, norm)]
105 |
106 | if relu != []:
107 | layers += [ReLU(relu)]
108 |
109 | if drop != []:
110 | layers += [nn.Dropout2d(drop)]
111 |
112 | self.cbr = nn.Sequential(*layers)
113 |
114 | def forward(self, x):
115 | return self.cbr(x)
116 |
117 |
118 | class Conv2d(nn.Module):
119 | def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, bias=True):
120 | super(Conv2d, self).__init__()
121 | self.conv = nn.Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
122 |
123 | def forward(self, x):
124 | return self.conv(x)
125 |
126 |
127 | class Deconv2d(nn.Module):
128 | def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, bias=True):
129 | super(Deconv2d, self).__init__()
130 | self.deconv = nn.ConvTranspose2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias)
131 |
132 | # layers = [nn.Upsample(scale_factor=2, mode='bilinear'),
133 | # nn.ReflectionPad2d(1),
134 | # nn.Conv2d(nch_in , nch_out, kernel_size=3, stride=1, padding=0)]
135 | #
136 | # self.deconv = nn.Sequential(*layers)
137 |
138 | def forward(self, x):
139 | return self.deconv(x)
140 |
141 |
142 | class Linear(nn.Module):
143 | def __init__(self, nch_in, nch_out):
144 | super(Linear, self).__init__()
145 | self.linear = nn.Linear(nch_in, nch_out)
146 |
147 | def forward(self, x):
148 | return self.linear(x)
149 |
150 |
151 | class Norm2d(nn.Module):
152 | def __init__(self, nch, norm_mode):
153 | super(Norm2d, self).__init__()
154 | if norm_mode == 'bnorm':
155 | self.norm = nn.BatchNorm2d(nch)
156 | elif norm_mode == 'inorm':
157 | self.norm = nn.InstanceNorm2d(nch)
158 |
159 | def forward(self, x):
160 | return self.norm(x)
161 |
162 |
163 | class ReLU(nn.Module):
164 | def __init__(self, relu):
165 | super(ReLU, self).__init__()
166 | if relu > 0:
167 | self.relu = nn.LeakyReLU(relu, True)
168 | elif relu == 0:
169 | self.relu = nn.ReLU(True)
170 |
171 | def forward(self, x):
172 | return self.relu(x)
173 |
174 |
175 | class Padding(nn.Module):
176 | def __init__(self, padding, padding_mode='zeros', value=0):
177 | super(Padding, self).__init__()
178 | if padding_mode == 'reflection':
179 | self. padding = nn.ReflectionPad2d(padding)
180 | elif padding_mode == 'replication':
181 | self.padding = nn.ReplicationPad2d(padding)
182 | elif padding_mode == 'constant':
183 | self.padding = nn.ConstantPad2d(padding, value)
184 | elif padding_mode == 'zeros':
185 | self.padding = nn.ZeroPad2d(padding)
186 |
187 | def forward(self, x):
188 | return self.padding(x)
189 |
190 |
191 | class Pooling2d(nn.Module):
192 | def __init__(self, nch=[], pool=2, type='avg'):
193 | super().__init__()
194 |
195 | if type == 'avg':
196 | self.pooling = nn.AvgPool2d(pool)
197 | elif type == 'max':
198 | self.pooling = nn.MaxPool2d(pool)
199 | elif type == 'conv':
200 | self.pooling = nn.Conv2d(nch, nch, kernel_size=pool, stride=pool)
201 |
202 | def forward(self, x):
203 | return self.pooling(x)
204 |
205 |
206 | class UnPooling2d(nn.Module):
207 | def __init__(self, nch=[], pool=2, type='nearest'):
208 | super().__init__()
209 |
210 | if type == 'nearest':
211 | self.unpooling = nn.Upsample(scale_factor=pool, mode='nearest', align_corners=True)
212 | elif type == 'bilinear':
213 | self.unpooling = nn.Upsample(scale_factor=pool, mode='bilinear', align_corners=True)
214 | elif type == 'conv':
215 | self.unpooling = nn.ConvTranspose2d(nch, nch, kernel_size=pool, stride=pool)
216 |
217 | def forward(self, x):
218 | return self.unpooling(x)
219 |
220 |
221 | class Concat(nn.Module):
222 | def __init__(self):
223 | super().__init__()
224 |
225 | def forward(self, x1, x2):
226 | diffy = x2.size()[2] - x1.size()[2]
227 | diffx = x2.size()[3] - x1.size()[3]
228 |
229 | x1 = F.pad(x1, [diffx // 2, diffx - diffx // 2,
230 | diffy // 2, diffy - diffy // 2])
231 |
232 | return torch.cat([x2, x1], dim=1)
233 |
234 |
235 | class TV1dLoss(nn.Module):
236 | def __init__(self):
237 | super(TV1dLoss, self).__init__()
238 |
239 | def forward(self, input):
240 | # loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \
241 | # torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :]))
242 | loss = torch.mean(torch.abs(input[:, :-1] - input[:, 1:]))
243 |
244 | return loss
245 |
246 |
247 | class TV2dLoss(nn.Module):
248 | def __init__(self):
249 | super(TV2dLoss, self).__init__()
250 |
251 | def forward(self, input):
252 | loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \
253 | torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :]))
254 | return loss
255 |
256 |
257 | class SSIM2dLoss(nn.Module):
258 | def __init__(self):
259 | super(SSIM2dLoss, self).__init__()
260 |
261 | def forward(self, input, targer):
262 | loss = 0
263 | return loss
264 |
265 |
266 | class GradientPaneltyLoss(nn.Module):
267 | def __init__(self):
268 | super(GradientPaneltyLoss, self).__init__()
269 |
270 | def forward(self, y, x):
271 | """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
272 | weight = torch.ones_like(y)
273 | dydx = torch.autograd.grad(outputs=y,
274 | inputs=x,
275 | grad_outputs=weight,
276 | retain_graph=True,
277 | create_graph=True,
278 | only_inputs=True)[0]
279 |
280 | dydx = dydx.view(dydx.size(0), -1)
281 | dydx_l2norm = torch.sqrt(torch.sum(dydx ** 2, dim=1))
282 | return torch.mean((dydx_l2norm - 1) ** 2)
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch.backends.cudnn as cudnn
3 |
4 | from train import *
5 | from utils import *
6 |
7 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
8 |
9 | cudnn.benchmark = True
10 | cudnn.fastest = True
11 |
12 | ## setup parse
13 | parser = argparse.ArgumentParser(description='Train the WGAN-GP network',
14 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
15 |
16 | parser.add_argument('--gpu_ids', default='0', dest='gpu_ids')
17 |
18 | parser.add_argument('--mode', default='train', choices=['train', 'test'], dest='mode')
19 | parser.add_argument('--train_continue', default='off', choices=['on', 'off'], dest='train_continue')
20 |
21 | parser.add_argument('--scope', default='wgan-gp', dest='scope')
22 | parser.add_argument('--norm', type=str, default='inorm', dest='norm')
23 |
24 | parser.add_argument('--name_data', type=str, default='celeba', dest='name_data')
25 |
26 | parser.add_argument('--dir_checkpoint', default='./checkpoints', dest='dir_checkpoint')
27 | parser.add_argument('--dir_log', default='./log', dest='dir_log')
28 |
29 | parser.add_argument('--dir_data', default='../datasets', dest='dir_data')
30 | parser.add_argument('--dir_result', default='./results', dest='dir_result')
31 |
32 | parser.add_argument('--num_epoch', type=int, default=10, dest='num_epoch')
33 | parser.add_argument('--batch_size', type=int, default=128, dest='batch_size')
34 |
35 | parser.add_argument('--lr_G', type=float, default=2e-4, dest='lr_G')
36 | parser.add_argument('--lr_D', type=float, default=2e-4, dest='lr_D')
37 |
38 | parser.add_argument('--num_freq_disp', type=int, default=50, dest='num_freq_disp')
39 | parser.add_argument('--num_freq_save', type=int, default=5, dest='num_freq_save')
40 |
41 | parser.add_argument('--lr_policy', type=str, default='linear', choices=['linear', 'step', 'plateau', 'cosine'], dest='lr_policy')
42 | parser.add_argument('--n_epochs', type=int, default=100, dest='n_epochs')
43 | parser.add_argument('--n_epochs_decay', type=int, default=100, dest='n_epochs_decay')
44 | parser.add_argument('--lr_decay_iters', type=int, default=50, dest='lr_decay_iters')
45 |
46 | parser.add_argument('--wgt_gan', type=float, default=1e0, dest='wgt_gan')
47 | parser.add_argument('--wgt_disc', type=float, default=1e0, dest='wgt_disc')
48 |
49 | parser.add_argument('--optim', default='adam', choices=['sgd', 'adam', 'rmsprop'], dest='optim')
50 | parser.add_argument('--beta1', default=0.5, dest='beta1')
51 |
52 | parser.add_argument('--ny_in', type=int, default=1, dest='ny_in')
53 | parser.add_argument('--nx_in', type=int, default=1, dest='nx_in')
54 | parser.add_argument('--nch_in', type=int, default=100, dest='nch_in')
55 |
56 | parser.add_argument('--ny_load', type=int, default=64, dest='ny_load')
57 | parser.add_argument('--nx_load', type=int, default=64, dest='nx_load')
58 | parser.add_argument('--nch_load', type=int, default=3, dest='nch_load')
59 |
60 | parser.add_argument('--ny_out', type=int, default=64, dest='ny_out')
61 | parser.add_argument('--nx_out', type=int, default=64, dest='nx_out')
62 | parser.add_argument('--nch_out', type=int, default=3, dest='nch_out')
63 |
64 | parser.add_argument('--nch_ker', type=int, default=64, dest='nch_ker')
65 |
66 | parser.add_argument('--data_type', default='float32', dest='data_type')
67 |
68 | PARSER = Parser(parser)
69 |
70 | def main():
71 | ARGS = PARSER.get_arguments()
72 | PARSER.write_args()
73 | PARSER.print_args()
74 |
75 | TRAINER = Train(ARGS)
76 |
77 | if ARGS.mode == 'train':
78 | TRAINER.train()
79 | elif ARGS.mode == 'test':
80 | TRAINER.test()
81 |
82 | if __name__ == '__main__':
83 | main()
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from layer import *
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import init
6 | from torch.optim import lr_scheduler
7 |
8 |
9 | class DCGAN(nn.Module):
10 | def __init__(self, nch_in, nch_out, nch_ker=64, norm='bnorm'):
11 | super(DCGAN, self).__init__()
12 |
13 | self.nch_in = nch_in
14 | self.nch_out = nch_out
15 | self.nch_ker = nch_ker
16 | self.norm = norm
17 |
18 | if norm == 'bnorm':
19 | self.bias = False
20 | else:
21 | self.bias = True
22 |
23 | self.dec5 = DECNR2d(1 * self.nch_in, 8 * self.nch_ker, kernel_size=4, stride=1, padding=0, norm=self.norm, relu=0.0, drop=[])
24 | self.dec4 = DECNR2d(8 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0, drop=[])
25 | self.dec3 = DECNR2d(4 * self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0, drop=[])
26 | self.dec2 = DECNR2d(2 * self.nch_ker, 1 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0, drop=[])
27 | self.dec1 = Deconv2d(1 * self.nch_ker, 1 * self.nch_out,kernel_size=4, stride=2, padding=1, bias=False)
28 |
29 | def forward(self, x):
30 |
31 | x = self.dec5(x)
32 | x = self.dec4(x)
33 | x = self.dec3(x)
34 | x = self.dec2(x)
35 | x = self.dec1(x)
36 |
37 | x = torch.tanh(x)
38 |
39 | return x
40 |
41 |
42 | class UNet(nn.Module):
43 | def __init__(self, nch_in, nch_out, nch_ker=64, norm='bnorm'):
44 | super(UNet, self).__init__()
45 |
46 | self.nch_in = nch_in
47 | self.nch_out = nch_out
48 | self.nch_ker = nch_ker
49 | self.norm = norm
50 |
51 | if norm == 'bnorm':
52 | self.bias = False
53 | else:
54 | self.bias = True
55 |
56 | self.enc1 = CNR2d(1 * self.nch_in, 1 * self.nch_ker, stride=2, norm=self.norm, relu=0.2, drop=[])
57 | self.enc2 = CNR2d(1 * self.nch_ker, 2 * self.nch_ker, stride=2, norm=self.norm, relu=0.2, drop=[])
58 | self.enc3 = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, stride=2, norm=self.norm, relu=0.2, drop=[])
59 | self.enc4 = CNR2d(4 * self.nch_ker, 8 * self.nch_ker, stride=2, norm=self.norm, relu=0.2, drop=[])
60 | self.enc5 = CNR2d(8 * self.nch_ker, 8 * self.nch_ker, stride=2, norm=self.norm, relu=0.2, drop=[])
61 | self.enc6 = CNR2d(8 * self.nch_ker, 8 * self.nch_ker, stride=2, norm=self.norm, relu=0.2, drop=[])
62 | self.enc7 = CNR2d(8 * self.nch_ker, 8 * self.nch_ker, stride=2, norm=self.norm, relu=0.2, drop=[])
63 | self.enc8 = CNR2d(8 * self.nch_ker, 8 * self.nch_ker, stride=2, norm=self.norm, relu=0.0, drop=[])
64 |
65 | self.dec8 = DECNR2d(1 * 8 * self.nch_ker, 8 * self.nch_ker, stride=2, norm=self.norm, relu=0.0, drop=0.5)
66 | self.dec7 = DECNR2d(2 * 8 * self.nch_ker, 8 * self.nch_ker, stride=2, norm=self.norm, relu=0.0, drop=0.5)
67 | self.dec6 = DECNR2d(2 * 8 * self.nch_ker, 8 * self.nch_ker, stride=2, norm=self.norm, relu=0.0, drop=0.5)
68 | self.dec5 = DECNR2d(2 * 8 * self.nch_ker, 8 * self.nch_ker, stride=2, norm=self.norm, relu=0.0, drop=[])
69 | self.dec4 = DECNR2d(2 * 8 * self.nch_ker, 4 * self.nch_ker, stride=2, norm=self.norm, relu=0.0, drop=[])
70 | self.dec3 = DECNR2d(2 * 4 * self.nch_ker, 2 * self.nch_ker, stride=2, norm=self.norm, relu=0.0, drop=[])
71 | self.dec2 = DECNR2d(2 * 2 * self.nch_ker, 1 * self.nch_ker, stride=2, norm=self.norm, relu=0.0, drop=[])
72 | self.dec1 = DECNR2d(2 * 1 * self.nch_ker, 1 * self.nch_out, stride=2, norm=[], relu=[], drop=[], bias=False)
73 |
74 | def forward(self, x):
75 |
76 | enc1 = self.enc1(x)
77 | enc2 = self.enc2(enc1)
78 | enc3 = self.enc3(enc2)
79 | enc4 = self.enc4(enc3)
80 | enc5 = self.enc5(enc4)
81 | enc6 = self.enc6(enc5)
82 | enc7 = self.enc7(enc6)
83 | enc8 = self.enc8(enc7)
84 |
85 | dec8 = self.dec8(enc8)
86 | dec7 = self.dec7(torch.cat([enc7, dec8], dim=1))
87 | dec6 = self.dec6(torch.cat([enc6, dec7], dim=1))
88 | dec5 = self.dec5(torch.cat([enc5, dec6], dim=1))
89 | dec4 = self.dec4(torch.cat([enc4, dec5], dim=1))
90 | dec3 = self.dec3(torch.cat([enc3, dec4], dim=1))
91 | dec2 = self.dec2(torch.cat([enc2, dec3], dim=1))
92 | dec1 = self.dec1(torch.cat([enc1, dec2], dim=1))
93 |
94 | x = torch.tanh(dec1)
95 |
96 | return x
97 |
98 |
99 | class ResNet(nn.Module):
100 | def __init__(self, nch_in, nch_out, nch_ker=64, norm='bnorm', nblk=6):
101 | super(ResNet, self).__init__()
102 |
103 | self.nch_in = nch_in
104 | self.nch_out = nch_out
105 | self.nch_ker = nch_ker
106 | self.norm = norm
107 | self.nblk = nblk
108 |
109 | if norm == 'bnorm':
110 | self.bias = False
111 | else:
112 | self.bias = True
113 |
114 | self.enc1 = CNR2d(self.nch_in, 1 * self.nch_ker, kernel_size=7, stride=1, padding=3, norm=self.norm, relu=0.0)
115 |
116 | self.enc2 = CNR2d(1 * self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
117 |
118 | self.enc3 = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
119 |
120 | if self.nblk:
121 | res = []
122 |
123 | for i in range(self.nblk):
124 | res += [ResBlock(4 * self.nch_ker, 4 * self.nch_ker, kernel_size=3, stride=1, padding=1, norm=self.norm, relu=0.0, padding_mode='reflection')]
125 |
126 | self.res = nn.Sequential(*res)
127 |
128 | self.dec3 = DECNR2d(4 * self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
129 |
130 | self.dec2 = DECNR2d(2 * self.nch_ker, 1 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.0)
131 |
132 | self.dec1 = CNR2d(1 * self.nch_ker, self.nch_out, kernel_size=7, stride=1, padding=3, norm=[], relu=[], bias=False)
133 |
134 | def forward(self, x):
135 | x = self.enc1(x)
136 | x = self.enc2(x)
137 | x = self.enc3(x)
138 |
139 | if self.nblk:
140 | x = self.res(x)
141 |
142 | x = self.dec3(x)
143 | x = self.dec2(x)
144 | x = self.dec1(x)
145 |
146 | x = torch.tanh(x)
147 |
148 | return x
149 |
150 |
151 | class Discriminator(nn.Module):
152 | def __init__(self, nch_in, nch_ker=64, norm='bnorm'):
153 | super(Discriminator, self).__init__()
154 |
155 | self.nch_in = nch_in
156 | self.nch_ker = nch_ker
157 | self.norm = norm
158 |
159 | if norm == 'bnorm':
160 | self.bias = False
161 | else:
162 | self.bias = True
163 |
164 | # dsc1 : 256 x 256 x 3 -> 128 x 128 x 64
165 | # dsc2 : 128 x 128 x 64 -> 64 x 64 x 128
166 | # dsc3 : 64 x 64 x 128 -> 32 x 32 x 256
167 | # dsc4 : 32 x 32 x 256 -> 32 x 32 x 512
168 | # dsc5 : 32 x 32 x 512 -> 32 x 32 x 1
169 |
170 | self.dsc1 = CNR2d(1 * self.nch_in, 1 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.2)
171 | self.dsc2 = CNR2d(1 * self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.2)
172 | self.dsc3 = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.2)
173 | self.dsc4 = CNR2d(4 * self.nch_ker, 8 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=self.norm, relu=0.2)
174 | self.dsc5 = CNR2d(8 * self.nch_ker, 1, kernel_size=4, stride=1, padding=1, norm=[], relu=[], bias=False)
175 |
176 | # self.dsc1 = CNR2d(1 * self.nch_in, 1 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=[], relu=0.2)
177 | # self.dsc2 = CNR2d(1 * self.nch_ker, 2 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=[], relu=0.2)
178 | # self.dsc3 = CNR2d(2 * self.nch_ker, 4 * self.nch_ker, kernel_size=4, stride=2, padding=1, norm=[], relu=0.2)
179 | # self.dsc4 = CNR2d(4 * self.nch_ker, 8 * self.nch_ker, kernel_size=4, stride=1, padding=1, norm=[], relu=0.2)
180 | # self.dsc5 = CNR2d(8 * self.nch_ker, 1, kernel_size=4, stride=1, padding=1, norm=[], relu=[], bias=False)
181 |
182 | def forward(self, x):
183 |
184 | x = self.dsc1(x)
185 | x = self.dsc2(x)
186 | x = self.dsc3(x)
187 | x = self.dsc4(x)
188 | x = self.dsc5(x)
189 |
190 | # x = torch.sigmoid(x)
191 |
192 | return x
193 |
194 |
195 | def init_weights(net, init_type='normal', init_gain=0.02):
196 | """Initialize network weights.
197 |
198 | Parameters:
199 | net (network) -- network to be initialized
200 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
201 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
202 |
203 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
204 | work better for some applications. Feel free to try yourself.
205 | """
206 | def init_func(m): # define the initialization function
207 | classname = m.__class__.__name__
208 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
209 | if init_type == 'normal':
210 | init.normal_(m.weight.data, 0.0, init_gain)
211 | elif init_type == 'xavier':
212 | init.xavier_normal_(m.weight.data, gain=init_gain)
213 | elif init_type == 'kaiming':
214 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
215 | elif init_type == 'orthogonal':
216 | init.orthogonal_(m.weight.data, gain=init_gain)
217 | else:
218 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
219 | if hasattr(m, 'bias') and m.bias is not None:
220 | init.constant_(m.bias.data, 0.0)
221 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
222 | init.normal_(m.weight.data, 1.0, init_gain)
223 | init.constant_(m.bias.data, 0.0)
224 |
225 | print('initialize network with %s' % init_type)
226 | net.apply(init_func) # apply the initialization function
227 |
228 |
229 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
230 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
231 | Parameters:
232 | net (network) -- the network to be initialized
233 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
234 | gain (float) -- scaling factor for normal, xavier and orthogonal.
235 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
236 |
237 | Return an initialized network.
238 | """
239 | if gpu_ids:
240 | assert(torch.cuda.is_available())
241 | net.to(gpu_ids[0])
242 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
243 | init_weights(net, init_type, init_gain=init_gain)
244 | return net
245 |
246 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from model import *
2 | from dataset import *
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from torchvision import transforms
8 | from torch.utils.tensorboard import SummaryWriter
9 | import matplotlib.pyplot as plt
10 | from matplotlib.pyplot import cm
11 | from statistics import mean
12 |
13 |
14 | class Train:
15 | def __init__(self, args):
16 | self.mode = args.mode
17 | self.train_continue = args.train_continue
18 |
19 | self.scope = args.scope
20 | self.dir_checkpoint = args.dir_checkpoint
21 | self.dir_log = args.dir_log
22 |
23 | self.dir_data = args.dir_data
24 | self.dir_result = args.dir_result
25 |
26 | self.num_epoch = args.num_epoch
27 | self.batch_size = args.batch_size
28 |
29 | self.lr_G = args.lr_G
30 | self.lr_D = args.lr_D
31 |
32 | self.wgt_gan = args.wgt_gan
33 | self.wgt_disc = args.wgt_disc
34 |
35 | self.optim = args.optim
36 | self.beta1 = args.beta1
37 |
38 | self.ny_in = args.ny_in
39 | self.nx_in = args.nx_in
40 | self.nch_in = args.nch_in
41 |
42 | self.ny_load = args.ny_load
43 | self.nx_load = args.nx_load
44 | self.nch_load = args.nch_load
45 |
46 | self.ny_out = args.ny_out
47 | self.nx_out = args.nx_out
48 | self.nch_out = args.nch_out
49 |
50 | self.nch_ker = args.nch_ker
51 |
52 | self.data_type = args.data_type
53 | self.norm = args.norm
54 |
55 | self.gpu_ids = args.gpu_ids
56 |
57 | self.num_freq_disp = args.num_freq_disp
58 | self.num_freq_save = args.num_freq_save
59 |
60 | self.name_data = args.name_data
61 |
62 | if self.gpu_ids and torch.cuda.is_available():
63 | self.device = torch.device("cuda:%d" % self.gpu_ids[0])
64 | torch.cuda.set_device(self.gpu_ids[0])
65 | else:
66 | self.device = torch.device("cpu")
67 |
68 | def save(self, dir_chck, netG, netD, optimG, optimD, epoch):
69 | if not os.path.exists(dir_chck):
70 | os.makedirs(dir_chck)
71 |
72 | torch.save({'netG': netG.state_dict(), 'netD': netD.state_dict(),
73 | 'optimG': optimG.state_dict(), 'optimD': optimD.state_dict()},
74 | '%s/model_epoch%04d.pth' % (dir_chck, epoch))
75 |
76 | def load(self, dir_chck, netG, netD=[], optimG=[], optimD=[], epoch=[], mode='train'):
77 | if not epoch:
78 | ckpt = os.listdir(dir_chck)
79 | ckpt.sort()
80 | epoch = int(ckpt[-1].split('epoch')[1].split('.pth')[0])
81 |
82 | dict_net = torch.load('%s/model_epoch%04d.pth' % (dir_chck, epoch))
83 |
84 | print('Loaded %dth network' % epoch)
85 |
86 | if mode == 'train':
87 | netG.load_state_dict(dict_net['netG'])
88 | netD.load_state_dict(dict_net['netD'])
89 | optimG.load_state_dict(dict_net['optimG'])
90 | optimD.load_state_dict(dict_net['optimD'])
91 |
92 | return netG, netD, optimG, optimD, epoch
93 |
94 | elif mode == 'test':
95 | netG.load_state_dict(dict_net['netG'])
96 |
97 | return netG, epoch
98 |
99 | def preprocess(self, data):
100 | rescale = Rescale((self.ny_load, self.nx_load))
101 | randomcrop = RandomCrop((self.ny_out, self.nx_out))
102 | normalize = Normalize()
103 | randomflip = RandomFlip()
104 | totensor = ToTensor()
105 | # return totensor(randomcrop(rescale(randomflip(nomalize(data)))))
106 | return totensor(normalize(rescale(data)))
107 |
108 | def deprocess(self, data):
109 | tonumpy = ToNumpy()
110 | denomalize = Denomalize()
111 | return denomalize(tonumpy(data))
112 |
113 |
114 | def train(self):
115 | mode = self.mode
116 |
117 | train_continue = self.train_continue
118 | num_epoch = self.num_epoch
119 |
120 | lr_G = self.lr_G
121 | lr_D = self.lr_D
122 |
123 | wgt_gan = self.wgt_gan
124 | wgt_disc = self.wgt_disc
125 |
126 | batch_size = self.batch_size
127 | device = self.device
128 |
129 | gpu_ids = self.gpu_ids
130 |
131 | nch_in = self.nch_in
132 | nch_out = self.nch_out
133 | nch_ker = self.nch_ker
134 |
135 | norm = self.norm
136 | name_data = self.name_data
137 |
138 | num_freq_disp = self.num_freq_disp
139 | num_freq_save = self.num_freq_save
140 |
141 | ny_in = self.ny_in
142 | nx_in = self.nx_in
143 |
144 | ## setup dataset
145 | dir_chck = os.path.join(self.dir_checkpoint, self.scope, name_data)
146 |
147 | dir_data_train = os.path.join(self.dir_data, name_data)
148 | dir_log = os.path.join(self.dir_log, self.scope, name_data)
149 |
150 | transform_train = transforms.Compose([Normalize(), Rescale((self.ny_load, self.nx_load)), ToTensor()])
151 | transform_inv = transforms.Compose([ToNumpy(), Denomalize()])
152 |
153 | dataset_train = Dataset(dir_data_train, data_type=self.data_type, nch=self.nch_out, transform=transform_train)
154 |
155 | loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=8)
156 |
157 | num_train = len(dataset_train)
158 |
159 | num_batch_train = int((num_train / batch_size) + ((num_train % batch_size) != 0))
160 |
161 | ## setup network
162 | netG = DCGAN(nch_in, nch_out, nch_ker, norm)
163 | netD = Discriminator(nch_out, nch_ker, [])
164 |
165 | init_net(netG, init_type='normal', init_gain=0.02, gpu_ids=gpu_ids)
166 | init_net(netD, init_type='normal', init_gain=0.02, gpu_ids=gpu_ids)
167 |
168 | ## setup loss & optimization
169 | fn_GAN = nn.BCEWithLogitsLoss().to(device)
170 | fn_GP = GradientPaneltyLoss().to(device)
171 |
172 | paramsG = netG.parameters()
173 | paramsD = netD.parameters()
174 |
175 | optimG = torch.optim.Adam(paramsG, lr=lr_G, betas=(self.beta1, 0.999))
176 | optimD = torch.optim.Adam(paramsD, lr=lr_D, betas=(self.beta1, 0.999))
177 |
178 | # schedG = get_scheduler(optimG, self.opts)
179 | # schedD = get_scheduler(optimD, self.opts)
180 |
181 | # schedG = torch.optim.lr_scheduler.ExponentialLR(optimG, gamma=0.9)
182 | # schedD = torch.optim.lr_scheduler.ExponentialLR(optimD, gamma=0.9)
183 |
184 | ## load from checkpoints
185 | st_epoch = 0
186 |
187 | if train_continue == 'on':
188 | netG, netD, optimG, optimD, st_epoch = self.load(dir_chck, netG, netD, optimG, optimD, mode=mode)
189 |
190 | ## setup tensorboard
191 | writer_train = SummaryWriter(log_dir=dir_log)
192 |
193 | for epoch in range(st_epoch + 1, num_epoch + 1):
194 | ## training phase
195 | netG.train()
196 | netD.train()
197 |
198 | loss_G_train = []
199 | loss_D_real_train = []
200 | loss_D_fake_train = []
201 |
202 | for i, data in enumerate(loader_train, 1):
203 | def should(freq):
204 | return freq > 0 and (i % freq == 0 or i == num_batch_train)
205 |
206 | label = data.to(device)
207 | input = torch.randn(label.size(0), nch_in, ny_in, nx_in).to(device)
208 |
209 | # forward netG
210 | output = netG(input)
211 |
212 | # backward netD
213 | set_requires_grad(netD, True)
214 | optimD.zero_grad()
215 |
216 | pred_real = netD(label)
217 | pred_fake = netD(output.detach())
218 |
219 | alpha = torch.rand(label.size(0), 1, 1, 1).to(self.device)
220 | output_ = (alpha * label + (1 - alpha) * output.detach()).requires_grad_(True)
221 | src_out_ = netD(output_)
222 |
223 | # BCE Loss
224 | # loss_D_real = fn_GAN(pred_real, torch.ones_like(pred_real))
225 | # loss_D_fake = fn_GAN(pred_fake, torch.zeros_like(pred_fake))
226 |
227 | # WGAN Loss
228 | loss_D_real = torch.mean(pred_real)
229 | loss_D_fake = -torch.mean(pred_fake)
230 |
231 | # Gradient penalty Loss
232 | loss_D_gp = fn_GP(src_out_, output_)
233 |
234 | loss_D = 0.5 * (loss_D_real + loss_D_fake) + loss_D_gp
235 | # loss_D = 0.5 * (loss_D_real + loss_D_fake)
236 |
237 | loss_D.backward()
238 | optimD.step()
239 |
240 | # backward netG
241 | set_requires_grad(netD, False)
242 | optimG.zero_grad()
243 |
244 | pred_fake = netD(output)
245 |
246 | # loss_G = fn_GAN(pred_fake, torch.ones_like(pred_fake))
247 | loss_G = torch.mean(pred_fake)
248 |
249 | loss_G.backward()
250 | optimG.step()
251 |
252 | # get losses
253 | loss_G_train += [loss_G.item()]
254 | loss_D_real_train += [loss_D_real.item()]
255 | loss_D_fake_train += [loss_D_fake.item()]
256 |
257 | print('TRAIN: EPOCH %d: BATCH %04d/%04d: '
258 | 'GEN GAN: %.4f DISC FAKE: %.4f DISC REAL: %.4f' %
259 | (epoch, i, num_batch_train,
260 | mean(loss_G_train), mean(loss_D_fake_train), mean(loss_D_real_train)))
261 |
262 | if should(num_freq_disp):
263 | ## show output
264 | output = transform_inv(output)
265 | label = transform_inv(label)
266 |
267 | writer_train.add_images('output', output, num_batch_train * (epoch - 1) + i, dataformats='NHWC')
268 | writer_train.add_images('label', label, num_batch_train * (epoch - 1) + i, dataformats='NHWC')
269 |
270 | writer_train.add_scalar('loss_G', mean(loss_G_train), epoch)
271 | writer_train.add_scalar('loss_D_fake', mean(loss_D_fake_train), epoch)
272 | writer_train.add_scalar('loss_D_real', mean(loss_D_real_train), epoch)
273 | # writer_train.add_scalar('distance_Wasserstein', -(mean(loss_D_fake_train) + mean(loss_D_real_train)), epoch)
274 |
275 | # update schduler
276 | # schedG.step()
277 | # schedD.step()
278 |
279 | ## save
280 | if (epoch % num_freq_save) == 0:
281 | self.save(dir_chck, netG, netD, optimG, optimD, epoch)
282 |
283 | writer_train.close()
284 |
285 | def test(self):
286 | mode = self.mode
287 |
288 | batch_size = self.batch_size
289 | device = self.device
290 | gpu_ids = self.gpu_ids
291 |
292 | ny_in = self.ny_in
293 | nx_in = self.nx_in
294 |
295 | nch_in = self.nch_in
296 | nch_out = self.nch_out
297 | nch_ker = self.nch_ker
298 |
299 | norm = self.norm
300 |
301 | name_data = self.name_data
302 |
303 | ## setup dataset
304 | dir_chck = os.path.join(self.dir_checkpoint, self.scope, name_data)
305 |
306 | dir_result = os.path.join(self.dir_result, self.scope, name_data)
307 | dir_result_save = os.path.join(dir_result, 'images')
308 | if not os.path.exists(dir_result_save):
309 | os.makedirs(dir_result_save)
310 |
311 | transform_inv = transforms.Compose([ToNumpy(), Denomalize()])
312 |
313 | ## setup network
314 | netG = DCGAN(nch_in, nch_out, nch_ker, norm)
315 | init_net(netG, init_type='normal', init_gain=0.02, gpu_ids=gpu_ids)
316 |
317 | ## load from checkpoints
318 | st_epoch = 0
319 |
320 | netG, st_epoch = self.load(dir_chck, netG, mode=mode)
321 |
322 | ## test phase
323 | with torch.no_grad():
324 | netG.eval()
325 | # netG.train()
326 |
327 | input = torch.randn(batch_size, nch_in, ny_in, nx_in).to(device)
328 |
329 | output = netG(input)
330 |
331 | output = transform_inv(output)
332 |
333 | for j in range(output.shape[0]):
334 | name = j
335 | fileset = {'name': name,
336 | 'output': "%04d-output.png" % name}
337 |
338 | if nch_out == 3:
339 | plt.imsave(os.path.join(dir_result_save, fileset['output']), output[j, :, :, :].squeeze())
340 | elif nch_out == 1:
341 | plt.imsave(os.path.join(dir_result_save, fileset['output']), output[j, :, :, :].squeeze(), cmap=cm.gray)
342 |
343 | append_index(dir_result, fileset)
344 |
345 |
346 | def set_requires_grad(nets, requires_grad=False):
347 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
348 | Parameters:
349 | nets (network list) -- a list of networks
350 | requires_grad (bool) -- whether the networks require gradients or not
351 | """
352 | if not isinstance(nets, list):
353 | nets = [nets]
354 | for net in nets:
355 | if net is not None:
356 | for param in net.parameters():
357 | param.requires_grad = requires_grad
358 |
359 |
360 | def get_scheduler(optimizer, opt):
361 | """Return a learning rate scheduler
362 |
363 | Parameters:
364 | optimizer -- the optimizer of the network
365 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
366 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
367 |
368 | For 'linear', we keep the same learning rate for the first epochs
369 | and linearly decay the rate to zero over the next epochs.
370 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
371 | See https://pytorch.org/docs/stable/optim.html for more details.
372 | """
373 | if opt.lr_policy == 'linear':
374 | def lambda_rule(epoch):
375 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
376 | return lr_l
377 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
378 | elif opt.lr_policy == 'step':
379 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
380 | elif opt.lr_policy == 'plateau':
381 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
382 | elif opt.lr_policy == 'cosine':
383 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
384 | else:
385 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
386 | return scheduler
387 |
388 |
389 | def append_index(dir_result, fileset, step=False):
390 | index_path = os.path.join(dir_result, "index.html")
391 | if os.path.exists(index_path):
392 | index = open(index_path, "a")
393 | else:
394 | index = open(index_path, "w")
395 | index.write("")
396 | if step:
397 | index.write("step | ")
398 | for key, value in fileset.items():
399 | index.write("%s | " % key)
400 | index.write('
')
401 |
402 | # for fileset in filesets:
403 | index.write("")
404 |
405 | if step:
406 | index.write("%d | " % fileset["step"])
407 | index.write("%s | " % fileset["name"])
408 |
409 | del fileset['name']
410 |
411 | for key, value in fileset.items():
412 | index.write(" | " % value)
413 |
414 | index.write("
")
415 | return index_path
416 |
417 |
418 | def add_plot(output, label, writer, epoch=[], ylabel='Density', xlabel='Radius', namescope=[]):
419 | fig, ax = plt.subplots()
420 |
421 | ax.plot(output.transpose(1, 0).detach().numpy(), '-')
422 | ax.plot(label.transpose(1, 0).detach().numpy(), '--')
423 |
424 | ax.set_xlim(0, 400)
425 |
426 | ax.grid(True)
427 | ax.set_ylabel(ylabel)
428 | ax.set_xlabel(xlabel)
429 |
430 | writer.add_figure(namescope, fig, epoch)
431 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import os
4 | import logging
5 | import torch
6 | # import argparse
7 |
8 | ''''
9 | class Logger:
10 | class Parser:
11 | '''
12 | class Parser:
13 | def __init__(self, parser):
14 | self.__parser = parser
15 | self.__args = parser.parse_args()
16 |
17 | # set gpu ids
18 | str_ids = self.__args.gpu_ids.split(',')
19 | self.__args.gpu_ids = []
20 | for str_id in str_ids:
21 | id = int(str_id)
22 | if id >= 0:
23 | self.__args.gpu_ids.append(id)
24 | # if len(self.__args.gpu_ids) > 0:
25 | # torch.cuda.set_device(self.__args.gpu_ids[0])
26 |
27 | def get_parser(self):
28 | return self.__parser
29 |
30 | def get_arguments(self):
31 | return self.__args
32 |
33 | def write_args(self):
34 | params_dict = vars(self.__args)
35 |
36 | log_dir = os.path.join(params_dict['dir_log'], params_dict['scope'], params_dict['name_data'])
37 | args_name = os.path.join(log_dir, 'args.txt')
38 |
39 | if not os.path.exists(log_dir):
40 | os.makedirs(log_dir)
41 |
42 | with open(args_name, 'wt') as args_fid:
43 | args_fid.write('----' * 10 + '\n')
44 | args_fid.write('{0:^40}'.format('PARAMETER TABLES') + '\n')
45 | args_fid.write('----' * 10 + '\n')
46 | for k, v in sorted(params_dict.items()):
47 | args_fid.write('{}'.format(str(k)) + ' : ' + ('{0:>%d}' % (35 - len(str(k)))).format(str(v)) + '\n')
48 | args_fid.write('----' * 10 + '\n')
49 |
50 | def print_args(self, name='PARAMETER TABLES'):
51 | params_dict = vars(self.__args)
52 |
53 | print('----' * 10)
54 | print('{0:^40}'.format(name))
55 | print('----' * 10)
56 | for k, v in sorted(params_dict.items()):
57 | if '__' not in str(k):
58 | print('{}'.format(str(k)) + ' : ' + ('{0:>%d}' % (35 - len(str(k)))).format(str(v)))
59 | print('----' * 10)
60 |
61 |
62 | class Logger:
63 | def __init__(self, info=logging.INFO, name=__name__):
64 | logger = logging.getLogger(name)
65 | logger.setLevel(info)
66 |
67 | self.__logger = logger
68 |
69 | def get_logger(self, handler_type='stream_handler'):
70 | if handler_type == 'stream_handler':
71 | handler = logging.StreamHandler()
72 | log_format = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
73 | handler.setFormatter(log_format)
74 | else:
75 | handler = logging.FileHandler('utils.log')
76 |
77 | self.__logger.addHandler(handler)
78 |
79 | return self.__logger
80 |
--------------------------------------------------------------------------------