├── save └── .keep ├── .gitignore ├── img ├── cca_plot.png └── gcca_plot.png ├── gcca ├── __init__.py ├── bridged_cca.py ├── cca.py ├── hierarchical_cca.py └── gcca.py ├── setup.py └── README.md /save/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .idea 3 | -------------------------------------------------------------------------------- /img/cca_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rupy/GCCA/HEAD/img/cca_plot.png -------------------------------------------------------------------------------- /img/gcca_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rupy/GCCA/HEAD/img/gcca_plot.png -------------------------------------------------------------------------------- /gcca/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'rupy' 2 | 3 | from cca import CCA 4 | from gcca import GCCA 5 | from bridged_cca import BridgedCCA 6 | from hierarchical_cca import HierarchicalCCA -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='gcca', 5 | description='Generalized Canonical Correlation Analysis library', 6 | author='rupy', 7 | author_email='rupyapps@gmail.com', 8 | license='MIT License', 9 | version='0.1dev', 10 | url='https://github.com/rupy/GCCA', 11 | packages=find_packages(), 12 | install_requires=['numpy', 'scipy', 'matplotlib', 'h5py'], 13 | zip_safe=False, 14 | classifiers=[ 15 | 'Development Status :: 3 - Alpha', 16 | 'Intended Audience :: Science/Research', 17 | 'Intended Audience :: Developers', 18 | 'License :: OSI Approved :: MIT License', 19 | 'Programming Language :: Python', 20 | 'Programming Language :: Python :: 2.7', 21 | 'Topic :: Software Development', 22 | 'Topic :: Scientific/Engineering', 23 | 'Operating System :: Microsoft :: Windows', 24 | 'Operating System :: POSIX', 25 | 'Operating System :: Unix', 26 | 'Operating System :: MacOS', 27 | ], 28 | ) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GCCA 2 | 3 | This repository is implementation of Generalized Canonical Correlation Analysis(GCCA). 4 | CCA can use only 2 data but GCCA can use more than 2 data. 5 | 6 | ## CCA 7 | 8 | CCA is the method to transform 2 data to one joint space. See example graph: 9 | 10 | ![CCA Plot Result](https://github.com/rupy/GCCA/blob/master/img/cca_plot.png) 11 | 12 | CCA inplementation contains PCCA (Probablistic Canonical Correlation Analysis) transformation that is assumed that there is latent space in 2 data. 13 | 14 | ## GCCA 15 | 16 | GCCA is the method to transform multiple data to one joint space. See example graph: 17 | 18 | ![GCCA Plot Result](https://github.com/rupy/GCCA/blob/master/img/gcca_plot.png) 19 | 20 | You can give GCCA any number of data. 21 | 22 | ## Installation 23 | 24 | You can use 'git clone' command to install 25 | 26 | ## Dependencies 27 | 28 | You have to install python dependent libraries in advance as follow: 29 | 30 | ``` 31 | numpy==1.9.1 32 | scipy==0.14.1 33 | matplotlib==1.4.2 34 | h5py==2.4.0 35 | ``` 36 | 37 | ## Usage of CCA 38 | 39 | ```python 40 | from cca import CCA 41 | import logging 42 | import numpy as np 43 | 44 | # set log level 45 | logging.root.setLevel(level=logging.INFO) 46 | 47 | # create data in advance 48 | a = np.random.rand(50, 50) 49 | b = np.random.rand(50, 60) 50 | 51 | # create instance of CCA 52 | cca = CCA() 53 | # calculate CCA 54 | cca.fit(a, b) 55 | # transform 56 | cca.transform(a, b) 57 | # transform by PCCA 58 | cca.ptransform(a, b) 59 | # save 60 | cca.save_params("save/cca.h5") 61 | # load 62 | cca.load_params("save/cca.h5") 63 | # plot 64 | cca.plot_pcca_result() 65 | ``` 66 | 67 | 68 | ## Usage of GCCA 69 | 70 | ```python 71 | from gcca import GCCA 72 | import logging 73 | import numpy as np 74 | 75 | # set log level 76 | logging.root.setLevel(level=logging.INFO) 77 | 78 | # create data in advance 79 | a = np.random.rand(50, 50) 80 | b = np.random.rand(50, 60) 81 | c = np.random.rand(50, 70) 82 | d = np.random.rand(50, 80) 83 | e = np.random.rand(50, 90) 84 | f = np.random.rand(50, 100) 85 | g = np.random.rand(50, 110) 86 | h = np.random.rand(50, 120) 87 | i = np.random.rand(50, 130) 88 | j = np.random.rand(50, 140) 89 | k = np.random.rand(50, 150) 90 | 91 | # create instance of GCCA 92 | gcca = GCCA() 93 | # calculate GCCA 94 | gcca.fit(a, b, c, d, e, f, g, h, i, j, k) 95 | # transform 96 | gcca.transform(a, b, c, d, e, f, g, h, i, j, k) 97 | # save 98 | gcca.save_params("save/gcca.h5") 99 | # load 100 | gcca.load_params("save/gcca.h5") 101 | # plot 102 | gcca.plot_gcca_result() 103 | ``` 104 | 105 | That's it! 106 | -------------------------------------------------------------------------------- /gcca/bridged_cca.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #-*- coding: utf-8 -*- 3 | 4 | __author__ = 'rupy' 5 | 6 | from gcca import GCCA 7 | import numpy as np 8 | import logging 9 | from sklearn.datasets import load_digits 10 | 11 | class BridgedCCA(GCCA): 12 | 13 | def __init__(self, n_components=2, reg_param=0.1): 14 | GCCA.__init__(self, n_components, reg_param) 15 | 16 | def fit(self, x0_pair0, x1_pair0, x1_pair1, x2_pair1): 17 | 18 | p0_list = [x0_pair0, x1_pair0] 19 | p1_list = [x1_pair1, x2_pair1] 20 | data_num = 3 21 | 22 | # data size check 23 | p0_num = len(p0_list) 24 | p1_num = len(p1_list) 25 | self.logger.info("pair0 data num is %d", p0_num) 26 | for i, x in enumerate(p0_list): 27 | self.logger.info("pair0 data shape x_%d: %s", i, x.shape) 28 | self.logger.info("pair1 data num is %d", p1_num) 29 | for i, x in enumerate(p1_list): 30 | self.logger.info("pair1 data shape x_%d: %s", i + 1, x.shape) 31 | 32 | self.logger.info("normalizing") 33 | p0_norm_list = [ self.normalize(x) for x in p0_list] 34 | p1_norm_list = [ self.normalize(x) for x in p1_list] 35 | 36 | p0_d_list = [0] + [sum([len(x.T) for x in p0_list][:i + 1]) for i in xrange(p0_num)] 37 | p1_d_list = [0] + [sum([len(x.T) for x in p1_list][:i + 1]) for i in xrange(p1_num)] 38 | 39 | p0_cov_mat = self.calc_cov_mat(p0_norm_list) 40 | p0_cov_mat = self.add_regularization_term(p0_cov_mat) 41 | p1_cov_mat = self.calc_cov_mat(p1_norm_list) 42 | p1_cov_mat = self.add_regularization_term(p1_cov_mat) 43 | self.logger.info("calc variance") 44 | x1_all = np.vstack([x1_pair0, x1_pair1]) 45 | x1_var = np.cov(x1_all.T) 46 | self.logger.info("adding regularization term") 47 | x1_var += self.reg_param * np.average(np.diag(x1_var)) * np.eye(x1_var.shape[0]) 48 | 49 | x_list = [x0_pair0, x1_all, x2_pair1] 50 | d_list = [0] + [sum([len(x.T) for x in x_list][:i + 1]) for i in xrange(data_num)] 51 | 52 | c00 = p0_cov_mat[0][0] 53 | c01 = p0_cov_mat[0][1] 54 | # c11 = p0_cov_mat[1][1] 55 | # c11 = p1_cov_mat[1 - 1][1 - 1] 56 | c11 = x1_var 57 | c12 = p1_cov_mat[1 - 1][2 - 1] 58 | c22 = p1_cov_mat[2 - 1][2 - 1] 59 | c02 = np.zeros((c00.shape[0], c22.shape[1])) 60 | 61 | cov_mat = [[np.array([]) for col in range(data_num)] for row in range(data_num)] 62 | cov_mat[0][0], cov_mat[0][1], cov_mat[0][2] = c00, c01, c02 63 | cov_mat[1][0], cov_mat[1][1], cov_mat[1][2] = c01.T, c11, c12 64 | cov_mat[2][0], cov_mat[2][1], cov_mat[2][2] = c02.T, c12.T, c22 65 | 66 | self.logger.info("calculating generalized eigenvalue problem ( A*u = (lambda)*B*u )") 67 | # left = A, right = B 68 | left = 0.5 * np.vstack([ 69 | np.hstack([np.zeros_like(c00), c01, c02]), 70 | np.hstack([c01.T, np.zeros_like(c11), c12]), 71 | np.hstack([c02.T, c12.T, np.zeros_like(c22)]) 72 | ]) 73 | right = np.vstack([ 74 | np.hstack([c00, np.zeros_like(c01), np.zeros_like(c02)]), 75 | np.hstack([np.zeros_like(c01.T), c11, np.zeros_like(c12)]), 76 | np.hstack([np.zeros_like(c02.T), np.zeros_like(c12.T), c22]) 77 | ]) 78 | 79 | # calc GEV 80 | self.logger.info("solving") 81 | eigvals, eigvecs = self.solve_eigprob(left, right) 82 | h_list = [eigvecs[start:end] for start, end in zip(d_list[0:-1], d_list[1:])] 83 | h_list_norm = [ self.eigvec_normalization(h, cov_mat[i][i]) for i, h in enumerate(h_list)] 84 | 85 | # substitute local variables for member variables 86 | self.data_num = data_num 87 | self.cov_mat = cov_mat 88 | self.h_list = h_list_norm 89 | self.eigvals = eigvals 90 | 91 | def main(): 92 | 93 | # set log level 94 | logging.root.setLevel(level=logging.INFO) 95 | 96 | # create data in advance 97 | digit = load_digits() 98 | a = digit.data[:150, 0::3] 99 | b = digit.data[:150, 1::3] 100 | c = digit.data[:150, 2::3] 101 | # a = np.random.rand(100, 50) 102 | # b = np.random.rand(100, 60) 103 | # c = np.random.rand(100, 70) 104 | 105 | # create instance of BridgedCCA 106 | bcca = BridgedCCA(reg_param=0.0001) 107 | # calculate BridgedCCA 108 | bcca.fit(a[:50], b[:50], b[50:100], c[50:100]) 109 | # transform 110 | bcca.transform(a[100:], b[100:], c[100:]) 111 | # save 112 | bcca.save_params("save/bcca.h5") 113 | # load 114 | bcca.load_params("save/bcca.h5") 115 | # calc correlations 116 | bcca.calc_correlations() 117 | # plot 118 | bcca.plot_result() 119 | 120 | if __name__=="__main__": 121 | 122 | main() 123 | -------------------------------------------------------------------------------- /gcca/cca.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #-*- coding: utf-8 -*- 3 | 4 | __author__ = 'rupy' 5 | 6 | from gcca import GCCA 7 | import numpy as np 8 | import logging 9 | import os 10 | import matplotlib.pyplot as plt 11 | from matplotlib import colors 12 | import h5py 13 | 14 | class CCA(GCCA): 15 | 16 | def __init__(self, n_components=2, reg_param=0.1): 17 | GCCA.__init__(self, n_components, reg_param) 18 | 19 | # log setting 20 | program = os.path.basename(__name__) 21 | self.logger = logging.getLogger(program) 22 | logging.basicConfig(format='%(asctime)s : %(name)s : %(levelname)s : %(message)s') 23 | 24 | self.z_p = np.array([]) 25 | 26 | def fit(self, x0, x1): 27 | 28 | x_list = [x0, x1] 29 | 30 | # data size check 31 | data_num = len(x_list) 32 | self.logger.info("data num is %d", data_num) 33 | for i, x in enumerate(x_list): 34 | self.logger.info("data shape x_%d: %s", i, x.shape) 35 | 36 | self.logger.info("normalizing") 37 | x_norm_list = [ self.normalize(x) for x in x_list] 38 | 39 | d_list = [0] + [sum([len(x.T) for x in x_list][:i + 1]) for i in xrange(data_num)] 40 | cov_mat = self.calc_cov_mat(x_norm_list) 41 | cov_mat = self.add_regularization_term(cov_mat) 42 | c_00 = cov_mat[0][0] 43 | c_01 = cov_mat[0][1] 44 | c_11 = cov_mat[1][1] 45 | 46 | self.logger.info("calculating generalized eigenvalue problem ( A*u = (lambda)*B*u )") 47 | 48 | # 1 49 | left_1 = np.dot(c_01, np.linalg.solve(c_11,c_01.T)) 50 | right_1 = c_00 51 | eigvals_1, eigvecs_1 = self.solve_eigprob(left_1, right_1) 52 | eigvecs_1_norm = self.eigvec_normalization(eigvecs_1, right_1) 53 | # 2 54 | right_2 = c_11 55 | eigvecs_2 = 1 / eigvals_1 * np.dot(np.linalg.solve(c_11, c_01.T), eigvecs_1_norm) 56 | eigvecs_2_norm = self.eigvec_normalization(eigvecs_2, right_2) 57 | 58 | # substitute local variables for member variables 59 | self.data_num = data_num 60 | self.cov_mat = cov_mat 61 | self.h_list = [eigvecs_1_norm, eigvecs_2_norm] 62 | self.eigvals = eigvals_1 63 | 64 | def ptransform(self, x0, x1, beta=0.5): 65 | 66 | x0_projected, x1_projected = self.transform(x0, x1) 67 | 68 | I = np.eye(len(self.eigvals)) 69 | lamb = np.diag(self.eigvals) 70 | mat1 = np.linalg.solve(I - np.diag(self.eigvals**2), I) 71 | mat2 = -np.dot(mat1, lamb) 72 | mat12 = np.vstack((mat1, mat2)) 73 | mat21 = np.vstack((mat2, mat1)) 74 | mat = np.hstack((mat12, mat21)) 75 | p = np.vstack((lamb**beta, lamb**(1-beta))) 76 | q = np.vstack((x0_projected.T, x1_projected.T)) 77 | z = np.dot(p.T, np.dot(mat, q)).T[:,:self.n_components] 78 | 79 | self.z_p = z 80 | 81 | return x0_projected, x1_projected, z 82 | 83 | def save_params(self, filepath): 84 | 85 | GCCA.save_params(self, filepath) 86 | if len(self.z_p) != 0: 87 | with h5py.File(filepath, 'a') as f: 88 | f.create_dataset("z_p", data=self.z_p) 89 | f.flush() 90 | 91 | def load_params(self, filepath): 92 | 93 | GCCA.load_params(self, filepath) 94 | 95 | with h5py.File(filepath, "r") as f: 96 | if "z_p" in f: 97 | self.z_p = f["z_p"].value 98 | f.flush() 99 | 100 | def plot_result(self): 101 | 102 | self.logger.info("plotting result") 103 | row_num = 2 104 | col_num = 2 105 | 106 | # begin plot 107 | plt.figure() 108 | 109 | color_list = colors.cnames.keys() 110 | plt.subplot(row_num, col_num, 1) 111 | plt.plot(self.z_list[0][:, 0], self.z_list[0][:, 1], c=color_list[0], marker='.', ls=' ') 112 | plt.title("Z_0(CCA)") 113 | plt.subplot(row_num, col_num, 2) 114 | plt.plot(self.z_list[1][:, 0], self.z_list[1][:, 1], c=color_list[1], marker='.', ls=' ') 115 | plt.title('Z_1(CCA)') 116 | 117 | plt.subplot(row_num, col_num, 3) 118 | plt.plot(self.z_list[0][:, 0], self.z_list[0][:, 1], c=color_list[0], marker='.', ls=' ') 119 | plt.plot(self.z_list[1][:, 0], self.z_list[1][:, 1], c=color_list[1], marker='.', ls=' ') 120 | plt.title('Z_ALL(CCA)') 121 | 122 | if len(self.z_p) != 0: 123 | plt.subplot(row_num, col_num, 4) 124 | plt.plot(self.z_p[:, 0], self.z_p[:, 1], c=color_list[2], marker='.', ls=' ') 125 | plt.title('Z(PCCA)') 126 | 127 | plt.show() 128 | 129 | def main(): 130 | 131 | # set log level 132 | logging.root.setLevel(level=logging.INFO) 133 | 134 | # create data in advance 135 | a = np.random.rand(50, 50) 136 | b = np.random.rand(50, 60) 137 | 138 | # create instance of CCA 139 | cca = CCA() 140 | # calculate CCA 141 | cca.fit(a, b) 142 | # transform 143 | cca.transform(a, b) 144 | # transform by PCCA 145 | cca.ptransform(a, b) 146 | # save 147 | cca.save_params("save/cca.h5") 148 | # load 149 | cca.load_params("save/cca.h5") 150 | # plot 151 | cca.plot_result() 152 | # calc correlations 153 | cca.calc_correlations() 154 | 155 | if __name__=="__main__": 156 | 157 | main() -------------------------------------------------------------------------------- /gcca/hierarchical_cca.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #-*- coding: utf-8 -*- 3 | 4 | __author__ = 'rupy' 5 | 6 | import numpy as np 7 | 8 | from gcca import GCCA 9 | from cca import CCA 10 | import logging 11 | import h5py 12 | 13 | class HierarchicalCCA(GCCA): 14 | 15 | def __init__(self, n_components=2, reg_param=0.1): 16 | GCCA.__init__(self, n_components, reg_param) 17 | 18 | self.cca1 = CCA(self.n_components, self.reg_param) 19 | self.cca2 = CCA(self.n_components, self.reg_param) 20 | 21 | self.z_list = [] 22 | 23 | def fit(self, x0, x1, x2): 24 | 25 | self.data_num = 3 26 | 27 | # 1 28 | self.cca1.fit(x0, x1) 29 | self.cca1.transform(x0, x1) 30 | 31 | # 2 32 | z0 = self.cca1.z_list[0] 33 | z1 = self.cca1.z_list[1] 34 | z_all = np.vstack([z0, z1]) 35 | x_dup = np.vstack([x2, x2]) 36 | self.cca2.fit(z_all, x_dup) 37 | 38 | 39 | def transform(self, x0, x1, x2): 40 | 41 | # 1 42 | self.cca1.transform(x0, x1) 43 | 44 | # 2 45 | z0 = self.cca1.z_list[0] 46 | z1 = self.cca1.z_list[1] 47 | z_all = np.vstack([z0, z1]) 48 | x_dup = np.vstack([x2, x2]) 49 | self.cca2.transform(z_all, x_dup) 50 | w_all, w2_dup = self.cca2.z_list 51 | w0 = w_all[:z0.shape[0]] 52 | w1 = w_all[z0.shape[0]:] 53 | w2 = w2_dup[:x2.shape[0]] 54 | 55 | self.z_list = [w0, w1, w2] 56 | 57 | def save_params(self, filepath): 58 | self.logger.info("saving hierarchical cca to %s", filepath) 59 | with h5py.File(filepath, 'w') as f: 60 | f.create_dataset("n_components", data=self.n_components) 61 | f.create_dataset("reg_param", data=self.reg_param) 62 | f.create_dataset("data_num_all", data=self.data_num) 63 | f.create_dataset("data_num1", data=self.cca1.data_num) 64 | f.create_dataset("data_num2", data=self.cca2.data_num) 65 | 66 | cov_grp1 = f.create_group("cov_mat1") 67 | for i, row in enumerate(self.cca1.cov_mat): 68 | for j, cov in enumerate(row): 69 | cov_grp1.create_dataset(str(i) + "_" + str(j), data=cov) 70 | 71 | cov_grp2 = f.create_group("cov_mat2") 72 | for i, row in enumerate(self.cca2.cov_mat): 73 | for j, cov in enumerate(row): 74 | cov_grp2.create_dataset(str(i) + "_" + str(j), data=cov) 75 | 76 | h_grp1 = f.create_group("h_list1") 77 | for i, h in enumerate(self.cca1.h_list): 78 | h_grp1.create_dataset(str(i), data=h) 79 | 80 | h_grp2 = f.create_group("h_list2") 81 | for i, h in enumerate(self.cca2.h_list): 82 | h_grp2.create_dataset(str(i), data=h) 83 | 84 | f.create_dataset("eig_vals1", data=self.cca1.eigvals) 85 | f.create_dataset("eig_vals2", data=self.cca2.eigvals) 86 | 87 | if len(self.cca1.z_list) != 0: 88 | z_grp1 = f.create_group("z_list1") 89 | for i, z in enumerate(self.z_list): 90 | z_grp1.create_dataset(str(i), data=z) 91 | 92 | if len(self.cca2.z_list) != 0: 93 | z_grp2 = f.create_group("z_list2") 94 | for i, z in enumerate(self.cca2.z_list): 95 | z_grp2.create_dataset(str(i), data=z) 96 | 97 | 98 | if len(self.z_list) != 0: 99 | z_grp3 = f.create_group("z_list_all") 100 | for i, z in enumerate(self.z_list): 101 | z_grp3.create_dataset(str(i), data=z) 102 | 103 | f.flush() 104 | 105 | def load_params(self, filepath): 106 | self.logger.info("loading hierarchical cca from %s", filepath) 107 | with h5py.File(filepath, "r") as f: 108 | self.n_components = f["n_components"].value 109 | self.reg_param = f["reg_param"].value 110 | self.cca1.n_components = self.n_components 111 | self.cca1.reg_param = self.reg_param 112 | self.data_num = f["data_num_all"].value 113 | self.cca1.data_num = f["data_num1"].value 114 | self.cca2.data_num = f["data_num2"].value 115 | 116 | self.cca1.cov_mat = [[np.array([]) for col in range(self.cca1.data_num)] for row in range(self.cca1.data_num)] 117 | self.cca2.cov_mat = [[np.array([]) for col in range(self.cca2.data_num)] for row in range(self.cca2.data_num)] 118 | 119 | for i in xrange(self.cca1.data_num): 120 | for j in xrange(self.cca1.data_num): 121 | self.cca1.cov_mat[i][j] = f["cov_mat1/" + str(i) + "_" + str(j)] 122 | 123 | for i in xrange(self.cca2.data_num): 124 | for j in xrange(self.cca2.data_num): 125 | self.cca2.cov_mat[i][j] = f["cov_mat2/" + str(i) + "_" + str(j)] 126 | 127 | self.cca1.h_list = [None] * self.data_num 128 | for i in xrange(self.cca1.data_num): 129 | self.cca1.h_list[i] = f["h_list1/" + str(i)].value 130 | self.cca2.h_list = [None] * self.data_num 131 | for i in xrange(self.cca2.data_num): 132 | self.cca2.h_list[i] = f["h_list2/" + str(i)].value 133 | self.cca1.eig_vals = f["eig_vals1"].value 134 | self.cca2.eig_vals = f["eig_vals2"].value 135 | 136 | if "z_list1" in f: 137 | self.cca1.z_list = [None] * self.cca2.data_num 138 | for i in xrange(self.cca1.data_num): 139 | self.cca1.z_list[i] = f["z_list1/" + str(i)].value 140 | 141 | if "z_list2" in f: 142 | self.cca2.z_list = [None] * self.cca2.data_num 143 | for i in xrange(self.cca2.data_num): 144 | self.cca2.z_list[i] = f["z_list2/" + str(i)].value 145 | 146 | if "z_list_all" in f: 147 | self.z_list = [None] * self.data_num 148 | for i in xrange(self.data_num): 149 | self.z_list[i] = f["z_list_all/" + str(i)].value 150 | 151 | f.flush() 152 | 153 | def main(): 154 | # set log level 155 | logging.root.setLevel(level=logging.INFO) 156 | 157 | # create data in advance 158 | a = np.random.rand(50, 50) 159 | b = np.random.rand(50, 60) 160 | c = np.random.rand(50, 70) 161 | d = np.random.rand(50, 80) 162 | e = np.random.rand(50, 90) 163 | f = np.random.rand(50, 100) 164 | g = np.random.rand(50, 110) 165 | h = np.random.rand(50, 120) 166 | i = np.random.rand(50, 130) 167 | j = np.random.rand(50, 140) 168 | k = np.random.rand(50, 150) 169 | 170 | # create instance of GCCA 171 | hcca = HierarchicalCCA(reg_param=0.01) 172 | # calculate GCCA 173 | hcca.fit(a, b, c) 174 | # transform 175 | hcca.transform(a, b, c) 176 | # save 177 | hcca.save_params("save/hcca.h5") 178 | # load 179 | hcca.load_params("save/hcca.h5") 180 | # plot 181 | hcca.plot_result() 182 | # calc correlations 183 | hcca.calc_correlations() 184 | 185 | if __name__=="__main__": 186 | 187 | main() -------------------------------------------------------------------------------- /gcca/gcca.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #-*- coding: utf-8 -*- 3 | 4 | __author__ = 'rupy' 5 | 6 | import numpy as np 7 | from scipy.linalg import eig 8 | import logging 9 | import os 10 | import matplotlib.pyplot as plt 11 | import math 12 | from matplotlib import colors 13 | import h5py 14 | 15 | class GCCA: 16 | 17 | def __init__(self, n_components=2, reg_param=0.1): 18 | 19 | # log setting 20 | program = os.path.basename(__name__) 21 | self.logger = logging.getLogger(program) 22 | logging.basicConfig(format='%(asctime)s : %(name)s : %(levelname)s : %(message)s') 23 | 24 | # GCCA params 25 | self.n_components = n_components 26 | self.reg_param = reg_param 27 | 28 | # result of fitting 29 | self.data_num = 0 30 | self.cov_mat = [[]] 31 | self.h_list = [] 32 | self.eigvals = np.array([]) 33 | 34 | # result of transformation 35 | self.z_list = [] 36 | 37 | def eigvec_normalization(self, eig_vecs, x_var): 38 | self.logger.info("normalization") 39 | z_var = np.dot(eig_vecs.T, np.dot(x_var, eig_vecs)) 40 | invvar = np.diag(np.reciprocal(np.sqrt(np.diag(z_var)))) 41 | eig_vecs = np.dot(eig_vecs, invvar) 42 | # print np.dot(eig_vecs.T, np.dot(x_var, eig_vecs)).round().astype(int) 43 | return eig_vecs 44 | 45 | 46 | def solve_eigprob(self, left, right): 47 | 48 | self.logger.info("calculating eigen dimension") 49 | eig_dim = min([np.linalg.matrix_rank(left), np.linalg.matrix_rank(right)]) 50 | 51 | self.logger.info("calculating eigenvalues & eigenvector") 52 | eig_vals, eig_vecs = eig(left, right) 53 | 54 | self.logger.info("sorting eigenvalues & eigenvector") 55 | sort_indices = np.argsort(eig_vals)[::-1] 56 | eig_vals = eig_vals[sort_indices][:eig_dim].real 57 | eig_vecs = eig_vecs[:,sort_indices][:,:eig_dim].real 58 | 59 | return eig_vals, eig_vecs 60 | 61 | def calc_cov_mat(self, x_list): 62 | 63 | data_num = len(x_list) 64 | 65 | self.logger.info("calc variance & covariance matrix") 66 | z = np.vstack([x.T for x in x_list]) 67 | cov = np.cov(z) 68 | d_list = [0] + [sum([len(x.T) for x in x_list][:i + 1]) for i in xrange(data_num)] 69 | cov_mat = [[np.array([]) for col in range(data_num)] for row in range(data_num)] 70 | for i in xrange(data_num): 71 | for j in xrange(data_num): 72 | i_start, i_end = d_list[i], d_list[i + 1] 73 | j_start, j_end = d_list[j], d_list[j + 1] 74 | cov_mat[i][j] = cov[i_start:i_end, j_start:j_end] 75 | 76 | return cov_mat 77 | 78 | def add_regularization_term(self, cov_mat): 79 | 80 | data_num = len(cov_mat) 81 | 82 | # regularization 83 | self.logger.info("adding regularization term") 84 | for i in xrange(data_num): 85 | cov_mat[i][i] += self.reg_param * np.average(np.diag(cov_mat[i][i])) * np.eye(cov_mat[i][i].shape[0]) 86 | 87 | return cov_mat 88 | 89 | def fit(self, *x_list): 90 | 91 | # data size check 92 | data_num = len(x_list) 93 | self.logger.info("data num is %d", data_num) 94 | for i, x in enumerate(x_list): 95 | self.logger.info("data shape x_%d: %s", i, x.shape) 96 | 97 | self.logger.info("normalizing") 98 | x_norm_list = [ self.normalize(x) for x in x_list] 99 | 100 | d_list = [0] + [sum([len(x.T) for x in x_list][:i + 1]) for i in xrange(data_num)] 101 | cov_mat = self.calc_cov_mat(x_norm_list) 102 | cov_mat = self.add_regularization_term(cov_mat) 103 | 104 | self.logger.info("calculating generalized eigenvalue problem ( A*u = (lambda)*B*u )") 105 | # left = A, right = B 106 | left = 0.5 * np.vstack( 107 | [ 108 | np.hstack([np.zeros_like(cov_mat[i][j]) if i == j else cov_mat[i][j] for j in xrange(data_num)]) 109 | for i in xrange(data_num) 110 | ] 111 | ) 112 | right = np.vstack( 113 | [ 114 | np.hstack([np.zeros_like(cov_mat[i][j]) if i != j else cov_mat[i][j] for j in xrange(data_num)]) 115 | for i in xrange(data_num) 116 | ] 117 | ) 118 | 119 | # calc GEV 120 | self.logger.info("solving") 121 | eigvals, eigvecs = self.solve_eigprob(left, right) 122 | 123 | h_list = [eigvecs[start:end] for start, end in zip(d_list[0:-1], d_list[1:])] 124 | h_list_norm = [self.eigvec_normalization(h, cov_mat[i][i]) for i, h in enumerate(h_list)] 125 | 126 | # substitute local variables for member variables 127 | self.data_num = data_num 128 | self.cov_mat = cov_mat 129 | self.h_list = h_list_norm 130 | self.eigvals = eigvals 131 | 132 | def transform(self, *x_list): 133 | 134 | # data size check 135 | data_num = len(x_list) 136 | self.logger.info("data num is %d", data_num) 137 | for i, x in enumerate(x_list): 138 | self.logger.info("data shape x_%d: %s", i, x.shape) 139 | 140 | if self.data_num != data_num: 141 | raise Exception('data num when fitting is different from data num to be transformed') 142 | 143 | self.logger.info("normalizing") 144 | x_norm_list = [ self.normalize(x) for x in x_list] 145 | 146 | self.logger.info("transform matrices by GCCA") 147 | z_list = [np.dot(x, h_vec) for x, h_vec in zip(x_norm_list, self.h_list)] 148 | 149 | self.z_list = z_list 150 | 151 | return z_list 152 | 153 | def fit_transform(self, *x_list): 154 | self.fit(x_list) 155 | self.transform(x_list) 156 | 157 | @staticmethod 158 | def normalize(mat): 159 | m = np.mean(mat, axis=0) 160 | mat = mat - m 161 | return mat 162 | 163 | def save_params(self, filepath): 164 | 165 | self.logger.info("saving to %s", filepath) 166 | with h5py.File(filepath, 'w') as f: 167 | f.create_dataset("n_components", data=self.n_components) 168 | f.create_dataset("reg_param", data=self.reg_param) 169 | f.create_dataset("data_num", data=self.data_num) 170 | 171 | cov_grp = f.create_group("cov_mat") 172 | for i, row in enumerate(self.cov_mat): 173 | for j, cov in enumerate(row): 174 | cov_grp.create_dataset(str(i) + "_" + str(j), data=cov) 175 | 176 | h_grp = f.create_group("h_list") 177 | for i, h in enumerate(self.h_list): 178 | h_grp.create_dataset(str(i), data=h) 179 | 180 | f.create_dataset("eig_vals", data=self.eigvals) 181 | 182 | if len(self.z_list) != 0: 183 | z_grp = f.create_group("z_list") 184 | for i, z in enumerate(self.z_list): 185 | z_grp.create_dataset(str(i), data=z) 186 | f.flush() 187 | 188 | def load_params(self, filepath): 189 | self.logger.info("loading from %s", filepath) 190 | with h5py.File(filepath, "r") as f: 191 | self.n_components = f["n_components"].value 192 | self.reg_param = f["reg_param"].value 193 | self.data_num = f["data_num"].value 194 | 195 | self.cov_mat = [[np.array([]) for col in range(self.data_num)] for row in range(self.data_num)] 196 | for i in xrange(self.data_num): 197 | for j in xrange(self.data_num): 198 | self.cov_mat[i][j] = f["cov_mat/" + str(i) + "_" + str(j)] 199 | self.h_list = [None] * self.data_num 200 | for i in xrange(self.data_num): 201 | self.h_list[i] = f["h_list/" + str(i)].value 202 | self.eig_vals = f["eig_vals"].value 203 | 204 | if "z_list" in f: 205 | self.z_list = [None] * self.data_num 206 | for i in xrange(self.data_num): 207 | self.z_list[i] = f["z_list/" + str(i)].value 208 | f.flush() 209 | 210 | def plot_result(self): 211 | 212 | self.logger.info("plotting result") 213 | col_num = int(math.ceil(math.sqrt(self.data_num + 1))) 214 | row_num = int((self.data_num + 1) / float(col_num)) 215 | if row_num != (self.data_num + 1) / float(col_num): 216 | row_num += 1 217 | 218 | # begin plot 219 | plt.figure() 220 | 221 | color_list = colors.cnames.keys() 222 | for i in xrange(self.data_num): 223 | 224 | plt.subplot(row_num, col_num, i + 1) 225 | plt.plot(self.z_list[i][:, 0], self.z_list[i][:, 1], c=color_list[i], marker='.', ls=' ') 226 | plt.title("Z_%d(GCCA)" % (i + 1)) 227 | 228 | plt.subplot(row_num, col_num, self.data_num + 1) 229 | for i in xrange(self.data_num): 230 | plt.plot(self.z_list[i][:, 0], self.z_list[i][:, 1], c=color_list[i], marker='.', ls=' ') 231 | plt.title("Z_ALL(GCCA)") 232 | 233 | plt.show() 234 | 235 | def calc_correlations(self): 236 | for i, z_i in enumerate(self.z_list): 237 | for j, z_j in enumerate(self.z_list): 238 | if i < j: 239 | print "(%d, %d): %f" % (i, j, np.corrcoef(z_i[:,0], z_j[:,0])[0, 1]) 240 | 241 | def main(): 242 | 243 | # set log level 244 | logging.root.setLevel(level=logging.INFO) 245 | 246 | # create data in advance 247 | a = np.random.rand(50, 50) 248 | b = np.random.rand(50, 60) 249 | c = np.random.rand(50, 70) 250 | d = np.random.rand(50, 80) 251 | e = np.random.rand(50, 90) 252 | f = np.random.rand(50, 100) 253 | g = np.random.rand(50, 110) 254 | h = np.random.rand(50, 120) 255 | i = np.random.rand(50, 130) 256 | j = np.random.rand(50, 140) 257 | k = np.random.rand(50, 150) 258 | 259 | # create instance of GCCA 260 | gcca = GCCA(reg_param=0.01) 261 | # calculate GCCA 262 | gcca.fit(a, b, c, d, e, f, g, h, i, j, k) 263 | # transform 264 | gcca.transform(a, b, c, d, e, f, g, h, i, j, k) 265 | # save 266 | gcca.save_params("save/gcca.h5") 267 | # load 268 | gcca.load_params("save/gcca.h5") 269 | # plot 270 | gcca.plot_result() 271 | # calc correlations 272 | gcca.calc_correlations() 273 | 274 | if __name__=="__main__": 275 | main() --------------------------------------------------------------------------------