├── LICENSE ├── README.md ├── data_proc ├── Create New Species Files.ipynb ├── __pycache__ │ ├── data_utils.cpython-38.pyc │ └── gene_embeddings.cpython-38.pyc ├── data_utils.py ├── download_proc_czi_cxg.py ├── gene_embeddings.py ├── generate_reduced_chrom_files.py └── preproc_many_dataset.py ├── eval_data.py ├── eval_single_anndata.py ├── evaluate.py ├── examples ├── Benchmark Embeddings with scIB.ipynb └── Label Transfer Using Logistic Classifier.ipynb ├── model.py ├── model_files └── new_species_protein_embeddings.csv ├── requirements.txt └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yanay Rosen, Yusuf Roohani, Jure Leskovec 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Universal Cell Embeddings 2 | 3 | This repo includes a PyTorch [HuggingFace Accelerator](https://huggingface.co/docs/accelerate/package_reference/accelerator) implementation of the UCE model, to be used to embed individual anndata datasets. 4 | 5 | ## Installation 6 | 7 | ``` 8 | pip install -r requirements.txt 9 | ``` 10 | 11 | ## Embedding a new dataset 12 | 13 | To generate an embedding for a new single-cell RNA sequencing dataset in the AnnData format, use the `eval_single_anndata.py` script. 14 | 15 | ``` 16 | python eval_single_anndata.py --adata_path {path_to_anndata} --dir {output_dir} --species {species} --model_loc {model_loc} --batch_size {batch_size} 17 | ``` 18 | 19 | where 20 | - `adata_path`: a h5ad file. The `.X` slot of the file should be scRNA-seq counts. The `.var_names` slot should correspond to gene names, *not ENSEMBLIDs*. 21 | - `dir`: the working directory in which intermediate and final output files will be saved to skip repeated processing of the same dataset. 22 | - `species`: the species of the dataset you are embedding. 23 | - `model_loc`: the location of the model weights `.torch` file. 24 | - `batch_size`: the per GPU batch size. For the 33 layer model, on a 80GB GPU, you should use 25. For a 4 layer model on the same GPU, you can use 100. 25 | 26 | For a sample output on the 10k pbmc dataset, run 27 | ``` 28 | python eval_single_anndata.py 29 | ``` 30 | All necessary model files will be downloaded automatically. 31 | 32 | 33 | **Note**: This script makes use of additional files, which are described in the code documentation. These are downloaded automatically unless already present in the working directory. The script defaults to the pretrained 4-layer model. For running the pretrained 33-layer model from the paper, please download using this [link](https://figshare.com/articles/dataset/Universal_Cell_Embedding_Model_Files/24320806?file=43423236) and set `--nlayers 33`. 34 | 35 | ## Output 36 | 37 | Final evaluated AnnData: `dir/{dataset_name}.h5ad`. This AnnData will be 38 | identical to the proccessed input anndata, but have UCE embeddings added in the `.obsm["X_uce"]` slot. 39 | 40 | Please see documentation for information on additional output files. All 41 | outputs from `eval_single_anndata.py` are stored in the `dir` directory. 42 | 43 | ## Data 44 | 45 | You can download processed datasets used in the papere [here](https://drive.google.com/drive/folders/1f63fh0ykgEhCrkd_EVvIootBw7LYDVI7?usp=drive_link) 46 | 47 | **Note:** These datasets were embedded using the 33 layer model. Embeddings for the 33 layer model are not compatible with embeddings from the 4 layer model. 48 | 49 | ## Citing 50 | 51 | If you find our paper and code useful, please consider citing the [preprint](https://www.biorxiv.org/content/10.1101/2023.11.28.568918v1): 52 | 53 | ``` 54 | @article{rosen2023universal, 55 | title={Universal Cell Embeddings: A Foundation Model for Cell Biology}, 56 | author={Rosen, Yanay and Roohani, Yusuf and Agrawal, Ayush and Samotorcan, Leon and Consortium, Tabula Sapiens and Quake, Stephen R and Leskovec, Jure}, 57 | journal={bioRxiv}, 58 | pages={2023--11}, 59 | year={2023}, 60 | publisher={Cold Spring Harbor Laboratory} 61 | } 62 | ``` 63 | -------------------------------------------------------------------------------- /data_proc/Create New Species Files.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "0e4018ee", 6 | "metadata": {}, 7 | "source": [ 8 | "# Embedding Novel Species\n", 9 | "\n", 10 | "This notebook will create the files you need to embed a novel species that wasn't included in the training data.\n", 11 | "\n", 12 | "To start, you will need to download the ESM2 protein embeddings and the reference proteome for the species.\n", 13 | "\n", 14 | "You can find precalculated ESM2 protein embeddings for many species [here](https://drive.google.com/drive/folders/1_Dz7HS5N3GoOAG6MdhsXWY1nwLoN13DJ?usp=drive_link)\n", 15 | "\n", 16 | "For reference proteomes, you can download them from [here](https://useast.ensembl.org/info/about/species.html).\n", 17 | "\n", 18 | "If there is no protein embedding for the species you are interested in, you can request to have it made via Github or email, or you can create it yourself following instructions [here](https://github.com/snap-stanford/SATURN/tree/main/protein_embeddings)." 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 1, 24 | "id": "ab368d92", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "import numpy as np\n", 29 | "import pickle as pkl\n", 30 | "import pandas as pd" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 2, 36 | "id": "c9a306f3", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "SPECIES_NAME = \"chicken\" # short hand name for this species, will be used in arguments and files\n", 41 | "\n", 42 | "# Path to the species proteome\n", 43 | "SPECIES_PROTEIN_FASTA_PATH = \"../../../SATURN/protein_embeddings/data/Gallus_gallus.bGalGal1.mat.broiler.GRCg7b.pep.all.fa\"\n", 44 | "\n", 45 | "# Path to the ESM2 Embeddings\n", 46 | "SPECIES_PROTEIN_EMBEDDINGS_PATH = \"../model_files/protein_embeddings/Gallus_gallus.bGalGal1.mat.broiler.GRCg7b.pep.all.gene_symbol_to_embedding_ESM2.pt\"\n", 47 | "\n", 48 | "# primary_assembly name, this needs to be matched to the FASTA file\n", 49 | "ASSEMBLY_NAME = \"bGalGal1.mat.broiler.GRCg7b\"\n", 50 | "# NCBI Taxonomy ID, please set this so that if someone else also embeds the same species,\n", 51 | "# randomly generated chromosome tokens will be the same\n", 52 | "TAXONOMY_ID = 9031" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "id": "e5d37e52", 58 | "metadata": {}, 59 | "source": [ 60 | "You can view the FASTA format here, please confirm the primary_assembly name is correct." 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 3, 66 | "id": "2ecf1464", 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | ">ENSGALP00010000002.1 pep primary_assembly:bGalGal1.mat.broiler.GRCg7b:MT:2824:3798:1 gene:ENSGALG00010000007.1 transcript:ENSGALT00010000007.1 gene_biotype:protein_coding transcript_biotype:protein_coding gene_symbol:ND1 description:NADH dehydrogenase subunit 1 [Source:NCBI gene (formerly Entrezgene);Acc:63549479]\r\n", 74 | "MTLPTLTNLLIMTLSYILPILIAVAFLTLVERKILSYMQARKGPNIVGPFGLLQPVADGV\r\n", 75 | "KLFIKEPIRPSTSSPFLFIITPILALLLALTIWVPLPLPFPLADLNLGLLFLLAMSSLTV\r\n", 76 | "YSLLWSGWASNSKYALIGALRAVAQTISYEVTLAIILLSTIMLSGNYTLSTLAITQEPIY\r\n", 77 | "LIFSAWPLAMMWYISTLAETNRAPFDLTEGESELVSGFNVEYAAGPFAMFFLAEYANIML\r\n", 78 | "MNTLTTVLFLNPSFLNLPPELFPIALATKTLLLSSSFLWIRASYPRFRYDQLMHLLWKNF\r\n", 79 | "LPLTLALCLWHTSMPISYAGLPPI\r\n", 80 | ">ENSGALP00010000003.1 pep primary_assembly:bGalGal1.mat.broiler.GRCg7b:MT:4015:5053:1 gene:ENSGALG00010000011.1 transcript:ENSGALT00010000011.1 gene_biotype:protein_coding transcript_biotype:protein_coding gene_symbol:ND2 description:NADH dehydrogenase subunit 2 [Source:NCBI gene (formerly Entrezgene);Acc:63549482]\r\n", 81 | "MNPHAKLICTVSLIMGTSITISSNHWILAWTGLEINTLAIIPLISKSHHPRAIEATIKYF\r\n", 82 | "LTQSTASALILFSSMTNAWSTGQWDITQLNHPTSCLMLTMAIAIKLGLVPFHFWFPEVLQ\r\n" 83 | ] 84 | } 85 | ], 86 | "source": [ 87 | "!head {SPECIES_PROTEIN_FASTA_PATH}" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 4, 93 | "id": "90540d0b", 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "species_to_paths = {\n", 98 | " SPECIES_NAME: SPECIES_PROTEIN_FASTA_PATH,\n", 99 | "}\n", 100 | "\n", 101 | "species_to_ids = {\n", 102 | " SPECIES_NAME: ASSEMBLY_NAME,\n", 103 | "}" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 5, 109 | "id": "623b99cf", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "all_pos_def = []\n", 114 | "\n", 115 | "missing_genes = {}\n", 116 | "for species in species_to_ids.keys():\n", 117 | " missing_genes[species] = []\n", 118 | " proteome_path = species_to_paths[species]\n", 119 | " species_id = species_to_ids[species]\n", 120 | "\n", 121 | " with open(proteome_path) as f:\n", 122 | " proteome_lines = f.readlines()\n", 123 | "\n", 124 | " gene_symbol_to_location = {}\n", 125 | " gene_symbol_to_chrom = {}\n", 126 | "\n", 127 | " for line in proteome_lines:\n", 128 | " if line.startswith(\">\"):\n", 129 | " split_line = line.split()\n", 130 | " gene_symbol = [token for token in split_line if token.startswith(\"gene_symbol\")]\n", 131 | " if len(gene_symbol) > 0:\n", 132 | " gene_symbol = gene_symbol[0].split(\":\")\n", 133 | " \n", 134 | " if len(gene_symbol) == 2:\n", 135 | " gene_symbol = gene_symbol[1]\n", 136 | " elif len(gene_symbol) > 2:\n", 137 | " gene_symbol = \":\".join(gene_symbol[1:]) # fix for annoying zebrafish gene names with colons in them\n", 138 | " else:\n", 139 | " 1/0 # something weird happening, throw an error\n", 140 | " \n", 141 | " \n", 142 | " chrom = None\n", 143 | " \n", 144 | " chrom_arr = [token for token in split_line if token.startswith(\"chromosome:\")]\n", 145 | " if len(chrom_arr) > 0:\n", 146 | " chrom = chrom_arr[0].replace(\"chromosome:\", \"\")\n", 147 | " else:\n", 148 | " chrom_arr = [token for token in split_line if token.startswith(\"primary_assembly:\")]\n", 149 | " if len(chrom_arr) > 0:\n", 150 | " chrom = chrom_arr[0].replace(\"primary_assembly:\", \"\")\n", 151 | " else:\n", 152 | " chrom_arr = [token for token in split_line if token.startswith(\"scaffold:\")] \n", 153 | " if len(chrom_arr) > 0:\n", 154 | " chrom = chrom_arr[0].replace(\"scaffold:\", \"\")\n", 155 | " if chrom is not None:\n", 156 | " gene_symbol_to_location[gene_symbol] = chrom.split(\":\")[2]\n", 157 | " gene_symbol_to_chrom[gene_symbol] = chrom.split(\":\")[1]\n", 158 | " else:\n", 159 | " missing_genes[species].append(gene_symbol)\n", 160 | " \n", 161 | "\n", 162 | " positional_df = pd.DataFrame()\n", 163 | " positional_df[\"gene_symbol\"] = [gn.upper() for gn in list(gene_symbol_to_chrom.keys())]\n", 164 | " positional_df[\"chromosome\"] = list(gene_symbol_to_chrom.values())\n", 165 | " positional_df[\"start\"] = list(gene_symbol_to_location.values())\n", 166 | " positional_df = positional_df.sort_values([\"chromosome\", \"start\"])\n", 167 | " #positional_df = positional_df.set_index(\"gene_symbol\")\n", 168 | " positional_df[\"species\"] = species\n", 169 | " all_pos_def.append(positional_df)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 6, 175 | "id": "b72887b3", 176 | "metadata": {}, 177 | "outputs": [ 178 | { 179 | "data": { 180 | "text/html": [ 181 | "
\n", 182 | "\n", 195 | "\n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | "
gene_symbolchromosomestartspecies
2327GCC111006145chicken
2502NCAM21100828671chicken
3084ENS-21101147482chicken
2331DENND6B11012031chicken
3973MRPL391102578362chicken
...............
4722CA9Z9779343chicken
4738ARHGEF39Z9835547chicken
3885MRPL17Z9850679chicken
4172CCBE1Z9852827chicken
3293PMAIP1Z9998272chicken
\n", 285 | "

13271 rows × 4 columns

\n", 286 | "
" 287 | ], 288 | "text/plain": [ 289 | " gene_symbol chromosome start species\n", 290 | "2327 GCC1 1 1006145 chicken\n", 291 | "2502 NCAM2 1 100828671 chicken\n", 292 | "3084 ENS-2 1 101147482 chicken\n", 293 | "2331 DENND6B 1 1012031 chicken\n", 294 | "3973 MRPL39 1 102578362 chicken\n", 295 | "... ... ... ... ...\n", 296 | "4722 CA9 Z 9779343 chicken\n", 297 | "4738 ARHGEF39 Z 9835547 chicken\n", 298 | "3885 MRPL17 Z 9850679 chicken\n", 299 | "4172 CCBE1 Z 9852827 chicken\n", 300 | "3293 PMAIP1 Z 9998272 chicken\n", 301 | "\n", 302 | "[13271 rows x 4 columns]" 303 | ] 304 | }, 305 | "execution_count": 6, 306 | "metadata": {}, 307 | "output_type": "execute_result" 308 | } 309 | ], 310 | "source": [ 311 | "master_pos_def = pd.concat(all_pos_def)\n", 312 | "master_pos_def" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": 7, 318 | "id": "6d9dac28", 319 | "metadata": {}, 320 | "outputs": [ 321 | { 322 | "data": { 323 | "text/plain": [ 324 | "chicken 13271\n", 325 | "Name: species, dtype: int64" 326 | ] 327 | }, 328 | "execution_count": 7, 329 | "metadata": {}, 330 | "output_type": "execute_result" 331 | } 332 | ], 333 | "source": [ 334 | "master_pos_def[\"species\"].value_counts() # double check how many genes are mapped" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 8, 340 | "id": "4a3d45c2", 341 | "metadata": {}, 342 | "outputs": [ 343 | { 344 | "name": "stdout", 345 | "output_type": "stream", 346 | "text": [ 347 | "chicken: 0\n" 348 | ] 349 | } 350 | ], 351 | "source": [ 352 | "for k, v in missing_genes.items():\n", 353 | " print(f\"{k}: {len(v)}\") # are any genes missing?" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": 9, 359 | "id": "c59774b1", 360 | "metadata": { 361 | "scrolled": true 362 | }, 363 | "outputs": [ 364 | { 365 | "name": "stdout", 366 | "output_type": "stream", 367 | "text": [ 368 | "*********\n", 369 | "chicken\n" 370 | ] 371 | }, 372 | { 373 | "data": { 374 | "text/plain": [ 375 | "1 1785\n", 376 | "2 1169\n", 377 | "3 1067\n", 378 | "4 953\n", 379 | "5 817\n", 380 | "Z 629\n", 381 | "6 458\n", 382 | "8 450\n", 383 | "7 442\n", 384 | "9 382\n", 385 | "10 366\n", 386 | "14 359\n", 387 | "11 327\n", 388 | "15 326\n", 389 | "13 306\n", 390 | "20 298\n", 391 | "12 293\n", 392 | "19 278\n", 393 | "18 274\n", 394 | "17 260\n", 395 | "26 237\n", 396 | "28 237\n", 397 | "27 235\n", 398 | "21 226\n", 399 | "23 214\n", 400 | "25 176\n", 401 | "34 155\n", 402 | "24 149\n", 403 | "22 142\n", 404 | "16 54\n", 405 | "30 52\n", 406 | "38 49\n", 407 | "31 14\n", 408 | "MT 13\n", 409 | "39 10\n", 410 | "JAENSK010000484.1 7\n", 411 | "35 6\n", 412 | "JAENSK010000592.1 6\n", 413 | "W 5\n", 414 | "MU179278.1 5\n", 415 | "MU179279.1 4\n", 416 | "36 3\n", 417 | "JAENSK010000483.1 3\n", 418 | "JAENSK010000585.1 3\n", 419 | "JAENSK010000593.1 2\n", 420 | "MU179258.1 2\n", 421 | "MU179272.1 2\n", 422 | "MU179273.1 2\n", 423 | "JAENSK010000584.1 2\n", 424 | "JAENSK010000656.1 1\n", 425 | "Name: chromosome, dtype: int64" 426 | ] 427 | }, 428 | "metadata": {}, 429 | "output_type": "display_data" 430 | }, 431 | { 432 | "name": "stdout", 433 | "output_type": "stream", 434 | "text": [ 435 | "*********\n" 436 | ] 437 | } 438 | ], 439 | "source": [ 440 | "# Count genes per chromosome\n", 441 | "for species in species_to_ids.keys():\n", 442 | " print(\"*********\")\n", 443 | " print(species)\n", 444 | " display(master_pos_def[master_pos_def[\"species\"] == species][\"chromosome\"].value_counts().head(50))\n", 445 | " print(\"*********\")" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": 10, 451 | "id": "541baded", 452 | "metadata": {}, 453 | "outputs": [], 454 | "source": [ 455 | "master_pos_def.to_csv(f\"{SPECIES_NAME}_to_chrom_pos.csv\", index=False) # Save the DF" 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": 11, 461 | "id": "eabd0e31", 462 | "metadata": {}, 463 | "outputs": [ 464 | { 465 | "name": "stdout", 466 | "output_type": "stream", 467 | "text": [ 468 | "chicken_to_chrom_pos.csv\n" 469 | ] 470 | } 471 | ], 472 | "source": [ 473 | "# The chromosome file path will be:\n", 474 | "print(f\"{SPECIES_NAME}_to_chrom_pos.csv\")" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": 12, 480 | "id": "fe1345b1", 481 | "metadata": {}, 482 | "outputs": [ 483 | { 484 | "data": { 485 | "text/plain": [ 486 | "66" 487 | ] 488 | }, 489 | "execution_count": 12, 490 | "metadata": {}, 491 | "output_type": "execute_result" 492 | } 493 | ], 494 | "source": [ 495 | "N_UNIQ_CHROM = len(master_pos_def[master_pos_def[\"species\"] == species][\"chromosome\"].unique())\n", 496 | "N_UNIQ_CHROM" 497 | ] 498 | }, 499 | { 500 | "cell_type": "markdown", 501 | "id": "e37e277f", 502 | "metadata": {}, 503 | "source": [ 504 | "# Generate token file" 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "execution_count": 13, 510 | "id": "d6904975", 511 | "metadata": {}, 512 | "outputs": [], 513 | "source": [ 514 | "import torch\n", 515 | "import pickle\n", 516 | "token_dim = 5120" 517 | ] 518 | }, 519 | { 520 | "cell_type": "markdown", 521 | "id": "a2798848", 522 | "metadata": {}, 523 | "source": [ 524 | "This will create the token file. Please note the offset value." 525 | ] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "execution_count": 14, 530 | "id": "4355dabd", 531 | "metadata": {}, 532 | "outputs": [ 533 | { 534 | "name": "stdout", 535 | "output_type": "stream", 536 | "text": [ 537 | "CHROM_TOKEN_OFFSET: 13275\n", 538 | "Saved PE, offsets file\n" 539 | ] 540 | } 541 | ], 542 | "source": [ 543 | "species_to_offsets = {}\n", 544 | "\n", 545 | "all_pe = torch.load(\"../model_files/all_tokens.torch\")[0:4] # read in existing token file to make sure \n", 546 | "# that special vocab tokens are the same for different seeds\n", 547 | "\n", 548 | "offset = len(all_pe) # special tokens at the top!\n", 549 | "\n", 550 | "PE = torch.load(SPECIES_PROTEIN_EMBEDDINGS_PATH)\n", 551 | "\n", 552 | "pe_stacked = torch.stack(list(PE.values()))\n", 553 | "all_pe = torch.vstack((all_pe, pe_stacked))\n", 554 | "species_to_offsets[species] = offset\n", 555 | "\n", 556 | "print(\"CHROM_TOKEN_OFFSET:\", all_pe.shape[0])\n", 557 | "torch.manual_seed(TAXONOMY_ID)\n", 558 | "CHROM_TENSORS = torch.normal(mean=0, std=1, size=(N_UNIQ_CHROM, 5120)) \n", 559 | "# N_UNIQ_CHROM is the total number of chromosome choices, it is hardcoded for now (for species in the training data)\n", 560 | "all_pe = torch.vstack(\n", 561 | " (all_pe, CHROM_TENSORS)) # Add the chrom tensors to the end\n", 562 | "all_pe.requires_grad = False\n", 563 | "\n", 564 | "\n", 565 | "torch.save(all_pe, f\"{SPECIES_NAME}_pe_tokens.torch\")\n", 566 | "\n", 567 | "with open(f\"{SPECIES_NAME}_offsets.pkl\", \"wb+\") as f:\n", 568 | " pickle.dump(species_to_offsets, f)\n", 569 | "print(\"Saved PE, offsets file\")" 570 | ] 571 | }, 572 | { 573 | "cell_type": "code", 574 | "execution_count": 15, 575 | "id": "c26fe491", 576 | "metadata": { 577 | "scrolled": true 578 | }, 579 | "outputs": [ 580 | { 581 | "data": { 582 | "text/plain": [ 583 | "torch.Size([13341, 5120])" 584 | ] 585 | }, 586 | "execution_count": 15, 587 | "metadata": {}, 588 | "output_type": "execute_result" 589 | } 590 | ], 591 | "source": [ 592 | "all_pe.shape" 593 | ] 594 | }, 595 | { 596 | "cell_type": "code", 597 | "execution_count": 16, 598 | "id": "21f937ea", 599 | "metadata": { 600 | "scrolled": true 601 | }, 602 | "outputs": [ 603 | { 604 | "data": { 605 | "text/plain": [ 606 | "torch.Size([13341, 5120])" 607 | ] 608 | }, 609 | "execution_count": 16, 610 | "metadata": {}, 611 | "output_type": "execute_result" 612 | } 613 | ], 614 | "source": [ 615 | "all_pe.shape" 616 | ] 617 | }, 618 | { 619 | "cell_type": "code", 620 | "execution_count": 17, 621 | "id": "5faadace", 622 | "metadata": {}, 623 | "outputs": [ 624 | { 625 | "name": "stdout", 626 | "output_type": "stream", 627 | "text": [ 628 | "chicken_offsets.pkl\n" 629 | ] 630 | } 631 | ], 632 | "source": [ 633 | "print(f\"{SPECIES_NAME}_offsets.pkl\")" 634 | ] 635 | }, 636 | { 637 | "cell_type": "code", 638 | "execution_count": 18, 639 | "id": "6ceac20b", 640 | "metadata": {}, 641 | "outputs": [ 642 | { 643 | "data": { 644 | "text/plain": [ 645 | "'../model_files/protein_embeddings/Gallus_gallus.bGalGal1.mat.broiler.GRCg7b.pep.all.gene_symbol_to_embedding_ESM2.pt'" 646 | ] 647 | }, 648 | "execution_count": 18, 649 | "metadata": {}, 650 | "output_type": "execute_result" 651 | } 652 | ], 653 | "source": [ 654 | "SPECIES_PROTEIN_EMBEDDINGS_PATH" 655 | ] 656 | }, 657 | { 658 | "cell_type": "markdown", 659 | "id": "e4697330", 660 | "metadata": {}, 661 | "source": [ 662 | "# Example evaluation of new species" 663 | ] 664 | }, 665 | { 666 | "cell_type": "markdown", 667 | "id": "2b72667d", 668 | "metadata": {}, 669 | "source": [ 670 | "**Note: when you evaluate a new species, you need to change some arguments and modify some files:**\n", 671 | "\n", 672 | "You will need to modify the csv in `model_files/new_species_protein_embeddings.csv` to include the new protein embeddings file you downloaded.\n", 673 | "\n", 674 | "In the file add a row for the new species with the format:\n", 675 | "`species name,full path to protein embedding file`\n", 676 | "\n", 677 | "Please also add this line to the dictionary created on line 247 in the file `data_proc/data_utils.py`.\n", 678 | "\n", 679 | "When you want to embed this new species, you will need to specify these newly created files as arguments.\n", 680 | "- `CHROM_TOKEN_OFFSET`: This tells UCE when the rows corresponding to chromosome tokens starts.\n", 681 | "- `spec_chrom_csv_path`: This is a new csv, created by this script, which maps genes to chromosomes and genomic positions\n", 682 | "- `token_file`: This is a new token file that will work just for this species. The embeddings generated will still be universal though!\n", 683 | "- `offset_pkl_path`: This is another file that maps genes to tokens\n", 684 | "\n", 685 | "\n", 686 | "```\n", 687 | "\n", 688 | "accelerate launch eval_single_anndata.py chicken_heart.h5ad --species=chicken --CHROM_TOKEN_OFFSET=13275 --spec_chrom_csv_path=data_proc/chicken_to_chrom_pos.csv --token_file=data_proc/chicken_pe_tokens.torch --offset_pkl_path=data_proc/chicken_offsets.pkl --dir=... --multi_gpu=True\n", 689 | "\n", 690 | "```" 691 | ] 692 | } 693 | ], 694 | "metadata": { 695 | "kernelspec": { 696 | "display_name": "Python 3 (ipykernel)", 697 | "language": "python", 698 | "name": "python3" 699 | }, 700 | "language_info": { 701 | "codemirror_mode": { 702 | "name": "ipython", 703 | "version": 3 704 | }, 705 | "file_extension": ".py", 706 | "mimetype": "text/x-python", 707 | "name": "python", 708 | "nbconvert_exporter": "python", 709 | "pygments_lexer": "ipython3", 710 | "version": "3.8.6" 711 | } 712 | }, 713 | "nbformat": 4, 714 | "nbformat_minor": 5 715 | } 716 | -------------------------------------------------------------------------------- /data_proc/__pycache__/data_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snap-stanford/UCE/8227a65cdd021b9186ef86671d2aef5c895c8e4b/data_proc/__pycache__/data_utils.cpython-38.pyc -------------------------------------------------------------------------------- /data_proc/__pycache__/gene_embeddings.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snap-stanford/UCE/8227a65cdd021b9186ef86671d2aef5c895c8e4b/data_proc/__pycache__/gene_embeddings.cpython-38.pyc -------------------------------------------------------------------------------- /data_proc/data_utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | 4 | import scanpy as sc 5 | import torch 6 | 7 | from torch import nn, Tensor 8 | import torch.nn.functional as F 9 | import torch.utils.data as data 10 | import torch.optim as optim 11 | import numpy as np 12 | import pickle 13 | import os 14 | import argparse 15 | import logging 16 | import time 17 | 18 | from tqdm.auto import tqdm 19 | import pandas as pd 20 | 21 | import math 22 | import anndata 23 | from pathlib import Path 24 | 25 | 26 | from torch.utils.data import dataset 27 | from torch.utils.data import DataLoader, TensorDataset, dataset 28 | from scipy.stats import binom 29 | from typing import Dict, List, Optional, Tuple 30 | from scanpy import AnnData 31 | 32 | 33 | from data_proc.gene_embeddings import load_gene_embeddings_adata 34 | 35 | def data_to_torch_X(X): 36 | if isinstance(X, sc.AnnData): 37 | X = X.X 38 | if not isinstance(X, np.ndarray): 39 | X = X.toarray() 40 | return torch.from_numpy(X).float() 41 | 42 | class SincleCellDataset(data.Dataset): 43 | def __init__(self, 44 | expression: torch.tensor, # Subset to hv genes, count data! cells x genes 45 | protein_embeddings: torch.tensor, # same order as expression, also subset genes x pe 46 | labels: None, # optional, tensor of labels 47 | covar_vals: None, # tensor of covar values or none 48 | ) -> None: 49 | super(SincleCellDataset, self).__init__() 50 | 51 | # Set expression 52 | self.expression = expression 53 | 54 | row_sums = self.expression.sum(1) # UMI Counts 55 | log_norm_count_adj = torch.log1p(self.expression / (self.expression.sum(1)).unsqueeze(1) * torch.tensor(1000)) 56 | 57 | # Set log norm and count adjusted expression 58 | max_vals, max_idx = torch.max(log_norm_count_adj, dim=0) 59 | self.expression_mod = log_norm_count_adj / max_vals 60 | 61 | # Calculate dropout likliehoods of each gene 62 | self.dropout_vec = (self.expression == 0).float().mean(0) # per gene dropout percentages 63 | 64 | # Set data info 65 | self.num_cells = self.expression.shape[0] 66 | self.num_genes = self.expression.shape[1] 67 | 68 | # Set optional label info, including categorical covariate index 69 | self.covar_vals = covar_vals 70 | self.labels = labels 71 | 72 | # Set protein embeddings 73 | self.protein_embeddings = protein_embeddings 74 | 75 | self.item_mode = "expression" 76 | if self.covar_vals is not None: 77 | self.item_mode = "expression+covar" 78 | 79 | 80 | def __getitem__(self, idx): 81 | if self.item_mode == "expression": 82 | if isinstance(idx, int): 83 | if idx < self.num_cells: 84 | return self.expression[idx, :] 85 | else: 86 | raise IndexError 87 | else: 88 | raise NotImplementedError 89 | elif self.item_mode == "expression+covar": 90 | if isinstance(idx, int): 91 | if idx < self.num_cells: 92 | return self.expression[idx, :], self.covar_vals[idx] 93 | else: 94 | raise IndexError 95 | else: 96 | raise NotImplementedError 97 | 98 | 99 | def __len__(self) -> int: 100 | return self.num_cells 101 | 102 | def get_dim(self) -> Dict[str, int]: 103 | return self.num_genes 104 | 105 | 106 | def data_to_torch_X(X): 107 | if isinstance(X, sc.AnnData): 108 | X = X.X 109 | if not isinstance(X, np.ndarray): 110 | X = X.toarray() 111 | return torch.from_numpy(X).float() 112 | 113 | 114 | def anndata_to_sc_dataset(adata:sc.AnnData, 115 | species:str="human", 116 | labels:list=[], 117 | covar_col:str=None, 118 | hv_genes=None, 119 | embedding_model="ESM2", 120 | ) -> (SincleCellDataset, AnnData): 121 | 122 | # Subset to just genes we have embeddings for 123 | adata, protein_embeddings = load_gene_embeddings_adata( 124 | adata=adata, 125 | species=[species], 126 | embedding_model=embedding_model 127 | ) 128 | 129 | if hv_genes is not None: 130 | sc.pp.highly_variable_genes(adata, flavor='seurat_v3', n_top_genes=hv_genes) # Expects Count Data 131 | 132 | hv_index = adata.var["highly_variable"] 133 | adata = adata[:, hv_index] # Subset to hv genes only 134 | 135 | protein_embeddings = protein_embeddings[species][hv_index] 136 | else: 137 | protein_embeddings = protein_embeddings[species] 138 | expression = data_to_torch_X(adata.X) 139 | 140 | covar_vals = None 141 | if len(labels) > 0: 142 | assert covar_col is None or covar_col in labels, "Covar needs to be in labels" # make sure you keep track of covar column! 143 | labels = adata.obs.loc[:, labels].values 144 | 145 | if covar_col is not None: 146 | # we have a categorical label to use as covariate 147 | covar_vals = torch.tensor(pd.Categorical(adata.obs[covar_col]).codes) 148 | return SincleCellDataset( 149 | expression=expression, 150 | protein_embeddings=protein_embeddings, 151 | labels=labels, 152 | covar_vals=covar_vals 153 | ), adata 154 | 155 | def adata_path_to_prot_chrom_starts(adata, dataset_species, spec_pe_genes, gene_to_chrom_pos, offset): 156 | """ 157 | Given a :path: to an h5ad, 158 | """ 159 | pe_row_idxs = torch.tensor([spec_pe_genes.index(k.upper()) + offset for k in adata.var_names]).long() 160 | print(len(np.unique(pe_row_idxs))) 161 | 162 | spec_chrom = gene_to_chrom_pos[gene_to_chrom_pos["species"] == dataset_species].set_index("gene_symbol") 163 | 164 | gene_chrom = spec_chrom.loc[[k.upper() for k in adata.var_names]] 165 | 166 | dataset_chroms = gene_chrom["spec_chrom"].cat.codes # now this is correctely indexed by species and chromosome 167 | print("Max Code:", max(dataset_chroms)) 168 | dataset_pos = gene_chrom["start"].values 169 | return pe_row_idxs, dataset_chroms, dataset_pos 170 | 171 | 172 | 173 | def process_raw_anndata(row, h5_folder_path, npz_folder_path, scp, skip, 174 | additional_filter, root): 175 | path = row.path 176 | if not os.path.isfile(root + "/" + path): 177 | print( "**********************************") 178 | print(f"***********{root + '/' + path} File Missing****") 179 | print( "**********************************") 180 | print(path, root) 181 | return None 182 | 183 | name = path.replace(".h5ad", "") 184 | proc_path = path.replace(".h5ad", "_proc.h5ad") 185 | if skip: 186 | if os.path.isfile(h5_folder_path + proc_path): 187 | print(f"{name} already processed. Skipping") 188 | return None, None, None 189 | 190 | print(f"Proccessing {name}") 191 | 192 | species = row.species 193 | covar_col = row.covar_col 194 | 195 | ad = sc.read(root + "/" + path) 196 | labels = [] 197 | if "cell_type" in ad.obs.columns: 198 | labels.append("cell_type") 199 | 200 | 201 | if covar_col is np.nan or np.isnan(covar_col): 202 | covar_col = None 203 | else: 204 | labels.append(covar_col) 205 | 206 | if additional_filter: 207 | sc.pp.filter_genes(ad, min_cells=10) 208 | sc.pp.filter_cells(ad, min_genes=25) 209 | 210 | 211 | dataset, adata = anndata_to_sc_dataset(ad, species=species, labels=labels, covar_col=covar_col, hv_genes=None) 212 | adata = adata.copy() 213 | 214 | if additional_filter: 215 | sc.pp.filter_genes(ad, min_cells=10) 216 | sc.pp.filter_cells(ad, min_genes=25) 217 | 218 | num_cells = adata.X.shape[0] 219 | num_genes = adata.X.shape[1] 220 | 221 | adata_path = h5_folder_path + proc_path 222 | adata.write(adata_path) 223 | 224 | arr = data_to_torch_X(adata.X).numpy() 225 | 226 | print(arr.max()) # this is a nice check to make sure it's counts 227 | filename = npz_folder_path + f"{name}_counts.npz" 228 | shape = arr.shape 229 | print(name, shape) 230 | fp = np.memmap(filename, dtype='int64', mode='w+', shape=shape) 231 | fp[:] = arr[:] 232 | fp.flush() 233 | 234 | if scp != "": 235 | subprocess.call(["scp", filename, f"{scp}:{filename}"]) 236 | subprocess.call(["scp", adata_path, f"{scp}:{adata_path}"]) 237 | 238 | return adata, num_cells, num_genes 239 | 240 | 241 | def get_species_to_pe(EMBEDDING_DIR): 242 | """ 243 | Given an embedding directory, return all embeddings as a dictionary coded by species. 244 | Note: In the current form, this function is written such that the directory needs all of the following species embeddings. 245 | """ 246 | EMBEDDING_DIR = Path(EMBEDDING_DIR) 247 | 248 | embeddings_paths = { 249 | 'human': EMBEDDING_DIR / 'Homo_sapiens.GRCh38.gene_symbol_to_embedding_ESM2.pt', 250 | 'mouse': EMBEDDING_DIR / 'Mus_musculus.GRCm39.gene_symbol_to_embedding_ESM2.pt', 251 | 'frog': EMBEDDING_DIR / 'Xenopus_tropicalis.Xenopus_tropicalis_v9.1.gene_symbol_to_embedding_ESM2.pt', 252 | 'zebrafish': EMBEDDING_DIR / 'Danio_rerio.GRCz11.gene_symbol_to_embedding_ESM2.pt', 253 | "mouse_lemur": EMBEDDING_DIR / "Microcebus_murinus.Mmur_3.0.gene_symbol_to_embedding_ESM2.pt", 254 | "pig": EMBEDDING_DIR / 'Sus_scrofa.Sscrofa11.1.gene_symbol_to_embedding_ESM2.pt', 255 | "macaca_fascicularis": EMBEDDING_DIR / 'Macaca_fascicularis.Macaca_fascicularis_6.0.gene_symbol_to_embedding_ESM2.pt', 256 | "macaca_mulatta": EMBEDDING_DIR / 'Macaca_mulatta.Mmul_10.gene_symbol_to_embedding_ESM2.pt', 257 | } 258 | extra_species = pd.read_csv("./model_files/new_species_protein_embeddings.csv").set_index("species").to_dict()["path"] 259 | embeddings_paths.update(extra_species) # adds new species 260 | 261 | 262 | 263 | species_to_pe = { 264 | species:torch.load(pe_dir) for species, pe_dir in embeddings_paths.items() 265 | } 266 | 267 | species_to_pe = {species:{k.upper(): v for k,v in pe.items()} for species, pe in species_to_pe.items()} 268 | return species_to_pe 269 | 270 | 271 | def get_spec_chrom_csv(path="/dfs/project/cross-species/yanay/code/all_to_chrom_pos.csv"): 272 | """ 273 | Get the species to chrom csv file 274 | """ 275 | gene_to_chrom_pos = pd.read_csv(path) 276 | gene_to_chrom_pos["spec_chrom"] = pd.Categorical(gene_to_chrom_pos["species"] + "_" + gene_to_chrom_pos["chromosome"]) # add the spec_chrom list 277 | return gene_to_chrom_pos -------------------------------------------------------------------------------- /data_proc/download_proc_czi_cxg.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["OMP_NUM_THREADS"] = "20" # export OMP_NUM_THREADS=4 3 | os.environ["OPENBLAS_NUM_THREADS"] = "20" # export OPENBLAS_NUM_THREADS=4 4 | os.environ["MKL_NUM_THREADS"] = "20" # export MKL_NUM_THREADS=6 5 | os.environ["VECLIB_MAXIMUM_THREADS"] = "20" # export VECLIB_MAXIMUM_THREADS=4 6 | os.environ["NUMEXPR_NUM_THREADS"] = "20" 7 | 8 | 9 | import warnings 10 | warnings.filterwarnings('ignore') 11 | 12 | import cellxgene_census 13 | from tqdm import tqdm 14 | import scanpy as sc 15 | 16 | from collections import defaultdict 17 | from typing import Dict, List, Optional, Tuple 18 | 19 | import torch 20 | import torch.utils.data as data 21 | import torch 22 | import numpy as np 23 | import scanpy as sc 24 | from numpy import array 25 | import os 26 | import pickle as pkl 27 | import glob 28 | 29 | def data_to_torch_X(X): 30 | if isinstance(X, sc.AnnData): 31 | X = X.X 32 | if not isinstance(X, np.ndarray): 33 | X = X.toarray() 34 | return torch.from_numpy(X).float() 35 | 36 | import sys 37 | sys.path.append('../') 38 | 39 | from gene_embeddings import load_gene_embeddings_adata 40 | import pandas as pd 41 | import numpy as np 42 | from scanpy import AnnData 43 | from multiprocessing import Pool, Process, Manager 44 | 45 | import multiprocessing.pool as mpp 46 | # https://stackoverflow.com/questions/57354700/starmap-combined-with-tqdm 47 | def istarmap(self, func, iterable, chunksize=1): 48 | """starmap-version of imap 49 | """ 50 | if self._state != mpp.RUN: 51 | raise ValueError("Pool not running") 52 | 53 | if chunksize < 1: 54 | raise ValueError( 55 | "Chunksize must be 1+, not {0:n}".format( 56 | chunksize)) 57 | 58 | task_batches = mpp.Pool._get_tasks(func, iterable, chunksize) 59 | result = mpp.IMapIterator(self._cache) 60 | self._taskqueue.put( 61 | ( 62 | self._guarded_task_generation(result._job, 63 | mpp.starmapstar, 64 | task_batches), 65 | result._set_length 66 | )) 67 | return (item for chunk in result for item in chunk) 68 | 69 | 70 | mpp.Pool.istarmap = istarmap 71 | 72 | 73 | VERSION = "2023-04-25" 74 | N_TOP_GENES = 12000 75 | 76 | 77 | print(cellxgene_census.get_census_version_description(VERSION)) 78 | 79 | census = cellxgene_census.open_soma(census_version=VERSION) 80 | census_datasets = census["census_info"]["datasets"].read().concat().to_pandas() 81 | 82 | # for convenience, indexing on the soma_joinid which links this to other census data. 83 | census_datasets = census_datasets.set_index("soma_joinid") 84 | 85 | species_to_readable = { 86 | "Homo sapiens":"human", 87 | "Mus musculus":"mouse" 88 | } 89 | 90 | def process_row(row, num_genes, num_cells, paths, all_species, covar_cols, dataset_title, h5_root="/dfs/project/uce/cxg_data/anndatas/", npz_root="/dfs/project/uce/cxg_data/npzs/"): 91 | dataset_id = row[1].dataset_id 92 | #dataset_title = row[1].dataset_title.lower().replace(' ', '_').replace(",", "").replace("/", "") 93 | 94 | save_path = h5_root + f"{dataset_title}.h5ad" 95 | no_primary_path = save_path.replace(".h5ad", "_no_primary.h5ad") 96 | proc_path = save_path.replace(".h5ad", "_proc.h5ad") 97 | npz_path = npz_root + f"{dataset_title}_counts.npz" 98 | # Download the anndata 99 | 100 | if os.path.exists(no_primary_path): 101 | print("No Primary, skipping") 102 | return 103 | 104 | if not os.path.exists(save_path) and not os.path.exists(no_primary_path): 105 | cellxgene_census.download_source_h5ad( 106 | dataset_id, to_path=save_path 107 | ) 108 | if os.path.exists(proc_path) and os.path.exists(npz_path): 109 | print("Already Proc") 110 | try: 111 | ad = sc.read(proc_path) 112 | except: 113 | print() 114 | print() 115 | print("Error reading on:", dataset_title) 116 | print() 117 | print() 118 | return 119 | # Get organism 120 | if "organism" in ad.obs.columns: 121 | unique_organisms = list(ad.obs.organism.unique().categories) 122 | unique_organism_str = ", ".join(unique_organisms) 123 | else: 124 | unique_organism_str = "human" 125 | species = species_to_readable.get(unique_organism_str, "human") 126 | # don't need to do hv if already proc 127 | if "sample" in ad.obs.columns: 128 | covar_cols[dataset_title] = "sample" 129 | elif "batch" in ad.obs.columns: 130 | covar_cols[dataset_title] = "batch" 131 | else: 132 | covar_cols[dataset_title] = "" 133 | 134 | 135 | num_genes[dataset_title] = ad.X.shape[1] 136 | num_cells[dataset_title] = ad.X.shape[0] 137 | paths[dataset_title] = f"{dataset_title}.h5ad" 138 | all_species[dataset_title] = species 139 | 140 | return # Skip everything else 141 | # Read the raw AD 142 | ad = sc.read(save_path) 143 | 144 | # Change to counts 145 | if not sc._utils.check_nonnegative_integers(ad.X): 146 | # don't have counts yet, need raw 147 | if ad.raw is None: 148 | print("Skipped, no counts") 149 | return 150 | ad.X = ad.raw.X.toarray() 151 | if not sc._utils.check_nonnegative_integers(ad.X): 152 | print("Skipped, no counts") 153 | return 154 | 155 | # SUBSET TO primary data 156 | if len(np.unique(ad.obs["is_primary_data"])) >= 1: 157 | primary_data = ad.obs.is_primary_data.value_counts() 158 | ad = ad[ad.obs.is_primary_data] 159 | if ad.X.shape[0] == 0: 160 | print("no primary data") 161 | print(primary_data) 162 | os.rename(save_path, no_primary_path) 163 | return # No primary data 164 | print("has primary data") 165 | # Switch to gene symbols 166 | ad.var["feature_id_orig"] = list(ad.var.index) 167 | ad.var_names = list(ad.var.feature_name) 168 | 169 | # Get organism 170 | if "organism" in ad.obs.columns: 171 | unique_organisms = list(ad.obs.organism.unique().categories) 172 | unique_organism_str = ", ".join(unique_organisms) 173 | else: 174 | unique_organism_str = "human" 175 | species = species_to_readable.get(unique_organism_str, "human") 176 | # Filter to gene symbols with protein embeddings 177 | ad, _ = load_gene_embeddings_adata( 178 | adata=ad, 179 | species=[species], 180 | embedding_model="ESM2" 181 | ) 182 | 183 | ad = ad.copy() 184 | # Simple filtering by counts 185 | sc.pp.filter_cells(ad, min_genes=200) 186 | sc.pp.filter_genes(ad, min_cells=10) 187 | 188 | #print(ad) 189 | 190 | if "sample" in ad.obs.columns: 191 | try: 192 | sc.pp.highly_variable_genes(ad, flavor="seurat_v3", n_top_genes=N_TOP_GENES, subset=True, batch_key="sample") 193 | except: 194 | try: 195 | sc.pp.highly_variable_genes(ad, flavor="seurat_v3", n_top_genes=N_TOP_GENES, subset=True, batch_key="sample", span=1) 196 | except: 197 | print(f"can't hv gene subset {dataset_title}") 198 | covar_cols[dataset_title] = "sample" 199 | elif "batch" in ad.obs.columns: 200 | try: 201 | sc.pp.highly_variable_genes(ad, flavor="seurat_v3", n_top_genes=N_TOP_GENES, subset=True, batch_key="batch") 202 | except: 203 | try: 204 | sc.pp.highly_variable_genes(ad, flavor="seurat_v3", n_top_genes=N_TOP_GENES, subset=True, batch_key="batch", span=1) 205 | except: 206 | print(f"can't hv gene subset {dataset_title}") 207 | covar_cols[dataset_title] = "batch" 208 | else: 209 | try: 210 | sc.pp.highly_variable_genes(ad, flavor="seurat_v3", n_top_genes=N_TOP_GENES, subset=True) 211 | except: 212 | try: 213 | sc.pp.highly_variable_genes(ad, flavor="seurat_v3", n_top_genes=N_TOP_GENES, subset=True, span=1) 214 | except: 215 | print(f"can't hv gene subset {dataset_title}") 216 | covar_cols[dataset_title] = "" 217 | 218 | num_genes[dataset_title] = ad.X.shape[1] 219 | num_cells[dataset_title] = ad.X.shape[0] 220 | paths[dataset_title] = f"{dataset_title}.h5ad" 221 | all_species[dataset_title] = species 222 | 223 | print("writing proc") 224 | ad.write(proc_path) 225 | 226 | arr = data_to_torch_X(ad.X).numpy() 227 | 228 | shape = arr.shape 229 | 230 | fp = np.memmap(npz_path, dtype='int64', mode='w+', shape=shape) 231 | fp[:] = arr[:] 232 | fp.flush() 233 | 234 | return 235 | 236 | if __name__ == '__main__': 237 | ''' 238 | manager = Manager() 239 | num_genes = manager.dict() 240 | num_cells = manager.dict() 241 | paths = manager.dict() 242 | all_species = manager.dict() 243 | covar_cols = manager.dict() 244 | ''' 245 | num_genes = {} 246 | num_cells = {} 247 | paths = {} 248 | all_species = {} 249 | covar_cols = {} 250 | 251 | df = pd.DataFrame() 252 | # Shuffle the dataset 253 | census_datasets = census_datasets#.iloc[270:] 254 | iterrows = list(census_datasets.iterrows()) 255 | #p = Pool(8) 256 | #for row in tqdm(iterrows, total=len(census_datasets)): 257 | # p.apply_async(process_row, args=(row, num_genes, num_cells, paths, all_species, covar_cols)) 258 | #p.close() 259 | #p.join() 260 | ''' 261 | with Pool(1) as p: 262 | nrows = len(iterrows) 263 | inputs = zip(iterrows, [num_genes]*nrows, [num_cells]*nrows, [paths]*nrows, [all_species]*nrows, [covar_cols]*nrows) 264 | for _ in tqdm(p.istarmap(process_row, inputs), 265 | total=nrows): 266 | pass 267 | 268 | ''' 269 | 270 | if os.path.exists("dataset_rows_mouse_fixed.pkl"): 271 | dataset_rows = {} 272 | for path in glob.glob("dataset_rows_mouse_fixed*.pkl"): 273 | with open(path, "rb") as f: 274 | dataset_rows_path = pkl.load(f) 275 | dataset_rows.update(dataset_rows_path) 276 | 277 | print(f"{len(dataset_rows)} already counted") 278 | else: 279 | dataset_rows = {} 280 | 281 | 282 | pbar = tqdm(iterrows) 283 | all_errors = [] 284 | total_number_of_cells = 0 285 | 286 | duplicate_titles = ['Dissection: Body of hippocampus (HiB) - Rostral DG-CA4', 'Retina', 287 | 'Colon', 'Myeloid cells', 'Ileum', 'Airway'] 288 | duplicate_titles_2 = ['retina', 'airway', 'myeloid_cells', 'colon', 'ileum', 'immune_cells'] 289 | 290 | for row in pbar: 291 | dataset_title = row[1].dataset_title 292 | if dataset_title in duplicate_titles: 293 | dataset_title = row[1].collection_name + row[1].dataset_title 294 | 295 | dataset_title = dataset_title.lower().replace(' ', '_').replace(",", "").replace("/", "") 296 | 297 | if dataset_title in duplicate_titles_2: 298 | dataset_title = (row[1].collection_name + "_" + dataset_title).lower().replace(' ', '_').replace(",", "").replace("/", "") 299 | 300 | print(f"{total_number_of_cells} cells done") 301 | if dataset_title in dataset_rows: 302 | paths[dataset_title] = dataset_rows[dataset_title][0] 303 | all_species[dataset_title] = dataset_rows[dataset_title][1] 304 | covar_cols[dataset_title] = dataset_rows[dataset_title][2] 305 | num_cells[dataset_title] = dataset_rows[dataset_title][3] 306 | num_genes[dataset_title] = dataset_rows[dataset_title][4] 307 | #print("skipped read of proc") 308 | 309 | total_number_of_cells += dataset_rows[dataset_title][3] 310 | continue # Skip! 311 | else: 312 | pbar.set_description(f"{dataset_title} proc") 313 | try: 314 | process_row(row, num_genes, num_cells, paths, all_species, covar_cols, dataset_title=dataset_title) 315 | except: 316 | print(f"****{dataset_title} ERROR****") 317 | all_errors.append(dataset_title) 318 | 319 | 320 | pbar.set_description(f"{dataset_title} done") 321 | 322 | if dataset_title in paths: 323 | dataset_rows[dataset_title] = [paths[dataset_title], all_species[dataset_title], covar_cols[dataset_title], num_cells[dataset_title], num_genes[dataset_title], dataset_title] 324 | 325 | total_number_of_cells += dataset_rows[dataset_title][3] 326 | 327 | with open("dataset_rows_mouse_fixed.pkl", "wb") as f: 328 | pkl.dump(dataset_rows, f) 329 | print("wrote pkl") 330 | 331 | # path,species,covar_col,num_cells,names 332 | 333 | df["path"] = list(paths.values()) 334 | df["species"] = list(all_species.values()) 335 | df["covar_col"] = list(covar_cols.values()) 336 | df["num_cells"] = list(num_cells.values()) 337 | df["num_genes"] = list(num_genes.values()) 338 | df["names"] = list(paths.keys()) 339 | 340 | print(df.head(20)) 341 | print() 342 | print("Errors:") 343 | print(all_errors) 344 | df.to_csv("cxg_datasets.csv", index=False) 345 | -------------------------------------------------------------------------------- /data_proc/gene_embeddings.py: -------------------------------------------------------------------------------- 1 | """Helper functions for loading pretrained gene embeddings.""" 2 | from pathlib import Path 3 | from typing import Dict, Tuple 4 | 5 | import torch 6 | 7 | from scanpy import AnnData 8 | import numpy as np 9 | import pandas as pd 10 | 11 | 12 | EMBEDDING_DIR = Path('model_files/protein_embeddings') 13 | MODEL_TO_SPECIES_TO_GENE_EMBEDDING_PATH = { 14 | 'ESM2': { 15 | 'human': EMBEDDING_DIR / 'Homo_sapiens.GRCh38.gene_symbol_to_embedding_ESM2.pt', 16 | 'mouse': EMBEDDING_DIR / 'Mus_musculus.GRCm39.gene_symbol_to_embedding_ESM2.pt', 17 | 'frog': EMBEDDING_DIR / 'Xenopus_tropicalis.Xenopus_tropicalis_v9.1.gene_symbol_to_embedding_ESM2.pt', 18 | 'zebrafish': EMBEDDING_DIR / 'Danio_rerio.GRCz11.gene_symbol_to_embedding_ESM2.pt', 19 | "mouse_lemur": EMBEDDING_DIR / "Microcebus_murinus.Mmur_3.0.gene_symbol_to_embedding_ESM2.pt", 20 | "pig": EMBEDDING_DIR / 'Sus_scrofa.Sscrofa11.1.gene_symbol_to_embedding_ESM2.pt', 21 | "macaca_fascicularis": EMBEDDING_DIR / 'Macaca_fascicularis.Macaca_fascicularis_6.0.gene_symbol_to_embedding_ESM2.pt', 22 | "macaca_mulatta": EMBEDDING_DIR / 'Macaca_mulatta.Mmul_10.gene_symbol_to_embedding_ESM2.pt', 23 | } 24 | } 25 | 26 | extra_species = pd.read_csv("./model_files/new_species_protein_embeddings.csv").set_index("species").to_dict()["path"] 27 | MODEL_TO_SPECIES_TO_GENE_EMBEDDING_PATH["ESM2"].update(extra_species) # adds new species 28 | 29 | 30 | def load_gene_embeddings_adata(adata: AnnData, species: list, embedding_model: str) -> Tuple[AnnData, Dict[str, torch.FloatTensor]]: 31 | """Loads gene embeddings for all the species/genes in the provided data. 32 | 33 | :param data: An AnnData object containing gene expression data for cells. 34 | :param species: Species corresponding to this adata 35 | 36 | :param embedding_model: The gene embedding model whose embeddings will be loaded. 37 | :return: A tuple containing: 38 | - A subset of the data only containing the gene expression for genes with embeddings in all species. 39 | - A dictionary mapping species name to the corresponding gene embedding matrix (num_genes, embedding_dim). 40 | """ 41 | # Get species names 42 | species_names = species 43 | species_names_set = set(species_names) 44 | 45 | # Get embedding paths for the model 46 | species_to_gene_embedding_path = MODEL_TO_SPECIES_TO_GENE_EMBEDDING_PATH[embedding_model] 47 | available_species = set(species_to_gene_embedding_path) 48 | 49 | # Ensure embeddings are available for all species 50 | if not (species_names_set <= available_species): 51 | raise ValueError(f'The following species do not have gene embeddings: {species_names_set - available_species}') 52 | 53 | # Load gene embeddings for desired species (and convert gene symbols to lower case) 54 | species_to_gene_symbol_to_embedding = { 55 | species: { 56 | gene_symbol.lower(): gene_embedding 57 | for gene_symbol, gene_embedding in torch.load(species_to_gene_embedding_path[species]).items() 58 | } 59 | for species in species_names 60 | } 61 | 62 | # Determine which genes to include based on gene expression and embedding availability 63 | genes_with_embeddings = set.intersection(*[ 64 | set(gene_symbol_to_embedding) 65 | for gene_symbol_to_embedding in species_to_gene_symbol_to_embedding.values() 66 | ]) 67 | genes_to_use = {gene for gene in adata.var_names if gene.lower() in genes_with_embeddings} 68 | 69 | # Subset data to only use genes with embeddings 70 | adata = adata[:, adata.var_names.isin(genes_to_use)] 71 | 72 | # Set up dictionary mapping species to gene embedding matrix (num_genes, embedding_dim) 73 | species_to_gene_embeddings = { 74 | species_name: torch.stack([ 75 | species_to_gene_symbol_to_embedding[species_name][gene_symbol.lower()] 76 | for gene_symbol in adata.var_names 77 | ]) 78 | for species_name in species_names 79 | } 80 | 81 | return adata, species_to_gene_embeddings 82 | -------------------------------------------------------------------------------- /data_proc/generate_reduced_chrom_files.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["OMP_NUM_THREADS"] = "4" # export OMP_NUM_THREADS=4 3 | os.environ["OPENBLAS_NUM_THREADS"] = "4" # export OPENBLAS_NUM_THREADS=4 4 | os.environ["MKL_NUM_THREADS"] = "4" # export MKL_NUM_THREADS=6 5 | os.environ["VECLIB_MAXIMUM_THREADS"] = "4" # export VECLIB_MAXIMUM_THREADS=4 6 | os.environ["NUMEXPR_NUM_THREADS"] = "4" 7 | 8 | 9 | import warnings 10 | warnings.filterwarnings("ignore") 11 | 12 | import scanpy as sc 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | import numpy as np 18 | import pickle 19 | import os 20 | import argparse 21 | import logging 22 | import time 23 | 24 | from tqdm.auto import tqdm 25 | import matplotlib.pyplot as plt 26 | import pandas as pd 27 | 28 | #sc._settings.ScanpyConfig.n_jobs = 6 29 | 30 | import math 31 | from typing import Tuple 32 | 33 | import torch 34 | from torch import nn, Tensor 35 | import torch.nn.functional as F 36 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 37 | from torch.utils.data import dataset 38 | 39 | 40 | from accelerate import Accelerator 41 | import anndata 42 | from data_utils import adata_path_to_prot_chrom_starts, get_spec_chrom_csv 43 | 44 | 45 | 46 | from torch.utils.data import dataset 47 | from torch.utils.data import DataLoader, TensorDataset 48 | from scipy.stats import binom 49 | 50 | 51 | 52 | 53 | def padding_tensor(sequences): 54 | """ 55 | :param sequences: list of tensors 56 | :return: 57 | """ 58 | num = len(sequences) 59 | max_len = max([s.size(0) for s in sequences]) 60 | out_dims = (num, max_len, 1280) 61 | 62 | 63 | out_tensor = sequences[0].data.new(*out_dims).fill_(0) 64 | out_dims2 = (num, max_len) 65 | 66 | mask = sequences[0].data.new(*out_dims2).fill_(float('-inf')) 67 | for i, tensor in enumerate(sequences): 68 | length = tensor.size(0) 69 | out_tensor[i, :length] = tensor 70 | mask[i, :length] = 1 71 | return out_tensor.permute(1, 0, 2), mask 72 | 73 | 74 | from pathlib import Path 75 | # ESM1b 76 | ''' 77 | EMBEDDING_DIR = Path('/dfs/project/cross-species/data/proteome/embeddings') 78 | human_pe_dir = EMBEDDING_DIR / 'Homo_sapiens.GRCh38.gene_symbol_to_embedding_ESM1b.pt' 79 | mouse_pe_dir = EMBEDDING_DIR / 'Mus_musculus.GRCm39.gene_symbol_to_embedding_ESM1b.pt' 80 | lemur_pe_dir = Path("/dfs/project/cross-species/yanay/data/proteome/embeddings/") / 'Microcebus_murinus.Mmur_3.0.gene_symbol_to_embedding_ESM1b.pt' 81 | 82 | ''' 83 | 84 | # Upgrade to ESM2 85 | EMBEDDING_DIR = Path('/dfs/project/cross-species/data/proteome/embeddings') 86 | EMBEDDING_DIR = Path('/dfs/project/cross-species/yanay/data/proteome/embeddings') 87 | 88 | embeddings_paths = { 89 | 'human': EMBEDDING_DIR / 'Homo_sapiens.GRCh38.gene_symbol_to_embedding_ESM2.pt', 90 | 'mouse': EMBEDDING_DIR / 'Mus_musculus.GRCm39.gene_symbol_to_embedding_ESM2.pt', 91 | 'frog': EMBEDDING_DIR / 'Xenopus_tropicalis.Xenopus_tropicalis_v9.1.gene_symbol_to_embedding_ESM2.pt', 92 | 'zebrafish': EMBEDDING_DIR / 'Danio_rerio.GRCz11.gene_symbol_to_embedding_ESM2.pt', 93 | "mouse_lemur": EMBEDDING_DIR / "Microcebus_murinus.Mmur_3.0.gene_symbol_to_embedding_ESM2.pt", 94 | "pig": EMBEDDING_DIR / 'Sus_scrofa.Sscrofa11.1.gene_symbol_to_embedding_ESM2.pt', 95 | "macaca_fascicularis": EMBEDDING_DIR / 'Macaca_fascicularis.Macaca_fascicularis_6.0.gene_symbol_to_embedding_ESM2.pt', 96 | "macaca_mulatta": EMBEDDING_DIR / 'Macaca_mulatta.Mmul_10.gene_symbol_to_embedding_ESM2.pt', 97 | } 98 | 99 | species_to_pe = { 100 | species:torch.load(pe_dir) for species, pe_dir in embeddings_paths.items() 101 | } 102 | 103 | species_to_pe = {species:{k.upper(): v for k,v in pe.items()} for species, pe in species_to_pe.items()} 104 | 105 | #species_to_keys = {species:list(pe.keys()) for species, pe in species_to_pe.items()} 106 | #species_to_keys = {species:dict(zip(keys, np.arange(len(keys)))) for species, keys in species_to_keys.items()} 107 | 108 | 109 | #datasets_df = pd.read_csv("/dfs/project/cross-species/yanay/code/UCE/data_proc/full_train_datasets.csv") 110 | datasets_df = pd.read_csv("tissue_datasets.csv") 111 | datasets_df = pd.read_csv("perturb_datasets.csv") 112 | datasets_df = pd.read_csv("../new_perturb_datasets.csv") 113 | 114 | 115 | #pd.concat((#pd.read_csv("new_datasets.csv"), 116 | #pd.read_csv("pbmcs_nohvg.csv"), 117 | #pd.read_csv("lung_nohvg.csv"), 118 | #pd.read_csv("new_tabula_datasets.csv"), 119 | #pd.read_csv("updated_datasets.csv"), 120 | # #pd.read_csv("sanger_heart_atlas_datasets.csv"), 121 | # pd.read_csv("tissue_datasets.csv") 122 | # )) 123 | 124 | 125 | 126 | 127 | #datasets_df = pd.read_csv("cell_cycle_datasets.csv") 128 | #datasets_df = pd.read_csv("spatial_datasets.csv") 129 | #datasets_df = pd.read_csv("perturb_datasets.csv") 130 | #datasets_df = pd.read_csv("ccle_datasets.csv") 131 | #datasets_df = pd.read_csv("pancreas_datasets.csv") 132 | 133 | 134 | 135 | sorted_dataset_names = sorted(datasets_df["names"]) 136 | with open("dataset_shapes.pkl", "rb") as f: 137 | shapes_dict = pickle.load(f) 138 | 139 | 140 | shapes_dict.update({ 141 | "madissoon_novel_lung":(190728, 8000), 142 | 'flores_cerebellum_human': (20232, 8000), 143 | 'osuch_gut_human': (272310, 8000), 144 | 'msk_ovarian_human': (929690, 8000), 145 | 'htan_vmuc_dis_epi_human': (65084, 8000), 146 | 'htan_vmuc_val_epi_human': (57564, 8000), 147 | 'htan_vmuc_non_epi_human': (9099, 8000), 148 | 'hao_pbmc_3p_human': (161764, 8000), 149 | 'hao_pbmc_5p_human': (49147, 8000), 150 | 'gao_tumors_human': (36111, 8000), 151 | 'swabrick_breast_human': (92427, 8000), 152 | 'wu_cryo_tumors_human': (105662, 8000), 153 | 'cell_line_het_human': (53513, 8000), 154 | 'bi_allen_metastasis_human': (27787, 8000), 155 | 'zheng68k_human': (68579, 8000), 156 | 'zheng68k_12k_human': (68579, 12000), 157 | 'mouse_embryo_ct': (153597, 12000), 158 | "regev_gtex_heart": (36574, 8000), 159 | "tabula_sapiens_heart": (11505, 8000), 160 | "10k_pbmcs":(11990, 12000), 161 | "epo_ido":(35834,12000), 162 | 'tabula_sapiens_kidney': (9641, 8000), 163 | 'tabula_microcebus_kidney': (14592, 8000), 164 | 'tabula_muris_kidney': (2781, 8000), 165 | 'tabula_muris_senis_kidney': (19610, 8000), 166 | 'immune_human': (33506, 8000) 167 | }) 168 | 169 | for row in datasets_df.iterrows(): 170 | ngenes = row[1].num_genes 171 | ncells = row[1].num_cells 172 | name = row[1].names 173 | if not np.isnan(ngenes): 174 | shapes_dict[name] = (int(ncells), int(ngenes)) 175 | 176 | #with open("dataset_shapes.pkl", "wb") as f: 177 | # pickle.dump(shapes_dict, f) 178 | token_dim = 5120 179 | mmap_dict = {} 180 | 181 | root_dir = "/lfs/local/0/yanay/uce_h5s/" 182 | root_dir_census = "/lfs/local/0/yanay/cxg_h5s/" 183 | 184 | dataset_to_paths = {r[1]["names"]:root_dir + r[1]["path"].replace(".h5ad", "_proc.h5ad") for r in datasets_df.iterrows()} 185 | for row in datasets_df.iterrows(): 186 | name = row[1].names 187 | census = row[1].census 188 | 189 | if census == "yes": 190 | dataset_to_paths[name] = dataset_to_paths[name].replace(root_dir, root_dir_census) 191 | 192 | 193 | datasets_to_species = {r[1]["names"]:r[1]["species"] for r in datasets_df.iterrows()} 194 | 195 | #species_to_pe = {"mouse":mouse_pe, "human":human_pe, "mouse_lemur":lemur_pe} 196 | 197 | #dataset_to_protein_embeddings_all = {k:species_to_pe[v] for k, v in datasets_to_species.items()} 198 | 199 | dataset_to_protein_embeddings = {} 200 | 201 | 202 | #dataset_to_protein_embeddings_all["madissoon_novel_lung"] = species_to_pe["human"] 203 | datasets_to_species["madissoon_novel_lung"] = "human" 204 | #dataset_to_paths["madissoon_novel_lung"] = "/lfs/local/0/yanay/uce_h5s/madissoon_novel_lung_proc.h5ad" 205 | 206 | 207 | 208 | # New Chrom Based Code 209 | gene_to_chrom_pos = get_spec_chrom_csv() 210 | species_to_chrom_categories = {} 211 | 212 | for species in np.unique(gene_to_chrom_pos["species"]): 213 | species_to_chrom_categories[species] = pd.Categorical(gene_to_chrom_pos["chromosome"]).categories 214 | 215 | 216 | dataset_to_chroms = {} 217 | dataset_to_starts = {} 218 | 219 | sorted_species_names = sorted(species_to_pe.keys()) 220 | print(sorted_species_names) 221 | 222 | if os.path.exists(f"/dfs/project/uce/all_species_pe_tokens.torch"): 223 | all_pe = torch.load(f"/dfs/project/uce/all_species_pe_tokens.torch") 224 | with open("/dfs/project/uce/all_species_offsets.pkl", "rb") as f: 225 | species_to_offsets = pickle.load(f) 226 | print("Loaded PE", all_pe.shape) 227 | else: 228 | torch.manual_seed(8) 229 | MASK_TENSOR = torch.zeros((1, token_dim)) # this is the padding token 230 | CHROM_TENSOR_LEFT = torch.normal(mean=0, std=1, size=(1, token_dim)) 231 | CHROM_TENSOR_RIGHT = torch.normal(mean=0, std=1, size=(1, token_dim)) 232 | CLS_TENSOR = torch.normal(mean=0, std=1, size=(1, token_dim)) 233 | species_to_offsets = {} 234 | 235 | all_pe = [MASK_TENSOR, CHROM_TENSOR_LEFT, CHROM_TENSOR_RIGHT, CLS_TENSOR] 236 | offset = len(all_pe) # special tokens at the top! 237 | for species in sorted_species_names: 238 | pe_stacked = torch.stack(list(species_to_pe[species].values())) 239 | all_pe.append(pe_stacked) 240 | species_to_offsets[species] = offset 241 | offset += pe_stacked.shape[0] 242 | 243 | all_pe = torch.vstack(all_pe) 244 | print(all_pe.shape) 245 | torch.save(all_pe, f"/dfs/project/uce/all_species_pe_tokens.torch") 246 | with open("/dfs/project/uce/all_species_offsets.pkl", "wb+") as f: 247 | pickle.dump(species_to_offsets, f) 248 | print("Saved PE") 249 | 250 | # Load in already saved! 251 | if os.path.exists(f"/lfs/local/0/yanay/reduced_datasets_to_pe_chrom_{token_dim}_new.torch"): 252 | dataset_to_protein_embeddings = torch.load(f"/lfs/local/0/yanay/reduced_datasets_to_pe_chrom_{token_dim}_new.torch") 253 | 254 | with open("/lfs/local/0/yanay/dataset_to_chroms_new.pkl", "rb") as f: 255 | dataset_to_chroms = pickle.load(f) 256 | with open("/lfs/local/0/yanay/dataset_to_starts_new.pkl", "rb") as f: 257 | dataset_to_starts = pickle.load(f) 258 | else: 259 | dataset_to_protein_embeddings = {} 260 | dataset_to_chroms = {} 261 | dataset_to_starts = {} 262 | 263 | 264 | # Add the new ones 265 | print("creating reduced size protein embeddings file") 266 | 267 | redo = True 268 | 269 | for dataset, path in tqdm(list(dataset_to_paths.items())): 270 | if dataset in dataset_to_protein_embeddings.keys() and not redo: 271 | continue # skip since already procced 272 | print(dataset) 273 | adata = sc.read(path) 274 | dataset_species = datasets_to_species[dataset] 275 | spec_pe_genes = list(species_to_pe[dataset_species].keys()) 276 | offset = species_to_offsets[dataset_species] 277 | 278 | # Get proper idxs 279 | pe_row_idxs, dataset_chroms, dataset_pos = adata_path_to_prot_chrom_starts(adata, dataset_species, spec_pe_genes, gene_to_chrom_pos, offset) 280 | # Add to dicts 281 | dataset_to_chroms[dataset] = dataset_chroms 282 | dataset_to_starts[dataset] = dataset_pos 283 | dataset_to_protein_embeddings[dataset] = pe_row_idxs 284 | 285 | del adata 286 | # save Dicts and idxs 287 | torch.save(dataset_to_protein_embeddings, f"/lfs/local/0/yanay/reduced_datasets_to_pe_chrom_{token_dim}_new.torch") 288 | 289 | with open("/lfs/local/0/yanay/dataset_to_chroms_new.pkl", "wb+") as f: 290 | pickle.dump(dataset_to_chroms, f) 291 | with open("/lfs/local/0/yanay/dataset_to_starts_new.pkl", "wb+") as f: 292 | pickle.dump(dataset_to_starts, f) -------------------------------------------------------------------------------- /data_proc/preproc_many_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["OMP_NUM_THREADS"] = "10" # export OMP_NUM_THREADS=4 3 | os.environ["OPENBLAS_NUM_THREADS"] = "10" # export OPENBLAS_NUM_THREADS=4 4 | os.environ["MKL_NUM_THREADS"] = "10" # export MKL_NUM_THREADS=6 5 | os.environ["VECLIB_MAXIMUM_THREADS"] = "10" # export VECLIB_MAXIMUM_THREADS=4 6 | os.environ["NUMEXPR_NUM_THREADS"] = "10" 7 | 8 | 9 | 10 | from collections import defaultdict 11 | from typing import Dict, List, Optional, Tuple 12 | 13 | import torch 14 | import torch.utils.data as data 15 | import numpy as np 16 | import scanpy as sc 17 | from numpy import array 18 | import subprocess 19 | import os 20 | from tqdm import tqdm 21 | import warnings 22 | warnings.filterwarnings("ignore") 23 | 24 | 25 | from gene_embeddings import load_gene_embeddings_adata 26 | import pandas as pd 27 | import numpy as np 28 | from scanpy import AnnData 29 | from data_utils import process_raw_anndata 30 | 31 | def data_to_torch_X(X): 32 | if isinstance(X, sc.AnnData): 33 | X = X.X 34 | if not isinstance(X, np.ndarray): 35 | X = X.toarray() 36 | return torch.from_numpy(X).float() 37 | 38 | class SincleCellDataset(data.Dataset): 39 | def __init__(self, 40 | expression: torch.tensor, # Subset to hv genes, count data! cells x genes 41 | protein_embeddings: torch.tensor, # same order as expression, also subset genes x pe 42 | labels: None, # optional, tensor of labels 43 | covar_vals: None, # tensor of covar values or none 44 | ) -> None: 45 | super(SincleCellDataset, self).__init__() 46 | 47 | # Set expression 48 | self.expression = expression 49 | 50 | row_sums = self.expression.sum(1) # UMI Counts 51 | log_norm_count_adj = torch.log1p(self.expression / (self.expression.sum(1)).unsqueeze(1) * torch.tensor(1000)) 52 | 53 | # Set log norm and count adjusted expression 54 | max_vals, max_idx = torch.max(log_norm_count_adj, dim=0) 55 | self.expression_mod = log_norm_count_adj / max_vals 56 | 57 | # Calculate dropout likliehoods of each gene 58 | self.dropout_vec = (self.expression == 0).float().mean(0) # per gene dropout percentages 59 | 60 | # Set data info 61 | self.num_cells = self.expression.shape[0] 62 | self.num_genes = self.expression.shape[1] 63 | 64 | # Set optional label info, including categorical covariate index 65 | self.covar_vals = covar_vals 66 | self.labels = labels 67 | 68 | # Set protein embeddings 69 | self.protein_embeddings = protein_embeddings 70 | 71 | self.item_mode = "expression" 72 | if self.covar_vals is not None: 73 | self.item_mode = "expression+covar" 74 | 75 | 76 | def __getitem__(self, idx): 77 | if self.item_mode == "expression": 78 | if isinstance(idx, int): 79 | if idx < self.num_cells: 80 | return self.expression[idx, :] 81 | else: 82 | raise IndexError 83 | else: 84 | raise NotImplementedError 85 | elif self.item_mode == "expression+covar": 86 | if isinstance(idx, int): 87 | if idx < self.num_cells: 88 | return self.expression[idx, :], self.covar_vals[idx] 89 | else: 90 | raise IndexError 91 | else: 92 | raise NotImplementedError 93 | 94 | 95 | def __len__(self) -> int: 96 | return self.num_cells 97 | 98 | def get_dim(self) -> Dict[str, int]: 99 | return self.num_genes 100 | 101 | 102 | def data_to_torch_X(X): 103 | if isinstance(X, sc.AnnData): 104 | X = X.X 105 | if not isinstance(X, np.ndarray): 106 | X = X.toarray() 107 | return torch.from_numpy(X).float() 108 | 109 | 110 | def anndata_to_sc_dataset(adata:sc.AnnData, 111 | species:str="human", 112 | labels:list=[], 113 | covar_col:str=None, 114 | hv_genes:int=12000, 115 | embedding_model="ESM1b", 116 | ) -> (SincleCellDataset, AnnData): 117 | 118 | # Subset to just genes we have embeddings for 119 | adata, protein_embeddings = load_gene_embeddings_adata( 120 | adata=adata, 121 | species=[species], 122 | embedding_model=embedding_model 123 | ) 124 | 125 | if DO_HVG: 126 | sc.pp.highly_variable_genes(adata, flavor='seurat_v3', n_top_genes=hv_genes) # Expects Count Data 127 | 128 | hv_index = adata.var["highly_variable"] 129 | adata = adata[:, hv_index] # Subset to hv genes only 130 | 131 | protein_embeddings = protein_embeddings[species][hv_index] 132 | else: 133 | protein_embeddings = protein_embeddings[species] 134 | expression = data_to_torch_X(adata.X) 135 | 136 | covar_vals = None 137 | if len(labels) > 0: 138 | assert covar_col is None or covar_col in labels, "Covar needs to be in labels" # make sure you keep track of covar column! 139 | labels = adata.obs.loc[:, labels].values 140 | 141 | if covar_col is not None: 142 | # we have a categorical label to use as covariate 143 | covar_vals = torch.tensor(pd.Categorical(adata.obs[covar_col]).codes) 144 | return SincleCellDataset( 145 | expression=expression, 146 | protein_embeddings=protein_embeddings, 147 | labels=labels, 148 | covar_vals=covar_vals 149 | ), adata 150 | 151 | def proc(args): 152 | datasets_df = pd.read_csv(args.datasets_df) 153 | datasets_df["covar_col"] = np.nan 154 | skip = args.skip 155 | additional_filter = args.filter 156 | DO_HVG = args.DO_HVG 157 | 158 | num_genes = {} 159 | num_cells = {} 160 | 161 | ir = list(datasets_df.iterrows()) 162 | for i, row in tqdm(ir, total=len(datasets_df)): 163 | _, ncells, ngenes = process_raw_anndata(row, h5_folder_path, npz_folder_path, scp, skip, additional_filter, root=args.file_root_path) 164 | if (ncells is not None) and (ngenes is not None): 165 | num_genes[path] = adata.X.shape[1] 166 | num_cells[path] = ngenes 167 | 168 | if "num_cells" not in datasets_df.columns: 169 | datasets_df["num_cells"] = 0 170 | if "num_genes" not in datasets_df.columns: 171 | datasets_df["num_genes"] = 0 172 | for k in num_genes.keys(): 173 | ng = num_genes[k] 174 | nc = num_cells[k] 175 | datasets_df.loc[datasets_df["path"] == k, "num_cells"] = nc 176 | datasets_df.loc[datasets_df["path"] == k, "num_genes"] = ng 177 | # Write with the cells and genes info back to the original path 178 | datasets_df.to_csv(args.datasets_df, index=False) 179 | if __name__=="__main__": 180 | # Parse command-line arguments 181 | 182 | parser = argparse.ArgumentParser(description='Preproc datasets h5ad datasets.') 183 | 184 | # Define command-line arguments 185 | parser.add_argument('--scp', type=str, default="", help='Name of a SNAP server to SCP the results to. It should have the same folders as the script is already saving to.') 186 | parser.add_argument('--h5_folder_path', type=str, default="/lfs/local/0/yanay/uce_h5s/", help='Folder to save H5s to.') 187 | parser.add_argument('--npz_folder_path', type=str, default="/lfs/local/0/yanay/uce_proc/", help='Folder to save NPZs to.') 188 | 189 | 190 | parser.add_argument('--datasets_df', type=str, default="/dfs/project/uce/new_perturb_datasets.csv", help='Path to datasets csv. Will be overwritten to have the correct num cells and num genes for each dataset.') 191 | 192 | parser.add_argument('--filter', type=bool, default=True, help='Should you do an additional gene/cell filtering? This can be a good step since even if you have already done it, subsetting to protein embeddings can make some cells sparser.') 193 | parser.add_argument('--skip', type=bool, default=True, help='Should you skip datasets that appear to have already been created in the h5 folder?') 194 | 195 | parser.add_argument('--DO_HVG', type=bool, default=False, help='Should a HVG subset be done.') 196 | 197 | 198 | parse 199 | args = parser.parse_args() 200 | main(args) 201 | -------------------------------------------------------------------------------- /eval_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataloaders 3 | 4 | """ 5 | 6 | import warnings 7 | warnings.filterwarnings("ignore") 8 | import sys 9 | sys.path.append('../') 10 | from typing import Dict, List, Optional, Tuple, Any 11 | import torch 12 | import numpy as np 13 | import pickle 14 | import torch.utils.data as data 15 | 16 | 17 | class MultiDatasetSentences(data.Dataset): 18 | def __init__(self, sorted_dataset_names, shapes_dict, args, 19 | dataset_to_protein_embeddings_path= "/lfs/local/0/yanay/reduced_datasets_to_pe_chrom_5120_new.torch", 20 | datasets_to_chroms_path="/lfs/local/0/yanay/dataset_to_chroms_new.pkl", 21 | datasets_to_starts_path="/lfs/local/0/yanay/dataset_to_starts_new.pkl", 22 | npzs_dir="/lfs/local/0/yanay/uce_proc/") -> None: 23 | super(MultiDatasetSentences, self).__init__() 24 | # self.xs = {} 25 | self.num_cells = {} 26 | self.num_genes = {} 27 | self.shapes_dict = shapes_dict 28 | self.args = args 29 | 30 | self.total_num_cells = 0 31 | for name in sorted_dataset_names: 32 | num_cells, num_genes = self.shapes_dict[name] 33 | # self.xs[name] = X 34 | self.num_cells[name] = num_cells 35 | self.num_genes[name] = num_genes 36 | 37 | self.total_num_cells += num_cells 38 | 39 | self.datasets = sorted_dataset_names 40 | 41 | # TODO: preferably not hard-coded here 42 | self.dataset_to_protein_embeddings = torch.load(dataset_to_protein_embeddings_path) 43 | with open(datasets_to_chroms_path, "rb") as f: 44 | self.dataset_to_chroms = pickle.load(f) 45 | with open(datasets_to_starts_path, "rb") as f: 46 | self.dataset_to_starts = pickle.load(f) 47 | 48 | self.npzs_dir = npzs_dir 49 | 50 | def __getitem__(self, idx): 51 | if isinstance(idx, int): 52 | for dataset in sorted(self.datasets): 53 | if idx < self.num_cells[dataset]: 54 | #cts = np.memmap(f"/lfs/local/0/yanay/cxg_npzs/" + f"{dataset}_counts.npz", 55 | # dtype='int64', mode='r', shape=self.shapes_dict[dataset]) 56 | cts = np.memmap(self.npzs_dir + f"{dataset}_counts.npz", dtype='int64', mode='r', shape=self.shapes_dict[dataset]) 57 | counts = cts[idx] 58 | counts = torch.tensor(counts).unsqueeze(0) 59 | weights = torch.log1p(counts) 60 | weights = (weights / torch.sum(weights)) 61 | batch_sentences, mask, seq_len, cell_sentences = \ 62 | sample_cell_sentences(counts, weights, dataset, self.args, 63 | dataset_to_protein_embeddings= self.dataset_to_protein_embeddings, 64 | dataset_to_chroms=self.dataset_to_chroms, 65 | dataset_to_starts=self.dataset_to_starts) 66 | return batch_sentences, mask, idx, seq_len, cell_sentences 67 | else: 68 | idx -= self.num_cells[dataset] 69 | raise IndexError 70 | else: 71 | raise NotImplementedError 72 | 73 | def __len__(self) -> int: 74 | return self.total_num_cells 75 | 76 | def get_dim(self) -> Dict[str, int]: 77 | return self.num_genes 78 | 79 | 80 | class MultiDatasetSentenceCollator(object): 81 | def __init__(self, args): 82 | self.pad_length = args.pad_length 83 | 84 | 85 | def __call__(self, batch): 86 | batch_size = len(batch) 87 | batch_sentences = torch.zeros((batch_size, self.pad_length)) 88 | mask = torch.zeros((batch_size, self.pad_length)) 89 | cell_sentences = torch.zeros((batch_size, self.pad_length)) 90 | 91 | idxs = torch.zeros(batch_size) 92 | 93 | i = 0 94 | max_len = 0 95 | for bs, msk, idx, seq_len, cs in batch: 96 | batch_sentences[i, :] = bs 97 | cell_sentences[i, :] = cs 98 | max_len = max(max_len, seq_len) 99 | mask[i, :] = msk 100 | idxs[i] = idx 101 | 102 | i += 1 103 | 104 | return batch_sentences[:, :max_len] , mask[:, :max_len], idxs, cell_sentences 105 | 106 | 107 | 108 | def sample_cell_sentences(counts, batch_weights, dataset, args, 109 | dataset_to_protein_embeddings, 110 | dataset_to_chroms, 111 | dataset_to_starts): 112 | 113 | dataset_idxs = dataset_to_protein_embeddings[dataset] # get the dataset specific protein embedding idxs 114 | cell_sentences = torch.zeros((counts.shape[0], args.pad_length)) # init the cell representation as 0s 115 | mask = torch.zeros((counts.shape[0], args.pad_length)) # start of masking the whole sequence 116 | chroms = dataset_to_chroms[dataset] # get the dataset specific chroms for each gene 117 | starts = dataset_to_starts[dataset] # get the dataset specific genomic start locations for each gene 118 | 119 | longest_seq_len = 0 # we need to keep track of this so we can subset the batch at the end 120 | 121 | for c, cell in enumerate(counts): 122 | weights = batch_weights[c].numpy() 123 | weights = weights / sum(weights) # RE NORM after mask 124 | 125 | # randomly choose the genes that will make up the sample, weighted by expression, with replacement 126 | choice_idx = np.random.choice(np.arange(len(weights)), 127 | size=args.sample_size, p=weights, 128 | replace=True) 129 | choosen_chrom = chroms[choice_idx] # get the sampled genes chromosomes 130 | # order the genes by chromosome 131 | chrom_sort = np.argsort(choosen_chrom) 132 | choice_idx = choice_idx[chrom_sort] 133 | 134 | # sort the genes by start 135 | new_chrom = chroms[choice_idx] 136 | choosen_starts = starts[choice_idx] 137 | 138 | ordered_choice_idx = np.full((args.pad_length), 139 | args.cls_token_idx) # start with cls 140 | # i= 0 first token is CLS 141 | i = 1 # continue on to the rest of the sequence with left bracket being assumed. 142 | # Shuffle the chroms now, there's no natural order to chromosomes 143 | uq_chroms = np.unique(new_chrom) 144 | np.random.shuffle(uq_chroms) # shuffle 145 | 146 | # This loop is actually just over one cell 147 | for chrom in uq_chroms: 148 | # Open Chrom token 149 | ordered_choice_idx[i] = int(chrom) + args.CHROM_TOKEN_OFFSET # token of this chromosome # i = 1 next token is a chrom open 150 | i += 1 151 | # now sort the genes by start order within the chroms 152 | loc = np.where(new_chrom == chrom)[0] 153 | sort_by_start = np.argsort( 154 | choosen_starts[loc]) # start locations for this chromsome 155 | 156 | to_add = choice_idx[loc[sort_by_start]] 157 | ordered_choice_idx[i:(i + len(to_add))] = dataset_idxs[to_add] 158 | i += len(to_add) 159 | ordered_choice_idx[i] = args.chrom_token_right_idx # add the chrom sep again 160 | i += 1 # add the closing token again 161 | 162 | longest_seq_len = max(longest_seq_len, i) 163 | remainder_len = (args.pad_length - i) 164 | 165 | cell_mask = torch.concat((torch.ones(i), 166 | # pay attention to all of these tokens, ignore the rest! 167 | torch.zeros(remainder_len))) 168 | 169 | mask[c, :] = cell_mask 170 | 171 | ordered_choice_idx[i:] = args.pad_token_idx # the remainder of the sequence 172 | cell_sentences[c, :] = torch.from_numpy(ordered_choice_idx) 173 | 174 | cell_sentences_pe = cell_sentences.long() # token indices 175 | 176 | return cell_sentences_pe, mask, longest_seq_len, cell_sentences -------------------------------------------------------------------------------- /eval_single_anndata.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for Evaluating a Single AnnData 3 | 4 | Parameters: 5 | ---------- 6 | - `adata_path` (str): 7 | Full path to the AnnData you want to embed. 8 | - `dir` (str): 9 | Working folder where all files will be saved. 10 | - `species` (str): 11 | Species of the AnnData. 12 | - `filter` (bool): 13 | Additional gene/cell filtering on the AnnData. 14 | - `skip` (bool): 15 | Skip datasets that appear to have already been created. 16 | - `model_loc` (str): 17 | Location of pretrained UCE model's weights in a `.torch` file. 18 | - `batch_size` (int): 19 | Batch size for processing. 20 | - `CXG` (bool): 21 | Use CXG model. 22 | - `nlayers` (int): 23 | Number of transformer layers. 24 | - `output_dim` (int): 25 | Desired output dimension. 26 | - `d_hid` (int): 27 | Hidden dimension for processing. 28 | - `token_dim` (int): 29 | Token dimension. 30 | - `spec_chrom_csv_path` (str): 31 | CSV file mapping genes from each species to their respective chromosomes 32 | and genomic start positions. 33 | - `token_file` (str): 34 | `.torch` file containing token/protein embeddings for all tokens. 35 | - `protein_embeddings_dir` (str): 36 | Directory containing protein embedding `.pt` files for all species. 37 | - `offset_pkl_path` (str): 38 | `.pkl` file mapping between species and their gene's locations in the `token_file`. 39 | - `pad_length` (int): 40 | Length to pad the cell sentence to. 41 | - `pad_token_idx` (int): 42 | Index of the padding token in the `token_file`. 43 | - `chrom_token_left_idx` (int): 44 | Left chromosome token index 45 | - `chrom_token_right_idx` (int): 46 | Right chromosome token index 47 | - `cls_token_idx` (int): 48 | CLS token index in the `token_file`. 49 | - `CHROM_TOKEN_OFFSET` (int): 50 | Offset index, tokens after this mark are chromosome identifiers. 51 | - `sample_size` (int): 52 | Number of genes sampled for cell sentence. 53 | - `multi_gpu` (bool): 54 | Run evaluation on multiple GPUs (using accelerator) 55 | 56 | Returns: 57 | ------- 58 | - `dir/{dataset_name}_proc.h5ad`: 59 | The processed AnnData. Processing involves subsetting it to genes which 60 | have protein embeddings and then refiltering the dataset by minimum counts. 61 | - `dir/{dataset_name}_chroms.pkl`: 62 | File mapping the genes in the dataset to their corresponding chromosome 63 | indices. 64 | - `dir/{dataset_name}_counts.npz`: 65 | File containing the counts of the AnnData in an easily accessible format. 66 | - `dir/{dataset_name}_shapes_dict.pkl`: 67 | File containing the shape (ncell x ngene) of the AnnData, used to read the 68 | `.npz` file. 69 | - `dir/{dataset_name}_pe_idx.torch`: 70 | File mapping between the genes in the dataset and their index in the tokens file. 71 | - `dir/{dataset_name}_starts.pkl`: 72 | File mapping between the genes in the dataset and their genomic start locations. 73 | 74 | """ 75 | 76 | 77 | import argparse 78 | from evaluate import AnndataProcessor 79 | from accelerate import Accelerator 80 | 81 | def main(args, accelerator): 82 | processor = AnndataProcessor(args, accelerator) 83 | processor.preprocess_anndata() 84 | processor.generate_idxs() 85 | processor.run_evaluation() 86 | 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser( 90 | description='Embed a single anndata using UCE.') 91 | 92 | # Anndata Processing Arguments 93 | parser.add_argument('--adata_path', type=str, 94 | default=None, 95 | help='Full path to the anndata you want to embed.') 96 | parser.add_argument('--dir', type=str, 97 | default="./", 98 | help='Working folder where all files will be saved.') 99 | parser.add_argument('--species', type=str, default="human", 100 | help='Species of the anndata.') 101 | parser.add_argument('--filter', type=bool, default=True, 102 | help='Additional gene/cell filtering on the anndata.') 103 | parser.add_argument('--skip', type=bool, default=True, 104 | help='Skip datasets that appear to have already been created.') 105 | 106 | # Model Arguments 107 | parser.add_argument('--model_loc', type=str, 108 | default=None, 109 | help='Location of the model.') 110 | parser.add_argument('--batch_size', type=int, default=25, 111 | help='Batch size.') 112 | parser.add_argument('--pad_length', type=int, default=1536, 113 | help='Batch size.') 114 | parser.add_argument("--pad_token_idx", type=int, default=0, 115 | help="PAD token index") 116 | parser.add_argument("--chrom_token_left_idx", type=int, default=1, 117 | help="Chrom token left index") 118 | parser.add_argument("--chrom_token_right_idx", type=int, default=2, 119 | help="Chrom token right index") 120 | parser.add_argument("--cls_token_idx", type=int, default=3, 121 | help="CLS token index") 122 | parser.add_argument("--CHROM_TOKEN_OFFSET", type=int, default=143574, 123 | help="Offset index, tokens after this mark are chromosome identifiers") 124 | parser.add_argument('--sample_size', type=int, default=1024, 125 | help='Number of genes sampled for cell sentence') 126 | parser.add_argument('--CXG', type=bool, default=True, 127 | help='Use CXG model.') 128 | parser.add_argument('--nlayers', type=int, default=4, 129 | help='Number of transformer layers.') 130 | parser.add_argument('--output_dim', type=int, default=1280, 131 | help='Output dimension.') 132 | parser.add_argument('--d_hid', type=int, default=5120, 133 | help='Hidden dimension.') 134 | parser.add_argument('--token_dim', type=int, default=5120, 135 | help='Token dimension.') 136 | parser.add_argument('--multi_gpu', type=bool, default=False, 137 | help='Use multiple GPUs') 138 | 139 | # Misc Arguments 140 | parser.add_argument("--spec_chrom_csv_path", 141 | default="./model_files/species_chrom.csv", type=str, 142 | help="CSV Path for species genes to chromosomes and start locations.") 143 | parser.add_argument("--token_file", 144 | default="./model_files/all_tokens.torch", type=str, 145 | help="Path for token embeddings.") 146 | parser.add_argument("--protein_embeddings_dir", 147 | default="./model_files/protein_embeddings/", type=str, 148 | help="Directory where protein embedding .pt files are stored.") 149 | parser.add_argument("--offset_pkl_path", 150 | default="./model_files/species_offsets.pkl", type=str, 151 | help="PKL file which contains offsets for each species.") 152 | 153 | args = parser.parse_args() 154 | accelerator = Accelerator(project_dir=args.dir) 155 | main(args, accelerator) 156 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # os.environ["NCCL_DEBUG"] = "INFO" 4 | os.environ["OMP_NUM_THREADS"] = "12" # export OMP_NUM_THREADS=4 5 | os.environ["OPENBLAS_NUM_THREADS"] = "12" # export OPENBLAS_NUM_THREADS=4 6 | os.environ["MKL_NUM_THREADS"] = "12" # export MKL_NUM_THREADS=6 7 | os.environ["VECLIB_MAXIMUM_THREADS"] = "12" # export VECLIB_MAXIMUM_THREADS=4 8 | os.environ["NUMEXPR_NUM_THREADS"] = "12" 9 | 10 | import warnings 11 | 12 | warnings.filterwarnings("ignore") 13 | 14 | import scanpy as sc 15 | from tqdm.auto import tqdm 16 | from torch import nn, Tensor 17 | 18 | from model import TransformerModel 19 | from eval_data import MultiDatasetSentences, MultiDatasetSentenceCollator 20 | from utils import figshare_download 21 | 22 | from torch.utils.data import DataLoader 23 | from data_proc.data_utils import adata_path_to_prot_chrom_starts, \ 24 | get_spec_chrom_csv, process_raw_anndata, get_species_to_pe 25 | 26 | import os 27 | import pickle 28 | import pandas as pd 29 | import numpy as np 30 | import torch 31 | 32 | 33 | class AnndataProcessor: 34 | def __init__(self, args, accelerator): 35 | self.args = args 36 | self.accelerator = accelerator 37 | self.h5_folder_path = self.args.dir 38 | self.npz_folder_path = self.args.dir 39 | self.scp = "" 40 | 41 | # Check if paths exist, if not, create them 42 | self.check_paths() 43 | 44 | # Set up the anndata 45 | self.adata_name = self.args.adata_path.split("/")[-1] 46 | self.adata_root_path = self.args.adata_path.replace(self.adata_name, "") 47 | self.name = self.adata_name.replace(".h5ad", "") 48 | self.proc_h5_path = self.h5_folder_path + f"{self.name}_proc.h5ad" 49 | self.adata = None 50 | 51 | # Set up the row 52 | row = pd.Series() 53 | row.path = self.adata_name 54 | row.covar_col = np.nan 55 | row.species = self.args.species 56 | self.row = row 57 | 58 | # Set paths once to be used throughout the class 59 | self.pe_idx_path = self.args.dir + f"{self.name}_pe_idx.torch" 60 | self.chroms_path = self.args.dir + f"{self.name}_chroms.pkl" 61 | self.starts_path = self.args.dir + f"{self.name}_starts.pkl" 62 | self.shapes_dict_path = self.args.dir + f"{self.name}_shapes_dict.pkl" 63 | 64 | def check_paths(self): 65 | """ 66 | Check if the paths exist, if not, create them 67 | """ 68 | figshare_download("https://figshare.com/ndownloader/files/42706558", 69 | self.args.spec_chrom_csv_path) 70 | figshare_download("https://figshare.com/ndownloader/files/42706555", 71 | self.args.offset_pkl_path) 72 | if not os.path.exists(self.args.protein_embeddings_dir): 73 | figshare_download("https://figshare.com/ndownloader/files/42715213", 74 | 'model_files/protein_embeddings.tar.gz') 75 | figshare_download("https://figshare.com/ndownloader/files/42706585", 76 | self.args.token_file) 77 | if self.args.adata_path is None: 78 | print("Using sample AnnData: 10k pbmcs dataset") 79 | self.args.adata_path = "./data/10k_pbmcs_proc.h5ad" 80 | figshare_download( 81 | "https://figshare.com/ndownloader/files/42706966", 82 | self.args.adata_path) 83 | if self.args.model_loc is None: 84 | print("Using sample 4 layer model") 85 | self.args.model_loc = "./model_files/4layer_model.torch" 86 | figshare_download( 87 | "https://figshare.com/ndownloader/files/42706576", 88 | self.args.model_loc) 89 | 90 | 91 | def preprocess_anndata(self): 92 | if self.accelerator.is_main_process: 93 | self.adata, num_cells, num_genes = \ 94 | process_raw_anndata(self.row, 95 | self.h5_folder_path, 96 | self.npz_folder_path, 97 | self.scp, 98 | self.args.skip, 99 | self.args.filter, 100 | root=self.adata_root_path) 101 | if (num_cells is not None) and (num_genes is not None): 102 | self.save_shapes_dict(self.name, num_cells, num_genes, 103 | self.shapes_dict_path) 104 | 105 | if self.adata is None: 106 | self.adata = sc.read(self.proc_h5_path) 107 | 108 | def save_shapes_dict(self, name, num_cells, num_genes, shapes_dict_path): 109 | shapes_dict = {name: (num_cells, num_genes)} 110 | with open(shapes_dict_path, "wb+") as f: 111 | pickle.dump(shapes_dict, f) 112 | print("Wrote Shapes Dict") 113 | 114 | def generate_idxs(self): 115 | if self.accelerator.is_main_process: 116 | if os.path.exists(self.pe_idx_path) and \ 117 | os.path.exists(self.chroms_path) and \ 118 | os.path.exists(self.starts_path): 119 | print("PE Idx, Chrom and Starts files already created") 120 | 121 | else: 122 | species_to_pe = get_species_to_pe(self.args.protein_embeddings_dir) 123 | with open(self.args.offset_pkl_path, "rb") as f: 124 | species_to_offsets = pickle.load(f) 125 | 126 | gene_to_chrom_pos = get_spec_chrom_csv( 127 | self.args.spec_chrom_csv_path) 128 | dataset_species = self.args.species 129 | spec_pe_genes = list(species_to_pe[dataset_species].keys()) 130 | offset = species_to_offsets[dataset_species] 131 | pe_row_idxs, dataset_chroms, dataset_pos = adata_path_to_prot_chrom_starts( 132 | self.adata, dataset_species, spec_pe_genes, gene_to_chrom_pos, offset) 133 | 134 | # Save to the temp dict 135 | torch.save({self.name: pe_row_idxs}, self.pe_idx_path) 136 | with open(self.chroms_path, "wb+") as f: 137 | pickle.dump({self.name: dataset_chroms}, f) 138 | with open(self.starts_path, "wb+") as f: 139 | pickle.dump({self.name: dataset_pos}, f) 140 | 141 | def run_evaluation(self): 142 | self.accelerator.wait_for_everyone() 143 | with open(self.shapes_dict_path, "rb") as f: 144 | shapes_dict = pickle.load(f) 145 | run_eval(self.adata, self.name, self.pe_idx_path, self.chroms_path, 146 | self.starts_path, shapes_dict, self.accelerator, self.args) 147 | 148 | 149 | def get_ESM2_embeddings(args): 150 | # Load in ESM2 embeddings and special tokens 151 | all_pe = torch.load(args.token_file) 152 | if all_pe.shape[0] == 143574: 153 | torch.manual_seed(23) 154 | CHROM_TENSORS = torch.normal(mean=0, std=1, size=(1895, args.token_dim)) 155 | # 1895 is the total number of chromosome choices, it is hardcoded for now 156 | all_pe = torch.vstack( 157 | (all_pe, CHROM_TENSORS)) # Add the chrom tensors to the end 158 | all_pe.requires_grad = False 159 | 160 | return all_pe 161 | 162 | 163 | def padding_tensor(sequences): 164 | """ 165 | :param sequences: list of tensors 166 | :return: 167 | """ 168 | num = len(sequences) 169 | max_len = max([s.size(0) for s in sequences]) 170 | out_dims = (num, max_len, 1280) 171 | 172 | out_tensor = sequences[0].data.new(*out_dims).fill_(0) 173 | out_dims2 = (num, max_len) 174 | 175 | mask = sequences[0].data.new(*out_dims2).fill_(float('-inf')) 176 | for i, tensor in enumerate(sequences): 177 | length = tensor.size(0) 178 | out_tensor[i, :length] = tensor 179 | mask[i, :length] = 1 180 | return out_tensor.permute(1, 0, 2), mask 181 | 182 | 183 | def run_eval(adata, name, pe_idx_path, chroms_path, starts_path, shapes_dict, 184 | accelerator, args): 185 | 186 | #### Set up the model #### 187 | token_dim = args.token_dim 188 | emsize = 1280 # embedding dimension 189 | d_hid = args.d_hid # dimension of the feedforward network model in nn.TransformerEncoder 190 | nlayers = args.nlayers # number of nn.TransformerEncoderLayer in nn.TransformerEncoder 191 | nhead = 20 # number of heads in nn.MultiheadAttention 192 | dropout = 0.05 # dropout probability 193 | model = TransformerModel(token_dim=token_dim, d_model=emsize, nhead=nhead, 194 | d_hid=d_hid, 195 | nlayers=nlayers, dropout=dropout, 196 | output_dim=args.output_dim) 197 | if args.model_loc is None: 198 | raise ValueError("Must provide a model location") 199 | # intialize as empty 200 | empty_pe = torch.zeros(145469, 5120) 201 | empty_pe.requires_grad = False 202 | model.pe_embedding = nn.Embedding.from_pretrained(empty_pe) 203 | model.load_state_dict(torch.load(args.model_loc, map_location="cpu"), 204 | strict=True) 205 | # Load in the real token embeddings 206 | all_pe = get_ESM2_embeddings(args) 207 | # This will make sure that you don't overwrite the tokens in case you're embedding species from the training data 208 | # We avoid doing that just in case the random seeds are different across different versions. 209 | if all_pe.shape[0] != 145469: 210 | all_pe.requires_grad = False 211 | model.pe_embedding = nn.Embedding.from_pretrained(all_pe) 212 | print(f"Loaded model:\n{args.model_loc}") 213 | model = model.eval() 214 | model = accelerator.prepare(model) 215 | batch_size = args.batch_size 216 | 217 | #### Run the model #### 218 | # Dataloaders 219 | dataset = MultiDatasetSentences(sorted_dataset_names=[name], 220 | shapes_dict=shapes_dict, 221 | args=args, npzs_dir=args.dir, 222 | dataset_to_protein_embeddings_path=pe_idx_path, 223 | datasets_to_chroms_path=chroms_path, 224 | datasets_to_starts_path=starts_path 225 | ) 226 | multi_dataset_sentence_collator = MultiDatasetSentenceCollator(args) 227 | 228 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, 229 | collate_fn=multi_dataset_sentence_collator, 230 | num_workers=0) 231 | dataloader = accelerator.prepare(dataloader) 232 | pbar = tqdm(dataloader, disable=not accelerator.is_local_main_process) 233 | dataset_embeds = [] 234 | with torch.no_grad(): 235 | for batch in pbar: 236 | batch_sentences, mask, idxs = batch[0], batch[1], batch[2] 237 | batch_sentences = batch_sentences.permute(1, 0) 238 | if args.multi_gpu: 239 | batch_sentences = model.module.pe_embedding(batch_sentences.long()) 240 | else: 241 | batch_sentences = model.pe_embedding(batch_sentences.long()) 242 | batch_sentences = nn.functional.normalize(batch_sentences, 243 | dim=2) # Normalize token outputs now 244 | _, embedding = model.forward(batch_sentences, mask=mask) 245 | # Fix for duplicates in last batch 246 | accelerator.wait_for_everyone() 247 | embeddings = accelerator.gather_for_metrics((embedding)) 248 | if accelerator.is_main_process: 249 | dataset_embeds.append(embeddings.detach().cpu().numpy()) 250 | 251 | accelerator.wait_for_everyone() 252 | if accelerator.is_main_process: 253 | dataset_embeds = np.vstack(dataset_embeds) 254 | adata.obsm["X_uce"] = dataset_embeds 255 | write_path = args.dir + f"{name}_uce_adata.h5ad" 256 | adata.write(write_path) 257 | 258 | print("*****Wrote Anndata to:*****") 259 | print(write_path) 260 | -------------------------------------------------------------------------------- /examples/Label Transfer Using Logistic Classifier.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "3f4f1b19-5369-4e4d-9366-b6f07f88b402", 6 | "metadata": {}, 7 | "source": [ 8 | "# Transferring Labels Using UCE\n", 9 | "\n", 10 | "This notebook walks through the example from Figure 4d,4e of transferring labels from mouse kidney norn cells to a human lung disease dataset.\n", 11 | "\n", 12 | "To transfer labels, we use a basic default implementation of sklearn's logistic classifier." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "id": "5ca49083-fd91-473f-b60a-621b07d52de2", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "## Imports\n", 23 | "import scanpy as sc\n", 24 | "import numpy as np\n", 25 | "import random\n", 26 | "from sklearn.linear_model import LogisticRegression\n", 27 | "sc._settings.settings._vector_friendly=True\n", 28 | "import matplotlib\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "\n", 31 | "## Seed\n", 32 | "np.random.seed(0)\n", 33 | "random.seed(0)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "id": "72536f9e-b010-44a7-b323-c32f05cf7d98", 39 | "metadata": {}, 40 | "source": [ 41 | "## Load in anndatas\n", 42 | "You can download the anndatas here: https://drive.google.com/drive/folders/1f63fh0ykgEhCrkd_EVvIootBw7LYDVI7" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 2, 48 | "id": "8e5d7a7e-2a86-4fce-82fe-17afcf83dec5", 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "epo_uce = sc.read(\"mouse_kidney_norn.h5ad\")\n", 53 | "kam_20_uce = sc.read(\"human_lung_disease.h5ad\")" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "id": "dcbf764a-ad99-4622-a780-ecd62e471132", 59 | "metadata": {}, 60 | "source": [ 61 | "### Train Classifier on Mouse Kidney Cells\n", 62 | "\n", 63 | "We train a classifier to predict coarsened cell types, from the UCE embeddings" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "id": "9e56e9aa-bf4a-42d9-b8a6-e76833161083", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "epo_map = {\n", 74 | " \"Norn\":\"Norn\",\n", 75 | " \"Proximal tubule\":\"Proximal tubule\",\n", 76 | " \"Collecting duct principal\":\"Collecting duct\",\n", 77 | " \"Distal convoluted tubule\":\"Distal convoluted tubule\",\n", 78 | " \"Fibroblasts\":\"Fibroblast\",\n", 79 | " \"Endothelial\":\"Endothelial\",\n", 80 | " \"Collecting duct transient\":\"Collecting duct\",\n", 81 | " \"Other\":\"misc\",\n", 82 | " \"Pericyte Ren1+\":\"Pericyte\",\n", 83 | " \"Podocytes\":\"Podocyte\",\n", 84 | " \"Pericyte3\":\"Pericyte\",\n", 85 | " \"Pericyte1\":\"Pericyte\",\n", 86 | " \"Pericyte2\":\"Pericyte\",\n", 87 | " \"Collecting duct intercalated\":\"Collecting duct\",\n", 88 | " \"Loop of henle\":\"Loop of henle\",\n", 89 | " \"Proximal tubule2\":\"Proximal tubule\",\n", 90 | " \"Macrophages\":\"Macrophage\",\n", 91 | " \"Neutrophil\":\"Granulocyte\",\n", 92 | " \"T lymphocyte\":\"T cell\",\n", 93 | " \"Collecting duct\":\"Collecting duct\",\n", 94 | " \"Monocytes\":\"Monocyte\",\n", 95 | " \n", 96 | "} # coarse cell type map" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 4, 102 | "id": "39b160cd-4462-437f-b89e-7363e20e8ebe", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "epo_uce_no_misc = epo_uce[epo_uce.obs.group != \"Other\"] # remove misc cells\n", 107 | "X = epo_uce_no_misc.obsm[\"X_uce\"] # input is UCE embeddings\n", 108 | "y = [epo_map[ct] for ct in epo_uce_no_misc.obs[\"group\"].values] # output is mapped cell types\n", 109 | "clf = LogisticRegression(random_state=0).fit(X, y) # fit classifier" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "id": "ae2fd8b6-e681-4f02-9f9c-71d71d95f925", 115 | "metadata": {}, 116 | "source": [ 117 | "### Predict norn-like cells using classifier" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 5, 123 | "id": "722bcc74-aaca-43f9-b5dd-2427584d7683", 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "kam_20_uce.obs[\"pred\"] = clf.predict(kam_20_uce.obsm[\"X_uce\"]) # predict cell types for lung disease dataset" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 6, 133 | "id": "ae31d2c0-5987-46c1-9be0-d874ae6577b5", 134 | "metadata": {}, 135 | "outputs": [ 136 | { 137 | "data": { 138 | "text/plain": [ 139 | "pred\n", 140 | "Proximal tubule 119834\n", 141 | "T cell 93556\n", 142 | "Granulocyte 52485\n", 143 | "Collecting duct 15727\n", 144 | "Macrophage 11800\n", 145 | "Endothelial 7233\n", 146 | "Norn 6005\n", 147 | "Podocyte 4270\n", 148 | "Pericyte 1316\n", 149 | "Fibroblast 623\n", 150 | "Monocyte 56\n", 151 | "Loop of henle 23\n", 152 | "Name: count, dtype: int64" 153 | ] 154 | }, 155 | "execution_count": 6, 156 | "metadata": {}, 157 | "output_type": "execute_result" 158 | } 159 | ], 160 | "source": [ 161 | "kam_20_uce.obs[\"pred\"].value_counts()" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "id": "23a839f8-88ce-4446-b4a5-961a38a264b5", 167 | "metadata": {}, 168 | "source": [ 169 | "# Check Differential Expression" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 7, 175 | "id": "159cc172-1be5-4ad2-a4b5-2ebd2e52302e", 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "# Preproccess Count Values\n", 180 | "sc.pp.highly_variable_genes(kam_20_uce, n_top_genes=8000, flavor=\"seurat_v3\", subset=True)\n", 181 | "sc.pp.normalize_per_cell(kam_20_uce)\n", 182 | "sc.pp.log1p(kam_20_uce)" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 8, 188 | "id": "57b9fb8a-2699-43f4-b74f-2a646e8610c8", 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "# Subset to predicted Norn-like cells\n", 193 | "kam20_norn_ad = kam_20_uce[kam_20_uce.obs.pred == \"Norn\"].copy()" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 9, 199 | "id": "591d1aa7-c4ad-4ad9-9bef-2ed91ed19b57", 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "all_de_dfs = {}\n", 204 | "ngenes = 4" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 10, 210 | "id": "bb5bc067-c588-4a34-a8f3-138824531828", 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "sc.tl.rank_genes_groups(kam20_norn_ad, groupby=\"Disease_Identity\", use_raw=False, reference=\"Control\") # DE, diseases vs control" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 11, 220 | "id": "db176232-0dee-4c1a-bda2-38a19a7afcdd", 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "de_df = sc.get.rank_genes_groups_df(kam20_norn_ad, group=\"COPD\") # get COPD vs control results\n", 225 | "all_de_dfs[\"copd_vs_control\"] = de_df[~de_df.index.isin(de_df.iloc[10:-10].index)] # top 10 and bottom 10 genes\n", 226 | "copd_control_genes = list(de_df.head(ngenes)[\"names\"].values)" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 12, 232 | "id": "c4d25f7a-4056-4e03-bcd1-9ab9cb2705e6", 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "de_df = sc.get.rank_genes_groups_df(kam20_norn_ad, group=\"IPF\") # get IPF vs control results\n", 237 | "all_de_dfs[\"ipf_vs_control\"] = de_df[~de_df.index.isin(de_df.iloc[10:-10].index)] # top 10 and bottom 10 genes\n", 238 | "ipf_control_genes = list(de_df.head(ngenes)[\"names\"].values)" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 13, 244 | "id": "4e72311d-44ee-4ca5-8abc-24e9079b574b", 245 | "metadata": {}, 246 | "outputs": [], 247 | "source": [ 248 | "sc.tl.rank_genes_groups(kam20_norn_ad, groupby=\"Disease_Identity\", use_raw=False, reference=\"IPF\") # DE, all vs IPF" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 14, 254 | "id": "69be495e-3b0f-41f3-95d8-47a526f00bbd", 255 | "metadata": {}, 256 | "outputs": [], 257 | "source": [ 258 | "de_df = sc.get.rank_genes_groups_df(kam20_norn_ad, group=\"COPD\") # COPD vs IPF\n", 259 | "all_de_dfs[\"copd_vs_ipf\"] = de_df[~de_df.index.isin(de_df.iloc[10:-10].index)] # top 10 and bottom 10 genes\n", 260 | "copd_ipf_genes = list(de_df.head(ngenes)[\"names\"].values)" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 15, 266 | "id": "db9e2358-7431-49df-a980-ff4869d824d4", 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "sc.tl.rank_genes_groups(kam20_norn_ad, groupby=\"Disease_Identity\", use_raw=False, reference=\"COPD\") # DE, all vs COPD" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 16, 276 | "id": "753b9891-f940-46b7-9fa2-8972b6f67136", 277 | "metadata": {}, 278 | "outputs": [], 279 | "source": [ 280 | "de_df = sc.get.rank_genes_groups_df(kam20_norn_ad, group=\"IPF\") # IPF vs COPD\n", 281 | "all_de_dfs[\"ipf_vs_copd\"] = de_df[~de_df.index.isin(de_df.iloc[10:-10].index)] # top 10 and bottom 10 genes\n", 282 | "ipf_copd_genes = list(de_df.head(ngenes)[\"names\"].values)" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": 17, 288 | "id": "647407f4-c9ef-405e-9742-342e31664497", 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "data": { 293 | "text/plain": [ 294 | "['POSTN',\n", 295 | " 'COL1A1',\n", 296 | " 'COL3A1',\n", 297 | " 'SPARC',\n", 298 | " 'LUM',\n", 299 | " 'MFAP4',\n", 300 | " 'PTGDS',\n", 301 | " 'PTPRG',\n", 302 | " 'GPX3',\n", 303 | " 'NAMPT',\n", 304 | " 'RPL41',\n", 305 | " 'CRISPLD2',\n", 306 | " 'SERPINH1',\n", 307 | " 'COL1A2']" 308 | ] 309 | }, 310 | "execution_count": 17, 311 | "metadata": {}, 312 | "output_type": "execute_result" 313 | } 314 | ], 315 | "source": [ 316 | "gene_list = ipf_control_genes + copd_control_genes + copd_ipf_genes + ipf_copd_genes\n", 317 | "\n", 318 | "reduced_gene_list = []\n", 319 | "for g in gene_list:\n", 320 | " if g in reduced_gene_list:\n", 321 | " next\n", 322 | " else:\n", 323 | " reduced_gene_list.append(g)\n", 324 | "reduced_gene_list" 325 | ] 326 | }, 327 | { 328 | "cell_type": "markdown", 329 | "id": "c2abdd88-2f7b-4ac4-961a-ea417f85bb04", 330 | "metadata": {}, 331 | "source": [ 332 | "## Plot Results" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": 18, 338 | "id": "f11581f2-2895-419e-85da-0f53c07756ec", 339 | "metadata": {}, 340 | "outputs": [ 341 | { 342 | "data": { 343 | "image/png": "", 344 | "text/plain": [ 345 | "
" 346 | ] 347 | }, 348 | "metadata": {}, 349 | "output_type": "display_data" 350 | } 351 | ], 352 | "source": [ 353 | "fig, ax = plt.subplots(1,1, figsize=(5, 5))\n", 354 | "sc.pl.dotplot(kam20_norn_ad, groupby=\"Disease_Identity\", var_names=reduced_gene_list, show=True, swap_axes=True, ax=ax)" 355 | ] 356 | } 357 | ], 358 | "metadata": { 359 | "kernelspec": { 360 | "display_name": "Python 3 (ipykernel)", 361 | "language": "python", 362 | "name": "python3" 363 | }, 364 | "language_info": { 365 | "codemirror_mode": { 366 | "name": "ipython", 367 | "version": 3 368 | }, 369 | "file_extension": ".py", 370 | "mimetype": "text/x-python", 371 | "name": "python", 372 | "nbconvert_exporter": "python", 373 | "pygments_lexer": "ipython3", 374 | "version": "3.11.7" 375 | } 376 | }, 377 | "nbformat": 4, 378 | "nbformat_minor": 5 379 | } 380 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model class 3 | 4 | """ 5 | 6 | import warnings 7 | warnings.filterwarnings("ignore") 8 | import math 9 | from torch import nn, Tensor 10 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 11 | 12 | import sys 13 | sys.path.append('../') 14 | from typing import Any 15 | import torch 16 | 17 | 18 | def full_block(in_features, out_features, p_drop=0.1): 19 | return nn.Sequential( 20 | nn.Linear(in_features, out_features, bias=True), 21 | nn.LayerNorm(out_features), 22 | nn.GELU(), 23 | nn.Dropout(p=p_drop), 24 | ) 25 | 26 | 27 | class PositionalEncoding(nn.Module): 28 | 29 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1536): 30 | super().__init__() 31 | self.dropout = nn.Dropout(p=dropout) 32 | 33 | position = torch.arange(max_len).unsqueeze(1) 34 | div_term = torch.exp \ 35 | (torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 36 | pe = torch.zeros(max_len, 1, d_model) 37 | pe[:, 0, 0::2] = torch.sin(position * div_term) 38 | pe[:, 0, 1::2] = torch.cos(position * div_term) 39 | self.register_buffer('pe', pe) 40 | 41 | def forward(self, x: Tensor) -> Tensor: 42 | """ 43 | Args: 44 | x: Tensor, shape [seq_len, batch_size, embedding_dim] 45 | """ 46 | x = x + self.pe[:x.size(0)] 47 | return self.dropout(x) 48 | 49 | 50 | class TransformerModel(nn.Module): 51 | 52 | def __init__(self, token_dim: int, d_model: int, nhead: int, d_hid: int, 53 | nlayers: int, output_dim:int, dropout: float = 0.05): 54 | super().__init__() 55 | self.model_type = 'Transformer' 56 | self.pos_encoder = PositionalEncoding(d_model, dropout) 57 | self.d_model = d_model 58 | 59 | self.encoder = nn.Sequential(nn.Linear(token_dim, d_model), 60 | nn.GELU(), 61 | nn.LayerNorm(d_model)) 62 | 63 | 64 | 65 | encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout) 66 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) 67 | 68 | 69 | self.d_model = d_model 70 | self.dropout = dropout 71 | 72 | 73 | self.decoder = nn.Sequential(full_block(d_model, 1024, self.dropout), 74 | full_block(1024, output_dim, self.dropout), 75 | full_block(output_dim, output_dim, self.dropout), 76 | nn.Linear(output_dim, output_dim) 77 | ) 78 | 79 | self.binary_decoder = nn.Sequential( 80 | full_block(output_dim + 1280, 2048, self.dropout), 81 | full_block(2048, 512, self.dropout), 82 | full_block(512, 128, self.dropout), 83 | nn.Linear(128, 1) 84 | ) 85 | 86 | self.gene_embedding_layer = nn.Sequential(nn.Linear(token_dim, d_model), 87 | nn.GELU(), 88 | nn.LayerNorm(d_model)) 89 | 90 | self.pe_embedding = None 91 | 92 | def forward(self, src: Tensor, mask: Tensor): 93 | """ 94 | Args: 95 | src: Tensor, shape [seq_len, batch_size] 96 | Returns: 97 | output Tensor of shape [seq_len, batch_size, ntoken] 98 | """ 99 | src = self.encoder(src) * math.sqrt(self.d_model) 100 | src = self.pos_encoder(src) 101 | output = self.transformer_encoder(src, src_key_padding_mask=( 1 -mask)) 102 | gene_output = self.decoder(output) # batch x seq_len x 128 103 | # embedding = torch.mul(gene_output, mask.t().unsqueeze(2)).sum(0) # average over non zero genes 104 | # In the new format, the cls token, which is at the 0 index mark, is the output. 105 | embedding = gene_output[0, :, :] # select only the CLS token. 106 | embedding = nn.functional.normalize(embedding, dim=1) # Normalize. 107 | return gene_output, embedding 108 | 109 | 110 | def predict(self, cell_embedding, gene_embeddings): 111 | gene_embeddings = self.gene_embedding_layer(gene_embeddings) 112 | dec = self.binary_decoder \ 113 | (torch.hstack((cell_embedding, gene_embeddings))) 114 | return dec 115 | 116 | -------------------------------------------------------------------------------- /model_files/new_species_protein_embeddings.csv: -------------------------------------------------------------------------------- 1 | species,path 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.26.4 2 | scipy==1.14.1 3 | pandas==2.2.2 4 | tqdm==4.66.5 5 | torch==2.1.1 6 | scanpy==1.10.2 7 | accelerate==0.24.0 8 | requests==2.25.1 9 | urllib3==1.26.6 10 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils 3 | 4 | """ 5 | 6 | import warnings 7 | warnings.filterwarnings("ignore") 8 | import pandas as pd 9 | import numpy as np 10 | import os 11 | import requests 12 | from tqdm import tqdm 13 | import tarfile 14 | 15 | 16 | def get_shapes_dict(dataset_path): 17 | shapes_dict = {} 18 | datasets_df = pd.read_csv(dataset_path) 19 | sorted_dataset_names = sorted(datasets_df["names"]) 20 | 21 | for name in sorted_dataset_names: 22 | shapes_dict[name] = (int(datasets_df.set_index("names").loc[name]["num_cells"]), 8000) 23 | 24 | shapes_dict["dev_immune_mouse"] = (443697, 4786) 25 | shapes_dict["dev_immune_human"] = (34009, 5566) 26 | shapes_dict["intestinal_tract_human"] = (69668, 5192) 27 | shapes_dict["gtex_human"] = (18511, 7109) 28 | shapes_dict["gut_endoderm_mouse"] = (113043, 6806) 29 | shapes_dict["luca"] = (249591, 7196) 30 | shapes_dict.update({ 31 | "madissoon_novel_lung":(190728, 8000), 32 | 'flores_cerebellum_human': (20232, 8000), 33 | 'osuch_gut_human': (272310, 8000), 34 | 'msk_ovarian_human': (929690, 8000), 35 | 'htan_vmuc_dis_epi_human': (65084, 8000), 36 | 'htan_vmuc_val_epi_human': (57564, 8000), 37 | 'htan_vmuc_non_epi_human': (9099, 8000), 38 | 'hao_pbmc_3p_human': (161764, 8000), 39 | 'hao_pbmc_5p_human': (49147, 8000), 40 | 'gao_tumors_human': (36111, 8000), 41 | 'swabrick_breast_human': (92427, 8000), 42 | 'wu_cryo_tumors_human': (105662, 8000), 43 | 'cell_line_het_human': (53513, 8000), 44 | 'bi_allen_metastasis_human': (27787, 8000), 45 | 'zheng68k_human': (68579, 8000), 46 | 'zheng68k_12k_human': (68579, 12000), 47 | 'mouse_embryo_ct': (153597, 12000), 48 | "regev_gtex_heart": (36574, 8000), 49 | "tabula_sapiens_heart": (11505, 8000), 50 | "10k_pbmcs":(11990, 12000), 51 | "epo_ido":(35834,12000), 52 | 'tabula_sapiens_kidney': (9641, 8000), 53 | 'tabula_microcebus_kidney': (14592, 8000), 54 | 'tabula_muris_kidney': (2781, 8000), 55 | 'tabula_muris_senis_kidney': (19610, 8000), 56 | 'immune_human': (33506, 8000) 57 | }) 58 | 59 | shapes_dict["zyl_sanes_glaucoma_pig"] = (5901, 6819) 60 | shapes_dict["parkinsons_macaF"] = (1062, 5103) 61 | 62 | for row in datasets_df.iterrows(): 63 | ngenes = row[1].num_genes 64 | ncells = row[1].num_cells 65 | name = row[1].names 66 | if not np.isnan(ngenes): 67 | shapes_dict[name] = (int(ncells), int(ngenes)) 68 | 69 | return shapes_dict 70 | 71 | 72 | def figshare_download(url, save_path): 73 | """ 74 | Figshare download helper with progress bar 75 | 76 | Args: 77 | url (str): the url of the dataset 78 | path (str): the path to save the dataset 79 | """ 80 | 81 | if os.path.exists(save_path): 82 | return 83 | else: 84 | # Check if directory exists 85 | if not os.path.exists(os.path.dirname(save_path)): 86 | os.makedirs(os.path.dirname(save_path)) 87 | print("Downloading " + save_path + " from " + url + " ..." + "\n") 88 | response = requests.get(url, stream=True) 89 | total_size_in_bytes = int(response.headers.get('content-length', 0)) 90 | block_size = 1024 91 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', 92 | unit_scale=True) 93 | with open(save_path, 'wb') as file: 94 | for data in response.iter_content(block_size): 95 | progress_bar.update(len(data)) 96 | file.write(data) 97 | progress_bar.close() 98 | 99 | # If the downloaded filename ends in tar.gz then extraact it 100 | if save_path.endswith(".tar.gz"): 101 | with tarfile.open(save_path) as tar: 102 | tar.extractall(path=os.path.dirname(save_path)) 103 | print("Done!") 104 | --------------------------------------------------------------------------------