├── LICENSE
├── README.md
├── data_proc
├── Create New Species Files.ipynb
├── __pycache__
│ ├── data_utils.cpython-38.pyc
│ └── gene_embeddings.cpython-38.pyc
├── data_utils.py
├── download_proc_czi_cxg.py
├── gene_embeddings.py
├── generate_reduced_chrom_files.py
└── preproc_many_dataset.py
├── eval_data.py
├── eval_single_anndata.py
├── evaluate.py
├── examples
├── Benchmark Embeddings with scIB.ipynb
└── Label Transfer Using Logistic Classifier.ipynb
├── model.py
├── model_files
└── new_species_protein_embeddings.csv
├── requirements.txt
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Yanay Rosen, Yusuf Roohani, Jure Leskovec
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Universal Cell Embeddings
2 |
3 | This repo includes a PyTorch [HuggingFace Accelerator](https://huggingface.co/docs/accelerate/package_reference/accelerator) implementation of the UCE model, to be used to embed individual anndata datasets.
4 |
5 | ## Installation
6 |
7 | ```
8 | pip install -r requirements.txt
9 | ```
10 |
11 | ## Embedding a new dataset
12 |
13 | To generate an embedding for a new single-cell RNA sequencing dataset in the AnnData format, use the `eval_single_anndata.py` script.
14 |
15 | ```
16 | python eval_single_anndata.py --adata_path {path_to_anndata} --dir {output_dir} --species {species} --model_loc {model_loc} --batch_size {batch_size}
17 | ```
18 |
19 | where
20 | - `adata_path`: a h5ad file. The `.X` slot of the file should be scRNA-seq counts. The `.var_names` slot should correspond to gene names, *not ENSEMBLIDs*.
21 | - `dir`: the working directory in which intermediate and final output files will be saved to skip repeated processing of the same dataset.
22 | - `species`: the species of the dataset you are embedding.
23 | - `model_loc`: the location of the model weights `.torch` file.
24 | - `batch_size`: the per GPU batch size. For the 33 layer model, on a 80GB GPU, you should use 25. For a 4 layer model on the same GPU, you can use 100.
25 |
26 | For a sample output on the 10k pbmc dataset, run
27 | ```
28 | python eval_single_anndata.py
29 | ```
30 | All necessary model files will be downloaded automatically.
31 |
32 |
33 | **Note**: This script makes use of additional files, which are described in the code documentation. These are downloaded automatically unless already present in the working directory. The script defaults to the pretrained 4-layer model. For running the pretrained 33-layer model from the paper, please download using this [link](https://figshare.com/articles/dataset/Universal_Cell_Embedding_Model_Files/24320806?file=43423236) and set `--nlayers 33`.
34 |
35 | ## Output
36 |
37 | Final evaluated AnnData: `dir/{dataset_name}.h5ad`. This AnnData will be
38 | identical to the proccessed input anndata, but have UCE embeddings added in the `.obsm["X_uce"]` slot.
39 |
40 | Please see documentation for information on additional output files. All
41 | outputs from `eval_single_anndata.py` are stored in the `dir` directory.
42 |
43 | ## Data
44 |
45 | You can download processed datasets used in the papere [here](https://drive.google.com/drive/folders/1f63fh0ykgEhCrkd_EVvIootBw7LYDVI7?usp=drive_link)
46 |
47 | **Note:** These datasets were embedded using the 33 layer model. Embeddings for the 33 layer model are not compatible with embeddings from the 4 layer model.
48 |
49 | ## Citing
50 |
51 | If you find our paper and code useful, please consider citing the [preprint](https://www.biorxiv.org/content/10.1101/2023.11.28.568918v1):
52 |
53 | ```
54 | @article{rosen2023universal,
55 | title={Universal Cell Embeddings: A Foundation Model for Cell Biology},
56 | author={Rosen, Yanay and Roohani, Yusuf and Agrawal, Ayush and Samotorcan, Leon and Consortium, Tabula Sapiens and Quake, Stephen R and Leskovec, Jure},
57 | journal={bioRxiv},
58 | pages={2023--11},
59 | year={2023},
60 | publisher={Cold Spring Harbor Laboratory}
61 | }
62 | ```
63 |
--------------------------------------------------------------------------------
/data_proc/Create New Species Files.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "0e4018ee",
6 | "metadata": {},
7 | "source": [
8 | "# Embedding Novel Species\n",
9 | "\n",
10 | "This notebook will create the files you need to embed a novel species that wasn't included in the training data.\n",
11 | "\n",
12 | "To start, you will need to download the ESM2 protein embeddings and the reference proteome for the species.\n",
13 | "\n",
14 | "You can find precalculated ESM2 protein embeddings for many species [here](https://drive.google.com/drive/folders/1_Dz7HS5N3GoOAG6MdhsXWY1nwLoN13DJ?usp=drive_link)\n",
15 | "\n",
16 | "For reference proteomes, you can download them from [here](https://useast.ensembl.org/info/about/species.html).\n",
17 | "\n",
18 | "If there is no protein embedding for the species you are interested in, you can request to have it made via Github or email, or you can create it yourself following instructions [here](https://github.com/snap-stanford/SATURN/tree/main/protein_embeddings)."
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 1,
24 | "id": "ab368d92",
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "import numpy as np\n",
29 | "import pickle as pkl\n",
30 | "import pandas as pd"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": 2,
36 | "id": "c9a306f3",
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "SPECIES_NAME = \"chicken\" # short hand name for this species, will be used in arguments and files\n",
41 | "\n",
42 | "# Path to the species proteome\n",
43 | "SPECIES_PROTEIN_FASTA_PATH = \"../../../SATURN/protein_embeddings/data/Gallus_gallus.bGalGal1.mat.broiler.GRCg7b.pep.all.fa\"\n",
44 | "\n",
45 | "# Path to the ESM2 Embeddings\n",
46 | "SPECIES_PROTEIN_EMBEDDINGS_PATH = \"../model_files/protein_embeddings/Gallus_gallus.bGalGal1.mat.broiler.GRCg7b.pep.all.gene_symbol_to_embedding_ESM2.pt\"\n",
47 | "\n",
48 | "# primary_assembly name, this needs to be matched to the FASTA file\n",
49 | "ASSEMBLY_NAME = \"bGalGal1.mat.broiler.GRCg7b\"\n",
50 | "# NCBI Taxonomy ID, please set this so that if someone else also embeds the same species,\n",
51 | "# randomly generated chromosome tokens will be the same\n",
52 | "TAXONOMY_ID = 9031"
53 | ]
54 | },
55 | {
56 | "cell_type": "markdown",
57 | "id": "e5d37e52",
58 | "metadata": {},
59 | "source": [
60 | "You can view the FASTA format here, please confirm the primary_assembly name is correct."
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 3,
66 | "id": "2ecf1464",
67 | "metadata": {},
68 | "outputs": [
69 | {
70 | "name": "stdout",
71 | "output_type": "stream",
72 | "text": [
73 | ">ENSGALP00010000002.1 pep primary_assembly:bGalGal1.mat.broiler.GRCg7b:MT:2824:3798:1 gene:ENSGALG00010000007.1 transcript:ENSGALT00010000007.1 gene_biotype:protein_coding transcript_biotype:protein_coding gene_symbol:ND1 description:NADH dehydrogenase subunit 1 [Source:NCBI gene (formerly Entrezgene);Acc:63549479]\r\n",
74 | "MTLPTLTNLLIMTLSYILPILIAVAFLTLVERKILSYMQARKGPNIVGPFGLLQPVADGV\r\n",
75 | "KLFIKEPIRPSTSSPFLFIITPILALLLALTIWVPLPLPFPLADLNLGLLFLLAMSSLTV\r\n",
76 | "YSLLWSGWASNSKYALIGALRAVAQTISYEVTLAIILLSTIMLSGNYTLSTLAITQEPIY\r\n",
77 | "LIFSAWPLAMMWYISTLAETNRAPFDLTEGESELVSGFNVEYAAGPFAMFFLAEYANIML\r\n",
78 | "MNTLTTVLFLNPSFLNLPPELFPIALATKTLLLSSSFLWIRASYPRFRYDQLMHLLWKNF\r\n",
79 | "LPLTLALCLWHTSMPISYAGLPPI\r\n",
80 | ">ENSGALP00010000003.1 pep primary_assembly:bGalGal1.mat.broiler.GRCg7b:MT:4015:5053:1 gene:ENSGALG00010000011.1 transcript:ENSGALT00010000011.1 gene_biotype:protein_coding transcript_biotype:protein_coding gene_symbol:ND2 description:NADH dehydrogenase subunit 2 [Source:NCBI gene (formerly Entrezgene);Acc:63549482]\r\n",
81 | "MNPHAKLICTVSLIMGTSITISSNHWILAWTGLEINTLAIIPLISKSHHPRAIEATIKYF\r\n",
82 | "LTQSTASALILFSSMTNAWSTGQWDITQLNHPTSCLMLTMAIAIKLGLVPFHFWFPEVLQ\r\n"
83 | ]
84 | }
85 | ],
86 | "source": [
87 | "!head {SPECIES_PROTEIN_FASTA_PATH}"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": 4,
93 | "id": "90540d0b",
94 | "metadata": {},
95 | "outputs": [],
96 | "source": [
97 | "species_to_paths = {\n",
98 | " SPECIES_NAME: SPECIES_PROTEIN_FASTA_PATH,\n",
99 | "}\n",
100 | "\n",
101 | "species_to_ids = {\n",
102 | " SPECIES_NAME: ASSEMBLY_NAME,\n",
103 | "}"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 5,
109 | "id": "623b99cf",
110 | "metadata": {},
111 | "outputs": [],
112 | "source": [
113 | "all_pos_def = []\n",
114 | "\n",
115 | "missing_genes = {}\n",
116 | "for species in species_to_ids.keys():\n",
117 | " missing_genes[species] = []\n",
118 | " proteome_path = species_to_paths[species]\n",
119 | " species_id = species_to_ids[species]\n",
120 | "\n",
121 | " with open(proteome_path) as f:\n",
122 | " proteome_lines = f.readlines()\n",
123 | "\n",
124 | " gene_symbol_to_location = {}\n",
125 | " gene_symbol_to_chrom = {}\n",
126 | "\n",
127 | " for line in proteome_lines:\n",
128 | " if line.startswith(\">\"):\n",
129 | " split_line = line.split()\n",
130 | " gene_symbol = [token for token in split_line if token.startswith(\"gene_symbol\")]\n",
131 | " if len(gene_symbol) > 0:\n",
132 | " gene_symbol = gene_symbol[0].split(\":\")\n",
133 | " \n",
134 | " if len(gene_symbol) == 2:\n",
135 | " gene_symbol = gene_symbol[1]\n",
136 | " elif len(gene_symbol) > 2:\n",
137 | " gene_symbol = \":\".join(gene_symbol[1:]) # fix for annoying zebrafish gene names with colons in them\n",
138 | " else:\n",
139 | " 1/0 # something weird happening, throw an error\n",
140 | " \n",
141 | " \n",
142 | " chrom = None\n",
143 | " \n",
144 | " chrom_arr = [token for token in split_line if token.startswith(\"chromosome:\")]\n",
145 | " if len(chrom_arr) > 0:\n",
146 | " chrom = chrom_arr[0].replace(\"chromosome:\", \"\")\n",
147 | " else:\n",
148 | " chrom_arr = [token for token in split_line if token.startswith(\"primary_assembly:\")]\n",
149 | " if len(chrom_arr) > 0:\n",
150 | " chrom = chrom_arr[0].replace(\"primary_assembly:\", \"\")\n",
151 | " else:\n",
152 | " chrom_arr = [token for token in split_line if token.startswith(\"scaffold:\")] \n",
153 | " if len(chrom_arr) > 0:\n",
154 | " chrom = chrom_arr[0].replace(\"scaffold:\", \"\")\n",
155 | " if chrom is not None:\n",
156 | " gene_symbol_to_location[gene_symbol] = chrom.split(\":\")[2]\n",
157 | " gene_symbol_to_chrom[gene_symbol] = chrom.split(\":\")[1]\n",
158 | " else:\n",
159 | " missing_genes[species].append(gene_symbol)\n",
160 | " \n",
161 | "\n",
162 | " positional_df = pd.DataFrame()\n",
163 | " positional_df[\"gene_symbol\"] = [gn.upper() for gn in list(gene_symbol_to_chrom.keys())]\n",
164 | " positional_df[\"chromosome\"] = list(gene_symbol_to_chrom.values())\n",
165 | " positional_df[\"start\"] = list(gene_symbol_to_location.values())\n",
166 | " positional_df = positional_df.sort_values([\"chromosome\", \"start\"])\n",
167 | " #positional_df = positional_df.set_index(\"gene_symbol\")\n",
168 | " positional_df[\"species\"] = species\n",
169 | " all_pos_def.append(positional_df)"
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "execution_count": 6,
175 | "id": "b72887b3",
176 | "metadata": {},
177 | "outputs": [
178 | {
179 | "data": {
180 | "text/html": [
181 | "
\n",
182 | "\n",
195 | "
\n",
196 | " \n",
197 | "
\n",
198 | "
\n",
199 | "
gene_symbol
\n",
200 | "
chromosome
\n",
201 | "
start
\n",
202 | "
species
\n",
203 | "
\n",
204 | " \n",
205 | " \n",
206 | "
\n",
207 | "
2327
\n",
208 | "
GCC1
\n",
209 | "
1
\n",
210 | "
1006145
\n",
211 | "
chicken
\n",
212 | "
\n",
213 | "
\n",
214 | "
2502
\n",
215 | "
NCAM2
\n",
216 | "
1
\n",
217 | "
100828671
\n",
218 | "
chicken
\n",
219 | "
\n",
220 | "
\n",
221 | "
3084
\n",
222 | "
ENS-2
\n",
223 | "
1
\n",
224 | "
101147482
\n",
225 | "
chicken
\n",
226 | "
\n",
227 | "
\n",
228 | "
2331
\n",
229 | "
DENND6B
\n",
230 | "
1
\n",
231 | "
1012031
\n",
232 | "
chicken
\n",
233 | "
\n",
234 | "
\n",
235 | "
3973
\n",
236 | "
MRPL39
\n",
237 | "
1
\n",
238 | "
102578362
\n",
239 | "
chicken
\n",
240 | "
\n",
241 | "
\n",
242 | "
...
\n",
243 | "
...
\n",
244 | "
...
\n",
245 | "
...
\n",
246 | "
...
\n",
247 | "
\n",
248 | "
\n",
249 | "
4722
\n",
250 | "
CA9
\n",
251 | "
Z
\n",
252 | "
9779343
\n",
253 | "
chicken
\n",
254 | "
\n",
255 | "
\n",
256 | "
4738
\n",
257 | "
ARHGEF39
\n",
258 | "
Z
\n",
259 | "
9835547
\n",
260 | "
chicken
\n",
261 | "
\n",
262 | "
\n",
263 | "
3885
\n",
264 | "
MRPL17
\n",
265 | "
Z
\n",
266 | "
9850679
\n",
267 | "
chicken
\n",
268 | "
\n",
269 | "
\n",
270 | "
4172
\n",
271 | "
CCBE1
\n",
272 | "
Z
\n",
273 | "
9852827
\n",
274 | "
chicken
\n",
275 | "
\n",
276 | "
\n",
277 | "
3293
\n",
278 | "
PMAIP1
\n",
279 | "
Z
\n",
280 | "
9998272
\n",
281 | "
chicken
\n",
282 | "
\n",
283 | " \n",
284 | "
\n",
285 | "
13271 rows × 4 columns
\n",
286 | "
"
287 | ],
288 | "text/plain": [
289 | " gene_symbol chromosome start species\n",
290 | "2327 GCC1 1 1006145 chicken\n",
291 | "2502 NCAM2 1 100828671 chicken\n",
292 | "3084 ENS-2 1 101147482 chicken\n",
293 | "2331 DENND6B 1 1012031 chicken\n",
294 | "3973 MRPL39 1 102578362 chicken\n",
295 | "... ... ... ... ...\n",
296 | "4722 CA9 Z 9779343 chicken\n",
297 | "4738 ARHGEF39 Z 9835547 chicken\n",
298 | "3885 MRPL17 Z 9850679 chicken\n",
299 | "4172 CCBE1 Z 9852827 chicken\n",
300 | "3293 PMAIP1 Z 9998272 chicken\n",
301 | "\n",
302 | "[13271 rows x 4 columns]"
303 | ]
304 | },
305 | "execution_count": 6,
306 | "metadata": {},
307 | "output_type": "execute_result"
308 | }
309 | ],
310 | "source": [
311 | "master_pos_def = pd.concat(all_pos_def)\n",
312 | "master_pos_def"
313 | ]
314 | },
315 | {
316 | "cell_type": "code",
317 | "execution_count": 7,
318 | "id": "6d9dac28",
319 | "metadata": {},
320 | "outputs": [
321 | {
322 | "data": {
323 | "text/plain": [
324 | "chicken 13271\n",
325 | "Name: species, dtype: int64"
326 | ]
327 | },
328 | "execution_count": 7,
329 | "metadata": {},
330 | "output_type": "execute_result"
331 | }
332 | ],
333 | "source": [
334 | "master_pos_def[\"species\"].value_counts() # double check how many genes are mapped"
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "execution_count": 8,
340 | "id": "4a3d45c2",
341 | "metadata": {},
342 | "outputs": [
343 | {
344 | "name": "stdout",
345 | "output_type": "stream",
346 | "text": [
347 | "chicken: 0\n"
348 | ]
349 | }
350 | ],
351 | "source": [
352 | "for k, v in missing_genes.items():\n",
353 | " print(f\"{k}: {len(v)}\") # are any genes missing?"
354 | ]
355 | },
356 | {
357 | "cell_type": "code",
358 | "execution_count": 9,
359 | "id": "c59774b1",
360 | "metadata": {
361 | "scrolled": true
362 | },
363 | "outputs": [
364 | {
365 | "name": "stdout",
366 | "output_type": "stream",
367 | "text": [
368 | "*********\n",
369 | "chicken\n"
370 | ]
371 | },
372 | {
373 | "data": {
374 | "text/plain": [
375 | "1 1785\n",
376 | "2 1169\n",
377 | "3 1067\n",
378 | "4 953\n",
379 | "5 817\n",
380 | "Z 629\n",
381 | "6 458\n",
382 | "8 450\n",
383 | "7 442\n",
384 | "9 382\n",
385 | "10 366\n",
386 | "14 359\n",
387 | "11 327\n",
388 | "15 326\n",
389 | "13 306\n",
390 | "20 298\n",
391 | "12 293\n",
392 | "19 278\n",
393 | "18 274\n",
394 | "17 260\n",
395 | "26 237\n",
396 | "28 237\n",
397 | "27 235\n",
398 | "21 226\n",
399 | "23 214\n",
400 | "25 176\n",
401 | "34 155\n",
402 | "24 149\n",
403 | "22 142\n",
404 | "16 54\n",
405 | "30 52\n",
406 | "38 49\n",
407 | "31 14\n",
408 | "MT 13\n",
409 | "39 10\n",
410 | "JAENSK010000484.1 7\n",
411 | "35 6\n",
412 | "JAENSK010000592.1 6\n",
413 | "W 5\n",
414 | "MU179278.1 5\n",
415 | "MU179279.1 4\n",
416 | "36 3\n",
417 | "JAENSK010000483.1 3\n",
418 | "JAENSK010000585.1 3\n",
419 | "JAENSK010000593.1 2\n",
420 | "MU179258.1 2\n",
421 | "MU179272.1 2\n",
422 | "MU179273.1 2\n",
423 | "JAENSK010000584.1 2\n",
424 | "JAENSK010000656.1 1\n",
425 | "Name: chromosome, dtype: int64"
426 | ]
427 | },
428 | "metadata": {},
429 | "output_type": "display_data"
430 | },
431 | {
432 | "name": "stdout",
433 | "output_type": "stream",
434 | "text": [
435 | "*********\n"
436 | ]
437 | }
438 | ],
439 | "source": [
440 | "# Count genes per chromosome\n",
441 | "for species in species_to_ids.keys():\n",
442 | " print(\"*********\")\n",
443 | " print(species)\n",
444 | " display(master_pos_def[master_pos_def[\"species\"] == species][\"chromosome\"].value_counts().head(50))\n",
445 | " print(\"*********\")"
446 | ]
447 | },
448 | {
449 | "cell_type": "code",
450 | "execution_count": 10,
451 | "id": "541baded",
452 | "metadata": {},
453 | "outputs": [],
454 | "source": [
455 | "master_pos_def.to_csv(f\"{SPECIES_NAME}_to_chrom_pos.csv\", index=False) # Save the DF"
456 | ]
457 | },
458 | {
459 | "cell_type": "code",
460 | "execution_count": 11,
461 | "id": "eabd0e31",
462 | "metadata": {},
463 | "outputs": [
464 | {
465 | "name": "stdout",
466 | "output_type": "stream",
467 | "text": [
468 | "chicken_to_chrom_pos.csv\n"
469 | ]
470 | }
471 | ],
472 | "source": [
473 | "# The chromosome file path will be:\n",
474 | "print(f\"{SPECIES_NAME}_to_chrom_pos.csv\")"
475 | ]
476 | },
477 | {
478 | "cell_type": "code",
479 | "execution_count": 12,
480 | "id": "fe1345b1",
481 | "metadata": {},
482 | "outputs": [
483 | {
484 | "data": {
485 | "text/plain": [
486 | "66"
487 | ]
488 | },
489 | "execution_count": 12,
490 | "metadata": {},
491 | "output_type": "execute_result"
492 | }
493 | ],
494 | "source": [
495 | "N_UNIQ_CHROM = len(master_pos_def[master_pos_def[\"species\"] == species][\"chromosome\"].unique())\n",
496 | "N_UNIQ_CHROM"
497 | ]
498 | },
499 | {
500 | "cell_type": "markdown",
501 | "id": "e37e277f",
502 | "metadata": {},
503 | "source": [
504 | "# Generate token file"
505 | ]
506 | },
507 | {
508 | "cell_type": "code",
509 | "execution_count": 13,
510 | "id": "d6904975",
511 | "metadata": {},
512 | "outputs": [],
513 | "source": [
514 | "import torch\n",
515 | "import pickle\n",
516 | "token_dim = 5120"
517 | ]
518 | },
519 | {
520 | "cell_type": "markdown",
521 | "id": "a2798848",
522 | "metadata": {},
523 | "source": [
524 | "This will create the token file. Please note the offset value."
525 | ]
526 | },
527 | {
528 | "cell_type": "code",
529 | "execution_count": 14,
530 | "id": "4355dabd",
531 | "metadata": {},
532 | "outputs": [
533 | {
534 | "name": "stdout",
535 | "output_type": "stream",
536 | "text": [
537 | "CHROM_TOKEN_OFFSET: 13275\n",
538 | "Saved PE, offsets file\n"
539 | ]
540 | }
541 | ],
542 | "source": [
543 | "species_to_offsets = {}\n",
544 | "\n",
545 | "all_pe = torch.load(\"../model_files/all_tokens.torch\")[0:4] # read in existing token file to make sure \n",
546 | "# that special vocab tokens are the same for different seeds\n",
547 | "\n",
548 | "offset = len(all_pe) # special tokens at the top!\n",
549 | "\n",
550 | "PE = torch.load(SPECIES_PROTEIN_EMBEDDINGS_PATH)\n",
551 | "\n",
552 | "pe_stacked = torch.stack(list(PE.values()))\n",
553 | "all_pe = torch.vstack((all_pe, pe_stacked))\n",
554 | "species_to_offsets[species] = offset\n",
555 | "\n",
556 | "print(\"CHROM_TOKEN_OFFSET:\", all_pe.shape[0])\n",
557 | "torch.manual_seed(TAXONOMY_ID)\n",
558 | "CHROM_TENSORS = torch.normal(mean=0, std=1, size=(N_UNIQ_CHROM, 5120)) \n",
559 | "# N_UNIQ_CHROM is the total number of chromosome choices, it is hardcoded for now (for species in the training data)\n",
560 | "all_pe = torch.vstack(\n",
561 | " (all_pe, CHROM_TENSORS)) # Add the chrom tensors to the end\n",
562 | "all_pe.requires_grad = False\n",
563 | "\n",
564 | "\n",
565 | "torch.save(all_pe, f\"{SPECIES_NAME}_pe_tokens.torch\")\n",
566 | "\n",
567 | "with open(f\"{SPECIES_NAME}_offsets.pkl\", \"wb+\") as f:\n",
568 | " pickle.dump(species_to_offsets, f)\n",
569 | "print(\"Saved PE, offsets file\")"
570 | ]
571 | },
572 | {
573 | "cell_type": "code",
574 | "execution_count": 15,
575 | "id": "c26fe491",
576 | "metadata": {
577 | "scrolled": true
578 | },
579 | "outputs": [
580 | {
581 | "data": {
582 | "text/plain": [
583 | "torch.Size([13341, 5120])"
584 | ]
585 | },
586 | "execution_count": 15,
587 | "metadata": {},
588 | "output_type": "execute_result"
589 | }
590 | ],
591 | "source": [
592 | "all_pe.shape"
593 | ]
594 | },
595 | {
596 | "cell_type": "code",
597 | "execution_count": 16,
598 | "id": "21f937ea",
599 | "metadata": {
600 | "scrolled": true
601 | },
602 | "outputs": [
603 | {
604 | "data": {
605 | "text/plain": [
606 | "torch.Size([13341, 5120])"
607 | ]
608 | },
609 | "execution_count": 16,
610 | "metadata": {},
611 | "output_type": "execute_result"
612 | }
613 | ],
614 | "source": [
615 | "all_pe.shape"
616 | ]
617 | },
618 | {
619 | "cell_type": "code",
620 | "execution_count": 17,
621 | "id": "5faadace",
622 | "metadata": {},
623 | "outputs": [
624 | {
625 | "name": "stdout",
626 | "output_type": "stream",
627 | "text": [
628 | "chicken_offsets.pkl\n"
629 | ]
630 | }
631 | ],
632 | "source": [
633 | "print(f\"{SPECIES_NAME}_offsets.pkl\")"
634 | ]
635 | },
636 | {
637 | "cell_type": "code",
638 | "execution_count": 18,
639 | "id": "6ceac20b",
640 | "metadata": {},
641 | "outputs": [
642 | {
643 | "data": {
644 | "text/plain": [
645 | "'../model_files/protein_embeddings/Gallus_gallus.bGalGal1.mat.broiler.GRCg7b.pep.all.gene_symbol_to_embedding_ESM2.pt'"
646 | ]
647 | },
648 | "execution_count": 18,
649 | "metadata": {},
650 | "output_type": "execute_result"
651 | }
652 | ],
653 | "source": [
654 | "SPECIES_PROTEIN_EMBEDDINGS_PATH"
655 | ]
656 | },
657 | {
658 | "cell_type": "markdown",
659 | "id": "e4697330",
660 | "metadata": {},
661 | "source": [
662 | "# Example evaluation of new species"
663 | ]
664 | },
665 | {
666 | "cell_type": "markdown",
667 | "id": "2b72667d",
668 | "metadata": {},
669 | "source": [
670 | "**Note: when you evaluate a new species, you need to change some arguments and modify some files:**\n",
671 | "\n",
672 | "You will need to modify the csv in `model_files/new_species_protein_embeddings.csv` to include the new protein embeddings file you downloaded.\n",
673 | "\n",
674 | "In the file add a row for the new species with the format:\n",
675 | "`species name,full path to protein embedding file`\n",
676 | "\n",
677 | "Please also add this line to the dictionary created on line 247 in the file `data_proc/data_utils.py`.\n",
678 | "\n",
679 | "When you want to embed this new species, you will need to specify these newly created files as arguments.\n",
680 | "- `CHROM_TOKEN_OFFSET`: This tells UCE when the rows corresponding to chromosome tokens starts.\n",
681 | "- `spec_chrom_csv_path`: This is a new csv, created by this script, which maps genes to chromosomes and genomic positions\n",
682 | "- `token_file`: This is a new token file that will work just for this species. The embeddings generated will still be universal though!\n",
683 | "- `offset_pkl_path`: This is another file that maps genes to tokens\n",
684 | "\n",
685 | "\n",
686 | "```\n",
687 | "\n",
688 | "accelerate launch eval_single_anndata.py chicken_heart.h5ad --species=chicken --CHROM_TOKEN_OFFSET=13275 --spec_chrom_csv_path=data_proc/chicken_to_chrom_pos.csv --token_file=data_proc/chicken_pe_tokens.torch --offset_pkl_path=data_proc/chicken_offsets.pkl --dir=... --multi_gpu=True\n",
689 | "\n",
690 | "```"
691 | ]
692 | }
693 | ],
694 | "metadata": {
695 | "kernelspec": {
696 | "display_name": "Python 3 (ipykernel)",
697 | "language": "python",
698 | "name": "python3"
699 | },
700 | "language_info": {
701 | "codemirror_mode": {
702 | "name": "ipython",
703 | "version": 3
704 | },
705 | "file_extension": ".py",
706 | "mimetype": "text/x-python",
707 | "name": "python",
708 | "nbconvert_exporter": "python",
709 | "pygments_lexer": "ipython3",
710 | "version": "3.8.6"
711 | }
712 | },
713 | "nbformat": 4,
714 | "nbformat_minor": 5
715 | }
716 |
--------------------------------------------------------------------------------
/data_proc/__pycache__/data_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snap-stanford/UCE/8227a65cdd021b9186ef86671d2aef5c895c8e4b/data_proc/__pycache__/data_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/data_proc/__pycache__/gene_embeddings.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snap-stanford/UCE/8227a65cdd021b9186ef86671d2aef5c895c8e4b/data_proc/__pycache__/gene_embeddings.cpython-38.pyc
--------------------------------------------------------------------------------
/data_proc/data_utils.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.filterwarnings("ignore")
3 |
4 | import scanpy as sc
5 | import torch
6 |
7 | from torch import nn, Tensor
8 | import torch.nn.functional as F
9 | import torch.utils.data as data
10 | import torch.optim as optim
11 | import numpy as np
12 | import pickle
13 | import os
14 | import argparse
15 | import logging
16 | import time
17 |
18 | from tqdm.auto import tqdm
19 | import pandas as pd
20 |
21 | import math
22 | import anndata
23 | from pathlib import Path
24 |
25 |
26 | from torch.utils.data import dataset
27 | from torch.utils.data import DataLoader, TensorDataset, dataset
28 | from scipy.stats import binom
29 | from typing import Dict, List, Optional, Tuple
30 | from scanpy import AnnData
31 |
32 |
33 | from data_proc.gene_embeddings import load_gene_embeddings_adata
34 |
35 | def data_to_torch_X(X):
36 | if isinstance(X, sc.AnnData):
37 | X = X.X
38 | if not isinstance(X, np.ndarray):
39 | X = X.toarray()
40 | return torch.from_numpy(X).float()
41 |
42 | class SincleCellDataset(data.Dataset):
43 | def __init__(self,
44 | expression: torch.tensor, # Subset to hv genes, count data! cells x genes
45 | protein_embeddings: torch.tensor, # same order as expression, also subset genes x pe
46 | labels: None, # optional, tensor of labels
47 | covar_vals: None, # tensor of covar values or none
48 | ) -> None:
49 | super(SincleCellDataset, self).__init__()
50 |
51 | # Set expression
52 | self.expression = expression
53 |
54 | row_sums = self.expression.sum(1) # UMI Counts
55 | log_norm_count_adj = torch.log1p(self.expression / (self.expression.sum(1)).unsqueeze(1) * torch.tensor(1000))
56 |
57 | # Set log norm and count adjusted expression
58 | max_vals, max_idx = torch.max(log_norm_count_adj, dim=0)
59 | self.expression_mod = log_norm_count_adj / max_vals
60 |
61 | # Calculate dropout likliehoods of each gene
62 | self.dropout_vec = (self.expression == 0).float().mean(0) # per gene dropout percentages
63 |
64 | # Set data info
65 | self.num_cells = self.expression.shape[0]
66 | self.num_genes = self.expression.shape[1]
67 |
68 | # Set optional label info, including categorical covariate index
69 | self.covar_vals = covar_vals
70 | self.labels = labels
71 |
72 | # Set protein embeddings
73 | self.protein_embeddings = protein_embeddings
74 |
75 | self.item_mode = "expression"
76 | if self.covar_vals is not None:
77 | self.item_mode = "expression+covar"
78 |
79 |
80 | def __getitem__(self, idx):
81 | if self.item_mode == "expression":
82 | if isinstance(idx, int):
83 | if idx < self.num_cells:
84 | return self.expression[idx, :]
85 | else:
86 | raise IndexError
87 | else:
88 | raise NotImplementedError
89 | elif self.item_mode == "expression+covar":
90 | if isinstance(idx, int):
91 | if idx < self.num_cells:
92 | return self.expression[idx, :], self.covar_vals[idx]
93 | else:
94 | raise IndexError
95 | else:
96 | raise NotImplementedError
97 |
98 |
99 | def __len__(self) -> int:
100 | return self.num_cells
101 |
102 | def get_dim(self) -> Dict[str, int]:
103 | return self.num_genes
104 |
105 |
106 | def data_to_torch_X(X):
107 | if isinstance(X, sc.AnnData):
108 | X = X.X
109 | if not isinstance(X, np.ndarray):
110 | X = X.toarray()
111 | return torch.from_numpy(X).float()
112 |
113 |
114 | def anndata_to_sc_dataset(adata:sc.AnnData,
115 | species:str="human",
116 | labels:list=[],
117 | covar_col:str=None,
118 | hv_genes=None,
119 | embedding_model="ESM2",
120 | ) -> (SincleCellDataset, AnnData):
121 |
122 | # Subset to just genes we have embeddings for
123 | adata, protein_embeddings = load_gene_embeddings_adata(
124 | adata=adata,
125 | species=[species],
126 | embedding_model=embedding_model
127 | )
128 |
129 | if hv_genes is not None:
130 | sc.pp.highly_variable_genes(adata, flavor='seurat_v3', n_top_genes=hv_genes) # Expects Count Data
131 |
132 | hv_index = adata.var["highly_variable"]
133 | adata = adata[:, hv_index] # Subset to hv genes only
134 |
135 | protein_embeddings = protein_embeddings[species][hv_index]
136 | else:
137 | protein_embeddings = protein_embeddings[species]
138 | expression = data_to_torch_X(adata.X)
139 |
140 | covar_vals = None
141 | if len(labels) > 0:
142 | assert covar_col is None or covar_col in labels, "Covar needs to be in labels" # make sure you keep track of covar column!
143 | labels = adata.obs.loc[:, labels].values
144 |
145 | if covar_col is not None:
146 | # we have a categorical label to use as covariate
147 | covar_vals = torch.tensor(pd.Categorical(adata.obs[covar_col]).codes)
148 | return SincleCellDataset(
149 | expression=expression,
150 | protein_embeddings=protein_embeddings,
151 | labels=labels,
152 | covar_vals=covar_vals
153 | ), adata
154 |
155 | def adata_path_to_prot_chrom_starts(adata, dataset_species, spec_pe_genes, gene_to_chrom_pos, offset):
156 | """
157 | Given a :path: to an h5ad,
158 | """
159 | pe_row_idxs = torch.tensor([spec_pe_genes.index(k.upper()) + offset for k in adata.var_names]).long()
160 | print(len(np.unique(pe_row_idxs)))
161 |
162 | spec_chrom = gene_to_chrom_pos[gene_to_chrom_pos["species"] == dataset_species].set_index("gene_symbol")
163 |
164 | gene_chrom = spec_chrom.loc[[k.upper() for k in adata.var_names]]
165 |
166 | dataset_chroms = gene_chrom["spec_chrom"].cat.codes # now this is correctely indexed by species and chromosome
167 | print("Max Code:", max(dataset_chroms))
168 | dataset_pos = gene_chrom["start"].values
169 | return pe_row_idxs, dataset_chroms, dataset_pos
170 |
171 |
172 |
173 | def process_raw_anndata(row, h5_folder_path, npz_folder_path, scp, skip,
174 | additional_filter, root):
175 | path = row.path
176 | if not os.path.isfile(root + "/" + path):
177 | print( "**********************************")
178 | print(f"***********{root + '/' + path} File Missing****")
179 | print( "**********************************")
180 | print(path, root)
181 | return None
182 |
183 | name = path.replace(".h5ad", "")
184 | proc_path = path.replace(".h5ad", "_proc.h5ad")
185 | if skip:
186 | if os.path.isfile(h5_folder_path + proc_path):
187 | print(f"{name} already processed. Skipping")
188 | return None, None, None
189 |
190 | print(f"Proccessing {name}")
191 |
192 | species = row.species
193 | covar_col = row.covar_col
194 |
195 | ad = sc.read(root + "/" + path)
196 | labels = []
197 | if "cell_type" in ad.obs.columns:
198 | labels.append("cell_type")
199 |
200 |
201 | if covar_col is np.nan or np.isnan(covar_col):
202 | covar_col = None
203 | else:
204 | labels.append(covar_col)
205 |
206 | if additional_filter:
207 | sc.pp.filter_genes(ad, min_cells=10)
208 | sc.pp.filter_cells(ad, min_genes=25)
209 |
210 |
211 | dataset, adata = anndata_to_sc_dataset(ad, species=species, labels=labels, covar_col=covar_col, hv_genes=None)
212 | adata = adata.copy()
213 |
214 | if additional_filter:
215 | sc.pp.filter_genes(ad, min_cells=10)
216 | sc.pp.filter_cells(ad, min_genes=25)
217 |
218 | num_cells = adata.X.shape[0]
219 | num_genes = adata.X.shape[1]
220 |
221 | adata_path = h5_folder_path + proc_path
222 | adata.write(adata_path)
223 |
224 | arr = data_to_torch_X(adata.X).numpy()
225 |
226 | print(arr.max()) # this is a nice check to make sure it's counts
227 | filename = npz_folder_path + f"{name}_counts.npz"
228 | shape = arr.shape
229 | print(name, shape)
230 | fp = np.memmap(filename, dtype='int64', mode='w+', shape=shape)
231 | fp[:] = arr[:]
232 | fp.flush()
233 |
234 | if scp != "":
235 | subprocess.call(["scp", filename, f"{scp}:{filename}"])
236 | subprocess.call(["scp", adata_path, f"{scp}:{adata_path}"])
237 |
238 | return adata, num_cells, num_genes
239 |
240 |
241 | def get_species_to_pe(EMBEDDING_DIR):
242 | """
243 | Given an embedding directory, return all embeddings as a dictionary coded by species.
244 | Note: In the current form, this function is written such that the directory needs all of the following species embeddings.
245 | """
246 | EMBEDDING_DIR = Path(EMBEDDING_DIR)
247 |
248 | embeddings_paths = {
249 | 'human': EMBEDDING_DIR / 'Homo_sapiens.GRCh38.gene_symbol_to_embedding_ESM2.pt',
250 | 'mouse': EMBEDDING_DIR / 'Mus_musculus.GRCm39.gene_symbol_to_embedding_ESM2.pt',
251 | 'frog': EMBEDDING_DIR / 'Xenopus_tropicalis.Xenopus_tropicalis_v9.1.gene_symbol_to_embedding_ESM2.pt',
252 | 'zebrafish': EMBEDDING_DIR / 'Danio_rerio.GRCz11.gene_symbol_to_embedding_ESM2.pt',
253 | "mouse_lemur": EMBEDDING_DIR / "Microcebus_murinus.Mmur_3.0.gene_symbol_to_embedding_ESM2.pt",
254 | "pig": EMBEDDING_DIR / 'Sus_scrofa.Sscrofa11.1.gene_symbol_to_embedding_ESM2.pt',
255 | "macaca_fascicularis": EMBEDDING_DIR / 'Macaca_fascicularis.Macaca_fascicularis_6.0.gene_symbol_to_embedding_ESM2.pt',
256 | "macaca_mulatta": EMBEDDING_DIR / 'Macaca_mulatta.Mmul_10.gene_symbol_to_embedding_ESM2.pt',
257 | }
258 | extra_species = pd.read_csv("./model_files/new_species_protein_embeddings.csv").set_index("species").to_dict()["path"]
259 | embeddings_paths.update(extra_species) # adds new species
260 |
261 |
262 |
263 | species_to_pe = {
264 | species:torch.load(pe_dir) for species, pe_dir in embeddings_paths.items()
265 | }
266 |
267 | species_to_pe = {species:{k.upper(): v for k,v in pe.items()} for species, pe in species_to_pe.items()}
268 | return species_to_pe
269 |
270 |
271 | def get_spec_chrom_csv(path="/dfs/project/cross-species/yanay/code/all_to_chrom_pos.csv"):
272 | """
273 | Get the species to chrom csv file
274 | """
275 | gene_to_chrom_pos = pd.read_csv(path)
276 | gene_to_chrom_pos["spec_chrom"] = pd.Categorical(gene_to_chrom_pos["species"] + "_" + gene_to_chrom_pos["chromosome"]) # add the spec_chrom list
277 | return gene_to_chrom_pos
--------------------------------------------------------------------------------
/data_proc/download_proc_czi_cxg.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["OMP_NUM_THREADS"] = "20" # export OMP_NUM_THREADS=4
3 | os.environ["OPENBLAS_NUM_THREADS"] = "20" # export OPENBLAS_NUM_THREADS=4
4 | os.environ["MKL_NUM_THREADS"] = "20" # export MKL_NUM_THREADS=6
5 | os.environ["VECLIB_MAXIMUM_THREADS"] = "20" # export VECLIB_MAXIMUM_THREADS=4
6 | os.environ["NUMEXPR_NUM_THREADS"] = "20"
7 |
8 |
9 | import warnings
10 | warnings.filterwarnings('ignore')
11 |
12 | import cellxgene_census
13 | from tqdm import tqdm
14 | import scanpy as sc
15 |
16 | from collections import defaultdict
17 | from typing import Dict, List, Optional, Tuple
18 |
19 | import torch
20 | import torch.utils.data as data
21 | import torch
22 | import numpy as np
23 | import scanpy as sc
24 | from numpy import array
25 | import os
26 | import pickle as pkl
27 | import glob
28 |
29 | def data_to_torch_X(X):
30 | if isinstance(X, sc.AnnData):
31 | X = X.X
32 | if not isinstance(X, np.ndarray):
33 | X = X.toarray()
34 | return torch.from_numpy(X).float()
35 |
36 | import sys
37 | sys.path.append('../')
38 |
39 | from gene_embeddings import load_gene_embeddings_adata
40 | import pandas as pd
41 | import numpy as np
42 | from scanpy import AnnData
43 | from multiprocessing import Pool, Process, Manager
44 |
45 | import multiprocessing.pool as mpp
46 | # https://stackoverflow.com/questions/57354700/starmap-combined-with-tqdm
47 | def istarmap(self, func, iterable, chunksize=1):
48 | """starmap-version of imap
49 | """
50 | if self._state != mpp.RUN:
51 | raise ValueError("Pool not running")
52 |
53 | if chunksize < 1:
54 | raise ValueError(
55 | "Chunksize must be 1+, not {0:n}".format(
56 | chunksize))
57 |
58 | task_batches = mpp.Pool._get_tasks(func, iterable, chunksize)
59 | result = mpp.IMapIterator(self._cache)
60 | self._taskqueue.put(
61 | (
62 | self._guarded_task_generation(result._job,
63 | mpp.starmapstar,
64 | task_batches),
65 | result._set_length
66 | ))
67 | return (item for chunk in result for item in chunk)
68 |
69 |
70 | mpp.Pool.istarmap = istarmap
71 |
72 |
73 | VERSION = "2023-04-25"
74 | N_TOP_GENES = 12000
75 |
76 |
77 | print(cellxgene_census.get_census_version_description(VERSION))
78 |
79 | census = cellxgene_census.open_soma(census_version=VERSION)
80 | census_datasets = census["census_info"]["datasets"].read().concat().to_pandas()
81 |
82 | # for convenience, indexing on the soma_joinid which links this to other census data.
83 | census_datasets = census_datasets.set_index("soma_joinid")
84 |
85 | species_to_readable = {
86 | "Homo sapiens":"human",
87 | "Mus musculus":"mouse"
88 | }
89 |
90 | def process_row(row, num_genes, num_cells, paths, all_species, covar_cols, dataset_title, h5_root="/dfs/project/uce/cxg_data/anndatas/", npz_root="/dfs/project/uce/cxg_data/npzs/"):
91 | dataset_id = row[1].dataset_id
92 | #dataset_title = row[1].dataset_title.lower().replace(' ', '_').replace(",", "").replace("/", "")
93 |
94 | save_path = h5_root + f"{dataset_title}.h5ad"
95 | no_primary_path = save_path.replace(".h5ad", "_no_primary.h5ad")
96 | proc_path = save_path.replace(".h5ad", "_proc.h5ad")
97 | npz_path = npz_root + f"{dataset_title}_counts.npz"
98 | # Download the anndata
99 |
100 | if os.path.exists(no_primary_path):
101 | print("No Primary, skipping")
102 | return
103 |
104 | if not os.path.exists(save_path) and not os.path.exists(no_primary_path):
105 | cellxgene_census.download_source_h5ad(
106 | dataset_id, to_path=save_path
107 | )
108 | if os.path.exists(proc_path) and os.path.exists(npz_path):
109 | print("Already Proc")
110 | try:
111 | ad = sc.read(proc_path)
112 | except:
113 | print()
114 | print()
115 | print("Error reading on:", dataset_title)
116 | print()
117 | print()
118 | return
119 | # Get organism
120 | if "organism" in ad.obs.columns:
121 | unique_organisms = list(ad.obs.organism.unique().categories)
122 | unique_organism_str = ", ".join(unique_organisms)
123 | else:
124 | unique_organism_str = "human"
125 | species = species_to_readable.get(unique_organism_str, "human")
126 | # don't need to do hv if already proc
127 | if "sample" in ad.obs.columns:
128 | covar_cols[dataset_title] = "sample"
129 | elif "batch" in ad.obs.columns:
130 | covar_cols[dataset_title] = "batch"
131 | else:
132 | covar_cols[dataset_title] = ""
133 |
134 |
135 | num_genes[dataset_title] = ad.X.shape[1]
136 | num_cells[dataset_title] = ad.X.shape[0]
137 | paths[dataset_title] = f"{dataset_title}.h5ad"
138 | all_species[dataset_title] = species
139 |
140 | return # Skip everything else
141 | # Read the raw AD
142 | ad = sc.read(save_path)
143 |
144 | # Change to counts
145 | if not sc._utils.check_nonnegative_integers(ad.X):
146 | # don't have counts yet, need raw
147 | if ad.raw is None:
148 | print("Skipped, no counts")
149 | return
150 | ad.X = ad.raw.X.toarray()
151 | if not sc._utils.check_nonnegative_integers(ad.X):
152 | print("Skipped, no counts")
153 | return
154 |
155 | # SUBSET TO primary data
156 | if len(np.unique(ad.obs["is_primary_data"])) >= 1:
157 | primary_data = ad.obs.is_primary_data.value_counts()
158 | ad = ad[ad.obs.is_primary_data]
159 | if ad.X.shape[0] == 0:
160 | print("no primary data")
161 | print(primary_data)
162 | os.rename(save_path, no_primary_path)
163 | return # No primary data
164 | print("has primary data")
165 | # Switch to gene symbols
166 | ad.var["feature_id_orig"] = list(ad.var.index)
167 | ad.var_names = list(ad.var.feature_name)
168 |
169 | # Get organism
170 | if "organism" in ad.obs.columns:
171 | unique_organisms = list(ad.obs.organism.unique().categories)
172 | unique_organism_str = ", ".join(unique_organisms)
173 | else:
174 | unique_organism_str = "human"
175 | species = species_to_readable.get(unique_organism_str, "human")
176 | # Filter to gene symbols with protein embeddings
177 | ad, _ = load_gene_embeddings_adata(
178 | adata=ad,
179 | species=[species],
180 | embedding_model="ESM2"
181 | )
182 |
183 | ad = ad.copy()
184 | # Simple filtering by counts
185 | sc.pp.filter_cells(ad, min_genes=200)
186 | sc.pp.filter_genes(ad, min_cells=10)
187 |
188 | #print(ad)
189 |
190 | if "sample" in ad.obs.columns:
191 | try:
192 | sc.pp.highly_variable_genes(ad, flavor="seurat_v3", n_top_genes=N_TOP_GENES, subset=True, batch_key="sample")
193 | except:
194 | try:
195 | sc.pp.highly_variable_genes(ad, flavor="seurat_v3", n_top_genes=N_TOP_GENES, subset=True, batch_key="sample", span=1)
196 | except:
197 | print(f"can't hv gene subset {dataset_title}")
198 | covar_cols[dataset_title] = "sample"
199 | elif "batch" in ad.obs.columns:
200 | try:
201 | sc.pp.highly_variable_genes(ad, flavor="seurat_v3", n_top_genes=N_TOP_GENES, subset=True, batch_key="batch")
202 | except:
203 | try:
204 | sc.pp.highly_variable_genes(ad, flavor="seurat_v3", n_top_genes=N_TOP_GENES, subset=True, batch_key="batch", span=1)
205 | except:
206 | print(f"can't hv gene subset {dataset_title}")
207 | covar_cols[dataset_title] = "batch"
208 | else:
209 | try:
210 | sc.pp.highly_variable_genes(ad, flavor="seurat_v3", n_top_genes=N_TOP_GENES, subset=True)
211 | except:
212 | try:
213 | sc.pp.highly_variable_genes(ad, flavor="seurat_v3", n_top_genes=N_TOP_GENES, subset=True, span=1)
214 | except:
215 | print(f"can't hv gene subset {dataset_title}")
216 | covar_cols[dataset_title] = ""
217 |
218 | num_genes[dataset_title] = ad.X.shape[1]
219 | num_cells[dataset_title] = ad.X.shape[0]
220 | paths[dataset_title] = f"{dataset_title}.h5ad"
221 | all_species[dataset_title] = species
222 |
223 | print("writing proc")
224 | ad.write(proc_path)
225 |
226 | arr = data_to_torch_X(ad.X).numpy()
227 |
228 | shape = arr.shape
229 |
230 | fp = np.memmap(npz_path, dtype='int64', mode='w+', shape=shape)
231 | fp[:] = arr[:]
232 | fp.flush()
233 |
234 | return
235 |
236 | if __name__ == '__main__':
237 | '''
238 | manager = Manager()
239 | num_genes = manager.dict()
240 | num_cells = manager.dict()
241 | paths = manager.dict()
242 | all_species = manager.dict()
243 | covar_cols = manager.dict()
244 | '''
245 | num_genes = {}
246 | num_cells = {}
247 | paths = {}
248 | all_species = {}
249 | covar_cols = {}
250 |
251 | df = pd.DataFrame()
252 | # Shuffle the dataset
253 | census_datasets = census_datasets#.iloc[270:]
254 | iterrows = list(census_datasets.iterrows())
255 | #p = Pool(8)
256 | #for row in tqdm(iterrows, total=len(census_datasets)):
257 | # p.apply_async(process_row, args=(row, num_genes, num_cells, paths, all_species, covar_cols))
258 | #p.close()
259 | #p.join()
260 | '''
261 | with Pool(1) as p:
262 | nrows = len(iterrows)
263 | inputs = zip(iterrows, [num_genes]*nrows, [num_cells]*nrows, [paths]*nrows, [all_species]*nrows, [covar_cols]*nrows)
264 | for _ in tqdm(p.istarmap(process_row, inputs),
265 | total=nrows):
266 | pass
267 |
268 | '''
269 |
270 | if os.path.exists("dataset_rows_mouse_fixed.pkl"):
271 | dataset_rows = {}
272 | for path in glob.glob("dataset_rows_mouse_fixed*.pkl"):
273 | with open(path, "rb") as f:
274 | dataset_rows_path = pkl.load(f)
275 | dataset_rows.update(dataset_rows_path)
276 |
277 | print(f"{len(dataset_rows)} already counted")
278 | else:
279 | dataset_rows = {}
280 |
281 |
282 | pbar = tqdm(iterrows)
283 | all_errors = []
284 | total_number_of_cells = 0
285 |
286 | duplicate_titles = ['Dissection: Body of hippocampus (HiB) - Rostral DG-CA4', 'Retina',
287 | 'Colon', 'Myeloid cells', 'Ileum', 'Airway']
288 | duplicate_titles_2 = ['retina', 'airway', 'myeloid_cells', 'colon', 'ileum', 'immune_cells']
289 |
290 | for row in pbar:
291 | dataset_title = row[1].dataset_title
292 | if dataset_title in duplicate_titles:
293 | dataset_title = row[1].collection_name + row[1].dataset_title
294 |
295 | dataset_title = dataset_title.lower().replace(' ', '_').replace(",", "").replace("/", "")
296 |
297 | if dataset_title in duplicate_titles_2:
298 | dataset_title = (row[1].collection_name + "_" + dataset_title).lower().replace(' ', '_').replace(",", "").replace("/", "")
299 |
300 | print(f"{total_number_of_cells} cells done")
301 | if dataset_title in dataset_rows:
302 | paths[dataset_title] = dataset_rows[dataset_title][0]
303 | all_species[dataset_title] = dataset_rows[dataset_title][1]
304 | covar_cols[dataset_title] = dataset_rows[dataset_title][2]
305 | num_cells[dataset_title] = dataset_rows[dataset_title][3]
306 | num_genes[dataset_title] = dataset_rows[dataset_title][4]
307 | #print("skipped read of proc")
308 |
309 | total_number_of_cells += dataset_rows[dataset_title][3]
310 | continue # Skip!
311 | else:
312 | pbar.set_description(f"{dataset_title} proc")
313 | try:
314 | process_row(row, num_genes, num_cells, paths, all_species, covar_cols, dataset_title=dataset_title)
315 | except:
316 | print(f"****{dataset_title} ERROR****")
317 | all_errors.append(dataset_title)
318 |
319 |
320 | pbar.set_description(f"{dataset_title} done")
321 |
322 | if dataset_title in paths:
323 | dataset_rows[dataset_title] = [paths[dataset_title], all_species[dataset_title], covar_cols[dataset_title], num_cells[dataset_title], num_genes[dataset_title], dataset_title]
324 |
325 | total_number_of_cells += dataset_rows[dataset_title][3]
326 |
327 | with open("dataset_rows_mouse_fixed.pkl", "wb") as f:
328 | pkl.dump(dataset_rows, f)
329 | print("wrote pkl")
330 |
331 | # path,species,covar_col,num_cells,names
332 |
333 | df["path"] = list(paths.values())
334 | df["species"] = list(all_species.values())
335 | df["covar_col"] = list(covar_cols.values())
336 | df["num_cells"] = list(num_cells.values())
337 | df["num_genes"] = list(num_genes.values())
338 | df["names"] = list(paths.keys())
339 |
340 | print(df.head(20))
341 | print()
342 | print("Errors:")
343 | print(all_errors)
344 | df.to_csv("cxg_datasets.csv", index=False)
345 |
--------------------------------------------------------------------------------
/data_proc/gene_embeddings.py:
--------------------------------------------------------------------------------
1 | """Helper functions for loading pretrained gene embeddings."""
2 | from pathlib import Path
3 | from typing import Dict, Tuple
4 |
5 | import torch
6 |
7 | from scanpy import AnnData
8 | import numpy as np
9 | import pandas as pd
10 |
11 |
12 | EMBEDDING_DIR = Path('model_files/protein_embeddings')
13 | MODEL_TO_SPECIES_TO_GENE_EMBEDDING_PATH = {
14 | 'ESM2': {
15 | 'human': EMBEDDING_DIR / 'Homo_sapiens.GRCh38.gene_symbol_to_embedding_ESM2.pt',
16 | 'mouse': EMBEDDING_DIR / 'Mus_musculus.GRCm39.gene_symbol_to_embedding_ESM2.pt',
17 | 'frog': EMBEDDING_DIR / 'Xenopus_tropicalis.Xenopus_tropicalis_v9.1.gene_symbol_to_embedding_ESM2.pt',
18 | 'zebrafish': EMBEDDING_DIR / 'Danio_rerio.GRCz11.gene_symbol_to_embedding_ESM2.pt',
19 | "mouse_lemur": EMBEDDING_DIR / "Microcebus_murinus.Mmur_3.0.gene_symbol_to_embedding_ESM2.pt",
20 | "pig": EMBEDDING_DIR / 'Sus_scrofa.Sscrofa11.1.gene_symbol_to_embedding_ESM2.pt',
21 | "macaca_fascicularis": EMBEDDING_DIR / 'Macaca_fascicularis.Macaca_fascicularis_6.0.gene_symbol_to_embedding_ESM2.pt',
22 | "macaca_mulatta": EMBEDDING_DIR / 'Macaca_mulatta.Mmul_10.gene_symbol_to_embedding_ESM2.pt',
23 | }
24 | }
25 |
26 | extra_species = pd.read_csv("./model_files/new_species_protein_embeddings.csv").set_index("species").to_dict()["path"]
27 | MODEL_TO_SPECIES_TO_GENE_EMBEDDING_PATH["ESM2"].update(extra_species) # adds new species
28 |
29 |
30 | def load_gene_embeddings_adata(adata: AnnData, species: list, embedding_model: str) -> Tuple[AnnData, Dict[str, torch.FloatTensor]]:
31 | """Loads gene embeddings for all the species/genes in the provided data.
32 |
33 | :param data: An AnnData object containing gene expression data for cells.
34 | :param species: Species corresponding to this adata
35 |
36 | :param embedding_model: The gene embedding model whose embeddings will be loaded.
37 | :return: A tuple containing:
38 | - A subset of the data only containing the gene expression for genes with embeddings in all species.
39 | - A dictionary mapping species name to the corresponding gene embedding matrix (num_genes, embedding_dim).
40 | """
41 | # Get species names
42 | species_names = species
43 | species_names_set = set(species_names)
44 |
45 | # Get embedding paths for the model
46 | species_to_gene_embedding_path = MODEL_TO_SPECIES_TO_GENE_EMBEDDING_PATH[embedding_model]
47 | available_species = set(species_to_gene_embedding_path)
48 |
49 | # Ensure embeddings are available for all species
50 | if not (species_names_set <= available_species):
51 | raise ValueError(f'The following species do not have gene embeddings: {species_names_set - available_species}')
52 |
53 | # Load gene embeddings for desired species (and convert gene symbols to lower case)
54 | species_to_gene_symbol_to_embedding = {
55 | species: {
56 | gene_symbol.lower(): gene_embedding
57 | for gene_symbol, gene_embedding in torch.load(species_to_gene_embedding_path[species]).items()
58 | }
59 | for species in species_names
60 | }
61 |
62 | # Determine which genes to include based on gene expression and embedding availability
63 | genes_with_embeddings = set.intersection(*[
64 | set(gene_symbol_to_embedding)
65 | for gene_symbol_to_embedding in species_to_gene_symbol_to_embedding.values()
66 | ])
67 | genes_to_use = {gene for gene in adata.var_names if gene.lower() in genes_with_embeddings}
68 |
69 | # Subset data to only use genes with embeddings
70 | adata = adata[:, adata.var_names.isin(genes_to_use)]
71 |
72 | # Set up dictionary mapping species to gene embedding matrix (num_genes, embedding_dim)
73 | species_to_gene_embeddings = {
74 | species_name: torch.stack([
75 | species_to_gene_symbol_to_embedding[species_name][gene_symbol.lower()]
76 | for gene_symbol in adata.var_names
77 | ])
78 | for species_name in species_names
79 | }
80 |
81 | return adata, species_to_gene_embeddings
82 |
--------------------------------------------------------------------------------
/data_proc/generate_reduced_chrom_files.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["OMP_NUM_THREADS"] = "4" # export OMP_NUM_THREADS=4
3 | os.environ["OPENBLAS_NUM_THREADS"] = "4" # export OPENBLAS_NUM_THREADS=4
4 | os.environ["MKL_NUM_THREADS"] = "4" # export MKL_NUM_THREADS=6
5 | os.environ["VECLIB_MAXIMUM_THREADS"] = "4" # export VECLIB_MAXIMUM_THREADS=4
6 | os.environ["NUMEXPR_NUM_THREADS"] = "4"
7 |
8 |
9 | import warnings
10 | warnings.filterwarnings("ignore")
11 |
12 | import scanpy as sc
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | import torch.optim as optim
17 | import numpy as np
18 | import pickle
19 | import os
20 | import argparse
21 | import logging
22 | import time
23 |
24 | from tqdm.auto import tqdm
25 | import matplotlib.pyplot as plt
26 | import pandas as pd
27 |
28 | #sc._settings.ScanpyConfig.n_jobs = 6
29 |
30 | import math
31 | from typing import Tuple
32 |
33 | import torch
34 | from torch import nn, Tensor
35 | import torch.nn.functional as F
36 | from torch.nn import TransformerEncoder, TransformerEncoderLayer
37 | from torch.utils.data import dataset
38 |
39 |
40 | from accelerate import Accelerator
41 | import anndata
42 | from data_utils import adata_path_to_prot_chrom_starts, get_spec_chrom_csv
43 |
44 |
45 |
46 | from torch.utils.data import dataset
47 | from torch.utils.data import DataLoader, TensorDataset
48 | from scipy.stats import binom
49 |
50 |
51 |
52 |
53 | def padding_tensor(sequences):
54 | """
55 | :param sequences: list of tensors
56 | :return:
57 | """
58 | num = len(sequences)
59 | max_len = max([s.size(0) for s in sequences])
60 | out_dims = (num, max_len, 1280)
61 |
62 |
63 | out_tensor = sequences[0].data.new(*out_dims).fill_(0)
64 | out_dims2 = (num, max_len)
65 |
66 | mask = sequences[0].data.new(*out_dims2).fill_(float('-inf'))
67 | for i, tensor in enumerate(sequences):
68 | length = tensor.size(0)
69 | out_tensor[i, :length] = tensor
70 | mask[i, :length] = 1
71 | return out_tensor.permute(1, 0, 2), mask
72 |
73 |
74 | from pathlib import Path
75 | # ESM1b
76 | '''
77 | EMBEDDING_DIR = Path('/dfs/project/cross-species/data/proteome/embeddings')
78 | human_pe_dir = EMBEDDING_DIR / 'Homo_sapiens.GRCh38.gene_symbol_to_embedding_ESM1b.pt'
79 | mouse_pe_dir = EMBEDDING_DIR / 'Mus_musculus.GRCm39.gene_symbol_to_embedding_ESM1b.pt'
80 | lemur_pe_dir = Path("/dfs/project/cross-species/yanay/data/proteome/embeddings/") / 'Microcebus_murinus.Mmur_3.0.gene_symbol_to_embedding_ESM1b.pt'
81 |
82 | '''
83 |
84 | # Upgrade to ESM2
85 | EMBEDDING_DIR = Path('/dfs/project/cross-species/data/proteome/embeddings')
86 | EMBEDDING_DIR = Path('/dfs/project/cross-species/yanay/data/proteome/embeddings')
87 |
88 | embeddings_paths = {
89 | 'human': EMBEDDING_DIR / 'Homo_sapiens.GRCh38.gene_symbol_to_embedding_ESM2.pt',
90 | 'mouse': EMBEDDING_DIR / 'Mus_musculus.GRCm39.gene_symbol_to_embedding_ESM2.pt',
91 | 'frog': EMBEDDING_DIR / 'Xenopus_tropicalis.Xenopus_tropicalis_v9.1.gene_symbol_to_embedding_ESM2.pt',
92 | 'zebrafish': EMBEDDING_DIR / 'Danio_rerio.GRCz11.gene_symbol_to_embedding_ESM2.pt',
93 | "mouse_lemur": EMBEDDING_DIR / "Microcebus_murinus.Mmur_3.0.gene_symbol_to_embedding_ESM2.pt",
94 | "pig": EMBEDDING_DIR / 'Sus_scrofa.Sscrofa11.1.gene_symbol_to_embedding_ESM2.pt',
95 | "macaca_fascicularis": EMBEDDING_DIR / 'Macaca_fascicularis.Macaca_fascicularis_6.0.gene_symbol_to_embedding_ESM2.pt',
96 | "macaca_mulatta": EMBEDDING_DIR / 'Macaca_mulatta.Mmul_10.gene_symbol_to_embedding_ESM2.pt',
97 | }
98 |
99 | species_to_pe = {
100 | species:torch.load(pe_dir) for species, pe_dir in embeddings_paths.items()
101 | }
102 |
103 | species_to_pe = {species:{k.upper(): v for k,v in pe.items()} for species, pe in species_to_pe.items()}
104 |
105 | #species_to_keys = {species:list(pe.keys()) for species, pe in species_to_pe.items()}
106 | #species_to_keys = {species:dict(zip(keys, np.arange(len(keys)))) for species, keys in species_to_keys.items()}
107 |
108 |
109 | #datasets_df = pd.read_csv("/dfs/project/cross-species/yanay/code/UCE/data_proc/full_train_datasets.csv")
110 | datasets_df = pd.read_csv("tissue_datasets.csv")
111 | datasets_df = pd.read_csv("perturb_datasets.csv")
112 | datasets_df = pd.read_csv("../new_perturb_datasets.csv")
113 |
114 |
115 | #pd.concat((#pd.read_csv("new_datasets.csv"),
116 | #pd.read_csv("pbmcs_nohvg.csv"),
117 | #pd.read_csv("lung_nohvg.csv"),
118 | #pd.read_csv("new_tabula_datasets.csv"),
119 | #pd.read_csv("updated_datasets.csv"),
120 | # #pd.read_csv("sanger_heart_atlas_datasets.csv"),
121 | # pd.read_csv("tissue_datasets.csv")
122 | # ))
123 |
124 |
125 |
126 |
127 | #datasets_df = pd.read_csv("cell_cycle_datasets.csv")
128 | #datasets_df = pd.read_csv("spatial_datasets.csv")
129 | #datasets_df = pd.read_csv("perturb_datasets.csv")
130 | #datasets_df = pd.read_csv("ccle_datasets.csv")
131 | #datasets_df = pd.read_csv("pancreas_datasets.csv")
132 |
133 |
134 |
135 | sorted_dataset_names = sorted(datasets_df["names"])
136 | with open("dataset_shapes.pkl", "rb") as f:
137 | shapes_dict = pickle.load(f)
138 |
139 |
140 | shapes_dict.update({
141 | "madissoon_novel_lung":(190728, 8000),
142 | 'flores_cerebellum_human': (20232, 8000),
143 | 'osuch_gut_human': (272310, 8000),
144 | 'msk_ovarian_human': (929690, 8000),
145 | 'htan_vmuc_dis_epi_human': (65084, 8000),
146 | 'htan_vmuc_val_epi_human': (57564, 8000),
147 | 'htan_vmuc_non_epi_human': (9099, 8000),
148 | 'hao_pbmc_3p_human': (161764, 8000),
149 | 'hao_pbmc_5p_human': (49147, 8000),
150 | 'gao_tumors_human': (36111, 8000),
151 | 'swabrick_breast_human': (92427, 8000),
152 | 'wu_cryo_tumors_human': (105662, 8000),
153 | 'cell_line_het_human': (53513, 8000),
154 | 'bi_allen_metastasis_human': (27787, 8000),
155 | 'zheng68k_human': (68579, 8000),
156 | 'zheng68k_12k_human': (68579, 12000),
157 | 'mouse_embryo_ct': (153597, 12000),
158 | "regev_gtex_heart": (36574, 8000),
159 | "tabula_sapiens_heart": (11505, 8000),
160 | "10k_pbmcs":(11990, 12000),
161 | "epo_ido":(35834,12000),
162 | 'tabula_sapiens_kidney': (9641, 8000),
163 | 'tabula_microcebus_kidney': (14592, 8000),
164 | 'tabula_muris_kidney': (2781, 8000),
165 | 'tabula_muris_senis_kidney': (19610, 8000),
166 | 'immune_human': (33506, 8000)
167 | })
168 |
169 | for row in datasets_df.iterrows():
170 | ngenes = row[1].num_genes
171 | ncells = row[1].num_cells
172 | name = row[1].names
173 | if not np.isnan(ngenes):
174 | shapes_dict[name] = (int(ncells), int(ngenes))
175 |
176 | #with open("dataset_shapes.pkl", "wb") as f:
177 | # pickle.dump(shapes_dict, f)
178 | token_dim = 5120
179 | mmap_dict = {}
180 |
181 | root_dir = "/lfs/local/0/yanay/uce_h5s/"
182 | root_dir_census = "/lfs/local/0/yanay/cxg_h5s/"
183 |
184 | dataset_to_paths = {r[1]["names"]:root_dir + r[1]["path"].replace(".h5ad", "_proc.h5ad") for r in datasets_df.iterrows()}
185 | for row in datasets_df.iterrows():
186 | name = row[1].names
187 | census = row[1].census
188 |
189 | if census == "yes":
190 | dataset_to_paths[name] = dataset_to_paths[name].replace(root_dir, root_dir_census)
191 |
192 |
193 | datasets_to_species = {r[1]["names"]:r[1]["species"] for r in datasets_df.iterrows()}
194 |
195 | #species_to_pe = {"mouse":mouse_pe, "human":human_pe, "mouse_lemur":lemur_pe}
196 |
197 | #dataset_to_protein_embeddings_all = {k:species_to_pe[v] for k, v in datasets_to_species.items()}
198 |
199 | dataset_to_protein_embeddings = {}
200 |
201 |
202 | #dataset_to_protein_embeddings_all["madissoon_novel_lung"] = species_to_pe["human"]
203 | datasets_to_species["madissoon_novel_lung"] = "human"
204 | #dataset_to_paths["madissoon_novel_lung"] = "/lfs/local/0/yanay/uce_h5s/madissoon_novel_lung_proc.h5ad"
205 |
206 |
207 |
208 | # New Chrom Based Code
209 | gene_to_chrom_pos = get_spec_chrom_csv()
210 | species_to_chrom_categories = {}
211 |
212 | for species in np.unique(gene_to_chrom_pos["species"]):
213 | species_to_chrom_categories[species] = pd.Categorical(gene_to_chrom_pos["chromosome"]).categories
214 |
215 |
216 | dataset_to_chroms = {}
217 | dataset_to_starts = {}
218 |
219 | sorted_species_names = sorted(species_to_pe.keys())
220 | print(sorted_species_names)
221 |
222 | if os.path.exists(f"/dfs/project/uce/all_species_pe_tokens.torch"):
223 | all_pe = torch.load(f"/dfs/project/uce/all_species_pe_tokens.torch")
224 | with open("/dfs/project/uce/all_species_offsets.pkl", "rb") as f:
225 | species_to_offsets = pickle.load(f)
226 | print("Loaded PE", all_pe.shape)
227 | else:
228 | torch.manual_seed(8)
229 | MASK_TENSOR = torch.zeros((1, token_dim)) # this is the padding token
230 | CHROM_TENSOR_LEFT = torch.normal(mean=0, std=1, size=(1, token_dim))
231 | CHROM_TENSOR_RIGHT = torch.normal(mean=0, std=1, size=(1, token_dim))
232 | CLS_TENSOR = torch.normal(mean=0, std=1, size=(1, token_dim))
233 | species_to_offsets = {}
234 |
235 | all_pe = [MASK_TENSOR, CHROM_TENSOR_LEFT, CHROM_TENSOR_RIGHT, CLS_TENSOR]
236 | offset = len(all_pe) # special tokens at the top!
237 | for species in sorted_species_names:
238 | pe_stacked = torch.stack(list(species_to_pe[species].values()))
239 | all_pe.append(pe_stacked)
240 | species_to_offsets[species] = offset
241 | offset += pe_stacked.shape[0]
242 |
243 | all_pe = torch.vstack(all_pe)
244 | print(all_pe.shape)
245 | torch.save(all_pe, f"/dfs/project/uce/all_species_pe_tokens.torch")
246 | with open("/dfs/project/uce/all_species_offsets.pkl", "wb+") as f:
247 | pickle.dump(species_to_offsets, f)
248 | print("Saved PE")
249 |
250 | # Load in already saved!
251 | if os.path.exists(f"/lfs/local/0/yanay/reduced_datasets_to_pe_chrom_{token_dim}_new.torch"):
252 | dataset_to_protein_embeddings = torch.load(f"/lfs/local/0/yanay/reduced_datasets_to_pe_chrom_{token_dim}_new.torch")
253 |
254 | with open("/lfs/local/0/yanay/dataset_to_chroms_new.pkl", "rb") as f:
255 | dataset_to_chroms = pickle.load(f)
256 | with open("/lfs/local/0/yanay/dataset_to_starts_new.pkl", "rb") as f:
257 | dataset_to_starts = pickle.load(f)
258 | else:
259 | dataset_to_protein_embeddings = {}
260 | dataset_to_chroms = {}
261 | dataset_to_starts = {}
262 |
263 |
264 | # Add the new ones
265 | print("creating reduced size protein embeddings file")
266 |
267 | redo = True
268 |
269 | for dataset, path in tqdm(list(dataset_to_paths.items())):
270 | if dataset in dataset_to_protein_embeddings.keys() and not redo:
271 | continue # skip since already procced
272 | print(dataset)
273 | adata = sc.read(path)
274 | dataset_species = datasets_to_species[dataset]
275 | spec_pe_genes = list(species_to_pe[dataset_species].keys())
276 | offset = species_to_offsets[dataset_species]
277 |
278 | # Get proper idxs
279 | pe_row_idxs, dataset_chroms, dataset_pos = adata_path_to_prot_chrom_starts(adata, dataset_species, spec_pe_genes, gene_to_chrom_pos, offset)
280 | # Add to dicts
281 | dataset_to_chroms[dataset] = dataset_chroms
282 | dataset_to_starts[dataset] = dataset_pos
283 | dataset_to_protein_embeddings[dataset] = pe_row_idxs
284 |
285 | del adata
286 | # save Dicts and idxs
287 | torch.save(dataset_to_protein_embeddings, f"/lfs/local/0/yanay/reduced_datasets_to_pe_chrom_{token_dim}_new.torch")
288 |
289 | with open("/lfs/local/0/yanay/dataset_to_chroms_new.pkl", "wb+") as f:
290 | pickle.dump(dataset_to_chroms, f)
291 | with open("/lfs/local/0/yanay/dataset_to_starts_new.pkl", "wb+") as f:
292 | pickle.dump(dataset_to_starts, f)
--------------------------------------------------------------------------------
/data_proc/preproc_many_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["OMP_NUM_THREADS"] = "10" # export OMP_NUM_THREADS=4
3 | os.environ["OPENBLAS_NUM_THREADS"] = "10" # export OPENBLAS_NUM_THREADS=4
4 | os.environ["MKL_NUM_THREADS"] = "10" # export MKL_NUM_THREADS=6
5 | os.environ["VECLIB_MAXIMUM_THREADS"] = "10" # export VECLIB_MAXIMUM_THREADS=4
6 | os.environ["NUMEXPR_NUM_THREADS"] = "10"
7 |
8 |
9 |
10 | from collections import defaultdict
11 | from typing import Dict, List, Optional, Tuple
12 |
13 | import torch
14 | import torch.utils.data as data
15 | import numpy as np
16 | import scanpy as sc
17 | from numpy import array
18 | import subprocess
19 | import os
20 | from tqdm import tqdm
21 | import warnings
22 | warnings.filterwarnings("ignore")
23 |
24 |
25 | from gene_embeddings import load_gene_embeddings_adata
26 | import pandas as pd
27 | import numpy as np
28 | from scanpy import AnnData
29 | from data_utils import process_raw_anndata
30 |
31 | def data_to_torch_X(X):
32 | if isinstance(X, sc.AnnData):
33 | X = X.X
34 | if not isinstance(X, np.ndarray):
35 | X = X.toarray()
36 | return torch.from_numpy(X).float()
37 |
38 | class SincleCellDataset(data.Dataset):
39 | def __init__(self,
40 | expression: torch.tensor, # Subset to hv genes, count data! cells x genes
41 | protein_embeddings: torch.tensor, # same order as expression, also subset genes x pe
42 | labels: None, # optional, tensor of labels
43 | covar_vals: None, # tensor of covar values or none
44 | ) -> None:
45 | super(SincleCellDataset, self).__init__()
46 |
47 | # Set expression
48 | self.expression = expression
49 |
50 | row_sums = self.expression.sum(1) # UMI Counts
51 | log_norm_count_adj = torch.log1p(self.expression / (self.expression.sum(1)).unsqueeze(1) * torch.tensor(1000))
52 |
53 | # Set log norm and count adjusted expression
54 | max_vals, max_idx = torch.max(log_norm_count_adj, dim=0)
55 | self.expression_mod = log_norm_count_adj / max_vals
56 |
57 | # Calculate dropout likliehoods of each gene
58 | self.dropout_vec = (self.expression == 0).float().mean(0) # per gene dropout percentages
59 |
60 | # Set data info
61 | self.num_cells = self.expression.shape[0]
62 | self.num_genes = self.expression.shape[1]
63 |
64 | # Set optional label info, including categorical covariate index
65 | self.covar_vals = covar_vals
66 | self.labels = labels
67 |
68 | # Set protein embeddings
69 | self.protein_embeddings = protein_embeddings
70 |
71 | self.item_mode = "expression"
72 | if self.covar_vals is not None:
73 | self.item_mode = "expression+covar"
74 |
75 |
76 | def __getitem__(self, idx):
77 | if self.item_mode == "expression":
78 | if isinstance(idx, int):
79 | if idx < self.num_cells:
80 | return self.expression[idx, :]
81 | else:
82 | raise IndexError
83 | else:
84 | raise NotImplementedError
85 | elif self.item_mode == "expression+covar":
86 | if isinstance(idx, int):
87 | if idx < self.num_cells:
88 | return self.expression[idx, :], self.covar_vals[idx]
89 | else:
90 | raise IndexError
91 | else:
92 | raise NotImplementedError
93 |
94 |
95 | def __len__(self) -> int:
96 | return self.num_cells
97 |
98 | def get_dim(self) -> Dict[str, int]:
99 | return self.num_genes
100 |
101 |
102 | def data_to_torch_X(X):
103 | if isinstance(X, sc.AnnData):
104 | X = X.X
105 | if not isinstance(X, np.ndarray):
106 | X = X.toarray()
107 | return torch.from_numpy(X).float()
108 |
109 |
110 | def anndata_to_sc_dataset(adata:sc.AnnData,
111 | species:str="human",
112 | labels:list=[],
113 | covar_col:str=None,
114 | hv_genes:int=12000,
115 | embedding_model="ESM1b",
116 | ) -> (SincleCellDataset, AnnData):
117 |
118 | # Subset to just genes we have embeddings for
119 | adata, protein_embeddings = load_gene_embeddings_adata(
120 | adata=adata,
121 | species=[species],
122 | embedding_model=embedding_model
123 | )
124 |
125 | if DO_HVG:
126 | sc.pp.highly_variable_genes(adata, flavor='seurat_v3', n_top_genes=hv_genes) # Expects Count Data
127 |
128 | hv_index = adata.var["highly_variable"]
129 | adata = adata[:, hv_index] # Subset to hv genes only
130 |
131 | protein_embeddings = protein_embeddings[species][hv_index]
132 | else:
133 | protein_embeddings = protein_embeddings[species]
134 | expression = data_to_torch_X(adata.X)
135 |
136 | covar_vals = None
137 | if len(labels) > 0:
138 | assert covar_col is None or covar_col in labels, "Covar needs to be in labels" # make sure you keep track of covar column!
139 | labels = adata.obs.loc[:, labels].values
140 |
141 | if covar_col is not None:
142 | # we have a categorical label to use as covariate
143 | covar_vals = torch.tensor(pd.Categorical(adata.obs[covar_col]).codes)
144 | return SincleCellDataset(
145 | expression=expression,
146 | protein_embeddings=protein_embeddings,
147 | labels=labels,
148 | covar_vals=covar_vals
149 | ), adata
150 |
151 | def proc(args):
152 | datasets_df = pd.read_csv(args.datasets_df)
153 | datasets_df["covar_col"] = np.nan
154 | skip = args.skip
155 | additional_filter = args.filter
156 | DO_HVG = args.DO_HVG
157 |
158 | num_genes = {}
159 | num_cells = {}
160 |
161 | ir = list(datasets_df.iterrows())
162 | for i, row in tqdm(ir, total=len(datasets_df)):
163 | _, ncells, ngenes = process_raw_anndata(row, h5_folder_path, npz_folder_path, scp, skip, additional_filter, root=args.file_root_path)
164 | if (ncells is not None) and (ngenes is not None):
165 | num_genes[path] = adata.X.shape[1]
166 | num_cells[path] = ngenes
167 |
168 | if "num_cells" not in datasets_df.columns:
169 | datasets_df["num_cells"] = 0
170 | if "num_genes" not in datasets_df.columns:
171 | datasets_df["num_genes"] = 0
172 | for k in num_genes.keys():
173 | ng = num_genes[k]
174 | nc = num_cells[k]
175 | datasets_df.loc[datasets_df["path"] == k, "num_cells"] = nc
176 | datasets_df.loc[datasets_df["path"] == k, "num_genes"] = ng
177 | # Write with the cells and genes info back to the original path
178 | datasets_df.to_csv(args.datasets_df, index=False)
179 | if __name__=="__main__":
180 | # Parse command-line arguments
181 |
182 | parser = argparse.ArgumentParser(description='Preproc datasets h5ad datasets.')
183 |
184 | # Define command-line arguments
185 | parser.add_argument('--scp', type=str, default="", help='Name of a SNAP server to SCP the results to. It should have the same folders as the script is already saving to.')
186 | parser.add_argument('--h5_folder_path', type=str, default="/lfs/local/0/yanay/uce_h5s/", help='Folder to save H5s to.')
187 | parser.add_argument('--npz_folder_path', type=str, default="/lfs/local/0/yanay/uce_proc/", help='Folder to save NPZs to.')
188 |
189 |
190 | parser.add_argument('--datasets_df', type=str, default="/dfs/project/uce/new_perturb_datasets.csv", help='Path to datasets csv. Will be overwritten to have the correct num cells and num genes for each dataset.')
191 |
192 | parser.add_argument('--filter', type=bool, default=True, help='Should you do an additional gene/cell filtering? This can be a good step since even if you have already done it, subsetting to protein embeddings can make some cells sparser.')
193 | parser.add_argument('--skip', type=bool, default=True, help='Should you skip datasets that appear to have already been created in the h5 folder?')
194 |
195 | parser.add_argument('--DO_HVG', type=bool, default=False, help='Should a HVG subset be done.')
196 |
197 |
198 | parse
199 | args = parser.parse_args()
200 | main(args)
201 |
--------------------------------------------------------------------------------
/eval_data.py:
--------------------------------------------------------------------------------
1 | """
2 | Dataloaders
3 |
4 | """
5 |
6 | import warnings
7 | warnings.filterwarnings("ignore")
8 | import sys
9 | sys.path.append('../')
10 | from typing import Dict, List, Optional, Tuple, Any
11 | import torch
12 | import numpy as np
13 | import pickle
14 | import torch.utils.data as data
15 |
16 |
17 | class MultiDatasetSentences(data.Dataset):
18 | def __init__(self, sorted_dataset_names, shapes_dict, args,
19 | dataset_to_protein_embeddings_path= "/lfs/local/0/yanay/reduced_datasets_to_pe_chrom_5120_new.torch",
20 | datasets_to_chroms_path="/lfs/local/0/yanay/dataset_to_chroms_new.pkl",
21 | datasets_to_starts_path="/lfs/local/0/yanay/dataset_to_starts_new.pkl",
22 | npzs_dir="/lfs/local/0/yanay/uce_proc/") -> None:
23 | super(MultiDatasetSentences, self).__init__()
24 | # self.xs = {}
25 | self.num_cells = {}
26 | self.num_genes = {}
27 | self.shapes_dict = shapes_dict
28 | self.args = args
29 |
30 | self.total_num_cells = 0
31 | for name in sorted_dataset_names:
32 | num_cells, num_genes = self.shapes_dict[name]
33 | # self.xs[name] = X
34 | self.num_cells[name] = num_cells
35 | self.num_genes[name] = num_genes
36 |
37 | self.total_num_cells += num_cells
38 |
39 | self.datasets = sorted_dataset_names
40 |
41 | # TODO: preferably not hard-coded here
42 | self.dataset_to_protein_embeddings = torch.load(dataset_to_protein_embeddings_path)
43 | with open(datasets_to_chroms_path, "rb") as f:
44 | self.dataset_to_chroms = pickle.load(f)
45 | with open(datasets_to_starts_path, "rb") as f:
46 | self.dataset_to_starts = pickle.load(f)
47 |
48 | self.npzs_dir = npzs_dir
49 |
50 | def __getitem__(self, idx):
51 | if isinstance(idx, int):
52 | for dataset in sorted(self.datasets):
53 | if idx < self.num_cells[dataset]:
54 | #cts = np.memmap(f"/lfs/local/0/yanay/cxg_npzs/" + f"{dataset}_counts.npz",
55 | # dtype='int64', mode='r', shape=self.shapes_dict[dataset])
56 | cts = np.memmap(self.npzs_dir + f"{dataset}_counts.npz", dtype='int64', mode='r', shape=self.shapes_dict[dataset])
57 | counts = cts[idx]
58 | counts = torch.tensor(counts).unsqueeze(0)
59 | weights = torch.log1p(counts)
60 | weights = (weights / torch.sum(weights))
61 | batch_sentences, mask, seq_len, cell_sentences = \
62 | sample_cell_sentences(counts, weights, dataset, self.args,
63 | dataset_to_protein_embeddings= self.dataset_to_protein_embeddings,
64 | dataset_to_chroms=self.dataset_to_chroms,
65 | dataset_to_starts=self.dataset_to_starts)
66 | return batch_sentences, mask, idx, seq_len, cell_sentences
67 | else:
68 | idx -= self.num_cells[dataset]
69 | raise IndexError
70 | else:
71 | raise NotImplementedError
72 |
73 | def __len__(self) -> int:
74 | return self.total_num_cells
75 |
76 | def get_dim(self) -> Dict[str, int]:
77 | return self.num_genes
78 |
79 |
80 | class MultiDatasetSentenceCollator(object):
81 | def __init__(self, args):
82 | self.pad_length = args.pad_length
83 |
84 |
85 | def __call__(self, batch):
86 | batch_size = len(batch)
87 | batch_sentences = torch.zeros((batch_size, self.pad_length))
88 | mask = torch.zeros((batch_size, self.pad_length))
89 | cell_sentences = torch.zeros((batch_size, self.pad_length))
90 |
91 | idxs = torch.zeros(batch_size)
92 |
93 | i = 0
94 | max_len = 0
95 | for bs, msk, idx, seq_len, cs in batch:
96 | batch_sentences[i, :] = bs
97 | cell_sentences[i, :] = cs
98 | max_len = max(max_len, seq_len)
99 | mask[i, :] = msk
100 | idxs[i] = idx
101 |
102 | i += 1
103 |
104 | return batch_sentences[:, :max_len] , mask[:, :max_len], idxs, cell_sentences
105 |
106 |
107 |
108 | def sample_cell_sentences(counts, batch_weights, dataset, args,
109 | dataset_to_protein_embeddings,
110 | dataset_to_chroms,
111 | dataset_to_starts):
112 |
113 | dataset_idxs = dataset_to_protein_embeddings[dataset] # get the dataset specific protein embedding idxs
114 | cell_sentences = torch.zeros((counts.shape[0], args.pad_length)) # init the cell representation as 0s
115 | mask = torch.zeros((counts.shape[0], args.pad_length)) # start of masking the whole sequence
116 | chroms = dataset_to_chroms[dataset] # get the dataset specific chroms for each gene
117 | starts = dataset_to_starts[dataset] # get the dataset specific genomic start locations for each gene
118 |
119 | longest_seq_len = 0 # we need to keep track of this so we can subset the batch at the end
120 |
121 | for c, cell in enumerate(counts):
122 | weights = batch_weights[c].numpy()
123 | weights = weights / sum(weights) # RE NORM after mask
124 |
125 | # randomly choose the genes that will make up the sample, weighted by expression, with replacement
126 | choice_idx = np.random.choice(np.arange(len(weights)),
127 | size=args.sample_size, p=weights,
128 | replace=True)
129 | choosen_chrom = chroms[choice_idx] # get the sampled genes chromosomes
130 | # order the genes by chromosome
131 | chrom_sort = np.argsort(choosen_chrom)
132 | choice_idx = choice_idx[chrom_sort]
133 |
134 | # sort the genes by start
135 | new_chrom = chroms[choice_idx]
136 | choosen_starts = starts[choice_idx]
137 |
138 | ordered_choice_idx = np.full((args.pad_length),
139 | args.cls_token_idx) # start with cls
140 | # i= 0 first token is CLS
141 | i = 1 # continue on to the rest of the sequence with left bracket being assumed.
142 | # Shuffle the chroms now, there's no natural order to chromosomes
143 | uq_chroms = np.unique(new_chrom)
144 | np.random.shuffle(uq_chroms) # shuffle
145 |
146 | # This loop is actually just over one cell
147 | for chrom in uq_chroms:
148 | # Open Chrom token
149 | ordered_choice_idx[i] = int(chrom) + args.CHROM_TOKEN_OFFSET # token of this chromosome # i = 1 next token is a chrom open
150 | i += 1
151 | # now sort the genes by start order within the chroms
152 | loc = np.where(new_chrom == chrom)[0]
153 | sort_by_start = np.argsort(
154 | choosen_starts[loc]) # start locations for this chromsome
155 |
156 | to_add = choice_idx[loc[sort_by_start]]
157 | ordered_choice_idx[i:(i + len(to_add))] = dataset_idxs[to_add]
158 | i += len(to_add)
159 | ordered_choice_idx[i] = args.chrom_token_right_idx # add the chrom sep again
160 | i += 1 # add the closing token again
161 |
162 | longest_seq_len = max(longest_seq_len, i)
163 | remainder_len = (args.pad_length - i)
164 |
165 | cell_mask = torch.concat((torch.ones(i),
166 | # pay attention to all of these tokens, ignore the rest!
167 | torch.zeros(remainder_len)))
168 |
169 | mask[c, :] = cell_mask
170 |
171 | ordered_choice_idx[i:] = args.pad_token_idx # the remainder of the sequence
172 | cell_sentences[c, :] = torch.from_numpy(ordered_choice_idx)
173 |
174 | cell_sentences_pe = cell_sentences.long() # token indices
175 |
176 | return cell_sentences_pe, mask, longest_seq_len, cell_sentences
--------------------------------------------------------------------------------
/eval_single_anndata.py:
--------------------------------------------------------------------------------
1 | """
2 | Script for Evaluating a Single AnnData
3 |
4 | Parameters:
5 | ----------
6 | - `adata_path` (str):
7 | Full path to the AnnData you want to embed.
8 | - `dir` (str):
9 | Working folder where all files will be saved.
10 | - `species` (str):
11 | Species of the AnnData.
12 | - `filter` (bool):
13 | Additional gene/cell filtering on the AnnData.
14 | - `skip` (bool):
15 | Skip datasets that appear to have already been created.
16 | - `model_loc` (str):
17 | Location of pretrained UCE model's weights in a `.torch` file.
18 | - `batch_size` (int):
19 | Batch size for processing.
20 | - `CXG` (bool):
21 | Use CXG model.
22 | - `nlayers` (int):
23 | Number of transformer layers.
24 | - `output_dim` (int):
25 | Desired output dimension.
26 | - `d_hid` (int):
27 | Hidden dimension for processing.
28 | - `token_dim` (int):
29 | Token dimension.
30 | - `spec_chrom_csv_path` (str):
31 | CSV file mapping genes from each species to their respective chromosomes
32 | and genomic start positions.
33 | - `token_file` (str):
34 | `.torch` file containing token/protein embeddings for all tokens.
35 | - `protein_embeddings_dir` (str):
36 | Directory containing protein embedding `.pt` files for all species.
37 | - `offset_pkl_path` (str):
38 | `.pkl` file mapping between species and their gene's locations in the `token_file`.
39 | - `pad_length` (int):
40 | Length to pad the cell sentence to.
41 | - `pad_token_idx` (int):
42 | Index of the padding token in the `token_file`.
43 | - `chrom_token_left_idx` (int):
44 | Left chromosome token index
45 | - `chrom_token_right_idx` (int):
46 | Right chromosome token index
47 | - `cls_token_idx` (int):
48 | CLS token index in the `token_file`.
49 | - `CHROM_TOKEN_OFFSET` (int):
50 | Offset index, tokens after this mark are chromosome identifiers.
51 | - `sample_size` (int):
52 | Number of genes sampled for cell sentence.
53 | - `multi_gpu` (bool):
54 | Run evaluation on multiple GPUs (using accelerator)
55 |
56 | Returns:
57 | -------
58 | - `dir/{dataset_name}_proc.h5ad`:
59 | The processed AnnData. Processing involves subsetting it to genes which
60 | have protein embeddings and then refiltering the dataset by minimum counts.
61 | - `dir/{dataset_name}_chroms.pkl`:
62 | File mapping the genes in the dataset to their corresponding chromosome
63 | indices.
64 | - `dir/{dataset_name}_counts.npz`:
65 | File containing the counts of the AnnData in an easily accessible format.
66 | - `dir/{dataset_name}_shapes_dict.pkl`:
67 | File containing the shape (ncell x ngene) of the AnnData, used to read the
68 | `.npz` file.
69 | - `dir/{dataset_name}_pe_idx.torch`:
70 | File mapping between the genes in the dataset and their index in the tokens file.
71 | - `dir/{dataset_name}_starts.pkl`:
72 | File mapping between the genes in the dataset and their genomic start locations.
73 |
74 | """
75 |
76 |
77 | import argparse
78 | from evaluate import AnndataProcessor
79 | from accelerate import Accelerator
80 |
81 | def main(args, accelerator):
82 | processor = AnndataProcessor(args, accelerator)
83 | processor.preprocess_anndata()
84 | processor.generate_idxs()
85 | processor.run_evaluation()
86 |
87 |
88 | if __name__ == "__main__":
89 | parser = argparse.ArgumentParser(
90 | description='Embed a single anndata using UCE.')
91 |
92 | # Anndata Processing Arguments
93 | parser.add_argument('--adata_path', type=str,
94 | default=None,
95 | help='Full path to the anndata you want to embed.')
96 | parser.add_argument('--dir', type=str,
97 | default="./",
98 | help='Working folder where all files will be saved.')
99 | parser.add_argument('--species', type=str, default="human",
100 | help='Species of the anndata.')
101 | parser.add_argument('--filter', type=bool, default=True,
102 | help='Additional gene/cell filtering on the anndata.')
103 | parser.add_argument('--skip', type=bool, default=True,
104 | help='Skip datasets that appear to have already been created.')
105 |
106 | # Model Arguments
107 | parser.add_argument('--model_loc', type=str,
108 | default=None,
109 | help='Location of the model.')
110 | parser.add_argument('--batch_size', type=int, default=25,
111 | help='Batch size.')
112 | parser.add_argument('--pad_length', type=int, default=1536,
113 | help='Batch size.')
114 | parser.add_argument("--pad_token_idx", type=int, default=0,
115 | help="PAD token index")
116 | parser.add_argument("--chrom_token_left_idx", type=int, default=1,
117 | help="Chrom token left index")
118 | parser.add_argument("--chrom_token_right_idx", type=int, default=2,
119 | help="Chrom token right index")
120 | parser.add_argument("--cls_token_idx", type=int, default=3,
121 | help="CLS token index")
122 | parser.add_argument("--CHROM_TOKEN_OFFSET", type=int, default=143574,
123 | help="Offset index, tokens after this mark are chromosome identifiers")
124 | parser.add_argument('--sample_size', type=int, default=1024,
125 | help='Number of genes sampled for cell sentence')
126 | parser.add_argument('--CXG', type=bool, default=True,
127 | help='Use CXG model.')
128 | parser.add_argument('--nlayers', type=int, default=4,
129 | help='Number of transformer layers.')
130 | parser.add_argument('--output_dim', type=int, default=1280,
131 | help='Output dimension.')
132 | parser.add_argument('--d_hid', type=int, default=5120,
133 | help='Hidden dimension.')
134 | parser.add_argument('--token_dim', type=int, default=5120,
135 | help='Token dimension.')
136 | parser.add_argument('--multi_gpu', type=bool, default=False,
137 | help='Use multiple GPUs')
138 |
139 | # Misc Arguments
140 | parser.add_argument("--spec_chrom_csv_path",
141 | default="./model_files/species_chrom.csv", type=str,
142 | help="CSV Path for species genes to chromosomes and start locations.")
143 | parser.add_argument("--token_file",
144 | default="./model_files/all_tokens.torch", type=str,
145 | help="Path for token embeddings.")
146 | parser.add_argument("--protein_embeddings_dir",
147 | default="./model_files/protein_embeddings/", type=str,
148 | help="Directory where protein embedding .pt files are stored.")
149 | parser.add_argument("--offset_pkl_path",
150 | default="./model_files/species_offsets.pkl", type=str,
151 | help="PKL file which contains offsets for each species.")
152 |
153 | args = parser.parse_args()
154 | accelerator = Accelerator(project_dir=args.dir)
155 | main(args, accelerator)
156 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | # os.environ["NCCL_DEBUG"] = "INFO"
4 | os.environ["OMP_NUM_THREADS"] = "12" # export OMP_NUM_THREADS=4
5 | os.environ["OPENBLAS_NUM_THREADS"] = "12" # export OPENBLAS_NUM_THREADS=4
6 | os.environ["MKL_NUM_THREADS"] = "12" # export MKL_NUM_THREADS=6
7 | os.environ["VECLIB_MAXIMUM_THREADS"] = "12" # export VECLIB_MAXIMUM_THREADS=4
8 | os.environ["NUMEXPR_NUM_THREADS"] = "12"
9 |
10 | import warnings
11 |
12 | warnings.filterwarnings("ignore")
13 |
14 | import scanpy as sc
15 | from tqdm.auto import tqdm
16 | from torch import nn, Tensor
17 |
18 | from model import TransformerModel
19 | from eval_data import MultiDatasetSentences, MultiDatasetSentenceCollator
20 | from utils import figshare_download
21 |
22 | from torch.utils.data import DataLoader
23 | from data_proc.data_utils import adata_path_to_prot_chrom_starts, \
24 | get_spec_chrom_csv, process_raw_anndata, get_species_to_pe
25 |
26 | import os
27 | import pickle
28 | import pandas as pd
29 | import numpy as np
30 | import torch
31 |
32 |
33 | class AnndataProcessor:
34 | def __init__(self, args, accelerator):
35 | self.args = args
36 | self.accelerator = accelerator
37 | self.h5_folder_path = self.args.dir
38 | self.npz_folder_path = self.args.dir
39 | self.scp = ""
40 |
41 | # Check if paths exist, if not, create them
42 | self.check_paths()
43 |
44 | # Set up the anndata
45 | self.adata_name = self.args.adata_path.split("/")[-1]
46 | self.adata_root_path = self.args.adata_path.replace(self.adata_name, "")
47 | self.name = self.adata_name.replace(".h5ad", "")
48 | self.proc_h5_path = self.h5_folder_path + f"{self.name}_proc.h5ad"
49 | self.adata = None
50 |
51 | # Set up the row
52 | row = pd.Series()
53 | row.path = self.adata_name
54 | row.covar_col = np.nan
55 | row.species = self.args.species
56 | self.row = row
57 |
58 | # Set paths once to be used throughout the class
59 | self.pe_idx_path = self.args.dir + f"{self.name}_pe_idx.torch"
60 | self.chroms_path = self.args.dir + f"{self.name}_chroms.pkl"
61 | self.starts_path = self.args.dir + f"{self.name}_starts.pkl"
62 | self.shapes_dict_path = self.args.dir + f"{self.name}_shapes_dict.pkl"
63 |
64 | def check_paths(self):
65 | """
66 | Check if the paths exist, if not, create them
67 | """
68 | figshare_download("https://figshare.com/ndownloader/files/42706558",
69 | self.args.spec_chrom_csv_path)
70 | figshare_download("https://figshare.com/ndownloader/files/42706555",
71 | self.args.offset_pkl_path)
72 | if not os.path.exists(self.args.protein_embeddings_dir):
73 | figshare_download("https://figshare.com/ndownloader/files/42715213",
74 | 'model_files/protein_embeddings.tar.gz')
75 | figshare_download("https://figshare.com/ndownloader/files/42706585",
76 | self.args.token_file)
77 | if self.args.adata_path is None:
78 | print("Using sample AnnData: 10k pbmcs dataset")
79 | self.args.adata_path = "./data/10k_pbmcs_proc.h5ad"
80 | figshare_download(
81 | "https://figshare.com/ndownloader/files/42706966",
82 | self.args.adata_path)
83 | if self.args.model_loc is None:
84 | print("Using sample 4 layer model")
85 | self.args.model_loc = "./model_files/4layer_model.torch"
86 | figshare_download(
87 | "https://figshare.com/ndownloader/files/42706576",
88 | self.args.model_loc)
89 |
90 |
91 | def preprocess_anndata(self):
92 | if self.accelerator.is_main_process:
93 | self.adata, num_cells, num_genes = \
94 | process_raw_anndata(self.row,
95 | self.h5_folder_path,
96 | self.npz_folder_path,
97 | self.scp,
98 | self.args.skip,
99 | self.args.filter,
100 | root=self.adata_root_path)
101 | if (num_cells is not None) and (num_genes is not None):
102 | self.save_shapes_dict(self.name, num_cells, num_genes,
103 | self.shapes_dict_path)
104 |
105 | if self.adata is None:
106 | self.adata = sc.read(self.proc_h5_path)
107 |
108 | def save_shapes_dict(self, name, num_cells, num_genes, shapes_dict_path):
109 | shapes_dict = {name: (num_cells, num_genes)}
110 | with open(shapes_dict_path, "wb+") as f:
111 | pickle.dump(shapes_dict, f)
112 | print("Wrote Shapes Dict")
113 |
114 | def generate_idxs(self):
115 | if self.accelerator.is_main_process:
116 | if os.path.exists(self.pe_idx_path) and \
117 | os.path.exists(self.chroms_path) and \
118 | os.path.exists(self.starts_path):
119 | print("PE Idx, Chrom and Starts files already created")
120 |
121 | else:
122 | species_to_pe = get_species_to_pe(self.args.protein_embeddings_dir)
123 | with open(self.args.offset_pkl_path, "rb") as f:
124 | species_to_offsets = pickle.load(f)
125 |
126 | gene_to_chrom_pos = get_spec_chrom_csv(
127 | self.args.spec_chrom_csv_path)
128 | dataset_species = self.args.species
129 | spec_pe_genes = list(species_to_pe[dataset_species].keys())
130 | offset = species_to_offsets[dataset_species]
131 | pe_row_idxs, dataset_chroms, dataset_pos = adata_path_to_prot_chrom_starts(
132 | self.adata, dataset_species, spec_pe_genes, gene_to_chrom_pos, offset)
133 |
134 | # Save to the temp dict
135 | torch.save({self.name: pe_row_idxs}, self.pe_idx_path)
136 | with open(self.chroms_path, "wb+") as f:
137 | pickle.dump({self.name: dataset_chroms}, f)
138 | with open(self.starts_path, "wb+") as f:
139 | pickle.dump({self.name: dataset_pos}, f)
140 |
141 | def run_evaluation(self):
142 | self.accelerator.wait_for_everyone()
143 | with open(self.shapes_dict_path, "rb") as f:
144 | shapes_dict = pickle.load(f)
145 | run_eval(self.adata, self.name, self.pe_idx_path, self.chroms_path,
146 | self.starts_path, shapes_dict, self.accelerator, self.args)
147 |
148 |
149 | def get_ESM2_embeddings(args):
150 | # Load in ESM2 embeddings and special tokens
151 | all_pe = torch.load(args.token_file)
152 | if all_pe.shape[0] == 143574:
153 | torch.manual_seed(23)
154 | CHROM_TENSORS = torch.normal(mean=0, std=1, size=(1895, args.token_dim))
155 | # 1895 is the total number of chromosome choices, it is hardcoded for now
156 | all_pe = torch.vstack(
157 | (all_pe, CHROM_TENSORS)) # Add the chrom tensors to the end
158 | all_pe.requires_grad = False
159 |
160 | return all_pe
161 |
162 |
163 | def padding_tensor(sequences):
164 | """
165 | :param sequences: list of tensors
166 | :return:
167 | """
168 | num = len(sequences)
169 | max_len = max([s.size(0) for s in sequences])
170 | out_dims = (num, max_len, 1280)
171 |
172 | out_tensor = sequences[0].data.new(*out_dims).fill_(0)
173 | out_dims2 = (num, max_len)
174 |
175 | mask = sequences[0].data.new(*out_dims2).fill_(float('-inf'))
176 | for i, tensor in enumerate(sequences):
177 | length = tensor.size(0)
178 | out_tensor[i, :length] = tensor
179 | mask[i, :length] = 1
180 | return out_tensor.permute(1, 0, 2), mask
181 |
182 |
183 | def run_eval(adata, name, pe_idx_path, chroms_path, starts_path, shapes_dict,
184 | accelerator, args):
185 |
186 | #### Set up the model ####
187 | token_dim = args.token_dim
188 | emsize = 1280 # embedding dimension
189 | d_hid = args.d_hid # dimension of the feedforward network model in nn.TransformerEncoder
190 | nlayers = args.nlayers # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
191 | nhead = 20 # number of heads in nn.MultiheadAttention
192 | dropout = 0.05 # dropout probability
193 | model = TransformerModel(token_dim=token_dim, d_model=emsize, nhead=nhead,
194 | d_hid=d_hid,
195 | nlayers=nlayers, dropout=dropout,
196 | output_dim=args.output_dim)
197 | if args.model_loc is None:
198 | raise ValueError("Must provide a model location")
199 | # intialize as empty
200 | empty_pe = torch.zeros(145469, 5120)
201 | empty_pe.requires_grad = False
202 | model.pe_embedding = nn.Embedding.from_pretrained(empty_pe)
203 | model.load_state_dict(torch.load(args.model_loc, map_location="cpu"),
204 | strict=True)
205 | # Load in the real token embeddings
206 | all_pe = get_ESM2_embeddings(args)
207 | # This will make sure that you don't overwrite the tokens in case you're embedding species from the training data
208 | # We avoid doing that just in case the random seeds are different across different versions.
209 | if all_pe.shape[0] != 145469:
210 | all_pe.requires_grad = False
211 | model.pe_embedding = nn.Embedding.from_pretrained(all_pe)
212 | print(f"Loaded model:\n{args.model_loc}")
213 | model = model.eval()
214 | model = accelerator.prepare(model)
215 | batch_size = args.batch_size
216 |
217 | #### Run the model ####
218 | # Dataloaders
219 | dataset = MultiDatasetSentences(sorted_dataset_names=[name],
220 | shapes_dict=shapes_dict,
221 | args=args, npzs_dir=args.dir,
222 | dataset_to_protein_embeddings_path=pe_idx_path,
223 | datasets_to_chroms_path=chroms_path,
224 | datasets_to_starts_path=starts_path
225 | )
226 | multi_dataset_sentence_collator = MultiDatasetSentenceCollator(args)
227 |
228 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
229 | collate_fn=multi_dataset_sentence_collator,
230 | num_workers=0)
231 | dataloader = accelerator.prepare(dataloader)
232 | pbar = tqdm(dataloader, disable=not accelerator.is_local_main_process)
233 | dataset_embeds = []
234 | with torch.no_grad():
235 | for batch in pbar:
236 | batch_sentences, mask, idxs = batch[0], batch[1], batch[2]
237 | batch_sentences = batch_sentences.permute(1, 0)
238 | if args.multi_gpu:
239 | batch_sentences = model.module.pe_embedding(batch_sentences.long())
240 | else:
241 | batch_sentences = model.pe_embedding(batch_sentences.long())
242 | batch_sentences = nn.functional.normalize(batch_sentences,
243 | dim=2) # Normalize token outputs now
244 | _, embedding = model.forward(batch_sentences, mask=mask)
245 | # Fix for duplicates in last batch
246 | accelerator.wait_for_everyone()
247 | embeddings = accelerator.gather_for_metrics((embedding))
248 | if accelerator.is_main_process:
249 | dataset_embeds.append(embeddings.detach().cpu().numpy())
250 |
251 | accelerator.wait_for_everyone()
252 | if accelerator.is_main_process:
253 | dataset_embeds = np.vstack(dataset_embeds)
254 | adata.obsm["X_uce"] = dataset_embeds
255 | write_path = args.dir + f"{name}_uce_adata.h5ad"
256 | adata.write(write_path)
257 |
258 | print("*****Wrote Anndata to:*****")
259 | print(write_path)
260 |
--------------------------------------------------------------------------------
/examples/Label Transfer Using Logistic Classifier.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "3f4f1b19-5369-4e4d-9366-b6f07f88b402",
6 | "metadata": {},
7 | "source": [
8 | "# Transferring Labels Using UCE\n",
9 | "\n",
10 | "This notebook walks through the example from Figure 4d,4e of transferring labels from mouse kidney norn cells to a human lung disease dataset.\n",
11 | "\n",
12 | "To transfer labels, we use a basic default implementation of sklearn's logistic classifier."
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 1,
18 | "id": "5ca49083-fd91-473f-b60a-621b07d52de2",
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "## Imports\n",
23 | "import scanpy as sc\n",
24 | "import numpy as np\n",
25 | "import random\n",
26 | "from sklearn.linear_model import LogisticRegression\n",
27 | "sc._settings.settings._vector_friendly=True\n",
28 | "import matplotlib\n",
29 | "import matplotlib.pyplot as plt\n",
30 | "\n",
31 | "## Seed\n",
32 | "np.random.seed(0)\n",
33 | "random.seed(0)"
34 | ]
35 | },
36 | {
37 | "cell_type": "markdown",
38 | "id": "72536f9e-b010-44a7-b323-c32f05cf7d98",
39 | "metadata": {},
40 | "source": [
41 | "## Load in anndatas\n",
42 | "You can download the anndatas here: https://drive.google.com/drive/folders/1f63fh0ykgEhCrkd_EVvIootBw7LYDVI7"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": 2,
48 | "id": "8e5d7a7e-2a86-4fce-82fe-17afcf83dec5",
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "epo_uce = sc.read(\"mouse_kidney_norn.h5ad\")\n",
53 | "kam_20_uce = sc.read(\"human_lung_disease.h5ad\")"
54 | ]
55 | },
56 | {
57 | "cell_type": "markdown",
58 | "id": "dcbf764a-ad99-4622-a780-ecd62e471132",
59 | "metadata": {},
60 | "source": [
61 | "### Train Classifier on Mouse Kidney Cells\n",
62 | "\n",
63 | "We train a classifier to predict coarsened cell types, from the UCE embeddings"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 3,
69 | "id": "9e56e9aa-bf4a-42d9-b8a6-e76833161083",
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "epo_map = {\n",
74 | " \"Norn\":\"Norn\",\n",
75 | " \"Proximal tubule\":\"Proximal tubule\",\n",
76 | " \"Collecting duct principal\":\"Collecting duct\",\n",
77 | " \"Distal convoluted tubule\":\"Distal convoluted tubule\",\n",
78 | " \"Fibroblasts\":\"Fibroblast\",\n",
79 | " \"Endothelial\":\"Endothelial\",\n",
80 | " \"Collecting duct transient\":\"Collecting duct\",\n",
81 | " \"Other\":\"misc\",\n",
82 | " \"Pericyte Ren1+\":\"Pericyte\",\n",
83 | " \"Podocytes\":\"Podocyte\",\n",
84 | " \"Pericyte3\":\"Pericyte\",\n",
85 | " \"Pericyte1\":\"Pericyte\",\n",
86 | " \"Pericyte2\":\"Pericyte\",\n",
87 | " \"Collecting duct intercalated\":\"Collecting duct\",\n",
88 | " \"Loop of henle\":\"Loop of henle\",\n",
89 | " \"Proximal tubule2\":\"Proximal tubule\",\n",
90 | " \"Macrophages\":\"Macrophage\",\n",
91 | " \"Neutrophil\":\"Granulocyte\",\n",
92 | " \"T lymphocyte\":\"T cell\",\n",
93 | " \"Collecting duct\":\"Collecting duct\",\n",
94 | " \"Monocytes\":\"Monocyte\",\n",
95 | " \n",
96 | "} # coarse cell type map"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": 4,
102 | "id": "39b160cd-4462-437f-b89e-7363e20e8ebe",
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "epo_uce_no_misc = epo_uce[epo_uce.obs.group != \"Other\"] # remove misc cells\n",
107 | "X = epo_uce_no_misc.obsm[\"X_uce\"] # input is UCE embeddings\n",
108 | "y = [epo_map[ct] for ct in epo_uce_no_misc.obs[\"group\"].values] # output is mapped cell types\n",
109 | "clf = LogisticRegression(random_state=0).fit(X, y) # fit classifier"
110 | ]
111 | },
112 | {
113 | "cell_type": "markdown",
114 | "id": "ae2fd8b6-e681-4f02-9f9c-71d71d95f925",
115 | "metadata": {},
116 | "source": [
117 | "### Predict norn-like cells using classifier"
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": 5,
123 | "id": "722bcc74-aaca-43f9-b5dd-2427584d7683",
124 | "metadata": {},
125 | "outputs": [],
126 | "source": [
127 | "kam_20_uce.obs[\"pred\"] = clf.predict(kam_20_uce.obsm[\"X_uce\"]) # predict cell types for lung disease dataset"
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "execution_count": 6,
133 | "id": "ae31d2c0-5987-46c1-9be0-d874ae6577b5",
134 | "metadata": {},
135 | "outputs": [
136 | {
137 | "data": {
138 | "text/plain": [
139 | "pred\n",
140 | "Proximal tubule 119834\n",
141 | "T cell 93556\n",
142 | "Granulocyte 52485\n",
143 | "Collecting duct 15727\n",
144 | "Macrophage 11800\n",
145 | "Endothelial 7233\n",
146 | "Norn 6005\n",
147 | "Podocyte 4270\n",
148 | "Pericyte 1316\n",
149 | "Fibroblast 623\n",
150 | "Monocyte 56\n",
151 | "Loop of henle 23\n",
152 | "Name: count, dtype: int64"
153 | ]
154 | },
155 | "execution_count": 6,
156 | "metadata": {},
157 | "output_type": "execute_result"
158 | }
159 | ],
160 | "source": [
161 | "kam_20_uce.obs[\"pred\"].value_counts()"
162 | ]
163 | },
164 | {
165 | "cell_type": "markdown",
166 | "id": "23a839f8-88ce-4446-b4a5-961a38a264b5",
167 | "metadata": {},
168 | "source": [
169 | "# Check Differential Expression"
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "execution_count": 7,
175 | "id": "159cc172-1be5-4ad2-a4b5-2ebd2e52302e",
176 | "metadata": {},
177 | "outputs": [],
178 | "source": [
179 | "# Preproccess Count Values\n",
180 | "sc.pp.highly_variable_genes(kam_20_uce, n_top_genes=8000, flavor=\"seurat_v3\", subset=True)\n",
181 | "sc.pp.normalize_per_cell(kam_20_uce)\n",
182 | "sc.pp.log1p(kam_20_uce)"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": 8,
188 | "id": "57b9fb8a-2699-43f4-b74f-2a646e8610c8",
189 | "metadata": {},
190 | "outputs": [],
191 | "source": [
192 | "# Subset to predicted Norn-like cells\n",
193 | "kam20_norn_ad = kam_20_uce[kam_20_uce.obs.pred == \"Norn\"].copy()"
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": 9,
199 | "id": "591d1aa7-c4ad-4ad9-9bef-2ed91ed19b57",
200 | "metadata": {},
201 | "outputs": [],
202 | "source": [
203 | "all_de_dfs = {}\n",
204 | "ngenes = 4"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": 10,
210 | "id": "bb5bc067-c588-4a34-a8f3-138824531828",
211 | "metadata": {},
212 | "outputs": [],
213 | "source": [
214 | "sc.tl.rank_genes_groups(kam20_norn_ad, groupby=\"Disease_Identity\", use_raw=False, reference=\"Control\") # DE, diseases vs control"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": 11,
220 | "id": "db176232-0dee-4c1a-bda2-38a19a7afcdd",
221 | "metadata": {},
222 | "outputs": [],
223 | "source": [
224 | "de_df = sc.get.rank_genes_groups_df(kam20_norn_ad, group=\"COPD\") # get COPD vs control results\n",
225 | "all_de_dfs[\"copd_vs_control\"] = de_df[~de_df.index.isin(de_df.iloc[10:-10].index)] # top 10 and bottom 10 genes\n",
226 | "copd_control_genes = list(de_df.head(ngenes)[\"names\"].values)"
227 | ]
228 | },
229 | {
230 | "cell_type": "code",
231 | "execution_count": 12,
232 | "id": "c4d25f7a-4056-4e03-bcd1-9ab9cb2705e6",
233 | "metadata": {},
234 | "outputs": [],
235 | "source": [
236 | "de_df = sc.get.rank_genes_groups_df(kam20_norn_ad, group=\"IPF\") # get IPF vs control results\n",
237 | "all_de_dfs[\"ipf_vs_control\"] = de_df[~de_df.index.isin(de_df.iloc[10:-10].index)] # top 10 and bottom 10 genes\n",
238 | "ipf_control_genes = list(de_df.head(ngenes)[\"names\"].values)"
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": 13,
244 | "id": "4e72311d-44ee-4ca5-8abc-24e9079b574b",
245 | "metadata": {},
246 | "outputs": [],
247 | "source": [
248 | "sc.tl.rank_genes_groups(kam20_norn_ad, groupby=\"Disease_Identity\", use_raw=False, reference=\"IPF\") # DE, all vs IPF"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": 14,
254 | "id": "69be495e-3b0f-41f3-95d8-47a526f00bbd",
255 | "metadata": {},
256 | "outputs": [],
257 | "source": [
258 | "de_df = sc.get.rank_genes_groups_df(kam20_norn_ad, group=\"COPD\") # COPD vs IPF\n",
259 | "all_de_dfs[\"copd_vs_ipf\"] = de_df[~de_df.index.isin(de_df.iloc[10:-10].index)] # top 10 and bottom 10 genes\n",
260 | "copd_ipf_genes = list(de_df.head(ngenes)[\"names\"].values)"
261 | ]
262 | },
263 | {
264 | "cell_type": "code",
265 | "execution_count": 15,
266 | "id": "db9e2358-7431-49df-a980-ff4869d824d4",
267 | "metadata": {},
268 | "outputs": [],
269 | "source": [
270 | "sc.tl.rank_genes_groups(kam20_norn_ad, groupby=\"Disease_Identity\", use_raw=False, reference=\"COPD\") # DE, all vs COPD"
271 | ]
272 | },
273 | {
274 | "cell_type": "code",
275 | "execution_count": 16,
276 | "id": "753b9891-f940-46b7-9fa2-8972b6f67136",
277 | "metadata": {},
278 | "outputs": [],
279 | "source": [
280 | "de_df = sc.get.rank_genes_groups_df(kam20_norn_ad, group=\"IPF\") # IPF vs COPD\n",
281 | "all_de_dfs[\"ipf_vs_copd\"] = de_df[~de_df.index.isin(de_df.iloc[10:-10].index)] # top 10 and bottom 10 genes\n",
282 | "ipf_copd_genes = list(de_df.head(ngenes)[\"names\"].values)"
283 | ]
284 | },
285 | {
286 | "cell_type": "code",
287 | "execution_count": 17,
288 | "id": "647407f4-c9ef-405e-9742-342e31664497",
289 | "metadata": {},
290 | "outputs": [
291 | {
292 | "data": {
293 | "text/plain": [
294 | "['POSTN',\n",
295 | " 'COL1A1',\n",
296 | " 'COL3A1',\n",
297 | " 'SPARC',\n",
298 | " 'LUM',\n",
299 | " 'MFAP4',\n",
300 | " 'PTGDS',\n",
301 | " 'PTPRG',\n",
302 | " 'GPX3',\n",
303 | " 'NAMPT',\n",
304 | " 'RPL41',\n",
305 | " 'CRISPLD2',\n",
306 | " 'SERPINH1',\n",
307 | " 'COL1A2']"
308 | ]
309 | },
310 | "execution_count": 17,
311 | "metadata": {},
312 | "output_type": "execute_result"
313 | }
314 | ],
315 | "source": [
316 | "gene_list = ipf_control_genes + copd_control_genes + copd_ipf_genes + ipf_copd_genes\n",
317 | "\n",
318 | "reduced_gene_list = []\n",
319 | "for g in gene_list:\n",
320 | " if g in reduced_gene_list:\n",
321 | " next\n",
322 | " else:\n",
323 | " reduced_gene_list.append(g)\n",
324 | "reduced_gene_list"
325 | ]
326 | },
327 | {
328 | "cell_type": "markdown",
329 | "id": "c2abdd88-2f7b-4ac4-961a-ea417f85bb04",
330 | "metadata": {},
331 | "source": [
332 | "## Plot Results"
333 | ]
334 | },
335 | {
336 | "cell_type": "code",
337 | "execution_count": 18,
338 | "id": "f11581f2-2895-419e-85da-0f53c07756ec",
339 | "metadata": {},
340 | "outputs": [
341 | {
342 | "data": {
343 | "image/png": "",
344 | "text/plain": [
345 | ""
346 | ]
347 | },
348 | "metadata": {},
349 | "output_type": "display_data"
350 | }
351 | ],
352 | "source": [
353 | "fig, ax = plt.subplots(1,1, figsize=(5, 5))\n",
354 | "sc.pl.dotplot(kam20_norn_ad, groupby=\"Disease_Identity\", var_names=reduced_gene_list, show=True, swap_axes=True, ax=ax)"
355 | ]
356 | }
357 | ],
358 | "metadata": {
359 | "kernelspec": {
360 | "display_name": "Python 3 (ipykernel)",
361 | "language": "python",
362 | "name": "python3"
363 | },
364 | "language_info": {
365 | "codemirror_mode": {
366 | "name": "ipython",
367 | "version": 3
368 | },
369 | "file_extension": ".py",
370 | "mimetype": "text/x-python",
371 | "name": "python",
372 | "nbconvert_exporter": "python",
373 | "pygments_lexer": "ipython3",
374 | "version": "3.11.7"
375 | }
376 | },
377 | "nbformat": 4,
378 | "nbformat_minor": 5
379 | }
380 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | """
2 | Model class
3 |
4 | """
5 |
6 | import warnings
7 | warnings.filterwarnings("ignore")
8 | import math
9 | from torch import nn, Tensor
10 | from torch.nn import TransformerEncoder, TransformerEncoderLayer
11 |
12 | import sys
13 | sys.path.append('../')
14 | from typing import Any
15 | import torch
16 |
17 |
18 | def full_block(in_features, out_features, p_drop=0.1):
19 | return nn.Sequential(
20 | nn.Linear(in_features, out_features, bias=True),
21 | nn.LayerNorm(out_features),
22 | nn.GELU(),
23 | nn.Dropout(p=p_drop),
24 | )
25 |
26 |
27 | class PositionalEncoding(nn.Module):
28 |
29 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1536):
30 | super().__init__()
31 | self.dropout = nn.Dropout(p=dropout)
32 |
33 | position = torch.arange(max_len).unsqueeze(1)
34 | div_term = torch.exp \
35 | (torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
36 | pe = torch.zeros(max_len, 1, d_model)
37 | pe[:, 0, 0::2] = torch.sin(position * div_term)
38 | pe[:, 0, 1::2] = torch.cos(position * div_term)
39 | self.register_buffer('pe', pe)
40 |
41 | def forward(self, x: Tensor) -> Tensor:
42 | """
43 | Args:
44 | x: Tensor, shape [seq_len, batch_size, embedding_dim]
45 | """
46 | x = x + self.pe[:x.size(0)]
47 | return self.dropout(x)
48 |
49 |
50 | class TransformerModel(nn.Module):
51 |
52 | def __init__(self, token_dim: int, d_model: int, nhead: int, d_hid: int,
53 | nlayers: int, output_dim:int, dropout: float = 0.05):
54 | super().__init__()
55 | self.model_type = 'Transformer'
56 | self.pos_encoder = PositionalEncoding(d_model, dropout)
57 | self.d_model = d_model
58 |
59 | self.encoder = nn.Sequential(nn.Linear(token_dim, d_model),
60 | nn.GELU(),
61 | nn.LayerNorm(d_model))
62 |
63 |
64 |
65 | encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
66 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
67 |
68 |
69 | self.d_model = d_model
70 | self.dropout = dropout
71 |
72 |
73 | self.decoder = nn.Sequential(full_block(d_model, 1024, self.dropout),
74 | full_block(1024, output_dim, self.dropout),
75 | full_block(output_dim, output_dim, self.dropout),
76 | nn.Linear(output_dim, output_dim)
77 | )
78 |
79 | self.binary_decoder = nn.Sequential(
80 | full_block(output_dim + 1280, 2048, self.dropout),
81 | full_block(2048, 512, self.dropout),
82 | full_block(512, 128, self.dropout),
83 | nn.Linear(128, 1)
84 | )
85 |
86 | self.gene_embedding_layer = nn.Sequential(nn.Linear(token_dim, d_model),
87 | nn.GELU(),
88 | nn.LayerNorm(d_model))
89 |
90 | self.pe_embedding = None
91 |
92 | def forward(self, src: Tensor, mask: Tensor):
93 | """
94 | Args:
95 | src: Tensor, shape [seq_len, batch_size]
96 | Returns:
97 | output Tensor of shape [seq_len, batch_size, ntoken]
98 | """
99 | src = self.encoder(src) * math.sqrt(self.d_model)
100 | src = self.pos_encoder(src)
101 | output = self.transformer_encoder(src, src_key_padding_mask=( 1 -mask))
102 | gene_output = self.decoder(output) # batch x seq_len x 128
103 | # embedding = torch.mul(gene_output, mask.t().unsqueeze(2)).sum(0) # average over non zero genes
104 | # In the new format, the cls token, which is at the 0 index mark, is the output.
105 | embedding = gene_output[0, :, :] # select only the CLS token.
106 | embedding = nn.functional.normalize(embedding, dim=1) # Normalize.
107 | return gene_output, embedding
108 |
109 |
110 | def predict(self, cell_embedding, gene_embeddings):
111 | gene_embeddings = self.gene_embedding_layer(gene_embeddings)
112 | dec = self.binary_decoder \
113 | (torch.hstack((cell_embedding, gene_embeddings)))
114 | return dec
115 |
116 |
--------------------------------------------------------------------------------
/model_files/new_species_protein_embeddings.csv:
--------------------------------------------------------------------------------
1 | species,path
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.26.4
2 | scipy==1.14.1
3 | pandas==2.2.2
4 | tqdm==4.66.5
5 | torch==2.1.1
6 | scanpy==1.10.2
7 | accelerate==0.24.0
8 | requests==2.25.1
9 | urllib3==1.26.6
10 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utils
3 |
4 | """
5 |
6 | import warnings
7 | warnings.filterwarnings("ignore")
8 | import pandas as pd
9 | import numpy as np
10 | import os
11 | import requests
12 | from tqdm import tqdm
13 | import tarfile
14 |
15 |
16 | def get_shapes_dict(dataset_path):
17 | shapes_dict = {}
18 | datasets_df = pd.read_csv(dataset_path)
19 | sorted_dataset_names = sorted(datasets_df["names"])
20 |
21 | for name in sorted_dataset_names:
22 | shapes_dict[name] = (int(datasets_df.set_index("names").loc[name]["num_cells"]), 8000)
23 |
24 | shapes_dict["dev_immune_mouse"] = (443697, 4786)
25 | shapes_dict["dev_immune_human"] = (34009, 5566)
26 | shapes_dict["intestinal_tract_human"] = (69668, 5192)
27 | shapes_dict["gtex_human"] = (18511, 7109)
28 | shapes_dict["gut_endoderm_mouse"] = (113043, 6806)
29 | shapes_dict["luca"] = (249591, 7196)
30 | shapes_dict.update({
31 | "madissoon_novel_lung":(190728, 8000),
32 | 'flores_cerebellum_human': (20232, 8000),
33 | 'osuch_gut_human': (272310, 8000),
34 | 'msk_ovarian_human': (929690, 8000),
35 | 'htan_vmuc_dis_epi_human': (65084, 8000),
36 | 'htan_vmuc_val_epi_human': (57564, 8000),
37 | 'htan_vmuc_non_epi_human': (9099, 8000),
38 | 'hao_pbmc_3p_human': (161764, 8000),
39 | 'hao_pbmc_5p_human': (49147, 8000),
40 | 'gao_tumors_human': (36111, 8000),
41 | 'swabrick_breast_human': (92427, 8000),
42 | 'wu_cryo_tumors_human': (105662, 8000),
43 | 'cell_line_het_human': (53513, 8000),
44 | 'bi_allen_metastasis_human': (27787, 8000),
45 | 'zheng68k_human': (68579, 8000),
46 | 'zheng68k_12k_human': (68579, 12000),
47 | 'mouse_embryo_ct': (153597, 12000),
48 | "regev_gtex_heart": (36574, 8000),
49 | "tabula_sapiens_heart": (11505, 8000),
50 | "10k_pbmcs":(11990, 12000),
51 | "epo_ido":(35834,12000),
52 | 'tabula_sapiens_kidney': (9641, 8000),
53 | 'tabula_microcebus_kidney': (14592, 8000),
54 | 'tabula_muris_kidney': (2781, 8000),
55 | 'tabula_muris_senis_kidney': (19610, 8000),
56 | 'immune_human': (33506, 8000)
57 | })
58 |
59 | shapes_dict["zyl_sanes_glaucoma_pig"] = (5901, 6819)
60 | shapes_dict["parkinsons_macaF"] = (1062, 5103)
61 |
62 | for row in datasets_df.iterrows():
63 | ngenes = row[1].num_genes
64 | ncells = row[1].num_cells
65 | name = row[1].names
66 | if not np.isnan(ngenes):
67 | shapes_dict[name] = (int(ncells), int(ngenes))
68 |
69 | return shapes_dict
70 |
71 |
72 | def figshare_download(url, save_path):
73 | """
74 | Figshare download helper with progress bar
75 |
76 | Args:
77 | url (str): the url of the dataset
78 | path (str): the path to save the dataset
79 | """
80 |
81 | if os.path.exists(save_path):
82 | return
83 | else:
84 | # Check if directory exists
85 | if not os.path.exists(os.path.dirname(save_path)):
86 | os.makedirs(os.path.dirname(save_path))
87 | print("Downloading " + save_path + " from " + url + " ..." + "\n")
88 | response = requests.get(url, stream=True)
89 | total_size_in_bytes = int(response.headers.get('content-length', 0))
90 | block_size = 1024
91 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB',
92 | unit_scale=True)
93 | with open(save_path, 'wb') as file:
94 | for data in response.iter_content(block_size):
95 | progress_bar.update(len(data))
96 | file.write(data)
97 | progress_bar.close()
98 |
99 | # If the downloaded filename ends in tar.gz then extraact it
100 | if save_path.endswith(".tar.gz"):
101 | with tarfile.open(save_path) as tar:
102 | tar.extractall(path=os.path.dirname(save_path))
103 | print("Done!")
104 |
--------------------------------------------------------------------------------