├── images └── diagram.JPG ├── dataloader.py ├── LICENSE ├── README.md ├── model.py └── train.py /images/diagram.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erikqu/EnhanceNet-PyTorch/HEAD/images/diagram.JPG -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | import cv2 4 | import numpy as np 5 | 6 | class Dataset(Dataset): 7 | ''' 8 | LR = low resolution image 9 | HR = high resolution image 10 | dir = dump of image directories for dataset 11 | 12 | ''' 13 | def __init__(self, ids, lr,hr): 14 | 'Initialization' 15 | self.dir = ids 16 | self.lr = lr 17 | self.hr = hr 18 | def __len__(self): 19 | return len(self.dir) 20 | def __getitem__(self, index): 21 | filename = self.dir[index] 22 | lower = cv2.imread(self.lr + filename,1) 23 | higher = cv2.imread(self.hr + filename,1) 24 | #transpose so pytorch plays nice 25 | lower= lower.transpose((2, 0, 1)) 26 | higher = higher.transpose((2, 0, 1)) 27 | #pass numpy arrays to torch and make float tensors. 28 | lower = torch.from_numpy(lower).float() 29 | higher = torch.from_numpy(higher).float() 30 | return lower, higher -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Erik Quintanilla 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EnhanceNet-PyTorch 2 | A PyTorch implementation of ENET-PA for Single Image Super Resolution (SISR). 3 | 4 | ![Screenshot](images/diagram.JPG) 5 | 6 | Example from ENET paper 7 | 8 | If you use this architecture in your work please cite the original paper: 9 | 10 | ``` 11 | @inproceedings{enhancenet, 12 | title={{EnhanceNet: Single Image Super-Resolution through Automated Texture Synthesis}}, 13 | author={Sajjadi, Mehdi S. M. and Sch{\"o}lkopf, Bernhard and Hirsch, Michael}, 14 | booktitle={Computer Vision (ICCV), 2017 IEEE International Conference on}, 15 | pages={4501--4510}, 16 | year={2017}, 17 | organization={IEEE}, 18 | url={https://arxiv.org/abs/1612.07919/} 19 | } 20 | ``` 21 | 22 | # Description 23 | 24 | ENET-PA here is implemented in PyTorch as there is no current implementation in PyTorch. All credit goes to Sajjad et al. Adversarial learning along with perceptual loss (hence P+A). The model is in the form of a GAN and does 4x upscaling of 64x64 images to 512x512. 25 | 26 | #### https://arxiv.org/abs/1612.07919/ 27 | 28 | # TODO 29 | 30 | - Add texture matching loss (ENET-PAT) 31 | - Make user friendly for out-of-box train and test 32 | 33 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torchvision.models import vgg19 4 | import torchvision 5 | 6 | ''' 7 | EnhanceNet Implementation in PyTorch by Erik Quintanilla 8 | 9 | Single Image Super Resolution 10 | 11 | https://arxiv.org/abs/1612.07919/ 12 | 13 | This program assumes GPU. 14 | ''' 15 | 16 | class ResidualBlock(nn.Module): 17 | def __init__(self, in_features): 18 | super(ResidualBlock, self).__init__() 19 | self.conv_block = nn.Sequential( 20 | nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1), 21 | nn.ReLU(), 22 | nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1), 23 | ) 24 | 25 | def forward(self, x): 26 | return self.conv_block(x) 27 | 28 | 29 | class Generator(nn.Module): 30 | def __init__(self, in_channels=3, out_channels=3, residual_blocks=10): 31 | super(Generator, self).__init__() 32 | self.conv1 = nn.Sequential( 33 | nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1), 34 | nn.ReLU()) 35 | self.conv2 = nn.Sequential( 36 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), 37 | nn.ReLU()) 38 | #Residual blocks 39 | residuals = [] 40 | for _ in range(residual_blocks): 41 | residuals.append(ResidualBlock(64)) 42 | self.residuals = nn.Sequential(*residuals) 43 | 44 | #nearest neighbor upsample 45 | self.upsample = nn.Sequential( 46 | nn.Upsample(scale_factor=2, mode='bicubic'), 47 | nn.Conv2d(64, 64, 3, 1, 1), 48 | nn.ReLU(), 49 | nn.Upsample(scale_factor=2, mode='bicubic'), 50 | nn.Conv2d(64, 64, 3, 1, 1), 51 | nn.ReLU()) 52 | self.conv3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.ReLU()) 53 | self.conv4 = nn.Sequential(nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)) 54 | self.resize = nn.Upsample(scale_factor=4, mode='bicubic',align_corners=True) 55 | 56 | def forward(self, x): 57 | out = self.conv1(x) 58 | out = self.residuals(out) 59 | out = self.conv2(out) 60 | out= self.upsample(out) 61 | out = self.conv3(out) 62 | i_res = self.conv4(out) 63 | i_bicubic = self.resize(x) 64 | out = torch.add(i_bicubic ,i_res) 65 | return out 66 | 67 | class Discriminator(nn.Module): 68 | def __init__(self, input_shape): 69 | super(Discriminator, self).__init__() 70 | layers = [] 71 | self.input_shape = input_shape 72 | in_channels, in_height, in_width = self.input_shape 73 | self.output_shape = (1, 8, 8) 74 | 75 | def discriminator_block(in_filters, out_filters, first_block=False): 76 | layers = [] 77 | layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1)) 78 | if not first_block: 79 | layers.append(nn.BatchNorm2d(out_filters)) 80 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 81 | layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1)) 82 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 83 | return layers 84 | 85 | in_filters = in_channels 86 | for i, out_filters in enumerate([16, 32,64, 128, 256, 512]): 87 | layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0))) 88 | in_filters = out_filters 89 | 90 | layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1)) 91 | 92 | self.model = nn.Sequential(*layers) 93 | 94 | def forward(self, img): 95 | return self.model(img) 96 | 97 | class Vgg_Features(nn.Module): 98 | def __init__(self, pool_layer_num = 9): 99 | ''' 100 | To capture bothlow-level and high-level features, 101 | we use a combination ofthe second and fifth pooling 102 | layers and compute the MSEon their feature activations. 103 | 104 | - Sajjadi et al. 105 | ''' 106 | 107 | #we have maxpooling layers at [4,9,18,27,36] 108 | super(Vgg_Features, self).__init__() 109 | model = vgg19(pretrained=True) 110 | self.features = nn.Sequential(*list(model.features.children())[:pool_layer_num]) 111 | 112 | def forward(self, img): 113 | return self.features(img) 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | from torch.autograd import Variable 7 | from torchvision.utils import save_image, make_grid 8 | import cv2 9 | from model import * 10 | from dataloader import * 11 | 12 | ''' 13 | EnhanceNet Implementation in PyTorch by Erik Quintanilla 14 | 15 | Single Image Super Resolution 16 | 17 | https://arxiv.org/abs/1612.07919/ 18 | 19 | This program assumes GPU. 20 | ''' 21 | 22 | #hyperparams 23 | cuda = torch.cuda.is_available() 24 | #torch.cuda.empty_cache() 25 | height = 128 26 | width = 128 27 | channels = 3 28 | lr = .0009 29 | b1 = .5 30 | b2 = .9 31 | batch_size = 3 32 | n_epochs= 5 33 | hr_shape = (height, width) 34 | 35 | save_interval = 100 36 | 37 | #build models 38 | generator = Generator(residual_blocks=10) 39 | discriminator = Discriminator(input_shape=(channels, *hr_shape)) 40 | features_2 = Vgg_Features(pool_layer_num = 9) #9 here is the actual index, but it's the second pooling layer 41 | features_5 = Vgg_Features(pool_layer_num = 36) #36 here is the actual index, but it's the fifth pooling layer 42 | #send to gpu 43 | generator = generator.cuda() 44 | discriminator = discriminator.cuda() 45 | loss = torch.nn.MSELoss().cuda() 46 | features_2.cuda() 47 | features_5.cuda() 48 | 49 | #set optimizers 50 | g_opti = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2)) 51 | d_opti = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2)) 52 | 53 | #set data geneator 54 | curdir = "" 55 | imagedir = np.load(curdir + "_ids.npy") 56 | lowres = "M:/" + curdir + "_128/" 57 | highres = "M:/" + curdir + "_512/" 58 | gen = Dataset(ids = imagedir, lr = lowres, hr = highres) 59 | train_loader = DataLoader(gen, batch_size=batch_size, shuffle=True, num_workers=0) 60 | train_loader = iter(train_loader) 61 | 62 | load_weights = True 63 | 64 | if load_weights: 65 | print("Loading old weights...") 66 | tmp = torch.load("M:/Experiments/EnhanceNet-PyTorch/saved_models/generator_0_4800.pth") 67 | generator.load_state_dict(tmp) 68 | tmp = torch.load("M:/Experiments/EnhanceNet-PyTorch/saved_models/discriminator_0_4800.pth") 69 | discriminator.load_state_dict(tmp) 70 | print("Best old weights loaded!") 71 | 72 | Tensor = torch.cuda.FloatTensor 73 | 74 | for epoch in range(n_epochs): 75 | for i, (lr, hr) in enumerate(train_loader): 76 | 77 | #Variables so torch plays nice with autograd, valid and fake for discriminator 78 | lr = Variable(lr.type(Tensor)) 79 | hr = Variable(hr.type(Tensor)) 80 | valid = Variable(Tensor(np.ones((batch_size, *discriminator.output_shape))), requires_grad=False) 81 | fake = Variable(Tensor(np.zeros((batch_size, *discriminator.output_shape))), requires_grad=False) 82 | 83 | '''Generator''' 84 | 85 | #reset grads 86 | g_opti.zero_grad() 87 | 88 | #get our "fake" images from our generator 89 | generated_hr = generator(lr) 90 | 91 | #what does the discriminator think of these fake images? 92 | verdict = discriminator(generated_hr) 93 | 94 | #fetch loss 95 | 96 | #perceptual loss uses both the second and fifth pooling layer. 97 | #_2, _5 here denote the pooling layer 98 | generated_features = features_2(generated_hr) 99 | real_features = features_2(hr) 100 | feature_loss_2 = loss(generated_features, real_features.detach()) 101 | 102 | generated_features = features_5(generated_hr) 103 | real_features = features_5(hr) 104 | feature_loss_5 = loss(generated_features, real_features.detach()) 105 | 106 | total_feature_loss = (.2*feature_loss_2) + feature_loss_5 107 | 108 | g_loss = loss(verdict, valid) + total_feature_loss #ENET-PA 109 | 110 | #backpop that loss 111 | g_loss.backward() 112 | #update our optimizer 113 | g_opti.step() 114 | 115 | '''Discriminator''' 116 | 117 | d_opti.zero_grad() 118 | 119 | #we do a shuffle on the true and fake images so it's not trivial which is which. 120 | hr_imgs = torch.cat([discriminator(hr), discriminator(generated_hr.detach())], dim=0) 121 | hr_labels = torch.cat([valid, fake], dim=0) 122 | idxs = list(range(len(hr_labels))) 123 | idxs = np.random.shuffle(idxs) 124 | hr_imgs = hr_imgs[idxs] 125 | hr_labels = hr_labels[idxs] 126 | 127 | d_loss = loss(hr_imgs, hr_labels) 128 | d_loss.backward() 129 | d_opti.step() 130 | 131 | print("D: %f G: %f \t Epoch: (%i/%i) Batch: (%i/%i)" %(d_loss.item(), g_loss.item(), epoch, n_epochs, i, len(train_loader))) 132 | if i % save_interval == 0: 133 | #put the channels back in order! 134 | generated_hr = generated_hr[:, [2,1,0]] 135 | hr = hr[:, [2,1,0]] 136 | #fancy grid so we can view 137 | generated_hr = make_grid(generated_hr, nrow=1, normalize=True) 138 | hr = make_grid(hr, nrow=1, normalize=True) 139 | tmp = torch.cat((hr, generated_hr), -1) 140 | save_image(tmp, "samples/%d.png" % i, normalize=False) 141 | #save_image(lr, "samples/originals/%d.png" % i, normalize=False) 142 | #save generator and discriminator 143 | torch.save(generator.state_dict(), "saved_models/generator_%d.pth" % epoch) 144 | torch.save(discriminator.state_dict(), "saved_models/discriminator_%d.pth" % epoch) 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | --------------------------------------------------------------------------------