├── upsample4x.pth ├── .gitignore ├── README.md ├── model_utils.py ├── model.py └── run.py /upsample4x.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bguisard/SuperResolution/HEAD/upsample4x.pth -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints 2 | data 3 | results 4 | dev 5 | old_models 6 | .ipynb_checkpoints 7 | __pycache__ 8 | *.swp 9 | *.p 10 | *.csv 11 | *.bc 12 | *.ipynb 13 | *.pdf 14 | *.jpg 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Super Resolution 2 | 3 | #### * Work in progress * 4 | 5 | This is an implementation in PyTorch of the super resolution CNN based in the paper: [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155). 6 | 7 | There are a few differences between the proposed network and my implementation: 8 | 9 | - changed upsampling blocks to remove the "checkered board pattern" as sugested by [[2]](http://distill.pub/2016/deconv-checkerboard/) 10 | 11 | - Changed deconv kernel size to 4 and added padding to all steps in order to get the same dimensions as the paper if you decide not to use the upsampling block from [2]. 12 | 13 | - Removed the gaussian noise when generating lower resolution images 14 | 15 | The model was trained for 10 epochs on a random set of 20k images from the validation set of the MS-COCO dataset. 16 | 17 | 18 | ## Usage: 19 | 20 | - To use provided model weights: 21 | python run --model upsample4x.pth --target_image [PATH_TO_IMAGE_TO_UPSAMPLE] 22 | 23 | The output of the model will be an image named 'up_IMAGENAME' on your current working folder. 24 | 25 | - To train a new model: 26 | python run.py --mode train --image_folder [PATH_TO_TRAIN_FOLDER] --trn_epochs [EPOCHS] --l_rate [LEARNING_RATE] 27 | 28 | ## To-do: 29 | 30 | - Add a few good-to-haves on model (change loss function, process multiple images, prevent gray imgs, etc) 31 | 32 | - Add list of dependencies 33 | 34 | - Train on all of ImageNet dataset to improve results while avoiding overfitting. 35 | 36 | ## References 37 | 38 | [1] [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155) 39 | 40 | [2] [Deconvolution and Checkerboard Artifacts](http://distill.pub/2016/deconv-checkerboard/) 41 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from PIL import Image 4 | import torch 5 | from torchvision import transforms 6 | from torch.autograd import Variable 7 | import numpy as np 8 | 9 | 10 | def is_image_file(filename): 11 | return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"]) 12 | 13 | 14 | def load_img(filepath): 15 | """ 16 | TO-DO: Skip grayscale images or return error 17 | """ 18 | img = Image.open(filepath) 19 | return img 20 | 21 | 22 | class DatasetFromFolder(torch.utils.data.Dataset): 23 | def __init__(self, image_dir, input_transform=None, target_transform=None): 24 | super(DatasetFromFolder, self).__init__() 25 | self.image_filenames = [os.path.join(image_dir, x) for x in os.listdir(image_dir) if is_image_file(x)] 26 | 27 | self.input_transform = input_transform 28 | self.target_transform = target_transform 29 | 30 | def __getitem__(self, index): 31 | input = load_img(self.image_filenames[index]) 32 | target = input.copy() 33 | if self.input_transform: 34 | input = self.input_transform(input) 35 | if self.target_transform: 36 | target = self.target_transform(target) 37 | 38 | return input, target 39 | 40 | def __len__(self): 41 | return len(self.image_filenames) 42 | 43 | 44 | def calculate_valid_crop_size(crop_size, upscale_factor): 45 | return crop_size - (crop_size % upscale_factor) 46 | 47 | 48 | def img_transform(crop_size, upscale_factor=1): 49 | return transforms.Compose([ 50 | transforms.Scale(crop_size // upscale_factor), 51 | transforms.CenterCrop(crop_size // upscale_factor), 52 | transforms.ToTensor()]) 53 | 54 | 55 | def get_training_set(path, hr_size=256, upscale_factor=4): 56 | train_dir = path 57 | crop_size = calculate_valid_crop_size(hr_size, upscale_factor) 58 | 59 | return DatasetFromFolder(train_dir, 60 | input_transform=img_transform(crop_size, upscale_factor=upscale_factor), 61 | target_transform=img_transform(crop_size)) 62 | 63 | 64 | def denorm_meanstd(arr, mean, std): 65 | new_img = np.zeros_like(arr) 66 | for i in range(3): 67 | new_img[i, :, :] = arr[i, :, :] * std[i] 68 | new_img[i, :, :] += mean[i] 69 | return new_img 70 | 71 | 72 | def image_loader(image_name, max_sz=256): 73 | """ forked from pytorch tutorials """ 74 | r_image = Image.open(image_name) 75 | mindim = np.min((np.max(r_image.size[:2]), max_sz)) 76 | 77 | loader = transforms.Compose([transforms.CenterCrop(mindim), 78 | transforms.ToTensor()]) 79 | 80 | image = Variable(loader(r_image)) 81 | 82 | return image.unsqueeze(0) 83 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | """ 8 | 9 | PyTorch implementation of: 10 | [1] [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155) 11 | 12 | With a few modifications as suggested by: 13 | [2] [Deconvolution and Checkerboard Artifacts](http://distill.pub/2016/deconv-checkerboard/) 14 | 15 | """ 16 | 17 | 18 | def create_loss_model(vgg, end_layer, use_maxpool=True, use_cuda=False): 19 | """ 20 | [1] uses the output of vgg16 relu2_2 layer as a loss function (layer8 on PyTorch default vgg16 model). 21 | This function expects a vgg16 model from PyTorch and will return a custom version up until layer = end_layer 22 | that will be used as our loss function. 23 | """ 24 | 25 | vgg = copy.deepcopy(vgg) 26 | 27 | model = nn.Sequential() 28 | 29 | if use_cuda: 30 | model.cuda(device_id=0) 31 | 32 | i = 0 33 | for layer in list(vgg): 34 | 35 | if i > end_layer: 36 | break 37 | 38 | if isinstance(layer, nn.Conv2d): 39 | name = "conv_" + str(i) 40 | model.add_module(name, layer) 41 | 42 | if isinstance(layer, nn.ReLU): 43 | name = "relu_" + str(i) 44 | model.add_module(name, layer) 45 | 46 | if isinstance(layer, nn.MaxPool2d): 47 | name = "pool_" + str(i) 48 | if use_maxpool: 49 | model.add_module(name, layer) 50 | else: 51 | avgpool = nn.AvgPool2d(kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding) 52 | model.add_module(name, avgpool) 53 | i += 1 54 | return model 55 | 56 | 57 | class ResidualBlock(nn.Module): 58 | """ Residual blocks as implemented in [1] """ 59 | def __init__(self, num, use_cuda=False): 60 | super(ResidualBlock, self).__init__() 61 | if use_cuda: 62 | self.c1 = nn.Conv2d(num, num, kernel_size=3, stride=1, padding=1).cuda(device_id=0) 63 | self.c2 = nn.Conv2d(num, num, kernel_size=3, stride=1, padding=1).cuda(device_id=0) 64 | self.b1 = nn.BatchNorm2d(num).cuda(device_id=0) 65 | self.b2 = nn.BatchNorm2d(num).cuda(device_id=0) 66 | else: 67 | self.c1 = nn.Conv2d(num, num, kernel_size=3, stride=1, padding=1) 68 | self.c2 = nn.Conv2d(num, num, kernel_size=3, stride=1, padding=1) 69 | self.b1 = nn.BatchNorm2d(num) 70 | self.b2 = nn.BatchNorm2d(num) 71 | 72 | def forward(self, x): 73 | h = F.relu(self.b1(self.c1(x))) 74 | h = self.b2(self.c2(h)) 75 | return h + x 76 | 77 | 78 | class UpsampleBlock(nn.Module): 79 | """ Upsample block suggested by [2] to remove checkerboard pattern from images """ 80 | def __init__(self, num, use_cuda=False): 81 | super(UpsampleBlock, self).__init__() 82 | if use_cuda: 83 | self.up1 = nn.UpsamplingNearest2d(scale_factor=2).cuda(device_id=0) 84 | self.c2 = nn.Conv2d(num, num, kernel_size=3, stride=1, padding=0).cuda(device_id=0) 85 | self.b3 = nn.BatchNorm2d(num).cuda(device_id=0) 86 | else: 87 | self.up1 = nn.UpsamplingNearest2d(scale_factor=2) 88 | self.c2 = nn.Conv2d(num, num, kernel_size=3, stride=1, padding=0) 89 | self.b3 = nn.BatchNorm2d(num) 90 | 91 | def forward(self, x): 92 | h = self.up1(x) 93 | h = F.pad(h, (1, 1, 1, 1), mode='reflect') 94 | h = self.b3(self.c2(h)) 95 | return F.relu(h) 96 | 97 | 98 | class SuperRes4x(nn.Module): 99 | def __init__(self, use_cuda=False, use_UpBlock=True): 100 | 101 | super(SuperRes4x, self).__init__() 102 | # To-do: Retrain with self.uplock and self.use_cuda as parameters 103 | 104 | # self.upblock = use_UpBlock 105 | # self.use_cuda = use_cuda 106 | upblock = True 107 | 108 | # Downsizing layer 109 | if use_cuda: 110 | self.c1 = nn.Conv2d(3, 64, kernel_size=9, stride=1, padding=4).cuda(device_id=0) 111 | self.b2 = nn.BatchNorm2d(64).cuda(device_id=0) 112 | else: 113 | self.c1 = nn.Conv2d(3, 64, kernel_size=9, stride=1, padding=4) 114 | self.b2 = nn.BatchNorm2d(64) 115 | 116 | if upblock: 117 | # Loop for residual blocks 118 | self.rs = [ResidualBlock(64, use_cuda=use_cuda) for i in range(4)] 119 | # Loop for upsampling 120 | self.up = [UpsampleBlock(64, use_cuda=use_cuda) for i in range(2)] 121 | else: 122 | # Loop for residual blocks 123 | self.rs = [ResidualBlock(64, use_cuda=use_cuda) for i in range(4)] 124 | # Transposed convolution blocks 125 | if self.use_cuda: 126 | self.dc2 = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1).cuda(device_id=0) 127 | self.bc2 = nn.BatchNorm2d(64).cuda(device_id=0) 128 | self.dc3 = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1).cuda(device_id=0) 129 | self.bc3 = nn.BatchNorm2d(64).cuda(device_id=0) 130 | else: 131 | self.dc2 = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1) 132 | self.bc2 = nn.BatchNorm2d(64) 133 | self.dc3 = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1) 134 | self.bc3 = nn.BatchNorm2d(64) 135 | 136 | # Last convolutional layer 137 | if use_cuda: 138 | self.c3 = nn.Conv2d(64, 3, kernel_size=9, stride=1, padding=4).cuda(device_id=0) 139 | else: 140 | self.c3 = nn.Conv2d(64, 3, kernel_size=9, stride=1, padding=4) 141 | 142 | def forward(self, x): 143 | upblock = True 144 | # Downsizing layer - Large Kernel ensures large receptive field on the residual blocks 145 | h = F.relu(self.b2(self.c1(x))) 146 | 147 | # Residual Layers 148 | for r in self.rs: 149 | h = r(h) # will go through all residual blocks in this loop 150 | 151 | if upblock: 152 | # Upsampling Layers - improvement suggested by [2] to remove "checkerboard pattern" 153 | for u in self.up: 154 | h = u(h) # will go through all upsampling blocks in this loop 155 | else: 156 | # As recommended by [1] 157 | h = F.relu(self.bc2(self.dc2(h))) 158 | h = F.relu(self.bc3(self.dc3(h))) 159 | 160 | # Last layer and scaled tanh activation - Scaled from 0 to 1 instead of 0 - 255 161 | h = F.tanh(self.c3(h)) 162 | h = torch.add(h, 1.) 163 | h = torch.mul(h, 0.5) 164 | return h 165 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import numpy as np 5 | 6 | from PIL import Image 7 | 8 | import torch 9 | import torch.optim as optim 10 | import torch.nn as nn 11 | 12 | from torch.utils.data import DataLoader 13 | from torchvision import models 14 | from torch.autograd import Variable 15 | 16 | from model_utils import * 17 | from model import * 18 | 19 | 20 | def train_SuperRes(model, optimizer, loader, epochs=1, use_vgg_loss=True, use_cuda=False, save_best=True): 21 | 22 | model.train() 23 | 24 | loss_fn = nn.MSELoss(size_average=False) 25 | 26 | best_loss = 1e6 27 | 28 | epoch = [1] 29 | while epoch[0] <= epochs: 30 | 31 | epoch_loss = 0 32 | 33 | for iteration, batch in enumerate(loader, 1): 34 | inp, tgt = Variable(batch[0]), Variable(batch[1]) 35 | if use_cuda: 36 | inp = inp.cuda(device_id=0) 37 | tgt = tgt.cuda(device_id=0) 38 | 39 | # Forward 40 | pred = model(inp) 41 | 42 | # Compute loss 43 | if use_vgg_loss: 44 | vgg_loss_inp = vgg_loss(pred) 45 | vgg_loss_tgt = vgg_loss(tgt) 46 | loss = loss_fn(vgg_loss_inp, vgg_loss_tgt) 47 | else: 48 | loss = loss_fn(pred, tgt) 49 | 50 | # Update loss 51 | epoch_loss += loss.data[0] 52 | 53 | # Backward pass 54 | optimizer.zero_grad() 55 | loss.backward() 56 | optimizer.step() 57 | 58 | if iteration % 100 == 0: 59 | print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(loader), loss.data[0])) 60 | 61 | if epoch[0] % 1 == 0: 62 | print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(loader))) 63 | 64 | if save_best and epoch_loss < best_loss: 65 | print("Saving model at epoch: {}".format(epoch)) 66 | torch.save(model, "best_model.pth") 67 | best_loss = epoch_loss 68 | 69 | epoch[0] += 1 70 | 71 | 72 | if __name__ == '__main__': 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument( 75 | '--mode', 76 | type=str, 77 | default='upsample', 78 | help='Flag for Training or Upsampling. Valid entries are \'train\' and \'upsample\' ' 79 | ) 80 | parser.add_argument( 81 | '--image_folder', 82 | type=str, 83 | nargs='?', 84 | default='', 85 | help='Path to image folder. This is where the images will be read for training or saved for upsampling.' 86 | ) 87 | parser.add_argument( 88 | '--hi_res_size', 89 | type=int, 90 | nargs='?', 91 | default=256, 92 | help='Size of hi_res image for training mode. Default is 256.' 93 | ) 94 | parser.add_argument( 95 | '--trn_epochs', 96 | type=int, 97 | nargs='?', 98 | default=5, 99 | help='Number of training epochs. Default is 5.' 100 | ) 101 | parser.add_argument( 102 | '--l_rate', 103 | type=int, 104 | nargs='?', 105 | default=1e-3, 106 | help='Learning rate for training. Default is 1e-3.' 107 | ) 108 | parser.add_argument( 109 | '--target_image', 110 | type=str, 111 | nargs='?', 112 | default='', 113 | help='Path to target image to upsample.' 114 | ) 115 | parser.add_argument( 116 | '--model', 117 | type=str, 118 | nargs='?', 119 | default='', 120 | help='Path to saved model file.' 121 | ) 122 | 123 | args = parser.parse_args() 124 | if args.mode not in ['train', 'upsample']: 125 | print ("Please choose \'train\' or \'upsample\' as mode") 126 | sys.exit() 127 | 128 | if args.mode == 'train': 129 | print ("\n *** INITIALIZING *** ") 130 | if args.image_folder == '': 131 | print ("Please specify folder with training input images\n") 132 | sys.exit() 133 | else: 134 | print ("running in mode: {}".format(args.mode)) 135 | print ("training folder: {} \n".format(args.image_folder)) 136 | 137 | # Define if flag to use GPU and dtype for variables 138 | use_cuda = torch.cuda.is_available() 139 | dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor 140 | 141 | # Load training dataset 142 | train_dataset = get_training_set(path=args.image_folder, 143 | hr_size=args.hi_res_size, 144 | upscale_factor=4) 145 | if use_cuda: 146 | train_data_loader = DataLoader(dataset=train_dataset, 147 | num_workers=4, 148 | batch_size=8, 149 | shuffle=True) 150 | else: 151 | train_data_loader = DataLoader(dataset=train_dataset, 152 | num_workers=4, 153 | batch_size=1, 154 | shuffle=True) 155 | 156 | # Create VGG model for loss function 157 | vgg16 = models.vgg16(pretrained=True).features 158 | if use_cuda: 159 | vgg16.cuda(device_id=0); 160 | 161 | vgg_loss = create_loss_model(vgg16, 8, use_cuda=use_cuda) 162 | 163 | for param in vgg_loss.parameters(): 164 | param.requires_grad = False 165 | 166 | model = SuperRes4x(use_cuda=use_cuda) 167 | optimizer = optim.Adam(model.parameters(), lr=args.l_rate) 168 | train_SuperRes(model, optimizer, train_data_loader, use_cuda=use_cuda, 169 | epochs=args.trn_epochs, use_vgg_loss=False) 170 | 171 | if args.mode == 'upsample': 172 | print ("\n *** INITIALIZING *** ") 173 | if args.target_image == '': 174 | print ("Please specify image to upsample\n") 175 | sys.exit() 176 | if args.model == '': 177 | print ("Please specify path to saved model\n") 178 | sys.exit() 179 | else: 180 | print ("running in mode: {}".format(args.mode)) 181 | print ("upsample image: {}".format(args.target_image)) 182 | print("using model: {} \n".format(args.model)) 183 | 184 | use_cuda = torch.cuda.is_available() 185 | dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor 186 | 187 | model = SuperRes4x(use_cuda=use_cuda) 188 | model = torch.load(args.model) 189 | 190 | image = image_loader(args.target_image).type(dtype) 191 | upsampled = model(image) 192 | 193 | if use_cuda: 194 | upsampled = upsampled.cpu().data.numpy().squeeze() 195 | else: 196 | upsampled = upsampled.data.numpy().squeeze() 197 | 198 | upsampled = np.swapaxes(upsampled, 0, 2) 199 | upsampled = np.swapaxes(upsampled, 0, 1) 200 | upsampled = np.array(upsampled * 255, dtype=np.uint8) 201 | upsampled = Image.fromarray(upsampled, 'RGB') 202 | upsampled.save('up_' + args.target_image.split('/')[-1]) 203 | --------------------------------------------------------------------------------