├── .gitignore ├── LICENSE.txt ├── README.md ├── data └── README.md ├── env.yml ├── jupyter_nb ├── .gitignore ├── Process_Breast.ipynb └── scGraph-minimal.ipynb ├── meta ├── COVID │ ├── batch2cat.json │ └── cell2cat.json ├── brain │ ├── batch2cat.json │ └── cell2cat.json ├── breast │ ├── batch2cat.json │ └── cell2cat.json ├── eye │ ├── batch2cat.json │ └── cell2cat.json ├── gut_fetal │ ├── batch2cat.json │ └── cell2cat.json ├── heart │ ├── batch2cat.json │ └── cell2cat.json ├── lung │ ├── batch2cat.json │ └── cell2cat.json ├── lung_fetal_donor │ ├── batch2cat.json │ └── cell2cat.json ├── lung_fetal_organoid │ ├── batch2cat.json │ └── cell2cat.json ├── pancreas │ ├── batch2cat.json │ └── cell2cat.json └── skin │ ├── batch2cat.json │ └── cell2cat.json ├── res └── scGraph │ └── lung.csv ├── scripts ├── Lung_Mixup.log ├── Lung_scGraph.log ├── _Islander_MixUp.sh ├── _Islander_SCL.sh ├── _Islander_Triplet.sh ├── _download_data.sh ├── _scBenchmark.sh ├── _scGraph.sh ├── _scIB_Geneformer.sh └── _scIB_Islander.sh ├── src ├── ArgParser.py ├── Data_Handler.py ├── Utils_Handler.py ├── Vis_Handler.py ├── __init__.py ├── scBenchmarker.py ├── scDataset.py ├── scFinetuner.py ├── scGraph.py ├── scLoss.py ├── scModel.py └── scTrain.py └── teaser.png /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .DS_Store 3 | 4 | data 5 | models 6 | src/__pycache__/ 7 | wandb 8 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Genentech, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Islander 2 | This repository is the official implementation for the paper **Metric Mirages in Cell Embeddings**. 3 | 4 | Please contact wang.hanchen@gene.com or hanchenw@cs.stanford.edu if you have any questions. 5 | 6 | 7 | 8 | ![teaser](teaser.png) 9 | 10 | 11 | 12 | ### Citation 13 | 14 | ```bibtex 15 | @article {Islander, 16 | author = {Hanchen Wang and Jure Leskovec and Aviv Regev}, 17 | title = {Metric Mirages in Cell Embeddings}, 18 | doi = {10.1101/2024.04.02.587824}, 19 | publisher = {Cold Spring Harbor Laboratory}, 20 | URL = {https://www.biorxiv.org/content/early/2024/04/02/2024.04.02.587824} 21 | journal = {bioRxiv}, 22 | year = {2024}, 23 | } 24 | ``` 25 | 26 | 27 | 28 | --- 29 | 30 | 31 | 32 | ### Usage 33 | 34 | We include scripts and logs to reproduce the results in the scripts folder. You can also follow the step-by-step instructions below: 35 | 36 | **Step 0**: Set up the environment. 37 | 38 | ```bash 39 | conda env create -f env.yml 40 | ``` 41 | 42 | **NOTE**: The default setup uses GPU-compiled packages (for PyTorch, JAXlib, *etc*.). Please adjust them according to your local CUDA version *or* switch to the CPU version as needed. The calculation of scGraph scores does not require GPU access. 43 | 44 | **Step 1**: Preprocessing. Data can be downloaded from: 45 | 46 | | Brain | Breast | COVID | Eye | FetalGut | FetalLung | Heart | Lung | Pancreas | Skin | 47 | | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | 48 | | [Paper](https://www.science.org/doi/10.1126/science.add7046) | [Paper](https://www.nature.com/articles/s41586-023-06252-9) | [Paper](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7402042/) | [Paper](https://www.sciencedirect.com/science/article/pii/S2666979X22001069?via%3Dihub) | [Paper](https://www.sciencedirect.com/science/article/pii/S1534580720308868?via%3Dihub) | [Paper](https://linkinghub.elsevier.com/retrieve/pii/S0092867422014155) | [Paper](https://www.nature.com/articles/s44161-022-00183-w) | [Paper](https://www.nature.com/articles/s41591-023-02327-2) | [Paper](https://www.nature.com/articles/s41592-021-01336-8) | [Paper](https://www.nature.com/articles/s42003-020-0922-4) | 49 | | [Data](https://cellxgene.cziscience.com/collections/283d65eb-dd53-496d-adb7-7570c7caa443) | [Data](https://cellxgene.cziscience.com/collections/4195ab4c-20bd-4cd3-8b3d-65601277e731) | [Data](https://atlas.fredhutch.org/fredhutch/covid/) | [Data](https://cellxgene.cziscience.com/collections/348da6dc-5bf6-435d-adc5-37747b9ae38a) | [Data](https://cellxgene.cziscience.com/collections/17481d16-ee44-49e5-bcf0-28c0780d8c4a) | [Data](https://cellxgene.cziscience.com/collections/2d2e2acd-dade-489f-a2da-6c11aa654028) | [Data](https://cellxgene.cziscience.com/collections/43b45a20-a969-49ac-a8e8-8c84b211bd01) | [Data](https://cellxgene.cziscience.com/collections/6f6d381a-7701-4781-935c-db10d30de293) | [Data](https://figshare.com/articles/dataset/Benchmarking_atlas-level_data_integration_in_single-cell_genomics_-_integration_task_datasets_Immune_and_pancreas_/12420968?file=24539828) | [Data](https://cellxgene.cziscience.com/collections/c353707f-09a4-4f12-92a0-cb741e57e5f0) | 50 | 51 | 52 | 53 | We applied quality control to each dataset by filtering out cell profiles with fewer than 1,000 reads or fewer than 500 detected genes. Genes present in fewer than five cells were also excluded. Normalization was performed using Scanpy, where each cell’s read counts were scaled to a total of 10,000, followed by a log1p transformation: 54 | 55 | ```python 56 | # download via "wget -O data/breast/local.h5ad https://datasets.cellxgene.cziscience.com/b8b5be07-061b-4390-af0a-f9ced877a068.h5ad" 57 | adata = sc.read_h5ad(dh.DATA_RAW_["breast"]) 58 | adata.X = adata.raw.X 59 | adata.layers["raw_counts"] = adata.raw.X 60 | del adata.raw 61 | uh.preprocess(adata) 62 | 63 | [Output] 64 | filtered out 9954 cells that have less than 1000 counts 65 | filtered out 865 cells that have less than 500 genes expressed 66 | filtered out 3803 genes that are detected in less than 5 cells 67 | ============================================================================= 68 | 29431 genes x 703512 cells after quality control. 69 | ============================================================================= 70 | normalizing by total count per cell 71 | finished (0:00:06): normalized adata.X and added 'n_counts', counts per cell before normalization (adata.obs) 72 | ``` 73 | 74 | 75 | 76 | The top 1000 highly variable genes are selected through: 77 | 78 | ```python 79 | sc.pp.highly_variable_genes(adata, subset=True, flavor="seurat_v3", n_top_genes=1000) 80 | ``` 81 | 82 | Then metadata is saved as JSON files. See the minimal example: [Process_Breast.ipynb](jupyter_nb/Process_Breast.ipynb). 83 | 84 | 85 | 86 | **Step 2**: Run Islander and benchmark with scIB 87 | 88 | ```bash 89 | cd ${HOME}/Islander/src 90 | 91 | export LR=0.001 92 | export EPOCH=10 93 | export MODE="mixup" 94 | export LEAKAGE=16 95 | export MLPSIZE="128 128" 96 | export DATASET_List=("lung" "lung_fetal_donor" "lung_fetal_organoid" \ 97 | "brain" "breast" "heart" "eye" "gut_fetal" "skin" "COVID" "pancreas") 98 | 99 | for DATASET in "${DATASET_List[@]}"; do 100 | export PROJECT="_${DATASET}_" 101 | export SavePrefix="${HOME}/Islander/models/${PROJECT}" 102 | export RUNNAME="MODE-${MODE}-ONLY_LEAK-${LEAKAGE}_MLP-${MLPSIZE}" 103 | echo "DATASET-${DATASET}_${RUNNAME}" 104 | mkdir -p $SavePrefix 105 | 106 | # === Training === 107 | python scTrain.py \ 108 | --gpu 3 \ 109 | --lr ${LR} \ 110 | --mode ${MODE} \ 111 | --epoch ${EPOCH} \ 112 | --dataset ${DATASET} \ 113 | --leakage ${LEAKAGE} \ 114 | --project ${PROJECT} \ 115 | --mlp_size ${MLPSIZE} \ 116 | --runname "${RUNNAME}" \ 117 | --savename "${SavePrefix}/${RUNNAME}"; 118 | 119 | # === Benchmarking === 120 | python scBenchmarker.py \ 121 | --islander \ 122 | --saveadata \ 123 | --dataset "${DATASET}" \ 124 | --save_path "${SavePrefix}/${RUNNAME}"; 125 | done 126 | ``` 127 | 128 | 129 | 130 | We have also provided variants of Islander, which make use of different forms of semi-supervised learning loss (triplet and supervised contrastive loss). See [scripts/_Islander_SCL.sh](scripts/_Islander_SCL.sh) and [scripts/_Islander_Triplet.sh](scripts/_Islander_Triplet.sh) for details. 131 | 132 | 133 | 134 | **Step 3**: Run integration methods and benchmark with scIB 135 | 136 | ```bash 137 | export DATASET_List="lung_fetal_donor" 138 | echo -e "\n\n" 139 | 140 | echo "DATASET-${DATASET}_HVG" 141 | export CUDA_VISIBLE_DEVICES=2 & python scBenchmarker.py \ 142 | --all \ 143 | --highvar \ 144 | --saveadata \ 145 | --dataset "${DATASET}" \ 146 | --savecsv "${DATASET}_FULL" \ 147 | --save_path "${HOME}/Islander_dev/models/_${DATASET}_/MODE-mixup-ONLY_LEAK-16_MLP-128 128"; 148 | 149 | # === highly variable genes === 150 | echo "DATASET-${DATASET}_HVG" 151 | export CUDA_VISIBLE_DEVICES=2 & python scBenchmarker.py \ 152 | --all \ 153 | --highvar \ 154 | --saveadata \ 155 | --dataset "${DATASET}" \ 156 | --savecsv "${DATASET}_HVG" \ 157 | --save_path "${HOME}/Islander_dev/models/_${DATASET}_/MODE-mixup-ONLY_LEAK-16_MLP-128 128"; 158 | ``` 159 | 160 | 161 | 162 | **Step 4**: Run and benchmark foundation models 163 | 164 | Please refer to the authors' original tutorials ([scGPT](https://github.com/bowang-lab/scGPT/tree/main/tutorials/zero-shot), [Geneformer](https://huggingface.co/ctheodoris/Geneformer/tree/main/examples), [scFoundation](https://github.com/biomap-research/scFoundation/tree/main/model), [UCE](https://github.com/snap-stanford/UCE)) for extracting zero-shot and fine-tuned cell embeddings. We provide a minimal example notebook [_nb/Geneformer_Skin.ipynb](jupyter_nb/Geneformer_Skin.ipynb) to extract zero-shot cell embeddings for the skin dataset using pre-trained Geneformer. To evaluate such embedding with scIB: 165 | 166 | ```bash 167 | cd ${HOME}/Islander/src 168 | 169 | export DATASET="brain" 170 | echo -e "\n\n" 171 | echo "DATASET-${DATASET}_Geneformer" 172 | python scBenchmarker.py \ 173 | --obsm_keys Geneformer \ 174 | --dataset "${DATASET}" \ 175 | --savecsv "${DATASET}_Geneformer" \ 176 | --save_path "${HOME}/Islander/models/_${skin}_/MODE-mixup-ONLY_LEAK-16_MLP-128 128"; 177 | 178 | ``` 179 | 180 | 181 | 182 | **Step 5**: Benchmark with scGraph (can be replaced on customized AnnData file) 183 | 184 | ```bash 185 | cd ${HOME}/Islander/src 186 | 187 | python scGraph.py \ 188 | --adata_path ${HOME}/Islander/data/lung/emb.h5ad \ 189 | --batch_key sample \ 190 | --label_key cell_type \ 191 | --savename ${HOME}/lung_scGraph; 192 | ``` 193 | 194 | The output file is in the format **Corr-Weights**, reported as scGraph scores in the paper. It is based on weighted rank correlation, where the weights are inversely proportional to the inter-cluster centroid distances. **Corr-PCA** represents the rank correlation using equal weights. **Rank-PCA** represents rank differences. 195 | 196 | | | Rank-PCA | Corr-PCA | Corr-Weights | 197 | | :---------- | -------: | -------: | -----------: | 198 | | Geneformer | 0.610 | 0.799 | 0.498 | 199 | | Harmony | 0.670 | 0.924 | 0.678 | 200 | | Harmony_hvg | 0.709 | 0.941 | 0.724 | 201 | | Islander | 0.292 | 0.847 | 0.160 | 202 | 203 | 204 | 205 | **Parameter Settings in scGraph**: 206 | 207 | scGraph uses PCA for each cell type within each batch to represent cluster-cluster relationships. Batches with fewer than 100 cells or cell types with fewer than 10 cells are excluded. PCA is calculated on the 1,000 highly variable genes, after removing 10% of the cells (5% from each extreme). 208 | 209 | All the numerical values mentioned above are adjustable. For further details, please refer to `scGraph.py`. 210 | 211 | --- 212 | 213 | 214 | 215 | **Case study**: scGraph vs scIB on fibroblast cells from the human fetal lung 216 | 217 | Please see [Fibroblast_Case.ipynb](jupyter_nb/Fibroblast_Case.ipynb) to reproduce the results reported in the paper. 218 | 219 | 220 | 221 | ### File Organization 222 | 223 | ``` 224 | ├── LICENSE.txt 225 | ├── README.md 226 | ├── data 227 | ├── env.yml # for GPU environments 228 | ├── jupyter_nb 229 | │ ├── Fibroblast_Case.ipynb 230 | │ ├── Geneformer_Skin.ipynb 231 | │ └── Process_Breast.ipynb 232 | ├── meta 233 | │ ├── COVID 234 | │ │ ├── batch2cat.json 235 | │ │ └── cell2cat.json 236 | │ ├── ... 237 | ├── res 238 | │ └── scGraph 239 | │ └── scIB 240 | ├── scripts 241 | │ ├── Lung_Mixup.log 242 | │ ├── Lung_scGraph.log 243 | │ ├── _Islander_MixUp.sh 244 | │ ├── _Islander_SCL.sh 245 | │ ├── _Islander_Triplet.sh 246 | │ ├── _download_data.sh 247 | │ ├── _scBenchmark.sh 248 | │ ├── _scGraph.sh 249 | │ ├── _scIB_Geneformer.sh 250 | │ └── _scIB_Islander.sh # scib benchmark on cell islands embeddings 251 | ├── src 252 | │ ├── ArgParser.py 253 | │ ├── Data_Handler.py 254 | │ ├── Utils_Handler.py 255 | │ ├── Vis_Handler.py 256 | │ ├── __init__.py 257 | │ ├── scBenchmarker.py 258 | │ ├── scDataset.py 259 | │ ├── scFinetuner.py 260 | │ ├── scGraph.py 261 | │ ├── scLoss.py 262 | │ ├── scModel.py 263 | │ └── scTrain.py 264 | └── teaser.png 265 | ``` 266 | 267 | 268 | 269 | ### Stand-alone scGraph package for evaluations 270 | 271 | We also provided a standalone [scgraph](https://pypi.org/project/scgraph-eval/) python package, that can be installed via: 272 | 273 | ``` 274 | # conda create -n scgraph python=3.10 # to create another conda environment if necessary 275 | pip install scgraph-eval 276 | ``` 277 | 278 | #### Python 279 | 280 | ```python3 281 | from scgraph import scGraph 282 | 283 | # Initialize the graph analyzer 284 | scgraph = scGraph( 285 | adata_path="path/to/your/data.h5ad", # Path to AnnData object 286 | batch_key="batch", # Column name for batch information 287 | label_key="cell_type", # Column name for cell type labels 288 | trim_rate=0.05, # Trim rate for robust mean calculation 289 | thres_batch=100, # Minimum number of cells per batch 290 | thres_celltype=10, # Minimum number of cells per cell type 291 | only_umap=True, # Only evaluate 2D embeddings (mostly umaps) 292 | ) 293 | 294 | # Run the analysis, return a pandas dataframe 295 | results = scgraph.main() 296 | 297 | # Save the results 298 | results.to_csv("embedding_evaluation_results.csv") 299 | ``` 300 | 301 | #### Command line 302 | 303 | ```bash 304 | scgraph --adata_path path/to/data.h5ad --batch_key batch --label_key cell_type --savename results 305 | ``` 306 | 307 | #### Notebook 308 | 309 | We provide a notebook on how to install the scgraph on a labtop, and reproduce the results on using scgraph to evaluate cell embeddings of fibroblast family in human fetal lung. 310 | 311 | - [Data](https://drive.google.com/file/d/1a2UF4V_INGMKayCoMErZG-_kq_KuZmjA/view?usp=drive_link) 312 | 313 | - [Notebook](jupyter_nb/scGraph-minimal.ipynb) 314 | 315 | #### Contributing 316 | 317 | Contributions are welcome! Please feel free to submit a Pull Request. 318 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | Placeholder for the `data/` repo -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: scIntegrater 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.08.22=h06a4308_0 8 | - ld_impl_linux-64=2.38=h1181459_1 9 | - libffi=3.4.4=h6a678d5_0 10 | - libgcc-ng=11.2.0=h1234567_1 11 | - libgomp=11.2.0=h1234567_1 12 | - libstdcxx-ng=11.2.0=h1234567_1 13 | - ncurses=6.4=h6a678d5_0 14 | - openssl=3.0.12=h7f8727e_0 15 | - pip=23.3=py39h06a4308_0 16 | - python=3.9.18=h955ad1f_0 17 | - readline=8.2=h5eee18b_0 18 | - setuptools=68.0.0=py39h06a4308_0 19 | - sqlite=3.41.2=h5eee18b_0 20 | - tk=8.6.12=h1ccaba5_0 21 | - wheel=0.41.2=py39h06a4308_0 22 | - xz=5.4.2=h5eee18b_0 23 | - zlib=1.2.13=h5eee18b_0 24 | - pip: 25 | - absl-py==2.0.0 26 | - adjusttext==0.8 27 | - aiohttp==3.9.0 28 | - aiosignal==1.3.1 29 | - anndata==0.10.3 30 | - annotated-types==0.6.0 31 | - annoy==1.17.3 32 | - anyio==3.7.1 33 | - appdirs==1.4.4 34 | - argon2-cffi==23.1.0 35 | - argon2-cffi-bindings==21.2.0 36 | - array-api-compat==1.4 37 | - arrow==1.3.0 38 | - asttokens==2.4.1 39 | - async-lru==2.0.4 40 | - async-timeout==4.0.3 41 | - attrs==23.1.0 42 | - babel==2.13.1 43 | - backoff==2.2.1 44 | - bbknn==1.6.0 45 | - beautifulsoup4==4.12.2 46 | - black==23.11.0 47 | - bleach==6.1.0 48 | - blessed==1.20.0 49 | - boto3==1.29.3 50 | - botocore==1.32.3 51 | - certifi==2023.11.17 52 | - cffi==1.16.0 53 | - charset-normalizer==3.3.2 54 | - chex==0.1.7 55 | - click==8.1.7 56 | - comm==0.2.0 57 | - contextlib2==21.6.0 58 | - contourpy==1.2.0 59 | - croniter==1.4.1 60 | - cycler==0.12.1 61 | - cython==3.0.5 62 | - dateutils==0.6.12 63 | - debugpy==1.8.0 64 | - decorator==5.1.1 65 | - deepdiff==6.7.1 66 | - defusedxml==0.7.1 67 | - dm-tree==0.1.8 68 | - docker-pycreds==0.4.0 69 | - docrep==0.3.2 70 | - etils==1.5.2 71 | - exceptiongroup==1.1.3 72 | - executing==2.0.1 73 | - fastapi==0.104.1 74 | - fastjsonschema==2.19.0 75 | - fbpca==1.0 76 | - filelock==3.13.1 77 | - flax==0.7.5 78 | - fonttools==4.44.3 79 | - fqdn==1.5.1 80 | - frozenlist==1.4.0 81 | - fsspec==2023.10.0 82 | - gdown==4.7.1 83 | - geosketch==1.2 84 | - get-annotations==0.1.2 85 | - gitdb==4.0.11 86 | - gitpython==3.1.40 87 | - h11==0.14.0 88 | - h5py==3.10.0 89 | - harmony-pytorch==0.1.7 90 | - idna==3.4 91 | - igraph==0.10.8 92 | - importlib-metadata==6.8.0 93 | - importlib-resources==6.1.1 94 | - inquirer==3.1.3 95 | - intervaltree==3.1.0 96 | - ipykernel==6.26.0 97 | - ipython==8.17.2 98 | - ipywidgets==8.1.1 99 | - isoduration==20.11.0 100 | - itsdangerous==2.1.2 101 | - jax==0.4.20 102 | - jaxlib==0.4.20+cuda11.cudnn86 103 | - jedi==0.19.1 104 | - jinja2==3.1.2 105 | - jmespath==1.0.1 106 | - joblib==1.3.2 107 | - json5==0.9.14 108 | - jsonpointer==2.4 109 | - jsonschema==4.20.0 110 | - jsonschema-specifications==2023.11.1 111 | - jupyter==1.0.0 112 | - jupyter-client==8.6.0 113 | - jupyter-console==6.6.3 114 | - jupyter-core==5.5.0 115 | - jupyter-events==0.9.0 116 | - jupyter-lsp==2.2.0 117 | - jupyter-server==2.10.1 118 | - jupyter-server-terminals==0.4.4 119 | - jupyterlab==4.0.8 120 | - jupyterlab-pygments==0.2.2 121 | - jupyterlab-server==2.25.1 122 | - jupyterlab-widgets==3.0.9 123 | - kiwisolver==1.4.5 124 | - leidenalg==0.10.1 125 | - lightning==2.0.9.post0 126 | - lightning-cloud==0.5.54 127 | - lightning-utilities==0.10.0 128 | - llvmlite==0.41.1 129 | - markdown-it-py==3.0.0 130 | - markupsafe==2.1.3 131 | - matplotlib==3.8.2 132 | - matplotlib-inline==0.1.6 133 | - mdurl==0.1.2 134 | - mistune==3.0.2 135 | - ml-collections==0.1.1 136 | - ml-dtypes==0.3.1 137 | - mnnpy==0.1.9.5 138 | - mpmath==1.3.0 139 | - msgpack==1.0.7 140 | - mudata==0.2.3 141 | - multidict==6.0.4 142 | - multipledispatch==1.0.0 143 | - muon==0.1.5 144 | - mypy-extensions==1.0.0 145 | - natsort==8.4.0 146 | - nbclient==0.9.0 147 | - nbconvert==7.11.0 148 | - nbformat==5.9.2 149 | - nest-asyncio==1.5.8 150 | - networkx==3.2.1 151 | - newick==1.0.0 152 | - notebook==7.0.6 153 | - notebook-shim==0.2.3 154 | - numba==0.58.1 155 | - numpy==1.26.2 156 | - numpyro==0.13.2 157 | - nvidia-cublas-cu11==11.11.3.6 158 | - nvidia-cublas-cu12==12.1.3.1 159 | - nvidia-cuda-cupti-cu11==11.8.87 160 | - nvidia-cuda-cupti-cu12==12.1.105 161 | - nvidia-cuda-nvcc-cu11==11.8.89 162 | - nvidia-cuda-nvcc-cu12==12.3.103 163 | - nvidia-cuda-nvrtc-cu11==11.8.89 164 | - nvidia-cuda-nvrtc-cu12==12.1.105 165 | - nvidia-cuda-runtime-cu11==11.8.89 166 | - nvidia-cuda-runtime-cu12==12.1.105 167 | - nvidia-cudnn-cu11==8.9.6.50 168 | - nvidia-cudnn-cu12==8.9.2.26 169 | - nvidia-cufft-cu11==10.9.0.58 170 | - nvidia-cufft-cu12==11.0.2.54 171 | - nvidia-curand-cu12==10.3.2.106 172 | - nvidia-cusolver-cu11==11.4.1.48 173 | - nvidia-cusolver-cu12==11.4.5.107 174 | - nvidia-cusparse-cu11==11.7.5.86 175 | - nvidia-cusparse-cu12==12.1.0.106 176 | - nvidia-nccl-cu11==2.19.3 177 | - nvidia-nccl-cu12==2.18.1 178 | - nvidia-nvjitlink-cu12==12.3.101 179 | - nvidia-nvtx-cu12==12.1.105 180 | - opt-einsum==3.3.0 181 | - optax==0.1.7 182 | - orbax-checkpoint==0.4.3 183 | - ordered-set==4.1.0 184 | - overrides==7.4.0 185 | - packaging==23.2 186 | - pandas==2.1.3 187 | - pandocfilters==1.5.0 188 | - parso==0.8.3 189 | - pathspec==0.11.2 190 | - patsy==0.5.3 191 | - pexpect==4.8.0 192 | - pillow==10.1.0 193 | - platformdirs==4.0.0 194 | - plottable==0.1.5 195 | - prometheus-client==0.18.0 196 | - prompt-toolkit==3.0.41 197 | - protobuf==4.25.1 198 | - psutil==5.9.6 199 | - ptyprocess==0.7.0 200 | - pure-eval==0.2.2 201 | - pycparser==2.21 202 | - pydantic==2.1.1 203 | - pydantic-core==2.4.0 204 | - pygments==2.16.1 205 | - pyjwt==2.8.0 206 | - pynndescent==0.5.10 207 | - pyparsing==3.1.1 208 | - pyro-api==0.1.2 209 | - pyro-ppl==1.8.6 210 | - pysocks==1.7.1 211 | - python-dateutil==2.8.2 212 | - python-editor==1.0.4 213 | - python-json-logger==2.0.7 214 | - python-multipart==0.0.6 215 | - pytorch-lightning==2.1.2 216 | - pytz==2023.3.post1 217 | - pyyaml==6.0.1 218 | - pyzmq==25.1.1 219 | - qtconsole==5.5.1 220 | - qtpy==2.4.1 221 | - readchar==4.0.5 222 | - referencing==0.31.0 223 | - requests==2.31.0 224 | - rfc3339-validator==0.1.4 225 | - rfc3986-validator==0.1.1 226 | - rich==13.7.0 227 | - rpds-py==0.13.0 228 | - s3transfer==0.7.0 229 | - scanorama==1.7.4 230 | - scanpy==1.9.6 231 | - scarches==0.5.9 232 | - scgen==2.1.1 233 | - schpl==1.0.3 234 | - scib-metrics==0.4.1 235 | - scikit-learn==1.3.2 236 | - scipy==1.11.3 237 | - scvi-tools==1.0.4 238 | - seaborn==0.12.2 239 | - send2trash==1.8.2 240 | - sentry-sdk==1.35.0 241 | - session-info==1.0.0 242 | - setproctitle==1.3.3 243 | - six==1.16.0 244 | - smmap==5.0.1 245 | - sniffio==1.3.0 246 | - sortedcontainers==2.4.0 247 | - soupsieve==2.5 248 | - sparse==0.14.0 249 | - stack-data==0.6.3 250 | - starlette==0.27.0 251 | - starsessions==1.3.0 252 | - statsmodels==0.14.0 253 | - stdlib-list==0.10.0 254 | - sympy==1.12 255 | - tensorstore==0.1.50 256 | - terminado==0.18.0 257 | - texttable==1.7.0 258 | - threadpoolctl==3.2.0 259 | - tinycss2==1.2.1 260 | - tomli==2.0.1 261 | - toolz==0.12.0 262 | - torch==2.1.1 263 | - torchmetrics==1.2.0 264 | - tornado==6.3.3 265 | - tqdm==4.66.1 266 | - traitlets==5.13.0 267 | - triton==2.1.0 268 | - types-python-dateutil==2.8.19.14 269 | - typing-extensions==4.8.0 270 | - tzdata==2023.3 271 | - umap-learn==0.5.5 272 | - uri-template==1.3.0 273 | - urllib3==1.26.18 274 | - uvicorn==0.24.0.post1 275 | - wandb==0.16.0 276 | - wcwidth==0.2.10 277 | - webcolors==1.13 278 | - webencodings==0.5.1 279 | - websocket-client==1.6.4 280 | - websockets==12.0 281 | - widgetsnbextension==4.0.9 282 | - xarray==2023.11.0 283 | - yarl==1.9.2 284 | - zipp==3.17.0 285 | prefix: /home/wangh256/anaconda3/envs/scIntegrater 286 | -------------------------------------------------------------------------------- /jupyter_nb/.gitignore: -------------------------------------------------------------------------------- 1 | \[Tutorial\]*.ipynb 2 | _deprecated 3 | .ipynb_checkpoints/ 4 | _temp* 5 | -------------------------------------------------------------------------------- /meta/COVID/batch2cat.json: -------------------------------------------------------------------------------- 1 | {"1": 0, "2": 1, "3": 2, "4": 3, "5": 4, "6": 5, "7": 6, "8": 7, "9": 8, "10": 9} -------------------------------------------------------------------------------- /meta/COVID/cell2cat.json: -------------------------------------------------------------------------------- 1 | {"CD8 TCM": 0, "CD8 Naive": 1, "CD16 Mono": 2, "CD14 Mono": 3, "CD4 Naive": 4, "CD4 TCM": 5, "NK": 6, "B naive": 7, "MAIT": 8, "B intermediate": 9, "CD4 TEM": 10, "NK_CD56bright": 11, "CD8 TEM": 12, "gdT": 13, "pDC": 14, "Treg": 15, "B memory": 16, "cDC2": 17, "HSPC": 18, "Platelet": 19, "NK Proliferating": 20, "CD4 CTL": 21, "ILC": 22, "dnT": 23, "cDC1": 24, "Plasmablast": 25, "ASDC": 26, "CD8 Proliferating": 27, "CD4 Proliferating": 28, "Eryth": 29, "Doublet": 30} -------------------------------------------------------------------------------- /meta/brain/batch2cat.json: -------------------------------------------------------------------------------- 1 | {"H19.30.002": 0, "H19.30.001": 1, "H18.30.002": 2, "H18.30.001": 3} -------------------------------------------------------------------------------- /meta/brain/cell2cat.json: -------------------------------------------------------------------------------- 1 | {"oligodendrocyte": 0, "astrocyte": 1, "Bergmann glial cell": 2, "oligodendrocyte precursor cell": 3, "ependymal cell": 4, "choroid plexus epithelial cell": 5, "fibroblast": 6, "pericyte": 7, "vascular associated smooth muscle cell": 8, "endothelial cell": 9, "central nervous system macrophage": 10} -------------------------------------------------------------------------------- /meta/breast/batch2cat.json: -------------------------------------------------------------------------------- 1 | {"P01": 0, "P02": 1, "P03": 2, "P04": 3, "P05": 4, "P06": 5, "P07": 6, "P08": 7, "P09": 8, "P10": 9, "P11": 10, "P20": 11, "P22": 12, "P23": 13, "P37": 14, "P27": 15, "P42": 16, "P43": 17, "P12": 18, "P13": 19, "P14": 20, "P15": 21, "P16": 22, "P17": 23, "P18": 24, "P19": 25, "P21": 26, "P24": 27, "P25": 28, "P26": 29, "P29": 30, "P30": 31, "P31": 32, "P32": 33, "P33": 34, "P34": 35, "P35": 36, "P36": 37, "P44": 38, "P28": 39, "P56": 40, "P91": 41, "P95": 42, "P45": 43, "P50": 44, "P51": 45, "P52": 46, "P53": 47, "P54": 48, "P55": 49, "P57": 50, "P38": 51, "P39": 52, "P40": 53, "P58": 54, "P59": 55, "P60": 56, "P63": 57, "P66": 58, "P67": 59, "P70": 60, "P71": 61, "P73": 62, "P75": 63, "P76": 64, "P77": 65, "P79": 66, "P82": 67, "P83": 68, "P87": 69, "P88": 70, "P90": 71, "P94": 72, "P97": 73, "P98": 74, "P99": 75, "P101": 76, "P103": 77, "P105": 78, "P106": 79, "P107": 80, "P108": 81, "P110": 82, "P111": 83, "P112": 84, "P113": 85, "P114": 86, "P115": 87, "P116": 88, "P117": 89, "P118": 90, "P119": 91, "P120": 92, "P121": 93, "P122": 94, "P123": 95, "P124": 96, "P125": 97, "P126": 98, "P46": 99, "P47": 100, "P48": 101, "P62": 102, "P64": 103, "P65": 104, "P68": 105, "P69": 106, "P72": 107, "P74": 108, "P78": 109, "P80": 110, "P81": 111, "P84": 112, "P85": 113, "P86": 114, "P89": 115, "P92": 116, "P93": 117, "P96": 118, "P100": 119, "P102": 120, "P104": 121, "P49": 122, "P41": 123, "P61": 124, "P109": 125} -------------------------------------------------------------------------------- /meta/breast/cell2cat.json: -------------------------------------------------------------------------------- 1 | {"luminal epithelial cell of mammary gland": 0, "mammary gland epithelial cell": 1, "basal cell": 2, "naive B cell": 3, "class switched memory B cell": 4, "IgA plasma cell": 5, "IgG plasma cell": 6, "unswitched memory B cell": 7, "mature NK T cell": 8, "effector memory CD8-positive, alpha-beta T cell": 9, "CD4-positive, alpha-beta T cell": 10, "gamma-delta T cell": 11, "CD4-positive helper T cell": 12, "regulatory T cell": 13, "CD8-positive, alpha-beta memory T cell": 14, "natural killer cell": 15, "lymphocyte": 16, "activated CD4-positive, alpha-beta T cell": 17, "effector memory CD4-positive, alpha-beta T cell": 18, "activated CD8-positive, alpha-beta T cell": 19, "T cell": 20, "macrophage": 21, "classical monocyte": 22, "myeloid dendritic cell": 23, "non-classical monocyte": 24, "conventional dendritic cell": 25, "alternatively activated macrophage": 26, "inflammatory macrophage": 27, "myeloid cell": 28, "plasmacytoid dendritic cell": 29, "neutrophil": 30, "fibroblast": 31, "endothelial cell of lymphatic vessel": 32, "capillary endothelial cell": 33, "vein endothelial cell": 34, "endothelial cell of artery": 35, "vascular associated smooth muscle cell": 36, "pericyte": 37, "mast cell": 38} -------------------------------------------------------------------------------- /meta/eye/batch2cat.json: -------------------------------------------------------------------------------- 1 | {"LVG1_retina_OD": 0, "LVG1_retina_OS": 1, "LGS1_retina_OD": 2, "LGS1_retina_OS": 3, "LGS2_retina_OD": 4, "LGS2_retina_OS": 5, "LGS3_retina_OD": 6, "LGS3_retina_OS": 7} -------------------------------------------------------------------------------- /meta/eye/cell2cat.json: -------------------------------------------------------------------------------- 1 | {"retinal ganglion cell": 0, "retinal cone cell": 1, "amacrine cell": 2, "astrocyte": 3, "microglial cell": 4, "retinal rod cell": 5, "retina horizontal cell": 6, "Mueller cell": 7, "OFF-bipolar cell": 8, "rod bipolar cell": 9, "ON-bipolar cell": 10} -------------------------------------------------------------------------------- /meta/gut_fetal/batch2cat.json: -------------------------------------------------------------------------------- 1 | {"2029": 0, "2026": 1, "2043": 2, "2046": 3, "2049": 4, "2121": 5, "2119": 6, "2133": 7, "2134": 8} -------------------------------------------------------------------------------- /meta/gut_fetal/cell2cat.json: -------------------------------------------------------------------------------- 1 | {"mesodermal cell": 0, "fibroblast": 1, "hematopoietic cell": 2, "enteric smooth muscle cell": 3, "neural crest cell": 4, "enteric neuron": 5, "progenitor cell": 6, "mesothelial cell": 7, "pericyte": 8, "vein endothelial cell": 9, "endothelial cell of artery": 10, "endothelial cell of lymphatic vessel": 11, "myofibroblast cell": 12, "erythroblast": 13, "enterocyte": 14, "colon epithelial cell": 15, "interstitial cell of Cajal": 16, "enteroendocrine cell": 17, "intestine goblet cell": 18, "stem cell": 19, "epithelial cell": 20} -------------------------------------------------------------------------------- /meta/heart/batch2cat.json: -------------------------------------------------------------------------------- 1 | {"H5": 0, "H6": 1, "H3": 2, "H2": 3, "H7": 4, "H4": 5, "D1": 6, "D2": 7, "D3": 8, "D4": 9, "D5": 10, "D6": 11, "D7": 12, "D11": 13} -------------------------------------------------------------------------------- /meta/heart/cell2cat.json: -------------------------------------------------------------------------------- 1 | {"regular ventricular cardiac myocyte": 0, "pericyte": 1, "native cell": 2, "endothelial cell": 3, "vein endothelial cell": 4, "fibroblast": 5, "capillary endothelial cell": 6, "epicardial adipocyte": 7, "endothelial cell of artery": 8, "smooth muscle cell": 9, "neural cell": 10, "macrophage": 11, "CD8-positive, alpha-beta cytotoxic T cell": 12, "endothelial cell of lymphatic vessel": 13, "activated CD8-positive, alpha-beta T cell": 14, "mast cell": 15, "natural killer cell": 16, "CD4-positive, alpha-beta cytotoxic T cell": 17, "mature NK T cell": 18, "dendritic cell": 19, "CD14-positive, CD16-positive monocyte": 20, "B cell": 21, "regular atrial cardiac myocyte": 22, "monocyte": 23, "CD14-positive monocyte": 24, "activated CD4-positive, alpha-beta T cell": 25, "mesothelial cell": 26} -------------------------------------------------------------------------------- /meta/lung/batch2cat.json: -------------------------------------------------------------------------------- 1 | {"SC22": 0, "distal 2": 1, "SC27": 2, "medial 2": 3, "T164": 4, "proximal 3": 5, "distal 1a": 6, "F02617": 7, "GRO-03_biopsy": 8, "GRO-10_biopsy": 9, "SC18": 10, "VUHD68": 11, "SC142": 12, "F02526": 13, "T165": 14, "VUHD67": 15, "SC24": 16, "T121": 17, "D363_Brus_Dis1": 18, "GRO-08_biopsy": 19, "7185212": 20, "SC14": 21, "F01394": 22, "7239217": 23, "SC141": 24, "F01874": 25, "SC156": 26, "D372_Biop_Int2": 27, "distal 3": 28, "368C_12h": 29, "GRO-06_biopsy": 30, "GRO-01_nasal_brush": 31, "GRO-02_biopsy": 32, "T154": 33, "SC84": 34, "SC85": 35, "GRO-10_nasal_brush": 36, "D326_Biop_Pro1": 37, "GRO-09_biopsy": 38, "F02524": 39, "390C_12h": 40, "F01641": 41, "SC182": 42, "VUHD70": 43, "T90": 44, "F01365": 45, "T166": 46, "T89": 47, "GRO-09_nasal_brush": 48, "D353_Brus_Nas1": 49, "D353_Brus_Dis1": 50, "D337_Brus_Dis1": 51, "D354_Biop_Pro1": 52, "D344_Biop_Nas1": 53, "7135920": 54, "356C_72h": 55, "D354_Brus_Dis1": 56, "SC20": 57, "F01367": 58, "SC45": 59, "F02528": 60, "SC173": 61, "D339_Biop_Nas1": 62, "SC56": 63, "SC29": 64, "D367_Biop_Pro1": 65, "F02522": 66, "SC07": 67, "356C_0h": 68, "T137": 69, "F02992": 70, "D363_Brus_Nas1": 71, "SC89": 72, "F02509": 73, "SC184": 74, "390C_0h": 75, "SC86": 76, "D372_Biop_Pro1": 77, "D353_Biop_Pro1": 78, "D372_Brus_Nas1": 79, "SC174_SC173": 80, "356C_12h": 81, "SC174_SC172": 82, "SC143": 83, "D344_Brus_Dis1": 84, "F02611": 85, "F01607": 86, "D367_Brus_Nas1": 87, "T120": 88, "SC144": 89, "F01639": 90, "SC172": 91, "D339_Biop_Int1": 92, "D322_Biop_Pro1": 93, "F02607": 94, "SC10": 95, "F01851": 96, "VUHD101": 97, "D326_Biop_Int1": 98, "T167": 99, "F03045": 100, "D354_Biop_Int2": 101, "T85": 102, "356C_24h": 103, "GRO-03_nasal_brush": 104, "F01157": 105, "F01403": 106, "SC155": 107, "D353_Biop_Int2": 108, "VUHD66": 109, "368C_72h": 110, "GRO-04_nasal_brush": 111, "SC87": 112, "T101": 113, "F02730": 114, "SC88": 115, "T126": 116, "D363_Biop_Pro1": 117, "GRO-04_biopsy": 118, "D367_Brus_Dis1": 119, "SC183": 120, "GRO-11_biopsy": 121, "F01513": 122, "7239213": 123, "T153": 124, "D372_Brus_Dis1": 125, "7185213": 126, "GRO-07_biopsy": 127, "390C_72h": 128, "D339_Biop_Pro1": 129, "7135919": 130, "368C_24h": 131, "SC59": 132, "D367_Biop_Int1": 133, "390C_24h": 134, "368C_0h": 135, "D344_Biop_Pro1": 136, "SC31D": 137, "SC181": 138, "SC185": 139, "F02092": 140, "F02609": 141, "F01366": 142, "D322_Biop_Nas1": 143, "D372_Biop_Int1": 144, "D326_Brus_Dis1": 145, "7239219": 146, "7239220": 147, "D344_Biop_Int1": 148, "D363_Biop_Int2": 149, "F01692": 150, "F02613": 151, "F01494": 152, "F01508": 153, "F00409": 154, "F01483": 155, "7239218": 156, "F01853": 157, "7119452": 158, "7119453": 159, "SC31": 160, "GRO-01_biopsy": 161, "F01495": 162, "F01511": 163, "F01506": 164, "F01869": 165} -------------------------------------------------------------------------------- /meta/lung/cell2cat.json: -------------------------------------------------------------------------------- 1 | {"alveolar macrophage": 0, "natural killer cell": 1, "type II pneumocyte": 2, "respiratory basal cell": 3, "vein endothelial cell": 4, "CD8-positive, alpha-beta T cell": 5, "pulmonary artery endothelial cell": 6, "bronchus fibroblast of lung": 7, "CD4-positive, alpha-beta T cell": 8, "type I pneumocyte": 9, "ciliated columnar cell of tracheobronchial tree": 10, "plasma cell": 11, "respiratory hillock cell": 12, "nasal mucosa goblet cell": 13, "club cell": 14, "smooth muscle cell": 15, "classical monocyte": 16, "elicited macrophage": 17, "tracheobronchial serous cell": 18, "non-classical monocyte": 19, "capillary endothelial cell": 20, "alveolar type 2 fibroblast cell": 21, "endothelial cell of lymphatic vessel": 22, "epithelial cell of lower respiratory tract": 23, "tracheobronchial smooth muscle cell": 24, "alveolar type 1 fibroblast cell": 25, "multi-ciliated epithelial cell": 26, "bronchial goblet cell": 27, "lung neuroendocrine cell": 28, "CD1c-positive myeloid dendritic cell": 29, "conventional dendritic cell": 30, "myofibroblast cell": 31, "B cell": 32, "mast cell": 33, "lung macrophage": 34, "mucus secreting cell": 35, "tracheobronchial goblet cell": 36, "lung pericyte": 37, "epithelial cell of alveolus of lung": 38, "acinar cell": 39, "mesothelial cell": 40, "serous secreting cell": 41, "ionocyte": 42, "stromal cell": 43, "brush cell of trachebronchial tree": 44, "plasmacytoid dendritic cell": 45, "T cell": 46, "fibroblast": 47, "hematopoietic stem cell": 48, "dendritic cell": 49} -------------------------------------------------------------------------------- /meta/lung_fetal_donor/batch2cat.json: -------------------------------------------------------------------------------- 1 | {"5891STDY8062349": 0, "5891STDY8062350": 1, "5891STDY8062351": 2, "5891STDY8062352": 3, "5891STDY8062353": 4, "5891STDY8062354": 5, "5891STDY8062355": 6, "5891STDY8062356": 7, "WSSS8012016": 8, "WSSS8011222": 9, "WSSS_F_LNG8713176": 10, "WSSS_F_LNG8713177": 11, "WSSS_F_LNG8713178": 12, "WSSS_F_LNG8713179": 13, "WSSS_F_LNG8713180": 14, "WSSS_F_LNG8713181": 15, "WSSS_F_LNG8713184": 16, "WSSS_F_LNG8713185": 17, "WSSS_F_LNG8713186": 18, "WSSS_F_LNG8713187": 19, "WSSS_F_LNG8713188": 20, "WSSS_F_LNG8713189": 21, "WSSS_F_LNG8713190": 22, "WSSS_F_LNG8713191": 23, "5891STDY9030806": 24, "5891STDY9030807": 25, "5891STDY9030808": 26, "5891STDY9030809": 27, "5891STDY9030810": 28} -------------------------------------------------------------------------------- /meta/lung_fetal_donor/cell2cat.json: -------------------------------------------------------------------------------- 1 | {"Alveolar fibro": 0, "Interm fibro": 1, "Mid fibro": 2, "Pericyte": 3, "Adventitial fibro": 4, "Vascular SMC 1": 5, "Myofibro 3": 6, "Vascular SMC 2": 7, "Airway fibro": 8, "Late mesothelial": 9, "Myofibro 2": 10, "Myofibro 1": 11, "Early fibro": 12, "Mid mesothelial": 13, "Late airway SMC": 14, "MYL4+ SMC": 15, "Early mesothelial": 16, "Resting chondrocyte": 17, "Interm chondrocyte": 18, "Mesenchymal 3": 19, "Mesenchymal 2": 20, "Mesenchymal 1": 21, "ACTC+ SMC": 22, "Early airway SMC 1": 23, "ASPN+ chondrocyte": 24, "Early airway SMC 2": 25, "Late tip": 26, "Club": 27, "Late airway progenitor": 28, "Squamous": 29, "AT1": 30, "AT2": 31, "Proximal secretory 2": 32, "Proximal secretory 3": 33, "Proximal secretory 1": 34, "GHRL+ neuroendocrine": 35, "Pulmonary neuroendocrine": 36, "SMG": 37, "Late stalk": 38, "Early tip": 39, "Early airway progenitor": 40, "Early stalk": 41, "Interm neuroendocrine": 42, "Pulmonary NE precursor": 43, "Mid airway progenitor": 44, "Mid tip": 45, "Proximal basal": 46, "Proximal secretory progenitors": 47, "Late basal": 48, "Mid basal": 49, "GHRL+ NE precursor": 50, "Mid stalk": 51, "SMG basal": 52, "MUC5AC+ ASCL1+ progenitor": 53, "SPP1+ M\u03a6": 54, "Cycling DC": 55, "S100A12-lo cla. mono.": 56, "Non-cla. mono.": 57, "DC2": 58, "S100A12-hi cla. mono.": 59, "Promonocyte-like": 60, "CX3CR1+ M\u03a6": 61, "DC1": 62, "HSC": 63, "DC3": 64, "Neutrophil": 65, "CXCL9+ M\u03a6": 66, "Megakaryocyte": 67, "Promyelocyte-like": 68, "pDC": 69, "Eosinophil": 70, "aDC 2": 71, "pre-pDC/DC5": 72, "Basophil": 73, "MEP": 74, "Platelet": 75, "aDC 1": 76, "HSC/ELP": 77, "GMP": 78, "Myelocyte-like": 79, "APOE+ M\u03a62": 80, "CMP": 81, "APOE+ M\u03a61": 82, "Late cap": 83, "Intermediate lymphatic endo": 84, "Venous endo": 85, "GRIA2+ arterial endo": 86, "Definitive reticulocyte": 87, "Lymphatic endo": 88, "Mid cap": 89, "Definitive erythrocyte": 90, "Cycling definitive erythroblast": 91, "Arterial endo": 92, "SCG3+ lymphatic endothelial": 93, "Aerocyte": 94, "OMD+ endo": 95, "Definitive erythroblast": 96, "Primitive erythrocyte": 97, "Early cap": 98, "HMOX1+ primitive erythroblast": 99, "CD4 T": 100, "CD16+ NK": 101, "CD56bright NK": 102, "Intermediate NK": 103, "Th17": 104, "Treg": 105, "CD8 T": 106, "Cycling T": 107, "NKT2": 108, "NKT1": 109, "ILC3": 110, "Cycling NK": 111, "ILC2": 112, "Activated NK": 113, "ILCP": 114, "T\u03b1\u03b2_Entry": 115, "\u03ba small pre-B": 116, "Large pre-B": 117, "CD5- Mature B": 118, "Immature B": 119, "Pro-B": 120, "\u03bb small pre-B": 121, "Late pro-B": 122, "Late pre-B": 123, "CD5+ CCL22- mature B": 124, "CD5+ CCL22+ mature B": 125, "Pro-B/Pre-B transition": 126, "Ciliated": 127, "MUC16+ ciliated": 128, "Deuterosomal": 129, "COL20A1+ Schwann": 130, "Late Schwann": 131, "Early Schwann": 132, "PCP4+ neuron": 133, "MFNG+ DBH+ neuron": 134, "Schwann precursor": 135, "TM4SF4+ CHODL+ neuron": 136, "SST+ neuron": 137, "KCNIP4+ neuron": 138, "TM4SF4+ PENK+ neuron": 139, "FGFBP2+ Neural progenitor": 140, "Mid Schwann": 141, "Proliferating Schwann": 142, "Mast": 143} -------------------------------------------------------------------------------- /meta/lung_fetal_organoid/batch2cat.json: -------------------------------------------------------------------------------- 1 | {"WSSS_F_LNG10282020": 0, "WSSS_F_LNG10282021": 1, "WSSS_F_LNG10282022": 2, "WSSS_F_LNG10282023": 3, "WSSS_F_LNG10282024": 4, "WSSS_F_LNG10282025": 5, "WSSS_F_LNG10282026": 6, "WSSS_F_LNG10282027": 7, "5891STDY8062349": 8, "5891STDY8062350": 9, "5891STDY8062351": 10, "5891STDY8062352": 11, "5891STDY8062353": 12, "5891STDY8062354": 13, "5891STDY8062355": 14, "5891STDY8062356": 15, "WSSS8012016": 16, "WSSS8011222": 17, "WSSS_F_LNG8713176": 18, "WSSS_F_LNG8713177": 19, "WSSS_F_LNG8713178": 20, "WSSS_F_LNG8713179": 21, "WSSS_F_LNG8713180": 22, "WSSS_F_LNG8713181": 23, "WSSS_F_LNG8713184": 24, "WSSS_F_LNG8713185": 25, "WSSS_F_LNG8713186": 26, "WSSS_F_LNG8713187": 27, "WSSS_F_LNG8713188": 28, "WSSS_F_LNG8713189": 29, "WSSS_F_LNG8713190": 30, "WSSS_F_LNG8713191": 31, "5891STDY9030806": 32, "5891STDY9030807": 33, "5891STDY9030808": 34, "5891STDY9030809": 35, "5891STDY9030810": 36} -------------------------------------------------------------------------------- /meta/lung_fetal_organoid/cell2cat.json: -------------------------------------------------------------------------------- 1 | {"Mid stalk": 0, "Mid tip": 1, "Late stalk": 2, "AT2": 3, "Late tip": 4, "Mid airway progenitor": 5, "Pulmonary NE precursor": 6, "GHRL+ NE precursor": 7, "Mid basal": 8, "Late basal": 9, "Pulmonary neuroendocrine": 10, "MUC5AC+ ASCL1+ progenitor": 11, "Early tip": 12, "Late airway progenitor": 13, "Early stalk": 14, "Early airway progenitor": 15, "Club": 16, "Squamous": 17, "AT1": 18, "Secretory 2": 19, "Secretory 3": 20, "Secretory 1": 21, "GHRL+ neuroendocrine": 22, "SMG": 23, "NEUROD1+ pulmonary neuroendocrine": 24, "Proximal basal": 25, "Secretory progenitors": 26, "SMG basal": 27} -------------------------------------------------------------------------------- /meta/pancreas/batch2cat.json: -------------------------------------------------------------------------------- 1 | {"celseq": 0, "celseq2": 1, "fluidigmc1": 2, "smartseq2": 3, "inDrop1": 4, "inDrop2": 5, "inDrop3": 6, "inDrop4": 7, "smarter": 8} -------------------------------------------------------------------------------- /meta/pancreas/cell2cat.json: -------------------------------------------------------------------------------- 1 | {"gamma": 0, "acinar": 1, "alpha": 2, "delta": 3, "beta": 4, "ductal": 5, "endothelial": 6, "activated_stellate": 7, "schwann": 8, "mast": 9, "macrophage": 10, "epsilon": 11, "quiescent_stellate": 12, "t_cell": 13} -------------------------------------------------------------------------------- /meta/skin/batch2cat.json: -------------------------------------------------------------------------------- 1 | {"S1": 0, "S2": 1, "S3": 2, "S4": 3, "S5": 4} -------------------------------------------------------------------------------- /meta/skin/cell2cat.json: -------------------------------------------------------------------------------- 1 | {"T cell": 0, "keratinocyte": 1, "macrophage": 2, "stem cell of epidermis": 3, "reticular cell": 4, "pericyte": 5, "inflammatory cell": 6, "secretory cell": 7, "mesenchymal cell": 8, "endothelial cell of vascular tree": 9, "melanocyte": 10, "endothelial cell of lymphatic vessel": 11, "erythrocyte": 12} -------------------------------------------------------------------------------- /res/scGraph/lung.csv: -------------------------------------------------------------------------------- 1 | ,Rank-PCA,Corr-PCA,Corr-Weighted 2 | BBKNN,0.5565723094819637,0.666521886386873,0.6898861006657622 3 | Harmony,0.6325068151005349,0.7455266231517271,0.7262389160236772 4 | Provided_x_scanvi_emb,0.3620533119599134,0.6403430764007565,0.5992361098936274 5 | Provided_x_umap,0.5427853152454126,0.6391963086732239,0.6652168796261054 6 | Scanorama,0.6556284391220193,0.7809783152141264,0.74945184218577 7 | X_pca,0.7260006008824645,0.8336719096239535,0.8054892646539925 8 | X_scANVI,0.4685784403950564,0.69158355079398,0.6645481977359752 9 | X_scVI,0.3781593976996603,0.6047324059110402,0.559186406004497 10 | X_tsne,0.3751132221169736,0.49510320066552665,0.5206096243986322 11 | X_umap,0.47479628251783396,0.5838155729775487,0.5932947438447949 12 | scGen,0.6303520398646303,0.7731980517172379,0.7201274464746682 13 | scPoli,0.6810088727142352,0.8149359497866032,0.8218137090027244 14 | -------------------------------------------------------------------------------- /scripts/Lung_Mixup.log: -------------------------------------------------------------------------------- 1 | DATASET-lung_MODE-mixup-ONLY_LEAK-16_MLP-128 128 2 | wandb: Currently logged in as: hanchen. Use `wandb login --relogin` to force relogin 3 | wandb: wandb version 0.18.1 is available! To upgrade, please run: 4 | wandb: $ pip install wandb --upgrade 5 | wandb: Tracking run with wandb version 0.17.0 6 | wandb: Run data is saved locally in /gpfs/scratchfs01/site/u/wangh256/Islander/src/wandb/run-20240923_122103-3bynd09h 7 | wandb: Run `wandb offline` to turn off syncing. 8 | wandb: Syncing run MODE-mixup-ONLY_LEAK-16_MLP-128 128 9 | wandb: ⭐️ View project at https://genentech.wandb.io/hanchen/_lung_ 10 | wandb: 🚀 View run at https://genentech.wandb.io/hanchen/_lung_/runs/3bynd09h 11 | 12 | 13 | 14 | ============================================================================= 15 | DATASET: lung 16 | BATCH: sample, LABEL: cell_type 17 | Load 28024 Genes & 542643 Cells. 18 | ============================================================================= 19 | 20 | 21 | ...Loading from Pre Saved, Training Split... 22 | ...Loading from Pre Saved, Testing Split... 23 | # Genes: 28024 24 | # Cells: 488192 (Training), 54528 (Testing) 25 | 26 | ============================ Model architecture =========================== 27 | Leak dim: 16 28 | Dropout: True 29 | MLP size: [128, 16, 128] 30 | Batchnorm: True , momentum: 0.1 , eps: 1e-05 31 | ============================================================================= 32 | LR: 0.000976 Epoch: 0, Test Loss: 0.17, Train Loss: 1.31, Test Acc: 94.93, Time: 0.00 Mins 33 | LR: 0.000905 Epoch: 1, Test Loss: 0.17, Train Loss: 1.04, Test Acc: 94.95, Time: 0.00 Mins 34 | LR: 0.000794 Epoch: 2, Test Loss: 0.17, Train Loss: 0.99, Test Acc: 94.91, Time: 1.00 Mins 35 | LR: 0.000655 Epoch: 3, Test Loss: 0.17, Train Loss: 0.95, Test Acc: 95.02, Time: 1.00 Mins 36 | LR: 0.000500 Epoch: 4, Test Loss: 0.17, Train Loss: 0.92, Test Acc: 95.05, Time: 2.00 Mins 37 | LR: 0.000345 Epoch: 5, Test Loss: 0.18, Train Loss: 0.89, Test Acc: 95.05, Time: 2.00 Mins 38 | LR: 0.000206 Epoch: 6, Test Loss: 0.18, Train Loss: 0.89, Test Acc: 95.04, Time: 3.00 Mins 39 | LR: 0.000095 Epoch: 7, Test Loss: 0.18, Train Loss: 0.89, Test Acc: 95.08, Time: 3.00 Mins 40 | LR: 0.000024 Epoch: 8, Test Loss: 0.18, Train Loss: 0.89, Test Acc: 95.00, Time: 4.00 Mins 41 | LR: 0.000000 Epoch: 9, Test Loss: 0.18, Train Loss: 0.88, Test Acc: 95.14, Time: 4.00 Mins 42 | wandb: - 0.046 MB of 0.046 MB uploaded wandb: \ 0.046 MB of 0.067 MB uploaded wandb: 43 | wandb: Run history: 44 | wandb: LR █▇▇▆▅▃▂▂▁▁ 45 | wandb: Test Acc ▂▂▁▄▅▅▅▆▄█ 46 | wandb: Test Loss ▁▂▆▂▂▇▆██▇ 47 | wandb: Test Rec Loss ▂█▃▅▃▄▁▃▄▁ 48 | wandb: Train Loss █▄▃▂▂▁▁▁▁▁ 49 | wandb: Train Rec Loss ▁█▆▅▄▁▂▂▁▂ 50 | wandb: 51 | wandb: Run summary: 52 | wandb: LR 0.0 53 | wandb: Test Acc 95.13747 54 | wandb: Test Loss 0.17661 55 | wandb: Test Rec Loss 0.20821 56 | wandb: Train Loss 0.87541 57 | wandb: Train Rec Loss 0.20818 58 | wandb: 59 | wandb: 🚀 View run MODE-mixup-ONLY_LEAK-16_MLP-128 128 at: https://genentech.wandb.io/hanchen/_lung_/runs/3bynd09h 60 | wandb: ⭐️ View project at: https://genentech.wandb.io/hanchen/_lung_ 61 | wandb: Synced 7 W&B file(s), 0 media file(s), 0 artifact file(s) and 1 other file(s) 62 | wandb: Find logs at: ./wandb/run-20240923_122103-3bynd09h/logs 63 | -------------------------------------------------------------------------------- /scripts/Lung_scGraph.log: -------------------------------------------------------------------------------- 1 | _lung_ 2 | Processing batches, calcualte centroids and pairwise distances 3 | 0%| | 0/166 [00:00= 0) 27 | integer_values = np.all(subset.astype(int) == subset) 28 | assert non_negative and integer_values 29 | 30 | # Quality Control 31 | sc.pp.filter_cells(adata, min_counts=min_counts) 32 | sc.pp.filter_cells(adata, min_genes=min_genes) 33 | sc.pp.filter_genes(adata, min_cells=min_cells) 34 | print("=" * 77) 35 | print(rf"{adata.n_vars} genes x {adata.n_obs} cells after quality control.") 36 | print("=" * 77) 37 | 38 | # Pre-Processing 39 | sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4) 40 | sc.pp.log1p(adata) 41 | return adata 42 | 43 | 44 | def umaps_rawcounts(adata, n_neighbors=10, n_pcs=50): 45 | _list = list(adata.obsm.keys()) 46 | for _obsm in _list: 47 | if "Provided_" in _obsm: 48 | continue 49 | adata.obsm["Provided_" + _obsm.lower()] = adata.obsm[_obsm] 50 | del adata.obsm[_obsm] 51 | 52 | sc.pp.neighbors(adata, n_neighbors=n_neighbors, n_pcs=n_pcs) 53 | sc.tl.umap(adata) 54 | return adata 55 | 56 | 57 | # import scipy.sparse 58 | def compare_sparse_matrices(a, b): 59 | """ 60 | Compares two sparse matrices for equality. 61 | Parameters: 62 | a (csr_matrix or csc_matrix): The first sparse matrix. 63 | b (csr_matrix or csc_matrix): The second sparse matrix. 64 | Returns: 65 | bool: True if the matrices are equal, False otherwise. 66 | """ 67 | # Ensure the matrices have the same format for comparison 68 | a_csr = a.tocsr() 69 | b_csr = b.tocsr() 70 | # Compare data, indices, and indptr attributes 71 | data_equal = np.array_equal(a_csr.data, b_csr.data) 72 | indices_equal = np.array_equal(a_csr.indices, b_csr.indices) 73 | indptr_equal = np.array_equal(a_csr.indptr, b_csr.indptr) 74 | # Optionally, compare shapes 75 | shape_equal = a_csr.shape == b_csr.shape 76 | return data_equal and indices_equal and indptr_equal and shape_equal 77 | 78 | 79 | def check_unique(adata, batch, metas): 80 | for meta in metas: 81 | print(meta, "=" * 77) 82 | for batch_id in tqdm(adata.obs[batch].unique()): 83 | adata_batch = adata[adata.obs[batch] == batch_id] 84 | if adata_batch.obs[meta].unique().__len__() > 1: 85 | print(batch_id, adata_batch.obs[meta].unique()) 86 | break 87 | print("\n") 88 | 89 | 90 | def benchmark_obsm(adata, batch_key, label_key, obsm_keys): 91 | biocons = BioConservation(nmi_ari_cluster_labels_leiden=True) 92 | batcorr = BatchCorrection() 93 | 94 | benchmarker = Benchmarker( 95 | adata, 96 | batch_key=batch_key, 97 | label_key=label_key, 98 | embedding_obsm_keys=obsm_keys, 99 | bio_conservation_metrics=biocons, 100 | batch_correction_metrics=batcorr, 101 | pre_integrated_embedding_obsm_key="X_pca", 102 | n_jobs=-1, 103 | ) 104 | benchmarker.prepare(neighbor_computer=None) 105 | benchmarker.benchmark() 106 | 107 | return benchmarker.get_results(min_max_scale=False) 108 | 109 | 110 | def calculate_centroids(X, labels): 111 | centroids = dict() 112 | for label in labels.unique(): 113 | centroids[label] = np.mean(X[labels == label], axis=0) 114 | return centroids 115 | 116 | 117 | def calculate_trimmed_means(X, labels, trim_proportion=0.2, ignore_=[]): 118 | centroids = dict() 119 | if isinstance(X, csr_matrix): 120 | X = X.toarray() 121 | for label in labels.unique(): 122 | if label in ignore_: 123 | continue 124 | centroids[label] = trim_mean(X[labels == label], proportiontocut=trim_proportion, axis=0) 125 | return centroids 126 | 127 | 128 | def compute_classwise_distances(centroids): 129 | centroid_vectors = np.array([centroids[key] for key in sorted(centroids.keys())]) 130 | distances = cdist(centroid_vectors, centroid_vectors, "euclidean") 131 | return pd.DataFrame(distances, columns=sorted(centroids.keys()), index=sorted(centroids.keys())) 132 | -------------------------------------------------------------------------------- /src/Vis_Handler.py: -------------------------------------------------------------------------------- 1 | import Data_Handler as dh 2 | import umap, scanpy as sc, numpy as np, pandas as pd, matplotlib.pyplot as plt 3 | 4 | 5 | cfg = {"frameon": False, "legend_fontsize": 10, "legend_fontoutline": 2} 6 | 7 | 8 | def umap_inline_annotations(adata, obsm="X_umap", obs="cell_type", title=""): 9 | umap_coords = adata.obsm[obsm] 10 | 11 | df = pd.DataFrame(umap_coords, columns=["x", "y"], index=adata.obs.index) 12 | df["cluster"] = adata.obs[obs] 13 | centroids = df.groupby("cluster").mean() 14 | 15 | fig, ax = plt.subplots(figsize=(16, 16), dpi=300) 16 | with plt.rc_context({"figure.figsize": (16, 16), "figure.dpi": (300)}): 17 | sc.pl.umap( 18 | adata, 19 | color=obs, 20 | title=title, 21 | show=False, 22 | legend_loc=None, 23 | frameon=False, 24 | ax=ax, 25 | ) 26 | 27 | for cluster, centroid in centroids.iterrows(): 28 | plt.text(centroid["x"], centroid["y"], str(cluster), fontsize=14, ha="center") 29 | return fig 30 | 31 | 32 | def _dist_sweep(adata, obsm="GeneCBM", dist=[0.2, 0.3, 0.5], obs="cell_type"): 33 | print(dist) 34 | print("\n\n\n") 35 | for _dist in dist: 36 | reducer = umap.UMAP(min_dist=_dist) 37 | embedding = reducer.fit_transform(adata.obsm[obsm]) 38 | adata.obsm["CACHE_%s" % _dist] = embedding 39 | 40 | fig, ax = plt.subplots(figsize=(16, 16)) 41 | sc.pl.embedding( 42 | adata, 43 | basis="CACHE_%s" % _dist, 44 | color=[obs], 45 | legend_loc="on data", 46 | frameon=False, 47 | ncols=1, 48 | size=77, 49 | ax=ax, 50 | ) 51 | 52 | # The legend() function of the axes object is used to set the fontsize. 53 | ax.legend(fontsize=5) 54 | 55 | return fig 56 | 57 | 58 | def _obsm_sweep(dataset, dist=0.5, hvg=False, obsms=None): 59 | _suffix = "_hvg" if hvg else "" 60 | adata = sc.read_h5ad(dh.DATA_EMB_[dataset + _suffix]) 61 | obsms = list(adata.obsm.keys()).copy() if obsms is None else obsms 62 | 63 | for _obsm in obsms: 64 | print(_obsm) 65 | _, _dim = adata.obsm[_obsm].shape 66 | if _dim != 2: 67 | reducer = umap.UMAP(min_dist=dist) 68 | embedding = reducer.fit_transform(adata.obsm[_obsm]) 69 | adata.obsm["%s_UMAP" % _obsm] = embedding 70 | _plot = "%s_UMAP" % _obsm 71 | else: 72 | _plot = _obsm 73 | sc.pl.embedding( 74 | adata, 75 | basis=_plot, 76 | color=[dh.META_[dataset]["celltype"]], 77 | save="_%s%s.pdf" % (dataset, _suffix), 78 | **cfg, 79 | ) 80 | 81 | 82 | def hex_to_rgb(hex_color): 83 | """Convert a hex color string to RGB tuple.""" 84 | hex_color = hex_color.lstrip("#") 85 | return tuple(int(hex_color[i : i + 2], 16) / 255 for i in (0, 2, 4)) 86 | 87 | 88 | def rgb_to_hex(rgb_color): 89 | """Convert an RGB tuple to a hex color string.""" 90 | return "#" + "".join([f"{int(round(c * 255)):02x}" for c in rgb_color]) 91 | 92 | 93 | def interpolate_hex_colors(start_hex, end_hex, n): 94 | """Generate `n` evenly spaced colors between `start_hex` and `end_hex`.""" 95 | 96 | # Convert start and end colors to RGB 97 | start_rgb = hex_to_rgb(start_hex) 98 | end_rgb = hex_to_rgb(end_hex) 99 | 100 | # Generate the intermediate RGB values 101 | rgb_values = [np.linspace(start, end, n) for start, end in zip(start_rgb, end_rgb)] 102 | 103 | # Convert each RGB set to hex and return the list 104 | return [rgb_to_hex(rgb) for rgb in zip(*rgb_values)] 105 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/Islander/4fbf0d9336d4a3bb4667fec9aeb3bff551321622/src/__init__.py -------------------------------------------------------------------------------- /src/scBenchmarker.py: -------------------------------------------------------------------------------- 1 | # https://docs.scvi-tools.org/en/stable/tutorials/notebooks/harmonization.html 2 | # https://scib-metrics.readthedocs.io/en/stable/notebooks/lung_example.html 3 | 4 | import Data_Handler as dh 5 | import os, scvi, json, torch, argparse, numpy as np, scanpy as sc, pandas as pd 6 | from scib_metrics.benchmark import Benchmarker, BioConservation, BatchCorrection 7 | from scDataset import scDataset, collate_fn, set_seed 8 | from os.path import join, exists, getmtime, dirname 9 | from ArgParser import Parser_Benchmarker 10 | from torch.utils.data import DataLoader 11 | from scModel import Model_ZOO 12 | 13 | # NOTE: We downgrade the version of JAXLAB and JAX 14 | # https://github.com/google/jax/issues/15268#issuecomment-1487625083 15 | 16 | 17 | set_seed(dh.SEED) 18 | os.environ["CUDA_VISIBLE_DEVICES"] = dh.CUDA_DEVICE 19 | # torch.cuda.is_available() -> False 20 | sc.settings.verbosity = 3 21 | sc.logging.print_header() 22 | 23 | GPU_4_NEIGHBORS = False 24 | 25 | print("\n\n") 26 | print("=" * 44) 27 | print("use GPU for neighbors calculation: ", GPU_4_NEIGHBORS) 28 | print("GPU is available: ", torch.cuda.is_available()) 29 | print("=" * 44) 30 | print("\n") 31 | 32 | 33 | def faiss_hnsw_nn(X: np.ndarray, k: int): 34 | # """GPU HNSW nearest neighbor search using faiss. 35 | 36 | # See https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md 37 | # for index param details. 38 | # """ 39 | # X = np.ascontiguousarray(X, dtype=np.float32) 40 | # res = faiss.StandardGpuResources() 41 | # M = 32 42 | # index = faiss.IndexHNSWFlat(X.shape[1], M, faiss.METRIC_L2) 43 | # gpu_index = faiss.index_cpu_to_gpu(res, 0, index) 44 | # gpu_index.add(X) 45 | # distances, indices = gpu_index.search(X, k) 46 | # del index 47 | # del gpu_index 48 | # # distances are squared 49 | # return NeighborsOutput(indices=indices, distances=np.sqrt(distances)) 50 | raise NotImplementedError 51 | 52 | 53 | def faiss_brute_force_nn(X: np.ndarray, k: int): 54 | # """GPU brute force nearest neighbor search using faiss.""" 55 | # X = np.ascontiguousarray(X, dtype=np.float32) 56 | # res = faiss.StandardGpuResources() 57 | # index = faiss.IndexFlatL2(X.shape[1]) 58 | # gpu_index = faiss.index_cpu_to_gpu(res, 0, index) 59 | # gpu_index.add(X) 60 | # distances, indices = gpu_index.search(X, k) 61 | # del index 62 | # del gpu_index 63 | # # distances are squared 64 | # return NeighborsOutput(indices=indices, distances=np.sqrt(distances)) 65 | raise NotImplementedError 66 | 67 | 68 | class BasicHandler: 69 | def __init__(self, args, ext_adata=None): 70 | self.args = args 71 | self.parse_cfg() 72 | self.ext_adata = ext_adata 73 | self.device = torch.device("cuda:0") 74 | self.res_dir = rf"{dh.RES_DIR}/scCB" 75 | os.makedirs(self.res_dir, exist_ok=True) 76 | 77 | def _str_formatter(self, message): 78 | print(f"\n=== {message} ===\n") 79 | 80 | def _scDataloader(self, shuffle=True): 81 | # assert exists(join(dh.DATA_DIR, self.data_prefix)), "datapath should exist" 82 | 83 | if self.ext_adata: 84 | self.adata = self.ext_adata 85 | self.external_adata = True 86 | self._str_formatter("Using External ADATA") 87 | 88 | else: # HLCA 89 | self.adata = dh.DATA_RAW 90 | 91 | _batch_id = list(dh.BATCH2CAT.keys()) 92 | if self.args.batch_id: 93 | _batch_id = self.args.batch_id 94 | 95 | _dataset = scDataset( 96 | self.adata, 97 | inference=True, 98 | batch_id=_batch_id, 99 | prefix=self.data_prefix, 100 | n_cells=self.args.n_cells, 101 | ) 102 | self.n_gene = _dataset.n_vars 103 | 104 | if len(_batch_id) == 1: 105 | # Generating data counteraparts based on the given batches 106 | self._str_formatter("Single Batch: {}".format(_batch_id)) 107 | batch_ = _dataset[0] 108 | _df = pd.DataFrame(index=dh.CONCEPTS2CAT.keys(), columns=["digits", "names"]) 109 | _df["digits"] = np.concatenate(list(batch_["meta"].values())).reshape(dh.NUM_CONCEPTS, -1).astype(int)[:, 0] 110 | _df["names"] = dh.CONCEPT2NAMES(_df["digits"]) 111 | print(_df) 112 | else: 113 | self._str_formatter("Multi Batches: {}".format(_batch_id)) 114 | 115 | return DataLoader( 116 | _dataset, 117 | batch_size=1, 118 | num_workers=4, 119 | shuffle=shuffle, 120 | collate_fn=collate_fn, 121 | ) 122 | 123 | def _load_model(self): 124 | if self.args.batch1d == "Vanilla": 125 | bn_eps, bn_momentum = 1e-5, 0.1 126 | elif self.args.batch1d == "scVI": 127 | bn_eps, bn_momentum = 1e-3, 0.01 128 | else: 129 | raise ValueError("Unknown batch1d type") 130 | 131 | self.model = Model_ZOO[self.args.type]( 132 | n_gene=self.n_gene, 133 | leak_dim=self.args.leakage, 134 | mlp_size=self.args.mlp_size, 135 | bn_eps=bn_eps, 136 | bn_momentum=bn_momentum, 137 | batchnorm=self.args.batchnorm, 138 | dropout_layer=self.args.dropout, 139 | with_projector=False, 140 | ).to(self.device) 141 | 142 | self._load_part_of_state_dict(torch.load(self.ckpt_path)) 143 | self._str_formatter("Loaded weights from: %s" % self.ckpt_path) 144 | self.model.eval() 145 | 146 | def _load_part_of_state_dict(self, state_dict): 147 | model_state_dict = self.model.state_dict() 148 | common_keys = set(model_state_dict.keys()) & set(state_dict.keys()) 149 | extra_keys = set(state_dict.keys()) - set(model_state_dict.keys()) 150 | missing_keys = set(model_state_dict.keys()) - set(state_dict.keys()) 151 | 152 | for key in common_keys: 153 | model_state_dict[key] = state_dict[key] 154 | 155 | self.model.load_state_dict(model_state_dict) 156 | 157 | if extra_keys: 158 | print(f"Warning: The following keys present in the checkpoints are not found in the model: {extra_keys}") 159 | if missing_keys: 160 | print(f"Warning: The following keys present in the model are not found in the checkpoints: {missing_keys}") 161 | 162 | def parse_cfg(self): 163 | def _recent_ckpt(dirpath): 164 | a = [s for s in os.listdir(dirpath) if ".pth" in s] 165 | a.sort(key=lambda s: getmtime(join(dirpath, s))) 166 | return a 167 | 168 | self.cfg = None 169 | if ".pth" in self.args.save_path: 170 | self.ckpt_path = self.args.save_path 171 | else: 172 | if exists(join(self.args.save_path, "ckpt_best.pth")): 173 | self.ckpt_path = join(self.args.save_path, "ckpt_best.pth") 174 | else: 175 | self.ckpt_path = join(self.args.save_path, _recent_ckpt(self.args.save_path)[-1]) 176 | if exists(join(self.args.save_path, "cfg.json")): 177 | self.cfg = json.load(open(join(self.args.save_path, "cfg.json"))) 178 | elif exists(join(dirname(self.args.save_path), "cfg.json")): 179 | self.cfg = json.load(open(join(dirname(self.args.save_path), "cfg.json"))) 180 | 181 | if self.cfg: 182 | dict_ = vars(self.args) 183 | dict_.update(self.cfg) 184 | self.args = argparse.Namespace(**dict_) 185 | 186 | assert ("concept" in self.args.type) or ("cb" in self.args.type) 187 | 188 | 189 | class scIB(BasicHandler): 190 | """Integration Benchmarker""" 191 | 192 | # https://scib-metrics.readthedocs.io/en/stable/ 193 | # https://github.com/theislab/scib-reproducibility/ 194 | # https://github.com/theislab/scib-pipeline/ 195 | # https://github.com/theislab/scib/ 196 | 197 | def __init__(self, args): 198 | super().__init__(args=args) 199 | self._load_adata_() 200 | 201 | def _load_adata_(self): 202 | 203 | if not self.args.customized_data: 204 | _suffix = "_hvg" if self.args.highvar else "" 205 | assert not self.args.use_raw, "use_raw is not supported" 206 | self.adata = sc.read(dh.DATA_EMB_[self.args.dataset + _suffix]) 207 | # 208 | self.batch_key = dh.META_[self.args.dataset]["batch"] 209 | self.label_key = dh.META_[self.args.dataset]["celltype"] 210 | else: 211 | self.adata = sc.read(self.args.dataset) 212 | self.batch_key = self.args.batch_key 213 | self.label_key = self.args.label_key 214 | return 215 | 216 | def _str_formatter(self, message): 217 | print(f"\n=== {message} ===\n") 218 | 219 | def _scDataloader(self): 220 | self._str_formatter("Dataloader") 221 | _verb = True 222 | _scDataset = scDataset( 223 | dataset=self.args.dataset, 224 | inference=True, 225 | rm_cache=False, 226 | verbose=_verb, 227 | ) 228 | _scDataLoader = DataLoader( 229 | _scDataset, 230 | batch_size=1, 231 | shuffle=False, 232 | num_workers=8, 233 | collate_fn=collate_fn, 234 | ) 235 | return _scDataset, _scDataLoader 236 | 237 | def _benchmark_(self, n_jobs=-1, scratch=False): 238 | # NOTE: for DEBUG 239 | # first_5_ = self.adata.obs[self.batch_key].unique()[:5].to_list() 240 | # self.adata = self.adata[[_it in first_5_ for _it in self.adata.obs[self.batch_key]]] 241 | 242 | # recompute the embeddings 243 | # self.scratch = True if self.args.highvar else scratch 244 | self.scratch = scratch 245 | self._pca_() 246 | if self.args.umap or self.args.all: 247 | self._umap_() 248 | if self.args.tsne or self.args.all: 249 | self._tsne_() 250 | if self.args.harmony or self.args.all: 251 | self._harmony_() 252 | if self.args.scanorama or self.args.all: 253 | self._scanorama_() 254 | if self.args.scvi or self.args.all: 255 | self._scvi_() 256 | if self.args.scanvi or self.args.all: 257 | self._scanvi_() 258 | if self.args.bbknn or self.args.all: 259 | self._bbknn_() 260 | if self.args.scgen or self.args.all: 261 | self._scgen_() 262 | if self.args.fastmnn or self.args.all: 263 | self._fastmnn_() 264 | if self.args.scpoli or self.args.all: 265 | self._scpoli_() 266 | 267 | if self.args.islander: 268 | self._islander_() 269 | 270 | if self.args.obsm_keys is None: 271 | obsm_keys = list(self.adata.obsm) 272 | else: 273 | obsm_keys = self.args.obsm_keys 274 | 275 | for embed in obsm_keys: 276 | print("%12s, %d" % (embed, self.adata.obsm[embed].shape[1])) 277 | if self.adata.obsm[embed].shape[0] != np.unique(self.adata.obsm[embed], axis=0).shape[0]: 278 | print("\nWarning: Embedding %s has duplications\n" % embed) 279 | obsm_keys.remove(embed) 280 | self._save_adata_() 281 | 282 | self._str_formatter(rf"scIB Benchmarking: {obsm_keys}") 283 | biocons = BioConservation(nmi_ari_cluster_labels_leiden=True) 284 | batcorr = BatchCorrection() 285 | 286 | """ === for DEBUG ===""" 287 | # biocons = BioConservation(isolated_labels=False) 288 | # biocons = BioConservation( 289 | # silhouette_label=False, 290 | # isolated_labels=False,) 291 | # batcorr = BatchCorrection( 292 | # silhouette_batch=False, 293 | # ilisi_knn=False, 294 | # kbet_per_label=False, 295 | # ) 296 | 297 | self.benchmarker = Benchmarker( 298 | self.adata, 299 | batch_key=self.batch_key, 300 | label_key=self.label_key, 301 | embedding_obsm_keys=obsm_keys, 302 | bio_conservation_metrics=biocons, 303 | batch_correction_metrics=batcorr, 304 | pre_integrated_embedding_obsm_key="X_pca", 305 | n_jobs=n_jobs, 306 | ) 307 | if torch.cuda.is_available() and GPU_4_NEIGHBORS: 308 | # self.benchmarker.prepare(neighbor_computer=faiss_brute_force_nn) 309 | self.benchmarker.prepare(neighbor_computer=faiss_hnsw_nn) 310 | else: 311 | # Calculate the Neighbors based on the CPUs 312 | self.benchmarker.prepare(neighbor_computer=None) 313 | self.benchmarker.benchmark() 314 | 315 | self._str_formatter("scIB Benchmarking Finished") 316 | df = self.benchmarker.get_results(min_max_scale=False) 317 | print(df.head(7)) 318 | 319 | os.makedirs(rf"{dh.RES_DIR}/scIB", exist_ok=True) 320 | if self.args.savecsv is not None: 321 | df.to_csv(rf"{dh.RES_DIR}/scIB/{self.args.savecsv}.csv") 322 | else: 323 | savecsv = self.args.save_path.split("/")[-1].replace(" ", "_") 324 | df.to_csv(rf"{dh.RES_DIR}/scIB/{savecsv}.csv") 325 | 326 | savefig = False 327 | if savefig: 328 | _suffix = "_hvg" if self.args.highvar else "" 329 | self.benchmarker.plot_results_table(min_max_scale=False, show=False, save_dir=rf"{dh.RES_DIR}/scIB/") 330 | os.rename( 331 | src=rf"{dh.RES_DIR}/scIB/scib_results.svg", 332 | dst=rf"{dh.RES_DIR}/figures/scib_{self.args.dataset}{_suffix}.svg", 333 | ) 334 | return 335 | 336 | def _save_adata_(self): 337 | if self.args.saveadata or self.args.all: 338 | if not self.args.customized_data: 339 | _suffix = "_hvg" if self.args.highvar else "" 340 | self._str_formatter("Saving %s" % (self.args.dataset + _suffix)) 341 | self.adata.write_h5ad(dh.DATA_EMB_[self.args.dataset + _suffix], compression="gzip") 342 | else: 343 | self._str_formatter("Saving %s" % self.args.dataset) 344 | self.adata.write_h5ad(self.args.dataset, compression="gzip") 345 | return 346 | 347 | def _pca_(self, n_comps=50): 348 | self._str_formatter("PCA") 349 | # if "X_pca" in self.adata.obsm and not self.scratch: 350 | # return 351 | sc.pp.pca(self.adata, n_comps=n_comps) 352 | self._save_adata_() 353 | return 354 | 355 | def _umap_(self, n_neighbors=10, n_pcs=50): 356 | self._str_formatter("UMAP") 357 | if "X_umap" in self.adata.obsm and not self.scratch: 358 | return 359 | sc.pp.neighbors(self.adata, n_neighbors=n_neighbors, n_pcs=n_pcs) 360 | sc.tl.umap(self.adata) 361 | self._save_adata_() 362 | return 363 | 364 | def _tsne_(self): 365 | if "X_tsne" in self.adata.obsm and not self.scratch: 366 | return 367 | self._str_formatter("TSNE") 368 | sc.tl.tsne(self.adata) 369 | self._save_adata_() 370 | return 371 | 372 | def _scvi_(self): 373 | self._str_formatter("scVI") 374 | if "X_scVI" in self.adata.obsm and not self.scratch: 375 | return 376 | adata = self.adata.copy() 377 | # adata.X = adata.layers["raw_counts"] 378 | scvi.model.SCVI.setup_anndata(adata, layer=None, batch_key=self.batch_key) 379 | self.vae = scvi.model.SCVI(adata, n_layers=2, n_latent=30, gene_likelihood="nb") 380 | self.vae.train() 381 | self.adata.obsm["X_scVI"] = self.vae.get_latent_representation() 382 | self._save_adata_() 383 | return 384 | 385 | def _scanvi_(self): 386 | self._str_formatter("X_scANVI") 387 | if "X_scANVI" in self.adata.obsm and not self.scratch: 388 | return 389 | lvae = scvi.model.SCANVI.from_scvi_model( 390 | self.vae, 391 | adata=self.adata, 392 | labels_key=self.label_key, 393 | unlabeled_category="Unknown", 394 | ) 395 | lvae.train() 396 | self.adata.obsm["X_scANVI"] = lvae.get_latent_representation(self.adata) 397 | # self.adata.obsm["mde_scanvi"] = mde(self.adata.obsm["X_scANVI"]) 398 | self._save_adata_() 399 | return 400 | 401 | def _bbknn_(self): 402 | # ref: https://bbknn.readthedocs.io/en/latest/# 403 | # tutorial: https://nbviewer.org/github/Teichlab/bbknn/blob/master/examples/mouse.ipynb 404 | 405 | import bbknn 406 | 407 | self._str_formatter("bbknn") 408 | if "X_bbknn" in self.adata.obsm and not self.scratch: 409 | return 410 | # if self.adata.n_obs < 1e5: 411 | # _temp_adata = bbknn.bbknn(self.adata, batch_key=self.batch_key, copy=True) 412 | # else: 413 | print(self.adata.obs[self.batch_key].value_counts().tail()) 414 | _smallest_n_neighbor = self.adata.obs[self.batch_key].value_counts().tail(1).values[0] 415 | _temp_adata = bbknn.bbknn( 416 | self.adata, 417 | batch_key=self.batch_key, 418 | neighbors_within_batch=min(10, _smallest_n_neighbor), 419 | copy=True, 420 | ) 421 | sc.tl.umap(_temp_adata) 422 | self.adata.obsm["X_bbknn"] = _temp_adata.obsm["X_umap"] 423 | self._save_adata_() 424 | return 425 | 426 | def _harmony_(self): 427 | # https://pypi.org/project/harmony-pytorch/ 428 | self._str_formatter("Harmony") 429 | if "Harmony" in self.adata.obsm and not self.scratch: 430 | return 431 | 432 | from harmony import harmonize 433 | 434 | self.adata.obsm["Harmony"] = harmonize( 435 | self.adata.obsm["X_pca"], 436 | self.adata.obs, 437 | batch_key=self.batch_key, 438 | use_gpu=True 439 | ) 440 | self._save_adata_() 441 | return 442 | 443 | def _scanorama_(self): 444 | # https://github.com/brianhie/scanorama 445 | self._str_formatter("Scanorama") 446 | if "Scanorama" in self.adata.obsm and not self.scratch: 447 | return 448 | 449 | import scanorama 450 | 451 | batch_cats = self.adata.obs[self.batch_key].cat.categories 452 | adata_list = [self.adata[self.adata.obs[self.batch_key] == b].copy() for b in batch_cats] 453 | scanorama.integrate_scanpy(adata_list) 454 | 455 | self.adata.obsm["Scanorama"] = np.zeros((self.adata.shape[0], adata_list[0].obsm["X_scanorama"].shape[1])) 456 | for i, b in enumerate(batch_cats): 457 | self.adata.obsm["Scanorama"][self.adata.obs[self.batch_key] == b] = adata_list[i].obsm["X_scanorama"] 458 | self._save_adata_() 459 | return 460 | 461 | def _scgen_(self): 462 | self._str_formatter("scGen") 463 | if "scgen_pca" in self.adata.obsm and not self.scratch: 464 | return 465 | 466 | # if not self.args.highvar: 467 | if self.adata.shape[0] > 1e5: 468 | return 469 | 470 | import scgen 471 | 472 | # ref: https://scgen.readthedocs.io/en/stable/tutorials/scgen_batch_removal.html 473 | # pip install git+https://github.com/theislab/scgen.git 474 | 475 | # ref: https://github.com/theislab/scib-reproducibility/blob/main/notebooks/integration/Test_scgen.ipynb 476 | # ref: https://github.com/theislab/scib/blob/main/scib/integration.py 477 | 478 | # ref: https://github.com/LungCellAtlas/HLCA_reproducibility/blob/main/notebooks/1_building_and_annotating_the_atlas_core/04_integration_benchmark_prep_and_scgen.ipynb 479 | 480 | scgen.SCGEN.setup_anndata(self.adata, batch_key=self.batch_key, labels_key=self.label_key) 481 | model = scgen.SCGEN(self.adata) 482 | model.train( 483 | max_epochs=100, 484 | batch_size=32, # 32 485 | early_stopping=True, 486 | early_stopping_patience=25, 487 | ) 488 | 489 | adata_scgen = model.batch_removal() 490 | sc.pp.pca(adata_scgen, svd_solver="arpack") 491 | sc.pp.neighbors(adata_scgen) 492 | sc.tl.umap(adata_scgen) 493 | 494 | self.adata.obsm["scgen_umap"] = adata_scgen.obsm["X_umap"] 495 | self.adata.obsm["scgen_pca"] = adata_scgen.obsm["X_pca"] 496 | self._save_adata_() 497 | return 498 | 499 | def _fastmnn_(self): 500 | # [deprecated]: https://github.com/chriscainx/mnnpy 501 | # ref: https://github.com/HelloWorldLTY/mnnpy 502 | # from: https://github.com/chriscainx/mnnpy/issues/42 503 | 504 | self._str_formatter("fastMNN") 505 | if "fastMNN_pca" in self.adata.obsm and not self.scratch: 506 | return 507 | 508 | # if not self.args.highvar: 509 | if self.adata.shape[0] > 1e5: 510 | return 511 | 512 | import mnnpy 513 | 514 | def split_batches(adata, batch_key, hvg=None, return_categories=False): 515 | """Split batches and preserve category information 516 | Ref: https://github.com/theislab/scib/blob/main/scib/utils.py#L32""" 517 | split = [] 518 | batch_categories = adata.obs[batch_key].cat.categories 519 | if hvg is not None: 520 | adata = adata[:, hvg] 521 | for i in batch_categories: 522 | split.append(adata[adata.obs[batch_key] == i].copy()) 523 | if return_categories: 524 | return split, batch_categories 525 | return split 526 | 527 | split, categories = split_batches(self.adata, batch_key=self.batch_key, return_categories=True) 528 | if self.args.dataset in [ 529 | "lung_fetal_organoid", 530 | "COVID", 531 | "heart", 532 | "brain", 533 | "breast", 534 | ]: 535 | k = 10 536 | else: 537 | k = 20 538 | corrected, _, _ = mnnpy.mnn_correct( 539 | *split, 540 | k=k, 541 | batch_key=self.batch_key, 542 | batch_categories=categories, 543 | index_unique=None, 544 | ) 545 | 546 | adata_fastmnn = corrected 547 | sc.pp.pca(adata_fastmnn, svd_solver="arpack") 548 | sc.pp.neighbors(adata_fastmnn) 549 | sc.tl.umap(adata_fastmnn) 550 | 551 | self.adata.obsm["fastMNN_umap"] = adata_fastmnn.obsm["X_umap"] 552 | self.adata.obsm["fastMNN_pca"] = adata_fastmnn.obsm["X_pca"] 553 | self._save_adata_() 554 | return 555 | 556 | def _scpoli_(self): 557 | self._str_formatter("scPoli") 558 | if "scPoli" in self.adata.obsm and not self.scratch: 559 | return 560 | if self.args.dataset in ["brain", "breast"]: 561 | return 562 | import warnings 563 | 564 | warnings.filterwarnings("ignore") 565 | from scarches.models.scpoli import scPoli 566 | 567 | self.adata.X = self.adata.X.astype(np.float32) 568 | early_stopping_kwargs = { 569 | "early_stopping_metric": "val_prototype_loss", 570 | "mode": "min", 571 | "threshold": 0, 572 | "patience": 10, 573 | "reduce_lr": True, 574 | "lr_patience": 13, 575 | "lr_factor": 0.1, 576 | } 577 | scpoli_model = scPoli( 578 | adata=self.adata, 579 | condition_keys=self.batch_key, 580 | cell_type_keys=self.label_key, 581 | embedding_dims=16, 582 | recon_loss="nb", 583 | ) 584 | scpoli_model.train( 585 | n_epochs=50, 586 | pretraining_epochs=40, 587 | early_stopping_kwargs=early_stopping_kwargs, 588 | eta=5, 589 | ) 590 | self.adata.obsm["scPoli"] = scpoli_model.get_latent(self.adata, mean=True) 591 | self._save_adata_() 592 | return 593 | 594 | def _islander_(self): 595 | from tqdm.auto import tqdm 596 | 597 | scDataset, self.scDataLoader = self._scDataloader() 598 | # self.cell2cat = scData_Train.CELL2CAT 599 | self.n_gene = scDataset.n_vars 600 | self._load_model() 601 | self.model.eval() 602 | emb_cells = [] 603 | 604 | for item in tqdm(self.scDataLoader): 605 | counts_ = item["counts"].to(self.device).squeeze() 606 | # for _idx in range(self.adata.n_obs): 607 | # if (self.adata.X[_idx, :] == counts_.cpu().numpy()[1]).all(): 608 | # print(_idx) 609 | emb_cell = self.model.extra_repr(counts_) 610 | emb_cells.append(emb_cell.detach().cpu().numpy()) 611 | 612 | emb_cells = np.concatenate(emb_cells, axis=0) 613 | self.adata.obsm["Islander"] = emb_cells 614 | self._save_adata_() 615 | return 616 | 617 | # TODO: add scGraph 618 | 619 | 620 | if __name__ == "__main__": 621 | # 622 | args = Parser_Benchmarker() 623 | benchmarker = scIB(args=args) 624 | benchmarker._benchmark_() 625 | 626 | # jaxlib is version 0.4.7, but this version of jax requires version >= 0.4.14. 627 | -------------------------------------------------------------------------------- /src/scDataset.py: -------------------------------------------------------------------------------- 1 | # Ref: https://github.com/theislab/dca/blob/master/dca/io.py 2 | from __future__ import division, print_function, absolute_import 3 | import os, json, time, torch, shutil, random, pickle, numpy as np, scanpy as sc, Data_Handler as dh 4 | from scipy.sparse import save_npz, load_npz, csr_matrix 5 | from torch.utils.data import Dataset, DataLoader 6 | from os.path import join, exists 7 | from typing import Union 8 | from tqdm import tqdm 9 | 10 | AnyRandom = Union[None, int, np.random.RandomState] 11 | NumCells = 256 # 4096 12 | 13 | 14 | def set_seed(seed_=0): 15 | random.seed(seed_) 16 | np.random.seed(seed_) 17 | torch.cuda.manual_seed(seed_) 18 | torch.cuda.manual_seed_all(seed_) 19 | torch.backends.cudnn.enabled = True 20 | torch.backends.cudnn.deterministic = True 21 | 22 | 23 | # Ref: https://pytorch.org/tutorials/beginner/data_loading_tutorial.html 24 | class scDataset(Dataset): 25 | def __init__( 26 | self, 27 | random_state: AnyRandom = None, 28 | num_cells: int = NumCells, 29 | test_ratio: float = 0.1, 30 | inference: bool = False, 31 | rm_cache: bool = False, 32 | train_and_test=False, 33 | training: bool = True, 34 | batch_id: list = None, 35 | dataset: str = "lung", 36 | verbose: bool = True, 37 | n_genes: int = -1, 38 | ): 39 | self.train_and_test = train_and_test 40 | self.random_state = random_state 41 | self.test_ratio = test_ratio 42 | # self.full_train = full_train 43 | self.inference = inference 44 | self.training = training 45 | self.batch_id = batch_id 46 | self.n_cells = num_cells 47 | self.verbose = verbose 48 | self.dataset = dataset 49 | self.n_genes = n_genes 50 | # TODO: here can be optimized 51 | self._load_adata() 52 | 53 | if self.random_state: 54 | set_seed(self.random_state) 55 | 56 | if rm_cache: 57 | self._rmdirs() 58 | 59 | self.CELL2CAT = dh.CELL2CAT_[self.dataset] 60 | self.BATCH2CAT = dh.BATCH2CAT_[self.dataset] 61 | self.batch_key = dh.META_[self.dataset]["batch"] 62 | self.label_key = dh.META_[self.dataset]["celltype"] 63 | if self.verbose: 64 | print("\n" * 2) 65 | print("=" * 77) 66 | print("DATASET: %s" % (self.dataset)) 67 | print("BATCH: %s, LABEL: %s" % (self.batch_key, self.label_key)) 68 | print(f"Load {self.n_vars} Genes & {self.n_obs} Cells.") 69 | print("=" * 77 + "\n" * 2) 70 | if inference: 71 | self._outpath = join(dh.DATA_DIR, self.dataset, "Benchmark_") 72 | os.makedirs(self._outpath, exist_ok=True) 73 | if len(os.listdir(self._outpath)) < self.n_obs // self.n_cells: 74 | self._batchize_inference() 75 | 76 | elif self._check_presaved() and not rm_cache: 77 | _split = "Training" if self.training else "Testing" 78 | print("...Loading from Pre Saved, %s Split..." % _split) 79 | 80 | else: 81 | self._mkdirs() 82 | self._save_cfg() 83 | time_ = time.time() 84 | print("...Pre-Processing from Scratch...") 85 | self._count_quantified_batch() 86 | self._batchize() 87 | print("...Finished, Used %d Mins..." % ((time.time() - time_) // 60)) 88 | self._split() 89 | del self.adata 90 | 91 | def _str_formatter(self, message): 92 | if self.verbose: 93 | print(f"\n=== {message} ===\n") 94 | 95 | def _check_presaved(self) -> bool: 96 | self._outpath = join(dh.DATA_DIR, self.dataset, "Train") 97 | if not exists(self._outpath): 98 | return False 99 | if len(os.listdir(self._outpath)) < (self.n_obs // self.n_cells) * (1 - self.test_ratio - 0.05) * 2: 100 | return False 101 | return True 102 | 103 | def _save_cfg(self): 104 | _file = join(dh.DATA_DIR, self.dataset, "cfg.json") 105 | _cfgs = self.__dict__.copy() 106 | _cfgs["adata"] = None 107 | json.dump(_cfgs, open(_file, "w")) 108 | 109 | def _mkdirs(self): 110 | os.makedirs(join(dh.DATA_DIR, self.dataset), exist_ok=True) 111 | os.makedirs(join(dh.DATA_DIR, self.dataset, "Test"), exist_ok=True) 112 | os.makedirs(join(dh.DATA_DIR, self.dataset, "Train"), exist_ok=True) 113 | 114 | def _rmdirs(self): 115 | for _split in ["Train", "Test", "Benchmark_"]: 116 | if exists(join(dh.DATA_DIR, self.dataset, _split)): 117 | shutil.rmtree(join(dh.DATA_DIR, self.dataset, _split)) 118 | 119 | def _load_adata(self): 120 | # ref: https://chat.openai.com/share/8bc3b625-1e97-4954-bde0-69306df2c062 121 | # normalize then select highly variable genes 122 | _suffix = "_hvg" if self.n_genes != -1 else "" 123 | adata = sc.read(dh.DATA_EMB_[self.dataset + _suffix]) 124 | 125 | if self.n_genes != -1: 126 | sc.pp.highly_variable_genes( 127 | adata, 128 | subset=True, 129 | flavor="seurat_v3", 130 | n_top_genes=self.n_genes, 131 | ) 132 | self.n_vars = adata.n_vars 133 | self.n_obs = adata.n_obs 134 | self.adata = adata 135 | return 136 | 137 | def _count_quantified_batch(self): 138 | count_ = 0 139 | for _id in self.adata.obs[self.batch_key].unique().to_list(): 140 | _cells = (self.adata.obs[self.batch_key] == _id).sum() 141 | if _cells <= self.n_cells: 142 | continue 143 | count_ += 1 144 | self._str_formatter("Qualified batches: %d" % count_) 145 | 146 | @staticmethod 147 | def _isnan(value): 148 | return value != value 149 | 150 | def _conceptualize(self, adata, mask_tokens=["nan", "unknown"]): 151 | obj = dict() 152 | 153 | obj["cell_type"] = adata.obs[self.label_key].apply(lambda x: self.CELL2CAT[x]).values.to_list() 154 | obj["batch_id"] = adata.obs[self.batch_key].apply(lambda x: self.BATCH2CAT[x]).values.to_list() 155 | 156 | return obj 157 | 158 | def _split(self): 159 | if self.inference: 160 | _split = "Benchmark_" 161 | else: 162 | _split = "Train" if self.training else "Test" 163 | _outpath = join(dh.DATA_DIR, self.dataset) 164 | _path = join(_outpath, _split) 165 | self.file_list = [] 166 | 167 | for item in os.listdir(_path): 168 | self.file_list.append(join(_path, item.replace(".npz", "").replace(".pkl", ""))) 169 | # self.file_list = list(set(self.file_list)) 170 | # if self.inference: 171 | # self.file_list.sort() 172 | 173 | if self.train_and_test: 174 | _split = "Train" if _split == "Test" else "Test" 175 | _path = join(_outpath, _split) 176 | for item in os.listdir(_path): 177 | self.file_list.append(join(_path, item.replace(".npz", "").replace(".pkl", ""))) 178 | self.file_list = list(set(self.file_list)) 179 | if self.inference: 180 | self.file_list.sort() 181 | 182 | def _batchize(self): 183 | # ref: https://discuss.pytorch.org/t/best-way-to-load-a-lot-of-training-data/80847/2 184 | # Save files on disks instead of loading in MEM 185 | _path = join(dh.DATA_DIR, self.dataset) 186 | _num_cells = self.adata.__len__() 187 | _ids = np.arange(_num_cells) 188 | np.random.shuffle(_ids) 189 | 190 | _num_batch = round(_num_cells / self.n_cells) 191 | train_ids = np.ones(_num_batch) 192 | zero_indices = np.random.choice(_num_batch - 1, int(_num_batch * self.test_ratio), replace=False) 193 | train_ids[zero_indices] = 0 194 | train_ids[-1] = 0 195 | train_ids = train_ids.astype(bool) 196 | 197 | for _idx in tqdm(range(_num_batch)): 198 | adata_subset = self.adata[_ids[_idx * self.n_cells : (_idx + 1) * self.n_cells]].copy() 199 | 200 | obj = self._conceptualize(adata_subset) 201 | _split = "Train" if train_ids[_idx] else "Test" 202 | _outfile = join(_path, _split, "_%s.pkl" % str(_idx).zfill(10)) 203 | if isinstance(adata_subset.X, np.ndarray): 204 | adata_subset.X = csr_matrix(adata_subset.X) 205 | save_npz(_outfile.replace(".pkl", ".npz"), adata_subset.X) 206 | with open(_outfile, "wb") as handle: 207 | pickle.dump(obj, handle) 208 | 209 | # if self.full_train: 210 | # for _f in os.listdir(join(_path, "Test")): 211 | # shutil.copy(join(_path, "Test", _f), join(_path, "Train", _f)) 212 | return 213 | 214 | def _batchize_inference(self): 215 | _num_cells = self.adata.__len__() 216 | _ids = np.arange(_num_cells) 217 | for _idx in tqdm(range(_num_cells // self.n_cells + 1)): 218 | adata_subset = self.adata[_ids[_idx * self.n_cells : (_idx + 1) * self.n_cells]].copy() 219 | obj = self._conceptualize(adata_subset) 220 | _outfile = join(self._outpath, "_%s.pkl" % str(_idx).zfill(10)) 221 | if isinstance(adata_subset.X, np.ndarray): 222 | adata_subset.X = csr_matrix(adata_subset.X) 223 | save_npz(_outfile.replace(".pkl", ".npz"), adata_subset.X) 224 | with open(_outfile, "wb") as handle: 225 | pickle.dump(obj, handle) 226 | return 227 | 228 | def __len__(self): 229 | return len(self.file_list) 230 | 231 | def __getitem__(self, idx): 232 | filename = self.file_list[idx] 233 | infos = pickle.load(open(filename + ".pkl", "rb")) 234 | item = {"counts": torch.Tensor(load_npz(filename + ".npz").toarray())} 235 | item["cell_type"] = torch.Tensor(infos["cell_type"]).long() 236 | item["batch_id"] = torch.Tensor(infos["batch_id"]) 237 | return item 238 | 239 | def sample_two_files(self): 240 | # Randomly sample two files from the list 241 | file_indices = random.sample(range(len(self.file_list)), 2) 242 | return [self.__getitem__(idx) for idx in file_indices] 243 | 244 | 245 | def collate_fn(batch): 246 | output = {} 247 | for sample in batch: 248 | if isinstance(sample, dict): 249 | for key, value in sample.items(): 250 | if isinstance(value, torch.Tensor) or isinstance(value, np.ndarray) or isinstance(value, list): 251 | if key not in output: 252 | output[key] = [] 253 | output[key].append(value) 254 | elif isinstance(value, dict): 255 | if key not in output: 256 | output[key] = {} 257 | for k, v in value.items(): 258 | if k not in output[key]: 259 | output[key][k] = [] 260 | output[key][k].append(v) 261 | for key, value in output.items(): 262 | if isinstance(value, list): 263 | output[key] = torch.concat(value) 264 | elif isinstance(value, dict): 265 | for k, v in value.items(): 266 | output[key][k] = torch.tensor(v) 267 | return output 268 | 269 | 270 | class TripletDataset(Dataset): 271 | def __init__(self, scDataset): 272 | self.file_list = scDataset.file_list 273 | 274 | self.labels = [] 275 | for _file in self.file_list: 276 | infos = pickle.load(open(_file + ".pkl", "rb")) 277 | self.labels.append(infos["cell_type"]) 278 | self.labels_set = list(set(self.labels.numpy())) 279 | self.label_to_indices = {label: torch.where(self.labels == label)[0] for label in self.labels_set} 280 | 281 | def __getitem__(self, index): 282 | anchor, label = self.file_list[index], self.labels[index] 283 | positive_index = index 284 | while positive_index == index: 285 | positive_index = np.random.choice(self.label_to_indices[label]) 286 | negative_label = np.random.choice(list(filter(lambda x: x != label, self.labels_set))) 287 | negative_index = np.random.choice(self.label_to_indices[negative_label]) 288 | positive, negative = ( 289 | self.file_list[positive_index], 290 | self.file_list[negative_index], 291 | ) 292 | 293 | return anchor, positive, negative 294 | 295 | def __len__(self): 296 | return len(self.file_list) 297 | 298 | 299 | class TripletDatasetDUAL(Dataset): 300 | def __init__(self, scDataset): 301 | self.file_list = scDataset.file_list 302 | self.CELL2CAT = scDataset.CELL2CAT 303 | self.n_vars = scDataset.n_vars 304 | # self.BATCH2CAT = dh.BATCH2CAT_[self.dataset] 305 | # self.batch_key = dh.META_[self.dataset]["batch"] 306 | # self.label_key = dh.META_[self.dataset]["celltype"] 307 | 308 | def __getitem__(self, index): 309 | _idx = self.file_list[index] 310 | anchor_counts = torch.Tensor(load_npz(_idx + ".npz").toarray()) 311 | anchor_labels = pickle.load(open(_idx + ".pkl", "rb"))["cell_type"] 312 | 313 | # Randomly select a different file 314 | r_idx = np.random.choice(self.file_list) 315 | r_counts = torch.Tensor(load_npz(r_idx + ".npz").toarray()) 316 | r_labels = pickle.load(open(r_idx + ".pkl", "rb"))["cell_type"] 317 | 318 | # Create union of counts and labels 319 | union_counts = torch.cat([anchor_counts, r_counts]) 320 | union_labels = anchor_labels + r_labels 321 | positives, negatives = [], [] 322 | 323 | # Iterate over each data point in the anchor file 324 | for i in range(len(anchor_labels)): 325 | # Find a positive example (different data point within the same file) 326 | pos_indices = [j for j in range(len(anchor_labels)) if anchor_labels[j] == anchor_labels[i] and j != i] 327 | if pos_indices: 328 | pos_index = np.random.choice(pos_indices) 329 | positive = union_counts[pos_index] 330 | else: 331 | # If no valid positive found in the same file, use the anchor as a last resort 332 | positive = anchor_counts[i] 333 | 334 | # Find a negative example (any data point from a different file with a different label) 335 | neg_indices = [k for k in range(len(anchor_labels), len(union_labels)) if union_labels[k] != anchor_labels[i]] 336 | if neg_indices: 337 | neg_index = np.random.choice(neg_indices) 338 | negative = union_counts[neg_index] 339 | else: 340 | # Fallback if no negative found (extremely rare case) 341 | negative = torch.rand(anchor_counts.shape[1]) # Random tensor as a last resort 342 | 343 | positives.append(positive.unsqueeze(0)) 344 | negatives.append(negative.unsqueeze(0)) 345 | 346 | return {"anchor": anchor_counts, "positive": torch.cat(positives), "negative": torch.cat(negatives)} 347 | 348 | def __len__(self): 349 | return len(self.file_list) 350 | 351 | 352 | class ContrastiveDataset(Dataset): 353 | def __init__(self, scDataset): 354 | self.file_list = scDataset.file_list 355 | 356 | self.labels = [] 357 | for _file in self.file_list: 358 | infos = pickle.load(open(_file + ".pkl", "rb")) 359 | self.labels.append(infos["cell_type"]) 360 | self.labels_set = list(set(self.labels.numpy())) 361 | self.label_to_indices = {label: torch.where(self.labels == label)[0] for label in self.labels_set} 362 | 363 | def __getitem__(self, index): 364 | data1, label1 = self.file_list[index], self.labels[index] 365 | # Generate a positive example 50% of the time 366 | should_get_same_class = np.random.randint(0, 2) 367 | if should_get_same_class: 368 | while True: 369 | # Keep looping until the same class image is found 370 | idx2 = np.random.choice(self.label_to_indices[label1]) 371 | if idx2 != index: 372 | break 373 | else: 374 | label2 = np.random.choice(list(filter(lambda x: x != label1, self.labels_set))) 375 | idx2 = np.random.choice(self.label_to_indices[label2]) 376 | 377 | data2 = self.file_list[idx2] 378 | return (data1, data2), should_get_same_class 379 | 380 | 381 | class ContrastiveDatasetDUAL(Dataset): 382 | # in case there are many cell types (e.g., > 100) 383 | def __init__(self, scDataset): 384 | self.file_list = scDataset.file_list 385 | self.CELL2CAT = scDataset.CELL2CAT 386 | self.n_vars = scDataset.n_vars 387 | 388 | def __getitem__(self, index): 389 | file1 = self.file_list[index] 390 | count1 = torch.Tensor(load_npz(file1 + ".npz").toarray()) 391 | label1 = torch.Tensor(pickle.load(open(file1 + ".pkl", "rb"))["cell_type"]) 392 | 393 | file2 = np.random.choice(self.file_list) 394 | count2 = torch.Tensor(load_npz(file2 + ".npz").toarray()) 395 | label2 = torch.Tensor(pickle.load(open(file2 + ".pkl", "rb"))["cell_type"]) 396 | 397 | return {"b1": count1, "l1": label1, "b2": count2, "l2": label2} 398 | 399 | def __len__(self): 400 | return len(self.file_list) 401 | 402 | 403 | if __name__ == "__main__": 404 | # 405 | for DATASET in [ 406 | # "lung", 407 | # "lung_fetal_donor", 408 | # "lung_fetal_organoid", 409 | # "brain", 410 | # "breast", 411 | # "heart", 412 | # "eye", 413 | # "gut_fetal", 414 | # "skin", 415 | # "COVID", 416 | "pancreas", 417 | ]: 418 | scData_Train = scDataset( 419 | dataset=DATASET, 420 | rm_cache=True, 421 | ) 422 | 423 | scData_TrainLoader = DataLoader( 424 | scData_Train, 425 | batch_size=1, 426 | shuffle=True, 427 | num_workers=4, 428 | collate_fn=collate_fn, 429 | ) 430 | for id_, batch_ in enumerate(tqdm(scData_TrainLoader)): 431 | pass 432 | # print( 433 | # id_, # Integer 434 | # batch_["norm_counts"].size(), # Tensor 435 | # batch_["sample_id"].__len__(), # List 436 | # batch_["cell_type"].__len__(), # Tensor 437 | # ) 438 | -------------------------------------------------------------------------------- /src/scFinetuner.py: -------------------------------------------------------------------------------- 1 | # https://docs.scvi-tools.org/en/stable/tutorials/notebooks/harmonization.html 2 | # https://scib-metrics.readthedocs.io/en/stable/notebooks/lung_example.html 3 | 4 | import os, torch, json, time, wandb, shutil, scanpy as sc, numpy as np, torch.nn as nn, torch.nn.functional as F, Utils_Handler as uh, Data_Handler as dh, scLoss as scL, scDataset as scD 5 | from torch.optim.lr_scheduler import CosineAnnealingLR 6 | from scLoss import HSICLoss, InfoNCELoss 7 | from torch.utils.data import DataLoader 8 | from ArgParser import Parser_Finetuner 9 | from os.path import join, exists 10 | from scBenchmarker import scIB 11 | from torch.optim import Adam 12 | from copy import deepcopy 13 | 14 | # set_seed(dh.SEED) 15 | 16 | PATIENCE = 10 17 | MIX_UP = False 18 | REC_GRADS = False 19 | # CONSTANT_LR = True 20 | CONSTANT_LR = False 21 | _n = dh.TOTAL_CONCEPTS 22 | 23 | 24 | # 25142 Genes, Limb Intersection 25 | 26 | 27 | class scFineTuner(scIB): 28 | """Basic Cross Entropy Fine Tuner""" 29 | 30 | def __init__(self, args): 31 | super().__init__(args=args) 32 | 33 | def _load_adata_(self): 34 | assert self.args.dataset != "hlca" 35 | 36 | if self.args.use_raw: 37 | adata = sc.read(dh.DATA_RAW_[self.args.dataset], first_column_names=True) 38 | if self.args.highvar: 39 | self._str_formatter("using top %d highly variable genes" % self.args.highvar_n) 40 | sc.pp.highly_variable_genes( 41 | adata, 42 | subset=True, 43 | flavor="seurat_v3", 44 | batch_key=self.batch_key, 45 | n_top_genes=self.args.highvar_n, 46 | ) 47 | else: 48 | _suffix = "_hvg" if self.args.highvar else "" 49 | adata = sc.read(dh.DATA_EMB_[self.args.dataset + _suffix]) 50 | 51 | if self.args.batch_id is not None: 52 | adata = adata[adata.obs[self.batch_key].isin(self.args.batch_id), :] 53 | self.adata = adata 54 | 55 | return 56 | 57 | def _prep_dataloader_(self): 58 | self._str_formatter("Dataloader, Fine-Tuning") 59 | _, gene_pad = self.args.dataset.split("_") 60 | self._dataset = scD.scDataset_Transfer( 61 | adata=self.adata, 62 | prefix="Finetune_" + self.args.dataset, 63 | n_cells=self.args.n_cells, 64 | dataset=self.args.dataset, 65 | gene_pad=gene_pad, 66 | with_meta=True, 67 | rm_cache=False, 68 | one_shot=False, 69 | few_shots=False, 70 | all_shots=True, 71 | shuffle=True, 72 | ) 73 | 74 | self.n_gene = self._dataset.n_vars 75 | self.CONCEPTS = self._dataset.CONCEPTS 76 | self.NUM_CELLTYPE = len(self._dataset.CELL2CAT) 77 | self.scDataLoader = DataLoader( 78 | self._dataset, 79 | batch_size=1, 80 | shuffle=False, 81 | num_workers=4, 82 | collate_fn=scD.collate_fn, 83 | ) 84 | return 85 | 86 | def _loss_cet(self, emb_b: torch.Tensor, cell_type: torch.Tensor): 87 | # cell type prediction 88 | emb_b = self.model.projector(emb_b) 89 | emb_b = emb_b.log_softmax(dim=-1) 90 | return F.nll_loss(emb_b, cell_type.to(emb_b.device)) 91 | 92 | def _loss_mixup(self, emb_b, cell_type, alpha=1.0): 93 | # alpha=1.0, 0.2, 0.4, 0.6 94 | """mixed inputs, pairs of targets, and lambda""" 95 | if alpha > 0: 96 | lam = np.random.beta(alpha, alpha) 97 | else: 98 | lam = 1 99 | 100 | batch_size = emb_b.size()[0] 101 | index = torch.randperm(batch_size) 102 | 103 | mixed_x = lam * emb_b + (1 - lam) * emb_b[index, :] 104 | y_a, y_b = cell_type, cell_type[index] 105 | 106 | return lam * self._loss_cet(mixed_x, y_a) + (1 - lam) * self._loss_cet(mixed_x, y_b) 107 | 108 | def _finetune_(self): 109 | model_cfg = self.args.save_path.split("/")[-1] 110 | dataset_cfg = self.args.save_path.split("/")[-2] 111 | self._str_formatter("Model: %s, Dataset: %s" % (model_cfg, dataset_cfg)) 112 | wandb.init( 113 | project="Cross Entropy FT %s" % dataset_cfg, 114 | name=model_cfg, 115 | config=self.args, 116 | ) 117 | self._prep_dataloader_() 118 | self._load_model() 119 | BEST_TEST = 1e4 120 | 121 | self.model.projector = torch.nn.Linear(self.args.leakage, self.NUM_CELLTYPE, bias=False).to(self.device) # without bias doesn't matter that much 122 | 123 | if self.args.savename_ft: 124 | SAVE_PATH = self.args.savename_ft 125 | else: 126 | SAVE_PATH = join(dh.MODEL_DIR, uh.DATETIME) 127 | self._str_formatter("SAVE_PATH: %s" % SAVE_PATH) 128 | 129 | if exists(SAVE_PATH): 130 | shutil.rmtree(SAVE_PATH) 131 | os.makedirs(SAVE_PATH, exist_ok=True) 132 | 133 | with open(join(SAVE_PATH, "cfg.json"), "w") as outfile: 134 | json.dump(vars(self.args), outfile) 135 | 136 | LR = self.args.lr_ft 137 | NUM_EPOCHS = self.args.epoch_ft 138 | Opt = Adam(self.model.parameters(), lr=LR) 139 | LR_SCHEDULER = CosineAnnealingLR(Opt, T_max=NUM_EPOCHS) 140 | 141 | start = time.time() 142 | for epoch in range(NUM_EPOCHS): 143 | total_err, total_rec, total_cet, total_hsi = 0, 0, 0, 0 144 | cell_acc_ = [] 145 | 146 | LossWeights = { 147 | "reconstruction": [1.0], 148 | # "infonce": [], 149 | "cet": [], 150 | } 151 | for batch_ in self.scDataLoader: 152 | Opt.zero_grad() 153 | counts_ = batch_["norm_counts"].squeeze().to(self.device) 154 | rec_ = self.model(counts_) 155 | emb_ = self.model.extra_repr(counts_) 156 | emb_b = emb_[:, dh.TOTAL_CONCEPTS :] 157 | 158 | loss_rec = args.w_rec_ft * self.model._loss_rec(counts_, rec_) 159 | loss_hsi = args.w_hsi_ft * self._loss_hsi(emb_b, batch_["meta"]) 160 | loss_inf = args.w_inf_ft * self._loss_inf(emb_b, batch_["meta"]) 161 | loss_cet = args.w_cet_ft * self._loss_cet(emb_b, batch_["cell_type"]) 162 | 163 | if MIX_UP: 164 | loss_cet += self._loss_mixup(emb_b, batch_["cell_type"]) 165 | 166 | loss = loss_rec + loss_hsi + loss_inf + loss_cet 167 | wandb.log( 168 | { 169 | "Loss Reconstruction": loss_rec.item(), 170 | "Loss Cell Type": loss_cet.item(), 171 | "Loss InfoNCE": loss_inf.item(), 172 | "Loss HSIC": loss_hsi.item(), 173 | } 174 | ) 175 | 176 | if REC_GRADS: 177 | LossWeights["cet"].append(uh.Weights_on_GradNorm(self.model, loss_cet, loss_rec)) 178 | 179 | else: 180 | loss.backward() 181 | Opt.step() 182 | 183 | total_err += loss.item() 184 | total_rec += loss_rec.item() 185 | total_hsi += loss_hsi.item() 186 | total_cet += loss_cet.item() 187 | 188 | cell_acc_.append(self.model.acc_celltype_(emb_b, batch_["cell_type"])) 189 | 190 | if not CONSTANT_LR: 191 | LR_SCHEDULER.step() 192 | lr = LR_SCHEDULER.get_last_lr()[0] 193 | train_err = total_err / len(self.scDataLoader) 194 | train_hsi = total_hsi / len(self.scDataLoader) 195 | train_rec = total_rec / len(self.scDataLoader) 196 | 197 | if uh._isnan(train_err): 198 | self._str_formatter("NaN Value Encountered, Quitting") 199 | quit() 200 | 201 | if epoch % 1 == 0: 202 | print( 203 | "LR: %.6f " % lr, 204 | "Epoch: %2d, " % epoch, 205 | "Total Loss: %.2f, " % train_err, 206 | "CellType Acc: %.2f, " % (100 * uh.mean_(cell_acc_)), 207 | "EST: %.1f Mins" % ((time.time() - start) // 60), 208 | ) 209 | if REC_GRADS: 210 | for key, value in LossWeights.items(): 211 | print("%20s, %.4f" % (key, uh.mean_(value))) 212 | print("\n") 213 | 214 | wandb.log( 215 | { 216 | "LR": lr, 217 | "Total Loss": train_err, 218 | "HSIC Loss": train_hsi, 219 | "Reconstruction Loss": train_rec, 220 | "CellType Accuracy": 100 * uh.mean_(cell_acc_), 221 | } 222 | ) 223 | 224 | if (epoch + 1) % 1 == 0: 225 | torch.save( 226 | self.model.state_dict(), 227 | join(SAVE_PATH, "ckpt_%d.pth" % (epoch + 1)), 228 | ) 229 | 230 | if train_err < BEST_TEST: 231 | BEST_TEST = train_err 232 | torch.save(self.model.state_dict(), join(SAVE_PATH, "ckpt_best.pth")) 233 | 234 | return 235 | 236 | 237 | class scFineTuner_SCL(scFineTuner): 238 | """Supervised Contrastive Learning""" 239 | 240 | def __init__(self, args): 241 | super().__init__(args=args) 242 | 243 | def _prep_dataloader_(self): 244 | self._str_formatter("Prep Dataloader, GeneCBM, Fine-Tuning") 245 | _, gene_pad = self.args.dataset.split("_") 246 | self._dataset = scD.scDataset_Transfer( 247 | adata=self.adata, 248 | gene_pad=gene_pad, 249 | prefix="Finetune_" + self.args.dataset, 250 | dataset=self.args.dataset, 251 | n_cells=self.args.n_cells, 252 | with_meta=True, 253 | rm_cache=False, 254 | one_shot=False, 255 | few_shots=False, 256 | all_shots=True, 257 | shuffle=True, 258 | ) 259 | 260 | self.n_gene = self._dataset.n_vars 261 | self.CONCEPTS = self._dataset.CONCEPTS 262 | self.NUM_CELLTYPE = len(self._dataset.CELL2CAT) 263 | self.scDataLoader = DataLoader( 264 | self._dataset, 265 | batch_size=1, 266 | shuffle=True, 267 | num_workers=4, 268 | collate_fn=scD.collate_fn, 269 | ) 270 | 271 | self.ContrastiveDataset = scD.ContrastiveDatasetDUAL(self._dataset) 272 | self.ContrastiveDatasetloader = DataLoader( 273 | self.ContrastiveDataset, 274 | batch_size=1, 275 | shuffle=True, 276 | num_workers=4, 277 | ) 278 | 279 | # self.TripletDataset = TripletDataset(self._dataset) 280 | # self.ContrastiveDataset = ContrastiveDataset(self._dataset) 281 | 282 | # self.TripletDataloader = DataLoader( 283 | # self.TripletDataset, 284 | # batch_size=1, 285 | # shuffle=False, 286 | # num_workers=4, 287 | # collate_fn=collate_fn, 288 | # ) 289 | 290 | return 291 | 292 | def _finetune_(self): 293 | dataset_cfg = self.args.save_path.split("/")[-2] 294 | model_cfg = self.args.save_path.split("/")[-1] 295 | wandb.init(project="SCL FT %s" % dataset_cfg, name=model_cfg, config=self.args) 296 | self._prep_dataloader_() 297 | self._load_model() 298 | 299 | self.model.projector = nn.Linear(self.args.leakage, self.NUM_CELLTYPE).to(self.device) 300 | 301 | if self.args.savename_ft: 302 | SAVE_PATH = self.args.savename_ft 303 | else: 304 | SAVE_PATH = join(dh.MODEL_DIR, uh.DATETIME) 305 | self._str_formatter("SAVE_PATH: %s" % SAVE_PATH) 306 | os.makedirs(SAVE_PATH, exist_ok=True) 307 | 308 | with open(join(SAVE_PATH, "cfg.json"), "w") as outfile: 309 | json.dump(vars(self.args), outfile) 310 | 311 | BEST_TEST = 1e4 312 | LR = self.args.lr_ft 313 | NUM_EPOCHS = self.args.epoch_ft 314 | Opt = Adam(self.model.parameters(), lr=LR) 315 | LR_SCHEDULER = CosineAnnealingLR(Opt, T_max=NUM_EPOCHS) 316 | 317 | LOCAL_PATIENCE = 0 318 | start = time.time() 319 | # triplet_loss = scL.TripletLoss() 320 | # contrastive_loss = scL.ContrastiveLoss() 321 | scl_loss = scL.DualBatchSupervisedContrastiveLoss() 322 | 323 | for epoch in range(NUM_EPOCHS): 324 | total_err, total_rec, total_scl = 0, 0, 0 325 | 326 | # for anchor, positive, negative in self.TripletDataloader: 327 | # Opt.zero_grad() 328 | 329 | # counts_a = anchor["norm_counts"].squeeze().to(self.device) 330 | # emb_a = self.model.extra_repr(counts_a) 331 | 332 | # counts_p = positive["norm_counts"].squeeze().to(self.device) 333 | # emb_p = self.model.extra_repr(counts_p) 334 | 335 | # counts_n = negative["norm_counts"].squeeze().to(self.device) 336 | # emb_n = self.model.extra_repr(counts_n) 337 | 338 | # triplet_loss = triplet_loss(emb_a, emb_p, emb_n) 339 | 340 | # triplet_loss.backward() 341 | # Opt.step() 342 | # total_error += triplet_loss.item() 343 | 344 | # for batch_ in self.scDataLoader: 345 | # Opt.zero_grad() 346 | # counts_ = batch_["norm_counts"].squeeze().to(self.device) 347 | # emb_ = self.model.extra_repr(counts_) 348 | # _permute = torch.randperm(emb_.size()[0]) 349 | # emb_permute = emb_[_permute].to(self.device) 350 | # label_ = batch_["cell_type"].to(self.device) 351 | # label_permute = label_[_permute] 352 | 353 | # _loss = contrastive_loss( 354 | # emb_, emb_permute, label_, label_permute) 355 | # _loss.backward() 356 | # Opt.step() 357 | # total_error += _loss.item() 358 | 359 | for b1, l1, b2, l2 in self.ContrastiveDatasetloader: 360 | Opt.zero_grad() 361 | 362 | """ Intra Batch """ 363 | counts_1 = b1.squeeze().to(self.device) 364 | emb_1 = self.model.extra_repr(counts_1) 365 | l1 = torch.concat(l1).to(self.device) 366 | 367 | _permute = torch.randperm(emb_1.size()[0]) 368 | emb_, l1_ = emb_1[_permute], l1[_permute] 369 | _loss = scl_loss(emb_1[:, _n:], emb_[:, _n:], l1, l1_) 370 | 371 | """ Inter Batch """ 372 | counts_2 = b2.squeeze().to(self.device) 373 | emb_2 = self.model.extra_repr(counts_2) 374 | if len(emb_1) == len(emb_2): 375 | l2 = torch.concat(l2).to(self.device) 376 | _permute = torch.randperm(emb_1.size()[0]) 377 | emb_2, l2 = emb_2[_permute], l2[_permute] 378 | _loss += scl_loss(emb_1[:, _n:], emb_2[:, _n:], l1, l2) 379 | 380 | _loss.backward() 381 | Opt.step() 382 | train_err += args.w_scl_ft * _loss.item() 383 | 384 | LR_SCHEDULER.step() 385 | train_err = train_err / len(self.scDataLoader) 386 | # train_recon_loss = total_recon / len(self.scDataLoader) 387 | 388 | lr = LR_SCHEDULER.get_last_lr()[0] 389 | 390 | if uh._isnan(train_err): 391 | self._str_formatter("NaN Value Encountered, Quitting") 392 | quit() 393 | 394 | if epoch % 1 == 0: 395 | print( 396 | "LR: %.6f " % lr, 397 | "Epoch: %2d, " % epoch, 398 | "Total Loss: %.2f, " % train_err, 399 | "EST: %.1f Mins" % ((time.time() - start) // 60), 400 | ) 401 | 402 | wandb.log( 403 | { 404 | "LR": lr, 405 | "Total Loss": total_err, 406 | } 407 | ) 408 | 409 | if (epoch + 1) % 1 == 0: 410 | torch.save( 411 | self.model.state_dict(), 412 | join(SAVE_PATH, "ckpt_%d.pth" % (epoch + 1)), 413 | ) 414 | 415 | LOCAL_PATIENCE += 1 416 | if total_err < BEST_TEST: 417 | LOCAL_PATIENCE = 0 418 | BEST_TEST = total_err 419 | torch.save(self.model.state_dict(), join(SAVE_PATH, "ckpt_best.pth")) 420 | 421 | if LOCAL_PATIENCE > PATIENCE: 422 | torch.save( 423 | self.model.state_dict(), 424 | join(SAVE_PATH, "ckpt_%d.pth" % (epoch + 1)), 425 | ) 426 | print("Patience (%d Epochs) Reached, Quitting" % PATIENCE) 427 | quit() 428 | 429 | return 430 | 431 | 432 | MOMERY_SIZE = 65536 # The queue size in MoCo 433 | # MOMERY_SIZE = 2048 # Smaller size does not help (overfit to positive pairs) 434 | MOMENTUM = 0.999 # The momentum in MoCo 435 | # TEMP = 0.07 # The temperature in MoCo 436 | 437 | 438 | class MemoryBank(nn.Module): 439 | def __init__(self, input_dim, device, memory_size=MOMERY_SIZE): 440 | super(MemoryBank, self).__init__() 441 | 442 | self.memory_size = memory_size 443 | self.embed_bank = torch.randn(memory_size, input_dim).to(device) 444 | # self.embed_bank = F.normalize(self.embed_bank, dim=0).to(device) 445 | self.label_bank = torch.randint(0, 20, (memory_size,)).to(device) 446 | 447 | @torch.no_grad() 448 | def fetch(self, embeds_ema, labels): 449 | self.embed_bank = torch.concat([embeds_ema, self.embed_bank]) 450 | self.label_bank = torch.concat([labels, self.label_bank]) 451 | 452 | self.embed_bank = self.embed_bank[: self.memory_size, :] 453 | self.label_bank = self.label_bank[: self.memory_size] 454 | 455 | return self.embed_bank, self.label_bank 456 | 457 | def forward(self): 458 | return 459 | 460 | 461 | class scFineTuner_PSL(scFineTuner): 462 | def __init__(self, args): 463 | super().__init__(args=args) 464 | 465 | def _finetune_(self): 466 | model_cfg = self.args.save_path.split("/")[-1] 467 | dataset_cfg = self.args.save_path.split("/")[-2] 468 | wandb.init( 469 | project="CFT_MB%s%s" % (dataset_cfg, self.args.trainer_ft), 470 | name=model_cfg, 471 | config=self.args, 472 | ) 473 | 474 | self._prep_dataloader_() 475 | self._load_model() 476 | 477 | if self.args.savename_ft: 478 | SAVE_PATH = self.args.savename_ft 479 | else: 480 | SAVE_PATH = join(dh.MODEL_DIR, uh.DATETIME) 481 | self._str_formatter("SAVE_PATH: %s" % SAVE_PATH) 482 | 483 | if exists(SAVE_PATH): 484 | shutil.rmtree(SAVE_PATH) 485 | os.makedirs(SAVE_PATH, exist_ok=True) 486 | 487 | with open(join(SAVE_PATH, "cfg.json"), "w") as outfile: 488 | json.dump(vars(self.args), outfile) 489 | 490 | """ === Projector === """ 491 | self.model.projector = nn.Linear(self.args.leakage, self.NUM_CELLTYPE, bias=True).to(self.device) 492 | # bias=False, doesn't matter that much 493 | 494 | """ === MoCo with Memory Bank === """ 495 | # encoder_q and encoder_k as the query encoder and key encoder 496 | encoder_q = self.model.encoder 497 | encoder_k = deepcopy(encoder_q).to(self.device) 498 | # print(encoder_q is self.model.encoder) -> True 499 | # print(encoder_k is self.model.encoder) -> False 500 | 501 | encoder_k.load_state_dict(encoder_q.state_dict()) 502 | for param_q, param_k in zip(encoder_q.parameters(), encoder_k.parameters()): 503 | param_k.data.copy_(param_q.data) # initialize 504 | param_k.requires_grad = False # not update by gradient 505 | 506 | BEST_TEST = 1e4 507 | LR = self.args.lr_ft 508 | NUM_EPOCHS = self.args.epoch_ft 509 | Opt = Adam(self.model.parameters(), lr=LR) 510 | LR_SCHEDULER = CosineAnnealingLR(Opt, T_max=NUM_EPOCHS) 511 | if self.args.trainer_ft == "scl": 512 | TRAINER = scL.SCL(MemoryBank(self.model.leak_dim, self.device)) 513 | elif self.args.trainer_ft == "simple": 514 | TRAINER = scL.SimPLE(MemoryBank(self.model.leak_dim, self.device)) 515 | else: 516 | raise NotImplementedError 517 | 518 | LOCAL_PATIENCE = 0 519 | start = time.time() 520 | 521 | for epoch in range(NUM_EPOCHS): 522 | total_err, total_psl, total_rec = 0, 0, 0 523 | cell_acc_ = [] 524 | LossWeights = { 525 | "reconstruction": [1.0], 526 | "psl": [], 527 | } 528 | for item in self.scDataLoader: 529 | Opt.zero_grad() 530 | 531 | counts_ = item["norm_counts"].to(self.device) 532 | labels_ = item["cell_type"].to(self.device) 533 | _permute = torch.randperm(counts_.size()[0]) 534 | counts_ = counts_[_permute] 535 | labels_ = labels_[_permute] 536 | 537 | """ === MoCo with Memory Bank === """ 538 | with torch.no_grad(): 539 | # update the key encoder 540 | for param_q, param_k in zip(encoder_q.parameters(), encoder_k.parameters()): 541 | param_k.data = param_k.data * MOMENTUM + param_q.data * (1.0 - MOMENTUM) 542 | 543 | # compute key features 544 | k = encoder_k(counts_)[:, dh.TOTAL_CONCEPTS :] 545 | # k = F.normalize(k, dim=1) 546 | 547 | # compute query features 548 | q = encoder_q(counts_)[:, dh.TOTAL_CONCEPTS :] 549 | loss_pls = args.w_psl_ft * TRAINER(q, labels_, k) 550 | 551 | rec_ = self.model(counts_) 552 | loss_rec = args.w_rec_ft * nn.MSELoss()(counts_, rec_) 553 | loss_cet = args.w_cet_ft * self._loss_cet(q, labels_) 554 | loss = loss_rec + loss_pls + loss_cet 555 | 556 | # backprop 557 | total_psl += loss_pls.item() 558 | total_rec += loss_rec.item() 559 | total_err += loss.item() 560 | 561 | cell_acc_.append(self.model.acc_celltype_(q, labels_)) 562 | 563 | # update / record gradients 564 | if REC_GRADS and (epoch + 1) % 2 == 0: 565 | LossWeights["psl"].append(uh.Weights_on_GradNorm(self.model, loss_pls, loss_rec)) 566 | else: 567 | loss.backward() 568 | Opt.step() 569 | 570 | # lr update 571 | if not CONSTANT_LR: 572 | LR_SCHEDULER.step() 573 | lr = LR_SCHEDULER.get_last_lr()[0] 574 | 575 | if uh._isnan(total_err): 576 | self._str_formatter("NaN Value Encountered, Quitting") 577 | quit() 578 | 579 | if epoch % 1 == 0: 580 | print( 581 | "LR: %.6f, " % lr, 582 | "Epoch: %2d, " % epoch, 583 | "Total Loss: %.4f, " % (total_err / len(self.scDataLoader)), 584 | "CellType Acc: %.2f, " % (100 * uh.mean_(cell_acc_)), 585 | "EST: %.1f Mins" % ((time.time() - start) // 60), 586 | ) 587 | 588 | wandb.log( 589 | { 590 | "LR": lr, 591 | "Total Loss": total_err / len(self.scDataLoader), 592 | "Total PSL Loss": total_psl / len(self.scDataLoader), 593 | "Total REC Loss": total_rec / len(self.scDataLoader), 594 | } 595 | ) 596 | 597 | if (epoch + 1) % 1 == 0: 598 | torch.save( 599 | self.model.state_dict(), 600 | join(SAVE_PATH, "ckpt_%d.pth" % (epoch + 1)), 601 | ) 602 | if REC_GRADS: 603 | raise ValueError 604 | # for key, value in LossWeights.items(): 605 | # print("%20s, %.4f" % (key, uh.mean_(value))) 606 | # print("\n") 607 | 608 | LOCAL_PATIENCE += 1 609 | if total_err < BEST_TEST: 610 | LOCAL_PATIENCE = 0 611 | BEST_TEST = total_err 612 | torch.save(self.model.state_dict(), join(SAVE_PATH, "ckpt_best.pth")) 613 | 614 | if LOCAL_PATIENCE > PATIENCE: 615 | torch.save( 616 | self.model.state_dict(), 617 | join(SAVE_PATH, "ckpt_%d.pth" % (epoch + 1)), 618 | ) 619 | print("Patience (%d Epochs) Reached, Quitting" % PATIENCE) 620 | quit() 621 | return 622 | 623 | 624 | if __name__ == "__main__": 625 | args = Parser_Finetuner() 626 | 627 | if args.mode == "ce": 628 | ft = scFineTuner(args) 629 | elif args.mode == "scl": 630 | ft = scFineTuner_SCL(args) 631 | elif args.mode == "psl": 632 | ft = scFineTuner_PSL(args) 633 | ft._finetune_() 634 | -------------------------------------------------------------------------------- /src/scGraph.py: -------------------------------------------------------------------------------- 1 | import warnings, numpy as np, scanpy as sc, pandas as pd, Utils_Handler as uh, Data_Handler as dh 2 | warnings.filterwarnings("ignore") 3 | from tqdm.auto import tqdm 4 | 5 | class scGraph: 6 | def __init__(self, adata_path, batch_key, label_key, hvg=False, trim_rate=0.05, thres_batch=100, thres_celltype=10): 7 | 8 | self.trim_rate = trim_rate 9 | self.thres_batch = thres_batch 10 | self.thres_celltype = thres_celltype 11 | self.ignore_celltype = [] # List of cell types to be ignored 12 | self._collect_pca_ = dict() 13 | self.concensus_df_pca = None 14 | 15 | self.adata = sc.read(adata_path, first_column_names=True) 16 | self.batch_key = batch_key 17 | self.label_key = label_key 18 | 19 | # self.adata = sc.read(dh.DATA_EMB_[adata_path], first_column_names=True) 20 | # self.batch_key = dh.META_[adata_path]["batch"] 21 | # self.label_key = dh.META_[adata_path]["celltype"] 22 | 23 | # if hvg: 24 | # # bdata = sc.read(adata_path.replace(".h5ad", "_hvg.h5ad"), first_column_names=True) 25 | # bdata = sc.read(dh.DATA_EMB_[adata_path + "_hvg"], first_column_names=True) 26 | # for _obsm in bdata.obsm: 27 | # self.adata.obsm[_obsm + "_hvg"] = bdata.obsm[_obsm] 28 | 29 | def preprocess(self): 30 | for celltype in self.adata.obs[self.label_key].unique(): 31 | if self.adata.obs[self.label_key].value_counts()[celltype] < self.thres_celltype: 32 | print(f"Skipped cell type {celltype}, due to < {self.thres_celltype} cells") 33 | self.ignore_celltype.append(celltype) 34 | 35 | def process_batches(self): 36 | print("Processing batches, calcualte centroids and pairwise distances") 37 | for BATCH_ in tqdm(self.adata.obs[self.batch_key].unique()): 38 | adata_batch = self.adata[self.adata.obs[self.batch_key] == BATCH_].copy() 39 | 40 | if len(adata_batch) < self.thres_batch: 41 | print(f"Skipped batch {BATCH_}, due to < {self.thres_batch} cells") 42 | continue 43 | 44 | # NOTE: make sure the adata.X is log1p transformed, otherwise do it here 45 | # sc.pp.normalize_per_cell(adata_batch, counts_per_cell_after=1e4) 46 | # sc.pp.log1p(adata_batch) 47 | 48 | sc.pp.highly_variable_genes(adata_batch, n_top_genes=1000) 49 | sc.pp.pca(adata_batch, n_comps=10, use_highly_variable=True) 50 | 51 | # NOTE: make sure 52 | centroids_pca = uh.calculate_trimmed_means( 53 | adata_batch.obsm["X_pca"], 54 | adata_batch.obs[self.label_key], 55 | trim_proportion=self.trim_rate, 56 | ignore_=self.ignore_celltype, 57 | ) 58 | pca_pairdist = uh.compute_classwise_distances(centroids_pca) 59 | self._collect_pca_[BATCH_] = pca_pairdist.div(pca_pairdist.max(axis=0), axis=1) 60 | 61 | def calculate_consensus(self): 62 | df_combined = pd.concat(self._collect_pca_.values(), axis=0, sort=False) 63 | self.concensus_df_pca = df_combined.groupby(df_combined.index).mean() 64 | self.concensus_df_pca = self.concensus_df_pca.loc[self.concensus_df_pca.columns, :] 65 | self.concensus_df_pca = self.concensus_df_pca / self.concensus_df_pca.max(axis=0) 66 | 67 | @staticmethod 68 | def rank_diff(df1, df2): 69 | spearman_corr = {} 70 | for col in df1.columns: 71 | paired_non_nan = pd.concat([df1[col], df2[col]], axis=1).dropna() 72 | spearman_corr[col] = paired_non_nan.iloc[:, 0].corr(paired_non_nan.iloc[:, 1], method="spearman") 73 | return pd.DataFrame.from_dict(spearman_corr, orient="index", columns=["Spearman Correlation"]) 74 | 75 | @staticmethod 76 | def corr_diff(df1, df2): 77 | pearson_corr = {} 78 | for col in df1.columns: 79 | paired_non_nan = pd.concat([df1[col], df2[col]], axis=1).dropna() 80 | pearson_corr[col] = paired_non_nan.iloc[:, 0].corr(paired_non_nan.iloc[:, 1], method="pearson") 81 | return pd.DataFrame.from_dict(pearson_corr, orient="index", columns=["Pearson Correlation"]) 82 | 83 | @staticmethod 84 | def corrw_diff(df1, df2): 85 | pearson_corr = {} 86 | for col in df1.columns: 87 | paired_non_nan = pd.concat([df1[col], df2[col]], axis=1).dropna() 88 | pearson_corr[col] = scGraph.weighted_pearson( 89 | paired_non_nan.iloc[:, 0], paired_non_nan.iloc[:, 1], paired_non_nan.iloc[:, 1]) 90 | return pd.DataFrame.from_dict(pearson_corr, orient="index", columns=["Pearson Correlation"]) 91 | 92 | @staticmethod 93 | def weighted_pearson(x, y, distances): 94 | with np.errstate(divide="ignore", invalid="ignore"): 95 | weights = 1 / distances 96 | weights[distances == 0] = 0 97 | weights /= np.sum(weights) 98 | weighted_mean_x = np.average(x, weights=weights) 99 | weighted_mean_y = np.average(y, weights=weights) 100 | covariance = np.sum(weights * (x - weighted_mean_x) * (y - weighted_mean_y)) 101 | variance_x = np.sum(weights * (x - weighted_mean_x) ** 2) 102 | variance_y = np.sum(weights * (y - weighted_mean_y) ** 2) 103 | weighted_pearson_corr = covariance / np.sqrt(variance_x * variance_y) 104 | return weighted_pearson_corr 105 | 106 | def adata_concensus(self, obsm): 107 | _centroid = uh.calculate_trimmed_means( 108 | np.array(self.adata.obsm[obsm]), 109 | self.adata.obs[self.label_key], 110 | trim_proportion=self.trim_rate, 111 | ignore_=self.ignore_celltype, 112 | ) 113 | _pairdist = uh.compute_classwise_distances(_centroid) 114 | return _pairdist.div(_pairdist.max(axis=0), axis=1) 115 | 116 | def main(self, _obsm_list=None): 117 | self.preprocess() 118 | self.process_batches() 119 | self.calculate_consensus() 120 | 121 | res_df = pd.DataFrame(columns=["Rank-PCA", "Corr-PCA", "Corr-Weighted"]) 122 | if _obsm_list is None: 123 | _obsm_list = sorted(list(self.adata.obsm)) 124 | 125 | # self.concensus_df_pca.to_csv("concensus_df_pca_%s.csv"%self.trim_rate) 126 | # exit() 127 | for _obsm in _obsm_list: 128 | adata_df = self.adata_concensus(_obsm) 129 | _row_df = pd.DataFrame( 130 | { 131 | "Rank-PCA": self.rank_diff(adata_df, self.concensus_df_pca).mean().values, 132 | "Corr-PCA": self.corr_diff(adata_df, self.concensus_df_pca).mean().values, 133 | "Corr-Weighted": self.corrw_diff(adata_df, self.concensus_df_pca).mean().values, 134 | }, 135 | index=[_obsm], 136 | ) 137 | res_df = pd.concat([res_df, _row_df], axis=0, sort=False) 138 | return res_df 139 | 140 | def main(): 141 | import argparse 142 | parser = argparse.ArgumentParser(description="scGraph") 143 | parser.add_argument("--adata_path", type=str, default="/home/wangh256/Islander/data/breast/emb.h5ad", 144 | help="Path to the dataset to be used for analysis") 145 | parser.add_argument("--batch_key", type=str, default="donor_id", 146 | help="Batch key in the adata.obs") 147 | parser.add_argument("--hvg", type=bool, default=False, 148 | help="Whether to include hvg subset") 149 | parser.add_argument("--label_key", type=str, default="cell_type", 150 | help="Label key in the adata.obs") 151 | parser.add_argument("--trim_rate", type=float, default=0.05, 152 | help="Trim Rate, on two sides") 153 | parser.add_argument("--thres_batch", type=int, default=100, 154 | help="Minimum batch size for being consideration") 155 | parser.add_argument("--thres_celltype", type=int, default=10, 156 | help="Minimum number of cells in each cell type") 157 | parser.add_argument("--savename", type=str, default="scGraph", 158 | help="file name to save the results") 159 | args = parser.parse_args() 160 | 161 | scgraph = scGraph( 162 | adata_path=args.adata_path, 163 | batch_key=args.batch_key, 164 | label_key=args.label_key, 165 | hvg=args.hvg, 166 | trim_rate=args.trim_rate, 167 | thres_batch=args.thres_batch, 168 | thres_celltype=args.thres_celltype 169 | ) 170 | results = scgraph.main() 171 | results.to_csv(f"{args.savename}.csv") 172 | print(results.head()) 173 | 174 | if __name__ == "__main__": 175 | main() -------------------------------------------------------------------------------- /src/scLoss.py: -------------------------------------------------------------------------------- 1 | import torch, torch.nn as nn, torch.nn.functional as F, Utils_Handler as uh 2 | 3 | MOMERY_SIZE = 65536 # The queue size in MoCo 4 | MOMENTUM = 0.999 # The momentum in MoCo 5 | 6 | 7 | class MemoryBank(nn.Module): 8 | def __init__(self, input_dim, device, memory_size=MOMERY_SIZE): 9 | super(MemoryBank, self).__init__() 10 | 11 | self.memory_size = memory_size 12 | self.embed_bank = torch.randn(memory_size, input_dim).to(device) 13 | self.label_bank = torch.randint(0, 20, (memory_size,)).to(device) 14 | 15 | @torch.no_grad() 16 | def fetch(self, embeds_ema, labels): 17 | self.embed_bank = torch.concat([embeds_ema, self.embed_bank]) 18 | self.label_bank = torch.concat([labels, self.label_bank]) 19 | 20 | self.embed_bank = self.embed_bank[: self.memory_size, :] 21 | self.label_bank = self.label_bank[: self.memory_size] 22 | 23 | return self.embed_bank, self.label_bank 24 | 25 | def forward(self): 26 | return 27 | 28 | 29 | class InfoNCELoss(nn.Module): 30 | """Info Noise Contrastive Estimation Loss""" 31 | 32 | def __init__(self, queries, keys, temperature=0.1): 33 | super(InfoNCELoss, self).__init__() 34 | query_dim = queries.shape[1] 35 | key_dim = keys.shape[1] 36 | device = keys.device 37 | if query_dim != key_dim: 38 | self.proj = nn.Linear(query_dim, key_dim).to(device) 39 | else: 40 | self.proj = None 41 | self.temperature = temperature 42 | self.queries = queries 43 | self.keys = keys 44 | 45 | def forward(self): 46 | queries = self.queries 47 | keys = self.keys 48 | if self.proj: 49 | queries = self.proj(queries) 50 | batch_size = queries.size(0) 51 | 52 | logits = torch.matmul(queries, keys.t()) / self.temperature 53 | labels = torch.arange(batch_size).to(logits.device) 54 | return nn.CrossEntropyLoss()(logits, labels) 55 | 56 | 57 | class TripletLoss(nn.Module): 58 | def __init__(self, margin=1.0): 59 | super(TripletLoss, self).__init__() 60 | self.margin = margin 61 | 62 | def forward(self, anchor, positive, negative): 63 | distance_positive = (anchor - positive).pow(2).sum(1) # .pow(.5) 64 | distance_negative = (anchor - negative).pow(2).sum(1) # .pow(.5) 65 | losses = F.relu(distance_positive - distance_negative + self.margin) 66 | return losses.mean() 67 | 68 | 69 | class ContrastiveLoss(nn.Module): 70 | def __init__(self, margin=1.0): 71 | super(ContrastiveLoss, self).__init__() 72 | self.margin = margin 73 | 74 | def forward(self, emb1, emb2, label): 75 | l2 = (emb1 - emb2).pow(2).sum(1) 76 | loss = torch.mean((1 - label) * torch.pow(l2, 2) + (label) * torch.pow(torch.clamp(self.margin - l2, min=0.0), 2)) 77 | return loss 78 | 79 | 80 | # class SupervisedContrastiveLoss(nn.Module): 81 | # def __init__(self, temperature=0.1): 82 | # super(SupervisedContrastiveLoss, self).__init__() 83 | # self.temperature = temperature 84 | 85 | # def forward(self, embeds, labels): 86 | # # normalize the embeddings along the embedding dimension 87 | # embeds = F.normalize(embeds, dim=-1) 88 | 89 | # # Compute similarity matrix 90 | # sim_matrix = torch.matmul(embeds, embeds.t()) 91 | 92 | # # Create mask for positive samples 93 | # labels = labels.unsqueeze(1) 94 | # positive_mask = torch.eq(labels, labels.t()).float() 95 | 96 | # # Extract positive and negative similarities 97 | # positive_sim = sim_matrix * positive_mask 98 | # negative_sim = sim_matrix * (1 - positive_mask) 99 | 100 | # # Compute logit (numerator) and sum of exp(logits) (denominator) for the loss function 101 | # numerator = torch.sum(F.exp(positive_sim / self.temperature), dim=-1) 102 | # denominator = torch.sum(F.exp(negative_sim / self.temperature), dim=-1) 103 | 104 | # # Compute the loss 105 | # loss = -torch.log(numerator / (denominator + 1e-7)) 106 | # return torch.mean(loss) 107 | 108 | 109 | class SupervisedContrastiveLoss(nn.Module): 110 | def __init__(self, temperature=0.1): 111 | super(SupervisedContrastiveLoss, self).__init__() 112 | self.temperature = temperature 113 | 114 | def forward(self, emb_1, emb_2, labels1, labels2): 115 | """ 116 | Args: 117 | emb_1, emb_2: Tensors of shape (batch_size, embedding_dim) 118 | labels1, labels2: Tensors of shape (batch_size) 119 | """ 120 | 121 | # Normalize the embeddings 122 | emb_1 = F.normalize(emb_1, p=2, dim=-1) 123 | emb_2 = F.normalize(emb_2, p=2, dim=-1) 124 | 125 | # Concatenate the embeddings and labels 126 | embs = torch.cat([emb_1, emb_2], dim=0) 127 | labels = torch.cat([labels1, labels2], dim=0) 128 | 129 | # Compute the dot products 130 | dot_products = torch.mm(embs, embs.t()) / self.temperature 131 | 132 | # Compute the similarity scores 133 | labels_matrix = labels.unsqueeze(1) == labels.unsqueeze(0) 134 | labels_matrix = labels_matrix.float() 135 | 136 | # Subtract the similarity scores for positive pairs 137 | positive_pairs = dot_products * labels_matrix 138 | 139 | # Subtract the similarity scores for negative pairs 140 | negative_pairs = dot_products * (1 - labels_matrix) 141 | 142 | # Subtract the maximum similarity score for stability 143 | max_positive_pairs = positive_pairs.max(dim=1, keepdim=True)[0] 144 | max_negative_pairs = negative_pairs.max(dim=1, keepdim=True)[0] 145 | 146 | # Compute the logits 147 | logits = positive_pairs - max_positive_pairs - max_negative_pairs 148 | 149 | # Compute the log-sum-exp for the denominator 150 | logsumexp = torch.logsumexp(negative_pairs - max_negative_pairs, dim=1, keepdim=True) 151 | 152 | # Compute the loss 153 | loss = -logits + logsumexp 154 | 155 | return loss.mean() 156 | 157 | 158 | class SCL(nn.Module): 159 | """Supervised Contrastive Learning, with Memory Bank""" 160 | 161 | def __init__(self, memory_bank, temp=1, alpha=1e-4): 162 | super().__init__() 163 | self.temp = temp 164 | self.alpha = alpha 165 | self.memory_bank = memory_bank 166 | 167 | def forward(self, embeds, labels, embeds_ema): 168 | embeds_bank, labels_bank = self.memory_bank.fetch(embeds_ema, labels) 169 | 170 | # mask 171 | mask_p = labels.view(-1, 1).eq(labels_bank.view(1, -1)) 172 | mask_n = mask_p.logical_not() 173 | mask_p[:, : embeds.size(0)].fill_diagonal_(False) 174 | 175 | # logit 176 | x_norm = F.normalize(embeds, dim=1) 177 | f_norm = F.normalize(embeds_bank, dim=1) 178 | logits = x_norm.mm(f_norm.t()) / self.temp 179 | 180 | logits_p = torch.masked_select(logits, mask_p) 181 | logits_n = torch.masked_select(logits, mask_n) 182 | 183 | loss_p = F.binary_cross_entropy_with_logits(logits_p, torch.ones_like(logits_p)) 184 | loss_n = F.binary_cross_entropy_with_logits(logits_n, torch.zeros_like(logits_n)) 185 | loss = self.alpha * loss_p + (1.0 - self.alpha) * loss_n 186 | 187 | return loss 188 | 189 | 190 | # class SimPLE(nn.Module): 191 | # def __init__( 192 | # self, 193 | # memory_bank, 194 | # r=1.0, 195 | # m=0.0, 196 | # b_cos=0.2, 197 | # lw=1000.0, 198 | # alpha=1e-4, # have to set this small 199 | # b_logit=-10.0, 200 | # ): 201 | # """Simple Pairwise Similarity Learning with Memory Bank""" 202 | 203 | # super().__init__() 204 | # self.memory_bank = memory_bank 205 | # self.rank = 0 206 | 207 | # # === hyperparam === 208 | # # alpha is the lambda in paper Eqn. 5 209 | # self.alpha = alpha 210 | # self.b_cos = b_cos 211 | # self.lw = lw 212 | # self.r = r 213 | # self.m = m 214 | # # 215 | # # init bias 216 | # self.b_logit = nn.Parameter(b_logit + 0.0 * torch.Tensor(1)).to( 217 | # self.memory_bank.embed_bank.device 218 | # ) 219 | 220 | # # embeds, labels, embeds_ema, 221 | # def forward(self, x, y, x_ema): 222 | # # update bank 223 | # self.feat_bank, self.label_bank = self.memory_bank.fetch(x_ema, y) 224 | 225 | # # mask 226 | # mask_p = y.view(-1, 1).eq(self.label_bank.view(1, -1)) 227 | # mask_n = mask_p.logical_not() 228 | # pt = self.rank * x.size(0) 229 | # mask_p[:, pt : pt + x.size(0)].fill_diagonal_(False) 230 | 231 | # # logit 232 | # # x_mag = F.softplus(x[:, :1], beta=1) 233 | # # f_mag = F.softplus(self.feat_bank[:, :1], beta=1) 234 | # # x_dir = F.normalize(x[:, 1:], dim=1) 235 | # # f_dir = F.normalize(self.feat_bank[:, 1:], dim=1) 236 | # # logits = x_mag.mm(f_mag.t()) * (x_dir.mm(f_dir.t()) - self.b_cos) 237 | 238 | # x_norm = F.normalize(x, dim=1) 239 | # f_norm = F.normalize(self.feat_bank, dim=1) 240 | # logits = x_norm.mm(f_norm.t()) - self.b_cos 241 | 242 | # logits_p = torch.masked_select(logits, mask_p) 243 | # logits_p = (logits_p - self.m + self.b_logit) / self.r 244 | # logits_n = torch.masked_select(logits, mask_n) 245 | # logits_n = (logits_n + self.m + self.b_logit) * self.r 246 | 247 | # # loss 248 | # loss_p = F.binary_cross_entropy_with_logits( 249 | # logits_p, 250 | # torch.ones_like(logits_p), 251 | # ) 252 | # loss_n = F.binary_cross_entropy_with_logits( 253 | # logits_n, 254 | # torch.zeros_like(logits_n), 255 | # ) 256 | # loss = self.alpha * loss_p + (1.0 - self.alpha) * loss_n 257 | 258 | # return self.lw * loss 259 | 260 | 261 | if __name__ == "__main__": 262 | pass 263 | -------------------------------------------------------------------------------- /src/scModel.py: -------------------------------------------------------------------------------- 1 | import time, torch, torch.nn as nn, torch.nn.functional as F 2 | import Utils_Handler as uh 3 | 4 | # from scLoss import HSICLoss, DDCLoss, KCCALoss, InfoNCELoss 5 | 6 | 7 | class AutoEncoder(nn.Module): 8 | def __init__( 9 | self, 10 | n_gene: int, 11 | bn_eps=1e-05, 12 | batchnorm=True, 13 | bn_momentum=0.1, 14 | dropout_layer=False, 15 | mlp_size=[128, 128], 16 | reconstruct_loss=nn.MSELoss(), 17 | **kwargs, 18 | ): 19 | super().__init__() 20 | 21 | self.n_gene = n_gene 22 | self.bn_eps = bn_eps 23 | self.mlp_size = mlp_size 24 | self.batchnorm = batchnorm 25 | self.bn_momentum = bn_momentum 26 | self.dropout_layer = dropout_layer 27 | self.reconstruct_loss = reconstruct_loss 28 | 29 | self.encoder = self._encode() 30 | self.decoder = self._decode() 31 | 32 | def _batchnorm1d(self, size): 33 | # https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html 34 | # Input shape here is (N, C), (# Cell, # Gene) 35 | # default batch1d: BatchNorm1d(64, eps=1e-05, momentum=0.1) 36 | # scvi batch1d: BatchNorm1d(64, eps=1e-03, momentum=0.01) 37 | return nn.BatchNorm1d(size, momentum=self.bn_momentum, eps=self.bn_eps) 38 | 39 | def _encode(self): 40 | # on the order of functional layers 41 | # ref: https://stackoverflow.com/questions/39691902 42 | # Input -> FC -> BN -> ReLu (-> Dropout) -> FC -> Reconstruction 43 | 44 | enc = [] 45 | enc.append(nn.Linear(self.n_gene, self.mlp_size[0])) 46 | if self.batchnorm: 47 | enc.append(self._batchnorm1d(self.mlp_size[0])) 48 | enc.append(nn.ReLU()) 49 | if self.dropout_layer: 50 | enc.append(nn.Dropout(p=0.1)) 51 | 52 | mid_idx = len(self.mlp_size) // 2 - (len(self.mlp_size) + 1) % 2 53 | for i in range(mid_idx): 54 | enc.append(nn.Linear(self.mlp_size[i], self.mlp_size[i + 1])) 55 | if self.batchnorm: 56 | enc.append(self._batchnorm1d(self.mlp_size[i + 1])) 57 | enc.append(nn.ReLU()) 58 | if self.dropout_layer: 59 | enc.append(nn.Dropout(p=0.1)) 60 | 61 | # NOTE: make sure the last layer is linear 62 | enc.pop() 63 | if self.batchnorm: 64 | enc.pop() 65 | if self.dropout_layer: 66 | enc.pop() 67 | return nn.Sequential(*enc) 68 | 69 | def _decode(self): 70 | dec = [] 71 | mid_ = len(self.mlp_size) // 2 - (len(self.mlp_size) + 1) % 2 72 | 73 | for i in range(len(self.mlp_size) // 2): 74 | dec.append(nn.Linear(self.mlp_size[mid_ + i], self.mlp_size[mid_ + i + 1])) 75 | if self.batchnorm: 76 | dec.append(self._batchnorm1d(self.mlp_size[mid_ + i + 1])) 77 | dec.append(nn.ReLU()) 78 | dec.append(nn.Linear(self.mlp_size[-1], self.n_gene)) 79 | dec.append(nn.ReLU()) # norm counts are always non-negative 80 | return nn.Sequential(*dec) 81 | 82 | def forward(self, counts_: torch.Tensor, **kwargs) -> torch.Tensor: 83 | emb_ = self.encoder(counts_) # (# Cell, # Embedding) 84 | out_ = self.decoder(emb_) 85 | return out_ 86 | 87 | def extra_repr(self, x): 88 | return self.encoder(x) 89 | 90 | def recon_repr(self, x): 91 | return self.decoder(x) 92 | 93 | def loss_(self, rec_pred: torch.Tensor, rec_target: torch.Tensor, **kwargs): 94 | return self.reconstruct_loss(rec_pred, rec_target) 95 | 96 | def norm_inf(self, rec_): 97 | normed_counts = torch.exp(rec_) - 1 98 | size_factors = 1e4 / normed_counts.sum(axis=1) 99 | rec_new = torch.log(normed_counts * size_factors[:, None] + 1) 100 | return rec_new 101 | 102 | 103 | class AE_Concept(AutoEncoder): 104 | def __init__( 105 | self, 106 | n_gene: int, 107 | n_concept=0, 108 | leak_dim=16, 109 | cell2cat=None, 110 | bn_eps=1e-05, 111 | bn_momentum=0.1, 112 | batchnorm=True, 113 | mlp_size=[64, 64], 114 | dropout_layer=False, 115 | with_projector=True, 116 | **kwargs, 117 | ) -> None: 118 | super().__init__( 119 | n_gene=n_gene, 120 | bn_eps=bn_eps, 121 | mlp_size=mlp_size, 122 | batchnorm=batchnorm, 123 | bn_momentum=bn_momentum, 124 | dropout_layer=dropout_layer, 125 | ) 126 | self.leak_dim = leak_dim 127 | self.cell2cat = cell2cat 128 | self.n_concept = n_concept 129 | 130 | if len(self.mlp_size) % 2 == 0: 131 | self.mlp_size.insert(len(self.mlp_size) // 2, n_concept + leak_dim) 132 | 133 | print("=" * 28, " Model architecture ", "=" * 27) 134 | print("Leak dim:", self.leak_dim) 135 | print("Dropout:", self.dropout_layer) 136 | print("MLP size:", self.mlp_size) 137 | print("Batchnorm:", batchnorm, ", momentum:", bn_momentum, ", eps:", bn_eps) 138 | print("=" * 77) 139 | 140 | self.encoder = self._encode() 141 | self.decoder = self._decode() 142 | if with_projector: 143 | self.projector = self._proj() 144 | 145 | def _proj(self): 146 | NUM_CELLTYPE = self.cell2cat.__len__() 147 | return nn.Linear(self.leak_dim, NUM_CELLTYPE) 148 | # return nn.Linear(self.leak_dim, NUM_CELLTYPE, bias=False) 149 | 150 | def _loss_rec(self, pred: torch.Tensor, target: torch.Tensor): 151 | """Reconstruction Loss""" 152 | return self.reconstruct_loss(pred, target) 153 | 154 | def _loss_celltype(self, emb_b: torch.Tensor, cell_type: torch.Tensor): 155 | """Cell Type Cross Entropy Loss""" 156 | emb_b = self.projector(emb_b) 157 | emb_b = emb_b.log_softmax(dim=-1) 158 | return F.nll_loss(emb_b, cell_type.to(emb_b.device)) 159 | 160 | def forward(self, counts_: torch.Tensor, **kwargs) -> torch.Tensor: 161 | emb_ = self.encoder(counts_) 162 | out_ = self.decoder(emb_) 163 | return out_ 164 | 165 | def acc_celltype_(self, emb_b: torch.Tensor, cell_type: torch.Tensor): 166 | self.projector.eval() 167 | emb_b = self.projector(emb_b) 168 | emb_b = emb_b.softmax(dim=-1).max(-1)[1] 169 | return emb_b.eq(cell_type.to(emb_b.device)).mean(dtype=float).cpu().item() 170 | 171 | 172 | Model_ZOO = { 173 | "base-ae": AutoEncoder, 174 | "ae-concept": AE_Concept, 175 | } 176 | 177 | if __name__ == "__main__": 178 | # 179 | import argparse 180 | from scDataset import scDataset, collate_fn 181 | from torch.utils.data import DataLoader 182 | 183 | # 184 | LR = 0.001 185 | N_EPOCHS = 3 186 | MINCELLS = 256 187 | DEVICE = torch.device("cuda:3") 188 | 189 | for DATASET in [ 190 | "lung", 191 | "lung_fetal_donor", 192 | "lung_fetal_organoid", 193 | "brain", 194 | "breast", 195 | "heart", 196 | "eye", 197 | "gut_fetal", 198 | "skin", 199 | "COVID", 200 | "pancreas", 201 | ]: 202 | # scData_Train = scDataset(dataset=DATASET, training=True, rm_cache=True) 203 | scData_Train = scDataset(dataset=DATASET, training=True) 204 | scData_Test = scDataset(dataset=DATASET, verbose=False, training=False) 205 | scTrainLoader = DataLoader( 206 | scData_Train, 207 | batch_size=1, 208 | shuffle=True, 209 | num_workers=8, 210 | collate_fn=collate_fn, 211 | ) 212 | scTestLoader = DataLoader( 213 | scData_Test, 214 | batch_size=1, 215 | shuffle=True, 216 | num_workers=8, 217 | collate_fn=collate_fn, 218 | ) 219 | n_genes = scData_Train.n_vars 220 | 221 | def ArgParser(): 222 | parser = argparse.ArgumentParser() 223 | parser.add_argument("--leak_dim", type=int, default=16) 224 | parser.add_argument("--mlp_size", type=int, nargs="+", default=[128, 128]) # default=[128, 128, 128, 128] 225 | return parser.parse_args() 226 | 227 | args = ArgParser() 228 | MODEL = AE_Concept( 229 | n_gene=n_genes, 230 | dropout_layer=True, 231 | mlp_size=args.mlp_size, 232 | leak_dim=args.leak_dim, 233 | cell2cat=scData_Train.CELL2CAT, 234 | ).to(DEVICE) 235 | 236 | Opt = torch.optim.Adam(MODEL.parameters(), lr=LR) 237 | start = time.time() 238 | for epoch in range(N_EPOCHS): 239 | # LossWeights = { 240 | # "reconstruction": [1.0], 241 | # "celltype": [], 242 | # "concept": [], 243 | # } 244 | # print("Epoch: %2d" % epoch) 245 | for id_, batch_ in enumerate(scTrainLoader): 246 | # 247 | Opt.zero_grad() 248 | counts = batch_["counts"].squeeze().to(DEVICE) 249 | # 250 | rec_ = MODEL(counts) 251 | emb_ = MODEL.extra_repr(counts) 252 | loss_rec = MODEL._loss_rec(counts, rec_) 253 | loss_celltype = MODEL._loss_celltype(emb_, batch_["cell_type"]) 254 | 255 | # LossWeights["celltype"].append( 256 | # uh.Weights_on_GradNorm(AE_CBLEAKAGE, loss_rec, loss_celltype) 257 | # ) 258 | loss = loss_rec + 1e-2 * loss_celltype 259 | loss.backward() 260 | Opt.step() 261 | # for key, value in LossWeights.items(): 262 | # LossWeights[key] = uh.mean_(value) 263 | # print("%17s, %.6f" % (key, LossWeights[key])) 264 | with torch.no_grad(): 265 | acc_ = [] 266 | for _, batch_ in enumerate(scTestLoader): 267 | counts = batch_["counts"].to(DEVICE) 268 | acc_.append( 269 | MODEL.acc_celltype_( 270 | emb_b=MODEL.extra_repr(counts), 271 | cell_type=batch_["cell_type"], 272 | ) 273 | ) 274 | print("Epoch: %2d, Test Acc: %.2f, EST: %.1f Mins" % (epoch, 100 * uh.mean_(acc_), (time.time() - start) // 60)) 275 | -------------------------------------------------------------------------------- /src/scTrain.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import os, time, json, torch, wandb, shutil, Data_Handler as dh, numpy as np, torch.nn.functional as F 3 | from scDataset import scDataset, ContrastiveDatasetDUAL, TripletDatasetDUAL, collate_fn 4 | from scLoss import TripletLoss, SupervisedContrastiveLoss 5 | from torch.optim import SGD, RMSprop, Adam, ASGD, AdamW 6 | from torch.optim.lr_scheduler import CosineAnnealingLR 7 | from Utils_Handler import _isnan, mean_, DATETIME 8 | from torch.utils.data import DataLoader 9 | from ArgParser import Parser_Trainer 10 | from os.path import join, exists 11 | from scModel import Model_ZOO 12 | 13 | PATIENCE = 10 14 | _Optimizer = { 15 | "sgd": SGD, 16 | "asgd": ASGD, 17 | "adam": Adam, 18 | "adamw": AdamW, 19 | "rmsprop": RMSprop, 20 | } 21 | 22 | 23 | class scTrainer: 24 | def __init__(self, args, mixup=False, **kwargs) -> None: 25 | self.args = args 26 | self.mixup = mixup 27 | self.dataset = args.dataset 28 | self.device = torch.device("cuda:%s" % args.gpu) 29 | 30 | if args.savename: 31 | self.SAVE_PATH = args.savename 32 | else: 33 | self.SAVE_PATH = rf"{dh.MODEL_DIR}/{self.dataset}_{DATETIME}" 34 | if exists(self.SAVE_PATH): 35 | shutil.rmtree(self.SAVE_PATH) 36 | os.makedirs(self.SAVE_PATH, exist_ok=True) 37 | 38 | with open(join(self.SAVE_PATH, "cfg.json"), "w") as outfile: 39 | json.dump(vars(args), outfile) 40 | 41 | scData_Train, self.scData_TrainLoader = self._scDataloader(train_and_test=self.args.train_and_test) 42 | scData_Test, self.scData_TestLoader = self._scDataloader(training=False) 43 | self.n_Train, self.n_Test = len(scData_Train), len(scData_Test) 44 | self.cell2cat = scData_Train.CELL2CAT 45 | self.n_vars = scData_Train.n_vars 46 | print("# Genes: %d" % self.n_vars) 47 | print("# Cells: %d (Training), %d (Testing)\n" % (self.n_Train * 256, self.n_Test * 256)) 48 | 49 | self.MODEL = self._scModel() 50 | self._scTrain() 51 | 52 | def _scModel(self): 53 | if self.args.batch1d == "Vanilla": 54 | bn_eps, bn_momentum = 1e-5, 0.1 55 | elif self.args.batch1d == "scVI": 56 | bn_eps, bn_momentum = 1e-3, 0.01 57 | else: 58 | raise ValueError("Unknown batch1d type") 59 | 60 | return Model_ZOO[self.args.type]( 61 | bn_eps=bn_eps, 62 | bn_momentum=bn_momentum, 63 | leak_dim=self.args.leakage, 64 | batchnorm=self.args.batchnorm, 65 | dropout_layer=self.args.dropout, 66 | mlp_size=self.args.mlp_size, 67 | cell2cat=self.cell2cat, 68 | n_gene=self.n_vars, 69 | ).to(self.device) 70 | 71 | def _scDataloader(self, training=True, train_and_test=False): 72 | _verb = True if training else False 73 | _scDataset = scDataset( 74 | dataset=self.dataset, 75 | verbose=_verb, 76 | training=training, 77 | train_and_test=train_and_test, 78 | ) 79 | _scDataLoader = DataLoader(_scDataset, batch_size=1, shuffle=True, num_workers=8, collate_fn=collate_fn) 80 | return _scDataset, _scDataLoader 81 | 82 | def _loss_cet(self, emb_b, cell_type): 83 | emb_ = self.MODEL.projector(emb_b) 84 | logits_ = emb_.log_softmax(dim=-1) 85 | return F.nll_loss(logits_, cell_type.to(self.device)) 86 | 87 | def _loss_mixup(self, emb_b, cell_type, alpha=1.0): 88 | # alpha = 1.0, 0.2, 0.4, 0.6 89 | 90 | lam = np.random.beta(alpha, alpha) if alpha > 0 else 1 91 | 92 | batch_size = emb_b.size()[0] 93 | index = torch.randperm(batch_size) 94 | y_a, y_b = cell_type, cell_type[index] 95 | mixed_x = lam * emb_b + (1 - lam) * emb_b[index, :] 96 | 97 | return lam * self._loss_cet(mixed_x, y_a) + (1 - lam) * self._loss_cet(mixed_x, y_b) 98 | 99 | def _loss_rec(self, counts_, rec_): 100 | return F.mse_loss(rec_, counts_) 101 | 102 | def _scTrain(self): 103 | LR = self.args.lr 104 | NUM_EPOCHS = self.args.epoch 105 | 106 | if self.args.optimiser == "adam": 107 | Opt = _Optimizer["adam"](self.MODEL.parameters(), lr=LR) 108 | elif args.optimiser == "sgd": 109 | Opt = _Optimizer["sgd"](self.MODEL.parameters(), lr=LR, momentum=0.9) 110 | else: 111 | raise ValueError("we now only work with adam and sgd optimisers") 112 | LR_SCHEDULER = CosineAnnealingLR(Opt, T_max=NUM_EPOCHS) 113 | 114 | LOCAL_PATIENCE = 0 115 | start = time.time() 116 | BEST_TEST_All = 1e4 117 | # BEST_TEST_Rec = 1e4 118 | # BEST_TEST_Celltype = 0 119 | for epoch in range(NUM_EPOCHS): 120 | """=== Training ===""" 121 | total_err, total_rec, total_cet = 0, 0, 0 122 | for _, batch_ in enumerate(self.scData_TrainLoader): 123 | Opt.zero_grad() 124 | 125 | counts_ = batch_["counts"].squeeze().to(self.device) 126 | rec_ = self.MODEL(counts_) 127 | emb_ = self.MODEL.extra_repr(counts_) 128 | 129 | loss_rec = self._loss_rec(counts_, rec_) 130 | loss_cet = self._loss_cet(emb_, batch_["cell_type"]) 131 | if self.mixup: 132 | loss_cet += self._loss_mixup(emb_, batch_["cell_type"]) 133 | 134 | loss = self.args.w_rec * loss_rec + self.args.w_cet * loss_cet 135 | loss.backward() 136 | total_err += loss.item() 137 | total_rec += loss_rec.item() 138 | total_cet += loss_cet.item() 139 | Opt.step() 140 | LR_SCHEDULER.step() 141 | 142 | train_err = total_err / self.n_Train 143 | train_rec = total_rec / self.n_Train 144 | 145 | """ === Testing === """ 146 | total_err, total_rec, total_acc = 0, 0, [] 147 | for _, batch_ in enumerate(self.scData_TestLoader): 148 | with torch.no_grad(): 149 | counts_ = batch_["counts"].squeeze().to(self.device) 150 | rec_ = self.MODEL(counts_) 151 | emb_ = self.MODEL.extra_repr(counts_) 152 | 153 | loss_rec = self._loss_rec(counts_, rec_) 154 | loss_cet = self._loss_cet(emb_, batch_["cell_type"]) 155 | loss = self.args.w_rec * loss_rec + self.args.w_cet * loss_cet 156 | 157 | total_err += loss.item() 158 | total_rec += loss_rec.item() 159 | total_acc.append(self.MODEL.acc_celltype_(emb_, batch_["cell_type"])) 160 | 161 | test_err = total_err / self.n_Test 162 | test_rec = total_rec / self.n_Test 163 | 164 | """ === Logging === """ 165 | lr = LR_SCHEDULER.get_last_lr()[0] 166 | if epoch % 1 == 0: 167 | print( 168 | "LR: %.6f " % lr, 169 | "Epoch: %2d, " % epoch, 170 | "Test Loss: %.2f, " % test_err, 171 | "Train Loss: %.2f, " % train_err, 172 | "Test Acc: %.2f, " % (100 * mean_(total_acc)), 173 | "Time: %.2f Mins" % ((time.time() - start) // 60), 174 | ) 175 | 176 | wandb.log( 177 | { 178 | "LR": lr, 179 | "Test Loss": test_err, 180 | "Train Loss": train_err, 181 | "Test Acc": 100 * mean_(total_acc), 182 | # 183 | "Train Rec Loss": train_rec, 184 | "Test Rec Loss": test_rec, 185 | } 186 | ) 187 | LOCAL_PATIENCE += 1 188 | 189 | if (epoch + 1) % 10 == 0: 190 | torch.save(self.MODEL.state_dict(), join(self.SAVE_PATH, "ckpt_%d.pth" % (epoch + 1))) 191 | 192 | if test_err < BEST_TEST_All: 193 | LOCAL_PATIENCE = 0 194 | BEST_TEST_All = test_err 195 | torch.save(self.MODEL.state_dict(), join(self.SAVE_PATH, "ckpt_best.pth")) 196 | 197 | if _isnan(train_err): 198 | print("NaN Value Encountered, Quitting") 199 | quit() 200 | 201 | if LOCAL_PATIENCE > PATIENCE: 202 | torch.save(self.MODEL.state_dict(), join(self.SAVE_PATH, "ckpt_%d.pth" % (epoch + 1))) 203 | print("Patience (%d Epochs) Reached, Quitting" % PATIENCE) 204 | quit() 205 | 206 | 207 | class scTrainer_Semi(scTrainer): 208 | """Supervised Contrastive Learning""" 209 | 210 | def __init__(self, args, mode="scl"): 211 | self.mode = mode 212 | super().__init__(args=args) 213 | 214 | def _scDataloader(self, training=True, train_and_test=False): 215 | _verb = True if training else False 216 | _scDataset = scDataset( 217 | dataset=self.dataset, 218 | verbose=_verb, 219 | training=training, 220 | train_and_test=train_and_test, 221 | ) 222 | if self.mode == "scl": 223 | _scDataset = ContrastiveDatasetDUAL(_scDataset) 224 | 225 | elif self.mode == "triplet": 226 | _scDataset = TripletDatasetDUAL(_scDataset) 227 | 228 | _scDataLoader = DataLoader(_scDataset, batch_size=1, shuffle=True, num_workers=8, collate_fn=collate_fn) 229 | return _scDataset, _scDataLoader 230 | 231 | def _scTrain(self): 232 | LR = self.args.lr 233 | NUM_EPOCHS = self.args.epoch 234 | 235 | if self.args.optimiser == "adam": 236 | Opt = _Optimizer["adam"](self.MODEL.parameters(), lr=LR) 237 | elif args.optimiser == "sgd": 238 | Opt = _Optimizer["sgd"](self.MODEL.parameters(), lr=LR, momentum=0.9) 239 | else: 240 | raise ValueError("we now only work with adam and sgd optimisers") 241 | 242 | Opt = Adam(self.MODEL.parameters(), lr=LR) 243 | LR_SCHEDULER = CosineAnnealingLR(Opt, T_max=NUM_EPOCHS) 244 | 245 | BEST_TEST = 1e4 246 | LOCAL_PATIENCE = 0 247 | start = time.time() 248 | 249 | triplet_loss = TripletLoss() 250 | scl_loss = SupervisedContrastiveLoss() 251 | 252 | for epoch in range(NUM_EPOCHS): 253 | train_err = 0 254 | if self.mode == "triplet": 255 | for item in self.scData_TrainLoader: 256 | Opt.zero_grad() 257 | 258 | anchor, positive, negative = item.values() 259 | counts_a = anchor.to(self.device) 260 | emb_a = self.MODEL.extra_repr(counts_a) 261 | 262 | counts_p = positive.to(self.device) 263 | emb_p = self.MODEL.extra_repr(counts_p) 264 | 265 | counts_n = negative.to(self.device) 266 | emb_n = self.MODEL.extra_repr(counts_n) 267 | 268 | _loss = triplet_loss(emb_a, emb_p, emb_n) 269 | _loss.backward() 270 | Opt.step() 271 | train_err += _loss.item() 272 | 273 | elif self.mode == "scl": 274 | for item in self.scData_TrainLoader: 275 | b1, l1, b2, l2 = item.values() 276 | Opt.zero_grad() 277 | 278 | """ Intra Batch """ 279 | counts_1 = b1.squeeze().to(self.device) 280 | emb_1 = self.MODEL.extra_repr(counts_1) 281 | l1 = l1.to(self.device) 282 | 283 | _permute = torch.randperm(emb_1.size()[0]) 284 | emb_, l1_ = emb_1[_permute], l1[_permute] 285 | _loss = scl_loss(emb_1, emb_, l1, l1_) 286 | 287 | """ Inter Batch """ 288 | counts_2 = b2.squeeze().to(self.device) 289 | emb_2 = self.MODEL.extra_repr(counts_2) 290 | if len(emb_1) == len(emb_2): 291 | l2 = l2.to(self.device) 292 | _permute = torch.randperm(emb_1.size()[0]) 293 | emb_2, l2 = emb_2[_permute], l2[_permute] 294 | _loss += scl_loss(emb_1, emb_2, l1, l2) 295 | 296 | _loss.backward() 297 | Opt.step() 298 | train_err += _loss.item() 299 | 300 | """ === Shared === """ 301 | LR_SCHEDULER.step() 302 | lr = LR_SCHEDULER.get_last_lr()[0] 303 | train_err = train_err / len(self.scData_TrainLoader) 304 | 305 | if _isnan(train_err): 306 | self._str_formatter("NaN Value Encountered, Quitting") 307 | quit() 308 | 309 | if epoch % 1 == 0: 310 | print( 311 | "LR: %.6f " % lr, 312 | "Epoch: %2d, " % epoch, 313 | "Total Loss: %.2f, " % train_err, 314 | "EST: %.1f Mins" % ((time.time() - start) // 60), 315 | ) 316 | 317 | if (epoch + 1) % 10 == 0: 318 | torch.save(self.MODEL.state_dict(), join(self.SAVE_PATH, "ckpt_%d.pth" % (epoch + 1))) 319 | 320 | LOCAL_PATIENCE += 1 321 | if train_err < BEST_TEST: 322 | LOCAL_PATIENCE = 0 323 | BEST_TEST = train_err 324 | torch.save(self.MODEL.state_dict(), join(self.SAVE_PATH, "ckpt_best.pth")) 325 | 326 | if LOCAL_PATIENCE > PATIENCE: 327 | torch.save(self.MODEL.state_dict(), join(self.SAVE_PATH, "ckpt_%d.pth" % (epoch + 1))) 328 | print("Patience (%d Epochs) Reached, Quitting" % PATIENCE) 329 | quit() 330 | 331 | return 332 | 333 | 334 | class scTrainer_PSL(scTrainer): 335 | def __init__(self, args) -> None: 336 | pass 337 | 338 | 339 | if __name__ == "__main__": 340 | args = Parser_Trainer() 341 | wandb.init(project=args.project, name=args.runname, config=args) 342 | 343 | if args.mode in ["vanilla", "mixup"]: 344 | scTrainer(args, mixup=args.mode == "mixup") 345 | elif args.mode in ["triplet", "scl"]: 346 | scTrainer_Semi(args, mode=args.mode) 347 | elif args.mode == "psl": 348 | scTrainer_PSL(args) 349 | else: 350 | raise ValueError("Unknown mode") 351 | -------------------------------------------------------------------------------- /teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Genentech/Islander/4fbf0d9336d4a3bb4667fec9aeb3bff551321622/teaser.png --------------------------------------------------------------------------------