├── checkpoints └── checkpoint ├── .gitignore ├── PSNR.py ├── data ├── README.md ├── aug_test.m └── aug_train.m ├── MODEL.py ├── MODEL_FACTORIZED.py ├── README.md ├── PLOT.py ├── TEST.py └── VDSR.py /checkpoints/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "VDSR_adam_epoch_021.ckpt-74008" 2 | all_model_checkpoint_paths: "VDSR_adam_epoch_021.ckpt-74008" 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | log/ 2 | trash/ 3 | *.pyc 4 | data/train/ 5 | data/train 6 | data/test 7 | data/Set* 8 | data/291 9 | data/91 10 | *.mat 11 | checkpoints/*ckpt* 12 | psnr/*ckpt* 13 | preserve.zip 14 | -------------------------------------------------------------------------------- /PSNR.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import math 4 | 5 | def psnr(target, ref, scale): 6 | #assume RGB image 7 | target_data = np.array(target) 8 | target_data = target_data[scale:-scale, scale:-scale] 9 | 10 | ref_data = np.array(ref) 11 | ref_data = ref_data[scale:-scale, scale:-scale] 12 | 13 | diff = ref_data - target_data 14 | diff = diff.flatten('C') 15 | rmse = math.sqrt( np.mean(diff ** 2.) ) 16 | return 20*math.log10(1.0/rmse) 17 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # data generation 2 | 0. Download train/test data from [original author's project page](http://cv.snu.ac.kr/research/VDSR/) 3 | 1. Download and unzip 291 dataset, and set the proper directory in 'aug_train.m'. 4 | 2. Download and unzip other test dataset (Set5, Set14, B100, Urban100), and set the proper directory in 'aug_test.m'. 5 | 3. run 'aug_train.m' and 'aug_test.m' matlab code for patch/data generation 6 | 7 | 8 | - Please note that data are generated/manipulated in Matlab, for a good bicubic interpolation and reproducibility. (OpenCV2 interpolation is strange) 9 | - Too much data will make the network diverge. I'm currently using patches with original/rotate90/original flipped/rotate90 flipped, and you can find the data [here](https://drive.google.com/file/d/0B4KsMpU0Beosc1FNQVlFZWlMOG8/view?usp=sharing) 10 | -------------------------------------------------------------------------------- /data/aug_test.m: -------------------------------------------------------------------------------- 1 | 2 | target = 'Set14'; 3 | dataDir = fullfile('./', target); 4 | count = 0; 5 | f_lst = dir(fullfile(dataDir, '*.bmp')); 6 | folder = fullfile('test', target); 7 | mkdir(folder); 8 | for f_iter = 1:numel(f_lst) 9 | % disp(f_iter); 10 | f_info = f_lst(f_iter); 11 | if f_info.name == '.' 12 | continue; 13 | end 14 | f_path = fullfile(dataDir,f_info.name); 15 | disp(f_path); 16 | img_raw = imread(f_path); 17 | if size(img_raw,3)==3 18 | img_raw = rgb2ycbcr(img_raw); 19 | img_raw = img_raw(:,:,1); 20 | % else 21 | % img_raw = rgb2ycbcr(repmat(img_raw, [1 1 3])); 22 | end 23 | 24 | img_raw = im2double(img_raw); 25 | 26 | img_size = size(img_raw); 27 | width = img_size(2); 28 | height = img_size(1); 29 | 30 | img_raw = img_raw(1:height-mod(height,12),1:width-mod(width,12),:); 31 | 32 | img_size = size(img_raw); 33 | 34 | img_2 = imresize(imresize(img_raw,1/2,'bicubic'),[img_size(1),img_size(2)],'bicubic'); 35 | img_3 = imresize(imresize(img_raw,1/3,'bicubic'),[img_size(1),img_size(2)],'bicubic'); 36 | img_4 = imresize(imresize(img_raw,1/4,'bicubic'),[img_size(1),img_size(2)],'bicubic'); 37 | 38 | patch_name = sprintf('%s/%d',folder,count); 39 | 40 | save(patch_name, 'img_raw'); 41 | save(sprintf('%s_2', patch_name), 'img_2'); 42 | save(sprintf('%s_3', patch_name), 'img_3'); 43 | save(sprintf('%s_4', patch_name), 'img_4'); 44 | 45 | count = count + 1; 46 | display(count); 47 | 48 | 49 | end 50 | -------------------------------------------------------------------------------- /MODEL.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def model(input_tensor): 5 | with tf.device("/gpu:0"): 6 | weights = [] 7 | tensor = None 8 | 9 | #conv_00_w = tf.get_variable("conv_00_w", [3,3,1,64], initializer=tf.contrib.layers.xavier_initializer()) 10 | conv_00_w = tf.get_variable("conv_00_w", [3,3,1,64], initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0/9))) 11 | conv_00_b = tf.get_variable("conv_00_b", [64], initializer=tf.constant_initializer(0)) 12 | weights.append(conv_00_w) 13 | weights.append(conv_00_b) 14 | tensor = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(input_tensor, conv_00_w, strides=[1,1,1,1], padding='SAME'), conv_00_b)) 15 | 16 | for i in range(18): 17 | #conv_w = tf.get_variable("conv_%02d_w" % (i+1), [3,3,64,64], initializer=tf.contrib.layers.xavier_initializer()) 18 | conv_w = tf.get_variable("conv_%02d_w" % (i+1), [3,3,64,64], initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0/9/64))) 19 | conv_b = tf.get_variable("conv_%02d_b" % (i+1), [64], initializer=tf.constant_initializer(0)) 20 | weights.append(conv_w) 21 | weights.append(conv_b) 22 | tensor = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1,1,1,1], padding='SAME'), conv_b)) 23 | 24 | #conv_w = tf.get_variable("conv_19_w", [3,3,64,1], initializer=tf.contrib.layers.xavier_initializer()) 25 | conv_w = tf.get_variable("conv_20_w", [3,3,64,1], initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0/9/64))) 26 | conv_b = tf.get_variable("conv_20_b", [1], initializer=tf.constant_initializer(0)) 27 | weights.append(conv_w) 28 | weights.append(conv_b) 29 | tensor = tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1,1,1,1], padding='SAME'), conv_b) 30 | 31 | tensor = tf.add(tensor, input_tensor) 32 | return tensor, weights 33 | -------------------------------------------------------------------------------- /MODEL_FACTORIZED.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def model_factorized(input_tensor): 5 | with tf.device("/gpu:0"): 6 | weights = [] 7 | tensor = None 8 | 9 | conv_00_w = tf.get_variable("conv_00_w", [3,3,1,64], initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0/9/32))) 10 | conv_00_b = tf.get_variable("conv_00_b", [64], initializer=tf.constant_initializer(0)) 11 | weights.append(conv_00_w) 12 | weights.append(conv_00_b) 13 | tensor = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(input_tensor, conv_00_w, strides=[1,1,1,1], padding='SAME'), conv_00_b)) 14 | 15 | depth = 50 16 | for i in range(depth-1): 17 | depthwise_filter = tf.get_variable("depth_conv_%02d_w" % (i+1), [3,3,64,1], initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0/9/32))) 18 | pointwise_filter = tf.get_variable("point_conv_%02d_w" % (i+1), [1,1,64,64], initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0/1/128))) 19 | conv_b = tf.get_variable("conv_%02d_b" % (i+1), [64], initializer=tf.constant_initializer(0)) 20 | weights.append(depthwise_filter) 21 | weights.append(pointwise_filter) 22 | weights.append(conv_b) 23 | conv_tensor = tf.nn.bias_add(tf.nn.separable_conv2d(tensor, depthwise_filter, pointwise_filter, [1,1,1,1], padding='SAME'), conv_b) 24 | """ 25 | conv_tensor = tf.nn.relu(tf.nn.depthwise_conv2d(tensor, depthwise_filter, [1,1,1,1], padding='SAME')) 26 | conv_tensor = tf.nn.bias_add(tf.nn.conv2d(conv_tensor, pointwise_filter, [1,1,1,1], padding='VALID'), conv_b) 27 | """ 28 | tensor = tf.nn.relu(tf.add(tensor, conv_tensor)) 29 | 30 | 31 | conv_w = tf.get_variable("conv_%02d_w"%depth, [3,3,64,1], initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0/9/64))) 32 | conv_b = tf.get_variable("conv_%02d_b"%depth, [1], initializer=tf.constant_initializer(0)) 33 | weights.append(conv_w) 34 | weights.append(conv_b) 35 | tensor = tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1,1,1,1], padding='SAME'), conv_b) 36 | 37 | tensor = tf.add(tensor, input_tensor) 38 | return tensor, weights 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow-vdsr 2 | 3 | ## Overview 4 | This is a Tensorflow implementation for ["Accurate Image Super-Resolution Using Very Deep Convolutional Networks", CVPR 16'](http://cv.snu.ac.kr/research/VDSR/VDSR_CVPR2016.pdf). 5 | - [The author's project page](http://cv.snu.ac.kr/research/VDSR/) 6 | - To download the required data for training/testing, please refer to the README.md at data directory. 7 | 8 | ## Files 9 | - VDSR.py : main training file. 10 | - MODEL.py : model definition. 11 | - MODEL_FACTORIZED.py : model definition for Factorized CNN. (not recommended to use. for record purpose only) 12 | - PSNR.py : define how to calculate PSNR in python 13 | - TEST.py : test all the saved checkpoints 14 | - PLOT.py : plot the test result from TEST.py 15 | 16 | ## How To Use 17 | ### Training 18 | ```shell 19 | # if start from scratch 20 | python VDSR.py 21 | # if start with a checkpoint 22 | python VDSR.py --model_path ./checkpoints/CHECKPOINT_NAME.ckpt 23 | ``` 24 | ### Testing 25 | ```shell 26 | # this will test all the checkpoint in ./checkpoint directory. 27 | # and save the results in ./psnr directory 28 | python TEST.py 29 | ``` 30 | ### Plot Result 31 | ```shell 32 | # plot the psnr result stored in ./psnr directory 33 | python PLOT.py 34 | ``` 35 | 36 | ## Result 37 | The checkpoint is file is [here](https://drive.google.com/file/d/0B4KsMpU0BeosbDB2NllZZkdvY1U/view?usp=sharing&resourcekey=0-G924x9W58xBdEWG0ACKlLw) 38 | ##### Results on Set 5 39 | 40 | | Scale | Bicubic | VDSR | tf_VDSR | 41 | |:---------:|:-------:|:----:|:-------:| 42 | | **2x** - PSNR/SSIM| 33.66/0.9929 | 37.53/0.9587 | 37.24 | 43 | | **3x** - PSNR/SSIM| 30.39/0.8682 | 33.66/0.9213 | 33.37 | 44 | | **4x** - PSNR/SSIM| 28.42/0.8104 | 31.35/0.8838 | 31.09 | 45 | 46 | ##### Results on Set 14 47 | 48 | | Scale | Bicubic | VDSR | tf_VDSR | 49 | |:---------:|:-------:|:----:|:-------:| 50 | | **2x** - PSNR/SSIM| 30.24/0.8688 | 33.03/0.9124 | 32.80 | 51 | | **3x** - PSNR/SSIM| 27.55/0.7742 | 29.77/0.8314 | 29.67 | 52 | | **4x** - PSNR/SSIM| 26.00/0.7027 | 28.01/0.7674 | 27.87 | 53 | 54 | ## Remarks 55 | - The training is further accelerated with asynchronous data fetch. 56 | - Tried to accelerate the network with the idea from [Factorized CNN](https://128.84.21.199/pdf/1608.04337v1.pdf). It is possible to implement with `tf.nn.depthwise_conv2d` and 1x1 convolution, but not so effective. 57 | - Thanks to @harungunaydin 's comment, **AdamOptimizer** gives a much more stable training. There's an option added. 58 | -------------------------------------------------------------------------------- /PLOT.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pickle, glob 3 | import numpy as np 4 | import sys 5 | psnr_prefix = './psnr/*' 6 | psnr_paths = sorted(glob.glob(psnr_prefix)) 7 | 8 | psnr_means = {} 9 | 10 | def filter_by_scale(row, scale): 11 | return row[-1]==scale 12 | 13 | for i, psnr_path in enumerate(psnr_paths): 14 | print "" 15 | print psnr_path 16 | psnr_dict = None 17 | epoch = str(i)#psnr_path.split("_")[-1] 18 | with open(psnr_path, 'rb') as f: 19 | psnr_dict = pickle.load(f) 20 | dataset_keys = psnr_dict.keys() 21 | for j, key in enumerate(dataset_keys): 22 | print 'dataset', key 23 | psnr_list = psnr_dict[key] 24 | psnr_np = np.array(psnr_list) 25 | 26 | psnr_np_2 = psnr_np[np.array([filter_by_scale(row,2) for row in psnr_np])] 27 | psnr_np_3 = psnr_np[np.array([filter_by_scale(row,3) for row in psnr_np])] 28 | psnr_np_4 = psnr_np[np.array([filter_by_scale(row,4) for row in psnr_np])] 29 | print "x2:",np.mean(psnr_np_2, axis=0).tolist() 30 | print "x3:",np.mean(psnr_np_3, axis=0).tolist() 31 | print "x4:",np.mean(psnr_np_4, axis=0).tolist() 32 | 33 | mean_2 = np.mean(psnr_np_2, axis=0).tolist() 34 | mean_3 = np.mean(psnr_np_3, axis=0).tolist() 35 | mean_4 = np.mean(psnr_np_4, axis=0).tolist() 36 | psnr_mean = [mean_2, mean_3, mean_4] 37 | #print 'psnr mean', psnr_mean 38 | if psnr_means.has_key(key): 39 | psnr_means[key][epoch] = psnr_mean 40 | else: 41 | psnr_means[key] = {epoch: psnr_mean} 42 | 43 | #sys.exit(1) 44 | 45 | keys = psnr_means.keys() 46 | for i, key in enumerate(keys): 47 | psnr_dict = psnr_means[key] 48 | epochs = sorted(psnr_dict.keys()) 49 | x_axis = [] 50 | bicub_mean = [] 51 | vdsr_mean_2 = [] 52 | vdsr_mean_3 = [] 53 | vdsr_mean_4 = [] 54 | 55 | for epoch in epochs: 56 | print epoch 57 | print psnr_dict[epoch] 58 | x_axis.append(int(epoch)) 59 | bicub_mean.append(psnr_dict[epoch][0][0]) 60 | vdsr_mean_2.append(psnr_dict[epoch][0][1]) 61 | vdsr_mean_3.append(psnr_dict[epoch][1][1]) 62 | vdsr_mean_4.append(psnr_dict[epoch][2][1]) 63 | plt.figure(i) 64 | print key 65 | print len(x_axis), len(bicub_mean), len(vdsr_mean_2) 66 | print vdsr_mean_2 67 | print "x2", np.argmax(vdsr_mean_2), np.max(vdsr_mean_2) 68 | print "x3", np.argmax(vdsr_mean_3), np.max(vdsr_mean_3) 69 | print "x4", np.argmax(vdsr_mean_4), np.max(vdsr_mean_4) 70 | lines_bicub = plt.plot(vdsr_mean_2, 'g') 71 | lines_bicub = plt.plot(vdsr_mean_4, 'b', vdsr_mean_3, 'y') 72 | plt.setp(lines_bicub, linewidth=3.0) 73 | #plt.show() 74 | 75 | """ 76 | psnr_means : 77 | { 78 | 'DATASET_NAME' : 79 | { 80 | 'EPOCH' : [bicubic psnr, vdsr psnr] 81 | } 82 | 'DATASET_NAME_2': 83 | { 84 | 'EPOCH' : [bicubic psnr, vdsr psnr] 85 | } 86 | ... 87 | } 88 | """ 89 | for i, psnr_path in enumerate(psnr_paths): 90 | print i, psnr_path 91 | -------------------------------------------------------------------------------- /TEST.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import misc 3 | from PIL import Image 4 | import tensorflow as tf 5 | import glob, os, re 6 | from PSNR import psnr 7 | import scipy.io 8 | import pickle 9 | from MODEL import model 10 | #from MODEL_FACTORIZED import model_factorized 11 | import time 12 | DATA_PATH = "./data/test/" 13 | 14 | import argparse 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--model_path") 17 | args = parser.parse_args() 18 | model_path = args.model_path 19 | def get_img_list(data_path): 20 | l = glob.glob(os.path.join(data_path,"*")) 21 | l = [f for f in l if re.search("^\d+.mat$", os.path.basename(f))] 22 | train_list = [] 23 | for f in l: 24 | if os.path.exists(f): 25 | if os.path.exists(f[:-4]+"_2.mat"): train_list.append([f, f[:-4]+"_2.mat", 2]) 26 | if os.path.exists(f[:-4]+"_3.mat"): train_list.append([f, f[:-4]+"_3.mat", 3]) 27 | if os.path.exists(f[:-4]+"_4.mat"): train_list.append([f, f[:-4]+"_4.mat", 4]) 28 | return train_list 29 | def get_test_image(test_list, offset, batch_size): 30 | target_list = test_list[offset:offset+batch_size] 31 | input_list = [] 32 | gt_list = [] 33 | scale_list = [] 34 | for pair in target_list: 35 | print pair[1] 36 | mat_dict = scipy.io.loadmat(pair[1]) 37 | input_img = None 38 | if mat_dict.has_key("img_2"): input_img = mat_dict["img_2"] 39 | elif mat_dict.has_key("img_3"): input_img = mat_dict["img_3"] 40 | elif mat_dict.has_key("img_4"): input_img = mat_dict["img_4"] 41 | else: continue 42 | gt_img = scipy.io.loadmat(pair[0])['img_raw'] 43 | input_list.append(input_img) 44 | gt_list.append(gt_img) 45 | scale_list.append(pair[2]) 46 | return input_list, gt_list, scale_list 47 | def test_VDSR_with_sess(epoch, ckpt_path, data_path,sess): 48 | folder_list = glob.glob(os.path.join(data_path, 'Set*')) 49 | print 'folder_list', folder_list 50 | saver.restore(sess, ckpt_path) 51 | 52 | psnr_dict = {} 53 | for folder_path in folder_list: 54 | psnr_list = [] 55 | img_list = get_img_list(folder_path) 56 | for i in range(len(img_list)): 57 | input_list, gt_list, scale_list = get_test_image(img_list, i, 1) 58 | input_y = input_list[0] 59 | gt_y = gt_list[0] 60 | start_t = time.time() 61 | img_vdsr_y = sess.run([output_tensor], feed_dict={input_tensor: np.resize(input_y, (1, input_y.shape[0], input_y.shape[1], 1))}) 62 | img_vdsr_y = np.resize(img_vdsr_y, (input_y.shape[0], input_y.shape[1])) 63 | end_t = time.time() 64 | print "end_t",end_t,"start_t",start_t 65 | print "time consumption",end_t-start_t 66 | print "image_size", input_y.shape 67 | 68 | psnr_bicub = psnr(input_y, gt_y, scale_list[0]) 69 | psnr_vdsr = psnr(img_vdsr_y, gt_y, scale_list[0]) 70 | print "PSNR: bicubic %f\tVDSR %f" % (psnr_bicub, psnr_vdsr) 71 | psnr_list.append([psnr_bicub, psnr_vdsr, scale_list[0]]) 72 | psnr_dict[os.path.basename(folder_path)] = psnr_list 73 | with open('psnr/%s' % os.path.basename(ckpt_path), 'wb') as f: 74 | pickle.dump(psnr_dict, f) 75 | def test_VDSR(epoch, ckpt_path, data_path): 76 | with tf.Session() as sess: 77 | test_VDSR_with_sess(epoch, ckpt_path, data_path, sess) 78 | if __name__ == '__main__': 79 | model_list = sorted(glob.glob("./checkpoints/VDSR_adam_epoch_*")) 80 | model_list = [fn for fn in model_list if not os.path.basename(fn).endswith("meta")] 81 | with tf.Session() as sess: 82 | input_tensor = tf.placeholder(tf.float32, shape=(1, None, None, 1)) 83 | shared_model = tf.make_template('shared_model', model) 84 | output_tensor, weights = shared_model(input_tensor) 85 | #output_tensor, weights = model(input_tensor) 86 | saver = tf.train.Saver(weights) 87 | tf.initialize_all_variables().run() 88 | for model_ckpt in model_list: 89 | print model_ckpt 90 | epoch = int(model_ckpt.split('epoch_')[-1].split('.ckpt')[0]) 91 | #if epoch<60: 92 | # continue 93 | print "Testing model",model_ckpt 94 | test_VDSR_with_sess(80, model_ckpt, DATA_PATH,sess) 95 | -------------------------------------------------------------------------------- /data/aug_train.m: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | dataDir = '291';%fullfile('data', '291'); 6 | mkdir('train'); 7 | count = 0; 8 | f_lst = []; 9 | f_lst = [f_lst; dir(fullfile(dataDir, '*.jpg'))]; 10 | f_lst = [f_lst; dir(fullfile(dataDir, '*.bmp'))]; 11 | for f_iter = 1:numel(f_lst) 12 | % disp(f_iter); 13 | f_info = f_lst(f_iter); 14 | if f_info.name == '.' 15 | continue; 16 | end 17 | f_path = fullfile(dataDir,f_info.name); 18 | img_raw = imread(f_path); 19 | img_raw = rgb2ycbcr(img_raw); 20 | img_raw = im2double(img_raw(:,:,1)); 21 | 22 | img_size = size(img_raw); 23 | width = img_size(2); 24 | height = img_size(1); 25 | 26 | img_raw = img_raw(1:height-mod(height,12),1:width-mod(width,12),:); 27 | 28 | img_size = size(img_raw); 29 | patch_size = 41; 30 | stride = 41; 31 | x_size = (img_size(2)-patch_size)/stride+1; 32 | y_size = (img_size(1)-patch_size)/stride+1; 33 | 34 | img_2 = imresize(imresize(img_raw,1/2,'bicubic'),[img_size(1),img_size(2)],'bicubic'); 35 | img_3 = imresize(imresize(img_raw,1/3,'bicubic'),[img_size(1),img_size(2)],'bicubic'); 36 | img_4 = imresize(imresize(img_raw,1/4,'bicubic'),[img_size(1),img_size(2)],'bicubic'); 37 | 38 | for x = 0:x_size-1 39 | for y = 0:y_size-1 40 | x_coord = x*stride; y_coord = y*stride; 41 | patch_name = sprintf('train/%d',count); 42 | 43 | patch = imrotate(img_raw(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 0); 44 | save(patch_name, 'patch'); 45 | patch = imrotate(img_2(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 0); 46 | save(sprintf('%s_2', patch_name), 'patch'); 47 | patch = imrotate(img_3(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 0); 48 | save(sprintf('%s_3', patch_name), 'patch'); 49 | patch = imrotate(img_4(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 0); 50 | save(sprintf('%s_4', patch_name), 'patch'); 51 | 52 | count = count+1; 53 | 54 | patch_name = sprintf('train/%d',count); 55 | 56 | patch = imrotate(img_raw(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 90); 57 | save(patch_name, 'patch'); 58 | patch = imrotate(img_2(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 90); 59 | save(sprintf('%s_2', patch_name), 'patch'); 60 | patch = imrotate(img_3(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 90); 61 | save(sprintf('%s_3', patch_name), 'patch'); 62 | patch = imrotate(img_4(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 90); 63 | save(sprintf('%s_4', patch_name), 'patch'); 64 | 65 | count = count+1; 66 | 67 | patch_name = sprintf('train/%d',count); 68 | 69 | patch = fliplr(imrotate(img_raw(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 0)); 70 | save(patch_name, 'patch'); 71 | patch = fliplr(imrotate(img_2(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 0)); 72 | save(sprintf('%s_2', patch_name), 'patch'); 73 | patch = fliplr(imrotate(img_3(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 0)); 74 | save(sprintf('%s_3', patch_name), 'patch'); 75 | patch = fliplr(imrotate(img_4(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 0)); 76 | save(sprintf('%s_4', patch_name), 'patch'); 77 | 78 | count = count+1; 79 | 80 | patch_name = sprintf('train/%d',count); 81 | 82 | patch = fliplr(imrotate(img_raw(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 90)); 83 | save(patch_name, 'patch'); 84 | patch = fliplr(imrotate(img_2(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 90)); 85 | save(sprintf('%s_2', patch_name), 'patch'); 86 | patch = fliplr(imrotate(img_3(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 90)); 87 | save(sprintf('%s_3', patch_name), 'patch'); 88 | patch = fliplr(imrotate(img_4(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 90)); 89 | save(sprintf('%s_4', patch_name), 'patch'); 90 | 91 | count = count+1; 92 | 93 | 94 | %{ 95 | patch_name = sprintf('aug/%d',count); 96 | 97 | patch = imrotate(img_raw(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 180); 98 | save(patch_name, 'patch'); 99 | patch = imrotate(img_2(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 180); 100 | save(sprintf('%s_2', patch_name), 'patch'); 101 | patch = imrotate(img_3(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 180); 102 | save(sprintf('%s_3', patch_name), 'patch'); 103 | patch = imrotate(img_4(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 180); 104 | save(sprintf('%s_4', patch_name), 'patch'); 105 | 106 | count = count+1; 107 | 108 | patch_name = sprintf('aug/%d',count); 109 | 110 | patch = fliplr(imrotate(img_raw(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 180)); 111 | save(patch_name, 'patch'); 112 | patch = fliplr(imrotate(img_2(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 180)); 113 | save(sprintf('%s_2', patch_name), 'patch'); 114 | patch = fliplr(imrotate(img_3(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 180)); 115 | save(sprintf('%s_3', patch_name), 'patch'); 116 | patch = fliplr(imrotate(img_4(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 180)); 117 | save(sprintf('%s_4', patch_name), 'patch'); 118 | 119 | count = count+1; 120 | 121 | patch_name = sprintf('aug/%d',count); 122 | 123 | patch = imrotate(img_raw(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 270); 124 | save(patch_name, 'patch'); 125 | patch = imrotate(img_2(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 270); 126 | save(sprintf('%s_2', patch_name), 'patch'); 127 | patch = imrotate(img_3(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 270); 128 | save(sprintf('%s_3', patch_name), 'patch'); 129 | patch = imrotate(img_4(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 270); 130 | save(sprintf('%s_4', patch_name), 'patch'); 131 | 132 | count = count+1; 133 | 134 | patch_name = sprintf('aug/%d',count); 135 | 136 | patch = fliplr(imrotate(img_raw(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 180)); 137 | save(patch_name, 'patch'); 138 | patch = fliplr(imrotate(img_2(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 180)); 139 | save(sprintf('%s_2', patch_name), 'patch'); 140 | patch = fliplr(imrotate(img_3(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 180)); 141 | save(sprintf('%s_3', patch_name), 'patch'); 142 | patch = fliplr(imrotate(img_4(y_coord+1:y_coord+patch_size,x_coord+1:x_coord+patch_size,:), 180)); 143 | save(sprintf('%s_4', patch_name), 'patch'); 144 | 145 | count = count+1; 146 | %} 147 | end 148 | end 149 | 150 | display(count); 151 | 152 | 153 | end 154 | -------------------------------------------------------------------------------- /VDSR.py: -------------------------------------------------------------------------------- 1 | import os, glob, re, signal, sys, argparse, threading, time 2 | from random import shuffle 3 | import random 4 | import tensorflow as tf 5 | from PIL import Image 6 | import numpy as np 7 | import scipy.io 8 | from MODEL import model 9 | from PSNR import psnr 10 | from TEST import test_VDSR 11 | #from MODEL_FACTORIZED import model_factorized 12 | DATA_PATH = "./data/train/" 13 | IMG_SIZE = (41, 41) 14 | BATCH_SIZE = 64 15 | BASE_LR = 0.0001 16 | LR_RATE = 0.1 17 | LR_STEP_SIZE = 120 18 | MAX_EPOCH = 120 19 | 20 | USE_QUEUE_LOADING = True 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--model_path") 24 | args = parser.parse_args() 25 | model_path = args.model_path 26 | 27 | TEST_DATA_PATH = "./data/test/" 28 | 29 | def get_train_list(data_path): 30 | l = glob.glob(os.path.join(data_path,"*")) 31 | print len(l) 32 | l = [f for f in l if re.search("^\d+.mat$", os.path.basename(f))] 33 | print len(l) 34 | train_list = [] 35 | for f in l: 36 | if os.path.exists(f): 37 | if os.path.exists(f[:-4]+"_2.mat"): train_list.append([f, f[:-4]+"_2.mat"]) 38 | if os.path.exists(f[:-4]+"_3.mat"): train_list.append([f, f[:-4]+"_3.mat"]) 39 | if os.path.exists(f[:-4]+"_4.mat"): train_list.append([f, f[:-4]+"_4.mat"]) 40 | return train_list 41 | 42 | def get_image_batch(train_list,offset,batch_size): 43 | target_list = train_list[offset:offset+batch_size] 44 | input_list = [] 45 | gt_list = [] 46 | cbcr_list = [] 47 | for pair in target_list: 48 | input_img = scipy.io.loadmat(pair[1])['patch'] 49 | gt_img = scipy.io.loadmat(pair[0])['patch'] 50 | input_list.append(input_img) 51 | gt_list.append(gt_img) 52 | input_list = np.array(input_list) 53 | input_list.resize([BATCH_SIZE, IMG_SIZE[1], IMG_SIZE[0], 1]) 54 | gt_list = np.array(gt_list) 55 | gt_list.resize([BATCH_SIZE, IMG_SIZE[1], IMG_SIZE[0], 1]) 56 | return input_list, gt_list, np.array(cbcr_list) 57 | 58 | def get_test_image(test_list, offset, batch_size): 59 | target_list = test_list[offset:offset+batch_size] 60 | input_list = [] 61 | gt_list = [] 62 | for pair in target_list: 63 | mat_dict = scipy.io.loadmat(pair[1]) 64 | input_img = None 65 | if mat_dict.has_key("img_2"): input_img = mat_dict["img_2"] 66 | elif mat_dict.has_key("img_3"): input_img = mat_dict["img_3"] 67 | elif mat_dict.has_key("img_4"): input_img = mat_dict["img_4"] 68 | else: continue 69 | gt_img = scipy.io.loadmat(pair[0])['img_raw'] 70 | input_list.append(input_img[:,:,0]) 71 | gt_list.append(gt_img[:,:,0]) 72 | return input_list, gt_list 73 | 74 | if __name__ == '__main__': 75 | train_list = get_train_list(DATA_PATH) 76 | 77 | if not USE_QUEUE_LOADING: 78 | print "not use queue loading, just sequential loading..." 79 | 80 | 81 | ### WITHOUT ASYNCHRONOUS DATA LOADING ### 82 | 83 | train_input = tf.placeholder(tf.float32, shape=(BATCH_SIZE, IMG_SIZE[0], IMG_SIZE[1], 1)) 84 | train_gt = tf.placeholder(tf.float32, shape=(BATCH_SIZE, IMG_SIZE[0], IMG_SIZE[1], 1)) 85 | 86 | ### WITHOUT ASYNCHRONOUS DATA LOADING ### 87 | 88 | else: 89 | print "use queue loading" 90 | 91 | 92 | ### WITH ASYNCHRONOUS DATA LOADING ### 93 | 94 | train_input_single = tf.placeholder(tf.float32, shape=(IMG_SIZE[0], IMG_SIZE[1], 1)) 95 | train_gt_single = tf.placeholder(tf.float32, shape=(IMG_SIZE[0], IMG_SIZE[1], 1)) 96 | q = tf.FIFOQueue(10000, [tf.float32, tf.float32], [[IMG_SIZE[0], IMG_SIZE[1], 1], [IMG_SIZE[0], IMG_SIZE[1], 1]]) 97 | enqueue_op = q.enqueue([train_input_single, train_gt_single]) 98 | 99 | train_input, train_gt = q.dequeue_many(BATCH_SIZE) 100 | 101 | ### WITH ASYNCHRONOUS DATA LOADING ### 102 | 103 | 104 | shared_model = tf.make_template('shared_model', model) 105 | #train_output, weights = model(train_input) 106 | train_output, weights = shared_model(train_input) 107 | loss = tf.reduce_sum(tf.nn.l2_loss(tf.subtract(train_output, train_gt))) 108 | for w in weights: 109 | loss += tf.nn.l2_loss(w)*1e-4 110 | tf.summary.scalar("loss", loss) 111 | 112 | global_step = tf.Variable(0, trainable=False) 113 | learning_rate = tf.train.exponential_decay(BASE_LR, global_step*BATCH_SIZE, len(train_list)*LR_STEP_SIZE, LR_RATE, staircase=True) 114 | tf.summary.scalar("learning rate", learning_rate) 115 | 116 | optimizer = tf.train.AdamOptimizer(learning_rate)#tf.train.MomentumOptimizer(learning_rate, 0.9) 117 | opt = optimizer.minimize(loss, global_step=global_step) 118 | 119 | saver = tf.train.Saver(weights, max_to_keep=0) 120 | 121 | shuffle(train_list) 122 | config = tf.ConfigProto() 123 | #config.operation_timeout_in_ms=10000 124 | 125 | with tf.Session(config=config) as sess: 126 | #TensorBoard open log with "tensorboard --logdir=logs" 127 | if not os.path.exists('logs'): 128 | os.mkdir('logs') 129 | merged = tf.summary.merge_all() 130 | file_writer = tf.summary.FileWriter('logs', sess.graph) 131 | 132 | tf.initialize_all_variables().run() 133 | 134 | if model_path: 135 | print "restore model..." 136 | saver.restore(sess, model_path) 137 | print "Done" 138 | 139 | ### WITH ASYNCHRONOUS DATA LOADING ### 140 | def load_and_enqueue(coord, file_list, enqueue_op, train_input_single, train_gt_single, idx=0, num_thread=1): 141 | count = 0; 142 | length = len(file_list) 143 | try: 144 | while not coord.should_stop(): 145 | i = count % length; 146 | input_img = scipy.io.loadmat(file_list[i][1])['patch'].reshape([IMG_SIZE[0], IMG_SIZE[1], 1]) 147 | gt_img = scipy.io.loadmat(file_list[i][0])['patch'].reshape([IMG_SIZE[0], IMG_SIZE[1], 1]) 148 | sess.run(enqueue_op, feed_dict={train_input_single:input_img, train_gt_single:gt_img}) 149 | count+=1 150 | except Exception as e: 151 | print "stopping...", idx, e 152 | ### WITH ASYNCHRONOUS DATA LOADING ### 153 | threads = [] 154 | def signal_handler(signum,frame): 155 | sess.run(q.close(cancel_pending_enqueues=True)) 156 | coord.request_stop() 157 | coord.join(threads) 158 | print "Done" 159 | sys.exit(1) 160 | original_sigint = signal.getsignal(signal.SIGINT) 161 | signal.signal(signal.SIGINT, signal_handler) 162 | 163 | if USE_QUEUE_LOADING: 164 | # create threads 165 | num_thread=20 166 | coord = tf.train.Coordinator() 167 | for i in range(num_thread): 168 | length = len(train_list)/num_thread 169 | t = threading.Thread(target=load_and_enqueue, args=(coord, train_list[i*length:(i+1)*length],enqueue_op, train_input_single, train_gt_single, i, num_thread)) 170 | threads.append(t) 171 | t.start() 172 | print "num thread:" , len(threads) 173 | 174 | for epoch in xrange(0, MAX_EPOCH): 175 | max_step=len(train_list)//BATCH_SIZE 176 | for step in range(max_step): 177 | _,l,output,lr, g_step, summary = sess.run([opt, loss, train_output, learning_rate, global_step, merged]) 178 | print "[epoch %2.4f] loss %.4f\t lr %.5f"%(epoch+(float(step)*BATCH_SIZE/len(train_list)), np.sum(l)/BATCH_SIZE, lr) 179 | file_writer.add_summary(summary, step+epoch*max_step) 180 | #print "[epoch %2.4f] loss %.4f\t lr %.5f\t norm %.2f"%(epoch+(float(step)*BATCH_SIZE/len(train_list)), np.sum(l)/BATCH_SIZE, lr, norm) 181 | saver.save(sess, "./checkpoints/VDSR_adam_epoch_%03d.ckpt" % epoch ,global_step=global_step) 182 | else: 183 | for epoch in xrange(0, MAX_EPOCH): 184 | for step in range(len(train_list)//BATCH_SIZE): 185 | offset = step*BATCH_SIZE 186 | input_data, gt_data, cbcr_data = get_image_batch(train_list, offset, BATCH_SIZE) 187 | feed_dict = {train_input: input_data, train_gt: gt_data} 188 | _,l,output,lr, g_step = sess.run([opt, loss, train_output, learning_rate, global_step], feed_dict=feed_dict) 189 | print "[epoch %2.4f] loss %.4f\t lr %.5f"%(epoch+(float(step)*BATCH_SIZE/len(train_list)), np.sum(l)/BATCH_SIZE, lr) 190 | del input_data, gt_data, cbcr_data 191 | 192 | saver.save(sess, "./checkpoints/VDSR_const_clip_0.01_epoch_%03d.ckpt" % epoch ,global_step=global_step) 193 | 194 | --------------------------------------------------------------------------------