├── .DS_Store ├── .gitignore ├── .travis.yml ├── LICENSE ├── README.rst ├── docs ├── .DS_Store ├── 00_hello_sagenet.ipynb ├── 01_multiple_references.ipynb ├── Makefile ├── _static │ ├── .DS_Store │ ├── custom.css │ └── img │ │ └── logo.png ├── about.rst ├── api.rst ├── conf.py ├── index.rst └── make.bat ├── figures ├── show_ad_r1_all.pdf └── show_ad_r1_all_conf.pdf ├── notebooks ├── 00_hello_sagenet.ipynb ├── 01_multiple_references.ipynb ├── SageNet_developing_human_heart_analysis.ipynb └── SageNet_mouse_gastrulation_analysis.ipynb ├── requirements.txt ├── sagenet ├── .DS_Store ├── DHH_data │ ├── _DHH_data.py │ └── __init__.py ├── DHH_data_ │ ├── _DHH_data_.py │ └── __init__.py ├── MGA_analysis.py ├── MGA_data │ ├── _MGA_data.py │ └── __init__.py ├── __init__.py ├── classifier.py ├── datasets │ ├── __init__.py │ └── _datasets.py ├── model.py ├── sage.py └── utils.py └── setup.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarioniLab/sagenet/5dce6dd375cf28678d735f5e5d8083dfd3d86596/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | dist: xenial 3 | cache: pip 4 | python: 5 | - "3.7" 6 | - "3.8" 7 | - "3.9" 8 | 9 | install: 10 | - pip install -r requirements.txt 11 | - python setup.py install 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Marioni Laboratory 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | |PyPI| |Docs| 2 | 3 | SageNet: Spatial reconstruction of dissociated single-cell datasets using graph neural networks 4 | ========================================================================= 5 | .. raw:: html 6 | 7 | **SageNet** is a robust and generalizable graph neural network approach that probabilistically maps dissociated single cells from an scRNAseq dataset to their hypothetical tissue of origin using one or more reference datasets aquired by spatially resolved transcriptomics techniques. It is compatible with both high-plex imaging (e.g., seqFISH, MERFISH, etc.) and spatial barcoding (e.g., 10X visium, Slide-seq, etc.) datasets as the spatial reference. 8 | 9 | .. raw:: html 10 | 11 |

12 | 13 | sagenet logo 15 | 16 |

17 | 18 | 19 | 20 | SageNet is implemented with `pytorch `_ and `pytorch-geometric `_ to be modular, fast, and scalable. Also, it uses ``anndata`` to be compatible with `scanpy `_ and `squidpy `_ for pre- and post-processing steps. 21 | 22 | 23 | .. raw:: html 24 | 25 |

26 | 27 | sagenet workflow 29 | 30 |

31 | 32 | 33 | 34 | Installation 35 | ============ 36 | 37 | 38 | .. note:: 39 | 40 | **v1.0** 41 | The dependency ``torch-geometric`` should be installed separately, corresponding the system specefities, look at `this link `_ for instructions. We recommend to use Miniconda. 42 | 43 | 44 | GitHub (currently recomended) 45 | --------------- 46 | 47 | First, clone the repository using ``git``:: 48 | 49 | git clone https://github.com/MarioniLab/sagenet 50 | 51 | Then, ``cd`` to the sagenet folder and run the install command:: 52 | 53 | cd sagenet 54 | python setup.py install #or pip install . 55 | 56 | PyPI 57 | -------- 58 | 59 | The easiest way to get SageNet is through pip using the following command:: 60 | 61 | pip install sagenet 62 | 63 | 64 | 65 | 66 | Usage 67 | ============ 68 | 69 | :: 70 | 71 | import sagenet as sg 72 | import scanpy as sc 73 | import squidpy as sq 74 | import anndata as ad 75 | import random 76 | random.seed(10) 77 | 78 | 79 | Training phase: 80 | --------------- 81 | 82 | 83 | **Input:** 84 | 85 | - Expression matrix associated with the (spatial) reference dataset (an ``anndata`` object) 86 | 87 | :: 88 | 89 | adata_r = sg.MGA_data.seqFISH1() 90 | 91 | 92 | - gene-gene interaction network 93 | 94 | 95 | :: 96 | 97 | glasso(adata_r, [0.5, 0.75, 1]) 98 | 99 | 100 | 101 | 102 | - one or more partitionings of the spatial reference into distinct connected neighborhoods of cells or spots 103 | 104 | :: 105 | 106 | adata_r.obsm['spatial'] = np.array(adata_r.obs[['x','y']]) 107 | sq.gr.spatial_neighbors(adata_r, coord_type="generic") 108 | sc.tl.leiden(adata_r, resolution=.01, random_state=0, key_added='leiden_0.01', adjacency=adata_r.obsp["spatial_connectivities"]) 109 | sc.tl.leiden(adata_r, resolution=.05, random_state=0, key_added='leiden_0.05', adjacency=adata_r.obsp["spatial_connectivities"]) 110 | sc.tl.leiden(adata_r, resolution=.1, random_state=0, key_added='leiden_0.1', adjacency=adata_r.obsp["spatial_connectivities"]) 111 | sc.tl.leiden(adata_r, resolution=.5, random_state=0, key_added='leiden_0.5', adjacency=adata_r.obsp["spatial_connectivities"]) 112 | sc.tl.leiden(adata_r, resolution=1, random_state=0, key_added='leiden_1', adjacency=adata_r.obsp["spatial_connectivities"]) 113 | 114 | 115 | 116 | **Training:** 117 | :: 118 | 119 | 120 | sg_obj = sg.sage.sage(device=device) 121 | sg_obj.add_ref(adata_r, comm_columns=['leiden_0.01', 'leiden_0.05', 'leiden_0.1', 'leiden_0.5', 'leiden_1'], tag='seqFISH_ref', epochs=20, verbose = False) 122 | 123 | 124 | 125 | **Output:** 126 | 127 | - A set of pre-trained models (one for each partitioning) 128 | 129 | :: 130 | 131 | 132 | !mkdir models 133 | !mkdir models/seqFISH_ref 134 | sg_obj.save_model_as_folder('models/seqFISH_ref') 135 | 136 | 137 | - A set of Spatially Informative Genes 138 | 139 | :: 140 | 141 | 142 | ind = np.where(adata_r.var['ST_all_importance'] <= 5)[0] 143 | SIGs = list(adata_r.var_names[ind]) 144 | with rc_context({'figure.figsize': (4, 4)}): 145 | sc.pl.spatial(adata_r, color=SIGs, ncols=4, spot_size=0.03, legend_loc=None) 146 | 147 | 148 | .. raw:: html 149 | 150 |

151 | 152 | spatial markers 154 | 155 |

