├── LICENSE
├── README.md
├── data_utils.py
├── figs
├── fig1.png
└── fig2.png
├── model_sr.py
├── networks.py
├── pytorch_ssim
└── __init__.py
└── train.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Zhuang Liu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Exploring Simple and Transferable Recognition-Aware Image Processing
2 |
3 | This repo contains the code and instructions to reproduce the results in
4 |
5 | [Exploring Simple and Transferable Recognition-Aware Image Processing](https://arxiv.org/abs/1910.09185). IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI).
6 |
7 | Zhuang Liu, Hung-Ju Wang, Tinghui Zhou, Zhiqiang Shen, Bingyi Kang, Evan Shelhamer, Trevor Darrell.
8 |
9 |
10 |
11 |
12 | Fig. 1: Image processing aims for images that look visually pleasing for human, but not those accurately recognized by machines. In this work we try to enhance output images’ recognition accuracy.
13 |
14 | ### Abtract
15 | Recent progress in image recognition has stimulated the deployment of vision systems at an unprecedented scale. As a result, visual data are now often consumed not only by humans but also by machines. Existing image processing methods only optimize for better human perception, yet the resulting images may not be accurately recognized by machines. This can be undesirable, e.g., the images can be improperly handled by search engines or recommendation systems. In this work, we propose simple approaches to improve machine interpretability of processed images: optimizing the recognition loss directly on the image processing network or through an intermediate transforming model. Interestingly, the processing model's ability to enhance recognition quality can transfer when evaluated on models of different architectures, recognized categories, tasks and training datasets. This makes the solutions applicable even when we do not have the knowledge of future recognition models, e.g., if we upload processed images to the Internet. We conduct experiments on multiple image processing tasks, with ImageNet classification and PASCAL VOC detection as recognition tasks. With our simple methods, substantial accuracy gain can be achieved with strong transferability and minimal image quality loss. Through a user study we further show that the accuracy gain can transfer to a black-box, third-party cloud model. Finally, we try to explain this transferability phenomenon by demonstrating the similarities of different models' decision boundaries.
16 |
17 |
18 |
19 | Fig. 2: Left: RA (Recognition-Aware) processing. In addition to the image processing loss, we add a recognition loss using a fixed recognition model R, for the processing model P to optimize. Right: RA with transformer. “Recognition Loss” stands for the dashed box in the left figure. A Transformer T is introduced between the output of P and input of R, to optimize recognition loss. We cut the gradient from recognition loss flowing to P, such that P only optimizes the image processing loss and the image quality is not affected.
20 |
21 |
22 | ### Dependencies
23 | Pytorch 1.5.0, and corresponding version of torchvision (0.6.0). The code could also be run using other recent versions of Pytorch (0.4.0+).
24 |
25 | Please install following the official instructions at [Pytorch](https://pytorch.org/).
26 |
27 | ### Data Preparation
28 | Download and uncompress the ImageNet classification dataset from http://image-net.org/download to `PATH_TO_IMAGENET`, which should contain subfolders `train/` and `val/`.
29 |
30 |
31 | ### Training
32 | The examples given are for a super-resolution task, change `--task` to be `dn/jpeg` for denoising/jpeg-deblocking
33 | The model P is a SRResNet, the model R is a resnet18, see options in train.py
34 | Models, logs and some visualizations will be available in the output folder (`--save-dir`)
35 |
36 | Plain Processing
37 |
38 | CUDA_VISIBLE_DEVICES=0 python train.py --l 0 --save-dir checkpoints_sr/ --task sr --sr-arch SRResNet --arch resnet18 --mode ra --data PATH_TO_IMAGENET
39 |
40 | RA Processing
41 |
42 | CUDA_VISIBLE_DEVICES=0 python train.py --l 0.001 --save-dir checkpoints_sr/ --task sr --sr-arch SRResNet --arch resnet18 --mode ra --data PATH_TO_IMAGENET
43 |
44 | RA with Transformer
45 |
46 | CUDA_VISIBLE_DEVICES=0 python train.py --l 0.01 --save-dir checkpoints_sr_T/ --task sr --sr-arch SRResNet --arch resnet18 --mode ra_transformer --data PATH_TO_IMAGENET
47 |
48 | Unsupervised RA
49 |
50 | CUDA_VISIBLE_DEVICES=0 python train.py --l 10 --save-dir checkpoints_sr_U/ --task sr --sr-arch SRResNet --arch resnet18 --mode ra_unsupervised --data PATH_TO_IMAGENET
51 |
52 |
53 |
54 | ### Evaluation
55 | After training, we could test the resulting image processing models on multiple R architectures (Evaluate transferability on different architectures).
56 |
57 | Plain Processing, RA Processing or Unsupervised RA
58 |
59 | CUDA_VISIBLE_DEVICES=0 python train.py --cross-evaluate --model-sr PATH_TO_MODEL --task sr --mode ra --data PATH_TO_IMAGENET
60 |
61 | RA with Transformer
62 |
63 | CUDA_VISIBLE_DEVICES=0 python train.py --cross-evaluate --model-sr PATH_TO_SR_MODEL --model-transformer PATH_TO_TRANSFORMER_MODEL --task sr --mode ra_transform --data PATH_TO_IMAGENET
64 |
65 | After evaluation finishes, results will be saved in the same folder as `PATH_TO_MODEL`.
66 |
67 |
68 | ### Pretrained Models
69 | We provide pretrained models of Plain Processing, RA Processing and Unsupervised RA in the following links, for all three tasks.
70 | The recognition model R used as loss is ResNet-18.
71 | | Task | Models |
72 | | ------------- | ----------- |
73 | | Super-resolution | [Google Drive](https://drive.google.com/drive/folders/1U6AGvTyl7BewnwPDxzxSyd6cfxWJ1Tkd?usp=sharing) |
74 | | Denoising | [Google Drive](https://drive.google.com/drive/folders/1LyGyMtpqDI2ExVCzL_4inC6X_lndnEvl?usp=sharing) |
75 | | JPEG-deblocking | [Google Drive](https://drive.google.com/drive/folders/1E4TDXwFUtJbRx8fNgVkUOhCzL4011CX2?usp=sharing) |
76 |
77 | The models can be used to test models following the commands above.
78 |
79 | ### Results
80 |
81 | The provided pretrained models should produce the results shown in the following tables (ImageNet accuracy %, same as corresponding results in paper).
82 |
83 | Note that the R models used to train all P models here is ResNet-18, hence the table is different than Table 1 in paper, but covers the results of Table 1,2 and 10 in paper.
84 |
85 |
86 | #### Super-resolution
87 |
88 | P Model/Evaluation on R | ResNet-18 | ResNet-50 | ResNet-101 | DenseNet-121 | VGG-16
89 | -------|:-------:|:--------:|:--------:|:--------:|:--------:|
90 | Plain Processing |52.6 | 58.8 | 61.9| 57.7 | 50.2
91 | RA Processing |61.8 |66.7 | 68.8| 64.7| 58.2
92 | Unsupervised RA |61.3 | 66.3 | 68.6| 64.5 | 57.3
93 |
94 | #### Denoising
95 | P Model/Evaluation on R | ResNet-18 | ResNet-50 | ResNet-101 | DenseNet-121 | VGG-16
96 | -------|:-------:|:--------:|:--------:|:--------:|:--------:|
97 | Plain Processing |61.9 | 68.0 | 69.1 | 66.4 | 60.9
98 | RA Processing |65.1 |70.6 | 71.9 | 69.1 | 63.8
99 | Unsupervised RA |61.7 |67.9 | 69.7 | 66.4 | 60.5
100 |
101 | #### JPEG-deblocking
102 | P Model/Evaluation on R | ResNet-18 | ResNet-50 | ResNet-101 | DenseNet-121 | VGG-16
103 | -------|:-------:|:--------:|:--------:|:--------:|:--------:|
104 | Plain Processing |48.2 | 53.8| 56.0| 52.9 | 42.4
105 | RA Processing |57.7 |62.3 |64.3 | 60.7 | 52.8
106 | Unsupervised RA |53.8 |59.1| 62.0| 57.5 | 50.0
107 |
108 | Models trained with this code should also produce similar results.
109 |
110 | ### Contact
111 | You are welcome to open issues or contact liuzhuangthu@gmail.com
112 |
113 |
114 |
115 |
116 |
117 |
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | import os, time, shutil, argparse
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 |
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.autograd import Variable
9 | from torchvision import datasets, transforms
10 | import torchvision.models as models
11 | import pdb
12 | from PIL import Image
13 | import threading
14 |
15 | class SRImageFolder(datasets.ImageFolder):
16 |
17 | def __init__(self, traindir, train_transform):
18 | super(SRImageFolder, self).__init__(traindir, train_transform)
19 | self.upscale = 4
20 |
21 | def __getitem__(self, index):
22 |
23 | path, target = self.imgs[index]
24 | img = self.loader(path)
25 |
26 | if self.transform is not None:
27 | img_output_PIL = self.transform(img)
28 |
29 | lr_size = img_output_PIL.size[0] // self.upscale
30 | img_input_PIL = transforms.Resize((lr_size, lr_size), Image.BICUBIC)(img_output_PIL)
31 |
32 | img_output = transforms.ToTensor()(img_output_PIL)
33 | img_input = transforms.ToTensor()(img_input_PIL)
34 |
35 | if self.target_transform is not None:
36 | target = self.target_transform(target)
37 |
38 | return img_input, img_output, target
39 |
40 | class DNImageFolder(datasets.ImageFolder):
41 |
42 | def __init__(self, traindir, train_transform, deterministic=False):
43 | # self.lr_size = 54
44 | super(DNImageFolder, self).__init__(traindir, train_transform)
45 | self.std = 0.1
46 | self.deterministic = deterministic
47 | print("constructing DN Image folder")
48 | # pass
49 |
50 | def __getitem__(self, index):
51 |
52 | path, target = self.imgs[index]
53 | img = self.loader(path)
54 |
55 | # print(len(self.imgs))
56 |
57 | if self.transform is not None:
58 | img_output_PIL = self.transform(img)
59 | img_output = transforms.ToTensor()(img_output_PIL)
60 |
61 | if self.deterministic:
62 | torch.manual_seed(index)
63 | noise = torch.randn(img_output.size()) * self.std
64 | img_input = torch.clamp(img_output + noise, 0, 1)
65 |
66 |
67 | if self.target_transform is not None:
68 | target = self.target_transform(target)
69 |
70 | return img_input, img_output, target
71 |
72 | class JPEGImageFolder(datasets.ImageFolder):
73 |
74 | def __init__(self, traindir, train_transform, tmp_dir):
75 | super(JPEGImageFolder, self).__init__(traindir, train_transform)
76 |
77 | self.quality = 10
78 | self.tmp_dir = tmp_dir
79 | os.makedirs(tmp_dir, exist_ok=True)
80 |
81 |
82 | def __getitem__(self, index):
83 |
84 | path, target = self.imgs[index]
85 | img = self.loader(path)
86 |
87 |
88 | if self.transform is not None:
89 | img_output_PIL = self.transform(img)
90 |
91 | img_output_PIL.save(self.tmp_dir + '{}.jpeg'.format(index), quality=self.quality)
92 | img_input_PIL = Image.open(self.tmp_dir + '{}.jpeg'.format(index))
93 | os.remove(self.tmp_dir + "{}.jpeg".format(index))
94 |
95 | img_output = transforms.ToTensor()(img_output_PIL)
96 | img_input = transforms.ToTensor()(img_input_PIL)
97 |
98 | if self.target_transform is not None:
99 | target = self.target_transform(target)
100 |
101 | return img_input, img_output, target
102 |
103 | class SelfImageFolder(datasets.ImageFolder):
104 |
105 | def __init__(self, traindir, train_transform):
106 | super(SelfImageFolder, self).__init__(traindir, train_transform)
107 |
108 | def __getitem__(self, index):
109 |
110 | path, target = self.imgs[index]
111 | img = self.loader(path)
112 |
113 | if self.transform is not None:
114 | img_output_PIL = self.transform(img)
115 |
116 | img_output = transforms.ToTensor()(img_output_PIL)
117 | img_input = img_output + 0.
118 |
119 | if self.target_transform is not None:
120 | target = self.target_transform(target)
121 |
122 | return img_input, img_output, target
123 | # return Variable(img_input).cuda(), Variable(img_output), Variable(target)
124 |
125 | if __name__ =='__main__':
126 | traindir = '/scratch/zhuangl/datasets/imagenet/train'
127 | train_transform = transforms.Compose([
128 | transforms.Resize(256),
129 | transforms.RandomCrop(224),
130 | transforms.RandomHorizontalFlip(),
131 | ])
132 | train_dataset = JPEGImageFolder(traindir, train_transform)
133 | train_loader = torch.utils.data.DataLoader(
134 | train_dataset, batch_size=5, shuffle=True,
135 | num_workers=1, pin_memory=True, sampler=None)
136 |
137 | for i, (img_input, img_output, target) in enumerate(train_loader):
138 | print(i)
139 | # pdb.set_trace()
140 | #
--------------------------------------------------------------------------------
/figs/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuzhuang13/Transferable_RA/fecdc63292abed59fffd9237c7ea4d7a4db0aef4/figs/fig1.png
--------------------------------------------------------------------------------
/figs/fig2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuzhuang13/Transferable_RA/fecdc63292abed59fffd9237c7ea4d7a4db0aef4/figs/fig2.png
--------------------------------------------------------------------------------
/model_sr.py:
--------------------------------------------------------------------------------
1 | ### code from https://github.com/leftthomas/SRGAN/blob/master/model.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.init as init
6 | import torch.nn.functional as F
7 | import math
8 | from torch.autograd import Variable
9 |
10 | from math import sqrt
11 | import numpy as np
12 | import pdb
13 |
14 |
15 | class SimpleNet(nn.Module):
16 | def __init__(self, upscale_factor, channel):
17 | super(SimpleNet, self).__init__()
18 | # if channel == 'RGB':
19 | # init_channel = 3
20 | # elif channel == 'YCbCr':
21 | # init_channel = 1
22 | self.relu = nn.ReLU()
23 | self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
24 | self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
25 | self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
26 | self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
27 | self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
28 |
29 | self._initialize_weights()
30 |
31 | def forward(self, x):
32 | x = self.relu(self.conv1(x))
33 | x = self.relu(self.conv2(x))
34 | x = self.relu(self.conv3(x))
35 | x = self.pixel_shuffle(self.conv4(x))
36 | return x
37 |
38 | def _initialize_weights(self):
39 | init.orthogonal(self.conv1.weight, init.calculate_gain('relu'))
40 | init.orthogonal(self.conv2.weight, init.calculate_gain('relu'))
41 | init.orthogonal(self.conv3.weight, init.calculate_gain('relu'))
42 | init.orthogonal(self.conv4.weight)
43 |
44 |
45 | class ResNet(nn.Module):
46 | def __init__(self, upscale_factor, channel, residual=False):
47 | upsample_block_num = int(math.log(upscale_factor, 2))
48 |
49 | super(ResNet, self).__init__()
50 |
51 | self.residual=residual
52 |
53 | c = channel
54 | self.block1 = nn.Sequential(
55 | nn.Conv2d(c, 64, kernel_size=9, padding=4),
56 | nn.PReLU()
57 | )
58 | self.block2 = ResidualBlock(64)
59 | self.block3 = ResidualBlock(64)
60 | self.block4 = ResidualBlock(64)
61 | self.block5 = ResidualBlock(64)
62 | self.block6 = ResidualBlock(64)
63 | self.block7 = nn.Sequential(
64 | nn.Conv2d(64, 64, kernel_size=3, padding=1),
65 | nn.PReLU()
66 | )
67 | block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]
68 | block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
69 | self.block8 = nn.Sequential(*block8)
70 |
71 | def forward(self, x):
72 | block1 = self.block1(x)
73 | block2 = self.block2(block1)
74 | block3 = self.block3(block2)
75 | block4 = self.block4(block3)
76 | block5 = self.block5(block4)
77 | block6 = self.block6(block5)
78 | block7 = self.block7(block6)
79 | block8 = self.block8(block1 + block7)
80 |
81 | if self.residual:
82 | # print('i am residual')
83 | return torch.clamp(x - (F.tanh(block8) + 1) / 2, 0, 1)
84 | else:
85 | # print('i am not')
86 | return (F.tanh(block8) + 1) / 2
87 |
88 |
89 | class ResidualBlock(nn.Module):
90 | def __init__(self, channels):
91 | super(ResidualBlock, self).__init__()
92 | self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
93 | self.bn1 = nn.BatchNorm2d(channels)
94 | self.prelu = nn.PReLU()
95 | self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
96 | self.bn2 = nn.BatchNorm2d(channels)
97 |
98 | def forward(self, x):
99 | residual = self.conv1(x)
100 | residual = self.bn1(residual)
101 | residual = self.prelu(residual)
102 | residual = self.conv2(residual)
103 | residual = self.bn2(residual)
104 |
105 | return x + residual
106 |
107 |
108 | class UpsampleBLock(nn.Module):
109 | def __init__(self, in_channels, up_scale):
110 | super(UpsampleBLock, self).__init__()
111 | self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
112 | self.pixel_shuffle = nn.PixelShuffle(up_scale)
113 | self.prelu = nn.PReLU()
114 |
115 | def forward(self, x):
116 | x = self.conv(x)
117 | x = self.pixel_shuffle(x)
118 | x = self.prelu(x)
119 | return x
120 |
121 | class Discriminator(nn.Module):
122 | def __init__(self):
123 | super(Discriminator, self).__init__()
124 | self.net = nn.Sequential(
125 | nn.Conv2d(3, 64, kernel_size=3, padding=1),
126 | nn.LeakyReLU(0.2),
127 |
128 | nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
129 | nn.BatchNorm2d(64),
130 | nn.LeakyReLU(0.2),
131 |
132 | nn.Conv2d(64, 128, kernel_size=3, padding=1),
133 | nn.BatchNorm2d(128),
134 | nn.LeakyReLU(0.2),
135 |
136 | nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
137 | nn.BatchNorm2d(128),
138 | nn.LeakyReLU(0.2),
139 |
140 | nn.Conv2d(128, 256, kernel_size=3, padding=1),
141 | nn.BatchNorm2d(256),
142 | nn.LeakyReLU(0.2),
143 |
144 | nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
145 | nn.BatchNorm2d(256),
146 | nn.LeakyReLU(0.2),
147 |
148 | nn.Conv2d(256, 512, kernel_size=3, padding=1),
149 | nn.BatchNorm2d(512),
150 | nn.LeakyReLU(0.2),
151 |
152 | nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
153 | nn.BatchNorm2d(512),
154 | nn.LeakyReLU(0.2),
155 |
156 | nn.AdaptiveAvgPool2d(1),
157 | nn.Conv2d(512, 1024, kernel_size=1),
158 | nn.LeakyReLU(0.2),
159 | nn.Conv2d(1024, 1, kernel_size=1)
160 | )
161 |
162 | def forward(self, x):
163 | batch_size = x.size(0)
164 | return F.sigmoid(self.net(x).view(batch_size))
165 |
166 | #----------densenet#
167 | def xavier(param):
168 | init.xavier_uniform(param)
169 | class SingleLayer(nn.Module):
170 | def __init__(self, inChannels,growthRate):
171 | super(SingleLayer, self).__init__()
172 | self.conv =nn.Conv2d(inChannels,growthRate,kernel_size=3,padding=1, bias=True)
173 | def forward(self, x):
174 | out = F.relu(self.conv(x))
175 | out = torch.cat((x, out), 1)
176 | return out
177 |
178 | class SingleBlock(nn.Module):
179 | def __init__(self, inChannels,growthRate,nDenselayer):
180 | super(SingleBlock, self).__init__()
181 | self.block= self._make_dense(inChannels,growthRate, nDenselayer)
182 |
183 | def _make_dense(self,inChannels,growthRate, nDenselayer):
184 | layers = []
185 | for i in range(int(nDenselayer)):
186 | layers.append(SingleLayer(inChannels,growthRate))
187 | inChannels += growthRate
188 | return nn.Sequential(*layers)
189 |
190 | def forward(self, x):
191 | out=self.block(x)
192 | return out
193 |
194 | class SRDenseNet(nn.Module):
195 | def __init__(self,inChannels,growthRate,nDenselayer,nBlock):
196 | super(SRDenseNet,self).__init__()
197 |
198 | self.conv1 = nn.Conv2d(3,growthRate,kernel_size=3, padding=1,bias=True)
199 |
200 | inChannels = growthRate
201 |
202 | self.denseblock = self._make_block(inChannels,growthRate, nDenselayer,nBlock)
203 | inChannels +=growthRate* nDenselayer*nBlock
204 |
205 | self.Bottleneck = nn.Conv2d(in_channels=inChannels, out_channels=128, kernel_size=1,padding=0, bias=True)
206 |
207 | self.convt1 = nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1, bias=True)
208 |
209 | self.convt2 =nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1, bias=True)
210 |
211 | self.conv2 =nn.Conv2d(in_channels=128, out_channels=3, kernel_size=3,padding=1, bias=True)
212 |
213 |
214 | # for m in self.modules():
215 | # if isinstance(m, nn.Conv2d):
216 | # xavier(m.weight.data)
217 | # if m.bias is not None:
218 | # m.bias.data.zero_()
219 |
220 | def _make_block(self, inChannels,growthRate, nDenselayer,nBlock):
221 | blocks =[]
222 | for i in range(int(nBlock)):
223 | blocks.append(SingleBlock(inChannels,growthRate,nDenselayer))
224 | inChannels += growthRate* nDenselayer
225 | return nn.Sequential(* blocks)
226 |
227 | def forward(self,x):
228 | out = F.relu(self.conv1(x))
229 | out = self.denseblock(out)
230 | out = self.Bottleneck(out)
231 | out = self.convt1(out)
232 | out = self.convt2(out)
233 |
234 | HR = self.conv2(out)
235 | return HR
236 |
237 | if __name__ == '__main__':
238 | a = torch.randn(3, 1, 10, 10)
239 | a = Variable(a)
240 | net = SRDenseNet(16,16,8,8) #default upscale is 4
241 | # pdb.set_trace()
242 |
--------------------------------------------------------------------------------
/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import functools
5 | from torch.optim import lr_scheduler
6 |
7 |
8 | ###############################################################################
9 | # Helper Functions
10 | ###############################################################################
11 | def get_norm_layer(norm_type='instance'):
12 | """Return a normalization layer
13 |
14 | Parameters:
15 | norm_type (str) -- the name of the normalization layer: batch | instance | none
16 |
17 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
18 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
19 | """
20 | if norm_type == 'batch':
21 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
22 | elif norm_type == 'instance':
23 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
24 | elif norm_type == 'none':
25 | norm_layer = None
26 | else:
27 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
28 | return norm_layer
29 |
30 |
31 | def get_scheduler(optimizer, opt):
32 | """Return a learning rate scheduler
33 |
34 | Parameters:
35 | optimizer -- the optimizer of the network
36 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
37 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
38 |
39 | For 'linear', we keep the same learning rate for the first epochs
40 | and linearly decay the rate to zero over the next epochs.
41 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
42 | See https://pytorch.org/docs/stable/optim.html for more details.
43 | """
44 | if opt.lr_policy == 'linear':
45 | def lambda_rule(epoch):
46 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
47 | return lr_l
48 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
49 | elif opt.lr_policy == 'step':
50 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
51 | elif opt.lr_policy == 'plateau':
52 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
53 | elif opt.lr_policy == 'cosine':
54 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
55 | else:
56 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
57 | return scheduler
58 |
59 |
60 | def init_weights(net, init_type='normal', init_gain=0.02):
61 | """Initialize network weights.
62 |
63 | Parameters:
64 | net (network) -- network to be initialized
65 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
66 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
67 |
68 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
69 | work better for some applications. Feel free to try yourself.
70 | """
71 | def init_func(m): # define the initialization function
72 | classname = m.__class__.__name__
73 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
74 | if init_type == 'normal':
75 | init.normal_(m.weight.data, 0.0, init_gain)
76 | elif init_type == 'xavier':
77 | init.xavier_normal_(m.weight.data, gain=init_gain)
78 | elif init_type == 'kaiming':
79 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
80 | elif init_type == 'orthogonal':
81 | init.orthogonal_(m.weight.data, gain=init_gain)
82 | else:
83 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
84 | if hasattr(m, 'bias') and m.bias is not None:
85 | init.constant_(m.bias.data, 0.0)
86 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
87 | init.normal_(m.weight.data, 1.0, init_gain)
88 | init.constant_(m.bias.data, 0.0)
89 |
90 | print('initialize network with %s' % init_type)
91 | net.apply(init_func) # apply the initialization function
92 |
93 |
94 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
95 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
96 | Parameters:
97 | net (network) -- the network to be initialized
98 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
99 | gain (float) -- scaling factor for normal, xavier and orthogonal.
100 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
101 |
102 | Return an initialized network.
103 | """
104 | if len(gpu_ids) > 0:
105 | assert(torch.cuda.is_available())
106 | net.to(gpu_ids[0])
107 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
108 | init_weights(net, init_type, init_gain=init_gain)
109 | return net
110 |
111 |
112 | def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
113 | """Create a generator
114 |
115 | Parameters:
116 | input_nc (int) -- the number of channels in input images
117 | output_nc (int) -- the number of channels in output images
118 | ngf (int) -- the number of filters in the last conv layer
119 | netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
120 | norm (str) -- the name of normalization layers used in the network: batch | instance | none
121 | use_dropout (bool) -- if use dropout layers.
122 | init_type (str) -- the name of our initialization method.
123 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
124 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
125 |
126 | Returns a generator
127 |
128 | Our current implementation provides two types of generators:
129 | U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
130 | The original U-Net paper: https://arxiv.org/abs/1505.04597
131 |
132 | Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
133 | Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
134 | We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
135 |
136 |
137 | The generator has been initialized by . It uses RELU for non-linearity.
138 | """
139 | net = None
140 | norm_layer = get_norm_layer(norm_type=norm)
141 |
142 | if netG == 'resnet_9blocks':
143 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
144 | elif netG == 'resnet_6blocks':
145 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
146 | elif netG == 'unet_128':
147 | net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
148 | elif netG == 'unet_256':
149 | net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
150 | else:
151 | raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
152 | return init_net(net, init_type, init_gain, gpu_ids)
153 |
154 |
155 | def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
156 | """Create a discriminator
157 |
158 | Parameters:
159 | input_nc (int) -- the number of channels in input images
160 | ndf (int) -- the number of filters in the first conv layer
161 | netD (str) -- the architecture's name: basic | n_layers | pixel
162 | n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
163 | norm (str) -- the type of normalization layers used in the network.
164 | init_type (str) -- the name of the initialization method.
165 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
166 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
167 |
168 | Returns a discriminator
169 |
170 | Our current implementation provides three types of discriminators:
171 | [basic]: 'PatchGAN' classifier described in the original pix2pix paper.
172 | It can classify whether 70×70 overlapping patches are real or fake.
173 | Such a patch-level discriminator architecture has fewer parameters
174 | than a full-image discriminator and can work on arbitrarily-sized images
175 | in a fully convolutional fashion.
176 |
177 | [n_layers]: With this mode, you cna specify the number of conv layers in the discriminator
178 | with the parameter (default=3 as used in [basic] (PatchGAN).)
179 |
180 | [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
181 | It encourages greater color diversity but has no effect on spatial statistics.
182 |
183 | The discriminator has been initialized by . It uses Leakly RELU for non-linearity.
184 | """
185 | net = None
186 | norm_layer = get_norm_layer(norm_type=norm)
187 |
188 | if netD == 'basic': # default PatchGAN classifier
189 | net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
190 | elif netD == 'n_layers': # more options
191 | net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
192 | elif netD == 'pixel': # classify if each pixel is real or fake
193 | net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
194 | else:
195 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % net)
196 | return init_net(net, init_type, init_gain, gpu_ids)
197 |
198 |
199 | ##############################################################################
200 | # Classes
201 | ##############################################################################
202 | class GANLoss(nn.Module):
203 | """Define different GAN objectives.
204 |
205 | The GANLoss class abstracts away the need to create the target label tensor
206 | that has the same size as the input.
207 | """
208 |
209 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
210 | """ Initialize the GANLoss class.
211 |
212 | Parameters:
213 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
214 | target_real_label (bool) - - label for a real image
215 | target_fake_label (bool) - - label of a fake image
216 |
217 | Note: Do not use sigmoid as the last layer of Discriminator.
218 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
219 | """
220 | super(GANLoss, self).__init__()
221 | self.register_buffer('real_label', torch.tensor(target_real_label))
222 | self.register_buffer('fake_label', torch.tensor(target_fake_label))
223 | self.gan_mode = gan_mode
224 | if gan_mode == 'lsgan':
225 | self.loss = nn.MSELoss()
226 | elif gan_mode == 'vanilla':
227 | self.loss = nn.BCEWithLogitsLoss()
228 | elif gan_mode in ['wgangp']:
229 | self.loss = None
230 | else:
231 | raise NotImplementedError('gan mode %s not implemented' % gan_mode)
232 |
233 | def get_target_tensor(self, prediction, target_is_real):
234 | """Create label tensors with the same size as the input.
235 |
236 | Parameters:
237 | prediction (tensor) - - tpyically the prediction from a discriminator
238 | target_is_real (bool) - - if the ground truth label is for real images or fake images
239 |
240 | Returns:
241 | A label tensor filled with ground truth label, and with the size of the input
242 | """
243 |
244 | if target_is_real:
245 | target_tensor = self.real_label
246 | else:
247 | target_tensor = self.fake_label
248 | return target_tensor.expand_as(prediction)
249 |
250 | def __call__(self, prediction, target_is_real):
251 | """Calculate loss given Discriminator's output and grount truth labels.
252 |
253 | Parameters:
254 | prediction (tensor) - - tpyically the prediction output from a discriminator
255 | target_is_real (bool) - - if the ground truth label is for real images or fake images
256 |
257 | Returns:
258 | the calculated loss.
259 | """
260 | if self.gan_mode in ['lsgan', 'vanilla']:
261 | target_tensor = self.get_target_tensor(prediction, target_is_real)
262 | loss = self.loss(prediction, target_tensor)
263 | elif self.gan_mode == 'wgangp':
264 | if target_is_real:
265 | loss = -prediction.mean()
266 | else:
267 | loss = prediction.mean()
268 | return loss
269 |
270 |
271 | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
272 | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
273 |
274 | Arguments:
275 | netD (network) -- discriminator network
276 | real_data (tensor array) -- real images
277 | fake_data (tensor array) -- generated images from the generator
278 | device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
279 | type (str) -- if we mix real and fake data or not [real | fake | mixed].
280 | constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
281 | lambda_gp (float) -- weight for this loss
282 |
283 | Returns the gradient penalty loss
284 | """
285 | if lambda_gp > 0.0:
286 | if type == 'real': # either use real images, fake images, or a linear interpolation of two.
287 | interpolatesv = real_data
288 | elif type == 'fake':
289 | interpolatesv = fake_data
290 | elif type == 'mixed':
291 | alpha = torch.rand(real_data.shape[0], 1)
292 | alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
293 | alpha = alpha.to(device)
294 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
295 | else:
296 | raise NotImplementedError('{} not implemented'.format(type))
297 | interpolatesv.requires_grad_(True)
298 | disc_interpolates = netD(interpolatesv)
299 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
300 | grad_outputs=torch.ones(disc_interpolates.size()).to(device),
301 | create_graph=True, retain_graph=True, only_inputs=True)
302 | gradients = gradients[0].view(real_data.size(0), -1) # flat the data
303 | gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
304 | return gradient_penalty, gradients
305 | else:
306 | return 0.0, None
307 |
308 |
309 | class ResnetGenerator(nn.Module):
310 | """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
311 |
312 | We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
313 | """
314 |
315 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
316 | """Construct a Resnet-based generator
317 |
318 | Parameters:
319 | input_nc (int) -- the number of channels in input images
320 | output_nc (int) -- the number of channels in output images
321 | ngf (int) -- the number of filters in the last conv layer
322 | norm_layer -- normalization layer
323 | use_dropout (bool) -- if use dropout layers
324 | n_blocks (int) -- the number of ResNet blocks
325 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
326 | """
327 | assert(n_blocks >= 0)
328 | super(ResnetGenerator, self).__init__()
329 | if type(norm_layer) == functools.partial:
330 | use_bias = norm_layer.func == nn.InstanceNorm2d
331 | else:
332 | use_bias = norm_layer == nn.InstanceNorm2d
333 |
334 | model = [nn.ReflectionPad2d(3),
335 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
336 | norm_layer(ngf),
337 | nn.ReLU(True)]
338 |
339 | n_downsampling = 2
340 | for i in range(n_downsampling): # add downsampling layers
341 | mult = 2 ** i
342 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
343 | norm_layer(ngf * mult * 2),
344 | nn.ReLU(True)]
345 |
346 | mult = 2 ** n_downsampling
347 | for i in range(n_blocks): # add ResNet blocks
348 |
349 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
350 |
351 | for i in range(n_downsampling): # add upsampling layers
352 | mult = 2 ** (n_downsampling - i)
353 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
354 | kernel_size=3, stride=2,
355 | padding=1, output_padding=1,
356 | bias=use_bias),
357 | norm_layer(int(ngf * mult / 2)),
358 | nn.ReLU(True)]
359 | model += [nn.ReflectionPad2d(3)]
360 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
361 | model += [nn.Tanh()]
362 |
363 | self.model = nn.Sequential(*model)
364 |
365 | def forward(self, input):
366 | """Standard forward"""
367 | # normalize to 0-1
368 | return (self.model(input) + 1) / 2
369 |
370 |
371 | class ResnetBlock(nn.Module):
372 | """Define a Resnet block"""
373 |
374 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
375 | """Initialize the Resnet block
376 |
377 | A resnet block is a conv block with skip connections
378 | We construct a conv block with build_conv_block function,
379 | and implement skip connections in function.
380 | Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
381 | """
382 | super(ResnetBlock, self).__init__()
383 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
384 |
385 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
386 | """Construct a convolutional block.
387 |
388 | Parameters:
389 | dim (int) -- the number of channels in the conv layer.
390 | padding_type (str) -- the name of padding layer: reflect | replicate | zero
391 | norm_layer -- normalization layer
392 | use_dropout (bool) -- if use dropout layers.
393 | use_bias (bool) -- if the conv layer uses bias or not
394 |
395 | Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
396 | """
397 | conv_block = []
398 | p = 0
399 | if padding_type == 'reflect':
400 | conv_block += [nn.ReflectionPad2d(1)]
401 | elif padding_type == 'replicate':
402 | conv_block += [nn.ReplicationPad2d(1)]
403 | elif padding_type == 'zero':
404 | p = 1
405 | else:
406 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
407 |
408 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
409 | if use_dropout:
410 | conv_block += [nn.Dropout(0.5)]
411 |
412 | p = 0
413 | if padding_type == 'reflect':
414 | conv_block += [nn.ReflectionPad2d(1)]
415 | elif padding_type == 'replicate':
416 | conv_block += [nn.ReplicationPad2d(1)]
417 | elif padding_type == 'zero':
418 | p = 1
419 | else:
420 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
421 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
422 |
423 | return nn.Sequential(*conv_block)
424 |
425 | def forward(self, x):
426 | """Forward function (with skip connections)"""
427 | out = x + self.conv_block(x) # add skip connections
428 | return out
429 |
430 |
431 | class UnetGenerator(nn.Module):
432 | """Create a Unet-based generator"""
433 |
434 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
435 | """Construct a Unet generator
436 | Parameters:
437 | input_nc (int) -- the number of channels in input images
438 | output_nc (int) -- the number of channels in output images
439 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
440 | image of size 128x128 will become of size 1x1 # at the bottleneck
441 | ngf (int) -- the number of filters in the last conv layer
442 | norm_layer -- normalization layer
443 |
444 | We construct the U-Net from the innermost layer to the outermost layer.
445 | It is a recursive process.
446 | """
447 | super(UnetGenerator, self).__init__()
448 | # construct unet structure
449 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
450 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
451 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
452 | # gradually reduce the number of filters from ngf * 8 to ngf
453 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
454 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
455 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
456 | self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
457 |
458 | def forward(self, input):
459 | """Standard forward"""
460 | return self.model(input)
461 |
462 |
463 | class UnetSkipConnectionBlock(nn.Module):
464 | """Defines the Unet submodule with skip connection.
465 | X -------------------identity----------------------
466 | |-- downsampling -- |submodule| -- upsampling --|
467 | """
468 |
469 | def __init__(self, outer_nc, inner_nc, input_nc=None,
470 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
471 | """Construct a Unet submodule with skip connections.
472 |
473 | Parameters:
474 | outer_nc (int) -- the number of filters in the outer conv layer
475 | inner_nc (int) -- the number of filters in the inner conv layer
476 | input_nc (int) -- the number of channels in input images/features
477 | submodule (UnetSkipConnectionBlock) -- previously defined submodules
478 | outermost (bool) -- if this module is the outermost module
479 | innermost (bool) -- if this module is the innermost module
480 | norm_layer -- normalization layer
481 | user_dropout (bool) -- if use dropout layers.
482 | """
483 | super(UnetSkipConnectionBlock, self).__init__()
484 | self.outermost = outermost
485 | if type(norm_layer) == functools.partial:
486 | use_bias = norm_layer.func == nn.InstanceNorm2d
487 | else:
488 | use_bias = norm_layer == nn.InstanceNorm2d
489 | if input_nc is None:
490 | input_nc = outer_nc
491 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
492 | stride=2, padding=1, bias=use_bias)
493 | downrelu = nn.LeakyReLU(0.2, True)
494 | downnorm = norm_layer(inner_nc)
495 | uprelu = nn.ReLU(True)
496 | upnorm = norm_layer(outer_nc)
497 |
498 | if outermost:
499 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
500 | kernel_size=4, stride=2,
501 | padding=1)
502 | down = [downconv]
503 | up = [uprelu, upconv, nn.Tanh()]
504 | model = down + [submodule] + up
505 | elif innermost:
506 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
507 | kernel_size=4, stride=2,
508 | padding=1, bias=use_bias)
509 | down = [downrelu, downconv]
510 | up = [uprelu, upconv, upnorm]
511 | model = down + up
512 | else:
513 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
514 | kernel_size=4, stride=2,
515 | padding=1, bias=use_bias)
516 | down = [downrelu, downconv, downnorm]
517 | up = [uprelu, upconv, upnorm]
518 |
519 | if use_dropout:
520 | model = down + [submodule] + up + [nn.Dropout(0.5)]
521 | else:
522 | model = down + [submodule] + up
523 |
524 | self.model = nn.Sequential(*model)
525 |
526 | def forward(self, x):
527 | if self.outermost:
528 | return self.model(x)
529 | else: # add skip connections
530 | return torch.cat([x, self.model(x)], 1)
531 |
532 |
533 | class NLayerDiscriminator(nn.Module):
534 | """Defines a PatchGAN discriminator"""
535 |
536 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
537 | """Construct a PatchGAN discriminator
538 |
539 | Parameters:
540 | input_nc (int) -- the number of channels in input images
541 | ndf (int) -- the number of filters in the last conv layer
542 | n_layers (int) -- the number of conv layers in the discriminator
543 | norm_layer -- normalization layer
544 | """
545 | super(NLayerDiscriminator, self).__init__()
546 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
547 | use_bias = norm_layer.func != nn.BatchNorm2d
548 | else:
549 | use_bias = norm_layer != nn.BatchNorm2d
550 |
551 | kw = 4
552 | padw = 1
553 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
554 | nf_mult = 1
555 | nf_mult_prev = 1
556 | for n in range(1, n_layers): # gradually increase the number of filters
557 | nf_mult_prev = nf_mult
558 | nf_mult = min(2 ** n, 8)
559 | sequence += [
560 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
561 | norm_layer(ndf * nf_mult),
562 | nn.LeakyReLU(0.2, True)
563 | ]
564 |
565 | nf_mult_prev = nf_mult
566 | nf_mult = min(2 ** n_layers, 8)
567 | sequence += [
568 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
569 | norm_layer(ndf * nf_mult),
570 | nn.LeakyReLU(0.2, True)
571 | ]
572 |
573 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
574 | self.model = nn.Sequential(*sequence)
575 |
576 | def forward(self, input):
577 | """Standard forward."""
578 | return self.model(input)
579 |
580 |
581 | class PixelDiscriminator(nn.Module):
582 | """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
583 |
584 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
585 | """Construct a 1x1 PatchGAN discriminator
586 |
587 | Parameters:
588 | input_nc (int) -- the number of channels in input images
589 | ndf (int) -- the number of filters in the last conv layer
590 | norm_layer -- normalization layer
591 | """
592 | super(PixelDiscriminator, self).__init__()
593 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
594 | use_bias = norm_layer.func != nn.InstanceNorm2d
595 | else:
596 | use_bias = norm_layer != nn.InstanceNorm2d
597 |
598 | self.net = [
599 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
600 | nn.LeakyReLU(0.2, True),
601 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
602 | norm_layer(ndf * 2),
603 | nn.LeakyReLU(0.2, True),
604 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
605 |
606 | self.net = nn.Sequential(*self.net)
607 |
608 | def forward(self, input):
609 | """Standard forward."""
610 | return self.net(input)
611 |
--------------------------------------------------------------------------------
/pytorch_ssim/__init__.py:
--------------------------------------------------------------------------------
1 | # Code from https://github.com/Po-Hsun-Su/pytorch-ssim
2 |
3 | from math import exp
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 |
9 |
10 | def gaussian(window_size, sigma):
11 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
12 | return gauss / gauss.sum()
13 |
14 |
15 | def create_window(window_size, channel):
16 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
17 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
18 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
19 | return window
20 |
21 |
22 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
23 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
24 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
25 |
26 | mu1_sq = mu1.pow(2)
27 | mu2_sq = mu2.pow(2)
28 | mu1_mu2 = mu1 * mu2
29 |
30 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
31 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
32 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
33 |
34 | C1 = 0.01 ** 2
35 | C2 = 0.03 ** 2
36 |
37 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
38 |
39 | if size_average:
40 | return ssim_map.mean()
41 | else:
42 | return ssim_map.mean(1).mean(1).mean(1)
43 |
44 |
45 | class SSIM(torch.nn.Module):
46 | def __init__(self, window_size=11, size_average=True):
47 | super(SSIM, self).__init__()
48 | self.window_size = window_size
49 | self.size_average = size_average
50 | self.channel = 1
51 | self.window = create_window(window_size, self.channel)
52 |
53 | def forward(self, img1, img2):
54 | (_, channel, _, _) = img1.size()
55 |
56 | if channel == self.channel and self.window.data.type() == img1.data.type():
57 | window = self.window
58 | else:
59 | window = create_window(self.window_size, channel)
60 |
61 | if img1.is_cuda:
62 | window = window.cuda(img1.get_device())
63 | window = window.type_as(img1)
64 |
65 | self.window = window
66 | self.channel = channel
67 |
68 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
69 |
70 |
71 | def ssim(img1, img2, window_size=11, size_average=True):
72 | (_, channel, _, _) = img1.size()
73 | window = create_window(window_size, channel)
74 |
75 | if img1.is_cuda:
76 | window = window.cuda(img1.get_device())
77 | window = window.type_as(img1)
78 |
79 | return _ssim(img1, img2, window, window_size, channel, size_average)
80 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 | import time
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.parallel
9 | import torch.backends.cudnn as cudnn
10 | import torch.distributed as dist
11 | import torch.optim
12 | import torch.utils.data
13 | import torch.utils.data.distributed
14 | import torchvision.transforms as transforms
15 | import torchvision.datasets as datasets
16 | import torchvision.models as models
17 | import pdb
18 | from model_sr import SimpleNet, ResNet, Discriminator, SRDenseNet
19 | from torch.autograd import Variable
20 | from PIL import Image
21 | from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize
22 | import numpy as np
23 | import json
24 | from math import log10
25 | import pytorch_ssim
26 | from networks import ResnetGenerator
27 | from data_utils import SRImageFolder, DNImageFolder, JPEGImageFolder, SelfImageFolder
28 | from time import gmtime, strftime
29 |
30 | # from skimage import io, color
31 |
32 | model_names = sorted(name for name in models.__dict__
33 | if name.islower() and not name.startswith("__")
34 | and callable(models.__dict__[name]))
35 |
36 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
37 | parser.add_argument('--data', default='/home/zhuangl/datasets/imagenet',
38 | help='path to dataset')
39 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
40 | choices=model_names,
41 | help='model architecture: ' +
42 | ' | '.join(model_names) +
43 | ' (default: resnet18)')
44 | parser.add_argument('-j', '--workers', default=5, type=int, metavar='N',
45 | help='number of data loading workers (default: 4)')
46 | parser.add_argument('--epochs', default=6, type=int, metavar='N',
47 | help='number of total epochs to run')
48 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
49 | help='manual epoch number (useful on restarts)')
50 | parser.add_argument('-b', '--batch-size', default=20, type=int,
51 | metavar='N', help='mini-batch size (default: 256)')
52 | parser.add_argument('--lr', '--learning-rate', default=0.0001, type=float,
53 | metavar='LR', help='initial learning rate')
54 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
55 | help='momentum')
56 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
57 | metavar='W', help='weight decay (default: 1e-4)')
58 | parser.add_argument('--print-freq', '-p', default=1, type=int,
59 | metavar='N', help='print frequency (default: 10)')
60 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
61 | help='path to latest checkpoint (default: none)')
62 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
63 | help='evaluate model on validation set')
64 | parser.add_argument('--pretrained', dest='pretrained', action='store_false',
65 | help='use pre-trained model')
66 | parser.add_argument('--world-size', default=1, type=int,
67 | help='number of distributed processes')
68 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
69 | help='url used to set up distributed training')
70 | parser.add_argument('--dist-backend', default='gloo', type=str,
71 | help='distributed backend')
72 | parser.add_argument('--evaluate-notransform', action='store_true', help='whether to evaluate bicubic interpolation')
73 | parser.add_argument('--upscale', default=4, type=int) # SR up resolution
74 | parser.add_argument('--l', default=0, type=float) # coefficient for RA loss, lambda in paper
75 | parser.add_argument('--save-dir', default='checkpoint/default/', type=str)
76 | parser.add_argument('--mode', default='sr',type=str) # mode, ra, ra_transformer, ra_unsupervised
77 | parser.add_argument('--task', default='sr', type=str) #
78 | parser.add_argument('--std', default=0.1, type=float) # noise level for denoising
79 | parser.add_argument('--L', default=1, type=float)
80 | parser.add_argument('--model-sr', default='test', type=str) # for evaluation
81 | parser.add_argument('--model-transformer', default=None, type=str) # for evaluation
82 | parser.add_argument('--test-batch-size', default=20, type=int)
83 | parser.add_argument('--cross-evaluate', action='store_true')
84 | parser.add_argument('--custom-evaluate', action='store_true')
85 | parser.add_argument('--custom-evaluate-model', default='', type=str)
86 | parser.add_argument('--sr-arch', default='SRResNet', type=str)
87 | parser.add_argument('--transformer-arch', default='pix2pix', type=str)
88 | parser.add_argument('--lower_lr', action='store_false', help='whether to lower lr every certain epochs') # default is True
89 | parser.add_argument('--vis', action='store_true', help='whether to visualize sr results')
90 | parser.add_argument('--l_soft', default=0.001, type=float)
91 | # parser.add_argument('--sr_model', action='store_true', help='whether to use the SRResNet model in dn and jpeg')
92 | best_prec1 = 0
93 |
94 | # get high res images output by bicubic interpolation, only used in notransform test for sr
95 | def trans_RGB_bicubic(data):
96 | up = args.upscale
97 | ims_np = (data.clone()*255.).permute(0, 2, 3, 1).numpy().astype(np.uint8)
98 |
99 | hr_size = ims_np.shape[1]
100 |
101 | lr_size = hr_size // up
102 |
103 | rgb_hrs = data.new().resize_(data.size(0), 3, hr_size, hr_size).zero_()
104 |
105 | for i, im_np in enumerate(ims_np):
106 | im = Image.fromarray(im_np, 'RGB')
107 | rgb_lr = Resize((lr_size, lr_size), Image.BICUBIC)(im)
108 | rgb_hr = Resize((hr_size, hr_size), Image.BICUBIC)(rgb_lr)
109 | rgb_hr = ToTensor()(rgb_hr)
110 | rgb_hrs[i].copy_(rgb_hr)
111 | return rgb_hrs
112 |
113 | # normalize the output of sr, to fit cls network input
114 | # input 0-1, output: normalized imagenet network input
115 | def process_to_input_cls(RGB):
116 | means = [0.485, 0.456, 0.406]
117 | stds = [0.229, 0.224, 0.225]
118 | RGB_new = torch.autograd.Variable(RGB.data.new(*RGB.size()))
119 |
120 | RGB_new[:, 0, :, :] = (RGB[:, 0, :, :] - means[0]) / stds[0]
121 | RGB_new[:, 1, :, :] = (RGB[:, 1, :, :] - means[1]) / stds[1]
122 | RGB_new[:, 2, :, :] = (RGB[:, 2, :, :] - means[2]) / stds[2]
123 |
124 | return RGB_new
125 |
126 | def main():
127 |
128 | global args, best_prec1
129 | args = parser.parse_args()
130 |
131 | if 'small' in args.data:
132 | args.epochs = 30
133 | else:
134 | args.epochs = 6
135 |
136 | print(args)
137 |
138 | args.distributed = args.world_size > 1
139 |
140 |
141 | if args.mode == 'ra_transform':
142 | save_dir_extra = '_'.join([args.sr_arch, args.transformer_arch, args.arch])
143 | elif args.mode == 'ra_unsupervised':
144 | args.l = 10
145 | save_dir_extra = '_'.join([args.sr_arch, str(args.l), args.arch])
146 | elif args.mode == 'ra':
147 | save_dir_extra = '_'.join([args.sr_arch, str(args.l), args.arch])
148 |
149 | args.save_dir = args.save_dir + save_dir_extra
150 |
151 | os.makedirs(args.save_dir, exist_ok=True)
152 | print('making directory ', args.save_dir)
153 |
154 | if args.distributed:
155 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
156 | world_size=args.world_size)
157 |
158 | # create model
159 | if args.pretrained:
160 | print("=> using pre-trained model '{}'".format(args.arch))
161 | model = models.__dict__[args.arch](pretrained=True)
162 | else:
163 | print("=> creating model '{}'".format(args.arch))
164 | model = models.__dict__[args.arch]()
165 |
166 | # if single machine multi gpus
167 | if not args.distributed:
168 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
169 | model.features = torch.nn.DataParallel(model.features)
170 | model.cuda()
171 | model.eval()
172 | else:
173 | model = torch.nn.DataParallel(model).cuda() #disable multi gpu for now
174 | model.eval()
175 | else:
176 | return
177 |
178 | # two models for sr and two models for dn/jpeg
179 | # before 0302, default for sr is SRResNet, default for dn/jpeg is pix2pix
180 | if args.task == 'sr':
181 | if args.sr_arch == 'SRResNet':
182 | model_sr = ResNet(upscale_factor=4, channel=3, residual=False)
183 | elif args.sr_arch == 'SRDenseNet':
184 | model_sr = SRDenseNet(16,16,8,8)
185 | elif args.task == 'dn' or args.task == 'jpeg':
186 | if args.sr_arch == 'SRResNet':
187 | model_sr = ResNet(upscale_factor=1, channel=3, residual=False)
188 | elif args.sr_arch == 'pix2pix':
189 | model_sr = ResnetGenerator(3, 3, n_blocks=6)
190 | model_sr = torch.nn.DataParallel(model_sr).cuda()
191 | model_sr.train()
192 |
193 | # not using these models for now
194 | if args.transformer_arch == 'SRResNet':
195 | model_transformer = ResNet(upscale_factor=1, channel=3, residual=False)
196 | elif args.transformer_arch == 'pix2pix':
197 | model_transformer = ResnetGenerator(3, 3, n_blocks=6)
198 |
199 | model_transformer = torch.nn.DataParallel(model_transformer).cuda()
200 | model_transformer.train()
201 |
202 |
203 | criterion_sr = nn.MSELoss()
204 | criterion_sr.cuda()
205 | criterion = nn.CrossEntropyLoss().cuda()
206 |
207 | optimizer_sr = torch.optim.Adam(model_sr.parameters(), lr=args.lr) # previous used 0.001 as default, now 0.0001
208 | optimizer_transformer = torch.optim.Adam(model_transformer.parameters(), lr=args.lr)
209 |
210 | optimizer = torch.optim.SGD(model.parameters(), 0.01,
211 | momentum=args.momentum,
212 | weight_decay=args.weight_decay)
213 |
214 | # optionally resume from a checkpoint, not supported now
215 | if args.resume:
216 | if os.path.isfile(args.resume):
217 | print("=> loading checkpoint '{}'".format(args.resume))
218 | checkpoint = torch.load(args.resume)
219 | args.start_epoch = checkpoint['epoch']
220 | best_prec1 = checkpoint['best_prec1']
221 | model.load_state_dict(checkpoint['state_dict'])
222 | optimizer.load_state_dict(checkpoint['optimizer'])
223 | print("=> loaded checkpoint '{}' (epoch {})"
224 | .format(args.resume, checkpoint['epoch']))
225 | else:
226 | print("=> no checkpoint found at '{}'".format(args.resume))
227 |
228 | cudnn.benchmark = True
229 |
230 | # Data loading code
231 | traindir = os.path.join(args.data, 'train')
232 | valdir = os.path.join(args.data, 'val')
233 | # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
234 | # std=[0.229, 0.224, 0.225])
235 | train_transform = transforms.Compose([
236 | transforms.Resize(256),
237 | transforms.RandomCrop(224),
238 | transforms.RandomHorizontalFlip()])
239 | val_transform = transforms.Compose([
240 | transforms.Resize(256),
241 | transforms.CenterCrop(224)])
242 |
243 | if args.task == 'sr':
244 | train_dataset = SRImageFolder(traindir, train_transform)
245 | val_dataset = SRImageFolder(valdir, val_transform)
246 | elif args.task == 'dn':
247 | train_dataset = DNImageFolder(traindir, train_transform)
248 | val_dataset = DNImageFolder(valdir, val_transform, deterministic=True)
249 | elif args.task == 'self':
250 | train_dataset = SelfImageFolder(traindir, train_transform)
251 | val_dataset = SelfImageFolder(valdir, val_transform)
252 | elif args.task == 'jpeg':
253 | randomfoldername = strftime("%Y-%m-%d_%H-%M-%S", gmtime())
254 | randomfoldername += str(os.getpid())
255 | train_dataset = JPEGImageFolder(traindir, train_transform, tmp_dir=args.data + '/trash/{}_{}/'.format(randomfoldername, np.random.randint(1, 1000)))
256 | val_dataset = JPEGImageFolder(valdir, val_transform, tmp_dir=args.data + '/trash/{}_{}/'.format(randomfoldername, np.random.randint(1, 1000)))
257 |
258 | train_loader = torch.utils.data.DataLoader(
259 | train_dataset, batch_size=args.batch_size, shuffle=True,
260 | num_workers=args.workers, pin_memory=True, sampler=None)
261 |
262 | val_loader = torch.utils.data.DataLoader(
263 | val_dataset, batch_size=args.test_batch_size, shuffle=False,
264 | num_workers=args.workers, pin_memory=True)
265 |
266 | run = eval(args.mode)
267 |
268 | # evaluation options
269 | if args.evaluate:
270 | # model_sr = torch.load(args.model_sr).cuda()
271 | model_sr = load_model(args.model_sr, model_sr)
272 |
273 | loss_sr, loss_cls, top1, top5, psnr, ssim = validate(val_loader, model_sr, model_transformer, model, optimizer_sr, criterion_sr, criterion, run=run)
274 | # pdb.set_trace()
275 | save_file = args.model_sr + '_{}.txt'.format(args.arch)
276 | np.savetxt(save_file, [loss_sr, loss_cls, top1, top5, psnr, ssim])
277 | return
278 |
279 | if args.custom_evaluate: # custom R
280 | # model_sr = torch.load(args.model_sr).cuda()
281 | model_sr = load_model(args.model_sr, model_sr)
282 |
283 | model = torch.load(args.custom_evaluate_model).cuda()
284 | loss_sr, loss_cls, top1, top5, psnr, ssim = validate(val_loader, model_sr, model_transformer, model, optimizer_sr, criterion_sr, criterion, run=run)
285 | # pdb.set_trace()
286 | save_file = args.model_sr + '_custom.txt'
287 | np.savetxt(save_file, [loss_sr, loss_cls, top1, top5, psnr, ssim])
288 | return
289 |
290 | if args.evaluate_notransform:
291 | os.makedirs('notransform_results/' + args.task, exist_ok=True)
292 | loss_sr, loss_cls, top1, top5, psnr, ssim = validate_notransform(val_loader, model, criterion_sr, criterion)
293 | save_file = 'notransform_results/' + args.task + '/{}.txt'.format(args.arch)
294 | np.savetxt(save_file, [loss_sr, loss_cls, top1, top5, psnr, ssim])
295 | return
296 |
297 |
298 | # val_loader, model_sr, model_transformer, model_D, model, criterion_sr, criterion, run
299 | if args.cross_evaluate:
300 | basic_model_list = ['resnet18','resnet50','vgg16_bn', 'resnet101', 'densenet121']
301 |
302 | more_model_list = ['densenet169', 'densenet201', 'vgg13_bn', 'vgg19_bn']
303 | other_model_list = ['vgg13', 'vgg16', 'vgg19', 'inception_v3']
304 | # complete_model_list = ['vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'densenet169', 'densenet201', 'inception_v3']
305 | model_list = basic_model_list + more_model_list + other_model_list
306 | # model_sr = torch.load(args.model_sr).cuda()
307 |
308 | model_sr = load_model(args.model_sr, model_sr)
309 |
310 | # pdb.set_trace()
311 | model_sr = nn.DataParallel(model_sr)
312 | log = {}
313 | if args.model_transformer is not None:
314 | model_transformer = torch.load(args.model_transformer).cuda()
315 | model_transformer = nn.DataParallel(model_transformer)
316 | run=ra_transform
317 |
318 | for arch in basic_model_list:
319 | model = models.__dict__[arch](pretrained=True)
320 | model = torch.nn.DataParallel(model).cuda()
321 | loss_sr, loss_cls, top1, top5, psnr, ssim = validate(val_loader, model_sr, model_transformer, model, optimizer_sr, criterion_sr, criterion, run=run)
322 | # log[arch] = top1
323 |
324 | if isinstance(top1, torch.Tensor):
325 | log[arch] = top1.item()
326 | else:
327 | log[arch] = top1
328 |
329 | with open(args.model_sr + '_' + run.__name__ + '.txt', 'w') as outfile:
330 | json.dump(log, outfile)
331 | return
332 |
333 | if args.vis:
334 | model_sr = load_model(args.model_sr, model_sr)
335 | vis(val_loader, model_sr, model_transformer, model, criterion_sr, criterion, run)
336 | return
337 |
338 | log = []
339 |
340 | for epoch in range(args.start_epoch, args.epochs):
341 | if args.distributed:
342 | train_sampler.set_epoch(epoch)
343 | if args.lower_lr:
344 | adjust_learning_rate(optimizer_sr, epoch)
345 | adjust_learning_rate(optimizer_transformer, epoch)
346 |
347 | log_tmp = []
348 |
349 | # train for one epoch
350 | loss_sr, loss_cls, top1, top5, psnr, ssim = train(train_loader, model_sr, model_transformer, model, optimizer_sr, optimizer_transformer, criterion_sr, criterion, epoch, run=run)
351 | log_tmp += [loss_sr, loss_cls, top1, top5, psnr, ssim]
352 |
353 |
354 | # evaluate on validation set
355 | loss_sr, loss_cls, top1, top5, psnr, ssim = validate(val_loader, model_sr, model_transformer, model, optimizer_sr, criterion_sr, criterion, run=run)
356 | log_tmp += [loss_sr, loss_cls, top1, top5, psnr, ssim]
357 |
358 | log.append(log_tmp)
359 | np.savetxt(os.path.join(args.save_dir, 'log.txt'), log)
360 |
361 | model_sr_out_path = os.path.join(args.save_dir, "model_sr_epoch_{}.pth".format(epoch))
362 | torch.save(model_sr, model_sr_out_path)
363 | print("Checkpoint saved to {}".format(model_sr_out_path))
364 |
365 | if args.mode == 'ra_transform':
366 | model_transformer_out_path = os.path.join(args.save_dir, "model_transformer_epoch_{}.pth".format(epoch))
367 | torch.save(model_transformer, model_transformer_out_path)
368 | print("Checkpoint saved to {}".format(model_transformer_out_path))
369 |
370 |
371 | args.model_sr = model_sr_out_path
372 | vis(val_loader, model_sr, model_transformer, model, criterion_sr, criterion, run)
373 |
374 | # Model possibly trained from an older version of Pytorch, so need this extra custom function here
375 | def load_model(model_path, model_sr):
376 | load_dict = torch.load(model_path).state_dict()
377 | model_dict = model_sr.state_dict()
378 | model_dict.update(load_dict)
379 | model_sr.load_state_dict(model_dict)
380 |
381 | return model_sr
382 |
383 | def ra(input_sr_var, target_sr_var, target_cls_var, model_sr, model_transformer, model,
384 | optimizer_sr, optimizer_transformer, criterion_sr, criterion, train=True):
385 | if train:
386 | optimizer_sr.zero_grad()
387 |
388 | # pdb.set_trace()
389 | output_sr = model_sr(input_sr_var)
390 | # pdb.set_trace()
391 | loss_sr = criterion_sr(output_sr, target_sr_var)
392 |
393 | loss_cls = 0
394 |
395 | input_cls = process_to_input_cls(output_sr)
396 | output_cls = model(input_cls)
397 | loss_cls = criterion(output_cls, target_cls_var)
398 |
399 |
400 | # compute ssim for every image
401 | ssim = 0
402 | # not compute during training to save time
403 | if not train:
404 | for i in range(output_sr.size(0)):
405 | sr_image = output_sr[i].unsqueeze(0)
406 | hr_image = target_sr_var[i].unsqueeze(0)
407 | ssim += pytorch_ssim.ssim(sr_image, hr_image).item()
408 | ssim = ssim / output_sr.size(0)
409 |
410 | loss = loss_sr + args.l * loss_cls
411 |
412 | if train:
413 | loss.backward()
414 | optimizer_sr.step()
415 |
416 | return loss_sr, loss_cls, output_cls, ssim
417 |
418 |
419 | def ra_unsupervised(input_sr_var, target_sr_var, target_cls_var, model_sr, model_transformer, model,
420 | optimizer_sr, optimizer_transformer, criterion_sr, criterion, train=True):
421 | if train:
422 | optimizer_sr.zero_grad()
423 |
424 | output_sr = model_sr(input_sr_var)
425 | loss_sr = criterion_sr(output_sr, target_sr_var)
426 |
427 | loss_cls = 0
428 |
429 | input_cls = process_to_input_cls(output_sr)
430 | output_cls = model(input_cls)
431 |
432 | output_cls_soft_target_v = model(process_to_input_cls(target_sr_var))
433 |
434 |
435 | output_cls_soft_target = Variable(torch.zeros(output_cls_soft_target_v.size())).cuda()
436 | output_cls_soft_target.data.copy_(output_cls_soft_target_v.data) # bug found, lost a "v" here.
437 | loss_cls = criterion_sr(nn.Softmax(dim=1)(output_cls), nn.Softmax(dim=1)(output_cls_soft_target))
438 |
439 | # output_cls_soft_target =
440 | # loss_cls = criterion(output_cls, target_cls_var)
441 |
442 | # compute ssim for every image
443 | ssim = 0
444 | # not compute during training to save time
445 | if not train:
446 | for i in range(output_sr.size(0)):
447 | sr_image = output_sr[i].unsqueeze(0)
448 | hr_image = target_sr_var[i].unsqueeze(0)
449 | ssim += pytorch_ssim.ssim(sr_image, hr_image).item()
450 | ssim = ssim / output_sr.size(0)
451 |
452 | loss = loss_sr + args.l * loss_cls
453 |
454 | if train:
455 | loss.backward()
456 | optimizer_sr.step()
457 |
458 | return loss_sr, loss_cls, output_cls, ssim
459 |
460 | # in sr transform 2, sr model only optimizes sr loss.
461 | def ra_transform(input_sr_var, target_sr_var, target_cls_var, model_sr, model_transformer, model,
462 | optimizer_sr, optimizer_transformer, criterion_sr, criterion, train=True):
463 | if train:
464 | optimizer_sr.zero_grad()
465 | optimizer_transformer.zero_grad()
466 |
467 | output_sr = model_sr(input_sr_var)
468 | loss_sr = criterion_sr(output_sr, target_sr_var)
469 |
470 | if train:
471 | loss_sr.backward()
472 | optimizer_sr.step()
473 |
474 | loss_cls = 0
475 |
476 | output_sr.detach_()
477 | input_cls = process_to_input_cls(model_transformer(output_sr))
478 |
479 | output_cls = model(input_cls)
480 | loss_cls = criterion(output_cls, target_cls_var)
481 | # compute ssim for every image
482 | ssim = 0
483 | # not compute during training to save time
484 | if not train:
485 | for i in range(output_sr.size(0)):
486 | sr_image = output_sr[i].unsqueeze(0)
487 | hr_image = target_sr_var[i].unsqueeze(0)
488 | ssim += pytorch_ssim.ssim(sr_image, hr_image).item()
489 | ssim = ssim / output_sr.size(0)
490 |
491 | loss = args.l * loss_cls
492 |
493 | if train:
494 | loss.backward()
495 | optimizer_transformer.step()
496 |
497 | return loss_sr, loss_cls, output_cls, ssim
498 |
499 |
500 |
501 | def train(train_loader, model_sr, model_transformer, model, optimizer_sr, optimizer_transformer, criterion_sr, criterion, epoch, run):
502 |
503 | torch.manual_seed(epoch)
504 | batch_time = AverageMeter()
505 | data_time = AverageMeter()
506 | run_time = AverageMeter()
507 | process_time = AverageMeter()
508 | losses = AverageMeter()
509 | losses_sr = AverageMeter()
510 | losses_cls = AverageMeter()
511 | top1 = AverageMeter()
512 | top5 = AverageMeter()
513 | psnr_avg = AverageMeter()
514 | ssim_avg = AverageMeter()
515 |
516 | model_sr.train()
517 | # if model_transformer is not None:
518 | model_transformer.train()
519 |
520 | # model.eval()
521 | if type(model) is list:
522 | for i in range(len(model)):
523 | model[i].eval()
524 | print(model[i].training)
525 | else:
526 | model.eval()
527 |
528 | end = time.time()
529 |
530 | for i, (img_input, img_output, target) in enumerate(train_loader):
531 | data_time.update(time.time() - end)
532 |
533 | input_sr_var = Variable(img_input.cuda())
534 | target_sr_var = Variable(img_output).cuda()
535 | target_cls_var = Variable(target).cuda()
536 | target = target.cuda()
537 |
538 | start_run = time.time()
539 |
540 | loss_sr, loss_cls, output_cls, ssim = run(input_sr_var, target_sr_var, target_cls_var, model_sr, model_transformer, model, optimizer_sr,
541 | optimizer_transformer, criterion_sr, criterion, train=True)
542 |
543 | run_time.update(time.time() - start_run)
544 |
545 | process_start = time.time()
546 | psnr = 10 * log10(1 / (loss_sr.item()))
547 | process_time.update(time.time() - process_start)
548 |
549 | prec1, prec5 = accuracy(output_cls.data, target, topk=(1, 5))
550 | # losses.update(loss.item(), input.size(0))
551 | top1.update(prec1[0], img_input.size(0))
552 | top5.update(prec5[0], img_input.size(0))
553 | losses_sr.update(loss_sr.item(), img_input.size(0))
554 | losses_cls.update(loss_cls.item(), img_input.size(0))
555 | psnr_avg.update(psnr, img_input.size(0))
556 | ssim_avg.update(ssim, img_input.size(0))
557 |
558 | batch_time.update(time.time() - end)
559 | end = time.time()
560 |
561 | if i % args.print_freq == 0:
562 | print('Epoch: [{0}][{1}/{2}]\t'
563 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
564 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
565 | 'Process {process_time.val:.3f} ({process_time.avg:.3f})\t'
566 | 'Run {run_time.val:.3f} ({run_time.avg:.3f})\t'
567 | # 'Loss {loss.val:.4f} ({loss.avg:.4f})\t's
568 | 'Loss_sr {loss_sr.val:.4f} ({loss_sr.avg:.3f})'
569 | 'Loss_cls {loss_cls.val:.4f} ({loss_cls.avg: .3f})'
570 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
571 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
572 | epoch, i, len(train_loader), batch_time=batch_time,
573 | data_time=data_time, process_time=process_time, run_time=run_time, loss=losses, loss_sr = losses_sr, loss_cls=losses_cls, top1=top1, top5=top5))
574 | return losses_sr.avg, losses_cls.avg, top1.avg, top5.avg, psnr_avg.avg, ssim_avg.avg
575 | # pdb.set_trace()
576 |
577 | def validate(val_loader, model_sr, model_transformer, model, optimizer_sr, criterion_sr, criterion, run):
578 |
579 | torch.manual_seed(1)
580 |
581 | batch_time = AverageMeter()
582 | data_time = AverageMeter()
583 | losses = AverageMeter()
584 | losses_sr = AverageMeter()
585 | losses_cls = AverageMeter()
586 | top1 = AverageMeter()
587 | top5 = AverageMeter()
588 | psnr_avg = AverageMeter()
589 | ssim_avg = AverageMeter()
590 |
591 | model_sr.eval()
592 | if model_transformer is not None:
593 | model_transformer.eval()
594 |
595 | if type(model) is list:
596 | for i in range(len(model)):
597 | model[i].eval()
598 | print(model[i].training)
599 | else:
600 | model.eval()
601 |
602 | end = time.time()
603 |
604 | for i, (img_input, img_output, target) in enumerate(val_loader):
605 | target = target.cuda(async=True)
606 | input_sr_var = Variable(img_input, volatile=True).cuda()
607 | target_sr_var = Variable(img_output, volatile=True).cuda()
608 | target_cls_var = Variable(target, volatile=True).cuda()
609 |
610 |
611 | loss_sr, loss_cls, output_cls, ssim = run(input_sr_var, target_sr_var, target_cls_var, model_sr, model_transformer, model, optimizer_sr=optimizer_sr,
612 | optimizer_transformer=None, criterion_sr=criterion_sr, criterion=criterion, train=False)
613 |
614 | psnr = 10 * log10(1 / (loss_sr.item() + 1e-9))
615 |
616 | prec1, prec5 = accuracy(output_cls.data, target, topk=(1, 5))
617 | # losses.update(loss.item(), input.size(0))
618 | top1.update(prec1[0], img_input.size(0))
619 | top5.update(prec5[0], img_input.size(0))
620 | losses_sr.update(loss_sr.item(), img_input.size(0))
621 | losses_cls.update(loss_cls.item(), img_input.size(0))
622 | psnr_avg.update(psnr, img_input.size(0))
623 | ssim_avg.update(ssim, img_input.size(0))
624 |
625 | batch_time.update(time.time() - end)
626 | end = time.time()
627 |
628 | if i % args.print_freq == 0:
629 | print('Test: [{0}/{1}]\t'
630 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
631 | 'Loss_sr {loss_sr.val:.4f} ({loss_sr.avg:.4f})\t'
632 | 'Loss_cls {loss_cls.val:.4f} ({loss_cls.avg: .3f})'
633 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
634 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'
635 | 'PSNR {psnr.val:.3f} ({psnr.avg:.3f})'.format(
636 | i, len(val_loader), batch_time=batch_time, loss_sr=losses_sr, loss_cls=losses_cls,
637 | top1=top1, top5=top5, psnr=psnr_avg))
638 | return losses_sr.avg, losses_cls.avg, top1.avg, top5.avg, psnr_avg.avg, ssim_avg.avg
639 |
640 |
641 | # evaluate "no processing"
642 | def validate_notransform(val_loader, model, criterion_sr, criterion):
643 |
644 | batch_time = AverageMeter()
645 | data_time = AverageMeter()
646 | losses = AverageMeter()
647 | losses_sr = AverageMeter()
648 | losses_cls = AverageMeter()
649 | top1 = AverageMeter()
650 | top5 = AverageMeter()
651 | psnr_avg = AverageMeter()
652 | ssim_avg = AverageMeter()
653 |
654 | if type(model) is list:
655 | for i in range(len(model)):
656 | model[i].eval()
657 | print(model[i].training)
658 | else:
659 | model.eval()
660 |
661 | end = time.time()
662 |
663 | for i, (img_input, img_output, target) in enumerate(val_loader):
664 | # print(i)
665 | if True:
666 | target = target.cuda(async=True)
667 | target_sr_var = Variable(img_output).cuda()
668 | target_cls_var = Variable(target).cuda()
669 |
670 | # output of bicubic (tensor in 0-1)
671 | if args.task == 'sr':
672 | output_sr = Variable(trans_RGB_bicubic(img_output), volatile=True).cuda()
673 | else:
674 | output_sr = Variable(img_input, volatile=True).cuda()
675 |
676 | # remaining is the same as in sr function
677 | loss_sr = criterion_sr(output_sr, target_sr_var)
678 |
679 | input_cls = process_to_input_cls(output_sr)
680 | output_cls = model(input_cls)
681 | loss_cls = criterion(output_cls, target_cls_var)
682 |
683 | ssim = 0
684 | for j in range(output_sr.size(0)):
685 | sr_image = output_sr[j].unsqueeze(0)
686 | hr_image = target_sr_var[j].unsqueeze(0)
687 | ssim += pytorch_ssim.ssim(sr_image, hr_image).item()
688 | ssim = ssim / output_sr.size(0)
689 |
690 | psnr = 10 * log10(1 / loss_sr.item())
691 |
692 |
693 | prec1, prec5 = accuracy(output_cls.data, target, topk=(1, 5))
694 | top1.update(prec1[0], img_input.size(0))
695 | top5.update(prec5[0], img_input.size(0))
696 | losses_sr.update(loss_sr.item(), img_input.size(0))
697 | losses_cls.update(loss_cls.item(), img_input.size(0))
698 | psnr_avg.update(psnr, img_input.size(0))
699 | ssim_avg.update(ssim, img_input.size(0))
700 |
701 | batch_time.update(time.time() - end)
702 | end = time.time()
703 |
704 | if i % args.print_freq == 0:
705 | # pdb.set_trace()
706 | print('Test: [{0}/{1}]\t'
707 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
708 | 'Loss_sr {loss_sr.val:.4f} ({loss_sr.avg:.4f})\t'
709 | 'Loss_cls {loss_cls.val:.4f} ({loss_cls.avg: .3f})'
710 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
711 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'
712 | 'PSNR {psnr.val:.3f} ({psnr.avg:.3f})'.format(
713 | i, len(val_loader), batch_time=batch_time, loss_sr=losses_sr, loss_cls=losses_cls,
714 | top1=top1, top5=top5, psnr=psnr_avg))
715 | return losses_sr.avg, losses_cls.avg, top1.avg, top5.avg, psnr_avg.avg, ssim_avg.avg
716 |
717 | def vis(val_loader, model_sr, model_transformer, model, criterion_sr, criterion, run):
718 |
719 | torch.manual_seed(1)
720 | image_list = []
721 | image_list_input = []
722 | image_list_target = []
723 | for i, (img_input, img_output, target) in enumerate(val_loader):
724 | if i > 10:
725 | break
726 |
727 | input_sr_var = Variable(img_input, volatile=True).cuda()
728 | target_sr_var = Variable(img_output, volatile=True).cuda()
729 | target_cls_var = Variable(target, volatile=True).cuda()
730 | output_sr = model_sr(input_sr_var)
731 | im = image_from_RGB(output_sr[0])
732 | im_input = image_from_RGB(input_sr_var[0])
733 | im_target = image_from_RGB(target_sr_var[0])
734 |
735 | image_list.append(im)
736 | image_list_input.append(im_input)
737 | image_list_target.append(im_target)
738 |
739 | im_save = combine_image(image_list)
740 | im_save.save(args.model_sr + '_output.png')
741 | im_save_input = combine_image(image_list_input)
742 | im_save_input.save(args.model_sr + '_input.png')
743 | im_save_target = combine_image(image_list_target)
744 | im_save_target.save(args.model_sr + '_target.png')
745 |
746 | return im_save
747 |
748 |
749 | #utilities functions
750 |
751 | # util functions for visualize
752 | def image_from_RGB(out):
753 | # data = torch.clamp(output_sr*255., 0, 255).data
754 | if out.size(0) == 3:
755 | out = out.permute(1,2,0).cpu()
756 | out_img_y = out.data.numpy()
757 | out_img_y *= 255.0
758 | out_img_y = out_img_y.clip(0, 255)
759 | out_img_y = Image.fromarray(np.uint8(out_img_y), mode='RGB')
760 | elif out.size(0) == 1:
761 | out = out.cpu()
762 | out_img_y = out.data.numpy()
763 | out_img_y *= 255.0
764 | # pdb.set_trace()
765 | out_img_y = out_img_y.clip(0, 255)
766 | out_img_y = out_img_y[0]
767 | # pdb.set_trace()
768 | out_img_y = Image.fromarray(np.uint8(out_img_y), mode='L')
769 |
770 | # out_img_y.save('test.png')
771 | return out_img_y
772 | # pdb.set_trace()
773 | # pdb.set_trace()
774 | def combine_image(images):
775 | # images = map(Image.open, ['Test1.png', 'Test2.png', 'Test3.png'])
776 | widths, heights = zip(*(i.size for i in images))
777 |
778 | total_width = sum(widths)
779 | max_height = max(heights)
780 |
781 | new_im = Image.new('RGB', (total_width, max_height))
782 |
783 | x_offset = 0
784 | for im in images:
785 | new_im.paste(im, (x_offset,0))
786 | x_offset += im.size[0]
787 |
788 | # new_im.save('test.png')
789 | return new_im
790 |
791 | # util functions with imagenet training, not in use
792 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
793 | torch.save(state, filename)
794 | if is_best:
795 | shutil.copyfile(filename, 'model_best.pth.tar')
796 |
797 |
798 | class AverageMeter(object):
799 | """Computes and stores the average and current value"""
800 | def __init__(self):
801 | self.reset()
802 |
803 | def reset(self):
804 | self.val = 0
805 | self.avg = 0
806 | self.sum = 0
807 | self.count = 0
808 |
809 | def update(self, val, n=1):
810 | self.val = val
811 | self.sum += val * n
812 | self.count += n
813 | self.avg = self.sum / self.count
814 |
815 |
816 | def adjust_learning_rate(optimizer, epoch):
817 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
818 | if 'small' in args.data:
819 | if epoch in range(20):
820 | lr = args.lr
821 | elif epoch in range(20, 25):
822 | lr = args.lr * 0.1
823 | elif epoch in range(25, 30):
824 | lr = args.lr * 0.01
825 | else:
826 | if epoch in [0,1,2,3]: # for emergency use, to be changed back to [0,1,2,3]
827 | lr = args.lr
828 | elif epoch in [4]:
829 | lr = args.lr * 0.1
830 | elif epoch in [5]:
831 | lr = args.lr * 0.01
832 |
833 | for param_group in optimizer.param_groups:
834 | param_group['lr'] = lr
835 |
836 |
837 | def accuracy(output, target, topk=(1,)):
838 | """Computes the precision@k for the specified values of k"""
839 | maxk = max(topk)
840 | batch_size = target.size(0)
841 |
842 | _, pred = output.topk(maxk, 1, True, True)
843 | pred = pred.t()
844 | correct = pred.eq(target.view(1, -1).expand_as(pred))
845 |
846 | res = []
847 | for k in topk:
848 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
849 | res.append(correct_k.mul_(100.0 / batch_size))
850 | return res
851 |
852 |
853 | if __name__ == '__main__':
854 | main()
855 |
--------------------------------------------------------------------------------