├── LICENSE ├── README.md ├── create_data_lists.py ├── datasets.py ├── eval.py ├── img ├── baboon.png ├── cyberpunk1.png ├── cyberpunk4.png ├── cyberpunk6.png ├── cyberpunk7.png ├── cyberpunk8.png ├── cyberpunk9.png ├── discriminator.PNG ├── discriminator_forward_pass_1.PNG ├── discriminator_forward_pass_2.PNG ├── discriminator_update_1.PNG ├── discriminator_update_2.PNG ├── earth.png ├── flowers.png ├── generator_discriminator_forward_pass.PNG ├── generator_forward_pass.PNG ├── generator_update_1.PNG ├── generator_update_2.PNG ├── incomplete.jpg ├── interleaved_training.PNG ├── learn_part_network.PNG ├── lenna.png ├── loss_function.PNG ├── man.png ├── perceptual_loss.PNG ├── pixel_shuffle.PNG ├── pixel_shuffle_layer.PNG ├── samurai.png ├── samurai_bicubic.png ├── samurai_bilinear.png ├── samurai_hr.png ├── samurai_lanczos.png ├── samurai_lr.png ├── samurai_nn.png ├── samurai_sr.png ├── samurai_srgan.png ├── samurai_srresnet.png ├── skip_connections_1.PNG ├── skip_connections_2.PNG ├── skip_connections_3.PNG ├── skip_connections_4.PNG ├── srgan.PNG ├── srresnet.PNG ├── srresnet_forward_pass.PNG ├── srresnet_update.PNG ├── subpixel_convolution.PNG ├── tiger.png ├── upsampling_bilinear_1.PNG ├── upsampling_bilinear_2.PNG ├── upsampling_empty.PNG ├── upsampling_lr.PNG ├── upsampling_methods.PNG ├── upsampling_nn.PNG ├── upsampling_sr.PNG ├── vgg_forward_pass.PNG ├── why_loss_function.PNG ├── why_not_mse_1.PNG ├── why_not_mse_2.PNG ├── why_not_mse_3.PNG └── zebra.png ├── models.py ├── super_resolve.py ├── train_srgan.py ├── train_srresnet.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Sagar Vinodababu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is a **[PyTorch](https://pytorch.org) Tutorial to Super-Resolution**. 2 | 3 | This is also a tutorial for learning about **GANs** and how they work, regardless of intended task or application. 4 | 5 | This is the fifth in [a series of tutorials](https://github.com/sgrvinod/Deep-Tutorials-for-PyTorch) I'm writing about _implementing_ cool models on your own with the amazing PyTorch library. 6 | 7 | Basic knowledge of PyTorch, convolutional neural networks is assumed. 8 | 9 | If you're new to PyTorch, first read [Deep Learning with PyTorch: A 60 Minute Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html) and [Learning PyTorch with Examples](https://pytorch.org/tutorials/beginner/pytorch_with_examples.html). 10 | 11 | Questions, suggestions, or corrections can be posted as [issues](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/issues). 12 | 13 | I'm using `PyTorch 1.4` in `Python 3.6`. 14 | 15 | # Contents 16 | 17 | [***Objective***](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution#objective) 18 | 19 | [***Concepts***](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution#concepts) 20 | 21 | [***Overview***](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution#overview) 22 | 23 | [***Implementation***](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution#implementation) 24 | 25 | [***Training***](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution#training) 26 | 27 | [***Evaluation***](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution#evaluation) 28 | 29 | [***Inference***](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution#inference) 30 | 31 | [***Frequently Asked Questions***](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution#frequently-asked-questions) 32 | 33 | # Objective 34 | 35 | **To build a model that can realistically increase image resolution.** 36 | 37 | Super-resolution (SR) models essentially hallucinate new pixels where previously there were none. In this tutorial, we will try to _quadruple_ the dimensions of an image i.e. increase the number of pixels by 16x! 38 | 39 | We're going to be implementing [_Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network_](https://arxiv.org/abs/1609.04802). It's not just that the results are very impressive... it's also a great introduction to GANs! 40 | 41 | We will train the two models described in the paper — the SRResNet, and the SRGAN which greatly improves upon the former through adversarial training. 42 | 43 | Before you proceed, take a look at some examples generated from low-resolution images not seen during training. _Enhance!_ 44 | 45 | --- 46 | 47 |
48 |
49 |
62 |
63 |
111 |
112 |
117 |
118 |
129 |
130 |
135 |
136 |
141 |
142 |
147 |
148 |
155 |
156 |
161 |
162 |
165 |
166 |
171 |
172 |
177 |
178 |
197 |
198 |
209 |
210 |
215 |
216 |
231 |
232 |
237 |
238 |
243 |
244 |
263 |
264 |
271 |
272 |
279 |
280 |
289 |
290 |
301 |
302 |
317 |
318 |
331 |
332 |
357 |
358 |
363 |
364 |
395 |
396 |
419 |
420 |
433 |
434 |
439 |
440 |
447 |
448 |
453 |
454 |
473 |
474 |
481 |
482 |
491 |
492 |
503 |
504 |
511 |
512 |
517 |
518 |
531 |
532 |
539 |
540 |
553 |
554 |
825 | Click on image to view at full size. 826 |
827 | 828 |  829 | 830 | --- 831 | 832 |833 | Click on image to view at full size. 834 |
835 | 836 |  837 | 838 | --- 839 | 840 |841 | Click on image to view at full size. 842 |
843 | 844 |  845 | 846 | --- 847 | 848 |849 | Click on image to view at full size. 850 |
851 | 852 |
853 |
854 |
859 | Click on image to view at full size. 860 |
861 | 862 |
863 |
864 |
869 | Click on image to view at full size. 870 |
871 | 872 |  873 | 874 | --- 875 | 876 | # Frequently Asked Questions 877 | 878 | I will populate this section over time from common questions asked in the [*Issues*](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/issues) section of this repository. 879 | 880 | **Why are super-resolved (SR) images from the Generator passed through the Discriminator twice? Why not simply *reuse* the output of the Discriminator from the first time?** 881 | 882 | Yes, we do discriminate SR images *twice* – 883 | 884 | - When training the Generator, we pass SR images through the Discriminator, and use the Discriminator's output in the adversarial loss function with the incorrect but desired $HR$ label. 885 | 886 | - When training the Discriminator, we pass SR images through the Discriminator, and use the Discriminator's output to calculate the binary cross entropy loss with the correct and desired $SR$ label. 887 | 888 | In the first instance, our goal is to update the parameters $\theta_G$ of the Generator using the gradients of the loss function with respect to $\theta_G$. And indeed, the Generator *is* a part of the computational graph over which we backpropagate gradients. 889 | 890 | In the second instance, our goal is to update only the parameters $\theta_D$ of the Discriminator, which are *upstream* of $\theta_G$ in the *backwards* direction as we backpropagate gradients. 891 | 892 | In other words, it is not necessary to calculate the gradients of the loss function with respect to $\theta_G$ when training the Discriminator, and there is *no* need for the Generator to be a part of the computational graph! Having it so would be expensive because backpropagation is expensive. Therefore, we *detach* the SR images from the computational graph in the second instance, causing it to become, essentially, an independent variable with no memory of the computational graph (i.e. the Generator) that led to its creation. 893 | 894 | This is why we forward-propagate twice – once with the SR images a part of the full SRGAN computational graph, *requiring* backpropagation across the Generator, and once with the SR images detached from the Generator's computational graph, *preventing* backpropagation across the Generator. 895 | 896 | Forward-propagating twice is *much* cheaper than backpropagating twice. 897 | 898 | **How does subpixel convolution compare with transposed convolution?** 899 | 900 | They seem rather similar to me, and should be able to achieve similar results. 901 | 902 | They can be mathematically equivalent if, for a desired upsampling factor $s$, and a kernel size $k$ used in the subpixel convolution, the kernel size for the transposed convolution is $sk$. The number of parameters in this case will also be the same – $ns^2 * i * k * k$ for the former and $n * i * sk * sk$ for the latter. 903 | 904 | However, there are indications from some people that subpixel convolution *is* superior in particular ways, although I do not understand why. See this [paper](https://arxiv.org/pdf/1609.07009.pdf), this [repository](https://github.com/atriumlts/subpixel), and this [Reddit thread](https://www.reddit.com/r/MachineLearning/comments/n5ru8r/d_subpixel_convolutions_vs_transposed_convolutions/). Perhaps the [original paper](https://arxiv.org/pdf/1609.05158.pdf) too. 905 | 906 | Obviously, being mathematically equivalent does not mean they are optimizable or learnable or efficient in the same way, but if anyone can knows *why* subpixel convolution can yield superior results, please open an issue and let me know so I can add this information to this tutorial. 907 | -------------------------------------------------------------------------------- /create_data_lists.py: -------------------------------------------------------------------------------- 1 | from utils import create_data_lists 2 | 3 | if __name__ == '__main__': 4 | create_data_lists(train_folders=['/media/ssd/sr data/train2014', 5 | '/media/ssd/sr data/val2014'], 6 | test_folders=['/media/ssd/sr data/BSDS100', 7 | '/media/ssd/sr data/Set5', 8 | '/media/ssd/sr data/Set14'], 9 | min_size=100, 10 | output_folder='./') 11 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import json 4 | import os 5 | from PIL import Image 6 | from utils import ImageTransforms 7 | 8 | 9 | class SRDataset(Dataset): 10 | """ 11 | A PyTorch Dataset to be used by a PyTorch DataLoader. 12 | """ 13 | 14 | def __init__(self, data_folder, split, crop_size, scaling_factor, lr_img_type, hr_img_type, test_data_name=None): 15 | """ 16 | :param data_folder: # folder with JSON data files 17 | :param split: one of 'train' or 'test' 18 | :param crop_size: crop size of target HR images 19 | :param scaling_factor: the input LR images will be downsampled from the target HR images by this factor; the scaling done in the super-resolution 20 | :param lr_img_type: the format for the LR image supplied to the model; see convert_image() in utils.py for available formats 21 | :param hr_img_type: the format for the HR image supplied to the model; see convert_image() in utils.py for available formats 22 | :param test_data_name: if this is the 'test' split, which test dataset? (for example, "Set14") 23 | """ 24 | 25 | self.data_folder = data_folder 26 | self.split = split.lower() 27 | self.crop_size = int(crop_size) 28 | self.scaling_factor = int(scaling_factor) 29 | self.lr_img_type = lr_img_type 30 | self.hr_img_type = hr_img_type 31 | self.test_data_name = test_data_name 32 | 33 | assert self.split in {'train', 'test'} 34 | if self.split == 'test' and self.test_data_name is None: 35 | raise ValueError("Please provide the name of the test dataset!") 36 | assert lr_img_type in {'[0, 255]', '[0, 1]', '[-1, 1]', 'imagenet-norm'} 37 | assert hr_img_type in {'[0, 255]', '[0, 1]', '[-1, 1]', 'imagenet-norm'} 38 | 39 | # If this is a training dataset, then crop dimensions must be perfectly divisible by the scaling factor 40 | # (If this is a test dataset, images are not cropped to a fixed size, so this variable isn't used) 41 | if self.split == 'train': 42 | assert self.crop_size % self.scaling_factor == 0, "Crop dimensions are not perfectly divisible by scaling factor! This will lead to a mismatch in the dimensions of the original HR patches and their super-resolved (SR) versions!" 43 | 44 | # Read list of image-paths 45 | if self.split == 'train': 46 | with open(os.path.join(data_folder, 'train_images.json'), 'r') as j: 47 | self.images = json.load(j) 48 | else: 49 | with open(os.path.join(data_folder, self.test_data_name + '_test_images.json'), 'r') as j: 50 | self.images = json.load(j) 51 | 52 | # Select the correct set of transforms 53 | self.transform = ImageTransforms(split=self.split, 54 | crop_size=self.crop_size, 55 | scaling_factor=self.scaling_factor, 56 | lr_img_type=self.lr_img_type, 57 | hr_img_type=self.hr_img_type) 58 | 59 | def __getitem__(self, i): 60 | """ 61 | This method is required to be defined for use in the PyTorch DataLoader. 62 | 63 | :param i: index to retrieve 64 | :return: the 'i'th pair LR and HR images to be fed into the model 65 | """ 66 | # Read image 67 | img = Image.open(self.images[i], mode='r') 68 | img = img.convert('RGB') 69 | if img.width <= 96 or img.height <= 96: 70 | print(self.images[i], img.width, img.height) 71 | lr_img, hr_img = self.transform(img) 72 | 73 | return lr_img, hr_img 74 | 75 | def __len__(self): 76 | """ 77 | This method is required to be defined for use in the PyTorch DataLoader. 78 | 79 | :return: size of this data (in number of images) 80 | """ 81 | return len(self.images) 82 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from skimage.metrics import peak_signal_noise_ratio, structural_similarity 3 | from datasets import SRDataset 4 | 5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | 7 | # Data 8 | data_folder = "./" 9 | test_data_names = ["Set5", "Set14", "BSDS100"] 10 | 11 | # Model checkpoints 12 | srgan_checkpoint = "./checkpoint_srgan.pth.tar" 13 | srresnet_checkpoint = "./checkpoint_srresnet.pth.tar" 14 | 15 | # Load model, either the SRResNet or the SRGAN 16 | # srresnet = torch.load(srresnet_checkpoint)['model'].to(device) 17 | # srresnet.eval() 18 | # model = srresnet 19 | srgan_generator = torch.load(srgan_checkpoint)['generator'].to(device) 20 | srgan_generator.eval() 21 | model = srgan_generator 22 | 23 | # Evaluate 24 | for test_data_name in test_data_names: 25 | print("\nFor %s:\n" % test_data_name) 26 | 27 | # Custom dataloader 28 | test_dataset = SRDataset(data_folder, 29 | split='test', 30 | crop_size=0, 31 | scaling_factor=4, 32 | lr_img_type='imagenet-norm', 33 | hr_img_type='[-1, 1]', 34 | test_data_name=test_data_name) 35 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4, 36 | pin_memory=True) 37 | 38 | # Keep track of the PSNRs and the SSIMs across batches 39 | PSNRs = AverageMeter() 40 | SSIMs = AverageMeter() 41 | 42 | # Prohibit gradient computation explicitly because I had some problems with memory 43 | with torch.no_grad(): 44 | # Batches 45 | for i, (lr_imgs, hr_imgs) in enumerate(test_loader): 46 | # Move to default device 47 | lr_imgs = lr_imgs.to(device) # (batch_size (1), 3, w / 4, h / 4), imagenet-normed 48 | hr_imgs = hr_imgs.to(device) # (batch_size (1), 3, w, h), in [-1, 1] 49 | 50 | # Forward prop. 51 | sr_imgs = model(lr_imgs) # (1, 3, w, h), in [-1, 1] 52 | 53 | # Calculate PSNR and SSIM 54 | sr_imgs_y = convert_image(sr_imgs, source='[-1, 1]', target='y-channel').squeeze( 55 | 0) # (w, h), in y-channel 56 | hr_imgs_y = convert_image(hr_imgs, source='[-1, 1]', target='y-channel').squeeze(0) # (w, h), in y-channel 57 | psnr = peak_signal_noise_ratio(hr_imgs_y.cpu().numpy(), sr_imgs_y.cpu().numpy(), 58 | data_range=255.) 59 | ssim = structural_similarity(hr_imgs_y.cpu().numpy(), sr_imgs_y.cpu().numpy(), 60 | data_range=255.) 61 | PSNRs.update(psnr, lr_imgs.size(0)) 62 | SSIMs.update(ssim, lr_imgs.size(0)) 63 | 64 | # Print average PSNR and SSIM 65 | print('PSNR - {psnrs.avg:.3f}'.format(psnrs=PSNRs)) 66 | print('SSIM - {ssims.avg:.3f}'.format(ssims=SSIMs)) 67 | 68 | print("\n") 69 | -------------------------------------------------------------------------------- /img/baboon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/baboon.png -------------------------------------------------------------------------------- /img/cyberpunk1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/cyberpunk1.png -------------------------------------------------------------------------------- /img/cyberpunk4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/cyberpunk4.png -------------------------------------------------------------------------------- /img/cyberpunk6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/cyberpunk6.png -------------------------------------------------------------------------------- /img/cyberpunk7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/cyberpunk7.png -------------------------------------------------------------------------------- /img/cyberpunk8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/cyberpunk8.png -------------------------------------------------------------------------------- /img/cyberpunk9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/cyberpunk9.png -------------------------------------------------------------------------------- /img/discriminator.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/discriminator.PNG -------------------------------------------------------------------------------- /img/discriminator_forward_pass_1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/discriminator_forward_pass_1.PNG -------------------------------------------------------------------------------- /img/discriminator_forward_pass_2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/discriminator_forward_pass_2.PNG -------------------------------------------------------------------------------- /img/discriminator_update_1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/discriminator_update_1.PNG -------------------------------------------------------------------------------- /img/discriminator_update_2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/discriminator_update_2.PNG -------------------------------------------------------------------------------- /img/earth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/earth.png -------------------------------------------------------------------------------- /img/flowers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/flowers.png -------------------------------------------------------------------------------- /img/generator_discriminator_forward_pass.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/generator_discriminator_forward_pass.PNG -------------------------------------------------------------------------------- /img/generator_forward_pass.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/generator_forward_pass.PNG -------------------------------------------------------------------------------- /img/generator_update_1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/generator_update_1.PNG -------------------------------------------------------------------------------- /img/generator_update_2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/generator_update_2.PNG -------------------------------------------------------------------------------- /img/incomplete.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/incomplete.jpg -------------------------------------------------------------------------------- /img/interleaved_training.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/interleaved_training.PNG -------------------------------------------------------------------------------- /img/learn_part_network.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/learn_part_network.PNG -------------------------------------------------------------------------------- /img/lenna.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/lenna.png -------------------------------------------------------------------------------- /img/loss_function.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/loss_function.PNG -------------------------------------------------------------------------------- /img/man.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/man.png -------------------------------------------------------------------------------- /img/perceptual_loss.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/perceptual_loss.PNG -------------------------------------------------------------------------------- /img/pixel_shuffle.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/pixel_shuffle.PNG -------------------------------------------------------------------------------- /img/pixel_shuffle_layer.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/pixel_shuffle_layer.PNG -------------------------------------------------------------------------------- /img/samurai.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/samurai.png -------------------------------------------------------------------------------- /img/samurai_bicubic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/samurai_bicubic.png -------------------------------------------------------------------------------- /img/samurai_bilinear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/samurai_bilinear.png -------------------------------------------------------------------------------- /img/samurai_hr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/samurai_hr.png -------------------------------------------------------------------------------- /img/samurai_lanczos.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/samurai_lanczos.png -------------------------------------------------------------------------------- /img/samurai_lr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/samurai_lr.png -------------------------------------------------------------------------------- /img/samurai_nn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/samurai_nn.png -------------------------------------------------------------------------------- /img/samurai_sr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/samurai_sr.png -------------------------------------------------------------------------------- /img/samurai_srgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/samurai_srgan.png -------------------------------------------------------------------------------- /img/samurai_srresnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/samurai_srresnet.png -------------------------------------------------------------------------------- /img/skip_connections_1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/skip_connections_1.PNG -------------------------------------------------------------------------------- /img/skip_connections_2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/skip_connections_2.PNG -------------------------------------------------------------------------------- /img/skip_connections_3.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/skip_connections_3.PNG -------------------------------------------------------------------------------- /img/skip_connections_4.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/skip_connections_4.PNG -------------------------------------------------------------------------------- /img/srgan.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/srgan.PNG -------------------------------------------------------------------------------- /img/srresnet.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/srresnet.PNG -------------------------------------------------------------------------------- /img/srresnet_forward_pass.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/srresnet_forward_pass.PNG -------------------------------------------------------------------------------- /img/srresnet_update.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/srresnet_update.PNG -------------------------------------------------------------------------------- /img/subpixel_convolution.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/subpixel_convolution.PNG -------------------------------------------------------------------------------- /img/tiger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/tiger.png -------------------------------------------------------------------------------- /img/upsampling_bilinear_1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/upsampling_bilinear_1.PNG -------------------------------------------------------------------------------- /img/upsampling_bilinear_2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/upsampling_bilinear_2.PNG -------------------------------------------------------------------------------- /img/upsampling_empty.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/upsampling_empty.PNG -------------------------------------------------------------------------------- /img/upsampling_lr.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/upsampling_lr.PNG -------------------------------------------------------------------------------- /img/upsampling_methods.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/upsampling_methods.PNG -------------------------------------------------------------------------------- /img/upsampling_nn.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/upsampling_nn.PNG -------------------------------------------------------------------------------- /img/upsampling_sr.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/upsampling_sr.PNG -------------------------------------------------------------------------------- /img/vgg_forward_pass.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/vgg_forward_pass.PNG -------------------------------------------------------------------------------- /img/why_loss_function.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/why_loss_function.PNG -------------------------------------------------------------------------------- /img/why_not_mse_1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/why_not_mse_1.PNG -------------------------------------------------------------------------------- /img/why_not_mse_2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/why_not_mse_2.PNG -------------------------------------------------------------------------------- /img/why_not_mse_3.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/why_not_mse_3.PNG -------------------------------------------------------------------------------- /img/zebra.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sgrvinod/a-PyTorch-Tutorial-to-Super-Resolution/c30556bc79a9d43539f2d658722bcb7ff5011ddb/img/zebra.png -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torchvision 4 | import math 5 | 6 | 7 | class ConvolutionalBlock(nn.Module): 8 | """ 9 | A convolutional block, comprising convolutional, BN, activation layers. 10 | """ 11 | 12 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, batch_norm=False, activation=None): 13 | """ 14 | :param in_channels: number of input channels 15 | :param out_channels: number of output channe;s 16 | :param kernel_size: kernel size 17 | :param stride: stride 18 | :param batch_norm: include a BN layer? 19 | :param activation: Type of activation; None if none 20 | """ 21 | super(ConvolutionalBlock, self).__init__() 22 | 23 | if activation is not None: 24 | activation = activation.lower() 25 | assert activation in {'prelu', 'leakyrelu', 'tanh'} 26 | 27 | # A container that will hold the layers in this convolutional block 28 | layers = list() 29 | 30 | # A convolutional layer 31 | layers.append( 32 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 33 | padding=kernel_size // 2)) 34 | 35 | # A batch normalization (BN) layer, if wanted 36 | if batch_norm is True: 37 | layers.append(nn.BatchNorm2d(num_features=out_channels)) 38 | 39 | # An activation layer, if wanted 40 | if activation == 'prelu': 41 | layers.append(nn.PReLU()) 42 | elif activation == 'leakyrelu': 43 | layers.append(nn.LeakyReLU(0.2)) 44 | elif activation == 'tanh': 45 | layers.append(nn.Tanh()) 46 | 47 | # Put together the convolutional block as a sequence of the layers in this container 48 | self.conv_block = nn.Sequential(*layers) 49 | 50 | def forward(self, input): 51 | """ 52 | Forward propagation. 53 | 54 | :param input: input images, a tensor of size (N, in_channels, w, h) 55 | :return: output images, a tensor of size (N, out_channels, w, h) 56 | """ 57 | output = self.conv_block(input) # (N, out_channels, w, h) 58 | 59 | return output 60 | 61 | 62 | class SubPixelConvolutionalBlock(nn.Module): 63 | """ 64 | A subpixel convolutional block, comprising convolutional, pixel-shuffle, and PReLU activation layers. 65 | """ 66 | 67 | def __init__(self, kernel_size=3, n_channels=64, scaling_factor=2): 68 | """ 69 | :param kernel_size: kernel size of the convolution 70 | :param n_channels: number of input and output channels 71 | :param scaling_factor: factor to scale input images by (along both dimensions) 72 | """ 73 | super(SubPixelConvolutionalBlock, self).__init__() 74 | 75 | # A convolutional layer that increases the number of channels by scaling factor^2, followed by pixel shuffle and PReLU 76 | self.conv = nn.Conv2d(in_channels=n_channels, out_channels=n_channels * (scaling_factor ** 2), 77 | kernel_size=kernel_size, padding=kernel_size // 2) 78 | # These additional channels are shuffled to form additional pixels, upscaling each dimension by the scaling factor 79 | self.pixel_shuffle = nn.PixelShuffle(upscale_factor=scaling_factor) 80 | self.prelu = nn.PReLU() 81 | 82 | def forward(self, input): 83 | """ 84 | Forward propagation. 85 | 86 | :param input: input images, a tensor of size (N, n_channels, w, h) 87 | :return: scaled output images, a tensor of size (N, n_channels, w * scaling factor, h * scaling factor) 88 | """ 89 | output = self.conv(input) # (N, n_channels * scaling factor^2, w, h) 90 | output = self.pixel_shuffle(output) # (N, n_channels, w * scaling factor, h * scaling factor) 91 | output = self.prelu(output) # (N, n_channels, w * scaling factor, h * scaling factor) 92 | 93 | return output 94 | 95 | 96 | class ResidualBlock(nn.Module): 97 | """ 98 | A residual block, comprising two convolutional blocks with a residual connection across them. 99 | """ 100 | 101 | def __init__(self, kernel_size=3, n_channels=64): 102 | """ 103 | :param kernel_size: kernel size 104 | :param n_channels: number of input and output channels (same because the input must be added to the output) 105 | """ 106 | super(ResidualBlock, self).__init__() 107 | 108 | # The first convolutional block 109 | self.conv_block1 = ConvolutionalBlock(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, 110 | batch_norm=True, activation='PReLu') 111 | 112 | # The second convolutional block 113 | self.conv_block2 = ConvolutionalBlock(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, 114 | batch_norm=True, activation=None) 115 | 116 | def forward(self, input): 117 | """ 118 | Forward propagation. 119 | 120 | :param input: input images, a tensor of size (N, n_channels, w, h) 121 | :return: output images, a tensor of size (N, n_channels, w, h) 122 | """ 123 | residual = input # (N, n_channels, w, h) 124 | output = self.conv_block1(input) # (N, n_channels, w, h) 125 | output = self.conv_block2(output) # (N, n_channels, w, h) 126 | output = output + residual # (N, n_channels, w, h) 127 | 128 | return output 129 | 130 | 131 | class SRResNet(nn.Module): 132 | """ 133 | The SRResNet, as defined in the paper. 134 | """ 135 | 136 | def __init__(self, large_kernel_size=9, small_kernel_size=3, n_channels=64, n_blocks=16, scaling_factor=4): 137 | """ 138 | :param large_kernel_size: kernel size of the first and last convolutions which transform the inputs and outputs 139 | :param small_kernel_size: kernel size of all convolutions in-between, i.e. those in the residual and subpixel convolutional blocks 140 | :param n_channels: number of channels in-between, i.e. the input and output channels for the residual and subpixel convolutional blocks 141 | :param n_blocks: number of residual blocks 142 | :param scaling_factor: factor to scale input images by (along both dimensions) in the subpixel convolutional block 143 | """ 144 | super(SRResNet, self).__init__() 145 | 146 | # Scaling factor must be 2, 4, or 8 147 | scaling_factor = int(scaling_factor) 148 | assert scaling_factor in {2, 4, 8}, "The scaling factor must be 2, 4, or 8!" 149 | 150 | # The first convolutional block 151 | self.conv_block1 = ConvolutionalBlock(in_channels=3, out_channels=n_channels, kernel_size=large_kernel_size, 152 | batch_norm=False, activation='PReLu') 153 | 154 | # A sequence of n_blocks residual blocks, each containing a skip-connection across the block 155 | self.residual_blocks = nn.Sequential( 156 | *[ResidualBlock(kernel_size=small_kernel_size, n_channels=n_channels) for i in range(n_blocks)]) 157 | 158 | # Another convolutional block 159 | self.conv_block2 = ConvolutionalBlock(in_channels=n_channels, out_channels=n_channels, 160 | kernel_size=small_kernel_size, 161 | batch_norm=True, activation=None) 162 | 163 | # Upscaling is done by sub-pixel convolution, with each such block upscaling by a factor of 2 164 | n_subpixel_convolution_blocks = int(math.log2(scaling_factor)) 165 | self.subpixel_convolutional_blocks = nn.Sequential( 166 | *[SubPixelConvolutionalBlock(kernel_size=small_kernel_size, n_channels=n_channels, scaling_factor=2) for i 167 | in range(n_subpixel_convolution_blocks)]) 168 | 169 | # The last convolutional block 170 | self.conv_block3 = ConvolutionalBlock(in_channels=n_channels, out_channels=3, kernel_size=large_kernel_size, 171 | batch_norm=False, activation='Tanh') 172 | 173 | def forward(self, lr_imgs): 174 | """ 175 | Forward prop. 176 | 177 | :param lr_imgs: low-resolution input images, a tensor of size (N, 3, w, h) 178 | :return: super-resolution output images, a tensor of size (N, 3, w * scaling factor, h * scaling factor) 179 | """ 180 | output = self.conv_block1(lr_imgs) # (N, 3, w, h) 181 | residual = output # (N, n_channels, w, h) 182 | output = self.residual_blocks(output) # (N, n_channels, w, h) 183 | output = self.conv_block2(output) # (N, n_channels, w, h) 184 | output = output + residual # (N, n_channels, w, h) 185 | output = self.subpixel_convolutional_blocks(output) # (N, n_channels, w * scaling factor, h * scaling factor) 186 | sr_imgs = self.conv_block3(output) # (N, 3, w * scaling factor, h * scaling factor) 187 | 188 | return sr_imgs 189 | 190 | 191 | class Generator(nn.Module): 192 | """ 193 | The generator in the SRGAN, as defined in the paper. Architecture identical to the SRResNet. 194 | """ 195 | 196 | def __init__(self, large_kernel_size=9, small_kernel_size=3, n_channels=64, n_blocks=16, scaling_factor=4): 197 | """ 198 | :param large_kernel_size: kernel size of the first and last convolutions which transform the inputs and outputs 199 | :param small_kernel_size: kernel size of all convolutions in-between, i.e. those in the residual and subpixel convolutional blocks 200 | :param n_channels: number of channels in-between, i.e. the input and output channels for the residual and subpixel convolutional blocks 201 | :param n_blocks: number of residual blocks 202 | :param scaling_factor: factor to scale input images by (along both dimensions) in the subpixel convolutional block 203 | """ 204 | super(Generator, self).__init__() 205 | 206 | # The generator is simply an SRResNet, as above 207 | self.net = SRResNet(large_kernel_size=large_kernel_size, small_kernel_size=small_kernel_size, 208 | n_channels=n_channels, n_blocks=n_blocks, scaling_factor=scaling_factor) 209 | 210 | def initialize_with_srresnet(self, srresnet_checkpoint): 211 | """ 212 | Initialize with weights from a trained SRResNet. 213 | 214 | :param srresnet_checkpoint: checkpoint filepath 215 | """ 216 | srresnet = torch.load(srresnet_checkpoint)['model'] 217 | self.net.load_state_dict(srresnet.state_dict()) 218 | 219 | print("\nLoaded weights from pre-trained SRResNet.\n") 220 | 221 | def forward(self, lr_imgs): 222 | """ 223 | Forward prop. 224 | 225 | :param lr_imgs: low-resolution input images, a tensor of size (N, 3, w, h) 226 | :return: super-resolution output images, a tensor of size (N, 3, w * scaling factor, h * scaling factor) 227 | """ 228 | sr_imgs = self.net(lr_imgs) # (N, n_channels, w * scaling factor, h * scaling factor) 229 | 230 | return sr_imgs 231 | 232 | 233 | class Discriminator(nn.Module): 234 | """ 235 | The discriminator in the SRGAN, as defined in the paper. 236 | """ 237 | 238 | def __init__(self, kernel_size=3, n_channels=64, n_blocks=8, fc_size=1024): 239 | """ 240 | :param kernel_size: kernel size in all convolutional blocks 241 | :param n_channels: number of output channels in the first convolutional block, after which it is doubled in every 2nd block thereafter 242 | :param n_blocks: number of convolutional blocks 243 | :param fc_size: size of the first fully connected layer 244 | """ 245 | super(Discriminator, self).__init__() 246 | 247 | in_channels = 3 248 | 249 | # A series of convolutional blocks 250 | # The first, third, fifth (and so on) convolutional blocks increase the number of channels but retain image size 251 | # The second, fourth, sixth (and so on) convolutional blocks retain the same number of channels but halve image size 252 | # The first convolutional block is unique because it does not employ batch normalization 253 | conv_blocks = list() 254 | for i in range(n_blocks): 255 | out_channels = (n_channels if i is 0 else in_channels * 2) if i % 2 is 0 else in_channels 256 | conv_blocks.append( 257 | ConvolutionalBlock(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 258 | stride=1 if i % 2 is 0 else 2, batch_norm=i is not 0, activation='LeakyReLu')) 259 | in_channels = out_channels 260 | self.conv_blocks = nn.Sequential(*conv_blocks) 261 | 262 | # An adaptive pool layer that resizes it to a standard size 263 | # For the default input size of 96 and 8 convolutional blocks, this will have no effect 264 | self.adaptive_pool = nn.AdaptiveAvgPool2d((6, 6)) 265 | 266 | self.fc1 = nn.Linear(out_channels * 6 * 6, fc_size) 267 | 268 | self.leaky_relu = nn.LeakyReLU(0.2) 269 | 270 | self.fc2 = nn.Linear(1024, 1) 271 | 272 | # Don't need a sigmoid layer because the sigmoid operation is performed by PyTorch's nn.BCEWithLogitsLoss() 273 | 274 | def forward(self, imgs): 275 | """ 276 | Forward propagation. 277 | 278 | :param imgs: high-resolution or super-resolution images which must be classified as such, a tensor of size (N, 3, w * scaling factor, h * scaling factor) 279 | :return: a score (logit) for whether it is a high-resolution image, a tensor of size (N) 280 | """ 281 | batch_size = imgs.size(0) 282 | output = self.conv_blocks(imgs) 283 | output = self.adaptive_pool(output) 284 | output = self.fc1(output.view(batch_size, -1)) 285 | output = self.leaky_relu(output) 286 | logit = self.fc2(output) 287 | 288 | return logit 289 | 290 | 291 | class TruncatedVGG19(nn.Module): 292 | """ 293 | A truncated VGG19 network, such that its output is the 'feature map obtained by the j-th convolution (after activation) 294 | before the i-th maxpooling layer within the VGG19 network', as defined in the paper. 295 | 296 | Used to calculate the MSE loss in this VGG feature-space, i.e. the VGG loss. 297 | """ 298 | 299 | def __init__(self, i, j): 300 | """ 301 | :param i: the index i in the definition above 302 | :param j: the index j in the definition above 303 | """ 304 | super(TruncatedVGG19, self).__init__() 305 | 306 | # Load the pre-trained VGG19 available in torchvision 307 | vgg19 = torchvision.models.vgg19(pretrained=True) 308 | 309 | maxpool_counter = 0 310 | conv_counter = 0 311 | truncate_at = 0 312 | # Iterate through the convolutional section ("features") of the VGG19 313 | for layer in vgg19.features.children(): 314 | truncate_at += 1 315 | 316 | # Count the number of maxpool layers and the convolutional layers after each maxpool 317 | if isinstance(layer, nn.Conv2d): 318 | conv_counter += 1 319 | if isinstance(layer, nn.MaxPool2d): 320 | maxpool_counter += 1 321 | conv_counter = 0 322 | 323 | # Break if we reach the jth convolution after the (i - 1)th maxpool 324 | if maxpool_counter == i - 1 and conv_counter == j: 325 | break 326 | 327 | # Check if conditions were satisfied 328 | assert maxpool_counter == i - 1 and conv_counter == j, "One or both of i=%d and j=%d are not valid choices for the VGG19!" % ( 329 | i, j) 330 | 331 | # Truncate to the jth convolution (+ activation) before the ith maxpool layer 332 | self.truncated_vgg19 = nn.Sequential(*list(vgg19.features.children())[:truncate_at + 1]) 333 | 334 | def forward(self, input): 335 | """ 336 | Forward propagation 337 | :param input: high-resolution or super-resolution images, a tensor of size (N, 3, w * scaling factor, h * scaling factor) 338 | :return: the specified VGG19 feature map, a tensor of size (N, feature_map_channels, feature_map_w, feature_map_h) 339 | """ 340 | output = self.truncated_vgg19(input) # (N, feature_map_channels, feature_map_w, feature_map_h) 341 | 342 | return output 343 | -------------------------------------------------------------------------------- /super_resolve.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import * 3 | from PIL import Image, ImageDraw, ImageFont 4 | 5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | 7 | # Model checkpoints 8 | srgan_checkpoint = "./checkpoint_srgan.pth.tar" 9 | srresnet_checkpoint = "./checkpoint_srresnet.pth.tar" 10 | 11 | # Load models 12 | srresnet = torch.load(srresnet_checkpoint)['model'].to(device) 13 | srresnet.eval() 14 | srgan_generator = torch.load(srgan_checkpoint)['generator'].to(device) 15 | srgan_generator.eval() 16 | 17 | 18 | def visualize_sr(img, halve=False): 19 | """ 20 | Visualizes the super-resolved images from the SRResNet and SRGAN for comparison with the bicubic-upsampled image 21 | and the original high-resolution (HR) image, as done in the paper. 22 | 23 | :param img: filepath of the HR iamge 24 | :param halve: halve each dimension of the HR image to make sure it's not greater than the dimensions of your screen? 25 | For instance, for a 2160p HR image, the LR image will be of 540p (1080p/4) resolution. On a 1080p screen, 26 | you will therefore be looking at a comparison between a 540p LR image and a 1080p SR/HR image because 27 | your 1080p screen can only display the 2160p SR/HR image at a downsampled 1080p. This is only an 28 | APPARENT rescaling of 2x. 29 | If you want to reduce HR resolution by a different extent, modify accordingly. 30 | """ 31 | # Load image, downsample to obtain low-res version 32 | hr_img = Image.open(img, mode="r") 33 | hr_img = hr_img.convert('RGB') 34 | if halve: 35 | hr_img = hr_img.resize((int(hr_img.width / 2), int(hr_img.height / 2)), 36 | Image.LANCZOS) 37 | lr_img = hr_img.resize((int(hr_img.width / 4), int(hr_img.height / 4)), 38 | Image.BICUBIC) 39 | 40 | # Bicubic Upsampling 41 | bicubic_img = lr_img.resize((hr_img.width, hr_img.height), Image.BICUBIC) 42 | 43 | # Super-resolution (SR) with SRResNet 44 | sr_img_srresnet = srresnet(convert_image(lr_img, source='pil', target='imagenet-norm').unsqueeze(0).to(device)) 45 | sr_img_srresnet = sr_img_srresnet.squeeze(0).cpu().detach() 46 | sr_img_srresnet = convert_image(sr_img_srresnet, source='[-1, 1]', target='pil') 47 | 48 | # Super-resolution (SR) with SRGAN 49 | sr_img_srgan = srgan_generator(convert_image(lr_img, source='pil', target='imagenet-norm').unsqueeze(0).to(device)) 50 | sr_img_srgan = sr_img_srgan.squeeze(0).cpu().detach() 51 | sr_img_srgan = convert_image(sr_img_srgan, source='[-1, 1]', target='pil') 52 | 53 | # Create grid 54 | margin = 40 55 | grid_img = Image.new('RGB', (2 * hr_img.width + 3 * margin, 2 * hr_img.height + 3 * margin), (255, 255, 255)) 56 | 57 | # Font 58 | draw = ImageDraw.Draw(grid_img) 59 | try: 60 | font = ImageFont.truetype("calibril.ttf", size=23) 61 | # It will also look for this file in your OS's default fonts directory, where you may have the Calibri Light font installed if you have MS Office 62 | # Otherwise, use any TTF font of your choice 63 | except OSError: 64 | print( 65 | "Defaulting to a terrible font. To use a font of your choice, include the link to its TTF file in the function.") 66 | font = ImageFont.load_default() 67 | 68 | # Place bicubic-upsampled image 69 | grid_img.paste(bicubic_img, (margin, margin)) 70 | text_size = font.getsize("Bicubic") 71 | draw.text(xy=[margin + bicubic_img.width / 2 - text_size[0] / 2, margin - text_size[1] - 5], text="Bicubic", 72 | font=font, 73 | fill='black') 74 | 75 | # Place SRResNet image 76 | grid_img.paste(sr_img_srresnet, (2 * margin + bicubic_img.width, margin)) 77 | text_size = font.getsize("SRResNet") 78 | draw.text( 79 | xy=[2 * margin + bicubic_img.width + sr_img_srresnet.width / 2 - text_size[0] / 2, margin - text_size[1] - 5], 80 | text="SRResNet", font=font, fill='black') 81 | 82 | # Place SRGAN image 83 | grid_img.paste(sr_img_srgan, (margin, 2 * margin + sr_img_srresnet.height)) 84 | text_size = font.getsize("SRGAN") 85 | draw.text( 86 | xy=[margin + bicubic_img.width / 2 - text_size[0] / 2, 2 * margin + sr_img_srresnet.height - text_size[1] - 5], 87 | text="SRGAN", font=font, fill='black') 88 | 89 | # Place original HR image 90 | grid_img.paste(hr_img, (2 * margin + bicubic_img.width, 2 * margin + sr_img_srresnet.height)) 91 | text_size = font.getsize("Original HR") 92 | draw.text(xy=[2 * margin + bicubic_img.width + sr_img_srresnet.width / 2 - text_size[0] / 2, 93 | 2 * margin + sr_img_srresnet.height - text_size[1] - 1], text="Original HR", font=font, fill='black') 94 | 95 | # Display grid 96 | grid_img.show() 97 | 98 | return grid_img 99 | 100 | 101 | if __name__ == '__main__': 102 | grid_img = visualize_sr("/media/ssd/sr data/Set14/baboon.png") 103 | -------------------------------------------------------------------------------- /train_srgan.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch.backends.cudnn as cudnn 3 | from torch import nn 4 | from models import Generator, Discriminator, TruncatedVGG19 5 | from datasets import SRDataset 6 | from utils import * 7 | 8 | # Data parameters 9 | data_folder = './' # folder with JSON data files 10 | crop_size = 96 # crop size of target HR images 11 | scaling_factor = 4 # the scaling factor for the generator; the input LR images will be downsampled from the target HR images by this factor 12 | 13 | # Generator parameters 14 | large_kernel_size_g = 9 # kernel size of the first and last convolutions which transform the inputs and outputs 15 | small_kernel_size_g = 3 # kernel size of all convolutions in-between, i.e. those in the residual and subpixel convolutional blocks 16 | n_channels_g = 64 # number of channels in-between, i.e. the input and output channels for the residual and subpixel convolutional blocks 17 | n_blocks_g = 16 # number of residual blocks 18 | srresnet_checkpoint = "./checkpoint_srresnet.pth.tar" # filepath of the trained SRResNet checkpoint used for initialization 19 | 20 | # Discriminator parameters 21 | kernel_size_d = 3 # kernel size in all convolutional blocks 22 | n_channels_d = 64 # number of output channels in the first convolutional block, after which it is doubled in every 2nd block thereafter 23 | n_blocks_d = 8 # number of convolutional blocks 24 | fc_size_d = 1024 # size of the first fully connected layer 25 | 26 | # Learning parameters 27 | checkpoint = None # path to model (SRGAN) checkpoint, None if none 28 | batch_size = 16 # batch size 29 | start_epoch = 0 # start at this epoch 30 | iterations = 2e5 # number of training iterations 31 | workers = 4 # number of workers for loading data in the DataLoader 32 | vgg19_i = 5 # the index i in the definition for VGG loss; see paper or models.py 33 | vgg19_j = 4 # the index j in the definition for VGG loss; see paper or models.py 34 | beta = 1e-3 # the coefficient to weight the adversarial loss in the perceptual loss 35 | print_freq = 500 # print training status once every __ batches 36 | lr = 1e-4 # learning rate 37 | grad_clip = None # clip if gradients are exploding 38 | 39 | # Default device 40 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 41 | 42 | cudnn.benchmark = True 43 | 44 | 45 | def main(): 46 | """ 47 | Training. 48 | """ 49 | global start_epoch, epoch, checkpoint, srresnet_checkpoint 50 | 51 | # Initialize model or load checkpoint 52 | if checkpoint is None: 53 | # Generator 54 | generator = Generator(large_kernel_size=large_kernel_size_g, 55 | small_kernel_size=small_kernel_size_g, 56 | n_channels=n_channels_g, 57 | n_blocks=n_blocks_g, 58 | scaling_factor=scaling_factor) 59 | 60 | # Initialize generator network with pretrained SRResNet 61 | generator.initialize_with_srresnet(srresnet_checkpoint=srresnet_checkpoint) 62 | 63 | # Initialize generator's optimizer 64 | optimizer_g = torch.optim.Adam(params=filter(lambda p: p.requires_grad, generator.parameters()), 65 | lr=lr) 66 | 67 | # Discriminator 68 | discriminator = Discriminator(kernel_size=kernel_size_d, 69 | n_channels=n_channels_d, 70 | n_blocks=n_blocks_d, 71 | fc_size=fc_size_d) 72 | 73 | # Initialize discriminator's optimizer 74 | optimizer_d = torch.optim.Adam(params=filter(lambda p: p.requires_grad, discriminator.parameters()), 75 | lr=lr) 76 | 77 | else: 78 | checkpoint = torch.load(checkpoint) 79 | start_epoch = checkpoint['epoch'] + 1 80 | generator = checkpoint['generator'] 81 | discriminator = checkpoint['discriminator'] 82 | optimizer_g = checkpoint['optimizer_g'] 83 | optimizer_d = checkpoint['optimizer_d'] 84 | print("\nLoaded checkpoint from epoch %d.\n" % (checkpoint['epoch'] + 1)) 85 | 86 | # Truncated VGG19 network to be used in the loss calculation 87 | truncated_vgg19 = TruncatedVGG19(i=vgg19_i, j=vgg19_j) 88 | truncated_vgg19.eval() 89 | 90 | # Loss functions 91 | content_loss_criterion = nn.MSELoss() 92 | adversarial_loss_criterion = nn.BCEWithLogitsLoss() 93 | 94 | # Move to default device 95 | generator = generator.to(device) 96 | discriminator = discriminator.to(device) 97 | truncated_vgg19 = truncated_vgg19.to(device) 98 | content_loss_criterion = content_loss_criterion.to(device) 99 | adversarial_loss_criterion = adversarial_loss_criterion.to(device) 100 | 101 | # Custom dataloaders 102 | train_dataset = SRDataset(data_folder, 103 | split='train', 104 | crop_size=crop_size, 105 | scaling_factor=scaling_factor, 106 | lr_img_type='imagenet-norm', 107 | hr_img_type='imagenet-norm') 108 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers, 109 | pin_memory=True) 110 | 111 | # Total number of epochs to train for 112 | epochs = int(iterations // len(train_loader) + 1) 113 | 114 | # Epochs 115 | for epoch in range(start_epoch, epochs): 116 | 117 | # At the halfway point, reduce learning rate to a tenth 118 | if epoch == int((iterations / 2) // len(train_loader) + 1): 119 | adjust_learning_rate(optimizer_g, 0.1) 120 | adjust_learning_rate(optimizer_d, 0.1) 121 | 122 | # One epoch's training 123 | train(train_loader=train_loader, 124 | generator=generator, 125 | discriminator=discriminator, 126 | truncated_vgg19=truncated_vgg19, 127 | content_loss_criterion=content_loss_criterion, 128 | adversarial_loss_criterion=adversarial_loss_criterion, 129 | optimizer_g=optimizer_g, 130 | optimizer_d=optimizer_d, 131 | epoch=epoch) 132 | 133 | # Save checkpoint 134 | torch.save({'epoch': epoch, 135 | 'generator': generator, 136 | 'discriminator': discriminator, 137 | 'optimizer_g': optimizer_g, 138 | 'optimizer_d': optimizer_d}, 139 | 'checkpoint_srgan.pth.tar') 140 | 141 | 142 | def train(train_loader, generator, discriminator, truncated_vgg19, content_loss_criterion, adversarial_loss_criterion, 143 | optimizer_g, optimizer_d, epoch): 144 | """ 145 | One epoch's training. 146 | 147 | :param train_loader: train dataloader 148 | :param generator: generator 149 | :param discriminator: discriminator 150 | :param truncated_vgg19: truncated VGG19 network 151 | :param content_loss_criterion: content loss function (Mean Squared-Error loss) 152 | :param adversarial_loss_criterion: adversarial loss function (Binary Cross-Entropy loss) 153 | :param optimizer_g: optimizer for the generator 154 | :param optimizer_d: optimizer for the discriminator 155 | :param epoch: epoch number 156 | """ 157 | # Set to train mode 158 | generator.train() 159 | discriminator.train() # training mode enables batch normalization 160 | 161 | batch_time = AverageMeter() # forward prop. + back prop. time 162 | data_time = AverageMeter() # data loading time 163 | losses_c = AverageMeter() # content loss 164 | losses_a = AverageMeter() # adversarial loss in the generator 165 | losses_d = AverageMeter() # adversarial loss in the discriminator 166 | 167 | start = time.time() 168 | 169 | # Batches 170 | for i, (lr_imgs, hr_imgs) in enumerate(train_loader): 171 | data_time.update(time.time() - start) 172 | 173 | # Move to default device 174 | lr_imgs = lr_imgs.to(device) # (batch_size (N), 3, 24, 24), imagenet-normed 175 | hr_imgs = hr_imgs.to(device) # (batch_size (N), 3, 96, 96), imagenet-normed 176 | 177 | # GENERATOR UPDATE 178 | 179 | # Generate 180 | sr_imgs = generator(lr_imgs) # (N, 3, 96, 96), in [-1, 1] 181 | sr_imgs = convert_image(sr_imgs, source='[-1, 1]', target='imagenet-norm') # (N, 3, 96, 96), imagenet-normed 182 | 183 | # Calculate VGG feature maps for the super-resolved (SR) and high resolution (HR) images 184 | sr_imgs_in_vgg_space = truncated_vgg19(sr_imgs) 185 | hr_imgs_in_vgg_space = truncated_vgg19(hr_imgs).detach() # detached because they're constant, targets 186 | 187 | # Discriminate super-resolved (SR) images 188 | sr_discriminated = discriminator(sr_imgs) # (N) 189 | 190 | # Calculate the Perceptual loss 191 | content_loss = content_loss_criterion(sr_imgs_in_vgg_space, hr_imgs_in_vgg_space) 192 | adversarial_loss = adversarial_loss_criterion(sr_discriminated, torch.ones_like(sr_discriminated)) 193 | perceptual_loss = content_loss + beta * adversarial_loss 194 | 195 | # Back-prop. 196 | optimizer_g.zero_grad() 197 | perceptual_loss.backward() 198 | 199 | # Clip gradients, if necessary 200 | if grad_clip is not None: 201 | clip_gradient(optimizer_g, grad_clip) 202 | 203 | # Update generator 204 | optimizer_g.step() 205 | 206 | # Keep track of loss 207 | losses_c.update(content_loss.item(), lr_imgs.size(0)) 208 | losses_a.update(adversarial_loss.item(), lr_imgs.size(0)) 209 | 210 | # DISCRIMINATOR UPDATE 211 | 212 | # Discriminate super-resolution (SR) and high-resolution (HR) images 213 | hr_discriminated = discriminator(hr_imgs) 214 | sr_discriminated = discriminator(sr_imgs.detach()) 215 | # But didn't we already discriminate the SR images earlier, before updating the generator (G)? Why not just use that here? 216 | # Because, if we used that, we'd be back-propagating (finding gradients) over the G too when backward() is called 217 | # It's actually faster to detach the SR images from the G and forward-prop again, than to back-prop. over the G unnecessarily 218 | # See FAQ section in the tutorial 219 | 220 | # Binary Cross-Entropy loss 221 | adversarial_loss = adversarial_loss_criterion(sr_discriminated, torch.zeros_like(sr_discriminated)) + \ 222 | adversarial_loss_criterion(hr_discriminated, torch.ones_like(hr_discriminated)) 223 | 224 | # Back-prop. 225 | optimizer_d.zero_grad() 226 | adversarial_loss.backward() 227 | 228 | # Clip gradients, if necessary 229 | if grad_clip is not None: 230 | clip_gradient(optimizer_d, grad_clip) 231 | 232 | # Update discriminator 233 | optimizer_d.step() 234 | 235 | # Keep track of loss 236 | losses_d.update(adversarial_loss.item(), hr_imgs.size(0)) 237 | 238 | # Keep track of batch times 239 | batch_time.update(time.time() - start) 240 | 241 | # Reset start time 242 | start = time.time() 243 | 244 | # Print status 245 | if i % print_freq == 0: 246 | print('Epoch: [{0}][{1}/{2}]----' 247 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})----' 248 | 'Data Time {data_time.val:.3f} ({data_time.avg:.3f})----' 249 | 'Cont. Loss {loss_c.val:.4f} ({loss_c.avg:.4f})----' 250 | 'Adv. Loss {loss_a.val:.4f} ({loss_a.avg:.4f})----' 251 | 'Disc. Loss {loss_d.val:.4f} ({loss_d.avg:.4f})'.format(epoch, 252 | i, 253 | len(train_loader), 254 | batch_time=batch_time, 255 | data_time=data_time, 256 | loss_c=losses_c, 257 | loss_a=losses_a, 258 | loss_d=losses_d)) 259 | 260 | del lr_imgs, hr_imgs, sr_imgs, hr_imgs_in_vgg_space, sr_imgs_in_vgg_space, hr_discriminated, sr_discriminated # free some memory since their histories may be stored 261 | 262 | 263 | if __name__ == '__main__': 264 | main() 265 | -------------------------------------------------------------------------------- /train_srresnet.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch.backends.cudnn as cudnn 3 | import torch 4 | from torch import nn 5 | from models import SRResNet 6 | from datasets import SRDataset 7 | from utils import * 8 | 9 | # Data parameters 10 | data_folder = './' # folder with JSON data files 11 | crop_size = 96 # crop size of target HR images 12 | scaling_factor = 4 # the scaling factor for the generator; the input LR images will be downsampled from the target HR images by this factor 13 | 14 | # Model parameters 15 | large_kernel_size = 9 # kernel size of the first and last convolutions which transform the inputs and outputs 16 | small_kernel_size = 3 # kernel size of all convolutions in-between, i.e. those in the residual and subpixel convolutional blocks 17 | n_channels = 64 # number of channels in-between, i.e. the input and output channels for the residual and subpixel convolutional blocks 18 | n_blocks = 16 # number of residual blocks 19 | 20 | # Learning parameters 21 | checkpoint = None # path to model checkpoint, None if none 22 | batch_size = 16 # batch size 23 | start_epoch = 0 # start at this epoch 24 | iterations = 1e6 # number of training iterations 25 | workers = 4 # number of workers for loading data in the DataLoader 26 | print_freq = 500 # print training status once every __ batches 27 | lr = 1e-4 # learning rate 28 | grad_clip = None # clip if gradients are exploding 29 | 30 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 31 | 32 | cudnn.benchmark = True 33 | 34 | 35 | def main(): 36 | """ 37 | Training. 38 | """ 39 | global start_epoch, epoch, checkpoint 40 | 41 | # Initialize model or load checkpoint 42 | if checkpoint is None: 43 | model = SRResNet(large_kernel_size=large_kernel_size, small_kernel_size=small_kernel_size, 44 | n_channels=n_channels, n_blocks=n_blocks, scaling_factor=scaling_factor) 45 | # Initialize the optimizer 46 | optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()), 47 | lr=lr) 48 | 49 | else: 50 | checkpoint = torch.load(checkpoint) 51 | start_epoch = checkpoint['epoch'] + 1 52 | model = checkpoint['model'] 53 | optimizer = checkpoint['optimizer'] 54 | 55 | # Move to default device 56 | model = model.to(device) 57 | criterion = nn.MSELoss().to(device) 58 | 59 | # Custom dataloaders 60 | train_dataset = SRDataset(data_folder, 61 | split='train', 62 | crop_size=crop_size, 63 | scaling_factor=scaling_factor, 64 | lr_img_type='imagenet-norm', 65 | hr_img_type='[-1, 1]') 66 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers, 67 | pin_memory=True) # note that we're passing the collate function here 68 | 69 | # Total number of epochs to train for 70 | epochs = int(iterations // len(train_loader) + 1) 71 | 72 | # Epochs 73 | for epoch in range(start_epoch, epochs): 74 | # One epoch's training 75 | train(train_loader=train_loader, 76 | model=model, 77 | criterion=criterion, 78 | optimizer=optimizer, 79 | epoch=epoch) 80 | 81 | # Save checkpoint 82 | torch.save({'epoch': epoch, 83 | 'model': model, 84 | 'optimizer': optimizer}, 85 | 'checkpoint_srresnet.pth.tar') 86 | 87 | 88 | def train(train_loader, model, criterion, optimizer, epoch): 89 | """ 90 | One epoch's training. 91 | 92 | :param train_loader: DataLoader for training data 93 | :param model: model 94 | :param criterion: content loss function (Mean Squared-Error loss) 95 | :param optimizer: optimizer 96 | :param epoch: epoch number 97 | """ 98 | model.train() # training mode enables batch normalization 99 | 100 | batch_time = AverageMeter() # forward prop. + back prop. time 101 | data_time = AverageMeter() # data loading time 102 | losses = AverageMeter() # loss 103 | 104 | start = time.time() 105 | 106 | # Batches 107 | for i, (lr_imgs, hr_imgs) in enumerate(train_loader): 108 | data_time.update(time.time() - start) 109 | 110 | # Move to default device 111 | lr_imgs = lr_imgs.to(device) # (batch_size (N), 3, 24, 24), imagenet-normed 112 | hr_imgs = hr_imgs.to(device) # (batch_size (N), 3, 96, 96), in [-1, 1] 113 | 114 | # Forward prop. 115 | sr_imgs = model(lr_imgs) # (N, 3, 96, 96), in [-1, 1] 116 | 117 | # Loss 118 | loss = criterion(sr_imgs, hr_imgs) # scalar 119 | 120 | # Backward prop. 121 | optimizer.zero_grad() 122 | loss.backward() 123 | 124 | # Clip gradients, if necessary 125 | if grad_clip is not None: 126 | clip_gradient(optimizer, grad_clip) 127 | 128 | # Update model 129 | optimizer.step() 130 | 131 | # Keep track of loss 132 | losses.update(loss.item(), lr_imgs.size(0)) 133 | 134 | # Keep track of batch time 135 | batch_time.update(time.time() - start) 136 | 137 | # Reset start time 138 | start = time.time() 139 | 140 | # Print status 141 | if i % print_freq == 0: 142 | print('Epoch: [{0}][{1}/{2}]----' 143 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})----' 144 | 'Data Time {data_time.val:.3f} ({data_time.avg:.3f})----' 145 | 'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(epoch, i, len(train_loader), 146 | batch_time=batch_time, 147 | data_time=data_time, loss=losses)) 148 | del lr_imgs, hr_imgs, sr_imgs # free some memory since their histories may be stored 149 | 150 | 151 | if __name__ == '__main__': 152 | main() 153 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import json 4 | import random 5 | import torchvision.transforms.functional as FT 6 | import torch 7 | import math 8 | 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | # Some constants 12 | rgb_weights = torch.FloatTensor([65.481, 128.553, 24.966]).to(device) 13 | imagenet_mean = torch.FloatTensor([0.485, 0.456, 0.406]).unsqueeze(1).unsqueeze(2) 14 | imagenet_std = torch.FloatTensor([0.229, 0.224, 0.225]).unsqueeze(1).unsqueeze(2) 15 | imagenet_mean_cuda = torch.FloatTensor([0.485, 0.456, 0.406]).to(device).unsqueeze(0).unsqueeze(2).unsqueeze(3) 16 | imagenet_std_cuda = torch.FloatTensor([0.229, 0.224, 0.225]).to(device).unsqueeze(0).unsqueeze(2).unsqueeze(3) 17 | 18 | 19 | def create_data_lists(train_folders, test_folders, min_size, output_folder): 20 | """ 21 | Create lists for images in the training set and each of the test sets. 22 | 23 | :param train_folders: folders containing the training images; these will be merged 24 | :param test_folders: folders containing the test images; each test folder will form its own test set 25 | :param min_size: minimum width and height of images to be considered 26 | :param output_folder: save data lists here 27 | """ 28 | print("\nCreating data lists... this may take some time.\n") 29 | train_images = list() 30 | for d in train_folders: 31 | for i in os.listdir(d): 32 | img_path = os.path.join(d, i) 33 | img = Image.open(img_path, mode='r') 34 | if img.width >= min_size and img.height >= min_size: 35 | train_images.append(img_path) 36 | print("There are %d images in the training data.\n" % len(train_images)) 37 | with open(os.path.join(output_folder, 'train_images.json'), 'w') as j: 38 | json.dump(train_images, j) 39 | 40 | for d in test_folders: 41 | test_images = list() 42 | test_name = d.split("/")[-1] 43 | for i in os.listdir(d): 44 | img_path = os.path.join(d, i) 45 | img = Image.open(img_path, mode='r') 46 | if img.width >= min_size and img.height >= min_size: 47 | test_images.append(img_path) 48 | print("There are %d images in the %s test data.\n" % (len(test_images), test_name)) 49 | with open(os.path.join(output_folder, test_name + '_test_images.json'), 'w') as j: 50 | json.dump(test_images, j) 51 | 52 | print("JSONS containing lists of Train and Test images have been saved to %s\n" % output_folder) 53 | 54 | 55 | def convert_image(img, source, target): 56 | """ 57 | Convert an image from a source format to a target format. 58 | 59 | :param img: image 60 | :param source: source format, one of 'pil' (PIL image), '[0, 1]' or '[-1, 1]' (pixel value ranges) 61 | :param target: target format, one of 'pil' (PIL image), '[0, 255]', '[0, 1]', '[-1, 1]' (pixel value ranges), 62 | 'imagenet-norm' (pixel values standardized by imagenet mean and std.), 63 | 'y-channel' (luminance channel Y in the YCbCr color format, used to calculate PSNR and SSIM) 64 | :return: converted image 65 | """ 66 | assert source in {'pil', '[0, 1]', '[-1, 1]'}, "Cannot convert from source format %s!" % source 67 | assert target in {'pil', '[0, 255]', '[0, 1]', '[-1, 1]', 'imagenet-norm', 68 | 'y-channel'}, "Cannot convert to target format %s!" % target 69 | 70 | # Convert from source to [0, 1] 71 | if source == 'pil': 72 | img = FT.to_tensor(img) 73 | 74 | elif source == '[0, 1]': 75 | pass # already in [0, 1] 76 | 77 | elif source == '[-1, 1]': 78 | img = (img + 1.) / 2. 79 | 80 | # Convert from [0, 1] to target 81 | if target == 'pil': 82 | img = FT.to_pil_image(img) 83 | 84 | elif target == '[0, 255]': 85 | img = 255. * img 86 | 87 | elif target == '[0, 1]': 88 | pass # already in [0, 1] 89 | 90 | elif target == '[-1, 1]': 91 | img = 2. * img - 1. 92 | 93 | elif target == 'imagenet-norm': 94 | if img.ndimension() == 3: 95 | img = (img - imagenet_mean) / imagenet_std 96 | elif img.ndimension() == 4: 97 | img = (img - imagenet_mean_cuda) / imagenet_std_cuda 98 | 99 | elif target == 'y-channel': 100 | # Based on definitions at https://github.com/xinntao/BasicSR/wiki/Color-conversion-in-SR 101 | # torch.dot() does not work the same way as numpy.dot() 102 | # So, use torch.matmul() to find the dot product between the last dimension of an 4-D tensor and a 1-D tensor 103 | img = torch.matmul(255. * img.permute(0, 2, 3, 1)[:, 4:-4, 4:-4, :], rgb_weights) / 255. + 16. 104 | 105 | return img 106 | 107 | 108 | class ImageTransforms(object): 109 | """ 110 | Image transformation pipeline. 111 | """ 112 | 113 | def __init__(self, split, crop_size, scaling_factor, lr_img_type, hr_img_type): 114 | """ 115 | :param split: one of 'train' or 'test' 116 | :param crop_size: crop size of HR images 117 | :param scaling_factor: LR images will be downsampled from the HR images by this factor 118 | :param lr_img_type: the target format for the LR image; see convert_image() above for available formats 119 | :param hr_img_type: the target format for the HR image; see convert_image() above for available formats 120 | """ 121 | self.split = split.lower() 122 | self.crop_size = crop_size 123 | self.scaling_factor = scaling_factor 124 | self.lr_img_type = lr_img_type 125 | self.hr_img_type = hr_img_type 126 | 127 | assert self.split in {'train', 'test'} 128 | 129 | def __call__(self, img): 130 | """ 131 | :param img: a PIL source image from which the HR image will be cropped, and then downsampled to create the LR image 132 | :return: LR and HR images in the specified format 133 | """ 134 | 135 | # Crop 136 | if self.split == 'train': 137 | # Take a random fixed-size crop of the image, which will serve as the high-resolution (HR) image 138 | left = random.randint(1, img.width - self.crop_size) 139 | top = random.randint(1, img.height - self.crop_size) 140 | right = left + self.crop_size 141 | bottom = top + self.crop_size 142 | hr_img = img.crop((left, top, right, bottom)) 143 | else: 144 | # Take the largest possible center-crop of it such that its dimensions are perfectly divisible by the scaling factor 145 | x_remainder = img.width % self.scaling_factor 146 | y_remainder = img.height % self.scaling_factor 147 | left = x_remainder // 2 148 | top = y_remainder // 2 149 | right = left + (img.width - x_remainder) 150 | bottom = top + (img.height - y_remainder) 151 | hr_img = img.crop((left, top, right, bottom)) 152 | 153 | # Downsize this crop to obtain a low-resolution version of it 154 | lr_img = hr_img.resize((int(hr_img.width / self.scaling_factor), int(hr_img.height / self.scaling_factor)), 155 | Image.BICUBIC) 156 | 157 | # Sanity check 158 | assert hr_img.width == lr_img.width * self.scaling_factor and hr_img.height == lr_img.height * self.scaling_factor 159 | 160 | # Convert the LR and HR image to the required type 161 | lr_img = convert_image(lr_img, source='pil', target=self.lr_img_type) 162 | hr_img = convert_image(hr_img, source='pil', target=self.hr_img_type) 163 | 164 | return lr_img, hr_img 165 | 166 | 167 | class AverageMeter(object): 168 | """ 169 | Keeps track of most recent, average, sum, and count of a metric. 170 | """ 171 | 172 | def __init__(self): 173 | self.reset() 174 | 175 | def reset(self): 176 | self.val = 0 177 | self.avg = 0 178 | self.sum = 0 179 | self.count = 0 180 | 181 | def update(self, val, n=1): 182 | self.val = val 183 | self.sum += val * n 184 | self.count += n 185 | self.avg = self.sum / self.count 186 | 187 | 188 | def clip_gradient(optimizer, grad_clip): 189 | """ 190 | Clips gradients computed during backpropagation to avoid explosion of gradients. 191 | 192 | :param optimizer: optimizer with the gradients to be clipped 193 | :param grad_clip: clip value 194 | """ 195 | for group in optimizer.param_groups: 196 | for param in group['params']: 197 | if param.grad is not None: 198 | param.grad.data.clamp_(-grad_clip, grad_clip) 199 | 200 | 201 | def save_checkpoint(state, filename): 202 | """ 203 | Save model checkpoint. 204 | 205 | :param state: checkpoint contents 206 | """ 207 | 208 | torch.save(state, filename) 209 | 210 | 211 | def adjust_learning_rate(optimizer, shrink_factor): 212 | """ 213 | Shrinks learning rate by a specified factor. 214 | 215 | :param optimizer: optimizer whose learning rate must be shrunk. 216 | :param shrink_factor: factor in interval (0, 1) to multiply learning rate with. 217 | """ 218 | 219 | print("\nDECAYING learning rate.") 220 | for param_group in optimizer.param_groups: 221 | param_group['lr'] = param_group['lr'] * shrink_factor 222 | print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],)) 223 | --------------------------------------------------------------------------------