├── LICENSE ├── README.md ├── analysis ├── kexin │ ├── README.md │ ├── __init__.py │ ├── analysis_on_norman_adamson.ipynb │ ├── analyze_fold_change.ipynb │ ├── construct train coexpresssion.ipynb │ ├── data.py │ ├── data_analysis_more.ipynb │ ├── data_splits.ipynb │ ├── error_analysis.ipynb │ ├── error_analysis_upreg.ipynb │ ├── evaluate_demo.ipynb │ ├── evaluate_saved_model-More_Metrics.ipynb │ ├── evaluate_saved_model.ipynb │ ├── examine_io.ipynb │ ├── gene-level-variation_analysis.ipynb │ ├── get_more_de.ipynb │ ├── get_signature.ipynb │ ├── gnn_generalization_test-new_model.ipynb │ ├── gnn_generalization_test.ipynb │ ├── go_graph-adamson.ipynb │ ├── go_graph-norman.ipynb │ ├── go_graph-norman_adamson.ipynb │ ├── helper.py │ ├── hyperopt_train.py │ ├── inference.py │ ├── make_CPA_data.ipynb │ ├── mlp mixer.ipynb │ ├── model.py │ ├── playground.ipynb │ ├── simulation_single_data_split_on_new_dataset.ipynb │ ├── split pert order.ipynb │ ├── train.ipynb │ ├── train.py │ ├── train_1.ipynb │ ├── train_2.ipynb │ ├── train_3.ipynb │ ├── train_4.ipynb │ ├── train_5.ipynb │ ├── train_large.ipynb │ ├── uncertainty-reg10.ipynb │ ├── uncertainty.ipynb │ ├── understand_split.ipynb │ └── utils.py └── plot_functions.py ├── data └── preprocessing │ ├── Adamson2016.ipynb │ ├── Norman19.ipynb │ └── Replogle_2022_preprocess.ipynb ├── gears ├── __init__.py ├── data_utils.py ├── gears.py ├── inference.py ├── make_GO.py ├── model.py ├── pertdata.py ├── utils.py └── version.py ├── legacy ├── GI_accuracy.py ├── __init__.py ├── data.py ├── evaluate.py ├── flow.py ├── genes_with_hi_mean.npy ├── inference.py ├── learn_weights.py ├── linear_pert_model.py ├── model.py ├── pertnet.py ├── plot_functions.py ├── train.py └── utils.py └── paper ├── CPA_reproduce ├── README.md ├── __init__.py ├── api.py ├── collect_results.py ├── cpa.sh ├── cpa_to_wandb.py ├── data.py ├── helper.py ├── model.py ├── plotting.py └── train.py ├── CellOracle ├── CellOracle.ipynb └── CellOracle_evaluation_wandb.ipynb ├── Ext_Fig_2.ipynb ├── Ext_Fig_3.ipynb ├── Ext_Fig_5.ipynb ├── Ext_Fig_6a.ipynb ├── Ext_Fig_6a_GI_genes_compute.ipynb ├── Ext_Fig_7.ipynb ├── Fig2b-Adamson2016.ipynb ├── Fig2b-Dixit2016.ipynb ├── Fig2b-gearsonly.ipynb ├── Fig2b-gearsonly_Adamson.ipynb ├── Fig2b-gearsonly_Dixit.ipynb ├── Fig2b-replogle-k562.ipynb ├── Fig2b-replogle-rpe1.ipynb ├── Fig2cd.ipynb ├── Fig2f-gearsonly_Norman.ipynb ├── Fig2f.ipynb ├── Fig2f_mean.ipynb ├── Fig2g.ipynb ├── Fig2h.ipynb ├── Fig3c+Ext_Fig_6.ipynb ├── Fig3c.ipynb ├── Fig3d.ipynb ├── Fig4.ipynb ├── Fig4_UMAP_predict.py ├── Fig4_UMAP_train.py ├── Fig5b.ipynb ├── GRN ├── GRN_model.py ├── README.md ├── SCENIC_norman.ipynb ├── graph_filtering.ipynb ├── learn_weights.py └── run_GRN_baseline.py ├── README.md ├── Random_heatmap.ipynb ├── SI_Fig_15.ipynb ├── SI_Fig_17.ipynb ├── SI_Fig_17_Horlbeck_Jurkat.ipynb ├── archive ├── fig2cd.ipynb ├── fig2f.ipynb └── fig2g.ipynb ├── data ├── GI_data.pkl └── TF_names_v_1.01.txt ├── fig2_train.py ├── fig_utils.py ├── predicting_GIs.py ├── reproduce_preprint_results.ipynb └── supp_train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Yusuf Roohani 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GEARS (Miscellaneous) 2 | 3 | ### Miscellaneous code and analysis scripts related to [GEARS: Predicting transcriptional outcomes of novel multi-gene perturbations](https://github.com/snap-stanford/GEARS) 4 | 5 | 6 | This repo contains: 7 | - Code for reproducing figures and analysis from the paper 8 | - Legacy code from older versions of the model 9 | 10 | ---- 11 | 12 | **Official code repository**: [Link](https://github.com/snap-stanford/GEARS) 13 | 14 | Preprint: [Link](https://www.biorxiv.org/content/10.1101/2022.07.12.499735v2) 15 | -------------------------------------------------------------------------------- /analysis/kexin/README.md: -------------------------------------------------------------------------------- 1 | 2 | ``` 3 | python train.py --dataset Norman2019 \ 4 | --split single_only \ 5 | --device cuda:0 \ 6 | --batch_size 64 \ 7 | --model GNN_Disentangle_AE \ 8 | --node_hidden_size 8 \ 9 | --max_epochs 1 \ 10 | --model_backend GAT \ 11 | --gnn_num_layers 4 \ 12 | --loss_mode l2 \ 13 | --focal_gamma 2 14 | ``` -------------------------------------------------------------------------------- /analysis/kexin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhr91/GEARS_misc/f88211870dfa89c38a2eedbd69ca1abd28a25f3c/analysis/kexin/__init__.py -------------------------------------------------------------------------------- /analysis/kexin/data_analysis_more.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "e15e75b5", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import scanpy as sc\n", 11 | "import numpy as np\n", 12 | "import pandas as pd\n", 13 | "import seaborn as sns\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "import warnings\n", 16 | "warnings.filterwarnings(\"ignore\")\n", 17 | "\n", 18 | "adata = sc.read('/dfs/project/perturb-gnn/datasets/Norman2019_hvg+perts.h5ad')" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 27, 24 | "id": "1022b090", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "pert2pert_full_id = dict(adata.obs[['condition', 'cov_drug_dose_name']].values)" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 54, 34 | "id": "2d7eae85", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "unique_perts = adata.obs.condition.unique()\n", 39 | "X = adata.X\n", 40 | "c = adata.obs.condition\n", 41 | "query_pert = unique_perts[12]\n", 42 | "expression = X[np.where(c == query_pert)[0]].toarray()" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 55, 48 | "id": "8f18e565", 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "adata.uns['rank_genes_groups_cov'][pert2pert_full_id[query_pert]]\n", 53 | "\n", 54 | "de_idx = np.where(adata.var_names.isin(\n", 55 | " np.array(adata.uns['rank_genes_groups_cov'][pert2pert_full_id[query_pert]])))[0]" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 56, 61 | "id": "5b56563d", 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "data": { 66 | "text/plain": [ 67 | "array([0.3637023 , 0.02996498, 0.61281943, 0.72060037, 0.5657906 ,\n", 68 | " 0.40530732, 1.0259172 , 1.0281726 , 0.40466785, 0.36748803,\n", 69 | " 1.4743949 , 0.8042997 , 0.5626899 , 0.12749983, 0.09244415,\n", 70 | " 0.5978718 , 0.2420775 , 0.00527949, 0. , 0.42057905],\n", 71 | " dtype=float32)" 72 | ] 73 | }, 74 | "execution_count": 56, 75 | "metadata": {}, 76 | "output_type": "execute_result" 77 | } 78 | ], 79 | "source": [ 80 | "np.std(expression[:, de_idx], axis = 0)**2" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 57, 86 | "id": "670f66c1", 87 | "metadata": {}, 88 | "outputs": [ 89 | { 90 | "data": { 91 | "text/plain": [ 92 | "array([0.5608998 , 0.04417451, 2.226867 , 1.2113272 , 1.7787977 ,\n", 93 | " 0.9248679 , 1.7555475 , 3.2955768 , 0.80007803, 0.64751977,\n", 94 | " 4.466268 , 1.225273 , 1.5727811 , 0.17758968, 0.13103218,\n", 95 | " 1.5126592 , 3.2560515 , 0.00896992, 0. , 1.2743382 ],\n", 96 | " dtype=float32)" 97 | ] 98 | }, 99 | "execution_count": 57, 100 | "metadata": {}, 101 | "output_type": "execute_result" 102 | } 103 | ], 104 | "source": [ 105 | "np.mean(expression[:, de_idx], axis = 0)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 58, 111 | "id": "58469795", 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "data": { 116 | "text/plain": [ 117 | "array([0. , 0. , 0. , 0. , 0. , 0. ,\n", 118 | " 0. , 0. , 0. , 0. , 0. , 0. ,\n", 119 | " 0. , 0. , 0. , 0. , 1.2048866, 0. ,\n", 120 | " 0. , 0. ], dtype=float32)" 121 | ] 122 | }, 123 | "execution_count": 58, 124 | "metadata": {}, 125 | "output_type": "execute_result" 126 | } 127 | ], 128 | "source": [ 129 | "np.min(expression[:, de_idx], axis = 0)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 59, 135 | "id": "add26f61", 136 | "metadata": {}, 137 | "outputs": [ 138 | { 139 | "data": { 140 | "text/plain": [ 141 | "array([2.5098755, 1.1482052, 4.037729 , 3.5382335, 3.6960921, 2.482056 ,\n", 142 | " 3.6960921, 5.471069 , 2.364613 , 2.3326802, 6.7137156, 3.494384 ,\n", 143 | " 3.9264498, 1.6070515, 1.4129144, 3.7742238, 4.397499 , 0.739601 ,\n", 144 | " 0. , 2.6451356], dtype=float32)" 145 | ] 146 | }, 147 | "execution_count": 59, 148 | "metadata": {}, 149 | "output_type": "execute_result" 150 | } 151 | ], 152 | "source": [ 153 | "np.max(expression[:, de_idx], axis = 0)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "id": "0196df70", 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "id": "4b04d40c", 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "id": "fcae18cd", 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 17, 183 | "id": "d1c675b7", 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "unique_perturbs = []\n", 188 | "for i in adata.obs.condition.unique():\n", 189 | " if '+' in i:\n", 190 | " unique_perturbs.append(i.split('+')[0])\n", 191 | " unique_perturbs.append(i.split('+')[1])\n", 192 | " \n", 193 | "unique_perturbs = np.unique(unique_perturbs)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 36, 199 | "id": "40332009", 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "all_tfs = pd.read_csv('TF_names_v_1.01.txt', header = None).values" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 38, 209 | "id": "dbf266fa", 210 | "metadata": {}, 211 | "outputs": [ 212 | { 213 | "data": { 214 | "text/plain": [ 215 | "array(['ARID1A', 'ARRDC3', 'ATL1', 'BAK1', 'BCL2L11', 'BCORL1', 'BPGM',\n", 216 | " 'C19orf26', 'C3orf72', 'CBFA2T3', 'CBL', 'CDKN1A', 'CDKN1B',\n", 217 | " 'CDKN1C', 'CELF2', 'CITED1', 'CKS1B', 'CLDN6', 'CNN1', 'CNNM4',\n", 218 | " 'COL1A1', 'COL2A1', 'DUSP9', 'ELMSAN1', 'GLB1L2', 'HK2', 'IER5L',\n", 219 | " 'IGDCC3', 'KIAA1804', 'KIF18B', 'KIF2C', 'MAML2', 'MAP2K3',\n", 220 | " 'MAP2K6', 'MAP4K3', 'MAP4K5', 'MAP7D1', 'MAPK1', 'MIDN', 'NCL',\n", 221 | " 'NIT1', 'PLK4', 'PRTG', 'PTPN1', 'PTPN12', 'PTPN13', 'PTPN9',\n", 222 | " 'RHOXF2BB', 'RUNX1T1', 'S1PR2', 'SAMD1', 'SET', 'SGK1', 'SLC38A2',\n", 223 | " 'SLC4A1', 'SLC6A9', 'STIL', 'TGFBR2', 'TMSB4X', 'UBASH3A',\n", 224 | " 'UBASH3B', 'ZC3HAV1', 'ctrl'], dtype=')" 353 | ] 354 | }, 355 | "execution_count": 28, 356 | "metadata": {}, 357 | "output_type": "execute_result" 358 | } 359 | ], 360 | "source": [ 361 | "mixer(base_emb)" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": null, 367 | "id": "32cb00c6", 368 | "metadata": {}, 369 | "outputs": [], 370 | "source": [] 371 | } 372 | ], 373 | "metadata": { 374 | "kernelspec": { 375 | "display_name": "Python 3 (ipykernel)", 376 | "language": "python", 377 | "name": "python3" 378 | }, 379 | "language_info": { 380 | "codemirror_mode": { 381 | "name": "ipython", 382 | "version": 3 383 | }, 384 | "file_extension": ".py", 385 | "mimetype": "text/x-python", 386 | "name": "python", 387 | "nbconvert_exporter": "python", 388 | "pygments_lexer": "ipython3", 389 | "version": "3.8.11" 390 | } 391 | }, 392 | "nbformat": 4, 393 | "nbformat_minor": 5 394 | } 395 | -------------------------------------------------------------------------------- /analysis/kexin/train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 16, 6 | "id": "4279a2ac", 7 | "metadata": { 8 | "scrolled": false 9 | }, 10 | "outputs": [ 11 | { 12 | "name": "stdout", 13 | "output_type": "stream", 14 | "text": [ 15 | "---- Printing Arguments ----\n", 16 | "dataset: Norman2019\n", 17 | "split: single_only\n", 18 | "seed: 1\n", 19 | "test_set_fraction: 0.1\n", 20 | "perturbation_key: condition\n", 21 | "species: human\n", 22 | "binary_pert: True\n", 23 | "edge_attr: True\n", 24 | "ctrl_remove_train: False\n", 25 | "edge_weights: False\n", 26 | "pert_feats: True\n", 27 | "pert_delta: False\n", 28 | "edge_filter: False\n", 29 | "network_name: string\n", 30 | "top_edge_percent: 10.0\n", 31 | "device: cuda:0\n", 32 | "max_epochs: 1\n", 33 | "lr: 0.005\n", 34 | "lr_decay_step_size: 3\n", 35 | "lr_decay_factor: 0.5\n", 36 | "weight_decay: 0.0005\n", 37 | "batch_size: 64\n", 38 | "print_progress_steps: 50\n", 39 | "node_hidden_size: 8\n", 40 | "node_embed_size: 1\n", 41 | "ae_hidden_size: 512\n", 42 | "gnn_num_layers: 4\n", 43 | "ae_num_layers: 2\n", 44 | "model: GNN_Disentangle_AE\n", 45 | "model_backend: GAT\n", 46 | "shared_weights: False\n", 47 | "pert_loss_wt: 1\n", 48 | "loss_type: micro\n", 49 | "loss_mode: l2\n", 50 | "focal_gamma: 2\n", 51 | "wandb: True\n", 52 | "project_name: pert_gnn_v1\n", 53 | "entity_name: kexinhuang\n", 54 | "----------------------------\n", 55 | "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mkexinhuang\u001b[0m (use `wandb login --relogin` to force relogin)\n", 56 | "\u001b[34m\u001b[1mwandb\u001b[0m: wandb version 0.12.4 is available! To upgrade, please run:\n", 57 | "\u001b[34m\u001b[1mwandb\u001b[0m: $ pip install wandb --upgrade\n", 58 | "\u001b[34m\u001b[1mwandb\u001b[0m: Tracking run with wandb version 0.12.2\n", 59 | "\u001b[34m\u001b[1mwandb\u001b[0m: Syncing run \u001b[33mGAT_string_8_4_l2_Norman2019\u001b[0m\n", 60 | "\u001b[34m\u001b[1mwandb\u001b[0m: View project at \u001b[34m\u001b[4mhttps://wandb.ai/kexinhuang/pert_gnn_v1\u001b[0m\n", 61 | "\u001b[34m\u001b[1mwandb\u001b[0m: View run at \u001b[34m\u001b[4mhttps://wandb.ai/kexinhuang/pert_gnn_v1/runs/x22do22y\u001b[0m\n", 62 | "\u001b[34m\u001b[1mwandb\u001b[0m: Run data is saved locally in /dfs/user/kexinh/perturb_GNN/kexin/wandb/run-20211011_134038-x22do22y\n", 63 | "\u001b[34m\u001b[1mwandb\u001b[0m: Run `wandb offline` to turn off syncing.\n", 64 | "\n", 65 | "Training GAT_string_8_4_l2_Norman2019\n", 66 | "Building cell graph... \n", 67 | "There are 50506 edges in the PPI.\n", 68 | "Creating pyg object for each cell in the data...\n", 69 | "Local copy of pyg dataset is detected. Loading...\n", 70 | "Loading splits...\n", 71 | "Local copy of split is detected. Loading...\n", 72 | "Creating dataloaders....\n", 73 | "Dataloaders created\n", 74 | "Finished data setup, in total takes 0.8058381994565328 min\n", 75 | "Initializing model... \n", 76 | "Start Training...\n", 77 | "Epoch 1 Step 1 Train Loss: 0.3884\n", 78 | "Epoch 1 Step 51 Train Loss: 0.0392\n", 79 | "Epoch 1 Step 101 Train Loss: 0.0423\n", 80 | "Epoch 1 Step 151 Train Loss: 0.0507\n", 81 | "Epoch 1 Step 201 Train Loss: 0.0454\n", 82 | "Epoch 1 Step 251 Train Loss: 0.0477\n", 83 | "Epoch 1 Step 301 Train Loss: 0.0402\n", 84 | "Epoch 1 Step 351 Train Loss: 0.0451\n", 85 | "Epoch 1 Step 401 Train Loss: 0.0438\n", 86 | "Epoch 1 Step 451 Train Loss: 0.0366\n", 87 | "Epoch 1 Step 501 Train Loss: 0.0468\n", 88 | "Epoch 1 Step 551 Train Loss: 0.0437\n", 89 | "Epoch 1 Step 601 Train Loss: 0.0375\n", 90 | "Epoch 1 Step 651 Train Loss: 0.0427\n", 91 | "Epoch 1: Train: 0.0043, R2 0.9708 Validation: 0.0044. R2 0.9703 Loss: 0.0453\n", 92 | "DE_Train: 0.2855, R2 0.5529 DE_Validation: 0.2756. R2 0.6805 \n", 93 | "Start testing....\n", 94 | "Final best performing model: Test_DE: 0.2077, R2 0.3332 \n", 95 | "Saving model....\n", 96 | "Done!\n", 97 | "\n", 98 | "\u001b[34m\u001b[1mwandb\u001b[0m: Waiting for W&B process to finish, PID 44962\n", 99 | "\u001b[34m\u001b[1mwandb\u001b[0m: Program ended successfully.\n", 100 | "\u001b[34m\u001b[1mwandb\u001b[0m: \n", 101 | "\u001b[34m\u001b[1mwandb\u001b[0m: Find user logs for this run at: /dfs/user/kexinh/perturb_GNN/kexin/wandb/run-20211011_134038-x22do22y/logs/debug.log\n", 102 | "\u001b[34m\u001b[1mwandb\u001b[0m: Find internal logs for this run at: /dfs/user/kexinh/perturb_GNN/kexin/wandb/run-20211011_134038-x22do22y/logs/debug-internal.log\n", 103 | "\u001b[34m\u001b[1mwandb\u001b[0m: Run summary:\n", 104 | "\u001b[34m\u001b[1mwandb\u001b[0m: Test_DE_MSE 0.20767\n", 105 | "\u001b[34m\u001b[1mwandb\u001b[0m: Test_R2 0.33322\n", 106 | "\u001b[34m\u001b[1mwandb\u001b[0m: train_de_mse 0.28553\n", 107 | "\u001b[34m\u001b[1mwandb\u001b[0m: train_de_r2 0.55286\n", 108 | "\u001b[34m\u001b[1mwandb\u001b[0m: train_mse 0.00434\n", 109 | "\u001b[34m\u001b[1mwandb\u001b[0m: train_r2 0.97082\n", 110 | "\u001b[34m\u001b[1mwandb\u001b[0m: training_loss 0.03998\n", 111 | "\u001b[34m\u001b[1mwandb\u001b[0m: val_de_mse 0.27558\n", 112 | "\u001b[34m\u001b[1mwandb\u001b[0m: val_de_r2 0.68051\n", 113 | "\u001b[34m\u001b[1mwandb\u001b[0m: val_mse 0.00442\n", 114 | "\u001b[34m\u001b[1mwandb\u001b[0m: val_r2 0.97027\n", 115 | "\u001b[34m\u001b[1mwandb\u001b[0m: Run history:\n", 116 | "\u001b[34m\u001b[1mwandb\u001b[0m: Test_DE_MSE ▁\n", 117 | "\u001b[34m\u001b[1mwandb\u001b[0m: Test_R2 ▁\n", 118 | "\u001b[34m\u001b[1mwandb\u001b[0m: train_de_mse ▁\n", 119 | "\u001b[34m\u001b[1mwandb\u001b[0m: train_de_r2 ▁\n", 120 | "\u001b[34m\u001b[1mwandb\u001b[0m: train_mse ▁\n", 121 | "\u001b[34m\u001b[1mwandb\u001b[0m: train_r2 ▁\n", 122 | "\u001b[34m\u001b[1mwandb\u001b[0m: training_loss █▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▁▁▂▁▂▁▂▂▁▂▂▂▂▂▁▁▂▂▂▂▂▁▂▂\n", 123 | "\u001b[34m\u001b[1mwandb\u001b[0m: val_de_mse ▁\n", 124 | "\u001b[34m\u001b[1mwandb\u001b[0m: val_de_r2 ▁\n", 125 | "\u001b[34m\u001b[1mwandb\u001b[0m: val_mse ▁\n", 126 | "\u001b[34m\u001b[1mwandb\u001b[0m: val_r2 ▁\n", 127 | "\u001b[34m\u001b[1mwandb\u001b[0m: Synced 7 W&B file(s), 0 media file(s), 0 artifact file(s) and 1 other file(s)\n", 128 | "\u001b[34m\u001b[1mwandb\u001b[0m: \n", 129 | "\u001b[34m\u001b[1mwandb\u001b[0m: Synced \u001b[33mGAT_string_8_4_l2_Norman2019\u001b[0m: \u001b[34mhttps://wandb.ai/kexinhuang/pert_gnn_v1/runs/x22do22y\u001b[0m\n" 130 | ] 131 | } 132 | ], 133 | "source": [ 134 | "!python train.py --dataset Norman2019 \\\n", 135 | " --split single_only \\\n", 136 | " --seed 1 \\\n", 137 | " --test_set_fraction 0.1 \\\n", 138 | " --network_name string \\\n", 139 | " --top_edge_percent 10 \\\n", 140 | " --max_epochs 1 \\\n", 141 | " --batch_size 64 \\\n", 142 | " --lr 5e-3 \\\n", 143 | " --lr_decay_step_size 3 \\\n", 144 | " --lr_decay_factor 0.5 \\\n", 145 | " --weight_decay 5e-4 \\\n", 146 | " --model GNN_Disentangle_AE \\\n", 147 | " --model_backend GAT \\\n", 148 | " --node_hidden_size 8 \\\n", 149 | " --gnn_num_layers 4 \\\n", 150 | " --ae_hidden_size 512 \\\n", 151 | " --ae_num_layers 2\\\n", 152 | " --loss_mode l2 \\\n", 153 | " --focal_gamma 2 \\\n", 154 | " --print_progress_steps 50 \\\n", 155 | " --device cuda:0 \\\n", 156 | " --wandb" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "id": "2b32b172", 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [] 166 | } 167 | ], 168 | "metadata": { 169 | "kernelspec": { 170 | "display_name": "Python 3 (ipykernel)", 171 | "language": "python", 172 | "name": "python3" 173 | }, 174 | "language_info": { 175 | "codemirror_mode": { 176 | "name": "ipython", 177 | "version": 3 178 | }, 179 | "file_extension": ".py", 180 | "mimetype": "text/x-python", 181 | "name": "python", 182 | "nbconvert_exporter": "python", 183 | "pygments_lexer": "ipython3", 184 | "version": "3.8.11" 185 | } 186 | }, 187 | "nbformat": 4, 188 | "nbformat_minor": 5 189 | } 190 | -------------------------------------------------------------------------------- /analysis/kexin/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | ## helper function 5 | def parse_single_pert(i): 6 | a = i.split('+')[0] 7 | b = i.split('+')[1] 8 | if a == 'ctrl': 9 | pert = b 10 | else: 11 | pert = a 12 | return pert 13 | 14 | def parse_combo_pert(i): 15 | return i.split('+')[0], i.split('+')[1] 16 | 17 | def parse_any_pert(p): 18 | if ('ctrl' in p) and (p != 'ctrl'): 19 | return [parse_single_pert(p)] 20 | elif 'ctrl' not in p: 21 | out = parse_combo_pert(p) 22 | return [out[0], out[1]] 23 | 24 | def np_pearson_cor(x, y): 25 | xv = x - x.mean(axis=0) 26 | yv = y - y.mean(axis=0) 27 | xvss = (xv * xv).sum(axis=0) 28 | yvss = (yv * yv).sum(axis=0) 29 | result = np.matmul(xv.transpose(), yv) / np.sqrt(np.outer(xvss, yvss)) 30 | # bound the values to -1 to 1 in the event of precision issues 31 | return np.maximum(np.minimum(result, 1.0), -1.0) 32 | 33 | 34 | def get_coexpression_network_from_train(adata, pertdl, args, threshold = 0.4, k = 10): 35 | import os 36 | import pandas as pd 37 | 38 | fname = './saved_networks/' + args['dataset'] + '_' + args['split'] + '_' + str(args['seed']) + '_' + str(args['test_set_fraction']) + '_' + str(threshold) + '_' + str(k) + '_co_expression_network.csv' 39 | 40 | if os.path.exists(fname): 41 | return fname 42 | else: 43 | gene_list = [f for f in adata.var.gene_symbols.values] 44 | idx2gene = dict(zip(range(len(gene_list)), gene_list)) 45 | X = adata.X 46 | train_perts = pertdl.set2conditions['train'] 47 | X_tr = X[np.isin(adata.obs.condition, [i for i in train_perts if 'ctrl' in i])] 48 | gene_list = adata.var['gene_name'].values 49 | 50 | X_tr = X_tr.toarray() 51 | out = np_pearson_cor(X_tr, X_tr) 52 | out[np.isnan(out)] = 0 53 | out = np.abs(out) 54 | 55 | out_sort_idx = np.argsort(out)[:, -(k + 1):] 56 | out_sort_val = np.sort(out)[:, -(k + 1):] 57 | 58 | df_g = [] 59 | for i in range(out_sort_idx.shape[0]): 60 | target = idx2gene[i] 61 | for j in range(out_sort_idx.shape[1]): 62 | df_g.append((idx2gene[out_sort_idx[i, j]], target, out_sort_val[i, j])) 63 | 64 | df_g = [i for i in df_g if i[2] > threshold] 65 | df_co_expression = pd.DataFrame(df_g).rename(columns = {0: 'source', 1: 'target', 2: 'importance'}) 66 | df_co_expression.to_csv(fname, index = False) 67 | return fname 68 | 69 | def weighted_mse_loss(input, target, weight): 70 | """ 71 | Weighted MSE implementation 72 | """ 73 | sample_mean = torch.mean((input - target) ** 2, 1) 74 | return torch.mean(weight * sample_mean) 75 | 76 | 77 | def uncertainty_loss_fct(pred, logvar, y, perts, loss_mode = 'l2', gamma = 1, reg = 0.1, reg_core = 1): 78 | perts = np.array(perts) 79 | losses = torch.tensor(0.0, requires_grad=True).to(pred.device) 80 | for p in set(perts): 81 | pred_p = pred[np.where(perts==p)[0]] 82 | y_p = y[np.where(perts==p)[0]] 83 | logvar_p = logvar[np.where(perts==p)[0]] 84 | 85 | if loss_mode == 'l2': 86 | losses += torch.sum(0.5 * torch.exp(-logvar_p) * (pred_p - y_p)**2 + 0.5 * logvar_p)/pred_p.shape[0]/pred_p.shape[1] 87 | elif loss_mode == 'l3': 88 | #losses += torch.sum(0.5 * torch.exp(-logvar_p) * (pred_p - y_p)**(2 + gamma) + 0.01 * logvar_p)/pred_p.shape[0]/pred_p.shape[1] 89 | #losses += torch.sum((pred_p - y_p)**(2 + gamma) + 0.1 * torch.exp(-logvar_p) * (pred_p - y_p)**(2 + gamma) + 0.1 * logvar_p)/pred_p.shape[0]/pred_p.shape[1] 90 | losses += reg_core * torch.sum((pred_p - y_p)**(2 + gamma) + reg * torch.exp(-logvar_p) * (pred_p - y_p)**(2 + gamma))/pred_p.shape[0]/pred_p.shape[1] 91 | 92 | 93 | return losses/(len(set(perts))) 94 | 95 | 96 | def loss_fct(pred, y, perts, weight=1, loss_type = 'macro', loss_mode = 'l2', gamma = 1): 97 | 98 | # Micro average MSE 99 | if loss_type == 'macro': 100 | mse_p = torch.nn.MSELoss() 101 | perts = np.array(perts) 102 | losses = torch.tensor(0.0, requires_grad=True).to(pred.device) 103 | for p in set(perts): 104 | pred_p = pred[np.where(perts==p)[0]] 105 | y_p = y[np.where(perts==p)[0]] 106 | if loss_mode == 'l2': 107 | losses += torch.sum((pred_p - y_p)**2)/pred_p.shape[0]/pred_p.shape[1] 108 | elif loss_mode == 'l3': 109 | losses += torch.sum((pred_p - y_p)**(2 + gamma))/pred_p.shape[0]/pred_p.shape[1] 110 | 111 | return losses/(len(set(perts))) 112 | 113 | else: 114 | # Weigh the loss for perturbations (unweighted by default) 115 | #weights = np.ones(len(pred)) 116 | #non_ctrl_idx = np.where([('ctrl' != p) for p in perts])[0] 117 | #weights[non_ctrl_idx] = weight 118 | #loss = weighted_mse_loss(pred, y, torch.Tensor(weights).to(pred.device)) 119 | if loss_mode == 'l2': 120 | loss = torch.sum((pred - y)**2)/pred.shape[0]/pred.shape[1] 121 | elif loss_mode == 'l3': 122 | loss = torch.sum((pred - y)**(2 + gamma))/pred.shape[0]/pred.shape[1] 123 | 124 | return loss -------------------------------------------------------------------------------- /analysis/plot_functions.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | def get_de(res, query): 5 | """ 6 | Given a query perturbation and model output file, 7 | return predicted and true post-perturbation expression 8 | """ 9 | query_idx = np.where(res['pert_cat'] == query)[0] 10 | de = {"pred_de": res['pred_de'][query_idx], 11 | "truth_de": res['truth_de'][query_idx]} 12 | return de 13 | 14 | def get_de_ctrl(pert, adata): 15 | """ 16 | Get ctrl expression for DE genes for a given perturbation 17 | """ 18 | mean_ctrl_exp = adata[adata.obs['condition'] == 'ctrl'].to_df().mean() 19 | de_genes = get_covar_genes(pert, adata) 20 | return mean_ctrl_exp[de_genes] 21 | 22 | def get_covar_genes(p, adata): 23 | """ 24 | Get genes that are differentially expressed 25 | """ 26 | gene_name_dict = adata.var.loc[:,'gene_name'].to_dict() 27 | pert_name = 'A549_'+p+'_1+1' 28 | de_genes = adata.uns['rank_genes_groups_cov'][pert_name] 29 | return de_genes 30 | 31 | def create_boxplot(res, adata, query, genes=None): 32 | """ 33 | Create a boxplot showing true, predicted and control expression 34 | for a given perturbation 35 | """ 36 | 37 | plt.figure(figsize=[10,3]) 38 | plt.title(query) 39 | pert_de_res = get_de(res, query)['pred_de'] 40 | truth_de_res = get_de(res, query)['truth_de'] 41 | plt.boxplot(truth_de_res, showfliers=False, 42 | medianprops = dict(linewidth=0)) 43 | ctrl_means = get_de_ctrl(query, adata).values 44 | 45 | for i in range(pert_de_res.shape[1]): 46 | _ = plt.scatter(i+1, np.mean(pert_de_res[:,i]), color='red') 47 | _ = plt.scatter(i+1, ctrl_means[i], color='forestgreen', marker='*') 48 | 49 | ax = plt.gca() 50 | if genes is not None: 51 | ax.xaxis.set_ticklabels(genes) 52 | else: 53 | ax.xaxis.set_ticklabels(['G1','G2','G3','G4','G5','G6','G7','G8','G9', 'G10', 54 | 'G11','G12','G13','G14','G15','G16','G17','G18','G19', 'G20']) 55 | -------------------------------------------------------------------------------- /gears/__init__.py: -------------------------------------------------------------------------------- 1 | from .gears import GEARS 2 | from .pertdata import PertData -------------------------------------------------------------------------------- /gears/make_GO.py: -------------------------------------------------------------------------------- 1 | ## Script for creating Gene Ontology graph from a custom set of genes 2 | 3 | import pickle, os 4 | import pandas as pd 5 | 6 | data_name = 'dixit' 7 | 8 | with open(os.path.join('./data/', 'gene2go_all.pkl'), 'rb') as f: 9 | gene2go = pickle.load(f) 10 | 11 | with open('./data/essential_' + data_name + '.pkl', 'rb') as f: 12 | essential_genes = pickle.load(f) 13 | 14 | gene2go = {i: gene2go[i] for i in essential_genes if i in gene2go} 15 | 16 | import tqdm 17 | from multiprocessing import Pool 18 | import numpy as np 19 | 20 | def get_edge_list(g1): 21 | edge_list = [] 22 | for g2 in gene2go.keys(): 23 | score = len(gene2go[g1].intersection(gene2go[g2]))/len(gene2go[g1].union(gene2go[g2])) 24 | if score > 0.1: 25 | edge_list.append((g1, g2, score)) 26 | return edge_list 27 | 28 | with Pool(40) as p: 29 | all_edge_list = list(tqdm.tqdm(p.imap(get_edge_list, list(gene2go.keys())), total=len(gene2go.keys()))) 30 | 31 | edge_list = [] 32 | for i in all_edge_list: 33 | edge_list = edge_list + i 34 | 35 | del all_edge_list 36 | 37 | df_edge_list = pd.DataFrame(edge_list).rename(columns = {0: 'gene1', 1: 'gene2', 2: 'score'}) 38 | 39 | df_edge_list = df_edge_list.rename(columns = {'gene1': 'source', 'gene2': 'target', 'score': 'importance'}) 40 | df_edge_list.to_csv('./data/go_essential_' + data_name + '.csv', index = False) 41 | -------------------------------------------------------------------------------- /gears/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Sequential, Linear, ReLU 5 | 6 | from torch_geometric.nn import SGConv 7 | 8 | class MLP(torch.nn.Module): 9 | 10 | def __init__(self, sizes, batch_norm=True, last_layer_act="linear"): 11 | super(MLP, self).__init__() 12 | layers = [] 13 | for s in range(len(sizes) - 1): 14 | layers = layers + [ 15 | torch.nn.Linear(sizes[s], sizes[s + 1]), 16 | torch.nn.BatchNorm1d(sizes[s + 1]) 17 | if batch_norm and s < len(sizes) - 1 else None, 18 | torch.nn.ReLU() 19 | ] 20 | 21 | layers = [l for l in layers if l is not None][:-1] 22 | self.activation = last_layer_act 23 | self.network = torch.nn.Sequential(*layers) 24 | self.relu = torch.nn.ReLU() 25 | def forward(self, x): 26 | return self.network(x) 27 | 28 | 29 | class GEARS_Model(torch.nn.Module): 30 | """ 31 | GEARS 32 | """ 33 | 34 | def __init__(self, args): 35 | super(GEARS_Model, self).__init__() 36 | self.args = args 37 | self.num_genes = args['num_genes'] 38 | self.num_perts = args['num_perts'] 39 | hidden_size = args['hidden_size'] 40 | self.uncertainty = args['uncertainty'] 41 | self.num_layers = args['num_go_gnn_layers'] 42 | self.indv_out_hidden_size = args['decoder_hidden_size'] 43 | self.num_layers_gene_pos = args['num_gene_gnn_layers'] 44 | self.no_perturb = args['no_perturb'] 45 | self.cell_fitness_pred = args['cell_fitness_pred'] 46 | self.pert_emb_lambda = 0.2 47 | 48 | # perturbation positional embedding added only to the perturbed genes 49 | self.pert_w = nn.Linear(1, hidden_size) 50 | 51 | # gene/globel perturbation embedding dictionary lookup 52 | self.gene_emb = nn.Embedding(self.num_genes, hidden_size, max_norm=True) 53 | self.pert_emb = nn.Embedding(self.num_perts, hidden_size, max_norm=True) 54 | 55 | # transformation layer 56 | self.emb_trans = nn.ReLU() 57 | self.pert_base_trans = nn.ReLU() 58 | self.transform = nn.ReLU() 59 | self.emb_trans_v2 = MLP([hidden_size, hidden_size, hidden_size], last_layer_act='ReLU') 60 | self.pert_fuse = MLP([hidden_size, hidden_size, hidden_size], last_layer_act='ReLU') 61 | 62 | # gene co-expression GNN 63 | self.G_coexpress = args['G_coexpress'].to(args['device']) 64 | self.G_coexpress_weight = args['G_coexpress_weight'].to(args['device']) 65 | 66 | self.emb_pos = nn.Embedding(self.num_genes, hidden_size, max_norm=True) 67 | self.layers_emb_pos = torch.nn.ModuleList() 68 | for i in range(1, self.num_layers_gene_pos + 1): 69 | self.layers_emb_pos.append(SGConv(hidden_size, hidden_size, 1)) 70 | 71 | ### perturbation gene ontology GNN 72 | self.G_sim = args['G_go'].to(args['device']) 73 | self.G_sim_weight = args['G_go_weight'].to(args['device']) 74 | 75 | self.sim_layers = torch.nn.ModuleList() 76 | for i in range(1, self.num_layers + 1): 77 | self.sim_layers.append(SGConv(hidden_size, hidden_size, 1)) 78 | 79 | # decoder shared MLP 80 | self.recovery_w = MLP([hidden_size, hidden_size*2, hidden_size], last_layer_act='linear') 81 | 82 | # gene specific decoder 83 | self.indv_w1 = nn.Parameter(torch.rand(self.num_genes, 84 | hidden_size, 1)) 85 | self.indv_b1 = nn.Parameter(torch.rand(self.num_genes, 1)) 86 | self.act = nn.ReLU() 87 | nn.init.xavier_normal_(self.indv_w1) 88 | nn.init.xavier_normal_(self.indv_b1) 89 | 90 | # Cross gene MLP 91 | self.cross_gene_state = MLP([self.num_genes, hidden_size, 92 | hidden_size]) 93 | # final gene specific decoder 94 | self.indv_w2 = nn.Parameter(torch.rand(1, self.num_genes, 95 | hidden_size+1)) 96 | self.indv_b2 = nn.Parameter(torch.rand(1, self.num_genes)) 97 | nn.init.xavier_normal_(self.indv_w2) 98 | nn.init.xavier_normal_(self.indv_b2) 99 | 100 | # batchnorms 101 | self.bn_emb = nn.BatchNorm1d(hidden_size) 102 | self.bn_pert_base = nn.BatchNorm1d(hidden_size) 103 | self.bn_pert_base_trans = nn.BatchNorm1d(hidden_size) 104 | 105 | # uncertainty mode 106 | if self.uncertainty: 107 | self.uncertainty_w = MLP([hidden_size, hidden_size*2, hidden_size, 1], last_layer_act='linear') 108 | 109 | #if self.cell_fitness_pred: 110 | self.cell_fitness_mlp = MLP([self.num_genes, hidden_size*2, hidden_size, 1], last_layer_act='linear') 111 | 112 | def forward(self, data): 113 | x, pert_idx = data.x, data.pert_idx 114 | if self.no_perturb: 115 | out = x.reshape(-1,1) 116 | out = torch.split(torch.flatten(out), self.num_genes) 117 | return torch.stack(out) 118 | else: 119 | num_graphs = len(data.batch.unique()) 120 | 121 | ## get base gene embeddings 122 | emb = self.gene_emb(torch.LongTensor(list(range(self.num_genes))).repeat(num_graphs, ).to(self.args['device'])) 123 | emb = self.bn_emb(emb) 124 | base_emb = self.emb_trans(emb) 125 | 126 | pos_emb = self.emb_pos(torch.LongTensor(list(range(self.num_genes))).repeat(num_graphs, ).to(self.args['device'])) 127 | for idx, layer in enumerate(self.layers_emb_pos): 128 | pos_emb = layer(pos_emb, self.G_coexpress, self.G_coexpress_weight) 129 | if idx < len(self.layers_emb_pos) - 1: 130 | pos_emb = pos_emb.relu() 131 | 132 | base_emb = base_emb + 0.2 * pos_emb 133 | base_emb = self.emb_trans_v2(base_emb) 134 | 135 | ## get perturbation index and embeddings 136 | #print(pert_idx) 137 | #pert = pert_idx.reshape(-1,1) 138 | #pert_index = torch.where(pert.reshape(num_graphs, int(pert_idx.shape[0]/num_graphs)) == 1) 139 | 140 | pert_index = [] 141 | for idx, i in enumerate(pert_idx): 142 | for j in i: 143 | if j != -1: 144 | pert_index.append([idx, j]) 145 | pert_index = torch.tensor(pert_index).T 146 | 147 | pert_global_emb = self.pert_emb(torch.LongTensor(list(range(self.num_perts))).to(self.args['device'])) 148 | 149 | ## augment global perturbation embedding with GNN 150 | for idx, layer in enumerate(self.sim_layers): 151 | pert_global_emb = layer(pert_global_emb, self.G_sim, self.G_sim_weight) 152 | if idx < self.num_layers - 1: 153 | pert_global_emb = pert_global_emb.relu() 154 | 155 | ## add global perturbation embedding to each gene in each cell in the batch 156 | base_emb = base_emb.reshape(num_graphs, self.num_genes, -1) 157 | 158 | if pert_index.shape[0] != 0: 159 | ### in case all samples in the batch are controls, then there is no indexing for pert_index. 160 | pert_track = {} 161 | for i, j in enumerate(pert_index[0]): 162 | if j.item() in pert_track: 163 | pert_track[j.item()] = pert_track[j.item()] + pert_global_emb[pert_index[1][i]] 164 | else: 165 | pert_track[j.item()] = pert_global_emb[pert_index[1][i]] 166 | 167 | if len(list(pert_track.values())) > 0: 168 | if len(list(pert_track.values())) == 1: 169 | # circumvent when batch size = 1 with single perturbation and cannot feed into MLP 170 | emb_total = self.pert_fuse(torch.stack(list(pert_track.values()) * 2)) 171 | else: 172 | emb_total = self.pert_fuse(torch.stack(list(pert_track.values()))) 173 | 174 | for idx, j in enumerate(pert_track.keys()): 175 | base_emb[j] = base_emb[j] + emb_total[idx] 176 | 177 | base_emb = base_emb.reshape(num_graphs * self.num_genes, -1) 178 | 179 | ## add the perturbation positional embedding 180 | #pert_emb = self.pert_w(pert) 181 | #combined = pert_emb+base_emb 182 | #combined = self.bn_pert_base_trans(combined) 183 | #base_emb = self.pert_base_trans(combined) 184 | base_emb = self.bn_pert_base(base_emb) 185 | 186 | ## apply the first MLP 187 | base_emb = self.transform(base_emb) 188 | out = self.recovery_w(base_emb) 189 | out = out.reshape(num_graphs, self.num_genes, -1) 190 | out = out.unsqueeze(-1) * self.indv_w1 191 | w = torch.sum(out, axis = 2) 192 | out = w + self.indv_b1 193 | 194 | # Cross gene 195 | cross_gene_embed = self.cross_gene_state(out.reshape(num_graphs, self.num_genes, -1).squeeze(2)) 196 | cross_gene_embed = cross_gene_embed.repeat(1, self.num_genes) 197 | 198 | cross_gene_embed = cross_gene_embed.reshape([num_graphs,self.num_genes, -1]) 199 | cross_gene_out = torch.cat([out, cross_gene_embed], 2) 200 | 201 | cross_gene_out = cross_gene_out * self.indv_w2 202 | cross_gene_out = torch.sum(cross_gene_out, axis=2) 203 | out = cross_gene_out + self.indv_b2 204 | out = out.reshape(num_graphs * self.num_genes, -1) + x.reshape(-1,1) 205 | out = torch.split(torch.flatten(out), self.num_genes) 206 | 207 | ## uncertainty head 208 | if self.uncertainty: 209 | out_logvar = self.uncertainty_w(base_emb) 210 | out_logvar = torch.split(torch.flatten(out_logvar), self.num_genes) 211 | return torch.stack(out), torch.stack(out_logvar) 212 | 213 | if self.cell_fitness_pred: 214 | return torch.stack(out), self.cell_fitness_mlp(torch.stack(out)) 215 | 216 | return torch.stack(out) 217 | 218 | -------------------------------------------------------------------------------- /gears/version.py: -------------------------------------------------------------------------------- 1 | """GEARS version file 2 | """ 3 | # Based on NiLearn package 4 | # License: simplified BSD 5 | 6 | # PEP0440 compatible formatted version, see: 7 | # https://www.python.org/dev/peps/pep-0440/ 8 | # 9 | # Generic release markers: 10 | # X.Y 11 | # X.Y.Z # For bug fix releases 12 | # 13 | # Admissible pre-release markers: 14 | # X.YaN # Alpha release 15 | # X.YbN # Beta release 16 | # X.YrcN # Release Candidate 17 | # X.Y # Final release 18 | # 19 | # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. 20 | # 'X.Y.dev0' is the canonical version of 'X.Y.dev' 21 | # 22 | __version__ = '0.0.2' # pragma: no cover -------------------------------------------------------------------------------- /legacy/GI_accuracy.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import numpy as np 3 | import torch 4 | import scanpy as sc 5 | import pandas as pd 6 | import copy 7 | import sys 8 | import os 9 | from data import PertDataloader 10 | from inference import evaluate, compute_metrics 11 | from inference import GIs 12 | import matplotlib.patches as mpatches 13 | 14 | # Linear model fitting functions 15 | import statsmodels.api as sm 16 | from sklearn.linear_model import LinearRegression, TheilSenRegressor 17 | from dcor import distance_correlation, partial_distance_correlation 18 | from sklearn.metrics import r2_score 19 | 20 | import matplotlib.pyplot as plt 21 | %matplotlib inline 22 | 23 | device = 'cuda:1' 24 | home_dir = '/dfs/user/yhr/perturb_GNN/pertnet/' 25 | sys.path.append(home_dir) 26 | 27 | 28 | ## Read in output from leave one out models 29 | # Plot generation functions 30 | def get_t_p_seen1(metric): 31 | res_seen1 = {} 32 | res_p_seen1 = {} 33 | res_t_seen1 = {} 34 | 35 | # Set up output dictionaries 36 | for GI in GI_names: 37 | res_p_seen1[GI] = [] 38 | res_t_seen1[GI] = [] 39 | 40 | for GI_sel in GI_names: 41 | 42 | # For a given GI what are all the relevant perturbations 43 | all_perts_gi = GIs[GI_sel.upper()] 44 | 45 | # What are all the relevant single gene perturbations 46 | all_perts_gi = [v.split('+') for v in all_perts_gi] 47 | seen1_perts_gi = np.unique([item for sublist in all_perts_gi for item in sublist]) 48 | 49 | # Iterate over all models trained with these single genes held out 50 | for GI in seen1_perts_gi: 51 | for d in dict_names: 52 | if GI in d: 53 | res_seen1[GI] = np.load(d, allow_pickle=True).item() 54 | 55 | # Get all keys for single pert model predictions that are relevant 56 | keys_ = [k for k in res_seen1[GI].keys() if k in GIs[GI_sel.upper()]] 57 | 58 | p_vals = [res_seen1[GI][k]['pred'][metric] for k in keys_] 59 | t_vals = [res_seen1[GI][k]['truth'][metric] for k in keys_] 60 | 61 | res_p_seen1[GI_sel].extend(p_vals) 62 | res_t_seen1[GI_sel].extend(t_vals) 63 | 64 | return res_p_seen1, res_t_seen1 65 | 66 | def get_t_p_seen2(metric): 67 | 68 | # Seen 2 69 | res_p = {} 70 | res_t = {} 71 | 72 | for GI in GI_names: 73 | res_p[GI] = [] 74 | res_t[GI] = [] 75 | 76 | for GI in GI_names: 77 | for d in dict_names: 78 | if GI in d: 79 | loaded = list(np.load(d, allow_pickle=True).item().values())[0] 80 | res_p[GI].append(loaded['pred'][metric]) 81 | res_t[GI].append(loaded['truth'][metric]) 82 | 83 | return res_p, res_t 84 | 85 | 86 | ## Compute accuracy 87 | 88 | def synergy_similar_pheno(dict_): 89 | return np.sum(np.array(dict_['mag']['synergy_similar_pheno'])>1)/len(dict_['mag']['synergy_similar_pheno']) 90 | 91 | def synergy_dissimilar_pheno(dict_): 92 | return np.sum(np.array(dict_['mag']['synergy_dissimilar_pheno'])>1)/len(dict_['mag']['synergy_dissimilar_pheno']) 93 | 94 | def potentiation(dict_): 95 | cond1 = np.sum(np.array(dict_['mag']['potentiation'])>1)/len(dict_['mag']['potentiation']) 96 | return cond1 97 | #cond2 = 98 | 99 | # TODO check this condition 100 | def additive(dict_, thresh=0.3): 101 | cond = np.abs(np.array(dict_['mag']['additive'])-1)<=thresh 102 | return np.sum(cond)/len(dict_['mag']['additive']) 103 | 104 | def suppressor(dict_): 105 | return np.sum(np.array(dict_['mag']['suppressor'])<1)/len(dict_['mag']['suppressor']) 106 | 107 | def neomorphic(dict_): 108 | return np.sum(np.array(dict_['corr_fit']['neomorphic'])<0.85)/len(dict_['corr_fit']['neomorphic']) 109 | 110 | def redundant(dict_): 111 | return np.sum(np.array(dict_['dcor']['redundant'])>0.8)/len(dict_['dcor']['redundant']) 112 | 113 | def epistasis(dict_): 114 | return np.sum(np.array(dict_['dominance']['epistasis'])>0.25)/len(dict_['dominance']['epistasis']) 115 | 116 | 117 | res_p_seen1_dict = {} 118 | res_t_seen1_dict = {} 119 | 120 | res_p_seen2_dict = {} 121 | res_t_seen2_dict = {} 122 | 123 | 124 | # Set up data dictionaries 125 | for metric in ['mag', 'corr_fit', 'dcor', 'dominance', 'dcor_singles']: 126 | res_p_seen1_dict[metric], res_t_seen1_dict[metric] = get_t_p_seen1(metric) 127 | res_p_seen2_dict[metric], res_t_seen2_dict[metric] = get_t_p_seen2(metric) 128 | 129 | 130 | accuracy_seen2 = {} 131 | accuracy_seen1 = {} 132 | 133 | accuracy_seen2['synergy_similar_pheno'] = synergy_similar_pheno(res_p_seen2_dict) 134 | accuracy_seen2['synergy_dissimilar_pheno'] = synergy_dissimilar_pheno(res_p_seen2_dict) 135 | accuracy_seen2['potentiation'] = potentiation(res_p_seen2_dict) 136 | accuracy_seen2['additive'] = additive(res_p_seen2_dict) 137 | accuracy_seen2['suppressor'] = suppressor(res_p_seen2_dict) 138 | accuracy_seen2['neomorphic'] = neomorphic(res_p_seen2_dict) 139 | accuracy_seen2['redundant'] = redundant(res_p_seen2_dict) 140 | accuracy_seen2['epistasis'] = epistasis(res_p_seen2_dict) 141 | 142 | accuracy_seen1['synergy_similar_pheno'] = synergy_similar_pheno(res_p_seen1_dict) 143 | accuracy_seen1['synergy_dissimilar_pheno'] = synergy_dissimilar_pheno(res_p_seen1_dict) 144 | accuracy_seen1['potentiation'] = potentiation(res_p_seen1_dict) 145 | accuracy_seen1['additive'] = additive(res_p_seen1_dict) 146 | accuracy_seen1['suppressor'] = suppressor(res_p_seen1_dict) 147 | accuracy_seen1['neomorphic'] = neomorphic(res_p_seen1_dict) 148 | accuracy_seen1['redundant'] = redundant(res_p_seen1_dict) 149 | accuracy_seen1['epistasis'] = epistasis(res_p_seen1_dict) 150 | -------------------------------------------------------------------------------- /legacy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhr91/GEARS_misc/f88211870dfa89c38a2eedbd69ca1abd28a25f3c/legacy/__init__.py -------------------------------------------------------------------------------- /legacy/evaluate.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import argparse 3 | from time import time 4 | import sys 5 | 6 | import scanpy as sc 7 | import numpy as np 8 | 9 | import torch 10 | import torch.optim as optim 11 | import torch.nn as nn 12 | from torch.optim.lr_scheduler import StepLR 13 | 14 | from model import No_Perturb, PertNet 15 | from data import PertDataloader, GeneSimNetwork, GeneCoexpressNetwork 16 | from inference import evaluate, compute_metrics, deeper_analysis, GI_subgroup, non_dropout_analysis, non_zero_analysis 17 | from utils import loss_fct, uncertainty_loss_fct, parse_any_pert, get_coexpression_network_from_train, get_similarity_network 18 | 19 | torch.manual_seed(0) 20 | 21 | import warnings 22 | warnings.filterwarnings("ignore") 23 | 24 | 25 | 26 | def trainer(args): 27 | 28 | wandb_status = args['wandb'] 29 | device = args['device'] 30 | project_name = args['project_name'] 31 | entity_name = args['entity_name'] 32 | exp_name = args['exp_name'] 33 | 34 | model_name = args['model_name'] 35 | args = np.load('./saved_args/'+model_name+'.npy', allow_pickle = True).item() 36 | args['device'] = device 37 | 38 | ## set up wandb 39 | if wandb_status: 40 | import wandb 41 | wandb.init(project=project_name, entity=entity_name, name=exp_name) 42 | wandb.config.update(args) 43 | 44 | if args['dataset'] == 'Norman2019': 45 | data_path = '/dfs/project/perturb-gnn/datasets/Norman2019/Norman2019_hvg+perts_more_de.h5ad' 46 | elif args['dataset'] == 'Adamson2016': 47 | data_path = '/dfs/project/perturb-gnn/datasets/Adamson2016_hvg+perts_more_de_in_genes.h5ad' 48 | elif args['dataset'] == 'Dixit2016': 49 | data_path = '/dfs/project/perturb-gnn/datasets/Dixit2016_hvg+perts_more_de.h5ad' 50 | elif args['dataset'] == 'Norman2019_Adamson2016': 51 | data_path = '/dfs/project/perturb-gnn/datasets/trans_norman_adamson/norman2019.h5ad' 52 | 53 | s = time() 54 | adata = sc.read_h5ad(data_path) 55 | if 'gene_symbols' not in adata.var.columns.values: 56 | adata.var['gene_symbols'] = adata.var['gene_name'] 57 | gene_list = [f for f in adata.var.gene_symbols.values] 58 | 59 | # Pertrubation dataloader 60 | pertdl = PertDataloader(adata, args) 61 | 62 | model = torch.load('./saved_models/' + model_name) 63 | model.args = args 64 | 65 | if 'G_sim' in vars(model): 66 | if isinstance(model.G_sim, dict): 67 | for i,j in model.G_sim.items(): 68 | model.G_sim[i] = j.to(model.args['device']) 69 | 70 | for i,j in model.G_sim_weight.items(): 71 | model.G_sim_weight[i] = j.to(model.args['device']) 72 | else: 73 | model.G_sim = model.G_sim.to(model.args['device']) 74 | model.G_sim_weight = model.G_sim_weight.to(model.args['device']) 75 | 76 | best_model = model 77 | 78 | print('Start testing....') 79 | test_res = evaluate(pertdl.loaders['test_loader'],best_model, args) 80 | 81 | test_metrics, test_pert_res = compute_metrics(test_res) 82 | 83 | if wandb_status: 84 | metrics = ['mse', 'mae', 'spearman', 'pearson', 'r2'] 85 | for m in metrics: 86 | wandb.log({'test_' + m: test_metrics[m], 87 | 'test_de_'+m: test_metrics[m + '_de'] 88 | }) 89 | 90 | out = deeper_analysis(adata, test_res) 91 | out_non_dropout = non_dropout_analysis(adata, test_res) 92 | out_non_zero = non_zero_analysis(adata, test_res) 93 | GI_out = GI_subgroup(out) 94 | GI_out_non_dropout = GI_subgroup(out_non_dropout) 95 | GI_out_non_zero = GI_subgroup(out_non_zero) 96 | 97 | metrics = ['frac_in_range', 'frac_in_range_45_55', 'frac_in_range_40_60', 'frac_in_range_25_75', 'mean_sigma', 'std_sigma', 'frac_sigma_below_1', 'frac_sigma_below_2', 'pearson_delta', 98 | 'pearson_delta_de', 'fold_change_gap_all', 'pearson_delta_top200_hvg', 'fold_change_gap_upreg_3', 99 | 'fold_change_gap_downreg_0.33', 'fold_change_gap_downreg_0.1', 'fold_change_gap_upreg_10', 100 | 'pearson_top200_hvg', 'pearson_top200_de', 'pearson_top20_de', 'pearson_delta_top200_de', 101 | 'pearson_top100_de', 'pearson_delta_top100_de', 'pearson_delta_top50_de', 'pearson_top50_de', 'pearson_delta_top20_de', 102 | 'mse_top200_hvg', 'mse_top100_de', 'mse_top200_de', 'mse_top50_de', 'mse_top20_de', 'frac_correct_direction_all', 'frac_correct_direction_20', 'frac_correct_direction_50', 'frac_correct_direction_100', 'frac_correct_direction_200', 'frac_correct_direction_20_nonzero'] 103 | 104 | metrics_non_dropout = ['frac_correct_direction_top20_non_dropout', 'frac_opposite_direction_top20_non_dropout', 'frac_0/1_direction_top20_non_dropout', 'frac_correct_direction_non_zero', 'frac_correct_direction_non_dropout', 'frac_in_range_non_dropout', 'frac_in_range_45_55_non_dropout', 'frac_in_range_40_60_non_dropout', 'frac_in_range_25_75_non_dropout', 'mean_sigma_non_dropout', 'std_sigma_non_dropout', 'frac_sigma_below_1_non_dropout', 'frac_sigma_below_2_non_dropout', 'pearson_delta_top20_de_non_dropout', 'pearson_top20_de_non_dropout', 'mse_top20_de_non_dropout', 'frac_opposite_direction_non_dropout', 'frac_0/1_direction_non_dropout', 'frac_opposite_direction_non_zero', 'frac_0/1_direction_non_zero'] 105 | 106 | 107 | metrics_non_zero = ['frac_correct_direction_top20_non_zero', 'frac_opposite_direction_top20_non_zero', 'frac_0/1_direction_top20_non_zero', 'frac_in_range_non_zero', 'frac_in_range_45_55_non_zero', 'frac_in_range_40_60_non_zero', 'frac_in_range_25_75_non_zero', 'mean_sigma_non_zero', 'std_sigma_non_zero', 'frac_sigma_below_1_non_zero', 'frac_sigma_below_2_non_zero', 'pearson_delta_top20_de_non_zero', 'pearson_top20_de_non_zero', 'mse_top20_de_non_zero'] 108 | 109 | if args['wandb']: 110 | for m in metrics: 111 | wandb.log({'test_' + m: np.mean([j[m] for i,j in out.items() if m in j])}) 112 | 113 | for m in metrics_non_dropout: 114 | wandb.log({'test_' + m: np.mean([j[m] for i,j in out_non_dropout.items() if m in j])}) 115 | 116 | for m in metrics_non_zero: 117 | wandb.log({'test_' + m: np.mean([j[m] for i,j in out_non_zero.items() if m in j])}) 118 | 119 | 120 | if args['split'] == 'simulation': 121 | subgroup = pertdl.subgroup 122 | subgroup_analysis = {} 123 | for name in subgroup['test_subgroup'].keys(): 124 | subgroup_analysis[name] = {} 125 | for m in list(list(test_pert_res.values())[0].keys()): 126 | subgroup_analysis[name][m] = [] 127 | 128 | for name, pert_list in subgroup['test_subgroup'].items(): 129 | for pert in pert_list: 130 | for m, res in test_pert_res[pert].items(): 131 | subgroup_analysis[name][m].append(res) 132 | 133 | for name, result in subgroup_analysis.items(): 134 | for m in result.keys(): 135 | subgroup_analysis[name][m] = np.mean(subgroup_analysis[name][m]) 136 | if args['wandb']: 137 | wandb.log({'test_' + name + '_' + m: subgroup_analysis[name][m]}) 138 | 139 | print('test_' + name + '_' + m + ': ' + str(subgroup_analysis[name][m])) 140 | 141 | ## deeper analysis 142 | subgroup_analysis = {} 143 | for name in subgroup['test_subgroup'].keys(): 144 | subgroup_analysis[name] = {} 145 | for m in metrics: 146 | subgroup_analysis[name][m] = [] 147 | 148 | for m in metrics_non_dropout: 149 | subgroup_analysis[name][m] = [] 150 | 151 | for m in metrics_non_zero: 152 | subgroup_analysis[name][m] = [] 153 | 154 | for name, pert_list in subgroup['test_subgroup'].items(): 155 | for pert in pert_list: 156 | for m, res in out[pert].items(): 157 | subgroup_analysis[name][m].append(res) 158 | 159 | for m, res in out_non_dropout[pert].items(): 160 | subgroup_analysis[name][m].append(res) 161 | 162 | for m, res in out_non_zero[pert].items(): 163 | subgroup_analysis[name][m].append(res) 164 | 165 | 166 | for name, result in subgroup_analysis.items(): 167 | for m in result.keys(): 168 | subgroup_analysis[name][m] = np.mean(subgroup_analysis[name][m]) 169 | if args['wandb']: 170 | wandb.log({'test_' + name + '_' + m: subgroup_analysis[name][m]}) 171 | 172 | print('test_' + name + '_' + m + ': ' + str(subgroup_analysis[name][m])) 173 | 174 | for i,j in GI_out.items(): 175 | for m in ['mean_sigma', 'frac_in_range_45_55', 'frac_in_range_40_60', 'frac_in_range_25_75', 176 | 'fold_change_gap_all', 'pearson_delta_top200_de', 'pearson_delta_top100_de', 'pearson_delta_top50_de', 177 | 'mse_top200_de', 'mse_top100_de', 'mse_top50_de', 'mse_top20_de', 'pearson_delta_top20_de']: 178 | if args['wandb']: 179 | wandb.log({'test_' + i + '_' + m: j[m]}) 180 | 181 | 182 | for i,j in GI_out_non_dropout.items(): 183 | for m in ['frac_correct_direction_top20_non_dropout', 'mse_top20_de_non_dropout', 'pearson_delta_top20_de_non_dropout', 'frac_in_range_25_75_non_dropout', 'frac_sigma_below_1_non_dropout']: 184 | if args['wandb']: 185 | wandb.log({'test_' + i + '_' + m: j[m]}) 186 | 187 | 188 | for i,j in GI_out_non_zero.items(): 189 | for m in ['frac_correct_direction_top20_non_zero', 'mse_top20_de_non_zero', 'pearson_delta_top20_de_non_zero', 'frac_in_range_25_75_non_zero', 'frac_sigma_below_1_non_zero']: 190 | if args['wandb']: 191 | wandb.log({'test_' + i + '_' + m: j[m]}) 192 | 193 | 194 | print('Done!') 195 | 196 | 197 | def parse_arguments(): 198 | """ 199 | Argument parser 200 | """ 201 | 202 | # dataset arguments 203 | parser = argparse.ArgumentParser(description='Perturbation response') 204 | 205 | # wandb related 206 | parser.add_argument('--wandb', default=False, action='store_true', 207 | help='Use wandb or not') 208 | parser.add_argument('--project_name', type=str, default='pert_gnn', 209 | help='project name') 210 | parser.add_argument('--entity_name', type=str, default='kexinhuang', 211 | help='entity name') 212 | parser.add_argument('--exp_name', type=str, default='N/A', 213 | help='entity name') 214 | 215 | # misc 216 | parser.add_argument('--model_name', type=str, default='pert_gnn') 217 | parser.add_argument('--device', type=str, default='cuda') 218 | 219 | return dict(vars(parser.parse_args())) 220 | 221 | 222 | if __name__ == "__main__": 223 | 224 | #python evaluate.py --project_name pert_gnn_simulation_norman2019 \ 225 | # --exp_name no_perturb \ 226 | # --model_name no_perturb \ 227 | # --device cuda:7 \ 228 | # --wandb 229 | 230 | trainer(parse_arguments()) 231 | -------------------------------------------------------------------------------- /legacy/flow.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import tqdm 4 | from multiprocessing import Pool 5 | from sklearn import linear_model 6 | 7 | # ------------------------- 8 | # Flow method 9 | # ------------------------- 10 | 11 | def get_expression_data(filename, graph_df, p_thresh, TF_only=False, 12 | zero_TFs=False, TFs=None): 13 | # Read in expression data 14 | expression_data = pd.read_csv(filename, index_col=0) 15 | if expression_data.shape[1] <= 1: 16 | expression_data = pd.read_csv(filename, index_col=0, delimiter='\t') 17 | 18 | expression_genes = np.unique(np.array(graph_df.iloc[:,:2].values).flatten()) 19 | expression = expression_data[expression_data['gene_name'].isin(expression_genes)] 20 | try: 21 | expression = expression[expression['p_val_adj'] < p_thresh] # Seurat 22 | col = 'avg_log2FC' 23 | except: 24 | expression = expression[expression['adj.P.Val'] < p_thresh] # limma 25 | col = 'logFC' 26 | 27 | if zero_TFs: 28 | missed_TFs = [t for t in TFs if t not in expression[ 29 | 'gene_name'].values] 30 | missed_TFs = pd.DataFrame({'gene_name': missed_TFs, col: 0}) 31 | expression = expression.append(missed_TFs, ignore_index=True) 32 | 33 | return expression 34 | 35 | 36 | def get_expression_lambda(modelname): 37 | try: 38 | df = pd.read_csv('../Data/'+modelname, index_col=0).drop(columns=[ 39 | 'index']) 40 | except: 41 | df = pd.read_csv('../Data/'+modelname, index_col=0) 42 | return df.rename(columns={'avg_log2FC':'logFC'}) 43 | 44 | 45 | def get_graph(name, TF_only=False, top=None): 46 | # Read in TF network 47 | if (TF_only): 48 | df = pd.read_csv('/dfs/user/yhr/cell_reprogram/Data/transcription_networks/TF_only_'+name, 49 | header=None) 50 | elif top is not None: 51 | df = pd.read_csv('/dfs/user/yhr/cell_reprogram/Data/transcription_networks/G_all_edges_top' 52 | +str(top)+'_'+name, header=None) 53 | else: 54 | df = pd.read_csv('/dfs/user/yhr/cell_reprogram/Data/transcription_networks/G_all_edges_'+name, 55 | header=None) 56 | return df 57 | 58 | 59 | def add_weight(G, u, v, weight): 60 | try: 61 | G.remove_edge(u, v) 62 | except: 63 | # If the edge doesn't exist don't add a weighted version 64 | # return 65 | pass 66 | G.add_edge(u, v, weight=weight) 67 | 68 | 69 | # Set diagonal elements to 1 only for TFs 70 | def get_TFs(species): 71 | if species =='mouse': 72 | TFs = pd.read_csv('/dfs/user/yhr/cell_reprogram/Data/TF_names/mouse_tf_gene_names.txt', 73 | delimiter='\t',header=None).iloc[:,0].values 74 | elif species=='human': 75 | TFs = pd.read_csv('/dfs/user/yhr/cell_reprogram/Data/TF_names/TF_names_v_1.01_human.txt', 76 | delimiter='\t',header=None).iloc[:,0].values 77 | return np.unique(TFs, return_counts=True)[0] 78 | 79 | 80 | def I_TF(A, expression, lamb, TFs): 81 | TF_idx = np.where(expression['gene_name'].isin(TFs).values) 82 | res = np.array(A).copy() 83 | 84 | # If A is a matrix 85 | if len(res.shape) > 1: 86 | if res.shape[0] == res.shape[1]: 87 | for i in TF_idx: 88 | res[i, i] += lamb 89 | return res 90 | 91 | # If A is a vetor 92 | for i in TF_idx: 93 | res[i] *= lamb 94 | return res 95 | 96 | 97 | def get_model(A, y, expression, lamb=1): 98 | n_col = A.shape[1] 99 | try: 100 | sol = np.linalg.lstsq(A + I_TF(A, expression, lamb), I_TF(y, expression, 2)) 101 | sol = np.array([float(i) for i in sol[0]]) 102 | except: 103 | sol = np.zeros(len(expression)) 104 | return sol 105 | 106 | def solve_parallel2(A,B, expression, lambdas, threads): 107 | pool = Pool() 108 | for l in lambdas: 109 | print('Lambda: ' + str(l)) 110 | iter = list([(A, B, expression, l)] * threads) 111 | 112 | results = pool.starmap(get_model, iter) 113 | 114 | for j in range(threads): 115 | expression[str(l)+'_'+str(j)] = results[j] 116 | return expression 117 | 118 | def Map(F, x, args, workers): 119 | """ 120 | wrapper for map() 121 | Spawn workers for parallel processing 122 | 123 | """ 124 | iter_ = ((xi, args) for xi in x) 125 | with Pool(workers) as pool: 126 | ret = pool.starmap(F, iter_) 127 | #ret = list(tqdm.tqdm(pool.starmap(F, iter_), total=len(x))) 128 | return ret 129 | 130 | 131 | def mapper(l, args): 132 | A, B, positive = args 133 | return solve_lasso(A, B, lamb=l, positive=positive) 134 | 135 | 136 | def solve_parallel(A,B, expression, lambdas, positive=False, 137 | workers=10): 138 | args = (A, B, positive) 139 | exp_df_list = Map(mapper, lambdas, args, workers=workers) 140 | 141 | dict_ = {l: exp for l, exp in zip(lambdas, exp_df_list)} 142 | exp_df =pd.DataFrame(dict_) 143 | for c in expression.columns: 144 | exp_df[c] = expression.reset_index()[c] 145 | 146 | return exp_df 147 | 148 | 149 | def solve_lasso(A,B,lamb, positive): 150 | print('Lambda: '+ str(lamb)) 151 | clf = linear_model.Lasso(alpha=lamb, positive=positive) 152 | clf.fit(A,B) 153 | return clf.coef_ 154 | 155 | def solve(A,B, expression, lambdas, positive=False): 156 | exp_df = expression 157 | for l in lambdas: 158 | exp_df[str(l)] = solve_lasso(A,B, lamb=l, positive=positive) 159 | return exp_df 160 | -------------------------------------------------------------------------------- /legacy/genes_with_hi_mean.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhr91/GEARS_misc/f88211870dfa89c38a2eedbd69ca1abd28a25f3c/legacy/genes_with_hi_mean.npy -------------------------------------------------------------------------------- /legacy/learn_weights.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pandas as pd 3 | import numpy as np 4 | import scanpy as sc 5 | import networkx as nx 6 | sys.path.append('../model/') 7 | from flow import get_graph 8 | from sklearn.model_selection import train_test_split 9 | from sklearn import linear_model 10 | from sklearn.neural_network import MLPRegressor 11 | from sklearn.metrics import r2_score 12 | import argparse 13 | 14 | no_model_count = 0 15 | 16 | def nonzero_idx(mat): 17 | mat=pd.DataFrame(mat) 18 | return mat[(mat > 0).sum(1) > 0].index.values 19 | 20 | def data_split(X, y, size=0.1): 21 | nnz = list(set(nonzero_idx(X)).intersection(set(nonzero_idx(y)))) 22 | 23 | if len(nnz) <= 5: 24 | global no_model_count 25 | no_model_count += 1 26 | 27 | return -1,-1 28 | 29 | train_split, val_split = train_test_split(nnz, test_size=size, 30 | random_state=42) 31 | return train_split, val_split 32 | 33 | def train_regressor(X, y, kind='linear'): 34 | 35 | if kind == 'linear' or kind == 'linear_TM': 36 | model = linear_model.LinearRegression() 37 | elif kind == 'lasso': 38 | model = linear_model.Lasso(alpha=10e4) 39 | elif kind == 'elasticnet' or kind == 'elasticnet_TM': 40 | model = linear_model.ElasticNet(alpha=10, l1_ratio=0.5, 41 | max_iter=1000) 42 | elif kind == 'ridge': 43 | model = linear_model.Ridge(alpha=10e4, max_iter=1000) 44 | elif kind == 'MLP' or kind == 'MLP_TM': 45 | model = MLPRegressor(hidden_layer_sizes=(20,10), max_iter=1000) 46 | 47 | reg = model.fit(X, y) 48 | loss = np.sqrt(np.mean((y - model.predict(X))**2)) 49 | return reg, loss, reg.score(X, y) 50 | 51 | 52 | def evaluate_regressor(model, X, y): 53 | y_cap = model.predict(X) 54 | loss = np.sqrt(np.mean((y - y_cap)**2)) 55 | 56 | return loss, y, y_cap 57 | 58 | def init_dict(): 59 | d = {} 60 | d['linear'] = [] 61 | d['linear_TM'] = [] 62 | d['ones'] = [] 63 | d['lasso'] = [] 64 | d['elasticnet'] = [] 65 | d['elasticnet_TM'] = [] 66 | d['ridge'] = [] 67 | d['MLP'] = [] 68 | d['MLP_TM'] = [] 69 | return d 70 | 71 | # Looks at the median of max expression across cells/not genes 72 | def max_median_norm(df): 73 | return df/df.max().median() 74 | 75 | def get_weights(adj_mat, exp_adata, nodelist, lim=50000): 76 | models = init_dict() 77 | val_loss = init_dict() 78 | train_loss = init_dict() 79 | train_score = init_dict() 80 | preds = init_dict() 81 | trues = init_dict() 82 | adj_list = {} 83 | test_splits = {} 84 | 85 | adj_list['TF'] = []; adj_list['target'] = []; adj_list['importance'] = []; 86 | #X, X_TM = set_up_TM_data(X) 87 | #TM_rows = [c for c in X_TM.index if '_TM' in c] 88 | 89 | adj_mat_idx = np.arange(len(adj_mat)) 90 | np.random.shuffle(adj_mat_idx) 91 | count = 0 92 | 93 | def trainer(kind, feats, y, train_split, val_split): 94 | model, train_loss_, train_score_ = train_regressor( 95 | feats[train_split,:], 96 | y[train_split], kind=kind) 97 | val_loss_, true, pred = evaluate_regressor(model, 98 | feats[val_split, :], 99 | y[val_split]) 100 | 101 | # Store results 102 | val_loss[kind].append(val_loss_) 103 | train_loss[kind].append(train_loss_) 104 | train_score[kind].append(train_score_) 105 | trues[kind].extend(true) 106 | preds[kind].extend(pred) 107 | try: models[kind].append(model.coef_); 108 | except: pass; 109 | 110 | print('Total genes: ', str(len(adj_mat_idx))) 111 | for itr in adj_mat_idx: 112 | i = adj_mat[itr] 113 | if i.sum() > 0: 114 | idx = np.where(i > 0)[1] 115 | TFs = np.array(nodelist)[idx] 116 | target = np.array(nodelist)[itr] 117 | 118 | try: 119 | feats = exp_adata[:, TFs].X.toarray() 120 | y = exp_adata[:, target].X.toarray() 121 | except: 122 | continue 123 | train_split, test_split = data_split(feats, y, size=0.1) 124 | if train_split==-1: continue; 125 | 126 | feats = feats[train_split,:] 127 | train_split, val_split = data_split(feats, y, size=0.1) 128 | if train_split==-1: continue; 129 | 130 | # Add data from TM 131 | #feats_TM = X_TM.loc[:, TFs] 132 | #y_TM = X_TM.loc[:, target] 133 | 134 | # Linear Regression models 135 | trainer('linear', feats, y, train_split, val_split) 136 | #trainer('linear_TM', feats_TM, y_TM, train_split+TM_rows, val_split) 137 | #trainer('ridge', feats, y, train_split, val_split) 138 | #trainer('MLP', feats, y, train_split, val_split) 139 | #trainer('MLP_TM', feats_TM, y_TM, train_split+TM_rows, val_split) 140 | #trainer('elasticnet', feats, y, train_split, val_split) 141 | #trainer('elasticnet_TM', feats_TM, y_TM, train_split+TM_rows, 142 | # val_split) 143 | 144 | # All edges are 1 145 | model = linear_model.LinearRegression() 146 | model.coef_ = np.ones(len(idx)) 147 | model.intercept_ = 0 148 | val_loss_, true, pred = evaluate_regressor(model, feats[ 149 | val_split,:], 150 | y[val_split]) 151 | val_loss['ones'].append(val_loss_) 152 | trues['ones'].extend(true) 153 | preds['ones'].extend(pred) 154 | 155 | # Add row to new weight matrix 156 | for j,k in enumerate(TFs): 157 | adj_list['TF'].append(k) 158 | adj_list['target'].append(target) 159 | adj_list['importance'].append(models['linear'][-1][0][j]) 160 | 161 | # Save the test split for use later 162 | test_splits[target] = test_split 163 | print(count) 164 | count += 1 165 | 166 | if count >= lim: 167 | break 168 | return models, adj_list, test_splits 169 | 170 | 171 | def main(args): 172 | exp_adata = sc.read_h5ad(args.exp_matrix) 173 | G = pd.read_csv(args.graph_name, header=None) 174 | G = nx.from_pandas_edgelist(G, source=0, 175 | target=1, create_using=nx.DiGraph()) 176 | adj_mat = nx.linalg.graphmatrix.adjacency_matrix(G).todense().T 177 | nodelist = [n for n in G.nodes()] 178 | 179 | # Remove self-edges 180 | np.fill_diagonal(adj_mat, 0) 181 | 182 | models, adj_list, test_splits = get_weights(adj_mat, exp_adata, nodelist, lim=1000) 183 | 184 | # Save final results 185 | #np.save('train_loss', specs[]) 186 | np.save('test_splits', test_splits) 187 | 188 | if args.out_name is None: 189 | args.out_name = args.graph_name 190 | pd.DataFrame(adj_list).to_csv(args.out_name + '_learntweights.csv') 191 | 192 | # Convert coefficients into new weight matrix 193 | print('Done') 194 | 195 | 196 | if __name__ == '__main__': 197 | parser = argparse.ArgumentParser( 198 | description='Set model hyperparametrs.', 199 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 200 | #torch.cuda.set_device(4) 201 | 202 | parser.add_argument('--exp_matrix', type=str, 203 | help='Expression matrix') 204 | parser.add_argument('--graph_name', type=str, 205 | help='Graph filename') 206 | parser.add_argument('--out_name', type=str, 207 | help='Output filename') 208 | 209 | 210 | parser.set_defaults( 211 | exp_matrix = './temp/Norman2019_split5.h5ad', 212 | graph_name='./temp/Norman2019_split5_pearson.txt', 213 | out_name=None) 214 | 215 | args = parser.parse_args() 216 | main(args) 217 | -------------------------------------------------------------------------------- /legacy/linear_pert_model.py: -------------------------------------------------------------------------------- 1 | from flow import get_graph, get_expression_data,\ 2 | add_weight, get_TFs, solve,\ 3 | solve_parallel, get_expression_lambda 4 | import networkx as nx 5 | import numpy as np 6 | import pandas as pd 7 | 8 | 9 | # Linear model for simulating linear perturbation effects 10 | class linear_model(): 11 | def __init__(self, graph_path, weights_path, gene_list, 12 | binary=False, pos_edges=False, hops=3, 13 | species='human'): 14 | self.TFs = get_TFs(species) 15 | self.gene_list = gene_list 16 | 17 | # Set up graph structure 18 | G_df = get_graph(name = graph_path, TF_only=False) 19 | print('Edges: '+str(len(G_df))) 20 | self.G = nx.from_pandas_edgelist(G_df, source=0, 21 | target=1, create_using=nx.DiGraph()) 22 | 23 | # Add edge weights 24 | self.read_weights = pd.read_csv(weights_path, index_col=0) 25 | try: 26 | self.read_weights = self.read_weights.set_index('TF') 27 | except: 28 | pass 29 | 30 | # Get adjacency matrix 31 | self.adj_mat = self.create_adj_mat() 32 | 33 | A = self.adj_mat.T 34 | if binary and pos_edges: 35 | A = np.array(A != 0).astype('float') 36 | 37 | # Set the diagonal elements to zero everywhere except the TFs 38 | np.fill_diagonal(A, 0) 39 | each_hop = A.copy() 40 | last_hop = A.copy() 41 | for k in range(hops-1): 42 | last_hop = last_hop @ each_hop 43 | if binary: 44 | A += last_hop/(k+2) 45 | else: 46 | A += last_hop 47 | self.A = A 48 | 49 | 50 | def create_adj_mat(self): 51 | # Create a df version of the graph for merging 52 | G_df = pd.DataFrame(self.G.edges(), columns=['TF', 'target']) 53 | 54 | # Merge it with the weights DF 55 | weighted_G_df = self.read_weights.merge(G_df, on=['TF', 'target']) 56 | for w in weighted_G_df.iterrows(): 57 | add_weight(self.G, w[1]['TF'], w[1]['target'], w[1]['importance']) 58 | 59 | # Get an adjacency matrix based on the gene ordering from the DE list 60 | return nx.linalg.graphmatrix.adjacency_matrix( 61 | self.G, nodelist=self.gene_list).todense() 62 | 63 | 64 | def simulate_pert(self, pert_genes, pert_mags=None): 65 | """ 66 | Returns predicted differential expression (delta) upon perturbing 67 | a list of genes 'pert_genes' 68 | """ 69 | 70 | # Create perturbation vector 71 | pert_idx = np.where([(g in pert_genes) for g in self.gene_list])[0] 72 | theta = np.zeros([len(self.gene_list),1]) 73 | 74 | # Set up the input vector 75 | if pert_mags is None: 76 | pert_mags = np.ones(len(pert_genes)) 77 | for idx, pert_mag in zip(pert_idx, pert_mags): 78 | theta[pert_idx] = pert_mag 79 | 80 | # Compute differential expression vector 81 | delta = np.dot(self.A, theta) 82 | delta = np.squeeze(np.array(delta)) 83 | 84 | # Add the perturbation magnitude directly for the TF 85 | delta = delta + np.squeeze(theta) 86 | 87 | return delta 88 | -------------------------------------------------------------------------------- /legacy/pertnet.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhr91/GEARS_misc/f88211870dfa89c38a2eedbd69ca1abd28a25f3c/legacy/pertnet.py -------------------------------------------------------------------------------- /legacy/plot_functions.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | def get_de(res, query): 5 | """ 6 | Given a query perturbation and model output file, 7 | return predicted and true post-perturbation expression 8 | """ 9 | query_idx = np.where(res['pert_cat'] == query)[0] 10 | de = {"pred_de": res['pred_de'][query_idx], 11 | "truth_de": res['truth_de'][query_idx]} 12 | return de 13 | 14 | def get_de_ctrl(pert, adata): 15 | """ 16 | Get ctrl expression for DE genes for a given perturbation 17 | """ 18 | mean_ctrl_exp = adata[adata.obs['condition'] == 'ctrl'].to_df().mean() 19 | de_genes = get_covar_genes(pert, adata) 20 | return mean_ctrl_exp[de_genes] 21 | 22 | def get_covar_genes(p, adata): 23 | """ 24 | Get genes that are differentially expressed 25 | """ 26 | gene_name_dict = adata.var.loc[:,'gene_name'].to_dict() 27 | pert_name = 'A549_'+p+'_1+1' 28 | de_genes = adata.uns['rank_genes_groups_cov'][pert_name] 29 | return de_genes 30 | 31 | def create_boxplot(res, adata, query, genes=None): 32 | """ 33 | Create a boxplot showing true, predicted and control expression 34 | for a given perturbation 35 | """ 36 | 37 | plt.figure(figsize=[10,3]) 38 | plt.title(query) 39 | pert_de_res = get_de(res, query)['pred_de'] 40 | truth_de_res = get_de(res, query)['truth_de'] 41 | plt.boxplot(truth_de_res, showfliers=False, 42 | medianprops = dict(linewidth=0)) 43 | ctrl_means = get_de_ctrl(query, adata).values 44 | 45 | for i in range(pert_de_res.shape[1]): 46 | _ = plt.scatter(i+1, np.mean(pert_de_res[:,i]), color='red') 47 | _ = plt.scatter(i+1, ctrl_means[i], color='forestgreen', marker='*') 48 | 49 | ax = plt.gca() 50 | if genes is not None: 51 | ax.xaxis.set_ticklabels(genes) 52 | else: 53 | ax.xaxis.set_ticklabels(['G1','G2','G3','G4','G5','G6','G7','G8','G9', 'G10', 54 | 'G11','G12','G13','G14','G15','G16','G17','G18','G19', 'G20']) 55 | -------------------------------------------------------------------------------- /paper/CPA_reproduce/README.md: -------------------------------------------------------------------------------- 1 | For downloading preprocessed file for Norman 2019 please use the following [link](https://dataverse.harvard.edu/api/access/datafile/6881912) 2 | 3 | - First run cpa.sh (eg: ~$ `bash cpa.sh ./data/Norman2019 2`) The second arugment here is the CUDA device number 4 | - Then run cpa_to_wandb.py to compute metrics 5 | -------------------------------------------------------------------------------- /paper/CPA_reproduce/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /paper/CPA_reproduce/collect_results.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | import json 5 | import pprint 6 | import argparse 7 | import numpy as np 8 | from model_selection import * 9 | 10 | def run_collect_results(save_dir, one_line, metric='onlyDEmeans'): 11 | records = [] 12 | for fname in os.listdir(save_dir): 13 | if fname.endswith(".out"): 14 | full_dir = os.path.join(save_dir, fname) 15 | records_file = [] 16 | with open(full_dir, "r") as f: 17 | for line in f.readlines(): 18 | if line.startswith("{"): 19 | records_file.append(json.loads(line)) 20 | records.append(records_file) 21 | 22 | if len(records) == 0: 23 | return 24 | 25 | best_score = None 26 | best_record = None 27 | best_epoch = None 28 | for r, record in enumerate(records): 29 | for e, epoch in enumerate(record): 30 | if "evaluation_stats" in epoch: 31 | if epoch["evaluation_stats"]["optimal for covariates"] == 1: 32 | epoch["evaluation_stats"]["optimal for covariates"] = 0 33 | 34 | if metric == 'all': 35 | this_score = np.mean(epoch["evaluation_stats"]["test"]) 36 | elif metric == 'onlyDEmeans': 37 | this_score = epoch["evaluation_stats"]["test"][1] 38 | elif metric == 'onlyDE': 39 | this_score = epoch["evaluation_stats"]["test"][1]+\ 40 | epoch["evaluation_stats"]["test"][3] 41 | elif metric == 'woDE': 42 | this_score = (epoch["evaluation_stats"]["test"][0]+\ 43 | epoch["evaluation_stats"]["test"][2])/2 44 | else: 45 | raise NotImplementedError 46 | 47 | this_score -= abs(epoch["evaluation_stats"]["perturbation disentanglement"] -\ 48 | epoch["evaluation_stats"]["optimal for perturbations"])/2 +\ 49 | abs(epoch["evaluation_stats"]["covariate disentanglement"] -\ 50 | epoch["evaluation_stats"]["optimal for covariates"])/2 51 | 52 | if best_score is None or this_score > best_score: 53 | best_score = this_score 54 | best_record = r 55 | best_epoch = e 56 | 57 | best_stats = { 58 | "training_args": records[best_record][0]["training_args"], 59 | "autoencoder_params": records[best_record][1]["autoencoder_params"], 60 | "best_epoch": records[best_record][best_epoch]["epoch"], 61 | "best_stats": records[best_record][best_epoch]["evaluation_stats"] 62 | 63 | } 64 | 65 | best_stats.update({ 66 | "best_file": "{}/model_seed={}_epoch={}.pt".format( 67 | best_stats["training_args"]["save_dir"], 68 | best_stats["training_args"]["seed"], 69 | best_stats["best_epoch"])}) 70 | 71 | if "path" in best_stats["training_args"]: 72 | dataset_key = "path" 73 | else: 74 | dataset_key = "dataset_path" 75 | 76 | if one_line: 77 | print("{:>40}: [{:.3f}, {:.3f}, {:.3f}, {:.3f}] ({:.3f}, {:.3f})".format( 78 | best_stats["training_args"][dataset_key], 79 | *best_stats["best_stats"]["ood"], 80 | np.mean(best_stats["best_stats"]["ood"]), 81 | best_stats["best_stats"]["disentanglement"])) 82 | 83 | else: 84 | pprint.pprint(best_stats, indent=2) 85 | 86 | # get_best_plots(best_stats["best_file"]) 87 | 88 | 89 | if __name__ == "__main__": 90 | parser = argparse.ArgumentParser(description='Collect results.') 91 | parser.add_argument('--save_dir', type=str, required=True) 92 | parser.add_argument('--one_line', action="store_true") 93 | parser.add_argument('--metric', type=str, default='onlyDEmeans') 94 | args = parser.parse_args() 95 | run_collect_results(args.save_dir, args.one_line, args.metric) 96 | -------------------------------------------------------------------------------- /paper/CPA_reproduce/cpa.sh: -------------------------------------------------------------------------------- 1 | for seed in 1 2 3 4 5 2 | do 3 | 4 | mkdir "$1_split${seed}" 5 | 6 | python train.py --dataset_path ./data/$1_simulation_cpa.h5ad \ 7 | --dataset $1 \ 8 | --split_key "split${seed}" \ 9 | --save_dir "$1_split${seed}" \ 10 | --cuda $2 \ 11 | #--emb "kg" 12 | done 13 | 14 | 15 | ## example script 16 | # bash cpa.sh jost2020_hvg 5 17 | # bash cpa.sh tian2019_ipsc_hvg 5 18 | # bash cpa.sh replogle2020_hvg 2 19 | # bash cpa.sh replogle_rpe1_gw_filtered_hvg 2 20 | # bash cpa.sh replogle_k562_essential_filtered_hvg 7 21 | 22 | ## to run CPA+KG, simply add a flag --emb "kg" 23 | -------------------------------------------------------------------------------- /paper/CPA_reproduce/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import warnings 4 | import torch 5 | 6 | import numpy as np 7 | 8 | warnings.simplefilter(action='ignore', category=FutureWarning) 9 | import scanpy as sc 10 | import pandas as pd 11 | 12 | from sklearn.preprocessing import OneHotEncoder 13 | 14 | def ranks_to_df(data, key='rank_genes_groups'): 15 | """Converts an `sc.tl.rank_genes_groups` result into a MultiIndex dataframe. 16 | 17 | You can access various levels of the MultiIndex with `df.loc[[category]]`. 18 | 19 | Params 20 | ------ 21 | data : `AnnData` 22 | key : str (default: 'rank_genes_groups') 23 | Field in `.uns` of data where `sc.tl.rank_genes_groups` result is 24 | stored. 25 | """ 26 | d = data.uns[key] 27 | dfs = [] 28 | for k in d.keys(): 29 | if k == 'params': 30 | continue 31 | series = pd.DataFrame.from_records(d[k]).unstack() 32 | series.name = k 33 | dfs.append(series) 34 | 35 | return pd.concat(dfs, axis=1) 36 | 37 | 38 | class Dataset: 39 | def __init__(self, 40 | fname, 41 | perturbation_key, 42 | dose_key, 43 | cell_type_key, 44 | split_key='split'): 45 | print('---1') 46 | data = sc.read(fname) 47 | print('---2') 48 | self.perturbation_key = perturbation_key 49 | self.dose_key = dose_key 50 | self.cell_type_key = cell_type_key 51 | try: 52 | self.genes = torch.Tensor(data.X.A) 53 | except: 54 | self.genes = torch.Tensor(data.X) 55 | 56 | self.var_names = data.var_names 57 | print('---3') 58 | self.pert_categories = np.array(data.obs['cov_drug_dose_name'].values) 59 | self.de_genes = {i: j[:20] for i, j in data.uns['rank_genes_groups_cov_all'].items()} 60 | #self.de_genes = data.uns['rank_genes_groups_cov'] 61 | #self.de_genes = data.uns['top_non_dropout_de_20'] 62 | 63 | self.ctrl = data.obs['control'].values 64 | self.ctrl_name = list(np.unique(data[data.obs['control'] == 1].obs[self.perturbation_key])) 65 | 66 | self.drugs_names = np.array(data.obs[perturbation_key].values) 67 | self.dose_names = np.array(data.obs[dose_key].values) 68 | print('---4') 69 | # get unique drugs 70 | drugs_names_unique = set() 71 | for d in self.drugs_names: 72 | [drugs_names_unique.add(i) for i in d.split("+")] 73 | self.drugs_names_unique = np.array(list(drugs_names_unique)) 74 | 75 | # save encoder for a comparison with Mo's model 76 | # later we need to remove this part 77 | encoder_drug = OneHotEncoder(sparse=False) 78 | encoder_drug.fit(self.drugs_names_unique.reshape(-1, 1)) 79 | print('---5') 80 | self.atomic_drugs_dict = dict(zip(self.drugs_names_unique, encoder_drug.transform( 81 | self.drugs_names_unique.reshape(-1, 1)))) 82 | 83 | # get drug combinations 84 | drugs = [] 85 | for i, comb in enumerate(self.drugs_names): 86 | drugs_combos = encoder_drug.transform( 87 | np.array(comb.split("+")).reshape(-1, 1)) 88 | dose_combos = str(data.obs[dose_key].values[i]).split("+") 89 | for j, d in enumerate(dose_combos): 90 | if j == 0: 91 | drug_ohe = float(d) * drugs_combos[j] 92 | else: 93 | drug_ohe += float(d) * drugs_combos[j] 94 | drugs.append(drug_ohe) 95 | self.drugs = torch.Tensor(drugs) 96 | print('---6') 97 | self.cell_types_names = np.array(data.obs[cell_type_key].values) 98 | self.cell_types_names_unique = np.unique(self.cell_types_names) 99 | 100 | encoder_ct = OneHotEncoder(sparse=False) 101 | encoder_ct.fit(self.cell_types_names_unique.reshape(-1, 1)) 102 | 103 | self.atomic_сovars_dict = dict(zip(list(self.cell_types_names_unique), encoder_ct.transform( 104 | self.cell_types_names_unique.reshape(-1, 1)))) 105 | 106 | self.cell_types = torch.Tensor(encoder_ct.transform( 107 | self.cell_types_names.reshape(-1, 1))).float() 108 | 109 | self.num_cell_types = len(self.cell_types_names_unique) 110 | self.num_genes = self.genes.shape[1] 111 | self.num_drugs = len(self.drugs_names_unique) 112 | print('---7') 113 | self.indices = { 114 | "all": list(range(len(self.genes))), 115 | "control": np.where(data.obs['control'] == 1)[0].tolist(), 116 | "treated": np.where(data.obs['control'] != 1)[0].tolist(), 117 | "train": np.where(data.obs[split_key] == 'train')[0].tolist(), 118 | "test": np.where(data.obs[split_key] == 'test')[0].tolist(), 119 | "ood": np.where(data.obs[split_key] == 'ood')[0].tolist() 120 | } 121 | 122 | atomic_ohe = encoder_drug.transform( 123 | self.drugs_names_unique.reshape(-1, 1)) 124 | 125 | self.drug_dict = {} 126 | for idrug, drug in enumerate(self.drugs_names_unique): 127 | i = np.where(atomic_ohe[idrug] == 1)[0][0] 128 | self.drug_dict[i] = drug 129 | 130 | 131 | 132 | def subset(self, split, condition="all"): 133 | idx = list(set(self.indices[split]) & set(self.indices[condition])) 134 | return SubDataset(self, idx) 135 | 136 | def __getitem__(self, i): 137 | return self.genes[i], self.drugs[i], self.cell_types[i] 138 | 139 | def __len__(self): 140 | return len(self.genes) 141 | 142 | 143 | class SubDataset: 144 | """ 145 | Subsets a `Dataset` by selecting the examples given by `indices`. 146 | """ 147 | 148 | def __init__(self, dataset, indices): 149 | self.perturbation_key = dataset.perturbation_key 150 | self.dose_key = dataset.dose_key 151 | self.covars_key = dataset.cell_type_key 152 | 153 | self.perts_dict = dataset.atomic_drugs_dict 154 | self.covars_dict = dataset.atomic_сovars_dict 155 | 156 | self.genes = dataset.genes[indices] 157 | self.drugs = dataset.drugs[indices] 158 | self.cell_types = dataset.cell_types[indices] 159 | 160 | self.drugs_names = dataset.drugs_names[indices] 161 | self.pert_categories = dataset.pert_categories[indices] 162 | self.cell_types_names = dataset.cell_types_names[indices] 163 | 164 | self.var_names = dataset.var_names 165 | self.de_genes = dataset.de_genes 166 | self.ctrl_name = dataset.ctrl_name[0] 167 | 168 | self.num_cell_types = dataset.num_cell_types 169 | self.num_genes = dataset.num_genes 170 | self.num_drugs = dataset.num_drugs 171 | 172 | def __getitem__(self, i): 173 | return self.genes[i], self.drugs[i], self.cell_types[i] 174 | 175 | def __len__(self): 176 | return len(self.genes) 177 | 178 | 179 | def load_dataset_splits( 180 | dataset_path, 181 | perturbation_key, 182 | dose_key, 183 | cell_type_key, 184 | split_key, 185 | return_dataset=False): 186 | print('--1') 187 | dataset = Dataset(dataset_path, 188 | perturbation_key, 189 | dose_key, 190 | cell_type_key, 191 | split_key) 192 | print('--2') 193 | splits = { 194 | "training": dataset.subset("train", "all"), 195 | "training_control": dataset.subset("train", "control"), 196 | "training_treated": dataset.subset("train", "treated"), 197 | "test": dataset.subset("test", "all"), 198 | "test_control": dataset.subset("test", "control"), 199 | "test_treated": dataset.subset("test", "treated"), 200 | "ood": dataset.subset("ood", "all") 201 | } 202 | 203 | if return_dataset: 204 | return splits, dataset 205 | else: 206 | return splits 207 | -------------------------------------------------------------------------------- /paper/CPA_reproduce/helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import scanpy as sc 4 | import pandas as pd 5 | 6 | def rank_genes_groups_by_cov( 7 | adata, 8 | groupby, 9 | control_group, 10 | covariate, 11 | pool_doses=False, 12 | n_genes=50, 13 | rankby_abs=True, 14 | key_added='rank_genes_groups_cov', 15 | return_dict=False, 16 | ): 17 | 18 | """ 19 | Function that generates a list of differentially expressed genes computed 20 | separately for each covariate category, and using the respective control 21 | cells as reference. 22 | 23 | Usage example: 24 | 25 | rank_genes_groups_by_cov( 26 | adata, 27 | groupby='cov_product_dose', 28 | covariate_key='cell_type', 29 | control_group='Vehicle_0' 30 | ) 31 | 32 | Parameters 33 | ---------- 34 | adata : AnnData 35 | AnnData dataset 36 | groupby : str 37 | Obs column that defines the groups, should be 38 | cartesian product of covariate_perturbation_cont_var, 39 | it is important that this format is followed. 40 | control_group : str 41 | String that defines the control group in the groupby obs 42 | covariate : str 43 | Obs column that defines the main covariate by which we 44 | want to separate DEG computation (eg. cell type, species, etc.) 45 | n_genes : int (default: 50) 46 | Number of DEGs to include in the lists 47 | rankby_abs : bool (default: True) 48 | If True, rank genes by absolute values of the score, thus including 49 | top downregulated genes in the top N genes. If False, the ranking will 50 | have only upregulated genes at the top. 51 | key_added : str (default: 'rank_genes_groups_cov') 52 | Key used when adding the dictionary to adata.uns 53 | return_dict : str (default: False) 54 | Signals whether to return the dictionary or not 55 | 56 | Returns 57 | ------- 58 | Adds the DEG dictionary to adata.uns 59 | 60 | If return_dict is True returns: 61 | gene_dict : dict 62 | Dictionary where groups are stored as keys, and the list of DEGs 63 | are the corresponding values 64 | 65 | """ 66 | 67 | gene_dict = {} 68 | cov_categories = adata.obs[covariate].unique() 69 | for cov_cat in cov_categories: 70 | print(cov_cat) 71 | #name of the control group in the groupby obs column 72 | control_group_cov = '_'.join([cov_cat, control_group]) 73 | 74 | #subset adata to cells belonging to a covariate category 75 | adata_cov = adata[adata.obs[covariate]==cov_cat] 76 | 77 | #compute DEGs 78 | sc.tl.rank_genes_groups( 79 | adata_cov, 80 | groupby=groupby, 81 | reference=control_group_cov, 82 | rankby_abs=rankby_abs, 83 | n_genes=n_genes 84 | ) 85 | 86 | #add entries to dictionary of gene sets 87 | de_genes = pd.DataFrame(adata_cov.uns['rank_genes_groups']['names']) 88 | for group in de_genes: 89 | gene_dict[group] = de_genes[group].tolist() 90 | 91 | adata.uns[key_added] = gene_dict 92 | 93 | if return_dict: 94 | return gene_dict 95 | 96 | -------------------------------------------------------------------------------- /paper/CellOracle/CellOracle.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import sys\n", 11 | "\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "import numpy as np\n", 14 | "import pandas as pd\n", 15 | "import scanpy as sc\n", 16 | "import seaborn as sns\n", 17 | "\n", 18 | "import celloracle as co" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "adata = sc.read_h5ad('/dfs/project/perturb-gnn/datasets/Norman2019/Norman2019_hvg+perts.h5ad')\n", 28 | "#ctrl_adata = adata[adata.obs['condition']=='ctrl']\n", 29 | "#ctrl_adata.var = ctrl_adata.var.set_index('gene_name')" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 3, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "TF_names = pd.read_csv('TF_names_v_1.01.txt', delimiter='\\t', header=None)\n", 39 | "TF_names = TF_names.rename(columns={0:'Gene'})\n", 40 | "\n", 41 | "all_conds = [c.split('+') for c in adata.obs['condition'].values ]\n", 42 | "all_conds = [item for sublist in all_conds for item in sublist]\n", 43 | "all_conds = set(all_conds)\n", 44 | "\n", 45 | "# treat all perturbations as TFs\n", 46 | "# aug_TF_names = list(TF_names['Gene'].values) + list(all_conds)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 4, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "#sc.pp.subsample(adata, n_obs=500)\n", 56 | "#sc.pp.pca(adata)\n", 57 | "adata.var = adata.var.set_index('gene_name')\n", 58 | "adata.obs['label']='0'" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "name": "stdout", 68 | "output_type": "stream", 69 | "text": [ 70 | "5045 genes were found in the adata. Note that Celloracle is intended to use around 1000-3000 genes, so the behavior with this number of genes may differ from what is expected.\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "oracle = co.Oracle()\n", 76 | "oracle.import_anndata_as_raw_count(adata=adata,\n", 77 | " cluster_column_name='condition',\n", 78 | " embedding_name='X_pca')\n", 79 | "\n", 80 | "oracle.perform_PCA()\n", 81 | "\n", 82 | "n_comps = np.where(np.diff(np.diff(np.cumsum(oracle.pca.explained_variance_ratio_))>0.002))[0][0]\n", 83 | "n_cell = oracle.adata.shape[0]\n", 84 | "k = int(0.025*n_cell)\n", 85 | "\n", 86 | "oracle.knn_imputation(n_pca_dims=n_comps, k=k, balanced=True, b_sight=k*8,\n", 87 | " b_maxl=k*4, n_jobs=4)\n", 88 | "\n", 89 | "base_GRN = co.data.load_human_promoter_base_GRN()\n", 90 | "\n", 91 | "# You can load TF info dataframe with the following code.\n", 92 | "oracle.import_TF_data(TF_info_matrix=base_GRN)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 6, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "oracle.fit_GRN_for_simulation(GRN_unit='whole')" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 5, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "# Save cell oracle object\n", 111 | "#oracle.to_hdf5(\"Norman19.celloracle.oracle\")\n", 112 | "\n", 113 | "oracle = co.load_hdf5(\"Norman19.celloracle.oracle\")" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 6, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "data": { 123 | "text/plain": [ 124 | "Oracle object\n", 125 | "\n", 126 | "Meta data\n", 127 | " celloracle version used for instantiation: 0.10.12\n", 128 | " n_cells: 91205\n", 129 | " n_genes: 5045\n", 130 | " cluster_name: condition\n", 131 | " dimensional_reduction_name: X_pca\n", 132 | " n_target_genes_in_TFdict: 27150 genes\n", 133 | " n_regulatory_in_TFdict: 1094 genes\n", 134 | " n_regulatory_in_both_TFdict_and_scRNA-seq: 181 genes\n", 135 | " n_target_genes_both_TFdict_and_scRNA-seq: 3436 genes\n", 136 | " k_for_knn_imputation: 2280\n", 137 | "Status\n", 138 | " Gene expression matrix: Ready\n", 139 | " BaseGRN: Ready\n", 140 | " PCA calculation: Done\n", 141 | " Knn imputation: Done\n", 142 | " GRN calculation for simulation: Done" 143 | ] 144 | }, 145 | "execution_count": 6, 146 | "metadata": {}, 147 | "output_type": "execute_result" 148 | } 149 | ], 150 | "source": [ 151 | "oracle" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 7, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "def check_pert(x, pertable_genes):\n", 161 | " x1, x2 = x.split('+')\n", 162 | " if x1 not in pertable_genes and x1 != 'ctrl':\n", 163 | " return False\n", 164 | " if x2 not in pertable_genes and x2 != 'ctrl':\n", 165 | " return False\n", 166 | " \n", 167 | " else:\n", 168 | " return True\n", 169 | " \n", 170 | "def get_pert_value(g):\n", 171 | " \n", 172 | " if g+'+ctrl' in adata.obs['condition']:\n", 173 | " pert_value = adata[adata.obs['condition'] == g+'+ctrl'][:,g].X.mean()\n", 174 | " \n", 175 | " else:\n", 176 | " pert_value = adata[adata.obs['condition'] == 'ctrl+'+g][:,g].X.mean()\n", 177 | " \n", 178 | " return pert_value\n", 179 | "\n", 180 | "def get_pert_effect(pert):\n", 181 | " \n", 182 | " g1,g2 = pert.split('+')\n", 183 | " pert_conditions = {}\n", 184 | " \n", 185 | " if g1 != 'ctrl':\n", 186 | " pert_value_g1 = get_pert_value(g1)\n", 187 | " if pert_value_g1 <0:\n", 188 | " pert_value_g1 = 1\n", 189 | " pert_conditions.update({g1:pert_value_g1})\n", 190 | " \n", 191 | " if g2 != 'ctrl':\n", 192 | " pert_value_g2 = get_pert_value(g2)\n", 193 | " if pert_value_g2 <0:\n", 194 | " pert_value_g2 = 1\n", 195 | " pert_conditions.update({g2:pert_value_g2})\n", 196 | " \n", 197 | " ctrl_idxs = np.where(oracle.adata.obs['condition']=='ctrl')[0]\n", 198 | " oracle.simulate_shift(perturb_condition=pert_conditions,\n", 199 | " ignore_warning=True,\n", 200 | " n_propagation=3)\n", 201 | " \n", 202 | " perturbed_expression = oracle.adata.layers['simulated_count'][ctrl_idxs,:]\n", 203 | " perturbed_expression = perturbed_expression.mean(0)\n", 204 | " \n", 205 | " #_ = [oracle.adata.layers.pop(k) for k in ['simulation_input', 'simulated_count']]\n", 206 | " \n", 207 | " return perturbed_expression" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": {}, 214 | "outputs": [ 215 | { 216 | "name": "stdout", 217 | "output_type": "stream", 218 | "text": [ 219 | "MEIS1+ctrl\n", 220 | "KLF1+FOXA1\n", 221 | "TBX3+TBX2\n", 222 | "CEBPE+KLF1\n", 223 | "ZNF318+FOXL2\n", 224 | "Failed: ZNF318+FOXL2\n", 225 | "JUN+CEBPA\n" 226 | ] 227 | }, 228 | { 229 | "name": "stderr", 230 | "output_type": "stream", 231 | "text": [ 232 | "/dfs/user/yhr/scenic_env/lib/python3.7/site-packages/scipy/sparse/base.py:581: RuntimeWarning: divide by zero encountered in true_divide\n", 233 | " return self.astype(np.float_)._mul_scalar(1./other)\n" 234 | ] 235 | }, 236 | { 237 | "name": "stdout", 238 | "output_type": "stream", 239 | "text": [ 240 | "ctrl+MEIS1\n", 241 | "ETS2+CEBPE\n", 242 | "POU3F2+FOXL2\n", 243 | "AHR+KLF1\n", 244 | "CEBPB+CEBPA\n", 245 | "FOXL2+MEIS1\n", 246 | "FOXL2+ctrl\n", 247 | "FOSB+CEBPE\n", 248 | "FOSB+CEBPB\n", 249 | "FOXA3+HOXB9\n", 250 | "OSR2+ctrl\n", 251 | "ctrl+SPI1\n", 252 | "CEBPB+ctrl\n", 253 | "CEBPB+OSR2\n", 254 | "FEV+ISL2\n", 255 | "JUN+ctrl\n", 256 | "FOXA1+HOXB9\n", 257 | "ZBTB10+ctrl\n", 258 | "Failed: ZBTB10+ctrl\n", 259 | "CEBPE+SPI1\n", 260 | "FOXA1+FOXL2\n", 261 | "FOXF1+FOXL2\n", 262 | "LYL1+CEBPB\n", 263 | "ctrl+CEBPB\n", 264 | "PRDM1+ctrl\n", 265 | "FOSB+OSR2\n", 266 | "FOXL2+HOXB9\n", 267 | "ctrl+OSR2\n", 268 | "JUN+CEBPB\n", 269 | "ZBTB10+SNAI1\n", 270 | "Failed: ZBTB10+SNAI1\n", 271 | "ctrl+FOXL2\n", 272 | "CEBPE+CEBPB\n", 273 | "FOXA3+FOXL2\n", 274 | "SPI1+ctrl\n", 275 | "EGR1+ctrl\n", 276 | "ZBTB10+DLX2\n", 277 | "Failed: ZBTB10+DLX2\n", 278 | "SNAI1+DLX2\n", 279 | "ctrl+FOXA1\n", 280 | "FOXA3+FOXA1\n" 281 | ] 282 | }, 283 | { 284 | "name": "stderr", 285 | "output_type": "stream", 286 | "text": [ 287 | "/dfs/user/yhr/scenic_env/lib/python3.7/site-packages/scipy/sparse/base.py:581: RuntimeWarning: divide by zero encountered in true_divide\n", 288 | " return self.astype(np.float_)._mul_scalar(1./other)\n" 289 | ] 290 | }, 291 | { 292 | "name": "stdout", 293 | "output_type": "stream", 294 | "text": [ 295 | "ctrl+ETS2\n", 296 | "KLF1+FOXA1\n", 297 | "HES7+ctrl\n", 298 | "ZNF318+FOXL2\n", 299 | "Failed: ZNF318+FOXL2\n", 300 | "FOXO4+ctrl\n", 301 | "JUN+CEBPA\n" 302 | ] 303 | } 304 | ], 305 | "source": [ 306 | "for split_num in range(1,6):\n", 307 | " split_file = '/dfs/project/perturb-gnn/datasets/data/norman/splits/norman_simulation_'+str(split_num)+'_0.75.pkl'\n", 308 | " split_perts = pd.read_pickle(split_file)\n", 309 | " test_perts = split_perts['test']\n", 310 | "\n", 311 | " ctrl_adata = adata[adata.obs['condition']=='ctrl']\n", 312 | " ctrl_mean = ctrl_adata.X.toarray().mean(0)\n", 313 | "\n", 314 | " unique_perts = set(np.hstack([x.split('+') for x in adata.obs['condition'].values]))\n", 315 | " pertable_genes = [x for x in unique_perts if x in TF_names.iloc[:,0].values]\n", 316 | " pertable_test_perts = [p for p in test_perts if check_pert(p, pertable_genes)]\n", 317 | "\n", 318 | " perturbed_expression = {}\n", 319 | "\n", 320 | " for pert in pertable_test_perts:\n", 321 | " ## Retry with repeated reloading\n", 322 | " oracle._clear_simulation_results()\n", 323 | " if pert not in perturbed_expression:\n", 324 | " print(pert)\n", 325 | " try:\n", 326 | " perturbed_expression[pert] = get_pert_effect(pert)\n", 327 | " except:\n", 328 | " print('Failed: '+pert)\n", 329 | " \n", 330 | " #np.save('CellOracle_preds_pert_exp_split_'+str(split_num), perturbed_expression)\n", 331 | " np.save('CellOracle_preds_pert_exp_split_retry_'+str(split_num), perturbed_expression)" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": null, 337 | "metadata": {}, 338 | "outputs": [], 339 | "source": [] 340 | } 341 | ], 342 | "metadata": { 343 | "kernelspec": { 344 | "display_name": "scenic_env", 345 | "language": "python", 346 | "name": "scenic_env" 347 | }, 348 | "language_info": { 349 | "codemirror_mode": { 350 | "name": "ipython", 351 | "version": 3 352 | }, 353 | "file_extension": ".py", 354 | "mimetype": "text/x-python", 355 | "name": "python", 356 | "nbconvert_exporter": "python", 357 | "pygments_lexer": "ipython3", 358 | "version": "3.7.4" 359 | } 360 | }, 361 | "nbformat": 4, 362 | "nbformat_minor": 2 363 | } 364 | -------------------------------------------------------------------------------- /paper/Fig4_UMAP_predict.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../../gears_misc/') 3 | from gears import PertData, GEARS 4 | 5 | import numpy as np 6 | from multiprocessing import Pool 7 | import tqdm 8 | import scanpy as sc 9 | 10 | data_name = 'norman_umi' 11 | model_name = 'gears_misc_umi_no_test' 12 | pert_data = PertData('/dfs/project/perturb-gnn/datasets/data') 13 | pert_data.load(data_path = '/dfs/project/perturb-gnn/datasets/data/'+data_name) 14 | pert_data.prepare_split(split = 'no_test', seed = 1) 15 | pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) 16 | 17 | gears_model = GEARS(pert_data, device = 'cuda:7', 18 | weight_bias_track = False, 19 | proj_name = 'gears', 20 | exp_name = model_name) 21 | gears_model.load_pretrained('./model_ckpt/'+model_name) 22 | 23 | 24 | ## ---- GI Predictions 25 | 26 | def get_reverse(perts): 27 | return [t.split('+')[-1]+'+'+t.split('+')[0] for t in perts] 28 | 29 | def remove_reverse(perts): 30 | return list(set(perts).difference(set(get_reverse(perts)))) 31 | 32 | def remove_duplicates_list(list_): 33 | import itertools 34 | list_.sort() 35 | return list(k for k,_ in itertools.groupby(list_)) 36 | 37 | # ---- SEEN 2 38 | 39 | norman_adata = sc.read_h5ad('/dfs/project/perturb-gnn/datasets/data/'+data_name+'/perturb_processed.h5ad') 40 | 41 | genes_of_interest = set([c.strip('+ctrl') for c in norman_adata.obs['condition'] 42 | if ('ctrl+' in c) or ('+ctrl' in c)]) 43 | genes_of_interest = [g for g in genes_of_interest if g in list(pert_data.pert_names)] 44 | 45 | 46 | all_possible_combos = [] 47 | 48 | for g1 in genes_of_interest: 49 | for g2 in genes_of_interest: 50 | if g1==g2: 51 | continue 52 | all_possible_combos.append(sorted([g1,g2])) 53 | 54 | all_possible_combos = remove_duplicates_list(all_possible_combos) 55 | 56 | ## First run inference on all combos using GPU 57 | 58 | # Predict all singles 59 | for c in genes_of_interest: 60 | print('Single prediction: ',c) 61 | gears_model.predict([[c]]) 62 | 63 | # Predict all combos 64 | for it, c in enumerate(all_possible_combos): 65 | print('Combo prediction: ',it) 66 | gears_model.predict([c]) 67 | 68 | # Then use a CPU-based model for computing GI scores parallely 69 | np.save(model_name+'_all_preds', gears_model.saved_pred) 70 | gears_model_cpu = GEARS(pert_data, device = 'cpu') 71 | gears_model_cpu.saved_pred = gears_model.saved_pred 72 | 73 | def Map(F, x, workers): 74 | """ 75 | wrapper for map() 76 | Spawn workers for parallel processing 77 | 78 | """ 79 | with Pool(workers) as pool: 80 | ret = list(tqdm.tqdm(pool.imap(F, x), total=len(x))) 81 | return ret 82 | 83 | def mapper(c): 84 | return gears_model_cpu.GI_predict(c) 85 | 86 | all_GIs = Map(mapper, all_possible_combos, workers=10) 87 | 88 | # Construct final dictionary and save 89 | all_GIs = {str(key):val for key, val in zip(all_possible_combos, all_GIs)} 90 | np.save(model_name+'_allGI', all_GIs) 91 | 92 | # If computing uncertainty 93 | # np.save(model_name+'_alluncs', gears_model.saved_logvar_sum) 94 | -------------------------------------------------------------------------------- /paper/Fig4_UMAP_train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../../gears_misc/') 3 | 4 | from gears import PertData, GEARS 5 | 6 | pert_data = PertData('/dfs/project/perturb-gnn/datasets/data') 7 | pert_data.load(data_path = '/dfs/project/perturb-gnn/datasets/data/norman_umi') 8 | pert_data.prepare_split(split = 'no_test', seed = 1) 9 | pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) 10 | 11 | 12 | gears_model = GEARS(pert_data, device = 'cuda:6', 13 | weight_bias_track = True, 14 | proj_name = 'gears', 15 | exp_name = 'gears_misc_umi_no_test') 16 | gears_model.model_initialize(hidden_size = 64, 17 | uncertainty=False) 18 | 19 | gears_model.train(epochs = 20, lr = 1e-3) 20 | 21 | gears_model.save_model('gears_misc_umi_no_test') 22 | -------------------------------------------------------------------------------- /paper/GRN/GRN_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import networkx as nx 6 | import math 7 | import pdb 8 | import pandas as pd 9 | 10 | import sys 11 | from flow import get_graph, get_expression_data,\ 12 | add_weight, get_TFs, solve,\ 13 | solve_parallel, get_expression_lambda 14 | # Linear model for simulating linear perturbation effects 15 | class linear_model(): 16 | def __init__(self, graph_path, weights_path, gene_list, 17 | binary=False, pos_edges=False, hops=3, 18 | species='human'): 19 | self.TFs = get_TFs(species) 20 | self.gene_list = gene_list 21 | 22 | # Set up graph structure 23 | G_df = get_graph(name = graph_path, TF_only=False) 24 | print('Edges: '+str(len(G_df))) 25 | self.G = nx.from_pandas_edgelist(G_df, source=0, 26 | target=1, create_using=nx.DiGraph()) 27 | 28 | for n in self.gene_list: 29 | if n not in self.G.nodes(): 30 | self.G.add_node(n) 31 | 32 | # Add edge weights 33 | self.read_weights = pd.read_csv(weights_path, index_col=0) 34 | try: 35 | self.read_weights = self.read_weights.set_index('TF') 36 | except: 37 | pass 38 | 39 | # Get adjacency matrix 40 | self.adj_mat = self.create_adj_mat() 41 | 42 | A = self.adj_mat.T 43 | if binary and pos_edges: 44 | A = np.array(A != 0).astype('float') 45 | 46 | # Set the diagonal elements to zero everywhere except the TFs 47 | np.fill_diagonal(A, 0) 48 | each_hop = A.copy() 49 | last_hop = A.copy() 50 | for k in range(hops-1): 51 | last_hop = last_hop @ each_hop 52 | if binary: 53 | A += last_hop/(k+2) 54 | else: 55 | A += last_hop 56 | self.A = A 57 | 58 | 59 | def create_adj_mat(self): 60 | # Create a df version of the graph for merging 61 | G_df = pd.DataFrame(self.G.edges(), columns=['TF', 'target']) 62 | 63 | # Merge it with the weights DF 64 | weighted_G_df = self.read_weights.merge(G_df, on=['TF', 'target']) 65 | for w in weighted_G_df.iterrows(): 66 | add_weight(self.G, w[1]['TF'], w[1]['target'], w[1]['importance']) 67 | 68 | # Get an adjacency matrix based on the gene ordering from the DE list 69 | return nx.linalg.graphmatrix.adjacency_matrix( 70 | self.G, nodelist=self.gene_list).todense() 71 | 72 | 73 | def simulate_pert(self, pert_genes, pert_mags=None): 74 | """ 75 | Returns predicted differential expression (delta) upon perturbing 76 | a list of genes 'pert_genes' 77 | """ 78 | 79 | # Create perturbation vector 80 | pert_idx = np.where([(g in pert_genes) for g in self.gene_list])[0] 81 | theta = np.zeros([len(self.gene_list),1]) 82 | 83 | # Set up the input vector 84 | if pert_mags is None: 85 | pert_mags = np.ones(len(pert_genes)) 86 | for idx, pert_mag in zip(pert_idx, pert_mags): 87 | theta[pert_idx] = pert_mag 88 | 89 | # Compute differential expression vector 90 | delta = np.dot(self.A, theta) 91 | delta = np.squeeze(np.array(delta)) 92 | 93 | # Add the perturbation magnitude directly for the TF 94 | delta = delta + np.squeeze(theta) 95 | 96 | return delta 97 | 98 | class No_Perturb(torch.nn.Module): 99 | """ 100 | No Perturbation 101 | """ 102 | 103 | def __init__(self): 104 | super(No_Perturb, self).__init__() 105 | 106 | def forward(self, data): 107 | 108 | x = data.x 109 | x = x[:, 0].reshape(*data.y.shape) 110 | 111 | return x, None -------------------------------------------------------------------------------- /paper/GRN/README.md: -------------------------------------------------------------------------------- 1 | Procedure for generating GRN to be used in GRN baseline 2 | 3 | - First compute the SCENIC graph using `SCENIC_norman.ipynb` 4 | - Then Filter the graph using `graph_filtering.ipynb` 5 | - Finally run `python learn_weights.py` 6 | 7 | Procedure for running GRN baselines 8 | 9 | - Run `python run_GRN_baseline.py` 10 | -------------------------------------------------------------------------------- /paper/GRN/graph_filtering.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### Target filtering procedure" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import pandas as pd\n", 17 | "import glob\n", 18 | "\n", 19 | "import scanpy as sc\n", 20 | "import numpy as np" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "# Top 50 targets for each TF (targets with highest weight)\n", 30 | "def get_topk(df, tf, k):\n", 31 | " return df[df['TF'] == tf].sort_values('importance', ascending=False)[:k]\n", 32 | "\n", 33 | "def filter_topk(grnboost_out, k=50):\n", 34 | " \n", 35 | " tfs = grnboost_out['TF'].unique()\n", 36 | " tf_dfs = []\n", 37 | " for tf in tfs:\n", 38 | " tf_dfs.append(get_topk(grnboost_out, tf, k=k))\n", 39 | " \n", 40 | " return pd.concat(tf_dfs)\n", 41 | "\n", 42 | "# Targets with importance > the 95th percentile\n", 43 | "def get_pc(grnboost_out, pc=95):\n", 44 | " return grnboost_out.sort_values('importance', ascending=False)[:int(len(grnboost_out)*(1-0.01*pc))]\n", 45 | "\n", 46 | "# Get filtered adjacency lists\n", 47 | "def get_filtered_adj_list(grnboost_out):\n", 48 | " filters = {}\n", 49 | " filters['top50'] = filter_topk(grnboost_out, k=50)\n", 50 | " filters['95pc'] = get_pc(grnboost_out, pc=95)\n", 51 | "\n", 52 | " return filters" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "# Generate filtered adjacency files for GRNboost graph\n", 62 | "\n", 63 | "#names = ['norman']\n", 64 | "#names = ['tian2019_neuron_hvg', 'tian2019_ipsc_hvg', 'jost2020_hvg', 'replogle2020_hvg']\n", 65 | "names = ['adamson']\n", 66 | "\n", 67 | "for name in names:\n", 68 | " for split in range(5,6):\n", 69 | " # Read GRNboost output\n", 70 | " grnboost_out = pd.read_csv('./adjacencies_'+name+'_'+str(split)+'_grnboost.csv', index_col =0)\n", 71 | " filtered = get_filtered_adj_list(grnboost_out)\n", 72 | "\n", 73 | " # Save filtered graphs\n", 74 | " filtered['top50'].to_csv('/dfs/project/perturb-gnn/graphs/linear/grnboost/'+name+'_'+str(split)+'_top50.csv', \n", 75 | " index=False, header=False)\n", 76 | " filtered['95pc'].to_csv('/dfs/project/perturb-gnn/graphs/linear/grnboost/'+name+'_'+str(split)+'_95pc.csv',\n", 77 | " index=False, header=False)" 78 | ] 79 | } 80 | ], 81 | "metadata": { 82 | "kernelspec": { 83 | "display_name": "deepsnap", 84 | "language": "python", 85 | "name": "deepsnap" 86 | }, 87 | "language_info": { 88 | "codemirror_mode": { 89 | "name": "ipython", 90 | "version": 3 91 | }, 92 | "file_extension": ".py", 93 | "mimetype": "text/x-python", 94 | "name": "python", 95 | "nbconvert_exporter": "python", 96 | "pygments_lexer": "ipython3", 97 | "version": "3.7.4" 98 | } 99 | }, 100 | "nbformat": 4, 101 | "nbformat_minor": 2 102 | } 103 | -------------------------------------------------------------------------------- /paper/GRN/learn_weights.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pandas as pd 3 | import numpy as np 4 | import scanpy as sc 5 | import networkx as nx 6 | sys.path.append('../model/') 7 | from flow import get_graph 8 | from sklearn.model_selection import train_test_split 9 | from sklearn import linear_model 10 | from sklearn.neural_network import MLPRegressor 11 | from sklearn.metrics import r2_score 12 | import glob 13 | import argparse 14 | import pdb 15 | 16 | no_model_count = 0 17 | 18 | def nonzero_idx(mat): 19 | mat=pd.DataFrame(mat) 20 | return mat[(mat > 0).sum(1) > 0].index.values 21 | 22 | def get_split_adata(adata, split_dir, split_id): 23 | split_files = [f for f in glob.glob(split_dir + '/*') if 'subgroup' not in f] 24 | split_fname = [f for f in split_files if 'simulation_'+str(split_id) in f][0] 25 | split_dict = pd.read_pickle(split_fname) 26 | 27 | return adata[adata.obs['condition'].isin(split_dict['train'])] 28 | 29 | def data_split(X, y, size=0.1): 30 | nnz = list(set(nonzero_idx(X)).intersection(set(nonzero_idx(y)))) 31 | 32 | if len(nnz) <= 2: 33 | global no_model_count 34 | no_model_count += 1 35 | 36 | return -1,-1 37 | 38 | train_split, val_split = train_test_split(nnz, test_size=size, 39 | random_state=42) 40 | return train_split, val_split 41 | 42 | def train_regressor(X, y, kind, alpha=0): 43 | 44 | if kind == 'linear': 45 | model = linear_model.LinearRegression() 46 | elif kind == 'lasso': 47 | model = linear_model.Lasso(alpha=alpha) 48 | elif kind == 'elasticnet': 49 | model = linear_model.ElasticNet(alpha=alpha, l1_ratio=0.5, 50 | max_iter=1000) 51 | elif kind == 'ridge': 52 | model = linear_model.Ridge(alpha=alpha, max_iter=1000) 53 | elif kind == 'MLP': 54 | model = MLPRegressor(hidden_layer_sizes=(20,10), max_iter=1000) 55 | 56 | reg = model.fit(X, y) 57 | loss = np.sqrt(np.mean((y - model.predict(X))**2)) 58 | return reg, loss, reg.score(X, y) 59 | 60 | 61 | def evaluate_regressor(model, X, y): 62 | y_cap = model.predict(X) 63 | loss = np.sqrt(np.mean((y - y_cap)**2)) 64 | 65 | return loss, y, y_cap 66 | 67 | def init_dict(): 68 | d = {} 69 | d['linear'] = [] 70 | d['lasso'] = [] 71 | d['ridge'] = [] 72 | d['MLP'] = [] 73 | return d 74 | 75 | # Looks at the median of max expression across cells/not genes 76 | def max_median_norm(df): 77 | return df/df.max().median() 78 | 79 | def get_weights(adj_mat, exp_adata, nodelist, method='linear', lim=50000): 80 | models = init_dict() 81 | adj_list = {} 82 | 83 | adj_list['TF'] = []; adj_list['target'] = []; adj_list['importance'] = []; 84 | 85 | adj_mat_idx = np.arange(len(adj_mat)) 86 | np.random.shuffle(adj_mat_idx) 87 | count = 0 88 | 89 | 90 | def trainer(kind, feats, y): 91 | model, _, _ = train_regressor( 92 | feats, y, kind=kind) 93 | 94 | # Store results 95 | try: models[kind].append(model.coef_); 96 | except: pass; 97 | 98 | 99 | def trainer_split(kind, feats, y, train_split, val_split): 100 | models_ = [] 101 | val_losses_ = [] 102 | 103 | for alpha in [1e-6, 1e-4, 1e-2, 1e-1]: 104 | model, _, _ = train_regressor( 105 | feats[train_split,:], 106 | y[train_split], kind=kind, 107 | alpha=alpha) 108 | val_loss, _, _ = evaluate_regressor(model, 109 | feats[val_split, :], 110 | y[val_split]) 111 | 112 | models_.append(model) 113 | val_losses_.append(val_loss) 114 | 115 | best_model = models_[np.argmin(val_losses_)] 116 | 117 | # Store results 118 | try: models[kind].append(best_model.coef_); 119 | except: pass; 120 | 121 | 122 | print('T genes: ', str(len(adj_mat_idx))) 123 | for itr in adj_mat_idx: 124 | i = adj_mat[itr] 125 | if i.sum() > 0: 126 | idx = np.where(i > 0)[1] 127 | TFs = np.array(nodelist)[idx] 128 | target = np.array(nodelist)[itr] 129 | 130 | feats = exp_adata[:, TFs].X.toarray() 131 | y = exp_adata[:, target].X.toarray() 132 | 133 | if method=='linear': 134 | trainer('linear', feats, y) 135 | else: 136 | train_split, val_split = data_split(feats, y, size=0.1) 137 | if train_split==-1: continue; 138 | trainer_split(method, feats, y, train_split, val_split) 139 | 140 | # Add row to new weight matrix 141 | for j,k in enumerate(TFs): 142 | adj_list['TF'].append(k) 143 | adj_list['target'].append(target) 144 | try: 145 | adj_list['importance'].append(models[method][-1][0][j]) 146 | except: 147 | adj_list['importance'].append(models[method][-1][j]) 148 | 149 | print(count) 150 | count += 1 151 | 152 | if count >= lim: 153 | break 154 | return models, adj_list 155 | 156 | 157 | def main(args): 158 | try: 159 | split_id = int(args.graph_name.split('_')[-2]) 160 | except: 161 | split_id = int(args.graph_name.split('_')[-3]) 162 | 163 | adata = sc.read_h5ad(args.split_dir.split('splits')[0]+'perturb_processed.h5ad') 164 | 165 | # Remove genes with duplicaet names 166 | genes_to_keep = adata.var.drop_duplicates('gene_name').index 167 | adata = adata[:, genes_to_keep] 168 | 169 | exp_adata = get_split_adata(adata, args.split_dir, split_id) 170 | exp_adata.var = exp_adata.var.set_index('gene_name', drop=False) 171 | 172 | G = pd.read_csv(args.graph_name, header=None) 173 | G = nx.from_pandas_edgelist(G, source=0, 174 | target=1, create_using=nx.DiGraph()) 175 | adj_mat = nx.linalg.graphmatrix.adjacency_matrix(G).todense().T 176 | nodelist = [n for n in G.nodes()] 177 | 178 | # Remove self-edges 179 | np.fill_diagonal(adj_mat, 0) 180 | models, adj_list = get_weights(adj_mat, exp_adata, 181 | nodelist, method=args.method, lim=20000) 182 | 183 | out_name = args.graph_name.split('/')[-1].split('.')[0] 184 | pd.DataFrame(adj_list).to_csv(args.out_dir + out_name + 185 | '_' + args.method + '_learntweights.csv') 186 | 187 | # Convert coefficients into new weight matrix 188 | print('Done') 189 | 190 | 191 | if __name__ == '__main__': 192 | parser = argparse.ArgumentParser( 193 | description='Set model hyperparametrs.', 194 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 195 | #torch.cuda.set_device(4) 196 | 197 | parser.add_argument('--split_dir', type=str, 198 | help='Directory for splits') 199 | parser.add_argument('--graph_name', type=str, 200 | help='Graph filename') 201 | parser.add_argument('--out_dir', type=str, 202 | help='Output filename') 203 | parser.add_argument('--method', type=str, 204 | help='Regression method') 205 | 206 | 207 | parser.set_defaults( 208 | split_dir ='../data/norman2019/splits', 209 | graph_name='norman2019_1_top50.csv', 210 | method='linear', 211 | out_dir='./') 212 | 213 | args = parser.parse_args() 214 | main(args) 215 | -------------------------------------------------------------------------------- /paper/GRN/run_GRN_baseline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import scanpy as sc 4 | import pickle 5 | from GRN_model import linear_model 6 | from utils import parse_any_pert 7 | from tqdm import tqdm 8 | from data import PertDataloader 9 | from pertdata import PertData 10 | import torch 11 | 12 | import warnings 13 | warnings.filterwarnings("ignore") 14 | import argparse 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--seed', type=int, default=1) 17 | parser.add_argument('--dataset_name', type=str, default='norman', choices = ['norman', 'adamson2016', 'dixit2016', 'jost2020_hvg', 'tian2021_crispri_hvg', 'tian2021_crispra_hvg', 'replogle2020_hvg', 'replogle_rpe1_gw_hvg', 'replogle_k562_gw_hvg', 'replogle_k562_essential_hvg', 'tian2019_neuron_hvg', 'tian2019_ipsc_hvg', 'replogle_rpe1_gw_filtered_hvg', 'replogle_k562_essential_filtered_hvg']) 18 | parser.add_argument('--graph_type', type=str, default='grnboost', choices = ['coexpression', 'grnboost', 'go']) 19 | 20 | args = parser.parse_args() 21 | 22 | data_path = '/dfs/project/perturb-gnn/datasets/data/' 23 | split_num = args.seed 24 | graph_type = args.graph_type 25 | filter_ = 'top50' 26 | method = 'linear' 27 | dataset = args.dataset_name 28 | 29 | if dataset in ['tian2019_neuron_hvg', 'tian2019_ipsc_hvg', 'jost2020_hvg']: 30 | dataset += '_small_graph' 31 | 32 | if dataset == 'Norman2019': 33 | data_path = '/dfs/project/perturb-gnn/datasets/Norman2019/Norman2019_hvg+perts_more_de.h5ad' 34 | elif dataset == 'replogle_k562_essential_filtered_hvg': 35 | data_path = '/dfs/project/perturb-gnn/datasets/data/replogle_k562_essential_filtered_hvg/perturb_processed_linear.h5ad' 36 | else: 37 | data_path = '/dfs/project/perturb-gnn/datasets/data/' + dataset + '/perturb_processed.h5ad' 38 | 39 | adata = sc.read(data_path) 40 | if 'gene_symbols' not in adata.var.columns.values: 41 | adata.var['gene_symbols'] = adata.var['gene_name'] 42 | 43 | if dataset == 'Norman2019': 44 | split_path = '../data/Norman2019/splits/Norman2019_simulation_' + str(split_num) + '_0.1.pkl' 45 | else: 46 | split_path = '/dfs/project/perturb-gnn/datasets/data/' + dataset + '/splits/' + dataset + '_simulation_' + str(split_num) + '_0.75.pkl' 47 | 48 | split = pickle.load(open(split_path, 'rb')) 49 | 50 | condition2set = {} 51 | 52 | for i,j in split.items(): 53 | for k in j: 54 | condition2set[k] = i 55 | 56 | #adata.obs['split_status'] = [condition2set[i] for i in adata.obs.condition.values] 57 | #adata_train = adata[adata.obs['split_status'] == 'train'] 58 | 59 | if dataset == 'norman': 60 | 61 | if graph_type == 'go': 62 | graph_path = '/dfs/project/perturb-gnn/graphs/linear/grnboost/go_essential_norman_filter_' + str(split_num) + '_linear.csv' 63 | weights_path = '/dfs/project/perturb-gnn/graphs/linear/grnboost/go_essential_norman_filter_' + str(split_num) + '_linear_learntweights.csv' 64 | else: 65 | graph_path = '/dfs/project/perturb-gnn/graphs/linear/' + graph_type + '/norman_' + str(split_num) + '_' + str(filter_) + '.csv' 66 | weights_path = '/dfs/project/perturb-gnn/graphs/linear/' + graph_type + '/norman_' + str(split_num) + '_' + str(filter_) + '_' + method + '_learntweights.csv' 67 | 68 | 69 | elif dataset == 'replogle_rpe1_gw_filtered_hvg': 70 | graph_path = '/dfs/project/perturb-gnn/graphs/linear/grnboost/replogle_rpe1_gw_filtered_hvg_' + str(split_num) + '_subsample_top50.csv' 71 | weights_path = '/dfs/project/perturb-gnn/graphs/linear/grnboost/replogle_rpe1_gw_filtered_hvg_' + str(split_num) + '_subsample_top50_linear_learntweights.csv' 72 | elif dataset == 'replogle_k562_essential_filtered_hvg': 73 | graph_path = '/dfs/project/perturb-gnn/graphs/linear/grnboost/replogle_k562_essential_filtered_hvg_' + str(split_num) + '_subsample_top50.csv' 74 | weights_path = '/dfs/project/perturb-gnn/graphs/linear/grnboost/replogle_k562_essential_filtered_hvg_' + str(split_num) + '_subsample_top50_linear_learntweights.csv' 75 | else: 76 | graph_path = '/dfs/project/perturb-gnn/graphs/linear/grnboost/' + args.dataset_name + '_' + str(split_num) + '_' + str(filter_) + '.csv' 77 | weights_path = '/dfs/project/perturb-gnn/graphs/linear/grnboost/' + args.dataset_name + '_' + str(split_num) + '_' + str(filter_) + '_' + method + '_learntweights.csv' 78 | 79 | 80 | gene_list = adata.var.gene_name.values 81 | 82 | model = linear_model(graph_path=graph_path, 83 | weights_path=weights_path, 84 | gene_list = gene_list, 85 | binary=False, 86 | pos_edges=False, 87 | hops=1, 88 | species='human') 89 | 90 | #args = np.load('./saved_args/pertnet_uncertainty_ori.npy', allow_pickle = True).item() 91 | 92 | #args['dataset'] = dataset 93 | #args['seed'] = split_num 94 | #args['test_pert_genes'] = 'N/A' 95 | 96 | #pertdl = PertDataloader(adata, args) 97 | 98 | 99 | if args.dataset_name == 'tian2019_neuron_hvg': 100 | gene_path = '/dfs/user/kexinh/gears2/data/essential_all_data_pert_genes_tian2019_neuron.pkl' 101 | elif args.dataset_name == 'tian2019_ipsc_hvg': 102 | gene_path = '/dfs/user/kexinh/gears2/data/essential_all_data_pert_genes_tian2019_ipsc.pkl' 103 | elif args.dataset_name == 'jost2020_hvg': 104 | gene_path = '/dfs/user/kexinh/gears2/data/essential_all_data_pert_genes_jost2020.pkl' 105 | else: 106 | gene_path = None 107 | 108 | pert_data = PertData('/dfs/project/perturb-gnn/datasets/data', gene_path = gene_path) # specific saved folder 109 | pert_data.load(data_path = '/dfs/project/perturb-gnn/datasets/data/' + dataset) 110 | 111 | pert_data.prepare_split(split = 'simulation', seed = split_num) 112 | pert_data.get_dataloader(batch_size = 128, test_batch_size = 128) 113 | 114 | pred_delta = {pert: model.simulate_pert(parse_any_pert(pert)) for pert in split['test']} 115 | adata_ctrl = adata[adata.obs.condition == 'ctrl'] 116 | 117 | 118 | pert_cat = [] 119 | pred = [] 120 | truth = [] 121 | pred_de = [] 122 | truth_de = [] 123 | results = {} 124 | 125 | for batch in tqdm(pert_data.dataloader['test_loader']): 126 | 127 | pert_cat.extend(batch.pert) 128 | p = np.array([pred_delta[i]+adata_ctrl.X[np.random.randint(0, adata_ctrl.shape[0])].toarray().reshape(-1,) for i in batch.pert]) 129 | t = batch.y 130 | 131 | pred.extend(p) 132 | truth.extend(t.cpu()) 133 | 134 | # Differentially expressed genes 135 | for itr, de_idx in enumerate(batch.de_idx): 136 | pred_de.append(p[itr, de_idx]) 137 | truth_de.append(t[itr, de_idx]) 138 | 139 | # all genes 140 | results['pert_cat'] = np.array(pert_cat) 141 | 142 | pred = np.stack(pred) 143 | truth = torch.stack(truth) 144 | results['pred']= pred 145 | results['truth']= truth.detach().cpu().numpy() 146 | 147 | pred_de = np.stack(pred_de) 148 | truth_de = torch.stack(truth_de) 149 | results['pred_de']= pred_de 150 | results['truth_de']= truth_de.detach().cpu().numpy() 151 | 152 | 153 | from inference_new import evaluate, compute_metrics, deeper_analysis, GI_subgroup, non_dropout_analysis 154 | 155 | test_metrics, test_pert_res = compute_metrics(results) 156 | test_res = results 157 | 158 | import wandb 159 | 160 | wandb.init(project='linear_model', name= '_'.join([args.dataset_name, str(split_num)])) 161 | 162 | args = {} 163 | args['wandb'] = True 164 | 165 | out = deeper_analysis(adata, test_res) 166 | out_non_dropout = non_dropout_analysis(adata, test_res) 167 | 168 | metrics = ['pearson_delta'] 169 | metrics_non_dropout = ['frac_opposite_direction_top20_non_dropout', 'frac_sigma_below_1_non_dropout', 'mse_top20_de_non_dropout'] 170 | 171 | if args['wandb']: 172 | for m in metrics: 173 | wandb.log({'test_' + m: np.mean([j[m] for i,j in out.items() if m in j])}) 174 | 175 | for m in metrics_non_dropout: 176 | wandb.log({'test_' + m: np.mean([j[m] for i,j in out_non_dropout.items() if m in j])}) 177 | 178 | if dataset == 'Norman2019': 179 | subgroup_path = './splits/' + dataset_name + '_simulation_' + str(split_num) + '_0.1_subgroup.pkl' 180 | else: 181 | subgroup_path = '/dfs/project/perturb-gnn/datasets/data/' + dataset + '/splits/'+ dataset + '_simulation_' + str(split_num) + '_0.75_subgroup.pkl' 182 | 183 | subgroup = pickle.load(open(subgroup_path, "rb")) 184 | 185 | subgroup_analysis = {} 186 | for name in subgroup['test_subgroup'].keys(): 187 | subgroup_analysis[name] = {} 188 | for m in list(list(test_pert_res.values())[0].keys()): 189 | subgroup_analysis[name][m] = [] 190 | 191 | for name, pert_list in subgroup['test_subgroup'].items(): 192 | for pert in pert_list: 193 | for m, res in test_pert_res[pert].items(): 194 | subgroup_analysis[name][m].append(res) 195 | 196 | for name, result in subgroup_analysis.items(): 197 | for m in result.keys(): 198 | subgroup_analysis[name][m] = np.mean(subgroup_analysis[name][m]) 199 | wandb.log({'test_' + name + '_' + m: subgroup_analysis[name][m]}) 200 | 201 | ## deeper analysis 202 | subgroup_analysis = {} 203 | for name in subgroup['test_subgroup'].keys(): 204 | subgroup_analysis[name] = {} 205 | for m in metrics: 206 | subgroup_analysis[name][m] = [] 207 | 208 | for m in metrics_non_dropout: 209 | subgroup_analysis[name][m] = [] 210 | 211 | for name, pert_list in subgroup['test_subgroup'].items(): 212 | for pert in pert_list: 213 | for m in metrics: 214 | subgroup_analysis[name][m].append(out[pert][m]) 215 | 216 | for m in metrics_non_dropout: 217 | subgroup_analysis[name][m].append(out_non_dropout[pert][m]) 218 | 219 | for name, result in subgroup_analysis.items(): 220 | for m in result.keys(): 221 | subgroup_analysis[name][m] = np.mean(subgroup_analysis[name][m]) 222 | wandb.log({'test_' + name + '_' + m: subgroup_analysis[name][m]}) 223 | 224 | -------------------------------------------------------------------------------- /paper/README.md: -------------------------------------------------------------------------------- 1 | To recreate Fig 2: 2 | 3 | - Run this [notebook](https://github.com/yhr91/GEARS_misc/blob/main/paper/reproduce_preprint_results.ipynb) 4 | 5 | To recreate Fig 4: 6 | - First train the model using this [script](https://github.com/yhr91/GEARS_misc/blob/main/paper/Fig4_UMAP_train.py) 7 | - Then run inference for all combinations using this [script](https://github.com/yhr91/GEARS_misc/blob/main/paper/Fig4_UMAP_predict.py) 8 | - After that you can produce the UMAP using this [notebook](https://github.com/yhr91/GEARS_misc/blob/main/paper/Fig4.ipynb) 9 | 10 | The code here will not install GEARS from the [main repository](https://github.com/snap-stanford/GEARS). It will use the local path to GEARS in this repository `../gears` 11 | 12 | For other baselines: 13 | - CPA: See `CPA_reproduce` 14 | - GRN: See `GRN` 15 | - CellOracle: See `CellOracle` 16 | 17 | Please raise an issue or email yhr@cs.stanford.edu in case of any problems/questions 18 | -------------------------------------------------------------------------------- /paper/data/GI_data.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhr91/GEARS_misc/f88211870dfa89c38a2eedbd69ca1abd28a25f3c/paper/data/GI_data.pkl -------------------------------------------------------------------------------- /paper/fig2_train.py: -------------------------------------------------------------------------------- 1 | ## Tested with GEARS v 0.1.1 2 | 3 | import argparse 4 | import sys 5 | 6 | from gears import PertData, GEARS 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--seed', type=int, default=1) 9 | parser.add_argument('--device', type=int, default=3) 10 | parser.add_argument('--dataset', type=str, default='norman', choices = ['norman', 'adamson', 'dixit', 11 | 'replogle_k562_essential', 12 | 'replogle_rpe1_essential']) 13 | parser.add_argument('--model', type=str, default='gears', choices = ['gears', 'no_perturb']) 14 | parser.add_argument('--batch_size', type=int, default=32) 15 | 16 | args = parser.parse_args() 17 | seed = args.seed 18 | 19 | if args.model == 'no_perturb': 20 | epoch = 0 21 | no_perturb = True 22 | else: 23 | epoch = 15 24 | no_perturb = False 25 | 26 | ## Set this to local dataloader directory 27 | data_path = './data/' 28 | 29 | if args.dataset in ['norman', 'adamson', 'dixit']: 30 | pert_data = PertData('./data', default_pert_graph=False) 31 | else: 32 | pert_data = PertData('./data', default_pert_graph=True) 33 | 34 | pert_data.load(data_name = args.dataset) 35 | pert_data.prepare_split(split = 'simulation', seed = seed) # get data split with seed 36 | pert_data.get_dataloader(batch_size = args.batch_size, test_batch_size = args.batch_size) 37 | 38 | gears_model = GEARS(pert_data, device = 'cuda:' + str(args.device), 39 | weight_bias_track = True, 40 | proj_name = args.dataset, 41 | exp_name = str(args.model) + '_seed' + str(seed)) 42 | 43 | gears_model.model_initialize(hidden_size = 64, no_perturb = no_perturb) 44 | gears_model.train(epochs = epoch) 45 | if args.model != 'no_perturb': 46 | gears_model.save_model('./model_ckpt/' + args.dataset + '_' + args.model + '_run' + str(seed)) 47 | -------------------------------------------------------------------------------- /paper/predicting_GIs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../../gears_misc/') 3 | from gears import PertData, GEARS 4 | 5 | import numpy as np 6 | from multiprocessing import Pool 7 | import tqdm 8 | import scanpy as sc 9 | 10 | data_name = 'norman_umi_only' 11 | model_name = 'gears_misc_umi_only_no_test' 12 | pert_data = PertData('/dfs/project/perturb-gnn/datasets/data') 13 | pert_data.load(data_path = '/dfs/project/perturb-gnn/datasets/data/'+data_name) 14 | pert_data.prepare_split(split = 'no_test', seed = 1) 15 | pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) 16 | 17 | gears_model = GEARS(pert_data, device = 'cuda:8', 18 | weight_bias_track = False, 19 | proj_name = 'gears', 20 | exp_name = model_name) 21 | gears_model.load_pretrained('./model_ckpt/'+model_name) 22 | 23 | 24 | ## ---- GI Predictions 25 | 26 | def get_reverse(perts): 27 | return [t.split('+')[-1]+'+'+t.split('+')[0] for t in perts] 28 | 29 | def remove_reverse(perts): 30 | return list(set(perts).difference(set(get_reverse(perts)))) 31 | 32 | def remove_duplicates_list(list_): 33 | import itertools 34 | list_.sort() 35 | return list(k for k,_ in itertools.groupby(list_)) 36 | 37 | # ---- SEEN 2 38 | 39 | norman_adata = sc.read_h5ad('/dfs/project/perturb-gnn/datasets/data/'+data_name+'/perturb_processed.h5ad') 40 | 41 | genes_of_interest = set([c.strip('+ctrl') for c in norman_adata.obs['condition'] 42 | if ('ctrl+' in c) or ('+ctrl' in c)]) 43 | genes_of_interest = [g for g in genes_of_interest if g in list(pert_data.pert_names)] 44 | 45 | 46 | all_possible_combos = [] 47 | 48 | for g1 in genes_of_interest: 49 | for g2 in genes_of_interest: 50 | if g1==g2: 51 | continue 52 | all_possible_combos.append(sorted([g1,g2])) 53 | 54 | all_possible_combos = remove_duplicates_list(all_possible_combos) 55 | 56 | ## First run inference on all combos using GPU 57 | 58 | # Predict all singles 59 | for c in genes_of_interest: 60 | print('Single prediction: ',c) 61 | gears_model.predict([[c]]) 62 | 63 | # Predict all combos 64 | for it, c in enumerate(all_possible_combos): 65 | print('Combo prediction: ',it) 66 | gears_model.predict([c]) 67 | 68 | # Then use a CPU-based model for computing GI scores parallely 69 | np.save(model_name+'_all_preds', gears_model.saved_pred) 70 | gears_model_cpu = GEARS(pert_data, device = 'cpu') 71 | gears_model_cpu.saved_pred = gears_model.saved_pred 72 | 73 | def Map(F, x, workers): 74 | """ 75 | wrapper for map() 76 | Spawn workers for parallel processing 77 | 78 | """ 79 | with Pool(workers) as pool: 80 | ret = list(tqdm.tqdm(pool.imap(F, x), total=len(x))) 81 | return ret 82 | 83 | def mapper(c): 84 | return gears_model_cpu.GI_predict(c) 85 | 86 | all_GIs = Map(mapper, all_possible_combos, workers=10) 87 | 88 | # Construct final dictionary and save 89 | all_GIs = {str(key):val for key, val in zip(all_possible_combos, all_GIs)} 90 | np.save(model_name+'_allGI', all_GIs) 91 | np.save(model_name+'_alluncs', gears_model.saved_logvar_sum) 92 | -------------------------------------------------------------------------------- /paper/reproduce_preprint_results.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Reproduce results from the [preprint](https://www.biorxiv.org/content/10.1101/2022.07.12.499735v2)\n", 8 | "\n", 9 | "For Dixit, Adamson and Norman dataset\n", 10 | "\n", 11 | "Note: You might need to restart notebook each time you initialize and train a new model or no perturb baseline. It might be more efficient to move the training code to separate scripts" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "## Note this uses the pip version of GEARS (v0.0.4) - not the one in this repository\n", 21 | "import gears\n", 22 | "from gears import PertData, GEARS" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "seed = 1" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "## Dataset: Dixit " 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 3, 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "name": "stderr", 48 | "output_type": "stream", 49 | "text": [ 50 | "Found local copy...\n", 51 | "Downloading...\n", 52 | "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 168M/168M [00:12<00:00, 13.1MiB/s]\n", 53 | "Extracting zip file...\n", 54 | "Done!\n", 55 | "These perturbations are not in the GO graph and their perturbation can thus not be predicted\n", 56 | "[]\n", 57 | "Creating pyg object for each cell in the data...\n", 58 | "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:59<00:00, 2.97s/it]\n", 59 | "Saving new dataset pyg object at ./data/dixit/data_pyg/cell_graphs.pkl\n", 60 | "Done!\n", 61 | "Creating new splits....\n", 62 | "Saving new splits at ./data/dixit/splits/dixit_simulation_1_0.75.pkl\n", 63 | "Simulation split test composition:\n", 64 | "combo_seen0:0\n", 65 | "combo_seen1:0\n", 66 | "combo_seen2:0\n", 67 | "unseen_single:5\n", 68 | "Done!\n", 69 | "Creating dataloaders....\n", 70 | "Done!\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "dataset_name = 'dixit'\n", 76 | "\n", 77 | "pert_data = PertData('./data', default_pert_graph=False)\n", 78 | "pert_data.load(data_name = dataset_name)\n", 79 | "pert_data.prepare_split(split = 'simulation', seed = seed) # get data split with seed\n", 80 | "pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) # prepare data loader" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "gears_model = GEARS(pert_data, device = 'cuda:2', \n", 90 | " weight_bias_track = False, \n", 91 | " proj_name = dataset_name, \n", 92 | " exp_name = 'gears_seed' + str(seed))\n", 93 | "\n", 94 | "gears_model.model_initialize(hidden_size = 64)\n", 95 | "\n", 96 | "gears_model.train(epochs=15)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "### Dixit: No perturb condition" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "dataset_name = 'dixit'\n", 113 | "\n", 114 | "pert_data = PertData('./data', default_pert_graph=False)\n", 115 | "pert_data.load(data_name = dataset_name)\n", 116 | "pert_data.prepare_split(split = 'simulation', seed = seed) # get data split with seed\n", 117 | "pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) # prepare data loader" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "gears_model = GEARS(pert_data, device = 'cuda:2', \n", 127 | " weight_bias_track = False, \n", 128 | " proj_name = dataset_name, \n", 129 | " exp_name = 'no_pert_seed' + str(seed))\n", 130 | "\n", 131 | "gears_model.model_initialize(hidden_size = 64,\n", 132 | " no_perturb=True)\n", 133 | "\n", 134 | "gears_model.train(epochs=0)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "## Dataset: Adamson " 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 51, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "dataset_name = 'adamson'\n", 151 | "\n", 152 | "pert_data = PertData('./data', default_pert_graph=False)\n", 153 | "pert_data.load(data_name = dataset_name)\n", 154 | "pert_data.prepare_split(split = 'simulation', seed = seed) # get data split with seed\n", 155 | "pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) # prepare data loader" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 46, 161 | "metadata": {}, 162 | "outputs": [ 163 | { 164 | "name": "stderr", 165 | "output_type": "stream", 166 | "text": [ 167 | "Found local copy...\n", 168 | "Downloading...\n", 169 | "100%|███████████████████████████████████████| 559k/559k [00:00<00:00, 1.88MiB/s]\n", 170 | "Downloading...\n", 171 | "100%|███████████████████████████████████████| 141M/141M [00:10<00:00, 13.7MiB/s]\n", 172 | "Extracting zip file...\n", 173 | "Done!\n", 174 | "These perturbations are not in the GO graph and their perturbation can thus not be predicted\n", 175 | "['SRPR+ctrl' 'SLMO2+ctrl' 'TIMM23+ctrl' 'AMIGO3+ctrl' 'KCTD16+ctrl']\n", 176 | "Creating pyg object for each cell in the data...\n", 177 | "100%|███████████████████████████████████████████| 82/82 [01:26<00:00, 1.05s/it]\n", 178 | "Saving new dataset pyg object at ./data2/adamson/data_pyg/cell_graphs.pkl\n", 179 | "Done!\n" 180 | ] 181 | } 182 | ], 183 | "source": [ 184 | "gears_model = GEARS(pert_data, device = 'cuda:2', \n", 185 | " weight_bias_track = False, \n", 186 | " proj_name = dataset_name, \n", 187 | " exp_name = 'gears_seed' + str(seed))\n", 188 | "\n", 189 | "gears_model.model_initialize(hidden_size = 64)\n", 190 | "\n", 191 | "gears_model.train(epochs=1)" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": {}, 197 | "source": [ 198 | "### Adamson: No perturb condition" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "dataset_name = 'adamson'\n", 208 | "\n", 209 | "pert_data = PertData('./data', default_pert_graph=False)\n", 210 | "pert_data.load(data_name = dataset_name)\n", 211 | "pert_data.prepare_split(split = 'simulation', seed = seed) # get data split with seed\n", 212 | "pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) # prepare data loader" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "gears_model = GEARS(pert_data, device = 'cuda:2', \n", 222 | " weight_bias_track = False, \n", 223 | " proj_name = dataset_name, \n", 224 | " exp_name = 'no_pert_seed' + str(seed))\n", 225 | "\n", 226 | "gears_model.model_initialize(hidden_size = 64,\n", 227 | " no_perturb=True)\n", 228 | "\n", 229 | "gears_model.train(epochs=0)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": {}, 242 | "source": [ 243 | "### Dataset: Norman " 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "dataset_name = 'norman'\n", 253 | "\n", 254 | "pert_data = PertData('./data', default_pert_graph=False)\n", 255 | "pert_data.load(data_name = dataset_name)\n", 256 | "pert_data.prepare_split(split = 'simulation', seed = seed) # get data split with seed\n", 257 | "pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) # prepare data loader" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "gears_model = GEARS(pert_data, device = 'cuda:2', \n", 267 | " weight_bias_track = False, \n", 268 | " proj_name = dataset_name, \n", 269 | " exp_name = 'gears_seed' + str(seed))\n", 270 | "\n", 271 | "gears_model.model_initialize(hidden_size = 64)\n", 272 | "\n", 273 | "gears_model.train(epochs=1)" 274 | ] 275 | }, 276 | { 277 | "cell_type": "markdown", 278 | "metadata": {}, 279 | "source": [ 280 | "### Norman: No perturb condition" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "dataset_name = 'norman'\n", 290 | "\n", 291 | "pert_data = PertData('./data', default_pert_graph=False)\n", 292 | "pert_data.load(data_name = dataset_name)\n", 293 | "pert_data.prepare_split(split = 'simulation', seed = seed) # get data split with seed\n", 294 | "pert_data.get_dataloader(batch_size = 32, test_batch_size = 128) # prepare data loader" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": null, 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [ 303 | "gears_model = GEARS(pert_data, device = 'cuda:2', \n", 304 | " weight_bias_track = False, \n", 305 | " proj_name = dataset_name, \n", 306 | " exp_name = 'no_pert_seed' + str(seed))\n", 307 | "\n", 308 | "gears_model.model_initialize(hidden_size = 64,\n", 309 | " no_perturb=True)\n", 310 | "\n", 311 | "gears_model.train(epochs=0)" 312 | ] 313 | } 314 | ], 315 | "metadata": { 316 | "kernelspec": { 317 | "display_name": "deepamp", 318 | "language": "python", 319 | "name": "deepamp" 320 | }, 321 | "language_info": { 322 | "codemirror_mode": { 323 | "name": "ipython", 324 | "version": 3 325 | }, 326 | "file_extension": ".py", 327 | "mimetype": "text/x-python", 328 | "name": "python", 329 | "nbconvert_exporter": "python", 330 | "pygments_lexer": "ipython3", 331 | "version": "3.8.10" 332 | } 333 | }, 334 | "nbformat": 4, 335 | "nbformat_minor": 4 336 | } 337 | -------------------------------------------------------------------------------- /paper/supp_train.py: -------------------------------------------------------------------------------- 1 | ## Training script for supplementary datasets 2 | 3 | import argparse 4 | import sys 5 | sys.path.append('../') 6 | 7 | from gears import PertData, GEARS 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--seed', type=int, default=1) 10 | parser.add_argument('--device', type=int, default=0) 11 | parser.add_argument('--dataset', type=str, default='norman2019', choices = ['norman2019', 'jost2020_hvg', 'tian2021_crispri_hvg', 'tian2021_crispra_hvg', 'replogle2020_hvg', 'replogle_rpe1_gw_hvg', 'replogle_k562_gw_hvg', 'replogle_k562_essential_hvg', 'tian2019_neuron_hvg', 'tian2019_ipsc_hvg', 'replogle_rpe1_gw_filtered_hvg', 'replogle_k562_essential_filtered_hvg', 'norman']) 12 | parser.add_argument('--model', type=str, default='gears', choices = ['gears', 'no_perturb']) 13 | parser.add_argument('--batch_size', type=int, default=32) 14 | 15 | args = parser.parse_args() 16 | seed = args.seed 17 | 18 | if args.model == 'no_perturb': 19 | epoch = 0 20 | no_perturb = True 21 | else: 22 | epoch = 15 23 | no_perturb = False 24 | 25 | ## Set this to local dataloader directory 26 | data_path = './data/' 27 | 28 | if args.dataset == 'tian2019_neuron_hvg': 29 | gene_path = './data/essential_all_data_pert_genes_tian2019_neuron.pkl' 30 | elif args.dataset == 'tian2019_ipsc_hvg': 31 | gene_path = './data/essential_all_data_pert_genes_tian2019_ipsc.pkl' 32 | elif args.dataset == 'jost2020_hvg': 33 | gene_path = './data/essential_all_data_pert_genes_jost2020.pkl' 34 | elif args.dataset == 'norman2019': 35 | gene_path = './data/essential_norman.pkl' 36 | else: 37 | gene_path = None 38 | 39 | if args.dataset in ['tian2019_neuron_hvg', 'tian2019_ipsc_hvg', 'jost2020_hvg']: 40 | add_small_graph = True 41 | else: 42 | add_small_graph = False 43 | 44 | pert_data = PertData(data_path[:-1], gene_path = gene_path) # specific saved folder 45 | pert_data.load(data_path = data_path + args.dataset) # load the processed data, the path is saved folder + dataset_name 46 | pert_data.prepare_split(split = 'simulation', seed = seed) 47 | pert_data.get_dataloader(batch_size = args.batch_size, test_batch_size = args.batch_size) 48 | from gears import GEARS 49 | gears_model = GEARS(pert_data, device = 'cuda:' + str(args.device), 50 | weight_bias_track = True, 51 | proj_name = args.dataset, 52 | exp_name = str(args.model) + '_seed' + str(seed)) 53 | 54 | if args.dataset == 'tian2019_neuron_hvg': 55 | go_path = './data/go_essential_tian2020_neuron.csv' 56 | elif args.dataset == 'tian2019_ipsc_hvg': 57 | go_path = './data/go_essential_tian2020_ipsc.csv' 58 | elif args.dataset == 'jost2020_hvg': 59 | go_path = './data/go_essential_jost2020.csv' 60 | elif args.dataset == 'norman2019': 61 | go_path = './data/go_essential_norman.csv' 62 | else: 63 | go_path = None 64 | 65 | gears_model.model_initialize(hidden_size = 64, no_perturb = no_perturb, go_path = go_path) 66 | 67 | gears_model.train(epochs = epoch) 68 | if args.model != 'no_perturb': 69 | gears_model.save_model('./model_ckpt/' + args.dataset + '_' + args.model + '_run' + str(seed)) 70 | --------------------------------------------------------------------------------