├── .gitignore ├── Figure.png ├── LICENSE ├── MultiCPA ├── __init__.py ├── api.py ├── data.py ├── helper.py ├── model.py ├── plotting.py ├── seml_sweep_icb.py └── train.py ├── README.md └── environment_multicpa.yml /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | *.DS_Store 131 | .idea* 132 | -------------------------------------------------------------------------------- /Figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theislab/multicpa/a08ace3429523c73e1b966e793cb0dba2686db6a/Figure.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, Theis Lab 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /MultiCPA/__init__.py: -------------------------------------------------------------------------------- 1 | # Author: Kemal Inecik 2 | # Email: k.inecik@gmail.com 3 | -------------------------------------------------------------------------------- /MultiCPA/api.py: -------------------------------------------------------------------------------- 1 | # Author: Kemal Inecik 2 | # Email: k.inecik@gmail.com 3 | 4 | import numpy as np 5 | import sys 6 | import torch 7 | import scanpy as sc 8 | import pandas as pd 9 | import re 10 | import itertools 11 | from sklearn.metrics import r2_score 12 | from sklearn.metrics.pairwise import cosine_distances, euclidean_distances 13 | from MultiCPA.data import SubDataset 14 | import copy 15 | 16 | class ComPertAPI: 17 | """ 18 | API for ComPert model to make it compatible with scanpy. 19 | """ 20 | def __init__(self, datasets, model): 21 | """ 22 | Parameters 23 | ---------- 24 | dataset : ComPertDataset 25 | Full dataset. 26 | model : ComPertModel 27 | Pre-trained ComPert model. 28 | """ 29 | dataset = datasets['training'] 30 | self.perturbation_key = dataset.perturbation_key 31 | self.dose_key = dataset.dose_key 32 | self.covars_key = dataset.covars_key 33 | self.min_dose = dataset.drugs[dataset.drugs > 0].min().item() 34 | self.max_dose = dataset.drugs[dataset.drugs > 0].max().item() 35 | 36 | self.model = model 37 | self.var_names = dataset.var_names 38 | 39 | self.unique_perts = list(dataset.perts_dict.keys()) 40 | self.unique_сovars = list(dataset.covars_dict.keys()) 41 | self.num_drugs = dataset.num_drugs 42 | 43 | self.perts_dict = dataset.perts_dict 44 | self.covars_dict = dataset.covars_dict 45 | 46 | self.drug_ohe = torch.Tensor(list(dataset.perts_dict.values())) 47 | self.covars_ohe = torch.LongTensor(list(dataset.covars_dict.values())) 48 | 49 | self.emb_covars = None 50 | self.emb_perts = None 51 | self.seen_covars_perts = None 52 | self.comb_emb = None 53 | self.control_cat = None 54 | 55 | self.seen_covars_perts = {} 56 | for k in datasets.keys(): 57 | self.seen_covars_perts[k] = np.unique(datasets[k].pert_categories) 58 | 59 | self.measured_points = {} 60 | self.num_measured_points = {} 61 | for k in datasets.keys(): 62 | self.measured_points[k] = {} 63 | self.num_measured_points[k] = {} 64 | for pert in np.unique(datasets[k].pert_categories): 65 | num_points = len(np.where(datasets[k].pert_categories == pert)[0]) 66 | self.num_measured_points[k][pert] = num_points 67 | 68 | cov, drug, dose = pert.split('_') 69 | if not('+' in dose): 70 | dose = float(dose) 71 | if cov in self.measured_points[k].keys(): 72 | if drug in self.measured_points[k][cov].keys(): 73 | self.measured_points[k][cov][drug].append(dose) 74 | else: 75 | self.measured_points[k][cov][drug] = [dose] 76 | else: 77 | self.measured_points[k][cov] = {drug: [dose]} 78 | 79 | self.measured_points['all'] = copy.deepcopy(self.measured_points['training']) 80 | for cov in self.measured_points['ood'].keys(): 81 | for pert in self.measured_points['ood'][cov].keys(): 82 | if pert in self.measured_points['training'][cov].keys(): 83 | self.measured_points['all'][cov][pert] =\ 84 | self.measured_points['training'][cov][pert].copy()+\ 85 | self.measured_points['ood'][cov][pert].copy() 86 | else: 87 | self.measured_points['all'][cov][pert] =\ 88 | self.measured_points['ood'][cov][pert].copy() 89 | 90 | 91 | def get_drug_embeddings(self, dose=1.0, return_anndata=True): 92 | """ 93 | Parameters 94 | ---------- 95 | dose : int (default: 1.0) 96 | Dose at which to evaluate latent embedding vector. 97 | return_anndata : bool, optional (default: True) 98 | Return embedding wrapped into anndata object. 99 | 100 | Returns 101 | ------- 102 | If return_anndata is True, returns anndata object. Otherwise, doesn't 103 | return anything. Always saves embeddding in self.emb_perts. 104 | """ 105 | self.emb_perts = self.model.compute_drug_embeddings_(dose*\ 106 | self.drug_ohe.to(self.model.device)).cpu().clone().detach().numpy() 107 | if return_anndata: 108 | adata = sc.AnnData(self.emb_perts) 109 | adata.obs[self.perturbation_key] = self.unique_perts 110 | return adata 111 | 112 | def get_covars_embeddings(self, return_anndata=True): 113 | """ 114 | Parameters 115 | ---------- 116 | return_anndata : bool, optional (default: True) 117 | Return embedding wrapped into anndata object. 118 | 119 | Returns 120 | ------- 121 | If return_anndata is True, returns anndata object. Otherwise, doesn't 122 | return anything. Always saves embeddding in self.emb_covars. 123 | """ 124 | self.emb_covars = self.model.cell_type_embeddings( 125 | self.covars_ohe.to(self.model.device).argmax(1) 126 | ).cpu().clone().detach().numpy() 127 | 128 | if return_anndata: 129 | adata = sc.AnnData(self.emb_covars) 130 | adata.obs[self.covars_key] = self.unique_сovars 131 | return adata 132 | 133 | def get_drug_encoding_(self, drugs, doses=None): 134 | """ 135 | Parameters 136 | ---------- 137 | drugs : str 138 | Drugs combination as a string, where individual drugs are separated 139 | with a plus. 140 | doses : str, optional (default: None) 141 | Doses corresponding to the drugs combination as a string. Individual 142 | drugs are separated with a plus. 143 | 144 | Returns 145 | ------- 146 | One hot encodding for a mixture of drugs. 147 | """ 148 | 149 | drug_mix = np.zeros([1, self.num_drugs]) 150 | atomic_drugs = drugs.split('+') 151 | doses = str(doses) 152 | 153 | if doses is None: 154 | doses_list = [1.0]*len(atomic_drugs) 155 | else: 156 | doses_list = [float(d) for d in str(doses).split('+')] 157 | for j, drug in enumerate(atomic_drugs): 158 | drug_mix += doses_list[j]*self.perts_dict[drug] 159 | 160 | return drug_mix 161 | 162 | def mix_drugs(self, drugs_list, doses_list=None, return_anndata=True): 163 | """ 164 | Gets a list of drugs combinations to mix, e.g. ['A+B', 'B+C'] and 165 | corresponding doses. 166 | 167 | Parameters 168 | ---------- 169 | drugs_list : list 170 | List of drug combinations, where each drug combination is a string. 171 | Individual drugs in the combination are separated with a plus. 172 | doses_list : str, optional (default: None) 173 | List of corresponding doses, where each dose combination is a string. 174 | Individual doses in the combination are separated with a plus. 175 | return_anndata : bool, optional (default: True) 176 | Return embedding wrapped into anndata object. 177 | 178 | Returns 179 | ------- 180 | If return_anndata is True, returns anndata structure of the combinations, 181 | otherwise returns a np.array of corresponding embeddings. 182 | """ 183 | 184 | drug_mix = np.zeros([len(drugs_list), self.num_drugs]) 185 | for i, drug_combo in enumerate(drugs_list): 186 | drug_mix[i] = self.get_drug_encoding_(drug_combo, doses=doses_list[i]) 187 | 188 | emb = self.model.compute_drug_embeddings_(torch.Tensor(drug_mix).to( 189 | self.model.device)).cpu().clone().detach().numpy() 190 | 191 | if return_anndata: 192 | adata = sc.AnnData(emb) 193 | adata.obs[self.perturbation_key] = drugs_list 194 | adata.obs[self.dose_key] = doses_list 195 | return adata 196 | else: 197 | return emb 198 | 199 | def latent_dose_response(self, perturbations=None, dose=None, 200 | contvar_min=0, contvar_max=1, n_points=100): 201 | """ 202 | Parameters 203 | ---------- 204 | perturbations : list 205 | List containing two names for which to return complete pairwise 206 | dose-response. 207 | doses : np.array (default: None) 208 | Doses values. If None, default values will be generated on a grid: 209 | n_points in range [contvar_min, contvar_max]. 210 | contvar_min : float (default: 0) 211 | Minimum dose value to generate for default option. 212 | contvar_max : float (default: 0) 213 | Maximum dose value to generate for default option. 214 | n_points : int (default: 100) 215 | Number of dose points to generate for default option. 216 | Returns 217 | ------- 218 | pd.DataFrame 219 | """ 220 | # dosers work only for atomic drugs. TODO add drug combinations 221 | self.model.eval() 222 | 223 | if perturbations is None: 224 | perturbations = self.unique_perts 225 | 226 | if dose is None: 227 | dose = np.linspace(contvar_min, contvar_max, n_points) 228 | n_points = len(dose) 229 | 230 | df = pd.DataFrame(columns=[self.perturbation_key, self.dose_key,\ 231 | 'response']) 232 | for drug in perturbations: 233 | d = np.where(self.perts_dict[drug] == 1)[0][0] 234 | this_drug = torch.Tensor(dose).to(self.model.device).view(-1, 1) 235 | if self.model.doser_type == 'mlp': 236 | response = (self.model.dosers[d](this_drug).sigmoid() *\ 237 | this_drug.gt(0)).cpu().clone().detach().numpy().reshape(-1) 238 | else: 239 | response = self.model.dosers.one_drug(this_drug.view(-1),\ 240 | d).cpu().clone().detach().numpy().reshape(-1) 241 | 242 | df_drug = pd.DataFrame(list(zip([drug]*n_points, dose, list(response))), 243 | columns=[self.perturbation_key, self.dose_key, 'response']) 244 | df = pd.concat([df, df_drug]) 245 | 246 | return df 247 | 248 | def latent_dose_response2D(self, perturbations, dose=None, 249 | contvar_min=0, contvar_max=1, n_points=100,): 250 | """ 251 | Parameters 252 | ---------- 253 | perturbations : list, optional (default: None) 254 | List of atomic drugs for which to return latent dose response. 255 | Currently drug combinations are not supported. 256 | doses : np.array (default: None) 257 | Doses values. If None, default values will be generated on a grid: 258 | n_points in range [contvar_min, contvar_max]. 259 | contvar_min : float (default: 0) 260 | Minimum dose value to generate for default option. 261 | contvar_max : float (default: 0) 262 | Maximum dose value to generate for default option. 263 | n_points : int (default: 100) 264 | Number of dose points to generate for default option. 265 | Returns 266 | ------- 267 | pd.DataFrame 268 | """ 269 | # dosers work only for atomic drugs. TODO add drug combinations 270 | 271 | assert len(perturbations) == 2, "You should provide a list of 2 perturbations." 272 | 273 | self.model.eval() 274 | 275 | if dose is None: 276 | dose = np.linspace(contvar_min, contvar_max, n_points) 277 | n_points = len(dose) 278 | 279 | df = pd.DataFrame(columns=perturbations + ['response']) 280 | response = {} 281 | 282 | for drug in perturbations: 283 | d = np.where(self.perts_dict[drug] == 1)[0][0] 284 | this_drug = torch.Tensor(dose).to(self.model.device).view(-1, 1) 285 | if self.model.doser_type == 'mlp': 286 | response[drug] = (self.model.dosers[d](this_drug).sigmoid() *\ 287 | this_drug.gt(0)).cpu().clone().detach().numpy().reshape(-1) 288 | else: 289 | response[drug] = self.model.dosers.one_drug(this_drug.view(-1),\ 290 | d).cpu().clone().detach().numpy().reshape(-1) 291 | 292 | l = 0 293 | for i in range(len(dose)): 294 | for j in range(len(dose)): 295 | df.loc[l] = [dose[i], dose[j], response[perturbations[0]][i]+\ 296 | response[perturbations[1]][j]] 297 | l += 1 298 | 299 | return df 300 | 301 | def compute_comb_emb(self, thrh=30): 302 | """ 303 | Generates an AnnData object containing all the latent vectors of the 304 | cov+dose*pert combinations seen during training. 305 | Called in api.compute_uncertainty(), stores the AnnData in self.comb_emb. 306 | 307 | Parameters 308 | ---------- 309 | Returns 310 | ------- 311 | """ 312 | if self.seen_covars_perts['training'] is None: 313 | raise ValueError('Need to run parse_training_conditions() first!') 314 | 315 | emb_covars = self.get_covars_embeddings(return_anndata=True) 316 | 317 | #Generate adata with all cov+pert latent vect combinations 318 | tmp_ad_list = [] 319 | for cov_pert in self.seen_covars_perts['training']: 320 | if self.num_measured_points['training'][cov_pert] > thrh: 321 | cov_loop, pert_loop, dose_loop = cov_pert.split('_') 322 | emb_perts_loop = [] 323 | if '+' in pert_loop: 324 | pert_loop_list = pert_loop.split('+') 325 | dose_loop_list = dose_loop.split('+') 326 | for _dose in pd.Series(dose_loop_list).unique(): 327 | tmp_ad = self.get_drug_embeddings(dose=float(_dose)) 328 | tmp_ad.obs['pert_dose'] = tmp_ad.obs.condition + '_' + _dose 329 | emb_perts_loop.append(tmp_ad) 330 | 331 | emb_perts_loop = emb_perts_loop[0].concatenate(emb_perts_loop[1:]) 332 | X = ( 333 | emb_covars.X[emb_covars.obs.cell_type == cov_loop] 334 | + np.expand_dims( 335 | emb_perts_loop.X[ 336 | emb_perts_loop.obs.pert_dose.isin( 337 | [ 338 | pert_loop_list[i] + '_' + dose_loop_list[i] 339 | for i in range(len(pert_loop_list)) 340 | ] 341 | ) 342 | ].sum(axis=0), 343 | axis=0 344 | ) 345 | ) 346 | if X.shape[0] > 1: 347 | raise ValueError('Error with comb computation') 348 | else: 349 | emb_perts = self.get_drug_embeddings(dose=float(dose_loop)) 350 | X = ( 351 | emb_covars.X[emb_covars.obs.cell_type == cov_loop] 352 | + emb_perts.X[emb_perts.obs.condition == pert_loop] 353 | ) 354 | tmp_ad = sc.AnnData( 355 | X=X 356 | ) 357 | tmp_ad.obs['cov_pert'] = '_'.join([cov_loop, pert_loop, dose_loop]) 358 | tmp_ad_list.append(tmp_ad) 359 | 360 | self.comb_emb = tmp_ad_list[0].concatenate(tmp_ad_list[1:]) 361 | 362 | def compute_uncertainty( 363 | self, 364 | cov, 365 | pert, 366 | dose, 367 | thrh=30 368 | ): 369 | """ 370 | Compute uncertainties for the queried covariate+perturbation combination. 371 | The distance from the closest condition in the training set is used as a 372 | proxy for uncertainty. 373 | 374 | Parameters 375 | ---------- 376 | cov: string 377 | Covariate (eg. cell_type) for the queried uncertainty 378 | pert: string 379 | Perturbation for the queried uncertainty. In case of combinations the 380 | format has to be 'pertA+pertB' 381 | dose: string 382 | String which contains the dose of the perturbation queried. In case 383 | of combinations the format has to be 'doseA+doseB' 384 | 385 | Returns 386 | ------- 387 | min_cos_dist: float 388 | Minimum cosine distance with the training set. 389 | min_eucl_dist: float 390 | Minimum euclidean distance with the training set. 391 | closest_cond_cos: string 392 | Closest training condition wrt cosine distances. 393 | closest_cond_eucl: string 394 | Closest training condition wrt euclidean distances. 395 | """ 396 | 397 | if self.comb_emb is None: 398 | self.compute_comb_emb(thrh=30) 399 | 400 | covar_ohe = torch.Tensor( 401 | self.covars_dict[cov] 402 | ).to(self.model.device) 403 | 404 | drug_ohe = torch.Tensor( 405 | self.get_drug_encoding_( 406 | pert, 407 | doses=dose 408 | ) 409 | ).to(self.model.device) 410 | 411 | cov = covar_ohe.expand([1, self.covars_ohe.shape[1]]) 412 | pert = drug_ohe.expand([1, self.drug_ohe.shape[1]]) 413 | 414 | drug_emb = self.model.compute_drug_embeddings_(pert).detach().cpu().numpy() 415 | cell_emb = self.model.cell_type_embeddings(cov.argmax(1)).detach().cpu().numpy() 416 | cond_emb = drug_emb + cell_emb 417 | 418 | cos_dist = cosine_distances(cond_emb, self.comb_emb.X)[0] 419 | min_cos_dist = np.min(cos_dist) 420 | cos_idx = np.argmin(cos_dist) 421 | closest_cond_cos = self.comb_emb.obs.cov_pert[cos_idx] 422 | 423 | eucl_dist = euclidean_distances(cond_emb, self.comb_emb.X)[0] 424 | min_eucl_dist = np.min(eucl_dist) 425 | eucl_idx = np.argmin(eucl_dist) 426 | closest_cond_eucl = self.comb_emb.obs.cov_pert[eucl_idx] 427 | 428 | return min_cos_dist, min_eucl_dist, closest_cond_cos, closest_cond_eucl 429 | 430 | def predict( 431 | self, 432 | genes, 433 | df, 434 | uncertainty=True, 435 | return_anndata=True, 436 | sample=False, 437 | n_samples=10 438 | ): 439 | """Predict values of control 'genes' conditions specified in df. 440 | 441 | Parameters 442 | ---------- 443 | genes : np.array 444 | Control cells. 445 | df : pd.DataFrame 446 | Values for perturbations and covariates to generate. 447 | uncertainty: bool (default: True) 448 | Compute uncertainties for the generated cells. 449 | return_anndata : bool, optional (default: True) 450 | Return embedding wrapped into anndata object. 451 | sample : bool (default: False) 452 | If sample is True, returns samples from gausssian distribution with 453 | mean and variance estimated by the model. Otherwise, returns just 454 | means and variances estimated by the model. 455 | n_samples : int (default: 10) 456 | Number of samples to sample if sampling is True. 457 | Returns 458 | ------- 459 | If return_anndata is True, returns anndata structure. Otherwise, returns 460 | np.arrays for gene_means, gene_vars and a data frame for the corresponding 461 | conditions df_obs. 462 | 463 | """ 464 | self.model.eval() 465 | num = genes.shape[0] 466 | dim = genes.shape[1] 467 | genes = torch.Tensor(genes).to(self.model.device) 468 | if sample: 469 | print('Careful! These are sampled values! Better use means and \ 470 | variances for dowstream tasks!') 471 | 472 | gene_means_list = [] 473 | gene_vars_list = [] 474 | df_list = [] 475 | 476 | for i in range(len(df)): 477 | comb_name = df.loc[i][self.perturbation_key] 478 | dose_name = df.loc[i][self.dose_key] 479 | covar_name = df.loc[i][self.covars_key] 480 | 481 | covar_ohe = torch.Tensor( 482 | self.covars_dict[covar_name] 483 | ).to(self.model.device) 484 | 485 | drug_ohe = torch.Tensor( 486 | self.get_drug_encoding_( 487 | comb_name, 488 | doses=dose_name 489 | ) 490 | ).to(self.model.device) 491 | 492 | drugs = drug_ohe.expand([num, self.drug_ohe.shape[1]]) 493 | covars = covar_ohe.expand([num, self.covars_ohe.shape[1]]) 494 | 495 | gene_reconstructions = self.model.predict( 496 | genes, 497 | drugs, 498 | covars 499 | ).cpu().clone().detach().numpy() 500 | 501 | if sample: 502 | df_list.append( 503 | pd.DataFrame( 504 | [df.loc[i].values]*num*n_samples, 505 | columns=df.columns 506 | ) 507 | ) 508 | dist = torch.distributions.normal.Normal( 509 | torch.Tensor(gene_reconstructions[:, :dim]), 510 | torch.Tensor(gene_reconstructions[:, dim:]), 511 | ) 512 | gene_means_list.append( 513 | dist 514 | .sample(torch.Size([n_samples])) 515 | .cpu() 516 | .detach() 517 | .numpy() 518 | .reshape(-1, dim) 519 | ) 520 | else: 521 | df_list.append( 522 | pd.DataFrame( 523 | [df.loc[i].values]*num, 524 | columns=df.columns 525 | ) 526 | ) 527 | 528 | gene_means_list.append( 529 | gene_reconstructions[:, :dim] 530 | ) 531 | 532 | if uncertainty: 533 | cos_dist, eucl_dist, closest_cond_cos, closest_cond_eucl =\ 534 | self.compute_uncertainty( 535 | cov=covar_name, 536 | pert=comb_name, 537 | dose=dose_name 538 | ) 539 | df_list[-1] = df_list[-1].assign( 540 | uncertainty_cosine=cos_dist, 541 | uncertainty_euclidean=eucl_dist, 542 | closest_cond_cosine=closest_cond_cos, 543 | closest_cond_euclidean=closest_cond_eucl 544 | ) 545 | gene_vars_list.append( 546 | gene_reconstructions[:, dim:] 547 | ) 548 | 549 | gene_means = np.concatenate(gene_means_list) 550 | gene_vars = np.concatenate(gene_vars_list) 551 | df_obs = pd.concat(df_list) 552 | del df_list, gene_means_list, gene_vars_list 553 | 554 | if return_anndata: 555 | adata = sc.AnnData(gene_means) 556 | adata.var_names = self.var_names 557 | adata.obs = df_obs 558 | if not sample: 559 | adata.layers["variance"] = gene_vars 560 | 561 | adata.obs.index = adata.obs.index.astype(str) # type fix 562 | return adata 563 | else: 564 | return gene_means, gene_vars, df_obs 565 | 566 | def get_response( 567 | self, 568 | datasets, 569 | doses=None, 570 | contvar_min=None, 571 | contvar_max=None, 572 | n_points=50, 573 | ncells_max=100, 574 | perturbations=None, 575 | control_name='test_control' 576 | ): 577 | """Decoded dose response data frame. 578 | 579 | Parameters 580 | ---------- 581 | dataset : CompPertDataset 582 | The file location of the spreadsheet 583 | doses : np.array (default: None) 584 | Doses values. If None, default values will be generated on a grid: 585 | n_points in range [contvar_min, contvar_max]. 586 | contvar_min : float (default: 0) 587 | Minimum dose value to generate for default option. 588 | contvar_max : float (default: 0) 589 | Maximum dose value to generate for default option. 590 | n_points : int (default: 100) 591 | Number of dose points to generate for default option. 592 | perturbations : list (default: None) 593 | List of perturbations for dose response 594 | 595 | Returns 596 | ------- 597 | pd.DataFrame 598 | of decoded response values of genes and average response. 599 | """ 600 | 601 | if contvar_min is None: 602 | contvar_min = self.min_dose 603 | if contvar_max is None: 604 | contvar_max = self.max_dose 605 | 606 | self.model.eval() 607 | # doses = torch.Tensor(np.linspace(contvar_min, contvar_max, n_points)) 608 | if doses is None: 609 | doses = np.linspace(contvar_min, contvar_max, n_points) 610 | 611 | if perturbations is None: 612 | perturbations = self.unique_perts 613 | 614 | response = pd.DataFrame(columns=[self.covars_key, 615 | self.perturbation_key, 616 | self.dose_key, 617 | 'response'] + list(self.var_names)) 618 | 619 | i = 0 620 | for ict, ct in enumerate(self.unique_сovars): 621 | # genes_control = dataset.genes[dataset.indices['control']] 622 | genes_control =\ 623 | datasets[control_name].genes[datasets[control_name].cell_types_names ==\ 624 | ct].clone().detach() 625 | if len(genes_control) < 1: 626 | print('Warning! Not enought control cells for this covariate.\ 627 | Taking control cells from all covariates.') 628 | genes_control = datasets[control_name].genes 629 | 630 | if ncells_max < len(genes_control): 631 | ncells_max = min(ncells_max, len(genes_control)) 632 | idx = torch.LongTensor(np.random.choice(range(len(genes_control)),\ 633 | ncells_max, replace=False)) 634 | genes_control = genes_control[idx] 635 | 636 | num, dim = genes_control.size(0), genes_control.size(1) 637 | control_avg = genes_control.mean(dim=0).cpu().clone().detach().numpy().reshape(-1) 638 | 639 | for idr, drug in enumerate(perturbations): 640 | if not (drug in datasets[control_name].ctrl_name): 641 | for dose in doses: 642 | df = pd.DataFrame(data={self.covars_key: [ct], 643 | self.perturbation_key: [drug], self.dose_key: [dose]}) 644 | 645 | gene_means, _, _ =\ 646 | self.predict(genes_control.cpu().detach().numpy(),\ 647 | df, return_anndata=False) 648 | predicted_data = np.mean(gene_means, axis=0).reshape(-1) 649 | 650 | response.loc[i] = [ct, drug, dose, 651 | np.linalg.norm(predicted_data-control_avg)] +\ 652 | list(predicted_data - control_avg) 653 | i += 1 654 | return response 655 | 656 | def get_response_reference( 657 | self, 658 | datasets, 659 | perturbations=None 660 | ): 661 | 662 | """Computes reference values of the response. 663 | 664 | Parameters 665 | ---------- 666 | dataset : CompPertDataset 667 | The file location of the spreadsheet 668 | perturbations : list (default: None) 669 | List of perturbations for dose response 670 | 671 | Returns 672 | ------- 673 | pd.DataFrame 674 | of decoded response values of genes and average response. 675 | """ 676 | if perturbations is None: 677 | perturbations = self.unique_perts 678 | 679 | reference_response_curve = pd.DataFrame(columns=[self.covars_key, 680 | self.perturbation_key, 681 | self.dose_key, 682 | 'split', 683 | 'num_cells', 684 | 'response'] +\ 685 | list(self.var_names)) 686 | 687 | dataset_ctr = datasets['training_control'] 688 | 689 | i = 0 690 | for split in ['training_treated', 'ood']: 691 | dataset = datasets[split] 692 | for pert in self.seen_covars_perts[split]: 693 | ct, drug, dose_val = pert.split('_') 694 | if drug in perturbations: 695 | if not ('+' in dose_val): 696 | dose = float(dose_val) 697 | else: 698 | dose = dose_val 699 | 700 | genes_control = dataset_ctr.genes[ 701 | (dataset_ctr.cell_types_names == ct)].clone().detach() 702 | if len(genes_control) < 1: 703 | print('Warning! Not enought control cells for this covariate. \ 704 | Taking control cells from all covariates.') 705 | genes_control = dataset_ctr.genes.clone().detach() 706 | 707 | num, dim = genes_control.size(0), genes_control.size(1) 708 | control_avg =\ 709 | genes_control.mean(dim=0).cpu().clone().detach().numpy().reshape(-1) 710 | 711 | idx = np.where((dataset.pert_categories == pert))[0] 712 | 713 | if len(idx): 714 | y_true = dataset.genes[idx, :].numpy().mean(axis=0) 715 | reference_response_curve.loc[i] = [ct, drug, 716 | dose, split, len(idx), np.linalg.norm(y_true - control_avg)] +\ 717 | list(y_true - control_avg) 718 | 719 | i += 1 720 | 721 | return reference_response_curve 722 | 723 | def get_response2D( 724 | self, 725 | datasets, 726 | perturbations, 727 | covar, 728 | doses=None, 729 | contvar_min=None, 730 | contvar_max=None, 731 | n_points=10, 732 | ncells_max=100, 733 | fixed_drugs='', 734 | fixed_doses='' 735 | ): 736 | """Decoded dose response data frame. 737 | 738 | Parameters 739 | ---------- 740 | dataset : CompPertDataset 741 | The file location of the spreadsheet 742 | perturbations : list 743 | List of length 2 of perturbations for dose response. 744 | covar : str 745 | Name of a covariate for which to compute dose-response. 746 | doses : np.array (default: None) 747 | Doses values. If None, default values will be generated on a grid: 748 | n_points in range [contvar_min, contvar_max]. 749 | contvar_min : float (default: 0) 750 | Minimum dose value to generate for default option. 751 | contvar_max : float (default: 0) 752 | Maximum dose value to generate for default option. 753 | n_points : int (default: 100) 754 | Number of dose points to generate for default option. 755 | 756 | Returns 757 | ------- 758 | pd.DataFrame 759 | of decoded response values of genes and average response. 760 | """ 761 | 762 | assert len(perturbations) == 2, "You should provide a list of 2 perturbations." 763 | 764 | if contvar_min is None: 765 | contvar_min = self.min_dose 766 | 767 | if contvar_max is None: 768 | contvar_max = self.max_dose 769 | 770 | self.model.eval() 771 | # doses = torch.Tensor(np.linspace(contvar_min, contvar_max, n_points)) 772 | if doses is None: 773 | doses = np.linspace(contvar_min, contvar_max, n_points) 774 | 775 | # genes_control = dataset.genes[dataset.indices['control']] 776 | genes_control =\ 777 | datasets['test_control'].genes[datasets['test_control'].cell_types_names ==\ 778 | covar].clone().detach() 779 | if len(genes_control) < 1: 780 | print('Warning! Not enought control cells for this covariate. \ 781 | Taking control cells from all covariates.') 782 | genes_control = datasets['test_control'].genes 783 | 784 | ncells_max = min(ncells_max, len(genes_control)) 785 | idx = torch.LongTensor(np.random.choice(range(len(genes_control)), ncells_max)) 786 | genes_control = genes_control[idx] 787 | 788 | num, dim = genes_control.size(0), genes_control.size(1) 789 | control_avg = genes_control.mean(dim=0).cpu().clone().detach().numpy().reshape(-1) 790 | 791 | response = pd.DataFrame(columns=perturbations + ['response'] +\ 792 | list(self.var_names)) 793 | 794 | drug = perturbations[0] + '+' + perturbations[1] 795 | 796 | dose_vals = [f"{d[0]}+{d[1]}" for d in itertools.product(*[doses, doses])] 797 | dose_comb = [list(d) for d in itertools.product(*[doses, doses])] 798 | 799 | i = 0 800 | if not (drug in ['Vehicle', 'EGF', 'unst', 'control', 'ctrl']): 801 | for dose in dose_vals: 802 | df = pd.DataFrame(data={self.covars_key: [covar], 803 | self.perturbation_key: [drug+fixed_drugs],\ 804 | self.dose_key: [dose+fixed_doses]}) 805 | 806 | gene_means, _, _ = self.predict( 807 | genes_control.cpu().detach().numpy(), df, 808 | return_anndata=False) 809 | 810 | predicted_data = np.mean(gene_means, axis=0).reshape(-1) 811 | 812 | response.loc[i] = dose_comb[i] +\ 813 | [np.linalg.norm(control_avg - predicted_data)] +\ 814 | list(predicted_data - control_avg) 815 | i += 1 816 | 817 | return response 818 | 819 | def get_cycle_uncertainty( 820 | self, 821 | genes_from, 822 | df_from, 823 | df_to, 824 | ncells_max=100, 825 | direction='forward' 826 | ): 827 | 828 | """Uncertainty for a single condition. 829 | 830 | Parameters 831 | ---------- 832 | genes_from: torch.Tensor 833 | Genes for comparison. 834 | df_from: pd.DataFrame 835 | Full description of the condition. 836 | df_to: pd.DataFrame 837 | Full description of the control condition. 838 | ncells_max: int, optional (defaul: 100) 839 | Max number of cells to use. 840 | Returns 841 | ------- 842 | tuple 843 | with uncertainty estimations: (MSE, 1-R2). 844 | """ 845 | self.model.eval() 846 | genes_control = genes_from.clone().detach() 847 | 848 | if ncells_max < len(genes_control): 849 | idx = torch.LongTensor(np.random.choice(range(len(genes_control)),\ 850 | ncells_max, replace=False)) 851 | genes_control = genes_control[idx] 852 | 853 | gene_condition, _, _ = self.predict(genes_control, df_to,\ 854 | return_anndata=False, sample=False) 855 | gene_condition = torch.Tensor(gene_condition).clone().detach() 856 | gene_return, _, _ = self.predict(gene_condition, df_from,\ 857 | return_anndata=False, sample=False) 858 | 859 | if direction == 'forward': 860 | # control -> condition -> control' 861 | genes_control = genes_control.numpy() 862 | ctr = np.mean(genes_control, axis=0) 863 | ret = np.mean(gene_return, axis=0) 864 | return np.mean((genes_control - gene_return)**2), 1-r2_score(ctr, ret) 865 | else: 866 | # control -> condition -> control' -> condition' 867 | gene_return = torch.Tensor(gene_return).clone().detach() 868 | gene_condition_return, _, _ = self.predict(gene_return, df_to,\ 869 | return_anndata=False, sample=False) 870 | gene_condition = gene_condition.numpy() 871 | ctr = np.mean(gene_condition, axis=0) 872 | ret = np.mean(gene_condition_return, axis=0) 873 | return np.mean((gene_condition - gene_condition_return)**2),\ 874 | 1-r2_score(ctr, ret) 875 | 876 | def print_complete_cycle_uncertainty( 877 | self, 878 | datasets, 879 | datasets_ctr, 880 | ncells_max=1000, 881 | split_list=['test', 'ood'], 882 | direction='forward' 883 | ): 884 | uncert = pd.DataFrame(columns=[self.covars_key, 885 | self.perturbation_key, 886 | self.dose_key, 'split', 'MSE', '1-R2']) 887 | 888 | ctr_covar, ctrl_name, ctr_dose = datasets_ctr.pert_categories[0].split('_') 889 | df_ctrl = pd.DataFrame({self.perturbation_key: [ctrl_name], 890 | self.dose_key: [ctr_dose], 891 | self.covars_key: [ctr_covar]}) 892 | 893 | i = 0 894 | for split in split_list: 895 | dataset = datasets[split] 896 | print(split) 897 | for pert_cat in np.unique(dataset.pert_categories): 898 | idx = np.where(dataset.pert_categories == pert_cat)[0] 899 | genes = dataset.genes[idx, :] 900 | 901 | covar, pert, dose = pert_cat.split('_') 902 | df_cond = pd.DataFrame({self.perturbation_key: [pert], 903 | self.dose_key: [dose], 904 | self.covars_key: [covar]}) 905 | 906 | if direction == 'back': 907 | # condition -> control -> condition 908 | uncert.loc[i] = [covar, pert, dose, split] +\ 909 | list(self.get_cycle_uncertainty(genes, df_cond,\ 910 | df_ctrl, ncells_max=ncells_max)) 911 | else: 912 | # control -> condition -> control 913 | uncert.loc[i] = [covar, pert, dose, split] +\ 914 | list(self.get_cycle_uncertainty(datasets_ctr.genes,\ 915 | df_ctrl, df_cond, ncells_max=ncells_max,\ 916 | direction=direction)) 917 | 918 | i += 1 919 | 920 | return uncert 921 | 922 | def evaluate_r2( 923 | self, 924 | dataset, 925 | genes_control 926 | ): 927 | """ 928 | Measures different quality metrics about an ComPert `autoencoder`, when 929 | tasked to translate some `genes_control` into each of the drug/cell_type 930 | combinations described in `dataset`. 931 | 932 | Considered metrics are R2 score about means and variances for all genes, as 933 | well as R2 score about means and variances about differentially expressed 934 | (_de) genes. 935 | """ 936 | self.model.eval() 937 | scores = pd.DataFrame(columns=[self.covars_key, 938 | self.perturbation_key, 939 | self.dose_key, 940 | 'R2_mean', 'R2_mean_DE', 'R2_var', 941 | 'R2_var_DE', 'num_cells']) 942 | 943 | num, dim = genes_control.size(0), genes_control.size(1) 944 | 945 | total_cells = len(dataset) 946 | 947 | icond = 0 948 | for pert_category in np.unique(dataset.pert_categories): 949 | # pert_category category contains: 'celltype_perturbation_dose' info 950 | de_idx = np.where( 951 | dataset.var_names.isin( 952 | np.array(dataset.de_genes[pert_category])))[0] 953 | 954 | idx = np.where(dataset.pert_categories == pert_category)[0] 955 | 956 | if len(idx) > 0: 957 | emb_drugs = dataset.drugs[idx][0].view( 958 | 1, -1).repeat(num, 1).clone() 959 | emb_cts = dataset.cell_types[idx][0].view( 960 | 1, -1).repeat(num, 1).clone() 961 | 962 | genes_predict = self.model.predict( 963 | genes_control, emb_drugs, emb_cts).detach().cpu() 964 | 965 | mean_predict = genes_predict[:, :dim] 966 | var_predict = genes_predict[:, dim:] 967 | 968 | # estimate metrics only for reasonably-sized drug/cell-type combos 969 | 970 | y_true = dataset.genes[idx, :].numpy() 971 | 972 | # true means and variances 973 | yt_m = y_true.mean(axis=0) 974 | yt_v = y_true.var(axis=0) 975 | # predicted means and variances 976 | yp_m = mean_predict.mean(0) 977 | yp_v = var_predict.mean(0) 978 | 979 | mean_score = r2_score(yt_m, yp_m) 980 | var_score = r2_score(yt_v, yp_v) 981 | 982 | mean_score_de = r2_score(yt_m[de_idx], yp_m[de_idx]) 983 | var_score_de = r2_score(yt_v[de_idx], yp_v[de_idx]) 984 | scores.loc[icond] = pert_category.split('_') +\ 985 | [mean_score, mean_score_de, var_score, var_score_de, len(idx)] 986 | icond += 1 987 | 988 | return scores 989 | 990 | 991 | 992 | def get_reference_from_combo( 993 | perturbations_list, 994 | datasets, 995 | splits=['training', 'ood'] 996 | ): 997 | """ 998 | A simple function that produces a pd.DataFrame of individual 999 | drugs-doses combinations used among the splits (for a fixed covariate). 1000 | """ 1001 | df_list = [] 1002 | for split_name in splits: 1003 | full_dataset = datasets[split_name] 1004 | ref = {'num_cells': []} 1005 | for pp in perturbations_list: 1006 | ref[pp] = [] 1007 | 1008 | ndrugs = len(perturbations_list) 1009 | for pert_cat in np.unique(full_dataset.pert_categories): 1010 | _, pert, dose = pert_cat.split('_') 1011 | pert_list = pert.split('+') 1012 | if set(pert_list) == set(perturbations_list): 1013 | dose_list = dose.split('+') 1014 | ncells = len(full_dataset.pert_categories[ 1015 | full_dataset.pert_categories == pert_cat]) 1016 | for j in range(ndrugs): 1017 | ref[pert_list[j]].append(float(dose_list[j])) 1018 | ref['num_cells'].append(ncells) 1019 | print(pert, dose, ncells) 1020 | df = pd.DataFrame.from_dict(ref) 1021 | df['split'] = split_name 1022 | df_list.append(df) 1023 | 1024 | return pd.concat(df_list) 1025 | 1026 | 1027 | def linear_interp(y1, y2, x1, x2, x): 1028 | a = (y1 - y2)/(x1 - x2) 1029 | b = y1 - a*x1 1030 | y = a*x + b 1031 | return y 1032 | 1033 | def evaluate_r2_benchmark( 1034 | compert_api, 1035 | datasets, 1036 | pert_category, 1037 | pert_category_list 1038 | ): 1039 | scores = pd.DataFrame(columns=[compert_api.covars_key, 1040 | compert_api.perturbation_key, 1041 | compert_api.dose_key, 1042 | 'R2_mean', 'R2_mean_DE', 1043 | 'R2_var', 'R2_var_DE', 1044 | 'num_cells', 'benchmark', 'method']) 1045 | 1046 | de_idx = np.where( 1047 | datasets['ood'].var_names.isin( 1048 | np.array(datasets['ood'].de_genes[pert_category])))[0] 1049 | idx = np.where(datasets['ood'].pert_categories == pert_category)[0] 1050 | y_true = datasets['ood'].genes[idx, :].numpy() 1051 | # true means and variances 1052 | yt_m = y_true.mean(axis=0) 1053 | yt_v = y_true.var(axis=0) 1054 | 1055 | icond = 0 1056 | if len(idx) > 0: 1057 | for pert_category_predict in pert_category_list: 1058 | if '+' in pert_category_predict: 1059 | pert1, pert2 = pert_category_predict.split('+') 1060 | idx_pred1 = np.where(datasets['training'].pert_categories ==\ 1061 | pert1)[0] 1062 | idx_pred2 = np.where(datasets['training'].pert_categories ==\ 1063 | pert2)[0] 1064 | 1065 | y_pred1 = datasets['training'].genes[idx_pred1, :].numpy() 1066 | y_pred2 = datasets['training'].genes[idx_pred2, :].numpy() 1067 | 1068 | x1 = float(pert1.split('_')[2]) 1069 | x2 = float(pert2.split('_')[2]) 1070 | x = float(pert_category.split('_')[2]) 1071 | yp_m1 = y_pred1.mean(axis=0) 1072 | yp_m2 = y_pred2.mean(axis=0) 1073 | yp_v1 = y_pred1.var(axis=0) 1074 | yp_v2 = y_pred2.var(axis=0) 1075 | 1076 | yp_m = linear_interp(yp_m1, yp_m2, x1, x2, x) 1077 | yp_v = linear_interp(yp_v1, yp_v2, x1, x2, x) 1078 | 1079 | # yp_m = (y_pred1.mean(axis=0) + y_pred2.mean(axis=0))/2 1080 | # yp_v = (y_pred1.var(axis=0) + y_pred2.var(axis=0))/2 1081 | 1082 | else: 1083 | idx_pred = np.where(datasets['training'].pert_categories ==\ 1084 | pert_category_predict)[0] 1085 | print(pert_category_predict, len(idx_pred)) 1086 | y_pred = datasets['training'].genes[idx_pred, :].numpy() 1087 | # predicted means and variances 1088 | yp_m = y_pred.mean(axis=0) 1089 | yp_v = y_pred.var(axis=0) 1090 | 1091 | mean_score = r2_score(yt_m, yp_m) 1092 | var_score = r2_score(yt_v, yp_v) 1093 | 1094 | mean_score_de = r2_score(yt_m[de_idx], yp_m[de_idx]) 1095 | var_score_de = r2_score(yt_v[de_idx], yp_v[de_idx]) 1096 | scores.loc[icond] = pert_category.split('_') +\ 1097 | [mean_score, mean_score_de, var_score, var_score_de,\ 1098 | len(idx), pert_category_predict, 'benchmark'] 1099 | icond += 1 1100 | 1101 | return scores 1102 | -------------------------------------------------------------------------------- /MultiCPA/data.py: -------------------------------------------------------------------------------- 1 | # Author: Kemal Inecik 2 | # Email: k.inecik@gmail.com 3 | 4 | import warnings 5 | import torch 6 | 7 | import numpy as np 8 | 9 | warnings.simplefilter(action='ignore', category=FutureWarning) 10 | import scanpy as sc 11 | import pandas as pd 12 | 13 | from sklearn.preprocessing import OneHotEncoder 14 | 15 | def ranks_to_df(data, key='rank_genes_groups'): 16 | """Converts an `sc.tl.rank_genes_groups` result into a MultiIndex dataframe. 17 | 18 | You can access various levels of the MultiIndex with `df.loc[[category]]`. 19 | 20 | Params 21 | ------ 22 | data : `AnnData` 23 | key : str (default: 'rank_genes_groups') 24 | Field in `.uns` of data where `sc.tl.rank_genes_groups` result is 25 | stored. 26 | """ 27 | d = data.uns[key] 28 | dfs = [] 29 | for k in d.keys(): 30 | if k == 'params': 31 | continue 32 | series = pd.DataFrame.from_records(d[k]).unstack() 33 | series.name = k 34 | dfs.append(series) 35 | 36 | return pd.concat(dfs, axis=1) 37 | 38 | 39 | class Dataset: 40 | def __init__(self, 41 | fname, 42 | perturbation_key, 43 | dose_key, 44 | cell_type_key, 45 | split_key='split', 46 | counts_key=None, 47 | proteins_key=None, 48 | raw_proteins_key=None): 49 | 50 | data = sc.read(fname) 51 | 52 | self.perturbation_key = perturbation_key 53 | self.dose_key = dose_key 54 | self.cell_type_key = cell_type_key 55 | self.genes = torch.Tensor(data.X.A) 56 | self.counts_key = counts_key 57 | self.proteins_key = proteins_key 58 | self.raw_proteins_key = raw_proteins_key 59 | 60 | if counts_key is not None: 61 | self.raw_genes = torch.Tensor(data.layers[counts_key].A) 62 | 63 | if proteins_key is not None: 64 | self.proteins = torch.Tensor(data.obsm[proteins_key]) 65 | if raw_proteins_key is not None: 66 | self.raw_proteins = torch.Tensor(data.obsm[raw_proteins_key]) 67 | else: 68 | self.proteins = None 69 | 70 | self.var_names = data.var_names 71 | 72 | self.pert_categories = np.array(data.obs['cov_drug_dose_name'].values) 73 | 74 | self.de_genes = data.uns['rank_genes_groups_cov'] 75 | self.ctrl = data.obs['control'].values 76 | self.ctrl_name = list(np.unique(data[data.obs['control'] == 1].obs[self.perturbation_key])) 77 | 78 | self.drugs_names = np.array(data.obs[perturbation_key].values) 79 | self.dose_names = np.array(data.obs[dose_key].values) 80 | 81 | # get unique drugs 82 | drugs_names_unique = set() 83 | for d in self.drugs_names: 84 | [drugs_names_unique.add(i) for i in d.split("+")] 85 | self.drugs_names_unique = np.array(list(drugs_names_unique)) 86 | 87 | # save encoder for a comparison with Mo's model 88 | # later we need to remove this part 89 | encoder_drug = OneHotEncoder(sparse=False) 90 | encoder_drug.fit(self.drugs_names_unique.reshape(-1, 1)) 91 | 92 | self.atomic_drugs_dict = dict(zip(self.drugs_names_unique, encoder_drug.transform( 93 | self.drugs_names_unique.reshape(-1, 1)))) 94 | 95 | # get drug combinations 96 | drugs = [] 97 | for i, comb in enumerate(self.drugs_names): 98 | drugs_combos = encoder_drug.transform( 99 | np.array(comb.split("+")).reshape(-1, 1)) 100 | dose_combos = str(data.obs[dose_key].values[i]).split("+") 101 | for j, d in enumerate(dose_combos): 102 | if j == 0: 103 | drug_ohe = float(d) * drugs_combos[j] 104 | else: 105 | drug_ohe += float(d) * drugs_combos[j] 106 | drugs.append(drug_ohe) 107 | self.drugs = torch.Tensor(drugs) 108 | 109 | self.cell_types_names = np.array(data.obs[cell_type_key].values) 110 | self.cell_types_names_unique = np.unique(self.cell_types_names) 111 | 112 | encoder_ct = OneHotEncoder(sparse=False) 113 | encoder_ct.fit(self.cell_types_names_unique.reshape(-1, 1)) 114 | 115 | self.atomic_сovars_dict = dict(zip(list(self.cell_types_names_unique), encoder_ct.transform( 116 | self.cell_types_names_unique.reshape(-1, 1)))) 117 | 118 | self.cell_types = torch.Tensor(encoder_ct.transform( 119 | self.cell_types_names.reshape(-1, 1))).float() 120 | 121 | self.num_cell_types = len(self.cell_types_names_unique) 122 | self.num_genes = self.genes.shape[1] 123 | self.num_drugs = len(self.drugs_names_unique) 124 | 125 | if self.proteins is not None: 126 | self.num_proteins = self.proteins.shape[1] 127 | else: 128 | self.num_proteins = 0 129 | 130 | self.indices = { 131 | "all": list(range(len(self.genes))), 132 | "control": np.where(data.obs['control'] == 1)[0].tolist(), 133 | "treated": np.where(data.obs['control'] != 1)[0].tolist(), 134 | "train": np.where(data.obs[split_key] == 'train')[0].tolist(), 135 | "test": np.where(data.obs[split_key] == 'test')[0].tolist(), 136 | "ood": np.where(data.obs[split_key] == 'ood')[0].tolist() 137 | } 138 | 139 | atomic_ohe = encoder_drug.transform( 140 | self.drugs_names_unique.reshape(-1, 1)) 141 | 142 | self.drug_dict = {} 143 | for idrug, drug in enumerate(self.drugs_names_unique): 144 | i = np.where(atomic_ohe[idrug] == 1)[0][0] 145 | self.drug_dict[i] = drug 146 | 147 | 148 | 149 | def subset(self, split, condition="all"): 150 | idx = list(set(self.indices[split]) & set(self.indices[condition])) 151 | return SubDataset(self, idx) 152 | 153 | def __getitem__(self, i): 154 | 155 | raw_genes_ = None 156 | if self.counts_key is not None: 157 | raw_genes_ = self.raw_genes[i] 158 | 159 | proteins_ = None 160 | if self.proteins_key is not None: 161 | proteins_ = self.proteins[i] 162 | 163 | raw_proteins_ = None 164 | if self.raw_proteins_key is not None: 165 | raw_proteins_ = self.raw_proteins[i] 166 | 167 | return self.genes[i], self.drugs[i], self.cell_types[i], proteins_, raw_genes_, raw_proteins_ 168 | 169 | def __len__(self): 170 | return len(self.genes) 171 | 172 | 173 | class SubDataset: 174 | """ 175 | Subsets a `Dataset` by selecting the examples given by `indices`. 176 | """ 177 | 178 | def __init__(self, dataset, indices): 179 | self.perturbation_key = dataset.perturbation_key 180 | self.dose_key = dataset.dose_key 181 | self.covars_key = dataset.cell_type_key 182 | self.counts_key = dataset.counts_key 183 | self.proteins_key = dataset.proteins_key 184 | self.raw_proteins_key = dataset.raw_proteins_key 185 | 186 | self.perts_dict = dataset.atomic_drugs_dict 187 | self.covars_dict = dataset.atomic_сovars_dict 188 | 189 | self.genes = dataset.genes[indices] 190 | self.drugs = dataset.drugs[indices] 191 | self.cell_types = dataset.cell_types[indices] 192 | 193 | if dataset.proteins_key is not None: 194 | self.proteins = dataset.proteins[indices] 195 | 196 | if dataset.raw_proteins_key is not None: 197 | self.raw_proteins = dataset.raw_proteins[indices] 198 | 199 | if dataset.counts_key is not None: 200 | self.raw_genes = dataset.raw_genes[indices] 201 | 202 | self.drugs_names = dataset.drugs_names[indices] 203 | self.pert_categories = dataset.pert_categories[indices] 204 | self.cell_types_names = dataset.cell_types_names[indices] 205 | 206 | self.var_names = dataset.var_names 207 | self.de_genes = dataset.de_genes 208 | self.ctrl_name = dataset.ctrl_name[0] 209 | 210 | self.num_cell_types = dataset.num_cell_types 211 | self.num_genes = dataset.num_genes 212 | self.num_drugs = dataset.num_drugs 213 | self.num_proteins = dataset.num_proteins 214 | 215 | def __getitem__(self, i): 216 | 217 | raw_genes_ = None 218 | if self.counts_key is not None: 219 | raw_genes_ = self.raw_genes[i] 220 | 221 | proteins_ = None 222 | if self.proteins_key is not None: 223 | proteins_ = self.proteins[i] 224 | 225 | raw_proteins_ = None 226 | if self.raw_proteins_key is not None: 227 | raw_proteins_ = self.raw_proteins[i] 228 | 229 | return self.genes[i], self.drugs[i], self.cell_types[i], proteins_, raw_genes_, raw_proteins_ 230 | 231 | def __len__(self): 232 | return len(self.genes) 233 | 234 | 235 | def load_dataset_splits( 236 | dataset_path, 237 | perturbation_key, 238 | dose_key, 239 | cell_type_key, 240 | split_key, 241 | counts_key=None, 242 | proteins_key=None, 243 | raw_proteins_key=None, 244 | return_dataset=False): 245 | 246 | dataset = Dataset(dataset_path, 247 | perturbation_key, 248 | dose_key, 249 | cell_type_key, 250 | split_key, 251 | counts_key, 252 | proteins_key, 253 | raw_proteins_key) 254 | 255 | splits = { 256 | "training": dataset.subset("train", "all"), 257 | "training_control": dataset.subset("train", "control"), 258 | "training_treated": dataset.subset("train", "treated"), 259 | "test": dataset.subset("test", "all"), 260 | "test_control": dataset.subset("test", "control"), 261 | "test_treated": dataset.subset("test", "treated"), 262 | "ood": dataset.subset("ood", "all") 263 | } 264 | 265 | if return_dataset: 266 | return splits, dataset 267 | else: 268 | return splits 269 | -------------------------------------------------------------------------------- /MultiCPA/helper.py: -------------------------------------------------------------------------------- 1 | # Author: Kemal Inecik 2 | # Email: k.inecik@gmail.com 3 | 4 | import scanpy as sc 5 | import pandas as pd 6 | 7 | 8 | def model_importer(_id, df, model_dir, dataset_relative_to, model_created='TotalComPert'): 9 | import os 10 | import torch 11 | from MultiCPA.train import prepare_compert 12 | 13 | args = { 14 | 'dataset_path': "config.dataset.dataset_args.dataset_path", # full path to the anndata dataset 15 | 'cell_type_key': 'config.dataset.dataset_args.cell_type_key', # necessary field for cell types. Fill it with a dummy variable if no celltypes present. 16 | 'split_key': 'config.dataset.dataset_args.split_key', # necessary field for train, test, ood splits. 17 | 'perturbation_key': 'config.dataset.dataset_args.perturbation_key', # necessary field for perturbations 18 | 'dose_key': 'config.dataset.dataset_args.dose_key', # necessary field for dose. Fill in with dummy variable if dose is the same. 19 | 'checkpoint_freq': 'config.training.checkpoint_freq', # checkoint frequencty to save intermediate results 20 | 'max_epochs': 'config.training.num_epochs', # maximum epochs for training 21 | 'max_minutes': 'config.training.max_minutes', # maximum computation time 22 | 'patience': 'config.model.model_args.patience', # patience for early stopping 23 | 'loss_ae': 'config.model.model_args.loss_ae', # loss (currently only gaussian loss is supported) 24 | 'doser_type': 'config.model.model_args.doser_type', # non-linearity for doser function 25 | 'save_dir': 'config.training.save_dir', # directory to save the model 26 | 'decoder_activation': 'config.model.model_args.decoder_activation', # last layer of the decoder 27 | 'seed': 'config.seed', # random seed 28 | 'raw_counts_key': 'config.dataset.dataset_args.counts_key', # necessary field for nb loss. Name of the layer storing raw gene counts. 29 | 'is_vae': 'config.model.model_args.is_vae', # using a vae or ae model 30 | 'protein_key': 'config.dataset.dataset_args.proteins_key', # name of the field storing the protein data in adata.obsm[proteins_key] 31 | 'raw_protein_key': 'config.dataset.dataset_args.raw_proteins_key', # necessary field for nb loss. Name of the field storing the raw protein data in adata.obsm[protein_expression_raw] 32 | } #'hparams': "", # autoencoder architecture 33 | 34 | model_experiment = df.loc[_id] 35 | if model_experiment["status"] == 0: 36 | raise NotImplementedError 37 | exp_id = model_experiment['result.exp_id'] 38 | model_name = os.path.join(model_dir, f"{exp_id}_last.pt") 39 | 40 | for i in args: 41 | args[i] = model_experiment[args[i]] 42 | state, hypers, history = torch.load(model_name, map_location=torch.device('cpu')) 43 | args['hparams'] = hypers 44 | args['dataset_path'] = os.path.join(dataset_relative_to, args['dataset_path']) 45 | autoencoder, datasets = prepare_compert(args, state_dict=state, model=model_created) 46 | for p in autoencoder.parameters(): # reset requires_grad 47 | p.requires_grad = False 48 | autoencoder.eval(); 49 | return autoencoder, datasets, state, history, args 50 | 51 | 52 | def rank_genes_groups_by_cov( 53 | adata, 54 | groupby, 55 | control_group, 56 | covariate, 57 | pool_doses=False, 58 | n_genes=50, 59 | rankby_abs=True, 60 | key_added='rank_genes_groups_cov', 61 | return_dict=False, 62 | ): 63 | 64 | """ 65 | Function that generates a list of differentially expressed genes computed 66 | separately for each covariate category, and using the respective control 67 | cells as reference. 68 | 69 | Usage example: 70 | 71 | rank_genes_groups_by_cov( 72 | adata, 73 | groupby='cov_product_dose', 74 | covariate_key='cell_type', 75 | control_group='Vehicle_0' 76 | ) 77 | 78 | Parameters 79 | ---------- 80 | adata : AnnData 81 | AnnData dataset 82 | groupby : str 83 | Obs column that defines the groups, should be 84 | cartesian product of covariate_perturbation_cont_var, 85 | it is important that this format is followed. 86 | control_group : str 87 | String that defines the control group in the groupby obs 88 | covariate : str 89 | Obs column that defines the main covariate by which we 90 | want to separate DEG computation (eg. cell type, species, etc.) 91 | n_genes : int (default: 50) 92 | Number of DEGs to include in the lists 93 | rankby_abs : bool (default: True) 94 | If True, rank genes by absolute values of the score, thus including 95 | top downregulated genes in the top N genes. If False, the ranking will 96 | have only upregulated genes at the top. 97 | key_added : str (default: 'rank_genes_groups_cov') 98 | Key used when adding the dictionary to adata.uns 99 | return_dict : str (default: False) 100 | Signals whether to return the dictionary or not 101 | 102 | Returns 103 | ------- 104 | Adds the DEG dictionary to adata.uns 105 | 106 | If return_dict is True returns: 107 | gene_dict : dict 108 | Dictionary where groups are stored as keys, and the list of DEGs 109 | are the corresponding values 110 | 111 | """ 112 | 113 | gene_dict = {} 114 | cov_categories = adata.obs[covariate].unique() 115 | for cov_cat in cov_categories: 116 | print(cov_cat) 117 | #name of the control group in the groupby obs column 118 | control_group_cov = '_'.join([cov_cat, control_group]) 119 | 120 | #subset adata to cells belonging to a covariate category 121 | adata_cov = adata[adata.obs[covariate]==cov_cat] 122 | 123 | #compute DEGs 124 | sc.tl.rank_genes_groups( 125 | adata_cov, 126 | groupby=groupby, 127 | reference=control_group_cov, 128 | rankby_abs=rankby_abs, 129 | n_genes=n_genes 130 | ) 131 | 132 | #add entries to dictionary of gene sets 133 | de_genes = pd.DataFrame(adata_cov.uns['rank_genes_groups']['names']) 134 | for group in de_genes: 135 | gene_dict[group] = de_genes[group].tolist() 136 | 137 | adata.uns[key_added] = gene_dict 138 | 139 | if return_dict: 140 | return gene_dict 141 | 142 | 143 | -------------------------------------------------------------------------------- /MultiCPA/plotting.py: -------------------------------------------------------------------------------- 1 | # Author: Kemal Inecik 2 | # Email: k.inecik@gmail.com 3 | 4 | import numpy as np 5 | import sys 6 | import pprint 7 | 8 | import torch 9 | import scanpy as sc 10 | 11 | from collections import defaultdict 12 | from sklearn.metrics import r2_score 13 | from sklearn.metrics.pairwise import cosine_similarity 14 | from sklearn.decomposition import KernelPCA 15 | import seaborn as sns 16 | import pandas as pd 17 | import matplotlib.pyplot as plt 18 | import re 19 | import seaborn as sns 20 | from adjustText import adjust_text 21 | import matplotlib.font_manager 22 | from MultiCPA.api import ComPertAPI, get_reference_from_combo 23 | 24 | 25 | FONT_SIZE = 13 26 | font = {'size': FONT_SIZE} 27 | 28 | matplotlib.rc('font', **font) 29 | matplotlib.rc('ytick', labelsize=FONT_SIZE) 30 | matplotlib.rc('xtick', labelsize=FONT_SIZE) 31 | 32 | 33 | class CompertVisuals: 34 | """ 35 | A wrapper for automatic plotting CompPert latent embeddings and dose-response 36 | curve. Sets up prefix for all files and default dictionaries for atomic 37 | perturbations and cell types. 38 | """ 39 | def __init__(self, 40 | compert, 41 | fileprefix=None, 42 | perts_palette=None, 43 | сovars_palette=None, 44 | plot_params={'fontsize': None} 45 | ): 46 | """ 47 | Parameters 48 | ---------- 49 | compert : CompPertAPI 50 | Variable from ComPertAPI class. 51 | fileprefix : str, optional (default: None) 52 | Prefix (with path) to the filename to save all embeddings in a 53 | standartized manner. If None, embeddings are not saved to file. 54 | perts_palette : dict (default: None) 55 | Dictionary of colors for the embeddings of perturbations. Keys 56 | correspond to perturbations and values to their colors. If None, 57 | default dicitonary will be set up. 58 | сovars_palette : dict (default: None) 59 | Dictionary of colors for the embeddings of covariates. Keys 60 | correspond to covariates and values to their colors. If None, 61 | default dicitonary will be set up. 62 | """ 63 | 64 | self.fileprefix = fileprefix 65 | 66 | self.perturbation_key = compert.perturbation_key 67 | self.dose_key = compert.dose_key 68 | self.covars_key = compert.covars_key 69 | self.measured_points = compert.measured_points 70 | 71 | self.unique_perts = compert.unique_perts 72 | self.unique_сovars = compert.unique_сovars 73 | 74 | if perts_palette is None: 75 | self.perts_palette = dict(zip(self.unique_perts, 76 | get_palette(len(self.unique_perts)))) 77 | else: 78 | self.perts_palette = perts_palette 79 | 80 | if сovars_palette is None: 81 | self.сovars_palette = dict(zip(self.unique_сovars, 82 | get_palette(len(self.unique_сovars), palette_name='tab10'))) 83 | else: 84 | self.сovars_palette = сovars_palette 85 | 86 | if plot_params['fontsize'] is None: 87 | self.fontsize = FONT_SIZE 88 | else: 89 | self.fontsize = plot_params['fontsize'] 90 | 91 | def plot_latent_embeddings(self, 92 | emb, 93 | titlename='Example', 94 | kind='perturbations', 95 | palette=None, 96 | labels=None, 97 | dimred='KernelPCA', 98 | filename=None, 99 | show_text=True 100 | ): 101 | """ 102 | Parameters 103 | ---------- 104 | emb : np.array 105 | Multi-dimensional embedding of perturbations or covariates. 106 | titlename : str, optional (default: 'Example') 107 | Title. 108 | kind : int, optional, optional (default: 'perturbations') 109 | Specify if this is embedding of perturbations, covariates or some 110 | other. If it is perturbations or covariates, it will use default 111 | saved dictionaries for colors. 112 | palette : dict, optional (default: None) 113 | If embedding of kind not perturbations or covariates, the user can 114 | specify color dictionary for the embedding. 115 | labels : list, optional (default: None) 116 | Labels for the embeddings. 117 | dimred : str, optional (default: 'KernelPCA') 118 | Dimensionality reduction method for plotting low dimensional 119 | representations. Options: 'KernelPCA', 'UMAPpre', 'UMAPcos', None. 120 | If None, uses first 2 dimensions of the embedding. 121 | filename : str (default: None) 122 | Name of the file to save the plot. If None, will automatically 123 | generate name from prefix file. 124 | """ 125 | if filename is None: 126 | if self.fileprefix is None: 127 | filename = None 128 | file_name_similarity = None 129 | else: 130 | filename = f'{self.fileprefix}_emebdding.png' 131 | file_name_similarity=f'{self.fileprefix}_emebdding_similarity.png' 132 | else: 133 | file_name_similarity = filename.split('.')[0] + '_similarity.png' 134 | 135 | if (labels is None): 136 | if kind == 'perturbations': 137 | palette = self.perts_palette 138 | labels = self.unique_perts 139 | elif kind == 'covars': 140 | palette = self.сovars_palette 141 | labels = self.unique_сovars 142 | 143 | if len(emb) < 2: 144 | print(f'Embedding contains only {len(emb)} vectors. Not enough to plot.') 145 | else: 146 | plot_embedding( 147 | fast_dimred(emb, method=dimred), 148 | labels, 149 | show_lines=True, 150 | show_text=show_text, 151 | col_dict=palette, 152 | title=titlename, 153 | file_name=filename, 154 | fontsize=self.fontsize 155 | ) 156 | 157 | plot_similarity( 158 | emb, 159 | labels, 160 | col_dict=palette, 161 | fontsize=self.fontsize, 162 | file_name=file_name_similarity 163 | ) 164 | 165 | def plot_contvar_response2D(self, 166 | df_response2D, 167 | df_ref=None, 168 | levels=15, 169 | figsize=(4,4), 170 | xlims=(0, 1.03), 171 | ylims=(0, 1.03), 172 | palette="coolwarm", 173 | response_name='response', 174 | title_name=None, 175 | fontsize=None, 176 | postfix='', 177 | filename=None, 178 | alpha=0.4, 179 | sizes=(40, 160), 180 | logdose=False, 181 | file_format='png'): 182 | 183 | """ 184 | Parameters 185 | ---------- 186 | df_response2D : pd.DataFrame 187 | Data frame with responses of combinations with columns=(dose1, dose2, 188 | response). 189 | levels: int, optional (default: 15) 190 | Number of levels for contour plot. 191 | response_name : str (default: 'response') 192 | Name of column in df_response to plot as response. 193 | alpha: float (default: 0.4) 194 | Transparency of the background contour. 195 | figsize: tuple (default: (4,4)) 196 | Size of the figure in inches. 197 | palette : dict, optional (default: None) 198 | Colors dictionary for perturbations to plot. 199 | title_name : str, optional (default: None) 200 | Title for the plot. 201 | postfix : str, optional (defualt: '') 202 | Postfix to add to the output file name to save the model. 203 | filename : str, optional (defualt: None) 204 | Name of the file to save the plot. If None, will automatically 205 | generate name from prefix file. 206 | logdose: bool (default: False) 207 | If True, dose values will be log10. 0 values will be mapped to 208 | minumum value -1,e.g. 209 | if smallest non-zero dose was 0.001, 0 will be mapped to -4. 210 | """ 211 | sns.set_style("white") 212 | 213 | if (filename is None) and not (self.fileprefix is None): 214 | filename = f'{self.fileprefix}_{postfix}response2D.png' 215 | if fontsize is None: 216 | fontsize = self.fontsize 217 | 218 | x_name, y_name = df_response2D.columns[:2] 219 | 220 | x = df_response2D[x_name].values 221 | y = df_response2D[y_name].values 222 | 223 | if logdose: 224 | x = log10_with0(x) 225 | y = log10_with0(y) 226 | 227 | z = df_response2D[response_name].values 228 | 229 | n = int(np.sqrt(len(x))) 230 | 231 | X = x.reshape(n, n) 232 | Y = y.reshape(n, n) 233 | Z = z.reshape(n, n) 234 | 235 | fig, ax = plt.subplots(figsize=figsize) 236 | 237 | CS = ax.contourf(X,Y,Z, cmap=palette, levels=levels, alpha=alpha) 238 | CS = ax.contour(X, Y, Z, levels=15, cmap=palette) 239 | ax.clabel(CS, inline=1, fontsize=fontsize) 240 | ax.set(xlim=(0, 1), ylim=(0, 1)) 241 | ax.axis("equal") 242 | ax.axis("square") 243 | ax.yaxis.set_tick_params(labelsize=fontsize) 244 | ax.xaxis.set_tick_params(labelsize=fontsize) 245 | ax.set_xlabel(x_name, fontsize=fontsize, fontweight="bold") 246 | ax.set_ylabel(y_name, fontsize=fontsize, fontweight="bold") 247 | ax.set_xlim(xlims) 248 | ax.set_ylim(ylims) 249 | 250 | # sns.despine(left=False, bottom=False, right=True) 251 | sns.despine() 252 | 253 | if not (df_ref is None): 254 | sns.scatterplot( 255 | x=x_name, 256 | y=y_name, 257 | hue='split', 258 | size='num_cells', 259 | sizes=sizes, 260 | alpha=1., 261 | palette={'train': '#000000', 'training': '#000000', 'ood': '#e41a1c'}, 262 | data=df_ref, ax=ax) 263 | ax.legend_.remove() 264 | 265 | ax.set_title(title_name, fontweight="bold", fontsize=fontsize) 266 | plt.tight_layout() 267 | 268 | if filename: 269 | save_to_file(fig, filename) 270 | 271 | 272 | def plot_contvar_response(self, 273 | df_response, 274 | response_name='response', 275 | var_name=None, 276 | df_ref=None, 277 | palette=None, 278 | title_name=None, 279 | postfix='', 280 | xlabelname=None, 281 | filename=None, 282 | logdose=False, 283 | fontsize=None, 284 | measured_points=None, 285 | bbox=(1.35, 1.), 286 | figsize=(7., 4.) 287 | ): 288 | """ 289 | Parameters 290 | ---------- 291 | df_response : pd.DataFrame 292 | Data frame of responses. 293 | response_name : str (default: 'response') 294 | Name of column in df_response to plot as response. 295 | var_name : str, optional (default: None) 296 | Name of conditioning variable, e.g. could correspond to covariates. 297 | df_ref : pd.DataFrame, optional (default: None) 298 | Reference values. Fields for plotting should correspond to 299 | df_response. 300 | palette : dict, optional (default: None) 301 | Colors dictionary for perturbations to plot. 302 | title_name : str, optional (default: None) 303 | Title for the plot. 304 | postfix : str, optional (defualt: '') 305 | Postfix to add to the output file name to save the model. 306 | filename : str, optional (defualt: None) 307 | Name of the file to save the plot. If None, will automatically 308 | generate name from prefix file. 309 | logdose: bool (default: False) 310 | If True, dose values will be log10. 0 values will be mapped to 311 | minumum value -1,e.g. 312 | if smallest non-zero dose was 0.001, 0 will be mapped to -4. 313 | figsize: tuple (default: (7., 4.)) 314 | Size of output figure 315 | """ 316 | if (filename is None) and not (self.fileprefix is None): 317 | filename = f'{self.fileprefix}_{postfix}response.png' 318 | 319 | if fontsize is None: 320 | fontsize = self.fontsize 321 | 322 | if logdose: 323 | dose_name = f'log10-{self.dose_key}' 324 | df_response[dose_name] = log10_with0(df_response[self.dose_key].values) 325 | if not (df_ref is None): 326 | df_ref[dose_name] = log10_with0(df_ref[self.dose_key].values) 327 | else: 328 | dose_name = self.dose_key 329 | 330 | if var_name is None: 331 | if len(self.unique_сovars) > 1: 332 | var_name = self.covars_key 333 | else: 334 | var_name = self.perturbation_key 335 | 336 | if palette is None: 337 | if var_name == self.perturbation_key: 338 | palette = self.perts_palette 339 | elif var_name == self.covars_key: 340 | palette = self.сovars_palette 341 | 342 | 343 | plot_dose_response(df_response, 344 | dose_name, 345 | var_name, 346 | xlabelname=xlabelname, 347 | df_ref=df_ref, 348 | response_name=response_name, 349 | title_name=title_name, 350 | use_ref_response=(not (df_ref is None)), 351 | col_dict=palette, 352 | plot_vertical=False, 353 | f1=figsize[0], 354 | f2=figsize[1], 355 | fname=filename, 356 | logscale=measured_points, 357 | measured_points=measured_points, 358 | bbox=bbox, 359 | fontsize=fontsize, 360 | format='png') 361 | 362 | def plot_scatter( 363 | self, 364 | df, 365 | x_axis, 366 | y_axis, 367 | hue=None, 368 | size=None, 369 | style=None, 370 | figsize=(4.5, 4.5), 371 | title=None, 372 | palette=None, 373 | filename=None, 374 | alpha=.75, 375 | sizes=(30, 90), 376 | text_dict=None, 377 | postfix='', 378 | fontsize=14): 379 | 380 | sns.set_style("white") 381 | 382 | if (filename is None) and not (self.fileprefix is None): 383 | filename = f'{self.fileprefix}_scatter{postfix}.png' 384 | 385 | if fontsize is None: 386 | fontsize = self.fontsize 387 | 388 | fig = plt.figure(figsize=figsize) 389 | ax = plt.gca() 390 | sns.scatterplot( 391 | x=x_axis, 392 | y=y_axis, 393 | hue=hue, 394 | style=style, 395 | size=size, 396 | sizes=sizes, 397 | alpha=alpha, 398 | palette=palette, 399 | data=df) 400 | 401 | ax.legend_.remove() 402 | ax.set_xlabel(x_axis, fontsize=fontsize) 403 | ax.set_ylabel(y_axis, fontsize=fontsize) 404 | ax.xaxis.set_tick_params(labelsize=fontsize) 405 | ax.yaxis.set_tick_params(labelsize=fontsize) 406 | ax.set_title(title) 407 | if not (text_dict is None): 408 | texts = [] 409 | for label in text_dict.keys(): 410 | texts.append( 411 | ax.text( 412 | text_dict[label][0], 413 | text_dict[label][1], 414 | label, 415 | fontsize=fontsize 416 | ) 417 | ) 418 | 419 | adjust_text( 420 | texts, 421 | arrowprops=dict(arrowstyle='-', color='black', lw=0.1), 422 | ax=ax 423 | ) 424 | 425 | plt.tight_layout() 426 | 427 | if filename: 428 | save_to_file(fig, filename) 429 | 430 | 431 | def log10_with0(x): 432 | mx = np.min(x[x > 0]) 433 | x[x == 0] = mx/10 434 | return np.log10(x) 435 | 436 | def get_palette( 437 | n_colors, 438 | palette_name='Set1' 439 | ): 440 | 441 | try: 442 | palette = sns.color_palette(palette_name) 443 | except: 444 | print('Palette not found. Using default palette tab10') 445 | palette = sns.color_palette() 446 | while len(palette) < n_colors: 447 | palette += palette 448 | 449 | return palette 450 | 451 | 452 | def fast_dimred(emb, method='KernelPCA'): 453 | """ 454 | Takes high dimensional embeddings and produces a 2-dimensional representation 455 | for plotting. 456 | emb: np.array 457 | Embeddings matrix. 458 | method: str (default: 'KernelPCA') 459 | Method for dimensionality reduction: KernelPCA, UMAPpre, UMAPcos, tSNE. 460 | If None return first 2 dimensions of the embedding vector. 461 | """ 462 | if method is None: 463 | return emb[:, :2] 464 | elif method == 'KernelPCA': 465 | similarity_matrix = cosine_similarity(emb) 466 | np.fill_diagonal(similarity_matrix, 1.0) 467 | X = KernelPCA(n_components=2, kernel="precomputed")\ 468 | .fit_transform(similarity_matrix) 469 | else: 470 | raise NotImplementedError 471 | 472 | return X 473 | 474 | 475 | def plot_dose_response(df, 476 | contvar_key, 477 | perturbation_key, 478 | df_ref=None, 479 | response_name='response', 480 | use_ref_response=False, 481 | palette=None, 482 | col_dict=None, 483 | fontsize=8, 484 | measured_points=None, 485 | interpolate=True, 486 | f1=7, 487 | f2=3., 488 | bbox=(1.35, 1.), 489 | ref_name='origin', 490 | title_name='None', 491 | plot_vertical=True, 492 | fname=None, 493 | logscale=None, 494 | xlabelname=None, 495 | format='png'): 496 | 497 | """Plotting decoding of the response with respect to dose. 498 | 499 | Params 500 | ------ 501 | df : `DataFrame` 502 | Table with columns=[perturbation_key, contvar_key, response_name]. 503 | The last column is always "response". 504 | contvar_key : str 505 | Name of the column in df for values to use for x axis. 506 | perturbation_key : str 507 | Name of the column in df for the perturbation or covariate to plot. 508 | response_name: str (default: response) 509 | Name of the column in df for values to use for y axis. 510 | df_ref : `DataFrame` (default: None) 511 | Table with the same columns as in df to plot ground_truth or another 512 | condition for comparison. Could 513 | also be used to just extract reference values for x-axis. 514 | use_ref_response : bool (default: False) 515 | A flag indicating if to use values for y axis from df_ref (True) or j 516 | ust to extract reference values for x-axis. 517 | col_dict : dictionary (default: None) 518 | Dictionary with colors for each value in perturbation_key. 519 | bbox : tuple (default: (1.35, 1.)) 520 | Coordinates to adjust the legend. 521 | plot_vertical : boolean (default: False) 522 | Flag if to plot reference values for x axis from df_ref dataframe. 523 | f1 : float (default: 7.0)) 524 | Width in inches for the plot. 525 | f2 : float (default: 3.0)) 526 | Hight in inches for the plot. 527 | fname : str (default: None) 528 | Name of the file to export the plot. The name comes without format 529 | extension. 530 | format : str (default: png) 531 | Format for the file to export the plot. 532 | """ 533 | sns.set_style("white") 534 | if use_ref_response and not (df_ref is None): 535 | df[ref_name] = 'predictions' 536 | df_ref[ref_name] = 'observations' 537 | if interpolate: 538 | df_plt = pd.concat([df, df_ref]) 539 | else: 540 | df_plt = df 541 | else: 542 | df_plt = df 543 | 544 | atomic_drugs = np.unique(df[perturbation_key].values) 545 | 546 | if palette is None: 547 | current_palette = get_palette(len(list(atomic_drugs))) 548 | 549 | if col_dict is None: 550 | col_dict = dict( 551 | zip( 552 | list(atomic_drugs), 553 | current_palette 554 | ) 555 | ) 556 | 557 | fig = plt.figure(figsize=(f1, f2)) 558 | ax = plt.gca() 559 | 560 | if use_ref_response: 561 | sns.lineplot( 562 | x=contvar_key, 563 | y=response_name, 564 | palette=col_dict, 565 | hue=perturbation_key, 566 | style=ref_name, 567 | dashes=[(1, 0), (2, 1)], 568 | legend='full', 569 | style_order=['predictions', 'observations'], 570 | data=df_plt, ax=ax) 571 | 572 | df_ref = df_ref.replace('training_treated', 'train') 573 | sns.scatterplot( 574 | x=contvar_key, 575 | y=response_name, 576 | hue='split', 577 | size='num_cells', 578 | sizes=(10, 100), 579 | alpha=1., 580 | palette={'train': '#000000', 'training': '#000000', 'ood': '#e41a1c'}, 581 | data=df_ref, ax=ax) 582 | 583 | ax.legend_.remove() 584 | else: 585 | sns.lineplot(x=contvar_key, y=response_name, 586 | palette=col_dict, 587 | hue=perturbation_key, 588 | data=df_plt, ax=ax) 589 | ax.legend( 590 | loc='upper right', 591 | bbox_to_anchor=bbox, 592 | fontsize=fontsize) 593 | 594 | if not (title_name is None): 595 | ax.set_title(title_name, fontsize=fontsize, fontweight='bold') 596 | ax.grid('off') 597 | 598 | if xlabelname is None: 599 | ax.set_xlabel(contvar_key, fontsize=fontsize) 600 | else: 601 | ax.set_xlabel(xlabelname, fontsize=fontsize) 602 | 603 | ax.set_ylabel(f"{response_name}", fontsize=fontsize) 604 | 605 | ax.xaxis.set_tick_params(labelsize=fontsize) 606 | ax.yaxis.set_tick_params(labelsize=fontsize) 607 | 608 | if not (logscale is None): 609 | ax.set_xticks(np.log10(logscale)) 610 | ax.set_xticklabels(logscale, rotation=90) 611 | 612 | if not (df_ref is None): 613 | atomic_drugs=np.unique(df_ref[perturbation_key].values) 614 | for drug in atomic_drugs: 615 | x = df_ref[df_ref[perturbation_key] == drug][contvar_key].values 616 | m1 = np.min(df[df[perturbation_key] == drug][response_name].values) 617 | m2 = np.max(df[df[perturbation_key] == drug][response_name].values) 618 | 619 | if plot_vertical: 620 | for x_dot in x: 621 | ax.plot([x_dot, x_dot], [m1, m2], ':', color='black', 622 | linewidth=.5, alpha=0.5) 623 | 624 | fig.tight_layout() 625 | if fname: 626 | plt.savefig(f'{fname}.{format}', format=format) 627 | 628 | return fig 629 | 630 | def plot_uncertainty_comb_dose( 631 | compert_api, 632 | cov, 633 | pert, 634 | N=11, 635 | metric='cosine', 636 | measured_points=None, 637 | cond_key='condition', 638 | vmin=None, 639 | vmax=None, 640 | sizes=(40, 160), 641 | df_ref=None, 642 | xlims=(0, 1.03), 643 | ylims=(0, 1.03), 644 | fixed_drugs='', 645 | fixed_doses='', 646 | title=True, 647 | filename=None 648 | ): 649 | """Plotting uncertainty for a single perturbation at a dose range for a 650 | particular covariate. 651 | 652 | Params 653 | ------ 654 | compert_api 655 | Api object for the model class. 656 | cov : str 657 | Name of covariate. 658 | pert : str 659 | Name of the perturbation. 660 | N : int 661 | Number of dose values. 662 | metric: str (default: 'cosine') 663 | Metric to evaluate uncertainty. 664 | measured_points : dict (default: None) 665 | A dicitionary of dictionaries. Per each covariate a dictionary with 666 | observed doses per perturbation, e.g. {'covar1': {'pert1': 667 | [0.1, 0.5, 1.0], 'pert2': [0.3]} 668 | cond_key : str (default: 'condition') 669 | Name of the variable to use for plotting. 670 | filename : str (default: None) 671 | Full path to the file to export the plot. File extension should be 672 | included. 673 | 674 | Returns 675 | ------- 676 | pd.DataFrame of uncertainty estimations. 677 | """ 678 | 679 | df_list = [] 680 | for i in np.round(np.linspace(0, 1, N), decimals=2): 681 | for j in np.round(np.linspace(0, 1, N), decimals=2): 682 | df_list.append( 683 | { 684 | 'cell_type' : cov, 685 | 'condition' : pert+fixed_drugs, 686 | 'dose_val' : str(i) + '+' + str(j)+fixed_doses, 687 | } 688 | ) 689 | df_pred = pd.DataFrame(df_list) 690 | uncert_cos = [] 691 | uncert_eucl = [] 692 | closest_cond_cos = [] 693 | closest_cond_eucl = [] 694 | for i in range(df_pred.shape[0]): 695 | uncert_cos_, uncert_eucl_, closest_cond_cos_, closest_cond_eucl_ = ( 696 | compert_api.compute_uncertainty( 697 | cov=df_pred.iloc[i]['cell_type'], 698 | pert=df_pred.iloc[i]['condition'], 699 | dose=df_pred.iloc[i]['dose_val'] 700 | ) 701 | ) 702 | uncert_cos.append(uncert_cos_) 703 | uncert_eucl.append(uncert_eucl_) 704 | closest_cond_cos.append(closest_cond_cos_) 705 | closest_cond_eucl.append(closest_cond_eucl_) 706 | 707 | df_pred['uncertainty_cosine'] = uncert_cos 708 | df_pred['uncertainty_eucl'] = uncert_eucl 709 | df_pred['closest_cond_cos'] = closest_cond_cos 710 | df_pred['closest_cond_eucl'] = closest_cond_eucl 711 | doses = df_pred.dose_val.apply(lambda x: x.split('+')) 712 | X = np.array( 713 | doses 714 | .apply(lambda x: x[0]) 715 | .astype(float) 716 | ).reshape(N, N) 717 | Y = np.array( 718 | doses 719 | .apply(lambda x: x[1]) 720 | .astype(float) 721 | ).reshape(N, N) 722 | Z = np.array( 723 | df_pred[f'uncertainty_{metric}'] 724 | .values 725 | .astype(float) 726 | ).reshape(N, N) 727 | 728 | fig, ax = plt.subplots(1, 1) 729 | CS = ax.contourf(X, Y, Z, cmap='coolwarm', levels=20, 730 | alpha=1, vmin=vmin, vmax=vmax) 731 | 732 | ax.set_xlabel(pert.split('+')[0], fontweight="bold") 733 | ax.set_ylabel(pert.split('+')[1], fontweight="bold") 734 | if title: 735 | ax.set_title(cov) 736 | 737 | if not (df_ref is None): 738 | sns.scatterplot( 739 | x=pert.split('+')[0], 740 | y=pert.split('+')[1], 741 | hue='split', 742 | size='num_cells', 743 | sizes=sizes, 744 | alpha=1., 745 | palette={'train': '#000000', 'training': '#000000', 'ood': '#e41a1c'}, 746 | data=df_ref, 747 | ax=ax) 748 | ax.legend_.remove() 749 | 750 | if measured_points: 751 | ticks = measured_points[cov][pert] 752 | xticks = [float(x.split('+')[0]) for x in ticks] 753 | yticks = [float(x.split('+')[1]) for x in ticks] 754 | ax.set_xticks(xticks) 755 | ax.set_xticklabels(xticks, rotation=90) 756 | ax.set_yticks(yticks) 757 | fig.colorbar(CS) 758 | sns.despine() 759 | ax.axis("equal") 760 | ax.axis("square") 761 | ax.set_xlim(xlims) 762 | ax.set_ylim(ylims) 763 | 764 | plt.tight_layout() 765 | 766 | if filename: 767 | plt.savefig(filename) 768 | 769 | return df_pred 770 | 771 | def plot_uncertainty_dose( 772 | compert_api, 773 | cov, 774 | pert, 775 | N=11, 776 | metric='cosine', 777 | measured_points=None, 778 | cond_key='condition', 779 | log=False, 780 | min_dose=None, 781 | filename=None 782 | ): 783 | """Plotting uncertainty for a single perturbation at a dose range for a 784 | particular covariate. 785 | 786 | Params 787 | ------ 788 | compert_api 789 | Api object for the model class. 790 | cov : str 791 | Name of covariate. 792 | pert : str 793 | Name of the perturbation. 794 | N : int 795 | Number of dose values. 796 | metric: str (default: 'cosine') 797 | Metric to evaluate uncertainty. 798 | measured_points : dict (default: None) 799 | A dicitionary of dictionaries. Per each covariate a dictionary with 800 | observed doses per perturbation, e.g. {'covar1': {'pert1': 801 | [0.1, 0.5, 1.0], 'pert2': [0.3]} 802 | cond_key : str (default: 'condition') 803 | Name of the variable to use for plotting. 804 | log : boolean (default: False) 805 | A flag if to plot on a log scale. 806 | min_dose : float (default: None) 807 | Minimum dose for the uncertainty estimate. 808 | filename : str (default: None) 809 | Full path to the file to export the plot. File extension should be included. 810 | 811 | Returns 812 | ------- 813 | pd.DataFrame of uncertainty estimations. 814 | """ 815 | 816 | df_list = [] 817 | if log: 818 | if min_dose is None: 819 | min_dose = 1e-3 820 | N_val = np.round(np.logspace(np.log10(min_dose), np.log10(1), N), decimals=10) 821 | else: 822 | if min_dose is None: 823 | min_dose = 0 824 | N_val = np.round(np.linspace(min_dose, 1., N), decimals=3) 825 | 826 | for i in N_val: 827 | df_list.append( 828 | { 829 | 'cell_type' : cov, 830 | 'condition' : pert, 831 | 'dose_val' : repr(i), 832 | } 833 | ) 834 | df_pred = pd.DataFrame(df_list) 835 | uncert_cos = [] 836 | uncert_eucl = [] 837 | closest_cond_cos = [] 838 | closest_cond_eucl = [] 839 | for i in range(df_pred.shape[0]): 840 | uncert_cos_, uncert_eucl_, closest_cond_cos_, closest_cond_eucl_ = ( 841 | compert_api.compute_uncertainty( 842 | cov=df_pred.iloc[i]['cell_type'], 843 | pert=df_pred.iloc[i]['condition'], 844 | dose=df_pred.iloc[i]['dose_val'] 845 | ) 846 | ) 847 | uncert_cos.append(uncert_cos_) 848 | uncert_eucl.append(uncert_eucl_) 849 | closest_cond_cos.append(closest_cond_cos_) 850 | closest_cond_eucl.append(closest_cond_eucl_) 851 | 852 | df_pred['uncertainty_cosine'] = uncert_cos 853 | df_pred['uncertainty_eucl'] = uncert_eucl 854 | df_pred['closest_cond_cos'] = closest_cond_cos 855 | df_pred['closest_cond_eucl'] = closest_cond_eucl 856 | 857 | x = df_pred.dose_val.values.astype(float) 858 | y = df_pred[f'uncertainty_{metric}'].values.astype(float) 859 | fig, ax = plt.subplots(1, 1) 860 | ax.plot(x, y) 861 | ax.set_xlabel(pert) 862 | ax.set_ylabel('Uncertainty') 863 | ax.set_title(cov) 864 | if log: 865 | ax.set_xscale('log') 866 | if measured_points: 867 | ticks = measured_points[cov][pert] 868 | ax.set_xticks(ticks) 869 | ax.set_xticklabels(ticks, rotation=90) 870 | else: 871 | plt.draw() 872 | ax.set_xticklabels(ax.get_xticklabels(), rotation=90) 873 | 874 | sns.despine() 875 | plt.tight_layout() 876 | 877 | if filename: 878 | plt.savefig(filename) 879 | 880 | return df_pred 881 | 882 | 883 | 884 | def save_to_file(fig, file_name, file_format=None): 885 | if file_format is None: 886 | if file_name.split(".")[-1] in ['png', 'pdf']: 887 | file_format = file_name.split(".")[-1] 888 | savename = file_name 889 | else: 890 | file_format = 'pdf' 891 | savename = f'{file_name}.{file_format}' 892 | else: 893 | savename = file_name 894 | 895 | fig.savefig(savename, format=file_format) 896 | print(f"Saved file to: {savename}") 897 | 898 | 899 | def plot_embedding( 900 | emb, 901 | labels=None, 902 | col_dict=None, 903 | title=None, 904 | show_lines=False, 905 | show_text=False, 906 | show_legend=True, 907 | axis_equal=True, 908 | circle_size=40, 909 | circe_transparency=1.0, 910 | line_transparency=0.8, 911 | line_width=1.0, 912 | fontsize=9, 913 | fig_width=4, 914 | fig_height=4, 915 | file_name=None, 916 | file_format=None, 917 | labels_name=None, 918 | width_ratios=[7, 1], 919 | bbox=(1.3, 0.7), 920 | show=True 921 | ): 922 | sns.set_style("white") 923 | 924 | # create data structure suitable for embedding 925 | df = pd.DataFrame(emb, columns=['dim1', 'dim2']) 926 | if not (labels is None): 927 | if labels_name is None: 928 | labels_name = 'labels' 929 | df[labels_name] = labels 930 | 931 | 932 | fig = plt.figure(figsize=(fig_width, fig_height)) 933 | ax = plt.gca() 934 | 935 | sns.despine(left=False, bottom=False, right=True) 936 | 937 | if (col_dict is None) and not (labels is None): 938 | col_dict = get_colors(labels) 939 | 940 | sns.scatterplot( 941 | x="dim1", 942 | y="dim2", 943 | hue=labels_name, 944 | palette=col_dict, 945 | alpha=circe_transparency, 946 | edgecolor="none", 947 | s=circle_size, 948 | data=df, 949 | ax=ax) 950 | 951 | try: 952 | ax.legend_.remove() 953 | except: 954 | pass 955 | 956 | if show_lines: 957 | for i in range(len(emb)): 958 | if col_dict is None: 959 | ax.plot( 960 | [0, emb[i, 0]], 961 | [0, emb[i, 1]], 962 | alpha=line_transparency, 963 | linewidth=line_width, 964 | c=None 965 | ) 966 | else: 967 | ax.plot( 968 | [0, emb[i, 0]], 969 | [0, emb[i, 1]], 970 | alpha=line_transparency, 971 | linewidth=line_width, 972 | c=col_dict[labels[i]] 973 | ) 974 | 975 | if show_text and not (labels is None): 976 | texts = [] 977 | labels = np.array(labels) 978 | unique_labels = np.unique(labels) 979 | for label in unique_labels: 980 | idx_label = np.where(labels == label)[0] 981 | texts.append( 982 | ax.text( 983 | np.mean(emb[idx_label, 0]), 984 | np.mean(emb[idx_label, 1]), 985 | label, 986 | fontsize=fontsize 987 | ) 988 | ) 989 | 990 | adjust_text( 991 | texts, 992 | arrowprops=dict(arrowstyle='-', color='black', lw=0.1), 993 | ax=ax 994 | ) 995 | 996 | if axis_equal: 997 | ax.axis('equal') 998 | ax.axis('square') 999 | 1000 | 1001 | if title: 1002 | ax.set_title(title, fontsize=fontsize, fontweight="bold") 1003 | 1004 | ax.set_xlabel('dim1', fontsize=fontsize) 1005 | ax.set_ylabel('dim2', fontsize=fontsize) 1006 | ax.xaxis.set_tick_params(labelsize=fontsize) 1007 | ax.yaxis.set_tick_params(labelsize=fontsize) 1008 | 1009 | plt.tight_layout() 1010 | 1011 | if file_name: 1012 | save_to_file(fig, file_name, file_format) 1013 | 1014 | if show: 1015 | plt.show() 1016 | plt.close() 1017 | 1018 | return plt 1019 | 1020 | 1021 | def get_colors( 1022 | labels, 1023 | palette=None, 1024 | palette_name=None 1025 | ): 1026 | n_colors = len(labels) 1027 | if palette is None: 1028 | palette = get_palette(n_colors, palette_name) 1029 | col_dict = dict(zip(labels, palette[:n_colors])) 1030 | return col_dict 1031 | 1032 | 1033 | def plot_similarity( 1034 | emb, 1035 | labels=None, 1036 | col_dict=None, 1037 | fig_width=4, 1038 | fig_height=4, 1039 | cmap='coolwarm', 1040 | fmt='png', 1041 | fontsize=7, 1042 | file_format=None, 1043 | file_name=None, 1044 | show=True 1045 | ): 1046 | 1047 | # first we take construct similarity matrix 1048 | # add another similarity 1049 | similarity_matrix = cosine_similarity(emb) 1050 | 1051 | df = pd.DataFrame( 1052 | similarity_matrix, 1053 | columns=labels, 1054 | index=labels, 1055 | ) 1056 | 1057 | if col_dict is None: 1058 | col_dict = get_colors(labels) 1059 | 1060 | network_colors = pd.Series(df.columns, index=df.columns).map(col_dict) 1061 | 1062 | sns_plot = sns.clustermap( 1063 | df, 1064 | cmap=cmap, 1065 | center=0, 1066 | row_colors=network_colors, 1067 | col_colors=network_colors, 1068 | mask=False, 1069 | metric='euclidean', 1070 | figsize=(fig_height, fig_width), 1071 | vmin=-1, vmax=1, 1072 | fmt=file_format 1073 | ) 1074 | 1075 | sns_plot.ax_heatmap.xaxis.set_tick_params(labelsize=fontsize) 1076 | sns_plot.ax_heatmap.yaxis.set_tick_params(labelsize=fontsize) 1077 | sns_plot.ax_heatmap.axis('equal') 1078 | sns_plot.cax.yaxis.set_tick_params(labelsize=fontsize) 1079 | 1080 | if file_name: 1081 | save_to_file(sns_plot, file_name, file_format) 1082 | 1083 | if show: 1084 | plt.show() 1085 | plt.close() 1086 | 1087 | 1088 | from scipy import stats, sparse 1089 | from sklearn.metrics import r2_score 1090 | 1091 | def mean_plot( 1092 | adata, 1093 | pred, 1094 | condition_key, 1095 | exp_key, 1096 | path_to_save="./reg_mean.pdf", 1097 | gene_list=None, 1098 | deg_list=None, 1099 | show=False, 1100 | title=None, 1101 | verbose=False, 1102 | x_coeff=0.30, 1103 | y_coeff=0.8, 1104 | fontsize=11, 1105 | R2_type="R2", 1106 | figsize=(3.5, 3.5), 1107 | **kwargs 1108 | ): 1109 | """ 1110 | Plots mean matching. 1111 | 1112 | # Parameters 1113 | adata: `~anndata.AnnData` 1114 | Contains real v 1115 | pred: `~anndata.AnnData` 1116 | Contains predicted values. 1117 | condition_key: Str 1118 | adata.obs key to look for x-axis and y-axis condition 1119 | exp_key: Str 1120 | Condition in adata.obs[condition_key] to be ploted 1121 | path_to_save: basestring 1122 | Path to save the plot. 1123 | gene_list: list 1124 | List of gene names to be plotted. 1125 | deg_list: list 1126 | List of DEGs to compute R2 1127 | show: boolean 1128 | if True plots the figure 1129 | Verbose: boolean 1130 | If true prints the value 1131 | title: Str 1132 | Title of the plot 1133 | x_coeff: float 1134 | Shifts R2 text horizontally by x_coeff 1135 | y_coeff: float 1136 | Shifts R2 text vertically by y_coeff 1137 | show: bool 1138 | if `True`: will show to the plot after saving it. 1139 | fontsize: int 1140 | Font size for R2 texts 1141 | R2_type: Str 1142 | How to compute R2 value, should be either Pearson R2 or R2 (sklearn) 1143 | 1144 | Returns: 1145 | Calluated R2 values 1146 | """ 1147 | 1148 | r2_types = ['R2', 'Pearson R2'] 1149 | if R2_type not in r2_types: 1150 | raise ValueError("R2 caclulation should be one of" + str(r2_types)) 1151 | if sparse.issparse(adata.X): 1152 | adata.X = adata.X.A 1153 | if sparse.issparse(pred.X): 1154 | pred.X = pred.X.A 1155 | diff_genes = deg_list 1156 | real = adata[adata.obs[condition_key] == exp_key] 1157 | pred = pred[pred.obs[condition_key] == exp_key] 1158 | if diff_genes is not None: 1159 | if hasattr(diff_genes, "tolist"): 1160 | diff_genes = diff_genes.tolist() 1161 | real_diff = adata[:, diff_genes][adata.obs[condition_key] == exp_key] 1162 | pred_diff = pred[:, diff_genes][pred.obs[condition_key] == exp_key] 1163 | x_diff = np.average(pred_diff.X, axis=0) 1164 | y_diff = np.average(real_diff.X, axis=0) 1165 | if R2_type == "R2": 1166 | r2_diff = r2_score(y_diff, x_diff) 1167 | if R2_type == "Pearson R2": 1168 | m, b, pearson_r_diff, p_value_diff, std_err_diff =\ 1169 | stats.linregress(y_diff, x_diff) 1170 | r2_diff = pearson_r_diff**2 1171 | if verbose: 1172 | print(f'Top {len(diff_genes)} DEGs var: ', r2_diff) 1173 | x = np.average(pred.X, axis=0) 1174 | y = np.average(real.X, axis=0) 1175 | if R2_type == "R2": 1176 | r2 = r2_score(y, x) 1177 | if R2_type == "Pearson R2": 1178 | m, b, pearson_r, p_value, std_err = stats.linregress(y, x) 1179 | r2 = pearson_r**2 1180 | if verbose: 1181 | print('All genes var: ', r2) 1182 | df = pd.DataFrame({f'{exp_key}_true': x, f'{exp_key}_pred': y}) 1183 | 1184 | plt.figure(figsize=figsize) 1185 | ax = sns.regplot(x=f'{exp_key}_true', y=f'{exp_key}_pred', data=df) 1186 | ax.tick_params(labelsize=fontsize) 1187 | if "range" in kwargs: 1188 | start, stop, step = kwargs.get("range") 1189 | ax.set_xticks(np.arange(start, stop, step)) 1190 | ax.set_yticks(np.arange(start, stop, step)) 1191 | ax.set_xlabel('true', fontsize=fontsize) 1192 | ax.set_ylabel('pred', fontsize=fontsize) 1193 | if gene_list is not None: 1194 | for i in gene_list: 1195 | j = adata.var_names.tolist().index(i) 1196 | x_bar = x[j] 1197 | y_bar = y[j] 1198 | plt.text(x_bar, y_bar, i, fontsize=fontsize, color="black") 1199 | plt.plot(x_bar, y_bar, 'o', color="red", markersize=5) 1200 | if title is None: 1201 | plt.title(f"", fontsize=fontsize, fontweight="bold") 1202 | else: 1203 | plt.title(title, fontsize=fontsize, fontweight="bold") 1204 | ax.text(max(x) - max(x) * x_coeff, max(y) - y_coeff * max(y), 1205 | r'$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= ' + f"{r2:.2f}", 1206 | fontsize=fontsize) 1207 | if diff_genes is not None: 1208 | ax.text(max(x) - max(x) * x_coeff, max(y) - (y_coeff + 0.15) * max(y), 1209 | r'$\mathrm{R^2_{\mathrm{\mathsf{DEGs}}}}$= ' + f"{r2_diff:.2f}", 1210 | fontsize=fontsize) 1211 | plt.savefig(f"{path_to_save}", bbox_inches='tight', dpi=100) 1212 | if show: 1213 | plt.show() 1214 | plt.close() 1215 | if diff_genes is not None: 1216 | return r2, r2_diff 1217 | else: 1218 | return r2 1219 | 1220 | 1221 | def plot_r2_matrix(pred, adata, de_genes=None, **kwds): 1222 | """Plots a pairwise R2 heatmap between predicted and control conditions. 1223 | 1224 | Params 1225 | ------ 1226 | pred : `AnnData` 1227 | Must have the field `cov_drug_dose_name` 1228 | adata : `AnnData` 1229 | Original gene expression data, with the field `cov_drug_dose_name`. 1230 | de_genes : `dict` 1231 | Dictionary of de_genes, where the keys 1232 | match the categories in `cov_drug_dose_name` 1233 | """ 1234 | r2s_mean = defaultdict(list) 1235 | r2s_var = defaultdict(list) 1236 | conditions = pred.obs['cov_drug_dose_name'].cat.categories 1237 | for cond in conditions: 1238 | if de_genes: 1239 | degs = de_genes[cond] 1240 | y_pred = pred[:, degs][pred.obs['cov_drug_dose_name'] == cond].X 1241 | y_true_adata = adata[:, degs] 1242 | else: 1243 | y_pred = pred[pred.obs['cov_drug_dose_name'] == cond].X 1244 | y_true_adata = adata 1245 | 1246 | # calculate r2 between pairwise 1247 | for cond_real in conditions: 1248 | y_true = y_true_adata[y_true_adata.obs['cov_drug_dose_name'] ==\ 1249 | cond_real].X.toarray() 1250 | r2s_mean[cond_real].append(r2_score(y_true.mean(axis=0),\ 1251 | y_pred.mean(axis=0))) 1252 | r2s_var[cond_real].append(r2_score(y_true.var(axis=0),\ 1253 | y_pred.var(axis=0))) 1254 | 1255 | for r2_dict in [r2s_mean, r2s_var]: 1256 | r2_df = pd.DataFrame.from_dict(r2_dict, orient='index') 1257 | r2_df.columns = conditions 1258 | 1259 | plt.figure(figsize=(5, 5)) 1260 | p = sns.heatmap(data=r2_df, vmin = max(r2_df.min(0).min(), 0), 1261 | cmap='Blues', cbar=False, 1262 | annot=True, fmt='.2f', annot_kws={'fontsize':5}, **kwds) 1263 | plt.xticks(fontsize=6) 1264 | plt.yticks(fontsize=6) 1265 | plt.xlabel('y_true') 1266 | plt.ylabel('y_pred') 1267 | plt.show() 1268 | 1269 | 1270 | def arrange_history(history): 1271 | 1272 | print(history.keys()) 1273 | 1274 | 1275 | class ComPertHistory: 1276 | """ 1277 | A wrapper for automatic plotting history of ComPert model.. 1278 | """ 1279 | def __init__(self, 1280 | history, 1281 | fileprefix=None 1282 | ): 1283 | """ 1284 | Parameters 1285 | ---------- 1286 | history : dict 1287 | Dictionary of ComPert history. 1288 | fileprefix : str, optional (default: None) 1289 | Prefix (with path) to the filename to save all embeddings in a 1290 | standartized manner. If None, embeddings are not saved to file. 1291 | """ 1292 | 1293 | self.time = history['elapsed_time_min'] 1294 | self.losses_list = ['loss_reconstruction', 'loss_reconstruction_genes', 1295 | 'loss_reconstruction_proteins', 'loss_adv_drugs', 'loss_adv_cell_types'] 1296 | self.losses_names_list = ['recon_loss', 'recon_loss_genes', 1297 | 'recon_loss_proteins', 'loss_adv_drugs', 'loss_adv_cell_types'] 1298 | self.penalties_list = ['penalty_adv_drugs', 'penalty_adv_cell_types'] 1299 | 1300 | subset_keys = ['epoch'] + self.losses_list + self.penalties_list 1301 | 1302 | self.losses = pd.DataFrame(dict((k, history[k]) for k in\ 1303 | subset_keys if k in history)) 1304 | 1305 | self.header = ['mean', 'var', 'mean_DE', 'var_DE', 'mean_proteins', 'var_proteins'] 1306 | 1307 | self.metrics = pd.DataFrame(columns=['epoch', 'split'] + self.header) 1308 | for split in ['training', 'test', 'ood']: 1309 | df_split = pd.DataFrame(np.array(history[split]), columns=self.header) 1310 | df_split['split'] = split 1311 | df_split['epoch'] = history['stats_epoch'] 1312 | self.metrics = pd.concat([self.metrics, df_split]) 1313 | 1314 | self.disent = pd.DataFrame(dict((k, history[k])\ 1315 | for k in ['perturbation disentanglement',\ 1316 | 'covariate disentanglement'] if k in\ 1317 | history)) 1318 | self.disent['epoch'] = history['stats_epoch'] 1319 | 1320 | self.fileprefix = fileprefix 1321 | 1322 | def print_time(self): 1323 | print(f"Computation time: {self.time:.0f} min") 1324 | 1325 | def plot_losses(self, filename=None, show=True): 1326 | """ 1327 | Parameters 1328 | ---------- 1329 | filename : str (default: None) 1330 | Name of the file to save the plot. If None, will automatically 1331 | generate name from prefix file. 1332 | """ 1333 | if filename is None: 1334 | if self.fileprefix is None: 1335 | filename = None 1336 | else: 1337 | filename = f'{self.fileprefix}_history_losses.png' 1338 | 1339 | fig, ax = plt.subplots(1, 6, sharex=True, sharey=False, figsize=(15, 2.)) 1340 | 1341 | i = 0 1342 | for i in range(6): 1343 | if i < 5: 1344 | ax[i].plot(self.losses['epoch'].values,\ 1345 | self.losses[self.losses_list[i]].values) 1346 | ax[i].set_title(self.losses_names_list[i], fontweight="bold") 1347 | else: 1348 | ax[i].plot(self.losses['epoch'].values,\ 1349 | self.losses[self.penalties_list].values) 1350 | ax[i].set_title('Penalties', fontweight="bold") 1351 | plt.tight_layout() 1352 | 1353 | if filename: 1354 | save_to_file(fig, filename) 1355 | 1356 | if show: 1357 | plt.show() 1358 | plt.close() 1359 | 1360 | def plot_metrics(self, epoch_min=0, filename=None, show=True): 1361 | """ 1362 | Parameters 1363 | ---------- 1364 | epoch_min : int (default: 0) 1365 | Epoch from which to show metrics history plot. Done for readability. 1366 | 1367 | filename : str (default: None) 1368 | Name of the file to save the plot. If None, will automatically 1369 | generate name from prefix file. 1370 | """ 1371 | if filename is None: 1372 | if self.fileprefix is None: 1373 | filename = None 1374 | else: 1375 | filename = f'{self.fileprefix}_history_metrics.png' 1376 | 1377 | df = self.metrics.melt(id_vars=["epoch", "split"]) 1378 | col_dict = dict(zip(['training', 'test', 'ood'],\ 1379 | ['#377eb8', '#4daf4a', '#e41a1c'])) 1380 | fig, axs = plt.subplots(4, 2, sharex=True, sharey=False, figsize=(8., 8.)) 1381 | ax = plt.gca() 1382 | i = 0 1383 | for i1 in range(3): 1384 | for i2 in range(2): 1385 | sns.lineplot( 1386 | data=df[(df['variable'] == self.header[i]) &\ 1387 | (df['epoch'] > epoch_min)], 1388 | x="epoch", 1389 | y="value", 1390 | palette=col_dict, 1391 | hue="split", 1392 | ax=axs[i1, i2] 1393 | ) 1394 | axs[i1, i2].set_title(self.header[i], fontweight="bold") 1395 | i += 1 1396 | 1397 | sns.lineplot( 1398 | data=self.disent[self.disent['epoch'] > epoch_min], 1399 | x="epoch", 1400 | y="perturbation disentanglement", 1401 | legend=False, 1402 | ax=axs[3, 0] 1403 | ) 1404 | axs[3, 0].set_title("perturbation disentanglement", fontweight="bold") 1405 | 1406 | sns.lineplot( 1407 | data=self.disent[self.disent['epoch'] > epoch_min], 1408 | x="epoch", 1409 | y="covariate disentanglement", 1410 | legend=False, 1411 | ax=axs[3, 1] 1412 | ) 1413 | axs[3, 1].set_title("covariate disentanglement", fontweight="bold") 1414 | 1415 | plt.tight_layout() 1416 | 1417 | if filename: 1418 | save_to_file(fig, filename) 1419 | 1420 | if show: 1421 | plt.show() 1422 | plt.close() 1423 | -------------------------------------------------------------------------------- /MultiCPA/seml_sweep_icb.py: -------------------------------------------------------------------------------- 1 | # Author: Kemal Inecik 2 | # Email: k.inecik@gmail.com 3 | 4 | import sys 5 | from sacred import Experiment 6 | from collections import defaultdict 7 | import json 8 | import seml 9 | import torch 10 | import os 11 | import time 12 | import numpy as np 13 | 14 | from data import load_dataset_splits 15 | from model import ComPert, TotalComPert, PoEComPert, TotalPoEComPert 16 | from train import evaluate 17 | 18 | ex = Experiment() 19 | seml.setup_logger(ex) 20 | 21 | 22 | @ex.post_run_hook 23 | def collect_stats(_run): 24 | seml.collect_exp_stats(_run) 25 | 26 | 27 | @ex.config 28 | def config(): 29 | overwrite = None 30 | db_collection = None 31 | if db_collection is not None: 32 | ex.observers.append( 33 | seml.create_mongodb_observer(db_collection, overwrite=overwrite) 34 | ) 35 | 36 | 37 | def pjson(s): 38 | """ 39 | Prints a string in JSON format and flushes stdout 40 | """ 41 | print(json.dumps(s), flush=True) 42 | 43 | 44 | class ExperimentWrapper: 45 | """ 46 | A simple wrapper around a sacred experiment, making use of sacred's captured functions with prefixes. 47 | This allows a modular design of the configuration, where certain sub-dictionaries (e.g., "data") are parsed by 48 | specific method. This avoids having one large "main" function which takes all parameters as input. 49 | """ 50 | 51 | def __init__(self, init_all=True): 52 | 53 | 54 | if init_all: 55 | self.init_all() 56 | 57 | # With the prefix option we can "filter" the configuration for the sub-dictionary under "dataset". 58 | @ex.capture(prefix="dataset") 59 | def init_dataset(self, dataset_args: dict): 60 | """ 61 | Perform dataset loading, preprocessing etc. 62 | Since we set prefix="dataset ", this method only gets passed the respective sub-dictionary, enabling a modular 63 | experiment design. 64 | """ 65 | self.datasets, self.dataset = load_dataset_splits( 66 | **dataset_args, 67 | return_dataset=True 68 | ) 69 | 70 | @ex.capture(prefix="model") 71 | def init_model(self, model_type: str, model_args: dict): 72 | 73 | device = "cuda" if torch.cuda.is_available() else "cpu" 74 | print(f'# DEVICE: {device}') 75 | 76 | if model_type == 'ComPert': 77 | self.autoencoder = ComPert( 78 | self.datasets["training"].num_genes, 79 | self.datasets["training"].num_drugs, 80 | self.datasets["training"].num_cell_types, 81 | num_proteins=self.datasets["training"].num_proteins, 82 | device=device, 83 | seed=self.seed, 84 | loss_ae=model_args["loss_ae"], 85 | doser_type=model_args["doser_type"], 86 | patience=model_args["patience"], 87 | hparams=model_args["hparams"], 88 | decoder_activation=model_args["decoder_activation"], 89 | is_vae=model_args["is_vae"], 90 | ) 91 | elif model_type == 'TotalComPert': 92 | self.autoencoder = TotalComPert( 93 | self.datasets["training"].num_genes, 94 | self.datasets["training"].num_drugs, 95 | self.datasets["training"].num_cell_types, 96 | num_proteins=self.datasets["training"].num_proteins, 97 | device=device, 98 | seed=self.seed, 99 | loss_ae=model_args["loss_ae"], 100 | doser_type=model_args["doser_type"], 101 | patience=model_args["patience"], 102 | hparams=model_args["hparams"], 103 | decoder_activation=model_args["decoder_activation"], 104 | is_vae=model_args["is_vae"], 105 | ) 106 | elif model_type == 'PoEComPert': 107 | self.autoencoder = PoEComPert( 108 | self.datasets["training"].num_genes, 109 | self.datasets["training"].num_drugs, 110 | self.datasets["training"].num_cell_types, 111 | num_proteins=self.datasets["training"].num_proteins, 112 | device=device, 113 | seed=self.seed, 114 | loss_ae=model_args["loss_ae"], 115 | doser_type=model_args["doser_type"], 116 | patience=model_args["patience"], 117 | hparams=model_args["hparams"], 118 | decoder_activation=model_args["decoder_activation"], 119 | is_vae=model_args["is_vae"], 120 | ) 121 | elif model_type == 'TotalPoEComPert': 122 | self.autoencoder = TotalPoEComPert( 123 | self.datasets["training"].num_genes, 124 | self.datasets["training"].num_drugs, 125 | self.datasets["training"].num_cell_types, 126 | num_proteins=self.datasets["training"].num_proteins, 127 | device=device, 128 | seed=self.seed, 129 | loss_ae=model_args["loss_ae"], 130 | doser_type=model_args["doser_type"], 131 | patience=model_args["patience"], 132 | hparams=model_args["hparams"], 133 | decoder_activation=model_args["decoder_activation"], 134 | is_vae=model_args["is_vae"], 135 | ) 136 | 137 | def update_datasets(self): 138 | 139 | self.datasets.update({ 140 | "loader_tr": torch.utils.data.DataLoader( 141 | self.datasets["training"], 142 | batch_size=self.autoencoder.hparams["batch_size"], 143 | shuffle=True) 144 | }) 145 | # pjson({"training_args": args}) 146 | pjson({"autoencoder_params": self.autoencoder.hparams}) 147 | 148 | @ex.capture 149 | def init_all(self, seed): 150 | """ 151 | Sequentially run the sub-initializers of the experiment. 152 | """ 153 | 154 | self.seed = seed 155 | self.init_dataset() 156 | self.init_model() 157 | self.update_datasets() 158 | 159 | @ex.capture(prefix="training") 160 | def train( 161 | self, 162 | num_epochs: int, 163 | max_minutes: int, 164 | checkpoint_freq: int, 165 | ignore_evaluation: bool, 166 | save_checkpoints: bool, 167 | save_dir: str, 168 | save_last: bool, 169 | ): 170 | 171 | print(f"CWD: {os.getcwd()}") 172 | print(f"Save dir: {save_dir}") 173 | print(f"Is path?: {os.path.exists(save_dir)}") 174 | 175 | exp_id = ''.join(map(str, list(np.random.randint(0, 10, 30)))) 176 | 177 | start_time = time.time() 178 | for epoch in range(num_epochs): 179 | epoch_training_stats = defaultdict(float) 180 | 181 | for genes, drugs, cell_types, proteins, raw_genes, raw_proteins in self.datasets["loader_tr"]: 182 | minibatch_training_stats = self.autoencoder.update( 183 | genes, drugs, cell_types, proteins, raw_genes, raw_proteins, epoch, num_epochs) 184 | 185 | for key, val in minibatch_training_stats.items(): 186 | epoch_training_stats[key] += val 187 | 188 | for key, val in epoch_training_stats.items(): 189 | epoch_training_stats[key] = val / len(self.datasets["loader_tr"]) 190 | if not (key in self.autoencoder.history.keys()): 191 | self.autoencoder.history[key] = [] 192 | self.autoencoder.history[key].append(epoch_training_stats[key]) 193 | self.autoencoder.history['epoch'].append(epoch) 194 | 195 | ellapsed_minutes = (time.time() - start_time) / 60 196 | self.autoencoder.history['elapsed_time_min'] = ellapsed_minutes 197 | 198 | # decay learning rate if necessary 199 | # also check stopping condition: patience ran out OR 200 | # time ran out OR max epochs achieved 201 | stop = ellapsed_minutes > max_minutes or \ 202 | (epoch == num_epochs - 1) 203 | 204 | if (epoch % checkpoint_freq) == 0 or stop: 205 | evaluation_stats = {} 206 | if not ignore_evaluation: 207 | evaluation_stats = evaluate(self.autoencoder, self.datasets) 208 | for key, val in evaluation_stats.items(): 209 | if not (key in self.autoencoder.history.keys()): 210 | self.autoencoder.history[key] = [] 211 | self.autoencoder.history[key].append(val) 212 | self.autoencoder.history['stats_epoch'].append(epoch) 213 | 214 | pjson({ 215 | "epoch": epoch, 216 | "training_stats": epoch_training_stats, 217 | "evaluation_stats": evaluation_stats, 218 | "ellapsed_minutes": ellapsed_minutes 219 | }) 220 | 221 | if save_checkpoints: 222 | if save_dir is None or not os.path.exists(save_dir): 223 | print(os.path.exists(save_dir)) 224 | print(not os.path.exists(save_dir)) 225 | raise ValueError( 226 | "Please provide a valid directory path in the 'save_dir' argument." 227 | ) 228 | fn = os.path.join(save_dir, f"{exp_id}_{epoch}.pt") 229 | torch.save( 230 | (self.autoencoder.state_dict(), self.autoencoder.hparams, self.autoencoder.history), 231 | fn) 232 | print(f"Model saved: {fn}") 233 | 234 | stop = stop or self.autoencoder.early_stopping( 235 | np.mean(evaluation_stats["test"])) #or self.autoencoder.specific_threshold( 236 | #evaluation_stats["test"][0], epoch, epoch_thr=361, score_thr=0.05) # 0->gene 237 | if stop: 238 | pjson({"early_stop": epoch}) 239 | if save_last: 240 | if save_dir is None or not os.path.exists(save_dir): 241 | print(os.path.exists(save_dir)) 242 | print(not os.path.exists(save_dir)) 243 | raise ValueError( 244 | "Please provide a valid directory path in the 'save_dir' argument." 245 | ) 246 | fn = os.path.join(save_dir, f"{exp_id}_last.pt") 247 | torch.save( 248 | (self.autoencoder.state_dict(), self.autoencoder.hparams, self.autoencoder.history), 249 | fn) 250 | print(f"Model saved: {fn}") 251 | 252 | break 253 | 254 | results = self.autoencoder.history 255 | # results = pd.DataFrame.from_dict(results) # not same length! 256 | results["total_epochs"] = epoch 257 | results["exp_id"] = exp_id 258 | return results 259 | 260 | 261 | # We can call this command, e.g., from a Jupyter notebook with init_all=False to get an "empty" experiment wrapper, 262 | # where we can then for instance load a pretrained model to inspect the performance. 263 | @ex.command(unobserved=True) 264 | def get_experiment(init_all=False): 265 | print("get_experiment") 266 | experiment = ExperimentWrapper(init_all=init_all) 267 | return experiment 268 | 269 | 270 | # This function will be called by default. Note that we could in principle manually pass an experiment instance, 271 | # e.g., obtained by loading a model from the database or by calling this from a Jupyter notebook. 272 | @ex.automain 273 | def train(experiment=None): 274 | if experiment is None: 275 | experiment = ExperimentWrapper() 276 | return experiment.train() 277 | -------------------------------------------------------------------------------- /MultiCPA/train.py: -------------------------------------------------------------------------------- 1 | # Author: Kemal Inecik 2 | # Email: k.inecik@gmail.com 3 | 4 | import os 5 | import json 6 | import argparse 7 | 8 | import torch 9 | import numpy as np 10 | from collections import defaultdict 11 | 12 | try: 13 | from data import load_dataset_splits 14 | from model import ComPert, TotalComPert, PoEComPert, TotalPoEComPert 15 | except (ModuleNotFoundError, ImportError): 16 | from MultiCPA.data import load_dataset_splits 17 | from MultiCPA.model import ComPert, TotalComPert, PoEComPert, TotalPoEComPert 18 | 19 | from sklearn.metrics import r2_score, balanced_accuracy_score, make_scorer 20 | from sklearn.linear_model import LogisticRegression 21 | from sklearn.model_selection import cross_val_score 22 | from sklearn.neighbors import KNeighborsClassifier 23 | from sklearn.preprocessing import StandardScaler 24 | 25 | import time 26 | 27 | def pjson(s): 28 | """ 29 | Prints a string in JSON format and flushes stdout 30 | """ 31 | print(json.dumps(s), flush=True) 32 | 33 | 34 | def evaluate_disentanglement(autoencoder, dataset, nonlinear=False): 35 | """ 36 | Given a ComPert model, this function measures the correlation between 37 | its latent space and 1) a dataset's drug vectors 2) a datasets covariate 38 | vectors. 39 | 40 | """ 41 | if autoencoder.loss_ae == 'gauss': 42 | latent_basal = autoencoder.get_latent( 43 | dataset.genes, 44 | dataset.drugs, 45 | dataset.cell_types, 46 | proteins=dataset.proteins, 47 | return_latent_treated=False) 48 | elif autoencoder.loss_ae == 'nb': 49 | latent_basal = autoencoder.get_latent( 50 | dataset.raw_genes, 51 | dataset.drugs, 52 | dataset.cell_types, 53 | proteins=dataset.raw_proteins, 54 | return_latent_treated=False) 55 | else: 56 | raise ValueError("Autoencoder loss must be either 'nb' or 'gauss'.") 57 | 58 | latent_basal = latent_basal.detach().cpu().numpy() 59 | 60 | if nonlinear: 61 | clf = KNeighborsClassifier( 62 | n_neighbors=int(np.sqrt(len(latent_basal)))) 63 | else: 64 | clf = LogisticRegression(solver="liblinear", 65 | multi_class="auto", 66 | max_iter=10000) 67 | 68 | pert_scores = cross_val_score( 69 | clf, 70 | StandardScaler().fit_transform(latent_basal), dataset.drugs_names, 71 | scoring=make_scorer(balanced_accuracy_score), cv=5, n_jobs=-1) 72 | 73 | if len(np.unique(dataset.cell_types_names)) > 1: 74 | cov_scores = cross_val_score( 75 | clf, 76 | StandardScaler().fit_transform(latent_basal), dataset.cell_types_names, 77 | scoring=make_scorer(balanced_accuracy_score), cv=5, n_jobs=-1) 78 | return np.mean(pert_scores), np.mean(cov_scores) 79 | else: 80 | return np.mean(pert_scores), 0 81 | 82 | 83 | def evaluate_r2(autoencoder, dataset, genes_control, proteins_control, sample=False): 84 | """ 85 | Measures different quality metrics about an ComPert `autoencoder`, when 86 | tasked to translate some `genes_control` into each of the drug/cell_type 87 | combinations described in `dataset`. 88 | 89 | Considered metrics are R2 score about means and variances for all genes, as 90 | well as R2 score about means and variances about differentially expressed 91 | (_de) genes. 92 | 93 | If protein data is available, the R2 score for the protein data is computed separately. 94 | 95 | For computing the R2 score, one can take the predicted mean or sample from the decoder distribution. 96 | """ 97 | 98 | mean_score_genes, var_score_genes, mean_score_genes_de, var_score_genes_de, \ 99 | mean_score_proteins, var_score_proteins = [], [], [], [], [], [] 100 | num, dim_genes = genes_control.size(0), genes_control.size(1) 101 | 102 | if autoencoder.num_proteins is not None: 103 | dim_proteins = proteins_control.size(1) 104 | 105 | for pert_category in np.unique(dataset.pert_categories): 106 | # pert_category category contains: 'celltype_perturbation_dose' info 107 | de_idx = np.where( 108 | dataset.var_names.isin( 109 | np.array(dataset.de_genes[pert_category])))[0] 110 | 111 | idx = np.where(dataset.pert_categories == pert_category)[0] 112 | 113 | if len(idx) > 30: 114 | emb_drugs = dataset.drugs[idx][0].view( 115 | 1, -1).repeat(num, 1).clone() 116 | emb_cts = dataset.cell_types[idx][0].view( 117 | 1, -1).repeat(num, 1).clone() 118 | 119 | if sample: 120 | # sample from the decoder distribution 121 | gene_predictions, protein_predictions = autoencoder.sample( 122 | genes_control, emb_drugs, emb_cts, proteins_control) 123 | 124 | mean_predict_genes = gene_predictions[:, :dim_genes] 125 | var_predict_genes = gene_predictions[:, dim_genes:] 126 | 127 | if autoencoder.num_proteins is not None: 128 | mean_predict_proteins = protein_predictions[:, :dim_proteins] 129 | var_predict_proteins = protein_predictions[:, dim_proteins:] 130 | else: 131 | # take the predicted means instead of sampling from the decoder distribution 132 | gene_predictions, protein_predictions = autoencoder.predict( 133 | genes_control, emb_drugs, emb_cts, proteins_control) 134 | 135 | if isinstance(gene_predictions, list): 136 | gene_predictions = gene_predictions[-1] 137 | 138 | gene_predictions = gene_predictions.detach().cpu() 139 | 140 | mean_predict_genes = gene_predictions[:, :dim_genes] 141 | if autoencoder.loss_ae == 'nb': 142 | # compute variance based on dispersion 143 | var_predict_genes = mean_predict_genes + (mean_predict_genes ** 2) / \ 144 | gene_predictions[:, dim_genes:] 145 | else: 146 | # take the predicted variance estimates 147 | var_predict_genes = gene_predictions[:, dim_genes:] 148 | 149 | if autoencoder.num_proteins is not None: 150 | if isinstance(protein_predictions, list): 151 | protein_predictions = protein_predictions[-1] 152 | 153 | protein_predictions = protein_predictions.detach().cpu() 154 | 155 | mean_predict_proteins = protein_predictions[:, :dim_proteins] 156 | if autoencoder.loss_ae == 'nb': 157 | # compute variance based on dispersion 158 | var_predict_proteins = mean_predict_proteins + (mean_predict_proteins ** 2) / \ 159 | protein_predictions[:, dim_proteins:] 160 | else: 161 | # take the predicted variance estimates 162 | var_predict_proteins = protein_predictions[:, dim_proteins:] 163 | 164 | # estimate metrics only for reasonably-sized drug/cell-type combos 165 | if autoencoder.loss_ae == 'gauss': 166 | y_true_genes = dataset.genes[idx, :].numpy() 167 | elif autoencoder.loss_ae == 'nb': 168 | y_true_genes = dataset.raw_genes[idx, :].numpy() 169 | else: 170 | raise ValueError("Autoencoder loss must be either 'nb' or 'gauss'.") 171 | 172 | # true means and variances 173 | yt_m_genes = y_true_genes.mean(axis=0) 174 | yt_v_genes = y_true_genes.var(axis=0) 175 | # predicted means and variances 176 | if sample: 177 | yp_m_genes = mean_predict_genes.mean(0) 178 | yp_v_genes = var_predict_genes.var(0) 179 | else: 180 | yp_m_genes = mean_predict_genes.mean(0) 181 | yp_v_genes = var_predict_genes.mean(0) 182 | 183 | 184 | mean_score_genes.append(r2_score(yt_m_genes, yp_m_genes)) 185 | var_score_genes.append(r2_score(yt_v_genes, yp_v_genes)) 186 | 187 | mean_score_genes_de.append(r2_score(yt_m_genes[de_idx], yp_m_genes[de_idx])) 188 | var_score_genes_de.append(r2_score(yt_v_genes[de_idx], yp_v_genes[de_idx])) 189 | 190 | if autoencoder.num_proteins is not None: 191 | # estimate metrics only for reasonably-sized drug/cell-type combos 192 | if autoencoder.loss_ae == 'gauss': 193 | y_true_proteins = dataset.proteins[idx, :].numpy() 194 | elif autoencoder.loss_ae == 'nb': 195 | y_true_proteins = dataset.raw_proteins[idx, :].numpy() 196 | 197 | # true means and variances 198 | yt_m_proteins = y_true_proteins.mean(axis=0) 199 | yt_v_proteins = y_true_proteins.var(axis=0) 200 | # predicted means and variances 201 | if sample: 202 | yp_m_proteins = mean_predict_proteins.mean(0) 203 | yp_v_proteins = var_predict_proteins.var(0) 204 | else: 205 | yp_m_proteins = mean_predict_proteins.mean(0) 206 | yp_v_proteins = var_predict_proteins.mean(0) 207 | 208 | if len(yt_m_proteins) > 0: 209 | mean_score_proteins.append(r2_score(yt_m_proteins, yp_m_proteins)) 210 | var_score_proteins.append(r2_score(yt_v_proteins, yp_v_proteins)) 211 | else: 212 | mean_score_proteins.append(0) 213 | var_score_proteins.append(0) 214 | 215 | 216 | return [np.mean(s) if len(s) else -1 217 | for s in [mean_score_genes, var_score_genes, mean_score_genes_de, 218 | var_score_genes_de, mean_score_proteins, var_score_proteins]] 219 | 220 | 221 | def evaluate(autoencoder, datasets): 222 | """ 223 | Measure quality metrics using `evaluate()` on the training, test, and 224 | out-of-distributiion (ood) splits. 225 | """ 226 | 227 | autoencoder.eval() 228 | if autoencoder.loss_ae == 'gauss': 229 | # use the normalized counts to evaluate the model 230 | with torch.no_grad(): 231 | stats_test = evaluate_r2( 232 | autoencoder, 233 | datasets["test_treated"], 234 | datasets["test_control"].genes, 235 | datasets["test_control"].proteins) 236 | 237 | stats_disent_pert, stats_disent_cov = evaluate_disentanglement( 238 | autoencoder, datasets["test"]) 239 | 240 | evaluation_stats = { 241 | "training": evaluate_r2( 242 | autoencoder, 243 | datasets["training_treated"], 244 | datasets["training_control"].genes, 245 | datasets["training_control"].proteins), 246 | "test": stats_test, 247 | "ood": evaluate_r2( 248 | autoencoder, 249 | datasets["ood"], 250 | datasets["test_control"].genes, 251 | datasets["test_control"].proteins), 252 | "perturbation disentanglement": stats_disent_pert, 253 | "optimal for perturbations": 1 / datasets['test'].num_drugs, 254 | "covariate disentanglement": stats_disent_cov, 255 | "optimal for covariates": 1 / datasets['test'].num_cell_types, 256 | } 257 | elif autoencoder.loss_ae == 'nb': 258 | # use the raw counts to evaluate the model 259 | with torch.no_grad(): 260 | stats_test = evaluate_r2( 261 | autoencoder, 262 | datasets["test_treated"], 263 | datasets["test_control"].raw_genes, 264 | datasets["test_control"].raw_proteins) 265 | 266 | stats_disent_pert, stats_disent_cov = evaluate_disentanglement( 267 | autoencoder, datasets["test"]) 268 | 269 | evaluation_stats = { 270 | "training": evaluate_r2( 271 | autoencoder, 272 | datasets["training_treated"], 273 | datasets["training_control"].raw_genes, 274 | datasets["training_control"].raw_proteins), 275 | "test": stats_test, 276 | "ood": evaluate_r2( 277 | autoencoder, 278 | datasets["ood"], 279 | datasets["test_control"].raw_genes, 280 | datasets["test_control"].raw_proteins), 281 | "perturbation disentanglement": stats_disent_pert, 282 | "optimal for perturbations": 1 / datasets['test'].num_drugs, 283 | "covariate disentanglement": stats_disent_cov, 284 | "optimal for covariates": 1 / datasets['test'].num_cell_types, 285 | } 286 | else: 287 | raise ValueError("Autoencoder loss must be either 'nb' or 'gauss'.") 288 | 289 | autoencoder.train() 290 | return evaluation_stats 291 | 292 | 293 | def prepare_compert(args, model='ComPert', state_dict=None): 294 | """ 295 | Instantiates autoencoder and dataset to run an experiment. 296 | """ 297 | 298 | device = "cuda" if torch.cuda.is_available() else "cpu" 299 | 300 | datasets = load_dataset_splits( 301 | args["dataset_path"], 302 | args["perturbation_key"], 303 | args["dose_key"], 304 | args["cell_type_key"], 305 | args["split_key"], 306 | args["raw_counts_key"], 307 | args["protein_key"], 308 | args["raw_protein_key"], 309 | ) 310 | 311 | if model == 'ComPert': 312 | autoencoder = ComPert( 313 | datasets["training"].num_genes, 314 | datasets["training"].num_drugs, 315 | datasets["training"].num_cell_types, 316 | num_proteins=datasets["training"].num_proteins, 317 | device=device, 318 | seed=args["seed"], 319 | loss_ae=args["loss_ae"], 320 | doser_type=args["doser_type"], 321 | patience=args["patience"], 322 | hparams=args["hparams"], 323 | decoder_activation=args["decoder_activation"], 324 | is_vae=args["is_vae"], 325 | ) 326 | elif model == 'TotalComPert': 327 | autoencoder = TotalComPert( 328 | datasets["training"].num_genes, 329 | datasets["training"].num_drugs, 330 | datasets["training"].num_cell_types, 331 | num_proteins=datasets["training"].num_proteins, 332 | device=device, 333 | seed=args["seed"], 334 | loss_ae=args["loss_ae"], 335 | doser_type=args["doser_type"], 336 | patience=args["patience"], 337 | hparams=args["hparams"], 338 | decoder_activation=args["decoder_activation"], 339 | is_vae=args["is_vae"], 340 | ) 341 | elif model == 'PoEComPert': 342 | autoencoder = PoEComPert( 343 | datasets["training"].num_genes, 344 | datasets["training"].num_drugs, 345 | datasets["training"].num_cell_types, 346 | num_proteins=datasets["training"].num_proteins, 347 | device=device, 348 | seed=args["seed"], 349 | loss_ae=args["loss_ae"], 350 | doser_type=args["doser_type"], 351 | patience=args["patience"], 352 | hparams=args["hparams"], 353 | decoder_activation=args["decoder_activation"], 354 | is_vae=args["is_vae"], 355 | ) 356 | elif model == 'TotalPoEComPert': 357 | autoencoder = TotalPoEComPert( 358 | datasets["training"].num_genes, 359 | datasets["training"].num_drugs, 360 | datasets["training"].num_cell_types, 361 | num_proteins=datasets["training"].num_proteins, 362 | device=device, 363 | seed=args["seed"], 364 | loss_ae=args["loss_ae"], 365 | doser_type=args["doser_type"], 366 | patience=args["patience"], 367 | hparams=args["hparams"], 368 | decoder_activation=args["decoder_activation"], 369 | is_vae=args["is_vae"], 370 | ) 371 | else: 372 | raise NotImplementedError("The model architecture {} is not implemented!".format(model)) 373 | 374 | if state_dict is not None: 375 | autoencoder.load_state_dict(state_dict) 376 | 377 | return autoencoder, datasets 378 | 379 | 380 | def train_compert(args, model='ComPert', return_model=False): 381 | """ 382 | Trains a ComPert autoencoder 383 | """ 384 | 385 | autoencoder, datasets = prepare_compert(args, model) 386 | 387 | datasets.update({ 388 | "loader_tr": torch.utils.data.DataLoader( 389 | datasets["training"], 390 | batch_size=autoencoder.hparams["batch_size"], 391 | shuffle=True) 392 | }) 393 | 394 | pjson({"training_args": args}) 395 | pjson({"autoencoder_params": autoencoder.hparams}) 396 | 397 | start_time = time.time() 398 | for epoch in range(args["max_epochs"]): 399 | epoch_training_stats = defaultdict(float) 400 | 401 | for genes, drugs, cell_types, proteins, raw_genes, raw_proteins in datasets["loader_tr"]: 402 | minibatch_training_stats = autoencoder.update( 403 | genes, drugs, cell_types, proteins, raw_genes, raw_proteins, epoch, args["max_epochs"]) 404 | 405 | for key, val in minibatch_training_stats.items(): 406 | epoch_training_stats[key] += val 407 | 408 | for key, val in epoch_training_stats.items(): 409 | epoch_training_stats[key] = val / len(datasets["loader_tr"]) 410 | if not (key in autoencoder.history.keys()): 411 | autoencoder.history[key] = [] 412 | autoencoder.history[key].append(epoch_training_stats[key]) 413 | autoencoder.history['epoch'].append(epoch) 414 | 415 | ellapsed_minutes = (time.time() - start_time) / 60 416 | autoencoder.history['elapsed_time_min'] = ellapsed_minutes 417 | 418 | # decay learning rate if necessary 419 | # also check stopping condition: patience ran out OR 420 | # time ran out OR max epochs achieved 421 | stop = ellapsed_minutes > args["max_minutes"] or \ 422 | (epoch == args["max_epochs"] - 1) 423 | 424 | if (epoch % args["checkpoint_freq"]) == 0 or stop: 425 | evaluation_stats = evaluate(autoencoder, datasets) 426 | for key, val in evaluation_stats.items(): 427 | if not (key in autoencoder.history.keys()): 428 | autoencoder.history[key] = [] 429 | autoencoder.history[key].append(val) 430 | autoencoder.history['stats_epoch'].append(epoch) 431 | 432 | pjson({ 433 | "epoch": epoch, 434 | "training_stats": epoch_training_stats, 435 | "evaluation_stats": evaluation_stats, 436 | "ellapsed_minutes": ellapsed_minutes 437 | }) 438 | 439 | torch.save( 440 | (autoencoder.state_dict(), args, autoencoder.history), 441 | os.path.join( 442 | args["save_dir"], 443 | "model_seed={}_epoch={}.pt".format(args["seed"], epoch))) 444 | 445 | pjson({"model_saved": "model_seed={}_epoch={}.pt\n".format( 446 | args["seed"], epoch)}) 447 | stop = stop or autoencoder.early_stopping( 448 | np.mean(evaluation_stats["test"])) #or autoencoder.specific_threshold( 449 | #evaluation_stats["test"][0], epoch, epoch_thr=361, score_thr=0.15) # 0->gene 450 | if stop: 451 | pjson({"early_stop": epoch}) 452 | break 453 | 454 | if return_model: 455 | return autoencoder, datasets 456 | 457 | 458 | def parse_arguments(): 459 | """ 460 | Read arguments if this script is called from a terminal. 461 | """ 462 | 463 | parser = argparse.ArgumentParser(description='Drug combinations.') 464 | # dataset arguments 465 | parser.add_argument('--dataset_path', type=str, required=True) 466 | parser.add_argument('--perturbation_key', type=str, default="condition") 467 | parser.add_argument('--dose_key', type=str, default="dose_val") 468 | parser.add_argument('--cell_type_key', type=str, default="cell_type") 469 | parser.add_argument('--split_key', type=str, default="split") 470 | parser.add_argument('--loss_ae', type=str, default='gauss') 471 | parser.add_argument('--doser_type', type=str, default='sigm') 472 | parser.add_argument('--decoder_activation', type=str, default='linear') 473 | 474 | # ComPert arguments (see set_hparams_() in MultiCPA.model.ComPert) 475 | parser.add_argument('--seed', type=int, default=0) 476 | parser.add_argument('--hparams', type=str, default="") 477 | 478 | # training arguments 479 | parser.add_argument('--max_epochs', type=int, default=2000) 480 | parser.add_argument('--max_minutes', type=int, default=300) 481 | parser.add_argument('--patience', type=int, default=20) 482 | parser.add_argument('--checkpoint_freq', type=int, default=20) 483 | 484 | # output folder 485 | parser.add_argument('--save_dir', type=str, required=True) 486 | # number of trials when executing MultiCPA.sweep 487 | parser.add_argument('--sweep_seeds', type=int, default=200) 488 | return dict(vars(parser.parse_args())) 489 | 490 | 491 | if __name__ == "__main__": 492 | train_compert(parse_arguments()) 493 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MultiCPA 2 | 3 | `MultiCPA` is a research project from a computatiobal biology group of Prof. Fabian 4 | Theis (https://github.com/theislab) from Helmholtz Zentrum München. 5 | 6 | ## What is MultiCPA? 7 | ![Screenshot](Figure.png) 8 | 9 | `MultiCPA` is a framework to learn effects of perturbations at the single-cell level for multiple modalities: `proteins` and `mRNAs`. 10 | MultiCPA encodes and learns phenotypic drug response across different cell types, doses and drug combinations. MultiCPA allows: 11 | 12 | * Out-of-distribution predicitons of unseen drug combinations at various doses and among different cell types. 13 | * Learn interpretable drug and cell type latent spaces. 14 | * Estimate dose response curve for each perturbation and their combinations. 15 | * Access the uncertainty of the estimations of the model. 16 | 17 | ## Package Structure 18 | 19 | The repository is centered around the `MultiCPA` module: 20 | 21 | * [`MultiCPA.train`](MultiCPA/train.py) contains scripts to train the model. 22 | * [`MultiCPA.api`](MultiCPA/api.py) contains user friendly scripts to interact with the model via scanpy. 23 | * [`MultiCPA.plotting`](MultiCPA/plotting.py) contains scripts to plotting functions. 24 | * [`MultiCPA.model`](MultiCPA/model.py) contains modules of compert model. 25 | * [`MultiCPA.data`](MultiCPA/data.py) contains data loader, which transforms anndata structure to a class compatible with compert model. 26 | 27 | Additional files and folders for reproducibility are found in another repository: [multicpa-reproducibility](https://github.com/theislab/multicpa-reproducibility) 28 | 29 | * [`datasets`](datasets/) contains both versions of the data: raw and pre-processed. 30 | * [`preprocessing`](preprocessing/) contains notebooks to reproduce the datasets pre-processing from raw data. 31 | * [`notebooks`](notebooks/) contains notebooks to reproduce plots from the paper and detailed analysis of each of the datasets. 32 | * [`figures`](figures/) contains figures after running the notebooks. 33 | 34 | Note that the codebase was build on top of `CPA` model. 35 | 36 | ## Usage 37 | 38 | To learn how to use this repository, check [example_training.ipynb](https://github.com/theislab/multicpa-reproducibility/blob/main/notebooks/example_training.ipynb). 39 | Note that hyperparameters in the demo are not default and will not work for new datasets. Please make 40 | sure to run `seml` sweeps for your new dataset to find best hyperparameters. Provided Conda environments are strongly recommended. 41 | 42 | ## Examples and Reproducibility 43 | All the examples and the reproducibility notebooks for the plots in the paper could be found in the [multicpa-reproducibility](https://github.com/theislab/multicpa-reproducibility) repository. 44 | 45 | ## Documentation 46 | 47 | Currently, you can access the documentation via `help` function in IPython. For example: 48 | 49 | ```python 50 | from MultiCPA.api import ComPertAPI 51 | 52 | help(ComPertAPI) 53 | 54 | from MultiCPA.plotting import CompertVisuals 55 | 56 | help(CompertVisuals) 57 | 58 | ``` 59 | 60 | A separate page with the documentation is coming soon. 61 | 62 | ## Support and contribute 63 | 64 | If you have a question or noticed a problem, you can post an [`issue`](https://github.com/theislab/multicpa/). 65 | 66 | ## License 67 | 68 | This source code is released under the BSD 3-Clause License, included [here](LICENSE). 69 | -------------------------------------------------------------------------------- /environment_multicpa.yml: -------------------------------------------------------------------------------- 1 | name: multicpa_env 2 | channels: 3 | - plotly 4 | - pytorch 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=2_kmp_llvm 10 | - alabaster=0.7.12=py_0 11 | - anndata=0.8.0=py38h578d9bd_0 12 | - anyio=3.6.1=py38h578d9bd_0 13 | - argon2-cffi=21.3.0=pyhd8ed1ab_0 14 | - argon2-cffi-bindings=21.2.0=py38h0a891b7_2 15 | - arpack=3.7.0=hc6cf775_2 16 | - asttokens=2.0.5=pyhd8ed1ab_0 17 | - attrs=21.4.0=pyhd8ed1ab_0 18 | - babel=2.10.3=pyhd8ed1ab_0 19 | - backcall=0.2.0=pyh9f0ad1d_0 20 | - backports=1.0=py_2 21 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 22 | - beautifulsoup4=4.11.1=pyha770c72_0 23 | - blas=1.0=mkl 24 | - bleach=5.0.1=pyhd8ed1ab_0 25 | - bottleneck=1.3.5=py38h7deecbd_0 26 | - brotli=1.0.9=he6710b0_2 27 | - brotlipy=0.7.0=py38h0a891b7_1004 28 | - bzip2=1.0.8=h7b6447c_0 29 | - c-ares=1.18.1=h7f98852_0 30 | - ca-certificates=2022.6.15=ha878542_0 31 | - cached-property=1.5.2=hd8ed1ab_1 32 | - cached_property=1.5.2=pyha770c72_1 33 | - certifi=2022.6.15=py38h06a4308_0 34 | - cffi=1.15.0=py38hd667e15_1 35 | - charset-normalizer=2.1.0=pyhd8ed1ab_0 36 | - colorama=0.4.5=pyhd8ed1ab_0 37 | - cryptography=37.0.4=py38h2b5fc30_0 38 | - cudatoolkit=11.3.1=h2bc3f7f_2 39 | - cycler=0.11.0=pyhd3eb1b0_0 40 | - dbus=1.13.18=hb2f20db_0 41 | - debugpy=1.6.0=py38hfa26641_0 42 | - decorator=5.1.1=pyhd8ed1ab_0 43 | - defusedxml=0.7.1=pyhd8ed1ab_0 44 | - docutils=0.16=py38h578d9bd_3 45 | - entrypoints=0.4=pyhd8ed1ab_0 46 | - executing=0.8.3=pyhd8ed1ab_0 47 | - expat=2.4.4=h295c915_0 48 | - ffmpeg=4.3=hf484d3e_0 49 | - flit-core=3.7.1=pyhd8ed1ab_0 50 | - fontconfig=2.13.1=h6c09931_0 51 | - fonttools=4.25.0=pyhd3eb1b0_0 52 | - freetype=2.11.0=h70c0345_0 53 | - giflib=5.2.1=h7b6447c_0 54 | - glib=2.69.1=h4ff587b_1 55 | - glpk=4.65=h9202a9a_1004 56 | - gmp=6.2.1=h58526e2_0 57 | - gnutls=3.6.15=he1e5248_0 58 | - gst-plugins-base=1.14.0=h8213a91_2 59 | - gstreamer=1.14.0=h28cd5cc_2 60 | - h5py=3.2.1=nompi_py38h9915d05_100 61 | - hdf5=1.10.6=nompi_h7c3c948_1111 62 | - icu=58.2=he6710b0_3 63 | - idna=3.3=pyhd8ed1ab_0 64 | - igraph=0.9.9=h026ac8f_0 65 | - imagesize=1.4.1=pyhd8ed1ab_0 66 | - importlib-metadata=4.11.4=py38h578d9bd_0 67 | - importlib_metadata=4.11.4=hd8ed1ab_0 68 | - importlib_resources=5.8.0=pyhd8ed1ab_0 69 | - intel-openmp=2021.4.0=h06a4308_3561 70 | - ipykernel=6.15.0=pyh210e3f2_0 71 | - ipython=8.4.0=py38h578d9bd_0 72 | - ipython_genutils=0.2.0=py_1 73 | - jedi=0.18.1=py38h578d9bd_1 74 | - jinja2=3.1.2=pyhd8ed1ab_1 75 | - joblib=1.1.0=pyhd3eb1b0_0 76 | - jpeg=9e=h7f8727e_0 77 | - json5=0.9.5=pyh9f0ad1d_0 78 | - jsonschema=4.6.1=pyhd8ed1ab_0 79 | - jupyter_client=7.3.4=pyhd8ed1ab_0 80 | - jupyter_core=4.10.0=py38h578d9bd_0 81 | - jupyter_server=1.18.1=pyhd8ed1ab_0 82 | - jupyterlab=3.4.3=pyhd8ed1ab_0 83 | - jupyterlab_pygments=0.2.2=pyhd8ed1ab_0 84 | - jupyterlab_server=2.15.0=pyhd8ed1ab_0 85 | - keyutils=1.6.1=h166bdaf_0 86 | - kiwisolver=1.4.2=py38h295c915_0 87 | - krb5=1.19.3=h3790be6_0 88 | - lame=3.100=h7b6447c_0 89 | - lcms2=2.12=h3be6417_0 90 | - ld_impl_linux-64=2.38=h1181459_1 91 | - leidenalg=0.8.10=py38hfa26641_0 92 | - libblas=3.9.0=12_linux64_mkl 93 | - libcblas=3.9.0=12_linux64_mkl 94 | - libcurl=7.83.1=h7bff187_0 95 | - libedit=3.1.20191231=he28a2e2_2 96 | - libev=4.33=h516909a_1 97 | - libffi=3.3=he6710b0_2 98 | - libgcc-ng=12.1.0=h8d9b700_16 99 | - libgfortran-ng=7.5.0=ha8ba4b0_17 100 | - libgfortran4=7.5.0=ha8ba4b0_17 101 | - libiconv=1.16=h7f8727e_2 102 | - libidn2=2.3.2=h7f8727e_0 103 | - liblapack=3.9.0=12_linux64_mkl 104 | - libllvm11=11.1.0=hf817b99_3 105 | - libnghttp2=1.47.0=h727a467_0 106 | - libpng=1.6.37=hbc83047_0 107 | - libsodium=1.0.18=h36c2ea0_1 108 | - libssh2=1.10.0=ha56f1ee_2 109 | - libstdcxx-ng=12.1.0=ha89aaad_16 110 | - libtasn1=4.16.0=h27cfd23_0 111 | - libtiff=4.2.0=h2818925_1 112 | - libunistring=0.9.10=h27cfd23_0 113 | - libuuid=1.0.3=h7f8727e_2 114 | - libwebp=1.2.2=h55f646e_0 115 | - libwebp-base=1.2.2=h7f8727e_0 116 | - libxcb=1.15=h7f8727e_0 117 | - libxml2=2.9.14=h74e7548_0 118 | - libzlib=1.2.12=h166bdaf_1 119 | - llvm-openmp=14.0.4=he0ac6c6_0 120 | - llvmlite=0.38.1=py38h38d86a4_0 121 | - lz4-c=1.9.3=h295c915_1 122 | - markupsafe=2.1.1=py38h0a891b7_1 123 | - matplotlib=3.5.1=py38h06a4308_1 124 | - matplotlib-base=3.5.1=py38ha18d171_1 125 | - matplotlib-inline=0.1.3=pyhd8ed1ab_0 126 | - metis=5.1.0=h58526e2_1006 127 | - mistune=0.8.4=py38h497a2fe_1005 128 | - mkl=2021.4.0=h06a4308_640 129 | - mkl-service=2.4.0=py38h7f8727e_0 130 | - mkl_fft=1.3.1=py38hd3c417c_0 131 | - mkl_random=1.2.2=py38h51133e4_0 132 | - mpfr=4.1.0=h9202a9a_1 133 | - munkres=1.1.4=py_0 134 | - natsort=8.1.0=pyhd8ed1ab_0 135 | - nbclassic=0.3.7=pyhd8ed1ab_0 136 | - nbclient=0.6.6=pyhd8ed1ab_0 137 | - nbconvert=6.5.0=pyhd8ed1ab_0 138 | - nbconvert-core=6.5.0=pyhd8ed1ab_0 139 | - nbconvert-pandoc=6.5.0=pyhd8ed1ab_0 140 | - nbformat=5.4.0=pyhd8ed1ab_0 141 | - ncurses=6.3=h5eee18b_3 142 | - nest-asyncio=1.5.5=pyhd8ed1ab_0 143 | - nettle=3.7.3=hbbd107a_1 144 | - networkx=2.8.4=pyhd8ed1ab_0 145 | - notebook=6.4.12=pyha770c72_0 146 | - notebook-shim=0.1.0=pyhd8ed1ab_0 147 | - numba=0.55.2=py38hdc3674a_0 148 | - numexpr=2.8.3=py38h807cd23_0 149 | - numpy=1.22.3=py38he7a7128_0 150 | - numpy-base=1.22.3=py38hf524024_0 151 | - openh264=2.1.1=h4ff587b_0 152 | - openssl=1.1.1q=h166bdaf_0 153 | - packaging=21.3=pyhd3eb1b0_0 154 | - pandas=1.4.2=py38h295c915_0 155 | - pandoc=2.18=ha770c72_0 156 | - pandocfilters=1.5.0=pyhd8ed1ab_0 157 | - parso=0.8.3=pyhd8ed1ab_0 158 | - patsy=0.5.2=pyhd8ed1ab_0 159 | - pcre=8.45=h295c915_0 160 | - pexpect=4.8.0=pyh9f0ad1d_2 161 | - pickleshare=0.7.5=py_1003 162 | - pillow=9.0.1=py38h22f2fdc_0 163 | - pip=21.2.4=py38h06a4308_0 164 | - plotly=5.9.0=py_0 165 | - prometheus_client=0.14.1=pyhd8ed1ab_0 166 | - prompt-toolkit=3.0.30=pyha770c72_0 167 | - psutil=5.9.1=py38h0a891b7_0 168 | - ptyprocess=0.7.0=pyhd3deb0d_0 169 | - pure_eval=0.2.2=pyhd8ed1ab_0 170 | - pycparser=2.21=pyhd8ed1ab_0 171 | - pygments=2.12.0=pyhd8ed1ab_0 172 | - pynndescent=0.5.7=pyh6c4a22f_0 173 | - pyopenssl=22.0.0=pyhd8ed1ab_0 174 | - pyparsing=3.0.4=pyhd3eb1b0_0 175 | - pyqt=5.9.2=py38h05f1152_4 176 | - pyrsistent=0.18.1=py38h0a891b7_1 177 | - pysocks=1.7.1=py38h578d9bd_5 178 | - python=3.8.13=h12debd9_0 179 | - python-dateutil=2.8.2=pyhd3eb1b0_0 180 | - python-fastjsonschema=2.15.3=pyhd8ed1ab_0 181 | - python-igraph=0.9.11=py38hd0e5696_0 182 | - python_abi=3.8=2_cp38 183 | - pytorch=1.12.0=py3.8_cuda11.3_cudnn8.3.2_0 184 | - pytorch-mutex=1.0=cuda 185 | - pytz=2022.1=py38h06a4308_0 186 | - pyzmq=23.2.0=py38hfc09fa9_0 187 | - qt=5.9.7=h5867ecd_1 188 | - readline=8.1.2=h7f8727e_1 189 | - requests=2.28.1=pyhd8ed1ab_0 190 | - scanpy=1.9.1=pyhd8ed1ab_0 191 | - scikit-learn=1.0.2=py38h51133e4_1 192 | - scipy=1.7.3=py38hc147768_0 193 | - seaborn=0.11.2=pyhd3eb1b0_0 194 | - send2trash=1.8.0=pyhd8ed1ab_0 195 | - session-info=1.0.0=pyhd8ed1ab_0 196 | - sip=4.19.13=py38h295c915_0 197 | - six=1.16.0=pyhd3eb1b0_1 198 | - sniffio=1.2.0=py38h578d9bd_3 199 | - snowballstemmer=2.2.0=pyhd8ed1ab_0 200 | - soupsieve=2.3.1=pyhd8ed1ab_0 201 | - sphinx=5.0.2=pyh6c4a22f_0 202 | - sphinxcontrib-applehelp=1.0.2=py_0 203 | - sphinxcontrib-devhelp=1.0.2=py_0 204 | - sphinxcontrib-htmlhelp=2.0.0=pyhd8ed1ab_0 205 | - sphinxcontrib-jsmath=1.0.1=py_0 206 | - sphinxcontrib-qthelp=1.0.3=py_0 207 | - sphinxcontrib-serializinghtml=1.1.5=pyhd8ed1ab_2 208 | - sqlite=3.38.5=hc218d9a_0 209 | - stack_data=0.3.0=pyhd8ed1ab_0 210 | - statsmodels=0.13.2=py38h71d37f0_0 211 | - stdlib-list=0.7.0=py_2 212 | - suitesparse=5.10.1=h9e50725_1 213 | - tbb=2021.5.0=h924138e_1 214 | - tenacity=8.0.1=py38h06a4308_0 215 | - terminado=0.15.0=py38h578d9bd_0 216 | - texttable=1.6.4=pyhd8ed1ab_0 217 | - threadpoolctl=2.2.0=pyh0d69192_0 218 | - tinycss2=1.1.1=pyhd8ed1ab_0 219 | - tk=8.6.12=h1ccaba5_0 220 | - torchaudio=0.12.0=py38_cu113 221 | - torchvision=0.13.0=py38_cu113 222 | - tornado=6.1=py38h27cfd23_0 223 | - tqdm=4.64.0=pyhd8ed1ab_0 224 | - traitlets=5.3.0=pyhd8ed1ab_0 225 | - typing_extensions=4.1.1=pyh06a4308_0 226 | - umap-learn=0.5.3=py38h578d9bd_0 227 | - urllib3=1.26.9=pyhd8ed1ab_0 228 | - wcwidth=0.2.5=pyh9f0ad1d_2 229 | - webencodings=0.5.1=py_1 230 | - websocket-client=1.3.3=pyhd8ed1ab_0 231 | - wheel=0.37.1=pyhd3eb1b0_0 232 | - xz=5.2.5=h7f8727e_1 233 | - zeromq=4.3.4=h9c3ff4c_1 234 | - zipp=3.8.0=pyhd8ed1ab_0 235 | - zlib=1.2.12=h166bdaf_1 236 | - zstd=1.5.2=ha4553b6_0 237 | - pip: 238 | - absl-py==1.1.0 239 | - adjusttext==0.7.3 240 | - aiohttp==3.8.1 241 | - aiosignal==1.2.0 242 | - async-timeout==4.0.2 243 | - cachetools==5.2.0 244 | - chex==0.1.3 245 | - commonmark==0.9.1 246 | - dm-tree==0.1.7 247 | - docopt==0.6.2 248 | - docrep==0.3.2 249 | - et-xmlfile==1.1.0 250 | - etils==0.6.0 251 | - flatbuffers==2.0 252 | - flax==0.5.2 253 | - frozenlist==1.3.0 254 | - fsspec==2022.5.0 255 | - future==0.18.2 256 | - gitdb==4.0.9 257 | - gitpython==3.1.27 258 | - google-auth==2.9.0 259 | - google-auth-oauthlib==0.4.6 260 | - grpcio==1.47.0 261 | - ipywidgets==7.7.1 262 | - jax==0.3.14 263 | - jaxlib==0.3.14 264 | - jsonpickle==1.5.2 265 | - jupyterlab-widgets==1.1.1 266 | - markdown==3.3.7 267 | - msgpack==1.0.4 268 | - multidict==6.0.2 269 | - multipledispatch==0.6.0 270 | - munch==2.5.0 271 | - numpyro==0.10.0 272 | - oauthlib==3.2.0 273 | - openpyxl==3.0.10 274 | - opt-einsum==3.3.0 275 | - optax==0.1.2 276 | - protobuf==3.19.4 277 | - py-cpuinfo==8.0.0 278 | - pyasn1==0.4.8 279 | - pyasn1-modules==0.2.8 280 | - pydeprecate==0.3.1 281 | - pymongo==4.1.1 282 | - pyro-api==0.1.2 283 | - pyro-ppl==1.8.1 284 | - pytorch-lightning==1.5.10 285 | - pyyaml==6.0 286 | - requests-oauthlib==1.3.1 287 | - rich==11.2.0 288 | - rsa==4.8 289 | - sacred==0.8.2 290 | - scvi-tools==0.16.4 291 | - seml==0.3.6 292 | - setuptools==59.5.0 293 | - smmap==5.0.0 294 | - tensorboard==2.9.1 295 | - tensorboard-data-server==0.6.1 296 | - tensorboard-plugin-wit==1.8.1 297 | - toolz==0.11.2 298 | - torchmetrics==0.9.2 299 | - werkzeug==2.1.2 300 | - widgetsnbextension==3.6.1 301 | - wrapt==1.14.1 302 | - yarl==1.7.2 303 | prefix: /home/icb/kemal.inecik/miniconda3/envs/multicpa_env 304 | --------------------------------------------------------------------------------