├── LICENSE ├── LR_SGD.py ├── README.md └── tripletloss.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Noel Codella 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LR_SGD.py: -------------------------------------------------------------------------------- 1 | from keras.legacy import interfaces 2 | import keras.backend as K 3 | from keras.optimizers import Optimizer 4 | 5 | # Taken, with permission, from https://ksaluja15.github.io/Learning-Rate-Multipliers-in-Keras/ 6 | 7 | class LR_SGD(Optimizer): 8 | """Stochastic gradient descent optimizer. 9 | 10 | Includes support for momentum, 11 | learning rate decay, and Nesterov momentum. 12 | 13 | # Arguments 14 | lr: float >= 0. Learning rate. 15 | momentum: float >= 0. Parameter updates momentum. 16 | decay: float >= 0. Learning rate decay over each update. 17 | nesterov: boolean. Whether to apply Nesterov momentum. 18 | """ 19 | 20 | def __init__(self, lr=0.01, momentum=0., decay=0., 21 | nesterov=False,multipliers=None,**kwargs): 22 | super(LR_SGD, self).__init__(**kwargs) 23 | with K.name_scope(self.__class__.__name__): 24 | self.iterations = K.variable(0, dtype='int64', name='iterations') 25 | self.lr = K.variable(lr, name='lr') 26 | self.momentum = K.variable(momentum, name='momentum') 27 | self.decay = K.variable(decay, name='decay') 28 | self.initial_decay = decay 29 | self.nesterov = nesterov 30 | self.lr_multipliers = multipliers 31 | 32 | @interfaces.legacy_get_updates_support 33 | def get_updates(self, loss, params): 34 | grads = self.get_gradients(loss, params) 35 | self.updates = [K.update_add(self.iterations, 1)] 36 | 37 | lr = self.lr 38 | if self.initial_decay > 0: 39 | lr *= (1. / (1. + self.decay * K.cast(self.iterations, 40 | K.dtype(self.decay)))) 41 | # momentum 42 | shapes = [K.int_shape(p) for p in params] 43 | moments = [K.zeros(shape) for shape in shapes] 44 | self.weights = [self.iterations] + moments 45 | for p, g, m in zip(params, grads, moments): 46 | 47 | matched_layer = [x for x in self.lr_multipliers.keys() if x in p.name] 48 | if matched_layer: 49 | new_lr = lr * self.lr_multipliers[matched_layer[0]] 50 | else: 51 | new_lr = lr 52 | 53 | v = self.momentum * m - new_lr * g # velocity 54 | self.updates.append(K.update(m, v)) 55 | 56 | if self.nesterov: 57 | new_p = p + self.momentum * v - new_lr * g 58 | else: 59 | new_p = p + v 60 | 61 | # Apply constraints. 62 | if getattr(p, 'constraint', None) is not None: 63 | new_p = p.constraint(new_p) 64 | 65 | self.updates.append(K.update(p, new_p)) 66 | return self.updates 67 | 68 | def get_config(self): 69 | config = {'lr': float(K.get_value(self.lr)), 70 | 'momentum': float(K.get_value(self.momentum)), 71 | 'decay': float(K.get_value(self.decay)), 72 | 'nesterov': self.nesterov} 73 | base_config = super(LR_SGD, self).get_config() 74 | return dict(list(base_config.items()) + list(config.items())) 75 | 76 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tripletloss-keras-tensorflow 2 | There exist many code examples across the web for implementing triplet-loss objectives in Tensorflow using Keras. However, none of these examples place triplet-loss in the context of other commonly needed functionality, such as loading arbitrarily large non-standard datasets from text files, using data augmentation across the multiple inputs, using an optimizer with a variable learning rate across pre-trained and randomly initialized layers of a network, or loading models with custom optimizers and objective losses. 3 | 4 | Here, with the help of some of the other web sources, I have assembled an example of a triplet-loss objective with all of the common pieces one might need. Hope some find it helpful/useful. 5 | -------------------------------------------------------------------------------- /tripletloss.py: -------------------------------------------------------------------------------- 1 | # Noel C. F. Codella 2 | # Example Triplet Loss Code for Keras / TensorFlow 3 | 4 | # Implementing Improved Triplet Loss from: 5 | # Zhang et al. "Tracking Persons-of-Interest via Adaptive Discriminative Features" ECCV 2016 6 | 7 | # Got help from multiple web sources, including: 8 | # 1) https://stackoverflow.com/questions/47727679/triplet-model-for-image-retrieval-from-the-keras-pretrained-network 9 | # 2) https://ksaluja15.github.io/Learning-Rate-Multipliers-in-Keras/ 10 | # 3) https://keras.io/preprocessing/image/ 11 | # 4) https://github.com/keras-team/keras/issues/3386 12 | # 5) https://github.com/keras-team/keras/issues/8130 13 | 14 | 15 | # GLOBAL DEFINES 16 | T_G_WIDTH = 224 17 | T_G_HEIGHT = 224 18 | T_G_NUMCHANNELS = 3 19 | T_G_SEED = 1337 20 | 21 | # Misc. Necessities 22 | import sys 23 | import ssl # these two lines solved issues loading pretrained model 24 | ssl._create_default_https_context = ssl._create_unverified_context 25 | import numpy as np 26 | import matplotlib.pyplot as plt 27 | import cv2 28 | from scipy.misc import imresize 29 | np.random.seed(T_G_SEED) 30 | 31 | # TensorFlow Includes 32 | import tensorflow as tf 33 | #from tensorflow.contrib.losses import metric_learning 34 | tf.set_random_seed(T_G_SEED) 35 | 36 | # Keras Imports & Defines 37 | import keras 38 | import keras.applications 39 | from keras import backend as K 40 | from keras.models import Model 41 | from keras import optimizers 42 | import keras.layers as kl 43 | 44 | from keras.preprocessing.image import ImageDataGenerator 45 | 46 | # Generator object for data augmentation. 47 | # Can change values here to affect augmentation style. 48 | datagen = ImageDataGenerator( rotation_range=90, 49 | width_shift_range=0.05, 50 | height_shift_range=0.05, 51 | zoom_range=0.1, 52 | horizontal_flip=True, 53 | vertical_flip=True, 54 | ) 55 | 56 | # Local Imports 57 | from LR_SGD import LR_SGD 58 | 59 | # generator function for data augmentation 60 | def createDataGen(X1, X2, X3, Y, b): 61 | 62 | local_seed = T_G_SEED 63 | genX1 = datagen.flow(X1,Y, batch_size=b, seed=local_seed, shuffle=False) 64 | genX2 = datagen.flow(X2,Y, batch_size=b, seed=local_seed, shuffle=False) 65 | genX3 = datagen.flow(X3,Y, batch_size=b, seed=local_seed, shuffle=False) 66 | while True: 67 | X1i = genX1.next() 68 | X2i = genX2.next() 69 | X3i = genX3.next() 70 | 71 | yield [X1i[0], X2i[0], X3i[0]], X1i[1] 72 | 73 | 74 | def createModel(emb_size): 75 | 76 | # Initialize a ResNet50_ImageNet Model 77 | resnet_input = kl.Input(shape=(T_G_WIDTH,T_G_HEIGHT,T_G_NUMCHANNELS)) 78 | resnet_model = keras.applications.resnet50.ResNet50(weights='imagenet', include_top = False, input_tensor=resnet_input) 79 | 80 | # New Layers over ResNet50 81 | net = resnet_model.output 82 | #net = kl.Flatten(name='flatten')(net) 83 | net = kl.GlobalAveragePooling2D(name='gap')(net) 84 | #net = kl.Dropout(0.5)(net) 85 | net = kl.Dense(emb_size,activation='relu',name='t_emb_1')(net) 86 | net = kl.Lambda(lambda x: K.l2_normalize(x,axis=1), name='t_emb_1_l2norm')(net) 87 | 88 | # model creation 89 | base_model = Model(resnet_model.input, net, name="base_model") 90 | 91 | # triplet framework, shared weights 92 | input_shape=(T_G_WIDTH,T_G_HEIGHT,T_G_NUMCHANNELS) 93 | input_anchor = kl.Input(shape=input_shape, name='input_anchor') 94 | input_positive = kl.Input(shape=input_shape, name='input_pos') 95 | input_negative = kl.Input(shape=input_shape, name='input_neg') 96 | 97 | net_anchor = base_model(input_anchor) 98 | net_positive = base_model(input_positive) 99 | net_negative = base_model(input_negative) 100 | 101 | # The Lamda layer produces output using given function. Here its Euclidean distance. 102 | positive_dist = kl.Lambda(euclidean_distance, name='pos_dist')([net_anchor, net_positive]) 103 | negative_dist = kl.Lambda(euclidean_distance, name='neg_dist')([net_anchor, net_negative]) 104 | tertiary_dist = kl.Lambda(euclidean_distance, name='ter_dist')([net_positive, net_negative]) 105 | 106 | # This lambda layer simply stacks outputs so both distances are available to the objective 107 | stacked_dists = kl.Lambda(lambda vects: K.stack(vects, axis=1), name='stacked_dists')([positive_dist, negative_dist, tertiary_dist]) 108 | 109 | model = Model([input_anchor, input_positive, input_negative], stacked_dists, name='triple_siamese') 110 | 111 | # Setting up optimizer designed for variable learning rate 112 | 113 | # Variable Learning Rate per Layers 114 | lr_mult_dict = {} 115 | last_layer = '' 116 | for layer in resnet_model.layers: 117 | # comment this out to refine earlier layers 118 | # layer.trainable = False 119 | # print layer.name 120 | lr_mult_dict[layer.name] = 1 121 | # last_layer = layer.name 122 | lr_mult_dict['t_emb_1'] = 100 123 | 124 | base_lr = 0.0001 125 | momentum = 0.9 126 | v_optimizer = LR_SGD(lr=base_lr, momentum=momentum, decay=0.0, nesterov=False, multipliers = lr_mult_dict) 127 | 128 | model.compile(optimizer=v_optimizer, loss=triplet_loss, metrics=[accuracy]) 129 | 130 | return model 131 | 132 | 133 | def triplet_loss(y_true, y_pred): 134 | margin = K.constant(1) 135 | return K.mean(K.maximum(K.constant(0), K.square(y_pred[:,0,0]) - 0.5*(K.square(y_pred[:,1,0])+K.square(y_pred[:,2,0])) + margin)) 136 | 137 | def accuracy(y_true, y_pred): 138 | return K.mean(y_pred[:,0,0] < y_pred[:,1,0]) 139 | 140 | def l2Norm(x): 141 | return K.l2_normalize(x, axis=-1) 142 | 143 | def euclidean_distance(vects): 144 | x, y = vects 145 | return K.sqrt(K.maximum(K.sum(K.square(x - y), axis=1, keepdims=True), K.epsilon())) 146 | 147 | 148 | # loads an image and preprocesses 149 | def t_read_image(loc): 150 | t_image = cv2.imread(loc) 151 | t_image = cv2.resize(t_image, (T_G_HEIGHT,T_G_WIDTH)) 152 | t_image = t_image.astype("float32") 153 | t_image = keras.applications.resnet50.preprocess_input(t_image, data_format='channels_last') 154 | 155 | return t_image 156 | 157 | # loads a set of images from a text index file 158 | def t_read_image_list(flist, start, length): 159 | 160 | with open(flist) as f: 161 | content = f.readlines() 162 | content = [x.strip().split()[0] for x in content] 163 | 164 | datalen = length 165 | if (datalen < 0): 166 | datalen = len(content) 167 | 168 | if (start + datalen > len(content)): 169 | datalen = len(content) - start 170 | 171 | imgset = np.zeros((datalen, T_G_HEIGHT, T_G_WIDTH, T_G_NUMCHANNELS)) 172 | 173 | for i in range(start, start+datalen): 174 | if ((i-start) < len(content)): 175 | imgset[i-start] = t_read_image(content[i]) 176 | 177 | return imgset 178 | 179 | 180 | def file_numlines(fn): 181 | with open(fn) as f: 182 | return sum(1 for _ in f) 183 | 184 | 185 | def main(argv): 186 | 187 | if len(argv) < 2: 188 | print 'Usage: \n\t -learn \n\t -extract \n\t\tBuilds and scores a triplet-loss model ' 189 | return 190 | 191 | if 'learn' in argv[0]: 192 | learn(argv[1:]) 193 | elif 'extract' in argv[0]: 194 | extract(argv[1:]) 195 | 196 | return 197 | 198 | 199 | def extract(argv): 200 | 201 | if len(argv) < 3: 202 | print 'Usage: \n\t \n\t\tExtracts triplet-loss model' 203 | return 204 | 205 | modelpref = argv[0] 206 | imglist = argv[1] 207 | outfile = argv[2] 208 | 209 | with open(modelpref + '.json', "r") as json_file: 210 | model_json = json_file.read() 211 | 212 | loaded_model = keras.models.model_from_json(model_json) 213 | loaded_model.load_weights(modelpref + '.h5') 214 | 215 | base_model = loaded_model.get_layer('base_model') 216 | 217 | # create a new single input 218 | input_shape=(T_G_WIDTH,T_G_HEIGHT,T_G_NUMCHANNELS) 219 | input_single = kl.Input(shape=input_shape, name='input_single') 220 | 221 | # create a new model without the triple loss 222 | net_single = base_model(input_single) 223 | model = Model(input_single, net_single, name='embedding_net') 224 | 225 | chunksize = 1000 226 | total_img = file_numlines(imglist) 227 | total_img_ch = int(np.ceil(total_img / float(chunksize))) 228 | 229 | with open(outfile, 'w') as f_handle: 230 | 231 | for i in range(0, total_img_ch): 232 | imgs = t_read_image_list(imglist, i*chunksize, chunksize) 233 | 234 | vals = model.predict(imgs) 235 | 236 | np.savetxt(f_handle, vals) 237 | 238 | 239 | return 240 | 241 | 242 | 243 | def learn(argv): 244 | 245 | if len(argv) < 10: 246 | print 'Usage: \n\t \n\t\tLearns triplet-loss model' 247 | return 248 | 249 | in_t_a = argv[0] 250 | in_t_b = argv[1] 251 | in_t_c = argv[2] 252 | 253 | in_v_a = argv[3] 254 | in_v_b = argv[4] 255 | in_v_c = argv[5] 256 | 257 | emb_size = int(argv[6]) 258 | batch = int(argv[7]) 259 | numepochs = int(argv[8]) 260 | outpath = argv[9] 261 | 262 | # chunksize is the number of images we load from disk at a time 263 | chunksize = batch*100 264 | total_t = file_numlines(in_t_a) 265 | total_v = file_numlines(in_v_b) 266 | total_t_ch = int(np.ceil(total_t / float(chunksize))) 267 | total_v_ch = int(np.ceil(total_v / float(chunksize))) 268 | 269 | print 'Dataset has ' + str(total_t) + ' training triplets, and ' + str(total_v) + ' validation triplets.' 270 | 271 | print 'Creating a model ...' 272 | model = createModel(emb_size) 273 | 274 | print 'Training loop ...' 275 | 276 | # manual loop over epochs to support very large sets of triplets 277 | for e in range(0, numepochs): 278 | 279 | for t in range(0, total_t_ch): 280 | 281 | print 'Epoch ' + str(e) + ': train chunk ' + str(t+1) + '/ ' + str(total_t_ch) + ' ...' 282 | 283 | print 'Reading image lists ...' 284 | anchors_t = t_read_image_list(in_t_a, t*chunksize, chunksize) 285 | positives_t = t_read_image_list(in_t_b, t*chunksize, chunksize) 286 | negatives_t = t_read_image_list(in_t_c, t*chunksize, chunksize) 287 | Y_train = np.random.randint(2, size=(1,2,anchors_t.shape[0])).T 288 | 289 | print 'Starting to fit ...' 290 | # This method does NOT use data augmentation 291 | # model.fit([anchors_t, positives_t, negatives_t], Y_train, epochs=numepochs, batch_size=batch) 292 | 293 | # This method uses data augmentation 294 | model.fit_generator(generator=createDataGen(anchors_t,positives_t,negatives_t,Y_train,batch), steps_per_epoch=len(Y_train) / batch, epochs=1, shuffle=False, use_multiprocessing=True) 295 | 296 | # In case the validation images don't fit in memory, we load chunks from disk again. 297 | val_res = [0.0, 0.0] 298 | total_w = 0.0 299 | for v in range(0, total_v_ch): 300 | 301 | print 'Loading validation image lists ...' 302 | print 'Epoch ' + str(e) + ': val chunk ' + str(v+1) + '/ ' + str(total_v_ch) + ' ...' 303 | anchors_v = t_read_image_list(in_v_a, v*chunksize, chunksize) 304 | positives_v = t_read_image_list(in_v_b, v*chunksize, chunksize) 305 | negatives_v = t_read_image_list(in_v_c, v*chunksize, chunksize) 306 | Y_val = np.random.randint(2, size=(1,2,anchors_v.shape[0])).T 307 | 308 | # Weight of current validation measurement. 309 | # if loaded expected number of items, this will be 1.0, otherwise < 1.0, and > 0.0. 310 | w = float(anchors_v.shape[0]) / float(chunksize) 311 | total_w = total_w + w 312 | 313 | curval = model.evaluate([anchors_v, positives_v, negatives_v], Y_val, batch_size=batch) 314 | val_res[0] = val_res[0] + w*curval[0] 315 | val_res[1] = val_res[1] + w*curval[1] 316 | 317 | val_res = [x / total_w for x in val_res] 318 | 319 | print 'Validation Results: ' + str(val_res) 320 | 321 | print 'Saving model ...' 322 | 323 | # Save the model and weights 324 | model.save(outpath + '.h5') 325 | 326 | # Due to some remaining Keras bugs around loading custom optimizers 327 | # and objectives, we save the model architecture as well 328 | model_json = model.to_json() 329 | with open(outpath + '.json', "w") as json_file: 330 | json_file.write(model_json) 331 | 332 | return 333 | 334 | 335 | # Main Driver 336 | if __name__ == "__main__": 337 | main(sys.argv[1:]) 338 | --------------------------------------------------------------------------------