├── LICENSE ├── README.md ├── data_utils.py ├── figs ├── fig1.png └── fig2.png ├── model_sr.py ├── networks.py ├── pytorch_ssim └── __init__.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Zhuang Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Exploring Simple and Transferable Recognition-Aware Image Processing 2 | 3 | This repo contains the code and instructions to reproduce the results in 4 | 5 | [Exploring Simple and Transferable Recognition-Aware Image Processing](https://arxiv.org/abs/1910.09185). IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI). 6 | 7 | Zhuang Liu, Hung-Ju Wang, Tinghui Zhou, Zhiqiang Shen, Bingyi Kang, Evan Shelhamer, Trevor Darrell. 8 | 9 | 10 | 11 | drawing 12 | Fig. 1: Image processing aims for images that look visually pleasing for human, but not those accurately recognized by machines. In this work we try to enhance output images’ recognition accuracy. 13 | 14 | ### Abtract 15 | Recent progress in image recognition has stimulated the deployment of vision systems at an unprecedented scale. As a result, visual data are now often consumed not only by humans but also by machines. Existing image processing methods only optimize for better human perception, yet the resulting images may not be accurately recognized by machines. This can be undesirable, e.g., the images can be improperly handled by search engines or recommendation systems. In this work, we propose simple approaches to improve machine interpretability of processed images: optimizing the recognition loss directly on the image processing network or through an intermediate transforming model. Interestingly, the processing model's ability to enhance recognition quality can transfer when evaluated on models of different architectures, recognized categories, tasks and training datasets. This makes the solutions applicable even when we do not have the knowledge of future recognition models, e.g., if we upload processed images to the Internet. We conduct experiments on multiple image processing tasks, with ImageNet classification and PASCAL VOC detection as recognition tasks. With our simple methods, substantial accuracy gain can be achieved with strong transferability and minimal image quality loss. Through a user study we further show that the accuracy gain can transfer to a black-box, third-party cloud model. Finally, we try to explain this transferability phenomenon by demonstrating the similarities of different models' decision boundaries. 16 | 17 | 18 | drawing 19 | Fig. 2: Left: RA (Recognition-Aware) processing. In addition to the image processing loss, we add a recognition loss using a fixed recognition model R, for the processing model P to optimize. Right: RA with transformer. “Recognition Loss” stands for the dashed box in the left figure. A Transformer T is introduced between the output of P and input of R, to optimize recognition loss. We cut the gradient from recognition loss flowing to P, such that P only optimizes the image processing loss and the image quality is not affected. 20 | 21 | 22 | ### Dependencies 23 | Pytorch 1.5.0, and corresponding version of torchvision (0.6.0). The code could also be run using other recent versions of Pytorch (0.4.0+). 24 | 25 | Please install following the official instructions at [Pytorch](https://pytorch.org/). 26 | 27 | ### Data Preparation 28 | Download and uncompress the ImageNet classification dataset from http://image-net.org/download to `PATH_TO_IMAGENET`, which should contain subfolders `train/` and `val/`. 29 | 30 | 31 | ### Training 32 | The examples given are for a super-resolution task, change `--task` to be `dn/jpeg` for denoising/jpeg-deblocking 33 | The model P is a SRResNet, the model R is a resnet18, see options in train.py 34 | Models, logs and some visualizations will be available in the output folder (`--save-dir`) 35 | 36 | Plain Processing 37 | 38 | CUDA_VISIBLE_DEVICES=0 python train.py --l 0 --save-dir checkpoints_sr/ --task sr --sr-arch SRResNet --arch resnet18 --mode ra --data PATH_TO_IMAGENET 39 | 40 | RA Processing 41 | 42 | CUDA_VISIBLE_DEVICES=0 python train.py --l 0.001 --save-dir checkpoints_sr/ --task sr --sr-arch SRResNet --arch resnet18 --mode ra --data PATH_TO_IMAGENET 43 | 44 | RA with Transformer 45 | 46 | CUDA_VISIBLE_DEVICES=0 python train.py --l 0.01 --save-dir checkpoints_sr_T/ --task sr --sr-arch SRResNet --arch resnet18 --mode ra_transformer --data PATH_TO_IMAGENET 47 | 48 | Unsupervised RA 49 | 50 | CUDA_VISIBLE_DEVICES=0 python train.py --l 10 --save-dir checkpoints_sr_U/ --task sr --sr-arch SRResNet --arch resnet18 --mode ra_unsupervised --data PATH_TO_IMAGENET 51 | 52 | 53 | 54 | ### Evaluation 55 | After training, we could test the resulting image processing models on multiple R architectures (Evaluate transferability on different architectures). 56 | 57 | Plain Processing, RA Processing or Unsupervised RA 58 | 59 | CUDA_VISIBLE_DEVICES=0 python train.py --cross-evaluate --model-sr PATH_TO_MODEL --task sr --mode ra --data PATH_TO_IMAGENET 60 | 61 | RA with Transformer 62 | 63 | CUDA_VISIBLE_DEVICES=0 python train.py --cross-evaluate --model-sr PATH_TO_SR_MODEL --model-transformer PATH_TO_TRANSFORMER_MODEL --task sr --mode ra_transform --data PATH_TO_IMAGENET 64 | 65 | After evaluation finishes, results will be saved in the same folder as `PATH_TO_MODEL`. 66 | 67 | 68 | ### Pretrained Models 69 | We provide pretrained models of Plain Processing, RA Processing and Unsupervised RA in the following links, for all three tasks. 70 | The recognition model R used as loss is ResNet-18. 71 | | Task | Models | 72 | | ------------- | ----------- | 73 | | Super-resolution | [Google Drive](https://drive.google.com/drive/folders/1U6AGvTyl7BewnwPDxzxSyd6cfxWJ1Tkd?usp=sharing) | 74 | | Denoising | [Google Drive](https://drive.google.com/drive/folders/1LyGyMtpqDI2ExVCzL_4inC6X_lndnEvl?usp=sharing) | 75 | | JPEG-deblocking | [Google Drive](https://drive.google.com/drive/folders/1E4TDXwFUtJbRx8fNgVkUOhCzL4011CX2?usp=sharing) | 76 | 77 | The models can be used to test models following the commands above. 78 | 79 | ### Results 80 | 81 | The provided pretrained models should produce the results shown in the following tables (ImageNet accuracy %, same as corresponding results in paper). 82 | 83 | Note that the R models used to train all P models here is ResNet-18, hence the table is different than Table 1 in paper, but covers the results of Table 1,2 and 10 in paper. 84 | 85 | 86 | #### Super-resolution 87 | 88 | P Model/Evaluation on R | ResNet-18 | ResNet-50 | ResNet-101 | DenseNet-121 | VGG-16 89 | -------|:-------:|:--------:|:--------:|:--------:|:--------:| 90 | Plain Processing |52.6 | 58.8 | 61.9| 57.7 | 50.2 91 | RA Processing |61.8 |66.7 | 68.8| 64.7| 58.2 92 | Unsupervised RA |61.3 | 66.3 | 68.6| 64.5 | 57.3 93 | 94 | #### Denoising 95 | P Model/Evaluation on R | ResNet-18 | ResNet-50 | ResNet-101 | DenseNet-121 | VGG-16 96 | -------|:-------:|:--------:|:--------:|:--------:|:--------:| 97 | Plain Processing |61.9 | 68.0 | 69.1 | 66.4 | 60.9 98 | RA Processing |65.1 |70.6 | 71.9 | 69.1 | 63.8 99 | Unsupervised RA |61.7 |67.9 | 69.7 | 66.4 | 60.5 100 | 101 | #### JPEG-deblocking 102 | P Model/Evaluation on R | ResNet-18 | ResNet-50 | ResNet-101 | DenseNet-121 | VGG-16 103 | -------|:-------:|:--------:|:--------:|:--------:|:--------:| 104 | Plain Processing |48.2 | 53.8| 56.0| 52.9 | 42.4 105 | RA Processing |57.7 |62.3 |64.3 | 60.7 | 52.8 106 | Unsupervised RA |53.8 |59.1| 62.0| 57.5 | 50.0 107 | 108 | Models trained with this code should also produce similar results. 109 | 110 | ### Contact 111 | You are welcome to open issues or contact liuzhuangthu@gmail.com 112 | 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import os, time, shutil, argparse 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | from torchvision import datasets, transforms 10 | import torchvision.models as models 11 | import pdb 12 | from PIL import Image 13 | import threading 14 | 15 | class SRImageFolder(datasets.ImageFolder): 16 | 17 | def __init__(self, traindir, train_transform): 18 | super(SRImageFolder, self).__init__(traindir, train_transform) 19 | self.upscale = 4 20 | 21 | def __getitem__(self, index): 22 | 23 | path, target = self.imgs[index] 24 | img = self.loader(path) 25 | 26 | if self.transform is not None: 27 | img_output_PIL = self.transform(img) 28 | 29 | lr_size = img_output_PIL.size[0] // self.upscale 30 | img_input_PIL = transforms.Resize((lr_size, lr_size), Image.BICUBIC)(img_output_PIL) 31 | 32 | img_output = transforms.ToTensor()(img_output_PIL) 33 | img_input = transforms.ToTensor()(img_input_PIL) 34 | 35 | if self.target_transform is not None: 36 | target = self.target_transform(target) 37 | 38 | return img_input, img_output, target 39 | 40 | class DNImageFolder(datasets.ImageFolder): 41 | 42 | def __init__(self, traindir, train_transform, deterministic=False): 43 | # self.lr_size = 54 44 | super(DNImageFolder, self).__init__(traindir, train_transform) 45 | self.std = 0.1 46 | self.deterministic = deterministic 47 | print("constructing DN Image folder") 48 | # pass 49 | 50 | def __getitem__(self, index): 51 | 52 | path, target = self.imgs[index] 53 | img = self.loader(path) 54 | 55 | # print(len(self.imgs)) 56 | 57 | if self.transform is not None: 58 | img_output_PIL = self.transform(img) 59 | img_output = transforms.ToTensor()(img_output_PIL) 60 | 61 | if self.deterministic: 62 | torch.manual_seed(index) 63 | noise = torch.randn(img_output.size()) * self.std 64 | img_input = torch.clamp(img_output + noise, 0, 1) 65 | 66 | 67 | if self.target_transform is not None: 68 | target = self.target_transform(target) 69 | 70 | return img_input, img_output, target 71 | 72 | class JPEGImageFolder(datasets.ImageFolder): 73 | 74 | def __init__(self, traindir, train_transform, tmp_dir): 75 | super(JPEGImageFolder, self).__init__(traindir, train_transform) 76 | 77 | self.quality = 10 78 | self.tmp_dir = tmp_dir 79 | os.makedirs(tmp_dir, exist_ok=True) 80 | 81 | 82 | def __getitem__(self, index): 83 | 84 | path, target = self.imgs[index] 85 | img = self.loader(path) 86 | 87 | 88 | if self.transform is not None: 89 | img_output_PIL = self.transform(img) 90 | 91 | img_output_PIL.save(self.tmp_dir + '{}.jpeg'.format(index), quality=self.quality) 92 | img_input_PIL = Image.open(self.tmp_dir + '{}.jpeg'.format(index)) 93 | os.remove(self.tmp_dir + "{}.jpeg".format(index)) 94 | 95 | img_output = transforms.ToTensor()(img_output_PIL) 96 | img_input = transforms.ToTensor()(img_input_PIL) 97 | 98 | if self.target_transform is not None: 99 | target = self.target_transform(target) 100 | 101 | return img_input, img_output, target 102 | 103 | class SelfImageFolder(datasets.ImageFolder): 104 | 105 | def __init__(self, traindir, train_transform): 106 | super(SelfImageFolder, self).__init__(traindir, train_transform) 107 | 108 | def __getitem__(self, index): 109 | 110 | path, target = self.imgs[index] 111 | img = self.loader(path) 112 | 113 | if self.transform is not None: 114 | img_output_PIL = self.transform(img) 115 | 116 | img_output = transforms.ToTensor()(img_output_PIL) 117 | img_input = img_output + 0. 118 | 119 | if self.target_transform is not None: 120 | target = self.target_transform(target) 121 | 122 | return img_input, img_output, target 123 | # return Variable(img_input).cuda(), Variable(img_output), Variable(target) 124 | 125 | if __name__ =='__main__': 126 | traindir = '/scratch/zhuangl/datasets/imagenet/train' 127 | train_transform = transforms.Compose([ 128 | transforms.Resize(256), 129 | transforms.RandomCrop(224), 130 | transforms.RandomHorizontalFlip(), 131 | ]) 132 | train_dataset = JPEGImageFolder(traindir, train_transform) 133 | train_loader = torch.utils.data.DataLoader( 134 | train_dataset, batch_size=5, shuffle=True, 135 | num_workers=1, pin_memory=True, sampler=None) 136 | 137 | for i, (img_input, img_output, target) in enumerate(train_loader): 138 | print(i) 139 | # pdb.set_trace() 140 | # -------------------------------------------------------------------------------- /figs/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuzhuang13/Transferable_RA/fecdc63292abed59fffd9237c7ea4d7a4db0aef4/figs/fig1.png -------------------------------------------------------------------------------- /figs/fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuzhuang13/Transferable_RA/fecdc63292abed59fffd9237c7ea4d7a4db0aef4/figs/fig2.png -------------------------------------------------------------------------------- /model_sr.py: -------------------------------------------------------------------------------- 1 | ### code from https://github.com/leftthomas/SRGAN/blob/master/model.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | import torch.nn.functional as F 7 | import math 8 | from torch.autograd import Variable 9 | 10 | from math import sqrt 11 | import numpy as np 12 | import pdb 13 | 14 | 15 | class SimpleNet(nn.Module): 16 | def __init__(self, upscale_factor, channel): 17 | super(SimpleNet, self).__init__() 18 | # if channel == 'RGB': 19 | # init_channel = 3 20 | # elif channel == 'YCbCr': 21 | # init_channel = 1 22 | self.relu = nn.ReLU() 23 | self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) 24 | self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) 25 | self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) 26 | self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1)) 27 | self.pixel_shuffle = nn.PixelShuffle(upscale_factor) 28 | 29 | self._initialize_weights() 30 | 31 | def forward(self, x): 32 | x = self.relu(self.conv1(x)) 33 | x = self.relu(self.conv2(x)) 34 | x = self.relu(self.conv3(x)) 35 | x = self.pixel_shuffle(self.conv4(x)) 36 | return x 37 | 38 | def _initialize_weights(self): 39 | init.orthogonal(self.conv1.weight, init.calculate_gain('relu')) 40 | init.orthogonal(self.conv2.weight, init.calculate_gain('relu')) 41 | init.orthogonal(self.conv3.weight, init.calculate_gain('relu')) 42 | init.orthogonal(self.conv4.weight) 43 | 44 | 45 | class ResNet(nn.Module): 46 | def __init__(self, upscale_factor, channel, residual=False): 47 | upsample_block_num = int(math.log(upscale_factor, 2)) 48 | 49 | super(ResNet, self).__init__() 50 | 51 | self.residual=residual 52 | 53 | c = channel 54 | self.block1 = nn.Sequential( 55 | nn.Conv2d(c, 64, kernel_size=9, padding=4), 56 | nn.PReLU() 57 | ) 58 | self.block2 = ResidualBlock(64) 59 | self.block3 = ResidualBlock(64) 60 | self.block4 = ResidualBlock(64) 61 | self.block5 = ResidualBlock(64) 62 | self.block6 = ResidualBlock(64) 63 | self.block7 = nn.Sequential( 64 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 65 | nn.PReLU() 66 | ) 67 | block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)] 68 | block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4)) 69 | self.block8 = nn.Sequential(*block8) 70 | 71 | def forward(self, x): 72 | block1 = self.block1(x) 73 | block2 = self.block2(block1) 74 | block3 = self.block3(block2) 75 | block4 = self.block4(block3) 76 | block5 = self.block5(block4) 77 | block6 = self.block6(block5) 78 | block7 = self.block7(block6) 79 | block8 = self.block8(block1 + block7) 80 | 81 | if self.residual: 82 | # print('i am residual') 83 | return torch.clamp(x - (F.tanh(block8) + 1) / 2, 0, 1) 84 | else: 85 | # print('i am not') 86 | return (F.tanh(block8) + 1) / 2 87 | 88 | 89 | class ResidualBlock(nn.Module): 90 | def __init__(self, channels): 91 | super(ResidualBlock, self).__init__() 92 | self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) 93 | self.bn1 = nn.BatchNorm2d(channels) 94 | self.prelu = nn.PReLU() 95 | self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) 96 | self.bn2 = nn.BatchNorm2d(channels) 97 | 98 | def forward(self, x): 99 | residual = self.conv1(x) 100 | residual = self.bn1(residual) 101 | residual = self.prelu(residual) 102 | residual = self.conv2(residual) 103 | residual = self.bn2(residual) 104 | 105 | return x + residual 106 | 107 | 108 | class UpsampleBLock(nn.Module): 109 | def __init__(self, in_channels, up_scale): 110 | super(UpsampleBLock, self).__init__() 111 | self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1) 112 | self.pixel_shuffle = nn.PixelShuffle(up_scale) 113 | self.prelu = nn.PReLU() 114 | 115 | def forward(self, x): 116 | x = self.conv(x) 117 | x = self.pixel_shuffle(x) 118 | x = self.prelu(x) 119 | return x 120 | 121 | class Discriminator(nn.Module): 122 | def __init__(self): 123 | super(Discriminator, self).__init__() 124 | self.net = nn.Sequential( 125 | nn.Conv2d(3, 64, kernel_size=3, padding=1), 126 | nn.LeakyReLU(0.2), 127 | 128 | nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), 129 | nn.BatchNorm2d(64), 130 | nn.LeakyReLU(0.2), 131 | 132 | nn.Conv2d(64, 128, kernel_size=3, padding=1), 133 | nn.BatchNorm2d(128), 134 | nn.LeakyReLU(0.2), 135 | 136 | nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1), 137 | nn.BatchNorm2d(128), 138 | nn.LeakyReLU(0.2), 139 | 140 | nn.Conv2d(128, 256, kernel_size=3, padding=1), 141 | nn.BatchNorm2d(256), 142 | nn.LeakyReLU(0.2), 143 | 144 | nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1), 145 | nn.BatchNorm2d(256), 146 | nn.LeakyReLU(0.2), 147 | 148 | nn.Conv2d(256, 512, kernel_size=3, padding=1), 149 | nn.BatchNorm2d(512), 150 | nn.LeakyReLU(0.2), 151 | 152 | nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), 153 | nn.BatchNorm2d(512), 154 | nn.LeakyReLU(0.2), 155 | 156 | nn.AdaptiveAvgPool2d(1), 157 | nn.Conv2d(512, 1024, kernel_size=1), 158 | nn.LeakyReLU(0.2), 159 | nn.Conv2d(1024, 1, kernel_size=1) 160 | ) 161 | 162 | def forward(self, x): 163 | batch_size = x.size(0) 164 | return F.sigmoid(self.net(x).view(batch_size)) 165 | 166 | #----------densenet# 167 | def xavier(param): 168 | init.xavier_uniform(param) 169 | class SingleLayer(nn.Module): 170 | def __init__(self, inChannels,growthRate): 171 | super(SingleLayer, self).__init__() 172 | self.conv =nn.Conv2d(inChannels,growthRate,kernel_size=3,padding=1, bias=True) 173 | def forward(self, x): 174 | out = F.relu(self.conv(x)) 175 | out = torch.cat((x, out), 1) 176 | return out 177 | 178 | class SingleBlock(nn.Module): 179 | def __init__(self, inChannels,growthRate,nDenselayer): 180 | super(SingleBlock, self).__init__() 181 | self.block= self._make_dense(inChannels,growthRate, nDenselayer) 182 | 183 | def _make_dense(self,inChannels,growthRate, nDenselayer): 184 | layers = [] 185 | for i in range(int(nDenselayer)): 186 | layers.append(SingleLayer(inChannels,growthRate)) 187 | inChannels += growthRate 188 | return nn.Sequential(*layers) 189 | 190 | def forward(self, x): 191 | out=self.block(x) 192 | return out 193 | 194 | class SRDenseNet(nn.Module): 195 | def __init__(self,inChannels,growthRate,nDenselayer,nBlock): 196 | super(SRDenseNet,self).__init__() 197 | 198 | self.conv1 = nn.Conv2d(3,growthRate,kernel_size=3, padding=1,bias=True) 199 | 200 | inChannels = growthRate 201 | 202 | self.denseblock = self._make_block(inChannels,growthRate, nDenselayer,nBlock) 203 | inChannels +=growthRate* nDenselayer*nBlock 204 | 205 | self.Bottleneck = nn.Conv2d(in_channels=inChannels, out_channels=128, kernel_size=1,padding=0, bias=True) 206 | 207 | self.convt1 = nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1, bias=True) 208 | 209 | self.convt2 =nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1, bias=True) 210 | 211 | self.conv2 =nn.Conv2d(in_channels=128, out_channels=3, kernel_size=3,padding=1, bias=True) 212 | 213 | 214 | # for m in self.modules(): 215 | # if isinstance(m, nn.Conv2d): 216 | # xavier(m.weight.data) 217 | # if m.bias is not None: 218 | # m.bias.data.zero_() 219 | 220 | def _make_block(self, inChannels,growthRate, nDenselayer,nBlock): 221 | blocks =[] 222 | for i in range(int(nBlock)): 223 | blocks.append(SingleBlock(inChannels,growthRate,nDenselayer)) 224 | inChannels += growthRate* nDenselayer 225 | return nn.Sequential(* blocks) 226 | 227 | def forward(self,x): 228 | out = F.relu(self.conv1(x)) 229 | out = self.denseblock(out) 230 | out = self.Bottleneck(out) 231 | out = self.convt1(out) 232 | out = self.convt2(out) 233 | 234 | HR = self.conv2(out) 235 | return HR 236 | 237 | if __name__ == '__main__': 238 | a = torch.randn(3, 1, 10, 10) 239 | a = Variable(a) 240 | net = SRDenseNet(16,16,8,8) #default upscale is 4 241 | # pdb.set_trace() 242 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | 7 | 8 | ############################################################################### 9 | # Helper Functions 10 | ############################################################################### 11 | def get_norm_layer(norm_type='instance'): 12 | """Return a normalization layer 13 | 14 | Parameters: 15 | norm_type (str) -- the name of the normalization layer: batch | instance | none 16 | 17 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 18 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 19 | """ 20 | if norm_type == 'batch': 21 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 22 | elif norm_type == 'instance': 23 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 24 | elif norm_type == 'none': 25 | norm_layer = None 26 | else: 27 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 28 | return norm_layer 29 | 30 | 31 | def get_scheduler(optimizer, opt): 32 | """Return a learning rate scheduler 33 | 34 | Parameters: 35 | optimizer -- the optimizer of the network 36 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  37 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 38 | 39 | For 'linear', we keep the same learning rate for the first epochs 40 | and linearly decay the rate to zero over the next epochs. 41 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 42 | See https://pytorch.org/docs/stable/optim.html for more details. 43 | """ 44 | if opt.lr_policy == 'linear': 45 | def lambda_rule(epoch): 46 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 47 | return lr_l 48 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 49 | elif opt.lr_policy == 'step': 50 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 51 | elif opt.lr_policy == 'plateau': 52 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 53 | elif opt.lr_policy == 'cosine': 54 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 55 | else: 56 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 57 | return scheduler 58 | 59 | 60 | def init_weights(net, init_type='normal', init_gain=0.02): 61 | """Initialize network weights. 62 | 63 | Parameters: 64 | net (network) -- network to be initialized 65 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 66 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 67 | 68 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 69 | work better for some applications. Feel free to try yourself. 70 | """ 71 | def init_func(m): # define the initialization function 72 | classname = m.__class__.__name__ 73 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 74 | if init_type == 'normal': 75 | init.normal_(m.weight.data, 0.0, init_gain) 76 | elif init_type == 'xavier': 77 | init.xavier_normal_(m.weight.data, gain=init_gain) 78 | elif init_type == 'kaiming': 79 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 80 | elif init_type == 'orthogonal': 81 | init.orthogonal_(m.weight.data, gain=init_gain) 82 | else: 83 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 84 | if hasattr(m, 'bias') and m.bias is not None: 85 | init.constant_(m.bias.data, 0.0) 86 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 87 | init.normal_(m.weight.data, 1.0, init_gain) 88 | init.constant_(m.bias.data, 0.0) 89 | 90 | print('initialize network with %s' % init_type) 91 | net.apply(init_func) # apply the initialization function 92 | 93 | 94 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 95 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 96 | Parameters: 97 | net (network) -- the network to be initialized 98 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 99 | gain (float) -- scaling factor for normal, xavier and orthogonal. 100 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 101 | 102 | Return an initialized network. 103 | """ 104 | if len(gpu_ids) > 0: 105 | assert(torch.cuda.is_available()) 106 | net.to(gpu_ids[0]) 107 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 108 | init_weights(net, init_type, init_gain=init_gain) 109 | return net 110 | 111 | 112 | def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]): 113 | """Create a generator 114 | 115 | Parameters: 116 | input_nc (int) -- the number of channels in input images 117 | output_nc (int) -- the number of channels in output images 118 | ngf (int) -- the number of filters in the last conv layer 119 | netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128 120 | norm (str) -- the name of normalization layers used in the network: batch | instance | none 121 | use_dropout (bool) -- if use dropout layers. 122 | init_type (str) -- the name of our initialization method. 123 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 124 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 125 | 126 | Returns a generator 127 | 128 | Our current implementation provides two types of generators: 129 | U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images) 130 | The original U-Net paper: https://arxiv.org/abs/1505.04597 131 | 132 | Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks) 133 | Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations. 134 | We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style). 135 | 136 | 137 | The generator has been initialized by . It uses RELU for non-linearity. 138 | """ 139 | net = None 140 | norm_layer = get_norm_layer(norm_type=norm) 141 | 142 | if netG == 'resnet_9blocks': 143 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) 144 | elif netG == 'resnet_6blocks': 145 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) 146 | elif netG == 'unet_128': 147 | net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 148 | elif netG == 'unet_256': 149 | net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 150 | else: 151 | raise NotImplementedError('Generator model name [%s] is not recognized' % netG) 152 | return init_net(net, init_type, init_gain, gpu_ids) 153 | 154 | 155 | def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]): 156 | """Create a discriminator 157 | 158 | Parameters: 159 | input_nc (int) -- the number of channels in input images 160 | ndf (int) -- the number of filters in the first conv layer 161 | netD (str) -- the architecture's name: basic | n_layers | pixel 162 | n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' 163 | norm (str) -- the type of normalization layers used in the network. 164 | init_type (str) -- the name of the initialization method. 165 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 166 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 167 | 168 | Returns a discriminator 169 | 170 | Our current implementation provides three types of discriminators: 171 | [basic]: 'PatchGAN' classifier described in the original pix2pix paper. 172 | It can classify whether 70×70 overlapping patches are real or fake. 173 | Such a patch-level discriminator architecture has fewer parameters 174 | than a full-image discriminator and can work on arbitrarily-sized images 175 | in a fully convolutional fashion. 176 | 177 | [n_layers]: With this mode, you cna specify the number of conv layers in the discriminator 178 | with the parameter (default=3 as used in [basic] (PatchGAN).) 179 | 180 | [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not. 181 | It encourages greater color diversity but has no effect on spatial statistics. 182 | 183 | The discriminator has been initialized by . It uses Leakly RELU for non-linearity. 184 | """ 185 | net = None 186 | norm_layer = get_norm_layer(norm_type=norm) 187 | 188 | if netD == 'basic': # default PatchGAN classifier 189 | net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer) 190 | elif netD == 'n_layers': # more options 191 | net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) 192 | elif netD == 'pixel': # classify if each pixel is real or fake 193 | net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) 194 | else: 195 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % net) 196 | return init_net(net, init_type, init_gain, gpu_ids) 197 | 198 | 199 | ############################################################################## 200 | # Classes 201 | ############################################################################## 202 | class GANLoss(nn.Module): 203 | """Define different GAN objectives. 204 | 205 | The GANLoss class abstracts away the need to create the target label tensor 206 | that has the same size as the input. 207 | """ 208 | 209 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): 210 | """ Initialize the GANLoss class. 211 | 212 | Parameters: 213 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. 214 | target_real_label (bool) - - label for a real image 215 | target_fake_label (bool) - - label of a fake image 216 | 217 | Note: Do not use sigmoid as the last layer of Discriminator. 218 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. 219 | """ 220 | super(GANLoss, self).__init__() 221 | self.register_buffer('real_label', torch.tensor(target_real_label)) 222 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 223 | self.gan_mode = gan_mode 224 | if gan_mode == 'lsgan': 225 | self.loss = nn.MSELoss() 226 | elif gan_mode == 'vanilla': 227 | self.loss = nn.BCEWithLogitsLoss() 228 | elif gan_mode in ['wgangp']: 229 | self.loss = None 230 | else: 231 | raise NotImplementedError('gan mode %s not implemented' % gan_mode) 232 | 233 | def get_target_tensor(self, prediction, target_is_real): 234 | """Create label tensors with the same size as the input. 235 | 236 | Parameters: 237 | prediction (tensor) - - tpyically the prediction from a discriminator 238 | target_is_real (bool) - - if the ground truth label is for real images or fake images 239 | 240 | Returns: 241 | A label tensor filled with ground truth label, and with the size of the input 242 | """ 243 | 244 | if target_is_real: 245 | target_tensor = self.real_label 246 | else: 247 | target_tensor = self.fake_label 248 | return target_tensor.expand_as(prediction) 249 | 250 | def __call__(self, prediction, target_is_real): 251 | """Calculate loss given Discriminator's output and grount truth labels. 252 | 253 | Parameters: 254 | prediction (tensor) - - tpyically the prediction output from a discriminator 255 | target_is_real (bool) - - if the ground truth label is for real images or fake images 256 | 257 | Returns: 258 | the calculated loss. 259 | """ 260 | if self.gan_mode in ['lsgan', 'vanilla']: 261 | target_tensor = self.get_target_tensor(prediction, target_is_real) 262 | loss = self.loss(prediction, target_tensor) 263 | elif self.gan_mode == 'wgangp': 264 | if target_is_real: 265 | loss = -prediction.mean() 266 | else: 267 | loss = prediction.mean() 268 | return loss 269 | 270 | 271 | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): 272 | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 273 | 274 | Arguments: 275 | netD (network) -- discriminator network 276 | real_data (tensor array) -- real images 277 | fake_data (tensor array) -- generated images from the generator 278 | device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 279 | type (str) -- if we mix real and fake data or not [real | fake | mixed]. 280 | constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2 281 | lambda_gp (float) -- weight for this loss 282 | 283 | Returns the gradient penalty loss 284 | """ 285 | if lambda_gp > 0.0: 286 | if type == 'real': # either use real images, fake images, or a linear interpolation of two. 287 | interpolatesv = real_data 288 | elif type == 'fake': 289 | interpolatesv = fake_data 290 | elif type == 'mixed': 291 | alpha = torch.rand(real_data.shape[0], 1) 292 | alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) 293 | alpha = alpha.to(device) 294 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) 295 | else: 296 | raise NotImplementedError('{} not implemented'.format(type)) 297 | interpolatesv.requires_grad_(True) 298 | disc_interpolates = netD(interpolatesv) 299 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, 300 | grad_outputs=torch.ones(disc_interpolates.size()).to(device), 301 | create_graph=True, retain_graph=True, only_inputs=True) 302 | gradients = gradients[0].view(real_data.size(0), -1) # flat the data 303 | gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps 304 | return gradient_penalty, gradients 305 | else: 306 | return 0.0, None 307 | 308 | 309 | class ResnetGenerator(nn.Module): 310 | """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. 311 | 312 | We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) 313 | """ 314 | 315 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): 316 | """Construct a Resnet-based generator 317 | 318 | Parameters: 319 | input_nc (int) -- the number of channels in input images 320 | output_nc (int) -- the number of channels in output images 321 | ngf (int) -- the number of filters in the last conv layer 322 | norm_layer -- normalization layer 323 | use_dropout (bool) -- if use dropout layers 324 | n_blocks (int) -- the number of ResNet blocks 325 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero 326 | """ 327 | assert(n_blocks >= 0) 328 | super(ResnetGenerator, self).__init__() 329 | if type(norm_layer) == functools.partial: 330 | use_bias = norm_layer.func == nn.InstanceNorm2d 331 | else: 332 | use_bias = norm_layer == nn.InstanceNorm2d 333 | 334 | model = [nn.ReflectionPad2d(3), 335 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), 336 | norm_layer(ngf), 337 | nn.ReLU(True)] 338 | 339 | n_downsampling = 2 340 | for i in range(n_downsampling): # add downsampling layers 341 | mult = 2 ** i 342 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), 343 | norm_layer(ngf * mult * 2), 344 | nn.ReLU(True)] 345 | 346 | mult = 2 ** n_downsampling 347 | for i in range(n_blocks): # add ResNet blocks 348 | 349 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 350 | 351 | for i in range(n_downsampling): # add upsampling layers 352 | mult = 2 ** (n_downsampling - i) 353 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 354 | kernel_size=3, stride=2, 355 | padding=1, output_padding=1, 356 | bias=use_bias), 357 | norm_layer(int(ngf * mult / 2)), 358 | nn.ReLU(True)] 359 | model += [nn.ReflectionPad2d(3)] 360 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 361 | model += [nn.Tanh()] 362 | 363 | self.model = nn.Sequential(*model) 364 | 365 | def forward(self, input): 366 | """Standard forward""" 367 | # normalize to 0-1 368 | return (self.model(input) + 1) / 2 369 | 370 | 371 | class ResnetBlock(nn.Module): 372 | """Define a Resnet block""" 373 | 374 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 375 | """Initialize the Resnet block 376 | 377 | A resnet block is a conv block with skip connections 378 | We construct a conv block with build_conv_block function, 379 | and implement skip connections in function. 380 | Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf 381 | """ 382 | super(ResnetBlock, self).__init__() 383 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 384 | 385 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 386 | """Construct a convolutional block. 387 | 388 | Parameters: 389 | dim (int) -- the number of channels in the conv layer. 390 | padding_type (str) -- the name of padding layer: reflect | replicate | zero 391 | norm_layer -- normalization layer 392 | use_dropout (bool) -- if use dropout layers. 393 | use_bias (bool) -- if the conv layer uses bias or not 394 | 395 | Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) 396 | """ 397 | conv_block = [] 398 | p = 0 399 | if padding_type == 'reflect': 400 | conv_block += [nn.ReflectionPad2d(1)] 401 | elif padding_type == 'replicate': 402 | conv_block += [nn.ReplicationPad2d(1)] 403 | elif padding_type == 'zero': 404 | p = 1 405 | else: 406 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 407 | 408 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] 409 | if use_dropout: 410 | conv_block += [nn.Dropout(0.5)] 411 | 412 | p = 0 413 | if padding_type == 'reflect': 414 | conv_block += [nn.ReflectionPad2d(1)] 415 | elif padding_type == 'replicate': 416 | conv_block += [nn.ReplicationPad2d(1)] 417 | elif padding_type == 'zero': 418 | p = 1 419 | else: 420 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 421 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] 422 | 423 | return nn.Sequential(*conv_block) 424 | 425 | def forward(self, x): 426 | """Forward function (with skip connections)""" 427 | out = x + self.conv_block(x) # add skip connections 428 | return out 429 | 430 | 431 | class UnetGenerator(nn.Module): 432 | """Create a Unet-based generator""" 433 | 434 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): 435 | """Construct a Unet generator 436 | Parameters: 437 | input_nc (int) -- the number of channels in input images 438 | output_nc (int) -- the number of channels in output images 439 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, 440 | image of size 128x128 will become of size 1x1 # at the bottleneck 441 | ngf (int) -- the number of filters in the last conv layer 442 | norm_layer -- normalization layer 443 | 444 | We construct the U-Net from the innermost layer to the outermost layer. 445 | It is a recursive process. 446 | """ 447 | super(UnetGenerator, self).__init__() 448 | # construct unet structure 449 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer 450 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters 451 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) 452 | # gradually reduce the number of filters from ngf * 8 to ngf 453 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 454 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 455 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 456 | self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer 457 | 458 | def forward(self, input): 459 | """Standard forward""" 460 | return self.model(input) 461 | 462 | 463 | class UnetSkipConnectionBlock(nn.Module): 464 | """Defines the Unet submodule with skip connection. 465 | X -------------------identity---------------------- 466 | |-- downsampling -- |submodule| -- upsampling --| 467 | """ 468 | 469 | def __init__(self, outer_nc, inner_nc, input_nc=None, 470 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 471 | """Construct a Unet submodule with skip connections. 472 | 473 | Parameters: 474 | outer_nc (int) -- the number of filters in the outer conv layer 475 | inner_nc (int) -- the number of filters in the inner conv layer 476 | input_nc (int) -- the number of channels in input images/features 477 | submodule (UnetSkipConnectionBlock) -- previously defined submodules 478 | outermost (bool) -- if this module is the outermost module 479 | innermost (bool) -- if this module is the innermost module 480 | norm_layer -- normalization layer 481 | user_dropout (bool) -- if use dropout layers. 482 | """ 483 | super(UnetSkipConnectionBlock, self).__init__() 484 | self.outermost = outermost 485 | if type(norm_layer) == functools.partial: 486 | use_bias = norm_layer.func == nn.InstanceNorm2d 487 | else: 488 | use_bias = norm_layer == nn.InstanceNorm2d 489 | if input_nc is None: 490 | input_nc = outer_nc 491 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, 492 | stride=2, padding=1, bias=use_bias) 493 | downrelu = nn.LeakyReLU(0.2, True) 494 | downnorm = norm_layer(inner_nc) 495 | uprelu = nn.ReLU(True) 496 | upnorm = norm_layer(outer_nc) 497 | 498 | if outermost: 499 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 500 | kernel_size=4, stride=2, 501 | padding=1) 502 | down = [downconv] 503 | up = [uprelu, upconv, nn.Tanh()] 504 | model = down + [submodule] + up 505 | elif innermost: 506 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 507 | kernel_size=4, stride=2, 508 | padding=1, bias=use_bias) 509 | down = [downrelu, downconv] 510 | up = [uprelu, upconv, upnorm] 511 | model = down + up 512 | else: 513 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 514 | kernel_size=4, stride=2, 515 | padding=1, bias=use_bias) 516 | down = [downrelu, downconv, downnorm] 517 | up = [uprelu, upconv, upnorm] 518 | 519 | if use_dropout: 520 | model = down + [submodule] + up + [nn.Dropout(0.5)] 521 | else: 522 | model = down + [submodule] + up 523 | 524 | self.model = nn.Sequential(*model) 525 | 526 | def forward(self, x): 527 | if self.outermost: 528 | return self.model(x) 529 | else: # add skip connections 530 | return torch.cat([x, self.model(x)], 1) 531 | 532 | 533 | class NLayerDiscriminator(nn.Module): 534 | """Defines a PatchGAN discriminator""" 535 | 536 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): 537 | """Construct a PatchGAN discriminator 538 | 539 | Parameters: 540 | input_nc (int) -- the number of channels in input images 541 | ndf (int) -- the number of filters in the last conv layer 542 | n_layers (int) -- the number of conv layers in the discriminator 543 | norm_layer -- normalization layer 544 | """ 545 | super(NLayerDiscriminator, self).__init__() 546 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 547 | use_bias = norm_layer.func != nn.BatchNorm2d 548 | else: 549 | use_bias = norm_layer != nn.BatchNorm2d 550 | 551 | kw = 4 552 | padw = 1 553 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 554 | nf_mult = 1 555 | nf_mult_prev = 1 556 | for n in range(1, n_layers): # gradually increase the number of filters 557 | nf_mult_prev = nf_mult 558 | nf_mult = min(2 ** n, 8) 559 | sequence += [ 560 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 561 | norm_layer(ndf * nf_mult), 562 | nn.LeakyReLU(0.2, True) 563 | ] 564 | 565 | nf_mult_prev = nf_mult 566 | nf_mult = min(2 ** n_layers, 8) 567 | sequence += [ 568 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 569 | norm_layer(ndf * nf_mult), 570 | nn.LeakyReLU(0.2, True) 571 | ] 572 | 573 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 574 | self.model = nn.Sequential(*sequence) 575 | 576 | def forward(self, input): 577 | """Standard forward.""" 578 | return self.model(input) 579 | 580 | 581 | class PixelDiscriminator(nn.Module): 582 | """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" 583 | 584 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): 585 | """Construct a 1x1 PatchGAN discriminator 586 | 587 | Parameters: 588 | input_nc (int) -- the number of channels in input images 589 | ndf (int) -- the number of filters in the last conv layer 590 | norm_layer -- normalization layer 591 | """ 592 | super(PixelDiscriminator, self).__init__() 593 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 594 | use_bias = norm_layer.func != nn.InstanceNorm2d 595 | else: 596 | use_bias = norm_layer != nn.InstanceNorm2d 597 | 598 | self.net = [ 599 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), 600 | nn.LeakyReLU(0.2, True), 601 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), 602 | norm_layer(ndf * 2), 603 | nn.LeakyReLU(0.2, True), 604 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] 605 | 606 | self.net = nn.Sequential(*self.net) 607 | 608 | def forward(self, input): 609 | """Standard forward.""" 610 | return self.net(input) 611 | -------------------------------------------------------------------------------- /pytorch_ssim/__init__.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/Po-Hsun-Su/pytorch-ssim 2 | 3 | from math import exp 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | 10 | def gaussian(window_size, sigma): 11 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 12 | return gauss / gauss.sum() 13 | 14 | 15 | def create_window(window_size, channel): 16 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 17 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 18 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 19 | return window 20 | 21 | 22 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 23 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 24 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 25 | 26 | mu1_sq = mu1.pow(2) 27 | mu2_sq = mu2.pow(2) 28 | mu1_mu2 = mu1 * mu2 29 | 30 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 31 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 32 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 33 | 34 | C1 = 0.01 ** 2 35 | C2 = 0.03 ** 2 36 | 37 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 38 | 39 | if size_average: 40 | return ssim_map.mean() 41 | else: 42 | return ssim_map.mean(1).mean(1).mean(1) 43 | 44 | 45 | class SSIM(torch.nn.Module): 46 | def __init__(self, window_size=11, size_average=True): 47 | super(SSIM, self).__init__() 48 | self.window_size = window_size 49 | self.size_average = size_average 50 | self.channel = 1 51 | self.window = create_window(window_size, self.channel) 52 | 53 | def forward(self, img1, img2): 54 | (_, channel, _, _) = img1.size() 55 | 56 | if channel == self.channel and self.window.data.type() == img1.data.type(): 57 | window = self.window 58 | else: 59 | window = create_window(self.window_size, channel) 60 | 61 | if img1.is_cuda: 62 | window = window.cuda(img1.get_device()) 63 | window = window.type_as(img1) 64 | 65 | self.window = window 66 | self.channel = channel 67 | 68 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 69 | 70 | 71 | def ssim(img1, img2, window_size=11, size_average=True): 72 | (_, channel, _, _) = img1.size() 73 | window = create_window(window_size, channel) 74 | 75 | if img1.is_cuda: 76 | window = window.cuda(img1.get_device()) 77 | window = window.type_as(img1) 78 | 79 | return _ssim(img1, img2, window, window_size, channel, size_average) 80 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.distributed as dist 11 | import torch.optim 12 | import torch.utils.data 13 | import torch.utils.data.distributed 14 | import torchvision.transforms as transforms 15 | import torchvision.datasets as datasets 16 | import torchvision.models as models 17 | import pdb 18 | from model_sr import SimpleNet, ResNet, Discriminator, SRDenseNet 19 | from torch.autograd import Variable 20 | from PIL import Image 21 | from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize 22 | import numpy as np 23 | import json 24 | from math import log10 25 | import pytorch_ssim 26 | from networks import ResnetGenerator 27 | from data_utils import SRImageFolder, DNImageFolder, JPEGImageFolder, SelfImageFolder 28 | from time import gmtime, strftime 29 | 30 | # from skimage import io, color 31 | 32 | model_names = sorted(name for name in models.__dict__ 33 | if name.islower() and not name.startswith("__") 34 | and callable(models.__dict__[name])) 35 | 36 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 37 | parser.add_argument('--data', default='/home/zhuangl/datasets/imagenet', 38 | help='path to dataset') 39 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', 40 | choices=model_names, 41 | help='model architecture: ' + 42 | ' | '.join(model_names) + 43 | ' (default: resnet18)') 44 | parser.add_argument('-j', '--workers', default=5, type=int, metavar='N', 45 | help='number of data loading workers (default: 4)') 46 | parser.add_argument('--epochs', default=6, type=int, metavar='N', 47 | help='number of total epochs to run') 48 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 49 | help='manual epoch number (useful on restarts)') 50 | parser.add_argument('-b', '--batch-size', default=20, type=int, 51 | metavar='N', help='mini-batch size (default: 256)') 52 | parser.add_argument('--lr', '--learning-rate', default=0.0001, type=float, 53 | metavar='LR', help='initial learning rate') 54 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 55 | help='momentum') 56 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 57 | metavar='W', help='weight decay (default: 1e-4)') 58 | parser.add_argument('--print-freq', '-p', default=1, type=int, 59 | metavar='N', help='print frequency (default: 10)') 60 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 61 | help='path to latest checkpoint (default: none)') 62 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 63 | help='evaluate model on validation set') 64 | parser.add_argument('--pretrained', dest='pretrained', action='store_false', 65 | help='use pre-trained model') 66 | parser.add_argument('--world-size', default=1, type=int, 67 | help='number of distributed processes') 68 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 69 | help='url used to set up distributed training') 70 | parser.add_argument('--dist-backend', default='gloo', type=str, 71 | help='distributed backend') 72 | parser.add_argument('--evaluate-notransform', action='store_true', help='whether to evaluate bicubic interpolation') 73 | parser.add_argument('--upscale', default=4, type=int) # SR up resolution 74 | parser.add_argument('--l', default=0, type=float) # coefficient for RA loss, lambda in paper 75 | parser.add_argument('--save-dir', default='checkpoint/default/', type=str) 76 | parser.add_argument('--mode', default='sr',type=str) # mode, ra, ra_transformer, ra_unsupervised 77 | parser.add_argument('--task', default='sr', type=str) # 78 | parser.add_argument('--std', default=0.1, type=float) # noise level for denoising 79 | parser.add_argument('--L', default=1, type=float) 80 | parser.add_argument('--model-sr', default='test', type=str) # for evaluation 81 | parser.add_argument('--model-transformer', default=None, type=str) # for evaluation 82 | parser.add_argument('--test-batch-size', default=20, type=int) 83 | parser.add_argument('--cross-evaluate', action='store_true') 84 | parser.add_argument('--custom-evaluate', action='store_true') 85 | parser.add_argument('--custom-evaluate-model', default='', type=str) 86 | parser.add_argument('--sr-arch', default='SRResNet', type=str) 87 | parser.add_argument('--transformer-arch', default='pix2pix', type=str) 88 | parser.add_argument('--lower_lr', action='store_false', help='whether to lower lr every certain epochs') # default is True 89 | parser.add_argument('--vis', action='store_true', help='whether to visualize sr results') 90 | parser.add_argument('--l_soft', default=0.001, type=float) 91 | # parser.add_argument('--sr_model', action='store_true', help='whether to use the SRResNet model in dn and jpeg') 92 | best_prec1 = 0 93 | 94 | # get high res images output by bicubic interpolation, only used in notransform test for sr 95 | def trans_RGB_bicubic(data): 96 | up = args.upscale 97 | ims_np = (data.clone()*255.).permute(0, 2, 3, 1).numpy().astype(np.uint8) 98 | 99 | hr_size = ims_np.shape[1] 100 | 101 | lr_size = hr_size // up 102 | 103 | rgb_hrs = data.new().resize_(data.size(0), 3, hr_size, hr_size).zero_() 104 | 105 | for i, im_np in enumerate(ims_np): 106 | im = Image.fromarray(im_np, 'RGB') 107 | rgb_lr = Resize((lr_size, lr_size), Image.BICUBIC)(im) 108 | rgb_hr = Resize((hr_size, hr_size), Image.BICUBIC)(rgb_lr) 109 | rgb_hr = ToTensor()(rgb_hr) 110 | rgb_hrs[i].copy_(rgb_hr) 111 | return rgb_hrs 112 | 113 | # normalize the output of sr, to fit cls network input 114 | # input 0-1, output: normalized imagenet network input 115 | def process_to_input_cls(RGB): 116 | means = [0.485, 0.456, 0.406] 117 | stds = [0.229, 0.224, 0.225] 118 | RGB_new = torch.autograd.Variable(RGB.data.new(*RGB.size())) 119 | 120 | RGB_new[:, 0, :, :] = (RGB[:, 0, :, :] - means[0]) / stds[0] 121 | RGB_new[:, 1, :, :] = (RGB[:, 1, :, :] - means[1]) / stds[1] 122 | RGB_new[:, 2, :, :] = (RGB[:, 2, :, :] - means[2]) / stds[2] 123 | 124 | return RGB_new 125 | 126 | def main(): 127 | 128 | global args, best_prec1 129 | args = parser.parse_args() 130 | 131 | if 'small' in args.data: 132 | args.epochs = 30 133 | else: 134 | args.epochs = 6 135 | 136 | print(args) 137 | 138 | args.distributed = args.world_size > 1 139 | 140 | 141 | if args.mode == 'ra_transform': 142 | save_dir_extra = '_'.join([args.sr_arch, args.transformer_arch, args.arch]) 143 | elif args.mode == 'ra_unsupervised': 144 | args.l = 10 145 | save_dir_extra = '_'.join([args.sr_arch, str(args.l), args.arch]) 146 | elif args.mode == 'ra': 147 | save_dir_extra = '_'.join([args.sr_arch, str(args.l), args.arch]) 148 | 149 | args.save_dir = args.save_dir + save_dir_extra 150 | 151 | os.makedirs(args.save_dir, exist_ok=True) 152 | print('making directory ', args.save_dir) 153 | 154 | if args.distributed: 155 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 156 | world_size=args.world_size) 157 | 158 | # create model 159 | if args.pretrained: 160 | print("=> using pre-trained model '{}'".format(args.arch)) 161 | model = models.__dict__[args.arch](pretrained=True) 162 | else: 163 | print("=> creating model '{}'".format(args.arch)) 164 | model = models.__dict__[args.arch]() 165 | 166 | # if single machine multi gpus 167 | if not args.distributed: 168 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 169 | model.features = torch.nn.DataParallel(model.features) 170 | model.cuda() 171 | model.eval() 172 | else: 173 | model = torch.nn.DataParallel(model).cuda() #disable multi gpu for now 174 | model.eval() 175 | else: 176 | return 177 | 178 | # two models for sr and two models for dn/jpeg 179 | # before 0302, default for sr is SRResNet, default for dn/jpeg is pix2pix 180 | if args.task == 'sr': 181 | if args.sr_arch == 'SRResNet': 182 | model_sr = ResNet(upscale_factor=4, channel=3, residual=False) 183 | elif args.sr_arch == 'SRDenseNet': 184 | model_sr = SRDenseNet(16,16,8,8) 185 | elif args.task == 'dn' or args.task == 'jpeg': 186 | if args.sr_arch == 'SRResNet': 187 | model_sr = ResNet(upscale_factor=1, channel=3, residual=False) 188 | elif args.sr_arch == 'pix2pix': 189 | model_sr = ResnetGenerator(3, 3, n_blocks=6) 190 | model_sr = torch.nn.DataParallel(model_sr).cuda() 191 | model_sr.train() 192 | 193 | # not using these models for now 194 | if args.transformer_arch == 'SRResNet': 195 | model_transformer = ResNet(upscale_factor=1, channel=3, residual=False) 196 | elif args.transformer_arch == 'pix2pix': 197 | model_transformer = ResnetGenerator(3, 3, n_blocks=6) 198 | 199 | model_transformer = torch.nn.DataParallel(model_transformer).cuda() 200 | model_transformer.train() 201 | 202 | 203 | criterion_sr = nn.MSELoss() 204 | criterion_sr.cuda() 205 | criterion = nn.CrossEntropyLoss().cuda() 206 | 207 | optimizer_sr = torch.optim.Adam(model_sr.parameters(), lr=args.lr) # previous used 0.001 as default, now 0.0001 208 | optimizer_transformer = torch.optim.Adam(model_transformer.parameters(), lr=args.lr) 209 | 210 | optimizer = torch.optim.SGD(model.parameters(), 0.01, 211 | momentum=args.momentum, 212 | weight_decay=args.weight_decay) 213 | 214 | # optionally resume from a checkpoint, not supported now 215 | if args.resume: 216 | if os.path.isfile(args.resume): 217 | print("=> loading checkpoint '{}'".format(args.resume)) 218 | checkpoint = torch.load(args.resume) 219 | args.start_epoch = checkpoint['epoch'] 220 | best_prec1 = checkpoint['best_prec1'] 221 | model.load_state_dict(checkpoint['state_dict']) 222 | optimizer.load_state_dict(checkpoint['optimizer']) 223 | print("=> loaded checkpoint '{}' (epoch {})" 224 | .format(args.resume, checkpoint['epoch'])) 225 | else: 226 | print("=> no checkpoint found at '{}'".format(args.resume)) 227 | 228 | cudnn.benchmark = True 229 | 230 | # Data loading code 231 | traindir = os.path.join(args.data, 'train') 232 | valdir = os.path.join(args.data, 'val') 233 | # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 234 | # std=[0.229, 0.224, 0.225]) 235 | train_transform = transforms.Compose([ 236 | transforms.Resize(256), 237 | transforms.RandomCrop(224), 238 | transforms.RandomHorizontalFlip()]) 239 | val_transform = transforms.Compose([ 240 | transforms.Resize(256), 241 | transforms.CenterCrop(224)]) 242 | 243 | if args.task == 'sr': 244 | train_dataset = SRImageFolder(traindir, train_transform) 245 | val_dataset = SRImageFolder(valdir, val_transform) 246 | elif args.task == 'dn': 247 | train_dataset = DNImageFolder(traindir, train_transform) 248 | val_dataset = DNImageFolder(valdir, val_transform, deterministic=True) 249 | elif args.task == 'self': 250 | train_dataset = SelfImageFolder(traindir, train_transform) 251 | val_dataset = SelfImageFolder(valdir, val_transform) 252 | elif args.task == 'jpeg': 253 | randomfoldername = strftime("%Y-%m-%d_%H-%M-%S", gmtime()) 254 | randomfoldername += str(os.getpid()) 255 | train_dataset = JPEGImageFolder(traindir, train_transform, tmp_dir=args.data + '/trash/{}_{}/'.format(randomfoldername, np.random.randint(1, 1000))) 256 | val_dataset = JPEGImageFolder(valdir, val_transform, tmp_dir=args.data + '/trash/{}_{}/'.format(randomfoldername, np.random.randint(1, 1000))) 257 | 258 | train_loader = torch.utils.data.DataLoader( 259 | train_dataset, batch_size=args.batch_size, shuffle=True, 260 | num_workers=args.workers, pin_memory=True, sampler=None) 261 | 262 | val_loader = torch.utils.data.DataLoader( 263 | val_dataset, batch_size=args.test_batch_size, shuffle=False, 264 | num_workers=args.workers, pin_memory=True) 265 | 266 | run = eval(args.mode) 267 | 268 | # evaluation options 269 | if args.evaluate: 270 | # model_sr = torch.load(args.model_sr).cuda() 271 | model_sr = load_model(args.model_sr, model_sr) 272 | 273 | loss_sr, loss_cls, top1, top5, psnr, ssim = validate(val_loader, model_sr, model_transformer, model, optimizer_sr, criterion_sr, criterion, run=run) 274 | # pdb.set_trace() 275 | save_file = args.model_sr + '_{}.txt'.format(args.arch) 276 | np.savetxt(save_file, [loss_sr, loss_cls, top1, top5, psnr, ssim]) 277 | return 278 | 279 | if args.custom_evaluate: # custom R 280 | # model_sr = torch.load(args.model_sr).cuda() 281 | model_sr = load_model(args.model_sr, model_sr) 282 | 283 | model = torch.load(args.custom_evaluate_model).cuda() 284 | loss_sr, loss_cls, top1, top5, psnr, ssim = validate(val_loader, model_sr, model_transformer, model, optimizer_sr, criterion_sr, criterion, run=run) 285 | # pdb.set_trace() 286 | save_file = args.model_sr + '_custom.txt' 287 | np.savetxt(save_file, [loss_sr, loss_cls, top1, top5, psnr, ssim]) 288 | return 289 | 290 | if args.evaluate_notransform: 291 | os.makedirs('notransform_results/' + args.task, exist_ok=True) 292 | loss_sr, loss_cls, top1, top5, psnr, ssim = validate_notransform(val_loader, model, criterion_sr, criterion) 293 | save_file = 'notransform_results/' + args.task + '/{}.txt'.format(args.arch) 294 | np.savetxt(save_file, [loss_sr, loss_cls, top1, top5, psnr, ssim]) 295 | return 296 | 297 | 298 | # val_loader, model_sr, model_transformer, model_D, model, criterion_sr, criterion, run 299 | if args.cross_evaluate: 300 | basic_model_list = ['resnet18','resnet50','vgg16_bn', 'resnet101', 'densenet121'] 301 | 302 | more_model_list = ['densenet169', 'densenet201', 'vgg13_bn', 'vgg19_bn'] 303 | other_model_list = ['vgg13', 'vgg16', 'vgg19', 'inception_v3'] 304 | # complete_model_list = ['vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'densenet169', 'densenet201', 'inception_v3'] 305 | model_list = basic_model_list + more_model_list + other_model_list 306 | # model_sr = torch.load(args.model_sr).cuda() 307 | 308 | model_sr = load_model(args.model_sr, model_sr) 309 | 310 | # pdb.set_trace() 311 | model_sr = nn.DataParallel(model_sr) 312 | log = {} 313 | if args.model_transformer is not None: 314 | model_transformer = torch.load(args.model_transformer).cuda() 315 | model_transformer = nn.DataParallel(model_transformer) 316 | run=ra_transform 317 | 318 | for arch in basic_model_list: 319 | model = models.__dict__[arch](pretrained=True) 320 | model = torch.nn.DataParallel(model).cuda() 321 | loss_sr, loss_cls, top1, top5, psnr, ssim = validate(val_loader, model_sr, model_transformer, model, optimizer_sr, criterion_sr, criterion, run=run) 322 | # log[arch] = top1 323 | 324 | if isinstance(top1, torch.Tensor): 325 | log[arch] = top1.item() 326 | else: 327 | log[arch] = top1 328 | 329 | with open(args.model_sr + '_' + run.__name__ + '.txt', 'w') as outfile: 330 | json.dump(log, outfile) 331 | return 332 | 333 | if args.vis: 334 | model_sr = load_model(args.model_sr, model_sr) 335 | vis(val_loader, model_sr, model_transformer, model, criterion_sr, criterion, run) 336 | return 337 | 338 | log = [] 339 | 340 | for epoch in range(args.start_epoch, args.epochs): 341 | if args.distributed: 342 | train_sampler.set_epoch(epoch) 343 | if args.lower_lr: 344 | adjust_learning_rate(optimizer_sr, epoch) 345 | adjust_learning_rate(optimizer_transformer, epoch) 346 | 347 | log_tmp = [] 348 | 349 | # train for one epoch 350 | loss_sr, loss_cls, top1, top5, psnr, ssim = train(train_loader, model_sr, model_transformer, model, optimizer_sr, optimizer_transformer, criterion_sr, criterion, epoch, run=run) 351 | log_tmp += [loss_sr, loss_cls, top1, top5, psnr, ssim] 352 | 353 | 354 | # evaluate on validation set 355 | loss_sr, loss_cls, top1, top5, psnr, ssim = validate(val_loader, model_sr, model_transformer, model, optimizer_sr, criterion_sr, criterion, run=run) 356 | log_tmp += [loss_sr, loss_cls, top1, top5, psnr, ssim] 357 | 358 | log.append(log_tmp) 359 | np.savetxt(os.path.join(args.save_dir, 'log.txt'), log) 360 | 361 | model_sr_out_path = os.path.join(args.save_dir, "model_sr_epoch_{}.pth".format(epoch)) 362 | torch.save(model_sr, model_sr_out_path) 363 | print("Checkpoint saved to {}".format(model_sr_out_path)) 364 | 365 | if args.mode == 'ra_transform': 366 | model_transformer_out_path = os.path.join(args.save_dir, "model_transformer_epoch_{}.pth".format(epoch)) 367 | torch.save(model_transformer, model_transformer_out_path) 368 | print("Checkpoint saved to {}".format(model_transformer_out_path)) 369 | 370 | 371 | args.model_sr = model_sr_out_path 372 | vis(val_loader, model_sr, model_transformer, model, criterion_sr, criterion, run) 373 | 374 | # Model possibly trained from an older version of Pytorch, so need this extra custom function here 375 | def load_model(model_path, model_sr): 376 | load_dict = torch.load(model_path).state_dict() 377 | model_dict = model_sr.state_dict() 378 | model_dict.update(load_dict) 379 | model_sr.load_state_dict(model_dict) 380 | 381 | return model_sr 382 | 383 | def ra(input_sr_var, target_sr_var, target_cls_var, model_sr, model_transformer, model, 384 | optimizer_sr, optimizer_transformer, criterion_sr, criterion, train=True): 385 | if train: 386 | optimizer_sr.zero_grad() 387 | 388 | # pdb.set_trace() 389 | output_sr = model_sr(input_sr_var) 390 | # pdb.set_trace() 391 | loss_sr = criterion_sr(output_sr, target_sr_var) 392 | 393 | loss_cls = 0 394 | 395 | input_cls = process_to_input_cls(output_sr) 396 | output_cls = model(input_cls) 397 | loss_cls = criterion(output_cls, target_cls_var) 398 | 399 | 400 | # compute ssim for every image 401 | ssim = 0 402 | # not compute during training to save time 403 | if not train: 404 | for i in range(output_sr.size(0)): 405 | sr_image = output_sr[i].unsqueeze(0) 406 | hr_image = target_sr_var[i].unsqueeze(0) 407 | ssim += pytorch_ssim.ssim(sr_image, hr_image).item() 408 | ssim = ssim / output_sr.size(0) 409 | 410 | loss = loss_sr + args.l * loss_cls 411 | 412 | if train: 413 | loss.backward() 414 | optimizer_sr.step() 415 | 416 | return loss_sr, loss_cls, output_cls, ssim 417 | 418 | 419 | def ra_unsupervised(input_sr_var, target_sr_var, target_cls_var, model_sr, model_transformer, model, 420 | optimizer_sr, optimizer_transformer, criterion_sr, criterion, train=True): 421 | if train: 422 | optimizer_sr.zero_grad() 423 | 424 | output_sr = model_sr(input_sr_var) 425 | loss_sr = criterion_sr(output_sr, target_sr_var) 426 | 427 | loss_cls = 0 428 | 429 | input_cls = process_to_input_cls(output_sr) 430 | output_cls = model(input_cls) 431 | 432 | output_cls_soft_target_v = model(process_to_input_cls(target_sr_var)) 433 | 434 | 435 | output_cls_soft_target = Variable(torch.zeros(output_cls_soft_target_v.size())).cuda() 436 | output_cls_soft_target.data.copy_(output_cls_soft_target_v.data) # bug found, lost a "v" here. 437 | loss_cls = criterion_sr(nn.Softmax(dim=1)(output_cls), nn.Softmax(dim=1)(output_cls_soft_target)) 438 | 439 | # output_cls_soft_target = 440 | # loss_cls = criterion(output_cls, target_cls_var) 441 | 442 | # compute ssim for every image 443 | ssim = 0 444 | # not compute during training to save time 445 | if not train: 446 | for i in range(output_sr.size(0)): 447 | sr_image = output_sr[i].unsqueeze(0) 448 | hr_image = target_sr_var[i].unsqueeze(0) 449 | ssim += pytorch_ssim.ssim(sr_image, hr_image).item() 450 | ssim = ssim / output_sr.size(0) 451 | 452 | loss = loss_sr + args.l * loss_cls 453 | 454 | if train: 455 | loss.backward() 456 | optimizer_sr.step() 457 | 458 | return loss_sr, loss_cls, output_cls, ssim 459 | 460 | # in sr transform 2, sr model only optimizes sr loss. 461 | def ra_transform(input_sr_var, target_sr_var, target_cls_var, model_sr, model_transformer, model, 462 | optimizer_sr, optimizer_transformer, criterion_sr, criterion, train=True): 463 | if train: 464 | optimizer_sr.zero_grad() 465 | optimizer_transformer.zero_grad() 466 | 467 | output_sr = model_sr(input_sr_var) 468 | loss_sr = criterion_sr(output_sr, target_sr_var) 469 | 470 | if train: 471 | loss_sr.backward() 472 | optimizer_sr.step() 473 | 474 | loss_cls = 0 475 | 476 | output_sr.detach_() 477 | input_cls = process_to_input_cls(model_transformer(output_sr)) 478 | 479 | output_cls = model(input_cls) 480 | loss_cls = criterion(output_cls, target_cls_var) 481 | # compute ssim for every image 482 | ssim = 0 483 | # not compute during training to save time 484 | if not train: 485 | for i in range(output_sr.size(0)): 486 | sr_image = output_sr[i].unsqueeze(0) 487 | hr_image = target_sr_var[i].unsqueeze(0) 488 | ssim += pytorch_ssim.ssim(sr_image, hr_image).item() 489 | ssim = ssim / output_sr.size(0) 490 | 491 | loss = args.l * loss_cls 492 | 493 | if train: 494 | loss.backward() 495 | optimizer_transformer.step() 496 | 497 | return loss_sr, loss_cls, output_cls, ssim 498 | 499 | 500 | 501 | def train(train_loader, model_sr, model_transformer, model, optimizer_sr, optimizer_transformer, criterion_sr, criterion, epoch, run): 502 | 503 | torch.manual_seed(epoch) 504 | batch_time = AverageMeter() 505 | data_time = AverageMeter() 506 | run_time = AverageMeter() 507 | process_time = AverageMeter() 508 | losses = AverageMeter() 509 | losses_sr = AverageMeter() 510 | losses_cls = AverageMeter() 511 | top1 = AverageMeter() 512 | top5 = AverageMeter() 513 | psnr_avg = AverageMeter() 514 | ssim_avg = AverageMeter() 515 | 516 | model_sr.train() 517 | # if model_transformer is not None: 518 | model_transformer.train() 519 | 520 | # model.eval() 521 | if type(model) is list: 522 | for i in range(len(model)): 523 | model[i].eval() 524 | print(model[i].training) 525 | else: 526 | model.eval() 527 | 528 | end = time.time() 529 | 530 | for i, (img_input, img_output, target) in enumerate(train_loader): 531 | data_time.update(time.time() - end) 532 | 533 | input_sr_var = Variable(img_input.cuda()) 534 | target_sr_var = Variable(img_output).cuda() 535 | target_cls_var = Variable(target).cuda() 536 | target = target.cuda() 537 | 538 | start_run = time.time() 539 | 540 | loss_sr, loss_cls, output_cls, ssim = run(input_sr_var, target_sr_var, target_cls_var, model_sr, model_transformer, model, optimizer_sr, 541 | optimizer_transformer, criterion_sr, criterion, train=True) 542 | 543 | run_time.update(time.time() - start_run) 544 | 545 | process_start = time.time() 546 | psnr = 10 * log10(1 / (loss_sr.item())) 547 | process_time.update(time.time() - process_start) 548 | 549 | prec1, prec5 = accuracy(output_cls.data, target, topk=(1, 5)) 550 | # losses.update(loss.item(), input.size(0)) 551 | top1.update(prec1[0], img_input.size(0)) 552 | top5.update(prec5[0], img_input.size(0)) 553 | losses_sr.update(loss_sr.item(), img_input.size(0)) 554 | losses_cls.update(loss_cls.item(), img_input.size(0)) 555 | psnr_avg.update(psnr, img_input.size(0)) 556 | ssim_avg.update(ssim, img_input.size(0)) 557 | 558 | batch_time.update(time.time() - end) 559 | end = time.time() 560 | 561 | if i % args.print_freq == 0: 562 | print('Epoch: [{0}][{1}/{2}]\t' 563 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 564 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 565 | 'Process {process_time.val:.3f} ({process_time.avg:.3f})\t' 566 | 'Run {run_time.val:.3f} ({run_time.avg:.3f})\t' 567 | # 'Loss {loss.val:.4f} ({loss.avg:.4f})\t's 568 | 'Loss_sr {loss_sr.val:.4f} ({loss_sr.avg:.3f})' 569 | 'Loss_cls {loss_cls.val:.4f} ({loss_cls.avg: .3f})' 570 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 571 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 572 | epoch, i, len(train_loader), batch_time=batch_time, 573 | data_time=data_time, process_time=process_time, run_time=run_time, loss=losses, loss_sr = losses_sr, loss_cls=losses_cls, top1=top1, top5=top5)) 574 | return losses_sr.avg, losses_cls.avg, top1.avg, top5.avg, psnr_avg.avg, ssim_avg.avg 575 | # pdb.set_trace() 576 | 577 | def validate(val_loader, model_sr, model_transformer, model, optimizer_sr, criterion_sr, criterion, run): 578 | 579 | torch.manual_seed(1) 580 | 581 | batch_time = AverageMeter() 582 | data_time = AverageMeter() 583 | losses = AverageMeter() 584 | losses_sr = AverageMeter() 585 | losses_cls = AverageMeter() 586 | top1 = AverageMeter() 587 | top5 = AverageMeter() 588 | psnr_avg = AverageMeter() 589 | ssim_avg = AverageMeter() 590 | 591 | model_sr.eval() 592 | if model_transformer is not None: 593 | model_transformer.eval() 594 | 595 | if type(model) is list: 596 | for i in range(len(model)): 597 | model[i].eval() 598 | print(model[i].training) 599 | else: 600 | model.eval() 601 | 602 | end = time.time() 603 | 604 | for i, (img_input, img_output, target) in enumerate(val_loader): 605 | target = target.cuda(async=True) 606 | input_sr_var = Variable(img_input, volatile=True).cuda() 607 | target_sr_var = Variable(img_output, volatile=True).cuda() 608 | target_cls_var = Variable(target, volatile=True).cuda() 609 | 610 | 611 | loss_sr, loss_cls, output_cls, ssim = run(input_sr_var, target_sr_var, target_cls_var, model_sr, model_transformer, model, optimizer_sr=optimizer_sr, 612 | optimizer_transformer=None, criterion_sr=criterion_sr, criterion=criterion, train=False) 613 | 614 | psnr = 10 * log10(1 / (loss_sr.item() + 1e-9)) 615 | 616 | prec1, prec5 = accuracy(output_cls.data, target, topk=(1, 5)) 617 | # losses.update(loss.item(), input.size(0)) 618 | top1.update(prec1[0], img_input.size(0)) 619 | top5.update(prec5[0], img_input.size(0)) 620 | losses_sr.update(loss_sr.item(), img_input.size(0)) 621 | losses_cls.update(loss_cls.item(), img_input.size(0)) 622 | psnr_avg.update(psnr, img_input.size(0)) 623 | ssim_avg.update(ssim, img_input.size(0)) 624 | 625 | batch_time.update(time.time() - end) 626 | end = time.time() 627 | 628 | if i % args.print_freq == 0: 629 | print('Test: [{0}/{1}]\t' 630 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 631 | 'Loss_sr {loss_sr.val:.4f} ({loss_sr.avg:.4f})\t' 632 | 'Loss_cls {loss_cls.val:.4f} ({loss_cls.avg: .3f})' 633 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 634 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})' 635 | 'PSNR {psnr.val:.3f} ({psnr.avg:.3f})'.format( 636 | i, len(val_loader), batch_time=batch_time, loss_sr=losses_sr, loss_cls=losses_cls, 637 | top1=top1, top5=top5, psnr=psnr_avg)) 638 | return losses_sr.avg, losses_cls.avg, top1.avg, top5.avg, psnr_avg.avg, ssim_avg.avg 639 | 640 | 641 | # evaluate "no processing" 642 | def validate_notransform(val_loader, model, criterion_sr, criterion): 643 | 644 | batch_time = AverageMeter() 645 | data_time = AverageMeter() 646 | losses = AverageMeter() 647 | losses_sr = AverageMeter() 648 | losses_cls = AverageMeter() 649 | top1 = AverageMeter() 650 | top5 = AverageMeter() 651 | psnr_avg = AverageMeter() 652 | ssim_avg = AverageMeter() 653 | 654 | if type(model) is list: 655 | for i in range(len(model)): 656 | model[i].eval() 657 | print(model[i].training) 658 | else: 659 | model.eval() 660 | 661 | end = time.time() 662 | 663 | for i, (img_input, img_output, target) in enumerate(val_loader): 664 | # print(i) 665 | if True: 666 | target = target.cuda(async=True) 667 | target_sr_var = Variable(img_output).cuda() 668 | target_cls_var = Variable(target).cuda() 669 | 670 | # output of bicubic (tensor in 0-1) 671 | if args.task == 'sr': 672 | output_sr = Variable(trans_RGB_bicubic(img_output), volatile=True).cuda() 673 | else: 674 | output_sr = Variable(img_input, volatile=True).cuda() 675 | 676 | # remaining is the same as in sr function 677 | loss_sr = criterion_sr(output_sr, target_sr_var) 678 | 679 | input_cls = process_to_input_cls(output_sr) 680 | output_cls = model(input_cls) 681 | loss_cls = criterion(output_cls, target_cls_var) 682 | 683 | ssim = 0 684 | for j in range(output_sr.size(0)): 685 | sr_image = output_sr[j].unsqueeze(0) 686 | hr_image = target_sr_var[j].unsqueeze(0) 687 | ssim += pytorch_ssim.ssim(sr_image, hr_image).item() 688 | ssim = ssim / output_sr.size(0) 689 | 690 | psnr = 10 * log10(1 / loss_sr.item()) 691 | 692 | 693 | prec1, prec5 = accuracy(output_cls.data, target, topk=(1, 5)) 694 | top1.update(prec1[0], img_input.size(0)) 695 | top5.update(prec5[0], img_input.size(0)) 696 | losses_sr.update(loss_sr.item(), img_input.size(0)) 697 | losses_cls.update(loss_cls.item(), img_input.size(0)) 698 | psnr_avg.update(psnr, img_input.size(0)) 699 | ssim_avg.update(ssim, img_input.size(0)) 700 | 701 | batch_time.update(time.time() - end) 702 | end = time.time() 703 | 704 | if i % args.print_freq == 0: 705 | # pdb.set_trace() 706 | print('Test: [{0}/{1}]\t' 707 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 708 | 'Loss_sr {loss_sr.val:.4f} ({loss_sr.avg:.4f})\t' 709 | 'Loss_cls {loss_cls.val:.4f} ({loss_cls.avg: .3f})' 710 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 711 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})' 712 | 'PSNR {psnr.val:.3f} ({psnr.avg:.3f})'.format( 713 | i, len(val_loader), batch_time=batch_time, loss_sr=losses_sr, loss_cls=losses_cls, 714 | top1=top1, top5=top5, psnr=psnr_avg)) 715 | return losses_sr.avg, losses_cls.avg, top1.avg, top5.avg, psnr_avg.avg, ssim_avg.avg 716 | 717 | def vis(val_loader, model_sr, model_transformer, model, criterion_sr, criterion, run): 718 | 719 | torch.manual_seed(1) 720 | image_list = [] 721 | image_list_input = [] 722 | image_list_target = [] 723 | for i, (img_input, img_output, target) in enumerate(val_loader): 724 | if i > 10: 725 | break 726 | 727 | input_sr_var = Variable(img_input, volatile=True).cuda() 728 | target_sr_var = Variable(img_output, volatile=True).cuda() 729 | target_cls_var = Variable(target, volatile=True).cuda() 730 | output_sr = model_sr(input_sr_var) 731 | im = image_from_RGB(output_sr[0]) 732 | im_input = image_from_RGB(input_sr_var[0]) 733 | im_target = image_from_RGB(target_sr_var[0]) 734 | 735 | image_list.append(im) 736 | image_list_input.append(im_input) 737 | image_list_target.append(im_target) 738 | 739 | im_save = combine_image(image_list) 740 | im_save.save(args.model_sr + '_output.png') 741 | im_save_input = combine_image(image_list_input) 742 | im_save_input.save(args.model_sr + '_input.png') 743 | im_save_target = combine_image(image_list_target) 744 | im_save_target.save(args.model_sr + '_target.png') 745 | 746 | return im_save 747 | 748 | 749 | #utilities functions 750 | 751 | # util functions for visualize 752 | def image_from_RGB(out): 753 | # data = torch.clamp(output_sr*255., 0, 255).data 754 | if out.size(0) == 3: 755 | out = out.permute(1,2,0).cpu() 756 | out_img_y = out.data.numpy() 757 | out_img_y *= 255.0 758 | out_img_y = out_img_y.clip(0, 255) 759 | out_img_y = Image.fromarray(np.uint8(out_img_y), mode='RGB') 760 | elif out.size(0) == 1: 761 | out = out.cpu() 762 | out_img_y = out.data.numpy() 763 | out_img_y *= 255.0 764 | # pdb.set_trace() 765 | out_img_y = out_img_y.clip(0, 255) 766 | out_img_y = out_img_y[0] 767 | # pdb.set_trace() 768 | out_img_y = Image.fromarray(np.uint8(out_img_y), mode='L') 769 | 770 | # out_img_y.save('test.png') 771 | return out_img_y 772 | # pdb.set_trace() 773 | # pdb.set_trace() 774 | def combine_image(images): 775 | # images = map(Image.open, ['Test1.png', 'Test2.png', 'Test3.png']) 776 | widths, heights = zip(*(i.size for i in images)) 777 | 778 | total_width = sum(widths) 779 | max_height = max(heights) 780 | 781 | new_im = Image.new('RGB', (total_width, max_height)) 782 | 783 | x_offset = 0 784 | for im in images: 785 | new_im.paste(im, (x_offset,0)) 786 | x_offset += im.size[0] 787 | 788 | # new_im.save('test.png') 789 | return new_im 790 | 791 | # util functions with imagenet training, not in use 792 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 793 | torch.save(state, filename) 794 | if is_best: 795 | shutil.copyfile(filename, 'model_best.pth.tar') 796 | 797 | 798 | class AverageMeter(object): 799 | """Computes and stores the average and current value""" 800 | def __init__(self): 801 | self.reset() 802 | 803 | def reset(self): 804 | self.val = 0 805 | self.avg = 0 806 | self.sum = 0 807 | self.count = 0 808 | 809 | def update(self, val, n=1): 810 | self.val = val 811 | self.sum += val * n 812 | self.count += n 813 | self.avg = self.sum / self.count 814 | 815 | 816 | def adjust_learning_rate(optimizer, epoch): 817 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 818 | if 'small' in args.data: 819 | if epoch in range(20): 820 | lr = args.lr 821 | elif epoch in range(20, 25): 822 | lr = args.lr * 0.1 823 | elif epoch in range(25, 30): 824 | lr = args.lr * 0.01 825 | else: 826 | if epoch in [0,1,2,3]: # for emergency use, to be changed back to [0,1,2,3] 827 | lr = args.lr 828 | elif epoch in [4]: 829 | lr = args.lr * 0.1 830 | elif epoch in [5]: 831 | lr = args.lr * 0.01 832 | 833 | for param_group in optimizer.param_groups: 834 | param_group['lr'] = lr 835 | 836 | 837 | def accuracy(output, target, topk=(1,)): 838 | """Computes the precision@k for the specified values of k""" 839 | maxk = max(topk) 840 | batch_size = target.size(0) 841 | 842 | _, pred = output.topk(maxk, 1, True, True) 843 | pred = pred.t() 844 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 845 | 846 | res = [] 847 | for k in topk: 848 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 849 | res.append(correct_k.mul_(100.0 / batch_size)) 850 | return res 851 | 852 | 853 | if __name__ == '__main__': 854 | main() 855 | --------------------------------------------------------------------------------