├── requirements.txt ├── README.md ├── customLossFunctions.py ├── 12bits_NUS.py ├── utils.py └── optimizers.py /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.16.3 2 | tensorflow>=1.0.0 3 | keras==2.2.4 4 | opencv-python==4.1.0 5 | h5py==2.9.0 6 | scipy==0.18.1 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WDHT 2 | Implementation of Weakly Supervised Deep Image Hashing through Tag Embeddings 3 | 4 | This repository is an implementation of the paper https://arxiv.org/abs/1806.05804. In the current implementation, we used ResNet50 of the Keras library instead of AlexNet as reported in the paper. This is due to unavailability of AlexNet model in Keras. Consequently, we are able to achieve slightly more accuaracy than reported in the paper. 5 | 6 | Installation: 7 | 8 | 1. This code is built using the following set-up 9 | - Ubuntu 14.0.4 10 | - Python 2.7 11 | - libraries in `requirements.txt` 12 | 2. Replace the optimizers.py file in the Keras directory (/usr/local/lib/python2.7/dist-packages/keras) with the one in the current directory. The modified file contains a new class `FineTuneSGD` which is used in 12bits_NUS.py as an optimizer. 13 | 3. Download the processed dataset from https://www.dropbox.com/s/f48rnct40mjluhl/nusWide.hdf5?dl=0 and keep this file in a sub folder called 'data' in the code folder, WDHT (The data-set is processed to be in 'BGR' color order). 14 | 4. Download the pretrained weights from https://www.dropbox.com/s/tvqs7l6kqvfwgdi/weights_12bits_NUS.h5?dl=0 and keep this folder in a sub folder called 'weights' in the code folder, WDHT. 15 | 16 | Evaluation: 17 | Change the phase to 'Testing' in the "12bits_NUS.py" code and directly obtain the mean average precision. The pretrained weights are already downloaded as a part of the installation process above. 18 | 19 | Training: 20 | Change the phase to 'Training' to train the network as given in the paper. 21 | 22 | -------------------------------------------------------------------------------- /customLossFunctions.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import keras.backend as K 3 | import numpy as np 4 | import pdb 5 | 6 | def catCrossEntr(l1=1): 7 | def loss(y_true, y_pred): 8 | classProb = y_pred 9 | return l1*K.categorical_crossentropy(y_true, classProb) 10 | return loss 11 | 12 | 13 | def quantizationLoss(l2=0.01, nbits = 12.0): 14 | def loss(y_true, y_pred): 15 | activs = y_pred 16 | curLoss = -(1./float(nbits))*l2*(K.sum((K.square(activs - 0.5)), axis=1)) 17 | return curLoss 18 | return loss 19 | 20 | def equiProbBits(l3=1): 21 | def loss(y_true, y_pred): 22 | activs = y_pred 23 | curLoss = l3*K.square(K.abs(K.mean(activs, axis =1)-0.5)) 24 | return curLoss 25 | return loss 26 | 27 | 28 | def dahLoss(): 29 | def loss(y_true, y_pred): 30 | epsilon = 0.000001 31 | D = y_pred - epsilon 32 | S = y_true 33 | beta = 0.9 34 | #curLoss = S - (1 - D) 35 | curLoss = -1*beta*S*K.log(D) - (1-beta)*(1-S)*K.log(1-D) 36 | #curLoss = K.binary_crossentropy(S, D) 37 | return curLoss 38 | return loss 39 | 40 | 41 | def contrastive(l2 = 1.0, m = 3.0): 42 | def loss(y_true, y_pred): 43 | #S = 1.0 - y_true 44 | S = y_true 45 | D = y_pred 46 | total = K.sum(K.sum(y_true)) + K.sum(K.sum(1.0-y_true)) 47 | beta = K.sum(K.sum(y_true))/total 48 | #pdb.set_trace() 49 | #print(beta) 50 | curLoss = l2*(S*(1-beta)*K.square(D) + (1-S)*beta*K.square(K.maximum(0, m - D))) 51 | return curLoss 52 | return loss 53 | 54 | def dummy(l2 = 1.0): 55 | def loss(y_true, y_pred): 56 | #S = 1.0 - y_true 57 | TrueDist = 1.0 - y_true 58 | D = y_pred 59 | return K.mean(K.square(TrueDist - D), axis=-1) 60 | # m = 3.0 61 | # beta = 0.1#0.455078125#K.sum(K.sum(y_true))/1024. 62 | # #pdb.set_trace() 63 | # #print(beta) 64 | # curLoss = l2*(S*(1-beta)*K.square(D) + (1-S)*beta*K.square(K.maximum(0, m - D))) 65 | # return curLoss 66 | return loss 67 | 68 | 69 | def dahLossDummy(y_true, y_pred): 70 | epsilon = 0.000001 71 | D = y_pred - epsilon 72 | S = y_true 73 | beta = 0.5 74 | #curLoss = S - (1 - D) 75 | #curLoss = -1*beta*K.dot(S,K.log(D)) - (1-beta)*K.dot((1-S),K.log(1-D)) 76 | curLoss = K.binary_crossentropy(S, D) 77 | return curLoss 78 | 79 | def vectorLoss(l2=1.0, m = 0.25): 80 | def loss(y_true, y_pred): 81 | multiplier = K.ones((50, 50))-K.eye(50) 82 | #pdb.set_trace() 83 | multiplier = K.expand_dims(multiplier, axis=0) 84 | multiplier = K.repeat_elements(multiplier, 50, 0) 85 | curLoss = K.maximum((m - y_pred)*multiplier, 0.) 86 | return l2*curLoss 87 | return loss 88 | -------------------------------------------------------------------------------- /12bits_NUS.py: -------------------------------------------------------------------------------- 1 | from keras import optimizers 2 | import keras 3 | from keras.layers import Dense, Dot, Dropout, Activation, Input 4 | from keras.layers.merge import Subtract 5 | from keras.layers.core import Lambda, Reshape, RepeatVector 6 | from keras.models import Model 7 | from keras.callbacks import Callback 8 | from keras.preprocessing.image import ImageDataGenerator 9 | import keras.backend as K 10 | from keras.callbacks import ModelCheckpoint 11 | import pdb 12 | import sys 13 | import utils 14 | import numpy as np 15 | import datetime 16 | import json 17 | import utils 18 | import cv2 19 | import random 20 | import h5py 21 | from customLossFunctions import catCrossEntr, quantizationLoss, equiProbBits, dahLoss, contrastive, vectorLoss, dummy 22 | from scipy import io as sio 23 | from scipy.spatial.distance import cdist 24 | import ast 25 | 26 | from keras.applications import ResNet50 27 | from keras.layers import Dense, GlobalAveragePooling2D 28 | 29 | lambda1 = 10.0 30 | lambda2 = 1.0 31 | margin = 0.2 32 | MODEL_DIR = './../weights/weights_12bits_NUS.h5' 33 | IMAGE_WIDTH = IMAGE_HEIGHT = 227 34 | batch_size = 50 35 | phase = 'Testing' 36 | retainTop = False 37 | nEpochs = 50 38 | nBits = 12 39 | nClasses = 21 40 | totalTrainSamples = 100000 41 | totalTestSamples = 2000 42 | totalGallerySamples =100000 43 | 44 | f = h5py.File('./../data/nusWide.hdf5', 'r') 45 | 46 | trainDataImages = f['train_img'] 47 | trainDataLabels = f['train_label'] 48 | trainDataVectors = f['train_vector'] 49 | 50 | 51 | testDataImages = f['test_img'] 52 | testDataLabels = f['test_label'] 53 | testDataVectors = f['test_vector'] 54 | 55 | 56 | def test(): 57 | galleryHashes = np.zeros((int(totalGallerySamples/batch_size)*batch_size, nBits)) 58 | queryHashes = np.zeros((int(totalTestSamples/batch_size)*batch_size, nBits)) 59 | galleryCls = np.zeros((int(totalGallerySamples/batch_size)*batch_size, nClasses)) 60 | queryCls = np.zeros((int(totalTestSamples/batch_size)*batch_size, nClasses)) 61 | DG_Gal = data_generator(totalSamples = batch_size*int(totalGallerySamples/batch_size), batch_size = batch_size, dataset='G', phase ='Test', augmentation=False, shuffle=False) 62 | for j in range(int(totalGallerySamples/batch_size)): 63 | data, lab = next(DG_Gal) 64 | [dummy1, h, dummy2] = multiLab.predict(data, batch_size=batch_size) 65 | galleryHashes[j*batch_size:(j+1)*batch_size,:] = np.asarray(h > 0.5, dtype='int32') 66 | galleryCls[j*batch_size:(j+1)*batch_size] = lab 67 | if j%batch_size == 0: 68 | print("Generated batch: "+str(j)) 69 | 70 | DG_Que = data_generator(totalSamples = batch_size*int(totalTestSamples/batch_size), batch_size = batch_size, dataset = 'V', phase = 'Test', augmentation=False, shuffle=False) 71 | for j in range(int(totalTestSamples/batch_size)): 72 | data, lab = next(DG_Que) 73 | [dummy1, h, dummy2] = multiLab.predict(data, batch_size=batch_size) 74 | queryHashes[j*batch_size:(j+1)*batch_size,:] = np.asarray(h > 0.5, dtype='int32') 75 | queryCls[j*batch_size:(j+1)*batch_size] = lab 76 | if j%batch_size == 0: 77 | print("Generated batch: "+str(j)) 78 | MAP = utils.getMAP(queryLabels = queryCls, databaseLabels = galleryCls, queryHashes = queryHashes, databaseHashes= galleryHashes, curType='zeroOne', typeOfData='multiLabelled') 79 | print("MAP for top 5000 retrieved is: "+str(MAP)) 80 | 81 | class saveWeights(Callback): 82 | def __init__(self): 83 | self.count = 0 84 | 85 | def on_train_begin(self, logs={}): 86 | test() 87 | pass 88 | 89 | def on_batch_end(self, batch, logs={}): 90 | pass 91 | 92 | def on_epoch_end(self, epoch, logs={}): 93 | pass 94 | # multiLab.save_weights('./../pretrainedWeights/weights_12bits_NUS_epoch_'+str(self.count)+'.h5') 95 | # if self.count % 5 == 0 and self.count >= 0: 96 | # test() 97 | # self.count = self.count + 1 98 | 99 | def on_train_end(self, logs={}): 100 | pass 101 | 102 | params = { 103 | "rotation_range": 6, 104 | "width_shift_range": 0.1, 105 | "height_shift_range": 0.1, 106 | "shear_range": 0.2, 107 | "zoom_range": 0.1, 108 | "horizontal_flip": True, 109 | "fill_mode": 'reflect' 110 | } 111 | image_generator = ImageDataGenerator(**params) 112 | 113 | def data_generator(totalSamples, batch_size=batch_size, dataset = 'T', phase='Train', augmentation=False, shuffle=False): 114 | global image_generator 115 | batch_count = totalSamples// batch_size 116 | if dataset == 'T': 117 | images = trainDataImages 118 | vectors = trainDataVectors 119 | labels = trainDataLabels 120 | elif dataset == 'V': 121 | images = testDataImages 122 | vectors = testDataVectors 123 | labels = testDataLabels 124 | elif dataset == 'G': 125 | images = trainDataImages 126 | vectors = trainDataVectors 127 | labels = trainDataLabels 128 | while True: 129 | if shuffle: 130 | images, vectors = utils.shuffleInUnison(images, vectors) 131 | for i in range(batch_count): 132 | curBatchImages = images[i*batch_size:(i+1)*batch_size] 133 | curBatchVectorsTemp = vectors[i*batch_size:(i+1)*batch_size] 134 | curBatchLabels = labels[i*batch_size:(i+1)*batch_size] 135 | m_j = np.array([curBatchVectorsTemp,]*batch_size) 136 | m_n = np.transpose(m_j, (1, 0, 2)) 137 | curBatchVectors = m_n - m_j 138 | curBatchImages = utils.cropImages(curBatchImages, cropHeight=224, cropWidth=224) 139 | curBatchImages = np.transpose(curBatchImages, (0, 2, 3, 1)) 140 | # if model == 'Alexnet': 141 | # curBatchImages = curBatchImages[:,::-1,:,:] 142 | sim = cdist(curBatchVectorsTemp,curBatchVectorsTemp , 'cosine') 143 | if augmentation: 144 | seed = random.randint(1, 1e7) 145 | curBatchImages = next(image_generator.flow(curBatchImages, batch_size=batch_size, seed=seed, shuffle=False)) 146 | if phase == 'Train': 147 | yield [curBatchImages, curBatchVectors], [np.zeros((batch_size, batch_size, batch_size)), np.zeros((batch_size, nBits)), sim] 148 | elif phase == 'Test': 149 | yield [curBatchImages, curBatchVectors], curBatchLabels 150 | 151 | model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3), pooling='avg') 152 | firstLayer, lastLayer = model.input, model.output 153 | 154 | def normalize(x): 155 | x = K.l2_normalize(x, axis=-1) 156 | return x 157 | 158 | def permuteDims(x): 159 | x = K.permute_dimensions(x, (1, 0, 2)) 160 | return x 161 | 162 | def computeDistancesforContrastive(x): 163 | D = K.sum(K.square(x),axis=-1)/float(nBits) 164 | return D 165 | 166 | def dist(x): 167 | return K.sum(x, axis=-1)/float(batch_size) 168 | 169 | vectorInput = Input(shape=(batch_size, 300)) 170 | if not retainTop: 171 | dense_4 = Dense(256, name='dense_4')(lastLayer) 172 | activ_4 = Activation('relu', name='activ_4')(dense_4) 173 | dense_5 = Dense(nBits, name='dense_5')(activ_4) 174 | output_hash = Activation('sigmoid', name='sigmoid')(dense_5) 175 | activ_5_rep = RepeatVector(batch_size)(output_hash) 176 | perm_activ_5_rep = Lambda(permuteDims)(activ_5_rep) 177 | activ_6_rep = RepeatVector(batch_size)(output_hash) 178 | mergeLayer = Subtract()([perm_activ_5_rep, activ_6_rep]) 179 | distances = Lambda(computeDistancesforContrastive, output_shape=(batch_size,))(mergeLayer) 180 | dense_6 = Dense(300, name='dense_6')(activ_4) 181 | activ_7 = Activation('tanh', name='tanh')(dense_6) 182 | vectorsPred = RepeatVector(batch_size)(activ_7) 183 | cosine = Dot(axes=2)([vectorInput, vectorsPred]) 184 | multiLab = Model(inputs=[firstLayer, vectorInput], outputs=[cosine, output_hash, distances]) 185 | else: 186 | multiLab = model 187 | 188 | last_layer_variables = list() 189 | multiLab_len = len(multiLab.layers) 190 | model_len = len(model.layers) 191 | counter = 0 192 | for layer in multiLab.layers: 193 | counter = counter + 1 194 | if counter > model_len: 195 | last_layer_variables.extend(layer.weights) 196 | 197 | # FineTuneSGD implemented using https://github.com/fchollet/keras/issues/5920 198 | 199 | if phase == 'Testing': 200 | multiLab.load_weights(MODEL_DIR) 201 | test() 202 | multiLab.compile(loss=[vectorLoss(l2=lambda1, m=margin), quantizationLoss(l2=lambda2, nbits=nBits), 'mean_squared_error'], 203 | optimizer=optimizers.FineTuneSGD(exception_vars=last_layer_variables, lr=0.001, momentum=0.9, multiplier=0.1)) 204 | 205 | if phase == 'Training': 206 | saveweights = saveWeights() 207 | print("Learning") 208 | multiLab.fit_generator( 209 | data_generator(totalSamples = batch_size*int(totalTrainSamples/batch_size), batch_size = batch_size, dataset = 'T'), 210 | steps_per_epoch=int(totalTrainSamples/batch_size), 211 | epochs=nEpochs, 212 | verbose=1, 213 | validation_data=data_generator(totalSamples = batch_size*int(totalTestSamples/batch_size), batch_size = batch_size, dataset='V'), 214 | validation_steps=int(totalTestSamples/batch_size), callbacks=[saveweights]) 215 | multiLab.save_weights(MODEL_DIR) 216 | 217 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from scipy import io as sio 2 | from scipy.misc import imresize 3 | import numpy as np 4 | import pdb 5 | from scipy.spatial.distance import cdist 6 | 7 | def getData(dataset='CIFAR10', channels_last=True): 8 | if dataset == 'CIFAR10': 9 | #This matrix is made by the MATLAB/MatConvNet/DPSH_IJCAI_ version 1.0_beta23 code. As per the code, the data should be in RGB format(verified visually) 10 | data = sio.loadmat('./datasets/cifar-10.mat') 11 | trainData = data['train_data'] 12 | trainLabels = data['train_L'] 13 | queryData = data['test_data'] 14 | queryLabels = data['test_L'] 15 | galleryData = data['data_set'] 16 | galleryLabels = data['dataset_L'] 17 | if channels_last: 18 | trainData = np.transpose(trainData, (3, 0, 1, 2)) 19 | queryData = np.transpose(queryData, (3, 0, 1, 2)) 20 | galleryData = np.transpose(galleryData, (3, 0, 1, 2)) 21 | else: 22 | raise NotImplementedError 23 | return trainData, trainLabels, queryData, queryLabels, galleryData, galleryLabels 24 | 25 | def makeDataSet(imageIds, labels, labelType = 'oneHot', nImagesPerClassTrain=500, nImagesPerClassTest = 100): 26 | if labelType == 'oneHot': 27 | nClasses = labels.shape[1] 28 | nTrainImages = int(imageIds.shape[0]*0.7) 29 | trainImages = imageIds[0:nTrainImages] 30 | testImages = imageIds[nTrainImages:] 31 | trainLabels = labels[0:nTrainImages] 32 | testLabels = labels[nTrainImages:] 33 | trainSetImageIds = np.zeros((nClasses, nImagesPerClassTrain), dtype='uint32') 34 | trainSetLabels = np.zeros((nClasses, nImagesPerClassTrain, nClasses)) 35 | for i in range(nClasses): 36 | consider = trainLabels[:, i] == 1 37 | curImageIds = trainImages[consider] 38 | curLabels = trainLabels[consider,:] 39 | curImageIds, curLabels = shuffleInUnison(curImageIds, curLabels) 40 | trainSetImageIds[i,:] = np.reshape(curImageIds[0:nImagesPerClassTrain], (nImagesPerClassTrain,)) 41 | trainSetLabels[i, :, :] = curLabels[0:nImagesPerClassTrain] 42 | testSetImageIds = np.zeros((nClasses, nImagesPerClassTest), dtype='uint32') 43 | testSetLabels = np.zeros((nClasses, nImagesPerClassTest, nClasses)) 44 | for i in range(nClasses): 45 | consider = testLabels[:, i] == 1 46 | curImageIds = testImages[consider] 47 | curLabels = testLabels[consider,:] 48 | curImageIds, curLabels = shuffleInUnison(curImageIds, curLabels) 49 | testSetImageIds[i,:] = np.reshape(curImageIds[0:nImagesPerClassTest], (nImagesPerClassTest,)) 50 | testSetLabels[i, :, :] = curLabels[0:nImagesPerClassTest] 51 | trainSetImageIds = np.reshape(trainSetImageIds, (nClasses*nImagesPerClassTrain,)) 52 | trainSetLabels = np.reshape(trainSetLabels, (nClasses*nImagesPerClassTrain, nClasses)) 53 | testSetImageIds = np.reshape(testSetImageIds, (nClasses*nImagesPerClassTest,)) 54 | testSetLabels = np.reshape(testSetLabels, (nClasses*nImagesPerClassTest,nClasses)) 55 | return trainSetImageIds, trainSetLabels, testSetImageIds, testSetLabels 56 | 57 | 58 | def resizeImages(images, resizeHeight=256, resizeWidth = 256): 59 | resizedImages = np.zeros((images.shape[0], 3, resizeHeight, resizeWidth)) 60 | for i in range(resizedImages.shape[0]): 61 | resizedImages[i,:,:,:] = np.transpose(imresize(images[i], (resizeHeight, resizeWidth)), (2, 0, 1)) 62 | return resizedImages 63 | 64 | 65 | def cropImages(images, cropHeight=227, cropWidth=227): 66 | croppedImages = np.zeros((images.shape[0], 3, cropHeight, cropWidth)) 67 | for i in range(croppedImages.shape[0]): 68 | randX = np.random.randint(images.shape[2]-cropHeight) 69 | randY = np.random.randint(images.shape[3]-cropWidth) 70 | croppedImages[i,:,:,:] = images[i,:,randX:randX+cropHeight,randY:randY+cropWidth] 71 | return croppedImages 72 | 73 | 74 | def meanSubtract(images, sourceDataSet='IMAGENET', order='RGB'): 75 | if order == 'RGB': 76 | images[:, 0, :, :] -= 123.68 77 | images[:, 1, :, :] -= 116.779 78 | images[:, 2, :, :] -= 103.939 # values copied from https://github.com/heuritech/convnets-keras/blob/master/convnetskeras/convnets.py 79 | elif order == 'BGR': 80 | images[:, 0, :, :] -= 103.939 81 | images[:, 1, :, :] -= 116.779 82 | images[:, 2, :, :] -= 123.68 # values copied from https://github.com/heuritech/convnets-keras/blob/master/convnetskeras/convnets.py 83 | #this is the order in which the keras.applications models are trained. 84 | return images 85 | 86 | 87 | def shuffleInUnison(images, labels): 88 | perm = np.random.permutation(images.shape[0]) 89 | images = images[perm] 90 | labels = labels[perm] 91 | return images, labels 92 | 93 | 94 | def generatePairs(images, labels, batch_size): 95 | n_classes = np.unique(labels).shape[0] 96 | images_classwise = np.zeros((n_classes, images.shape[0]/n_classes, images.shape[1], images.shape[2], images.shape[3])) 97 | for i in range(n_classes): 98 | curClass = labels == i 99 | images_classwise[i,:,:,:,:] = images[curClass,:,:,:] 100 | randomLabels = np.random.randint(10, size=batch_size) 101 | simLabels = randomLabels[0:batch_size/2] 102 | dissimLabels = randomLabels[batch_size/2:batch_size] 103 | imagePairs = [] 104 | similarity = [] 105 | queryLabs = [] 106 | databaseLabs = [] 107 | for i in range(len(simLabels)): 108 | randomImgNums = np.random.randint(images.shape[0]/n_classes, size=2) 109 | imagePairs.append([images_classwise[simLabels[i], randomImgNums[0], :, :, :], images_classwise[simLabels[i], randomImgNums[1], :, :, :]]) 110 | similarity.append(1) 111 | queryLabs.append(simLabels[i]) 112 | databaseLabs.append(simLabels[i]) 113 | for i in range(len(dissimLabels)): 114 | randomImgNums = np.random.randint(images.shape[0]/n_classes, size=2) 115 | secondImageClass = dissimLabels[i] 116 | while(secondImageClass==dissimLabels[i]): 117 | secondImageClass = np.random.randint(n_classes) 118 | imagePairs.append([images_classwise[dissimLabels[i], randomImgNums[0], :, :, :], images_classwise[secondImageClass, randomImgNums[1], :, :, :]]) 119 | similarity.append(0) 120 | queryLabs.append(dissimLabels[i]) 121 | databaseLabs.append(secondImageClass) 122 | imagePairs = np.array(imagePairs) 123 | similarity = np.array(similarity) 124 | return imagePairs, similarity, np.asarray(queryLabs), np.asarray(databaseLabs) 125 | 126 | 127 | def prepareData(dataset='CIFAR10'): 128 | trainData, trainLabels, queryData, queryLabels, galleryData, galleryLabels = getData(dataset=dataset) 129 | return trainData, trainLabels, queryData, queryLabels, galleryData, galleryLabels 130 | 131 | def multiLabelGetVectors(data, dim=300, nClasses=81, nTags=1000, method='mean'): 132 | vecMat = np.zeros((len(data), dim)) 133 | labels = np.zeros((len(data), nClasses)) 134 | images = np.zeros((len(data), 1)) 135 | tags = np.zeros((len(data), nTags)) 136 | for i in range(len(data)): 137 | curRec = data[i] 138 | curVector = np.zeros((300, )) 139 | images[i] = curRec[0] 140 | labels[i] = curRec[1] 141 | if method == 'mean': 142 | for j in range(len(curRec[2])): 143 | curVector = curVector + curRec[2][j] 144 | tags[i][int(curRec[2][j][1])] = 1 145 | vecMat[i][:] = curVector/float(len(curRec[2])) 146 | elif method == 'idf': 147 | #pdb.set_trace() 148 | avg = 0 149 | for j in range(len(curRec[2])): 150 | curVector = curVector +[x* curRec[2][j][2] for x in curRec[2][j][0]] 151 | tags[i][int(curRec[2][j][1])] = 1 152 | avg = avg + curRec[2][j][2] 153 | vecMat[i][:] = curVector/float(avg) 154 | elif method == 'minFreq': 155 | minFreq = 100 156 | minFrqIndex = -100 157 | for j in range(len(curRec[2])): 158 | if curRec[2][j][2] < minFreq: 159 | tags[i][int(curRec[2][j][1])] = 1 160 | minFreq = curRec[2][j][2] 161 | minFreqIndex = j 162 | vecMat[i][:] = curRec[2][minFreqIndex][0] 163 | elif method == 'cutFreq': 164 | avg = 0.00001 165 | for j in range(len(curRec[2])): 166 | if curRec[2][j][2] > 5.3 and curRec[2][j][2] < 8.2: 167 | curVector = curVector +[x* curRec[2][j][2] for x in curRec[2][j][0]] 168 | tags[i][int(curRec[2][j][1])] = 1 169 | avg = avg + curRec[2][j][2] 170 | vecMat[i][:] = curVector/float(avg) 171 | images = np.array(images, dtype='uint32') 172 | return (images, labels, vecMat, tags) 173 | 174 | def multiLabelGetVectorsNUS(data, dim=300, nClasses=81, nTags=1000, method='mean'): 175 | vecMat = np.zeros((len(data), dim)) 176 | labels = np.zeros((len(data), nClasses)) 177 | images = np.zeros((len(data), 1)) 178 | tags = [] 179 | for i in range(len(data)): 180 | curRec = data[i] 181 | curVector = np.zeros((300, )) 182 | images[i] = curRec[0] 183 | labels[i] = curRec[1] 184 | if method == 'mean': 185 | for j in range(len(curRec[2])): 186 | curVector = curVector + curRec[2][j] 187 | vecMat[i][:] = curVector/float(len(curRec[2])) 188 | tags.append(curRec[3]) 189 | elif method == 'idf': 190 | avg = 0 191 | for j in range(len(curRec[2])): 192 | curVector = curVector +[x* curRec[2][j][2] for x in curRec[2][j][0]] 193 | avg = avg + curRec[2][j][2] 194 | vecMat[i][:] = curVector/float(avg) 195 | elif method == 'minFreq': 196 | minFreq = 100 197 | minFrqIndex = -100 198 | for j in range(len(curRec[2])): 199 | if curRec[2][j][2] < minFreq: 200 | minFreq = curRec[2][j][2] 201 | minFreqIndex = j 202 | vecMat[i][:] = curRec[2][minFreqIndex][0] 203 | elif method == 'cutFreq': 204 | avg = 0.00001 205 | for j in range(len(curRec[2])): 206 | if curRec[2][j][2] > 5.3 and curRec[2][j][2] < 8.2: 207 | curVector = curVector +[x* curRec[2][j][2] for x in curRec[2][j][0]] 208 | avg = avg + curRec[2][j][2] 209 | vecMat[i][:] = curVector/float(avg) 210 | images = np.array(images, dtype='uint32') 211 | return (images, labels, vecMat, tags) 212 | 213 | def multiLabelGetVectorsDelete(data, dim=300, nClasses=81, nTags=1000, method='mean'): 214 | vecMat = np.zeros((len(data), dim)) 215 | labels = np.zeros((len(data), nClasses)) 216 | images = np.zeros((len(data), 1)) 217 | tags = np.zeros((len(data), nTags)) 218 | for i in range(len(data)): 219 | curRec = data[i] 220 | images[i] = curRec[0] 221 | labels[i] = curRec[1] 222 | images = np.array(images, dtype='uint32') 223 | return (images, labels) 224 | 225 | def getTotalWeights(weightsShape): 226 | totalWeights = 1 227 | for i in range(len(weightsShape)): 228 | totalWeights = totalWeights*weightsShape[i] 229 | return totalWeights 230 | 231 | 232 | def checkIfWeightsAreNotLost(model_1, model_2, layerList): 233 | for i in range(len(layerList)): 234 | sameWeights = False 235 | weights1 = model_1.layers[layerList[i]].get_weights()[0] 236 | weights2 = model_2.layers[layerList[i]].get_weights()[0] 237 | if weights1.shape != weights2.shape: 238 | print("Weights Shapes did not match") 239 | break 240 | else: 241 | totalNumberOfWeights = getTotalWeights(weights1.shape) 242 | if np.sum(weights1 == weights2) != totalNumberOfWeights: 243 | print("Weights are different") 244 | break 245 | else: 246 | sameWeights = True 247 | return sameWeights 248 | 249 | def preprocessLabels(labels): 250 | temp = np.sum(labels, axis =0) 251 | temp = np.argsort(temp) 252 | temp = temp[-21:] 253 | labels = labels[:,temp] 254 | temp = np.array(np.sum(labels, axis=-1) !=0, dtype='bool') 255 | labels = labels[temp] 256 | return labels 257 | 258 | def computeSimilarityMatrix(queryLabels, databaseLabels, typeOfData='singleLabelled', type='interOverUnion'): 259 | count = 0 260 | groundTruthSimilarityMatrix = np.zeros((queryLabels.shape[0], databaseLabels.shape[0])) 261 | if typeOfData=='singleLabelled': 262 | for i in range(queryLabels.shape[0]): 263 | groundTruthSimilarityMatrix[i,:] = queryLabels[i] == databaseLabels 264 | elif typeOfData=='multiLabelled': 265 | for i in range(queryLabels.shape[0]): 266 | curQue = queryLabels[i][:] 267 | if sum(curQue) != 0: 268 | threshold = 1 269 | sim = np.sum(np.logical_and(curQue, databaseLabels), axis=-1) 270 | den = np.sum(np.logical_or(curQue, databaseLabels), axis=-1) 271 | count = count + np.sum(np.logical_and(sum(curQue) > 1, sim == 1)) 272 | if type=='zeroOne': 273 | groundTruthSimilarityMatrix[i][np.where(sim >= threshold)[0]] = 1 274 | elif type=='interOverUnion': 275 | groundTruthSimilarityMatrix[i][:] = np.divide(np.array(sim,dtype='float32'),(np.array(den,dtype='float32')+0.00001)) 276 | # for j in range(databaseLabels.shape[0]): 277 | # curDb = databaseLabels[j][:] 278 | # sim = np.sum(np.logical_and(curQue, curDb), axis=-1) 279 | # den = np.sum(np.logical_or(curQue, curDb), axis=-1) 280 | # if type=='zeroOne': 281 | # #pdb.set_trace() 282 | # if sim >= threshold: 283 | # groundTruthSimilarityMatrix[i][j] = 1 284 | # elif type=='interOverUnion': 285 | # groundTruthSimilarityMatrix[i][j] = float(sim)/(float(den)+0.00001) 286 | if type=='zeroOne': 287 | groundTruthSimilarityMatrix = np.asarray(groundTruthSimilarityMatrix, dtype='float32') 288 | elif type=='interOverUnion': 289 | groundTruthSimilarityMatrix = groundTruthSimilarityMatrix > 0.25 290 | groundTruthSimilarityMatrix = np.asarray(groundTruthSimilarityMatrix, dtype='float32') 291 | return groundTruthSimilarityMatrix 292 | 293 | 294 | def calcHammingRank(queryHashes, databaseHashes, space='Hamming'): 295 | hammingDist = np.zeros((queryHashes.shape[0], databaseHashes.shape[0])) 296 | hammingRank = np.zeros((queryHashes.shape[0], databaseHashes.shape[0])) 297 | if space == 'Hamming': 298 | for i in range(queryHashes.shape[0]): 299 | hammingDist[i] = np.reshape(np.sum(np.abs(queryHashes[i] - databaseHashes), axis=1), (databaseHashes.shape[0], )) 300 | hammingRank[i] = np.argsort(hammingDist[i]) 301 | elif space == 'RealValued': 302 | for i in range(queryHashes.shape[0]): 303 | if i % 100 == 0: 304 | print(i) 305 | hammingDist[i] = cdist(np.reshape(queryHashes[i], (1, 300)),databaseHashes , 'cosine') 306 | hammingRank[i] = np.argsort(hammingDist[i]) 307 | return hammingDist, hammingRank 308 | 309 | 310 | def calcMAP(groundTruthSimilarityMatrix, hammingRank, hammingDist): 311 | [Q, N] = hammingRank.shape 312 | pos = np.arange(N)+1 313 | MAP = 0 314 | numSucc = 0 315 | for i in range(Q): 316 | ngb = groundTruthSimilarityMatrix[i, np.asarray(hammingRank[i,:], dtype='int32')] 317 | ngb = ngb[0:N] 318 | nRel = np.sum(ngb) 319 | if nRel > 0: 320 | prec = np.divide(np.cumsum(ngb), pos) 321 | prec = prec[0:5000] 322 | ngb = ngb[0:5000] 323 | ap = np.mean(prec[np.asarray(ngb, dtype='bool')]) 324 | rec = np.array(np.cumsum(ngb)/float(np.sum(groundTruthSimilarityMatrix[i])), dtype='float32') 325 | if i == 0: 326 | precisions = prec 327 | recalls = rec 328 | else: 329 | precisions = precisions + prec 330 | recalls = recalls + rec 331 | MAP = MAP + ap 332 | numSucc = numSucc + 1 333 | precisions = precisions/float(Q) 334 | recalls = recalls/float(Q) 335 | MAP = float(MAP)/numSucc 336 | precisions = [] 337 | recalls = [] 338 | for j in range(8): 339 | countOrNot = np.array(hammingDist <= j, dtype='int32') 340 | newSim = np.multiply(groundTruthSimilarityMatrix, countOrNot) 341 | countOrNot = countOrNot + 0.000001 342 | prec = np.mean(np.divide(np.sum(newSim, axis=-1), np.sum(countOrNot, axis=-1)))# float(np.sum(np.sum(newSim)))/float(np.sum()) 343 | rec = np.mean(np.divide(np.sum(newSim, axis=-1), np.sum(groundTruthSimilarityMatrix, axis=-1))) 344 | precisions.append(prec) 345 | recalls.append(rec) 346 | return MAP, precisions, recalls 347 | 348 | 349 | def getMAP(queryLabels, databaseLabels, queryHashes, databaseHashes, curType, typeOfData='singleLabelled', space='Hamming'): 350 | if typeOfData == 'singleLabelled': 351 | groundTruthSimilarityMatrix = computeSimilarityMatrix(queryLabels, databaseLabels) 352 | elif typeOfData == 'multiLabelled': 353 | groundTruthSimilarityMatrix = computeSimilarityMatrix(queryLabels, databaseLabels, typeOfData='multiLabelled', type = curType) 354 | hammingDist, hammingRank = calcHammingRank(queryHashes, databaseHashes, space) 355 | MAP, precisions, recalls = calcMAP(groundTruthSimilarityMatrix, hammingRank, hammingDist) 356 | precisions = [] 357 | recalls = [] 358 | countOrNot = np.array(hammingDist <= 2, dtype='int32') 359 | newSim = np.multiply(groundTruthSimilarityMatrix, countOrNot) 360 | #pdb.set_trace() 361 | countOrNot = countOrNot + 0.000001 362 | prec = np.mean(np.divide(np.sum(newSim, axis=-1), np.sum(countOrNot, axis=-1)))# float(np.sum(np.sum(newSim)))/float(np.sum()) 363 | rec = np.mean(np.divide(np.sum(newSim, axis=-1), np.sum(groundTruthSimilarityMatrix, axis=-1))) 364 | # for i in range(12): 365 | # countOrNot = np.array(hammingDist <= i, dtype='int32') 366 | # newSim = np.multiply(groundTruthSimilarityMatrix, countOrNot) 367 | # #pdb.set_trace() 368 | # countOrNot = countOrNot + 0.000001 369 | # prec = np.mean(np.divide(np.sum(newSim, axis=-1), np.sum(countOrNot, axis=-1)))# float(np.sum(np.sum(newSim)))/float(np.sum()) 370 | # rec = np.mean(np.divide(np.sum(newSim, axis=-1), np.sum(groundTruthSimilarityMatrix, axis=-1))) 371 | # precisions.append(prec) 372 | # recalls.append(rec) 373 | return (MAP, prec, rec) 374 | 375 | def computeMAPRealValuedSpace(queryLabels, databaseLabels, queryVectors, databaseVectors): 376 | pass 377 | 378 | def computeConfusion(simMat, dist): 379 | #temp = np.array(np.exp(-1*dist) > 0.5, dtype='int32') 380 | temp = np.array(3.0 - dist > 0, dtype='int32') 381 | #pdb.set_trace() 382 | tps = np.sum(np.logical_and(temp == 1, simMat == 1)) 383 | tns = np.sum(np.logical_and(temp == 0, simMat == 0)) 384 | fps = np.sum(np.logical_and(temp == 1, simMat == 0)) 385 | fns = np.sum(np.logical_and(temp == 0, simMat == 1)) 386 | return (tps, tns, fps, fns) 387 | 388 | 389 | 390 | def precisionAtK(queryLabels, databaseLabels, queryHashes, databaseHashes, k, curType, typeOfData='singleLabelled'): 391 | if typeOfData == 'singleLabelled': 392 | groundTruthSimilarityMatrix = computeSimilarityMatrix(queryLabels, databaseLabels) 393 | elif typeOfData == 'multiLabelled': 394 | groundTruthSimilarityMatrix = computeSimilarityMatrix(queryLabels, databaseLabels, typeOfData='multiLabelled', type = curType) 395 | hammingDist, hammingRank = calcHammingRank(queryHashes, databaseHashes) 396 | countOrNot = np.array(hammingDist == k, dtype='int32') 397 | newSim = np.multiply(groundTruthSimilarityMatrix, countOrNot) 398 | precAtK = float(np.sum(np.sum(newSim)))/float(np.sum(np.sum(countOrNot))) 399 | return precAtK 400 | 401 | def getWeightShapesFromModel(model, library='Keras'): 402 | """ 403 | Desc: 404 | 405 | Args: 406 | 407 | Returns: 408 | 409 | 410 | """ 411 | # pdb.set_trace() 412 | weightShapes=[] 413 | if library == 'Keras': 414 | nLayers = len(model.layers) 415 | for i in range(nLayers): 416 | nParamSets = len(model.layers[i].get_weights()) 417 | assert nParamSets%2 == 0 418 | for j in range(int(nParamSets/2)): 419 | weightShapes.append([model.layers[i].get_weights()[2*j].shape, model.layers[i].get_weights()[2*j+1].shape]) 420 | print(weightShapes[-1]) 421 | return weightShapes 422 | -------------------------------------------------------------------------------- /optimizers.py: -------------------------------------------------------------------------------- 1 | """Built-in optimizer classes. 2 | """ 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import six 8 | import copy 9 | from six.moves import zip 10 | 11 | from . import backend as K 12 | from .utils.generic_utils import serialize_keras_object 13 | from .utils.generic_utils import deserialize_keras_object 14 | from .legacy import interfaces 15 | 16 | if K.backend() == 'tensorflow': 17 | import tensorflow as tf 18 | 19 | 20 | def clip_norm(g, c, n): 21 | """Clip the gradient `g` if the L2 norm `n` exceeds `c`. 22 | 23 | # Arguments 24 | g: Tensor, the gradient tensor 25 | c: float >= 0. Gradients will be clipped 26 | when their L2 norm exceeds this value. 27 | n: Tensor, actual norm of `g`. 28 | 29 | # Returns 30 | Tensor, the gradient clipped if required. 31 | """ 32 | if c <= 0: # if clipnorm == 0 no need to add ops to the graph 33 | return g 34 | 35 | # tf require using a special op to multiply IndexedSliced by scalar 36 | if K.backend() == 'tensorflow': 37 | condition = n >= c 38 | then_expression = tf.scalar_mul(c / n, g) 39 | else_expression = g 40 | 41 | # saving the shape to avoid converting sparse tensor to dense 42 | if isinstance(then_expression, tf.Tensor): 43 | g_shape = copy.copy(then_expression.get_shape()) 44 | elif isinstance(then_expression, tf.IndexedSlices): 45 | g_shape = copy.copy(then_expression.dense_shape) 46 | if condition.dtype != tf.bool: 47 | condition = tf.cast(condition, 'bool') 48 | g = tf.cond(condition, 49 | lambda: then_expression, 50 | lambda: else_expression) 51 | if isinstance(then_expression, tf.Tensor): 52 | g.set_shape(g_shape) 53 | elif isinstance(then_expression, tf.IndexedSlices): 54 | g._dense_shape = g_shape 55 | else: 56 | g = K.switch(K.greater_equal(n, c), g * c / n, g) 57 | return g 58 | 59 | 60 | class Optimizer(object): 61 | """Abstract optimizer base class. 62 | 63 | Note: this is the parent class of all optimizers, not an actual optimizer 64 | that can be used for training models. 65 | 66 | All Keras optimizers support the following keyword arguments: 67 | 68 | clipnorm: float >= 0. Gradients will be clipped 69 | when their L2 norm exceeds this value. 70 | clipvalue: float >= 0. Gradients will be clipped 71 | when their absolute value exceeds this value. 72 | """ 73 | 74 | def __init__(self, **kwargs): 75 | allowed_kwargs = {'clipnorm', 'clipvalue'} 76 | for k in kwargs: 77 | if k not in allowed_kwargs: 78 | raise TypeError('Unexpected keyword argument ' 79 | 'passed to optimizer: ' + str(k)) 80 | self.__dict__.update(kwargs) 81 | self.updates = [] 82 | self.weights = [] 83 | 84 | @interfaces.legacy_get_updates_support 85 | def get_updates(self, loss, params): 86 | raise NotImplementedError 87 | 88 | def get_gradients(self, loss, params): 89 | grads = K.gradients(loss, params) 90 | if None in grads: 91 | raise ValueError('An operation has `None` for gradient. ' 92 | 'Please make sure that all of your ops have a ' 93 | 'gradient defined (i.e. are differentiable). ' 94 | 'Common ops without gradient: ' 95 | 'K.argmax, K.round, K.eval.') 96 | if hasattr(self, 'clipnorm') and self.clipnorm > 0: 97 | norm = K.sqrt(sum([K.sum(K.square(g)) for g in grads])) 98 | grads = [clip_norm(g, self.clipnorm, norm) for g in grads] 99 | if hasattr(self, 'clipvalue') and self.clipvalue > 0: 100 | grads = [K.clip(g, -self.clipvalue, self.clipvalue) for g in grads] 101 | return grads 102 | 103 | def set_weights(self, weights): 104 | """Sets the weights of the optimizer, from Numpy arrays. 105 | 106 | Should only be called after computing the gradients 107 | (otherwise the optimizer has no weights). 108 | 109 | # Arguments 110 | weights: a list of Numpy arrays. The number 111 | of arrays and their shape must match 112 | number of the dimensions of the weights 113 | of the optimizer (i.e. it should match the 114 | output of `get_weights`). 115 | 116 | # Raises 117 | ValueError: in case of incompatible weight shapes. 118 | """ 119 | params = self.weights 120 | if len(params) != len(weights): 121 | raise ValueError('Length of the specified weight list (' + 122 | str(len(weights)) + 123 | ') does not match the number of weights ' + 124 | 'of the optimizer (' + str(len(params)) + ')') 125 | weight_value_tuples = [] 126 | param_values = K.batch_get_value(params) 127 | for pv, p, w in zip(param_values, params, weights): 128 | if pv.shape != w.shape: 129 | raise ValueError('Optimizer weight shape ' + 130 | str(pv.shape) + 131 | ' not compatible with ' 132 | 'provided weight shape ' + str(w.shape)) 133 | weight_value_tuples.append((p, w)) 134 | K.batch_set_value(weight_value_tuples) 135 | 136 | def get_weights(self): 137 | """Returns the current value of the weights of the optimizer. 138 | 139 | # Returns 140 | A list of numpy arrays. 141 | """ 142 | return K.batch_get_value(self.weights) 143 | 144 | def get_config(self): 145 | config = {} 146 | if hasattr(self, 'clipnorm'): 147 | config['clipnorm'] = self.clipnorm 148 | if hasattr(self, 'clipvalue'): 149 | config['clipvalue'] = self.clipvalue 150 | return config 151 | 152 | @classmethod 153 | def from_config(cls, config): 154 | return cls(**config) 155 | 156 | 157 | class SGD(Optimizer): 158 | """Stochastic gradient descent optimizer. 159 | 160 | Includes support for momentum, 161 | learning rate decay, and Nesterov momentum. 162 | 163 | # Arguments 164 | lr: float >= 0. Learning rate. 165 | momentum: float >= 0. Parameter that accelerates SGD 166 | in the relevant direction and dampens oscillations. 167 | decay: float >= 0. Learning rate decay over each update. 168 | nesterov: boolean. Whether to apply Nesterov momentum. 169 | """ 170 | 171 | def __init__(self, lr=0.01, momentum=0., decay=0., 172 | nesterov=False, **kwargs): 173 | super(SGD, self).__init__(**kwargs) 174 | with K.name_scope(self.__class__.__name__): 175 | self.iterations = K.variable(0, dtype='int64', name='iterations') 176 | self.lr = K.variable(lr, name='lr') 177 | self.momentum = K.variable(momentum, name='momentum') 178 | self.decay = K.variable(decay, name='decay') 179 | self.initial_decay = decay 180 | self.nesterov = nesterov 181 | 182 | @interfaces.legacy_get_updates_support 183 | def get_updates(self, loss, params): 184 | grads = self.get_gradients(loss, params) 185 | self.updates = [K.update_add(self.iterations, 1)] 186 | 187 | lr = self.lr 188 | if self.initial_decay > 0: 189 | lr = lr * (1. / (1. + self.decay * K.cast(self.iterations, 190 | K.dtype(self.decay)))) 191 | # momentum 192 | shapes = [K.int_shape(p) for p in params] 193 | moments = [K.zeros(shape) for shape in shapes] 194 | self.weights = [self.iterations] + moments 195 | for p, g, m in zip(params, grads, moments): 196 | v = self.momentum * m - lr * g # velocity 197 | self.updates.append(K.update(m, v)) 198 | 199 | if self.nesterov: 200 | new_p = p + self.momentum * v - lr * g 201 | else: 202 | new_p = p + v 203 | 204 | # Apply constraints. 205 | if getattr(p, 'constraint', None) is not None: 206 | new_p = p.constraint(new_p) 207 | 208 | self.updates.append(K.update(p, new_p)) 209 | return self.updates 210 | 211 | def get_config(self): 212 | config = {'lr': float(K.get_value(self.lr)), 213 | 'momentum': float(K.get_value(self.momentum)), 214 | 'decay': float(K.get_value(self.decay)), 215 | 'nesterov': self.nesterov} 216 | base_config = super(SGD, self).get_config() 217 | return dict(list(base_config.items()) + list(config.items())) 218 | 219 | class FineTuneSGD(Optimizer): 220 | """Stochastic gradient descent optimizer. 221 | 222 | Includes support for momentum, 223 | learning rate decay, and Nesterov momentum. 224 | 225 | # Arguments 226 | lr: float >= 0. Learning rate. 227 | momentum: float >= 0. Parameter updates momentum. 228 | decay: float >= 0. Learning rate decay over each update. 229 | nesterov: boolean. Whether to apply Nesterov momentum. 230 | """ 231 | 232 | def __init__(self, exception_vars, multiplier=0.1, lr=0.01, momentum=0., decay=0., 233 | nesterov=False, **kwargs): 234 | super(FineTuneSGD, self).__init__(**kwargs) 235 | with K.name_scope(self.__class__.__name__): 236 | self.iterations = K.variable(0, dtype='int64', name='iterations') 237 | self.lr = K.variable(lr, name='lr') 238 | self.momentum = K.variable(momentum, name='momentum') 239 | self.decay = K.variable(decay, name='decay') 240 | self.initial_decay = decay 241 | self.nesterov = nesterov 242 | self.exception_vars = exception_vars 243 | self.multiplier = multiplier 244 | 245 | @interfaces.legacy_get_updates_support 246 | def get_updates(self, loss, params): 247 | grads = self.get_gradients(loss, params) 248 | self.updates = [K.update_add(self.iterations, 1)] 249 | 250 | lr = self.lr 251 | if self.initial_decay > 0: 252 | lr = lr * (1. / (1. + self.decay * K.cast(self.iterations, 253 | K.dtype(self.decay)))) 254 | # momentum 255 | shapes = [K.int_shape(p) for p in params] 256 | moments = [K.zeros(shape) for shape in shapes] 257 | self.weights = [self.iterations] + moments 258 | for p, g, m in zip(params, grads, moments): 259 | if p not in self.exception_vars: 260 | multiplied_lr = lr * self.multiplier 261 | else: 262 | multiplied_lr = lr 263 | v = self.momentum * m - multiplied_lr * g # velocity 264 | self.updates.append(K.update(m, v)) 265 | 266 | if self.nesterov: 267 | new_p = p + self.momentum * v - multiplied_lr * g 268 | else: 269 | new_p = p + v 270 | 271 | # Apply constraints. 272 | if getattr(p, 'constraint', None) is not None: 273 | new_p = p.constraint(new_p) 274 | 275 | self.updates.append(K.update(p, new_p)) 276 | return self.updates 277 | 278 | def get_config(self): 279 | config = {'lr': float(K.get_value(self.lr)), 280 | 'momentum': float(K.get_value(self.momentum)), 281 | 'decay': float(K.get_value(self.decay)), 282 | 'nesterov': self.nesterov} 283 | base_config = super(SGD, self).get_config() 284 | return dict(list(base_config.items()) + list(config.items())) 285 | 286 | 287 | class RMSprop(Optimizer): 288 | """RMSProp optimizer. 289 | 290 | It is recommended to leave the parameters of this optimizer 291 | at their default values 292 | (except the learning rate, which can be freely tuned). 293 | 294 | This optimizer is usually a good choice for recurrent 295 | neural networks. 296 | 297 | # Arguments 298 | lr: float >= 0. Learning rate. 299 | rho: float >= 0. 300 | epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`. 301 | decay: float >= 0. Learning rate decay over each update. 302 | 303 | # References 304 | - [rmsprop: Divide the gradient by a running average of its recent magnitude] 305 | (http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) 306 | """ 307 | 308 | def __init__(self, lr=0.001, rho=0.9, epsilon=None, decay=0., 309 | **kwargs): 310 | super(RMSprop, self).__init__(**kwargs) 311 | with K.name_scope(self.__class__.__name__): 312 | self.lr = K.variable(lr, name='lr') 313 | self.rho = K.variable(rho, name='rho') 314 | self.decay = K.variable(decay, name='decay') 315 | self.iterations = K.variable(0, dtype='int64', name='iterations') 316 | if epsilon is None: 317 | epsilon = K.epsilon() 318 | self.epsilon = epsilon 319 | self.initial_decay = decay 320 | 321 | @interfaces.legacy_get_updates_support 322 | def get_updates(self, loss, params): 323 | grads = self.get_gradients(loss, params) 324 | accumulators = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] 325 | self.weights = accumulators 326 | self.updates = [K.update_add(self.iterations, 1)] 327 | 328 | lr = self.lr 329 | if self.initial_decay > 0: 330 | lr = lr * (1. / (1. + self.decay * K.cast(self.iterations, 331 | K.dtype(self.decay)))) 332 | 333 | for p, g, a in zip(params, grads, accumulators): 334 | # update accumulator 335 | new_a = self.rho * a + (1. - self.rho) * K.square(g) 336 | self.updates.append(K.update(a, new_a)) 337 | new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon) 338 | 339 | # Apply constraints. 340 | if getattr(p, 'constraint', None) is not None: 341 | new_p = p.constraint(new_p) 342 | 343 | self.updates.append(K.update(p, new_p)) 344 | return self.updates 345 | 346 | def get_config(self): 347 | config = {'lr': float(K.get_value(self.lr)), 348 | 'rho': float(K.get_value(self.rho)), 349 | 'decay': float(K.get_value(self.decay)), 350 | 'epsilon': self.epsilon} 351 | base_config = super(RMSprop, self).get_config() 352 | return dict(list(base_config.items()) + list(config.items())) 353 | 354 | 355 | class Adagrad(Optimizer): 356 | """Adagrad optimizer. 357 | 358 | Adagrad is an optimizer with parameter-specific learning rates, 359 | which are adapted relative to how frequently a parameter gets 360 | updated during training. The more updates a parameter receives, 361 | the smaller the updates. 362 | 363 | It is recommended to leave the parameters of this optimizer 364 | at their default values. 365 | 366 | # Arguments 367 | lr: float >= 0. Initial learning rate. 368 | epsilon: float >= 0. If `None`, defaults to `K.epsilon()`. 369 | decay: float >= 0. Learning rate decay over each update. 370 | 371 | # References 372 | - [Adaptive Subgradient Methods for Online Learning and Stochastic 373 | Optimization](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) 374 | """ 375 | 376 | def __init__(self, lr=0.01, epsilon=None, decay=0., **kwargs): 377 | super(Adagrad, self).__init__(**kwargs) 378 | with K.name_scope(self.__class__.__name__): 379 | self.lr = K.variable(lr, name='lr') 380 | self.decay = K.variable(decay, name='decay') 381 | self.iterations = K.variable(0, dtype='int64', name='iterations') 382 | if epsilon is None: 383 | epsilon = K.epsilon() 384 | self.epsilon = epsilon 385 | self.initial_decay = decay 386 | 387 | @interfaces.legacy_get_updates_support 388 | def get_updates(self, loss, params): 389 | grads = self.get_gradients(loss, params) 390 | shapes = [K.int_shape(p) for p in params] 391 | accumulators = [K.zeros(shape) for shape in shapes] 392 | self.weights = accumulators 393 | self.updates = [K.update_add(self.iterations, 1)] 394 | 395 | lr = self.lr 396 | if self.initial_decay > 0: 397 | lr = lr * (1. / (1. + self.decay * K.cast(self.iterations, 398 | K.dtype(self.decay)))) 399 | 400 | for p, g, a in zip(params, grads, accumulators): 401 | new_a = a + K.square(g) # update accumulator 402 | self.updates.append(K.update(a, new_a)) 403 | new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon) 404 | 405 | # Apply constraints. 406 | if getattr(p, 'constraint', None) is not None: 407 | new_p = p.constraint(new_p) 408 | 409 | self.updates.append(K.update(p, new_p)) 410 | return self.updates 411 | 412 | def get_config(self): 413 | config = {'lr': float(K.get_value(self.lr)), 414 | 'decay': float(K.get_value(self.decay)), 415 | 'epsilon': self.epsilon} 416 | base_config = super(Adagrad, self).get_config() 417 | return dict(list(base_config.items()) + list(config.items())) 418 | 419 | 420 | class Adadelta(Optimizer): 421 | """Adadelta optimizer. 422 | 423 | Adadelta is a more robust extension of Adagrad 424 | that adapts learning rates based on a moving window of gradient updates, 425 | instead of accumulating all past gradients. This way, Adadelta continues 426 | learning even when many updates have been done. Compared to Adagrad, in the 427 | original version of Adadelta you don't have to set an initial learning 428 | rate. In this version, initial learning rate and decay factor can 429 | be set, as in most other Keras optimizers. 430 | 431 | It is recommended to leave the parameters of this optimizer 432 | at their default values. 433 | 434 | # Arguments 435 | lr: float >= 0. Initial learning rate, defaults to 1. 436 | It is recommended to leave it at the default value. 437 | rho: float >= 0. Adadelta decay factor, corresponding to fraction of 438 | gradient to keep at each time step. 439 | epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`. 440 | decay: float >= 0. Initial learning rate decay. 441 | 442 | # References 443 | - [Adadelta - an adaptive learning rate method] 444 | (https://arxiv.org/abs/1212.5701) 445 | """ 446 | 447 | def __init__(self, lr=1.0, rho=0.95, epsilon=None, decay=0., 448 | **kwargs): 449 | super(Adadelta, self).__init__(**kwargs) 450 | with K.name_scope(self.__class__.__name__): 451 | self.lr = K.variable(lr, name='lr') 452 | self.decay = K.variable(decay, name='decay') 453 | self.iterations = K.variable(0, dtype='int64', name='iterations') 454 | if epsilon is None: 455 | epsilon = K.epsilon() 456 | self.rho = rho 457 | self.epsilon = epsilon 458 | self.initial_decay = decay 459 | 460 | @interfaces.legacy_get_updates_support 461 | def get_updates(self, loss, params): 462 | grads = self.get_gradients(loss, params) 463 | shapes = [K.int_shape(p) for p in params] 464 | accumulators = [K.zeros(shape) for shape in shapes] 465 | delta_accumulators = [K.zeros(shape) for shape in shapes] 466 | self.weights = accumulators + delta_accumulators 467 | self.updates = [K.update_add(self.iterations, 1)] 468 | 469 | lr = self.lr 470 | if self.initial_decay > 0: 471 | lr = lr * (1. / (1. + self.decay * K.cast(self.iterations, 472 | K.dtype(self.decay)))) 473 | 474 | for p, g, a, d_a in zip(params, grads, accumulators, delta_accumulators): 475 | # update accumulator 476 | new_a = self.rho * a + (1. - self.rho) * K.square(g) 477 | self.updates.append(K.update(a, new_a)) 478 | 479 | # use the new accumulator and the *old* delta_accumulator 480 | update = g * K.sqrt(d_a + self.epsilon) / K.sqrt(new_a + self.epsilon) 481 | new_p = p - lr * update 482 | 483 | # Apply constraints. 484 | if getattr(p, 'constraint', None) is not None: 485 | new_p = p.constraint(new_p) 486 | 487 | self.updates.append(K.update(p, new_p)) 488 | 489 | # update delta_accumulator 490 | new_d_a = self.rho * d_a + (1 - self.rho) * K.square(update) 491 | self.updates.append(K.update(d_a, new_d_a)) 492 | return self.updates 493 | 494 | def get_config(self): 495 | config = {'lr': float(K.get_value(self.lr)), 496 | 'rho': self.rho, 497 | 'decay': float(K.get_value(self.decay)), 498 | 'epsilon': self.epsilon} 499 | base_config = super(Adadelta, self).get_config() 500 | return dict(list(base_config.items()) + list(config.items())) 501 | 502 | 503 | class Adam(Optimizer): 504 | """Adam optimizer. 505 | 506 | Default parameters follow those provided in the original paper. 507 | 508 | # Arguments 509 | lr: float >= 0. Learning rate. 510 | beta_1: float, 0 < beta < 1. Generally close to 1. 511 | beta_2: float, 0 < beta < 1. Generally close to 1. 512 | epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`. 513 | decay: float >= 0. Learning rate decay over each update. 514 | amsgrad: boolean. Whether to apply the AMSGrad variant of this 515 | algorithm from the paper "On the Convergence of Adam and 516 | Beyond". 517 | 518 | # References 519 | - [Adam - A Method for Stochastic Optimization] 520 | (https://arxiv.org/abs/1412.6980v8) 521 | - [On the Convergence of Adam and Beyond] 522 | (https://openreview.net/forum?id=ryQu7f-RZ) 523 | """ 524 | 525 | def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, 526 | epsilon=None, decay=0., amsgrad=False, **kwargs): 527 | super(Adam, self).__init__(**kwargs) 528 | with K.name_scope(self.__class__.__name__): 529 | self.iterations = K.variable(0, dtype='int64', name='iterations') 530 | self.lr = K.variable(lr, name='lr') 531 | self.beta_1 = K.variable(beta_1, name='beta_1') 532 | self.beta_2 = K.variable(beta_2, name='beta_2') 533 | self.decay = K.variable(decay, name='decay') 534 | if epsilon is None: 535 | epsilon = K.epsilon() 536 | self.epsilon = epsilon 537 | self.initial_decay = decay 538 | self.amsgrad = amsgrad 539 | 540 | @interfaces.legacy_get_updates_support 541 | def get_updates(self, loss, params): 542 | grads = self.get_gradients(loss, params) 543 | self.updates = [K.update_add(self.iterations, 1)] 544 | 545 | lr = self.lr 546 | if self.initial_decay > 0: 547 | lr = lr * (1. / (1. + self.decay * K.cast(self.iterations, 548 | K.dtype(self.decay)))) 549 | 550 | t = K.cast(self.iterations, K.floatx()) + 1 551 | lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) / 552 | (1. - K.pow(self.beta_1, t))) 553 | 554 | ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] 555 | vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] 556 | if self.amsgrad: 557 | vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] 558 | else: 559 | vhats = [K.zeros(1) for _ in params] 560 | self.weights = [self.iterations] + ms + vs + vhats 561 | 562 | for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats): 563 | m_t = (self.beta_1 * m) + (1. - self.beta_1) * g 564 | v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g) 565 | if self.amsgrad: 566 | vhat_t = K.maximum(vhat, v_t) 567 | p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon) 568 | self.updates.append(K.update(vhat, vhat_t)) 569 | else: 570 | p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon) 571 | 572 | self.updates.append(K.update(m, m_t)) 573 | self.updates.append(K.update(v, v_t)) 574 | new_p = p_t 575 | 576 | # Apply constraints. 577 | if getattr(p, 'constraint', None) is not None: 578 | new_p = p.constraint(new_p) 579 | 580 | self.updates.append(K.update(p, new_p)) 581 | return self.updates 582 | 583 | def get_config(self): 584 | config = {'lr': float(K.get_value(self.lr)), 585 | 'beta_1': float(K.get_value(self.beta_1)), 586 | 'beta_2': float(K.get_value(self.beta_2)), 587 | 'decay': float(K.get_value(self.decay)), 588 | 'epsilon': self.epsilon, 589 | 'amsgrad': self.amsgrad} 590 | base_config = super(Adam, self).get_config() 591 | return dict(list(base_config.items()) + list(config.items())) 592 | 593 | 594 | class Adamax(Optimizer): 595 | """Adamax optimizer from Adam paper's Section 7. 596 | 597 | It is a variant of Adam based on the infinity norm. 598 | Default parameters follow those provided in the paper. 599 | 600 | # Arguments 601 | lr: float >= 0. Learning rate. 602 | beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1. 603 | epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`. 604 | decay: float >= 0. Learning rate decay over each update. 605 | 606 | # References 607 | - [Adam - A Method for Stochastic Optimization] 608 | (https://arxiv.org/abs/1412.6980v8) 609 | """ 610 | 611 | def __init__(self, lr=0.002, beta_1=0.9, beta_2=0.999, 612 | epsilon=None, decay=0., **kwargs): 613 | super(Adamax, self).__init__(**kwargs) 614 | with K.name_scope(self.__class__.__name__): 615 | self.iterations = K.variable(0, dtype='int64', name='iterations') 616 | self.lr = K.variable(lr, name='lr') 617 | self.beta_1 = K.variable(beta_1, name='beta_1') 618 | self.beta_2 = K.variable(beta_2, name='beta_2') 619 | self.decay = K.variable(decay, name='decay') 620 | if epsilon is None: 621 | epsilon = K.epsilon() 622 | self.epsilon = epsilon 623 | self.initial_decay = decay 624 | 625 | @interfaces.legacy_get_updates_support 626 | def get_updates(self, loss, params): 627 | grads = self.get_gradients(loss, params) 628 | self.updates = [K.update_add(self.iterations, 1)] 629 | 630 | lr = self.lr 631 | if self.initial_decay > 0: 632 | lr = lr * (1. / (1. + self.decay * K.cast(self.iterations, 633 | K.dtype(self.decay)))) 634 | 635 | t = K.cast(self.iterations, K.floatx()) + 1 636 | lr_t = lr / (1. - K.pow(self.beta_1, t)) 637 | 638 | shapes = [K.int_shape(p) for p in params] 639 | # zero init of 1st moment 640 | ms = [K.zeros(shape) for shape in shapes] 641 | # zero init of exponentially weighted infinity norm 642 | us = [K.zeros(shape) for shape in shapes] 643 | self.weights = [self.iterations] + ms + us 644 | 645 | for p, g, m, u in zip(params, grads, ms, us): 646 | 647 | m_t = (self.beta_1 * m) + (1. - self.beta_1) * g 648 | u_t = K.maximum(self.beta_2 * u, K.abs(g)) 649 | p_t = p - lr_t * m_t / (u_t + self.epsilon) 650 | 651 | self.updates.append(K.update(m, m_t)) 652 | self.updates.append(K.update(u, u_t)) 653 | new_p = p_t 654 | 655 | # Apply constraints. 656 | if getattr(p, 'constraint', None) is not None: 657 | new_p = p.constraint(new_p) 658 | 659 | self.updates.append(K.update(p, new_p)) 660 | return self.updates 661 | 662 | def get_config(self): 663 | config = {'lr': float(K.get_value(self.lr)), 664 | 'beta_1': float(K.get_value(self.beta_1)), 665 | 'beta_2': float(K.get_value(self.beta_2)), 666 | 'decay': float(K.get_value(self.decay)), 667 | 'epsilon': self.epsilon} 668 | base_config = super(Adamax, self).get_config() 669 | return dict(list(base_config.items()) + list(config.items())) 670 | 671 | 672 | class Nadam(Optimizer): 673 | """Nesterov Adam optimizer. 674 | 675 | Much like Adam is essentially RMSprop with momentum, 676 | Nadam is Adam RMSprop with Nesterov momentum. 677 | 678 | Default parameters follow those provided in the paper. 679 | It is recommended to leave the parameters of this optimizer 680 | at their default values. 681 | 682 | # Arguments 683 | lr: float >= 0. Learning rate. 684 | beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1. 685 | epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`. 686 | 687 | # References 688 | - [Nadam report](http://cs229.stanford.edu/proj2015/054_report.pdf) 689 | - [On the importance of initialization and momentum in deep learning] 690 | (http://www.cs.toronto.edu/~fritz/absps/momentum.pdf) 691 | """ 692 | 693 | def __init__(self, lr=0.002, beta_1=0.9, beta_2=0.999, 694 | epsilon=None, schedule_decay=0.004, **kwargs): 695 | super(Nadam, self).__init__(**kwargs) 696 | with K.name_scope(self.__class__.__name__): 697 | self.iterations = K.variable(0, dtype='int64', name='iterations') 698 | self.m_schedule = K.variable(1., name='m_schedule') 699 | self.lr = K.variable(lr, name='lr') 700 | self.beta_1 = K.variable(beta_1, name='beta_1') 701 | self.beta_2 = K.variable(beta_2, name='beta_2') 702 | if epsilon is None: 703 | epsilon = K.epsilon() 704 | self.epsilon = epsilon 705 | self.schedule_decay = schedule_decay 706 | 707 | @interfaces.legacy_get_updates_support 708 | def get_updates(self, loss, params): 709 | grads = self.get_gradients(loss, params) 710 | self.updates = [K.update_add(self.iterations, 1)] 711 | 712 | t = K.cast(self.iterations, K.floatx()) + 1 713 | 714 | # Due to the recommendations in [2], i.e. warming momentum schedule 715 | momentum_cache_t = self.beta_1 * (1. - 0.5 * ( 716 | K.pow(K.cast_to_floatx(0.96), t * self.schedule_decay))) 717 | momentum_cache_t_1 = self.beta_1 * (1. - 0.5 * ( 718 | K.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay))) 719 | m_schedule_new = self.m_schedule * momentum_cache_t 720 | m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1 721 | self.updates.append((self.m_schedule, m_schedule_new)) 722 | 723 | shapes = [K.int_shape(p) for p in params] 724 | ms = [K.zeros(shape) for shape in shapes] 725 | vs = [K.zeros(shape) for shape in shapes] 726 | 727 | self.weights = [self.iterations] + ms + vs 728 | 729 | for p, g, m, v in zip(params, grads, ms, vs): 730 | # the following equations given in [1] 731 | g_prime = g / (1. - m_schedule_new) 732 | m_t = self.beta_1 * m + (1. - self.beta_1) * g 733 | m_t_prime = m_t / (1. - m_schedule_next) 734 | v_t = self.beta_2 * v + (1. - self.beta_2) * K.square(g) 735 | v_t_prime = v_t / (1. - K.pow(self.beta_2, t)) 736 | m_t_bar = (1. - momentum_cache_t) * g_prime + ( 737 | momentum_cache_t_1 * m_t_prime) 738 | 739 | self.updates.append(K.update(m, m_t)) 740 | self.updates.append(K.update(v, v_t)) 741 | 742 | p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon) 743 | new_p = p_t 744 | 745 | # Apply constraints. 746 | if getattr(p, 'constraint', None) is not None: 747 | new_p = p.constraint(new_p) 748 | 749 | self.updates.append(K.update(p, new_p)) 750 | return self.updates 751 | 752 | def get_config(self): 753 | config = {'lr': float(K.get_value(self.lr)), 754 | 'beta_1': float(K.get_value(self.beta_1)), 755 | 'beta_2': float(K.get_value(self.beta_2)), 756 | 'epsilon': self.epsilon, 757 | 'schedule_decay': self.schedule_decay} 758 | base_config = super(Nadam, self).get_config() 759 | return dict(list(base_config.items()) + list(config.items())) 760 | 761 | 762 | class TFOptimizer(Optimizer): 763 | """Wrapper class for native TensorFlow optimizers. 764 | """ 765 | 766 | def __init__(self, optimizer): 767 | self.optimizer = optimizer 768 | with K.name_scope(self.__class__.__name__): 769 | self.iterations = K.variable(0, dtype='int64', name='iterations') 770 | 771 | @interfaces.legacy_get_updates_support 772 | def get_updates(self, loss, params): 773 | grads = self.optimizer.compute_gradients(loss, params) 774 | self.updates = [K.update_add(self.iterations, 1)] 775 | opt_update = self.optimizer.apply_gradients( 776 | grads, global_step=self.iterations) 777 | self.updates.append(opt_update) 778 | return self.updates 779 | 780 | @property 781 | def weights(self): 782 | raise NotImplementedError 783 | 784 | def get_config(self): 785 | raise NotImplementedError 786 | 787 | def from_config(self, config): 788 | raise NotImplementedError 789 | 790 | 791 | # Aliases. 792 | 793 | sgd = SGD 794 | rmsprop = RMSprop 795 | adagrad = Adagrad 796 | adadelta = Adadelta 797 | adam = Adam 798 | adamax = Adamax 799 | nadam = Nadam 800 | 801 | 802 | def serialize(optimizer): 803 | return serialize_keras_object(optimizer) 804 | 805 | 806 | def deserialize(config, custom_objects=None): 807 | """Inverse of the `serialize` function. 808 | 809 | # Arguments 810 | config: Optimizer configuration dictionary. 811 | custom_objects: Optional dictionary mapping 812 | names (strings) to custom objects 813 | (classes and functions) 814 | to be considered during deserialization. 815 | 816 | # Returns 817 | A Keras Optimizer instance. 818 | """ 819 | all_classes = { 820 | 'sgd': SGD, 821 | 'rmsprop': RMSprop, 822 | 'adagrad': Adagrad, 823 | 'adadelta': Adadelta, 824 | 'adam': Adam, 825 | 'adamax': Adamax, 826 | 'nadam': Nadam, 827 | 'tfoptimizer': TFOptimizer, 828 | } 829 | # Make deserialization case-insensitive for built-in optimizers. 830 | if config['class_name'].lower() in all_classes: 831 | config['class_name'] = config['class_name'].lower() 832 | return deserialize_keras_object(config, 833 | module_objects=all_classes, 834 | custom_objects=custom_objects, 835 | printable_module_name='optimizer') 836 | 837 | 838 | def get(identifier): 839 | """Retrieves a Keras Optimizer instance. 840 | 841 | # Arguments 842 | identifier: Optimizer identifier, one of 843 | - String: name of an optimizer 844 | - Dictionary: configuration dictionary. 845 | - Keras Optimizer instance (it will be returned unchanged). 846 | - TensorFlow Optimizer instance 847 | (it will be wrapped as a Keras Optimizer). 848 | 849 | # Returns 850 | A Keras Optimizer instance. 851 | 852 | # Raises 853 | ValueError: If `identifier` cannot be interpreted. 854 | """ 855 | if K.backend() == 'tensorflow': 856 | # Wrap TF optimizer instances 857 | if isinstance(identifier, tf.train.Optimizer): 858 | return TFOptimizer(identifier) 859 | if isinstance(identifier, dict): 860 | return deserialize(identifier) 861 | elif isinstance(identifier, six.string_types): 862 | config = {'class_name': str(identifier), 'config': {}} 863 | return deserialize(config) 864 | if isinstance(identifier, Optimizer): 865 | return identifier 866 | else: 867 | raise ValueError('Could not interpret optimizer identifier: ' + 868 | str(identifier)) 869 | --------------------------------------------------------------------------------