├── LICENSE ├── README.md ├── dataprovider.py ├── groundtruth.jpg ├── model.py ├── small.jpg ├── srresnet.py └── superresolution.jpg /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-SRGAN 2 |

Source: SRResNetVgg5,4: (Ground Truth: )

3 | 4 | PyTorch version of the paper: [Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network](https://arxiv.org/abs/1609.04802) 5 | (currently it does not implement the GAN, but the srresnet + vgg19-5,4 loss) 6 | 7 | you can train a net from scratch: 8 | (optionally start training with just the pixel-wise loss on the resnet part: 9 | `python srresnet.py --image-dir traindir --cuda --pretraining --images 16384 --batchSize 16`) 10 | 11 | (use `--pretrained modelfile.pth` to continue from a pretraining or previous run for example) 12 | `python srresnet.py --image-dir traindir --cuda --images 16384 --batchSize 16` 13 | 14 | 15 | and then inference with the arguments: 16 | `--pretrained model/model_epoch_80.pth --testing --test-image BSDS300/images/train/100075.jpg` 17 | -------------------------------------------------------------------------------- /dataprovider.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.utils.data as data 3 | from os import listdir 4 | from os.path import join 5 | from PIL import Image 6 | from scipy.misc import imread, imresize, imsave 7 | import numpy as np 8 | import random 9 | import torch 10 | 11 | 12 | def is_image_file(filename): 13 | extensions = ['.png', '.jpg', '.jpeg', '.bmp'] 14 | return any(filename.endswith(extension) for extension in extensions) 15 | 16 | 17 | class DatasetFromDir(data.Dataset): 18 | def __init__(self, file_path, samples, height=224, width=224): 19 | 20 | image_dir = file_path 21 | self.height = height 22 | self.width = width 23 | self.scale = 4 24 | self.labels = [] 25 | 26 | image_filenames = [ 27 | join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)] 28 | 29 | for i in image_filenames: 30 | print(i) 31 | if len(self.labels) >= samples: 32 | break 33 | img = imread(i) 34 | try: 35 | H, W = img.shape[0], img.shape[1] 36 | label_orig = Image.fromarray(np.uint8(img)) 37 | if H <= W: 38 | if H < self.height: 39 | label_orig = label_orig.resize( 40 | (W * self.height // H, self.height), Image.ANTIALIAS) 41 | else: 42 | if W < self.width: 43 | label_orig = label_orig.resize( 44 | (self.width, H * self.width // W), Image.ANTIALIAS) 45 | H, W = label_orig.size 46 | if H > self.height and W > self.width: 47 | self.labels.append(label_orig) 48 | 49 | if len(self.labels) >= samples: 50 | break 51 | 52 | except (ValueError, IndexError) as e: 53 | print(i) 54 | print(img.shape, img.dtype) 55 | print(e) 56 | 57 | print('we have {} training samples'.format(len(self.labels))) 58 | 59 | def __getitem__(self, index): 60 | while True: # hack to make sure we have a color image we can handle... 61 | index = random.randint(0, len(self.labels) - 1) 62 | label_orig = self.labels[index] 63 | 64 | W, H = label_orig.size 65 | left = random.randint(0, W - self.width - 1) 66 | top = random.randint(0, H - self.height - 1) 67 | right = left + self.width 68 | bottom = top + self.height 69 | label = label_orig.crop((left, top, right, bottom)) 70 | 71 | data = label.resize( 72 | (self.width // self.scale, self.height // self.scale), Image.ANTIALIAS) 73 | 74 | data = np.asarray(data) 75 | label = np.asarray(label) 76 | 77 | # currently we work only on images with 3 channels 78 | if label.ndim == 3: 79 | if label.shape[2] != 3: 80 | label = label[:, :, 0:3] 81 | data = data[:, :, 0:3] 82 | l_width = label.shape[1] 83 | l_height = label.shape[0] 84 | d_width = data.shape[1] 85 | d_height = data.shape[0] 86 | 87 | input = torch.ByteTensor( 88 | torch.ByteStorage.from_buffer(data.transpose(2, 0, 1).tobytes())).float().div(255).view(3, d_height, d_width) 89 | 90 | target = torch.ByteTensor( 91 | torch.ByteStorage.from_buffer(label.transpose(2, 0, 1).tobytes())).float().div(255).view(3, l_height, l_width) 92 | return input, target 93 | 94 | def __len__(self): 95 | return len(self.labels) 96 | -------------------------------------------------------------------------------- /groundtruth.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kayr7/PyTorch-SRGAN/a62acd8abaef76269c2b206b96e26a488f4a2114/groundtruth.jpg -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | # this is one block for a resnet 9 | 10 | 11 | class Residual(nn.Module): 12 | def __init__(self, n_channels=64): 13 | super(Residual, self).__init__() 14 | self.n_channels = n_channels 15 | self.conv1 = nn.Conv2d(in_channels=self.n_channels, 16 | out_channels=self.n_channels, 17 | kernel_size=3, 18 | stride=1, 19 | padding=1, 20 | bias=False) 21 | self.bn1 = nn.BatchNorm2d(self.n_channels) 22 | self.relu = nn.ReLU(inplace=True) 23 | self.conv2 = nn.Conv2d(in_channels=self.n_channels, 24 | out_channels=self.n_channels, 25 | kernel_size=3, 26 | stride=1, 27 | padding=1, 28 | bias=False) 29 | self.bn2 = nn.BatchNorm2d(self.n_channels) 30 | 31 | def forward(self, x): 32 | input = x 33 | output = torch.add(self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x))))), 34 | input) 35 | return output 36 | 37 | 38 | class SubPixelConv(nn.Module): 39 | def __init__(self, n_channels=64, upsample=2): 40 | super(SubPixelConv, self).__init__() 41 | self.n_channels = n_channels 42 | self.upsample = upsample 43 | self.out_channels = self.upsample * self.upsample * self.n_channels 44 | 45 | self.conv = nn.Conv2d(in_channels=self.n_channels, 46 | out_channels=self.out_channels, 47 | kernel_size=3, 48 | stride=1, 49 | padding=1, 50 | bias=False) 51 | self.upsample_net = nn.PixelShuffle(self.upsample) 52 | self.relu = nn.ReLU(inplace=True) 53 | 54 | def forward(self, x): 55 | input = x 56 | output = self.relu(self.upsample_net(self.conv(x))) 57 | return output 58 | 59 | 60 | class SRResNet(nn.Module): 61 | def __init__(self, n_channels=64, n_blocks=15): 62 | super(SRResNet, self).__init__() 63 | self.n_channels = n_channels 64 | self.inConv = nn.Conv2d(in_channels=3, # RGB 65 | out_channels=self.n_channels, 66 | kernel_size=3, # in paper it is 9, somehow other implementations always used 3 67 | stride=1, 68 | padding=1, 69 | bias=True) 70 | self.inRelu = nn.ReLU(inplace=True) 71 | 72 | self.resBlocks = self.make_block_layers(n_blocks, Residual) 73 | 74 | self.glueConv = nn.Conv2d(in_channels=self.n_channels, 75 | out_channels=self.n_channels, 76 | kernel_size=3, 77 | stride=1, 78 | padding=1, 79 | bias=True) 80 | self.glueBN = nn.BatchNorm2d(self.n_channels) 81 | 82 | self.upscaleBlock = self.make_block_layers(2, SubPixelConv) 83 | 84 | self.outConv = nn.Conv2d(in_channels=n_channels, 85 | out_channels=3, # RGB 86 | kernel_size=3, # paper has 9 87 | padding=1, 88 | bias=True) 89 | 90 | for m in self.modules(): 91 | if isinstance(m, nn.Conv2d): 92 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 93 | m.weight.data.normal_(0, math.sqrt(2. / n)) 94 | if m.bias is not None: 95 | m.bias.data.zero_() 96 | elif isinstance(m, nn.BatchNorm2d): 97 | m.weight.data.fill_(1) 98 | if m.bias is not None: 99 | m.bias.data.zero_() 100 | 101 | def forward(self, x): 102 | first_step = self.inRelu(self.inConv(x)) 103 | residual = first_step 104 | output = torch.add(self.glueBN(self.glueConv(self.resBlocks(first_step))), 105 | residual) 106 | output = self.outConv(self.upscaleBlock(output)) 107 | return output 108 | 109 | def make_block_layers(self, n_blocks, block_fn): 110 | layers = [block_fn() for x in range(n_blocks)] 111 | return nn.Sequential(*layers) 112 | -------------------------------------------------------------------------------- /small.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kayr7/PyTorch-SRGAN/a62acd8abaef76269c2b206b96e26a488f4a2114/small.jpg -------------------------------------------------------------------------------- /srresnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | from model import SRResNet, Residual, SubPixelConv 4 | import dataprovider 5 | import argparse 6 | import os 7 | import time 8 | 9 | from PIL import Image 10 | import random 11 | from scipy.misc import imread 12 | import numpy as np 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.parallel 17 | import torch.optim as optim 18 | import torch.backends.cudnn as cudnn 19 | import torchvision.datasets as datasets 20 | import torchvision.models as models 21 | from torch.autograd import Variable 22 | from torch.utils.data import DataLoader 23 | 24 | 25 | HEIGHT = 224 26 | WIDTH = 224 27 | SCALE = 4 28 | 29 | # Training settings 30 | parser = argparse.ArgumentParser(description='PyTorch VDSR') 31 | parser.add_argument('--batchSize', 32 | type=int, 33 | default=64, 34 | help='Training batch size') 35 | parser.add_argument('--nEpochs', 36 | type=int, 37 | default=150, 38 | help='Number of epochs to train for') 39 | parser.add_argument('--lr', 40 | type=float, 41 | default=0.01, 42 | help='Learning Rate. Default=0.1') 43 | parser.add_argument('--step', 44 | type=int, 45 | default=10, 46 | help='learning rate decayed every n epochs, Default: n=10') 47 | parser.add_argument('--cuda', 48 | action='store_true', 49 | help='Use cuda?') 50 | parser.add_argument('--resume', 51 | default='', 52 | type=str, 53 | help='Path to checkpoint (default: none)') 54 | parser.add_argument('--start-epoch', 55 | default=1, 56 | type=int, 57 | help='Manual epoch number (useful on restarts)') 58 | parser.add_argument('--clip', 59 | type=float, 60 | default=0.4, 61 | help='Clipping Gradients. Default=0.4') 62 | parser.add_argument('--threads', 63 | type=int, 64 | default=0, 65 | help='Number of threads for data loader, Default: 4') 66 | parser.add_argument('--images', 67 | type=int, 68 | default=400, 69 | help='Number of threads for data loader, Default: 400') 70 | parser.add_argument('--test-image', 71 | default='', 72 | type=str, 73 | help='Path to image that should be scaled up') 74 | parser.add_argument('--momentum', 75 | default=0.9, 76 | type=float, 77 | help='Momentum, Default: 0.9') 78 | parser.add_argument('--weight-decay', 79 | '--wd', 80 | default=1e-4, 81 | type=float, 82 | help='Weight decay, Default: 1e-4') 83 | parser.add_argument('--percep-scale', 84 | default=0.006, 85 | type=float, 86 | help='weight to content vs pixel') 87 | parser.add_argument('--pretrained', 88 | default='', 89 | type=str, 90 | help='path to pretrained model (default: none)') 91 | parser.add_argument('--image-dir', 92 | default='', 93 | type=str, 94 | help='directory with images to train on (default: none)') 95 | parser.add_argument('--pretraining', 96 | action='store_true', 97 | help='pretraining step?') 98 | parser.add_argument('--testing', 99 | action='store_true', 100 | help='inference step?') 101 | 102 | 103 | def main(): 104 | 105 | global opt, model, HEIGHT, WIDTH, SCALE 106 | opt = parser.parse_args() 107 | print(opt) 108 | test_image = None 109 | if opt.testing: 110 | opt.batchSize = 1 111 | img = imread(opt.test_image) 112 | HEIGHT, WIDTH = img.shape[0], img.shape[1] 113 | test_image = Image.fromarray(np.uint8(img)) 114 | test_image = np.asarray(test_image) 115 | 116 | if test_image.ndim == 3: 117 | if test_image.shape[2] != 3: 118 | test_image = test_image[:, :, 0:3] 119 | 120 | test_image = torch.ByteTensor( 121 | torch.ByteStorage.from_buffer(test_image.transpose(2, 0, 1).tobytes())).float().div(255).view(-1, 3, HEIGHT, WIDTH) 122 | else: 123 | print('not good... we do not upscale non color images yet') 124 | return 125 | 126 | cuda = opt.cuda 127 | if cuda and not torch.cuda.is_available(): 128 | raise Exception('No GPU found, please run without --cuda') 129 | 130 | opt.seed = random.randint(1, 10000) 131 | print('Random Seed: ', opt.seed) 132 | torch.manual_seed(opt.seed) 133 | if cuda: 134 | torch.cuda.manual_seed(opt.seed) 135 | 136 | model = SRResNet() 137 | 138 | # clean this mess up! 139 | if opt.testing: 140 | model.eval() 141 | mean = torch.zeros(opt.batchSize, 3, HEIGHT * SCALE, WIDTH * SCALE) 142 | mean[:, 0, :, :] = 0.485 143 | mean[:, 1, :, :] = 0.456 144 | mean[:, 2, :, :] = 0.406 145 | 146 | std = torch.zeros(opt.batchSize, 3, HEIGHT * SCALE, WIDTH * SCALE) 147 | std[:, 0, :, :] = 0.229 148 | std[:, 1, :, :] = 0.224 149 | std[:, 2, :, :] = 0.225 150 | 151 | tmean = torch.zeros(opt.batchSize, 3, HEIGHT, WIDTH) 152 | tmean[:, 0, :, :] = 0.485 153 | tmean[:, 1, :, :] = 0.456 154 | tmean[:, 2, :, :] = 0.406 155 | 156 | tstd = torch.zeros(opt.batchSize, 3, HEIGHT, WIDTH) 157 | tstd[:, 0, :, :] = 0.229 158 | tstd[:, 1, :, :] = 0.224 159 | tstd[:, 2, :, :] = 0.225 160 | 161 | else: 162 | model.train() 163 | mean = torch.zeros(opt.batchSize, 3, HEIGHT, WIDTH) 164 | mean[:, 0, :, :] = 0.485 165 | mean[:, 1, :, :] = 0.456 166 | mean[:, 2, :, :] = 0.406 167 | 168 | std = torch.zeros(opt.batchSize, 3, HEIGHT, WIDTH) 169 | std[:, 0, :, :] = 0.229 170 | std[:, 1, :, :] = 0.224 171 | std[:, 2, :, :] = 0.225 172 | 173 | tmean = torch.zeros(opt.batchSize, 3, HEIGHT // SCALE, WIDTH // SCALE) 174 | tmean[:, 0, :, :] = 0.485 175 | tmean[:, 1, :, :] = 0.456 176 | tmean[:, 2, :, :] = 0.406 177 | 178 | tstd = torch.zeros(opt.batchSize, 3, HEIGHT // SCALE, WIDTH // SCALE) 179 | tstd[:, 0, :, :] = 0.229 180 | tstd[:, 1, :, :] = 0.224 181 | tstd[:, 2, :, :] = 0.225 182 | 183 | if not opt.pretraining and not opt.testing: 184 | percep_model = models.__dict__['vgg19'](pretrained=True) 185 | percep_model.features = nn.Sequential( 186 | *list(percep_model.features.children())[:-14]) 187 | percep_model.eval() 188 | 189 | criterion = nn.MSELoss(size_average=False) 190 | lr = opt.lr 191 | 192 | if cuda: 193 | model = torch.nn.DataParallel(model).cuda() 194 | criterion = criterion.cuda() 195 | if not opt.pretraining and not opt.testing: 196 | percep_model = percep_model.cuda() 197 | mean = Variable(mean).cuda() 198 | std = Variable(std).cuda() 199 | tmean = Variable(tmean).cuda() 200 | tstd = Variable(tstd).cuda() 201 | 202 | if opt.pretrained: 203 | if os.path.isfile(opt.pretrained): 204 | print('=> loading model {}'.format(opt.pretrained)) 205 | weights = torch.load(opt.pretrained) 206 | model.load_state_dict(weights['model'].state_dict()) 207 | else: 208 | print('=> no model found at {}'.format(opt.pretrained)) 209 | 210 | if opt.testing: 211 | test_image = Variable(test_image) 212 | if cuda: 213 | test_image = test_image.cuda() 214 | 215 | test_image = test_image.sub(tmean).div(tstd) 216 | gen = model(test_image) 217 | gened = torch.clamp(gen.mul(std).add(mean).mul(255.0), min=0., max=255.0).byte()[ 218 | 0].data.cpu().numpy().transpose(1, 2, 0) 219 | gened = Image.fromarray(gened) 220 | gened.save('testing-sr.jpg') 221 | 222 | else: 223 | train_set = dataprovider.DatasetFromDir( 224 | opt.image_dir, 225 | samples=opt.images, 226 | width=WIDTH, 227 | height=HEIGHT) 228 | 229 | training_data_loader = DataLoader( 230 | dataset=train_set, 231 | num_workers=opt.threads, 232 | batch_size=opt.batchSize, 233 | shuffle=True) 234 | 235 | optimizer = optim.Adam(model.parameters(), lr=lr) 236 | 237 | counter = 0 238 | for epoch in range(opt.nEpochs): 239 | 240 | loss_sum = Variable(torch.zeros(1), requires_grad=False) 241 | if cuda: 242 | loss_sum = loss_sum.cuda() 243 | 244 | for iteration, batch in enumerate(training_data_loader, 1): 245 | counter = counter + 1 246 | input, target = ( 247 | Variable(batch[0]), 248 | Variable(batch[1], requires_grad=False)) 249 | 250 | if cuda: 251 | input = input.cuda() 252 | target = target.cuda() 253 | 254 | input = input.sub(tmean).div(tstd) 255 | target = target.sub(mean).div(std) 256 | 257 | gen = model(input) 258 | optimizer.zero_grad() 259 | loss = criterion(gen, target) 260 | 261 | if not opt.pretraining: 262 | out_percep = percep_model.features(gen) 263 | out_percep_real = Variable(percep_model.features( 264 | target).data, requires_grad=False) 265 | percep_loss = criterion(out_percep, out_percep_real) 266 | # loss_relation = percep_loss.div(loss) 267 | 268 | loss = loss.add(percep_loss.mul(opt.percep_scale)) # loss_relation)) 269 | 270 | loss.backward() 271 | nn.utils.clip_grad_norm(model.parameters(), opt.clip) 272 | loss_sum.add_(loss) 273 | optimizer.step() 274 | 275 | if counter % 400 == 0: 276 | print('sum_of_loss = {}'.format( 277 | loss_sum.data.select(0, 0))) 278 | loss_sum = Variable(torch.zeros(1), requires_grad=False) 279 | if cuda: 280 | loss_sum = loss_sum.cuda() 281 | 282 | save_checkpoint(model, epoch) 283 | input = torch.clamp(input.mul(tstd).add(tmean).mul( 284 | 255.0), min=0., max=255.0).byte()[0].data.cpu().numpy().transpose(1, 2, 0) 285 | inp = Image.fromarray(input) 286 | label = torch.clamp(target.mul(std).add(mean).mul(255.0), min=0., max=255.0).byte()[ 287 | 0].data.cpu().numpy().transpose(1, 2, 0) 288 | lab = Image.fromarray(label) 289 | gened = torch.clamp(gen.mul(std).add(mean).mul(255.0), min=0., max=255.0).byte()[ 290 | 0].data.cpu().numpy().transpose(1, 2, 0) 291 | gened = Image.fromarray(gened) 292 | inp.save('input.jpg') 293 | lab.save('gt.jpg') 294 | gened.save('sr.jpg') 295 | 296 | 297 | def save_checkpoint(model, epoch): 298 | model_out_path = 'model/' + 'model_epoch_{}.pth'.format(epoch) 299 | state = {'epoch': epoch, 'model': model} 300 | if not os.path.exists('model/'): 301 | os.makedirs('model/') 302 | 303 | torch.save(state, model_out_path) 304 | 305 | print('Checkpoint saved to {}'.format(model_out_path)) 306 | 307 | 308 | if __name__ == '__main__': 309 | main() 310 | -------------------------------------------------------------------------------- /superresolution.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kayr7/PyTorch-SRGAN/a62acd8abaef76269c2b206b96e26a488f4a2114/superresolution.jpg --------------------------------------------------------------------------------