├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── data ├── MDT_data_prepare.py └── data_preprocessing │ ├── cao_organogenesis │ ├── README.md │ ├── environment.yaml │ ├── process2anndata.py │ └── scripts │ │ ├── alevin_fry_data_dl.sh │ │ ├── cao_et_al_get_quant.R │ │ └── cao_et_al_update_anno.R │ └── hindbrain_dev │ ├── README.md │ ├── Snakefile │ ├── config.json │ ├── environment.yaml │ ├── kallisto_readme │ └── README.md │ └── scripts │ ├── loom_metadata.py │ ├── loom_subset.py │ ├── loompy_combine.py │ ├── manifest.json │ ├── metadata.py │ ├── mouse_build.py │ ├── mouse_download.sh │ └── mouse_generate_fragments.py ├── deepvelo ├── __init__.py ├── base │ ├── __init__.py │ ├── base_data_loader.py │ ├── base_model.py │ └── base_trainer.py ├── data_loader │ └── data_loaders.py ├── logger │ ├── __init__.py │ ├── logger.py │ ├── logger_config.json │ └── visualization.py ├── model │ ├── layers.py │ ├── loss.py │ ├── metric.py │ └── model.py ├── parse_config.py ├── pipeline │ ├── __init__.py │ └── eval.py ├── plot │ ├── __init__.py │ ├── plot.py │ └── scatter.py ├── tool │ ├── __init__.py │ ├── driver_gene.py │ ├── kinetic_rates.py │ ├── stats.py │ └── velocity.py ├── train.py ├── trainer │ ├── __init__.py │ └── trainer.py └── utils │ ├── __init__.py │ ├── confidence.py │ ├── map_velocity_expression.py │ ├── optimization.py │ ├── plot.py │ ├── preprocess.py │ ├── scatter.py │ ├── temporal.py │ └── util.py ├── docs ├── Makefile ├── conf.py ├── index.rst └── make.bat ├── examples ├── README.md ├── computation_time_plot.ipynb ├── figure2.ipynb ├── figure3(d-h).ipynb ├── figure3.ipynb ├── figure4_hindbrain.ipynb ├── incorporate_cellrank.ipynb ├── la_manno_hippocampus.ipynb ├── minimal_example.ipynb ├── mouse_gastrulation.py ├── multifacet_check.py ├── organogenesis_chondrocyte.ipynb ├── r_analysis │ ├── README.md │ ├── env.yaml │ ├── helpers │ │ └── scIB_knit_table.R │ └── scripts │ │ ├── 00_extra_package_installs.R │ │ ├── 01_deepvelo_marker_analysis.R │ │ ├── 02_deepvelo_scvelo_pathway_analysis.R │ │ ├── 03_deepvelo_scvelo_activepathways_results_analysis.R │ │ ├── 04_deepvelo_tf_driver_analysis.R │ │ ├── 05_supplementary_table_formatting.R │ │ └── 06_full_scores_dotplot.R ├── robustness.ipynb ├── sweep_robustness.py └── sweep_robustness.yaml ├── poetry.lock ├── pyproject.toml └── tests ├── scvelo_test.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # input data, saved log, checkpoints 104 | data/ 105 | cache/ 106 | input/ 107 | saved/ 108 | out/ 109 | outs/ 110 | datasets/ 111 | figures/ 112 | wandb/ 113 | result_files 114 | temporary_files 115 | *.png 116 | *.loom 117 | *.txt 118 | *.pdf 119 | *.Rhistory 120 | 121 | # Downloaded deepvelo data 122 | deepvelo_data/ 123 | deepvelo_data.tar.gz 124 | 125 | # editor, os cache directory 126 | .vscode/ 127 | .idea/ 128 | __MACOSX/ 129 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details, https://readthedocs.org/projects/deepvelo/ 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the OS, Python version and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.9" 13 | # You can also specify other tool versions: 14 | # nodejs: "19" 15 | # rust: "1.64" 16 | # golang: "1.19" 17 | 18 | # Build documentation in the "docs/" directory with Sphinx 19 | sphinx: 20 | configuration: docs/conf.py 21 | # Optionally build your docs in additional formats such as PDF and ePub 22 | # formats: 23 | # - pdf 24 | # - epub 25 | 26 | # Optional but recommended, declare the Python requirements required 27 | # to build your documentation 28 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 29 | # python: 30 | # install: 31 | # - requirements: docs/requirements.txt 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Victor Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepVelo - A Deep Learning-based velocity estimation tool with cell-specific kinetic rates 2 | 3 | [![PyPI version](https://badge.fury.io/py/deepvelo.svg)](https://badge.fury.io/py/deepvelo) 4 | [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) 5 | 6 | This is the official implementation of the [DeepVelo](https://genomebiology.biomedcentral.com/articles/10.1186/s13059-023-03148-9) method. 7 | DeepVelo employs cell-specific kinetic rates and provides more accurate RNA velocity estimates for complex differentiation and lineage decision events in heterogeneous scRNA-seq data. Please check out the paper for more details. 8 | 9 | ![alt text](https://user-images.githubusercontent.com/11674033/171066682-a899377f-fae1-452a-8b67-8bc8c244b641.png) 10 | 11 | ## Installation 12 | 13 | Please note that using the pip version is currently recommended. The currently supported python versions are `3.7`, `3.8`, and `3.9`. 14 | 15 | ```bash 16 | pip install deepvelo 17 | ``` 18 | 19 | ### Using GPU 20 | 21 | The `dgl` cpu version is installed by default. For GPU acceleration, please install a proper [dgl gpu](https://www.dgl.ai/pages/start.html) version compatible with your CUDA environment. 22 | 23 | ```bash 24 | pip uninstall dgl # remove the cpu version 25 | # replace cu101 with your desired CUDA version and run the following 26 | pip install "dgl-cu101>=0.4.3,<0.7" 27 | 28 | ``` 29 | 30 | ### Install the development version 31 | 32 | We use poetry to manage dependencies. 33 | 34 | ```bash 35 | poetry install 36 | ``` 37 | 38 | This will install the exact versions in the provided [poetry.lock](poetry.lock) file. If you want to install the latest version for all dependencies, use the following command. 39 | 40 | ```bash 41 | poetry update 42 | ``` 43 | 44 | ## Minimal example 45 | 46 | We provide a number of notebooks in the [examples](examples) folder to help you get started. This folder contains analyses from the paper, as well as a minimal [python notebook](examples/minimal_example.ipynb). 47 | 48 | DeepVelo fully integrates with [scanpy](https://scanpy.readthedocs.io/en/latest/) and [scVelo](https://scvelo.readthedocs.io/). The basic usage is as follows: 49 | 50 | ```python 51 | import anndata as ann 52 | import deepvelo as dv 53 | import scvelo as scv 54 | 55 | adata = ann.read_h5ad("..") # load your data in AnnData here - modify the path accordingly 56 | 57 | # preprocess the data 58 | scv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=2000) 59 | scv.pp.moments(adata, n_neighbors=30, n_pcs=30) 60 | 61 | # run DeepVelo using the default configs 62 | trainer = dv.train(adata, dv.Constants.default_configs) 63 | # this will train the model and predict the velocity vectore. The result is stored in adata.layers['velocity']. You can use trainer.model to access the model. 64 | 65 | # Plot the velocity results 66 | scv.tl.velocity_graph(adata, n_jobs=4) 67 | scv.pl.velocity_embedding_stream( 68 | adata, 69 | basis="umap", 70 | color="clusters", 71 | legend_fontsize=9, 72 | dpi=150 73 | ) 74 | ``` 75 | 76 | ### Fitting large number of cells 77 | 78 | If you can not fit a large dataset into (GPU) memory using the default configs, please try setting a small `inner_batch_size` in the configs, which can reduce the memory usage and maintain the same performance. 79 | 80 | Currently the training works on the whole graph of cells, we plan to release a flexible version using graph node sampling in the near future. 81 | -------------------------------------------------------------------------------- /data/data_preprocessing/cao_organogenesis/README.md: -------------------------------------------------------------------------------- 1 | # Preprocessing of the Cao et al. organogenesis data 2 | 3 | 4 | ### This directory contains scripts for preprocessing mouse hindbrain developmental data from Vladoiu et al. for the subsequent Velocity analysis with scVelo and DeepVelo 5 | 6 | ## Instructions: 7 | 8 | ### Install the R processing conda environment 9 | ``` 10 | conda env create -f environment.yaml 11 | ``` 12 | 13 | ### Downloading the Cao et al. organogenesis data 14 | Download the SRR files corresponding to the fastqs from the SRA archive - https://www.ncbi.nlm.nih.gov/Traces/study/?acc=PRJNA490754&o=acc_s%3Aa. 15 | Create a `data` directory and save the SRR files in that directory. 16 | 17 | Convert the SRA files to fastq using `scripts/sra_to_fastq.sh` 18 | 19 | ### Downloading the alevin fry quantification data 20 | The data and scripts used to reprocess the Cao et al. data are taken from the Alevin-Fry quantification pipeline (https://combine-lab.github.io/alevin-fry-tutorials/2021/sci-rna-seq3/), with two exceptions: 21 | 22 | 1) All of the samples are processed 23 | 2) The spliced and unspliced reads are kept separate for RNA velocity analysis 24 | 25 | To download the necessary data: 26 | ``` 27 | sh scripts/alevin_fry_data_dl.sh 28 | ``` 29 | 30 | ### Running the Alevin-fry quantification pipeline 31 | To quantify the fastqs and obtain `.rds` files with the spliced and unspliced counts, with matched barcodes, run: 32 | ``` 33 | Rscript cao_et_al_get_quant.R 34 | Rscript cao_et_al_update_anno.R 35 | ``` 36 | 37 | ### Converting the counts to h5ad and running the RNA velocity analysis 38 | 39 | Set the `data_folder` path to the output data in the previous step in `process2anndata.py` and run 40 | ```python 41 | python process2anndata.py 42 | ``` 43 | The h5ad files will be stored in the `output_folder`. 44 | -------------------------------------------------------------------------------- /data/data_preprocessing/cao_organogenesis/process2anndata.py: -------------------------------------------------------------------------------- 1 | # This file process the rds and csv files to generate the anndata object for the 2 | # cao et al organogenesis dataset 3 | # %% 4 | from pathlib import Path 5 | 6 | import scanpy as sc 7 | import pandas as pd 8 | import numpy as np 9 | import pyreadr 10 | import rpy2.robjects as robjects 11 | import anndata2ri 12 | from scipy import sparse 13 | 14 | data_folder = Path("/cluster/projects/bwanggroup/cao_et_al/alevin/runs") 15 | cell_annotation_file = data_folder / "cell_annotations_updated_aligned.rds" 16 | gene_annotation_file = data_folder / "gene_col_idx.rds" 17 | matrix_file = data_folder / "quants.rds" 18 | quants_col_names_file = data_folder / "quants_col_names.rds" 19 | quants_row_names_file = data_folder / "quants_row_names.rds" 20 | 21 | output_folder = Path("/cluster/home/haotianc/process_cao_organogenesis") 22 | 23 | #%% Read in the rds file 24 | cell_annotation = pyreadr.read_r(cell_annotation_file)[None] 25 | gene_annotation = pyreadr.read_r(gene_annotation_file)[None] 26 | 27 | gene_names = pyreadr.read_r(quants_col_names_file)[None] 28 | gene_names = gene_names.iloc[:, 0].to_list() 29 | assert len(gene_names) == len(gene_annotation) * 3 30 | gene_names = gene_names[: len(gene_annotation)] 31 | cell_barcodes = pyreadr.read_r(quants_row_names_file)[None] 32 | cell_barcodes = cell_barcodes.iloc[:, 0].to_list() 33 | assert len(cell_barcodes) == len(cell_annotation) 34 | assert all( 35 | [a == b for a, b in zip(cell_barcodes, cell_annotation["sample_name"])] 36 | ), "cell barcodes need to be aligned" 37 | 38 | # # align the cell annotation and cell barcodes 39 | # barcode_dict = cell_annotation.set_index("cb")["id"].to_dict() 40 | # assert all([x in barcode_dict for x in cell_barcodes]) 41 | # # reorder the cell annotation 42 | # cell_annotation.set_index("cb", inplace=True) 43 | 44 | readRDS = robjects.r["readRDS"] 45 | df = readRDS(str(matrix_file)) # dgCMatrix mapped to RS4 46 | # summary = robjects.r["summary"] 47 | # as_dataframe = robjects.r["as.data.frame"] 48 | # as_matrix = robjects.r["as.matrix"] 49 | # matrix_ = as_matrix(df) 50 | anndata2ri.activate() 51 | matrix = anndata2ri.scipy2ri.rpy2py(df) # anndata2ri.ri2py 52 | matrix = matrix.tocsr() 53 | assert isinstance(matrix, sparse.csr_matrix) 54 | 55 | #%% Wrap into an anndata object 56 | # meta info 57 | num_genes = len(gene_annotation) 58 | num_cells = len(cell_annotation) 59 | 60 | assert (gene_annotation["spliced"].to_numpy() - np.arange(num_genes) - 1).any() == 0 61 | spliced = matrix[:, :num_genes] 62 | unspliced = matrix[:, num_genes : 2 * num_genes] 63 | ambiguous = matrix[:, 2 * num_genes : 3 * num_genes] 64 | 65 | #%% Create an anndata object 66 | adata = sc.AnnData( 67 | X=spliced, 68 | obs=cell_annotation.set_index("sample_name"), 69 | var=pd.DataFrame(index=gene_names), 70 | layers={"spliced": spliced, "unspliced": unspliced, "ambiguous": ambiguous}, 71 | ) 72 | 73 | # add tsne and other coordinates 74 | adata.obsm["X_tsne"] = cell_annotation[["tsne_1", "tsne_2"]].to_numpy() 75 | adata.obs.drop(columns=["tsne_1", "tsne_2"], inplace=True) 76 | adata.obsm["sub_tsne"] = cell_annotation[["sub_tsne_1", "sub_tsne_2"]].to_numpy() 77 | adata.obs.drop(columns=["sub_tsne_1", "sub_tsne_2"], inplace=True) 78 | adata.obsm["Main_cluster_tsne"] = cell_annotation[ 79 | ["Main_cluster_tsne_1", "Main_cluster_tsne_2"] 80 | ].to_numpy() 81 | adata.obs.drop(columns=["Main_cluster_tsne_1", "Main_cluster_tsne_2"], inplace=True) 82 | adata.obsm["Sub_cluster_tsne"] = cell_annotation[ 83 | ["Sub_cluster_tsne_1", "Sub_cluster_tsne_2"] 84 | ].to_numpy() 85 | adata.obs.drop(columns=["Sub_cluster_tsne_1", "Sub_cluster_tsne_2"], inplace=True) 86 | adata.obsm["Main_trajectory_umap"] = cell_annotation[ 87 | ["Main_trajectory_umap_1", "Main_trajectory_umap_2", "Main_trajectory_umap_3"] 88 | ].to_numpy() 89 | adata.obs.drop( 90 | columns=[ 91 | "Main_trajectory_umap_1", 92 | "Main_trajectory_umap_2", 93 | "Main_trajectory_umap_3", 94 | ], 95 | inplace=True, 96 | ) 97 | adata.obsm["Main_trajectory_refined_umap"] = cell_annotation[ 98 | [ 99 | "Main_trajectory_refined_umap_1", 100 | "Main_trajectory_refined_umap_2", 101 | "Main_trajectory_refined_umap_3", 102 | ] 103 | ].to_numpy() 104 | adata.obs.drop( 105 | columns=[ 106 | "Main_trajectory_refined_umap_1", 107 | "Main_trajectory_refined_umap_2", 108 | "Main_trajectory_refined_umap_3", 109 | ], 110 | inplace=True, 111 | ) 112 | adata.obsm["Sub_trajectory_umap"] = cell_annotation[ 113 | ["Sub_trajectory_umap_1", "Sub_trajectory_umap_2"] 114 | ].to_numpy() 115 | adata.obs.drop( 116 | columns=[ 117 | "Sub_trajectory_umap_1", 118 | "Sub_trajectory_umap_2", 119 | ], 120 | inplace=True, 121 | ) 122 | # remove other unused columns 123 | adata.obs.drop( 124 | columns=[ 125 | "id.x", 126 | "sample", 127 | "id.y", 128 | ], 129 | inplace=True, 130 | ) 131 | 132 | # Save 133 | # convert obs columns to string if contains nan 134 | for col in adata.obs.columns: 135 | if adata.obs[col].isnull().any(): 136 | adata.obs[col] = adata.obs[col].astype(str) 137 | # save the anndata object 138 | adata.write(output_folder / "cao_organogenesis.h5ad", compression="gzip") 139 | 140 | # %% subset the data to only include Chondrocyte trajectory 141 | chondrocyte_idx = adata.obs["Sub_trajectory_name"] == "Chondrocyte trajectory" 142 | adata_chondrocyte = adata[chondrocyte_idx, :] 143 | 144 | print(adata_chondrocyte.obs["Sub_trajectory_name"].unique()) 145 | print(adata_chondrocyte.obs["Main_cell_type"].value_counts()) 146 | 147 | # Save 148 | # convert obs columns to string if contains nan 149 | for col in adata_chondrocyte.obs.columns: 150 | if adata_chondrocyte.obs[col].isnull().any(): 151 | adata_chondrocyte.obs[col] = adata_chondrocyte.obs[col].astype(str) 152 | # save the anndata object 153 | adata_chondrocyte.write( 154 | output_folder / "cao_organogenesis_chondrocyte.h5ad", compression="gzip" 155 | ) 156 | 157 | # %% subset the data to only include Mesenchymal trajectory 158 | mesenchymal_idx = ( 159 | adata.obs["Main_trajectory_refined_by_cluster"] == "Mesenchymal trajectory" 160 | ) 161 | adata_mesenchymal = adata[mesenchymal_idx, :] 162 | 163 | print(adata_mesenchymal.obs["Main_trajectory_refined_by_cluster"].unique()) 164 | print(adata_mesenchymal.obs["Main_cell_type"].value_counts()) 165 | 166 | # Save 167 | # convert obs columns to string if contains nan 168 | for col in adata_mesenchymal.obs.columns: 169 | if adata_mesenchymal.obs[col].isnull().any(): 170 | adata_mesenchymal.obs[col] = adata_mesenchymal.obs[col].astype(str) 171 | # save the anndata object 172 | adata_mesenchymal.write( 173 | output_folder / "cao_organogenesis_mesenchymal.h5ad", compression="gzip" 174 | ) 175 | 176 | # filter in the celltypes as in the figure 4 of cao et al. paper 177 | celltypes = [ 178 | "Chondrocytes & osteoblasts", 179 | "Connective tissue progenitors", 180 | "Intermediate Mesoderm", 181 | "Early mesenchyme", 182 | "Myocytes", 183 | "Chondroctye progenitors", 184 | "Limb mesenchyme", 185 | ] 186 | adata_mesenchymal_filtered = adata_mesenchymal[ 187 | adata_mesenchymal.obs["Main_cell_type"].isin(celltypes), : 188 | ] 189 | print(adata_mesenchymal_filtered.obs["Main_cell_type"].value_counts()) 190 | 191 | # Save 192 | adata_mesenchymal_filtered.write( 193 | output_folder / "cao_organogenesis_mesenchymal_filtered.h5ad", 194 | compression="gzip", 195 | ) 196 | 197 | # %% draw tsnes and umaps using the saved coordinates 198 | def plots(adata, dir_name=None): 199 | if dir_name is not None: 200 | default_figdir = sc.settings.figdir 201 | sc.settings.figdir = Path(dir_name) 202 | 203 | sc.pl.scatter( 204 | adata, 205 | basis="tsne", 206 | color=["Main_cell_type", "development_stage"], 207 | # color="Main_cell_type", 208 | legend_loc="on data", 209 | legend_fontsize=8, 210 | legend_fontoutline=1, 211 | show=False, 212 | save="_cao_organogenesis_tsne.png", 213 | ) 214 | 215 | # sub_tsne 216 | adata.obsm["X_sub_tsne"] = adata.obsm["sub_tsne"] 217 | sc.pl.scatter( 218 | adata, 219 | basis="sub_tsne", 220 | color=["Main_cell_type", "development_stage"], 221 | legend_fontsize=8, 222 | legend_fontoutline=1, 223 | show=False, 224 | save="_cao_organogenesis.png", 225 | ) 226 | 227 | # Main_cluster_tsne 228 | adata.obsm["X_Main_cluster_tsne"] = adata.obsm["Main_cluster_tsne"] 229 | sc.pl.scatter( 230 | adata, 231 | basis="Main_cluster_tsne", 232 | color=["Main_cell_type", "development_stage"], 233 | legend_fontsize=8, 234 | legend_fontoutline=1, 235 | show=False, 236 | save="_cao_organogenesis.png", 237 | ) 238 | 239 | # Sub_cluster_tsne 240 | adata.obsm["X_Sub_cluster_tsne"] = adata.obsm["Sub_cluster_tsne"] 241 | sc.pl.scatter( 242 | adata, 243 | basis="Sub_cluster_tsne", 244 | color=["Main_cell_type", "development_stage"], 245 | legend_fontsize=8, 246 | legend_fontoutline=1, 247 | show=False, 248 | save="_cao_organogenesis.png", 249 | ) 250 | 251 | # Main_trajectory_umap 252 | adata.obsm["X_Main_trajectory_umap"] = adata.obsm["Main_trajectory_umap"] 253 | sc.pl.scatter( 254 | adata, 255 | basis="Main_trajectory_umap", 256 | color=["Main_cell_type", "development_stage"], 257 | legend_fontsize=8, 258 | legend_fontoutline=1, 259 | show=False, 260 | save="_cao_organogenesis.png", 261 | ) 262 | 263 | # NOTE: this one is good, Main_trajectory_refined_umap 264 | adata.obsm["X_Main_trajectory_refined_umap"] = adata.obsm[ 265 | "Main_trajectory_refined_umap" 266 | ] 267 | sc.pl.scatter( 268 | adata, 269 | basis="Main_trajectory_refined_umap", 270 | color=["Main_cell_type", "development_stage"], 271 | legend_fontsize=8, 272 | legend_fontoutline=1, 273 | show=False, 274 | save="_cao_organogenesis.png", 275 | ) 276 | 277 | # Sub_trajectory_umap 278 | adata.obsm["X_Sub_trajectory_umap"] = adata.obsm["Sub_trajectory_umap"] 279 | sc.pl.scatter( 280 | adata, 281 | basis="Sub_trajectory_umap", 282 | color=["Main_cell_type", "development_stage"], 283 | legend_fontsize=8, 284 | legend_fontoutline=1, 285 | show=False, 286 | save="_cao_organogenesis.png", 287 | ) 288 | 289 | if dir_name is not None: 290 | sc.settings.figdir = default_figdir 291 | 292 | 293 | plots(adata_chondrocyte, dir_name=output_folder / "chondrocyte") 294 | plots(adata_mesenchymal, dir_name=output_folder / "mesenchymal") 295 | plots(adata_mesenchymal_filtered, dir_name=output_folder / "mesenchymal_filtered") 296 | -------------------------------------------------------------------------------- /data/data_preprocessing/cao_organogenesis/scripts/alevin_fry_data_dl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p data/cao_organogenesis 4 | cd data/cao_organogenesis 5 | 6 | wget http://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_mouse/release_M25/gencode.vM25.annotation.gtf.gz 7 | wget http://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_mouse/release_M25/GRCm38.primary_assembly.genome.fa.gz 8 | wget https://zenodo.org/record/5676291/files/cell_annotations.txt 9 | wget https://zenodo.org/record/5676291/files/transcriptome_splici_fl52_t2g_3col.tsv 10 | wget https://zenodo.org/record/5676291/files/cell_barcodes.txt 11 | wget https://zenodo.org/record/5676291/files/gene_annotation_table.txt 12 | wget https://zenodo.org/record/5676291/files/splici_idx_mouse_gencodeM25.tar.gz 13 | -------------------------------------------------------------------------------- /data/data_preprocessing/cao_organogenesis/scripts/cao_et_al_get_quant.R: -------------------------------------------------------------------------------- 1 | library(Matrix) 2 | library(data.table) 3 | 4 | # Change to working directory of alevin runs 5 | setwd("/cluster/projects/bwanggroup/cao_et_al/alevin/runs") 6 | 7 | # Get all of the directories that start with "res" 8 | dirs <- list.files(pattern = "res") 9 | 10 | # Extract the run name from the directory name 11 | srr_names <- lapply(dirs, function(fp) { 12 | strsplit(fp, "_")[[1]][2] 13 | }) 14 | 15 | # Load quant files from each directory 16 | quantFiles <- lapply(dirs, function(fp) { 17 | readMM(file.path(fp, "alevin/quants_mat.mtx")) 18 | }) 19 | names(quantFiles) <- srr_names 20 | 21 | # Load the cell files from each directory 22 | cellFiles <- lapply(dirs, function(fp) { 23 | fread( 24 | file.path(fp, "alevin/quants_mat_rows.txt"), 25 | header = F, 26 | col.names = "cellbarcode" 27 | ) 28 | }) 29 | names(cellFiles) <- srr_names 30 | 31 | # Load one file for the gene names 32 | genes <- fread( 33 | file.path(dirs[1], "alevin/quants_mat_cols.txt"), 34 | header = F, 35 | col.names = "geneId" 36 | ) 37 | 38 | # Change to top level dir (alevin) 39 | setwd("..") 40 | 41 | # Load cell annotations 42 | cell_anno <- fread("cell_annotations.txt") 43 | cell_anno <- cell_anno[run %in% names(quantFiles)] 44 | cellb <- sapply(cell_anno$cb,function(cb){ 45 | if(nchar(cb)==20){ 46 | return(paste0(cb,"A")) 47 | } else { 48 | return(paste0(cb,"AC")) 49 | } 50 | }) 51 | cell_anno[,cb:=cellb] 52 | 53 | # Check if the dimensions of quantFiles and cellFiles match up 54 | for (i in 1:length(quantFiles)){ 55 | if (dim(quantFiles[[i]])[1] != dim(cellFiles[[i]])[1]){ 56 | print(paste0("Mismatch in dimensions for ", names(quantFiles)[i])) 57 | } 58 | } 59 | 60 | # subset the quantFiles to include cells used in the study 61 | quantFiles <- lapply(1:length(quantFiles), function(i){ 62 | quantfile_sub <- quantFiles[[i]] 63 | cellfile_sub <- cellFiles[[i]] 64 | rownames(quantfile_sub) <- cellfile_sub$cellbarcode 65 | colnames(quantfile_sub) <- genes$geneId 66 | cbi <- cell_anno[run == names(quantFiles)[i], cb] 67 | rws <- cellfile_sub$cellbarcode %in% cbi 68 | quantfile_sub <- quantfile_sub[rws,] 69 | return(quantfile_sub) 70 | }) 71 | names(quantFiles) <- srr_names 72 | 73 | # get matching annotations for each quantFile 74 | cell_annotations <- lapply(1:length(quantFiles), function(i){ 75 | c_anno <- cell_anno[run == names(quantFiles)[i],] 76 | cb_matches <- match(rownames(quantFiles[[i]]), c_anno$cb) 77 | cb_matches <- cb_matches[!is.na(cb_matches)] 78 | return(c_anno[cb_matches]) 79 | }) 80 | 81 | gene_annotations <- fread("gene_annotation_table.txt") 82 | g_matches <- match(colnames(quantFiles[[1]]),gene_annotations$Geneid) 83 | gene_annotations <- gene_annotations[g_matches] 84 | 85 | # Don't add the spliced, unspliced and ambiguous counts, 86 | # but keep an index of which columns they correspond to 87 | gene_col_idx = data.frame( 88 | "spliced" = 1:55401, 89 | "unspliced" = 55402:110802, 90 | "ambiguous" = 110803:166203 91 | ) 92 | 93 | # For each quantfile add the full correct rowname based on the 94 | # cell barcode and the gene id 95 | quantFiles <- lapply(1:length(quantFiles), function(i){ 96 | quantfile_sub <- quantFiles[[i]] 97 | quantfile_name <- names(quantFiles)[i] 98 | quantfile_rownames <- rownames(quantfile_sub) 99 | rownames(quantfile_sub) <- paste0( 100 | "sci3-me-", quantfile_name, ".", quantfile_rownames 101 | ) 102 | return(quantfile_sub) 103 | }) 104 | names(quantFiles) <- srr_names 105 | 106 | # Similarly for the cell annotations add a column based on this procedure 107 | cell_annotations <- lapply(1:length(quantFiles), function(i) { 108 | quantfile_name <- names(quantFiles)[i] 109 | cell_annotations[[i]]$sample_name <- paste0( 110 | "sci3-me-", quantfile_name, ".", cell_annotations[[i]]$cb 111 | ) 112 | return(cell_annotations[[i]]) 113 | }) 114 | 115 | # Concatenate the quant files and cell annotations 116 | quants <- do.call(rbind, quantFiles) 117 | cell_annotations <- do.call(rbind, cell_annotations) 118 | 119 | # Save quants row and column names 120 | saveRDS(rownames(quants), "runs/quants_row_names.rds") 121 | saveRDS(colnames(quants), "runs/quants_col_names.rds") 122 | 123 | # Save the quants, cell annotations, and gene indices 124 | saveRDS(quants, "runs/quants.rds") 125 | saveRDS(cell_annotations, "runs/cell_annotations.rds") 126 | saveRDS(gene_col_idx, "runs/gene_col_idx.rds") 127 | -------------------------------------------------------------------------------- /data/data_preprocessing/cao_organogenesis/scripts/cao_et_al_update_anno.R: -------------------------------------------------------------------------------- 1 | library(data.table) 2 | 3 | # Change to working directory of alevin runs 4 | setwd("/cluster/projects/bwanggroup/cao_et_al/alevin/runs") 5 | 6 | # Load the cell annotations, with and without the cell-types 7 | cell_annos <- readRDS("cell_annotations.rds") 8 | cell_annos_with_type <- fread( 9 | "cell_annotations_with_labels.csv", 10 | sep = "," 11 | ) 12 | 13 | # Merge the two together (left merge on previously saved cell annotations) 14 | cell_annos_merged_with_type <- merge( 15 | cell_annos, 16 | cell_annos_with_type, 17 | by = "sample", 18 | all.x = TRUE 19 | ) 20 | 21 | # Reload the rows corresponding to the data cell barcodes 22 | quants_rows <- readRDS("quants_row_names.rds") 23 | 24 | # Create a dataframe of the quants rows and merge with the cell annotations 25 | quants_rows_df <- data.frame( 26 | sample_name = quants_rows, 27 | stringsAsFactors = FALSE 28 | ) 29 | quants_rows_df$id <- 1:nrow(quants_rows_df) 30 | 31 | # Merge the quants rows with the cell annotations 32 | quants_rows_cell_annos_merged <- merge( 33 | quants_rows_df, 34 | cell_annos_merged_with_type, 35 | by = "sample_name", 36 | all.x = TRUE 37 | ) 38 | 39 | # Get index for reordering based on sample name 40 | reorder_idx <- match( 41 | quants_rows_df$sample_name, quants_rows_cell_annos_merged$sample_name 42 | ) 43 | 44 | # Order the cell annotations by the order of the quants rows 45 | quants_rows_cell_annos_merged <- quants_rows_cell_annos_merged[reorder_idx,] 46 | 47 | # Ensure that ids from quants rows cell annos merged are in same order as quants rows 48 | if (all(quants_rows_df$sample_name == quants_rows_cell_annos_merged$sample_name)){ 49 | print("Cell annotations are in same order as quants rows") 50 | } else { 51 | stop("Cell annotations are not in same order as quants rows") 52 | } 53 | 54 | # Re-save the merged cell annotations 55 | saveRDS( 56 | quants_rows_cell_annos_merged, 57 | "cell_annotations_updated_aligned.rds" 58 | ) -------------------------------------------------------------------------------- /data/data_preprocessing/hindbrain_dev/README.md: -------------------------------------------------------------------------------- 1 | # Preprocessing of the mouse hindbrain development data for analysis 2 | 3 | ### This directory contains scripts for preprocessing mouse hindbrain developmental data from Vladoiu et al. for the subsequent Velocity analysis with scVelo and DeepVelo 4 | 5 | ## Instructions: 6 | 7 | ### Downloading data and kallisto files 8 | To be added after larger data repository prepared. 9 | 10 | ### Running Snakemake pipeline 11 | 12 | 1) Assuming conda is installed, resolve and install environment 13 | ``` 14 | conda env create -f environment.yaml 15 | ``` 16 | 17 | 2) Ensure files paths are correct in `config.json` by setting script and working directory 18 | 19 | 3) Test Snakemake (dry run) to ensure validity of DAG and files 20 | ``` 21 | conda activate hindbrain_velocity 22 | snakemake -np 23 | ``` 24 | 25 | 4) Run snakemake pipeline either locally (a) or using a custom configuration on HPC 26 | 27 | (Note for this step, creating a cluster.json or profile will be necessary. See Snakemake documentation for details - https://snakemake.readthedocs.io/en/stable/snakefiles/configuration.html) 28 | ``` 29 | a) snakemake --cores=8 30 | ``` 31 | -------------------------------------------------------------------------------- /data/data_preprocessing/hindbrain_dev/Snakefile: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | configfile: "config.json" 4 | scriptdir = os.path.join(os.getcwd(), "scripts") 5 | workdir: config["workdir"] 6 | 7 | # mm10 genome and associated files downloaded using the 8 | # 'mouse_download.sh' script from 9 | # https://github.com/linnarsson-lab/loompy/tree/master/kallisto 10 | 11 | # All directive - run all jobs beginning to end 12 | rule all: 13 | input: 14 | "loom_subset/loom_concat_subset.h5ad" 15 | 16 | # Process metadata file for loompy workflow 17 | rule metadata_file: 18 | params: 19 | script_path = os.path.join(scriptdir, "metadata.py") 20 | output: 21 | "hindbrain_metadata.tab" 22 | log: 23 | "logs/metadata_file/hindbrain_metadata.log" 24 | threads: 16 25 | shell: 26 | "python {params.script_path} {output} &> {log}" 27 | 28 | # Index fragments file using kallisto 29 | rule kallisto_idx: 30 | input: 31 | os.path.join( 32 | config["workdir"], 33 | "kallisto/inputs/gencode.vM23.fragments.fa" 34 | ) 35 | params: 36 | workdir = config["workdir"], 37 | build_script = os.path.join( 38 | scriptdir, 39 | "mouse_build.py" 40 | ), 41 | fragments_script = os.path.join( 42 | scriptdir, 43 | "mouse_generate_fragments.py" 44 | ), 45 | kmers = config["kallisto_idx_kmers"] 46 | output: 47 | os.path.join( 48 | config["workdir"], 49 | "kallisto/gencode.vM23.fragments.idx" 50 | ) 51 | threads: 16 52 | log: 53 | "logs/kallisto_idx/kallisto_idx.log" 54 | shell: 55 | """ 56 | cd {params.workdir}/kallisto 57 | python -v {params.build_script} 58 | python -v {params.fragments_script} 59 | kallisto index -i {output} -k {params.kmers} {input} 60 | """ 61 | 62 | # Create individual loompy files from fastqs for each timepoint 63 | rule fastq_to_loom: 64 | input: 65 | metadata = "{metaname}_metadata.tab", 66 | index_file = os.path.join( 67 | config["workdir"], 68 | "kallisto/gencode.{gencode_ver}.fragments.idx" 69 | ), 70 | read_l1_r1 = "fastq/{sample}/{run}_L001_R1_001.fastq.gz", 71 | read_l1_r2 = "fastq/{sample}/{run}_L001_R2_001.fastq.gz", 72 | read_l2_r1 = "fastq/{sample}/{run}_L002_R1_001.fastq.gz", 73 | read_l2_r2 = "fastq/{sample}/{run}_L002_R2_001.fastq.gz" 74 | params: 75 | index_path = os.path.join( 76 | config["workdir"], 77 | "kallisto" 78 | ) 79 | threads: 16 80 | output: 81 | "loom/{sample}/{run}_{metaname}_{gencode_ver}.loom" 82 | log: 83 | "logs/fastq_to_loom/{sample}/{run}_{metaname}_{gencode_ver}.log" 84 | shell: 85 | """ 86 | loompy fromfq {output} {wildcards.sample} {params.index_path} {input.metadata} \ 87 | {input.read_l1_r1} {input.read_l1_r2} {input.read_l2_r1} \ 88 | {input.read_l2_r2} \ 89 | &> {log} 90 | """ 91 | 92 | # Combine loompy files into one 93 | rule concat_looms: 94 | input: 95 | [ 96 | "loom/{sample}/{run}_hindbrain_vM23.loom".format(sample = i, run = j) 97 | for i,j in zip(config["samples"], config["runs"]) 98 | ] 99 | output: 100 | "loom_concat/{loomfile}.loom" 101 | params: 102 | script_path = os.path.join(scriptdir, "loompy_combine.py") 103 | threads: 16 104 | log: 105 | "logs/concat_looms/{loomfile}_looms.log" 106 | shell: 107 | """ 108 | python {params.script_path} \ 109 | --outfile {output} \ 110 | --loomfiles {input} \ 111 | &> {log} 112 | """ 113 | 114 | # Match metadata in loomfile with metadata file 115 | rule loom_metadata: 116 | input: 117 | "loom_concat/{loomfile}.loom" 118 | output: 119 | "loom_metadata/{loomfile}_loom_metadata.tsv", 120 | params: 121 | script_path = os.path.join(scriptdir, "loom_metadata.py"), 122 | barcode_meta_file = "cluster_annotations/barcode_cluster.csv", 123 | cluster_meta_file = "cluster_annotations/cluster_annotations.csv" 124 | threads: 16 125 | log: 126 | "logs/loom_metadata/{loomfile}_loom_metadata.log" 127 | shell: 128 | """ 129 | python {params.script_path} \ 130 | --outfile {output} \ 131 | --loomfile {input} \ 132 | --barcodefile {params.barcode_meta_file} \ 133 | --clusterfile {params.cluster_meta_file} \ 134 | &> {log} 135 | """ 136 | 137 | # Subset loomfile and convert to anndata 138 | # for only data (celltypes) that is utilized in the analysis 139 | rule loom_subset_to_adata: 140 | input: 141 | full_loomfile = "loom_concat/{loomfile}.loom", 142 | loom_metadata = "loom_metadata/{loomfile}_loom_metadata.tsv" 143 | output: 144 | "loom_subset/{loomfile}_subset.h5ad" 145 | params: 146 | script_path = os.path.join(scriptdir, "loom_subset.py") 147 | threads: 16 148 | log: 149 | "logs/loom_subset/{loomfile}_subset.log" 150 | shell: 151 | """ 152 | python {params.script_path} \ 153 | --outfile {output} \ 154 | --loomfile {input.full_loomfile} \ 155 | --metadata {input.loom_metadata} \ 156 | &> {log} 157 | """ -------------------------------------------------------------------------------- /data/data_preprocessing/hindbrain_dev/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "workdir": "[ENTER WORKING DIRECTORY]", 3 | "scriptdir": "[ENTER SCRIPT DIRECTORY]", 4 | "samples": ["E10", "E12", "E14", "E16", "E18", "P0", "P5", "P7", "P14"], 5 | "runs": [ 6 | "SingleCell_h9_S1", 7 | "SingleCell_h8_S2", 8 | "SingleCell_h7_S1", 9 | "SingleCell_h6_S2", 10 | "SingleCell_h12_S1", 11 | "SingleCell_f4_S2", 12 | "SingleCell_d11_S1", 13 | "SingleCell_f7_S2", 14 | "SingleCell_g9_S2" 15 | ], 16 | "kallisto_idx_kmers": 31, 17 | "subset_celltypes": [ 18 | "Differentiating GABA interneurons", 19 | "GABA interneurons", 20 | "Gliogenic progenitors", 21 | "Neural stem cells", 22 | "Proliferating VZ progenitors", 23 | "VZ progenitors" 24 | ] 25 | } -------------------------------------------------------------------------------- /data/data_preprocessing/hindbrain_dev/environment.yaml: -------------------------------------------------------------------------------- 1 | name: hindbrain_velocity 2 | channels: 3 | - r 4 | - conda-forge 5 | - bioconda 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=1_gnu 10 | - aioeasywebdav=2.4.0=py38h32f6830_1001 11 | - aiohttp=3.7.3=py38h497a2fe_1 12 | - alabaster=0.7.12=pyhd3eb1b0_0 13 | - amply=0.1.4=py_0 14 | - anndata=0.7.8=py38h578d9bd_1 15 | - appdirs=1.4.4=pyh9f0ad1d_0 16 | - arpack=3.7.0=hdefa2d7_2 17 | - async-timeout=3.0.1=py_1000 18 | - attmap=0.12.11=py_0 19 | - attrs=20.3.0=pyhd3deb0d_0 20 | - babel=2.9.1=pyhd3eb1b0_0 21 | - backports=1.0=py_2 22 | - backports.functools_lru_cache=1.6.1=py_0 23 | - bcrypt=3.2.0=py38h497a2fe_1 24 | - bedtools=2.30.0=hc088bd4_0 25 | - biopython=1.78=py38h497a2fe_1 26 | - blas=1.1=openblas 27 | - blosc=1.21.0=h9c3ff4c_0 28 | - boto3=1.17.9=pyhd8ed1ab_0 29 | - botocore=1.20.9=pyhd8ed1ab_0 30 | - brotlipy=0.7.0=py38h497a2fe_1001 31 | - bzip2=1.0.8=h7f98852_4 32 | - c-ares=1.17.1=h36c2ea0_0 33 | - ca-certificates=2020.12.5=ha878542_0 34 | - cachetools=4.2.1=pyhd8ed1ab_0 35 | - cairo=1.16.0=hcf35c78_1003 36 | - certifi=2020.12.5=py38h578d9bd_1 37 | - cffi=1.14.4=py38ha312104_0 38 | - chardet=3.0.4=py38h924ce5b_1008 39 | - click=7.1.2=pyh9f0ad1d_0 40 | - coincbc=2.10.5=hcee13e7_1 41 | - colorama=0.4.4=pyhd3eb1b0_0 42 | - configargparse=1.3=pyhd8ed1ab_0 43 | - cryptography=3.4.4=py38h3e25421_0 44 | - curl=7.78.0=hea6ffbf_0 45 | - cycler=0.10.0=py_2 46 | - cython=0.29.21=py38h709712a_2 47 | - datrie=0.8.2=py38h1e0a361_1 48 | - dbus=1.13.18=hb2f20db_0 49 | - decorator=4.4.2=py_0 50 | - docutils=0.16=py38h578d9bd_3 51 | - dropbox=10.9.0=pyhd3deb0d_0 52 | - dunamai=1.8.0=pyhd8ed1ab_0 53 | - expat=2.2.10=h9c3ff4c_0 54 | - fftw=3.3.9=nompi_h74d3f13_100 55 | - filechunkio=1.8=py_2 56 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 57 | - font-ttf-inconsolata=2.001=hab24e00_0 58 | - font-ttf-source-code-pro=2.030=hab24e00_0 59 | - font-ttf-ubuntu=0.83=hab24e00_0 60 | - fontconfig=2.13.1=hba837de_1004 61 | - fonts-conda-forge=1=0 62 | - freetype=2.10.4=h0708190_1 63 | - fribidi=1.0.10=h36c2ea0_0 64 | - ftputil=4.0.0=py_0 65 | - gdk-pixbuf=2.38.2=h3f25603_6 66 | - get_version=3.5.4=pyhd8ed1ab_0 67 | - gettext=0.19.8.1=hf34092f_1004 68 | - ghostscript=9.53.3=h58526e2_2 69 | - giflib=5.2.1=h36c2ea0_2 70 | - git=2.32.0=pl5262hc120c5b_1 71 | - gitdb=4.0.5=pyhd8ed1ab_1 72 | - gitpython=3.1.13=pyhd8ed1ab_0 73 | - glib=2.58.3=py38h73cb85d_1004 74 | - glpk=4.65=h9202a9a_1004 75 | - gmp=6.2.1=h58526e2_0 76 | - gobject-introspection=1.66.1=py38h03d966d_1 77 | - google-api-core=1.25.1=pyh44b312d_0 78 | - google-api-python-client=1.12.8=pyhd3deb0d_0 79 | - google-auth=1.24.0=pyhd3deb0d_0 80 | - google-auth-httplib2=0.0.4=pyh9f0ad1d_0 81 | - google-cloud-core=1.5.0=pyhd3deb0d_0 82 | - google-cloud-storage=1.35.0=pyhd3deb0d_0 83 | - google-crc32c=1.1.2=py38h8838a9a_0 84 | - google-resumable-media=1.2.0=pyhd3deb0d_0 85 | - googleapis-common-protos=1.52.0=py38h578d9bd_1 86 | - graphite2=1.3.13=h58526e2_1001 87 | - graphviz=2.42.3=h0511662_0 88 | - grpcio=1.35.0=py38hdd6454d_0 89 | - gst-plugins-base=1.14.5=h0935bb2_2 90 | - gstreamer=1.14.5=h36ae1b5_2 91 | - h5py=2.10.0=nompi_py38h513d04c_102 92 | - harfbuzz=2.4.0=h9f30f68_3 93 | - hdf5=1.10.5=nompi_h5b725eb_1114 94 | - httplib2=0.19.0=pyhd8ed1ab_0 95 | - icu=64.2=he1b5a44_1 96 | - idna=2.10=pyh9f0ad1d_0 97 | - imagemagick=7.0.10_28=pl5262hf04efa9_1 98 | - imagesize=1.3.0=pyhd8ed1ab_0 99 | - importlib-metadata=3.4.0=py38h578d9bd_0 100 | - importlib_metadata=3.4.0=hd8ed1ab_0 101 | - iniconfig=1.1.1=pyh9f0ad1d_0 102 | - ipython_genutils=0.2.0=py_1 103 | - jbig=2.1=h516909a_2002 104 | - jinja2=2.11.3=pyh44b312d_0 105 | - jmespath=0.10.0=pyh9f0ad1d_0 106 | - joblib=1.0.1=pyhd8ed1ab_0 107 | - jpeg=9d=h36c2ea0_0 108 | - jsonschema=3.2.0=py_2 109 | - jupyter_core=4.7.1=py38h578d9bd_0 110 | - kallisto=0.46.2=h4f7b962_1 111 | - kiwisolver=1.3.1=py38h1fd1430_1 112 | - krb5=1.19.2=hcc1bbae_0 113 | - lcms2=2.12=hddcbb42_0 114 | - ld_impl_linux-64=2.35.1=hea4e1c9_2 115 | - legacy-api-wrap=1.2=py_0 116 | - leidenalg=0.8.3=py38h82c7cc0_0 117 | - libblas=3.9.0=8_openblas 118 | - libcblas=3.9.0=8_openblas 119 | - libcrc32c=1.1.1=h9c3ff4c_2 120 | - libcurl=7.78.0=h2574ce0_0 121 | - libedit=3.1.20210714=h7f8727e_0 122 | - libev=4.33=h516909a_1 123 | - libffi=3.2.1=he1b5a44_1007 124 | - libgcc-ng=9.3.0=h2828fa1_18 125 | - libgfortran-ng=9.3.0=hff62375_18 126 | - libgfortran5=9.3.0=hff62375_18 127 | - libgomp=9.3.0=h2828fa1_18 128 | - libiconv=1.16=h516909a_0 129 | - liblapack=3.9.0=8_openblas 130 | - libllvm10=10.0.1=he513fc3_3 131 | - libnghttp2=1.43.0=h812cca2_0 132 | - libopenblas=0.3.12=pthreads_h4812303_1 133 | - libpng=1.6.37=h21135ba_2 134 | - libprotobuf=3.14.0=h780b84a_0 135 | - librsvg=2.50.2=h1f8de02_0 136 | - libsodium=1.0.18=h36c2ea0_1 137 | - libssh2=1.9.0=ha56f1ee_6 138 | - libstdcxx-ng=9.3.0=h6de172a_18 139 | - libtiff=4.2.0=hdc55705_0 140 | - libtool=2.4.6=h58526e2_1007 141 | - libuuid=2.32.1=h7f98852_1000 142 | - libwebp=1.2.0=h3452ae3_0 143 | - libwebp-base=1.2.0=h7f98852_0 144 | - libxcb=1.13=h7f98852_1003 145 | - libxml2=2.9.10=hee79883_0 146 | - llvmlite=0.35.0=py38h4630a5e_1 147 | - logmuse=0.2.6=pyh8c360ce_0 148 | - loompy=3.0.6=py_0 149 | - lz4-c=1.9.3=h9c3ff4c_0 150 | - lzo=2.10=h516909a_1000 151 | - markupsafe=1.1.1=py38h497a2fe_3 152 | - matplotlib=3.3.4=py38h578d9bd_0 153 | - matplotlib-base=3.3.4=py38h0efea84_0 154 | - metis=5.1.0=h58526e2_1006 155 | - mock=4.0.3=py38h578d9bd_2 156 | - more-itertools=8.7.0=pyhd8ed1ab_0 157 | - multidict=5.1.0=py38h497a2fe_1 158 | - natsort=8.1.0=pyhd8ed1ab_0 159 | - nbformat=5.1.2=pyhd8ed1ab_1 160 | - ncurses=6.2=h58526e2_4 161 | - networkx=2.5=py_0 162 | - numba=0.52.0=py38h51da96c_0 163 | - numexpr=2.8.1=py38hecfb737_0 164 | - numpy=1.20.1=py38h18fd61f_0 165 | - numpy_groupies=0.9.13=pyh9f0ad1d_1 166 | - oauth2client=4.1.3=py_0 167 | - olefile=0.46=pyh9f0ad1d_1 168 | - openblas=0.3.12=pthreads_h04b7a96_1 169 | - openjpeg=2.3.1=hf7af979_3 170 | - openssl=1.1.1k=h7f98852_0 171 | - packaging=20.9=pyh44b312d_0 172 | - pandas=1.2.2=py38h51da96c_0 173 | - pango=1.42.4=h7062337_4 174 | - paramiko=2.7.2=pyh9f0ad1d_0 175 | - patsy=0.5.1=py_0 176 | - pcre=8.44=he1b5a44_0 177 | - pcre2=10.35=h032f7d1_2 178 | - peppy=0.31.0=pyh9f0ad1d_0 179 | - perl=5.26.2=h36c2ea0_1008 180 | - pillow=8.1.0=py38ha0e1e83_2 181 | - pip=21.0.1=pyhd8ed1ab_0 182 | - pixman=0.38.0=h516909a_1003 183 | - pkg-config=0.29.2=h36c2ea0_1008 184 | - pluggy=0.13.1=py38h578d9bd_4 185 | - prettytable=2.0.0=pyhd8ed1ab_0 186 | - protobuf=3.14.0=py38h709712a_1 187 | - psutil=5.8.0=py38h497a2fe_1 188 | - pthread-stubs=0.4=h36c2ea0_1001 189 | - pulp=2.3.1=py38h32f6830_0 190 | - py=1.10.0=pyhd3deb0d_0 191 | - pyasn1=0.4.8=py_0 192 | - pyasn1-modules=0.2.7=py_0 193 | - pycparser=2.20=pyh9f0ad1d_2 194 | - pygments=2.8.0=pyhd8ed1ab_0 195 | - pygraphviz=1.7=py38h0d738da_0 196 | - pynacl=1.4.0=py38h497a2fe_2 197 | - pynndescent=0.5.2=pyh44b312d_0 198 | - pyopenssl=20.0.1=pyhd8ed1ab_0 199 | - pyparsing=2.4.7=pyh9f0ad1d_0 200 | - pyqt=5.9.2=py38h05f1152_4 201 | - pyrsistent=0.17.3=py38h497a2fe_2 202 | - pysftp=0.2.9=py_1 203 | - pysocks=1.7.1=py38h578d9bd_3 204 | - pytables=3.6.1=py38h9f153d1_1 205 | - pytest=6.2.2=py38h578d9bd_0 206 | - python=3.8.6=h852b56e_0_cpython 207 | - python-dateutil=2.8.1=py_0 208 | - python-igraph=0.9.1=py38h2af5540_0 209 | - python-irodsclient=0.8.2=py_0 210 | - python_abi=3.8=1_cp38 211 | - pytz=2021.1=pyhd8ed1ab_0 212 | - pyyaml=5.4.1=py38h497a2fe_0 213 | - qt=5.9.7=h0c104cb_3 214 | - ratelimiter=1.2.0=py_1002 215 | - readline=8.0=he28a2e2_2 216 | - requests=2.25.1=pyhd3deb0d_0 217 | - rsa=4.7.1=pyh44b312d_0 218 | - s3transfer=0.3.4=pyhd8ed1ab_0 219 | - scanpy=1.8.2=pyhd8ed1ab_0 220 | - scikit-learn=0.24.1=py38h658cfdd_0 221 | - scipy=1.6.0=py38hb2138dd_0 222 | - seaborn=0.11.1=hd8ed1ab_1 223 | - seaborn-base=0.11.1=pyhd8ed1ab_1 224 | - setuptools=49.6.0=py38h578d9bd_3 225 | - simplejson=3.17.2=py38h497a2fe_2 226 | - sinfo=0.3.1=py_0 227 | - sip=4.19.13=py38he6710b0_0 228 | - six=1.15.0=pyh9f0ad1d_0 229 | - slacker=0.14.0=py_0 230 | - smmap=3.0.5=pyh44b312d_0 231 | - snakemake=5.32.2=0 232 | - snakemake-minimal=5.32.2=py_0 233 | - snowballstemmer=2.2.0=pyhd8ed1ab_0 234 | - sphinx=4.3.2=pyh6c4a22f_0 235 | - sphinxcontrib-applehelp=1.0.2=pyhd3eb1b0_0 236 | - sphinxcontrib-devhelp=1.0.2=pyhd3eb1b0_0 237 | - sphinxcontrib-htmlhelp=2.0.0=pyhd8ed1ab_0 238 | - sphinxcontrib-jsmath=1.0.1=pyhd3eb1b0_0 239 | - sphinxcontrib-qthelp=1.0.3=pyhd3eb1b0_0 240 | - sphinxcontrib-serializinghtml=1.1.5=pyhd8ed1ab_1 241 | - sqlite=3.34.0=h74cdb3f_0 242 | - statsmodels=0.12.2=py38h5c078b8_0 243 | - stdlib-list=0.7.0=py38h32f6830_1 244 | - suitesparse=5.7.2=h7a0d4b7_0 245 | - tbb=2020.2=h4bd325d_4 246 | - texttable=1.6.3=pyh9f0ad1d_0 247 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 248 | - tk=8.6.10=h21135ba_1 249 | - toml=0.10.2=pyhd8ed1ab_0 250 | - toposort=1.6=pyhd8ed1ab_0 251 | - tornado=6.1=py38h497a2fe_1 252 | - tqdm=4.62.3=pyhd3eb1b0_1 253 | - traitlets=5.0.5=py_0 254 | - typing-extensions=3.7.4.3=0 255 | - typing_extensions=3.7.4.3=py_0 256 | - ubiquerg=0.6.1=pyh9f0ad1d_0 257 | - umap-learn=0.5.1=py38h578d9bd_0 258 | - uritemplate=3.0.1=py_0 259 | - urllib3=1.26.3=pyhd8ed1ab_0 260 | - veracitools=0.1.3=py_0 261 | - wcwidth=0.2.5=pyh9f0ad1d_2 262 | - wheel=0.36.2=pyhd3deb0d_0 263 | - wrapt=1.12.1=py38h497a2fe_3 264 | - xmlrunner=1.7.7=py_0 265 | - xorg-kbproto=1.0.7=h7f98852_1002 266 | - xorg-libice=1.0.10=h516909a_0 267 | - xorg-libsm=1.2.3=h84519dc_1000 268 | - xorg-libx11=1.6.12=h516909a_0 269 | - xorg-libxau=1.0.9=h7f98852_0 270 | - xorg-libxdmcp=1.1.3=h7f98852_0 271 | - xorg-libxext=1.3.4=h516909a_0 272 | - xorg-libxpm=3.5.13=h516909a_0 273 | - xorg-libxrender=0.9.10=h516909a_1002 274 | - xorg-libxt=1.1.5=h516909a_1003 275 | - xorg-renderproto=0.11.1=h14c3975_1002 276 | - xorg-xextproto=7.3.0=h7f98852_1002 277 | - xorg-xproto=7.0.31=h7f98852_1007 278 | - xz=5.2.5=h516909a_1 279 | - yaml=0.2.5=h516909a_0 280 | - yarl=1.5.1=py38h1e0a361_0 281 | - zipp=3.4.0=py_0 282 | - zlib=1.2.11=h516909a_1010 283 | - zstd=1.4.8=ha95c52a_1 284 | -------------------------------------------------------------------------------- /data/data_preprocessing/hindbrain_dev/kallisto_readme/README.md: -------------------------------------------------------------------------------- 1 | # Building an annotated kallisto index 2 | 3 | For the human genome see the notebook subdirectory. 4 | 5 | ## Building the mouse kallisto index 6 | 7 | These instructions work on Linux (tested on CentOS7). 8 | 9 | 1. Make sure packages bedtools and kallisto are installed on the system. 10 | - bedtools from https://bedtools.readthedocs.io/en/latest/content/installation.html 11 | - kallisto from https://pachterlab.github.io/kallisto/download.html 12 | 13 | 2. Create your working directory, 'cd' there, and put the files above there. 14 | 15 | 3. Download and preprocess input files: 16 | 17 | `bash mouse_download.sh` 18 | 19 | This will create a directory "inputs" and put some files there as well as in the current directory. 20 | 21 | 4. Download "BrowseTF TcoF-DB.xlsx" from https://tools.sschmeier.com/tcof/browse/?type=tcof&species=mouse&class=all# by clicking the "Excel" button. (Main page is https://tools.sschmeier.com/tcof/home/). 22 | Open the file in Excel and save tab-separated as "inputs/TcoF-DB.tsv". 23 | 24 | 5. You need to download some annotations for Mouse GRCm38 from BioMart (https://m.ensembl.org/biomart) Open this link in a new browser tab: 25 | 26 | http://www.ensembl.org/biomart/martview/7c9b283e3eca26cb81449ec518f4fc14?VIRTUALSCHEMANAME=default&ATTRIBUTES=mmusculus_gene_ensembl.default.feature_page.ensembl_gene_id|mmusculus_gene_ensembl.default.feature_page.ensembl_gene_id_version|mmusculus_gene_ensembl.default.feature_page.ensembl_transcript_id|mmusculus_gene_ensembl.default.feature_page.ensembl_transcript_id_version|mmusculus_gene_ensembl.default.feature_page.ucsc|mmusculus_gene_ensembl.default.feature_page.vega_translation|mmusculus_gene_ensembl.default.feature_page.ccds&FILTERS=&VISIBLEPANEL=resultspanel 27 | 28 | On this BioMart page, click the "Go" button, and save the downloaded "mart_export.txt" file as "inputs/mart_export.txt". 29 | The file should contain the following columns in the header: 30 | Gene stable ID Gene stable ID version Transcript stable ID Transcript stable ID version UCSC Stable ID Vega translation ID CCDS ID 31 | 32 | If the link fails, you need to manually select the proper dataset and columns from the https://m.ensembl.org/biomart webpage and download: 33 | * Select Dataset "Ensembl Genes 101"/"Mouse genes GRCm38". 34 | * Select Attributes as in columns above: First 4 should be auto-selected. Select the following 3 from the "EXTERNAL" section, clicking the 3 boxes in the order above. 35 | * Click "Results", export using "Go" button, and save to "inputs/mart_export.txt". 36 | 37 | 6. Run the annotation assembly script: 38 | 39 | `python mouse_build.py` 40 | 41 | 7. Create the "manifest.json" file or use the one supplied above. It should contain: 42 | ``` 43 | { 44 | "species": "Mus musculus", 45 | "index_file": "gencode.vM23.fragments.idx", 46 | "gene_metadata_file": "gencode.vM23.metadata.tab", 47 | "gene_metadata_key": "AccessionVersion", 48 | "fragments_to_genes_file": "fragments2genes.txt", 49 | "layers": { 50 | "unspliced": "unspliced_fragments.txt", 51 | "spliced": "spliced_fragments.txt" 52 | } 53 | } 54 | ``` 55 | 56 | 8. Run the fragment generator script: 57 | 58 | `python mouse_generate_fragments.py` 59 | 60 | 9. Build the kallisto index: 61 | 62 | `kallisto index -i gencode.vM23.fragments.idx -k 31 inputs/gencode.vM23.fragments.fa` 63 | 64 | 10. Refer to the notebook for human for more info on the output. 65 | -------------------------------------------------------------------------------- /data/data_preprocessing/hindbrain_dev/scripts/loom_metadata.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import loompy 4 | import numpy as np 5 | import pandas as pd 6 | 7 | # Function to get barcodes from loom file and subset 8 | # barcode metadata from vladoiu et al and return 9 | def loom_metadata(loom_file, barcode_file, cluster_file, out_file): 10 | # Connect to loom file 11 | ds = loompy.connect(loom_file) 12 | 13 | # Load barcode and cluster files 14 | barcode_meta = pd.read_csv(barcode_file) 15 | cluster_meta = pd.read_csv(cluster_file, header = None) 16 | 17 | # Format data for both 18 | barcode_meta.columns = ["CellID", "Cluster"] 19 | cluster_meta.columns = ["Cluster", "Celltype", "Lineage"] 20 | 21 | # Create dataframe of CellID and UMI data from loom file 22 | loom_cellid = ds.ca["CellID"] 23 | loom_umi = ds.ca["TotalUMIs"] 24 | loom_meta = pd.DataFrame({ 25 | "CellID": loom_cellid, 26 | "TotalUMIs": loom_umi 27 | }) 28 | 29 | # Consecutively merge three datasets 30 | # (loom meta -> barcode meta -> cluster meta) 31 | loom_barcode_merge = loom_meta.merge( 32 | barcode_meta, 33 | how = "inner", 34 | on = "CellID" 35 | ) 36 | loom_barcode_cluster_merge = loom_barcode_merge.merge( 37 | cluster_meta, 38 | how = "inner", 39 | on = "Cluster" 40 | ) 41 | 42 | # Add timepoint information 43 | barcode_list = list(loom_barcode_cluster_merge["CellID"].values) 44 | timepoints = [i.split("_", 1)[0] for i in barcode_list] 45 | loom_barcode_cluster_merge["Timepoint"] = timepoints 46 | 47 | # Save dataframe to tsv 48 | loom_barcode_cluster_merge.to_csv( 49 | out_file, 50 | sep = "\t", 51 | header = True, 52 | index = False 53 | ) 54 | 55 | if __name__ == "__main__": 56 | parser = argparse.ArgumentParser( 57 | description = "Output file and input files " + 58 | "for concatenation of loom data" 59 | ) 60 | parser.add_argument( 61 | "--outfile", 62 | type = str, 63 | help = "Path of output for merged loom metadata" 64 | ) 65 | parser.add_argument( 66 | "--loomfile", 67 | type = str, 68 | help = "Path of concatenated loom file for developmental data" 69 | ) 70 | parser.add_argument( 71 | "--barcodefile", 72 | type = str, 73 | help = "Path of barcode-cluster metadata file" 74 | ) 75 | parser.add_argument( 76 | "--clusterfile", 77 | type = str, 78 | help = "Path of cluster-celltype-lineage metadata file" 79 | ) 80 | args = parser.parse_args() 81 | loom_metadata( 82 | loom_file = args.loomfile, 83 | barcode_file = args.barcodefile, 84 | cluster_file = args.clusterfile, 85 | out_file = args.outfile 86 | ) -------------------------------------------------------------------------------- /data/data_preprocessing/hindbrain_dev/scripts/loom_subset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import scanpy as sc 6 | import anndata as ann 7 | import loompy as lp 8 | 9 | # Define main function to subset loom file (after importing) 10 | # as anndata for the relevant celltypes/lineages used in analysis 11 | def loom_subset(full_loomfile, loom_metadata, out_file): 12 | # Read in loom file using scanpy 13 | full_data = sc.read_loom( 14 | full_loomfile, 15 | sparse = True 16 | ) 17 | 18 | # Rename index 19 | full_data.obs.index.rename("id", inplace = True) 20 | 21 | # Add new column to obs using CellIDs from index 22 | full_data.obs["CellID"] = full_data.obs.index 23 | 24 | # Read in loom metadata file 25 | loom_metadata = pd.read_csv( 26 | loom_metadata, 27 | sep = "\t" 28 | ) 29 | 30 | # Pull out obs columns from full data 31 | full_data_obs = full_data.obs.copy() 32 | 33 | # Merge anndata obs and loom metadata 34 | full_data_obs_meta = full_data_obs.merge( 35 | loom_metadata, 36 | how = "inner", 37 | on = "CellID" 38 | ) 39 | 40 | # Subset original anndata object for only matched cellids and combine information 41 | full_data = full_data[full_data.obs["CellID"].isin(full_data_obs_meta["CellID"])] 42 | 43 | # Subset metadata for only matched cellids 44 | loom_metadata_sub = loom_metadata[loom_metadata["CellID"].isin(full_data_obs_meta["CellID"])] 45 | 46 | # Append subset loom metadata to original anndata object and reset index 47 | full_data.obs = full_data.obs.merge( 48 | loom_metadata_sub, 49 | how = "inner", 50 | on = "CellID" 51 | ) 52 | full_data.obs.set_index("CellID", inplace = True) 53 | 54 | # Subset data to only include celltypes/lineages used in analysis 55 | full_data_sub = full_data[ 56 | full_data.obs["Celltype"].isin([ 57 | "Differentiating GABA interneurons", 58 | "GABA interneurons", 59 | "Gliogenic progenitors", 60 | "Neural stem cells", 61 | "Proliferating VZ progenitors", 62 | "VZ progenitors" 63 | ]) 64 | ] 65 | 66 | # Return subsetted data 67 | full_data_sub.write( 68 | out_file 69 | ) 70 | 71 | if __name__ == "__main__": 72 | parser = argparse.ArgumentParser( 73 | description = "Output file and input files " + 74 | "for subsetting of loom data" 75 | ) 76 | parser.add_argument( 77 | "--outfile", 78 | type = str, 79 | help = "Path of output for subset data as anndata object" 80 | ) 81 | parser.add_argument( 82 | "--loomfile", 83 | type = str, 84 | help = "Path of concatenated loom file for developmental data" 85 | ) 86 | parser.add_argument( 87 | "--metadata", 88 | type = str, 89 | help = "Path of inner joined metadata for loom file" 90 | ) 91 | args = parser.parse_args() 92 | loom_subset( 93 | full_loomfile = args.loomfile, 94 | loom_metadata = args.metadata, 95 | out_file = args.outfile 96 | ) 97 | 98 | -------------------------------------------------------------------------------- /data/data_preprocessing/hindbrain_dev/scripts/loompy_combine.py: -------------------------------------------------------------------------------- 1 | import loompy 2 | import argparse 3 | 4 | # Function to take in arbitrary number of loompy 5 | # files, return and save concatenated 6 | def loompy_combine(out_file, loom_file_list): 7 | # Define params 8 | files = loom_file_list 9 | output_file = out_file 10 | key = "Accession" 11 | 12 | # Concat and save loom file 13 | loompy.combine( 14 | files = files, 15 | output_file = output_file, 16 | key = key 17 | ) 18 | 19 | if __name__ == "__main__": 20 | parser = argparse.ArgumentParser( 21 | description = "Output file and input files " + 22 | "for concatenation of loom data" 23 | ) 24 | parser.add_argument( 25 | "--outfile", 26 | type = str, 27 | help = "Path of output for concatenated loom file" 28 | ) 29 | parser.add_argument( 30 | "--loomfiles", 31 | type = str, 32 | nargs = "*", 33 | help = "Any number of loom file paths" 34 | ) 35 | args = parser.parse_args() 36 | loompy_combine( 37 | out_file = args.outfile, 38 | loom_file_list = args.loomfiles 39 | ) -------------------------------------------------------------------------------- /data/data_preprocessing/hindbrain_dev/scripts/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "species": "Mus musculus", 3 | "index_file": "gencode.vM23.fragments.idx", 4 | "gene_metadata_file": "gencode.vM23.metadata.tab", 5 | "gene_metadata_key": "AccessionVersion", 6 | "fragments_to_genes_file": "fragments2genes.txt", 7 | "layers": { 8 | "unspliced": "unspliced_fragments.txt", 9 | "spliced": "spliced_fragments.txt" 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /data/data_preprocessing/hindbrain_dev/scripts/metadata.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import numpy as np 4 | import pandas as pd 5 | 6 | def write_metadata(save_loc): 7 | # Define subsets, tech, and target cells 8 | # Target cell number defined from 9 | # https://www.nature.com/articles/s41586-019-1158-7 10 | # rounded up to the nearest thousand 11 | samples = [ 12 | "E10", 13 | "E12", 14 | "E14", 15 | "E16", 16 | "E18", 17 | "P0", 18 | "P5", 19 | "P7", 20 | "P14" 21 | ] 22 | technology = "10xv2" 23 | target_cells = [ 24 | "8000", 25 | "8000", 26 | "7000", 27 | "8000", 28 | "6000", 29 | "5000", 30 | "12000", 31 | "8000", 32 | "5000" 33 | ] 34 | 35 | # Create and save df as tsv 36 | meta_df = pd.DataFrame({ 37 | "name" : samples, 38 | "technology" : technology, 39 | "targetnumcells": target_cells 40 | }) 41 | meta_df.to_csv( 42 | save_loc, 43 | sep = "\t", 44 | index = False 45 | ) 46 | 47 | if __name__ == '__main__': 48 | write_metadata(sys.argv[1]) -------------------------------------------------------------------------------- /data/data_preprocessing/hindbrain_dev/scripts/mouse_build.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | d = "./" 4 | 5 | mgiID2MRK_ENSE = {} 6 | enseID2mgiID = {} 7 | with open(d + "inputs/MRK_ENSEMBL.rpt") as f: 8 | for line in f: 9 | items = line[:-1].split("\t") 10 | mgiID = items[0] 11 | mgiID2MRK_ENSE[mgiID] = items 12 | enseID = items[5] 13 | enseID2mgiID[enseID] = mgiID 14 | 15 | enseID2MGI_GMC = {} 16 | with open(d + "inputs/MGI_Gene_Model_Coord.rpt") as f: 17 | MGI_GMC_headers = f.readline()[:-1].split("\t") 18 | for line in f: 19 | items = line[:-1].split("\t") 20 | enseID = items[10] 21 | enseID2MGI_GMC[enseID] = items 22 | 23 | mgiID2MRK_Seq = {} 24 | with open(d + "inputs/MRK_Sequence.rpt") as f: 25 | MRK_Seq_headers = f.readline()[:-1].split("\t") 26 | for line in f: 27 | items = line[:-1].split("\t") 28 | mgiID = items[0] 29 | mgiID2MRK_Seq[mgiID] = items 30 | 31 | enseID2mart = {} 32 | with open(d + "inputs/mart_export.txt") as f: 33 | mart_headers = f.readline()[:-1].split("\t") 34 | for line in f: 35 | items = line[:-1].split("\t") 36 | enseID = items[0] 37 | enseID2mart[enseID] = items 38 | 39 | geneSymbol2TF = {} 40 | with open(d + "inputs/TF_TcoF-DB.tsv") as f: 41 | TF_headers = f.readline()[:-1].split("\t") 42 | for line in f: 43 | items = line[:-1].split("\t") 44 | geneSymbol = items[0] 45 | geneSymbol2TF[geneSymbol] = items 46 | 47 | geneSymbol2Regulated = defaultdict(list) 48 | with open(d + "inputs/trrust_rawdata.mouse.tsv") as f: 49 | for line in f: 50 | items = line[:-1].split("\t") 51 | TFSymbol = items[0] 52 | geneSymbol2Regulated[TFSymbol].append(items[1]) 53 | 54 | with open(d + "gencode.vM23.metadata.tab", "w") as fout: 55 | fout.write("\t".join([ 56 | "Accession", 57 | "AccessionVersion", 58 | "Gene", 59 | "FullName", 60 | "GeneType", 61 | "HgncID", 62 | "Chromosome", 63 | "Strand", 64 | "ChromosomeStart", 65 | "ChromosomeEnd", 66 | "LocusGroup", 67 | "LocusType", 68 | "Location", 69 | "LocationSortable", 70 | "Aliases", 71 | "VegaID", 72 | "UcscID", 73 | "RefseqID", 74 | "CcdsID", 75 | "UniprotID", 76 | "PubmedID", 77 | "MgdID", 78 | "RgdID", 79 | "CosmicID", 80 | "OmimID", 81 | "MirBaseID", 82 | "IsTFi (TcoF-DB)", 83 | "DnaBindingDomain", 84 | "Regulates (TRRUST)" 85 | ])) 86 | fout.write("\n") 87 | with open(d + "inputs/gencode.vM23.primary_assembly.annotation.gtf") as f: 88 | for line in f: 89 | if line.startswith("##"): 90 | continue 91 | items = line[:-1].split("\t") 92 | if items[2] != "gene": 93 | continue 94 | extra = {x.strip().split(" ")[0]: x.strip().split(" ")[1].strip('"') for x in items[8].split(";")[:-1]} 95 | enseID = extra["gene_id"].split(".")[0] 96 | geneSymbol = extra.get("gene_name", "") 97 | fout.write("\t".join([ 98 | enseID, 99 | extra["gene_id"], 100 | geneSymbol, 101 | enseID2MGI_GMC[enseID][3] if enseID in enseID2MGI_GMC else "", # full name 102 | extra["gene_type"], # gene type from gencode 103 | "", # HGNC id 104 | items[0], # Chromosome 105 | items[6], 106 | items[3], # Start 107 | items[4], # End 108 | "", # Locus group 109 | mgiID2MRK_ENSE[mgiID][8], # Locus type 110 | "", # Location 111 | "", # Location, sortable 112 | "", # Aliases 113 | enseID2mart[enseID][5] if enseID in enseID2mart else "", # VEGA id 114 | enseID2mart[enseID][4] if enseID in enseID2mart else "", # UCSC id 115 | mgiID2MRK_Seq[mgiID][12], # Refseq id 116 | enseID2mart[enseID][6] if enseID in enseID2mart else "", # CCDS id 117 | mgiID2MRK_Seq[mgiID][14], # Uniprot id 118 | "", # Pubmed id 119 | "", # MGD id 120 | "", # RGD id 121 | "", # COSMIC id 122 | "", # OMIM id 123 | "", # MIRbase id 124 | "True" if (geneSymbol in geneSymbol2TF) else "False", # IsTF? 125 | "", # DBD 126 | ",".join(geneSymbol2Regulated[geneSymbol]) # TF regulated genes 127 | ])) 128 | fout.write("\n") 129 | -------------------------------------------------------------------------------- /data/data_preprocessing/hindbrain_dev/scripts/mouse_download.sh: -------------------------------------------------------------------------------- 1 | mkdir inputs 2 | 3 | # Make sure these manual steps have been done: 4 | # Download "BrowseTF TcoF-DB.xlsx" from https://tools.sschmeier.com/tcof/browse/?type=tcof&species=mouse&class=all# (a button at https://tools.sschmeier.com/tcof/home/) 5 | # Open the file in Excel and save tab-separated as "inputs/TcoF-Db.tsv" 6 | # 7 | # You need to import data from BioMart using this link: 8 | # http://www.ensembl.org/biomart/martview/7c9b283e3eca26cb81449ec518f4fc14?VIRTUALSCHEMANAME=default&ATTRIBUTES=mmusculus_gene_ensembl.default.feature_page.ensembl_gene_id|mmusculus_gene_ensembl.default.feature_page.ensembl_gene_id_version|mmusculus_gene_ensembl.default.feature_page.ensembl_transcript_id|mmusculus_gene_ensembl.default.feature_page.ensembl_transcript_id_version|mmusculus_gene_ensembl.default.feature_page.ucsc|mmusculus_gene_ensembl.default.feature_page.vega_translation|mmusculus_gene_ensembl.default.feature_page.ccds&FILTERS=&VISIBLEPANEL=resultspanel 9 | # by clicking "Go" button, and saving the downloaded "mart_export.txt" file in "inputs/mart_export.txt". 10 | # The file should contain the following columns: 11 | # Gene stable ID Gene stable ID version Transcript stable ID Transcript stable ID version UCSC Stable ID Vega translation ID CCDS ID 12 | 13 | wget ftp://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_mouse/release_M23/GRCm38.primary_assembly.genome.fa.gz 14 | wget ftp://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_mouse/release_M23/gencode.vM23.primary_assembly.annotation.gtf.gz 15 | wget https://github.com/10XGenomics/cellranger/raw/master/lib/python/cellranger/barcodes/3M-february-2018.txt.gz 16 | wget https://github.com/10XGenomics/cellranger/raw/master/lib/python/cellranger/barcodes/737K-april-2014_rc.txt 17 | wget https://github.com/10XGenomics/cellranger/raw/master/lib/python/cellranger/barcodes/737K-august-2016.txt 18 | 19 | zcat gencode.vM23.primary_assembly.annotation.gtf.gz | gawk 'OFS="\t" {if ($3=="gene") {print $1,$4-1,$5,$10,0,$7}}' | tr -d '";' > gencode.vM23.primary_assembly.annotation.bed 20 | bedtools sort -i gencode.vM23.primary_assembly.annotation.bed > gencode.vM23.primary_assembly.annotation.sorted.bed 21 | bedtools merge -i gencode.vM23.primary_assembly.annotation.sorted.bed -s -c 4 -o collapse > gencode.vM23.primary_assembly.annotation.merged.bed 22 | gunzip GRCm38.primary_assembly.genome.fa.gz 23 | bedtools getfasta -name -fi GRCm38.primary_assembly.genome.fa -bed gencode.vM23.primary_assembly.annotation.sorted.bed | sed 's/::.*//' > gencode.vM23.unspliced.fa 24 | 25 | mv 737K-april-2014_rc.txt 10xv1_whitelist.txt 26 | mv 737K-august-2016.txt 10xv2_whitelist.txt 27 | gunzip 3M-february-2018.txt.gz 28 | mv 3M-february-2018.txt 10xv3_whitelist.txt 29 | 30 | mv GRCm38.primary_assembly.genome.fa* inputs/ 31 | mv gencode.vM23.unspliced.fa inputs/ 32 | mv gencode.vM23.primary_assembly.annotation.bed inputs/ 33 | mv gencode.vM23.primary_assembly.annotation.sorted.bed inputs/ 34 | gunzip gencode.vM23.primary_assembly.annotation.gtf.gz 35 | mv gencode.vM23.primary_assembly.annotation.gtf inputs/ 36 | 37 | cd inputs 38 | wget ftp://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_mouse/release_M23/gencode.vM23.transcripts.fa.gz 39 | wget https://www.grnpedia.org/trrust/data/trrust_rawdata.mouse.tsv 40 | gunzip gencode.vM23.transcripts.fa.gz 41 | wget http://www.informatics.jax.org/downloads/reports/MGI_Gene_Model_Coord.rpt 42 | wget http://www.informatics.jax.org/downloads/reports/MRK_ENSEMBL.rpt 43 | wget http://www.informatics.jax.org/downloads/reports/MRK_Sequence.rpt 44 | cd .. 45 | 46 | -------------------------------------------------------------------------------- /data/data_preprocessing/hindbrain_dev/scripts/mouse_generate_fragments.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | extent = 600 # how many bases away from polya to include 4 | min_len = 90 # how many non-repeat bases required to make a transcript 5 | 6 | from typing import * 7 | from Bio import SeqIO 8 | from Bio.SeqRecord import SeqRecord 9 | from Bio.Seq import Seq 10 | 11 | def find_polys(seq: SeqRecord, c: str = "A", n: int = 15) -> List[Tuple[int, int]]: 12 | found = [] 13 | count = seq[:n].count(c) # Count occurences in the first k-mer 14 | if count >= n - 1: # We have a match 15 | found.append(0) 16 | ix = 0 17 | while ix < len(seq) - n - 1: 18 | if seq[ix] == c: # Outgoing base 19 | count -= 1 20 | if seq[ix + n] == c: # Incoming base 21 | count += 1 22 | ix += 1 23 | if count >= n - 1: # We have a match 24 | found.append(ix) 25 | 26 | sorted_by_lower_bound = [(f, f + n) for f in found] 27 | # merge intervals (https://codereview.stackexchange.com/questions/69242/merging-overlapping-intervals) 28 | merged = [] 29 | for higher in sorted_by_lower_bound: 30 | if not merged: 31 | merged.append(higher) 32 | else: 33 | lower = merged[-1] 34 | # test for intersection between lower and higher: 35 | # we know via sorting that lower[0] <= higher[0] 36 | if higher[0] <= lower[1]: 37 | upper_bound = max(lower[1], higher[1]) 38 | merged[-1] = (lower[0], upper_bound) # replace by merged interval 39 | else: 40 | merged.append(higher) 41 | return merged 42 | 43 | polyAs = {} 44 | polyTs = {} 45 | for fasta in SeqIO.parse(open("inputs/gencode.vM23.unspliced.fa"),'fasta'): 46 | gene_id = fasta.id 47 | intervals = find_polys(fasta.seq, c="A", n=14) 48 | if len(intervals) > 0: 49 | polyAs[gene_id] = intervals 50 | # Collect fragments on the opposite strand, downstream of poly-Ts (not sure if such reads really happen?) 51 | intervals = find_polys(fasta.seq, c="T", n=14) 52 | if len(intervals) > 0: 53 | polyTs[gene_id] = intervals 54 | 55 | tr2g = {} 56 | with open("inputs/gencode.vM23.primary_assembly.annotation.gtf") as f: 57 | for line in f: 58 | if "\ttranscript\t" in line: 59 | items = line.split("; ") 60 | chrom, _, _, start, end, _, strand, _, gid = items[0].split("\t") 61 | gene_id = gid.split('"')[1] 62 | transcript_id = items[1].split('"')[1] 63 | gene_type = items[2].split('"')[1] 64 | gene_name = items[3].split('"')[1] 65 | tr2g[transcript_id] = (chrom, start, end, strand, gene_id, gene_type, gene_name) 66 | 67 | count = 0 68 | with open("fragments2genes.txt", "w") as ftr2g: 69 | with open("inputs/gencode.vM23.fragments.fa", "w") as fout: 70 | # Write the nascent fragments, with one partial transcript per internal poly-A/T site 71 | with open("unspliced_fragments.txt", "w") as fucapture: 72 | for fasta in SeqIO.parse(open("inputs/gencode.vM23.unspliced.fa"),'fasta'): # Note we're in the masked file now 73 | gene_id = fasta.id 74 | if gene_id in polyAs: 75 | for interval in polyAs[gene_id]: 76 | seq = str(fasta.seq[max(0, interval[0] - extent):interval[0]]) 77 | #seq = seq.translate(tr).strip("N") 78 | if len(seq) >= min_len: 79 | count += 1 80 | transcript_id = f"{gene_id}.A{interval[0]}" 81 | trseq = SeqRecord(Seq(seq), transcript_id, '', '') 82 | fout.write(trseq.format("fasta")) 83 | ftr2g.write(f"{transcript_id}\t{gene_id}\n") 84 | fucapture.write(f"{transcript_id}\n") 85 | if gene_id in polyTs: 86 | for interval in polyTs[gene_id]: 87 | seq = str(fasta.seq[interval[1]:interval[1] + extent].reverse_complement()) 88 | #seq = seq.translate(tr).strip("N") 89 | if len(seq) >= min_len: 90 | count += 1 91 | transcript_id = f"{gene_id}.T{interval[0]}" 92 | trseq = SeqRecord(Seq(seq), transcript_id, '', '') 93 | fout.write(trseq.format("fasta")) 94 | ftr2g.write(f"{transcript_id}\t{gene_id}\n") 95 | fucapture.write(f"{transcript_id}\n") 96 | # Write the mature fragments, covering the 3' end of each mature transcript 97 | with open("spliced_fragments.txt", "w") as fscapture: 98 | for fasta in SeqIO.parse(open("inputs/gencode.vM23.transcripts.fa"),'fasta'): # Note we're in the masked file now 99 | transcript_id = fasta.id.split("|")[0] 100 | gene_id = fasta.id.split("|")[1] 101 | attrs = tr2g[transcript_id] 102 | seq = str(fasta.seq[-extent:]) 103 | if len(seq) >= min_len: 104 | count += 1 105 | trseq = SeqRecord(Seq(seq), f"{transcript_id}.{count} gene_id:{attrs[4]} gene_name:{attrs[6]}", '', '') 106 | fout.write(trseq.format("fasta")) 107 | ftr2g.write(f"{transcript_id}.{count}\t{attrs[4]}\n") 108 | fscapture.write(f"{transcript_id}.{count}\n") 109 | 110 | -------------------------------------------------------------------------------- /deepvelo/__init__.py: -------------------------------------------------------------------------------- 1 | from .train import * 2 | 3 | from . import tool as tl 4 | from . import plot as pl 5 | from . import pipeline as pipe 6 | -------------------------------------------------------------------------------- /deepvelo/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_data_loader import * 2 | from .base_model import * 3 | from .base_trainer import * 4 | -------------------------------------------------------------------------------- /deepvelo/base/base_data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.dataloader import default_collate 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | 6 | 7 | class BaseDataLoader(DataLoader): 8 | """ 9 | Base class for all data loaders 10 | """ 11 | 12 | def __init__( 13 | self, 14 | dataset, 15 | batch_size, 16 | shuffle, 17 | validation_split, 18 | num_workers, 19 | collate_fn=default_collate, 20 | ): 21 | self.validation_split = validation_split 22 | self.shuffle = shuffle 23 | 24 | self.batch_idx = 0 25 | self.n_samples = len(dataset) 26 | 27 | self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) 28 | 29 | self.init_kwargs = { 30 | "dataset": dataset, 31 | "batch_size": batch_size, 32 | "shuffle": self.shuffle, 33 | "collate_fn": collate_fn, 34 | "num_workers": num_workers, 35 | } 36 | super().__init__(sampler=self.sampler, **self.init_kwargs) 37 | 38 | def _split_sampler(self, split): 39 | if split == 0.0: 40 | return None, None 41 | 42 | idx_full = np.arange(self.n_samples) 43 | 44 | np.random.seed(0) 45 | np.random.shuffle(idx_full) 46 | 47 | if isinstance(split, int): 48 | assert split > 0 49 | assert ( 50 | split < self.n_samples 51 | ), "validation set size is configured to be larger than entire dataset." 52 | len_valid = split 53 | else: 54 | len_valid = int(self.n_samples * split) 55 | 56 | valid_idx = idx_full[0:len_valid] 57 | train_idx = np.delete(idx_full, np.arange(0, len_valid)) 58 | 59 | train_sampler = SubsetRandomSampler(train_idx) 60 | valid_sampler = SubsetRandomSampler(valid_idx) 61 | 62 | # turn off shuffle option which is mutually exclusive with sampler 63 | self.shuffle = False 64 | self.n_samples = len(train_idx) 65 | 66 | return train_sampler, valid_sampler 67 | 68 | def split_validation(self): 69 | if self.valid_sampler is None: 70 | return None 71 | else: 72 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs) 73 | -------------------------------------------------------------------------------- /deepvelo/base/base_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | from abc import abstractmethod 4 | 5 | 6 | class BaseModel(nn.Module): 7 | """ 8 | Base class for all models 9 | """ 10 | 11 | @abstractmethod 12 | def forward(self, *inputs): 13 | """ 14 | Forward pass logic 15 | 16 | :return: Model output 17 | """ 18 | raise NotImplementedError 19 | 20 | def __str__(self): 21 | """ 22 | Model prints with number of trainable parameters 23 | """ 24 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 25 | params = sum([np.prod(p.size()) for p in model_parameters]) 26 | return super().__str__() + "\nTrainable parameters: {}".format(params) 27 | -------------------------------------------------------------------------------- /deepvelo/base/base_trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from abc import abstractmethod 4 | from numpy import inf 5 | from deepvelo.logger import TensorboardWriter 6 | 7 | 8 | class BaseTrainer: 9 | """ 10 | Base class for all trainers 11 | """ 12 | 13 | def __init__(self, model, criterion, metric_ftns, optimizer, config): 14 | self.config = config 15 | self.logger = config.get_logger("trainer", config["trainer"]["verbosity"]) 16 | 17 | # setup GPU device if available, move model into configured device 18 | self.device, device_ids = self._prepare_device(config["n_gpu"]) 19 | self.model = model.to(self.device) 20 | if len(device_ids) > 1: 21 | self.model = torch.nn.DataParallel(model, device_ids=device_ids) 22 | 23 | self.criterion = criterion 24 | self.metric_ftns = metric_ftns 25 | self.optimizer = optimizer 26 | 27 | cfg_trainer = config["trainer"] 28 | self.epochs = cfg_trainer["epochs"] 29 | self.save_period = cfg_trainer["save_period"] 30 | self.monitor = cfg_trainer.get("monitor", "off") 31 | 32 | # configuration to monitor model performance and save best 33 | if self.monitor == "off": 34 | self.mnt_mode = "off" 35 | self.mnt_best = 0 36 | else: 37 | self.mnt_mode, self.mnt_metric = self.monitor.split() 38 | assert self.mnt_mode in ["min", "max"] 39 | 40 | self.mnt_best = inf if self.mnt_mode == "min" else -inf 41 | self.early_stop = cfg_trainer.get("early_stop", inf) 42 | 43 | self.start_epoch = 1 44 | 45 | self.checkpoint_dir = config.save_dir 46 | 47 | # setup visualization writer instance 48 | self.writer = TensorboardWriter( 49 | config.log_dir, self.logger, cfg_trainer["tensorboard"] 50 | ) 51 | 52 | if config.resume is not None: 53 | self._resume_checkpoint(config.resume) 54 | 55 | @abstractmethod 56 | def _train_epoch(self, epoch): 57 | """ 58 | Training logic for an epoch 59 | 60 | :param epoch: Current epoch number 61 | """ 62 | raise NotImplementedError 63 | 64 | def train(self, callback=None, callback_freq=1): 65 | """ 66 | Full training logic 67 | """ 68 | not_improved_count = 0 69 | tik = time.time() 70 | if "mle" in self.config["loss"]["type"]: 71 | if self.config["arch"]["args"]["pred_unspliced"]: 72 | self.candidate_states = torch.cat( 73 | [ 74 | self.data_loader.dataset.Sx_sz, 75 | self.data_loader.dataset.Ux_sz, 76 | ], 77 | dim=1, 78 | ).to(self.device) 79 | else: 80 | self.candidate_states = self.data_loader.dataset.Sx_sz.to(self.device) 81 | for epoch in range(self.start_epoch, self.epochs + 1): 82 | result = self._train_epoch(epoch) 83 | 84 | # save logged informations into log dict 85 | log = {"epoch": epoch, "time:": time.time() - tik} 86 | log.update(result) 87 | tik = time.time() 88 | 89 | # print logged informations to the screen 90 | for key, value in log.items(): 91 | self.logger.info(" {:15s}: {}".format(str(key), value)) 92 | 93 | if callback is not None: 94 | if epoch % callback_freq == 0: 95 | callback(epoch) 96 | 97 | # evaluate model performance according to configured metric, save best checkpoint as model_best 98 | best = False 99 | if self.mnt_mode != "off": 100 | try: 101 | # check whether model performance improved or not, according to specified metric(mnt_metric) 102 | improved = ( 103 | self.mnt_mode == "min" and log[self.mnt_metric] <= self.mnt_best 104 | ) or ( 105 | self.mnt_mode == "max" and log[self.mnt_metric] >= self.mnt_best 106 | ) 107 | except KeyError: 108 | self.logger.warning( 109 | "Warning: Metric '{}' is not found. " 110 | "Model performance monitoring is disabled.".format( 111 | self.mnt_metric 112 | ) 113 | ) 114 | self.mnt_mode = "off" 115 | improved = False 116 | 117 | if improved: 118 | self.mnt_best = log[self.mnt_metric] 119 | not_improved_count = 0 120 | best = True 121 | else: 122 | not_improved_count += 1 123 | 124 | if not_improved_count > self.early_stop: 125 | self.logger.info( 126 | "Validation performance didn't improve for {} epochs. " 127 | "Training stops.".format(self.early_stop) 128 | ) 129 | break 130 | 131 | if epoch % self.save_period == 0: 132 | self._save_checkpoint(epoch, save_best=best) 133 | 134 | def train_with_epoch_callback(self, callback, freq): 135 | self.train(callback, freq) 136 | 137 | def _prepare_device(self, n_gpu_use): 138 | """ 139 | setup GPU device if available, move model into configured device 140 | """ 141 | n_gpu = torch.cuda.device_count() 142 | if n_gpu_use > 0 and n_gpu == 0: 143 | self.logger.warning( 144 | "Warning: There's no GPU available on this machine," 145 | "training will be performed on CPU." 146 | ) 147 | n_gpu_use = 0 148 | if n_gpu_use > n_gpu: 149 | self.logger.warning( 150 | "Warning: The number of GPU's configured to use is {}, but only {} are available " 151 | "on this machine.".format(n_gpu_use, n_gpu) 152 | ) 153 | n_gpu_use = n_gpu 154 | device = torch.device("cuda:0" if n_gpu_use > 0 else "cpu") 155 | list_ids = list(range(n_gpu_use)) 156 | return device, list_ids 157 | 158 | def _save_checkpoint(self, epoch, save_best=False): 159 | """ 160 | Saving checkpoints 161 | 162 | :param epoch: current epoch number 163 | :param log: logging information of the epoch 164 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth' 165 | """ 166 | arch = type(self.model).__name__ 167 | state = { 168 | "arch": arch, 169 | "epoch": epoch, 170 | "state_dict": self.model.state_dict(), 171 | "optimizer": self.optimizer.state_dict(), 172 | "monitor_best": self.mnt_best, 173 | "config": self.config, 174 | } 175 | filename = str(self.checkpoint_dir / "checkpoint-epoch{}.pth".format(epoch)) 176 | torch.save(state, filename) 177 | self.logger.info("Saving checkpoint: {} ...".format(filename)) 178 | if save_best: 179 | best_path = str(self.checkpoint_dir / "model_best.pth") 180 | torch.save(state, best_path) 181 | self.logger.info("Saving current best: model_best.pth ...") 182 | 183 | def _resume_checkpoint(self, resume_path): 184 | """ 185 | Resume from saved checkpoints 186 | 187 | :param resume_path: Checkpoint path to be resumed 188 | """ 189 | resume_path = str(resume_path) 190 | self.logger.info("Loading checkpoint: {} ...".format(resume_path)) 191 | checkpoint = torch.load(resume_path) 192 | self.start_epoch = checkpoint["epoch"] + 1 193 | self.mnt_best = checkpoint["monitor_best"] 194 | 195 | # load architecture params from checkpoint. 196 | if checkpoint["config"]["arch"] != self.config["arch"]: 197 | self.logger.warning( 198 | "Warning: Architecture configuration given in config file is different from that of " 199 | "checkpoint. This may yield an exception while state_dict is being loaded." 200 | ) 201 | self.model.load_state_dict(checkpoint["state_dict"]) 202 | 203 | # load optimizer state from checkpoint only when optimizer type is not changed. 204 | if ( 205 | checkpoint["config"]["optimizer"]["type"] 206 | != self.config["optimizer"]["type"] 207 | ): 208 | self.logger.warning( 209 | "Warning: Optimizer type given in config file is different from that of checkpoint. " 210 | "Optimizer parameters not being resumed." 211 | ) 212 | else: 213 | self.optimizer.load_state_dict(checkpoint["optimizer"]) 214 | 215 | self.logger.info( 216 | "Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch) 217 | ) 218 | -------------------------------------------------------------------------------- /deepvelo/data_loader/data_loaders.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from pathlib import Path 3 | 4 | # from torchvision import datasets, transforms 5 | import numpy as np 6 | import torch 7 | import dgl 8 | import hnswlib 9 | from dgl.contrib.sampling import NeighborSampler 10 | from torch.utils.data import Dataset 11 | from sklearn.metrics import pairwise_distances 12 | from sklearn.decomposition import PCA 13 | from anndata import AnnData 14 | 15 | from deepvelo.base import BaseDataLoader 16 | 17 | 18 | class VeloDataset(Dataset): 19 | def __init__( 20 | self, 21 | data_source, 22 | train=True, 23 | type="average", 24 | topC=30, 25 | topG=20, 26 | velocity_genes=False, 27 | use_scaled_u=False, 28 | ): 29 | # check if data_source is a file path or inmemory data 30 | if isinstance(data_source, str): 31 | data_source = Path(data_source) 32 | with open(data_source, "rb") as f: 33 | adata = pickle.load(f) 34 | elif isinstance(data_source, AnnData): 35 | adata = data_source 36 | else: 37 | raise ValueError("data_source must be a file path or anndata object") 38 | self.Ux_sz = adata.layers["Mu"] 39 | self.Sx_sz = adata.layers["Ms"] 40 | if velocity_genes: 41 | self.Ux_sz = self.Ux_sz[:, adata.var["velocity_genes"]] 42 | self.Sx_sz = self.Sx_sz[:, adata.var["velocity_genes"]] 43 | if use_scaled_u: 44 | scaling = np.std(self.Ux_sz, axis=0) / np.std(self.Sx_sz, axis=0) 45 | self.Ux_sz = self.Ux_sz / scaling 46 | self.connectivities = adata.obsp["connectivities"] # shape (cells, features) 47 | 48 | self.topG = topG 49 | N_cell, N_gene = self.Sx_sz.shape 50 | 51 | # build the knn graph in the original space 52 | # TODO: try using the original connectivities to build the graph 53 | if "pca" in type: 54 | n_pcas = 30 55 | pca_ = PCA( 56 | n_components=n_pcas, 57 | svd_solver="randomized", 58 | ) 59 | Sx_sz_pca = pca_.fit_transform(self.Sx_sz) 60 | if N_cell < 3000: 61 | ori_dist = pairwise_distances(Sx_sz_pca, Sx_sz_pca) 62 | self.ori_idx = np.argsort(ori_dist, axis=1)[:, :topG] # (1720, 20) 63 | self.nn_t_idx = np.argsort(ori_dist, axis=1)[:, 1:topC] 64 | else: 65 | p = hnswlib.Index(space="l2", dim=n_pcas) 66 | p.init_index(max_elements=N_cell, ef_construction=200, M=30) 67 | p.add_items(Sx_sz_pca) 68 | p.set_ef(max(topC, topG) + 10) 69 | self.ori_idx = p.knn_query(Sx_sz_pca, k=topG)[0].astype(int) 70 | self.nn_t_idx = p.knn_query(Sx_sz_pca, k=topC)[0][:, 1:].astype(int) 71 | else: 72 | raise NotImplementedError( 73 | "the argument type of VeloDataset has to include original " 74 | "distance method 'pca' or 'raw'" 75 | ) 76 | 77 | self.g = self.build_graph(self.ori_idx) 78 | self.nn_idx = self.nn_t_idx 79 | self.neighbor_time = 0 80 | 81 | # update the velocity target vectors for spliced and unspliced counts 82 | self.velo = np.zeros(self.Sx_sz.shape, dtype=np.float32) 83 | self.velo_u = np.zeros(self.Ux_sz.shape, dtype=np.float32) 84 | for i in range(N_cell): 85 | self.velo[i] = np.mean(self.Sx_sz[self.nn_idx[i]], axis=0) - self.Sx_sz[i] 86 | self.velo_u[i] = np.mean(self.Ux_sz[self.nn_idx[i]], axis=0) - self.Ux_sz[i] 87 | 88 | # build masks 89 | mask = np.ones([N_cell, N_gene]) 90 | mask[self.Ux_sz == 0] = 0 91 | mask[self.Sx_sz == 0] = 0 92 | 93 | self.Ux_sz = torch.tensor(self.Ux_sz, dtype=torch.float32) 94 | self.Sx_sz = torch.tensor(self.Sx_sz, dtype=torch.float32) 95 | self.velo = torch.tensor(self.velo, dtype=torch.float32) 96 | self.velo_u = torch.tensor(self.velo_u, dtype=torch.float32) 97 | self.mask = torch.tensor(mask, dtype=torch.float32) 98 | print("velo data shape:", self.velo.shape) 99 | 100 | def large_batch(self, device): 101 | """ 102 | build the large batch for training 103 | """ 104 | # check if self._large_batch is already built 105 | if hasattr(self, "_large_batch"): 106 | return self._large_batch 107 | self._large_batch = [ 108 | { 109 | "Ux_sz": self.Ux_sz.to(device), 110 | "Sx_sz": self.Sx_sz.to(device), 111 | "velo": self.velo.to(device), 112 | "velo_u": self.velo_u.to(device), 113 | "mask": self.mask.to(device), 114 | "t+1 neighbor idx": torch.tensor( 115 | self.nn_t_idx, 116 | dtype=torch.long, 117 | ).to(device), 118 | } 119 | ] 120 | return self._large_batch 121 | 122 | def __len__(self): 123 | return len(self.Ux_sz) # 1720 124 | 125 | def __getitem__(self, i): 126 | data_dict = { 127 | "Ux_sz": self.Ux_sz[i], 128 | "Sx_sz": self.Sx_sz[i], 129 | "velo": self.velo[i], 130 | "velo_u": self.velo_u[i], 131 | "mask": self.mask[i], 132 | "t+1 neighbor idx": self.nn_t_idx[i], 133 | } 134 | return data_dict 135 | 136 | def gen_neighbor_batch(self, size): 137 | indices = np.random.random_integers(0, high=len(self) - 1, size=size) 138 | # self.neighbors_per_gene is the neighbor indices for all cells, shape 139 | # (N_cells, topG, genes) 140 | 141 | # TODO(Haotian): try the per gene version 142 | # Here since the per gene version encounters the 0 gene count bug, we first 143 | # use the per cell version, which is using self.ind 144 | return self.ind[indices, : self.topG].flatten() 145 | 146 | def build_graph(self, ind): 147 | """ind (N,k) contains neighbor index""" 148 | print("building graph") 149 | g = dgl.DGLGraph() 150 | g.add_nodes(len(self.Ux_sz)) 151 | edge_list = [] 152 | for i in range(ind.shape[0]): 153 | for j in range(ind.shape[1]): 154 | edge_list.append((i, ind[i, j])) 155 | # add edges two lists of nodes: src and dst 156 | src, dst = tuple(zip(*edge_list)) 157 | g.add_edges(src, dst) 158 | # edges are directional in DGL; make them bi-directional 159 | g.add_edges(dst, src) 160 | return g 161 | 162 | 163 | class VeloDataLoader(BaseDataLoader): 164 | """ 165 | MNIST data loading demo using BaseDataLoader 166 | """ 167 | 168 | def __init__( 169 | self, 170 | data_source, 171 | batch_size, 172 | shuffle=True, 173 | validation_split=0.0, 174 | num_workers=1, 175 | training=True, 176 | type="average", 177 | topC=30, 178 | topG=16, 179 | velocity_genes=False, 180 | use_scaled_u=False, 181 | ): 182 | self.data_source = data_source 183 | self.dataset = VeloDataset( 184 | data_source, 185 | train=training, 186 | type=type, 187 | topC=topC, 188 | topG=topG, 189 | velocity_genes=velocity_genes, 190 | use_scaled_u=use_scaled_u, 191 | ) 192 | self.shuffle = shuffle 193 | self.is_large_batch = batch_size == len(self.dataset) 194 | super().__init__( 195 | self.dataset, batch_size, shuffle, validation_split, num_workers 196 | ) 197 | 198 | 199 | class VeloNeighborSampler(NeighborSampler, BaseDataLoader): 200 | """ 201 | minibatch neighbor sampler using DGL NeighborSampler 202 | """ 203 | 204 | def __init__( 205 | self, 206 | data_dir, 207 | batch_size, 208 | num_neighbors, 209 | num_hops, 210 | shuffle=True, 211 | validation_split=0.0, 212 | num_workers=32, 213 | training=True, 214 | ): 215 | self.data_dir = data_dir 216 | self.dataset = VeloDataset(self.data_dir, train=training) 217 | # FIXME: the split_validation here is not working as in the BaseDataLoader 218 | # BaseDataLoader.__init__(self, self.dataset, batch_size, shuffle, 219 | # validation_split, num_workers) 220 | 221 | g = self.dataset.g 222 | norm = 1.0 / g.in_degrees().float().unsqueeze(1) 223 | g.ndata["Ux_sz"] = self.dataset.Ux_sz 224 | g.ndata["Sx_sz"] = self.dataset.Sx_sz 225 | g.ndata["velo"] = self.dataset.velo 226 | g.ndata["norm"] = norm 227 | # need to set to readonly for nodeflow 228 | g.readonly() 229 | 230 | NeighborSampler.__init__( 231 | self, 232 | g, 233 | batch_size, 234 | num_neighbors, 235 | neighbor_type="in", 236 | shuffle=shuffle, 237 | num_workers=num_workers, 238 | num_hops=num_hops, 239 | # seed_nodes=train_nid 240 | ) 241 | 242 | # FIXME: the split_validation here is not working as in the BaseDataLoader 243 | def split_validation(self): 244 | return None 245 | 246 | def __len__(self): 247 | return self.dataset.__len__() 248 | 249 | 250 | if __name__ == "__main__": 251 | VeloDataset("./data/DG_norm_genes.npz") 252 | VeloNeighborSampler("./data/DG_norm_genes.npz", 32, 15, 4) 253 | -------------------------------------------------------------------------------- /deepvelo/logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * 2 | from .visualization import * 3 | -------------------------------------------------------------------------------- /deepvelo/logger/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | from pathlib import Path 4 | from deepvelo.utils import read_json 5 | 6 | 7 | def setup_logging( 8 | save_dir, log_config="logger/logger_config.json", default_level=logging.INFO 9 | ): 10 | """ 11 | Setup logging configuration 12 | """ 13 | log_config = Path(log_config) 14 | if log_config.is_file(): 15 | config = read_json(log_config) 16 | # modify logging paths based on run config 17 | for _, handler in config["handlers"].items(): 18 | if "filename" in handler: 19 | handler["filename"] = str(save_dir / handler["filename"]) 20 | 21 | logging.config.dictConfig(config) 22 | else: 23 | print( 24 | "Warning: logging configuration file is not found in {}.".format(log_config) 25 | ) 26 | logging.basicConfig(level=default_level) 27 | -------------------------------------------------------------------------------- /deepvelo/logger/logger_config.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "version": 1, 4 | "disable_existing_loggers": false, 5 | "formatters": { 6 | "simple": {"format": "%(message)s"}, 7 | "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"} 8 | }, 9 | "handlers": { 10 | "console": { 11 | "class": "logging.StreamHandler", 12 | "level": "DEBUG", 13 | "formatter": "simple", 14 | "stream": "ext://sys.stdout" 15 | }, 16 | "info_file_handler": { 17 | "class": "logging.handlers.RotatingFileHandler", 18 | "level": "INFO", 19 | "formatter": "datetime", 20 | "filename": "info.log", 21 | "maxBytes": 10485760, 22 | "backupCount": 20, "encoding": "utf8" 23 | } 24 | }, 25 | "root": { 26 | "level": "INFO", 27 | "handlers": [ 28 | "console", 29 | "info_file_handler" 30 | ] 31 | } 32 | } -------------------------------------------------------------------------------- /deepvelo/logger/visualization.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from datetime import datetime 3 | 4 | 5 | class TensorboardWriter: 6 | def __init__(self, log_dir, logger, enabled): 7 | self.writer = None 8 | self.selected_module = "" 9 | 10 | if enabled: 11 | log_dir = str(log_dir) 12 | 13 | # Retrieve vizualization writer. 14 | succeeded = False 15 | for module in ["torch.utils.tensorboard", "tensorboardX"]: 16 | try: 17 | self.writer = importlib.import_module(module).SummaryWriter(log_dir) 18 | succeeded = True 19 | break 20 | except ImportError: 21 | succeeded = False 22 | self.selected_module = module 23 | 24 | if not succeeded: 25 | message = ( 26 | "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " 27 | "this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " 28 | "version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file." 29 | ) 30 | logger.warning(message) 31 | 32 | self.step = 0 33 | self.mode = "" 34 | 35 | self.tb_writer_ftns = { 36 | "add_scalar", 37 | "add_scalars", 38 | "add_image", 39 | "add_images", 40 | "add_audio", 41 | "add_text", 42 | "add_histogram", 43 | "add_pr_curve", 44 | "add_embedding", 45 | } 46 | self.tag_mode_exceptions = {"add_histogram", "add_embedding"} 47 | self.timer = datetime.now() 48 | 49 | def set_step(self, step, mode="train"): 50 | self.mode = mode 51 | self.step = step 52 | if step == 0: 53 | self.timer = datetime.now() 54 | else: 55 | duration = datetime.now() - self.timer 56 | self.add_scalar("steps_per_sec", 1 / duration.total_seconds()) 57 | self.timer = datetime.now() 58 | 59 | def __getattr__(self, name): 60 | """ 61 | If visualization is configured to use: 62 | return add_data() methods of tensorboard with additional information (step, tag) added. 63 | Otherwise: 64 | return a blank function handle that does nothing 65 | """ 66 | if name in self.tb_writer_ftns: 67 | add_data = getattr(self.writer, name, None) 68 | 69 | def wrapper(tag, data, *args, **kwargs): 70 | if add_data is not None: 71 | # add mode(train/valid) tag 72 | if name not in self.tag_mode_exceptions: 73 | tag = "{}/{}".format(tag, self.mode) 74 | add_data(tag, data, self.step, *args, **kwargs) 75 | 76 | return wrapper 77 | else: 78 | # default action for returning methods defined in this class, set_step() for instance. 79 | try: 80 | attr = object.__getattr__(name) 81 | except AttributeError: 82 | raise AttributeError( 83 | "type object '{}' has no attribute '{}'".format( 84 | self.selected_module, name 85 | ) 86 | ) 87 | return attr 88 | -------------------------------------------------------------------------------- /deepvelo/model/layers.py: -------------------------------------------------------------------------------- 1 | # the GIN layer implementation adopted from Dwivedi, Vijay Prakash, et al. 2 | # "Benchmarking graph neural networks." arXiv preprint arXiv:2003.00982 (2020). 3 | # Github implementation, https://github.com/graphdeeplearning/benchmarking-gnns/blob/master/layers/gin_layer.py 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from deepvelo.base import BaseModel 8 | import dgl.function as fn 9 | 10 | """ 11 | GIN: Graph Isomorphism Networks 12 | HOW POWERFUL ARE GRAPH NEURAL NETWORKS? (Keyulu Xu, Weihua Hu, Jure Leskovec and Stefanie Jegelka, ICLR 2019) 13 | https://arxiv.org/pdf/1810.00826.pdf 14 | """ 15 | 16 | 17 | class GINLayer(BaseModel): 18 | """ 19 | [!] code adapted from dgl implementation of GINConv 20 | Parameters 21 | ---------- 22 | apply_func : callable activation function/layer or None 23 | If not None, apply this function to the updated node feature, 24 | the :math:`f_\Theta` in the formula. 25 | aggr_type : 26 | Aggregator type to use (``sum``, ``max`` or ``mean``). 27 | out_dim : 28 | Rquired for batch norm layer; should match out_dim of apply_func if not None. 29 | dropout : 30 | Required for dropout of output features. 31 | batch_norm : 32 | boolean flag for batch_norm layer. 33 | residual : 34 | boolean flag for using residual connection. 35 | init_eps : optional 36 | Initial :math:`\epsilon` value, default: ``0``. 37 | learn_eps : bool, optional 38 | If True, :math:`\epsilon` will be a learnable parameter. 39 | 40 | """ 41 | 42 | def __init__( 43 | self, 44 | apply_func, 45 | aggr_type, 46 | dropout, 47 | batch_norm, 48 | residual=False, 49 | init_eps=0, 50 | learn_eps=False, 51 | ): 52 | super().__init__() 53 | self.apply_func = apply_func 54 | 55 | if aggr_type == "sum": 56 | self._reducer = fn.sum 57 | elif aggr_type == "max": 58 | self._reducer = fn.max 59 | elif aggr_type == "mean": 60 | self._reducer = fn.mean 61 | else: 62 | raise KeyError("Aggregator type {} not recognized.".format(aggr_type)) 63 | 64 | self.batch_norm = batch_norm 65 | self.residual = residual 66 | self.dropout = dropout 67 | 68 | in_dim = apply_func.mlp.input_dim 69 | out_dim = apply_func.mlp.output_dim 70 | 71 | if in_dim != out_dim: 72 | self.residual = False 73 | 74 | # to specify whether eps is trainable or not. 75 | if learn_eps: 76 | self.eps = torch.nn.Parameter(torch.FloatTensor([init_eps])) 77 | else: 78 | self.register_buffer("eps", torch.FloatTensor([init_eps])) 79 | 80 | self.bn_node_h = nn.BatchNorm1d(out_dim) 81 | 82 | def forward(self, g, h): 83 | h_in = h # for residual connection 84 | 85 | g = g.local_var() 86 | g.ndata["h"] = h 87 | g.update_all(fn.copy_u("h", "m"), self._reducer("m", "neigh")) 88 | h = (1 + self.eps) * h + g.ndata["neigh"] 89 | if self.apply_func is not None: 90 | h = self.apply_func(h) 91 | 92 | if self.batch_norm: 93 | h = self.bn_node_h(h) # batch normalization 94 | 95 | h = F.relu(h) # non-linear activation 96 | 97 | if self.residual: 98 | h = h_in + h # residual connection 99 | 100 | h = F.dropout(h, self.dropout, training=self.training) 101 | 102 | return h 103 | 104 | 105 | class ApplyNodeFunc(BaseModel): 106 | """ 107 | This class is used in class GINNet 108 | Update the node feature hv with MLP 109 | """ 110 | 111 | def __init__(self, mlp): 112 | super().__init__() 113 | self.mlp = mlp 114 | 115 | def forward(self, h): 116 | h = self.mlp(h) 117 | return h 118 | 119 | 120 | class MLP(BaseModel): 121 | """MLP with linear output""" 122 | 123 | def __init__(self, num_layers, input_dim, hidden_dim, output_dim): 124 | 125 | super().__init__() 126 | self.linear_or_not = True # default is linear model 127 | self.num_layers = num_layers 128 | self.output_dim = output_dim 129 | self.input_dim = input_dim 130 | 131 | if num_layers < 1: 132 | raise ValueError("number of layers should be positive!") 133 | elif num_layers == 1: 134 | # Linear model 135 | self.linear = nn.Linear(input_dim, output_dim) 136 | else: 137 | # Multi-layer model 138 | self.linear_or_not = False 139 | self.linears = torch.nn.ModuleList() 140 | self.batch_norms = torch.nn.ModuleList() 141 | 142 | self.linears.append(nn.Linear(input_dim, hidden_dim)) 143 | for layer in range(num_layers - 2): 144 | self.linears.append(nn.Linear(hidden_dim, hidden_dim)) 145 | self.linears.append(nn.Linear(hidden_dim, output_dim)) 146 | 147 | for layer in range(num_layers - 1): 148 | self.batch_norms.append(nn.BatchNorm1d((hidden_dim))) 149 | 150 | def forward(self, x): 151 | if self.linear_or_not: 152 | # If linear model 153 | return self.linear(x) 154 | else: 155 | # If MLP 156 | h = x 157 | for i in range(self.num_layers - 1): 158 | h = F.relu(self.batch_norms[i](self.linears[i](h))) 159 | return self.linears[-1](h) 160 | -------------------------------------------------------------------------------- /deepvelo/model/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def accuracy(output, target): 5 | with torch.no_grad(): 6 | pred = torch.argmax(output, dim=1) 7 | assert pred.shape[0] == len(target) 8 | correct = 0 9 | correct += torch.sum(pred == target).item() 10 | return correct / len(target) 11 | 12 | 13 | def top_k_acc(output, target, k=3): 14 | with torch.no_grad(): 15 | pred = torch.topk(output, k, dim=1)[1] 16 | assert pred.shape[0] == len(target) 17 | correct = 0 18 | for i in range(k): 19 | correct += torch.sum(pred[:, i] == target).item() 20 | return correct / len(target) 21 | 22 | 23 | def mse(output, target): 24 | with torch.no_grad(): 25 | se = torch.mean((target - output) ** 2) 26 | return se 27 | 28 | 29 | def min_mse(output, target): 30 | """ 31 | selects one closest cell and computes the loss 32 | 33 | the target is the set of velocity target candidates, 34 | find the closest in them. 35 | 36 | output: torch.tensor e.g. (128, 2000) 37 | target: torch.tensor e.g. (128, 30, 2000) 38 | """ 39 | with torch.no_grad(): 40 | distance = torch.pow( 41 | target - torch.unsqueeze(output, 1), exponent=2 42 | ) # (128, 30, 2000) 43 | distance = torch.sum(distance, dim=2) # (128, 30) 44 | min_distance = torch.min(distance, dim=1)[0] # (128,) 45 | 46 | # loss = torch.mean(torch.max(torch.tensor(alpha).float(), min_distance)) 47 | se = torch.mean(min_distance) 48 | 49 | return se 50 | -------------------------------------------------------------------------------- /deepvelo/parse_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from pathlib import Path 4 | from functools import reduce, partial 5 | from operator import getitem 6 | from datetime import datetime 7 | from typing import Callable 8 | from deepvelo.logger import setup_logging 9 | from deepvelo.utils import read_json, write_json, validate_config 10 | 11 | 12 | class ConfigParser: 13 | def __init__( 14 | self, 15 | config, 16 | resume=None, 17 | modification=None, 18 | run_id=None, 19 | validator: Callable = validate_config, 20 | ): 21 | """ 22 | class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving 23 | and logging module. 24 | :param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example. 25 | :param resume: String, path to the checkpoint being loaded. 26 | :param modification: Dict keychain:value, specifying position values to be replaced from config dict. 27 | :param run_id: Unique Identifier for training processes. Used to save checkpoints and training log. Timestamp is being used as default 28 | """ 29 | # load config file and apply modification 30 | self._config = _update_config(config, modification) 31 | self.resume = resume 32 | 33 | if validator: 34 | self._config = validator(self._config) 35 | 36 | # set save_dir where trained model and log will be saved. 37 | save_dir = Path(self.config["trainer"]["save_dir"]) 38 | 39 | exper_name = self.config["name"] 40 | if run_id is None: # use timestamp as default run-id 41 | run_id = datetime.now().strftime(r"%m%d_%H%M%S") 42 | self._save_dir = save_dir / "models" / exper_name / run_id 43 | self._log_dir = save_dir / "log" / exper_name / run_id 44 | 45 | # make directory for saving checkpoints and log. 46 | exist_ok = run_id == "" 47 | self.save_dir.mkdir(parents=True, exist_ok=exist_ok) 48 | self.log_dir.mkdir(parents=True, exist_ok=exist_ok) 49 | 50 | # save updated config file to the checkpoint dir 51 | write_json(self.config, self.save_dir / "config.json") 52 | 53 | # configure logging module 54 | setup_logging(self.log_dir) 55 | self.log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} 56 | 57 | @classmethod 58 | def from_args(cls, args, options=""): 59 | """ 60 | Initialize this class from some cli arguments. Used in train, test. 61 | """ 62 | for opt in options: 63 | args.add_argument(*opt.flags, default=None, type=opt.type) 64 | if not isinstance(args, tuple): 65 | args = args.parse_args() 66 | 67 | if args.device is not None: 68 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 69 | if args.resume is not None: 70 | resume = Path(args.resume) 71 | cfg_fname = resume.parent / "config.json" 72 | else: 73 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." 74 | assert args.config is not None, msg_no_cfg 75 | resume = None 76 | cfg_fname = Path(args.config) 77 | 78 | config = read_json(cfg_fname) 79 | if args.config and resume: 80 | # update new config for fine-tuning 81 | config.update(read_json(args.config)) 82 | 83 | # parse custom cli options into dictionary 84 | modification = { 85 | opt.target: getattr(args, _get_opt_name(opt.flags)) for opt in options 86 | } 87 | return cls(config, resume, modification) 88 | 89 | def init_obj(self, name, module, *args, **kwargs): 90 | """ 91 | Finds a function handle with the name given as 'type' in config, and returns the 92 | instance initialized with corresponding arguments given. 93 | 94 | `object = config.init_obj('name', module, a, b=1)` 95 | is equivalent to 96 | `object = module.name(a, b=1)` 97 | """ 98 | module_name = self[name]["type"] 99 | module_args = dict(self[name]["args"]) 100 | assert all( 101 | [k not in module_args for k in kwargs] 102 | ), "Overwriting kwargs given in config file is not allowed" 103 | module_args.update(kwargs) 104 | return getattr(module, module_name)(*args, **module_args) 105 | 106 | def init_ftn(self, name, module, *args, **kwargs): 107 | """ 108 | Finds a function handle with the name given as 'type' in config, and returns the 109 | function with given arguments fixed with functools.partial. 110 | 111 | `function = config.init_ftn('name', module, a, b=1)` 112 | is equivalent to 113 | `function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`. 114 | """ 115 | module_name = self[name]["type"] 116 | module_args = dict(self[name]["args"]) 117 | assert all( 118 | [k not in module_args for k in kwargs] 119 | ), "Overwriting kwargs given in config file is not allowed" 120 | module_args.update(kwargs) 121 | return partial(getattr(module, module_name), *args, **module_args) 122 | 123 | def __getitem__(self, name): 124 | """Access items like ordinary dict.""" 125 | return self.config[name] 126 | 127 | def get_logger(self, name, verbosity=2): 128 | msg_verbosity = "verbosity option {} is invalid. Valid options are {}.".format( 129 | verbosity, self.log_levels.keys() 130 | ) 131 | assert verbosity in self.log_levels, msg_verbosity 132 | logger = logging.getLogger(name) 133 | logger.setLevel(self.log_levels[verbosity]) 134 | return logger 135 | 136 | # setting read-only attributes 137 | @property 138 | def config(self): 139 | return self._config 140 | 141 | @property 142 | def save_dir(self): 143 | return self._save_dir 144 | 145 | @property 146 | def log_dir(self): 147 | return self._log_dir 148 | 149 | 150 | # helper functions to update config dict with custom cli options 151 | def _update_config(config, modification): 152 | if modification is None: 153 | return config 154 | 155 | for k, v in modification.items(): 156 | if v is not None: 157 | _set_by_path(config, k, v) 158 | return config 159 | 160 | 161 | def _get_opt_name(flags): 162 | for flg in flags: 163 | if flg.startswith("--"): 164 | return flg.replace("--", "") 165 | return flags[0].replace("--", "") 166 | 167 | 168 | def _set_by_path(tree, keys, value): 169 | """Set a value in a nested object in tree by sequence of keys.""" 170 | keys = keys.split(";") 171 | _get_by_path(tree, keys[:-1])[keys[-1]] = value 172 | 173 | 174 | def _get_by_path(tree, keys): 175 | """Access a nested object in tree by sequence of keys.""" 176 | return reduce(getitem, keys, tree) 177 | -------------------------------------------------------------------------------- /deepvelo/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from .eval import evaluate 2 | -------------------------------------------------------------------------------- /deepvelo/pipeline/eval.py: -------------------------------------------------------------------------------- 1 | # this is the evaluation pipeline 2 | import json 3 | from pathlib import Path 4 | from typing import Any, Dict, List, Optional, Tuple, Union 5 | import warnings 6 | 7 | from anndata import AnnData 8 | 9 | from ..utils import cross_boundary_correctness, velocity_confidence 10 | from ..plot import compare_plot 11 | from ..tool import stats_test 12 | 13 | PathLike = Union[str, Path] 14 | 15 | 16 | def evaluate( 17 | result_adatas: Dict[str, AnnData], 18 | metrics: List[str] = [ 19 | "direction_score", 20 | "overall_consistency", 21 | "celltype_consistency", 22 | ], 23 | cluster_edges: Optional[List[Tuple[str]]] = None, 24 | cluster_key: str = "clusters", 25 | vkey: str = "velocity", 26 | basis: str = "umap", 27 | save_dir: Optional[PathLike] = None, 28 | ) -> Dict[str, Any]: 29 | """ 30 | Evaluate metrics and compare on provided adatas that contain the results. 31 | 32 | Args: 33 | result_adatas (dict): dictionary of adatas that contain the results. 34 | metrics (list): list of metrics to evaluate. Available metrics: `direction_score`, 35 | `overall_consistency`, `celltype_consistency`. Default: all metrics. 36 | cluster_edges (list): list of tuples of cluster names that are considered as 37 | boundary clusters. Required if `direction_score` is in `metrics`. 38 | cluster_key (str): key of cluster annotation in `adata.obs`, default: `clusters`. 39 | vkey (str): key of velocity in `adata.layers`, default: `velocity`. 40 | save_dir (str or Path): directory to save the plots. If None, do not save plots. 41 | 42 | Returns: 43 | dict: dictionary of evaluation results. 44 | """ 45 | 46 | if "direction_score" in metrics: 47 | assert cluster_edges is not None 48 | 49 | use_methods = list(result_adatas.keys()) 50 | adata_list = [result_adatas[method] for method in use_methods] 51 | eval_results = {} 52 | 53 | if "direction_score" in metrics: 54 | eval_results["direction_score"] = {} 55 | for method in use_methods: 56 | cbcs, avg_cbc = cross_boundary_correctness( 57 | result_adatas[method], 58 | cluster_key, 59 | vkey, 60 | cluster_edges, 61 | x_emb_key=basis, # or Ms 62 | ) 63 | print( 64 | f"Average cross-boundary correctness of {method}: {avg_cbc:.2f}\n", cbcs 65 | ) 66 | all_cbcs_ = result_adatas[method].uns["raw_direction_scores"] 67 | eval_results["direction_score"][method] = { 68 | "mean": all_cbcs_.mean(), 69 | "std": all_cbcs_.std(), 70 | } 71 | 72 | ax_hist, ax_stat = compare_plot( 73 | *adata_list, 74 | labels=list(result_adatas.keys()), 75 | data=[adata.uns["raw_direction_scores"] for adata in adata_list], 76 | ylabel="Direction scores", 77 | ) 78 | if save_dir is not None: 79 | ax_hist.get_figure().savefig(save_dir / "direction_score_hist.png", dpi=300) 80 | ax_stat.get_figure().savefig(save_dir / "direction_score_comp.png", dpi=300) 81 | _, pval = stats_test(*(ad.uns["raw_direction_scores"] for ad in adata_list)) 82 | eval_results["direction_score"]["pval"] = pval 83 | 84 | # # recompute on basis 85 | # eval_results[f"{basis}_direction_score"] = {} 86 | # for method in use_methods: 87 | # cbcs, avg_cbc = cross_boundary_correctness( 88 | # result_adatas[method], 89 | # cluster_key, 90 | # vkey, 91 | # cluster_edges, 92 | # x_emb_key=basis, 93 | # output_key_prefix=f"{basis}_", 94 | # ) 95 | # print( 96 | # f"Average cross-boundary correctness of {method} on {basis}: {avg_cbc:.2f}\n", 97 | # cbcs, 98 | # ) 99 | # eval_results[f"{basis}_direction_score"][method] = avg_cbc 100 | 101 | # ax_hist, ax_stat = compare_plot( 102 | # *adata_list, 103 | # labels=list(result_adatas.keys()), 104 | # data=[adata.uns[f"{basis}_raw_direction_scores"] for adata in adata_list], 105 | # ylabel=f"{basis.upper()} direction scores", 106 | # ) 107 | # if save_dir is not None: 108 | # ax_hist.get_figure().savefig( 109 | # save_dir / f"{basis}_direction_score_hist.png", dpi=300 110 | # ) 111 | # ax_stat.get_figure().savefig( 112 | # save_dir / f"{basis}_direction_score_comp.png", dpi=300 113 | # ) 114 | # _, pval = stats_test( 115 | # *(ad.uns[f"{basis}_raw_direction_scores"] for ad in adata_list) 116 | # ) 117 | # eval_results[f"{basis}_direction_score"]["pval"] = pval 118 | 119 | if "overall_consistency" in metrics: 120 | # Compare consistency score 121 | eval_results["overall_consistency"] = {} 122 | for method in use_methods: 123 | velocity_confidence(result_adatas[method], vkey=vkey, method="cosine") 124 | mean_cosine = result_adatas[method].obs[f"{vkey}_confidence_cosine"].mean() 125 | std_cosine = result_adatas[method].obs[f"{vkey}_confidence_cosine"].std() 126 | eval_results["overall_consistency"][method] = { 127 | "mean": mean_cosine, 128 | "std": std_cosine, 129 | } 130 | ax_hist, ax_stat = compare_plot(*adata_list, labels=list(result_adatas.keys())) 131 | if save_dir is not None: 132 | ax_hist.get_figure().savefig( 133 | save_dir / "overall_consistency_hist.png", dpi=300 134 | ) 135 | ax_stat.get_figure().savefig( 136 | save_dir / "overall_consistency_comp.png", dpi=300 137 | ) 138 | _, pval = stats_test( 139 | *(ad.obs[f"{vkey}_confidence_cosine"] for ad in adata_list) 140 | ) 141 | eval_results["overall_consistency"]["pval"] = pval 142 | 143 | if "celltype_consistency" in metrics: 144 | eval_results["celltype_consistency"] = {} 145 | # cosine similarity, compute within Celltype 146 | for method in use_methods: 147 | velocity_confidence( 148 | result_adatas[method], vkey=vkey, method="cosine", scope_key=cluster_key 149 | ) 150 | res_cosine = result_adatas[method].obs[f"{vkey}_confidence_cosine"] 151 | if res_cosine.isna().sum() > 0: 152 | warnings.warn( 153 | f"NaN values found in adata.obs[{vkey}_confidence_cosine]. " 154 | "NaN values will be removed for calculating the average." 155 | ) 156 | res_cosine = res_cosine.dropna() 157 | eval_results["celltype_consistency"][method] = { 158 | "mean": res_cosine.mean(), 159 | "std": res_cosine.std(), 160 | } 161 | ax_hist, ax_stat = compare_plot( 162 | *adata_list, 163 | labels=list(result_adatas.keys()), 164 | ylabel="Celltype-wise consistency", 165 | ) 166 | if save_dir is not None: 167 | ax_hist.get_figure().savefig( 168 | save_dir / "celltype_consistency_hist.png", dpi=300 169 | ) 170 | ax_stat.get_figure().savefig( 171 | save_dir / "celltype_consistency_comp.png", dpi=300 172 | ) 173 | _, pval = stats_test( 174 | *(ad.obs[f"{vkey}_confidence_cosine"].dropna() for ad in adata_list) 175 | ) 176 | eval_results["celltype_consistency"]["pval"] = pval 177 | 178 | if save_dir is not None: 179 | with open(save_dir / "eval_results.json", "w") as f: 180 | json.dump(eval_results, f, indent=4) 181 | 182 | return eval_results 183 | -------------------------------------------------------------------------------- /deepvelo/plot/__init__.py: -------------------------------------------------------------------------------- 1 | from .plot import ( 2 | parula_map, 3 | parula_map_r, 4 | statplot, 5 | compare_plot, 6 | dist_plot, 7 | gene_scatter, 8 | draw_var_dist, 9 | ) 10 | from .scatter import scatter 11 | -------------------------------------------------------------------------------- /deepvelo/tool/__init__.py: -------------------------------------------------------------------------------- 1 | from .driver_gene import driver_gene 2 | from .stats import stats_test 3 | from .velocity import velocity, get_velo_genes 4 | from .kinetic_rates import process_kinetic_rates 5 | -------------------------------------------------------------------------------- /deepvelo/tool/driver_gene.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from anndata import AnnData 4 | import numpy as np 5 | from scipy.stats import pearsonr 6 | from scvelo import logging as logg 7 | from sklearn.linear_model import LassoCV 8 | 9 | 10 | def driver_gene( 11 | adata: AnnData, 12 | method: str = "lasso", # "corr", "lasso" 13 | random_seed: int = 0, 14 | expr_key: str = "Ms", 15 | do_log: bool = False, 16 | norm_along: str = "genes", # "cells", "genes" 17 | ) -> Optional[LassoCV]: 18 | 19 | if method not in ["corr", "lasso"]: 20 | raise ValueError("method must be either corr or lasso") 21 | 22 | if "velocity_pseudotime" not in adata.obs.keys(): 23 | raise ValueError( 24 | "velocity_pseudotime is not in adata.obs. " 25 | "Please run scv.tl.velocity_pseudotime(adata) first", 26 | ) 27 | 28 | assert expr_key in adata.layers.keys(), f"{expr_key} is not in adata.layers" 29 | if do_log: 30 | if f"log_{expr_key}" in adata.layers.keys(): 31 | logg.info(f"reuse existing 'log_{expr_key}' (adata.layers)") 32 | else: 33 | adata.layers[f"log_{expr_key}"] = np.log1p(adata.layers[expr_key]) 34 | logg.info(f"added 'log_{expr_key}' (adata.layers)") 35 | expr_key = f"log_{expr_key}" 36 | 37 | if norm_along == "cells": 38 | norm_axis = 0 39 | elif norm_along == "genes": 40 | norm_axis = 1 41 | else: 42 | raise ValueError("norm_along must be either cells or genes") 43 | 44 | if method == "corr": 45 | time_corr = [] 46 | time_corr_pval = [] 47 | for gene in adata.var_names: 48 | tcorr, pval = pearsonr( 49 | adata.obs["velocity_pseudotime"], 50 | adata[:, gene].layers[expr_key].flatten(), 51 | ) 52 | time_corr.append(tcorr) 53 | time_corr_pval.append(pval) 54 | 55 | adata.var["time_corr"] = time_corr 56 | adata.var["time_corr_abs"] = np.abs(time_corr) 57 | adata.var["time_corr_pval"] = time_corr_pval 58 | 59 | logg.info("added 'time_corr' (adata.var)") 60 | logg.info("added 'time_corr_abs' (adata.var)") 61 | logg.info("added 'time_corr_pval' (adata.var)") 62 | 63 | elif method == "lasso": 64 | exprs = adata.layers[expr_key] 65 | # normalize gene expression to standard Z-score, each row is a sample 66 | exprs = (exprs - exprs.mean(axis=norm_axis, keepdims=True)) / ( 67 | exprs.std(axis=norm_axis, keepdims=True) + 1e-8 68 | ) 69 | 70 | target_time = np.array(adata.obs["velocity_pseudotime"]).reshape(-1, 1) 71 | # make target zero mean 72 | target_time = target_time - target_time.mean() 73 | 74 | try: 75 | lasso = LassoCV(cv=5, random_state=random_seed) 76 | lasso.fit(exprs, target_time) 77 | except ValueError: 78 | # if Gram matrix error, https://github.com/scikit-learn/scikit-learn/pull/22059) 79 | lasso = LassoCV(cv=5, random_state=random_seed, precompute=False) 80 | lasso.fit(exprs, target_time) 81 | lasso_score = lasso.score(exprs, target_time) 82 | lasso_coef = lasso.coef_ 83 | lasso_coef_abs = np.abs(lasso_coef) 84 | 85 | adata.var["lasso_coef"] = lasso_coef 86 | adata.var["lasso_coef_abs"] = lasso_coef_abs 87 | 88 | logg.info("added 'lasso_coef' (adata.var)") 89 | logg.info("added 'lasso_coef_abs' (adata.var)") 90 | 91 | return lasso 92 | -------------------------------------------------------------------------------- /deepvelo/tool/kinetic_rates.py: -------------------------------------------------------------------------------- 1 | from anndata import AnnData 2 | from umap import UMAP 3 | from sklearn.decomposition import PCA 4 | import numpy as np 5 | from scvelo import logging as logg 6 | 7 | 8 | def process_kinetic_rates(adata, mode=["total_map"], seed=0): 9 | if "cell_specific_alpha" in adata.layers: 10 | all_rates = np.concatenate( 11 | [ 12 | adata.layers["cell_specific_beta"], 13 | adata.layers["cell_specific_gamma"], 14 | adata.layers["cell_specific_alpha"], 15 | ], 16 | axis=1, 17 | ) 18 | else: 19 | all_rates = np.concatenate( 20 | [ 21 | adata.layers["cell_specific_beta"], 22 | adata.layers["cell_specific_gamma"], 23 | ], 24 | axis=1, 25 | ) 26 | 27 | # pca and umap of all rates 28 | if "total_map" in mode and "X_rates_umap" not in adata.obsm: 29 | rates_pca = PCA(n_components=30, random_state=seed).fit_transform(all_rates) 30 | adata.obsm["X_rates_pca"] = rates_pca 31 | logg.info("Added `X_rates_pca` (adata.obsm)") 32 | 33 | rates_umap = UMAP( 34 | n_neighbors=60, 35 | min_dist=0.6, 36 | spread=0.9, 37 | random_state=seed, 38 | ).fit_transform(rates_pca) 39 | adata.obsm["X_rates_umap"] = rates_umap 40 | 41 | logg.info("Added `X_rates_umap` (adata.obsm)") 42 | 43 | # # pca and umap of gene-wise rates 44 | # if "gene_wise_map" in mode: 45 | # rates_pca_gene_wise = PCA(n_components=30, random_state=seed).fit_transform( 46 | # adata.layers["Ms"].T 47 | # ) 48 | # adata.varm["rates_pca"] = rates_pca_gene_wise 49 | 50 | # rates_umap_gene_wise = UMAP( 51 | # n_neighbors=60, 52 | # min_dist=0.6, 53 | # spread=0.9, 54 | # random_state=seed, 55 | # ).fit_transform(rates_pca_gene_wise) 56 | # adata.varm["rates_umap"] = rates_umap_gene_wise 57 | -------------------------------------------------------------------------------- /deepvelo/tool/stats.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | from scipy.stats import mannwhitneyu, kruskal 4 | 5 | 6 | def stats_test( 7 | *data: List, 8 | threshold: float = 0.05, 9 | verbose: bool = True, 10 | **kwargs, 11 | ) -> Tuple[float, float]: 12 | """ 13 | Run statistical test on data. 14 | 15 | Args: 16 | data (list): list of data to compare. If len(data) == 2, use Mann-Whitney U test. 17 | Otherwise, use Kruskal-Wallis test. 18 | threshold (float): threshold for p-value. 19 | verbose (bool): whether to print the result. 20 | **kwargs: keyword arguments for `scipy.stats.mannwhitneyu` or `scipy.stats.kruskal`. 21 | 22 | Returns: 23 | tuple: (statistic, p-value) 24 | """ 25 | if len(data) == 2: 26 | stat, pval = mannwhitneyu(*data, **kwargs) 27 | else: 28 | stat, pval = kruskal(*data, **kwargs) 29 | if verbose: 30 | print(f"statistic: {stat}, p-value: {pval}") 31 | if pval < threshold: 32 | print("Significant difference. Reject null hypothesis.") 33 | else: 34 | print("Insignificant difference. Accept null hypothesis.") 35 | 36 | return stat, pval 37 | -------------------------------------------------------------------------------- /deepvelo/train.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Callable, Mapping 3 | import numpy as np 4 | 5 | import torch 6 | from anndata import AnnData 7 | from scvelo import logging as logg 8 | 9 | from deepvelo.trainer import Trainer 10 | import deepvelo.data_loader.data_loaders as module_data 11 | import deepvelo.model.loss as module_loss 12 | import deepvelo.model.metric as module_metric 13 | import deepvelo.model.model as module_arch 14 | from deepvelo.parse_config import ConfigParser 15 | 16 | 17 | # a hack to make constants, see https://stackoverflow.com/questions/3203286 18 | class MetaConstants(type): 19 | @property 20 | def default_configs(cls): 21 | return deepcopy(cls._default_configs) 22 | 23 | 24 | class Constants(object, metaclass=MetaConstants): 25 | _default_configs = { 26 | "name": "DeepVelo_Base", 27 | "n_gpu": 1, # whether to use GPU 28 | "arch": { 29 | "type": "VeloGCN", 30 | "args": { 31 | "layers": [64, 64], 32 | "dropout": 0.2, 33 | "fc_layer": False, 34 | "pred_unspliced": False, 35 | }, 36 | }, 37 | "data_loader": { 38 | "type": "VeloDataLoader", 39 | "args": { 40 | "shuffle": False, 41 | "validation_split": 0.0, 42 | "num_workers": 2, 43 | "type": "pca, t", 44 | "topC": 30, 45 | "topG": 20, 46 | "velocity_genes": False, 47 | "use_scaled_u": False, 48 | }, 49 | }, 50 | "optimizer": { 51 | "type": "Adam", 52 | "args": {"lr": 0.001, "weight_decay": 0, "amsgrad": True}, 53 | }, 54 | "loss": { 55 | "type": "mle_plus_direction", 56 | "args": { 57 | "pearson_scale": 18.0, 58 | "coeff_u": 1.0, 59 | "coeff_s": 1.0, 60 | "inner_batch_size": None, # if None, will autoset the size. 61 | "stop_pearson_after": 1000, # by default, set this large to avoid stopping 62 | }, 63 | }, 64 | "constraint_loss": False, 65 | "mask_zeros": False, 66 | "metrics": ["mse"], 67 | "lr_scheduler": {"type": "StepLR", "args": {"step_size": 1, "gamma": 0.97}}, 68 | "trainer": { 69 | "epochs": 100, 70 | "save_dir": "saved/", 71 | "save_period": 1000, 72 | "verbosity": 1, 73 | "monitor": "min mse", 74 | "early_stop": 1000, 75 | "tensorboard": True, 76 | "grad_clip": False, 77 | }, 78 | } 79 | 80 | 81 | def train( 82 | adata: AnnData, 83 | configs: Mapping, 84 | verbose: bool = False, 85 | return_kinetic_rates: bool = True, 86 | callback: Callable = None, 87 | **kwargs, 88 | ): 89 | batch_size, n_genes = adata.layers["Ms"].shape 90 | if configs["data_loader"]["args"]["velocity_genes"]: 91 | n_genes = int(np.sum(adata.var["velocity_genes"])) 92 | configs["arch"]["args"]["n_genes"] = n_genes 93 | configs["data_loader"]["args"]["batch_size"] = batch_size 94 | config = ConfigParser(configs) 95 | logger = config.get_logger("train") 96 | 97 | # setup data_loader instances, use adata as the data_source to load inmemory data 98 | data_loader = config.init_obj("data_loader", module_data, data_source=adata) 99 | valid_data_loader = data_loader.split_validation() 100 | 101 | # build model architecture, then print to console 102 | if config["arch"]["type"] in ["VeloGCN", "VeloGIN"]: 103 | model = config.init_obj("arch", module_arch, g=data_loader.dataset.g) 104 | else: 105 | model = config.init_obj("arch", module_arch) 106 | logger.info(f"Beginning training of {configs['name']} ...") 107 | if verbose: 108 | logger.info(configs) 109 | logger.info(model) 110 | 111 | # get function handles of loss and metrics 112 | criterion = getattr(module_loss, configs["loss"]["type"]) 113 | metrics = [getattr(module_metric, met) for met in configs["metrics"]] 114 | 115 | # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler 116 | trainable_params = filter(lambda p: p.requires_grad, model.parameters()) 117 | optimizer = config.init_obj("optimizer", torch.optim, trainable_params) 118 | lr_scheduler = config.init_obj("lr_scheduler", torch.optim.lr_scheduler, optimizer) 119 | 120 | trainer = Trainer( 121 | model, 122 | criterion, 123 | metrics, 124 | optimizer, 125 | config=config, 126 | data_loader=data_loader, 127 | valid_data_loader=valid_data_loader, 128 | lr_scheduler=lr_scheduler, 129 | ) 130 | 131 | def callback_wrapper(epoch): 132 | # evaluate all and return the velocity matrix (cells, features) 133 | config_copy = configs["data_loader"]["args"].copy() 134 | config_copy.update(shuffle=False, training=False, data_source=adata) 135 | eval_loader = getattr(module_data, configs["data_loader"]["type"])( 136 | **config_copy 137 | ) 138 | velo_mat, velo_mat_u, kinetic_rates = trainer.eval( 139 | eval_loader, return_kinetic_rates=return_kinetic_rates 140 | ) 141 | 142 | if callback is not None: 143 | callback(adata, velo_mat, velo_mat_u, kinetic_rates, epoch) 144 | else: 145 | logg.warn( 146 | "Set verbose to True but no callback function provided. A possible " 147 | "callback function accepts at least two arguments: adata, velo_mat " 148 | ) 149 | 150 | if verbose: 151 | trainer.train_with_epoch_callback( 152 | callback=callback_wrapper, 153 | freq=kwargs.get("freq", 30), 154 | ) 155 | else: 156 | trainer.train() 157 | 158 | if configs["data_loader"]["args"]["shuffle"] == False: 159 | eval_loader = data_loader 160 | else: 161 | config_copy = configs["data_loader"]["args"].copy() 162 | config_copy.update(shuffle=False, training=False, data_source=adata) 163 | eval_loader = getattr(module_data, configs["data_loader"]["type"])( 164 | **config_copy 165 | ) 166 | velo_mat, velo_mat_u, kinetic_rates = trainer.eval( 167 | eval_loader, return_kinetic_rates=return_kinetic_rates 168 | ) 169 | 170 | print("velo_mat shape:", velo_mat.shape) 171 | # add velocity 172 | if configs["data_loader"]["args"]["velocity_genes"]: 173 | # the predictions only contain the velocity genes 174 | velocity_ = np.full(adata.shape, np.nan, dtype=velo_mat.dtype) 175 | idx = adata.var["velocity_genes"].values 176 | velocity_[:, idx] = velo_mat 177 | if len(velo_mat_u) > 0: 178 | velocity_u = np.full(adata.shape, np.nan, dtype=velo_mat.dtype) 179 | velocity_u[:, idx] = velo_mat_u 180 | else: 181 | velocity_ = velo_mat 182 | velocity_u = velo_mat_u 183 | 184 | assert adata.layers["Ms"].shape == velocity_.shape 185 | adata.layers["velocity"] = velocity_ # (cells, genes) 186 | if len(velo_mat_u) > 0: 187 | adata.layers["velocity_unspliced"] = velocity_u 188 | 189 | logg.hint(f"added 'velocity' (adata.layers)") 190 | logg.hint(f"added 'velocity_unspliced' (adata.layers)") 191 | 192 | if return_kinetic_rates: 193 | for k, v in kinetic_rates.items(): 194 | if v is not None: 195 | if configs["data_loader"]["args"]["velocity_genes"]: 196 | v_ = np.zeros(adata.shape, dtype=v.dtype) 197 | v_[:, adata.var["velocity_genes"].values] = v 198 | v = v_ 199 | adata.layers["cell_specific_" + k] = v 200 | logg.hint(f"added 'cell_specific_{k}' (adata.layers)") 201 | return trainer 202 | -------------------------------------------------------------------------------- /deepvelo/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import * 2 | -------------------------------------------------------------------------------- /deepvelo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import ( 2 | ensure_dir, 3 | read_json, 4 | write_json, 5 | save_model_and_config, 6 | inf_loop, 7 | validate_config, 8 | update_dict, 9 | get_indices, 10 | get_indices_from_csr, 11 | make_dense, 12 | get_weight, 13 | R2, 14 | MetricTracker, 15 | ) 16 | from .confidence import * 17 | from .temporal import * 18 | 19 | # deprecated import velocity 20 | from ..tool.velocity import * 21 | -------------------------------------------------------------------------------- /deepvelo/utils/map_velocity_expression.py: -------------------------------------------------------------------------------- 1 | # to plot the velocity onto the tsne of gene expressions 2 | # from .. import settings 3 | from scvelo.preprocessing.moments import second_order_moments 4 | from scvelo.tools.rank_velocity_genes import rank_velocity_genes 5 | from .scatter import scatter 6 | from scvelo.plotting.utils import ( 7 | savefig_or_show, 8 | default_basis, 9 | default_size, 10 | get_basis, 11 | get_figure_params, 12 | ) 13 | 14 | import numpy as np 15 | import pandas as pd 16 | import matplotlib.pyplot as pl 17 | from matplotlib import rcParams 18 | from scipy.sparse import issparse 19 | 20 | 21 | def velocity_map( 22 | adata, 23 | var_names=None, 24 | basis=None, 25 | vkey="velocity", 26 | mode=None, 27 | fits=None, 28 | layers="all", 29 | color=None, 30 | color_map=None, 31 | colorbar=True, 32 | perc=[2, 98], 33 | alpha=0.5, 34 | size=None, 35 | groupby=None, 36 | groups=None, 37 | legend_loc="none", 38 | legend_fontsize=8, 39 | use_raw=False, 40 | fontsize=None, 41 | figsize=None, 42 | dpi=None, 43 | show=None, 44 | save=None, 45 | ax=None, 46 | ncols=None, 47 | **kwargs, 48 | ): 49 | """Phase and velocity plot for set of genes. 50 | 51 | The phase plot shows spliced against unspliced expressions with steady-state fit. 52 | Further the embedding is shown colored by velocity and expression. 53 | 54 | Arguments 55 | --------- 56 | adata: :class:`~anndata.AnnData` 57 | Annotated data matrix. 58 | var_names: `str` or list of `str` (default: `None`) 59 | Which variables to show. 60 | basis: `str` (default: `'umap'`) 61 | Key for embedding coordinates. 62 | mode: `'stochastic'` or `None` (default: `None`) 63 | Whether to show show covariability phase portrait. 64 | fits: `str` or list of `str` (default: `['velocity', 'dynamics']`) 65 | Which steady-state estimates to show. 66 | layers: `str` or list of `str` (default: `'all'`) 67 | Which layers to show. 68 | color: `str`, list of `str` or `None` (default: `None`) 69 | Key for annotations of observations/cells or variables/genes 70 | color_map: `str` or tuple (default: `['RdYlGn', 'gnuplot_r']`) 71 | String denoting matplotlib color map. If tuple is given, first and latter 72 | color map correspond to velocity and expression, respectively. 73 | perc: tuple, e.g. [2,98] (default: `[2,98]`) 74 | Specify percentile for continuous coloring. 75 | groups: `str`, `list` (default: `None`) 76 | Subset of groups, e.g. [‘g1’, ‘g2’], to which the plot shall be restricted. 77 | groupby: `str`, `list` or `np.ndarray` (default: `None`) 78 | Key of observations grouping to consider. 79 | legend_loc: str (default: 'none') 80 | Location of legend, either 'on data', 'right margin' 81 | or valid keywords for matplotlib.legend. 82 | size: `float` (default: 5) 83 | Point size. 84 | alpha: `float` (default: 1) 85 | Set blending - 0 transparent to 1 opaque. 86 | fontsize: `float` (default: `None`) 87 | Label font size. 88 | figsize: tuple (default: `(7,5)`) 89 | Figure size. 90 | dpi: `int` (default: 80) 91 | Figure dpi. 92 | show: `bool`, optional (default: `None`) 93 | Show the plot, do not return axis. 94 | save: `bool` or `str`, optional (default: `None`) 95 | If `True` or a `str`, save the figure. A string is appended to the default 96 | filename. Infer the filetype if ending on {'.pdf', '.png', '.svg'}. 97 | ax: `matplotlib.Axes`, optional (default: `None`) 98 | A matplotlib axes object. Only works if plotting a single component. 99 | ncols: `int` or `None` (default: `None`) 100 | Number of columns to arange multiplots into. 101 | 102 | """ 103 | basis = default_basis(adata) if basis is None else get_basis(adata, basis) 104 | color, color_map = kwargs.pop("c", color), kwargs.pop("cmap", color_map) 105 | if fits is None: 106 | fits = ["velocity", "dynamics"] 107 | if color_map is None: 108 | color_map = ["RdYlGn", "gnuplot_r"] 109 | 110 | if isinstance(groupby, str) and groupby in adata.obs.keys(): 111 | if ( 112 | "rank_velocity_genes" not in adata.uns.keys() 113 | or adata.uns["rank_velocity_genes"]["params"]["groupby"] != groupby 114 | ): 115 | rank_velocity_genes(adata, vkey=vkey, n_genes=10, groupby=groupby) 116 | names = np.array(adata.uns["rank_velocity_genes"]["names"].tolist()) 117 | if groups is None: 118 | var_names = names[:, 0] 119 | else: 120 | groups = [groups] if isinstance(groups, str) else groups 121 | categories = adata.obs[groupby].cat.categories 122 | idx = np.array([any([g in group for g in groups]) for group in categories]) 123 | var_names = np.hstack(names[idx, : int(10 / idx.sum())]) 124 | elif var_names is not None: 125 | if isinstance(var_names, str): 126 | var_names = [var_names] 127 | else: 128 | var_names = [var for var in var_names if var in adata.var_names] 129 | else: 130 | raise ValueError("No var_names or groups specified.") 131 | var_names = pd.unique(var_names) 132 | 133 | if use_raw or "Ms" not in adata.layers.keys(): 134 | skey, ukey = "spliced", "unspliced" 135 | else: 136 | skey, ukey = "Ms", "Mu" 137 | layers = [vkey, skey] if layers == "all" else layers 138 | layers = [layer for layer in layers if layer in adata.layers.keys() or layer == "X"] 139 | 140 | fits = list(adata.layers.keys()) if fits == "all" else fits 141 | fits = [fit for fit in fits if f"{fit}_gamma" in adata.var.keys()] + ["dynamics"] 142 | stochastic_fits = [fit for fit in fits if f"variance_{fit}" in adata.layers.keys()] 143 | 144 | nplts = 1 + len(layers) + (mode == "stochastic") * 2 145 | ncols = 1 if ncols is None else ncols 146 | nrows = int(np.ceil(len(var_names) / ncols)) 147 | ncols = int(ncols * nplts) 148 | figsize = rcParams["figure.figsize"] if figsize is None else figsize 149 | figsize, dpi = get_figure_params(figsize, dpi, ncols / 2) 150 | if ax is None: 151 | gs_figsize = (figsize[0] * ncols / 2, figsize[1] * nrows / 2) 152 | ax = pl.figure(figsize=gs_figsize, dpi=dpi) 153 | gs = pl.GridSpec(nrows, ncols, wspace=0.5, hspace=0.8) 154 | 155 | # half size, since fontsize is halved in width and height 156 | size = default_size(adata) / 2 if size is None else size 157 | fontsize = rcParams["font.size"] * 0.8 if fontsize is None else fontsize 158 | 159 | scatter_kwargs = dict(colorbar=colorbar, perc=perc, size=size, use_raw=use_raw) 160 | scatter_kwargs.update(dict(fontsize=fontsize, legend_fontsize=legend_fontsize)) 161 | 162 | for v, var in enumerate(var_names): 163 | _adata = adata[:, var] 164 | s, u = _adata.layers[skey], _adata.layers[ukey] 165 | if issparse(s): 166 | s, u = s.A, u.A 167 | 168 | # velocity and expression plots 169 | for l, layer in enumerate(layers): 170 | ax = pl.subplot(gs[v * nplts + l + 1]) 171 | title = "expression" if layer in ["X", skey] else layer 172 | # _kwargs = {} if title == 'expression' else kwargs 173 | cmap = color_map 174 | if isinstance(color_map, (list, tuple)): 175 | cmap = color_map[-1] if layer in ["X", skey] else color_map[0] 176 | scatter( 177 | adata, 178 | basis=basis, 179 | color=var, 180 | layer=layer, 181 | title=title, 182 | color_map=cmap, 183 | alpha=0.1, # alpha, 184 | frameon=False, 185 | show=False, 186 | ax=ax, 187 | save=False, 188 | **scatter_kwargs, 189 | **kwargs, 190 | ) 191 | 192 | savefig_or_show(dpi=dpi, save=save, show=show) 193 | if show is False: 194 | return ax 195 | -------------------------------------------------------------------------------- /deepvelo/utils/optimization.py: -------------------------------------------------------------------------------- 1 | from scvelo.tools.utils import sum_obs, prod_sum_obs, make_dense 2 | from scipy.optimize import minimize 3 | from scipy.sparse import csr_matrix, issparse 4 | import numpy as np 5 | import warnings 6 | 7 | 8 | def get_weight(x, y=None, perc=95): 9 | xy_norm = np.array(x.A if issparse(x) else x) 10 | if y is not None: 11 | if issparse(y): 12 | y = y.A 13 | xy_norm = xy_norm / np.clip(np.max(xy_norm, axis=0), 1e-3, None) 14 | xy_norm += y / np.clip(np.max(y, axis=0), 1e-3, None) 15 | if isinstance(perc, int): 16 | weights = xy_norm >= np.percentile(xy_norm, perc, axis=0) 17 | else: 18 | lb, ub = np.percentile(xy_norm, perc, axis=0) 19 | weights = (xy_norm <= lb) | (xy_norm >= ub) 20 | return weights 21 | 22 | 23 | def leastsq_NxN( 24 | x, y, fit_offset=False, perc=None, constraint_positive_offset=True, mask_zero=False 25 | ): 26 | """Solves least squares X*b=Y for b.""" 27 | if perc is not None: 28 | if not fit_offset and isinstance(perc, (list, tuple)): 29 | perc = perc[1] 30 | weights = csr_matrix(get_weight(x, y, perc=perc)).astype(bool) 31 | x, y = weights.multiply(x).tocsr(), weights.multiply(y).tocsr() 32 | else: 33 | weights = None 34 | 35 | # mask zero 36 | if mask_zero: 37 | x[y == 0] = 0 38 | 39 | with warnings.catch_warnings(): 40 | warnings.simplefilter("ignore") 41 | xx_ = prod_sum_obs(x, x) 42 | xy_ = prod_sum_obs(x, y) 43 | 44 | if fit_offset: 45 | n_obs = x.shape[0] if weights is None else sum_obs(weights) 46 | x_ = sum_obs(x) / n_obs 47 | y_ = sum_obs(y) / n_obs 48 | gamma = (xy_ / n_obs - x_ * y_) / (xx_ / n_obs - x_**2) 49 | offset = y_ - gamma * x_ 50 | 51 | # fix negative offsets: 52 | if constraint_positive_offset: 53 | idx = offset < 0 54 | if gamma.ndim > 0: 55 | gamma[idx] = xy_[idx] / xx_[idx] 56 | else: 57 | gamma = xy_ / xx_ 58 | offset = np.clip(offset, 0, None) 59 | else: 60 | gamma = xy_ / xx_ 61 | offset = np.zeros(x.shape[1]) if x.ndim > 1 else 0 62 | nans_offset, nans_gamma = np.isnan(offset), np.isnan(gamma) 63 | if np.any([nans_offset, nans_gamma]): 64 | offset[np.isnan(offset)], gamma[np.isnan(gamma)] = 0, 0 65 | return offset, gamma 66 | 67 | 68 | leastsq = leastsq_NxN 69 | 70 | 71 | def optimize_NxN(x, y, fit_offset=False, perc=None): 72 | """Just to compare with closed-form solution""" 73 | if perc is not None: 74 | if not fit_offset and isinstance(perc, (list, tuple)): 75 | perc = perc[1] 76 | weights = get_weight(x, y, perc).astype(bool) 77 | if issparse(weights): 78 | weights = weights.A 79 | else: 80 | weights = None 81 | 82 | x, y = x.astype(np.float64), y.astype(np.float64) 83 | 84 | n_vars = x.shape[1] 85 | offset, gamma = np.zeros(n_vars), np.zeros(n_vars) 86 | 87 | for i in range(n_vars): 88 | xi = x[:, i] if weights is None else x[:, i][weights[:, i]] 89 | yi = y[:, i] if weights is None else y[:, i][weights[:, i]] 90 | 91 | if fit_offset: 92 | offset[i], gamma[i] = minimize( 93 | lambda m: np.sum((-yi + xi * m[1] + m[0]) ** 2), 94 | method="L-BFGS-B", 95 | x0=(0, 0.1), 96 | bounds=[(0, None), (None, None)], 97 | ).x 98 | else: 99 | gamma[i] = minimize( 100 | lambda m: np.sum((-yi + xi * m) ** 2), x0=0.1, method="L-BFGS-B" 101 | ).x 102 | offset[np.isnan(offset)], gamma[np.isnan(gamma)] = 0, 0 103 | return offset, gamma 104 | 105 | 106 | def leastsq_generalized( 107 | x, 108 | y, 109 | x2, 110 | y2, 111 | res_std=None, 112 | res2_std=None, 113 | fit_offset=False, 114 | fit_offset2=False, 115 | perc=None, 116 | ): 117 | """Solution to the 2-dim generalized least squares: gamma = inv(X'QX)X'QY""" 118 | if perc is not None: 119 | if not fit_offset and isinstance(perc, (list, tuple)): 120 | perc = perc[1] 121 | weights = csr_matrix( 122 | get_weight(x, y, perc=perc) | get_weight(x, perc=perc) 123 | ).astype(bool) 124 | x, y = weights.multiply(x).tocsr(), weights.multiply(y).tocsr() 125 | # x2, y2 = weights.multiply(x2).tocsr(), weights.multiply(y2).tocsr() 126 | 127 | n_obs, n_var = x.shape 128 | offset, offset_ss = ( 129 | np.zeros(n_var, dtype="float32"), 130 | np.zeros(n_var, dtype="float32"), 131 | ) 132 | gamma = np.ones(n_var, dtype="float32") 133 | 134 | if (res_std is None) or (res2_std is None): 135 | res_std, res2_std = np.ones(n_var), np.ones(n_var) 136 | ones, zeros = np.ones(n_obs), np.zeros(n_obs) 137 | 138 | with warnings.catch_warnings(): 139 | warnings.simplefilter("ignore") 140 | x, y = ( 141 | np.vstack((make_dense(x) / res_std, x2 / res2_std)), 142 | np.vstack((make_dense(y) / res_std, y2 / res2_std)), 143 | ) 144 | 145 | if fit_offset and fit_offset2: 146 | for i in range(n_var): 147 | A = np.c_[ 148 | np.vstack( 149 | (np.c_[ones / res_std[i], zeros], np.c_[zeros, ones / res2_std[i]]) 150 | ), 151 | x[:, i], 152 | ] 153 | offset[i], offset_ss[i], gamma[i] = np.linalg.pinv(A.T.dot(A)).dot( 154 | A.T.dot(y[:, i]) 155 | ) 156 | elif fit_offset: 157 | for i in range(n_var): 158 | A = np.c_[np.hstack((ones / res_std[i], zeros)), x[:, i]] 159 | offset[i], gamma[i] = np.linalg.pinv(A.T.dot(A)).dot(A.T.dot(y[:, i])) 160 | elif fit_offset2: 161 | for i in range(n_var): 162 | A = np.c_[np.hstack((zeros, ones / res2_std[i])), x[:, i]] 163 | offset_ss[i], gamma[i] = np.linalg.pinv(A.T.dot(A)).dot(A.T.dot(y[:, i])) 164 | else: 165 | for i in range(n_var): 166 | A = np.c_[x[:, i]] 167 | gamma[i] = np.linalg.pinv(A.T.dot(A)).dot(A.T.dot(y[:, i])) 168 | 169 | offset[np.isnan(offset)] = 0 170 | offset_ss[np.isnan(offset_ss)] = 0 171 | gamma[np.isnan(gamma)] = 0 172 | 173 | return offset, offset_ss, gamma 174 | 175 | 176 | def maximum_likelihood(Ms, Mu, Mus, Mss, fit_offset=False, fit_offset2=False): 177 | """Maximizing the log likelihood using weights according to empirical bayes""" 178 | n_obs, n_var = Ms.shape 179 | offset = np.zeros(n_var, dtype="float32") 180 | offset_ss = np.zeros(n_var, dtype="float32") 181 | gamma = np.ones(n_var, dtype="float32") 182 | 183 | def sse(A, data, b): 184 | sigma = (A.dot(data) - b).std(1) 185 | return np.log(sigma).sum() 186 | 187 | if fit_offset and fit_offset2: 188 | for i in range(n_var): 189 | data = np.vstack((Mu[:, i], Ms[:, i], Mus[:, i], Mss[:, i])) 190 | offset[i], offset_ss[i], gamma[i] = minimize( 191 | lambda m: sse( 192 | np.array([[1, -m[2], 0, 0], [1, m[2], 2, -2 * m[2]]]), 193 | data, 194 | b=np.array(m[0], m[1]), 195 | ), 196 | x0=(1e-4, 1e-4, 1), 197 | method="L-BFGS-B", 198 | ).x 199 | elif fit_offset: 200 | for i in range(n_var): 201 | data = np.vstack((Mu[:, i], Ms[:, i], Mus[:, i], Mss[:, i])) 202 | offset[i], gamma[i] = minimize( 203 | lambda m: sse( 204 | np.array([[1, -m[1], 0, 0], [1, m[1], 2, -2 * m[1]]]), 205 | data, 206 | b=np.array(m[0], 0), 207 | ), 208 | x0=(1e-4, 1), 209 | method="L-BFGS-B", 210 | ).x 211 | elif fit_offset2: 212 | for i in range(n_var): 213 | data = np.vstack((Mu[:, i], Ms[:, i], Mus[:, i], Mss[:, i])) 214 | offset_ss[i], gamma[i] = minimize( 215 | lambda m: sse( 216 | np.array([[1, -m[1], 0, 0], [1, m[1], 2, -2 * m[1]]]), 217 | data, 218 | b=np.array(0, m[0]), 219 | ), 220 | x0=(1e-4, 1), 221 | method="L-BFGS-B", 222 | ).x 223 | else: 224 | for i in range(n_var): 225 | data = np.vstack((Mu[:, i], Ms[:, i], Mus[:, i], Mss[:, i])) 226 | gamma[i] = minimize( 227 | lambda m: sse(np.array([[1, -m, 0, 0], [1, m, 2, -2 * m]]), data, b=0), 228 | x0=gamma[i], 229 | method="L-BFGS-B", 230 | ).x 231 | return offset, offset_ss, gamma 232 | -------------------------------------------------------------------------------- /deepvelo/utils/plot.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | with warnings.catch_warnings(): 4 | warnings.simplefilter("always") 5 | warnings.warn( 6 | "deepvelo.utils.plot is deprecated. Please use deepvelo.plot.plot instead.", 7 | DeprecationWarning, 8 | ) 9 | 10 | from ..plot.plot import * 11 | -------------------------------------------------------------------------------- /deepvelo/utils/preprocess.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import numpy as np 4 | import scvelo as scv 5 | from scvelo import logging as logg 6 | from scvelo.core import sum as sum_ 7 | from anndata import AnnData 8 | 9 | from deepvelo.utils.plot import dist_plot 10 | 11 | 12 | def clip_and_norm_Ms_Mu( 13 | adata, 14 | do_clip: bool = True, 15 | do_norm: bool = True, 16 | target_mean: float = 0.4, 17 | replace: bool = False, 18 | save_fig: Optional[str] = None, 19 | plot: bool = True, 20 | print_summary: bool = True, 21 | ) -> Tuple[float, float]: 22 | """ 23 | Normalize using the mean and standard deviation of the gene expression matrix. 24 | 25 | Args: 26 | adata (Anndata): Anndata object. 27 | target_mean (float): target mean. 28 | replace (bool): replace the original data. 29 | save_fig (str): directory to save figures. 30 | plot (bool): plot the distribution of the normalized data. 31 | print_summary (bool): print the summary of the normalized data. 32 | 33 | Returns: 34 | Tupel[float, float]: scale factor for Ms and Mu. 35 | """ 36 | non_zero_Ms = adata.layers["Ms"][adata.layers["Ms"] > 0] 37 | non_zero_Mu = adata.layers["Mu"][adata.layers["Mu"] > 0] 38 | if print_summary: 39 | print( 40 | f"Raw Ms: mean {adata.layers['Ms'].mean():.2f}," 41 | f" max {adata.layers['Ms'].max():.2f}," 42 | f" std {adata.layers['Ms'].std():.2f}," 43 | f" 99.5% quantile {np.percentile(adata.layers['Ms'], 99.5):.2f}" 44 | f" 99.5% of non-zero: {np.percentile(non_zero_Ms, 99.5):.2f}" 45 | ) 46 | print( 47 | f"Raw Mu: mean {adata.layers['Mu'].mean():.2f}," 48 | f" max {adata.layers['Mu'].max():.2f}," 49 | f" std {adata.layers['Mu'].std():.2f}," 50 | f" 99.5% quantile {np.percentile(adata.layers['Mu'], 99.5):.2f}" 51 | f" 99.5% of non-zero: {np.percentile(non_zero_Mu, 99.5):.2f}" 52 | ) 53 | 54 | if do_clip: 55 | # clip the max value to 99.5% quantile 56 | adata.layers["NMs"] = np.clip( 57 | adata.layers["Ms"], None, np.percentile(non_zero_Ms, 99.5) 58 | ) 59 | adata.layers["NMu"] = np.clip( 60 | adata.layers["Mu"], None, np.percentile(non_zero_Mu, 99.5) 61 | ) 62 | else: 63 | adata.layers["NMs"] = adata.layers["Ms"] 64 | adata.layers["NMu"] = adata.layers["Mu"] 65 | logg.hint(f"added 'NMs' (adata.layers)") 66 | logg.hint(f"added 'NMu' (adata.layers)") 67 | 68 | if plot: 69 | dist_plot( 70 | adata.layers["NMs"].flatten(), 71 | adata.layers["NMu"].flatten(), 72 | bins=20, 73 | labels=["NMs", "NMu"], 74 | title="Distribution of Ms and Mu", 75 | save=f"{save_fig}/hist-Ms-Mu.png" if save_fig is not None else None, 76 | ) 77 | 78 | scale_Ms, scale_Mu = 1.0, 1.0 79 | if do_norm: 80 | scale_Ms = adata.layers["NMs"].mean() / target_mean 81 | scale_Mu = adata.layers["NMu"].mean() / target_mean 82 | adata.layers["NMs"] = adata.layers["NMs"] / scale_Ms 83 | adata.layers["NMu"] = adata.layers["NMu"] / scale_Mu 84 | print(f"Normalized Ms and Mu to mean of {target_mean}") 85 | if plot: 86 | ax = scv.pl.hist( 87 | [adata.layers["NMs"].flatten(), adata.layers["NMu"].flatten()], 88 | labels=["NMs", "NMu"], 89 | kde=False, 90 | normed=False, 91 | bins=20, 92 | # xlim=[0, 1], 93 | fontsize=18, 94 | legend_fontsize=16, 95 | show=False, 96 | ) 97 | if save_fig is not None: 98 | ax.get_figure().savefig(f"{save_fig}/hist-normed-Ms-Mu.png") 99 | 100 | if print_summary: 101 | print( 102 | f"New Ms: mean {adata.layers['NMs'].mean():.2f}," 103 | f" max {adata.layers['NMs'].max():.2f}," 104 | f" std {adata.layers['NMs'].std():.2f}," 105 | f" 99.5% quantile {np.percentile(adata.layers['NMs'], 99.5):.2f}" 106 | ) 107 | print( 108 | f"New Mu: mean {adata.layers['NMu'].mean():.2f}," 109 | f" max {adata.layers['NMu'].max():.2f}," 110 | f" std {adata.layers['NMu'].std():.2f}," 111 | f" 99.5% quantile {np.percentile(adata.layers['NMu'], 99.5):.2f}" 112 | ) 113 | 114 | if replace: 115 | adata.layers["Ms"] = adata.layers["NMs"] 116 | adata.layers["Mu"] = adata.layers["NMu"] 117 | logg.hint(f"replaced 'Ms' (adata.layers) with 'NMs'") 118 | logg.hint(f"replaced 'Mu' (adata.layers) with 'NMu'") 119 | 120 | return scale_Ms, scale_Mu 121 | 122 | 123 | def autoset_coeff_s(adata: AnnData, use_raw: bool = True) -> float: 124 | """ 125 | Automatically set the weighting for objective term of the spliced 126 | read correlation. Modified from the scv.pl.proportions function. 127 | 128 | Args: 129 | adata (Anndata): Anndata object. 130 | use_raw (bool): use raw data or processed data. 131 | 132 | Returns: 133 | float: weighting coefficient for objective term of the unpliced read 134 | """ 135 | layers = ["spliced", "unspliced", "ambigious"] 136 | layers_keys = [key for key in layers if key in adata.layers.keys()] 137 | counts_layers = [sum_(adata.layers[key], axis=1) for key in layers_keys] 138 | 139 | if use_raw: 140 | ikey, obs = "initial_size_", adata.obs 141 | counts_layers = [ 142 | obs[ikey + layer_key] if ikey + layer_key in obs.keys() else c 143 | for layer_key, c in zip(layers_keys, counts_layers) 144 | ] 145 | counts_total = np.sum(counts_layers, 0) 146 | counts_total += counts_total == 0 147 | counts_layers = np.array([counts / counts_total for counts in counts_layers]) 148 | counts_layers = np.mean(counts_layers, axis=1) 149 | 150 | spliced_counts = counts_layers[layers_keys.index("spliced")] 151 | ratio = spliced_counts / counts_layers.sum() 152 | 153 | if ratio < 0.7: 154 | coeff_s = 0.5 155 | print( 156 | f"The ratio of spliced reads is {ratio*100:.1f}% (less than 70%). " 157 | f"Suggest using coeff_s {coeff_s}." 158 | ) 159 | elif ratio < 0.85: 160 | coeff_s = 0.75 161 | print( 162 | f"The ratio of spliced reads is {ratio*100:.1f}% (between 70% and 85%). " 163 | f"Suggest using coeff_s {coeff_s}." 164 | ) 165 | else: 166 | coeff_s = 1.0 167 | print( 168 | f"The ratio of spliced reads is {ratio*100:.1f}% (more than 85%). " 169 | f"Suggest using coeff_s {coeff_s}." 170 | ) 171 | 172 | return coeff_s 173 | -------------------------------------------------------------------------------- /deepvelo/utils/scatter.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | with warnings.catch_warnings(): 4 | warnings.simplefilter("always") 5 | warnings.warn( 6 | "deepvelo.utils.scatter is deprecated. Please use deepvelo.plot.scatter instead.", 7 | DeprecationWarning, 8 | ) 9 | 10 | from ..plot.scatter import * 11 | -------------------------------------------------------------------------------- /deepvelo/utils/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Dict 3 | 4 | import pandas as pd 5 | import numpy as np 6 | import torch 7 | from pathlib import Path 8 | from itertools import repeat 9 | from collections import OrderedDict 10 | from collections.abc import Mapping 11 | from scvelo.core import l2_norm, prod_sum, sum 12 | 13 | 14 | def ensure_dir(dirname): 15 | dirname = Path(dirname) 16 | if not dirname.is_dir(): 17 | dirname.mkdir(parents=True, exist_ok=False) 18 | return dirname 19 | 20 | 21 | def read_json(fname): 22 | fname = Path(fname) 23 | with fname.open("rt") as handle: 24 | return json.load(handle, object_hook=OrderedDict) 25 | 26 | 27 | def write_json(content, fname): 28 | fname = Path(fname) 29 | with fname.open("wt") as handle: 30 | json.dump(content, handle, indent=4, sort_keys=False) 31 | 32 | 33 | def save_model_and_config(model, config, directory): 34 | """saves model and config to directory.""" 35 | directory = ensure_dir(directory) 36 | torch.save(model.state_dict(), directory / "model.pt") 37 | write_json(config, directory / "config.json") 38 | 39 | 40 | def inf_loop(data_loader): 41 | """wrapper function for endless data loader.""" 42 | for loader in repeat(data_loader): 43 | yield from loader 44 | 45 | 46 | def validate_config(config: Mapping) -> Mapping: 47 | """ 48 | Return config if it is valid, otherwise raise an error. 49 | """ 50 | 51 | # check if the gpu verion of dgl is installed 52 | if config["n_gpu"] > 0: 53 | import dgl 54 | 55 | try: 56 | dgl.graph([]).to("cuda") 57 | except dgl.DGLError: 58 | print( 59 | "Config Warning: Set to use GPU, but GPU version of DGL is not " 60 | "installed. Reset to use CPU instead." 61 | ) 62 | config["n_gpu"] = 0 63 | 64 | return config 65 | 66 | 67 | def update_dict(d: Dict, u: Mapping, copy=False): 68 | """recursively updates nested dict with values from u.""" 69 | if copy: 70 | d = d.copy() 71 | for k, v in u.items(): 72 | if isinstance(v, Mapping): 73 | r = update_dict(d.get(k, {}), v, copy) 74 | d[k] = r 75 | else: 76 | d[k] = u[k] 77 | return d 78 | 79 | 80 | def get_indices(dist, n_neighbors=None, mode_neighbors="distances"): 81 | from scvelo.preprocessing.neighbors import compute_connectivities_umap 82 | 83 | D = dist.copy() 84 | D.data += 1e-6 85 | 86 | n_counts = sum(D > 0, axis=1) 87 | n_neighbors = ( 88 | n_counts.min() if n_neighbors is None else min(n_counts.min(), n_neighbors) 89 | ) 90 | rows = np.where(n_counts > n_neighbors)[0] 91 | cumsum_neighs = np.insert(n_counts.cumsum(), 0, 0) 92 | dat = D.data 93 | 94 | for row in rows: 95 | n0, n1 = cumsum_neighs[row], cumsum_neighs[row + 1] 96 | rm_idx = n0 + dat[n0:n1].argsort()[n_neighbors:] 97 | dat[rm_idx] = 0 98 | D.eliminate_zeros() 99 | 100 | D.data -= 1e-6 101 | if mode_neighbors == "distances": 102 | indices = D.indices.reshape((-1, n_neighbors)) 103 | elif mode_neighbors == "connectivities": 104 | knn_indices = D.indices.reshape((-1, n_neighbors)) 105 | knn_distances = D.data.reshape((-1, n_neighbors)) 106 | _, conn = compute_connectivities_umap( 107 | knn_indices, knn_distances, D.shape[0], n_neighbors 108 | ) 109 | indices = get_indices_from_csr(conn) 110 | return indices, D 111 | 112 | 113 | def get_indices_from_csr(conn): 114 | # extracts indices from connectivity matrix, pads with nans 115 | ixs = np.ones((conn.shape[0], np.max((conn > 0).sum(1)))) * np.nan 116 | for i in range(ixs.shape[0]): 117 | cell_indices = conn[i, :].indices 118 | ixs[i, : len(cell_indices)] = cell_indices 119 | return ixs 120 | 121 | 122 | def make_dense(X): 123 | from scipy.sparse import issparse 124 | 125 | XA = X.A if issparse(X) and X.ndim == 2 else X.A1 if issparse(X) else X 126 | if XA.ndim == 2: 127 | XA = XA[0] if XA.shape[0] == 1 else XA[:, 0] if XA.shape[1] == 1 else XA 128 | return np.array(XA) 129 | 130 | 131 | def get_weight(x, y=None, perc=95): 132 | from scipy.sparse import issparse 133 | 134 | xy_norm = np.array(x.A if issparse(x) else x) 135 | if y is not None: 136 | if issparse(y): 137 | y = y.A 138 | xy_norm = xy_norm / np.clip(np.max(xy_norm, axis=0), 1e-3, None) 139 | xy_norm += y / np.clip(np.max(y, axis=0), 1e-3, None) 140 | 141 | if isinstance(perc, int): 142 | weights = xy_norm >= np.percentile(xy_norm, perc, axis=0) 143 | else: 144 | lb, ub = np.percentile(xy_norm, perc, axis=0) 145 | weights = (xy_norm <= lb) | (xy_norm >= ub) 146 | 147 | return weights 148 | 149 | 150 | def R2(residual, total): 151 | r2 = np.ones(residual.shape[1]) - np.sum(residual * residual, axis=0) / np.sum( 152 | total * total, axis=0 153 | ) 154 | r2[np.isnan(r2)] = 0 155 | return r2 156 | 157 | 158 | class MetricTracker: 159 | def __init__(self, *keys, writer=None): 160 | self.writer = writer 161 | self._data = pd.DataFrame(index=keys, columns=["total", "counts", "average"]) 162 | self.reset() 163 | 164 | def reset(self): 165 | for col in self._data.columns: 166 | self._data[col].values[:] = 0 167 | 168 | def update(self, key, value, n=1): 169 | if self.writer is not None: 170 | self.writer.add_scalar(key, value) 171 | self._data.total[key] += value * n 172 | self._data.counts[key] += n 173 | self._data.average[key] = self._data.total[key] / self._data.counts[key] 174 | 175 | def avg(self, key): 176 | return self._data.average[key] 177 | 178 | def result(self): 179 | return dict(self._data.average) 180 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | project = "DeepVelo" 10 | copyright = "2024, bowang-lab" 11 | author = "bowang-lab" 12 | release = "0.2.8" 13 | 14 | # -- General configuration --------------------------------------------------- 15 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 16 | 17 | extensions = [ 18 | "sphinx.ext.autodoc", 19 | "sphinx.ext.autosummary", 20 | "myst_parser", 21 | ] 22 | 23 | templates_path = ["_templates"] 24 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 25 | 26 | 27 | # -- Options for HTML output ------------------------------------------------- 28 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 29 | 30 | html_theme = "sphinx_rtd_theme" 31 | html_static_path = ["_static"] 32 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. DeepVelo documentation master file, created by 2 | sphinx-quickstart on Sat Jan 20 02:43:30 2024. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to DeepVelo's documentation! 7 | ==================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | 14 | 15 | Indices and tables 16 | ================== 17 | 18 | * :ref:`genindex` 19 | * :ref:`modindex` 20 | * :ref:`search` 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Analysis scripts for DeepVelo 2 | 3 | This folder contains the code necessary to reproduce the analysis for the DeepVelo manuscript 4 | 5 | ### Instructions: 6 | 7 | Most notebooks can run directly and download necessary data automatically. For running python notebooks using the hindbrain data or running the R scripts: 8 | 9 | 1. Please download the necessary files from https://doi.org/10.6084/m9.figshare.24716592: 10 | 11 | ``` 12 | wget -O deepvelo_data.tar.gz https://figshare.com/ndownloader/files/43428348 13 | ``` 14 | 15 | 2. Untar and extract the main directory 16 | 17 | ``` 18 | tar -xzvf deepvelo_data.tar.gz 19 | ``` 20 | 21 | **For R scripts** - please change directories into the `r_analysis` subfolder and follow the instructions 22 | to download and install the conda environment for generating the R figures and analyses. 23 | -------------------------------------------------------------------------------- /examples/minimal_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Let's start by loading the necessary libraries, and setting some configuration options for reproducibility and visualization." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stderr", 17 | "output_type": "stream", 18 | "text": [ 19 | "/h/hmaan/miniconda3/envs/dvelo_py_38/lib/python3.8/site-packages/deepvelo/utils/plot.py:5: DeprecationWarning: deepvelo.utils.plot is deprecated. Please use deepvelo.plot.plot instead.\n", 20 | " warnings.warn(\n" 21 | ] 22 | } 23 | ], 24 | "source": [ 25 | "# %%\n", 26 | "#Import top level libraries, including the deepvelo package\n", 27 | "import numpy as np\n", 28 | "import scvelo as scv\n", 29 | "import torch\n", 30 | "\n", 31 | "from deepvelo.utils import velocity, update_dict\n", 32 | "from deepvelo.utils.preprocess import autoset_coeff_s\n", 33 | "from deepvelo.utils.plot import statplot, compare_plot\n", 34 | "from deepvelo import train, Constants\n", 35 | "\n", 36 | "# fix random seeds for reproducibility\n", 37 | "SEED = 123\n", 38 | "torch.manual_seed(SEED)\n", 39 | "torch.backends.cudnn.deterministic = True\n", 40 | "torch.backends.cudnn.benchmark = False\n", 41 | "np.random.seed(SEED)\n", 42 | "\n", 43 | "# set options for for visualization and verbosity\n", 44 | "scv.settings.verbosity = 3 # show errors(0), warnings(1), info(2), hints(3)\n", 45 | "scv.settings.set_figure_params(\n", 46 | " \"scvelo\", transparent=False\n", 47 | ") # for beautified visualization\n", 48 | "\n", 49 | "%load_ext autoreload\n", 50 | "%autoreload 2\n" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "We're going to be using the Dentate Gyrus neurogenesis data from [La Manno et al. (2018)](https://doi.org/10.1038/s41586-018-0414-6) in this example. Start by loading and preprocessing the data." 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 2, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "name": "stdout", 67 | "output_type": "stream", 68 | "text": [ 69 | "Filtered out 18710 genes that are detected 20 counts (shared).\n", 70 | "Normalized count data: X, spliced, unspliced.\n", 71 | "Extracted 2000 highly variable genes.\n", 72 | "Logarithmized X.\n", 73 | "computing neighbors\n", 74 | " finished (0:00:26) --> added \n", 75 | " 'distances' and 'connectivities', weighted adjacency matrices (adata.obsp)\n", 76 | "computing moments based on connectivities\n", 77 | " finished (0:00:03) --> added \n", 78 | " 'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)\n" 79 | ] 80 | } 81 | ], 82 | "source": [ 83 | "adata = scv.datasets.dentategyrus_lamanno()\n", 84 | "scv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=2000)\n", 85 | "scv.pp.moments(adata, n_neighbors=30, n_pcs=30)\n" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "Now we're going to configure the DeepVelo model and name the experiment - we'll just call it DeepVelo for now. We're also going to empirically set the spliced correlation objective based on the data - this is recommended for best performance." 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 3, 98 | "metadata": {}, 99 | "outputs": [ 100 | { 101 | "name": "stdout", 102 | "output_type": "stream", 103 | "text": [ 104 | "The ratio of spliced reads is 75.4% (between 70% and 85%). Suggest using coeff_s 0.75.\n" 105 | ] 106 | } 107 | ], 108 | "source": [ 109 | "configs = {\n", 110 | " \"name\": \"DeepVelo\", # name of the experiment\n", 111 | " \"loss\": {\"args\": {\"coeff_s\": autoset_coeff_s(adata)}} # Automatic setting of the spliced correlation objective\n", 112 | "}\n", 113 | "configs = update_dict(Constants.default_configs, configs)\n" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": {}, 119 | "source": [ 120 | "Now we can call the velocity and train methods to fit the model to the data. " 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "name": "stdout", 130 | "output_type": "stream", 131 | "text": [ 132 | "computing velocities\n", 133 | " finished (0:00:09) --> added \n", 134 | " 'velocity', velocity vectors for each individual cell (adata.layers)\n", 135 | "Config Warning: Set to use GPU, but GPU version of DGL is not installed. Reset to use CPU instead.\n", 136 | "Warning: logging configuration file is not found in logger/logger_config.json.\n", 137 | "building graph\n" 138 | ] 139 | }, 140 | { 141 | "name": "stderr", 142 | "output_type": "stream", 143 | "text": [ 144 | "INFO:train:Beginning training of DeepVelo ...\n", 145 | "WARNING:trainer:Warning: visualization (Tensorboard) is configured to use, but currently not installed on this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file.\n" 146 | ] 147 | }, 148 | { 149 | "name": "stdout", 150 | "output_type": "stream", 151 | "text": [ 152 | "velo data shape: torch.Size([18213, 2000])\n" 153 | ] 154 | }, 155 | { 156 | "name": "stderr", 157 | "output_type": "stream", 158 | "text": [ 159 | "INFO:trainer: epoch : 1\n", 160 | "INFO:trainer: time: : 21.26166820526123\n", 161 | "INFO:trainer: loss : 80684.3046875\n", 162 | "INFO:trainer: mse : 0.942495584487915\n", 163 | "INFO:trainer: epoch : 2\n", 164 | "INFO:trainer: time: : 21.131349325180054\n", 165 | "INFO:trainer: loss : 16099.67578125\n", 166 | "INFO:trainer: mse : 0.6484971046447754\n", 167 | "INFO:trainer: epoch : 3\n", 168 | "INFO:trainer: time: : 21.7674458026886\n", 169 | "INFO:trainer: loss : 9191.4521484375\n", 170 | "INFO:trainer: mse : 0.6198839545249939\n", 171 | "INFO:trainer: epoch : 4\n", 172 | "INFO:trainer: time: : 22.867656230926514\n", 173 | "INFO:trainer: loss : 6672.14892578125\n", 174 | "INFO:trainer: mse : 0.6180376410484314\n", 175 | "INFO:trainer: epoch : 5\n", 176 | "INFO:trainer: time: : 23.165233373641968\n", 177 | "INFO:trainer: loss : 5399.71337890625\n", 178 | "INFO:trainer: mse : 0.6397435665130615\n", 179 | "INFO:trainer: epoch : 6\n", 180 | "INFO:trainer: time: : 23.299899339675903\n", 181 | "INFO:trainer: loss : 4641.85400390625\n", 182 | "INFO:trainer: mse : 0.655648410320282\n", 183 | "INFO:trainer: epoch : 7\n", 184 | "INFO:trainer: time: : 23.070889711380005\n", 185 | "INFO:trainer: loss : 4119.5205078125\n", 186 | "INFO:trainer: mse : 0.6404055953025818\n" 187 | ] 188 | } 189 | ], 190 | "source": [ 191 | "# initial velocity\n", 192 | "velocity(adata, mask_zero=False)\n", 193 | "trainer = train(adata, configs)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": {}, 199 | "source": [ 200 | "Now that the velocity calculation is complete, we can visualize the results. We'll start by visualizing the velocity field in the embedding space." 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "# velocity plot\n", 210 | "scv.tl.velocity_graph(adata, n_jobs=8)\n", 211 | "scv.pl.velocity_embedding_stream(\n", 212 | " adata,\n", 213 | " basis=\"tsne\",\n", 214 | " color=\"clusters\",\n", 215 | " legend_fontsize=9,\n", 216 | " dpi=150, # increase dpi for higher resolution\n", 217 | ")\n" 218 | ] 219 | }, 220 | { 221 | "cell_type": "markdown", 222 | "metadata": {}, 223 | "source": [ 224 | "We can further visualize the pseudotime estimated based on the velocity field, and plot this in the embedding space." 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": null, 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "scv.tl.velocity_pseudotime(adata)\n", 234 | "scv.pl.scatter(\n", 235 | " adata,\n", 236 | " color=\"velocity_pseudotime\",\n", 237 | " cmap=\"gnuplot\",\n", 238 | " dpi=150,\n", 239 | ")" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": {}, 245 | "source": [ 246 | "There are a number of other visualizations and analyses that can be performed - please see the rest of the examples from the paper.\n" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [] 255 | } 256 | ], 257 | "metadata": { 258 | "interpreter": { 259 | "hash": "c4a96b5b33e9918a79a56c93882e3be4fe6c52e21891fd42273b478ddcde2cd9" 260 | }, 261 | "kernelspec": { 262 | "display_name": "deepvelo dep fixes", 263 | "language": "python", 264 | "name": "deepvelo_dep_fixes" 265 | }, 266 | "language_info": { 267 | "codemirror_mode": { 268 | "name": "ipython", 269 | "version": 3 270 | }, 271 | "file_extension": ".py", 272 | "mimetype": "text/x-python", 273 | "name": "python", 274 | "nbconvert_exporter": "python", 275 | "pygments_lexer": "ipython3", 276 | "version": "3.8.18" 277 | } 278 | }, 279 | "nbformat": 4, 280 | "nbformat_minor": 2 281 | } 282 | -------------------------------------------------------------------------------- /examples/multifacet_check.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | 4 | import numpy as np 5 | from anndata import AnnData 6 | import scvelo as scv 7 | 8 | datasets = [ 9 | "Dentate Gyrus", 10 | "Pancreas", 11 | "Hindbrain", 12 | "Hippocampus", 13 | "Chondrocyte Organogenesis", 14 | "Gastrulation Erythroid", 15 | ] 16 | 17 | # %% 18 | results = {} 19 | for data in datasets: 20 | if data == "Dentate Gyrus": 21 | adata = scv.datasets.dentategyrus() 22 | groupby = "clusters" 23 | if data == "Pancreas": 24 | adata = scv.datasets.pancreas() 25 | groupby = "clusters" 26 | if data == "Hindbrain": 27 | adata = scv.read( 28 | "deepvelo_data/h5ad_files/Hindbrain_GABA_Glio.h5ad", cache=True 29 | ) 30 | groupby = "Celltype" 31 | if data == "Hippocampus": 32 | adata = scv.datasets.dentategyrus_lamanno() 33 | groupby = "clusters" 34 | if data == "Chondrocyte Organogenesis": 35 | adata = scv.read("../data/cao_organogenesis_chondrocyte.h5ad", cache=True) 36 | adata = adata[np.random.choice(adata.obs_names, 30000, replace=False)] 37 | adata.X = ( 38 | adata.layers["spliced"] 39 | + adata.layers["unspliced"] 40 | + adata.layers["ambiguous"] 41 | ) 42 | groupby = "Main_cell_type" 43 | if data == "Mouse Gastrulation": 44 | adata = scv.datasets.gastrulation_erythroid() 45 | groupby = "celltype" 46 | 47 | scv.pp.filter_and_normalize(adata, min_shared_counts=30, n_top_genes=2000) 48 | scv.pp.moments(adata, n_pcs=30, n_neighbors=30) 49 | scv.tl.velocity(adata) 50 | 51 | var_names = adata[:, adata.var["velocity_genes"]].var_names 52 | # var_names = adata[:, adata.var["velocity_genes"]].var_names[:10] 53 | scv.tl.recover_dynamics(adata, n_jobs=8) 54 | scv.tl.differential_kinetic_test(adata, var_names=var_names, groupby=groupby) 55 | 56 | res_df = scv.get_df( 57 | adata[:, var_names], ["fit_diff_kinetics", "fit_pval_kinetics"], precision=2 58 | ) 59 | 60 | num_total_genes = len(var_names) 61 | num_multifacet_genes = len(res_df) 62 | ratio = num_multifacet_genes / num_total_genes 63 | 64 | results[data] = { 65 | "num_total_genes": num_total_genes, 66 | "num_multifacet_genes": num_multifacet_genes, 67 | "ratio": ratio, 68 | } 69 | 70 | # %% save results and visualize in sns barplot 71 | import pandas as pd 72 | import seaborn as sns 73 | from matplotlib import pyplot as plt 74 | 75 | df = pd.DataFrame(results).T 76 | df = df.reset_index() 77 | df.to_csv("saved/multifacet_check.csv", index=False) 78 | 79 | # set default font size 80 | sns.set(font_scale=1.3) 81 | 82 | fig = plt.figure(figsize=(12, 9)) 83 | # plot with larger font size and make sure x-axis labels are well separated 84 | sns.barplot(x="index", y="ratio", data=df, palette="Blues_d") 85 | plt.xticks(rotation=45, ha="right") 86 | plt.xlabel("") 87 | plt.ylabel("") 88 | plt.title(f"Ratio of multifaceted genes per dataset, avg {df['ratio'].mean():.2f}") 89 | plt.tight_layout() 90 | plt.savefig("saved/multifacet_check.png", dpi=300) 91 | 92 | # %% 93 | -------------------------------------------------------------------------------- /examples/r_analysis/README.md: -------------------------------------------------------------------------------- 1 | # Driver-gene analysis of mouse hindbrain developmental data 2 | 3 | This folder contains the necessary scripts and environment to the run the R-based analysis for the mouse hindbrain developmental data, and reproduce the 4 | figures related to this data in the main manuscript. 5 | 6 | ## Instructions 7 | 8 | 1. Install the conda environment that contains the necessary R-build and libraries for running the analysis: 9 | ``` 10 | # This will create an env named 'deepvelo_r_analysis' 11 | conda env create -f env.yaml 12 | ``` 13 | 14 | 2. Create outs and data folders, for output of scripts and necessary data, respectively 15 | ``` 16 | mkdir data 17 | mkdir -p outs/figures 18 | ``` 19 | 20 | 3. Copy the necessary files to run the scripts from the DeepVelo resource folder into the data dir 21 | ``` 22 | cp ../deepvelo_data/deepvelo_outputs/MDT_driver_genes[DYNAMICAL].csv data # scVelo driver genes 23 | cp ../deepvelo_data/deepvelo_outputs/MDT_driver_genes.csv data # DeepVelo driver genes 24 | cp ../deepvelo_data/metadata_files/41586_2019_1158_MOESM4_ESM.csv data # Vladiou et al. marker genes 25 | cp ../deepvelo_data/metadata_files/HGNC_AllianceHomology.rpt data # Full list of mouse genes 26 | cp ../deepvelo_data/metadata_files/mm_go_mf_bp_reac_feb_2022.gmt data # Pathway GMT files for pathway analysis 27 | cp ../deepvelo_data/metadata_files/02_human_mouse_tfs_matched.tsv data # List of matched human-mouse TFs 28 | ``` 29 | 4. Run the R-rscripts in order to reproduce analysis 30 | ``` 31 | conda activate deepvelo_r_analysis 32 | cd scripts 33 | Rscript --verbose 00_extra_package_installs.R 34 | Rscript --verbose 01_deepvelo_marker_analysis.R 35 | Rscript --verbose 02_deepvelo_scvelo_pathway_analysis.R 36 | Rscript --verbose 03_deepvelo_scvelo_activepathways_results_analysis.R 37 | Rscript --verbose 04_deepvelo_tf_driver_analysis.R 38 | Rscript --verbose 05_supplementary_table_formatting.R 39 | ``` 40 | -------------------------------------------------------------------------------- /examples/r_analysis/scripts/00_extra_package_installs.R: -------------------------------------------------------------------------------- 1 | require(devtools) 2 | 3 | # Install R packages not available in the Conda repository 4 | install_version( 5 | "ActivePathways", 6 | version = "1.1.0", 7 | repos = "http://cran.us.r-project.org" 8 | ) 9 | install.packages( 10 | "ggvenn", 11 | repos = "http://cran.us.r-project.org" 12 | ) -------------------------------------------------------------------------------- /examples/r_analysis/scripts/02_deepvelo_scvelo_pathway_analysis.R: -------------------------------------------------------------------------------- 1 | library(data.table) 2 | library(tidyr) 3 | library(dplyr) 4 | library(ggplot2) 5 | library(ggthemes) 6 | library(reshape2) 7 | library(ggpubr) 8 | library(ActivePathways) 9 | 10 | # Define not in 11 | `%ni%` <- Negate(`%in%`) 12 | 13 | ######################## DeepVelo pathway analysis ######################## 14 | 15 | # Deepvelo genes and associated p-values 16 | dvelo <- fread("../data/MDT_driver_genes.csv") 17 | 18 | # Get drivers ordered by gliogenic and gabaergic lineages 19 | gaba_dvelo <- dvelo[ 20 | order(dvelo$`GABA interneurons_pval`, decreasing = FALSE), 21 | ] 22 | glio_dvelo <- dvelo[ 23 | order(dvelo$`Gliogenic progenitors_pval`, decreasing = FALSE), 24 | ] 25 | colnames(gaba_dvelo)[1] <- "gene" 26 | colnames(glio_dvelo)[1] <- "gene" 27 | 28 | # Perform more strict bonferroni correction on p-values 29 | gaba_dvelo$gaba_bonfer_p <- p.adjust( 30 | gaba_dvelo$`GABA interneurons_pval`, 31 | method = "bonferroni" 32 | ) 33 | glio_dvelo$glio_bonfer_p <- p.adjust( 34 | glio_dvelo$`Gliogenic progenitors_pval`, 35 | method = "bonferroni" 36 | ) 37 | 38 | # Create dataframe for gaba and glio significant values and convert to numeric 39 | # matrix 40 | gaba_glio_signif <- merge(gaba_dvelo, glio_dvelo) 41 | gaba_glio_mat_signif <- as.matrix( 42 | gaba_glio_signif[, c("gaba_bonfer_p", "glio_bonfer_p")] 43 | ) 44 | rownames(gaba_glio_mat_signif) <- gaba_glio_signif$gene 45 | diff_exp_mat <- gaba_glio_mat_signif 46 | colnames(diff_exp_mat) <- c("GABAergic", "Gliogenic") 47 | 48 | # Run activepathways 49 | dir.create("../outs/02_activepathways_dvelo_gaba_glio_out", recursive = TRUE) 50 | gaba_glio_ap <- ActivePathways( 51 | diff_exp_mat, 52 | "../data/mm_go_mf_bp_reac_feb_2022.gmt", 53 | geneset.filter = c(5, 2000), 54 | significant = 0.05, 55 | cytoscape.file.tag = "../outs/02_activepathways_dvelo_gaba_glio_out" 56 | ) 57 | 58 | # Run activepathways per subset independantly 59 | gaba_mat_signif <- as.matrix(gaba_glio_signif[, c("gaba_bonfer_p")]) 60 | glio_mat_signif <- as.matrix(gaba_glio_signif[, c("glio_bonfer_p")]) 61 | rownames(gaba_mat_signif) <- gaba_glio_signif$gene 62 | rownames(glio_mat_signif) <- gaba_glio_signif$gene 63 | 64 | dir.create("../outs/02_activepathways_dvelo_gaba_out", recursive = TRUE) 65 | gaba_ap <- ActivePathways( 66 | gaba_mat_signif, 67 | "../data/mm_go_mf_bp_reac_feb_2022.gmt", 68 | geneset.filter = c(5, 2000), 69 | significant = 0.05, 70 | cytoscape.file.tag = "../outs/02_activepathways_dvelo_gaba_out" 71 | ) 72 | 73 | dir.create("../outs/02_activepathways_dvelo_glio_out", recursive = TRUE) 74 | glio_ap <- ActivePathways( 75 | glio_mat_signif, 76 | "../data/mm_go_mf_bp_reac_feb_2022.gmt", 77 | geneset.filter = c(5, 2000), 78 | significant = 0.05, 79 | cytoscape.file.tag = "../outs/02_activepathways_dvelo_glio_out" 80 | ) 81 | 82 | # Run ActivePathway on top 100, 250, 500, and full list subsets for both gaba and glio 83 | # independantly 84 | subset_list <- list(100, 250, 500) 85 | gaba_subset_mats <- lapply(subset_list, function(x) { 86 | subset <- gaba_glio_signif[order(gaba_glio_signif$gaba_bonfer_p, decreasing = FALSE), c("gene", "gaba_bonfer_p")] 87 | subset[(x+1):nrow(subset), c("gaba_bonfer_p")] <- 1 88 | colnames(subset)[2] <- paste0("GABAergic ", x) 89 | return(subset) 90 | }) 91 | glio_subset_mats <- lapply(subset_list, function(x) { 92 | subset <- gaba_glio_signif[order(gaba_glio_signif$glio_bonfer_p, decreasing = FALSE), c("gene", "glio_bonfer_p")] 93 | subset[(x+1):nrow(subset), c("glio_bonfer_p")] <- 1 94 | colnames(subset)[2] <- paste0("Gliogenic ", x) 95 | return(subset) 96 | }) 97 | glio_mat_subset <- as.data.frame(glio_mat_signif) 98 | gaba_mat_subset <- as.data.frame(gaba_mat_signif) 99 | glio_mat_subset$gene <- rownames(glio_mat_subset) 100 | gaba_mat_subset$gene <- rownames(gaba_mat_subset) 101 | colnames(glio_mat_subset)[1] <- "Gliogenic all" 102 | colnames(gaba_mat_subset)[1] <- "GABAergic all" 103 | 104 | # Merge all together 105 | glio_subsets_merged <- Reduce(merge, glio_subset_mats) 106 | gaba_subsets_merged <- Reduce(merge, gaba_subset_mats) 107 | subsets_list <- list( 108 | glio_subsets_merged, 109 | gaba_subsets_merged, 110 | glio_mat_subset, 111 | gaba_mat_subset 112 | ) 113 | subsets_merged <- Reduce(merge, subsets_list) 114 | subsets_merged_copy <- subsets_merged 115 | 116 | # Remove gene column and perform ActivePathways enrichment 117 | subsets_merged <- subsets_merged[, -c("gene")] 118 | subsets_merged_mat <- as.matrix(subsets_merged) 119 | rownames(subsets_merged_mat) <- subsets_merged_copy$gene 120 | 121 | # Perform ActivePathways enrichment analysis 122 | dir.create("../outs/02_activepathways_dvelo_glio_gaba_top_100_250_500_all_out", recursive = TRUE) 123 | glio_ap <- ActivePathways( 124 | subsets_merged_mat, 125 | "../data/mm_go_mf_bp_reac_feb_2022.gmt", 126 | geneset.filter = c(5, 2000), 127 | significant = 0.05, 128 | cytoscape.file.tag = "../outs/02_activepathways_dvelo_glio_gaba_top_100_250_500_all_out" 129 | ) 130 | 131 | # Run activepathways on only the top 100 - based on the correlation values instead of 132 | # the p-values 133 | glio_top_100 <- gaba_glio_signif[ 134 | order(gaba_glio_signif$`Gliogenic progenitors_corr`, decreasing = TRUE) 135 | ][0:100] 136 | glio_top_100 <- glio_top_100[, c("gene", "glio_bonfer_p")] 137 | glio_top_100_mat <- as.matrix(glio_top_100[, -1]) 138 | rownames(glio_top_100_mat) <- glio_top_100$gene 139 | colnames(glio_top_100_mat) <- "Gliogenic" 140 | 141 | dir.create("../outs/02_activepathways_dvelo_glio_top_100_corr_out", recursive = TRUE) 142 | glio_ap <- ActivePathways( 143 | glio_top_100_mat, 144 | "../data/mm_go_mf_bp_reac_feb_2022.gmt", 145 | geneset.filter = c(5, 2000), 146 | significant = 0.05, 147 | cytoscape.file.tag = "../outs/02_activepathways_dvelo_glio_top_100_corr_out" 148 | ) 149 | 150 | gaba_top_100 <- gaba_glio_signif[ 151 | order(gaba_glio_signif$`GABA interneurons_corr`, decreasing = TRUE) 152 | ][0:100] 153 | gaba_top_100 <- gaba_top_100[, c("gene", "gaba_bonfer_p")] 154 | gaba_top_100_mat <- as.matrix(gaba_top_100[, -1]) 155 | rownames(gaba_top_100_mat) <- gaba_top_100$gene 156 | colnames(gaba_top_100_mat) <- "GABAergic" 157 | 158 | dir.create("../outs/02_activepathways_dvelo_gaba_top_100_corr_out", recursive = TRUE) 159 | gaba_ap <- ActivePathways( 160 | gaba_top_100_mat, 161 | "../data/mm_go_mf_bp_reac_feb_2022.gmt", 162 | geneset.filter = c(5, 2000), 163 | significant = 0.05, 164 | cytoscape.file.tag = "../outs/02_activepathways_dvelo_gaba_top_100_corr_out" 165 | ) 166 | 167 | ######################## scVelo pathway analysis ######################## 168 | 169 | # scVelo genes and associated p-values 170 | scvelo <- fread("../data/MDT_driver_genes[DYNAMICAL].csv") 171 | 172 | # Get drivers ordered by gliogenic and gabaergic lineages 173 | gaba_scvelo <- scvelo[ 174 | order(scvelo$`GABA interneurons_pval`, decreasing = FALSE), 175 | ] 176 | glio_scvelo <- scvelo[ 177 | order(scvelo$`Gliogenic progenitors_pval`, decreasing = FALSE), 178 | ] 179 | colnames(gaba_scvelo)[1] <- "gene" 180 | colnames(glio_scvelo)[1] <- "gene" 181 | 182 | # Perform more strict bonferroni correction on p-values 183 | gaba_scvelo$gaba_bonfer_p <- p.adjust( 184 | gaba_scvelo$`GABA interneurons_pval`, 185 | method = "bonferroni" 186 | ) 187 | glio_scvelo$glio_bonfer_p <- p.adjust( 188 | glio_scvelo$`Gliogenic progenitors_pval`, 189 | method = "bonferroni" 190 | ) 191 | 192 | # Create dataframe for gaba and glio significant values and convert to numeric 193 | # matrix 194 | gaba_glio_signif <- merge(gaba_scvelo, glio_scvelo) 195 | gaba_glio_mat_signif <- as.matrix( 196 | gaba_glio_signif[, c("gaba_bonfer_p", "glio_bonfer_p")] 197 | ) 198 | rownames(gaba_glio_mat_signif) <- gaba_glio_signif$gene 199 | diff_exp_mat <- gaba_glio_mat_signif 200 | colnames(diff_exp_mat) <- c("GABAergic", "Gliogenic") 201 | 202 | # Run activepathways on only the top 100 - based on the correlation values instead of 203 | # the p-values 204 | glio_top_100 <- gaba_glio_signif[ 205 | order(gaba_glio_signif$`Gliogenic progenitors_corr`, decreasing = TRUE) 206 | ][0:100] 207 | glio_top_100 <- glio_top_100[, c("gene", "glio_bonfer_p")] 208 | glio_top_100_mat <- as.matrix(glio_top_100[, -1]) 209 | rownames(glio_top_100_mat) <- glio_top_100$gene 210 | colnames(glio_top_100_mat) <- "Gliogenic" 211 | 212 | dir.create("../outs/02_activepathways_scvelo_glio_top_100_corr_out", recursive = TRUE) 213 | glio_ap <- ActivePathways( 214 | glio_top_100_mat, 215 | "../data/mm_go_mf_bp_reac_feb_2022.gmt", 216 | geneset.filter = c(5, 2000), 217 | significant = 0.05, 218 | cytoscape.file.tag = "../outs/02_activepathways_scvelo_glio_top_100_corr_out" 219 | ) 220 | 221 | gaba_top_100 <- gaba_glio_signif[ 222 | order(gaba_glio_signif$`GABA interneurons_corr`, decreasing = TRUE) 223 | ][0:100] 224 | gaba_top_100 <- gaba_top_100[, c("gene", "gaba_bonfer_p")] 225 | gaba_top_100_mat <- as.matrix(gaba_top_100[, -1]) 226 | rownames(gaba_top_100_mat) <- gaba_top_100$gene 227 | colnames(gaba_top_100_mat) <- "GABAergic" 228 | 229 | dir.create("../outs/02_activepathways_scvelo_gaba_top_100_corr_out", recursive = TRUE) 230 | gaba_ap <- ActivePathways( 231 | gaba_top_100_mat, 232 | "../data/mm_go_mf_bp_reac_feb_2022.gmt", 233 | geneset.filter = c(5, 2000), 234 | significant = 0.05, 235 | cytoscape.file.tag = "../outs/02_activepathways_scvelo_gaba_top_100_corr_out" 236 | ) 237 | -------------------------------------------------------------------------------- /examples/r_analysis/scripts/04_deepvelo_tf_driver_analysis.R: -------------------------------------------------------------------------------- 1 | library(data.table) 2 | library(tidyr) 3 | library(dplyr) 4 | library(ggplot2) 5 | library(ggthemes) 6 | library(reshape2) 7 | library(ggpubr) 8 | library(ggvenn) 9 | 10 | # Define not in 11 | `%ni%` <- Negate(`%in%`) 12 | 13 | # Dynamical genes 14 | scvelo <- fread("../data/MDT_driver_genes[DYNAMICAL].csv") 15 | 16 | # Deepvelo genes 17 | dvelo <- fread("../data/MDT_driver_genes.csv") 18 | 19 | colnames(scvelo)[1] <- "gene" 20 | colnames(dvelo)[1] <- "gene" 21 | 22 | # Load transcription factor dataframe 23 | tf_df <- fread("../data/02_human_mouse_tfs_matched.tsv") 24 | 25 | # Subset for top 100 driver genes for GABAergic and Gliogenic for 26 | # both deepvelo and scvelo 27 | scvelo_top_100_gaba <- scvelo[ 28 | order(scvelo$`GABA interneurons_corr`, decreasing = TRUE) 29 | ][0:100] 30 | scvelo_top_100_glio <- scvelo[ 31 | order(scvelo$`Gliogenic progenitors_corr`, decreasing = TRUE) 32 | ][0:100] 33 | 34 | dvelo_top_100_gaba <- dvelo[ 35 | order(dvelo$`GABA interneurons_corr`, decreasing = TRUE) 36 | ][0:100] 37 | dvelo_top_100_glio <- dvelo[ 38 | order(dvelo$`Gliogenic progenitors_corr`, decreasing = TRUE) 39 | ][0:100] 40 | 41 | scvelo_top_100_gaba <- scvelo_top_100_gaba[,c("gene", "GABA interneurons_corr")] 42 | scvelo_top_100_glio <- scvelo_top_100_glio[,c("gene", "Gliogenic progenitors_corr")] 43 | dvelo_top_100_gaba <- dvelo_top_100_gaba[,c("gene", "GABA interneurons_corr")] 44 | dvelo_top_100_glio <- dvelo_top_100_glio[,c("gene", "Gliogenic progenitors_corr")] 45 | 46 | # Append TF info 47 | scvelo_top_100_gaba$tf <- ifelse( 48 | scvelo_top_100_gaba$gene %in% tf_df$mm_gene_name, 49 | "Yes", 50 | "No" 51 | ) 52 | scvelo_top_100_glio$tf <- ifelse( 53 | scvelo_top_100_glio$gene %in% tf_df$mm_gene_name, 54 | "Yes", 55 | "No" 56 | ) 57 | dvelo_top_100_gaba$tf <- ifelse( 58 | dvelo_top_100_gaba$gene %in% tf_df$mm_gene_name, 59 | "Yes", 60 | "No" 61 | ) 62 | dvelo_top_100_glio$tf <- ifelse( 63 | dvelo_top_100_glio$gene %in% tf_df$mm_gene_name, 64 | "Yes", 65 | "No" 66 | ) 67 | 68 | # Compare TF lengths across subsets 69 | tf_overlaps <- list( 70 | "scVelo \nGABAergic" = scvelo_top_100_gaba$gene[which(scvelo_top_100_gaba$tf %in% "Yes")], 71 | "DeepVelo \nGABAergic" = dvelo_top_100_gaba$gene[which(dvelo_top_100_gaba$tf %in% "Yes")], 72 | "scVelo \nGliogenic" = scvelo_top_100_glio$gene[which(scvelo_top_100_glio$tf %in% "Yes")], 73 | "DeepVelo \nGliogenic" = dvelo_top_100_glio$gene[which(dvelo_top_100_glio$tf %in% "Yes")] 74 | ) 75 | 76 | # Get VennDiagram of comparison 77 | ggvenn( 78 | tf_overlaps, 79 | set_name_size = 5 80 | ) 81 | ggsave( 82 | "../outs/figures/06_dvelo_scvelo_top_100_gaba_glio_tf_venn.pdf", 83 | height = 6, 84 | width = 12 85 | ) 86 | 87 | # Pick out transcription factors found unique to deepvelo for both lineages 88 | dvelo_top_100_gaba$gene_unique <- ifelse( 89 | dvelo_top_100_gaba$gene %in% scvelo_top_100_gaba$gene, 90 | "No", 91 | "Yes" 92 | ) 93 | dvelo_top_100_glio$gene_unique <- ifelse( 94 | dvelo_top_100_glio$gene %in% scvelo_top_100_glio$gene, 95 | "No", 96 | "Yes" 97 | ) 98 | 99 | # Create barplot of TF overlap between two methods for driver genes 100 | tf_overlap_deepvelo <- data.frame( 101 | "GABAergic" = length(which(dvelo_top_100_gaba$tf == "Yes")), 102 | "Gliogenic" = length(which(dvelo_top_100_glio$tf == "Yes")), 103 | "Method" = "DeepVelo" 104 | ) 105 | tf_overlap_scvelo <- data.table( 106 | "GABAergic" = length(which(scvelo_top_100_gaba$tf == "Yes")), 107 | "Gliogenic" = length(which(scvelo_top_100_glio$tf == "Yes")), 108 | "Method" = "scVelo" 109 | ) 110 | tf_overlap <- rbind(tf_overlap_deepvelo, tf_overlap_scvelo) 111 | tf_overlap_melted <- reshape2::melt(tf_overlap) 112 | colnames(tf_overlap_melted) <- c("Method", "Lineage", "TF_Overlap") 113 | 114 | ggplot(data = tf_overlap_melted, aes( 115 | x = Lineage, 116 | y = TF_Overlap, 117 | fill = factor(Method, levels = c("scVelo", "DeepVelo")) 118 | )) + 119 | geom_bar(stat = "identity", position = "dodge") + 120 | scale_fill_manual(values = c("#FFC20A", "#0C7BDC")) + 121 | theme_few() + 122 | labs( 123 | fill = "Method", 124 | x = "Lineage", 125 | y = "TF overlap in top 100 driver genes" 126 | ) + 127 | geom_text( 128 | aes(Lineage, label = TF_Overlap), 129 | position = position_dodge(width = 1), 130 | vjust = -0.2, 131 | size = 5.5 132 | ) + 133 | theme(axis.text.y = element_text(size = 20)) + 134 | theme(axis.text.x = element_text(size = 20)) + 135 | theme(axis.title.x = element_text(size = 20, face = "bold")) + 136 | theme(axis.title.y = element_text(size = 20, face = "bold")) + 137 | theme(legend.title = element_text(size = 20, face = "bold")) + 138 | theme(legend.text = element_text(size = 20)) + 139 | theme(aspect.ratio = 1) 140 | ggsave( 141 | "../outs/figures/04_deepvelo_glio_gaba_tf_overlap_comparison.pdf", 142 | height = 7, 143 | width = 7 144 | ) -------------------------------------------------------------------------------- /examples/r_analysis/scripts/05_supplementary_table_formatting.R: -------------------------------------------------------------------------------- 1 | library(data.table) 2 | library(tidyr) 3 | library(dplyr) 4 | library(ggplot2) 5 | library(ggthemes) 6 | library(reshape2) 7 | library(ggpubr) 8 | 9 | # Load marker list from Vladiou et al. 10 | vla_markers <- fread("../data/41586_2019_1158_MOESM4_ESM.csv") 11 | vla_markers <- vla_markers[, -1] 12 | 13 | # Fill down values 14 | vla_markers$Annotations <- na_if(vla_markers$Annotations, "") 15 | vla_markers <- vla_markers %>% tidyr::fill(Annotations, .direction = "down") 16 | 17 | # Subset and save for GABAergic and gliogenic markers 18 | gaba_glio_markers <- vla_markers[vla_markers$Annotations %in% c( 19 | "GABA interneurons", 20 | "Differentiating GABA interneurons", 21 | "Gliogenic progenitors" 22 | )] 23 | fwrite( 24 | gaba_glio_markers, 25 | "../outs/05_marker_gaba_glio_subset.tsv", 26 | sep = "\t", 27 | quote = FALSE, 28 | row.names = FALSE, 29 | col.names = TRUE 30 | ) 31 | 32 | -------------------------------------------------------------------------------- /examples/r_analysis/scripts/06_full_scores_dotplot.R: -------------------------------------------------------------------------------- 1 | library(ggplot2) 2 | library(ggthemes) 3 | library(data.table) 4 | library(rjson) 5 | library(plyr) 6 | library(stringr) 7 | library(tibble) 8 | library(RColorBrewer) 9 | library(dplyr) 10 | library(scales) 11 | library(ggimage) 12 | library(cowplot) 13 | 14 | # Change to data directory 15 | setwd("../data/") 16 | 17 | # Load all of the json files 18 | json_files <- list.files(pattern = ".json") 19 | 20 | # Map the json file names to dataset names 21 | json_names <- stringr::str_split_fixed( 22 | json_files, 23 | pattern = "_eval_results", 24 | n = 3 25 | )[,1] 26 | dataset_names <- plyr::mapvalues( 27 | json_names, 28 | from = c( 29 | "figure2_dentategyrus", 30 | "figure3_pancreas", 31 | "figure4_hindbrain", 32 | "organogenesis_chondrocyte", 33 | "la_manno_hippocampus" 34 | ), 35 | to = c( 36 | "Dentate gyrus", 37 | "Pancreas", 38 | "Hindbrain", 39 | "Chondrocyte", 40 | "Hippocampus" 41 | ) 42 | ) 43 | 44 | # Load all of the files into a dataframe object 45 | jsons_loaded <- lapply(json_files, function(x) { 46 | json <- fromJSON(file = x) 47 | return(json) 48 | }) 49 | 50 | dfs_per_dataset <- lapply(jsons_loaded, function(x) { 51 | dir_scores <- unlist(x$direction_score[1:4]) 52 | oc_scores <- unlist(x$overall_consistency[1:4]) 53 | csc_scores <- unlist(x$celltype_consistency[1:4]) 54 | methods <- c("DeepVelo", "scVelo (dynamical)", "Velocyto (steady state)", "CellDancer") 55 | results_df <- data.frame( 56 | "Method" = methods, 57 | "Direction score" = dir_scores, 58 | "Overall consistency" = oc_scores, 59 | "Cell-type wise consistency" = csc_scores 60 | ) 61 | return(results_df) 62 | }) 63 | 64 | # Add dataset names to the dataframes 65 | dfs_per_dataset <- mapply( 66 | FUN = function(x, y) { 67 | x$Dataset <- y 68 | return(x) 69 | }, 70 | x = dfs_per_dataset, 71 | y = dataset_names, 72 | SIMPLIFY = FALSE 73 | ) 74 | 75 | # Combine all of the dataframes into one 76 | dfs_combined <- Reduce(rbind, dfs_per_dataset) 77 | 78 | # Save the processed dataframe 79 | fwrite( 80 | dfs_combined, 81 | file = "full_scores_processed.csv", 82 | sep = ",", 83 | row.names = FALSE, 84 | col.names = TRUE, 85 | quote = FALSE 86 | ) 87 | 88 | # Source the scib knit table function 89 | source("../helpers/scIB_knit_table.R") 90 | 91 | # Scale the scores between 0 and 1 92 | relevant_columns = c("Direction score", "Overall consistency", "Cell-type wise consistency") 93 | colnames(dfs_combined) <- c( 94 | "Method", 95 | "Direction score", 96 | "Overall consistency", 97 | "Cell-type wise consistency", 98 | "Dataset" 99 | ) 100 | dfs_combined[, relevant_columns] <- lapply( 101 | dfs_combined[, relevant_columns], 102 | function(x) { 103 | return((x - min(x)) / (max(x) - min(x))) 104 | } 105 | ) 106 | 107 | # Create an overall score column that uses the formula - 0.5*direction_score + 0.25*overall_consistency + 0.25*celltype_consistency 108 | dfs_combined$`Overall score` <- 109 | 0.5*dfs_combined$`Direction score` + 110 | 0.25*dfs_combined$`Overall consistency` + 111 | 0.25*dfs_combined$`Cell-type wise consistency` 112 | 113 | # Prepare the data, column info, row info, and palettes 114 | data <- dfs_combined 115 | data$Embedding <- rep("graph", nrow(data)) 116 | data <- data[, c("Method", "Dataset", "Overall score", "Overall consistency", "Cell-type wise consistency", "Direction score")] 117 | row_info <- data.frame("id" = data$Method, "group" = NA) 118 | column_info <- data.frame( 119 | "id" = colnames(data), 120 | "group" = c("Text", "Text", "S0", "S1", "S2", "S3"), 121 | "geom" = c("text", "text", "bar", "bar", "bar", "bar"), 122 | "width" = c(6, 2.5, 2, 2, 2, 2), 123 | overlay = FALSE 124 | ) 125 | palettes <- list( 126 | "S0" = "YlOrRd", 127 | "S1" = "YlGnBu", 128 | "S2" = "BuPu", 129 | "S3" = "RdPu" 130 | ) 131 | g <- scIB_knit_table(data = data, column_info = column_info, row_info = row_info, palettes = palettes, usability = F) 132 | now <- Sys.time() 133 | outdir <- "../outs/figures/" 134 | ggsave(paste0(outdir, "/", format(now, "%Y%m%d_%H%M%S_"), "velocity_summary_metrics.pdf"), g, device = cairo_pdf, width = 297, height = 420, units = "mm") 135 | ggsave(paste0(outdir, "/", format(now, "%Y%m%d_%H%M%S_"), "velocity_summary_metrics.tiff"), g, device = "tiff", dpi = "retina", width = 297, height = 420, units = "mm") 136 | ggsave(paste0(outdir, "/", format(now, "%Y%m%d_%H%M%S_"), "velocity_summary_metrics.png"), g, device = "png", dpi = "retina", width = 297, height = 420, units = "mm") -------------------------------------------------------------------------------- /examples/sweep_robustness.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import numpy as np 3 | import scvelo as scv 4 | import torch 5 | from umap import UMAP 6 | from sklearn.decomposition import PCA 7 | from scipy.stats import mannwhitneyu 8 | import wandb 9 | 10 | from deepvelo.utils import velocity, velocity_confidence, update_dict 11 | from deepvelo.utils.preprocess import autoset_coeff_s 12 | from deepvelo.utils.plot import statplot, compare_plot 13 | from deepvelo import train, Constants 14 | 15 | hyperparameter_defaults = dict( 16 | seed=123, 17 | layers=[64, 64], 18 | topC=30, 19 | topG=20, 20 | lr=0.001, 21 | pearson_scale=18.0, 22 | pp_hvg=2000, 23 | pp_neighbors=30, 24 | pp_pcs=30 25 | # NOTE: add any hyperparameters you want to sweep here 26 | ) 27 | run = wandb.init(config=hyperparameter_defaults, project="scFormer", reinit=True) 28 | wargs = wandb.config 29 | 30 | 31 | # fix random seeds for reproducibility 32 | SEED = wargs.seed 33 | torch.manual_seed(SEED) 34 | torch.backends.cudnn.deterministic = True 35 | torch.backends.cudnn.benchmark = False 36 | np.random.seed(SEED) 37 | 38 | scv.settings.verbosity = 3 # show errors(0), warnings(1), info(2), hints(3) 39 | scv.settings.set_figure_params( 40 | "scvelo", transparent=False 41 | ) # for beautified visualization 42 | 43 | # %% [markdown] 44 | # # Load DG data and preprocess 45 | 46 | # %% 47 | adata = scv.datasets.dentategyrus() 48 | scv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=wargs.pp_hvg) 49 | scv.pp.moments(adata, n_neighbors=wargs.pp_neighbors, n_pcs=wargs.pp_pcs) 50 | 51 | # %% [markdown] 52 | # # DeepVelo 53 | 54 | # %% 55 | # specific configs to overide the default configs, #NOTE: see train.py for complete args 56 | configs = { 57 | "name": "DeepVelo", # name of the experiment 58 | "arch": { 59 | "args": { 60 | "layers": wargs.layers, 61 | }, 62 | }, 63 | "data_loader": { 64 | "args": { 65 | "topC": wargs.topC, 66 | "topG": wargs.topG, 67 | }, 68 | }, 69 | "optimizer": { 70 | "args": { 71 | "lr": wargs.lr, 72 | }, 73 | }, 74 | "loss": { 75 | "args": { 76 | "pearson_scale": wargs.pearson_scale, 77 | "coeff_s": autoset_coeff_s(adata), 78 | }, 79 | }, 80 | "trainer": {"verbosity": 0}, # increase verbosity to show training progress 81 | } 82 | configs = update_dict(Constants.default_configs, configs) 83 | 84 | # %% 85 | # initial velocity 86 | velocity(adata, mask_zero=False) 87 | trainer = train(adata, configs) 88 | 89 | 90 | # %% 91 | scv.tl.velocity_graph(adata, n_jobs=8) 92 | 93 | # %% 94 | # velocity plot 95 | scv.pl.velocity_embedding_stream( 96 | adata, 97 | basis="umap", 98 | color="clusters", 99 | legend_fontsize=9, 100 | dpi=150, # increase dpi for higher resolution 101 | show=False, 102 | ) 103 | # NOTE: may log the plot to wandb using wandb.log({"velocity_embedding_stream": wandb.Image(plt)}) 104 | 105 | 106 | # %% 107 | scv.pl.velocity_embedding( 108 | adata, 109 | basis="umap", 110 | arrow_length=6, 111 | arrow_size=1.2, 112 | dpi=150, 113 | show=False, 114 | ) 115 | 116 | 117 | # %% 118 | scv.pl.velocity_embedding_grid( 119 | adata, 120 | basis="umap", 121 | arrow_length=4, 122 | # alpha=0.1, 123 | arrow_size=2, 124 | arrow_color="tab:blue", 125 | dpi=150, 126 | show=False, 127 | ) 128 | 129 | 130 | # %% 131 | # get kinetic_rates 132 | if "cell_specific_alpha" in adata.layers: 133 | all_rates = np.concatenate( 134 | [ 135 | adata.layers["cell_specific_beta"], 136 | adata.layers["cell_specific_gamma"], 137 | adata.layers["cell_specific_alpha"], 138 | ], 139 | axis=1, 140 | ) 141 | else: 142 | all_rates = np.concatenate( 143 | [ 144 | adata.layers["cell_specific_beta"], 145 | adata.layers["cell_specific_gamma"], 146 | ], 147 | axis=1, 148 | ) 149 | # pca and umap of all rates 150 | rates_pca = PCA(n_components=30, random_state=SEED).fit_transform(all_rates) 151 | adata.obsm["X_rates_pca"] = rates_pca 152 | 153 | rates_umap = UMAP( 154 | n_neighbors=60, 155 | min_dist=0.6, 156 | spread=0.9, 157 | random_state=SEED, 158 | ).fit_transform(rates_pca) 159 | adata.obsm["X_rates_umap"] = rates_umap 160 | 161 | 162 | # %% 163 | # plot kinetic rates umap 164 | scv.pl.scatter( 165 | adata, 166 | basis="rates_umap", 167 | # omit_velocity_fit=True, 168 | add_outline="Granule mature, Granule immature, Neuroblast", 169 | outline_width=(0.15, 0.3), 170 | title="umap of cell-specific kinetic rates", 171 | legend_loc="none", 172 | dpi=150, 173 | show=False, 174 | ) 175 | 176 | # %% 177 | # plot genes 178 | scv.pl.velocity( 179 | adata, 180 | var_names=["Tmsb10", "Ppp3ca"], 181 | basis="umap", 182 | show=False, 183 | ) 184 | 185 | # %% 186 | # save adata for next steps 187 | deepvelo_adata = adata.copy() 188 | 189 | 190 | # %% [markdown] 191 | # # Compare consistency score 192 | 193 | # %% 194 | vkey = "velocity" 195 | method = "cosine" 196 | velocity_confidence(deepvelo_adata, vkey=vkey, method=method) 197 | deepvelo_adata.obs["overall_consistency"] = deepvelo_adata.obs[ 198 | f"{vkey}_confidence_{method}" 199 | ].copy() 200 | 201 | # %% 202 | vkey = "velocity" 203 | method = "cosine" 204 | scope_key = "clusters" 205 | # 3. cosine similarity, compute within Celltype 206 | velocity_confidence(deepvelo_adata, vkey=vkey, method=method, scope_key=scope_key) 207 | deepvelo_adata.obs["celltype_consistency"] = deepvelo_adata.obs[ 208 | f"{vkey}_confidence_{method}" 209 | ].copy() 210 | 211 | 212 | # NOTE: example of logging metrics 213 | wandb.log( 214 | { 215 | "celltype_consistency": deepvelo_adata.obs["celltype_consistency"].mean(), 216 | "overall_consistency": deepvelo_adata.obs["overall_consistency"].mean(), 217 | } 218 | ) 219 | -------------------------------------------------------------------------------- /examples/sweep_robustness.yaml: -------------------------------------------------------------------------------- 1 | program: sweep_robustness.py 2 | method: random 3 | name: sweep_deepvelo 4 | project: deepvelo 5 | metric: 6 | name: celltype_consistency 7 | goal: maximize 8 | parameters: 9 | lr: 10 | disribution: uniform 11 | min: 0.0001 12 | max: 0.1 13 | # optimizer: 14 | # values: ["Adam", "SGD", "RMSProp"] - not sure if last spelled correctly 15 | topC: 16 | values: [5, 10, 20, 30, 40, 50] 17 | topG: 18 | values: [5, 10, 20, 30, 40, 50] 19 | layers: 20 | values: 21 | [[32, 32], [64, 64], [128, 128], [256, 256], [512, 512], [64, 64, 64]] 22 | pearson_scale: 23 | distribution: int_uniform 24 | min: 1 25 | max: 100 26 | n_hvg: 27 | values: [500, 1000, 2000, 2500, 5000] 28 | pp_neighbors: 29 | values: [5, 10, 15, 20, 30, 40, 50] 30 | pp_pcs: 31 | values: [10, 20, 30, 40, 50] 32 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "deepvelo" 3 | version = "0.2.8" 4 | description = "Deep Velocity" 5 | authors = ["subercui "] 6 | readme = "README.md" 7 | license = "MIT" 8 | homepage = "https://github.com/bowang-lab/DeepVelo" 9 | repository = "https://github.com/bowang-lab/DeepVelo" 10 | 11 | [tool.poetry.dependencies] 12 | python = ">=3.7.1,<3.10" 13 | scvelo = "^0.2.4" 14 | torch = ">=1.2,<1.13" 15 | umap-learn = ">=0.5.2,<=0.5.4" 16 | seaborn = "^0.11.2" 17 | adjustText = "^0.7.3" 18 | scanpy = "^1.8.2" 19 | numpy = "^1.21.1" 20 | tqdm = "^4.62.3" 21 | matplotlib = ">=3.3,<3.6" 22 | dgl = { version = ">=0.4,!=0.8.0.post1,<0.9", markers="extra!='gpu'" } 23 | dgl-cu102 = { version = ">=0.4,!=0.8.0.post1,<0.9", optional = true } 24 | igraph = "^0.9.10" 25 | hnswlib = "^0.6.2" 26 | 27 | [tool.poetry.dev-dependencies] 28 | ipykernel = "^6.7.0" 29 | cellrank = { version = "1.5.0", optional = true } 30 | 31 | [tool.poetry.extras] 32 | gpu = ["dgl-cu102"] 33 | cellrank = ["cellrank>=1.5.0"] 34 | 35 | [build-system] 36 | requires = ["poetry-core>=1.0.0"] 37 | build-backend = "poetry.core.masonry.api" 38 | -------------------------------------------------------------------------------- /tests/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from tqdm import tqdm 4 | import deepvelo.data_loader.data_loaders as module_data 5 | import deepvelo.model.loss as module_loss 6 | import deepvelo.model.metric as module_metric 7 | import deepvelo.model.model as module_arch 8 | from deepvelo.parse_config import ConfigParser 9 | 10 | 11 | def main(config): 12 | logger = config.get_logger("test") 13 | 14 | # setup data_loader instances 15 | data_loader = getattr(module_data, config["data_loader"]["type"])( 16 | config["data_loader"]["args"]["data_dir"], 17 | batch_size=512, 18 | shuffle=False, 19 | validation_split=0.0, 20 | training=False, 21 | num_workers=2, 22 | ) 23 | 24 | # build model architecture 25 | model = config.init_obj("arch", module_arch) 26 | logger.info(model) 27 | 28 | # get function handles of loss and metrics 29 | loss_fn = getattr(module_loss, config["loss"]) 30 | metric_fns = [getattr(module_metric, met) for met in config["metrics"]] 31 | 32 | logger.info("Loading checkpoint: {} ...".format(config.resume)) 33 | checkpoint = torch.load(config.resume) 34 | state_dict = checkpoint["state_dict"] 35 | if config["n_gpu"] > 1: 36 | model = torch.nn.DataParallel(model) 37 | model.load_state_dict(state_dict) 38 | 39 | # prepare model for testing 40 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 41 | model = model.to(device) 42 | model.eval() 43 | 44 | total_loss = 0.0 45 | total_metrics = torch.zeros(len(metric_fns)) 46 | 47 | with torch.no_grad(): 48 | for i, (data, target) in enumerate(tqdm(data_loader)): 49 | data, target = data.to(device), target.to(device) 50 | output = model(data) 51 | 52 | # 53 | # save sample images, or do something with output here 54 | # 55 | 56 | # computing loss, metrics on test set 57 | loss = loss_fn(output, target) 58 | batch_size = data.shape[0] 59 | total_loss += loss.item() * batch_size 60 | for i, metric in enumerate(metric_fns): 61 | total_metrics[i] += metric(output, target) * batch_size 62 | 63 | n_samples = len(data_loader.sampler) 64 | log = {"loss": total_loss / n_samples} 65 | log.update( 66 | { 67 | met.__name__: total_metrics[i].item() / n_samples 68 | for i, met in enumerate(metric_fns) 69 | } 70 | ) 71 | logger.info(log) 72 | 73 | 74 | if __name__ == "__main__": 75 | args = argparse.ArgumentParser(description="PyTorch Template") 76 | args.add_argument( 77 | "-c", 78 | "--config", 79 | default=None, 80 | type=str, 81 | help="config file path (default: None)", 82 | ) 83 | args.add_argument( 84 | "-r", 85 | "--resume", 86 | default=None, 87 | type=str, 88 | help="path to latest checkpoint (default: None)", 89 | ) 90 | args.add_argument( 91 | "-d", 92 | "--device", 93 | default=None, 94 | type=str, 95 | help="indices of GPUs to enable (default: all)", 96 | ) 97 | 98 | config = ConfigParser.from_args(args) 99 | main(config) 100 | --------------------------------------------------------------------------------