├── LICENSE ├── README.md ├── Sampling_animation.gif ├── Simulation_animation.gif ├── TNN_schematic.jpg ├── reproducibilty └── README.md ├── requirements.txt ├── setup.py ├── tnn ├── __init__.py ├── tnn.py └── version.py └── umap_embedding.png /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 lkmklsmn 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # insct ("Insight") 2 | **IN**tegration of millions of **S**ingle **C**ells using batch-aware **T**riplet networks 3 | 4 | `INSCT` is a deep learning algorithm which calculates an integrated embedding for scRNA-seq data. With `INSCT`, you can: 5 | 6 | * Integrate scRNA-seq datasets across batches with/without labels. 7 | * Generate a low-dimensional representation of the scRNA-seq data. 8 | * Integrate of millions of cells on personal computers. 9 | 10 | For more info check out our [manuscript](https://go.nature.com/2Uq73If). 11 | 12 | ## How does it work? 13 | ![tnn](https://github.com/lkmklsmn/insct/blob/master/TNN_schematic.jpg) 14 | 15 | 1. `INSCT` learns a data representation, which integrates cells across batches. The goal of the network is to minimize the distance between Anchor and Positive while maximizing the distance between Anchor and Negative. Anchor and Positive pairs consist of transcriptionally similar cells from different batches. The Negative is a transcriptomically dissimilar cell sampled from the same batch as the Anchor. 16 | 1. Principal components of three data points corresponding to Anchor, Positive and Negative are fed into three identical neural networks, which share weights. The triplet loss function is used to train the network weights and the two-dimensional embedding layer activations represent the integrated embedding. 17 | 18 | To learn an integrated embedding that overcomes batch effects, `INSCT` samples triplets in a batch-aware manner: 19 | 20 | ![tnn](https://github.com/lkmklsmn/insct/blob/master/Sampling_animation.gif) 21 | 22 | ## What does it do? 23 | 24 | For example, we simulated scRNAseq data, where batch effects dominate the embedding: 25 | 26 | ![tnn](https://github.com/lkmklsmn/insct/blob/master/umap_embedding.png) 27 | 28 | However, `INSCT` learns an integrated embedding where cells cluster by group instead of batch: 29 | 30 | ![tnn](https://github.com/lkmklsmn/insct/blob/master/Simulation_animation.gif) 31 | 32 | ## Check out our interactive tutorials! 33 | The following notebooks can be run within your web browser and allow you to interactively explore tnn. We have prepared the following analysis examples: 34 | 1. [Simulation dataset](https://colab.research.google.com/drive/1LEDnRwFH2v166T-pUaCYb6TZMgfViO-W?usp=sharing) 35 | 2. [Pancreas dataset](https://colab.research.google.com/drive/1v_B0pXVYMqHsV2polaoRHkxflrNcQGej?usp=sharing) 36 | 37 | Notebooks to reproduce the analyses described in our preprint can be found in the _reproducibility_ folder. 38 | 39 | ## Installation 40 | 41 | `INSCT` depends on the following Python packages. These need to be installed separately: 42 | ``` 43 | ivis==1.7.2 44 | scanpy 45 | hnswlib 46 | ``` 47 | 48 | To install `INSCT`, follow these instructions: 49 | 50 | ### Github 51 | 52 | Install directly from Github using pip: 53 | 54 | ```alias 55 | pip install git+https://github.com/lkmklsmn/insct.git 56 | ``` 57 | 58 | Download the package from Github and install it locally: 59 | 60 | ```alias 61 | git clone http://github.com/lkmklsmn/insct 62 | cd insct 63 | pip install . 64 | ``` 65 | 66 | ## Usage 67 | ### Unsupervised model 68 | Triplets sampled based on transcriptional similarity 69 | 1. AnnData object with PCs 70 | 2. Batch vector 71 | 72 | ```alias 73 | from insct.tnn import TNN 74 | model = TNN() 75 | model.fit(X = adata, batch_name='batch') 76 | ``` 77 | 78 | ### Supervised model 79 | Triplets sampled based on both transcriptional similarity and known labels 80 | 1. AnnData object with PCs 81 | 2. Batch vector 82 | 3. Celltype vector 83 | 84 | ```alias 85 | model = TNN() 86 | model.fit(X = adata, batch_name='batch', celltype_name='Celltypes') 87 | ``` 88 | 89 | ### Semi-supervised model 90 | Triplets sampled based on both transcriptional similarity and known labels 91 | 1. AnnData object with PCs 92 | 2. Batch vector 93 | 3. Celltype vector 94 | 4. Masking vector (which labels to ignore) 95 | 96 | ```alias 97 | model = TNN() 98 | model.fit(X = adata, batch_name='batch', celltype_name='Celltypes', mask_batch= batch_name) 99 | ``` 100 | 101 | ## Output 102 | 1. Coordinates for the integrated embedding 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /Sampling_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkmklsmn/insct/c2eb12df8f69d330996c0f5c43ae39f292cb02ea/Sampling_animation.gif -------------------------------------------------------------------------------- /Simulation_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkmklsmn/insct/c2eb12df8f69d330996c0f5c43ae39f292cb02ea/Simulation_animation.gif -------------------------------------------------------------------------------- /TNN_schematic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkmklsmn/insct/c2eb12df8f69d330996c0f5c43ae39f292cb02ea/TNN_schematic.jpg -------------------------------------------------------------------------------- /reproducibilty/README.md: -------------------------------------------------------------------------------- 1 | 2 | Google Colab notebooks to reproduce the analysis described in our manuscript can be accessed here: 3 | 4 | 1. [Fig2 - Simulation analysis](https://colab.research.google.com/drive/1LEDnRwFH2v166T-pUaCYb6TZMgfViO-W?usp=sharing) 5 | 2. [Fig3 - Tabula Muris integration](https://colab.research.google.com/drive/1y2z0vQmA2SQNrj9XAfuhiFDm19XeubOn?usp=sharing) 6 | 3. [Fig3 - Tabula Muris evaluation](https://colab.research.google.com/drive/1bxekVOIfBaeScx5Rh1S4iEEAn1QsqOSH?usp=sharing) 7 | 4. [Fig4 - Mouse-human integration](https://colab.research.google.com/drive/1oy5b9HoKrPktOB3KqHvhtOOflFmrfPRr?usp=sharing) 8 | 5. [Fig4 - Mouse-human evaluation](https://colab.research.google.com/drive/12alsy9-ANbFVRQ7jmOYO2VoE6BLZuJqZ?usp=sharing) 9 | 6. [Fig5 - Pancreas semi-supervised integration](https://colab.research.google.com/drive/1Ulj0ghqdRMjedUQKHECyoQQyl8j4tdJF?usp=sharing) 10 | 7. [Fig6 - Mouse brain integration](https://colab.research.google.com/drive/1yB2cr9jPDysGKt8Hnze-kvxPlbZ0R-SQ#scrollTo=1yG583VZztc0) 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pip==21.1 2 | wheel==0.33.1 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="tnn", 5 | version="0.0.1", 6 | description="Deep triplet neural networks for integration of scRNAseq data", 7 | author='Lukas Simon, Yin-Ying Wang', 8 | author_email="lkmklsmn@gmail.com", 9 | packages=['tnn'], 10 | install_requires=['sklearn','scanpy','anndata','pandas','tensorflow','numpy', 'ivis', 'scipy', 'hnswlib'] 11 | ) 12 | -------------------------------------------------------------------------------- /tnn/__init__.py: -------------------------------------------------------------------------------- 1 | from .tnn import TNN 2 | from .version import VERSION as __version__ 3 | -------------------------------------------------------------------------------- /tnn/tnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scanpy as sc 3 | import pandas as pd 4 | 5 | from annoy import AnnoyIndex 6 | 7 | from tensorflow import keras 8 | from tensorflow.keras.models import load_model, Model 9 | from tensorflow.keras.layers import Input, Dense, AlphaDropout, Lambda 10 | from tensorflow.keras.regularizers import l2, l1_l2 11 | from tensorflow.keras.callbacks import Callback, EarlyStopping 12 | from tensorflow.keras import backend as K 13 | from tensorflow.keras import regularizers 14 | from tensorflow.keras.utils import Sequence 15 | 16 | import random 17 | from ivis.nn.losses import * 18 | from ivis.nn.network import triplet_network, base_network 19 | from ivis.nn.callbacks import ModelCheckpoint 20 | 21 | from scipy.sparse import issparse 22 | from scipy.spatial import cKDTree 23 | 24 | from sklearn.base import BaseEstimator 25 | from sklearn.neighbors import NearestNeighbors 26 | import json 27 | import os 28 | import shutil 29 | import multiprocessing 30 | import platform 31 | 32 | from sklearn.metrics.pairwise import rbf_kernel, euclidean_distances 33 | from sklearn.neighbors import NearestNeighbors 34 | from intervaltree import IntervalTree 35 | import operator 36 | 37 | from scipy.sparse import issparse 38 | from annoy import AnnoyIndex 39 | from multiprocessing import Process, cpu_count, Queue 40 | from collections import namedtuple 41 | from operator import attrgetter 42 | from tqdm import tqdm 43 | import time 44 | import itertools 45 | import networkx as nx 46 | 47 | import hnswlib 48 | 49 | from sklearn.preprocessing import LabelEncoder 50 | le = LabelEncoder() 51 | from sklearn import metrics 52 | 53 | def base_network(input_shape): 54 | '''Base network to be shared (eq. to feature extraction). 55 | ''' 56 | inputs = Input(shape=input_shape) 57 | n_dim = round(0.75 * input_shape[0]) 58 | x = Dense(n_dim, activation='selu', 59 | kernel_initializer='lecun_normal')(inputs) 60 | x = AlphaDropout(0.25)(x) 61 | x = Dense(n_dim, activation='selu', 62 | kernel_initializer='lecun_normal')(x) 63 | x = AlphaDropout(0.25)(x) 64 | x = Dense(n_dim, activation='selu', kernel_initializer='lecun_normal')(x) 65 | return Model(inputs, x) 66 | 67 | 68 | def generator_from_index(adata, batch_name, celltype_name=None, mask_batch=None, Y = None, k = 20, label_ratio = 0.8, k_to_m_ratio = 0.75, batch_size = 32, search_k=-1, 69 | save_on_disk = True, approx = True, verbose=1): 70 | 71 | print('version 0.0.2. 09:00, 12/01/2020') 72 | 73 | # Calculate MNNs by pairwise comparison between batches 74 | 75 | cells = adata.obs_names 76 | 77 | if(verbose > 0): 78 | print("Calculating MNNs...") 79 | 80 | mnn_dict = create_dictionary_mnn(adata, batch_name=batch_name, k = k, save_on_disk = save_on_disk, approx = approx, verbose = verbose) 81 | 82 | if(verbose > 0): 83 | print(str(len(mnn_dict)) + " cells defined as MNNs") 84 | 85 | if celltype_name is None: 86 | label_dict=dict() 87 | else: 88 | 89 | if (verbose > 0): 90 | print ('Generating supervised positive pairs...') 91 | 92 | label_dict_original = create_dictionary_label(adata, celltype_name= celltype_name, batch_name=batch_name, mask_batch=mask_batch, k=k, verbose=verbose) 93 | num_label = round(label_ratio * len(label_dict_original)) 94 | 95 | cells_for_label = np.random.choice(list(label_dict_original.keys()), num_label, replace = False) 96 | 97 | label_dict = {key: value for key, value in label_dict_original.items() if key in cells_for_label} 98 | 99 | if(verbose > 0): 100 | print(str(len(label_dict.keys())) + " cells defined as supervision triplets") 101 | 102 | print (len(set(mnn_dict.keys())&set(label_dict.keys()))) 103 | 104 | if k_to_m_ratio == 0.0: 105 | knn_dict = dict() 106 | 107 | else: 108 | num_k = round(k_to_m_ratio * len(mnn_dict)) 109 | # Calculate KNNs for subset of residual cells 110 | cells_for_knn = list(set(cells) - (set(list(label_dict.keys()))| set(list(mnn_dict.keys())))) 111 | if(len(cells_for_knn) > num_k): 112 | cells_for_knn = np.random.choice(cells_for_knn, num_k, replace = False) 113 | 114 | if(verbose > 0): 115 | print("Calculating KNNs...") 116 | 117 | cdata = adata[cells_for_knn] 118 | knn_dict = create_dictionary_knn(cdata, cells_for_knn, k = k, save_on_disk = save_on_disk, approx = approx) 119 | if(verbose > 0): 120 | print(str(len(cells_for_knn)) + " cells defined as KNNs") 121 | 122 | final_dict = merge_dict(mnn_dict, label_dict) 123 | final_dict.update(knn_dict) 124 | 125 | 126 | cells_for_train = list(final_dict.keys()) 127 | print ('Total cells for training:'+ str(len(cells_for_train))) 128 | 129 | ddata = adata[cells_for_train] 130 | 131 | # Reorder triplet list according to cells 132 | if(verbose > 0): 133 | print("Reorder") 134 | names_as_dict = dict(zip(list(adata.obs_names), range(0, adata.shape[0]))) 135 | def get_indices2(name): 136 | return([names_as_dict[x] for x in final_dict[name]]) 137 | 138 | triplet_list = list(map(get_indices2, cells_for_train)) 139 | 140 | batch_list = ddata.obs[batch_name] 141 | batch_indices = [] 142 | for i in batch_list.unique(): 143 | batch_indices.append(list(np.where(batch_list == i)[0])) 144 | 145 | batch_as_dict = dict(zip(list(batch_list.unique()), range(0, len(batch_list.unique())))) 146 | tmp = map(lambda _: batch_as_dict[_], batch_list) 147 | batch_list = list(tmp) 148 | 149 | if Y is None: 150 | return KnnTripletGenerator(X = ddata.obsm["X_pca"], X1 = adata.obsm['X_pca'], dictionary = triplet_list, 151 | batch_list = batch_list, batch_indices = batch_indices, batch_size=batch_size) 152 | 153 | else: 154 | tmp = dict(zip(cells, Y)) 155 | Y_new = [tmp[x] for x in cells_for_train] 156 | Y_new = le.fit_transform(Y_new) 157 | return LabeledKnnTripletGenerator(X = ddata.obsm["X_pca"], X1 = adata.obsm['X_pca'], Y = Y_new, dictionary = triplet_list, 158 | batch_list = batch_list, batch_indices = batch_indices, batch_size = batch_size) 159 | 160 | 161 | def merge_dict(x,y): 162 | for k,v in x.items(): 163 | if k in y.keys(): 164 | y[k] += v 165 | else: 166 | y[k] = v 167 | return y 168 | 169 | 170 | class KnnTripletGenerator(Sequence): 171 | 172 | def __init__(self, X, X1, dictionary, batch_list, batch_indices, batch_size=32): 173 | self.X = X 174 | self.X1 = X1 175 | self.batch_list = batch_list 176 | self.batch_indices = batch_indices 177 | self.batch_size = batch_size 178 | self.dictionary = dictionary 179 | self.placeholder_labels = np.empty(batch_size, dtype=np.uint8) 180 | self.num_cells = len(self.dictionary) 181 | 182 | def __len__(self): 183 | return int(np.ceil(len(self.dictionary) / float(self.batch_size))) 184 | 185 | def __getitem__(self, idx): 186 | 187 | batch_indices = range(idx * self.batch_size, min((idx + 1) * self.batch_size, self.num_cells)) 188 | 189 | triplet_batch = [self.knn_triplet_from_dictionary(row_index = row_index, 190 | neighbour_list = self.dictionary[row_index], 191 | batch = self.batch_list[row_index], 192 | num_cells = self.num_cells) for row_index in batch_indices] 193 | 194 | if (issparse(self.X)): 195 | triplet_batch = [[e.toarray()[0] for e in t] for t in triplet_batch] 196 | 197 | triplet_batch = np.array(triplet_batch) 198 | placeholder_labels = self.placeholder_labels[:triplet_batch.shape[0]] 199 | 200 | return tuple([triplet_batch[:, 0], triplet_batch[:, 1], triplet_batch[:, 2]]), placeholder_labels 201 | 202 | def knn_triplet_from_dictionary(self, row_index, neighbour_list, batch, num_cells): 203 | """ A random (unweighted) positive example chosen. """ 204 | triplets = [] 205 | 206 | anchor = row_index 207 | positive = np.random.choice(neighbour_list) 208 | negative = self.batch_indices[batch][np.random.randint(len(self.batch_indices[batch]))] 209 | 210 | triplets += [self.X[anchor], self.X1[positive], 211 | self.X1[negative]] 212 | 213 | return triplets 214 | 215 | 216 | class LabeledKnnTripletGenerator(Sequence): 217 | def __init__(self, X, X1, Y, dictionary, batch_list, batch_indices, batch_size=32): 218 | self.X = X 219 | self.X1 = X1 220 | self.Y = Y 221 | self.batch_list = batch_list 222 | self.batch_indices = batch_indices 223 | self.batch_size = batch_size 224 | self.dictionary = dictionary 225 | self.num_cells = len(self.dictionary) 226 | 227 | def __len__(self): 228 | return int(np.ceil(len(self.dictionary) / float(self.batch_size))) 229 | 230 | def __getitem__(self, idx): 231 | 232 | batch_indices = range(idx * self.batch_size, min((idx + 1) * self.batch_size, self.num_cells)) 233 | 234 | triplet_batch = [self.knn_triplet_from_dictionary(row_index = row_index, 235 | neighbour_list = self.dictionary[row_index], 236 | batch = self.batch_list[row_index], 237 | num_cells = self.num_cells) for row_index in batch_indices] 238 | 239 | if (issparse(self.X)): 240 | triplet_batch = [[e.toarray()[0] for e in t] for t in triplet_batch] 241 | 242 | triplet_batch = np.array(triplet_batch) 243 | label_batch = self.Y[batch_indices] 244 | 245 | return tuple([triplet_batch[:, 0], triplet_batch[:, 1], triplet_batch[:, 2]]), tuple([np.array(label_batch), np.array(label_batch)]) 246 | 247 | def knn_triplet_from_dictionary(self, row_index, neighbour_list, batch, num_cells): 248 | """ A random (unweighted) positive example chosen. """ 249 | triplets = [] 250 | 251 | anchor = row_index 252 | 253 | positive = np.random.choice(neighbour_list) 254 | negative = self.batch_indices[batch][np.random.randint(len(self.batch_indices[batch]))] 255 | 256 | triplets += [self.X[anchor], self.X1[positive], 257 | self.X1[negative]] 258 | 259 | 260 | return triplets 261 | 262 | 263 | def create_dictionary_label(bdata, celltype_name, batch_name, mask_batch, k=50, verbose=1): 264 | 265 | #cell_names = adata.obs_names 266 | adata = bdata[bdata.obs[batch_name]!=mask_batch] 267 | batch_list = adata.obs[batch_name] 268 | cell_types = adata.obs[celltype_name] 269 | 270 | print (batch_list.unique()) 271 | 272 | types = [] 273 | for i in batch_list.unique(): 274 | types.append(cell_types[batch_list == i]) 275 | 276 | print (len(types)) 277 | 278 | labeled_dict = dict() 279 | 280 | for comb in list(itertools.permutations(range(len(types)), 2)): 281 | 282 | i = comb[0] 283 | j = comb[1] 284 | 285 | if(verbose > 0): 286 | print('Processing positive pairs {}'.format((i, j))) 287 | 288 | ref_types = types[i] 289 | new_types = types[j] 290 | common = set(ref_types) & set(new_types) 291 | 292 | for each in common: 293 | ref = list(ref_types[ref_types==each].index) 294 | new = list(new_types[new_types==each].index) 295 | 296 | num_k =min(int(k/10), 5,len(new)) 297 | 298 | for key in ref: 299 | new_cells = np.random.choice(new, num_k, replace = False) 300 | if key not in labeled_dict.keys(): 301 | 302 | labeled_dict[key] = list(new_cells) 303 | else: 304 | labeled_dict[key] += list(new_cells) 305 | 306 | return(labeled_dict) 307 | 308 | 309 | def create_dictionary_mnn(adata, batch_name, k = 50, save_on_disk = True, approx = True, verbose = 1): 310 | 311 | cell_names = adata.obs_names 312 | 313 | batch_list = adata.obs[batch_name] 314 | datasets = [] 315 | datasets_pcs = [] 316 | cells = [] 317 | for i in batch_list.unique(): 318 | datasets.append(adata[batch_list == i]) 319 | datasets_pcs.append(adata[batch_list == i].obsm["X_pca"]) 320 | cells.append(cell_names[batch_list == i]) 321 | 322 | mnns = dict() 323 | 324 | for comb in list(itertools.combinations(range(len(cells)), 2)): 325 | i = comb[0] 326 | j = comb[1] 327 | 328 | if(verbose > 0): 329 | print('Processing datasets {}'.format((i, j))) 330 | 331 | new = list(cells[j]) 332 | ref = list(cells[i]) 333 | 334 | ds1 = adata[new].obsm['X_pca'] 335 | ds2 = adata[ref].obsm['X_pca'] 336 | names1 = new 337 | names2 = ref 338 | match = mnn(ds1, ds2, names1, names2, knn=k, save_on_disk = save_on_disk, approx = approx) 339 | 340 | G = nx.Graph() 341 | G.add_edges_from(match) 342 | node_names = np.array(G.nodes) 343 | anchors = list(node_names) 344 | adj = nx.adjacency_matrix(G) 345 | 346 | tmp = np.split(adj.indices, adj.indptr[1:-1]) 347 | 348 | for i in range(0, len(anchors)): 349 | key = anchors[i] 350 | i = tmp[i] 351 | names = list(node_names[i]) 352 | mnns[key] = names 353 | 354 | return(mnns) 355 | 356 | 357 | def create_dictionary_knn(adata, cell_subset, k = 50, save_on_disk = True, approx = True): 358 | 359 | cell_names = adata.obs_names 360 | 361 | dataset = adata[cell_subset] 362 | pcs = dataset.obsm['X_pca'] 363 | 364 | def get_names(ind): 365 | return np.array(cell_subset)[ind] 366 | 367 | if approx: 368 | dim = pcs.shape[1] 369 | num_elements = pcs.shape[0] 370 | p = hnswlib.Index(space='l2', dim = dim) 371 | p.init_index(max_elements=num_elements, ef_construction=100, M=16) 372 | p.set_ef(10) 373 | p.add_items(pcs) 374 | ind, distances = p.knn_query(pcs, k=k) 375 | 376 | cell_subset = np.array(cell_subset) 377 | names = list(map(lambda x: cell_subset[x], ind)) 378 | knns = dict(zip(cell_subset, names)) 379 | 380 | else: 381 | nn_ = NearestNeighbors(n_neighbors = k, p = 2) 382 | nn_.fit(pcs) 383 | ind = nn_.kneighbors(pcs, return_distance=False) 384 | 385 | names = list(map(lambda x: cell_subset[x], ind)) 386 | knns = dict(zip(cell_subset, names)) 387 | 388 | return(knns) 389 | 390 | 391 | class TNN(BaseEstimator): 392 | 393 | def __init__(self, embedding_dims=2, k=150, distance='pn', batch_size=64, 394 | epochs=1000, n_epochs_without_progress=20, 395 | margin=1, ntrees=50, search_k=-1, 396 | precompute=True, save_on_disk=True, 397 | k_to_m_ratio = 0.75, 398 | label_ratio=0.75, 399 | supervision_metric='sparse_categorical_crossentropy', 400 | supervision_weight=0.3, annoy_index_path=None, 401 | approx = True, 402 | #sample_weight = None, 403 | callbacks=[], build_index_on_disk=None, verbose=1): 404 | 405 | self.embedding_dims = embedding_dims 406 | self.k = k 407 | self.distance = distance 408 | self.batch_size = batch_size 409 | self.epochs = epochs 410 | self.n_epochs_without_progress = n_epochs_without_progress 411 | self.margin = margin 412 | self.ntrees = ntrees 413 | self.search_k = search_k 414 | self.precompute = precompute 415 | self.model_def = "dummy" 416 | self.model_ = None 417 | self.encoder = None 418 | self.k_to_m_ratio = k_to_m_ratio 419 | self.label_ratio = label_ratio 420 | self.approx = approx 421 | self.supervision_metric = supervision_metric 422 | self.supervision_weight = supervision_weight 423 | self.supervised_model_ = None 424 | self.loss_history_ = [] 425 | self.annoy_index_path = annoy_index_path 426 | self.callbacks = callbacks 427 | #self.sample_weight = sample_weight 428 | self.save_on_disk = save_on_disk 429 | for callback in self.callbacks: 430 | if isinstance(callback, ModelCheckpoint): 431 | callback = callback.register_ivis_model(self) 432 | if build_index_on_disk is None: 433 | self.build_index_on_disk = True if platform.system() != 'Windows' else False 434 | else: 435 | self.build_index_on_disk = build_index_on_disk 436 | self.verbose = verbose 437 | 438 | def __getstate__(self): 439 | """ Return object serializable variable dict """ 440 | 441 | state = dict(self.__dict__) 442 | if 'model_' in state: 443 | state['model_'] = None 444 | if 'encoder' in state: 445 | state['encoder'] = None 446 | if 'supervised_model_' in state: 447 | state['supervised_model_'] = None 448 | if 'callbacks' in state: 449 | state['callbacks'] = [] 450 | if not isinstance(state['model_def'], str): 451 | state['model_def'] = None 452 | return state 453 | 454 | def _fit(self, X, batch_name, celltype_name=None, mask_batch=None, Y=None, shuffle_mode=True): 455 | 456 | datagen = generator_from_index(X, 457 | batch_name = batch_name, 458 | mask_batch=mask_batch, 459 | celltype_name = celltype_name, 460 | Y = Y, 461 | k_to_m_ratio = self.k_to_m_ratio, 462 | label_ratio = self.label_ratio, 463 | k=self.k, 464 | batch_size=self.batch_size, 465 | search_k=self.search_k, 466 | verbose = self.verbose, 467 | save_on_disk = self.save_on_disk, 468 | approx = self.approx) 469 | 470 | loss_monitor = 'loss' 471 | try: 472 | triplet_loss_func = triplet_loss(distance=self.distance, 473 | margin=self.margin) 474 | except KeyError: 475 | raise ValueError('Loss function `{}` not implemented.'.format(self.distance)) 476 | 477 | if self.model_ is None: 478 | if type(self.model_def) is str: 479 | input_size = (X.obsm['X_pca'].shape[-1],) 480 | self.model_, anchor_embedding, _, _ = \ 481 | triplet_network(base_network(input_size), 482 | embedding_dims=self.embedding_dims) 483 | else: 484 | self.model_, anchor_embedding, _, _ = \ 485 | triplet_network(self.model_def, 486 | embedding_dims=self.embedding_dims) 487 | 488 | if Y is None: 489 | 490 | self.model_.compile(optimizer='adam', loss=triplet_loss_func) 491 | else: 492 | Y = le.fit_transform(Y) 493 | if is_categorical(self.supervision_metric): 494 | if not is_multiclass(self.supervision_metric): 495 | if not is_hinge(self.supervision_metric): 496 | # Binary logistic classifier 497 | if len(Y.shape) > 1: 498 | self.n_classes = Y.shape[-1] 499 | else: 500 | self.n_classes = 1 501 | supervised_output = Dense(self.n_classes, activation='sigmoid', 502 | name='supervised')(anchor_embedding) 503 | else: 504 | # Binary Linear SVM output 505 | if len(Y.shape) > 1: 506 | self.n_classes = Y.shape[-1] 507 | else: 508 | self.n_classes = 1 509 | supervised_output = Dense(self.n_classes, activation='linear', 510 | name='supervised', 511 | kernel_regularizer=regularizers.l1(l1=0.01))(anchor_embedding) 512 | else: 513 | if not is_hinge(self.supervision_metric): 514 | validate_sparse_labels(Y) 515 | self.n_classes = len(np.unique(Y[Y != np.array(-1)])) 516 | # Softmax classifier 517 | supervised_output = Dense(self.n_classes, activation='softmax', 518 | name='supervised')(anchor_embedding) 519 | else: 520 | self.n_classes = len(np.unique(Y, axis=0)) 521 | # Multiclass Linear SVM output 522 | supervised_output = Dense(self.n_classes, activation='linear', 523 | name='supervised', 524 | kernel_regularizer=regularizers.l1(l1=0.01))(anchor_embedding) 525 | else: 526 | # Regression 527 | if len(Y.shape) > 1: 528 | self.n_classes = Y.shape[-1] 529 | else: 530 | self.n_classes = 1 531 | supervised_output = Dense(self.n_classes, activation='linear', 532 | name='supervised')(anchor_embedding) 533 | 534 | supervised_loss = keras.losses.get(self.supervision_metric) 535 | if self.supervision_metric == 'sparse_categorical_crossentropy': 536 | supervised_loss = semi_supervised_loss(supervised_loss) 537 | 538 | final_network = Model(inputs=self.model_.inputs, 539 | outputs=[self.model_.output, 540 | supervised_output]) 541 | self.model_ = final_network 542 | self.model_.compile( 543 | optimizer='adam', 544 | loss={ 545 | 'stacked_triplets': triplet_loss_func, 546 | 'supervised': supervised_loss 547 | }, 548 | loss_weights={ 549 | 'stacked_triplets': 1 - self.supervision_weight, 550 | 'supervised': self.supervision_weight}) 551 | 552 | # Store dedicated classification model 553 | supervised_model_input = Input(shape=(X.obsm['X_pca'].shape[-1],)) 554 | embedding = self.model_.layers[3](supervised_model_input) 555 | softmax_out = self.model_.layers[-1](embedding) 556 | 557 | self.supervised_model_ = Model(supervised_model_input, softmax_out) 558 | 559 | self.encoder = self.model_.layers[3] 560 | 561 | if self.verbose > 0: 562 | print('Training neural network') 563 | 564 | hist = self.model_.fit( 565 | datagen, 566 | epochs=self.epochs, 567 | callbacks=[callback for callback in self.callbacks] + 568 | [EarlyStopping(monitor=loss_monitor, 569 | patience=self.n_epochs_without_progress)], 570 | shuffle=shuffle_mode, 571 | workers = 10, 572 | verbose=self.verbose) 573 | 574 | self.loss_history_ += hist.history['loss'] 575 | 576 | def fit(self, X, batch_name, celltype_name=None, mask_batch=None, Y=None, shuffle_mode=True): 577 | """Fit model. 578 | Parameters 579 | ---------- 580 | X : Anndata object to be embedded. 581 | batch_name : name of column in Anndata.obs containing batch information 582 | Y : Optional array for supervised dimentionality reduction. 583 | Returns 584 | ------- 585 | returns an instance of self 586 | """ 587 | self._fit(X, batch_name, celltype_name, mask_batch, Y, shuffle_mode) 588 | return self 589 | 590 | def fit_transform(self, X, batch_name, celltype_name=None, mask_batch=None, Y=None, shuffle_mode=True): 591 | """Fit to data then transform 592 | Parameters 593 | ---------- 594 | X : Anndata object to be embedded. 595 | Y : Optional array for supervised dimentionality reduction. 596 | Returns 597 | ------- 598 | X_new : transformed array, shape (n_samples, embedding_dims) 599 | Embedding of the new data in low-dimensional space. 600 | """ 601 | self.fit(X, batch_name, celltype_name, mask_batch, Y, shuffle_mode) 602 | return self.transform(X) 603 | 604 | def transform(self, X): 605 | """Transform X into the existing embedded space and return that 606 | transformed output. 607 | Parameters 608 | ---------- 609 | X : Anndata object to be embedded. 610 | Returns 611 | ------- 612 | X_new : array, shape (n_samples, embedding_dims) 613 | Embedding of the new data in low-dimensional space. 614 | """ 615 | embedding = self.encoder.predict(X.obsm['X_pca'], verbose=self.verbose) 616 | return embedding 617 | 618 | def score_samples(self, X): 619 | """Passes X through classification network to obtain predicted 620 | supervised values. Only applicable when trained in 621 | supervised mode. 622 | Parameters 623 | ---------- 624 | X : array, shape (n_samples, n_features) 625 | Data to be passed through classification network. 626 | Returns 627 | ------- 628 | X_new : array, shape (n_samples, embedding_dims) 629 | Softmax class probabilities of the data. 630 | """ 631 | if self.supervised_model_ is None: 632 | raise Exception("Model was not trained in classification mode.") 633 | 634 | softmax_output = self.supervised_model_.predict(X, verbose=self.verbose) 635 | return softmax_output 636 | 637 | 638 | def semi_supervised_loss(loss_function): 639 | def new_loss_function(y_true, y_pred): 640 | mask = tf.cast(~tf.math.equal(y_true, -1), tf.float32) 641 | y_true_pos = tf.nn.relu(y_true) 642 | loss = loss_function(y_true_pos, y_pred) 643 | masked_loss = loss * mask 644 | return masked_loss 645 | new_func = new_loss_function 646 | new_func.__name__ = loss_function.__name__ 647 | return new_func 648 | 649 | 650 | def validate_sparse_labels(Y): 651 | if not zero_indexed(Y): 652 | raise ValueError('Ensure that your labels are zero-indexed') 653 | if not consecutive_indexed(Y): 654 | raise ValueError('Ensure that your labels are indexed consecutively') 655 | 656 | 657 | def zero_indexed(Y): 658 | if min(abs(Y)) != 0: 659 | return False 660 | return True 661 | 662 | 663 | def consecutive_indexed(Y): 664 | """ Assumes that Y is zero-indexed. """ 665 | n_classes = len(np.unique(Y[Y != np.array(-1)])) 666 | if max(Y) >= n_classes: 667 | return False 668 | return True 669 | 670 | 671 | def nn_approx(ds1, ds2, names1, names2, knn=50): 672 | dim = ds2.shape[1] 673 | num_elements = ds2.shape[0] 674 | p = hnswlib.Index(space='l2', dim=dim) 675 | p.init_index(max_elements=num_elements, ef_construction=100, M = 16) 676 | p.set_ef(10) 677 | p.add_items(ds2) 678 | ind, distances = p.knn_query(ds1, k=knn) 679 | 680 | match = set() 681 | for a, b in zip(range(ds1.shape[0]), ind): 682 | for b_i in b: 683 | match.add((names1[a], names2[b_i])) 684 | 685 | return match 686 | 687 | 688 | def nn(ds1, ds2, names1, names2, knn=50, metric_p=2): 689 | # Find nearest neighbors of first dataset. 690 | nn_ = NearestNeighbors(knn, p=metric_p) 691 | nn_.fit(ds2) 692 | ind = nn_.kneighbors(ds1, return_distance=False) 693 | 694 | match = set() 695 | for a, b in zip(range(ds1.shape[0]), ind): 696 | for b_i in b: 697 | match.add((names1[a], names2[b_i])) 698 | 699 | return match 700 | 701 | 702 | def nn_annoy(ds1, ds2, names1, names2, knn = 20, metric='euclidean', n_trees = 50, save_on_disk = True): 703 | """ Assumes that Y is zero-indexed. """ 704 | # Build index. 705 | a = AnnoyIndex(ds2.shape[1], metric=metric) 706 | if(save_on_disk): 707 | a.on_disk_build('annoy.index') 708 | for i in range(ds2.shape[0]): 709 | a.add_item(i, ds2[i, :]) 710 | a.build(n_trees) 711 | 712 | # Search index. 713 | ind = [] 714 | for i in range(ds1.shape[0]): 715 | ind.append(a.get_nns_by_vector(ds1[i, :], knn, search_k=-1)) 716 | ind = np.array(ind) 717 | 718 | # Match. 719 | match = set() 720 | for a, b in zip(range(ds1.shape[0]), ind): 721 | for b_i in b: 722 | match.add((names1[a], names2[b_i])) 723 | 724 | return match 725 | 726 | 727 | def mnn(ds1, ds2, names1, names2, knn = 20, save_on_disk = True, approx = True): 728 | # Find nearest neighbors in first direction. 729 | if approx: 730 | match1 = nn_approx(ds1, ds2, names1, names2, knn=knn)#, save_on_disk = save_on_disk) 731 | # Find nearest neighbors in second direction. 732 | match2 = nn_approx(ds2, ds1, names2, names1, knn=knn)#, save_on_disk = save_on_disk) 733 | else: 734 | match1 = nn(ds1, ds2, names1, names2, knn=knn) 735 | match2 = nn(ds2, ds1, names2, names1, knn=knn) 736 | # Compute mutual nearest neighbors. 737 | mutual = match1 & set([ (b, a) for a, b in match2 ]) 738 | 739 | return mutual 740 | -------------------------------------------------------------------------------- /tnn/version.py: -------------------------------------------------------------------------------- 1 | VERSION = '0.0.2' 2 | -------------------------------------------------------------------------------- /umap_embedding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkmklsmn/insct/c2eb12df8f69d330996c0f5c43ae39f292cb02ea/umap_embedding.png --------------------------------------------------------------------------------