├── DeepCCA.py ├── LICENSE ├── README.md ├── linear_cca.py ├── models.py ├── objectives.py └── utils.py /DeepCCA.py: -------------------------------------------------------------------------------- 1 | try: 2 | import cPickle as thepickle 3 | except ImportError: 4 | import _pickle as thepickle 5 | 6 | import gzip 7 | import numpy as np 8 | 9 | from keras.callbacks import ModelCheckpoint 10 | from utils import load_data, svm_classify 11 | from linear_cca import linear_cca 12 | from models import create_model 13 | 14 | 15 | def train_model(model, data1, data2, epoch_num, batch_size): 16 | """ 17 | trains the model 18 | # Arguments 19 | data1 and data2: the train, validation, and test data for view 1 and view 2 respectively. data should be packed 20 | like ((X for train, Y for train), (X for validation, Y for validation), (X for test, Y for test)) 21 | epoch_num: number of epochs to train the model 22 | batch_size: the size of batches 23 | # Returns 24 | the trained model 25 | """ 26 | 27 | # Unpacking the data 28 | train_set_x1, train_set_y1 = data1[0] 29 | valid_set_x1, valid_set_y1 = data1[1] 30 | test_set_x1, test_set_y1 = data1[2] 31 | 32 | train_set_x2, train_set_y2 = data2[0] 33 | valid_set_x2, valid_set_y2 = data2[1] 34 | test_set_x2, test_set_y2 = data2[2] 35 | 36 | # best weights are saved in "temp_weights.hdf5" during training 37 | # it is done to return the best model based on the validation loss 38 | checkpointer = ModelCheckpoint(filepath="temp_weights.h5", verbose=1, save_best_only=True, save_weights_only=True) 39 | 40 | # used dummy Y because labels are not used in the loss function 41 | model.fit([train_set_x1, train_set_x2], np.zeros(len(train_set_x1)), 42 | batch_size=batch_size, epochs=epoch_num, shuffle=True, 43 | validation_data=([valid_set_x1, valid_set_x2], np.zeros(len(valid_set_x1))), 44 | callbacks=[checkpointer]) 45 | 46 | model.load_weights("temp_weights.h5") 47 | 48 | results = model.evaluate([test_set_x1, test_set_x2], np.zeros(len(test_set_x1)), batch_size=batch_size, verbose=1) 49 | 50 | print('loss on test data: ', results) 51 | 52 | results = model.evaluate([valid_set_x1, valid_set_x2], np.zeros(len(valid_set_x1)), batch_size=batch_size, verbose=1) 53 | print('loss on validation data: ', results) 54 | return model 55 | 56 | 57 | def test_model(model, data1, data2, outdim_size, apply_linear_cca): 58 | """produce the new features by using the trained model 59 | # Arguments 60 | model: the trained model 61 | data1 and data2: the train, validation, and test data for view 1 and view 2 respectively. 62 | Data should be packed like 63 | ((X for train, Y for train), (X for validation, Y for validation), (X for test, Y for test)) 64 | outdim_size: dimension of new features 65 | apply_linear_cca: if to apply linear CCA on the new features 66 | # Returns 67 | new features packed like 68 | ((new X for train - view 1, new X for train - view 2, Y for train), 69 | (new X for validation - view 1, new X for validation - view 2, Y for validation), 70 | (new X for test - view 1, new X for test - view 2, Y for test)) 71 | """ 72 | 73 | # producing the new features 74 | new_data = [] 75 | for k in range(3): 76 | pred_out = model.predict([data1[k][0], data2[k][0]]) 77 | r = int(pred_out.shape[1] / 2) 78 | new_data.append([pred_out[:, :r], pred_out[:, r:], data1[k][1]]) 79 | 80 | # based on the DCCA paper, a linear CCA should be applied on the output of the networks because 81 | # the loss function actually estimates the correlation when a linear CCA is applied to the output of the networks 82 | # however it does not improve the performance significantly 83 | if apply_linear_cca: 84 | w = [None, None] 85 | m = [None, None] 86 | print("Linear CCA started!") 87 | w[0], w[1], m[0], m[1] = linear_cca(new_data[0][0], new_data[0][1], outdim_size) 88 | print("Linear CCA ended!") 89 | 90 | # Something done in the original MATLAB implementation of DCCA, do not know exactly why;) 91 | # it did not affect the performance significantly on the noisy MNIST dataset 92 | #s = np.sign(w[0][0,:]) 93 | #s = s.reshape([1, -1]).repeat(w[0].shape[0], axis=0) 94 | #w[0] = w[0] * s 95 | #w[1] = w[1] * s 96 | ### 97 | 98 | for k in range(3): 99 | data_num = len(new_data[k][0]) 100 | for v in range(2): 101 | new_data[k][v] -= m[v].reshape([1, -1]).repeat(data_num, axis=0) 102 | new_data[k][v] = np.dot(new_data[k][v], w[v]) 103 | 104 | return new_data 105 | 106 | 107 | if __name__ == '__main__': 108 | ############ 109 | # Parameters Section 110 | 111 | # the path to save the final learned features 112 | save_to = './new_features.gz' 113 | 114 | # the size of the new space learned by the model (number of the new features) 115 | outdim_size = 10 116 | 117 | # size of the input for view 1 and view 2 118 | input_shape1 = 784 119 | input_shape2 = 784 120 | 121 | # number of layers with nodes in each one 122 | layer_sizes1 = [1024, 1024, 1024, outdim_size] 123 | layer_sizes2 = [1024, 1024, 1024, outdim_size] 124 | 125 | # the parameters for training the network 126 | learning_rate = 1e-3 127 | epoch_num = 100 128 | batch_size = 800 129 | 130 | # the regularization parameter of the network 131 | # seems necessary to avoid the gradient exploding especially when non-saturating activations are used 132 | reg_par = 1e-5 133 | 134 | # specifies if all the singular values should get used to calculate the correlation or just the top outdim_size ones 135 | # if one option does not work for a network or dataset, try the other one 136 | use_all_singular_values = False 137 | 138 | # if a linear CCA should get applied on the learned features extracted from the networks 139 | # it does not affect the performance on noisy MNIST significantly 140 | apply_linear_cca = True 141 | 142 | # end of parameters section 143 | ############ 144 | 145 | # Each view is stored in a gzip file separately. They will get downloaded the first time the code gets executed. 146 | # Datasets get stored under the datasets folder of user's Keras folder 147 | # normally under [Home Folder]/.keras/datasets/ 148 | data1 = load_data('noisymnist_view1.gz', 'https://www2.cs.uic.edu/~vnoroozi/noisy-mnist/noisymnist_view1.gz') 149 | data2 = load_data('noisymnist_view2.gz', 'https://www2.cs.uic.edu/~vnoroozi/noisy-mnist/noisymnist_view2.gz') 150 | 151 | # Building, training, and producing the new features by DCCA 152 | model = create_model(layer_sizes1, layer_sizes2, input_shape1, input_shape2, 153 | learning_rate, reg_par, outdim_size, use_all_singular_values) 154 | model.summary() 155 | model = train_model(model, data1, data2, epoch_num, batch_size) 156 | new_data = test_model(model, data1, data2, outdim_size, apply_linear_cca) 157 | 158 | # Training and testing of SVM with linear kernel on the view 1 with new features 159 | [test_acc, valid_acc] = svm_classify(new_data, C=0.01) 160 | print("Accuracy on view 1 (validation data) is:", valid_acc * 100.0) 161 | print("Accuracy on view 1 (test data) is:", test_acc*100.0) 162 | 163 | # Saving new features in a gzip pickled file specified by save_to 164 | print('saving new features ...') 165 | f1 = gzip.open(save_to, 'wb') 166 | thepickle.dump(new_data, f1) 167 | f1.close() 168 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Vahid Noroozi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DCCA: Deep Canonical Correlation Analysis 2 | 3 | This is an implementation of Deep Canonical Correlation Analysis (DCCA or Deep CCA) in Python. It needs Theano and Keras libraries to be installed. 4 | 5 | DCCA is a non-linear version of CCA which uses neural networks as the mapping functions instead of linear transformers. DCCA is originally proposed in the following paper: 6 | 7 | Galen Andrew, Raman Arora, Jeff Bilmes, Karen Livescu, "[Deep Canonical Correlation Analysis.](http://www.jmlr.org/proceedings/papers/v28/andrew13.pdf)", ICML, 2013. 8 | 9 | It uses the Keras library with the Theano backend, and does not work on the Tensorflow backend. Because the loss function of the network is written with Theano. The base modeling network can easily get substituted with a more efficient and powerful network like CNN. 10 | 11 | Most of the configuration and parameters are set based on the following paper: 12 | 13 | Weiran Wang, Raman Arora, Karen Livescu, and Jeff Bilmes. "[On Deep Multi-View Representation Learning.](http://proceedings.mlr.press/v37/wangb15.pdf)", ICML, 2015. 14 | 15 | ### Dataset 16 | The model is evaluated on a noisy version of MNIST dataset. I built the dataset exactly like the way it is introduced in the paper. The train/validation/test split is the original split of MNIST. 17 | 18 | The dataset was large and could not get uploaded on GitHub. So it is uploaded on another server. The first time that the code gets executed, the dataset gets downloaded automatically by the code. It will get saved under the datasets folder of user's Keras folder (normally under [Home Folder]/.keras/datasets/). 19 | 20 | ### Differences with the original paper 21 | The following are the differences between my implementation and the original paper (they are small): 22 | 23 | * I used RMSProp (an adaptive version of gradient descent) instead of GD with momentum. It was so much faster in converging. 24 | * Instead of a non-saturating version of sigmoid, I just used the standard sigmoid as the activation functions. Standard sigmoid is used in the MATLAB implementation too. It should not affect the performance significantly. However, if it is needed, it can get substituted by another non-saturating activation function like ReLU. 25 | * Pre-training is not done in this implementation. However, it is not clear how much it can be useful. 26 | 27 | ### Other Implementations 28 | The following are the other implementations of DCCA in MATLAB and C++ from which I got help for the implementation. These codes are written by the authors of the original paper: 29 | 30 | * [C++ implementation](https://homes.cs.washington.edu/~galen/files/dcca.tgz) from Galen Andrew's website (https://homes.cs.washington.edu/~galen/) 31 | 32 | * [MATLAB implementation](http://ttic.uchicago.edu/~wwang5/papers/dccae.tgz) from Weiran Wang's website (http://ttic.uchicago.edu/~wwang5/dccae.html) -------------------------------------------------------------------------------- /linear_cca.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | 4 | def linear_cca(H1, H2, outdim_size): 5 | """ 6 | An implementation of linear CCA 7 | # Arguments: 8 | H1 and H2: the matrices containing the data for view 1 and view 2. Each row is a sample. 9 | outdim_size: specifies the number of new features 10 | # Returns 11 | A and B: the linear transformation matrices 12 | mean1 and mean2: the means of data for both views 13 | """ 14 | r1 = 1e-4 15 | r2 = 1e-4 16 | 17 | m = H1.shape[0] 18 | o = H1.shape[1] 19 | 20 | mean1 = numpy.mean(H1, axis=0) 21 | mean2 = numpy.mean(H2, axis=0) 22 | H1bar = H1 - numpy.tile(mean1, (m, 1)) 23 | H2bar = H2 - numpy.tile(mean2, (m, 1)) 24 | 25 | SigmaHat12 = (1.0 / (m - 1)) * numpy.dot(H1bar.T, H2bar) 26 | SigmaHat11 = (1.0 / (m - 1)) * numpy.dot(H1bar.T, H1bar) + r1 * numpy.identity(o) 27 | SigmaHat22 = (1.0 / (m - 1)) * numpy.dot(H2bar.T, H2bar) + r2 * numpy.identity(o) 28 | 29 | [D1, V1] = numpy.linalg.eigh(SigmaHat11) 30 | [D2, V2] = numpy.linalg.eigh(SigmaHat22) 31 | SigmaHat11RootInv = numpy.dot(numpy.dot(V1, numpy.diag(D1 ** -0.5)), V1.T) 32 | SigmaHat22RootInv = numpy.dot(numpy.dot(V2, numpy.diag(D2 ** -0.5)), V2.T) 33 | 34 | Tval = numpy.dot(numpy.dot(SigmaHat11RootInv, SigmaHat12), SigmaHat22RootInv) 35 | 36 | [U, D, V] = numpy.linalg.svd(Tval) 37 | V = V.T 38 | A = numpy.dot(SigmaHat11RootInv, U[:, 0:outdim_size]) 39 | B = numpy.dot(SigmaHat22RootInv, V[:, 0:outdim_size]) 40 | D = D[0:outdim_size] 41 | 42 | return A, B, mean1, mean2 43 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Dense, Merge 2 | from keras.models import Sequential 3 | from keras.optimizers import RMSprop 4 | from keras.regularizers import l2 5 | from objectives import cca_loss 6 | 7 | 8 | def create_model(layer_sizes1, layer_sizes2, input_size1, input_size2, 9 | learning_rate, reg_par, outdim_size, use_all_singular_values): 10 | """ 11 | builds the whole model 12 | the structure of each sub-network is defined in build_mlp_net, 13 | and it can easily get substituted with a more efficient and powerful network like CNN 14 | """ 15 | view1_model = build_mlp_net(layer_sizes1, input_size1, reg_par) 16 | view2_model = build_mlp_net(layer_sizes2, input_size2, reg_par) 17 | 18 | model = Sequential() 19 | model.add(Merge([view1_model, view2_model], mode='concat')) 20 | 21 | model_optimizer = RMSprop(lr=learning_rate) 22 | model.compile(loss=cca_loss(outdim_size, use_all_singular_values), optimizer=model_optimizer) 23 | 24 | return model 25 | 26 | 27 | def build_mlp_net(layer_sizes, input_size, reg_par): 28 | model = Sequential() 29 | for l_id, ls in enumerate(layer_sizes): 30 | if l_id == 0: 31 | input_dim = input_size 32 | else: 33 | input_dim = [] 34 | if l_id == len(layer_sizes)-1: 35 | activation = 'linear' 36 | else: 37 | activation = 'sigmoid' 38 | 39 | model.add(Dense(ls, input_dim=input_dim, 40 | activation=activation, 41 | kernel_regularizer=l2(reg_par))) 42 | return model 43 | -------------------------------------------------------------------------------- /objectives.py: -------------------------------------------------------------------------------- 1 | import theano.tensor as T 2 | 3 | 4 | def cca_loss(outdim_size, use_all_singular_values): 5 | """ 6 | The main loss function (inner_cca_objective) is wrapped in this function due to 7 | the constraints imposed by Keras on objective functions 8 | """ 9 | def inner_cca_objective(y_true, y_pred): 10 | """ 11 | It is the loss function of CCA as introduced in the original paper. There can be other formulations. 12 | It is implemented by Theano tensor operations, and does not work on Tensorflow backend 13 | y_true is just ignored 14 | """ 15 | 16 | r1 = 1e-4 17 | r2 = 1e-4 18 | eps = 1e-12 19 | o1 = o2 = y_pred.shape[1]//2 20 | 21 | # unpack (separate) the output of networks for view 1 and view 2 22 | H1 = y_pred[:, 0:o1].T 23 | H2 = y_pred[:, o1:o1+o2].T 24 | 25 | m = H1.shape[1] 26 | 27 | H1bar = H1 - (1.0 / m) * T.dot(H1, T.ones([m, m])) 28 | H2bar = H2 - (1.0 / m) * T.dot(H2, T.ones([m, m])) 29 | 30 | SigmaHat12 = (1.0 / (m - 1)) * T.dot(H1bar, H2bar.T) 31 | SigmaHat11 = (1.0 / (m - 1)) * T.dot(H1bar, H1bar.T) + r1 * T.eye(o1) 32 | SigmaHat22 = (1.0 / (m - 1)) * T.dot(H2bar, H2bar.T) + r2 * T.eye(o2) 33 | 34 | # Calculating the root inverse of covariance matrices by using eigen decomposition 35 | [D1, V1] = T.nlinalg.eigh(SigmaHat11) 36 | [D2, V2] = T.nlinalg.eigh(SigmaHat22) 37 | 38 | # Added to increase stability 39 | posInd1 = T.gt(D1, eps).nonzero()[0] 40 | D1 = D1[posInd1] 41 | V1 = V1[:, posInd1] 42 | posInd2 = T.gt(D2, eps).nonzero()[0] 43 | D2 = D2[posInd2] 44 | V2 = V2[:, posInd2] 45 | 46 | SigmaHat11RootInv = T.dot(T.dot(V1, T.nlinalg.diag(D1 ** -0.5)), V1.T) 47 | SigmaHat22RootInv = T.dot(T.dot(V2, T.nlinalg.diag(D2 ** -0.5)), V2.T) 48 | 49 | Tval = T.dot(T.dot(SigmaHat11RootInv, SigmaHat12), SigmaHat22RootInv) 50 | 51 | if use_all_singular_values: 52 | # all singular values are used to calculate the correlation 53 | corr = T.sqrt(T.nlinalg.trace(T.dot(Tval.T, Tval))) 54 | else: 55 | # just the top outdim_size singular values are used 56 | [U, V] = T.nlinalg.eigh(T.dot(Tval.T, Tval)) 57 | U = U[T.gt(U, eps).nonzero()[0]] 58 | U = U.sort() 59 | corr = T.sum(T.sqrt(U[0:outdim_size])) 60 | 61 | return -corr 62 | 63 | return inner_cca_objective 64 | 65 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | from sklearn import svm 3 | from sklearn.metrics import accuracy_score 4 | import numpy as np 5 | import theano 6 | from keras.utils.data_utils import get_file 7 | 8 | 9 | def load_data(data_file, url): 10 | """loads the data from the gzip pickled files, and converts to numpy arrays""" 11 | print('loading data ...') 12 | path = get_file(data_file, origin=url) 13 | f = gzip.open(path, 'rb') 14 | train_set, valid_set, test_set = load_pickle(f) 15 | f.close() 16 | 17 | train_set_x, train_set_y = make_numpy_array(train_set) 18 | valid_set_x, valid_set_y = make_numpy_array(valid_set) 19 | test_set_x, test_set_y = make_numpy_array(test_set) 20 | 21 | return [(train_set_x, train_set_y), (valid_set_x, valid_set_y), (test_set_x, test_set_y)] 22 | 23 | 24 | def make_numpy_array(data_xy): 25 | """converts the input to numpy arrays""" 26 | data_x, data_y = data_xy 27 | data_x = np.asarray(data_x, dtype=theano.config.floatX) 28 | data_y = np.asarray(data_y, dtype='int32') 29 | return data_x, data_y 30 | 31 | 32 | def svm_classify(data, C): 33 | """ 34 | trains a linear SVM on the data 35 | input C specifies the penalty factor of SVM 36 | """ 37 | train_data, _, train_label = data[0] 38 | valid_data, _, valid_label = data[1] 39 | test_data, _, test_label = data[2] 40 | 41 | print('training SVM...') 42 | clf = svm.LinearSVC(C=C, dual=False) 43 | clf.fit(train_data, train_label.ravel()) 44 | 45 | p = clf.predict(test_data) 46 | test_acc = accuracy_score(test_label, p) 47 | p = clf.predict(valid_data) 48 | valid_acc = accuracy_score(valid_label, p) 49 | 50 | return [test_acc, valid_acc] 51 | 52 | 53 | def load_pickle(f): 54 | """ 55 | loads and returns the content of a pickled file 56 | it handles the inconsistencies between the pickle packages available in Python 2 and 3 57 | """ 58 | try: 59 | import cPickle as thepickle 60 | except ImportError: 61 | import _pickle as thepickle 62 | 63 | try: 64 | ret = thepickle.load(f, encoding='latin1') 65 | except TypeError: 66 | ret = thepickle.load(f) 67 | 68 | return ret 69 | 70 | --------------------------------------------------------------------------------