├── LICENSE
├── README.md
├── dataprovider.py
├── groundtruth.jpg
├── model.py
├── small.jpg
├── srresnet.py
└── superresolution.jpg
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "{}"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright {yyyy} {name of copyright owner}
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch-SRGAN
2 |
Source:
SRResNetVgg5,4:
(Ground Truth:
)
3 |
4 | PyTorch version of the paper: [Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network](https://arxiv.org/abs/1609.04802)
5 | (currently it does not implement the GAN, but the srresnet + vgg19-5,4 loss)
6 |
7 | you can train a net from scratch:
8 | (optionally start training with just the pixel-wise loss on the resnet part:
9 | `python srresnet.py --image-dir traindir --cuda --pretraining --images 16384 --batchSize 16`)
10 |
11 | (use `--pretrained modelfile.pth` to continue from a pretraining or previous run for example)
12 | `python srresnet.py --image-dir traindir --cuda --images 16384 --batchSize 16`
13 |
14 |
15 | and then inference with the arguments:
16 | `--pretrained model/model_epoch_80.pth --testing --test-image BSDS300/images/train/100075.jpg`
17 |
--------------------------------------------------------------------------------
/dataprovider.py:
--------------------------------------------------------------------------------
1 |
2 | import torch.utils.data as data
3 | from os import listdir
4 | from os.path import join
5 | from PIL import Image
6 | from scipy.misc import imread, imresize, imsave
7 | import numpy as np
8 | import random
9 | import torch
10 |
11 |
12 | def is_image_file(filename):
13 | extensions = ['.png', '.jpg', '.jpeg', '.bmp']
14 | return any(filename.endswith(extension) for extension in extensions)
15 |
16 |
17 | class DatasetFromDir(data.Dataset):
18 | def __init__(self, file_path, samples, height=224, width=224):
19 |
20 | image_dir = file_path
21 | self.height = height
22 | self.width = width
23 | self.scale = 4
24 | self.labels = []
25 |
26 | image_filenames = [
27 | join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)]
28 |
29 | for i in image_filenames:
30 | print(i)
31 | if len(self.labels) >= samples:
32 | break
33 | img = imread(i)
34 | try:
35 | H, W = img.shape[0], img.shape[1]
36 | label_orig = Image.fromarray(np.uint8(img))
37 | if H <= W:
38 | if H < self.height:
39 | label_orig = label_orig.resize(
40 | (W * self.height // H, self.height), Image.ANTIALIAS)
41 | else:
42 | if W < self.width:
43 | label_orig = label_orig.resize(
44 | (self.width, H * self.width // W), Image.ANTIALIAS)
45 | H, W = label_orig.size
46 | if H > self.height and W > self.width:
47 | self.labels.append(label_orig)
48 |
49 | if len(self.labels) >= samples:
50 | break
51 |
52 | except (ValueError, IndexError) as e:
53 | print(i)
54 | print(img.shape, img.dtype)
55 | print(e)
56 |
57 | print('we have {} training samples'.format(len(self.labels)))
58 |
59 | def __getitem__(self, index):
60 | while True: # hack to make sure we have a color image we can handle...
61 | index = random.randint(0, len(self.labels) - 1)
62 | label_orig = self.labels[index]
63 |
64 | W, H = label_orig.size
65 | left = random.randint(0, W - self.width - 1)
66 | top = random.randint(0, H - self.height - 1)
67 | right = left + self.width
68 | bottom = top + self.height
69 | label = label_orig.crop((left, top, right, bottom))
70 |
71 | data = label.resize(
72 | (self.width // self.scale, self.height // self.scale), Image.ANTIALIAS)
73 |
74 | data = np.asarray(data)
75 | label = np.asarray(label)
76 |
77 | # currently we work only on images with 3 channels
78 | if label.ndim == 3:
79 | if label.shape[2] != 3:
80 | label = label[:, :, 0:3]
81 | data = data[:, :, 0:3]
82 | l_width = label.shape[1]
83 | l_height = label.shape[0]
84 | d_width = data.shape[1]
85 | d_height = data.shape[0]
86 |
87 | input = torch.ByteTensor(
88 | torch.ByteStorage.from_buffer(data.transpose(2, 0, 1).tobytes())).float().div(255).view(3, d_height, d_width)
89 |
90 | target = torch.ByteTensor(
91 | torch.ByteStorage.from_buffer(label.transpose(2, 0, 1).tobytes())).float().div(255).view(3, l_height, l_width)
92 | return input, target
93 |
94 | def __len__(self):
95 | return len(self.labels)
96 |
--------------------------------------------------------------------------------
/groundtruth.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kayr7/PyTorch-SRGAN/a62acd8abaef76269c2b206b96e26a488f4a2114/groundtruth.jpg
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | import math
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | # this is one block for a resnet
9 |
10 |
11 | class Residual(nn.Module):
12 | def __init__(self, n_channels=64):
13 | super(Residual, self).__init__()
14 | self.n_channels = n_channels
15 | self.conv1 = nn.Conv2d(in_channels=self.n_channels,
16 | out_channels=self.n_channels,
17 | kernel_size=3,
18 | stride=1,
19 | padding=1,
20 | bias=False)
21 | self.bn1 = nn.BatchNorm2d(self.n_channels)
22 | self.relu = nn.ReLU(inplace=True)
23 | self.conv2 = nn.Conv2d(in_channels=self.n_channels,
24 | out_channels=self.n_channels,
25 | kernel_size=3,
26 | stride=1,
27 | padding=1,
28 | bias=False)
29 | self.bn2 = nn.BatchNorm2d(self.n_channels)
30 |
31 | def forward(self, x):
32 | input = x
33 | output = torch.add(self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x))))),
34 | input)
35 | return output
36 |
37 |
38 | class SubPixelConv(nn.Module):
39 | def __init__(self, n_channels=64, upsample=2):
40 | super(SubPixelConv, self).__init__()
41 | self.n_channels = n_channels
42 | self.upsample = upsample
43 | self.out_channels = self.upsample * self.upsample * self.n_channels
44 |
45 | self.conv = nn.Conv2d(in_channels=self.n_channels,
46 | out_channels=self.out_channels,
47 | kernel_size=3,
48 | stride=1,
49 | padding=1,
50 | bias=False)
51 | self.upsample_net = nn.PixelShuffle(self.upsample)
52 | self.relu = nn.ReLU(inplace=True)
53 |
54 | def forward(self, x):
55 | input = x
56 | output = self.relu(self.upsample_net(self.conv(x)))
57 | return output
58 |
59 |
60 | class SRResNet(nn.Module):
61 | def __init__(self, n_channels=64, n_blocks=15):
62 | super(SRResNet, self).__init__()
63 | self.n_channels = n_channels
64 | self.inConv = nn.Conv2d(in_channels=3, # RGB
65 | out_channels=self.n_channels,
66 | kernel_size=3, # in paper it is 9, somehow other implementations always used 3
67 | stride=1,
68 | padding=1,
69 | bias=True)
70 | self.inRelu = nn.ReLU(inplace=True)
71 |
72 | self.resBlocks = self.make_block_layers(n_blocks, Residual)
73 |
74 | self.glueConv = nn.Conv2d(in_channels=self.n_channels,
75 | out_channels=self.n_channels,
76 | kernel_size=3,
77 | stride=1,
78 | padding=1,
79 | bias=True)
80 | self.glueBN = nn.BatchNorm2d(self.n_channels)
81 |
82 | self.upscaleBlock = self.make_block_layers(2, SubPixelConv)
83 |
84 | self.outConv = nn.Conv2d(in_channels=n_channels,
85 | out_channels=3, # RGB
86 | kernel_size=3, # paper has 9
87 | padding=1,
88 | bias=True)
89 |
90 | for m in self.modules():
91 | if isinstance(m, nn.Conv2d):
92 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
93 | m.weight.data.normal_(0, math.sqrt(2. / n))
94 | if m.bias is not None:
95 | m.bias.data.zero_()
96 | elif isinstance(m, nn.BatchNorm2d):
97 | m.weight.data.fill_(1)
98 | if m.bias is not None:
99 | m.bias.data.zero_()
100 |
101 | def forward(self, x):
102 | first_step = self.inRelu(self.inConv(x))
103 | residual = first_step
104 | output = torch.add(self.glueBN(self.glueConv(self.resBlocks(first_step))),
105 | residual)
106 | output = self.outConv(self.upscaleBlock(output))
107 | return output
108 |
109 | def make_block_layers(self, n_blocks, block_fn):
110 | layers = [block_fn() for x in range(n_blocks)]
111 | return nn.Sequential(*layers)
112 |
--------------------------------------------------------------------------------
/small.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kayr7/PyTorch-SRGAN/a62acd8abaef76269c2b206b96e26a488f4a2114/small.jpg
--------------------------------------------------------------------------------
/srresnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | from model import SRResNet, Residual, SubPixelConv
4 | import dataprovider
5 | import argparse
6 | import os
7 | import time
8 |
9 | from PIL import Image
10 | import random
11 | from scipy.misc import imread
12 | import numpy as np
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.parallel
17 | import torch.optim as optim
18 | import torch.backends.cudnn as cudnn
19 | import torchvision.datasets as datasets
20 | import torchvision.models as models
21 | from torch.autograd import Variable
22 | from torch.utils.data import DataLoader
23 |
24 |
25 | HEIGHT = 224
26 | WIDTH = 224
27 | SCALE = 4
28 |
29 | # Training settings
30 | parser = argparse.ArgumentParser(description='PyTorch VDSR')
31 | parser.add_argument('--batchSize',
32 | type=int,
33 | default=64,
34 | help='Training batch size')
35 | parser.add_argument('--nEpochs',
36 | type=int,
37 | default=150,
38 | help='Number of epochs to train for')
39 | parser.add_argument('--lr',
40 | type=float,
41 | default=0.01,
42 | help='Learning Rate. Default=0.1')
43 | parser.add_argument('--step',
44 | type=int,
45 | default=10,
46 | help='learning rate decayed every n epochs, Default: n=10')
47 | parser.add_argument('--cuda',
48 | action='store_true',
49 | help='Use cuda?')
50 | parser.add_argument('--resume',
51 | default='',
52 | type=str,
53 | help='Path to checkpoint (default: none)')
54 | parser.add_argument('--start-epoch',
55 | default=1,
56 | type=int,
57 | help='Manual epoch number (useful on restarts)')
58 | parser.add_argument('--clip',
59 | type=float,
60 | default=0.4,
61 | help='Clipping Gradients. Default=0.4')
62 | parser.add_argument('--threads',
63 | type=int,
64 | default=0,
65 | help='Number of threads for data loader, Default: 4')
66 | parser.add_argument('--images',
67 | type=int,
68 | default=400,
69 | help='Number of threads for data loader, Default: 400')
70 | parser.add_argument('--test-image',
71 | default='',
72 | type=str,
73 | help='Path to image that should be scaled up')
74 | parser.add_argument('--momentum',
75 | default=0.9,
76 | type=float,
77 | help='Momentum, Default: 0.9')
78 | parser.add_argument('--weight-decay',
79 | '--wd',
80 | default=1e-4,
81 | type=float,
82 | help='Weight decay, Default: 1e-4')
83 | parser.add_argument('--percep-scale',
84 | default=0.006,
85 | type=float,
86 | help='weight to content vs pixel')
87 | parser.add_argument('--pretrained',
88 | default='',
89 | type=str,
90 | help='path to pretrained model (default: none)')
91 | parser.add_argument('--image-dir',
92 | default='',
93 | type=str,
94 | help='directory with images to train on (default: none)')
95 | parser.add_argument('--pretraining',
96 | action='store_true',
97 | help='pretraining step?')
98 | parser.add_argument('--testing',
99 | action='store_true',
100 | help='inference step?')
101 |
102 |
103 | def main():
104 |
105 | global opt, model, HEIGHT, WIDTH, SCALE
106 | opt = parser.parse_args()
107 | print(opt)
108 | test_image = None
109 | if opt.testing:
110 | opt.batchSize = 1
111 | img = imread(opt.test_image)
112 | HEIGHT, WIDTH = img.shape[0], img.shape[1]
113 | test_image = Image.fromarray(np.uint8(img))
114 | test_image = np.asarray(test_image)
115 |
116 | if test_image.ndim == 3:
117 | if test_image.shape[2] != 3:
118 | test_image = test_image[:, :, 0:3]
119 |
120 | test_image = torch.ByteTensor(
121 | torch.ByteStorage.from_buffer(test_image.transpose(2, 0, 1).tobytes())).float().div(255).view(-1, 3, HEIGHT, WIDTH)
122 | else:
123 | print('not good... we do not upscale non color images yet')
124 | return
125 |
126 | cuda = opt.cuda
127 | if cuda and not torch.cuda.is_available():
128 | raise Exception('No GPU found, please run without --cuda')
129 |
130 | opt.seed = random.randint(1, 10000)
131 | print('Random Seed: ', opt.seed)
132 | torch.manual_seed(opt.seed)
133 | if cuda:
134 | torch.cuda.manual_seed(opt.seed)
135 |
136 | model = SRResNet()
137 |
138 | # clean this mess up!
139 | if opt.testing:
140 | model.eval()
141 | mean = torch.zeros(opt.batchSize, 3, HEIGHT * SCALE, WIDTH * SCALE)
142 | mean[:, 0, :, :] = 0.485
143 | mean[:, 1, :, :] = 0.456
144 | mean[:, 2, :, :] = 0.406
145 |
146 | std = torch.zeros(opt.batchSize, 3, HEIGHT * SCALE, WIDTH * SCALE)
147 | std[:, 0, :, :] = 0.229
148 | std[:, 1, :, :] = 0.224
149 | std[:, 2, :, :] = 0.225
150 |
151 | tmean = torch.zeros(opt.batchSize, 3, HEIGHT, WIDTH)
152 | tmean[:, 0, :, :] = 0.485
153 | tmean[:, 1, :, :] = 0.456
154 | tmean[:, 2, :, :] = 0.406
155 |
156 | tstd = torch.zeros(opt.batchSize, 3, HEIGHT, WIDTH)
157 | tstd[:, 0, :, :] = 0.229
158 | tstd[:, 1, :, :] = 0.224
159 | tstd[:, 2, :, :] = 0.225
160 |
161 | else:
162 | model.train()
163 | mean = torch.zeros(opt.batchSize, 3, HEIGHT, WIDTH)
164 | mean[:, 0, :, :] = 0.485
165 | mean[:, 1, :, :] = 0.456
166 | mean[:, 2, :, :] = 0.406
167 |
168 | std = torch.zeros(opt.batchSize, 3, HEIGHT, WIDTH)
169 | std[:, 0, :, :] = 0.229
170 | std[:, 1, :, :] = 0.224
171 | std[:, 2, :, :] = 0.225
172 |
173 | tmean = torch.zeros(opt.batchSize, 3, HEIGHT // SCALE, WIDTH // SCALE)
174 | tmean[:, 0, :, :] = 0.485
175 | tmean[:, 1, :, :] = 0.456
176 | tmean[:, 2, :, :] = 0.406
177 |
178 | tstd = torch.zeros(opt.batchSize, 3, HEIGHT // SCALE, WIDTH // SCALE)
179 | tstd[:, 0, :, :] = 0.229
180 | tstd[:, 1, :, :] = 0.224
181 | tstd[:, 2, :, :] = 0.225
182 |
183 | if not opt.pretraining and not opt.testing:
184 | percep_model = models.__dict__['vgg19'](pretrained=True)
185 | percep_model.features = nn.Sequential(
186 | *list(percep_model.features.children())[:-14])
187 | percep_model.eval()
188 |
189 | criterion = nn.MSELoss(size_average=False)
190 | lr = opt.lr
191 |
192 | if cuda:
193 | model = torch.nn.DataParallel(model).cuda()
194 | criterion = criterion.cuda()
195 | if not opt.pretraining and not opt.testing:
196 | percep_model = percep_model.cuda()
197 | mean = Variable(mean).cuda()
198 | std = Variable(std).cuda()
199 | tmean = Variable(tmean).cuda()
200 | tstd = Variable(tstd).cuda()
201 |
202 | if opt.pretrained:
203 | if os.path.isfile(opt.pretrained):
204 | print('=> loading model {}'.format(opt.pretrained))
205 | weights = torch.load(opt.pretrained)
206 | model.load_state_dict(weights['model'].state_dict())
207 | else:
208 | print('=> no model found at {}'.format(opt.pretrained))
209 |
210 | if opt.testing:
211 | test_image = Variable(test_image)
212 | if cuda:
213 | test_image = test_image.cuda()
214 |
215 | test_image = test_image.sub(tmean).div(tstd)
216 | gen = model(test_image)
217 | gened = torch.clamp(gen.mul(std).add(mean).mul(255.0), min=0., max=255.0).byte()[
218 | 0].data.cpu().numpy().transpose(1, 2, 0)
219 | gened = Image.fromarray(gened)
220 | gened.save('testing-sr.jpg')
221 |
222 | else:
223 | train_set = dataprovider.DatasetFromDir(
224 | opt.image_dir,
225 | samples=opt.images,
226 | width=WIDTH,
227 | height=HEIGHT)
228 |
229 | training_data_loader = DataLoader(
230 | dataset=train_set,
231 | num_workers=opt.threads,
232 | batch_size=opt.batchSize,
233 | shuffle=True)
234 |
235 | optimizer = optim.Adam(model.parameters(), lr=lr)
236 |
237 | counter = 0
238 | for epoch in range(opt.nEpochs):
239 |
240 | loss_sum = Variable(torch.zeros(1), requires_grad=False)
241 | if cuda:
242 | loss_sum = loss_sum.cuda()
243 |
244 | for iteration, batch in enumerate(training_data_loader, 1):
245 | counter = counter + 1
246 | input, target = (
247 | Variable(batch[0]),
248 | Variable(batch[1], requires_grad=False))
249 |
250 | if cuda:
251 | input = input.cuda()
252 | target = target.cuda()
253 |
254 | input = input.sub(tmean).div(tstd)
255 | target = target.sub(mean).div(std)
256 |
257 | gen = model(input)
258 | optimizer.zero_grad()
259 | loss = criterion(gen, target)
260 |
261 | if not opt.pretraining:
262 | out_percep = percep_model.features(gen)
263 | out_percep_real = Variable(percep_model.features(
264 | target).data, requires_grad=False)
265 | percep_loss = criterion(out_percep, out_percep_real)
266 | # loss_relation = percep_loss.div(loss)
267 |
268 | loss = loss.add(percep_loss.mul(opt.percep_scale)) # loss_relation))
269 |
270 | loss.backward()
271 | nn.utils.clip_grad_norm(model.parameters(), opt.clip)
272 | loss_sum.add_(loss)
273 | optimizer.step()
274 |
275 | if counter % 400 == 0:
276 | print('sum_of_loss = {}'.format(
277 | loss_sum.data.select(0, 0)))
278 | loss_sum = Variable(torch.zeros(1), requires_grad=False)
279 | if cuda:
280 | loss_sum = loss_sum.cuda()
281 |
282 | save_checkpoint(model, epoch)
283 | input = torch.clamp(input.mul(tstd).add(tmean).mul(
284 | 255.0), min=0., max=255.0).byte()[0].data.cpu().numpy().transpose(1, 2, 0)
285 | inp = Image.fromarray(input)
286 | label = torch.clamp(target.mul(std).add(mean).mul(255.0), min=0., max=255.0).byte()[
287 | 0].data.cpu().numpy().transpose(1, 2, 0)
288 | lab = Image.fromarray(label)
289 | gened = torch.clamp(gen.mul(std).add(mean).mul(255.0), min=0., max=255.0).byte()[
290 | 0].data.cpu().numpy().transpose(1, 2, 0)
291 | gened = Image.fromarray(gened)
292 | inp.save('input.jpg')
293 | lab.save('gt.jpg')
294 | gened.save('sr.jpg')
295 |
296 |
297 | def save_checkpoint(model, epoch):
298 | model_out_path = 'model/' + 'model_epoch_{}.pth'.format(epoch)
299 | state = {'epoch': epoch, 'model': model}
300 | if not os.path.exists('model/'):
301 | os.makedirs('model/')
302 |
303 | torch.save(state, model_out_path)
304 |
305 | print('Checkpoint saved to {}'.format(model_out_path))
306 |
307 |
308 | if __name__ == '__main__':
309 | main()
310 |
--------------------------------------------------------------------------------
/superresolution.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kayr7/PyTorch-SRGAN/a62acd8abaef76269c2b206b96e26a488f4a2114/superresolution.jpg
--------------------------------------------------------------------------------