├── README.md ├── network.py ├── train.py ├── predict.py └── preprocess.py /README.md: -------------------------------------------------------------------------------- 1 | # SRCNN-keras-hdf5 2 | 1. 本repo基于keras框架与SRCNN网络,利用HDF5库能够读写超过内存的大数据的特点,写了个能够直接用于超过内存的大数据的SRCNN demo。 3 | 4 | 2. demo重点展示了hdf5的分块读写代码与应用思路,使之能够直接应用于超过内存的大数据 5 | 6 | 3. demo 参考了 http://mmlab.ie.cuhk.edu.hk/projects/SRCNN.html ,训练数据可以从链接中直接下载,下载后放置于根目录下的 7 | /dataset/Train 路径中 8 | 9 | 4. 关于hdf5读写超过内存的大数据的方法与技巧,我在这篇博客中做了少许说明 10 | 11 | http://www.cnblogs.com/nwpuxuezha/p/6537307.html 12 | 13 | 5. 欢迎提出宝贵的意见与建议,相互分享与学习。 14 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | from keras.models import Sequential 2 | from keras.models import Model 3 | from keras.layers import Input, Convolution2D 4 | from keras.optimizers import Adam 5 | 6 | 7 | def srcnn(input_shape=(33,33,1)): 8 | model = Sequential() 9 | model.add(Convolution2D(64, 9, 9, border_mode='valid', input_shape=input_shape, activation='relu')) 10 | model.add(Convolution2D(32, 1, 1, activation='relu')) 11 | model.add(Convolution2D(1, 5, 5, )) 12 | model.compile(Adam(lr=0.001), 'mse') 13 | return model 14 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import h5py 4 | import network 5 | import numpy as np 6 | from scipy import misc 7 | 8 | from keras.models import Sequential 9 | from keras.models import Model 10 | from keras.layers import Input, Convolution2D 11 | from keras.optimizers import Adam 12 | 13 | import argparse 14 | 15 | 16 | def train(): 17 | 18 | model = network.srcnn() 19 | 20 | output_file = './data.h5' 21 | h5f = h5py.File(output_file, 'r') 22 | X = h5f['input'] 23 | y = h5f['label'] 24 | 25 | n_epoch = args.n_epoch 26 | 27 | if not os.path.exists(args.save): 28 | os.mkdir(args.save) 29 | 30 | for epoch in range(0, n_epoch, 5): 31 | model.fit(X, y, batch_size=128, nb_epoch=5, shuffle='batch') 32 | if args.save: 33 | print("Saving model ", epoch + 5) 34 | model.save(os.path.join(args.save, 'model_%d.h5' %(epoch+5))) 35 | 36 | 37 | if __name__ == '__main__': 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('-S', '--save', 40 | default='./save', 41 | dest='save', 42 | type=str, 43 | nargs=1, 44 | help="Path to save the checkpoints to") 45 | parser.add_argument('-D', '--data', 46 | default='./dataset/Train/output/', 47 | dest='data', 48 | type=str, 49 | nargs=1, 50 | help="Training data directory") 51 | parser.add_argument('-E', '--epoch', 52 | default=50, 53 | dest='n_epoch', 54 | type=int, 55 | nargs=1, 56 | help="Training epochs must be a multiple of 5") 57 | args = parser.parse_args() 58 | print(args) 59 | train() 60 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from os import listdir 2 | from os.path import isfile, join 3 | import argparse 4 | import h5py 5 | 6 | import numpy as np 7 | from scipy import misc 8 | import network 9 | 10 | input_size = 33 11 | label_size = 21 12 | pad = (33 - 21) // 2 13 | 14 | def ycbcr2rgb(im): 15 | xform = np.array([[1, 0, 1.402], [1, -0.34414, -.71414], [1, 1.772, 0]]) 16 | rgb = im.astype(np.float) 17 | rgb[:,:,[1,2]] -= 128 18 | return rgb.dot(xform.T) 19 | 20 | 21 | def predict(): 22 | model = network.srcnn((None, None, 1)) 23 | f = h5py.File(option.model, mode='r') 24 | model.load_weights_from_hdf5_group(f['model_weights']) 25 | 26 | X = misc.imread(option.input, mode='YCbCr') 27 | 28 | w, h, c = X.shape 29 | w -= int(w % option.scale) 30 | h -= int(h % option.scale) 31 | X = X[0:w, 0:h, :] 32 | X[:,:,1] = X[:,:,0] 33 | X[:,:,2] = X[:,:,0] 34 | 35 | scaled = misc.imresize(X, 1.0/option.scale, 'bicubic') 36 | scaled = misc.imresize(scaled, option.scale/1.0, 'bicubic') 37 | newimg = np.zeros(scaled.shape) 38 | 39 | if option.baseline: 40 | misc.imsave(option.baseline, scaled[pad : w - w % input_size, pad: h - h % input_size, :]) 41 | 42 | newimg[pad:-pad, pad:-pad, 0, None] = model.predict(scaled[None, :, :, 0, None] / 255) 43 | newimg[pad:-pad, pad:-pad, 1, None] = model.predict(scaled[None, :, :, 1, None] / 255) 44 | newimg[pad:-pad, pad:-pad, 2, None] = model.predict(scaled[None, :, :, 2, None] / 255) 45 | misc.imsave(option.output, newimg) 46 | 47 | 48 | if __name__ == '__main__': 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('-M', '--model', 51 | default='./save/model_205.h5', 52 | dest='model', 53 | type=str, 54 | nargs=1, 55 | help="The model to be used for prediction") 56 | parser.add_argument('-I', '--input-file', 57 | default='./dataset/Test/Set5/baby_GT.bmp', 58 | dest='input', 59 | type=str, 60 | nargs=1, 61 | help="Input image file path") 62 | parser.add_argument('-O', '--output-file', 63 | default='./dataset/Test/Set5/baby_SRCNN.bmp', 64 | dest='output', 65 | type=str, 66 | nargs=1, 67 | help="Output image file path") 68 | parser.add_argument('-B', '--baseline', 69 | default='./dataset/Test/Set5/baby_bicubic.bmp', 70 | dest='baseline', 71 | type=str, 72 | nargs=1, 73 | help="Baseline bicubic interpolated image file path") 74 | parser.add_argument('-S', '--scale-factor', 75 | default=3.0, 76 | dest='scale', 77 | type=float, 78 | nargs=1, 79 | help="Scale factor") 80 | option = parser.parse_args() 81 | predict() 82 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | from os import listdir, makedirs 2 | from os.path import isfile, join, exists 3 | import os 4 | import argparse 5 | import numpy as np 6 | import h5py 7 | from scipy import misc 8 | 9 | 10 | def remove_if_exist(file_name): 11 | if exists(file_name): 12 | os.remove(file_name) 13 | 14 | 15 | def preprocess_dataset(option, **kwargs): 16 | 17 | input_dir = option.input_dir 18 | output_file = option.output_file 19 | 20 | scale = kwargs.pop('scale', 3) 21 | input_size = kwargs.pop('input_size', 33) 22 | label_size = kwargs.pop('label_size', 21) 23 | channels = kwargs.pop('channels', 1) 24 | stride = kwargs.pop('stride', 14) 25 | chunks = kwargs.pop('chunks', 1024) 26 | 27 | pad = (input_size - label_size) // 2 28 | 29 | input_nums = 1024 30 | remove_if_exist(output_file) 31 | with h5py.File(output_file, 'w') as f: 32 | f.create_dataset("input", (input_nums, input_size, input_size, channels), 33 | maxshape=(None, input_size, input_size, channels), 34 | chunks=(128, input_size, input_size, channels), 35 | dtype='float32') 36 | f.create_dataset("label", (input_nums, label_size, label_size, channels), 37 | maxshape=(None, label_size, label_size, channels), 38 | chunks=(128, label_size, label_size, channels), 39 | dtype='float32') 40 | f.create_dataset("count", data=(0,)) 41 | 42 | count = 0 43 | for f in listdir(input_dir): 44 | f = join(input_dir, f) 45 | if not isfile(f): 46 | continue 47 | print(f) 48 | 49 | image = misc.imread(f, flatten=False, mode='YCbCr') 50 | 51 | w, h, c = image.shape 52 | w -= int(w % scale) 53 | h -= int(h % scale) 54 | image = image[0:w, 0:h, 0] 55 | 56 | scaled = misc.imresize(image, 1.0 / scale, 'bicubic') 57 | scaled = misc.imresize(scaled, scale / 1.0, 'bicubic') 58 | 59 | h5f = h5py.File(output_file, 'a') 60 | if count + chunks > h5f['input'].shape[0]: 61 | input_nums = count + chunks 62 | h5f['input'].resize((input_nums, input_size, input_size, channels)) 63 | h5f['label'].resize((input_nums, label_size, label_size, channels)) 64 | 65 | for i in range(0, h - input_size + 1, stride): 66 | for j in range(0, w - input_size + 1, stride): 67 | 68 | sub_img = scaled[j: j + input_size, i: i + input_size] 69 | sub_img = sub_img.reshape([1, input_size, input_size, 1]) 70 | sub_img = sub_img / 255 71 | 72 | sub_img_label = image[j + pad: j + pad + label_size, i + pad: i + pad + label_size] 73 | sub_img_label = sub_img_label.reshape([1, label_size, label_size, 1]) 74 | sub_img_label = sub_img_label / 255 75 | 76 | h5f['input'][count] = sub_img 77 | h5f['label'][count] = sub_img_label 78 | count += 1 79 | 80 | h5f = h5py.File(output_file, 'a') 81 | h5f['input'].resize((count, input_size, input_size, channels)) 82 | h5f['label'].resize((count, label_size, label_size, channels)) 83 | 84 | if __name__ == '__main__': 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument('-I', '--input-dir', 87 | default='./dataset/Train', 88 | dest='input_dir', 89 | type=str, 90 | nargs=1, 91 | help="Data input directory") 92 | parser.add_argument('-O', '--output-file', 93 | default='./data.h5', 94 | dest='output_file', 95 | type=str, 96 | nargs=1, 97 | help="Data output file with hdf5 format") 98 | option = parser.parse_args() 99 | 100 | preprocess_dataset(option=option, 101 | scale=3, 102 | input_size=33, 103 | label_size=21, 104 | stride=14, 105 | channels=1, 106 | chunks=1024) 107 | --------------------------------------------------------------------------------