├── .gitignore ├── LICENSE ├── README.md ├── callbacks.py ├── data.py ├── data ├── p1.png ├── p2.png ├── p3.png ├── p4.png ├── p5.png ├── p6.png └── p7.png ├── figures ├── c1.png ├── c2.png ├── c3.png ├── c4.png ├── c5.png ├── c6.png ├── dloss.png ├── gloss.png ├── movie1.gif ├── movie2.gif ├── ploss.png ├── rgan.png ├── rrdb.png └── wall.png ├── inference.py ├── layers.py ├── losses.py ├── models.py ├── notebook.ipynb ├── requirements.txt ├── results ├── sr-p1.jpg ├── sr-p2.jpg ├── sr-p3.jpg ├── sr-p4.jpg ├── sr-p5.jpg ├── sr-p6.jpg └── sr-p7.jpg ├── train.py └── weights └── GeneratorVG4(1520).h5 /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Rishik Mourya 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Super Resolution for Real Time Image Enhancement 2 | 3 | 4 | 5 |
6 | 7 |
8 | 9 | # Final Results from Validation data 10 | 11 | 12 | 13 | 16 | 19 | 20 |
14 | 15 | 17 | 18 |
21 | 22 | # Introduction 23 | 24 | It is no suprising that adversarial training is indeed possible for super resolution tasks. The problem with pretty much all the state of the art models is that they are just not usable for native deployment, because of 100s of MBs of model size. 25 | 26 | Also, most of the known super resolution frameworks only works on single compression method, such as bicubic, nearest, bilinear etc. Which helps the model a lot to beat the previous state of the art scores but they perform poorly on real life low resolution images because they might not belong to the same compression method for which the model was trained for. 27 | 28 | So for me the main goal with this project wasn't just to create yet another super resolution model, but to develop a lightest model as possible which works well on any random compression methods. 29 | 30 | 31 | With this goal in mind, I tried adopting the modern super resolution frameworks such as relativistic adversarial training, and content loss optimization (I've mainly followed the [ESRGAN](https://arxiv.org/abs/1809.00219), with few changes in the objective function), and finally was able to create a **model of size 5MB!!!** 32 | 33 | # API Usage 34 | 35 | ```python 36 | from inference import enhance_image 37 | 38 | enhance_image( 39 | lr_image, # , # or lr_path = , 40 | sr_path, # , 41 | visualize, # 42 | size, # , 43 | ) 44 | ``` 45 | 46 | # CLI Usage 47 | 48 | ``` 49 | usage: inference.py [-h] [--lr-path LR_PATH] [--sr-path SR_PATH] 50 | 51 | Super Resolution for Real Time Image Enhancement 52 | 53 | optional arguments: 54 | -h, --help show this help message and exit 55 | --lr-path LR_PATH Path to the low resolution image. 56 | --sr-path SR_PATH Output path where the enhanced image would be saved. 57 | ``` 58 | 59 | # Model architectures 60 | 61 | The main building block of the generator is the Residual in Residual Dense Block (RRDB), which consists of classic DenseNet module but coupled with a residual connection 62 | 63 |
64 | 65 |
66 | 67 | Now in the original paper the authors mentioned to remove the batch normalization layer in order to remove the checkboard artifacts, but due to the extreme small size of my model, I found utilizing the batch normalization layer quite effective for both speeding up the training and better quality results. 68 | 69 | Another change I made in the original architecture is replacing the nearest upsampling proceedure with the pixel shuffle, which helped a lot to produce highly detailed outputs given the size of the model. 70 | 71 | The discriminator is made up of blocks of classifc convolution layer followed by batch normalization followed by leaky relu non linearity. 72 | 73 | # Relativistic Discriminator 74 | 75 | A relativistic discriminator tries to predict the probability that a real 76 | image is relatively more realistic than a fake one. 77 | 78 |
79 | 80 |
81 | 82 | So the discriminator and the generator are optimized to minizize these corresponding losses: 83 | 84 | Discriminator's adversarial loss: 85 |
86 | 87 |
88 | 89 | Generator's adversarial loss: 90 |
91 | 92 |
93 | 94 | # Perceptual Loss (Final objective for the Generator) 95 | 96 | Original perceptual loss introduced in SRGAN paper combines the adversarial loss and the content loss obtained from the features of final convolution layers of the VGG Net. 97 | 98 | Effectiveness of perceptual loss if found increased by constraining on features before activation rather than after activation as practiced in SRGAN. 99 | 100 | To make the Perceptual loss more effective, I additionally added the preactivation features disparity from both shallow and deep layers, making the generator produce better results. 101 | 102 | In addition to content loss and relativistic adversarial optimization, a simple pixel loss is also added to the generator's final objective as per the paper. 103 | 104 | Now based on my experiments I found it really hard for the generator to produce highly detailed outputs when its also minimizing the pixel loss (I'm imputing this observation to the fact that my model is very small). 105 | 106 | This is a bit surprising because optimizing an additional objective function which has same optima should help speeding up the training. My interpretation is since super resolution is not a one to one matching, as multiple results are there for a single low resolution patch (more on patch size below), so forcing the generator to converge to a single output would cause the generator to not produce detailed but instead the average of all those possible outputs. 107 | 108 | So I tried reducing the pixel loss weight coefficient down to 1e-2 to 1e-4 as described in the paper, and then compared the results with the generator trained without any pixel loss, and found that pixel loss has no significant visual improvements. So given my constrained training environment (Google Colab), I decided not to utilize the pixel loss as one of the part of generator's loss. 109 | 110 | So here's the generator's final loss: 111 |
112 | 113 |
114 | 115 | # Patch size affect 116 | 117 | Ideally larger the patch size better the adversarial training hence better the results, since an enlarged receptive field helps both the models to capture more semantic information. Therefore the paper uses 96x96 to 192x192 as the patch size resolution, but since I was constrained to utilize Google Colab, my patch size was only 32x32 😶, and that too with batch size of 8. 118 | 119 | # Multiple compression methods 120 | 121 | The idea is to make the generator independent of the compression that is applied to the training dataset, so that its much more robust in real life samples. 122 | 123 | For this I randomly applied the nearest, bilinear and bicubic compressions on all the data points in the dataset every time a batch is processed. 124 | 125 | # Validation Results after ~500 epochs 126 | 127 | |

