├── .idea
├── keras-SRCNN.iml
├── misc.xml
├── modules.xml
└── workspace.xml
├── 3051crop_weight_200.h5
├── SRCNN.png
├── butterfly_GT.bmp
├── input.jpg
├── m_model_adam_new30.h5
├── main.py
├── pre_adam30.jpg
├── prepare_data.py
├── prepare_data.pyc
├── psnr.py
├── readme.md
└── result.png
/.idea/keras-SRCNN.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 | dim_ordering
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 | 1487317834271
286 |
287 |
288 | 1487317834271
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
--------------------------------------------------------------------------------
/3051crop_weight_200.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaokeAI/SRCNN-keras/162431e0110b1c30444f3f8b3fc52dca56963b05/3051crop_weight_200.h5
--------------------------------------------------------------------------------
/SRCNN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaokeAI/SRCNN-keras/162431e0110b1c30444f3f8b3fc52dca56963b05/SRCNN.png
--------------------------------------------------------------------------------
/butterfly_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaokeAI/SRCNN-keras/162431e0110b1c30444f3f8b3fc52dca56963b05/butterfly_GT.bmp
--------------------------------------------------------------------------------
/input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaokeAI/SRCNN-keras/162431e0110b1c30444f3f8b3fc52dca56963b05/input.jpg
--------------------------------------------------------------------------------
/m_model_adam_new30.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaokeAI/SRCNN-keras/162431e0110b1c30444f3f8b3fc52dca56963b05/m_model_adam_new30.h5
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from keras.models import Sequential
2 | from keras.layers import Conv2D, Input, BatchNormalization
3 | # from keras.layers.advanced_activations import LeakyReLU
4 | from keras.callbacks import ModelCheckpoint
5 | from keras.optimizers import SGD, Adam
6 | import prepare_data as pd
7 | import numpy
8 | import math
9 |
10 |
11 | def psnr(target, ref):
12 | # assume RGB image
13 | target_data = numpy.array(target, dtype=float)
14 | ref_data = numpy.array(ref, dtype=float)
15 |
16 | diff = ref_data - target_data
17 | diff = diff.flatten('C')
18 |
19 | rmse = math.sqrt(numpy.mean(diff ** 2.))
20 |
21 | return 20 * math.log10(255. / rmse)
22 |
23 |
24 | def model():
25 | # lrelu = LeakyReLU(alpha=0.1)
26 | SRCNN = Sequential()
27 | SRCNN.add(Conv2D(nb_filter=128, nb_row=9, nb_col=9, init='glorot_uniform',
28 | activation='relu', border_mode='valid', bias=True, input_shape=(32, 32, 1)))
29 | SRCNN.add(Conv2D(nb_filter=64, nb_row=3, nb_col=3, init='glorot_uniform',
30 | activation='relu', border_mode='same', bias=True))
31 | # SRCNN.add(BatchNormalization())
32 | SRCNN.add(Conv2D(nb_filter=1, nb_row=5, nb_col=5, init='glorot_uniform',
33 | activation='linear', border_mode='valid', bias=True))
34 | adam = Adam(lr=0.0003)
35 | SRCNN.compile(optimizer=adam, loss='mean_squared_error', metrics=['mean_squared_error'])
36 | return SRCNN
37 |
38 |
39 | def predict_model():
40 | # lrelu = LeakyReLU(alpha=0.1)
41 | SRCNN = Sequential()
42 | SRCNN.add(Conv2D(nb_filter=128, nb_row=9, nb_col=9, init='glorot_uniform',
43 | activation='relu', border_mode='valid', bias=True, input_shape=(None, None, 1)))
44 | SRCNN.add(Conv2D(nb_filter=64, nb_row=3, nb_col=3, init='glorot_uniform',
45 | activation='relu', border_mode='same', bias=True))
46 | # SRCNN.add(BatchNormalization())
47 | SRCNN.add(Conv2D(nb_filter=1, nb_row=5, nb_col=5, init='glorot_uniform',
48 | activation='linear', border_mode='valid', bias=True))
49 | adam = Adam(lr=0.0003)
50 | SRCNN.compile(optimizer=adam, loss='mean_squared_error', metrics=['mean_squared_error'])
51 | return SRCNN
52 |
53 |
54 | def train():
55 | srcnn_model = model()
56 | print(srcnn_model.summary())
57 | data, label = pd.read_training_data("./train.h5")
58 | val_data, val_label = pd.read_training_data("./test.h5")
59 |
60 | checkpoint = ModelCheckpoint("SRCNN_check.h5", monitor='val_loss', verbose=1, save_best_only=True,
61 | save_weights_only=False, mode='min')
62 | callbacks_list = [checkpoint]
63 |
64 | srcnn_model.fit(data, label, batch_size=128, validation_data=(val_data, val_label),
65 | callbacks=callbacks_list, shuffle=True, nb_epoch=200, verbose=0)
66 | # srcnn_model.load_weights("m_model_adam.h5")
67 |
68 |
69 | def predict():
70 | srcnn_model = predict_model()
71 | srcnn_model.load_weights("3051crop_weight_200.h5")
72 | IMG_NAME = "/home/mark/Engineer/SR/data/Set14/flowers.bmp"
73 | INPUT_NAME = "input2.jpg"
74 | OUTPUT_NAME = "pre2.jpg"
75 |
76 | import cv2
77 | img = cv2.imread(IMG_NAME, cv2.IMREAD_COLOR)
78 | img = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb)
79 | shape = img.shape
80 | Y_img = cv2.resize(img[:, :, 0], (shape[1] / 2, shape[0] / 2), cv2.INTER_CUBIC)
81 | Y_img = cv2.resize(Y_img, (shape[1], shape[0]), cv2.INTER_CUBIC)
82 | img[:, :, 0] = Y_img
83 | img = cv2.cvtColor(img, cv2.COLOR_YCrCb2BGR)
84 | cv2.imwrite(INPUT_NAME, img)
85 |
86 | Y = numpy.zeros((1, img.shape[0], img.shape[1], 1), dtype=float)
87 | Y[0, :, :, 0] = Y_img.astype(float) / 255.
88 | pre = srcnn_model.predict(Y, batch_size=1) * 255.
89 | pre[pre[:] > 255] = 255
90 | pre[pre[:] < 0] = 0
91 | pre = pre.astype(numpy.uint8)
92 | img = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb)
93 | img[6: -6, 6: -6, 0] = pre[0, :, :, 0]
94 | img = cv2.cvtColor(img, cv2.COLOR_YCrCb2BGR)
95 | cv2.imwrite(OUTPUT_NAME, img)
96 |
97 | # psnr calculation:
98 | im1 = cv2.imread(IMG_NAME, cv2.IMREAD_COLOR)
99 | im1 = cv2.cvtColor(im1, cv2.COLOR_BGR2YCrCb)[6: -6, 6: -6, 0]
100 | im2 = cv2.imread(INPUT_NAME, cv2.IMREAD_COLOR)
101 | im2 = cv2.cvtColor(im2, cv2.COLOR_BGR2YCrCb)[6: -6, 6: -6, 0]
102 | im3 = cv2.imread(OUTPUT_NAME, cv2.IMREAD_COLOR)
103 | im3 = cv2.cvtColor(im3, cv2.COLOR_BGR2YCrCb)[6: -6, 6: -6, 0]
104 |
105 | print "bicubic:"
106 | print cv2.PSNR(im1, im2)
107 | print "SRCNN:"
108 | print cv2.PSNR(im1, im3)
109 |
110 |
111 | if __name__ == "__main__":
112 | train()
113 | predict()
114 |
--------------------------------------------------------------------------------
/pre_adam30.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaokeAI/SRCNN-keras/162431e0110b1c30444f3f8b3fc52dca56963b05/pre_adam30.jpg
--------------------------------------------------------------------------------
/prepare_data.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import cv2
4 | import h5py
5 | import numpy
6 |
7 | DATA_PATH = "/home/mark/Engineer/SR/data/Train/"
8 | TEST_PATH = "/home/mark/Engineer/SR/SRCNN_createData/Test/Set14/"
9 | Random_Crop = 30
10 | Patch_size = 32
11 | label_size = 20
12 | conv_side = 6
13 | scale = 2
14 |
15 |
16 | def prepare_data(_path):
17 | names = os.listdir(_path)
18 | names = sorted(names)
19 | nums = names.__len__()
20 |
21 | data = numpy.zeros((nums * Random_Crop, 1, Patch_size, Patch_size), dtype=numpy.double)
22 | label = numpy.zeros((nums * Random_Crop, 1, label_size, label_size), dtype=numpy.double)
23 |
24 | for i in range(nums):
25 | name = _path + names[i]
26 | hr_img = cv2.imread(name, cv2.IMREAD_COLOR)
27 | shape = hr_img.shape
28 |
29 | hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2YCrCb)
30 | hr_img = hr_img[:, :, 0]
31 |
32 | # two resize operation to produce training data and labels
33 | lr_img = cv2.resize(hr_img, (shape[1] / scale, shape[0] / scale))
34 | lr_img = cv2.resize(lr_img, (shape[1], shape[0]))
35 |
36 | # produce Random_Crop random coordinate to crop training img
37 | Points_x = numpy.random.randint(0, min(shape[0], shape[1]) - Patch_size, Random_Crop)
38 | Points_y = numpy.random.randint(0, min(shape[0], shape[1]) - Patch_size, Random_Crop)
39 |
40 | for j in range(Random_Crop):
41 | lr_patch = lr_img[Points_x[j]: Points_x[j] + Patch_size, Points_y[j]: Points_y[j] + Patch_size]
42 | hr_patch = hr_img[Points_x[j]: Points_x[j] + Patch_size, Points_y[j]: Points_y[j] + Patch_size]
43 |
44 | lr_patch = lr_patch.astype(float) / 255.
45 | hr_patch = hr_patch.astype(float) / 255.
46 |
47 | data[i * Random_Crop + j, 0, :, :] = lr_patch
48 | label[i * Random_Crop + j, 0, :, :] = hr_patch[conv_side: -conv_side, conv_side: -conv_side]
49 | # cv2.imshow("lr", lr_patch)
50 | # cv2.imshow("hr", hr_patch)
51 | # cv2.waitKey(0)
52 | return data, label
53 |
54 | # BORDER_CUT = 8
55 | BLOCK_STEP = 16
56 | BLOCK_SIZE = 32
57 |
58 |
59 | def prepare_crop_data(_path):
60 | names = os.listdir(_path)
61 | names = sorted(names)
62 | nums = names.__len__()
63 |
64 | data = []
65 | label = []
66 |
67 | for i in range(nums):
68 | name = _path + names[i]
69 | hr_img = cv2.imread(name, cv2.IMREAD_COLOR)
70 | hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2YCrCb)
71 | hr_img = hr_img[:, :, 0]
72 | shape = hr_img.shape
73 |
74 | # two resize operation to produce training data and labels
75 | lr_img = cv2.resize(hr_img, (shape[1] / scale, shape[0] / scale))
76 | lr_img = cv2.resize(lr_img, (shape[1], shape[0]))
77 |
78 | width_num = (shape[0] - (BLOCK_SIZE - BLOCK_STEP) * 2) / BLOCK_STEP
79 | height_num = (shape[1] - (BLOCK_SIZE - BLOCK_STEP) * 2) / BLOCK_STEP
80 | for k in range(width_num):
81 | for j in range(height_num):
82 | x = k * BLOCK_STEP
83 | y = j * BLOCK_STEP
84 | hr_patch = hr_img[x: x + BLOCK_SIZE, y: y + BLOCK_SIZE]
85 | lr_patch = lr_img[x: x + BLOCK_SIZE, y: y + BLOCK_SIZE]
86 |
87 | lr_patch = lr_patch.astype(float) / 255.
88 | hr_patch = hr_patch.astype(float) / 255.
89 |
90 | lr = numpy.zeros((1, Patch_size, Patch_size), dtype=numpy.double)
91 | hr = numpy.zeros((1, label_size, label_size), dtype=numpy.double)
92 |
93 | lr[0, :, :] = lr_patch
94 | hr[0, :, :] = hr_patch[conv_side: -conv_side, conv_side: -conv_side]
95 |
96 | data.append(lr)
97 | label.append(hr)
98 |
99 | data = numpy.array(data, dtype=float)
100 | label = numpy.array(label, dtype=float)
101 | return data, label
102 |
103 |
104 | def write_hdf5(data, labels, output_filename):
105 | """
106 | This function is used to save image data and its label(s) to hdf5 file.
107 | output_file.h5,contain data and label
108 | """
109 |
110 | x = data.astype(numpy.float32)
111 | y = labels.astype(numpy.float32)
112 |
113 | with h5py.File(output_filename, 'w') as h:
114 | h.create_dataset('data', data=x, shape=x.shape)
115 | h.create_dataset('label', data=y, shape=y.shape)
116 | # h.create_dataset()
117 |
118 |
119 | def read_training_data(file):
120 | with h5py.File(file, 'r') as hf:
121 | data = numpy.array(hf.get('data'))
122 | label = numpy.array(hf.get('label'))
123 | train_data = numpy.transpose(data, (0, 2, 3, 1))
124 | train_label = numpy.transpose(label, (0, 2, 3, 1))
125 | return train_data, train_label
126 |
127 |
128 | if __name__ == "__main__":
129 | data, label = prepare_crop_data(DATA_PATH)
130 | write_hdf5(data, label, "crop_train.h5")
131 | data, label = prepare_data(TEST_PATH)
132 | write_hdf5(data, label, "test.h5")
133 | # _, _a = read_training_data("train.h5")
134 | # _, _a = read_training_data("test.h5")
135 |
--------------------------------------------------------------------------------
/prepare_data.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaokeAI/SRCNN-keras/162431e0110b1c30444f3f8b3fc52dca56963b05/prepare_data.pyc
--------------------------------------------------------------------------------
/psnr.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 | import numpy
4 |
5 |
6 | def psnr(target, ref):
7 | # assume RGB image
8 | target_data = numpy.array(target, dtype=float)
9 | ref_data = numpy.array(ref, dtype=float)
10 |
11 | diff = ref_data - target_data
12 | diff = diff.flatten('C')
13 |
14 | rmse = math.sqrt(numpy.mean(diff ** 2.))
15 |
16 | return 20 * math.log10(255. / rmse)
17 |
18 |
19 | if __name__ == "__main__":
20 | im1 = cv2.imread("./input.jpg", cv2.IMREAD_COLOR)
21 | im1 = cv2.cvtColor(im1, cv2.COLOR_BGR2YCrCb)[6: -6, 6: -6, 0]
22 | im2 = cv2.imread("./butterfly_GT.bmp", cv2.IMREAD_COLOR)
23 | im2 = cv2.cvtColor(im2, cv2.COLOR_BGR2YCrCb)[6: -6, 6: -6, 0]
24 | im3 = cv2.imread("pre_adam2000.jpg", cv2.IMREAD_COLOR)
25 | im3 = cv2.cvtColor(im3, cv2.COLOR_BGR2YCrCb)[6: -6, 6: -6, 0]
26 | im4 = cv2.imread("./pre.jpg", cv2.IMREAD_COLOR)
27 | im4 = cv2.cvtColor(im4, cv2.COLOR_BGR2YCrCb)[6: -6, 6: -6, 0]
28 |
29 | print "adam:"
30 | print psnr(im2, im3)
31 | print "bicubic:"
32 | print psnr(im2, im1)
33 | print "SRCNN:"
34 | print psnr(im2, im4)
35 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # Keras implementation of SRCNN
2 |
3 |
4 | The original paper is [Learning a Deep Convolutional Network for Image Super-Resolution](https://arxiv.org/abs/1501.00092)
5 |
6 |
7 |
8 |
9 |
10 | My implementation have some difference with the original paper, include:
11 |
12 | * use Adam alghorithm for optimization, with learning rate 0.0003 for all layers.
13 | * Use the opencv library to produce the training data and test data, not the matlab library. This difference may caused some deteriorate on the final results.
14 | * I did not set different learning rate in different layer, but I found this network still work.
15 | * The color space of YCrCb in Matlab and OpenCV also have some difference. So if you want to compare your results with some academic paper, you may want to use the code written with matlab.
16 |
17 | ## Use:
18 | ### Create your own data
19 | open **prepare_data.py** and change the data path to your data
20 |
21 | Excute:
22 | `python prepare_data.py`
23 |
24 | ### training and test:
25 | Excute:
26 | `python main.py`
27 |
28 |
29 | ## Result(training for 200 epoches on 91 images, with upscaling factor 2):
30 | Results on Set5 dataset:
31 |
32 |
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaokeAI/SRCNN-keras/162431e0110b1c30444f3f8b3fc52dca56963b05/result.png
--------------------------------------------------------------------------------