├── LICENSE ├── README.md ├── data.py ├── examples ├── kennedy.png ├── kennedy_zssr.png ├── lincoln.png └── lincoln_zssr.png ├── net.py ├── source_target_transforms.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unofficial PyTorch implementation of "Zero-Shot" Super-Resolution using Deep Internal Learning 2 | 3 | Unofficial Implementation of *1712.06087 "Zero-Shot" Super-Resolution using Deep Internal Learning by Assaf Shocher, Nadav Cohen, Michal Irani.* 4 | 5 | Official Project page: http://www.wisdom.weizmann.ac.il/~vision/zssr/ 6 | 7 | Paper: https://arxiv.org/abs/1712.06087 8 | 9 | 10 | ---------- 11 | 12 | 13 | This trains a deep neural network to perform super resolution using a single image. 14 | 15 | The network is not trained on additional images, and only uses information from within the target image. 16 | Pairs of high resolution and low resolution patches are sampled from the image, and the network fits their difference. 17 | 18 | ![Low resolution](https://github.com/jacobgil/pytorch-zssr/blob/master/examples/kennedy.png?raw=true) 19 | ![ZSSR](https://github.com/jacobgil/pytorch-zssr/blob/master/examples/kennedy_zssr.png?raw=true) 20 | 21 | ![ZSSR](https://github.com/jacobgil/pytorch-zssr/blob/master/examples/lincoln.png?raw=true) 22 | ![ZSSR](https://github.com/jacobgil/pytorch-zssr/blob/master/examples/lincoln_zssr.png?raw=true) 23 | 24 | 25 | ---------- 26 | 27 | 28 | TODO: 29 | - Implement additional augmentation using the "Geometric self ensemble" mentioned in the paper. 30 | - Implement gradual increase of the super resolution factor as described in the paper. 31 | - Support for arbitrary kernel estimation and sampling with arbitrary kernels. The current implementation interpolates the images bicubic interpolation. 32 | 33 | Deviations from paper: 34 | - Instead of fitting the loss and analyzing it's standard deviation, the network is trained for a constant number of batches. The learning rate shrinks x10 every 10,000 iterations. 35 | 36 | 37 | # Usage 38 | Example: ```python train.py --img img.png``` 39 | ``` 40 | usage: train.py [-h] [--num_batches NUM_BATCHES] [--crop CROP] [--lr LR] 41 | [--factor FACTOR] [--img IMG] 42 | 43 | optional arguments: 44 | -h, --help show this help message and exit 45 | --num_batches NUM_BATCHES 46 | Number of batches to run 47 | --crop CROP Random crop size 48 | --lr LR Base learning rate for Adam 49 | --factor FACTOR Interpolation factor. 50 | --img IMG Path to input img 51 | ``` 52 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import numpy as np 3 | import sys 4 | import random 5 | import torch 6 | from torchvision import transforms 7 | from torchvision.transforms import functional as F 8 | import numbers 9 | import cv2 10 | from source_target_transforms import * 11 | 12 | class DataSampler: 13 | def __init__(self, img, sr_factor, crop_size): 14 | self.img = img 15 | self.sr_factor = sr_factor 16 | self.pairs = self.create_hr_lr_pairs() 17 | sizes = np.float32([x[0].size[0]*x[0].size[1] / float(img.size[0]*img.size[1]) \ 18 | for x in self.pairs]) 19 | self.pair_probabilities = sizes / np.sum(sizes) 20 | 21 | self.transform = transforms.Compose([ 22 | RandomRotationFromSequence([0, 90, 180, 270]), 23 | RandomHorizontalFlip(), 24 | RandomVerticalFlip(), 25 | RandomCrop(crop_size), 26 | ToTensor()]) 27 | 28 | def create_hr_lr_pairs(self): 29 | smaller_side = min(self.img.size[0 : 2]) 30 | larger_side = max(self.img.size[0 : 2]) 31 | 32 | factors = [] 33 | for i in range(smaller_side//5, smaller_side+1): 34 | downsampled_smaller_side = i 35 | zoom = float(downsampled_smaller_side)/smaller_side 36 | downsampled_larger_side = round(larger_side*zoom) 37 | if downsampled_smaller_side%self.sr_factor==0 and \ 38 | downsampled_larger_side%self.sr_factor==0: 39 | factors.append(zoom) 40 | 41 | pairs = [] 42 | for zoom in factors: 43 | hr = self.img.resize((int(self.img.size[0]*zoom), \ 44 | int(self.img.size[1]*zoom)), \ 45 | resample=PIL.Image.BICUBIC) 46 | 47 | lr = hr.resize((int(hr.size[0]/self.sr_factor), \ 48 | int(hr.size[1]/self.sr_factor)), 49 | resample=PIL.Image.BICUBIC) 50 | 51 | lr = lr.resize(hr.size, resample=PIL.Image.BICUBIC) 52 | 53 | pairs.append((hr, lr)) 54 | 55 | return pairs 56 | 57 | def generate_data(self): 58 | while True: 59 | hr, lr = random.choices(self.pairs, weights=self.pair_probabilities, k=1)[0] 60 | hr_tensor, lr_tensor = self.transform((hr, lr)) 61 | hr_tensor = torch.unsqueeze(hr_tensor, 0) 62 | lr_tensor = torch.unsqueeze(lr_tensor, 0) 63 | yield hr_tensor, lr_tensor 64 | 65 | if __name__ == '__main__': 66 | img = PIL.Image.open(sys.argv[1]) 67 | sampler = DataSampler(img, 2) 68 | for x in sampler.generate_data(): 69 | hr, lr = x 70 | hr = hr.numpy().transpose((1, 2, 0)) 71 | lr = lr.numpy().transpose((1, 2, 0)) -------------------------------------------------------------------------------- /examples/kennedy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-zssr/ee2364def43455829dc108037d2af0aaf0b6d69c/examples/kennedy.png -------------------------------------------------------------------------------- /examples/kennedy_zssr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-zssr/ee2364def43455829dc108037d2af0aaf0b6d69c/examples/kennedy_zssr.png -------------------------------------------------------------------------------- /examples/lincoln.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-zssr/ee2364def43455829dc108037d2af0aaf0b6d69c/examples/lincoln.png -------------------------------------------------------------------------------- /examples/lincoln_zssr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobgil/pytorch-zssr/ee2364def43455829dc108037d2af0aaf0b6d69c/examples/lincoln_zssr.png -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ZSSRNet(nn.Module): 5 | def __init__(self, input_channels=3, kernel_size=3, channels=64): 6 | super(ZSSRNet, self).__init__() 7 | 8 | self.conv0 = nn.Conv2d(input_channels, channels, kernel_size=kernel_size, padding=kernel_size//2, bias=True) 9 | self.conv1 = nn.Conv2d(channels, channels, kernel_size=kernel_size, padding=kernel_size//2, bias=True) 10 | self.conv2 = nn.Conv2d(channels, channels, kernel_size=kernel_size, padding=kernel_size//2, bias=True) 11 | self.conv3 = nn.Conv2d(channels, channels, kernel_size=kernel_size, padding=kernel_size//2, bias=True) 12 | self.conv4 = nn.Conv2d(channels, channels, kernel_size=kernel_size, padding=kernel_size//2, bias=True) 13 | self.conv5 = nn.Conv2d(channels, channels, kernel_size=kernel_size, padding=kernel_size//2, bias=True) 14 | self.conv6 = nn.Conv2d(channels, channels, kernel_size=kernel_size, padding=kernel_size//2, bias=True) 15 | self.conv7 = nn.Conv2d(channels, input_channels, kernel_size=kernel_size, padding=kernel_size//2, bias=True) 16 | 17 | self.relu = nn.ReLU() 18 | 19 | def forward(self, x): 20 | x = self.relu(self.conv0(x)) 21 | x = self.relu(self.conv1(x)) 22 | x = self.relu(self.conv2(x)) 23 | x = self.relu(self.conv3(x)) 24 | x = self.relu(self.conv4(x)) 25 | x = self.relu(self.conv5(x)) 26 | x = self.relu(self.conv6(x)) 27 | x = self.conv7(x) 28 | 29 | return x -------------------------------------------------------------------------------- /source_target_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import PIL 3 | import random 4 | from torchvision import transforms 5 | from torchvision.transforms import functional as F 6 | import numbers 7 | 8 | class RandomRotationFromSequence(object): 9 | """Rotate the image by angle. 10 | Args: 11 | degrees (sequence or float or int): Range of degrees to select from. 12 | If degrees is a number instead of sequence like (min, max), the range of degrees 13 | will be (-degrees, +degrees). 14 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 15 | An optional resampling filter. 16 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters 17 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 18 | expand (bool, optional): Optional expansion flag. 19 | If true, expands the output to make it large enough to hold the entire rotated image. 20 | If false or omitted, make the output image the same size as the input image. 21 | Note that the expand flag assumes rotation around the center and no translation. 22 | center (2-tuple, optional): Optional center of rotation. 23 | Origin is the upper left corner. 24 | Default is the center of the image. 25 | """ 26 | 27 | def __init__(self, degrees, resample=False, expand=False, center=None): 28 | self.degrees = degrees 29 | self.resample = resample 30 | self.expand = expand 31 | self.center = center 32 | 33 | @staticmethod 34 | def get_params(degrees): 35 | """Get parameters for ``rotate`` for a random rotation. 36 | Returns: 37 | sequence: params to be passed to ``rotate`` for random rotation. 38 | """ 39 | angle = np.random.choice(degrees) 40 | return angle 41 | 42 | def __call__(self, data): 43 | """ 44 | img (PIL Image): Image to be rotated. 45 | Returns: 46 | PIL Image: Rotated image. 47 | """ 48 | hr, lr = data 49 | angle = self.get_params(self.degrees) 50 | return F.rotate(hr, angle, self.resample, self.expand, self.center), \ 51 | F.rotate(lr, angle, self.resample, self.expand, self.center) 52 | 53 | class RandomHorizontalFlip(object): 54 | """Horizontally flip the given PIL Image randomly with a probability of 0.5.""" 55 | 56 | def __call__(self, data): 57 | """ 58 | Args: 59 | img (PIL Image): Image to be flipped. 60 | Returns: 61 | PIL Image: Randomly flipped image. 62 | """ 63 | hr, lr = data 64 | if random.random() < 0.5: 65 | return F.hflip(hr), F.hflip(lr) 66 | return hr, lr 67 | 68 | class RandomVerticalFlip(object): 69 | """Vertically flip the given PIL Image randomly with a probability of 0.5.""" 70 | 71 | def __call__(self, data): 72 | """ 73 | Args: 74 | img (PIL Image): Image to be flipped. 75 | Returns: 76 | PIL Image: Randomly flipped image. 77 | """ 78 | hr, lr = data 79 | if random.random() < 0.5: 80 | return F.vflip(hr), F.vflip(lr) 81 | return hr, lr 82 | 83 | class RandomCrop(object): 84 | """Crop the given PIL Image at a random location. 85 | Args: 86 | size (sequence or int): Desired output size of the crop. If size is an 87 | int instead of sequence like (h, w), a square crop (size, size) is 88 | made. 89 | padding (int or sequence, optional): Optional padding on each border 90 | of the image. Default is 0, i.e no padding. If a sequence of length 91 | 4 is provided, it is used to pad left, top, right, bottom borders 92 | respectively. 93 | """ 94 | 95 | def __init__(self, size, padding=0): 96 | if isinstance(size, numbers.Number): 97 | self.size = (int(size), int(size)) 98 | else: 99 | self.size = size 100 | self.padding = padding 101 | 102 | @staticmethod 103 | def get_params(data, output_size): 104 | """Get parameters for ``crop`` for a random crop. 105 | Args: 106 | img (PIL Image): Image to be cropped. 107 | output_size (tuple): Expected output size of the crop. 108 | Returns: 109 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 110 | """ 111 | hr, lr = data 112 | w, h = hr.size 113 | th, tw = output_size 114 | if w == tw or h == th: 115 | return 0, 0, h, w 116 | 117 | if w < tw or h < th: 118 | th, tw = h//2, w//2 119 | 120 | i = random.randint(0, h - th) 121 | j = random.randint(0, w - tw) 122 | return i, j, th, tw 123 | 124 | def __call__(self, data): 125 | """ 126 | Args: 127 | img (PIL Image): Image to be cropped. 128 | Returns: 129 | PIL Image: Cropped image. 130 | """ 131 | hr, lr = data 132 | if self.padding > 0: 133 | hr = F.pad(hr, self.padding) 134 | lr = F.pad(lr, self.padding) 135 | 136 | i, j, h, w = self.get_params(data, self.size) 137 | return F.crop(hr, i, j, h, w), F.crop(lr, i, j, h, w) 138 | 139 | class ToTensor(object): 140 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 141 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 142 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 143 | """ 144 | 145 | def __call__(self, data): 146 | """ 147 | Args: 148 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 149 | Returns: 150 | Tensor: Converted image. 151 | """ 152 | hr, lr = data 153 | return F.to_tensor(hr), F.to_tensor(lr) 154 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from net import ZSSRNet 3 | from data import DataSampler 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torch.autograd import Variable 8 | from torch.nn import init 9 | import PIL 10 | import sys 11 | from torchvision import transforms 12 | import tqdm 13 | import argparse 14 | 15 | def weights_init_kaiming(m): 16 | classname = m.__class__.__name__ 17 | if classname.find('Conv') != -1: 18 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 19 | elif classname.find('Linear') != -1: 20 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 21 | elif classname.find('BatchNorm2d') != -1: 22 | init.normal(m.weight.data, 1.0, 0.02) 23 | init.constant(m.bias.data, 0.0) 24 | 25 | 26 | def adjust_learning_rate(optimizer, new_lr): 27 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 28 | for param_group in optimizer.param_groups: 29 | param_group['lr'] = new_lr 30 | 31 | def train(model, img, sr_factor, num_batches, learning_rate, crop_size): 32 | loss = nn.L1Loss() 33 | optimizer = optim.Adam(model.parameters(), lr=learning_rate) 34 | sampler = DataSampler(img, sr_factor, crop_size) 35 | model.cuda() 36 | with tqdm.tqdm(total=num_batches, miniters=1, mininterval=0) as progress: 37 | for iter, (hr, lr) in enumerate(sampler.generate_data()): 38 | model.zero_grad() 39 | 40 | lr = Variable(lr).cuda() 41 | hr = Variable(hr).cuda() 42 | 43 | output = model(lr) + lr 44 | error = loss(output, hr) 45 | 46 | cpu_loss = error.data.cpu().numpy()[0] 47 | 48 | progress.set_description("Iteration: {iter} Loss: {loss}, Learning Rate: {lr}".format( \ 49 | iter=iter, loss=cpu_loss, lr=learning_rate)) 50 | progress.update() 51 | 52 | if iter > 0 and iter % 10000 == 0: 53 | learning_rate = learning_rate / 10 54 | adjust_learning_rate(optimizer, new_lr=learning_rate) 55 | print("Learning rate reduced to {lr}".format(lr=learning_rate) ) 56 | 57 | error.backward() 58 | optimizer.step() 59 | 60 | if iter > num_batches: 61 | print('Done training.') 62 | break 63 | 64 | 65 | def test(model, img, sr_factor): 66 | model.eval() 67 | 68 | img = img.resize((int(img.size[0]*sr_factor), \ 69 | int(img.size[1]*sr_factor)), resample=PIL.Image.BICUBIC) 70 | img.save('low_res.png') 71 | 72 | img = transforms.ToTensor()(img) 73 | img = torch.unsqueeze(img, 0) 74 | input = Variable(img.cuda()) 75 | residual = model(input) 76 | output = input + residual 77 | 78 | output = output.cpu().data[0, :, :, :] 79 | o = output.numpy() 80 | o[np.where(o < 0)] = 0.0 81 | o[np.where(o > 1)] = 1.0 82 | output = torch.from_numpy(o) 83 | output = transforms.ToPILImage()(output) 84 | output.save('zssr.png') 85 | 86 | def get_args(): 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument('--num_batches', type=int, default=15000, \ 89 | help='Number of batches to run') 90 | parser.add_argument('--crop', type=int, default=128, \ 91 | help='Random crop size') 92 | parser.add_argument('--lr', type=float, default=0.00001, \ 93 | help='Base learning rate for Adam') 94 | parser.add_argument('--factor', type=int, default=2, \ 95 | help='Interpolation factor.') 96 | parser.add_argument('--img', type=str, help='Path to input img') 97 | 98 | args = parser.parse_args() 99 | 100 | return args 101 | 102 | if __name__ == '__main__': 103 | args = get_args() 104 | 105 | img = PIL.Image.open(args.img) 106 | num_channels = len(np.array(img).shape) 107 | if num_channels == 3: 108 | model = ZSSRNet(input_channels = 3) 109 | elif num_channels == 2: 110 | model = ZSSRNet(input_channels = 1) 111 | else: 112 | print("Expecting RGB or gray image, instead got", img.size) 113 | sys.exit(1) 114 | 115 | # Weight initialization 116 | model.apply(weights_init_kaiming) 117 | 118 | train(model, img, args.factor, args.num_batches, args.lr, args.crop) 119 | test(model, img, args.factor) --------------------------------------------------------------------------------