├── vgg_para └── README.txt ├── save_para └── README.txt ├── test ├── 0.jpg ├── 1.jpg ├── 15.jpg ├── 16.jpg ├── 17.jpg ├── 19.jpg ├── 21.jpg ├── 23.jpg ├── 30.jpg └── 31.jpg ├── IMAGES ├── sr1.jpg ├── sr2.jpg ├── sr3.jpg ├── sr4.jpg ├── sr5.jpg ├── sr6.jpg ├── down1.jpg ├── down2.jpg ├── down3.jpg ├── down4.jpg ├── down5.jpg ├── down6.jpg ├── bicubic1.jpg ├── bicubic2.jpg ├── bicubic3.jpg ├── bicubic4.jpg ├── bicubic5.jpg ├── bicubic6.jpg ├── networks.jpg └── wganloss.jpg ├── Raw img ├── 0.jpg ├── 1.jpg ├── 23.jpg ├── 30.jpg └── 31.jpg ├── results ├── 40.jpg └── 60.jpg ├── ImageNet ├── ILSVRC2012_val_00000001.JPEG ├── ILSVRC2012_val_00000002.JPEG └── ILSVRC2012_val_00000003.JPEG ├── LICENSE ├── utils.py ├── main.py ├── README.md ├── network.py ├── train.py └── ops.py /vgg_para/README.txt: -------------------------------------------------------------------------------- 1 | SAVE THE PARA OF VGG19 -------------------------------------------------------------------------------- /save_para/README.txt: -------------------------------------------------------------------------------- 1 | SAVE THE MODEL IN THIS FOLDER -------------------------------------------------------------------------------- /test/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/test/0.jpg -------------------------------------------------------------------------------- /test/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/test/1.jpg -------------------------------------------------------------------------------- /test/15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/test/15.jpg -------------------------------------------------------------------------------- /test/16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/test/16.jpg -------------------------------------------------------------------------------- /test/17.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/test/17.jpg -------------------------------------------------------------------------------- /test/19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/test/19.jpg -------------------------------------------------------------------------------- /test/21.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/test/21.jpg -------------------------------------------------------------------------------- /test/23.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/test/23.jpg -------------------------------------------------------------------------------- /test/30.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/test/30.jpg -------------------------------------------------------------------------------- /test/31.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/test/31.jpg -------------------------------------------------------------------------------- /IMAGES/sr1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/IMAGES/sr1.jpg -------------------------------------------------------------------------------- /IMAGES/sr2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/IMAGES/sr2.jpg -------------------------------------------------------------------------------- /IMAGES/sr3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/IMAGES/sr3.jpg -------------------------------------------------------------------------------- /IMAGES/sr4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/IMAGES/sr4.jpg -------------------------------------------------------------------------------- /IMAGES/sr5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/IMAGES/sr5.jpg -------------------------------------------------------------------------------- /IMAGES/sr6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/IMAGES/sr6.jpg -------------------------------------------------------------------------------- /Raw img/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/Raw img/0.jpg -------------------------------------------------------------------------------- /Raw img/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/Raw img/1.jpg -------------------------------------------------------------------------------- /Raw img/23.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/Raw img/23.jpg -------------------------------------------------------------------------------- /Raw img/30.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/Raw img/30.jpg -------------------------------------------------------------------------------- /Raw img/31.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/Raw img/31.jpg -------------------------------------------------------------------------------- /results/40.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/results/40.jpg -------------------------------------------------------------------------------- /results/60.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/results/60.jpg -------------------------------------------------------------------------------- /IMAGES/down1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/IMAGES/down1.jpg -------------------------------------------------------------------------------- /IMAGES/down2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/IMAGES/down2.jpg -------------------------------------------------------------------------------- /IMAGES/down3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/IMAGES/down3.jpg -------------------------------------------------------------------------------- /IMAGES/down4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/IMAGES/down4.jpg -------------------------------------------------------------------------------- /IMAGES/down5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/IMAGES/down5.jpg -------------------------------------------------------------------------------- /IMAGES/down6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/IMAGES/down6.jpg -------------------------------------------------------------------------------- /IMAGES/bicubic1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/IMAGES/bicubic1.jpg -------------------------------------------------------------------------------- /IMAGES/bicubic2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/IMAGES/bicubic2.jpg -------------------------------------------------------------------------------- /IMAGES/bicubic3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/IMAGES/bicubic3.jpg -------------------------------------------------------------------------------- /IMAGES/bicubic4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/IMAGES/bicubic4.jpg -------------------------------------------------------------------------------- /IMAGES/bicubic5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/IMAGES/bicubic5.jpg -------------------------------------------------------------------------------- /IMAGES/bicubic6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/IMAGES/bicubic6.jpg -------------------------------------------------------------------------------- /IMAGES/networks.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/IMAGES/networks.jpg -------------------------------------------------------------------------------- /IMAGES/wganloss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/IMAGES/wganloss.jpg -------------------------------------------------------------------------------- /ImageNet/ILSVRC2012_val_00000001.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/ImageNet/ILSVRC2012_val_00000001.JPEG -------------------------------------------------------------------------------- /ImageNet/ILSVRC2012_val_00000002.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/ImageNet/ILSVRC2012_val_00000002.JPEG -------------------------------------------------------------------------------- /ImageNet/ILSVRC2012_val_00000003.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/HEAD/ImageNet/ILSVRC2012_val_00000003.JPEG -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 MarTinGuo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.misc as misc 3 | from PIL import Image 4 | import os 5 | 6 | 7 | 8 | def read_crop_data(path, batch_size, shape, factor): 9 | h = shape[0] 10 | w = shape[1] 11 | c = shape[2] 12 | filenames = os.listdir(path) 13 | rand_selects = np.random.randint(0, filenames.__len__(), [batch_size]) 14 | batch = np.zeros([batch_size, h, w, c]) 15 | downsampled = np.zeros([batch_size, h//factor, w//factor, c]) 16 | for idx, select in enumerate(rand_selects): 17 | try: 18 | img = np.array(Image.open(path + filenames[select]))[:, :, :3] 19 | crop = random_crop(img, h) 20 | batch[idx, :, :, :] = crop 21 | downsampled[idx, :, :, :] = misc.imresize(crop, [h // factor, w // factor]) 22 | except: 23 | img = np.array(Image.open(path + filenames[0]))[:, :, :3] 24 | crop = random_crop(img, h) 25 | batch[idx, :, :, :] = crop 26 | downsampled[idx, :, :, :] = misc.imresize(crop, [h//factor, w//factor]) 27 | return batch, downsampled 28 | 29 | def random_crop(img, size): 30 | h = img.shape[0] 31 | w = img.shape[1] 32 | start_x = np.random.randint(0, h - size + 1) 33 | start_y = np.random.randint(0, w - size + 1) 34 | return img[start_x:start_x + size, start_y:start_y + size, :] 35 | 36 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from train import train, test, up_scale 3 | from PIL import Image 4 | import numpy as np 5 | import scipy.misc as misc 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument("--batch_size", type=int, default=4) #The paper: 16 11 | parser.add_argument("--lambd", type=float, default=1e-3) 12 | parser.add_argument("--learning_rate", type=float, default=1e-4) 13 | parser.add_argument("--clip_v", type=float, default=0.05) 14 | parser.add_argument("--B", type=int, default=5) #The paper: 16 15 | parser.add_argument("--max_itr", type=int, default=100000) #The paper: 600000 16 | parser.add_argument("--path_trainset", type=str, default="./ImageNet/") 17 | parser.add_argument("--path_vgg", type=str, default="./vgg_para/") 18 | parser.add_argument("--path_save_model", type=str, default="./save_para/") 19 | 20 | parser.add_argument("--is_trained", type=bool, default=False) 21 | 22 | args = parser.parse_args() 23 | 24 | if args.is_trained: 25 | parser.add_argument("--path_test_img", type=str, default="./test/0.jpg") 26 | args = parser.parse_args() 27 | img = np.array(Image.open(args.path_test_img)) 28 | h, w = img.shape[0] // 4, img.shape[1] // 4 #down sample factor: 4 29 | downsampled_img = misc.imresize(img, [h, w]) 30 | test(downsampled_img, img, args.B) 31 | else: 32 | train(batch_size=args.batch_size, lambd=args.lambd, init_lr=args.learning_rate, clip_v=args.clip_v, B=args.B, 33 | max_itr=args.max_itr, path_trainset=args.path_trainset, path_vgg=args.path_vgg, 34 | path_save_model=args.path_save_model) 35 | 36 | 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SRGAN-with-WGAN-Loss-TensorFlow 2 | SRGAN with WGAN loss function in TensorFlow 3 | ## Introduction 4 | This code mainly address the problem of super resolution, [Super Resolution Generative Adversarial Networks](http://openaccess.thecvf.com/content_cvpr_2017/papers/Ledig_Photo-Realistic_Single_Image_CVPR_2017_paper.pdf) 5 | #### There are four different from the paper: 6 | 1. The loss function, we use WGAN loss, instead of standard GAN loss. 7 | 2. The network architecture, Because of our poor device, in generator, we just use 5 residual block (paper: 16), and in discriminator, we use the standard DCGAN's discriminator. 8 | 3. The training set, device problem again,:cry: we just use a part of ImageNet ([ImageNet Val](http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar)) which just contains 50,000 images. 9 | 4. The max iteration, we just train the model about 100,000 iterations, instead of the paper 600,000. 10 | 11 | ![](https://github.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/blob/master/IMAGES/networks.jpg) 12 | 13 | ## How to use 14 | 1. Download the dataset [ImageNet Val](http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar) 15 | 2. unzip dataset and put it into the folder 'ImageNet' 16 | ``` 17 | ├── test 18 | ├── save_para 19 | ├── results 20 | ├── vgg_para 21 | ├── ImageNet 22 |    ├── ILSVRC2012_val_00000001.JPEG 23 | ├── ILSVRC2012_val_00000002.JPEG 24 | ├── ILSVRC2012_val_00000003.JPEG 25 | ├── ILSVRC2012_val_00000004.JPEG 26 | ├── ILSVRC2012_val_00000005.JPEG 27 | ├── ILSVRC2012_val_00000006.JPEG 28 | ... 29 | ``` 30 | 3. execute the file main.py 31 | ## Requirements 32 | - python3.5 33 | - tensorflow1.4.0 34 | - pillow 35 | - numpy 36 | - scipy 37 | - skimage 38 | ## Results 39 | #### Train procedure WGAN Loss 40 | ![](https://github.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/blob/master/IMAGES/wganloss.jpg) 41 | 42 | |Down sampled|Bicubic (x4)|SRGAN (x4)| 43 | |-|-|-| 44 | |![](https://github.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/blob/master/IMAGES/down1.jpg)|![](https://github.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/blob/master/IMAGES/bicubic1.jpg)|![](https://github.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/blob/master/IMAGES/sr1.jpg)| 45 | |![](https://github.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/blob/master/IMAGES/down2.jpg)|![](https://github.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/blob/master/IMAGES/bicubic2.jpg)|![](https://github.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/blob/master/IMAGES/sr2.jpg)| 46 | |![](https://github.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/blob/master/IMAGES/down3.jpg)|![](https://github.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/blob/master/IMAGES/bicubic3.jpg)|![](https://github.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/blob/master/IMAGES/sr3.jpg)| 47 | |![](https://github.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/blob/master/IMAGES/down4.jpg)|![](https://github.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/blob/master/IMAGES/bicubic4.jpg)|![](https://github.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/blob/master/IMAGES/sr4.jpg)| 48 | |![](https://github.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/blob/master/IMAGES/down5.jpg)|![](https://github.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/blob/master/IMAGES/bicubic5.jpg)|![](https://github.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/blob/master/IMAGES/sr5.jpg)| 49 | |![](https://github.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/blob/master/IMAGES/down6.jpg)|![](https://github.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/blob/master/IMAGES/bicubic6.jpg)|![](https://github.com/MingtaoGuo/SRGAN-with-WGAN-Loss-TensorFlow/blob/master/IMAGES/sr6.jpg)| 50 | ## Reference 51 | [1] Ledig C, Theis L, Huszár F, et al. Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network[C]//CVPR. 2017, 2(3): 4. 52 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | from ops import conv, relu, conv_, max_pooling, fully_connected, batchnorm, prelu, leaky_relu, B_residual_blocks, pixelshuffler 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | class generator: 6 | def __init__(self, name, B): 7 | self.name = name 8 | self.B = B 9 | 10 | def __call__(self, inputs, train_phase): 11 | with tf.variable_scope(self.name): 12 | inputs = conv("conv1", inputs, 64, 9) 13 | inputs = prelu("alpha1", inputs) 14 | skip_connection = tf.identity(inputs) 15 | #The paper has 16 residual blocks 16 | for b in range(1, self.B + 1): 17 | inputs = B_residual_blocks("B"+str(b), inputs, train_phase) 18 | # inputs = B_residual_blocks("B2", inputs, train_phase) 19 | # inputs = B_residual_blocks("B3", inputs, train_phase) 20 | # inputs = B_residual_blocks("B4", inputs, train_phase) 21 | # inputs = B_residual_blocks("B5", inputs, train_phase) 22 | inputs = conv("conv2", inputs, 64, 3) 23 | inputs = batchnorm(inputs, train_phase, "BN") 24 | inputs = inputs + skip_connection 25 | inputs = conv("conv3", inputs, 256, 3) 26 | inputs = pixelshuffler(inputs, 2) 27 | inputs = prelu("alpha2", inputs) 28 | inputs = conv("conv4", inputs, 256, 3) 29 | inputs = pixelshuffler(inputs, 2) 30 | inputs = prelu("alpha3", inputs) 31 | inputs = conv("conv5", inputs, 3, 9) 32 | return tf.nn.tanh(inputs) 33 | 34 | def var_list(self): 35 | return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name) 36 | 37 | class discriminator: 38 | def __init__(self, name): 39 | self.name = name 40 | 41 | def __call__(self, inputs, train_phase): 42 | with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): 43 | # inputs = tf.random_crop(inputs, [-1, 70, 70, 3]) 44 | inputs = conv("conv1_1", inputs, 64, 3, 2) 45 | inputs = leaky_relu(inputs, 0.2) 46 | # inputs = conv("conv1_2", inputs, 64, 3, is_SN=True) 47 | # inputs = leaky_relu(inputs, 0.2) 48 | inputs = conv("conv2_1", inputs, 128, 3, 2) 49 | inputs = batchnorm(inputs, train_phase, "BN1") 50 | inputs = leaky_relu(inputs, 0.2) 51 | # inputs = conv("conv2_2", inputs, 128, 3, is_SN=True) 52 | # inputs = leaky_relu(inputs, 0.2) 53 | inputs = conv("conv3_1", inputs, 256, 3, 2) 54 | inputs = batchnorm(inputs, train_phase, "BN2") 55 | inputs = leaky_relu(inputs, 0.2) 56 | # inputs = conv("conv3_2", inputs, 256, 3, is_SN=True) 57 | # inputs = leaky_relu(inputs, 0.2) 58 | inputs = conv("conv4_1", inputs, 512, 3, 2) 59 | inputs = batchnorm(inputs, train_phase, "BN3") 60 | inputs = leaky_relu(inputs, 0.2) 61 | # inputs = fully_connected("fc", inputs, 512, is_SN=True) 62 | output = fully_connected("output", inputs, 1) 63 | return output 64 | 65 | def var_list(self): 66 | return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name) 67 | 68 | def vggnet(inputs, vgg_path): 69 | inputs = (inputs + 1) * 127.5 70 | inputs = tf.reverse(inputs, [-1]) - np.array([103.939, 116.779, 123.68]) 71 | para = np.load(vgg_path+"vgg19.npy", encoding="latin1").item() 72 | inputs = relu(conv_(inputs, para["conv1_1"][0], para["conv1_1"][1])) 73 | inputs = relu(conv_(inputs, para["conv1_2"][0], para["conv1_2"][1])) 74 | inputs = max_pooling(inputs) 75 | inputs = relu(conv_(inputs, para["conv2_1"][0], para["conv2_1"][1])) 76 | inputs = relu(conv_(inputs, para["conv2_2"][0], para["conv2_2"][1])) 77 | F = inputs 78 | inputs = max_pooling(inputs) 79 | inputs = relu(conv_(inputs, para["conv3_1"][0], para["conv3_1"][1])) 80 | inputs = relu(conv_(inputs, para["conv3_2"][0], para["conv3_2"][1])) 81 | inputs = relu(conv_(inputs, para["conv3_3"][0], para["conv3_3"][1])) 82 | inputs = max_pooling(inputs) 83 | inputs = relu(conv_(inputs, para["conv4_1"][0], para["conv4_1"][1])) 84 | inputs = relu(conv_(inputs, para["conv4_2"][0], para["conv4_2"][1])) 85 | inputs = relu(conv_(inputs, para["conv4_3"][0], para["conv4_3"][1])) 86 | return F 87 | 88 | 89 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from network import generator, discriminator, vggnet 2 | import tensorflow as tf 3 | from utils import read_crop_data 4 | import numpy as np 5 | from PIL import Image 6 | import scipy.misc as misc 7 | from skimage.measure import compare_psnr as psnr 8 | from skimage.measure import compare_ssim as ssim 9 | import time 10 | 11 | def up_scale(downsampled_img): 12 | downsampled_img = downsampled_img[np.newaxis, :, :, :] 13 | downsampled = tf.placeholder(tf.float32, [None, None, None, 3]) 14 | train_phase = tf.placeholder(tf.bool) 15 | G = generator("generator") 16 | SR = G(downsampled, train_phase) 17 | sess = tf.Session() 18 | sess.run(tf.global_variables_initializer()) 19 | saver = tf.train.Saver() 20 | saver.restore(sess, "./save_para/.\\model.ckpt") 21 | SR_img = sess.run(SR, feed_dict={downsampled: downsampled_img/127.5 - 1, train_phase: False}) 22 | Image.fromarray(np.uint8((SR_img[0, :, :, :] + 1)*127.5)).show() 23 | Image.fromarray(np.uint8((downsampled_img[0, :, :, :]))).show() 24 | sess.close() 25 | 26 | def test(downsampled_img, img, B): 27 | downsampled_img = downsampled_img[np.newaxis, :, :, :] 28 | downsampled = tf.placeholder(tf.float32, [None, None, None, 3]) 29 | train_phase = tf.placeholder(tf.bool) 30 | G = generator("generator", B) 31 | SR = G(downsampled, train_phase) 32 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) 33 | sess.run(tf.global_variables_initializer()) 34 | saver = tf.train.Saver() 35 | saver.restore(sess, "./save_para/.\\model.ckpt") 36 | SR_img = sess.run(SR, feed_dict={downsampled: downsampled_img/127.5 - 1, train_phase: False}) 37 | Image.fromarray(np.uint8((SR_img[0, :, :, :] + 1)*127.5)).show() 38 | Image.fromarray(np.uint8((downsampled_img[0, :, :, :]))).show() 39 | h = img.shape[0] 40 | w = img.shape[1] 41 | bic_img = misc.imresize(downsampled_img[0, :, :, :], [h, w]) 42 | Image.fromarray(np.uint8((bic_img))).show() 43 | SR_img = misc.imresize(SR_img[0, :, :, :], [h, w]) 44 | p = psnr(img, SR_img) 45 | s = ssim(img, SR_img, multichannel=True) 46 | p1 = psnr(img, bic_img) 47 | s1 = ssim(img, bic_img, multichannel=True) 48 | print("SR PSNR: %f, SR SSIM:%f, BIC PSNR: %f, BIC SSIM: %f"%(p, s, p1, s1)) 49 | sess.close() 50 | 51 | def train(batch_size=4, lambd=1e-3, init_lr=1e-4, clip_v=0.05, B=16, max_itr=100000, path_trainset="./ImageNet/", path_vgg="./vgg_para/", path_save_model="./save_para/"): 52 | inputs = tf.placeholder(tf.float32, [None, 96, 96, 3]) 53 | downsampled = tf.placeholder(tf.float32, [None, 24, 24, 3]) 54 | train_phase = tf.placeholder(tf.bool) 55 | learning_rate = tf.placeholder(tf.float32) 56 | G = generator("generator", B) 57 | D = discriminator("discriminator") 58 | SR = G(downsampled, train_phase) 59 | phi = vggnet(tf.concat([inputs, SR], axis=0), path_vgg) 60 | phi = tf.split(phi, num_or_size_splits=2, axis=0) 61 | phi_gt = phi[0] 62 | phi_sr = phi[1] 63 | real_logits = D(inputs, train_phase) 64 | fake_logits = D(SR, train_phase) 65 | D_loss = tf.reduce_mean(fake_logits) - tf.reduce_mean(real_logits) 66 | G_loss = -tf.reduce_mean(fake_logits) * lambd + tf.nn.l2_loss(phi_sr - phi_gt) / batch_size 67 | clip_D = [var.assign(tf.clip_by_value(var, -clip_v, clip_v)) for var in D.var_list()] 68 | D_opt = tf.train.RMSPropOptimizer(learning_rate).minimize(D_loss, var_list=D.var_list()) 69 | G_opt = tf.train.RMSPropOptimizer(learning_rate).minimize(G_loss, var_list=G.var_list()) 70 | sess = tf.Session() 71 | sess.run(tf.global_variables_initializer()) 72 | saver = tf.train.Saver() 73 | # saver.restore(sess, "./save_para/.\\model.ckpt") 74 | lr0 = init_lr 75 | for itr in range(max_itr): 76 | if itr == max_itr // 2 or itr == max_itr * 3 // 4: 77 | lr0 = lr0 / 10 78 | s0 = time.time() 79 | batch, down_batch = read_crop_data(path_trainset, batch_size, [96, 96, 3], 4) 80 | e0 = time.time() 81 | batch = batch/127.5 - 1 82 | down_batch = down_batch/127.5 - 1 83 | s1 = time.time() 84 | sess.run(D_opt, feed_dict={inputs: batch, downsampled: down_batch, train_phase: True, learning_rate: lr0}) 85 | sess.run(clip_D) 86 | sess.run(G_opt, feed_dict={inputs: batch, downsampled: down_batch, train_phase: True, learning_rate: lr0}) 87 | e1 = time.time() 88 | if itr % 200 == 0: 89 | [d_loss, g_loss, sr] = sess.run([D_loss, G_loss, SR], feed_dict={downsampled: down_batch, inputs: batch, train_phase: False}) 90 | raw = np.uint8((batch[0] + 1) * 127.5) 91 | bicub = misc.imresize(np.uint8((down_batch[0] + 1) * 127.5), [96, 96]) 92 | gen = np.uint8((sr[0, :, :, :] + 1) * 127.5) 93 | print("Iteration: %d, D_loss: %f, G_loss: %e, PSNR: %f, SSIM: %f, Read_time: %f, Update_time: %f" % (itr, d_loss, g_loss, psnr(raw, gen), ssim(raw, gen, multichannel=True), e0 - s0, e1 - s1)) 94 | Image.fromarray(np.concatenate((raw, bicub, gen), axis=1)).save("./results/" + str(itr) + ".jpg") 95 | if itr % 5000 == 0: 96 | saver.save(sess, path_save_model+"model.ckpt") 97 | 98 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from PIL import Image 4 | 5 | 6 | 7 | def conv(name, inputs, nums_out, k_size, strides=1, is_SN=False): 8 | nums_in = int(inputs.shape[-1]) 9 | with tf.variable_scope(name): 10 | kernel = tf.get_variable("weights", [k_size, k_size, nums_in, nums_out], initializer=tf.truncated_normal_initializer(stddev=0.02)) 11 | bias = tf.get_variable("bias", [nums_out], initializer=tf.constant_initializer(0.)) 12 | if is_SN: 13 | inputs = tf.nn.conv2d(inputs, spectral_normalization(name, kernel), [1, strides, strides, 1], "SAME") + bias 14 | else: 15 | inputs = tf.nn.conv2d(inputs, kernel, [1, strides, strides, 1], "SAME") + bias 16 | return inputs 17 | 18 | def conv_(inputs, w, b): 19 | return tf.nn.conv2d(inputs, w, [1, 1, 1, 1], "SAME") + b 20 | 21 | def max_pooling(inputs): 22 | return tf.nn.max_pool(inputs, [1, 2, 2, 1], [1, 2, 2, 1], "SAME") 23 | 24 | def deconv(name, inputs, nums_out, k_size, strides=2): 25 | nums_in = int(inputs.shape[-1]) 26 | B = tf.shape(inputs)[0] 27 | H = inputs.shape[1] 28 | W = inputs.shape[2] 29 | with tf.variable_scope(name): 30 | kernel = tf.get_variable("weights", [k_size, k_size, nums_out, nums_in], initializer=tf.truncated_normal_initializer(stddev=0.02)) 31 | bias = tf.get_variable("bias", [nums_out], initializer=tf.constant_initializer(0.)) 32 | inputs = tf.nn.conv2d_transpose(inputs, kernel, [B, H * 2, W * 2, nums_out], [1, strides, strides, 1], "SAME") + bias 33 | return inputs 34 | 35 | def B_residual_blocks(name, inputs, train_phase): 36 | temp = tf.identity(inputs) 37 | with tf.variable_scope(name): 38 | inputs = conv("conv1", inputs, 64, 3) 39 | inputs = batchnorm(inputs, train_phase, "BN1") 40 | inputs = prelu("alpha1", inputs) 41 | inputs = conv("conv2", inputs, 64, 3) 42 | inputs = batchnorm(inputs, train_phase, "BN2") 43 | return temp + inputs 44 | 45 | def pixelshuffler(inputs, factor): 46 | B = tf.shape(inputs)[0] 47 | H = tf.shape(inputs)[1] 48 | W = tf.shape(inputs)[2] 49 | nums_in = int(inputs.shape[-1]) 50 | nums_out = nums_in // factor ** 2 51 | inputs = tf.split(inputs, num_or_size_splits=nums_out, axis=-1) 52 | output = 0 53 | for idx, split in enumerate(inputs): 54 | temp = tf.reshape(split, [B, H, W, factor, factor]) 55 | temp = tf.transpose(temp, perm=[0, 1, 4, 2, 3]) 56 | temp = tf.reshape(temp, [B, H * factor, W * factor, 1]) 57 | if idx == 0: 58 | output = temp 59 | else: 60 | output = tf.concat([output, temp], axis=-1) 61 | return output 62 | 63 | def prelu(name, inputs): 64 | with tf.variable_scope(name): 65 | slope = tf.get_variable(name+"alpha", [1], initializer=tf.constant_initializer(0.01)) 66 | return tf.maximum(inputs, inputs * slope) 67 | 68 | def relu(inputs): 69 | return tf.nn.relu(inputs) 70 | 71 | def tanh(inputs): 72 | return tf.nn.tanh(inputs) 73 | 74 | def leaky_relu(inputs, slope=0.2): 75 | return tf.maximum(slope * inputs, inputs) 76 | 77 | def global_sum_pooling(inputs): 78 | return tf.reduce_sum(inputs, axis=[1, 2]) 79 | 80 | def batchnorm(x, train_phase, scope_bn): 81 | #Batch Normalization 82 | #Ioffe S, Szegedy C. Batch normalization: accelerating deep network training by reducing internal covariate shift[J]. 2015:448-456. 83 | with tf.variable_scope(scope_bn): 84 | beta = tf.Variable(tf.constant(0.0, shape=[x.shape[-1]]), name='beta', trainable=True) 85 | gamma = tf.Variable(tf.constant(1.0, shape=[x.shape[-1]]), name='gamma', trainable=True) 86 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments') 87 | ema = tf.train.ExponentialMovingAverage(decay=0.5) 88 | 89 | def mean_var_with_update(): 90 | ema_apply_op = ema.apply([batch_mean, batch_var]) 91 | with tf.control_dependencies([ema_apply_op]): 92 | return tf.identity(batch_mean), tf.identity(batch_var) 93 | 94 | mean, var = tf.cond(train_phase, mean_var_with_update, 95 | lambda: (ema.average(batch_mean), ema.average(batch_var))) 96 | normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3) 97 | return normed 98 | 99 | 100 | def fully_connected(name, inputs, nums_out, is_SN=False): 101 | inputs = tf.layers.flatten(inputs) 102 | with tf.variable_scope(name): 103 | W = tf.get_variable("weights", [inputs.shape[-1], nums_out], initializer=tf.truncated_normal_initializer(stddev=0.02)) 104 | b = tf.get_variable("bias", [nums_out]) 105 | if is_SN: 106 | return tf.matmul(inputs, spectral_normalization(name, W)) + b 107 | else: 108 | return tf.matmul(inputs, W) + b 109 | 110 | def _l2normalize(v, eps=1e-12): 111 | return v / tf.sqrt(tf.reduce_sum(tf.square(v)) + eps) 112 | 113 | 114 | def max_singular_value(W, u=None, Ip=1): 115 | if u is None: 116 | u = tf.get_variable("u", [1, W.shape[-1]], initializer=tf.random_normal_initializer(), trainable=False) #1 x ch 117 | _u = u 118 | _v = 0 119 | for _ in range(Ip): 120 | _v = _l2normalize(tf.matmul(_u, W), eps=1e-12) 121 | _u = _l2normalize(tf.matmul(_v, W, transpose_b=True), eps=1e-12) 122 | sigma = tf.reduce_sum(tf.matmul(_u, W) * _v) 123 | return sigma, _u, _v 124 | 125 | def spectral_normalization(name, W, Ip=1): 126 | u = tf.get_variable(name + "_u", [1, W.shape[-1]], initializer=tf.random_normal_initializer(), trainable=False) # 1 x ch 127 | W_mat = tf.transpose(tf.reshape(W, [-1, W.shape[-1]])) 128 | sigma, _u, _ = max_singular_value(W_mat, u, Ip) 129 | with tf.control_dependencies([tf.assign(u, _u)]): 130 | W_sn = W / sigma 131 | return W_sn 132 | 133 | 134 | 135 | 136 | --------------------------------------------------------------------------------