├── README.md
├── READYME_PNG
├── LMVE_qpx4.png
├── QE-subnet.png
├── compareWithMFQE.png
├── results_of_each_frame.png
├── 主观图排版_LMVE.png
└── 主观图排版_wraped.png
├── experiment_tjc.xlsx
├── test
├── test_MF
│ ├── LMVE_model.py
│ ├── LMVE_test.py
│ ├── LSVE_model.py
│ ├── UTILS_MF_ra.py
│ ├── __pycache__
│ │ ├── MFVDSingle_ra.cpython-35.pyc
│ │ ├── UTILS_MF_ra.cpython-35.pyc
│ │ └── yangNet.cpython-35.pyc
│ └── checkpoints
│ │ ├── MF_qp22_ra_01222_LMVE_double
│ │ ├── MF_qp22_ra_01222_LMVE_double_149.ckpt.data-00000-of-00001
│ │ ├── MF_qp22_ra_01222_LMVE_double_149.ckpt.index
│ │ ├── MF_qp22_ra_01222_LMVE_double_149.ckpt.meta
│ │ └── checkpoint
│ │ ├── MF_qp27_ra_01202_LMVE_double
│ │ ├── MF_qp27_ra_01202_LMVE_double_201.ckpt.data-00000-of-00001
│ │ ├── MF_qp27_ra_01202_LMVE_double_201.ckpt.index
│ │ ├── MF_qp27_ra_01202_LMVE_double_201.ckpt.meta
│ │ └── checkpoint
│ │ ├── MF_qp32_ra_01262_LMVE_double
│ │ ├── MF_qp32_ra_01262_LMVE_double_138.ckpt.data-00000-of-00001
│ │ ├── MF_qp32_ra_01262_LMVE_double_138.ckpt.index
│ │ ├── MF_qp32_ra_01262_LMVE_double_138.ckpt.meta
│ │ └── checkpoint
│ │ └── MF_qp37_ra_01121
│ │ ├── MF_qp37_ra_01121_402.ckpt.data-00000-of-00001
│ │ ├── MF_qp37_ra_01121_402.ckpt.index
│ │ ├── MF_qp37_ra_01121_402.ckpt.meta
│ │ ├── a.txt
│ │ └── checkpoint
└── test_Single
│ ├── LSVE_ra_model.py
│ ├── LSVE_test.py
│ ├── UTILS_MF_ra_ALL_Single.py
│ ├── __pycache__
│ ├── MFVDSingle_ra.cpython-35.pyc
│ └── UTILS_MF_ra_ALL_Single.cpython-35.pyc
│ ├── checkpoints
│ ├── SF_qp22_ra_01251_LSVE
│ │ ├── SF_qp22_ra_01251_LSVE_105.ckpt.data-00000-of-00001
│ │ ├── SF_qp22_ra_01251_LSVE_105.ckpt.index
│ │ ├── SF_qp22_ra_01251_LSVE_105.ckpt.meta
│ │ └── checkpoint
│ ├── SF_qp27_ra_01232_LSVE
│ │ ├── SF_qp27_ra_01232_LSVE_169.ckpt.data-00000-of-00001
│ │ ├── SF_qp27_ra_01232_LSVE_169.ckpt.index
│ │ ├── SF_qp27_ra_01232_LSVE_169.ckpt.meta
│ │ └── checkpoint
│ ├── SF_qp32_ra_01242_LSVE
│ │ ├── SF_qp32_ra_01242_LSVE_185.ckpt.data-00000-of-00001
│ │ ├── SF_qp32_ra_01242_LSVE_185.ckpt.index
│ │ ├── SF_qp32_ra_01242_LSVE_185.ckpt.meta
│ │ └── checkpoint
│ └── SF_qp37_ra_01174_LSVE
│ │ ├── MF_qp37_ra_01174_LSVE_291.ckpt.data-00000-of-00001
│ │ ├── MF_qp37_ra_01174_LSVE_291.ckpt.index
│ │ ├── MF_qp37_ra_01174_LSVE_291.ckpt.meta
│ │ └── checkpoint
│ └── outdata
│ └── SF_qp27_ra_01232_LSVE
│ ├── events.out.tfevents.1549718071.AIR-PC
│ └── events.out.tfevents.1549718367.AIR-PC
└── train
├── train_MF
├── LMVE_ra_model.py
├── LMVE_ra_train.py
└── UTILS_MF_ra.py
└── train_Single
├── LSVE_Single_ra_train.py
├── LSVE_ra_model.py
└── UTILS_single_ra.py
/README.md:
--------------------------------------------------------------------------------
1 | ## LEARNING-BASED MULTI-FRAME VIDEO QUALITY ENHANCEMENT
2 | Junchao Tong*, Xilin Wu*, Dandan Ding*, Zheng Zhu**, Zoe Liu**
3 | \* Hangzhou Normal University
4 | ** Visionular Inc.
5 |
6 | ___
7 | * ### Example of the subjective quality for the HEVC compressed video(the 1st column and the 2nd column) and the preprocessed frames(MF、HF).
8 | `Here, MF is the Moderate-quality Frame and HF is the High-quality Frame.`
9 |
10 | 
11 |
12 | ___
13 | * ### Subjective quality performance on *FourPeople* and *BasketballPass* at QP=37 in RA configuration.
14 | 
15 |
16 | ___
17 | ## The experiment section
18 | ### 1.Performance of our proposed LMVE in comparison with LSVE and HEVC in terms of PSNR(dB).
19 |
20 | 
21 |
22 |
23 | ### 2.Performance of LMVE Compared with MFQE at QP=37.
24 |
25 | 
26 |
27 |
28 | * The architecture of the QE-subnet is shown in Figure 6, and the details of the convolutional layers are presented in Table 3.
29 | 
30 |
31 | ### 3.PSNR(dB) results of each frame in the first 21 frames on the FourPeople and PeopleOnStreet sequences.
32 | 
33 | * `For more details on the experiment section, you can download the "experiment_tjc.xlsx" file in the current web page`
34 |
35 | ___
36 | The open source address for FlowNet is: https://github.com/lmb-freiburg/flownet2
37 |
38 | The open source address for MFQE is: https://github.com/ryangBUAA/MFQE
39 | ___
40 | Thanks for reading!
41 |
42 | best wishes,
43 | Junchao Tong.
44 |
--------------------------------------------------------------------------------
/READYME_PNG/LMVE_qpx4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/READYME_PNG/LMVE_qpx4.png
--------------------------------------------------------------------------------
/READYME_PNG/QE-subnet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/READYME_PNG/QE-subnet.png
--------------------------------------------------------------------------------
/READYME_PNG/compareWithMFQE.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/READYME_PNG/compareWithMFQE.png
--------------------------------------------------------------------------------
/READYME_PNG/results_of_each_frame.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/READYME_PNG/results_of_each_frame.png
--------------------------------------------------------------------------------
/READYME_PNG/主观图排版_LMVE.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/READYME_PNG/主观图排版_LMVE.png
--------------------------------------------------------------------------------
/READYME_PNG/主观图排版_wraped.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/READYME_PNG/主观图排版_wraped.png
--------------------------------------------------------------------------------
/experiment_tjc.xlsx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/experiment_tjc.xlsx
--------------------------------------------------------------------------------
/test/test_MF/LMVE_model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 |
5 | def model_double(inputHigh1Data_tensor, inputLowData_tensor, inputHigh2Data_tensor):
6 | # with tf.device("/gpu:0"):
7 | input_before = inputHigh1Data_tensor # highData
8 | input_current = inputLowData_tensor # lowData
9 |
10 | input_after = inputHigh2Data_tensor
11 |
12 | # due to don't have training_Set at right now, so let it be annotation.
13 | tensor = None
14 |
15 | # ----------------------------------------------Frame -1--------------------------------------------------------
16 | input_before_w = tf.get_variable("input_before_w", [5, 5, 1, 64],
17 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 25)))
18 | input_before_b = tf.get_variable("input_before_b", [64], initializer=tf.constant_initializer(0))
19 | input_high1_tensor = tf.nn.relu(
20 | tf.nn.bias_add(tf.nn.conv2d(input_before, input_before_w, strides=[1, 1, 1, 1], padding='SAME'), input_before_b))
21 | # ----------------------------------------------Frame 0--------------------------------------------------------
22 | input_current_w = tf.get_variable("input_current_w", [5, 5, 1, 64],
23 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 25)))
24 | input_current_b = tf.get_variable("input_current_b", [64], initializer=tf.constant_initializer(0))
25 | input_low_tensor = tf.nn.relu(
26 | tf.nn.bias_add(tf.nn.conv2d(input_current, input_current_w, strides=[1, 1, 1, 1], padding='SAME'),
27 | input_current_b))
28 |
29 | # ----------------------------------------------Frame 1--------------------------------------------------------
30 | input_after_w = tf.get_variable("input_after_w", [5, 5, 1, 64],
31 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 25)))
32 | input_after_b = tf.get_variable("input_after_b", [64], initializer=tf.constant_initializer(0))
33 | input_high2_tensor = tf.nn.relu(
34 | tf.nn.bias_add(tf.nn.conv2d(input_after, input_after_w, strides=[1, 1, 1, 1], padding='SAME'),
35 | input_after_b))
36 | # ------------------------------------------Frame -1\0\1 concat------------------------------------------
37 | input_tensor_Concat = tf.concat([input_high1_tensor, input_low_tensor, input_high2_tensor], axis=3)
38 | # ----------------------------------1x1 conv, for reduce number of parameters----------------
39 | input_1x1_w = tf.get_variable("input_1x1_w", [1, 1, 192, 64],
40 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 192)))
41 | input_1x1_b = tf.get_variable("input_1x1_b", [64], initializer=tf.constant_initializer(0))
42 | input_1x1_tensor = tf.nn.relu(
43 | tf.nn.bias_add(tf.nn.conv2d(input_tensor_Concat, input_1x1_w, strides=[1, 1, 1, 1], padding='SAME'),
44 | input_1x1_b))
45 | tensor = input_1x1_tensor
46 |
47 | # --------------------------------------start iteration for last layers----------------------------------
48 | convId = 0
49 | for i in range(18):
50 | conv_w = tf.get_variable("conv_%02d_w" % (convId), [3, 3, 64, 64],
51 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 9 / 64)))
52 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, conv_w)
53 | conv_b = tf.get_variable("conv_%02d_b" % (convId), [64], initializer=tf.constant_initializer(0))
54 | convId += 1
55 | tensor = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1, 1, 1, 1], padding='SAME'), conv_b))
56 |
57 | conv_w = tf.get_variable("conv_%02d_w" % (convId), [3, 3, 64, 1],
58 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 9 / 64)))
59 | conv_b = tf.get_variable("conv_%02d_b" % (convId), [1], initializer=tf.constant_initializer(0))
60 | convId += 1
61 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, conv_w)
62 | tensor = tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1, 1, 1, 1], padding='SAME'), conv_b)
63 |
64 | tensor = tf.add(tensor, input_current)
65 | return tensor
66 |
--------------------------------------------------------------------------------
/test/test_MF/LMVE_test.py:
--------------------------------------------------------------------------------
1 | import time
2 | import itertools
3 | from LMVE_model import model_double
4 | from LSVE_model import model_single
5 | from UTILS_MF_ra import *
6 |
7 | tf.logging.set_verbosity(tf.logging.WARN)
8 |
9 | EXP_DATA1 = "MF_qp22_ra_01222_LMVE_double" # double checkpoints
10 | EXP_DATA2 = "MF_qp22_ra_01222_LMVE_double" # single checkpoints not be used
11 | MODEL_DOUBLE_PATH1 = "./checkpoints/%s/"%(EXP_DATA1)
12 | MODEL_DOUBLE_PATH2 = "./checkpoints/%s/"%(EXP_DATA2)
13 | HIGHDATA_Parent_PATH = r"E:\MF\HEVC_TestSequenct\qp37_ldp\warped_yuv\FourPeople_qp22_ai_1280x720" # warped high frames
14 | QP_LOWDATA_PATH = r'E:\MF\HEVC_TestSequenct\rec\qp37\ldp\E\FourPeople_qp37_ldp_1280x720' # low frames
15 | GT_PATH = r"E:\MF\HEVC_TestSequenct\org\E\FourPeople_1280x720" # ground truth
16 | DL_path = 'E:\MF\HEVC_TestSequenct\DL_rec\double\qp37'
17 | OUT_DATA_PATH = "./outdata/%s/"%(EXP_DATA1)
18 |
19 |
20 | # Ground truth images dir should be the 2nd component of 'fileOrDir' if 2 components are given.
21 |
22 | ##cb, cr components are not implemented
23 | def prepare_test_data(fileOrDir):
24 | doubleData_ycbcr = []
25 | doubleGT_y = []
26 | singleData_ycbcr = []
27 | singleGT_y = []
28 | fileName_list = []
29 | #The input is a single file.
30 | if len(fileOrDir) == 3:
31 | # return the whole absolute path.
32 | fileName_list = load_file_list(fileOrDir[1])
33 | # double_list # [[high, low1, label1], [[h21,h22], low2, label2]]
34 | # single_list # [[low1, lable1], [2,2] ....]
35 | double_list, single_list = get_test_list(HIGHDATA_Parent_PATH, load_file_list(fileOrDir[1]),
36 | load_file_list(fileOrDir[2]))
37 |
38 | for pair in double_list:
39 | high1Data_List = []
40 | lowData_List = []
41 | high2Data_List = []
42 |
43 | high1Data_imgY = c_getYdata(pair[0][0])
44 | high2Data_imgY = c_getYdata(pair[0][1])
45 | lowData_imgY = c_getYdata(pair[1])
46 | CbCr = c_getCbCr(pair[1])
47 | gt_imgY = c_getYdata(pair[2])
48 |
49 | #normalize
50 | high1Data_imgY = normalize(high1Data_imgY)
51 | lowData_imgY = normalize(lowData_imgY)
52 | high2Data_imgY = normalize(high2Data_imgY)
53 |
54 | high1Data_imgY = np.resize(high1Data_imgY, (1, high1Data_imgY.shape[0], high1Data_imgY.shape[1],1))
55 | lowData_imgY = np.resize(lowData_imgY, (1, lowData_imgY.shape[0], lowData_imgY.shape[1], 1))
56 | high2Data_imgY = np.resize(high2Data_imgY, (1, high2Data_imgY.shape[0], high2Data_imgY.shape[1], 1))
57 | gt_imgY = np.resize(gt_imgY, (1, gt_imgY.shape[0], gt_imgY.shape[1],1))
58 |
59 | ## act as a placeholder
60 |
61 | high1Data_List.append([high1Data_imgY, 0])
62 | lowData_List.append([lowData_imgY, CbCr])
63 | high2Data_List.append([high2Data_imgY, 0])
64 | doubleData_ycbcr.append([high1Data_List, lowData_List, high2Data_List])
65 | doubleGT_y.append(gt_imgY)
66 |
67 | # single_list # [[low1, lable1], [2,2] ....]
68 | for pair in single_list:
69 | lowData_list = []
70 | lowData_imgY = c_getYdata(pair[0])
71 | CbCr = c_getCbCr(pair[0])
72 | gt_imgY = c_getYdata(pair[1])
73 |
74 | # normalize
75 | lowData_imgY = normalize(lowData_imgY)
76 |
77 | lowData_imgY = np.resize(lowData_imgY, (1, lowData_imgY.shape[0], lowData_imgY.shape[1], 1))
78 | gt_imgY = np.resize(gt_imgY, (1, gt_imgY.shape[0], gt_imgY.shape[1], 1))
79 |
80 | lowData_list.append([lowData_imgY, CbCr])
81 | singleData_ycbcr.append(lowData_list)
82 | singleGT_y.append(gt_imgY)
83 |
84 | else:
85 | print("Invalid Inputs...!tjc!")
86 | exit(0)
87 |
88 | return doubleData_ycbcr, doubleGT_y, singleData_ycbcr, singleGT_y, fileName_list
89 |
90 | def test_all_ckpt(modelPath1, modelPath2, fileOrDir):
91 | max = [0, 0]
92 | tem1 = [f for f in os.listdir(modelPath1) if 'data' in f]
93 | ckptFiles1 = sorted([r.split('.data')[0] for r in tem1])
94 | tem2 = [f for f in os.listdir(modelPath2) if 'data' in f]
95 | ckptFiles2 = sorted([r.split('.data')[0] for r in tem2])
96 | re_psnr = tf.placeholder('float32')
97 | tf.summary.scalar('re_psnr', re_psnr)
98 |
99 | doubleData_ycbcr, doubleGT_y, singleData_ycbcr, singleGT_y, fileName_list = prepare_test_data(fileOrDir)
100 | total_time, total_psnr = 0, 0
101 | total_imgs = len(fileName_list)
102 | count = 0
103 | for i in range(total_imgs):
104 | if i % 4 != 0:
105 | count += 1
106 | # sorry! this place write so difficult!【[[[h1,0]],[[low,0]],[[h2, 0]]], [[[h1,0]],[[low,0]],[[h2, 0]]]】
107 | j = i - (i//4) - 1
108 | imgHigh1DataY = doubleData_ycbcr[j][0][0][0]
109 | imgLowDataY = doubleData_ycbcr[j][1][0][0]
110 | imgLowCbCr = doubleData_ycbcr[j][1][0][1]
111 | imgHigh2DataY = doubleData_ycbcr[j][2][0][0]
112 | gtY = doubleGT_y[j] if doubleGT_y else 0
113 | start_t = time.time()
114 | for ckpt1 in ckptFiles1:
115 | epoch = int(ckpt1.split('_')[-1].split('.')[0])
116 | if epoch != 149:
117 | continue
118 | # very important!!!!!!!
119 | tf.reset_default_graph()
120 | # Double section
121 | high1Data_tensor = tf.placeholder(tf.float32, shape=(1, None, None, 1))
122 | lowData_tensor = tf.placeholder(tf.float32, shape=(1, None, None, 1))
123 | high2Data_tensor = tf.placeholder(tf.float32, shape=(1, None, None, 1))
124 | shared_model1 = tf.make_template('shared_model', model_double)
125 | output_tensor1 = shared_model1(high1Data_tensor, lowData_tensor, high2Data_tensor)
126 | # output_tensor = shared_model(input_tensor)
127 | output_tensor1 = tf.clip_by_value(output_tensor1, 0., 1.)
128 | output_tensor1 = output_tensor1 * 255
129 | with tf.Session() as sess:
130 | saver = tf.train.Saver(tf.global_variables())
131 | sess.run(tf.global_variables_initializer())
132 |
133 | saver.restore(sess, os.path.join(modelPath1, ckpt1))
134 | out = sess.run(output_tensor1, feed_dict={high1Data_tensor: imgHigh1DataY,
135 | lowData_tensor:imgLowDataY, high2Data_tensor: imgHigh2DataY})
136 | hevc = psnr(imgLowDataY * 255.0, gtY)
137 | out = np.around(out)
138 | out = out.astype('int')
139 | out = np.reshape(out, [1, out.shape[1], out.shape[2], 1])
140 |
141 | Y = np.reshape(out, [out.shape[1], out.shape[2]])
142 | Y = np.array(list(itertools.chain.from_iterable(Y)))
143 | U = imgLowCbCr[0]
144 | V = imgLowCbCr[1]
145 | creatPath = os.path.join(DL_path, fileName_list[i].split('\\')[-2])
146 | if not os.path.exists(creatPath):
147 | os.mkdir(creatPath)
148 |
149 | if doubleGT_y:
150 | p = psnr(out, gtY)
151 |
152 | path = os.path.join(DL_path,
153 | fileName_list[i].split('\\')[-2],
154 | fileName_list[i].split('\\')[-1].split('.')[0]) + '_%.4f' % (p-hevc)+ '.yuv'
155 |
156 | YUV = np.concatenate((Y, U, V))
157 | YUV = YUV.astype('uint8')
158 | YUV.tofile(path)
159 |
160 | total_psnr += p
161 | print("qp37\tepoch:%d\t%s\t%.4f\n" % (epoch, fileName_list[i], p))
162 |
163 | duration_t = time.time() - start_t
164 | total_time += duration_t
165 | else: #single frame
166 | continue
167 | count += 1
168 | j = i // 4
169 | lowDataY = singleData_ycbcr[j][0][0]
170 | imgLowCbCr = singleData_ycbcr[j][0][1]
171 | gtY = singleGT_y[j] if singleGT_y else 0
172 | hevc = psnr(lowDataY * 255.0, gtY)
173 | start_t = time.time()
174 | for ckpt2 in ckptFiles2:
175 | epoch = int(ckpt2.split('_')[-1].split('.')[0])
176 | if epoch != 169:
177 | continue
178 |
179 | tf.reset_default_graph()
180 | # Single section
181 | lowSingleData_tensor = tf.placeholder(tf.float32, shape=(1, None, None, 1))
182 | shared_model2 = tf.make_template('shared_model', model_single)
183 | output_tensor2 = shared_model2(lowSingleData_tensor)
184 | output_tensor2 = tf.clip_by_value(output_tensor2, 0., 1.)
185 | output_tensor2 = output_tensor2 * 255
186 | with tf.Session() as sess:
187 | saver = tf.train.Saver(tf.global_variables())
188 | sess.run(tf.global_variables_initializer())
189 |
190 | saver.restore(sess, os.path.join(modelPath2, ckpt2))
191 | out = sess.run(output_tensor2, feed_dict={lowSingleData_tensor: lowDataY})
192 | out = np.around(out)
193 | out = out.astype('int')
194 | out = np.reshape(out, [1, out.shape[1], out.shape[2], 1])
195 | Y = np.reshape(out, [out.shape[1], out.shape[2]])
196 | Y = np.array(list(itertools.chain.from_iterable(Y)))
197 | U = imgLowCbCr[0]
198 | V = imgLowCbCr[1]
199 | creatPath = os.path.join(DL_path, fileName_list[i].split('\\')[-2])
200 | if not os.path.exists(creatPath):
201 | os.mkdir(creatPath)
202 |
203 | if singleGT_y:
204 | p = psnr(out, gtY)
205 | path = os.path.join(DL_path, fileName_list[i].split('\\')[-2],
206 | fileName_list[i].split('\\')[-1].split('.')[0]) + '_%.4f' % (p - hevc) + '.yuv'
207 | YUV = np.concatenate((Y, U, V))
208 | YUV = YUV.astype('uint8')
209 | YUV.tofile(path)
210 |
211 | total_psnr += p
212 | print("qp37\tepoch:%d\t%s\t%.4f\n" % (epoch, fileName_list[i], p))
213 |
214 | duration_t = time.time() - start_t
215 | total_time += duration_t
216 |
217 |
218 | print("AVG_DURATION:%.2f\tAVG_PSNR:%.4f"%(total_time/total_imgs, total_psnr / count))
219 | print('count:', count)
220 | avg_psnr = total_psnr / count
221 | if avg_psnr > max[0]:
222 | max[0] = avg_psnr
223 | max[1] = epoch
224 |
225 | if __name__ == '__main__':
226 | test_all_ckpt(MODEL_DOUBLE_PATH1, MODEL_DOUBLE_PATH2, [HIGHDATA_Parent_PATH, QP_LOWDATA_PATH, GT_PATH])
--------------------------------------------------------------------------------
/test/test_MF/LSVE_model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 |
5 | def model_single(inputLowData_tensor):
6 | # with tf.device("/gpu:0"):
7 | input_current = inputLowData_tensor # lowData
8 |
9 | # due to don't have training_Set at right now, so let it be annotation.
10 | tensor = None
11 |
12 | # ----------------------------------------------Frame 0--------------------------------------------------------
13 | input_current_w = tf.get_variable("input_current_w", [5, 5, 1, 64],
14 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 25)))
15 | input_current_b = tf.get_variable("input_current_b", [64], initializer=tf.constant_initializer(0))
16 | input_low_tensor = tf.nn.relu(
17 | tf.nn.bias_add(tf.nn.conv2d(input_current, input_current_w, strides=[1, 1, 1, 1], padding='SAME'),
18 | input_current_b))
19 |
20 | # ----------------------------------1x1 conv, for reduce number of parameters----------------
21 | input_3x3_w = tf.get_variable("input_1x1_w", [3, 3, 64, 64],
22 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 64 / 9)))
23 | input_1x1_b = tf.get_variable("input_1x1_b", [64], initializer=tf.constant_initializer(0))
24 | input_3x3_tensor = tf.nn.relu(
25 | tf.nn.bias_add(tf.nn.conv2d(input_low_tensor, input_3x3_w, strides=[1, 1, 1, 1], padding='SAME'),
26 | input_1x1_b))
27 | tensor = input_3x3_tensor
28 |
29 | # --------------------------------------start iteration for last layers----------------------------------
30 | convId = 0
31 | for i in range(18):
32 | conv_w = tf.get_variable("conv_%02d_w" % (convId), [3, 3, 64, 64],
33 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 9 / 64)))
34 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, conv_w)
35 | conv_b = tf.get_variable("conv_%02d_b" % (convId), [64], initializer=tf.constant_initializer(0))
36 | convId += 1
37 | tensor = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1, 1, 1, 1], padding='SAME'), conv_b))
38 |
39 | conv_w = tf.get_variable("conv_%02d_w" % (convId), [3, 3, 64, 1],
40 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 9 / 64)))
41 | conv_b = tf.get_variable("conv_%02d_b" % (convId), [1], initializer=tf.constant_initializer(0))
42 | convId += 1
43 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, conv_w)
44 | tensor = tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1, 1, 1, 1], padding='SAME'), conv_b)
45 | tensor = tf.add(tensor, input_current)
46 | return tensor
47 |
--------------------------------------------------------------------------------
/test/test_MF/UTILS_MF_ra.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import math, os, random, re
4 | from PIL import Image
5 | from LMVE_test import BATCH_SIZE
6 | from LMVE_test import PATCH_SIZE
7 |
8 | # due to a batch trainingSet come from one picture. I design a algorithm to make the TrainingSet more diversity.
9 | def normalize(x):
10 | x = x / 255.
11 | return truncate(x, 0., 1.)
12 |
13 | def denormalize(x):
14 | x = x * 255.
15 | return truncate(x, 0., 255.)
16 |
17 | def truncate(input, min, max):
18 | input = np.where(input > min, input, min)
19 | input = np.where(input < max, input, max)
20 | return input
21 |
22 | def remap(input):
23 | input = 16+219/255*input
24 | return truncate(input, 16.0, 235.0)
25 |
26 | def deremap(input):
27 | input = (input-16)*255/219
28 | return truncate(input, 0.0, 255.0)
29 |
30 | # return the whole absolute path.
31 | def load_file_list(directory):
32 | list = []
33 | for filename in [y for y in os.listdir(directory) if os.path.isfile(os.path.join(directory,y))]:
34 | list.append(os.path.join(directory,filename))
35 | return list
36 |
37 | def searchHighData(currentLowDataIndex, highDataList, highIndexList):
38 | searchOffset = 3
39 | searchedHighDataIndexList = []
40 | searchedHighData = []
41 | for i in range(currentLowDataIndex - searchOffset, currentLowDataIndex + searchOffset + 1):
42 | if i in highIndexList:
43 | searchedHighDataIndexList.append(i)
44 | assert len(searchedHighDataIndexList) == 2, 'search method have error!'
45 | for tempData in highDataList:
46 | if int(os.path.basename(tempData).split('.')[0].split('_')[-1]) \
47 | == searchedHighDataIndexList[0] == searchedHighDataIndexList[1]:
48 | searchedHighData.append(tempData)
49 | return searchedHighData
50 |
51 |
52 | # return like this"[[[high1Data, lowData], label], [[2, 7, 8], 22], [[3, 8, 9], 33]]" with the whole path.
53 | def get_test_list2(highDataList, lowDataList, labelList):
54 | assert len(lowDataList) == len(labelList), "low:%d, label:%d,"%(len(lowDataList) , len(labelList))
55 | # [0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48]
56 | highIndexList = [q for q in range(49) if q % 4 == 0]
57 | test_list = []
58 | for tempDataPath in lowDataList:
59 | tempData = []
60 | temp = []
61 | # this place should changed on the different situation.
62 | currentLowDataIndex = int(os.path.basename(tempDataPath).split('.')[0].split('_')[-1])
63 | searchedHighData = searchHighData(currentLowDataIndex, highDataList, highIndexList)
64 | tempData.append(searchedHighData[0])
65 | tempData.append(tempDataPath)
66 | tempData.append(searchedHighData[1])
67 |
68 | i = list(lowDataList).index(tempDataPath)
69 |
70 | temp.append(tempData)
71 | temp.append(labelList[i])
72 | test_list.append(temp)
73 | return test_list
74 |
75 | def get_temptest_list(high1DataList, lowDataList, high2DataList, labelList):
76 | tempData = []
77 | temp = []
78 | test_list = []
79 | for i in range(len(lowDataList)):
80 | tempData.append(high1DataList[i])
81 | tempData.append(lowDataList[i])
82 | tempData.append(high2DataList[i])
83 | temp.append(tempData)
84 | temp.append(labelList[i])
85 |
86 | test_list.append(temp)
87 | return test_list
88 |
89 | # [[high, low1, label1], [[h21,h22], low2, label2]]
90 | def get_test_list(HIGHDATA_Parent_PATH, lowDataList, labelList):
91 | doubleTest_list = []
92 | singleTest_list = []
93 | HighDirList = os.listdir(HIGHDATA_Parent_PATH)
94 | # convert string-list to int-list.
95 | HighDirList = list(map(int, HighDirList))
96 |
97 | for lowdata in lowDataList:
98 |
99 | tempData = []
100 | lowdataIndex = int(os.path.basename(lowdata).split('.')[0].split('_')[-1])
101 | if lowdataIndex % 4 != 0:
102 | if lowdataIndex in HighDirList:
103 | High_current_path = HIGHDATA_Parent_PATH + '/' + str(lowdataIndex)
104 | TWO_HighData = os.listdir(High_current_path)
105 | Two_HighData = [os.path.join(High_current_path, T) for T in TWO_HighData]
106 | tempData.append(Two_HighData)
107 | tempData.append(lowdata)
108 | labelIndex = list(lowDataList).index(lowdata)
109 | tempData.append(labelList[labelIndex])
110 | doubleTest_list.append(tempData)
111 | else:
112 | tempData.append(lowdata)
113 | labelIndex = list(lowDataList).index(lowdata)
114 | if int(os.path.basename(lowdata).split('.')[0].split('_')[-1]) == \
115 | int(os.path.basename(labelList[labelIndex]).split('.')[0].split('_')[-1]):
116 |
117 | tempData.append(labelList[labelIndex])
118 | singleTest_list.append(tempData)
119 |
120 | return doubleTest_list, singleTest_list
121 |
122 | # return like this"[[[high1Data, lowData, high2Data], label], [[2, 7, 8], 22], [[3, 8, 9], 33]]" with the whole path.
123 | def get_train_list(high1DataList, lowDataList, high2DataList, labelList):
124 | assert len(lowDataList) == len(high1DataList) == len(labelList) == len(high2DataList), \
125 | "low:%d, high1:%d, label:%d, high2:%d"%(len(lowDataList), len(high1DataList), len(labelList), len(high2DataList))
126 |
127 | train_list = []
128 | for i in range(len(labelList)):
129 | tempData = []
130 | temp = []
131 | # this place should changed on the different situation.
132 | if int(os.path.basename(high1DataList[i]).split('_')[-1].split('.')[0]) + 4 == \
133 | int(os.path.basename(lowDataList[i]).split('_')[-1].split('.')[0]) + 2 == \
134 | int(os.path.basename(high2DataList[i]).split('_')[-1].split('.')[0]):
135 | tempData.append(high1DataList[i])
136 | tempData.append(lowDataList[i])
137 | tempData.append(high2DataList[i])
138 | temp.append(tempData)
139 | temp.append(labelList[i])
140 |
141 | else:
142 | raise Exception('len(lowData) not equal with len(highData)...')
143 | train_list.append(temp)
144 | return train_list
145 |
146 | def prepare_nn_data(train_list):
147 | batchSizeRandomList = random.sample(range(0,len(train_list)), 8)
148 | gt_list = []
149 | high1Data_list = []
150 | lowData_list = []
151 | high2Data_list = []
152 | for i in batchSizeRandomList:
153 | high1Data_image = c_getYdata(train_list[i][0][0])
154 | lowData_image = c_getYdata(train_list[i][0][1])
155 | high2Data_image = c_getYdata(train_list[i][0][2])
156 | gt_image = c_getYdata(train_list[i][1])
157 | for j in range(0, 8):
158 | #crop images to the disired size.
159 | high1Data_imgY, lowData_imgY, high2Data_imgY, gt_imgY = \
160 | crop(high1Data_image, lowData_image, high2Data_image, gt_image, PATCH_SIZE[0], PATCH_SIZE[1], "ndarray")
161 |
162 | #normalize
163 | high1Data_imgY = normalize(high1Data_imgY)
164 | lowData_imgY = normalize(lowData_imgY)
165 | high2Data_imgY = normalize(high2Data_imgY)
166 | gt_imgY = normalize(gt_imgY)
167 |
168 | high1Data_list.append(high1Data_imgY)
169 | lowData_list.append(lowData_imgY)
170 | high2Data_list.append(high2Data_imgY)
171 | gt_list.append(gt_imgY)
172 |
173 | high1Data_list = np.resize(high1Data_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1))
174 | lowData_list = np.resize(lowData_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1))
175 | high2Data_list = np.resize(high2Data_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1))
176 | gt_list = np.resize(gt_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1))
177 |
178 | return high1Data_list, lowData_list, high2Data_list, gt_list
179 |
180 | def getWH(yuvfileName):
181 | deyuv=re.compile(r'(.+?)\.')
182 | deyuvFilename=deyuv.findall(yuvfileName)[0] #去yuv后缀的文件名
183 | if 'x' in os.path.basename(deyuvFilename).split('_')[-2]:
184 | wxh = os.path.basename(deyuvFilename).split('_')[-2]
185 | elif 'x' in os.path.basename(deyuvFilename).split('_')[1]:
186 | wxh = os.path.basename(deyuvFilename).split('_')[1]
187 | else:
188 | # print(yuvfileName)
189 | raise Exception('do not find wxh')
190 | w, h = wxh.split('x')
191 | return int(w), int(h)
192 |
193 | def getYdata(path, size):
194 | w = size[0]
195 | h = size[1]
196 | with open(path, 'rb') as fp:
197 | fp.seek(0, 0)
198 | Yt = fp.read()
199 | tem = Image.frombytes('L', [w, h], Yt)
200 | Yt = np.asarray(tem, dtype='float32')
201 | return Yt
202 |
203 |
204 | def c_getYdata(path):
205 | return getYdata(path, getWH(path))
206 |
207 | def c_getCbCr(path):
208 | w, h = getWH(path)
209 | CbCr = []
210 | with open(path, 'rb+') as file:
211 | y = file.read(h * w)
212 | if y == b'':
213 | return ''
214 | u = file.read(h * w // 4)
215 | v = file.read(h * w // 4)
216 | # convert string-list to int-list.
217 | u = list(map(int, u))
218 | v = list(map(int, v))
219 | CbCr.append(u)
220 | CbCr.append(v)
221 | return CbCr
222 |
223 | def img2y(input_img):
224 | if np.asarray(input_img).shape[2] == 3:
225 | input_imgY = input_img.convert('YCbCr').split()[0]
226 | input_imgCb, input_imgCr = input_img.convert('YCbCr').split()[1:3]
227 | input_imgY = np.asarray(input_imgY, dtype='float32')
228 | input_imgCb = np.asarray(input_imgCb, dtype='float32')
229 | input_imgCr = np.asarray(input_imgCr, dtype='float32')
230 |
231 | #Concatenate Cb, Cr components for easy, they are used in pair anyway.
232 | input_imgCb = np.expand_dims(input_imgCb,2)
233 | input_imgCr = np.expand_dims(input_imgCr,2)
234 | input_imgCbCr = np.concatenate((input_imgCb, input_imgCr), axis=2)
235 |
236 | elif np.asarray(input_img).shape[2] == 1:
237 | print("This image has one channal only.")
238 | #If the num of channal is 1, remain.
239 | input_imgY = input_img
240 | input_imgCbCr = None
241 | else:
242 | print("The num of channal is neither 3 nor 1.")
243 | exit()
244 | return input_imgY, input_imgCbCr
245 |
246 | # def crop(input_image, gt_image, patch_width, patch_height, img_type):
247 | def crop(high1Data_image, lowData_image, high2Data_image, gt_image, patch_width, patch_height, img_type):
248 | assert type(high1Data_image) == type(gt_image) == type(lowData_image) == type(high2Data_image), "types are different."
249 | high1Data_cropped = []
250 | lowData_cropped = []
251 | high2Data_cropped = []
252 | gt_cropped = []
253 |
254 | # return a ndarray object
255 | if img_type == "ndarray":
256 | in_row_ind = random.randint(0,high1Data_image.shape[0]-patch_width)
257 | in_col_ind = random.randint(0,high1Data_image.shape[1]-patch_height)
258 |
259 | high1Data_cropped = high1Data_image[in_row_ind:in_row_ind+patch_width, in_col_ind:in_col_ind+patch_height]
260 | lowData_cropped = lowData_image[in_row_ind:in_row_ind + patch_width, in_col_ind:in_col_ind + patch_height]
261 | high2Data_cropped = high2Data_image[in_row_ind:in_row_ind + patch_width, in_col_ind:in_col_ind + patch_height]
262 | gt_cropped = gt_image[in_row_ind:in_row_ind+patch_width, in_col_ind:in_col_ind+patch_height]
263 |
264 | #return an "Image" object
265 | elif img_type == "Image":
266 | pass
267 | return high1Data_cropped, lowData_cropped, high2Data_cropped, gt_cropped
268 |
269 | def save_images(inputY, inputCbCr, size, image_path):
270 | """Save mutiple images into one single image.
271 |
272 | # Parameters
273 | # -----------
274 | # images : numpy array [batch, w, h, c]
275 | # size : list of two int, row and column number.
276 | # number of images should be equal or less than size[0] * size[1]
277 | # image_path : string.
278 | #
279 | # Examples
280 | # ---------
281 | # # >>> images = np.random.rand(64, 100, 100, 3)
282 | # # >>> tl.visualize.save_images(images, [8, 8], 'temp.png')
283 | """
284 | def merge(images, size):
285 | h, w = images.shape[1], images.shape[2]
286 | img = np.zeros((h * size[0], w * size[1], 3))
287 | for idx, image in enumerate(images):
288 | i = idx % size[1]
289 | j = idx // size[1]
290 | img[j*h:j*h+h, i*w:i*w+w, :] = image
291 | return img
292 |
293 | inputY = inputY.astype('uint8')
294 | inputCbCr = inputCbCr.astype('uint8')
295 | output_concat = np.concatenate((inputY, inputCbCr), axis=3)
296 |
297 | assert len(output_concat) <= size[0] * size[1], "number of images should be equal or less than size[0] * size[1] {}".format(len(output_concat))
298 |
299 | new_output = merge(output_concat, size)
300 |
301 | new_output = new_output.astype('uint8')
302 |
303 | img = Image.fromarray(new_output, mode='YCbCr')
304 | img = img.convert('RGB')
305 | img.save(image_path)
306 |
307 | def get_image_batch(train_list,offset,batch_size):
308 | target_list = train_list[offset:offset+batch_size]
309 | input_list = []
310 | gt_list = []
311 | inputcbcr_list = []
312 | for pair in target_list:
313 | input_img = Image.open(pair[0])
314 | gt_img = Image.open(pair[1])
315 |
316 | #crop images to the disired size.
317 | input_img, gt_img = crop(input_img, gt_img, PATCH_SIZE[0], PATCH_SIZE[1], "Image")
318 |
319 | #focus on Y channal only
320 | input_imgY, input_imgCbCr = img2y(input_img)
321 | gt_imgY, gt_imgCbCr = img2y(gt_img)
322 |
323 | #input_imgY = normalize(input_imgY)
324 | #gt_imgY = normalize(gt_imgY)
325 |
326 | input_list.append(input_imgY)
327 | gt_list.append(gt_imgY)
328 | inputcbcr_list.append(input_imgCbCr)
329 |
330 | input_list = np.resize(input_list, (batch_size, PATCH_SIZE[0], PATCH_SIZE[1], 1))
331 | gt_list = np.resize(gt_list, (batch_size, PATCH_SIZE[0], PATCH_SIZE[1], 1))
332 |
333 | return input_list, gt_list, inputcbcr_list
334 |
335 | def save_test_img(inputY, inputCbCr, path):
336 | assert len(inputY.shape) == 4, "the tensor Y's shape is %s"%inputY.shape
337 | assert inputY.shape[0] == 1, "the fitst component must be 1, has not been completed otherwise.{}".format(inputY.shape)
338 |
339 | inputY = np.squeeze(inputY, axis=0)
340 | inputY = inputY.astype('uint8')
341 |
342 | inputCbCr = inputCbCr.astype('uint8')
343 |
344 | output_concat = np.concatenate((inputY, inputCbCr), axis=2)
345 | img = Image.fromarray(output_concat, mode='YCbCr')
346 | img = img.convert('RGB')
347 | img.save(path)
348 |
349 | def psnr(hr_image, sr_image, max_value=255.0):
350 | eps = 1e-10
351 | if((type(hr_image)==type(np.array([]))) or (type(hr_image)==type([]))):
352 | hr_image_data = np.asarray(hr_image, 'float32')
353 | sr_image_data = np.asarray(sr_image, 'float32')
354 |
355 | diff = sr_image_data - hr_image_data
356 | mse = np.mean(diff*diff)
357 | mse = np.maximum(eps, mse)
358 | return float(10*math.log10(max_value*max_value/mse))
359 | else:
360 | assert len(hr_image.shape)==4 and len(sr_image.shape)==4
361 | diff = hr_image - sr_image
362 | mse = tf.reduce_mean(tf.square(diff))
363 | mse = tf.maximum(mse, eps)
364 | return 10*tf.log(max_value*max_value/mse)/math.log(10)
365 |
366 | def getBeforeNNBlockDict(img, w, h):
367 | # print(img[:1500, : 2000])
368 | blockSize = 1000
369 | padding = 32
370 | yBlockNum = (h // blockSize) if (h % blockSize == 0) else (h // blockSize + 1)
371 | xBlockNum = (w // blockSize) if (w % blockSize == 0) else (w // blockSize + 1)
372 | tempImg = {}
373 | i = 0
374 | for yBlock in range(yBlockNum):
375 | for xBlock in range(xBlockNum):
376 | if yBlock == 0:
377 | if xBlock == 0:
378 | tempImg[i] = img[0: blockSize+padding, 0: blockSize+padding]
379 | elif xBlock == xBlockNum - 1:
380 | tempImg[i] = img[0: blockSize+padding, xBlock*blockSize-padding: w]
381 | else:
382 | tempImg[i] = img[0: blockSize+padding, blockSize*xBlock-padding: blockSize*(xBlock+1)+padding]
383 | elif yBlock == yBlockNum - 1:
384 | if xBlock == 0:
385 | tempImg[i] = img[blockSize*yBlock-padding: h, 0: blockSize+padding]
386 | elif xBlock == xBlockNum - 1:
387 | tempImg[i] = img[blockSize*yBlock-padding: h, blockSize*xBlock-padding: w]
388 | else:
389 | tempImg[i] = img[blockSize*yBlock-padding: h, blockSize*xBlock-padding: blockSize*(xBlock+1)+padding]
390 | elif xBlock == 0:
391 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding, 0: blockSize+padding]
392 | elif xBlock == xBlockNum - 1:
393 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding, blockSize*xBlock-padding: w]
394 | else:
395 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding,
396 | blockSize*xBlock-padding: blockSize*(xBlock+1)+padding]
397 | i += i
398 | l = tempImg[i].astype('uint8')
399 | l = Image.fromarray(l)
400 | l.show()
401 |
--------------------------------------------------------------------------------
/test/test_MF/__pycache__/MFVDSingle_ra.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/__pycache__/MFVDSingle_ra.cpython-35.pyc
--------------------------------------------------------------------------------
/test/test_MF/__pycache__/UTILS_MF_ra.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/__pycache__/UTILS_MF_ra.cpython-35.pyc
--------------------------------------------------------------------------------
/test/test_MF/__pycache__/yangNet.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/__pycache__/yangNet.cpython-35.pyc
--------------------------------------------------------------------------------
/test/test_MF/checkpoints/MF_qp22_ra_01222_LMVE_double/MF_qp22_ra_01222_LMVE_double_149.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp22_ra_01222_LMVE_double/MF_qp22_ra_01222_LMVE_double_149.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/test/test_MF/checkpoints/MF_qp22_ra_01222_LMVE_double/MF_qp22_ra_01222_LMVE_double_149.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp22_ra_01222_LMVE_double/MF_qp22_ra_01222_LMVE_double_149.ckpt.index
--------------------------------------------------------------------------------
/test/test_MF/checkpoints/MF_qp22_ra_01222_LMVE_double/MF_qp22_ra_01222_LMVE_double_149.ckpt.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp22_ra_01222_LMVE_double/MF_qp22_ra_01222_LMVE_double_149.ckpt.meta
--------------------------------------------------------------------------------
/test/test_MF/checkpoints/MF_qp22_ra_01222_LMVE_double/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "MF_qp22_ra_01222_MFVD_double_149.ckpt"
2 | all_model_checkpoint_paths: "MF_qp22_ra_01222_MFVD_double_149.ckpt"
3 |
--------------------------------------------------------------------------------
/test/test_MF/checkpoints/MF_qp27_ra_01202_LMVE_double/MF_qp27_ra_01202_LMVE_double_201.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp27_ra_01202_LMVE_double/MF_qp27_ra_01202_LMVE_double_201.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/test/test_MF/checkpoints/MF_qp27_ra_01202_LMVE_double/MF_qp27_ra_01202_LMVE_double_201.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp27_ra_01202_LMVE_double/MF_qp27_ra_01202_LMVE_double_201.ckpt.index
--------------------------------------------------------------------------------
/test/test_MF/checkpoints/MF_qp27_ra_01202_LMVE_double/MF_qp27_ra_01202_LMVE_double_201.ckpt.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp27_ra_01202_LMVE_double/MF_qp27_ra_01202_LMVE_double_201.ckpt.meta
--------------------------------------------------------------------------------
/test/test_MF/checkpoints/MF_qp27_ra_01202_LMVE_double/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "MF_qp27_ra_01202_MFVD_double_201.ckpt"
2 | all_model_checkpoint_paths: "MF_qp27_ra_01202_MFVD_double_201.ckpt"
3 |
--------------------------------------------------------------------------------
/test/test_MF/checkpoints/MF_qp32_ra_01262_LMVE_double/MF_qp32_ra_01262_LMVE_double_138.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp32_ra_01262_LMVE_double/MF_qp32_ra_01262_LMVE_double_138.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/test/test_MF/checkpoints/MF_qp32_ra_01262_LMVE_double/MF_qp32_ra_01262_LMVE_double_138.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp32_ra_01262_LMVE_double/MF_qp32_ra_01262_LMVE_double_138.ckpt.index
--------------------------------------------------------------------------------
/test/test_MF/checkpoints/MF_qp32_ra_01262_LMVE_double/MF_qp32_ra_01262_LMVE_double_138.ckpt.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp32_ra_01262_LMVE_double/MF_qp32_ra_01262_LMVE_double_138.ckpt.meta
--------------------------------------------------------------------------------
/test/test_MF/checkpoints/MF_qp32_ra_01262_LMVE_double/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "MF_qp32_ra_01262_MFVD_double_138.ckpt"
2 | all_model_checkpoint_paths: "MF_qp32_ra_01262_MFVD_double_138.ckpt"
3 |
--------------------------------------------------------------------------------
/test/test_MF/checkpoints/MF_qp37_ra_01121/MF_qp37_ra_01121_402.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp37_ra_01121/MF_qp37_ra_01121_402.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/test/test_MF/checkpoints/MF_qp37_ra_01121/MF_qp37_ra_01121_402.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp37_ra_01121/MF_qp37_ra_01121_402.ckpt.index
--------------------------------------------------------------------------------
/test/test_MF/checkpoints/MF_qp37_ra_01121/MF_qp37_ra_01121_402.ckpt.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp37_ra_01121/MF_qp37_ra_01121_402.ckpt.meta
--------------------------------------------------------------------------------
/test/test_MF/checkpoints/MF_qp37_ra_01121/a.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_MF/checkpoints/MF_qp37_ra_01121/a.txt
--------------------------------------------------------------------------------
/test/test_MF/checkpoints/MF_qp37_ra_01121/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "MF_qp37_ra_01121_002.ckpt"
2 | all_model_checkpoint_paths: "MF_qp37_ra_01121_002.ckpt"
3 |
--------------------------------------------------------------------------------
/test/test_Single/LSVE_ra_model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 |
5 | def model_single(inputLowData_tensor):
6 | # with tf.device("/gpu:0"):
7 | input_current = inputLowData_tensor # lowData
8 |
9 | # due to don't have training_Set at right now, so let it be annotation.
10 | tensor = None
11 |
12 | # ----------------------------------------------Frame 0--------------------------------------------------------
13 | input_current_w = tf.get_variable("input_current_w", [5, 5, 1, 64],
14 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 25)))
15 | input_current_b = tf.get_variable("input_current_b", [64], initializer=tf.constant_initializer(0))
16 | input_low_tensor = tf.nn.relu(
17 | tf.nn.bias_add(tf.nn.conv2d(input_current, input_current_w, strides=[1, 1, 1, 1], padding='SAME'),
18 | input_current_b))
19 |
20 | # ----------------------------------1x1 conv, for reduce number of parameters----------------
21 | input_3x3_w = tf.get_variable("input_1x1_w", [3, 3, 64, 64],
22 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 64 / 9)))
23 | input_1x1_b = tf.get_variable("input_1x1_b", [64], initializer=tf.constant_initializer(0))
24 | input_3x3_tensor = tf.nn.relu(
25 | tf.nn.bias_add(tf.nn.conv2d(input_low_tensor, input_3x3_w, strides=[1, 1, 1, 1], padding='SAME'),
26 | input_1x1_b))
27 | tensor = input_3x3_tensor
28 |
29 | # --------------------------------------start iteration for last layers----------------------------------
30 | convId = 0
31 | for i in range(18):
32 | conv_w = tf.get_variable("conv_%02d_w" % (convId), [3, 3, 64, 64],
33 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 9 / 64)))
34 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, conv_w)
35 | conv_b = tf.get_variable("conv_%02d_b" % (convId), [64], initializer=tf.constant_initializer(0))
36 | convId += 1
37 | tensor = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1, 1, 1, 1], padding='SAME'), conv_b))
38 |
39 | conv_w = tf.get_variable("conv_%02d_w" % (convId), [3, 3, 64, 1],
40 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 9 / 64)))
41 | conv_b = tf.get_variable("conv_%02d_b" % (convId), [1], initializer=tf.constant_initializer(0))
42 | convId += 1
43 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, conv_w)
44 | tensor = tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1, 1, 1, 1], padding='SAME'), conv_b)
45 | tensor = tf.add(tensor, input_current)
46 | return tensor
47 |
--------------------------------------------------------------------------------
/test/test_Single/LSVE_test.py:
--------------------------------------------------------------------------------
1 | import time
2 | from LSVE_ra_model import model_single
3 | from UTILS_MF_ra_ALL_Single import *
4 | import itertools
5 |
6 | tf.logging.set_verbosity(tf.logging.WARN)
7 |
8 | EXP_DATA = "SF_qp27_ra_01232_LSVE" # single checkpoints
9 | TESTOUT_PATH = "./testout/%s/"%(EXP_DATA)
10 | MODEL_PATH = "./checkpoints/%s/"%(EXP_DATA)
11 | QP_LOWDATA_PATH = r'E:\MF\HEVC_TestSequenct\rec\qp27\ra\D\BasketballPass_qp22_ai_416x240' # low frames
12 | GT_PATH = r"E:\MF\HEVC_TestSequenct\org\D\BasketballPass_416x240" # ground truth
13 | DL_path = 'E:\MF\HEVC_TestSequenct\DL_rec\Single\qp27'
14 | OUT_DATA_PATH = "./outdata/%s/"%(EXP_DATA)
15 | NOFILTER = {'q22':42.2758, 'q27':38.9788, 'qp32':35.8667, 'q37':32.8257,'qp37':32.8257}
16 |
17 | # Ground truth images dir should be the 2nd component of 'fileOrDir' if 2 components are given.
18 |
19 | ##cb, cr components are not implemented
20 | def prepare_test_data(fileOrDir):
21 | if not os.path.exists(TESTOUT_PATH):
22 | os.mkdir(TESTOUT_PATH)
23 |
24 | singleData_ycbcr = []
25 | singleGT_y = []
26 | fileName_list = []
27 |
28 | #The input is a single file.
29 | if len(fileOrDir) == 2:
30 | # return the whole absolute path.
31 | fileName_list = load_file_list(fileOrDir[0])
32 | # double_list # [[high, low1, label1], [[h21,h22], low2, label2]]
33 | # single_list # [[low1, lable1], [2,2] ....]
34 | single_list = get_test_list(load_file_list(fileOrDir[0]), load_file_list(fileOrDir[1]))
35 |
36 | # single_list # [[low1, lable1], [2,2] ....]
37 | for pair in single_list:
38 | lowData_list = []
39 | lowData_imgY = c_getYdata(pair[0])
40 | CbCr = c_getCbCr(pair[0])
41 | gt_imgY = c_getYdata(pair[1])
42 |
43 | # normalize
44 | lowData_imgY = normalize(lowData_imgY)
45 |
46 | lowData_imgY = np.resize(lowData_imgY, (1, lowData_imgY.shape[0], lowData_imgY.shape[1], 1))
47 | gt_imgY = np.resize(gt_imgY, (1, gt_imgY.shape[0], gt_imgY.shape[1], 1))
48 |
49 | lowData_list.append([lowData_imgY, CbCr])
50 | singleData_ycbcr.append(lowData_list)
51 | singleGT_y.append(gt_imgY)
52 |
53 | else:
54 | print("Invalid Inputs...!tjc!")
55 | exit(0)
56 |
57 | return singleData_ycbcr, singleGT_y, fileName_list
58 |
59 | def test_all_ckpt(modelPath, fileOrDir):
60 | max = [0, 0]
61 | tem = [f for f in os.listdir(modelPath) if 'data' in f]
62 | ckptFiles = sorted([r.split('.data')[0] for r in tem])
63 | re_psnr = tf.placeholder('float32')
64 | tf.summary.scalar('re_psnr', re_psnr)
65 |
66 | with tf.Session() as sess:
67 | lowData_tensor = tf.placeholder(tf.float32, shape=(1, None, None, 1))
68 | shared_model = tf.make_template('shared_model', model_single)
69 | output_tensor = shared_model(lowData_tensor)
70 | output_tensor = tf.clip_by_value(output_tensor, 0., 1.)
71 | output_tensor = output_tensor * 255
72 |
73 | merged = tf.summary.merge_all()
74 | file_writer = tf.summary.FileWriter(OUT_DATA_PATH, sess.graph)
75 | saver = tf.train.Saver(tf.global_variables())
76 | sess.run(tf.global_variables_initializer())
77 |
78 | singleData_ycbcr, singleGT_y, fileName_list = prepare_test_data(fileOrDir)
79 |
80 | for ckpt in ckptFiles:
81 | epoch = int(ckpt.split('_')[-1].split('.')[0])
82 | if epoch != 169:
83 | continue
84 |
85 | saver.restore(sess, os.path.join(modelPath,ckpt))
86 | total_time, total_psnr = 0, 0
87 | total_imgs = len(fileName_list)
88 | count = 0
89 | for i in range(total_imgs):
90 | if i%4==0:
91 | continue
92 | count += 1
93 |
94 | # sorry! this place write so difficult!【[[[h1,0]],[[low,0]],[[h2, 0]]], [[[h1,0]],[[low,0]],[[h2, 0]]]】
95 | lowDataY = singleData_ycbcr[i][0][0]
96 | imgLowCbCr = singleData_ycbcr[i][0][1]
97 | gtY = singleGT_y[i] if singleGT_y else 0
98 |
99 | #### adopt the split frame method to deal with the out of memory situation. ####
100 |
101 | start_t = time.time()
102 | out = sess.run(output_tensor, feed_dict={lowData_tensor: lowDataY})
103 | out = np.around(out)
104 | out = out.astype('int')
105 | out = np.reshape(out, [1, out.shape[1], out.shape[2], 1])
106 | hevc = psnr(lowDataY * 255.0, gtY)
107 | duration_t = time.time() - start_t
108 | total_time += duration_t
109 | Y = np.reshape(out, [out.shape[1], out.shape[2]])
110 | Y = np.array(list(itertools.chain.from_iterable(Y)))
111 | U = imgLowCbCr[0]
112 | V = imgLowCbCr[1]
113 | creatPath = os.path.join(DL_path, fileName_list[i].split('\\')[-2])
114 | if not os.path.exists(creatPath):
115 | os.mkdir(creatPath)
116 |
117 | if singleGT_y:
118 | p = psnr(out, gtY)
119 | path = os.path.join(DL_path,
120 | fileName_list[i].split('\\')[-2],
121 | fileName_list[i].split('\\')[-1].split('.')[0]) + '_%.4f' % (p - hevc) + '.yuv'
122 |
123 | YUV = np.concatenate((Y, U, V))
124 | YUV = YUV.astype('uint8')
125 | YUV.tofile(path)
126 |
127 | total_psnr += p
128 | print("qp??\tepoch:%d\t%s\t%.4f\n" % (epoch, fileName_list[i], p))
129 | #print("took:%.2fs\t psnr:%.2f name:%s"%(duration_t, p, save_path))
130 |
131 |
132 |
133 | print("AVG_DURATION:%.2f\tAVG_PSNR:%.4f"%(total_time/total_imgs, total_psnr / count))
134 | print('count:', count)
135 | # avg_psnr = total_psnr/total_imgs
136 | avg_psnr = total_psnr / count
137 | avg_duration = (total_time/total_imgs)
138 | if avg_psnr > max[0]:
139 | max[0] = avg_psnr
140 | max[1] = epoch
141 |
142 | if __name__ == '__main__':
143 | test_all_ckpt(MODEL_PATH, [QP_LOWDATA_PATH, GT_PATH])
--------------------------------------------------------------------------------
/test/test_Single/UTILS_MF_ra_ALL_Single.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import math, os, random, re
4 | from PIL import Image
5 | BATCH_SIZE = 64
6 | PATCH_SIZE = (64, 64)
7 |
8 | # due to a batch trainingSet come from one picture. I design a algorithm to make the TrainingSet more diversity.
9 | def normalize(x):
10 | x = x / 255.
11 | return truncate(x, 0., 1.)
12 |
13 | def denormalize(x):
14 | x = x * 255.
15 | return truncate(x, 0., 255.)
16 |
17 | def truncate(input, min, max):
18 | input = np.where(input > min, input, min)
19 | input = np.where(input < max, input, max)
20 | return input
21 |
22 | def remap(input):
23 | input = 16+219/255*input
24 | return truncate(input, 16.0, 235.0)
25 |
26 | def deremap(input):
27 | input = (input-16)*255/219
28 | return truncate(input, 0.0, 255.0)
29 |
30 | # return the whole absolute path.
31 | def load_file_list(directory):
32 | list = []
33 | for filename in [y for y in os.listdir(directory) if os.path.isfile(os.path.join(directory,y))]:
34 | list.append(os.path.join(directory,filename))
35 | return list
36 |
37 | def searchHighData(currentLowDataIndex, highDataList, highIndexList):
38 | searchOffset = 3
39 | searchedHighDataIndexList = []
40 | searchedHighData = []
41 | for i in range(currentLowDataIndex - searchOffset, currentLowDataIndex + searchOffset + 1):
42 | if i in highIndexList:
43 | searchedHighDataIndexList.append(i)
44 | assert len(searchedHighDataIndexList) == 2, 'search method have error!'
45 | for tempData in highDataList:
46 | if int(os.path.basename(tempData).split('.')[0].split('_')[-1]) \
47 | == searchedHighDataIndexList[0] == searchedHighDataIndexList[1]:
48 | searchedHighData.append(tempData)
49 | return searchedHighData
50 |
51 |
52 | # return like this"[[[high1Data, lowData], label], [[2, 7, 8], 22], [[3, 8, 9], 33]]" with the whole path.
53 | def get_test_list2(highDataList, lowDataList, labelList):
54 | assert len(lowDataList) == len(labelList), "low:%d, label:%d,"%(len(lowDataList) , len(labelList))
55 |
56 | # [0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48]
57 | highIndexList = [q for q in range(49) if q % 4 == 0]
58 | test_list = []
59 | for tempDataPath in lowDataList:
60 | tempData = []
61 | temp = []
62 | # this place should changed on the different situation.
63 | currentLowDataIndex = int(os.path.basename(tempDataPath).split('.')[0].split('_')[-1])
64 | searchedHighData = searchHighData(currentLowDataIndex, highDataList, highIndexList)
65 | tempData.append(searchedHighData[0])
66 | tempData.append(tempDataPath)
67 | tempData.append(searchedHighData[1])
68 |
69 | i = list(lowDataList).index(tempDataPath)
70 |
71 | temp.append(tempData)
72 | temp.append(labelList[i])
73 |
74 | test_list.append(temp)
75 | return test_list
76 |
77 | def get_temptest_list(high1DataList, lowDataList, high2DataList, labelList):
78 | tempData = []
79 | temp = []
80 | test_list = []
81 | for i in range(len(lowDataList)):
82 | tempData.append(high1DataList[i])
83 | tempData.append(lowDataList[i])
84 | tempData.append(high2DataList[i])
85 | temp.append(tempData)
86 | temp.append(labelList[i])
87 |
88 | test_list.append(temp)
89 | return test_list
90 |
91 | # [[high, low1, label1], [[h21,h22], low2, label2]]
92 | def get_test_list(lowDataList, labelList):
93 | singleTest_list = []
94 |
95 | for lowdata in lowDataList:
96 | tempData = []
97 |
98 | tempData.append(lowdata)
99 | labelIndex = list(lowDataList).index(lowdata)
100 | if int(os.path.basename(lowdata).split('.')[0].split('_')[-1]) == \
101 | int(os.path.basename(labelList[labelIndex]).split('.')[0].split('_')[-1]):
102 | tempData.append(labelList[labelIndex])
103 | singleTest_list.append(tempData)
104 |
105 | return singleTest_list
106 |
107 | # return like this"[[[high1Data, lowData, high2Data], label], [[2, 7, 8], 22], [[3, 8, 9], 33]]" with the whole path.
108 | def get_train_list(high1DataList, lowDataList, high2DataList, labelList):
109 | assert len(lowDataList) == len(high1DataList) == len(labelList) == len(high2DataList), \
110 | "low:%d, high1:%d, label:%d, high2:%d"%(len(lowDataList), len(high1DataList), len(labelList), len(high2DataList))
111 |
112 | train_list = []
113 | for i in range(len(labelList)):
114 | tempData = []
115 | temp = []
116 | # this place should changed on the different situation.
117 | if int(os.path.basename(high1DataList[i]).split('_')[-1].split('.')[0]) + 4 == \
118 | int(os.path.basename(lowDataList[i]).split('_')[-1].split('.')[0]) + 2 == \
119 | int(os.path.basename(high2DataList[i]).split('_')[-1].split('.')[0]):
120 | tempData.append(high1DataList[i])
121 | tempData.append(lowDataList[i])
122 | tempData.append(high2DataList[i])
123 | temp.append(tempData)
124 | temp.append(labelList[i])
125 |
126 | else:
127 | raise Exception('len(lowData) not equal with len(highData)...')
128 | train_list.append(temp)
129 | return train_list
130 |
131 | def prepare_nn_data(train_list):
132 | batchSizeRandomList = random.sample(range(0,len(train_list)), 8)
133 | gt_list = []
134 | high1Data_list = []
135 | lowData_list = []
136 | high2Data_list = []
137 | for i in batchSizeRandomList:
138 | high1Data_image = c_getYdata(train_list[i][0][0])
139 | lowData_image = c_getYdata(train_list[i][0][1])
140 | high2Data_image = c_getYdata(train_list[i][0][2])
141 | gt_image = c_getYdata(train_list[i][1])
142 | for j in range(0, 8):
143 | #crop images to the disired size.
144 | high1Data_imgY, lowData_imgY, high2Data_imgY, gt_imgY = \
145 | crop(high1Data_image, lowData_image, high2Data_image, gt_image, PATCH_SIZE[0], PATCH_SIZE[1], "ndarray")
146 | #normalize
147 | high1Data_imgY = normalize(high1Data_imgY)
148 | lowData_imgY = normalize(lowData_imgY)
149 | high2Data_imgY = normalize(high2Data_imgY)
150 | gt_imgY = normalize(gt_imgY)
151 |
152 | high1Data_list.append(high1Data_imgY)
153 | lowData_list.append(lowData_imgY)
154 | high2Data_list.append(high2Data_imgY)
155 | gt_list.append(gt_imgY)
156 |
157 | high1Data_list = np.resize(high1Data_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1))
158 | lowData_list = np.resize(lowData_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1))
159 | high2Data_list = np.resize(high2Data_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1))
160 | gt_list = np.resize(gt_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1))
161 |
162 | return high1Data_list, lowData_list, high2Data_list, gt_list
163 |
164 | def getWH(yuvfileName):
165 | deyuv=re.compile(r'(.+?)\.')
166 | deyuvFilename=deyuv.findall(yuvfileName)[0] #去yuv后缀的文件名
167 | if 'x' in os.path.basename(deyuvFilename).split('_')[-2]:
168 | wxh = os.path.basename(deyuvFilename).split('_')[-2]
169 | elif 'x' in os.path.basename(deyuvFilename).split('_')[1]:
170 | wxh = os.path.basename(deyuvFilename).split('_')[1]
171 | else:
172 | raise Exception('do not find wxh')
173 | w, h = wxh.split('x')
174 | return int(w), int(h)
175 |
176 | def c_getCbCr(path):
177 | w, h = getWH(path)
178 | CbCr = []
179 | with open(path, 'rb+') as file:
180 | y = file.read(h * w)
181 | if y == b'':
182 | return ''
183 | u = file.read(h * w // 4)
184 | v = file.read(h * w // 4)
185 | # convert string-list to int-list.
186 | u = list(map(int, u))
187 | v = list(map(int, v))
188 | CbCr.append(u)
189 | CbCr.append(v)
190 | return CbCr
191 |
192 |
193 | def getYdata(path, size):
194 | w = size[0]
195 | h = size[1]
196 | with open(path, 'rb') as fp:
197 | fp.seek(0, 0)
198 | Yt = fp.read()
199 | tem = Image.frombytes('L', [w, h], Yt)
200 |
201 | Yt = np.asarray(tem, dtype='float32')
202 | return Yt
203 |
204 |
205 | def c_getYdata(path):
206 | return getYdata(path, getWH(path))
207 |
208 | def img2y(input_img):
209 | if np.asarray(input_img).shape[2] == 3:
210 | input_imgY = input_img.convert('YCbCr').split()[0]
211 | input_imgCb, input_imgCr = input_img.convert('YCbCr').split()[1:3]
212 |
213 | input_imgY = np.asarray(input_imgY, dtype='float32')
214 | input_imgCb = np.asarray(input_imgCb, dtype='float32')
215 | input_imgCr = np.asarray(input_imgCr, dtype='float32')
216 |
217 |
218 | #Concatenate Cb, Cr components for easy, they are used in pair anyway.
219 | input_imgCb = np.expand_dims(input_imgCb,2)
220 | input_imgCr = np.expand_dims(input_imgCr,2)
221 | input_imgCbCr = np.concatenate((input_imgCb, input_imgCr), axis=2)
222 |
223 | elif np.asarray(input_img).shape[2] == 1:
224 | print("This image has one channal only.")
225 | #If the num of channal is 1, remain.
226 | input_imgY = input_img
227 | input_imgCbCr = None
228 | else:
229 | print("The num of channal is neither 3 nor 1.")
230 | exit()
231 | return input_imgY, input_imgCbCr
232 |
233 | # def crop(input_image, gt_image, patch_width, patch_height, img_type):
234 | def crop(high1Data_image, lowData_image, high2Data_image, gt_image, patch_width, patch_height, img_type):
235 | assert type(high1Data_image) == type(gt_image) == type(lowData_image) == type(high2Data_image), "types are different."
236 | high1Data_cropped = []
237 | lowData_cropped = []
238 | high2Data_cropped = []
239 | gt_cropped = []
240 |
241 | # return a ndarray object
242 | if img_type == "ndarray":
243 | in_row_ind = random.randint(0,high1Data_image.shape[0]-patch_width)
244 | in_col_ind = random.randint(0,high1Data_image.shape[1]-patch_height)
245 |
246 | high1Data_cropped = high1Data_image[in_row_ind:in_row_ind+patch_width, in_col_ind:in_col_ind+patch_height]
247 | lowData_cropped = lowData_image[in_row_ind:in_row_ind + patch_width, in_col_ind:in_col_ind + patch_height]
248 | high2Data_cropped = high2Data_image[in_row_ind:in_row_ind + patch_width, in_col_ind:in_col_ind + patch_height]
249 | gt_cropped = gt_image[in_row_ind:in_row_ind+patch_width, in_col_ind:in_col_ind+patch_height]
250 |
251 | #return an "Image" object
252 | elif img_type == "Image":
253 | pass
254 | return high1Data_cropped, lowData_cropped, high2Data_cropped, gt_cropped
255 |
256 | def save_images(inputY, inputCbCr, size, image_path):
257 | """Save mutiple images into one single image.
258 |
259 | # Parameters
260 | # -----------
261 | # images : numpy array [batch, w, h, c]
262 | # size : list of two int, row and column number.
263 | # number of images should be equal or less than size[0] * size[1]
264 | # image_path : string.
265 | #
266 | # Examples
267 | # ---------
268 | # # >>> images = np.random.rand(64, 100, 100, 3)
269 | # # >>> tl.visualize.save_images(images, [8, 8], 'temp.png')
270 | """
271 | def merge(images, size):
272 | h, w = images.shape[1], images.shape[2]
273 | img = np.zeros((h * size[0], w * size[1], 3))
274 | for idx, image in enumerate(images):
275 | i = idx % size[1]
276 | j = idx // size[1]
277 | img[j*h:j*h+h, i*w:i*w+w, :] = image
278 | return img
279 |
280 | inputY = inputY.astype('uint8')
281 | inputCbCr = inputCbCr.astype('uint8')
282 | output_concat = np.concatenate((inputY, inputCbCr), axis=3)
283 |
284 | assert len(output_concat) <= size[0] * size[1], "number of images should be equal or less than size[0] * size[1] {}".format(len(output_concat))
285 |
286 | new_output = merge(output_concat, size)
287 |
288 | new_output = new_output.astype('uint8')
289 |
290 | img = Image.fromarray(new_output, mode='YCbCr')
291 | img = img.convert('RGB')
292 | img.save(image_path)
293 |
294 | def get_image_batch(train_list,offset,batch_size):
295 | target_list = train_list[offset:offset+batch_size]
296 | input_list = []
297 | gt_list = []
298 | inputcbcr_list = []
299 | for pair in target_list:
300 | input_img = Image.open(pair[0])
301 | gt_img = Image.open(pair[1])
302 |
303 | #crop images to the disired size.
304 | input_img, gt_img = crop(input_img, gt_img, PATCH_SIZE[0], PATCH_SIZE[1], "Image")
305 |
306 | #focus on Y channal only
307 | input_imgY, input_imgCbCr = img2y(input_img)
308 | gt_imgY, gt_imgCbCr = img2y(gt_img)
309 | input_list.append(input_imgY)
310 | gt_list.append(gt_imgY)
311 | inputcbcr_list.append(input_imgCbCr)
312 |
313 | input_list = np.resize(input_list, (batch_size, PATCH_SIZE[0], PATCH_SIZE[1], 1))
314 | gt_list = np.resize(gt_list, (batch_size, PATCH_SIZE[0], PATCH_SIZE[1], 1))
315 |
316 | return input_list, gt_list, inputcbcr_list
317 |
318 | def save_test_img(inputY, inputCbCr, path):
319 | assert len(inputY.shape) == 4, "the tensor Y's shape is %s"%inputY.shape
320 | assert inputY.shape[0] == 1, "the fitst component must be 1, has not been completed otherwise.{}".format(inputY.shape)
321 |
322 | inputY = np.squeeze(inputY, axis=0)
323 | inputY = inputY.astype('uint8')
324 |
325 | inputCbCr = inputCbCr.astype('uint8')
326 |
327 | output_concat = np.concatenate((inputY, inputCbCr), axis=2)
328 | img = Image.fromarray(output_concat, mode='YCbCr')
329 | img = img.convert('RGB')
330 | img.save(path)
331 |
332 | def psnr(hr_image, sr_image, max_value=255.0):
333 | eps = 1e-10
334 | if((type(hr_image)==type(np.array([]))) or (type(hr_image)==type([]))):
335 | hr_image_data = np.asarray(hr_image, 'float32')
336 | sr_image_data = np.asarray(sr_image, 'float32')
337 |
338 | diff = sr_image_data - hr_image_data
339 | mse = np.mean(diff*diff)
340 | mse = np.maximum(eps, mse)
341 | return float(10*math.log10(max_value*max_value/mse))
342 | else:
343 | assert len(hr_image.shape)==4 and len(sr_image.shape)==4
344 | diff = hr_image - sr_image
345 | mse = tf.reduce_mean(tf.square(diff))
346 | mse = tf.maximum(mse, eps)
347 | return 10*tf.log(max_value*max_value/mse)/math.log(10)
348 |
349 | def getBeforeNNBlockDict(img, w, h):
350 | blockSize = 1000
351 | padding = 32
352 | yBlockNum = (h // blockSize) if (h % blockSize == 0) else (h // blockSize + 1)
353 | xBlockNum = (w // blockSize) if (w % blockSize == 0) else (w // blockSize + 1)
354 | tempImg = {}
355 | i = 0
356 | for yBlock in range(yBlockNum):
357 | for xBlock in range(xBlockNum):
358 | if yBlock == 0:
359 | if xBlock == 0:
360 | tempImg[i] = img[0: blockSize+padding, 0: blockSize+padding]
361 | elif xBlock == xBlockNum - 1:
362 | tempImg[i] = img[0: blockSize+padding, xBlock*blockSize-padding: w]
363 | else:
364 | tempImg[i] = img[0: blockSize+padding, blockSize*xBlock-padding: blockSize*(xBlock+1)+padding]
365 | elif yBlock == yBlockNum - 1:
366 | if xBlock == 0:
367 | tempImg[i] = img[blockSize*yBlock-padding: h, 0: blockSize+padding]
368 | elif xBlock == xBlockNum - 1:
369 | tempImg[i] = img[blockSize*yBlock-padding: h, blockSize*xBlock-padding: w]
370 | else:
371 | tempImg[i] = img[blockSize*yBlock-padding: h, blockSize*xBlock-padding: blockSize*(xBlock+1)+padding]
372 | elif xBlock == 0:
373 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding, 0: blockSize+padding]
374 | elif xBlock == xBlockNum - 1:
375 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding, blockSize*xBlock-padding: w]
376 | else:
377 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding,
378 | blockSize*xBlock-padding: blockSize*(xBlock+1)+padding]
379 | i += i
380 | l = tempImg[i].astype('uint8')
381 | l = Image.fromarray(l)
382 | l.show()
383 |
--------------------------------------------------------------------------------
/test/test_Single/__pycache__/MFVDSingle_ra.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/__pycache__/MFVDSingle_ra.cpython-35.pyc
--------------------------------------------------------------------------------
/test/test_Single/__pycache__/UTILS_MF_ra_ALL_Single.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/__pycache__/UTILS_MF_ra_ALL_Single.cpython-35.pyc
--------------------------------------------------------------------------------
/test/test_Single/checkpoints/SF_qp22_ra_01251_LSVE/SF_qp22_ra_01251_LSVE_105.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp22_ra_01251_LSVE/SF_qp22_ra_01251_LSVE_105.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/test/test_Single/checkpoints/SF_qp22_ra_01251_LSVE/SF_qp22_ra_01251_LSVE_105.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp22_ra_01251_LSVE/SF_qp22_ra_01251_LSVE_105.ckpt.index
--------------------------------------------------------------------------------
/test/test_Single/checkpoints/SF_qp22_ra_01251_LSVE/SF_qp22_ra_01251_LSVE_105.ckpt.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp22_ra_01251_LSVE/SF_qp22_ra_01251_LSVE_105.ckpt.meta
--------------------------------------------------------------------------------
/test/test_Single/checkpoints/SF_qp22_ra_01251_LSVE/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "SF_qp22_ra_01251_VDSRSingle_105.ckpt"
2 | all_model_checkpoint_paths: "SF_qp22_ra_01251_VDSRSingle_105.ckpt"
3 |
--------------------------------------------------------------------------------
/test/test_Single/checkpoints/SF_qp27_ra_01232_LSVE/SF_qp27_ra_01232_LSVE_169.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp27_ra_01232_LSVE/SF_qp27_ra_01232_LSVE_169.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/test/test_Single/checkpoints/SF_qp27_ra_01232_LSVE/SF_qp27_ra_01232_LSVE_169.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp27_ra_01232_LSVE/SF_qp27_ra_01232_LSVE_169.ckpt.index
--------------------------------------------------------------------------------
/test/test_Single/checkpoints/SF_qp27_ra_01232_LSVE/SF_qp27_ra_01232_LSVE_169.ckpt.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp27_ra_01232_LSVE/SF_qp27_ra_01232_LSVE_169.ckpt.meta
--------------------------------------------------------------------------------
/test/test_Single/checkpoints/SF_qp27_ra_01232_LSVE/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "SF_qp27_ra_01232_VDSRSingle_169.ckpt"
2 | all_model_checkpoint_paths: "SF_qp27_ra_01232_VDSRSingle_169.ckpt"
3 |
--------------------------------------------------------------------------------
/test/test_Single/checkpoints/SF_qp32_ra_01242_LSVE/SF_qp32_ra_01242_LSVE_185.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp32_ra_01242_LSVE/SF_qp32_ra_01242_LSVE_185.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/test/test_Single/checkpoints/SF_qp32_ra_01242_LSVE/SF_qp32_ra_01242_LSVE_185.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp32_ra_01242_LSVE/SF_qp32_ra_01242_LSVE_185.ckpt.index
--------------------------------------------------------------------------------
/test/test_Single/checkpoints/SF_qp32_ra_01242_LSVE/SF_qp32_ra_01242_LSVE_185.ckpt.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp32_ra_01242_LSVE/SF_qp32_ra_01242_LSVE_185.ckpt.meta
--------------------------------------------------------------------------------
/test/test_Single/checkpoints/SF_qp32_ra_01242_LSVE/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "SF_qp32_ra_01242_VDSRSingle_185.ckpt"
2 | all_model_checkpoint_paths: "SF_qp32_ra_01242_VDSRSingle_185.ckpt"
3 |
--------------------------------------------------------------------------------
/test/test_Single/checkpoints/SF_qp37_ra_01174_LSVE/MF_qp37_ra_01174_LSVE_291.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp37_ra_01174_LSVE/MF_qp37_ra_01174_LSVE_291.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/test/test_Single/checkpoints/SF_qp37_ra_01174_LSVE/MF_qp37_ra_01174_LSVE_291.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp37_ra_01174_LSVE/MF_qp37_ra_01174_LSVE_291.ckpt.index
--------------------------------------------------------------------------------
/test/test_Single/checkpoints/SF_qp37_ra_01174_LSVE/MF_qp37_ra_01174_LSVE_291.ckpt.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/checkpoints/SF_qp37_ra_01174_LSVE/MF_qp37_ra_01174_LSVE_291.ckpt.meta
--------------------------------------------------------------------------------
/test/test_Single/checkpoints/SF_qp37_ra_01174_LSVE/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "MF_qp37_ra_01174_VDSRSingle_291.ckpt"
2 | all_model_checkpoint_paths: "MF_qp37_ra_01174_VDSRSingle_291.ckpt"
3 |
--------------------------------------------------------------------------------
/test/test_Single/outdata/SF_qp27_ra_01232_LSVE/events.out.tfevents.1549718071.AIR-PC:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/outdata/SF_qp27_ra_01232_LSVE/events.out.tfevents.1549718071.AIR-PC
--------------------------------------------------------------------------------
/test/test_Single/outdata/SF_qp27_ra_01232_LSVE/events.out.tfevents.1549718367.AIR-PC:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IVC-Projects/LMVE/1b33d2c6734d534468761443046c3218615423ef/test/test_Single/outdata/SF_qp27_ra_01232_LSVE/events.out.tfevents.1549718367.AIR-PC
--------------------------------------------------------------------------------
/train/train_MF/LMVE_ra_model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 |
5 | def model_double(inputHigh1Data_tensor, inputLowData_tensor, inputHigh2Data_tensor):
6 | # with tf.device("/gpu:0"):
7 | input_before = inputHigh1Data_tensor # highData
8 | input_current = inputLowData_tensor # lowData
9 |
10 | input_after = inputHigh2Data_tensor
11 |
12 | # due to don't have training_Set at right now, so let it be annotation.
13 | tensor = None
14 |
15 | # ----------------------------------------------Frame -1--------------------------------------------------------
16 | input_before_w = tf.get_variable("input_before_w", [5, 5, 1, 64],
17 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 25)))
18 | input_before_b = tf.get_variable("input_before_b", [64], initializer=tf.constant_initializer(0))
19 | input_high1_tensor = tf.nn.relu(
20 | tf.nn.bias_add(tf.nn.conv2d(input_before, input_before_w, strides=[1, 1, 1, 1], padding='SAME'), input_before_b))
21 | # ----------------------------------------------Frame 0--------------------------------------------------------
22 | input_current_w = tf.get_variable("input_current_w", [5, 5, 1, 64],
23 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 25)))
24 | input_current_b = tf.get_variable("input_current_b", [64], initializer=tf.constant_initializer(0))
25 | input_low_tensor = tf.nn.relu(
26 | tf.nn.bias_add(tf.nn.conv2d(input_current, input_current_w, strides=[1, 1, 1, 1], padding='SAME'),
27 | input_current_b))
28 |
29 | # ----------------------------------------------Frame 1--------------------------------------------------------
30 | input_after_w = tf.get_variable("input_after_w", [5, 5, 1, 64],
31 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 25)))
32 | input_after_b = tf.get_variable("input_after_b", [64], initializer=tf.constant_initializer(0))
33 | input_high2_tensor = tf.nn.relu(
34 | tf.nn.bias_add(tf.nn.conv2d(input_after, input_after_w, strides=[1, 1, 1, 1], padding='SAME'),
35 | input_after_b))
36 | # ------------------------------------------Frame -1\0\1 concat------------------------------------------
37 | input_tensor_Concat = tf.concat([input_high1_tensor, input_low_tensor, input_high2_tensor], axis=3)
38 | # ----------------------------------1x1 conv, for reduce number of parameters----------------
39 | input_1x1_w = tf.get_variable("input_1x1_w", [1, 1, 192, 64],
40 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 192)))
41 | input_1x1_b = tf.get_variable("input_1x1_b", [64], initializer=tf.constant_initializer(0))
42 | input_1x1_tensor = tf.nn.relu(
43 | tf.nn.bias_add(tf.nn.conv2d(input_tensor_Concat, input_1x1_w, strides=[1, 1, 1, 1], padding='SAME'),
44 | input_1x1_b))
45 | tensor = input_1x1_tensor
46 |
47 | # --------------------------------------start iteration for last layers----------------------------------
48 | convId = 0
49 | for i in range(18):
50 | conv_w = tf.get_variable("conv_%02d_w" % (convId), [3, 3, 64, 64],
51 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 9 / 64)))
52 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, conv_w)
53 | conv_b = tf.get_variable("conv_%02d_b" % (convId), [64], initializer=tf.constant_initializer(0))
54 | convId += 1
55 | tensor = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1, 1, 1, 1], padding='SAME'), conv_b))
56 |
57 | conv_w = tf.get_variable("conv_%02d_w" % (convId), [3, 3, 64, 1],
58 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 9 / 64)))
59 | conv_b = tf.get_variable("conv_%02d_b" % (convId), [1], initializer=tf.constant_initializer(0))
60 | convId += 1
61 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, conv_w)
62 | tensor = tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1, 1, 1, 1], padding='SAME'), conv_b)
63 |
64 | tensor = tf.add(tensor, input_current)
65 | return tensor
66 |
--------------------------------------------------------------------------------
/train/train_MF/LMVE_ra_train.py:
--------------------------------------------------------------------------------
1 | import argparse, time
2 | from LMVE_ra_model import model_double
3 | from UTILS_MF_ra import *
4 |
5 | tf.logging.set_verbosity(tf.logging.WARN)
6 | #os.environ["CUDA_VISIBLE_DEVICES"] = "1"
7 |
8 | EXP_DATA = 'MF_qp37_ra_02052_yangNet_double' # checkpoints path
9 | LOW_DATA_PATH = r"E:\MF\trainSet\qp37\low_data" # low frames
10 | HIGH1_DATA_PATH = r"E:\MF\trainSet\qp37\high1_wraped_Y" # high frames1
11 | HIGH2_DATA_PATH = r"E:\MF\trainSet\qp37\high2_wraped_Y" # high frames2
12 | LABEL_PATH = r"E:\MF\trainSet\qp37\label_s" #lable frames
13 | LOG_PATH = "./logs/%s/"%(EXP_DATA)
14 | CKPT_PATH = "./checkpoints/%s/"%(EXP_DATA)
15 | SAMPLE_PATH = "./samples/%s/"%(EXP_DATA)
16 | PATCH_SIZE = (64, 64)
17 | BATCH_SIZE = 64
18 | BASE_LR = 3e-4
19 | LR_DECAY_RATE = 0.2
20 | LR_DECAY_STEP = 20
21 | MAX_EPOCH = 2000
22 |
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument("--model_path")
26 | args = parser.parse_args()
27 | model_path = args.model_path
28 | if __name__ == '__main__':
29 |
30 | # return like this"[[[high1Data, lowData, high2Data], label], [[3, 8, 9], 33]]" with the whole path.
31 | train_list = get_train_list(load_file_list(HIGH1_DATA_PATH), load_file_list(LOW_DATA_PATH),
32 | load_file_list(HIGH2_DATA_PATH), load_file_list(LABEL_PATH))
33 |
34 | with tf.name_scope('input_scope'):
35 | train_hight1Data = tf.placeholder('float32', shape=(BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1))
36 | train_lowData = tf.placeholder('float32', shape=(BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1))
37 | train_hight2Data = tf.placeholder('float32', shape=(BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1))
38 | train_gt = tf.placeholder('float32', shape=(BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1))
39 |
40 | shared_model = tf.make_template('shared_model', model_double)
41 | train_output = shared_model(train_hight1Data, train_lowData, train_hight2Data)
42 | train_output = tf.clip_by_value(train_output, 0., 1.)
43 | with tf.name_scope('loss_scope'):
44 | loss2 = tf.reduce_sum(tf.square(tf.subtract(train_output, train_gt)))
45 | loss1 = tf.reduce_sum(tf.abs(tf.subtract(train_output, train_gt)))
46 | W = tf.get_collection(tf.GraphKeys.WEIGHTS)
47 | for w in W:
48 | loss2 += tf.nn.l2_loss(w)*1e-4
49 |
50 | avg_loss = tf.placeholder('float32')
51 | tf.summary.scalar("avg_loss", avg_loss)
52 |
53 | global_step = tf.Variable(0, trainable=False) # len(train_list)
54 | learning_rate = tf.train.exponential_decay(BASE_LR, global_step, LR_DECAY_STEP*1000, LR_DECAY_RATE, staircase=True)
55 | tf.summary.scalar("learning rate", learning_rate)
56 |
57 | # org ---------------------------------------------------------------------------------
58 | optimizer = tf.train.AdamOptimizer(learning_rate, 0.9)
59 | opt = optimizer.minimize(loss2, global_step=global_step)
60 | saver = tf.train.Saver(max_to_keep=0)
61 |
62 | # org end------------------------------------------------------------------------------
63 | config = tf.ConfigProto()
64 | config.gpu_options.allow_growth = True
65 | with tf.Session(config=config) as sess:
66 | if not os.path.exists(LOG_PATH):
67 | os.makedirs(LOG_PATH)
68 | if not os.path.exists(os.path.dirname(CKPT_PATH)):
69 | os.makedirs(os.path.dirname(CKPT_PATH))
70 | if not os.path.exists(SAMPLE_PATH):
71 | os.makedirs(SAMPLE_PATH)
72 |
73 | merged = tf.summary.merge_all()
74 | file_writer = tf.summary.FileWriter(LOG_PATH, sess.graph)
75 |
76 | sess.run(tf.global_variables_initializer())
77 |
78 | if model_path:
79 | print("restore model...")
80 | saver.restore(sess, model_path)
81 | print("Done")
82 | for epoch in range(MAX_EPOCH):
83 | total_g_loss, n_iter = 0, 0
84 | idxOfImgs = np.random.permutation(len(train_list))
85 | epoch_time = time.time()
86 |
87 | for idx in range(1000):
88 | input_high1Data, input_lowData, input_high2Data, gt_data = prepare_nn_data(train_list)
89 | feed_dict = {train_hight1Data: input_high1Data, train_lowData: input_lowData,
90 | train_hight2Data: input_high2Data, train_gt: gt_data}
91 |
92 | _, l, output, g_step = sess.run([opt, loss2, train_output, global_step], feed_dict=feed_dict)
93 | total_g_loss += l
94 | n_iter += 1
95 | del input_high1Data, input_lowData, input_high2Data, gt_data, output
96 | lr, summary = sess.run([learning_rate, merged], {avg_loss:total_g_loss/n_iter})
97 | file_writer.add_summary(summary, epoch)
98 | tf.logging.warning("Epoch: [%4d/%4d] time: %4.4f\tloss: %.8f\tlr: %.8f"%(epoch, MAX_EPOCH, time.time()-epoch_time, total_g_loss/n_iter, lr))
99 | print("Epoch: [%4d/%4d] time: %4.4f\tloss: %.8f\tlr: %.8f"%(epoch, MAX_EPOCH, time.time()-epoch_time, total_g_loss/n_iter, lr))
100 | saver.save(sess, os.path.join(CKPT_PATH, "%s_%03d.ckpt"%(EXP_DATA, epoch)))
101 |
--------------------------------------------------------------------------------
/train/train_MF/UTILS_MF_ra.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import math, os, random, re
4 | from PIL import Image
5 | from LMVE_ra_train import BATCH_SIZE
6 | from LMVE_ra_train import PATCH_SIZE
7 |
8 | # due to a batch trainingSet come from one picture. I design a algorithm to make the TrainingSet more diversity.
9 | def normalize(x):
10 | x = x / 255.
11 | return truncate(x, 0., 1.)
12 |
13 | def denormalize(x):
14 | x = x * 255.
15 | return truncate(x, 0., 255.)
16 |
17 | def truncate(input, min, max):
18 | input = np.where(input > min, input, min)
19 | input = np.where(input < max, input, max)
20 | return input
21 |
22 | def remap(input):
23 | input = 16+219/255*input
24 | return truncate(input, 16.0, 235.0)
25 |
26 | def deremap(input):
27 | input = (input-16)*255/219
28 | return truncate(input, 0.0, 255.0)
29 |
30 | # return the whole absolute path.
31 | def load_file_list(directory):
32 | list = []
33 | for filename in [y for y in os.listdir(directory) if os.path.isfile(os.path.join(directory,y))]:
34 | list.append(os.path.join(directory,filename))
35 | return list
36 |
37 | def searchHighData(currentLowDataIndex, highDataList, highIndexList):
38 | searchOffset = 3
39 | searchedHighDataIndexList = []
40 | searchedHighData = []
41 | for i in range(currentLowDataIndex - searchOffset, currentLowDataIndex + searchOffset + 1):
42 | if i in highIndexList:
43 | searchedHighDataIndexList.append(i)
44 | assert len(searchedHighDataIndexList) == 2, 'search method have error!'
45 | for tempData in highDataList:
46 | if int(os.path.basename(tempData).split('.')[0].split('_')[-1]) \
47 | == searchedHighDataIndexList[0] == searchedHighDataIndexList[1]:
48 | searchedHighData.append(tempData)
49 | return searchedHighData
50 |
51 |
52 | # return like this"[[[high1Data, lowData], label], [[2, 7, 8], 22], [[3, 8, 9], 33]]" with the whole path.
53 | def get_test_list2(highDataList, lowDataList, labelList):
54 | assert len(lowDataList) == len(labelList), "low:%d, label:%d,"%(len(lowDataList) , len(labelList))
55 | # [0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48]
56 | highIndexList = [q for q in range(49) if q % 4 == 0]
57 | test_list = []
58 | for tempDataPath in lowDataList:
59 | tempData = []
60 | temp = []
61 | # this place should changed on the different situation.
62 | currentLowDataIndex = int(os.path.basename(tempDataPath).split('.')[0].split('_')[-1])
63 | searchedHighData = searchHighData(currentLowDataIndex, highDataList, highIndexList)
64 | tempData.append(searchedHighData[0])
65 | tempData.append(tempDataPath)
66 | tempData.append(searchedHighData[1])
67 |
68 | i = list(lowDataList).index(tempDataPath)
69 |
70 | temp.append(tempData)
71 | temp.append(labelList[i])
72 | test_list.append(temp)
73 | return test_list
74 |
75 | def get_temptest_list(high1DataList, lowDataList, high2DataList, labelList):
76 | tempData = []
77 | temp = []
78 | test_list = []
79 | for i in range(len(lowDataList)):
80 | tempData.append(high1DataList[i])
81 | tempData.append(lowDataList[i])
82 | tempData.append(high2DataList[i])
83 | temp.append(tempData)
84 | temp.append(labelList[i])
85 |
86 | test_list.append(temp)
87 | return test_list
88 |
89 | # [[high, low1, label1], [[h21,h22], low2, label2]]
90 | def get_test_list(HIGHDATA_Parent_PATH, lowDataList, labelList):
91 | doubleTest_list = []
92 | singleTest_list = []
93 | HighDirList = os.listdir(HIGHDATA_Parent_PATH)
94 | # convert string-list to int-list.
95 | HighDirList = list(map(int, HighDirList))
96 |
97 | for lowdata in lowDataList:
98 |
99 | tempData = []
100 | lowdataIndex = int(os.path.basename(lowdata).split('.')[0].split('_')[-1])
101 | if lowdataIndex % 4 != 0:
102 | if lowdataIndex in HighDirList:
103 | High_current_path = HIGHDATA_Parent_PATH + '/' + str(lowdataIndex)
104 | TWO_HighData = os.listdir(High_current_path)
105 | Two_HighData = [os.path.join(High_current_path, T) for T in TWO_HighData]
106 | tempData.append(Two_HighData)
107 | tempData.append(lowdata)
108 | labelIndex = list(lowDataList).index(lowdata)
109 | tempData.append(labelList[labelIndex])
110 | doubleTest_list.append(tempData)
111 | else:
112 | tempData.append(lowdata)
113 | labelIndex = list(lowDataList).index(lowdata)
114 | if int(os.path.basename(lowdata).split('.')[0].split('_')[-1]) == \
115 | int(os.path.basename(labelList[labelIndex]).split('.')[0].split('_')[-1]):
116 |
117 | tempData.append(labelList[labelIndex])
118 | singleTest_list.append(tempData)
119 |
120 | return doubleTest_list, singleTest_list
121 |
122 | # return like this"[[[high1Data, lowData, high2Data], label], [[2, 7, 8], 22], [[3, 8, 9], 33]]" with the whole path.
123 | def get_train_list(high1DataList, lowDataList, high2DataList, labelList):
124 | assert len(lowDataList) == len(high1DataList) == len(labelList) == len(high2DataList), \
125 | "low:%d, high1:%d, label:%d, high2:%d"%(len(lowDataList), len(high1DataList), len(labelList), len(high2DataList))
126 |
127 | train_list = []
128 | for i in range(len(labelList)):
129 | tempData = []
130 | temp = []
131 | # this place should changed on the different situation.
132 | if int(os.path.basename(high1DataList[i]).split('_')[-1].split('.')[0]) + 4 == \
133 | int(os.path.basename(lowDataList[i]).split('_')[-1].split('.')[0]) + 2 == \
134 | int(os.path.basename(high2DataList[i]).split('_')[-1].split('.')[0]):
135 | tempData.append(high1DataList[i])
136 | tempData.append(lowDataList[i])
137 | tempData.append(high2DataList[i])
138 | temp.append(tempData)
139 | temp.append(labelList[i])
140 |
141 | else:
142 | raise Exception('len(lowData) not equal with len(highData)...')
143 | train_list.append(temp)
144 | return train_list
145 |
146 | def prepare_nn_data(train_list):
147 | batchSizeRandomList = random.sample(range(0,len(train_list)), 8)
148 | gt_list = []
149 | high1Data_list = []
150 | lowData_list = []
151 | high2Data_list = []
152 | for i in batchSizeRandomList:
153 | high1Data_image = c_getYdata(train_list[i][0][0])
154 | lowData_image = c_getYdata(train_list[i][0][1])
155 | high2Data_image = c_getYdata(train_list[i][0][2])
156 | gt_image = c_getYdata(train_list[i][1])
157 | for j in range(0, 8):
158 | #crop images to the disired size.
159 | high1Data_imgY, lowData_imgY, high2Data_imgY, gt_imgY = \
160 | crop(high1Data_image, lowData_image, high2Data_image, gt_image, PATCH_SIZE[0], PATCH_SIZE[1], "ndarray")
161 |
162 | #normalize
163 | high1Data_imgY = normalize(high1Data_imgY)
164 | lowData_imgY = normalize(lowData_imgY)
165 | high2Data_imgY = normalize(high2Data_imgY)
166 | gt_imgY = normalize(gt_imgY)
167 |
168 | high1Data_list.append(high1Data_imgY)
169 | lowData_list.append(lowData_imgY)
170 | high2Data_list.append(high2Data_imgY)
171 | gt_list.append(gt_imgY)
172 |
173 | high1Data_list = np.resize(high1Data_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1))
174 | lowData_list = np.resize(lowData_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1))
175 | high2Data_list = np.resize(high2Data_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1))
176 | gt_list = np.resize(gt_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1))
177 |
178 | return high1Data_list, lowData_list, high2Data_list, gt_list
179 |
180 | def getWH(yuvfileName):
181 | deyuv=re.compile(r'(.+?)\.')
182 | deyuvFilename=deyuv.findall(yuvfileName)[0] #去yuv后缀的文件名
183 | if 'x' in os.path.basename(deyuvFilename).split('_')[-2]:
184 | wxh = os.path.basename(deyuvFilename).split('_')[-2]
185 | elif 'x' in os.path.basename(deyuvFilename).split('_')[1]:
186 | wxh = os.path.basename(deyuvFilename).split('_')[1]
187 | else:
188 | # print(yuvfileName)
189 | raise Exception('do not find wxh')
190 | w, h = wxh.split('x')
191 | return int(w), int(h)
192 |
193 | def getYdata(path, size):
194 | w = size[0]
195 | h = size[1]
196 | with open(path, 'rb') as fp:
197 | fp.seek(0, 0)
198 | Yt = fp.read()
199 | tem = Image.frombytes('L', [w, h], Yt)
200 | Yt = np.asarray(tem, dtype='float32')
201 | return Yt
202 |
203 |
204 | def c_getYdata(path):
205 | return getYdata(path, getWH(path))
206 |
207 | def c_getCbCr(path):
208 | w, h = getWH(path)
209 | CbCr = []
210 | with open(path, 'rb+') as file:
211 | y = file.read(h * w)
212 | if y == b'':
213 | return ''
214 | u = file.read(h * w // 4)
215 | v = file.read(h * w // 4)
216 | # convert string-list to int-list.
217 | u = list(map(int, u))
218 | v = list(map(int, v))
219 | CbCr.append(u)
220 | CbCr.append(v)
221 | return CbCr
222 |
223 | def img2y(input_img):
224 | if np.asarray(input_img).shape[2] == 3:
225 | input_imgY = input_img.convert('YCbCr').split()[0]
226 | input_imgCb, input_imgCr = input_img.convert('YCbCr').split()[1:3]
227 | input_imgY = np.asarray(input_imgY, dtype='float32')
228 | input_imgCb = np.asarray(input_imgCb, dtype='float32')
229 | input_imgCr = np.asarray(input_imgCr, dtype='float32')
230 |
231 | #Concatenate Cb, Cr components for easy, they are used in pair anyway.
232 | input_imgCb = np.expand_dims(input_imgCb,2)
233 | input_imgCr = np.expand_dims(input_imgCr,2)
234 | input_imgCbCr = np.concatenate((input_imgCb, input_imgCr), axis=2)
235 |
236 | elif np.asarray(input_img).shape[2] == 1:
237 | print("This image has one channal only.")
238 | #If the num of channal is 1, remain.
239 | input_imgY = input_img
240 | input_imgCbCr = None
241 | else:
242 | print("The num of channal is neither 3 nor 1.")
243 | exit()
244 | return input_imgY, input_imgCbCr
245 |
246 | # def crop(input_image, gt_image, patch_width, patch_height, img_type):
247 | def crop(high1Data_image, lowData_image, high2Data_image, gt_image, patch_width, patch_height, img_type):
248 | assert type(high1Data_image) == type(gt_image) == type(lowData_image) == type(high2Data_image), "types are different."
249 | high1Data_cropped = []
250 | lowData_cropped = []
251 | high2Data_cropped = []
252 | gt_cropped = []
253 |
254 | # return a ndarray object
255 | if img_type == "ndarray":
256 | in_row_ind = random.randint(0,high1Data_image.shape[0]-patch_width)
257 | in_col_ind = random.randint(0,high1Data_image.shape[1]-patch_height)
258 |
259 | high1Data_cropped = high1Data_image[in_row_ind:in_row_ind+patch_width, in_col_ind:in_col_ind+patch_height]
260 | lowData_cropped = lowData_image[in_row_ind:in_row_ind + patch_width, in_col_ind:in_col_ind + patch_height]
261 | high2Data_cropped = high2Data_image[in_row_ind:in_row_ind + patch_width, in_col_ind:in_col_ind + patch_height]
262 | gt_cropped = gt_image[in_row_ind:in_row_ind+patch_width, in_col_ind:in_col_ind+patch_height]
263 |
264 | #return an "Image" object
265 | elif img_type == "Image":
266 | pass
267 | return high1Data_cropped, lowData_cropped, high2Data_cropped, gt_cropped
268 |
269 | def save_images(inputY, inputCbCr, size, image_path):
270 | """Save mutiple images into one single image.
271 |
272 | # Parameters
273 | # -----------
274 | # images : numpy array [batch, w, h, c]
275 | # size : list of two int, row and column number.
276 | # number of images should be equal or less than size[0] * size[1]
277 | # image_path : string.
278 | #
279 | # Examples
280 | # ---------
281 | # # >>> images = np.random.rand(64, 100, 100, 3)
282 | # # >>> tl.visualize.save_images(images, [8, 8], 'temp.png')
283 | """
284 | def merge(images, size):
285 | h, w = images.shape[1], images.shape[2]
286 | img = np.zeros((h * size[0], w * size[1], 3))
287 | for idx, image in enumerate(images):
288 | i = idx % size[1]
289 | j = idx // size[1]
290 | img[j*h:j*h+h, i*w:i*w+w, :] = image
291 | return img
292 |
293 | inputY = inputY.astype('uint8')
294 | inputCbCr = inputCbCr.astype('uint8')
295 | output_concat = np.concatenate((inputY, inputCbCr), axis=3)
296 |
297 | assert len(output_concat) <= size[0] * size[1], "number of images should be equal or less than size[0] * size[1] {}".format(len(output_concat))
298 |
299 | new_output = merge(output_concat, size)
300 |
301 | new_output = new_output.astype('uint8')
302 |
303 | img = Image.fromarray(new_output, mode='YCbCr')
304 | img = img.convert('RGB')
305 | img.save(image_path)
306 |
307 | def get_image_batch(train_list,offset,batch_size):
308 | target_list = train_list[offset:offset+batch_size]
309 | input_list = []
310 | gt_list = []
311 | inputcbcr_list = []
312 | for pair in target_list:
313 | input_img = Image.open(pair[0])
314 | gt_img = Image.open(pair[1])
315 |
316 | #crop images to the disired size.
317 | input_img, gt_img = crop(input_img, gt_img, PATCH_SIZE[0], PATCH_SIZE[1], "Image")
318 |
319 | #focus on Y channal only
320 | input_imgY, input_imgCbCr = img2y(input_img)
321 | gt_imgY, gt_imgCbCr = img2y(gt_img)
322 |
323 | #input_imgY = normalize(input_imgY)
324 | #gt_imgY = normalize(gt_imgY)
325 |
326 | input_list.append(input_imgY)
327 | gt_list.append(gt_imgY)
328 | inputcbcr_list.append(input_imgCbCr)
329 |
330 | input_list = np.resize(input_list, (batch_size, PATCH_SIZE[0], PATCH_SIZE[1], 1))
331 | gt_list = np.resize(gt_list, (batch_size, PATCH_SIZE[0], PATCH_SIZE[1], 1))
332 |
333 | return input_list, gt_list, inputcbcr_list
334 |
335 | def save_test_img(inputY, inputCbCr, path):
336 | assert len(inputY.shape) == 4, "the tensor Y's shape is %s"%inputY.shape
337 | assert inputY.shape[0] == 1, "the fitst component must be 1, has not been completed otherwise.{}".format(inputY.shape)
338 |
339 | inputY = np.squeeze(inputY, axis=0)
340 | inputY = inputY.astype('uint8')
341 |
342 | inputCbCr = inputCbCr.astype('uint8')
343 |
344 | output_concat = np.concatenate((inputY, inputCbCr), axis=2)
345 | img = Image.fromarray(output_concat, mode='YCbCr')
346 | img = img.convert('RGB')
347 | img.save(path)
348 |
349 | def psnr(hr_image, sr_image, max_value=255.0):
350 | eps = 1e-10
351 | if((type(hr_image)==type(np.array([]))) or (type(hr_image)==type([]))):
352 | hr_image_data = np.asarray(hr_image, 'float32')
353 | sr_image_data = np.asarray(sr_image, 'float32')
354 |
355 | diff = sr_image_data - hr_image_data
356 | mse = np.mean(diff*diff)
357 | mse = np.maximum(eps, mse)
358 | return float(10*math.log10(max_value*max_value/mse))
359 | else:
360 | assert len(hr_image.shape)==4 and len(sr_image.shape)==4
361 | diff = hr_image - sr_image
362 | mse = tf.reduce_mean(tf.square(diff))
363 | mse = tf.maximum(mse, eps)
364 | return 10*tf.log(max_value*max_value/mse)/math.log(10)
365 |
366 | def getBeforeNNBlockDict(img, w, h):
367 | # print(img[:1500, : 2000])
368 | blockSize = 1000
369 | padding = 32
370 | yBlockNum = (h // blockSize) if (h % blockSize == 0) else (h // blockSize + 1)
371 | xBlockNum = (w // blockSize) if (w % blockSize == 0) else (w // blockSize + 1)
372 | tempImg = {}
373 | i = 0
374 | for yBlock in range(yBlockNum):
375 | for xBlock in range(xBlockNum):
376 | if yBlock == 0:
377 | if xBlock == 0:
378 | tempImg[i] = img[0: blockSize+padding, 0: blockSize+padding]
379 | elif xBlock == xBlockNum - 1:
380 | tempImg[i] = img[0: blockSize+padding, xBlock*blockSize-padding: w]
381 | else:
382 | tempImg[i] = img[0: blockSize+padding, blockSize*xBlock-padding: blockSize*(xBlock+1)+padding]
383 | elif yBlock == yBlockNum - 1:
384 | if xBlock == 0:
385 | tempImg[i] = img[blockSize*yBlock-padding: h, 0: blockSize+padding]
386 | elif xBlock == xBlockNum - 1:
387 | tempImg[i] = img[blockSize*yBlock-padding: h, blockSize*xBlock-padding: w]
388 | else:
389 | tempImg[i] = img[blockSize*yBlock-padding: h, blockSize*xBlock-padding: blockSize*(xBlock+1)+padding]
390 | elif xBlock == 0:
391 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding, 0: blockSize+padding]
392 | elif xBlock == xBlockNum - 1:
393 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding, blockSize*xBlock-padding: w]
394 | else:
395 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding,
396 | blockSize*xBlock-padding: blockSize*(xBlock+1)+padding]
397 | i += i
398 | l = tempImg[i].astype('uint8')
399 | l = Image.fromarray(l)
400 | l.show()
401 |
--------------------------------------------------------------------------------
/train/train_Single/LSVE_Single_ra_train.py:
--------------------------------------------------------------------------------
1 | import argparse, time
2 | from LSVE_ra_model import model_single
3 | from UTILS_single_ra import *
4 |
5 | tf.logging.set_verbosity(tf.logging.WARN)
6 |
7 | #os.environ["CUDA_VISIBLE_DEVICES"] = "1"
8 | EXP_DATA = 'SF_qp22_ra_01243_VDSRSingle' # checkpoints path
9 | HIGH_DATA_PATH = r"E:\MF\trainSet\qp22\single\high_data" # high frames
10 | LABEL_PATH = r"E:\MF\trainSet\qp22\single\label_s" # lable frames
11 | LOG_PATH = "./logs/%s/"%(EXP_DATA)
12 | CKPT_PATH = "./checkpoints/%s/"%(EXP_DATA)
13 | SAMPLE_PATH = "./samples/%s/"%(EXP_DATA)
14 | PATCH_SIZE = (64, 64)
15 | BATCH_SIZE = 64
16 | BASE_LR = 3e-4
17 | LR_DECAY_RATE = 0.5
18 | LR_DECAY_STEP = 20
19 | MAX_EPOCH = 500
20 |
21 |
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument("--model_path")
24 | args = parser.parse_args()
25 | model_path = args.model_path
26 |
27 | if __name__ == '__main__':
28 |
29 | # train_list return like this"[[hdata1, label1],[hdata2, label2]]" with the whole path.
30 | train_list = get_train_list(load_file_list(HIGH_DATA_PATH), load_file_list(LABEL_PATH))
31 | with tf.name_scope('input_scope'):
32 | train_hightData = tf.placeholder('float32', shape=(BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1))
33 | train_gt = tf.placeholder('float32', shape=(BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1))
34 |
35 | shared_model = tf.make_template('shared_model', model_single)
36 | train_output = shared_model(train_hightData)
37 | train_output = tf.clip_by_value(train_output, 0., 1.)
38 | with tf.name_scope('loss_scope'):
39 | loss2 = tf.reduce_mean(tf.square(tf.subtract(train_output, train_gt)))
40 | loss1 = tf.reduce_sum(tf.abs(tf.subtract(train_output, train_gt)))
41 | W = tf.get_collection(tf.GraphKeys.WEIGHTS)
42 | for w in W:
43 | loss2 += tf.nn.l2_loss(w)*1e-4
44 |
45 | avg_loss = tf.placeholder('float32')
46 | tf.summary.scalar("avg_loss", avg_loss)
47 |
48 | global_step = tf.Variable(0, trainable=False) # len(train_list)
49 | learning_rate = tf.train.exponential_decay(BASE_LR, global_step, LR_DECAY_STEP*1111, LR_DECAY_RATE, staircase=True)
50 | tf.summary.scalar("learning rate", learning_rate)
51 |
52 | # org ---------------------------------------------------------------------------------
53 | optimizer = tf.train.AdamOptimizer(learning_rate, 0.9)
54 | opt = optimizer.minimize(loss2, global_step=global_step)
55 | saver = tf.train.Saver(max_to_keep=0)
56 | # org end------------------------------------------------------------------------------
57 | config = tf.ConfigProto()
58 | config.gpu_options.allow_growth = True
59 |
60 | with tf.Session(config=config) as sess:
61 | if not os.path.exists(LOG_PATH):
62 | os.makedirs(LOG_PATH)
63 | if not os.path.exists(os.path.dirname(CKPT_PATH)):
64 | os.makedirs(os.path.dirname(CKPT_PATH))
65 | if not os.path.exists(SAMPLE_PATH):
66 | os.makedirs(SAMPLE_PATH)
67 |
68 | merged = tf.summary.merge_all()
69 | file_writer = tf.summary.FileWriter(LOG_PATH, sess.graph)
70 |
71 | sess.run(tf.global_variables_initializer())
72 |
73 | if model_path:
74 | print("restore model...")
75 | saver.restore(sess, model_path)
76 | print("Done")
77 |
78 | #for epoch in range(400, MAX_EPOCH):
79 | for epoch in range(MAX_EPOCH):
80 | total_g_loss, n_iter = 0, 0
81 | idxOfImgs = np.random.permutation(len(train_list))
82 |
83 | epoch_time = time.time()
84 |
85 | # for idx in range(len(idxOfImgs)):
86 | for idx in range(1111):
87 | input_highData, gt_data = prepare_nn_data(train_list)
88 | feed_dict = {train_hightData: input_highData, train_gt: gt_data}
89 |
90 | _, l, output, g_step = sess.run([opt, loss2, train_output, global_step], feed_dict=feed_dict)
91 | total_g_loss += l
92 | n_iter += 1
93 |
94 | del input_highData, gt_data, output
95 | lr, summary = sess.run([learning_rate, merged], {avg_loss:total_g_loss/n_iter})
96 | file_writer.add_summary(summary, epoch)
97 | tf.logging.warning("Epoch: [%4d/%4d] time: %4.4f\tloss: %.8f\tlr: %.8f"%(epoch, MAX_EPOCH, time.time()-epoch_time, total_g_loss/n_iter, lr))
98 | print("Epoch: [%4d/%4d] time: %4.4f\tloss: %.8f\tlr: %.8f"%(epoch, MAX_EPOCH, time.time()-epoch_time, total_g_loss/n_iter, lr))
99 | saver.save(sess, os.path.join(CKPT_PATH, "%s_%03d.ckpt"%(EXP_DATA, epoch)))
100 |
--------------------------------------------------------------------------------
/train/train_Single/LSVE_ra_model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 |
5 | def model_single(inputLowData_tensor):
6 | # with tf.device("/gpu:0"):
7 | input_current = inputLowData_tensor # lowData
8 |
9 | # due to don't have training_Set at right now, so let it be annotation.
10 | tensor = None
11 |
12 | # ----------------------------------------------Frame 0--------------------------------------------------------
13 | input_current_w = tf.get_variable("input_current_w", [5, 5, 1, 64],
14 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 25)))
15 | input_current_b = tf.get_variable("input_current_b", [64], initializer=tf.constant_initializer(0))
16 | input_low_tensor = tf.nn.relu(
17 | tf.nn.bias_add(tf.nn.conv2d(input_current, input_current_w, strides=[1, 1, 1, 1], padding='SAME'),
18 | input_current_b))
19 |
20 | # ----------------------------------1x1 conv, for reduce number of parameters----------------
21 | input_3x3_w = tf.get_variable("input_1x1_w", [3, 3, 64, 64],
22 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 64 / 9)))
23 | input_1x1_b = tf.get_variable("input_1x1_b", [64], initializer=tf.constant_initializer(0))
24 | input_3x3_tensor = tf.nn.relu(
25 | tf.nn.bias_add(tf.nn.conv2d(input_low_tensor, input_3x3_w, strides=[1, 1, 1, 1], padding='SAME'),
26 | input_1x1_b))
27 | tensor = input_3x3_tensor
28 |
29 | # --------------------------------------start iteration for last layers----------------------------------
30 | convId = 0
31 | for i in range(18):
32 | conv_w = tf.get_variable("conv_%02d_w" % (convId), [3, 3, 64, 64],
33 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 9 / 64)))
34 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, conv_w)
35 | conv_b = tf.get_variable("conv_%02d_b" % (convId), [64], initializer=tf.constant_initializer(0))
36 | convId += 1
37 | tensor = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1, 1, 1, 1], padding='SAME'), conv_b))
38 |
39 | conv_w = tf.get_variable("conv_%02d_w" % (convId), [3, 3, 64, 1],
40 | initializer=tf.random_normal_initializer(stddev=np.sqrt(2.0 / 9 / 64)))
41 | conv_b = tf.get_variable("conv_%02d_b" % (convId), [1], initializer=tf.constant_initializer(0))
42 | convId += 1
43 | tf.add_to_collection(tf.GraphKeys.WEIGHTS, conv_w)
44 | tensor = tf.nn.bias_add(tf.nn.conv2d(tensor, conv_w, strides=[1, 1, 1, 1], padding='SAME'), conv_b)
45 | tensor = tf.add(tensor, input_current)
46 | return tensor
47 |
--------------------------------------------------------------------------------
/train/train_Single/UTILS_single_ra.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import math, os, random, re
4 | from PIL import Image
5 | from LSVE_Single_ra_train import BATCH_SIZE
6 | from LSVE_Single_ra_train import PATCH_SIZE
7 |
8 | # due to a batch trainingSet come from one picture. I design a algorithm to make the TrainingSet more diversity.
9 |
10 | def normalize(x):
11 | x = x / 255.
12 | return truncate(x, 0., 1.)
13 |
14 | def denormalize(x):
15 | x = x * 255.
16 | return truncate(x, 0., 255.)
17 |
18 | def truncate(input, min, max):
19 | input = np.where(input > min, input, min)
20 | input = np.where(input < max, input, max)
21 | return input
22 |
23 | def remap(input):
24 | input = 16+219/255*input
25 | return truncate(input, 16.0, 235.0)
26 |
27 | def deremap(input):
28 | input = (input-16)*255/219
29 | return truncate(input, 0.0, 255.0)
30 |
31 | # return the whole absolute path.
32 | def load_file_list(directory):
33 | list = []
34 | for filename in [y for y in os.listdir(directory) if os.path.isfile(os.path.join(directory,y))]:
35 | list.append(os.path.join(directory,filename))
36 | return list
37 |
38 | def searchHighData(currentLowDataIndex, highDataList, highIndexList):
39 | searchOffset = 3
40 | searchedHighDataIndexList = []
41 | searchedHighData = []
42 | for i in range(currentLowDataIndex - searchOffset, currentLowDataIndex + searchOffset + 1):
43 | if i in highIndexList:
44 | searchedHighDataIndexList.append(i)
45 | assert len(searchedHighDataIndexList) == 2, 'search method have error!'
46 | for tempData in highDataList:
47 | if int(os.path.basename(tempData).split('.')[0].split('_')[-1]) \
48 | == searchedHighDataIndexList[0] == searchedHighDataIndexList[1]:
49 |
50 | searchedHighData.append(tempData)
51 | return searchedHighData
52 |
53 |
54 | # return like this"[[[high1Data, lowData], label], [[2, 7, 8], 22], [[3, 8, 9], 33]]" with the whole path.
55 | def get_test_list2(highDataList, lowDataList, labelList):
56 | assert len(lowDataList) == len(labelList), "low:%d, label:%d,"%(len(lowDataList) , len(labelList))
57 |
58 | # [0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48]
59 | highIndexList = [q for q in range(49) if q % 4 == 0]
60 | test_list = []
61 | for tempDataPath in lowDataList:
62 | tempData = []
63 | temp = []
64 | # this place should changed on the different situation.
65 | currentLowDataIndex = int(os.path.basename(tempDataPath).split('.')[0].split('_')[-1])
66 | searchedHighData = searchHighData(currentLowDataIndex, highDataList, highIndexList)
67 | tempData.append(searchedHighData[0])
68 | tempData.append(tempDataPath)
69 | tempData.append(searchedHighData[1])
70 |
71 | i = list(lowDataList).index(tempDataPath)
72 |
73 | temp.append(tempData)
74 | temp.append(labelList[i])
75 |
76 | test_list.append(temp)
77 | return test_list
78 |
79 | def get_temptest_list(high1DataList, lowDataList, high2DataList, labelList):
80 | tempData = []
81 | temp = []
82 | test_list = []
83 | for i in range(len(lowDataList)):
84 | tempData.append(high1DataList[i])
85 | tempData.append(lowDataList[i])
86 | tempData.append(high2DataList[i])
87 | temp.append(tempData)
88 | temp.append(labelList[i])
89 |
90 | test_list.append(temp)
91 | return test_list
92 |
93 | # [[high, low1, label1], [[h21,h22], low2, label2]]
94 | def get_test_list(HIGHDATA_Parent_PATH, lowDataList, labelList):
95 | doubleTest_list = []
96 | singleTest_list = []
97 | HighDirList = os.listdir(HIGHDATA_Parent_PATH)
98 | # convert string-list to int-list.
99 | HighDirList = list(map(int, HighDirList))
100 |
101 | for lowdata in lowDataList:
102 |
103 | tempData = []
104 | lowdataIndex = int(os.path.basename(lowdata).split('.')[0].split('_')[-1])
105 | if lowdataIndex % 4 != 0:
106 | if lowdataIndex in HighDirList:
107 | High_current_path = HIGHDATA_Parent_PATH + '/' + str(lowdataIndex)
108 | TWO_HighData = os.listdir(High_current_path)
109 | Two_HighData = [os.path.join(High_current_path, T) for T in TWO_HighData]
110 | tempData.append(Two_HighData)
111 | tempData.append(lowdata)
112 | labelIndex = list(lowDataList).index(lowdata)
113 | tempData.append(labelList[labelIndex])
114 | doubleTest_list.append(tempData)
115 | else:
116 | tempData.append(lowdata)
117 | labelIndex = list(lowDataList).index(lowdata)
118 | if int(os.path.basename(lowdata).split('.')[0].split('_')[-1]) == \
119 | int(os.path.basename(labelList[labelIndex]).split('.')[0].split('_')[-1]):
120 |
121 | tempData.append(labelList[labelIndex])
122 | singleTest_list.append(tempData)
123 |
124 | return doubleTest_list, singleTest_list
125 |
126 | # return like this"[[hdata1, label1],[hdata2, label2]]" with the whole path.
127 | def get_train_list(highDataList, labelList):
128 | assert len(highDataList) == len(labelList), \
129 | "high:%d, label:%d" % (len(highDataList), len(labelList))
130 |
131 | train_list = []
132 | for i in range(len(labelList)):
133 | tempData = []
134 | # this place should changed on the different situation.
135 | if int(os.path.basename(highDataList[i]).split('_')[-1].split('.')[0]) == \
136 | int(os.path.basename(labelList[i]).split('_')[-1].split('.')[0]):
137 | tempData.append(highDataList[i])
138 | tempData.append(labelList[i])
139 |
140 | else:
141 | raise Exception('len(lowData) not equal with len(highData)...')
142 | train_list.append(tempData)
143 | return train_list
144 |
145 | # train_list like this"[[hdata1, label1],[hdata2, label2]]" with the whole path.
146 | def prepare_nn_data(train_list):
147 | batchSizeRandomList = random.sample(range(0,len(train_list)), 8)
148 | gt_list = []
149 | highData_list = []
150 | for i in batchSizeRandomList:
151 | highData_image = c_getYdata(train_list[i][0])
152 | gt_image = c_getYdata(train_list[i][1])
153 | for j in range(0, 8):
154 | #crop images to the disired size.
155 | highData_imgY, gt_imgY = \
156 | crop(highData_image, gt_image, PATCH_SIZE[0], PATCH_SIZE[1], "ndarray")
157 | #normalize
158 | highData_imgY = normalize(highData_imgY)
159 | gt_imgY = normalize(gt_imgY)
160 |
161 | highData_list.append(highData_imgY)
162 | gt_list.append(gt_imgY)
163 |
164 | high1Data_list = np.resize(highData_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1))
165 | gt_list = np.resize(gt_list, (BATCH_SIZE, PATCH_SIZE[0], PATCH_SIZE[1], 1))
166 | return high1Data_list, gt_list
167 |
168 | def getWH(yuvfileName):
169 | deyuv=re.compile(r'(.+?)\.')
170 | deyuvFilename=deyuv.findall(yuvfileName)[0] #去yuv后缀的文件名
171 | if 'x' in os.path.basename(deyuvFilename).split('_')[-2]:
172 | wxh = os.path.basename(deyuvFilename).split('_')[-2]
173 | elif 'x' in os.path.basename(deyuvFilename).split('_')[1]:
174 | wxh = os.path.basename(deyuvFilename).split('_')[1]
175 | else:
176 | raise Exception('do not find wxh')
177 | w, h = wxh.split('x')
178 | return int(w), int(h)
179 |
180 | def getYdata(path, size):
181 | w = size[0]
182 | h = size[1]
183 | with open(path, 'rb') as fp:
184 | fp.seek(0, 0)
185 | Yt = fp.read()
186 | tem = Image.frombytes('L', [w, h], Yt)
187 | Yt = np.asarray(tem, dtype='float32')
188 | return Yt
189 |
190 | def c_getYdata(path):
191 | return getYdata(path, getWH(path))
192 |
193 | def img2y(input_img):
194 | if np.asarray(input_img).shape[2] == 3:
195 | input_imgY = input_img.convert('YCbCr').split()[0]
196 | input_imgCb, input_imgCr = input_img.convert('YCbCr').split()[1:3]
197 |
198 | input_imgY = np.asarray(input_imgY, dtype='float32')
199 | input_imgCb = np.asarray(input_imgCb, dtype='float32')
200 | input_imgCr = np.asarray(input_imgCr, dtype='float32')
201 |
202 |
203 | #Concatenate Cb, Cr components for easy, they are used in pair anyway.
204 | input_imgCb = np.expand_dims(input_imgCb,2)
205 | input_imgCr = np.expand_dims(input_imgCr,2)
206 | input_imgCbCr = np.concatenate((input_imgCb, input_imgCr), axis=2)
207 |
208 | elif np.asarray(input_img).shape[2] == 1:
209 | print("This image has one channal only.")
210 | #If the num of channal is 1, remain.
211 | input_imgY = input_img
212 | input_imgCbCr = None
213 | else:
214 | print("The num of channal is neither 3 nor 1.")
215 | exit()
216 | return input_imgY, input_imgCbCr
217 |
218 | # def crop(input_image, gt_image, patch_width, patch_height, img_type):
219 | def crop(highData_image, gt_image, patch_width, patch_height, img_type):
220 | assert type(highData_image) == type(gt_image), "types are different."
221 | highData_cropped = []
222 | gt_cropped = []
223 |
224 | # return a ndarray object
225 | if img_type == "ndarray":
226 | in_row_ind = random.randint(0,highData_image.shape[0]-patch_width)
227 | in_col_ind = random.randint(0,highData_image.shape[1]-patch_height)
228 | highData_cropped = highData_image[in_row_ind:in_row_ind+patch_width, in_col_ind:in_col_ind+patch_height]
229 | gt_cropped = gt_image[in_row_ind:in_row_ind+patch_width, in_col_ind:in_col_ind+patch_height]
230 |
231 | #return an "Image" object
232 | elif img_type == "Image":
233 | pass
234 | return highData_cropped, gt_cropped
235 |
236 | def save_images(inputY, inputCbCr, size, image_path):
237 | """Save mutiple images into one single image.
238 |
239 | # Parameters
240 | # -----------
241 | # images : numpy array [batch, w, h, c]
242 | # size : list of two int, row and column number.
243 | # number of images should be equal or less than size[0] * size[1]
244 | # image_path : string.
245 | #
246 | # Examples
247 | # ---------
248 | # # >>> images = np.random.rand(64, 100, 100, 3)
249 | # # >>> tl.visualize.save_images(images, [8, 8], 'temp.png')
250 | """
251 | def merge(images, size):
252 | h, w = images.shape[1], images.shape[2]
253 | img = np.zeros((h * size[0], w * size[1], 3))
254 | for idx, image in enumerate(images):
255 | i = idx % size[1]
256 | j = idx // size[1]
257 | img[j*h:j*h+h, i*w:i*w+w, :] = image
258 | return img
259 |
260 | inputY = inputY.astype('uint8')
261 | inputCbCr = inputCbCr.astype('uint8')
262 | output_concat = np.concatenate((inputY, inputCbCr), axis=3)
263 |
264 | assert len(output_concat) <= size[0] * size[1], "number of images should be equal or less than size[0] * size[1] {}".format(len(output_concat))
265 |
266 | new_output = merge(output_concat, size)
267 | new_output = new_output.astype('uint8')
268 | img = Image.fromarray(new_output, mode='YCbCr')
269 | img = img.convert('RGB')
270 | img.save(image_path)
271 |
272 | def get_image_batch(train_list,offset,batch_size):
273 | target_list = train_list[offset:offset+batch_size]
274 | input_list = []
275 | gt_list = []
276 | inputcbcr_list = []
277 | for pair in target_list:
278 | input_img = Image.open(pair[0])
279 | gt_img = Image.open(pair[1])
280 |
281 | #crop images to the disired size.
282 | input_img, gt_img = crop(input_img, gt_img, PATCH_SIZE[0], PATCH_SIZE[1], "Image")
283 |
284 | #focus on Y channal only
285 | input_imgY, input_imgCbCr = img2y(input_img)
286 | gt_imgY, gt_imgCbCr = img2y(gt_img)
287 | input_list.append(input_imgY)
288 | gt_list.append(gt_imgY)
289 | inputcbcr_list.append(input_imgCbCr)
290 |
291 | input_list = np.resize(input_list, (batch_size, PATCH_SIZE[0], PATCH_SIZE[1], 1))
292 | gt_list = np.resize(gt_list, (batch_size, PATCH_SIZE[0], PATCH_SIZE[1], 1))
293 |
294 | return input_list, gt_list, inputcbcr_list
295 |
296 | def save_test_img(inputY, inputCbCr, path):
297 | assert len(inputY.shape) == 4, "the tensor Y's shape is %s"%inputY.shape
298 | assert inputY.shape[0] == 1, "the fitst component must be 1, has not been completed otherwise.{}".format(inputY.shape)
299 |
300 | inputY = np.squeeze(inputY, axis=0)
301 | inputY = inputY.astype('uint8')
302 | inputCbCr = inputCbCr.astype('uint8')
303 | output_concat = np.concatenate((inputY, inputCbCr), axis=2)
304 | img = Image.fromarray(output_concat, mode='YCbCr')
305 | img = img.convert('RGB')
306 | img.save(path)
307 |
308 | def psnr(hr_image, sr_image, max_value=255.0):
309 | eps = 1e-10
310 | if((type(hr_image)==type(np.array([]))) or (type(hr_image)==type([]))):
311 | hr_image_data = np.asarray(hr_image, 'float32')
312 | sr_image_data = np.asarray(sr_image, 'float32')
313 |
314 | diff = sr_image_data - hr_image_data
315 | mse = np.mean(diff*diff)
316 | mse = np.maximum(eps, mse)
317 | return float(10*math.log10(max_value*max_value/mse))
318 | else:
319 | assert len(hr_image.shape)==4 and len(sr_image.shape)==4
320 | diff = hr_image - sr_image
321 | mse = tf.reduce_mean(tf.square(diff))
322 | mse = tf.maximum(mse, eps)
323 | return 10*tf.log(max_value*max_value/mse)/math.log(10)
324 |
325 | def getBeforeNNBlockDict(img, w, h):
326 | blockSize = 1000
327 | padding = 32
328 | yBlockNum = (h // blockSize) if (h % blockSize == 0) else (h // blockSize + 1)
329 | xBlockNum = (w // blockSize) if (w % blockSize == 0) else (w // blockSize + 1)
330 | tempImg = {}
331 | i = 0
332 | for yBlock in range(yBlockNum):
333 | for xBlock in range(xBlockNum):
334 | if yBlock == 0:
335 | if xBlock == 0:
336 | tempImg[i] = img[0: blockSize+padding, 0: blockSize+padding]
337 | elif xBlock == xBlockNum - 1:
338 | tempImg[i] = img[0: blockSize+padding, xBlock*blockSize-padding: w]
339 | else:
340 | tempImg[i] = img[0: blockSize+padding, blockSize*xBlock-padding: blockSize*(xBlock+1)+padding]
341 | elif yBlock == yBlockNum - 1:
342 | if xBlock == 0:
343 | tempImg[i] = img[blockSize*yBlock-padding: h, 0: blockSize+padding]
344 | elif xBlock == xBlockNum - 1:
345 | tempImg[i] = img[blockSize*yBlock-padding: h, blockSize*xBlock-padding: w]
346 | else:
347 | tempImg[i] = img[blockSize*yBlock-padding: h, blockSize*xBlock-padding: blockSize*(xBlock+1)+padding]
348 | elif xBlock == 0:
349 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding, 0: blockSize+padding]
350 | elif xBlock == xBlockNum - 1:
351 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding, blockSize*xBlock-padding: w]
352 | else:
353 | tempImg[i] = img[blockSize*yBlock-padding: blockSize*(yBlock+1)+padding,
354 | blockSize*xBlock-padding: blockSize*(xBlock+1)+padding]
355 | i += i
356 | l = tempImg[i].astype('uint8')
357 | l = Image.fromarray(l)
358 | l.show()
359 |
--------------------------------------------------------------------------------