├── README.md ├── READYME_PNG ├── LMVE_qpx4.png ├── QE-subnet.png ├── compareWithMFQE.png ├── results_of_each_frame.png ├── 主观图排版_LMVE.png └── 主观图排版_wraped.png ├── experiment_tjc.xlsx ├── test ├── test_MF │ ├── LMVE_model.py │ ├── LMVE_test.py │ ├── LSVE_model.py │ ├── UTILS_MF_ra.py │ ├── __pycache__ │ │ ├── MFVDSingle_ra.cpython-35.pyc │ │ ├── UTILS_MF_ra.cpython-35.pyc │ │ └── yangNet.cpython-35.pyc │ └── checkpoints │ │ ├── MF_qp22_ra_01222_LMVE_double │ │ ├── MF_qp22_ra_01222_LMVE_double_149.ckpt.data-00000-of-00001 │ │ ├── MF_qp22_ra_01222_LMVE_double_149.ckpt.index │ │ ├── MF_qp22_ra_01222_LMVE_double_149.ckpt.meta │ │ └── checkpoint │ │ ├── MF_qp27_ra_01202_LMVE_double │ │ ├── MF_qp27_ra_01202_LMVE_double_201.ckpt.data-00000-of-00001 │ │ ├── MF_qp27_ra_01202_LMVE_double_201.ckpt.index │ │ ├── MF_qp27_ra_01202_LMVE_double_201.ckpt.meta │ │ └── checkpoint │ │ ├── MF_qp32_ra_01262_LMVE_double │ │ ├── MF_qp32_ra_01262_LMVE_double_138.ckpt.data-00000-of-00001 │ │ ├── MF_qp32_ra_01262_LMVE_double_138.ckpt.index │ │ ├── MF_qp32_ra_01262_LMVE_double_138.ckpt.meta │ │ └── checkpoint │ │ └── MF_qp37_ra_01121 │ │ ├── MF_qp37_ra_01121_402.ckpt.data-00000-of-00001 │ │ ├── MF_qp37_ra_01121_402.ckpt.index │ │ ├── MF_qp37_ra_01121_402.ckpt.meta │ │ ├── a.txt │ │ └── checkpoint └── test_Single │ ├── LSVE_ra_model.py │ ├── LSVE_test.py │ ├── UTILS_MF_ra_ALL_Single.py │ ├── __pycache__ │ ├── MFVDSingle_ra.cpython-35.pyc │ └── UTILS_MF_ra_ALL_Single.cpython-35.pyc │ ├── checkpoints │ ├── SF_qp22_ra_01251_LSVE │ │ ├── SF_qp22_ra_01251_LSVE_105.ckpt.data-00000-of-00001 │ │ ├── SF_qp22_ra_01251_LSVE_105.ckpt.index │ │ ├── SF_qp22_ra_01251_LSVE_105.ckpt.meta │ │ └── checkpoint │ ├── SF_qp27_ra_01232_LSVE │ │ ├── SF_qp27_ra_01232_LSVE_169.ckpt.data-00000-of-00001 │ │ ├── SF_qp27_ra_01232_LSVE_169.ckpt.index │ │ ├── SF_qp27_ra_01232_LSVE_169.ckpt.meta │ │ └── checkpoint │ ├── SF_qp32_ra_01242_LSVE │ │ ├── SF_qp32_ra_01242_LSVE_185.ckpt.data-00000-of-00001 │ │ ├── SF_qp32_ra_01242_LSVE_185.ckpt.index │ │ ├── SF_qp32_ra_01242_LSVE_185.ckpt.meta │ │ └── checkpoint │ └── SF_qp37_ra_01174_LSVE │ │ ├── MF_qp37_ra_01174_LSVE_291.ckpt.data-00000-of-00001 │ │ ├── MF_qp37_ra_01174_LSVE_291.ckpt.index │ │ ├── MF_qp37_ra_01174_LSVE_291.ckpt.meta │ │ └── checkpoint │ └── outdata │ └── SF_qp27_ra_01232_LSVE │ ├── events.out.tfevents.1549718071.AIR-PC │ └── events.out.tfevents.1549718367.AIR-PC └── train ├── train_MF ├── LMVE_ra_model.py ├── LMVE_ra_train.py └── UTILS_MF_ra.py └── train_Single ├── LSVE_Single_ra_train.py ├── LSVE_ra_model.py └── UTILS_single_ra.py /README.md: -------------------------------------------------------------------------------- 1 | ## LEARNING-BASED MULTI-FRAME VIDEO QUALITY ENHANCEMENT 2 | Junchao Tong*, Xilin Wu*, Dandan Ding*, Zheng Zhu**, Zoe Liu**
3 | \* Hangzhou Normal University
4 | ** Visionular Inc.
5 | 6 | ___ 7 | * ### Example of the subjective quality for the HEVC compressed video(the 1st column and the 2nd column) and the preprocessed frames(MF、HF). 8 | `Here, MF is the Moderate-quality Frame and HF is the High-quality Frame.` 9 | 10 | ![](https://github.com/IVC-Projects/LMVE/blob/master/READYME_PNG/主观图排版_wraped.png) 11 | 12 | ___ 13 | * ### Subjective quality performance on *FourPeople* and *BasketballPass* at QP=37 in RA configuration. 14 | ![](https://github.com/IVC-Projects/LMVE/blob/master/READYME_PNG/主观图排版_LMVE.png)
15 | 16 | ___ 17 | ## The experiment section 18 | ### 1.Performance of our proposed LMVE in comparison with LSVE and HEVC in terms of PSNR(dB). 19 | 20 | ![](https://github.com/IVC-Projects/LMVE/blob/master/READYME_PNG/LMVE_qpx4.png) 21 |
22 | 23 | ### 2.Performance of LMVE Compared with MFQE at QP=37. 24 | 25 | ![](https://github.com/IVC-Projects/LMVE/blob/master/READYME_PNG/compareWithMFQE.png) 26 |
27 |
28 | * The architecture of the QE-subnet is shown in Figure 6, and the details of the convolutional layers are presented in Table 3. 29 | ![](https://github.com/IVC-Projects/LMVE/blob/master/READYME_PNG/QE-subnet.png) 30 | 31 | ### 3.PSNR(dB) results of each frame in the first 21 frames on the FourPeople and PeopleOnStreet sequences. 32 | ![](https://github.com/IVC-Projects/LMVE/blob/master/READYME_PNG/results_of_each_frame.png) 33 | * `For more details on the experiment section, you can download the "experiment_tjc.xlsx" file in the current web page` 34 | 35 | ___ 36 | The open source address for FlowNet is: https://github.com/lmb-freiburg/flownet2 37 |
38 | The open source address for MFQE is: https://github.com/ryangBUAA/MFQE 39 | ___ 40 | Thanks for reading! 41 |
42 | best wishes,
43 | Junchao Tong. 44 | -------------------------------------------------------------------------------- /READYME_PNG/LMVE_qpx4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/READYME_PNG/LMVE_qpx4.png -------------------------------------------------------------------------------- /READYME_PNG/QE-subnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/READYME_PNG/QE-subnet.png -------------------------------------------------------------------------------- /READYME_PNG/compareWithMFQE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/READYME_PNG/compareWithMFQE.png -------------------------------------------------------------------------------- /READYME_PNG/results_of_each_frame.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/READYME_PNG/results_of_each_frame.png -------------------------------------------------------------------------------- /READYME_PNG/主观图排版_LMVE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/READYME_PNG/主观图排版_LMVE.png -------------------------------------------------------------------------------- /READYME_PNG/主观图排版_wraped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/READYME_PNG/主观图排版_wraped.png -------------------------------------------------------------------------------- /experiment_tjc.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/experiment_tjc.xlsx -------------------------------------------------------------------------------- /test/test_MF/LMVE_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def model_double(inputHigh1Data_tensor, inputLowData_tensor, inputHigh2Data_tensor): 6 | # with tf.device("/gpu:0"): 7 | input_before = inputHigh1Data_tensor # highData 8 | input_current = inputLowData_tensor # lowData 9 | 10 | input_after = inputHigh2Data_tensor 11 | 12 | # due to don't have training_Set at right now, so let it be annotation. 13 | tensor = None 14 | 15 | # ----------------------------------------------Frame -1-------------------------------------------------------- 16 | input_before_w = tf.get_variable("input_before_w", [5, 5, 1, 64], 17 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 25))) 18 | input_before_b = tf.get_variable("input_before_b", [64], initializer=tf.constant_initializer(0)) 19 | input_high1_tensor = tf.nn.relu( 20 | tf.nn.bias_add(tf.nn.conv2d(input_before, input_before_w, strides=[1, 1, 1, 1], padding='SAME'), input_before_b)) 21 | # ----------------------------------------------Frame 0-------------------------------------------------------- 22 | input_current_w = tf.get_variable("input_current_w", [5, 5, 1, 64], 23 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 25))) 24 | input_current_b = tf.get_variable("input_current_b", [64], initializer=tf.constant_initializer(0)) 25 | input_low_tensor = tf.nn.relu( 26 | tf.nn.bias_add(tf.nn.conv2d(input_current, input_current_w, strides=[1, 1, 1, 1], padding='SAME'), 27 | input_current_b)) 28 | 29 | # ----------------------------------------------Frame 1-------------------------------------------------------- 30 | input_after_w = tf.get_variable("input_after_w", [5, 5, 1, 64], 31 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 25))) 32 | input_after_b = tf.get_variable("input_after_b", [64], initializer=tf.constant_initializer(0)) 33 | input_high2_tensor = tf.nn.relu( 34 | tf.nn.bias_add(tf.nn.conv2d(input_after, input_after_w, strides=[1, 1, 1, 1], padding='SAME'), 35 | input_after_b)) 36 | # ------------------------------------------Frame -1\0\1 concat------------------------------------------ 37 | input_tensor_Concat = tf.concat([input_high1_tensor, input_low_tensor, input_high2_tensor], axis=3) 38 | # ----------------------------------1x1 conv, for reduce number of parameters---------------- 39 | input_1x1_w = tf.get_variable("input_1x1_w", [1, 1, 192, 64], 40 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 192))) 41 | input_1x1_b = tf.get_variable("input_1x1_b", [64], initializer=tf.constant_initializer(0)) 42 | input_1x1_tensor = tf.nn.relu( 43 | tf.nn.bias_add(tf.nn.conv2d(input_tensor_Concat, input_1x1_w, strides=[1, 1, 1, 1], padding='SAME'), 44 | input_1x1_b)) 45 | tensor = input_1x1_tensor 46 | 47 | # --------------------------------------start iteration for last layers---------------------------------- 48 | convId = 0 49 | for i in range(18): 50 | conv_w = tf.get_variable("conv_%02d_w" % (convId), [3, 3, 64, 64], 51 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 9 / 64))) 52 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, conv_w) 53 | conv_b = tf.get_variable("conv_%02d_b" % (convId), [64], initializer=tf.constant_initializer(0)) 54 | convId += 1 55 | tensor = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1, 1, 1, 1], padding='SAME'), conv_b)) 56 | 57 | conv_w = tf.get_variable("conv_%02d_w" % (convId), [3, 3, 64, 1], 58 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 9 / 64))) 59 | conv_b = tf.get_variable("conv_%02d_b" % (convId), [1], initializer=tf.constant_initializer(0)) 60 | convId += 1 61 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, conv_w) 62 | tensor = tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1, 1, 1, 1], padding='SAME'), conv_b) 63 | 64 | tensor = tf.add(tensor, input_current) 65 | return tensor 66 | -------------------------------------------------------------------------------- /test/test_MF/LMVE_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import itertools 3 | from LMVE_model import model_double 4 | from LSVE_model import model_single 5 | from UTILS_MF_ra import * 6 | 7 | tf.logging.set_verbosity(tf.logging.WARN) 8 | 9 | EXP_DATA1 = "MF_qp22_ra_01222_LMVE_double" # double checkpoints 10 | EXP_DATA2 = "MF_qp22_ra_01222_LMVE_double" # single checkpoints not be used 11 | MODEL_DOUBLE_PATH1 = "./checkpoints/%s/"%(EXP_DATA1) 12 | MODEL_DOUBLE_PATH2 = "./checkpoints/%s/"%(EXP_DATA2) 13 | HIGHDATA_Parent_PATH = r"E:\MF\HEVC_TestSequenct\qp37_ldp\warped_yuv\FourPeople_qp22_ai_1280x720" # warped high frames 14 | QP_LOWDATA_PATH = r'E:\MF\HEVC_TestSequenct\rec\qp37\ldp\E\FourPeople_qp37_ldp_1280x720' # low frames 15 | GT_PATH = r"E:\MF\HEVC_TestSequenct\org\E\FourPeople_1280x720" # ground truth 16 | DL_path = 'E:\MF\HEVC_TestSequenct\DL_rec\double\qp37' 17 | OUT_DATA_PATH = "./outdata/%s/"%(EXP_DATA1) 18 | 19 | 20 | # Ground truth images dir should be the 2nd component of 'fileOrDir' if 2 components are given. 21 | 22 | ##cb, cr components are not implemented 23 | def prepare_test_data(fileOrDir): 24 | doubleData_ycbcr = [] 25 | doubleGT_y = [] 26 | singleData_ycbcr = [] 27 | singleGT_y = [] 28 | fileName_list = [] 29 | #The input is a single file. 30 | if len(fileOrDir) == 3: 31 | # return the whole absolute path. 32 | fileName_list = load_file_list(fileOrDir[1]) 33 | # double_list # [[high, low1, label1], [[h21,h22], low2, label2]] 34 | # single_list # [[low1, lable1], [2,2] ....] 35 | double_list, single_list = get_test_list(HIGHDATA_Parent_PATH, load_file_list(fileOrDir[1]), 36 | load_file_list(fileOrDir[2])) 37 | 38 | for pair in double_list: 39 | high1Data_List = [] 40 | lowData_List = [] 41 | high2Data_List = [] 42 | 43 | high1Data_imgY = c_getYdata(pair[0][0]) 44 | high2Data_imgY = c_getYdata(pair[0][1]) 45 | lowData_imgY = c_getYdata(pair[1]) 46 | CbCr = c_getCbCr(pair[1]) 47 | gt_imgY = c_getYdata(pair[2]) 48 | 49 | #normalize 50 | high1Data_imgY = normalize(high1Data_imgY) 51 | lowData_imgY = normalize(lowData_imgY) 52 | high2Data_imgY = normalize(high2Data_imgY) 53 | 54 | high1Data_imgY = np.resize(high1Data_imgY, (1, high1Data_imgY.shape[0], high1Data_imgY.shape[1],1)) 55 | lowData_imgY = np.resize(lowData_imgY, (1, lowData_imgY.shape[0], lowData_imgY.shape[1], 1)) 56 | high2Data_imgY = np.resize(high2Data_imgY, (1, high2Data_imgY.shape[0], high2Data_imgY.shape[1], 1)) 57 | gt_imgY = np.resize(gt_imgY, (1, gt_imgY.shape[0], gt_imgY.shape[1],1)) 58 | 59 | ## act as a placeholder 60 | 61 | high1Data_List.append([high1Data_imgY, 0]) 62 | lowData_List.append([lowData_imgY, CbCr]) 63 | high2Data_List.append([high2Data_imgY, 0]) 64 | doubleData_ycbcr.append([high1Data_List, lowData_List, high2Data_List]) 65 | doubleGT_y.append(gt_imgY) 66 | 67 | # single_list # [[low1, lable1], [2,2] ....] 68 | for pair in single_list: 69 | lowData_list = [] 70 | lowData_imgY = c_getYdata(pair[0]) 71 | CbCr = c_getCbCr(pair[0]) 72 | gt_imgY = c_getYdata(pair[1]) 73 | 74 | # normalize 75 | lowData_imgY = normalize(lowData_imgY) 76 | 77 | lowData_imgY = np.resize(lowData_imgY, (1, lowData_imgY.shape[0], lowData_imgY.shape[1], 1)) 78 | gt_imgY = np.resize(gt_imgY, (1, gt_imgY.shape[0], gt_imgY.shape[1], 1)) 79 | 80 | lowData_list.append([lowData_imgY, CbCr]) 81 | singleData_ycbcr.append(lowData_list) 82 | singleGT_y.append(gt_imgY) 83 | 84 | else: 85 | print("Invalid Inputs...!tjc!") 86 | exit(0) 87 | 88 | return doubleData_ycbcr, doubleGT_y, singleData_ycbcr, singleGT_y, fileName_list 89 | 90 | def test_all_ckpt(modelPath1, modelPath2, fileOrDir): 91 | max = [0, 0] 92 | tem1 = [f for f in os.listdir(modelPath1) if 'data' in f] 93 | ckptFiles1 = sorted([r.split('.data')[0] for r in tem1]) 94 | tem2 = [f for f in os.listdir(modelPath2) if 'data' in f] 95 | ckptFiles2 = sorted([r.split('.data')[0] for r in tem2]) 96 | re_psnr = tf.placeholder('float32') 97 | tf.summary.scalar('re_psnr', re_psnr) 98 | 99 | doubleData_ycbcr, doubleGT_y, singleData_ycbcr, singleGT_y, fileName_list = prepare_test_data(fileOrDir) 100 | total_time, total_psnr = 0, 0 101 | total_imgs = len(fileName_list) 102 | count = 0 103 | for i in range(total_imgs): 104 | if i % 4 != 0: 105 | count += 1 106 | # sorry! this place write so difficult!【[[[h1,0]],[[low,0]],[[h2, 0]]], [[[h1,0]],[[low,0]],[[h2, 0]]]】 107 | j = i - (i//4) - 1 108 | imgHigh1DataY = doubleData_ycbcr[j][0][0][0] 109 | imgLowDataY = doubleData_ycbcr[j][1][0][0] 110 | imgLowCbCr = doubleData_ycbcr[j][1][0][1] 111 | imgHigh2DataY = doubleData_ycbcr[j][2][0][0] 112 | gtY = doubleGT_y[j] if doubleGT_y else 0 113 | start_t = time.time() 114 | for ckpt1 in ckptFiles1: 115 | epoch = int(ckpt1.split('_')[-1].split('.')[0]) 116 | if epoch != 149: 117 | continue 118 | # very important!!!!!!! 119 | tf.reset_default_graph() 120 | # Double section 121 | high1Data_tensor = tf.placeholder(tf.float32, shape=(1, None, None, 1)) 122 | lowData_tensor = tf.placeholder(tf.float32, shape=(1, None, None, 1)) 123 | high2Data_tensor = tf.placeholder(tf.float32, shape=(1, None, None, 1)) 124 | shared_model1 = tf.make_template('shared_model', model_double) 125 | output_tensor1 = shared_model1(high1Data_tensor, lowData_tensor, high2Data_tensor) 126 | # output_tensor = shared_model(input_tensor) 127 | output_tensor1 = tf.clip_by_value(output_tensor1, 0., 1.) 128 | output_tensor1 = output_tensor1 * 255 129 | with tf.Session() as sess: 130 | saver = tf.train.Saver(tf.global_variables()) 131 | sess.run(tf.global_variables_initializer()) 132 | 133 | saver.restore(sess, os.path.join(modelPath1, ckpt1)) 134 | out = sess.run(output_tensor1, feed_dict={high1Data_tensor: imgHigh1DataY, 135 | lowData_tensor:imgLowDataY, high2Data_tensor: imgHigh2DataY}) 136 | hevc = psnr(imgLowDataY * 255.0, gtY) 137 | out = np.around(out) 138 | out = out.astype('int') 139 | out = np.reshape(out, [1, out.shape[1], out.shape[2], 1]) 140 | 141 | Y = np.reshape(out, [out.shape[1], out.shape[2]]) 142 | Y = np.array(list(itertools.chain.from_iterable(Y))) 143 | U = imgLowCbCr[0] 144 | V = imgLowCbCr[1] 145 | creatPath = os.path.join(DL_path, fileName_list[i].split('\\')[-2]) 146 | if not os.path.exists(creatPath): 147 | os.mkdir(creatPath) 148 | 149 | if doubleGT_y: 150 | p = psnr(out, gtY) 151 | 152 | path = os.path.join(DL_path, 153 | fileName_list[i].split('\\')[-2], 154 | fileName_list[i].split('\\')[-1].split('.')[0]) + '_%.4f' % (p-hevc)+ '.yuv' 155 | 156 | YUV = np.concatenate((Y, U, V)) 157 | YUV = YUV.astype('uint8') 158 | YUV.tofile(path) 159 | 160 | total_psnr += p 161 | print("qp37\tepoch:%d\t%s\t%.4f\n" % (epoch, fileName_list[i], p)) 162 | 163 | duration_t = time.time() - start_t 164 | total_time += duration_t 165 | else: #single frame 166 | continue 167 | count += 1 168 | j = i // 4 169 | lowDataY = singleData_ycbcr[j][0][0] 170 | imgLowCbCr = singleData_ycbcr[j][0][1] 171 | gtY = singleGT_y[j] if singleGT_y else 0 172 | hevc = psnr(lowDataY * 255.0, gtY) 173 | start_t = time.time() 174 | for ckpt2 in ckptFiles2: 175 | epoch = int(ckpt2.split('_')[-1].split('.')[0]) 176 | if epoch != 169: 177 | continue 178 | 179 | tf.reset_default_graph() 180 | # Single section 181 | lowSingleData_tensor = tf.placeholder(tf.float32, shape=(1, None, None, 1)) 182 | shared_model2 = tf.make_template('shared_model', model_single) 183 | output_tensor2 = shared_model2(lowSingleData_tensor) 184 | output_tensor2 = tf.clip_by_value(output_tensor2, 0., 1.) 185 | output_tensor2 = output_tensor2 * 255 186 | with tf.Session() as sess: 187 | saver = tf.train.Saver(tf.global_variables()) 188 | sess.run(tf.global_variables_initializer()) 189 | 190 | saver.restore(sess, os.path.join(modelPath2, ckpt2)) 191 | out = sess.run(output_tensor2, feed_dict={lowSingleData_tensor: lowDataY}) 192 | out = np.around(out) 193 | out = out.astype('int') 194 | out = np.reshape(out, [1, out.shape[1], out.shape[2], 1]) 195 | Y = np.reshape(out, [out.shape[1], out.shape[2]]) 196 | Y = np.array(list(itertools.chain.from_iterable(Y))) 197 | U = imgLowCbCr[0] 198 | V = imgLowCbCr[1] 199 | creatPath = os.path.join(DL_path, fileName_list[i].split('\\')[-2]) 200 | if not os.path.exists(creatPath): 201 | os.mkdir(creatPath) 202 | 203 | if singleGT_y: 204 | p = psnr(out, gtY) 205 | path = os.path.join(DL_path, fileName_list[i].split('\\')[-2], 206 | fileName_list[i].split('\\')[-1].split('.')[0]) + '_%.4f' % (p - hevc) + '.yuv' 207 | YUV = np.concatenate((Y, U, V)) 208 | YUV = YUV.astype('uint8') 209 | YUV.tofile(path) 210 | 211 | total_psnr += p 212 | print("qp37\tepoch:%d\t%s\t%.4f\n" % (epoch, fileName_list[i], p)) 213 | 214 | duration_t = time.time() - start_t 215 | total_time += duration_t 216 | 217 | 218 | print("AVG_DURATION:%.2f\tAVG_PSNR:%.4f"%(total_time/total_imgs, total_psnr / count)) 219 | print('count:', count) 220 | avg_psnr = total_psnr / count 221 | if avg_psnr > max[0]: 222 | max[0] = avg_psnr 223 | max[1] = epoch 224 | 225 | if __name__ == '__main__': 226 | test_all_ckpt(MODEL_DOUBLE_PATH1, MODEL_DOUBLE_PATH2, [HIGHDATA_Parent_PATH, QP_LOWDATA_PATH, GT_PATH]) -------------------------------------------------------------------------------- /test/test_MF/LSVE_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def model_single(inputLowData_tensor): 6 | # with tf.device("/gpu:0"): 7 | input_current = inputLowData_tensor # lowData 8 | 9 | # due to don't have training_Set at right now, so let it be annotation. 10 | tensor = None 11 | 12 | # ----------------------------------------------Frame 0-------------------------------------------------------- 13 | input_current_w = tf.get_variable("input_current_w", [5, 5, 1, 64], 14 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 25))) 15 | input_current_b = tf.get_variable("input_current_b", [64], initializer=tf.constant_initializer(0)) 16 | input_low_tensor = tf.nn.relu( 17 | tf.nn.bias_add(tf.nn.conv2d(input_current, input_current_w, strides=[1, 1, 1, 1], padding='SAME'), 18 | input_current_b)) 19 | 20 | # ----------------------------------1x1 conv, for reduce number of parameters---------------- 21 | input_3x3_w = tf.get_variable("input_1x1_w", [3, 3, 64, 64], 22 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 64 / 9))) 23 | input_1x1_b = tf.get_variable("input_1x1_b", [64], initializer=tf.constant_initializer(0)) 24 | input_3x3_tensor = tf.nn.relu( 25 | tf.nn.bias_add(tf.nn.conv2d(input_low_tensor, input_3x3_w, strides=[1, 1, 1, 1], padding='SAME'), 26 | input_1x1_b)) 27 | tensor = input_3x3_tensor 28 | 29 | # --------------------------------------start iteration for last layers---------------------------------- 30 | convId = 0 31 | for i in range(18): 32 | conv_w = tf.get_variable("conv_%02d_w" % (convId), [3, 3, 64, 64], 33 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 9 / 64))) 34 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, conv_w) 35 | conv_b = tf.get_variable("conv_%02d_b" % (convId), [64], initializer=tf.constant_initializer(0)) 36 | convId += 1 37 | tensor = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1, 1, 1, 1], padding='SAME'), conv_b)) 38 | 39 | conv_w = tf.get_variable("conv_%02d_w" % (convId), [3, 3, 64, 1], 40 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 9 / 64))) 41 | conv_b = tf.get_variable("conv_%02d_b" % (convId), [1], initializer=tf.constant_initializer(0)) 42 | convId += 1 43 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, conv_w) 44 | tensor = tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1, 1, 1, 1], padding='SAME'), conv_b) 45 | tensor = tf.add(tensor, input_current) 46 | return tensor 47 | -------------------------------------------------------------------------------- /test/test_MF/UTILS_MF_ra.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import math, os, random, re 4 | from PIL import Image 5 | from LMVE_test import BATCH_SIZE 6 | from LMVE_test import PATCH_SIZE 7 | 8 | # due to a batch trainingSet come from one picture. I design a algorithm to make the TrainingSet more diversity. 9 | def normalize(x): 10 | x = x / 255. 11 | return truncate(x, 0., 1.) 12 | 13 | def denormalize(x): 14 | x = x * 255. 15 | return truncate(x, 0., 255.) 16 | 17 | def truncate(input, min, max): 18 | input = np.where(input > min, input, min) 19 | input = np.where(input < max, input, max) 20 | return input 21 | 22 | def remap(input): 23 | input = 16+219/255*input 24 | return truncate(input, 16.0, 235.0) 25 | 26 | def deremap(input): 27 | input = (input-16)*255/219 28 | return truncate(input, 0.0, 255.0) 29 | 30 | # return the whole absolute path. 31 | def load_file_list(directory): 32 | list = [] 33 | for filename in [y for y in os.listdir(directory) if os.path.isfile(os.path.join(directory,y))]: 34 | list.append(os.path.join(directory,filename)) 35 | return list 36 | 37 | def searchHighData(currentLowDataIndex, highDataList, highIndexList): 38 | searchOffset = 3 39 | searchedHighDataIndexList = [] 40 | searchedHighData = [] 41 | for i in range(currentLowDataIndex - searchOffset, currentLowDataIndex + searchOffset + 1): 42 | if i in highIndexList: 43 | searchedHighDataIndexList.append(i) 44 | assert len(searchedHighDataIndexList) == 2, 'search method have error!' 45 | for tempData in highDataList: 46 | if int(os.path.basename(tempData).split('.')[0].split('_')[-1]) \ 47 | == searchedHighDataIndexList[0] == searchedHighDataIndexList[1]: 48 | searchedHighData.append(tempData) 49 | return searchedHighData 50 | 51 | 52 | # return like this"[[[high1Data, lowData], label], [[2, 7, 8], 22], [[3, 8, 9], 33]]" with the whole path. 53 | def get_test_list2(highDataList, lowDataList, labelList): 54 | assert len(lowDataList) == len(labelList), "low:%d, label:%d,"%(len(lowDataList) , len(labelList)) 55 | # [0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48] 56 | highIndexList = [q for q in range(49) if q % 4 == 0] 57 | test_list = [] 58 | for tempDataPath in lowDataList: 59 | tempData = [] 60 | temp = [] 61 | # this place should changed on the different situation. 62 | currentLowDataIndex = int(os.path.basename(tempDataPath).split('.')[0].split('_')[-1]) 63 | searchedHighData = searchHighData(currentLowDataIndex, highDataList, highIndexList) 64 | tempData.append(searchedHighData[0]) 65 | tempData.append(tempDataPath) 66 | tempData.append(searchedHighData[1]) 67 | 68 | i = list(lowDataList).index(tempDataPath) 69 | 70 | temp.append(tempData) 71 | temp.append(labelList[i]) 72 | test_list.append(temp) 73 | return test_list 74 | 75 | def get_temptest_list(high1DataList, lowDataList, high2DataList, labelList): 76 | tempData = [] 77 | temp = [] 78 | test_list = [] 79 | for i in range(len(lowDataList)): 80 | tempData.append(high1DataList[i]) 81 | tempData.append(lowDataList[i]) 82 | tempData.append(high2DataList[i]) 83 | temp.append(tempData) 84 | temp.append(labelList[i]) 85 | 86 | test_list.append(temp) 87 | return test_list 88 | 89 | # [[high, low1, label1], [[h21,h22], low2, label2]] 90 | def get_test_list(HIGHDATA_Parent_PATH, lowDataList, labelList): 91 | doubleTest_list = [] 92 | singleTest_list = [] 93 | HighDirList = os.listdir(HIGHDATA_Parent_PATH) 94 | # convert string-list to int-list. 95 | HighDirList = list(map(int, HighDirList)) 96 | 97 | for lowdata in lowDataList: 98 | 99 | tempData = [] 100 | lowdataIndex = int(os.path.basename(lowdata).split('.')[0].split('_')[-1]) 101 | if lowdataIndex % 4 != 0: 102 | if lowdataIndex in HighDirList: 103 | High_current_path = HIGHDATA_Parent_PATH + '/' + str(lowdataIndex) 104 | TWO_HighData = os.listdir(High_current_path) 105 | Two_HighData = [os.path.join(High_current_path, T) for T in TWO_HighData] 106 | tempData.append(Two_HighData) 107 | tempData.append(lowdata) 108 | labelIndex = list(lowDataList).index(lowdata) 109 | tempData.append(labelList[labelIndex]) 110 | doubleTest_list.append(tempData) 111 | else: 112 | tempData.append(lowdata) 113 | labelIndex = list(lowDataList).index(lowdata) 114 | if int(os.path.basename(lowdata).split('.')[0].split('_')[-1]) == \ 115 | int(os.path.basename(labelList[labelIndex]).split('.')[0].split('_')[-1]): 116 | 117 | tempData.append(labelList[labelIndex]) 118 | singleTest_list.append(tempData) 119 | 120 | return doubleTest_list, singleTest_list 121 | 122 | # return like this"[[[high1Data, lowData, high2Data], label], [[2, 7, 8], 22], [[3, 8, 9], 33]]" with the whole path. 123 | def get_train_list(high1DataList, lowDataList, high2DataList, labelList): 124 | assert len(lowDataList) == len(high1DataList) == len(labelList) == len(high2DataList), \ 125 | "low:%d, high1:%d, label:%d, high2:%d"%(len(lowDataList), len(high1DataList), len(labelList), len(high2DataList)) 126 | 127 | train_list = [] 128 | for i in range(len(labelList)): 129 | tempData = [] 130 | temp = [] 131 | # this place should changed on the different situation. 132 | if int(os.path.basename(high1DataList[i]).split('_')[-1].split('.')[0]) + 4 == \ 133 | int(os.path.basename(lowDataList[i]).split('_')[-1].split('.')[0]) + 2 == \ 134 | int(os.path.basename(high2DataList[i]).split('_')[-1].split('.')[0]): 135 | tempData.append(high1DataList[i]) 136 | tempData.append(lowDataList[i]) 137 | tempData.append(high2DataList[i]) 138 | temp.append(tempData) 139 | temp.append(labelList[i]) 140 | 141 | else: 142 | raise Exception('len(lowData) not equal with len(highData)...') 143 | train_list.append(temp) 144 | return train_list 145 | 146 | def prepare_nn_data(train_list): 147 | batchSizeRandomList = random.sample(range(0,len(train_list)), 8) 148 | gt_list = [] 149 | high1Data_list = [] 150 | lowData_list = [] 151 | high2Data_list = [] 152 | for i in batchSizeRandomList: 153 | high1Data_image = c_getYdata(train_list[i][0][0]) 154 | lowData_image = c_getYdata(train_list[i][0][1]) 155 | high2Data_image = c_getYdata(train_list[i][0][2]) 156 | gt_image = c_getYdata(train_list[i][1]) 157 | for j in range(0, 8): 158 | #crop images to the disired size. 159 | high1Data_imgY, lowData_imgY, high2Data_imgY, gt_imgY = \ 160 | crop(high1Data_image, lowData_image, high2Data_image, gt_image, PATCH_SIZE[0], PATCH_SIZE[1], "ndarray") 161 | 162 | #normalize 163 | high1Data_imgY = normalize(high1Data_imgY) 164 | lowData_imgY = normalize(lowData_imgY) 165 | high2Data_imgY = normalize(high2Data_imgY) 166 | gt_imgY = normalize(gt_imgY) 167 | 168 | high1Data_list.append(high1Data_imgY) 169 | lowData_list.append(lowData_imgY) 170 | high2Data_list.append(high2Data_imgY) 171 | gt_list.append(gt_imgY) 172 | 173 | high1Data_list = np.resize(high1Data_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 174 | lowData_list = np.resize(lowData_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 175 | high2Data_list = np.resize(high2Data_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 176 | gt_list = np.resize(gt_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 177 | 178 | return high1Data_list, lowData_list, high2Data_list, gt_list 179 | 180 | def getWH(yuvfileName): 181 | deyuv=re.compile(r'(.+?)\.') 182 | deyuvFilename=deyuv.findall(yuvfileName)[0] #去yuv后缀的文件名 183 | if 'x' in os.path.basename(deyuvFilename).split('_')[-2]: 184 | wxh = os.path.basename(deyuvFilename).split('_')[-2] 185 | elif 'x' in os.path.basename(deyuvFilename).split('_')[1]: 186 | wxh = os.path.basename(deyuvFilename).split('_')[1] 187 | else: 188 | # print(yuvfileName) 189 | raise Exception('do not find wxh') 190 | w, h = wxh.split('x') 191 | return int(w), int(h) 192 | 193 | def getYdata(path, size): 194 | w = size[0] 195 | h = size[1] 196 | with open(path, 'rb') as fp: 197 | fp.seek(0, 0) 198 | Yt = fp.read() 199 | tem = Image.frombytes('L', [w, h], Yt) 200 | Yt = np.asarray(tem, dtype='float32') 201 | return Yt 202 | 203 | 204 | def c_getYdata(path): 205 | return getYdata(path, getWH(path)) 206 | 207 | def c_getCbCr(path): 208 | w, h = getWH(path) 209 | CbCr = [] 210 | with open(path, 'rb+') as file: 211 | y = file.read(h * w) 212 | if y == b'': 213 | return '' 214 | u = file.read(h * w // 4) 215 | v = file.read(h * w // 4) 216 | # convert string-list to int-list. 217 | u = list(map(int, u)) 218 | v = list(map(int, v)) 219 | CbCr.append(u) 220 | CbCr.append(v) 221 | return CbCr 222 | 223 | def img2y(input_img): 224 | if np.asarray(input_img).shape[2] == 3: 225 | input_imgY = input_img.convert('YCbCr').split()[0] 226 | input_imgCb, input_imgCr = input_img.convert('YCbCr').split()[1:3] 227 | input_imgY = np.asarray(input_imgY, dtype='float32') 228 | input_imgCb = np.asarray(input_imgCb, dtype='float32') 229 | input_imgCr = np.asarray(input_imgCr, dtype='float32') 230 | 231 | #Concatenate Cb, Cr components for easy, they are used in pair anyway. 232 | input_imgCb = np.expand_dims(input_imgCb,2) 233 | input_imgCr = np.expand_dims(input_imgCr,2) 234 | input_imgCbCr = np.concatenate((input_imgCb, input_imgCr), axis=2) 235 | 236 | elif np.asarray(input_img).shape[2] == 1: 237 | print("This image has one channal only.") 238 | #If the num of channal is 1, remain. 239 | input_imgY = input_img 240 | input_imgCbCr = None 241 | else: 242 | print("The num of channal is neither 3 nor 1.") 243 | exit() 244 | return input_imgY, input_imgCbCr 245 | 246 | # def crop(input_image, gt_image, patch_width, patch_height, img_type): 247 | def crop(high1Data_image, lowData_image, high2Data_image, gt_image, patch_width, patch_height, img_type): 248 | assert type(high1Data_image) == type(gt_image) == type(lowData_image) == type(high2Data_image), "types are different." 249 | high1Data_cropped = [] 250 | lowData_cropped = [] 251 | high2Data_cropped = [] 252 | gt_cropped = [] 253 | 254 | # return a ndarray object 255 | if img_type == "ndarray": 256 | in_row_ind = random.randint(0,high1Data_image.shape[0]-patch_width) 257 | in_col_ind = random.randint(0,high1Data_image.shape[1]-patch_height) 258 | 259 | high1Data_cropped = high1Data_image[in_row_ind:in_row_ind+patch_width, in_col_ind:in_col_ind+patch_height] 260 | lowData_cropped = lowData_image[in_row_ind:in_row_ind + patch_width, in_col_ind:in_col_ind + patch_height] 261 | high2Data_cropped = high2Data_image[in_row_ind:in_row_ind + patch_width, in_col_ind:in_col_ind + patch_height] 262 | gt_cropped = gt_image[in_row_ind:in_row_ind+patch_width, in_col_ind:in_col_ind+patch_height] 263 | 264 | #return an "Image" object 265 | elif img_type == "Image": 266 | pass 267 | return high1Data_cropped, lowData_cropped, high2Data_cropped, gt_cropped 268 | 269 | def save_images(inputY, inputCbCr, size, image_path): 270 | """Save mutiple images into one single image. 271 | 272 | # Parameters 273 | # ----------- 274 | # images : numpy array [batch, w, h, c] 275 | # size : list of two int, row and column number. 276 | # number of images should be equal or less than size[0] * size[1] 277 | # image_path : string. 278 | # 279 | # Examples 280 | # --------- 281 | # # >>> images = np.random.rand(64, 100, 100, 3) 282 | # # >>> tl.visualize.save_images(images, [8, 8], 'temp.png') 283 | """ 284 | def merge(images, size): 285 | h, w = images.shape[1], images.shape[2] 286 | img = np.zeros((h * size[0], w * size[1], 3)) 287 | for idx, image in enumerate(images): 288 | i = idx % size[1] 289 | j = idx // size[1] 290 | img[j*h:j*h+h, i*w:i*w+w, :] = image 291 | return img 292 | 293 | inputY = inputY.astype('uint8') 294 | inputCbCr = inputCbCr.astype('uint8') 295 | output_concat = np.concatenate((inputY, inputCbCr), axis=3) 296 | 297 | assert len(output_concat) <= size[0] * size[1], "number of images should be equal or less than size[0] * size[1] {}".format(len(output_concat)) 298 | 299 | new_output = merge(output_concat, size) 300 | 301 | new_output = new_output.astype('uint8') 302 | 303 | img = Image.fromarray(new_output, mode='YCbCr') 304 | img = img.convert('RGB') 305 | img.save(image_path) 306 | 307 | def get_image_batch(train_list,offset,batch_size): 308 | target_list = train_list[offset:offset+batch_size] 309 | input_list = [] 310 | gt_list = [] 311 | inputcbcr_list = [] 312 | for pair in target_list: 313 | input_img = Image.open(pair[0]) 314 | gt_img = Image.open(pair[1]) 315 | 316 | #crop images to the disired size. 317 | input_img, gt_img = crop(input_img, gt_img, PATCH_SIZE[0], PATCH_SIZE[1], "Image") 318 | 319 | #focus on Y channal only 320 | input_imgY, input_imgCbCr = img2y(input_img) 321 | gt_imgY, gt_imgCbCr = img2y(gt_img) 322 | 323 | #input_imgY = normalize(input_imgY) 324 | #gt_imgY = normalize(gt_imgY) 325 | 326 | input_list.append(input_imgY) 327 | gt_list.append(gt_imgY) 328 | inputcbcr_list.append(input_imgCbCr) 329 | 330 | input_list = np.resize(input_list, (batch_size, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 331 | gt_list = np.resize(gt_list, (batch_size, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 332 | 333 | return input_list, gt_list, inputcbcr_list 334 | 335 | def save_test_img(inputY, inputCbCr, path): 336 | assert len(inputY.shape) == 4, "the tensor Y's shape is %s"%inputY.shape 337 | assert inputY.shape[0] == 1, "the fitst component must be 1, has not been completed otherwise.{}".format(inputY.shape) 338 | 339 | inputY = np.squeeze(inputY, axis=0) 340 | inputY = inputY.astype('uint8') 341 | 342 | inputCbCr = inputCbCr.astype('uint8') 343 | 344 | output_concat = np.concatenate((inputY, inputCbCr), axis=2) 345 | img = Image.fromarray(output_concat, mode='YCbCr') 346 | img = img.convert('RGB') 347 | img.save(path) 348 | 349 | def psnr(hr_image, sr_image, max_value=255.0): 350 | eps = 1e-10 351 | if((type(hr_image)==type(np.array([]))) or (type(hr_image)==type([]))): 352 | hr_image_data = np.asarray(hr_image, 'float32') 353 | sr_image_data = np.asarray(sr_image, 'float32') 354 | 355 | diff = sr_image_data - hr_image_data 356 | mse = np.mean(diff*diff) 357 | mse = np.maximum(eps, mse) 358 | return float(10*math.log10(max_value*max_value/mse)) 359 | else: 360 | assert len(hr_image.shape)==4 and len(sr_image.shape)==4 361 | diff = hr_image - sr_image 362 | mse = tf.reduce_mean(tf.square(diff)) 363 | mse = tf.maximum(mse, eps) 364 | return 10*tf.log(max_value*max_value/mse)/math.log(10) 365 | 366 | def getBeforeNNBlockDict(img, w, h): 367 | # print(img[:1500, : 2000]) 368 | blockSize = 1000 369 | padding = 32 370 | yBlockNum = (h // blockSize) if (h % blockSize == 0) else (h // blockSize + 1) 371 | xBlockNum = (w // blockSize) if (w % blockSize == 0) else (w // blockSize + 1) 372 | tempImg = {} 373 | i = 0 374 | for yBlock in range(yBlockNum): 375 | for xBlock in range(xBlockNum): 376 | if yBlock == 0: 377 | if xBlock == 0: 378 | tempImg[i] = img[0: blockSize+padding, 0: blockSize+padding] 379 | elif xBlock == xBlockNum - 1: 380 | tempImg[i] = img[0: blockSize+padding, xBlock*blockSize-padding: w] 381 | else: 382 | tempImg[i] = img[0: blockSize+padding, blockSize*xBlock-padding: blockSize*(xBlock+1)+padding] 383 | elif yBlock == yBlockNum - 1: 384 | if xBlock == 0: 385 | tempImg[i] = img[blockSize*yBlock-padding: h, 0: blockSize+padding] 386 | elif xBlock == xBlockNum - 1: 387 | tempImg[i] = img[blockSize*yBlock-padding: h, blockSize*xBlock-padding: w] 388 | else: 389 | tempImg[i] = img[blockSize*yBlock-padding: h, blockSize*xBlock-padding: blockSize*(xBlock+1)+padding] 390 | elif xBlock == 0: 391 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding, 0: blockSize+padding] 392 | elif xBlock == xBlockNum - 1: 393 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding, blockSize*xBlock-padding: w] 394 | else: 395 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding, 396 | blockSize*xBlock-padding: blockSize*(xBlock+1)+padding] 397 | i += i 398 | l = tempImg[i].astype('uint8') 399 | l = Image.fromarray(l) 400 | l.show() 401 | -------------------------------------------------------------------------------- /test/test_MF/__pycache__/MFVDSingle_ra.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/__pycache__/MFVDSingle_ra.cpython-35.pyc -------------------------------------------------------------------------------- /test/test_MF/__pycache__/UTILS_MF_ra.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/__pycache__/UTILS_MF_ra.cpython-35.pyc -------------------------------------------------------------------------------- /test/test_MF/__pycache__/yangNet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/__pycache__/yangNet.cpython-35.pyc -------------------------------------------------------------------------------- /test/test_MF/checkpoints/MF_qp22_ra_01222_LMVE_double/MF_qp22_ra_01222_LMVE_double_149.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp22_ra_01222_LMVE_double/MF_qp22_ra_01222_LMVE_double_149.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /test/test_MF/checkpoints/MF_qp22_ra_01222_LMVE_double/MF_qp22_ra_01222_LMVE_double_149.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp22_ra_01222_LMVE_double/MF_qp22_ra_01222_LMVE_double_149.ckpt.index -------------------------------------------------------------------------------- /test/test_MF/checkpoints/MF_qp22_ra_01222_LMVE_double/MF_qp22_ra_01222_LMVE_double_149.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp22_ra_01222_LMVE_double/MF_qp22_ra_01222_LMVE_double_149.ckpt.meta -------------------------------------------------------------------------------- /test/test_MF/checkpoints/MF_qp22_ra_01222_LMVE_double/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "MF_qp22_ra_01222_MFVD_double_149.ckpt" 2 | all_model_checkpoint_paths: "MF_qp22_ra_01222_MFVD_double_149.ckpt" 3 | -------------------------------------------------------------------------------- /test/test_MF/checkpoints/MF_qp27_ra_01202_LMVE_double/MF_qp27_ra_01202_LMVE_double_201.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp27_ra_01202_LMVE_double/MF_qp27_ra_01202_LMVE_double_201.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /test/test_MF/checkpoints/MF_qp27_ra_01202_LMVE_double/MF_qp27_ra_01202_LMVE_double_201.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp27_ra_01202_LMVE_double/MF_qp27_ra_01202_LMVE_double_201.ckpt.index -------------------------------------------------------------------------------- /test/test_MF/checkpoints/MF_qp27_ra_01202_LMVE_double/MF_qp27_ra_01202_LMVE_double_201.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp27_ra_01202_LMVE_double/MF_qp27_ra_01202_LMVE_double_201.ckpt.meta -------------------------------------------------------------------------------- /test/test_MF/checkpoints/MF_qp27_ra_01202_LMVE_double/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "MF_qp27_ra_01202_MFVD_double_201.ckpt" 2 | all_model_checkpoint_paths: "MF_qp27_ra_01202_MFVD_double_201.ckpt" 3 | -------------------------------------------------------------------------------- /test/test_MF/checkpoints/MF_qp32_ra_01262_LMVE_double/MF_qp32_ra_01262_LMVE_double_138.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp32_ra_01262_LMVE_double/MF_qp32_ra_01262_LMVE_double_138.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /test/test_MF/checkpoints/MF_qp32_ra_01262_LMVE_double/MF_qp32_ra_01262_LMVE_double_138.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp32_ra_01262_LMVE_double/MF_qp32_ra_01262_LMVE_double_138.ckpt.index -------------------------------------------------------------------------------- /test/test_MF/checkpoints/MF_qp32_ra_01262_LMVE_double/MF_qp32_ra_01262_LMVE_double_138.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp32_ra_01262_LMVE_double/MF_qp32_ra_01262_LMVE_double_138.ckpt.meta -------------------------------------------------------------------------------- /test/test_MF/checkpoints/MF_qp32_ra_01262_LMVE_double/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "MF_qp32_ra_01262_MFVD_double_138.ckpt" 2 | all_model_checkpoint_paths: "MF_qp32_ra_01262_MFVD_double_138.ckpt" 3 | -------------------------------------------------------------------------------- /test/test_MF/checkpoints/MF_qp37_ra_01121/MF_qp37_ra_01121_402.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp37_ra_01121/MF_qp37_ra_01121_402.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /test/test_MF/checkpoints/MF_qp37_ra_01121/MF_qp37_ra_01121_402.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp37_ra_01121/MF_qp37_ra_01121_402.ckpt.index -------------------------------------------------------------------------------- /test/test_MF/checkpoints/MF_qp37_ra_01121/MF_qp37_ra_01121_402.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp37_ra_01121/MF_qp37_ra_01121_402.ckpt.meta -------------------------------------------------------------------------------- /test/test_MF/checkpoints/MF_qp37_ra_01121/a.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp37_ra_01121/a.txt -------------------------------------------------------------------------------- /test/test_MF/checkpoints/MF_qp37_ra_01121/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "MF_qp37_ra_01121_002.ckpt" 2 | all_model_checkpoint_paths: "MF_qp37_ra_01121_002.ckpt" 3 | -------------------------------------------------------------------------------- /test/test_Single/LSVE_ra_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def model_single(inputLowData_tensor): 6 | # with tf.device("/gpu:0"): 7 | input_current = inputLowData_tensor # lowData 8 | 9 | # due to don't have training_Set at right now, so let it be annotation. 10 | tensor = None 11 | 12 | # ----------------------------------------------Frame 0-------------------------------------------------------- 13 | input_current_w = tf.get_variable("input_current_w", [5, 5, 1, 64], 14 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 25))) 15 | input_current_b = tf.get_variable("input_current_b", [64], initializer=tf.constant_initializer(0)) 16 | input_low_tensor = tf.nn.relu( 17 | tf.nn.bias_add(tf.nn.conv2d(input_current, input_current_w, strides=[1, 1, 1, 1], padding='SAME'), 18 | input_current_b)) 19 | 20 | # ----------------------------------1x1 conv, for reduce number of parameters---------------- 21 | input_3x3_w = tf.get_variable("input_1x1_w", [3, 3, 64, 64], 22 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 64 / 9))) 23 | input_1x1_b = tf.get_variable("input_1x1_b", [64], initializer=tf.constant_initializer(0)) 24 | input_3x3_tensor = tf.nn.relu( 25 | tf.nn.bias_add(tf.nn.conv2d(input_low_tensor, input_3x3_w, strides=[1, 1, 1, 1], padding='SAME'), 26 | input_1x1_b)) 27 | tensor = input_3x3_tensor 28 | 29 | # --------------------------------------start iteration for last layers---------------------------------- 30 | convId = 0 31 | for i in range(18): 32 | conv_w = tf.get_variable("conv_%02d_w" % (convId), [3, 3, 64, 64], 33 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 9 / 64))) 34 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, conv_w) 35 | conv_b = tf.get_variable("conv_%02d_b" % (convId), [64], initializer=tf.constant_initializer(0)) 36 | convId += 1 37 | tensor = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1, 1, 1, 1], padding='SAME'), conv_b)) 38 | 39 | conv_w = tf.get_variable("conv_%02d_w" % (convId), [3, 3, 64, 1], 40 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 9 / 64))) 41 | conv_b = tf.get_variable("conv_%02d_b" % (convId), [1], initializer=tf.constant_initializer(0)) 42 | convId += 1 43 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, conv_w) 44 | tensor = tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1, 1, 1, 1], padding='SAME'), conv_b) 45 | tensor = tf.add(tensor, input_current) 46 | return tensor 47 | -------------------------------------------------------------------------------- /test/test_Single/LSVE_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | from LSVE_ra_model import model_single 3 | from UTILS_MF_ra_ALL_Single import * 4 | import itertools 5 | 6 | tf.logging.set_verbosity(tf.logging.WARN) 7 | 8 | EXP_DATA = "SF_qp27_ra_01232_LSVE" # single checkpoints 9 | TESTOUT_PATH = "./testout/%s/"%(EXP_DATA) 10 | MODEL_PATH = "./checkpoints/%s/"%(EXP_DATA) 11 | QP_LOWDATA_PATH = r'E:\MF\HEVC_TestSequenct\rec\qp27\ra\D\BasketballPass_qp22_ai_416x240' # low frames 12 | GT_PATH = r"E:\MF\HEVC_TestSequenct\org\D\BasketballPass_416x240" # ground truth 13 | DL_path = 'E:\MF\HEVC_TestSequenct\DL_rec\Single\qp27' 14 | OUT_DATA_PATH = "./outdata/%s/"%(EXP_DATA) 15 | NOFILTER = {'q22':42.2758, 'q27':38.9788, 'qp32':35.8667, 'q37':32.8257,'qp37':32.8257} 16 | 17 | # Ground truth images dir should be the 2nd component of 'fileOrDir' if 2 components are given. 18 | 19 | ##cb, cr components are not implemented 20 | def prepare_test_data(fileOrDir): 21 | if not os.path.exists(TESTOUT_PATH): 22 | os.mkdir(TESTOUT_PATH) 23 | 24 | singleData_ycbcr = [] 25 | singleGT_y = [] 26 | fileName_list = [] 27 | 28 | #The input is a single file. 29 | if len(fileOrDir) == 2: 30 | # return the whole absolute path. 31 | fileName_list = load_file_list(fileOrDir[0]) 32 | # double_list # [[high, low1, label1], [[h21,h22], low2, label2]] 33 | # single_list # [[low1, lable1], [2,2] ....] 34 | single_list = get_test_list(load_file_list(fileOrDir[0]), load_file_list(fileOrDir[1])) 35 | 36 | # single_list # [[low1, lable1], [2,2] ....] 37 | for pair in single_list: 38 | lowData_list = [] 39 | lowData_imgY = c_getYdata(pair[0]) 40 | CbCr = c_getCbCr(pair[0]) 41 | gt_imgY = c_getYdata(pair[1]) 42 | 43 | # normalize 44 | lowData_imgY = normalize(lowData_imgY) 45 | 46 | lowData_imgY = np.resize(lowData_imgY, (1, lowData_imgY.shape[0], lowData_imgY.shape[1], 1)) 47 | gt_imgY = np.resize(gt_imgY, (1, gt_imgY.shape[0], gt_imgY.shape[1], 1)) 48 | 49 | lowData_list.append([lowData_imgY, CbCr]) 50 | singleData_ycbcr.append(lowData_list) 51 | singleGT_y.append(gt_imgY) 52 | 53 | else: 54 | print("Invalid Inputs...!tjc!") 55 | exit(0) 56 | 57 | return singleData_ycbcr, singleGT_y, fileName_list 58 | 59 | def test_all_ckpt(modelPath, fileOrDir): 60 | max = [0, 0] 61 | tem = [f for f in os.listdir(modelPath) if 'data' in f] 62 | ckptFiles = sorted([r.split('.data')[0] for r in tem]) 63 | re_psnr = tf.placeholder('float32') 64 | tf.summary.scalar('re_psnr', re_psnr) 65 | 66 | with tf.Session() as sess: 67 | lowData_tensor = tf.placeholder(tf.float32, shape=(1, None, None, 1)) 68 | shared_model = tf.make_template('shared_model', model_single) 69 | output_tensor = shared_model(lowData_tensor) 70 | output_tensor = tf.clip_by_value(output_tensor, 0., 1.) 71 | output_tensor = output_tensor * 255 72 | 73 | merged = tf.summary.merge_all() 74 | file_writer = tf.summary.FileWriter(OUT_DATA_PATH, sess.graph) 75 | saver = tf.train.Saver(tf.global_variables()) 76 | sess.run(tf.global_variables_initializer()) 77 | 78 | singleData_ycbcr, singleGT_y, fileName_list = prepare_test_data(fileOrDir) 79 | 80 | for ckpt in ckptFiles: 81 | epoch = int(ckpt.split('_')[-1].split('.')[0]) 82 | if epoch != 169: 83 | continue 84 | 85 | saver.restore(sess, os.path.join(modelPath,ckpt)) 86 | total_time, total_psnr = 0, 0 87 | total_imgs = len(fileName_list) 88 | count = 0 89 | for i in range(total_imgs): 90 | if i%4==0: 91 | continue 92 | count += 1 93 | 94 | # sorry! this place write so difficult!【[[[h1,0]],[[low,0]],[[h2, 0]]], [[[h1,0]],[[low,0]],[[h2, 0]]]】 95 | lowDataY = singleData_ycbcr[i][0][0] 96 | imgLowCbCr = singleData_ycbcr[i][0][1] 97 | gtY = singleGT_y[i] if singleGT_y else 0 98 | 99 | #### adopt the split frame method to deal with the out of memory situation. #### 100 | 101 | start_t = time.time() 102 | out = sess.run(output_tensor, feed_dict={lowData_tensor: lowDataY}) 103 | out = np.around(out) 104 | out = out.astype('int') 105 | out = np.reshape(out, [1, out.shape[1], out.shape[2], 1]) 106 | hevc = psnr(lowDataY * 255.0, gtY) 107 | duration_t = time.time() - start_t 108 | total_time += duration_t 109 | Y = np.reshape(out, [out.shape[1], out.shape[2]]) 110 | Y = np.array(list(itertools.chain.from_iterable(Y))) 111 | U = imgLowCbCr[0] 112 | V = imgLowCbCr[1] 113 | creatPath = os.path.join(DL_path, fileName_list[i].split('\\')[-2]) 114 | if not os.path.exists(creatPath): 115 | os.mkdir(creatPath) 116 | 117 | if singleGT_y: 118 | p = psnr(out, gtY) 119 | path = os.path.join(DL_path, 120 | fileName_list[i].split('\\')[-2], 121 | fileName_list[i].split('\\')[-1].split('.')[0]) + '_%.4f' % (p - hevc) + '.yuv' 122 | 123 | YUV = np.concatenate((Y, U, V)) 124 | YUV = YUV.astype('uint8') 125 | YUV.tofile(path) 126 | 127 | total_psnr += p 128 | print("qp??\tepoch:%d\t%s\t%.4f\n" % (epoch, fileName_list[i], p)) 129 | #print("took:%.2fs\t psnr:%.2f name:%s"%(duration_t, p, save_path)) 130 | 131 | 132 | 133 | print("AVG_DURATION:%.2f\tAVG_PSNR:%.4f"%(total_time/total_imgs, total_psnr / count)) 134 | print('count:', count) 135 | # avg_psnr = total_psnr/total_imgs 136 | avg_psnr = total_psnr / count 137 | avg_duration = (total_time/total_imgs) 138 | if avg_psnr > max[0]: 139 | max[0] = avg_psnr 140 | max[1] = epoch 141 | 142 | if __name__ == '__main__': 143 | test_all_ckpt(MODEL_PATH, [QP_LOWDATA_PATH, GT_PATH]) -------------------------------------------------------------------------------- /test/test_Single/UTILS_MF_ra_ALL_Single.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import math, os, random, re 4 | from PIL import Image 5 | BATCH_SIZE = 64 6 | PATCH_SIZE = (64, 64) 7 | 8 | # due to a batch trainingSet come from one picture. I design a algorithm to make the TrainingSet more diversity. 9 | def normalize(x): 10 | x = x / 255. 11 | return truncate(x, 0., 1.) 12 | 13 | def denormalize(x): 14 | x = x * 255. 15 | return truncate(x, 0., 255.) 16 | 17 | def truncate(input, min, max): 18 | input = np.where(input > min, input, min) 19 | input = np.where(input < max, input, max) 20 | return input 21 | 22 | def remap(input): 23 | input = 16+219/255*input 24 | return truncate(input, 16.0, 235.0) 25 | 26 | def deremap(input): 27 | input = (input-16)*255/219 28 | return truncate(input, 0.0, 255.0) 29 | 30 | # return the whole absolute path. 31 | def load_file_list(directory): 32 | list = [] 33 | for filename in [y for y in os.listdir(directory) if os.path.isfile(os.path.join(directory,y))]: 34 | list.append(os.path.join(directory,filename)) 35 | return list 36 | 37 | def searchHighData(currentLowDataIndex, highDataList, highIndexList): 38 | searchOffset = 3 39 | searchedHighDataIndexList = [] 40 | searchedHighData = [] 41 | for i in range(currentLowDataIndex - searchOffset, currentLowDataIndex + searchOffset + 1): 42 | if i in highIndexList: 43 | searchedHighDataIndexList.append(i) 44 | assert len(searchedHighDataIndexList) == 2, 'search method have error!' 45 | for tempData in highDataList: 46 | if int(os.path.basename(tempData).split('.')[0].split('_')[-1]) \ 47 | == searchedHighDataIndexList[0] == searchedHighDataIndexList[1]: 48 | searchedHighData.append(tempData) 49 | return searchedHighData 50 | 51 | 52 | # return like this"[[[high1Data, lowData], label], [[2, 7, 8], 22], [[3, 8, 9], 33]]" with the whole path. 53 | def get_test_list2(highDataList, lowDataList, labelList): 54 | assert len(lowDataList) == len(labelList), "low:%d, label:%d,"%(len(lowDataList) , len(labelList)) 55 | 56 | # [0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48] 57 | highIndexList = [q for q in range(49) if q % 4 == 0] 58 | test_list = [] 59 | for tempDataPath in lowDataList: 60 | tempData = [] 61 | temp = [] 62 | # this place should changed on the different situation. 63 | currentLowDataIndex = int(os.path.basename(tempDataPath).split('.')[0].split('_')[-1]) 64 | searchedHighData = searchHighData(currentLowDataIndex, highDataList, highIndexList) 65 | tempData.append(searchedHighData[0]) 66 | tempData.append(tempDataPath) 67 | tempData.append(searchedHighData[1]) 68 | 69 | i = list(lowDataList).index(tempDataPath) 70 | 71 | temp.append(tempData) 72 | temp.append(labelList[i]) 73 | 74 | test_list.append(temp) 75 | return test_list 76 | 77 | def get_temptest_list(high1DataList, lowDataList, high2DataList, labelList): 78 | tempData = [] 79 | temp = [] 80 | test_list = [] 81 | for i in range(len(lowDataList)): 82 | tempData.append(high1DataList[i]) 83 | tempData.append(lowDataList[i]) 84 | tempData.append(high2DataList[i]) 85 | temp.append(tempData) 86 | temp.append(labelList[i]) 87 | 88 | test_list.append(temp) 89 | return test_list 90 | 91 | # [[high, low1, label1], [[h21,h22], low2, label2]] 92 | def get_test_list(lowDataList, labelList): 93 | singleTest_list = [] 94 | 95 | for lowdata in lowDataList: 96 | tempData = [] 97 | 98 | tempData.append(lowdata) 99 | labelIndex = list(lowDataList).index(lowdata) 100 | if int(os.path.basename(lowdata).split('.')[0].split('_')[-1]) == \ 101 | int(os.path.basename(labelList[labelIndex]).split('.')[0].split('_')[-1]): 102 | tempData.append(labelList[labelIndex]) 103 | singleTest_list.append(tempData) 104 | 105 | return singleTest_list 106 | 107 | # return like this"[[[high1Data, lowData, high2Data], label], [[2, 7, 8], 22], [[3, 8, 9], 33]]" with the whole path. 108 | def get_train_list(high1DataList, lowDataList, high2DataList, labelList): 109 | assert len(lowDataList) == len(high1DataList) == len(labelList) == len(high2DataList), \ 110 | "low:%d, high1:%d, label:%d, high2:%d"%(len(lowDataList), len(high1DataList), len(labelList), len(high2DataList)) 111 | 112 | train_list = [] 113 | for i in range(len(labelList)): 114 | tempData = [] 115 | temp = [] 116 | # this place should changed on the different situation. 117 | if int(os.path.basename(high1DataList[i]).split('_')[-1].split('.')[0]) + 4 == \ 118 | int(os.path.basename(lowDataList[i]).split('_')[-1].split('.')[0]) + 2 == \ 119 | int(os.path.basename(high2DataList[i]).split('_')[-1].split('.')[0]): 120 | tempData.append(high1DataList[i]) 121 | tempData.append(lowDataList[i]) 122 | tempData.append(high2DataList[i]) 123 | temp.append(tempData) 124 | temp.append(labelList[i]) 125 | 126 | else: 127 | raise Exception('len(lowData) not equal with len(highData)...') 128 | train_list.append(temp) 129 | return train_list 130 | 131 | def prepare_nn_data(train_list): 132 | batchSizeRandomList = random.sample(range(0,len(train_list)), 8) 133 | gt_list = [] 134 | high1Data_list = [] 135 | lowData_list = [] 136 | high2Data_list = [] 137 | for i in batchSizeRandomList: 138 | high1Data_image = c_getYdata(train_list[i][0][0]) 139 | lowData_image = c_getYdata(train_list[i][0][1]) 140 | high2Data_image = c_getYdata(train_list[i][0][2]) 141 | gt_image = c_getYdata(train_list[i][1]) 142 | for j in range(0, 8): 143 | #crop images to the disired size. 144 | high1Data_imgY, lowData_imgY, high2Data_imgY, gt_imgY = \ 145 | crop(high1Data_image, lowData_image, high2Data_image, gt_image, PATCH_SIZE[0], PATCH_SIZE[1], "ndarray") 146 | #normalize 147 | high1Data_imgY = normalize(high1Data_imgY) 148 | lowData_imgY = normalize(lowData_imgY) 149 | high2Data_imgY = normalize(high2Data_imgY) 150 | gt_imgY = normalize(gt_imgY) 151 | 152 | high1Data_list.append(high1Data_imgY) 153 | lowData_list.append(lowData_imgY) 154 | high2Data_list.append(high2Data_imgY) 155 | gt_list.append(gt_imgY) 156 | 157 | high1Data_list = np.resize(high1Data_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 158 | lowData_list = np.resize(lowData_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 159 | high2Data_list = np.resize(high2Data_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 160 | gt_list = np.resize(gt_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 161 | 162 | return high1Data_list, lowData_list, high2Data_list, gt_list 163 | 164 | def getWH(yuvfileName): 165 | deyuv=re.compile(r'(.+?)\.') 166 | deyuvFilename=deyuv.findall(yuvfileName)[0] #去yuv后缀的文件名 167 | if 'x' in os.path.basename(deyuvFilename).split('_')[-2]: 168 | wxh = os.path.basename(deyuvFilename).split('_')[-2] 169 | elif 'x' in os.path.basename(deyuvFilename).split('_')[1]: 170 | wxh = os.path.basename(deyuvFilename).split('_')[1] 171 | else: 172 | raise Exception('do not find wxh') 173 | w, h = wxh.split('x') 174 | return int(w), int(h) 175 | 176 | def c_getCbCr(path): 177 | w, h = getWH(path) 178 | CbCr = [] 179 | with open(path, 'rb+') as file: 180 | y = file.read(h * w) 181 | if y == b'': 182 | return '' 183 | u = file.read(h * w // 4) 184 | v = file.read(h * w // 4) 185 | # convert string-list to int-list. 186 | u = list(map(int, u)) 187 | v = list(map(int, v)) 188 | CbCr.append(u) 189 | CbCr.append(v) 190 | return CbCr 191 | 192 | 193 | def getYdata(path, size): 194 | w = size[0] 195 | h = size[1] 196 | with open(path, 'rb') as fp: 197 | fp.seek(0, 0) 198 | Yt = fp.read() 199 | tem = Image.frombytes('L', [w, h], Yt) 200 | 201 | Yt = np.asarray(tem, dtype='float32') 202 | return Yt 203 | 204 | 205 | def c_getYdata(path): 206 | return getYdata(path, getWH(path)) 207 | 208 | def img2y(input_img): 209 | if np.asarray(input_img).shape[2] == 3: 210 | input_imgY = input_img.convert('YCbCr').split()[0] 211 | input_imgCb, input_imgCr = input_img.convert('YCbCr').split()[1:3] 212 | 213 | input_imgY = np.asarray(input_imgY, dtype='float32') 214 | input_imgCb = np.asarray(input_imgCb, dtype='float32') 215 | input_imgCr = np.asarray(input_imgCr, dtype='float32') 216 | 217 | 218 | #Concatenate Cb, Cr components for easy, they are used in pair anyway. 219 | input_imgCb = np.expand_dims(input_imgCb,2) 220 | input_imgCr = np.expand_dims(input_imgCr,2) 221 | input_imgCbCr = np.concatenate((input_imgCb, input_imgCr), axis=2) 222 | 223 | elif np.asarray(input_img).shape[2] == 1: 224 | print("This image has one channal only.") 225 | #If the num of channal is 1, remain. 226 | input_imgY = input_img 227 | input_imgCbCr = None 228 | else: 229 | print("The num of channal is neither 3 nor 1.") 230 | exit() 231 | return input_imgY, input_imgCbCr 232 | 233 | # def crop(input_image, gt_image, patch_width, patch_height, img_type): 234 | def crop(high1Data_image, lowData_image, high2Data_image, gt_image, patch_width, patch_height, img_type): 235 | assert type(high1Data_image) == type(gt_image) == type(lowData_image) == type(high2Data_image), "types are different." 236 | high1Data_cropped = [] 237 | lowData_cropped = [] 238 | high2Data_cropped = [] 239 | gt_cropped = [] 240 | 241 | # return a ndarray object 242 | if img_type == "ndarray": 243 | in_row_ind = random.randint(0,high1Data_image.shape[0]-patch_width) 244 | in_col_ind = random.randint(0,high1Data_image.shape[1]-patch_height) 245 | 246 | high1Data_cropped = high1Data_image[in_row_ind:in_row_ind+patch_width, in_col_ind:in_col_ind+patch_height] 247 | lowData_cropped = lowData_image[in_row_ind:in_row_ind + patch_width, in_col_ind:in_col_ind + patch_height] 248 | high2Data_cropped = high2Data_image[in_row_ind:in_row_ind + patch_width, in_col_ind:in_col_ind + patch_height] 249 | gt_cropped = gt_image[in_row_ind:in_row_ind+patch_width, in_col_ind:in_col_ind+patch_height] 250 | 251 | #return an "Image" object 252 | elif img_type == "Image": 253 | pass 254 | return high1Data_cropped, lowData_cropped, high2Data_cropped, gt_cropped 255 | 256 | def save_images(inputY, inputCbCr, size, image_path): 257 | """Save mutiple images into one single image. 258 | 259 | # Parameters 260 | # ----------- 261 | # images : numpy array [batch, w, h, c] 262 | # size : list of two int, row and column number. 263 | # number of images should be equal or less than size[0] * size[1] 264 | # image_path : string. 265 | # 266 | # Examples 267 | # --------- 268 | # # >>> images = np.random.rand(64, 100, 100, 3) 269 | # # >>> tl.visualize.save_images(images, [8, 8], 'temp.png') 270 | """ 271 | def merge(images, size): 272 | h, w = images.shape[1], images.shape[2] 273 | img = np.zeros((h * size[0], w * size[1], 3)) 274 | for idx, image in enumerate(images): 275 | i = idx % size[1] 276 | j = idx // size[1] 277 | img[j*h:j*h+h, i*w:i*w+w, :] = image 278 | return img 279 | 280 | inputY = inputY.astype('uint8') 281 | inputCbCr = inputCbCr.astype('uint8') 282 | output_concat = np.concatenate((inputY, inputCbCr), axis=3) 283 | 284 | assert len(output_concat) <= size[0] * size[1], "number of images should be equal or less than size[0] * size[1] {}".format(len(output_concat)) 285 | 286 | new_output = merge(output_concat, size) 287 | 288 | new_output = new_output.astype('uint8') 289 | 290 | img = Image.fromarray(new_output, mode='YCbCr') 291 | img = img.convert('RGB') 292 | img.save(image_path) 293 | 294 | def get_image_batch(train_list,offset,batch_size): 295 | target_list = train_list[offset:offset+batch_size] 296 | input_list = [] 297 | gt_list = [] 298 | inputcbcr_list = [] 299 | for pair in target_list: 300 | input_img = Image.open(pair[0]) 301 | gt_img = Image.open(pair[1]) 302 | 303 | #crop images to the disired size. 304 | input_img, gt_img = crop(input_img, gt_img, PATCH_SIZE[0], PATCH_SIZE[1], "Image") 305 | 306 | #focus on Y channal only 307 | input_imgY, input_imgCbCr = img2y(input_img) 308 | gt_imgY, gt_imgCbCr = img2y(gt_img) 309 | input_list.append(input_imgY) 310 | gt_list.append(gt_imgY) 311 | inputcbcr_list.append(input_imgCbCr) 312 | 313 | input_list = np.resize(input_list, (batch_size, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 314 | gt_list = np.resize(gt_list, (batch_size, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 315 | 316 | return input_list, gt_list, inputcbcr_list 317 | 318 | def save_test_img(inputY, inputCbCr, path): 319 | assert len(inputY.shape) == 4, "the tensor Y's shape is %s"%inputY.shape 320 | assert inputY.shape[0] == 1, "the fitst component must be 1, has not been completed otherwise.{}".format(inputY.shape) 321 | 322 | inputY = np.squeeze(inputY, axis=0) 323 | inputY = inputY.astype('uint8') 324 | 325 | inputCbCr = inputCbCr.astype('uint8') 326 | 327 | output_concat = np.concatenate((inputY, inputCbCr), axis=2) 328 | img = Image.fromarray(output_concat, mode='YCbCr') 329 | img = img.convert('RGB') 330 | img.save(path) 331 | 332 | def psnr(hr_image, sr_image, max_value=255.0): 333 | eps = 1e-10 334 | if((type(hr_image)==type(np.array([]))) or (type(hr_image)==type([]))): 335 | hr_image_data = np.asarray(hr_image, 'float32') 336 | sr_image_data = np.asarray(sr_image, 'float32') 337 | 338 | diff = sr_image_data - hr_image_data 339 | mse = np.mean(diff*diff) 340 | mse = np.maximum(eps, mse) 341 | return float(10*math.log10(max_value*max_value/mse)) 342 | else: 343 | assert len(hr_image.shape)==4 and len(sr_image.shape)==4 344 | diff = hr_image - sr_image 345 | mse = tf.reduce_mean(tf.square(diff)) 346 | mse = tf.maximum(mse, eps) 347 | return 10*tf.log(max_value*max_value/mse)/math.log(10) 348 | 349 | def getBeforeNNBlockDict(img, w, h): 350 | blockSize = 1000 351 | padding = 32 352 | yBlockNum = (h // blockSize) if (h % blockSize == 0) else (h // blockSize + 1) 353 | xBlockNum = (w // blockSize) if (w % blockSize == 0) else (w // blockSize + 1) 354 | tempImg = {} 355 | i = 0 356 | for yBlock in range(yBlockNum): 357 | for xBlock in range(xBlockNum): 358 | if yBlock == 0: 359 | if xBlock == 0: 360 | tempImg[i] = img[0: blockSize+padding, 0: blockSize+padding] 361 | elif xBlock == xBlockNum - 1: 362 | tempImg[i] = img[0: blockSize+padding, xBlock*blockSize-padding: w] 363 | else: 364 | tempImg[i] = img[0: blockSize+padding, blockSize*xBlock-padding: blockSize*(xBlock+1)+padding] 365 | elif yBlock == yBlockNum - 1: 366 | if xBlock == 0: 367 | tempImg[i] = img[blockSize*yBlock-padding: h, 0: blockSize+padding] 368 | elif xBlock == xBlockNum - 1: 369 | tempImg[i] = img[blockSize*yBlock-padding: h, blockSize*xBlock-padding: w] 370 | else: 371 | tempImg[i] = img[blockSize*yBlock-padding: h, blockSize*xBlock-padding: blockSize*(xBlock+1)+padding] 372 | elif xBlock == 0: 373 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding, 0: blockSize+padding] 374 | elif xBlock == xBlockNum - 1: 375 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding, blockSize*xBlock-padding: w] 376 | else: 377 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding, 378 | blockSize*xBlock-padding: blockSize*(xBlock+1)+padding] 379 | i += i 380 | l = tempImg[i].astype('uint8') 381 | l = Image.fromarray(l) 382 | l.show() 383 | -------------------------------------------------------------------------------- /test/test_Single/__pycache__/MFVDSingle_ra.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/__pycache__/MFVDSingle_ra.cpython-35.pyc -------------------------------------------------------------------------------- /test/test_Single/__pycache__/UTILS_MF_ra_ALL_Single.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/__pycache__/UTILS_MF_ra_ALL_Single.cpython-35.pyc -------------------------------------------------------------------------------- /test/test_Single/checkpoints/SF_qp22_ra_01251_LSVE/SF_qp22_ra_01251_LSVE_105.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp22_ra_01251_LSVE/SF_qp22_ra_01251_LSVE_105.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /test/test_Single/checkpoints/SF_qp22_ra_01251_LSVE/SF_qp22_ra_01251_LSVE_105.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp22_ra_01251_LSVE/SF_qp22_ra_01251_LSVE_105.ckpt.index -------------------------------------------------------------------------------- /test/test_Single/checkpoints/SF_qp22_ra_01251_LSVE/SF_qp22_ra_01251_LSVE_105.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp22_ra_01251_LSVE/SF_qp22_ra_01251_LSVE_105.ckpt.meta -------------------------------------------------------------------------------- /test/test_Single/checkpoints/SF_qp22_ra_01251_LSVE/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "SF_qp22_ra_01251_VDSRSingle_105.ckpt" 2 | all_model_checkpoint_paths: "SF_qp22_ra_01251_VDSRSingle_105.ckpt" 3 | -------------------------------------------------------------------------------- /test/test_Single/checkpoints/SF_qp27_ra_01232_LSVE/SF_qp27_ra_01232_LSVE_169.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp27_ra_01232_LSVE/SF_qp27_ra_01232_LSVE_169.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /test/test_Single/checkpoints/SF_qp27_ra_01232_LSVE/SF_qp27_ra_01232_LSVE_169.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp27_ra_01232_LSVE/SF_qp27_ra_01232_LSVE_169.ckpt.index -------------------------------------------------------------------------------- /test/test_Single/checkpoints/SF_qp27_ra_01232_LSVE/SF_qp27_ra_01232_LSVE_169.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp27_ra_01232_LSVE/SF_qp27_ra_01232_LSVE_169.ckpt.meta -------------------------------------------------------------------------------- /test/test_Single/checkpoints/SF_qp27_ra_01232_LSVE/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "SF_qp27_ra_01232_VDSRSingle_169.ckpt" 2 | all_model_checkpoint_paths: "SF_qp27_ra_01232_VDSRSingle_169.ckpt" 3 | -------------------------------------------------------------------------------- /test/test_Single/checkpoints/SF_qp32_ra_01242_LSVE/SF_qp32_ra_01242_LSVE_185.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp32_ra_01242_LSVE/SF_qp32_ra_01242_LSVE_185.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /test/test_Single/checkpoints/SF_qp32_ra_01242_LSVE/SF_qp32_ra_01242_LSVE_185.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp32_ra_01242_LSVE/SF_qp32_ra_01242_LSVE_185.ckpt.index -------------------------------------------------------------------------------- /test/test_Single/checkpoints/SF_qp32_ra_01242_LSVE/SF_qp32_ra_01242_LSVE_185.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp32_ra_01242_LSVE/SF_qp32_ra_01242_LSVE_185.ckpt.meta -------------------------------------------------------------------------------- /test/test_Single/checkpoints/SF_qp32_ra_01242_LSVE/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "SF_qp32_ra_01242_VDSRSingle_185.ckpt" 2 | all_model_checkpoint_paths: "SF_qp32_ra_01242_VDSRSingle_185.ckpt" 3 | -------------------------------------------------------------------------------- /test/test_Single/checkpoints/SF_qp37_ra_01174_LSVE/MF_qp37_ra_01174_LSVE_291.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp37_ra_01174_LSVE/MF_qp37_ra_01174_LSVE_291.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /test/test_Single/checkpoints/SF_qp37_ra_01174_LSVE/MF_qp37_ra_01174_LSVE_291.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp37_ra_01174_LSVE/MF_qp37_ra_01174_LSVE_291.ckpt.index -------------------------------------------------------------------------------- /test/test_Single/checkpoints/SF_qp37_ra_01174_LSVE/MF_qp37_ra_01174_LSVE_291.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp37_ra_01174_LSVE/MF_qp37_ra_01174_LSVE_291.ckpt.meta -------------------------------------------------------------------------------- /test/test_Single/checkpoints/SF_qp37_ra_01174_LSVE/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "MF_qp37_ra_01174_VDSRSingle_291.ckpt" 2 | all_model_checkpoint_paths: "MF_qp37_ra_01174_VDSRSingle_291.ckpt" 3 | -------------------------------------------------------------------------------- /test/test_Single/outdata/SF_qp27_ra_01232_LSVE/events.out.tfevents.1549718071.AIR-PC: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/outdata/SF_qp27_ra_01232_LSVE/events.out.tfevents.1549718071.AIR-PC -------------------------------------------------------------------------------- /test/test_Single/outdata/SF_qp27_ra_01232_LSVE/events.out.tfevents.1549718367.AIR-PC: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/outdata/SF_qp27_ra_01232_LSVE/events.out.tfevents.1549718367.AIR-PC -------------------------------------------------------------------------------- /train/train_MF/LMVE_ra_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def model_double(inputHigh1Data_tensor, inputLowData_tensor, inputHigh2Data_tensor): 6 | # with tf.device("/gpu:0"): 7 | input_before = inputHigh1Data_tensor # highData 8 | input_current = inputLowData_tensor # lowData 9 | 10 | input_after = inputHigh2Data_tensor 11 | 12 | # due to don't have training_Set at right now, so let it be annotation. 13 | tensor = None 14 | 15 | # ----------------------------------------------Frame -1-------------------------------------------------------- 16 | input_before_w = tf.get_variable("input_before_w", [5, 5, 1, 64], 17 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 25))) 18 | input_before_b = tf.get_variable("input_before_b", [64], initializer=tf.constant_initializer(0)) 19 | input_high1_tensor = tf.nn.relu( 20 | tf.nn.bias_add(tf.nn.conv2d(input_before, input_before_w, strides=[1, 1, 1, 1], padding='SAME'), input_before_b)) 21 | # ----------------------------------------------Frame 0-------------------------------------------------------- 22 | input_current_w = tf.get_variable("input_current_w", [5, 5, 1, 64], 23 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 25))) 24 | input_current_b = tf.get_variable("input_current_b", [64], initializer=tf.constant_initializer(0)) 25 | input_low_tensor = tf.nn.relu( 26 | tf.nn.bias_add(tf.nn.conv2d(input_current, input_current_w, strides=[1, 1, 1, 1], padding='SAME'), 27 | input_current_b)) 28 | 29 | # ----------------------------------------------Frame 1-------------------------------------------------------- 30 | input_after_w = tf.get_variable("input_after_w", [5, 5, 1, 64], 31 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 25))) 32 | input_after_b = tf.get_variable("input_after_b", [64], initializer=tf.constant_initializer(0)) 33 | input_high2_tensor = tf.nn.relu( 34 | tf.nn.bias_add(tf.nn.conv2d(input_after, input_after_w, strides=[1, 1, 1, 1], padding='SAME'), 35 | input_after_b)) 36 | # ------------------------------------------Frame -1\0\1 concat------------------------------------------ 37 | input_tensor_Concat = tf.concat([input_high1_tensor, input_low_tensor, input_high2_tensor], axis=3) 38 | # ----------------------------------1x1 conv, for reduce number of parameters---------------- 39 | input_1x1_w = tf.get_variable("input_1x1_w", [1, 1, 192, 64], 40 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 192))) 41 | input_1x1_b = tf.get_variable("input_1x1_b", [64], initializer=tf.constant_initializer(0)) 42 | input_1x1_tensor = tf.nn.relu( 43 | tf.nn.bias_add(tf.nn.conv2d(input_tensor_Concat, input_1x1_w, strides=[1, 1, 1, 1], padding='SAME'), 44 | input_1x1_b)) 45 | tensor = input_1x1_tensor 46 | 47 | # --------------------------------------start iteration for last layers---------------------------------- 48 | convId = 0 49 | for i in range(18): 50 | conv_w = tf.get_variable("conv_%02d_w" % (convId), [3, 3, 64, 64], 51 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 9 / 64))) 52 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, conv_w) 53 | conv_b = tf.get_variable("conv_%02d_b" % (convId), [64], initializer=tf.constant_initializer(0)) 54 | convId += 1 55 | tensor = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1, 1, 1, 1], padding='SAME'), conv_b)) 56 | 57 | conv_w = tf.get_variable("conv_%02d_w" % (convId), [3, 3, 64, 1], 58 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 9 / 64))) 59 | conv_b = tf.get_variable("conv_%02d_b" % (convId), [1], initializer=tf.constant_initializer(0)) 60 | convId += 1 61 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, conv_w) 62 | tensor = tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1, 1, 1, 1], padding='SAME'), conv_b) 63 | 64 | tensor = tf.add(tensor, input_current) 65 | return tensor 66 | -------------------------------------------------------------------------------- /train/train_MF/LMVE_ra_train.py: -------------------------------------------------------------------------------- 1 | import argparse, time 2 | from LMVE_ra_model import model_double 3 | from UTILS_MF_ra import * 4 | 5 | tf.logging.set_verbosity(tf.logging.WARN) 6 | #os.environ["CUDA_VISIBLE_DEVICES"] = "1" 7 | 8 | EXP_DATA = 'MF_qp37_ra_02052_yangNet_double' # checkpoints path 9 | LOW_DATA_PATH = r"E:\MF\trainSet\qp37\low_data" # low frames 10 | HIGH1_DATA_PATH = r"E:\MF\trainSet\qp37\high1_wraped_Y" # high frames1 11 | HIGH2_DATA_PATH = r"E:\MF\trainSet\qp37\high2_wraped_Y" # high frames2 12 | LABEL_PATH = r"E:\MF\trainSet\qp37\label_s" #lable frames 13 | LOG_PATH = "./logs/%s/"%(EXP_DATA) 14 | CKPT_PATH = "./checkpoints/%s/"%(EXP_DATA) 15 | SAMPLE_PATH = "./samples/%s/"%(EXP_DATA) 16 | PATCH_SIZE = (64, 64) 17 | BATCH_SIZE = 64 18 | BASE_LR = 3e-4 19 | LR_DECAY_RATE = 0.2 20 | LR_DECAY_STEP = 20 21 | MAX_EPOCH = 2000 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--model_path") 26 | args = parser.parse_args() 27 | model_path = args.model_path 28 | if __name__ == '__main__': 29 | 30 | # return like this"[[[high1Data, lowData, high2Data], label], [[3, 8, 9], 33]]" with the whole path. 31 | train_list = get_train_list(load_file_list(HIGH1_DATA_PATH), load_file_list(LOW_DATA_PATH), 32 | load_file_list(HIGH2_DATA_PATH), load_file_list(LABEL_PATH)) 33 | 34 | with tf.name_scope('input_scope'): 35 | train_hight1Data = tf.placeholder('float32', shape=(BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 36 | train_lowData = tf.placeholder('float32', shape=(BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 37 | train_hight2Data = tf.placeholder('float32', shape=(BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 38 | train_gt = tf.placeholder('float32', shape=(BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 39 | 40 | shared_model = tf.make_template('shared_model', model_double) 41 | train_output = shared_model(train_hight1Data, train_lowData, train_hight2Data) 42 | train_output = tf.clip_by_value(train_output, 0., 1.) 43 | with tf.name_scope('loss_scope'): 44 | loss2 = tf.reduce_sum(tf.square(tf.subtract(train_output, train_gt))) 45 | loss1 = tf.reduce_sum(tf.abs(tf.subtract(train_output, train_gt))) 46 | W = tf.get_collection(tf.GraphKeys.WEIGHTS) 47 | for w in W: 48 | loss2 += tf.nn.l2_loss(w)*1e-4 49 | 50 | avg_loss = tf.placeholder('float32') 51 | tf.summary.scalar("avg_loss", avg_loss) 52 | 53 | global_step = tf.Variable(0, trainable=False) # len(train_list) 54 | learning_rate = tf.train.exponential_decay(BASE_LR, global_step, LR_DECAY_STEP*1000, LR_DECAY_RATE, staircase=True) 55 | tf.summary.scalar("learning rate", learning_rate) 56 | 57 | # org --------------------------------------------------------------------------------- 58 | optimizer = tf.train.AdamOptimizer(learning_rate, 0.9) 59 | opt = optimizer.minimize(loss2, global_step=global_step) 60 | saver = tf.train.Saver(max_to_keep=0) 61 | 62 | # org end------------------------------------------------------------------------------ 63 | config = tf.ConfigProto() 64 | config.gpu_options.allow_growth = True 65 | with tf.Session(config=config) as sess: 66 | if not os.path.exists(LOG_PATH): 67 | os.makedirs(LOG_PATH) 68 | if not os.path.exists(os.path.dirname(CKPT_PATH)): 69 | os.makedirs(os.path.dirname(CKPT_PATH)) 70 | if not os.path.exists(SAMPLE_PATH): 71 | os.makedirs(SAMPLE_PATH) 72 | 73 | merged = tf.summary.merge_all() 74 | file_writer = tf.summary.FileWriter(LOG_PATH, sess.graph) 75 | 76 | sess.run(tf.global_variables_initializer()) 77 | 78 | if model_path: 79 | print("restore model...") 80 | saver.restore(sess, model_path) 81 | print("Done") 82 | for epoch in range(MAX_EPOCH): 83 | total_g_loss, n_iter = 0, 0 84 | idxOfImgs = np.random.permutation(len(train_list)) 85 | epoch_time = time.time() 86 | 87 | for idx in range(1000): 88 | input_high1Data, input_lowData, input_high2Data, gt_data = prepare_nn_data(train_list) 89 | feed_dict = {train_hight1Data: input_high1Data, train_lowData: input_lowData, 90 | train_hight2Data: input_high2Data, train_gt: gt_data} 91 | 92 | _, l, output, g_step = sess.run([opt, loss2, train_output, global_step], feed_dict=feed_dict) 93 | total_g_loss += l 94 | n_iter += 1 95 | del input_high1Data, input_lowData, input_high2Data, gt_data, output 96 | lr, summary = sess.run([learning_rate, merged], {avg_loss:total_g_loss/n_iter}) 97 | file_writer.add_summary(summary, epoch) 98 | tf.logging.warning("Epoch: [%4d/%4d] time: %4.4f\tloss: %.8f\tlr: %.8f"%(epoch, MAX_EPOCH, time.time()-epoch_time, total_g_loss/n_iter, lr)) 99 | print("Epoch: [%4d/%4d] time: %4.4f\tloss: %.8f\tlr: %.8f"%(epoch, MAX_EPOCH, time.time()-epoch_time, total_g_loss/n_iter, lr)) 100 | saver.save(sess, os.path.join(CKPT_PATH, "%s_%03d.ckpt"%(EXP_DATA, epoch))) 101 | -------------------------------------------------------------------------------- /train/train_MF/UTILS_MF_ra.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import math, os, random, re 4 | from PIL import Image 5 | from LMVE_ra_train import BATCH_SIZE 6 | from LMVE_ra_train import PATCH_SIZE 7 | 8 | # due to a batch trainingSet come from one picture. I design a algorithm to make the TrainingSet more diversity. 9 | def normalize(x): 10 | x = x / 255. 11 | return truncate(x, 0., 1.) 12 | 13 | def denormalize(x): 14 | x = x * 255. 15 | return truncate(x, 0., 255.) 16 | 17 | def truncate(input, min, max): 18 | input = np.where(input > min, input, min) 19 | input = np.where(input < max, input, max) 20 | return input 21 | 22 | def remap(input): 23 | input = 16+219/255*input 24 | return truncate(input, 16.0, 235.0) 25 | 26 | def deremap(input): 27 | input = (input-16)*255/219 28 | return truncate(input, 0.0, 255.0) 29 | 30 | # return the whole absolute path. 31 | def load_file_list(directory): 32 | list = [] 33 | for filename in [y for y in os.listdir(directory) if os.path.isfile(os.path.join(directory,y))]: 34 | list.append(os.path.join(directory,filename)) 35 | return list 36 | 37 | def searchHighData(currentLowDataIndex, highDataList, highIndexList): 38 | searchOffset = 3 39 | searchedHighDataIndexList = [] 40 | searchedHighData = [] 41 | for i in range(currentLowDataIndex - searchOffset, currentLowDataIndex + searchOffset + 1): 42 | if i in highIndexList: 43 | searchedHighDataIndexList.append(i) 44 | assert len(searchedHighDataIndexList) == 2, 'search method have error!' 45 | for tempData in highDataList: 46 | if int(os.path.basename(tempData).split('.')[0].split('_')[-1]) \ 47 | == searchedHighDataIndexList[0] == searchedHighDataIndexList[1]: 48 | searchedHighData.append(tempData) 49 | return searchedHighData 50 | 51 | 52 | # return like this"[[[high1Data, lowData], label], [[2, 7, 8], 22], [[3, 8, 9], 33]]" with the whole path. 53 | def get_test_list2(highDataList, lowDataList, labelList): 54 | assert len(lowDataList) == len(labelList), "low:%d, label:%d,"%(len(lowDataList) , len(labelList)) 55 | # [0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48] 56 | highIndexList = [q for q in range(49) if q % 4 == 0] 57 | test_list = [] 58 | for tempDataPath in lowDataList: 59 | tempData = [] 60 | temp = [] 61 | # this place should changed on the different situation. 62 | currentLowDataIndex = int(os.path.basename(tempDataPath).split('.')[0].split('_')[-1]) 63 | searchedHighData = searchHighData(currentLowDataIndex, highDataList, highIndexList) 64 | tempData.append(searchedHighData[0]) 65 | tempData.append(tempDataPath) 66 | tempData.append(searchedHighData[1]) 67 | 68 | i = list(lowDataList).index(tempDataPath) 69 | 70 | temp.append(tempData) 71 | temp.append(labelList[i]) 72 | test_list.append(temp) 73 | return test_list 74 | 75 | def get_temptest_list(high1DataList, lowDataList, high2DataList, labelList): 76 | tempData = [] 77 | temp = [] 78 | test_list = [] 79 | for i in range(len(lowDataList)): 80 | tempData.append(high1DataList[i]) 81 | tempData.append(lowDataList[i]) 82 | tempData.append(high2DataList[i]) 83 | temp.append(tempData) 84 | temp.append(labelList[i]) 85 | 86 | test_list.append(temp) 87 | return test_list 88 | 89 | # [[high, low1, label1], [[h21,h22], low2, label2]] 90 | def get_test_list(HIGHDATA_Parent_PATH, lowDataList, labelList): 91 | doubleTest_list = [] 92 | singleTest_list = [] 93 | HighDirList = os.listdir(HIGHDATA_Parent_PATH) 94 | # convert string-list to int-list. 95 | HighDirList = list(map(int, HighDirList)) 96 | 97 | for lowdata in lowDataList: 98 | 99 | tempData = [] 100 | lowdataIndex = int(os.path.basename(lowdata).split('.')[0].split('_')[-1]) 101 | if lowdataIndex % 4 != 0: 102 | if lowdataIndex in HighDirList: 103 | High_current_path = HIGHDATA_Parent_PATH + '/' + str(lowdataIndex) 104 | TWO_HighData = os.listdir(High_current_path) 105 | Two_HighData = [os.path.join(High_current_path, T) for T in TWO_HighData] 106 | tempData.append(Two_HighData) 107 | tempData.append(lowdata) 108 | labelIndex = list(lowDataList).index(lowdata) 109 | tempData.append(labelList[labelIndex]) 110 | doubleTest_list.append(tempData) 111 | else: 112 | tempData.append(lowdata) 113 | labelIndex = list(lowDataList).index(lowdata) 114 | if int(os.path.basename(lowdata).split('.')[0].split('_')[-1]) == \ 115 | int(os.path.basename(labelList[labelIndex]).split('.')[0].split('_')[-1]): 116 | 117 | tempData.append(labelList[labelIndex]) 118 | singleTest_list.append(tempData) 119 | 120 | return doubleTest_list, singleTest_list 121 | 122 | # return like this"[[[high1Data, lowData, high2Data], label], [[2, 7, 8], 22], [[3, 8, 9], 33]]" with the whole path. 123 | def get_train_list(high1DataList, lowDataList, high2DataList, labelList): 124 | assert len(lowDataList) == len(high1DataList) == len(labelList) == len(high2DataList), \ 125 | "low:%d, high1:%d, label:%d, high2:%d"%(len(lowDataList), len(high1DataList), len(labelList), len(high2DataList)) 126 | 127 | train_list = [] 128 | for i in range(len(labelList)): 129 | tempData = [] 130 | temp = [] 131 | # this place should changed on the different situation. 132 | if int(os.path.basename(high1DataList[i]).split('_')[-1].split('.')[0]) + 4 == \ 133 | int(os.path.basename(lowDataList[i]).split('_')[-1].split('.')[0]) + 2 == \ 134 | int(os.path.basename(high2DataList[i]).split('_')[-1].split('.')[0]): 135 | tempData.append(high1DataList[i]) 136 | tempData.append(lowDataList[i]) 137 | tempData.append(high2DataList[i]) 138 | temp.append(tempData) 139 | temp.append(labelList[i]) 140 | 141 | else: 142 | raise Exception('len(lowData) not equal with len(highData)...') 143 | train_list.append(temp) 144 | return train_list 145 | 146 | def prepare_nn_data(train_list): 147 | batchSizeRandomList = random.sample(range(0,len(train_list)), 8) 148 | gt_list = [] 149 | high1Data_list = [] 150 | lowData_list = [] 151 | high2Data_list = [] 152 | for i in batchSizeRandomList: 153 | high1Data_image = c_getYdata(train_list[i][0][0]) 154 | lowData_image = c_getYdata(train_list[i][0][1]) 155 | high2Data_image = c_getYdata(train_list[i][0][2]) 156 | gt_image = c_getYdata(train_list[i][1]) 157 | for j in range(0, 8): 158 | #crop images to the disired size. 159 | high1Data_imgY, lowData_imgY, high2Data_imgY, gt_imgY = \ 160 | crop(high1Data_image, lowData_image, high2Data_image, gt_image, PATCH_SIZE[0], PATCH_SIZE[1], "ndarray") 161 | 162 | #normalize 163 | high1Data_imgY = normalize(high1Data_imgY) 164 | lowData_imgY = normalize(lowData_imgY) 165 | high2Data_imgY = normalize(high2Data_imgY) 166 | gt_imgY = normalize(gt_imgY) 167 | 168 | high1Data_list.append(high1Data_imgY) 169 | lowData_list.append(lowData_imgY) 170 | high2Data_list.append(high2Data_imgY) 171 | gt_list.append(gt_imgY) 172 | 173 | high1Data_list = np.resize(high1Data_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 174 | lowData_list = np.resize(lowData_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 175 | high2Data_list = np.resize(high2Data_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 176 | gt_list = np.resize(gt_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 177 | 178 | return high1Data_list, lowData_list, high2Data_list, gt_list 179 | 180 | def getWH(yuvfileName): 181 | deyuv=re.compile(r'(.+?)\.') 182 | deyuvFilename=deyuv.findall(yuvfileName)[0] #去yuv后缀的文件名 183 | if 'x' in os.path.basename(deyuvFilename).split('_')[-2]: 184 | wxh = os.path.basename(deyuvFilename).split('_')[-2] 185 | elif 'x' in os.path.basename(deyuvFilename).split('_')[1]: 186 | wxh = os.path.basename(deyuvFilename).split('_')[1] 187 | else: 188 | # print(yuvfileName) 189 | raise Exception('do not find wxh') 190 | w, h = wxh.split('x') 191 | return int(w), int(h) 192 | 193 | def getYdata(path, size): 194 | w = size[0] 195 | h = size[1] 196 | with open(path, 'rb') as fp: 197 | fp.seek(0, 0) 198 | Yt = fp.read() 199 | tem = Image.frombytes('L', [w, h], Yt) 200 | Yt = np.asarray(tem, dtype='float32') 201 | return Yt 202 | 203 | 204 | def c_getYdata(path): 205 | return getYdata(path, getWH(path)) 206 | 207 | def c_getCbCr(path): 208 | w, h = getWH(path) 209 | CbCr = [] 210 | with open(path, 'rb+') as file: 211 | y = file.read(h * w) 212 | if y == b'': 213 | return '' 214 | u = file.read(h * w // 4) 215 | v = file.read(h * w // 4) 216 | # convert string-list to int-list. 217 | u = list(map(int, u)) 218 | v = list(map(int, v)) 219 | CbCr.append(u) 220 | CbCr.append(v) 221 | return CbCr 222 | 223 | def img2y(input_img): 224 | if np.asarray(input_img).shape[2] == 3: 225 | input_imgY = input_img.convert('YCbCr').split()[0] 226 | input_imgCb, input_imgCr = input_img.convert('YCbCr').split()[1:3] 227 | input_imgY = np.asarray(input_imgY, dtype='float32') 228 | input_imgCb = np.asarray(input_imgCb, dtype='float32') 229 | input_imgCr = np.asarray(input_imgCr, dtype='float32') 230 | 231 | #Concatenate Cb, Cr components for easy, they are used in pair anyway. 232 | input_imgCb = np.expand_dims(input_imgCb,2) 233 | input_imgCr = np.expand_dims(input_imgCr,2) 234 | input_imgCbCr = np.concatenate((input_imgCb, input_imgCr), axis=2) 235 | 236 | elif np.asarray(input_img).shape[2] == 1: 237 | print("This image has one channal only.") 238 | #If the num of channal is 1, remain. 239 | input_imgY = input_img 240 | input_imgCbCr = None 241 | else: 242 | print("The num of channal is neither 3 nor 1.") 243 | exit() 244 | return input_imgY, input_imgCbCr 245 | 246 | # def crop(input_image, gt_image, patch_width, patch_height, img_type): 247 | def crop(high1Data_image, lowData_image, high2Data_image, gt_image, patch_width, patch_height, img_type): 248 | assert type(high1Data_image) == type(gt_image) == type(lowData_image) == type(high2Data_image), "types are different." 249 | high1Data_cropped = [] 250 | lowData_cropped = [] 251 | high2Data_cropped = [] 252 | gt_cropped = [] 253 | 254 | # return a ndarray object 255 | if img_type == "ndarray": 256 | in_row_ind = random.randint(0,high1Data_image.shape[0]-patch_width) 257 | in_col_ind = random.randint(0,high1Data_image.shape[1]-patch_height) 258 | 259 | high1Data_cropped = high1Data_image[in_row_ind:in_row_ind+patch_width, in_col_ind:in_col_ind+patch_height] 260 | lowData_cropped = lowData_image[in_row_ind:in_row_ind + patch_width, in_col_ind:in_col_ind + patch_height] 261 | high2Data_cropped = high2Data_image[in_row_ind:in_row_ind + patch_width, in_col_ind:in_col_ind + patch_height] 262 | gt_cropped = gt_image[in_row_ind:in_row_ind+patch_width, in_col_ind:in_col_ind+patch_height] 263 | 264 | #return an "Image" object 265 | elif img_type == "Image": 266 | pass 267 | return high1Data_cropped, lowData_cropped, high2Data_cropped, gt_cropped 268 | 269 | def save_images(inputY, inputCbCr, size, image_path): 270 | """Save mutiple images into one single image. 271 | 272 | # Parameters 273 | # ----------- 274 | # images : numpy array [batch, w, h, c] 275 | # size : list of two int, row and column number. 276 | # number of images should be equal or less than size[0] * size[1] 277 | # image_path : string. 278 | # 279 | # Examples 280 | # --------- 281 | # # >>> images = np.random.rand(64, 100, 100, 3) 282 | # # >>> tl.visualize.save_images(images, [8, 8], 'temp.png') 283 | """ 284 | def merge(images, size): 285 | h, w = images.shape[1], images.shape[2] 286 | img = np.zeros((h * size[0], w * size[1], 3)) 287 | for idx, image in enumerate(images): 288 | i = idx % size[1] 289 | j = idx // size[1] 290 | img[j*h:j*h+h, i*w:i*w+w, :] = image 291 | return img 292 | 293 | inputY = inputY.astype('uint8') 294 | inputCbCr = inputCbCr.astype('uint8') 295 | output_concat = np.concatenate((inputY, inputCbCr), axis=3) 296 | 297 | assert len(output_concat) <= size[0] * size[1], "number of images should be equal or less than size[0] * size[1] {}".format(len(output_concat)) 298 | 299 | new_output = merge(output_concat, size) 300 | 301 | new_output = new_output.astype('uint8') 302 | 303 | img = Image.fromarray(new_output, mode='YCbCr') 304 | img = img.convert('RGB') 305 | img.save(image_path) 306 | 307 | def get_image_batch(train_list,offset,batch_size): 308 | target_list = train_list[offset:offset+batch_size] 309 | input_list = [] 310 | gt_list = [] 311 | inputcbcr_list = [] 312 | for pair in target_list: 313 | input_img = Image.open(pair[0]) 314 | gt_img = Image.open(pair[1]) 315 | 316 | #crop images to the disired size. 317 | input_img, gt_img = crop(input_img, gt_img, PATCH_SIZE[0], PATCH_SIZE[1], "Image") 318 | 319 | #focus on Y channal only 320 | input_imgY, input_imgCbCr = img2y(input_img) 321 | gt_imgY, gt_imgCbCr = img2y(gt_img) 322 | 323 | #input_imgY = normalize(input_imgY) 324 | #gt_imgY = normalize(gt_imgY) 325 | 326 | input_list.append(input_imgY) 327 | gt_list.append(gt_imgY) 328 | inputcbcr_list.append(input_imgCbCr) 329 | 330 | input_list = np.resize(input_list, (batch_size, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 331 | gt_list = np.resize(gt_list, (batch_size, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 332 | 333 | return input_list, gt_list, inputcbcr_list 334 | 335 | def save_test_img(inputY, inputCbCr, path): 336 | assert len(inputY.shape) == 4, "the tensor Y's shape is %s"%inputY.shape 337 | assert inputY.shape[0] == 1, "the fitst component must be 1, has not been completed otherwise.{}".format(inputY.shape) 338 | 339 | inputY = np.squeeze(inputY, axis=0) 340 | inputY = inputY.astype('uint8') 341 | 342 | inputCbCr = inputCbCr.astype('uint8') 343 | 344 | output_concat = np.concatenate((inputY, inputCbCr), axis=2) 345 | img = Image.fromarray(output_concat, mode='YCbCr') 346 | img = img.convert('RGB') 347 | img.save(path) 348 | 349 | def psnr(hr_image, sr_image, max_value=255.0): 350 | eps = 1e-10 351 | if((type(hr_image)==type(np.array([]))) or (type(hr_image)==type([]))): 352 | hr_image_data = np.asarray(hr_image, 'float32') 353 | sr_image_data = np.asarray(sr_image, 'float32') 354 | 355 | diff = sr_image_data - hr_image_data 356 | mse = np.mean(diff*diff) 357 | mse = np.maximum(eps, mse) 358 | return float(10*math.log10(max_value*max_value/mse)) 359 | else: 360 | assert len(hr_image.shape)==4 and len(sr_image.shape)==4 361 | diff = hr_image - sr_image 362 | mse = tf.reduce_mean(tf.square(diff)) 363 | mse = tf.maximum(mse, eps) 364 | return 10*tf.log(max_value*max_value/mse)/math.log(10) 365 | 366 | def getBeforeNNBlockDict(img, w, h): 367 | # print(img[:1500, : 2000]) 368 | blockSize = 1000 369 | padding = 32 370 | yBlockNum = (h // blockSize) if (h % blockSize == 0) else (h // blockSize + 1) 371 | xBlockNum = (w // blockSize) if (w % blockSize == 0) else (w // blockSize + 1) 372 | tempImg = {} 373 | i = 0 374 | for yBlock in range(yBlockNum): 375 | for xBlock in range(xBlockNum): 376 | if yBlock == 0: 377 | if xBlock == 0: 378 | tempImg[i] = img[0: blockSize+padding, 0: blockSize+padding] 379 | elif xBlock == xBlockNum - 1: 380 | tempImg[i] = img[0: blockSize+padding, xBlock*blockSize-padding: w] 381 | else: 382 | tempImg[i] = img[0: blockSize+padding, blockSize*xBlock-padding: blockSize*(xBlock+1)+padding] 383 | elif yBlock == yBlockNum - 1: 384 | if xBlock == 0: 385 | tempImg[i] = img[blockSize*yBlock-padding: h, 0: blockSize+padding] 386 | elif xBlock == xBlockNum - 1: 387 | tempImg[i] = img[blockSize*yBlock-padding: h, blockSize*xBlock-padding: w] 388 | else: 389 | tempImg[i] = img[blockSize*yBlock-padding: h, blockSize*xBlock-padding: blockSize*(xBlock+1)+padding] 390 | elif xBlock == 0: 391 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding, 0: blockSize+padding] 392 | elif xBlock == xBlockNum - 1: 393 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding, blockSize*xBlock-padding: w] 394 | else: 395 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding, 396 | blockSize*xBlock-padding: blockSize*(xBlock+1)+padding] 397 | i += i 398 | l = tempImg[i].astype('uint8') 399 | l = Image.fromarray(l) 400 | l.show() 401 | -------------------------------------------------------------------------------- /train/train_Single/LSVE_Single_ra_train.py: -------------------------------------------------------------------------------- 1 | import argparse, time 2 | from LSVE_ra_model import model_single 3 | from UTILS_single_ra import * 4 | 5 | tf.logging.set_verbosity(tf.logging.WARN) 6 | 7 | #os.environ["CUDA_VISIBLE_DEVICES"] = "1" 8 | EXP_DATA = 'SF_qp22_ra_01243_VDSRSingle' # checkpoints path 9 | HIGH_DATA_PATH = r"E:\MF\trainSet\qp22\single\high_data" # high frames 10 | LABEL_PATH = r"E:\MF\trainSet\qp22\single\label_s" # lable frames 11 | LOG_PATH = "./logs/%s/"%(EXP_DATA) 12 | CKPT_PATH = "./checkpoints/%s/"%(EXP_DATA) 13 | SAMPLE_PATH = "./samples/%s/"%(EXP_DATA) 14 | PATCH_SIZE = (64, 64) 15 | BATCH_SIZE = 64 16 | BASE_LR = 3e-4 17 | LR_DECAY_RATE = 0.5 18 | LR_DECAY_STEP = 20 19 | MAX_EPOCH = 500 20 | 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--model_path") 24 | args = parser.parse_args() 25 | model_path = args.model_path 26 | 27 | if __name__ == '__main__': 28 | 29 | # train_list return like this"[[hdata1, label1],[hdata2, label2]]" with the whole path. 30 | train_list = get_train_list(load_file_list(HIGH_DATA_PATH), load_file_list(LABEL_PATH)) 31 | with tf.name_scope('input_scope'): 32 | train_hightData = tf.placeholder('float32', shape=(BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 33 | train_gt = tf.placeholder('float32', shape=(BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 34 | 35 | shared_model = tf.make_template('shared_model', model_single) 36 | train_output = shared_model(train_hightData) 37 | train_output = tf.clip_by_value(train_output, 0., 1.) 38 | with tf.name_scope('loss_scope'): 39 | loss2 = tf.reduce_mean(tf.square(tf.subtract(train_output, train_gt))) 40 | loss1 = tf.reduce_sum(tf.abs(tf.subtract(train_output, train_gt))) 41 | W = tf.get_collection(tf.GraphKeys.WEIGHTS) 42 | for w in W: 43 | loss2 += tf.nn.l2_loss(w)*1e-4 44 | 45 | avg_loss = tf.placeholder('float32') 46 | tf.summary.scalar("avg_loss", avg_loss) 47 | 48 | global_step = tf.Variable(0, trainable=False) # len(train_list) 49 | learning_rate = tf.train.exponential_decay(BASE_LR, global_step, LR_DECAY_STEP*1111, LR_DECAY_RATE, staircase=True) 50 | tf.summary.scalar("learning rate", learning_rate) 51 | 52 | # org --------------------------------------------------------------------------------- 53 | optimizer = tf.train.AdamOptimizer(learning_rate, 0.9) 54 | opt = optimizer.minimize(loss2, global_step=global_step) 55 | saver = tf.train.Saver(max_to_keep=0) 56 | # org end------------------------------------------------------------------------------ 57 | config = tf.ConfigProto() 58 | config.gpu_options.allow_growth = True 59 | 60 | with tf.Session(config=config) as sess: 61 | if not os.path.exists(LOG_PATH): 62 | os.makedirs(LOG_PATH) 63 | if not os.path.exists(os.path.dirname(CKPT_PATH)): 64 | os.makedirs(os.path.dirname(CKPT_PATH)) 65 | if not os.path.exists(SAMPLE_PATH): 66 | os.makedirs(SAMPLE_PATH) 67 | 68 | merged = tf.summary.merge_all() 69 | file_writer = tf.summary.FileWriter(LOG_PATH, sess.graph) 70 | 71 | sess.run(tf.global_variables_initializer()) 72 | 73 | if model_path: 74 | print("restore model...") 75 | saver.restore(sess, model_path) 76 | print("Done") 77 | 78 | #for epoch in range(400, MAX_EPOCH): 79 | for epoch in range(MAX_EPOCH): 80 | total_g_loss, n_iter = 0, 0 81 | idxOfImgs = np.random.permutation(len(train_list)) 82 | 83 | epoch_time = time.time() 84 | 85 | # for idx in range(len(idxOfImgs)): 86 | for idx in range(1111): 87 | input_highData, gt_data = prepare_nn_data(train_list) 88 | feed_dict = {train_hightData: input_highData, train_gt: gt_data} 89 | 90 | _, l, output, g_step = sess.run([opt, loss2, train_output, global_step], feed_dict=feed_dict) 91 | total_g_loss += l 92 | n_iter += 1 93 | 94 | del input_highData, gt_data, output 95 | lr, summary = sess.run([learning_rate, merged], {avg_loss:total_g_loss/n_iter}) 96 | file_writer.add_summary(summary, epoch) 97 | tf.logging.warning("Epoch: [%4d/%4d] time: %4.4f\tloss: %.8f\tlr: %.8f"%(epoch, MAX_EPOCH, time.time()-epoch_time, total_g_loss/n_iter, lr)) 98 | print("Epoch: [%4d/%4d] time: %4.4f\tloss: %.8f\tlr: %.8f"%(epoch, MAX_EPOCH, time.time()-epoch_time, total_g_loss/n_iter, lr)) 99 | saver.save(sess, os.path.join(CKPT_PATH, "%s_%03d.ckpt"%(EXP_DATA, epoch))) 100 | -------------------------------------------------------------------------------- /train/train_Single/LSVE_ra_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def model_single(inputLowData_tensor): 6 | # with tf.device("/gpu:0"): 7 | input_current = inputLowData_tensor # lowData 8 | 9 | # due to don't have training_Set at right now, so let it be annotation. 10 | tensor = None 11 | 12 | # ----------------------------------------------Frame 0-------------------------------------------------------- 13 | input_current_w = tf.get_variable("input_current_w", [5, 5, 1, 64], 14 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 25))) 15 | input_current_b = tf.get_variable("input_current_b", [64], initializer=tf.constant_initializer(0)) 16 | input_low_tensor = tf.nn.relu( 17 | tf.nn.bias_add(tf.nn.conv2d(input_current, input_current_w, strides=[1, 1, 1, 1], padding='SAME'), 18 | input_current_b)) 19 | 20 | # ----------------------------------1x1 conv, for reduce number of parameters---------------- 21 | input_3x3_w = tf.get_variable("input_1x1_w", [3, 3, 64, 64], 22 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 64 / 9))) 23 | input_1x1_b = tf.get_variable("input_1x1_b", [64], initializer=tf.constant_initializer(0)) 24 | input_3x3_tensor = tf.nn.relu( 25 | tf.nn.bias_add(tf.nn.conv2d(input_low_tensor, input_3x3_w, strides=[1, 1, 1, 1], padding='SAME'), 26 | input_1x1_b)) 27 | tensor = input_3x3_tensor 28 | 29 | # --------------------------------------start iteration for last layers---------------------------------- 30 | convId = 0 31 | for i in range(18): 32 | conv_w = tf.get_variable("conv_%02d_w" % (convId), [3, 3, 64, 64], 33 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 9 / 64))) 34 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, conv_w) 35 | conv_b = tf.get_variable("conv_%02d_b" % (convId), [64], initializer=tf.constant_initializer(0)) 36 | convId += 1 37 | tensor = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1, 1, 1, 1], padding='SAME'), conv_b)) 38 | 39 | conv_w = tf.get_variable("conv_%02d_w" % (convId), [3, 3, 64, 1], 40 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 9 / 64))) 41 | conv_b = tf.get_variable("conv_%02d_b" % (convId), [1], initializer=tf.constant_initializer(0)) 42 | convId += 1 43 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, conv_w) 44 | tensor = tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1, 1, 1, 1], padding='SAME'), conv_b) 45 | tensor = tf.add(tensor, input_current) 46 | return tensor 47 | -------------------------------------------------------------------------------- /train/train_Single/UTILS_single_ra.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import math, os, random, re 4 | from PIL import Image 5 | from LSVE_Single_ra_train import BATCH_SIZE 6 | from LSVE_Single_ra_train import PATCH_SIZE 7 | 8 | # due to a batch trainingSet come from one picture. I design a algorithm to make the TrainingSet more diversity. 9 | 10 | def normalize(x): 11 | x = x / 255. 12 | return truncate(x, 0., 1.) 13 | 14 | def denormalize(x): 15 | x = x * 255. 16 | return truncate(x, 0., 255.) 17 | 18 | def truncate(input, min, max): 19 | input = np.where(input > min, input, min) 20 | input = np.where(input < max, input, max) 21 | return input 22 | 23 | def remap(input): 24 | input = 16+219/255*input 25 | return truncate(input, 16.0, 235.0) 26 | 27 | def deremap(input): 28 | input = (input-16)*255/219 29 | return truncate(input, 0.0, 255.0) 30 | 31 | # return the whole absolute path. 32 | def load_file_list(directory): 33 | list = [] 34 | for filename in [y for y in os.listdir(directory) if os.path.isfile(os.path.join(directory,y))]: 35 | list.append(os.path.join(directory,filename)) 36 | return list 37 | 38 | def searchHighData(currentLowDataIndex, highDataList, highIndexList): 39 | searchOffset = 3 40 | searchedHighDataIndexList = [] 41 | searchedHighData = [] 42 | for i in range(currentLowDataIndex - searchOffset, currentLowDataIndex + searchOffset + 1): 43 | if i in highIndexList: 44 | searchedHighDataIndexList.append(i) 45 | assert len(searchedHighDataIndexList) == 2, 'search method have error!' 46 | for tempData in highDataList: 47 | if int(os.path.basename(tempData).split('.')[0].split('_')[-1]) \ 48 | == searchedHighDataIndexList[0] == searchedHighDataIndexList[1]: 49 | 50 | searchedHighData.append(tempData) 51 | return searchedHighData 52 | 53 | 54 | # return like this"[[[high1Data, lowData], label], [[2, 7, 8], 22], [[3, 8, 9], 33]]" with the whole path. 55 | def get_test_list2(highDataList, lowDataList, labelList): 56 | assert len(lowDataList) == len(labelList), "low:%d, label:%d,"%(len(lowDataList) , len(labelList)) 57 | 58 | # [0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48] 59 | highIndexList = [q for q in range(49) if q % 4 == 0] 60 | test_list = [] 61 | for tempDataPath in lowDataList: 62 | tempData = [] 63 | temp = [] 64 | # this place should changed on the different situation. 65 | currentLowDataIndex = int(os.path.basename(tempDataPath).split('.')[0].split('_')[-1]) 66 | searchedHighData = searchHighData(currentLowDataIndex, highDataList, highIndexList) 67 | tempData.append(searchedHighData[0]) 68 | tempData.append(tempDataPath) 69 | tempData.append(searchedHighData[1]) 70 | 71 | i = list(lowDataList).index(tempDataPath) 72 | 73 | temp.append(tempData) 74 | temp.append(labelList[i]) 75 | 76 | test_list.append(temp) 77 | return test_list 78 | 79 | def get_temptest_list(high1DataList, lowDataList, high2DataList, labelList): 80 | tempData = [] 81 | temp = [] 82 | test_list = [] 83 | for i in range(len(lowDataList)): 84 | tempData.append(high1DataList[i]) 85 | tempData.append(lowDataList[i]) 86 | tempData.append(high2DataList[i]) 87 | temp.append(tempData) 88 | temp.append(labelList[i]) 89 | 90 | test_list.append(temp) 91 | return test_list 92 | 93 | # [[high, low1, label1], [[h21,h22], low2, label2]] 94 | def get_test_list(HIGHDATA_Parent_PATH, lowDataList, labelList): 95 | doubleTest_list = [] 96 | singleTest_list = [] 97 | HighDirList = os.listdir(HIGHDATA_Parent_PATH) 98 | # convert string-list to int-list. 99 | HighDirList = list(map(int, HighDirList)) 100 | 101 | for lowdata in lowDataList: 102 | 103 | tempData = [] 104 | lowdataIndex = int(os.path.basename(lowdata).split('.')[0].split('_')[-1]) 105 | if lowdataIndex % 4 != 0: 106 | if lowdataIndex in HighDirList: 107 | High_current_path = HIGHDATA_Parent_PATH + '/' + str(lowdataIndex) 108 | TWO_HighData = os.listdir(High_current_path) 109 | Two_HighData = [os.path.join(High_current_path, T) for T in TWO_HighData] 110 | tempData.append(Two_HighData) 111 | tempData.append(lowdata) 112 | labelIndex = list(lowDataList).index(lowdata) 113 | tempData.append(labelList[labelIndex]) 114 | doubleTest_list.append(tempData) 115 | else: 116 | tempData.append(lowdata) 117 | labelIndex = list(lowDataList).index(lowdata) 118 | if int(os.path.basename(lowdata).split('.')[0].split('_')[-1]) == \ 119 | int(os.path.basename(labelList[labelIndex]).split('.')[0].split('_')[-1]): 120 | 121 | tempData.append(labelList[labelIndex]) 122 | singleTest_list.append(tempData) 123 | 124 | return doubleTest_list, singleTest_list 125 | 126 | # return like this"[[hdata1, label1],[hdata2, label2]]" with the whole path. 127 | def get_train_list(highDataList, labelList): 128 | assert len(highDataList) == len(labelList), \ 129 | "high:%d, label:%d" % (len(highDataList), len(labelList)) 130 | 131 | train_list = [] 132 | for i in range(len(labelList)): 133 | tempData = [] 134 | # this place should changed on the different situation. 135 | if int(os.path.basename(highDataList[i]).split('_')[-1].split('.')[0]) == \ 136 | int(os.path.basename(labelList[i]).split('_')[-1].split('.')[0]): 137 | tempData.append(highDataList[i]) 138 | tempData.append(labelList[i]) 139 | 140 | else: 141 | raise Exception('len(lowData) not equal with len(highData)...') 142 | train_list.append(tempData) 143 | return train_list 144 | 145 | # train_list like this"[[hdata1, label1],[hdata2, label2]]" with the whole path. 146 | def prepare_nn_data(train_list): 147 | batchSizeRandomList = random.sample(range(0,len(train_list)), 8) 148 | gt_list = [] 149 | highData_list = [] 150 | for i in batchSizeRandomList: 151 | highData_image = c_getYdata(train_list[i][0]) 152 | gt_image = c_getYdata(train_list[i][1]) 153 | for j in range(0, 8): 154 | #crop images to the disired size. 155 | highData_imgY, gt_imgY = \ 156 | crop(highData_image, gt_image, PATCH_SIZE[0], PATCH_SIZE[1], "ndarray") 157 | #normalize 158 | highData_imgY = normalize(highData_imgY) 159 | gt_imgY = normalize(gt_imgY) 160 | 161 | highData_list.append(highData_imgY) 162 | gt_list.append(gt_imgY) 163 | 164 | high1Data_list = np.resize(highData_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 165 | gt_list = np.resize(gt_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 166 | return high1Data_list, gt_list 167 | 168 | def getWH(yuvfileName): 169 | deyuv=re.compile(r'(.+?)\.') 170 | deyuvFilename=deyuv.findall(yuvfileName)[0] #去yuv后缀的文件名 171 | if 'x' in os.path.basename(deyuvFilename).split('_')[-2]: 172 | wxh = os.path.basename(deyuvFilename).split('_')[-2] 173 | elif 'x' in os.path.basename(deyuvFilename).split('_')[1]: 174 | wxh = os.path.basename(deyuvFilename).split('_')[1] 175 | else: 176 | raise Exception('do not find wxh') 177 | w, h = wxh.split('x') 178 | return int(w), int(h) 179 | 180 | def getYdata(path, size): 181 | w = size[0] 182 | h = size[1] 183 | with open(path, 'rb') as fp: 184 | fp.seek(0, 0) 185 | Yt = fp.read() 186 | tem = Image.frombytes('L', [w, h], Yt) 187 | Yt = np.asarray(tem, dtype='float32') 188 | return Yt 189 | 190 | def c_getYdata(path): 191 | return getYdata(path, getWH(path)) 192 | 193 | def img2y(input_img): 194 | if np.asarray(input_img).shape[2] == 3: 195 | input_imgY = input_img.convert('YCbCr').split()[0] 196 | input_imgCb, input_imgCr = input_img.convert('YCbCr').split()[1:3] 197 | 198 | input_imgY = np.asarray(input_imgY, dtype='float32') 199 | input_imgCb = np.asarray(input_imgCb, dtype='float32') 200 | input_imgCr = np.asarray(input_imgCr, dtype='float32') 201 | 202 | 203 | #Concatenate Cb, Cr components for easy, they are used in pair anyway. 204 | input_imgCb = np.expand_dims(input_imgCb,2) 205 | input_imgCr = np.expand_dims(input_imgCr,2) 206 | input_imgCbCr = np.concatenate((input_imgCb, input_imgCr), axis=2) 207 | 208 | elif np.asarray(input_img).shape[2] == 1: 209 | print("This image has one channal only.") 210 | #If the num of channal is 1, remain. 211 | input_imgY = input_img 212 | input_imgCbCr = None 213 | else: 214 | print("The num of channal is neither 3 nor 1.") 215 | exit() 216 | return input_imgY, input_imgCbCr 217 | 218 | # def crop(input_image, gt_image, patch_width, patch_height, img_type): 219 | def crop(highData_image, gt_image, patch_width, patch_height, img_type): 220 | assert type(highData_image) == type(gt_image), "types are different." 221 | highData_cropped = [] 222 | gt_cropped = [] 223 | 224 | # return a ndarray object 225 | if img_type == "ndarray": 226 | in_row_ind = random.randint(0,highData_image.shape[0]-patch_width) 227 | in_col_ind = random.randint(0,highData_image.shape[1]-patch_height) 228 | highData_cropped = highData_image[in_row_ind:in_row_ind+patch_width, in_col_ind:in_col_ind+patch_height] 229 | gt_cropped = gt_image[in_row_ind:in_row_ind+patch_width, in_col_ind:in_col_ind+patch_height] 230 | 231 | #return an "Image" object 232 | elif img_type == "Image": 233 | pass 234 | return highData_cropped, gt_cropped 235 | 236 | def save_images(inputY, inputCbCr, size, image_path): 237 | """Save mutiple images into one single image. 238 | 239 | # Parameters 240 | # ----------- 241 | # images : numpy array [batch, w, h, c] 242 | # size : list of two int, row and column number. 243 | # number of images should be equal or less than size[0] * size[1] 244 | # image_path : string. 245 | # 246 | # Examples 247 | # --------- 248 | # # >>> images = np.random.rand(64, 100, 100, 3) 249 | # # >>> tl.visualize.save_images(images, [8, 8], 'temp.png') 250 | """ 251 | def merge(images, size): 252 | h, w = images.shape[1], images.shape[2] 253 | img = np.zeros((h * size[0], w * size[1], 3)) 254 | for idx, image in enumerate(images): 255 | i = idx % size[1] 256 | j = idx // size[1] 257 | img[j*h:j*h+h, i*w:i*w+w, :] = image 258 | return img 259 | 260 | inputY = inputY.astype('uint8') 261 | inputCbCr = inputCbCr.astype('uint8') 262 | output_concat = np.concatenate((inputY, inputCbCr), axis=3) 263 | 264 | assert len(output_concat) <= size[0] * size[1], "number of images should be equal or less than size[0] * size[1] {}".format(len(output_concat)) 265 | 266 | new_output = merge(output_concat, size) 267 | new_output = new_output.astype('uint8') 268 | img = Image.fromarray(new_output, mode='YCbCr') 269 | img = img.convert('RGB') 270 | img.save(image_path) 271 | 272 | def get_image_batch(train_list,offset,batch_size): 273 | target_list = train_list[offset:offset+batch_size] 274 | input_list = [] 275 | gt_list = [] 276 | inputcbcr_list = [] 277 | for pair in target_list: 278 | input_img = Image.open(pair[0]) 279 | gt_img = Image.open(pair[1]) 280 | 281 | #crop images to the disired size. 282 | input_img, gt_img = crop(input_img, gt_img, PATCH_SIZE[0], PATCH_SIZE[1], "Image") 283 | 284 | #focus on Y channal only 285 | input_imgY, input_imgCbCr = img2y(input_img) 286 | gt_imgY, gt_imgCbCr = img2y(gt_img) 287 | input_list.append(input_imgY) 288 | gt_list.append(gt_imgY) 289 | inputcbcr_list.append(input_imgCbCr) 290 | 291 | input_list = np.resize(input_list, (batch_size, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 292 | gt_list = np.resize(gt_list, (batch_size, PATCH_SIZE[0], PATCH_SIZE[1], 1)) 293 | 294 | return input_list, gt_list, inputcbcr_list 295 | 296 | def save_test_img(inputY, inputCbCr, path): 297 | assert len(inputY.shape) == 4, "the tensor Y's shape is %s"%inputY.shape 298 | assert inputY.shape[0] == 1, "the fitst component must be 1, has not been completed otherwise.{}".format(inputY.shape) 299 | 300 | inputY = np.squeeze(inputY, axis=0) 301 | inputY = inputY.astype('uint8') 302 | inputCbCr = inputCbCr.astype('uint8') 303 | output_concat = np.concatenate((inputY, inputCbCr), axis=2) 304 | img = Image.fromarray(output_concat, mode='YCbCr') 305 | img = img.convert('RGB') 306 | img.save(path) 307 | 308 | def psnr(hr_image, sr_image, max_value=255.0): 309 | eps = 1e-10 310 | if((type(hr_image)==type(np.array([]))) or (type(hr_image)==type([]))): 311 | hr_image_data = np.asarray(hr_image, 'float32') 312 | sr_image_data = np.asarray(sr_image, 'float32') 313 | 314 | diff = sr_image_data - hr_image_data 315 | mse = np.mean(diff*diff) 316 | mse = np.maximum(eps, mse) 317 | return float(10*math.log10(max_value*max_value/mse)) 318 | else: 319 | assert len(hr_image.shape)==4 and len(sr_image.shape)==4 320 | diff = hr_image - sr_image 321 | mse = tf.reduce_mean(tf.square(diff)) 322 | mse = tf.maximum(mse, eps) 323 | return 10*tf.log(max_value*max_value/mse)/math.log(10) 324 | 325 | def getBeforeNNBlockDict(img, w, h): 326 | blockSize = 1000 327 | padding = 32 328 | yBlockNum = (h // blockSize) if (h % blockSize == 0) else (h // blockSize + 1) 329 | xBlockNum = (w // blockSize) if (w % blockSize == 0) else (w // blockSize + 1) 330 | tempImg = {} 331 | i = 0 332 | for yBlock in range(yBlockNum): 333 | for xBlock in range(xBlockNum): 334 | if yBlock == 0: 335 | if xBlock == 0: 336 | tempImg[i] = img[0: blockSize+padding, 0: blockSize+padding] 337 | elif xBlock == xBlockNum - 1: 338 | tempImg[i] = img[0: blockSize+padding, xBlock*blockSize-padding: w] 339 | else: 340 | tempImg[i] = img[0: blockSize+padding, blockSize*xBlock-padding: blockSize*(xBlock+1)+padding] 341 | elif yBlock == yBlockNum - 1: 342 | if xBlock == 0: 343 | tempImg[i] = img[blockSize*yBlock-padding: h, 0: blockSize+padding] 344 | elif xBlock == xBlockNum - 1: 345 | tempImg[i] = img[blockSize*yBlock-padding: h, blockSize*xBlock-padding: w] 346 | else: 347 | tempImg[i] = img[blockSize*yBlock-padding: h, blockSize*xBlock-padding: blockSize*(xBlock+1)+padding] 348 | elif xBlock == 0: 349 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding, 0: blockSize+padding] 350 | elif xBlock == xBlockNum - 1: 351 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding, blockSize*xBlock-padding: w] 352 | else: 353 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding, 354 | blockSize*xBlock-padding: blockSize*(xBlock+1)+padding] 355 | i += i 356 | l = tempImg[i].astype('uint8') 357 | l = Image.fromarray(l) 358 | l.show() 359 | --------------------------------------------------------------------------------