Loss type

|Value| 128 | |--|--| 129 | |Content Loss (L1) [5th, 10th, 20th preactivation features from VGGNet]| ~38.582 130 | |Style Loss (L1) [320th preactivation features from EfficientNetB4]| ~1.1752 131 | |Adversarial Loss| ~1.550 132 | 133 | # Visual Comparisons 134 | 135 | Below are some of the common outputs that most of the super resolution papers compare with (not used in the training data). 136 | 137 |
138 | 139 |
140 |
141 | 142 |
143 |
144 | 145 |
146 |
147 | 148 |
149 |
150 | 151 |
152 |
153 | 154 |
155 | 156 | # Author - Rishik Mourya 157 | -------------------------------------------------------------------------------- /callbacks.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class CheckpointCallback(tf.keras.callbacks.Callback): 4 | def __init__(self, checkpoint_dir, resume = False, epoch_step = 1): 5 | super(CheckpointCallback, self).__init__() 6 | 7 | self.checkpoint_dir = checkpoint_dir 8 | self.resume = resume 9 | self.epoch_step = epoch_step 10 | 11 | def setup_checkpoint(self, *args, **kwargs): 12 | self.checkpoint = tf.train.Checkpoint( 13 | generator = self.model.generator, 14 | discriminator = self.model.discriminator, 15 | generator_optimizer = self.model.generator.optimizer, 16 | discriminator_optimizer = self.model.discriminator.optimizer 17 | ) 18 | self.manager = tf.train.CheckpointManager( 19 | self.checkpoint, 20 | directory = self.checkpoint_dir, 21 | checkpoint_name = 'SRGAN', 22 | max_to_keep = 1 23 | ) 24 | 25 | if self.resume: 26 | self.load_checkpoint() 27 | else: 28 | print('Starting training from scratch...\n') 29 | 30 | def on_batch_end(self, batch, *args, **kwargs): 31 | if (batch + 1) % int(self.epoch_step * len(train_data)) == 0: 32 | print(f"\n\nCheckpoint saved to {self.manager.save()}\n") 33 | 34 | def load_checkpoint(self): 35 | if self.manager.latest_checkpoint: 36 | self.checkpoint.restore(self.manager.latest_checkpoint) 37 | print(f"Checkpoint restored from '{self.manager.latest_checkpoint}'\n") 38 | else: 39 | print("No checkpoints found, initializing from scratch...\n") 40 | 41 | def set_lr(self, lr, beta_1 = 0.9): 42 | print(f'Continuing with learning rate: {lr}') 43 | self.model.generator.optimizer.beta_1 = beta_1 44 | self.model.generator.optimizer.learning_rate = lr 45 | self.model.discriminator.optimizer.beta_1 = beta_1 46 | self.model.discriminator.optimizer.learning_rate = lr 47 | 48 | class ProgressCallback(tf.keras.callbacks.Callback): 49 | def __init__(self, logs_step, generator_step): 50 | super(ProgressCallback, self).__init__() 51 | 52 | self.logs_step = logs_step 53 | self.generator_step = generator_step 54 | 55 | def on_batch_end(self, batch, logs, **kwargs): 56 | if (batch + 1) % int(self.generator_step * len(train_data)) == 0: 57 | if self.model.perceptual_finetune: 58 | visualize_samples( 59 | images_lists = (self.model.lrs[:3], self.model.srs[:3], self.model.hrs[:3]), 60 | titles = ('Low Resolution', 'Predicted Enhanced', 'High Resolution'), 61 | size = (11, 11) 62 | ) 63 | else: 64 | visualize_samples( 65 | images_lists = (self.model.lrs[:3], self.model.srs[:3]), 66 | titles = ('Low Resolution', 'Predicted Enhanced'), 67 | size = (7, 7) 68 | ) -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import matplotlib.pyplot as plt 3 | 4 | HR_SIZE = 128 5 | SCALE = 4 6 | LR_SIZE = int(HR_SIZE / 4) 7 | BATCH_SIZE = 8 8 | 9 | 10 | # [====================================================] 11 | # [================ Random Compressions ===============] 12 | # [====================================================] 13 | 14 | def random_compression(example): 15 | hr = example['hr'] 16 | hr_shape = tf.shape(hr) 17 | compression_idx = tf.random.uniform(shape = (), maxval = 7, dtype = tf.int32) 18 | 19 | if compression_idx == 0 or compression_idx == 1: 20 | # bicubic 21 | lr = tf.image.resize(hr, [int(hr_shape[0] / SCALE), int(hr_shape[1] / SCALE)], method = 'bicubic') 22 | lr = tf.cast(tf.round(tf.clip_by_value(lr, 0, 255)), tf.uint8) 23 | elif compression_idx == 2 or compression_idx == 3: 24 | # bilinear 25 | lr = tf.image.resize(hr, [int(hr_shape[0] / SCALE), int(hr_shape[1] / SCALE)], method = 'bilinear') 26 | lr = tf.cast(tf.round(tf.clip_by_value(lr, 0, 255)), tf.uint8) 27 | elif compression_idx == 4 or compression_idx == 5: 28 | # nearest 29 | lr = tf.image.resize(hr, [int(hr_shape[0] / SCALE), int(hr_shape[1] / SCALE)], method = 'nearest') 30 | lr = tf.cast(tf.round(tf.clip_by_value(lr, 0, 255)), tf.uint8) 31 | else: 32 | # default 33 | lr = example['lr'] 34 | 35 | return lr, hr 36 | 37 | # [======================================================] 38 | # [============= Spatial Random Augmentations ===========] 39 | # [======================================================] 40 | 41 | @tf.function() 42 | def random_crop(lr, hr): 43 | lr_shape = tf.shape(lr)[:2] 44 | 45 | lr_w = tf.random.uniform(shape = (), maxval = lr_shape[1] - LR_SIZE + 1, dtype = tf.int32) 46 | lr_h = tf.random.uniform(shape = (), maxval = lr_shape[0] - LR_SIZE + 1, dtype = tf.int32) 47 | 48 | hr_w = lr_w * int(SCALE) 49 | hr_h = lr_h * int(SCALE) 50 | 51 | lr_cropped = lr[lr_h:lr_h + LR_SIZE, lr_w: lr_w + LR_SIZE] 52 | hr_cropped = hr[hr_h:hr_h + HR_SIZE, hr_w: hr_w + HR_SIZE] 53 | 54 | return lr_cropped, hr_cropped 55 | 56 | @tf.function() 57 | def random_rotate(lr, hr): 58 | rn = tf.random.uniform(shape = (), maxval = 4, dtype = tf.int32) 59 | return tf.image.rot90(lr, rn), tf.image.rot90(hr, rn) 60 | 61 | @tf.function() 62 | def random_spatial_augmentation(lrs, hrs): 63 | lrs, hrs = tf.cond( 64 | tf.random.uniform(shape = (), maxval = 1) < 0.5, 65 | lambda: (lrs, hrs), 66 | lambda: random_rotate(lrs, hrs) 67 | ) 68 | 69 | return tf.cast(lrs, tf.float32), tf.cast(hrs, tf.float32) 70 | 71 | def visualize_samples(images_lists, titles = None, size = (12, 12), masked = False): 72 | assert len(images_lists) == len(titles) 73 | 74 | cols = len(images_lists) 75 | 76 | for images in zip(*images_lists): 77 | plt.figure(figsize = size) 78 | for idx, image in enumerate(images): 79 | plt.subplot(1, cols, idx + 1) 80 | plt.imshow(tf.cast(tf.round(tf.clip_by_value(image, 0, 255)), tf.uint8)) 81 | plt.axis('off') 82 | if titles: 83 | plt.title(titles[idx]) 84 | plt.show() -------------------------------------------------------------------------------- /data/p1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/data/p1.png -------------------------------------------------------------------------------- /data/p2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/data/p2.png -------------------------------------------------------------------------------- /data/p3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/data/p3.png -------------------------------------------------------------------------------- /data/p4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/data/p4.png -------------------------------------------------------------------------------- /data/p5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/data/p5.png -------------------------------------------------------------------------------- /data/p6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/data/p6.png -------------------------------------------------------------------------------- /data/p7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/data/p7.png -------------------------------------------------------------------------------- /figures/c1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/figures/c1.png -------------------------------------------------------------------------------- /figures/c2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/figures/c2.png -------------------------------------------------------------------------------- /figures/c3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/figures/c3.png -------------------------------------------------------------------------------- /figures/c4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/figures/c4.png -------------------------------------------------------------------------------- /figures/c5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/figures/c5.png -------------------------------------------------------------------------------- /figures/c6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/figures/c6.png -------------------------------------------------------------------------------- /figures/dloss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/figures/dloss.png -------------------------------------------------------------------------------- /figures/gloss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/figures/gloss.png -------------------------------------------------------------------------------- /figures/movie1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/figures/movie1.gif -------------------------------------------------------------------------------- /figures/movie2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/figures/movie2.gif -------------------------------------------------------------------------------- /figures/ploss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/figures/ploss.png -------------------------------------------------------------------------------- /figures/rgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/figures/rgan.png -------------------------------------------------------------------------------- /figures/rrdb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/figures/rrdb.png -------------------------------------------------------------------------------- /figures/wall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/figures/wall.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from data import visualize_samples 3 | from models import Generator 4 | 5 | generator = Generator() 6 | generator.load_weights('weights/GeneratorVG4(1520).h5') 7 | 8 | def enhance_image(lr_image = None, lr_path = None, sr_path = None, visualize = True, size = (20, 16)): 9 | assert any([lr_image is not None, lr_path]) 10 | if lr_path: 11 | lr_image = tf.image.decode_jpeg(tf.io.read_file(f"{lr_path}"), channels = 3) 12 | 13 | sr_image = generator(tf.expand_dims(lr_image, 0), training = False)[0] 14 | sr_image = tf.clip_by_value(sr_image, 0, 255) 15 | sr_image = tf.round(sr_image) 16 | sr_image = tf.cast(sr_image, tf.uint8) 17 | 18 | if visualize: 19 | visualize_samples(images_lists = [[lr_image], [sr_image]], titles = ['LR Image', 'SR_Image'], size = size) 20 | 21 | if sr_path: 22 | tf.io.write_file(sr_path, tf.image.encode_jpeg(sr_image)) 23 | 24 | if __name__ == '__main__': 25 | import argparse 26 | 27 | parser = argparse.ArgumentParser(description = 'Super Resolution for Real Time Image Enhancement') 28 | parser.add_argument('--lr-path', type = str, help = 'Path to the low resolution image.') 29 | parser.add_argument('--sr-path', type = str, default = None, help = 'Output path where the enhanced image would be saved.') 30 | 31 | args = parser.parse_args() 32 | 33 | enhance_image( 34 | lr_path = args.lr_path, 35 | sr_path = args.sr_path 36 | ) 37 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import layers 3 | 4 | kernel_init = tf.keras.initializers.GlorotNormal() 5 | 6 | class Conv2D(layers.Conv2D): 7 | def __init__(self, kernel_size = 3, padding = 'same', **kwargs): 8 | super(Conv2D, self).__init__( 9 | kernel_size = kernel_size, 10 | padding = padding, 11 | kernel_initializer = kernel_init, 12 | bias_initializer = tf.keras.initializers.Zeros(), 13 | **kwargs 14 | ) 15 | 16 | class Conv2DBlock(layers.Layer): 17 | def __init__(self, filters, batchnorm = True, activate = True, **kwargs): 18 | super(Conv2DBlock, self).__init__() 19 | 20 | self.conv = Conv2D(filters = filters, **kwargs) 21 | # I forgot to set use_bias to False... 22 | # you please set it to False if you want to save some parameters 23 | # because batchnorm right after conv layer is gonna make the biases obsolete 24 | self.batchnorm = layers.BatchNormalization() if batchnorm else None 25 | self.activate = layers.PReLU(shared_axes = [1, 2]) if activate else None 26 | 27 | def call(self, inputs): 28 | x = self.conv(inputs) 29 | if self.batchnorm: 30 | x = self.batchnorm(x) 31 | if self.activate: 32 | x = self.activate(x) 33 | return x 34 | 35 | class PixelShuffleUpSampling(layers.Layer): 36 | def __init__(self, filters, scale, **kwargs): 37 | super(PixelShuffleUpSampling, self).__init__(**kwargs) 38 | 39 | self.conv1 = Conv2DBlock(filters = filters, batchnorm = False, activate = False) 40 | self.upsample = layers.Lambda(lambda x: tf.nn.depth_to_space(x, scale)) 41 | self.prelu = layers.PReLU(shared_axes = [1, 2]) 42 | 43 | def call(self, x): 44 | x = self.conv1(x) 45 | x = self.upsample(x) 46 | x = self.prelu(x) 47 | return x 48 | 49 | class ResidualDenseBlock(layers.Layer): 50 | def __init__(self, filters = 64): 51 | super(ResidualDenseBlock, self).__init__() 52 | 53 | self.conv1 = Conv2DBlock(filters = filters // 2) 54 | self.conv2 = Conv2DBlock(filters = filters // 2) 55 | self.conv3 = Conv2DBlock(filters = filters, activate = False) 56 | 57 | def call(self, inputs): 58 | x1 = self.conv1(inputs) 59 | x2 = self.conv2(tf.concat([x1, inputs], 3)) 60 | outputs = self.conv3(tf.concat([x2, x1], 3)) 61 | 62 | return outputs + inputs 63 | 64 | class RRDBlock(layers.Layer): 65 | def __init__(self, filters, **kwargs): 66 | super(RRDBlock, self).__init__(**kwargs) 67 | 68 | self.rdb_1 = ResidualDenseBlock(filters) 69 | self.rdb_2 = ResidualDenseBlock(filters) 70 | self.rdb_3 = ResidualDenseBlock(filters) 71 | 72 | self.rrdb_inputs_scales = tf.Variable( 73 | tf.constant(value = 1.0, dtype = tf.float32, shape = [1, 1, 1, filters]), 74 | name = f'{self.name}_rrdb_inputs_scales', 75 | trainable = True 76 | ) 77 | self.rrdb_outputs_scales = tf.Variable( 78 | tf.constant(value = 0.5, dtype = tf.float32, shape = [1, 1, 1, filters]), 79 | name = f'{self.name}_rrdb_outputs_scales', 80 | trainable = True 81 | ) 82 | 83 | def call(self, inputs): 84 | x1 = self.rdb_1(inputs) 85 | x2 = self.rdb_2(x1) 86 | outputs = self.rdb_3(x2) 87 | 88 | return (self.rrdb_inputs_scales * inputs) + (self.rrdb_outputs_scales * outputs) -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import losses, applications, Model 3 | 4 | class PixelLossTraining: 5 | def setup_pixel_loss(self, pixel_loss): 6 | if pixel_loss == 'l1': 7 | self.pixel_loss_type = losses.MeanAbsoluteError() 8 | elif pixel_loss == 'l2': 9 | self.pixel_loss_type = losses.MeanSquaredError() 10 | 11 | @tf.function 12 | def pixel_loss(self, srs, hrs): 13 | return self.pixel_loss_type(hrs, srs) 14 | 15 | class VGGContentTraining: 16 | def setup_content_loss(self, content_loss): 17 | if content_loss == 'l1': 18 | self.content_loss_type = losses.MeanAbsoluteError() 19 | elif content_loss == 'l2': 20 | self.content_loss_type = losses.MeanSquaredError() 21 | 22 | vgg = applications.VGG19( 23 | input_shape = (224, 224, 3), 24 | include_top = False, 25 | weights = 'imagenet' 26 | ) 27 | 28 | vgg.layers[5].activation = None 29 | vgg.layers[10].activation = None 30 | vgg.layers[20].activation = None 31 | 32 | self.feature_extrator = Model( 33 | inputs = vgg.input, 34 | outputs = [ 35 | vgg.layers[5].output, 36 | vgg.layers[10].output, 37 | vgg.layers[20].output 38 | ] 39 | ) 40 | for layer in self.feature_extrator.layers: 41 | layer.trainable = False 42 | 43 | @tf.function 44 | def content_loss(self, srs, hrs): 45 | srs = applications.vgg19.preprocess_input(tf.image.resize(srs, (224, 224))) 46 | hrs = applications.vgg19.preprocess_input(tf.image.resize(hrs, (224, 224))) 47 | 48 | srs_features = self.feature_extrator(srs) 49 | hrs_features = self.feature_extrator(hrs) 50 | 51 | loss = 0.0 52 | for srs_feature, hrs_feature in zip(srs_features, hrs_features): 53 | loss += self.content_loss_type(hrs_feature / 12.75, srs_feature / 12.75) 54 | 55 | return loss 56 | 57 | class GramStyleTraining: 58 | # I tried using this loss but didn't see notice any significant help 59 | # so after first few epochs of training I stopped optimizing the model for this loss to save some time 60 | def setup_gram_style_loss(self, style_loss): 61 | if style_loss == 'l1': 62 | self.style_loss_type = losses.MeanAbsoluteError() 63 | elif style_loss == 'l2': 64 | self.style_loss_type = losses.MeanSquaredError() 65 | 66 | efficientnet = applications.EfficientNetB4( 67 | input_shape = (224, 224, 3), 68 | include_top = False, 69 | weights = 'imagenet' 70 | ) 71 | 72 | self.style_features_extractor = Model( 73 | inputs = efficientnet.input, 74 | outputs = [ 75 | # efficientnet.layers[25].output, 76 | # efficientnet.layers[84].output, 77 | # efficientnet.layers[143].output, 78 | efficientnet.layers[320].output, 79 | # efficientnet.layers[467].output, 80 | ] 81 | ) 82 | for layer in self.style_features_extractor.layers: 83 | layer.trainable = False 84 | 85 | @tf.function 86 | def gram_matrix(self, features): 87 | features = tf.transpose(features, (0, 3, 1, 2)) # (-1, C, H, W) 88 | features_a = tf.reshape(features, (tf.shape(features)[0], tf.shape(features)[1], -1)) # (-1, C, H * W) 89 | features_b = tf.reshape(features, (tf.shape(features)[0], -1, tf.shape(features)[1])) # (-1, H * W, C) 90 | 91 | return tf.linalg.matmul(features_a, features_b) # (-1, C, C) 92 | 93 | @tf.function 94 | def gram_style_loss(self, srs, hrs): 95 | srs = applications.efficientnet.preprocess_input(tf.image.resize(srs, (224, 224))) 96 | hrs = applications.efficientnet.preprocess_input(tf.image.resize(hrs, (224, 224))) 97 | 98 | srs_features = self.style_features_extractor(srs) # (2, -1, H, W, C) 99 | hrs_features = self.style_features_extractor(hrs) # (2, -1, H, W, C) 100 | 101 | # style_loss = 0.0 102 | # for srs_feature, hrs_feature in zip(srs_features, hrs_features): 103 | srs_gram = self.gram_matrix(srs_features) 104 | hrs_gram = self.gram_matrix(hrs_features) 105 | 106 | style_loss = self.style_loss_type(hrs_gram, srs_gram) 107 | 108 | return style_loss 109 | 110 | class AdversarialTraining: 111 | def setup_adversarial_loss(self, adv_loss): 112 | self.adv_loss_type = adv_loss 113 | self.binary_cross_entropy = losses.BinaryCrossentropy(from_logits = True) 114 | 115 | @tf.function 116 | def gen_adv_loss(self, fake_logits, real_logits = None): 117 | if self.adv_loss_type == 'gan': 118 | loss = self.binary_cross_entropy(tf.ones_like(fake_logits), fake_logits) 119 | 120 | elif self.adv_loss_type == 'ragan': 121 | real_loss = self.binary_cross_entropy(tf.ones_like(fake_logits), fake_logits - tf.reduce_mean(real_logits)) 122 | fake_loss = self.binary_cross_entropy(tf.zeros_like(real_logits), real_logits - tf.reduce_mean(fake_logits)) 123 | loss = real_loss + fake_loss 124 | 125 | return loss 126 | 127 | @tf.function 128 | def disc_adv_loss(self, fake_logits, real_logits): 129 | if self.adv_loss_type == 'gan': 130 | real_loss = self.binary_cross_entropy(tf.ones_like(real_logits), real_logits) 131 | fake_loss = self.binary_cross_entropy(tf.zeros_like(fake_logits), fake_logits) 132 | 133 | elif self.adv_loss_type == 'ragan': 134 | real_loss = self.binary_cross_entropy(tf.ones_like(real_logits), real_logits - tf.reduce_mean(fake_logits)) 135 | fake_loss = self.binary_cross_entropy(tf.zeros_like(fake_logits), fake_logits - tf.reduce_mean(real_logits)) 136 | 137 | return real_loss + fake_loss -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from layers import * 2 | from tensorflow.keras import layers, Model 3 | from data import HR_SIZE 4 | 5 | GEN_FILTERS = 64 6 | DISC_FILTERS = 64 7 | 8 | def Generator(): 9 | lr_image = layers.Input(shape = (None, None, 3)) 10 | 11 | spatial_feats = layers.Lambda(lambda x: x / 255.0)(lr_image) 12 | spatial_feats = Conv2DBlock(filters = GEN_FILTERS, kernel_size = 3, strides = 1, padding = 'same', batchnorm = False)(spatial_feats) 13 | spatial_feats = Conv2DBlock(filters = GEN_FILTERS, kernel_size = 1, strides = 1, padding = 'valid', batchnorm = False)(spatial_feats) 14 | 15 | rrdb1 = RRDBlock(GEN_FILTERS)(spatial_feats) 16 | rrdb2 = RRDBlock(GEN_FILTERS)(rrdb1) 17 | rrdb3 = RRDBlock(GEN_FILTERS)(rrdb2) 18 | rrdb4 = RRDBlock(GEN_FILTERS)(rrdb3) 19 | 20 | upsample1 = PixelShuffleUpSampling(GEN_FILTERS * 4, 2)(rrdb4) 21 | upsample2 = PixelShuffleUpSampling(GEN_FILTERS * 4, 2)(upsample1) 22 | 23 | x = Conv2DBlock(filters = GEN_FILTERS, batchnorm = False)(upsample2) 24 | x = Conv2DBlock(filters = 3, kernel_size = 3, activate = False, batchnorm = False)(x) 25 | x = layers.Activation('tanh')(x) 26 | 27 | sr_image = layers.Lambda(lambda x: (x + 1) * 127.5)(x) 28 | 29 | return Model(inputs = lr_image, outputs = sr_image, name = 'Generator') 30 | 31 | def Discriminator(): 32 | hr_image = layers.Input(shape = (HR_SIZE, HR_SIZE, 3)) 33 | x = layers.Lambda(lambda x: x / 127.5 - 1)(hr_image) 34 | 35 | x = Conv2D(kernel_size = 3, filters = DISC_FILTERS // 2)(x) 36 | x = layers.LeakyReLU(0.2)(x) 37 | 38 | x = Conv2D(filters = DISC_FILTERS // 2, kernel_size = 3, strides = 2)(x) 39 | # I forgot to set use_bias to False... 40 | # you please set it to False if you want to save some parameters 41 | # because batchnorm right after conv layer is gonna make the biases obsolete 42 | x = layers.BatchNormalization()(x) 43 | x = layers.LeakyReLU(0.2)(x) 44 | 45 | x = Conv2D(filters = DISC_FILTERS, kernel_size = 3)(x) 46 | x = layers.BatchNormalization()(x) 47 | x = layers.LeakyReLU(0.2)(x) 48 | 49 | x = Conv2D(filters = DISC_FILTERS, kernel_size = 3, strides = 2)(x) 50 | x = layers.BatchNormalization()(x) 51 | x = layers.LeakyReLU(0.2)(x) 52 | 53 | x = Conv2D(filters = DISC_FILTERS * 2, kernel_size = 3, strides = 1)(x) 54 | x = layers.BatchNormalization()(x) 55 | x = layers.LeakyReLU(0.2)(x) 56 | 57 | x = Conv2D(filters = DISC_FILTERS * 2, kernel_size = 3, strides = 2)(x) 58 | x = layers.BatchNormalization()(x) 59 | x = layers.LeakyReLU(0.2)(x) 60 | 61 | x = Conv2D(filters = DISC_FILTERS * 4, kernel_size = 3, strides = 1)(x) 62 | x = layers.BatchNormalization()(x) 63 | x = layers.LeakyReLU(0.2)(x) 64 | 65 | x = Conv2D(filters = DISC_FILTERS * 4, kernel_size = 3, strides = 2)(x) 66 | x = layers.BatchNormalization()(x) 67 | x = layers.LeakyReLU(0.2)(x) 68 | 69 | x = layers.Flatten()(x) 70 | 71 | x = layers.Dense(1024)(x) 72 | x = layers.LeakyReLU(0.2)(x) 73 | x = layers.Dense(1024)(x) 74 | x = layers.LeakyReLU(0.2)(x) 75 | 76 | logits = layers.Dense(1)(x) 77 | 78 | return Model(inputs = hr_image, outputs = logits, name = 'Discriminator') -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==2.4.1 2 | tensorflow-datasets==4.3.0 3 | matplotlib==3.2.1 -------------------------------------------------------------------------------- /results/sr-p1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/results/sr-p1.jpg -------------------------------------------------------------------------------- /results/sr-p2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/results/sr-p2.jpg -------------------------------------------------------------------------------- /results/sr-p3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/results/sr-p3.jpg -------------------------------------------------------------------------------- /results/sr-p4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/results/sr-p4.jpg -------------------------------------------------------------------------------- /results/sr-p5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/results/sr-p5.jpg -------------------------------------------------------------------------------- /results/sr-p6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/results/sr-p6.jpg -------------------------------------------------------------------------------- /results/sr-p7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/results/sr-p7.jpg -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | from tensorflow.keras import Model, optimizers 4 | import tensorflow_datasets as tfds 5 | from data import * 6 | from models import * 7 | from callbacks import * 8 | from losses import * 9 | 10 | train_data = tfds.load(f'div2k/bicubic_x{SCALE}', split = 'train', shuffle_files = True) 11 | train_data = train_data.map(random_compression, num_parallel_calls = tf.data.AUTOTUNE) 12 | train_data = train_data.map(random_crop, num_parallel_calls = tf.data.AUTOTUNE) 13 | train_data = train_data.batch(BATCH_SIZE, drop_remainder = True) 14 | train_data = train_data.map(random_spatial_augmentation, num_parallel_calls = tf.data.AUTOTUNE) 15 | 16 | train_data = train_data.prefetch(tf.data.AUTOTUNE) 17 | 18 | for lrs, hrs in train_data: 19 | break 20 | 21 | print(lrs.shape, hrs.shape) 22 | print(lrs.dtype, hrs.dtype) 23 | print(tf.reduce_min(lrs), tf.reduce_max(lrs)) 24 | print(tf.reduce_min(hrs), tf.reduce_max(hrs)) 25 | 26 | visualize_samples(images_lists = (lrs[:15], hrs[:15]), titles = ('Low Resolution', 'High Resolution'), size = (8, 8)) 27 | 28 | class SRGAN( 29 | Model, 30 | PixelLossTraining, 31 | GramStyleTraining, 32 | VGGContentTraining, 33 | AdversarialTraining 34 | ): 35 | def __init__( 36 | self, 37 | generator, 38 | discriminator, 39 | ): 40 | super(SRGAN, self).__init__(self, dynamic = True) 41 | 42 | self.generator = generator 43 | self.discriminator = discriminator 44 | 45 | def compile( 46 | self, 47 | 48 | generator_optimizer, 49 | discriminator_optimizer, 50 | 51 | perceptual_finetune, 52 | 53 | pixel_loss, 54 | style_loss, 55 | content_loss, 56 | adv_loss, 57 | 58 | loss_weights, 59 | ): 60 | super(SRGAN, self).compile() 61 | 62 | self.generator.optimizer = generator_optimizer 63 | self.discriminator.optimizer = discriminator_optimizer 64 | 65 | self.perceptual_finetune = perceptual_finetune 66 | 67 | self.setup_pixel_loss(pixel_loss) 68 | # self.setup_gram_style_loss(style_loss) 69 | # uncomment this to utilize style loss function 70 | self.setup_content_loss(content_loss) 71 | self.setup_adversarial_loss(adv_loss) 72 | 73 | if self.perceptual_finetune: 74 | self.loss_weights = loss_weights 75 | 76 | def train_step(self, batch): 77 | self.lrs = batch[0] 78 | self.hrs = batch[1] 79 | 80 | if self.perceptual_finetune: 81 | # [=================== Training Discriminator ===================] 82 | 83 | with tf.GradientTape() as disc_tape, tf.GradientTape() as gen_tape: 84 | self.srs = self.generator(self.lrs, training = True) 85 | 86 | real_logits = self.discriminator(self.hrs, training = True) 87 | fake_logits = self.discriminator(self.srs, training = True) 88 | 89 | content_loss = self.loss_weights['content_loss'] * self.content_loss(self.srs, self.hrs) 90 | gen_adv_loss = self.loss_weights['adv_loss'] * self.gen_adv_loss(fake_logits, real_logits) 91 | perceptual_loss = content_loss + gen_adv_loss 92 | 93 | # style_loss = self.loss_weights['style_loss'] * self.gram_style_loss(self.srs, self.hrs) 94 | # uncomment this and add it to gen_loss to utilize style loss function 95 | 96 | gen_loss = perceptual_loss 97 | 98 | disc_adv_loss = self.disc_adv_loss(fake_logits, real_logits) 99 | 100 | discriminator_gradients = disc_tape.gradient(disc_adv_loss, self.discriminator.trainable_variables) 101 | generator_gradients = gen_tape.gradient(gen_loss, self.generator.trainable_variables) 102 | 103 | self.discriminator.optimizer.apply_gradients(zip(discriminator_gradients, self.discriminator.trainable_variables)) 104 | self.generator.optimizer.apply_gradients(zip(generator_gradients, self.generator.trainable_variables)) 105 | 106 | return { 107 | 'Perceptual Loss': perceptual_loss, 108 | # 'Style Loss': style_loss, 109 | 'Generator Adv Loss': gen_adv_loss, 110 | 'Discriminator Adv Loss': disc_adv_loss, 111 | } 112 | 113 | else: 114 | with tf.GradientTape() as gen_tape: 115 | self.srs = self.generator(self.lrs, training = True) 116 | 117 | pixel_loss = self.pixel_loss(self.srs, self.hrs) 118 | 119 | generator_gradients = gen_tape.gradient(pixel_loss, self.generator.trainable_variables) 120 | self.generator.optimizer.apply_gradients(zip(generator_gradients, self.generator.trainable_variables)) 121 | 122 | return { 123 | 'Pixel Loss': pixel_loss, 124 | } 125 | 126 | EPOCHS = 1000 127 | LR = 0.002 128 | BETA_1 = 0.9 129 | BETA_2 = 0.999 130 | 131 | PERCEPTUAL_FINETUNE = False 132 | # first train the model for simply minimizing the simple pixel loss 133 | # you can set the LR a bit high like .001 - .004 134 | # once you pixel loss is saturated then set PERCEPTUAL_FINETUNE to True, 135 | # and reduce the LR down to something like .0004, .0002 136 | # keep monitoring your outputs, and manually reduce the lr when you see the outputs aren't improving anymore 137 | # I tried reducing lr on when loss function becomes plateau, but it didn't work as expected 138 | 139 | PIXEL_LOSS = 'l1' 140 | STYLE_LOSS = 'l1' 141 | CONTENT_LOSS = 'l1' 142 | # l1 loss type worked much better for all above losses in my experiment 143 | ADV_LOSS = 'ragan' 144 | 145 | LOSS_WEIGHTS = {'content_loss': 1.0, 'adv_loss': 0.09, 'style_loss': 1.0} 146 | # Don't forget to tune this adv_loss weights, observe the outputs per epoch 147 | # you'll get to see artifacts by the adversarial training 148 | # partial artificats would be fine, but when you see the outputs are getting weird 149 | # then reduce the adv_loss weights 150 | 151 | CHECKPOINT_DIR = os.path.join('drive', 'MyDrive', 'Model-Checkpoints', 'Super Resolution') 152 | 153 | generator_optimizer = optimizers.Adam( 154 | learning_rate = LR, 155 | beta_1 = BETA_1, 156 | beta_2 = BETA_2 157 | ) 158 | discriminator_optimizer = optimizers.Adam( 159 | learning_rate = LR, 160 | beta_1 = BETA_1, 161 | beta_2 = BETA_2 162 | ) 163 | 164 | generator = Generator() 165 | generator.summary(100) 166 | 167 | discriminator = Discriminator() 168 | discriminator.summary(100) 169 | 170 | 171 | srgan = SRGAN(generator, discriminator) 172 | srgan.compile( 173 | generator_optimizer = generator_optimizer, 174 | discriminator_optimizer = discriminator_optimizer, 175 | 176 | perceptual_finetune = PERCEPTUAL_FINETUNE, 177 | pixel_loss = PIXEL_LOSS, 178 | style_loss = STYLE_LOSS, 179 | content_loss = CONTENT_LOSS, 180 | adv_loss = ADV_LOSS, 181 | 182 | loss_weights = LOSS_WEIGHTS 183 | ) 184 | 185 | ckpt_callback = CheckpointCallback( 186 | checkpoint_dir = CHECKPOINT_DIR, 187 | resume = True, 188 | epoch_step = 4 189 | ) 190 | ckpt_callback.set_model(srgan) 191 | ckpt_callback.setup_checkpoint(srgan) 192 | ckpt_callback.set_lr(LR, BETA_1) 193 | 194 | srgan.fit( 195 | train_data.repeat(EPOCHS // 10), 196 | epochs = 10, 197 | callbacks = [ 198 | ckpt_callback, 199 | ProgressCallback( 200 | logs_step = 0.2, 201 | generator_step = 2 202 | ) 203 | ] 204 | ) -------------------------------------------------------------------------------- /weights/GeneratorVG4(1520).h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/braindotai/Real-Time-Super-Resolution/fa160991361624b3b748e4eae6a144ed60eafd5c/weights/GeneratorVG4(1520).h5 --------------------------------------------------------------------------------