├── .gitignore ├── doc └── error_rate_deepid_youtubeface.png ├── src ├── data_prepare │ ├── youtube_data_split.py │ ├── youtube_img_crop.py │ └── vectorize_img.py └── conv_net │ ├── load_data.py │ ├── deepid_generate.py │ ├── layers.py │ ├── sample_optimization.py │ └── deepid_class.py ├── README_ch.md └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | youtube_face 2 | *.pyc 3 | bak 4 | .DS_store 5 | -------------------------------------------------------------------------------- /doc/error_rate_deepid_youtubeface.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stdcoutzyx/DeepID_FaceClassify/HEAD/doc/error_rate_deepid_youtubeface.png -------------------------------------------------------------------------------- /src/data_prepare/youtube_data_split.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf8 -*- 3 | 4 | import os 5 | import sys 6 | import random 7 | 8 | def walk_through_folder_for_split(src_folder): 9 | test_set = [] 10 | train_set = [] 11 | 12 | label = 0 13 | for people_folder in os.listdir(src_folder): 14 | people_path = src_folder + people_folder + '/' 15 | video_folders = os.listdir(people_path) 16 | people_imgs = [] 17 | for video_folder in video_folders: 18 | video_path = people_path + video_folder + '/' 19 | img_files = os.listdir(video_path) 20 | for img_file in img_files: 21 | img_path = video_path + img_file 22 | people_imgs.append((img_path, label)) 23 | if len(people_imgs) < 25: 24 | continue 25 | random.shuffle(people_imgs) 26 | test_set += people_imgs[0:5] 27 | train_set += people_imgs[5:25] 28 | 29 | sys.stdout.write('\rdone: ' + str(label)) 30 | sys.stdout.flush() 31 | label += 1 32 | print '' 33 | print 'test set num: %d' % (len(test_set)) 34 | print 'train set num: %d' % (len(train_set)) 35 | return test_set, train_set 36 | 37 | def set_to_csv_file(data_set, file_name): 38 | f = open(file_name, 'wb') 39 | for item in data_set: 40 | line = item[0] + ',' + str(item[1]) + '\n' 41 | f.write(line) 42 | f.close() 43 | 44 | if __name__ == '__main__': 45 | if len(sys.argv) != 4: 46 | print 'Usage: python %s src_folder test_set_file train_set_file' % (sys.argv[0]) 47 | sys.exit() 48 | src_folder = sys.argv[1] 49 | test_set_file = sys.argv[2] 50 | train_set_file = sys.argv[3] 51 | if not src_folder.endswith('/'): 52 | src_folder += '/' 53 | 54 | test_set, train_set = walk_through_folder_for_split(src_folder) 55 | set_to_csv_file(test_set, test_set_file) 56 | set_to_csv_file(train_set, train_set_file) 57 | -------------------------------------------------------------------------------- /src/data_prepare/youtube_img_crop.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf8 -*- 3 | 4 | import Image 5 | import sys 6 | import os 7 | 8 | def crop_img_by_half_center(src_file_path, dest_file_path): 9 | im = Image.open(src_file_path) 10 | x_size, y_size = im.size 11 | start_point_xy = x_size / 4 12 | end_point_xy = x_size / 4 + x_size / 2 13 | box = (start_point_xy, start_point_xy, end_point_xy, end_point_xy) 14 | new_im = im.crop(box) 15 | new_new_im = new_im.resize((47,55)) 16 | new_new_im.save(dest_file_path) 17 | 18 | def walk_through_the_folder_for_crop(aligned_db_folder, result_folder): 19 | if not os.path.exists(result_folder): 20 | os.mkdir(result_folder) 21 | 22 | i = 0 23 | img_count = 0 24 | for people_folder in os.listdir(aligned_db_folder): 25 | src_people_path = aligned_db_folder + people_folder + '/' 26 | dest_people_path = result_folder + people_folder + '/' 27 | if not os.path.exists(dest_people_path): 28 | os.mkdir(dest_people_path) 29 | for video_folder in os.listdir(src_people_path): 30 | src_video_path = src_people_path + video_folder + '/' 31 | dest_video_path = dest_people_path + video_folder + '/' 32 | if not os.path.exists(dest_video_path): 33 | os.mkdir(dest_video_path) 34 | for img_file in os.listdir(src_video_path): 35 | src_img_path = src_video_path + img_file 36 | dest_img_path = dest_video_path + img_file 37 | crop_img_by_half_center(src_img_path, dest_img_path) 38 | i += 1 39 | img_count += len(os.listdir(src_video_path)) 40 | sys.stdout.write('\rsub_folder: %d, imgs %d' % (i, img_count) ) 41 | sys.stdout.flush() 42 | print '' 43 | 44 | if __name__ == '__main__': 45 | if len(sys.argv) != 3: 46 | print 'Usage: python %s aligned_db_folder new_folder' % (sys.argv[0]) 47 | sys.exit() 48 | aligned_db_folder = sys.argv[1] 49 | result_folder = sys.argv[2] 50 | if not aligned_db_folder.endswith('/'): 51 | aligned_db_folder += '/' 52 | if not result_folder.endswith('/'): 53 | result_folder += '/' 54 | walk_through_the_folder_for_crop(aligned_db_folder, result_folder) 55 | 56 | -------------------------------------------------------------------------------- /src/conv_net/load_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf8 -*- 3 | 4 | import cPickle 5 | import pickle 6 | import gzip 7 | import os 8 | import sys 9 | 10 | import theano 11 | import theano.tensor as T 12 | import numpy as np 13 | 14 | 15 | def shared_dataset(data_xy, borrow=True): 16 | data_x, data_y = data_xy 17 | shared_x = theano.shared( 18 | np.asarray(data_x, dtype=theano.config.floatX), 19 | borrow=borrow) 20 | shared_y = theano.shared( 21 | np.asarray(data_y, dtype=theano.config.floatX), 22 | borrow=borrow) 23 | return shared_x, T.cast(shared_y, 'int32') 24 | 25 | def load_data(dataset): 26 | f = gzip.open(dataset, 'rb') 27 | train_set, valid_set = cPickle.load(f)[0:2] 28 | f.close() 29 | train_set_x, train_set_y = shared_dataset(train_set) 30 | valid_set_x, valid_set_y = shared_dataset(valid_set) 31 | return [(train_set_x, train_set_y), (valid_set_x, valid_set_y)] 32 | 33 | def load_data(dataset): 34 | f = open(dataset, 'rb') 35 | train_set, valid_set = pickle.load(f)[0:2] 36 | f.close() 37 | train_set_x, train_set_y = shared_dataset(train_set) 38 | valid_set_x, valid_set_y = shared_dataset(valid_set) 39 | return [(train_set_x, train_set_y), (valid_set_x, valid_set_y)] 40 | 41 | def load_data_split_pickle(dataset): 42 | def get_files(vec_folder): 43 | file_names = os.listdir(vec_folder) 44 | file_names.sort() 45 | if not vec_folder.endswith('/'): 46 | vec_folder += '/' 47 | for i in range(len(file_names)): 48 | file_names[i] = vec_folder + file_names[i] 49 | return file_names 50 | 51 | def load_data_xy(file_names): 52 | datas = [] 53 | labels = [] 54 | for file_name in file_names: 55 | f = open(file_name, 'rb') 56 | x, y = pickle.load(f) 57 | datas.append(x) 58 | labels.append(y) 59 | combine_d = np.vstack(datas) 60 | combine_l = np.hstack(labels) 61 | return combine_d, combine_l 62 | 63 | valid_folder, train_folder = dataset 64 | valid_file_names = get_files(valid_folder) 65 | train_file_names = get_files(train_folder) 66 | valid_set = load_data_xy(valid_file_names) 67 | train_set = load_data_xy(train_file_names) 68 | 69 | train_set_x, train_set_y = shared_dataset(train_set) 70 | valid_set_x, valid_set_y = shared_dataset(valid_set) 71 | 72 | return [(train_set_x, train_set_y), (valid_set_x, valid_set_y)] 73 | 74 | -------------------------------------------------------------------------------- /src/data_prepare/vectorize_img.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf8 -*- 3 | 4 | import os 5 | import sys 6 | import random 7 | import numpy as np 8 | import Image 9 | 10 | def read_csv_file(csv_file): 11 | path_and_labels = [] 12 | f = open(csv_file, 'rb') 13 | for line in f: 14 | line = line.strip('\r\n') 15 | path, label = line.split(',') 16 | label = int(label) 17 | path_and_labels.append((path, label)) 18 | f.close() 19 | random.shuffle(path_and_labels) 20 | return path_and_labels 21 | 22 | def vectorize_imgs(path_and_labels, image_size): 23 | image_vector_len = np.prod(image_size) 24 | 25 | arrs = [] 26 | labels = [] 27 | i = 0 28 | for path_and_label in path_and_labels: 29 | path, label = path_and_label 30 | img = Image.open(path) 31 | arr_img = np.asarray(img, dtype='float64') 32 | arr_img = arr_img.transpose(2,0,1).reshape((image_vector_len, )) 33 | 34 | labels.append(label) 35 | arrs.append(arr_img) 36 | 37 | i += 1 38 | if i % 100 == 0: 39 | sys.stdout.write('\rdone: ' + str(i)) 40 | sys.stdout.flush() 41 | print '' 42 | arrs = np.asarray(arrs, dtype='float64') 43 | labels = np.asarray(labels, dtype='int32') 44 | return (arrs, labels) 45 | 46 | def cPickle_output(vars, file_name): 47 | import cPickle 48 | f = open(file_name, 'wb') 49 | cPickle.dump(vars, f, protocol=cPickle.HIGHEST_PROTOCOL) 50 | f.close() 51 | 52 | def output_data(vector_vars, vector_folder, batch_size=1000): 53 | if not vector_folder.endswith('/'): 54 | vector_folder += '/' 55 | if not os.path.exists(vector_folder): 56 | os.mkdir(vector_folder) 57 | x, y = vector_vars 58 | n_batch = len(x) / batch_size 59 | for i in range(n_batch): 60 | file_name = vector_folder + str(i) + '.pkl' 61 | batch_x = x[ i*batch_size: (i+1)*batch_size] 62 | batch_y = y[ i*batch_size: (i+1)*batch_size] 63 | cPickle_output((batch_x, batch_y), file_name) 64 | if n_batch * batch_size < len(x): 65 | batch_x = x[n_batch*batch_size: ] 66 | batch_y = y[n_batch*batch_size: ] 67 | file_name = vector_folder + str(n_batch) + '.pkl' 68 | cPickle_output((batch_x, batch_y), file_name) 69 | 70 | if __name__ == '__main__': 71 | if len(sys.argv) != 5: 72 | print 'Usage: python %s test_set_file train_set_file test_vector_folder train_vector_folder' % (sys.argv[0]) 73 | sys.exit() 74 | test_set_file = sys.argv[1] 75 | train_set_file = sys.argv[2] 76 | test_vector_folder = sys.argv[3] 77 | train_vector_folder = sys.argv[4] 78 | 79 | test_path_and_labels = read_csv_file(test_set_file) 80 | train_path_and_labels = read_csv_file(train_set_file) 81 | 82 | print 'test img num: %d' % (len(test_path_and_labels)) 83 | print 'train img num: %d' % (len(train_path_and_labels)) 84 | 85 | img_size = (3, 55, 47) # channel, height, width 86 | test_vec = vectorize_imgs(test_path_and_labels, img_size) 87 | train_vec = vectorize_imgs(train_path_and_labels, img_size) 88 | 89 | output_data(test_vec, test_vector_folder) 90 | output_data(train_vec, train_vector_folder) 91 | 92 | -------------------------------------------------------------------------------- /src/conv_net/deepid_generate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf8 -*- 3 | 4 | from layers import * 5 | from deepid_class import * 6 | 7 | import cPickle 8 | import gzip 9 | import os 10 | import sys 11 | import time 12 | import numpy 13 | import theano 14 | 15 | class DeepIDGenerator: 16 | def __init__(self, exist_params): 17 | self.rng = numpy.random.RandomState(1234) 18 | self.exist_params = exist_params 19 | 20 | def layer_params(self, nkerns, batch_size): 21 | src_channel = 3 22 | self.layer1_image_shape = (batch_size, src_channel, 55, 47) 23 | self.layer1_filter_shape = (nkerns[0], src_channel, 4, 4) 24 | self.layer2_image_shape = (batch_size, nkerns[0], 26, 22) 25 | self.layer2_filter_shape = (nkerns[1], nkerns[0], 3, 3) 26 | self.layer3_image_shape = (batch_size, nkerns[1], 12, 10) 27 | self.layer3_filter_shape = (nkerns[2], nkerns[1], 3, 3) 28 | self.layer4_image_shape = (batch_size, nkerns[2], 5, 4) 29 | self.layer4_filter_shape = (nkerns[3], nkerns[2], 2, 2) 30 | self.result_image_shape = (batch_size, nkerns[3], 4, 3) 31 | 32 | def build_layer_architecture(self, n_hidden, acti_func=relu): 33 | ''' 34 | simple means the deepid layer input is only the layer4 output. 35 | layer1: convpool layer 36 | layer2: convpool layer 37 | layer3: convpool layer 38 | layer4: conv layer 39 | deepid: hidden layer 40 | ''' 41 | x = T.matrix('x') 42 | 43 | print '\tbuilding the model ...' 44 | 45 | layer1_input = x.reshape(self.layer1_image_shape) 46 | self.layer1 = LeNetConvPoolLayer(self.rng, 47 | input = layer1_input, 48 | image_shape = self.layer1_image_shape, 49 | filter_shape = self.layer1_filter_shape, 50 | poolsize = (2,2), 51 | W = self.exist_params[5][0], 52 | b = self.exist_params[5][1], 53 | activation = acti_func) 54 | 55 | self.layer2 = LeNetConvPoolLayer(self.rng, 56 | input = self.layer1.output, 57 | image_shape = self.layer2_image_shape, 58 | filter_shape = self.layer2_filter_shape, 59 | poolsize = (2,2), 60 | W = self.exist_params[4][0], 61 | b = self.exist_params[4][1], 62 | activation = acti_func) 63 | 64 | self.layer3 = LeNetConvPoolLayer(self.rng, 65 | input = self.layer2.output, 66 | image_shape = self.layer3_image_shape, 67 | filter_shape = self.layer3_filter_shape, 68 | poolsize = (2,2), 69 | W = self.exist_params[3][0], 70 | b = self.exist_params[3][1], 71 | activation = acti_func) 72 | 73 | self.layer4 = LeNetConvLayer(self.rng, 74 | input = self.layer3.output, 75 | image_shape = self.layer4_image_shape, 76 | filter_shape = self.layer4_filter_shape, 77 | W = self.exist_params[2][0], 78 | b = self.exist_params[2][1], 79 | activation = acti_func) 80 | 81 | layer3_output_flatten = self.layer3.output.flatten(2) 82 | layer4_output_flatten = self.layer4.output.flatten(2) 83 | deepid_input = T.concatenate([layer3_output_flatten, layer4_output_flatten], axis=1) 84 | 85 | self.deepid_layer = HiddenLayer(self.rng, 86 | input = deepid_input, 87 | n_in = numpy.prod( self.result_image_shape[1:] ) + numpy.prod( self.layer4_image_shape[1:] ), 88 | n_out = n_hidden, 89 | W = self.exist_params[1][0], 90 | b = self.exist_params[1][1], 91 | activation = acti_func) 92 | 93 | self.generator = theano.function(inputs=[x], 94 | outputs=self.deepid_layer.output) 95 | 96 | def generate_deepid(self, x): 97 | print '\tgenerating ...' 98 | deepid_data = self.generator(x) 99 | return deepid_data 100 | 101 | def deepid_generating(dataset_folder, params_file, result_folder, nkerns, n_hidden, acti_func=relu): 102 | if not dataset_folder.endswith('/'): 103 | dataset_folder += '/' 104 | if not result_folder.endswith('/'): 105 | result_folder += '/' 106 | if not os.path.exists(result_folder): 107 | os.mkdir(result_folder) 108 | 109 | pd_helper = ParamDumpHelper(params_file) 110 | exist_params = pd_helper.get_params_from_file() 111 | if len(exist_params) != 0: 112 | exist_params = exist_params[-1] 113 | else: 114 | print 'error, no trained params' 115 | return 116 | 117 | dataset_files = os.listdir(dataset_folder) 118 | for dataset_file in dataset_files: 119 | dataset_path = dataset_folder + dataset_file 120 | result_path = result_folder + dataset_file 121 | x, y = load_data_xy(dataset_path) 122 | deepid = DeepIDGenerator(exist_params) 123 | deepid.layer_params(nkerns, x.shape[0]) 124 | deepid.build_layer_architecture(n_hidden, acti_func) 125 | new_x = deepid.generate_deepid(x) 126 | cPickle_output((new_x, y), result_path) 127 | 128 | def load_data_xy(dataset_path): 129 | print 'loading data of %s' % (dataset_path) 130 | f = open(dataset_path, 'rb') 131 | x, y = pickle.load(f) 132 | f.close() 133 | return x,y 134 | 135 | def cPickle_output(vars, file_name): 136 | print '\twriting data to %s' % (file_name) 137 | import cPickle 138 | f = open(file_name, 'wb') 139 | cPickle.dump(vars, f, protocol=cPickle.HIGHEST_PROTOCOL) 140 | f.close() 141 | 142 | if __name__ == '__main__': 143 | if len(sys.argv) != 4: 144 | print 'Usage: python %s dataset_folder params_file result_folder' % (sys.argv[0]) 145 | sys.exit() 146 | 147 | dataset_folder = sys.argv[1] 148 | params_file = sys.argv[2] 149 | result_folder = sys.argv[3] 150 | nkerns = [20,40,60,80] 151 | n_hidden = 160 152 | 153 | deepid_generating(dataset_folder, params_file, result_folder, nkerns, n_hidden, acti_func=relu) 154 | -------------------------------------------------------------------------------- /src/conv_net/layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf8 -*- 3 | 4 | import numpy 5 | import theano 6 | import theano.tensor as T 7 | from theano.tensor.signal import downsample 8 | from theano.tensor.nnet import conv 9 | 10 | def relu(x): 11 | return T.switch(x<0, 0, x) 12 | 13 | class LogisticRegression(object): 14 | def __init__(self, input, n_in, n_out, W=None, b=None): 15 | if W is None: 16 | W = theano.shared( 17 | value=numpy.zeros( (n_in, n_out), dtype=theano.config.floatX), 18 | name='W', 19 | borrow=True) 20 | if b is None: 21 | b = theano.shared( 22 | value=numpy.zeros( (n_out,), dtype=theano.config.floatX), 23 | name='b', 24 | borrow=True) 25 | self.W = W 26 | self.b = b 27 | 28 | self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W) + self.b) 29 | self.y_pred = T.argmax(self.p_y_given_x, axis=1) 30 | self.params = [self.W, self.b] 31 | 32 | def negative_log_likelihood(self, y): 33 | return -T.mean( T.log(self.p_y_given_x)[T.arange(y.shape[0]), y]) 34 | 35 | def errors(self, y): 36 | if y.ndim != self.y_pred.ndim: 37 | raise TypeError('y should hava the same shape as self.y_pred', 38 | ('y', target.type, 'y_pred', self.y_pred,type)) 39 | if y.dtype.startswith('int'): 40 | return T.mean(T.neq(self.y_pred, y)) 41 | else: 42 | raise NotImplementedError() 43 | 44 | class HiddenLayer(object): 45 | def __init__(self, rng, input, n_in, n_out, W=None, b=None, activation=T.tanh): 46 | self.input = input 47 | if W is None: 48 | W_values = numpy.asarray( rng.uniform( 49 | low = -numpy.sqrt(6. / (n_in + n_out)), 50 | high = numpy.sqrt(6. / (n_in + n_out)), 51 | size = (n_in, n_out)), 52 | dtype = theano.config.floatX) 53 | if activation == theano.tensor.nnet.sigmoid: 54 | W_values *= 4 55 | W = theano.shared(value=W_values, name='W', borrow=True) 56 | 57 | if b is None: 58 | b_values = numpy.zeros( (n_out,), dtype=theano.config.floatX) 59 | b = theano.shared(value=b_values, name='b', borrow=True) 60 | 61 | self.W = W 62 | self.b = b 63 | lin_output = T.dot(input, self.W) + self.b 64 | self.output = (lin_output if activation is None else activation(lin_output)) 65 | 66 | self.params = [self.W, self.b] 67 | 68 | class MLP(object): 69 | def __init__(self, rng, input, n_in, n_hidden, n_out): 70 | self.hiddenLayer = HiddenLayer( rng=rng, input=input, 71 | n_in=n_in, n_out=n_hidden, 72 | activation=T.tanh) 73 | self.logRegressionLayer = LogisticRegression( 74 | input=self.hiddenLayer.output, 75 | n_in=n_hidden, 76 | n_out=n_out) 77 | self.L1 = abs(self.hiddenLayer.W).sum() + abs(self.logRegressionLayer.W).sum() 78 | self.L2_sqr = (self.hiddenLayer.W ** 2).sum() + (self.logRegressionLayer.W ** 2).sum() 79 | self.negative_log_likelihood = self.logRegressionLayer.negative_log_likelihood 80 | self.errors = self.logRegressionLayer.errors 81 | self.params = self.hiddenLayer.params + self.logRegressionLayer.params 82 | 83 | class LeNetConvLayer(object): 84 | def __init__(self, rng, input, filter_shape, image_shape, W=None, b=None, activation=T.tanh): 85 | assert image_shape[1] == filter_shape[1] 86 | self.input = input 87 | fan_in = numpy.prod(filter_shape[1:]) 88 | fan_out = filter_shape[0] * numpy.prod(filter_shape[2:]) 89 | 90 | if W is None: 91 | W_bound = numpy.sqrt(6. /(fan_in + fan_out)) 92 | W = theano.shared( 93 | numpy.asarray( 94 | rng.uniform(low=-W_bound, high=W_bound, size=filter_shape), 95 | dtype=theano.config.floatX), 96 | borrow=True) 97 | if b is None: 98 | b_values = numpy.zeros( (filter_shape[0],), dtype=theano.config.floatX) 99 | b = theano.shared(value=b_values, borrow=True) 100 | self.W = W 101 | self.b = b 102 | 103 | conv_out = conv.conv2d( 104 | input = input, 105 | filters = self.W, 106 | filter_shape = filter_shape, 107 | image_shape = image_shape) 108 | conv_out = (conv_out + self.b.dimshuffle('x', 0, 'x', 'x')) 109 | self.output = conv_out if activation is None else activation(conv_out) 110 | self.params = [self.W, self.b] 111 | 112 | class PoolLayer(object): 113 | def __init__(self, input, poolsize=(2,2)): 114 | pooled_out = downsample.max_pool_2d( 115 | input = input, 116 | ds = poolsize, 117 | ignore_border = True) 118 | self.output = pooled_out 119 | 120 | class LeNetConvPoolLayer(object): 121 | def __init__(self, rng, input, filter_shape, image_shape, poolsize=(2,2), W=None, b=None, activation=T.tanh): 122 | assert image_shape[1] == filter_shape[1] 123 | self.input = input 124 | fan_in = numpy.prod(filter_shape[1:]) 125 | fan_out = filter_shape[0] * numpy.prod(filter_shape[2:]) / numpy.prod(poolsize) 126 | 127 | if W is None: 128 | W_bound = numpy.sqrt(6. /(fan_in + fan_out)) 129 | W = theano.shared( 130 | numpy.asarray( 131 | rng.uniform(low=-W_bound, high=W_bound, size=filter_shape), 132 | dtype=theano.config.floatX), 133 | borrow=True) 134 | if b is None: 135 | b_values = numpy.zeros( (filter_shape[0],), dtype=theano.config.floatX) 136 | b = theano.shared(value=b_values, borrow=True) 137 | self.W = W 138 | self.b = b 139 | 140 | conv_out = conv.conv2d( 141 | input = input, 142 | filters = self.W, 143 | filter_shape = filter_shape, 144 | image_shape = image_shape) 145 | conv_out = conv_out + self.b.dimshuffle('x', 0, 'x', 'x') 146 | if not activation is None: 147 | conv_out = activation(conv_out) 148 | 149 | pooled_out = downsample.max_pool_2d( 150 | input = conv_out, 151 | ds = poolsize, 152 | ignore_border=True) 153 | self.output = pooled_out 154 | self.params = [self.W, self.b] 155 | -------------------------------------------------------------------------------- /README_ch.md: -------------------------------------------------------------------------------- 1 | # DeepID实践 2 | 3 | 好久没有写博客了,I have failed my blog. 目前人脸验证算法可以说是DeepID最强,本文使用theano对DeepID进行实现。关于deepid的介绍,可以参见我这一片博文 [DeepID之三代](http://blog.csdn.net/stdcoutzyx/article/details/42091205)。 4 | 5 | 当然DeepID最强指的是DeepID和联合贝叶斯两个算法,本文中只实现了DeepID神经网络,并用它作为特征提取器来应用在其他任务上。 6 | 7 | 本文所用到的代码工程在github上:[DeepID_FaceClassify](https://github.com/stdcoutzyx/DeepID_FaceClassify) 8 | 9 | # 实践流程 10 | 11 | ## 环境配置 12 | 13 | 本工程使用theano库,所以在实验之前,theano环境是必须要配的,theano环境配置可以参见[theano document](http://deeplearning.net/software/theano/install.html#install)。文档已经较为全面,本文不再赘述,在下文中,均假设读者已经装好了theano。 14 | 15 | 16 | ## 代码概览 17 | 18 | 本文所用到的代码结构如下: 19 | 20 |
21 | src/ 22 | ├── conv_net 23 | │ ├── deepid_class.py 24 | │ ├── deepid_generate.py 25 | │ ├── layers.py 26 | │ ├── load_data.py 27 | │ └── sample_optimization.py 28 | └── data_prepare 29 | ├── vectorize_img.py 30 | ├── youtube_data_split.py 31 | └── youtube_img_crop.py 32 |33 | 34 | 正如文件名命名所指出的,代码共分为两个模块,即数据准备模块(`data_prepare`)和卷积神经网络模块(`conv_net`)。 35 | 36 | 37 | ## 数据准备 38 | 39 | 我觉得DeepID的强大得益于两个因素,卷积神经网络的结构和数据,数据对于DeepID或者说对任何的卷积神经网络都非常重要。 40 | 41 | 可惜的是,我去找过论文作者要过数据,可是被婉拒。所以在本文的实验中,我使用的数据并非论文中的数据。经过下面的描述你可以知道,如果你还有其他的数据,可以很轻松的用python将其处理为本文DeepID网络的输入数据。 42 | 43 | 以youtube face数据为例。它的文件夹结构如下所示,包含三级结构,第一是以人为单位,然后每个人有不同的视频,每个视频中采集出多张人脸图像。 44 | 45 |
46 | youtube_data/ 47 | ├── people_folderA 48 | │ ├── video_folderA 49 | │ │ ├── img1.jpg 50 | │ │ ├── img2.jpg 51 | │ │ └── imgN.jpg 52 | │ └── video_folderB 53 | └── people_folderB 54 |55 | 56 | 拿到youtube face数据以后,需要做如下两件事: 57 | 58 | - 对图像进行预处理,原来的youtube face图像中,人脸只占中间很小的一部分,我们对其进行裁剪,使人脸的比例变大。同时,将图像缩放为(47,55)大小。 59 | - 将数据集合划分为训练集和验证集。本文中划分训练集和验证集的方式如下: 60 | - 对于每一个人,将其不同视频下的图像混合在一起 61 | - 随机化 62 | - 选择前5张作为验证集,第6-25张作为训练集。 63 | 64 | 经过划分后,得到7975张验证集和31900训练集。显然,根据这两个数字你可以算出一共有1595个类(人)。 65 | 66 | ## 数据准备的代码使用 67 | 68 | **注意:** 数据准备模块中以youtube为前缀的的程序是专门用来处理youtube数据,因为其他数据可能图像属性和文件夹的结构不一样。如果你使用了其他数据,请阅读`youtube_img_crop.py`和`youtube_data_split.py`代码,然后重新写出适合自己数据的代码。数据预处理代码都很简单,相信在我代码的基础上,不需要改太多,就能适应另一种数据了。 69 | 70 | ### youtube_img_crop.py 71 | 72 | 被用来裁剪图片,youtube face数据中图像上人脸的比例都相当小,本程序用于将图像的边缘裁减掉,然后将图像缩放为47×55(DeepID的输入图像大小)。 73 | 74 | Usage: python youtube_img_crop.py aligned_db_folder new_folder 75 | 76 | - aligned_db_folder: 原始文件夹 77 | - new_folder: 结果文件夹,与原始文件夹的文件夹结构一样,只不过图像是被处理后的图像。 78 | 79 | #### youtube_data_split.py 80 | 81 | 用来切分数据,将数据分为训练集和验证集。 82 | 83 | Usage: python youtube_data_split.py src_folder test_set_file train_set_file 84 | 85 | - src_folder: 原始文件夹,此处应该为上一步得到的新文件夹 86 | - test_set_file: 验证集图片路径集合文件 87 | - train_set_file: 训练集图片路径集合文件 88 | 89 | `test_set_file`和`train_set_file`的格式如下,每一行分为两部分,第一部分是图像路径,第二部分是图像的类别标记。 90 | 91 | ``` 92 | youtube_47_55/Alan_Ball/2/aligned_detect_2.405.jpg,0 93 | youtube_47_55/Alan_Ball/2/aligned_detect_2.844.jpg,0 94 | youtube_47_55/Xiang_Liu/5/aligned_detect_5.1352.jpg,1 95 | youtube_47_55/Xiang_Liu/1/aligned_detect_1.482.jpg,1 96 | ``` 97 | 98 | ### vectorize_img.py 99 | 100 | 用来将图像向量化,每张图像都是47×55的,所以每张图片变成一个47×55×3的向量。 101 | 102 | 为了避免超大文件的出现,本程序自动将数据切分为小文件,每个小文件中只有1000张图片,即1000×(47×55×3)的矩阵。当然,最后一个小文件不一定是1000张。 103 | 104 | Usage: python vectorize_img.py test_set_file train_set_file test_vector_folder train_vector_folder 105 | 106 | - test_set_file: `*_data_split.py`生成的 107 | - train_set_file: `*_ata_split.py`生成的 108 | - test_vector_folder: 存储验证集向量文件的文件夹名称 109 | - train_vector_folder: 存储训练集向量文件的文件夹名称 110 | 111 | ## Conv_Net 112 | 113 | 走完了漫漫前路,终于可以直捣黄龙了。现在是DeepID时间。吼吼哈嘿。 114 | 115 | 在conv_net模块中,有五个程序文件 116 | 117 | - layers.py: 卷积神经网络相关的各种层次的定义,包括逻辑斯底回归层、隐含层、卷积层、max_pooling层等 118 | - load_data.py: 为DeepID载入数据。 119 | - sample_optimization.py: 针对各种层次的一些测试实验。 120 | - deepid_class.py: DeepID主程序 121 | - deepid_generate.py: 根据DeepID训练好的参数,来将隐含层抽取出来 122 | 123 | ## Conv_Net代码使用 124 | 125 | ### deepid_class.py 126 | 127 | Usage: python deepid_class.py vec_valid vec_train params_file 128 | 129 | - vec_valid: `vectorize_img.py`生成的 130 | - vec_train: `vectorize_img.py`生成的 131 | - params_file: 用来存储训练时每次迭代的参数,可以被用来断点续跑,由于CNN程序一般需要较长时间,万一遇到停电啥的,就可以用得上了。自然,更大的用途是保存参数后用来抽取特征。 132 | 133 | **注意:** 134 | 135 | DeepID训练过程有太多的参数需要调整,为了程序使用简便,我并没有把这些参数都使用命令行传参。如果你想要改变迭代次数、学习速率、批大小等参数,请在程序的最后一行调用函数里改。 136 | 137 | 138 | ### deepid_generate.py 139 | 140 | 可以使用下面的命令来抽取DeepID的隐含层,即160-d的那一层。 141 | 142 | Usage: python deepid_generate.py dataset_folder params_file result_folder 143 | 144 | - dataset_folder: 可以是训练集向量文件夹或者验证集向量文件夹。 145 | - params_file: `deepid_class.py`训练得到 146 | - result_folder: 结果文件夹,其下的文件与dataset_folder中文件的文件名一一对应,但是结果文件夹中的向量的长度变为160而不是原来的7755。 147 | 148 | # 效果展示 149 | 150 | ## DeepID 效果 151 | 152 | 跑完`deepid_class.py`以后,你可以得到输出如下。输出可以分为两部分,第一部分是每次迭代以及每个小batch的训练集误差,验证集误差等。第二部分是一个汇总,将`epoch train error valid error`. 按照统一格式打印了出来。 153 | 154 | ``` 155 | epoch 15, train_score 0.000444, valid_score 0.066000 156 | epoch 16, minibatch_index 62/63, error 0.000000 157 | epoch 16, train_score 0.000413, valid_score 0.065733 158 | epoch 17, minibatch_index 62/63, error 0.000000 159 | epoch 17, train_score 0.000508, valid_score 0.065333 160 | epoch 18, minibatch_index 62/63, error 0.000000 161 | epoch 18, train_score 0.000413, valid_score 0.070267 162 | epoch 19, minibatch_index 62/63, error 0.000000 163 | epoch 19, train_score 0.000413, valid_score 0.064533 164 | 165 | 0 0.974349206349 0.962933333333 166 | 1 0.890095238095 0.897466666667 167 | 2 0.70126984127 0.666666666667 168 | 3 0.392031746032 0.520133333333 169 | 4 0.187619047619 0.360666666667 170 | 5 0.20526984127 0.22 171 | 6 0.054380952381 0.171066666667 172 | 7 0.0154920634921 0.128 173 | 8 0.00650793650794 0.100133333333 174 | 9 0.00377777777778 0.0909333333333 175 | 10 0.00292063492063 0.086 176 | 11 0.0015873015873 0.0792 177 | 12 0.00133333333333 0.0754666666667 178 | 13 0.00111111111111 0.0714666666667 179 | 14 0.000761904761905 0.068 180 | 15 0.000444444444444 0.066 181 | 16 0.000412698412698 0.0657333333333 182 | 17 0.000507936507937 0.0653333333333 183 | 18 0.000412698412698 0.0702666666667 184 | 19 0.000412698412698 0.0645333333333 185 | ``` 186 | 187 | 上述数据画成折线图如下: 188 | 189 |  190 | 191 | ## 向量抽取效果展示 192 | 193 | 运行`deepid_generate.py`之后, 可以得到输出如下: 194 | 195 | ``` 196 | loading data of vec_test/0.pkl 197 | building the model ... 198 | generating ... 199 | writing data to deepid_test/0.pkl 200 | loading data of vec_test/3.pkl 201 | building the model ... 202 | generating ... 203 | writing data to deepid_test/3.pkl 204 | loading data of vec_test/1.pkl 205 | building the model ... 206 | generating ... 207 | writing data to deepid_test/1.pkl 208 | loading data of vec_test/7.pkl 209 | building the model ... 210 | generating ... 211 | writing data to deepid_test/7.pkl 212 | ``` 213 | 214 | 程序会对向量化文件夹内的每一个文件进行抽取操作,得到对应的160-d向量化文件。 215 | 216 | 将隐含层抽取出来后,我们可以在一些其他领域上验证该特征的有效性,比如图像检索。可以使用我的另一个github工程进行测试,[这是链接](https://github.com/stdcoutzyx/FaceRetrieval).使用验证集做查询集,训练集做被查询集,来看一下检索效果如何。 217 | 218 | 为了做对比,本文在youtube face数据上做了两个人脸检索实验。 219 | 220 | 221 | - PCA exp. 在 `vectorized_img.py`生成的数据上,使用PCA将特征降到160-d,然后进行人脸检索实验。 222 | - DeepID exp. 在 `deepid_generate.py`生成的160-d数据上直接进行人脸检索实验。 223 | 224 | **注意:** 在两个实验中,我都使用cosine相似度计算距离,之前做过很多实验,cosine距离比欧式距离要好。 225 | 226 | 人脸检索结果如下: 227 | 228 | |Precision| Top-1| Top-5| Top-10| 229 | |---------|------|------|-------| 230 | |PCA |95.20%|96.75%|97.22% | 231 | |DeepID |97.27%|97.93%|98.25% | 232 | 233 | |AP | Top-1| Top-5| Top-10| 234 | |---------|------|------|-------| 235 | |PCA |95.20%|84.19%|70.66% | 236 | |DeepID |97.27%|89.22%|76.64% | 237 | 238 | Precision意味着在top-N结果中只要出现相同类别的人,就算这次查询成功,否则失败。而AP则意味着,在top-N结果中需要统计与查询图片相同类别的图片有多少张,然后除以N,是这次查询的准确率,然后再求平均。 239 | 240 | 从结果中可以看到,在相同维度下,DeepID在信息的表达上还是要强于PCA的。 241 | 242 | # 参考文献 243 | 244 | [1]. Sun Y, Wang X, Tang X. Deep learning face representation from predicting 10,000 classes[C]//Computer Vision and Pattern Recognition (CVPR), 2014 IEEE Conference on. IEEE, 2014: 1891-1898. 245 | 246 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepID_FaceClassify 2 | 3 | Implementation of DeepID using theano. 4 | 5 | u can also see a Chinese version of this doc in [chinese version](https://github.com/stdcoutzyx/DeepID_FaceClassify/blob/master/README_ch.md), or u can see that in my [blog](http://blog.csdn.net/stdcoutzyx/article/details/45570221). 6 | 7 | # Usage 8 | 9 | ## Environment 10 | 11 | You have to install theano and related libs. There are enough information in the [theano document](http://deeplearning.net/software/theano/install.html#install). So i will assume all of the readers have installed theano correctly. 12 | 13 | ## Implemented Programmes 14 | 15 | The structure of my code look like that: 16 | 17 |
18 | src/ 19 | ├── conv_net 20 | │ ├── deepid_class.py 21 | │ ├── deepid_generate.py 22 | │ ├── layers.py 23 | │ ├── load_data.py 24 | │ └── sample_optimization.py 25 | └── data_prepare 26 | ├── vectorize_img.py 27 | ├── youtube_data_split.py 28 | └── youtube_img_crop.py 29 |30 | 31 | Just as the names of the folders imply, there are two modules without reference to each other in my code. The `data_prepare` module is used to prepare data. And the `conv_net` module is the implemention of DeepID. 32 | 33 | 34 | ## Data Preparation 35 | 36 | ### Details 37 | 38 | There are two parts which are important and neccessary for the amazing performance of DeepID, namely the net structure of the Convolutional Neural Network and the data. 39 | 40 | I had asked the author for the data, got nothing but a polite reply. So in my experiment, some other data are used instead. 41 | 42 | Take the youtube face data as an example. There are three levels of folders, which is showed below: 43 | 44 |
45 | youtube_data/ 46 | ├── people_folderA 47 | │ ├── video_folderA 48 | │ │ ├── img1.jpg 49 | │ │ ├── img2.jpg 50 | │ │ └── imgN.jpg 51 | │ └── video_folderB 52 | └── people_folderB 53 |54 | 55 | The first thing need to be done is to seperate the data into train set and validate set. The way i choose train set and validate set is as below: 56 | 57 | - Mix the imgs of the same person but different videos together. 58 | - Random shuffle 59 | - Choose first 5 imgs as validate set. 60 | - Choose the 5th to 25th imgs as the train set. 61 | 62 | At last, i get 7975 imgs as the validation set and 31900 imgs as the train set. Obviously, you will know that there are 1595 classes(persons) totally. 63 | 64 | ### Usage of Code 65 | 66 | **Note:** the file prefixed with "youtube" is specifically for the youtube data because of the folder structure and the img property. So if you want to deal with some other dataset, please read the code of `*_img_crop.py` and `*_data_split.py` and re-implement them. I believe the code is readable and easy to understand for the readers. 67 | 68 | #### youtube_img_crop.py 69 | 70 | Used to get the face out of the img. Face in youtube data has been aligned into the center of the img. So this programme aims to increase the ratio of the face in the whole img and resize the img into (47,55), which is the input size for the DeepID. 71 | 72 | Usage: python youtube_img_crop.py aligned_db_folder new_folder 73 | 74 | - aligned_db_folder: source folder 75 | - new_folder: The programme will generate the whole folder structure the same as the source folder, with all the imgs are processed into new size. 76 | 77 | #### youtube_data_split.py 78 | 79 | Used to split data into two set, One is for train and one is for valid. 80 | 81 | Usage: python youtube_data_split.py src_folder test_set_file train_set_file 82 | 83 | The format of test_set_file and train_set_file is like below. There are two parts in one line, the first is path of the img, the second is label of the img. 84 | 85 | ``` 86 | youtube_47_55/Alan_Ball/2/aligned_detect_2.405.jpg,0 87 | youtube_47_55/Alan_Ball/2/aligned_detect_2.844.jpg,0 88 | youtube_47_55/Xiang_Liu/5/aligned_detect_5.1352.jpg,1 89 | youtube_47_55/Xiang_Liu/1/aligned_detect_1.482.jpg,1 90 | ``` 91 | 92 | #### vectorize_img.py 93 | 94 | Used to vectorize the imgs. To make the thousands of imgs into a two-d array, whose size is (m,n). m is the number of samples, n is the 47×55×3. 95 | 96 | To avoid occurance of super big file, `vectorize_img.py` automatically seperate data into batches with 1000 samples in each batch. 97 | 98 | Usage: python vectorize_img.py test_set_file train_set_file test_vector_folder train_vector_folder 99 | 100 | - test_set_file: generated by *_data_split.py 101 | - train_set_file: generated by *_data_split.py 102 | - test_vector_folder: the folder name to store the vector files of validate set 103 | - train_vector_folder: the folder name to store the vector files of train set 104 | 105 | 106 | ## Conv_Net 107 | 108 | ### Details 109 | 110 | Now it's the exciting time. 111 | 112 | In the conv_net module, there are five programme files. 113 | 114 | - layers.py: definition of different types of layer, including LogisticRegression, HiddenLayer, LeNetConvLayer, PoolLayer and LeNetConvPoolLayer. 115 | - load_data.py: load data for the executive programme. 116 | - sample_optimization.py: some test function to validate the corrective of layers defined in `layers.py`. 117 | - deepid_class.py: DeepID main programme. 118 | - deepid_generate.py: get the Hidden Layer used the trained parameters. 119 | 120 | ### Usage of Code 121 | 122 | #### deepid_class.py 123 | 124 | Usage: python deepid_class.py vec_valid vec_train params_file 125 | 126 | - vec_valid: generated by `vectorize_img.py` 127 | - vec_train: generated by `vectorize_img.py` 128 | - params_file: to store the trained parameters of all iterations. It can be used if your computer come across unexpected shutdown. And it can be used to extract the hidden layer of the net. 129 | 130 | **Note:** 131 | there are so many parameters need to be adjusted for DeepID, so i did not show them directly in the command line for the simple use of my code. If you want to change the epoch num, learning rate, batch size and so on, please change them in the last line of the file. 132 | 133 | 134 | #### deepid_generate.py 135 | 136 | You can extract the hidden layer whose dimension is 160 with command below: 137 | 138 | Usage: python deepid_generate.py dataset_folder params_file result_folder 139 | 140 | - dataset_folder: it can be the folder of train set or valid set. 141 | - params_file: trained by `deepid_class.py` 142 | - result_folder: include files whose name are the same as in the dataset_folder, but the dimension of x in each file will be num_sample×160 instead of num_samples×7755. 143 | 144 | 145 | # Performance 146 | 147 | ## DeepID performance 148 | 149 | After running the `deepid_class.py`, you will get the output of the programme like that. The first part is the train error and valid error of each epoch, The second part is the summarization of the `epoch, train error, valid error`. 150 | 151 | ``` 152 | epoch 15, train_score 0.000444, valid_score 0.066000 153 | epoch 16, minibatch_index 62/63, error 0.000000 154 | epoch 16, train_score 0.000413, valid_score 0.065733 155 | epoch 17, minibatch_index 62/63, error 0.000000 156 | epoch 17, train_score 0.000508, valid_score 0.065333 157 | epoch 18, minibatch_index 62/63, error 0.000000 158 | epoch 18, train_score 0.000413, valid_score 0.070267 159 | epoch 19, minibatch_index 62/63, error 0.000000 160 | epoch 19, train_score 0.000413, valid_score 0.064533 161 | 162 | 0 0.974349206349 0.962933333333 163 | 1 0.890095238095 0.897466666667 164 | 2 0.70126984127 0.666666666667 165 | 3 0.392031746032 0.520133333333 166 | 4 0.187619047619 0.360666666667 167 | 5 0.20526984127 0.22 168 | 6 0.054380952381 0.171066666667 169 | 7 0.0154920634921 0.128 170 | 8 0.00650793650794 0.100133333333 171 | 9 0.00377777777778 0.0909333333333 172 | 10 0.00292063492063 0.086 173 | 11 0.0015873015873 0.0792 174 | 12 0.00133333333333 0.0754666666667 175 | 13 0.00111111111111 0.0714666666667 176 | 14 0.000761904761905 0.068 177 | 15 0.000444444444444 0.066 178 | 16 0.000412698412698 0.0657333333333 179 | 17 0.000507936507937 0.0653333333333 180 | 18 0.000412698412698 0.0702666666667 181 | 19 0.000412698412698 0.0645333333333 182 | ``` 183 | 184 | You can also put the second part of the output into a figure with matplotlib. 185 | 186 |  187 | 188 | ## Generated Feature performance 189 | 190 | After running `deepid_generate.py`, you will get output like below: 191 | 192 | ``` 193 | loading data of vec_test/0.pkl 194 | building the model ... 195 | generating ... 196 | writing data to deepid_test/0.pkl 197 | loading data of vec_test/3.pkl 198 | building the model ... 199 | generating ... 200 | writing data to deepid_test/3.pkl 201 | loading data of vec_test/1.pkl 202 | building the model ... 203 | generating ... 204 | writing data to deepid_test/1.pkl 205 | loading data of vec_test/7.pkl 206 | building the model ... 207 | generating ... 208 | writing data to deepid_test/7.pkl 209 | ``` 210 | 211 | The programme will extract on each sub file of the vectorized data. 212 | 213 | After extracting the hidden layer, we can do some other things to prove the effiency of the deepid feature. For example, in the domain of feature retrieval, you can use my another github project to test on the data generated in this project, here is the [link](https://github.com/stdcoutzyx/FaceRetrieval). 214 | 215 | For comparison, i have done two experiments on the youtube face data for face retrieval. 216 | 217 | - PCA exp. Reduce feature to 160-d on data generated by `vectorized_img.py`, and do face retrieval exp on that 218 | - DeepID exp. Do face retrieval exp directly on the data generated by `deepid_generate.py`. 219 | 220 | **Note:** In both experiments, i use the cosine distance to measure the similarity of two vectors. 221 | 222 | Results of face retrieval are below: 223 | 224 | |Precision| Top-1| Top-5| Top-10| 225 | |---------|------|------|-------| 226 | |PCA |95.20%|96.75%|97.22% | 227 | |DeepID |97.27%|97.93%|98.25% | 228 | 229 | |AP | Top-1| Top-5| Top-10| 230 | |---------|------|------|-------| 231 | |PCA |95.20%|84.19%|70.66% | 232 | |DeepID |97.27%|89.22%|76.64% | 233 | 234 | Precision means if there is a photo who has the same people with the query image in the top-N results, it's correct. 235 | But AP will calculate how many photos who has the same people with the query image in the top-N results. 236 | 237 | From the results, we can know the DeepID feature is superior to the pca method with the equal dimension. 238 | 239 | # Reference 240 | 241 | [1]. Sun Y, Wang X, Tang X. Deep learning face representation from predicting 10,000 classes[C]//Computer Vision and Pattern Recognition (CVPR), 2014 IEEE Conference on. IEEE, 2014: 1891-1898. 242 | 243 | -------------------------------------------------------------------------------- /src/conv_net/sample_optimization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf8 -*- 3 | 4 | import cPickle 5 | import gzip 6 | import os 7 | import sys 8 | import time 9 | import numpy 10 | import theano 11 | import theano.tensor as T 12 | 13 | from layers import * 14 | from load_data import * 15 | 16 | def sgd_optimization_mnist(learning_rate=0.13, 17 | n_epochs=1000, dataset='mnist.pkl.gz', batch_size=300): 18 | datasets = load_data(dataset) 19 | train_set_x, train_set_y = datasets[0] 20 | valid_set_x, valid_set_y = datasets[1] 21 | 22 | n_train_batches = train_set_x.get_value(borrow=True).shape[0] / batch_size 23 | n_valid_batches = valid_set_x.get_value(borrow=True).shape[0] / batch_size 24 | 25 | index = T.lscalar() 26 | x = T.matrix('x') 27 | y = T.ivector('y') 28 | step_rate = T.dscalar() 29 | 30 | classifier = LogisticRegression(input=x, n_in=28*28, n_out=10) 31 | cost = classifier.negative_log_likelihood(y) 32 | 33 | test_valid_model = theano.function(inputs=[index], 34 | outputs=classifier.errors(y), 35 | givens = { 36 | x: valid_set_x[index * batch_size : (index+1) * batch_size], 37 | y: valid_set_y[index * batch_size : (index+1) * batch_size]} 38 | ) 39 | 40 | test_train_model = theano.function(inputs=[index], 41 | outputs=classifier.errors(y), 42 | givens = { 43 | x: train_set_x[index * batch_size : (index+1) * batch_size], 44 | y: train_set_y[index * batch_size : (index+1) * batch_size]} 45 | ) 46 | 47 | g_W = T.grad(cost=cost, wrt=classifier.W) 48 | g_b = T.grad(cost=cost, wrt=classifier.b) 49 | updates = [(classifier.W, classifier.W - step_rate * g_W), 50 | (classifier.b, classifier.b - step_rate * g_b)] 51 | train_model = theano.function(inputs=[index, step_rate], 52 | outputs=cost, 53 | updates=updates, 54 | givens = { 55 | x: train_set_x[index * batch_size : (index+1) * batch_size], 56 | y: train_set_y[index * batch_size : (index+1) * batch_size]} 57 | ) 58 | print 'Train the model ...' 59 | train_sample_num = train_set_x.get_value(borrow=True).shape[0] 60 | valid_sample_num = valid_set_x.get_value(borrow=True).shape[0] 61 | 62 | epoch = 0 63 | while epoch < n_epochs: 64 | epoch += 1 65 | if epoch > 50: 66 | learning_rate = 0.1 67 | for minibatch_index in xrange(n_train_batches): 68 | minibatch_cost = train_model(minibatch_index, learning_rate) 69 | train_losses = [test_train_model(i) for i in xrange(n_train_batches)] 70 | valid_losses = [test_valid_model(i) for i in xrange(n_valid_batches)] 71 | 72 | ''' 73 | train_score = numpy.sum(train_losses) 74 | valid_score = numpy.sum(valid_losses) 75 | print 'epoch %i, train_score %f, valid_score %f' % (epoch, float(train_score) / train_sample_num, float(valid_score) / valid_sample_num) 76 | ''' 77 | train_score = numpy.mean(train_losses) 78 | valid_score = numpy.mean(valid_losses) 79 | print 'epoch %i, train_score %f, valid_score %f' % (epoch, train_score, valid_score) 80 | 81 | def test_mlp(learning_rate=0.01, L1_reg=0.00, L2_reg=0.0001, 82 | n_epochs=1000, dataset='mnist.pkl.gz', batch_size=20, n_hidden=500): 83 | datasets = load_data(dataset) 84 | train_set_x, train_set_y = datasets[0] 85 | valid_set_x, valid_set_y = datasets[1] 86 | 87 | n_train_batches = train_set_x.get_value(borrow=True).shape[0] / batch_size 88 | n_valid_batches = valid_set_x.get_value(borrow=True).shape[0] / batch_size 89 | 90 | index = T.lscalar() 91 | x = T.matrix('x') 92 | y = T.ivector('y') 93 | 94 | rng = numpy.random.RandomState(1234) 95 | 96 | classifier = MLP(rng=rng, input=x, n_in=28*28, 97 | n_hidden=n_hidden, n_out=10) 98 | cost = classifier.negative_log_likelihood(y) + L1_reg * classifier.L1 + L2_reg * classifier.L2_sqr 99 | 100 | test_valid_model = theano.function(inputs=[index], 101 | outputs=classifier.errors(y), 102 | givens = { 103 | x: valid_set_x[index * batch_size : (index+1) * batch_size], 104 | y: valid_set_y[index * batch_size : (index+1) * batch_size]} 105 | ) 106 | 107 | test_train_model = theano.function(inputs=[index], 108 | outputs=classifier.errors(y), 109 | givens = { 110 | x: train_set_x[index * batch_size : (index+1) * batch_size], 111 | y: train_set_y[index * batch_size : (index+1) * batch_size]} 112 | ) 113 | 114 | gparams = [] 115 | for param in classifier.params: 116 | gparam = T.grad(cost, param) 117 | gparams.append(gparam) 118 | updates = [] 119 | for param, gparam in zip(classifier.params, gparams): 120 | updates.append((param, param - learning_rate * gparam)) 121 | 122 | train_model = theano.function(inputs=[index], 123 | outputs=cost, 124 | updates=updates, 125 | givens = { 126 | x: train_set_x[index * batch_size : (index+1) * batch_size], 127 | y: train_set_y[index * batch_size : (index+1) * batch_size]} 128 | ) 129 | print 'Train the model ...' 130 | train_sample_num = train_set_x.get_value(borrow=True).shape[0] 131 | valid_sample_num = valid_set_x.get_value(borrow=True).shape[0] 132 | 133 | epoch = 0 134 | while epoch < n_epochs: 135 | epoch += 1 136 | for minibatch_index in xrange(n_train_batches): 137 | minibatch_cost = train_model(minibatch_index) 138 | train_losses = [test_train_model(i) for i in xrange(n_train_batches)] 139 | valid_losses = [test_valid_model(i) for i in xrange(n_valid_batches)] 140 | 141 | ''' 142 | train_score = numpy.sum(train_losses) 143 | valid_score = numpy.sum(valid_losses) 144 | print 'epoch %i, train_score %f, valid_score %f' % (epoch, float(train_score) / train_sample_num, float(valid_score) / valid_sample_num) 145 | ''' 146 | train_score = numpy.mean(train_losses) 147 | valid_score = numpy.mean(valid_losses) 148 | print 'epoch %i, train_score %f, valid_score %f' % (epoch, train_score, valid_score) 149 | 150 | 151 | def evaluate_lenet3(learning_rate=0.1, n_epochs=200, dataset='mnist.pkl.gz', nkerns=[20,50], batch_size=500): 152 | ''' 153 | layer0: convpool layer 154 | layer1: convpool layer 155 | layer1: hidden layer 156 | layer2: logistic layer 157 | ''' 158 | 159 | datasets = load_data(dataset) 160 | train_set_x, train_set_y = datasets[0] 161 | valid_set_x, valid_set_y = datasets[1] 162 | 163 | n_train_batches = train_set_x.get_value(borrow=True).shape[0] / batch_size 164 | n_valid_batches = valid_set_x.get_value(borrow=True).shape[0] / batch_size 165 | 166 | index = T.lscalar() 167 | x = T.matrix('x') 168 | y = T.ivector('y') 169 | 170 | image_shape = (batch_size, 1, 28, 28) 171 | rng = numpy.random.RandomState(1234) 172 | 173 | print 'building the model ...' 174 | 175 | layer0_input = x.reshape(image_shape) 176 | layer0 = LeNetConvPoolLayer(rng, 177 | input = layer0_input, 178 | image_shape = image_shape, 179 | filter_shape = (nkerns[0], 1, 5, 5), 180 | poolsize = (2, 2), 181 | activation = relu) 182 | layer1 = LeNetConvPoolLayer(rng, 183 | input = layer0.output, 184 | image_shape = (batch_size, nkerns[0], 12, 12), 185 | filter_shape = (nkerns[1], nkerns[0], 5, 5), 186 | poolsize = (2,2), 187 | activation = relu) 188 | 189 | layer2_input = layer1.output.flatten(2) 190 | layer2 = HiddenLayer(rng, 191 | input = layer2_input, 192 | n_in = nkerns[1] * 4 * 4, 193 | n_out = 500, 194 | activation = relu) 195 | layer3 = LogisticRegression( 196 | input = layer2.output, 197 | n_in = 500, 198 | n_out = 10) 199 | 200 | cost = layer3.negative_log_likelihood(y) 201 | 202 | test_valid_model = theano.function(inputs=[index], 203 | outputs=layer3.errors(y), 204 | givens = { 205 | x: valid_set_x[index * batch_size : (index+1) * batch_size], 206 | y: valid_set_y[index * batch_size : (index+1) * batch_size]} 207 | ) 208 | 209 | test_train_model = theano.function(inputs=[index], 210 | outputs=layer3.errors(y), 211 | givens = { 212 | x: train_set_x[index * batch_size : (index+1) * batch_size], 213 | y: train_set_y[index * batch_size : (index+1) * batch_size]} 214 | ) 215 | 216 | params = layer3.params + layer2.params + layer1.params + layer0.params 217 | gparams = [] 218 | for param in params: 219 | gparam = T.grad(cost, param) 220 | gparams.append(gparam) 221 | updates = [] 222 | for param, gparam in zip(params, gparams): 223 | updates.append((param, param - learning_rate * gparam)) 224 | 225 | train_model = theano.function(inputs=[index], 226 | outputs=cost, 227 | updates=updates, 228 | givens = { 229 | x: train_set_x[index * batch_size : (index+1) * batch_size], 230 | y: train_set_y[index * batch_size : (index+1) * batch_size]} 231 | ) 232 | print 'Train the model ...' 233 | train_sample_num = train_set_x.get_value(borrow=True).shape[0] 234 | valid_sample_num = valid_set_x.get_value(borrow=True).shape[0] 235 | 236 | epoch = 0 237 | while epoch < n_epochs: 238 | epoch += 1 239 | for minibatch_index in xrange(n_train_batches): 240 | minibatch_cost = train_model(minibatch_index) 241 | print '\tepoch %i, minibatch_index %i/%i, minibatch_cost %f' % (epoch, minibatch_index, n_train_batches, minibatch_cost) 242 | train_losses = [test_train_model(i) for i in xrange(n_train_batches)] 243 | valid_losses = [test_valid_model(i) for i in xrange(n_valid_batches)] 244 | 245 | ''' 246 | train_score = numpy.sum(train_losses) 247 | valid_score = numpy.sum(valid_losses) 248 | print 'epoch %i, train_score %f, valid_score %f' % (epoch, float(train_score) / train_sample_num, float(valid_score) / valid_sample_num) 249 | ''' 250 | train_score = numpy.mean(train_losses) 251 | valid_score = numpy.mean(valid_losses) 252 | print 'epoch %i, train_score %f, valid_score %f' % (epoch, train_score, valid_score) 253 | 254 | 255 | if __name__ == '__main__': 256 | if len(sys.argv) != 2: 257 | print 'Usage: python %s (dataset_file)' % (sys.argv[0]) 258 | sys.exit() 259 | sgd_optimization_mnist(learning_rate=0.2, n_epochs=1000, dataset=sys.argv[1], batch_size=600) 260 | # test_mlp(learning_rate=0.01, L1_reg=0.00, L2_reg=0.0001, n_epochs=1000, dataset=sys.argv[1], batch_size=20, n_hidden=500) 261 | # evaluate_lenet3(learning_rate=0.1, n_epochs=200, dataset=sys.argv[1], nkerns=[20, 50], batch_size=500) 262 | -------------------------------------------------------------------------------- /src/conv_net/deepid_class.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf8 -*- 3 | 4 | from layers import * 5 | from load_data import * 6 | 7 | import cPickle 8 | import gzip 9 | import os 10 | import sys 11 | import time 12 | import numpy 13 | import theano 14 | 15 | class ParamDumpHelper: 16 | def __init__(self, dump_file): 17 | self.dump_file = dump_file 18 | 19 | def dump(self, params): 20 | dumped_params = self.get_params_from_file() 21 | dumped_params.append(params) 22 | self.params_to_file(dumped_params) 23 | 24 | def params_to_file(self, params): 25 | f = gzip.open(self.dump_file, 'wb') 26 | if len(params) > 20: 27 | params = params[10:] 28 | pickle.dump(params, f) 29 | f.close() 30 | 31 | def get_params_from_file(self): 32 | if os.path.exists(self.dump_file): 33 | f = gzip.open(self.dump_file, 'rb') 34 | dumped_params = pickle.load(f) 35 | f.close() 36 | return dumped_params 37 | return [] 38 | 39 | class DeepID: 40 | def __init__(self, pd_helper): 41 | self.rng = numpy.random.RandomState(1234) 42 | self.pd_helper = pd_helper 43 | exist_params = pd_helper.get_params_from_file() 44 | if len(exist_params) != 0: 45 | self.exist_params = exist_params[-1] 46 | else: 47 | self.exist_params = [[None, None], 48 | [None, None], 49 | [None, None], 50 | [None, None], 51 | [None, None], 52 | [None, None], 53 | 1.0, 54 | 1.0, 55 | 0] 56 | 57 | def load_data_deepid(self, dataset_file, batch_size): 58 | print 'loading data ...' 59 | datasets = load_data_split_pickle(dataset_file) 60 | self.train_set_x, self.train_set_y = datasets[0] 61 | self.valid_set_x, self.valid_set_y = datasets[1] 62 | 63 | self.n_train_batches = self.train_set_x.get_value(borrow=True).shape[0] / batch_size 64 | self.n_valid_batches = self.valid_set_x.get_value(borrow=True).shape[0] / batch_size 65 | self.batch_size = batch_size 66 | 67 | print 'train_x: ', self.train_set_x.get_value(borrow=True).shape 68 | print 'train_y: ', self.train_set_y.shape 69 | print 'valid_x: ', self.valid_set_x.get_value(borrow=True).shape 70 | print 'valid_y: ', self.valid_set_y.shape 71 | 72 | 73 | def layer_params(self, nkerns=[20,40,60,80]): 74 | src_channel = 3 75 | self.layer1_image_shape = (self.batch_size, src_channel, 55, 47) 76 | self.layer1_filter_shape = (nkerns[0], src_channel, 4, 4) 77 | self.layer2_image_shape = (self.batch_size, nkerns[0], 26, 22) 78 | self.layer2_filter_shape = (nkerns[1], nkerns[0], 3, 3) 79 | self.layer3_image_shape = (self.batch_size, nkerns[1], 12, 10) 80 | self.layer3_filter_shape = (nkerns[2], nkerns[1], 3, 3) 81 | self.layer4_image_shape = (self.batch_size, nkerns[2], 5, 4) 82 | self.layer4_filter_shape = (nkerns[3], nkerns[2], 2, 2) 83 | self.result_image_shape = (self.batch_size, nkerns[3], 4, 3) 84 | 85 | def build_layer_architecture(self, n_hidden, n_out, acti_func=relu): 86 | ''' 87 | simple means the deepid layer input is only the layer4 output. 88 | layer1: convpool layer 89 | layer2: convpool layer 90 | layer3: convpool layer 91 | layer4: conv layer 92 | deepid: hidden layer 93 | softmax: logistic layer 94 | ''' 95 | self.index = T.lscalar() 96 | self.step_rate = T.dscalar() 97 | self.x = T.matrix('x') 98 | self.y = T.ivector('y') 99 | 100 | print 'building the model ...' 101 | 102 | layer1_input = self.x.reshape(self.layer1_image_shape) 103 | self.layer1 = LeNetConvPoolLayer(self.rng, 104 | input = layer1_input, 105 | image_shape = self.layer1_image_shape, 106 | filter_shape = self.layer1_filter_shape, 107 | poolsize = (2,2), 108 | W = self.exist_params[5][0], 109 | b = self.exist_params[5][1], 110 | activation = acti_func) 111 | 112 | self.layer2 = LeNetConvPoolLayer(self.rng, 113 | input = self.layer1.output, 114 | image_shape = self.layer2_image_shape, 115 | filter_shape = self.layer2_filter_shape, 116 | poolsize = (2,2), 117 | W = self.exist_params[4][0], 118 | b = self.exist_params[4][1], 119 | activation = acti_func) 120 | 121 | self.layer3 = LeNetConvPoolLayer(self.rng, 122 | input = self.layer2.output, 123 | image_shape = self.layer3_image_shape, 124 | filter_shape = self.layer3_filter_shape, 125 | poolsize = (2,2), 126 | W = self.exist_params[3][0], 127 | b = self.exist_params[3][1], 128 | activation = acti_func) 129 | 130 | self.layer4 = LeNetConvLayer(self.rng, 131 | input = self.layer3.output, 132 | image_shape = self.layer4_image_shape, 133 | filter_shape = self.layer4_filter_shape, 134 | W = self.exist_params[2][0], 135 | b = self.exist_params[2][1], 136 | activation = acti_func) 137 | 138 | # deepid_input = layer4.output.flatten(2) 139 | 140 | layer3_output_flatten = self.layer3.output.flatten(2) 141 | layer4_output_flatten = self.layer4.output.flatten(2) 142 | deepid_input = T.concatenate([layer3_output_flatten, layer4_output_flatten], axis=1) 143 | 144 | self.deepid_layer = HiddenLayer(self.rng, 145 | input = deepid_input, 146 | n_in = numpy.prod( self.result_image_shape[1:] ) + numpy.prod( self.layer4_image_shape[1:] ), 147 | # n_in = numpy.prod( self.result_image_shape[1:] ), 148 | n_out = n_hidden, 149 | W = self.exist_params[1][0], 150 | b = self.exist_params[1][1], 151 | activation = acti_func) 152 | self.softmax_layer = LogisticRegression( 153 | input = self.deepid_layer.output, 154 | n_in = n_hidden, 155 | n_out = n_out, 156 | W = self.exist_params[0][0], 157 | b = self.exist_params[0][1]) 158 | 159 | self.cost = self.softmax_layer.negative_log_likelihood(self.y) 160 | 161 | def build_test_valid_model(self): 162 | self.test_valid_model = theano.function(inputs=[self.index], 163 | outputs=self.softmax_layer.errors(self.y), 164 | givens = { 165 | self.x: self.valid_set_x[self.index * self.batch_size : (self.index+1) * self.batch_size], 166 | self.y: self.valid_set_y[self.index * self.batch_size : (self.index+1) * self.batch_size]} 167 | ) 168 | 169 | def build_test_train_model(self): 170 | self.test_train_model = theano.function(inputs=[self.index], 171 | outputs=self.softmax_layer.errors(self.y), 172 | givens = { 173 | self.x: self.train_set_x[self.index * self.batch_size : (self.index+1) * self.batch_size], 174 | self.y: self.train_set_y[self.index * self.batch_size : (self.index+1) * self.batch_size]} 175 | ) 176 | 177 | def build_train_model(self): 178 | self.params = self.softmax_layer.params + self.deepid_layer.params + self.layer4.params \ 179 | + self.layer3.params + self.layer2.params + self.layer1.params 180 | gparams = [] 181 | for param in self.params: 182 | gparam = T.grad(self.cost, param) 183 | gparams.append(gparam) 184 | updates = [] 185 | for param, gparam in zip(self.params, gparams): 186 | updates.append((param, param - self.step_rate * gparam)) 187 | 188 | self.train_model = theano.function(inputs=[self.index, self.step_rate], 189 | outputs=self.cost, 190 | updates=updates, 191 | givens = { 192 | self.x: self.train_set_x[self.index * self.batch_size : (self.index+1) * self.batch_size], 193 | self.y: self.train_set_y[self.index * self.batch_size : (self.index+1) * self.batch_size]} 194 | ) 195 | 196 | def train(self, n_epochs, learning_rate): 197 | print 'Training the model ...' 198 | train_sample_num = self.train_set_x.get_value(borrow=True).shape[0] 199 | valid_sample_num = self.valid_set_x.get_value(borrow=True).shape[0] 200 | 201 | loss_records = [] 202 | 203 | epoch = self.exist_params[-1] 204 | while epoch < n_epochs: 205 | train_losses = [] 206 | for minibatch_index in xrange( self.n_train_batches ): 207 | minibatch_cost = self.train_model(minibatch_index, learning_rate) 208 | train_loss = self.test_train_model(minibatch_index) 209 | train_losses.append(train_loss) 210 | 211 | line = '\r\tepoch %i, minibatch_index %i/%i, error %f' % (epoch, minibatch_index, self.n_train_batches, train_loss) 212 | sys.stdout.write(line) 213 | sys.stdout.flush() 214 | 215 | valid_losses = [self.test_valid_model(i) for i in xrange( self.n_valid_batches) ] 216 | 217 | train_score = numpy.mean(train_losses) 218 | valid_score = numpy.mean(valid_losses) 219 | loss_records.append((epoch, train_score, valid_score)) 220 | print '\nepoch %i, train_score %f, valid_score %f' % (epoch, train_score, valid_score) 221 | 222 | params = [self.softmax_layer.params, 223 | self.deepid_layer.params, 224 | self.layer4.params, 225 | self.layer3.params, 226 | self.layer2.params, 227 | self.layer1.params, 228 | valid_score, 229 | train_score, 230 | epoch] 231 | self.pd_helper.dump(params) 232 | epoch += 1 233 | return loss_records 234 | 235 | 236 | def simple_deepid(learning_rate, n_epochs, dataset, params_file, 237 | nkerns, batch_size, n_hidden, n_out, acti_func): 238 | pd_helper = ParamDumpHelper(params_file) 239 | deepid = DeepID(pd_helper) 240 | deepid.load_data_deepid(dataset, batch_size) 241 | deepid.layer_params(nkerns) 242 | deepid.build_layer_architecture(n_hidden, n_out, acti_func) 243 | deepid.build_test_train_model() 244 | deepid.build_test_valid_model() 245 | deepid.build_train_model() 246 | loss_records = deepid.train(n_epochs, learning_rate) 247 | 248 | print '' 249 | for record in loss_records: 250 | print record[0], record[1], record[2] 251 | 252 | 253 | if __name__ == '__main__': 254 | if len(sys.argv) != 4: 255 | print 'Usage: python %s vec_valid vec_train params_file' % (sys.argv[0]) 256 | sys.exit() 257 | simple_deepid(learning_rate=0.01, n_epochs=20, dataset=(sys.argv[1], sys.argv[2]), params_file=sys.argv[3], nkerns=[20,40,60,80], batch_size=500, n_hidden=160, n_out=1595, acti_func=relu) 258 | --------------------------------------------------------------------------------