156 | 157 | 158 | 159 | 160 | Mapping phase 161 | --------------- 162 | 163 | **Input:** 164 | 165 | - Expression matrix associated with the (dissociated) query dataset (an ``anndata`` object) 166 | :: 167 | 168 | adata_q = sg.MGA_data.scRNAseq() 169 | 170 | 171 | **Mapping:** 172 | :: 173 | 174 | sg_obj.map_query(adata_q) 175 | 176 | 177 | **Output:** 178 | 179 | - The reconstructed cell-cell spatial distance matrix 180 | :: 181 | 182 | 183 | adata_q.obsm['dist_map'] 184 | 185 | 186 | - A consensus scoring of mappability (uncertainity of mapping) of each cell to the references 187 | :: 188 | 189 | 190 | adata_q.obs 191 | 192 | 193 | :: 194 | 195 | 196 | import anndata 197 | dist_adata = anndata.AnnData(adata_q.obsm['dist_map'], obs = adata_q.obs) 198 | knn_indices, knn_dists, forest = sc.neighbors.compute_neighbors_umap(dist_adata.X, n_neighbors=50, metric='precomputed') 199 | dist_adata.obsp['distances'], dist_adata.obsp['connectivities'] = sc.neighbors._compute_connectivities_umap( 200 | knn_indices, 201 | knn_dists, 202 | dist_adata.shape[0], 203 | 50, # change to neighbors you plan to use 204 | ) 205 | sc.pp.neighbors(dist_adata, metric='precomputed', use_rep='X') 206 | sc.tl.umap(dist_adata) 207 | sc.pl.umap(dist_adata, color='cell_type', palette=celltype_colours) 208 | 209 | 210 | .. raw:: html 211 | 212 |

213 | 214 | reconstructed space 216 | 217 |

218 | 219 | 220 | Notebooks 221 | ============ 222 | To see some examples of our pipeline's capability, look at the `notebooks `_ directory. The notebooks are also available on google colab: 223 | 224 | #. `Spatial reconstruction of the developing human heart `_ 225 | #. `Spatial reconstruction of the mouse embryo `_ 226 | 227 | Interactive examples 228 | ============ 229 | * `Spatial mapping of the mouse gastrulation atlas `_ 230 | 231 | 232 | Support and contribute 233 | ============ 234 | If you have a question or new architecture or a model that could be integrated into our pipeline, you can 235 | post an `issue `__ or reach us by `email `_. 236 | 237 | 238 | Contributions 239 | ============ 240 | This work is led by Elyas Heidari and Shila Ghazanfar as a joint effort between `MarioniLab@CRUK@EMBL-EBI `__ and `RobinsonLab@UZH `__. 241 | 242 | .. |Docs| image:: https://readthedocs.org/projects/sagenet/badge/?version=latest 243 | :target: https://sagenet.readthedocs.io 244 | 245 | .. |PyPI| image:: https://img.shields.io/pypi/v/sagenet.svg 246 | :target: https://pypi.org/project/sagenet 247 | 248 | .. |travis| image:: https://travis-ci.com/MarioniLab/sagenet.svg?branch=main 249 | :target: https://travis-ci.com/MarioniLab/sagenet 250 | -------------------------------------------------------------------------------- /docs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarioniLab/sagenet/5dce6dd375cf28678d735f5e5d8083dfd3d86596/docs/.DS_Store -------------------------------------------------------------------------------- /docs/01_multiple_references.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# Multiple references\n", 7 | "\n", 8 | "In this notebook we show installation and basic usage of **SageNet**. " 9 | ], 10 | "metadata": { 11 | "id": "B8TiAQi7hwXH" 12 | } 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": { 18 | "id": "5hXws4b_EPoR" 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu111.html \n", 23 | "!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu111.html \n", 24 | "!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git " 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": { 31 | "id": "i8t1adiyJ51v" 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "!pwd" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": { 42 | "colab": { 43 | "base_uri": "https://localhost:8080/" 44 | }, 45 | "id": "Qgc-2ZeGEQIo", 46 | "outputId": "de2de8e9-a1f9-4d1f-c807-3e659586042f" 47 | }, 48 | "outputs": [ 49 | { 50 | "name": "stdout", 51 | "output_type": "stream", 52 | "text": [ 53 | "fatal: destination path 'sagenet' already exists and is not an empty directory.\n", 54 | "/content/sagenet/sagenet\n", 55 | "\u001b[31mERROR: Directory '.' is not installable. Neither 'setup.py' nor 'pyproject.toml' found.\u001b[0m\n" 56 | ] 57 | } 58 | ], 59 | "source": [ 60 | "!git clone https://github.com/MarioniLab/sagenet\n", 61 | "%cd sagenet\n", 62 | "!pip install ." 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": { 69 | "id": "w-znoEjrERTp" 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "import sagenet as sg\n", 74 | "import scanpy as sc\n", 75 | "import squidpy as sq\n", 76 | "import anndata as ad\n", 77 | "import random\n", 78 | "random.seed(10)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": { 85 | "id": "stz_3hVlGuaF" 86 | }, 87 | "outputs": [], 88 | "source": [ 89 | "celltype_colours = {\n", 90 | " \"Epiblast\" : \"#635547\",\n", 91 | " \"Primitive Streak\" : \"#DABE99\",\n", 92 | " \"Caudal epiblast\" : \"#9e6762\",\n", 93 | " \"PGC\" : \"#FACB12\",\n", 94 | " \"Anterior Primitive Streak\" : \"#c19f70\",\n", 95 | " \"Notochord\" : \"#0F4A9C\",\n", 96 | " \"Def. endoderm\" : \"#F397C0\",\n", 97 | " \"Definitive endoderm\" : \"#F397C0\",\n", 98 | " \"Gut\" : \"#EF5A9D\",\n", 99 | " \"Gut tube\" : \"#EF5A9D\",\n", 100 | " \"Nascent mesoderm\" : \"#C594BF\",\n", 101 | " \"Mixed mesoderm\" : \"#DFCDE4\",\n", 102 | " \"Intermediate mesoderm\" : \"#139992\",\n", 103 | " \"Caudal Mesoderm\" : \"#3F84AA\",\n", 104 | " \"Paraxial mesoderm\" : \"#8DB5CE\",\n", 105 | " \"Somitic mesoderm\" : \"#005579\",\n", 106 | " \"Pharyngeal mesoderm\" : \"#C9EBFB\",\n", 107 | " \"Splanchnic mesoderm\" : \"#C9EBFB\",\n", 108 | " \"Cardiomyocytes\" : \"#B51D8D\",\n", 109 | " \"Allantois\" : \"#532C8A\",\n", 110 | " \"ExE mesoderm\" : \"#8870ad\",\n", 111 | " \"Lateral plate mesoderm\" : \"#8870ad\",\n", 112 | " \"Mesenchyme\" : \"#cc7818\",\n", 113 | " \"Mixed mesenchymal mesoderm\" : \"#cc7818\",\n", 114 | " \"Haematoendothelial progenitors\" : \"#FBBE92\",\n", 115 | " \"Endothelium\" : \"#ff891c\",\n", 116 | " \"Blood progenitors 1\" : \"#f9decf\",\n", 117 | " \"Blood progenitors 2\" : \"#c9a997\",\n", 118 | " \"Erythroid1\" : \"#C72228\",\n", 119 | " \"Erythroid2\" : \"#f79083\",\n", 120 | " \"Erythroid3\" : \"#EF4E22\",\n", 121 | " \"Erythroid\" : \"#f79083\",\n", 122 | " \"Blood progenitors\" : \"#f9decf\",\n", 123 | " \"NMP\" : \"#8EC792\",\n", 124 | " \"Rostral neurectoderm\" : \"#65A83E\",\n", 125 | " \"Caudal neurectoderm\" : \"#354E23\",\n", 126 | " \"Neural crest\" : \"#C3C388\",\n", 127 | " \"Forebrain/Midbrain/Hindbrain\" : \"#647a4f\",\n", 128 | " \"Spinal cord\" : \"#CDE088\",\n", 129 | " \"Surface ectoderm\" : \"#f7f79e\",\n", 130 | " \"Visceral endoderm\" : \"#F6BFCB\",\n", 131 | " \"ExE endoderm\" : \"#7F6874\",\n", 132 | " \"ExE ectoderm\" : \"#989898\",\n", 133 | " \"Parietal endoderm\" : \"#1A1A1A\",\n", 134 | " \"Unknown\" : \"#FFFFFF\",\n", 135 | " \"Low quality\" : \"#e6e6e6\",\n", 136 | " # somitic and paraxial types\n", 137 | " # colour from T chimera paper Guibentif et al Developmental Cell 2021\n", 138 | " \"Cranial mesoderm\" : \"#77441B\",\n", 139 | " \"Anterior somitic tissues\" : \"#F90026\",\n", 140 | " \"Sclerotome\" : \"#A10037\",\n", 141 | " \"Dermomyotome\" : \"#DA5921\",\n", 142 | " \"Posterior somitic tissues\" : \"#E1C239\",\n", 143 | " \"Presomitic mesoderm\" : \"#9DD84A\"\n", 144 | "}" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": { 151 | "id": "yUWR8-aPESU0" 152 | }, 153 | "outputs": [], 154 | "source": [ 155 | "from copy import copy\n", 156 | "adata_r1 = sg.datasets.seqFISH1()\n", 157 | "adata_r2 = sg.datasets.seqFISH2()\n", 158 | "adata_r3 = sg.datasets.seqFISH3()\n", 159 | "adata_q1 = copy(adata_r1)\n", 160 | "adata_q2 = copy(adata_r2)\n", 161 | "adata_q3 = copy(adata_r3)\n", 162 | "adata_q4 = sg.datasets.MGA()\n", 163 | "sc.pp.subsample(adata_q1, fraction=0.25)\n", 164 | "sc.pp.subsample(adata_q2, fraction=0.25)\n", 165 | "sc.pp.subsample(adata_q3, fraction=0.25)\n", 166 | "sc.pp.subsample(adata_q4, fraction=0.25)\n", 167 | "adata_q = ad.concat([adata_q1, adata_q2, adata_q3, adata_q4], join=\"inner\")\n", 168 | "del adata_q1 \n", 169 | "del adata_q2 \n", 170 | "del adata_q3 \n", 171 | "del adata_q4" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": { 178 | "id": "aU5MmeY8GrYE" 179 | }, 180 | "outputs": [], 181 | "source": [ 182 | "from sagenet.utils import glasso\n", 183 | "import numpy as np\n", 184 | "glasso(adata_r1, [0.5, 0.75, 1])\n", 185 | "adata_r1.obsm['spatial'] = np.array(adata_r1.obs[['x','y']])\n", 186 | "sq.gr.spatial_neighbors(adata_r1, coord_type=\"generic\")\n", 187 | "sc.tl.leiden(adata_r1, resolution=.01, random_state=0, key_added='leiden_0.01', adjacency=adata_r1.obsp[\"spatial_connectivities\"])\n", 188 | "sc.tl.leiden(adata_r1, resolution=.05, random_state=0, key_added='leiden_0.05', adjacency=adata_r1.obsp[\"spatial_connectivities\"])\n", 189 | "sc.tl.leiden(adata_r1, resolution=.1, random_state=0, key_added='leiden_0.1', adjacency=adata_r1.obsp[\"spatial_connectivities\"])\n", 190 | "sc.tl.leiden(adata_r1, resolution=.5, random_state=0, key_added='leiden_0.5', adjacency=adata_r1.obsp[\"spatial_connectivities\"])\n", 191 | "sc.tl.leiden(adata_r1, resolution=1, random_state=0, key_added='leiden_1', adjacency=adata_r1.obsp[\"spatial_connectivities\"])\n", 192 | "glasso(adata_r2, [0.5, 0.75, 1])\n", 193 | "adata_r2.obsm['spatial'] = np.array(adata_r2.obs[['x','y']])\n", 194 | "sq.gr.spatial_neighbors(adata_r2, coord_type=\"generic\")\n", 195 | "sc.tl.leiden(adata_r2, resolution=.01, random_state=0, key_added='leiden_0.01', adjacency=adata_r2.obsp[\"spatial_connectivities\"])\n", 196 | "sc.tl.leiden(adata_r2, resolution=.05, random_state=0, key_added='leiden_0.05', adjacency=adata_r2.obsp[\"spatial_connectivities\"])\n", 197 | "sc.tl.leiden(adata_r2, resolution=.1, random_state=0, key_added='leiden_0.1', adjacency=adata_r2.obsp[\"spatial_connectivities\"])\n", 198 | "sc.tl.leiden(adata_r2, resolution=.5, random_state=0, key_added='leiden_0.5', adjacency=adata_r2.obsp[\"spatial_connectivities\"])\n", 199 | "sc.tl.leiden(adata_r2, resolution=1, random_state=0, key_added='leiden_1', adjacency=adata_r2.obsp[\"spatial_connectivities\"])\n", 200 | "glasso(adata_r3, [0.5, 0.75, 1])\n", 201 | "adata_r3.obsm['spatial'] = np.array(adata_r3.obs[['x','y']])\n", 202 | "sq.gr.spatial_neighbors(adata_r3, coord_type=\"generic\")\n", 203 | "sc.tl.leiden(adata_r3, resolution=.01, random_state=0, key_added='leiden_0.01', adjacency=adata_r3.obsp[\"spatial_connectivities\"])\n", 204 | "sc.tl.leiden(adata_r3, resolution=.05, random_state=0, key_added='leiden_0.05', adjacency=adata_r3.obsp[\"spatial_connectivities\"])\n", 205 | "sc.tl.leiden(adata_r3, resolution=.1, random_state=0, key_added='leiden_0.1', adjacency=adata_r3.obsp[\"spatial_connectivities\"])\n", 206 | "sc.tl.leiden(adata_r3, resolution=.5, random_state=0, key_added='leiden_0.5', adjacency=adata_r3.obsp[\"spatial_connectivities\"])\n", 207 | "sc.tl.leiden(adata_r3, resolution=1, random_state=0, key_added='leiden_1', adjacency=adata_r3.obsp[\"spatial_connectivities\"])" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": { 214 | "colab": { 215 | "base_uri": "https://localhost:8080/" 216 | }, 217 | "id": "r_St4TwlIcAW", 218 | "outputId": "e289e9c1-a225-419b-842d-6bfc5b4bf3d3" 219 | }, 220 | "outputs": [ 221 | { 222 | "name": "stdout", 223 | "output_type": "stream", 224 | "text": [ 225 | "cpu\n" 226 | ] 227 | } 228 | ], 229 | "source": [ 230 | "import torch\n", 231 | "if torch.cuda.is_available(): \n", 232 | " dev = \"cuda:0\" \n", 233 | "else: \n", 234 | " dev = \"cpu\" \n", 235 | "device = torch.device(dev)\n", 236 | "print(device)" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "metadata": { 243 | "id": "WRcnUb8wIduW" 244 | }, 245 | "outputs": [], 246 | "source": [ 247 | "sg_obj = sg.sage.sage(device=device)" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": null, 253 | "metadata": { 254 | "colab": { 255 | "background_save": true, 256 | "base_uri": "https://localhost:8080/" 257 | }, 258 | "id": "BxX5nvMpISVZ", 259 | "outputId": "cd6f96d8-ea7c-4bbe-e48d-4a08c4d191e5" 260 | }, 261 | "outputs": [ 262 | { 263 | "name": "stderr", 264 | "output_type": "stream", 265 | "text": [ 266 | "/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n", 267 | " warnings.warn(out)\n", 268 | "/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n", 269 | " warnings.warn(out)\n", 270 | "/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n", 271 | " warnings.warn(out)\n", 272 | "/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n", 273 | " warnings.warn(out)\n", 274 | "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", 275 | " _warn_prf(average, modifier, msg_start, len(result))\n", 276 | "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", 277 | " _warn_prf(average, modifier, msg_start, len(result))\n", 278 | "/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n", 279 | " warnings.warn(out)\n", 280 | "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", 281 | " _warn_prf(average, modifier, msg_start, len(result))\n", 282 | "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", 283 | " _warn_prf(average, modifier, msg_start, len(result))\n", 284 | "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", 285 | " _warn_prf(average, modifier, msg_start, len(result))\n" 286 | ] 287 | } 288 | ], 289 | "source": [ 290 | "sg_obj.add_ref(adata_r1, comm_columns=['leiden_0.01', 'leiden_0.05', 'leiden_0.1', 'leiden_0.5', 'leiden_1'], tag='seqFISH_ref1', epochs=15, verbose = False)" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": null, 296 | "metadata": { 297 | "colab": { 298 | "background_save": true 299 | }, 300 | "id": "YZ4lN_0kIWwO", 301 | "outputId": "499b9ceb-8c69-4200-9468-985587f43f32" 302 | }, 303 | "outputs": [ 304 | { 305 | "name": "stderr", 306 | "output_type": "stream", 307 | "text": [ 308 | "/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n", 309 | " warnings.warn(out)\n", 310 | "/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n", 311 | " warnings.warn(out)\n", 312 | "/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n", 313 | " warnings.warn(out)\n", 314 | "/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n", 315 | " warnings.warn(out)\n", 316 | "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", 317 | " _warn_prf(average, modifier, msg_start, len(result))\n", 318 | "/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n", 319 | " warnings.warn(out)\n", 320 | "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", 321 | " _warn_prf(average, modifier, msg_start, len(result))\n", 322 | "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", 323 | " _warn_prf(average, modifier, msg_start, len(result))\n", 324 | "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", 325 | " _warn_prf(average, modifier, msg_start, len(result))\n" 326 | ] 327 | } 328 | ], 329 | "source": [ 330 | "sg_obj.add_ref(adata_r2, comm_columns=['leiden_0.01', 'leiden_0.05', 'leiden_0.1', 'leiden_0.5', 'leiden_1'], tag='seqFISH_ref2', epochs=15, verbose = False)" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": null, 336 | "metadata": { 337 | "id": "sA9dF1vZIXLV" 338 | }, 339 | "outputs": [], 340 | "source": [ 341 | "sg_obj.add_ref(adata_r3, comm_columns=['leiden_0.01', 'leiden_0.05', 'leiden_0.1', 'leiden_0.5', 'leiden_1'], tag='seqFISH_ref3', epochs=15, verbose = False)" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "metadata": { 348 | "id": "qQjIRmPnI4F2" 349 | }, 350 | "outputs": [], 351 | "source": [ 352 | "ind = np.argsort(-(adata_r.var['seqFISH_ref_entropy']+ adata_r.var['seqFISH_ref2_entropy'] + adata_r.var['seqFISH_ref3_entropy']))[0:12]\n", 353 | "with rc_context({'figure.figsize': (4, 4)}):\n", 354 | " sc.pl.spatial(adata_r, color=list(adata_r.var_names[ind]), ncols=4, spot_size=0.03, legend_loc=None)" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": null, 360 | "metadata": { 361 | "id": "qOWVKZw9JDPu" 362 | }, 363 | "outputs": [], 364 | "source": [ 365 | "!mkdir models\n", 366 | "!mkdir models/seqFISH_ref\n", 367 | "sg_obj.save_model_as_folder('models/seqFISH_multiple_ref')" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": null, 373 | "metadata": { 374 | "id": "s5rn2FxrJUNN" 375 | }, 376 | "outputs": [], 377 | "source": [ 378 | "sg_obj.map_query(adata_q)" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": null, 384 | "metadata": { 385 | "id": "WN1vAbpuJW74" 386 | }, 387 | "outputs": [], 388 | "source": [ 389 | "import anndata\n", 390 | "dist_adata = anndata.AnnData(adata_q.obsm['dist_map'], obs = adata_q.obs)\n", 391 | "knn_indices, knn_dists, forest = sc.neighbors.compute_neighbors_umap(dist_adata.X, n_neighbors=50, metric='precomputed')\n", 392 | "dist_adata.obsp['distances'], dist_adata.obsp['connectivities'] = sc.neighbors._compute_connectivities_umap(\n", 393 | " knn_indices,\n", 394 | " knn_dists,\n", 395 | " dist_adata.shape[0],\n", 396 | " 50, # change to neighbors you plan to use\n", 397 | ")\n", 398 | "sc.pp.neighbors(dist_adata, metric='precomputed', use_rep='X')\n", 399 | "sc.tl.umap(dist_adata)\n", 400 | "sc.pl.umap(dist_adata, color='cell_type', palette=celltype_colours, save='eli.pdf')" 401 | ] 402 | } 403 | ], 404 | "metadata": { 405 | "colab": { 406 | "name": "01_multiple_references.ipynb", 407 | "provenance": [], 408 | "collapsed_sections": [] 409 | }, 410 | "kernelspec": { 411 | "display_name": "Python 3", 412 | "name": "python3" 413 | }, 414 | "language_info": { 415 | "name": "python" 416 | } 417 | }, 418 | "nbformat": 4, 419 | "nbformat_minor": 0 420 | } -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarioniLab/sagenet/5dce6dd375cf28678d735f5e5d8083dfd3d86596/docs/_static/.DS_Store -------------------------------------------------------------------------------- /docs/_static/custom.css: -------------------------------------------------------------------------------- 1 | /* ReadTheDocs theme colors */ 2 | /* Copied from https://github.com/theislab/scvelo */ 3 | 4 | .wy-nav-top { 5 | background-color: #404040 6 | } 7 | 8 | .wy-nav-content { 9 | max-width: 950px 10 | } 11 | 12 | .wy-side-nav-search { 13 | background-color: transparent 14 | } 15 | 16 | .wy-side-nav-search input[type="text"] { 17 | border-width: 0 18 | } 19 | 20 | 21 | /* Custom classes */ 22 | .small { 23 | font-size: 40% 24 | } 25 | 26 | .smaller, .pr { 27 | font-size: 70% 28 | } 29 | 30 | 31 | /* Custom classes with bootstrap buttons */ 32 | 33 | .tutorial, 34 | .tutorial:visited, 35 | .tutorial:hover { 36 | /* text-decoration: underline; */ 37 | font-weight: bold; 38 | padding: 2px 5px; 39 | white-space: nowrap; 40 | max-width: 100%; 41 | background: #EF3270; 42 | border: solid 1px #EF3270; 43 | border-radius: .25rem; 44 | font-size: 75%; 45 | /* font-family: SFMono-Regular,Menlo,Monaco,Consolas,"Liberation Mono","Courier New",Courier,monospace; */ 46 | color: #404040; 47 | overflow-x: auto; 48 | box-sizing: border-box; 49 | } 50 | 51 | 52 | /* Formatting of RTD markup: rubrics and sidebars and admonitions */ 53 | 54 | /* rubric */ 55 | .rst-content p.rubric { 56 | margin-bottom: 6px; 57 | font-weight: normal; 58 | } 59 | 60 | .rst-content p.rubric::after { 61 | content: ":" 62 | } 63 | 64 | /* sidebar */ 65 | .rst-content .sidebar { 66 | /* margin: 0px 0px 0px 12px; */ 67 | padding-bottom: 0px; 68 | } 69 | 70 | .rst-content .sidebar p { 71 | margin-bottom: 12px; 72 | } 73 | 74 | .rst-content .sidebar p, 75 | .rst-content .sidebar ul, 76 | .rst-content .sidebar dl { 77 | font-size: 13px; 78 | } 79 | 80 | /* less space after bullet lists in admonitions like warnings and notes */ 81 | .rst-content .section .admonition ul { 82 | margin-bottom: 6px; 83 | } 84 | 85 | 86 | /* Code: literals and links */ 87 | 88 | .rst-content tt.literal, 89 | .rst-content code.literal { 90 | color: #404040; 91 | } 92 | 93 | /* slim font weight for non-link code */ 94 | .rst-content tt:not(.xref), 95 | .rst-content code:not(.xref), 96 | .rst-content *:not(a) > tt.xref, 97 | .rst-content *:not(a) > code.xref, 98 | .rst-content a > tt.xref, 99 | .rst-content a > code.xref, 100 | .rst-content dl:not(.docutils) a > tt.xref, 101 | 102 | 103 | /* Just one box for annotation code for a less noisy look */ 104 | 105 | .rst-content .annotation { 106 | padding: 2px 5px; 107 | background-color: white; 108 | border: 1px solid #e1e4e5; 109 | } 110 | 111 | .rst-content .annotation tt, 112 | .rst-content .annotation code { 113 | padding: 0 0; 114 | background-color: transparent; 115 | border: 0 solid transparent; 116 | } 117 | 118 | 119 | /* Parameter lists */ 120 | 121 | .rst-content dl:not(.docutils) dl dt { 122 | /* mimick numpydoc’s blockquote style */ 123 | font-weight: normal; 124 | background: none transparent; 125 | border-left: none; 126 | margin: 0 0 12px; 127 | padding: 3px 0 0; 128 | font-size: 100%; 129 | } 130 | 131 | .rst-content dl:not(.docutils) dl dt code { 132 | font-size: 100%; 133 | font-weight: normal; 134 | background: none transparent; 135 | border: none; 136 | padding: 0 2px; 137 | } 138 | 139 | .rst-content dl:not(.docutils) dl dt a.reference > code { 140 | text-decoration: underline; 141 | } 142 | 143 | /* Mimick rubric style used for other headings */ 144 | .rst-content dl:not(.docutils) dl > dt { 145 | font-weight: bold; 146 | background: none transparent; 147 | border-left: none; 148 | margin: 0 0 12px; 149 | padding: 3px 0 0; 150 | font-size: 100%; 151 | } 152 | 153 | /* Parameters contain parts and don’t need bold font */ 154 | .rst-content dl.field-list dl > dt { 155 | font-weight: unset 156 | } 157 | 158 | /* Add colon between return tuple element name and type */ 159 | .rst-content dl:not(.docutils) dl > dt .classifier::before { 160 | content: ' : ' 161 | } 162 | 163 | /* Function headers */ 164 | 165 | .rst-content dl:not(.docutils) dt { 166 | background: #edf0f2; 167 | color: #404040; 168 | border-top: solid 3px #343131; 169 | } 170 | 171 | .rst-content .section ul li p:last-child { 172 | margin-bottom: 0; 173 | margin-top: 0; 174 | } 175 | 176 | /* Copy buttons */ 177 | a.copybtn { 178 | position: absolute; 179 | top: -6.5px; 180 | right: 0px; 181 | width: 1em; 182 | height: 1em; 183 | padding: .15em; 184 | opacity: .4; 185 | transition: opacity 0.5s; 186 | } 187 | 188 | 189 | /* Remove prompt line numbers */ 190 | .nbinput > :first-child, 191 | .nboutput > :first-child { 192 | min-width: 0 !important; 193 | } 194 | 195 | /* Adjust width of navigation bar on mobile */ 196 | @media screen and (max-width: 768px) { 197 | .header-bar { 198 | display: none; 199 | } 200 | 201 | .wy-nav-content-wrap { 202 | margin-left: 0px; 203 | } 204 | 205 | .wy-nav-side { 206 | width: 300px; 207 | } 208 | 209 | .wy-nav-side.shift { 210 | max-width: 320px; 211 | } 212 | 213 | /* Fix sidebar adjust */ 214 | .rst-versions { 215 | width: 40%; 216 | max-width: 320px; 217 | } 218 | } 219 | 220 | /* Handle landscape */ 221 | @media screen and (min-width: 377px) { 222 | .wy-nav-content-wrap.shift { 223 | left: 320px; 224 | } 225 | } -------------------------------------------------------------------------------- /docs/_static/img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarioniLab/sagenet/5dce6dd375cf28678d735f5e5d8083dfd3d86596/docs/_static/img/logo.png -------------------------------------------------------------------------------- /docs/about.rst: -------------------------------------------------------------------------------- 1 | 2 | 3 | SageNet: Single-cell Spatial Locator 4 | ========================================================================= 5 | .. raw:: html 6 | 7 | **SageNet** is a robust and generalizable graph neural network approach that probabilistically maps dissociated single cells from an scRNAseq dataset to their hypothetical tissue of origin using one or more reference datasets aquired by spatially resolved transcriptomics techniques. It is compatible with both high-plex imaging (e.g., seqFISH, MERFISH, etc.) and spatial barcoding (e.g., 10X visium, Slide-seq, etc.) datasets as the spatial reference. 8 | 9 | 10 | .. raw:: html 11 | 12 |

13 | 14 | sagenet logo 16 | 17 |

18 | 19 | SageNet is implemented with `pytorch `_ and `pytorch-geometric `_ to be modular, fast, and scalable. Also, it uses ``anndata`` to be compatible with `scanpy `_ and `squidpy `_ for pre- and post-processing steps. 20 | 21 | 22 | Installation 23 | ============ 24 | 25 | 26 | .. note:: 27 | 28 | **0.1.0** 29 | The dependency ``torch-geometric`` should be installed separately, corresponding the system specefities, look at `this link `_ for instructions. We recommend to use Miniconda. 30 | 31 | PyPI 32 | -------- 33 | 34 | The easiest way to get SageNet is through pip using the following command:: 35 | 36 | pip install sagenet 37 | 38 | Development 39 | --------------- 40 | 41 | First, clone the repository using ``git``:: 42 | 43 | git clone https://github.com/MarioniLab/sagenet 44 | 45 | Then, ``cd`` to the sagenet folder and run the install command:: 46 | 47 | cd sagenet 48 | python setup.py install #or pip install . 49 | 50 | 51 | Usage 52 | ============ 53 | :: 54 | 55 | import sagenet as sg 56 | import scanpy as sc 57 | import squidpy as sq 58 | import anndata as ad 59 | import random 60 | random.seed(10) 61 | 62 | 63 | Training phase: 64 | --------------- 65 | 66 | 67 | **Input:** 68 | 69 | - Expression matrix associated with the (spatial) reference dataset (an ``anndata`` object) 70 | 71 | :: 72 | 73 | adata_r = sg.datasets.seqFISH1() 74 | 75 | 76 | - gene-gene interaction network 77 | 78 | 79 | :: 80 | 81 | glasso(adata_r, [0.5, 0.75, 1]) 82 | 83 | 84 | 85 | 86 | - one or more partitionings of the spatial reference into distinct connected neighborhoods of cells or spots 87 | 88 | :: 89 | 90 | adata_r.obsm['spatial'] = np.array(adata_r.obs[['x','y']]) 91 | sq.gr.spatial_neighbors(adata_r, coord_type="generic") 92 | sc.tl.leiden(adata_r, resolution=.01, random_state=0, key_added='leiden_0.01', adjacency=adata_r.obsp["spatial_connectivities"]) 93 | sc.tl.leiden(adata_r, resolution=.05, random_state=0, key_added='leiden_0.05', adjacency=adata_r.obsp["spatial_connectivities"]) 94 | sc.tl.leiden(adata_r, resolution=.1, random_state=0, key_added='leiden_0.1', adjacency=adata_r.obsp["spatial_connectivities"]) 95 | sc.tl.leiden(adata_r, resolution=.5, random_state=0, key_added='leiden_0.5', adjacency=adata_r.obsp["spatial_connectivities"]) 96 | sc.tl.leiden(adata_r, resolution=1, random_state=0, key_added='leiden_1', adjacency=adata_r.obsp["spatial_connectivities"]) 97 | 98 | 99 | 100 | **Training:** 101 | :: 102 | 103 | 104 | sg_obj = sg.sage.sage(device=device) 105 | sg_obj.add_ref(adata_r, comm_columns=['leiden_0.01', 'leiden_0.05', 'leiden_0.1', 'leiden_0.5', 'leiden_1'], tag='seqFISH_ref', epochs=20, verbose = False) 106 | 107 | 108 | 109 | **Output:** 110 | 111 | - A set of pre-trained models (one for each partitioning) 112 | 113 | :: 114 | 115 | 116 | !mkdir models 117 | !mkdir models/seqFISH_ref 118 | sg_obj.save_model_as_folder('models/seqFISH_ref') 119 | 120 | 121 | - A concensus scoring of spatially informativity of each gene 122 | 123 | :: 124 | 125 | 126 | ind = np.argsort(-adata_r.var['seqFISH_ref_entropy'])[0:12] 127 | with rc_context({'figure.figsize': (4, 4)}): 128 | sc.pl.spatial(adata_r, color=list(adata_r.var_names[ind]), ncols=4, spot_size=0.03, legend_loc=None) 129 | 130 | 131 | .. raw:: html 132 | 133 |

134 | 135 | spatial markers 137 | 138 |

139 | 140 | 141 | 142 | 143 | Mapping phase 144 | --------------- 145 | 146 | **Input:** 147 | 148 | - Expression matrix associated with the (dissociated) query dataset (an ``anndata`` object) 149 | :: 150 | 151 | adata_q = sg.datasets.MGA() 152 | 153 | 154 | **Mapping:** 155 | :: 156 | 157 | sg_obj.map_query(adata_q) 158 | 159 | 160 | **Output:** 161 | 162 | - The reconstructed cell-cell spatial distance matrix 163 | :: 164 | 165 | 166 | adata_q.obsm['dist_map'] 167 | 168 | 169 | - A concensus scoring of mapability (uncertainity of mapping) of each cell to the references 170 | :: 171 | 172 | 173 | adata_q.obs 174 | 175 | 176 | :: 177 | 178 | 179 | import anndata 180 | dist_adata = anndata.AnnData(adata_q.obsm['dist_map'], obs = adata_q.obs) 181 | knn_indices, knn_dists, forest = sc.neighbors.compute_neighbors_umap(dist_adata.X, n_neighbors=50, metric='precomputed') 182 | dist_adata.obsp['distances'], dist_adata.obsp['connectivities'] = sc.neighbors._compute_connectivities_umap( 183 | knn_indices, 184 | knn_dists, 185 | dist_adata.shape[0], 186 | 50, # change to neighbors you plan to use 187 | ) 188 | sc.pp.neighbors(dist_adata, metric='precomputed', use_rep='X') 189 | sc.tl.umap(dist_adata) 190 | sc.pl.umap(dist_adata, color='cell_type', palette=celltype_colours) 191 | 192 | 193 | .. raw:: html 194 | 195 |

196 | 197 | reconstructed space 199 | 200 |

201 | 202 | 203 | Notebooks 204 | ============ 205 | To see some examples of our pipeline's capability, look at the `notebooks `_ directory. The notebooks are also avaialble on google colab: 206 | 207 | #. `Intro to SageNet `_ 208 | #. `Using multiple references `_ 209 | 210 | Interactive examples 211 | ============ 212 | * `Spatial mapping of the mouse gastrulation atlas `_ 213 | 214 | 215 | Support and contribute 216 | ============ 217 | If you have a question or new architecture or a model that could be integrated into our pipeline, you can 218 | post an `issue `__ or reach us by `email `_. 219 | 220 | 221 | Contributions 222 | ============ 223 | This work is led by Elyas Heidari and Shila Ghazanfar as a joint effort between `MarioniLab@CRUK@EMBL-EBI `__ and `RobinsonLab@UZH `__. 224 | 225 | .. |Docs| image:: https://readthedocs.org/projects/sagenet/badge/?version=latest 226 | :target: https://sagenet.readthedocs.io 227 | 228 | .. |PyPI| image:: https://img.shields.io/pypi/v/sagenet.svg 229 | :target: https://pypi.org/project/sagenet 230 | 231 | 232 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | API 2 | === 3 | 4 | The API reference contains detailed descriptions of the different end-user classes, functions, methods, etc. 5 | 6 | 7 | .. note:: 8 | 9 | This API reference only contains end-user documentation. 10 | If you are looking to hack away at sagenet' internals, you will find more detailed comments in the source code. 11 | 12 | 13 | * `sage`_ 14 | * `classifier`_ 15 | * `utils`_ 16 | 17 | sage 18 | ---- 19 | 20 | .. automodule:: sagenet.sage 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | 26 | classifier 27 | ---------- 28 | 29 | .. automodule:: sagenet.classifier 30 | :members: 31 | :undoc-members: 32 | :show-inheritance: 33 | 34 | 35 | utils 36 | ----- 37 | 38 | .. automodule:: sagenet.utils 39 | :members: 40 | :undoc-members: 41 | :show-inheritance: 42 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import inspect 14 | import os 15 | import sys 16 | from datetime import datetime 17 | 18 | sys.path.insert(0, os.path.abspath('..')) 19 | 20 | # -- Readthedocs theme ------------------------------------------------------- 21 | on_rtd = os.environ.get('READTHEDOCS', None) == 'True' 22 | 23 | if not on_rtd: # only import and set the theme if we're building docs locally 24 | import sphinx_rtd_theme 25 | 26 | html_theme = 'sphinx_rtd_theme' 27 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 28 | 29 | import sagenet 30 | 31 | # -- Retrieve notebooks ------------------------------------------------------ 32 | 33 | from urllib.request import urlretrieve 34 | 35 | notebooks_url = 'https://github.com/MarioniLab/sagenet/tree/main/notebooks' 36 | notebooks = [ 37 | '00_hello_sagenet.ipynb', 38 | '01_multiple_references.ipynb' 39 | ] 40 | 41 | for nb in notebooks: 42 | try: 43 | urlretrieve(notebooks_url + nb, nb) 44 | except: 45 | pass 46 | 47 | # -- Project information ----------------------------------------------------- 48 | 49 | project = 'sagenet' 50 | author = 'Elyas Heidari' 51 | copyright = f'{datetime.now():%Y}, ' + author 52 | 53 | pygments_style = 'sphinx' 54 | todo_include_todos = True 55 | html_theme_options = dict(navigation_depth=3, titles_only=False) 56 | html_context = dict( 57 | display_github=True, 58 | github_user='MarioniLab', 59 | github_repo='sagenet', 60 | github_version='master', 61 | conf_py_path='/docs/', 62 | ) 63 | html_static_path = ['_static'] 64 | html_logo = '_static/img/logo.png' 65 | 66 | def setup(app): 67 | app.add_css_file('custom.css') 68 | 69 | # -- General configuration --------------------------------------------------- 70 | 71 | # Add any Sphinx extension module names here, as strings. They can be 72 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 73 | # ones. 74 | extensions = [ 75 | 'sphinx.ext.autodoc', 76 | 'sphinx.ext.doctest', 77 | 'nbsphinx', 78 | 'sphinx.ext.napoleon', 79 | 'sphinx.ext.todo', 80 | 'sphinx.ext.mathjax', 81 | 'sphinx.ext.graphviz', 82 | 'sphinx.ext.intersphinx', 83 | 'sphinx.ext.linkcode', 84 | 'sphinx_rtd_theme', 85 | 'numpydoc', 86 | ] 87 | 88 | add_module_names = True 89 | autosummary_generate = True 90 | numpydoc_show_class_members = True 91 | 92 | intersphinx_mapping = { 93 | 'python': ('https://docs.python.org/3', None), 94 | 'anndata': ('https://anndata.readthedocs.io/en/latest/', None), 95 | 'numpy': ('https://numpy.readthedocs.io/en/latest/', None), 96 | 'scanpy': ('https://scanpy.readthedocs.io/en/latest/', None), 97 | } 98 | 99 | # Add any paths that contain templates here, relative to this directory. 100 | templates_path = ['_templates'] 101 | 102 | # List of patterns, relative to source directory, that match files and 103 | # directories to ignore when looking for source files. 104 | # This pattern also affects html_static_path and html_extra_path. 105 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', '**.ipynb_checkpoints'] 106 | 107 | 108 | def linkcode_resolve(domain, info): 109 | """ 110 | Determine the URL corresponding to Python object 111 | """ 112 | if domain != 'py': 113 | return None 114 | 115 | modname = info['module'] 116 | fullname = info['fullname'] 117 | 118 | submod = sys.modules.get(modname) 119 | if submod is None: 120 | return None 121 | 122 | obj = submod 123 | for part in fullname.split('.'): 124 | try: 125 | obj = getattr(obj, part) 126 | except: 127 | return None 128 | 129 | try: 130 | fn = inspect.getsourcefile(obj) 131 | except: 132 | fn = None 133 | if not fn: 134 | return None 135 | 136 | try: 137 | source, lineno = inspect.findsource(obj) 138 | except: 139 | lineno = None 140 | 141 | if lineno: 142 | linespec = "#L%d" % (lineno + 1) 143 | else: 144 | linespec = "" 145 | 146 | fn = os.path.relpath(fn, start=os.path.dirname(sagenet.__file__)) 147 | 148 | github = f"https://github.com/MarioniLab/sagenet/blob/main/sagenet/{fn}{linespec}" 149 | return github -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. sagenet documentation master file, created by 2 | sphinx-quickstart on Tue Dec 7 05:28:34 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | SageNet: Spatial reconstruction of single-cell dissociated datasets using graph neural networks 6 | ========================================================================= 7 | .. raw:: html 8 | 9 | **SageNet** is a robust and generalizable graph neural network approach that probabilistically maps dissociated single cells from an scRNAseq dataset to their hypothetical tissue of origin using one or more reference datasets aquired by spatially resolved transcriptomics techniques. It is compatible with both high-plex imaging (e.g., seqFISH, MERFISH, etc.) and spatial barcoding (e.g., 10X visium, Slide-seq, etc.) datasets as the spatial reference. 10 | 11 | 12 | .. raw:: html 13 | 14 |

15 | 16 | sagenet logo 18 | 19 |

20 | 21 | SageNet is implemented with `pytorch `_ and `pytorch-geometric `_ to be modular, fast, and scalable. Also, it uses ``anndata`` to be compatible with `scanpy `_ and `squidpy `_ for pre- and post-processing steps. 22 | 23 | Installation 24 | ------------------------------- 25 | You can get the latest development version of our toolkit from `Github `_ using the following steps: 26 | 27 | First, clone the repository using ``git``:: 28 | 29 | git clone https://github.com/MarioniLab/sagenet 30 | 31 | Then, ``cd`` to the sagenet folder and run the install command:: 32 | 33 | cd sagenet 34 | python setup.py install #or pip install ` 35 | 36 | 37 | The dependency ``torch-geometric`` should be installed separately, corresponding the system specefities, look at `this link `_ for instructions. 38 | 39 | 40 | .. raw:: html 41 | 42 |

43 | 44 | activations logo 46 | 47 |

48 | 49 | 50 | Notebooks 51 | ------------------------------- 52 | To see some examples of our pipeline's capability, look at the `notebooks `_ directory. The notebooks are also avaialble on google colab: 53 | 54 | #. `Intro to SageNet `_ 55 | #. `Using multiple references `_ 56 | 57 | Interactive examples 58 | ------------------------------- 59 | See `this `_ 60 | 61 | 62 | Support and contribute 63 | ------------------------------- 64 | If you have a question or new architecture or a model that could be integrated into our pipeline, you can 65 | post an `issue `__ or reach us by `email `_. 66 | 67 | 68 | Contributions 69 | ------------------------------- 70 | This work is led by Elyas Heidari and Shila Ghazanfar as a joint effort between `MarioniLab@CRUK@EMBL-EBI `__ and `RobinsonLab@UZH `__. 71 | 72 | 73 | 74 | 75 | .. toctree:: 76 | :maxdepth: 1 77 | :caption: Main 78 | :hidden: 79 | 80 | about 81 | installation 82 | api.rst 83 | 84 | .. toctree:: 85 | :maxdepth: 1 86 | :caption: Examples 87 | :hidden: 88 | 89 | 01_multiple_references 90 | 00_hello_sagenet 91 | 92 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /figures/show_ad_r1_all.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarioniLab/sagenet/5dce6dd375cf28678d735f5e5d8083dfd3d86596/figures/show_ad_r1_all.pdf -------------------------------------------------------------------------------- /figures/show_ad_r1_all_conf.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarioniLab/sagenet/5dce6dd375cf28678d735f5e5d8083dfd3d86596/figures/show_ad_r1_all_conf.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.8.0 2 | captum>=0.4 3 | tensorboard>=2.6 4 | numpydoc>=1.1 5 | anndata 6 | gglasso 7 | squidpy 8 | 9 | 10 | -------------------------------------------------------------------------------- /sagenet/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarioniLab/sagenet/5dce6dd375cf28678d735f5e5d8083dfd3d86596/sagenet/.DS_Store -------------------------------------------------------------------------------- /sagenet/DHH_data/_DHH_data.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | from squidpy.datasets._utils import AMetadata 3 | 4 | ST = AMetadata( 5 | name="ST", 6 | doc_header="", 7 | # shape=(270876, 43), 8 | url="https://figshare.com/ndownloader/files/31796207", 9 | ) 10 | 11 | scRNAseq = AMetadata( 12 | name="scRNAseq", 13 | doc_header="", 14 | # shape=(270876, 43), 15 | url="https://figshare.com/ndownloader/files/31796219", 16 | ) 17 | 18 | 19 | 20 | 21 | for name, var in copy(locals()).items(): 22 | if isinstance(var, AMetadata): 23 | var._create_function(name, globals()) 24 | 25 | 26 | __all__ = [ # noqa: F822 27 | "ST", 28 | "scRNAseq", 29 | 30 | ] 31 | -------------------------------------------------------------------------------- /sagenet/DHH_data/__init__.py: -------------------------------------------------------------------------------- 1 | from sagenet.DHH_data._DHH_data import * 2 | -------------------------------------------------------------------------------- /sagenet/DHH_data_/_DHH_data_.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | from squidpy.datasets._utils import AMetadata 3 | 4 | ST = AMetadata( 5 | name="ST", 6 | doc_header="", 7 | # shape=(270876, 43), 8 | url="https://figshare.com/ndownloader/files/31796207", 9 | ) 10 | 11 | scRNAseq = AMetadata( 12 | name="scRNAseq", 13 | doc_header="", 14 | # shape=(270876, 43), 15 | url="https://figshare.com/ndownloader/files/31796219", 16 | ) 17 | 18 | 19 | 20 | 21 | for name, var in copy(locals()).items(): 22 | if isinstance(var, AMetadata): 23 | var._create_function(name, globals()) 24 | 25 | 26 | __all__ = [ # noqa: F822 27 | "ST", 28 | "scRNAseq", 29 | 30 | ] 31 | -------------------------------------------------------------------------------- /sagenet/DHH_data_/__init__.py: -------------------------------------------------------------------------------- 1 | from sagenet.DHH_data._DHH_data_ import * 2 | -------------------------------------------------------------------------------- /sagenet/MGA_analysis.py: -------------------------------------------------------------------------------- 1 | import sagenet as sg 2 | from sagenet.utils import map2ref, glasso 3 | import numpy as np 4 | import squidpy as sq 5 | import scanpy as sc 6 | import torch 7 | import anndata as ad 8 | 9 | random.seed(1996) 10 | 11 | 12 | adata_r1 = sg.MGA_data.seqFISH1_1() 13 | adata_r2 = sg.MGA_data.seqFISH2_1() 14 | adata_r3 = sg.MGA_data.seqFISH3_1() 15 | adata_q1 = sg.MGA_data.seqFISH1_2() 16 | adata_q2 = sg.MGA_data.seqFISH2_2() 17 | adata_q3 = sg.MGA_data.seqFISH3_2() 18 | adata_q = sg.MGA_data.scRNAseq() 19 | 20 | 21 | # Map everything to 1-1 22 | glasso(adata_r1) 23 | adata_r1.obsm['spatial'] = np.array(adata_r1.obs[['x','y']]) 24 | sq.gr.spatial_neighbors(adata_r1, coord_type="generic") 25 | sc.tl.leiden(adata_r1, resolution=.01, random_state=0, key_added='leiden_0.01', adjacency=adata_r1.obsp["spatial_connectivities"]) 26 | sc.tl.leiden(adata_r1, resolution=.05, random_state=0, key_added='leiden_0.05', adjacency=adata_r1.obsp["spatial_connectivities"]) 27 | sc.tl.leiden(adata_r1, resolution=.1, random_state=0, key_added='leiden_0.1', adjacency=adata_r1.obsp["spatial_connectivities"]) 28 | sc.tl.leiden(adata_r1, resolution=.5, random_state=0, key_added='leiden_0.5', adjacency=adata_r1.obsp["spatial_connectivities"]) 29 | sc.tl.leiden(adata_r1, resolution=1, random_state=0, key_added='leiden_1', adjacency=adata_r1.obsp["spatial_connectivities"]) 30 | 31 | if torch.cuda.is_available(): 32 | dev = "cuda:0" 33 | else: 34 | dev = "cpu" 35 | 36 | 37 | device = torch.device(dev) 38 | print(device) 39 | 40 | sg_obj = sg.sage.sage(device=device) 41 | sg_obj.add_ref(adata_r1, comm_columns=['leiden_0.01', 'leiden_0.05', 'leiden_0.1', 'leiden_0.5', 'leiden_1'], tag='embryo1_2', epochs=20, verbose = True, classifier='GraphSAGE') 42 | 43 | 44 | sg_obj.map_query(adata_r1, save_prob=True) 45 | ind, conf = map2ref(adata_r1, adata_r1) 46 | adata_r1.obsm['spatial_pred'] = adata_r1.obsm['spatial'][ind,:] 47 | adata_r1.obs['conf'] = np.log(conf) 48 | sg_obj.map_query(adata_r2, save_prob=True) 49 | ind, conf = map2ref(adata_r1, adata_r2) 50 | adata_r2.obsm['spatial'] = adata_r1.obsm['spatial'][ind,:] 51 | adata_r2.obs['conf'] = np.log(conf) 52 | sg_obj.map_query(adata_r3, save_prob=True) 53 | ind, conf = map2ref(adata_r1, adata_r3) 54 | adata_r3.obsm['spatial'] = adata_r1.obsm['spatial'][ind,:] 55 | adata_r3.obs['conf'] = np.log(conf) 56 | sg_obj.map_query(adata_q1, save_prob=True) 57 | ind, conf = map2ref(adata_r1, adata_q1) 58 | adata_q1.obsm['spatial'] = adata_r1.obsm['spatial'][ind,:] 59 | adata_q1.obs['conf'] = np.log(conf) 60 | sg_obj.map_query(adata_q2, save_prob=True) 61 | ind, conf = map2ref(adata_r1, adata_q2) 62 | adata_q2.obsm['spatial'] = adata_r1.obsm['spatial'][ind,:] 63 | adata_q2.obs['conf'] = np.log(conf) 64 | sg_obj.map_query(adata_q3, save_prob=True) 65 | ind, conf = map2ref(adata_r1, adata_q3) 66 | adata_q3.obsm['spatial'] = adata_r1.obsm['spatial'][ind,:] 67 | adata_q3.obs['conf'] = np.log(conf) 68 | sg_obj.map_query(adata_q, save_prob=True) 69 | ind, conf = map2ref(adata_r1, adata_q) 70 | adata_q.obsm['spatial'] = adata_r1.obsm['spatial'][ind,:] 71 | adata_q.obs['conf'] = np.log(conf) 72 | 73 | adata_r1.obsm['spatial'] = adata_r1.obsm['spatial_pred'] 74 | 75 | 76 | ad_concat = ad.concat([adata_r1, adata_r2, adata_r3, adata_q1, adata_q2, adata_q3, adata_q], label='batch') 77 | 78 | sc.pl.spatial( 79 | ad_concat, 80 | color='cell_type', 81 | palette=celltype_colours,# Color cells based on 'cell_type' 82 | # color_map=cell_type_color_map, # Use the custom color map 83 | # library_id='r1_mapping', # Use 'r1_mapping' coordinates 84 | title='all to r1 map', 85 | save='_ad_r1_all.pdf', 86 | spot_size=.1 87 | ) 88 | 89 | sc.pl.spatial( 90 | ad_concat, 91 | color='conf', 92 | # palette=celltype_colours,# Color cells based on 'cell_type' 93 | # color_map=cell_type_color_map, # Use the custom color map 94 | # library_id='r1_mapping', # Use 'r1_mapping' coordinates 95 | title='all to r1 map', 96 | save='_ad_r1_all_conf.pdf', 97 | spot_size=.1 98 | ) -------------------------------------------------------------------------------- /sagenet/MGA_data/_MGA_data.py: -------------------------------------------------------------------------------- 1 | 2 | from copy import copy 3 | from squidpy.datasets._utils import AMetadata 4 | 5 | _scRNAseq = AMetadata( 6 | name="scRNAseq", 7 | doc_header="", 8 | # shape=(270876, 43), 9 | url="https://figshare.com/ndownloader/files/31767704", 10 | ) 11 | 12 | _seqFISH1_1 = AMetadata( 13 | name="seqFISH1_1", 14 | doc_header="", 15 | # shape=(270876, 43), 16 | url="https://figshare.com/ndownloader/files/31716029", 17 | ) 18 | 19 | _seqFISH2_1 = AMetadata( 20 | name="seqFISH2_1", 21 | doc_header="", 22 | # shape=(270876, 43), 23 | url="https://figshare.com/ndownloader/files/31716041", 24 | ) 25 | 26 | 27 | _seqFISH3_1 = AMetadata( 28 | name="seqFISH3_1", 29 | doc_header="", 30 | # shape=(270876, 43), 31 | url="https://figshare.com/ndownloader/files/31716089", 32 | ) 33 | 34 | 35 | _seqFISH1_2 = AMetadata( 36 | name="seqFISH1_2", 37 | doc_header="", 38 | # shape=(270876, 43), 39 | url="https://figshare.com/ndownloader/files/31920353", 40 | ) 41 | 42 | _seqFISH2_2 = AMetadata( 43 | name="seqFISH2_2", 44 | doc_header="", 45 | # shape=(270876, 43), 46 | url="https://figshare.com/ndownloader/files/31920644", 47 | ) 48 | 49 | 50 | _seqFISH3_2 = AMetadata( 51 | name="seqFISH3_2", 52 | doc_header="", 53 | # shape=(270876, 43), 54 | url="https://figshare.com/ndownloader/files/31920890", 55 | ) 56 | 57 | 58 | for name, var in copy(locals()).items(): 59 | if isinstance(var, AMetadata): 60 | var._create_function(name, globals()) 61 | 62 | 63 | __all__ = [ # noqa: F822 64 | "scRNAseq", 65 | "seqFISH1_1", 66 | "seqFISH2_1", 67 | "seqFISH3_1", 68 | "seqFISH1_2", 69 | "seqFISH2_2", 70 | "seqFISH3_2", 71 | 72 | ] 73 | -------------------------------------------------------------------------------- /sagenet/MGA_data/__init__.py: -------------------------------------------------------------------------------- 1 | from sagenet.MGA_data._MGA_data import * 2 | -------------------------------------------------------------------------------- /sagenet/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = __maintainer__ = "Elyas Heidari" 2 | __email__ = 'eheidari@student.ethz.ch' 3 | __version__ = "0.1.1" 4 | 5 | from . import classifier, model, sage, utils, MGA_data, DHH_data 6 | 7 | -------------------------------------------------------------------------------- /sagenet/classifier.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | from torch.optim import lr_scheduler 3 | import torch.nn as nn 4 | from torch.utils.data import TensorDataset, DataLoader, random_split 5 | from torch.utils.tensorboard import SummaryWriter 6 | from sagenet.utils import compute_metrics 7 | from sagenet.model import * 8 | from captum import * 9 | from captum.attr import IntegratedGradients 10 | import numpy as np 11 | from sklearn.preprocessing import normalize 12 | 13 | class Classifier(): 14 | 15 | """ 16 | A Neural Network Classifier. A number of Graph Neural Networks (GNN) and an MLP are implemented. 17 | 18 | Parameters 19 | ---------- 20 | n_features : int 21 | number of input features. 22 | n_classes : int 23 | number of classes. 24 | n_hidden_GNN : list, default=[] 25 | list of integers indicating sizes of GNN hidden layers. 26 | n_hidden_FC : list, default=[] 27 | list of integers indicating sizes of FC hidden layers. If a GNN is used, this indicates FC hidden layers after the GNN layers. 28 | K : integer, default=4 29 | Convolution layer filter size. Used only when `classifier == 'Chebnet'`. 30 | dropout_GNN : float, default=0 31 | dropout rate for GNN hidden layers. 32 | dropout_FC : float, default=0 33 | dropout rate for FC hidden layers. 34 | classifier : str, default='MLP' 35 | - 'MLP' --> multilayer perceptron 36 | - 'GraphSAGE'--> GraphSAGE Network 37 | - 'Chebnet'--> Chebyshev spectral Graph Convolutional Network 38 | - 'GATConv'--> Graph Attentional Neural Network 39 | - 'GENConv'--> GENeralized Graph Convolution Network 40 | - 'GINConv'--> Graph Isoform Network 41 | - 'GraphConv'--> Graph Convolutional Neural Network 42 | - 'MFConv'--> Convolutional Networks on Graphs for Learning Molecular Fingerprints 43 | - 'TransformerConv'--> Graph Transformer Neural Network 44 | lr : float, default=0.001 45 | base learning rate for the SGD optimization algorithm. 46 | momentum : float, default=0.9 47 | base momentum for the SGD optimization algorithm. 48 | log_dir : str, default=None 49 | path to the log directory. Specifically, used for tensorboard logs. 50 | device : str, default='cpu' 51 | the processing unit. 52 | 53 | 54 | 55 | See also 56 | -------- 57 | Classifier.fit : fits the classifier to data 58 | Classifier.eval : evaluates the classifier predictions 59 | """ 60 | def __init__(self, 61 | n_features, 62 | n_classes, 63 | n_hidden_GNN=[], 64 | n_hidden_FC=[], 65 | K=4, 66 | pool_K=4, 67 | dropout_GNN=0, 68 | dropout_FC=0, 69 | classifier='MLP', 70 | lr=.001, 71 | momentum=.9, 72 | log_dir=None, 73 | device='cpu'): 74 | if classifier == 'MLP': 75 | self.net = NN(n_features=n_features, n_classes=n_classes,\ 76 | n_hidden_FC=n_hidden_FC, dropout_FC=dropout_FC) 77 | if classifier == 'GraphSAGE': 78 | self.net = GraphSAGE(n_features=n_features, n_classes=n_classes,\ 79 | n_hidden_GNN=n_hidden_GNN, n_hidden_FC=n_hidden_FC, \ 80 | dropout_FC=dropout_FC, dropout_GNN=dropout_GNN) 81 | if classifier == 'Chebnet': 82 | self.net = ChebNet(n_features=n_features, n_classes=n_classes,\ 83 | n_hidden_GNN=n_hidden_GNN, n_hidden_FC=n_hidden_FC, \ 84 | dropout_FC=dropout_FC, dropout_GNN=dropout_GNN, K=K) 85 | if classifier == 'GATConv': 86 | self.net = GATConvNet(n_features=n_features, n_classes=n_classes,\ 87 | n_hidden_GNN=n_hidden_GNN, n_hidden_FC=n_hidden_FC, \ 88 | dropout_FC=dropout_FC, dropout_GNN=dropout_GNN) 89 | if classifier == 'GENConv': 90 | self.net = GENConvNet(n_features=n_features, n_classes=n_classes,\ 91 | n_hidden_GNN=n_hidden_GNN, n_hidden_FC=n_hidden_FC, \ 92 | dropout_FC=dropout_FC, dropout_GNN=dropout_GNN) 93 | if classifier =="GINConv": 94 | self.net = GINConv(n_features=n_features, n_classes=n_classes,\ 95 | n_hidden_GNN=n_hidden_GNN, n_hidden_FC=n_hidden_FC, \ 96 | dropout_FC=dropout_FC, dropout_GNN=dropout_GNN) 97 | if classifier =="GraphConv": 98 | self.net = GraphConv(n_features=n_features, n_classes=n_classes,\ 99 | n_hidden_GNN=n_hidden_GNN, n_hidden_FC=n_hidden_FC, \ 100 | dropout_FC=dropout_FC, dropout_GNN=dropout_GNN) 101 | if classifier =="MFConv": 102 | self.net = MFConv(n_features=n_features, n_classes=n_classes,\ 103 | n_hidden_GNN=n_hidden_GNN, n_hidden_FC=n_hidden_FC, \ 104 | dropout_FC=dropout_FC, dropout_GNN=dropout_GNN) 105 | if classifier =="TransformerConv": 106 | self.net = TransformerConv(n_features=n_features, n_classes=n_classes,\ 107 | n_hidden_GNN=n_hidden_GNN, n_hidden_FC=n_hidden_FC, \ 108 | dropout_FC=dropout_FC, dropout_GNN=dropout_GNN) 109 | if classifier =="Conv1d": 110 | self.net = ConvNet(n_features=n_features, n_classes=n_classes,\ 111 | n_hidden_GNN=n_hidden_GNN, n_hidden_FC=n_hidden_FC, \ 112 | dropout_FC=dropout_FC, dropout_GNN=dropout_GNN, filter_K=K, pool_K=pool_K) 113 | self.criterion = nn.CrossEntropyLoss() 114 | self.optimizer = optim.SGD(self.net.parameters(), lr=lr, momentum=momentum) 115 | self.logging = log_dir is not None 116 | self.device = device 117 | self.lr = lr 118 | if self.logging: 119 | self.writer = SummaryWriter(log_dir=log_dir,flush_secs=1) 120 | 121 | def fit(self,data_loader,epochs,test_dataloader=None,verbose=False): 122 | """ 123 | fits the classifier to the input data. 124 | 125 | Parameters 126 | ---------- 127 | data_loader : torch-geometric dataloader 128 | the training dataset. 129 | epochs : int 130 | number of epochs. 131 | test_dataloader : torch-geometric dataloader, default=None 132 | the test dataset on which the model is evaluated in each epoch. 133 | verbose : boolean, default=False 134 | whether to print out loss during training. 135 | """ 136 | if self.logging: 137 | data= next(iter(data_loader)) 138 | self.writer.add_graph(self.net,[data.x,data.edge_index]) 139 | # self.scheduler = lr_scheduler.CyclicLR(self.optimizer, base_lr=self.lr, max_lr=0.01,step_size_up=5,mode="triangular2") 140 | 141 | self.scheduler = lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, T_0=50, T_mult=1, eta_min=0.00005, last_epoch=-1) 142 | for epoch in range(epochs): 143 | self.net.train() 144 | self.net.to(self.device) 145 | total_loss = 0 146 | 147 | for batch in data_loader: 148 | x, edge_index, label = batch.x.to(self.device), batch.edge_index.to(self.device), batch.y.to(self.device) 149 | self.optimizer.zero_grad() 150 | pred = self.net(x, edge_index) 151 | loss = self.criterion(pred,label) 152 | loss.backward() 153 | self.optimizer.step() 154 | self.scheduler.step() 155 | total_loss += loss.item() * batch.num_graphs 156 | total_loss /= len(data_loader.dataset) 157 | if verbose and epoch%(epochs//10)==0: 158 | print('[%d] loss: %.3f' % (epoch + 1,total_loss)) 159 | 160 | if self.logging: 161 | #Save the training loss, the training accuracy and the test accuracy for tensorboard vizualisation 162 | self.writer.add_scalar("Training Loss",total_loss,epoch) 163 | accuracy_train = self.eval(data_loader,verbose=False)[0] 164 | self.writer.add_scalar("Accuracy on Training Dataset",accuracy_train,epoch) 165 | if test_dataloader is not None: 166 | accuracy_test = self.eval(test_dataloader,verbose=False)[0] 167 | self.writer.add_scalar("Accuracy on Test Dataset",accuracy_test,epoch) 168 | 169 | 170 | 171 | 172 | 173 | 174 | def eval(self,data_loader,verbose=False): 175 | """ 176 | evaluates the model based on predictions 177 | 178 | Parameters 179 | ---------- 180 | test_dataloader : torch-geometric dataloader, default=None 181 | the dataset on which the model is evaluated. 182 | verbose : boolean, default=False 183 | whether to print out loss during training. 184 | Returns 185 | ---------- 186 | accuracy : float 187 | accuracy 188 | conf_mat : ndarray 189 | confusion matrix 190 | precision : fload 191 | weighted precision score 192 | recall : float 193 | weighted recall score 194 | f1_score : float 195 | weighted f1 score 196 | """ 197 | self.net.eval() 198 | correct = 0 199 | total = 0 200 | y_true = [] 201 | y_pred = [] 202 | with torch.no_grad(): 203 | for batch in data_loader: 204 | x, edge_index, label = batch.x.to(self.device), batch.edge_index.to(self.device), batch.y.to('cpu') 205 | y_true.extend(list(label)) 206 | outputs = self.net(x, edge_index) 207 | _, predicted = torch.max(outputs.data, 1) 208 | predicted = predicted.to('cpu') 209 | y_pred.extend(list(predicted)) 210 | accuracy, conf_mat, precision, recall, f1_score = compute_metrics(y_true, y_pred) 211 | if verbose: 212 | print('Accuracy: {:.3f}'.format(accuracy)) 213 | print('Confusion Matrix:\n', conf_mat) 214 | print('Precision: {:.3f}'.format(precision)) 215 | print('Recall: {:.3f}'.format(recall)) 216 | print('f1_score: {:.3f}'.format(f1_score)) 217 | return accuracy, conf_mat, precision, recall, f1_score 218 | 219 | def interpret(self, data_loader, n_features, n_classes): 220 | """ 221 | interprets a trained model, by giving importance scores assigned to each feature regarding each class 222 | it uses the `IntegratedGradients` method from the package `captum` to computed class-wise feature importances 223 | and then computes entropy values to get a global importance measure. 224 | 225 | Parameters 226 | ---------- 227 | data_loder : torch-geometric dataloader, default=None 228 | the dataset on which the model is evaluated. 229 | n_features : int 230 | number of features. 231 | n_classes : int 232 | number of classes. 233 | 234 | Returns 235 | ---------- 236 | ent : numpy ndarray, shape (n_features) 237 | 238 | """ 239 | batch = next(iter(data_loader)) 240 | e = batch.edge_index.to(self.device).long() 241 | def model_forward(input): 242 | out = self.net(input, e) 243 | return out 244 | self.net.eval() 245 | importances = np.zeros((n_features, n_classes)) 246 | for batch in data_loader: 247 | input = batch.x.to(self.device) 248 | target = batch.y.to(self.device) 249 | ig = IntegratedGradients(model_forward) 250 | attributions = ig.attribute(input, target=target) 251 | attributions = attributions.to('cpu').detach().numpy() 252 | attributions = attributions.reshape(n_features, len(target)) 253 | attributions = normalize(attributions, axis=0, norm='l2') 254 | # attributions /= np.linalg.norm(attributions) 255 | importances[:, target.to('cpu').numpy()] += attributions 256 | # importances = np.e**importances 257 | # importances = importances / importances.max(axis=0) 258 | # imp = (importances.T / np.sum(importances, axis = 1)).T 259 | # ent = (-imp * np.log2(imp)).sum(axis = 1) / np.log2(n_classes) 260 | # idx = (-importances).argsort(axis=0) 261 | # ent = np.min(idx, axis=1) 262 | return importances 263 | -------------------------------------------------------------------------------- /sagenet/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from sagenet.datasets._datasets import * -------------------------------------------------------------------------------- /sagenet/datasets/_datasets.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | from squidpy.datasets._utils import AMetadata 3 | 4 | _MGA_scRNAseq = AMetadata( 5 | name="MGA_scRNAseq", 6 | doc_header="", 7 | # shape=(270876, 43), 8 | url="https://figshare.com/ndownloader/files/31767704", 9 | ) 10 | 11 | _MGA_seqFISH1 = AMetadata( 12 | name="MGA_seqFISH1", 13 | doc_header="", 14 | # shape=(270876, 43), 15 | url="https://figshare.com/ndownloader/files/31716029", 16 | ) 17 | 18 | _MGA_seqFISH2 = AMetadata( 19 | name="MGA_seqFISH2", 20 | doc_header="", 21 | # shape=(270876, 43), 22 | url="https://figshare.com/ndownloader/files/31716041", 23 | ) 24 | 25 | 26 | _MGA_seqFISH3 = AMetadata( 27 | name="MGA_seqFISH3", 28 | doc_header="", 29 | # shape=(270876, 43), 30 | url="https://figshare.com/ndownloader/files/31716089", 31 | ) 32 | 33 | _DHH_visium_ = AMetadata( 34 | name="DHH_visium_", 35 | doc_header="", 36 | # shape=(270876, 43), 37 | url="https://figshare.com/ndownloader/files/31796207", 38 | ) 39 | 40 | _DHH_scRNAseq = AMetadata( 41 | name="DHH_scRNAseq", 42 | doc_header="", 43 | # shape=(270876, 43), 44 | url="https://figshare.com/ndownloader/files/31796219", 45 | ) 46 | 47 | 48 | 49 | 50 | for name, var in copy(locals()).items(): 51 | if isinstance(var, AMetadata): 52 | var._create_function(name, globals()) 53 | 54 | 55 | __all__ = [ # noqa: F822 56 | "MGA_scRNAseq", 57 | "MGA_seqFISH1", 58 | "MGA_seqFISH2", 59 | "MGA_seqFISH3", 60 | 'DHH_visium_', 61 | 'DHH_scRNAseq' 62 | 63 | ] 64 | -------------------------------------------------------------------------------- /sagenet/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch_geometric.nn as pyg_nn 5 | 6 | 7 | 8 | class NN(nn.Module): 9 | def __init__(self, \ 10 | n_features, \ 11 | n_classes, \ 12 | n_hidden_GNN=[], \ 13 | n_hidden_FC=[10], \ 14 | dropout_GNN=0, \ 15 | dropout_FC=0): 16 | super(NN, self).__init__() 17 | self.FC = True 18 | self.n_features = n_features 19 | self.n_classes = n_classes 20 | self.layers_GNN = nn.ModuleList() 21 | self.layers_FC = nn.ModuleList() 22 | self.n_layers_GNN = len(n_hidden_GNN) 23 | self.n_layers_FC = len(n_hidden_FC) 24 | self.dropout_GNN = dropout_GNN 25 | self.dropout_FC = dropout_FC 26 | self.n_hidden_GNN = n_hidden_GNN 27 | self.n_hidden_FC = n_hidden_FC 28 | self.conv = False 29 | if self.n_layers_GNN > 0: 30 | self.FC = False 31 | 32 | # Fully connected layers. They occur after the graph convolutions (or at the start if there no are graph convolutions) 33 | if self.n_layers_FC > 0: 34 | if self.n_layers_GNN==0: 35 | self.layers_FC.append(nn.Linear(n_features, n_hidden_FC[0])) 36 | else: 37 | self.layers_FC.append(nn.Linear(n_features*n_hidden_GNN[-1], n_hidden_FC[0])) 38 | if self.n_layers_FC > 1: 39 | for i in range(self.n_layers_FC-1): 40 | self.layers_FC.append(nn.Linear(n_hidden_FC[i], n_hidden_FC[(i+1)])) 41 | 42 | # Last layer 43 | if self.n_layers_FC>0: 44 | self.last_layer_FC = nn.Linear(n_hidden_FC[-1], n_classes) 45 | elif self.n_layers_GNN>0: 46 | self.last_layer_FC = nn.Linear(n_features*n_hidden_GNN[-1], n_classes) 47 | else: 48 | self.last_layer_FC = nn.Linear(n_features, n_classes) 49 | 50 | def forward(self,x,edge_index): 51 | if self.FC: 52 | # Resize from (1,batch_size * n_features) to (batch_size, n_features) 53 | x = x.view(-1,self.n_features) 54 | if self.conv: 55 | x = x.view(-1,1,self.n_features) 56 | for layer in self.layers_GNN: 57 | x = F.relu(layer(x)) 58 | x = F.max_pool1d(x, kernel_size=self.pool_K, stride=1, padding=self.pool_K//2, dilation=1) 59 | x = F.dropout(x, p=self.dropout_GNN, training=self.training) 60 | # x = F.max_pool1d(x) 61 | else: 62 | for layer in self.layers_GNN: 63 | x = F.relu(layer(x, edge_index)) 64 | x = F.dropout(x, p=self.dropout_GNN, training=self.training) 65 | if self.n_layers_GNN > 0: 66 | x = x.view(-1, self.n_features*self.n_hidden_GNN[-1]) 67 | for layer in self.layers_FC: 68 | x = F.relu(layer(x)) 69 | x = F.dropout(x, p=self.dropout_FC, training=self.training) 70 | x = self.last_layer_FC(x) 71 | return x 72 | 73 | 74 | class GraphSAGE(NN): 75 | def __init__(self, \ 76 | n_features, \ 77 | n_classes, \ 78 | n_hidden_GNN=[10], \ 79 | n_hidden_FC=[], \ 80 | dropout_GNN=0, \ 81 | dropout_FC=0): 82 | super(GraphSAGE, self).__init__(\ 83 | n_features, n_classes, n_hidden_GNN,\ 84 | n_hidden_FC, dropout_FC, dropout_GNN) 85 | 86 | self.layers_GNN.append(pyg_nn.SAGEConv(1, n_hidden_GNN[0])) 87 | if self.n_layers_GNN > 1: 88 | for i in range(self.n_layers_GNN-1): 89 | self.layers_GNN.append(pyg_nn.SAGEConv(n_hidden_GNN[i], n_hidden_GNN[(i+1)])) 90 | 91 | 92 | class ChebNet(NN): 93 | def __init__(self, 94 | n_features, 95 | n_classes, 96 | n_hidden_GNN=[10], 97 | n_hidden_FC=[], 98 | K=4, 99 | dropout_GNN=0, 100 | dropout_FC=0): 101 | super(ChebNet, self).__init__(\ 102 | n_features, n_classes, n_hidden_GNN,\ 103 | n_hidden_FC, dropout_FC, dropout_GNN) 104 | 105 | self.layers_GNN.append(pyg_nn.ChebConv(1, n_hidden_GNN[0], K)) 106 | if self.n_layers_GNN > 1: 107 | for i in range(self.n_layers_GNN-1): 108 | self.layers_GNN.append(pyg_nn.ChebConv(n_hidden_GNN[i], n_hidden_GNN[(i+1)], K)) 109 | 110 | 111 | class NNConvNet(NN): 112 | def __init__(self, \ 113 | n_features, \ 114 | n_classes, \ 115 | n_hidden_GNN=[10], \ 116 | n_hidden_FC=[], \ 117 | dropout_GNN=0, \ 118 | dropout_FC=0): 119 | super(NNConvNet, self).__init__(\ 120 | n_features, n_classes, n_hidden_GNN,\ 121 | n_hidden_FC, dropout_FC, dropout_GNN) 122 | 123 | self.layers_GNN.append(pyg_nn.NNConv(1, n_hidden_GNN[0])) 124 | if self.n_layers_GNN > 1: 125 | for i in range(self.n_layers_GNN-1): 126 | self.layers_GNN.append(pyg_nn.NNConv(n_hidden_GNN[i], n_hidden_GNN[(i+1)])) 127 | 128 | class GATConvNet(NN): 129 | def __init__(self, \ 130 | n_features, \ 131 | n_classes, \ 132 | n_hidden_GNN=[10], \ 133 | n_hidden_FC=[], \ 134 | dropout_GNN=0, \ 135 | dropout_FC=0): 136 | super(GATConvNet, self).__init__(\ 137 | n_features, n_classes, n_hidden_GNN,\ 138 | n_hidden_FC, dropout_FC, dropout_GNN) 139 | 140 | self.layers_GNN.append(pyg_nn.GATConv(1, n_hidden_GNN[0])) 141 | if self.n_layers_GNN > 1: 142 | for i in range(self.n_layers_GNN-1): 143 | self.layers_GNN.append(pyg_nn.GATConv(n_hidden_GNN[i], n_hidden_GNN[(i+1)])) 144 | 145 | class GENConvNet(NN): 146 | def __init__(self, \ 147 | n_features, \ 148 | n_classes, \ 149 | n_hidden_GNN=[10], \ 150 | n_hidden_FC=[], \ 151 | dropout_GNN=0, \ 152 | dropout_FC=0): 153 | super(GENConvNet, self).__init__(\ 154 | n_features, n_classes, n_hidden_GNN,\ 155 | n_hidden_FC, dropout_FC, dropout_GNN) 156 | 157 | self.layers_GNN.append(pyg_nn.GENConv(1, n_hidden_GNN[0])) 158 | if self.n_layers_GNN > 1: 159 | for i in range(self.n_layers_GNN-1): 160 | self.layers_GNN.append(pyg_nn.GENConv(n_hidden_GNN[i], n_hidden_GNN[(i+1)])) 161 | 162 | 163 | class GINConv(NN): 164 | def __init__(self, \ 165 | n_features, \ 166 | n_classes, \ 167 | n_hidden_GNN=[10], \ 168 | n_hidden_FC=[], \ 169 | dropout_GNN=0, \ 170 | dropout_FC=0): 171 | super(GINConv, self).__init__(\ 172 | n_features, n_classes, n_hidden_GNN,\ 173 | n_hidden_FC, dropout_FC, dropout_GNN) 174 | 175 | self.layers_GNN.append(pyg_nn.GINConv(nn.Sequential(nn.Linear(1, n_hidden_GNN[0]), 176 | nn.ReLU(), nn.Linear(n_hidden_GNN[0],n_hidden_GNN[0])),eps=0.2)) 177 | if self.n_layers_GNN > 1: 178 | for i in range(self.n_layers_GNN-1): 179 | self.layers_GNN.append(pyg_nn.GINConv(nn.Sequential(nn.Linear(n_hidden_GNN[i], n_hidden_GNN[(i+1)]), 180 | nn.ReLU(), nn.Linear(n_hidden_GNN[(i+1)],n_hidden_GNN[(i+1)])))) 181 | 182 | class GraphConv(NN): 183 | def __init__(self, \ 184 | n_features, \ 185 | n_classes, \ 186 | n_hidden_GNN=[10], \ 187 | n_hidden_FC=[], \ 188 | dropout_GNN=0, \ 189 | dropout_FC=0): 190 | super(GraphConv, self).__init__(\ 191 | n_features, n_classes, n_hidden_GNN,\ 192 | n_hidden_FC, dropout_FC, dropout_GNN) 193 | 194 | self.layers_GNN.append(pyg_nn.GraphConv(1, n_hidden_GNN[0])) 195 | if self.n_layers_GNN > 1: 196 | for i in range(self.n_layers_GNN-1): 197 | self.layers_GNN.append(pyg_nn.GraphConv(n_hidden_GNN[i], n_hidden_GNN[(i+1)])) 198 | 199 | class MFConv(NN): 200 | def __init__(self, \ 201 | n_features, \ 202 | n_classes, \ 203 | n_hidden_GNN=[10], \ 204 | n_hidden_FC=[], \ 205 | dropout_GNN=0, \ 206 | dropout_FC=0): 207 | super(MFConv, self).__init__(\ 208 | n_features, n_classes, n_hidden_GNN,\ 209 | n_hidden_FC, dropout_FC, dropout_GNN) 210 | 211 | self.layers_GNN.append(pyg_nn.MFConv(1, n_hidden_GNN[0])) 212 | if self.n_layers_GNN > 1: 213 | for i in range(self.n_layers_GNN-1): 214 | self.layers_GNN.append(pyg_nn.MFConv(n_hidden_GNN[i], n_hidden_GNN[(i+1)])) 215 | 216 | class TransformerConv(NN): 217 | def __init__(self, \ 218 | n_features, \ 219 | n_classes, \ 220 | n_hidden_GNN=[10], \ 221 | n_hidden_FC=[], \ 222 | dropout_GNN=0, \ 223 | dropout_FC=0): 224 | super(TransformerConv, self).__init__(\ 225 | n_features, n_classes, n_hidden_GNN,\ 226 | n_hidden_FC, dropout_FC, dropout_GNN) 227 | 228 | self.layers_GNN.append(pyg_nn.TransformerConv(1, n_hidden_GNN[0])) 229 | if self.n_layers_GNN > 1: 230 | for i in range(self.n_layers_GNN-1): 231 | self.layers_GNN.append(pyg_nn.TransformerConv(n_hidden_GNN[i], n_hidden_GNN[(i+1)])) 232 | 233 | class ConvNet(NN): 234 | def __init__(self, 235 | n_features, 236 | n_classes, 237 | n_hidden_GNN=[10], 238 | n_hidden_FC=[], 239 | filter_K=4, 240 | pool_K=0, 241 | dropout_GNN=0, 242 | dropout_FC=0): 243 | super(ConvNet, self).__init__(\ 244 | n_features, n_classes, n_hidden_GNN,\ 245 | n_hidden_FC, dropout_FC, dropout_GNN) 246 | self.conv = True 247 | self.filter_K = filter_K 248 | self.pool_K = pool_K 249 | 250 | self.layers_GNN.append(nn.Conv1d(in_channels=1, out_channels=n_hidden_GNN[0], kernel_size=filter_K, padding=filter_K//2, dilation=1, stride=1)) 251 | if self.n_layers_GNN > 1: 252 | for i in range(self.n_layers_GNN-1): 253 | self.layers_GNN.append(nn.Conv1d(in_channels=n_hidden_GNN[i], out_channels=n_hidden_GNN[(i+1)], kernel_size=filter_K, padding=filter_K//2, dilation=1, stride=1)) 254 | -------------------------------------------------------------------------------- /sagenet/sage.py: -------------------------------------------------------------------------------- 1 | from sagenet.utils import * 2 | from sagenet.classifier import * 3 | from sagenet.model import * 4 | from os import listdir 5 | import numpy as np 6 | import anndata 7 | import re 8 | 9 | class sage(): 10 | """ 11 | A `sagenet` object. 12 | 13 | Parameters 14 | ---------- 15 | device : str, default = 'cpu' 16 | the processing unit to be used in the classifiers (gpu or cpu). 17 | """ 18 | 19 | def __init__(self, device='cpu'): 20 | self.models = {} 21 | self.adjs = {} 22 | inf_genes = None 23 | self.num_refs = 0 24 | self.device = device 25 | 26 | def add_ref(self, 27 | adata, 28 | tag = None, 29 | comm_columns = 'class_', 30 | classifier = 'TransformerConv', 31 | num_workers = 0, 32 | batch_size = 32, 33 | epochs = 10, 34 | n_genes = 10, 35 | verbose = False): 36 | """Trains new classifiers on a reference dataset. 37 | 38 | Parameters 39 | ---------- 40 | adata : `AnnData` 41 | The annotated data matrix of shape `n_obs × n_vars` to be used as the spatial reference. Rows correspond to cells (or spots) and columns to genes. 42 | tag : str, default = `None` 43 | The tag to be used for storing the trained models and the outputs in the `sagenet` object. 44 | classifier : str, default = `'TransformerConv'` 45 | The type of classifier to be passed to `sagenet.Classifier()` 46 | comm_columns : list of str, `'class_'` 47 | The columns in `adata.obs` to be used as spatial partitions. 48 | num_workers : int 49 | Non-negative. Number of workers to be passed to `torch_geometric.data.DataLoader`. 50 | epochs : int 51 | number of epochs. 52 | verbose : boolean, default=False 53 | whether to print out loss during training. 54 | 55 | Return 56 | ------ 57 | Returns nothing. 58 | 59 | Notes 60 | ----- 61 | Trains the models and adds them to `.models` dictionery of the `sagenet` object. 62 | Also adds a new key `{tag}_entropy` to `.var` from `adata` which contains the entropy values as the importance score corresponding to each gene. 63 | """ 64 | ind = np.where(np.sum(adata.varm['adj'], axis=1) == 0)[0] 65 | ents = np.ones(adata.var.shape[0]) * 1000000 66 | # ents = np.zeros(adata.var.shape[0]) 67 | self.num_refs += 1 68 | 69 | if tag is None: 70 | tag = 'ref' + str(self.num_refs) 71 | 72 | for comm in comm_columns: 73 | data_loader = get_dataloader( 74 | graph = adata.varm['adj'].toarray(), 75 | X = adata.X, y = adata.obs[comm].values.astype('long'), 76 | batch_size = batch_size, 77 | shuffle = True, 78 | num_workers = num_workers 79 | ) 80 | 81 | clf = Classifier( 82 | n_features = adata.shape[1], 83 | n_classes = (np.max(adata.obs[comm].values.astype('long'))+1), 84 | n_hidden_GNN = [8], 85 | dropout_FC = 0.2, 86 | dropout_GNN = 0.3, 87 | classifier = classifier, 88 | lr = 0.001, 89 | momentum = 0.9, 90 | device = self.device 91 | ) 92 | 93 | clf.fit(data_loader, epochs = epochs, test_dataloader=None,verbose=verbose) 94 | imp = clf.interpret(data_loader, n_features=adata.shape[1], n_classes=(np.max(adata.obs[comm].values.astype('long'))+1)) 95 | idx = (-abs(imp)).argsort(axis=0) 96 | imp = np.min(idx, axis=1) 97 | # imp += imp 98 | np.put(imp, ind, 1000000) 99 | ents = np.minimum(ents, imp) 100 | 101 | # imp = np.min(idx, axis=1) 102 | # ents = np.minimum(ents, imp) 103 | self.models['_'.join([tag, comm])] = clf.net 104 | self.adjs['_'.join([tag, comm])] = adata.varm['adj'].toarray() 105 | save_adata(adata, attr='var', key='_'.join([tag, 'importance']), data=ents) 106 | # return ents 107 | 108 | 109 | def map_query(self, adata_q, save_pred=True, save_ent=True, save_prob=False, save_dist=False): 110 | """Maps a query dataset to space using the trained models on the spatial reference(s). 111 | 112 | Parameters 113 | ---------- 114 | adata : `AnnData` 115 | The annotated data matrix of shape `n_obs × n_vars` to be used as the query. Rows correspond to cells (or spots) and columns to genes. 116 | 117 | Return 118 | ------ 119 | Returns nothing. 120 | 121 | Notes 122 | ----- 123 | * Adds new key(s) `pred_{tag}_{partitioning_name}` to `.obs` from `adata` which contains the predicted partition for partitioning `{partitioning_name}`, trained by model `{tag}`. 124 | * Adds new key(s) `ent_{tag}_{partitioning_name}` to `.obs` from `adata` which contains the uncertainity in prediction for partitioning `{partitioning_name}`, trained by model `{tag}`. 125 | * Adds a new key `distmap` to `.obsm` from `adata` which is a sparse matrix of size `n_obs × n_obs` containing the reconstructed cell-to-cell spatial distance. 126 | """ 127 | for tag in self.models.keys(): 128 | self.models[tag].eval() 129 | i = 0 130 | adata_q.obs['class_'] = 0 131 | data_loader = get_dataloader( 132 | graph = self.adjs[tag], 133 | X = adata_q.X, y = adata_q.obs['class_'].values.astype('long'), #TODO: fix this 134 | batch_size = 1, 135 | shuffle = False, 136 | num_workers = 0 137 | ) 138 | with torch.no_grad(): 139 | for batch in data_loader: 140 | x, edge_index = batch.x.to(self.device), batch.edge_index.to(self.device) 141 | outputs = self.models[tag](x, edge_index) 142 | predicted = outputs.data.to('cpu').detach().numpy() 143 | i += 1 144 | if i == 1: 145 | n_classes = predicted.shape[1] 146 | y_pred = np.empty((0, n_classes)) 147 | y_pred = np.concatenate((y_pred, predicted), axis=0) 148 | if save_pred or save_prob or save_ent or save_dist: 149 | y_pred = np.exp(y_pred) 150 | y_pred = (y_pred.T / y_pred.T.sum(0)).T 151 | if save_prob: 152 | save_adata(adata_q, attr='obsm', key='_'.join(['prob', tag]), data = y_pred) 153 | if save_pred: 154 | save_adata(adata_q, attr='obs', key='_'.join(['pred', tag]), data = np.argmax(y_pred, axis=1)) 155 | if save_ent: 156 | temp = (-y_pred * np.log2(y_pred)).sum(axis = 1) 157 | # adata_q.obs['_'.join(['ent', tag])] = np.array(temp) / np.log2(n_classes) 158 | save_adata(adata_q, attr='obs', key='_'.join(['ent', tag]), data = (np.array(temp) / np.log2(n_classes))) 159 | if save_dist: 160 | dist_mat = np.zeros((adata_q.shape[0], adata_q.shape[0])) 161 | y_pred_1 = (multinomial_rvs(1, y_pred).T * np.array(adata_q.obs['_'.join(['ent', tag])])).T 162 | y_pred_2 = (y_pred.T * (1-np.array(adata_q.obs['_'.join(['ent', tag])]))).T 163 | y_pred_final = y_pred_1 + y_pred_2 164 | kl_d = kullback_leibler_divergence(y_pred_final) 165 | kl_d = kl_d + kl_d.T 166 | kl_d /= np.linalg.norm(kl_d, 'fro') 167 | dist_mat += kl_d 168 | save_adata(adata_q, attr='obsm', key='dist_map', data=dist_mat) 169 | 170 | def save_model(self, tag, dir='.'): 171 | """Saves a single trained model. 172 | 173 | Parameters 174 | ---------- 175 | tag : str 176 | Name of the trained model to be saved. 177 | dir : dir, defult=`'.'` 178 | The saving directory. 179 | """ 180 | path = os.path.join(dir, tag) + '.pickle' 181 | torch.save(self.models[tag], path) 182 | 183 | def load_model(self, tag, dir='.'): 184 | """Loads a single pre-trained model. 185 | 186 | Parameters 187 | ---------- 188 | tag : str 189 | Name of the trained model to be stored in the `sagenet` object. 190 | dir : dir, defult=`'.'` 191 | The input directory. 192 | """ 193 | path = os.path.join(dir, tag) + '.pickle' 194 | self.models[tag] = torch.load(path) 195 | 196 | def save_model_as_folder(self, dir='.'): 197 | """Saves all trained models stored in the `sagenet` object as a folder. 198 | 199 | Parameters 200 | ---------- 201 | dir : dir, defult=`'.'` 202 | The saving directory. 203 | """ 204 | for tag in self.models.keys(): 205 | self.save_model(tag, dir) 206 | adj_path = os.path.join(dir, tag) + '.h5ad' 207 | adj_adata = anndata.AnnData(X = self.adjs[tag]) 208 | adj_adata.write(filename=adj_path) 209 | 210 | def load_model_as_folder(self, dir='.'): 211 | """Loads pre-trained models from a directory. 212 | 213 | Parameters 214 | ---------- 215 | dir : dir, defult=`'.'` 216 | The input directory. 217 | """ 218 | model_files = [f for f in listdir(dir) if re.search(r".pickle$", f)] 219 | for m in model_files: 220 | tag = re.sub(r'.pickle', '', m) 221 | model_path = os.path.join(dir, tag) + '.pickle' 222 | adj_path = os.path.join(dir, tag) + '.h5ad' 223 | self.models[tag] = torch.load(model_path) 224 | self.adjs[tag] = anndata.read_h5ad(adj_path).X 225 | 226 | 227 | -------------------------------------------------------------------------------- /sagenet/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import numpy as np 4 | from sklearn.metrics import * 5 | from gglasso.helper.data_generation import generate_precision_matrix, group_power_network, sample_covariance_matrix 6 | from gglasso.problem import glasso_problem 7 | from gglasso.helper.basic_linalg import adjacency_matrix 8 | import torch 9 | import torch_geometric.data as geo_dt 10 | from sklearn.utils.extmath import fast_logdet 11 | from scipy import sparse 12 | import warnings 13 | from torch.nn import Softmax 14 | from sklearn.preprocessing import StandardScaler 15 | from sklearn.covariance import empirical_covariance 16 | from sklearn.metrics import * 17 | from scipy import sparse 18 | from functools import reduce 19 | import warnings 20 | from torch.nn import Softmax 21 | from scipy.spatial import cKDTree 22 | 23 | def glasso(adata, lambda_low=-10, lambda_high=-1, mode='cd'): 24 | """ 25 | Recustructs the gene-gene interaction network based on gene expressions in `.X` using a guassian graphical model estimated by `glasso`. 26 | 27 | Parameters 28 | ---------- 29 | adata: `AnnData` 30 | The annotated data matrix of shape `n_obs × n_vars`. Rows correspond to cells and columns to genes. 31 | alphas: int or array-like of shape (n_alphas,), dtype=`float`, default=`5` 32 | Non-negative. If an integer is given, it fixes the number of points on the grids of alpha to be used. If a list is given, it gives the grid to be used. 33 | n_jobs: int, default `None` 34 | Non-negative. number of jobs. 35 | 36 | Returns 37 | ------- 38 | adds an `csr_matrix` matrix under key `adj` to `.varm`. 39 | 40 | References 41 | ----------- 42 | Friedman, J., Hastie, T., & Tibshirani, R. (2008). 43 | Sparse inverse covariance estimation with the graphical lasso. 44 | Biostatistics, 9(3), 432-441. 45 | """ 46 | N = adata.shape[1] 47 | scaler = StandardScaler() 48 | data = scaler.fit_transform(adata.X) 49 | S = empirical_covariance(data) 50 | P = glasso_problem(S, N, latent = False, do_scaling = True) 51 | # lambda1_range = np.logspace(-0.1, -1, 10) 52 | lambda1_range = np.logspace(-10, -1,10) 53 | modelselect_params = {'lambda1_range': lambda1_range} 54 | P.model_selection(modelselect_params = modelselect_params, method = 'eBIC', gamma = 0.1, tol=1e-7) 55 | sol = P.solution.precision_ 56 | P.solution.calc_adjacency(t = 1e-4) 57 | save_adata(adata, attr='varm', key='adj', data=sparse.csr_matrix(P.solution.precision_)) 58 | 59 | 60 | 61 | def compute_metrics(y_true, y_pred): 62 | """ 63 | Computes prediction quality metrics. 64 | 65 | Parameters 66 | ---------- 67 | y_true : 1d array-like, or label indicator array / sparse matrix 68 | Ground truth (correct) labels. 69 | 70 | y_pred : 1d array-like, or label indicator array / sparse matrix 71 | Predicted labels, as returned by a classifier. 72 | 73 | Returns 74 | -------- 75 | accuracy : accuracy 76 | conf_mat : confusion matrix 77 | precision : weighted precision score 78 | recall : weighted recall score 79 | f1 : weighted f1 score 80 | """ 81 | accuracy = accuracy_score(y_true, y_pred) 82 | conf_mat = confusion_matrix(y_true, y_pred) 83 | precision = precision_score(y_true, y_pred, average='weighted') 84 | recall = recall_score(y_true, y_pred, average='weighted') 85 | f1 = f1_score(y_true, y_pred, average='weighted') 86 | return accuracy, conf_mat, precision, recall, f1 87 | 88 | 89 | 90 | 91 | def get_dataloader(graph, X, y, batch_size=1, undirected=True, shuffle=True, num_workers=0): 92 | """ 93 | Converts a graph and a dataset to a dataloader. 94 | 95 | Parameters 96 | ---------- 97 | graph : igraph object 98 | The underlying graph to be fed to the graph neural networks. 99 | 100 | X : numpy ndarray 101 | Input dataset with columns as features and rows as observations. 102 | 103 | y : numpy ndarray 104 | Class labels. 105 | 106 | batch_size: int, default=1 107 | The batch size. 108 | 109 | undirected: boolean 110 | if the input graph is undirected (symmetric adjacency matrix). 111 | 112 | shuffle: boolean, default = `True` 113 | Wheather to shuffle the dataset to be passed to `torch_geometric.data.DataLoader`. 114 | 115 | num_workers: int, default = 0 116 | Non-negative. Number of workers to be passed to `torch_geometric.data.DataLoader`. 117 | 118 | 119 | Returns 120 | -------- 121 | dataloader : a pytorch-geometric dataloader. All of the graphs will have the same connectivity (given by the input graph), 122 | but the node features will be the features from X. 123 | """ 124 | n_obs, n_features = X.shape 125 | rows, cols = np.where(graph == 1) 126 | edges = zip(rows.tolist(), cols.tolist()) 127 | sources = [] 128 | targets = [] 129 | for edge in edges: 130 | sources.append(edge[0]) 131 | targets.append(edge[1]) 132 | if undirected: 133 | sources.append(edge[0]) 134 | targets.append(edge[1]) 135 | edge_index = torch.tensor([sources,targets],dtype=torch.long) 136 | 137 | list_graphs = [] 138 | y = y.tolist() 139 | # print(y) 140 | for i in range(n_obs): 141 | y_tensor = torch.tensor(y[i]) 142 | X_tensor = torch.tensor(X[i,:]).view(X.shape[1], 1).float() 143 | data = geo_dt.Data(x=X_tensor, edge_index=edge_index, y=y_tensor) 144 | list_graphs.append(data.coalesce()) 145 | 146 | dataloader = geo_dt.DataLoader(list_graphs, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=False) 147 | return dataloader 148 | 149 | 150 | 151 | def kullback_leibler_divergence(X): 152 | 153 | """Finds the pairwise Kullback-Leibler divergence 154 | matrix between all rows in X. 155 | 156 | Parameters 157 | ---------- 158 | X : array_like, shape (n_samples, n_features) 159 | Array of probability data. Each row must sum to 1. 160 | 161 | Returns 162 | ------- 163 | D : ndarray, shape (n_samples, n_samples) 164 | The Kullback-Leibler divergence matrix. A pairwise matrix D such that D_{i, j} 165 | is the divergence between the ith and jth vectors of the given matrix X. 166 | 167 | Notes 168 | ----- 169 | Based on code from Gordon J. Berman et al. 170 | (https://github.com/gordonberman/MotionMapper) 171 | 172 | References 173 | ----------- 174 | Berman, G. J., Choi, D. M., Bialek, W., & Shaevitz, J. W. (2014). 175 | Mapping the stereotyped behaviour of freely moving fruit flies. 176 | Journal of The Royal Society Interface, 11(99), 20140672. 177 | """ 178 | 179 | X_log = np.log(X) 180 | X_log[np.isinf(X_log) | np.isnan(X_log)] = 0 181 | 182 | entropies = -np.sum(X * X_log, axis=1) 183 | 184 | D = np.matmul(-X, X_log.T) 185 | D = D - entropies 186 | D = D / np.log(2) 187 | D *= (1 - np.eye(D.shape[0])) 188 | 189 | return D 190 | 191 | def multinomial_rvs(n, p): 192 | """Sample from the multinomial distribution with multiple p vectors. 193 | 194 | Parameters 195 | ---------- 196 | n : int 197 | must be a scalar >=1 198 | p : numpy ndarray 199 | must an n-dimensional 200 | he last axis of p holds the sequence of probabilities for a multinomial distribution. 201 | 202 | Returns 203 | ------- 204 | D : ndarray 205 | same shape as p 206 | """ 207 | count = np.full(p.shape[:-1], n) 208 | out = np.zeros(p.shape, dtype=int) 209 | ps = p.cumsum(axis=-1) 210 | # Conditional probabilities 211 | with np.errstate(divide='ignore', invalid='ignore'): 212 | condp = p / ps 213 | condp[np.isnan(condp)] = 0.0 214 | for i in range(p.shape[-1]-1, 0, -1): 215 | binsample = np.random.binomial(count, condp[..., i]) 216 | out[..., i] = binsample 217 | count -= binsample 218 | out[..., 0] = count 219 | return out 220 | 221 | def save_adata(adata, attr, key, data): 222 | """updates an attribute of an `AnnData` object 223 | 224 | Parameters 225 | ---------- 226 | adata : `AnnData` 227 | The annotated data matrix of shape `n_obs × n_vars`. Rows correspond to cells and columns to genes. 228 | attr : str 229 | must be an attribute of `adata`, e.g., `obs`, `var`, etc. 230 | key : str 231 | must be a key in the attr 232 | data : non-specific 233 | the data to be updated/placed 234 | 235 | """ 236 | obj = getattr(adata, attr) 237 | obj[key] = data 238 | 239 | 240 | def prob_con(adata, overwrite=False, inplace=True): 241 | if "prob_concatenated" in adata.obsm.keys(): 242 | warnings.warn("obsm['prob_concatenated'] already exists!") 243 | if not overwrite: 244 | return adata 245 | else: 246 | warnings.warn("overwriting obsm['prob_concatenated'].") 247 | del adata.obsm["prob_concatenated"] 248 | # Get a list of obsm matrices with names starting with "prob" 249 | prob_matrices = [matrix_name for matrix_name in adata.obsm.keys() if matrix_name.startswith("prob")] 250 | # Define a function to concatenate two matrices 251 | def concatenate_matrices(matrix1, matrix2): 252 | return np.concatenate((matrix1, matrix2), axis=1) 253 | # Use functools.reduce to concatenate all matrices in prob_matrices 254 | if prob_matrices: 255 | concatenated_matrix = reduce(concatenate_matrices, [adata.obsm[matrix] for matrix in prob_matrices]) 256 | adata.obsm["prob_concatenated"] = concatenated_matrix 257 | if inplace: 258 | save_adata(adata, attr='obsm', key='spatial', data=concatenated_matrix) 259 | return None 260 | else: 261 | warnings.warn("No 'prob' matrices found in the AnnData object.") 262 | if not inplace: 263 | return adata 264 | 265 | 266 | def map2ref(adata_ref, adata_q, k=10): 267 | # if "spatial" not in adata_ref.obsm.keys(): 268 | # raise Exception("adata_ref.obsm['spatial'] does not exist. Necessary for spatial mapping.") 269 | if "prob_concatenated" not in adata_ref.obsm.keys(): 270 | warnings.warn("obsm['prob_concatenated'] does not exsit for adata_ref. Calculating obsm['prob_concatenated'].") 271 | prob_con(adata_ref) 272 | if "prob_concatenated" not in adata_q.obsm.keys(): 273 | warnings.warn("obsm['prob_concatenated'] does not exsit for adata_q. Calculating obsm['prob_concatenated'].") 274 | prob_con(adata_q) 275 | ref_embeddings = adata_ref.obsm['prob_concatenated'] 276 | kdtree_r1 = cKDTree(ref_embeddings) 277 | target_embeddings = adata_q.obsm['prob_concatenated'] 278 | distances, indices = kdtree_r1.query(target_embeddings, k=k) 279 | m = Softmax(dim=1) 280 | probs = m(-torch.tensor(distances)) 281 | dist = torch.distributions.categorical.Categorical(probs=probs) 282 | idx = dist.sample().numpy() 283 | indices = indices[np.arange(len(indices)), idx] 284 | conf = distances.min(1) / distances.min() 285 | return indices, conf 286 | # adata_q.obsm['spatial'] = adata_ref.obsm['spatial'][indices] 287 | # if inplace: 288 | # save_adata(adata_q, attr='obsm', key='spatial', data= adata_ref.obsm['spatial'][indices]) 289 | # else: 290 | # return adata_q 291 | # swap.obs['sink'] -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from setuptools import setup, find_packages 4 | 5 | long_description = Path('README.rst').read_text('utf-8') 6 | 7 | import os 8 | os.environ[ 'NUMBA_CACHE_DIR' ] = '/tmp/' 9 | 10 | try: 11 | from sagenet import __author__, __email__ 12 | except ImportError: # Deps not yet installed 13 | __author__ = __maintainer__ ='Elyas Heidari' 14 | __email__ = 'eheidari@student.ethz.ch' 15 | __version__ = '0.1.2' 16 | 17 | setup(name='sagenet', 18 | version = "0.1.2", 19 | description='Spatial reconstruction of dissociated single-cell data', 20 | long_description=long_description, 21 | long_description_content_type="text/markdown", 22 | url='https://github.com/MarioniLab/sagenet', 23 | author=__author__, 24 | author_email=__email__, 25 | license='MIT', 26 | platforms=["Linux", "MacOSX"], 27 | packages=find_packages(), 28 | zip_safe=False, 29 | # download_url="https://github.com/MarioniLab/sagenet/archive/refs/tags/SageNet_v0.1.0.1.tar.gz", 30 | project_urls={ 31 | "Documentation": "https://sagenet.readthedocs.io/en/latest", 32 | "Source Code": "https://github.com/MarioniLab/sagenet", 33 | }, 34 | install_requires=[l.strip() for l in Path("requirements.txt").read_text("utf-8").splitlines()], 35 | classifiers=[ 36 | "Development Status :: 5 - Production/Stable", 37 | "Intended Audience :: Science/Research", 38 | "Natural Language :: English", 39 | "License :: OSI Approved :: BSD License", 40 | "Operating System :: POSIX :: Linux", 41 | "Operating System :: MacOS :: MacOS X", 42 | "Typing :: Typed", 43 | # "Programming Language :: Python :: 3", 44 | # "Programming Language :: Python :: 3.6", 45 | "Programming Language :: Python :: 3.7", 46 | "Programming Language :: Python :: 3.8", 47 | "Programming Language :: Python :: 3.9", 48 | "Environment :: Console", 49 | "Framework :: Jupyter", 50 | "Intended Audience :: Science/Research", 51 | "Topic :: Scientific/Engineering :: Bio-Informatics", 52 | "Topic :: Scientific/Engineering :: Visualization", 53 | ], 54 | doc=[ 55 | 'sphinx', 56 | 'sphinx_rtd_theme', 57 | 'sphinx_autodoc_typehints', 58 | 'typing_extensions; python_version < "3.8"', 59 | ], 60 | keywords=sorted( 61 | [ 62 | "single-cell", 63 | "bio-informatics", 64 | "spatial transcriptomics", 65 | "spatial data analysis", 66 | "single-cell data analysis", 67 | ] 68 | ), 69 | ) 70 | --------------------------------------------------------------------------------