├── README.md ├── SIMLR_PY ├── LICENSE.txt ├── MANIFEST ├── Makefile ├── README.md ├── SIMLR.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ └── top_level.txt ├── SIMLR │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ └── __init__.cpython-39.pyc │ ├── core.py │ ├── core.pyc │ ├── helper.py │ └── helper.pyc ├── build │ ├── lib.linux-x86_64-2.7 │ │ └── SIMLR │ │ │ ├── __init__.py │ │ │ ├── core.py │ │ │ ├── helper.py │ │ │ └── helpers.py │ └── lib │ │ └── SIMLR │ │ ├── __init__.py │ │ ├── core.py │ │ └── helper.py ├── dist │ ├── SIMLR-0.0.0.tar.gz │ ├── SIMLR-0.1.0.tar.gz │ ├── SIMLR-0.1.1.tar.gz │ └── SIMLR-0.1.3.tar.gz ├── requirements.txt ├── setup.py └── tests │ ├── Zeisel.mat │ └── test_largescale.py ├── centrality.py ├── data_loader.py ├── fig1.png ├── fig2.png ├── main.py ├── model.py ├── plot.py ├── prediction.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # MultiGraphGAN 2 | MultiGraphGAN for jointly predicting multiple brain graphs from a single brain graph, coded up in Python by Alaa Bessadok. Please contact alaa.bessadok@gmail.Com for further inquiries. Thanks. 3 | 4 | This repository provides the official PyTorch implementation of the following paper: 5 | 6 |

7 | 8 |

