├── .DS_Store
├── .gitignore
├── .travis.yml
├── LICENSE
├── README.rst
├── docs
├── .DS_Store
├── 00_hello_sagenet.ipynb
├── 01_multiple_references.ipynb
├── Makefile
├── _static
│ ├── .DS_Store
│ ├── custom.css
│ └── img
│ │ └── logo.png
├── about.rst
├── api.rst
├── conf.py
├── index.rst
└── make.bat
├── figures
├── show_ad_r1_all.pdf
└── show_ad_r1_all_conf.pdf
├── notebooks
├── 00_hello_sagenet.ipynb
├── 01_multiple_references.ipynb
├── SageNet_developing_human_heart_analysis.ipynb
└── SageNet_mouse_gastrulation_analysis.ipynb
├── requirements.txt
├── sagenet
├── .DS_Store
├── DHH_data
│ ├── _DHH_data.py
│ └── __init__.py
├── DHH_data_
│ ├── _DHH_data_.py
│ └── __init__.py
├── MGA_analysis.py
├── MGA_data
│ ├── _MGA_data.py
│ └── __init__.py
├── __init__.py
├── classifier.py
├── datasets
│ ├── __init__.py
│ └── _datasets.py
├── model.py
├── sage.py
└── utils.py
└── setup.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarioniLab/sagenet/5dce6dd375cf28678d735f5e5d8083dfd3d86596/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | dist: xenial
3 | cache: pip
4 | python:
5 | - "3.7"
6 | - "3.8"
7 | - "3.9"
8 |
9 | install:
10 | - pip install -r requirements.txt
11 | - python setup.py install
12 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Marioni Laboratory
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | |PyPI| |Docs|
2 |
3 | SageNet: Spatial reconstruction of dissociated single-cell datasets using graph neural networks
4 | =========================================================================
5 | .. raw:: html
6 |
7 | **SageNet** is a robust and generalizable graph neural network approach that probabilistically maps dissociated single cells from an scRNAseq dataset to their hypothetical tissue of origin using one or more reference datasets aquired by spatially resolved transcriptomics techniques. It is compatible with both high-plex imaging (e.g., seqFISH, MERFISH, etc.) and spatial barcoding (e.g., 10X visium, Slide-seq, etc.) datasets as the spatial reference.
8 |
9 | .. raw:: html
10 |
11 |
12 |
13 |
15 |
16 |
17 |
18 |
19 |
20 | SageNet is implemented with `pytorch `_ and `pytorch-geometric `_ to be modular, fast, and scalable. Also, it uses ``anndata`` to be compatible with `scanpy `_ and `squidpy `_ for pre- and post-processing steps.
21 |
22 |
23 | .. raw:: html
24 |
25 |
26 |
27 |
29 |
30 |
31 |
32 |
33 |
34 | Installation
35 | ============
36 |
37 |
38 | .. note::
39 |
40 | **v1.0**
41 | The dependency ``torch-geometric`` should be installed separately, corresponding the system specefities, look at `this link `_ for instructions. We recommend to use Miniconda.
42 |
43 |
44 | GitHub (currently recomended)
45 | ---------------
46 |
47 | First, clone the repository using ``git``::
48 |
49 | git clone https://github.com/MarioniLab/sagenet
50 |
51 | Then, ``cd`` to the sagenet folder and run the install command::
52 |
53 | cd sagenet
54 | python setup.py install #or pip install .
55 |
56 | PyPI
57 | --------
58 |
59 | The easiest way to get SageNet is through pip using the following command::
60 |
61 | pip install sagenet
62 |
63 |
64 |
65 |
66 | Usage
67 | ============
68 |
69 | ::
70 |
71 | import sagenet as sg
72 | import scanpy as sc
73 | import squidpy as sq
74 | import anndata as ad
75 | import random
76 | random.seed(10)
77 |
78 |
79 | Training phase:
80 | ---------------
81 |
82 |
83 | **Input:**
84 |
85 | - Expression matrix associated with the (spatial) reference dataset (an ``anndata`` object)
86 |
87 | ::
88 |
89 | adata_r = sg.MGA_data.seqFISH1()
90 |
91 |
92 | - gene-gene interaction network
93 |
94 |
95 | ::
96 |
97 | glasso(adata_r, [0.5, 0.75, 1])
98 |
99 |
100 |
101 |
102 | - one or more partitionings of the spatial reference into distinct connected neighborhoods of cells or spots
103 |
104 | ::
105 |
106 | adata_r.obsm['spatial'] = np.array(adata_r.obs[['x','y']])
107 | sq.gr.spatial_neighbors(adata_r, coord_type="generic")
108 | sc.tl.leiden(adata_r, resolution=.01, random_state=0, key_added='leiden_0.01', adjacency=adata_r.obsp["spatial_connectivities"])
109 | sc.tl.leiden(adata_r, resolution=.05, random_state=0, key_added='leiden_0.05', adjacency=adata_r.obsp["spatial_connectivities"])
110 | sc.tl.leiden(adata_r, resolution=.1, random_state=0, key_added='leiden_0.1', adjacency=adata_r.obsp["spatial_connectivities"])
111 | sc.tl.leiden(adata_r, resolution=.5, random_state=0, key_added='leiden_0.5', adjacency=adata_r.obsp["spatial_connectivities"])
112 | sc.tl.leiden(adata_r, resolution=1, random_state=0, key_added='leiden_1', adjacency=adata_r.obsp["spatial_connectivities"])
113 |
114 |
115 |
116 | **Training:**
117 | ::
118 |
119 |
120 | sg_obj = sg.sage.sage(device=device)
121 | sg_obj.add_ref(adata_r, comm_columns=['leiden_0.01', 'leiden_0.05', 'leiden_0.1', 'leiden_0.5', 'leiden_1'], tag='seqFISH_ref', epochs=20, verbose = False)
122 |
123 |
124 |
125 | **Output:**
126 |
127 | - A set of pre-trained models (one for each partitioning)
128 |
129 | ::
130 |
131 |
132 | !mkdir models
133 | !mkdir models/seqFISH_ref
134 | sg_obj.save_model_as_folder('models/seqFISH_ref')
135 |
136 |
137 | - A set of Spatially Informative Genes
138 |
139 | ::
140 |
141 |
142 | ind = np.where(adata_r.var['ST_all_importance'] <= 5)[0]
143 | SIGs = list(adata_r.var_names[ind])
144 | with rc_context({'figure.figsize': (4, 4)}):
145 | sc.pl.spatial(adata_r, color=SIGs, ncols=4, spot_size=0.03, legend_loc=None)
146 |
147 |
148 | .. raw:: html
149 |
150 |
151 |
152 |
154 |
155 |
156 |
157 |
158 |
159 |
160 | Mapping phase
161 | ---------------
162 |
163 | **Input:**
164 |
165 | - Expression matrix associated with the (dissociated) query dataset (an ``anndata`` object)
166 | ::
167 |
168 | adata_q = sg.MGA_data.scRNAseq()
169 |
170 |
171 | **Mapping:**
172 | ::
173 |
174 | sg_obj.map_query(adata_q)
175 |
176 |
177 | **Output:**
178 |
179 | - The reconstructed cell-cell spatial distance matrix
180 | ::
181 |
182 |
183 | adata_q.obsm['dist_map']
184 |
185 |
186 | - A consensus scoring of mappability (uncertainity of mapping) of each cell to the references
187 | ::
188 |
189 |
190 | adata_q.obs
191 |
192 |
193 | ::
194 |
195 |
196 | import anndata
197 | dist_adata = anndata.AnnData(adata_q.obsm['dist_map'], obs = adata_q.obs)
198 | knn_indices, knn_dists, forest = sc.neighbors.compute_neighbors_umap(dist_adata.X, n_neighbors=50, metric='precomputed')
199 | dist_adata.obsp['distances'], dist_adata.obsp['connectivities'] = sc.neighbors._compute_connectivities_umap(
200 | knn_indices,
201 | knn_dists,
202 | dist_adata.shape[0],
203 | 50, # change to neighbors you plan to use
204 | )
205 | sc.pp.neighbors(dist_adata, metric='precomputed', use_rep='X')
206 | sc.tl.umap(dist_adata)
207 | sc.pl.umap(dist_adata, color='cell_type', palette=celltype_colours)
208 |
209 |
210 | .. raw:: html
211 |
212 |
213 |
214 |
216 |
217 |
218 |
219 |
220 | Notebooks
221 | ============
222 | To see some examples of our pipeline's capability, look at the `notebooks `_ directory. The notebooks are also available on google colab:
223 |
224 | #. `Spatial reconstruction of the developing human heart `_
225 | #. `Spatial reconstruction of the mouse embryo `_
226 |
227 | Interactive examples
228 | ============
229 | * `Spatial mapping of the mouse gastrulation atlas `_
230 |
231 |
232 | Support and contribute
233 | ============
234 | If you have a question or new architecture or a model that could be integrated into our pipeline, you can
235 | post an `issue `__ or reach us by `email `_.
236 |
237 |
238 | Contributions
239 | ============
240 | This work is led by Elyas Heidari and Shila Ghazanfar as a joint effort between `MarioniLab@CRUK@EMBL-EBI `__ and `RobinsonLab@UZH `__.
241 |
242 | .. |Docs| image:: https://readthedocs.org/projects/sagenet/badge/?version=latest
243 | :target: https://sagenet.readthedocs.io
244 |
245 | .. |PyPI| image:: https://img.shields.io/pypi/v/sagenet.svg
246 | :target: https://pypi.org/project/sagenet
247 |
248 | .. |travis| image:: https://travis-ci.com/MarioniLab/sagenet.svg?branch=main
249 | :target: https://travis-ci.com/MarioniLab/sagenet
250 |
--------------------------------------------------------------------------------
/docs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarioniLab/sagenet/5dce6dd375cf28678d735f5e5d8083dfd3d86596/docs/.DS_Store
--------------------------------------------------------------------------------
/docs/01_multiple_references.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "source": [
6 | "# Multiple references\n",
7 | "\n",
8 | "In this notebook we show installation and basic usage of **SageNet**. "
9 | ],
10 | "metadata": {
11 | "id": "B8TiAQi7hwXH"
12 | }
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "metadata": {
18 | "id": "5hXws4b_EPoR"
19 | },
20 | "outputs": [],
21 | "source": [
22 | "!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu111.html \n",
23 | "!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu111.html \n",
24 | "!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git "
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "metadata": {
31 | "id": "i8t1adiyJ51v"
32 | },
33 | "outputs": [],
34 | "source": [
35 | "!pwd"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {
42 | "colab": {
43 | "base_uri": "https://localhost:8080/"
44 | },
45 | "id": "Qgc-2ZeGEQIo",
46 | "outputId": "de2de8e9-a1f9-4d1f-c807-3e659586042f"
47 | },
48 | "outputs": [
49 | {
50 | "name": "stdout",
51 | "output_type": "stream",
52 | "text": [
53 | "fatal: destination path 'sagenet' already exists and is not an empty directory.\n",
54 | "/content/sagenet/sagenet\n",
55 | "\u001b[31mERROR: Directory '.' is not installable. Neither 'setup.py' nor 'pyproject.toml' found.\u001b[0m\n"
56 | ]
57 | }
58 | ],
59 | "source": [
60 | "!git clone https://github.com/MarioniLab/sagenet\n",
61 | "%cd sagenet\n",
62 | "!pip install ."
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": null,
68 | "metadata": {
69 | "id": "w-znoEjrERTp"
70 | },
71 | "outputs": [],
72 | "source": [
73 | "import sagenet as sg\n",
74 | "import scanpy as sc\n",
75 | "import squidpy as sq\n",
76 | "import anndata as ad\n",
77 | "import random\n",
78 | "random.seed(10)"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "metadata": {
85 | "id": "stz_3hVlGuaF"
86 | },
87 | "outputs": [],
88 | "source": [
89 | "celltype_colours = {\n",
90 | " \"Epiblast\" : \"#635547\",\n",
91 | " \"Primitive Streak\" : \"#DABE99\",\n",
92 | " \"Caudal epiblast\" : \"#9e6762\",\n",
93 | " \"PGC\" : \"#FACB12\",\n",
94 | " \"Anterior Primitive Streak\" : \"#c19f70\",\n",
95 | " \"Notochord\" : \"#0F4A9C\",\n",
96 | " \"Def. endoderm\" : \"#F397C0\",\n",
97 | " \"Definitive endoderm\" : \"#F397C0\",\n",
98 | " \"Gut\" : \"#EF5A9D\",\n",
99 | " \"Gut tube\" : \"#EF5A9D\",\n",
100 | " \"Nascent mesoderm\" : \"#C594BF\",\n",
101 | " \"Mixed mesoderm\" : \"#DFCDE4\",\n",
102 | " \"Intermediate mesoderm\" : \"#139992\",\n",
103 | " \"Caudal Mesoderm\" : \"#3F84AA\",\n",
104 | " \"Paraxial mesoderm\" : \"#8DB5CE\",\n",
105 | " \"Somitic mesoderm\" : \"#005579\",\n",
106 | " \"Pharyngeal mesoderm\" : \"#C9EBFB\",\n",
107 | " \"Splanchnic mesoderm\" : \"#C9EBFB\",\n",
108 | " \"Cardiomyocytes\" : \"#B51D8D\",\n",
109 | " \"Allantois\" : \"#532C8A\",\n",
110 | " \"ExE mesoderm\" : \"#8870ad\",\n",
111 | " \"Lateral plate mesoderm\" : \"#8870ad\",\n",
112 | " \"Mesenchyme\" : \"#cc7818\",\n",
113 | " \"Mixed mesenchymal mesoderm\" : \"#cc7818\",\n",
114 | " \"Haematoendothelial progenitors\" : \"#FBBE92\",\n",
115 | " \"Endothelium\" : \"#ff891c\",\n",
116 | " \"Blood progenitors 1\" : \"#f9decf\",\n",
117 | " \"Blood progenitors 2\" : \"#c9a997\",\n",
118 | " \"Erythroid1\" : \"#C72228\",\n",
119 | " \"Erythroid2\" : \"#f79083\",\n",
120 | " \"Erythroid3\" : \"#EF4E22\",\n",
121 | " \"Erythroid\" : \"#f79083\",\n",
122 | " \"Blood progenitors\" : \"#f9decf\",\n",
123 | " \"NMP\" : \"#8EC792\",\n",
124 | " \"Rostral neurectoderm\" : \"#65A83E\",\n",
125 | " \"Caudal neurectoderm\" : \"#354E23\",\n",
126 | " \"Neural crest\" : \"#C3C388\",\n",
127 | " \"Forebrain/Midbrain/Hindbrain\" : \"#647a4f\",\n",
128 | " \"Spinal cord\" : \"#CDE088\",\n",
129 | " \"Surface ectoderm\" : \"#f7f79e\",\n",
130 | " \"Visceral endoderm\" : \"#F6BFCB\",\n",
131 | " \"ExE endoderm\" : \"#7F6874\",\n",
132 | " \"ExE ectoderm\" : \"#989898\",\n",
133 | " \"Parietal endoderm\" : \"#1A1A1A\",\n",
134 | " \"Unknown\" : \"#FFFFFF\",\n",
135 | " \"Low quality\" : \"#e6e6e6\",\n",
136 | " # somitic and paraxial types\n",
137 | " # colour from T chimera paper Guibentif et al Developmental Cell 2021\n",
138 | " \"Cranial mesoderm\" : \"#77441B\",\n",
139 | " \"Anterior somitic tissues\" : \"#F90026\",\n",
140 | " \"Sclerotome\" : \"#A10037\",\n",
141 | " \"Dermomyotome\" : \"#DA5921\",\n",
142 | " \"Posterior somitic tissues\" : \"#E1C239\",\n",
143 | " \"Presomitic mesoderm\" : \"#9DD84A\"\n",
144 | "}"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": null,
150 | "metadata": {
151 | "id": "yUWR8-aPESU0"
152 | },
153 | "outputs": [],
154 | "source": [
155 | "from copy import copy\n",
156 | "adata_r1 = sg.datasets.seqFISH1()\n",
157 | "adata_r2 = sg.datasets.seqFISH2()\n",
158 | "adata_r3 = sg.datasets.seqFISH3()\n",
159 | "adata_q1 = copy(adata_r1)\n",
160 | "adata_q2 = copy(adata_r2)\n",
161 | "adata_q3 = copy(adata_r3)\n",
162 | "adata_q4 = sg.datasets.MGA()\n",
163 | "sc.pp.subsample(adata_q1, fraction=0.25)\n",
164 | "sc.pp.subsample(adata_q2, fraction=0.25)\n",
165 | "sc.pp.subsample(adata_q3, fraction=0.25)\n",
166 | "sc.pp.subsample(adata_q4, fraction=0.25)\n",
167 | "adata_q = ad.concat([adata_q1, adata_q2, adata_q3, adata_q4], join=\"inner\")\n",
168 | "del adata_q1 \n",
169 | "del adata_q2 \n",
170 | "del adata_q3 \n",
171 | "del adata_q4"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": null,
177 | "metadata": {
178 | "id": "aU5MmeY8GrYE"
179 | },
180 | "outputs": [],
181 | "source": [
182 | "from sagenet.utils import glasso\n",
183 | "import numpy as np\n",
184 | "glasso(adata_r1, [0.5, 0.75, 1])\n",
185 | "adata_r1.obsm['spatial'] = np.array(adata_r1.obs[['x','y']])\n",
186 | "sq.gr.spatial_neighbors(adata_r1, coord_type=\"generic\")\n",
187 | "sc.tl.leiden(adata_r1, resolution=.01, random_state=0, key_added='leiden_0.01', adjacency=adata_r1.obsp[\"spatial_connectivities\"])\n",
188 | "sc.tl.leiden(adata_r1, resolution=.05, random_state=0, key_added='leiden_0.05', adjacency=adata_r1.obsp[\"spatial_connectivities\"])\n",
189 | "sc.tl.leiden(adata_r1, resolution=.1, random_state=0, key_added='leiden_0.1', adjacency=adata_r1.obsp[\"spatial_connectivities\"])\n",
190 | "sc.tl.leiden(adata_r1, resolution=.5, random_state=0, key_added='leiden_0.5', adjacency=adata_r1.obsp[\"spatial_connectivities\"])\n",
191 | "sc.tl.leiden(adata_r1, resolution=1, random_state=0, key_added='leiden_1', adjacency=adata_r1.obsp[\"spatial_connectivities\"])\n",
192 | "glasso(adata_r2, [0.5, 0.75, 1])\n",
193 | "adata_r2.obsm['spatial'] = np.array(adata_r2.obs[['x','y']])\n",
194 | "sq.gr.spatial_neighbors(adata_r2, coord_type=\"generic\")\n",
195 | "sc.tl.leiden(adata_r2, resolution=.01, random_state=0, key_added='leiden_0.01', adjacency=adata_r2.obsp[\"spatial_connectivities\"])\n",
196 | "sc.tl.leiden(adata_r2, resolution=.05, random_state=0, key_added='leiden_0.05', adjacency=adata_r2.obsp[\"spatial_connectivities\"])\n",
197 | "sc.tl.leiden(adata_r2, resolution=.1, random_state=0, key_added='leiden_0.1', adjacency=adata_r2.obsp[\"spatial_connectivities\"])\n",
198 | "sc.tl.leiden(adata_r2, resolution=.5, random_state=0, key_added='leiden_0.5', adjacency=adata_r2.obsp[\"spatial_connectivities\"])\n",
199 | "sc.tl.leiden(adata_r2, resolution=1, random_state=0, key_added='leiden_1', adjacency=adata_r2.obsp[\"spatial_connectivities\"])\n",
200 | "glasso(adata_r3, [0.5, 0.75, 1])\n",
201 | "adata_r3.obsm['spatial'] = np.array(adata_r3.obs[['x','y']])\n",
202 | "sq.gr.spatial_neighbors(adata_r3, coord_type=\"generic\")\n",
203 | "sc.tl.leiden(adata_r3, resolution=.01, random_state=0, key_added='leiden_0.01', adjacency=adata_r3.obsp[\"spatial_connectivities\"])\n",
204 | "sc.tl.leiden(adata_r3, resolution=.05, random_state=0, key_added='leiden_0.05', adjacency=adata_r3.obsp[\"spatial_connectivities\"])\n",
205 | "sc.tl.leiden(adata_r3, resolution=.1, random_state=0, key_added='leiden_0.1', adjacency=adata_r3.obsp[\"spatial_connectivities\"])\n",
206 | "sc.tl.leiden(adata_r3, resolution=.5, random_state=0, key_added='leiden_0.5', adjacency=adata_r3.obsp[\"spatial_connectivities\"])\n",
207 | "sc.tl.leiden(adata_r3, resolution=1, random_state=0, key_added='leiden_1', adjacency=adata_r3.obsp[\"spatial_connectivities\"])"
208 | ]
209 | },
210 | {
211 | "cell_type": "code",
212 | "execution_count": null,
213 | "metadata": {
214 | "colab": {
215 | "base_uri": "https://localhost:8080/"
216 | },
217 | "id": "r_St4TwlIcAW",
218 | "outputId": "e289e9c1-a225-419b-842d-6bfc5b4bf3d3"
219 | },
220 | "outputs": [
221 | {
222 | "name": "stdout",
223 | "output_type": "stream",
224 | "text": [
225 | "cpu\n"
226 | ]
227 | }
228 | ],
229 | "source": [
230 | "import torch\n",
231 | "if torch.cuda.is_available(): \n",
232 | " dev = \"cuda:0\" \n",
233 | "else: \n",
234 | " dev = \"cpu\" \n",
235 | "device = torch.device(dev)\n",
236 | "print(device)"
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "execution_count": null,
242 | "metadata": {
243 | "id": "WRcnUb8wIduW"
244 | },
245 | "outputs": [],
246 | "source": [
247 | "sg_obj = sg.sage.sage(device=device)"
248 | ]
249 | },
250 | {
251 | "cell_type": "code",
252 | "execution_count": null,
253 | "metadata": {
254 | "colab": {
255 | "background_save": true,
256 | "base_uri": "https://localhost:8080/"
257 | },
258 | "id": "BxX5nvMpISVZ",
259 | "outputId": "cd6f96d8-ea7c-4bbe-e48d-4a08c4d191e5"
260 | },
261 | "outputs": [
262 | {
263 | "name": "stderr",
264 | "output_type": "stream",
265 | "text": [
266 | "/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
267 | " warnings.warn(out)\n",
268 | "/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
269 | " warnings.warn(out)\n",
270 | "/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
271 | " warnings.warn(out)\n",
272 | "/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
273 | " warnings.warn(out)\n",
274 | "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
275 | " _warn_prf(average, modifier, msg_start, len(result))\n",
276 | "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
277 | " _warn_prf(average, modifier, msg_start, len(result))\n",
278 | "/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
279 | " warnings.warn(out)\n",
280 | "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
281 | " _warn_prf(average, modifier, msg_start, len(result))\n",
282 | "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
283 | " _warn_prf(average, modifier, msg_start, len(result))\n",
284 | "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
285 | " _warn_prf(average, modifier, msg_start, len(result))\n"
286 | ]
287 | }
288 | ],
289 | "source": [
290 | "sg_obj.add_ref(adata_r1, comm_columns=['leiden_0.01', 'leiden_0.05', 'leiden_0.1', 'leiden_0.5', 'leiden_1'], tag='seqFISH_ref1', epochs=15, verbose = False)"
291 | ]
292 | },
293 | {
294 | "cell_type": "code",
295 | "execution_count": null,
296 | "metadata": {
297 | "colab": {
298 | "background_save": true
299 | },
300 | "id": "YZ4lN_0kIWwO",
301 | "outputId": "499b9ceb-8c69-4200-9468-985587f43f32"
302 | },
303 | "outputs": [
304 | {
305 | "name": "stderr",
306 | "output_type": "stream",
307 | "text": [
308 | "/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
309 | " warnings.warn(out)\n",
310 | "/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
311 | " warnings.warn(out)\n",
312 | "/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
313 | " warnings.warn(out)\n",
314 | "/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
315 | " warnings.warn(out)\n",
316 | "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
317 | " _warn_prf(average, modifier, msg_start, len(result))\n",
318 | "/usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:13: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead\n",
319 | " warnings.warn(out)\n",
320 | "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
321 | " _warn_prf(average, modifier, msg_start, len(result))\n",
322 | "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
323 | " _warn_prf(average, modifier, msg_start, len(result))\n",
324 | "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1308: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
325 | " _warn_prf(average, modifier, msg_start, len(result))\n"
326 | ]
327 | }
328 | ],
329 | "source": [
330 | "sg_obj.add_ref(adata_r2, comm_columns=['leiden_0.01', 'leiden_0.05', 'leiden_0.1', 'leiden_0.5', 'leiden_1'], tag='seqFISH_ref2', epochs=15, verbose = False)"
331 | ]
332 | },
333 | {
334 | "cell_type": "code",
335 | "execution_count": null,
336 | "metadata": {
337 | "id": "sA9dF1vZIXLV"
338 | },
339 | "outputs": [],
340 | "source": [
341 | "sg_obj.add_ref(adata_r3, comm_columns=['leiden_0.01', 'leiden_0.05', 'leiden_0.1', 'leiden_0.5', 'leiden_1'], tag='seqFISH_ref3', epochs=15, verbose = False)"
342 | ]
343 | },
344 | {
345 | "cell_type": "code",
346 | "execution_count": null,
347 | "metadata": {
348 | "id": "qQjIRmPnI4F2"
349 | },
350 | "outputs": [],
351 | "source": [
352 | "ind = np.argsort(-(adata_r.var['seqFISH_ref_entropy']+ adata_r.var['seqFISH_ref2_entropy'] + adata_r.var['seqFISH_ref3_entropy']))[0:12]\n",
353 | "with rc_context({'figure.figsize': (4, 4)}):\n",
354 | " sc.pl.spatial(adata_r, color=list(adata_r.var_names[ind]), ncols=4, spot_size=0.03, legend_loc=None)"
355 | ]
356 | },
357 | {
358 | "cell_type": "code",
359 | "execution_count": null,
360 | "metadata": {
361 | "id": "qOWVKZw9JDPu"
362 | },
363 | "outputs": [],
364 | "source": [
365 | "!mkdir models\n",
366 | "!mkdir models/seqFISH_ref\n",
367 | "sg_obj.save_model_as_folder('models/seqFISH_multiple_ref')"
368 | ]
369 | },
370 | {
371 | "cell_type": "code",
372 | "execution_count": null,
373 | "metadata": {
374 | "id": "s5rn2FxrJUNN"
375 | },
376 | "outputs": [],
377 | "source": [
378 | "sg_obj.map_query(adata_q)"
379 | ]
380 | },
381 | {
382 | "cell_type": "code",
383 | "execution_count": null,
384 | "metadata": {
385 | "id": "WN1vAbpuJW74"
386 | },
387 | "outputs": [],
388 | "source": [
389 | "import anndata\n",
390 | "dist_adata = anndata.AnnData(adata_q.obsm['dist_map'], obs = adata_q.obs)\n",
391 | "knn_indices, knn_dists, forest = sc.neighbors.compute_neighbors_umap(dist_adata.X, n_neighbors=50, metric='precomputed')\n",
392 | "dist_adata.obsp['distances'], dist_adata.obsp['connectivities'] = sc.neighbors._compute_connectivities_umap(\n",
393 | " knn_indices,\n",
394 | " knn_dists,\n",
395 | " dist_adata.shape[0],\n",
396 | " 50, # change to neighbors you plan to use\n",
397 | ")\n",
398 | "sc.pp.neighbors(dist_adata, metric='precomputed', use_rep='X')\n",
399 | "sc.tl.umap(dist_adata)\n",
400 | "sc.pl.umap(dist_adata, color='cell_type', palette=celltype_colours, save='eli.pdf')"
401 | ]
402 | }
403 | ],
404 | "metadata": {
405 | "colab": {
406 | "name": "01_multiple_references.ipynb",
407 | "provenance": [],
408 | "collapsed_sections": []
409 | },
410 | "kernelspec": {
411 | "display_name": "Python 3",
412 | "name": "python3"
413 | },
414 | "language_info": {
415 | "name": "python"
416 | }
417 | },
418 | "nbformat": 4,
419 | "nbformat_minor": 0
420 | }
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_static/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarioniLab/sagenet/5dce6dd375cf28678d735f5e5d8083dfd3d86596/docs/_static/.DS_Store
--------------------------------------------------------------------------------
/docs/_static/custom.css:
--------------------------------------------------------------------------------
1 | /* ReadTheDocs theme colors */
2 | /* Copied from https://github.com/theislab/scvelo */
3 |
4 | .wy-nav-top {
5 | background-color: #404040
6 | }
7 |
8 | .wy-nav-content {
9 | max-width: 950px
10 | }
11 |
12 | .wy-side-nav-search {
13 | background-color: transparent
14 | }
15 |
16 | .wy-side-nav-search input[type="text"] {
17 | border-width: 0
18 | }
19 |
20 |
21 | /* Custom classes */
22 | .small {
23 | font-size: 40%
24 | }
25 |
26 | .smaller, .pr {
27 | font-size: 70%
28 | }
29 |
30 |
31 | /* Custom classes with bootstrap buttons */
32 |
33 | .tutorial,
34 | .tutorial:visited,
35 | .tutorial:hover {
36 | /* text-decoration: underline; */
37 | font-weight: bold;
38 | padding: 2px 5px;
39 | white-space: nowrap;
40 | max-width: 100%;
41 | background: #EF3270;
42 | border: solid 1px #EF3270;
43 | border-radius: .25rem;
44 | font-size: 75%;
45 | /* font-family: SFMono-Regular,Menlo,Monaco,Consolas,"Liberation Mono","Courier New",Courier,monospace; */
46 | color: #404040;
47 | overflow-x: auto;
48 | box-sizing: border-box;
49 | }
50 |
51 |
52 | /* Formatting of RTD markup: rubrics and sidebars and admonitions */
53 |
54 | /* rubric */
55 | .rst-content p.rubric {
56 | margin-bottom: 6px;
57 | font-weight: normal;
58 | }
59 |
60 | .rst-content p.rubric::after {
61 | content: ":"
62 | }
63 |
64 | /* sidebar */
65 | .rst-content .sidebar {
66 | /* margin: 0px 0px 0px 12px; */
67 | padding-bottom: 0px;
68 | }
69 |
70 | .rst-content .sidebar p {
71 | margin-bottom: 12px;
72 | }
73 |
74 | .rst-content .sidebar p,
75 | .rst-content .sidebar ul,
76 | .rst-content .sidebar dl {
77 | font-size: 13px;
78 | }
79 |
80 | /* less space after bullet lists in admonitions like warnings and notes */
81 | .rst-content .section .admonition ul {
82 | margin-bottom: 6px;
83 | }
84 |
85 |
86 | /* Code: literals and links */
87 |
88 | .rst-content tt.literal,
89 | .rst-content code.literal {
90 | color: #404040;
91 | }
92 |
93 | /* slim font weight for non-link code */
94 | .rst-content tt:not(.xref),
95 | .rst-content code:not(.xref),
96 | .rst-content *:not(a) > tt.xref,
97 | .rst-content *:not(a) > code.xref,
98 | .rst-content a > tt.xref,
99 | .rst-content a > code.xref,
100 | .rst-content dl:not(.docutils) a > tt.xref,
101 |
102 |
103 | /* Just one box for annotation code for a less noisy look */
104 |
105 | .rst-content .annotation {
106 | padding: 2px 5px;
107 | background-color: white;
108 | border: 1px solid #e1e4e5;
109 | }
110 |
111 | .rst-content .annotation tt,
112 | .rst-content .annotation code {
113 | padding: 0 0;
114 | background-color: transparent;
115 | border: 0 solid transparent;
116 | }
117 |
118 |
119 | /* Parameter lists */
120 |
121 | .rst-content dl:not(.docutils) dl dt {
122 | /* mimick numpydoc’s blockquote style */
123 | font-weight: normal;
124 | background: none transparent;
125 | border-left: none;
126 | margin: 0 0 12px;
127 | padding: 3px 0 0;
128 | font-size: 100%;
129 | }
130 |
131 | .rst-content dl:not(.docutils) dl dt code {
132 | font-size: 100%;
133 | font-weight: normal;
134 | background: none transparent;
135 | border: none;
136 | padding: 0 2px;
137 | }
138 |
139 | .rst-content dl:not(.docutils) dl dt a.reference > code {
140 | text-decoration: underline;
141 | }
142 |
143 | /* Mimick rubric style used for other headings */
144 | .rst-content dl:not(.docutils) dl > dt {
145 | font-weight: bold;
146 | background: none transparent;
147 | border-left: none;
148 | margin: 0 0 12px;
149 | padding: 3px 0 0;
150 | font-size: 100%;
151 | }
152 |
153 | /* Parameters contain parts and don’t need bold font */
154 | .rst-content dl.field-list dl > dt {
155 | font-weight: unset
156 | }
157 |
158 | /* Add colon between return tuple element name and type */
159 | .rst-content dl:not(.docutils) dl > dt .classifier::before {
160 | content: ' : '
161 | }
162 |
163 | /* Function headers */
164 |
165 | .rst-content dl:not(.docutils) dt {
166 | background: #edf0f2;
167 | color: #404040;
168 | border-top: solid 3px #343131;
169 | }
170 |
171 | .rst-content .section ul li p:last-child {
172 | margin-bottom: 0;
173 | margin-top: 0;
174 | }
175 |
176 | /* Copy buttons */
177 | a.copybtn {
178 | position: absolute;
179 | top: -6.5px;
180 | right: 0px;
181 | width: 1em;
182 | height: 1em;
183 | padding: .15em;
184 | opacity: .4;
185 | transition: opacity 0.5s;
186 | }
187 |
188 |
189 | /* Remove prompt line numbers */
190 | .nbinput > :first-child,
191 | .nboutput > :first-child {
192 | min-width: 0 !important;
193 | }
194 |
195 | /* Adjust width of navigation bar on mobile */
196 | @media screen and (max-width: 768px) {
197 | .header-bar {
198 | display: none;
199 | }
200 |
201 | .wy-nav-content-wrap {
202 | margin-left: 0px;
203 | }
204 |
205 | .wy-nav-side {
206 | width: 300px;
207 | }
208 |
209 | .wy-nav-side.shift {
210 | max-width: 320px;
211 | }
212 |
213 | /* Fix sidebar adjust */
214 | .rst-versions {
215 | width: 40%;
216 | max-width: 320px;
217 | }
218 | }
219 |
220 | /* Handle landscape */
221 | @media screen and (min-width: 377px) {
222 | .wy-nav-content-wrap.shift {
223 | left: 320px;
224 | }
225 | }
--------------------------------------------------------------------------------
/docs/_static/img/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarioniLab/sagenet/5dce6dd375cf28678d735f5e5d8083dfd3d86596/docs/_static/img/logo.png
--------------------------------------------------------------------------------
/docs/about.rst:
--------------------------------------------------------------------------------
1 |
2 |
3 | SageNet: Single-cell Spatial Locator
4 | =========================================================================
5 | .. raw:: html
6 |
7 | **SageNet** is a robust and generalizable graph neural network approach that probabilistically maps dissociated single cells from an scRNAseq dataset to their hypothetical tissue of origin using one or more reference datasets aquired by spatially resolved transcriptomics techniques. It is compatible with both high-plex imaging (e.g., seqFISH, MERFISH, etc.) and spatial barcoding (e.g., 10X visium, Slide-seq, etc.) datasets as the spatial reference.
8 |
9 |
10 | .. raw:: html
11 |
12 |
13 |
14 |
16 |
17 |
18 |
19 | SageNet is implemented with `pytorch `_ and `pytorch-geometric `_ to be modular, fast, and scalable. Also, it uses ``anndata`` to be compatible with `scanpy `_ and `squidpy `_ for pre- and post-processing steps.
20 |
21 |
22 | Installation
23 | ============
24 |
25 |
26 | .. note::
27 |
28 | **0.1.0**
29 | The dependency ``torch-geometric`` should be installed separately, corresponding the system specefities, look at `this link `_ for instructions. We recommend to use Miniconda.
30 |
31 | PyPI
32 | --------
33 |
34 | The easiest way to get SageNet is through pip using the following command::
35 |
36 | pip install sagenet
37 |
38 | Development
39 | ---------------
40 |
41 | First, clone the repository using ``git``::
42 |
43 | git clone https://github.com/MarioniLab/sagenet
44 |
45 | Then, ``cd`` to the sagenet folder and run the install command::
46 |
47 | cd sagenet
48 | python setup.py install #or pip install .
49 |
50 |
51 | Usage
52 | ============
53 | ::
54 |
55 | import sagenet as sg
56 | import scanpy as sc
57 | import squidpy as sq
58 | import anndata as ad
59 | import random
60 | random.seed(10)
61 |
62 |
63 | Training phase:
64 | ---------------
65 |
66 |
67 | **Input:**
68 |
69 | - Expression matrix associated with the (spatial) reference dataset (an ``anndata`` object)
70 |
71 | ::
72 |
73 | adata_r = sg.datasets.seqFISH1()
74 |
75 |
76 | - gene-gene interaction network
77 |
78 |
79 | ::
80 |
81 | glasso(adata_r, [0.5, 0.75, 1])
82 |
83 |
84 |
85 |
86 | - one or more partitionings of the spatial reference into distinct connected neighborhoods of cells or spots
87 |
88 | ::
89 |
90 | adata_r.obsm['spatial'] = np.array(adata_r.obs[['x','y']])
91 | sq.gr.spatial_neighbors(adata_r, coord_type="generic")
92 | sc.tl.leiden(adata_r, resolution=.01, random_state=0, key_added='leiden_0.01', adjacency=adata_r.obsp["spatial_connectivities"])
93 | sc.tl.leiden(adata_r, resolution=.05, random_state=0, key_added='leiden_0.05', adjacency=adata_r.obsp["spatial_connectivities"])
94 | sc.tl.leiden(adata_r, resolution=.1, random_state=0, key_added='leiden_0.1', adjacency=adata_r.obsp["spatial_connectivities"])
95 | sc.tl.leiden(adata_r, resolution=.5, random_state=0, key_added='leiden_0.5', adjacency=adata_r.obsp["spatial_connectivities"])
96 | sc.tl.leiden(adata_r, resolution=1, random_state=0, key_added='leiden_1', adjacency=adata_r.obsp["spatial_connectivities"])
97 |
98 |
99 |
100 | **Training:**
101 | ::
102 |
103 |
104 | sg_obj = sg.sage.sage(device=device)
105 | sg_obj.add_ref(adata_r, comm_columns=['leiden_0.01', 'leiden_0.05', 'leiden_0.1', 'leiden_0.5', 'leiden_1'], tag='seqFISH_ref', epochs=20, verbose = False)
106 |
107 |
108 |
109 | **Output:**
110 |
111 | - A set of pre-trained models (one for each partitioning)
112 |
113 | ::
114 |
115 |
116 | !mkdir models
117 | !mkdir models/seqFISH_ref
118 | sg_obj.save_model_as_folder('models/seqFISH_ref')
119 |
120 |
121 | - A concensus scoring of spatially informativity of each gene
122 |
123 | ::
124 |
125 |
126 | ind = np.argsort(-adata_r.var['seqFISH_ref_entropy'])[0:12]
127 | with rc_context({'figure.figsize': (4, 4)}):
128 | sc.pl.spatial(adata_r, color=list(adata_r.var_names[ind]), ncols=4, spot_size=0.03, legend_loc=None)
129 |
130 |
131 | .. raw:: html
132 |
133 |
134 |
135 |
137 |
138 |
139 |
140 |
141 |
142 |
143 | Mapping phase
144 | ---------------
145 |
146 | **Input:**
147 |
148 | - Expression matrix associated with the (dissociated) query dataset (an ``anndata`` object)
149 | ::
150 |
151 | adata_q = sg.datasets.MGA()
152 |
153 |
154 | **Mapping:**
155 | ::
156 |
157 | sg_obj.map_query(adata_q)
158 |
159 |
160 | **Output:**
161 |
162 | - The reconstructed cell-cell spatial distance matrix
163 | ::
164 |
165 |
166 | adata_q.obsm['dist_map']
167 |
168 |
169 | - A concensus scoring of mapability (uncertainity of mapping) of each cell to the references
170 | ::
171 |
172 |
173 | adata_q.obs
174 |
175 |
176 | ::
177 |
178 |
179 | import anndata
180 | dist_adata = anndata.AnnData(adata_q.obsm['dist_map'], obs = adata_q.obs)
181 | knn_indices, knn_dists, forest = sc.neighbors.compute_neighbors_umap(dist_adata.X, n_neighbors=50, metric='precomputed')
182 | dist_adata.obsp['distances'], dist_adata.obsp['connectivities'] = sc.neighbors._compute_connectivities_umap(
183 | knn_indices,
184 | knn_dists,
185 | dist_adata.shape[0],
186 | 50, # change to neighbors you plan to use
187 | )
188 | sc.pp.neighbors(dist_adata, metric='precomputed', use_rep='X')
189 | sc.tl.umap(dist_adata)
190 | sc.pl.umap(dist_adata, color='cell_type', palette=celltype_colours)
191 |
192 |
193 | .. raw:: html
194 |
195 |
196 |
197 |
199 |
200 |
201 |
202 |
203 | Notebooks
204 | ============
205 | To see some examples of our pipeline's capability, look at the `notebooks `_ directory. The notebooks are also avaialble on google colab:
206 |
207 | #. `Intro to SageNet `_
208 | #. `Using multiple references `_
209 |
210 | Interactive examples
211 | ============
212 | * `Spatial mapping of the mouse gastrulation atlas `_
213 |
214 |
215 | Support and contribute
216 | ============
217 | If you have a question or new architecture or a model that could be integrated into our pipeline, you can
218 | post an `issue `__ or reach us by `email `_.
219 |
220 |
221 | Contributions
222 | ============
223 | This work is led by Elyas Heidari and Shila Ghazanfar as a joint effort between `MarioniLab@CRUK@EMBL-EBI `__ and `RobinsonLab@UZH `__.
224 |
225 | .. |Docs| image:: https://readthedocs.org/projects/sagenet/badge/?version=latest
226 | :target: https://sagenet.readthedocs.io
227 |
228 | .. |PyPI| image:: https://img.shields.io/pypi/v/sagenet.svg
229 | :target: https://pypi.org/project/sagenet
230 |
231 |
232 |
--------------------------------------------------------------------------------
/docs/api.rst:
--------------------------------------------------------------------------------
1 | API
2 | ===
3 |
4 | The API reference contains detailed descriptions of the different end-user classes, functions, methods, etc.
5 |
6 |
7 | .. note::
8 |
9 | This API reference only contains end-user documentation.
10 | If you are looking to hack away at sagenet' internals, you will find more detailed comments in the source code.
11 |
12 |
13 | * `sage`_
14 | * `classifier`_
15 | * `utils`_
16 |
17 | sage
18 | ----
19 |
20 | .. automodule:: sagenet.sage
21 | :members:
22 | :undoc-members:
23 | :show-inheritance:
24 |
25 |
26 | classifier
27 | ----------
28 |
29 | .. automodule:: sagenet.classifier
30 | :members:
31 | :undoc-members:
32 | :show-inheritance:
33 |
34 |
35 | utils
36 | -----
37 |
38 | .. automodule:: sagenet.utils
39 | :members:
40 | :undoc-members:
41 | :show-inheritance:
42 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | import inspect
14 | import os
15 | import sys
16 | from datetime import datetime
17 |
18 | sys.path.insert(0, os.path.abspath('..'))
19 |
20 | # -- Readthedocs theme -------------------------------------------------------
21 | on_rtd = os.environ.get('READTHEDOCS', None) == 'True'
22 |
23 | if not on_rtd: # only import and set the theme if we're building docs locally
24 | import sphinx_rtd_theme
25 |
26 | html_theme = 'sphinx_rtd_theme'
27 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
28 |
29 | import sagenet
30 |
31 | # -- Retrieve notebooks ------------------------------------------------------
32 |
33 | from urllib.request import urlretrieve
34 |
35 | notebooks_url = 'https://github.com/MarioniLab/sagenet/tree/main/notebooks'
36 | notebooks = [
37 | '00_hello_sagenet.ipynb',
38 | '01_multiple_references.ipynb'
39 | ]
40 |
41 | for nb in notebooks:
42 | try:
43 | urlretrieve(notebooks_url + nb, nb)
44 | except:
45 | pass
46 |
47 | # -- Project information -----------------------------------------------------
48 |
49 | project = 'sagenet'
50 | author = 'Elyas Heidari'
51 | copyright = f'{datetime.now():%Y}, ' + author
52 |
53 | pygments_style = 'sphinx'
54 | todo_include_todos = True
55 | html_theme_options = dict(navigation_depth=3, titles_only=False)
56 | html_context = dict(
57 | display_github=True,
58 | github_user='MarioniLab',
59 | github_repo='sagenet',
60 | github_version='master',
61 | conf_py_path='/docs/',
62 | )
63 | html_static_path = ['_static']
64 | html_logo = '_static/img/logo.png'
65 |
66 | def setup(app):
67 | app.add_css_file('custom.css')
68 |
69 | # -- General configuration ---------------------------------------------------
70 |
71 | # Add any Sphinx extension module names here, as strings. They can be
72 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
73 | # ones.
74 | extensions = [
75 | 'sphinx.ext.autodoc',
76 | 'sphinx.ext.doctest',
77 | 'nbsphinx',
78 | 'sphinx.ext.napoleon',
79 | 'sphinx.ext.todo',
80 | 'sphinx.ext.mathjax',
81 | 'sphinx.ext.graphviz',
82 | 'sphinx.ext.intersphinx',
83 | 'sphinx.ext.linkcode',
84 | 'sphinx_rtd_theme',
85 | 'numpydoc',
86 | ]
87 |
88 | add_module_names = True
89 | autosummary_generate = True
90 | numpydoc_show_class_members = True
91 |
92 | intersphinx_mapping = {
93 | 'python': ('https://docs.python.org/3', None),
94 | 'anndata': ('https://anndata.readthedocs.io/en/latest/', None),
95 | 'numpy': ('https://numpy.readthedocs.io/en/latest/', None),
96 | 'scanpy': ('https://scanpy.readthedocs.io/en/latest/', None),
97 | }
98 |
99 | # Add any paths that contain templates here, relative to this directory.
100 | templates_path = ['_templates']
101 |
102 | # List of patterns, relative to source directory, that match files and
103 | # directories to ignore when looking for source files.
104 | # This pattern also affects html_static_path and html_extra_path.
105 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', '**.ipynb_checkpoints']
106 |
107 |
108 | def linkcode_resolve(domain, info):
109 | """
110 | Determine the URL corresponding to Python object
111 | """
112 | if domain != 'py':
113 | return None
114 |
115 | modname = info['module']
116 | fullname = info['fullname']
117 |
118 | submod = sys.modules.get(modname)
119 | if submod is None:
120 | return None
121 |
122 | obj = submod
123 | for part in fullname.split('.'):
124 | try:
125 | obj = getattr(obj, part)
126 | except:
127 | return None
128 |
129 | try:
130 | fn = inspect.getsourcefile(obj)
131 | except:
132 | fn = None
133 | if not fn:
134 | return None
135 |
136 | try:
137 | source, lineno = inspect.findsource(obj)
138 | except:
139 | lineno = None
140 |
141 | if lineno:
142 | linespec = "#L%d" % (lineno + 1)
143 | else:
144 | linespec = ""
145 |
146 | fn = os.path.relpath(fn, start=os.path.dirname(sagenet.__file__))
147 |
148 | github = f"https://github.com/MarioniLab/sagenet/blob/main/sagenet/{fn}{linespec}"
149 | return github
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. sagenet documentation master file, created by
2 | sphinx-quickstart on Tue Dec 7 05:28:34 2021.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 | SageNet: Spatial reconstruction of single-cell dissociated datasets using graph neural networks
6 | =========================================================================
7 | .. raw:: html
8 |
9 | **SageNet** is a robust and generalizable graph neural network approach that probabilistically maps dissociated single cells from an scRNAseq dataset to their hypothetical tissue of origin using one or more reference datasets aquired by spatially resolved transcriptomics techniques. It is compatible with both high-plex imaging (e.g., seqFISH, MERFISH, etc.) and spatial barcoding (e.g., 10X visium, Slide-seq, etc.) datasets as the spatial reference.
10 |
11 |
12 | .. raw:: html
13 |
14 |
15 |
16 |
18 |
19 |
20 |
21 | SageNet is implemented with `pytorch `_ and `pytorch-geometric `_ to be modular, fast, and scalable. Also, it uses ``anndata`` to be compatible with `scanpy `_ and `squidpy `_ for pre- and post-processing steps.
22 |
23 | Installation
24 | -------------------------------
25 | You can get the latest development version of our toolkit from `Github `_ using the following steps:
26 |
27 | First, clone the repository using ``git``::
28 |
29 | git clone https://github.com/MarioniLab/sagenet
30 |
31 | Then, ``cd`` to the sagenet folder and run the install command::
32 |
33 | cd sagenet
34 | python setup.py install #or pip install `
35 |
36 |
37 | The dependency ``torch-geometric`` should be installed separately, corresponding the system specefities, look at `this link `_ for instructions.
38 |
39 |
40 | .. raw:: html
41 |
42 |
43 |
44 |
46 |
47 |
48 |
49 |
50 | Notebooks
51 | -------------------------------
52 | To see some examples of our pipeline's capability, look at the `notebooks `_ directory. The notebooks are also avaialble on google colab:
53 |
54 | #. `Intro to SageNet `_
55 | #. `Using multiple references `_
56 |
57 | Interactive examples
58 | -------------------------------
59 | See `this `_
60 |
61 |
62 | Support and contribute
63 | -------------------------------
64 | If you have a question or new architecture or a model that could be integrated into our pipeline, you can
65 | post an `issue `__ or reach us by `email `_.
66 |
67 |
68 | Contributions
69 | -------------------------------
70 | This work is led by Elyas Heidari and Shila Ghazanfar as a joint effort between `MarioniLab@CRUK@EMBL-EBI `__ and `RobinsonLab@UZH `__.
71 |
72 |
73 |
74 |
75 | .. toctree::
76 | :maxdepth: 1
77 | :caption: Main
78 | :hidden:
79 |
80 | about
81 | installation
82 | api.rst
83 |
84 | .. toctree::
85 | :maxdepth: 1
86 | :caption: Examples
87 | :hidden:
88 |
89 | 01_multiple_references
90 | 00_hello_sagenet
91 |
92 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/figures/show_ad_r1_all.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarioniLab/sagenet/5dce6dd375cf28678d735f5e5d8083dfd3d86596/figures/show_ad_r1_all.pdf
--------------------------------------------------------------------------------
/figures/show_ad_r1_all_conf.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarioniLab/sagenet/5dce6dd375cf28678d735f5e5d8083dfd3d86596/figures/show_ad_r1_all_conf.pdf
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.8.0
2 | captum>=0.4
3 | tensorboard>=2.6
4 | numpydoc>=1.1
5 | anndata
6 | gglasso
7 | squidpy
8 |
9 |
10 |
--------------------------------------------------------------------------------
/sagenet/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarioniLab/sagenet/5dce6dd375cf28678d735f5e5d8083dfd3d86596/sagenet/.DS_Store
--------------------------------------------------------------------------------
/sagenet/DHH_data/_DHH_data.py:
--------------------------------------------------------------------------------
1 | from copy import copy
2 | from squidpy.datasets._utils import AMetadata
3 |
4 | ST = AMetadata(
5 | name="ST",
6 | doc_header="",
7 | # shape=(270876, 43),
8 | url="https://figshare.com/ndownloader/files/31796207",
9 | )
10 |
11 | scRNAseq = AMetadata(
12 | name="scRNAseq",
13 | doc_header="",
14 | # shape=(270876, 43),
15 | url="https://figshare.com/ndownloader/files/31796219",
16 | )
17 |
18 |
19 |
20 |
21 | for name, var in copy(locals()).items():
22 | if isinstance(var, AMetadata):
23 | var._create_function(name, globals())
24 |
25 |
26 | __all__ = [ # noqa: F822
27 | "ST",
28 | "scRNAseq",
29 |
30 | ]
31 |
--------------------------------------------------------------------------------
/sagenet/DHH_data/__init__.py:
--------------------------------------------------------------------------------
1 | from sagenet.DHH_data._DHH_data import *
2 |
--------------------------------------------------------------------------------
/sagenet/DHH_data_/_DHH_data_.py:
--------------------------------------------------------------------------------
1 | from copy import copy
2 | from squidpy.datasets._utils import AMetadata
3 |
4 | ST = AMetadata(
5 | name="ST",
6 | doc_header="",
7 | # shape=(270876, 43),
8 | url="https://figshare.com/ndownloader/files/31796207",
9 | )
10 |
11 | scRNAseq = AMetadata(
12 | name="scRNAseq",
13 | doc_header="",
14 | # shape=(270876, 43),
15 | url="https://figshare.com/ndownloader/files/31796219",
16 | )
17 |
18 |
19 |
20 |
21 | for name, var in copy(locals()).items():
22 | if isinstance(var, AMetadata):
23 | var._create_function(name, globals())
24 |
25 |
26 | __all__ = [ # noqa: F822
27 | "ST",
28 | "scRNAseq",
29 |
30 | ]
31 |
--------------------------------------------------------------------------------
/sagenet/DHH_data_/__init__.py:
--------------------------------------------------------------------------------
1 | from sagenet.DHH_data._DHH_data_ import *
2 |
--------------------------------------------------------------------------------
/sagenet/MGA_analysis.py:
--------------------------------------------------------------------------------
1 | import sagenet as sg
2 | from sagenet.utils import map2ref, glasso
3 | import numpy as np
4 | import squidpy as sq
5 | import scanpy as sc
6 | import torch
7 | import anndata as ad
8 |
9 | random.seed(1996)
10 |
11 |
12 | adata_r1 = sg.MGA_data.seqFISH1_1()
13 | adata_r2 = sg.MGA_data.seqFISH2_1()
14 | adata_r3 = sg.MGA_data.seqFISH3_1()
15 | adata_q1 = sg.MGA_data.seqFISH1_2()
16 | adata_q2 = sg.MGA_data.seqFISH2_2()
17 | adata_q3 = sg.MGA_data.seqFISH3_2()
18 | adata_q = sg.MGA_data.scRNAseq()
19 |
20 |
21 | # Map everything to 1-1
22 | glasso(adata_r1)
23 | adata_r1.obsm['spatial'] = np.array(adata_r1.obs[['x','y']])
24 | sq.gr.spatial_neighbors(adata_r1, coord_type="generic")
25 | sc.tl.leiden(adata_r1, resolution=.01, random_state=0, key_added='leiden_0.01', adjacency=adata_r1.obsp["spatial_connectivities"])
26 | sc.tl.leiden(adata_r1, resolution=.05, random_state=0, key_added='leiden_0.05', adjacency=adata_r1.obsp["spatial_connectivities"])
27 | sc.tl.leiden(adata_r1, resolution=.1, random_state=0, key_added='leiden_0.1', adjacency=adata_r1.obsp["spatial_connectivities"])
28 | sc.tl.leiden(adata_r1, resolution=.5, random_state=0, key_added='leiden_0.5', adjacency=adata_r1.obsp["spatial_connectivities"])
29 | sc.tl.leiden(adata_r1, resolution=1, random_state=0, key_added='leiden_1', adjacency=adata_r1.obsp["spatial_connectivities"])
30 |
31 | if torch.cuda.is_available():
32 | dev = "cuda:0"
33 | else:
34 | dev = "cpu"
35 |
36 |
37 | device = torch.device(dev)
38 | print(device)
39 |
40 | sg_obj = sg.sage.sage(device=device)
41 | sg_obj.add_ref(adata_r1, comm_columns=['leiden_0.01', 'leiden_0.05', 'leiden_0.1', 'leiden_0.5', 'leiden_1'], tag='embryo1_2', epochs=20, verbose = True, classifier='GraphSAGE')
42 |
43 |
44 | sg_obj.map_query(adata_r1, save_prob=True)
45 | ind, conf = map2ref(adata_r1, adata_r1)
46 | adata_r1.obsm['spatial_pred'] = adata_r1.obsm['spatial'][ind,:]
47 | adata_r1.obs['conf'] = np.log(conf)
48 | sg_obj.map_query(adata_r2, save_prob=True)
49 | ind, conf = map2ref(adata_r1, adata_r2)
50 | adata_r2.obsm['spatial'] = adata_r1.obsm['spatial'][ind,:]
51 | adata_r2.obs['conf'] = np.log(conf)
52 | sg_obj.map_query(adata_r3, save_prob=True)
53 | ind, conf = map2ref(adata_r1, adata_r3)
54 | adata_r3.obsm['spatial'] = adata_r1.obsm['spatial'][ind,:]
55 | adata_r3.obs['conf'] = np.log(conf)
56 | sg_obj.map_query(adata_q1, save_prob=True)
57 | ind, conf = map2ref(adata_r1, adata_q1)
58 | adata_q1.obsm['spatial'] = adata_r1.obsm['spatial'][ind,:]
59 | adata_q1.obs['conf'] = np.log(conf)
60 | sg_obj.map_query(adata_q2, save_prob=True)
61 | ind, conf = map2ref(adata_r1, adata_q2)
62 | adata_q2.obsm['spatial'] = adata_r1.obsm['spatial'][ind,:]
63 | adata_q2.obs['conf'] = np.log(conf)
64 | sg_obj.map_query(adata_q3, save_prob=True)
65 | ind, conf = map2ref(adata_r1, adata_q3)
66 | adata_q3.obsm['spatial'] = adata_r1.obsm['spatial'][ind,:]
67 | adata_q3.obs['conf'] = np.log(conf)
68 | sg_obj.map_query(adata_q, save_prob=True)
69 | ind, conf = map2ref(adata_r1, adata_q)
70 | adata_q.obsm['spatial'] = adata_r1.obsm['spatial'][ind,:]
71 | adata_q.obs['conf'] = np.log(conf)
72 |
73 | adata_r1.obsm['spatial'] = adata_r1.obsm['spatial_pred']
74 |
75 |
76 | ad_concat = ad.concat([adata_r1, adata_r2, adata_r3, adata_q1, adata_q2, adata_q3, adata_q], label='batch')
77 |
78 | sc.pl.spatial(
79 | ad_concat,
80 | color='cell_type',
81 | palette=celltype_colours,# Color cells based on 'cell_type'
82 | # color_map=cell_type_color_map, # Use the custom color map
83 | # library_id='r1_mapping', # Use 'r1_mapping' coordinates
84 | title='all to r1 map',
85 | save='_ad_r1_all.pdf',
86 | spot_size=.1
87 | )
88 |
89 | sc.pl.spatial(
90 | ad_concat,
91 | color='conf',
92 | # palette=celltype_colours,# Color cells based on 'cell_type'
93 | # color_map=cell_type_color_map, # Use the custom color map
94 | # library_id='r1_mapping', # Use 'r1_mapping' coordinates
95 | title='all to r1 map',
96 | save='_ad_r1_all_conf.pdf',
97 | spot_size=.1
98 | )
--------------------------------------------------------------------------------
/sagenet/MGA_data/_MGA_data.py:
--------------------------------------------------------------------------------
1 |
2 | from copy import copy
3 | from squidpy.datasets._utils import AMetadata
4 |
5 | _scRNAseq = AMetadata(
6 | name="scRNAseq",
7 | doc_header="",
8 | # shape=(270876, 43),
9 | url="https://figshare.com/ndownloader/files/31767704",
10 | )
11 |
12 | _seqFISH1_1 = AMetadata(
13 | name="seqFISH1_1",
14 | doc_header="",
15 | # shape=(270876, 43),
16 | url="https://figshare.com/ndownloader/files/31716029",
17 | )
18 |
19 | _seqFISH2_1 = AMetadata(
20 | name="seqFISH2_1",
21 | doc_header="",
22 | # shape=(270876, 43),
23 | url="https://figshare.com/ndownloader/files/31716041",
24 | )
25 |
26 |
27 | _seqFISH3_1 = AMetadata(
28 | name="seqFISH3_1",
29 | doc_header="",
30 | # shape=(270876, 43),
31 | url="https://figshare.com/ndownloader/files/31716089",
32 | )
33 |
34 |
35 | _seqFISH1_2 = AMetadata(
36 | name="seqFISH1_2",
37 | doc_header="",
38 | # shape=(270876, 43),
39 | url="https://figshare.com/ndownloader/files/31920353",
40 | )
41 |
42 | _seqFISH2_2 = AMetadata(
43 | name="seqFISH2_2",
44 | doc_header="",
45 | # shape=(270876, 43),
46 | url="https://figshare.com/ndownloader/files/31920644",
47 | )
48 |
49 |
50 | _seqFISH3_2 = AMetadata(
51 | name="seqFISH3_2",
52 | doc_header="",
53 | # shape=(270876, 43),
54 | url="https://figshare.com/ndownloader/files/31920890",
55 | )
56 |
57 |
58 | for name, var in copy(locals()).items():
59 | if isinstance(var, AMetadata):
60 | var._create_function(name, globals())
61 |
62 |
63 | __all__ = [ # noqa: F822
64 | "scRNAseq",
65 | "seqFISH1_1",
66 | "seqFISH2_1",
67 | "seqFISH3_1",
68 | "seqFISH1_2",
69 | "seqFISH2_2",
70 | "seqFISH3_2",
71 |
72 | ]
73 |
--------------------------------------------------------------------------------
/sagenet/MGA_data/__init__.py:
--------------------------------------------------------------------------------
1 | from sagenet.MGA_data._MGA_data import *
2 |
--------------------------------------------------------------------------------
/sagenet/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = __maintainer__ = "Elyas Heidari"
2 | __email__ = 'eheidari@student.ethz.ch'
3 | __version__ = "0.1.1"
4 |
5 | from . import classifier, model, sage, utils, MGA_data, DHH_data
6 |
7 |
--------------------------------------------------------------------------------
/sagenet/classifier.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 | from torch.optim import lr_scheduler
3 | import torch.nn as nn
4 | from torch.utils.data import TensorDataset, DataLoader, random_split
5 | from torch.utils.tensorboard import SummaryWriter
6 | from sagenet.utils import compute_metrics
7 | from sagenet.model import *
8 | from captum import *
9 | from captum.attr import IntegratedGradients
10 | import numpy as np
11 | from sklearn.preprocessing import normalize
12 |
13 | class Classifier():
14 |
15 | """
16 | A Neural Network Classifier. A number of Graph Neural Networks (GNN) and an MLP are implemented.
17 |
18 | Parameters
19 | ----------
20 | n_features : int
21 | number of input features.
22 | n_classes : int
23 | number of classes.
24 | n_hidden_GNN : list, default=[]
25 | list of integers indicating sizes of GNN hidden layers.
26 | n_hidden_FC : list, default=[]
27 | list of integers indicating sizes of FC hidden layers. If a GNN is used, this indicates FC hidden layers after the GNN layers.
28 | K : integer, default=4
29 | Convolution layer filter size. Used only when `classifier == 'Chebnet'`.
30 | dropout_GNN : float, default=0
31 | dropout rate for GNN hidden layers.
32 | dropout_FC : float, default=0
33 | dropout rate for FC hidden layers.
34 | classifier : str, default='MLP'
35 | - 'MLP' --> multilayer perceptron
36 | - 'GraphSAGE'--> GraphSAGE Network
37 | - 'Chebnet'--> Chebyshev spectral Graph Convolutional Network
38 | - 'GATConv'--> Graph Attentional Neural Network
39 | - 'GENConv'--> GENeralized Graph Convolution Network
40 | - 'GINConv'--> Graph Isoform Network
41 | - 'GraphConv'--> Graph Convolutional Neural Network
42 | - 'MFConv'--> Convolutional Networks on Graphs for Learning Molecular Fingerprints
43 | - 'TransformerConv'--> Graph Transformer Neural Network
44 | lr : float, default=0.001
45 | base learning rate for the SGD optimization algorithm.
46 | momentum : float, default=0.9
47 | base momentum for the SGD optimization algorithm.
48 | log_dir : str, default=None
49 | path to the log directory. Specifically, used for tensorboard logs.
50 | device : str, default='cpu'
51 | the processing unit.
52 |
53 |
54 |
55 | See also
56 | --------
57 | Classifier.fit : fits the classifier to data
58 | Classifier.eval : evaluates the classifier predictions
59 | """
60 | def __init__(self,
61 | n_features,
62 | n_classes,
63 | n_hidden_GNN=[],
64 | n_hidden_FC=[],
65 | K=4,
66 | pool_K=4,
67 | dropout_GNN=0,
68 | dropout_FC=0,
69 | classifier='MLP',
70 | lr=.001,
71 | momentum=.9,
72 | log_dir=None,
73 | device='cpu'):
74 | if classifier == 'MLP':
75 | self.net = NN(n_features=n_features, n_classes=n_classes,\
76 | n_hidden_FC=n_hidden_FC, dropout_FC=dropout_FC)
77 | if classifier == 'GraphSAGE':
78 | self.net = GraphSAGE(n_features=n_features, n_classes=n_classes,\
79 | n_hidden_GNN=n_hidden_GNN, n_hidden_FC=n_hidden_FC, \
80 | dropout_FC=dropout_FC, dropout_GNN=dropout_GNN)
81 | if classifier == 'Chebnet':
82 | self.net = ChebNet(n_features=n_features, n_classes=n_classes,\
83 | n_hidden_GNN=n_hidden_GNN, n_hidden_FC=n_hidden_FC, \
84 | dropout_FC=dropout_FC, dropout_GNN=dropout_GNN, K=K)
85 | if classifier == 'GATConv':
86 | self.net = GATConvNet(n_features=n_features, n_classes=n_classes,\
87 | n_hidden_GNN=n_hidden_GNN, n_hidden_FC=n_hidden_FC, \
88 | dropout_FC=dropout_FC, dropout_GNN=dropout_GNN)
89 | if classifier == 'GENConv':
90 | self.net = GENConvNet(n_features=n_features, n_classes=n_classes,\
91 | n_hidden_GNN=n_hidden_GNN, n_hidden_FC=n_hidden_FC, \
92 | dropout_FC=dropout_FC, dropout_GNN=dropout_GNN)
93 | if classifier =="GINConv":
94 | self.net = GINConv(n_features=n_features, n_classes=n_classes,\
95 | n_hidden_GNN=n_hidden_GNN, n_hidden_FC=n_hidden_FC, \
96 | dropout_FC=dropout_FC, dropout_GNN=dropout_GNN)
97 | if classifier =="GraphConv":
98 | self.net = GraphConv(n_features=n_features, n_classes=n_classes,\
99 | n_hidden_GNN=n_hidden_GNN, n_hidden_FC=n_hidden_FC, \
100 | dropout_FC=dropout_FC, dropout_GNN=dropout_GNN)
101 | if classifier =="MFConv":
102 | self.net = MFConv(n_features=n_features, n_classes=n_classes,\
103 | n_hidden_GNN=n_hidden_GNN, n_hidden_FC=n_hidden_FC, \
104 | dropout_FC=dropout_FC, dropout_GNN=dropout_GNN)
105 | if classifier =="TransformerConv":
106 | self.net = TransformerConv(n_features=n_features, n_classes=n_classes,\
107 | n_hidden_GNN=n_hidden_GNN, n_hidden_FC=n_hidden_FC, \
108 | dropout_FC=dropout_FC, dropout_GNN=dropout_GNN)
109 | if classifier =="Conv1d":
110 | self.net = ConvNet(n_features=n_features, n_classes=n_classes,\
111 | n_hidden_GNN=n_hidden_GNN, n_hidden_FC=n_hidden_FC, \
112 | dropout_FC=dropout_FC, dropout_GNN=dropout_GNN, filter_K=K, pool_K=pool_K)
113 | self.criterion = nn.CrossEntropyLoss()
114 | self.optimizer = optim.SGD(self.net.parameters(), lr=lr, momentum=momentum)
115 | self.logging = log_dir is not None
116 | self.device = device
117 | self.lr = lr
118 | if self.logging:
119 | self.writer = SummaryWriter(log_dir=log_dir,flush_secs=1)
120 |
121 | def fit(self,data_loader,epochs,test_dataloader=None,verbose=False):
122 | """
123 | fits the classifier to the input data.
124 |
125 | Parameters
126 | ----------
127 | data_loader : torch-geometric dataloader
128 | the training dataset.
129 | epochs : int
130 | number of epochs.
131 | test_dataloader : torch-geometric dataloader, default=None
132 | the test dataset on which the model is evaluated in each epoch.
133 | verbose : boolean, default=False
134 | whether to print out loss during training.
135 | """
136 | if self.logging:
137 | data= next(iter(data_loader))
138 | self.writer.add_graph(self.net,[data.x,data.edge_index])
139 | # self.scheduler = lr_scheduler.CyclicLR(self.optimizer, base_lr=self.lr, max_lr=0.01,step_size_up=5,mode="triangular2")
140 |
141 | self.scheduler = lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, T_0=50, T_mult=1, eta_min=0.00005, last_epoch=-1)
142 | for epoch in range(epochs):
143 | self.net.train()
144 | self.net.to(self.device)
145 | total_loss = 0
146 |
147 | for batch in data_loader:
148 | x, edge_index, label = batch.x.to(self.device), batch.edge_index.to(self.device), batch.y.to(self.device)
149 | self.optimizer.zero_grad()
150 | pred = self.net(x, edge_index)
151 | loss = self.criterion(pred,label)
152 | loss.backward()
153 | self.optimizer.step()
154 | self.scheduler.step()
155 | total_loss += loss.item() * batch.num_graphs
156 | total_loss /= len(data_loader.dataset)
157 | if verbose and epoch%(epochs//10)==0:
158 | print('[%d] loss: %.3f' % (epoch + 1,total_loss))
159 |
160 | if self.logging:
161 | #Save the training loss, the training accuracy and the test accuracy for tensorboard vizualisation
162 | self.writer.add_scalar("Training Loss",total_loss,epoch)
163 | accuracy_train = self.eval(data_loader,verbose=False)[0]
164 | self.writer.add_scalar("Accuracy on Training Dataset",accuracy_train,epoch)
165 | if test_dataloader is not None:
166 | accuracy_test = self.eval(test_dataloader,verbose=False)[0]
167 | self.writer.add_scalar("Accuracy on Test Dataset",accuracy_test,epoch)
168 |
169 |
170 |
171 |
172 |
173 |
174 | def eval(self,data_loader,verbose=False):
175 | """
176 | evaluates the model based on predictions
177 |
178 | Parameters
179 | ----------
180 | test_dataloader : torch-geometric dataloader, default=None
181 | the dataset on which the model is evaluated.
182 | verbose : boolean, default=False
183 | whether to print out loss during training.
184 | Returns
185 | ----------
186 | accuracy : float
187 | accuracy
188 | conf_mat : ndarray
189 | confusion matrix
190 | precision : fload
191 | weighted precision score
192 | recall : float
193 | weighted recall score
194 | f1_score : float
195 | weighted f1 score
196 | """
197 | self.net.eval()
198 | correct = 0
199 | total = 0
200 | y_true = []
201 | y_pred = []
202 | with torch.no_grad():
203 | for batch in data_loader:
204 | x, edge_index, label = batch.x.to(self.device), batch.edge_index.to(self.device), batch.y.to('cpu')
205 | y_true.extend(list(label))
206 | outputs = self.net(x, edge_index)
207 | _, predicted = torch.max(outputs.data, 1)
208 | predicted = predicted.to('cpu')
209 | y_pred.extend(list(predicted))
210 | accuracy, conf_mat, precision, recall, f1_score = compute_metrics(y_true, y_pred)
211 | if verbose:
212 | print('Accuracy: {:.3f}'.format(accuracy))
213 | print('Confusion Matrix:\n', conf_mat)
214 | print('Precision: {:.3f}'.format(precision))
215 | print('Recall: {:.3f}'.format(recall))
216 | print('f1_score: {:.3f}'.format(f1_score))
217 | return accuracy, conf_mat, precision, recall, f1_score
218 |
219 | def interpret(self, data_loader, n_features, n_classes):
220 | """
221 | interprets a trained model, by giving importance scores assigned to each feature regarding each class
222 | it uses the `IntegratedGradients` method from the package `captum` to computed class-wise feature importances
223 | and then computes entropy values to get a global importance measure.
224 |
225 | Parameters
226 | ----------
227 | data_loder : torch-geometric dataloader, default=None
228 | the dataset on which the model is evaluated.
229 | n_features : int
230 | number of features.
231 | n_classes : int
232 | number of classes.
233 |
234 | Returns
235 | ----------
236 | ent : numpy ndarray, shape (n_features)
237 |
238 | """
239 | batch = next(iter(data_loader))
240 | e = batch.edge_index.to(self.device).long()
241 | def model_forward(input):
242 | out = self.net(input, e)
243 | return out
244 | self.net.eval()
245 | importances = np.zeros((n_features, n_classes))
246 | for batch in data_loader:
247 | input = batch.x.to(self.device)
248 | target = batch.y.to(self.device)
249 | ig = IntegratedGradients(model_forward)
250 | attributions = ig.attribute(input, target=target)
251 | attributions = attributions.to('cpu').detach().numpy()
252 | attributions = attributions.reshape(n_features, len(target))
253 | attributions = normalize(attributions, axis=0, norm='l2')
254 | # attributions /= np.linalg.norm(attributions)
255 | importances[:, target.to('cpu').numpy()] += attributions
256 | # importances = np.e**importances
257 | # importances = importances / importances.max(axis=0)
258 | # imp = (importances.T / np.sum(importances, axis = 1)).T
259 | # ent = (-imp * np.log2(imp)).sum(axis = 1) / np.log2(n_classes)
260 | # idx = (-importances).argsort(axis=0)
261 | # ent = np.min(idx, axis=1)
262 | return importances
263 |
--------------------------------------------------------------------------------
/sagenet/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from sagenet.datasets._datasets import *
--------------------------------------------------------------------------------
/sagenet/datasets/_datasets.py:
--------------------------------------------------------------------------------
1 | from copy import copy
2 | from squidpy.datasets._utils import AMetadata
3 |
4 | _MGA_scRNAseq = AMetadata(
5 | name="MGA_scRNAseq",
6 | doc_header="",
7 | # shape=(270876, 43),
8 | url="https://figshare.com/ndownloader/files/31767704",
9 | )
10 |
11 | _MGA_seqFISH1 = AMetadata(
12 | name="MGA_seqFISH1",
13 | doc_header="",
14 | # shape=(270876, 43),
15 | url="https://figshare.com/ndownloader/files/31716029",
16 | )
17 |
18 | _MGA_seqFISH2 = AMetadata(
19 | name="MGA_seqFISH2",
20 | doc_header="",
21 | # shape=(270876, 43),
22 | url="https://figshare.com/ndownloader/files/31716041",
23 | )
24 |
25 |
26 | _MGA_seqFISH3 = AMetadata(
27 | name="MGA_seqFISH3",
28 | doc_header="",
29 | # shape=(270876, 43),
30 | url="https://figshare.com/ndownloader/files/31716089",
31 | )
32 |
33 | _DHH_visium_ = AMetadata(
34 | name="DHH_visium_",
35 | doc_header="",
36 | # shape=(270876, 43),
37 | url="https://figshare.com/ndownloader/files/31796207",
38 | )
39 |
40 | _DHH_scRNAseq = AMetadata(
41 | name="DHH_scRNAseq",
42 | doc_header="",
43 | # shape=(270876, 43),
44 | url="https://figshare.com/ndownloader/files/31796219",
45 | )
46 |
47 |
48 |
49 |
50 | for name, var in copy(locals()).items():
51 | if isinstance(var, AMetadata):
52 | var._create_function(name, globals())
53 |
54 |
55 | __all__ = [ # noqa: F822
56 | "MGA_scRNAseq",
57 | "MGA_seqFISH1",
58 | "MGA_seqFISH2",
59 | "MGA_seqFISH3",
60 | 'DHH_visium_',
61 | 'DHH_scRNAseq'
62 |
63 | ]
64 |
--------------------------------------------------------------------------------
/sagenet/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch_geometric.nn as pyg_nn
5 |
6 |
7 |
8 | class NN(nn.Module):
9 | def __init__(self, \
10 | n_features, \
11 | n_classes, \
12 | n_hidden_GNN=[], \
13 | n_hidden_FC=[10], \
14 | dropout_GNN=0, \
15 | dropout_FC=0):
16 | super(NN, self).__init__()
17 | self.FC = True
18 | self.n_features = n_features
19 | self.n_classes = n_classes
20 | self.layers_GNN = nn.ModuleList()
21 | self.layers_FC = nn.ModuleList()
22 | self.n_layers_GNN = len(n_hidden_GNN)
23 | self.n_layers_FC = len(n_hidden_FC)
24 | self.dropout_GNN = dropout_GNN
25 | self.dropout_FC = dropout_FC
26 | self.n_hidden_GNN = n_hidden_GNN
27 | self.n_hidden_FC = n_hidden_FC
28 | self.conv = False
29 | if self.n_layers_GNN > 0:
30 | self.FC = False
31 |
32 | # Fully connected layers. They occur after the graph convolutions (or at the start if there no are graph convolutions)
33 | if self.n_layers_FC > 0:
34 | if self.n_layers_GNN==0:
35 | self.layers_FC.append(nn.Linear(n_features, n_hidden_FC[0]))
36 | else:
37 | self.layers_FC.append(nn.Linear(n_features*n_hidden_GNN[-1], n_hidden_FC[0]))
38 | if self.n_layers_FC > 1:
39 | for i in range(self.n_layers_FC-1):
40 | self.layers_FC.append(nn.Linear(n_hidden_FC[i], n_hidden_FC[(i+1)]))
41 |
42 | # Last layer
43 | if self.n_layers_FC>0:
44 | self.last_layer_FC = nn.Linear(n_hidden_FC[-1], n_classes)
45 | elif self.n_layers_GNN>0:
46 | self.last_layer_FC = nn.Linear(n_features*n_hidden_GNN[-1], n_classes)
47 | else:
48 | self.last_layer_FC = nn.Linear(n_features, n_classes)
49 |
50 | def forward(self,x,edge_index):
51 | if self.FC:
52 | # Resize from (1,batch_size * n_features) to (batch_size, n_features)
53 | x = x.view(-1,self.n_features)
54 | if self.conv:
55 | x = x.view(-1,1,self.n_features)
56 | for layer in self.layers_GNN:
57 | x = F.relu(layer(x))
58 | x = F.max_pool1d(x, kernel_size=self.pool_K, stride=1, padding=self.pool_K//2, dilation=1)
59 | x = F.dropout(x, p=self.dropout_GNN, training=self.training)
60 | # x = F.max_pool1d(x)
61 | else:
62 | for layer in self.layers_GNN:
63 | x = F.relu(layer(x, edge_index))
64 | x = F.dropout(x, p=self.dropout_GNN, training=self.training)
65 | if self.n_layers_GNN > 0:
66 | x = x.view(-1, self.n_features*self.n_hidden_GNN[-1])
67 | for layer in self.layers_FC:
68 | x = F.relu(layer(x))
69 | x = F.dropout(x, p=self.dropout_FC, training=self.training)
70 | x = self.last_layer_FC(x)
71 | return x
72 |
73 |
74 | class GraphSAGE(NN):
75 | def __init__(self, \
76 | n_features, \
77 | n_classes, \
78 | n_hidden_GNN=[10], \
79 | n_hidden_FC=[], \
80 | dropout_GNN=0, \
81 | dropout_FC=0):
82 | super(GraphSAGE, self).__init__(\
83 | n_features, n_classes, n_hidden_GNN,\
84 | n_hidden_FC, dropout_FC, dropout_GNN)
85 |
86 | self.layers_GNN.append(pyg_nn.SAGEConv(1, n_hidden_GNN[0]))
87 | if self.n_layers_GNN > 1:
88 | for i in range(self.n_layers_GNN-1):
89 | self.layers_GNN.append(pyg_nn.SAGEConv(n_hidden_GNN[i], n_hidden_GNN[(i+1)]))
90 |
91 |
92 | class ChebNet(NN):
93 | def __init__(self,
94 | n_features,
95 | n_classes,
96 | n_hidden_GNN=[10],
97 | n_hidden_FC=[],
98 | K=4,
99 | dropout_GNN=0,
100 | dropout_FC=0):
101 | super(ChebNet, self).__init__(\
102 | n_features, n_classes, n_hidden_GNN,\
103 | n_hidden_FC, dropout_FC, dropout_GNN)
104 |
105 | self.layers_GNN.append(pyg_nn.ChebConv(1, n_hidden_GNN[0], K))
106 | if self.n_layers_GNN > 1:
107 | for i in range(self.n_layers_GNN-1):
108 | self.layers_GNN.append(pyg_nn.ChebConv(n_hidden_GNN[i], n_hidden_GNN[(i+1)], K))
109 |
110 |
111 | class NNConvNet(NN):
112 | def __init__(self, \
113 | n_features, \
114 | n_classes, \
115 | n_hidden_GNN=[10], \
116 | n_hidden_FC=[], \
117 | dropout_GNN=0, \
118 | dropout_FC=0):
119 | super(NNConvNet, self).__init__(\
120 | n_features, n_classes, n_hidden_GNN,\
121 | n_hidden_FC, dropout_FC, dropout_GNN)
122 |
123 | self.layers_GNN.append(pyg_nn.NNConv(1, n_hidden_GNN[0]))
124 | if self.n_layers_GNN > 1:
125 | for i in range(self.n_layers_GNN-1):
126 | self.layers_GNN.append(pyg_nn.NNConv(n_hidden_GNN[i], n_hidden_GNN[(i+1)]))
127 |
128 | class GATConvNet(NN):
129 | def __init__(self, \
130 | n_features, \
131 | n_classes, \
132 | n_hidden_GNN=[10], \
133 | n_hidden_FC=[], \
134 | dropout_GNN=0, \
135 | dropout_FC=0):
136 | super(GATConvNet, self).__init__(\
137 | n_features, n_classes, n_hidden_GNN,\
138 | n_hidden_FC, dropout_FC, dropout_GNN)
139 |
140 | self.layers_GNN.append(pyg_nn.GATConv(1, n_hidden_GNN[0]))
141 | if self.n_layers_GNN > 1:
142 | for i in range(self.n_layers_GNN-1):
143 | self.layers_GNN.append(pyg_nn.GATConv(n_hidden_GNN[i], n_hidden_GNN[(i+1)]))
144 |
145 | class GENConvNet(NN):
146 | def __init__(self, \
147 | n_features, \
148 | n_classes, \
149 | n_hidden_GNN=[10], \
150 | n_hidden_FC=[], \
151 | dropout_GNN=0, \
152 | dropout_FC=0):
153 | super(GENConvNet, self).__init__(\
154 | n_features, n_classes, n_hidden_GNN,\
155 | n_hidden_FC, dropout_FC, dropout_GNN)
156 |
157 | self.layers_GNN.append(pyg_nn.GENConv(1, n_hidden_GNN[0]))
158 | if self.n_layers_GNN > 1:
159 | for i in range(self.n_layers_GNN-1):
160 | self.layers_GNN.append(pyg_nn.GENConv(n_hidden_GNN[i], n_hidden_GNN[(i+1)]))
161 |
162 |
163 | class GINConv(NN):
164 | def __init__(self, \
165 | n_features, \
166 | n_classes, \
167 | n_hidden_GNN=[10], \
168 | n_hidden_FC=[], \
169 | dropout_GNN=0, \
170 | dropout_FC=0):
171 | super(GINConv, self).__init__(\
172 | n_features, n_classes, n_hidden_GNN,\
173 | n_hidden_FC, dropout_FC, dropout_GNN)
174 |
175 | self.layers_GNN.append(pyg_nn.GINConv(nn.Sequential(nn.Linear(1, n_hidden_GNN[0]),
176 | nn.ReLU(), nn.Linear(n_hidden_GNN[0],n_hidden_GNN[0])),eps=0.2))
177 | if self.n_layers_GNN > 1:
178 | for i in range(self.n_layers_GNN-1):
179 | self.layers_GNN.append(pyg_nn.GINConv(nn.Sequential(nn.Linear(n_hidden_GNN[i], n_hidden_GNN[(i+1)]),
180 | nn.ReLU(), nn.Linear(n_hidden_GNN[(i+1)],n_hidden_GNN[(i+1)]))))
181 |
182 | class GraphConv(NN):
183 | def __init__(self, \
184 | n_features, \
185 | n_classes, \
186 | n_hidden_GNN=[10], \
187 | n_hidden_FC=[], \
188 | dropout_GNN=0, \
189 | dropout_FC=0):
190 | super(GraphConv, self).__init__(\
191 | n_features, n_classes, n_hidden_GNN,\
192 | n_hidden_FC, dropout_FC, dropout_GNN)
193 |
194 | self.layers_GNN.append(pyg_nn.GraphConv(1, n_hidden_GNN[0]))
195 | if self.n_layers_GNN > 1:
196 | for i in range(self.n_layers_GNN-1):
197 | self.layers_GNN.append(pyg_nn.GraphConv(n_hidden_GNN[i], n_hidden_GNN[(i+1)]))
198 |
199 | class MFConv(NN):
200 | def __init__(self, \
201 | n_features, \
202 | n_classes, \
203 | n_hidden_GNN=[10], \
204 | n_hidden_FC=[], \
205 | dropout_GNN=0, \
206 | dropout_FC=0):
207 | super(MFConv, self).__init__(\
208 | n_features, n_classes, n_hidden_GNN,\
209 | n_hidden_FC, dropout_FC, dropout_GNN)
210 |
211 | self.layers_GNN.append(pyg_nn.MFConv(1, n_hidden_GNN[0]))
212 | if self.n_layers_GNN > 1:
213 | for i in range(self.n_layers_GNN-1):
214 | self.layers_GNN.append(pyg_nn.MFConv(n_hidden_GNN[i], n_hidden_GNN[(i+1)]))
215 |
216 | class TransformerConv(NN):
217 | def __init__(self, \
218 | n_features, \
219 | n_classes, \
220 | n_hidden_GNN=[10], \
221 | n_hidden_FC=[], \
222 | dropout_GNN=0, \
223 | dropout_FC=0):
224 | super(TransformerConv, self).__init__(\
225 | n_features, n_classes, n_hidden_GNN,\
226 | n_hidden_FC, dropout_FC, dropout_GNN)
227 |
228 | self.layers_GNN.append(pyg_nn.TransformerConv(1, n_hidden_GNN[0]))
229 | if self.n_layers_GNN > 1:
230 | for i in range(self.n_layers_GNN-1):
231 | self.layers_GNN.append(pyg_nn.TransformerConv(n_hidden_GNN[i], n_hidden_GNN[(i+1)]))
232 |
233 | class ConvNet(NN):
234 | def __init__(self,
235 | n_features,
236 | n_classes,
237 | n_hidden_GNN=[10],
238 | n_hidden_FC=[],
239 | filter_K=4,
240 | pool_K=0,
241 | dropout_GNN=0,
242 | dropout_FC=0):
243 | super(ConvNet, self).__init__(\
244 | n_features, n_classes, n_hidden_GNN,\
245 | n_hidden_FC, dropout_FC, dropout_GNN)
246 | self.conv = True
247 | self.filter_K = filter_K
248 | self.pool_K = pool_K
249 |
250 | self.layers_GNN.append(nn.Conv1d(in_channels=1, out_channels=n_hidden_GNN[0], kernel_size=filter_K, padding=filter_K//2, dilation=1, stride=1))
251 | if self.n_layers_GNN > 1:
252 | for i in range(self.n_layers_GNN-1):
253 | self.layers_GNN.append(nn.Conv1d(in_channels=n_hidden_GNN[i], out_channels=n_hidden_GNN[(i+1)], kernel_size=filter_K, padding=filter_K//2, dilation=1, stride=1))
254 |
--------------------------------------------------------------------------------
/sagenet/sage.py:
--------------------------------------------------------------------------------
1 | from sagenet.utils import *
2 | from sagenet.classifier import *
3 | from sagenet.model import *
4 | from os import listdir
5 | import numpy as np
6 | import anndata
7 | import re
8 |
9 | class sage():
10 | """
11 | A `sagenet` object.
12 |
13 | Parameters
14 | ----------
15 | device : str, default = 'cpu'
16 | the processing unit to be used in the classifiers (gpu or cpu).
17 | """
18 |
19 | def __init__(self, device='cpu'):
20 | self.models = {}
21 | self.adjs = {}
22 | inf_genes = None
23 | self.num_refs = 0
24 | self.device = device
25 |
26 | def add_ref(self,
27 | adata,
28 | tag = None,
29 | comm_columns = 'class_',
30 | classifier = 'TransformerConv',
31 | num_workers = 0,
32 | batch_size = 32,
33 | epochs = 10,
34 | n_genes = 10,
35 | verbose = False):
36 | """Trains new classifiers on a reference dataset.
37 |
38 | Parameters
39 | ----------
40 | adata : `AnnData`
41 | The annotated data matrix of shape `n_obs × n_vars` to be used as the spatial reference. Rows correspond to cells (or spots) and columns to genes.
42 | tag : str, default = `None`
43 | The tag to be used for storing the trained models and the outputs in the `sagenet` object.
44 | classifier : str, default = `'TransformerConv'`
45 | The type of classifier to be passed to `sagenet.Classifier()`
46 | comm_columns : list of str, `'class_'`
47 | The columns in `adata.obs` to be used as spatial partitions.
48 | num_workers : int
49 | Non-negative. Number of workers to be passed to `torch_geometric.data.DataLoader`.
50 | epochs : int
51 | number of epochs.
52 | verbose : boolean, default=False
53 | whether to print out loss during training.
54 |
55 | Return
56 | ------
57 | Returns nothing.
58 |
59 | Notes
60 | -----
61 | Trains the models and adds them to `.models` dictionery of the `sagenet` object.
62 | Also adds a new key `{tag}_entropy` to `.var` from `adata` which contains the entropy values as the importance score corresponding to each gene.
63 | """
64 | ind = np.where(np.sum(adata.varm['adj'], axis=1) == 0)[0]
65 | ents = np.ones(adata.var.shape[0]) * 1000000
66 | # ents = np.zeros(adata.var.shape[0])
67 | self.num_refs += 1
68 |
69 | if tag is None:
70 | tag = 'ref' + str(self.num_refs)
71 |
72 | for comm in comm_columns:
73 | data_loader = get_dataloader(
74 | graph = adata.varm['adj'].toarray(),
75 | X = adata.X, y = adata.obs[comm].values.astype('long'),
76 | batch_size = batch_size,
77 | shuffle = True,
78 | num_workers = num_workers
79 | )
80 |
81 | clf = Classifier(
82 | n_features = adata.shape[1],
83 | n_classes = (np.max(adata.obs[comm].values.astype('long'))+1),
84 | n_hidden_GNN = [8],
85 | dropout_FC = 0.2,
86 | dropout_GNN = 0.3,
87 | classifier = classifier,
88 | lr = 0.001,
89 | momentum = 0.9,
90 | device = self.device
91 | )
92 |
93 | clf.fit(data_loader, epochs = epochs, test_dataloader=None,verbose=verbose)
94 | imp = clf.interpret(data_loader, n_features=adata.shape[1], n_classes=(np.max(adata.obs[comm].values.astype('long'))+1))
95 | idx = (-abs(imp)).argsort(axis=0)
96 | imp = np.min(idx, axis=1)
97 | # imp += imp
98 | np.put(imp, ind, 1000000)
99 | ents = np.minimum(ents, imp)
100 |
101 | # imp = np.min(idx, axis=1)
102 | # ents = np.minimum(ents, imp)
103 | self.models['_'.join([tag, comm])] = clf.net
104 | self.adjs['_'.join([tag, comm])] = adata.varm['adj'].toarray()
105 | save_adata(adata, attr='var', key='_'.join([tag, 'importance']), data=ents)
106 | # return ents
107 |
108 |
109 | def map_query(self, adata_q, save_pred=True, save_ent=True, save_prob=False, save_dist=False):
110 | """Maps a query dataset to space using the trained models on the spatial reference(s).
111 |
112 | Parameters
113 | ----------
114 | adata : `AnnData`
115 | The annotated data matrix of shape `n_obs × n_vars` to be used as the query. Rows correspond to cells (or spots) and columns to genes.
116 |
117 | Return
118 | ------
119 | Returns nothing.
120 |
121 | Notes
122 | -----
123 | * Adds new key(s) `pred_{tag}_{partitioning_name}` to `.obs` from `adata` which contains the predicted partition for partitioning `{partitioning_name}`, trained by model `{tag}`.
124 | * Adds new key(s) `ent_{tag}_{partitioning_name}` to `.obs` from `adata` which contains the uncertainity in prediction for partitioning `{partitioning_name}`, trained by model `{tag}`.
125 | * Adds a new key `distmap` to `.obsm` from `adata` which is a sparse matrix of size `n_obs × n_obs` containing the reconstructed cell-to-cell spatial distance.
126 | """
127 | for tag in self.models.keys():
128 | self.models[tag].eval()
129 | i = 0
130 | adata_q.obs['class_'] = 0
131 | data_loader = get_dataloader(
132 | graph = self.adjs[tag],
133 | X = adata_q.X, y = adata_q.obs['class_'].values.astype('long'), #TODO: fix this
134 | batch_size = 1,
135 | shuffle = False,
136 | num_workers = 0
137 | )
138 | with torch.no_grad():
139 | for batch in data_loader:
140 | x, edge_index = batch.x.to(self.device), batch.edge_index.to(self.device)
141 | outputs = self.models[tag](x, edge_index)
142 | predicted = outputs.data.to('cpu').detach().numpy()
143 | i += 1
144 | if i == 1:
145 | n_classes = predicted.shape[1]
146 | y_pred = np.empty((0, n_classes))
147 | y_pred = np.concatenate((y_pred, predicted), axis=0)
148 | if save_pred or save_prob or save_ent or save_dist:
149 | y_pred = np.exp(y_pred)
150 | y_pred = (y_pred.T / y_pred.T.sum(0)).T
151 | if save_prob:
152 | save_adata(adata_q, attr='obsm', key='_'.join(['prob', tag]), data = y_pred)
153 | if save_pred:
154 | save_adata(adata_q, attr='obs', key='_'.join(['pred', tag]), data = np.argmax(y_pred, axis=1))
155 | if save_ent:
156 | temp = (-y_pred * np.log2(y_pred)).sum(axis = 1)
157 | # adata_q.obs['_'.join(['ent', tag])] = np.array(temp) / np.log2(n_classes)
158 | save_adata(adata_q, attr='obs', key='_'.join(['ent', tag]), data = (np.array(temp) / np.log2(n_classes)))
159 | if save_dist:
160 | dist_mat = np.zeros((adata_q.shape[0], adata_q.shape[0]))
161 | y_pred_1 = (multinomial_rvs(1, y_pred).T * np.array(adata_q.obs['_'.join(['ent', tag])])).T
162 | y_pred_2 = (y_pred.T * (1-np.array(adata_q.obs['_'.join(['ent', tag])]))).T
163 | y_pred_final = y_pred_1 + y_pred_2
164 | kl_d = kullback_leibler_divergence(y_pred_final)
165 | kl_d = kl_d + kl_d.T
166 | kl_d /= np.linalg.norm(kl_d, 'fro')
167 | dist_mat += kl_d
168 | save_adata(adata_q, attr='obsm', key='dist_map', data=dist_mat)
169 |
170 | def save_model(self, tag, dir='.'):
171 | """Saves a single trained model.
172 |
173 | Parameters
174 | ----------
175 | tag : str
176 | Name of the trained model to be saved.
177 | dir : dir, defult=`'.'`
178 | The saving directory.
179 | """
180 | path = os.path.join(dir, tag) + '.pickle'
181 | torch.save(self.models[tag], path)
182 |
183 | def load_model(self, tag, dir='.'):
184 | """Loads a single pre-trained model.
185 |
186 | Parameters
187 | ----------
188 | tag : str
189 | Name of the trained model to be stored in the `sagenet` object.
190 | dir : dir, defult=`'.'`
191 | The input directory.
192 | """
193 | path = os.path.join(dir, tag) + '.pickle'
194 | self.models[tag] = torch.load(path)
195 |
196 | def save_model_as_folder(self, dir='.'):
197 | """Saves all trained models stored in the `sagenet` object as a folder.
198 |
199 | Parameters
200 | ----------
201 | dir : dir, defult=`'.'`
202 | The saving directory.
203 | """
204 | for tag in self.models.keys():
205 | self.save_model(tag, dir)
206 | adj_path = os.path.join(dir, tag) + '.h5ad'
207 | adj_adata = anndata.AnnData(X = self.adjs[tag])
208 | adj_adata.write(filename=adj_path)
209 |
210 | def load_model_as_folder(self, dir='.'):
211 | """Loads pre-trained models from a directory.
212 |
213 | Parameters
214 | ----------
215 | dir : dir, defult=`'.'`
216 | The input directory.
217 | """
218 | model_files = [f for f in listdir(dir) if re.search(r".pickle$", f)]
219 | for m in model_files:
220 | tag = re.sub(r'.pickle', '', m)
221 | model_path = os.path.join(dir, tag) + '.pickle'
222 | adj_path = os.path.join(dir, tag) + '.h5ad'
223 | self.models[tag] = torch.load(model_path)
224 | self.adjs[tag] = anndata.read_h5ad(adj_path).X
225 |
226 |
227 |
--------------------------------------------------------------------------------
/sagenet/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 | import numpy as np
4 | from sklearn.metrics import *
5 | from gglasso.helper.data_generation import generate_precision_matrix, group_power_network, sample_covariance_matrix
6 | from gglasso.problem import glasso_problem
7 | from gglasso.helper.basic_linalg import adjacency_matrix
8 | import torch
9 | import torch_geometric.data as geo_dt
10 | from sklearn.utils.extmath import fast_logdet
11 | from scipy import sparse
12 | import warnings
13 | from torch.nn import Softmax
14 | from sklearn.preprocessing import StandardScaler
15 | from sklearn.covariance import empirical_covariance
16 | from sklearn.metrics import *
17 | from scipy import sparse
18 | from functools import reduce
19 | import warnings
20 | from torch.nn import Softmax
21 | from scipy.spatial import cKDTree
22 |
23 | def glasso(adata, lambda_low=-10, lambda_high=-1, mode='cd'):
24 | """
25 | Recustructs the gene-gene interaction network based on gene expressions in `.X` using a guassian graphical model estimated by `glasso`.
26 |
27 | Parameters
28 | ----------
29 | adata: `AnnData`
30 | The annotated data matrix of shape `n_obs × n_vars`. Rows correspond to cells and columns to genes.
31 | alphas: int or array-like of shape (n_alphas,), dtype=`float`, default=`5`
32 | Non-negative. If an integer is given, it fixes the number of points on the grids of alpha to be used. If a list is given, it gives the grid to be used.
33 | n_jobs: int, default `None`
34 | Non-negative. number of jobs.
35 |
36 | Returns
37 | -------
38 | adds an `csr_matrix` matrix under key `adj` to `.varm`.
39 |
40 | References
41 | -----------
42 | Friedman, J., Hastie, T., & Tibshirani, R. (2008).
43 | Sparse inverse covariance estimation with the graphical lasso.
44 | Biostatistics, 9(3), 432-441.
45 | """
46 | N = adata.shape[1]
47 | scaler = StandardScaler()
48 | data = scaler.fit_transform(adata.X)
49 | S = empirical_covariance(data)
50 | P = glasso_problem(S, N, latent = False, do_scaling = True)
51 | # lambda1_range = np.logspace(-0.1, -1, 10)
52 | lambda1_range = np.logspace(-10, -1,10)
53 | modelselect_params = {'lambda1_range': lambda1_range}
54 | P.model_selection(modelselect_params = modelselect_params, method = 'eBIC', gamma = 0.1, tol=1e-7)
55 | sol = P.solution.precision_
56 | P.solution.calc_adjacency(t = 1e-4)
57 | save_adata(adata, attr='varm', key='adj', data=sparse.csr_matrix(P.solution.precision_))
58 |
59 |
60 |
61 | def compute_metrics(y_true, y_pred):
62 | """
63 | Computes prediction quality metrics.
64 |
65 | Parameters
66 | ----------
67 | y_true : 1d array-like, or label indicator array / sparse matrix
68 | Ground truth (correct) labels.
69 |
70 | y_pred : 1d array-like, or label indicator array / sparse matrix
71 | Predicted labels, as returned by a classifier.
72 |
73 | Returns
74 | --------
75 | accuracy : accuracy
76 | conf_mat : confusion matrix
77 | precision : weighted precision score
78 | recall : weighted recall score
79 | f1 : weighted f1 score
80 | """
81 | accuracy = accuracy_score(y_true, y_pred)
82 | conf_mat = confusion_matrix(y_true, y_pred)
83 | precision = precision_score(y_true, y_pred, average='weighted')
84 | recall = recall_score(y_true, y_pred, average='weighted')
85 | f1 = f1_score(y_true, y_pred, average='weighted')
86 | return accuracy, conf_mat, precision, recall, f1
87 |
88 |
89 |
90 |
91 | def get_dataloader(graph, X, y, batch_size=1, undirected=True, shuffle=True, num_workers=0):
92 | """
93 | Converts a graph and a dataset to a dataloader.
94 |
95 | Parameters
96 | ----------
97 | graph : igraph object
98 | The underlying graph to be fed to the graph neural networks.
99 |
100 | X : numpy ndarray
101 | Input dataset with columns as features and rows as observations.
102 |
103 | y : numpy ndarray
104 | Class labels.
105 |
106 | batch_size: int, default=1
107 | The batch size.
108 |
109 | undirected: boolean
110 | if the input graph is undirected (symmetric adjacency matrix).
111 |
112 | shuffle: boolean, default = `True`
113 | Wheather to shuffle the dataset to be passed to `torch_geometric.data.DataLoader`.
114 |
115 | num_workers: int, default = 0
116 | Non-negative. Number of workers to be passed to `torch_geometric.data.DataLoader`.
117 |
118 |
119 | Returns
120 | --------
121 | dataloader : a pytorch-geometric dataloader. All of the graphs will have the same connectivity (given by the input graph),
122 | but the node features will be the features from X.
123 | """
124 | n_obs, n_features = X.shape
125 | rows, cols = np.where(graph == 1)
126 | edges = zip(rows.tolist(), cols.tolist())
127 | sources = []
128 | targets = []
129 | for edge in edges:
130 | sources.append(edge[0])
131 | targets.append(edge[1])
132 | if undirected:
133 | sources.append(edge[0])
134 | targets.append(edge[1])
135 | edge_index = torch.tensor([sources,targets],dtype=torch.long)
136 |
137 | list_graphs = []
138 | y = y.tolist()
139 | # print(y)
140 | for i in range(n_obs):
141 | y_tensor = torch.tensor(y[i])
142 | X_tensor = torch.tensor(X[i,:]).view(X.shape[1], 1).float()
143 | data = geo_dt.Data(x=X_tensor, edge_index=edge_index, y=y_tensor)
144 | list_graphs.append(data.coalesce())
145 |
146 | dataloader = geo_dt.DataLoader(list_graphs, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=False)
147 | return dataloader
148 |
149 |
150 |
151 | def kullback_leibler_divergence(X):
152 |
153 | """Finds the pairwise Kullback-Leibler divergence
154 | matrix between all rows in X.
155 |
156 | Parameters
157 | ----------
158 | X : array_like, shape (n_samples, n_features)
159 | Array of probability data. Each row must sum to 1.
160 |
161 | Returns
162 | -------
163 | D : ndarray, shape (n_samples, n_samples)
164 | The Kullback-Leibler divergence matrix. A pairwise matrix D such that D_{i, j}
165 | is the divergence between the ith and jth vectors of the given matrix X.
166 |
167 | Notes
168 | -----
169 | Based on code from Gordon J. Berman et al.
170 | (https://github.com/gordonberman/MotionMapper)
171 |
172 | References
173 | -----------
174 | Berman, G. J., Choi, D. M., Bialek, W., & Shaevitz, J. W. (2014).
175 | Mapping the stereotyped behaviour of freely moving fruit flies.
176 | Journal of The Royal Society Interface, 11(99), 20140672.
177 | """
178 |
179 | X_log = np.log(X)
180 | X_log[np.isinf(X_log) | np.isnan(X_log)] = 0
181 |
182 | entropies = -np.sum(X * X_log, axis=1)
183 |
184 | D = np.matmul(-X, X_log.T)
185 | D = D - entropies
186 | D = D / np.log(2)
187 | D *= (1 - np.eye(D.shape[0]))
188 |
189 | return D
190 |
191 | def multinomial_rvs(n, p):
192 | """Sample from the multinomial distribution with multiple p vectors.
193 |
194 | Parameters
195 | ----------
196 | n : int
197 | must be a scalar >=1
198 | p : numpy ndarray
199 | must an n-dimensional
200 | he last axis of p holds the sequence of probabilities for a multinomial distribution.
201 |
202 | Returns
203 | -------
204 | D : ndarray
205 | same shape as p
206 | """
207 | count = np.full(p.shape[:-1], n)
208 | out = np.zeros(p.shape, dtype=int)
209 | ps = p.cumsum(axis=-1)
210 | # Conditional probabilities
211 | with np.errstate(divide='ignore', invalid='ignore'):
212 | condp = p / ps
213 | condp[np.isnan(condp)] = 0.0
214 | for i in range(p.shape[-1]-1, 0, -1):
215 | binsample = np.random.binomial(count, condp[..., i])
216 | out[..., i] = binsample
217 | count -= binsample
218 | out[..., 0] = count
219 | return out
220 |
221 | def save_adata(adata, attr, key, data):
222 | """updates an attribute of an `AnnData` object
223 |
224 | Parameters
225 | ----------
226 | adata : `AnnData`
227 | The annotated data matrix of shape `n_obs × n_vars`. Rows correspond to cells and columns to genes.
228 | attr : str
229 | must be an attribute of `adata`, e.g., `obs`, `var`, etc.
230 | key : str
231 | must be a key in the attr
232 | data : non-specific
233 | the data to be updated/placed
234 |
235 | """
236 | obj = getattr(adata, attr)
237 | obj[key] = data
238 |
239 |
240 | def prob_con(adata, overwrite=False, inplace=True):
241 | if "prob_concatenated" in adata.obsm.keys():
242 | warnings.warn("obsm['prob_concatenated'] already exists!")
243 | if not overwrite:
244 | return adata
245 | else:
246 | warnings.warn("overwriting obsm['prob_concatenated'].")
247 | del adata.obsm["prob_concatenated"]
248 | # Get a list of obsm matrices with names starting with "prob"
249 | prob_matrices = [matrix_name for matrix_name in adata.obsm.keys() if matrix_name.startswith("prob")]
250 | # Define a function to concatenate two matrices
251 | def concatenate_matrices(matrix1, matrix2):
252 | return np.concatenate((matrix1, matrix2), axis=1)
253 | # Use functools.reduce to concatenate all matrices in prob_matrices
254 | if prob_matrices:
255 | concatenated_matrix = reduce(concatenate_matrices, [adata.obsm[matrix] for matrix in prob_matrices])
256 | adata.obsm["prob_concatenated"] = concatenated_matrix
257 | if inplace:
258 | save_adata(adata, attr='obsm', key='spatial', data=concatenated_matrix)
259 | return None
260 | else:
261 | warnings.warn("No 'prob' matrices found in the AnnData object.")
262 | if not inplace:
263 | return adata
264 |
265 |
266 | def map2ref(adata_ref, adata_q, k=10):
267 | # if "spatial" not in adata_ref.obsm.keys():
268 | # raise Exception("adata_ref.obsm['spatial'] does not exist. Necessary for spatial mapping.")
269 | if "prob_concatenated" not in adata_ref.obsm.keys():
270 | warnings.warn("obsm['prob_concatenated'] does not exsit for adata_ref. Calculating obsm['prob_concatenated'].")
271 | prob_con(adata_ref)
272 | if "prob_concatenated" not in adata_q.obsm.keys():
273 | warnings.warn("obsm['prob_concatenated'] does not exsit for adata_q. Calculating obsm['prob_concatenated'].")
274 | prob_con(adata_q)
275 | ref_embeddings = adata_ref.obsm['prob_concatenated']
276 | kdtree_r1 = cKDTree(ref_embeddings)
277 | target_embeddings = adata_q.obsm['prob_concatenated']
278 | distances, indices = kdtree_r1.query(target_embeddings, k=k)
279 | m = Softmax(dim=1)
280 | probs = m(-torch.tensor(distances))
281 | dist = torch.distributions.categorical.Categorical(probs=probs)
282 | idx = dist.sample().numpy()
283 | indices = indices[np.arange(len(indices)), idx]
284 | conf = distances.min(1) / distances.min()
285 | return indices, conf
286 | # adata_q.obsm['spatial'] = adata_ref.obsm['spatial'][indices]
287 | # if inplace:
288 | # save_adata(adata_q, attr='obsm', key='spatial', data= adata_ref.obsm['spatial'][indices])
289 | # else:
290 | # return adata_q
291 | # swap.obs['sink']
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from setuptools import setup, find_packages
4 |
5 | long_description = Path('README.rst').read_text('utf-8')
6 |
7 | import os
8 | os.environ[ 'NUMBA_CACHE_DIR' ] = '/tmp/'
9 |
10 | try:
11 | from sagenet import __author__, __email__
12 | except ImportError: # Deps not yet installed
13 | __author__ = __maintainer__ ='Elyas Heidari'
14 | __email__ = 'eheidari@student.ethz.ch'
15 | __version__ = '0.1.2'
16 |
17 | setup(name='sagenet',
18 | version = "0.1.2",
19 | description='Spatial reconstruction of dissociated single-cell data',
20 | long_description=long_description,
21 | long_description_content_type="text/markdown",
22 | url='https://github.com/MarioniLab/sagenet',
23 | author=__author__,
24 | author_email=__email__,
25 | license='MIT',
26 | platforms=["Linux", "MacOSX"],
27 | packages=find_packages(),
28 | zip_safe=False,
29 | # download_url="https://github.com/MarioniLab/sagenet/archive/refs/tags/SageNet_v0.1.0.1.tar.gz",
30 | project_urls={
31 | "Documentation": "https://sagenet.readthedocs.io/en/latest",
32 | "Source Code": "https://github.com/MarioniLab/sagenet",
33 | },
34 | install_requires=[l.strip() for l in Path("requirements.txt").read_text("utf-8").splitlines()],
35 | classifiers=[
36 | "Development Status :: 5 - Production/Stable",
37 | "Intended Audience :: Science/Research",
38 | "Natural Language :: English",
39 | "License :: OSI Approved :: BSD License",
40 | "Operating System :: POSIX :: Linux",
41 | "Operating System :: MacOS :: MacOS X",
42 | "Typing :: Typed",
43 | # "Programming Language :: Python :: 3",
44 | # "Programming Language :: Python :: 3.6",
45 | "Programming Language :: Python :: 3.7",
46 | "Programming Language :: Python :: 3.8",
47 | "Programming Language :: Python :: 3.9",
48 | "Environment :: Console",
49 | "Framework :: Jupyter",
50 | "Intended Audience :: Science/Research",
51 | "Topic :: Scientific/Engineering :: Bio-Informatics",
52 | "Topic :: Scientific/Engineering :: Visualization",
53 | ],
54 | doc=[
55 | 'sphinx',
56 | 'sphinx_rtd_theme',
57 | 'sphinx_autodoc_typehints',
58 | 'typing_extensions; python_version < "3.8"',
59 | ],
60 | keywords=sorted(
61 | [
62 | "single-cell",
63 | "bio-informatics",
64 | "spatial transcriptomics",
65 | "spatial data analysis",
66 | "single-cell data analysis",
67 | ]
68 | ),
69 | )
70 |
--------------------------------------------------------------------------------