9 | 10 | 11 | > **Topology-Aware Generative Adversarial Network for Joint Prediction of Multiple Brain Graphs from a Single Brain Graph** 12 | > [Alaa Bessadok](https://github.com/AlaaBessadok)1,2, [Mohamed Ali Mahjoub]2, [Islem Rekik](https://basira-lab.com/)1 13 | > 1BASIRA Lab, Faculty of Computer and Informatics, Istanbul Technical University, Istanbul, Turkey 14 | > 2University of Sousse, Higher Institute of Informatics and Communication Technologies, Sousse, Tunisia 15 | > 16 | > **Abstract:** *Multimodal medical datasets with incomplete observations present a barrier to large-scale neuroscience studies. Several works based on Generative Adversarial Networks (GAN) have been recently proposed to predict a set of medical images from a single modality (e.g, FLAIR 17 | MRI from T1 MRI). However, such frameworks are primarily designed to operate on images, limiting their generalizability to non-Euclidean geometric data such as brain graphs. While a growing number of connectomic studies has demonstrated the promise of including brain graphs for diagnosing neurological disorders, no geometric deep learning work was designed for multiple target brain graphs prediction from a source brain graph. Despite the momentum the field of graph generation has gained 18 | in the last two years, existing works have two critical drawbacks. First, the bulk of such works aims to learn one model for each target domain to generate from a source domain. Thus, they have a limited scalability in jointly predicting multiple target domains. Second, they merely consider the global topological scale of a graph (i.e., graph connectivity structure) and overlook the local topology at the node scale of a graph (e.g., how central a node is in the graph). To meet these challenges, we introduce MultiGraphGAN architecture, which not only predicts multiple brain graphs from a single brain graph but also preserves the topological structure of each target graph to predict. Its three core contributions lie in: (i) designing a graph adversarial auto-encoder for jointly predicting brain graphs from a single one, (ii) handling the mode collapse problem of GAN by clustering the encoded source graphs and proposing a cluster-specific decoder, (iii) introducing a topological loss to force the reconstruction of topologically sound target brain graphs. Our MultiGraphGAN significantly outperformed its variants thereby showing its great potential in multi-view brain graph generation from a single graph. Our code is available at https://github.com/basiralab/MultiGraphGAN.* 19 | 20 | This work is published in MICCAI 2020, Lima, Peru. MultiGraphGAN is a geometric deep learning framework for jointly predicting multiple brain graphs from a single graph. Using an end-to-end learning fashion, it preserves the topological structure of each target graph. Our MultiGraphGAN framework comprises two key steps (1) source graphs embedding and clustering and, (2) cluster-specific multi-target graph prediction. We have evaluated our method on ABIDE dataset. Detailed information can be found in the original paper and the video in the BASIRA Lab YouTube channel. In this repository, we release the code for training and testing MultiGraphGAN on a simulated dataset. 21 | 22 | # Installation 23 | 24 | The code has been tested with Python 3, PyTorch 1.3.1 on Ubuntu 16.04. GPU is required to run the code. You also need other dependencies (e.g., numpy, yaml, networkx, SIMLR) which can be installed via: 25 | 26 | ```bash 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | # Training and testing MultiGraphGAN 31 | 32 | We provide a demo code for the usage of MultiGraphGAN for multiple target graphs prediction from a source graph. In main.py we train MultiGraphGAN on a simulated dataset with 280 subjects and test it on 30 subjects. Each sample has 6 brain graphs (one source graph and five target graphs). In this example, we used three input arguments (i.e., num_domains, nb_clusters and mode), you can add hyper-parameters (e.g., lambda_topology, lambda_rec) and vary their default values. 33 | 34 | You can train the program with the following command: 35 | 36 | ```bash 37 | python main.py --num_domains=6 --nb_clusters=2 --mode='train' 38 | ``` 39 | 40 | In this example, we simulated a training dataset with 280 samples and a testing set with 30 samples. If you want to test the code using the hyperparameters described in the paper, type in the terminal the following commande: 41 | 42 | ```bash 43 | python main.py --num_domains=6 --nb_clusters=2 --mode='test' 44 | ``` 45 | 46 | # Input data 47 | 48 | In order to use our framework, you need to provide: 49 | 50 | * a source_target_domains list where each element is a matrix of size (n * f). We denote n the total number of subjects in the dataset and f the number of features. Any element of the list can be considered as the source domain and the rest are the target domains. You need to include your data in the file main.py. So, just remove our simulated training and testing dataset and replace it with yours. 51 | 52 | # Output Data 53 | 54 | If you set the number of source and target domains to 3 using this argument --num_domains=3 , and keep the same size of our simulated data, the execution of main.py will produce saved csv files of the source and target data. Then, you can plot the brain graphs of any subject from the saved csv files. To do so, run the plot.py to get the following outputs especially when running the demo with default parameter setting: 55 | 56 |

57 | 58 |

59 | 60 | # YouTube videos to install and run the code and understand how MultiGraphGAN works 61 | 62 | To install and run MultiGraphGAN, check the following YouTube video: 63 | 64 | https://youtu.be/JvT5XtAgbUk 65 | 66 | To learn about how MultiGraphGAN works, check the following YouTube videos: 67 | 68 | Short version (10mn): https://youtu.be/vEnzMQqbdHc 69 | 70 | Long version (20mn): https://youtu.be/yNx7H9NLzlE 71 | 72 | # Related references 73 | 74 | Multi-Marginal Wasserstein GAN (MWGAN): 75 | Cao, J., Mo, L., Zhang, Y., Jia, K., Shen, C., Tan, M.: Multi-marginal wasserstein gan. [https://arxiv.org/pdf/1911.00888.pdf] (2019) [https://github.com/caojiezhang/MWGAN]. 76 | 77 | Single‐cell Interpretation via Multi‐kernel LeaRning (SIMLR): 78 | Wang, B., Ramazzotti, D., De Sano, L., Zhu, J., Pierson, E., Batzoglou, S.: SIMLR: a tool for large-scale single-cell analysis by multi-kernel learning. [https://www.biorxiv.org/content/10.1101/052225v3] (2017) [https://github.com/bowang87/SIMLR_PY]. 79 | 80 | 81 | # Citation 82 | 83 | If our code is useful for your work please cite our paper: 84 | 85 | ```latex 86 | @inproceedings{bessadok2020, 87 | title={Topology-Aware Generative Adversarial Network for Joint Prediction of Multiple Brain Graphs from a Single Brain Graph}, 88 | author={Bessadok, Alaa and Mahjoub, Mohamed Ali and Rekik, Islem}, 89 | booktitle={ International Conference on Medical Image Computing and Computer Assisted Intervention}, 90 | year={2020}, 91 | organization={Springer} 92 | } 93 | ``` 94 | 95 | # MultiGraphGAN on arXiv 96 | 97 | https://arxiv.org/abs/2009.11058 98 | 99 | # Acknowledgement 100 | 101 | This project has been funded by the 2232 International Fellowship for Outstanding Researchers Program of TUBITAK (Project No:118C288, http://basira-lab.com/reprime/) supporting Dr. Islem Rekik. However, all scientific contributions made in this project are owned and approved solely by the authors. 102 | 103 | # License 104 | Our code is released under MIT License (see LICENSE file for details). 105 | 106 | 107 | -------------------------------------------------------------------------------- /SIMLR_PY/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright <2017> 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /SIMLR_PY/MANIFEST: -------------------------------------------------------------------------------- 1 | # file GENERATED by distutils, do NOT edit 2 | setup.py 3 | SIMLR/__init__.py 4 | SIMLR/core.py 5 | SIMLR/helper.py 6 | -------------------------------------------------------------------------------- /SIMLR_PY/Makefile: -------------------------------------------------------------------------------- 1 | init: 2 | pip install -r requirements.txt 3 | test: 4 | py.test tests 5 | .PHONY: init test 6 | -------------------------------------------------------------------------------- /SIMLR_PY/README.md: -------------------------------------------------------------------------------- 1 | SIMLR 2 | ============================================================ 3 | This is a python implementation of the paper published in Nature Methods titled as "Visualization and analysis of single-cell RNA-seq data by kernel-based similarity learning". 4 | 5 | 6 | OVERVIEW 7 | ============================================================ 8 | 9 | Single-cell RNA-seq technologies enable high throughput gene expression measurement of individual cells, and allow the discovery of heterogeneity within cell populations. Measurement of cell-to-cell gene expression similarity is critical to identification, visualization and analysis of cell populations. However, single-cell data introduce challenges to conventional measures of gene expression similarity because of the high level of noise, outliers and dropouts. We develop 10 | a novel similarity-learning framework, SIMLR (Single-cell Interpretation via Multi-kernel LeaRning), which learns an appropriate distance metric from the data for dimension reduction, clustering and visualization. SIMLR is capable of separating known subpopulations more accurately in single-cell data sets than do existing dimension reduction methods. Additionally, SIMLR demonstrates high sensitivity and accuracy on high-throughput peripheral blood mononuclear cells 11 | (PBMC) data sets generated by the GemCode single-cell technology from 10x Genomics. 12 | 13 | IMPLEMENTATIONS 14 | ============================================================ 15 | We provide implementations of SIMLR for large scale single-cell RNA-seq data. With small dataset (e.g, dataset with less than 3,000 cells), we recommend the user to use the matlab package or R package from https://github.com/BatzoglouLabSU/SIMLR. For Large dataset (with more than 3,000 cells), we recommend the user to use the python function called "SIMLR_LARGE". 16 | 17 | This large-scale implementation uses approximate version of SIMLR to address the computational issue. 18 | 19 | DEMO 20 | ============================================================ 21 | We provide two demos for the usage of SIMLR in large scale. In test_largescale.py we run SIMLR on Zeisel dataset with 3005 cells in our paper. 22 | 23 | DEBUG 24 | ============================================================ 25 | Please feel free to send us emails if you have touble running our SIMLR. The correspondence email is bowang87@stanford.edu 26 | 27 | Requirements 28 | ============================================================ 29 | 30 | - `numpy>=1.8` 31 | - `scipy>=0.13.2` 32 | - `annoy>=1.8` 33 | - `sklearn>=0.17` 34 | - `fbpca>=1.0` 35 | 36 | Installation 37 | ============================================================ 38 | python setup.py install 39 | or 40 | pip install SIMLR 41 | 42 | 43 | Tutorial 44 | ============================================================ 45 | see tests/test_largescale.py 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /SIMLR_PY/SIMLR.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: SIMLR 3 | Version: 0.1.0 4 | Summary: Visualization and analysis of single-cell RNA-seq data by kernel-based similarity learning 5 | Home-page: https://github.com/bowang87/SIMLR-PY 6 | Author: Bo Wang 7 | Author-email: bowang87@stanford.edu 8 | License: UNKNOWN 9 | Description: UNKNOWN 10 | Platform: UNKNOWN 11 | -------------------------------------------------------------------------------- /SIMLR_PY/SIMLR.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | SIMLR/__init__.py 3 | SIMLR/core.py 4 | SIMLR/helper.py 5 | SIMLR.egg-info/PKG-INFO 6 | SIMLR.egg-info/SOURCES.txt 7 | SIMLR.egg-info/dependency_links.txt 8 | SIMLR.egg-info/top_level.txt -------------------------------------------------------------------------------- /SIMLR_PY/SIMLR.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /SIMLR_PY/SIMLR.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | SIMLR 2 | -------------------------------------------------------------------------------- /SIMLR_PY/SIMLR/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import SIMLR_LARGE 2 | -------------------------------------------------------------------------------- /SIMLR_PY/SIMLR/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basiralab/MultiGraphGAN/7b6d2160daaf07c2462b11a7a621e5ead04cef0b/SIMLR_PY/SIMLR/__init__.pyc -------------------------------------------------------------------------------- /SIMLR_PY/SIMLR/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basiralab/MultiGraphGAN/7b6d2160daaf07c2462b11a7a621e5ead04cef0b/SIMLR_PY/SIMLR/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /SIMLR_PY/SIMLR/core.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions for large scale SIMLR and accuracy checks 3 | 4 | --------------------------------------------------------------------- 5 | 6 | This module contains the following functions: 7 | 8 | save_sparse_csr 9 | save a sparse csr format input of single-cell RNA-seq data 10 | load_sparse_csr 11 | load a sparse csr format input of single-cell RNA-seq data 12 | nearest_neighbor_search 13 | Approximate Nearset Neighbor search for every cell 14 | NE_dn 15 | Row-normalization of a matrix 16 | mex_L2_distance 17 | A fast way to calculate L2 distance 18 | Cal_distance_memory 19 | Calculate Kernels in a memory-saving mode 20 | mex_multipleK 21 | A fast way to calculate kernels 22 | Hbeta 23 | A simple LP method to solve linear weight 24 | euclidean_proj_simplex 25 | A fast way to calculate simplex projection 26 | fast_pca 27 | A fast randomized pca with sparse input 28 | fast_minibatch_kmeans 29 | A fast mini-batch version of k-means 30 | SIMLR_Large 31 | A large-scale implementation of our SIMLR 32 | --------------------------------------------------------------------- 33 | 34 | Copyright 2016 Bo Wang, Stanford University. 35 | All rights reserved. 36 | """ 37 | from __future__ import absolute_import 38 | from __future__ import division 39 | from __future__ import print_function 40 | from __future__ import unicode_literals 41 | from . import helper 42 | import numpy as np 43 | import sys 44 | import os 45 | from annoy import AnnoyIndex 46 | import scipy.io as sio 47 | from scipy.sparse import csr_matrix, csc_matrix, linalg 48 | from fbpca import svd, pca 49 | import time 50 | from sklearn.decomposition import TruncatedSVD 51 | from sklearn.cluster import MiniBatchKMeans, KMeans 52 | 53 | class SIMLR_LARGE(object): 54 | """A class for large-scale SIMLR. 55 | 56 | Attributes: 57 | num_of_rank: The rank hyper-parameter in SIMLR usually set to number of clusters. 58 | num_of_neighbors: the number of neighbors kept for each cell to approximate full cell similarities 59 | mode_of_memory: an indicator to open the memory-saving mode. This is helpful for datasets of millions of cells. It will sacrify a bit speed though. 60 | """ 61 | def __init__(self, num_of_rank, num_of_neighbor=30, mode_of_memory = False, max_iter = 5): 62 | self.num_of_rank = int(num_of_rank) 63 | self.num_of_neighbor = int(num_of_neighbor) 64 | self.mode_of_memory = mode_of_memory 65 | self.max_iter = int(max_iter) 66 | 67 | def nearest_neighbor_search(self, GE_csc): 68 | K = self.num_of_neighbor * 2 69 | n,d = GE_csc.shape 70 | t = AnnoyIndex(d) 71 | for i in range(n): 72 | t.add_item(i,GE_csc[i,:]) 73 | t.build(100) 74 | t.save('test.ann') 75 | u = AnnoyIndex(d) 76 | u.load('test.ann') 77 | os.remove('test.ann') 78 | val = np.zeros((n,K)) 79 | ind = np.zeros((n,K)) 80 | for i in range(n): 81 | tmp, tmp1 = u.get_nns_by_item(i,K, include_distances=True) 82 | ind[i,:] = tmp 83 | val[i,:] = tmp1 84 | return ind.astype('int'), val 85 | 86 | 87 | 88 | 89 | def mex_L2_distance(self, F, ind): 90 | m,n = ind.shape 91 | I = np.tile(np.arange(m), n) 92 | if self.mode_of_memory: 93 | temp = np.zeros((m,n)) 94 | for i in range(n): 95 | temptemp = np.take(F, np.arange(m), axis = 0) - np.take(F, ind[:,i],axis=0) 96 | temp[:,i] = (temptemp*temptemp).sum(axis=1) 97 | return temp 98 | else: 99 | temp = np.take(F, I, axis = 0) - np.take(F, ind.ravel(order = 'F'),axis=0) 100 | temp = (temp*temp).sum(axis=1) 101 | return temp.reshape((m,n),order = 'F') 102 | 103 | 104 | 105 | def Cal_distance_memory(self, S, alpha): 106 | NT = len(alpha) 107 | DD = alpha.copy() 108 | for i in range(NT): 109 | temp = np.load('Kernel_' + str(i)+'.npy') 110 | if i == 0: 111 | distX = alpha[0]*temp 112 | else: 113 | distX += alpha[i]*temp 114 | DD[i] = ((temp*S).sum(axis = 0)/(S.shape[0]+0.0)).mean(axis = 0) 115 | alphaK0 = helper.umkl_bo(DD, 1.0/len(DD)); 116 | alphaK0 = alphaK0/np.sum(alphaK0) 117 | return distX, alphaK0 118 | 119 | 120 | def mex_multipleK(self, val, ind): 121 | #val *= val 122 | KK = self.num_of_neighbor 123 | ismemory = self.mode_of_memory 124 | m,n=val.shape 125 | sigma = np.arange(1,2.1,.25) 126 | allK = (np.arange(np.ceil(KK/2.0), min(n,np.ceil(KK*1.5))+1, np.ceil(KK/10.0))).astype('int') 127 | if ismemory: 128 | D_kernels = [] 129 | alphaK = np.ones(len(allK)*len(sigma))/(0.0 + len(allK)*len(sigma)) 130 | else: 131 | D_kernels = np.zeros((m,n,len(allK)*len(sigma))) 132 | alphaK = np.ones(D_kernels.shape[2])/(0.0 + D_kernels.shape[2]) 133 | t = 0; 134 | for k in allK: 135 | temp = val[:,np.arange(k)].sum(axis=1)/(k+0.0) 136 | temp0 = .5*(temp[:,np.newaxis] + np.take(temp,ind)) 137 | temp = val/temp0 138 | temp*=temp 139 | for s in sigma: 140 | temp1 = np.exp(-temp/2.0/s/s)/np.sqrt(2*np.pi)/s/temp0 141 | temptemp = temp1[:, 0] 142 | temp1[:] = .5*(temptemp[:,np.newaxis] + temptemp[ind]) - temp1 143 | if ismemory: 144 | np.save('Kernel_' + str(t), temp1 - temp1.min()) 145 | else: 146 | D_kernels[:,:,t] = temp1 - temp1.min() 147 | t = t+1 148 | 149 | return D_kernels, alphaK 150 | 151 | 152 | def fast_eigens(self, val, ind): 153 | n,d = val.shape 154 | rows = np.tile(np.arange(n), d) 155 | cols = ind.ravel(order='F') 156 | A = csr_matrix((val.ravel(order='F'),(rows,cols)),shape = (n, n)) + csr_matrix((val.ravel(order='F'),(cols,rows)),shape = (n, n)) 157 | (d,V) = linalg.eigsh(A,self.num_of_rank,which='LM') 158 | d = -np.sort(-np.real(d)) 159 | return np.real(V),d/np.max(abs(d)) 160 | def fast_minibatch_kmeans(self, X,C): 161 | batchsize = int(min(1000, np.round(X.shape[0]/C/C))) 162 | cls = MiniBatchKMeans(init='k-means++',n_clusters=C, batch_size = batchsize, n_init = 100, max_iter = 100) 163 | return cls.fit_predict(X) 164 | 165 | def fit(self, X, beta = 0.8): 166 | K = self.num_of_neighbor 167 | is_memory = self.mode_of_memory 168 | c = self.num_of_rank 169 | NITER = self.max_iter 170 | n,d = X.shape 171 | if d > 500: 172 | print('SIMLR highly recommends you to perform PCA first on the data\n'); 173 | print('Please use the in-line function fast_pca on your input\n'); 174 | ind, val = self.nearest_neighbor_search(X) 175 | del X 176 | D_Kernels, alphaK = self.mex_multipleK(val, ind) 177 | del val 178 | if is_memory: 179 | distX,alphaK0 = self.Cal_distance_memory(np.ones((ind.shape[0], ind.shape[1])), alphaK) 180 | else: 181 | distX = D_Kernels.dot(alphaK) 182 | rr = (.5*(K*distX[:,K+2] - distX[:,np.arange(1,K+1)].sum(axis = 1))).mean() 183 | lambdar = rr 184 | S0 = distX.max() - distX 185 | S0[:] = helper.NE_dn(S0) 186 | F, evalues = self.fast_eigens(S0.copy(), ind.copy()) 187 | F = helper.NE_dn(F) 188 | F *= (1-beta)*d/(1-beta*d*d); 189 | F0 = F.copy() 190 | for iter in range(NITER): 191 | FF = self.mex_L2_distance(F, ind) 192 | FF[:] = (distX + lambdar*FF)/2.0/rr 193 | FF[:] = helper.euclidean_proj_simplex(-FF) 194 | S0[:] = (1-beta)*S0 + beta*FF 195 | 196 | F[:], evalues = self.fast_eigens(S0, ind) 197 | F *= (1-beta)*d/(1-beta*d*d); 198 | F[:] = helper.NE_dn(F) 199 | F[:] = (1-beta)*F0 + beta*F 200 | F0 = F.copy() 201 | lambdar = lambdar * 1.5 202 | rr = rr / 1.05 203 | if is_memory: 204 | distX, alphaK0 = self.Cal_distance_memory(S0, alphaK) 205 | alphaK = (1-beta)*alphaK + beta*alphaK0 206 | alphaK = alphaK/np.sum(alphaK) 207 | 208 | else: 209 | DD = ((D_Kernels*S0[:,:,np.newaxis]).sum(axis = 0)/(D_Kernels.shape[0]+0.0)).mean(axis = 0) 210 | alphaK0 = helper.umkl_bo(DD, 1.0/len(DD)); 211 | alphaK0 = alphaK0/np.sum(alphaK0) 212 | alphaK = (1-beta)*alphaK + beta*alphaK0 213 | alphaK = alphaK/np.sum(alphaK) 214 | distX = D_Kernels.dot(alphaK) 215 | 216 | if is_memory: 217 | for i in range(len(alphaK)): 218 | os.remove('Kernel_' + str(i) + '.npy') 219 | rows = np.tile(np.arange(n), S0.shape[1]) 220 | cols = ind.ravel(order='F') 221 | val = S0 222 | S0 = csr_matrix((S0.ravel(order='F'),(rows,cols)),shape = (n, n)) + csr_matrix((S0.ravel(order='F'),(cols,rows)), shape = (n, n)) 223 | return S0, F, val, ind 224 | 225 | 226 | 227 | -------------------------------------------------------------------------------- /SIMLR_PY/SIMLR/core.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basiralab/MultiGraphGAN/7b6d2160daaf07c2462b11a7a621e5ead04cef0b/SIMLR_PY/SIMLR/core.pyc -------------------------------------------------------------------------------- /SIMLR_PY/SIMLR/helper.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import numpy as np 7 | from scipy.sparse import csr_matrix, csc_matrix, linalg 8 | from sklearn.decomposition import TruncatedSVD 9 | from fbpca import pca 10 | import time 11 | 12 | 13 | def save_sparse_csr(filename,array, label=[]): 14 | np.savez(filename,data = array.data ,indices=array.indices, 15 | indptr =array.indptr, shape=array.shape, label = label ) 16 | 17 | def load_sparse_csr(filename): 18 | loader = np.load(filename) 19 | if 'label' in loader.keys(): 20 | label = loader['label'] 21 | else: 22 | label = [] 23 | return csr_matrix(( np.log10(1.0+loader['data']), loader['indices'], loader['indptr']), 24 | shape = loader['shape']), label 25 | 26 | def NE_dn(A, type='ave'): 27 | m , n = A.shape 28 | diags = np.ones(m) 29 | diags[:] = abs(A).sum(axis=1).flatten() 30 | if type == 'ave': 31 | D = 1/(diags+np.finfo(float).eps) 32 | return A*D[:,np.newaxis] 33 | elif type == 'gph': 34 | D = 1/np.sqrt(diags+np.finfo(float).eps) 35 | return (A*D[:,np.newaxis])*D 36 | 37 | def Hbeta(D,beta): 38 | D = (D-D.min())/(D.max() - D.min() + np.finfo(float).eps) 39 | P = np.exp(-D*beta) 40 | sumP = P.sum() 41 | H = np.log(sumP) + beta*sum(D*P)/sumP 42 | P = P / sumP 43 | return H, P 44 | 45 | 46 | 47 | def umkl_bo(D, beta): 48 | tol = 1e-4 49 | u = 20 50 | logU = np.log(u) 51 | H, P = Hbeta(D,beta) 52 | betamin = -np.inf 53 | betamax = np.inf 54 | Hdiff = H - logU 55 | tries = 0 56 | while(abs(Hdiff)>tol)&(tries < 30): 57 | if Hdiff>0: 58 | betamin = beta 59 | if np.isinf(betamax): 60 | beta *= 2.0 61 | else: 62 | beta = .5*(beta + betamax) 63 | else: 64 | betamax = beta 65 | if np.isinf(betamin): 66 | beta /= 2.0 67 | else: 68 | beta = .5*(beta + betamin) 69 | H, P = Hbeta(D,beta) 70 | Hdiff = H - logU 71 | tries +=1 72 | return P 73 | 74 | def euclidean_proj_simplex(v, s=1): 75 | """ Compute the Euclidean projection on a positive simplex 76 | Solves the optimisation problem (using the algorithm from [1]): 77 | min_w 0.5 * || w - v ||_2^2 , s.t. \sum_i w_i = s, w_i >= 0 78 | Parameters 79 | ---------- 80 | v: (n,) numpy array, 81 | n-dimensional vector to project 82 | s: int, optional, default: 1, 83 | radius of the simplex 84 | Returns 85 | ------- 86 | w: (n,) numpy array, 87 | Euclidean projection of v on the simplex 88 | """ 89 | assert s > 0, "Radius s must be strictly positive (%d <= 0)" % s 90 | n,d = v.shape # will raise ValueError if v is not 1-D 91 | # get the array of cumulative sums of a sorted (decreasing) copy of v 92 | 93 | v -= (v.mean(axis = 1)[:,np.newaxis]-1.0/d) 94 | u = -np.sort(-v) 95 | cssv = np.cumsum(u,axis = 1) 96 | # get the number of > 0 components of the optimal solution 97 | temp = u * np.arange(1,d+1) - cssv +s 98 | temp[temp<0] = 'nan' 99 | rho = np.nanargmin(temp,axis = 1) 100 | #rho = np.nonzero(u * np.arange(1, n+1) > (cssv - s))[0][-1] 101 | # compute the Lagrange multiplier associated to the simplex constraint 102 | theta = (cssv[np.arange(n), rho] - s) / (rho + 1.0) 103 | # compute the projection by thresholding v using theta 104 | w = (v - theta[:,np.newaxis]).clip(min=0) 105 | return w 106 | 107 | 108 | def fast_pca(in_X, no_dim): 109 | (U, s, Va) = pca(in_X, no_dim, False, 8) 110 | del Va 111 | U[:] = U*np.sqrt(np.abs(s)) 112 | D = 1/(np.sqrt(np.sum(U*U,axis = 1)+np.finfo(float).eps)+np.finfo(float).eps) 113 | return U*D[:,np.newaxis] 114 | 115 | 116 | -------------------------------------------------------------------------------- /SIMLR_PY/SIMLR/helper.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basiralab/MultiGraphGAN/7b6d2160daaf07c2462b11a7a621e5ead04cef0b/SIMLR_PY/SIMLR/helper.pyc -------------------------------------------------------------------------------- /SIMLR_PY/build/lib.linux-x86_64-2.7/SIMLR/__init__.py: -------------------------------------------------------------------------------- 1 | from core import SIMLR_LARGE 2 | -------------------------------------------------------------------------------- /SIMLR_PY/build/lib.linux-x86_64-2.7/SIMLR/core.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions for large scale SIMLR and accuracy checks 3 | 4 | --------------------------------------------------------------------- 5 | 6 | This module contains the following functions: 7 | 8 | save_sparse_csr 9 | save a sparse csr format input of single-cell RNA-seq data 10 | load_sparse_csr 11 | load a sparse csr format input of single-cell RNA-seq data 12 | nearest_neighbor_search 13 | Approximate Nearset Neighbor search for every cell 14 | NE_dn 15 | Row-normalization of a matrix 16 | mex_L2_distance 17 | A fast way to calculate L2 distance 18 | Cal_distance_memory 19 | Calculate Kernels in a memory-saving mode 20 | mex_multipleK 21 | A fast way to calculate kernels 22 | Hbeta 23 | A simple LP method to solve linear weight 24 | euclidean_proj_simplex 25 | A fast way to calculate simplex projection 26 | fast_pca 27 | A fast randomized pca with sparse input 28 | fast_minibatch_kmeans 29 | A fast mini-batch version of k-means 30 | SIMLR_Large 31 | A large-scale implementation of our SIMLR 32 | --------------------------------------------------------------------- 33 | 34 | Copyright 2016 Bo Wang, Stanford University. 35 | All rights reserved. 36 | """ 37 | from __future__ import absolute_import 38 | from __future__ import division 39 | from __future__ import print_function 40 | from __future__ import unicode_literals 41 | from . import helper 42 | import numpy as np 43 | import sys 44 | import os 45 | from annoy import AnnoyIndex 46 | import scipy.io as sio 47 | from scipy.sparse import csr_matrix, csc_matrix, linalg 48 | from fbpca import svd, pca 49 | import time 50 | from sklearn.decomposition import TruncatedSVD 51 | from sklearn.cluster import MiniBatchKMeans, KMeans 52 | 53 | class SIMLR_LARGE(object): 54 | """A class for large-scale SIMLR. 55 | 56 | Attributes: 57 | num_of_rank: The rank hyper-parameter in SIMLR usually set to number of clusters. 58 | num_of_neighbors: the number of neighbors kept for each cell to approximate full cell similarities 59 | mode_of_memory: an indicator to open the memory-saving mode. This is helpful for datasets of millions of cells. It will sacrify a bit speed though. 60 | """ 61 | def __init__(self, num_of_rank, num_of_neighbor=30, mode_of_memory = False, max_iter = 5): 62 | self.num_of_rank = int(num_of_rank) 63 | self.num_of_neighbor = int(num_of_neighbor) 64 | self.mode_of_memory = mode_of_memory 65 | self.max_iter = int(max_iter) 66 | 67 | def nearest_neighbor_search(self, GE_csc): 68 | K = self.num_of_neighbor * 2 69 | n,d = GE_csc.shape 70 | t = AnnoyIndex(d) 71 | for i in xrange(n): 72 | t.add_item(i,GE_csc[i,:]) 73 | t.build(100) 74 | t.save('test.ann') 75 | u = AnnoyIndex(d) 76 | u.load('test.ann') 77 | os.remove('test.ann') 78 | val = np.zeros((n,K)) 79 | ind = np.zeros((n,K)) 80 | for i in xrange(n): 81 | tmp, tmp1 = u.get_nns_by_item(i,K, include_distances=True) 82 | ind[i,:] = tmp 83 | val[i,:] = tmp1 84 | return ind.astype('int'), val 85 | 86 | 87 | 88 | 89 | def mex_L2_distance(self, F, ind): 90 | m,n = ind.shape 91 | I = np.tile(np.arange(m), n) 92 | if self.mode_of_memory: 93 | temp = np.zeros((m,n)) 94 | for i in range(n): 95 | temptemp = np.take(F, np.arange(m), axis = 0) - np.take(F, ind[:,i],axis=0) 96 | temp[:,i] = (temptemp*temptemp).sum(axis=1) 97 | return temp 98 | else: 99 | temp = np.take(F, I, axis = 0) - np.take(F, ind.ravel(order = 'F'),axis=0) 100 | temp = (temp*temp).sum(axis=1) 101 | return temp.reshape((m,n),order = 'F') 102 | 103 | 104 | 105 | def Cal_distance_memory(self, S, alpha): 106 | NT = len(alpha) 107 | DD = alpha.copy() 108 | for i in xrange(NT): 109 | temp = np.load('Kernel_' + str(i)+'.npy') 110 | if i == 0: 111 | distX = alpha[0]*temp 112 | else: 113 | distX += alpha[i]*temp 114 | DD[i] = ((temp*S).sum(axis = 0)/(S.shape[0]+0.0)).mean(axis = 0) 115 | alphaK0 = helper.umkl_bo(DD, 1.0/len(DD)); 116 | alphaK0 = alphaK0/np.sum(alphaK0) 117 | return distX, alphaK0 118 | 119 | 120 | def mex_multipleK(self, val, ind): 121 | #val *= val 122 | KK = self.num_of_neighbor 123 | ismemory = self.mode_of_memory 124 | m,n=val.shape 125 | sigma = np.arange(1,2.1,.25) 126 | allK = (np.arange(np.ceil(KK/2.0), min(n,np.ceil(KK*1.5))+1, np.ceil(KK/10.0))).astype('int') 127 | if ismemory: 128 | D_kernels = [] 129 | alphaK = np.ones(len(allK)*len(sigma))/(0.0 + len(allK)*len(sigma)) 130 | else: 131 | D_kernels = np.zeros((m,n,len(allK)*len(sigma))) 132 | alphaK = np.ones(D_kernels.shape[2])/(0.0 + D_kernels.shape[2]) 133 | t = 0; 134 | for k in allK: 135 | temp = val[:,np.arange(k)].sum(axis=1)/(k+0.0) 136 | temp0 = .5*(temp[:,np.newaxis] + np.take(temp,ind)) 137 | temp = val/temp0 138 | temp*=temp 139 | for s in sigma: 140 | temp1 = np.exp(-temp/2.0/s/s)/np.sqrt(2*np.pi)/s/temp0 141 | temptemp = temp1[:, 0] 142 | temp1[:] = .5*(temptemp[:,np.newaxis] + temptemp[ind]) - temp1 143 | if ismemory: 144 | np.save('Kernel_' + str(t), temp1 - temp1.min()) 145 | else: 146 | D_kernels[:,:,t] = temp1 - temp1.min() 147 | t = t+1 148 | 149 | return D_kernels, alphaK 150 | 151 | 152 | def fast_eigens(self, val, ind): 153 | n,d = val.shape 154 | rows = np.tile(np.arange(n), d) 155 | cols = ind.ravel(order='F') 156 | A = csr_matrix((val.ravel(order='F'),(rows,cols)),shape = (n, n)) + csr_matrix((val.ravel(order='F'),(cols,rows)),shape = (n, n)) 157 | (d,V) = linalg.eigsh(A,self.num_of_rank,which='LM') 158 | d = -np.sort(-np.real(d)) 159 | return np.real(V),d/np.max(abs(d)) 160 | def fast_minibatch_kmeans(self, X,C): 161 | batchsize = int(min(1000, np.round(X.shape[0]/C/C))) 162 | cls = MiniBatchKMeans(init='k-means++',n_clusters=C, batch_size = batchsize, n_init = 100, max_iter = 100) 163 | return cls.fit_predict(X) 164 | 165 | def fit(self, X, beta = 0.8): 166 | K = self.num_of_neighbor 167 | is_memory = self.mode_of_memory 168 | c = self.num_of_rank 169 | NITER = self.max_iter 170 | n,d = X.shape 171 | if d > 500: 172 | print('SIMLR highly recommends you to perform PCA first on the data\n'); 173 | print('Please use the in-line function fast_pca on your input\n'); 174 | ind, val = self.nearest_neighbor_search(X) 175 | del X 176 | D_Kernels, alphaK = self.mex_multipleK(val, ind) 177 | del val 178 | if is_memory: 179 | distX,alphaK0 = self.Cal_distance_memory(np.ones((ind.shape[0], ind.shape[1])), alphaK) 180 | else: 181 | distX = D_Kernels.dot(alphaK) 182 | rr = (.5*(K*distX[:,K+2] - distX[:,np.arange(1,K+1)].sum(axis = 1))).mean() 183 | lambdar = rr 184 | S0 = distX.max() - distX 185 | S0[:] = helper.NE_dn(S0) 186 | F, evalues = self.fast_eigens(S0.copy(), ind.copy()) 187 | F = helper.NE_dn(F) 188 | F *= (1-beta)*d/(1-beta*d*d); 189 | F0 = F.copy() 190 | for iter in range(NITER): 191 | FF = self.mex_L2_distance(F, ind) 192 | FF[:] = (distX + lambdar*FF)/2.0/rr 193 | FF[:] = helper.euclidean_proj_simplex(-FF) 194 | S0[:] = (1-beta)*S0 + beta*FF 195 | 196 | F[:], evalues = self.fast_eigens(S0, ind) 197 | F *= (1-beta)*d/(1-beta*d*d); 198 | F[:] = helper.NE_dn(F) 199 | F[:] = (1-beta)*F0 + beta*F 200 | F0 = F.copy() 201 | lambdar = lambdar * 1.5 202 | rr = rr / 1.05 203 | if is_memory: 204 | distX, alphaK0 = self.Cal_distance_memory(S0, alphaK) 205 | alphaK = (1-beta)*alphaK + beta*alphaK0 206 | alphaK = alphaK/np.sum(alphaK) 207 | 208 | else: 209 | DD = ((D_Kernels*S0[:,:,np.newaxis]).sum(axis = 0)/(D_Kernels.shape[0]+0.0)).mean(axis = 0) 210 | alphaK0 = helper.umkl_bo(DD, 1.0/len(DD)); 211 | alphaK0 = alphaK0/np.sum(alphaK0) 212 | alphaK = (1-beta)*alphaK + beta*alphaK0 213 | alphaK = alphaK/np.sum(alphaK) 214 | distX = D_Kernels.dot(alphaK) 215 | 216 | if is_memory: 217 | for i in xrange(len(alphaK)): 218 | os.remove('Kernel_' + str(i) + '.npy') 219 | rows = np.tile(np.arange(n), S0.shape[1]) 220 | cols = ind.ravel(order='F') 221 | val = S0 222 | S0 = csr_matrix((S0.ravel(order='F'),(rows,cols)),shape = (n, n)) + csr_matrix((S0.ravel(order='F'),(cols,rows)), shape = (n, n)) 223 | return S0, F, val, ind 224 | 225 | 226 | 227 | -------------------------------------------------------------------------------- /SIMLR_PY/build/lib.linux-x86_64-2.7/SIMLR/helper.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import numpy as np 7 | from scipy.sparse import csr_matrix, csc_matrix, linalg 8 | from sklearn.decomposition import TruncatedSVD 9 | from fbpca import pca 10 | import time 11 | 12 | 13 | def save_sparse_csr(filename,array, label=[]): 14 | np.savez(filename,data = array.data ,indices=array.indices, 15 | indptr =array.indptr, shape=array.shape, label = label ) 16 | 17 | def load_sparse_csr(filename): 18 | loader = np.load(filename) 19 | if 'label' in loader.keys(): 20 | label = loader['label'] 21 | else: 22 | label = [] 23 | return csr_matrix(( np.log10(1.0+loader['data']), loader['indices'], loader['indptr']), 24 | shape = loader['shape']), label 25 | 26 | def NE_dn(A, type='ave'): 27 | m , n = A.shape 28 | diags = np.ones(m) 29 | diags[:] = abs(A).sum(axis=1).flatten() 30 | if type == 'ave': 31 | D = 1/(diags+np.finfo(float).eps) 32 | return A*D[:,np.newaxis] 33 | elif type == 'gph': 34 | D = 1/np.sqrt(diags+np.finfo(float).eps) 35 | return (A*D[:,np.newaxis])*D 36 | 37 | def Hbeta(D,beta): 38 | D = (D-D.min())/(D.max() - D.min() + np.finfo(float).eps) 39 | P = np.exp(-D*beta) 40 | sumP = P.sum() 41 | H = np.log(sumP) + beta*sum(D*P)/sumP 42 | P = P / sumP 43 | return H, P 44 | 45 | 46 | 47 | def umkl_bo(D, beta): 48 | tol = 1e-4 49 | u = 20 50 | logU = np.log(u) 51 | H, P = Hbeta(D,beta) 52 | betamin = -np.inf 53 | betamax = np.inf 54 | Hdiff = H - logU 55 | tries = 0 56 | while(abs(Hdiff)>tol)&(tries < 30): 57 | if Hdiff>0: 58 | betamin = beta 59 | if np.isinf(betamax): 60 | beta *= 2.0 61 | else: 62 | beta = .5*(beta + betamax) 63 | else: 64 | betamax = beta 65 | if np.isinf(betamin): 66 | beta /= 2.0 67 | else: 68 | beta = .5*(beta + betamin) 69 | H, P = Hbeta(D,beta) 70 | Hdiff = H - logU 71 | tries +=1 72 | return P 73 | 74 | def euclidean_proj_simplex(v, s=1): 75 | """ Compute the Euclidean projection on a positive simplex 76 | Solves the optimisation problem (using the algorithm from [1]): 77 | min_w 0.5 * || w - v ||_2^2 , s.t. \sum_i w_i = s, w_i >= 0 78 | Parameters 79 | ---------- 80 | v: (n,) numpy array, 81 | n-dimensional vector to project 82 | s: int, optional, default: 1, 83 | radius of the simplex 84 | Returns 85 | ------- 86 | w: (n,) numpy array, 87 | Euclidean projection of v on the simplex 88 | """ 89 | assert s > 0, "Radius s must be strictly positive (%d <= 0)" % s 90 | n,d = v.shape # will raise ValueError if v is not 1-D 91 | # get the array of cumulative sums of a sorted (decreasing) copy of v 92 | 93 | v -= (v.mean(axis = 1)[:,np.newaxis]-1.0/d) 94 | u = -np.sort(-v) 95 | cssv = np.cumsum(u,axis = 1) 96 | # get the number of > 0 components of the optimal solution 97 | temp = u * np.arange(1,d+1) - cssv +s 98 | temp[temp<0] = 'nan' 99 | rho = np.nanargmin(temp,axis = 1) 100 | #rho = np.nonzero(u * np.arange(1, n+1) > (cssv - s))[0][-1] 101 | # compute the Lagrange multiplier associated to the simplex constraint 102 | theta = (cssv[np.arange(n), rho] - s) / (rho + 1.0) 103 | # compute the projection by thresholding v using theta 104 | w = (v - theta[:,np.newaxis]).clip(min=0) 105 | return w 106 | 107 | 108 | def fast_pca(in_X, no_dim): 109 | (U, s, Va) = pca(csc_matrix(in_X), no_dim, False, 8) 110 | del Va 111 | U[:] = U*np.sqrt(np.abs(s)) 112 | D = 1/(np.sqrt(np.sum(U*U,axis = 1)+np.finfo(float).eps)+np.finfo(float).eps) 113 | return U*D[:,np.newaxis] 114 | 115 | 116 | -------------------------------------------------------------------------------- /SIMLR_PY/build/lib.linux-x86_64-2.7/SIMLR/helpers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import numpy as np 7 | from scipy.sparse import csr_matrix, csc_matrix, linalg 8 | from fbpca import svd, pca 9 | import time 10 | 11 | 12 | def save_sparse_csr(filename,array, label=[]): 13 | np.savez(filename,data = array.data ,indices=array.indices, 14 | indptr =array.indptr, shape=array.shape, label = label ) 15 | 16 | def load_sparse_csr(filename): 17 | loader = np.load(filename) 18 | if 'label' in loader.keys(): 19 | label = loader['label'] 20 | else: 21 | label = [] 22 | return csr_matrix(( np.log10(1.0+loader['data']), loader['indices'], loader['indptr']), 23 | shape = loader['shape']), label 24 | 25 | def NE_dn(A, type='ave'): 26 | m , n = A.shape 27 | diags = np.ones(m) 28 | diags[:] = abs(A).sum(axis=1).flatten() 29 | if type == 'ave': 30 | D = 1/(diags+np.finfo(float).eps) 31 | return A*D[:,np.newaxis] 32 | elif type == 'gph': 33 | D = 1/np.sqrt(diags+np.finfo(float).eps) 34 | return (A*D[:,np.newaxis])*D 35 | 36 | def Hbeta(D,beta): 37 | D = (D-D.min())/(D.max() - D.min() + np.finfo(float).eps) 38 | P = np.exp(-D*beta) 39 | sumP = P.sum() 40 | H = np.log(sumP) + beta*sum(D*P)/sumP 41 | P = P / sumP 42 | return H, P 43 | 44 | 45 | 46 | def umkl_bo(D, beta): 47 | tol = 1e-4 48 | u = 20 49 | logU = np.log(u) 50 | H, P = Hbeta(D,beta) 51 | betamin = -np.inf 52 | betamax = np.inf 53 | Hdiff = H - logU 54 | tries = 0 55 | while(abs(Hdiff)>tol)&(tries < 30): 56 | if Hdiff>0: 57 | betamin = beta 58 | if np.isinf(betamax): 59 | beta *= 2.0 60 | else: 61 | beta = .5*(beta + betamax) 62 | else: 63 | betamax = beta 64 | if np.isinf(betamin): 65 | beta /= 2.0 66 | else: 67 | beta = .5*(beta + betamin) 68 | H, P = Hbeta(D,beta) 69 | Hdiff = H - logU 70 | tries +=1 71 | return P 72 | 73 | def euclidean_proj_simplex(v, s=1): 74 | """ Compute the Euclidean projection on a positive simplex 75 | Solves the optimisation problem (using the algorithm from [1]): 76 | min_w 0.5 * || w - v ||_2^2 , s.t. \sum_i w_i = s, w_i >= 0 77 | Parameters 78 | ---------- 79 | v: (n,) numpy array, 80 | n-dimensional vector to project 81 | s: int, optional, default: 1, 82 | radius of the simplex 83 | Returns 84 | ------- 85 | w: (n,) numpy array, 86 | Euclidean projection of v on the simplex 87 | """ 88 | assert s > 0, "Radius s must be strictly positive (%d <= 0)" % s 89 | n,d = v.shape # will raise ValueError if v is not 1-D 90 | # get the array of cumulative sums of a sorted (decreasing) copy of v 91 | 92 | v -= (v.mean(axis = 1)[:,np.newaxis]-1.0/d) 93 | u = -np.sort(-v) 94 | cssv = np.cumsum(u,axis = 1) 95 | # get the number of > 0 components of the optimal solution 96 | temp = u * np.arange(1,d+1) - cssv +s 97 | temp[temp<0] = 'nan' 98 | rho = np.nanargmin(temp,axis = 1) 99 | #rho = np.nonzero(u * np.arange(1, n+1) > (cssv - s))[0][-1] 100 | # compute the Lagrange multiplier associated to the simplex constraint 101 | theta = (cssv[np.arange(n), rho] - s) / (rho + 1.0) 102 | # compute the projection by thresholding v using theta 103 | w = (v - theta[:,np.newaxis]).clip(min=0) 104 | return w 105 | 106 | 107 | def fast_pca(in_X, no_dim): 108 | (U, s, Va) = pca(csc_matrix(in_X), no_dim, True, 5) 109 | del Va 110 | U[:] = U*np.sqrt(np.abs(s)) 111 | D = 1/(np.sqrt(np.sum(U*U,axis = 1)+np.finfo(float).eps)+np.finfo(float).eps) 112 | return U*D[:,np.newaxis] 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /SIMLR_PY/build/lib/SIMLR/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import SIMLR_LARGE 2 | -------------------------------------------------------------------------------- /SIMLR_PY/build/lib/SIMLR/core.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions for large scale SIMLR and accuracy checks 3 | 4 | --------------------------------------------------------------------- 5 | 6 | This module contains the following functions: 7 | 8 | save_sparse_csr 9 | save a sparse csr format input of single-cell RNA-seq data 10 | load_sparse_csr 11 | load a sparse csr format input of single-cell RNA-seq data 12 | nearest_neighbor_search 13 | Approximate Nearset Neighbor search for every cell 14 | NE_dn 15 | Row-normalization of a matrix 16 | mex_L2_distance 17 | A fast way to calculate L2 distance 18 | Cal_distance_memory 19 | Calculate Kernels in a memory-saving mode 20 | mex_multipleK 21 | A fast way to calculate kernels 22 | Hbeta 23 | A simple LP method to solve linear weight 24 | euclidean_proj_simplex 25 | A fast way to calculate simplex projection 26 | fast_pca 27 | A fast randomized pca with sparse input 28 | fast_minibatch_kmeans 29 | A fast mini-batch version of k-means 30 | SIMLR_Large 31 | A large-scale implementation of our SIMLR 32 | --------------------------------------------------------------------- 33 | 34 | Copyright 2016 Bo Wang, Stanford University. 35 | All rights reserved. 36 | """ 37 | from __future__ import absolute_import 38 | from __future__ import division 39 | from __future__ import print_function 40 | from __future__ import unicode_literals 41 | from . import helper 42 | import numpy as np 43 | import sys 44 | import os 45 | from annoy import AnnoyIndex 46 | import scipy.io as sio 47 | from scipy.sparse import csr_matrix, csc_matrix, linalg 48 | from fbpca import svd, pca 49 | import time 50 | from sklearn.decomposition import TruncatedSVD 51 | from sklearn.cluster import MiniBatchKMeans, KMeans 52 | 53 | class SIMLR_LARGE(object): 54 | """A class for large-scale SIMLR. 55 | 56 | Attributes: 57 | num_of_rank: The rank hyper-parameter in SIMLR usually set to number of clusters. 58 | num_of_neighbors: the number of neighbors kept for each cell to approximate full cell similarities 59 | mode_of_memory: an indicator to open the memory-saving mode. This is helpful for datasets of millions of cells. It will sacrify a bit speed though. 60 | """ 61 | def __init__(self, num_of_rank, num_of_neighbor=30, mode_of_memory = False, max_iter = 5): 62 | self.num_of_rank = int(num_of_rank) 63 | self.num_of_neighbor = int(num_of_neighbor) 64 | self.mode_of_memory = mode_of_memory 65 | self.max_iter = int(max_iter) 66 | 67 | def nearest_neighbor_search(self, GE_csc): 68 | K = self.num_of_neighbor * 2 69 | n,d = GE_csc.shape 70 | t = AnnoyIndex(d) 71 | for i in range(n): 72 | t.add_item(i,GE_csc[i,:]) 73 | t.build(100) 74 | t.save('test.ann') 75 | u = AnnoyIndex(d) 76 | u.load('test.ann') 77 | os.remove('test.ann') 78 | val = np.zeros((n,K)) 79 | ind = np.zeros((n,K)) 80 | for i in range(n): 81 | tmp, tmp1 = u.get_nns_by_item(i,K, include_distances=True) 82 | ind[i,:] = tmp 83 | val[i,:] = tmp1 84 | return ind.astype('int'), val 85 | 86 | 87 | 88 | 89 | def mex_L2_distance(self, F, ind): 90 | m,n = ind.shape 91 | I = np.tile(np.arange(m), n) 92 | if self.mode_of_memory: 93 | temp = np.zeros((m,n)) 94 | for i in range(n): 95 | temptemp = np.take(F, np.arange(m), axis = 0) - np.take(F, ind[:,i],axis=0) 96 | temp[:,i] = (temptemp*temptemp).sum(axis=1) 97 | return temp 98 | else: 99 | temp = np.take(F, I, axis = 0) - np.take(F, ind.ravel(order = 'F'),axis=0) 100 | temp = (temp*temp).sum(axis=1) 101 | return temp.reshape((m,n),order = 'F') 102 | 103 | 104 | 105 | def Cal_distance_memory(self, S, alpha): 106 | NT = len(alpha) 107 | DD = alpha.copy() 108 | for i in range(NT): 109 | temp = np.load('Kernel_' + str(i)+'.npy') 110 | if i == 0: 111 | distX = alpha[0]*temp 112 | else: 113 | distX += alpha[i]*temp 114 | DD[i] = ((temp*S).sum(axis = 0)/(S.shape[0]+0.0)).mean(axis = 0) 115 | alphaK0 = helper.umkl_bo(DD, 1.0/len(DD)); 116 | alphaK0 = alphaK0/np.sum(alphaK0) 117 | return distX, alphaK0 118 | 119 | 120 | def mex_multipleK(self, val, ind): 121 | #val *= val 122 | KK = self.num_of_neighbor 123 | ismemory = self.mode_of_memory 124 | m,n=val.shape 125 | sigma = np.arange(1,2.1,.25) 126 | allK = (np.arange(np.ceil(KK/2.0), min(n,np.ceil(KK*1.5))+1, np.ceil(KK/10.0))).astype('int') 127 | if ismemory: 128 | D_kernels = [] 129 | alphaK = np.ones(len(allK)*len(sigma))/(0.0 + len(allK)*len(sigma)) 130 | else: 131 | D_kernels = np.zeros((m,n,len(allK)*len(sigma))) 132 | alphaK = np.ones(D_kernels.shape[2])/(0.0 + D_kernels.shape[2]) 133 | t = 0; 134 | for k in allK: 135 | temp = val[:,np.arange(k)].sum(axis=1)/(k+0.0) 136 | temp0 = .5*(temp[:,np.newaxis] + np.take(temp,ind)) 137 | temp = val/temp0 138 | temp*=temp 139 | for s in sigma: 140 | temp1 = np.exp(-temp/2.0/s/s)/np.sqrt(2*np.pi)/s/temp0 141 | temptemp = temp1[:, 0] 142 | temp1[:] = .5*(temptemp[:,np.newaxis] + temptemp[ind]) - temp1 143 | if ismemory: 144 | np.save('Kernel_' + str(t), temp1 - temp1.min()) 145 | else: 146 | D_kernels[:,:,t] = temp1 - temp1.min() 147 | t = t+1 148 | 149 | return D_kernels, alphaK 150 | 151 | 152 | def fast_eigens(self, val, ind): 153 | n,d = val.shape 154 | rows = np.tile(np.arange(n), d) 155 | cols = ind.ravel(order='F') 156 | A = csr_matrix((val.ravel(order='F'),(rows,cols)),shape = (n, n)) + csr_matrix((val.ravel(order='F'),(cols,rows)),shape = (n, n)) 157 | (d,V) = linalg.eigsh(A,self.num_of_rank,which='LM') 158 | d = -np.sort(-np.real(d)) 159 | return np.real(V),d/np.max(abs(d)) 160 | def fast_minibatch_kmeans(self, X,C): 161 | batchsize = int(min(1000, np.round(X.shape[0]/C/C))) 162 | cls = MiniBatchKMeans(init='k-means++',n_clusters=C, batch_size = batchsize, n_init = 100, max_iter = 100) 163 | return cls.fit_predict(X) 164 | 165 | def fit(self, X, beta = 0.8): 166 | K = self.num_of_neighbor 167 | is_memory = self.mode_of_memory 168 | c = self.num_of_rank 169 | NITER = self.max_iter 170 | n,d = X.shape 171 | if d > 500: 172 | print('SIMLR highly recommends you to perform PCA first on the data\n'); 173 | print('Please use the in-line function fast_pca on your input\n'); 174 | ind, val = self.nearest_neighbor_search(X) 175 | del X 176 | D_Kernels, alphaK = self.mex_multipleK(val, ind) 177 | del val 178 | if is_memory: 179 | distX,alphaK0 = self.Cal_distance_memory(np.ones((ind.shape[0], ind.shape[1])), alphaK) 180 | else: 181 | distX = D_Kernels.dot(alphaK) 182 | rr = (.5*(K*distX[:,K+2] - distX[:,np.arange(1,K+1)].sum(axis = 1))).mean() 183 | lambdar = rr 184 | S0 = distX.max() - distX 185 | S0[:] = helper.NE_dn(S0) 186 | F, evalues = self.fast_eigens(S0.copy(), ind.copy()) 187 | F = helper.NE_dn(F) 188 | F *= (1-beta)*d/(1-beta*d*d); 189 | F0 = F.copy() 190 | for iter in range(NITER): 191 | FF = self.mex_L2_distance(F, ind) 192 | FF[:] = (distX + lambdar*FF)/2.0/rr 193 | FF[:] = helper.euclidean_proj_simplex(-FF) 194 | S0[:] = (1-beta)*S0 + beta*FF 195 | 196 | F[:], evalues = self.fast_eigens(S0, ind) 197 | F *= (1-beta)*d/(1-beta*d*d); 198 | F[:] = helper.NE_dn(F) 199 | F[:] = (1-beta)*F0 + beta*F 200 | F0 = F.copy() 201 | lambdar = lambdar * 1.5 202 | rr = rr / 1.05 203 | if is_memory: 204 | distX, alphaK0 = self.Cal_distance_memory(S0, alphaK) 205 | alphaK = (1-beta)*alphaK + beta*alphaK0 206 | alphaK = alphaK/np.sum(alphaK) 207 | 208 | else: 209 | DD = ((D_Kernels*S0[:,:,np.newaxis]).sum(axis = 0)/(D_Kernels.shape[0]+0.0)).mean(axis = 0) 210 | alphaK0 = helper.umkl_bo(DD, 1.0/len(DD)); 211 | alphaK0 = alphaK0/np.sum(alphaK0) 212 | alphaK = (1-beta)*alphaK + beta*alphaK0 213 | alphaK = alphaK/np.sum(alphaK) 214 | distX = D_Kernels.dot(alphaK) 215 | 216 | if is_memory: 217 | for i in range(len(alphaK)): 218 | os.remove('Kernel_' + str(i) + '.npy') 219 | rows = np.tile(np.arange(n), S0.shape[1]) 220 | cols = ind.ravel(order='F') 221 | val = S0 222 | S0 = csr_matrix((S0.ravel(order='F'),(rows,cols)),shape = (n, n)) + csr_matrix((S0.ravel(order='F'),(cols,rows)), shape = (n, n)) 223 | return S0, F, val, ind 224 | 225 | 226 | 227 | -------------------------------------------------------------------------------- /SIMLR_PY/build/lib/SIMLR/helper.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import numpy as np 7 | from scipy.sparse import csr_matrix, csc_matrix, linalg 8 | from sklearn.decomposition import TruncatedSVD 9 | from fbpca import pca 10 | import time 11 | 12 | 13 | def save_sparse_csr(filename,array, label=[]): 14 | np.savez(filename,data = array.data ,indices=array.indices, 15 | indptr =array.indptr, shape=array.shape, label = label ) 16 | 17 | def load_sparse_csr(filename): 18 | loader = np.load(filename) 19 | if 'label' in loader.keys(): 20 | label = loader['label'] 21 | else: 22 | label = [] 23 | return csr_matrix(( np.log10(1.0+loader['data']), loader['indices'], loader['indptr']), 24 | shape = loader['shape']), label 25 | 26 | def NE_dn(A, type='ave'): 27 | m , n = A.shape 28 | diags = np.ones(m) 29 | diags[:] = abs(A).sum(axis=1).flatten() 30 | if type == 'ave': 31 | D = 1/(diags+np.finfo(float).eps) 32 | return A*D[:,np.newaxis] 33 | elif type == 'gph': 34 | D = 1/np.sqrt(diags+np.finfo(float).eps) 35 | return (A*D[:,np.newaxis])*D 36 | 37 | def Hbeta(D,beta): 38 | D = (D-D.min())/(D.max() - D.min() + np.finfo(float).eps) 39 | P = np.exp(-D*beta) 40 | sumP = P.sum() 41 | H = np.log(sumP) + beta*sum(D*P)/sumP 42 | P = P / sumP 43 | return H, P 44 | 45 | 46 | 47 | def umkl_bo(D, beta): 48 | tol = 1e-4 49 | u = 20 50 | logU = np.log(u) 51 | H, P = Hbeta(D,beta) 52 | betamin = -np.inf 53 | betamax = np.inf 54 | Hdiff = H - logU 55 | tries = 0 56 | while(abs(Hdiff)>tol)&(tries < 30): 57 | if Hdiff>0: 58 | betamin = beta 59 | if np.isinf(betamax): 60 | beta *= 2.0 61 | else: 62 | beta = .5*(beta + betamax) 63 | else: 64 | betamax = beta 65 | if np.isinf(betamin): 66 | beta /= 2.0 67 | else: 68 | beta = .5*(beta + betamin) 69 | H, P = Hbeta(D,beta) 70 | Hdiff = H - logU 71 | tries +=1 72 | return P 73 | 74 | def euclidean_proj_simplex(v, s=1): 75 | """ Compute the Euclidean projection on a positive simplex 76 | Solves the optimisation problem (using the algorithm from [1]): 77 | min_w 0.5 * || w - v ||_2^2 , s.t. \sum_i w_i = s, w_i >= 0 78 | Parameters 79 | ---------- 80 | v: (n,) numpy array, 81 | n-dimensional vector to project 82 | s: int, optional, default: 1, 83 | radius of the simplex 84 | Returns 85 | ------- 86 | w: (n,) numpy array, 87 | Euclidean projection of v on the simplex 88 | """ 89 | assert s > 0, "Radius s must be strictly positive (%d <= 0)" % s 90 | n,d = v.shape # will raise ValueError if v is not 1-D 91 | # get the array of cumulative sums of a sorted (decreasing) copy of v 92 | 93 | v -= (v.mean(axis = 1)[:,np.newaxis]-1.0/d) 94 | u = -np.sort(-v) 95 | cssv = np.cumsum(u,axis = 1) 96 | # get the number of > 0 components of the optimal solution 97 | temp = u * np.arange(1,d+1) - cssv +s 98 | temp[temp<0] = 'nan' 99 | rho = np.nanargmin(temp,axis = 1) 100 | #rho = np.nonzero(u * np.arange(1, n+1) > (cssv - s))[0][-1] 101 | # compute the Lagrange multiplier associated to the simplex constraint 102 | theta = (cssv[np.arange(n), rho] - s) / (rho + 1.0) 103 | # compute the projection by thresholding v using theta 104 | w = (v - theta[:,np.newaxis]).clip(min=0) 105 | return w 106 | 107 | 108 | def fast_pca(in_X, no_dim): 109 | (U, s, Va) = pca(in_X, no_dim, False, 8) 110 | del Va 111 | U[:] = U*np.sqrt(np.abs(s)) 112 | D = 1/(np.sqrt(np.sum(U*U,axis = 1)+np.finfo(float).eps)+np.finfo(float).eps) 113 | return U*D[:,np.newaxis] 114 | 115 | 116 | -------------------------------------------------------------------------------- /SIMLR_PY/dist/SIMLR-0.0.0.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basiralab/MultiGraphGAN/7b6d2160daaf07c2462b11a7a621e5ead04cef0b/SIMLR_PY/dist/SIMLR-0.0.0.tar.gz -------------------------------------------------------------------------------- /SIMLR_PY/dist/SIMLR-0.1.0.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basiralab/MultiGraphGAN/7b6d2160daaf07c2462b11a7a621e5ead04cef0b/SIMLR_PY/dist/SIMLR-0.1.0.tar.gz -------------------------------------------------------------------------------- /SIMLR_PY/dist/SIMLR-0.1.1.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basiralab/MultiGraphGAN/7b6d2160daaf07c2462b11a7a621e5ead04cef0b/SIMLR_PY/dist/SIMLR-0.1.1.tar.gz -------------------------------------------------------------------------------- /SIMLR_PY/dist/SIMLR-0.1.3.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basiralab/MultiGraphGAN/7b6d2160daaf07c2462b11a7a621e5ead04cef0b/SIMLR_PY/dist/SIMLR-0.1.3.tar.gz -------------------------------------------------------------------------------- /SIMLR_PY/requirements.txt: -------------------------------------------------------------------------------- 1 | fbpca==1.0 2 | numpy==1.8.0 3 | scikit-learn==0.17.1 4 | scipy==0.13.2 5 | annoy==1.8.0 6 | -------------------------------------------------------------------------------- /SIMLR_PY/setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2017, Stanford University. 3 | All rights reserved. 4 | 5 | This source code is a Python implementation of SIMLR for the following paper published in Nature Methods: 6 | Visualization and analysis of single-cell RNA-seq data by kernel-based similarity learning 7 | """ 8 | from distutils.core import setup 9 | 10 | setup( 11 | name='SIMLR', 12 | version='0.1.3', 13 | author='Bo Wang', 14 | author_email='bowang87@stanford.edu', 15 | url='https://github.com/bowang87/SIMLR-PY', 16 | description='Visualization and analysis of single-cell RNA-seq data by kernel-based similarity learning', 17 | packages=['SIMLR'], 18 | install_requires=[ 19 | 'fbpca>=1.0', 20 | 'numpy>=1.8.0', 21 | 'scipy>=0.13.2', 22 | 'annoy>=1.8.0', 23 | 'scikit-learn>=0.17.1', 24 | ], 25 | classifiers=[]) 26 | -------------------------------------------------------------------------------- /SIMLR_PY/tests/Zeisel.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basiralab/MultiGraphGAN/7b6d2160daaf07c2462b11a7a621e5ead04cef0b/SIMLR_PY/tests/Zeisel.mat -------------------------------------------------------------------------------- /SIMLR_PY/tests/test_largescale.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import scipy.io as sio 4 | sys.path.insert(0,os.path.abspath('..')) 5 | import time 6 | import numpy as np 7 | import SIMLR 8 | from sklearn.metrics.cluster import normalized_mutual_info_score as nmi 9 | from sklearn.metrics.cluster import adjusted_rand_score as ari 10 | from scipy.sparse import csr_matrix 11 | 12 | 13 | filename = 'Zeisel.mat' 14 | X = csr_matrix(sio.loadmat(filename)['X']) #loading single-cell RNA-seq data 15 | X.data = np.log10(1+X.data) ##take log transform of gene counts. This is very important since it makes the data more gaussian 16 | label = sio.loadmat(filename)['true_labs'] #this is ground-truth label for validation 17 | c = label.max() # number of clusters 18 | ### if the number of genes are more than 500, we recommend to perform pca first! 19 | print('Start to Run PCA on the RNA-seq data!\n') 20 | start_main = time.time() 21 | if X.shape[1]>500: 22 | X = SIMLR.helper.fast_pca(X,500) 23 | else: 24 | X = X.todense() 25 | print('Successfully Run PCA! PCA took %f seconds in total\n' % (time.time() - start_main)) 26 | print('Start to Run SIMLR!\n') 27 | start_main = time.time() 28 | simlr = SIMLR.SIMLR_LARGE(c, 30, 0); ###This is how we initialize an object for SIMLR. the first input is number of rank (clusters) and the second input is number of neighbors. The third one is an binary indicator whether to use memory-saving mode. you can turn it on when the number of cells are extremely large to save some memory but with the cost of efficiency. 29 | S, F,val, ind = simlr.fit(X) 30 | print('Successfully Run SIMLR! SIMLR took %f seconds in total\n' % (time.time() - start_main)) 31 | y_pred = simlr.fast_minibatch_kmeans(F,c) 32 | print('NMI value is %f \n' % nmi(y_pred.flatten(),label.flatten())) 33 | print('ARI value is %f \n' % ari(y_pred.flatten(),label.flatten())) 34 | 35 | -------------------------------------------------------------------------------- /centrality.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import networkx as nx 3 | 4 | # put it back into a 2D symmetric array 5 | def to_2d(vector): 6 | size = 35 7 | x = np.zeros((size,size)) 8 | c = 0 9 | for i in range(1,size): 10 | for j in range(0,i): 11 | x[i][j] = vector[c] 12 | x[j][i] = vector[c] 13 | c = c + 1 14 | return x 15 | 16 | def topological_measures(data): 17 | # ROI is the number of brain regions (i.e.,35 in our case) 18 | ROI= to_2d(data[0]).shape[0] 19 | CC = np.empty((0,ROI), int) 20 | BC = np.empty((0,ROI), int) 21 | EC = np.empty((0,ROI), int) 22 | topology = [] 23 | for i in range(data.shape[0]): 24 | A = to_2d(data[i]) 25 | np.fill_diagonal(A, 0) 26 | 27 | # create a graph from similarity matrix 28 | # G = nx.from_numpy_matrix(A) 29 | G = nx.DiGraph(np.array(A)) 30 | U = G.to_undirected() 31 | 32 | # Centrality # 33 | # compute closeness centrality and transform the output to vector 34 | cc = nx.closeness_centrality(U) 35 | closeness_centrality = np.array([cc[g] for g in U]) 36 | # compute betweeness centrality and transform the output to vector 37 | bc = nx.betweenness_centrality(U) 38 | betweenness_centrality = np.array([bc[g] for g in U]) 39 | # compute egeinvector centrality and transform the output to vector 40 | ec = nx.eigenvector_centrality(U) 41 | eigenvector_centrality = np.array([ec[g] for g in U]) 42 | 43 | # create a matrix of all subjects centralities 44 | CC = np.vstack((CC, closeness_centrality)) 45 | BC = np.vstack((BC, betweenness_centrality)) 46 | EC = np.vstack((EC, eigenvector_centrality)) 47 | 48 | topology.append(CC)#0 49 | topology.append(BC)#1 50 | topology.append(EC)#2 51 | 52 | return topology -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data_utils 2 | import numpy as np 3 | import torch 4 | import SIMLR 5 | import os 6 | 7 | def get_loader(features, batch_size, num_workers=1): 8 | """ 9 | Build and return a data loader. 10 | """ 11 | dataset = data_utils.TensorDataset(torch.Tensor(features)) 12 | loader = data_utils.DataLoader(dataset, 13 | batch_size=batch_size, 14 | shuffle = True, #set to True in case of training and False when testing the model 15 | num_workers=num_workers 16 | ) 17 | 18 | return loader 19 | 20 | def learn_adj(x): 21 | y = [] 22 | for t in x: 23 | b = t.cpu().numpy() 24 | y.append(b) 25 | 26 | x = np.array(y) 27 | batchsize = x.shape[0] 28 | simlr = SIMLR.SIMLR_LARGE(1, batchsize/3, 0) 29 | adj, _,_, _ = simlr.fit(x) 30 | array = adj.toarray() 31 | tensor = torch.Tensor(array).cuda() 32 | 33 | return tensor 34 | 35 | def to_tensor(x): 36 | y = [] 37 | for t in x: 38 | b = t.numpy() 39 | y.append(b) 40 | 41 | x = np.array(y) 42 | x = x[0] 43 | tensor = torch.Tensor(x) 44 | 45 | return tensor 46 | 47 | def create_dirs_if_not_exist(dir_list): 48 | if isinstance(dir_list, list): 49 | for dir in dir_list: 50 | if not os.path.exists(dir): 51 | os.makedirs(dir) 52 | else: 53 | if not os.path.exists(dir_list): 54 | os.makedirs(dir_list) -------------------------------------------------------------------------------- /fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basiralab/MultiGraphGAN/7b6d2160daaf07c2462b11a7a621e5ead04cef0b/fig1.png -------------------------------------------------------------------------------- /fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basiralab/MultiGraphGAN/7b6d2160daaf07c2462b11a7a621e5ead04cef0b/fig2.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf8 -*- 3 | """ 4 | Main function of MultiGraphGAN framework 5 | for jointly predicting multiple target brain graphs from a single source graph. 6 | 7 | Details can be found in: 8 | (1) the original paper https://link.springer.com/ 9 | Alaa Bessadok, Mohamed Ali Mahjoub, and Islem Rekik. "Topology-Aware Generative Adversarial Network for Joint Prediction of 10 | Multiple Brain Graphs from a Single Brain Graph", MICCAI 2020, Lima, Peru. 11 | (2) the youtube channel of BASIRA Lab: https://www.youtube.com/watch?v=OJOtLy9Xd34 12 | --------------------------------------------------------------------- 13 | 14 | This file contains the implementation of two main steps of our MultiGraphGAN framework: 15 | (1) source graphs embedding and clustering, and 16 | (2) cluster-specific multi-target graph prediction. 17 | 18 | MultiGraphGAN(src_loader, tgt_loaders, nb_clusters, opts) 19 | Inputs: 20 | src_loader: a PyTorch dataloader returning elements from source dataset batch by batch 21 | tgt_loaders: a PyTorch dataloader returning elements from target dataset batch by batch 22 | nb_clusters: number of clusters used to cluster the source graph embeddings 23 | opts: a python object (parser) storing all arguments needed to run the code such as hyper-parameters 24 | Output: 25 | model: our MultiGraphGAN model 26 | 27 | To evaluate our framework we used 90% of the dataset as training set and 10% for testing. 28 | 29 | Sample use for training: 30 | model = MultiGraphGAN(src_loader, tgt_loaders, opts.nb_clusters, opts) 31 | model.train() 32 | 33 | Sample use for testing: 34 | model = MultiGraphGAN(src_loader, tgt_loaders, opts.nb_clusters, opts) 35 | predicted_target_graphs, source_graphs = model.test() 36 | Output: 37 | predicted_target_graphs : a list of size num_domains-1 where num_domains is the number of source and target domains. 38 | Each element is an (n × f) matrix stacking the predicted target feature graphs f of n testing subjects 39 | source_graphs : a matrix of size (n × f) stacking the source feature graphs f of n testing subjects 40 | --------------------------------------------------------------------- 41 | Copyright 2020 Alaa Bessadok, Sousse University. 42 | Please cite the above paper if you use this code. 43 | All rights reserved. 44 | """ 45 | import argparse 46 | import random 47 | import yaml 48 | import numpy as np 49 | from torch.backends import cudnn 50 | from prediction import MultiGraphGAN 51 | from data_loader import * 52 | 53 | parser = argparse.ArgumentParser() 54 | # initialisation 55 | # Basic opts. 56 | parser.add_argument('--num_domains', type=int, default=6, help='how many domains(including source domain)') 57 | parser.add_argument('--mode', type=str, default='train', choices=['train', 'test']) 58 | parser.add_argument('--log_dir', type=str, default='logs/') 59 | parser.add_argument('--checkpoint_dir', type=str, default='models/') 60 | parser.add_argument('--sample_dir', type=str, default='samples/') 61 | parser.add_argument('--result_dir', type=str, default='results/') 62 | parser.add_argument('--result_root', type=str, default='result_MultiGraphGAN/') 63 | 64 | # GCN model opts 65 | parser.add_argument('--hidden1', type=int, default=32) 66 | parser.add_argument('--hidden2', type=int, default=16) 67 | parser.add_argument('--dropout', type=float, default=0.5) 68 | parser.add_argument('--in_feature', type=int, default=595) 69 | 70 | # Discriminator model opts. 71 | parser.add_argument('--cls_loss', type=str, default='BCE', choices=['LS','BCE'], help='least square loss or binary cross entropy loss') 72 | parser.add_argument('--lambda_cls', type=float, default=1, help='hyper-parameter for domain classification loss') 73 | parser.add_argument('--Lf', type=float, default=5, help='a constant with respect to the inter-domain constraint') 74 | parser.add_argument('--lambda_reg', type=float, default=0.1, help='a constant with respect to the gradient penalty') 75 | 76 | # Generator model opts. 77 | parser.add_argument('--lambda_idt', type=float, default=10, help='hyper-parameter for identity loss') 78 | parser.add_argument('--lambda_info', type=float, default=1, help='hype-rparameter for information maximazation loss') 79 | parser.add_argument('--lambda_topology', type=float, default=0.1, help='hyper-parameter for topological constraint') 80 | parser.add_argument('--lambda_rec', type=float, default=0.01, help='hyper-parameter for graph reconstruction loss') 81 | parser.add_argument('--nb_clusters', type=int, default=2, help='number of clusters for MKML clustering') 82 | 83 | # Training opts. 84 | parser.add_argument('--batch_size', type=int, default=70, help='mini-batch size') 85 | parser.add_argument('--num_iters', type=int, default=10, help='number of total iterations for training D') 86 | parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for G') 87 | parser.add_argument('--d_lr', type=float, default=0.0001, help='learning rate for D') 88 | parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each G update') 89 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer') 90 | parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer') 91 | parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step') 92 | parser.add_argument('--num_workers', type=int, default=1, help='num_workers to load data.') 93 | parser.add_argument('--log_step', type=int, default=5) 94 | parser.add_argument('--sample_step', type=int, default=5) 95 | parser.add_argument('--model_save_step', type=int, default=10) 96 | 97 | # Test opts. 98 | parser.add_argument('--test_iters', type=int, default=10, help='test model from this step') 99 | 100 | opts = parser.parse_args() 101 | opts.log_dir = os.path.join(opts.result_root, opts.log_dir) 102 | opts.checkpoint_dir = os.path.join(opts.result_root, opts.checkpoint_dir) 103 | opts.sample_dir = os.path.join(opts.result_root, opts.sample_dir) 104 | opts.result_dir = os.path.join(opts.result_root, opts.result_dir) 105 | 106 | 107 | if __name__ == '__main__': 108 | import pandas as pd 109 | # For fast training. 110 | cudnn.benchmark = True 111 | 112 | if opts.mode == 'train': 113 | """ 114 | Training MultiGraphGAN 115 | """ 116 | # Create directories if not exist. 117 | create_dirs_if_not_exist([opts.log_dir, opts.checkpoint_dir, opts.sample_dir, opts.result_dir]) 118 | 119 | # log opts. 120 | with open(os.path.join(opts.result_root, 'opts.yaml'), 'w') as f: 121 | f.write(yaml.dump(vars(opts))) 122 | 123 | # Simulate graph data for easy test the code 124 | source_target_domains = [] 125 | for i in range(opts.num_domains): 126 | source_target_domains.append(np.random.normal(random.random(), random.random(), (280,595))) 127 | 128 | # Choose the source domain to be translated 129 | src_domain = 0 130 | 131 | # Load source and target TRAIN datasets 132 | tgt_loaders = [] 133 | for domain in range(0, opts.num_domains): 134 | if domain == src_domain: 135 | source_feature = source_target_domains[domain] 136 | src_loader = get_loader(source_feature, opts.batch_size, opts.num_workers) 137 | else: 138 | target_feature = source_target_domains[domain] 139 | tgt_loader = get_loader(target_feature, opts.batch_size, opts.num_workers) 140 | tgt_loaders.append(tgt_loader) 141 | 142 | # Train MultiGraphGAN 143 | model = MultiGraphGAN(src_loader, tgt_loaders, opts.nb_clusters, opts) 144 | model.train() 145 | 146 | elif opts.mode == 'test': 147 | """ 148 | Testing MultiGraphGAN 149 | """ 150 | # Create directories if not exist. 151 | create_dirs_if_not_exist([opts.result_dir]) 152 | 153 | # Simulate graph data for easy test the code 154 | source_target_domains = [] 155 | for i in range(opts.num_domains): 156 | source_target_domains.append(np.random.normal(random.random(), random.random(), (30,595))) 157 | 158 | # Choose the source domain to be translated 159 | src_domain = 0 160 | 161 | # Load source and target TEST datasets 162 | tgt_loaders = [] 163 | for domain in range(0, opts.num_domains): 164 | if domain == src_domain: 165 | source_feature = source_target_domains[domain] 166 | src_loader = get_loader(source_feature, opts.batch_size, opts.num_workers) 167 | else: 168 | target_feature = source_target_domains[domain] 169 | tgt_loader = get_loader(target_feature, opts.batch_size, opts.num_workers) 170 | tgt_loaders.append(tgt_loader) 171 | 172 | # Test MultiGraphGAN 173 | model = MultiGraphGAN(src_loader, tgt_loaders, opts.nb_clusters, opts) 174 | predicted_target_graphs, source_graphs = model.test() 175 | 176 | # Save data into csv files 177 | print("saving source graphs into csv file...") 178 | f = source_graphs.cpu().numpy() 179 | dataframe = pd.DataFrame(data=f.astype(float)) 180 | dataframe.to_csv('source_graphs.csv', sep=' ', header=True, float_format='%.6f', index=False) 181 | 182 | print("saving predicted target graphs into csv files...") 183 | for idx in range(len(predicted_target_graphs)): 184 | f = predicted_target_graphs[idx].numpy() 185 | dataframe = pd.DataFrame(data=f.astype(float)) 186 | dataframe.to_csv('predicted_graphs_%d.csv'%(idx+1), sep=' ', header=True, float_format='%.6f', index=False) 187 | 188 | 189 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | import torch.nn as nn 5 | from torch.nn.parameter import Parameter 6 | from torch.nn.modules.module import Module 7 | from torch.nn import functional as F 8 | 9 | class GCN(Module): 10 | """ 11 | Graph Convolutional Network 12 | """ 13 | def __init__(self, in_features, out_features, bias=True): 14 | super(GCN, self).__init__() 15 | self.in_features = in_features 16 | self.out_features = out_features 17 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 18 | if bias: 19 | self.bias = Parameter(torch.FloatTensor(out_features)) 20 | else: 21 | self.register_parameter('bias', None) 22 | self.reset_parameters() 23 | 24 | def reset_parameters(self): 25 | stdv = 1. / math.sqrt(self.weight.size(1)) 26 | self.weight.data.uniform_(-stdv, stdv) 27 | if self.bias is not None: 28 | self.bias.data.uniform_(-stdv, stdv) 29 | 30 | def forward(self, input, adj): 31 | support = torch.mm(input, self.weight) 32 | output = torch.spmm(adj, support) 33 | if self.bias is not None: 34 | return output + self.bias 35 | else: 36 | return output 37 | 38 | def __repr__(self): 39 | return self.__class__.__name__ + ' (' \ 40 | + str(self.in_features) + ' -> ' \ 41 | + str(self.out_features) + ')' 42 | 43 | return output 44 | 45 | class GCNencoder(nn.Module): 46 | """ 47 | Encoder network. 48 | """ 49 | def __init__(self, nfeat, nhid, nout, dropout): 50 | super(GCNencoder, self).__init__() 51 | 52 | self.gc1 = GCN(nfeat, nhid) 53 | self.gc2 = GCN(nhid, nout) 54 | self.dropout = dropout 55 | 56 | def forward(self, x, adj): 57 | x = F.relu(self.gc1(x, adj)) 58 | x = F.dropout(x, self.dropout, training=self.training) 59 | x = F.relu(self.gc2(x, adj)) 60 | return x 61 | 62 | 63 | class GCNdecoder(nn.Module): 64 | """ 65 | Decoder network. 66 | """ 67 | def __init__(self, nfeat, nhid, nout, dropout): 68 | super(GCNdecoder, self).__init__() 69 | 70 | self.gc1 = GCN(nfeat, nhid) 71 | self.gc2 = GCN(nhid, nout) 72 | self.dropout = dropout 73 | 74 | def forward(self, x, adj): 75 | x = F.relu(self.gc1(x, adj)) 76 | x = F.dropout(x, self.dropout, training=self.training) 77 | x = F.relu(self.gc2(x, adj)) 78 | return x 79 | 80 | class Discriminator(nn.Module): 81 | """ 82 | Discriminator network with GCN. 83 | """ 84 | def __init__(self, input_size, output_size, dropout): 85 | super(Discriminator, self).__init__() 86 | 87 | self.gc1 = GCN(input_size, 32) 88 | self.gc2 = GCN(32, 16) 89 | self.gc3 = GCN(16, 1) 90 | self.dropout = dropout 91 | 92 | def forward(self, x, adj): 93 | x = F.relu(self.gc1(x, adj)) 94 | x = F.dropout(x, self.dropout, training=self.training) 95 | x = F.relu(self.gc2(x, adj)) 96 | x = F.dropout(x, self.dropout, training=self.training) 97 | a = self.gc3(x, adj) 98 | x = a.view(a.shape[0]) 99 | return F.sigmoid(x), F.softmax(x, dim=0) 100 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import ttest_rel 3 | from sklearn.metrics import mean_absolute_error 4 | from scipy.io import loadmat 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | 8 | # read csv file 9 | def readcsv(filename): 10 | data = pd.read_csv(filename) 11 | c = [] 12 | data = np.array(data) 13 | for i in range(0,data.shape[0]): 14 | a = data[i][0] 15 | b = np.array(list(a.split(" "))) 16 | c.append(b) 17 | 18 | return(np.array(c)) 19 | 20 | # Plot the connectomes 21 | def show_mtrx(m): 22 | fig, ax = plt.subplots(figsize = (20, 10)) 23 | 24 | min_val = round((m.min()), 6) 25 | max_val = round((m.max()), 6) 26 | 27 | cax = ax.matshow(m, cmap=plt.cm.Spectral) 28 | cbar = fig.colorbar(cax, ticks=[min_val, float((min_val + max_val)/2), max_val]) 29 | cbar.ax.set_yticklabels(['< %.2f'%(min_val), '%.2f'%(float((min_val + max_val)/2)), '> %.2f'%(max_val)]) 30 | plt.title(label="Source graph") 31 | plt.show() 32 | 33 | 34 | # put it back into a 2D symmetric array 35 | def to_2d(vector): 36 | size = 35 37 | x = np.zeros((size,size)) 38 | c = 0 39 | for i in range(1,size): 40 | for j in range(0,i): 41 | x[i][j] = vector[c] 42 | x[j][i] = vector[c] 43 | c = c + 1 44 | return x 45 | 46 | # Display the source matrix of the first subject 47 | pred = readcsv("source_graphs.csv") 48 | SG = to_2d(pred[0]) 49 | show_mtrx(SG) 50 | 51 | # Display the target graph in the domain 1 of the first subject 52 | pred = readcsv("predicted_graphs_1.csv") 53 | TG1 = to_2d(pred[0]) 54 | show_mtrx(TG1) 55 | 56 | # Display the target graph in the domain 2 of the first subject 57 | pred = readcsv("predicted_graphs_2.csv") 58 | TG2 = to_2d(pred[0]) 59 | show_mtrx(TG2) 60 | 61 | -------------------------------------------------------------------------------- /prediction.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import datetime 4 | import itertools 5 | import torch 6 | import SIMLR 7 | import torch.nn.functional as F 8 | from sklearn.metrics import mean_absolute_error 9 | from model import GCNencoder, GCNdecoder 10 | from model import Discriminator 11 | from data_loader import * 12 | from centrality import * 13 | import numpy as np 14 | 15 | class MultiGraphGAN(object): 16 | """ 17 | Build MultiGraphGAN model for training and testing. 18 | """ 19 | def __init__(self, src_loader, tgt_loaders, nb_clusters, opts): 20 | self.src_loader = src_loader 21 | self.tgt_loaders = tgt_loaders 22 | self.opts = opts 23 | # device 24 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | # criterion function 26 | self.criterionIdt = torch.nn.L1Loss() 27 | # build models 28 | self.build_model() 29 | self.build_generators(nb_clusters) 30 | self.nb_clusters = nb_clusters 31 | 32 | def build_model(self): 33 | """ 34 | Build encoder and discriminator models and initialize optimizers. 35 | """ 36 | # build shared encoder 37 | self.E = GCNencoder(self.opts.in_feature, self.opts.hidden1, self.opts.hidden2, self.opts.dropout).to(self.device) 38 | 39 | # build discriminator( combined with the auxiliary classifier ) 40 | self.D = Discriminator(self.opts.in_feature, 1, self.opts.dropout).to(self.device) 41 | 42 | self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.opts.d_lr, [self.opts.beta1, self.opts.beta2]) 43 | 44 | def build_generators(self,nb_clusters): 45 | """ 46 | Build cluster-specific generators models and initialize optimizers. 47 | """ 48 | self.Gs = [] 49 | param = [] 50 | for i in range(self.opts.num_domains - 1): 51 | inside_list=[] 52 | for i in range (nb_clusters): 53 | G_i = GCNdecoder(self.opts.hidden2, self.opts.hidden1, self.opts.in_feature, self.opts.dropout).to(self.device) 54 | inside_list.append(G_i) 55 | param.append(G_i) 56 | self.Gs.append(inside_list) 57 | 58 | # build optimizers 59 | param_list = [self.E.parameters()] + [G.parameters() for G in param] 60 | self.g_optimizer = torch.optim.Adam(itertools.chain(*param_list), 61 | self.opts.g_lr, [self.opts.beta1, self.opts.beta2]) 62 | 63 | def restore_model(self, resume_iters, nb_clusters): 64 | """ 65 | Restore the trained generators and discriminator. 66 | """ 67 | print('Loading the trained models from step {}...'.format(resume_iters)) 68 | 69 | E_path = os.path.join(self.opts.checkpoint_dir, '{}-E.ckpt'.format(resume_iters)) 70 | self.E.load_state_dict(torch.load(E_path, map_location=lambda storage, loc: storage)) 71 | 72 | for c in range(nb_clusters): 73 | for i in range(self.opts.num_domains - 1): 74 | G_i_path = os.path.join(self.opts.checkpoint_dir, '{}-G{}-{}.ckpt'.format(resume_iters, i+1, c)) 75 | print(G_i_path ) 76 | self.Gs[i][c].load_state_dict(torch.load(G_i_path, map_location=lambda storage, loc: storage)) 77 | 78 | D_path = os.path.join(self.opts.checkpoint_dir, '{}-D.ckpt'.format(resume_iters)) 79 | if os.path.exists(D_path): 80 | self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage)) 81 | 82 | 83 | def reset_grad(self): 84 | """ 85 | Reset the gradient buffers. 86 | """ 87 | self.g_optimizer.zero_grad() 88 | self.d_optimizer.zero_grad() 89 | 90 | 91 | def gradient_penalty(self, y, x, Lf): 92 | """ 93 | Compute gradient penalty. 94 | """ 95 | weight = torch.ones(y.size()).to(self.device) 96 | dydx = torch.autograd.grad(outputs=y, 97 | inputs=x, 98 | grad_outputs=weight, 99 | retain_graph=True, 100 | create_graph=True, 101 | only_inputs=True)[0] 102 | 103 | dydx = dydx.view(dydx.size(0), -1) 104 | dydx_l2norm = torch.sqrt(torch.sum(dydx ** 2, dim=1)) 105 | 106 | ZERO = torch.zeros_like(dydx_l2norm).to(self.device) 107 | penalty = torch.max(dydx_l2norm - Lf, ZERO) 108 | return torch.mean(penalty) ** 2 109 | 110 | 111 | def classification_loss(self, logit, target, type='LS'): 112 | """ 113 | Compute classification loss. 114 | """ 115 | print(type) 116 | if type == 'BCE': 117 | return F.binary_cross_entropy_with_logits(logit, target) 118 | elif type == 'LS': 119 | return F.mse_loss(logit, target) 120 | else: 121 | assert False, '[*] classification loss not implemented.' 122 | 123 | 124 | def train(self): 125 | """ 126 | Train MultiGraphGAN 127 | """ 128 | nb_clusters = self.nb_clusters 129 | 130 | #fixed data for evaluating: generate samples. 131 | src_iter = iter(self.src_loader) 132 | 133 | x_src_fixed= next(src_iter) 134 | x_src_fixed = x_src_fixed[0].to(self.device) 135 | d = next(iter(self.src_loader)) 136 | 137 | tgt_iters = [] 138 | for loader in self.tgt_loaders: 139 | tgt_iters.append(iter(loader)) 140 | 141 | # label 142 | label_pos = torch.FloatTensor([1] * d[0].shape[0]).to(self.device) 143 | label_neg = torch.FloatTensor([0] * d[0].shape[0]).to(self.device) 144 | 145 | # Start training from scratch or resume training. 146 | start_iters = 0 147 | if self.opts.resume_iters: 148 | start_iters = self.opts.resume_iters 149 | self.restore_model(self.opts.resume_iters) 150 | 151 | # Start training. 152 | print('Start training MultiGraphGAN...') 153 | start_time = time.time() 154 | 155 | for i in range(start_iters, self.opts.num_iters): 156 | print("iteration",i) 157 | # =================================================================================== # 158 | # 1. Preprocess input data # 159 | # =================================================================================== # 160 | try: 161 | x_src = next(src_iter) 162 | except: 163 | src_iter = iter(self.src_loader) 164 | x_src = next(src_iter) 165 | 166 | x_src = x_src[0].to(self.device) 167 | 168 | x_tgts = [] 169 | for tgt_idx in range(len(tgt_iters)): 170 | try: 171 | x_tgt_i= next(tgt_iters[tgt_idx]) 172 | x_tgts.append(x_tgt_i) 173 | except: 174 | tgt_iters[tgt_idx] = iter(self.tgt_loaders[tgt_idx]) 175 | x_tgt_i= next(tgt_iters[tgt_idx]) 176 | x_tgts.append(x_tgt_i) 177 | 178 | for tgt_idx in range(len(x_tgts)): 179 | x_tgts[tgt_idx] = x_tgts[tgt_idx][0].to(self.device) 180 | print("x_tgts",x_tgts[tgt_idx].shape) 181 | 182 | # =================================================================================== # 183 | # 2. Train the discriminator # 184 | # =================================================================================== # 185 | 186 | embedding = self.E(x_src,learn_adj(x_src)).detach() 187 | ## Cluster the source graph embeddings using SIMLR 188 | simlr = SIMLR.SIMLR_LARGE(nb_clusters, embedding.shape[0]/2, 0) 189 | S, ff, val, ind = simlr.fit(embedding) 190 | y_pred = simlr.fast_minibatch_kmeans(ff,nb_clusters) 191 | y_pred = y_pred.tolist() 192 | get_indexes = lambda x, xs: [i for (y, i) in zip(xs, range(len(xs))) if x == y] 193 | 194 | x_fake_list = [] 195 | x_src_list = [] 196 | d_loss_cls = 0 197 | d_loss_fake = 0 198 | d_loss = 0 199 | 200 | print("Train the discriminator") 201 | for par in range(nb_clusters): 202 | print("================") 203 | print("cluster",par) 204 | print("================") 205 | cluster_index_list = get_indexes(par,y_pred) 206 | print(cluster_index_list) 207 | for idx in range(len(self.Gs)): 208 | x_fake_i = self.Gs[idx][par](embedding[cluster_index_list],learn_adj(x_tgts[idx][cluster_index_list])).detach() 209 | x_fake_list.append(x_fake_i) 210 | x_src_list.append(x_src[cluster_index_list]) 211 | 212 | out_fake_i, out_cls_fake_i = self.D(x_fake_i,learn_adj(x_fake_i)) 213 | _, out_cls_real_i = self.D(x_tgts[idx][cluster_index_list],learn_adj(x_tgts[idx][cluster_index_list])) 214 | 215 | ### Graph domain classification loss 216 | d_loss_cls_i = self.classification_loss(out_cls_real_i, label_pos[cluster_index_list], type=self.opts.cls_loss) \ 217 | + self.classification_loss(out_cls_fake_i, label_neg[cluster_index_list], type=self.opts.cls_loss) 218 | d_loss_cls += d_loss_cls_i 219 | 220 | # Part of adversarial loss 221 | d_loss_fake += torch.mean(out_fake_i) 222 | 223 | out_src, out_cls_src = self.D(x_src[cluster_index_list],learn_adj(x_src[cluster_index_list])) 224 | ### Adversarial loss 225 | d_loss_adv = torch.mean(out_src) - d_loss_fake / (self.opts.num_domains - 1) 226 | 227 | ### Gradient penalty loss 228 | x_fake_cat = torch.cat(x_fake_list) 229 | x_src_cat = torch.cat(x_src_list) 230 | 231 | alpha = torch.rand(x_src_cat.size(0), 1).to(self.device) 232 | x_hat = (alpha * x_src_cat.data + (1 - alpha) * x_fake_cat.data).requires_grad_(True) 233 | 234 | out_hat, _ = self.D(x_hat,learn_adj(x_hat.detach())) 235 | d_loss_reg = self.gradient_penalty(out_hat, x_hat, self.opts.Lf) 236 | 237 | # Cluster-based loss to update the discriminator 238 | d_loss_cluster = -1 * d_loss_adv + self.opts.lambda_cls * d_loss_cls + self.opts.lambda_reg * d_loss_reg 239 | 240 | ### Discriminator loss 241 | d_loss += d_loss_cluster 242 | 243 | 244 | print("d_loss",d_loss) 245 | self.reset_grad() 246 | d_loss.backward() 247 | self.d_optimizer.step() 248 | 249 | # Logging. 250 | loss = {} 251 | loss['D/loss_adv'] = d_loss_adv.item() 252 | loss['D/loss_cls'] = d_loss_cls.item() 253 | loss['D/loss_reg'] = d_loss_reg.item() 254 | 255 | # =================================================================================== # 256 | # 3. Train the cluster-specific generators # 257 | # =================================================================================== # 258 | print("Train the generators") 259 | if (i + 1) % self.opts.n_critic == 0: 260 | g_loss_info = 0 261 | g_loss_adv = 0 262 | g_loss_idt = 0 263 | g_loss_topo = 0 264 | g_loss_rec = 0 265 | g_loss = 0 266 | 267 | for par in range(nb_clusters): 268 | print("cluster",par) 269 | for idx in range(len(self.Gs)): 270 | # ========================= # 271 | # =====source-to-target==== # 272 | # ========================= # 273 | x_fake_i = self.Gs[idx][par](embedding[cluster_index_list],learn_adj(x_tgts[idx][cluster_index_list])) 274 | 275 | # Global topology loss 276 | global_topology = self.criterionIdt(x_fake_i, x_tgts[idx][cluster_index_list]) 277 | 278 | # Local topology loss 279 | real_topology = topological_measures(x_tgts[idx][cluster_index_list]) 280 | fake_topology = topological_measures(x_fake_i.detach()) 281 | # 0:closeness centrality 1:betweeness centrality 2:eginvector centrality 282 | local_topology = mean_absolute_error(fake_topology[0],real_topology[0]) 283 | 284 | ### Topology loss 285 | g_loss_topo += (local_topology + global_topology) 286 | 287 | if self.opts.lambda_idt > 0: 288 | x_fake_i_idt = self.Gs[idx][par](self.E(x_tgts[idx][cluster_index_list],learn_adj(x_tgts[idx][cluster_index_list])),learn_adj(x_tgts[idx][cluster_index_list])) 289 | g_loss_idt += self.criterionIdt(x_fake_i_idt, x_tgts[idx][cluster_index_list]) 290 | 291 | out_fake_i, out_cls_fake_i = self.D(x_fake_i,learn_adj(x_fake_i.detach())) 292 | 293 | ### Information maximization loss 294 | g_loss_info_i = F.binary_cross_entropy_with_logits(out_cls_fake_i, label_pos[cluster_index_list]) 295 | g_loss_info += g_loss_info_i 296 | 297 | ### Adversarial loss 298 | g_loss_adv -= torch.mean(out_fake_i) # opposed sign 299 | 300 | # ========================= # 301 | # =====target-to-source==== # 302 | # ========================= # 303 | x_reconst = self.Gs[idx][par](self.E(x_fake_i,learn_adj(x_fake_i.detach())),learn_adj(x_fake_i.detach())) 304 | 305 | # Reconstructed global topology loss 306 | reconstructed_global_topology = self.criterionIdt(x_src[cluster_index_list], x_reconst) 307 | 308 | # Reconstructed local topology loss 309 | real_topology = topological_measures(x_src[cluster_index_list]) 310 | fake_topology = topological_measures(x_reconst.detach()) 311 | # 0:closeness centrality 1:betweeness centrality 2:eginvector centrality 312 | reconstructed_local_topology = mean_absolute_error(fake_topology[0],real_topology[0]) 313 | 314 | ### Graph reconstruction loss 315 | g_loss_rec += (reconstructed_local_topology + reconstructed_global_topology) 316 | 317 | # Cluster-based loss to update the generators 318 | g_loss_cluster = g_loss_adv / (self.opts.num_domains - 1) + self.opts.lambda_info * g_loss_info + self.opts.lambda_idt * g_loss_idt + self.opts.lambda_topology * g_loss_topo + self.opts.lambda_rec * g_loss_rec 319 | 320 | ### Generator loss 321 | g_loss += g_loss_cluster 322 | 323 | print("g_loss",g_loss) 324 | self.reset_grad() 325 | g_loss.backward() 326 | self.g_optimizer.step() 327 | 328 | # Logging. 329 | loss['G/loss_adv'] = g_loss_adv.item() 330 | loss['G/loss_rec'] = g_loss_rec.item() 331 | loss['G/loss_cls'] = g_loss_info.item() 332 | if self.opts.lambda_idt > 0: 333 | loss['G/loss_idt'] = g_loss_idt.item() 334 | 335 | # =================================================================================== # 336 | # 4. Miscellaneous # 337 | # =================================================================================== # 338 | # print out training information. 339 | if (i + 1) % self.opts.log_step == 0: 340 | et = time.time() - start_time 341 | et = str(datetime.timedelta(seconds=et))[:-7] 342 | log = "Elapsed [{}], Iteration [{}/{}]".format(et, i + 1, self.opts.num_iters) 343 | for tag, value in loss.items(): 344 | log += ", {}: {:.4f}".format(tag, value) 345 | print(log) 346 | 347 | 348 | # save model checkpoints. 349 | if (i + 1) % self.opts.model_save_step == 0: 350 | E_path = os.path.join(self.opts.checkpoint_dir, '{}-E.ckpt'.format(i+1)) 351 | torch.save(self.E.state_dict(), E_path) 352 | 353 | D_path = os.path.join(self.opts.checkpoint_dir, '{}-D.ckpt'.format(i+1)) 354 | torch.save(self.D.state_dict(), D_path) 355 | 356 | for par in range(nb_clusters): 357 | for idx in range(len(self.Gs)): 358 | G_i_path = os.path.join(self.opts.checkpoint_dir, '{}-G{}-{}.ckpt'.format(i+1, idx+1, par)) 359 | print(G_i_path) 360 | torch.save(self.Gs[idx][par].state_dict(), G_i_path) 361 | 362 | print('Saved model checkpoints into {}...'.format(self.opts.checkpoint_dir)) 363 | 364 | print('=============================') 365 | print("End of Training") 366 | print('=============================') 367 | 368 | # =================================================================================== # 369 | # 5. Test with a new dataset # 370 | # =================================================================================== # 371 | def test(self): 372 | """ 373 | Test the trained MultiGraphGAN. 374 | """ 375 | self.restore_model(self.opts.test_iters,self.opts.nb_clusters) 376 | 377 | # Set data loader. 378 | src_loader = self.src_loader 379 | x_src = next(iter(self.src_loader)) 380 | x_src = x_src[0].to(self.device) 381 | 382 | tgt_iters = [] 383 | for loader in self.tgt_loaders: 384 | tgt_iters.append(iter(loader)) 385 | 386 | x_tgts = [] 387 | for tgt_idx in range(len(tgt_iters)): 388 | try: 389 | x_tgt_i= next(tgt_iters[tgt_idx]) 390 | x_tgts.append(x_tgt_i) 391 | except: 392 | tgt_iters[tgt_idx] = iter(self.tgt_loaders[tgt_idx]) 393 | x_tgt_i= next(tgt_iters[tgt_idx]) 394 | x_tgts.append(x_tgt_i) 395 | 396 | for tgt_idx in range(len(x_tgts)): 397 | x_tgts[tgt_idx] = x_tgts[tgt_idx][0].to(self.device) 398 | 399 | # return model.eval() 400 | for par in range(self.opts.nb_clusters): 401 | for idx in range(len(self.Gs)): 402 | self.Gs[idx][par].eval() 403 | 404 | with torch.no_grad(): 405 | embedding = self.E(x_src,learn_adj(x_src)) 406 | predicted_target_graphs = [] 407 | for idx in range(len(self.Gs)): 408 | sum_cluster_pred_graph = 0 409 | for par in range(self.opts.nb_clusters): 410 | x_fake_i = self.Gs[idx][par](embedding,learn_adj(x_src)) 411 | sum_cluster_pred_graph = np.add(sum_cluster_pred_graph,x_fake_i.cpu()) 412 | 413 | average_predicted_target_graph = sum_cluster_pred_graph / float(self.opts.nb_clusters) 414 | predicted_target_graphs.append(average_predicted_target_graph) 415 | 416 | 417 | print('=============================') 418 | print("End of Testing") 419 | print('=============================') 420 | 421 | return predicted_target_graphs, x_src -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | networkx 3 | annoy 4 | fbpca 5 | SIMLR 6 | yaml --------------------------------------------------------------------------------